# 新闻文本分类系统 - 机器学习模块 ## 功能特性 ### GPU/CPU自动检测 - 自动检测可用GPU(包括8GB显存) - 自动回退到CPU(如果GPU不可用) - 显示设备信息和显存信息 ### 动态参数调整 - 通过配置文件调整训练参数 - 支持命令行参数覆盖 - 混合精度自动检测 ## 使用方法 ### 1. 安装依赖 ```bash pip install -r requirements.txt ``` ### 从mysql拉去训练数据 ```bash python .\src\utils\data_loader.py ``` ### 2. 配置训练参数 编辑 `config.yaml` 文件调整训练参数: ```yaml database: url: "mysql+pymysql://root:root@localhost/news_classifier" data_limit: 1000 model: name: "bert-base-chinese" num_labels: 9 output_dir: "./models/deep_learning/bert_finetuned" training: use_gpu: true # 自动检测GPU epochs: 3 batch_size: 8 learning_rate: 2e-5 warmup_steps: 500 weight_decay: 0.01 ``` ### 3. 训练BERT模型 ```bash python train_bert.py ``` ### 4. 使用命令行参数覆盖配置 ```bash # 使用GPU训练 python train_bert.py --use_gpu # 调整训练轮数 python train_bert.py --epochs 5 # 调整批大小 python train_bert.py --batch_size 16 ``` ### 5. 启动API服务 ```bash python src/api/server.py ``` ## 参数说明 ### 训练参数 - `epochs`: 训练轮数(默认:3) - `batch_size`: 批大小(默认:8) - `learning_rate`: 学习率(默认:2e-5) - `warmup_steps`: 预热步数(默认:500) - `weight_decay`: 权重衰减(默认:0.01) ### 设备配置 - `use_gpu`: 是否使用GPU(自动检测) - `fp16`: 混合精度(自动检测) ## 8GB显存优化建议 对于8GB显存的GPU,建议配置: ```yaml training: use_gpu: true epochs: 3 batch_size: 8-16 fp16: true # 启用混合精度 ``` ## 设备检测 训练时会自动检测设备: - **GPU可用**:使用GPU训练,显示GPU名称和显存信息 - **GPU不可用**:自动回退到CPU训练 ## 性能优化 - 混合精度训练(FP16) - 梯度累积 - 自动批大小调整 - 内存优化 ## 注意事项 1. 确保安装了CUDA和cuDNN(如果使用GPU) 2. 8GB显存可以训练BERT-base模型 3. 可以通过调整batch_size适应不同显存大小 4. 训练时间取决于数据量和硬件配置 ## API接口 ```bash # 单条预测 POST /api/predict { "title": "新闻标题", "content": "新闻内容", "mode": "hybrid" # traditional, hybrid } # 批量预测 POST /api/batch-predict [ { "title": "新闻标题1", "content": "新闻内容1", "mode": "hybrid" }, { "title": "新闻标题2", "content": "新闻内容2", "mode": "traditional" } ] ``` --- ## **分组方案** ### **1️⃣ 基础科学计算 / 机器学习** ```cmd pip install numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 joblib>=1.3.0 ``` --- ### **2️⃣ 深度学习 / NLP** ```cmd pip install torch>=2.0.0 transformers>=4.30.0 jieba>=0.42.0 ``` 💡 如果你有 **GPU** 并希望安装 GPU 版 PyTorch,需要单独去 PyTorch 官网生成命令。 --- ### **3️⃣ API 服务** ```cmd pip install fastapi>=0.100.0 "uvicorn[standard]>=0.23.0" pydantic>=2.0.0 ``` --- ### **4️⃣ 数据库相关** ```cmd pip install sqlalchemy>=2.0.0 pymysql>=1.1.0 ``` --- ### **5️⃣ 数据可视化** ```cmd pip install matplotlib>=3.7.0 seaborn>=0.12.0 ``` --- ### **6️⃣ 工具 / 配置文件处理** ```cmd pip install python-dotenv>=1.0.0 pyyaml>=6.0 ``` --- ✅ **这样分批安装的好处**: 1. 出现安装错误时,更容易定位是哪个模块有问题。 2. 每组依赖关系相对独立,减少冲突。 3. CMD 一次执行一条命令,不需要处理复杂的换行符或引号问题。