news-classifier/ml-module/src/traditional/predict.py

94 lines
2.6 KiB
Python

"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射(根据数据库中的分类)
CATEGORY_MAP = {
'ENTERTAINMENT': '娱乐',
'SPORTS': '体育',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'MILITARY': '军事',
'AUTOMOTIVE': '汽车',
'GOVERNMENT': '政务',
'HEALTH': '健康',
'AI': 'AI'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)