""" BERT模型训练脚本 支持GPU/CPU自动检测和动态参数调整 """ import argparse import yaml import os from src.deep_learning.bert_model import BertClassifier from src.utils.data_loader import DataLoader from datasets import Dataset import pandas as pd from transformers import BertTokenizer, DataCollatorWithPadding def load_and_prepare_data(db_url: str, limit: int = 1000): """加载数据并准备训练集""" # 创建数据加载器 loader = DataLoader(db_url) # 从数据库加载数据 df = loader.load_news_from_db(limit=limit) if df.empty: print("没有可用的训练数据") return None, None # 准备数据 texts = df['title'] + ' ' + df['content'] labels = df['category_name'].map({ '娱乐': 0, '体育': 1, '财经': 2, '科技': 3, '军事': 4, '汽车': 5, '政务': 6, '健康': 7, 'AI': 8 }).fillna(-1).astype(int) # 过滤无效数据 valid_data = df[labels != -1] texts = valid_data['title'] + ' ' + valid_data['content'] labels = labels[labels != -1] print(f"有效数据数量: {len(valid_data)}") # 创建Hugging Face Dataset dataset = Dataset.from_pandas(pd.DataFrame({ 'text': texts.tolist(), 'label': labels.tolist() })) # 划分训练集和验证集 train_test = dataset.train_test_split(test_size=0.2) return train_test['train'], train_test['test'] def train_bert_model(config: dict): """训练BERT模型""" # 加载数据 train_dataset, eval_dataset = load_and_prepare_data( config['database']['url'], limit=config['training']['data_limit'] ) if train_dataset is None: return # 初始化模型 classifier = BertClassifier( model_name=config['model']['name'], num_labels=config['model']['num_labels'], use_gpu=config['training']['use_gpu'] ) # 训练模型 classifier.train_model( train_dataset=train_dataset, eval_dataset=eval_dataset, output_dir=config['model']['output_dir'], num_train_epochs=config['training']['epochs'], per_device_train_batch_size=config['training']['batch_size'], per_device_eval_batch_size=config['training']['batch_size'] * 2, learning_rate=config['training']['learning_rate'], warmup_steps=config['training']['warmup_steps'], weight_decay=config['training']['weight_decay'], fp16=config['training'].get('fp16', None) ) def main(): # 解析命令行参数 parser = argparse.ArgumentParser(description='BERT模型训练脚本') parser.add_argument('--config', type=str, default='config.yaml', help='配置文件路径') parser.add_argument('--use_gpu', action='store_true', help='强制使用GPU') parser.add_argument('--epochs', type=int, help='训练轮数(覆盖配置文件)') parser.add_argument('--batch_size', type=int, help='批大小(覆盖配置文件)') args = parser.parse_args() # 加载配置文件 with open(args.config, 'r', encoding='utf-8') as f: config = yaml.safe_load(f) # 覆盖配置参数 if args.epochs is not None: config['training']['epochs'] = args.epochs if args.batch_size is not None: config['training']['batch_size'] = args.batch_size if args.use_gpu: config['training']['use_gpu'] = True # 开始训练 print("开始BERT模型训练...") print(f"配置: {config}") train_bert_model(config) print("训练完成!") if __name__ == '__main__': main()