news-classifier/ml-module/train_pipeline.py

279 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
混合策略完整训练流程
整合基线模型和混合分类器的训练和评估
这是毕业设计的核心入口脚本,实现了参考方案中的完整流程:
1. 阶段一:基线模型 (TF-IDF + NB/SVM)
2. 阶段二:混合策略 (BERT特征提取 + SVM/LR)
3. 阶段三:可视化分析和对比
"""
import os
import sys
import argparse
import logging
import yaml
import pandas as pd
from typing import Dict, Any
# 添加项目路径
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, ROOT_DIR)
from src.baseline.train import BaselineTrainer
from src.hybrid.classifier import HybridClassifier
from src.utils.visualizer import Visualizer
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(os.path.join(ROOT_DIR, 'training.log'), encoding='utf-8'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def load_config(config_path: str) -> Dict[str, Any]:
"""加载配置文件"""
with open(config_path, 'r', encoding='utf-8') as f:
return yaml.safe_load(f)
def train_baseline(config: Dict[str, Any]) -> Dict[str, Any]:
"""
训练基线模型
展示传统方法的性能作为对照组
"""
logger.info("=" * 60)
logger.info("阶段一:训练基线模型 (TF-IDF + 传统分类器)")
logger.info("=" * 60)
baseline_config = config.get('baseline', {})
model_type = baseline_config.get('model_type', 'nb')
data_path = baseline_config.get('data_path', './data/processed/news_data.csv')
# 初始化训练器
trainer = BaselineTrainer(model_type=model_type)
# 加载数据
df = trainer.load_data(data_path)
# 训练
results = trainer.train(df, test_size=baseline_config.get('test_size', 0.2))
# 保存模型
trainer.save_model()
logger.info(f"基线模型训练完成 - 准确率: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}")
return {
'baseline': {
'model_type': model_type,
'accuracy': results['accuracy'],
'f1_score': results['f1_score'],
'confusion_matrix': results['confusion_matrix'],
'y_test': results['y_test'],
'y_pred': results['y_pred'],
'X_test': results['X_test'],
'classification_report': results['classification_report']
}
}
def train_hybrid(config: Dict[str, Any], use_cached_features: bool = True) -> Dict[str, Any]:
"""
训练混合分类器
BERT特征提取 + SVM/LR
"""
logger.info("=" * 60)
logger.info("阶段二:训练混合分类器 (BERT特征 + SVM/LR)")
logger.info("=" * 60)
hybrid_config = config.get('hybrid', {})
classifier_type = hybrid_config.get('classifier_type', 'svm')
bert_model = hybrid_config.get('bert_model', 'hfl/chinese-roberta-wwm-ext')
data_path = hybrid_config.get('data_path', './data/processed/news_data.csv')
# 初始化混合分类器
classifier = HybridClassifier(
classifier_type=classifier_type,
bert_model_name=bert_model
)
# 加载数据和提取BERT特征
features, labels, df = classifier.load_data(
data_path,
use_cached_features=use_cached_features
)
# 训练
results = classifier.train(
features,
labels,
test_size=hybrid_config.get('test_size', 0.2),
C=hybrid_config.get('C', 1.0)
)
# 保存模型
classifier.save_model()
logger.info(f"混合分类器训练完成 - 准确率: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}")
return {
'hybrid': {
'model_type': f'{bert_model} + {classifier_type.upper()}',
'accuracy': results['accuracy'],
'f1_score': results['f1_score'],
'confusion_matrix': results['confusion_matrix'],
'y_test': results['y_test'],
'y_pred': results['y_pred'],
'X_test': results['X_test'],
'classification_report': results['classification_report'],
'categories': results['categories'],
'category_names': results['category_names']
}
}
def visualize_and_compare(all_results: Dict[str, Dict], config: Dict[str, Any]):
"""
生成可视化图表
混淆矩阵、t-SNE、性能对比
"""
logger.info("=" * 60)
logger.info("阶段三:生成可视化分析图表")
logger.info("=" * 60)
viz_config = config.get('visualization', {})
viz = Visualizer(output_dir=viz_config.get('output_dir', './outputs/visualizations'))
baseline_results = all_results.get('baseline')
hybrid_results = all_results.get('hybrid')
# 1. 混淆矩阵
if baseline_results:
logger.info("生成基线模型混淆矩阵...")
viz.plot_confusion_matrix(
baseline_results['y_test'],
baseline_results['y_pred'],
categories=None, # 需要从数据中获取
model_name='Baseline',
show=viz_config.get('show', False)
)
if hybrid_results:
logger.info("生成混合分类器混淆矩阵...")
viz.plot_confusion_matrix(
hybrid_results['y_test'],
hybrid_results['y_pred'],
categories=hybrid_results['categories'],
category_names=hybrid_results['category_names'],
model_name='Hybrid',
show=viz_config.get('show', False)
)
# 2. t-SNE降维可视化仅混合模型
if hybrid_results and hybrid_results['X_test'] is not None:
logger.info("生成t-SNE降维可视化...")
viz.plot_tsne(
hybrid_results['X_test'],
hybrid_results['y_test'],
categories=hybrid_results['categories'],
category_names=hybrid_results['category_names'],
model_name='BERT_Features',
perplexity=viz_config.get('tsne_perplexity', 30),
show=viz_config.get('show', False)
)
# 3. 模型性能对比
if baseline_results and hybrid_results:
logger.info("生成模型性能对比图...")
comparison_results = {
'Baseline (TF-IDF+NB/SVM)': {
'accuracy': baseline_results['accuracy'],
'f1_score': baseline_results['f1_score']
},
'Hybrid (BERT+SVM/LR)': {
'accuracy': hybrid_results['accuracy'],
'f1_score': hybrid_results['f1_score']
}
}
viz.plot_model_comparison(
comparison_results,
show=viz_config.get('show', False)
)
logger.info(f"所有可视化图表已保存到: {viz.output_dir}")
def main():
parser = argparse.ArgumentParser(description='混合策略新闻分类器训练流程')
parser.add_argument('--config', type=str, default='config_hybrid.yaml', help='配置文件路径')
parser.add_argument('--baseline-only', action='store_true', help='仅训练基线模型')
parser.add_argument('--hybrid-only', action='store_true', help='仅训练混合分类器')
parser.add_argument('--reextract-features', action='store_true', help='重新提取BERT特征')
parser.add_argument('--skip-baseline', action='store_true', help='跳过基线模型训练')
args = parser.parse_args()
# 加载配置
config_path = os.path.join(ROOT_DIR, args.config)
if not os.path.exists(config_path):
# 使用默认配置
logger.warning(f"配置文件不存在: {args.config},使用默认配置")
config = {
'baseline': {
'model_type': 'svm',
'data_path': './data/processed/news_data.csv',
'test_size': 0.2
},
'hybrid': {
'classifier_type': 'svm',
'bert_model': 'hfl/chinese-roberta-wwm-ext',
'data_path': './data/processed/news_data.csv',
'test_size': 0.2,
'C': 1.0
},
'visualization': {
'output_dir': './outputs/visualizations',
'show': False,
'tsne_perplexity': 30
}
}
else:
config = load_config(config_path)
all_results = {}
# 训练流程
if not args.hybrid_only and not args.skip_baseline:
baseline_results = train_baseline(config)
all_results.update(baseline_results)
if not args.baseline_only:
hybrid_results = train_hybrid(config, use_cached_features=not args.reextract_features)
all_results.update(hybrid_results)
# 生成可视化
if not args.baseline_only and not args.hybrid_only:
visualize_and_compare(all_results, config)
# 打印总结
logger.info("=" * 60)
logger.info("训练完成!性能总结")
logger.info("=" * 60)
for model_name, results in all_results.items():
logger.info(f"{model_name.upper()}:")
logger.info(f" 准确率: {results['accuracy']:.4f}")
logger.info(f" F1-Score: {results['f1_score']:.4f}")
logger.info("=" * 60)
if __name__ == '__main__':
main()