fix(training): patch lightgbm sklearn compatibility

This commit is contained in:
2026-03-12 18:15:09 +08:00
parent d7c8019f96
commit d70bd54c41
16 changed files with 885 additions and 203 deletions

View File

@@ -4,6 +4,7 @@ import joblib
import numpy as np
import config
from core.deep_learning_model import load_lstm_mlp_bundle, predict_lstm_mlp
from core.model_features import (
align_feature_frame,
apply_label_encoders,
@@ -20,6 +21,7 @@ MODEL_INFO = {
'gradient_boosting': {'name': 'gradient_boosting', 'name_cn': 'GBDT', 'description': '梯度提升决策树'},
'extra_trees': {'name': 'extra_trees', 'name_cn': '极端随机树', 'description': '高随机性的树模型'},
'stacking': {'name': 'stacking', 'name_cn': 'Stacking集成', 'description': '多模型融合'},
'lstm_mlp': {'name': 'lstm_mlp', 'name_cn': 'LSTM+MLP', 'description': '时序与静态特征融合的深度学习模型'},
}
@@ -50,6 +52,7 @@ class PredictService:
'gradient_boosting': 'gradient_boosting_model.pkl',
'extra_trees': 'extra_trees_model.pkl',
'stacking': 'stacking_model.pkl',
'lstm_mlp': 'lstm_mlp_model.pt',
}
allowed_models = self.training_metadata.get('available_models')
if allowed_models:
@@ -59,7 +62,12 @@ class PredictService:
path = os.path.join(config.MODELS_DIR, filename)
if os.path.exists(path):
try:
self.models[name] = joblib.load(path)
if name == 'lstm_mlp':
bundle = load_lstm_mlp_bundle(path)
if bundle is not None:
self.models[name] = bundle
else:
self.models[name] = joblib.load(path)
except Exception as exc:
print(f'Failed to load model {name}: {exc}')
@@ -107,8 +115,12 @@ class PredictService:
features = self._prepare_features(data)
try:
predicted_hours = self.models[model_type].predict([features])[0]
predicted_hours = self._inverse_transform_prediction(predicted_hours)
if model_type == 'lstm_mlp':
current_df = build_prediction_dataframe(data)
predicted_hours = predict_lstm_mlp(self.models[model_type], current_df)
else:
predicted_hours = self.models[model_type].predict([features])[0]
predicted_hours = self._inverse_transform_prediction(predicted_hours)
predicted_hours = max(0.5, float(predicted_hours))
except Exception:
return self._get_default_prediction(data)
@@ -196,6 +208,8 @@ class PredictService:
'test_samples': self.training_metadata.get('test_samples', 0),
'feature_count': self.training_metadata.get('feature_count_after_selection', 0),
'training_date': self.training_metadata.get('training_date', ''),
'sequence_window_size': self.training_metadata.get('sequence_window_size', 0),
'deep_learning_available': self.training_metadata.get('deep_learning_available', False),
},
}