fix:评估对齐
This commit is contained in:
@@ -506,7 +506,7 @@ def train_lstm_mlp(
|
||||
val_seq_num, val_seq_cat, val_static_num, val_static_cat, y_val = _build_sequence_arrays(
|
||||
validation_df, feature_layout, category_maps, target_transform
|
||||
)
|
||||
test_seq_num, test_seq_cat, test_static_num, test_static_cat, _ = _build_sequence_arrays(
|
||||
test_seq_num, test_seq_cat, test_static_num, test_static_cat, y_test_aligned = _build_sequence_arrays(
|
||||
test_df, feature_layout, category_maps, target_transform
|
||||
)
|
||||
|
||||
@@ -620,9 +620,10 @@ def train_lstm_mlp(
|
||||
|
||||
if target_transform == 'log1p':
|
||||
y_pred = np.expm1(predictions)
|
||||
y_true = np.expm1(y_test_aligned)
|
||||
else:
|
||||
y_pred = predictions
|
||||
y_true = test_df[config.TARGET_COLUMN].astype(float).to_numpy(dtype=np.float32)
|
||||
y_true = y_test_aligned
|
||||
y_pred = np.clip(y_pred, a_min=0, a_max=None)
|
||||
mse = mean_squared_error(y_true, y_pred)
|
||||
|
||||
|
||||
@@ -56,11 +56,15 @@ def patch_lightgbm_sklearn_compatibility():
|
||||
return
|
||||
|
||||
params = inspect.signature(check_X_y).parameters
|
||||
if 'force_all_finite' in params or 'ensure_all_finite' not in params:
|
||||
if 'force_all_finite' in params:
|
||||
return
|
||||
|
||||
def wrapped_check_X_y(*args, force_all_finite=None, **kwargs):
|
||||
if force_all_finite is not None and 'ensure_all_finite' not in kwargs:
|
||||
if (
|
||||
force_all_finite is not None
|
||||
and 'ensure_all_finite' in params
|
||||
and 'ensure_all_finite' not in kwargs
|
||||
):
|
||||
kwargs['ensure_all_finite'] = force_all_finite
|
||||
return check_X_y(*args, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user