64 lines
1.7 KiB
Python
64 lines
1.7 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
|
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
BASE_DIR = os.path.dirname(CURRENT_DIR)
|
|
if BASE_DIR not in sys.path:
|
|
sys.path.insert(0, BASE_DIR)
|
|
|
|
import config
|
|
from core.shap_analysis import SHAPAnalyzer
|
|
|
|
|
|
DEFAULT_MODELS = ['random_forest']
|
|
|
|
|
|
def build_cache(model_type):
|
|
analyzer = SHAPAnalyzer()
|
|
global_data = analyzer.global_shap_values(model_type)
|
|
if global_data.get('error'):
|
|
raise RuntimeError(global_data['error'])
|
|
|
|
top_features = [
|
|
item['name']
|
|
for item in global_data.get('top_features', [])[:15]
|
|
]
|
|
|
|
dependence = {}
|
|
for feature_name in top_features:
|
|
data = analyzer.shap_dependence(feature_name, model_type)
|
|
if not data.get('error'):
|
|
dependence[feature_name] = data
|
|
|
|
interaction = analyzer.shap_interaction(model_type, top_n=10)
|
|
if interaction.get('error'):
|
|
raise RuntimeError(interaction['error'])
|
|
|
|
return {
|
|
'model_type': model_type,
|
|
'global': global_data,
|
|
'dependence': dependence,
|
|
'interaction': interaction,
|
|
}
|
|
|
|
|
|
def save_cache(model_type, payload):
|
|
os.makedirs(config.SHAP_CACHE_DIR, exist_ok=True)
|
|
cache_path = os.path.join(config.SHAP_CACHE_DIR, f'{model_type}.json')
|
|
with open(cache_path, 'w', encoding='utf-8') as fp:
|
|
json.dump(payload, fp, ensure_ascii=False)
|
|
print(f'Saved SHAP cache: {cache_path}')
|
|
|
|
|
|
def main():
|
|
model_types = sys.argv[1:] or DEFAULT_MODELS
|
|
for model_type in model_types:
|
|
print(f'Generating SHAP cache for {model_type}...')
|
|
payload = build_cache(model_type)
|
|
save_cache(model_type, payload)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|