279 lines
9.0 KiB
Python
279 lines
9.0 KiB
Python
"""
|
||
混合策略完整训练流程
|
||
整合基线模型和混合分类器的训练和评估
|
||
|
||
这是毕业设计的核心入口脚本,实现了参考方案中的完整流程:
|
||
1. 阶段一:基线模型 (TF-IDF + NB/SVM)
|
||
2. 阶段二:混合策略 (BERT特征提取 + SVM/LR)
|
||
3. 阶段三:可视化分析和对比
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import logging
|
||
import yaml
|
||
import pandas as pd
|
||
from typing import Dict, Any
|
||
|
||
# 添加项目路径
|
||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||
sys.path.insert(0, ROOT_DIR)
|
||
|
||
from src.baseline.train import BaselineTrainer
|
||
from src.hybrid.classifier import HybridClassifier
|
||
from src.utils.visualizer import Visualizer
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.FileHandler(os.path.join(ROOT_DIR, 'training.log'), encoding='utf-8'),
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def load_config(config_path: str) -> Dict[str, Any]:
|
||
"""加载配置文件"""
|
||
with open(config_path, 'r', encoding='utf-8') as f:
|
||
return yaml.safe_load(f)
|
||
|
||
|
||
def train_baseline(config: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
训练基线模型
|
||
展示传统方法的性能作为对照组
|
||
"""
|
||
logger.info("=" * 60)
|
||
logger.info("阶段一:训练基线模型 (TF-IDF + 传统分类器)")
|
||
logger.info("=" * 60)
|
||
|
||
baseline_config = config.get('baseline', {})
|
||
model_type = baseline_config.get('model_type', 'nb')
|
||
data_path = baseline_config.get('data_path', './data/processed/news_data.csv')
|
||
|
||
# 初始化训练器
|
||
trainer = BaselineTrainer(model_type=model_type)
|
||
|
||
# 加载数据
|
||
df = trainer.load_data(data_path)
|
||
|
||
# 训练
|
||
results = trainer.train(df, test_size=baseline_config.get('test_size', 0.2))
|
||
|
||
# 保存模型
|
||
trainer.save_model()
|
||
|
||
logger.info(f"基线模型训练完成 - 准确率: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}")
|
||
|
||
return {
|
||
'baseline': {
|
||
'model_type': model_type,
|
||
'accuracy': results['accuracy'],
|
||
'f1_score': results['f1_score'],
|
||
'confusion_matrix': results['confusion_matrix'],
|
||
'y_test': results['y_test'],
|
||
'y_pred': results['y_pred'],
|
||
'X_test': results['X_test'],
|
||
'classification_report': results['classification_report']
|
||
}
|
||
}
|
||
|
||
|
||
def train_hybrid(config: Dict[str, Any], use_cached_features: bool = True) -> Dict[str, Any]:
|
||
"""
|
||
训练混合分类器
|
||
BERT特征提取 + SVM/LR
|
||
"""
|
||
logger.info("=" * 60)
|
||
logger.info("阶段二:训练混合分类器 (BERT特征 + SVM/LR)")
|
||
logger.info("=" * 60)
|
||
|
||
hybrid_config = config.get('hybrid', {})
|
||
classifier_type = hybrid_config.get('classifier_type', 'svm')
|
||
bert_model = hybrid_config.get('bert_model', 'hfl/chinese-roberta-wwm-ext')
|
||
data_path = hybrid_config.get('data_path', './data/processed/news_data.csv')
|
||
|
||
# 初始化混合分类器
|
||
classifier = HybridClassifier(
|
||
classifier_type=classifier_type,
|
||
bert_model_name=bert_model
|
||
)
|
||
|
||
# 加载数据和提取BERT特征
|
||
features, labels, df = classifier.load_data(
|
||
data_path,
|
||
use_cached_features=use_cached_features
|
||
)
|
||
|
||
# 训练
|
||
results = classifier.train(
|
||
features,
|
||
labels,
|
||
test_size=hybrid_config.get('test_size', 0.2),
|
||
C=hybrid_config.get('C', 1.0)
|
||
)
|
||
|
||
# 保存模型
|
||
classifier.save_model()
|
||
|
||
logger.info(f"混合分类器训练完成 - 准确率: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}")
|
||
|
||
return {
|
||
'hybrid': {
|
||
'model_type': f'{bert_model} + {classifier_type.upper()}',
|
||
'accuracy': results['accuracy'],
|
||
'f1_score': results['f1_score'],
|
||
'confusion_matrix': results['confusion_matrix'],
|
||
'y_test': results['y_test'],
|
||
'y_pred': results['y_pred'],
|
||
'X_test': results['X_test'],
|
||
'classification_report': results['classification_report'],
|
||
'categories': results['categories'],
|
||
'category_names': results['category_names']
|
||
}
|
||
}
|
||
|
||
|
||
def visualize_and_compare(all_results: Dict[str, Dict], config: Dict[str, Any]):
|
||
"""
|
||
生成可视化图表
|
||
混淆矩阵、t-SNE、性能对比
|
||
"""
|
||
logger.info("=" * 60)
|
||
logger.info("阶段三:生成可视化分析图表")
|
||
logger.info("=" * 60)
|
||
|
||
viz_config = config.get('visualization', {})
|
||
viz = Visualizer(output_dir=viz_config.get('output_dir', './outputs/visualizations'))
|
||
|
||
baseline_results = all_results.get('baseline')
|
||
hybrid_results = all_results.get('hybrid')
|
||
|
||
# 1. 混淆矩阵
|
||
if baseline_results:
|
||
logger.info("生成基线模型混淆矩阵...")
|
||
viz.plot_confusion_matrix(
|
||
baseline_results['y_test'],
|
||
baseline_results['y_pred'],
|
||
categories=None, # 需要从数据中获取
|
||
model_name='Baseline',
|
||
show=viz_config.get('show', False)
|
||
)
|
||
|
||
if hybrid_results:
|
||
logger.info("生成混合分类器混淆矩阵...")
|
||
viz.plot_confusion_matrix(
|
||
hybrid_results['y_test'],
|
||
hybrid_results['y_pred'],
|
||
categories=hybrid_results['categories'],
|
||
category_names=hybrid_results['category_names'],
|
||
model_name='Hybrid',
|
||
show=viz_config.get('show', False)
|
||
)
|
||
|
||
# 2. t-SNE降维可视化(仅混合模型)
|
||
if hybrid_results and hybrid_results['X_test'] is not None:
|
||
logger.info("生成t-SNE降维可视化...")
|
||
viz.plot_tsne(
|
||
hybrid_results['X_test'],
|
||
hybrid_results['y_test'],
|
||
categories=hybrid_results['categories'],
|
||
category_names=hybrid_results['category_names'],
|
||
model_name='BERT_Features',
|
||
perplexity=viz_config.get('tsne_perplexity', 30),
|
||
show=viz_config.get('show', False)
|
||
)
|
||
|
||
# 3. 模型性能对比
|
||
if baseline_results and hybrid_results:
|
||
logger.info("生成模型性能对比图...")
|
||
comparison_results = {
|
||
'Baseline (TF-IDF+NB/SVM)': {
|
||
'accuracy': baseline_results['accuracy'],
|
||
'f1_score': baseline_results['f1_score']
|
||
},
|
||
'Hybrid (BERT+SVM/LR)': {
|
||
'accuracy': hybrid_results['accuracy'],
|
||
'f1_score': hybrid_results['f1_score']
|
||
}
|
||
}
|
||
viz.plot_model_comparison(
|
||
comparison_results,
|
||
show=viz_config.get('show', False)
|
||
)
|
||
|
||
logger.info(f"所有可视化图表已保存到: {viz.output_dir}")
|
||
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(description='混合策略新闻分类器训练流程')
|
||
parser.add_argument('--config', type=str, default='config_hybrid.yaml', help='配置文件路径')
|
||
parser.add_argument('--baseline-only', action='store_true', help='仅训练基线模型')
|
||
parser.add_argument('--hybrid-only', action='store_true', help='仅训练混合分类器')
|
||
parser.add_argument('--reextract-features', action='store_true', help='重新提取BERT特征')
|
||
parser.add_argument('--skip-baseline', action='store_true', help='跳过基线模型训练')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 加载配置
|
||
config_path = os.path.join(ROOT_DIR, args.config)
|
||
if not os.path.exists(config_path):
|
||
# 使用默认配置
|
||
logger.warning(f"配置文件不存在: {args.config},使用默认配置")
|
||
config = {
|
||
'baseline': {
|
||
'model_type': 'svm',
|
||
'data_path': './data/processed/news_data.csv',
|
||
'test_size': 0.2
|
||
},
|
||
'hybrid': {
|
||
'classifier_type': 'svm',
|
||
'bert_model': 'hfl/chinese-roberta-wwm-ext',
|
||
'data_path': './data/processed/news_data.csv',
|
||
'test_size': 0.2,
|
||
'C': 1.0
|
||
},
|
||
'visualization': {
|
||
'output_dir': './outputs/visualizations',
|
||
'show': False,
|
||
'tsne_perplexity': 30
|
||
}
|
||
}
|
||
else:
|
||
config = load_config(config_path)
|
||
|
||
all_results = {}
|
||
|
||
# 训练流程
|
||
if not args.hybrid_only and not args.skip_baseline:
|
||
baseline_results = train_baseline(config)
|
||
all_results.update(baseline_results)
|
||
|
||
if not args.baseline_only:
|
||
hybrid_results = train_hybrid(config, use_cached_features=not args.reextract_features)
|
||
all_results.update(hybrid_results)
|
||
|
||
# 生成可视化
|
||
if not args.baseline_only and not args.hybrid_only:
|
||
visualize_and_compare(all_results, config)
|
||
|
||
# 打印总结
|
||
logger.info("=" * 60)
|
||
logger.info("训练完成!性能总结")
|
||
logger.info("=" * 60)
|
||
|
||
for model_name, results in all_results.items():
|
||
logger.info(f"{model_name.upper()}:")
|
||
logger.info(f" 准确率: {results['accuracy']:.4f}")
|
||
logger.info(f" F1-Score: {results['f1_score']:.4f}")
|
||
|
||
logger.info("=" * 60)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|