news-classifier/ml-module/train_bert.py

114 lines
3.5 KiB
Python

"""
BERT模型训练脚本
支持GPU/CPU自动检测和动态参数调整
"""
import argparse
import yaml
import os
from src.deep_learning.bert_model import BertClassifier
from src.utils.data_loader import DataLoader
from datasets import Dataset
import pandas as pd
from transformers import BertTokenizer, DataCollatorWithPadding
def load_and_prepare_data(db_url: str, limit: int = 1000):
"""加载数据并准备训练集"""
# 创建数据加载器
loader = DataLoader(db_url)
# 从数据库加载数据
df = loader.load_news_from_db(limit=limit)
if df.empty:
print("没有可用的训练数据")
return None, None
# 准备数据
texts = df['title'] + ' ' + df['content']
labels = df['category_name'].map({
'娱乐': 0, '体育': 1, '财经': 2, '科技': 3, '军事': 4,
'汽车': 5, '政务': 6, '健康': 7, 'AI': 8
}).fillna(-1).astype(int)
# 过滤无效数据
valid_data = df[labels != -1]
texts = valid_data['title'] + ' ' + valid_data['content']
labels = labels[labels != -1]
print(f"有效数据数量: {len(valid_data)}")
# 创建Hugging Face Dataset
dataset = Dataset.from_pandas(pd.DataFrame({
'text': texts.tolist(),
'label': labels.tolist()
}))
# 划分训练集和验证集
train_test = dataset.train_test_split(test_size=0.2)
return train_test['train'], train_test['test']
def train_bert_model(config: dict):
"""训练BERT模型"""
# 加载数据
train_dataset, eval_dataset = load_and_prepare_data(
config['database']['url'],
limit=config['training']['data_limit']
)
if train_dataset is None:
return
# 初始化模型
classifier = BertClassifier(
model_name=config['model']['name'],
num_labels=config['model']['num_labels'],
use_gpu=config['training']['use_gpu']
)
# 训练模型
classifier.train_model(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
output_dir=config['model']['output_dir'],
num_train_epochs=config['training']['epochs'],
per_device_train_batch_size=config['training']['batch_size'],
per_device_eval_batch_size=config['training']['batch_size'] * 2,
learning_rate=config['training']['learning_rate'],
warmup_steps=config['training']['warmup_steps'],
weight_decay=config['training']['weight_decay'],
fp16=config['training'].get('fp16', None)
)
def main():
# 解析命令行参数
parser = argparse.ArgumentParser(description='BERT模型训练脚本')
parser.add_argument('--config', type=str, default='config.yaml', help='配置文件路径')
parser.add_argument('--use_gpu', action='store_true', help='强制使用GPU')
parser.add_argument('--epochs', type=int, help='训练轮数(覆盖配置文件)')
parser.add_argument('--batch_size', type=int, help='批大小(覆盖配置文件)')
args = parser.parse_args()
# 加载配置文件
with open(args.config, 'r', encoding='utf-8') as f:
config = yaml.safe_load(f)
# 覆盖配置参数
if args.epochs is not None:
config['training']['epochs'] = args.epochs
if args.batch_size is not None:
config['training']['batch_size'] = args.batch_size
if args.use_gpu:
config['training']['use_gpu'] = True
# 开始训练
print("开始BERT模型训练...")
print(f"配置: {config}")
train_bert_model(config)
print("训练完成!")
if __name__ == '__main__':
main()