Compare commits

...

2 Commits

24 changed files with 2582 additions and 35 deletions

View File

@ -110,3 +110,13 @@ sources:
category_id: 9 category_id: 9
name: "AI" name: "AI"
css_selector: "div.kr-information-left" css_selector: "div.kr-information-left"
sina:
base_url: "https://sina.com.cn"
categories:
auto:
url: "https://auto.sina.com.cn/"
category_id: 6
name: "汽车"
css_selector: "div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1"
detail_css_selector: "div.main-content"

View File

@ -0,0 +1,82 @@
这是新浪网关于爬取汽车相关新闻的代码
```python
import requests
from bs4 import BeautifulSoup
URL = "https://auto.sina.com.cn/"
headers = {
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/120.0.0.0 Safari/537.36"
)
}
resp = requests.get(URL,headers=headers,timeout=10)
# resp.raise_for_status()
# resp.encoding = "utf-8"
# print(resp.text)
with open("example/example-10.html","r",encoding="utf-8") as f:
html = f.read()
# soup = BeautifulSoup(resp.text,"lxml")
soup = BeautifulSoup(html,"lxml")
div_list = soup.select("div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1")
for item in div_list:
a = item.select_one("div.ty-card-l a")
href = a.get("href")
# print(a.get('href'),a.get_text().strip())
resp = requests.get(url=href,headers=headers)
resp.encoding = resp.apparent_encoding # requests 会尝试猜测编码
soup = BeautifulSoup(resp.text,"lxml")
# 获取文章标题
article_title_tag = soup.select_one("div.main-content h1.main-title")
if article_title_tag:
article_title = article_title_tag.get_text(strip=True)
if not article_title:
article_title = "未知标题"
else:
article_title = "未知标题"
# print("标题:", article_title)
# 获取文章发布时间
from datetime import datetime
# 日期时间格式化函数
def normalize_time(time_str):
for fmt in ("%Y年%m月%d日 %H:%M", "%Y-%m-%d %H:%M:%S"):
try:
dt = datetime.strptime(time_str, fmt)
return dt.strftime("%Y-%m-%d %H:%M:%S")
except:
continue
return time_str # 如果都不匹配,返回原字符串
time_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source span.date")
if time_tag: # 只有存在时间标签才进行格式化
publish_time = normalize_time(time_tag.get_text(strip=True))
else:
publish_time = "1949-01-01 12:00:00"
#print(publish_time)
# 获取文章作者
author_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source a")
if author_tag:
author = author_tag.get_text(strip=True)
else:
author = "未知"
# print(author)
# 获取文章正文段落
article_div = soup.select_one("div.main-content div.article") # 核心文章容器
if not article_div:
# print("不是文章详情页,跳过")
continue # 如果不是详情页就跳过
paragraphs = article_div.find_all('p')
article_text = '\n'.join(p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True))
# print("正文:\n", article_text)
```

View File

@ -102,7 +102,6 @@ class BaseCrawler(ABC):
try: try:
# 获取页面HTML # 获取页面HTML
html = self._fetch_page() html = self._fetch_page()
# 解析文章列表 # 解析文章列表
article_urls = self._extract_article_urls(html) article_urls = self._extract_article_urls(html)
self.logger.info(f"找到 {len(article_urls)} 篇文章") self.logger.info(f"找到 {len(article_urls)} 篇文章")

View File

@ -31,6 +31,9 @@ CRAWLER_CLASSES = {
'kr36': { 'kr36': {
'ai': ('crawlers.kr36.ai', 'AICrawler'), 'ai': ('crawlers.kr36.ai', 'AICrawler'),
}, },
'sina': {
'auto': ('crawlers.sina.auto', 'SinaAutoCrawler'),
},
} }
@ -53,6 +56,11 @@ def list_crawlers() -> List[str]:
for category in kr36_categories.keys(): for category in kr36_categories.keys():
crawlers.append(f"kr36:{category}") crawlers.append(f"kr36:{category}")
# 新浪爬虫
sina_categories = config.get('sources.sina.categories', {})
for category in sina_categories.keys():
crawlers.append(f"sina:{category}")
return crawlers return crawlers
@ -101,7 +109,7 @@ def run_crawler(source: str, category: str, max_articles: int = None) -> bool:
# 创建并运行爬虫 # 创建并运行爬虫
crawler = crawler_class(source, category) crawler = crawler_class(source, category)
print("创建并运行爬虫")
# 覆盖最大文章数 # 覆盖最大文章数
if max_articles: if max_articles:
crawler.max_articles = max_articles crawler.max_articles = max_articles

View File

@ -0,0 +1,60 @@
"""
新浪汽车新闻爬虫
"""
from typing import List
from bs4 import BeautifulSoup
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from base.crawler_base import DynamicCrawler, Article
from parsers.sina_parser import SinaAutoParser
class SinaAutoCrawler(DynamicCrawler):
"""新浪汽车新闻爬虫"""
def _extract_article_urls(self, html: str) -> List[str]:
"""从HTML中提取文章URL列表"""
soup = BeautifulSoup(html, "lxml")
urls = []
# 尝试不同的选择器
div_list = soup.select("div.cardlist-a__list div.ty-card.ty-card-type1")
if not div_list:
div_list = soup.select("div.news-list li.news-item")
if not div_list:
div_list = soup.select("div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1")
for item in div_list:
a = item.select_one("a")
if a and a.get("href"):
urls.append(a.get("href"))
return urls
def _fetch_articles(self, urls: List[str]) -> List[Article]:
"""爬取文章详情"""
articles = []
parser = SinaAutoParser()
for i, url in enumerate(urls[:self.max_articles]):
try:
article = parser.parse(url)
article.category_id = self.category_id
article.source = "新浪"
if not article.author:
article.author = "新浪汽车"
if article.is_valid():
articles.append(article)
self.logger.info(f"[{i+1}/{len(urls)}] {article.title}")
except Exception as e:
self.logger.error(f"解析文章失败: {url} - {e}")
continue
return articles

View File

