Files
forsetsystem/backend/services/shap_service.py
2026-04-27 11:59:35 +08:00

126 lines
4.1 KiB
Python

import json
import os
import config
from core.shap_analysis import SHAPAnalyzer
class SHAPService:
"""SHAP 可解释性分析服务"""
def __init__(self):
self._analyzer = None
def _ensure_analyzer(self):
if self._analyzer is None:
self._analyzer = SHAPAnalyzer()
def _get_cache_path(self, model_type):
return os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json')
def _load_cache(self, model_type):
cache_path = self._get_cache_path(model_type)
if not os.path.exists(cache_path):
return None
try:
with open(cache_path, 'r', encoding='utf-8') as fp:
return json.load(fp)
except Exception:
return None
def _save_cache(self, model_type, payload):
os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True)
cache_path = self._get_cache_path(model_type)
with open(cache_path, 'w', encoding='utf-8') as fp:
json.dump(payload, fp, ensure_ascii=False)
def _build_cache_payload(self, model_type):
self._ensure_analyzer()
global_data = self._analyzer.global_shap_values(model_type)
if global_data.get('error'):
return {'error': global_data['error']}
top_features = [item['name'] for item in global_data.get('top_features', [])[:15]]
dependence = {}
for feature_name in top_features:
data = self._analyzer.shap_dependence(feature_name, model_type)
if not data.get('error'):
dependence[feature_name] = data
interaction = self._analyzer.shap_interaction(model_type, top_n=10)
if interaction.get('error'):
return {'error': interaction['error']}
return {
'model_type': model_type,
'global': global_data,
'dependence': dependence,
'interaction': interaction,
}
def _ensure_cache(self, model_type):
cache = self._load_cache(model_type)
if cache:
return cache
payload = self._build_cache_payload(model_type)
if payload.get('error'):
return {
'error': f'{model_type} 的贡献解释数据暂时不可用:{payload["error"]}'
}
try:
self._save_cache(model_type, payload)
except Exception:
# 缓存写入失败时至少保证当前请求可继续返回结果。
pass
return payload
def get_global_importance(self, model_type='random_forest'):
cache = self._ensure_cache(model_type)
if cache.get('error'):
return cache
return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'})
def get_local_explanation(self, data, model_type='random_forest'):
self._ensure_analyzer()
return self._analyzer.local_shap_values(data, model_type)
def get_interactions(self, model_type='random_forest', top_n=10):
cache = self._ensure_cache(model_type)
if cache.get('error'):
return cache
data = cache.get('interaction')
if not data:
return {'error': f'Interaction cache missing for {model_type}'}
if top_n and data.get('top_interactions'):
result = dict(data)
result['top_interactions'] = data['top_interactions'][:top_n]
return result
return data
def get_dependence(self, feature_name, model_type='random_forest'):
cache = self._ensure_cache(model_type)
if cache.get('error'):
return cache
dependence_map = cache.get('dependence', {})
data = dependence_map.get(feature_name)
if data:
return data
self._ensure_analyzer()
data = self._analyzer.shap_dependence(feature_name, model_type)
if data.get('error'):
return {'error': f'特征 {feature_name} 的依赖解释不可用:{data["error"]}'}
dependence_map[feature_name] = data
cache['dependence'] = dependence_map
try:
self._save_cache(model_type, cache)
except Exception:
pass
return data
shap_service = SHAPService()