42 KiB
42 KiB
新闻文本分类系统 - 模块开发任务清单
本文档详细列出每个模块需要完成的具体代码任务,参考现有工程结构。 注意:爬虫模块使用 Python 实现,而非 Java。
目录
1. 爬虫模块 (Python)
2. 后端服务模块 (Spring Boot)
模块目录结构
#### 任务 2.1.10: `application.yml` - 应用配置文件
```yaml
spring:
application:
name: news-classifier
datasource:
driver-class-name: com.mysql.cj.jdbc.Driver
url: jdbc:mysql://localhost:3306/news_classifier?useUnicode=true&characterEncoding=utf8mb4&serverTimezone=Asia/Shanghai
username: root
password: your_password
data:
redis:
host: localhost
port: 6379
database: 0
# MyBatis-Plus配置
mybatis-plus:
mapper-locations: classpath:mapper/**/*.xml
type-aliases-package: com.newsclassifier.entity
configuration:
map-underscore-to-camel-case: true
log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
global-config:
db-config:
id-type: auto
logic-delete-field: deleted
logic-delete-value: 1
logic-not-delete-value: 0
# JWT配置
jwt:
secret: your-secret-key-at-least-256-bits-long-for-hs256-algorithm
expiration: 86400000
# 分类器配置
classifier:
mode: hybrid # traditional, deep_learning, hybrid
confidence:
threshold: 0.75
hybrid-min: 0.6
bert:
service-url: http://localhost:5000/api/predict
timeout: 5000
# 日志配置
logging:
level:
com.newsclassifier: debug
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"
3. 前端桌面模块 (Tauri + Vue3)
模块目录结构
client/src/
├── api/ # API接口
│ ├── index.ts
│ ├── auth.ts
│ ├── news.ts
│ ├── category.ts
│ └── classifier.ts
├── assets/ # 静态资源
│ ├── images/
│ ├── styles/
│ │ ├── main.css
│ │ └── tailwind.css
│ └── fonts/
├── components/ # 组件
│ ├── ui/ # 基础UI组件
│ │ ├── button/
│ │ ├── input/
│ │ ├── dialog/
│ │ └── table/
│ ├── layout/ # 布局组件
│ │ ├── Header.vue
│ │ ├── Sidebar.vue
│ │ └── Footer.vue
│ ├── news/ # 新闻相关组件
│ │ ├── NewsCard.vue
│ │ ├── NewsList.vue
│ │ ├── NewsDetail.vue
│ │ └── CategoryFilter.vue
│ └── charts/ # 图表组件
│ ├── CategoryChart.vue
│ └── TrendChart.vue
├── composables/ # 组合式函数
│ ├── useAuth.ts
│ ├── useNews.ts
│ ├── useClassifier.ts
│ └── useToast.ts
├── layouts/ # 布局
│ ├── DefaultLayout.vue
│ ├── AuthLayout.vue
│ └── EmptyLayout.vue
├── router/ # 路由 (部分完成)
│ └── index.ts
├── stores/ # 状态管理 (部分完成)
│ ├── user.ts
│ ├── news.ts
│ └── category.ts
├── types/ # TypeScript类型
│ ├── api.d.ts
│ ├── news.d.ts
│ └── user.d.ts
├── utils/ # 工具函数
│ ├── request.ts
│ ├── storage.ts
│ ├── format.ts
│ └── validate.ts
├── views/ # 页面
│ ├── auth/
│ │ ├── Login.vue
│ │ └── Register.vue
│ ├── news/
│ │ ├── NewsList.vue
│ │ ├── NewsDetail.vue
│ │ └── NewsSearch.vue
│ ├── category/
│ │ ├── CategoryManage.vue
│ │ └── CategoryStats.vue
│ ├── classifier/
│ │ ├── ClassifierPage.vue
│ │ └── ModelCompare.vue
│ └── admin/
│ ├── Dashboard.vue
│ ├── UserManage.vue
│ └── SystemLog.vue
├── App.vue
└── main.ts
3.1 需要完成的具体文件
任务 3.1.1: utils/request.ts - HTTP请求封装
import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios'
// 响应数据类型
interface ApiResponse<T = any> {
code: number
message: string
data: T
}
// 创建axios实例
const service: AxiosInstance = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8080/api',
timeout: 15000,
headers: {
'Content-Type': 'application/json'
}
})
// 请求拦截器
service.interceptors.request.use(
(config) => {
const token = localStorage.getItem('token')
if (token) {
config.headers.Authorization = `Bearer ${token}`
}
return config
},
(error) => {
return Promise.reject(error)
}
)
// 响应拦截器
service.interceptors.response.use(
(response: AxiosResponse<ApiResponse>) => {
const { code, message, data } = response.data
if (code === 200) {
return data
} else {
// 处理错误
return Promise.reject(new Error(message || '请求失败'))
}
},
(error) => {
// 处理HTTP错误
return Promise.reject(error)
}
)
// 封装请求方法
export const http = {
get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
return service.get(url, config)
},
post<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
return service.post(url, data, config)
},
put<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
return service.put(url, data, config)
},
delete<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
return service.delete(url, config)
}
}
export default service
任务 3.1.2: api/news.ts - 新闻API
import { http } from './request'
// 新闻查询参数
export interface NewsQueryParams {
page?: number
size?: number
categoryId?: number
categoryCode?: string
keyword?: string
status?: number
}
// 新闻详情
export interface NewsDetail {
id: number
title: string
content: string
summary: string
source: string
sourceUrl: string
author: string
categoryId: number
categoryCode: string
coverImage: string
publishTime: string
viewCount: number
likeCount: number
commentCount: number
classifierType: string
confidence: number
}
// 分页响应
export interface PageResponse<T> {
total: number
records: T[]
current: number
size: number
}
// 新闻API
export const newsApi = {
// 分页查询
getNewsPage(params: NewsQueryParams): Promise<PageResponse<NewsDetail>> {
return http.get('/news/page', { params })
},
// 获取详情
getNewsDetail(id: number): Promise<NewsDetail> {
return http.get(`/news/${id}`)
},
// 搜索新闻
searchNews(keyword: string, page = 1, size = 20): Promise<PageResponse<NewsDetail>> {
return http.get('/news/search', { params: { keyword, page, size } })
},
// 手动分类
manualClassify(id: number, categoryId: number): Promise<void> {
return http.post(`/news/${id}/classify`, null, { params: { categoryId } })
}
}
任务 3.1.3: composables/useNews.ts - 新闻组合式函数
import { ref, computed } from 'vue'
import { newsApi, type NewsQueryParams, type NewsDetail, type PageResponse } from '@/api/news'
export function useNews() {
const loading = ref(false)
const newsList = ref<NewsDetail[]>([])
const total = ref(0)
const currentNews = ref<NewsDetail | null>(null)
// 分页查询
const fetchNewsPage = async (params: NewsQueryParams) => {
loading.value = true
try {
const result: PageResponse<NewsDetail> = await newsApi.getNewsPage(params)
newsList.value = result.records
total.value = result.total
} catch (error) {
console.error('获取新闻列表失败:', error)
throw error
} finally {
loading.value = false
}
}
// 获取详情
const fetchNewsDetail = async (id: number) => {
loading.value = true
try {
currentNews.value = await newsApi.getNewsDetail(id)
return currentNews.value
} catch (error) {
console.error('获取新闻详情失败:', error)
throw error
} finally {
loading.value = false
}
}
// 搜索新闻
const searchNews = async (keyword: string, page = 1, size = 20) => {
loading.value = true
try {
const result: PageResponse<NewsDetail> = await newsApi.searchNews(keyword, page, size)
newsList.value = result.records
total.value = result.total
} catch (error) {
console.error('搜索新闻失败:', error)
throw error
} finally {
loading.value = false
}
}
return {
loading: computed(() => loading.value),
newsList: computed(() => newsList.value),
total: computed(() => total.value),
currentNews: computed(() => currentNews.value),
fetchNewsPage,
fetchNewsDetail,
searchNews
}
}
任务 3.1.4: views/news/NewsList.vue - 新闻列表页面
<template>
<div class="news-list-container">
<!-- 分类筛选 -->
<div class="category-filter">
<button
v-for="cat in categories"
:key="cat.code"
:class="{ active: selectedCategory === cat.code }"
@click="selectCategory(cat.code)"
>
{{ cat.name }}
</button>
</div>
<!-- 搜索框 -->
<div class="search-bar">
<input
v-model="searchKeyword"
type="text"
placeholder="搜索新闻..."
@keyup.enter="handleSearch"
/>
<button @click="handleSearch">搜索</button>
</div>
<!-- 新闻列表 -->
<div v-if="!loading && newsList.length > 0" class="news-list">
<div v-for="news in newsList" :key="news.id" class="news-card" @click="viewDetail(news.id)">
<h3>{{ news.title }}</h3>
<p class="summary">{{ news.summary }}</p>
<div class="meta">
<span class="category">{{ getCategoryName(news.categoryCode) }}</span>
<span class="source">{{ news.source }}</span>
<span class="time">{{ formatDate(news.publishTime) }}</span>
</div>
<div v-if="news.classifierType" class="classifier-info">
<span class="badge">{{ news.classifierType }}</span>
<span class="confidence">{{ (news.confidence * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
<!-- 加载中 -->
<div v-if="loading" class="loading">加载中...</div>
<!-- 空状态 -->
<div v-if="!loading && newsList.length === 0" class="empty">暂无数据</div>
<!-- 分页 -->
<div v-if="total > 0" class="pagination">
<button :disabled="currentPage <= 1" @click="changePage(currentPage - 1)">上一页</button>
<span>{{ currentPage }} / {{ totalPages }}</span>
<button :disabled="currentPage >= totalPages" @click="changePage(currentPage + 1)">下一页</button>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { useNews } from '@/composables/useNews'
const router = useRouter()
const { loading, newsList, total, fetchNewsPage, searchNews } = useNews()
const categories = ref([
{ code: '', name: '全部' },
{ code: 'POLITICS', name: '时政' },
{ code: 'FINANCE', name: '财经' },
{ code: 'TECHNOLOGY', name: '科技' },
{ code: 'SPORTS', name: '体育' }
])
const selectedCategory = ref('')
const searchKeyword = ref('')
const currentPage = ref(1)
const pageSize = ref(20)
const totalPages = computed(() => Math.ceil(total.value / pageSize.value))
// 加载新闻
const loadNews = async () => {
await fetchNewsPage({
page: currentPage.value,
size: pageSize.value,
categoryCode: selectedCategory.value || undefined
})
}
// 选择分类
const selectCategory = (code: string) => {
selectedCategory.value = code
currentPage.value = 1
loadNews()
}
// 搜索
const handleSearch = async () => {
if (searchKeyword.value.trim()) {
await searchNews(searchKeyword.value, currentPage.value, pageSize.value)
} else {
loadNews()
}
}
// 查看详情
const viewDetail = (id: number) => {
router.push(`/news/${id}`)
}
// 翻页
const changePage = (page: number) => {
currentPage.value = page
loadNews()
}
// 格式化日期
const formatDate = (dateStr: string) => {
// 实现日期格式化
return dateStr
}
// 获取分类名称
const getCategoryName = (code: string) => {
const cat = categories.value.find(c => c.code === code)
return cat?.name || code
}
onMounted(() => {
loadNews()
})
</script>
<style scoped>
.news-list-container {
padding: 20px;
}
.category-filter {
display: flex;
gap: 10px;
margin-bottom: 20px;
}
.category-filter button {
padding: 8px 16px;
border: 1px solid #ddd;
border-radius: 4px;
background: white;
cursor: pointer;
}
.category-filter button.active {
background: #1890ff;
color: white;
border-color: #1890ff;
}
.news-card {
padding: 15px;
border: 1px solid #eee;
border-radius: 8px;
margin-bottom: 15px;
cursor: pointer;
transition: box-shadow 0.2s;
}
.news-card:hover {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}
.meta {
display: flex;
gap: 15px;
font-size: 12px;
color: #999;
margin-top: 10px;
}
.classifier-info {
display: flex;
gap: 8px;
margin-top: 8px;
}
.badge {
padding: 2px 8px;
background: #f0f0f0;
border-radius: 4px;
font-size: 12px;
}
.pagination {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
margin-top: 20px;
}
</style>
任务 3.1.5: views/classifier/ClassifierPage.vue - 分类器页面
<template>
<div class="classifier-page">
<div class="page-header">
<h2>文本分类</h2>
</div>
<!-- 分类模式选择 -->
<div class="mode-selector">
<label>分类模式:</label>
<select v-model="selectedMode">
<option value="traditional">传统机器学习 (TF-IDF + NB)</option>
<option value="deep_learning">深度学习 (BERT)</option>
<option value="hybrid">混合模式</option>
</select>
</div>
<!-- 输入区域 -->
<div class="input-area">
<div class="form-group">
<label>新闻标题</label>
<input v-model="formData.title" type="text" placeholder="请输入新闻标题" />
</div>
<div class="form-group">
<label>新闻内容</label>
<textarea v-model="formData.content" placeholder="请输入新闻内容" rows="10"></textarea>
</div>
<button @click="handleClassify" :disabled="classifying">
{{ classifying ? '分类中...' : '开始分类' }}
</button>
</div>
<!-- 分类结果 -->
<div v-if="result" class="result-area">
<h3>分类结果</h3>
<div class="result-item">
<span class="label">分类:</span>
<span class="value">{{ result.categoryName }} ({{ result.categoryCode }})</span>
</div>
<div class="result-item">
<span class="label">置信度:</span>
<span class="value">{{ (result.confidence * 100).toFixed(2) }}%</span>
</div>
<div class="result-item">
<span class="label">分类器:</span>
<span class="value">{{ result.classifierType }}</span>
</div>
<div class="result-item">
<span class="label">耗时:</span>
<span class="value">{{ result.duration }}ms</span>
</div>
<!-- 概率分布图 -->
<div v-if="result.probabilities" class="probabilities">
<h4>各类别概率</h4>
<div v-for="(prob, code) in result.probabilities" :key="code" class="prob-bar">
<span class="cat-name">{{ getCategoryName(code) }}</span>
<div class="bar-container">
<div class="bar" :style="{ width: (prob * 100) + '%' }"></div>
</div>
<span class="prob-value">{{ (prob * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import { classifierApi } from '@/api/classifier'
interface ClassificationResult {
categoryCode: string
categoryName: string
confidence: number
classifierType: string
duration: number
probabilities?: Record<string, number>
}
const formData = ref({
title: '',
content: ''
})
const selectedMode = ref('hybrid')
const classifying = ref(false)
const result = ref<ClassificationResult | null>(null)
const handleClassify = async () => {
if (!formData.value.title.trim() || !formData.value.content.trim()) {
alert('请输入标题和内容')
return
}
classifying.value = true
try {
result.value = await classifierApi.classify({
title: formData.value.title,
content: formData.value.content,
mode: selectedMode.value
})
} catch (error) {
console.error('分类失败:', error)
alert('分类失败,请重试')
} finally {
classifying.value = false
}
}
const getCategoryName = (code: string) => {
const map: Record<string, string> = {
POLITICS: '时政',
FINANCE: '财经',
TECHNOLOGY: '科技',
SPORTS: '体育'
}
return map[code] || code
}
</script>
<style scoped>
.classifier-page {
padding: 20px;
max-width: 800px;
margin: 0 auto;
}
.mode-selector {
margin-bottom: 20px;
}
.mode-selector select {
padding: 8px;
border-radius: 4px;
border: 1px solid #ddd;
}
.input-area {
background: #f9f9f9;
padding: 20px;
border-radius: 8px;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
margin-bottom: 5px;
font-weight: 500;
}
.form-group input,
.form-group textarea {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
box-sizing: border-box;
}
.result-area {
margin-top: 20px;
padding: 20px;
background: #f0f7ff;
border-radius: 8px;
}
.result-item {
display: flex;
padding: 8px 0;
}
.result-item .label {
width: 80px;
font-weight: 500;
}
.probabilities {
margin-top: 20px;
}
.prob-bar {
display: flex;
align-items: center;
margin-bottom: 10px;
}
.cat-name {
width: 80px;
}
.bar-container {
flex: 1;
height: 20px;
background: #e0e0e0;
border-radius: 4px;
overflow: hidden;
margin: 0 10px;
}
.bar {
height: 100%;
background: linear-gradient(90deg, #1890ff, #52c41a);
transition: width 0.3s;
}
.prob-value {
width: 60px;
text-align: right;
}
</style>
任务 3.1.6: router/index.ts - 路由配置 (更新)
import { createRouter, createWebHistory } from 'vue-router'
import type { RouteRecordRaw } from 'vue-router'
const routes: RouteRecordRaw[] = [
{
path: '/login',
name: 'Login',
component: () => import('@/views/auth/Login.vue'),
meta: { layout: 'EmptyLayout' }
},
{
path: '/',
name: 'Home',
component: () => import('@/views/news/NewsList.vue'),
meta: { requiresAuth: true }
},
{
path: '/news',
name: 'NewsList',
component: () => import('@/views/news/NewsList.vue'),
meta: { requiresAuth: true }
},
{
path: '/news/:id',
name: 'NewsDetail',
component: () => import('@/views/news/NewsDetail.vue'),
meta: { requiresAuth: true }
},
{
path: '/classifier',
name: 'Classifier',
component: () => import('@/views/classifier/ClassifierPage.vue'),
meta: { requiresAuth: true }
},
{
path: '/category',
name: 'CategoryStats',
component: () => import('@/views/category/CategoryStats.vue'),
meta: { requiresAuth: true }
},
{
path: '/admin',
name: 'AdminDashboard',
component: () => import('@/views/admin/Dashboard.vue'),
meta: { requiresAuth: true, requiresAdmin: true }
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
// 路由守卫
router.beforeEach((to, from, next) => {
const token = localStorage.getItem('token')
if (to.meta.requiresAuth && !token) {
next('/login')
} else if (to.meta.requiresAdmin) {
// 检查管理员权限
const userRole = localStorage.getItem('userRole')
if (userRole !== 'ADMIN') {
next('/')
} else {
next()
}
} else {
next()
}
})
export default router
4. 机器学习分类模块 (Python)
模块目录结构
ml-module/
├── data/
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后的数据
│ │ ├── training_data.csv
│ │ └── test_data.csv
│ └── external/ # 外部数据集
├── models/ # 训练好的模型
│ ├── traditional/
│ │ ├── nb_vectorizer.pkl
│ │ ├── nb_classifier.pkl
│ │ ├── svm_vectorizer.pkl
│ │ └── svm_classifier.pkl
│ ├── deep_learning/
│ │ └── bert_finetuned/
│ └── hybrid/
│ └── config.json
├── src/
│ ├── __init__.py
│ ├── traditional/ # 传统机器学习
│ │ ├── __init__.py
│ │ ├── train_model.py # (已有)
│ │ ├── predict.py
│ │ └── evaluate.py
│ ├── deep_learning/ # 深度学习
│ │ ├── __init__.py
│ │ ├── bert_model.py
│ │ ├── train_bert.py
│ │ └── predict_bert.py
│ ├── hybrid/ # 混合策略
│ │ ├── __init__.py
│ │ ├── hybrid_classifier.py
│ │ └── rule_engine.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── preprocessing.py # 数据预处理
│ │ └── metrics.py # 评估指标
│ └── api/ # API服务
│ ├── __init__.py
│ └── server.py # FastAPI服务
├── notebooks/ # Jupyter notebooks
│ ├── data_exploration.ipynb
│ └── model_comparison.ipynb
├── tests/ # 测试
│ ├── test_traditional.py
│ ├── test_bert.py
│ └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md
4.1 需要完成的具体文件
任务 4.1.1: src/traditional/predict.py - 传统模型预测
"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)
任务 4.1.2: src/deep_learning/bert_model.py - BERT模型
"""
BERT文本分类模型
"""
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments
)
from typing import Dict, Any, List
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
class BertClassifier:
"""BERT文本分类器"""
def __init__(self, model_name='bert-base-chinese', num_labels=10):
self.model_name = model_name
self.num_labels = num_labels
self.tokenizer = None
self.model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self, model_path):
"""加载微调后的模型"""
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=self.num_labels
)
self.model.to(self.device)
self.model.eval()
print(f"BERT模型加载成功: {model_path}")
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
"""
if self.model is None or self.tokenizer is None:
raise ValueError("模型未加载,请先调用load_model")
# 组合标题和内容
text = f"{title} [SEP] {content}"
# 分词
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
# 预测
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
# 获取预测结果
probs = torch.softmax(logits, dim=-1)
confidence, predicted_id = torch.max(probs, dim=-1)
predicted_id = predicted_id.item()
confidence = confidence.item()
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probs[0].cpu().numpy()):
category_code = ID_TO_LABEL[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': ID_TO_LABEL[predicted_id],
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
'confidence': confidence,
'probabilities': prob_dict
}
# 数据集类
class NewsDataset(torch.utils.data.Dataset):
"""新闻数据集"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
if __name__ == '__main__':
# 测试
classifier = BertClassifier()
# classifier.load_model('./models/deep_learning/bert_finetuned')
#
# result = classifier.predict(
# title="华为发布新款折叠屏手机",
# content="华为今天正式发布了新一代折叠屏手机..."
# )
# print(result)
print("BERT分类器初始化成功")
任务 4.1.3: src/hybrid/hybrid_classifier.py - 混合分类器
"""
混合策略分类器
结合规则引擎和机器学习模型
"""
import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier
class HybridClassifier:
"""混合分类器"""
def __init__(self):
# 初始化各个分类器
self.nb_predictor = TraditionalPredictor('nb')
self.bert_classifier = BertClassifier()
# 配置参数
self.config = {
'confidence_threshold': 0.75, # 高置信度阈值
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
'use_bert_threshold': 0.70, # 使用BERT的阈值
'rule_priority': True # 规则优先
}
# 规则关键词字典
self.rule_keywords = {
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
'LIFE': ['生活', '美食', '旅游', '购物'],
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
}
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
"""
规则匹配
:return: (category_code, confidence)
"""
text = title + ' ' + content
# 计算每个类别的关键词匹配数
matches = {}
for category, keywords in self.rule_keywords.items():
count = sum(1 for kw in keywords if kw in text)
if count > 0:
matches[category] = count
if not matches:
return None, 0.0
# 返回匹配最多的类别
best_category = max(matches, key=matches.get)
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
return best_category, confidence
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
"""
混合预测
"""
start_time = time.time()
# 1. 先尝试规则匹配
rule_category, rule_confidence = self.rule_match(title, content)
# 2. 传统机器学习预测
nb_result = self.nb_predictor.predict(title, content)
nb_confidence = nb_result['confidence']
# 决策逻辑
final_result = None
classifier_type = 'HYBRID'
# 规则优先且规则置信度高
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
final_result = {
'categoryCode': rule_category,
'categoryName': nb_result['categoryName'], # 从映射获取
'confidence': rule_confidence,
'classifierType': 'RULE',
'reason': '规则匹配'
}
# 传统模型置信度足够高
elif nb_confidence >= self.config['confidence_threshold']:
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '传统模型高置信度'
}
# 需要使用BERT
elif use_bert:
# TODO: 加载BERT模型预测
# bert_result = self.bert_classifier.predict(title, content)
# 如果BERT置信度也不高,选择最高的
final_result = {
**nb_result,
'classifierType': 'HYBRID',
'reason': '混合决策'
}
else:
# 不使用BERT,直接返回传统模型结果
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '默认传统模型'
}
# 计算耗时
duration = int((time.time() - start_time) * 1000)
final_result['duration'] = duration
return final_result
if __name__ == '__main__':
# 测试
classifier = HybridClassifier()
test_cases = [
{
'title': '国务院发布最新经济政策',
'content': '国务院今天发布了新的经济政策...'
},
{
'title': '华为发布新款折叠屏手机',
'content': '华为今天正式发布了新一代折叠屏手机...'
}
]
for case in test_cases:
result = classifier.predict(case['title'], case['content'])
print(f"标题: {case['title']}")
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
print(f"分类器: {result['classifierType']}")
print(f"原因: {result.get('reason', 'N/A')}")
print(f"耗时: {result['duration']}ms")
print("-" * 50)
任务 4.1.4: src/api/server.py - FastAPI服务
"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging
# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title="新闻分类API",
description="提供新闻文本分类服务",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class ClassifyRequest(BaseModel):
title: str
content: str
mode: Optional[str] = 'hybrid' # traditional, hybrid
# 响应模型
class ClassifyResponse(BaseModel):
categoryCode: str
categoryName: str
confidence: float
classifierType: str
duration: int
probabilities: Optional[dict] = None
# 初始化分类器
nb_predictor = None
hybrid_classifier = None
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
global nb_predictor, hybrid_classifier
logger.info("加载模型...")
try:
nb_predictor = TraditionalPredictor('nb')
logger.info("朴素贝叶斯模型加载成功")
except Exception as e:
logger.error(f"朴素贝叶斯模型加载失败: {e}")
try:
hybrid_classifier = HybridClassifier()
logger.info("混合分类器初始化成功")
except Exception as e:
logger.error(f"混合分类器初始化失败: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"message": "新闻分类API服务运行中"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"models": {
"nb_loaded": nb_predictor is not None,
"hybrid_loaded": hybrid_classifier is not None
}
}
@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
"""
文本分类接口
- **title**: 新闻标题
- **content**: 新闻内容
- **mode**: 分类模式 (traditional, hybrid)
"""
try:
if request.mode == 'traditional':
result = nb_predictor.predict(request.title, request.content)
result['classifierType'] = 'ML'
else: # hybrid
result = hybrid_classifier.predict(request.title, request.content)
return ClassifyResponse(**result)
except Exception as e:
logger.error(f"预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
"""
批量分类接口
"""
results = []
for req in requests:
try:
if req.mode == 'traditional':
result = nb_predictor.predict(req.title, req.content)
result['classifierType'] = 'ML'
else:
result = hybrid_classifier.predict(req.title, req.content)
results.append(result)
except Exception as e:
results.append({
'error': str(e),
'title': req.title
})
return {"results": results}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=5000,
log_level="info"
)
任务 4.1.5: src/utils/metrics.py - 评估指标
"""
模型评估指标工具
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
class ClassificationMetrics:
"""分类评估指标"""
@staticmethod
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
"""
计算所有指标
"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# 每个类别的指标
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
per_class_metrics = {}
for i, label in enumerate(labels):
per_class_metrics[label] = {
'precision': float(precision_per_class[i]),
'recall': float(recall_per_class[i]),
'f1': float(f1_per_class[i]),
'support': int(support_per_class[i])
}
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'per_class': per_class_metrics
}
@staticmethod
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
"""
绘制混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def print_report(y_true: List, y_pred: List, labels: List[str]):
"""
打印分类报告
"""
report = classification_report(
y_true, y_pred,
target_names=labels,
zero_division=0
)
print(report)
if __name__ == '__main__':
# 测试
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
metrics = ClassificationMetrics()
result = metrics.compute_all(y_true, y_pred, labels)
print(result)
任务 4.1.6: requirements.txt - 依赖文件
# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0
# 深度学习
torch>=2.0.0
transformers>=4.30.0
# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 工具
python-dotenv>=1.0.0
pyyaml>=6.0
总结
开发顺序建议
-
第一阶段:基础框架
- 后端:数据库连接、实体类、基础配置
- 前端:路由配置、状态管理、API封装
-
第二阶段:核心功能
- 爬虫模块(Python)
- 传统机器学习分类器
- 后端API接口
- 前端新闻列表页面
-
第三阶段:高级功能
- BERT深度学习分类器
- 混合策略分类器
- 前端分类器对比页面
- 统计图表
-
第四阶段:完善优化
- 用户认证
- 数据可视化
- 性能优化
- 异常处理
关键注意事项
- 爬虫模块使用 Python,通过 RESTful API 与 Java 后端通信
- 分类器模块独立部署,提供 HTTP 接口供后端调用
- 前后端分离,使用 JWT 进行身份认证
- 数据库表结构已在
schema.sql中定义,需严格遵守 - API 统一响应格式使用
Result<T>包装