@ -54,6 +54,7 @@ class NewsRepository:
try: try:
with db_pool.get_connection() as conn: with db_pool.get_connection() as conn:
cursor = conn.cursor() cursor = conn.cursor()
# 批量查询已存在的URL # 批量查询已存在的URL
if urls: if urls:
placeholders = ','.join(['%s'] * len(urls)) placeholders = ','.join(['%s'] * len(urls))
@ -61,24 +62,25 @@ class NewsRepository:
cursor.execute(check_sql, urls) cursor.execute(check_sql, urls)
existing_urls = {row[0] for row in cursor.fetchall()} existing_urls = {row[0] for row in cursor.fetchall()}
# 只插入不存在的记录 # 只插入URL不存在的记录
new_data = [item for item in data if item[0] not in existing_urls] new_data = [item for item in data if item[0] not in existing_urls]
if not new_data: if not new_data:
self.logger.info(f"所有 {len(data)} 条新闻已存在,跳过插入") self.logger.info(f"所有 {len(data)} 条新闻已存在,跳过插入")
return 0 return 0
# 执行插入 # 执行插入,使用 INSERT IGNORE 忽略 content_hash 重复的记录
sql = """ sql = """
INSERT INTO news (url, title, category_id, publish_time, author, source, content, content_hash) INSERT IGNORE INTO news (url, title, category_id, publish_time, author, source, content, content_hash)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
""" """
cursor.executemany(sql, new_data) cursor.executemany(sql, new_data)
conn.commit() conn.commit()
inserted = len(new_data) # 获取实际插入的行数
self.logger.info(f"成功插入 {inserted} 条新新闻,{len(data) - inserted} 条已存在") inserted = cursor.rowcount
self.logger.info(f"成功插入 {inserted} 条新新闻,{len(new_data) - inserted} 条因内容重复被忽略")
return inserted return inserted
except Exception as e: except Exception as e:

View File

@ -0,0 +1,71 @@
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from base.parser_base import BaseParser
from datetime import datetime
from typing import Optional
from base.crawler_base import Article
from bs4 import BeautifulSoup
from utils.logger import get_logger
from utils.http_client import HttpClient
class SinaAutoParser(BaseParser):
"""新浪网汽车新闻解析器"""
def __init__(self):
self.logger = get_logger(__name__)
self.http_client = HttpClient()
def parse(self, url: str) -> Article:
"""
解析新浪网文章详情页
Args:
url: 文章URL
Returns:
文章对象
"""
html = self.http_client.get(url)
soup = BeautifulSoup(html, "lxml")
# 获取文章标题
article_title_tag = soup.select_one("div.main-content h1.main-title")
article_title = article_title_tag.get_text(strip=True) if article_title_tag else "未知标题"
# 获取文章发布时间
def normalize_time(time_str):
for fmt in ("%Y年%m月%d%H:%M", "%Y-%m-%d %H:%M:%S"):
try:
dt = datetime.strptime(time_str, fmt)
return dt.strftime("%Y-%m-%d %H:%M:%S")
except:
continue
return time_str # 如果都不匹配,返回原字符串
time_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source span.date")
publish_time = normalize_time(time_tag.get_text(strip=True)) if time_tag else "1949-01-01 12:00:00"
# 获取文章作者
author_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source a")
author = author_tag.get_text(strip=True) if author_tag else "未知"
# 获取文章正文段落
article_div = soup.select_one("div.main-content div.article")
if not article_div:
raise ValueError("无法找到文章内容")
paragraphs = article_div.find_all('p')
content = '\n'.join(p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True))
return Article(
url=url,
title=article_title,
publish_time=publish_time,
author=author,
content=content,
category_id=6, # 汽车分类ID
source="sina"
)

View File

@ -0,0 +1,7 @@
{
"permissions": {
"allow": [
"Bash(mkdir:*)"
]
}
}

197
ml-module/README.md Normal file
View File

@ -0,0 +1,197 @@
# 新闻文本分类系统 - 机器学习模块
## 功能特性
### GPU/CPU自动检测
- 自动检测可用GPU包括8GB显存
- 自动回退到CPU如果GPU不可用
- 显示设备信息和显存信息
### 动态参数调整
- 通过配置文件调整训练参数
- 支持命令行参数覆盖
- 混合精度自动检测
## 使用方法
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 从mysql拉去训练数据
```bash
python .\src\utils\data_loader.py
```
### 2. 配置训练参数
编辑 `config.yaml` 文件调整训练参数:
```yaml
database:
url: "mysql+pymysql://root:root@localhost/news_classifier"
data_limit: 1000
model:
name: "bert-base-chinese"
num_labels: 9
output_dir: "./models/deep_learning/bert_finetuned"
training:
use_gpu: true # 自动检测GPU
epochs: 3
batch_size: 8
learning_rate: 2e-5
warmup_steps: 500
weight_decay: 0.01
```
### 3. 训练BERT模型
```bash
python train_bert.py
```
### 4. 使用命令行参数覆盖配置
```bash
# 使用GPU训练
python train_bert.py --use_gpu
# 调整训练轮数
python train_bert.py --epochs 5
# 调整批大小
python train_bert.py --batch_size 16
```
### 5. 启动API服务
```bash
python src/api/server.py
```
## 参数说明
### 训练参数
- `epochs`: 训练轮数默认3
- `batch_size`: 批大小默认8
- `learning_rate`: 学习率默认2e-5
- `warmup_steps`: 预热步数默认500
- `weight_decay`: 权重衰减默认0.01
### 设备配置
- `use_gpu`: 是否使用GPU自动检测
- `fp16`: 混合精度(自动检测)
## 8GB显存优化建议
对于8GB显存的GPU建议配置
```yaml
training:
use_gpu: true
epochs: 3
batch_size: 8-16
fp16: true # 启用混合精度
```
## 设备检测
训练时会自动检测设备:
- **GPU可用**使用GPU训练显示GPU名称和显存信息
- **GPU不可用**自动回退到CPU训练
## 性能优化
- 混合精度训练FP16
- 梯度累积
- 自动批大小调整
- 内存优化
## 注意事项
1. 确保安装了CUDA和cuDNN如果使用GPU
2. 8GB显存可以训练BERT-base模型
3. 可以通过调整batch_size适应不同显存大小
4. 训练时间取决于数据量和硬件配置
## API接口
```bash
# 单条预测
POST /api/predict
{
"title": "新闻标题",
"content": "新闻内容",
"mode": "hybrid" # traditional, hybrid
}
# 批量预测
POST /api/batch-predict
[
{
"title": "新闻标题1",
"content": "新闻内容1",
"mode": "hybrid"
},
{
"title": "新闻标题2",
"content": "新闻内容2",
"mode": "traditional"
}
]
```
---
## **分组方案**
### **1⃣ 基础科学计算 / 机器学习**
```cmd
pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0
```
---
### **2⃣ 深度学习 / NLP**
```cmd
pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0
```
💡 如果你有 **GPU** 并希望安装 GPU 版 PyTorch需要单独去 PyTorch 官网生成命令。
---
### **3⃣ API 服务**
```cmd
pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0
```
---
### **4⃣ 数据库相关**
```cmd
pip install sqlalchemy>=2.0.0 pymysql>=1.1.0
```
---
### **5⃣ 数据可视化**
```cmd
pip install matplotlib>=3.7.0 seaborn>=0.12.0
```
---
### **6⃣ 工具 / 配置文件处理**
```cmd
pip install python-dotenv>=1.0.0 pyyaml>=6.0
```
---
**这样分批安装的好处**
1. 出现安装错误时,更容易定位是哪个模块有问题。
2. 每组依赖关系相对独立,减少冲突。
3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。

