75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
import json
|
|
import os
|
|
|
|
import config
|
|
from core.shap_analysis import SHAPAnalyzer
|
|
|
|
|
|
class SHAPService:
|
|
"""SHAP 可解释性分析服务"""
|
|
|
|
def __init__(self):
|
|
self._analyzer = None
|
|
|
|
def _ensure_analyzer(self):
|
|
if self._analyzer is None:
|
|
self._analyzer = SHAPAnalyzer()
|
|
|
|
def _get_cache_path(self, model_type):
|
|
return os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json')
|
|
|
|
def _load_cache(self, model_type):
|
|
cache_path = self._get_cache_path(model_type)
|
|
if not os.path.exists(cache_path):
|
|
return None
|
|
try:
|
|
with open(cache_path, 'r', encoding='utf-8') as fp:
|
|
return json.load(fp)
|
|
except Exception:
|
|
return None
|
|
|
|
def get_global_importance(self, model_type='random_forest'):
|
|
cache = self._load_cache(model_type)
|
|
if not cache:
|
|
return {
|
|
'error': f'SHAP cache not found for {model_type}. '
|
|
f'Run backend/core/generate_shap_cache.py first.'
|
|
}
|
|
return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'})
|
|
|
|
def get_local_explanation(self, data, model_type='random_forest'):
|
|
self._ensure_analyzer()
|
|
return self._analyzer.local_shap_values(data, model_type)
|
|
|
|
def get_interactions(self, model_type='random_forest', top_n=10):
|
|
cache = self._load_cache(model_type)
|
|
if not cache:
|
|
return {
|
|
'error': f'SHAP cache not found for {model_type}. '
|
|
f'Run backend/core/generate_shap_cache.py first.'
|
|
}
|
|
data = cache.get('interaction')
|
|
if not data:
|
|
return {'error': f'Interaction cache missing for {model_type}'}
|
|
if top_n and data.get('top_interactions'):
|
|
result = dict(data)
|
|
result['top_interactions'] = data['top_interactions'][:top_n]
|
|
return result
|
|
return data
|
|
|
|
def get_dependence(self, feature_name, model_type='random_forest'):
|
|
cache = self._load_cache(model_type)
|
|
if not cache:
|
|
return {
|
|
'error': f'SHAP cache not found for {model_type}. '
|
|
f'Run backend/core/generate_shap_cache.py first.'
|
|
}
|
|
dependence_map = cache.get('dependence', {})
|
|
data = dependence_map.get(feature_name)
|
|
if data:
|
|
return data
|
|
return {'error': f'Dependence cache missing for feature {feature_name}'}
|
|
|
|
|
|
shap_service = SHAPService()
|