diff --git a/backend/config.py b/backend/config.py index 020f9a9..2531cff 100644 --- a/backend/config.py +++ b/backend/config.py @@ -7,6 +7,7 @@ DATA_DIR = os.path.join(BASE_DIR, 'data') RAW_DATA_DIR = os.path.join(DATA_DIR, 'raw') PROCESSED_DATA_DIR = os.path.join(DATA_DIR, 'processed') MODELS_DIR = os.path.join(BASE_DIR, 'models') +SHAP_CACHE_DIR = os.path.join(MODELS_DIR, 'shap_cache') RAW_DATA_FILENAME = 'china_enterprise_absence_events.csv' RAW_DATA_PATH = os.path.join(RAW_DATA_DIR, RAW_DATA_FILENAME) diff --git a/backend/core/generate_shap_cache.py b/backend/core/generate_shap_cache.py new file mode 100644 index 0000000..e74054c --- /dev/null +++ b/backend/core/generate_shap_cache.py @@ -0,0 +1,63 @@ +import json +import os +import sys + +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) +BASE_DIR = os.path.dirname(CURRENT_DIR) +if BASE_DIR not in sys.path: + sys.path.insert(0, BASE_DIR) + +import config +from core.shap_analysis import SHAPAnalyzer + + +DEFAULT_MODELS = ['random_forest'] + + +def build_cache(model_type): + analyzer = SHAPAnalyzer() + global_data = analyzer.global_shap_values(model_type) + if global_data.get('error'): + raise RuntimeError(global_data['error']) + + top_features = [ + item['name'] + for item in global_data.get('top_features', [])[:15] + ] + + dependence = {} + for feature_name in top_features: + data = analyzer.shap_dependence(feature_name, model_type) + if not data.get('error'): + dependence[feature_name] = data + + interaction = analyzer.shap_interaction(model_type, top_n=10) + if interaction.get('error'): + raise RuntimeError(interaction['error']) + + return { + 'model_type': model_type, + 'global': global_data, + 'dependence': dependence, + 'interaction': interaction, + } + + +def save_cache(model_type, payload): + os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True) + cache_path = os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json') + with open(cache_path, 'w', encoding='utf-8') as fp: + json.dump(payload, fp, ensure_ascii=False) + print(f'Saved SHAP cache: {cache_path}') + + +def main(): + model_types = sys.argv[1:] or DEFAULT_MODELS + for model_type in model_types: + print(f'Generating SHAP cache for {model_type}...') + payload = build_cache(model_type) + save_cache(model_type, payload) + + +if __name__ == '__main__': + main() diff --git a/backend/services/shap_service.py b/backend/services/shap_service.py index 931e646..f902d94 100644 --- a/backend/services/shap_service.py +++ b/backend/services/shap_service.py @@ -1,3 +1,7 @@ +import json +import os + +import config from core.shap_analysis import SHAPAnalyzer @@ -11,21 +15,60 @@ class SHAPService: if self._analyzer is None: self._analyzer = SHAPAnalyzer() + def _get_cache_path(self, model_type): + return os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json') + + def _load_cache(self, model_type): + cache_path = self._get_cache_path(model_type) + if not os.path.exists(cache_path): + return None + try: + with open(cache_path, 'r', encoding='utf-8') as fp: + return json.load(fp) + except Exception: + return None + def get_global_importance(self, model_type='random_forest'): - self._ensure_analyzer() - return self._analyzer.global_shap_values(model_type) + cache = self._load_cache(model_type) + if not cache: + return { + 'error': f'SHAP cache not found for {model_type}. ' + f'Run backend/core/generate_shap_cache.py first.' + } + return cache.get('global', {'error': f'Invalid SHAP cache for {model_type}'}) def get_local_explanation(self, data, model_type='random_forest'): self._ensure_analyzer() return self._analyzer.local_shap_values(data, model_type) def get_interactions(self, model_type='random_forest', top_n=10): - self._ensure_analyzer() - return self._analyzer.shap_interaction(model_type, top_n) + cache = self._load_cache(model_type) + if not cache: + return { + 'error': f'SHAP cache not found for {model_type}. ' + f'Run backend/core/generate_shap_cache.py first.' + } + data = cache.get('interaction') + if not data: + return {'error': f'Interaction cache missing for {model_type}'} + if top_n and data.get('top_interactions'): + result = dict(data) + result['top_interactions'] = data['top_interactions'][:top_n] + return result + return data def get_dependence(self, feature_name, model_type='random_forest'): - self._ensure_analyzer() - return self._analyzer.shap_dependence(feature_name, model_type) + cache = self._load_cache(model_type) + if not cache: + return { + 'error': f'SHAP cache not found for {model_type}. ' + f'Run backend/core/generate_shap_cache.py first.' + } + dependence_map = cache.get('dependence', {}) + data = dependence_map.get(feature_name) + if data: + return data + return {'error': f'Dependence cache missing for feature {feature_name}'} shap_service = SHAPService()