33
ml-module/config.yaml Normal file
View File

@ -0,0 +1,33 @@
# BERT模型训练配置
database:
url: "mysql+pymysql://root:root@localhost/news_classifier" # 数据库连接URL
data_limit: 1000 # 加载数据量限制
model:
name: "bert-base-chinese" # 模型名称
num_labels: 9 # 分类数量
output_dir: "./models/deep_learning/bert_finetuned" # 模型输出目录
training:
use_gpu: true # 是否使用GPU自动检测
epochs: 3 # 训练轮数
batch_size: 8 # 训练批大小
learning_rate: 2e-5 # 学习率
warmup_steps: 500 # 预热步数
weight_decay: 0.01 # 权重衰减
fp16: null # 混合精度null表示自动检测
# 日志和输出配置
logging:
level: "INFO" # 日志级别
file: "./training.log" # 日志文件
# 设备配置
device:
max_memory: "8GB" # 最大内存限制
gradient_checkpointing: true # 梯度检查点
# 性能优化
optimization:
gradient_accumulation_steps: 1 # 梯度累积步数
mixed_precision: true # 混合精度(如果支持)

View File

@ -0,0 +1,160 @@
及其
以及
之一
一些
一个
一种
多种
各自
各个
各类
记者
通讯员
编辑
报道
表示
指出
认为
介绍
透露
强调
分析
称之为
近日
今日
今天
昨日
目前
当前
今年
去年
明年
此前
之后
当日
当天
近日来
近来
近期
未来
方面
情况
问题
工作
任务
活动
会议
会议上
会议中
会议期间
会议指出
相关
有关
一定
一些
部分
整体
进一步
持续
不断
不断地
继续
推进
加强
提升
改善
推动
加快
实现
完成
开展
进行
同时
此外
因此
所以
但是
然而
不过
如果
因为
虽然
由于
其中
其中之一
对此
对此次
对此前
对此后
通过
按照
根据
依据
围绕
围绕着
针对
关于
对于
面对
着力
积极
主动
有效
充分
全面
已经
正在
正在进行
正在推进
正在开展
开始
结束
完成后
之后
以来
以来的
相关人士
业内人士
专家表示
专家认为
业内认为
市场认为
分析人士
有关人士
知情人士

51
ml-module/database.md Normal file
View File

@ -0,0 +1,51 @@
### 新闻分类表
```sql
CREATE TABLE news_category (
id INT NOT NULL AUTO_INCREMENT COMMENT '分类ID',
name VARCHAR(50) NOT NULL COMMENT '分类名称',
PRIMARY KEY (id),
UNIQUE KEY uk_name (name)
) ENGINE=InnoDB
DEFAULT CHARSET=utf8mb4
COLLATE=utf8mb4_0900_ai_ci
COMMENT='新闻分类表';
```
数据:
```text
1 娱乐
2 体育
3 财经
4 科技
5 军事
6 汽车
7 政务
8 健康
9 AI
```
### 新闻表
```sql
CREATE TABLE news (
id BIGINT NOT NULL AUTO_INCREMENT COMMENT '自增主键',
url VARCHAR(500) NOT NULL COMMENT '新闻原始URL',
title VARCHAR(255) NOT NULL COMMENT '新闻标题',
category_id INT NULL COMMENT '新闻分类ID',
publish_time DATETIME NULL COMMENT '发布时间',
author VARCHAR(100) NULL COMMENT '作者/来源',
source VARCHAR(50) NULL COMMENT '新闻来源(网易/36kr',
content LONGTEXT NOT NULL COMMENT '新闻正文',
content_hash CHAR(64) NOT NULL COMMENT '正文内容hash用于去重',
created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP COMMENT '入库时间',
PRIMARY KEY (id),
UNIQUE KEY uk_url (url),
UNIQUE KEY uk_content_hash (content_hash),
KEY idx_category_id (category_id),
KEY idx_source (source)
) ENGINE=InnoDB
DEFAULT CHARSET=utf8mb4
COLLATE=utf8mb4_0900_ai_ci
COMMENT='新闻表';
```

Binary file not shown.

Binary file not shown.

View File

@ -1,11 +1,30 @@
# 机器学习模块依赖 # 机器学习模块依赖
scikit-learn==1.4.0 numpy>=1.24.0
numpy==1.26.3 pandas>=2.0.0
pandas==2.1.4 scikit-learn>=1.3.0
jieba==0.42.1 jieba>=0.42.0
scipy==1.12.0 joblib>=1.3.0
joblib==1.3.2
# 深度学习 (可选) # 深度学习
# torch==2.1.2 torch>=2.0.0
# transformers==4.37.2 transformers>=4.30.0
# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 数据库
sqlalchemy>=2.0.0
pymysql>=1.1.0
# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 工具
python-dotenv>=1.0.0
pyyaml>=6.0
# 中文处理
jieba>=0.42.0

157
ml-module/src/api/server.py Normal file
View File

@ -0,0 +1,157 @@
"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional, List
import logging
import sys
import os
# 导入分类器
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title="新闻分类API",
description="提供新闻文本分类服务",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class ClassifyRequest(BaseModel):
title: str
content: str
mode: Optional[str] = 'hybrid' # traditional, hybrid
# 响应模型
class ClassifyResponse(BaseModel):
categoryCode: str
categoryName: str
confidence: float
classifierType: str
duration: int
probabilities: Optional[dict] = None
# 初始化分类器
nb_predictor = None
hybrid_classifier = None
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
global nb_predictor, hybrid_classifier
logger.info("加载模型...")
try:
nb_predictor = TraditionalPredictor('nb')
logger.info("朴素贝叶斯模型加载成功")
except Exception as e:
logger.error(f"朴素贝叶斯模型加载失败: {e}")
try:
hybrid_classifier = HybridClassifier()
logger.info("混合分类器初始化成功")
except Exception as e:
logger.error(f"混合分类器初始化失败: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"message": "新闻分类API服务运行中"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"models": {
"nb_loaded": nb_predictor is not None,
"hybrid_loaded": hybrid_classifier is not None
}
}
@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
"""
文本分类接口
- **title**: 新闻标题
- **content**: 新闻内容
- **mode**: 分类模式 (traditional, hybrid)
"""
try:
if request.mode == 'traditional':
result = nb_predictor.predict(request.title, request.content)
result['classifierType'] = 'ML'
else: # hybrid
result = hybrid_classifier.predict(request.title, request.content)
return ClassifyResponse(**result)
except Exception as e:
logger.error(f"预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-predict")
async def batch_predict(requests: List[ClassifyRequest]):
"""
批量分类接口
"""
results = []
for req in requests:
try:
if req.mode == 'traditional':
result = nb_predictor.predict(req.title, req.content)
result['classifierType'] = 'ML'
else:
result = hybrid_classifier.predict(req.title, req.content)
results.append(result)
except Exception as e:
results.append({
'error': str(e),
'title': req.title
})
return {"results": results}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=5000,
log_level="info"
)

