77 lines
3.2 KiB
Python
77 lines
3.2 KiB
Python
from flask import Blueprint, jsonify, request
|
|
import time
|
|
|
|
from services.shap_service import shap_service
|
|
|
|
shap_bp = Blueprint('shap', __name__, url_prefix='/api/shap')
|
|
|
|
|
|
def log_shap(route, stage, extra=''):
|
|
message = f'[SHAP] route={route} stage={stage}'
|
|
if extra:
|
|
message = f'{message} {extra}'
|
|
print(message, flush=True)
|
|
|
|
|
|
@shap_bp.route('/global', methods=['GET'])
|
|
def get_global_importance():
|
|
started_at = time.time()
|
|
try:
|
|
model_type = request.args.get('model', 'random_forest')
|
|
log_shap('/global', 'start', f'model={model_type}')
|
|
result = shap_service.get_global_importance(model_type)
|
|
elapsed = round(time.time() - started_at, 3)
|
|
log_shap('/global', 'success', f'model={model_type} elapsed={elapsed}s')
|
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
|
except Exception as e:
|
|
elapsed = round(time.time() - started_at, 3)
|
|
log_shap('/global', 'error', f'elapsed={elapsed}s error={e}')
|
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
|
|
|
|
|
@shap_bp.route('/local', methods=['POST'])
|
|
def get_local_explanation():
|
|
try:
|
|
data = request.get_json()
|
|
if not data:
|
|
return jsonify({'code': 400, 'message': 'Request body is required', 'data': None}), 400
|
|
model_type = data.get('model_type', 'random_forest')
|
|
result = shap_service.get_local_explanation(data, model_type)
|
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
|
except Exception as e:
|
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
|
|
|
|
|
@shap_bp.route('/interaction', methods=['GET'])
|
|
def get_interactions():
|
|
started_at = time.time()
|
|
try:
|
|
model_type = request.args.get('model', 'random_forest')
|
|
top_n = int(request.args.get('top_n', 10))
|
|
log_shap('/interaction', 'start', f'model={model_type} top_n={top_n}')
|
|
result = shap_service.get_interactions(model_type, top_n)
|
|
elapsed = round(time.time() - started_at, 3)
|
|
log_shap('/interaction', 'success', f'model={model_type} elapsed={elapsed}s')
|
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
|
except Exception as e:
|
|
elapsed = round(time.time() - started_at, 3)
|
|
log_shap('/interaction', 'error', f'elapsed={elapsed}s error={e}')
|
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
|
|
|
|
|
@shap_bp.route('/dependence', methods=['GET'])
|
|
def get_dependence():
|
|
started_at = time.time()
|
|
try:
|
|
feature = request.args.get('feature', '月均加班时长')
|
|
model_type = request.args.get('model', 'random_forest')
|
|
log_shap('/dependence', 'start', f'model={model_type} feature={feature}')
|
|
result = shap_service.get_dependence(feature, model_type)
|
|
elapsed = round(time.time() - started_at, 3)
|
|
log_shap('/dependence', 'success', f'model={model_type} elapsed={elapsed}s')
|
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
|
except Exception as e:
|
|
elapsed = round(time.time() - started_at, 3)
|
|
log_shap('/dependence', 'error', f'elapsed={elapsed}s error={e}')
|
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|