news-classifier/docs/模块开发任务清单.md

1666 lines
42 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 新闻文本分类系统 - 模块开发任务清单
> 本文档详细列出每个模块需要完成的具体代码任务,参考现有工程结构。
> 注意:爬虫模块使用 Python 实现,而非 Java。
---
## 目录
1. [爬虫模块 (Python)](#1-爬虫模块-python)
2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot)
3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3)
4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python)
---
## 1. 爬虫模块 (Python)
## 2. 后端服务模块 (Spring Boot)
### 模块目录结构
```
#### 任务 2.1.10: `application.yml` - 应用配置文件
```yaml
spring:
application:
name: news-classifier
datasource:
driver-class-name: com.mysql.cj.jdbc.Driver
url: jdbc:mysql://localhost:3306/news_classifier?useUnicode=true&characterEncoding=utf8mb4&serverTimezone=Asia/Shanghai
username: root
password: your_password
data:
redis:
host: localhost
port: 6379
database: 0
# MyBatis-Plus配置
mybatis-plus:
mapper-locations: classpath:mapper/**/*.xml
type-aliases-package: com.newsclassifier.entity
configuration:
map-underscore-to-camel-case: true
log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
global-config:
db-config:
id-type: auto
logic-delete-field: deleted
logic-delete-value: 1
logic-not-delete-value: 0
# JWT配置
jwt:
secret: your-secret-key-at-least-256-bits-long-for-hs256-algorithm
expiration: 86400000
# 分类器配置
classifier:
mode: hybrid # traditional, deep_learning, hybrid
confidence:
threshold: 0.75
hybrid-min: 0.6
bert:
service-url: http://localhost:5000/api/predict
timeout: 5000
# 日志配置
logging:
level:
com.newsclassifier: debug
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"
```
---
## 3. 前端桌面模块 (Tauri + Vue3)
### 模块目录结构
```
client/src/
├── api/ # API接口
│ ├── index.ts
│ ├── auth.ts
│ ├── news.ts
│ ├── category.ts
│ └── classifier.ts
├── assets/ # 静态资源
│ ├── images/
│ ├── styles/
│ │ ├── main.css
│ │ └── tailwind.css
│ └── fonts/
├── components/ # 组件
│ ├── ui/ # 基础UI组件
│ │ ├── button/
│ │ ├── input/
│ │ ├── dialog/
│ │ └── table/
│ ├── layout/ # 布局组件
│ │ ├── Header.vue
│ │ ├── Sidebar.vue
│ │ └── Footer.vue
│ ├── news/ # 新闻相关组件
│ │ ├── NewsCard.vue
│ │ ├── NewsList.vue
│ │ ├── NewsDetail.vue
│ │ └── CategoryFilter.vue
│ └── charts/ # 图表组件
│ ├── CategoryChart.vue
│ └── TrendChart.vue
├── composables/ # 组合式函数
│ ├── useAuth.ts
│ ├── useNews.ts
│ ├── useClassifier.ts
│ └── useToast.ts
├── layouts/ # 布局
│ ├── DefaultLayout.vue
│ ├── AuthLayout.vue
│ └── EmptyLayout.vue
├── router/ # 路由 (部分完成)
│ └── index.ts
├── stores/ # 状态管理 (部分完成)
│ ├── user.ts
│ ├── news.ts
│ └── category.ts
├── types/ # TypeScript类型
│ ├── api.d.ts
│ ├── news.d.ts
│ └── user.d.ts
├── utils/ # 工具函数
│ ├── request.ts
│ ├── storage.ts
│ ├── format.ts
│ └── validate.ts
├── views/ # 页面
│ ├── auth/
│ │ ├── Login.vue
│ │ └── Register.vue
│ ├── news/
│ │ ├── NewsList.vue
│ │ ├── NewsDetail.vue
│ │ └── NewsSearch.vue
│ ├── category/
│ │ ├── CategoryManage.vue
│ │ └── CategoryStats.vue
│ ├── classifier/
│ │ ├── ClassifierPage.vue
│ │ └── ModelCompare.vue
│ └── admin/
│ ├── Dashboard.vue
│ ├── UserManage.vue
│ └── SystemLog.vue
├── App.vue
└── main.ts
```
### 3.1 需要完成的具体文件
#### 任务 3.1.1: `utils/request.ts` - HTTP请求封装
```typescript
import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios'
// 响应数据类型
interface ApiResponse<T = any> {
code: number
message: string
data: T
}
// 创建axios实例
const service: AxiosInstance = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8080/api',
timeout: 15000,
headers: {
'Content-Type': 'application/json'
}
})
// 请求拦截器
service.interceptors.request.use(
(config) => {
const token = localStorage.getItem('token')
if (token) {
config.headers.Authorization = `Bearer ${token}`
}
return config
},
(error) => {
return Promise.reject(error)
}
)
// 响应拦截器
service.interceptors.response.use(
(response: AxiosResponse<ApiResponse>) => {
const { code, message, data } = response.data
if (code === 200) {
return data
} else {
// 处理错误
return Promise.reject(new Error(message || '请求失败'))
}
},
(error) => {
// 处理HTTP错误
return Promise.reject(error)
}
)
// 封装请求方法
export const http = {
get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
return service.get(url, config)
},
post<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
return service.post(url, data, config)
},
put<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
return service.put(url, data, config)
},
delete<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
return service.delete(url, config)
}
}
export default service
```
#### 任务 3.1.2: `api/news.ts` - 新闻API
```typescript
import { http } from './request'
// 新闻查询参数
export interface NewsQueryParams {
page?: number
size?: number
categoryId?: number
categoryCode?: string
keyword?: string
status?: number
}
// 新闻详情
export interface NewsDetail {
id: number
title: string
content: string
summary: string
source: string
sourceUrl: string
author: string
categoryId: number
categoryCode: string
coverImage: string
publishTime: string
viewCount: number
likeCount: number
commentCount: number
classifierType: string
confidence: number
}
// 分页响应
export interface PageResponse<T> {
total: number
records: T[]
current: number
size: number
}
// 新闻API
export const newsApi = {
// 分页查询
getNewsPage(params: NewsQueryParams): Promise<PageResponse<NewsDetail>> {
return http.get('/news/page', { params })
},
// 获取详情
getNewsDetail(id: number): Promise<NewsDetail> {
return http.get(`/news/${id}`)
},
// 搜索新闻
searchNews(keyword: string, page = 1, size = 20): Promise<PageResponse<NewsDetail>> {
return http.get('/news/search', { params: { keyword, page, size } })
},
// 手动分类
manualClassify(id: number, categoryId: number): Promise<void> {
return http.post(`/news/${id}/classify`, null, { params: { categoryId } })
}
}
```
#### 任务 3.1.3: `composables/useNews.ts` - 新闻组合式函数
```typescript
import { ref, computed } from 'vue'
import { newsApi, type NewsQueryParams, type NewsDetail, type PageResponse } from '@/api/news'
export function useNews() {
const loading = ref(false)
const newsList = ref<NewsDetail[]>([])
const total = ref(0)
const currentNews = ref<NewsDetail | null>(null)
// 分页查询
const fetchNewsPage = async (params: NewsQueryParams) => {
loading.value = true
try {
const result: PageResponse<NewsDetail> = await newsApi.getNewsPage(params)
newsList.value = result.records
total.value = result.total
} catch (error) {
console.error('获取新闻列表失败:', error)
throw error
} finally {
loading.value = false
}
}
// 获取详情
const fetchNewsDetail = async (id: number) => {
loading.value = true
try {
currentNews.value = await newsApi.getNewsDetail(id)
return currentNews.value
} catch (error) {
console.error('获取新闻详情失败:', error)
throw error
} finally {
loading.value = false
}
}
// 搜索新闻
const searchNews = async (keyword: string, page = 1, size = 20) => {
loading.value = true
try {
const result: PageResponse<NewsDetail> = await newsApi.searchNews(keyword, page, size)
newsList.value = result.records
total.value = result.total
} catch (error) {
console.error('搜索新闻失败:', error)
throw error
} finally {
loading.value = false
}
}
return {
loading: computed(() => loading.value),
newsList: computed(() => newsList.value),
total: computed(() => total.value),
currentNews: computed(() => currentNews.value),
fetchNewsPage,
fetchNewsDetail,
searchNews
}
}
```
#### 任务 3.1.4: `views/news/NewsList.vue` - 新闻列表页面
```vue
<template>
<div class="news-list-container">
<!-- 分类筛选 -->
<div class="category-filter">
<button
v-for="cat in categories"
:key="cat.code"
:class="{ active: selectedCategory === cat.code }"
@click="selectCategory(cat.code)"
>
{{ cat.name }}
</button>
</div>
<!-- 搜索框 -->
<div class="search-bar">
<input
v-model="searchKeyword"
type="text"
placeholder="搜索新闻..."
@keyup.enter="handleSearch"
/>
<button @click="handleSearch">搜索</button>
</div>
<!-- 新闻列表 -->
<div v-if="!loading && newsList.length > 0" class="news-list">
<div v-for="news in newsList" :key="news.id" class="news-card" @click="viewDetail(news.id)">
<h3>{{ news.title }}</h3>
<p class="summary">{{ news.summary }}</p>
<div class="meta">
<span class="category">{{ getCategoryName(news.categoryCode) }}</span>
<span class="source">{{ news.source }}</span>
<span class="time">{{ formatDate(news.publishTime) }}</span>
</div>
<div v-if="news.classifierType" class="classifier-info">
<span class="badge">{{ news.classifierType }}</span>
<span class="confidence">{{ (news.confidence * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
<!-- 加载中 -->
<div v-if="loading" class="loading">加载中...</div>
<!-- 空状态 -->
<div v-if="!loading && newsList.length === 0" class="empty">暂无数据</div>
<!-- 分页 -->
<div v-if="total > 0" class="pagination">
<button :disabled="currentPage <= 1" @click="changePage(currentPage - 1)">上一页</button>
<span>{{ currentPage }} / {{ totalPages }}</span>
<button :disabled="currentPage >= totalPages" @click="changePage(currentPage + 1)">下一页</button>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { useNews } from '@/composables/useNews'
const router = useRouter()
const { loading, newsList, total, fetchNewsPage, searchNews } = useNews()
const categories = ref([
{ code: '', name: '全部' },
{ code: 'POLITICS', name: '时政' },
{ code: 'FINANCE', name: '财经' },
{ code: 'TECHNOLOGY', name: '科技' },
{ code: 'SPORTS', name: '体育' }
])
const selectedCategory = ref('')
const searchKeyword = ref('')
const currentPage = ref(1)
const pageSize = ref(20)
const totalPages = computed(() => Math.ceil(total.value / pageSize.value))
// 加载新闻
const loadNews = async () => {
await fetchNewsPage({
page: currentPage.value,
size: pageSize.value,
categoryCode: selectedCategory.value || undefined
})
}
// 选择分类
const selectCategory = (code: string) => {
selectedCategory.value = code
currentPage.value = 1
loadNews()
}
// 搜索
const handleSearch = async () => {
if (searchKeyword.value.trim()) {
await searchNews(searchKeyword.value, currentPage.value, pageSize.value)
} else {
loadNews()
}
}
// 查看详情
const viewDetail = (id: number) => {
router.push(`/news/${id}`)
}
// 翻页
const changePage = (page: number) => {
currentPage.value = page
loadNews()
}
// 格式化日期
const formatDate = (dateStr: string) => {
// 实现日期格式化
return dateStr
}
// 获取分类名称
const getCategoryName = (code: string) => {
const cat = categories.value.find(c => c.code === code)
return cat?.name || code
}
onMounted(() => {
loadNews()
})
</script>
<style scoped>
.news-list-container {
padding: 20px;
}
.category-filter {
display: flex;
gap: 10px;
margin-bottom: 20px;
}
.category-filter button {
padding: 8px 16px;
border: 1px solid #ddd;
border-radius: 4px;
background: white;
cursor: pointer;
}
.category-filter button.active {
background: #1890ff;
color: white;
border-color: #1890ff;
}
.news-card {
padding: 15px;
border: 1px solid #eee;
border-radius: 8px;
margin-bottom: 15px;
cursor: pointer;
transition: box-shadow 0.2s;
}
.news-card:hover {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}
.meta {
display: flex;
gap: 15px;
font-size: 12px;
color: #999;
margin-top: 10px;
}
.classifier-info {
display: flex;
gap: 8px;
margin-top: 8px;
}
.badge {
padding: 2px 8px;
background: #f0f0f0;
border-radius: 4px;
font-size: 12px;
}
.pagination {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
margin-top: 20px;
}
</style>
```
#### 任务 3.1.5: `views/classifier/ClassifierPage.vue` - 分类器页面
```vue
<template>
<div class="classifier-page">
<div class="page-header">
<h2>文本分类</h2>
</div>
<!-- 分类模式选择 -->
<div class="mode-selector">
<label>分类模式:</label>
<select v-model="selectedMode">
<option value="traditional">传统机器学习 (TF-IDF + NB)</option>
<option value="deep_learning">深度学习 (BERT)</option>
<option value="hybrid">混合模式</option>
</select>
</div>
<!-- 输入区域 -->
<div class="input-area">
<div class="form-group">
<label>新闻标题</label>
<input v-model="formData.title" type="text" placeholder="请输入新闻标题" />
</div>
<div class="form-group">
<label>新闻内容</label>
<textarea v-model="formData.content" placeholder="请输入新闻内容" rows="10"></textarea>
</div>
<button @click="handleClassify" :disabled="classifying">
{{ classifying ? '分类中...' : '开始分类' }}
</button>
</div>
<!-- 分类结果 -->
<div v-if="result" class="result-area">
<h3>分类结果</h3>
<div class="result-item">
<span class="label">分类:</span>
<span class="value">{{ result.categoryName }} ({{ result.categoryCode }})</span>
</div>
<div class="result-item">
<span class="label">置信度:</span>
<span class="value">{{ (result.confidence * 100).toFixed(2) }}%</span>
</div>
<div class="result-item">
<span class="label">分类器:</span>
<span class="value">{{ result.classifierType }}</span>
</div>
<div class="result-item">
<span class="label">耗时:</span>
<span class="value">{{ result.duration }}ms</span>
</div>
<!-- 概率分布图 -->
<div v-if="result.probabilities" class="probabilities">
<h4>各类别概率</h4>
<div v-for="(prob, code) in result.probabilities" :key="code" class="prob-bar">
<span class="cat-name">{{ getCategoryName(code) }}</span>
<div class="bar-container">
<div class="bar" :style="{ width: (prob * 100) + '%' }"></div>
</div>
<span class="prob-value">{{ (prob * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import { classifierApi } from '@/api/classifier'
interface ClassificationResult {
categoryCode: string
categoryName: string
confidence: number
classifierType: string
duration: number
probabilities?: Record<string, number>
}
const formData = ref({
title: '',
content: ''
})
const selectedMode = ref('hybrid')
const classifying = ref(false)
const result = ref<ClassificationResult | null>(null)
const handleClassify = async () => {
if (!formData.value.title.trim() || !formData.value.content.trim()) {
alert('请输入标题和内容')
return
}
classifying.value = true
try {
result.value = await classifierApi.classify({
title: formData.value.title,
content: formData.value.content,
mode: selectedMode.value
})
} catch (error) {
console.error('分类失败:', error)
alert('分类失败,请重试')
} finally {
classifying.value = false
}
}
const getCategoryName = (code: string) => {
const map: Record<string, string> = {
POLITICS: '时政',
FINANCE: '财经',
TECHNOLOGY: '科技',
SPORTS: '体育'
}
return map[code] || code
}
</script>
<style scoped>
.classifier-page {
padding: 20px;
max-width: 800px;
margin: 0 auto;
}
.mode-selector {
margin-bottom: 20px;
}
.mode-selector select {
padding: 8px;
border-radius: 4px;
border: 1px solid #ddd;
}
.input-area {
background: #f9f9f9;
padding: 20px;
border-radius: 8px;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
margin-bottom: 5px;
font-weight: 500;
}
.form-group input,
.form-group textarea {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
box-sizing: border-box;
}
.result-area {
margin-top: 20px;
padding: 20px;
background: #f0f7ff;
border-radius: 8px;
}
.result-item {
display: flex;
padding: 8px 0;
}
.result-item .label {
width: 80px;
font-weight: 500;
}
.probabilities {
margin-top: 20px;
}
.prob-bar {
display: flex;
align-items: center;
margin-bottom: 10px;
}
.cat-name {
width: 80px;
}
.bar-container {
flex: 1;
height: 20px;
background: #e0e0e0;
border-radius: 4px;
overflow: hidden;
margin: 0 10px;
}
.bar {
height: 100%;
background: linear-gradient(90deg, #1890ff, #52c41a);
transition: width 0.3s;
}
.prob-value {
width: 60px;
text-align: right;
}
</style>
```
#### 任务 3.1.6: `router/index.ts` - 路由配置 (更新)
```typescript
import { createRouter, createWebHistory } from 'vue-router'
import type { RouteRecordRaw } from 'vue-router'
const routes: RouteRecordRaw[] = [
{
path: '/login',
name: 'Login',
component: () => import('@/views/auth/Login.vue'),
meta: { layout: 'EmptyLayout' }
},
{
path: '/',
name: 'Home',
component: () => import('@/views/news/NewsList.vue'),
meta: { requiresAuth: true }
},
{
path: '/news',
name: 'NewsList',
component: () => import('@/views/news/NewsList.vue'),
meta: { requiresAuth: true }
},
{
path: '/news/:id',
name: 'NewsDetail',
component: () => import('@/views/news/NewsDetail.vue'),
meta: { requiresAuth: true }
},
{
path: '/classifier',
name: 'Classifier',
component: () => import('@/views/classifier/ClassifierPage.vue'),
meta: { requiresAuth: true }
},
{
path: '/category',
name: 'CategoryStats',
component: () => import('@/views/category/CategoryStats.vue'),
meta: { requiresAuth: true }
},
{
path: '/admin',
name: 'AdminDashboard',
component: () => import('@/views/admin/Dashboard.vue'),
meta: { requiresAuth: true, requiresAdmin: true }
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
// 路由守卫
router.beforeEach((to, from, next) => {
const token = localStorage.getItem('token')
if (to.meta.requiresAuth && !token) {
next('/login')
} else if (to.meta.requiresAdmin) {
// 检查管理员权限
const userRole = localStorage.getItem('userRole')
if (userRole !== 'ADMIN') {
next('/')
} else {
next()
}
} else {
next()
}
})
export default router
```
---
## 4. 机器学习分类模块 (Python)
### 模块目录结构
```
ml-module/
├── data/
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后的数据
│ │ ├── training_data.csv
│ │ └── test_data.csv
│ └── external/ # 外部数据集
├── models/ # 训练好的模型
│ ├── traditional/
│ │ ├── nb_vectorizer.pkl
│ │ ├── nb_classifier.pkl
│ │ ├── svm_vectorizer.pkl
│ │ └── svm_classifier.pkl
│ ├── deep_learning/
│ │ └── bert_finetuned/
│ └── hybrid/
│ └── config.json
├── src/
│ ├── __init__.py
│ ├── traditional/ # 传统机器学习
│ │ ├── __init__.py
│ │ ├── train_model.py # (已有)
│ │ ├── predict.py
│ │ └── evaluate.py
│ ├── deep_learning/ # 深度学习
│ │ ├── __init__.py
│ │ ├── bert_model.py
│ │ ├── train_bert.py
│ │ └── predict_bert.py
│ ├── hybrid/ # 混合策略
│ │ ├── __init__.py
│ │ ├── hybrid_classifier.py
│ │ └── rule_engine.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── preprocessing.py # 数据预处理
│ │ └── metrics.py # 评估指标
│ └── api/ # API服务
│ ├── __init__.py
│ └── server.py # FastAPI服务
├── notebooks/ # Jupyter notebooks
│ ├── data_exploration.ipynb
│ └── model_comparison.ipynb
├── tests/ # 测试
│ ├── test_traditional.py
│ ├── test_bert.py
│ └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md
```
### 4.1 需要完成的具体文件
#### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测
```python
"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)
```
#### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型
```python
"""
BERT文本分类模型
"""
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments
)
from typing import Dict, Any, List
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
class BertClassifier:
"""BERT文本分类器"""
def __init__(self, model_name='bert-base-chinese', num_labels=10):
self.model_name = model_name
self.num_labels = num_labels
self.tokenizer = None
self.model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self, model_path):
"""加载微调后的模型"""
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=self.num_labels
)
self.model.to(self.device)
self.model.eval()
print(f"BERT模型加载成功: {model_path}")
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
"""
if self.model is None or self.tokenizer is None:
raise ValueError("模型未加载请先调用load_model")
# 组合标题和内容
text = f"{title} [SEP] {content}"
# 分词
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
# 预测
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
# 获取预测结果
probs = torch.softmax(logits, dim=-1)
confidence, predicted_id = torch.max(probs, dim=-1)
predicted_id = predicted_id.item()
confidence = confidence.item()
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probs[0].cpu().numpy()):
category_code = ID_TO_LABEL[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': ID_TO_LABEL[predicted_id],
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
'confidence': confidence,
'probabilities': prob_dict
}
# 数据集类
class NewsDataset(torch.utils.data.Dataset):
"""新闻数据集"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
if __name__ == '__main__':
# 测试
classifier = BertClassifier()
# classifier.load_model('./models/deep_learning/bert_finetuned')
#
# result = classifier.predict(
# title="华为发布新款折叠屏手机",
# content="华为今天正式发布了新一代折叠屏手机..."
# )
# print(result)
print("BERT分类器初始化成功")
```
#### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器
```python
"""
混合策略分类器
结合规则引擎和机器学习模型
"""
import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier
class HybridClassifier:
"""混合分类器"""
def __init__(self):
# 初始化各个分类器
self.nb_predictor = TraditionalPredictor('nb')
self.bert_classifier = BertClassifier()
# 配置参数
self.config = {
'confidence_threshold': 0.75, # 高置信度阈值
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
'use_bert_threshold': 0.70, # 使用BERT的阈值
'rule_priority': True # 规则优先
}
# 规则关键词字典
self.rule_keywords = {
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
'LIFE': ['生活', '美食', '旅游', '购物'],
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
}
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
"""
规则匹配
:return: (category_code, confidence)
"""
text = title + ' ' + content
# 计算每个类别的关键词匹配数
matches = {}
for category, keywords in self.rule_keywords.items():
count = sum(1 for kw in keywords if kw in text)
if count > 0:
matches[category] = count
if not matches:
return None, 0.0
# 返回匹配最多的类别
best_category = max(matches, key=matches.get)
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
return best_category, confidence
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
"""
混合预测
"""
start_time = time.time()
# 1. 先尝试规则匹配
rule_category, rule_confidence = self.rule_match(title, content)
# 2. 传统机器学习预测
nb_result = self.nb_predictor.predict(title, content)
nb_confidence = nb_result['confidence']
# 决策逻辑
final_result = None
classifier_type = 'HYBRID'
# 规则优先且规则置信度高
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
final_result = {
'categoryCode': rule_category,
'categoryName': nb_result['categoryName'], # 从映射获取
'confidence': rule_confidence,
'classifierType': 'RULE',
'reason': '规则匹配'
}
# 传统模型置信度足够高
elif nb_confidence >= self.config['confidence_threshold']:
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '传统模型高置信度'
}
# 需要使用BERT
elif use_bert:
# TODO: 加载BERT模型预测
# bert_result = self.bert_classifier.predict(title, content)
# 如果BERT置信度也不高选择最高的
final_result = {
**nb_result,
'classifierType': 'HYBRID',
'reason': '混合决策'
}
else:
# 不使用BERT直接返回传统模型结果
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '默认传统模型'
}
# 计算耗时
duration = int((time.time() - start_time) * 1000)
final_result['duration'] = duration
return final_result
if __name__ == '__main__':
# 测试
classifier = HybridClassifier()
test_cases = [
{
'title': '国务院发布最新经济政策',
'content': '国务院今天发布了新的经济政策...'
},
{
'title': '华为发布新款折叠屏手机',
'content': '华为今天正式发布了新一代折叠屏手机...'
}
]
for case in test_cases:
result = classifier.predict(case['title'], case['content'])
print(f"标题: {case['title']}")
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
print(f"分类器: {result['classifierType']}")
print(f"原因: {result.get('reason', 'N/A')}")
print(f"耗时: {result['duration']}ms")
print("-" * 50)
```
#### 任务 4.1.4: `src/api/server.py` - FastAPI服务
```python
"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging
# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title="新闻分类API",
description="提供新闻文本分类服务",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class ClassifyRequest(BaseModel):
title: str
content: str
mode: Optional[str] = 'hybrid' # traditional, hybrid
# 响应模型
class ClassifyResponse(BaseModel):
categoryCode: str
categoryName: str
confidence: float
classifierType: str
duration: int
probabilities: Optional[dict] = None
# 初始化分类器
nb_predictor = None
hybrid_classifier = None
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
global nb_predictor, hybrid_classifier
logger.info("加载模型...")
try:
nb_predictor = TraditionalPredictor('nb')
logger.info("朴素贝叶斯模型加载成功")
except Exception as e:
logger.error(f"朴素贝叶斯模型加载失败: {e}")
try:
hybrid_classifier = HybridClassifier()
logger.info("混合分类器初始化成功")
except Exception as e:
logger.error(f"混合分类器初始化失败: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"message": "新闻分类API服务运行中"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"models": {
"nb_loaded": nb_predictor is not None,
"hybrid_loaded": hybrid_classifier is not None
}
}
@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
"""
文本分类接口
- **title**: 新闻标题
- **content**: 新闻内容
- **mode**: 分类模式 (traditional, hybrid)
"""
try:
if request.mode == 'traditional':
result = nb_predictor.predict(request.title, request.content)
result['classifierType'] = 'ML'
else: # hybrid
result = hybrid_classifier.predict(request.title, request.content)
return ClassifyResponse(**result)
except Exception as e:
logger.error(f"预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
"""
批量分类接口
"""
results = []
for req in requests:
try:
if req.mode == 'traditional':
result = nb_predictor.predict(req.title, req.content)
result['classifierType'] = 'ML'
else:
result = hybrid_classifier.predict(req.title, req.content)
results.append(result)
except Exception as e:
results.append({
'error': str(e),
'title': req.title
})
return {"results": results}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=5000,
log_level="info"
)
```
#### 任务 4.1.5: `src/utils/metrics.py` - 评估指标
```python
"""
模型评估指标工具
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
class ClassificationMetrics:
"""分类评估指标"""
@staticmethod
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
"""
计算所有指标
"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# 每个类别的指标
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
per_class_metrics = {}
for i, label in enumerate(labels):
per_class_metrics[label] = {
'precision': float(precision_per_class[i]),
'recall': float(recall_per_class[i]),
'f1': float(f1_per_class[i]),
'support': int(support_per_class[i])
}
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'per_class': per_class_metrics
}
@staticmethod
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
"""
绘制混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def print_report(y_true: List, y_pred: List, labels: List[str]):
"""
打印分类报告
"""
report = classification_report(
y_true, y_pred,
target_names=labels,
zero_division=0
)
print(report)
if __name__ == '__main__':
# 测试
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
metrics = ClassificationMetrics()
result = metrics.compute_all(y_true, y_pred, labels)
print(result)
```
#### 任务 4.1.6: `requirements.txt` - 依赖文件
```txt
# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0
# 深度学习
torch>=2.0.0
transformers>=4.30.0
# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 工具
python-dotenv>=1.0.0
pyyaml>=6.0
```
---
## 总结
### 开发顺序建议
1. **第一阶段:基础框架**
- 后端:数据库连接、实体类、基础配置
- 前端路由配置、状态管理、API封装
2. **第二阶段:核心功能**
- 爬虫模块Python
- 传统机器学习分类器
- 后端API接口
- 前端新闻列表页面
3. **第三阶段:高级功能**
- BERT深度学习分类器
- 混合策略分类器
- 前端分类器对比页面
- 统计图表
4. **第四阶段:完善优化**
- 用户认证
- 数据可视化
- 性能优化
- 异常处理
### 关键注意事项
1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信
2. **分类器模块独立部署**,提供 HTTP 接口供后端调用
3. **前后端分离**,使用 JWT 进行身份认证
4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守
5. **API 统一响应格式**使用 `Result<T>` 包装