switch shap endpoints to offline cache

This commit is contained in:
2026-04-04 07:57:19 +08:00
parent 61338c0095
commit 5655eb0cda
3 changed files with 113 additions and 6 deletions

View File

@@ -1,3 +1,7 @@
import json
import os
import config
from core.shap_analysis import SHAPAnalyzer
@@ -11,21 +15,60 @@ class SHAPService:
if self._analyzer is None:
self._analyzer = SHAPAnalyzer()
def _get_cache_path(self, model_type):
return os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json')
def _load_cache(self, model_type):
cache_path = self._get_cache_path(model_type)
if not os.path.exists(cache_path):
return None
try:
with open(cache_path, 'r', encoding='utf-8') as fp:
return json.load(fp)
except Exception:
return None
def get_global_importance(self, model_type='random_forest'):
self._ensure_analyzer()
return self._analyzer.global_shap_values(model_type)
cache = self._load_cache(model_type)
if not cache:
return {
'error': f'SHAP cache not found for {model_type}. '
f'Run backend/core/generate_shap_cache.py first.'
}
return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'})
def get_local_explanation(self, data, model_type='random_forest'):
self._ensure_analyzer()
return self._analyzer.local_shap_values(data, model_type)
def get_interactions(self, model_type='random_forest', top_n=10):
self._ensure_analyzer()
return self._analyzer.shap_interaction(model_type, top_n)
cache = self._load_cache(model_type)
if not cache:
return {
'error': f'SHAP cache not found for {model_type}. '
f'Run backend/core/generate_shap_cache.py first.'
}
data = cache.get('interaction')
if not data:
return {'error': f'Interaction cache missing for {model_type}'}
if top_n and data.get('top_interactions'):
result = dict(data)
result['top_interactions'] = data['top_interactions'][:top_n]
return result
return data
def get_dependence(self, feature_name, model_type='random_forest'):
self._ensure_analyzer()
return self._analyzer.shap_dependence(feature_name, model_type)
cache = self._load_cache(model_type)
if not cache:
return {
'error': f'SHAP cache not found for {model_type}. '
f'Run backend/core/generate_shap_cache.py first.'
}
dependence_map = cache.get('dependence', {})
data = dependence_map.get(feature_name)
if data:
return data
return {'error': f'Dependence cache missing for feature {feature_name}'}
shap_service = SHAPService()