reuse global shap cache for detail views
This commit is contained in:
@@ -25,6 +25,7 @@ class SHAPAnalyzer:
|
|||||||
self.label_encoders = {}
|
self.label_encoders = {}
|
||||||
self.background_data = {}
|
self.background_data = {}
|
||||||
self.global_result_cache = {}
|
self.global_result_cache = {}
|
||||||
|
self.global_matrix_cache = {}
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def _ensure_initialized(self):
|
def _ensure_initialized(self):
|
||||||
@@ -168,6 +169,10 @@ class SHAPAnalyzer:
|
|||||||
shap_values = explainer.shap_values(X)
|
shap_values = explainer.shap_values(X)
|
||||||
if isinstance(shap_values, list):
|
if isinstance(shap_values, list):
|
||||||
shap_values = shap_values[0]
|
shap_values = shap_values[0]
|
||||||
|
self.global_matrix_cache[model_type] = {
|
||||||
|
'X': X,
|
||||||
|
'shap_values': shap_values,
|
||||||
|
}
|
||||||
|
|
||||||
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
||||||
feature_names = self.selected_features or self.feature_names or []
|
feature_names = self.selected_features or self.feature_names or []
|
||||||
@@ -303,35 +308,33 @@ class SHAPAnalyzer:
|
|||||||
return {'error': str(exc)}
|
return {'error': str(exc)}
|
||||||
|
|
||||||
def shap_interaction(self, model_type='random_forest', top_n=10):
|
def shap_interaction(self, model_type='random_forest', top_n=10):
|
||||||
"""计算 SHAP 交互值"""
|
"""计算近似 SHAP 交互强度"""
|
||||||
if not SHAP_AVAILABLE:
|
if not SHAP_AVAILABLE:
|
||||||
return {'error': 'SHAP library not installed'}
|
return {'error': 'SHAP library not installed'}
|
||||||
|
|
||||||
self._ensure_initialized()
|
self._ensure_initialized()
|
||||||
explainer = self._get_tree_explainer(model_type)
|
if model_type not in self.global_matrix_cache:
|
||||||
if explainer is None:
|
result = self.global_shap_values(model_type)
|
||||||
return {'error': f'No tree model available for {model_type}'}
|
if result.get('error'):
|
||||||
|
return result
|
||||||
X = self._get_background_sample(n_samples=12)
|
|
||||||
if X is None:
|
|
||||||
return {'error': 'Failed to prepare background data'}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
interaction_values = explainer.shap_interaction_values(X)
|
cached = self.global_matrix_cache.get(model_type)
|
||||||
if isinstance(interaction_values, list):
|
if not cached:
|
||||||
interaction_values = interaction_values[0]
|
return {'error': 'Failed to prepare SHAP cache'}
|
||||||
|
shap_values = np.asarray(cached['shap_values'])
|
||||||
mean_interaction = np.abs(interaction_values).mean(axis=0)
|
|
||||||
feature_names = self.selected_features or self.feature_names or []
|
feature_names = self.selected_features or self.feature_names or []
|
||||||
|
if shap_values.ndim != 2 or shap_values.shape[0] < 2:
|
||||||
|
return {'error': 'Not enough SHAP samples for interaction analysis'}
|
||||||
|
|
||||||
# 获取 top_n 特征的交互
|
corr_matrix = np.nan_to_num(np.corrcoef(shap_values, rowvar=False))
|
||||||
mean_abs = np.abs(interaction_values.mean(axis=0))
|
strength_matrix = np.abs(corr_matrix)
|
||||||
np.fill_diagonal(mean_abs, 0)
|
np.fill_diagonal(strength_matrix, 0)
|
||||||
flat_idx = np.argsort(mean_abs.ravel())[::-1][:top_n * 2]
|
flat_idx = np.argsort(strength_matrix.ravel())[::-1][:top_n * 2]
|
||||||
top_pairs = []
|
top_pairs = []
|
||||||
seen = set()
|
seen = set()
|
||||||
for idx in flat_idx:
|
for idx in flat_idx:
|
||||||
i, j = divmod(idx, mean_abs.shape[1])
|
i, j = divmod(idx, strength_matrix.shape[1])
|
||||||
if i >= j:
|
if i >= j:
|
||||||
continue
|
continue
|
||||||
pair_key = (min(i, j), max(i, j))
|
pair_key = (min(i, j), max(i, j))
|
||||||
@@ -346,7 +349,7 @@ class SHAPAnalyzer:
|
|||||||
'feature_1_cn': name_map.get(fi, fi),
|
'feature_1_cn': name_map.get(fi, fi),
|
||||||
'feature_2': fj,
|
'feature_2': fj,
|
||||||
'feature_2_cn': name_map.get(fj, fj),
|
'feature_2_cn': name_map.get(fj, fj),
|
||||||
'strength': round(float(mean_interaction[i, j]), 4),
|
'strength': round(float(strength_matrix[i, j]), 4),
|
||||||
})
|
})
|
||||||
if len(top_pairs) >= top_n:
|
if len(top_pairs) >= top_n:
|
||||||
break
|
break
|
||||||
@@ -364,24 +367,22 @@ class SHAPAnalyzer:
|
|||||||
return {'error': 'SHAP library not installed'}
|
return {'error': 'SHAP library not installed'}
|
||||||
|
|
||||||
self._ensure_initialized()
|
self._ensure_initialized()
|
||||||
explainer = self._get_tree_explainer(model_type)
|
if model_type not in self.global_matrix_cache:
|
||||||
if explainer is None:
|
result = self.global_shap_values(model_type)
|
||||||
return {'error': f'No tree model available for {model_type}'}
|
if result.get('error'):
|
||||||
|
return result
|
||||||
X = self._get_background_sample(n_samples=24)
|
|
||||||
if X is None:
|
|
||||||
return {'error': 'Failed to prepare background data'}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
cached = self.global_matrix_cache.get(model_type)
|
||||||
|
if not cached:
|
||||||
|
return {'error': 'Failed to prepare SHAP cache'}
|
||||||
feature_names = self.selected_features or self.feature_names or []
|
feature_names = self.selected_features or self.feature_names or []
|
||||||
if feature_name not in feature_names:
|
if feature_name not in feature_names:
|
||||||
return {'error': f'Feature {feature_name} not found'}
|
return {'error': f'Feature {feature_name} not found'}
|
||||||
|
|
||||||
col_idx = list(feature_names).index(feature_name)
|
col_idx = list(feature_names).index(feature_name)
|
||||||
shap_values = explainer.shap_values(X)
|
X = np.asarray(cached['X'])
|
||||||
if isinstance(shap_values, list):
|
shap_values = np.asarray(cached['shap_values'])
|
||||||
shap_values = shap_values[0]
|
|
||||||
|
|
||||||
feature_vals = X[:, col_idx].tolist()
|
feature_vals = X[:, col_idx].tolist()
|
||||||
shap_vals = shap_values[:, col_idx].tolist()
|
shap_vals = shap_values[:, col_idx].tolist()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user