# 新闻文本分类系统 - 模块开发任务清单 > 本文档详细列出每个模块需要完成的具体代码任务,参考现有工程结构。 > 注意:爬虫模块使用 Python 实现,而非 Java。 --- ## 目录 1. [爬虫模块 (Python)](#1-爬虫模块-python) 2. [后端服务模块 (Spring Boot)](#2-后端服务模块-spring-boot) 3. [前端桌面模块 (Tauri + Vue3)](#3-前端桌面模块-tauri--vue3) 4. [机器学习分类模块 (Python)](#4-机器学习分类模块-python) --- ## 1. 爬虫模块 (Python) ### 模块目录结构 ``` crawler-module/ ├── src/ │ ├── __init__.py │ ├── base/ # 基础爬虫框架 │ │ ├── __init__.py │ │ ├── base_crawler.py # 爬虫基类 │ │ ├── http_client.py # HTTP客户端封装 │ │ └── proxy_pool.py # 代理池(可选) │ ├── parsers/ # 解析器 │ │ ├── __init__.py │ │ ├── base_parser.py # 解析器基类 │ │ ├── sina_parser.py # 新浪新闻解析器 │ │ ├── sohu_parser.py # 搜狐新闻解析器 │ │ └── ifeng_parser.py # 凤凰网解析器 │ ├── cleaners/ # 数据清洗 │ │ ├── __init__.py │ │ ├── text_cleaner.py # 文本清洗 │ │ └── deduplicator.py # 去重处理 │ ├── storage/ # 存储层 │ │ ├── __init__.py │ │ ├── database.py # 数据库操作 │ │ └── storage_factory.py # 存储工厂 │ ├── utils/ # 工具类 │ │ ├── __init__.py │ │ ├── user_agent.py # User-Agent池 │ │ └── date_parser.py # 日期解析 │ └── crawler.py # 爬虫主入口 ├── config/ │ ├── __init__.py │ ├── settings.py # 配置文件 │ └── sources.json # 数据源配置 ├── requirements.txt └── main.py ``` ### 1.1 需要完成的具体文件 #### 任务 1.1.1: `config/settings.py` - 爬虫配置文件 ```python """ 爬虫模块配置文件 """ import os from typing import List class CrawlerConfig: """爬虫配置类""" # 项目根目录 BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) # 数据库配置 DB_HOST = os.getenv('DB_HOST', 'localhost') DB_PORT = int(os.getenv('DB_PORT', 3306)) DB_NAME = os.getenv('DB_NAME', 'news_classifier') DB_USER = os.getenv('DB_USER', 'root') DB_PASSWORD = os.getenv('DB_PASSWORD', '') # 爬虫配置 CONCURRENT_REQUESTS = 5 # 并发请求数 DOWNLOAD_DELAY = 1 # 下载延迟(秒) REQUEST_TIMEOUT = 10 # 请求超时(秒) MAX_RETRIES = 3 # 最大重试次数 # User-Agent配置 USER_AGENT_LIST = [ 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36', 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36', ] # 代理配置(可选) PROXY_ENABLED = False PROXY_LIST = [] # 日志配置 LOG_LEVEL = 'INFO' LOG_FILE = os.path.join(BASE_DIR, 'logs', 'crawler.log') # 存储配置 BATCH_INSERT_SIZE = 100 # 批量插入大小 ``` #### 任务 1.1.2: `config/sources.json` - 数据源配置 ```json { "sources": [ { "name": "新浪新闻", "code": "sina", "enabled": true, "base_url": "https://news.sina.com.cn", "list_url_template": "https://news.sina.com.cn/{category}/index.shtml", "categories": [ {"code": "POLITICS", "url_key": "pol"}, {"code": "FINANCE", "url_key": "finance"}, {"code": "TECHNOLOGY", "url_key": "tech"}, {"code": "SPORTS", "url_key": "sports"}, {"code": "ENTERTAINMENT", "url_key": "ent"} ], "parser": "sina_parser.SinaNewsParser" }, { "name": "搜狐新闻", "code": "sohu", "enabled": true, "base_url": "https://www.sohu.com", "list_url_template": "https://www.sohu.com/{category}", "categories": [ {"code": "POLITICS", "url_key": "politics"}, {"code": "FINANCE", "url_key": "business"}, {"code": "TECHNOLOGY", "url_key": "tech"}, {"code": "SPORTS", "url_key": "sports"} ], "parser": "sohu_parser.SohuNewsParser" } ] } ``` #### 任务 1.1.3: `base/base_crawler.py` - 爬虫基类 ```python """ 爬虫基类 """ from abc import ABC, abstractmethod from typing import List, Dict, Any import logging import time import random logger = logging.getLogger(__name__) class BaseCrawler(ABC): """爬虫基类""" def __init__(self, config: Dict[str, Any]): self.config = config self.name = self.__class__.__name__ logger.info(f"初始化爬虫: {self.name}") @abstractmethod def fetch_news_list(self, category: str, page: int = 1) -> List[Dict[str, Any]]: """ 获取新闻列表 :param category: 新闻类别 :param page: 页码 :return: 新闻URL列表 """ pass @abstractmethod def fetch_news_detail(self, url: str) -> Dict[str, Any]: """ 获取新闻详情 :param url: 新闻URL :return: 新闻详情字典 """ pass def crawl(self, categories: List[str], max_pages: int = 5) -> List[Dict[str, Any]]: """ 执行爬取任务 :param categories: 类别列表 :param max_pages: 最大页数 :return: 爬取的新闻列表 """ results = [] for category in categories: for page in range(1, max_pages + 1): try: news_list = self.fetch_news_list(category, page) for news_url in news_list: try: detail = self.fetch_news_detail(news_url) if detail: results.append(detail) except Exception as e: logger.error(f"解析新闻详情失败: {news_url}, 错误: {e}") # 随机延迟,避免请求过快 time.sleep(random.uniform(1, 3)) except Exception as e: logger.error(f"爬取失败: category={category}, page={page}, 错误: {e}") return results ``` #### 任务 1.1.4: `parsers/base_parser.py` - 解析器基类 ```python """ 解析器基类 """ from abc import ABC, abstractmethod from typing import Dict, Any, Optional from datetime import datetime class BaseParser(ABC): """新闻解析器基类""" @abstractmethod def parse_news_list(self, html: str) -> list[str]: """ 解析新闻列表页,获取新闻URL :param html: HTML内容 :return: 新闻URL列表 """ pass @abstractmethod def parse_news_detail(self, html: str, url: str) -> Optional[Dict[str, Any]]: """ 解析新闻详情页 :param html: HTML内容 :param url: 新闻URL :return: 解析后的新闻字典 """ pass def clean_html(self, html: str) -> str: """清理HTML标签""" from bs4 import BeautifulSoup soup = BeautifulSoup(html, 'html.parser') return soup.get_text(separator=' ', strip=True) def parse_publish_time(self, time_str: str) -> Optional[datetime]: """解析发布时间""" # 实现时间解析逻辑 pass ``` #### 任务 1.1.5: `parsers/sina_parser.py` - 新浪新闻解析器 ```python """ 新浪新闻解析器 """ from typing import Dict, Any, Optional from bs4 import BeautifulSoup import requests from .base_parser import BaseParser class SinaNewsParser(BaseParser): """新浪新闻解析器""" def __init__(self): self.base_url = "https://news.sina.com.cn" def parse_news_list(self, html: str) -> list[str]: """解析新浪新闻列表""" soup = BeautifulSoup(html, 'html.parser') urls = [] # 根据新浪新闻的实际HTML结构解析 for item in soup.select('.news-item'): link = item.select_one('a') if link and link.get('href'): urls.append(link['href']) return urls def parse_news_detail(self, html: str, url: str) -> Optional[Dict[str, Any]]: """解析新浪新闻详情""" soup = BeautifulSoup(html, 'html.parser') # 提取标题 title = soup.select_one('h1.main-title') title = title.get_text(strip=True) if title else '' # 提取内容 content = soup.select_one('.article-content') content = self.clean_html(str(content)) if content else '' # 提取来源 source = soup.select_one('.source') source = source.get_text(strip=True) if source else '新浪新闻' # 提取发布时间 publish_time = soup.select_one('.date') publish_time = publish_time.get_text(strip=True) if publish_time else None # 提取作者 author = soup.select_one('.author') author = author.get_text(strip=True) if author else '' return { 'title': title, 'content': content, 'summary': content[:200] if content else '', 'source': source, 'source_url': url, 'author': author, 'publish_time': self.parse_publish_time(publish_time) if publish_time else None } ``` #### 任务 1.1.6: `cleaners/text_cleaner.py` - 文本清洗 ```python """ 文本清洗工具 """ import re from typing import List class TextCleaner: """文本清洗器""" # 无效字符模式 INVALID_CHARS = r'[\x00-\x08\x0b-\x0c\x0e-\x1f\x7f-\x9f]' # 无意义词(停用词) STOP_WORDS = set([ '的', '了', '在', '是', '我', '有', '和', '就', '不', '人', '都', '一', '一个', '上', '也', '很', '到', '说', '要', '去' ]) @classmethod def clean_title(cls, title: str) -> str: """清洗标题""" if not title: return '' # 移除无效字符 title = re.sub(cls.INVALID_CHARS, '', title) # 移除多余空格 title = ' '.join(title.split()) return title.strip() @classmethod def clean_content(cls, content: str) -> str: """清洗内容""" if not content: return '' # 移除HTML标签 content = re.sub(r'<[^>]+>', '', content) # 移除无效字符 content = re.sub(cls.INVALID_CHARS, '', content) # 移除多余空白 content = ' '.join(content.split()) # 移除过短段落 paragraphs = content.split('。') paragraphs = [p.strip() for p in paragraphs if len(p.strip()) > 10] return '。'.join(paragraphs) @classmethod def extract_summary(cls, content: str, max_length: int = 200) -> str: """提取摘要""" if not content: return '' # 取前N个字符作为摘要 summary = content[:max_length] # 确保在句子边界截断 last_period = summary.rfind('。') if last_period > max_length * 0.7: return summary[:last_period + 1] return summary + '...' ``` #### 任务 1.1.7: `cleaners/deduplicator.py` - 去重处理 ```python """ 新闻去重处理 """ import hashlib from typing import Set, Dict, Any class NewsDeduplicator: """新闻去重器""" def __init__(self): self.seen_hashes: Set[str] = set() self.seen_urls: Set[str] = set() def compute_hash(self, title: str, content: str) -> str: """计算新闻内容的哈希值""" text = f"{title}|{content[:500]}" # 使用标题和前500字符 return hashlib.md5(text.encode('utf-8')).hexdigest() def is_duplicate(self, news: Dict[str, Any]) -> bool: """ 判断是否重复 :param news: 新闻字典 :return: True表示重复 """ # URL去重 if news.get('source_url') in self.seen_urls: return True # 内容去重 content_hash = self.compute_hash( news.get('title', ''), news.get('content', '') ) if content_hash in self.seen_hashes: return True # 记录 self.seen_urls.add(news.get('source_url', '')) self.seen_hashes.add(content_hash) return False def deduplicate_batch(self, news_list: list[Dict[str, Any]]) -> list[Dict[str, Any]]: """批量去重""" return [news for news in news_list if not self.is_duplicate(news)] ``` #### 任务 1.1.8: `storage/database.py` - 数据库存储 ```python """ 数据库存储层 """ import pymysql from typing import List, Dict, Any, Optional from datetime import datetime import logging logger = logging.getLogger(__name__) class NewsStorage: """新闻数据库存储""" def __init__(self, config: Dict[str, Any]): self.config = config self.connection = None self.connect() def connect(self): """连接数据库""" try: self.connection = pymysql.connect( host=self.config.get('DB_HOST', 'localhost'), port=self.config.get('DB_PORT', 3306), user=self.config.get('DB_USER', 'root'), password=self.config.get('DB_PASSWORD', ''), database=self.config.get('DB_NAME', 'news_classifier'), charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor ) logger.info("数据库连接成功") except Exception as e: logger.error(f"数据库连接失败: {e}") raise def close(self): """关闭连接""" if self.connection: self.connection.close() def save_news(self, news: Dict[str, Any], category_code: str) -> Optional[int]: """ 保存单条新闻 :return: 插入的新闻ID """ try: with self.connection.cursor() as cursor: sql = """ INSERT INTO news (title, content, summary, source, source_url, author, category_code, publish_time, status) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) """ cursor.execute(sql, ( news.get('title'), news.get('content'), news.get('summary'), news.get('source'), news.get('source_url'), news.get('author'), category_code, news.get('publish_time'), 1 # status: 已发布 )) self.connection.commit() return cursor.lastrowid except Exception as e: logger.error(f"保存新闻失败: {e}") self.connection.rollback() return None def batch_save_news(self, news_list: List[Dict[str, Any]], category_code: str) -> int: """ 批量保存新闻 :return: 成功保存的数量 """ count = 0 for news in news_list: if self.save_news(news, category_code): count += 1 return count def news_exists(self, source_url: str) -> bool: """检查新闻是否已存在""" try: with self.connection.cursor() as cursor: sql = "SELECT id FROM news WHERE source_url = %s LIMIT 1" cursor.execute(sql, (source_url,)) return cursor.fetchone() is not None except Exception as e: logger.error(f"检查新闻存在性失败: {e}") return False ``` #### 任务 1.1.9: `crawler.py` - 爬虫主入口 ```python """ 爬虫主程序入口 """ import argparse import logging from config.settings import CrawlerConfig from storage.database import NewsStorage from cleaners.text_cleaner import TextCleaner from cleaners.deduplicator import NewsDeduplicator from parsers.sina_parser import SinaNewsParser from parsers.sohu_parser import SohuNewsParser def setup_logging(): """配置日志""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('logs/crawler.log', encoding='utf-8'), logging.StreamHandler() ] ) def main(): """主函数""" parser = argparse.ArgumentParser(description='新闻爬虫') parser.add_argument('--source', type=str, help='数据源代码') parser.add_argument('--category', type=str, help='新闻类别') parser.add_argument('--pages', type=int, default=5, help='爬取页数') args = parser.parse_args() setup_logging() logger = logging.getLogger(__name__) # 初始化组件 storage = NewsStorage(CrawlerConfig.__dict__) cleaner = TextCleaner() deduplicator = NewsDeduplicator() # 选择解析器 parser_map = { 'sina': SinaNewsParser(), 'sohu': SohuNewsParser() } selected_parser = parser_map.get(args.source) if not selected_parser: logger.error(f"不支持的数据源: {args.source}") return logger.info(f"开始爬取: source={args.source}, category={args.category}, pages={args.pages}") # 执行爬取 # ... 具体爬取逻辑 logger.info("爬取完成") if __name__ == '__main__': main() ``` #### 任务 1.1.10: `requirements.txt` - 依赖文件 ```txt # 爬虫模块依赖 requests>=2.31.0 beautifulsoup4>=4.12.0 lxml>=4.9.0 pymysql>=1.1.0 python-dotenv>=1.0.0 ``` --- ## 2. 后端服务模块 (Spring Boot) ### 模块目录结构 ``` backend/src/main/java/com/newsclassifier/ ├── controller/ # 控制器层 │ ├── AuthController.java │ ├── NewsController.java │ ├── CategoryController.java │ ├── ClassifierController.java │ └── AdminController.java ├── service/ # 服务层 │ ├── AuthService.java │ ├── NewsService.java │ ├── CategoryService.java │ ├── ClassifierService.java │ └── impl/ │ ├── AuthServiceImpl.java │ ├── NewsServiceImpl.java │ └── ClassifierServiceImpl.java ├── mapper/ # MyBatis Mapper │ ├── UserMapper.java │ ├── NewsMapper.java │ └── CategoryMapper.java ├── entity/ # 实体类 │ ├── User.java (已完成) │ ├── News.java (已完成) │ └── NewsCategory.java ├── dto/ # 数据传输对象 │ ├── LoginDTO.java │ ├── RegisterDTO.java │ ├── NewsQueryDTO.java │ ├── ClassificationResultDTO.java │ └── PageResult.java ├── vo/ # 视图对象 │ ├── UserVO.java │ ├── NewsVO.java │ └── CategoryVO.java ├── common/ # 公共类 │ ├── Result.java (已完成) │ ├── PageRequest.java │ └── PageResponse.java ├── config/ # 配置类 │ ├── SecurityConfig.java │ ├── CorsConfig.java │ ├── MyBatisConfig.java │ └── AsyncConfig.java (已完成) ├── security/ # 安全认证 │ ├── JwtTokenProvider.java │ ├── JwtAuthenticationFilter.java │ ├── UserDetailsServiceImpl.java │ └── PasswordEncoder.java ├── classifier/ # 文本分类器 │ ├── IClassifier.java (已完成) │ ├── ClassificationResult.java │ ├── TraditionalMLClassifier.java (已完成) │ ├── BERTClassifier.java │ └── HybridClassifier.java (已完成) ├── exception/ # 异常处理 │ ├── GlobalExceptionHandler.java │ ├── BusinessException.java │ └── ErrorCode.java ├── util/ # 工具类 │ ├── JwtUtil.java │ ├── DateUtil.java │ └── ValidationUtil.java └── NewsClassifierApplication.java (已完成) ``` ### 2.1 需要完成的具体文件 #### 任务 2.1.1: `security/JwtTokenProvider.java` - JWT令牌提供者 ```java package com.newsclassifier.security; import io.jsonwebtoken.*; import io.jsonwebtoken.security.Keys; import org.springframework.beans.factory.annotation.Value; import org.springframework.security.core.Authentication; import org.springframework.stereotype.Component; import javax.crypto.SecretKey; import java.util.Date; /** * JWT令牌提供者 */ @Component public class JwtTokenProvider { @Value("${jwt.secret}") private String jwtSecret; @Value("${jwt.expiration:86400000}") // 默认24小时 private long jwtExpiration; private SecretKey getSigningKey() { return Keys.hmacShaKeyFor(jwtSecret.getBytes()); } /** * 生成JWT令牌 */ public String generateToken(Authentication authentication) { String username = authentication.getName(); Date now = new Date(); Date expiryDate = new Date(now.getTime() + jwtExpiration); return Jwts.builder() .subject(username) .issuedAt(now) .expiration(expiryDate) .signWith(getSigningKey()) .compact(); } /** * 从令牌获取用户名 */ public String getUsernameFromToken(String token) { Claims claims = Jwts.parser() .verifyWith(getSigningKey()) .build() .parseSignedClaims(token) .getPayload(); return claims.getSubject(); } /** * 验证令牌 */ public boolean validateToken(String token) { try { Jwts.parser() .verifyWith(getSigningKey()) .build() .parseSignedClaims(token); return true; } catch (JwtException ex) { // 日志记录 } return false; } } ``` #### 任务 2.1.2: `security/SecurityConfig.java` - 安全配置 ```java package com.newsclassifier.config; import com.newsclassifier.security.JwtAuthenticationFilter; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.http.SessionCreationPolicy; import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.crypto.password.PasswordEncoder; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; /** * Spring Security配置 */ @Configuration @EnableWebSecurity public class SecurityConfig { private final JwtAuthenticationFilter jwtAuthenticationFilter; public SecurityConfig(JwtAuthenticationFilter jwtAuthenticationFilter) { this.jwtAuthenticationFilter = jwtAuthenticationFilter; } @Bean public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception { http .csrf(csrf -> csrf.disable()) .sessionManagement(session -> session.sessionCreationPolicy(SessionCreationPolicy.STATELESS) ) .authorizeHttpRequests(auth -> auth .requestMatchers("/api/auth/**").permitAll() .requestMatchers("/api/doc.html", "/api/swagger/**").permitAll() .anyRequest().authenticated() ) .addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class); return http.build(); } @Bean public PasswordEncoder passwordEncoder() { return new BCryptPasswordEncoder(); } } ``` #### 任务 2.1.3: `controller/AuthController.java` - 认证控制器 ```java package com.newsclassifier.controller; import com.newsclassifier.common.Result; import com.newsclassifier.dto.LoginDTO; import com.newsclassifier.dto.RegisterDTO; import com.newsclassifier.service.AuthService; import com.newsclassifier.vo.UserVO; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import jakarta.validation.Valid; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.*; /** * 认证控制器 */ @Tag(name = "认证接口") @RestController @RequestMapping("/api/auth") @RequiredArgsConstructor public class AuthController { private final AuthService authService; @Operation(summary = "用户登录") @PostMapping("/login") public Result login(@Valid @RequestBody LoginDTO loginDTO) { UserVO userVO = authService.login(loginDTO); return Result.success(userVO); } @Operation(summary = "用户注册") @PostMapping("/register") public Result register(@Valid @RequestBody RegisterDTO registerDTO) { authService.register(registerDTO); return Result.success(); } @Operation(summary = "刷新令牌") @PostMapping("/refresh") public Result refreshToken(@RequestHeader("Authorization") String token) { String newToken = authService.refreshToken(token); return Result.success(newToken); } @Operation(summary = "用户登出") @PostMapping("/logout") public Result logout() { return Result.success(); } } ``` #### 任务 2.1.4: `controller/NewsController.java` - 新闻控制器 ```java package com.newsclassifier.controller; import com.newsclassifier.common.PageResponse; import com.newsclassifier.common.Result; import com.newsclassifier.dto.NewsQueryDTO; import com.newsclassifier.service.NewsService; import com.newsclassifier.vo.NewsVO; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.*; /** * 新闻控制器 */ @Tag(name = "新闻接口") @RestController @RequestMapping("/api/news") @RequiredArgsConstructor public class NewsController { private final NewsService newsService; @Operation(summary = "分页查询新闻") @GetMapping("/page") public Result> getNewsPage(NewsQueryDTO queryDTO) { PageResponse page = newsService.getNewsPage(queryDTO); return Result.success(page); } @Operation(summary = "获取新闻详情") @GetMapping("/{id}") public Result getNewsDetail(@PathVariable Long id) { NewsVO newsVO = newsService.getNewsDetail(id); return Result.success(newsVO); } @Operation(summary = "搜索新闻") @GetMapping("/search") public Result> searchNews( @RequestParam String keyword, @RequestParam(defaultValue = "1") Integer page, @RequestParam(defaultValue = "20") Integer size ) { PageResponse result = newsService.searchNews(keyword, page, size); return Result.success(result); } @Operation(summary = "手动分类新闻") @PostMapping("/{id}/classify") public Result manualClassify( @PathVariable Long id, @RequestParam Long categoryId ) { newsService.manualClassify(id, categoryId); return Result.success(); } } ``` #### 任务 2.1.5: `controller/ClassifierController.java` - 分类控制器 ```java package com.newsclassifier.controller; import com.newsclassifier.common.Result; import com.newsclassifier.dto.ClassifyRequestDTO; import com.newsclassifier.service.ClassifierService; import io.swagger.v3.oas.annotations.Operation; import io.swagger.v3.oas.annotations.tags.Tag; import lombok.RequiredArgsConstructor; import org.springframework.web.bind.annotation.*; /** * 文本分类控制器 */ @Tag(name = "分类接口") @RestController @RequestMapping("/api/classifier") @RequiredArgsConstructor public class ClassifierController { private final ClassifierService classifierService; @Operation(summary = "对单条新闻进行分类") @PostMapping("/classify") public Result classify(@RequestBody ClassifyRequestDTO request) { ClassificationResultDTO result = classifierService.classify( request.getTitle(), request.getContent(), request.getMode() ); return Result.success(result); } @Operation(summary = "批量分类") @PostMapping("/batch-classify") public Result batchClassify( @RequestBody BatchClassifyRequestDTO request ) { BatchClassifyResultDTO result = classifierService.batchClassify( request.getNewsIds(), request.getMode() ); return Result.success(result); } @Operation(summary = "获取分类器状态") @GetMapping("/status") public Result getStatus() { ClassifierStatusDTO status = classifierService.getStatus(); return Result.success(status); } } ``` #### 任务 2.1.6: `service/impl/NewsServiceImpl.java` - 新闻服务实现 ```java package com.newsclassifier.service.impl; import com.baomidou.mybatisplus.core.conditions.query.LambdaQueryWrapper; import com.baomidou.mybatisplus.extension.plugins.pagination.Page; import com.newsclassifier.common.PageResponse; import com.newsclassifier.dto.NewsQueryDTO; import com.newsclassifier.entity.News; import com.newsclassifier.mapper.NewsMapper; import com.newsclassifier.service.NewsService; import com.newsclassifier.vo.NewsVO; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Service; /** * 新闻服务实现 */ @Service @RequiredArgsConstructor public class NewsServiceImpl implements NewsService { private final NewsMapper newsMapper; @Override public PageResponse getNewsPage(NewsQueryDTO queryDTO) { Page page = new Page<>(queryDTO.getPage(), queryDTO.getSize()); LambdaQueryWrapper wrapper = new LambdaQueryWrapper() .eq(queryDTO.getCategoryId() != null, News::getCategoryId, queryDTO.getCategoryId()) .eq(queryDTO.getCategoryCode() != null, News::getCategoryCode, queryDTO.getCategoryCode()) .eq(queryDTO.getStatus() != null, News::getStatus, queryDTO.getStatus()) .like(queryDTO.getKeyword() != null, News::getTitle, queryDTO.getKeyword()) .orderByDesc(News::getPublishTime); Page resultPage = newsMapper.selectPage(page, wrapper); // 转换为VO List voList = resultPage.getRecords().stream() .map(this::convertToVO) .collect(Collectors.toList()); return PageResponse.of(resultPage.getTotal(), voList); } @Override public NewsVO getNewsDetail(Long id) { News news = newsMapper.selectById(id); if (news == null) { throw new BusinessException(ErrorCode.NEWS_NOT_FOUND); } // 增加浏览次数 newsMapper.addViewCount(id); return convertToVO(news); } private NewsVO convertToVO(News news) { // 实现Entity到VO的转换 return new NewsVO(); } } ``` #### 任务 2.1.7: `classifier/ClassificationResult.java` - 分类结果类 ```java package com.newsclassifier.classifier; import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import java.math.BigDecimal; /** * 分类结果 */ @Data @Builder @NoArgsConstructor @AllArgsConstructor public class ClassificationResult { /** * 分类代码 */ private String categoryCode; /** * 分类名称 */ private String categoryName; /** * 置信度 0-1 */ private BigDecimal confidence; /** * 分类器类型 */ private String classifierType; /** * 各类别概率分布 */ private java.util.Map probabilities; /** * 耗时(毫秒) */ private Long duration; } ``` #### 任务 2.1.8: `classifier/BERTClassifier.java` - BERT分类器 ```java package com.newsclassifier.classifier; import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import java.io.BufferedReader; import java.io.InputStreamReader; import java.math.BigDecimal; import java.util.HashMap; import java.util.Map; /** * BERT文本分类器 * 通过调用Python服务实现 */ @Slf4j @Component public class BERTClassifier implements IClassifier { @Value("${classifier.bert.service-url:http://localhost:5000/api/predict}") private String bertServiceUrl; @Value("${classifier.bert.timeout:5000}") private int timeout; @Override public ClassificationResult classify(String title, String content) { long startTime = System.currentTimeMillis(); try { // 调用Python BERT服务 String result = callBERTService(title, content); // 解析结果 return parseResult(result); } catch (Exception e) { log.error("BERT分类失败", e); // 返回默认结果或降级处理 return getDefaultResult(); } finally { long duration = System.currentTimeMillis() - startTime; log.info("BERT分类耗时: {}ms", duration); } } @Override public String getType() { return "DL"; } private String callBERTService(String title, String content) { // 使用HttpClient调用Python服务 // 实现HTTP请求逻辑 return ""; } private ClassificationResult parseResult(String jsonResponse) { // 解析JSON响应 return ClassificationResult.builder() .categoryCode("TECHNOLOGY") .categoryName("科技") .confidence(new BigDecimal("0.95")) .classifierType("DL") .build(); } } ``` #### 任务 2.1.9: `exception/GlobalExceptionHandler.java` - 全局异常处理 ```java package com.newsclassifier.exception; import com.newsclassifier.common.Result; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.validation.BindException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.ResponseStatus; import org.springframework.web.bind.annotation.RestControllerAdvice; /** * 全局异常处理器 */ @Slf4j @RestControllerAdvice public class GlobalExceptionHandler { @ExceptionHandler(BusinessException.class) @ResponseStatus(HttpStatus.OK) public Result handleBusinessException(BusinessException e) { log.error("业务异常: {}", e.getMessage()); return Result.error(e.getErrorCode(), e.getMessage()); } @ExceptionHandler(BindException.class) @ResponseStatus(HttpStatus.BAD_REQUEST) public Result handleBindException(BindException e) { String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage(); return Result.error(400, message); } @ExceptionHandler(Exception.class) @ResponseStatus(HttpStatus.INTERNAL_SERVER_ERROR) public Result handleException(Exception e) { log.error("系统异常", e); return Result.error(500, "系统内部错误"); } } ``` #### 任务 2.1.10: `application.yml` - 应用配置文件 ```yaml spring: application: name: news-classifier datasource: driver-class-name: com.mysql.cj.jdbc.Driver url: jdbc:mysql://localhost:3306/news_classifier?useUnicode=true&characterEncoding=utf8mb4&serverTimezone=Asia/Shanghai username: root password: your_password data: redis: host: localhost port: 6379 database: 0 # MyBatis-Plus配置 mybatis-plus: mapper-locations: classpath:mapper/**/*.xml type-aliases-package: com.newsclassifier.entity configuration: map-underscore-to-camel-case: true log-impl: org.apache.ibatis.logging.stdout.StdOutImpl global-config: db-config: id-type: auto logic-delete-field: deleted logic-delete-value: 1 logic-not-delete-value: 0 # JWT配置 jwt: secret: your-secret-key-at-least-256-bits-long-for-hs256-algorithm expiration: 86400000 # 分类器配置 classifier: mode: hybrid # traditional, deep_learning, hybrid confidence: threshold: 0.75 hybrid-min: 0.6 bert: service-url: http://localhost:5000/api/predict timeout: 5000 # 日志配置 logging: level: com.newsclassifier: debug pattern: console: "%d{yyyy-MM-dd HH:mm:ss} [%thread] %-5level %logger{36} - %msg%n" ``` --- ## 3. 前端桌面模块 (Tauri + Vue3) ### 模块目录结构 ``` client/src/ ├── api/ # API接口 │ ├── index.ts │ ├── auth.ts │ ├── news.ts │ ├── category.ts │ └── classifier.ts ├── assets/ # 静态资源 │ ├── images/ │ ├── styles/ │ │ ├── main.css │ │ └── tailwind.css │ └── fonts/ ├── components/ # 组件 │ ├── ui/ # 基础UI组件 │ │ ├── button/ │ │ ├── input/ │ │ ├── dialog/ │ │ └── table/ │ ├── layout/ # 布局组件 │ │ ├── Header.vue │ │ ├── Sidebar.vue │ │ └── Footer.vue │ ├── news/ # 新闻相关组件 │ │ ├── NewsCard.vue │ │ ├── NewsList.vue │ │ ├── NewsDetail.vue │ │ └── CategoryFilter.vue │ └── charts/ # 图表组件 │ ├── CategoryChart.vue │ └── TrendChart.vue ├── composables/ # 组合式函数 │ ├── useAuth.ts │ ├── useNews.ts │ ├── useClassifier.ts │ └── useToast.ts ├── layouts/ # 布局 │ ├── DefaultLayout.vue │ ├── AuthLayout.vue │ └── EmptyLayout.vue ├── router/ # 路由 (部分完成) │ └── index.ts ├── stores/ # 状态管理 (部分完成) │ ├── user.ts │ ├── news.ts │ └── category.ts ├── types/ # TypeScript类型 │ ├── api.d.ts │ ├── news.d.ts │ └── user.d.ts ├── utils/ # 工具函数 │ ├── request.ts │ ├── storage.ts │ ├── format.ts │ └── validate.ts ├── views/ # 页面 │ ├── auth/ │ │ ├── Login.vue │ │ └── Register.vue │ ├── news/ │ │ ├── NewsList.vue │ │ ├── NewsDetail.vue │ │ └── NewsSearch.vue │ ├── category/ │ │ ├── CategoryManage.vue │ │ └── CategoryStats.vue │ ├── classifier/ │ │ ├── ClassifierPage.vue │ │ └── ModelCompare.vue │ └── admin/ │ ├── Dashboard.vue │ ├── UserManage.vue │ └── SystemLog.vue ├── App.vue └── main.ts ``` ### 3.1 需要完成的具体文件 #### 任务 3.1.1: `utils/request.ts` - HTTP请求封装 ```typescript import axios, { AxiosInstance, AxiosRequestConfig, AxiosResponse } from 'axios' // 响应数据类型 interface ApiResponse { code: number message: string data: T } // 创建axios实例 const service: AxiosInstance = axios.create({ baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:8080/api', timeout: 15000, headers: { 'Content-Type': 'application/json' } }) // 请求拦截器 service.interceptors.request.use( (config) => { const token = localStorage.getItem('token') if (token) { config.headers.Authorization = `Bearer ${token}` } return config }, (error) => { return Promise.reject(error) } ) // 响应拦截器 service.interceptors.response.use( (response: AxiosResponse) => { const { code, message, data } = response.data if (code === 200) { return data } else { // 处理错误 return Promise.reject(new Error(message || '请求失败')) } }, (error) => { // 处理HTTP错误 return Promise.reject(error) } ) // 封装请求方法 export const http = { get(url: string, config?: AxiosRequestConfig): Promise { return service.get(url, config) }, post(url: string, data?: any, config?: AxiosRequestConfig): Promise { return service.post(url, data, config) }, put(url: string, data?: any, config?: AxiosRequestConfig): Promise { return service.put(url, data, config) }, delete(url: string, config?: AxiosRequestConfig): Promise { return service.delete(url, config) } } export default service ``` #### 任务 3.1.2: `api/news.ts` - 新闻API ```typescript import { http } from './request' // 新闻查询参数 export interface NewsQueryParams { page?: number size?: number categoryId?: number categoryCode?: string keyword?: string status?: number } // 新闻详情 export interface NewsDetail { id: number title: string content: string summary: string source: string sourceUrl: string author: string categoryId: number categoryCode: string coverImage: string publishTime: string viewCount: number likeCount: number commentCount: number classifierType: string confidence: number } // 分页响应 export interface PageResponse { total: number records: T[] current: number size: number } // 新闻API export const newsApi = { // 分页查询 getNewsPage(params: NewsQueryParams): Promise> { return http.get('/news/page', { params }) }, // 获取详情 getNewsDetail(id: number): Promise { return http.get(`/news/${id}`) }, // 搜索新闻 searchNews(keyword: string, page = 1, size = 20): Promise> { return http.get('/news/search', { params: { keyword, page, size } }) }, // 手动分类 manualClassify(id: number, categoryId: number): Promise { return http.post(`/news/${id}/classify`, null, { params: { categoryId } }) } } ``` #### 任务 3.1.3: `composables/useNews.ts` - 新闻组合式函数 ```typescript import { ref, computed } from 'vue' import { newsApi, type NewsQueryParams, type NewsDetail, type PageResponse } from '@/api/news' export function useNews() { const loading = ref(false) const newsList = ref([]) const total = ref(0) const currentNews = ref(null) // 分页查询 const fetchNewsPage = async (params: NewsQueryParams) => { loading.value = true try { const result: PageResponse = await newsApi.getNewsPage(params) newsList.value = result.records total.value = result.total } catch (error) { console.error('获取新闻列表失败:', error) throw error } finally { loading.value = false } } // 获取详情 const fetchNewsDetail = async (id: number) => { loading.value = true try { currentNews.value = await newsApi.getNewsDetail(id) return currentNews.value } catch (error) { console.error('获取新闻详情失败:', error) throw error } finally { loading.value = false } } // 搜索新闻 const searchNews = async (keyword: string, page = 1, size = 20) => { loading.value = true try { const result: PageResponse = await newsApi.searchNews(keyword, page, size) newsList.value = result.records total.value = result.total } catch (error) { console.error('搜索新闻失败:', error) throw error } finally { loading.value = false } } return { loading: computed(() => loading.value), newsList: computed(() => newsList.value), total: computed(() => total.value), currentNews: computed(() => currentNews.value), fetchNewsPage, fetchNewsDetail, searchNews } } ``` #### 任务 3.1.4: `views/news/NewsList.vue` - 新闻列表页面 ```vue ``` #### 任务 3.1.5: `views/classifier/ClassifierPage.vue` - 分类器页面 ```vue ``` #### 任务 3.1.6: `router/index.ts` - 路由配置 (更新) ```typescript import { createRouter, createWebHistory } from 'vue-router' import type { RouteRecordRaw } from 'vue-router' const routes: RouteRecordRaw[] = [ { path: '/login', name: 'Login', component: () => import('@/views/auth/Login.vue'), meta: { layout: 'EmptyLayout' } }, { path: '/', name: 'Home', component: () => import('@/views/news/NewsList.vue'), meta: { requiresAuth: true } }, { path: '/news', name: 'NewsList', component: () => import('@/views/news/NewsList.vue'), meta: { requiresAuth: true } }, { path: '/news/:id', name: 'NewsDetail', component: () => import('@/views/news/NewsDetail.vue'), meta: { requiresAuth: true } }, { path: '/classifier', name: 'Classifier', component: () => import('@/views/classifier/ClassifierPage.vue'), meta: { requiresAuth: true } }, { path: '/category', name: 'CategoryStats', component: () => import('@/views/category/CategoryStats.vue'), meta: { requiresAuth: true } }, { path: '/admin', name: 'AdminDashboard', component: () => import('@/views/admin/Dashboard.vue'), meta: { requiresAuth: true, requiresAdmin: true } } ] const router = createRouter({ history: createWebHistory(), routes }) // 路由守卫 router.beforeEach((to, from, next) => { const token = localStorage.getItem('token') if (to.meta.requiresAuth && !token) { next('/login') } else if (to.meta.requiresAdmin) { // 检查管理员权限 const userRole = localStorage.getItem('userRole') if (userRole !== 'ADMIN') { next('/') } else { next() } } else { next() } }) export default router ``` --- ## 4. 机器学习分类模块 (Python) ### 模块目录结构 ``` ml-module/ ├── data/ │ ├── raw/ # 原始数据 │ ├── processed/ # 处理后的数据 │ │ ├── training_data.csv │ │ └── test_data.csv │ └── external/ # 外部数据集 ├── models/ # 训练好的模型 │ ├── traditional/ │ │ ├── nb_vectorizer.pkl │ │ ├── nb_classifier.pkl │ │ ├── svm_vectorizer.pkl │ │ └── svm_classifier.pkl │ ├── deep_learning/ │ │ └── bert_finetuned/ │ └── hybrid/ │ └── config.json ├── src/ │ ├── __init__.py │ ├── traditional/ # 传统机器学习 │ │ ├── __init__.py │ │ ├── train_model.py # (已有) │ │ ├── predict.py │ │ └── evaluate.py │ ├── deep_learning/ # 深度学习 │ │ ├── __init__.py │ │ ├── bert_model.py │ │ ├── train_bert.py │ │ └── predict_bert.py │ ├── hybrid/ # 混合策略 │ │ ├── __init__.py │ │ ├── hybrid_classifier.py │ │ └── rule_engine.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── preprocessing.py # 数据预处理 │ │ └── metrics.py # 评估指标 │ └── api/ # API服务 │ ├── __init__.py │ └── server.py # FastAPI服务 ├── notebooks/ # Jupyter notebooks │ ├── data_exploration.ipynb │ └── model_comparison.ipynb ├── tests/ # 测试 │ ├── test_traditional.py │ ├── test_bert.py │ └── test_hybrid.py ├── requirements.txt ├── setup.py └── README.md ``` ### 4.1 需要完成的具体文件 #### 任务 4.1.1: `src/traditional/predict.py` - 传统模型预测 ```python """ 传统机器学习模型预测 """ import os import joblib import jieba from typing import Dict, Any # 分类映射 CATEGORY_MAP = { 'POLITICS': '时政', 'FINANCE': '财经', 'TECHNOLOGY': '科技', 'SPORTS': '体育', 'ENTERTAINMENT': '娱乐', 'HEALTH': '健康', 'EDUCATION': '教育', 'LIFE': '生活', 'INTERNATIONAL': '国际', 'MILITARY': '军事' } class TraditionalPredictor: """传统机器学习预测器""" def __init__(self, model_type='nb', model_dir='../../models/traditional'): self.model_type = model_type self.model_dir = model_dir self.vectorizer = None self.classifier = None self._load_model() def _load_model(self): """加载模型""" vectorizer_path = os.path.join(self.model_dir, f'{self.model_type}_vectorizer.pkl') classifier_path = os.path.join(self.model_dir, f'{self.model_type}_classifier.pkl') self.vectorizer = joblib.load(vectorizer_path) self.classifier = joblib.load(classifier_path) print(f"模型加载成功: {self.model_type}") def preprocess(self, title: str, content: str) -> str: """预处理文本""" text = title + ' ' + content # jieba分词 words = jieba.cut(text) return ' '.join(words) def predict(self, title: str, content: str) -> Dict[str, Any]: """ 预测 :return: 预测结果字典 """ # 预处理 processed = self.preprocess(title, content) # 特征提取 tfidf = self.vectorizer.transform([processed]) # 预测 prediction = self.classifier.predict(tfidf)[0] probabilities = self.classifier.predict_proba(tfidf)[0] # 获取各类别概率 prob_dict = {} for i, prob in enumerate(probabilities): category_code = self.classifier.classes_[i] prob_dict[category_code] = float(prob) return { 'categoryCode': prediction, 'categoryName': CATEGORY_MAP.get(prediction, '未知'), 'confidence': float(probabilities.max()), 'probabilities': prob_dict } # API入口 def predict_single(title: str, content: str, model_type='nb') -> Dict[str, Any]: """ 单条预测API """ predictor = TraditionalPredictor(model_type) return predictor.predict(title, content) if __name__ == '__main__': # 测试 result = predict_single( title="华为发布新款折叠屏手机", content="华为今天正式发布了新一代折叠屏手机,搭载最新麒麟芯片..." ) print(result) ``` #### 任务 4.1.2: `src/deep_learning/bert_model.py` - BERT模型 ```python """ BERT文本分类模型 """ import torch from transformers import ( BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments ) from typing import Dict, Any, List # 分类映射 CATEGORY_MAP = { 'POLITICS': '时政', 'FINANCE': '财经', 'TECHNOLOGY': '科技', 'SPORTS': '体育', 'ENTERTAINMENT': '娱乐', 'HEALTH': '健康', 'EDUCATION': '教育', 'LIFE': '生活', 'INTERNATIONAL': '国际', 'MILITARY': '军事' } # 反向映射 ID_TO_LABEL = {i: label for i, label in enumerate(CATEGORY_MAP.keys())} LABEL_TO_ID = {label: i for i, label in enumerate(CATEGORY_MAP.keys())} class BertClassifier: """BERT文本分类器""" def __init__(self, model_name='bert-base-chinese', num_labels=10): self.model_name = model_name self.num_labels = num_labels self.tokenizer = None self.model = None self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def load_model(self, model_path): """加载微调后的模型""" self.tokenizer = BertTokenizer.from_pretrained(model_path) self.model = BertForSequenceClassification.from_pretrained( model_path, num_labels=self.num_labels ) self.model.to(self.device) self.model.eval() print(f"BERT模型加载成功: {model_path}") def predict(self, title: str, content: str) -> Dict[str, Any]: """ 预测 """ if self.model is None or self.tokenizer is None: raise ValueError("模型未加载,请先调用load_model") # 组合标题和内容 text = f"{title} [SEP] {content}" # 分词 inputs = self.tokenizer( text, return_tensors='pt', truncation=True, max_length=512, padding='max_length' ) # 预测 with torch.no_grad(): inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) logits = outputs.logits # 获取预测结果 probs = torch.softmax(logits, dim=-1) confidence, predicted_id = torch.max(probs, dim=-1) predicted_id = predicted_id.item() confidence = confidence.item() # 获取各类别概率 prob_dict = {} for i, prob in enumerate(probs[0].cpu().numpy()): category_code = ID_TO_LABEL[i] prob_dict[category_code] = float(prob) return { 'categoryCode': ID_TO_LABEL[predicted_id], 'categoryName': CATEGORY_MAP.get(ID_TO_LABEL[predicted_id], '未知'), 'confidence': confidence, 'probabilities': prob_dict } # 数据集类 class NewsDataset(torch.utils.data.Dataset): """新闻数据集""" def __init__(self, texts, labels, tokenizer, max_length=512): self.texts = texts self.labels = labels self.tokenizer = tokenizer self.max_length = max_length def __len__(self): return len(self.texts) def __getitem__(self, idx): text = self.texts[idx] label = self.labels[idx] encoding = self.tokenizer( text, truncation=True, max_length=self.max_length, padding='max_length', return_tensors='pt' ) return { 'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'labels': torch.tensor(label, dtype=torch.long) } if __name__ == '__main__': # 测试 classifier = BertClassifier() # classifier.load_model('./models/deep_learning/bert_finetuned') # # result = classifier.predict( # title="华为发布新款折叠屏手机", # content="华为今天正式发布了新一代折叠屏手机..." # ) # print(result) print("BERT分类器初始化成功") ``` #### 任务 4.1.3: `src/hybrid/hybrid_classifier.py` - 混合分类器 ```python """ 混合策略分类器 结合规则引擎和机器学习模型 """ import time from typing import Dict, Any from ..traditional.predict import TraditionalPredictor from ..deep_learning.bert_model import BertClassifier class HybridClassifier: """混合分类器""" def __init__(self): # 初始化各个分类器 self.nb_predictor = TraditionalPredictor('nb') self.bert_classifier = BertClassifier() # 配置参数 self.config = { 'confidence_threshold': 0.75, # 高置信度阈值 'hybrid_min_confidence': 0.60, # 混合模式最低阈值 'use_bert_threshold': 0.70, # 使用BERT的阈值 'rule_priority': True # 规则优先 } # 规则关键词字典 self.rule_keywords = { 'POLITICS': ['政府', '政策', '选举', '国务院', '主席', '总理'], 'FINANCE': ['股市', '经济', '金融', '投资', '基金', '银行'], 'TECHNOLOGY': ['芯片', 'AI', '人工智能', '5G', '互联网', '科技'], 'SPORTS': ['比赛', '冠军', '联赛', '球员', '教练', 'NBA'], 'ENTERTAINMENT': ['明星', '电影', '电视剧', '娱乐圈', '歌手'], 'HEALTH': ['健康', '医疗', '疾病', '治疗', '疫苗'], 'EDUCATION': ['教育', '学校', '大学', '考试', '招生'], 'LIFE': ['生活', '美食', '旅游', '购物'], 'INTERNATIONAL': ['国际', '美国', '欧洲', '日本', '外交'], 'MILITARY': ['军事', '武器', '军队', '国防', '战争'] } def rule_match(self, title: str, content: str) -> tuple[str | None, float]: """ 规则匹配 :return: (category_code, confidence) """ text = title + ' ' + content # 计算每个类别的关键词匹配数 matches = {} for category, keywords in self.rule_keywords.items(): count = sum(1 for kw in keywords if kw in text) if count > 0: matches[category] = count if not matches: return None, 0.0 # 返回匹配最多的类别 best_category = max(matches, key=matches.get) confidence = min(0.9, matches[best_category] * 0.15) # 规则置信度 return best_category, confidence def predict(self, title: str, content: str, use_bert=True) -> Dict[str, Any]: """ 混合预测 """ start_time = time.time() # 1. 先尝试规则匹配 rule_category, rule_confidence = self.rule_match(title, content) # 2. 传统机器学习预测 nb_result = self.nb_predictor.predict(title, content) nb_confidence = nb_result['confidence'] # 决策逻辑 final_result = None classifier_type = 'HYBRID' # 规则优先且规则置信度高 if self.config['rule_priority'] and rule_confidence >= self.config['confidence_threshold']: final_result = { 'categoryCode': rule_category, 'categoryName': nb_result['categoryName'], # 从映射获取 'confidence': rule_confidence, 'classifierType': 'RULE', 'reason': '规则匹配' } # 传统模型置信度足够高 elif nb_confidence >= self.config['confidence_threshold']: final_result = { **nb_result, 'classifierType': 'ML', 'reason': '传统模型高置信度' } # 需要使用BERT elif use_bert: # TODO: 加载BERT模型预测 # bert_result = self.bert_classifier.predict(title, content) # 如果BERT置信度也不高,选择最高的 final_result = { **nb_result, 'classifierType': 'HYBRID', 'reason': '混合决策' } else: # 不使用BERT,直接返回传统模型结果 final_result = { **nb_result, 'classifierType': 'ML', 'reason': '默认传统模型' } # 计算耗时 duration = int((time.time() - start_time) * 1000) final_result['duration'] = duration return final_result if __name__ == '__main__': # 测试 classifier = HybridClassifier() test_cases = [ { 'title': '国务院发布最新经济政策', 'content': '国务院今天发布了新的经济政策...' }, { 'title': '华为发布新款折叠屏手机', 'content': '华为今天正式发布了新一代折叠屏手机...' } ] for case in test_cases: result = classifier.predict(case['title'], case['content']) print(f"标题: {case['title']}") print(f"结果: {result['categoryName']} ({result['confidence']:.2f})") print(f"分类器: {result['classifierType']}") print(f"原因: {result.get('reason', 'N/A')}") print(f"耗时: {result['duration']}ms") print("-" * 50) ``` #### 任务 4.1.4: `src/api/server.py` - FastAPI服务 ```python """ 机器学习模型API服务 使用FastAPI提供RESTful API """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import Optional import logging # 导入分类器 import sys import os sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from traditional.predict import TraditionalPredictor from hybrid.hybrid_classifier import HybridClassifier # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 创建FastAPI应用 app = FastAPI( title="新闻分类API", description="提供新闻文本分类服务", version="1.0.0" ) # 配置CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 请求模型 class ClassifyRequest(BaseModel): title: str content: str mode: Optional[str] = 'hybrid' # traditional, hybrid # 响应模型 class ClassifyResponse(BaseModel): categoryCode: str categoryName: str confidence: float classifierType: str duration: int probabilities: Optional[dict] = None # 初始化分类器 nb_predictor = None hybrid_classifier = None @app.on_event("startup") async def startup_event(): """启动时加载模型""" global nb_predictor, hybrid_classifier logger.info("加载模型...") try: nb_predictor = TraditionalPredictor('nb') logger.info("朴素贝叶斯模型加载成功") except Exception as e: logger.error(f"朴素贝叶斯模型加载失败: {e}") try: hybrid_classifier = HybridClassifier() logger.info("混合分类器初始化成功") except Exception as e: logger.error(f"混合分类器初始化失败: {e}") @app.get("/") async def root(): """健康检查""" return { "status": "ok", "message": "新闻分类API服务运行中" } @app.get("/health") async def health_check(): """健康检查""" return { "status": "healthy", "models": { "nb_loaded": nb_predictor is not None, "hybrid_loaded": hybrid_classifier is not None } } @app.post("/api/predict", response_model=ClassifyResponse) async def predict(request: ClassifyRequest): """ 文本分类接口 - **title**: 新闻标题 - **content**: 新闻内容 - **mode**: 分类模式 (traditional, hybrid) """ try: if request.mode == 'traditional': result = nb_predictor.predict(request.title, request.content) result['classifierType'] = 'ML' else: # hybrid result = hybrid_classifier.predict(request.title, request.content) return ClassifyResponse(**result) except Exception as e: logger.error(f"预测失败: {e}") raise HTTPException(status_code=500, detail=str(e)) @app.post("/api/batch-predict") async def batch_predict(requests: list[ClassifyRequest]): """ 批量分类接口 """ results = [] for req in requests: try: if req.mode == 'traditional': result = nb_predictor.predict(req.title, req.content) result['classifierType'] = 'ML' else: result = hybrid_classifier.predict(req.title, req.content) results.append(result) except Exception as e: results.append({ 'error': str(e), 'title': req.title }) return {"results": results} if __name__ == '__main__': import uvicorn uvicorn.run( app, host="0.0.0.0", port=5000, log_level="info" ) ``` #### 任务 4.1.5: `src/utils/metrics.py` - 评估指标 ```python """ 模型评估指标工具 """ import numpy as np from sklearn.metrics import ( accuracy_score, precision_recall_fscore_support, confusion_matrix, classification_report ) from typing import List, Dict, Any import matplotlib.pyplot as plt import seaborn as sns class ClassificationMetrics: """分类评估指标""" @staticmethod def compute_all(y_true: List, y_pred: List, labels: List[str]) -> Dict[str, Any]: """ 计算所有指标 """ accuracy = accuracy_score(y_true, y_pred) precision, recall, f1, support = precision_recall_fscore_support( y_true, y_pred, average='weighted', zero_division=0 ) # 每个类别的指标 precision_per_class, recall_per_class, f1_per_class, support_per_class = \ precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0) per_class_metrics = {} for i, label in enumerate(labels): per_class_metrics[label] = { 'precision': float(precision_per_class[i]), 'recall': float(recall_per_class[i]), 'f1': float(f1_per_class[i]), 'support': int(support_per_class[i]) } return { 'accuracy': float(accuracy), 'precision': float(precision), 'recall': float(recall), 'f1': float(f1), 'per_class': per_class_metrics } @staticmethod def plot_confusion_matrix(y_true: List, y_pred: List, labels: List[str], save_path: str = None): """ 绘制混淆矩阵 """ cm = confusion_matrix(y_true, y_pred) plt.figure(figsize=(10, 8)) sns.heatmap( cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels ) plt.xlabel('预测标签') plt.ylabel('真实标签') plt.title('混淆矩阵') if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') plt.close() @staticmethod def print_report(y_true: List, y_pred: List, labels: List[str]): """ 打印分类报告 """ report = classification_report( y_true, y_pred, target_names=labels, zero_division=0 ) print(report) if __name__ == '__main__': # 测试 y_true = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'POLITICS', 'TECHNOLOGY'] y_pred = ['POLITICS', 'TECHNOLOGY', 'FINANCE', 'TECHNOLOGY', 'TECHNOLOGY'] labels = ['POLITICS', 'TECHNOLOGY', 'FINANCE'] metrics = ClassificationMetrics() result = metrics.compute_all(y_true, y_pred, labels) print(result) ``` #### 任务 4.1.6: `requirements.txt` - 依赖文件 ```txt # 机器学习模块依赖 numpy>=1.24.0 pandas>=2.0.0 scikit-learn>=1.3.0 jieba>=0.42.0 joblib>=1.3.0 # 深度学习 torch>=2.0.0 transformers>=4.30.0 # API服务 fastapi>=0.100.0 uvicorn[standard]>=0.23.0 pydantic>=2.0.0 # 数据可视化 matplotlib>=3.7.0 seaborn>=0.12.0 # 工具 python-dotenv>=1.0.0 pyyaml>=6.0 ``` --- ## 总结 ### 开发顺序建议 1. **第一阶段:基础框架** - 后端:数据库连接、实体类、基础配置 - 前端:路由配置、状态管理、API封装 2. **第二阶段:核心功能** - 爬虫模块(Python) - 传统机器学习分类器 - 后端API接口 - 前端新闻列表页面 3. **第三阶段:高级功能** - BERT深度学习分类器 - 混合策略分类器 - 前端分类器对比页面 - 统计图表 4. **第四阶段:完善优化** - 用户认证 - 数据可视化 - 性能优化 - 异常处理 ### 关键注意事项 1. **爬虫模块使用 Python**,通过 RESTful API 与 Java 后端通信 2. **分类器模块独立部署**,提供 HTTP 接口供后端调用 3. **前后端分离**,使用 JWT 进行身份认证 4. **数据库表结构**已在 `schema.sql` 中定义,需严格遵守 5. **API 统一响应格式**使用 `Result` 包装