""" 混合策略完整训练流程 整合基线模型和混合分类器的训练和评估 这是毕业设计的核心入口脚本,实现了参考方案中的完整流程: 1. 阶段一:基线模型 (TF-IDF + NB/SVM) 2. 阶段二:混合策略 (BERT特征提取 + SVM/LR) 3. 阶段三:可视化分析和对比 """ import os import sys import argparse import logging import yaml import pandas as pd from typing import Dict, Any # 添加项目路径 ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) sys.path.insert(0, ROOT_DIR) from src.baseline.train import BaselineTrainer from src.hybrid.classifier import HybridClassifier from src.utils.visualizer import Visualizer # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(os.path.join(ROOT_DIR, 'training.log'), encoding='utf-8'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def load_config(config_path: str) -> Dict[str, Any]: """加载配置文件""" with open(config_path, 'r', encoding='utf-8') as f: return yaml.safe_load(f) def train_baseline(config: Dict[str, Any]) -> Dict[str, Any]: """ 训练基线模型 展示传统方法的性能作为对照组 """ logger.info("=" * 60) logger.info("阶段一:训练基线模型 (TF-IDF + 传统分类器)") logger.info("=" * 60) baseline_config = config.get('baseline', {}) model_type = baseline_config.get('model_type', 'nb') data_path = baseline_config.get('data_path', './data/processed/news_data.csv') # 初始化训练器 trainer = BaselineTrainer(model_type=model_type) # 加载数据 df = trainer.load_data(data_path) # 训练 results = trainer.train(df, test_size=baseline_config.get('test_size', 0.2)) # 保存模型 trainer.save_model() logger.info(f"基线模型训练完成 - 准确率: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}") return { 'baseline': { 'model_type': model_type, 'accuracy': results['accuracy'], 'f1_score': results['f1_score'], 'confusion_matrix': results['confusion_matrix'], 'y_test': results['y_test'], 'y_pred': results['y_pred'], 'X_test': results['X_test'], 'classification_report': results['classification_report'] } } def train_hybrid(config: Dict[str, Any], use_cached_features: bool = True) -> Dict[str, Any]: """ 训练混合分类器 BERT特征提取 + SVM/LR """ logger.info("=" * 60) logger.info("阶段二:训练混合分类器 (BERT特征 + SVM/LR)") logger.info("=" * 60) hybrid_config = config.get('hybrid', {}) classifier_type = hybrid_config.get('classifier_type', 'svm') bert_model = hybrid_config.get('bert_model', 'hfl/chinese-roberta-wwm-ext') data_path = hybrid_config.get('data_path', './data/processed/news_data.csv') # 初始化混合分类器 classifier = HybridClassifier( classifier_type=classifier_type, bert_model_name=bert_model ) # 加载数据和提取BERT特征 features, labels, df = classifier.load_data( data_path, use_cached_features=use_cached_features ) # 训练 results = classifier.train( features, labels, test_size=hybrid_config.get('test_size', 0.2), C=hybrid_config.get('C', 1.0) ) # 保存模型 classifier.save_model() logger.info(f"混合分类器训练完成 - 准确率: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}") return { 'hybrid': { 'model_type': f'{bert_model} + {classifier_type.upper()}', 'accuracy': results['accuracy'], 'f1_score': results['f1_score'], 'confusion_matrix': results['confusion_matrix'], 'y_test': results['y_test'], 'y_pred': results['y_pred'], 'X_test': results['X_test'], 'classification_report': results['classification_report'], 'categories': results['categories'], 'category_names': results['category_names'] } } def visualize_and_compare(all_results: Dict[str, Dict], config: Dict[str, Any]): """ 生成可视化图表 混淆矩阵、t-SNE、性能对比 """ logger.info("=" * 60) logger.info("阶段三:生成可视化分析图表") logger.info("=" * 60) viz_config = config.get('visualization', {}) viz = Visualizer(output_dir=viz_config.get('output_dir', './outputs/visualizations')) baseline_results = all_results.get('baseline') hybrid_results = all_results.get('hybrid') # 1. 混淆矩阵 if baseline_results: logger.info("生成基线模型混淆矩阵...") viz.plot_confusion_matrix( baseline_results['y_test'], baseline_results['y_pred'], categories=None, # 需要从数据中获取 model_name='Baseline', show=viz_config.get('show', False) ) if hybrid_results: logger.info("生成混合分类器混淆矩阵...") viz.plot_confusion_matrix( hybrid_results['y_test'], hybrid_results['y_pred'], categories=hybrid_results['categories'], category_names=hybrid_results['category_names'], model_name='Hybrid', show=viz_config.get('show', False) ) # 2. t-SNE降维可视化(仅混合模型) if hybrid_results and hybrid_results['X_test'] is not None: logger.info("生成t-SNE降维可视化...") viz.plot_tsne( hybrid_results['X_test'], hybrid_results['y_test'], categories=hybrid_results['categories'], category_names=hybrid_results['category_names'], model_name='BERT_Features', perplexity=viz_config.get('tsne_perplexity', 30), show=viz_config.get('show', False) ) # 3. 模型性能对比 if baseline_results and hybrid_results: logger.info("生成模型性能对比图...") comparison_results = { 'Baseline (TF-IDF+NB/SVM)': { 'accuracy': baseline_results['accuracy'], 'f1_score': baseline_results['f1_score'] }, 'Hybrid (BERT+SVM/LR)': { 'accuracy': hybrid_results['accuracy'], 'f1_score': hybrid_results['f1_score'] } } viz.plot_model_comparison( comparison_results, show=viz_config.get('show', False) ) logger.info(f"所有可视化图表已保存到: {viz.output_dir}") def main(): parser = argparse.ArgumentParser(description='混合策略新闻分类器训练流程') parser.add_argument('--config', type=str, default='config_hybrid.yaml', help='配置文件路径') parser.add_argument('--baseline-only', action='store_true', help='仅训练基线模型') parser.add_argument('--hybrid-only', action='store_true', help='仅训练混合分类器') parser.add_argument('--reextract-features', action='store_true', help='重新提取BERT特征') parser.add_argument('--skip-baseline', action='store_true', help='跳过基线模型训练') args = parser.parse_args() # 加载配置 config_path = os.path.join(ROOT_DIR, args.config) if not os.path.exists(config_path): # 使用默认配置 logger.warning(f"配置文件不存在: {args.config},使用默认配置") config = { 'baseline': { 'model_type': 'svm', 'data_path': './data/processed/news_data.csv', 'test_size': 0.2 }, 'hybrid': { 'classifier_type': 'svm', 'bert_model': 'hfl/chinese-roberta-wwm-ext', 'data_path': './data/processed/news_data.csv', 'test_size': 0.2, 'C': 1.0 }, 'visualization': { 'output_dir': './outputs/visualizations', 'show': False, 'tsne_perplexity': 30 } } else: config = load_config(config_path) all_results = {} # 训练流程 if not args.hybrid_only and not args.skip_baseline: baseline_results = train_baseline(config) all_results.update(baseline_results) if not args.baseline_only: hybrid_results = train_hybrid(config, use_cached_features=not args.reextract_features) all_results.update(hybrid_results) # 生成可视化 if not args.baseline_only and not args.hybrid_only: visualize_and_compare(all_results, config) # 打印总结 logger.info("=" * 60) logger.info("训练完成!性能总结") logger.info("=" * 60) for model_name, results in all_results.items(): logger.info(f"{model_name.upper()}:") logger.info(f" 准确率: {results['accuracy']:.4f}") logger.info(f" F1-Score: {results['f1_score']:.4f}") logger.info("=" * 60) if __name__ == '__main__': main()