news-classifier/ml-module/README.md

198 lines
3.6 KiB
Markdown
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 新闻文本分类系统 - 机器学习模块
## 功能特性
### GPU/CPU自动检测
- 自动检测可用GPU包括8GB显存
- 自动回退到CPU如果GPU不可用
- 显示设备信息和显存信息
### 动态参数调整
- 通过配置文件调整训练参数
- 支持命令行参数覆盖
- 混合精度自动检测
## 使用方法
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 从mysql拉去训练数据
```bash
python .\src\utils\data_loader.py
```
### 2. 配置训练参数
编辑 `config.yaml` 文件调整训练参数:
```yaml
database:
url: "mysql+pymysql://root:root@localhost/news_classifier"
data_limit: 1000
model:
name: "bert-base-chinese"
num_labels: 9
output_dir: "./models/deep_learning/bert_finetuned"
training:
use_gpu: true # 自动检测GPU
epochs: 3
batch_size: 8
learning_rate: 2e-5
warmup_steps: 500
weight_decay: 0.01
```
### 3. 训练BERT模型
```bash
python train_bert.py
```
### 4. 使用命令行参数覆盖配置
```bash
# 使用GPU训练
python train_bert.py --use_gpu
# 调整训练轮数
python train_bert.py --epochs 5
# 调整批大小
python train_bert.py --batch_size 16
```
### 5. 启动API服务
```bash
python src/api/server.py
```
## 参数说明
### 训练参数
- `epochs`: 训练轮数默认3
- `batch_size`: 批大小默认8
- `learning_rate`: 学习率默认2e-5
- `warmup_steps`: 预热步数默认500
- `weight_decay`: 权重衰减默认0.01
### 设备配置
- `use_gpu`: 是否使用GPU自动检测
- `fp16`: 混合精度(自动检测)
## 8GB显存优化建议
对于8GB显存的GPU建议配置
```yaml
training:
use_gpu: true
epochs: 3
batch_size: 8-16
fp16: true # 启用混合精度
```
## 设备检测
训练时会自动检测设备:
- **GPU可用**使用GPU训练显示GPU名称和显存信息
- **GPU不可用**自动回退到CPU训练
## 性能优化
- 混合精度训练FP16
- 梯度累积
- 自动批大小调整
- 内存优化
## 注意事项
1. 确保安装了CUDA和cuDNN如果使用GPU
2. 8GB显存可以训练BERT-base模型
3. 可以通过调整batch_size适应不同显存大小
4. 训练时间取决于数据量和硬件配置
## API接口
```bash
# 单条预测
POST /api/predict
{
"title": "新闻标题",
"content": "新闻内容",
"mode": "hybrid" # traditional, hybrid
}
# 批量预测
POST /api/batch-predict
[
{
"title": "新闻标题1",
"content": "新闻内容1",
"mode": "hybrid"
},
{
"title": "新闻标题2",
"content": "新闻内容2",
"mode": "traditional"
}
]
```
---
## **分组方案**
### **1⃣ 基础科学计算 / 机器学习**
```cmd
pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0
```
---
### **2⃣ 深度学习 / NLP**
```cmd
pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0
```
💡 如果你有 **GPU** 并希望安装 GPU 版 PyTorch需要单独去 PyTorch 官网生成命令。
---
### **3⃣ API 服务**
```cmd
pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0
```
---
### **4⃣ 数据库相关**
```cmd
pip install sqlalchemy>=2.0.0 pymysql>=1.1.0
```
---
### **5⃣ 数据可视化**
```cmd
pip install matplotlib>=3.7.0 seaborn>=0.12.0
```
---
### **6⃣ 工具 / 配置文件处理**
```cmd
pip install python-dotenv>=1.0.0 pyyaml>=6.0
```
---
**这样分批安装的好处**
1. 出现安装错误时,更容易定位是哪个模块有问题。
2. 每组依赖关系相对独立,减少冲突。
3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。