diff --git a/crawler-module/sina-auto.txt b/crawler-module/sina-auto.txt new file mode 100644 index 0000000..0bac1d5 --- /dev/null +++ b/crawler-module/sina-auto.txt @@ -0,0 +1,82 @@ + +这是新浪网关于爬取汽车相关新闻的代码 +```python +import requests +from bs4 import BeautifulSoup + + +URL = "https://auto.sina.com.cn/" + +headers = { + "User-Agent": ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/120.0.0.0 Safari/537.36" + ) +} +resp = requests.get(URL,headers=headers,timeout=10) +# resp.raise_for_status() +# resp.encoding = "utf-8" +# print(resp.text) +with open("example/example-10.html","r",encoding="utf-8") as f: + html = f.read() + +# soup = BeautifulSoup(resp.text,"lxml") +soup = BeautifulSoup(html,"lxml") +div_list = soup.select("div.feed_card.ty-feed-card-container div.cardlist-a__list div.ty-card.ty-card-type1") + +for item in div_list: + a = item.select_one("div.ty-card-l a") + href = a.get("href") + # print(a.get('href'),a.get_text().strip()) + + resp = requests.get(url=href,headers=headers) + resp.encoding = resp.apparent_encoding # requests 会尝试猜测编码 + soup = BeautifulSoup(resp.text,"lxml") + # 获取文章标题 + article_title_tag = soup.select_one("div.main-content h1.main-title") + if article_title_tag: + article_title = article_title_tag.get_text(strip=True) + if not article_title: + article_title = "未知标题" + else: + article_title = "未知标题" + # print("标题:", article_title) + # 获取文章发布时间 + from datetime import datetime + + # 日期时间格式化函数 + def normalize_time(time_str): + for fmt in ("%Y年%m月%d日 %H:%M", "%Y-%m-%d %H:%M:%S"): + try: + dt = datetime.strptime(time_str, fmt) + return dt.strftime("%Y-%m-%d %H:%M:%S") + except: + continue + return time_str # 如果都不匹配,返回原字符串 + + time_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source span.date") + if time_tag: # 只有存在时间标签才进行格式化 + publish_time = normalize_time(time_tag.get_text(strip=True)) + else: + publish_time = "1949-01-01 12:00:00" + #print(publish_time) + + # 获取文章作者 + author_tag = soup.select_one("div.main-content div.top-bar-wrap div.date-source a") + if author_tag: + author = author_tag.get_text(strip=True) + else: + author = "未知" + # print(author) + # 获取文章正文段落 + article_div = soup.select_one("div.main-content div.article") # 核心文章容器 + if not article_div: + # print("不是文章详情页,跳过") + continue # 如果不是详情页就跳过 + paragraphs = article_div.find_all('p') + article_text = '\n'.join(p.get_text(strip=True) for p in paragraphs if p.get_text(strip=True)) + # print("正文:\n", article_text) + + +``` diff --git a/ml-module/.claude/settings.local.json b/ml-module/.claude/settings.local.json new file mode 100644 index 0000000..c01124d --- /dev/null +++ b/ml-module/.claude/settings.local.json @@ -0,0 +1,7 @@ +{ + "permissions": { + "allow": [ + "Bash(mkdir:*)" + ] + } +} diff --git a/ml-module/README.md b/ml-module/README.md new file mode 100644 index 0000000..dfd0d03 --- /dev/null +++ b/ml-module/README.md @@ -0,0 +1,197 @@ +# 新闻文本分类系统 - 机器学习模块 + +## 功能特性 + +### GPU/CPU自动检测 +- 自动检测可用GPU(包括8GB显存) +- 自动回退到CPU(如果GPU不可用) +- 显示设备信息和显存信息 + +### 动态参数调整 +- 通过配置文件调整训练参数 +- 支持命令行参数覆盖 +- 混合精度自动检测 + +## 使用方法 + +### 1. 安装依赖 +```bash +pip install -r requirements.txt +``` +### 从mysql拉去训练数据 + +```bash +python .\src\utils\data_loader.py +``` +### 2. 配置训练参数 +编辑 `config.yaml` 文件调整训练参数: +```yaml +database: + url: "mysql+pymysql://root:root@localhost/news_classifier" + data_limit: 1000 + +model: + name: "bert-base-chinese" + num_labels: 9 + output_dir: "./models/deep_learning/bert_finetuned" + +training: + use_gpu: true # 自动检测GPU + epochs: 3 + batch_size: 8 + learning_rate: 2e-5 + warmup_steps: 500 + weight_decay: 0.01 +``` + +### 3. 训练BERT模型 +```bash +python train_bert.py +``` + +### 4. 使用命令行参数覆盖配置 +```bash +# 使用GPU训练 +python train_bert.py --use_gpu + +# 调整训练轮数 +python train_bert.py --epochs 5 + +# 调整批大小 +python train_bert.py --batch_size 16 +``` + +### 5. 启动API服务 +```bash +python src/api/server.py +``` + +## 参数说明 + +### 训练参数 +- `epochs`: 训练轮数(默认:3) +- `batch_size`: 批大小(默认:8) +- `learning_rate`: 学习率(默认:2e-5) +- `warmup_steps`: 预热步数(默认:500) +- `weight_decay`: 权重衰减(默认:0.01) + +### 设备配置 +- `use_gpu`: 是否使用GPU(自动检测) +- `fp16`: 混合精度(自动检测) + +## 8GB显存优化建议 + +对于8GB显存的GPU,建议配置: +```yaml +training: + use_gpu: true + epochs: 3 + batch_size: 8-16 + fp16: true # 启用混合精度 +``` + +## 设备检测 +训练时会自动检测设备: +- **GPU可用**:使用GPU训练,显示GPU名称和显存信息 +- **GPU不可用**:自动回退到CPU训练 + +## 性能优化 +- 混合精度训练(FP16) +- 梯度累积 +- 自动批大小调整 +- 内存优化 + +## 注意事项 +1. 确保安装了CUDA和cuDNN(如果使用GPU) +2. 8GB显存可以训练BERT-base模型 +3. 可以通过调整batch_size适应不同显存大小 +4. 训练时间取决于数据量和硬件配置 + +## API接口 +```bash +# 单条预测 +POST /api/predict +{ + "title": "新闻标题", + "content": "新闻内容", + "mode": "hybrid" # traditional, hybrid +} + +# 批量预测 +POST /api/batch-predict +[ + { + "title": "新闻标题1", + "content": "新闻内容1", + "mode": "hybrid" + }, + { + "title": "新闻标题2", + "content": "新闻内容2", + "mode": "traditional" + } +] +``` + + + +--- + +## **分组方案** + +### **1️⃣ 基础科学计算 / 机器学习** + +```cmd +pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0 +``` + +--- + +### **2️⃣ 深度学习 / NLP** + +```cmd +pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0 +``` + +💡 如果你有 **GPU** 并希望安装 GPU 版 PyTorch,需要单独去 PyTorch 官网生成命令。 + +--- + +### **3️⃣ API 服务** + +```cmd +pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0 +``` + +--- + +### **4️⃣ 数据库相关** + +```cmd +pip install sqlalchemy>=2.0.0 pymysql>=1.1.0 +``` + +--- + +### **5️⃣ 数据可视化** + +```cmd +pip install matplotlib>=3.7.0 seaborn>=0.12.0 +``` + +--- + +### **6️⃣ 工具 / 配置文件处理** + +```cmd +pip install python-dotenv>=1.0.0 pyyaml>=6.0 +``` + +--- + +✅ **这样分批安装的好处**: + +1. 出现安装错误时,更容易定位是哪个模块有问题。 +2. 每组依赖关系相对独立,减少冲突。 +3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。 + diff --git a/ml-module/config.yaml b/ml-module/config.yaml new file mode 100644 index 0000000..2e594a5 --- /dev/null +++ b/ml-module/config.yaml @@ -0,0 +1,33 @@ +# BERT模型训练配置 +database: + url: "mysql+pymysql://root:root@localhost/news_classifier" # 数据库连接URL + data_limit: 1000 # 加载数据量限制 + +model: + name: "bert-base-chinese" # 模型名称 + num_labels: 9 # 分类数量 + output_dir: "./models/deep_learning/bert_finetuned" # 模型输出目录 + +training: + use_gpu: true # 是否使用GPU(自动检测) + epochs: 3 # 训练轮数 + batch_size: 8 # 训练批大小 + learning_rate: 2e-5 # 学习率 + warmup_steps: 500 # 预热步数 + weight_decay: 0.01 # 权重衰减 + fp16: null # 混合精度(null表示自动检测) + +# 日志和输出配置 +logging: + level: "INFO" # 日志级别 + file: "./training.log" # 日志文件 + +# 设备配置 +device: + max_memory: "8GB" # 最大内存限制 + gradient_checkpointing: true # 梯度检查点 + +# 性能优化 +optimization: + gradient_accumulation_steps: 1 # 梯度累积步数 + mixed_precision: true # 混合精度(如果支持) \ No newline at end of file diff --git a/ml-module/data/news_stopwords.txt b/ml-module/data/news_stopwords.txt new file mode 100644 index 0000000..0b9800e --- /dev/null +++ b/ml-module/data/news_stopwords.txt @@ -0,0 +1,160 @@ +的 +了 +是 +在 +有 +和 +与 +及 +或 +而 +但 +被 +把 +将 +对 +于 +中 +上 +下 +内 +外 +等 +为 +以 +从 +到 +由 +就 +也 +都 +还 +又 +很 +并 +及其 +以及 +之一 +一些 +一个 +一种 +多种 +各自 +各个 +各类 +记者 +通讯员 +编辑 +报道 +表示 +指出 +认为 +称 +说 +介绍 +透露 +强调 +分析 +称之为 +近日 +今日 +今天 +昨日 +目前 +当前 +今年 +去年 +明年 +此前 +之后 +当日 +当天 +近日来 +近来 +近期 +未来 +方面 +情况 +问题 +工作 +任务 +活动 +会议 +会议上 +会议中 +会议期间 +会议指出 +相关 +有关 +一定 +一些 +部分 +整体 +进一步 +持续 +不断 +不断地 +继续 +推进 +加强 +提升 +改善 +推动 +加快 +实现 +完成 +开展 +进行 +同时 +此外 +因此 +所以 +但是 +然而 +不过 +如果 +因为 +虽然 +由于 +其中 +其中之一 +对此 +对此次 +对此前 +对此后 +通过 +按照 +根据 +依据 +围绕 +围绕着 +针对 +关于 +对于 +面对 +着力 +积极 +主动 +有效 +充分 +全面 +已经 +正在 +正在进行 +正在推进 +正在开展 +开始 +结束 +完成后 +之后 +以来 +以来的 +相关人士 +业内人士 +专家表示 +专家认为 +业内认为 +市场认为 +分析人士 +有关人士 +知情人士 \ No newline at end of file diff --git a/ml-module/database.md b/ml-module/database.md new file mode 100644 index 0000000..e997560 --- /dev/null +++ b/ml-module/database.md @@ -0,0 +1,51 @@ + +### 新闻分类表 +```sql +CREATE TABLE news_category ( + id INT NOT NULL AUTO_INCREMENT COMMENT '分类ID', + name VARCHAR(50) NOT NULL COMMENT '分类名称', + PRIMARY KEY (id), + UNIQUE KEY uk_name (name) +) ENGINE=InnoDB +DEFAULT CHARSET=utf8mb4 +COLLATE=utf8mb4_0900_ai_ci +COMMENT='新闻分类表'; +``` +数据: +```text +1 娱乐 +2 体育 +3 财经 +4 科技 +5 军事 +6 汽车 +7 政务 +8 健康 +9 AI +``` + +### 新闻表 +```sql +CREATE TABLE news ( + id BIGINT NOT NULL AUTO_INCREMENT COMMENT '自增主键', + url VARCHAR(500) NOT NULL COMMENT '新闻原始URL', + title VARCHAR(255) NOT NULL COMMENT '新闻标题', + category_id INT NULL COMMENT '新闻分类ID', + publish_time DATETIME NULL COMMENT '发布时间', + author VARCHAR(100) NULL COMMENT '作者/来源', + source VARCHAR(50) NULL COMMENT '新闻来源(网易/36kr)', + content LONGTEXT NOT NULL COMMENT '新闻正文', + content_hash CHAR(64) NOT NULL COMMENT '正文内容hash,用于去重', + created_at TIMESTAMP NULL DEFAULT CURRENT_TIMESTAMP COMMENT '入库时间', + + PRIMARY KEY (id), + UNIQUE KEY uk_url (url), + UNIQUE KEY uk_content_hash (content_hash), + KEY idx_category_id (category_id), + KEY idx_source (source) + +) ENGINE=InnoDB +DEFAULT CHARSET=utf8mb4 +COLLATE=utf8mb4_0900_ai_ci +COMMENT='新闻表'; +``` diff --git a/ml-module/models/traditional/nb_classifier.pkl b/ml-module/models/traditional/nb_classifier.pkl new file mode 100644 index 0000000..8a30c75 Binary files /dev/null and b/ml-module/models/traditional/nb_classifier.pkl differ diff --git a/ml-module/models/traditional/nb_vectorizer.pkl b/ml-module/models/traditional/nb_vectorizer.pkl new file mode 100644 index 0000000..c05705d Binary files /dev/null and b/ml-module/models/traditional/nb_vectorizer.pkl differ diff --git a/ml-module/requirements.txt b/ml-module/requirements.txt index bf1dffd..de042ae 100644 --- a/ml-module/requirements.txt +++ b/ml-module/requirements.txt @@ -1,11 +1,30 @@ # 机器学习模块依赖 -scikit-learn==1.4.0 -numpy==1.26.3 -pandas==2.1.4 -jieba==0.42.1 -scipy==1.12.0 -joblib==1.3.2 +numpy>=1.24.0 +pandas>=2.0.0 +scikit-learn>=1.3.0 +jieba>=0.42.0 +joblib>=1.3.0 -# 深度学习 (可选) -# torch==2.1.2 -# transformers==4.37.2 +# 深度学习 +torch>=2.0.0 +transformers>=4.30.0 + +# API服务 +fastapi>=0.100.0 +uvicorn[standard]>=0.23.0 +pydantic>=2.0.0 + +# 数据库 +sqlalchemy>=2.0.0 +pymysql>=1.1.0 + +# 数据可视化 +matplotlib>=3.7.0 +seaborn>=0.12.0 + +# 工具 +python-dotenv>=1.0.0 +pyyaml>=6.0 + +# 中文处理 +jieba>=0.42.0 diff --git a/ml-module/src/api/server.py b/ml-module/src/api/server.py new file mode 100644 index 0000000..e0698aa --- /dev/null +++ b/ml-module/src/api/server.py @@ -0,0 +1,157 @@ +""" +机器学习模型API服务 +使用FastAPI提供RESTful API +""" + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from typing import Optional, List +import logging +import sys +import os + +# 导入分类器 +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +from traditional.predict import TraditionalPredictor +from hybrid.hybrid_classifier import HybridClassifier + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 创建FastAPI应用 +app = FastAPI( + title="新闻分类API", + description="提供新闻文本分类服务", + version="1.0.0" +) + +# 配置CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# 请求模型 +class ClassifyRequest(BaseModel): + title: str + content: str + mode: Optional[str] = 'hybrid' # traditional, hybrid + +# 响应模型 +class ClassifyResponse(BaseModel): + categoryCode: str + categoryName: str + confidence: float + classifierType: str + duration: int + probabilities: Optional[dict] = None + + +# 初始化分类器 +nb_predictor = None +hybrid_classifier = None + + +@app.on_event("startup") +async def startup_event(): + """启动时加载模型""" + global nb_predictor, hybrid_classifier + + logger.info("加载模型...") + + try: + nb_predictor = TraditionalPredictor('nb') + logger.info("朴素贝叶斯模型加载成功") + except Exception as e: + logger.error(f"朴素贝叶斯模型加载失败: {e}") + + try: + hybrid_classifier = HybridClassifier() + logger.info("混合分类器初始化成功") + except Exception as e: + logger.error(f"混合分类器初始化失败: {e}") + + +@app.get("/") +async def root(): + """健康检查""" + return { + "status": "ok", + "message": "新闻分类API服务运行中" + } + + +@app.get("/health") +async def health_check(): + """健康检查""" + return { + "status": "healthy", + "models": { + "nb_loaded": nb_predictor is not None, + "hybrid_loaded": hybrid_classifier is not None + } + } + + +@app.post("/api/predict", response_model=ClassifyResponse) +async def predict(request: ClassifyRequest): + """ + 文本分类接口 + + - **title**: 新闻标题 + - **content**: 新闻内容 + - **mode**: 分类模式 (traditional, hybrid) + """ + try: + if request.mode == 'traditional': + result = nb_predictor.predict(request.title, request.content) + result['classifierType'] = 'ML' + else: # hybrid + result = hybrid_classifier.predict(request.title, request.content) + + return ClassifyResponse(**result) + + except Exception as e: + logger.error(f"预测失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/batch-predict") +async def batch_predict(requests: List[ClassifyRequest]): + """ + 批量分类接口 + """ + results = [] + for req in requests: + try: + if req.mode == 'traditional': + result = nb_predictor.predict(req.title, req.content) + result['classifierType'] = 'ML' + else: + result = hybrid_classifier.predict(req.title, req.content) + results.append(result) + except Exception as e: + results.append({ + 'error': str(e), + 'title': req.title + }) + + return {"results": results} + + +if __name__ == '__main__': + import uvicorn + + uvicorn.run( + app, + host="0.0.0.0", + port=5000, + log_level="info" + ) \ No newline at end of file diff --git a/ml-module/src/deep_learning/bert_model.py b/ml-module/src/deep_learning/bert_model.py new file mode 100644 index 0000000..8685094 --- /dev/null +++ b/ml-module/src/deep_learning/bert_model.py @@ -0,0 +1,209 @@ +""" +BERT文本分类模型 +""" + +import torch +from transformers import ( + BertTokenizer, + BertForSequenceClassification, + Trainer, + TrainingArguments +) +from typing import Dict, Any, List + +# 分类映射 +CATEGORY_MAP = { + 'ENTERTAINMENT': '娱乐', + 'SPORTS': '体育', + 'FINANCE': '财经', + 'TECHNOLOGY': '科技', + 'MILITARY': '军事', + 'AUTOMOTIVE': '汽车', + 'GOVERNMENT': '政务', + 'HEALTH': '健康', + 'AI': 'AI' +} + +# 反向映射 +ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())} +LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())} + + +class BertClassifier: + """BERT文本分类器 - 支持GPU/CPU自动检测和动态参数""" + + def __init__(self, model_name='bert-base-chinese', num_labels=9, use_gpu=True): + self.model_name = model_name + self.num_labels = num_labels + self.tokenizer = None + self.model = None + + # GPU/CPU自动检测 + self.use_gpu = use_gpu and torch.cuda.is_available() + self.device = torch.device('cuda' if self.use_gpu else 'cpu') + + # 打印设备信息 + if self.use_gpu: + print(f"使用GPU训练: {torch.cuda.get_device_name(0)}") + print(f"GPU显存: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") + else: + print("使用CPU训练") + + def load_model(self, model_path): + """加载微调后的模型""" + self.tokenizer = BertTokenizer.from_pretrained(model_path) + self.model = BertForSequenceClassification.from_pretrained( + model_path, + num_labels=self.num_labels + ) + self.model.to(self.device) + self.model.eval() + print(f"BERT模型加载成功: {model_path}") + + def train_model(self, train_dataset, eval_dataset, output_dir='./models/deep_learning/bert_finetuned', + num_train_epochs=3, per_device_train_batch_size=8, per_device_eval_batch_size=16, + learning_rate=2e-5, warmup_steps=500, weight_decay=0.01, fp16=None): + """ + 训练BERT模型 + + :param train_dataset: 训练数据集 + :param eval_dataset: 验证数据集 + :param output_dir: 模型输出目录 + :param num_train_epochs: 训练轮数 + :param per_device_train_batch_size: 每个设备的训练batch大小 + :param per_device_eval_batch_size: 每个设备的验证batch大小 + :param learning_rate: 学习率 + :param warmup_steps: 预热步数 + :param weight_decay: 权重衰减 + :param fp16: 是否使用混合精度(None表示自动检测) + """ + # 自动检测是否使用混合精度 + if fp16 is None: + fp16 = self.use_gpu + + # 配置训练参数 + training_args = TrainingArguments( + output_dir=output_dir, + num_train_epochs=num_train_epochs, + per_device_train_batch_size=per_device_train_batch_size, + per_device_eval_batch_size=per_device_eval_batch_size, + learning_rate=learning_rate, + warmup_steps=warmup_steps, + weight_decay=weight_decay, + logging_dir='./logs', + logging_steps=10, + evaluation_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="accuracy", + fp16=fp16, # 混合精度 + gradient_accumulation_steps=1, + ) + + # 创建训练器 + trainer = Trainer( + model=self.model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=eval_dataset + ) + + # 开始训练 + print("开始训练BERT模型...") + trainer.train() + print("训练完成!") + + # 保存模型 + trainer.save_model(output_dir) + self.tokenizer.save_pretrained(output_dir) + print(f"模型已保存到: {output_dir}") + + def predict(self, title: str, content: str) -> Dict[str, Any]: + """ + 预测 + """ + if self.model is None or self.tokenizer is None: + raise ValueError("模型未加载,请先调用load_model") + + # 组合标题和内容 + text = f"{title} [SEP] {content}" + + # 分词 + inputs = self.tokenizer( + text, + return_tensors='pt', + truncation=True, + max_length=512, + padding='max_length' + ) + + # 预测 + with torch.no_grad(): + inputs = {k: v.to(self.device) for k, v in inputs.items()} + outputs = self.model(**inputs) + logits = outputs.logits + + # 获取预测结果 + probs = torch.softmax(logits, dim=-1) + confidence, predicted_id = torch.max(probs, dim=-1) + + predicted_id = predicted_id.item() + confidence = confidence.item() + + # 获取各类别概率 + prob_dict = {} + for i, prob in enumerate(probs[0].cpu().numpy()): + category_code = ID_TO_LABEL[i] + prob_dict[category_code] = float(prob) + + return { + 'categoryCode': ID_TO_LABEL[predicted_id], + 'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'), + 'confidence': confidence, + 'probabilities': prob_dict + } + + +# 数据集类 +class NewsDataset(torch.utils.data.Dataset): + """新闻数据集""" + + def __init__(self, texts, labels, tokenizer, max_length=512): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = self.texts[idx] + label = self.labels[idx] + + encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } + + +if __name__ == '__main__': + # 测试 + classifier = BertClassifier() + # classifier.load_model('./models/deep_learning/bert_finetuned') + # + # result = classifier.predict( + # title="华为发布新款折叠屏手机", + # content="华为今天正式发布了新一代折叠屏手机..." + # ) + # print(result) + print("BERT分类器初始化成功") \ No newline at end of file diff --git a/ml-module/src/hybrid/hybrid_classifier.py b/ml-module/src/hybrid/hybrid_classifier.py new file mode 100644 index 0000000..7e54217 --- /dev/null +++ b/ml-module/src/hybrid/hybrid_classifier.py @@ -0,0 +1,148 @@ +""" +混合策略分类器 +结合规则引擎和机器学习模型 +""" + +import time +from typing import Dict, Any +from ..traditional.predict import TraditionalPredictor +from ..deep_learning.bert_model import BertClassifier + + +class HybridClassifier: + """混合分类器""" + + def __init__(self): + # 初始化各个分类器 + self.nb_predictor = TraditionalPredictor('nb') + self.bert_classifier = BertClassifier() + + # 配置参数 + self.config = { + 'confidence_threshold': 0.75, # 高置信度阈值 + 'hybrid_min_confidence': 0.60, # 混合模式最低阈值 + 'use_bert_threshold': 0.70, # 使用BERT的阈值 + 'rule_priority': True # 规则优先 + } + + # 规则关键词字典 + self.rule_keywords = { + 'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手', '娱乐'], + 'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA', '足球', '篮球'], + 'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行', '理财'], + 'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技', '数码'], + 'MILITARY': ['军事', '武器', '军队', '国防', '战争', '军事'], + 'AUTOMOTIVE': ['汽车', '车', '车型', '驾驶', '交通', ' automotive'], + 'GOVERNMENT': ['政府', '政策', '选举', '国务院', '主席', '总理', '政务'], + 'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗', '健康'], + 'AI': ['AI', '人工智能', '机器学习', '深度学习', '算法', 'AI'] + } + + def rule_match(self, title: str, content: str) -> tuple[str | None, float]: + """ + 规则匹配 + :return: (category_code, confidence) + """ + text = title + ' ' + content + + # 计算每个类别的关键词匹配数 + matches = {} + for category, keywords in self.rule_keywords.items(): + count = sum(1 for kw in keywords if kw in text) + if count > 0: + matches[category] = count + + if not matches: + return None, 0.0 + + # 返回匹配最多的类别 + best_category = max(matches, key=matches.get) + confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度 + + return best_category, confidence + + def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]: + """ + 混合预测 + """ + start_time = time.time() + + # 1. 先尝试规则匹配 + rule_category, rule_confidence = self.rule_match(title, content) + + # 2. 传统机器学习预测 + nb_result = self.nb_predictor.predict(title, content) + nb_confidence = nb_result['confidence'] + + # 决策逻辑 + final_result = None + classifier_type = 'HYBRID' + + # 规则优先且规则置信度高 + if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']: + final_result = { + 'categoryCode': rule_category, + 'categoryName': nb_result['categoryName'], # 从映射获取 + 'confidence': rule_confidence, + 'classifierType': 'RULE', + 'reason': '规则匹配' + } + # 传统模型置信度足够高 + elif nb_confidence >= self.config['confidence_threshold']: + final_result = { + **nb_result, + 'classifierType': 'ML', + 'reason': '传统模型高置信度' + } + # 需要使用BERT + elif use_bert: + # TODO: 加载BERT模型预测 + # bert_result = self.bert_classifier.predict(title, content) + # 如果BERT置信度也不高,选择最高的 + final_result = { + **nb_result, + 'classifierType': 'HYBRID', + 'reason': '混合决策' + } + else: + # 不使用BERT,直接返回传统模型结果 + final_result = { + **nb_result, + 'classifierType': 'ML', + 'reason': '默认传统模型' + } + + # 计算耗时 + duration = int((time.time() - start_time) * 1000) + final_result['duration'] = duration + + return final_result + + +if __name__ == '__main__': + # 测试 + classifier = HybridClassifier() + + test_cases = [ + { + 'title': '国务院发布最新经济政策', + 'content': '国务院今天发布了新的经济政策...' + }, + { + 'title': '华为发布新款折叠屏手机', + 'content': '华为今天正式发布了新一代折叠屏手机...' + }, + { + 'title': 'NBA总决赛精彩瞬间', + 'content': '湖人队与勇士队的总决赛...' + } + ] + + for case in test_cases: + result = classifier.predict(case['title'], case['content']) + print(f"标题: {case['title']}") + print(f"结果: {result['categoryName']} ({result['confidence']:.2f})") + print(f"分类器: {result['classifierType']}") + print(f"原因: {result.get('reason', 'N/A')}") + print(f"耗时: {result['duration']}ms") + print("-" * 50) \ No newline at end of file diff --git a/ml-module/src/traditional/predict.py b/ml-module/src/traditional/predict.py new file mode 100644 index 0000000..3d11719 --- /dev/null +++ b/ml-module/src/traditional/predict.py @@ -0,0 +1,94 @@ +""" +传统机器学习模型预测 +""" + +import os +import joblib +import jieba +from typing import Dict, Any + +# 分类映射(根据数据库中的分类) +CATEGORY_MAP = { + 'ENTERTAINMENT': '娱乐', + 'SPORTS': '体育', + 'FINANCE': '财经', + 'TECHNOLOGY': '科技', + 'MILITARY': '军事', + 'AUTOMOTIVE': '汽车', + 'GOVERNMENT': '政务', + 'HEALTH': '健康', + 'AI': 'AI' +} + + +class TraditionalPredictor: + """传统机器学习预测器""" + + def __init__(self, model_type='nb', model_dir='../../models/traditional'): + self.model_type = model_type + self.model_dir = model_dir + self.vectorizer = None + self.classifier = None + self._load_model() + + def _load_model(self): + """加载模型""" + vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl') + classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl') + + self.vectorizer = joblib.load(vectorizer_path) + self.classifier = joblib.load(classifier_path) + print(f"模型加载成功: {self.model_type}") + + def preprocess(self, title: str, content: str) -> str: + """预处理文本""" + text = title + ' ' + content + # jieba分词 + words = jieba.cut(text) + return ' '.join(words) + + def predict(self, title: str, content: str) -> Dict[str, Any]: + """ + 预测 + :return: 预测结果字典 + """ + # 预处理 + processed = self.preprocess(title, content) + + # 特征提取 + tfidf = self.vectorizer.transform([processed]) + + # 预测 + prediction = self.classifier.predict(tfidf)[0] + probabilities = self.classifier.predict_proba(tfidf)[0] + + # 获取各类别概率 + prob_dict = {} + for i, prob in enumerate(probabilities): + category_code = self.classifier.classes_[i] + prob_dict[category_code] = float(prob) + + return { + 'categoryCode': prediction, + 'categoryName': CATEGORY_MAP.get(prediction, '未知'), + 'confidence': float(probabilities.max()), + 'probabilities': prob_dict + } + + +# API入口 +def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]: + """ + 单条预测API + """ + predictor = TraditionalPredictor(model_type) + return predictor.predict(title, content) + + +if __name__ == '__main__': + # 测试 + result = predict_single( + title="华为发布新款折叠屏手机", + content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..." + ) + print(result) \ No newline at end of file diff --git a/ml-module/src/traditional/train_model.py b/ml-module/src/traditional/train_model.py index 607df4a..7615eab 100644 --- a/ml-module/src/traditional/train_model.py +++ b/ml-module/src/traditional/train_model.py @@ -14,45 +14,95 @@ from sklearn.svm import SVC from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, accuracy_score, f1_score -# 分类映射 +# 分类映射(与数据库表一致) CATEGORY_MAP = { - 'POLITICS': '时政', + 'ENTERTAINMENT': '娱乐', + 'SPORTS': '体育', 'FINANCE': '财经', 'TECHNOLOGY': '科技', - 'SPORTS': '体育', - 'ENTERTAINMENT': '娱乐', + 'MILITARY': '军事', + 'AUTOMOTIVE': '汽车', + 'GOVERNMENT': '政务', 'HEALTH': '健康', - 'EDUCATION': '教育', - 'LIFE': '生活', - 'INTERNATIONAL': '国际', - 'MILITARY': '军事' + 'AI': 'AI' } REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()} +# 分类专属停用词(与数据库表一致) +CATEGORY_STOPWORDS = { + 'ENTERTAINMENT': {'主演', '影片', '电影', '电视剧', '节目', '导演', '角色', '上映', '粉丝'}, + 'SPORTS': {'比赛', '赛事', '赛季', '球队', '选手', '球员', '主场', '客场', '对阵', '比分'}, + 'FINANCE': {'亿元', '万元', '同比', '环比', '增长率', '数据', '报告', '统计', '财报', '季度', '年度'}, + 'TECHNOLOGY': {'技术', '系统', '平台', '方案', '应用', '功能', '版本', '升级', '研发', '推出'}, + 'MILITARY': {'部队', '军方', '演习', '训练', '装备', '武器', '作战', '行动', '部署'}, + 'AUTOMOTIVE': {'汽车', '车型', '上市', '发布', '销量', '市场', '品牌', '厂商'}, + 'GOVERNMENT': {'会议', '讲话', '指出', '强调', '部署', '落实', '推进', '要求', '精神', '决定', '意见', '方案', '安排'}, + 'HEALTH': {'医生', '专家', '建议', '提示', '提醒', '研究', '发现', '可能', '有助于'}, + 'AI': {'技术', '系统', '模型', '算法', '应用', '功能', '版本', '升级', '研发', '推出', '人工智能'} +} + + class NewsClassifier: """新闻文本分类器""" - def __init__(self, model_type='nb'): + def __init__(self, model_type='nb', use_stopwords=True, use_category_stopwords=False): """ 初始化分类器 :param model_type: 模型类型 'nb' 朴素贝叶斯 或 'svm' 支持向量机 + :param use_stopwords: 是否使用通用停用词 + :param use_category_stopwords: 是否使用分类专属停用词 """ self.model_type = model_type self.vectorizer = None self.classifier = None self.categories = list(CATEGORY_MAP.keys()) + self.use_stopwords = use_stopwords + self.use_category_stopwords = use_category_stopwords + self.stopwords = set() - def preprocess_text(self, text): + if self.use_stopwords: + self._load_stopwords() + + def _load_stopwords(self, stopwords_path='../../data/news_stopwords.txt'): """ - 文本预处理:使用jieba分词 + 加载停用词表 + """ + try: + with open(stopwords_path, 'r', encoding='utf-8') as f: + self.stopwords = set(line.strip() for line in f if line.strip()) + print(f"已加载 {len(self.stopwords)} 个停用词") + except FileNotFoundError: + print(f"警告: 停用词文件不存在: {stopwords_path}") + + def preprocess_text(self, text, category=None): + """ + 文本预处理:使用jieba分词 + 停用词过滤 + :param text: 待处理文本 + :param category: 可选,指定分类时使用分类专属停用词 """ # 移除多余空格和换行 text = ' '.join(text.split()) # jieba分词 words = jieba.cut(text) - return ' '.join(words) + + # 过滤停用词和单字词 + result = [] + for w in words: + # 过滤单字词 + if len(w) <= 1: + continue + # 过滤通用停用词 + if self.use_stopwords and w in self.stopwords: + continue + # 过滤分类专属停用词 + if self.use_category_stopwords and category: + if w in CATEGORY_STOPWORDS.get(category, set()): + continue + result.append(w) + + return ' '.join(result) def load_data(self, csv_path): """ @@ -61,10 +111,22 @@ class NewsClassifier: df = pd.read_csv(csv_path) # 合并标题和内容作为特征 df['text'] = df['title'] + ' ' + df['content'] - # 预处理 - df['processed_text'] = df['text'].apply(self.preprocess_text) - # 转换分类名称为代码 - df['category_code'] = df['category'].map(REVERSE_CATEGORY_MAP) + + # 预处理(如果启用了分类专属停用词,需要传入分类信息) + if self.use_category_stopwords: + # 先转换为分类代码 + df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP) + # 逐行预处理,传入分类信息 + df['processed_text'] = df.apply( + lambda row: self.preprocess_text(row['text'], row['category_code']), + axis=1 + ) + else: + # 不使用分类专属停用词,直接批量预处理 + df['processed_text'] = df['text'].apply(self.preprocess_text) + # 转换分类名称为代码 + df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP) + return df def train(self, df): @@ -74,6 +136,10 @@ class NewsClassifier: X = df['processed_text'].values y = df['category_code'].values + # 获取实际数据中的分类 + actual_categories = sorted(df['category_code'].unique().tolist()) + actual_category_names = [CATEGORY_MAP[cat] for cat in actual_categories] + # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y @@ -108,7 +174,7 @@ class NewsClassifier: print(f"准确率: {accuracy:.4f}") print(f"F1-Score: {f1:.4f}") print("\n分类报告:") - print(classification_report(y_test, y_pred, target_names=self.categories)) + print(classification_report(y_test, y_pred, target_names=actual_category_names)) return accuracy, f1 @@ -120,6 +186,7 @@ class NewsClassifier: raise ValueError("模型未训练,请先调用train方法") text = title + ' ' + content + # 预测时不指定分类,只使用通用停用词 processed = self.preprocess_text(text) tfidf = self.vectorizer.transform([processed]) @@ -157,12 +224,12 @@ if __name__ == '__main__': classifier = NewsClassifier(model_type='nb') # 假设有训练数据文件 - train_data_path = '../data/processed/training_data.csv' + train_data_path = '../../data/processed/training_data.csv' if os.path.exists(train_data_path): df = classifier.load_data(train_data_path) classifier.train(df) - classifier.save_model('../models') + classifier.save_model('../../models/traditional') # 测试预测 test_title = "华为发布新款折叠屏手机" diff --git a/ml-module/src/utils/data_loader.py b/ml-module/src/utils/data_loader.py new file mode 100644 index 0000000..1ba76ba --- /dev/null +++ b/ml-module/src/utils/data_loader.py @@ -0,0 +1,163 @@ +""" +数据加载和本地存储工具 +""" + +import pandas as pd +import os +from sqlalchemy import create_engine +from typing import List, Dict, Any, Optional +import logging + +logger = logging.getLogger(__name__) + +class DataLoader: + """数据加载和本地存储类""" + + def __init__(self, db_url: str, data_dir: str = 'data'): + """ + 初始化数据加载器 + + :param db_url: 数据库连接URL + :param data_dir: 数据存储目录 + """ + self.db_url = db_url + self.data_dir = data_dir + self.engine = create_engine(db_url) + + # 确保数据目录存在 + os.makedirs(os.path.join(data_dir, 'raw'), exist_ok=True) + os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True) + + def load_news_from_db(self, limit: int = None, category_ids: List[int] = None) -> pd.DataFrame: + """ + 从数据库加载新闻数据 + + :param limit: 限制加载数据量 + :param category_ids: 指定分类ID列表 + :return: 新闻数据DataFrame + """ + query = """ + SELECT n.id, n.title, n.content, c.name as category_name + FROM news n + LEFT JOIN news_category c ON n.category_id = c.id + WHERE n.content IS NOT NULL AND n.title IS NOT NULL + """ + + if category_ids: + ids_str = ",".join(str(i) for i in category_ids) + query += f" AND c.id IN ({ids_str})" + + if limit: + query += f" LIMIT {limit}" + + df = pd.read_sql(query, self.engine) + + + logger.info(f"从数据库加载数据,条件: limit={limit}, category_ids={category_ids}") + + try: + df = pd.read_sql(query, self.engine) + logger.info(f"成功加载 {len(df)} 条新闻数据") + return df + except Exception as e: + logger.error(f"加载数据失败: {e}") + raise + + def save_to_local(self, df: pd.DataFrame, filename: str, subdir: str = 'processed'): + """ + 将数据保存到本地 + + :param df: 要保存的DataFrame + :param filename: 文件名 + :param subdir: 子目录 + """ + file_path = os.path.join(self.data_dir, subdir, filename) + + # 确保目录存在 + os.makedirs(os.path.join(self.data_dir, subdir), exist_ok=True) + + # 保存为CSV + df.to_csv(file_path, index=False, encoding='utf-8-sig') + logger.info(f"数据已保存到: {file_path}") + + def load_from_local(self, filename: str, subdir: str = 'processed') -> Optional[pd.DataFrame]: + """ + 从本地加载已保存的数据 + + :param filename: 文件名 + :param subdir: 子目录 + :return: DataFrame或None(如果文件不存在) + """ + file_path = os.path.join(self.data_dir, subdir, filename) + + if os.path.exists(file_path): + df = pd.read_csv(file_path, encoding='utf-8-sig') + logger.info(f"从本地加载 {len(df)} 条数据: {file_path}") + return df + else: + logger.warning(f"本地文件不存在: {file_path}") + return None + + def update_local_data(self, limit: int = None, category_ids: List[int] = None): + """ + 更新本地数据(从数据库重新加载数据并保存) + + :param limit: 限制加载数据量 + :param category_ids: 指定分类ID列表 + """ + # 从数据库加载数据 + df = self.load_news_from_db(limit=limit, category_ids=category_ids) + + if not df.empty: + # 生成文件名(包含时间戳) + import time + timestamp = time.strftime("%Y%m%d_%H%M%S") + filename = f"news_data_{timestamp}.csv" + + # 保存到本地 + self.save_to_local(df, filename) + + # 可选:删除旧的本地文件,只保留最新的 + self._cleanup_old_files(filename) + + return df + else: + logger.warning("没有可用的数据需要保存") + return None + + def _cleanup_old_files(self, current_filename: str): + """ + 清理旧的本地文件(可选) + + :param current_filename: 当前文件名 + """ + subdir = 'processed' + dir_path = os.path.join(self.data_dir, subdir) + + if os.path.exists(dir_path): + for filename in os.listdir(dir_path): + if filename != current_filename and filename.startswith('news_data_'): + file_path = os.path.join(dir_path, filename) + try: + os.remove(file_path) + logger.info(f"已删除旧文件: {file_path}") + except Exception as e: + logger.error(f"删除文件失败 {file_path}: {e}") + + +# 示例用法 +if __name__ == '__main__': + # 配置数据库连接 + db_url = "mysql+pymysql://root:root@localhost/news" + + # 创建数据加载器 + loader = DataLoader(db_url) + + # 从数据库加载数据并保存到本地 + data = loader.update_local_data(limit=800) + + if data is not None: + print(f"成功加载数据,共 {len(data)} 条记录") + print(data.head()) + else: + print("没有数据可加载") \ No newline at end of file diff --git a/ml-module/src/utils/metrics.py b/ml-module/src/utils/metrics.py new file mode 100644 index 0000000..6c0f69d --- /dev/null +++ b/ml-module/src/utils/metrics.py @@ -0,0 +1,96 @@ +""" +模型评估指标工具 +""" + +import numpy as np +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + confusion_matrix, + classification_report +) +from typing import List, Dict, Any +import matplotlib.pyplot as plt +import seaborn as sns + +class ClassificationMetrics: + """分类评估指标""" + + @staticmethod + def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]: + """ + 计算所有指标 + """ + accuracy = accuracy_score(y_true, y_pred) + + precision, recall, f1, support = precision_recall_fscore_support( + y_true, y_pred, average='weighted', zero_division=0 + ) + + # 每个类别的指标 + precision_per_class, recall_per_class, f1_per_class, support_per_class = \ + precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0) + + per_class_metrics = {} + for i, label in enumerate(labels): + per_class_metrics[label] = { + 'precision': float(precision_per_class[i]), + 'recall': float(recall_per_class[i]), + 'f1': float(f1_per_class[i]), + 'support': int(support_per_class[i]) + } + + return { + 'accuracy': float(accuracy), + 'precision': float(precision), + 'recall': float(recall), + 'f1': float(f1), + 'per_class': per_class_metrics + } + + @staticmethod + def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None): + """ + 绘制混淆矩阵 + """ + cm = confusion_matrix(y_true, y_pred) + + plt.figure(figsize=(10, 8)) + sns.heatmap( + cm, + annot=True, + fmt='d', + cmap='Blues', + xticklabels=labels, + yticklabels=labels + ) + plt.xlabel('预测标签') + plt.ylabel('真实标签') + plt.title('混淆矩阵') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + @staticmethod + def print_report(y_true: List, y_pred: List, labels: List[str]): + """ + 打印分类报告 + """ + report = classification_report( + y_true, y_pred, + target_names=labels, + zero_division=0 + ) + print(report) + + +if __name__ == '__main__': + # 测试 + y_true = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE', 'ENTERTAINMENT', 'TECHNOLOGY'] + y_pred = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY'] + labels = ['ENTERTAINMENT', 'TECHNOLOGY', 'FINANCE'] + + metrics = ClassificationMetrics() + result = metrics.compute_all(y_true, y_pred, labels) + print(result) \ No newline at end of file diff --git a/ml-module/train_bert.py b/ml-module/train_bert.py new file mode 100644 index 0000000..8a75cfc --- /dev/null +++ b/ml-module/train_bert.py @@ -0,0 +1,114 @@ +""" +BERT模型训练脚本 +支持GPU/CPU自动检测和动态参数调整 +""" + +import argparse +import yaml +import os +from src.deep_learning.bert_model import BertClassifier +from src.utils.data_loader import DataLoader +from datasets import Dataset +import pandas as pd +from transformers import BertTokenizer, DataCollatorWithPadding + +def load_and_prepare_data(db_url: str, limit: int = 1000): + """加载数据并准备训练集""" + # 创建数据加载器 + loader = DataLoader(db_url) + + # 从数据库加载数据 + df = loader.load_news_from_db(limit=limit) + + if df.empty: + print("没有可用的训练数据") + return None, None + + # 准备数据 + texts = df['title'] + ' ' + df['content'] + labels = df['category_name'].map({ + '娱乐': 0, '体育': 1, '财经': 2, '科技': 3, '军事': 4, + '汽车': 5, '政务': 6, '健康': 7, 'AI': 8 + }).fillna(-1).astype(int) + + # 过滤无效数据 + valid_data = df[labels != -1] + texts = valid_data['title'] + ' ' + valid_data['content'] + labels = labels[labels != -1] + + print(f"有效数据数量: {len(valid_data)}") + + # 创建Hugging Face Dataset + dataset = Dataset.from_pandas(pd.DataFrame({ + 'text': texts.tolist(), + 'label': labels.tolist() + })) + + # 划分训练集和验证集 + train_test = dataset.train_test_split(test_size=0.2) + + return train_test['train'], train_test['test'] + +def train_bert_model(config: dict): + """训练BERT模型""" + # 加载数据 + train_dataset, eval_dataset = load_and_prepare_data( + config['database']['url'], + limit=config['training']['data_limit'] + ) + + if train_dataset is None: + return + + # 初始化模型 + classifier = BertClassifier( + model_name=config['model']['name'], + num_labels=config['model']['num_labels'], + use_gpu=config['training']['use_gpu'] + ) + + # 训练模型 + classifier.train_model( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + output_dir=config['model']['output_dir'], + num_train_epochs=config['training']['epochs'], + per_device_train_batch_size=config['training']['batch_size'], + per_device_eval_batch_size=config['training']['batch_size'] * 2, + learning_rate=config['training']['learning_rate'], + warmup_steps=config['training']['warmup_steps'], + weight_decay=config['training']['weight_decay'], + fp16=config['training'].get('fp16', None) + ) + +def main(): + # 解析命令行参数 + parser = argparse.ArgumentParser(description='BERT模型训练脚本') + parser.add_argument('--config', type=str, default='config.yaml', help='配置文件路径') + parser.add_argument('--use_gpu', action='store_true', help='强制使用GPU') + parser.add_argument('--epochs', type=int, help='训练轮数(覆盖配置文件)') + parser.add_argument('--batch_size', type=int, help='批大小(覆盖配置文件)') + args = parser.parse_args() + + # 加载配置文件 + with open(args.config, 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + + # 覆盖配置参数 + if args.epochs is not None: + config['training']['epochs'] = args.epochs + if args.batch_size is not None: + config['training']['batch_size'] = args.batch_size + if args.use_gpu: + config['training']['use_gpu'] = True + + # 开始训练 + print("开始BERT模型训练...") + print(f"配置: {config}") + + train_bert_model(config) + + print("训练完成!") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ml-module/模块开发任务清单.md b/ml-module/模块开发任务清单.md new file mode 100644 index 0000000..0016c37 --- /dev/null +++ b/ml-module/模块开发任务清单.md @@ -0,0 +1,800 @@ +# 新闻文本分类系统 - 模块开发任务清单 +--- +## 目录 + +1. [爬虫模块 (Python)](#1-爬虫模块-python) +2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot) +3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3) +4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python) + +--- + +## 1. 爬虫模块 (Python) + +## 2. 后端服务模块 (Spring Boot) + +## 4. 机器学习分类模块 (Python) + +### 模块目录结构 + +``` +ml-module/ +├── data/ +│ ├── raw/ # 原始数据 +│ ├── processed/ # 处理后的数据 +│ │ ├── training_data.csv +│ │ └── test_data.csv +│ └── external/ # 外部数据集 +├── models/ # 训练好的模型 +│ ├── traditional/ +│ │ ├── nb_vectorizer.pkl +│ │ ├── nb_classifier.pkl +│ │ ├── svm_vectorizer.pkl +│ │ └── svm_classifier.pkl +│ ├── deep_learning/ +│ │ └── bert_finetuned/ +│ └── hybrid/ +│ └── config.json +├── src/ +│ ├── __init__.py +│ ├── traditional/ # 传统机器学习 +│ │ ├── __init__.py +│ │ ├── train_model.py # (已有) +│ │ ├── predict.py +│ │ └── evaluate.py +│ ├── deep_learning/ # 深度学习 +│ │ ├── __init__.py +│ │ ├── bert_model.py +│ │ ├── train_bert.py +│ │ └── predict_bert.py +│ ├── hybrid/ # 混合策略 +│ │ ├── __init__.py +│ │ ├── hybrid_classifier.py +│ │ └── rule_engine.py +│ ├── utils/ +│ │ ├── __init__.py +│ │ ├── preprocessing.py # 数据预处理 +│ │ └── metrics.py # 评估指标 +│ └── api/ # API服务 +│ ├── __init__.py +│ └── server.py # FastAPI服务 +├── notebooks/ # Jupyter notebooks +│ ├── data_exploration.ipynb +│ └── model_comparison.ipynb +├── tests/ # 测试 +│ ├── test_traditional.py +│ ├── test_bert.py +│ └── test_hybrid.py +├── requirements.txt +├── setup.py +└── README.md +``` + +### 4.1 需要完成的具体文件 + +#### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测 + +```python +""" +传统机器学习模型预测 +""" + +import os +import joblib +import jieba +from typing import Dict, Any + +# 分类映射 +CATEGORY_MAP = { + 'POLITICS': '时政', + 'FINANCE': '财经', + 'TECHNOLOGY': '科技', + 'SPORTS': '体育', + 'ENTERTAINMENT': '娱乐', + 'HEALTH': '健康', + 'EDUCATION': '教育', + 'LIFE': '生活', + 'INTERNATIONAL': '国际', + 'MILITARY': '军事' +} + + +class TraditionalPredictor: + """传统机器学习预测器""" + + def __init__(self, model_type='nb', model_dir='../../models/traditional'): + self.model_type = model_type + self.model_dir = model_dir + self.vectorizer = None + self.classifier = None + self._load_model() + + def _load_model(self): + """加载模型""" + vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl') + classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl') + + self.vectorizer = joblib.load(vectorizer_path) + self.classifier = joblib.load(classifier_path) + print(f"模型加载成功: {self.model_type}") + + def preprocess(self, title: str, content: str) -> str: + """预处理文本""" + text = title + ' ' + content + # jieba分词 + words = jieba.cut(text) + return ' '.join(words) + + def predict(self, title: str, content: str) -> Dict[str, Any]: + """ + 预测 + :return: 预测结果字典 + """ + # 预处理 + processed = self.preprocess(title, content) + + # 特征提取 + tfidf = self.vectorizer.transform([processed]) + + # 预测 + prediction = self.classifier.predict(tfidf)[0] + probabilities = self.classifier.predict_proba(tfidf)[0] + + # 获取各类别概率 + prob_dict = {} + for i, prob in enumerate(probabilities): + category_code = self.classifier.classes_[i] + prob_dict[category_code] = float(prob) + + return { + 'categoryCode': prediction, + 'categoryName': CATEGORY_MAP.get(prediction, '未知'), + 'confidence': float(probabilities.max()), + 'probabilities': prob_dict + } + + +# API入口 +def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]: + """ + 单条预测API + """ + predictor = TraditionalPredictor(model_type) + return predictor.predict(title, content) + + +if __name__ == '__main__': + # 测试 + result = predict_single( + title="华为发布新款折叠屏手机", + content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..." + ) + print(result) +``` + +#### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型 + +```python +""" +BERT文本分类模型 +""" + +import torch +from transformers import ( + BertTokenizer, + BertForSequenceClassification, + Trainer, + TrainingArguments +) +from typing import Dict, Any, List + + +# 分类映射 +CATEGORY_MAP = { + 'POLITICS': '时政', + 'FINANCE': '财经', + 'TECHNOLOGY': '科技', + 'SPORTS': '体育', + 'ENTERTAINMENT': '娱乐', + 'HEALTH': '健康', + 'EDUCATION': '教育', + 'LIFE': '生活', + 'INTERNATIONAL': '国际', + 'MILITARY': '军事' +} + +# 反向映射 +ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())} +LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())} + + +class BertClassifier: + """BERT文本分类器""" + + def __init__(self, model_name='bert-base-chinese', num_labels=10): + self.model_name = model_name + self.num_labels = num_labels + self.tokenizer = None + self.model = None + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def load_model(self, model_path): + """加载微调后的模型""" + self.tokenizer = BertTokenizer.from_pretrained(model_path) + self.model = BertForSequenceClassification.from_pretrained( + model_path, + num_labels=self.num_labels + ) + self.model.to(self.device) + self.model.eval() + print(f"BERT模型加载成功: {model_path}") + + def predict(self, title: str, content: str) -> Dict[str, Any]: + """ + 预测 + """ + if self.model is None or self.tokenizer is None: + raise ValueError("模型未加载,请先调用load_model") + + # 组合标题和内容 + text = f"{title} [SEP] {content}" + + # 分词 + inputs = self.tokenizer( + text, + return_tensors='pt', + truncation=True, + max_length=512, + padding='max_length' + ) + + # 预测 + with torch.no_grad(): + inputs = {k: v.to(self.device) for k, v in inputs.items()} + outputs = self.model(**inputs) + logits = outputs.logits + + # 获取预测结果 + probs = torch.softmax(logits, dim=-1) + confidence, predicted_id = torch.max(probs, dim=-1) + + predicted_id = predicted_id.item() + confidence = confidence.item() + + # 获取各类别概率 + prob_dict = {} + for i, prob in enumerate(probs[0].cpu().numpy()): + category_code = ID_TO_LABEL[i] + prob_dict[category_code] = float(prob) + + return { + 'categoryCode': ID_TO_LABEL[predicted_id], + 'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'), + 'confidence': confidence, + 'probabilities': prob_dict + } + + +# 数据集类 +class NewsDataset(torch.utils.data.Dataset): + """新闻数据集""" + + def __init__(self, texts, labels, tokenizer, max_length=512): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = self.texts[idx] + label = self.labels[idx] + + encoding = self.tokenizer( + text, + truncation=True, + max_length=self.max_length, + padding='max_length', + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } + + +if __name__ == '__main__': + # 测试 + classifier = BertClassifier() + # classifier.load_model('./models/deep_learning/bert_finetuned') + # + # result = classifier.predict( + # title="华为发布新款折叠屏手机", + # content="华为今天正式发布了新一代折叠屏手机..." + # ) + # print(result) + print("BERT分类器初始化成功") +``` + +#### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器 + +```python +""" +混合策略分类器 +结合规则引擎和机器学习模型 +""" + +import time +from typing import Dict, Any +from ..traditional.predict import TraditionalPredictor +from ..deep_learning.bert_model import BertClassifier + + +class HybridClassifier: + """混合分类器""" + + def __init__(self): + # 初始化各个分类器 + self.nb_predictor = TraditionalPredictor('nb') + self.bert_classifier = BertClassifier() + + # 配置参数 + self.config = { + 'confidence_threshold': 0.75, # 高置信度阈值 + 'hybrid_min_confidence': 0.60, # 混合模式最低阈值 + 'use_bert_threshold': 0.70, # 使用BERT的阈值 + 'rule_priority': True # 规则优先 + } + + # 规则关键词字典 + self.rule_keywords = { + 'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'], + 'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'], + 'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'], + 'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'], + 'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'], + 'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'], + 'EDUCATION': ['教育', '学校', '大学', '考试', '招生'], + 'LIFE': ['生活', '美食', '旅游', '购物'], + 'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'], + 'MILITARY': ['军事', '武器', '军队', '国防', '战争'] + } + + def rule_match(self, title: str, content: str) -> tuple[str | None, float]: + """ + 规则匹配 + :return: (category_code, confidence) + """ + text = title + ' ' + content + + # 计算每个类别的关键词匹配数 + matches = {} + for category, keywords in self.rule_keywords.items(): + count = sum(1 for kw in keywords if kw in text) + if count > 0: + matches[category] = count + + if not matches: + return None, 0.0 + + # 返回匹配最多的类别 + best_category = max(matches, key=matches.get) + confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度 + + return best_category, confidence + + def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]: + """ + 混合预测 + """ + start_time = time.time() + + # 1. 先尝试规则匹配 + rule_category, rule_confidence = self.rule_match(title, content) + + # 2. 传统机器学习预测 + nb_result = self.nb_predictor.predict(title, content) + nb_confidence = nb_result['confidence'] + + # 决策逻辑 + final_result = None + classifier_type = 'HYBRID' + + # 规则优先且规则置信度高 + if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']: + final_result = { + 'categoryCode': rule_category, + 'categoryName': nb_result['categoryName'], # 从映射获取 + 'confidence': rule_confidence, + 'classifierType': 'RULE', + 'reason': '规则匹配' + } + # 传统模型置信度足够高 + elif nb_confidence >= self.config['confidence_threshold']: + final_result = { + **nb_result, + 'classifierType': 'ML', + 'reason': '传统模型高置信度' + } + # 需要使用BERT + elif use_bert: + # TODO: 加载BERT模型预测 + # bert_result = self.bert_classifier.predict(title, content) + # 如果BERT置信度也不高,选择最高的 + final_result = { + **nb_result, + 'classifierType': 'HYBRID', + 'reason': '混合决策' + } + else: + # 不使用BERT,直接返回传统模型结果 + final_result = { + **nb_result, + 'classifierType': 'ML', + 'reason': '默认传统模型' + } + + # 计算耗时 + duration = int((time.time() - start_time) * 1000) + final_result['duration'] = duration + + return final_result + + +if __name__ == '__main__': + # 测试 + classifier = HybridClassifier() + + test_cases = [ + { + 'title': '国务院发布最新经济政策', + 'content': '国务院今天发布了新的经济政策...' + }, + { + 'title': '华为发布新款折叠屏手机', + 'content': '华为今天正式发布了新一代折叠屏手机...' + } + ] + + for case in test_cases: + result = classifier.predict(case['title'], case['content']) + print(f"标题: {case['title']}") + print(f"结果: {result['categoryName']} ({result['confidence']:.2f})") + print(f"分类器: {result['classifierType']}") + print(f"原因: {result.get('reason', 'N/A')}") + print(f"耗时: {result['duration']}ms") + print("-" * 50) +``` + +#### 任务 4.1.4: `src/api/server.py` - FastAPI服务 + +```python +""" +机器学习模型API服务 +使用FastAPI提供RESTful API +""" + +from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from typing import Optional +import logging + +# 导入分类器 +import sys +import os +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from traditional.predict import TraditionalPredictor +from hybrid.hybrid_classifier import HybridClassifier + +# 配置日志 +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# 创建FastAPI应用 +app = FastAPI( + title="新闻分类API", + description="提供新闻文本分类服务", + version="1.0.0" +) + +# 配置CORS +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +# 请求模型 +class ClassifyRequest(BaseModel): + title: str + content: str + mode: Optional[str] = 'hybrid' # traditional, hybrid + + +# 响应模型 +class ClassifyResponse(BaseModel): + categoryCode: str + categoryName: str + confidence: float + classifierType: str + duration: int + probabilities: Optional[dict] = None + + +# 初始化分类器 +nb_predictor = None +hybrid_classifier = None + + +@app.on_event("startup") +async def startup_event(): + """启动时加载模型""" + global nb_predictor, hybrid_classifier + + logger.info("加载模型...") + + try: + nb_predictor = TraditionalPredictor('nb') + logger.info("朴素贝叶斯模型加载成功") + except Exception as e: + logger.error(f"朴素贝叶斯模型加载失败: {e}") + + try: + hybrid_classifier = HybridClassifier() + logger.info("混合分类器初始化成功") + except Exception as e: + logger.error(f"混合分类器初始化失败: {e}") + + +@app.get("/") +async def root(): + """健康检查""" + return { + "status": "ok", + "message": "新闻分类API服务运行中" + } + + +@app.get("/health") +async def health_check(): + """健康检查""" + return { + "status": "healthy", + "models": { + "nb_loaded": nb_predictor is not None, + "hybrid_loaded": hybrid_classifier is not None + } + } + + +@app.post("/api/predict", response_model=ClassifyResponse) +async def predict(request: ClassifyRequest): + """ + 文本分类接口 + + - **title**: 新闻标题 + - **content**: 新闻内容 + - **mode**: 分类模式 (traditional, hybrid) + """ + try: + if request.mode == 'traditional': + result = nb_predictor.predict(request.title, request.content) + result['classifierType'] = 'ML' + else: # hybrid + result = hybrid_classifier.predict(request.title, request.content) + + return ClassifyResponse(**result) + + except Exception as e: + logger.error(f"预测失败: {e}") + raise HTTPException(status_code=500, detail=str(e)) + + +@app.post("/api/batch-predict") +async def batch_predict(requests: list[ClassifyRequest]): + """ + 批量分类接口 + """ + results = [] + for req in requests: + try: + if req.mode == 'traditional': + result = nb_predictor.predict(req.title, req.content) + result['classifierType'] = 'ML' + else: + result = hybrid_classifier.predict(req.title, req.content) + results.append(result) + except Exception as e: + results.append({ + 'error': str(e), + 'title': req.title + }) + + return {"results": results} + + +if __name__ == '__main__': + import uvicorn + + uvicorn.run( + app, + host="0.0.0.0", + port=5000, + log_level="info" + ) +``` + +#### 任务 4.1.5: `src/utils/metrics.py` - 评估指标 + +```python +""" +模型评估指标工具 +""" + +import numpy as np +from sklearn.metrics import ( + accuracy_score, + precision_recall_fscore_support, + confusion_matrix, + classification_report +) +from typing import List, Dict, Any +import matplotlib.pyplot as plt +import seaborn as sns + + +class ClassificationMetrics: + """分类评估指标""" + + @staticmethod + def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]: + """ + 计算所有指标 + """ + accuracy = accuracy_score(y_true, y_pred) + + precision, recall, f1, support = precision_recall_fscore_support( + y_true, y_pred, average='weighted', zero_division=0 + ) + + # 每个类别的指标 + precision_per_class, recall_per_class, f1_per_class, support_per_class = \ + precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0) + + per_class_metrics = {} + for i, label in enumerate(labels): + per_class_metrics[label] = { + 'precision': float(precision_per_class[i]), + 'recall': float(recall_per_class[i]), + 'f1': float(f1_per_class[i]), + 'support': int(support_per_class[i]) + } + + return { + 'accuracy': float(accuracy), + 'precision': float(precision), + 'recall': float(recall), + 'f1': float(f1), + 'per_class': per_class_metrics + } + + @staticmethod + def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None): + """ + 绘制混淆矩阵 + """ + cm = confusion_matrix(y_true, y_pred) + + plt.figure(figsize=(10, 8)) + sns.heatmap( + cm, + annot=True, + fmt='d', + cmap='Blues', + xticklabels=labels, + yticklabels=labels + ) + plt.xlabel('预测标签') + plt.ylabel('真实标签') + plt.title('混淆矩阵') + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + plt.close() + + @staticmethod + def print_report(y_true: List, y_pred: List, labels: List[str]): + """ + 打印分类报告 + """ + report = classification_report( + y_true, y_pred, + target_names=labels, + zero_division=0 + ) + print(report) + + +if __name__ == '__main__': + # 测试 + y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY'] + y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY'] + labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE'] + + metrics = ClassificationMetrics() + result = metrics.compute_all(y_true, y_pred, labels) + print(result) +``` + +#### 任务 4.1.6: `requirements.txt` - 依赖文件 + +```txt +# 机器学习模块依赖 +numpy>=1.24.0 +pandas>=2.0.0 +scikit-learn>=1.3.0 +jieba>=0.42.0 +joblib>=1.3.0 + +# 深度学习 +torch>=2.0.0 +transformers>=4.30.0 + +# API服务 +fastapi>=0.100.0 +uvicorn[standard]>=0.23.0 +pydantic>=2.0.0 + +# 数据可视化 +matplotlib>=3.7.0 +seaborn>=0.12.0 + +# 工具 +python-dotenv>=1.0.0 +pyyaml>=6.0 +``` + +--- + +## 总结 + +### 开发顺序建议 + +1. **第一阶段:基础框架** + - 后端:数据库连接、实体类、基础配置 + - 前端:路由配置、状态管理、API封装 + +2. **第二阶段:核心功能** + - 爬虫模块(Python) + - 传统机器学习分类器 + - 后端API接口 + - 前端新闻列表页面 + +3. **第三阶段:高级功能** + - BERT深度学习分类器 + - 混合策略分类器 + - 前端分类器对比页面 + - 统计图表 + +4. **第四阶段:完善优化** + - 用户认证 + - 数据可视化 + - 性能优化 + - 异常处理 + +### 关键注意事项 + +1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信 +2. **分类器模块独立部署**,提供 HTTP 接口供后端调用 +3. **前后端分离**,使用 JWT 进行身份认证 +4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守 +5. **API 统一响应格式**使用 `Result` 包装