22 KiB
22 KiB
新闻文本分类系统 - 模块开发任务清单
目录
1. 爬虫模块 (Python)
2. 后端服务模块 (Spring Boot)
4. 机器学习分类模块 (Python)
模块目录结构
ml-module/
├── data/
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后的数据
│ │ ├── training_data.csv
│ │ └── test_data.csv
│ └── external/ # 外部数据集
├── models/ # 训练好的模型
│ ├── traditional/
│ │ ├── nb_vectorizer.pkl
│ │ ├── nb_classifier.pkl
│ │ ├── svm_vectorizer.pkl
│ │ └── svm_classifier.pkl
│ ├── deep_learning/
│ │ └── bert_finetuned/
│ └── hybrid/
│ └── config.json
├── src/
│ ├── __init__.py
│ ├── traditional/ # 传统机器学习
│ │ ├── __init__.py
│ │ ├── train_model.py # (已有)
│ │ ├── predict.py
│ │ └── evaluate.py
│ ├── deep_learning/ # 深度学习
│ │ ├── __init__.py
│ │ ├── bert_model.py
│ │ ├── train_bert.py
│ │ └── predict_bert.py
│ ├── hybrid/ # 混合策略
│ │ ├── __init__.py
│ │ ├── hybrid_classifier.py
│ │ └── rule_engine.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── preprocessing.py # 数据预处理
│ │ └── metrics.py # 评估指标
│ └── api/ # API服务
│ ├── __init__.py
│ └── server.py # FastAPI服务
├── notebooks/ # Jupyter notebooks
│ ├── data_exploration.ipynb
│ └── model_comparison.ipynb
├── tests/ # 测试
│ ├── test_traditional.py
│ ├── test_bert.py
│ └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md
4.1 需要完成的具体文件
任务 4.1.1: src/traditional/predict.py - 传统模型预测
"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)
任务 4.1.2: src/deep_learning/bert_model.py - BERT模型
"""
BERT文本分类模型
"""
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments
)
from typing import Dict, Any, List
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
class BertClassifier:
"""BERT文本分类器"""
def __init__(self, model_name='bert-base-chinese', num_labels=10):
self.model_name = model_name
self.num_labels = num_labels
self.tokenizer = None
self.model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self, model_path):
"""加载微调后的模型"""
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=self.num_labels
)
self.model.to(self.device)
self.model.eval()
print(f"BERT模型加载成功: {model_path}")
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
"""
if self.model is None or self.tokenizer is None:
raise ValueError("模型未加载,请先调用load_model")
# 组合标题和内容
text = f"{title} [SEP] {content}"
# 分词
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
# 预测
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
# 获取预测结果
probs = torch.softmax(logits, dim=-1)
confidence, predicted_id = torch.max(probs, dim=-1)
predicted_id = predicted_id.item()
confidence = confidence.item()
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probs[0].cpu().numpy()):
category_code = ID_TO_LABEL[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': ID_TO_LABEL[predicted_id],
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
'confidence': confidence,
'probabilities': prob_dict
}
# 数据集类
class NewsDataset(torch.utils.data.Dataset):
"""新闻数据集"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
if __name__ == '__main__':
# 测试
classifier = BertClassifier()
# classifier.load_model('./models/deep_learning/bert_finetuned')
#
# result = classifier.predict(
# title="华为发布新款折叠屏手机",
# content="华为今天正式发布了新一代折叠屏手机..."
# )
# print(result)
print("BERT分类器初始化成功")
任务 4.1.3: src/hybrid/hybrid_classifier.py - 混合分类器
"""
混合策略分类器
结合规则引擎和机器学习模型
"""
import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier
class HybridClassifier:
"""混合分类器"""
def __init__(self):
# 初始化各个分类器
self.nb_predictor = TraditionalPredictor('nb')
self.bert_classifier = BertClassifier()
# 配置参数
self.config = {
'confidence_threshold': 0.75, # 高置信度阈值
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
'use_bert_threshold': 0.70, # 使用BERT的阈值
'rule_priority': True # 规则优先
}
# 规则关键词字典
self.rule_keywords = {
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
'LIFE': ['生活', '美食', '旅游', '购物'],
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
}
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
"""
规则匹配
:return: (category_code, confidence)
"""
text = title + ' ' + content
# 计算每个类别的关键词匹配数
matches = {}
for category, keywords in self.rule_keywords.items():
count = sum(1 for kw in keywords if kw in text)
if count > 0:
matches[category] = count
if not matches:
return None, 0.0
# 返回匹配最多的类别
best_category = max(matches, key=matches.get)
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
return best_category, confidence
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
"""
混合预测
"""
start_time = time.time()
# 1. 先尝试规则匹配
rule_category, rule_confidence = self.rule_match(title, content)
# 2. 传统机器学习预测
nb_result = self.nb_predictor.predict(title, content)
nb_confidence = nb_result['confidence']
# 决策逻辑
final_result = None
classifier_type = 'HYBRID'
# 规则优先且规则置信度高
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
final_result = {
'categoryCode': rule_category,
'categoryName': nb_result['categoryName'], # 从映射获取
'confidence': rule_confidence,
'classifierType': 'RULE',
'reason': '规则匹配'
}
# 传统模型置信度足够高
elif nb_confidence >= self.config['confidence_threshold']:
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '传统模型高置信度'
}
# 需要使用BERT
elif use_bert:
# TODO: 加载BERT模型预测
# bert_result = self.bert_classifier.predict(title, content)
# 如果BERT置信度也不高,选择最高的
final_result = {
**nb_result,
'classifierType': 'HYBRID',
'reason': '混合决策'
}
else:
# 不使用BERT,直接返回传统模型结果
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '默认传统模型'
}
# 计算耗时
duration = int((time.time() - start_time) * 1000)
final_result['duration'] = duration
return final_result
if __name__ == '__main__':
# 测试
classifier = HybridClassifier()
test_cases = [
{
'title': '国务院发布最新经济政策',
'content': '国务院今天发布了新的经济政策...'
},
{
'title': '华为发布新款折叠屏手机',
'content': '华为今天正式发布了新一代折叠屏手机...'
}
]
for case in test_cases:
result = classifier.predict(case['title'], case['content'])
print(f"标题: {case['title']}")
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
print(f"分类器: {result['classifierType']}")
print(f"原因: {result.get('reason', 'N/A')}")
print(f"耗时: {result['duration']}ms")
print("-" * 50)
任务 4.1.4: src/api/server.py - FastAPI服务
"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging
# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title="新闻分类API",
description="提供新闻文本分类服务",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class ClassifyRequest(BaseModel):
title: str
content: str
mode: Optional[str] = 'hybrid' # traditional, hybrid
# 响应模型
class ClassifyResponse(BaseModel):
categoryCode: str
categoryName: str
confidence: float
classifierType: str
duration: int
probabilities: Optional[dict] = None
# 初始化分类器
nb_predictor = None
hybrid_classifier = None
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
global nb_predictor, hybrid_classifier
logger.info("加载模型...")
try:
nb_predictor = TraditionalPredictor('nb')
logger.info("朴素贝叶斯模型加载成功")
except Exception as e:
logger.error(f"朴素贝叶斯模型加载失败: {e}")
try:
hybrid_classifier = HybridClassifier()
logger.info("混合分类器初始化成功")
except Exception as e:
logger.error(f"混合分类器初始化失败: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"message": "新闻分类API服务运行中"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"models": {
"nb_loaded": nb_predictor is not None,
"hybrid_loaded": hybrid_classifier is not None
}
}
@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
"""
文本分类接口
- **title**: 新闻标题
- **content**: 新闻内容
- **mode**: 分类模式 (traditional, hybrid)
"""
try:
if request.mode == 'traditional':
result = nb_predictor.predict(request.title, request.content)
result['classifierType'] = 'ML'
else: # hybrid
result = hybrid_classifier.predict(request.title, request.content)
return ClassifyResponse(**result)
except Exception as e:
logger.error(f"预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
"""
批量分类接口
"""
results = []
for req in requests:
try:
if req.mode == 'traditional':
result = nb_predictor.predict(req.title, req.content)
result['classifierType'] = 'ML'
else:
result = hybrid_classifier.predict(req.title, req.content)
results.append(result)
except Exception as e:
results.append({
'error': str(e),
'title': req.title
})
return {"results": results}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=5000,
log_level="info"
)
任务 4.1.5: src/utils/metrics.py - 评估指标
"""
模型评估指标工具
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
class ClassificationMetrics:
"""分类评估指标"""
@staticmethod
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
"""
计算所有指标
"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# 每个类别的指标
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
per_class_metrics = {}
for i, label in enumerate(labels):
per_class_metrics[label] = {
'precision': float(precision_per_class[i]),
'recall': float(recall_per_class[i]),
'f1': float(f1_per_class[i]),
'support': int(support_per_class[i])
}
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'per_class': per_class_metrics
}
@staticmethod
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
"""
绘制混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def print_report(y_true: List, y_pred: List, labels: List[str]):
"""
打印分类报告
"""
report = classification_report(
y_true, y_pred,
target_names=labels,
zero_division=0
)
print(report)
if __name__ == '__main__':
# 测试
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
metrics = ClassificationMetrics()
result = metrics.compute_all(y_true, y_pred, labels)
print(result)
任务 4.1.6: requirements.txt - 依赖文件
# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0
# 深度学习
torch>=2.0.0
transformers>=4.30.0
# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 工具
python-dotenv>=1.0.0
pyyaml>=6.0
总结
开发顺序建议
-
第一阶段:基础框架
- 后端:数据库连接、实体类、基础配置
- 前端:路由配置、状态管理、API封装
-
第二阶段:核心功能
- 爬虫模块(Python)
- 传统机器学习分类器
- 后端API接口
- 前端新闻列表页面
-
第三阶段:高级功能
- BERT深度学习分类器
- 混合策略分类器
- 前端分类器对比页面
- 统计图表
-
第四阶段:完善优化
- 用户认证
- 数据可视化
- 性能优化
- 异常处理
关键注意事项
- 爬虫模块使用 Python,通过 RESTful API 与 Java 后端通信
- 分类器模块独立部署,提供 HTTP 接口供后端调用
- 前后端分离,使用 JWT 进行身份认证
- 数据库表结构已在
schema.sql中定义,需严格遵守 - API 统一响应格式使用
Result<T>包装