fix:评估对齐

This commit is contained in:
2026-03-20 16:52:24 +08:00
parent 77e38fd15b
commit cc85e3807a
2 changed files with 9 additions and 4 deletions

View File

@@ -506,7 +506,7 @@ def train_lstm_mlp(
val_seq_num, val_seq_cat, val_static_num, val_static_cat, y_val = _build_sequence_arrays(
validation_df, feature_layout, category_maps, target_transform
)
test_seq_num, test_seq_cat, test_static_num, test_static_cat, _ = _build_sequence_arrays(
test_seq_num, test_seq_cat, test_static_num, test_static_cat, y_test_aligned = _build_sequence_arrays(
test_df, feature_layout, category_maps, target_transform
)
@@ -620,9 +620,10 @@ def train_lstm_mlp(
if target_transform == 'log1p':
y_pred = np.expm1(predictions)
y_true = np.expm1(y_test_aligned)
else:
y_pred = predictions
y_true = test_df[config.TARGET_COLUMN].astype(float).to_numpy(dtype=np.float32)
y_true = y_test_aligned
y_pred = np.clip(y_pred, a_min=0, a_max=None)
mse = mean_squared_error(y_true, y_pred)