|
|
||
|---|---|---|
| .. | ||
| .claude | ||
| data | ||
| models/traditional | ||
| src | ||
| README.md | ||
| config.yaml | ||
| database.md | ||
| requirements.txt | ||
| train_bert.py | ||
| 模块开发任务清单.md | ||
README.md
新闻文本分类系统 - 机器学习模块
功能特性
GPU/CPU自动检测
- 自动检测可用GPU(包括8GB显存)
- 自动回退到CPU(如果GPU不可用)
- 显示设备信息和显存信息
动态参数调整
- 通过配置文件调整训练参数
- 支持命令行参数覆盖
- 混合精度自动检测
使用方法
1. 安装依赖
pip install -r requirements.txt
从mysql拉去训练数据
python .\src\utils\data_loader.py
2. 配置训练参数
编辑 config.yaml 文件调整训练参数:
database:
url: "mysql+pymysql://root:root@localhost/news_classifier"
data_limit: 1000
model:
name: "bert-base-chinese"
num_labels: 9
output_dir: "./models/deep_learning/bert_finetuned"
training:
use_gpu: true # 自动检测GPU
epochs: 3
batch_size: 8
learning_rate: 2e-5
warmup_steps: 500
weight_decay: 0.01
3. 训练BERT模型
python train_bert.py
4. 使用命令行参数覆盖配置
# 使用GPU训练
python train_bert.py --use_gpu
# 调整训练轮数
python train_bert.py --epochs 5
# 调整批大小
python train_bert.py --batch_size 16
5. 启动API服务
python src/api/server.py
参数说明
训练参数
epochs: 训练轮数(默认:3)batch_size: 批大小(默认:8)learning_rate: 学习率(默认:2e-5)warmup_steps: 预热步数(默认:500)weight_decay: 权重衰减(默认:0.01)
设备配置
use_gpu: 是否使用GPU(自动检测)fp16: 混合精度(自动检测)
8GB显存优化建议
对于8GB显存的GPU,建议配置:
training:
use_gpu: true
epochs: 3
batch_size: 8-16
fp16: true # 启用混合精度
设备检测
训练时会自动检测设备:
- GPU可用:使用GPU训练,显示GPU名称和显存信息
- GPU不可用:自动回退到CPU训练
性能优化
- 混合精度训练(FP16)
- 梯度累积
- 自动批大小调整
- 内存优化
注意事项
- 确保安装了CUDA和cuDNN(如果使用GPU)
- 8GB显存可以训练BERT-base模型
- 可以通过调整batch_size适应不同显存大小
- 训练时间取决于数据量和硬件配置
API接口
# 单条预测
POST /api/predict
{
"title": "新闻标题",
"content": "新闻内容",
"mode": "hybrid" # traditional, hybrid
}
# 批量预测
POST /api/batch-predict
[
{
"title": "新闻标题1",
"content": "新闻内容1",
"mode": "hybrid"
},
{
"title": "新闻标题2",
"content": "新闻内容2",
"mode": "traditional"
}
]
分组方案
1️⃣ 基础科学计算 / 机器学习
pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0
2️⃣ 深度学习 / NLP
pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0
💡 如果你有 GPU 并希望安装 GPU 版 PyTorch,需要单独去 PyTorch 官网生成命令。
3️⃣ API 服务
pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0
4️⃣ 数据库相关
pip install sqlalchemy>=2.0.0 pymysql>=1.1.0
5️⃣ 数据可视化
pip install matplotlib>=3.7.0 seaborn>=0.12.0
6️⃣ 工具 / 配置文件处理
pip install python-dotenv>=1.0.0 pyyaml>=6.0
✅ 这样分批安装的好处:
- 出现安装错误时,更容易定位是哪个模块有问题。
- 每组依赖关系相对独立,减少冲突。
- CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。