news-classifier/ml-module/src/traditional/train_model.py

242 lines
8.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
传统机器学习文本分类器训练脚本
支持朴素贝叶斯和SVM两种算法
"""
import os
import jieba
import joblib
import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score, f1_score
# 分类映射(与数据库表一致)
CATEGORY_MAP = {
'ENTERTAINMENT': '娱乐',
'SPORTS': '体育',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'MILITARY': '军事',
'AUTOMOTIVE': '汽车',
'GOVERNMENT': '政务',
'HEALTH': '健康',
'AI': 'AI'
}
REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()}
# 分类专属停用词(与数据库表一致)
CATEGORY_STOPWORDS = {
'ENTERTAINMENT': {'主演', '影片', '电影', '电视剧', '节目', '导演', '角色', '上映', '粉丝'},
'SPORTS': {'比赛', '赛事', '赛季', '球队', '选手', '球员', '主场', '客场', '对阵', '比分'},
'FINANCE': {'亿元', '万元', '同比', '环比', '增长率', '数据', '报告', '统计', '财报', '季度', '年度'},
'TECHNOLOGY': {'技术', '系统', '平台', '方案', '应用', '功能', '版本', '升级', '研发', '推出'},
'MILITARY': {'部队', '军方', '演习', '训练', '装备', '武器', '作战', '行动', '部署'},
'AUTOMOTIVE': {'汽车', '车型', '上市', '发布', '销量', '市场', '品牌', '厂商'},
'GOVERNMENT': {'会议', '讲话', '指出', '强调', '部署', '落实', '推进', '要求', '精神', '决定', '意见', '方案', '安排'},
'HEALTH': {'医生', '专家', '建议', '提示', '提醒', '研究', '发现', '可能', '有助于'},
'AI': {'技术', '系统', '模型', '算法', '应用', '功能', '版本', '升级', '研发', '推出', '人工智能'}
}
class NewsClassifier:
"""新闻文本分类器"""
def __init__(self, model_type='nb', use_stopwords=True, use_category_stopwords=False):
"""
初始化分类器
:param model_type: 模型类型 'nb' 朴素贝叶斯 或 'svm' 支持向量机
:param use_stopwords: 是否使用通用停用词
:param use_category_stopwords: 是否使用分类专属停用词
"""
self.model_type = model_type
self.vectorizer = None
self.classifier = None
self.categories = list(CATEGORY_MAP.keys())
self.use_stopwords = use_stopwords
self.use_category_stopwords = use_category_stopwords
self.stopwords = set()
if self.use_stopwords:
self._load_stopwords()
def _load_stopwords(self, stopwords_path='../../data/news_stopwords.txt'):
"""
加载停用词表
"""
try:
with open(stopwords_path, 'r', encoding='utf-8') as f:
self.stopwords = set(line.strip() for line in f if line.strip())
print(f"已加载 {len(self.stopwords)} 个停用词")
except FileNotFoundError:
print(f"警告: 停用词文件不存在: {stopwords_path}")
def preprocess_text(self, text, category=None):
"""
文本预处理使用jieba分词 + 停用词过滤
:param text: 待处理文本
:param category: 可选,指定分类时使用分类专属停用词
"""
# 移除多余空格和换行
text = ' '.join(text.split())
# jieba分词
words = jieba.cut(text)
# 过滤停用词和单字词
result = []
for w in words:
# 过滤单字词
if len(w) <= 1:
continue
# 过滤通用停用词
if self.use_stopwords and w in self.stopwords:
continue
# 过滤分类专属停用词
if self.use_category_stopwords and category:
if w in CATEGORY_STOPWORDS.get(category, set()):
continue
result.append(w)
return ' '.join(result)
def load_data(self, csv_path):
"""
从CSV文件加载训练数据
"""
df = pd.read_csv(csv_path)
# 合并标题和内容作为特征
df['text'] = df['title'] + ' ' + df['content']
# 预处理(如果启用了分类专属停用词,需要传入分类信息)
if self.use_category_stopwords:
# 先转换为分类代码
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
# 逐行预处理,传入分类信息
df['processed_text'] = df.apply(
lambda row: self.preprocess_text(row['text'], row['category_code']),
axis=1
)
else:
# 不使用分类专属停用词,直接批量预处理
df['processed_text'] = df['text'].apply(self.preprocess_text)
# 转换分类名称为代码
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
return df
def train(self, df):
"""
训练模型
"""
X = df['processed_text'].values
y = df['category_code'].values
# 获取实际数据中的分类
actual_categories = sorted(df['category_code'].unique().tolist())
actual_category_names = [CATEGORY_MAP[cat] for cat in actual_categories]
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# TF-IDF特征提取
self.vectorizer = TfidfVectorizer(
max_features=5000,
ngram_range=(1, 2),
min_df=2
)
X_train_tfidf = self.vectorizer.fit_transform(X_train)
X_test_tfidf = self.vectorizer.transform(X_test)
# 选择分类器
if self.model_type == 'nb':
self.classifier = MultinomialNB(alpha=0.1)
elif self.model_type == 'svm':
self.classifier = SVC(kernel='linear', probability=True, random_state=42)
else:
raise ValueError(f"不支持的模型类型: {self.model_type}")
# 训练
self.classifier.fit(X_train_tfidf, y_train)
# 评估
y_pred = self.classifier.predict(X_test_tfidf)
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='weighted')
print(f"模型类型: {self.model_type}")
print(f"准确率: {accuracy:.4f}")
print(f"F1-Score: {f1:.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=actual_category_names))
return accuracy, f1
def predict(self, title, content):
"""
预测单个文本
"""
if self.classifier is None or self.vectorizer is None:
raise ValueError("模型未训练请先调用train方法")
text = title + ' ' + content
# 预测时不指定分类,只使用通用停用词
processed = self.preprocess_text(text)
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
confidence = float(probabilities.max())
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': round(confidence, 4)
}
def save_model(self, model_dir):
"""
保存模型
"""
os.makedirs(model_dir, exist_ok=True)
joblib.dump(self.vectorizer, os.path.join(model_dir, f'{self.model_type}_vectorizer.pkl'))
joblib.dump(self.classifier, os.path.join(model_dir, f'{self.model_type}_classifier.pkl'))
print(f"模型已保存到: {model_dir}")
def load_model(self, model_dir):
"""
加载模型
"""
self.vectorizer = joblib.load(os.path.join(model_dir, f'{self.model_type}_vectorizer.pkl'))
self.classifier = joblib.load(os.path.join(model_dir, f'{self.model_type}_classifier.pkl'))
print(f"模型已从 {model_dir} 加载")
if __name__ == '__main__':
# 示例用法
classifier = NewsClassifier(model_type='nb')
# 假设有训练数据文件
train_data_path = '../../data/processed/training_data.csv'
if os.path.exists(train_data_path):
df = classifier.load_data(train_data_path)
classifier.train(df)
classifier.save_model('../../models/traditional')
# 测试预测
test_title = "华为发布新款折叠屏手机"
test_content = "华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
result = classifier.predict(test_title, test_content)
print("\n测试预测结果:", result)
else:
print(f"训练数据文件不存在: {train_data_path}")
print("请先准备训练数据CSV文件包含列: title, content, category")