""" 传统机器学习文本分类器训练脚本 支持朴素贝叶斯和SVM两种算法 """ import os import jieba import joblib import pandas as pd import numpy as np from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.naive_bayes import MultinomialNB from sklearn.svm import SVC from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report, accuracy_score, f1_score # 分类映射(与数据库表一致) CATEGORY_MAP = { 'ENTERTAINMENT': '娱乐', 'SPORTS': '体育', 'FINANCE': '财经', 'TECHNOLOGY': '科技', 'MILITARY': '军事', 'AUTOMOTIVE': '汽车', 'GOVERNMENT': '政务', 'HEALTH': '健康', 'AI': 'AI' } REVERSE_CATEGORY_MAP = {v: k for k, v in CATEGORY_MAP.items()} # 分类专属停用词(与数据库表一致) CATEGORY_STOPWORDS = { 'ENTERTAINMENT': {'主演', '影片', '电影', '电视剧', '节目', '导演', '角色', '上映', '粉丝'}, 'SPORTS': {'比赛', '赛事', '赛季', '球队', '选手', '球员', '主场', '客场', '对阵', '比分'}, 'FINANCE': {'亿元', '万元', '同比', '环比', '增长率', '数据', '报告', '统计', '财报', '季度', '年度'}, 'TECHNOLOGY': {'技术', '系统', '平台', '方案', '应用', '功能', '版本', '升级', '研发', '推出'}, 'MILITARY': {'部队', '军方', '演习', '训练', '装备', '武器', '作战', '行动', '部署'}, 'AUTOMOTIVE': {'汽车', '车型', '上市', '发布', '销量', '市场', '品牌', '厂商'}, 'GOVERNMENT': {'会议', '讲话', '指出', '强调', '部署', '落实', '推进', '要求', '精神', '决定', '意见', '方案', '安排'}, 'HEALTH': {'医生', '专家', '建议', '提示', '提醒', '研究', '发现', '可能', '有助于'}, 'AI': {'技术', '系统', '模型', '算法', '应用', '功能', '版本', '升级', '研发', '推出', '人工智能'} } class NewsClassifier: """新闻文本分类器""" def __init__(self, model_type='nb', use_stopwords=True, use_category_stopwords=False): """ 初始化分类器 :param model_type: 模型类型 'nb' 朴素贝叶斯 或 'svm' 支持向量机 :param use_stopwords: 是否使用通用停用词 :param use_category_stopwords: 是否使用分类专属停用词 """ self.model_type = model_type self.vectorizer = None self.classifier = None self.categories = list(CATEGORY_MAP.keys()) self.use_stopwords = use_stopwords self.use_category_stopwords = use_category_stopwords self.stopwords = set() if self.use_stopwords: self._load_stopwords() def _load_stopwords(self, stopwords_path='../../data/news_stopwords.txt'): """ 加载停用词表 """ try: with open(stopwords_path, 'r', encoding='utf-8') as f: self.stopwords = set(line.strip() for line in f if line.strip()) print(f"已加载 {len(self.stopwords)} 个停用词") except FileNotFoundError: print(f"警告: 停用词文件不存在: {stopwords_path}") def preprocess_text(self, text, category=None): """ 文本预处理:使用jieba分词 + 停用词过滤 :param text: 待处理文本 :param category: 可选,指定分类时使用分类专属停用词 """ # 移除多余空格和换行 text = ' '.join(text.split()) # jieba分词 words = jieba.cut(text) # 过滤停用词和单字词 result = [] for w in words: # 过滤单字词 if len(w) <= 1: continue # 过滤通用停用词 if self.use_stopwords and w in self.stopwords: continue # 过滤分类专属停用词 if self.use_category_stopwords and category: if w in CATEGORY_STOPWORDS.get(category, set()): continue result.append(w) return ' '.join(result) def load_data(self, csv_path): """ 从CSV文件加载训练数据 """ df = pd.read_csv(csv_path) # 合并标题和内容作为特征 df['text'] = df['title'] + ' ' + df['content'] # 预处理(如果启用了分类专属停用词,需要传入分类信息) if self.use_category_stopwords: # 先转换为分类代码 df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP) # 逐行预处理,传入分类信息 df['processed_text'] = df.apply( lambda row: self.preprocess_text(row['text'], row['category_code']), axis=1 ) else: # 不使用分类专属停用词,直接批量预处理 df['processed_text'] = df['text'].apply(self.preprocess_text) # 转换分类名称为代码 df['category_code'] = df['category_name'].map(REVERSE_CATEGORY_MAP) return df def train(self, df): """ 训练模型 """ X = df['processed_text'].values y = df['category_code'].values # 获取实际数据中的分类 actual_categories = sorted(df['category_code'].unique().tolist()) actual_category_names = [CATEGORY_MAP[cat] for cat in actual_categories] # 划分训练集和测试集 X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42, stratify=y ) # TF-IDF特征提取 self.vectorizer = TfidfVectorizer( max_features=5000, ngram_range=(1, 2), min_df=2 ) X_train_tfidf = self.vectorizer.fit_transform(X_train) X_test_tfidf = self.vectorizer.transform(X_test) # 选择分类器 if self.model_type == 'nb': self.classifier = MultinomialNB(alpha=0.1) elif self.model_type == 'svm': self.classifier = SVC(kernel='linear', probability=True, random_state=42) else: raise ValueError(f"不支持的模型类型: {self.model_type}") # 训练 self.classifier.fit(X_train_tfidf, y_train) # 评估 y_pred = self.classifier.predict(X_test_tfidf) accuracy = accuracy_score(y_test, y_pred) f1 = f1_score(y_test, y_pred, average='weighted') print(f"模型类型: {self.model_type}") print(f"准确率: {accuracy:.4f}") print(f"F1-Score: {f1:.4f}") print("\n分类报告:") print(classification_report(y_test, y_pred, target_names=actual_category_names)) return accuracy, f1 def predict(self, title, content): """ 预测单个文本 """ if self.classifier is None or self.vectorizer is None: raise ValueError("模型未训练,请先调用train方法") text = title + ' ' + content # 预测时不指定分类,只使用通用停用词 processed = self.preprocess_text(text) tfidf = self.vectorizer.transform([processed]) # 预测 prediction = self.classifier.predict(tfidf)[0] probabilities = self.classifier.predict_proba(tfidf)[0] confidence = float(probabilities.max()) return { 'categoryCode': prediction, 'categoryName': CATEGORY_MAP.get(prediction, '未知'), 'confidence': round(confidence, 4) } def save_model(self, model_dir): """ 保存模型 """ os.makedirs(model_dir, exist_ok=True) joblib.dump(self.vectorizer, os.path.join(model_dir, f'{self.model_type}_vectorizer.pkl')) joblib.dump(self.classifier, os.path.join(model_dir, f'{self.model_type}_classifier.pkl')) print(f"模型已保存到: {model_dir}") def load_model(self, model_dir): """ 加载模型 """ self.vectorizer = joblib.load(os.path.join(model_dir, f'{self.model_type}_vectorizer.pkl')) self.classifier = joblib.load(os.path.join(model_dir, f'{self.model_type}_classifier.pkl')) print(f"模型已从 {model_dir} 加载") if __name__ == '__main__': # 示例用法 classifier = NewsClassifier(model_type='nb') # 假设有训练数据文件 train_data_path = '../../data/processed/training_data.csv' if os.path.exists(train_data_path): df = classifier.load_data(train_data_path) classifier.train(df) classifier.save_model('../../models/traditional') # 测试预测 test_title = "华为发布新款折叠屏手机" test_content = "华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..." result = classifier.predict(test_title, test_content) print("\n测试预测结果:", result) else: print(f"训练数据文件不存在: {train_data_path}") print("请先准备训练数据CSV文件,包含列: title, content, category")