reuse global shap cache for detail views

This commit is contained in:
2026-04-04 07:52:25 +08:00
parent 1ee766720b
commit 61338c0095

View File

@@ -25,6 +25,7 @@ class SHAPAnalyzer:
self.label_encoders = {} self.label_encoders = {}
self.background_data = {} self.background_data = {}
self.global_result_cache = {} self.global_result_cache = {}
self.global_matrix_cache = {}
self._initialized = False self._initialized = False
def _ensure_initialized(self): def _ensure_initialized(self):
@@ -168,6 +169,10 @@ class SHAPAnalyzer:
shap_values = explainer.shap_values(X) shap_values = explainer.shap_values(X)
if isinstance(shap_values, list): if isinstance(shap_values, list):
shap_values = shap_values[0] shap_values = shap_values[0]
self.global_matrix_cache[model_type] = {
'X': X,
'shap_values': shap_values,
}
mean_abs_shap = np.abs(shap_values).mean(axis=0) mean_abs_shap = np.abs(shap_values).mean(axis=0)
feature_names = self.selected_features or self.feature_names or [] feature_names = self.selected_features or self.feature_names or []
@@ -303,35 +308,33 @@ class SHAPAnalyzer:
return {'error': str(exc)} return {'error': str(exc)}
def shap_interaction(self, model_type='random_forest', top_n=10): def shap_interaction(self, model_type='random_forest', top_n=10):
"""计算 SHAP 交互""" """计算近似 SHAP 交互强度"""
if not SHAP_AVAILABLE: if not SHAP_AVAILABLE:
return {'error': 'SHAP library not installed'} return {'error': 'SHAP library not installed'}
self._ensure_initialized() self._ensure_initialized()
explainer = self._get_tree_explainer(model_type) if model_type not in self.global_matrix_cache:
if explainer is None: result = self.global_shap_values(model_type)
return {'error': f'No tree model available for {model_type}'} if result.get('error'):
return result
X = self._get_background_sample(n_samples=12)
if X is None:
return {'error': 'Failed to prepare background data'}
try: try:
interaction_values = explainer.shap_interaction_values(X) cached = self.global_matrix_cache.get(model_type)
if isinstance(interaction_values, list): if not cached:
interaction_values = interaction_values[0] return {'error': 'Failed to prepare SHAP cache'}
shap_values = np.asarray(cached['shap_values'])
mean_interaction = np.abs(interaction_values).mean(axis=0)
feature_names = self.selected_features or self.feature_names or [] feature_names = self.selected_features or self.feature_names or []
if shap_values.ndim != 2 or shap_values.shape[0] < 2:
return {'error': 'Not enough SHAP samples for interaction analysis'}
# 获取 top_n 特征的交互 corr_matrix = np.nan_to_num(np.corrcoef(shap_values, rowvar=False))
mean_abs = np.abs(interaction_values.mean(axis=0)) strength_matrix = np.abs(corr_matrix)
np.fill_diagonal(mean_abs, 0) np.fill_diagonal(strength_matrix, 0)
flat_idx = np.argsort(mean_abs.ravel())[::-1][:top_n * 2] flat_idx = np.argsort(strength_matrix.ravel())[::-1][:top_n * 2]
top_pairs = [] top_pairs = []
seen = set() seen = set()
for idx in flat_idx: for idx in flat_idx:
i, j = divmod(idx, mean_abs.shape[1]) i, j = divmod(idx, strength_matrix.shape[1])
if i >= j: if i >= j:
continue continue
pair_key = (min(i, j), max(i, j)) pair_key = (min(i, j), max(i, j))
@@ -346,7 +349,7 @@ class SHAPAnalyzer:
'feature_1_cn': name_map.get(fi, fi), 'feature_1_cn': name_map.get(fi, fi),
'feature_2': fj, 'feature_2': fj,
'feature_2_cn': name_map.get(fj, fj), 'feature_2_cn': name_map.get(fj, fj),
'strength': round(float(mean_interaction[i, j]), 4), 'strength': round(float(strength_matrix[i, j]), 4),
}) })
if len(top_pairs) >= top_n: if len(top_pairs) >= top_n:
break break
@@ -364,24 +367,22 @@ class SHAPAnalyzer:
return {'error': 'SHAP library not installed'} return {'error': 'SHAP library not installed'}
self._ensure_initialized() self._ensure_initialized()
explainer = self._get_tree_explainer(model_type) if model_type not in self.global_matrix_cache:
if explainer is None: result = self.global_shap_values(model_type)
return {'error': f'No tree model available for {model_type}'} if result.get('error'):
return result
X = self._get_background_sample(n_samples=24)
if X is None:
return {'error': 'Failed to prepare background data'}
try: try:
cached = self.global_matrix_cache.get(model_type)
if not cached:
return {'error': 'Failed to prepare SHAP cache'}
feature_names = self.selected_features or self.feature_names or [] feature_names = self.selected_features or self.feature_names or []
if feature_name not in feature_names: if feature_name not in feature_names:
return {'error': f'Feature {feature_name} not found'} return {'error': f'Feature {feature_name} not found'}
col_idx = list(feature_names).index(feature_name) col_idx = list(feature_names).index(feature_name)
shap_values = explainer.shap_values(X) X = np.asarray(cached['X'])
if isinstance(shap_values, list): shap_values = np.asarray(cached['shap_values'])
shap_values = shap_values[0]
feature_vals = X[:, col_idx].tolist() feature_vals = X[:, col_idx].tolist()
shap_vals = shap_values[:, col_idx].tolist() shap_vals = shap_values[:, col_idx].tolist()