fix:评估对齐
This commit is contained in:
@@ -506,7 +506,7 @@ def train_lstm_mlp(
|
|||||||
val_seq_num, val_seq_cat, val_static_num, val_static_cat, y_val = _build_sequence_arrays(
|
val_seq_num, val_seq_cat, val_static_num, val_static_cat, y_val = _build_sequence_arrays(
|
||||||
validation_df, feature_layout, category_maps, target_transform
|
validation_df, feature_layout, category_maps, target_transform
|
||||||
)
|
)
|
||||||
test_seq_num, test_seq_cat, test_static_num, test_static_cat, _ = _build_sequence_arrays(
|
test_seq_num, test_seq_cat, test_static_num, test_static_cat, y_test_aligned = _build_sequence_arrays(
|
||||||
test_df, feature_layout, category_maps, target_transform
|
test_df, feature_layout, category_maps, target_transform
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -620,9 +620,10 @@ def train_lstm_mlp(
|
|||||||
|
|
||||||
if target_transform == 'log1p':
|
if target_transform == 'log1p':
|
||||||
y_pred = np.expm1(predictions)
|
y_pred = np.expm1(predictions)
|
||||||
|
y_true = np.expm1(y_test_aligned)
|
||||||
else:
|
else:
|
||||||
y_pred = predictions
|
y_pred = predictions
|
||||||
y_true = test_df[config.TARGET_COLUMN].astype(float).to_numpy(dtype=np.float32)
|
y_true = y_test_aligned
|
||||||
y_pred = np.clip(y_pred, a_min=0, a_max=None)
|
y_pred = np.clip(y_pred, a_min=0, a_max=None)
|
||||||
mse = mean_squared_error(y_true, y_pred)
|
mse = mean_squared_error(y_true, y_pred)
|
||||||
|
|
||||||
|
|||||||
@@ -56,11 +56,15 @@ def patch_lightgbm_sklearn_compatibility():
|
|||||||
return
|
return
|
||||||
|
|
||||||
params = inspect.signature(check_X_y).parameters
|
params = inspect.signature(check_X_y).parameters
|
||||||
if 'force_all_finite' in params or 'ensure_all_finite' not in params:
|
if 'force_all_finite' in params:
|
||||||
return
|
return
|
||||||
|
|
||||||
def wrapped_check_X_y(*args, force_all_finite=None, **kwargs):
|
def wrapped_check_X_y(*args, force_all_finite=None, **kwargs):
|
||||||
if force_all_finite is not None and 'ensure_all_finite' not in kwargs:
|
if (
|
||||||
|
force_all_finite is not None
|
||||||
|
and 'ensure_all_finite' in params
|
||||||
|
and 'ensure_all_finite' not in kwargs
|
||||||
|
):
|
||||||
kwargs['ensure_all_finite'] = force_all_finite
|
kwargs['ensure_all_finite'] = force_all_finite
|
||||||
return check_X_y(*args, **kwargs)
|
return check_X_y(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user