94 lines
2.6 KiB
Python
94 lines
2.6 KiB
Python
"""
|
|
传统机器学习模型预测
|
|
"""
|
|
|
|
import os
|
|
import joblib
|
|
import jieba
|
|
from typing import Dict, Any
|
|
|
|
# 分类映射(根据数据库中的分类)
|
|
CATEGORY_MAP = {
|
|
'ENTERTAINMENT': '娱乐',
|
|
'SPORTS': '体育',
|
|
'FINANCE': '财经',
|
|
'TECHNOLOGY': '科技',
|
|
'MILITARY': '军事',
|
|
'AUTOMOTIVE': '汽车',
|
|
'GOVERNMENT': '政务',
|
|
'HEALTH': '健康',
|
|
'AI': 'AI'
|
|
}
|
|
|
|
|
|
class TraditionalPredictor:
|
|
"""传统机器学习预测器"""
|
|
|
|
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
|
|
self.model_type = model_type
|
|
self.model_dir = model_dir
|
|
self.vectorizer = None
|
|
self.classifier = None
|
|
self._load_model()
|
|
|
|
def _load_model(self):
|
|
"""加载模型"""
|
|
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
|
|
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
|
|
|
|
self.vectorizer = joblib.load(vectorizer_path)
|
|
self.classifier = joblib.load(classifier_path)
|
|
print(f"模型加载成功: {self.model_type}")
|
|
|
|
def preprocess(self, title: str, content: str) -> str:
|
|
"""预处理文本"""
|
|
text = title + ' ' + content
|
|
# jieba分词
|
|
words = jieba.cut(text)
|
|
return ' '.join(words)
|
|
|
|
def predict(self, title: str, content: str) -> Dict[str, Any]:
|
|
"""
|
|
预测
|
|
:return: 预测结果字典
|
|
"""
|
|
# 预处理
|
|
processed = self.preprocess(title, content)
|
|
|
|
# 特征提取
|
|
tfidf = self.vectorizer.transform([processed])
|
|
|
|
# 预测
|
|
prediction = self.classifier.predict(tfidf)[0]
|
|
probabilities = self.classifier.predict_proba(tfidf)[0]
|
|
|
|
# 获取各类别概率
|
|
prob_dict = {}
|
|
for i, prob in enumerate(probabilities):
|
|
category_code = self.classifier.classes_[i]
|
|
prob_dict[category_code] = float(prob)
|
|
|
|
return {
|
|
'categoryCode': prediction,
|
|
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
|
|
'confidence': float(probabilities.max()),
|
|
'probabilities': prob_dict
|
|
}
|
|
|
|
|
|
# API入口
|
|
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
|
|
"""
|
|
单条预测API
|
|
"""
|
|
predictor = TraditionalPredictor(model_type)
|
|
return predictor.predict(title, content)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# 测试
|
|
result = predict_single(
|
|
title="华为发布新款折叠屏手机",
|
|
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
|
|
)
|
|
print(result) |