news-classifier/ml-module/README.md

3.6 KiB
Raw Blame History

新闻文本分类系统 - 机器学习模块

功能特性

GPU/CPU自动检测

  • 自动检测可用GPU包括8GB显存
  • 自动回退到CPU如果GPU不可用
  • 显示设备信息和显存信息

动态参数调整

  • 通过配置文件调整训练参数
  • 支持命令行参数覆盖
  • 混合精度自动检测

使用方法

1. 安装依赖

pip install -r requirements.txt

从mysql拉去训练数据

python .\src\utils\data_loader.py

2. 配置训练参数

编辑 config.yaml 文件调整训练参数:

database:
  url: "mysql+pymysql://root:root@localhost/news_classifier"
  data_limit: 1000

model:
  name: "bert-base-chinese"
  num_labels: 9
  output_dir: "./models/deep_learning/bert_finetuned"

training:
  use_gpu: true  # 自动检测GPU
  epochs: 3
  batch_size: 8
  learning_rate: 2e-5
  warmup_steps: 500
  weight_decay: 0.01

3. 训练BERT模型

python train_bert.py

4. 使用命令行参数覆盖配置

# 使用GPU训练
python train_bert.py --use_gpu

# 调整训练轮数
python train_bert.py --epochs 5

# 调整批大小
python train_bert.py --batch_size 16

5. 启动API服务

python src/api/server.py

参数说明

训练参数

  • epochs: 训练轮数默认3
  • batch_size: 批大小默认8
  • learning_rate: 学习率默认2e-5
  • warmup_steps: 预热步数默认500
  • weight_decay: 权重衰减默认0.01

设备配置

  • use_gpu: 是否使用GPU自动检测
  • fp16: 混合精度(自动检测)

8GB显存优化建议

对于8GB显存的GPU建议配置

training:
  use_gpu: true
  epochs: 3
  batch_size: 8-16
  fp16: true  # 启用混合精度

设备检测

训练时会自动检测设备:

  • GPU可用使用GPU训练显示GPU名称和显存信息
  • GPU不可用自动回退到CPU训练

性能优化

  • 混合精度训练FP16
  • 梯度累积
  • 自动批大小调整
  • 内存优化

注意事项

  1. 确保安装了CUDA和cuDNN如果使用GPU
  2. 8GB显存可以训练BERT-base模型
  3. 可以通过调整batch_size适应不同显存大小
  4. 训练时间取决于数据量和硬件配置

API接口

# 单条预测
POST /api/predict
{
  "title": "新闻标题",
  "content": "新闻内容",
  "mode": "hybrid"  # traditional, hybrid
}

# 批量预测
POST /api/batch-predict
[
  {
    "title": "新闻标题1",
    "content": "新闻内容1",
    "mode": "hybrid"
  },
  {
    "title": "新闻标题2",
    "content": "新闻内容2",
    "mode": "traditional"
  }
]

分组方案

1 基础科学计算 / 机器学习

pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0

2 深度学习 / NLP

pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0

💡 如果你有 GPU 并希望安装 GPU 版 PyTorch需要单独去 PyTorch 官网生成命令。


3 API 服务

pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0

4 数据库相关

pip install sqlalchemy>=2.0.0 pymysql>=1.1.0

5 数据可视化

pip install matplotlib>=3.7.0 seaborn>=0.12.0

6 工具 / 配置文件处理

pip install python-dotenv>=1.0.0 pyyaml>=6.0

这样分批安装的好处

  1. 出现安装错误时,更容易定位是哪个模块有问题。
  2. 每组依赖关系相对独立,减少冲突。
  3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。