diff --git a/backend/core/shap_analysis.py b/backend/core/shap_analysis.py index a3995f4..aaf0218 100644 --- a/backend/core/shap_analysis.py +++ b/backend/core/shap_analysis.py @@ -25,6 +25,7 @@ class SHAPAnalyzer: self.label_encoders = {} self.background_data = {} self.global_result_cache = {} + self.global_matrix_cache = {} self._initialized = False def _ensure_initialized(self): @@ -168,6 +169,10 @@ class SHAPAnalyzer: shap_values = explainer.shap_values(X) if isinstance(shap_values, list): shap_values = shap_values[0] + self.global_matrix_cache[model_type] = { + 'X': X, + 'shap_values': shap_values, + } mean_abs_shap = np.abs(shap_values).mean(axis=0) feature_names = self.selected_features or self.feature_names or [] @@ -303,35 +308,33 @@ class SHAPAnalyzer: return {'error': str(exc)} def shap_interaction(self, model_type='random_forest', top_n=10): - """计算 SHAP 交互值""" + """计算近似 SHAP 交互强度""" if not SHAP_AVAILABLE: return {'error': 'SHAP library not installed'} self._ensure_initialized() - explainer = self._get_tree_explainer(model_type) - if explainer is None: - return {'error': f'No tree model available for {model_type}'} - - X = self._get_background_sample(n_samples=12) - if X is None: - return {'error': 'Failed to prepare background data'} + if model_type not in self.global_matrix_cache: + result = self.global_shap_values(model_type) + if result.get('error'): + return result try: - interaction_values = explainer.shap_interaction_values(X) - if isinstance(interaction_values, list): - interaction_values = interaction_values[0] - - mean_interaction = np.abs(interaction_values).mean(axis=0) + cached = self.global_matrix_cache.get(model_type) + if not cached: + return {'error': 'Failed to prepare SHAP cache'} + shap_values = np.asarray(cached['shap_values']) feature_names = self.selected_features or self.feature_names or [] + if shap_values.ndim != 2 or shap_values.shape[0] < 2: + return {'error': 'Not enough SHAP samples for interaction analysis'} - # 获取 top_n 特征的交互 - mean_abs = np.abs(interaction_values.mean(axis=0)) - np.fill_diagonal(mean_abs, 0) - flat_idx = np.argsort(mean_abs.ravel())[::-1][:top_n * 2] + corr_matrix = np.nan_to_num(np.corrcoef(shap_values, rowvar=False)) + strength_matrix = np.abs(corr_matrix) + np.fill_diagonal(strength_matrix, 0) + flat_idx = np.argsort(strength_matrix.ravel())[::-1][:top_n * 2] top_pairs = [] seen = set() for idx in flat_idx: - i, j = divmod(idx, mean_abs.shape[1]) + i, j = divmod(idx, strength_matrix.shape[1]) if i >= j: continue pair_key = (min(i, j), max(i, j)) @@ -346,7 +349,7 @@ class SHAPAnalyzer: 'feature_1_cn': name_map.get(fi, fi), 'feature_2': fj, 'feature_2_cn': name_map.get(fj, fj), - 'strength': round(float(mean_interaction[i, j]), 4), + 'strength': round(float(strength_matrix[i, j]), 4), }) if len(top_pairs) >= top_n: break @@ -364,24 +367,22 @@ class SHAPAnalyzer: return {'error': 'SHAP library not installed'} self._ensure_initialized() - explainer = self._get_tree_explainer(model_type) - if explainer is None: - return {'error': f'No tree model available for {model_type}'} - - X = self._get_background_sample(n_samples=24) - if X is None: - return {'error': 'Failed to prepare background data'} + if model_type not in self.global_matrix_cache: + result = self.global_shap_values(model_type) + if result.get('error'): + return result try: + cached = self.global_matrix_cache.get(model_type) + if not cached: + return {'error': 'Failed to prepare SHAP cache'} feature_names = self.selected_features or self.feature_names or [] if feature_name not in feature_names: return {'error': f'Feature {feature_name} not found'} col_idx = list(feature_names).index(feature_name) - shap_values = explainer.shap_values(X) - if isinstance(shap_values, list): - shap_values = shap_values[0] - + X = np.asarray(cached['X']) + shap_values = np.asarray(cached['shap_values']) feature_vals = X[:, col_idx].tolist() shap_vals = shap_values[:, col_idx].tolist()