View File

@ -0,0 +1,209 @@
"""
BERT文本分类模型
"""
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments
)
from typing import Dict, Any, List
# 分类映射
CATEGORY_MAP = {
'ENTERTAINMENT': '娱乐',
'SPORTS': '体育',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'MILITARY': '军事',
'AUTOMOTIVE': '汽车',
'GOVERNMENT': '政务',
'HEALTH': '健康',
'AI': 'AI'
}
# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
class BertClassifier:
"""BERT文本分类器 - 支持GPU/CPU自动检测和动态参数"""
def __init__(self, model_name='bert-base-chinese', num_labels=9, use_gpu=True):
self.model_name = model_name
self.num_labels = num_labels
self.tokenizer = None
self.model = None
# GPU/CPU自动检测
self.use_gpu = use_gpu and torch.cuda.is_available()
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
# 打印设备信息
if self.use_gpu:
print(f"使用GPU训练: {torch.cuda.get_device_name(0)}")
print(f"GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
print("使用CPU训练")
def load_model(self, model_path):
"""加载微调后的模型"""
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=self.num_labels
)
self.model.to(self.device)
self.model.eval()
print(f"BERT模型加载成功: {model_path}")
def train_model(self, train_dataset, eval_dataset, output_dir='./models/deep_learning/bert_finetuned',
num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=16,
learning_rate=2e-5, warmup_steps=500, weight_decay=0.01, fp16=None):
"""
训练BERT模型
:param train_dataset: 训练数据集
:param eval_dataset: 验证数据集
:param output_dir: 模型输出目录
:param num_train_epochs: 训练轮数
:param per_device_train_batch_size: 每个设备的训练batch大小
:param per_device_eval_batch_size: 每个设备的验证batch大小
:param learning_rate: 学习率
:param warmup_steps: 预热步数
:param weight_decay: 权重衰减
:param fp16: 是否使用混合精度None表示自动检测
"""
# 自动检测是否使用混合精度
if fp16 is None:
fp16 = self.use_gpu
# 配置训练参数
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
per_device_eval_batch_size=per_device_eval_batch_size,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
weight_decay=weight_decay,
logging_dir='./logs',
logging_steps=10,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="accuracy",
fp16=fp16, # 混合精度
gradient_accumulation_steps=1,
)
# 创建训练器
trainer = Trainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset
)
# 开始训练
print("开始训练BERT模型...")
trainer.train()
print("训练完成!")
# 保存模型
trainer.save_model(output_dir)
self.tokenizer.save_pretrained(output_dir)
print(f"模型已保存到: {output_dir}")
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
"""
if self.model is None or self.tokenizer is None:
raise ValueError("模型未加载请先调用load_model")
# 组合标题和内容
text = f"{title} [SEP] {content}"
# 分词
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
# 预测
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
# 获取预测结果
probs = torch.softmax(logits, dim=-1)
confidence, predicted_id = torch.max(probs, dim=-1)
predicted_id = predicted_id.item()
confidence = confidence.item()
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probs[0].cpu().numpy()):
category_code = ID_TO_LABEL[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': ID_TO_LABEL[predicted_id],
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
'confidence': confidence,
'probabilities': prob_dict
}
# 数据集类
class NewsDataset(torch.utils.data.Dataset):
"""新闻数据集"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
if __name__ == '__main__':
# 测试
classifier = BertClassifier()
# classifier.load_model('./models/deep_learning/bert_finetuned')
#
# result = classifier.predict(
# title="华为发布新款折叠屏手机",
# content="华为今天正式发布了新一代折叠屏手机..."
# )
# print(result)
print("BERT分类器初始化成功")

View File

@ -0,0 +1,148 @@
"""
混合策略分类器
结合规则引擎和机器学习模型
"""
import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier
class HybridClassifier:
"""混合分类器"""
def __init__(self):
# 初始化各个分类器
self.nb_predictor = TraditionalPredictor('nb')
self.bert_classifier = BertClassifier()
# 配置参数
self.config = {
'confidence_threshold': 0.75, # 高置信度阈值
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
'use_bert_threshold': 0.70, # 使用BERT的阈值
'rule_priority': True # 规则优先
}
# 规则关键词字典
self.rule_keywords = {
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手', '娱乐'],
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA', '足球', '篮球'],
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行', '理财'],
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技', '数码'],
'MILITARY': ['军事', '武器', '军队', '国防', '战争', '军事'],
'AUTOMOTIVE': ['汽车', '', '车型', '驾驶', '交通', ' automotive'],
'GOVERNMENT': ['政府', '政策', '选举', '国务院', '主席', '总理', '政务'],
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗', '健康'],
'AI': ['AI', '人工智能', '机器学习', '深度学习', '算法', 'AI']
}
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
"""
规则匹配
:return: (category_code, confidence)
"""
text = title + ' ' + content
# 计算每个类别的关键词匹配数
matches = {}
for category, keywords in self.rule_keywords.items():
count = sum(1 for kw in keywords if kw in text)
if count > 0:
matches[category] = count
if not matches:
return None, 0.0
# 返回匹配最多的类别
best_category = max(matches, key=matches.get)
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
return best_category, confidence
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
"""
混合预测
"""
start_time = time.time()
# 1. 先尝试规则匹配
rule_category, rule_confidence = self.rule_match(title, content)
# 2. 传统机器学习预测
nb_result = self.nb_predictor.predict(title, content)
nb_confidence = nb_result['confidence']
# 决策逻辑
final_result = None
classifier_type = 'HYBRID'
# 规则优先且规则置信度高
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
final_result = {
'categoryCode': rule_category,
'categoryName': nb_result['categoryName'], # 从映射获取
'confidence': rule_confidence,
'classifierType': 'RULE',
'reason': '规则匹配'
}
# 传统模型置信度足够高
elif nb_confidence >= self.config['confidence_threshold']:
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '传统模型高置信度'
}
# 需要使用BERT
elif use_bert:
# TODO: 加载BERT模型预测
# bert_result = self.bert_classifier.predict(title, content)
# 如果BERT置信度也不高选择最高的
final_result = {
**nb_result,
'classifierType': 'HYBRID',
'reason': '混合决策'
}
else:
# 不使用BERT直接返回传统模型结果
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '默认传统模型'
}
# 计算耗时
duration = int((time.time() - start_time) * 1000)
final_result['duration'] = duration
return final_result
if __name__ == '__main__':
# 测试
classifier = HybridClassifier()
test_cases = [
{
'title': '国务院发布最新经济政策',
'content': '国务院今天发布了新的经济政策...'
},
{
'title': '华为发布新款折叠屏手机',
'content': '华为今天正式发布了新一代折叠屏手机...'
},
{
'title': 'NBA总决赛精彩瞬间',
'content': '湖人队与勇士队的总决赛...'
}
]
for case in test_cases:
result = classifier.predict(case['title'], case['content'])
print(f"标题: {case['title']}")
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
print(f"分类器: {result['classifierType']}")
print(f"原因: {result.get('reason', 'N/A')}")
print(f"耗时: {result['duration']}ms")
print("-" * 50)

