feat(training): strengthen lstm-mlp with embeddings and early stopping
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
import os
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@@ -11,15 +12,15 @@ from core.model_features import engineer_features
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
except ImportError:
|
||||
torch = None
|
||||
nn = None
|
||||
DataLoader = None
|
||||
TensorDataset = None
|
||||
Dataset = object
|
||||
|
||||
|
||||
WINDOW_SIZE = 5
|
||||
WINDOW_SIZE = 8
|
||||
SEQUENCE_FEATURES = [
|
||||
'缺勤月份',
|
||||
'星期几',
|
||||
@@ -52,34 +53,125 @@ STATIC_FEATURES = [
|
||||
'家庭负担指数',
|
||||
'岗位稳定性指数',
|
||||
]
|
||||
DEFAULT_EPOCHS = 80
|
||||
DEFAULT_BATCH_SIZE = 256
|
||||
EARLY_STOPPING_PATIENCE = 12
|
||||
|
||||
|
||||
class SequenceStaticDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
seq_num: np.ndarray,
|
||||
seq_cat: np.ndarray,
|
||||
static_num: np.ndarray,
|
||||
static_cat: np.ndarray,
|
||||
targets: np.ndarray,
|
||||
):
|
||||
self.seq_num = torch.tensor(seq_num, dtype=torch.float32)
|
||||
self.seq_cat = torch.tensor(seq_cat, dtype=torch.long)
|
||||
self.static_num = torch.tensor(static_num, dtype=torch.float32)
|
||||
self.static_cat = torch.tensor(static_cat, dtype=torch.long)
|
||||
self.targets = torch.tensor(targets, dtype=torch.float32)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.targets)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
return (
|
||||
self.seq_num[index],
|
||||
self.seq_cat[index],
|
||||
self.static_num[index],
|
||||
self.static_cat[index],
|
||||
self.targets[index],
|
||||
)
|
||||
|
||||
|
||||
class LSTMMLPRegressor(nn.Module):
|
||||
def __init__(self, seq_input_dim: int, static_input_dim: int):
|
||||
def __init__(
|
||||
self,
|
||||
seq_num_dim: int,
|
||||
static_num_dim: int,
|
||||
seq_cat_cardinalities: List[int],
|
||||
static_cat_cardinalities: List[int],
|
||||
):
|
||||
super().__init__()
|
||||
self.seq_cat_embeddings = nn.ModuleList(
|
||||
[nn.Embedding(cardinality, _embedding_dim(cardinality)) for cardinality in seq_cat_cardinalities]
|
||||
)
|
||||
self.static_cat_embeddings = nn.ModuleList(
|
||||
[nn.Embedding(cardinality, _embedding_dim(cardinality)) for cardinality in static_cat_cardinalities]
|
||||
)
|
||||
|
||||
seq_cat_dim = sum(embedding.embedding_dim for embedding in self.seq_cat_embeddings)
|
||||
static_cat_dim = sum(embedding.embedding_dim for embedding in self.static_cat_embeddings)
|
||||
seq_input_dim = seq_num_dim + seq_cat_dim
|
||||
static_input_dim = static_num_dim + static_cat_dim
|
||||
|
||||
self.seq_projection = nn.Sequential(
|
||||
nn.Linear(seq_input_dim, 128),
|
||||
nn.LayerNorm(128),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.15),
|
||||
)
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=seq_input_dim,
|
||||
hidden_size=48,
|
||||
num_layers=1,
|
||||
input_size=128,
|
||||
hidden_size=96,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
dropout=0.0,
|
||||
dropout=0.2,
|
||||
bidirectional=True,
|
||||
)
|
||||
self.sequence_head = nn.Sequential(
|
||||
nn.Linear(96 * 2 * 2, 128),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.2),
|
||||
)
|
||||
self.static_net = nn.Sequential(
|
||||
nn.Linear(static_input_dim, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(static_input_dim, 96),
|
||||
nn.LayerNorm(96),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.15),
|
||||
nn.Linear(96, 64),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(48 + 32, 48),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128 + 64, 128),
|
||||
nn.LayerNorm(128),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.2),
|
||||
nn.Linear(128, 64),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(48, 1),
|
||||
nn.Linear(64, 1),
|
||||
)
|
||||
|
||||
def forward(self, sequence_x, static_x):
|
||||
lstm_output, _ = self.lstm(sequence_x)
|
||||
sequence_repr = lstm_output[:, -1, :]
|
||||
static_repr = self.static_net(static_x)
|
||||
def _embed_categorical(self, inputs: torch.Tensor, embeddings: nn.ModuleList) -> Optional[torch.Tensor]:
|
||||
if not embeddings:
|
||||
return None
|
||||
parts = [embedding(inputs[..., index]) for index, embedding in enumerate(embeddings)]
|
||||
return torch.cat(parts, dim=-1)
|
||||
|
||||
def forward(self, seq_num_x, seq_cat_x, static_num_x, static_cat_x):
|
||||
seq_parts = [seq_num_x]
|
||||
seq_embedded = self._embed_categorical(seq_cat_x, self.seq_cat_embeddings)
|
||||
if seq_embedded is not None:
|
||||
seq_parts.append(seq_embedded)
|
||||
seq_input = torch.cat(seq_parts, dim=-1)
|
||||
seq_input = self.seq_projection(seq_input)
|
||||
|
||||
lstm_output, _ = self.lstm(seq_input)
|
||||
sequence_last = lstm_output[:, -1, :]
|
||||
sequence_mean = lstm_output.mean(dim=1)
|
||||
sequence_repr = self.sequence_head(torch.cat([sequence_last, sequence_mean], dim=1))
|
||||
|
||||
static_parts = [static_num_x]
|
||||
static_embedded = self._embed_categorical(static_cat_x, self.static_cat_embeddings)
|
||||
if static_embedded is not None:
|
||||
static_parts.append(static_embedded)
|
||||
static_input = torch.cat(static_parts, dim=-1)
|
||||
static_repr = self.static_net(static_input)
|
||||
|
||||
fused = torch.cat([sequence_repr, static_repr], dim=1)
|
||||
return self.fusion(fused).squeeze(1)
|
||||
|
||||
@@ -88,175 +180,406 @@ def is_available() -> bool:
|
||||
return torch is not None
|
||||
|
||||
|
||||
def _embedding_dim(cardinality: int) -> int:
|
||||
return int(min(24, max(4, round(cardinality ** 0.35 * 2))))
|
||||
|
||||
|
||||
def _split_feature_types(df: pd.DataFrame, features: List[str]) -> Tuple[List[str], List[str]]:
|
||||
categorical = []
|
||||
numerical = []
|
||||
for feature in features:
|
||||
if feature not in df.columns:
|
||||
continue
|
||||
if pd.api.types.is_numeric_dtype(df[feature]):
|
||||
numerical.append(feature)
|
||||
else:
|
||||
categorical.append(feature)
|
||||
return categorical, numerical
|
||||
|
||||
|
||||
def _fit_category_maps(df: pd.DataFrame, features: List[str]) -> Dict[str, Dict[str, int]]:
|
||||
category_maps = {}
|
||||
for feature in features:
|
||||
if feature not in df.columns:
|
||||
continue
|
||||
if pd.api.types.is_numeric_dtype(df[feature]):
|
||||
continue
|
||||
values = sorted(df[feature].astype(str).unique().tolist())
|
||||
category_maps[feature] = {value: idx for idx, value in enumerate(values)}
|
||||
values = sorted(df[feature].astype(str).fillna('__MISSING__').unique().tolist())
|
||||
category_maps[feature] = {value: idx + 1 for idx, value in enumerate(values)}
|
||||
return category_maps
|
||||
|
||||
|
||||
def _apply_category_maps(df: pd.DataFrame, features: List[str], category_maps: Dict[str, Dict[str, int]]) -> pd.DataFrame:
|
||||
encoded = df.copy()
|
||||
for feature in features:
|
||||
if feature not in encoded.columns:
|
||||
encoded[feature] = 0
|
||||
continue
|
||||
if feature in category_maps:
|
||||
mapper = category_maps[feature]
|
||||
encoded[feature] = encoded[feature].astype(str).map(lambda value: mapper.get(value, 0))
|
||||
return encoded
|
||||
def _encode_categorical_series(values: pd.Series, mapping: Dict[str, int]) -> np.ndarray:
|
||||
return values.astype(str).fillna('__MISSING__').map(lambda value: mapping.get(value, 0)).to_numpy(dtype=np.int64)
|
||||
|
||||
|
||||
def _safe_standardize(values: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if values.shape[1] == 0:
|
||||
return np.zeros((0,), dtype=np.float32), np.ones((0,), dtype=np.float32)
|
||||
mean = values.mean(axis=0)
|
||||
std = values.std(axis=0)
|
||||
std = np.where(std < 1e-6, 1.0, std)
|
||||
return mean.astype(np.float32), std.astype(np.float32)
|
||||
|
||||
|
||||
def _build_feature_layout(train_df: pd.DataFrame) -> Dict[str, List[str]]:
|
||||
used_features = sorted(set(SEQUENCE_FEATURES + STATIC_FEATURES))
|
||||
seq_cat_features, seq_num_features = _split_feature_types(train_df, SEQUENCE_FEATURES)
|
||||
static_cat_features, static_num_features = _split_feature_types(train_df, STATIC_FEATURES)
|
||||
all_cat_features = sorted(set(seq_cat_features + static_cat_features))
|
||||
return {
|
||||
'used_features': used_features,
|
||||
'seq_cat_features': seq_cat_features,
|
||||
'seq_num_features': seq_num_features,
|
||||
'static_cat_features': static_cat_features,
|
||||
'static_num_features': static_num_features,
|
||||
'all_cat_features': all_cat_features,
|
||||
}
|
||||
|
||||
|
||||
def _build_sequence_arrays(
|
||||
df: pd.DataFrame,
|
||||
feature_layout: Dict[str, List[str]],
|
||||
category_maps: Dict[str, Dict[str, int]],
|
||||
target_transform: str,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
df = engineer_features(df.copy())
|
||||
features = sorted(set(SEQUENCE_FEATURES + STATIC_FEATURES))
|
||||
df = _apply_category_maps(df, features, category_maps)
|
||||
|
||||
for feature in feature_layout['used_features']:
|
||||
if feature not in df.columns:
|
||||
df[feature] = 0
|
||||
|
||||
df = df.sort_values(
|
||||
[config.EMPLOYEE_ID_COLUMN, config.EVENT_DATE_INDEX_COLUMN, config.EVENT_SEQUENCE_COLUMN]
|
||||
).reset_index(drop=True)
|
||||
|
||||
sequence_samples = []
|
||||
static_samples = []
|
||||
sequence_num_samples = []
|
||||
sequence_cat_samples = []
|
||||
static_num_samples = []
|
||||
static_cat_samples = []
|
||||
targets = []
|
||||
|
||||
for _, group in df.groupby(config.EMPLOYEE_ID_COLUMN, sort=False):
|
||||
seq_values = group[SEQUENCE_FEATURES].astype(float).values
|
||||
static_values = group[STATIC_FEATURES].astype(float).values
|
||||
target_values = group[config.TARGET_COLUMN].astype(float).values
|
||||
seq_num_values = group[feature_layout['seq_num_features']].astype(float).to_numpy(dtype=np.float32)
|
||||
static_num_values = group[feature_layout['static_num_features']].astype(float).to_numpy(dtype=np.float32)
|
||||
target_values = group[config.TARGET_COLUMN].astype(float).to_numpy(dtype=np.float32)
|
||||
|
||||
if feature_layout['seq_cat_features']:
|
||||
seq_cat_values = np.column_stack(
|
||||
[
|
||||
_encode_categorical_series(group[feature], category_maps[feature])
|
||||
for feature in feature_layout['seq_cat_features']
|
||||
]
|
||||
).astype(np.int64)
|
||||
else:
|
||||
seq_cat_values = np.zeros((len(group), 0), dtype=np.int64)
|
||||
|
||||
if feature_layout['static_cat_features']:
|
||||
static_cat_values = np.column_stack(
|
||||
[
|
||||
_encode_categorical_series(group[feature], category_maps[feature])
|
||||
for feature in feature_layout['static_cat_features']
|
||||
]
|
||||
).astype(np.int64)
|
||||
else:
|
||||
static_cat_values = np.zeros((len(group), 0), dtype=np.int64)
|
||||
|
||||
for index in range(len(group)):
|
||||
window_slice = seq_values[max(0, index - WINDOW_SIZE + 1): index + 1]
|
||||
sequence_window = np.zeros((WINDOW_SIZE, len(SEQUENCE_FEATURES)), dtype=np.float32)
|
||||
sequence_window[-len(window_slice):] = window_slice
|
||||
sequence_samples.append(sequence_window)
|
||||
static_samples.append(static_values[index].astype(np.float32))
|
||||
start_index = max(0, index - WINDOW_SIZE + 1)
|
||||
num_slice = seq_num_values[start_index: index + 1]
|
||||
cat_slice = seq_cat_values[start_index: index + 1]
|
||||
|
||||
num_window = np.zeros((WINDOW_SIZE, len(feature_layout['seq_num_features'])), dtype=np.float32)
|
||||
cat_window = np.zeros((WINDOW_SIZE, len(feature_layout['seq_cat_features'])), dtype=np.int64)
|
||||
num_window[-len(num_slice):] = num_slice
|
||||
if len(feature_layout['seq_cat_features']) > 0:
|
||||
cat_window[-len(cat_slice):] = cat_slice
|
||||
|
||||
sequence_num_samples.append(num_window)
|
||||
sequence_cat_samples.append(cat_window)
|
||||
static_num_samples.append(static_num_values[index].astype(np.float32))
|
||||
static_cat_samples.append(static_cat_values[index].astype(np.int64))
|
||||
targets.append(float(target_values[index]))
|
||||
|
||||
targets = np.array(targets, dtype=np.float32)
|
||||
targets_array = np.array(targets, dtype=np.float32)
|
||||
if target_transform == 'log1p':
|
||||
targets = np.log1p(np.clip(targets, a_min=0, a_max=None)).astype(np.float32)
|
||||
targets_array = np.log1p(np.clip(targets_array, a_min=0, a_max=None)).astype(np.float32)
|
||||
|
||||
return (
|
||||
np.array(sequence_samples, dtype=np.float32),
|
||||
np.array(static_samples, dtype=np.float32),
|
||||
targets,
|
||||
np.array(sequence_num_samples, dtype=np.float32),
|
||||
np.array(sequence_cat_samples, dtype=np.int64),
|
||||
np.array(static_num_samples, dtype=np.float32),
|
||||
np.array(static_cat_samples, dtype=np.int64),
|
||||
targets_array,
|
||||
)
|
||||
|
||||
|
||||
def _train_validation_split(train_df: pd.DataFrame, validation_ratio: float = 0.15) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
||||
employee_ids = train_df[config.EMPLOYEE_ID_COLUMN].dropna().astype(str).unique().tolist()
|
||||
rng = np.random.default_rng(config.RANDOM_STATE)
|
||||
rng.shuffle(employee_ids)
|
||||
validation_count = max(1, int(len(employee_ids) * validation_ratio))
|
||||
validation_ids = set(employee_ids[:validation_count])
|
||||
|
||||
validation_df = train_df[train_df[config.EMPLOYEE_ID_COLUMN].astype(str).isin(validation_ids)].copy()
|
||||
fit_df = train_df[~train_df[config.EMPLOYEE_ID_COLUMN].astype(str).isin(validation_ids)].copy()
|
||||
if fit_df.empty or validation_df.empty:
|
||||
split_index = max(1, int(len(train_df) * (1 - validation_ratio)))
|
||||
fit_df = train_df.iloc[:split_index].copy()
|
||||
validation_df = train_df.iloc[split_index:].copy()
|
||||
return fit_df, validation_df
|
||||
|
||||
|
||||
def _prepare_inference_window(
|
||||
df: pd.DataFrame,
|
||||
feature_layout: Dict[str, List[str]],
|
||||
category_maps: Dict[str, Dict[str, int]],
|
||||
default_sequence_num_prefix: np.ndarray,
|
||||
default_sequence_cat_prefix: np.ndarray,
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||||
df = engineer_features(df.copy())
|
||||
for feature in feature_layout['used_features']:
|
||||
if feature not in df.columns:
|
||||
df[feature] = 0
|
||||
|
||||
row = df.iloc[0]
|
||||
|
||||
seq_num_row = row[feature_layout['seq_num_features']].astype(float).to_numpy(dtype=np.float32)
|
||||
static_num_row = row[feature_layout['static_num_features']].astype(float).to_numpy(dtype=np.float32)
|
||||
|
||||
if feature_layout['seq_cat_features']:
|
||||
seq_cat_row = np.array(
|
||||
[category_maps[feature].get(str(row[feature]), 0) for feature in feature_layout['seq_cat_features']],
|
||||
dtype=np.int64,
|
||||
)
|
||||
else:
|
||||
seq_cat_row = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
if feature_layout['static_cat_features']:
|
||||
static_cat_row = np.array(
|
||||
[category_maps[feature].get(str(row[feature]), 0) for feature in feature_layout['static_cat_features']],
|
||||
dtype=np.int64,
|
||||
)
|
||||
else:
|
||||
static_cat_row = np.zeros((0,), dtype=np.int64)
|
||||
|
||||
sequence_num_window = np.vstack([default_sequence_num_prefix, seq_num_row.reshape(1, -1)]).astype(np.float32)
|
||||
if len(feature_layout['seq_cat_features']) > 0:
|
||||
sequence_cat_window = np.vstack([default_sequence_cat_prefix, seq_cat_row.reshape(1, -1)]).astype(np.int64)
|
||||
else:
|
||||
sequence_cat_window = np.zeros((WINDOW_SIZE, 0), dtype=np.int64)
|
||||
|
||||
return sequence_num_window, sequence_cat_window, static_num_row, static_cat_row
|
||||
|
||||
|
||||
def _evaluate_model(
|
||||
model: nn.Module,
|
||||
loader: DataLoader,
|
||||
device: torch.device,
|
||||
target_transform: str,
|
||||
) -> Tuple[float, Dict[str, float]]:
|
||||
model.eval()
|
||||
predictions = []
|
||||
targets = []
|
||||
with torch.no_grad():
|
||||
for batch_seq_num, batch_seq_cat, batch_static_num, batch_static_cat, batch_target in loader:
|
||||
batch_seq_num = batch_seq_num.to(device)
|
||||
batch_seq_cat = batch_seq_cat.to(device)
|
||||
batch_static_num = batch_static_num.to(device)
|
||||
batch_static_cat = batch_static_cat.to(device)
|
||||
batch_predictions = model(batch_seq_num, batch_seq_cat, batch_static_num, batch_static_cat)
|
||||
predictions.append(batch_predictions.cpu().numpy())
|
||||
targets.append(batch_target.numpy())
|
||||
|
||||
y_pred = np.concatenate(predictions) if predictions else np.array([], dtype=np.float32)
|
||||
y_true = np.concatenate(targets) if targets else np.array([], dtype=np.float32)
|
||||
|
||||
if target_transform == 'log1p':
|
||||
y_pred_eval = np.expm1(y_pred)
|
||||
y_true_eval = np.expm1(y_true)
|
||||
else:
|
||||
y_pred_eval = y_pred
|
||||
y_true_eval = y_true
|
||||
y_pred_eval = np.clip(y_pred_eval, a_min=0, a_max=None)
|
||||
mse = mean_squared_error(y_true_eval, y_pred_eval)
|
||||
metrics = {
|
||||
'r2': float(r2_score(y_true_eval, y_pred_eval)),
|
||||
'mse': float(mse),
|
||||
'rmse': float(np.sqrt(mse)),
|
||||
'mae': float(mean_absolute_error(y_true_eval, y_pred_eval)),
|
||||
}
|
||||
return metrics['rmse'], metrics
|
||||
|
||||
|
||||
def train_lstm_mlp(
|
||||
train_df: pd.DataFrame,
|
||||
test_df: pd.DataFrame,
|
||||
model_path: str,
|
||||
target_transform: str = 'log1p',
|
||||
epochs: int = 24,
|
||||
batch_size: int = 128,
|
||||
epochs: int = DEFAULT_EPOCHS,
|
||||
batch_size: int = DEFAULT_BATCH_SIZE,
|
||||
) -> Optional[Dict]:
|
||||
if torch is None:
|
||||
return None
|
||||
|
||||
used_features = sorted(set(SEQUENCE_FEATURES + STATIC_FEATURES))
|
||||
category_maps = _fit_category_maps(train_df, used_features)
|
||||
train_seq, train_static, y_train = _build_sequence_arrays(train_df, category_maps, target_transform)
|
||||
test_seq, test_static, y_test_transformed = _build_sequence_arrays(test_df, category_maps, target_transform)
|
||||
fit_df, validation_df = _train_validation_split(train_df)
|
||||
feature_layout = _build_feature_layout(fit_df)
|
||||
category_maps = _fit_category_maps(fit_df, feature_layout['all_cat_features'])
|
||||
|
||||
seq_mean, seq_std = _safe_standardize(train_seq.reshape(-1, train_seq.shape[-1]))
|
||||
static_mean, static_std = _safe_standardize(train_static)
|
||||
train_seq_num, train_seq_cat, train_static_num, train_static_cat, y_train = _build_sequence_arrays(
|
||||
fit_df, feature_layout, category_maps, target_transform
|
||||
)
|
||||
val_seq_num, val_seq_cat, val_static_num, val_static_cat, y_val = _build_sequence_arrays(
|
||||
validation_df, feature_layout, category_maps, target_transform
|
||||
)
|
||||
test_seq_num, test_seq_cat, test_static_num, test_static_cat, _ = _build_sequence_arrays(
|
||||
test_df, feature_layout, category_maps, target_transform
|
||||
)
|
||||
|
||||
train_seq = ((train_seq - seq_mean) / seq_std).astype(np.float32)
|
||||
test_seq = ((test_seq - seq_mean) / seq_std).astype(np.float32)
|
||||
train_static = ((train_static - static_mean) / static_std).astype(np.float32)
|
||||
test_static = ((test_static - static_mean) / static_std).astype(np.float32)
|
||||
seq_mean, seq_std = _safe_standardize(train_seq_num.reshape(-1, train_seq_num.shape[-1]))
|
||||
static_mean, static_std = _safe_standardize(train_static_num)
|
||||
|
||||
train_seq_num = ((train_seq_num - seq_mean) / seq_std).astype(np.float32)
|
||||
val_seq_num = ((val_seq_num - seq_mean) / seq_std).astype(np.float32)
|
||||
test_seq_num = ((test_seq_num - seq_mean) / seq_std).astype(np.float32)
|
||||
|
||||
train_static_num = ((train_static_num - static_mean) / static_std).astype(np.float32)
|
||||
val_static_num = ((val_static_num - static_mean) / static_std).astype(np.float32)
|
||||
test_static_num = ((test_static_num - static_mean) / static_std).astype(np.float32)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
if device.type == 'cuda':
|
||||
device_name = torch.cuda.get_device_name(device)
|
||||
print(f'[lstm_mlp] Training device: CUDA ({device_name})')
|
||||
print(f'[lstm_mlp] Training device: CUDA ({torch.cuda.get_device_name(device)})')
|
||||
else:
|
||||
print('[lstm_mlp] Training device: CPU')
|
||||
model = LSTMMLPRegressor(seq_input_dim=train_seq.shape[-1], static_input_dim=train_static.shape[-1]).to(device)
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.MSELoss()
|
||||
|
||||
train_dataset = TensorDataset(
|
||||
torch.tensor(train_seq),
|
||||
torch.tensor(train_static),
|
||||
torch.tensor(y_train),
|
||||
model = LSTMMLPRegressor(
|
||||
seq_num_dim=train_seq_num.shape[-1],
|
||||
static_num_dim=train_static_num.shape[-1],
|
||||
seq_cat_cardinalities=[len(category_maps[feature]) + 1 for feature in feature_layout['seq_cat_features']],
|
||||
static_cat_cardinalities=[len(category_maps[feature]) + 1 for feature in feature_layout['static_cat_features']],
|
||||
).to(device)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0012, weight_decay=1e-4)
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
optimizer, mode='min', factor=0.6, patience=4, min_lr=1e-5
|
||||
)
|
||||
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
||||
criterion = nn.SmoothL1Loss(beta=0.35)
|
||||
|
||||
train_loader = DataLoader(
|
||||
SequenceStaticDataset(train_seq_num, train_seq_cat, train_static_num, train_static_cat, y_train),
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
)
|
||||
val_loader = DataLoader(
|
||||
SequenceStaticDataset(val_seq_num, val_seq_cat, val_static_num, val_static_cat, y_val),
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
best_state = None
|
||||
best_metrics = None
|
||||
best_val_rmse = float('inf')
|
||||
stale_epochs = 0
|
||||
|
||||
for epoch in range(epochs):
|
||||
model.train()
|
||||
for _ in range(epochs):
|
||||
for batch_seq, batch_static, batch_target in train_loader:
|
||||
batch_seq = batch_seq.to(device)
|
||||
batch_static = batch_static.to(device)
|
||||
running_loss = 0.0
|
||||
for batch_seq_num, batch_seq_cat, batch_static_num, batch_static_cat, batch_target in train_loader:
|
||||
batch_seq_num = batch_seq_num.to(device)
|
||||
batch_seq_cat = batch_seq_cat.to(device)
|
||||
batch_static_num = batch_static_num.to(device)
|
||||
batch_static_cat = batch_static_cat.to(device)
|
||||
batch_target = batch_target.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
predictions = model(batch_seq, batch_static)
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
predictions = model(batch_seq_num, batch_seq_cat, batch_static_num, batch_static_cat)
|
||||
loss = criterion(predictions, batch_target)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
optimizer.step()
|
||||
running_loss += float(loss.item()) * len(batch_target)
|
||||
|
||||
train_loss = running_loss / max(1, len(train_loader.dataset))
|
||||
val_rmse, val_metrics = _evaluate_model(model, val_loader, device, target_transform)
|
||||
scheduler.step(val_rmse)
|
||||
|
||||
improved = val_rmse + 1e-4 < best_val_rmse
|
||||
if improved:
|
||||
best_val_rmse = val_rmse
|
||||
best_metrics = val_metrics
|
||||
best_state = copy.deepcopy(model.state_dict())
|
||||
stale_epochs = 0
|
||||
else:
|
||||
stale_epochs += 1
|
||||
|
||||
if epoch == 0 or (epoch + 1) % 5 == 0 or improved:
|
||||
print(
|
||||
f'[lstm_mlp] epoch={epoch + 1:02d} train_loss={train_loss:.4f} '
|
||||
f'val_r2={val_metrics["r2"]:.4f} val_rmse={val_metrics["rmse"]:.4f}'
|
||||
)
|
||||
|
||||
if stale_epochs >= EARLY_STOPPING_PATIENCE:
|
||||
print(f'[lstm_mlp] Early stopping at epoch {epoch + 1}')
|
||||
break
|
||||
|
||||
if best_state is None:
|
||||
best_state = copy.deepcopy(model.state_dict())
|
||||
model.load_state_dict(best_state)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
predictions = model(
|
||||
torch.tensor(test_seq).to(device),
|
||||
torch.tensor(test_static).to(device),
|
||||
torch.tensor(test_seq_num, dtype=torch.float32).to(device),
|
||||
torch.tensor(test_seq_cat, dtype=torch.long).to(device),
|
||||
torch.tensor(test_static_num, dtype=torch.float32).to(device),
|
||||
torch.tensor(test_static_cat, dtype=torch.long).to(device),
|
||||
).cpu().numpy()
|
||||
|
||||
if target_transform == 'log1p':
|
||||
y_pred = np.expm1(predictions)
|
||||
else:
|
||||
y_pred = predictions
|
||||
y_true = test_df[config.TARGET_COLUMN].astype(float).values
|
||||
y_true = test_df[config.TARGET_COLUMN].astype(float).to_numpy(dtype=np.float32)
|
||||
y_pred = np.clip(y_pred, a_min=0, a_max=None)
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
|
||||
default_prefix = train_seq[:, :-1, :].mean(axis=0).astype(np.float32)
|
||||
default_sequence_num_prefix = train_seq_num[:, :-1, :].mean(axis=0).astype(np.float32)
|
||||
if train_seq_cat.shape[-1] > 0:
|
||||
default_sequence_cat_prefix = np.rint(train_seq_cat[:, :-1, :].mean(axis=0)).astype(np.int64)
|
||||
else:
|
||||
default_sequence_cat_prefix = np.zeros((WINDOW_SIZE - 1, 0), dtype=np.int64)
|
||||
|
||||
bundle = {
|
||||
'state_dict': model.state_dict(),
|
||||
'sequence_features': SEQUENCE_FEATURES,
|
||||
'static_features': STATIC_FEATURES,
|
||||
'window_size': WINDOW_SIZE,
|
||||
'target_transform': target_transform,
|
||||
'feature_layout': feature_layout,
|
||||
'category_maps': category_maps,
|
||||
'seq_mean': seq_mean,
|
||||
'seq_std': seq_std,
|
||||
'static_mean': static_mean,
|
||||
'static_std': static_std,
|
||||
'default_sequence_prefix': default_prefix,
|
||||
'window_size': WINDOW_SIZE,
|
||||
'target_transform': target_transform,
|
||||
'sequence_input_dim': train_seq.shape[-1],
|
||||
'static_input_dim': train_static.shape[-1],
|
||||
'default_sequence_num_prefix': default_sequence_num_prefix,
|
||||
'default_sequence_cat_prefix': default_sequence_cat_prefix,
|
||||
'seq_num_dim': train_seq_num.shape[-1],
|
||||
'static_num_dim': train_static_num.shape[-1],
|
||||
'seq_cat_cardinalities': [len(category_maps[feature]) + 1 for feature in feature_layout['seq_cat_features']],
|
||||
'static_cat_cardinalities': [len(category_maps[feature]) + 1 for feature in feature_layout['static_cat_features']],
|
||||
'best_validation_metrics': best_metrics,
|
||||
}
|
||||
torch.save(bundle, model_path)
|
||||
|
||||
return {
|
||||
'metrics': {
|
||||
'r2': round(r2_score(y_true, y_pred), 4),
|
||||
'mse': round(mse, 4),
|
||||
'r2': round(float(r2_score(y_true, y_pred)), 4),
|
||||
'mse': round(float(mse), 4),
|
||||
'rmse': round(float(np.sqrt(mse)), 4),
|
||||
'mae': round(mean_absolute_error(y_true, y_pred), 4),
|
||||
'mae': round(float(mean_absolute_error(y_true, y_pred)), 4),
|
||||
},
|
||||
'metadata': {
|
||||
'sequence_window_size': WINDOW_SIZE,
|
||||
'sequence_feature_names': SEQUENCE_FEATURES,
|
||||
'static_feature_names': STATIC_FEATURES,
|
||||
'deep_validation_r2': round(float(best_metrics['r2']), 4) if best_metrics else None,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -264,10 +587,12 @@ def train_lstm_mlp(
|
||||
def load_lstm_mlp_bundle(model_path: str) -> Optional[Dict]:
|
||||
if torch is None or not os.path.exists(model_path):
|
||||
return None
|
||||
bundle = torch.load(model_path, map_location='cpu')
|
||||
bundle = torch.load(model_path, map_location='cpu', weights_only=False)
|
||||
model = LSTMMLPRegressor(
|
||||
seq_input_dim=bundle['sequence_input_dim'],
|
||||
static_input_dim=bundle['static_input_dim'],
|
||||
seq_num_dim=bundle['seq_num_dim'],
|
||||
static_num_dim=bundle['static_num_dim'],
|
||||
seq_cat_cardinalities=bundle['seq_cat_cardinalities'],
|
||||
static_cat_cardinalities=bundle['static_cat_cardinalities'],
|
||||
)
|
||||
model.load_state_dict(bundle['state_dict'])
|
||||
model.eval()
|
||||
@@ -276,22 +601,23 @@ def load_lstm_mlp_bundle(model_path: str) -> Optional[Dict]:
|
||||
|
||||
|
||||
def predict_lstm_mlp(bundle: Dict, current_df: pd.DataFrame) -> float:
|
||||
df = engineer_features(current_df.copy())
|
||||
used_features = sorted(set(bundle['sequence_features'] + bundle['static_features']))
|
||||
df = _apply_category_maps(df, used_features, bundle['category_maps'])
|
||||
sequence_num_window, sequence_cat_window, static_num_row, static_cat_row = _prepare_inference_window(
|
||||
current_df,
|
||||
bundle['feature_layout'],
|
||||
bundle['category_maps'],
|
||||
bundle['default_sequence_num_prefix'],
|
||||
bundle['default_sequence_cat_prefix'],
|
||||
)
|
||||
|
||||
sequence_row = df[bundle['sequence_features']].astype(float).values[0].astype(np.float32)
|
||||
static_row = df[bundle['static_features']].astype(float).values[0].astype(np.float32)
|
||||
|
||||
prefix = bundle['default_sequence_prefix']
|
||||
sequence_window = np.vstack([prefix, sequence_row.reshape(1, -1)]).astype(np.float32)
|
||||
sequence_window = (sequence_window - bundle['seq_mean']) / bundle['seq_std']
|
||||
static_row = ((static_row - bundle['static_mean']) / bundle['static_std']).astype(np.float32)
|
||||
sequence_num_window = ((sequence_num_window - bundle['seq_mean']) / bundle['seq_std']).astype(np.float32)
|
||||
static_num_row = ((static_num_row - bundle['static_mean']) / bundle['static_std']).astype(np.float32)
|
||||
|
||||
with torch.no_grad():
|
||||
prediction = bundle['model'](
|
||||
torch.tensor(sequence_window).unsqueeze(0),
|
||||
torch.tensor(static_row).unsqueeze(0),
|
||||
torch.tensor(sequence_num_window, dtype=torch.float32).unsqueeze(0),
|
||||
torch.tensor(sequence_cat_window, dtype=torch.long).unsqueeze(0),
|
||||
torch.tensor(static_num_row, dtype=torch.float32).unsqueeze(0),
|
||||
torch.tensor(static_cat_row, dtype=torch.long).unsqueeze(0),
|
||||
).cpu().numpy()[0]
|
||||
|
||||
if bundle.get('target_transform') == 'log1p':
|
||||
|
||||
Reference in New Issue
Block a user