reuse global shap cache for detail views
This commit is contained in:
@@ -25,6 +25,7 @@ class SHAPAnalyzer:
|
||||
self.label_encoders = {}
|
||||
self.background_data = {}
|
||||
self.global_result_cache = {}
|
||||
self.global_matrix_cache = {}
|
||||
self._initialized = False
|
||||
|
||||
def _ensure_initialized(self):
|
||||
@@ -168,6 +169,10 @@ class SHAPAnalyzer:
|
||||
shap_values = explainer.shap_values(X)
|
||||
if isinstance(shap_values, list):
|
||||
shap_values = shap_values[0]
|
||||
self.global_matrix_cache[model_type] = {
|
||||
'X': X,
|
||||
'shap_values': shap_values,
|
||||
}
|
||||
|
||||
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
||||
feature_names = self.selected_features or self.feature_names or []
|
||||
@@ -303,35 +308,33 @@ class SHAPAnalyzer:
|
||||
return {'error': str(exc)}
|
||||
|
||||
def shap_interaction(self, model_type='random_forest', top_n=10):
|
||||
"""计算 SHAP 交互值"""
|
||||
"""计算近似 SHAP 交互强度"""
|
||||
if not SHAP_AVAILABLE:
|
||||
return {'error': 'SHAP library not installed'}
|
||||
|
||||
self._ensure_initialized()
|
||||
explainer = self._get_tree_explainer(model_type)
|
||||
if explainer is None:
|
||||
return {'error': f'No tree model available for {model_type}'}
|
||||
|
||||
X = self._get_background_sample(n_samples=12)
|
||||
if X is None:
|
||||
return {'error': 'Failed to prepare background data'}
|
||||
if model_type not in self.global_matrix_cache:
|
||||
result = self.global_shap_values(model_type)
|
||||
if result.get('error'):
|
||||
return result
|
||||
|
||||
try:
|
||||
interaction_values = explainer.shap_interaction_values(X)
|
||||
if isinstance(interaction_values, list):
|
||||
interaction_values = interaction_values[0]
|
||||
|
||||
mean_interaction = np.abs(interaction_values).mean(axis=0)
|
||||
cached = self.global_matrix_cache.get(model_type)
|
||||
if not cached:
|
||||
return {'error': 'Failed to prepare SHAP cache'}
|
||||
shap_values = np.asarray(cached['shap_values'])
|
||||
feature_names = self.selected_features or self.feature_names or []
|
||||
if shap_values.ndim != 2 or shap_values.shape[0] < 2:
|
||||
return {'error': 'Not enough SHAP samples for interaction analysis'}
|
||||
|
||||
# 获取 top_n 特征的交互
|
||||
mean_abs = np.abs(interaction_values.mean(axis=0))
|
||||
np.fill_diagonal(mean_abs, 0)
|
||||
flat_idx = np.argsort(mean_abs.ravel())[::-1][:top_n * 2]
|
||||
corr_matrix = np.nan_to_num(np.corrcoef(shap_values, rowvar=False))
|
||||
strength_matrix = np.abs(corr_matrix)
|
||||
np.fill_diagonal(strength_matrix, 0)
|
||||
flat_idx = np.argsort(strength_matrix.ravel())[::-1][:top_n * 2]
|
||||
top_pairs = []
|
||||
seen = set()
|
||||
for idx in flat_idx:
|
||||
i, j = divmod(idx, mean_abs.shape[1])
|
||||
i, j = divmod(idx, strength_matrix.shape[1])
|
||||
if i >= j:
|
||||
continue
|
||||
pair_key = (min(i, j), max(i, j))
|
||||
@@ -346,7 +349,7 @@ class SHAPAnalyzer:
|
||||
'feature_1_cn': name_map.get(fi, fi),
|
||||
'feature_2': fj,
|
||||
'feature_2_cn': name_map.get(fj, fj),
|
||||
'strength': round(float(mean_interaction[i, j]), 4),
|
||||
'strength': round(float(strength_matrix[i, j]), 4),
|
||||
})
|
||||
if len(top_pairs) >= top_n:
|
||||
break
|
||||
@@ -364,24 +367,22 @@ class SHAPAnalyzer:
|
||||
return {'error': 'SHAP library not installed'}
|
||||
|
||||
self._ensure_initialized()
|
||||
explainer = self._get_tree_explainer(model_type)
|
||||
if explainer is None:
|
||||
return {'error': f'No tree model available for {model_type}'}
|
||||
|
||||
X = self._get_background_sample(n_samples=24)
|
||||
if X is None:
|
||||
return {'error': 'Failed to prepare background data'}
|
||||
if model_type not in self.global_matrix_cache:
|
||||
result = self.global_shap_values(model_type)
|
||||
if result.get('error'):
|
||||
return result
|
||||
|
||||
try:
|
||||
cached = self.global_matrix_cache.get(model_type)
|
||||
if not cached:
|
||||
return {'error': 'Failed to prepare SHAP cache'}
|
||||
feature_names = self.selected_features or self.feature_names or []
|
||||
if feature_name not in feature_names:
|
||||
return {'error': f'Feature {feature_name} not found'}
|
||||
|
||||
col_idx = list(feature_names).index(feature_name)
|
||||
shap_values = explainer.shap_values(X)
|
||||
if isinstance(shap_values, list):
|
||||
shap_values = shap_values[0]
|
||||
|
||||
X = np.asarray(cached['X'])
|
||||
shap_values = np.asarray(cached['shap_values'])
|
||||
feature_vals = X[:, col_idx].tolist()
|
||||
shap_vals = shap_values[:, col_idx].tolist()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user