114 lines
3.5 KiB
Python
114 lines
3.5 KiB
Python
"""
|
|
BERT模型训练脚本
|
|
支持GPU/CPU自动检测和动态参数调整
|
|
"""
|
|
|
|
import argparse
|
|
import yaml
|
|
import os
|
|
from src.deep_learning.bert_model import BertClassifier
|
|
from src.utils.data_loader import DataLoader
|
|
from datasets import Dataset
|
|
import pandas as pd
|
|
from transformers import BertTokenizer, DataCollatorWithPadding
|
|
|
|
def load_and_prepare_data(db_url: str, limit: int = 1000):
|
|
"""加载数据并准备训练集"""
|
|
# 创建数据加载器
|
|
loader = DataLoader(db_url)
|
|
|
|
# 从数据库加载数据
|
|
df = loader.load_news_from_db(limit=limit)
|
|
|
|
if df.empty:
|
|
print("没有可用的训练数据")
|
|
return None, None
|
|
|
|
# 准备数据
|
|
texts = df['title'] + ' ' + df['content']
|
|
labels = df['category_name'].map({
|
|
'娱乐': 0, '体育': 1, '财经': 2, '科技': 3, '军事': 4,
|
|
'汽车': 5, '政务': 6, '健康': 7, 'AI': 8
|
|
}).fillna(-1).astype(int)
|
|
|
|
# 过滤无效数据
|
|
valid_data = df[labels != -1]
|
|
texts = valid_data['title'] + ' ' + valid_data['content']
|
|
labels = labels[labels != -1]
|
|
|
|
print(f"有效数据数量: {len(valid_data)}")
|
|
|
|
# 创建Hugging Face Dataset
|
|
dataset = Dataset.from_pandas(pd.DataFrame({
|
|
'text': texts.tolist(),
|
|
'label': labels.tolist()
|
|
}))
|
|
|
|
# 划分训练集和验证集
|
|
train_test = dataset.train_test_split(test_size=0.2)
|
|
|
|
return train_test['train'], train_test['test']
|
|
|
|
def train_bert_model(config: dict):
|
|
"""训练BERT模型"""
|
|
# 加载数据
|
|
train_dataset, eval_dataset = load_and_prepare_data(
|
|
config['database']['url'],
|
|
limit=config['training']['data_limit']
|
|
)
|
|
|
|
if train_dataset is None:
|
|
return
|
|
|
|
# 初始化模型
|
|
classifier = BertClassifier(
|
|
model_name=config['model']['name'],
|
|
num_labels=config['model']['num_labels'],
|
|
use_gpu=config['training']['use_gpu']
|
|
)
|
|
|
|
# 训练模型
|
|
classifier.train_model(
|
|
train_dataset=train_dataset,
|
|
eval_dataset=eval_dataset,
|
|
output_dir=config['model']['output_dir'],
|
|
num_train_epochs=config['training']['epochs'],
|
|
per_device_train_batch_size=config['training']['batch_size'],
|
|
per_device_eval_batch_size=config['training']['batch_size'] * 2,
|
|
learning_rate=config['training']['learning_rate'],
|
|
warmup_steps=config['training']['warmup_steps'],
|
|
weight_decay=config['training']['weight_decay'],
|
|
fp16=config['training'].get('fp16', None)
|
|
)
|
|
|
|
def main():
|
|
# 解析命令行参数
|
|
parser = argparse.ArgumentParser(description='BERT模型训练脚本')
|
|
parser.add_argument('--config', type=str, default='config.yaml', help='配置文件路径')
|
|
parser.add_argument('--use_gpu', action='store_true', help='强制使用GPU')
|
|
parser.add_argument('--epochs', type=int, help='训练轮数(覆盖配置文件)')
|
|
parser.add_argument('--batch_size', type=int, help='批大小(覆盖配置文件)')
|
|
args = parser.parse_args()
|
|
|
|
# 加载配置文件
|
|
with open(args.config, 'r', encoding='utf-8') as f:
|
|
config = yaml.safe_load(f)
|
|
|
|
# 覆盖配置参数
|
|
if args.epochs is not None:
|
|
config['training']['epochs'] = args.epochs
|
|
if args.batch_size is not None:
|
|
config['training']['batch_size'] = args.batch_size
|
|
if args.use_gpu:
|
|
config['training']['use_gpu'] = True
|
|
|
|
# 开始训练
|
|
print("开始BERT模型训练...")
|
|
print(f"配置: {config}")
|
|
|
|
train_bert_model(config)
|
|
|
|
print("训练完成!")
|
|
|
|
if __name__ == '__main__':
|
|
main() |