Compare commits
2 Commits
83f4fd4d58
...
2afdd698b2
| Author | SHA1 | Date |
|---|---|---|
|
|
2afdd698b2 | |
|
|
61a5b7d301 |
|
|
@ -110,3 +110,13 @@ sources:
|
||||||
category_id: 9
|
category_id: 9
|
||||||
name: "AI"
|
name: "AI"
|
||||||
css_selector: "div.kr-information-left"
|
css_selector: "div.kr-information-left"
|
||||||
|
|
||||||
|
sina:
|
||||||
|
base_url: "https://sina.com.cn"
|
||||||
|
categories:
|
||||||
|
auto:
|
||||||
|
url: "https://auto.sina.com.cn/"
|
||||||
|
category_id: 6
|
||||||
|
name: "汽车"
|
||||||
|
css_selector: "div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1"
|
||||||
|
detail_css_selector: "div.main-content"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,82 @@
|
||||||
|
|
||||||
|
这是新浪网关于爬取汽车相关新闻的代码
|
||||||
|
```python
|
||||||
|
import requests
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
|
URL = "https://auto.sina.com.cn/"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"User-Agent": (
|
||||||
|
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
|
||||||
|
"AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/120.0.0.0 Safari/537.36"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
resp = requests.get(URL,headers=headers,timeout=10)
|
||||||
|
# resp.raise_for_status()
|
||||||
|
# resp.encoding = "utf-8"
|
||||||
|
# print(resp.text)
|
||||||
|
with open("example/example-10.html","r",encoding="utf-8") as f:
|
||||||
|
html = f.read()
|
||||||
|
|
||||||
|
# soup = BeautifulSoup(resp.text,"lxml")
|
||||||
|
soup = BeautifulSoup(html,"lxml")
|
||||||
|
div_list = soup.select("div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1")
|
||||||
|
|
||||||
|
for item in div_list:
|
||||||
|
a = item.select_one("div.ty-card-l a")
|
||||||
|
href = a.get("href")
|
||||||
|
# print(a.get('href'),a.get_text().strip())
|
||||||
|
|
||||||
|
resp = requests.get(url=href,headers=headers)
|
||||||
|
resp.encoding = resp.apparent_encoding # requests 会尝试猜测编码
|
||||||
|
soup = BeautifulSoup(resp.text,"lxml")
|
||||||
|
# 获取文章标题
|
||||||
|
article_title_tag = soup.select_one("div.main-content h1.main-title")
|
||||||
|
if article_title_tag:
|
||||||
|
article_title = article_title_tag.get_text(strip=True)
|
||||||
|
if not article_title:
|
||||||
|
article_title = "未知标题"
|
||||||
|
else:
|
||||||
|
article_title = "未知标题"
|
||||||
|
# print("标题:", article_title)
|
||||||
|
# 获取文章发布时间
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# 日期时间格式化函数
|
||||||
|
def normalize_time(time_str):
|
||||||
|
for fmt in ("%Y年%m月%d日 %H:%M", "%Y-%m-%d %H:%M:%S"):
|
||||||
|
try:
|
||||||
|
dt = datetime.strptime(time_str, fmt)
|
||||||
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
return time_str # 如果都不匹配,返回原字符串
|
||||||
|
|
||||||
|
time_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source span.date")
|
||||||
|
if time_tag: # 只有存在时间标签才进行格式化
|
||||||
|
publish_time = normalize_time(time_tag.get_text(strip=True))
|
||||||
|
else:
|
||||||
|
publish_time = "1949-01-01 12:00:00"
|
||||||
|
#print(publish_time)
|
||||||
|
|
||||||
|
# 获取文章作者
|
||||||
|
author_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source a")
|
||||||
|
if author_tag:
|
||||||
|
author = author_tag.get_text(strip=True)
|
||||||
|
else:
|
||||||
|
author = "未知"
|
||||||
|
# print(author)
|
||||||
|
# 获取文章正文段落
|
||||||
|
article_div = soup.select_one("div.main-content div.article") # 核心文章容器
|
||||||
|
if not article_div:
|
||||||
|
# print("不是文章详情页,跳过")
|
||||||
|
continue # 如果不是详情页就跳过
|
||||||
|
paragraphs = article_div.find_all('p')
|
||||||
|
article_text = '\n'.join(p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True))
|
||||||
|
# print("正文:\n", article_text)
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
|
@ -102,7 +102,6 @@ class BaseCrawler(ABC):
|
||||||
try:
|
try:
|
||||||
# 获取页面HTML
|
# 获取页面HTML
|
||||||
html = self._fetch_page()
|
html = self._fetch_page()
|
||||||
|
|
||||||
# 解析文章列表
|
# 解析文章列表
|
||||||
article_urls = self._extract_article_urls(html)
|
article_urls = self._extract_article_urls(html)
|
||||||
self.logger.info(f"找到 {len(article_urls)} 篇文章")
|
self.logger.info(f"找到 {len(article_urls)} 篇文章")
|
||||||
|
|
|
||||||
|
|
@ -31,6 +31,9 @@ CRAWLER_CLASSES = {
|
||||||
'kr36': {
|
'kr36': {
|
||||||
'ai': ('crawlers.kr36.ai', 'AICrawler'),
|
'ai': ('crawlers.kr36.ai', 'AICrawler'),
|
||||||
},
|
},
|
||||||
|
'sina': {
|
||||||
|
'auto': ('crawlers.sina.auto', 'SinaAutoCrawler'),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -53,6 +56,11 @@ def list_crawlers() -> List[str]:
|
||||||
for category in kr36_categories.keys():
|
for category in kr36_categories.keys():
|
||||||
crawlers.append(f"kr36:{category}")
|
crawlers.append(f"kr36:{category}")
|
||||||
|
|
||||||
|
# 新浪爬虫
|
||||||
|
sina_categories = config.get('sources.sina.categories', {})
|
||||||
|
for category in sina_categories.keys():
|
||||||
|
crawlers.append(f"sina:{category}")
|
||||||
|
|
||||||
return crawlers
|
return crawlers
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -101,7 +109,7 @@ def run_crawler(source: str, category: str, max_articles: int = None) -> bool:
|
||||||
|
|
||||||
# 创建并运行爬虫
|
# 创建并运行爬虫
|
||||||
crawler = crawler_class(source, category)
|
crawler = crawler_class(source, category)
|
||||||
|
print("创建并运行爬虫")
|
||||||
# 覆盖最大文章数
|
# 覆盖最大文章数
|
||||||
if max_articles:
|
if max_articles:
|
||||||
crawler.max_articles = max_articles
|
crawler.max_articles = max_articles
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
"""
|
||||||
|
新浪汽车新闻爬虫
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
from base.crawler_base import DynamicCrawler, Article
|
||||||
|
from parsers.sina_parser import SinaAutoParser
|
||||||
|
|
||||||
|
|
||||||
|
class SinaAutoCrawler(DynamicCrawler):
|
||||||
|
"""新浪汽车新闻爬虫"""
|
||||||
|
|
||||||
|
def _extract_article_urls(self, html: str) -> List[str]:
|
||||||
|
"""从HTML中提取文章URL列表"""
|
||||||
|
soup = BeautifulSoup(html, "lxml")
|
||||||
|
urls = []
|
||||||
|
|
||||||
|
# 尝试不同的选择器
|
||||||
|
div_list = soup.select("div.cardlist-a__list div.ty-card.ty-card-type1")
|
||||||
|
if not div_list:
|
||||||
|
div_list = soup.select("div.news-list li.news-item")
|
||||||
|
if not div_list:
|
||||||
|
div_list = soup.select("div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1")
|
||||||
|
|
||||||
|
for item in div_list:
|
||||||
|
a = item.select_one("a")
|
||||||
|
if a and a.get("href"):
|
||||||
|
urls.append(a.get("href"))
|
||||||
|
|
||||||
|
return urls
|
||||||
|
|
||||||
|
def _fetch_articles(self, urls: List[str]) -> List[Article]:
|
||||||
|
"""爬取文章详情"""
|
||||||
|
articles = []
|
||||||
|
parser = SinaAutoParser()
|
||||||
|
|
||||||
|
for i, url in enumerate(urls[:self.max_articles]):
|
||||||
|
try:
|
||||||
|
article = parser.parse(url)
|
||||||
|
article.category_id = self.category_id
|
||||||
|
article.source = "新浪"
|
||||||
|
|
||||||
|
if not article.author:
|
||||||
|
article.author = "新浪汽车"
|
||||||
|
|
||||||
|
if article.is_valid():
|
||||||
|
articles.append(article)
|
||||||
|
self.logger.info(f"[{i+1}/{len(urls)}] {article.title}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"解析文章失败: {url} - {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return articles
|
||||||
|
|
@ -54,6 +54,7 @@ class NewsRepository:
|
||||||
try:
|
try:
|
||||||
with db_pool.get_connection() as conn:
|
with db_pool.get_connection() as conn:
|
||||||
cursor = conn.cursor()
|
cursor = conn.cursor()
|
||||||
|
|
||||||
# 批量查询已存在的URL
|
# 批量查询已存在的URL
|
||||||
if urls:
|
if urls:
|
||||||
placeholders = ','.join(['%s'] * len(urls))
|
placeholders = ','.join(['%s'] * len(urls))
|
||||||
|
|
@ -61,24 +62,25 @@ class NewsRepository:
|
||||||
cursor.execute(check_sql, urls)
|
cursor.execute(check_sql, urls)
|
||||||
existing_urls = {row[0] for row in cursor.fetchall()}
|
existing_urls = {row[0] for row in cursor.fetchall()}
|
||||||
|
|
||||||
# 只插入不存在的记录
|
# 只插入URL不存在的记录
|
||||||
new_data = [item for item in data if item[0] not in existing_urls]
|
new_data = [item for item in data if item[0] not in existing_urls]
|
||||||
|
|
||||||
if not new_data:
|
if not new_data:
|
||||||
self.logger.info(f"所有 {len(data)} 条新闻已存在,跳过插入")
|
self.logger.info(f"所有 {len(data)} 条新闻已存在,跳过插入")
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
# 执行插入
|
# 执行插入,使用 INSERT IGNORE 忽略 content_hash 重复的记录
|
||||||
sql = """
|
sql = """
|
||||||
INSERT INTO news (url, title, category_id, publish_time, author, source, content, content_hash)
|
INSERT IGNORE INTO news (url, title, category_id, publish_time, author, source, content, content_hash)
|
||||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cursor.executemany(sql, new_data)
|
cursor.executemany(sql, new_data)
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
||||||
inserted = len(new_data)
|
# 获取实际插入的行数
|
||||||
self.logger.info(f"成功插入 {inserted} 条新新闻,{len(data) - inserted} 条已存在")
|
inserted = cursor.rowcount
|
||||||
|
self.logger.info(f"成功插入 {inserted} 条新新闻,{len(new_data) - inserted} 条因内容重复被忽略")
|
||||||
return inserted
|
return inserted
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
|
from base.parser_base import BaseParser
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
from base.crawler_base import Article
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from utils.logger import get_logger
|
||||||
|
from utils.http_client import HttpClient
|
||||||
|
|
||||||
|
|
||||||
|
class SinaAutoParser(BaseParser):
|
||||||
|
"""新浪网汽车新闻解析器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.logger = get_logger(__name__)
|
||||||
|
self.http_client = HttpClient()
|
||||||
|
|
||||||
|
def parse(self, url: str) -> Article:
|
||||||
|
"""
|
||||||
|
解析新浪网文章详情页
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: 文章URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
文章对象
|
||||||
|
"""
|
||||||
|
html = self.http_client.get(url)
|
||||||
|
soup = BeautifulSoup(html, "lxml")
|
||||||
|
|
||||||
|
# 获取文章标题
|
||||||
|
article_title_tag = soup.select_one("div.main-content h1.main-title")
|
||||||
|
article_title = article_title_tag.get_text(strip=True) if article_title_tag else "未知标题"
|
||||||
|
|
||||||
|
# 获取文章发布时间
|
||||||
|
def normalize_time(time_str):
|
||||||
|
for fmt in ("%Y年%m月%d日 %H:%M", "%Y-%m-%d %H:%M:%S"):
|
||||||
|
try:
|
||||||
|
dt = datetime.strptime(time_str, fmt)
|
||||||
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
return time_str # 如果都不匹配,返回原字符串
|
||||||
|
|
||||||
|
time_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source span.date")
|
||||||
|
publish_time = normalize_time(time_tag.get_text(strip=True)) if time_tag else "1949-01-01 12:00:00"
|
||||||
|
|
||||||
|
# 获取文章作者
|
||||||
|
author_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source a")
|
||||||
|
author = author_tag.get_text(strip=True) if author_tag else "未知"
|
||||||
|
|
||||||
|
# 获取文章正文段落
|
||||||
|
article_div = soup.select_one("div.main-content div.article")
|
||||||
|
if not article_div:
|
||||||
|
raise ValueError("无法找到文章内容")
|
||||||
|
|
||||||
|
paragraphs = article_div.find_all('p')
|
||||||
|
content = '\n'.join(p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True))
|
||||||
|
|
||||||
|
return Article(
|
||||||
|
url=url,
|
||||||
|
title=article_title,
|
||||||
|
publish_time=publish_time,
|
||||||
|
author=author,
|
||||||
|
content=content,
|
||||||
|
category_id=6, # 汽车分类ID
|
||||||
|
source="sina"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
{
|
||||||
|
"permissions": {
|
||||||
|
"allow": [
|
||||||
|
"Bash(mkdir:*)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,197 @@
|
||||||
|
# 新闻文本分类系统 - 机器学习模块
|
||||||
|
|
||||||
|
## 功能特性
|
||||||
|
|
||||||
|
### GPU/CPU自动检测
|
||||||
|
- 自动检测可用GPU(包括8GB显存)
|
||||||
|
- 自动回退到CPU(如果GPU不可用)
|
||||||
|
- 显示设备信息和显存信息
|
||||||
|
|
||||||
|
### 动态参数调整
|
||||||
|
- 通过配置文件调整训练参数
|
||||||
|
- 支持命令行参数覆盖
|
||||||
|
- 混合精度自动检测
|
||||||
|
|
||||||
|
## 使用方法
|
||||||
|
|
||||||
|
### 1. 安装依赖
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
### 从mysql拉去训练数据
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python .\src\utils\data_loader.py
|
||||||
|
```
|
||||||
|
### 2. 配置训练参数
|
||||||
|
编辑 `config.yaml` 文件调整训练参数:
|
||||||
|
```yaml
|
||||||
|
database:
|
||||||
|
url: "mysql+pymysql://root:root@localhost/news_classifier"
|
||||||
|
data_limit: 1000
|
||||||
|
|
||||||
|
model:
|
||||||
|
name: "bert-base-chinese"
|
||||||
|
num_labels: 9
|
||||||
|
output_dir: "./models/deep_learning/bert_finetuned"
|
||||||
|
|
||||||
|
training:
|
||||||
|
use_gpu: true # 自动检测GPU
|
||||||
|
epochs: 3
|
||||||
|
batch_size: 8
|
||||||
|
learning_rate: 2e-5
|
||||||
|
warmup_steps: 500
|
||||||
|
weight_decay: 0.01
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 训练BERT模型
|
||||||
|
```bash
|
||||||
|
python train_bert.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 使用命令行参数覆盖配置
|
||||||
|
```bash
|
||||||
|
# 使用GPU训练
|
||||||
|
python train_bert.py --use_gpu
|
||||||
|
|
||||||
|
# 调整训练轮数
|
||||||
|
python train_bert.py --epochs 5
|
||||||
|
|
||||||
|
# 调整批大小
|
||||||
|
python train_bert.py --batch_size 16
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 启动API服务
|
||||||
|
```bash
|
||||||
|
python src/api/server.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 参数说明
|
||||||
|
|
||||||
|
### 训练参数
|
||||||
|
- `epochs`: 训练轮数(默认:3)
|
||||||
|
- `batch_size`: 批大小(默认:8)
|
||||||
|
- `learning_rate`: 学习率(默认:2e-5)
|
||||||
|
- `warmup_steps`: 预热步数(默认:500)
|
||||||
|
- `weight_decay`: 权重衰减(默认:0.01)
|
||||||
|
|
||||||
|
### 设备配置
|
||||||
|
- `use_gpu`: 是否使用GPU(自动检测)
|
||||||
|
- `fp16`: 混合精度(自动检测)
|
||||||
|
|
||||||
|
## 8GB显存优化建议
|
||||||
|
|
||||||
|
对于8GB显存的GPU,建议配置:
|
||||||
|
```yaml
|
||||||
|
training:
|
||||||
|
use_gpu: true
|
||||||
|
epochs: 3
|
||||||
|
batch_size: 8-16
|
||||||
|
fp16: true # 启用混合精度
|
||||||
|
```
|
||||||
|
|
||||||
|
## 设备检测
|
||||||
|
训练时会自动检测设备:
|
||||||
|
- **GPU可用**:使用GPU训练,显示GPU名称和显存信息
|
||||||
|
- **GPU不可用**:自动回退到CPU训练
|
||||||
|
|
||||||
|
## 性能优化
|
||||||
|
- 混合精度训练(FP16)
|
||||||
|
- 梯度累积
|
||||||
|
- 自动批大小调整
|
||||||
|
- 内存优化
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
1. 确保安装了CUDA和cuDNN(如果使用GPU)
|
||||||
|
2. 8GB显存可以训练BERT-base模型
|
||||||
|
3. 可以通过调整batch_size适应不同显存大小
|
||||||
|
4. 训练时间取决于数据量和硬件配置
|
||||||
|
|
||||||
|
## API接口
|
||||||
|
```bash
|
||||||
|
# 单条预测
|
||||||
|
POST /api/predict
|
||||||
|
{
|
||||||
|
"title": "新闻标题",
|
||||||
|
"content": "新闻内容",
|
||||||
|
"mode": "hybrid" # traditional, hybrid
|
||||||
|
}
|
||||||
|
|
||||||
|
# 批量预测
|
||||||
|
POST /api/batch-predict
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"title": "新闻标题1",
|
||||||
|
"content": "新闻内容1",
|
||||||
|
"mode": "hybrid"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"title": "新闻标题2",
|
||||||
|
"content": "新闻内容2",
|
||||||
|
"mode": "traditional"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## **分组方案**
|
||||||
|
|
||||||
|
### **1️⃣ 基础科学计算 / 机器学习**
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **2️⃣ 深度学习 / NLP**
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0
|
||||||
|
```
|
||||||
|
|
||||||
|
💡 如果你有 **GPU** 并希望安装 GPU 版 PyTorch,需要单独去 PyTorch 官网生成命令。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **3️⃣ API 服务**
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **4️⃣ 数据库相关**
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
pip install sqlalchemy>=2.0.0 pymysql>=1.1.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **5️⃣ 数据可视化**
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
pip install matplotlib>=3.7.0 seaborn>=0.12.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### **6️⃣ 工具 / 配置文件处理**
|
||||||
|
|
||||||
|
```cmd
|
||||||
|
pip install python-dotenv>=1.0.0 pyyaml>=6.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
✅ **这样分批安装的好处**:
|
||||||
|
|
||||||
|
1. 出现安装错误时,更容易定位是哪个模块有问题。
|
||||||
|
2. 每组依赖关系相对独立,减少冲突。
|
||||||
|
3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。
|
||||||
|
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
# BERT模型训练配置
|
||||||
|
database:
|
||||||
|
url: "mysql+pymysql://root:root@localhost/news_classifier" # 数据库连接URL
|
||||||
|
data_limit: 1000 # 加载数据量限制
|
||||||
|
|
||||||
|
model:
|
||||||
|
name: "bert-base-chinese" # 模型名称
|
||||||
|
num_labels: 9 # 分类数量
|
||||||
|
output_dir: "./models/deep_learning/bert_finetuned" # 模型输出目录
|
||||||
|
|
||||||
|
training:
|
||||||
|
use_gpu: true # 是否使用GPU(自动检测)
|
||||||
|
epochs: 3 # 训练轮数
|
||||||
|
batch_size: 8 # 训练批大小
|
||||||
|
learning_rate: 2e-5 # 学习率
|
||||||
|
warmup_steps: 500 # 预热步数
|
||||||
|
weight_decay: 0.01 # 权重衰减
|
||||||
|
fp16: null # 混合精度(null表示自动检测)
|
||||||
|
|
||||||
|
# 日志和输出配置
|
||||||
|
logging:
|
||||||
|
level: "INFO" # 日志级别
|
||||||
|
file: "./training.log" # 日志文件
|
||||||
|
|
||||||
|
# 设备配置
|
||||||
|
device:
|
||||||
|
max_memory: "8GB" # 最大内存限制
|
||||||
|
gradient_checkpointing: true # 梯度检查点
|
||||||
|
|
||||||
|
# 性能优化
|
||||||
|
optimization:
|
||||||
|
gradient_accumulation_steps: 1 # 梯度累积步数
|
||||||
|
mixed_precision: true # 混合精度(如果支持)
|
||||||
|
|
@ -0,0 +1,160 @@
|
||||||
|
的
|
||||||
|
了
|
||||||
|
是
|
||||||
|
在
|
||||||
|
有
|
||||||
|
和
|
||||||
|
与
|
||||||
|
及
|
||||||
|
或
|
||||||
|
而
|
||||||
|
但
|
||||||
|
被
|
||||||
|
把
|
||||||
|
将
|
||||||
|
对
|
||||||
|
于
|
||||||
|
中
|
||||||
|
上
|
||||||
|
下
|
||||||
|
内
|
||||||
|
外
|
||||||
|
等
|
||||||
|
为
|
||||||
|
以
|
||||||
|
从
|
||||||
|
到
|
||||||
|
由
|
||||||
|
就
|
||||||
|
也
|
||||||
|
都
|
||||||
|
还
|
||||||
|
又
|
||||||
|
很
|
||||||
|
并
|
||||||
|
及其
|
||||||
|
以及
|
||||||
|
之一
|
||||||
|
一些
|
||||||
|
一个
|
||||||
|
一种
|
||||||
|
多种
|
||||||
|
各自
|
||||||
|
各个
|
||||||
|
各类
|
||||||
|
记者
|
||||||
|
通讯员
|
||||||
|
编辑
|
||||||
|
报道
|
||||||
|
表示
|
||||||
|
指出
|
||||||
|
认为
|
||||||
|
称
|
||||||
|
说
|
||||||
|
介绍
|
||||||
|
透露
|
||||||
|
强调
|
||||||
|
分析
|
||||||
|
称之为
|
||||||
|
近日
|
||||||
|
今日
|
||||||
|
今天
|
||||||
|
昨日
|
||||||
|
目前
|
||||||
|
当前
|
||||||
|
今年
|
||||||
|
去年
|
||||||
|
明年
|
||||||
|
此前
|
||||||
|
之后
|
||||||
|
当日
|
||||||
|
当天
|
||||||
|
近日来
|
||||||
|
近来
|
||||||
|
近期
|
||||||
|
未来
|
||||||
|
方面
|
||||||
|
情况
|
||||||
|
问题
|
||||||
|
工作
|
||||||
|
任务
|
||||||
|
活动
|
||||||
|
会议
|
||||||
|
会议上
|
||||||
|
会议中
|
||||||
|
会议期间
|
||||||
|
会议指出
|
||||||
|
相关
|
||||||
|
有关
|
||||||
|
一定
|
||||||
|
一些
|
||||||
|
部分
|
||||||
|
整体
|
||||||
|
进一步
|
||||||
|
持续
|
||||||
|
不断
|
||||||
|
不断地
|
||||||
|
继续
|
||||||
|
推进
|
||||||
|
加强
|
||||||
|
提升
|
||||||
|
改善
|
||||||
|
推动
|
||||||
|
加快
|
||||||
|
实现
|
||||||
|
完成
|
||||||
|
开展
|
||||||
|
进行
|
||||||
|
同时
|
||||||
|
此外
|
||||||
|
因此
|
||||||
|
所以
|
||||||
|
但是
|
||||||
|
然而
|
||||||
|
不过
|
||||||
|
如果
|
||||||
|
因为
|
||||||
|
虽然
|
||||||
|
由于
|
||||||
|
其中
|
||||||
|
其中之一
|
||||||
|
对此
|
||||||
|
对此次
|
||||||
|
对此前
|
||||||
|
对此后
|
||||||
|
通过
|
||||||
|
按照
|
||||||
|
根据
|
||||||
|
依据
|
||||||
|
围绕
|
||||||
|
围绕着
|
||||||
|
针对
|
||||||
|
关于
|
||||||
|
对于
|
||||||
|
面对
|
||||||
|
着力
|
||||||
|
积极
|
||||||
|
主动
|
||||||
|
有效
|
||||||
|
充分
|
||||||
|
全面
|
||||||
|
已经
|
||||||
|
正在
|
||||||
|
正在进行
|
||||||
|
正在推进
|
||||||
|
正在开展
|
||||||
|
开始
|
||||||
|
结束
|
||||||
|
完成后
|
||||||
|
之后
|
||||||
|
以来
|
||||||
|
以来的
|
||||||
|
相关人士
|
||||||
|
业内人士
|
||||||
|
专家表示
|
||||||
|
专家认为
|
||||||
|
业内认为
|
||||||
|
市场认为
|
||||||
|
分析人士
|
||||||
|
有关人士
|
||||||
|
知情人士
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
|
||||||
|
### 新闻分类表
|
||||||
|
```sql
|
||||||
|
CREATE TABLE news_category (
|
||||||
|
id INT NOT NULL AUTO_INCREMENT COMMENT '分类ID',
|
||||||
|
name VARCHAR(50) NOT NULL COMMENT '分类名称',
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
UNIQUE KEY uk_name (name)
|
||||||
|
) ENGINE=InnoDB
|
||||||
|
DEFAULT CHARSET=utf8mb4
|
||||||
|
COLLATE=utf8mb4_0900_ai_ci
|
||||||
|
COMMENT='新闻分类表';
|
||||||
|
```
|
||||||
|
数据:
|
||||||
|
```text
|
||||||
|
1 娱乐
|
||||||
|
2 体育
|
||||||
|
3 财经
|
||||||
|
4 科技
|
||||||
|
5 军事
|
||||||
|
6 汽车
|
||||||
|
7 政务
|
||||||
|
8 健康
|
||||||
|
9 AI
|
||||||
|
```
|
||||||
|
|
||||||
|
### 新闻表
|
||||||
|
```sql
|
||||||
|
CREATE TABLE news (
|
||||||
|
id BIGINT NOT NULL AUTO_INCREMENT COMMENT '自增主键',
|
||||||
|
url VARCHAR(500) NOT NULL COMMENT '新闻原始URL',
|
||||||
|
title VARCHAR(255) NOT NULL COMMENT '新闻标题',
|
||||||
|
category_id INT NULL COMMENT '新闻分类ID',
|
||||||
|
publish_time DATETIME NULL COMMENT '发布时间',
|
||||||
|
author VARCHAR(100) NULL COMMENT '作者/来源',
|
||||||
|
source VARCHAR(50) NULL COMMENT '新闻来源(网易/36kr)',
|
||||||
|
content LONGTEXT NOT NULL COMMENT '新闻正文',
|
||||||
|
content_hash CHAR(64) NOT NULL COMMENT '正文内容hash,用于去重',
|
||||||
|
created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP COMMENT '入库时间',
|
||||||
|
|
||||||
|
PRIMARY KEY (id),
|
||||||
|
UNIQUE KEY uk_url (url),
|
||||||
|
UNIQUE KEY uk_content_hash (content_hash),
|
||||||
|
KEY idx_category_id (category_id),
|
||||||
|
KEY idx_source (source)
|
||||||
|
|
||||||
|
) ENGINE=InnoDB
|
||||||
|
DEFAULT CHARSET=utf8mb4
|
||||||
|
COLLATE=utf8mb4_0900_ai_ci
|
||||||
|
COMMENT='新闻表';
|
||||||
|
```
|
||||||
Binary file not shown.
Binary file not shown.
|
|
@ -1,11 +1,30 @@
|
||||||
# 机器学习模块依赖
|
# 机器学习模块依赖
|
||||||
scikit-learn==1.4.0
|
numpy>=1.24.0
|
||||||
numpy==1.26.3
|
pandas>=2.0.0
|
||||||
pandas==2.1.4
|
scikit-learn>=1.3.0
|
||||||
jieba==0.42.1
|
jieba>=0.42.0
|
||||||
scipy==1.12.0
|
joblib>=1.3.0
|
||||||
joblib==1.3.2
|
|
||||||
|
|
||||||
# 深度学习 (可选)
|
# 深度学习
|
||||||
# torch==2.1.2
|
torch>=2.0.0
|
||||||
# transformers==4.37.2
|
transformers>=4.30.0
|
||||||
|
|
||||||
|
# API服务
|
||||||
|
fastapi>=0.100.0
|
||||||
|
uvicorn[standard]>=0.23.0
|
||||||
|
pydantic>=2.0.0
|
||||||
|
|
||||||
|
# 数据库
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
pymysql>=1.1.0
|
||||||
|
|
||||||
|
# 数据可视化
|
||||||
|
matplotlib>=3.7.0
|
||||||
|
seaborn>=0.12.0
|
||||||
|
|
||||||
|
# 工具
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pyyaml>=6.0
|
||||||
|
|
||||||
|
# 中文处理
|
||||||
|
jieba>=0.42.0
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,157 @@
|
||||||
|
"""
|
||||||
|
机器学习模型API服务
|
||||||
|
使用FastAPI提供RESTful API
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, List
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 导入分类器
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from traditional.predict import TraditionalPredictor
|
||||||
|
from hybrid.hybrid_classifier import HybridClassifier
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="新闻分类API",
|
||||||
|
description="提供新闻文本分类服务",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 请求模型
|
||||||
|
class ClassifyRequest(BaseModel):
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
mode: Optional[str] = 'hybrid' # traditional, hybrid
|
||||||
|
|
||||||
|
# 响应模型
|
||||||
|
class ClassifyResponse(BaseModel):
|
||||||
|
categoryCode: str
|
||||||
|
categoryName: str
|
||||||
|
confidence: float
|
||||||
|
classifierType: str
|
||||||
|
duration: int
|
||||||
|
probabilities: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化分类器
|
||||||
|
nb_predictor = None
|
||||||
|
hybrid_classifier = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""启动时加载模型"""
|
||||||
|
global nb_predictor, hybrid_classifier
|
||||||
|
|
||||||
|
logger.info("加载模型...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
nb_predictor = TraditionalPredictor('nb')
|
||||||
|
logger.info("朴素贝叶斯模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"朴素贝叶斯模型加载失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hybrid_classifier = HybridClassifier()
|
||||||
|
logger.info("混合分类器初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"混合分类器初始化失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""健康检查"""
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"message": "新闻分类API服务运行中"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""健康检查"""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"models": {
|
||||||
|
"nb_loaded": nb_predictor is not None,
|
||||||
|
"hybrid_loaded": hybrid_classifier is not None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/predict", response_model=ClassifyResponse)
|
||||||
|
async def predict(request: ClassifyRequest):
|
||||||
|
"""
|
||||||
|
文本分类接口
|
||||||
|
|
||||||
|
- **title**: 新闻标题
|
||||||
|
- **content**: 新闻内容
|
||||||
|
- **mode**: 分类模式 (traditional, hybrid)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if request.mode == 'traditional':
|
||||||
|
result = nb_predictor.predict(request.title, request.content)
|
||||||
|
result['classifierType'] = 'ML'
|
||||||
|
else: # hybrid
|
||||||
|
result = hybrid_classifier.predict(request.title, request.content)
|
||||||
|
|
||||||
|
return ClassifyResponse(**result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预测失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/batch-predict")
|
||||||
|
async def batch_predict(requests: List[ClassifyRequest]):
|
||||||
|
"""
|
||||||
|
批量分类接口
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for req in requests:
|
||||||
|
try:
|
||||||
|
if req.mode == 'traditional':
|
||||||
|
result = nb_predictor.predict(req.title, req.content)
|
||||||
|
result['classifierType'] = 'ML'
|
||||||
|
else:
|
||||||
|
result = hybrid_classifier.predict(req.title, req.content)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
results.append({
|
||||||
|
'error': str(e),
|
||||||
|
'title': req.title
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=5000,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,209 @@
|
||||||
|
"""
|
||||||
|
BERT文本分类模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
BertTokenizer,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments
|
||||||
|
)
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
# 分类映射
|
||||||
|
CATEGORY_MAP = {
|
||||||
|
'ENTERTAINMENT': '娱乐',
|
||||||
|
'SPORTS': '体育',
|
||||||
|
'FINANCE': '财经',
|
||||||
|
'TECHNOLOGY': '科技',
|
||||||
|
'MILITARY': '军事',
|
||||||
|
'AUTOMOTIVE': '汽车',
|
||||||
|
'GOVERNMENT': '政务',
|
||||||
|
'HEALTH': '健康',
|
||||||
|
'AI': 'AI'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 反向映射
|
||||||
|
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
|
||||||
|
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
|
||||||
|
|
||||||
|
|
||||||
|
class BertClassifier:
|
||||||
|
"""BERT文本分类器 - 支持GPU/CPU自动检测和动态参数"""
|
||||||
|
|
||||||
|
def __init__(self, model_name='bert-base-chinese', num_labels=9, use_gpu=True):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.tokenizer = None
|
||||||
|
self.model = None
|
||||||
|
|
||||||
|
# GPU/CPU自动检测
|
||||||
|
self.use_gpu = use_gpu and torch.cuda.is_available()
|
||||||
|
self.device = torch.device('cuda' if self.use_gpu else 'cpu')
|
||||||
|
|
||||||
|
# 打印设备信息
|
||||||
|
if self.use_gpu:
|
||||||
|
print(f"使用GPU训练: {torch.cuda.get_device_name(0)}")
|
||||||
|
print(f"GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
|
||||||
|
else:
|
||||||
|
print("使用CPU训练")
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
"""加载微调后的模型"""
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
num_labels=self.num_labels
|
||||||
|
)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
print(f"BERT模型加载成功: {model_path}")
|
||||||
|
|
||||||
|
def train_model(self, train_dataset, eval_dataset, output_dir='./models/deep_learning/bert_finetuned',
|
||||||
|
num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=16,
|
||||||
|
learning_rate=2e-5, warmup_steps=500, weight_decay=0.01, fp16=None):
|
||||||
|
"""
|
||||||
|
训练BERT模型
|
||||||
|
|
||||||
|
:param train_dataset: 训练数据集
|
||||||
|
:param eval_dataset: 验证数据集
|
||||||
|
:param output_dir: 模型输出目录
|
||||||
|
:param num_train_epochs: 训练轮数
|
||||||
|
:param per_device_train_batch_size: 每个设备的训练batch大小
|
||||||
|
:param per_device_eval_batch_size: 每个设备的验证batch大小
|
||||||
|
:param learning_rate: 学习率
|
||||||
|
:param warmup_steps: 预热步数
|
||||||
|
:param weight_decay: 权重衰减
|
||||||
|
:param fp16: 是否使用混合精度(None表示自动检测)
|
||||||
|
"""
|
||||||
|
# 自动检测是否使用混合精度
|
||||||
|
if fp16 is None:
|
||||||
|
fp16 = self.use_gpu
|
||||||
|
|
||||||
|
# 配置训练参数
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir=output_dir,
|
||||||
|
num_train_epochs=num_train_epochs,
|
||||||
|
per_device_train_batch_size=per_device_train_batch_size,
|
||||||
|
per_device_eval_batch_size=per_device_eval_batch_size,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
weight_decay=weight_decay,
|
||||||
|
logging_dir='./logs',
|
||||||
|
logging_steps=10,
|
||||||
|
evaluation_strategy="epoch",
|
||||||
|
save_strategy="epoch",
|
||||||
|
load_best_model_at_end=True,
|
||||||
|
metric_for_best_model="accuracy",
|
||||||
|
fp16=fp16, # 混合精度
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建训练器
|
||||||
|
trainer = Trainer(
|
||||||
|
model=self.model,
|
||||||
|
args=training_args,
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
# 开始训练
|
||||||
|
print("开始训练BERT模型...")
|
||||||
|
trainer.train()
|
||||||
|
print("训练完成!")
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
trainer.save_model(output_dir)
|
||||||
|
self.tokenizer.save_pretrained(output_dir)
|
||||||
|
print(f"模型已保存到: {output_dir}")
|
||||||
|
|
||||||
|
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
预测
|
||||||
|
"""
|
||||||
|
if self.model is None or self.tokenizer is None:
|
||||||
|
raise ValueError("模型未加载,请先调用load_model")
|
||||||
|
|
||||||
|
# 组合标题和内容
|
||||||
|
text = f"{title} [SEP] {content}"
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors='pt',
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
padding='max_length'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# 获取预测结果
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
confidence, predicted_id = torch.max(probs, dim=-1)
|
||||||
|
|
||||||
|
predicted_id = predicted_id.item()
|
||||||
|
confidence = confidence.item()
|
||||||
|
|
||||||
|
# 获取各类别概率
|
||||||
|
prob_dict = {}
|
||||||
|
for i, prob in enumerate(probs[0].cpu().numpy()):
|
||||||
|
category_code = ID_TO_LABEL[i]
|
||||||
|
prob_dict[category_code] = float(prob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'categoryCode': ID_TO_LABEL[predicted_id],
|
||||||
|
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
|
||||||
|
'confidence': confidence,
|
||||||
|
'probabilities': prob_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 数据集类
|
||||||
|
class NewsDataset(torch.utils.data.Dataset):
|
||||||
|
"""新闻数据集"""
|
||||||
|
|
||||||
|
def __init__(self, texts, labels, tokenizer, max_length=512):
|
||||||
|
self.texts = texts
|
||||||
|
self.labels = labels
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
text = self.texts[idx]
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
encoding = self.tokenizer(
|
||||||
|
text,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding='max_length',
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': encoding['input_ids'].flatten(),
|
||||||
|
'attention_mask': encoding['attention_mask'].flatten(),
|
||||||
|
'labels': torch.tensor(label, dtype=torch.long)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
classifier = BertClassifier()
|
||||||
|
# classifier.load_model('./models/deep_learning/bert_finetuned')
|
||||||
|
#
|
||||||
|
# result = classifier.predict(
|
||||||
|
# title="华为发布新款折叠屏手机",
|
||||||
|
# content="华为今天正式发布了新一代折叠屏手机..."
|
||||||
|
# )
|
||||||
|
# print(result)
|
||||||
|
print("BERT分类器初始化成功")
|
||||||
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""
|
||||||
|
混合策略分类器
|
||||||
|
结合规则引擎和机器学习模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any
|
||||||
|
from ..traditional.predict import TraditionalPredictor
|
||||||
|
from ..deep_learning.bert_model import BertClassifier
|
||||||
|
|
||||||
|
|
||||||
|
class HybridClassifier:
|
||||||
|
"""混合分类器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 初始化各个分类器
|
||||||
|
self.nb_predictor = TraditionalPredictor('nb')
|
||||||
|
self.bert_classifier = BertClassifier()
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
self.config = {
|
||||||
|
'confidence_threshold': 0.75, # 高置信度阈值
|
||||||
|
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
|
||||||
|
'use_bert_threshold': 0.70, # 使用BERT的阈值
|
||||||
|
'rule_priority': True # 规则优先
|
||||||
|
}
|
||||||
|
|
||||||
|
# 规则关键词字典
|
||||||
|
self.rule_keywords = {
|
||||||
|
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手', '娱乐'],
|
||||||
|
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA', '足球', '篮球'],
|
||||||
|
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行', '理财'],
|
||||||
|
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技', '数码'],
|
||||||
|
'MILITARY': ['军事', '武器', '军队', '国防', '战争', '军事'],
|
||||||
|
'AUTOMOTIVE': ['汽车', '车', '车型', '驾驶', '交通', ' automotive'],
|
||||||
|
'GOVERNMENT': ['政府', '政策', '选举', '国务院', '主席', '总理', '政务'],
|
||||||
|
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗', '健康'],
|
||||||
|
'AI': ['AI', '人工智能', '机器学习', '深度学习', '算法', 'AI']
|
||||||
|
}
|
||||||
|
|
||||||
|
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
|
||||||
|
"""
|
||||||
|
规则匹配
|
||||||
|
:return: (category_code, confidence)
|
||||||
|
"""
|
||||||
|
text = title + ' ' + content
|
||||||
|
|
||||||
|
# 计算每个类别的关键词匹配数
|
||||||
|
matches = {}
|
||||||
|
for category, keywords in self.rule_keywords.items():
|
||||||
|
count = sum(1 for kw in keywords if kw in text)
|
||||||
|
if count > 0:
|
||||||
|
matches[category] = count
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
# 返回匹配最多的类别
|
||||||
|
best_category = max(matches, key=matches.get)
|
||||||
|
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
|
||||||
|
|
||||||
|
return best_category, confidence
|
||||||
|
|
||||||
|
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
混合预测
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 1. 先尝试规则匹配
|
||||||
|
rule_category, rule_confidence = self.rule_match(title, content)
|
||||||
|
|
||||||
|
# 2. 传统机器学习预测
|
||||||
|
nb_result = self.nb_predictor.predict(title, content)
|
||||||
|
nb_confidence = nb_result['confidence']
|
||||||
|
|
||||||
|
# 决策逻辑
|
||||||
|
final_result = None
|
||||||
|
classifier_type = 'HYBRID'
|
||||||
|
|
||||||
|
# 规则优先且规则置信度高
|
||||||
|
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
|
||||||
|
final_result = {
|
||||||
|
'categoryCode': rule_category,
|
||||||
|
'categoryName': nb_result['categoryName'], # 从映射获取
|
||||||
|
'confidence': rule_confidence,
|
||||||
|
'classifierType': 'RULE',
|
||||||
|
'reason': '规则匹配'
|
||||||
|
}
|
||||||
|
# 传统模型置信度足够高
|
||||||
|
elif nb_confidence >= self.config['confidence_threshold']:
|
||||||
|
final_result = {
|
||||||
|
**nb_result,
|
||||||
|
'classifierType': 'ML',
|
||||||
|
'reason': '传统模型高置信度'
|
||||||
|
}
|
||||||
|
# 需要使用BERT
|
||||||
|
elif use_bert:
|
||||||
|
# TODO: 加载BERT模型预测
|
||||||
|
# bert_result = self.bert_classifier.predict(title, content)
|
||||||
|
# 如果BERT置信度也不高,选择最高的
|
||||||
|
final_result = {
|
||||||
|
**nb_result,
|
||||||
|
'classifierType': 'HYBRID',
|
||||||
|
'reason': '混合决策'
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 不使用BERT,直接返回传统模型结果
|
||||||
|
final_result = {
|
||||||
|
**nb_result,
|
||||||
|
'classifierType': 'ML',
|
||||||
|
'reason': '默认传统模型'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算耗时
|
||||||
|
duration = int((time.time() - start_time) * 1000)
|
||||||
|
final_result['duration'] = duration
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
classifier = HybridClassifier()
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
'title': '国务院发布最新经济政策',
|
||||||
|
'content': '国务院今天发布了新的经济政策...'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'title': '华为发布新款折叠屏手机',
|
||||||
|
'content': '华为今天正式发布了新一代折叠屏手机...'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'title': 'NBA总决赛精彩瞬间',
|
||||||
|
'content': '湖人队与勇士队的总决赛...'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
result = classifier.predict(case['title'], case['content'])
|
||||||
|
print(f"标题: {case['title']}")
|
||||||
|
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
|
||||||
|
print(f"分类器: {result['classifierType']}")
|
||||||
|
print(f"原因: {result.get('reason', 'N/A')}")
|
||||||
|
print(f"耗时: {result['duration']}ms")
|
||||||
|
print("-" * 50)
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
"""
|
||||||
|
传统机器学习模型预测
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import joblib
|
||||||
|
import jieba
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
# 分类映射(根据数据库中的分类)
|
||||||
|
CATEGORY_MAP = {
|
||||||
|
'ENTERTAINMENT': '娱乐',
|
||||||
|
'SPORTS': '体育',
|
||||||
|
'FINANCE': '财经',
|
||||||
|
'TECHNOLOGY': '科技',
|
||||||
|
'MILITARY': '军事',
|
||||||
|
'AUTOMOTIVE': '汽车',
|
||||||
|
'GOVERNMENT': '政务',
|
||||||
|
'HEALTH': '健康',
|
||||||
|
'AI': 'AI'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TraditionalPredictor:
|
||||||
|
"""传统机器学习预测器"""
|
||||||
|
|
||||||
|
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
|
||||||
|
self.model_type = model_type
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.vectorizer = None
|
||||||
|
self.classifier = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载模型"""
|
||||||
|
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
|
||||||
|
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
|
||||||
|
|
||||||
|
self.vectorizer = joblib.load(vectorizer_path)
|
||||||
|
self.classifier = joblib.load(classifier_path)
|
||||||
|
print(f"模型加载成功: {self.model_type}")
|
||||||
|
|
||||||
|
def preprocess(self, title: str, content: str) -> str:
|
||||||
|
"""预处理文本"""
|
||||||
|
text = title + ' ' + content
|
||||||
|
# jieba分词
|
||||||
|
words = jieba.cut(text)
|
||||||
|
return ' '.join(words)
|
||||||
|
|
||||||
|
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
预测
|
||||||
|
:return: 预测结果字典
|
||||||
|
"""
|
||||||
|
# 预处理
|
||||||
|
processed = self.preprocess(title, content)
|
||||||
|
|
||||||
|
# 特征提取
|
||||||
|
tfidf = self.vectorizer.transform([processed])
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
prediction = self.classifier.predict(tfidf)[0]
|
||||||
|
probabilities = self.classifier.predict_proba(tfidf)[0]
|
||||||
|
|
||||||
|
# 获取各类别概率
|
||||||
|
prob_dict = {}
|
||||||
|
for i, prob in enumerate(probabilities):
|
||||||
|
category_code = self.classifier.classes_[i]
|
||||||
|
prob_dict[category_code] = float(prob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'categoryCode': prediction,
|
||||||
|
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
|
||||||
|
'confidence': float(probabilities.max()),
|
||||||
|
'probabilities': prob_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# API入口
|
||||||
|
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
单条预测API
|
||||||
|
"""
|
||||||
|
predictor = TraditionalPredictor(model_type)
|
||||||
|
return predictor.predict(title, content)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
result = predict_single(
|
||||||
|
title="华为发布新款折叠屏手机",
|
||||||
|
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
|
@ -14,45 +14,95 @@ from sklearn.svm import SVC
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import classification_report, accuracy_score, f1_score
|
from sklearn.metrics import classification_report, accuracy_score, f1_score
|
||||||
|
|
||||||
# 分类映射
|
# 分类映射(与数据库表一致)
|
||||||
CATEGORY_MAP = {
|
CATEGORY_MAP = {
|
||||||
'POLITICS': '时政',
|
'ENTERTAINMENT': '娱乐',
|
||||||
|
'SPORTS': '体育',
|
||||||
'FINANCE': '财经',
|
'FINANCE': '财经',
|
||||||
'TECHNOLOGY': '科技',
|
'TECHNOLOGY': '科技',
|
||||||
'SPORTS': '体育',
|
'MILITARY': '军事',
|
||||||
'ENTERTAINMENT': '娱乐',
|
'AUTOMOTIVE': '汽车',
|
||||||
|
'GOVERNMENT': '政务',
|
||||||
'HEALTH': '健康',
|
'HEALTH': '健康',
|
||||||
'EDUCATION': '教育',
|
'AI': 'AI'
|
||||||
'LIFE': '生活',
|
|
||||||
'INTERNATIONAL': '国际',
|
|
||||||
'MILITARY': '军事'
|
|
||||||
}
|
}
|
||||||
|
|
||||||
REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()}
|
REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()}
|
||||||
|
|
||||||
|
|
||||||
|
# 分类专属停用词(与数据库表一致)
|
||||||
|
CATEGORY_STOPWORDS = {
|
||||||
|
'ENTERTAINMENT': {'主演', '影片', '电影', '电视剧', '节目', '导演', '角色', '上映', '粉丝'},
|
||||||
|
'SPORTS': {'比赛', '赛事', '赛季', '球队', '选手', '球员', '主场', '客场', '对阵', '比分'},
|
||||||
|
'FINANCE': {'亿元', '万元', '同比', '环比', '增长率', '数据', '报告', '统计', '财报', '季度', '年度'},
|
||||||
|
'TECHNOLOGY': {'技术', '系统', '平台', '方案', '应用', '功能', '版本', '升级', '研发', '推出'},
|
||||||
|
'MILITARY': {'部队', '军方', '演习', '训练', '装备', '武器', '作战', '行动', '部署'},
|
||||||
|
'AUTOMOTIVE': {'汽车', '车型', '上市', '发布', '销量', '市场', '品牌', '厂商'},
|
||||||
|
'GOVERNMENT': {'会议', '讲话', '指出', '强调', '部署', '落实', '推进', '要求', '精神', '决定', '意见', '方案', '安排'},
|
||||||
|
'HEALTH': {'医生', '专家', '建议', '提示', '提醒', '研究', '发现', '可能', '有助于'},
|
||||||
|
'AI': {'技术', '系统', '模型', '算法', '应用', '功能', '版本', '升级', '研发', '推出', '人工智能'}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class NewsClassifier:
|
class NewsClassifier:
|
||||||
"""新闻文本分类器"""
|
"""新闻文本分类器"""
|
||||||
|
|
||||||
def __init__(self, model_type='nb'):
|
def __init__(self, model_type='nb', use_stopwords=True, use_category_stopwords=False):
|
||||||
"""
|
"""
|
||||||
初始化分类器
|
初始化分类器
|
||||||
:param model_type: 模型类型 'nb' 朴素贝叶斯 或 'svm' 支持向量机
|
:param model_type: 模型类型 'nb' 朴素贝叶斯 或 'svm' 支持向量机
|
||||||
|
:param use_stopwords: 是否使用通用停用词
|
||||||
|
:param use_category_stopwords: 是否使用分类专属停用词
|
||||||
"""
|
"""
|
||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.vectorizer = None
|
self.vectorizer = None
|
||||||
self.classifier = None
|
self.classifier = None
|
||||||
self.categories = list(CATEGORY_MAP.keys())
|
self.categories = list(CATEGORY_MAP.keys())
|
||||||
|
self.use_stopwords = use_stopwords
|
||||||
|
self.use_category_stopwords = use_category_stopwords
|
||||||
|
self.stopwords = set()
|
||||||
|
|
||||||
def preprocess_text(self, text):
|
if self.use_stopwords:
|
||||||
|
self._load_stopwords()
|
||||||
|
|
||||||
|
def _load_stopwords(self, stopwords_path='../../data/news_stopwords.txt'):
|
||||||
"""
|
"""
|
||||||
文本预处理:使用jieba分词
|
加载停用词表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(stopwords_path, 'r', encoding='utf-8') as f:
|
||||||
|
self.stopwords = set(line.strip() for line in f if line.strip())
|
||||||
|
print(f"已加载 {len(self.stopwords)} 个停用词")
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"警告: 停用词文件不存在: {stopwords_path}")
|
||||||
|
|
||||||
|
def preprocess_text(self, text, category=None):
|
||||||
|
"""
|
||||||
|
文本预处理:使用jieba分词 + 停用词过滤
|
||||||
|
:param text: 待处理文本
|
||||||
|
:param category: 可选,指定分类时使用分类专属停用词
|
||||||
"""
|
"""
|
||||||
# 移除多余空格和换行
|
# 移除多余空格和换行
|
||||||
text = ' '.join(text.split())
|
text = ' '.join(text.split())
|
||||||
# jieba分词
|
# jieba分词
|
||||||
words = jieba.cut(text)
|
words = jieba.cut(text)
|
||||||
return ' '.join(words)
|
|
||||||
|
# 过滤停用词和单字词
|
||||||
|
result = []
|
||||||
|
for w in words:
|
||||||
|
# 过滤单字词
|
||||||
|
if len(w) <= 1:
|
||||||
|
continue
|
||||||
|
# 过滤通用停用词
|
||||||
|
if self.use_stopwords and w in self.stopwords:
|
||||||
|
continue
|
||||||
|
# 过滤分类专属停用词
|
||||||
|
if self.use_category_stopwords and category:
|
||||||
|
if w in CATEGORY_STOPWORDS.get(category, set()):
|
||||||
|
continue
|
||||||
|
result.append(w)
|
||||||
|
|
||||||
|
return ' '.join(result)
|
||||||
|
|
||||||
def load_data(self, csv_path):
|
def load_data(self, csv_path):
|
||||||
"""
|
"""
|
||||||
|
|
@ -61,10 +111,22 @@ class NewsClassifier:
|
||||||
df = pd.read_csv(csv_path)
|
df = pd.read_csv(csv_path)
|
||||||
# 合并标题和内容作为特征
|
# 合并标题和内容作为特征
|
||||||
df['text'] = df['title'] + ' ' + df['content']
|
df['text'] = df['title'] + ' ' + df['content']
|
||||||
# 预处理
|
|
||||||
df['processed_text'] = df['text'].apply(self.preprocess_text)
|
# 预处理(如果启用了分类专属停用词,需要传入分类信息)
|
||||||
# 转换分类名称为代码
|
if self.use_category_stopwords:
|
||||||
df['category_code'] = df['category'].map(REVERSE_CATEGORY_MAP)
|
# 先转换为分类代码
|
||||||
|
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
|
||||||
|
# 逐行预处理,传入分类信息
|
||||||
|
df['processed_text'] = df.apply(
|
||||||
|
lambda row: self.preprocess_text(row['text'], row['category_code']),
|
||||||
|
axis=1
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 不使用分类专属停用词,直接批量预处理
|
||||||
|
df['processed_text'] = df['text'].apply(self.preprocess_text)
|
||||||
|
# 转换分类名称为代码
|
||||||
|
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def train(self, df):
|
def train(self, df):
|
||||||
|
|
@ -74,6 +136,10 @@ class NewsClassifier:
|
||||||
X = df['processed_text'].values
|
X = df['processed_text'].values
|
||||||
y = df['category_code'].values
|
y = df['category_code'].values
|
||||||
|
|
||||||
|
# 获取实际数据中的分类
|
||||||
|
actual_categories = sorted(df['category_code'].unique().tolist())
|
||||||
|
actual_category_names = [CATEGORY_MAP[cat] for cat in actual_categories]
|
||||||
|
|
||||||
# 划分训练集和测试集
|
# 划分训练集和测试集
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X, y, test_size=0.2, random_state=42, stratify=y
|
X, y, test_size=0.2, random_state=42, stratify=y
|
||||||
|
|
@ -108,7 +174,7 @@ class NewsClassifier:
|
||||||
print(f"准确率: {accuracy:.4f}")
|
print(f"准确率: {accuracy:.4f}")
|
||||||
print(f"F1-Score: {f1:.4f}")
|
print(f"F1-Score: {f1:.4f}")
|
||||||
print("\n分类报告:")
|
print("\n分类报告:")
|
||||||
print(classification_report(y_test, y_pred, target_names=self.categories))
|
print(classification_report(y_test, y_pred, target_names=actual_category_names))
|
||||||
|
|
||||||
return accuracy, f1
|
return accuracy, f1
|
||||||
|
|
||||||
|
|
@ -120,6 +186,7 @@ class NewsClassifier:
|
||||||
raise ValueError("模型未训练,请先调用train方法")
|
raise ValueError("模型未训练,请先调用train方法")
|
||||||
|
|
||||||
text = title + ' ' + content
|
text = title + ' ' + content
|
||||||
|
# 预测时不指定分类,只使用通用停用词
|
||||||
processed = self.preprocess_text(text)
|
processed = self.preprocess_text(text)
|
||||||
tfidf = self.vectorizer.transform([processed])
|
tfidf = self.vectorizer.transform([processed])
|
||||||
|
|
||||||
|
|
@ -157,12 +224,12 @@ if __name__ == '__main__':
|
||||||
classifier = NewsClassifier(model_type='nb')
|
classifier = NewsClassifier(model_type='nb')
|
||||||
|
|
||||||
# 假设有训练数据文件
|
# 假设有训练数据文件
|
||||||
train_data_path = '../data/processed/training_data.csv'
|
train_data_path = '../../data/processed/training_data.csv'
|
||||||
|
|
||||||
if os.path.exists(train_data_path):
|
if os.path.exists(train_data_path):
|
||||||
df = classifier.load_data(train_data_path)
|
df = classifier.load_data(train_data_path)
|
||||||
classifier.train(df)
|
classifier.train(df)
|
||||||
classifier.save_model('../models')
|
classifier.save_model('../../models/traditional')
|
||||||
|
|
||||||
# 测试预测
|
# 测试预测
|
||||||
test_title = "华为发布新款折叠屏手机"
|
test_title = "华为发布新款折叠屏手机"
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,163 @@
|
||||||
|
"""
|
||||||
|
数据加载和本地存储工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import os
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DataLoader:
|
||||||
|
"""数据加载和本地存储类"""
|
||||||
|
|
||||||
|
def __init__(self, db_url: str, data_dir: str = 'data'):
|
||||||
|
"""
|
||||||
|
初始化数据加载器
|
||||||
|
|
||||||
|
:param db_url: 数据库连接URL
|
||||||
|
:param data_dir: 数据存储目录
|
||||||
|
"""
|
||||||
|
self.db_url = db_url
|
||||||
|
self.data_dir = data_dir
|
||||||
|
self.engine = create_engine(db_url)
|
||||||
|
|
||||||
|
# 确保数据目录存在
|
||||||
|
os.makedirs(os.path.join(data_dir, 'raw'), exist_ok=True)
|
||||||
|
os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True)
|
||||||
|
|
||||||
|
def load_news_from_db(self, limit: int = None, category_ids: List[int] = None) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
从数据库加载新闻数据
|
||||||
|
|
||||||
|
:param limit: 限制加载数据量
|
||||||
|
:param category_ids: 指定分类ID列表
|
||||||
|
:return: 新闻数据DataFrame
|
||||||
|
"""
|
||||||
|
query = """
|
||||||
|
SELECT n.id, n.title, n.content, c.name as category_name
|
||||||
|
FROM news n
|
||||||
|
LEFT JOIN news_category c ON n.category_id = c.id
|
||||||
|
WHERE n.content IS NOT NULL AND n.title IS NOT NULL
|
||||||
|
"""
|
||||||
|
|
||||||
|
if category_ids:
|
||||||
|
ids_str = ",".join(str(i) for i in category_ids)
|
||||||
|
query += f" AND c.id IN ({ids_str})"
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += f" LIMIT {limit}"
|
||||||
|
|
||||||
|
df = pd.read_sql(query, self.engine)
|
||||||
|
|
||||||
|
|
||||||
|
logger.info(f"从数据库加载数据,条件: limit={limit}, category_ids={category_ids}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = pd.read_sql(query, self.engine)
|
||||||
|
logger.info(f"成功加载 {len(df)} 条新闻数据")
|
||||||
|
return df
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"加载数据失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def save_to_local(self, df: pd.DataFrame, filename: str, subdir: str = 'processed'):
|
||||||
|
"""
|
||||||
|
将数据保存到本地
|
||||||
|
|
||||||
|
:param df: 要保存的DataFrame
|
||||||
|
:param filename: 文件名
|
||||||
|
:param subdir: 子目录
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(self.data_dir, subdir, filename)
|
||||||
|
|
||||||
|
# 确保目录存在
|
||||||
|
os.makedirs(os.path.join(self.data_dir, subdir), exist_ok=True)
|
||||||
|
|
||||||
|
# 保存为CSV
|
||||||
|
df.to_csv(file_path, index=False, encoding='utf-8-sig')
|
||||||
|
logger.info(f"数据已保存到: {file_path}")
|
||||||
|
|
||||||
|
def load_from_local(self, filename: str, subdir: str = 'processed') -> Optional[pd.DataFrame]:
|
||||||
|
"""
|
||||||
|
从本地加载已保存的数据
|
||||||
|
|
||||||
|
:param filename: 文件名
|
||||||
|
:param subdir: 子目录
|
||||||
|
:return: DataFrame或None(如果文件不存在)
|
||||||
|
"""
|
||||||
|
file_path = os.path.join(self.data_dir, subdir, filename)
|
||||||
|
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
||||||
|
logger.info(f"从本地加载 {len(df)} 条数据: {file_path}")
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
logger.warning(f"本地文件不存在: {file_path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_local_data(self, limit: int = None, category_ids: List[int] = None):
|
||||||
|
"""
|
||||||
|
更新本地数据(从数据库重新加载数据并保存)
|
||||||
|
|
||||||
|
:param limit: 限制加载数据量
|
||||||
|
:param category_ids: 指定分类ID列表
|
||||||
|
"""
|
||||||
|
# 从数据库加载数据
|
||||||
|
df = self.load_news_from_db(limit=limit, category_ids=category_ids)
|
||||||
|
|
||||||
|
if not df.empty:
|
||||||
|
# 生成文件名(包含时间戳)
|
||||||
|
import time
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"news_data_{timestamp}.csv"
|
||||||
|
|
||||||
|
# 保存到本地
|
||||||
|
self.save_to_local(df, filename)
|
||||||
|
|
||||||
|
# 可选:删除旧的本地文件,只保留最新的
|
||||||
|
self._cleanup_old_files(filename)
|
||||||
|
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
logger.warning("没有可用的数据需要保存")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _cleanup_old_files(self, current_filename: str):
|
||||||
|
"""
|
||||||
|
清理旧的本地文件(可选)
|
||||||
|
|
||||||
|
:param current_filename: 当前文件名
|
||||||
|
"""
|
||||||
|
subdir = 'processed'
|
||||||
|
dir_path = os.path.join(self.data_dir, subdir)
|
||||||
|
|
||||||
|
if os.path.exists(dir_path):
|
||||||
|
for filename in os.listdir(dir_path):
|
||||||
|
if filename != current_filename and filename.startswith('news_data_'):
|
||||||
|
file_path = os.path.join(dir_path, filename)
|
||||||
|
try:
|
||||||
|
os.remove(file_path)
|
||||||
|
logger.info(f"已删除旧文件: {file_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"删除文件失败 {file_path}: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# 示例用法
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 配置数据库连接
|
||||||
|
db_url = "mysql+pymysql://root:root@localhost/news"
|
||||||
|
|
||||||
|
# 创建数据加载器
|
||||||
|
loader = DataLoader(db_url)
|
||||||
|
|
||||||
|
# 从数据库加载数据并保存到本地
|
||||||
|
data = loader.update_local_data(limit=800)
|
||||||
|
|
||||||
|
if data is not None:
|
||||||
|
print(f"成功加载数据,共 {len(data)} 条记录")
|
||||||
|
print(data.head())
|
||||||
|
else:
|
||||||
|
print("没有数据可加载")
|
||||||
|
|
@ -0,0 +1,96 @@
|
||||||
|
"""
|
||||||
|
模型评估指标工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score,
|
||||||
|
precision_recall_fscore_support,
|
||||||
|
confusion_matrix,
|
||||||
|
classification_report
|
||||||
|
)
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
class ClassificationMetrics:
|
||||||
|
"""分类评估指标"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
计算所有指标
|
||||||
|
"""
|
||||||
|
accuracy = accuracy_score(y_true, y_pred)
|
||||||
|
|
||||||
|
precision, recall, f1, support = precision_recall_fscore_support(
|
||||||
|
y_true, y_pred, average='weighted', zero_division=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每个类别的指标
|
||||||
|
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
|
||||||
|
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
|
||||||
|
|
||||||
|
per_class_metrics = {}
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
per_class_metrics[label] = {
|
||||||
|
'precision': float(precision_per_class[i]),
|
||||||
|
'recall': float(recall_per_class[i]),
|
||||||
|
'f1': float(f1_per_class[i]),
|
||||||
|
'support': int(support_per_class[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'accuracy': float(accuracy),
|
||||||
|
'precision': float(precision),
|
||||||
|
'recall': float(recall),
|
||||||
|
'f1': float(f1),
|
||||||
|
'per_class': per_class_metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
|
||||||
|
"""
|
||||||
|
绘制混淆矩阵
|
||||||
|
"""
|
||||||
|
cm = confusion_matrix(y_true, y_pred)
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
sns.heatmap(
|
||||||
|
cm,
|
||||||
|
annot=True,
|
||||||
|
fmt='d',
|
||||||
|
cmap='Blues',
|
||||||
|
xticklabels=labels,
|
||||||
|
yticklabels=labels
|
||||||
|
)
|
||||||
|
plt.xlabel('预测标签')
|
||||||
|
plt.ylabel('真实标签')
|
||||||
|
plt.title('混淆矩阵')
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print_report(y_true: List, y_pred: List, labels: List[str]):
|
||||||
|
"""
|
||||||
|
打印分类报告
|
||||||
|
"""
|
||||||
|
report = classification_report(
|
||||||
|
y_true, y_pred,
|
||||||
|
target_names=labels,
|
||||||
|
zero_division=0
|
||||||
|
)
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
y_true = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE', 'ENTERTAINMENT', 'TECHNOLOGY']
|
||||||
|
y_pred = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
|
||||||
|
labels = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE']
|
||||||
|
|
||||||
|
metrics = ClassificationMetrics()
|
||||||
|
result = metrics.compute_all(y_true, y_pred, labels)
|
||||||
|
print(result)
|
||||||
|
|
@ -0,0 +1,114 @@
|
||||||
|
"""
|
||||||
|
BERT模型训练脚本
|
||||||
|
支持GPU/CPU自动检测和动态参数调整
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from src.deep_learning.bert_model import BertClassifier
|
||||||
|
from src.utils.data_loader import DataLoader
|
||||||
|
from datasets import Dataset
|
||||||
|
import pandas as pd
|
||||||
|
from transformers import BertTokenizer, DataCollatorWithPadding
|
||||||
|
|
||||||
|
def load_and_prepare_data(db_url: str, limit: int = 1000):
|
||||||
|
"""加载数据并准备训练集"""
|
||||||
|
# 创建数据加载器
|
||||||
|
loader = DataLoader(db_url)
|
||||||
|
|
||||||
|
# 从数据库加载数据
|
||||||
|
df = loader.load_news_from_db(limit=limit)
|
||||||
|
|
||||||
|
if df.empty:
|
||||||
|
print("没有可用的训练数据")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
texts = df['title'] + ' ' + df['content']
|
||||||
|
labels = df['category_name'].map({
|
||||||
|
'娱乐': 0, '体育': 1, '财经': 2, '科技': 3, '军事': 4,
|
||||||
|
'汽车': 5, '政务': 6, '健康': 7, 'AI': 8
|
||||||
|
}).fillna(-1).astype(int)
|
||||||
|
|
||||||
|
# 过滤无效数据
|
||||||
|
valid_data = df[labels != -1]
|
||||||
|
texts = valid_data['title'] + ' ' + valid_data['content']
|
||||||
|
labels = labels[labels != -1]
|
||||||
|
|
||||||
|
print(f"有效数据数量: {len(valid_data)}")
|
||||||
|
|
||||||
|
# 创建Hugging Face Dataset
|
||||||
|
dataset = Dataset.from_pandas(pd.DataFrame({
|
||||||
|
'text': texts.tolist(),
|
||||||
|
'label': labels.tolist()
|
||||||
|
}))
|
||||||
|
|
||||||
|
# 划分训练集和验证集
|
||||||
|
train_test = dataset.train_test_split(test_size=0.2)
|
||||||
|
|
||||||
|
return train_test['train'], train_test['test']
|
||||||
|
|
||||||
|
def train_bert_model(config: dict):
|
||||||
|
"""训练BERT模型"""
|
||||||
|
# 加载数据
|
||||||
|
train_dataset, eval_dataset = load_and_prepare_data(
|
||||||
|
config['database']['url'],
|
||||||
|
limit=config['training']['data_limit']
|
||||||
|
)
|
||||||
|
|
||||||
|
if train_dataset is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 初始化模型
|
||||||
|
classifier = BertClassifier(
|
||||||
|
model_name=config['model']['name'],
|
||||||
|
num_labels=config['model']['num_labels'],
|
||||||
|
use_gpu=config['training']['use_gpu']
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
classifier.train_model(
|
||||||
|
train_dataset=train_dataset,
|
||||||
|
eval_dataset=eval_dataset,
|
||||||
|
output_dir=config['model']['output_dir'],
|
||||||
|
num_train_epochs=config['training']['epochs'],
|
||||||
|
per_device_train_batch_size=config['training']['batch_size'],
|
||||||
|
per_device_eval_batch_size=config['training']['batch_size'] * 2,
|
||||||
|
learning_rate=config['training']['learning_rate'],
|
||||||
|
warmup_steps=config['training']['warmup_steps'],
|
||||||
|
weight_decay=config['training']['weight_decay'],
|
||||||
|
fp16=config['training'].get('fp16', None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 解析命令行参数
|
||||||
|
parser = argparse.ArgumentParser(description='BERT模型训练脚本')
|
||||||
|
parser.add_argument('--config', type=str, default='config.yaml', help='配置文件路径')
|
||||||
|
parser.add_argument('--use_gpu', action='store_true', help='强制使用GPU')
|
||||||
|
parser.add_argument('--epochs', type=int, help='训练轮数(覆盖配置文件)')
|
||||||
|
parser.add_argument('--batch_size', type=int, help='批大小(覆盖配置文件)')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# 加载配置文件
|
||||||
|
with open(args.config, 'r', encoding='utf-8') as f:
|
||||||
|
config = yaml.safe_load(f)
|
||||||
|
|
||||||
|
# 覆盖配置参数
|
||||||
|
if args.epochs is not None:
|
||||||
|
config['training']['epochs'] = args.epochs
|
||||||
|
if args.batch_size is not None:
|
||||||
|
config['training']['batch_size'] = args.batch_size
|
||||||
|
if args.use_gpu:
|
||||||
|
config['training']['use_gpu'] = True
|
||||||
|
|
||||||
|
# 开始训练
|
||||||
|
print("开始BERT模型训练...")
|
||||||
|
print(f"配置: {config}")
|
||||||
|
|
||||||
|
train_bert_model(config)
|
||||||
|
|
||||||
|
print("训练完成!")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,800 @@
|
||||||
|
# 新闻文本分类系统 - 模块开发任务清单
|
||||||
|
---
|
||||||
|
## 目录
|
||||||
|
|
||||||
|
1. [爬虫模块 (Python)](#1-爬虫模块-python)
|
||||||
|
2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot)
|
||||||
|
3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3)
|
||||||
|
4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. 爬虫模块 (Python)
|
||||||
|
|
||||||
|
## 2. 后端服务模块 (Spring Boot)
|
||||||
|
|
||||||
|
## 4. 机器学习分类模块 (Python)
|
||||||
|
|
||||||
|
### 模块目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
ml-module/
|
||||||
|
├── data/
|
||||||
|
│ ├── raw/ # 原始数据
|
||||||
|
│ ├── processed/ # 处理后的数据
|
||||||
|
│ │ ├── training_data.csv
|
||||||
|
│ │ └── test_data.csv
|
||||||
|
│ └── external/ # 外部数据集
|
||||||
|
├── models/ # 训练好的模型
|
||||||
|
│ ├── traditional/
|
||||||
|
│ │ ├── nb_vectorizer.pkl
|
||||||
|
│ │ ├── nb_classifier.pkl
|
||||||
|
│ │ ├── svm_vectorizer.pkl
|
||||||
|
│ │ └── svm_classifier.pkl
|
||||||
|
│ ├── deep_learning/
|
||||||
|
│ │ └── bert_finetuned/
|
||||||
|
│ └── hybrid/
|
||||||
|
│ └── config.json
|
||||||
|
├── src/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── traditional/ # 传统机器学习
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── train_model.py # (已有)
|
||||||
|
│ │ ├── predict.py
|
||||||
|
│ │ └── evaluate.py
|
||||||
|
│ ├── deep_learning/ # 深度学习
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── bert_model.py
|
||||||
|
│ │ ├── train_bert.py
|
||||||
|
│ │ └── predict_bert.py
|
||||||
|
│ ├── hybrid/ # 混合策略
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── hybrid_classifier.py
|
||||||
|
│ │ └── rule_engine.py
|
||||||
|
│ ├── utils/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── preprocessing.py # 数据预处理
|
||||||
|
│ │ └── metrics.py # 评估指标
|
||||||
|
│ └── api/ # API服务
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ └── server.py # FastAPI服务
|
||||||
|
├── notebooks/ # Jupyter notebooks
|
||||||
|
│ ├── data_exploration.ipynb
|
||||||
|
│ └── model_comparison.ipynb
|
||||||
|
├── tests/ # 测试
|
||||||
|
│ ├── test_traditional.py
|
||||||
|
│ ├── test_bert.py
|
||||||
|
│ └── test_hybrid.py
|
||||||
|
├── requirements.txt
|
||||||
|
├── setup.py
|
||||||
|
└── README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.1 需要完成的具体文件
|
||||||
|
|
||||||
|
#### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
传统机器学习模型预测
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import joblib
|
||||||
|
import jieba
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
# 分类映射
|
||||||
|
CATEGORY_MAP = {
|
||||||
|
'POLITICS': '时政',
|
||||||
|
'FINANCE': '财经',
|
||||||
|
'TECHNOLOGY': '科技',
|
||||||
|
'SPORTS': '体育',
|
||||||
|
'ENTERTAINMENT': '娱乐',
|
||||||
|
'HEALTH': '健康',
|
||||||
|
'EDUCATION': '教育',
|
||||||
|
'LIFE': '生活',
|
||||||
|
'INTERNATIONAL': '国际',
|
||||||
|
'MILITARY': '军事'
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TraditionalPredictor:
|
||||||
|
"""传统机器学习预测器"""
|
||||||
|
|
||||||
|
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
|
||||||
|
self.model_type = model_type
|
||||||
|
self.model_dir = model_dir
|
||||||
|
self.vectorizer = None
|
||||||
|
self.classifier = None
|
||||||
|
self._load_model()
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""加载模型"""
|
||||||
|
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
|
||||||
|
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
|
||||||
|
|
||||||
|
self.vectorizer = joblib.load(vectorizer_path)
|
||||||
|
self.classifier = joblib.load(classifier_path)
|
||||||
|
print(f"模型加载成功: {self.model_type}")
|
||||||
|
|
||||||
|
def preprocess(self, title: str, content: str) -> str:
|
||||||
|
"""预处理文本"""
|
||||||
|
text = title + ' ' + content
|
||||||
|
# jieba分词
|
||||||
|
words = jieba.cut(text)
|
||||||
|
return ' '.join(words)
|
||||||
|
|
||||||
|
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
预测
|
||||||
|
:return: 预测结果字典
|
||||||
|
"""
|
||||||
|
# 预处理
|
||||||
|
processed = self.preprocess(title, content)
|
||||||
|
|
||||||
|
# 特征提取
|
||||||
|
tfidf = self.vectorizer.transform([processed])
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
prediction = self.classifier.predict(tfidf)[0]
|
||||||
|
probabilities = self.classifier.predict_proba(tfidf)[0]
|
||||||
|
|
||||||
|
# 获取各类别概率
|
||||||
|
prob_dict = {}
|
||||||
|
for i, prob in enumerate(probabilities):
|
||||||
|
category_code = self.classifier.classes_[i]
|
||||||
|
prob_dict[category_code] = float(prob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'categoryCode': prediction,
|
||||||
|
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
|
||||||
|
'confidence': float(probabilities.max()),
|
||||||
|
'probabilities': prob_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# API入口
|
||||||
|
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
单条预测API
|
||||||
|
"""
|
||||||
|
predictor = TraditionalPredictor(model_type)
|
||||||
|
return predictor.predict(title, content)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
result = predict_single(
|
||||||
|
title="华为发布新款折叠屏手机",
|
||||||
|
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
|
||||||
|
)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
BERT文本分类模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import (
|
||||||
|
BertTokenizer,
|
||||||
|
BertForSequenceClassification,
|
||||||
|
Trainer,
|
||||||
|
TrainingArguments
|
||||||
|
)
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
|
||||||
|
# 分类映射
|
||||||
|
CATEGORY_MAP = {
|
||||||
|
'POLITICS': '时政',
|
||||||
|
'FINANCE': '财经',
|
||||||
|
'TECHNOLOGY': '科技',
|
||||||
|
'SPORTS': '体育',
|
||||||
|
'ENTERTAINMENT': '娱乐',
|
||||||
|
'HEALTH': '健康',
|
||||||
|
'EDUCATION': '教育',
|
||||||
|
'LIFE': '生活',
|
||||||
|
'INTERNATIONAL': '国际',
|
||||||
|
'MILITARY': '军事'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 反向映射
|
||||||
|
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
|
||||||
|
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
|
||||||
|
|
||||||
|
|
||||||
|
class BertClassifier:
|
||||||
|
"""BERT文本分类器"""
|
||||||
|
|
||||||
|
def __init__(self, model_name='bert-base-chinese', num_labels=10):
|
||||||
|
self.model_name = model_name
|
||||||
|
self.num_labels = num_labels
|
||||||
|
self.tokenizer = None
|
||||||
|
self.model = None
|
||||||
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
def load_model(self, model_path):
|
||||||
|
"""加载微调后的模型"""
|
||||||
|
self.tokenizer = BertTokenizer.from_pretrained(model_path)
|
||||||
|
self.model = BertForSequenceClassification.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
num_labels=self.num_labels
|
||||||
|
)
|
||||||
|
self.model.to(self.device)
|
||||||
|
self.model.eval()
|
||||||
|
print(f"BERT模型加载成功: {model_path}")
|
||||||
|
|
||||||
|
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
预测
|
||||||
|
"""
|
||||||
|
if self.model is None or self.tokenizer is None:
|
||||||
|
raise ValueError("模型未加载,请先调用load_model")
|
||||||
|
|
||||||
|
# 组合标题和内容
|
||||||
|
text = f"{title} [SEP] {content}"
|
||||||
|
|
||||||
|
# 分词
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
text,
|
||||||
|
return_tensors='pt',
|
||||||
|
truncation=True,
|
||||||
|
max_length=512,
|
||||||
|
padding='max_length'
|
||||||
|
)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
with torch.no_grad():
|
||||||
|
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||||
|
outputs = self.model(**inputs)
|
||||||
|
logits = outputs.logits
|
||||||
|
|
||||||
|
# 获取预测结果
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
confidence, predicted_id = torch.max(probs, dim=-1)
|
||||||
|
|
||||||
|
predicted_id = predicted_id.item()
|
||||||
|
confidence = confidence.item()
|
||||||
|
|
||||||
|
# 获取各类别概率
|
||||||
|
prob_dict = {}
|
||||||
|
for i, prob in enumerate(probs[0].cpu().numpy()):
|
||||||
|
category_code = ID_TO_LABEL[i]
|
||||||
|
prob_dict[category_code] = float(prob)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'categoryCode': ID_TO_LABEL[predicted_id],
|
||||||
|
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
|
||||||
|
'confidence': confidence,
|
||||||
|
'probabilities': prob_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 数据集类
|
||||||
|
class NewsDataset(torch.utils.data.Dataset):
|
||||||
|
"""新闻数据集"""
|
||||||
|
|
||||||
|
def __init__(self, texts, labels, tokenizer, max_length=512):
|
||||||
|
self.texts = texts
|
||||||
|
self.labels = labels
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.max_length = max_length
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.texts)
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
text = self.texts[idx]
|
||||||
|
label = self.labels[idx]
|
||||||
|
|
||||||
|
encoding = self.tokenizer(
|
||||||
|
text,
|
||||||
|
truncation=True,
|
||||||
|
max_length=self.max_length,
|
||||||
|
padding='max_length',
|
||||||
|
return_tensors='pt'
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'input_ids': encoding['input_ids'].flatten(),
|
||||||
|
'attention_mask': encoding['attention_mask'].flatten(),
|
||||||
|
'labels': torch.tensor(label, dtype=torch.long)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
classifier = BertClassifier()
|
||||||
|
# classifier.load_model('./models/deep_learning/bert_finetuned')
|
||||||
|
#
|
||||||
|
# result = classifier.predict(
|
||||||
|
# title="华为发布新款折叠屏手机",
|
||||||
|
# content="华为今天正式发布了新一代折叠屏手机..."
|
||||||
|
# )
|
||||||
|
# print(result)
|
||||||
|
print("BERT分类器初始化成功")
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
混合策略分类器
|
||||||
|
结合规则引擎和机器学习模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any
|
||||||
|
from ..traditional.predict import TraditionalPredictor
|
||||||
|
from ..deep_learning.bert_model import BertClassifier
|
||||||
|
|
||||||
|
|
||||||
|
class HybridClassifier:
|
||||||
|
"""混合分类器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 初始化各个分类器
|
||||||
|
self.nb_predictor = TraditionalPredictor('nb')
|
||||||
|
self.bert_classifier = BertClassifier()
|
||||||
|
|
||||||
|
# 配置参数
|
||||||
|
self.config = {
|
||||||
|
'confidence_threshold': 0.75, # 高置信度阈值
|
||||||
|
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
|
||||||
|
'use_bert_threshold': 0.70, # 使用BERT的阈值
|
||||||
|
'rule_priority': True # 规则优先
|
||||||
|
}
|
||||||
|
|
||||||
|
# 规则关键词字典
|
||||||
|
self.rule_keywords = {
|
||||||
|
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
|
||||||
|
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
|
||||||
|
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
|
||||||
|
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
|
||||||
|
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
|
||||||
|
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
|
||||||
|
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
|
||||||
|
'LIFE': ['生活', '美食', '旅游', '购物'],
|
||||||
|
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
|
||||||
|
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
|
||||||
|
}
|
||||||
|
|
||||||
|
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
|
||||||
|
"""
|
||||||
|
规则匹配
|
||||||
|
:return: (category_code, confidence)
|
||||||
|
"""
|
||||||
|
text = title + ' ' + content
|
||||||
|
|
||||||
|
# 计算每个类别的关键词匹配数
|
||||||
|
matches = {}
|
||||||
|
for category, keywords in self.rule_keywords.items():
|
||||||
|
count = sum(1 for kw in keywords if kw in text)
|
||||||
|
if count > 0:
|
||||||
|
matches[category] = count
|
||||||
|
|
||||||
|
if not matches:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
# 返回匹配最多的类别
|
||||||
|
best_category = max(matches, key=matches.get)
|
||||||
|
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
|
||||||
|
|
||||||
|
return best_category, confidence
|
||||||
|
|
||||||
|
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
混合预测
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# 1. 先尝试规则匹配
|
||||||
|
rule_category, rule_confidence = self.rule_match(title, content)
|
||||||
|
|
||||||
|
# 2. 传统机器学习预测
|
||||||
|
nb_result = self.nb_predictor.predict(title, content)
|
||||||
|
nb_confidence = nb_result['confidence']
|
||||||
|
|
||||||
|
# 决策逻辑
|
||||||
|
final_result = None
|
||||||
|
classifier_type = 'HYBRID'
|
||||||
|
|
||||||
|
# 规则优先且规则置信度高
|
||||||
|
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
|
||||||
|
final_result = {
|
||||||
|
'categoryCode': rule_category,
|
||||||
|
'categoryName': nb_result['categoryName'], # 从映射获取
|
||||||
|
'confidence': rule_confidence,
|
||||||
|
'classifierType': 'RULE',
|
||||||
|
'reason': '规则匹配'
|
||||||
|
}
|
||||||
|
# 传统模型置信度足够高
|
||||||
|
elif nb_confidence >= self.config['confidence_threshold']:
|
||||||
|
final_result = {
|
||||||
|
**nb_result,
|
||||||
|
'classifierType': 'ML',
|
||||||
|
'reason': '传统模型高置信度'
|
||||||
|
}
|
||||||
|
# 需要使用BERT
|
||||||
|
elif use_bert:
|
||||||
|
# TODO: 加载BERT模型预测
|
||||||
|
# bert_result = self.bert_classifier.predict(title, content)
|
||||||
|
# 如果BERT置信度也不高,选择最高的
|
||||||
|
final_result = {
|
||||||
|
**nb_result,
|
||||||
|
'classifierType': 'HYBRID',
|
||||||
|
'reason': '混合决策'
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 不使用BERT,直接返回传统模型结果
|
||||||
|
final_result = {
|
||||||
|
**nb_result,
|
||||||
|
'classifierType': 'ML',
|
||||||
|
'reason': '默认传统模型'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算耗时
|
||||||
|
duration = int((time.time() - start_time) * 1000)
|
||||||
|
final_result['duration'] = duration
|
||||||
|
|
||||||
|
return final_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
classifier = HybridClassifier()
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
{
|
||||||
|
'title': '国务院发布最新经济政策',
|
||||||
|
'content': '国务院今天发布了新的经济政策...'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'title': '华为发布新款折叠屏手机',
|
||||||
|
'content': '华为今天正式发布了新一代折叠屏手机...'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for case in test_cases:
|
||||||
|
result = classifier.predict(case['title'], case['content'])
|
||||||
|
print(f"标题: {case['title']}")
|
||||||
|
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
|
||||||
|
print(f"分类器: {result['classifierType']}")
|
||||||
|
print(f"原因: {result.get('reason', 'N/A')}")
|
||||||
|
print(f"耗时: {result['duration']}ms")
|
||||||
|
print("-" * 50)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 任务 4.1.4: `src/api/server.py` - FastAPI服务
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
机器学习模型API服务
|
||||||
|
使用FastAPI提供RESTful API
|
||||||
|
"""
|
||||||
|
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# 导入分类器
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
|
from traditional.predict import TraditionalPredictor
|
||||||
|
from hybrid.hybrid_classifier import HybridClassifier
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="新闻分类API",
|
||||||
|
description="提供新闻文本分类服务",
|
||||||
|
version="1.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 请求模型
|
||||||
|
class ClassifyRequest(BaseModel):
|
||||||
|
title: str
|
||||||
|
content: str
|
||||||
|
mode: Optional[str] = 'hybrid' # traditional, hybrid
|
||||||
|
|
||||||
|
|
||||||
|
# 响应模型
|
||||||
|
class ClassifyResponse(BaseModel):
|
||||||
|
categoryCode: str
|
||||||
|
categoryName: str
|
||||||
|
confidence: float
|
||||||
|
classifierType: str
|
||||||
|
duration: int
|
||||||
|
probabilities: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
# 初始化分类器
|
||||||
|
nb_predictor = None
|
||||||
|
hybrid_classifier = None
|
||||||
|
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""启动时加载模型"""
|
||||||
|
global nb_predictor, hybrid_classifier
|
||||||
|
|
||||||
|
logger.info("加载模型...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
nb_predictor = TraditionalPredictor('nb')
|
||||||
|
logger.info("朴素贝叶斯模型加载成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"朴素贝叶斯模型加载失败: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
hybrid_classifier = HybridClassifier()
|
||||||
|
logger.info("混合分类器初始化成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"混合分类器初始化失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
"""健康检查"""
|
||||||
|
return {
|
||||||
|
"status": "ok",
|
||||||
|
"message": "新闻分类API服务运行中"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""健康检查"""
|
||||||
|
return {
|
||||||
|
"status": "healthy",
|
||||||
|
"models": {
|
||||||
|
"nb_loaded": nb_predictor is not None,
|
||||||
|
"hybrid_loaded": hybrid_classifier is not None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/predict", response_model=ClassifyResponse)
|
||||||
|
async def predict(request: ClassifyRequest):
|
||||||
|
"""
|
||||||
|
文本分类接口
|
||||||
|
|
||||||
|
- **title**: 新闻标题
|
||||||
|
- **content**: 新闻内容
|
||||||
|
- **mode**: 分类模式 (traditional, hybrid)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if request.mode == 'traditional':
|
||||||
|
result = nb_predictor.predict(request.title, request.content)
|
||||||
|
result['classifierType'] = 'ML'
|
||||||
|
else: # hybrid
|
||||||
|
result = hybrid_classifier.predict(request.title, request.content)
|
||||||
|
|
||||||
|
return ClassifyResponse(**result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"预测失败: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/batch-predict")
|
||||||
|
async def batch_predict(requests: list[ClassifyRequest]):
|
||||||
|
"""
|
||||||
|
批量分类接口
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for req in requests:
|
||||||
|
try:
|
||||||
|
if req.mode == 'traditional':
|
||||||
|
result = nb_predictor.predict(req.title, req.content)
|
||||||
|
result['classifierType'] = 'ML'
|
||||||
|
else:
|
||||||
|
result = hybrid_classifier.predict(req.title, req.content)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
results.append({
|
||||||
|
'error': str(e),
|
||||||
|
'title': req.title
|
||||||
|
})
|
||||||
|
|
||||||
|
return {"results": results}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
uvicorn.run(
|
||||||
|
app,
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=5000,
|
||||||
|
log_level="info"
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 任务 4.1.5: `src/utils/metrics.py` - 评估指标
|
||||||
|
|
||||||
|
```python
|
||||||
|
"""
|
||||||
|
模型评估指标工具
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score,
|
||||||
|
precision_recall_fscore_support,
|
||||||
|
confusion_matrix,
|
||||||
|
classification_report
|
||||||
|
)
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationMetrics:
|
||||||
|
"""分类评估指标"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
计算所有指标
|
||||||
|
"""
|
||||||
|
accuracy = accuracy_score(y_true, y_pred)
|
||||||
|
|
||||||
|
precision, recall, f1, support = precision_recall_fscore_support(
|
||||||
|
y_true, y_pred, average='weighted', zero_division=0
|
||||||
|
)
|
||||||
|
|
||||||
|
# 每个类别的指标
|
||||||
|
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
|
||||||
|
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
|
||||||
|
|
||||||
|
per_class_metrics = {}
|
||||||
|
for i, label in enumerate(labels):
|
||||||
|
per_class_metrics[label] = {
|
||||||
|
'precision': float(precision_per_class[i]),
|
||||||
|
'recall': float(recall_per_class[i]),
|
||||||
|
'f1': float(f1_per_class[i]),
|
||||||
|
'support': int(support_per_class[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
'accuracy': float(accuracy),
|
||||||
|
'precision': float(precision),
|
||||||
|
'recall': float(recall),
|
||||||
|
'f1': float(f1),
|
||||||
|
'per_class': per_class_metrics
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
|
||||||
|
"""
|
||||||
|
绘制混淆矩阵
|
||||||
|
"""
|
||||||
|
cm = confusion_matrix(y_true, y_pred)
|
||||||
|
|
||||||
|
plt.figure(figsize=(10, 8))
|
||||||
|
sns.heatmap(
|
||||||
|
cm,
|
||||||
|
annot=True,
|
||||||
|
fmt='d',
|
||||||
|
cmap='Blues',
|
||||||
|
xticklabels=labels,
|
||||||
|
yticklabels=labels
|
||||||
|
)
|
||||||
|
plt.xlabel('预测标签')
|
||||||
|
plt.ylabel('真实标签')
|
||||||
|
plt.title('混淆矩阵')
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def print_report(y_true: List, y_pred: List, labels: List[str]):
|
||||||
|
"""
|
||||||
|
打印分类报告
|
||||||
|
"""
|
||||||
|
report = classification_report(
|
||||||
|
y_true, y_pred,
|
||||||
|
target_names=labels,
|
||||||
|
zero_division=0
|
||||||
|
)
|
||||||
|
print(report)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 测试
|
||||||
|
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
|
||||||
|
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
|
||||||
|
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
|
||||||
|
|
||||||
|
metrics = ClassificationMetrics()
|
||||||
|
result = metrics.compute_all(y_true, y_pred, labels)
|
||||||
|
print(result)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 任务 4.1.6: `requirements.txt` - 依赖文件
|
||||||
|
|
||||||
|
```txt
|
||||||
|
# 机器学习模块依赖
|
||||||
|
numpy>=1.24.0
|
||||||
|
pandas>=2.0.0
|
||||||
|
scikit-learn>=1.3.0
|
||||||
|
jieba>=0.42.0
|
||||||
|
joblib>=1.3.0
|
||||||
|
|
||||||
|
# 深度学习
|
||||||
|
torch>=2.0.0
|
||||||
|
transformers>=4.30.0
|
||||||
|
|
||||||
|
# API服务
|
||||||
|
fastapi>=0.100.0
|
||||||
|
uvicorn[standard]>=0.23.0
|
||||||
|
pydantic>=2.0.0
|
||||||
|
|
||||||
|
# 数据可视化
|
||||||
|
matplotlib>=3.7.0
|
||||||
|
seaborn>=0.12.0
|
||||||
|
|
||||||
|
# 工具
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
pyyaml>=6.0
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 总结
|
||||||
|
|
||||||
|
### 开发顺序建议
|
||||||
|
|
||||||
|
1. **第一阶段:基础框架**
|
||||||
|
- 后端:数据库连接、实体类、基础配置
|
||||||
|
- 前端:路由配置、状态管理、API封装
|
||||||
|
|
||||||
|
2. **第二阶段:核心功能**
|
||||||
|
- 爬虫模块(Python)
|
||||||
|
- 传统机器学习分类器
|
||||||
|
- 后端API接口
|
||||||
|
- 前端新闻列表页面
|
||||||
|
|
||||||
|
3. **第三阶段:高级功能**
|
||||||
|
- BERT深度学习分类器
|
||||||
|
- 混合策略分类器
|
||||||
|
- 前端分类器对比页面
|
||||||
|
- 统计图表
|
||||||
|
|
||||||
|
4. **第四阶段:完善优化**
|
||||||
|
- 用户认证
|
||||||
|
- 数据可视化
|
||||||
|
- 性能优化
|
||||||
|
- 异常处理
|
||||||
|
|
||||||
|
### 关键注意事项
|
||||||
|
|
||||||
|
1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信
|
||||||
|
2. **分类器模块独立部署**,提供 HTTP 接口供后端调用
|
||||||
|
3. **前后端分离**,使用 JWT 进行身份认证
|
||||||
|
4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守
|
||||||
|
5. **API 统一响应格式**使用 `Result<T>` 包装
|
||||||
Loading…
Reference in New Issue