optimize shap response latency
This commit is contained in:
@@ -1,17 +1,31 @@
|
|||||||
from flask import Blueprint, jsonify, request
|
from flask import Blueprint, jsonify, request
|
||||||
|
import time
|
||||||
|
|
||||||
from services.shap_service import shap_service
|
from services.shap_service import shap_service
|
||||||
|
|
||||||
shap_bp = Blueprint('shap', __name__, url_prefix='/api/shap')
|
shap_bp = Blueprint('shap', __name__, url_prefix='/api/shap')
|
||||||
|
|
||||||
|
|
||||||
|
def log_shap(route, stage, extra=''):
|
||||||
|
message = f'[SHAP] route={route} stage={stage}'
|
||||||
|
if extra:
|
||||||
|
message = f'{message} {extra}'
|
||||||
|
print(message, flush=True)
|
||||||
|
|
||||||
|
|
||||||
@shap_bp.route('/global', methods=['GET'])
|
@shap_bp.route('/global', methods=['GET'])
|
||||||
def get_global_importance():
|
def get_global_importance():
|
||||||
|
started_at = time.time()
|
||||||
try:
|
try:
|
||||||
model_type = request.args.get('model', 'random_forest')
|
model_type = request.args.get('model', 'random_forest')
|
||||||
|
log_shap('/global', 'start', f'model={model_type}')
|
||||||
result = shap_service.get_global_importance(model_type)
|
result = shap_service.get_global_importance(model_type)
|
||||||
|
elapsed = round(time.time() - started_at, 3)
|
||||||
|
log_shap('/global', 'success', f'model={model_type} elapsed={elapsed}s')
|
||||||
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = round(time.time() - started_at, 3)
|
||||||
|
log_shap('/global', 'error', f'elapsed={elapsed}s error={e}')
|
||||||
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
||||||
|
|
||||||
|
|
||||||
@@ -30,21 +44,33 @@ def get_local_explanation():
|
|||||||
|
|
||||||
@shap_bp.route('/interaction', methods=['GET'])
|
@shap_bp.route('/interaction', methods=['GET'])
|
||||||
def get_interactions():
|
def get_interactions():
|
||||||
|
started_at = time.time()
|
||||||
try:
|
try:
|
||||||
model_type = request.args.get('model', 'random_forest')
|
model_type = request.args.get('model', 'random_forest')
|
||||||
top_n = int(request.args.get('top_n', 10))
|
top_n = int(request.args.get('top_n', 10))
|
||||||
|
log_shap('/interaction', 'start', f'model={model_type} top_n={top_n}')
|
||||||
result = shap_service.get_interactions(model_type, top_n)
|
result = shap_service.get_interactions(model_type, top_n)
|
||||||
|
elapsed = round(time.time() - started_at, 3)
|
||||||
|
log_shap('/interaction', 'success', f'model={model_type} elapsed={elapsed}s')
|
||||||
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = round(time.time() - started_at, 3)
|
||||||
|
log_shap('/interaction', 'error', f'elapsed={elapsed}s error={e}')
|
||||||
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
||||||
|
|
||||||
|
|
||||||
@shap_bp.route('/dependence', methods=['GET'])
|
@shap_bp.route('/dependence', methods=['GET'])
|
||||||
def get_dependence():
|
def get_dependence():
|
||||||
|
started_at = time.time()
|
||||||
try:
|
try:
|
||||||
feature = request.args.get('feature', '月均加班时长')
|
feature = request.args.get('feature', '月均加班时长')
|
||||||
model_type = request.args.get('model', 'random_forest')
|
model_type = request.args.get('model', 'random_forest')
|
||||||
|
log_shap('/dependence', 'start', f'model={model_type} feature={feature}')
|
||||||
result = shap_service.get_dependence(feature, model_type)
|
result = shap_service.get_dependence(feature, model_type)
|
||||||
|
elapsed = round(time.time() - started_at, 3)
|
||||||
|
log_shap('/dependence', 'success', f'model={model_type} elapsed={elapsed}s')
|
||||||
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
return jsonify({'code': 200, 'message': 'success', 'data': result})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
elapsed = round(time.time() - started_at, 3)
|
||||||
|
log_shap('/dependence', 'error', f'elapsed={elapsed}s error={e}')
|
||||||
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
return jsonify({'code': 500, 'message': str(e), 'data': None}), 500
|
||||||
|
|||||||
@@ -23,7 +23,8 @@ class SHAPAnalyzer:
|
|||||||
self.feature_names = None
|
self.feature_names = None
|
||||||
self.selected_features = None
|
self.selected_features = None
|
||||||
self.label_encoders = {}
|
self.label_encoders = {}
|
||||||
self.background_data = None
|
self.background_data = {}
|
||||||
|
self.global_result_cache = {}
|
||||||
self._initialized = False
|
self._initialized = False
|
||||||
|
|
||||||
def _ensure_initialized(self):
|
def _ensure_initialized(self):
|
||||||
@@ -85,8 +86,8 @@ class SHAPAnalyzer:
|
|||||||
|
|
||||||
def _get_background_sample(self, n_samples=500):
|
def _get_background_sample(self, n_samples=500):
|
||||||
"""获取背景数据样本"""
|
"""获取背景数据样本"""
|
||||||
if self.background_data is not None:
|
if n_samples in self.background_data:
|
||||||
return self.background_data
|
return self.background_data[n_samples]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from core.preprocessing import get_clean_data
|
from core.preprocessing import get_clean_data
|
||||||
@@ -123,7 +124,7 @@ class SHAPAnalyzer:
|
|||||||
if selected_indices:
|
if selected_indices:
|
||||||
X = X[:, selected_indices]
|
X = X[:, selected_indices]
|
||||||
|
|
||||||
self.background_data = X
|
self.background_data[n_samples] = X
|
||||||
return X
|
return X
|
||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
@@ -151,12 +152,15 @@ class SHAPAnalyzer:
|
|||||||
if not SHAP_AVAILABLE:
|
if not SHAP_AVAILABLE:
|
||||||
return {'error': 'SHAP library not installed'}
|
return {'error': 'SHAP library not installed'}
|
||||||
|
|
||||||
|
if model_type in self.global_result_cache:
|
||||||
|
return self.global_result_cache[model_type]
|
||||||
|
|
||||||
self._ensure_initialized()
|
self._ensure_initialized()
|
||||||
explainer = self._get_tree_explainer(model_type)
|
explainer = self._get_tree_explainer(model_type)
|
||||||
if explainer is None:
|
if explainer is None:
|
||||||
return {'error': f'No tree model available for {model_type}'}
|
return {'error': f'No tree model available for {model_type}'}
|
||||||
|
|
||||||
X = self._get_background_sample()
|
X = self._get_background_sample(n_samples=32)
|
||||||
if X is None:
|
if X is None:
|
||||||
return {'error': 'Failed to prepare background data'}
|
return {'error': 'Failed to prepare background data'}
|
||||||
|
|
||||||
@@ -215,11 +219,13 @@ class SHAPAnalyzer:
|
|||||||
'dimension': self._map_feature_to_dimension(fname),
|
'dimension': self._map_feature_to_dimension(fname),
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
'model_type': model_type,
|
'model_type': model_type,
|
||||||
'dimensions': dimensions,
|
'dimensions': dimensions,
|
||||||
'top_features': top_features,
|
'top_features': top_features,
|
||||||
}
|
}
|
||||||
|
self.global_result_cache[model_type] = result
|
||||||
|
return result
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
return {'error': str(exc)}
|
return {'error': str(exc)}
|
||||||
|
|
||||||
@@ -306,7 +312,7 @@ class SHAPAnalyzer:
|
|||||||
if explainer is None:
|
if explainer is None:
|
||||||
return {'error': f'No tree model available for {model_type}'}
|
return {'error': f'No tree model available for {model_type}'}
|
||||||
|
|
||||||
X = self._get_background_sample(n_samples=200)
|
X = self._get_background_sample(n_samples=12)
|
||||||
if X is None:
|
if X is None:
|
||||||
return {'error': 'Failed to prepare background data'}
|
return {'error': 'Failed to prepare background data'}
|
||||||
|
|
||||||
@@ -362,7 +368,7 @@ class SHAPAnalyzer:
|
|||||||
if explainer is None:
|
if explainer is None:
|
||||||
return {'error': f'No tree model available for {model_type}'}
|
return {'error': f'No tree model available for {model_type}'}
|
||||||
|
|
||||||
X = self._get_background_sample()
|
X = self._get_background_sample(n_samples=24)
|
||||||
if X is None:
|
if X is None:
|
||||||
return {'error': 'Failed to prepare background data'}
|
return {'error': 'Failed to prepare background data'}
|
||||||
|
|
||||||
|
|||||||
@@ -385,6 +385,7 @@ function renderRiskChart() {
|
|||||||
|
|
||||||
// ── Tab 4: SHAP ──
|
// ── Tab 4: SHAP ──
|
||||||
async function loadShapGlobal() {
|
async function loadShapGlobal() {
|
||||||
|
if (activeTab.value !== 'shap') return
|
||||||
try {
|
try {
|
||||||
const data = await getGlobalImportance(shapModel.value)
|
const data = await getGlobalImportance(shapModel.value)
|
||||||
if (data.error) { ElMessage.error(data.error); return }
|
if (data.error) { ElMessage.error(data.error); return }
|
||||||
@@ -460,6 +461,7 @@ function renderShapDimPie() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function loadDependence() {
|
async function loadDependence() {
|
||||||
|
if (activeTab.value !== 'shap') return
|
||||||
if (!dependenceFeature.value) return
|
if (!dependenceFeature.value) return
|
||||||
try {
|
try {
|
||||||
const data = await getDependence(dependenceFeature.value, shapModel.value)
|
const data = await getDependence(dependenceFeature.value, shapModel.value)
|
||||||
@@ -483,6 +485,7 @@ async function loadDependence() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function loadInteractions() {
|
async function loadInteractions() {
|
||||||
|
if (activeTab.value !== 'shap') return
|
||||||
try {
|
try {
|
||||||
const data = await getInteractions(shapModel.value, 10)
|
const data = await getInteractions(shapModel.value, 10)
|
||||||
if (data.error || !data.top_interactions) return
|
if (data.error || !data.top_interactions) return
|
||||||
|
|||||||
Reference in New Issue
Block a user