news-classifier/ml-module/模块开发任务清单.md

22 KiB
Raw Blame History

新闻文本分类系统 - 模块开发任务清单


目录

  1. 爬虫模块 (Python)
  2. 后端服务模块 (Spring Boot)
  3. 前端桌面模块 (Tauri + Vue3)
  4. 机器学习分类模块 (Python)

1. 爬虫模块 (Python)

2. 后端服务模块 (Spring Boot)

4. 机器学习分类模块 (Python)

模块目录结构

ml-module/
├── data/
│   ├── raw/                   # 原始数据
│   ├── processed/             # 处理后的数据
│   │   ├── training_data.csv
│   │   └── test_data.csv
│   └── external/              # 外部数据集
├── models/                    # 训练好的模型
│   ├── traditional/
│   │   ├── nb_vectorizer.pkl
│   │   ├── nb_classifier.pkl
│   │   ├── svm_vectorizer.pkl
│   │   └── svm_classifier.pkl
│   ├── deep_learning/
│   │   └── bert_finetuned/
│   └── hybrid/
│       └── config.json
├── src/
│   ├── __init__.py
│   ├── traditional/           # 传统机器学习
│   │   ├── __init__.py
│   │   ├── train_model.py     # (已有)
│   │   ├── predict.py
│   │   └── evaluate.py
│   ├── deep_learning/         # 深度学习
│   │   ├── __init__.py
│   │   ├── bert_model.py
│   │   ├── train_bert.py
│   │   └── predict_bert.py
│   ├── hybrid/                # 混合策略
│   │   ├── __init__.py
│   │   ├── hybrid_classifier.py
│   │   └── rule_engine.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── preprocessing.py   # 数据预处理
│   │   └── metrics.py         # 评估指标
│   └── api/                   # API服务
│       ├── __init__.py
│       └── server.py          # FastAPI服务
├── notebooks/                 # Jupyter notebooks
│   ├── data_exploration.ipynb
│   └── model_comparison.ipynb
├── tests/                     # 测试
│   ├── test_traditional.py
│   ├── test_bert.py
│   └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md

4.1 需要完成的具体文件

任务 4.1.1: src/traditional/predict.py - 传统模型预测

"""
传统机器学习模型预测
"""

import os
import joblib
import jieba
from typing import Dict, Any

# 分类映射
CATEGORY_MAP = {
    'POLITICS': '时政',
    'FINANCE': '财经',
    'TECHNOLOGY': '科技',
    'SPORTS': '体育',
    'ENTERTAINMENT': '娱乐',
    'HEALTH': '健康',
    'EDUCATION': '教育',
    'LIFE': '生活',
    'INTERNATIONAL': '国际',
    'MILITARY': '军事'
}


class TraditionalPredictor:
    """传统机器学习预测器"""

    def __init__(self, model_type='nb', model_dir='../../models/traditional'):
        self.model_type = model_type
        self.model_dir = model_dir
        self.vectorizer = None
        self.classifier = None
        self._load_model()

    def _load_model(self):
        """加载模型"""
        vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
        classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')

        self.vectorizer = joblib.load(vectorizer_path)
        self.classifier = joblib.load(classifier_path)
        print(f"模型加载成功: {self.model_type}")

    def preprocess(self, title: str, content: str) -> str:
        """预处理文本"""
        text = title + ' ' + content
        # jieba分词
        words = jieba.cut(text)
        return ' '.join(words)

    def predict(self, title: str, content: str) -> Dict[str, Any]:
        """
        预测
        :return: 预测结果字典
        """
        # 预处理
        processed = self.preprocess(title, content)

        # 特征提取
        tfidf = self.vectorizer.transform([processed])

        # 预测
        prediction = self.classifier.predict(tfidf)[0]
        probabilities = self.classifier.predict_proba(tfidf)[0]

        # 获取各类别概率
        prob_dict = {}
        for i, prob in enumerate(probabilities):
            category_code = self.classifier.classes_[i]
            prob_dict[category_code] = float(prob)

        return {
            'categoryCode': prediction,
            'categoryName': CATEGORY_MAP.get(prediction, '未知'),
            'confidence': float(probabilities.max()),
            'probabilities': prob_dict
        }


# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
    """
    单条预测API
    """
    predictor = TraditionalPredictor(model_type)
    return predictor.predict(title, content)


if __name__ == '__main__':
    # 测试
    result = predict_single(
        title="华为发布新款折叠屏手机",
        content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
    )
    print(result)

任务 4.1.2: src/deep_learning/bert_model.py - BERT模型

"""
BERT文本分类模型
"""

import torch
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    Trainer,
    TrainingArguments
)
from typing import Dict, Any, List


# 分类映射
CATEGORY_MAP = {
    'POLITICS': '时政',
    'FINANCE': '财经',
    'TECHNOLOGY': '科技',
    'SPORTS': '体育',
    'ENTERTAINMENT': '娱乐',
    'HEALTH': '健康',
    'EDUCATION': '教育',
    'LIFE': '生活',
    'INTERNATIONAL': '国际',
    'MILITARY': '军事'
}

# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}


class BertClassifier:
    """BERT文本分类器"""

    def __init__(self, model_name='bert-base-chinese', num_labels=10):
        self.model_name = model_name
        self.num_labels = num_labels
        self.tokenizer = None
        self.model = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def load_model(self, model_path):
        """加载微调后的模型"""
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertForSequenceClassification.from_pretrained(
            model_path,
            num_labels=self.num_labels
        )
        self.model.to(self.device)
        self.model.eval()
        print(f"BERT模型加载成功: {model_path}")

    def predict(self, title: str, content: str) -> Dict[str, Any]:
        """
        预测
        """
        if self.model is None or self.tokenizer is None:
            raise ValueError("模型未加载请先调用load_model")

        # 组合标题和内容
        text = f"{title} [SEP] {content}"

        # 分词
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            truncation=True,
            max_length=512,
            padding='max_length'
        )

        # 预测
        with torch.no_grad():
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            outputs = self.model(**inputs)
            logits = outputs.logits

        # 获取预测结果
        probs = torch.softmax(logits, dim=-1)
        confidence, predicted_id = torch.max(probs, dim=-1)

        predicted_id = predicted_id.item()
        confidence = confidence.item()

        # 获取各类别概率
        prob_dict = {}
        for i, prob in enumerate(probs[0].cpu().numpy()):
            category_code = ID_TO_LABEL[i]
            prob_dict[category_code] = float(prob)

        return {
            'categoryCode': ID_TO_LABEL[predicted_id],
            'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
            'confidence': confidence,
            'probabilities': prob_dict
        }


# 数据集类
class NewsDataset(torch.utils.data.Dataset):
    """新闻数据集"""

    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


if __name__ == '__main__':
    # 测试
    classifier = BertClassifier()
    # classifier.load_model('./models/deep_learning/bert_finetuned')
    #
    # result = classifier.predict(
    #     title="华为发布新款折叠屏手机",
    #     content="华为今天正式发布了新一代折叠屏手机..."
    # )
    # print(result)
    print("BERT分类器初始化成功")

任务 4.1.3: src/hybrid/hybrid_classifier.py - 混合分类器

"""
混合策略分类器
结合规则引擎和机器学习模型
"""

import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier


class HybridClassifier:
    """混合分类器"""

    def __init__(self):
        # 初始化各个分类器
        self.nb_predictor = TraditionalPredictor('nb')
        self.bert_classifier = BertClassifier()

        # 配置参数
        self.config = {
            'confidence_threshold': 0.75,    # 高置信度阈值
            'hybrid_min_confidence': 0.60,   # 混合模式最低阈值
            'use_bert_threshold': 0.70,      # 使用BERT的阈值
            'rule_priority': True            # 规则优先
        }

        # 规则关键词字典
        self.rule_keywords = {
            'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
            'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
            'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
            'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
            'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
            'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
            'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
            'LIFE': ['生活', '美食', '旅游', '购物'],
            'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
            'MILITARY': ['军事', '武器', '军队', '国防', '战争']
        }

    def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
        """
        规则匹配
        :return: (category_code, confidence)
        """
        text = title + ' ' + content

        # 计算每个类别的关键词匹配数
        matches = {}
        for category, keywords in self.rule_keywords.items():
            count = sum(1 for kw in keywords if kw in text)
            if count > 0:
                matches[category] = count

        if not matches:
            return None, 0.0

        # 返回匹配最多的类别
        best_category = max(matches, key=matches.get)
        confidence = min(0.9, matches[best_category] * 0.15)  # 规则置信度

        return best_category, confidence

    def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
        """
        混合预测
        """
        start_time = time.time()

        # 1. 先尝试规则匹配
        rule_category, rule_confidence = self.rule_match(title, content)

        # 2. 传统机器学习预测
        nb_result = self.nb_predictor.predict(title, content)
        nb_confidence = nb_result['confidence']

        # 决策逻辑
        final_result = None
        classifier_type = 'HYBRID'

        # 规则优先且规则置信度高
        if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
            final_result = {
                'categoryCode': rule_category,
                'categoryName': nb_result['categoryName'],  # 从映射获取
                'confidence': rule_confidence,
                'classifierType': 'RULE',
                'reason': '规则匹配'
            }
        # 传统模型置信度足够高
        elif nb_confidence >= self.config['confidence_threshold']:
            final_result = {
                **nb_result,
                'classifierType': 'ML',
                'reason': '传统模型高置信度'
            }
        # 需要使用BERT
        elif use_bert:
            # TODO: 加载BERT模型预测
            # bert_result = self.bert_classifier.predict(title, content)
            # 如果BERT置信度也不高选择最高的
            final_result = {
                **nb_result,
                'classifierType': 'HYBRID',
                'reason': '混合决策'
            }
        else:
            # 不使用BERT直接返回传统模型结果
            final_result = {
                **nb_result,
                'classifierType': 'ML',
                'reason': '默认传统模型'
            }

        # 计算耗时
        duration = int((time.time() - start_time) * 1000)
        final_result['duration'] = duration

        return final_result


if __name__ == '__main__':
    # 测试
    classifier = HybridClassifier()

    test_cases = [
        {
            'title': '国务院发布最新经济政策',
            'content': '国务院今天发布了新的经济政策...'
        },
        {
            'title': '华为发布新款折叠屏手机',
            'content': '华为今天正式发布了新一代折叠屏手机...'
        }
    ]

    for case in test_cases:
        result = classifier.predict(case['title'], case['content'])
        print(f"标题: {case['title']}")
        print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
        print(f"分类器: {result['classifierType']}")
        print(f"原因: {result.get('reason', 'N/A')}")
        print(f"耗时: {result['duration']}ms")
        print("-" * 50)

任务 4.1.4: src/api/server.py - FastAPI服务

"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging

# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 创建FastAPI应用
app = FastAPI(
    title="新闻分类API",
    description="提供新闻文本分类服务",
    version="1.0.0"
)

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# 请求模型
class ClassifyRequest(BaseModel):
    title: str
    content: str
    mode: Optional[str] = 'hybrid'  # traditional, hybrid


# 响应模型
class ClassifyResponse(BaseModel):
    categoryCode: str
    categoryName: str
    confidence: float
    classifierType: str
    duration: int
    probabilities: Optional[dict] = None


# 初始化分类器
nb_predictor = None
hybrid_classifier = None


@app.on_event("startup")
async def startup_event():
    """启动时加载模型"""
    global nb_predictor, hybrid_classifier

    logger.info("加载模型...")

    try:
        nb_predictor = TraditionalPredictor('nb')
        logger.info("朴素贝叶斯模型加载成功")
    except Exception as e:
        logger.error(f"朴素贝叶斯模型加载失败: {e}")

    try:
        hybrid_classifier = HybridClassifier()
        logger.info("混合分类器初始化成功")
    except Exception as e:
        logger.error(f"混合分类器初始化失败: {e}")


@app.get("/")
async def root():
    """健康检查"""
    return {
        "status": "ok",
        "message": "新闻分类API服务运行中"
    }


