198 lines
3.6 KiB
Markdown
198 lines
3.6 KiB
Markdown
# 新闻文本分类系统 - 机器学习模块
|
||
|
||
## 功能特性
|
||
|
||
### GPU/CPU自动检测
|
||
- 自动检测可用GPU(包括8GB显存)
|
||
- 自动回退到CPU(如果GPU不可用)
|
||
- 显示设备信息和显存信息
|
||
|
||
### 动态参数调整
|
||
- 通过配置文件调整训练参数
|
||
- 支持命令行参数覆盖
|
||
- 混合精度自动检测
|
||
|
||
## 使用方法
|
||
|
||
### 1. 安装依赖
|
||
```bash
|
||
pip install -r requirements.txt
|
||
```
|
||
### 从mysql拉去训练数据
|
||
|
||
```bash
|
||
python .\src\utils\data_loader.py
|
||
```
|
||
### 2. 配置训练参数
|
||
编辑 `config.yaml` 文件调整训练参数:
|
||
```yaml
|
||
database:
|
||
url: "mysql+pymysql://root:root@localhost/news_classifier"
|
||
data_limit: 1000
|
||
|
||
model:
|
||
name: "bert-base-chinese"
|
||
num_labels: 9
|
||
output_dir: "./models/deep_learning/bert_finetuned"
|
||
|
||
training:
|
||
use_gpu: true # 自动检测GPU
|
||
epochs: 3
|
||
batch_size: 8
|
||
learning_rate: 2e-5
|
||
warmup_steps: 500
|
||
weight_decay: 0.01
|
||
```
|
||
|
||
### 3. 训练BERT模型
|
||
```bash
|
||
python train_bert.py
|
||
```
|
||
|
||
### 4. 使用命令行参数覆盖配置
|
||
```bash
|
||
# 使用GPU训练
|
||
python train_bert.py --use_gpu
|
||
|
||
# 调整训练轮数
|
||
python train_bert.py --epochs 5
|
||
|
||
# 调整批大小
|
||
python train_bert.py --batch_size 16
|
||
```
|
||
|
||
### 5. 启动API服务
|
||
```bash
|
||
python src/api/server.py
|
||
```
|
||
|
||
## 参数说明
|
||
|
||
### 训练参数
|
||
- `epochs`: 训练轮数(默认:3)
|
||
- `batch_size`: 批大小(默认:8)
|
||
- `learning_rate`: 学习率(默认:2e-5)
|
||
- `warmup_steps`: 预热步数(默认:500)
|
||
- `weight_decay`: 权重衰减(默认:0.01)
|
||
|
||
### 设备配置
|
||
- `use_gpu`: 是否使用GPU(自动检测)
|
||
- `fp16`: 混合精度(自动检测)
|
||
|
||
## 8GB显存优化建议
|
||
|
||
对于8GB显存的GPU,建议配置:
|
||
```yaml
|
||
training:
|
||
use_gpu: true
|
||
epochs: 3
|
||
batch_size: 8-16
|
||
fp16: true # 启用混合精度
|
||
```
|
||
|
||
## 设备检测
|
||
训练时会自动检测设备:
|
||
- **GPU可用**:使用GPU训练,显示GPU名称和显存信息
|
||
- **GPU不可用**:自动回退到CPU训练
|
||
|
||
## 性能优化
|
||
- 混合精度训练(FP16)
|
||
- 梯度累积
|
||
- 自动批大小调整
|
||
- 内存优化
|
||
|
||
## 注意事项
|
||
1. 确保安装了CUDA和cuDNN(如果使用GPU)
|
||
2. 8GB显存可以训练BERT-base模型
|
||
3. 可以通过调整batch_size适应不同显存大小
|
||
4. 训练时间取决于数据量和硬件配置
|
||
|
||
## API接口
|
||
```bash
|
||
# 单条预测
|
||
POST /api/predict
|
||
{
|
||
"title": "新闻标题",
|
||
"content": "新闻内容",
|
||
"mode": "hybrid" # traditional, hybrid
|
||
}
|
||
|
||
# 批量预测
|
||
POST /api/batch-predict
|
||
[
|
||
{
|
||
"title": "新闻标题1",
|
||
"content": "新闻内容1",
|
||
"mode": "hybrid"
|
||
},
|
||
{
|
||
"title": "新闻标题2",
|
||
"content": "新闻内容2",
|
||
"mode": "traditional"
|
||
}
|
||
]
|
||
```
|
||
|
||
|
||
|
||
---
|
||
|
||
## **分组方案**
|
||
|
||
### **1️⃣ 基础科学计算 / 机器学习**
|
||
|
||
```cmd
|
||
pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0
|
||
```
|
||
|
||
---
|
||
|
||
### **2️⃣ 深度学习 / NLP**
|
||
|
||
```cmd
|
||
pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0
|
||
```
|
||
|
||
💡 如果你有 **GPU** 并希望安装 GPU 版 PyTorch,需要单独去 PyTorch 官网生成命令。
|
||
|
||
---
|
||
|
||
### **3️⃣ API 服务**
|
||
|
||
```cmd
|
||
pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0
|
||
```
|
||
|
||
---
|
||
|
||
### **4️⃣ 数据库相关**
|
||
|
||
```cmd
|
||
pip install sqlalchemy>=2.0.0 pymysql>=1.1.0
|
||
```
|
||
|
||
---
|
||
|
||
### **5️⃣ 数据可视化**
|
||
|
||
```cmd
|
||
pip install matplotlib>=3.7.0 seaborn>=0.12.0
|
||
```
|
||
|
||
---
|
||
|
||
### **6️⃣ 工具 / 配置文件处理**
|
||
|
||
```cmd
|
||
pip install python-dotenv>=1.0.0 pyyaml>=6.0
|
||
```
|
||
|
||
---
|
||
|
||
✅ **这样分批安装的好处**:
|
||
|
||
1. 出现安装错误时,更容易定位是哪个模块有问题。
|
||
2. 每组依赖关系相对独立,减少冲突。
|
||
3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。
|
||
|