import json import os import config from core.shap_analysis import SHAPAnalyzer class SHAPService: """SHAP 可解释性分析服务""" def __init__(self): self._analyzer = None def _ensure_analyzer(self): if self._analyzer is None: self._analyzer = SHAPAnalyzer() def _get_cache_path(self, model_type): return os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json') def _load_cache(self, model_type): cache_path = self._get_cache_path(model_type) if not os.path.exists(cache_path): return None try: with open(cache_path, 'r', encoding='utf-8') as fp: return json.load(fp) except Exception: return None def _save_cache(self, model_type, payload): os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True) cache_path = self._get_cache_path(model_type) with open(cache_path, 'w', encoding='utf-8') as fp: json.dump(payload, fp, ensure_ascii=False) def _build_cache_payload(self, model_type): self._ensure_analyzer() global_data = self._analyzer.global_shap_values(model_type) if global_data.get('error'): return {'error': global_data['error']} top_features = [item['name'] for item in global_data.get('top_features', [])[:15]] dependence = {} for feature_name in top_features: data = self._analyzer.shap_dependence(feature_name, model_type) if not data.get('error'): dependence[feature_name] = data interaction = self._analyzer.shap_interaction(model_type, top_n=10) if interaction.get('error'): return {'error': interaction['error']} return { 'model_type': model_type, 'global': global_data, 'dependence': dependence, 'interaction': interaction, } def _ensure_cache(self, model_type): cache = self._load_cache(model_type) if cache: return cache payload = self._build_cache_payload(model_type) if payload.get('error'): return { 'error': f'{model_type} 的贡献解释数据暂时不可用:{payload["error"]}' } try: self._save_cache(model_type, payload) except Exception: # 缓存写入失败时至少保证当前请求可继续返回结果。 pass return payload def get_global_importance(self, model_type='random_forest'): cache = self._ensure_cache(model_type) if cache.get('error'): return cache return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'}) def get_local_explanation(self, data, model_type='random_forest'): self._ensure_analyzer() return self._analyzer.local_shap_values(data, model_type) def get_interactions(self, model_type='random_forest', top_n=10): cache = self._ensure_cache(model_type) if cache.get('error'): return cache data = cache.get('interaction') if not data: return {'error': f'Interaction cache missing for {model_type}'} if top_n and data.get('top_interactions'): result = dict(data) result['top_interactions'] = data['top_interactions'][:top_n] return result return data def get_dependence(self, feature_name, model_type='random_forest'): cache = self._ensure_cache(model_type) if cache.get('error'): return cache dependence_map = cache.get('dependence', {}) data = dependence_map.get(feature_name) if data: return data self._ensure_analyzer() data = self._analyzer.shap_dependence(feature_name, model_type) if data.get('error'): return {'error': f'特征 {feature_name} 的依赖解释不可用:{data["error"]}'} dependence_map[feature_name] = data cache['dependence'] = dependence_map try: self._save_cache(model_type, cache) except Exception: pass return data shap_service = SHAPService()