optimize shap response latency

This commit is contained in:
2026-04-04 07:47:32 +08:00
parent 11ba5c535b
commit 1ee766720b
3 changed files with 43 additions and 8 deletions

View File

@@ -1,17 +1,31 @@
from flask import Blueprint, jsonify, request from flask import Blueprint, jsonify, request
import time
from services.shap_service import shap_service from services.shap_service import shap_service
shap_bp = Blueprint('shap', __name__, url_prefix='/api/shap') shap_bp = Blueprint('shap', __name__, url_prefix='/api/shap')
def log_shap(route, stage, extra=''):
message = f'[SHAP] route={route} stage={stage}'
if extra:
message = f'{message} {extra}'
print(message, flush=True)
@shap_bp.route('/global', methods=['GET']) @shap_bp.route('/global', methods=['GET'])
def get_global_importance(): def get_global_importance():
started_at = time.time()
try: try:
model_type = request.args.get('model', 'random_forest') model_type = request.args.get('model', 'random_forest')
log_shap('/global', 'start', f'model={model_type}')
result = shap_service.get_global_importance(model_type) result = shap_service.get_global_importance(model_type)
elapsed = round(time.time() - started_at, 3)
log_shap('/global', 'success', f'model={model_type} elapsed={elapsed}s')
return jsonify({'code': 200, 'message': 'success', 'data': result}) return jsonify({'code': 200, 'message': 'success', 'data': result})
except Exception as e: except Exception as e:
elapsed = round(time.time() - started_at, 3)
log_shap('/global', 'error', f'elapsed={elapsed}s error={e}')
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500 return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
@@ -30,21 +44,33 @@ def get_local_explanation():
@shap_bp.route('/interaction', methods=['GET']) @shap_bp.route('/interaction', methods=['GET'])
def get_interactions(): def get_interactions():
started_at = time.time()
try: try:
model_type = request.args.get('model', 'random_forest') model_type = request.args.get('model', 'random_forest')
top_n = int(request.args.get('top_n', 10)) top_n = int(request.args.get('top_n', 10))
log_shap('/interaction', 'start', f'model={model_type} top_n={top_n}')
result = shap_service.get_interactions(model_type, top_n) result = shap_service.get_interactions(model_type, top_n)
elapsed = round(time.time() - started_at, 3)
log_shap('/interaction', 'success', f'model={model_type} elapsed={elapsed}s')
return jsonify({'code': 200, 'message': 'success', 'data': result}) return jsonify({'code': 200, 'message': 'success', 'data': result})
except Exception as e: except Exception as e:
elapsed = round(time.time() - started_at, 3)
log_shap('/interaction', 'error', f'elapsed={elapsed}s error={e}')
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500 return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
@shap_bp.route('/dependence', methods=['GET']) @shap_bp.route('/dependence', methods=['GET'])
def get_dependence(): def get_dependence():
started_at = time.time()
try: try:
feature = request.args.get('feature', '月均加班时长') feature = request.args.get('feature', '月均加班时长')
model_type = request.args.get('model', 'random_forest') model_type = request.args.get('model', 'random_forest')
log_shap('/dependence', 'start', f'model={model_type} feature={feature}')
result = shap_service.get_dependence(feature, model_type) result = shap_service.get_dependence(feature, model_type)
elapsed = round(time.time() - started_at, 3)
log_shap('/dependence', 'success', f'model={model_type} elapsed={elapsed}s')
return jsonify({'code': 200, 'message': 'success', 'data': result}) return jsonify({'code': 200, 'message': 'success', 'data': result})
except Exception as e: except Exception as e:
elapsed = round(time.time() - started_at, 3)
log_shap('/dependence', 'error', f'elapsed={elapsed}s error={e}')
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500 return jsonify({'code': 500, 'message': str(e), 'data': None}), 500

View File