View File

@ -0,0 +1,94 @@
"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射(根据数据库中的分类)
CATEGORY_MAP = {
'ENTERTAINMENT': '娱乐',
'SPORTS': '体育',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'MILITARY': '军事',
'AUTOMOTIVE': '汽车',
'GOVERNMENT': '政务',
'HEALTH': '健康',
'AI': 'AI'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)

View File

@ -14,45 +14,95 @@ from sklearn.svm import SVC
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, f1_score from sklearn.metrics import classification_report, accuracy_score, f1_score
# 分类映射 # 分类映射(与数据库表一致)
CATEGORY_MAP = { CATEGORY_MAP = {
'POLITICS': '时政', 'ENTERTAINMENT': '娱乐',
'SPORTS': '体育',
'FINANCE': '财经', 'FINANCE': '财经',
'TECHNOLOGY': '科技', 'TECHNOLOGY': '科技',
'SPORTS': '体育', 'MILITARY': '军事',
'ENTERTAINMENT': '娱乐', 'AUTOMOTIVE': '汽车',
'GOVERNMENT': '政务',
'HEALTH': '健康', 'HEALTH': '健康',
'EDUCATION': '教育', 'AI': 'AI'
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
} }
REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()} REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()}
# 分类专属停用词(与数据库表一致)
CATEGORY_STOPWORDS = {
'ENTERTAINMENT': {'主演', '影片', '电影', '电视剧', '节目', '导演', '角色', '上映', '粉丝'},
'SPORTS': {'比赛', '赛事', '赛季', '球队', '选手', '球员', '主场', '客场', '对阵', '比分'},
'FINANCE': {'亿元', '万元', '同比', '环比', '增长率', '数据', '报告', '统计', '财报', '季度', '年度'},
'TECHNOLOGY': {'技术', '系统', '平台', '方案', '应用', '功能', '版本', '升级', '研发', '推出'},
'MILITARY': {'部队', '军方', '演习', '训练', '装备', '武器', '作战', '行动', '部署'},
'AUTOMOTIVE': {'汽车', '车型', '上市', '发布', '销量', '市场', '品牌', '厂商'},
'GOVERNMENT': {'会议', '讲话', '指出', '强调', '部署', '落实', '推进', '要求', '精神', '决定', '意见', '方案', '安排'},
'HEALTH': {'医生', '专家', '建议', '提示', '提醒', '研究', '发现', '可能', '有助于'},
'AI': {'技术', '系统', '模型', '算法', '应用', '功能', '版本', '升级', '研发', '推出', '人工智能'}
}
class NewsClassifier: class NewsClassifier:
"""新闻文本分类器""" """新闻文本分类器"""
def __init__(self, model_type='nb'): def __init__(self, model_type='nb', use_stopwords=True, use_category_stopwords=False):
""" """
初始化分类器 初始化分类器
:param model_type: 模型类型 'nb' 朴素贝叶斯 'svm' 支持向量机 :param model_type: 模型类型 'nb' 朴素贝叶斯 'svm' 支持向量机
:param use_stopwords: 是否使用通用停用词
:param use_category_stopwords: 是否使用分类专属停用词
""" """
self.model_type = model_type self.model_type = model_type
self.vectorizer = None self.vectorizer = None
self.classifier = None self.classifier = None
self.categories = list(CATEGORY_MAP.keys()) self.categories = list(CATEGORY_MAP.keys())
self.use_stopwords = use_stopwords
self.use_category_stopwords = use_category_stopwords
self.stopwords = set()
def preprocess_text(self, text): if self.use_stopwords:
self._load_stopwords()
def _load_stopwords(self, stopwords_path='../../data/news_stopwords.txt'):
""" """
文本预处理使用jieba分词 加载停用词表
"""
try:
with open(stopwords_path, 'r', encoding='utf-8') as f:
self.stopwords = set(line.strip() for line in f if line.strip())
print(f"已加载 {len(self.stopwords)} 个停用词")
except FileNotFoundError:
print(f"警告: 停用词文件不存在: {stopwords_path}")
def preprocess_text(self, text, category=None):
"""
文本预处理使用jieba分词 + 停用词过滤
:param text: 待处理文本
:param category: 可选指定分类时使用分类专属停用词
""" """
# 移除多余空格和换行 # 移除多余空格和换行
text = ' '.join(text.split()) text = ' '.join(text.split())
# jieba分词 # jieba分词
words = jieba.cut(text) words = jieba.cut(text)
return ' '.join(words)
# 过滤停用词和单字词
result = []
for w in words:
# 过滤单字词
if len(w) <= 1:
continue
# 过滤通用停用词
if self.use_stopwords and w in self.stopwords:
continue
# 过滤分类专属停用词
if self.use_category_stopwords and category:
if w in CATEGORY_STOPWORDS.get(category, set()):
continue
result.append(w)
return ' '.join(result)
def load_data(self, csv_path): def load_data(self, csv_path):
""" """
@ -61,10 +111,22 @@ class NewsClassifier:
df = pd.read_csv(csv_path) df = pd.read_csv(csv_path)
# 合并标题和内容作为特征 # 合并标题和内容作为特征
df['text'] = df['title'] + ' ' + df['content'] df['text'] = df['title'] + ' ' + df['content']
# 预处理
df['processed_text'] = df['text'].apply(self.preprocess_text) # 预处理(如果启用了分类专属停用词,需要传入分类信息)
# 转换分类名称为代码 if self.use_category_stopwords:
df['category_code'] = df['category'].map(REVERSE_CATEGORY_MAP) # 先转换为分类代码
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
# 逐行预处理,传入分类信息
df['processed_text'] = df.apply(
lambda row: self.preprocess_text(row['text'], row['category_code']),
axis=1
)
else:
# 不使用分类专属停用词,直接批量预处理
df['processed_text'] = df['text'].apply(self.preprocess_text)
# 转换分类名称为代码
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
return df return df
def train(self, df): def train(self, df):
@ -74,6 +136,10 @@ class NewsClassifier:
X = df['processed_text'].values X = df['processed_text'].values
y = df['category_code'].values y = df['category_code'].values
# 获取实际数据中的分类
actual_categories = sorted(df['category_code'].unique().tolist())
actual_category_names = [CATEGORY_MAP[cat] for cat in actual_categories]
# 划分训练集和测试集 # 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split( X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y X, y, test_size=0.2, random_state=42, stratify=y
@ -108,7 +174,7 @@ class NewsClassifier:
print(f"准确率: {accuracy:.4f}") print(f"准确率: {accuracy:.4f}")
print(f"F1-Score: {f1:.4f}") print(f"F1-Score: {f1:.4f}")
print("\n分类报告:") print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=self.categories)) print(classification_report(y_test, y_pred, target_names=actual_category_names))
return accuracy, f1 return accuracy, f1
@ -120,6 +186,7 @@ class NewsClassifier:
raise ValueError("模型未训练请先调用train方法") raise ValueError("模型未训练请先调用train方法")
text = title + ' ' + content text = title + ' ' + content
# 预测时不指定分类,只使用通用停用词
processed = self.preprocess_text(text) processed = self.preprocess_text(text)
tfidf = self.vectorizer.transform([processed]) tfidf = self.vectorizer.transform([processed])
@ -157,12 +224,12 @@ if __name__ == '__main__':
classifier = NewsClassifier(model_type='nb') classifier = NewsClassifier(model_type='nb')
# 假设有训练数据文件 # 假设有训练数据文件
train_data_path = '../data/processed/training_data.csv' train_data_path = '../../data/processed/training_data.csv'
if os.path.exists(train_data_path): if os.path.exists(train_data_path):
df = classifier.load_data(train_data_path) df = classifier.load_data(train_data_path)
classifier.train(df) classifier.train(df)
classifier.save_model('../models') classifier.save_model('../../models/traditional')
# 测试预测 # 测试预测
test_title = "华为发布新款折叠屏手机" test_title = "华为发布新款折叠屏手机"

