242 lines
8.9 KiB
Python
242 lines
8.9 KiB
Python
"""
|
||
传统机器学习文本分类器训练脚本
|
||
支持朴素贝叶斯和SVM两种算法
|
||
"""
|
||
|
||
import os
|
||
import jieba
|
||
import joblib
|
||
import pandas as pd
|
||
import numpy as np
|
||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||
from sklearn.naive_bayes import MultinomialNB
|
||
from sklearn.svm import SVC
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.metrics import classification_report, accuracy_score, f1_score
|
||
|
||
# 分类映射(与数据库表一致)
|
||
CATEGORY_MAP = {
|
||
'ENTERTAINMENT': '娱乐',
|
||
'SPORTS': '体育',
|
||
'FINANCE': '财经',
|
||
'TECHNOLOGY': '科技',
|
||
'MILITARY': '军事',
|
||
'AUTOMOTIVE': '汽车',
|
||
'GOVERNMENT': '政务',
|
||
'HEALTH': '健康',
|
||
'AI': 'AI'
|
||
}
|
||
|
||
REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()}
|
||
|
||
|
||
# 分类专属停用词(与数据库表一致)
|
||
CATEGORY_STOPWORDS = {
|
||
'ENTERTAINMENT': {'主演', '影片', '电影', '电视剧', '节目', '导演', '角色', '上映', '粉丝'},
|
||
'SPORTS': {'比赛', '赛事', '赛季', '球队', '选手', '球员', '主场', '客场', '对阵', '比分'},
|
||
'FINANCE': {'亿元', '万元', '同比', '环比', '增长率', '数据', '报告', '统计', '财报', '季度', '年度'},
|
||
'TECHNOLOGY': {'技术', '系统', '平台', '方案', '应用', '功能', '版本', '升级', '研发', '推出'},
|
||
'MILITARY': {'部队', '军方', '演习', '训练', '装备', '武器', '作战', '行动', '部署'},
|
||
'AUTOMOTIVE': {'汽车', '车型', '上市', '发布', '销量', '市场', '品牌', '厂商'},
|
||
'GOVERNMENT': {'会议', '讲话', '指出', '强调', '部署', '落实', '推进', '要求', '精神', '决定', '意见', '方案', '安排'},
|
||
'HEALTH': {'医生', '专家', '建议', '提示', '提醒', '研究', '发现', '可能', '有助于'},
|
||
'AI': {'技术', '系统', '模型', '算法', '应用', '功能', '版本', '升级', '研发', '推出', '人工智能'}
|
||
}
|
||
|
||
|
||
class NewsClassifier:
|
||
"""新闻文本分类器"""
|
||
|
||
def __init__(self, model_type='nb', use_stopwords=True, use_category_stopwords=False):
|
||
"""
|
||
初始化分类器
|
||
:param model_type: 模型类型 'nb' 朴素贝叶斯 或 'svm' 支持向量机
|
||
:param use_stopwords: 是否使用通用停用词
|
||
:param use_category_stopwords: 是否使用分类专属停用词
|
||
"""
|
||
self.model_type = model_type
|
||
self.vectorizer = None
|
||
self.classifier = None
|
||
self.categories = list(CATEGORY_MAP.keys())
|
||
self.use_stopwords = use_stopwords
|
||
self.use_category_stopwords = use_category_stopwords
|
||
self.stopwords = set()
|
||
|
||
if self.use_stopwords:
|
||
self._load_stopwords()
|
||
|
||
def _load_stopwords(self, stopwords_path='../../data/news_stopwords.txt'):
|
||
"""
|
||
加载停用词表
|
||
"""
|
||
try:
|
||
with open(stopwords_path, 'r', encoding='utf-8') as f:
|
||
self.stopwords = set(line.strip() for line in f if line.strip())
|
||
print(f"已加载 {len(self.stopwords)} 个停用词")
|
||
except FileNotFoundError:
|
||
print(f"警告: 停用词文件不存在: {stopwords_path}")
|
||
|
||
def preprocess_text(self, text, category=None):
|
||
"""
|
||
文本预处理:使用jieba分词 + 停用词过滤
|
||
:param text: 待处理文本
|
||
:param category: 可选,指定分类时使用分类专属停用词
|
||
"""
|
||
# 移除多余空格和换行
|
||
text = ' '.join(text.split())
|
||
# jieba分词
|
||
words = jieba.cut(text)
|
||
|
||
# 过滤停用词和单字词
|
||
result = []
|
||
for w in words:
|
||
# 过滤单字词
|
||
if len(w) <= 1:
|
||
continue
|
||
# 过滤通用停用词
|
||
if self.use_stopwords and w in self.stopwords:
|
||
continue
|
||
# 过滤分类专属停用词
|
||
if self.use_category_stopwords and category:
|
||
if w in CATEGORY_STOPWORDS.get(category, set()):
|
||
continue
|
||
result.append(w)
|
||
|
||
return ' '.join(result)
|
||
|
||
def load_data(self, csv_path):
|
||
"""
|
||
从CSV文件加载训练数据
|
||
"""
|
||
df = pd.read_csv(csv_path)
|
||
# 合并标题和内容作为特征
|
||
df['text'] = df['title'] + ' ' + df['content']
|
||
|
||
# 预处理(如果启用了分类专属停用词,需要传入分类信息)
|
||
if self.use_category_stopwords:
|
||
# 先转换为分类代码
|
||
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
|
||
# 逐行预处理,传入分类信息
|
||
df['processed_text'] = df.apply(
|
||
lambda row: self.preprocess_text(row['text'], row['category_code']),
|
||
axis=1
|
||
)
|
||
else:
|
||
# 不使用分类专属停用词,直接批量预处理
|
||
df['processed_text'] = df['text'].apply(self.preprocess_text)
|
||
# 转换分类名称为代码
|
||
df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP)
|
||
|
||
return df
|
||
|
||
def train(self, df):
|
||
"""
|
||
训练模型
|
||
"""
|
||
X = df['processed_text'].values
|
||
y = df['category_code'].values
|
||
|
||
# 获取实际数据中的分类
|
||
actual_categories = sorted(df['category_code'].unique().tolist())
|
||
actual_category_names = [CATEGORY_MAP[cat] for cat in actual_categories]
|
||
|
||
# 划分训练集和测试集
|
||
X_train, X_test, y_train, y_test = train_test_split(
|
||
X, y, test_size=0.2, random_state=42, stratify=y
|
||
)
|
||
|
||
# TF-IDF特征提取
|
||
self.vectorizer = TfidfVectorizer(
|
||
max_features=5000,
|
||
ngram_range=(1, 2),
|
||
min_df=2
|
||
)
|
||
X_train_tfidf = self.vectorizer.fit_transform(X_train)
|
||
X_test_tfidf = self.vectorizer.transform(X_test)
|
||
|
||
# 选择分类器
|
||
if self.model_type == 'nb':
|
||
self.classifier = MultinomialNB(alpha=0.1)
|
||
elif self.model_type == 'svm':
|
||
self.classifier = SVC(kernel='linear', probability=True, random_state=42)
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {self.model_type}")
|
||
|
||
# 训练
|
||
self.classifier.fit(X_train_tfidf, y_train)
|
||
|
||
# 评估
|
||
y_pred = self.classifier.predict(X_test_tfidf)
|
||
accuracy = accuracy_score(y_test, y_pred)
|
||
f1 = f1_score(y_test, y_pred, average='weighted')
|
||
|
||
print(f"模型类型: {self.model_type}")
|
||
print(f"准确率: {accuracy:.4f}")
|
||
print(f"F1-Score: {f1:.4f}")
|
||
print("\n分类报告:")
|
||
print(classification_report(y_test, y_pred, target_names=actual_category_names))
|
||
|
||
return accuracy, f1
|
||
|
||
def predict(self, title, content):
|
||
"""
|
||
预测单个文本
|
||
"""
|
||
if self.classifier is None or self.vectorizer is None:
|
||
raise ValueError("模型未训练,请先调用train方法")
|
||
|
||
text = title + ' ' + content
|
||
# 预测时不指定分类,只使用通用停用词
|
||
processed = self.preprocess_text(text)
|
||
tfidf = self.vectorizer.transform([processed])
|
||
|
||
# 预测
|
||
prediction = self.classifier.predict(tfidf)[0]
|
||
probabilities = self.classifier.predict_proba(tfidf)[0]
|
||
confidence = float(probabilities.max())
|
||
|
||
return {
|
||
'categoryCode': prediction,
|
||
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
|
||
'confidence': round(confidence, 4)
|
||
}
|
||
|
||
def save_model(self, model_dir):
|
||
"""
|
||
保存模型
|
||
"""
|
||
os.makedirs(model_dir, exist_ok=True)
|
||
joblib.dump(self.vectorizer, os.path.join(model_dir, f'{self.model_type}_vectorizer.pkl'))
|
||
joblib.dump(self.classifier, os.path.join(model_dir, f'{self.model_type}_classifier.pkl'))
|
||
print(f"模型已保存到: {model_dir}")
|
||
|
||
def load_model(self, model_dir):
|
||
"""
|
||
加载模型
|
||
"""
|
||
self.vectorizer = joblib.load(os.path.join(model_dir, f'{self.model_type}_vectorizer.pkl'))
|
||
self.classifier = joblib.load(os.path.join(model_dir, f'{self.model_type}_classifier.pkl'))
|
||
print(f"模型已从 {model_dir} 加载")
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 示例用法
|
||
classifier = NewsClassifier(model_type='nb')
|
||
|
||
# 假设有训练数据文件
|
||
train_data_path = '../../data/processed/training_data.csv'
|
||
|
||
if os.path.exists(train_data_path):
|
||
df = classifier.load_data(train_data_path)
|
||
classifier.train(df)
|
||
classifier.save_model('../../models/traditional')
|
||
|
||
# 测试预测
|
||
test_title = "华为发布新款折叠屏手机"
|
||
test_content = "华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
|
||
result = classifier.predict(test_title, test_content)
|
||
print("\n测试预测结果:", result)
|
||
else:
|
||
print(f"训练数据文件不存在: {train_data_path}")
|
||
print("请先准备训练数据CSV文件,包含列: title, content, category")
|