@@ -23,7 +23,8 @@ class SHAPAnalyzer:
self.feature_names = None self.feature_names = None
self.selected_features = None self.selected_features = None
self.label_encoders = {} self.label_encoders = {}
self.background_data = None self.background_data = {}
self.global_result_cache = {}
self._initialized = False self._initialized = False
def _ensure_initialized(self): def _ensure_initialized(self):
@@ -85,8 +86,8 @@ class SHAPAnalyzer:
def _get_background_sample(self, n_samples=500): def _get_background_sample(self, n_samples=500):
"""获取背景数据样本""" """获取背景数据样本"""
if self.background_data is not None: if n_samples in self.background_data:
return self.background_data return self.background_data[n_samples]
try: try:
from core.preprocessing import get_clean_data from core.preprocessing import get_clean_data
@@ -123,7 +124,7 @@ class SHAPAnalyzer:
if selected_indices: if selected_indices:
X = X[:, selected_indices] X = X[:, selected_indices]
self.background_data = X self.background_data[n_samples] = X
return X return X
except Exception: except Exception:
return None return None
@@ -151,12 +152,15 @@ class SHAPAnalyzer:
if not SHAP_AVAILABLE: if not SHAP_AVAILABLE:
return {'error': 'SHAP library not installed'} return {'error': 'SHAP library not installed'}
if model_type in self.global_result_cache:
return self.global_result_cache[model_type]
self._ensure_initialized() self._ensure_initialized()
explainer = self._get_tree_explainer(model_type) explainer = self._get_tree_explainer(model_type)
if explainer is None: if explainer is None:
return {'error': f'No tree model available for {model_type}'} return {'error': f'No tree model available for {model_type}'}
X = self._get_background_sample() X = self._get_background_sample(n_samples=32)
if X is None: if X is None:
return {'error': 'Failed to prepare background data'} return {'error': 'Failed to prepare background data'}
@@ -215,11 +219,13 @@ class SHAPAnalyzer:
'dimension': self._map_feature_to_dimension(fname), 'dimension': self._map_feature_to_dimension(fname),
}) })
return { result = {
'model_type': model_type, 'model_type': model_type,
'dimensions': dimensions, 'dimensions': dimensions,
'top_features': top_features, 'top_features': top_features,
} }
self.global_result_cache[model_type] = result
return result
except Exception as exc: except Exception as exc:
return {'error': str(exc)} return {'error': str(exc)}
@@ -306,7 +312,7 @@ class SHAPAnalyzer:
if explainer is None: if explainer is None:
return {'error': f'No tree model available for {model_type}'} return {'error': f'No tree model available for {model_type}'}
X = self._get_background_sample(n_samples=200) X = self._get_background_sample(n_samples=12)
if X is None: if X is None:
return {'error': 'Failed to prepare background data'} return {'error': 'Failed to prepare background data'}
@@ -362,7 +368,7 @@ class SHAPAnalyzer:
if explainer is None: if explainer is None:
return {'error': f'No tree model available for {model_type}'} return {'error': f'No tree model available for {model_type}'}
X = self._get_background_sample() X = self._get_background_sample(n_samples=24)
if X is None: if X is None:
return {'error': 'Failed to prepare background data'} return {'error': 'Failed to prepare background data'}

View File

@@ -385,6 +385,7 @@ function renderRiskChart() {
// ── Tab 4: SHAP ── // ── Tab 4: SHAP ──
async function loadShapGlobal() { async function loadShapGlobal() {
if (activeTab.value !== 'shap') return
try { try {
const data = await getGlobalImportance(shapModel.value) const data = await getGlobalImportance(shapModel.value)
if (data.error) { ElMessage.error(data.error); return } if (data.error) { ElMessage.error(data.error); return }
@@ -460,6 +461,7 @@ function renderShapDimPie() {
} }
async function loadDependence() { async function loadDependence() {
if (activeTab.value !== 'shap') return
if (!dependenceFeature.value) return if (!dependenceFeature.value) return
try { try {
const data = await getDependence(dependenceFeature.value, shapModel.value) const data = await getDependence(dependenceFeature.value, shapModel.value)
@@ -483,6 +485,7 @@ async function loadDependence() {
} }
async function loadInteractions() { async function loadInteractions() {
if (activeTab.value !== 'shap') return
try { try {
const data = await getInteractions(shapModel.value, 10) const data = await getInteractions(shapModel.value, 10)
if (data.error || !data.top_interactions) return if (data.error || !data.top_interactions) return