@app.get("/health")
async def health_check():
    """健康检查"""
    return {
        "status": "healthy",
        "models": {
            "nb_loaded": nb_predictor is not None,
            "hybrid_loaded": hybrid_classifier is not None
        }
    }


@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
    """
    文本分类接口

    - **title**: 新闻标题
    - **content**: 新闻内容
    - **mode**: 分类模式 (traditional, hybrid)
    """
    try:
        if request.mode == 'traditional':
            result = nb_predictor.predict(request.title, request.content)
            result['classifierType'] = 'ML'
        else:  # hybrid
            result = hybrid_classifier.predict(request.title, request.content)

        return ClassifyResponse(**result)

    except Exception as e:
        logger.error(f"预测失败: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
    """
    批量分类接口
    """
    results = []
    for req in requests:
        try:
            if req.mode == 'traditional':
                result = nb_predictor.predict(req.title, req.content)
                result['classifierType'] = 'ML'
            else:
                result = hybrid_classifier.predict(req.title, req.content)
            results.append(result)
        except Exception as e:
            results.append({
                'error': str(e),
                'title': req.title
            })

    return {"results": results}


if __name__ == '__main__':
    import uvicorn

    uvicorn.run(
        app,
        host="0.0.0.0",
        port=5000,
        log_level="info"
    )

任务 4.1.5: src/utils/metrics.py - 评估指标

"""
模型评估指标工具
"""

import numpy as np
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns


class ClassificationMetrics:
    """分类评估指标"""

    @staticmethod
    def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
        """
        计算所有指标
        """
        accuracy = accuracy_score(y_true, y_pred)

        precision, recall, f1, support = precision_recall_fscore_support(
            y_true, y_pred, average='weighted', zero_division=0
        )

        # 每个类别的指标
        precision_per_class, recall_per_class, f1_per_class, support_per_class = \
            precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)

        per_class_metrics = {}
        for i, label in enumerate(labels):
            per_class_metrics[label] = {
                'precision': float(precision_per_class[i]),
                'recall': float(recall_per_class[i]),
                'f1': float(f1_per_class[i]),
                'support': int(support_per_class[i])
            }

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1),
            'per_class': per_class_metrics
        }

    @staticmethod
    def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
        """
        绘制混淆矩阵
        """
        cm = confusion_matrix(y_true, y_pred)

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            cm,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=labels,
            yticklabels=labels
        )
        plt.xlabel('预测标签')
        plt.ylabel('真实标签')
        plt.title('混淆矩阵')

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

    @staticmethod
    def print_report(y_true: List, y_pred: List, labels: List[str]):
        """
        打印分类报告
        """
        report = classification_report(
            y_true, y_pred,
            target_names=labels,
            zero_division=0
        )
        print(report)


if __name__ == '__main__':
    # 测试
    y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
    y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
    labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']

    metrics = ClassificationMetrics()
    result = metrics.compute_all(y_true, y_pred, labels)
    print(result)

任务 4.1.6: requirements.txt - 依赖文件

# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0

# 深度学习
torch>=2.0.0
transformers>=4.30.0

# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0

# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0

# 工具
python-dotenv>=1.0.0
pyyaml>=6.0

总结

开发顺序建议

  1. 第一阶段:基础框架

    • 后端:数据库连接、实体类、基础配置
    • 前端路由配置、状态管理、API封装
  2. 第二阶段:核心功能

    • 爬虫模块Python
    • 传统机器学习分类器
    • 后端API接口
    • 前端新闻列表页面
  3. 第三阶段:高级功能

    • BERT深度学习分类器
    • 混合策略分类器
    • 前端分类器对比页面
    • 统计图表
  4. 第四阶段:完善优化

    • 用户认证
    • 数据可视化
    • 性能优化
    • 异常处理

关键注意事项

  1. 爬虫模块使用 Python,通过 RESTful API 与 Java 后端通信
  2. 分类器模块独立部署,提供 HTTP 接口供后端调用
  3. 前后端分离,使用 JWT 进行身份认证
  4. 数据库表结构已在 schema.sql 中定义,需严格遵守
  5. API 统一响应格式使用 Result<T> 包装