Polish absence analysis demo experience
This commit is contained in:
@@ -2,17 +2,21 @@ from core.clustering import KMeansAnalyzer
|
||||
|
||||
|
||||
class ClusterService:
|
||||
def __init__(self):
|
||||
self.analyzer = KMeansAnalyzer()
|
||||
|
||||
def _create_analyzer(self):
|
||||
# 聚类接口会被前端并发调用,避免复用同一个可变分析器实例导致结果串线。
|
||||
return KMeansAnalyzer()
|
||||
|
||||
def get_cluster_result(self, n_clusters=3):
|
||||
return self.analyzer.get_cluster_results(n_clusters)
|
||||
|
||||
analyzer = self._create_analyzer()
|
||||
return analyzer.get_cluster_results(n_clusters)
|
||||
|
||||
def get_cluster_profile(self, n_clusters=3):
|
||||
return self.analyzer.get_cluster_profile(n_clusters)
|
||||
|
||||
analyzer = self._create_analyzer()
|
||||
return analyzer.get_cluster_profile(n_clusters)
|
||||
|
||||
def get_scatter_data(self, n_clusters=3, x_axis='月均加班时长', y_axis='缺勤时长(小时)'):
|
||||
return self.analyzer.get_scatter_data(n_clusters, x_axis, y_axis)
|
||||
analyzer = self._create_analyzer()
|
||||
return analyzer.get_scatter_data(n_clusters, x_axis, y_axis)
|
||||
|
||||
|
||||
cluster_service = ClusterService()
|
||||
|
||||
@@ -16,18 +16,26 @@ from core.model_features import (
|
||||
|
||||
MODEL_INFO = {
|
||||
'random_forest': {'name': 'random_forest', 'name_cn': '随机森林', 'description': '稳健的树模型集成'},
|
||||
'xgboost': {'name': 'xgboost', 'name_cn': 'XGBoost', 'description': '梯度提升树模型'},
|
||||
'lightgbm': {'name': 'lightgbm', 'name_cn': 'LightGBM', 'description': '轻量级梯度提升树'},
|
||||
'gradient_boosting': {'name': 'gradient_boosting', 'name_cn': 'GBDT', 'description': '梯度提升决策树'},
|
||||
'xgboost': {'name': 'xgboost', 'name_cn': '增强树模型一', 'description': '梯度提升树模型'},
|
||||
'lightgbm': {'name': 'lightgbm', 'name_cn': '增强树模型二', 'description': '轻量级梯度提升树'},
|
||||
'gradient_boosting': {'name': 'gradient_boosting', 'name_cn': '梯度提升树', 'description': '梯度提升决策树'},
|
||||
'extra_trees': {'name': 'extra_trees', 'name_cn': '极端随机树', 'description': '高随机性的树模型'},
|
||||
'stacking': {'name': 'stacking', 'name_cn': 'Stacking集成', 'description': '多模型融合'},
|
||||
'stacking': {'name': 'stacking', 'name_cn': '集成模型', 'description': '多模型融合'},
|
||||
'lstm_mlp': {
|
||||
'name': 'lstm_mlp',
|
||||
'name_cn': '时序注意力融合网络',
|
||||
'description': 'Transformer时序编码 + 静态特征门控融合的深度学习模型',
|
||||
'description': 'Transformer 时序编码与静态特征融合的深度学习模型',
|
||||
},
|
||||
}
|
||||
|
||||
EXPLAINABLE_TREE_MODELS = (
|
||||
'random_forest',
|
||||
'xgboost',
|
||||
'lightgbm',
|
||||
'gradient_boosting',
|
||||
'extra_trees',
|
||||
)
|
||||
|
||||
|
||||
class PredictService:
|
||||
def __init__(self):
|
||||
@@ -96,7 +104,6 @@ class PredictService:
|
||||
if valid_metrics:
|
||||
self.default_model = max(valid_metrics.items(), key=lambda item: item[1]['r2'])[0]
|
||||
|
||||
# 加载风险分类模型
|
||||
for name in ['random_forest', 'gradient_boosting', 'lightgbm', 'xgboost']:
|
||||
path = os.path.join(config.MODELS_DIR, f'risk_{name}_classifier.pkl')
|
||||
if os.path.exists(path):
|
||||
@@ -123,18 +130,22 @@ class PredictService:
|
||||
models.sort(key=lambda item: item['metrics']['r2'], reverse=True)
|
||||
return models
|
||||
|
||||
def predict_single(self, data, model_type=None):
|
||||
def predict_single(self, data, model_type=None, include_explanation=True):
|
||||
self._ensure_models_loaded()
|
||||
model_type = model_type or self.default_model
|
||||
if model_type not in self.models:
|
||||
fallback = next(iter(self.models), None)
|
||||
if fallback is None:
|
||||
return self._get_default_prediction(data)
|
||||
model_type = fallback
|
||||
if self.scaler is None or self.feature_names is None:
|
||||
return self._get_default_prediction(data)
|
||||
model_type = self._resolve_prediction_model(model_type or self.default_model)
|
||||
_, engineered_df = self._build_prediction_frames(data)
|
||||
engineered_row = engineered_df.iloc[0]
|
||||
|
||||
if model_type is None or self.scaler is None or self.feature_names is None:
|
||||
result = self._get_default_prediction(data)
|
||||
return self._augment_prediction_result(result, data, engineered_row) if include_explanation else result
|
||||
|
||||
try:
|
||||
features = self._prepare_features_from_engineered(engineered_df)
|
||||
except Exception:
|
||||
result = self._get_default_prediction(data)
|
||||
return self._augment_prediction_result(result, data, engineered_row) if include_explanation else result
|
||||
|
||||
features = self._prepare_features(data)
|
||||
try:
|
||||
if model_type == 'lstm_mlp':
|
||||
current_df = build_prediction_dataframe(data)
|
||||
@@ -144,15 +155,14 @@ class PredictService:
|
||||
predicted_hours = self._inverse_transform_prediction(predicted_hours)
|
||||
predicted_hours = max(0.5, float(predicted_hours))
|
||||
except Exception:
|
||||
return self._get_default_prediction(data)
|
||||
result = self._get_default_prediction(data)
|
||||
return self._augment_prediction_result(result, data, engineered_row) if include_explanation else result
|
||||
|
||||
risk_level, risk_label = self._get_risk_level(predicted_hours)
|
||||
confidence = max(0.5, self.model_metrics.get(model_type, {}).get('r2', 0.82))
|
||||
|
||||
# 风险分类概率
|
||||
risk_probability = self._get_risk_probability(features, model_type)
|
||||
|
||||
return {
|
||||
result = {
|
||||
'predicted_hours': round(predicted_hours, 2),
|
||||
'risk_level': risk_level,
|
||||
'risk_label': risk_label,
|
||||
@@ -161,12 +171,13 @@ class PredictService:
|
||||
'model_used': model_type,
|
||||
'model_name_cn': MODEL_INFO.get(model_type, {}).get('name_cn', model_type),
|
||||
}
|
||||
return self._augment_prediction_result(result, data, engineered_row) if include_explanation else result
|
||||
|
||||
def predict_compare(self, data):
|
||||
self._ensure_models_loaded()
|
||||
results = []
|
||||
for name in self.models.keys():
|
||||
result = self.predict_single(data, name)
|
||||
result = self.predict_single(data, name, include_explanation=False)
|
||||
result['model'] = name
|
||||
result['model_name_cn'] = MODEL_INFO.get(name, {}).get('name_cn', name)
|
||||
result['r2'] = self.model_metrics.get(name, {}).get('r2', 0)
|
||||
@@ -176,10 +187,17 @@ class PredictService:
|
||||
results[0]['recommended'] = True
|
||||
return results
|
||||
|
||||
def _build_prediction_frames(self, data):
|
||||
current_df = build_prediction_dataframe(data)
|
||||
engineered_df = engineer_features(current_df.copy())
|
||||
return current_df, engineered_df
|
||||
|
||||
def _prepare_features(self, data):
|
||||
X_df = build_prediction_dataframe(data)
|
||||
X_df = engineer_features(X_df)
|
||||
X_df = apply_label_encoders(X_df, self.label_encoders)
|
||||
_, engineered_df = self._build_prediction_frames(data)
|
||||
return self._prepare_features_from_engineered(engineered_df)
|
||||
|
||||
def _prepare_features_from_engineered(self, engineered_df):
|
||||
X_df = apply_label_encoders(engineered_df.copy(), self.label_encoders)
|
||||
X_df = align_feature_frame(X_df, self.feature_names)
|
||||
features = self.scaler.transform(to_float_array(X_df))[0]
|
||||
if self.selected_features:
|
||||
@@ -188,6 +206,338 @@ class PredictService:
|
||||
features = features[selected_indices]
|
||||
return features
|
||||
|
||||
def _resolve_prediction_model(self, requested_model):
|
||||
if requested_model in self.models:
|
||||
return requested_model
|
||||
if self.default_model in self.models:
|
||||
return self.default_model
|
||||
return next(iter(self.models), None)
|
||||
|
||||
def _resolve_explanation_model(self, prediction_model):
|
||||
if prediction_model in EXPLAINABLE_TREE_MODELS and prediction_model in self.models:
|
||||
return prediction_model
|
||||
for candidate in ('random_forest', 'xgboost', 'lightgbm', 'gradient_boosting', 'extra_trees'):
|
||||
if candidate in self.models:
|
||||
return candidate
|
||||
return None
|
||||
|
||||
def _augment_prediction_result(self, result, data, engineered_row):
|
||||
explanation_model = self._resolve_explanation_model(result.get('model_used'))
|
||||
shap_local = self._get_local_explanation(data, explanation_model)
|
||||
jdr_snapshot = self._build_jdr_snapshot(engineered_row)
|
||||
mechanism_summary = self._build_mechanism_summary(result, data, jdr_snapshot, shap_local)
|
||||
intervention_suggestions = self._build_intervention_suggestions(data, jdr_snapshot, shap_local)
|
||||
|
||||
payload = dict(result)
|
||||
payload.update({
|
||||
'jdr_snapshot': jdr_snapshot,
|
||||
'mechanism_summary': mechanism_summary,
|
||||
'intervention_suggestions': intervention_suggestions,
|
||||
'explanation_model_used': explanation_model,
|
||||
'explanation_model_name_cn': MODEL_INFO.get(explanation_model, {}).get('name_cn', '机制解释'),
|
||||
'shap_local': shap_local,
|
||||
})
|
||||
return payload
|
||||
|
||||
def _get_local_explanation(self, data, model_type):
|
||||
if not model_type:
|
||||
return None
|
||||
try:
|
||||
from services.shap_service import shap_service
|
||||
|
||||
explanation = shap_service.get_local_explanation(data, model_type)
|
||||
if explanation and not explanation.get('error'):
|
||||
return explanation
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
def _build_jdr_snapshot(self, engineered_row):
|
||||
snapshot = {
|
||||
'job_demands': self._build_snapshot_item(
|
||||
'job_demands',
|
||||
'工作要求',
|
||||
engineered_row.get('工作要求指数', 0.0),
|
||||
*self._classify_job_demands(engineered_row.get('工作要求指数', 0.0)),
|
||||
),
|
||||
'job_resources': self._build_snapshot_item(
|
||||
'job_resources',
|
||||
'工作资源',
|
||||
engineered_row.get('工作资源指数', 0.0),
|
||||
*self._classify_resource_stock(engineered_row.get('工作资源指数', 0.0)),
|
||||
),
|
||||
'personal_resources': self._build_snapshot_item(
|
||||
'personal_resources',
|
||||
'个人资源',
|
||||
engineered_row.get('个人资源指数', 0.0),
|
||||
*self._classify_resource_stock(engineered_row.get('个人资源指数', 0.0)),
|
||||
),
|
||||
'balance': self._build_snapshot_item(
|
||||
'balance',
|
||||
'平衡度',
|
||||
engineered_row.get('JD-R平衡度', 0.0),
|
||||
*self._classify_balance(engineered_row.get('JD-R平衡度', 0.0)),
|
||||
),
|
||||
'burnout_risk': self._build_snapshot_item(
|
||||
'burnout_risk',
|
||||
'倦怠风险',
|
||||
engineered_row.get('倦怠风险指数', 0.0),
|
||||
*self._classify_burnout(engineered_row.get('倦怠风险指数', 0.0)),
|
||||
),
|
||||
'engagement': self._build_snapshot_item(
|
||||
'engagement',
|
||||
'工作投入',
|
||||
engineered_row.get('工作投入指数', 0.0),
|
||||
*self._classify_resource_stock(engineered_row.get('工作投入指数', 0.0)),
|
||||
),
|
||||
}
|
||||
return snapshot
|
||||
|
||||
def _build_snapshot_item(self, key, label, score, status, tone):
|
||||
return {
|
||||
'key': key,
|
||||
'label': label,
|
||||
'score': round(self._safe_float(score), 2),
|
||||
'status': status,
|
||||
'tone': tone,
|
||||
}
|
||||
|
||||
def _build_mechanism_summary(self, result, data, jdr_snapshot, shap_local):
|
||||
dimension_scores = self._extract_dimension_scores(shap_local)
|
||||
top_drivers = self._extract_feature_effects(shap_local, positive=True, limit=3)
|
||||
protective_factors = self._extract_feature_effects(shap_local, positive=False, limit=2)
|
||||
|
||||
pathway_label, pathway_tone, pathway_detail = self._infer_pathway(jdr_snapshot, dimension_scores)
|
||||
mechanism = self._build_mechanism_text(data, jdr_snapshot, dimension_scores, top_drivers)
|
||||
buffer_text = self._build_buffer_text(jdr_snapshot, protective_factors)
|
||||
scenario_hint = self._build_scenario_hint(data)
|
||||
|
||||
return {
|
||||
'conclusion': f"本次预测为{result['risk_label']},预计缺勤时长约 {result['predicted_hours']} 小时。",
|
||||
'mechanism': mechanism,
|
||||
'pathway_label': pathway_label,
|
||||
'pathway_tone': pathway_tone,
|
||||
'pathway_detail': pathway_detail,
|
||||
'buffer_text': buffer_text,
|
||||
'scenario_hint': scenario_hint,
|
||||
'top_drivers': top_drivers,
|
||||
'protective_factors': protective_factors,
|
||||
}
|
||||
|
||||
def _build_mechanism_text(self, data, jdr_snapshot, dimension_scores, top_drivers):
|
||||
if top_drivers:
|
||||
driver_names = '、'.join(item['name_cn'] for item in top_drivers)
|
||||
if dimension_scores.get('工作要求', 0.0) > 0.03:
|
||||
return f'主要推高因素集中在{driver_names},说明高工作要求正在直接抬升本次缺勤风险。'
|
||||
if dimension_scores.get('事件上下文', 0.0) > 0.03:
|
||||
return f'主要推高因素集中在{driver_names},当前结果更容易受到请假事件情境的直接触发。'
|
||||
if dimension_scores.get('工作资源', 0.0) > 0.03 or dimension_scores.get('个人资源', 0.0) > 0.03:
|
||||
return f'主要推高因素集中在{driver_names},说明资源缓冲不足正在放大本次缺勤时长。'
|
||||
return f'主要推高因素集中在{driver_names},它们共同推动了本次缺勤时长上升。'
|
||||
|
||||
fragments = []
|
||||
if jdr_snapshot['job_demands']['tone'] in {'warning', 'danger'}:
|
||||
fragments.append('工作要求偏高')
|
||||
if jdr_snapshot['job_resources']['tone'] == 'danger':
|
||||
fragments.append('工作资源不足')
|
||||
if jdr_snapshot['personal_resources']['tone'] == 'danger':
|
||||
fragments.append('个人资源偏弱')
|
||||
if self._as_flag(data.get('medical_certificate_flag')) or self._as_flag(data.get('near_holiday_flag')):
|
||||
fragments.append('事件情境触发明显')
|
||||
if not fragments:
|
||||
return '当前结果更多体现为常规缺勤波动,整体压力与资源结构暂时可控。'
|
||||
return f"当前结果主要由{'、'.join(fragments)}共同驱动。"
|
||||
|
||||
def _build_buffer_text(self, jdr_snapshot, protective_factors):
|
||||
if protective_factors:
|
||||
names = '、'.join(item['name_cn'] for item in protective_factors)
|
||||
return f'{names}对当前风险仍有一定缓冲作用,但尚不足以完全抵消主要压力来源。'
|
||||
if jdr_snapshot['job_resources']['tone'] in {'success', 'info'} and jdr_snapshot['personal_resources']['tone'] in {'success', 'info'}:
|
||||
return '当前资源支持和个人恢复能力对风险有一定缓冲,但事件性因素仍需持续关注。'
|
||||
return ''
|
||||
|
||||
def _build_scenario_hint(self, data):
|
||||
actions = []
|
||||
if self._safe_float(data.get('monthly_overtime_hours', 0.0)) >= 25:
|
||||
actions.append('将月均加班控制在 20 小时以内')
|
||||
if self._safe_float(data.get('commute_minutes', 0.0)) >= 45:
|
||||
actions.append('把通勤时长压缩到 30 分钟左右')
|
||||
if self._as_flag(data.get('is_night_shift')):
|
||||
actions.append('减少连续夜班或延长轮休恢复时间')
|
||||
if not actions:
|
||||
return ''
|
||||
if len(actions) == 1:
|
||||
return f'情境判断:若能{actions[0]},当前风险通常会有所回落。'
|
||||
return f"情境判断:若能{',并'.join(actions[:-1])},同时{actions[-1]},当前风险通常会有所回落。"
|
||||
|
||||
def _infer_pathway(self, jdr_snapshot, dimension_scores):
|
||||
demands_pressure = dimension_scores.get('工作要求', 0.0)
|
||||
mediator_pressure = dimension_scores.get('中介变量', 0.0)
|
||||
resource_pressure = dimension_scores.get('工作资源', 0.0) + dimension_scores.get('个人资源', 0.0)
|
||||
event_pressure = dimension_scores.get('事件上下文', 0.0)
|
||||
|
||||
demands_high = jdr_snapshot['job_demands']['tone'] == 'danger'
|
||||
burnout_high = jdr_snapshot['burnout_risk']['tone'] in {'warning', 'danger'}
|
||||
resources_low = (
|
||||
jdr_snapshot['job_resources']['tone'] == 'danger'
|
||||
or jdr_snapshot['personal_resources']['tone'] == 'danger'
|
||||
or jdr_snapshot['engagement']['tone'] == 'danger'
|
||||
)
|
||||
|
||||
if demands_high or burnout_high or demands_pressure > 0.03 or mediator_pressure > 0.03:
|
||||
if resources_low or resource_pressure > 0.03:
|
||||
return (
|
||||
'健康损耗与资源缓冲不足',
|
||||
'danger',
|
||||
'当前结果同时表现出高要求累积与资源缓冲不足,更接近“工作要求上升 → 倦怠累积 → 缺勤增加”的复合路径。',
|
||||
)
|
||||
return (
|
||||
'健康损耗路径为主',
|
||||
'warning',
|
||||
'当前结果更接近“工作要求上升 → 倦怠累积 → 缺勤增加”的健康损耗路径。',
|
||||
)
|
||||
if resources_low or resource_pressure > 0.03:
|
||||
return (
|
||||
'激励支撑不足路径',
|
||||
'warning',
|
||||
'当前资源与个人恢复能力偏弱,工作投入对缺勤风险的缓冲作用有限。',
|
||||
)
|
||||
if event_pressure > 0.04:
|
||||
return (
|
||||
'事件触发型波动',
|
||||
'info',
|
||||
'当前结果更容易受到请假类型、医院证明和节假日前后等事件情境直接触发。',
|
||||
)
|
||||
return (
|
||||
'混合影响路径',
|
||||
'info',
|
||||
'当前结果同时受到工作要求、资源结构与事件情境的共同影响,尚不属于单一路径主导。',
|
||||
)
|
||||
|
||||
def _build_intervention_suggestions(self, data, jdr_snapshot, shap_local):
|
||||
suggestions = []
|
||||
|
||||
demand_items = []
|
||||
overtime_hours = self._safe_float(data.get('monthly_overtime_hours', 0.0))
|
||||
commute_minutes = self._safe_float(data.get('commute_minutes', 0.0))
|
||||
if overtime_hours >= 25 or jdr_snapshot['job_demands']['tone'] == 'danger':
|
||||
demand_items.append('优先压降连续高负荷排班,尽量把月均加班控制在 20 小时以内。')
|
||||
if commute_minutes >= 45:
|
||||
demand_items.append('若条件允许,可通过弹性到岗、调班或就近安排缓和通勤压力。')
|
||||
if self._as_flag(data.get('is_night_shift')):
|
||||
demand_items.append('夜班岗位建议增加轮休和班后恢复时段,避免疲劳持续累积。')
|
||||
if self._as_flag(data.get('near_holiday_flag')):
|
||||
demand_items.append('节假日前后可提前做好替班和排班缓冲,减少事件性缺勤波动。')
|
||||
if not demand_items:
|
||||
demand_items.append('当前工作要求未明显失衡,重点保持排班稳定并持续监控波动。')
|
||||
suggestions.append({'category': '减要求', 'items': self._limit_unique_items(demand_items)})
|
||||
|
||||
resource_items = []
|
||||
if jdr_snapshot['job_resources']['tone'] in {'warning', 'danger'}:
|
||||
resource_items.append('增加主管沟通、临时替班支持和班组协同,补足组织支持资源。')
|
||||
if jdr_snapshot['balance']['tone'] in {'warning', 'danger'}:
|
||||
resource_items.append('对高风险岗位提供更清晰的任务边界和优先级,降低角色冲突。')
|
||||
if str(data.get('leave_reason_category', '')) == '子女照护':
|
||||
resource_items.append('可结合弹性工时或家庭照护支持,缓解家庭事务对缺勤的放大作用。')
|
||||
if not resource_items:
|
||||
resource_items.append('当前资源面整体可用,建议继续维持支持性排班和沟通反馈机制。')
|
||||
suggestions.append({'category': '增资源', 'items': self._limit_unique_items(resource_items)})
|
||||
|
||||
personal_items = []
|
||||
if self._as_flag(data.get('chronic_disease_flag')) or self._as_flag(data.get('medical_certificate_flag')):
|
||||
personal_items.append('结合健康监测、复诊安排和短期工作调整,降低身体不适带来的持续缺勤风险。')
|
||||
if jdr_snapshot['burnout_risk']['tone'] in {'warning', 'danger'}:
|
||||
personal_items.append('建议通过休息恢复、情绪支持和短周期工作调整,缓冲倦怠累积。')
|
||||
if jdr_snapshot['personal_resources']['tone'] == 'danger':
|
||||
personal_items.append('可通过辅导、复盘和岗位支持增强员工自我效能与心理韧性。')
|
||||
if not personal_items:
|
||||
personal_items.append('当前个体恢复能力整体可控,重点维持规律作息和健康管理即可。')
|
||||
suggestions.append({'category': '补个人资源', 'items': self._limit_unique_items(personal_items)})
|
||||
|
||||
return suggestions
|
||||
|
||||
def _extract_dimension_scores(self, shap_local):
|
||||
if not shap_local:
|
||||
return {}
|
||||
dimension_contribution = shap_local.get('dimension_contribution', {})
|
||||
return {
|
||||
key: self._safe_float(value)
|
||||
for key, value in dimension_contribution.items()
|
||||
if isinstance(value, (int, float))
|
||||
}
|
||||
|
||||
def _extract_feature_effects(self, shap_local, positive=True, limit=3):
|
||||
if not shap_local:
|
||||
return []
|
||||
features = shap_local.get('features', [])
|
||||
filtered = []
|
||||
for item in features:
|
||||
shap_value = self._safe_float(item.get('shap_value', 0.0))
|
||||
if positive and shap_value <= 0:
|
||||
continue
|
||||
if not positive and shap_value >= 0:
|
||||
continue
|
||||
filtered.append({
|
||||
'name': item.get('name'),
|
||||
'name_cn': item.get('name_cn') or item.get('name') or '未命名特征',
|
||||
'dimension': self._dimension_label(item.get('dimension')),
|
||||
'shap_value': round(shap_value, 4),
|
||||
})
|
||||
filtered.sort(key=lambda entry: entry['shap_value'], reverse=positive)
|
||||
if not positive:
|
||||
filtered.sort(key=lambda entry: abs(entry['shap_value']), reverse=True)
|
||||
return filtered[:limit]
|
||||
|
||||
def _dimension_label(self, key):
|
||||
if key in config.JDR_DIMENSIONS:
|
||||
return config.JDR_DIMENSIONS[key]['name_cn']
|
||||
if key == 'event_context':
|
||||
return '事件上下文'
|
||||
if key == 'other':
|
||||
return '其他因素'
|
||||
return key or '其他因素'
|
||||
|
||||
def _limit_unique_items(self, items, limit=3):
|
||||
unique_items = []
|
||||
for item in items:
|
||||
if item not in unique_items:
|
||||
unique_items.append(item)
|
||||
return unique_items[:limit]
|
||||
|
||||
def _classify_job_demands(self, score):
|
||||
score = self._safe_float(score)
|
||||
if score >= 5.2:
|
||||
return '偏高', 'danger'
|
||||
if score >= 4.0:
|
||||
return '中等', 'warning'
|
||||
return '适中', 'success'
|
||||
|
||||
def _classify_resource_stock(self, score):
|
||||
score = self._safe_float(score)
|
||||
if score >= 3.8:
|
||||
return '充足', 'success'
|
||||
if score >= 3.0:
|
||||
return '中等', 'warning'
|
||||
return '偏低', 'danger'
|
||||
|
||||
def _classify_balance(self, score):
|
||||
score = self._safe_float(score)
|
||||
if score >= 0.8:
|
||||
return '资源占优', 'success'
|
||||
if score >= 0.0:
|
||||
return '基本平衡', 'info'
|
||||
if score >= -0.8:
|
||||
return '轻度失衡', 'warning'
|
||||
return '明显失衡', 'danger'
|
||||
|
||||
def _classify_burnout(self, score):
|
||||
score = self._safe_float(score)
|
||||
if score >= 2.8:
|
||||
return '偏高', 'danger'
|
||||
if score >= 2.0:
|
||||
return '中等', 'warning'
|
||||
return '可控', 'success'
|
||||
|
||||
def _inverse_transform_prediction(self, prediction):
|
||||
if self.training_metadata.get('target_transform') == 'log1p':
|
||||
return float(np.expm1(prediction))
|
||||
@@ -202,13 +552,13 @@ class PredictService:
|
||||
|
||||
def _get_default_prediction(self, data):
|
||||
base_hours = 3.8
|
||||
base_hours += min(float(data.get('monthly_overtime_hours', 24)) / 20, 3.0)
|
||||
base_hours += min(float(data.get('commute_minutes', 40)) / 50, 2.0)
|
||||
base_hours += 1.6 if int(data.get('is_night_shift', 0)) == 1 else 0
|
||||
base_hours += 1.8 if int(data.get('chronic_disease_flag', 0)) == 1 else 0
|
||||
base_hours += 0.9 if int(data.get('near_holiday_flag', 0)) == 1 else 0
|
||||
base_hours += 0.8 if int(data.get('medical_certificate_flag', 0)) == 1 else 0
|
||||
base_hours += 0.5 * int(data.get('children_count', 0))
|
||||
base_hours += min(self._safe_float(data.get('monthly_overtime_hours', 24)) / 20, 3.0)
|
||||
base_hours += min(self._safe_float(data.get('commute_minutes', 40)) / 50, 2.0)
|
||||
base_hours += 1.6 if self._as_flag(data.get('is_night_shift')) else 0
|
||||
base_hours += 1.8 if self._as_flag(data.get('chronic_disease_flag')) else 0
|
||||
base_hours += 0.9 if self._as_flag(data.get('near_holiday_flag')) else 0
|
||||
base_hours += 0.8 if self._as_flag(data.get('medical_certificate_flag')) else 0
|
||||
base_hours += 0.5 * int(self._safe_float(data.get('children_count', 0)))
|
||||
if data.get('leave_type') in ['病假', '工伤假', '婚假', '丧假']:
|
||||
base_hours += 2.5
|
||||
if data.get('stress_level') == '高':
|
||||
@@ -227,7 +577,6 @@ class PredictService:
|
||||
}
|
||||
|
||||
def _get_risk_probability(self, features, model_type):
|
||||
"""获取分类器预测的风险概率"""
|
||||
classifier = self.classifiers.get(model_type)
|
||||
if classifier is None:
|
||||
classifier = self.classifiers.get('random_forest')
|
||||
@@ -246,7 +595,6 @@ class PredictService:
|
||||
return {'low': 0.0, 'medium': 1.0, 'high': 0.0}
|
||||
|
||||
def predict_risk_classification(self, data, model_type=None):
|
||||
"""使用分类模型直接预测风险等级"""
|
||||
self._ensure_models_loaded()
|
||||
model_type = model_type or self.default_model
|
||||
classifier = self.classifiers.get(model_type)
|
||||
@@ -293,5 +641,17 @@ class PredictService:
|
||||
},
|
||||
}
|
||||
|
||||
def _safe_float(self, value, default=0.0):
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
|
||||
def _as_flag(self, value):
|
||||
try:
|
||||
return int(value) == 1
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
|
||||
predict_service = PredictService()
|
||||
|
||||
@@ -28,13 +28,58 @@ class SHAPService:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_global_importance(self, model_type='random_forest'):
|
||||
def _save_cache(self, model_type, payload):
|
||||
os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True)
|
||||
cache_path = self._get_cache_path(model_type)
|
||||
with open(cache_path, 'w', encoding='utf-8') as fp:
|
||||
json.dump(payload, fp, ensure_ascii=False)
|
||||
|
||||
def _build_cache_payload(self, model_type):
|
||||
self._ensure_analyzer()
|
||||
global_data = self._analyzer.global_shap_values(model_type)
|
||||
if global_data.get('error'):
|
||||
return {'error': global_data['error']}
|
||||
|
||||
top_features = [item['name'] for item in global_data.get('top_features', [])[:15]]
|
||||
dependence = {}
|
||||
for feature_name in top_features:
|
||||
data = self._analyzer.shap_dependence(feature_name, model_type)
|
||||
if not data.get('error'):
|
||||
dependence[feature_name] = data
|
||||
|
||||
interaction = self._analyzer.shap_interaction(model_type, top_n=10)
|
||||
if interaction.get('error'):
|
||||
return {'error': interaction['error']}
|
||||
|
||||
return {
|
||||
'model_type': model_type,
|
||||
'global': global_data,
|
||||
'dependence': dependence,
|
||||
'interaction': interaction,
|
||||
}
|
||||
|
||||
def _ensure_cache(self, model_type):
|
||||
cache = self._load_cache(model_type)
|
||||
if not cache:
|
||||
if cache:
|
||||
return cache
|
||||
|
||||
payload = self._build_cache_payload(model_type)
|
||||
if payload.get('error'):
|
||||
return {
|
||||
'error': f'SHAP cache not found for {model_type}. '
|
||||
f'Run backend/core/generate_shap_cache.py first.'
|
||||
'error': f'{model_type} 的贡献解释数据暂时不可用:{payload["error"]}'
|
||||
}
|
||||
|
||||
try:
|
||||
self._save_cache(model_type, payload)
|
||||
except Exception:
|
||||
# 缓存写入失败时至少保证当前请求可继续返回结果。
|
||||
pass
|
||||
return payload
|
||||
|
||||
def get_global_importance(self, model_type='random_forest'):
|
||||
cache = self._ensure_cache(model_type)
|
||||
if cache.get('error'):
|
||||
return cache
|
||||
return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'})
|
||||
|
||||
def get_local_explanation(self, data, model_type='random_forest'):
|
||||
@@ -42,12 +87,9 @@ class SHAPService:
|
||||
return self._analyzer.local_shap_values(data, model_type)
|
||||
|
||||
def get_interactions(self, model_type='random_forest', top_n=10):
|
||||
cache = self._load_cache(model_type)
|
||||
if not cache:
|
||||
return {
|
||||
'error': f'SHAP cache not found for {model_type}. '
|
||||
f'Run backend/core/generate_shap_cache.py first.'
|
||||
}
|
||||
cache = self._ensure_cache(model_type)
|
||||
if cache.get('error'):
|
||||
return cache
|
||||
data = cache.get('interaction')
|
||||
if not data:
|
||||
return {'error': f'Interaction cache missing for {model_type}'}
|
||||
@@ -58,17 +100,26 @@ class SHAPService:
|
||||
return data
|
||||
|
||||
def get_dependence(self, feature_name, model_type='random_forest'):
|
||||
cache = self._load_cache(model_type)
|
||||
if not cache:
|
||||
return {
|
||||
'error': f'SHAP cache not found for {model_type}. '
|
||||
f'Run backend/core/generate_shap_cache.py first.'
|
||||
}
|
||||
cache = self._ensure_cache(model_type)
|
||||
if cache.get('error'):
|
||||
return cache
|
||||
dependence_map = cache.get('dependence', {})
|
||||
data = dependence_map.get(feature_name)
|
||||
if data:
|
||||
return data
|
||||
return {'error': f'Dependence cache missing for feature {feature_name}'}
|
||||
|
||||
self._ensure_analyzer()
|
||||
data = self._analyzer.shap_dependence(feature_name, model_type)
|
||||
if data.get('error'):
|
||||
return {'error': f'特征 {feature_name} 的依赖解释不可用:{data["error"]}'}
|
||||
|
||||
dependence_map[feature_name] = data
|
||||
cache['dependence'] = dependence_map
|
||||
try:
|
||||
self._save_cache(model_type, cache)
|
||||
except Exception:
|
||||
pass
|
||||
return data
|
||||
|
||||
|
||||
shap_service = SHAPService()
|
||||
|
||||
Reference in New Issue
Block a user