switch shap endpoints to offline cache
This commit is contained in:
@@ -7,6 +7,7 @@ DATA_DIR = os.path.join(BASE_DIR, 'data')
|
|||||||
RAW_DATA_DIR = os.path.join(DATA_DIR, 'raw')
|
RAW_DATA_DIR = os.path.join(DATA_DIR, 'raw')
|
||||||
PROCESSED_DATA_DIR = os.path.join(DATA_DIR, 'processed')
|
PROCESSED_DATA_DIR = os.path.join(DATA_DIR, 'processed')
|
||||||
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
MODELS_DIR = os.path.join(BASE_DIR, 'models')
|
||||||
|
SHAP_CACHE_DIR = os.path.join(MODELS_DIR, 'shap_cache')
|
||||||
|
|
||||||
RAW_DATA_FILENAME = 'china_enterprise_absence_events.csv'
|
RAW_DATA_FILENAME = 'china_enterprise_absence_events.csv'
|
||||||
RAW_DATA_PATH = os.path.join(RAW_DATA_DIR, RAW_DATA_FILENAME)
|
RAW_DATA_PATH = os.path.join(RAW_DATA_DIR, RAW_DATA_FILENAME)
|
||||||
|
|||||||
63
backend/core/generate_shap_cache.py
Normal file
63
backend/core/generate_shap_cache.py
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
BASE_DIR = os.path.dirname(CURRENT_DIR)
|
||||||
|
if BASE_DIR not in sys.path:
|
||||||
|
sys.path.insert(0, BASE_DIR)
|
||||||
|
|
||||||
|
import config
|
||||||
|
from core.shap_analysis import SHAPAnalyzer
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MODELS = ['random_forest']
|
||||||
|
|
||||||
|
|
||||||
|
def build_cache(model_type):
|
||||||
|
analyzer = SHAPAnalyzer()
|
||||||
|
global_data = analyzer.global_shap_values(model_type)
|
||||||
|
if global_data.get('error'):
|
||||||
|
raise RuntimeError(global_data['error'])
|
||||||
|
|
||||||
|
top_features = [
|
||||||
|
item['name']
|
||||||
|
for item in global_data.get('top_features', [])[:15]
|
||||||
|
]
|
||||||
|
|
||||||
|
dependence = {}
|
||||||
|
for feature_name in top_features:
|
||||||
|
data = analyzer.shap_dependence(feature_name, model_type)
|
||||||
|
if not data.get('error'):
|
||||||
|
dependence[feature_name] = data
|
||||||
|
|
||||||
|
interaction = analyzer.shap_interaction(model_type, top_n=10)
|
||||||
|
if interaction.get('error'):
|
||||||
|
raise RuntimeError(interaction['error'])
|
||||||
|
|
||||||
|
return {
|
||||||
|
'model_type': model_type,
|
||||||
|
'global': global_data,
|
||||||
|
'dependence': dependence,
|
||||||
|
'interaction': interaction,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def save_cache(model_type, payload):
|
||||||
|
os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True)
|
||||||
|
cache_path = os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json')
|
||||||
|
with open(cache_path, 'w', encoding='utf-8') as fp:
|
||||||
|
json.dump(payload, fp, ensure_ascii=False)
|
||||||
|
print(f'Saved SHAP cache: {cache_path}')
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
model_types = sys.argv[1:] or DEFAULT_MODELS
|
||||||
|
for model_type in model_types:
|
||||||
|
print(f'Generating SHAP cache for {model_type}...')
|
||||||
|
payload = build_cache(model_type)
|
||||||
|
save_cache(model_type, payload)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
|
||||||
|
import config
|
||||||
from core.shap_analysis import SHAPAnalyzer
|
from core.shap_analysis import SHAPAnalyzer
|
||||||
|
|
||||||
|
|
||||||
@@ -11,21 +15,60 @@ class SHAPService:
|
|||||||
if self._analyzer is None:
|
if self._analyzer is None:
|
||||||
self._analyzer = SHAPAnalyzer()
|
self._analyzer = SHAPAnalyzer()
|
||||||
|
|
||||||
|
def _get_cache_path(self, model_type):
|
||||||
|
return os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json')
|
||||||
|
|
||||||
|
def _load_cache(self, model_type):
|
||||||
|
cache_path = self._get_cache_path(model_type)
|
||||||
|
if not os.path.exists(cache_path):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
with open(cache_path, 'r', encoding='utf-8') as fp:
|
||||||
|
return json.load(fp)
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
def get_global_importance(self, model_type='random_forest'):
|
def get_global_importance(self, model_type='random_forest'):
|
||||||
self._ensure_analyzer()
|
cache = self._load_cache(model_type)
|
||||||
return self._analyzer.global_shap_values(model_type)
|
if not cache:
|
||||||
|
return {
|
||||||
|
'error': f'SHAP cache not found for {model_type}. '
|
||||||
|
f'Run backend/core/generate_shap_cache.py first.'
|
||||||
|
}
|
||||||
|
return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'})
|
||||||
|
|
||||||
def get_local_explanation(self, data, model_type='random_forest'):
|
def get_local_explanation(self, data, model_type='random_forest'):
|
||||||
self._ensure_analyzer()
|
self._ensure_analyzer()
|
||||||
return self._analyzer.local_shap_values(data, model_type)
|
return self._analyzer.local_shap_values(data, model_type)
|
||||||
|
|
||||||
def get_interactions(self, model_type='random_forest', top_n=10):
|
def get_interactions(self, model_type='random_forest', top_n=10):
|
||||||
self._ensure_analyzer()
|
cache = self._load_cache(model_type)
|
||||||
return self._analyzer.shap_interaction(model_type, top_n)
|
if not cache:
|
||||||
|
return {
|
||||||
|
'error': f'SHAP cache not found for {model_type}. '
|
||||||
|
f'Run backend/core/generate_shap_cache.py first.'
|
||||||
|
}
|
||||||
|
data = cache.get('interaction')
|
||||||
|
if not data:
|
||||||
|
return {'error': f'Interaction cache missing for {model_type}'}
|
||||||
|
if top_n and data.get('top_interactions'):
|
||||||
|
result = dict(data)
|
||||||
|
result['top_interactions'] = data['top_interactions'][:top_n]
|
||||||
|
return result
|
||||||
|
return data
|
||||||
|
|
||||||
def get_dependence(self, feature_name, model_type='random_forest'):
|
def get_dependence(self, feature_name, model_type='random_forest'):
|
||||||
self._ensure_analyzer()
|
cache = self._load_cache(model_type)
|
||||||
return self._analyzer.shap_dependence(feature_name, model_type)
|
if not cache:
|
||||||
|
return {
|
||||||
|
'error': f'SHAP cache not found for {model_type}. '
|
||||||
|
f'Run backend/core/generate_shap_cache.py first.'
|
||||||
|
}
|
||||||
|
dependence_map = cache.get('dependence', {})
|
||||||
|
data = dependence_map.get(feature_name)
|
||||||
|
if data:
|
||||||
|
return data
|
||||||
|
return {'error': f'Dependence cache missing for feature {feature_name}'}
|
||||||
|
|
||||||
|
|
||||||
shap_service = SHAPService()
|
shap_service = SHAPService()
|
||||||
|
|||||||
Reference in New Issue
Block a user