163 lines
5.1 KiB
Python
163 lines
5.1 KiB
Python
"""
|
||
数据加载和本地存储工具
|
||
"""
|
||
|
||
import pandas as pd
|
||
import os
|
||
from sqlalchemy import create_engine
|
||
from typing import List, Dict, Any, Optional
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class DataLoader:
|
||
"""数据加载和本地存储类"""
|
||
|
||
def __init__(self, db_url: str, data_dir: str = 'data'):
|
||
"""
|
||
初始化数据加载器
|
||
|
||
:param db_url: 数据库连接URL
|
||
:param data_dir: 数据存储目录
|
||
"""
|
||
self.db_url = db_url
|
||
self.data_dir = data_dir
|
||
self.engine = create_engine(db_url)
|
||
|
||
# 确保数据目录存在
|
||
os.makedirs(os.path.join(data_dir, 'raw'), exist_ok=True)
|
||
os.makedirs(os.path.join(data_dir, 'processed'), exist_ok=True)
|
||
|
||
def load_news_from_db(self, limit: int = None, category_ids: List[int] = None) -> pd.DataFrame:
|
||
"""
|
||
从数据库加载新闻数据
|
||
|
||
:param limit: 限制加载数据量
|
||
:param category_ids: 指定分类ID列表
|
||
:return: 新闻数据DataFrame
|
||
"""
|
||
query = """
|
||
SELECT n.id, n.title, n.content, c.name as category_name
|
||
FROM news n
|
||
LEFT JOIN news_category c ON n.category_id = c.id
|
||
WHERE n.content IS NOT NULL AND n.title IS NOT NULL
|
||
"""
|
||
|
||
if category_ids:
|
||
ids_str = ",".join(str(i) for i in category_ids)
|
||
query += f" AND c.id IN ({ids_str})"
|
||
|
||
if limit:
|
||
query += f" LIMIT {limit}"
|
||
|
||
df = pd.read_sql(query, self.engine)
|
||
|
||
|
||
logger.info(f"从数据库加载数据,条件: limit={limit}, category_ids={category_ids}")
|
||
|
||
try:
|
||
df = pd.read_sql(query, self.engine)
|
||
logger.info(f"成功加载 {len(df)} 条新闻数据")
|
||
return df
|
||
except Exception as e:
|
||
logger.error(f"加载数据失败: {e}")
|
||
raise
|
||
|
||
def save_to_local(self, df: pd.DataFrame, filename: str, subdir: str = 'processed'):
|
||
"""
|
||
将数据保存到本地
|
||
|
||
:param df: 要保存的DataFrame
|
||
:param filename: 文件名
|
||
:param subdir: 子目录
|
||
"""
|
||
file_path = os.path.join(self.data_dir, subdir, filename)
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.join(self.data_dir, subdir), exist_ok=True)
|
||
|
||
# 保存为CSV
|
||
df.to_csv(file_path, index=False, encoding='utf-8-sig')
|
||
logger.info(f"数据已保存到: {file_path}")
|
||
|
||
def load_from_local(self, filename: str, subdir: str = 'processed') -> Optional[pd.DataFrame]:
|
||
"""
|
||
从本地加载已保存的数据
|
||
|
||
:param filename: 文件名
|
||
:param subdir: 子目录
|
||
:return: DataFrame或None(如果文件不存在)
|
||
"""
|
||
file_path = os.path.join(self.data_dir, subdir, filename)
|
||
|
||
if os.path.exists(file_path):
|
||
df = pd.read_csv(file_path, encoding='utf-8-sig')
|
||
logger.info(f"从本地加载 {len(df)} 条数据: {file_path}")
|
||
return df
|
||
else:
|
||
logger.warning(f"本地文件不存在: {file_path}")
|
||
return None
|
||
|
||
def update_local_data(self, limit: int = None, category_ids: List[int] = None):
|
||
"""
|
||
更新本地数据(从数据库重新加载数据并保存)
|
||
|
||
:param limit: 限制加载数据量
|
||
:param category_ids: 指定分类ID列表
|
||
"""
|
||
# 从数据库加载数据
|
||
df = self.load_news_from_db(limit=limit, category_ids=category_ids)
|
||
|
||
if not df.empty:
|
||
# 生成文件名(包含时间戳)
|
||
import time
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
filename = f"news_data_{timestamp}.csv"
|
||
|
||
# 保存到本地
|
||
self.save_to_local(df, filename)
|
||
|
||
# 可选:删除旧的本地文件,只保留最新的
|
||
self._cleanup_old_files(filename)
|
||
|
||
return df
|
||
else:
|
||
logger.warning("没有可用的数据需要保存")
|
||
return None
|
||
|
||
def _cleanup_old_files(self, current_filename: str):
|
||
"""
|
||
清理旧的本地文件(可选)
|
||
|
||
:param current_filename: 当前文件名
|
||
"""
|
||
subdir = 'processed'
|
||
dir_path = os.path.join(self.data_dir, subdir)
|
||
|
||
if os.path.exists(dir_path):
|
||
for filename in os.listdir(dir_path):
|
||
if filename != current_filename and filename.startswith('news_data_'):
|
||
file_path = os.path.join(dir_path, filename)
|
||
try:
|
||
os.remove(file_path)
|
||
logger.info(f"已删除旧文件: {file_path}")
|
||
except Exception as e:
|
||
logger.error(f"删除文件失败 {file_path}: {e}")
|
||
|
||
|
||
# 示例用法
|
||
if __name__ == '__main__':
|
||
# 配置数据库连接
|
||
db_url = "mysql+pymysql://root:root@localhost/news"
|
||
|
||
# 创建数据加载器
|
||
loader = DataLoader(db_url)
|
||
|
||
# 从数据库加载数据并保存到本地
|
||
data = loader.update_local_data(limit=800)
|
||
|
||
if data is not None:
|
||
print(f"成功加载数据,共 {len(data)} 条记录")
|
||
print(data.head())
|
||
else:
|
||
print("没有数据可加载") |