From 1ee766720b4abf07a897c7ab02e5f52007d50509 Mon Sep 17 00:00:00 2001 From: shenjianZ Date: Sat, 4 Apr 2026 07:47:32 +0800 Subject: [PATCH] optimize shap response latency --- backend/api/shap_routes.py | 26 ++++++++++++++++++++++++++ backend/core/shap_analysis.py | 22 ++++++++++++++-------- frontend/src/views/JDRAnalysis.vue | 3 +++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/backend/api/shap_routes.py b/backend/api/shap_routes.py index 67270a6..a3bc261 100644 --- a/backend/api/shap_routes.py +++ b/backend/api/shap_routes.py @@ -1,17 +1,31 @@ from flask import Blueprint, jsonify, request +import time from services.shap_service import shap_service shap_bp = Blueprint('shap', __name__, url_prefix='/api/shap') +def log_shap(route, stage, extra=''): + message = f'[SHAP] route={route} stage={stage}' + if extra: + message = f'{message} {extra}' + print(message, flush=True) + + @shap_bp.route('/global', methods=['GET']) def get_global_importance(): + started_at = time.time() try: model_type = request.args.get('model', 'random_forest') + log_shap('/global', 'start', f'model={model_type}') result = shap_service.get_global_importance(model_type) + elapsed = round(time.time() - started_at, 3) + log_shap('/global', 'success', f'model={model_type} elapsed={elapsed}s') return jsonify({'code': 200, 'message': 'success', 'data': result}) except Exception as e: + elapsed = round(time.time() - started_at, 3) + log_shap('/global', 'error', f'elapsed={elapsed}s error={e}') return jsonify({'code': 500, 'message': str(e), 'data': None}), 500 @@ -30,21 +44,33 @@ def get_local_explanation(): @shap_bp.route('/interaction', methods=['GET']) def get_interactions(): + started_at = time.time() try: model_type = request.args.get('model', 'random_forest') top_n = int(request.args.get('top_n', 10)) + log_shap('/interaction', 'start', f'model={model_type} top_n={top_n}') result = shap_service.get_interactions(model_type, top_n) + elapsed = round(time.time() - started_at, 3) + log_shap('/interaction', 'success', f'model={model_type} elapsed={elapsed}s') return jsonify({'code': 200, 'message': 'success', 'data': result}) except Exception as e: + elapsed = round(time.time() - started_at, 3) + log_shap('/interaction', 'error', f'elapsed={elapsed}s error={e}') return jsonify({'code': 500, 'message': str(e), 'data': None}), 500 @shap_bp.route('/dependence', methods=['GET']) def get_dependence(): + started_at = time.time() try: feature = request.args.get('feature', '月均加班时长') model_type = request.args.get('model', 'random_forest') + log_shap('/dependence', 'start', f'model={model_type} feature={feature}') result = shap_service.get_dependence(feature, model_type) + elapsed = round(time.time() - started_at, 3) + log_shap('/dependence', 'success', f'model={model_type} elapsed={elapsed}s') return jsonify({'code': 200, 'message': 'success', 'data': result}) except Exception as e: + elapsed = round(time.time() - started_at, 3) + log_shap('/dependence', 'error', f'elapsed={elapsed}s error={e}') return jsonify({'code': 500, 'message': str(e), 'data': None}), 500 diff --git a/backend/core/shap_analysis.py b/backend/core/shap_analysis.py index 0f22d60..a3995f4 100644 --- a/backend/core/shap_analysis.py +++ b/backend/core/shap_analysis.py @@ -23,7 +23,8 @@ class SHAPAnalyzer: self.feature_names = None self.selected_features = None self.label_encoders = {} - self.background_data = None + self.background_data = {} + self.global_result_cache = {} self._initialized = False def _ensure_initialized(self): @@ -85,8 +86,8 @@ class SHAPAnalyzer: def _get_background_sample(self, n_samples=500): """获取背景数据样本""" - if self.background_data is not None: - return self.background_data + if n_samples in self.background_data: + return self.background_data[n_samples] try: from core.preprocessing import get_clean_data @@ -123,7 +124,7 @@ class SHAPAnalyzer: if selected_indices: X = X[:, selected_indices] - self.background_data = X + self.background_data[n_samples] = X return X except Exception: return None @@ -151,12 +152,15 @@ class SHAPAnalyzer: if not SHAP_AVAILABLE: return {'error': 'SHAP library not installed'} + if model_type in self.global_result_cache: + return self.global_result_cache[model_type] + self._ensure_initialized() explainer = self._get_tree_explainer(model_type) if explainer is None: return {'error': f'No tree model available for {model_type}'} - X = self._get_background_sample() + X = self._get_background_sample(n_samples=32) if X is None: return {'error': 'Failed to prepare background data'} @@ -215,11 +219,13 @@ class SHAPAnalyzer: 'dimension': self._map_feature_to_dimension(fname), }) - return { + result = { 'model_type': model_type, 'dimensions': dimensions, 'top_features': top_features, } + self.global_result_cache[model_type] = result + return result except Exception as exc: return {'error': str(exc)} @@ -306,7 +312,7 @@ class SHAPAnalyzer: if explainer is None: return {'error': f'No tree model available for {model_type}'} - X = self._get_background_sample(n_samples=200) + X = self._get_background_sample(n_samples=12) if X is None: return {'error': 'Failed to prepare background data'} @@ -362,7 +368,7 @@ class SHAPAnalyzer: if explainer is None: return {'error': f'No tree model available for {model_type}'} - X = self._get_background_sample() + X = self._get_background_sample(n_samples=24) if X is None: return {'error': 'Failed to prepare background data'} diff --git a/frontend/src/views/JDRAnalysis.vue b/frontend/src/views/JDRAnalysis.vue index 0d4963d..764e1a6 100644 --- a/frontend/src/views/JDRAnalysis.vue +++ b/frontend/src/views/JDRAnalysis.vue @@ -385,6 +385,7 @@ function renderRiskChart() { // ── Tab 4: SHAP ── async function loadShapGlobal() { + if (activeTab.value !== 'shap') return try { const data = await getGlobalImportance(shapModel.value) if (data.error) { ElMessage.error(data.error); return } @@ -460,6 +461,7 @@ function renderShapDimPie() { } async function loadDependence() { + if (activeTab.value !== 'shap') return if (!dependenceFeature.value) return try { const data = await getDependence(dependenceFeature.value, shapModel.value) @@ -483,6 +485,7 @@ async function loadDependence() { } async function loadInteractions() { + if (activeTab.value !== 'shap') return try { const data = await getInteractions(shapModel.value, 10) if (data.error || !data.top_interactions) return