View File

@ -0,0 +1,163 @@
"""
数据加载和本地存储工具
"""
import pandas as pd
import os
from sqlalchemy import create_engine
from typing import List, Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
class DataLoader:
"""数据加载和本地存储类"""
def __init__(self, db_url: str, data_dir: str = 'data'):
"""
初始化数据加载器
:param db_url: 数据库连接URL
:param data_dir: 数据存储目录
"""
self.db_url = db_url
self.data_dir = data_dir
self.engine = create_engine(db_url)
# 确保数据目录存在
os.makedirs(os.path.join(data_dir, 'raw'), exist_ok=True)
os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True)
def load_news_from_db(self, limit: int = None, category_ids: List[int] = None) -> pd.DataFrame:
"""
从数据库加载新闻数据
:param limit: 限制加载数据量
:param category_ids: 指定分类ID列表
:return: 新闻数据DataFrame
"""
query = """
SELECT n.id, n.title, n.content, c.name as category_name
FROM news n
LEFT JOIN news_category c ON n.category_id = c.id
WHERE n.content IS NOT NULL AND n.title IS NOT NULL
"""
if category_ids:
ids_str = ",".join(str(i) for i in category_ids)
query += f" AND c.id IN ({ids_str})"
if limit:
query += f" LIMIT {limit}"
df = pd.read_sql(query, self.engine)
logger.info(f"从数据库加载数据,条件: limit={limit}, category_ids={category_ids}")
try:
df = pd.read_sql(query, self.engine)
logger.info(f"成功加载 {len(df)} 条新闻数据")
return df
except Exception as e:
logger.error(f"加载数据失败: {e}")
raise
def save_to_local(self, df: pd.DataFrame, filename: str, subdir: str = 'processed'):
"""
将数据保存到本地
:param df: 要保存的DataFrame
:param filename: 文件名
:param subdir: 子目录
"""
file_path = os.path.join(self.data_dir, subdir, filename)
# 确保目录存在
os.makedirs(os.path.join(self.data_dir, subdir), exist_ok=True)
# 保存为CSV
df.to_csv(file_path, index=False, encoding='utf-8-sig')
logger.info(f"数据已保存到: {file_path}")
def load_from_local(self, filename: str, subdir: str = 'processed') -> Optional[pd.DataFrame]:
"""
从本地加载已保存的数据
:param filename: 文件名
:param subdir: 子目录
:return: DataFrame或None如果文件不存在
"""
file_path = os.path.join(self.data_dir, subdir, filename)
if os.path.exists(file_path):
df = pd.read_csv(file_path, encoding='utf-8-sig')
logger.info(f"从本地加载 {len(df)} 条数据: {file_path}")
return df
else:
logger.warning(f"本地文件不存在: {file_path}")
return None
def update_local_data(self, limit: int = None, category_ids: List[int] = None):
"""
更新本地数据从数据库重新加载数据并保存
:param limit: 限制加载数据量
:param category_ids: 指定分类ID列表
"""
# 从数据库加载数据
df = self.load_news_from_db(limit=limit, category_ids=category_ids)
if not df.empty:
# 生成文件名(包含时间戳)
import time
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"news_data_{timestamp}.csv"
# 保存到本地
self.save_to_local(df, filename)
# 可选:删除旧的本地文件,只保留最新的
self._cleanup_old_files(filename)
return df
else:
logger.warning("没有可用的数据需要保存")
return None
def _cleanup_old_files(self, current_filename: str):
"""
清理旧的本地文件可选
:param current_filename: 当前文件名
"""
subdir = 'processed'
dir_path = os.path.join(self.data_dir, subdir)
if os.path.exists(dir_path):
for filename in os.listdir(dir_path):
if filename != current_filename and filename.startswith('news_data_'):
file_path = os.path.join(dir_path, filename)
try:
os.remove(file_path)
logger.info(f"已删除旧文件: {file_path}")
except Exception as e:
logger.error(f"删除文件失败 {file_path}: {e}")
# 示例用法
if __name__ == '__main__':
# 配置数据库连接
db_url = "mysql+pymysql://root:root@localhost/news"
# 创建数据加载器
loader = DataLoader(db_url)
# 从数据库加载数据并保存到本地
data = loader.update_local_data(limit=800)
if data is not None:
print(f"成功加载数据,共 {len(data)} 条记录")
print(data.head())
else:
print("没有数据可加载")

View File

