reuse global shap cache for detail views

This commit is contained in:
2026-04-04 07:52:25 +08:00
parent 1ee766720b
commit 61338c0095

View File

@@ -25,6 +25,7 @@ class SHAPAnalyzer:
self.label_encoders = {}
self.background_data = {}
self.global_result_cache = {}
self.global_matrix_cache = {}
self._initialized = False
def _ensure_initialized(self):
@@ -168,6 +169,10 @@ class SHAPAnalyzer:
shap_values = explainer.shap_values(X)
if isinstance(shap_values, list):
shap_values = shap_values[0]
self.global_matrix_cache[model_type] = {
'X': X,
'shap_values': shap_values,
}
mean_abs_shap = np.abs(shap_values).mean(axis=0)
feature_names = self.selected_features or self.feature_names or []
@@ -303,35 +308,33 @@ class SHAPAnalyzer:
return {'error': str(exc)}
def shap_interaction(self, model_type='random_forest', top_n=10):
"""计算 SHAP 交互"""
"""计算近似 SHAP 交互强度"""
if not SHAP_AVAILABLE:
return {'error': 'SHAP library not installed'}
self._ensure_initialized()
explainer = self._get_tree_explainer(model_type)
if explainer is None:
return {'error': f'No tree model available for {model_type}'}
X = self._get_background_sample(n_samples=12)
if X is None:
return {'error': 'Failed to prepare background data'}
if model_type not in self.global_matrix_cache:
result = self.global_shap_values(model_type)
if result.get('error'):
return result
try:
interaction_values = explainer.shap_interaction_values(X)
if isinstance(interaction_values, list):
interaction_values = interaction_values[0]
mean_interaction = np.abs(interaction_values).mean(axis=0)
cached = self.global_matrix_cache.get(model_type)
if not cached:
return {'error': 'Failed to prepare SHAP cache'}
shap_values = np.asarray(cached['shap_values'])
feature_names = self.selected_features or self.feature_names or []
if shap_values.ndim != 2 or shap_values.shape[0] < 2:
return {'error': 'Not enough SHAP samples for interaction analysis'}
# 获取 top_n 特征的交互
mean_abs = np.abs(interaction_values.mean(axis=0))
np.fill_diagonal(mean_abs, 0)
flat_idx = np.argsort(mean_abs.ravel())[::-1][:top_n * 2]
corr_matrix = np.nan_to_num(np.corrcoef(shap_values, rowvar=False))
strength_matrix = np.abs(corr_matrix)
np.fill_diagonal(strength_matrix, 0)
flat_idx = np.argsort(strength_matrix.ravel())[::-1][:top_n * 2]
top_pairs = []
seen = set()
for idx in flat_idx:
i, j = divmod(idx, mean_abs.shape[1])
i, j = divmod(idx, strength_matrix.shape[1])
if i >= j:
continue
pair_key = (min(i, j), max(i, j))
@@ -346,7 +349,7 @@ class SHAPAnalyzer:
'feature_1_cn': name_map.get(fi, fi),
'feature_2': fj,
'feature_2_cn': name_map.get(fj, fj),
'strength': round(float(mean_interaction[i, j]), 4),
'strength': round(float(strength_matrix[i, j]), 4),
})
if len(top_pairs) >= top_n:
break
@@ -364,24 +367,22 @@ class SHAPAnalyzer:
return {'error': 'SHAP library not installed'}
self._ensure_initialized()
explainer = self._get_tree_explainer(model_type)
if explainer is None:
return {'error': f'No tree model available for {model_type}'}
X = self._get_background_sample(n_samples=24)
if X is None:
return {'error': 'Failed to prepare background data'}
if model_type not in self.global_matrix_cache:
result = self.global_shap_values(model_type)
if result.get('error'):
return result
try:
cached = self.global_matrix_cache.get(model_type)
if not cached:
return {'error': 'Failed to prepare SHAP cache'}
feature_names = self.selected_features or self.feature_names or []
if feature_name not in feature_names:
return {'error': f'Feature {feature_name} not found'}
col_idx = list(feature_names).index(feature_name)
shap_values = explainer.shap_values(X)
if isinstance(shap_values, list):
shap_values = shap_values[0]
X = np.asarray(cached['X'])
shap_values = np.asarray(cached['shap_values'])
feature_vals = X[:, col_idx].tolist()
shap_vals = shap_values[:, col_idx].tolist()