fix(training): patch lightgbm sklearn compatibility
This commit is contained in:
@@ -4,6 +4,7 @@ import joblib
|
||||
import numpy as np
|
||||
|
||||
import config
|
||||
from core.deep_learning_model import load_lstm_mlp_bundle, predict_lstm_mlp
|
||||
from core.model_features import (
|
||||
align_feature_frame,
|
||||
apply_label_encoders,
|
||||
@@ -20,6 +21,7 @@ MODEL_INFO = {
|
||||
'gradient_boosting': {'name': 'gradient_boosting', 'name_cn': 'GBDT', 'description': '梯度提升决策树'},
|
||||
'extra_trees': {'name': 'extra_trees', 'name_cn': '极端随机树', 'description': '高随机性的树模型'},
|
||||
'stacking': {'name': 'stacking', 'name_cn': 'Stacking集成', 'description': '多模型融合'},
|
||||
'lstm_mlp': {'name': 'lstm_mlp', 'name_cn': 'LSTM+MLP', 'description': '时序与静态特征融合的深度学习模型'},
|
||||
}
|
||||
|
||||
|
||||
@@ -50,6 +52,7 @@ class PredictService:
|
||||
'gradient_boosting': 'gradient_boosting_model.pkl',
|
||||
'extra_trees': 'extra_trees_model.pkl',
|
||||
'stacking': 'stacking_model.pkl',
|
||||
'lstm_mlp': 'lstm_mlp_model.pt',
|
||||
}
|
||||
allowed_models = self.training_metadata.get('available_models')
|
||||
if allowed_models:
|
||||
@@ -59,7 +62,12 @@ class PredictService:
|
||||
path = os.path.join(config.MODELS_DIR, filename)
|
||||
if os.path.exists(path):
|
||||
try:
|
||||
self.models[name] = joblib.load(path)
|
||||
if name == 'lstm_mlp':
|
||||
bundle = load_lstm_mlp_bundle(path)
|
||||
if bundle is not None:
|
||||
self.models[name] = bundle
|
||||
else:
|
||||
self.models[name] = joblib.load(path)
|
||||
except Exception as exc:
|
||||
print(f'Failed to load model {name}: {exc}')
|
||||
|
||||
@@ -107,8 +115,12 @@ class PredictService:
|
||||
|
||||
features = self._prepare_features(data)
|
||||
try:
|
||||
predicted_hours = self.models[model_type].predict([features])[0]
|
||||
predicted_hours = self._inverse_transform_prediction(predicted_hours)
|
||||
if model_type == 'lstm_mlp':
|
||||
current_df = build_prediction_dataframe(data)
|
||||
predicted_hours = predict_lstm_mlp(self.models[model_type], current_df)
|
||||
else:
|
||||
predicted_hours = self.models[model_type].predict([features])[0]
|
||||
predicted_hours = self._inverse_transform_prediction(predicted_hours)
|
||||
predicted_hours = max(0.5, float(predicted_hours))
|
||||
except Exception:
|
||||
return self._get_default_prediction(data)
|
||||
@@ -196,6 +208,8 @@ class PredictService:
|
||||
'test_samples': self.training_metadata.get('test_samples', 0),
|
||||
'feature_count': self.training_metadata.get('feature_count_after_selection', 0),
|
||||
'training_date': self.training_metadata.get('training_date', ''),
|
||||
'sequence_window_size': self.training_metadata.get('sequence_window_size', 0),
|
||||
'deep_learning_available': self.training_metadata.get('deep_learning_available', False),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user