@ -0,0 +1,96 @@
"""
模型评估指标工具
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
class ClassificationMetrics:
"""分类评估指标"""
@staticmethod
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
"""
计算所有指标
"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# 每个类别的指标
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
per_class_metrics = {}
for i, label in enumerate(labels):
per_class_metrics[label] = {
'precision': float(precision_per_class[i]),
'recall': float(recall_per_class[i]),
'f1': float(f1_per_class[i]),
'support': int(support_per_class[i])
}
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'per_class': per_class_metrics
}
@staticmethod
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
"""
绘制混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def print_report(y_true: List, y_pred: List, labels: List[str]):
"""
打印分类报告
"""
report = classification_report(
y_true, y_pred,
target_names=labels,
zero_division=0
)
print(report)
if __name__ == '__main__':
# 测试
y_true = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE', 'ENTERTAINMENT', 'TECHNOLOGY']
y_pred = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
labels = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE']
metrics = ClassificationMetrics()
result = metrics.compute_all(y_true, y_pred, labels)
print(result)

114
ml-module/train_bert.py Normal file
View File

@ -0,0 +1,114 @@
"""
BERT模型训练脚本
支持GPU/CPU自动检测和动态参数调整
"""
import argparse
import yaml
import os
from src.deep_learning.bert_model import BertClassifier
from src.utils.data_loader import DataLoader
from datasets import Dataset
import pandas as pd
from transformers import BertTokenizer, DataCollatorWithPadding
def load_and_prepare_data(db_url: str, limit: int = 1000):
"""加载数据并准备训练集"""
# 创建数据加载器
loader = DataLoader(db_url)
# 从数据库加载数据
df = loader.load_news_from_db(limit=limit)
if df.empty:
print("没有可用的训练数据")
return None, None
# 准备数据
texts = df['title'] + ' ' + df['content']
labels = df['category_name'].map({
'娱乐': 0, '体育': 1, '财经': 2, '科技': 3, '军事': 4,
'汽车': 5, '政务': 6, '健康': 7, 'AI': 8
}).fillna(-1).astype(int)
# 过滤无效数据
valid_data = df[labels != -1]
texts = valid_data['title'] + ' ' + valid_data['content']
labels = labels[labels != -1]
print(f"有效数据数量: {len(valid_data)}")
# 创建Hugging Face Dataset
dataset = Dataset.from_pandas(pd.DataFrame({
'text': texts.tolist(),
'label': labels.tolist()
}))
# 划分训练集和验证集
train_test = dataset.train_test_split(test_size=0.2)
return train_test['train'], train_test['test']
def train_bert_model(config: dict):
"""训练BERT模型"""
# 加载数据
train_dataset, eval_dataset = load_and_prepare_data(
config['database']['url'],
limit=config['training']['data_limit']
)
if train_dataset is None:
return
# 初始化模型
classifier = BertClassifier(
model_name=config['model']['name'],
num_labels=config['model']['num_labels'],
use_gpu=config['training']['use_gpu']
)
# 训练模型
classifier.train_model(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir=config['model']['output_dir'],
num_train_epochs=config['training']['epochs'],
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'] * 2,
learning_rate=config['training']['learning_rate'],
warmup_steps=config['training']['warmup_steps'],
weight_decay=config['training']['weight_decay'],
fp16=config['training'].get('fp16', None)
)
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='BERT模型训练脚本')
parser.add_argument('--config', type=str, default='config.yaml', help='配置文件路径')
parser.add_argument('--use_gpu', action='store_true', help='强制使用GPU')
parser.add_argument('--epochs', type=int, help='训练轮数(覆盖配置文件)')
parser.add_argument('--batch_size', type=int, help='批大小(覆盖配置文件)')
args = parser.parse_args()
# 加载配置文件
with open(args.config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 覆盖配置参数
if args.epochs is not None:
config['training']['epochs'] = args.epochs
if args.batch_size is not None:
config['training']['batch_size'] = args.batch_size
if args.use_gpu:
config['training']['use_gpu'] = True
# 开始训练
print("开始BERT模型训练...")
print(f"配置: {config}")
train_bert_model(config)
print("训练完成!")
if __name__ == '__main__':
main()

View File

@ -0,0 +1,800 @@
# 新闻文本分类系统 - 模块开发任务清单
---
## 目录
1. [爬虫模块 (Python)](#1-爬虫模块-python)
2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot)
3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3)
4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python)
---
## 1. 爬虫模块 (Python)
## 2. 后端服务模块 (Spring Boot)
## 4. 机器学习分类模块 (Python)
### 模块目录结构
```
ml-module/
├── data/
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后的数据
│ │ ├── training_data.csv
│ │ └── test_data.csv
│ └── external/ # 外部数据集
├── models/ # 训练好的模型
│ ├── traditional/
│ │ ├── nb_vectorizer.pkl
│ │ ├── nb_classifier.pkl
│ │ ├── svm_vectorizer.pkl
│ │ └── svm_classifier.pkl
│ ├── deep_learning/
│ │ └── bert_finetuned/
│ └── hybrid/
│ └── config.json
├── src/
│ ├── __init__.py
│ ├── traditional/ # 传统机器学习
│ │ ├── __init__.py
│ │ ├── train_model.py # (已有)
│ │ ├── predict.py
│ │ └── evaluate.py
│ ├── deep_learning/ # 深度学习
│ │ ├── __init__.py
│ │ ├── bert_model.py
│ │ ├── train_bert.py
│ │ └── predict_bert.py
│ ├── hybrid/ # 混合策略
│ │ ├── __init__.py
│ │ ├── hybrid_classifier.py
│ │ └── rule_engine.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── preprocessing.py # 数据预处理
│ │ └── metrics.py # 评估指标
│ └── api/ # API服务
│ ├── __init__.py
│ └── server.py # FastAPI服务
├── notebooks/ # Jupyter notebooks
│ ├── data_exploration.ipynb
│ └── model_comparison.ipynb
├── tests/ # 测试
│ ├── test_traditional.py
│ ├── test_bert.py
│ └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md
```
### 4.1 需要完成的具体文件
#### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测
```python
"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)
```
#### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型
```python
"""
BERT文本分类模型
"""
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments
)
from typing import Dict, Any, List
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
class BertClassifier:
"""BERT文本分类器"""
def __init__(self, model_name='bert-base-chinese', num_labels=10):
self.model_name = model_name
self.num_labels = num_labels
self.tokenizer = None
self.model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self, model_path):
"""加载微调后的模型"""
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=self.num_labels
)
self.model.to(self.device)
self.model.eval()
print(f"BERT模型加载成功: {model_path}")
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
"""
if self.model is None or self.tokenizer is None:
raise ValueError("模型未加载请先调用load_model")
# 组合标题和内容
text = f"{title} [SEP] {content}"
# 分词
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
# 预测
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
# 获取预测结果
probs = torch.softmax(logits, dim=-1)
confidence, predicted_id = torch.max(probs, dim=-1)
predicted_id = predicted_id.item()
confidence = confidence.item()
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probs[0].cpu().numpy()):
category_code = ID_TO_LABEL[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': ID_TO_LABEL[predicted_id],
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
'confidence': confidence,
'probabilities': prob_dict
}
# 数据集类
class NewsDataset(torch.utils.data.Dataset):
"""新闻数据集"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
if __name__ == '__main__':
# 测试
classifier = BertClassifier()
# classifier.load_model('./models/deep_learning/bert_finetuned')
#
# result = classifier.predict(
# title="华为发布新款折叠屏手机",
# content="华为今天正式发布了新一代折叠屏手机..."
# )
# print(result)
print("BERT分类器初始化成功")
```
#### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器
```python
"""
混合策略分类器
结合规则引擎和机器学习模型
"""
import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier
class HybridClassifier:
"""混合分类器"""
def __init__(self):
# 初始化各个分类器
self.nb_predictor = TraditionalPredictor('nb')
self.bert_classifier = BertClassifier()
# 配置参数
self.config = {
'confidence_threshold': 0.75, # 高置信度阈值
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
'use_bert_threshold': 0.70, # 使用BERT的阈值
'rule_priority': True # 规则优先
}
# 规则关键词字典
self.rule_keywords = {
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
'LIFE': ['生活', '美食', '旅游', '购物'],
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
}
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
"""
规则匹配
:return: (category_code, confidence)
"""
text = title + ' ' + content
# 计算每个类别的关键词匹配数
matches = {}
for category, keywords in self.rule_keywords.items():
count = sum(1 for kw in keywords if kw in text)
if count > 0:
matches[category] = count
if not matches:
return None, 0.0
# 返回匹配最多的类别
best_category = max(matches, key=matches.get)
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
return best_category, confidence
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
"""
混合预测
"""
start_time = time.time()
# 1. 先尝试规则匹配
rule_category, rule_confidence = self.rule_match(title, content)
# 2. 传统机器学习预测
nb_result = self.nb_predictor.predict(title, content)
nb_confidence = nb_result['confidence']
# 决策逻辑
final_result = None
classifier_type = 'HYBRID'
# 规则优先且规则置信度高
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
final_result = {
'categoryCode': rule_category,
'categoryName': nb_result['categoryName'], # 从映射获取
'confidence': rule_confidence,
'classifierType': 'RULE',
'reason': '规则匹配'
}
# 传统模型置信度足够高
elif nb_confidence >= self.config['confidence_threshold']:
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '传统模型高置信度'
}
# 需要使用BERT
elif use_bert:
# TODO: 加载BERT模型预测
# bert_result = self.bert_classifier.predict(title, content)
# 如果BERT置信度也不高选择最高的
final_result = {
**nb_result,
'classifierType': 'HYBRID',
'reason': '混合决策'
}
else:
# 不使用BERT直接返回传统模型结果
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '默认传统模型'
}
# 计算耗时
duration = int((time.time() - start_time) * 1000)
final_result['duration'] = duration
return final_result
if __name__ == '__main__':
# 测试
classifier = HybridClassifier()
test_cases = [
{
'title': '国务院发布最新经济政策',
'content': '国务院今天发布了新的经济政策...'
},
{
'title': '华为发布新款折叠屏手机',
'content': '华为今天正式发布了新一代折叠屏手机...'
}
]
for case in test_cases:
result = classifier.predict(case['title'], case['content'])
print(f"标题: {case['title']}")
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
print(f"分类器: {result['classifierType']}")
print(f"原因: {result.get('reason', 'N/A')}")
print(f"耗时: {result['duration']}ms")
print("-" * 50)
```
#### 任务 4.1.4: `src/api/server.py` - FastAPI服务
```python
"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging
# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title="新闻分类API",
description="提供新闻文本分类服务",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class ClassifyRequest(BaseModel):
title: str
content: str
mode: Optional[str] = 'hybrid' # traditional, hybrid
# 响应模型
class ClassifyResponse(BaseModel):
categoryCode: str
categoryName: str
confidence: float
classifierType: str
duration: int
probabilities: Optional[dict] = None
# 初始化分类器
nb_predictor = None
hybrid_classifier = None
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
global nb_predictor, hybrid_classifier
logger.info("加载模型...")
try:
nb_predictor = TraditionalPredictor('nb')
logger.info("朴素贝叶斯模型加载成功")
except Exception as e:
logger.error(f"朴素贝叶斯模型加载失败: {e}")
try:
hybrid_classifier = HybridClassifier()
logger.info("混合分类器初始化成功")
except Exception as e:
logger.error(f"混合分类器初始化失败: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"message": "新闻分类API服务运行中"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"models": {
"nb_loaded": nb_predictor is not None,
"hybrid_loaded": hybrid_classifier is not None
}
}
@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
"""
文本分类接口
- **title**: 新闻标题
- **content**: 新闻内容
- **mode**: 分类模式 (traditional, hybrid)
"""
try:
if request.mode == 'traditional':
result = nb_predictor.predict(request.title, request.content)
result['classifierType'] = 'ML'
else: # hybrid
result = hybrid_classifier.predict(request.title, request.content)
return ClassifyResponse(**result)
except Exception as e:
logger.error(f"预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
"""
批量分类接口
"""
results = []
for req in requests:
try:
if req.mode == 'traditional':
result = nb_predictor.predict(req.title, req.content)
result['classifierType'] = 'ML'
else:
result = hybrid_classifier.predict(req.title, req.content)
results.append(result)
except Exception as e:
results.append({
'error': str(e),
'title': req.title
})
return {"results": results}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=5000,
log_level="info"
)
```
#### 任务 4.1.5: `src/utils/metrics.py` - 评估指标
```python
"""
模型评估指标工具
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
class ClassificationMetrics:
"""分类评估指标"""
@staticmethod
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
"""
计算所有指标
"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# 每个类别的指标
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
per_class_metrics = {}
for i, label in enumerate(labels):
per_class_metrics[label] = {
'precision': float(precision_per_class[i]),
'recall': float(recall_per_class[i]),
'f1': float(f1_per_class[i]),
'support': int(support_per_class[i])
}
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'per_class': per_class_metrics
}
@staticmethod
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
"""
绘制混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def print_report(y_true: List, y_pred: List, labels: List[str]):
"""
打印分类报告
"""
report = classification_report(
y_true, y_pred,
target_names=labels,
zero_division=0
)
print(report)
if __name__ == '__main__':
# 测试
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
metrics = ClassificationMetrics()
result = metrics.compute_all(y_true, y_pred, labels)
print(result)
```
#### 任务 4.1.6: `requirements.txt` - 依赖文件
```txt
# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0
# 深度学习
torch>=2.0.0
transformers>=4.30.0
# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 工具
python-dotenv>=1.0.0
pyyaml>=6.0
```
---
## 总结
### 开发顺序建议
1. **第一阶段:基础框架**
- 后端:数据库连接、实体类、基础配置
- 前端路由配置、状态管理、API封装
2. **第二阶段:核心功能**
- 爬虫模块Python
- 传统机器学习分类器
- 后端API接口
- 前端新闻列表页面
3. **第三阶段:高级功能**
- BERT深度学习分类器
- 混合策略分类器
- 前端分类器对比页面
- 统计图表
4. **第四阶段:完善优化**
- 用户认证
- 数据可视化
- 性能优化
- 异常处理
### 关键注意事项
1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信
2. **分类器模块独立部署**,提供 HTTP 接口供后端调用
3. **前后端分离**,使用 JWT 进行身份认证
4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守
5. **API 统一响应格式**使用 `Result<T>` 包装