""" 数据加载和本地存储工具 """ import pandas as pd import os from sqlalchemy import create_engine from typing import List, Dict, Any, Optional import logging logger = logging.getLogger(__name__) class DataLoader: """数据加载和本地存储类""" def __init__(self, db_url: str, data_dir: str = 'data'): """ 初始化数据加载器 :param db_url: 数据库连接URL :param data_dir: 数据存储目录 """ self.db_url = db_url self.data_dir = data_dir self.engine = create_engine(db_url) # 确保数据目录存在 os.makedirs(os.path.join(data_dir, 'raw'), exist_ok=True) os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True) def load_news_from_db(self, limit: int = None, category_ids: List[int] = None) -> pd.DataFrame: """ 从数据库加载新闻数据 :param limit: 限制加载数据量 :param category_ids: 指定分类ID列表 :return: 新闻数据DataFrame """ query = """ SELECT n.id, n.title, n.content, c.name as category_name FROM news n LEFT JOIN news_category c ON n.category_id = c.id WHERE n.content IS NOT NULL AND n.title IS NOT NULL """ if category_ids: ids_str = ",".join(str(i) for i in category_ids) query += f" AND c.id IN ({ids_str})" if limit: query += f" LIMIT {limit}" df = pd.read_sql(query, self.engine) logger.info(f"从数据库加载数据,条件: limit={limit}, category_ids={category_ids}") try: df = pd.read_sql(query, self.engine) logger.info(f"成功加载 {len(df)} 条新闻数据") return df except Exception as e: logger.error(f"加载数据失败: {e}") raise def save_to_local(self, df: pd.DataFrame, filename: str, subdir: str = 'processed'): """ 将数据保存到本地 :param df: 要保存的DataFrame :param filename: 文件名 :param subdir: 子目录 """ file_path = os.path.join(self.data_dir, subdir, filename) # 确保目录存在 os.makedirs(os.path.join(self.data_dir, subdir), exist_ok=True) # 保存为CSV df.to_csv(file_path, index=False, encoding='utf-8-sig') logger.info(f"数据已保存到: {file_path}") def load_from_local(self, filename: str, subdir: str = 'processed') -> Optional[pd.DataFrame]: """ 从本地加载已保存的数据 :param filename: 文件名 :param subdir: 子目录 :return: DataFrame或None(如果文件不存在) """ file_path = os.path.join(self.data_dir, subdir, filename) if os.path.exists(file_path): df = pd.read_csv(file_path, encoding='utf-8-sig') logger.info(f"从本地加载 {len(df)} 条数据: {file_path}") return df else: logger.warning(f"本地文件不存在: {file_path}") return None def update_local_data(self, limit: int = None, category_ids: List[int] = None): """ 更新本地数据(从数据库重新加载数据并保存) :param limit: 限制加载数据量 :param category_ids: 指定分类ID列表 """ # 从数据库加载数据 df = self.load_news_from_db(limit=limit, category_ids=category_ids) if not df.empty: # 生成文件名(包含时间戳) import time timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"news_data_{timestamp}.csv" # 保存到本地 self.save_to_local(df, filename) # 可选:删除旧的本地文件,只保留最新的 self._cleanup_old_files(filename) return df else: logger.warning("没有可用的数据需要保存") return None def _cleanup_old_files(self, current_filename: str): """ 清理旧的本地文件(可选) :param current_filename: 当前文件名 """ subdir = 'processed' dir_path = os.path.join(self.data_dir, subdir) if os.path.exists(dir_path): for filename in os.listdir(dir_path): if filename != current_filename and filename.startswith('news_data_'): file_path = os.path.join(dir_path, filename) try: os.remove(file_path) logger.info(f"已删除旧文件: {file_path}") except Exception as e: logger.error(f"删除文件失败 {file_path}: {e}") # 示例用法 if __name__ == '__main__': # 配置数据库连接 db_url = "mysql+pymysql://root:root@localhost/news" # 创建数据加载器 loader = DataLoader(db_url) # 从数据库加载数据并保存到本地 data = loader.update_local_data(limit=800) if data is not None: print(f"成功加载数据,共 {len(data)} 条记录") print(data.head()) else: print("没有数据可加载")