801 lines
22 KiB
Markdown
801 lines
22 KiB
Markdown
# 新闻文本分类系统 - 模块开发任务清单
|
||
---
|
||
## 目录
|
||
|
||
1. [爬虫模块 (Python)](#1-爬虫模块-python)
|
||
2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot)
|
||
3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3)
|
||
4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python)
|
||
|
||
---
|
||
|
||
## 1. 爬虫模块 (Python)
|
||
|
||
## 2. 后端服务模块 (Spring Boot)
|
||
|
||
## 4. 机器学习分类模块 (Python)
|
||
|
||
### 模块目录结构
|
||
|
||
```
|
||
ml-module/
|
||
├── data/
|
||
│ ├── raw/ # 原始数据
|
||
│ ├── processed/ # 处理后的数据
|
||
│ │ ├── training_data.csv
|
||
│ │ └── test_data.csv
|
||
│ └── external/ # 外部数据集
|
||
├── models/ # 训练好的模型
|
||
│ ├── traditional/
|
||
│ │ ├── nb_vectorizer.pkl
|
||
│ │ ├── nb_classifier.pkl
|
||
│ │ ├── svm_vectorizer.pkl
|
||
│ │ └── svm_classifier.pkl
|
||
│ ├── deep_learning/
|
||
│ │ └── bert_finetuned/
|
||
│ └── hybrid/
|
||
│ └── config.json
|
||
├── src/
|
||
│ ├── __init__.py
|
||
│ ├── traditional/ # 传统机器学习
|
||
│ │ ├── __init__.py
|
||
│ │ ├── train_model.py # (已有)
|
||
│ │ ├── predict.py
|
||
│ │ └── evaluate.py
|
||
│ ├── deep_learning/ # 深度学习
|
||
│ │ ├── __init__.py
|
||
│ │ ├── bert_model.py
|
||
│ │ ├── train_bert.py
|
||
│ │ └── predict_bert.py
|
||
│ ├── hybrid/ # 混合策略
|
||
│ │ ├── __init__.py
|
||
│ │ ├── hybrid_classifier.py
|
||
│ │ └── rule_engine.py
|
||
│ ├── utils/
|
||
│ │ ├── __init__.py
|
||
│ │ ├── preprocessing.py # 数据预处理
|
||
│ │ └── metrics.py # 评估指标
|
||
│ └── api/ # API服务
|
||
│ ├── __init__.py
|
||
│ └── server.py # FastAPI服务
|
||
├── notebooks/ # Jupyter notebooks
|
||
│ ├── data_exploration.ipynb
|
||
│ └── model_comparison.ipynb
|
||
├── tests/ # 测试
|
||
│ ├── test_traditional.py
|
||
│ ├── test_bert.py
|
||
│ └── test_hybrid.py
|
||
├── requirements.txt
|
||
├── setup.py
|
||
└── README.md
|
||
```
|
||
|
||
### 4.1 需要完成的具体文件
|
||
|
||
#### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测
|
||
|
||
```python
|
||
"""
|
||
传统机器学习模型预测
|
||
"""
|
||
|
||
import os
|
||
import joblib
|
||
import jieba
|
||
from typing import Dict, Any
|
||
|
||
# 分类映射
|
||
CATEGORY_MAP = {
|
||
'POLITICS': '时政',
|
||
'FINANCE': '财经',
|
||
'TECHNOLOGY': '科技',
|
||
'SPORTS': '体育',
|
||
'ENTERTAINMENT': '娱乐',
|
||
'HEALTH': '健康',
|
||
'EDUCATION': '教育',
|
||
'LIFE': '生活',
|
||
'INTERNATIONAL': '国际',
|
||
'MILITARY': '军事'
|
||
}
|
||
|
||
|
||
class TraditionalPredictor:
|
||
"""传统机器学习预测器"""
|
||
|
||
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
|
||
self.model_type = model_type
|
||
self.model_dir = model_dir
|
||
self.vectorizer = None
|
||
self.classifier = None
|
||
self._load_model()
|
||
|
||
def _load_model(self):
|
||
"""加载模型"""
|
||
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
|
||
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
|
||
|
||
self.vectorizer = joblib.load(vectorizer_path)
|
||
self.classifier = joblib.load(classifier_path)
|
||
print(f"模型加载成功: {self.model_type}")
|
||
|
||
def preprocess(self, title: str, content: str) -> str:
|
||
"""预处理文本"""
|
||
text = title + ' ' + content
|
||
# jieba分词
|
||
words = jieba.cut(text)
|
||
return ' '.join(words)
|
||
|
||
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
||
"""
|
||
预测
|
||
:return: 预测结果字典
|
||
"""
|
||
# 预处理
|
||
processed = self.preprocess(title, content)
|
||
|
||
# 特征提取
|
||
tfidf = self.vectorizer.transform([processed])
|
||
|
||
# 预测
|
||
prediction = self.classifier.predict(tfidf)[0]
|
||
probabilities = self.classifier.predict_proba(tfidf)[0]
|
||
|
||
# 获取各类别概率
|
||
prob_dict = {}
|
||
for i, prob in enumerate(probabilities):
|
||
category_code = self.classifier.classes_[i]
|
||
prob_dict[category_code] = float(prob)
|
||
|
||
return {
|
||
'categoryCode': prediction,
|
||
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
|
||
'confidence': float(probabilities.max()),
|
||
'probabilities': prob_dict
|
||
}
|
||
|
||
|
||
# API入口
|
||
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
|
||
"""
|
||
单条预测API
|
||
"""
|
||
predictor = TraditionalPredictor(model_type)
|
||
return predictor.predict(title, content)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试
|
||
result = predict_single(
|
||
title="华为发布新款折叠屏手机",
|
||
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
|
||
)
|
||
print(result)
|
||
```
|
||
|
||
#### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型
|
||
|
||
```python
|
||
"""
|
||
BERT文本分类模型
|
||
"""
|
||
|
||
import torch
|
||
from transformers import (
|
||
BertTokenizer,
|
||
BertForSequenceClassification,
|
||
Trainer,
|
||
TrainingArguments
|
||
)
|
||
from typing import Dict, Any, List
|
||
|
||
|
||
# 分类映射
|
||
CATEGORY_MAP = {
|
||
'POLITICS': '时政',
|
||
'FINANCE': '财经',
|
||
'TECHNOLOGY': '科技',
|
||
'SPORTS': '体育',
|
||
'ENTERTAINMENT': '娱乐',
|
||
'HEALTH': '健康',
|
||
'EDUCATION': '教育',
|
||
'LIFE': '生活',
|
||
'INTERNATIONAL': '国际',
|
||
'MILITARY': '军事'
|
||
}
|
||
|
||
# 反向映射
|
||
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
|
||
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
|
||
|
||
|
||
class BertClassifier:
|
||
"""BERT文本分类器"""
|
||
|
||
def __init__(self, model_name='bert-base-chinese', num_labels=10):
|
||
self.model_name = model_name
|
||
self.num_labels = num_labels
|
||
self.tokenizer = None
|
||
self.model = None
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
def load_model(self, model_path):
|
||
"""加载微调后的模型"""
|
||
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||
self.model = BertForSequenceClassification.from_pretrained(
|
||
model_path,
|
||
num_labels=self.num_labels
|
||
)
|
||
self.model.to(self.device)
|
||
self.model.eval()
|
||
print(f"BERT模型加载成功: {model_path}")
|
||
|
||
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
||
"""
|
||
预测
|
||
"""
|
||
if self.model is None or self.tokenizer is None:
|
||
raise ValueError("模型未加载,请先调用load_model")
|
||
|
||
# 组合标题和内容
|
||
text = f"{title} [SEP] {content}"
|
||
|
||
# 分词
|
||
inputs = self.tokenizer(
|
||
text,
|
||
return_tensors='pt',
|
||
truncation=True,
|
||
max_length=512,
|
||
padding='max_length'
|
||
)
|
||
|
||
# 预测
|
||
with torch.no_grad():
|
||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||
outputs = self.model(**inputs)
|
||
logits = outputs.logits
|
||
|
||
# 获取预测结果
|
||
probs = torch.softmax(logits, dim=-1)
|
||
confidence, predicted_id = torch.max(probs, dim=-1)
|
||
|
||
predicted_id = predicted_id.item()
|
||
confidence = confidence.item()
|
||
|
||
# 获取各类别概率
|
||
prob_dict = {}
|
||
for i, prob in enumerate(probs[0].cpu().numpy()):
|
||
category_code = ID_TO_LABEL[i]
|
||
prob_dict[category_code] = float(prob)
|
||
|
||
return {
|
||
'categoryCode': ID_TO_LABEL[predicted_id],
|
||
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
|
||
'confidence': confidence,
|
||
'probabilities': prob_dict
|
||
}
|
||
|
||
|
||
# 数据集类
|
||
class NewsDataset(torch.utils.data.Dataset):
|
||
"""新闻数据集"""
|
||
|
||
def __init__(self, texts, labels, tokenizer, max_length=512):
|
||
self.texts = texts
|
||
self.labels = labels
|
||
self.tokenizer = tokenizer
|
||
self.max_length = max_length
|
||
|
||
def __len__(self):
|
||
return len(self.texts)
|
||
|
||
def __getitem__(self, idx):
|
||
text = self.texts[idx]
|
||
label = self.labels[idx]
|
||
|
||
encoding = self.tokenizer(
|
||
text,
|
||
truncation=True,
|
||
max_length=self.max_length,
|
||
padding='max_length',
|
||
return_tensors='pt'
|
||
)
|
||
|
||
return {
|
||
'input_ids': encoding['input_ids'].flatten(),
|
||
'attention_mask': encoding['attention_mask'].flatten(),
|
||
'labels': torch.tensor(label, dtype=torch.long)
|
||
}
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试
|
||
classifier = BertClassifier()
|
||
# classifier.load_model('./models/deep_learning/bert_finetuned')
|
||
#
|
||
# result = classifier.predict(
|
||
# title="华为发布新款折叠屏手机",
|
||
# content="华为今天正式发布了新一代折叠屏手机..."
|
||
# )
|
||
# print(result)
|
||
print("BERT分类器初始化成功")
|
||
```
|
||
|
||
#### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器
|
||
|
||
```python
|
||
"""
|
||
混合策略分类器
|
||
结合规则引擎和机器学习模型
|
||
"""
|
||
|
||
import time
|
||
from typing import Dict, Any
|
||
from ..traditional.predict import TraditionalPredictor
|
||
from ..deep_learning.bert_model import BertClassifier
|
||
|
||
|
||
class HybridClassifier:
|
||
"""混合分类器"""
|
||
|
||
def __init__(self):
|
||
# 初始化各个分类器
|
||
self.nb_predictor = TraditionalPredictor('nb')
|
||
self.bert_classifier = BertClassifier()
|
||
|
||
# 配置参数
|
||
self.config = {
|
||
'confidence_threshold': 0.75, # 高置信度阈值
|
||
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
|
||
'use_bert_threshold': 0.70, # 使用BERT的阈值
|
||
'rule_priority': True # 规则优先
|
||
}
|
||
|
||
# 规则关键词字典
|
||
self.rule_keywords = {
|
||
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
|
||
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
|
||
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
|
||
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
|
||
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
|
||
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
|
||
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
|
||
'LIFE': ['生活', '美食', '旅游', '购物'],
|
||
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
|
||
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
|
||
}
|
||
|
||
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
|
||
"""
|
||
规则匹配
|
||
:return: (category_code, confidence)
|
||
"""
|
||
text = title + ' ' + content
|
||
|
||
# 计算每个类别的关键词匹配数
|
||
matches = {}
|
||
for category, keywords in self.rule_keywords.items():
|
||
count = sum(1 for kw in keywords if kw in text)
|
||
if count > 0:
|
||
matches[category] = count
|
||
|
||
if not matches:
|
||
return None, 0.0
|
||
|
||
# 返回匹配最多的类别
|
||
best_category = max(matches, key=matches.get)
|
||
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
|
||
|
||
return best_category, confidence
|
||
|
||
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
|
||
"""
|
||
混合预测
|
||
"""
|
||
start_time = time.time()
|
||
|
||
# 1. 先尝试规则匹配
|
||
rule_category, rule_confidence = self.rule_match(title, content)
|
||
|
||
# 2. 传统机器学习预测
|
||
nb_result = self.nb_predictor.predict(title, content)
|
||
nb_confidence = nb_result['confidence']
|
||
|
||
# 决策逻辑
|
||
final_result = None
|
||
classifier_type = 'HYBRID'
|
||
|
||
# 规则优先且规则置信度高
|
||
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
|
||
final_result = {
|
||
'categoryCode': rule_category,
|
||
'categoryName': nb_result['categoryName'], # 从映射获取
|
||
'confidence': rule_confidence,
|
||
'classifierType': 'RULE',
|
||
'reason': '规则匹配'
|
||
}
|
||
# 传统模型置信度足够高
|
||
elif nb_confidence >= self.config['confidence_threshold']:
|
||
final_result = {
|
||
**nb_result,
|
||
'classifierType': 'ML',
|
||
'reason': '传统模型高置信度'
|
||
}
|
||
# 需要使用BERT
|
||
elif use_bert:
|
||
# TODO: 加载BERT模型预测
|
||
# bert_result = self.bert_classifier.predict(title, content)
|
||
# 如果BERT置信度也不高,选择最高的
|
||
final_result = {
|
||
**nb_result,
|
||
'classifierType': 'HYBRID',
|
||
'reason': '混合决策'
|
||
}
|
||
else:
|
||
# 不使用BERT,直接返回传统模型结果
|
||
final_result = {
|
||
**nb_result,
|
||
'classifierType': 'ML',
|
||
'reason': '默认传统模型'
|
||
}
|
||
|
||
# 计算耗时
|
||
duration = int((time.time() - start_time) * 1000)
|
||
final_result['duration'] = duration
|
||
|
||
return final_result
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试
|
||
classifier = HybridClassifier()
|
||
|
||
test_cases = [
|
||
{
|
||
'title': '国务院发布最新经济政策',
|
||
'content': '国务院今天发布了新的经济政策...'
|
||
},
|
||
{
|
||
'title': '华为发布新款折叠屏手机',
|
||
'content': '华为今天正式发布了新一代折叠屏手机...'
|
||
}
|
||
]
|
||
|
||
for case in test_cases:
|
||
result = classifier.predict(case['title'], case['content'])
|
||
print(f"标题: {case['title']}")
|
||
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
|
||
print(f"分类器: {result['classifierType']}")
|
||
print(f"原因: {result.get('reason', 'N/A')}")
|
||
print(f"耗时: {result['duration']}ms")
|
||
print("-" * 50)
|
||
```
|
||
|
||
#### 任务 4.1.4: `src/api/server.py` - FastAPI服务
|
||
|
||
```python
|
||
"""
|
||
机器学习模型API服务
|
||
使用FastAPI提供RESTful API
|
||
"""
|
||
|
||
from fastapi import FastAPI, HTTPException
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from typing import Optional
|
||
import logging
|
||
|
||
# 导入分类器
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
|
||
from traditional.predict import TraditionalPredictor
|
||
from hybrid.hybrid_classifier import HybridClassifier
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 创建FastAPI应用
|
||
app = FastAPI(
|
||
title="新闻分类API",
|
||
description="提供新闻文本分类服务",
|
||
version="1.0.0"
|
||
)
|
||
|
||
# 配置CORS
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
# 请求模型
|
||
class ClassifyRequest(BaseModel):
|
||
title: str
|
||
content: str
|
||
mode: Optional[str] = 'hybrid' # traditional, hybrid
|
||
|
||
|
||
# 响应模型
|
||
class ClassifyResponse(BaseModel):
|
||
categoryCode: str
|
||
categoryName: str
|
||
confidence: float
|
||
classifierType: str
|
||
duration: int
|
||
probabilities: Optional[dict] = None
|
||
|
||
|
||
# 初始化分类器
|
||
nb_predictor = None
|
||
hybrid_classifier = None
|
||
|
||
|
||
@app.on_event("startup")
|
||
async def startup_event():
|
||
"""启动时加载模型"""
|
||
global nb_predictor, hybrid_classifier
|
||
|
||
logger.info("加载模型...")
|
||
|
||
try:
|
||
nb_predictor = TraditionalPredictor('nb')
|
||
logger.info("朴素贝叶斯模型加载成功")
|
||
except Exception as e:
|
||
logger.error(f"朴素贝叶斯模型加载失败: {e}")
|
||
|
||
try:
|
||
hybrid_classifier = HybridClassifier()
|
||
logger.info("混合分类器初始化成功")
|
||
except Exception as e:
|
||
logger.error(f"混合分类器初始化失败: {e}")
|
||
|
||
|
||
@app.get("/")
|
||
async def root():
|
||
"""健康检查"""
|
||
return {
|
||
"status": "ok",
|
||
"message": "新闻分类API服务运行中"
|
||
}
|
||
|
||
|
||
@app.get("/health")
|
||
async def health_check():
|
||
"""健康检查"""
|
||
return {
|
||
"status": "healthy",
|
||
"models": {
|
||
"nb_loaded": nb_predictor is not None,
|
||
"hybrid_loaded": hybrid_classifier is not None
|
||
}
|
||
}
|
||
|
||
|
||
@app.post("/api/predict", response_model=ClassifyResponse)
|
||
async def predict(request: ClassifyRequest):
|
||
"""
|
||
文本分类接口
|
||
|
||
- **title**: 新闻标题
|
||
- **content**: 新闻内容
|
||
- **mode**: 分类模式 (traditional, hybrid)
|
||
"""
|
||
try:
|
||
if request.mode == 'traditional':
|
||
result = nb_predictor.predict(request.title, request.content)
|
||
result['classifierType'] = 'ML'
|
||
else: # hybrid
|
||
result = hybrid_classifier.predict(request.title, request.content)
|
||
|
||
return ClassifyResponse(**result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"预测失败: {e}")
|
||
raise HTTPException(status_code=500, detail=str(e))
|
||
|
||
|
||
@app.post("/api/batch-predict")
|
||
async def batch_predict(requests: list[ClassifyRequest]):
|
||
"""
|
||
批量分类接口
|
||
"""
|
||
results = []
|
||
for req in requests:
|
||
try:
|
||
if req.mode == 'traditional':
|
||
result = nb_predictor.predict(req.title, req.content)
|
||
result['classifierType'] = 'ML'
|
||
else:
|
||
result = hybrid_classifier.predict(req.title, req.content)
|
||
results.append(result)
|
||
except Exception as e:
|
||
results.append({
|
||
'error': str(e),
|
||
'title': req.title
|
||
})
|
||
|
||
return {"results": results}
|
||
|
||
|
||
if __name__ == '__main__':
|
||
import uvicorn
|
||
|
||
uvicorn.run(
|
||
app,
|
||
host="0.0.0.0",
|
||
port=5000,
|
||
log_level="info"
|
||
)
|
||
```
|
||
|
||
#### 任务 4.1.5: `src/utils/metrics.py` - 评估指标
|
||
|
||
```python
|
||
"""
|
||
模型评估指标工具
|
||
"""
|
||
|
||
import numpy as np
|
||
from sklearn.metrics import (
|
||
accuracy_score,
|
||
precision_recall_fscore_support,
|
||
confusion_matrix,
|
||
classification_report
|
||
)
|
||
from typing import List, Dict, Any
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
|
||
|
||
class ClassificationMetrics:
|
||
"""分类评估指标"""
|
||
|
||
@staticmethod
|
||
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
|
||
"""
|
||
计算所有指标
|
||
"""
|
||
accuracy = accuracy_score(y_true, y_pred)
|
||
|
||
precision, recall, f1, support = precision_recall_fscore_support(
|
||
y_true, y_pred, average='weighted', zero_division=0
|
||
)
|
||
|
||
# 每个类别的指标
|
||
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
|
||
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
|
||
|
||
per_class_metrics = {}
|
||
for i, label in enumerate(labels):
|
||
per_class_metrics[label] = {
|
||
'precision': float(precision_per_class[i]),
|
||
'recall': float(recall_per_class[i]),
|
||
'f1': float(f1_per_class[i]),
|
||
'support': int(support_per_class[i])
|
||
}
|
||
|
||
return {
|
||
'accuracy': float(accuracy),
|
||
'precision': float(precision),
|
||
'recall': float(recall),
|
||
'f1': float(f1),
|
||
'per_class': per_class_metrics
|
||
}
|
||
|
||
@staticmethod
|
||
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
|
||
"""
|
||
绘制混淆矩阵
|
||
"""
|
||
cm = confusion_matrix(y_true, y_pred)
|
||
|
||
plt.figure(figsize=(10, 8))
|
||
sns.heatmap(
|
||
cm,
|
||
annot=True,
|
||
fmt='d',
|
||
cmap='Blues',
|
||
xticklabels=labels,
|
||
yticklabels=labels
|
||
)
|
||
plt.xlabel('预测标签')
|
||
plt.ylabel('真实标签')
|
||
plt.title('混淆矩阵')
|
||
|
||
if save_path:
|
||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||
plt.close()
|
||
|
||
@staticmethod
|
||
def print_report(y_true: List, y_pred: List, labels: List[str]):
|
||
"""
|
||
打印分类报告
|
||
"""
|
||
report = classification_report(
|
||
y_true, y_pred,
|
||
target_names=labels,
|
||
zero_division=0
|
||
)
|
||
print(report)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 测试
|
||
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
|
||
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
|
||
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
|
||
|
||
metrics = ClassificationMetrics()
|
||
result = metrics.compute_all(y_true, y_pred, labels)
|
||
print(result)
|
||
```
|
||
|
||
#### 任务 4.1.6: `requirements.txt` - 依赖文件
|
||
|
||
```txt
|
||
# 机器学习模块依赖
|
||
numpy>=1.24.0
|
||
pandas>=2.0.0
|
||
scikit-learn>=1.3.0
|
||
jieba>=0.42.0
|
||
joblib>=1.3.0
|
||
|
||
# 深度学习
|
||
torch>=2.0.0
|
||
transformers>=4.30.0
|
||
|
||
# API服务
|
||
fastapi>=0.100.0
|
||
uvicorn[standard]>=0.23.0
|
||
pydantic>=2.0.0
|
||
|
||
# 数据可视化
|
||
matplotlib>=3.7.0
|
||
seaborn>=0.12.0
|
||
|
||
# 工具
|
||
python-dotenv>=1.0.0
|
||
pyyaml>=6.0
|
||
```
|
||
|
||
---
|
||
|
||
## 总结
|
||
|
||
### 开发顺序建议
|
||
|
||
1. **第一阶段:基础框架**
|
||
- 后端:数据库连接、实体类、基础配置
|
||
- 前端:路由配置、状态管理、API封装
|
||
|
||
2. **第二阶段:核心功能**
|
||
- 爬虫模块(Python)
|
||
- 传统机器学习分类器
|
||
- 后端API接口
|
||
- 前端新闻列表页面
|
||
|
||
3. **第三阶段:高级功能**
|
||
- BERT深度学习分类器
|
||
- 混合策略分类器
|
||
- 前端分类器对比页面
|
||
- 统计图表
|
||
|
||
4. **第四阶段:完善优化**
|
||
- 用户认证
|
||
- 数据可视化
|
||
- 性能优化
|
||
- 异常处理
|
||
|
||
### 关键注意事项
|
||
|
||
1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信
|
||
2. **分类器模块独立部署**,提供 HTTP 接口供后端调用
|
||
3. **前后端分离**,使用 JWT 进行身份认证
|
||
4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守
|
||
5. **API 统一响应格式**使用 `Result<T>` 包装
|