# 新闻文本分类系统 - 模块开发任务清单 > 本文档详细列出每个模块需要完成的具体代码任务,参考现有工程结构。 > 注意:爬虫模块使用 Python 实现,而非 Java。 --- ## 目录 1. [爬虫模块 (Python)](#1-爬虫模块-python) 2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot) 3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3) 4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python) --- ## 1. 爬虫模块 (Python) ## 2. 后端服务模块 (Spring Boot) ### 模块目录结构 ``` #### 任务 2.1.10: `application.yml` - 应用配置文件 ```yaml spring: application: name: news-classifier datasource: driver-class-name: com.mysql.cj.jdbc.Driver url: jdbc:mysql://localhost:3306/news_classifier?useUnicode=true&characterEncoding=utf8mb4&serverTimezone=Asia/Shanghai username: root password: your_password data: redis: host: localhost port: 6379 database: 0 # MyBatis-Plus配置 mybatis-plus: mapper-locations: classpath:mapper/**/*.xml type-aliases-package: com.newsclassifier.entity configuration: map-underscore-to-camel-case: true log-impl: org.apache.ibatis.logging.stdout.StdOutImpl global-config: db-config: id-type: auto logic-delete-field: deleted logic-delete-value: 1 logic-not-delete-value: 0 # JWT配置 jwt: secret: your-secret-key-at-least-256-bits-long-for-hs256-algorithm expiration: 86400000 # 分类器配置 classifier: mode: hybrid # traditional, deep_learning, hybrid confidence: threshold: 0.75 hybrid-min: 0.6 bert: service-url: http://localhost:5000/api/predict timeout: 5000 # 日志配置 logging: level: com.newsclassifier: debug pattern: console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n" ``` --- ## 3. 前端桌面模块 (Tauri + Vue3) ### 模块目录结构 ``` client/src/ ├── api/ # API接口 │ ├── index.ts │ ├── auth.ts │ ├── news.ts │ ├── category.ts │ └── classifier.ts ├── assets/ # 静态资源 │ ├── images/ │ ├── styles/ │ │ ├── main.css │ │ └── tailwind.css │ └── fonts/ ├── components/ # 组件 │ ├── ui/ # 基础UI组件 │ │ ├── button/ │ │ ├── input/ │ │ ├── dialog/ │ │ └── table/ │ ├── layout/ # 布局组件 │ │ ├── Header.vue │ │ ├── Sidebar.vue │ │ └── Footer.vue │ ├── news/ # 新闻相关组件 │ │ ├── NewsCard.vue │ │ ├── NewsList.vue │ │ ├── NewsDetail.vue │ │ └── CategoryFilter.vue │ └── charts/ # 图表组件 │ ├── CategoryChart.vue │ └── TrendChart.vue ├── composables/ # 组合式函数 │ ├── useAuth.ts │ ├── useNews.ts │ ├── useClassifier.ts │ └── useToast.ts ├── layouts/ # 布局 │ ├── DefaultLayout.vue │ ├── AuthLayout.vue │ └── EmptyLayout.vue ├── router/ # 路由 (部分完成) │ └── index.ts ├── stores/ # 状态管理 (部分完成) │ ├── user.ts │ ├── news.ts │ └── category.ts ├── types/ # TypeScript类型 │ ├── api.d.ts │ ├── news.d.ts │ └── user.d.ts ├── utils/ # 工具函数 │ ├── request.ts │ ├── storage.ts │ ├── format.ts │ └── validate.ts ├── views/ # 页面 │ ├── auth/ │ │ ├── Login.vue │ │ └── Register.vue │ ├── news/ │ │ ├── NewsList.vue │ │ ├── NewsDetail.vue │ │ └── NewsSearch.vue │ ├── category/ │ │ ├── CategoryManage.vue │ │ └── CategoryStats.vue │ ├── classifier/ │ │ ├── ClassifierPage.vue │ │ └── ModelCompare.vue │ └── admin/ │ ├── Dashboard.vue │ ├── UserManage.vue │ └── SystemLog.vue ├── App.vue └── main.ts ``` ### 3.1 需要完成的具体文件 #### 任务 3.1.1: `utils/request.ts` - HTTP请求封装 ```typescript import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios' // 响应数据类型 interface ApiResponse { code: number message: string data: T } // 创建axios实例 const service: AxiosInstance = axios.create({ baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8080/api', timeout: 15000, headers: { 'Content-Type': 'application/json' } }) // 请求拦截器 service.interceptors.request.use( (config) => { const token = localStorage.getItem('token') if (token) { config.headers.Authorization = `Bearer ${token}` } return config }, (error) => { return Promise.reject(error) } ) // 响应拦截器 service.interceptors.response.use( (response: AxiosResponse) => { const { code, message, data } = response.data if (code === 200) { return data } else { // 处理错误 return Promise.reject(new Error(message || '请求失败')) } }, (error) => { // 处理HTTP错误 return Promise.reject(error) } ) // 封装请求方法 export const http = { get(url: string, config?: AxiosRequestConfig): Promise { return service.get(url, config) }, post(url: string, data?: any, config?: AxiosRequestConfig): Promise { return service.post(url, data, config) }, put(url: string, data?: any, config?: AxiosRequestConfig): Promise { return service.put(url, data, config) }, delete(url: string, config?: AxiosRequestConfig): Promise { return service.delete(url, config) } } export default service ``` #### 任务 3.1.2: `api/news.ts` - 新闻API ```typescript import { http } from './request' // 新闻查询参数 export interface NewsQueryParams { page?: number size?: number categoryId?: number categoryCode?: string keyword?: string status?: number } // 新闻详情 export interface NewsDetail { id: number title: string content: string summary: string source: string sourceUrl: string author: string categoryId: number categoryCode: string coverImage: string publishTime: string viewCount: number likeCount: number commentCount: number classifierType: string confidence: number } // 分页响应 export interface PageResponse { total: number records: T[] current: number size: number } // 新闻API export const newsApi = { // 分页查询 getNewsPage(params: NewsQueryParams): Promise> { return http.get('/news/page', { params }) }, // 获取详情 getNewsDetail(id: number): Promise { return http.get(`/news/${id}`) }, // 搜索新闻 searchNews(keyword: string, page = 1, size = 20): Promise> { return http.get('/news/search', { params: { keyword, page, size } }) }, // 手动分类 manualClassify(id: number, categoryId: number): Promise { return http.post(`/news/${id}/classify`, null, { params: { categoryId } }) } } ``` #### 任务 3.1.3: `composables/useNews.ts` - 新闻组合式函数 ```typescript import { ref, computed } from 'vue' import { newsApi, type NewsQueryParams, type NewsDetail, type PageResponse } from '@/api/news' export function useNews() { const loading = ref(false) const newsList = ref([]) const total = ref(0) const currentNews = ref(null) // 分页查询 const fetchNewsPage = async (params: NewsQueryParams) => { loading.value = true try { const result: PageResponse = await newsApi.getNewsPage(params) newsList.value = result.records total.value = result.total } catch (error) { console.error('获取新闻列表失败:', error) throw error } finally { loading.value = false } } // 获取详情 const fetchNewsDetail = async (id: number) => { loading.value = true try { currentNews.value = await newsApi.getNewsDetail(id) return currentNews.value } catch (error) { console.error('获取新闻详情失败:', error) throw error } finally { loading.value = false } } // 搜索新闻 const searchNews = async (keyword: string, page = 1, size = 20) => { loading.value = true try { const result: PageResponse = await newsApi.searchNews(keyword, page, size) newsList.value = result.records total.value = result.total } catch (error) { console.error('搜索新闻失败:', error) throw error } finally { loading.value = false } } return { loading: computed(() => loading.value), newsList: computed(() => newsList.value), total: computed(() => total.value), currentNews: computed(() => currentNews.value), fetchNewsPage, fetchNewsDetail, searchNews } } ``` #### 任务 3.1.4: `views/news/NewsList.vue` - 新闻列表页面 ```vue ``` #### 任务 3.1.5: `views/classifier/ClassifierPage.vue` - 分类器页面 ```vue ``` #### 任务 3.1.6: `router/index.ts` - 路由配置 (更新) ```typescript import { createRouter, createWebHistory } from 'vue-router' import type { RouteRecordRaw } from 'vue-router' const routes: RouteRecordRaw[] = [ { path: '/login', name: 'Login', component: () => import('@/views/auth/Login.vue'), meta: { layout: 'EmptyLayout' } }, { path: '/', name: 'Home', component: () => import('@/views/news/NewsList.vue'), meta: { requiresAuth: true } }, { path: '/news', name: 'NewsList', component: () => import('@/views/news/NewsList.vue'), meta: { requiresAuth: true } }, { path: '/news/:id', name: 'NewsDetail', component: () => import('@/views/news/NewsDetail.vue'), meta: { requiresAuth: true } }, { path: '/classifier', name: 'Classifier', component: () => import('@/views/classifier/ClassifierPage.vue'), meta: { requiresAuth: true } }, { path: '/category', name: 'CategoryStats', component: () => import('@/views/category/CategoryStats.vue'), meta: { requiresAuth: true } }, { path: '/admin', name: 'AdminDashboard', component: () => import('@/views/admin/Dashboard.vue'), meta: { requiresAuth: true, requiresAdmin: true } } ] const router = createRouter({ history: createWebHistory(), routes }) // 路由守卫 router.beforeEach((to, from, next) => { const token = localStorage.getItem('token') if (to.meta.requiresAuth && !token) { next('/login') } else if (to.meta.requiresAdmin) { // 检查管理员权限 const userRole = localStorage.getItem('userRole') if (userRole !== 'ADMIN') { next('/') } else { next() } } else { next() } }) export default router ``` --- ## 4. 机器学习分类模块 (Python) ### 模块目录结构 ``` ml-module/ ├── data/ │ ├── raw/ # 原始数据 │ ├── processed/ # 处理后的数据 │ │ ├── training_data.csv │ │ └── test_data.csv │ └── external/ # 外部数据集 ├── models/ # 训练好的模型 │ ├── traditional/ │ │ ├── nb_vectorizer.pkl │ │ ├── nb_classifier.pkl │ │ ├── svm_vectorizer.pkl │ │ └── svm_classifier.pkl │ ├── deep_learning/ │ │ └── bert_finetuned/ │ └── hybrid/ │ └── config.json ├── src/ │ ├── __init__.py │ ├── traditional/ # 传统机器学习 │ │ ├── __init__.py │ │ ├── train_model.py # (已有) │ │ ├── predict.py │ │ └── evaluate.py │ ├── deep_learning/ # 深度学习 │ │ ├── __init__.py │ │ ├── bert_model.py │ │ ├── train_bert.py │ │ └── predict_bert.py │ ├── hybrid/ # 混合策略 │ │ ├── __init__.py │ │ ├── hybrid_classifier.py │ │ └── rule_engine.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── preprocessing.py # 数据预处理 │ │ └── metrics.py # 评估指标 │ └── api/ # API服务 │ ├── __init__.py │ └── server.py # FastAPI服务 ├── notebooks/ # Jupyter notebooks │ ├── data_exploration.ipynb │ └── model_comparison.ipynb ├── tests/ # 测试 │ ├── test_traditional.py │ ├── test_bert.py │ └── test_hybrid.py ├── requirements.txt ├── setup.py └── README.md ``` ### 4.1 需要完成的具体文件 #### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测 ```python """ 传统机器学习模型预测 """ import os import joblib import jieba from typing import Dict, Any # 分类映射 CATEGORY_MAP = { 'POLITICS': '时政', 'FINANCE': '财经', 'TECHNOLOGY': '科技', 'SPORTS': '体育', 'ENTERTAINMENT': '娱乐', 'HEALTH': '健康', 'EDUCATION': '教育', 'LIFE': '生活', 'INTERNATIONAL': '国际', 'MILITARY': '军事' } class TraditionalPredictor: """传统机器学习预测器""" def __init__(self, model_type='nb', model_dir='../../models/traditional'): self.model_type = model_type self.model_dir = model_dir self.vectorizer = None self.classifier = None self._load_model() def _load_model(self): """加载模型""" vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl') classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl') self.vectorizer = joblib.load(vectorizer_path) self.classifier = joblib.load(classifier_path) print(f"模型加载成功: {self.model_type}") def preprocess(self, title: str, content: str) -> str: """预处理文本""" text = title + ' ' + content # jieba分词 words = jieba.cut(text) return ' '.join(words) def predict(self, title: str, content: str) -> Dict[str, Any]: """ 预测 :return: 预测结果字典 """ # 预处理 processed = self.preprocess(title, content) # 特征提取 tfidf = self.vectorizer.transform([processed]) # 预测 prediction = self.classifier.predict(tfidf)[0] probabilities = self.classifier.predict_proba(tfidf)[0] # 获取各类别概率 prob_dict = {} for i, prob in enumerate(probabilities): category_code = self.classifier.classes_[i] prob_dict[category_code] = float(prob) return { 'categoryCode': prediction, 'categoryName': CATEGORY_MAP.get(prediction, '未知'), 'confidence': float(probabilities.max()), 'probabilities': prob_dict } # API入口 def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]: """ 单条预测API """ predictor = TraditionalPredictor(model_type) return predictor.predict(title, content) if __name__ == '__main__': # 测试 result = predict_single( title="华为发布新款折叠屏手机", content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..." ) print(result) ``` #### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型 ```python """ BERT文本分类模型 """ import torch from transformers import ( BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments ) from typing import Dict, Any, List # 分类映射 CATEGORY_MAP = { 'POLITICS': '时政', 'FINANCE': '财经', 'TECHNOLOGY': '科技', 'SPORTS': '体育', 'ENTERTAINMENT': '娱乐', 'HEALTH': '健康', 'EDUCATION': '教育', 'LIFE': '生活', 'INTERNATIONAL': '国际', 'MILITARY': '军事' } # 反向映射 ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())} LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())} class BertClassifier: """BERT文本分类器""" def __init__(self, model_name='bert-base-chinese', num_labels=10): self.model_name = model_name self.num_labels = num_labels self.tokenizer = None self.model = None self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(self, model_path): """加载微调后的模型""" self.tokenizer = BertTokenizer.from_pretrained(model_path) self.model = BertForSequenceClassification.from_pretrained( model_path, num_labels=self.num_labels ) self.model.to(self.device) self.model.eval() print(f"BERT模型加载成功: {model_path}") def predict(self, title: str, content: str) -> Dict[str, Any]: """ 预测 """ if self.model is None or self.tokenizer is None: raise ValueError("模型未加载,请先调用load_model") # 组合标题和内容 text = f"{title} [SEP] {content}" # 分词 inputs = self.tokenizer( text, return_tensors='pt', truncation=True, max_length=512, padding='max_length' ) # 预测 with torch.no_grad(): inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) logits = outputs.logits # 获取预测结果 probs = torch.softmax(logits, dim=-1) confidence, predicted_id = torch.max(probs, dim=-1) predicted_id = predicted_id.item() confidence = confidence.item() # 获取各类别概率 prob_dict = {} for i, prob in enumerate(probs[0].cpu().numpy()): category_code = ID_TO_LABEL[i] prob_dict[category_code] = float(prob) return { 'categoryCode': ID_TO_LABEL[predicted_id], 'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'), 'confidence': confidence, 'probabilities': prob_dict } # 数据集类 class NewsDataset(torch.utils.data.Dataset): """新闻数据集""" def __init__(self, texts, labels, tokenizer, max_length=512): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] label = self.labels[idx] encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding='max_length', return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } if __name__ == '__main__': # 测试 classifier = BertClassifier() # classifier.load_model('./models/deep_learning/bert_finetuned') # # result = classifier.predict( # title="华为发布新款折叠屏手机", # content="华为今天正式发布了新一代折叠屏手机..." # ) # print(result) print("BERT分类器初始化成功") ``` #### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器 ```python """ 混合策略分类器 结合规则引擎和机器学习模型 """ import time from typing import Dict, Any from ..traditional.predict import TraditionalPredictor from ..deep_learning.bert_model import BertClassifier class HybridClassifier: """混合分类器""" def __init__(self): # 初始化各个分类器 self.nb_predictor = TraditionalPredictor('nb') self.bert_classifier = BertClassifier() # 配置参数 self.config = { 'confidence_threshold': 0.75, # 高置信度阈值 'hybrid_min_confidence': 0.60, # 混合模式最低阈值 'use_bert_threshold': 0.70, # 使用BERT的阈值 'rule_priority': True # 规则优先 } # 规则关键词字典 self.rule_keywords = { 'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'], 'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'], 'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'], 'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'], 'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'], 'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'], 'EDUCATION': ['教育', '学校', '大学', '考试', '招生'], 'LIFE': ['生活', '美食', '旅游', '购物'], 'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'], 'MILITARY': ['军事', '武器', '军队', '国防', '战争'] } def rule_match(self, title: str, content: str) -> tuple[str | None, float]: """ 规则匹配 :return: (category_code, confidence) """ text = title + ' ' + content # 计算每个类别的关键词匹配数 matches = {} for category, keywords in self.rule_keywords.items(): count = sum(1 for kw in keywords if kw in text) if count > 0: matches[category] = count if not matches: return None, 0.0 # 返回匹配最多的类别 best_category = max(matches, key=matches.get) confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度 return best_category, confidence def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]: """ 混合预测 """ start_time = time.time() # 1. 先尝试规则匹配 rule_category, rule_confidence = self.rule_match(title, content) # 2. 传统机器学习预测 nb_result = self.nb_predictor.predict(title, content) nb_confidence = nb_result['confidence'] # 决策逻辑 final_result = None classifier_type = 'HYBRID' # 规则优先且规则置信度高 if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']: final_result = { 'categoryCode': rule_category, 'categoryName': nb_result['categoryName'], # 从映射获取 'confidence': rule_confidence, 'classifierType': 'RULE', 'reason': '规则匹配' } # 传统模型置信度足够高 elif nb_confidence >= self.config['confidence_threshold']: final_result = { **nb_result, 'classifierType': 'ML', 'reason': '传统模型高置信度' } # 需要使用BERT elif use_bert: # TODO: 加载BERT模型预测 # bert_result = self.bert_classifier.predict(title, content) # 如果BERT置信度也不高,选择最高的 final_result = { **nb_result, 'classifierType': 'HYBRID', 'reason': '混合决策' } else: # 不使用BERT,直接返回传统模型结果 final_result = { **nb_result, 'classifierType': 'ML', 'reason': '默认传统模型' } # 计算耗时 duration = int((time.time() - start_time) * 1000) final_result['duration'] = duration return final_result if __name__ == '__main__': # 测试 classifier = HybridClassifier() test_cases = [ { 'title': '国务院发布最新经济政策', 'content': '国务院今天发布了新的经济政策...' }, { 'title': '华为发布新款折叠屏手机', 'content': '华为今天正式发布了新一代折叠屏手机...' } ] for case in test_cases: result = classifier.predict(case['title'], case['content']) print(f"标题: {case['title']}") print(f"结果: {result['categoryName']} ({result['confidence']:.2f})") print(f"分类器: {result['classifierType']}") print(f"原因: {result.get('reason', 'N/A')}") print(f"耗时: {result['duration']}ms") print("-" * 50) ``` #### 任务 4.1.4: `src/api/server.py` - FastAPI服务 ```python """ 机器学习模型API服务 使用FastAPI提供RESTful API """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional import logging # 导入分类器 import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from traditional.predict import TraditionalPredictor from hybrid.hybrid_classifier import HybridClassifier # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 创建FastAPI应用 app = FastAPI( title="新闻分类API", description="提供新闻文本分类服务", version="1.0.0" ) # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 请求模型 class ClassifyRequest(BaseModel): title: str content: str mode: Optional[str] = 'hybrid' # traditional, hybrid # 响应模型 class ClassifyResponse(BaseModel): categoryCode: str categoryName: str confidence: float classifierType: str duration: int probabilities: Optional[dict] = None # 初始化分类器 nb_predictor = None hybrid_classifier = None @app.on_event("startup") async def startup_event(): """启动时加载模型""" global nb_predictor, hybrid_classifier logger.info("加载模型...") try: nb_predictor = TraditionalPredictor('nb') logger.info("朴素贝叶斯模型加载成功") except Exception as e: logger.error(f"朴素贝叶斯模型加载失败: {e}") try: hybrid_classifier = HybridClassifier() logger.info("混合分类器初始化成功") except Exception as e: logger.error(f"混合分类器初始化失败: {e}") @app.get("/") async def root(): """健康检查""" return { "status": "ok", "message": "新闻分类API服务运行中" } @app.get("/health") async def health_check(): """健康检查""" return { "status": "healthy", "models": { "nb_loaded": nb_predictor is not None, "hybrid_loaded": hybrid_classifier is not None } } @app.post("/api/predict", response_model=ClassifyResponse) async def predict(request: ClassifyRequest): """ 文本分类接口 - **title**: 新闻标题 - **content**: 新闻内容 - **mode**: 分类模式 (traditional, hybrid) """ try: if request.mode == 'traditional': result = nb_predictor.predict(request.title, request.content) result['classifierType'] = 'ML' else: # hybrid result = hybrid_classifier.predict(request.title, request.content) return ClassifyResponse(**result) except Exception as e: logger.error(f"预测失败: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/batch-predict") async def batch_predict(requests: list[ClassifyRequest]): """ 批量分类接口 """ results = [] for req in requests: try: if req.mode == 'traditional': result = nb_predictor.predict(req.title, req.content) result['classifierType'] = 'ML' else: result = hybrid_classifier.predict(req.title, req.content) results.append(result) except Exception as e: results.append({ 'error': str(e), 'title': req.title }) return {"results": results} if __name__ == '__main__': import uvicorn uvicorn.run( app, host="0.0.0.0", port=5000, log_level="info" ) ``` #### 任务 4.1.5: `src/utils/metrics.py` - 评估指标 ```python """ 模型评估指标工具 """ import numpy as np from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report ) from typing import List, Dict, Any import matplotlib.pyplot as plt import seaborn as sns class ClassificationMetrics: """分类评估指标""" @staticmethod def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]: """ 计算所有指标 """ accuracy = accuracy_score(y_true, y_pred) precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, average='weighted', zero_division=0 ) # 每个类别的指标 precision_per_class, recall_per_class, f1_per_class, support_per_class = \ precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0) per_class_metrics = {} for i, label in enumerate(labels): per_class_metrics[label] = { 'precision': float(precision_per_class[i]), 'recall': float(recall_per_class[i]), 'f1': float(f1_per_class[i]), 'support': int(support_per_class[i]) } return { 'accuracy': float(accuracy), 'precision': float(precision), 'recall': float(recall), 'f1': float(f1), 'per_class': per_class_metrics } @staticmethod def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None): """ 绘制混淆矩阵 """ cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels ) plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() @staticmethod def print_report(y_true: List, y_pred: List, labels: List[str]): """ 打印分类报告 """ report = classification_report( y_true, y_pred, target_names=labels, zero_division=0 ) print(report) if __name__ == '__main__': # 测试 y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY'] y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY'] labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE'] metrics = ClassificationMetrics() result = metrics.compute_all(y_true, y_pred, labels) print(result) ``` #### 任务 4.1.6: `requirements.txt` - 依赖文件 ```txt # 机器学习模块依赖 numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 jieba>=0.42.0 joblib>=1.3.0 # 深度学习 torch>=2.0.0 transformers>=4.30.0 # API服务 fastapi>=0.100.0 uvicorn[standard]>=0.23.0 pydantic>=2.0.0 # 数据可视化 matplotlib>=3.7.0 seaborn>=0.12.0 # 工具 python-dotenv>=1.0.0 pyyaml>=6.0 ``` --- ## 总结 ### 开发顺序建议 1. **第一阶段:基础框架** - 后端:数据库连接、实体类、基础配置 - 前端:路由配置、状态管理、API封装 2. **第二阶段:核心功能** - 爬虫模块(Python) - 传统机器学习分类器 - 后端API接口 - 前端新闻列表页面 3. **第三阶段:高级功能** - BERT深度学习分类器 - 混合策略分类器 - 前端分类器对比页面 - 统计图表 4. **第四阶段:完善优化** - 用户认证 - 数据可视化 - 性能优化 - 异常处理 ### 关键注意事项 1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信 2. **分类器模块独立部署**,提供 HTTP 接口供后端调用 3. **前后端分离**,使用 JWT 进行身份认证 4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守 5. **API 统一响应格式**使用 `Result` 包装