406 lines
16 KiB
Python
406 lines
16 KiB
Python
import os
|
||
|
||
import joblib
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
import config
|
||
|
||
try:
|
||
import shap
|
||
SHAP_AVAILABLE = True
|
||
except ImportError:
|
||
SHAP_AVAILABLE = False
|
||
|
||
|
||
class SHAPAnalyzer:
|
||
"""基于 SHAP 值的可解释性分析器,按 JD-R 维度聚合解释结果。"""
|
||
|
||
def __init__(self):
|
||
self.explainers = {}
|
||
self.models = {}
|
||
self.scaler = None
|
||
self.feature_names = None
|
||
self.selected_features = None
|
||
self.label_encoders = {}
|
||
self.background_data = {}
|
||
self.global_result_cache = {}
|
||
self._initialized = False
|
||
|
||
def _ensure_initialized(self):
|
||
if self._initialized:
|
||
return
|
||
|
||
# 加载回归模型(SHAP 分析基于回归模型)
|
||
models_dir = config.MODELS_DIR
|
||
model_files = {
|
||
'random_forest': 'random_forest_model.pkl',
|
||
'xgboost': 'xgboost_model.pkl',
|
||
'lightgbm': 'lightgbm_model.pkl',
|
||
'gradient_boosting': 'gradient_boosting_model.pkl',
|
||
'extra_trees': 'extra_trees_model.pkl',
|
||
}
|
||
for name, filename in model_files.items():
|
||
path = os.path.join(models_dir, filename)
|
||
if os.path.exists(path):
|
||
try:
|
||
self.models[name] = joblib.load(path)
|
||
except Exception:
|
||
pass
|
||
|
||
# 加载预处理工件
|
||
if os.path.exists(config.SCALER_PATH):
|
||
self.scaler = joblib.load(config.SCALER_PATH)
|
||
for filename, attr in [
|
||
('feature_names.pkl', 'feature_names'),
|
||
('selected_features.pkl', 'selected_features'),
|
||
('label_encoders.pkl', 'label_encoders'),
|
||
]:
|
||
path = os.path.join(models_dir, filename)
|
||
if os.path.exists(path):
|
||
try:
|
||
setattr(self, attr, joblib.load(path))
|
||
except Exception:
|
||
pass
|
||
|
||
self._initialized = True
|
||
|
||
def _get_tree_explainer(self, model_type='random_forest'):
|
||
"""获取或创建 TreeExplainer"""
|
||
if not SHAP_AVAILABLE:
|
||
return None
|
||
|
||
if model_type in self.explainers:
|
||
return self.explainers[model_type]
|
||
|
||
model = self.models.get(model_type)
|
||
if model is None:
|
||
return None
|
||
|
||
try:
|
||
explainer = shap.TreeExplainer(model)
|
||
self.explainers[model_type] = explainer
|
||
return explainer
|
||
except Exception:
|
||
return None
|
||
|
||
def _get_background_sample(self, n_samples=500):
|
||
"""获取背景数据样本"""
|
||
if n_samples in self.background_data:
|
||
return self.background_data[n_samples]
|
||
|
||
try:
|
||
from core.preprocessing import get_clean_data
|
||
from core.model_features import (
|
||
normalize_columns, prepare_modeling_dataframe,
|
||
apply_outlier_bounds, fit_outlier_bounds,
|
||
engineer_features, extract_xy, fit_label_encoders,
|
||
apply_label_encoders, align_feature_frame, to_float_array,
|
||
NUMERICAL_OUTLIER_COLUMNS, ORDINAL_COLUMNS,
|
||
)
|
||
|
||
raw_df = normalize_columns(get_clean_data())
|
||
df = prepare_modeling_dataframe(raw_df)
|
||
|
||
bounds = fit_outlier_bounds(df, NUMERICAL_OUTLIER_COLUMNS)
|
||
df = apply_outlier_bounds(df, bounds)
|
||
df = engineer_features(df)
|
||
X_df, _ = extract_xy(df)
|
||
X_df, encoders = fit_label_encoders(X_df, ORDINAL_COLUMNS)
|
||
|
||
if self.feature_names:
|
||
X_df = align_feature_frame(X_df, self.feature_names)
|
||
|
||
if n_samples < len(X_df):
|
||
X_df = X_df.sample(n=n_samples, random_state=config.RANDOM_STATE)
|
||
|
||
if self.scaler is not None:
|
||
X = self.scaler.transform(to_float_array(X_df))
|
||
else:
|
||
X = to_float_array(X_df)
|
||
|
||
if self.selected_features and self.feature_names:
|
||
selected_indices = [self.feature_names.index(n) for n in self.selected_features if n in self.feature_names]
|
||
if selected_indices:
|
||
X = X[:, selected_indices]
|
||
|
||
self.background_data[n_samples] = X
|
||
return X
|
||
except Exception:
|
||
return None
|
||
|
||
def _get_feature_display_names(self):
|
||
"""获取特征显示名称映射"""
|
||
feature_names = self.selected_features or self.feature_names or []
|
||
return {name: config.FEATURE_NAME_CN.get(name, name) for name in feature_names}
|
||
|
||
def _map_feature_to_dimension(self, feature_name):
|
||
"""将特征映射到 JD-R 维度"""
|
||
for dim_key, dim_info in config.JDR_DIMENSIONS.items():
|
||
if feature_name in dim_info['features']:
|
||
return dim_key
|
||
# 事件/上下文特征
|
||
context_features = ['缺勤月份', '星期几', '是否节假日前后', '季节',
|
||
'请假类型', '请假原因大类', '是否提供医院证明',
|
||
'是否临时请假', '是否连续缺勤', '前一工作日是否加班']
|
||
if feature_name in context_features:
|
||
return 'event_context'
|
||
return 'other'
|
||
|
||
def global_shap_values(self, model_type='random_forest'):
|
||
"""计算全局 SHAP 重要性,按 JD-R 维度分组"""
|
||
if not SHAP_AVAILABLE:
|
||
return {'error': 'SHAP library not installed'}
|
||
|
||
if model_type in self.global_result_cache:
|
||
return self.global_result_cache[model_type]
|
||
|
||
self._ensure_initialized()
|
||
explainer = self._get_tree_explainer(model_type)
|
||
if explainer is None:
|
||
return {'error': f'No tree model available for {model_type}'}
|
||
|
||
X = self._get_background_sample(n_samples=32)
|
||
if X is None:
|
||
return {'error': 'Failed to prepare background data'}
|
||
|
||
try:
|
||
shap_values = explainer.shap_values(X)
|
||
if isinstance(shap_values, list):
|
||
shap_values = shap_values[0]
|
||
|
||
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
||
feature_names = self.selected_features or self.feature_names or []
|
||
name_map = self._get_feature_display_names()
|
||
|
||
# 按维度分组
|
||
dimensions = {}
|
||
for dim_key, dim_info in config.JDR_DIMENSIONS.items():
|
||
dim_features = []
|
||
for fname in feature_names:
|
||
if fname in dim_info['features']:
|
||
idx = list(feature_names).index(fname)
|
||
dim_features.append({
|
||
'name': fname,
|
||
'name_cn': name_map.get(fname, fname),
|
||
'importance': round(float(mean_abs_shap[idx]), 4),
|
||
})
|
||
if dim_features:
|
||
dimensions[dim_key] = {
|
||
'name_cn': dim_info['name_cn'],
|
||
'features': sorted(dim_features, key=lambda x: x['importance'], reverse=True),
|
||
}
|
||
|
||
# 事件上下文维度
|
||
context_features = []
|
||
for fname in feature_names:
|
||
if self._map_feature_to_dimension(fname) == 'event_context':
|
||
idx = list(feature_names).index(fname)
|
||
context_features.append({
|
||
'name': fname,
|
||
'name_cn': name_map.get(fname, fname),
|
||
'importance': round(float(mean_abs_shap[idx]), 4),
|
||
})
|
||
if context_features:
|
||
dimensions['event_context'] = {
|
||
'name_cn': '事件上下文',
|
||
'features': sorted(context_features, key=lambda x: x['importance'], reverse=True),
|
||
}
|
||
|
||
# Top 特征列表
|
||
top_indices = np.argsort(mean_abs_shap)[::-1][:20]
|
||
top_features = []
|
||
for idx in top_indices:
|
||
fname = feature_names[idx] if idx < len(feature_names) else f'f{idx}'
|
||
top_features.append({
|
||
'name': fname,
|
||
'name_cn': name_map.get(fname, fname),
|
||
'importance': round(float(mean_abs_shap[idx]), 4),
|
||
'dimension': self._map_feature_to_dimension(fname),
|
||
})
|
||
|
||
result = {
|
||
'model_type': model_type,
|
||
'dimensions': dimensions,
|
||
'top_features': top_features,
|
||
}
|
||
self.global_result_cache[model_type] = result
|
||
return result
|
||
except Exception as exc:
|
||
return {'error': str(exc)}
|
||
|
||
def local_shap_values(self, data, model_type='random_forest'):
|
||
"""计算单条预测的 SHAP 解释"""
|
||
if not SHAP_AVAILABLE:
|
||
return {'error': 'SHAP library not installed'}
|
||
|
||
self._ensure_initialized()
|
||
explainer = self._get_tree_explainer(model_type)
|
||
if explainer is None:
|
||
return {'error': f'No tree model available for {model_type}'}
|
||
|
||
try:
|
||
from core.model_features import (
|
||
build_prediction_dataframe, engineer_features,
|
||
apply_label_encoders, align_feature_frame, to_float_array,
|
||
)
|
||
|
||
X_df = build_prediction_dataframe(data)
|
||
X_df = engineer_features(X_df)
|
||
X_df = apply_label_encoders(X_df, self.label_encoders)
|
||
if self.feature_names:
|
||
X_df = align_feature_frame(X_df, self.feature_names)
|
||
features = self.scaler.transform(to_float_array(X_df))
|
||
if self.selected_features and self.feature_names:
|
||
selected_indices = [self.feature_names.index(n) for n in self.selected_features if n in self.feature_names]
|
||
if selected_indices:
|
||
features = features[:, selected_indices]
|
||
|
||
shap_values = explainer.shap_values(features)
|
||
if isinstance(shap_values, list):
|
||
shap_values = shap_values[0]
|
||
|
||
base_value = float(explainer.expected_value)
|
||
if isinstance(base_value, (list, np.ndarray)):
|
||
base_value = float(base_value[0])
|
||
|
||
feature_names = self.selected_features or self.feature_names or []
|
||
name_map = self._get_feature_display_names()
|
||
|
||
feature_contributions = []
|
||
dimension_contribution = {}
|
||
for idx, fname in enumerate(feature_names):
|
||
sv = float(shap_values[0][idx])
|
||
fv = float(features[0][idx])
|
||
dim = self._map_feature_to_dimension(fname)
|
||
feature_contributions.append({
|
||
'name': fname,
|
||
'name_cn': name_map.get(fname, fname),
|
||
'shap_value': round(sv, 4),
|
||
'feature_value': round(fv, 4),
|
||
'dimension': dim,
|
||
})
|
||
dimension_contribution[dim] = dimension_contribution.get(dim, 0) + sv
|
||
|
||
feature_contributions.sort(key=lambda x: abs(x['shap_value']), reverse=True)
|
||
|
||
# 维度标签
|
||
dim_labels = {}
|
||
for dk, di in config.JDR_DIMENSIONS.items():
|
||
dim_labels[dk] = di['name_cn']
|
||
dim_labels['event_context'] = '事件上下文'
|
||
dim_labels['other'] = '其他'
|
||
|
||
return {
|
||
'base_value': round(base_value, 4),
|
||
'features': feature_contributions[:20],
|
||
'dimension_contribution': {
|
||
dim_labels.get(k, k): round(v, 4)
|
||
for k, v in sorted(dimension_contribution.items(), key=lambda x: abs(x[1]), reverse=True)
|
||
},
|
||
}
|
||
except Exception as exc:
|
||
return {'error': str(exc)}
|
||
|
||
def shap_interaction(self, model_type='random_forest', top_n=10):
|
||
"""计算 SHAP 交互值"""
|
||
if not SHAP_AVAILABLE:
|
||
return {'error': 'SHAP library not installed'}
|
||
|
||
self._ensure_initialized()
|
||
explainer = self._get_tree_explainer(model_type)
|
||
if explainer is None:
|
||
return {'error': f'No tree model available for {model_type}'}
|
||
|
||
X = self._get_background_sample(n_samples=12)
|
||
if X is None:
|
||
return {'error': 'Failed to prepare background data'}
|
||
|
||
try:
|
||
interaction_values = explainer.shap_interaction_values(X)
|
||
if isinstance(interaction_values, list):
|
||
interaction_values = interaction_values[0]
|
||
|
||
mean_interaction = np.abs(interaction_values).mean(axis=0)
|
||
feature_names = self.selected_features or self.feature_names or []
|
||
|
||
# 获取 top_n 特征的交互
|
||
mean_abs = np.abs(interaction_values.mean(axis=0))
|
||
np.fill_diagonal(mean_abs, 0)
|
||
flat_idx = np.argsort(mean_abs.ravel())[::-1][:top_n * 2]
|
||
top_pairs = []
|
||
seen = set()
|
||
for idx in flat_idx:
|
||
i, j = divmod(idx, mean_abs.shape[1])
|
||
if i >= j:
|
||
continue
|
||
pair_key = (min(i, j), max(i, j))
|
||
if pair_key in seen:
|
||
continue
|
||
seen.add(pair_key)
|
||
fi = feature_names[i] if i < len(feature_names) else f'f{i}'
|
||
fj = feature_names[j] if j < len(feature_names) else f'f{j}'
|
||
name_map = self._get_feature_display_names()
|
||
top_pairs.append({
|
||
'feature_1': fi,
|
||
'feature_1_cn': name_map.get(fi, fi),
|
||
'feature_2': fj,
|
||
'feature_2_cn': name_map.get(fj, fj),
|
||
'strength': round(float(mean_interaction[i, j]), 4),
|
||
})
|
||
if len(top_pairs) >= top_n:
|
||
break
|
||
|
||
return {
|
||
'model_type': model_type,
|
||
'top_interactions': top_pairs,
|
||
}
|
||
except Exception as exc:
|
||
return {'error': str(exc)}
|
||
|
||
def shap_dependence(self, feature_name, model_type='random_forest'):
|
||
"""计算单个特征的 SHAP 依赖图数据"""
|
||
if not SHAP_AVAILABLE:
|
||
return {'error': 'SHAP library not installed'}
|
||
|
||
self._ensure_initialized()
|
||
explainer = self._get_tree_explainer(model_type)
|
||
if explainer is None:
|
||
return {'error': f'No tree model available for {model_type}'}
|
||
|
||
X = self._get_background_sample(n_samples=24)
|
||
if X is None:
|
||
return {'error': 'Failed to prepare background data'}
|
||
|
||
try:
|
||
feature_names = self.selected_features or self.feature_names or []
|
||
if feature_name not in feature_names:
|
||
return {'error': f'Feature {feature_name} not found'}
|
||
|
||
col_idx = list(feature_names).index(feature_name)
|
||
shap_values = explainer.shap_values(X)
|
||
if isinstance(shap_values, list):
|
||
shap_values = shap_values[0]
|
||
|
||
feature_vals = X[:, col_idx].tolist()
|
||
shap_vals = shap_values[:, col_idx].tolist()
|
||
|
||
# 下采样用于可视化
|
||
max_points = 300
|
||
if len(feature_vals) > max_points:
|
||
indices = np.random.RandomState(config.RANDOM_STATE).choice(
|
||
len(feature_vals), max_points, replace=False
|
||
)
|
||
feature_vals = [feature_vals[i] for i in indices]
|
||
shap_vals = [shap_vals[i] for i in indices]
|
||
|
||
name_map = self._get_feature_display_names()
|
||
return {
|
||
'feature': feature_name,
|
||
'feature_cn': name_map.get(feature_name, feature_name),
|
||
'values': [round(v, 4) for v in feature_vals],
|
||
'shap_values': [round(v, 4) for v in shap_vals],
|
||
}
|
||
except Exception as exc:
|
||
return {'error': str(exc)}
|