news-classifier/ml-module/src/utils/data_loader.py

163 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
数据加载和本地存储工具
"""
import pandas as pd
import os
from sqlalchemy import create_engine
from typing import List, Dict, Any, Optional
import logging
logger = logging.getLogger(__name__)
class DataLoader:
"""数据加载和本地存储类"""
def __init__(self, db_url: str, data_dir: str = 'data'):
"""
初始化数据加载器
:param db_url: 数据库连接URL
:param data_dir: 数据存储目录
"""
self.db_url = db_url
self.data_dir = data_dir
self.engine = create_engine(db_url)
# 确保数据目录存在
os.makedirs(os.path.join(data_dir, 'raw'), exist_ok=True)
os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True)
def load_news_from_db(self, limit: int = None, category_ids: List[int] = None) -> pd.DataFrame:
"""
从数据库加载新闻数据
:param limit: 限制加载数据量
:param category_ids: 指定分类ID列表
:return: 新闻数据DataFrame
"""
query = """
SELECT n.id, n.title, n.content, c.name as category_name
FROM news n
LEFT JOIN news_category c ON n.category_id = c.id
WHERE n.content IS NOT NULL AND n.title IS NOT NULL
"""
if category_ids:
ids_str = ",".join(str(i) for i in category_ids)
query += f" AND c.id IN ({ids_str})"
if limit:
query += f" LIMIT {limit}"
df = pd.read_sql(query, self.engine)
logger.info(f"从数据库加载数据,条件: limit={limit}, category_ids={category_ids}")
try:
df = pd.read_sql(query, self.engine)
logger.info(f"成功加载 {len(df)} 条新闻数据")
return df
except Exception as e:
logger.error(f"加载数据失败: {e}")
raise
def save_to_local(self, df: pd.DataFrame, filename: str, subdir: str = 'processed'):
"""
将数据保存到本地
:param df: 要保存的DataFrame
:param filename: 文件名
:param subdir: 子目录
"""
file_path = os.path.join(self.data_dir, subdir, filename)
# 确保目录存在
os.makedirs(os.path.join(self.data_dir, subdir), exist_ok=True)
# 保存为CSV
df.to_csv(file_path, index=False, encoding='utf-8-sig')
logger.info(f"数据已保存到: {file_path}")
def load_from_local(self, filename: str, subdir: str = 'processed') -> Optional[pd.DataFrame]:
"""
从本地加载已保存的数据
:param filename: 文件名
:param subdir: 子目录
:return: DataFrame或None如果文件不存在
"""
file_path = os.path.join(self.data_dir, subdir, filename)
if os.path.exists(file_path):
df = pd.read_csv(file_path, encoding='utf-8-sig')
logger.info(f"从本地加载 {len(df)} 条数据: {file_path}")
return df
else:
logger.warning(f"本地文件不存在: {file_path}")
return None
def update_local_data(self, limit: int = None, category_ids: List[int] = None):
"""
更新本地数据(从数据库重新加载数据并保存)
:param limit: 限制加载数据量
:param category_ids: 指定分类ID列表
"""
# 从数据库加载数据
df = self.load_news_from_db(limit=limit, category_ids=category_ids)
if not df.empty:
# 生成文件名(包含时间戳)
import time
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"news_data_{timestamp}.csv"
# 保存到本地
self.save_to_local(df, filename)
# 可选:删除旧的本地文件,只保留最新的
self._cleanup_old_files(filename)
return df
else:
logger.warning("没有可用的数据需要保存")
return None
def _cleanup_old_files(self, current_filename: str):
"""
清理旧的本地文件(可选)
:param current_filename: 当前文件名
"""
subdir = 'processed'
dir_path = os.path.join(self.data_dir, subdir)
if os.path.exists(dir_path):
for filename in os.listdir(dir_path):
if filename != current_filename and filename.startswith('news_data_'):
file_path = os.path.join(dir_path, filename)
try:
os.remove(file_path)
logger.info(f"已删除旧文件: {file_path}")
except Exception as e:
logger.error(f"删除文件失败 {file_path}: {e}")
# 示例用法
if __name__ == '__main__':
# 配置数据库连接
db_url = "mysql+pymysql://root:root@localhost/news"
# 创建数据加载器
loader = DataLoader(db_url)
# 从数据库加载数据并保存到本地
data = loader.update_local_data(limit=800)
if data is not None:
print(f"成功加载数据,共 {len(data)} 条记录")
print(data.head())
else:
print("没有数据可加载")