news-classifier/docs/模块开发任务清单.md

42 KiB
Raw Blame History

新闻文本分类系统 - 模块开发任务清单

本文档详细列出每个模块需要完成的具体代码任务,参考现有工程结构。 注意:爬虫模块使用 Python 实现,而非 Java。


目录

  1. 爬虫模块 (Python)
  2. 后端服务模块 (Spring Boot)
  3. 前端桌面模块 (Tauri + Vue3)
  4. 机器学习分类模块 (Python)

1. 爬虫模块 (Python)

2. 后端服务模块 (Spring Boot)

模块目录结构


#### 任务 2.1.10: `application.yml` - 应用配置文件

```yaml
spring:
  application:
    name: news-classifier
  datasource:
    driver-class-name: com.mysql.cj.jdbc.Driver
    url: jdbc:mysql://localhost:3306/news_classifier?useUnicode=true&characterEncoding=utf8mb4&serverTimezone=Asia/Shanghai
    username: root
    password: your_password
  data:
    redis:
      host: localhost
      port: 6379
      database: 0

# MyBatis-Plus配置
mybatis-plus:
  mapper-locations: classpath:mapper/**/*.xml
  type-aliases-package: com.newsclassifier.entity
  configuration:
    map-underscore-to-camel-case: true
    log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
  global-config:
    db-config:
      id-type: auto
      logic-delete-field: deleted
      logic-delete-value: 1
      logic-not-delete-value: 0

# JWT配置
jwt:
  secret: your-secret-key-at-least-256-bits-long-for-hs256-algorithm
  expiration: 86400000

# 分类器配置
classifier:
  mode: hybrid  # traditional, deep_learning, hybrid
  confidence:
    threshold: 0.75
    hybrid-min: 0.6
  bert:
    service-url: http://localhost:5000/api/predict
    timeout: 5000

# 日志配置
logging:
  level:
    com.newsclassifier: debug
  pattern:
    console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"

3. 前端桌面模块 (Tauri + Vue3)

模块目录结构

client/src/
├── api/                    # API接口
│   ├── index.ts
│   ├── auth.ts
│   ├── news.ts
│   ├── category.ts
│   └── classifier.ts
├── assets/                 # 静态资源
│   ├── images/
│   ├── styles/
│   │   ├── main.css
│   │   └── tailwind.css
│   └── fonts/
├── components/             # 组件
│   ├── ui/                 # 基础UI组件
│   │   ├── button/
│   │   ├── input/
│   │   ├── dialog/
│   │   └── table/
│   ├── layout/             # 布局组件
│   │   ├── Header.vue
│   │   ├── Sidebar.vue
│   │   └── Footer.vue
│   ├── news/               # 新闻相关组件
│   │   ├── NewsCard.vue
│   │   ├── NewsList.vue
│   │   ├── NewsDetail.vue
│   │   └── CategoryFilter.vue
│   └── charts/             # 图表组件
│       ├── CategoryChart.vue
│       └── TrendChart.vue
├── composables/            # 组合式函数
│   ├── useAuth.ts
│   ├── useNews.ts
│   ├── useClassifier.ts
│   └── useToast.ts
├── layouts/                # 布局
│   ├── DefaultLayout.vue
│   ├── AuthLayout.vue
│   └── EmptyLayout.vue
├── router/                 # 路由 (部分完成)
│   └── index.ts
├── stores/                 # 状态管理 (部分完成)
│   ├── user.ts
│   ├── news.ts
│   └── category.ts
├── types/                  # TypeScript类型
│   ├── api.d.ts
│   ├── news.d.ts
│   └── user.d.ts
├── utils/                  # 工具函数
│   ├── request.ts
│   ├── storage.ts
│   ├── format.ts
│   └── validate.ts
├── views/                  # 页面
│   ├── auth/
│   │   ├── Login.vue
│   │   └── Register.vue
│   ├── news/
│   │   ├── NewsList.vue
│   │   ├── NewsDetail.vue
│   │   └── NewsSearch.vue
│   ├── category/
│   │   ├── CategoryManage.vue
│   │   └── CategoryStats.vue
│   ├── classifier/
│   │   ├── ClassifierPage.vue
│   │   └── ModelCompare.vue
│   └── admin/
│       ├── Dashboard.vue
│       ├── UserManage.vue
│       └── SystemLog.vue
├── App.vue
└── main.ts

3.1 需要完成的具体文件

任务 3.1.1: utils/request.ts - HTTP请求封装

import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios'

// 响应数据类型
interface ApiResponse<T = any> {
  code: number
  message: string
  data: T
}

// 创建axios实例
const service: AxiosInstance = axios.create({
  baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8080/api',
  timeout: 15000,
  headers: {
    'Content-Type': 'application/json'
  }
})

// 请求拦截器
service.interceptors.request.use(
  (config) => {
    const token = localStorage.getItem('token')
    if (token) {
      config.headers.Authorization = `Bearer ${token}`
    }
    return config
  },
  (error) => {
    return Promise.reject(error)
  }
)

// 响应拦截器
service.interceptors.response.use(
  (response: AxiosResponse<ApiResponse>) => {
    const { code, message, data } = response.data

    if (code === 200) {
      return data
    } else {
      // 处理错误
      return Promise.reject(new Error(message || '请求失败'))
    }
  },
  (error) => {
    // 处理HTTP错误
    return Promise.reject(error)
  }
)

// 封装请求方法
export const http = {
  get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
    return service.get(url, config)
  },

  post<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
    return service.post(url, data, config)
  },

  put<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
    return service.put(url, data, config)
  },

  delete<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
    return service.delete(url, config)
  }
}

export default service

任务 3.1.2: api/news.ts - 新闻API

import { http } from './request'

// 新闻查询参数
export interface NewsQueryParams {
  page?: number
  size?: number
  categoryId?: number
  categoryCode?: string
  keyword?: string
  status?: number
}

// 新闻详情
export interface NewsDetail {
  id: number
  title: string
  content: string
  summary: string
  source: string
  sourceUrl: string
  author: string
  categoryId: number
  categoryCode: string
  coverImage: string
  publishTime: string
  viewCount: number
  likeCount: number
  commentCount: number
  classifierType: string
  confidence: number
}

// 分页响应
export interface PageResponse<T> {
  total: number
  records: T[]
  current: number
  size: number
}

// 新闻API
export const newsApi = {
  // 分页查询
  getNewsPage(params: NewsQueryParams): Promise<PageResponse<NewsDetail>> {
    return http.get('/news/page', { params })
  },

  // 获取详情
  getNewsDetail(id: number): Promise<NewsDetail> {
    return http.get(`/news/${id}`)
  },

  // 搜索新闻
  searchNews(keyword: string, page = 1, size = 20): Promise<PageResponse<NewsDetail>> {
    return http.get('/news/search', { params: { keyword, page, size } })
  },

  // 手动分类
  manualClassify(id: number, categoryId: number): Promise<void> {
    return http.post(`/news/${id}/classify`, null, { params: { categoryId } })
  }
}

任务 3.1.3: composables/useNews.ts - 新闻组合式函数

import { ref, computed } from 'vue'
import { newsApi, type NewsQueryParams, type NewsDetail, type PageResponse } from '@/api/news'

export function useNews() {
  const loading = ref(false)
  const newsList = ref<NewsDetail[]>([])
  const total = ref(0)
  const currentNews = ref<NewsDetail | null>(null)

  // 分页查询
  const fetchNewsPage = async (params: NewsQueryParams) => {
    loading.value = true
    try {
      const result: PageResponse<NewsDetail> = await newsApi.getNewsPage(params)
      newsList.value = result.records
      total.value = result.total
    } catch (error) {
      console.error('获取新闻列表失败:', error)
      throw error
    } finally {
      loading.value = false
    }
  }

  // 获取详情
  const fetchNewsDetail = async (id: number) => {
    loading.value = true
    try {
      currentNews.value = await newsApi.getNewsDetail(id)
      return currentNews.value
    } catch (error) {
      console.error('获取新闻详情失败:', error)
      throw error
    } finally {
      loading.value = false
    }
  }

  // 搜索新闻
  const searchNews = async (keyword: string, page = 1, size = 20) => {
    loading.value = true
    try {
      const result: PageResponse<NewsDetail> = await newsApi.searchNews(keyword, page, size)
      newsList.value = result.records
      total.value = result.total
    } catch (error) {
      console.error('搜索新闻失败:', error)
      throw error
    } finally {
      loading.value = false
    }
  }

  return {
    loading: computed(() => loading.value),
    newsList: computed(() => newsList.value),
    total: computed(() => total.value),
    currentNews: computed(() => currentNews.value),
    fetchNewsPage,
    fetchNewsDetail,
    searchNews
  }
}

任务 3.1.4: views/news/NewsList.vue - 新闻列表页面

<template>
  <div class="news-list-container">
    <!-- 分类筛选 -->
    <div class="category-filter">
      <button
        v-for="cat in categories"
        :key="cat.code"
        :class="{ active: selectedCategory === cat.code }"
        @click="selectCategory(cat.code)"
      >
        {{ cat.name }}
      </button>
    </div>

    <!-- 搜索框 -->
    <div class="search-bar">
      <input
        v-model="searchKeyword"
        type="text"
        placeholder="搜索新闻..."
        @keyup.enter="handleSearch"
      />
      <button @click="handleSearch">搜索</button>
    </div>

    <!-- 新闻列表 -->
    <div v-if="!loading && newsList.length > 0" class="news-list">
      <div v-for="news in newsList" :key="news.id" class="news-card" @click="viewDetail(news.id)">
        <h3>{{ news.title }}</h3>
        <p class="summary">{{ news.summary }}</p>
        <div class="meta">
          <span class="category">{{ getCategoryName(news.categoryCode) }}</span>
          <span class="source">{{ news.source }}</span>
          <span class="time">{{ formatDate(news.publishTime) }}</span>
        </div>
        <div v-if="news.classifierType" class="classifier-info">
          <span class="badge">{{ news.classifierType }}</span>
          <span class="confidence">{{ (news.confidence * 100).toFixed(1) }}%</span>
        </div>
      </div>
    </div>

    <!-- 加载中 -->
    <div v-if="loading" class="loading">加载中...</div>

    <!-- 空状态 -->
    <div v-if="!loading && newsList.length === 0" class="empty">暂无数据</div>

    <!-- 分页 -->
    <div v-if="total > 0" class="pagination">
      <button :disabled="currentPage <= 1" @click="changePage(currentPage - 1)">上一页</button>
      <span>{{ currentPage }} / {{ totalPages }}</span>
      <button :disabled="currentPage >= totalPages" @click="changePage(currentPage + 1)">下一页</button>
    </div>
  </div>
</template>

<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { useNews } from '@/composables/useNews'

const router = useRouter()
const { loading, newsList, total, fetchNewsPage, searchNews } = useNews()

const categories = ref([
  { code: '', name: '全部' },
  { code: 'POLITICS', name: '时政' },
  { code: 'FINANCE', name: '财经' },
  { code: 'TECHNOLOGY', name: '科技' },
  { code: 'SPORTS', name: '体育' }
])

const selectedCategory = ref('')
const searchKeyword = ref('')
const currentPage = ref(1)
const pageSize = ref(20)

const totalPages = computed(() => Math.ceil(total.value / pageSize.value))

// 加载新闻
const loadNews = async () => {
  await fetchNewsPage({
    page: currentPage.value,
    size: pageSize.value,
    categoryCode: selectedCategory.value || undefined
  })
}

// 选择分类
const selectCategory = (code: string) => {
  selectedCategory.value = code
  currentPage.value = 1
  loadNews()
}

// 搜索
const handleSearch = async () => {
  if (searchKeyword.value.trim()) {
    await searchNews(searchKeyword.value, currentPage.value, pageSize.value)
  } else {
    loadNews()
  }
}

// 查看详情
const viewDetail = (id: number) => {
  router.push(`/news/${id}`)
}

// 翻页
const changePage = (page: number) => {
  currentPage.value = page
  loadNews()
}

// 格式化日期
const formatDate = (dateStr: string) => {
  // 实现日期格式化
  return dateStr
}

// 获取分类名称
const getCategoryName = (code: string) => {
  const cat = categories.value.find(c => c.code === code)
  return cat?.name || code
}

onMounted(() => {
  loadNews()
})
</script>

<style scoped>
.news-list-container {
  padding: 20px;
}

.category-filter {
  display: flex;
  gap: 10px;
  margin-bottom: 20px;
}

.category-filter button {
  padding: 8px 16px;
  border: 1px solid #ddd;
  border-radius: 4px;
  background: white;
  cursor: pointer;
}

.category-filter button.active {
  background: #1890ff;
  color: white;
  border-color: #1890ff;
}

.news-card {
  padding: 15px;
  border: 1px solid #eee;
  border-radius: 8px;
  margin-bottom: 15px;
  cursor: pointer;
  transition: box-shadow 0.2s;
}

.news-card:hover {
  box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}

.meta {
  display: flex;
  gap: 15px;
  font-size: 12px;
  color: #999;
  margin-top: 10px;
}

.classifier-info {
  display: flex;
  gap: 8px;
  margin-top: 8px;
}

.badge {
  padding: 2px 8px;
  background: #f0f0f0;
  border-radius: 4px;
  font-size: 12px;
}

.pagination {
  display: flex;
  justify-content: center;
  align-items: center;
  gap: 15px;
  margin-top: 20px;
}
</style>

任务 3.1.5: views/classifier/ClassifierPage.vue - 分类器页面

<template>
  <div class="classifier-page">
    <div class="page-header">
      <h2>文本分类</h2>
    </div>

    <!-- 分类模式选择 -->
    <div class="mode-selector">
      <label>分类模式:</label>
      <select v-model="selectedMode">
        <option value="traditional">传统机器学习 (TF-IDF + NB)</option>
        <option value="deep_learning">深度学习 (BERT)</option>
        <option value="hybrid">混合模式</option>
      </select>
    </div>

    <!-- 输入区域 -->
    <div class="input-area">
      <div class="form-group">
        <label>新闻标题</label>
        <input v-model="formData.title" type="text" placeholder="请输入新闻标题" />
      </div>
      <div class="form-group">
        <label>新闻内容</label>
        <textarea v-model="formData.content" placeholder="请输入新闻内容" rows="10"></textarea>
      </div>
      <button @click="handleClassify" :disabled="classifying">
        {{ classifying ? '分类中...' : '开始分类' }}
      </button>
    </div>

    <!-- 分类结果 -->
    <div v-if="result" class="result-area">
      <h3>分类结果</h3>
      <div class="result-item">
        <span class="label">分类:</span>
        <span class="value">{{ result.categoryName }} ({{ result.categoryCode }})</span>
      </div>
      <div class="result-item">
        <span class="label">置信度:</span>
        <span class="value">{{ (result.confidence * 100).toFixed(2) }}%</span>
      </div>
      <div class="result-item">
        <span class="label">分类器:</span>
        <span class="value">{{ result.classifierType }}</span>
      </div>
      <div class="result-item">
        <span class="label">耗时:</span>
        <span class="value">{{ result.duration }}ms</span>
      </div>

      <!-- 概率分布图 -->
      <div v-if="result.probabilities" class="probabilities">
        <h4>各类别概率</h4>
        <div v-for="(prob, code) in result.probabilities" :key="code" class="prob-bar">
          <span class="cat-name">{{ getCategoryName(code) }}</span>
          <div class="bar-container">
            <div class="bar" :style="{ width: (prob * 100) + '%' }"></div>
          </div>
          <span class="prob-value">{{ (prob * 100).toFixed(1) }}%</span>
        </div>
      </div>
    </div>
  </div>
</template>

<script setup lang="ts">
import { ref } from 'vue'
import { classifierApi } from '@/api/classifier'

interface ClassificationResult {
  categoryCode: string
  categoryName: string
  confidence: number
  classifierType: string
  duration: number
  probabilities?: Record<string, number>
}

const formData = ref({
  title: '',
  content: ''
})

const selectedMode = ref('hybrid')
const classifying = ref(false)
const result = ref<ClassificationResult | null>(null)

const handleClassify = async () => {
  if (!formData.value.title.trim() || !formData.value.content.trim()) {
    alert('请输入标题和内容')
    return
  }

  classifying.value = true
  try {
    result.value = await classifierApi.classify({
      title: formData.value.title,
      content: formData.value.content,
      mode: selectedMode.value
    })
  } catch (error) {
    console.error('分类失败:', error)
    alert('分类失败,请重试')
  } finally {
    classifying.value = false
  }
}

const getCategoryName = (code: string) => {
  const map: Record<string, string> = {
    POLITICS: '时政',
    FINANCE: '财经',
    TECHNOLOGY: '科技',
    SPORTS: '体育'
  }
  return map[code] || code
}
</script>

<style scoped>
.classifier-page {
  padding: 20px;
  max-width: 800px;
  margin: 0 auto;
}

.mode-selector {
  margin-bottom: 20px;
}

.mode-selector select {
  padding: 8px;
  border-radius: 4px;
  border: 1px solid #ddd;
}

.input-area {
  background: #f9f9f9;
  padding: 20px;
  border-radius: 8px;
}

.form-group {
  margin-bottom: 15px;
}

.form-group label {
  display: block;
  margin-bottom: 5px;
  font-weight: 500;
}

.form-group input,
.form-group textarea {
  width: 100%;
  padding: 10px;
  border: 1px solid #ddd;
  border-radius: 4px;
  box-sizing: border-box;
}

.result-area {
  margin-top: 20px;
  padding: 20px;
  background: #f0f7ff;
  border-radius: 8px;
}

.result-item {
  display: flex;
  padding: 8px 0;
}

.result-item .label {
  width: 80px;
  font-weight: 500;
}

.probabilities {
  margin-top: 20px;
}

.prob-bar {
  display: flex;
  align-items: center;
  margin-bottom: 10px;
}

.cat-name {
  width: 80px;
}

.bar-container {
  flex: 1;
  height: 20px;
  background: #e0e0e0;
  border-radius: 4px;
  overflow: hidden;
  margin: 0 10px;
}

.bar {
  height: 100%;
  background: linear-gradient(90deg, #1890ff, #52c41a);
  transition: width 0.3s;
}

.prob-value {
  width: 60px;
  text-align: right;
}
</style>

任务 3.1.6: router/index.ts - 路由配置 (更新)

import { createRouter, createWebHistory } from 'vue-router'
import type { RouteRecordRaw } from 'vue-router'

const routes: RouteRecordRaw[] = [
  {
    path: '/login',
    name: 'Login',
    component: () => import('@/views/auth/Login.vue'),
    meta: { layout: 'EmptyLayout' }
  },
  {
    path: '/',
    name: 'Home',
    component: () => import('@/views/news/NewsList.vue'),
    meta: { requiresAuth: true }
  },
  {
    path: '/news',
    name: 'NewsList',
    component: () => import('@/views/news/NewsList.vue'),
    meta: { requiresAuth: true }
  },
  {
    path: '/news/:id',
    name: 'NewsDetail',
    component: () => import('@/views/news/NewsDetail.vue'),
    meta: { requiresAuth: true }
  },
  {
    path: '/classifier',
    name: 'Classifier',
    component: () => import('@/views/classifier/ClassifierPage.vue'),
    meta: { requiresAuth: true }
  },
  {
    path: '/category',
    name: 'CategoryStats',
    component: () => import('@/views/category/CategoryStats.vue'),
    meta: { requiresAuth: true }
  },
  {
    path: '/admin',
    name: 'AdminDashboard',
    component: () => import('@/views/admin/Dashboard.vue'),
    meta: { requiresAuth: true, requiresAdmin: true }
  }
]

const router = createRouter({
  history: createWebHistory(),
  routes
})

// 路由守卫
router.beforeEach((to, from, next) => {
  const token = localStorage.getItem('token')

  if (to.meta.requiresAuth && !token) {
    next('/login')
  } else if (to.meta.requiresAdmin) {
    // 检查管理员权限
    const userRole = localStorage.getItem('userRole')
    if (userRole !== 'ADMIN') {
      next('/')
    } else {
      next()
    }
  } else {
    next()
  }
})

export default router

4. 机器学习分类模块 (Python)

模块目录结构

ml-module/
├── data/
│   ├── raw/                   # 原始数据
│   ├── processed/             # 处理后的数据
│   │   ├── training_data.csv
│   │   └── test_data.csv
│   └── external/              # 外部数据集
├── models/                    # 训练好的模型
│   ├── traditional/
│   │   ├── nb_vectorizer.pkl
│   │   ├── nb_classifier.pkl
│   │   ├── svm_vectorizer.pkl
│   │   └── svm_classifier.pkl
│   ├── deep_learning/
│   │   └── bert_finetuned/
│   └── hybrid/
│       └── config.json
├── src/
│   ├── __init__.py
│   ├── traditional/           # 传统机器学习
│   │   ├── __init__.py
│   │   ├── train_model.py     # (已有)
│   │   ├── predict.py
│   │   └── evaluate.py
│   ├── deep_learning/         # 深度学习
│   │   ├── __init__.py
│   │   ├── bert_model.py
│   │   ├── train_bert.py
│   │   └── predict_bert.py
│   ├── hybrid/                # 混合策略
│   │   ├── __init__.py
│   │   ├── hybrid_classifier.py
│   │   └── rule_engine.py
│   ├── utils/
│   │   ├── __init__.py
│   │   ├── preprocessing.py   # 数据预处理
│   │   └── metrics.py         # 评估指标
│   └── api/                   # API服务
│       ├── __init__.py
│       └── server.py          # FastAPI服务
├── notebooks/                 # Jupyter notebooks
│   ├── data_exploration.ipynb
│   └── model_comparison.ipynb
├── tests/                     # 测试
│   ├── test_traditional.py
│   ├── test_bert.py
│   └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md

4.1 需要完成的具体文件

任务 4.1.1: src/traditional/predict.py - 传统模型预测

"""
传统机器学习模型预测
"""

import os
import joblib
import jieba
from typing import Dict, Any

# 分类映射
CATEGORY_MAP = {
    'POLITICS': '时政',
    'FINANCE': '财经',
    'TECHNOLOGY': '科技',
    'SPORTS': '体育',
    'ENTERTAINMENT': '娱乐',
    'HEALTH': '健康',
    'EDUCATION': '教育',
    'LIFE': '生活',
    'INTERNATIONAL': '国际',
    'MILITARY': '军事'
}


class TraditionalPredictor:
    """传统机器学习预测器"""

    def __init__(self, model_type='nb', model_dir='../../models/traditional'):
        self.model_type = model_type
        self.model_dir = model_dir
        self.vectorizer = None
        self.classifier = None
        self._load_model()

    def _load_model(self):
        """加载模型"""
        vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
        classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')

        self.vectorizer = joblib.load(vectorizer_path)
        self.classifier = joblib.load(classifier_path)
        print(f"模型加载成功: {self.model_type}")

    def preprocess(self, title: str, content: str) -> str:
        """预处理文本"""
        text = title + ' ' + content
        # jieba分词
        words = jieba.cut(text)
        return ' '.join(words)

    def predict(self, title: str, content: str) -> Dict[str, Any]:
        """
        预测
        :return: 预测结果字典
        """
        # 预处理
        processed = self.preprocess(title, content)

        # 特征提取
        tfidf = self.vectorizer.transform([processed])

        # 预测
        prediction = self.classifier.predict(tfidf)[0]
        probabilities = self.classifier.predict_proba(tfidf)[0]

        # 获取各类别概率
        prob_dict = {}
        for i, prob in enumerate(probabilities):
            category_code = self.classifier.classes_[i]
            prob_dict[category_code] = float(prob)

        return {
            'categoryCode': prediction,
            'categoryName': CATEGORY_MAP.get(prediction, '未知'),
            'confidence': float(probabilities.max()),
            'probabilities': prob_dict
        }


# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
    """
    单条预测API
    """
    predictor = TraditionalPredictor(model_type)
    return predictor.predict(title, content)


if __name__ == '__main__':
    # 测试
    result = predict_single(
        title="华为发布新款折叠屏手机",
        content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
    )
    print(result)

任务 4.1.2: src/deep_learning/bert_model.py - BERT模型

"""
BERT文本分类模型
"""

import torch
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    Trainer,
    TrainingArguments
)
from typing import Dict, Any, List


# 分类映射
CATEGORY_MAP = {
    'POLITICS': '时政',
    'FINANCE': '财经',
    'TECHNOLOGY': '科技',
    'SPORTS': '体育',
    'ENTERTAINMENT': '娱乐',
    'HEALTH': '健康',
    'EDUCATION': '教育',
    'LIFE': '生活',
    'INTERNATIONAL': '国际',
    'MILITARY': '军事'
}

# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}


class BertClassifier:
    """BERT文本分类器"""

    def __init__(self, model_name='bert-base-chinese', num_labels=10):
        self.model_name = model_name
        self.num_labels = num_labels
        self.tokenizer = None
        self.model = None
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    def load_model(self, model_path):
        """加载微调后的模型"""
        self.tokenizer = BertTokenizer.from_pretrained(model_path)
        self.model = BertForSequenceClassification.from_pretrained(
            model_path,
            num_labels=self.num_labels
        )
        self.model.to(self.device)
        self.model.eval()
        print(f"BERT模型加载成功: {model_path}")

    def predict(self, title: str, content: str) -> Dict[str, Any]:
        """
        预测
        """
        if self.model is None or self.tokenizer is None:
            raise ValueError("模型未加载请先调用load_model")

        # 组合标题和内容
        text = f"{title} [SEP] {content}"

        # 分词
        inputs = self.tokenizer(
            text,
            return_tensors='pt',
            truncation=True,
            max_length=512,
            padding='max_length'
        )

        # 预测
        with torch.no_grad():
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            outputs = self.model(**inputs)
            logits = outputs.logits

        # 获取预测结果
        probs = torch.softmax(logits, dim=-1)
        confidence, predicted_id = torch.max(probs, dim=-1)

        predicted_id = predicted_id.item()
        confidence = confidence.item()

        # 获取各类别概率
        prob_dict = {}
        for i, prob in enumerate(probs[0].cpu().numpy()):
            category_code = ID_TO_LABEL[i]
            prob_dict[category_code] = float(prob)

        return {
            'categoryCode': ID_TO_LABEL[predicted_id],
            'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
            'confidence': confidence,
            'probabilities': prob_dict
        }


# 数据集类
class NewsDataset(torch.utils.data.Dataset):
    """新闻数据集"""

    def __init__(self, texts, labels, tokenizer, max_length=512):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.long)
        }


if __name__ == '__main__':
    # 测试
    classifier = BertClassifier()
    # classifier.load_model('./models/deep_learning/bert_finetuned')
    #
    # result = classifier.predict(
    #     title="华为发布新款折叠屏手机",
    #     content="华为今天正式发布了新一代折叠屏手机..."
    # )
    # print(result)
    print("BERT分类器初始化成功")

任务 4.1.3: src/hybrid/hybrid_classifier.py - 混合分类器

"""
混合策略分类器
结合规则引擎和机器学习模型
"""

import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier


class HybridClassifier:
    """混合分类器"""

    def __init__(self):
        # 初始化各个分类器
        self.nb_predictor = TraditionalPredictor('nb')
        self.bert_classifier = BertClassifier()

        # 配置参数
        self.config = {
            'confidence_threshold': 0.75,    # 高置信度阈值
            'hybrid_min_confidence': 0.60,   # 混合模式最低阈值
            'use_bert_threshold': 0.70,      # 使用BERT的阈值
            'rule_priority': True            # 规则优先
        }

        # 规则关键词字典
        self.rule_keywords = {
            'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
            'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
            'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
            'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
            'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
            'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
            'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
            'LIFE': ['生活', '美食', '旅游', '购物'],
            'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
            'MILITARY': ['军事', '武器', '军队', '国防', '战争']
        }

    def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
        """
        规则匹配
        :return: (category_code, confidence)
        """
        text = title + ' ' + content

        # 计算每个类别的关键词匹配数
        matches = {}
        for category, keywords in self.rule_keywords.items():
            count = sum(1 for kw in keywords if kw in text)
            if count > 0:
                matches[category] = count

        if not matches:
            return None, 0.0

        # 返回匹配最多的类别
        best_category = max(matches, key=matches.get)
        confidence = min(0.9, matches[best_category] * 0.15)  # 规则置信度

        return best_category, confidence

    def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
        """
        混合预测
        """
        start_time = time.time()

        # 1. 先尝试规则匹配
        rule_category, rule_confidence = self.rule_match(title, content)

        # 2. 传统机器学习预测
        nb_result = self.nb_predictor.predict(title, content)
        nb_confidence = nb_result['confidence']

        # 决策逻辑
        final_result = None
        classifier_type = 'HYBRID'

        # 规则优先且规则置信度高
        if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
            final_result = {
                'categoryCode': rule_category,
                'categoryName': nb_result['categoryName'],  # 从映射获取
                'confidence': rule_confidence,
                'classifierType': 'RULE',
                'reason': '规则匹配'
            }
        # 传统模型置信度足够高
        elif nb_confidence >= self.config['confidence_threshold']:
            final_result = {
                **nb_result,
                'classifierType': 'ML',
                'reason': '传统模型高置信度'
            }
        # 需要使用BERT
        elif use_bert:
            # TODO: 加载BERT模型预测
            # bert_result = self.bert_classifier.predict(title, content)
            # 如果BERT置信度也不高选择最高的
            final_result = {
                **nb_result,
                'classifierType': 'HYBRID',
                'reason': '混合决策'
            }
        else:
            # 不使用BERT直接返回传统模型结果
            final_result = {
                **nb_result,
                'classifierType': 'ML',
                'reason': '默认传统模型'
            }

        # 计算耗时
        duration = int((time.time() - start_time) * 1000)
        final_result['duration'] = duration

        return final_result


if __name__ == '__main__':
    # 测试
    classifier = HybridClassifier()

    test_cases = [
        {
            'title': '国务院发布最新经济政策',
            'content': '国务院今天发布了新的经济政策...'
        },
        {
            'title': '华为发布新款折叠屏手机',
            'content': '华为今天正式发布了新一代折叠屏手机...'
        }
    ]

    for case in test_cases:
        result = classifier.predict(case['title'], case['content'])
        print(f"标题: {case['title']}")
        print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
        print(f"分类器: {result['classifierType']}")
        print(f"原因: {result.get('reason', 'N/A')}")
        print(f"耗时: {result['duration']}ms")
        print("-" * 50)

任务 4.1.4: src/api/server.py - FastAPI服务

"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""

from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging

# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# 创建FastAPI应用
app = FastAPI(
    title="新闻分类API",
    description="提供新闻文本分类服务",
    version="1.0.0"
)

# 配置CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# 请求模型
class ClassifyRequest(BaseModel):
    title: str
    content: str
    mode: Optional[str] = 'hybrid'  # traditional, hybrid


# 响应模型
class ClassifyResponse(BaseModel):
    categoryCode: str
    categoryName: str
    confidence: float
    classifierType: str
    duration: int
    probabilities: Optional[dict] = None


# 初始化分类器
nb_predictor = None
hybrid_classifier = None


@app.on_event("startup")
async def startup_event():
    """启动时加载模型"""
    global nb_predictor, hybrid_classifier

    logger.info("加载模型...")

    try:
        nb_predictor = TraditionalPredictor('nb')
        logger.info("朴素贝叶斯模型加载成功")
    except Exception as e:
        logger.error(f"朴素贝叶斯模型加载失败: {e}")

    try:
        hybrid_classifier = HybridClassifier()
        logger.info("混合分类器初始化成功")
    except Exception as e:
        logger.error(f"混合分类器初始化失败: {e}")


@app.get("/")
async def root():
    """健康检查"""
    return {
        "status": "ok",
        "message": "新闻分类API服务运行中"
    }


@app.get("/health")
async def health_check():
    """健康检查"""
    return {
        "status": "healthy",
        "models": {
            "nb_loaded": nb_predictor is not None,
            "hybrid_loaded": hybrid_classifier is not None
        }
    }


@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
    """
    文本分类接口

    - **title**: 新闻标题
    - **content**: 新闻内容
    - **mode**: 分类模式 (traditional, hybrid)
    """
    try:
        if request.mode == 'traditional':
            result = nb_predictor.predict(request.title, request.content)
            result['classifierType'] = 'ML'
        else:  # hybrid
            result = hybrid_classifier.predict(request.title, request.content)

        return ClassifyResponse(**result)

    except Exception as e:
        logger.error(f"预测失败: {e}")
        raise HTTPException(status_code=500, detail=str(e))


@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
    """
    批量分类接口
    """
    results = []
    for req in requests:
        try:
            if req.mode == 'traditional':
                result = nb_predictor.predict(req.title, req.content)
                result['classifierType'] = 'ML'
            else:
                result = hybrid_classifier.predict(req.title, req.content)
            results.append(result)
        except Exception as e:
            results.append({
                'error': str(e),
                'title': req.title
            })

    return {"results": results}


if __name__ == '__main__':
    import uvicorn

    uvicorn.run(
        app,
        host="0.0.0.0",
        port=5000,
        log_level="info"
    )

任务 4.1.5: src/utils/metrics.py - 评估指标

"""
模型评估指标工具
"""

import numpy as np
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns


class ClassificationMetrics:
    """分类评估指标"""

    @staticmethod
    def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
        """
        计算所有指标
        """
        accuracy = accuracy_score(y_true, y_pred)

        precision, recall, f1, support = precision_recall_fscore_support(
            y_true, y_pred, average='weighted', zero_division=0
        )

        # 每个类别的指标
        precision_per_class, recall_per_class, f1_per_class, support_per_class = \
            precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)

        per_class_metrics = {}
        for i, label in enumerate(labels):
            per_class_metrics[label] = {
                'precision': float(precision_per_class[i]),
                'recall': float(recall_per_class[i]),
                'f1': float(f1_per_class[i]),
                'support': int(support_per_class[i])
            }

        return {
            'accuracy': float(accuracy),
            'precision': float(precision),
            'recall': float(recall),
            'f1': float(f1),
            'per_class': per_class_metrics
        }

    @staticmethod
    def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
        """
        绘制混淆矩阵
        """
        cm = confusion_matrix(y_true, y_pred)

        plt.figure(figsize=(10, 8))
        sns.heatmap(
            cm,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=labels,
            yticklabels=labels
        )
        plt.xlabel('预测标签')
        plt.ylabel('真实标签')
        plt.title('混淆矩阵')

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()

    @staticmethod
    def print_report(y_true: List, y_pred: List, labels: List[str]):
        """
        打印分类报告
        """
        report = classification_report(
            y_true, y_pred,
            target_names=labels,
            zero_division=0
        )
        print(report)


if __name__ == '__main__':
    # 测试
    y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
    y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
    labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']

    metrics = ClassificationMetrics()
    result = metrics.compute_all(y_true, y_pred, labels)
    print(result)

任务 4.1.6: requirements.txt - 依赖文件

# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0

# 深度学习
torch>=2.0.0
transformers>=4.30.0

# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0

# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0

# 工具
python-dotenv>=1.0.0
pyyaml>=6.0

总结

开发顺序建议

  1. 第一阶段:基础框架

    • 后端:数据库连接、实体类、基础配置
    • 前端路由配置、状态管理、API封装
  2. 第二阶段:核心功能

    • 爬虫模块Python
    • 传统机器学习分类器
    • 后端API接口
    • 前端新闻列表页面
  3. 第三阶段:高级功能

    • BERT深度学习分类器
    • 混合策略分类器
    • 前端分类器对比页面
    • 统计图表
  4. 第四阶段:完善优化

    • 用户认证
    • 数据可视化
    • 性能优化
    • 异常处理

关键注意事项

  1. 爬虫模块使用 Python,通过 RESTful API 与 Java 后端通信
  2. 分类器模块独立部署,提供 HTTP 接口供后端调用
  3. 前后端分离,使用 JWT 进行身份认证
  4. 数据库表结构已在 schema.sql 中定义,需严格遵守
  5. API 统一响应格式使用 Result<T> 包装