import json import os import sys CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) BASE_DIR = os.path.dirname(CURRENT_DIR) if BASE_DIR not in sys.path: sys.path.insert(0, BASE_DIR) import config from core.shap_analysis import SHAPAnalyzer DEFAULT_MODELS = ['random_forest'] def build_cache(model_type): analyzer = SHAPAnalyzer() global_data = analyzer.global_shap_values(model_type) if global_data.get('error'): raise RuntimeError(global_data['error']) top_features = [ item['name'] for item in global_data.get('top_features', [])[:15] ] dependence = {} for feature_name in top_features: data = analyzer.shap_dependence(feature_name, model_type) if not data.get('error'): dependence[feature_name] = data interaction = analyzer.shap_interaction(model_type, top_n=10) if interaction.get('error'): raise RuntimeError(interaction['error']) return { 'model_type': model_type, 'global': global_data, 'dependence': dependence, 'interaction': interaction, } def save_cache(model_type, payload): os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True) cache_path = os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json') with open(cache_path, 'w', encoding='utf-8') as fp: json.dump(payload, fp, ensure_ascii=False) print(f'Saved SHAP cache: {cache_path}') def main(): model_types = sys.argv[1:] or DEFAULT_MODELS for model_type in model_types: print(f'Generating SHAP cache for {model_type}...') payload = build_cache(model_type) save_cache(model_type, payload) if __name__ == '__main__': main()