77 KiB
77 KiB
新闻文本分类系统 - 模块开发任务清单
本文档详细列出每个模块需要完成的具体代码任务,参考现有工程结构。 注意:爬虫模块使用 Python 实现,而非 Java。
目录
1. 爬虫模块 (Python)
模块目录结构
crawler-module/
├── src/
│ ├── __init__.py
│ ├── base/ # 基础爬虫框架
│ │ ├── __init__.py
│ │ ├── base_crawler.py # 爬虫基类
│ │ ├── http_client.py # HTTP客户端封装
│ │ └── proxy_pool.py # 代理池(可选)
│ ├── parsers/ # 解析器
│ │ ├── __init__.py
│ │ ├── base_parser.py # 解析器基类
│ │ ├── sina_parser.py # 新浪新闻解析器
│ │ ├── sohu_parser.py # 搜狐新闻解析器
│ │ └── ifeng_parser.py # 凤凰网解析器
│ ├── cleaners/ # 数据清洗
│ │ ├── __init__.py
│ │ ├── text_cleaner.py # 文本清洗
│ │ └── deduplicator.py # 去重处理
│ ├── storage/ # 存储层
│ │ ├── __init__.py
│ │ ├── database.py # 数据库操作
│ │ └── storage_factory.py # 存储工厂
│ ├── utils/ # 工具类
│ │ ├── __init__.py
│ │ ├── user_agent.py # User-Agent池
│ │ └── date_parser.py # 日期解析
│ └── crawler.py # 爬虫主入口
├── config/
│ ├── __init__.py
│ ├── settings.py # 配置文件
│ └── sources.json # 数据源配置
├── requirements.txt
└── main.py
1.1 需要完成的具体文件
任务 1.1.1: config/settings.py - 爬虫配置文件
"""
爬虫模块配置文件
"""
import os
from typing import List
class CrawlerConfig:
"""爬虫配置类"""
# 项目根目录
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# 数据库配置
DB_HOST = os.getenv('DB_HOST', 'localhost')
DB_PORT = int(os.getenv('DB_PORT', 3306))
DB_NAME = os.getenv('DB_NAME', 'news_classifier')
DB_USER = os.getenv('DB_USER', 'root')
DB_PASSWORD = os.getenv('DB_PASSWORD', '')
# 爬虫配置
CONCURRENT_REQUESTS = 5 # 并发请求数
DOWNLOAD_DELAY = 1 # 下载延迟(秒)
REQUEST_TIMEOUT = 10 # 请求超时(秒)
MAX_RETRIES = 3 # 最大重试次数
# User-Agent配置
USER_AGENT_LIST = [
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36',
]
# 代理配置(可选)
PROXY_ENABLED = False
PROXY_LIST = []
# 日志配置
LOG_LEVEL = 'INFO'
LOG_FILE = os.path.join(BASE_DIR, 'logs', 'crawler.log')
# 存储配置
BATCH_INSERT_SIZE = 100 # 批量插入大小
任务 1.1.2: config/sources.json - 数据源配置
{
"sources": [
{
"name": "新浪新闻",
"code": "sina",
"enabled": true,
"base_url": "https://news.sina.com.cn",
"list_url_template": "https://news.sina.com.cn/{category}/index.shtml",
"categories": [
{"code": "POLITICS", "url_key": "pol"},
{"code": "FINANCE", "url_key": "finance"},
{"code": "TECHNOLOGY", "url_key": "tech"},
{"code": "SPORTS", "url_key": "sports"},
{"code": "ENTERTAINMENT", "url_key": "ent"}
],
"parser": "sina_parser.SinaNewsParser"
},
{
"name": "搜狐新闻",
"code": "sohu",
"enabled": true,
"base_url": "https://www.sohu.com",
"list_url_template": "https://www.sohu.com/{category}",
"categories": [
{"code": "POLITICS", "url_key": "politics"},
{"code": "FINANCE", "url_key": "business"},
{"code": "TECHNOLOGY", "url_key": "tech"},
{"code": "SPORTS", "url_key": "sports"}
],
"parser": "sohu_parser.SohuNewsParser"
}
]
}
任务 1.1.3: base/base_crawler.py - 爬虫基类
"""
爬虫基类
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any
import logging
import time
import random
logger = logging.getLogger(__name__)
class BaseCrawler(ABC):
"""爬虫基类"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.name = self.__class__.__name__
logger.info(f"初始化爬虫: {self.name}")
@abstractmethod
def fetch_news_list(self, category: str, page: int = 1) -> List[Dict[str, Any]]:
"""
获取新闻列表
:param category: 新闻类别
:param page: 页码
:return: 新闻URL列表
"""
pass
@abstractmethod
def fetch_news_detail(self, url: str) -> Dict[str, Any]:
"""
获取新闻详情
:param url: 新闻URL
:return: 新闻详情字典
"""
pass
def crawl(self, categories: List[str], max_pages: int = 5) -> List[Dict[str, Any]]:
"""
执行爬取任务
:param categories: 类别列表
:param max_pages: 最大页数
:return: 爬取的新闻列表
"""
results = []
for category in categories:
for page in range(1, max_pages + 1):
try:
news_list = self.fetch_news_list(category, page)
for news_url in news_list:
try:
detail = self.fetch_news_detail(news_url)
if detail:
results.append(detail)
except Exception as e:
logger.error(f"解析新闻详情失败: {news_url}, 错误: {e}")
# 随机延迟,避免请求过快
time.sleep(random.uniform(1, 3))
except Exception as e:
logger.error(f"爬取失败: category={category}, page={page}, 错误: {e}")
return results
任务 1.1.4: parsers/base_parser.py - 解析器基类
"""
解析器基类
"""
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from datetime import datetime
class BaseParser(ABC):
"""新闻解析器基类"""
@abstractmethod
def parse_news_list(self, html: str) -> list[str]:
"""
解析新闻列表页,获取新闻URL
:param html: HTML内容
:return: 新闻URL列表
"""
pass
@abstractmethod
def parse_news_detail(self, html: str, url: str) -> Optional[Dict[str, Any]]:
"""
解析新闻详情页
:param html: HTML内容
:param url: 新闻URL
:return: 解析后的新闻字典
"""
pass
def clean_html(self, html: str) -> str:
"""清理HTML标签"""
from bs4 import BeautifulSoup
soup = BeautifulSoup(html, 'html.parser')
return soup.get_text(separator=' ', strip=True)
def parse_publish_time(self, time_str: str) -> Optional[datetime]:
"""解析发布时间"""
# 实现时间解析逻辑
pass
任务 1.1.5: parsers/sina_parser.py - 新浪新闻解析器
"""
新浪新闻解析器
"""
from typing import Dict, Any, Optional
from bs4 import BeautifulSoup
import requests
from .base_parser import BaseParser
class SinaNewsParser(BaseParser):
"""新浪新闻解析器"""
def __init__(self):
self.base_url = "https://news.sina.com.cn"
def parse_news_list(self, html: str) -> list[str]:
"""解析新浪新闻列表"""
soup = BeautifulSoup(html, 'html.parser')
urls = []
# 根据新浪新闻的实际HTML结构解析
for item in soup.select('.news-item'):
link = item.select_one('a')
if link and link.get('href'):
urls.append(link['href'])
return urls
def parse_news_detail(self, html: str, url: str) -> Optional[Dict[str, Any]]:
"""解析新浪新闻详情"""
soup = BeautifulSoup(html, 'html.parser')
# 提取标题
title = soup.select_one('h1.main-title')
title = title.get_text(strip=True) if title else ''
# 提取内容
content = soup.select_one('.article-content')
content = self.clean_html(str(content)) if content else ''
# 提取来源
source = soup.select_one('.source')
source = source.get_text(strip=True) if source else '新浪新闻'
# 提取发布时间
publish_time = soup.select_one('.date')
publish_time = publish_time.get_text(strip=True) if publish_time else None
# 提取作者
author = soup.select_one('.author')
author = author.get_text(strip=True) if author else ''
return {
'title': title,
'content': content,
'summary': content[:200] if content else '',
'source': source,
'source_url': url,
'author': author,
'publish_time': self.parse_publish_time(publish_time) if publish_time else None
}
任务 1.1.6: cleaners/text_cleaner.py - 文本清洗
"""
文本清洗工具
"""
import re
from typing import List
class TextCleaner:
"""文本清洗器"""
# 无效字符模式
INVALID_CHARS = r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f-\x9f]'
# 无意义词(停用词)
STOP_WORDS = set([
'的', '了', '在', '是', '我', '有', '和', '就', '不', '人',
'都', '一', '一个', '上', '也', '很', '到', '说', '要', '去'
])
@classmethod
def clean_title(cls, title: str) -> str:
"""清洗标题"""
if not title:
return ''
# 移除无效字符
title = re.sub(cls.INVALID_CHARS, '', title)
# 移除多余空格
title = ' '.join(title.split())
return title.strip()
@classmethod
def clean_content(cls, content: str) -> str:
"""清洗内容"""
if not content:
return ''
# 移除HTML标签
content = re.sub(r'<[^>]+>', '', content)
# 移除无效字符
content = re.sub(cls.INVALID_CHARS, '', content)
# 移除多余空白
content = ' '.join(content.split())
# 移除过短段落
paragraphs = content.split('。')
paragraphs = [p.strip() for p in paragraphs if len(p.strip()) > 10]
return '。'.join(paragraphs)
@classmethod
def extract_summary(cls, content: str, max_length: int = 200) -> str:
"""提取摘要"""
if not content:
return ''
# 取前N个字符作为摘要
summary = content[:max_length]
# 确保在句子边界截断
last_period = summary.rfind('。')
if last_period > max_length * 0.7:
return summary[:last_period + 1]
return summary + '...'
任务 1.1.7: cleaners/deduplicator.py - 去重处理
"""
新闻去重处理
"""
import hashlib
from typing import Set, Dict, Any
class NewsDeduplicator:
"""新闻去重器"""
def __init__(self):
self.seen_hashes: Set[str] = set()
self.seen_urls: Set[str] = set()
def compute_hash(self, title: str, content: str) -> str:
"""计算新闻内容的哈希值"""
text = f"{title}|{content[:500]}" # 使用标题和前500字符
return hashlib.md5(text.encode('utf-8')).hexdigest()
def is_duplicate(self, news: Dict[str, Any]) -> bool:
"""
判断是否重复
:param news: 新闻字典
:return: True表示重复
"""
# URL去重
if news.get('source_url') in self.seen_urls:
return True
# 内容去重
content_hash = self.compute_hash(
news.get('title', ''),
news.get('content', '')
)
if content_hash in self.seen_hashes:
return True
# 记录
self.seen_urls.add(news.get('source_url', ''))
self.seen_hashes.add(content_hash)
return False
def deduplicate_batch(self, news_list: list[Dict[str, Any]]) -> list[Dict[str, Any]]:
"""批量去重"""
return [news for news in news_list if not self.is_duplicate(news)]
任务 1.1.8: storage/database.py - 数据库存储
"""
数据库存储层
"""
import pymysql
from typing import List, Dict, Any, Optional
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class NewsStorage:
"""新闻数据库存储"""
def __init__(self, config: Dict[str, Any]):
self.config = config
self.connection = None
self.connect()
def connect(self):
"""连接数据库"""
try:
self.connection = pymysql.connect(
host=self.config.get('DB_HOST', 'localhost'),
port=self.config.get('DB_PORT', 3306),
user=self.config.get('DB_USER', 'root'),
password=self.config.get('DB_PASSWORD', ''),
database=self.config.get('DB_NAME', 'news_classifier'),
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor
)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
raise
def close(self):
"""关闭连接"""
if self.connection:
self.connection.close()
def save_news(self, news: Dict[str, Any], category_code: str) -> Optional[int]:
"""
保存单条新闻
:return: 插入的新闻ID
"""
try:
with self.connection.cursor() as cursor:
sql = """
INSERT INTO news (title, content, summary, source, source_url,
author, category_code, publish_time, status)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
cursor.execute(sql, (
news.get('title'),
news.get('content'),
news.get('summary'),
news.get('source'),
news.get('source_url'),
news.get('author'),
category_code,
news.get('publish_time'),
1 # status: 已发布
))
self.connection.commit()
return cursor.lastrowid
except Exception as e:
logger.error(f"保存新闻失败: {e}")
self.connection.rollback()
return None
def batch_save_news(self, news_list: List[Dict[str, Any]], category_code: str) -> int:
"""
批量保存新闻
:return: 成功保存的数量
"""
count = 0
for news in news_list:
if self.save_news(news, category_code):
count += 1
return count
def news_exists(self, source_url: str) -> bool:
"""检查新闻是否已存在"""
try:
with self.connection.cursor() as cursor:
sql = "SELECT id FROM news WHERE source_url = %s LIMIT 1"
cursor.execute(sql, (source_url,))
return cursor.fetchone() is not None
except Exception as e:
logger.error(f"检查新闻存在性失败: {e}")
return False
任务 1.1.9: crawler.py - 爬虫主入口
"""
爬虫主程序入口
"""
import argparse
import logging
from config.settings import CrawlerConfig
from storage.database import NewsStorage
from cleaners.text_cleaner import TextCleaner
from cleaners.deduplicator import NewsDeduplicator
from parsers.sina_parser import SinaNewsParser
from parsers.sohu_parser import SohuNewsParser
def setup_logging():
"""配置日志"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('logs/crawler.log', encoding='utf-8'),
logging.StreamHandler()
]
)
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='新闻爬虫')
parser.add_argument('--source', type=str, help='数据源代码')
parser.add_argument('--category', type=str, help='新闻类别')
parser.add_argument('--pages', type=int, default=5, help='爬取页数')
args = parser.parse_args()
setup_logging()
logger = logging.getLogger(__name__)
# 初始化组件
storage = NewsStorage(CrawlerConfig.__dict__)
cleaner = TextCleaner()
deduplicator = NewsDeduplicator()
# 选择解析器
parser_map = {
'sina': SinaNewsParser(),
'sohu': SohuNewsParser()
}
selected_parser = parser_map.get(args.source)
if not selected_parser:
logger.error(f"不支持的数据源: {args.source}")
return
logger.info(f"开始爬取: source={args.source}, category={args.category}, pages={args.pages}")
# 执行爬取
# ... 具体爬取逻辑
logger.info("爬取完成")
if __name__ == '__main__':
main()
任务 1.1.10: requirements.txt - 依赖文件
# 爬虫模块依赖
requests>=2.31.0
beautifulsoup4>=4.12.0
lxml>=4.9.0
pymysql>=1.1.0
python-dotenv>=1.0.0
2. 后端服务模块 (Spring Boot)
模块目录结构
backend/src/main/java/com/newsclassifier/
├── controller/ # 控制器层
│ ├── AuthController.java
│ ├── NewsController.java
│ ├── CategoryController.java
│ ├── ClassifierController.java
│ └── AdminController.java
├── service/ # 服务层
│ ├── AuthService.java
│ ├── NewsService.java
│ ├── CategoryService.java
│ ├── ClassifierService.java
│ └── impl/
│ ├── AuthServiceImpl.java
│ ├── NewsServiceImpl.java
│ └── ClassifierServiceImpl.java
├── mapper/ # MyBatis Mapper
│ ├── UserMapper.java
│ ├── NewsMapper.java
│ └── CategoryMapper.java
├── entity/ # 实体类
│ ├── User.java (已完成)
│ ├── News.java (已完成)
│ └── NewsCategory.java
├── dto/ # 数据传输对象
│ ├── LoginDTO.java
│ ├── RegisterDTO.java
│ ├── NewsQueryDTO.java
│ ├── ClassificationResultDTO.java
│ └── PageResult.java
├── vo/ # 视图对象
│ ├── UserVO.java
│ ├── NewsVO.java
│ └── CategoryVO.java
├── common/ # 公共类
│ ├── Result.java (已完成)
│ ├── PageRequest.java
│ └── PageResponse.java
├── config/ # 配置类
│ ├── SecurityConfig.java
│ ├── CorsConfig.java
│ ├── MyBatisConfig.java
│ └── AsyncConfig.java (已完成)
├── security/ # 安全认证
│ ├── JwtTokenProvider.java
│ ├── JwtAuthenticationFilter.java
│ ├── UserDetailsServiceImpl.java
│ └── PasswordEncoder.java
├── classifier/ # 文本分类器
│ ├── IClassifier.java (已完成)
│ ├── ClassificationResult.java
│ ├── TraditionalMLClassifier.java (已完成)
│ ├── BERTClassifier.java
│ └── HybridClassifier.java (已完成)
├── exception/ # 异常处理
│ ├── GlobalExceptionHandler.java
│ ├── BusinessException.java
│ └── ErrorCode.java
├── util/ # 工具类
│ ├── JwtUtil.java
│ ├── DateUtil.java
│ └── ValidationUtil.java
└── NewsClassifierApplication.java (已完成)
2.1 需要完成的具体文件
任务 2.1.1: security/JwtTokenProvider.java - JWT令牌提供者
package com.newsclassifier.security;
import io.jsonwebtoken.*;
import io.jsonwebtoken.security.Keys;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.security.core.Authentication;
import org.springframework.stereotype.Component;
import javax.crypto.SecretKey;
import java.util.Date;
/**
* JWT令牌提供者
*/
@Component
public class JwtTokenProvider {
@Value("${jwt.secret}")
private String jwtSecret;
@Value("${jwt.expiration:86400000}") // 默认24小时
private long jwtExpiration;
private SecretKey getSigningKey() {
return Keys.hmacShaKeyFor(jwtSecret.getBytes());
}
/**
* 生成JWT令牌
*/
public String generateToken(Authentication authentication) {
String username = authentication.getName();
Date now = new Date();
Date expiryDate = new Date(now.getTime() + jwtExpiration);
return Jwts.builder()
.subject(username)
.issuedAt(now)
.expiration(expiryDate)
.signWith(getSigningKey())
.compact();
}
/**
* 从令牌获取用户名
*/
public String getUsernameFromToken(String token) {
Claims claims = Jwts.parser()
.verifyWith(getSigningKey())
.build()
.parseSignedClaims(token)
.getPayload();
return claims.getSubject();
}
/**
* 验证令牌
*/
public boolean validateToken(String token) {
try {
Jwts.parser()
.verifyWith(getSigningKey())
.build()
.parseSignedClaims(token);
return true;
} catch (JwtException ex) {
// 日志记录
}
return false;
}
}
任务 2.1.2: security/SecurityConfig.java - 安全配置
package com.newsclassifier.config;
import com.newsclassifier.security.JwtAuthenticationFilter;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
import org.springframework.security.config.http.SessionCreationPolicy;
import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder;
import org.springframework.security.crypto.password.PasswordEncoder;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
/**
* Spring Security配置
*/
@Configuration
@EnableWebSecurity
public class SecurityConfig {
private final JwtAuthenticationFilter jwtAuthenticationFilter;
public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter) {
this.jwtAuthenticationFilter = jwtAuthenticationFilter;
}
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
http
.csrf(csrf -> csrf.disable())
.sessionManagement(session ->
session.sessionCreationPolicy(SessionCreationPolicy.STATELESS)
)
.authorizeHttpRequests(auth -> auth
.requestMatchers("/api/auth/**").permitAll()
.requestMatchers("/api/doc.html", "/api/swagger/**").permitAll()
.anyRequest().authenticated()
)
.addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);
return http.build();
}
@Bean
public PasswordEncoder passwordEncoder() {
return new BCryptPasswordEncoder();
}
}
任务 2.1.3: controller/AuthController.java - 认证控制器
package com.newsclassifier.controller;
import com.newsclassifier.common.Result;
import com.newsclassifier.dto.LoginDTO;
import com.newsclassifier.dto.RegisterDTO;
import com.newsclassifier.service.AuthService;
import com.newsclassifier.vo.UserVO;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.validation.Valid;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
/**
* 认证控制器
*/
@Tag(name = "认证接口")
@RestController
@RequestMapping("/api/auth")
@RequiredArgsConstructor
public class AuthController {
private final AuthService authService;
@Operation(summary = "用户登录")
@PostMapping("/login")
public Result<UserVO> login(@Valid @RequestBody LoginDTO loginDTO) {
UserVO userVO = authService.login(loginDTO);
return Result.success(userVO);
}
@Operation(summary = "用户注册")
@PostMapping("/register")
public Result<Void> register(@Valid @RequestBody RegisterDTO registerDTO) {
authService.register(registerDTO);
return Result.success();
}
@Operation(summary = "刷新令牌")
@PostMapping("/refresh")
public Result<String> refreshToken(@RequestHeader("Authorization") String token) {
String newToken = authService.refreshToken(token);
return Result.success(newToken);
}
@Operation(summary = "用户登出")
@PostMapping("/logout")
public Result<Void> logout() {
return Result.success();
}
}
任务 2.1.4: controller/NewsController.java - 新闻控制器
package com.newsclassifier.controller;
import com.newsclassifier.common.PageResponse;
import com.newsclassifier.common.Result;
import com.newsclassifier.dto.NewsQueryDTO;
import com.newsclassifier.service.NewsService;
import com.newsclassifier.vo.NewsVO;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
/**
* 新闻控制器
*/
@Tag(name = "新闻接口")
@RestController
@RequestMapping("/api/news")
@RequiredArgsConstructor
public class NewsController {
private final NewsService newsService;
@Operation(summary = "分页查询新闻")
@GetMapping("/page")
public Result<PageResponse<NewsVO>> getNewsPage(NewsQueryDTO queryDTO) {
PageResponse<NewsVO> page = newsService.getNewsPage(queryDTO);
return Result.success(page);
}
@Operation(summary = "获取新闻详情")
@GetMapping("/{id}")
public Result<NewsVO> getNewsDetail(@PathVariable Long id) {
NewsVO newsVO = newsService.getNewsDetail(id);
return Result.success(newsVO);
}
@Operation(summary = "搜索新闻")
@GetMapping("/search")
public Result<PageResponse<NewsVO>> searchNews(
@RequestParam String keyword,
@RequestParam(defaultValue = "1") Integer page,
@RequestParam(defaultValue = "20") Integer size
) {
PageResponse<NewsVO> result = newsService.searchNews(keyword, page, size);
return Result.success(result);
}
@Operation(summary = "手动分类新闻")
@PostMapping("/{id}/classify")
public Result<Void> manualClassify(
@PathVariable Long id,
@RequestParam Long categoryId
) {
newsService.manualClassify(id, categoryId);
return Result.success();
}
}
任务 2.1.5: controller/ClassifierController.java - 分类控制器
package com.newsclassifier.controller;
import com.newsclassifier.common.Result;
import com.newsclassifier.dto.ClassifyRequestDTO;
import com.newsclassifier.service.ClassifierService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.*;
/**
* 文本分类控制器
*/
@Tag(name = "分类接口")
@RestController
@RequestMapping("/api/classifier")
@RequiredArgsConstructor
public class ClassifierController {
private final ClassifierService classifierService;
@Operation(summary = "对单条新闻进行分类")
@PostMapping("/classify")
public Result<ClassificationResultDTO> classify(@RequestBody ClassifyRequestDTO request) {
ClassificationResultDTO result = classifierService.classify(
request.getTitle(),
request.getContent(),
request.getMode()
);
return Result.success(result);
}
@Operation(summary = "批量分类")
@PostMapping("/batch-classify")
public Result<BatchClassifyResultDTO> batchClassify(
@RequestBody BatchClassifyRequestDTO request
) {
BatchClassifyResultDTO result = classifierService.batchClassify(
request.getNewsIds(),
request.getMode()
);
return Result.success(result);
}
@Operation(summary = "获取分类器状态")
@GetMapping("/status")
public Result<ClassifierStatusDTO> getStatus() {
ClassifierStatusDTO status = classifierService.getStatus();
return Result.success(status);
}
}
任务 2.1.6: service/impl/NewsServiceImpl.java - 新闻服务实现
package com.newsclassifier.service.impl;
import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.newsclassifier.common.PageResponse;
import com.newsclassifier.dto.NewsQueryDTO;
import com.newsclassifier.entity.News;
import com.newsclassifier.mapper.NewsMapper;
import com.newsclassifier.service.NewsService;
import com.newsclassifier.vo.NewsVO;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Service;
/**
* 新闻服务实现
*/
@Service
@RequiredArgsConstructor
public class NewsServiceImpl implements NewsService {
private final NewsMapper newsMapper;
@Override
public PageResponse<NewsVO> getNewsPage(NewsQueryDTO queryDTO) {
Page<News> page = new Page<>(queryDTO.getPage(), queryDTO.getSize());
LambdaQueryWrapper<News> wrapper = new LambdaQueryWrapper<News>()
.eq(queryDTO.getCategoryId() != null, News::getCategoryId, queryDTO.getCategoryId())
.eq(queryDTO.getCategoryCode() != null, News::getCategoryCode, queryDTO.getCategoryCode())
.eq(queryDTO.getStatus() != null, News::getStatus, queryDTO.getStatus())
.like(queryDTO.getKeyword() != null, News::getTitle, queryDTO.getKeyword())
.orderByDesc(News::getPublishTime);
Page<News> resultPage = newsMapper.selectPage(page, wrapper);
// 转换为VO
List<NewsVO> voList = resultPage.getRecords().stream()
.map(this::convertToVO)
.collect(Collectors.toList());
return PageResponse.of(resultPage.getTotal(), voList);
}
@Override
public NewsVO getNewsDetail(Long id) {
News news = newsMapper.selectById(id);
if (news == null) {
throw new BusinessException(ErrorCode.NEWS_NOT_FOUND);
}
// 增加浏览次数
newsMapper.addViewCount(id);
return convertToVO(news);
}
private NewsVO convertToVO(News news) {
// 实现Entity到VO的转换
return new NewsVO();
}
}
任务 2.1.7: classifier/ClassificationResult.java - 分类结果类
package com.newsclassifier.classifier;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.math.BigDecimal;
/**
* 分类结果
*/
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class ClassificationResult {
/**
* 分类代码
*/
private String categoryCode;
/**
* 分类名称
*/
private String categoryName;
/**
* 置信度 0-1
*/
private BigDecimal confidence;
/**
* 分类器类型
*/
private String classifierType;
/**
* 各类别概率分布
*/
private java.util.Map<String, BigDecimal> probabilities;
/**
* 耗时(毫秒)
*/
private Long duration;
}
任务 2.1.8: classifier/BERTClassifier.java - BERT分类器
package com.newsclassifier.classifier;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.math.BigDecimal;
import java.util.HashMap;
import java.util.Map;
/**
* BERT文本分类器
* 通过调用Python服务实现
*/
@Slf4j
@Component
public class BERTClassifier implements IClassifier {
@Value("${classifier.bert.service-url:http://localhost:5000/api/predict}")
private String bertServiceUrl;
@Value("${classifier.bert.timeout:5000}")
private int timeout;
@Override
public ClassificationResult classify(String title, String content) {
long startTime = System.currentTimeMillis();
try {
// 调用Python BERT服务
String result = callBERTService(title, content);
// 解析结果
return parseResult(result);
} catch (Exception e) {
log.error("BERT分类失败", e);
// 返回默认结果或降级处理
return getDefaultResult();
} finally {
long duration = System.currentTimeMillis() - startTime;
log.info("BERT分类耗时: {}ms", duration);
}
}
@Override
public String getType() {
return "DL";
}
private String callBERTService(String title, String content) {
// 使用HttpClient调用Python服务
// 实现HTTP请求逻辑
return "";
}
private ClassificationResult parseResult(String jsonResponse) {
// 解析JSON响应
return ClassificationResult.builder()
.categoryCode("TECHNOLOGY")
.categoryName("科技")
.confidence(new BigDecimal("0.95"))
.classifierType("DL")
.build();
}
}
任务 2.1.9: exception/GlobalExceptionHandler.java - 全局异常处理
package com.newsclassifier.exception;
import com.newsclassifier.common.Result;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.HttpStatus;
import org.springframework.validation.BindException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.ResponseStatus;
import org.springframework.web.bind.annotation.RestControllerAdvice;
/**
* 全局异常处理器
*/
@Slf4j
@RestControllerAdvice
public class GlobalExceptionHandler {
@ExceptionHandler(BusinessException.class)
@ResponseStatus(HttpStatus.OK)
public Result<Void> handleBusinessException(BusinessException e) {
log.error("业务异常: {}", e.getMessage());
return Result.error(e.getErrorCode(), e.getMessage());
}
@ExceptionHandler(BindException.class)
@ResponseStatus(HttpStatus.BAD_REQUEST)
public Result<Void> handleBindException(BindException e) {
String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage();
return Result.error(400, message);
}
@ExceptionHandler(Exception.class)
@ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR)
public Result<Void> handleException(Exception e) {
log.error("系统异常", e);
return Result.error(500, "系统内部错误");
}
}
任务 2.1.10: application.yml - 应用配置文件
spring:
application:
name: news-classifier
datasource:
driver-class-name: com.mysql.cj.jdbc.Driver
url: jdbc:mysql://localhost:3306/news_classifier?useUnicode=true&characterEncoding=utf8mb4&serverTimezone=Asia/Shanghai
username: root
password: your_password
data:
redis:
host: localhost
port: 6379
database: 0
# MyBatis-Plus配置
mybatis-plus:
mapper-locations: classpath:mapper/**/*.xml
type-aliases-package: com.newsclassifier.entity
configuration:
map-underscore-to-camel-case: true
log-impl: org.apache.ibatis.logging.stdout.StdOutImpl
global-config:
db-config:
id-type: auto
logic-delete-field: deleted
logic-delete-value: 1
logic-not-delete-value: 0
# JWT配置
jwt:
secret: your-secret-key-at-least-256-bits-long-for-hs256-algorithm
expiration: 86400000
# 分类器配置
classifier:
mode: hybrid # traditional, deep_learning, hybrid
confidence:
threshold: 0.75
hybrid-min: 0.6
bert:
service-url: http://localhost:5000/api/predict
timeout: 5000
# 日志配置
logging:
level:
com.newsclassifier: debug
pattern:
console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n"
3. 前端桌面模块 (Tauri + Vue3)
模块目录结构
client/src/
├── api/ # API接口
│ ├── index.ts
│ ├── auth.ts
│ ├── news.ts
│ ├── category.ts
│ └── classifier.ts
├── assets/ # 静态资源
│ ├── images/
│ ├── styles/
│ │ ├── main.css
│ │ └── tailwind.css
│ └── fonts/
├── components/ # 组件
│ ├── ui/ # 基础UI组件
│ │ ├── button/
│ │ ├── input/
│ │ ├── dialog/
│ │ └── table/
│ ├── layout/ # 布局组件
│ │ ├── Header.vue
│ │ ├── Sidebar.vue
│ │ └── Footer.vue
│ ├── news/ # 新闻相关组件
│ │ ├── NewsCard.vue
│ │ ├── NewsList.vue
│ │ ├── NewsDetail.vue
│ │ └── CategoryFilter.vue
│ └── charts/ # 图表组件
│ ├── CategoryChart.vue
│ └── TrendChart.vue
├── composables/ # 组合式函数
│ ├── useAuth.ts
│ ├── useNews.ts
│ ├── useClassifier.ts
│ └── useToast.ts
├── layouts/ # 布局
│ ├── DefaultLayout.vue
│ ├── AuthLayout.vue
│ └── EmptyLayout.vue
├── router/ # 路由 (部分完成)
│ └── index.ts
├── stores/ # 状态管理 (部分完成)
│ ├── user.ts
│ ├── news.ts
│ └── category.ts
├── types/ # TypeScript类型
│ ├── api.d.ts
│ ├── news.d.ts
│ └── user.d.ts
├── utils/ # 工具函数
│ ├── request.ts
│ ├── storage.ts
│ ├── format.ts
│ └── validate.ts
├── views/ # 页面
│ ├── auth/
│ │ ├── Login.vue
│ │ └── Register.vue
│ ├── news/
│ │ ├── NewsList.vue
│ │ ├── NewsDetail.vue
│ │ └── NewsSearch.vue
│ ├── category/
│ │ ├── CategoryManage.vue
│ │ └── CategoryStats.vue
│ ├── classifier/
│ │ ├── ClassifierPage.vue
│ │ └── ModelCompare.vue
│ └── admin/
│ ├── Dashboard.vue
│ ├── UserManage.vue
│ └── SystemLog.vue
├── App.vue
└── main.ts
3.1 需要完成的具体文件
任务 3.1.1: utils/request.ts - HTTP请求封装
import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios'
// 响应数据类型
interface ApiResponse<T = any> {
code: number
message: string
data: T
}
// 创建axios实例
const service: AxiosInstance = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8080/api',
timeout: 15000,
headers: {
'Content-Type': 'application/json'
}
})
// 请求拦截器
service.interceptors.request.use(
(config) => {
const token = localStorage.getItem('token')
if (token) {
config.headers.Authorization = `Bearer ${token}`
}
return config
},
(error) => {
return Promise.reject(error)
}
)
// 响应拦截器
service.interceptors.response.use(
(response: AxiosResponse<ApiResponse>) => {
const { code, message, data } = response.data
if (code === 200) {
return data
} else {
// 处理错误
return Promise.reject(new Error(message || '请求失败'))
}
},
(error) => {
// 处理HTTP错误
return Promise.reject(error)
}
)
// 封装请求方法
export const http = {
get<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
return service.get(url, config)
},
post<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
return service.post(url, data, config)
},
put<T = any>(url: string, data?: any, config?: AxiosRequestConfig): Promise<T> {
return service.put(url, data, config)
},
delete<T = any>(url: string, config?: AxiosRequestConfig): Promise<T> {
return service.delete(url, config)
}
}
export default service
任务 3.1.2: api/news.ts - 新闻API
import { http } from './request'
// 新闻查询参数
export interface NewsQueryParams {
page?: number
size?: number
categoryId?: number
categoryCode?: string
keyword?: string
status?: number
}
// 新闻详情
export interface NewsDetail {
id: number
title: string
content: string
summary: string
source: string
sourceUrl: string
author: string
categoryId: number
categoryCode: string
coverImage: string
publishTime: string
viewCount: number
likeCount: number
commentCount: number
classifierType: string
confidence: number
}
// 分页响应
export interface PageResponse<T> {
total: number
records: T[]
current: number
size: number
}
// 新闻API
export const newsApi = {
// 分页查询
getNewsPage(params: NewsQueryParams): Promise<PageResponse<NewsDetail>> {
return http.get('/news/page', { params })
},
// 获取详情
getNewsDetail(id: number): Promise<NewsDetail> {
return http.get(`/news/${id}`)
},
// 搜索新闻
searchNews(keyword: string, page = 1, size = 20): Promise<PageResponse<NewsDetail>> {
return http.get('/news/search', { params: { keyword, page, size } })
},
// 手动分类
manualClassify(id: number, categoryId: number): Promise<void> {
return http.post(`/news/${id}/classify`, null, { params: { categoryId } })
}
}
任务 3.1.3: composables/useNews.ts - 新闻组合式函数
import { ref, computed } from 'vue'
import { newsApi, type NewsQueryParams, type NewsDetail, type PageResponse } from '@/api/news'
export function useNews() {
const loading = ref(false)
const newsList = ref<NewsDetail[]>([])
const total = ref(0)
const currentNews = ref<NewsDetail | null>(null)
// 分页查询
const fetchNewsPage = async (params: NewsQueryParams) => {
loading.value = true
try {
const result: PageResponse<NewsDetail> = await newsApi.getNewsPage(params)
newsList.value = result.records
total.value = result.total
} catch (error) {
console.error('获取新闻列表失败:', error)
throw error
} finally {
loading.value = false
}
}
// 获取详情
const fetchNewsDetail = async (id: number) => {
loading.value = true
try {
currentNews.value = await newsApi.getNewsDetail(id)
return currentNews.value
} catch (error) {
console.error('获取新闻详情失败:', error)
throw error
} finally {
loading.value = false
}
}
// 搜索新闻
const searchNews = async (keyword: string, page = 1, size = 20) => {
loading.value = true
try {
const result: PageResponse<NewsDetail> = await newsApi.searchNews(keyword, page, size)
newsList.value = result.records
total.value = result.total
} catch (error) {
console.error('搜索新闻失败:', error)
throw error
} finally {
loading.value = false
}
}
return {
loading: computed(() => loading.value),
newsList: computed(() => newsList.value),
total: computed(() => total.value),
currentNews: computed(() => currentNews.value),
fetchNewsPage,
fetchNewsDetail,
searchNews
}
}
任务 3.1.4: views/news/NewsList.vue - 新闻列表页面
<template>
<div class="news-list-container">
<!-- 分类筛选 -->
<div class="category-filter">
<button
v-for="cat in categories"
:key="cat.code"
:class="{ active: selectedCategory === cat.code }"
@click="selectCategory(cat.code)"
>
{{ cat.name }}
</button>
</div>
<!-- 搜索框 -->
<div class="search-bar">
<input
v-model="searchKeyword"
type="text"
placeholder="搜索新闻..."
@keyup.enter="handleSearch"
/>
<button @click="handleSearch">搜索</button>
</div>
<!-- 新闻列表 -->
<div v-if="!loading && newsList.length > 0" class="news-list">
<div v-for="news in newsList" :key="news.id" class="news-card" @click="viewDetail(news.id)">
<h3>{{ news.title }}</h3>
<p class="summary">{{ news.summary }}</p>
<div class="meta">
<span class="category">{{ getCategoryName(news.categoryCode) }}</span>
<span class="source">{{ news.source }}</span>
<span class="time">{{ formatDate(news.publishTime) }}</span>
</div>
<div v-if="news.classifierType" class="classifier-info">
<span class="badge">{{ news.classifierType }}</span>
<span class="confidence">{{ (news.confidence * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
<!-- 加载中 -->
<div v-if="loading" class="loading">加载中...</div>
<!-- 空状态 -->
<div v-if="!loading && newsList.length === 0" class="empty">暂无数据</div>
<!-- 分页 -->
<div v-if="total > 0" class="pagination">
<button :disabled="currentPage <= 1" @click="changePage(currentPage - 1)">上一页</button>
<span>{{ currentPage }} / {{ totalPages }}</span>
<button :disabled="currentPage >= totalPages" @click="changePage(currentPage + 1)">下一页</button>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from 'vue'
import { useRouter } from 'vue-router'
import { useNews } from '@/composables/useNews'
const router = useRouter()
const { loading, newsList, total, fetchNewsPage, searchNews } = useNews()
const categories = ref([
{ code: '', name: '全部' },
{ code: 'POLITICS', name: '时政' },
{ code: 'FINANCE', name: '财经' },
{ code: 'TECHNOLOGY', name: '科技' },
{ code: 'SPORTS', name: '体育' }
])
const selectedCategory = ref('')
const searchKeyword = ref('')
const currentPage = ref(1)
const pageSize = ref(20)
const totalPages = computed(() => Math.ceil(total.value / pageSize.value))
// 加载新闻
const loadNews = async () => {
await fetchNewsPage({
page: currentPage.value,
size: pageSize.value,
categoryCode: selectedCategory.value || undefined
})
}
// 选择分类
const selectCategory = (code: string) => {
selectedCategory.value = code
currentPage.value = 1
loadNews()
}
// 搜索
const handleSearch = async () => {
if (searchKeyword.value.trim()) {
await searchNews(searchKeyword.value, currentPage.value, pageSize.value)
} else {
loadNews()
}
}
// 查看详情
const viewDetail = (id: number) => {
router.push(`/news/${id}`)
}
// 翻页
const changePage = (page: number) => {
currentPage.value = page
loadNews()
}
// 格式化日期
const formatDate = (dateStr: string) => {
// 实现日期格式化
return dateStr
}
// 获取分类名称
const getCategoryName = (code: string) => {
const cat = categories.value.find(c => c.code === code)
return cat?.name || code
}
onMounted(() => {
loadNews()
})
</script>
<style scoped>
.news-list-container {
padding: 20px;
}
.category-filter {
display: flex;
gap: 10px;
margin-bottom: 20px;
}
.category-filter button {
padding: 8px 16px;
border: 1px solid #ddd;
border-radius: 4px;
background: white;
cursor: pointer;
}
.category-filter button.active {
background: #1890ff;
color: white;
border-color: #1890ff;
}
.news-card {
padding: 15px;
border: 1px solid #eee;
border-radius: 8px;
margin-bottom: 15px;
cursor: pointer;
transition: box-shadow 0.2s;
}
.news-card:hover {
box-shadow: 0 2px 8px rgba(0, 0, 0, 0.1);
}
.meta {
display: flex;
gap: 15px;
font-size: 12px;
color: #999;
margin-top: 10px;
}
.classifier-info {
display: flex;
gap: 8px;
margin-top: 8px;
}
.badge {
padding: 2px 8px;
background: #f0f0f0;
border-radius: 4px;
font-size: 12px;
}
.pagination {
display: flex;
justify-content: center;
align-items: center;
gap: 15px;
margin-top: 20px;
}
</style>
任务 3.1.5: views/classifier/ClassifierPage.vue - 分类器页面
<template>
<div class="classifier-page">
<div class="page-header">
<h2>文本分类</h2>
</div>
<!-- 分类模式选择 -->
<div class="mode-selector">
<label>分类模式:</label>
<select v-model="selectedMode">
<option value="traditional">传统机器学习 (TF-IDF + NB)</option>
<option value="deep_learning">深度学习 (BERT)</option>
<option value="hybrid">混合模式</option>
</select>
</div>
<!-- 输入区域 -->
<div class="input-area">
<div class="form-group">
<label>新闻标题</label>
<input v-model="formData.title" type="text" placeholder="请输入新闻标题" />
</div>
<div class="form-group">
<label>新闻内容</label>
<textarea v-model="formData.content" placeholder="请输入新闻内容" rows="10"></textarea>
</div>
<button @click="handleClassify" :disabled="classifying">
{{ classifying ? '分类中...' : '开始分类' }}
</button>
</div>
<!-- 分类结果 -->
<div v-if="result" class="result-area">
<h3>分类结果</h3>
<div class="result-item">
<span class="label">分类:</span>
<span class="value">{{ result.categoryName }} ({{ result.categoryCode }})</span>
</div>
<div class="result-item">
<span class="label">置信度:</span>
<span class="value">{{ (result.confidence * 100).toFixed(2) }}%</span>
</div>
<div class="result-item">
<span class="label">分类器:</span>
<span class="value">{{ result.classifierType }}</span>
</div>
<div class="result-item">
<span class="label">耗时:</span>
<span class="value">{{ result.duration }}ms</span>
</div>
<!-- 概率分布图 -->
<div v-if="result.probabilities" class="probabilities">
<h4>各类别概率</h4>
<div v-for="(prob, code) in result.probabilities" :key="code" class="prob-bar">
<span class="cat-name">{{ getCategoryName(code) }}</span>
<div class="bar-container">
<div class="bar" :style="{ width: (prob * 100) + '%' }"></div>
</div>
<span class="prob-value">{{ (prob * 100).toFixed(1) }}%</span>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref } from 'vue'
import { classifierApi } from '@/api/classifier'
interface ClassificationResult {
categoryCode: string
categoryName: string
confidence: number
classifierType: string
duration: number
probabilities?: Record<string, number>
}
const formData = ref({
title: '',
content: ''
})
const selectedMode = ref('hybrid')
const classifying = ref(false)
const result = ref<ClassificationResult | null>(null)
const handleClassify = async () => {
if (!formData.value.title.trim() || !formData.value.content.trim()) {
alert('请输入标题和内容')
return
}
classifying.value = true
try {
result.value = await classifierApi.classify({
title: formData.value.title,
content: formData.value.content,
mode: selectedMode.value
})
} catch (error) {
console.error('分类失败:', error)
alert('分类失败,请重试')
} finally {
classifying.value = false
}
}
const getCategoryName = (code: string) => {
const map: Record<string, string> = {
POLITICS: '时政',
FINANCE: '财经',
TECHNOLOGY: '科技',
SPORTS: '体育'
}
return map[code] || code
}
</script>
<style scoped>
.classifier-page {
padding: 20px;
max-width: 800px;
margin: 0 auto;
}
.mode-selector {
margin-bottom: 20px;
}
.mode-selector select {
padding: 8px;
border-radius: 4px;
border: 1px solid #ddd;
}
.input-area {
background: #f9f9f9;
padding: 20px;
border-radius: 8px;
}
.form-group {
margin-bottom: 15px;
}
.form-group label {
display: block;
margin-bottom: 5px;
font-weight: 500;
}
.form-group input,
.form-group textarea {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 4px;
box-sizing: border-box;
}
.result-area {
margin-top: 20px;
padding: 20px;
background: #f0f7ff;
border-radius: 8px;
}
.result-item {
display: flex;
padding: 8px 0;
}
.result-item .label {
width: 80px;
font-weight: 500;
}
.probabilities {
margin-top: 20px;
}
.prob-bar {
display: flex;
align-items: center;
margin-bottom: 10px;
}
.cat-name {
width: 80px;
}
.bar-container {
flex: 1;
height: 20px;
background: #e0e0e0;
border-radius: 4px;
overflow: hidden;
margin: 0 10px;
}
.bar {
height: 100%;
background: linear-gradient(90deg, #1890ff, #52c41a);
transition: width 0.3s;
}
.prob-value {
width: 60px;
text-align: right;
}
</style>
任务 3.1.6: router/index.ts - 路由配置 (更新)
import { createRouter, createWebHistory } from 'vue-router'
import type { RouteRecordRaw } from 'vue-router'
const routes: RouteRecordRaw[] = [
{
path: '/login',
name: 'Login',
component: () => import('@/views/auth/Login.vue'),
meta: { layout: 'EmptyLayout' }
},
{
path: '/',
name: 'Home',
component: () => import('@/views/news/NewsList.vue'),
meta: { requiresAuth: true }
},
{
path: '/news',
name: 'NewsList',
component: () => import('@/views/news/NewsList.vue'),
meta: { requiresAuth: true }
},
{
path: '/news/:id',
name: 'NewsDetail',
component: () => import('@/views/news/NewsDetail.vue'),
meta: { requiresAuth: true }
},
{
path: '/classifier',
name: 'Classifier',
component: () => import('@/views/classifier/ClassifierPage.vue'),
meta: { requiresAuth: true }
},
{
path: '/category',
name: 'CategoryStats',
component: () => import('@/views/category/CategoryStats.vue'),
meta: { requiresAuth: true }
},
{
path: '/admin',
name: 'AdminDashboard',
component: () => import('@/views/admin/Dashboard.vue'),
meta: { requiresAuth: true, requiresAdmin: true }
}
]
const router = createRouter({
history: createWebHistory(),
routes
})
// 路由守卫
router.beforeEach((to, from, next) => {
const token = localStorage.getItem('token')
if (to.meta.requiresAuth && !token) {
next('/login')
} else if (to.meta.requiresAdmin) {
// 检查管理员权限
const userRole = localStorage.getItem('userRole')
if (userRole !== 'ADMIN') {
next('/')
} else {
next()
}
} else {
next()
}
})
export default router
4. 机器学习分类模块 (Python)
模块目录结构
ml-module/
├── data/
│ ├── raw/ # 原始数据
│ ├── processed/ # 处理后的数据
│ │ ├── training_data.csv
│ │ └── test_data.csv
│ └── external/ # 外部数据集
├── models/ # 训练好的模型
│ ├── traditional/
│ │ ├── nb_vectorizer.pkl
│ │ ├── nb_classifier.pkl
│ │ ├── svm_vectorizer.pkl
│ │ └── svm_classifier.pkl
│ ├── deep_learning/
│ │ └── bert_finetuned/
│ └── hybrid/
│ └── config.json
├── src/
│ ├── __init__.py
│ ├── traditional/ # 传统机器学习
│ │ ├── __init__.py
│ │ ├── train_model.py # (已有)
│ │ ├── predict.py
│ │ └── evaluate.py
│ ├── deep_learning/ # 深度学习
│ │ ├── __init__.py
│ │ ├── bert_model.py
│ │ ├── train_bert.py
│ │ └── predict_bert.py
│ ├── hybrid/ # 混合策略
│ │ ├── __init__.py
│ │ ├── hybrid_classifier.py
│ │ └── rule_engine.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── preprocessing.py # 数据预处理
│ │ └── metrics.py # 评估指标
│ └── api/ # API服务
│ ├── __init__.py
│ └── server.py # FastAPI服务
├── notebooks/ # Jupyter notebooks
│ ├── data_exploration.ipynb
│ └── model_comparison.ipynb
├── tests/ # 测试
│ ├── test_traditional.py
│ ├── test_bert.py
│ └── test_hybrid.py
├── requirements.txt
├── setup.py
└── README.md
4.1 需要完成的具体文件
任务 4.1.1: src/traditional/predict.py - 传统模型预测
"""
传统机器学习模型预测
"""
import os
import joblib
import jieba
from typing import Dict, Any
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
class TraditionalPredictor:
"""传统机器学习预测器"""
def __init__(self, model_type='nb', model_dir='../../models/traditional'):
self.model_type = model_type
self.model_dir = model_dir
self.vectorizer = None
self.classifier = None
self._load_model()
def _load_model(self):
"""加载模型"""
vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl')
classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl')
self.vectorizer = joblib.load(vectorizer_path)
self.classifier = joblib.load(classifier_path)
print(f"模型加载成功: {self.model_type}")
def preprocess(self, title: str, content: str) -> str:
"""预处理文本"""
text = title + ' ' + content
# jieba分词
words = jieba.cut(text)
return ' '.join(words)
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
:return: 预测结果字典
"""
# 预处理
processed = self.preprocess(title, content)
# 特征提取
tfidf = self.vectorizer.transform([processed])
# 预测
prediction = self.classifier.predict(tfidf)[0]
probabilities = self.classifier.predict_proba(tfidf)[0]
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probabilities):
category_code = self.classifier.classes_[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': prediction,
'categoryName': CATEGORY_MAP.get(prediction, '未知'),
'confidence': float(probabilities.max()),
'probabilities': prob_dict
}
# API入口
def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]:
"""
单条预测API
"""
predictor = TraditionalPredictor(model_type)
return predictor.predict(title, content)
if __name__ == '__main__':
# 测试
result = predict_single(
title="华为发布新款折叠屏手机",
content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..."
)
print(result)
任务 4.1.2: src/deep_learning/bert_model.py - BERT模型
"""
BERT文本分类模型
"""
import torch
from transformers import (
BertTokenizer,
BertForSequenceClassification,
Trainer,
TrainingArguments
)
from typing import Dict, Any, List
# 分类映射
CATEGORY_MAP = {
'POLITICS': '时政',
'FINANCE': '财经',
'TECHNOLOGY': '科技',
'SPORTS': '体育',
'ENTERTAINMENT': '娱乐',
'HEALTH': '健康',
'EDUCATION': '教育',
'LIFE': '生活',
'INTERNATIONAL': '国际',
'MILITARY': '军事'
}
# 反向映射
ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())}
LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())}
class BertClassifier:
"""BERT文本分类器"""
def __init__(self, model_name='bert-base-chinese', num_labels=10):
self.model_name = model_name
self.num_labels = num_labels
self.tokenizer = None
self.model = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self, model_path):
"""加载微调后的模型"""
self.tokenizer = BertTokenizer.from_pretrained(model_path)
self.model = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=self.num_labels
)
self.model.to(self.device)
self.model.eval()
print(f"BERT模型加载成功: {model_path}")
def predict(self, title: str, content: str) -> Dict[str, Any]:
"""
预测
"""
if self.model is None or self.tokenizer is None:
raise ValueError("模型未加载,请先调用load_model")
# 组合标题和内容
text = f"{title} [SEP] {content}"
# 分词
inputs = self.tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
# 预测
with torch.no_grad():
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
logits = outputs.logits
# 获取预测结果
probs = torch.softmax(logits, dim=-1)
confidence, predicted_id = torch.max(probs, dim=-1)
predicted_id = predicted_id.item()
confidence = confidence.item()
# 获取各类别概率
prob_dict = {}
for i, prob in enumerate(probs[0].cpu().numpy()):
category_code = ID_TO_LABEL[i]
prob_dict[category_code] = float(prob)
return {
'categoryCode': ID_TO_LABEL[predicted_id],
'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'),
'confidence': confidence,
'probabilities': prob_dict
}
# 数据集类
class NewsDataset(torch.utils.data.Dataset):
"""新闻数据集"""
def __init__(self, texts, labels, tokenizer, max_length=512):
self.texts = texts
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long)
}
if __name__ == '__main__':
# 测试
classifier = BertClassifier()
# classifier.load_model('./models/deep_learning/bert_finetuned')
#
# result = classifier.predict(
# title="华为发布新款折叠屏手机",
# content="华为今天正式发布了新一代折叠屏手机..."
# )
# print(result)
print("BERT分类器初始化成功")
任务 4.1.3: src/hybrid/hybrid_classifier.py - 混合分类器
"""
混合策略分类器
结合规则引擎和机器学习模型
"""
import time
from typing import Dict, Any
from ..traditional.predict import TraditionalPredictor
from ..deep_learning.bert_model import BertClassifier
class HybridClassifier:
"""混合分类器"""
def __init__(self):
# 初始化各个分类器
self.nb_predictor = TraditionalPredictor('nb')
self.bert_classifier = BertClassifier()
# 配置参数
self.config = {
'confidence_threshold': 0.75, # 高置信度阈值
'hybrid_min_confidence': 0.60, # 混合模式最低阈值
'use_bert_threshold': 0.70, # 使用BERT的阈值
'rule_priority': True # 规则优先
}
# 规则关键词字典
self.rule_keywords = {
'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'],
'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'],
'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'],
'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'],
'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'],
'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'],
'EDUCATION': ['教育', '学校', '大学', '考试', '招生'],
'LIFE': ['生活', '美食', '旅游', '购物'],
'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'],
'MILITARY': ['军事', '武器', '军队', '国防', '战争']
}
def rule_match(self, title: str, content: str) -> tuple[str | None, float]:
"""
规则匹配
:return: (category_code, confidence)
"""
text = title + ' ' + content
# 计算每个类别的关键词匹配数
matches = {}
for category, keywords in self.rule_keywords.items():
count = sum(1 for kw in keywords if kw in text)
if count > 0:
matches[category] = count
if not matches:
return None, 0.0
# 返回匹配最多的类别
best_category = max(matches, key=matches.get)
confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度
return best_category, confidence
def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]:
"""
混合预测
"""
start_time = time.time()
# 1. 先尝试规则匹配
rule_category, rule_confidence = self.rule_match(title, content)
# 2. 传统机器学习预测
nb_result = self.nb_predictor.predict(title, content)
nb_confidence = nb_result['confidence']
# 决策逻辑
final_result = None
classifier_type = 'HYBRID'
# 规则优先且规则置信度高
if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']:
final_result = {
'categoryCode': rule_category,
'categoryName': nb_result['categoryName'], # 从映射获取
'confidence': rule_confidence,
'classifierType': 'RULE',
'reason': '规则匹配'
}
# 传统模型置信度足够高
elif nb_confidence >= self.config['confidence_threshold']:
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '传统模型高置信度'
}
# 需要使用BERT
elif use_bert:
# TODO: 加载BERT模型预测
# bert_result = self.bert_classifier.predict(title, content)
# 如果BERT置信度也不高,选择最高的
final_result = {
**nb_result,
'classifierType': 'HYBRID',
'reason': '混合决策'
}
else:
# 不使用BERT,直接返回传统模型结果
final_result = {
**nb_result,
'classifierType': 'ML',
'reason': '默认传统模型'
}
# 计算耗时
duration = int((time.time() - start_time) * 1000)
final_result['duration'] = duration
return final_result
if __name__ == '__main__':
# 测试
classifier = HybridClassifier()
test_cases = [
{
'title': '国务院发布最新经济政策',
'content': '国务院今天发布了新的经济政策...'
},
{
'title': '华为发布新款折叠屏手机',
'content': '华为今天正式发布了新一代折叠屏手机...'
}
]
for case in test_cases:
result = classifier.predict(case['title'], case['content'])
print(f"标题: {case['title']}")
print(f"结果: {result['categoryName']} ({result['confidence']:.2f})")
print(f"分类器: {result['classifierType']}")
print(f"原因: {result.get('reason', 'N/A')}")
print(f"耗时: {result['duration']}ms")
print("-" * 50)
任务 4.1.4: src/api/server.py - FastAPI服务
"""
机器学习模型API服务
使用FastAPI提供RESTful API
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import Optional
import logging
# 导入分类器
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from traditional.predict import TraditionalPredictor
from hybrid.hybrid_classifier import HybridClassifier
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 创建FastAPI应用
app = FastAPI(
title="新闻分类API",
description="提供新闻文本分类服务",
version="1.0.0"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 请求模型
class ClassifyRequest(BaseModel):
title: str
content: str
mode: Optional[str] = 'hybrid' # traditional, hybrid
# 响应模型
class ClassifyResponse(BaseModel):
categoryCode: str
categoryName: str
confidence: float
classifierType: str
duration: int
probabilities: Optional[dict] = None
# 初始化分类器
nb_predictor = None
hybrid_classifier = None
@app.on_event("startup")
async def startup_event():
"""启动时加载模型"""
global nb_predictor, hybrid_classifier
logger.info("加载模型...")
try:
nb_predictor = TraditionalPredictor('nb')
logger.info("朴素贝叶斯模型加载成功")
except Exception as e:
logger.error(f"朴素贝叶斯模型加载失败: {e}")
try:
hybrid_classifier = HybridClassifier()
logger.info("混合分类器初始化成功")
except Exception as e:
logger.error(f"混合分类器初始化失败: {e}")
@app.get("/")
async def root():
"""健康检查"""
return {
"status": "ok",
"message": "新闻分类API服务运行中"
}
@app.get("/health")
async def health_check():
"""健康检查"""
return {
"status": "healthy",
"models": {
"nb_loaded": nb_predictor is not None,
"hybrid_loaded": hybrid_classifier is not None
}
}
@app.post("/api/predict", response_model=ClassifyResponse)
async def predict(request: ClassifyRequest):
"""
文本分类接口
- **title**: 新闻标题
- **content**: 新闻内容
- **mode**: 分类模式 (traditional, hybrid)
"""
try:
if request.mode == 'traditional':
result = nb_predictor.predict(request.title, request.content)
result['classifierType'] = 'ML'
else: # hybrid
result = hybrid_classifier.predict(request.title, request.content)
return ClassifyResponse(**result)
except Exception as e:
logger.error(f"预测失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/batch-predict")
async def batch_predict(requests: list[ClassifyRequest]):
"""
批量分类接口
"""
results = []
for req in requests:
try:
if req.mode == 'traditional':
result = nb_predictor.predict(req.title, req.content)
result['classifierType'] = 'ML'
else:
result = hybrid_classifier.predict(req.title, req.content)
results.append(result)
except Exception as e:
results.append({
'error': str(e),
'title': req.title
})
return {"results": results}
if __name__ == '__main__':
import uvicorn
uvicorn.run(
app,
host="0.0.0.0",
port=5000,
log_level="info"
)
任务 4.1.5: src/utils/metrics.py - 评估指标
"""
模型评估指标工具
"""
import numpy as np
from sklearn.metrics import (
accuracy_score,
precision_recall_fscore_support,
confusion_matrix,
classification_report
)
from typing import List, Dict, Any
import matplotlib.pyplot as plt
import seaborn as sns
class ClassificationMetrics:
"""分类评估指标"""
@staticmethod
def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]:
"""
计算所有指标
"""
accuracy = accuracy_score(y_true, y_pred)
precision, recall, f1, support = precision_recall_fscore_support(
y_true, y_pred, average='weighted', zero_division=0
)
# 每个类别的指标
precision_per_class, recall_per_class, f1_per_class, support_per_class = \
precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
per_class_metrics = {}
for i, label in enumerate(labels):
per_class_metrics[label] = {
'precision': float(precision_per_class[i]),
'recall': float(recall_per_class[i]),
'f1': float(f1_per_class[i]),
'support': int(support_per_class[i])
}
return {
'accuracy': float(accuracy),
'precision': float(precision),
'recall': float(recall),
'f1': float(f1),
'per_class': per_class_metrics
}
@staticmethod
def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None):
"""
绘制混淆矩阵
"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(
cm,
annot=True,
fmt='d',
cmap='Blues',
xticklabels=labels,
yticklabels=labels
)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def print_report(y_true: List, y_pred: List, labels: List[str]):
"""
打印分类报告
"""
report = classification_report(
y_true, y_pred,
target_names=labels,
zero_division=0
)
print(report)
if __name__ == '__main__':
# 测试
y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY']
y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY']
labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE']
metrics = ClassificationMetrics()
result = metrics.compute_all(y_true, y_pred, labels)
print(result)
任务 4.1.6: requirements.txt - 依赖文件
# 机器学习模块依赖
numpy>=1.24.0
pandas>=2.0.0
scikit-learn>=1.3.0
jieba>=0.42.0
joblib>=1.3.0
# 深度学习
torch>=2.0.0
transformers>=4.30.0
# API服务
fastapi>=0.100.0
uvicorn[standard]>=0.23.0
pydantic>=2.0.0
# 数据可视化
matplotlib>=3.7.0
seaborn>=0.12.0
# 工具
python-dotenv>=1.0.0
pyyaml>=6.0
总结
开发顺序建议
-
第一阶段:基础框架
- 后端:数据库连接、实体类、基础配置
- 前端:路由配置、状态管理、API封装
-
第二阶段:核心功能
- 爬虫模块(Python)
- 传统机器学习分类器
- 后端API接口
- 前端新闻列表页面
-
第三阶段:高级功能
- BERT深度学习分类器
- 混合策略分类器
- 前端分类器对比页面
- 统计图表
-
第四阶段:完善优化
- 用户认证
- 数据可视化
- 性能优化
- 异常处理
关键注意事项
- 爬虫模块使用 Python,通过 RESTful API 与 Java 后端通信
- 分类器模块独立部署,提供 HTTP 接口供后端调用
- 前后端分离,使用 JWT 进行身份认证
- 数据库表结构已在
schema.sql中定义,需严格遵守 - API 统一响应格式使用
Result<T>包装