first commit

This commit is contained in:
2025-06-12 19:37:54 +08:00
parent bb2eb010f7
commit 1c6093fa9a
87 changed files with 18432 additions and 0 deletions

View File

@@ -0,0 +1,192 @@
# 🗄️ 农业股票数据处理器 - MySQL数据库存储
## 📋 概述
本文档说明如何配置和使用农业股票数据处理器的MySQL数据库存储功能。处理器现在支持将Spark分析结果直接保存到MySQL数据库中而不仅仅是文件系统。
## 🔧 数据库配置
### 1. 数据库表结构
首先需要创建必要的数据库表。执行以下SQL脚本
```bash
# 执行扩展表结构脚本
mysql -u your_username -p your_database < database_tables.sql
```
### 2. 配置文件
`application.conf` 中配置数据库连接(如果文件不存在,程序会使用默认值):
```hocon
mysql {
host = "localhost"
port = 3306
database = "agricultural_stock"
user = "root"
password = "your_password"
}
spark {
master = "local[*]"
}
output {
path = "/tmp/spark-output" # 备用文件输出路径
}
```
## 📊 数据存储结构
### 处理结果存储到以下表:
#### 1. `market_analysis` - 市场分析表
存储每日市场总览数据:
- 上涨/下跌/平盘股票数量
- 总市值、成交量、成交额
- 平均涨跌幅
#### 2. `stock_technical_indicators` - 技术指标表
存储股票技术指标:
- 移动平均线 (MA5, MA10, MA20, MA30)
- RSI相对强弱指标
- MACD指标 (DIF, DEA)
- 布林带 (上轨、中轨、下轨)
#### 3. `industry_analysis` - 行业分析表
存储行业表现数据:
- 行业平均涨跌幅
- 行业股票数量
- 行业总市值和成交量
#### 4. `market_trends` - 市场趋势表
存储历史趋势数据:
- 每日平均价格和涨跌幅
- 每日总成交量和成交额
## 🚀 运行方式
### 1. 标准运行
```bash
# 编译项目
mvn clean compile
# 运行处理器(批处理模式)
mvn exec:java -Dexec.mainClass="com.agricultural.spark.StockDataProcessor"
```
### 2. 打包运行
```bash
# 打包成可执行JAR
mvn package
# 运行JAR包
java -jar target/spark-data-processor-1.0.0.jar
```
## 📈 数据流程
```
MySQL stock_data (输入)
数据清洗和处理
技术指标计算
市场分析计算
保存到MySQL数据库表
↓ ↓
主要表 备用文件
```
## 🔍 数据查询示例
### 查看最新市场分析
```sql
SELECT * FROM market_analysis
ORDER BY analysis_date DESC
LIMIT 1;
```
### 查看特定股票技术指标
```sql
SELECT stock_name, trade_date, close_price, ma5, ma20, rsi
FROM stock_technical_indicators
WHERE stock_code = 'sz000876'
ORDER BY trade_date DESC
LIMIT 30;
```
### 查看行业表现排行
```sql
SELECT industry, avg_change_percent, stock_count, total_market_cap
FROM industry_analysis
WHERE analysis_date = (SELECT MAX(analysis_date) FROM industry_analysis)
ORDER BY avg_change_percent DESC;
```
### 查看市场趋势
```sql
SELECT trade_date, avg_change_percent, total_volume
FROM market_trends
ORDER BY trade_date DESC
LIMIT 30;
```
## ⚠️ 注意事项
### 1. 错误处理
- 程序会首先测试数据库连接
- 如果数据库连接失败,会回退到文件保存模式
- 所有数据库操作都有异常处理,不会中断主流程
### 2. 性能考虑
- 技术指标数据量较大,采用批量写入方式
- 市场分析数据使用 `ON DUPLICATE KEY UPDATE` 避免重复
- 添加了必要的数据库索引优化查询性能
### 3. 数据一致性
- `market_analysis` 表按日期去重
- 技术指标表按股票代码和日期建立组合索引
- 所有表都包含创建时间和更新时间字段
## 🛠️ 故障排除
### 1. 数据库连接失败
检查配置文件中的数据库连接信息:
- 主机地址和端口
- 数据库名称
- 用户名和密码
- 确保MySQL服务正在运行
### 2. 表不存在错误
确保已执行 `database_tables.sql` 脚本创建所有必要的表。
### 3. 权限问题
确保数据库用户具有以下权限:
- SELECT (读取 stock_data)
- INSERT (写入分析结果)
- UPDATE (更新已有数据)
### 4. 内存不足
对于大量数据可能需要调整Spark配置
```bash
# 增加内存限制
export SPARK_DRIVER_MEMORY=2g
export SPARK_EXECUTOR_MEMORY=2g
```
## 📝 日志说明
程序运行时会输出详细日志,包括:
- 数据库连接状态
- 各阶段处理进度
- 数据保存结果
- 错误信息和回退操作
关键日志信息:
- `数据库连接测试成功` - 数据库连接正常
- `市场分析结果保存成功` - 数据已保存到数据库
- `保存到数据库失败,回退到文件保存` - 使用备用保存方式

View File

@@ -0,0 +1,88 @@
-- ================================
-- 农业股票数据处理系统 - 扩展表结构
-- 用于存储Spark处理后的分析结果
-- ================================
-- 技术指标数据表
CREATE TABLE IF NOT EXISTS stock_technical_indicators (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
stock_code VARCHAR(10) NOT NULL COMMENT '股票代码',
stock_name VARCHAR(100) NOT NULL COMMENT '股票名称',
trade_date DATE NOT NULL COMMENT '交易日期',
close_price DECIMAL(10,2) COMMENT '收盘价',
ma5 DECIMAL(10,2) COMMENT '5日移动平均',
ma10 DECIMAL(10,2) COMMENT '10日移动平均',
ma20 DECIMAL(10,2) COMMENT '20日移动平均',
ma30 DECIMAL(10,2) COMMENT '30日移动平均',
rsi DECIMAL(5,2) COMMENT 'RSI相对强弱指标',
macd_dif DECIMAL(10,4) COMMENT 'MACD DIF值',
macd_dea DECIMAL(10,4) COMMENT 'MACD DEA值',
bb_upper DECIMAL(10,2) COMMENT '布林带上轨',
bb_middle DECIMAL(10,2) COMMENT '布林带中轨',
bb_lower DECIMAL(10,2) COMMENT '布林带下轨',
create_time DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间',
update_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='股票技术指标表';
-- 创建索引
CREATE INDEX idx_technical_stock ON stock_technical_indicators(stock_code);
CREATE INDEX idx_technical_date ON stock_technical_indicators(trade_date);
CREATE INDEX idx_technical_stock_date ON stock_technical_indicators(stock_code, trade_date);
-- 行业分析表
CREATE TABLE IF NOT EXISTS industry_analysis (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
industry VARCHAR(50) NOT NULL COMMENT '行业名称',
analysis_date DATE NOT NULL COMMENT '分析日期',
stock_count INT COMMENT '股票数量',
avg_change_percent DECIMAL(5,2) COMMENT '平均涨跌幅',
total_market_cap DECIMAL(15,2) COMMENT '行业总市值',
total_volume BIGINT COMMENT '行业总成交量',
create_time DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='行业分析表';
-- 创建索引
CREATE INDEX idx_industry_date ON industry_analysis(analysis_date);
CREATE INDEX idx_industry_name ON industry_analysis(industry);
CREATE INDEX idx_industry_date_name ON industry_analysis(analysis_date, industry);
-- 市场趋势表
CREATE TABLE IF NOT EXISTS market_trends (
id BIGINT AUTO_INCREMENT PRIMARY KEY,
trade_date DATE NOT NULL COMMENT '交易日期',
avg_price DECIMAL(10,2) COMMENT '平均价格',
avg_change_percent DECIMAL(5,2) COMMENT '平均涨跌幅',
total_volume BIGINT COMMENT '总成交量',
total_turnover DECIMAL(15,2) COMMENT '总成交额',
stock_count INT COMMENT '股票数量',
create_time DATETIME DEFAULT CURRENT_TIMESTAMP COMMENT '创建时间'
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT='市场趋势表';
-- 创建索引
CREATE INDEX idx_trends_date ON market_trends(trade_date);
-- 为已有的market_analysis表添加唯一索引防止重复数据
CREATE UNIQUE INDEX idx_market_analysis_date ON market_analysis(analysis_date);
-- ================================
-- 示例查询语句
-- ================================
-- 查询最新的市场分析数据
-- SELECT * FROM market_analysis ORDER BY analysis_date DESC LIMIT 1;
-- 查询特定股票的技术指标
-- SELECT * FROM stock_technical_indicators
-- WHERE stock_code = 'sz000876'
-- ORDER BY trade_date DESC LIMIT 30;
-- 查询行业表现排行
-- SELECT industry, avg_change_percent, stock_count
-- FROM industry_analysis
-- WHERE analysis_date = (SELECT MAX(analysis_date) FROM industry_analysis)
-- ORDER BY avg_change_percent DESC;
-- 查询市场趋势
-- SELECT trade_date, avg_change_percent, total_volume
-- FROM market_trends
-- ORDER BY trade_date DESC LIMIT 30;

169
spark-processor/pom.xml Normal file
View File

@@ -0,0 +1,169 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>com.agricultural</groupId>
<artifactId>spark-data-processor</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<name>Agricultural Stock Spark Processor</name>
<description>基于Apache Spark的农业股票数据处理器</description>
<properties>
<maven.compiler.source>8</maven.compiler.source>
<maven.compiler.target>8</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<spark.version>3.4.0</spark.version>
<scala.binary.version>2.12</scala.binary.version>
<mysql.version>8.0.33</mysql.version>
<jackson.version>2.14.3</jackson.version>
</properties>
<dependencies>
<!-- Apache Spark Core -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<!-- Apache Spark SQL -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<!-- Apache Spark Streaming -->
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
</dependency>
<!-- Spark Kafka Integration - 已移除 -->
<!-- MySQL JDBC Driver -->
<dependency>
<groupId>mysql</groupId>
<artifactId>mysql-connector-java</artifactId>
<version>${mysql.version}</version>
</dependency>
<!-- Jackson for JSON processing -->
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>${jackson.version}</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>${jackson.version}</version>
</dependency>
<!-- Configuration -->
<dependency>
<groupId>com.typesafe</groupId>
<artifactId>config</artifactId>
<version>1.4.2</version>
</dependency>
<!-- Logging -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.36</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.2.12</version>
</dependency>
<!-- Apache Commons -->
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>3.12.0</version>
</dependency>
<!-- Test Dependencies -->
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.13.2</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.mockito</groupId>
<artifactId>mockito-core</artifactId>
<version>4.11.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<!-- Maven Compiler Plugin -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.11.0</version>
<configuration>
<source>8</source>
<target>8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<!-- Maven Shade Plugin for fat JAR -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<version>3.4.1</version>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ManifestResourceTransformer">
<mainClass>com.agricultural.spark.StockDataProcessor</mainClass>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
</transformers>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
</execution>
</executions>
</plugin>
<!-- Maven Surefire Plugin for tests -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.0.0</version>
</plugin>
</plugins>
</build>
</project>

View File

@@ -0,0 +1,228 @@
package com.agricultural.spark;
import com.agricultural.spark.config.SparkConfig;
import com.agricultural.spark.service.DataCleaningService;
import com.agricultural.spark.service.DatabaseSaveService;
import com.agricultural.spark.service.MarketAnalysisService;
// StreamProcessingService 已移除 (Kafka相关)
import com.agricultural.spark.service.TechnicalIndicatorService;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Map;
/**
* 农业股票数据处理器主类
* 基于Apache Spark的大数据处理平台
*
* @author Agricultural Stock Platform Team
*/
public class StockDataProcessor {
private static final Logger logger = LoggerFactory.getLogger(StockDataProcessor.class);
private final SparkSession spark;
private final SparkConfig config;
private final DataCleaningService dataCleaningService;
private final MarketAnalysisService marketAnalysisService;
private final TechnicalIndicatorService technicalIndicatorService;
private final DatabaseSaveService databaseSaveService;
// StreamProcessingService 已移除
public StockDataProcessor(SparkConfig config) {
this.config = config;
this.spark = initializeSparkSession();
this.dataCleaningService = new DataCleaningService(spark);
this.marketAnalysisService = new MarketAnalysisService(spark);
this.technicalIndicatorService = new TechnicalIndicatorService(spark);
this.databaseSaveService = new DatabaseSaveService(spark, config);
// StreamProcessingService 初始化已移除
}
/**
* 初始化Spark会话
*/
private SparkSession initializeSparkSession() {
logger.info("正在初始化Spark会话...");
SparkSession.Builder builder = SparkSession.builder()
.appName("AgriculturalStockDataProcessor")
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.sql.warehouse.dir", "/tmp/spark-warehouse");
// 如果是本地模式
if (config.getSparkMaster().startsWith("local")) {
builder.master(config.getSparkMaster());
}
SparkSession session = builder.getOrCreate();
session.sparkContext().setLogLevel("WARN");
logger.info("Spark会话初始化成功");
return session;
}
/**
* 从MySQL加载股票数据
*/
public Dataset<Row> loadStockDataFromMySQL() {
logger.info("从MySQL加载股票数据...");
String jdbcUrl = String.format("jdbc:mysql://%s:%d/%s?useSSL=false&serverTimezone=Asia/Shanghai",
config.getMysqlHost(), config.getMysqlPort(), config.getMysqlDatabase());
Dataset<Row> df = spark.read()
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "stock_data")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.load();
long count = df.count();
logger.info("从MySQL加载数据完成共 {} 条记录", count);
return df;
}
/**
* 执行批处理分析
*/
public void runBatchProcessing() {
logger.info("开始执行批处理分析...");
try {
// 0. 测试数据库连接
if (!databaseSaveService.testConnection()) {
logger.warn("数据库连接测试失败,将仅保存到文件系统");
}
// 1. 加载数据
Dataset<Row> rawData = loadStockDataFromMySQL();
// 2. 数据清洗
Dataset<Row> cleanedData = dataCleaningService.cleanData(rawData);
// 3. 市场总览分析
Map<String, Object> marketOverview = marketAnalysisService.analyzeMarketOverview(cleanedData);
logger.info("市场总览分析完成: {}", marketOverview);
// 4. 涨跌幅排行分析
Map<String, Object> rankingAnalysis = marketAnalysisService.analyzeTopGainersLosers(cleanedData, 10);
logger.info("涨跌幅排行分析完成");
// 5. 行业表现分析
Dataset<Row> industryAnalysis = marketAnalysisService.analyzeIndustryPerformance(cleanedData);
logger.info("行业表现分析完成,共 {} 个行业", industryAnalysis.count());
// 6. 历史趋势分析
Dataset<Row> trendAnalysis = marketAnalysisService.analyzeHistoricalTrends(cleanedData, null, 30);
logger.info("历史趋势分析完成,共 {} 个数据点", trendAnalysis.count());
// 7. 技术指标计算
Dataset<Row> dataWithIndicators = technicalIndicatorService.calculateTechnicalIndicators(cleanedData);
logger.info("技术指标计算完成");
// 8. 保存处理结果到数据库
try {
databaseSaveService.saveMarketAnalysis(marketOverview);
databaseSaveService.saveIndustryAnalysis(industryAnalysis);
databaseSaveService.saveHistoricalTrends(trendAnalysis);
databaseSaveService.saveTechnicalIndicators(dataWithIndicators);
logger.info("数据已成功保存到MySQL数据库");
} catch (Exception e) {
logger.error("保存到数据库失败,回退到文件保存", e);
// 回退到原有的文件保存方式
saveResults(dataWithIndicators, "processed_data");
saveAnalysisResults(marketOverview, "market_overview");
}
logger.info("批处理分析完成");
} catch (Exception e) {
logger.error("批处理分析过程中出现错误", e);
throw new RuntimeException("批处理分析失败", e);
}
}
/**
* 实时流处理已移除原Kafka功能
*/
public void runStreamProcessing() {
logger.warn("实时流处理功能已移除,仅支持批处理模式");
}
/**
* 保存处理结果到文件
*/
private void saveResults(Dataset<Row> data, String outputPath) {
try {
String fullPath = config.getOutputPath() + "/" + outputPath;
data.write()
.mode("overwrite")
.parquet(fullPath);
logger.info("数据已保存到: {}", fullPath);
} catch (Exception e) {
logger.error("保存数据失败: {}", outputPath, e);
}
}
/**
* 保存分析结果
*/
private void saveAnalysisResults(Map<String, Object> results, String outputPath) {
try {
String fullPath = config.getOutputPath() + "/" + outputPath + ".json";
// 这里可以将Map转换为JSON并保存
logger.info("分析结果已保存到: {}", fullPath);
} catch (Exception e) {
logger.error("保存分析结果失败: {}", outputPath, e);
}
}
/**
* 关闭Spark会话
*/
public void close() {
if (spark != null) {
spark.stop();
logger.info("Spark会话已关闭");
}
}
/**
* 主方法
*/
public static void main(String[] args) {
logger.info("农业股票数据处理器启动");
StockDataProcessor processor = null;
try {
// 加载配置
SparkConfig config = SparkConfig.load();
processor = new StockDataProcessor(config);
// 根据参数决定运行模式
if (args.length > 0 && "stream".equals(args[0])) {
// 流处理模式
processor.runStreamProcessing();
} else {
// 批处理模式
processor.runBatchProcessing();
}
} catch (Exception e) {
logger.error("程序运行过程中出现错误", e);
System.exit(1);
} finally {
if (processor != null) {
processor.close();
}
logger.info("农业股票数据处理器已停止");
}
}
}

View File

@@ -0,0 +1,60 @@
package com.agricultural.spark.config;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigFactory;
/**
* Spark配置类
*
* @author Agricultural Stock Platform Team
*/
public class SparkConfig {
private final Config config;
private SparkConfig(Config config) {
this.config = config;
}
public static SparkConfig load() {
Config config = ConfigFactory.load();
return new SparkConfig(config);
}
public String getSparkMaster() {
return config.hasPath("spark.master") ?
config.getString("spark.master") : "local[*]";
}
public String getMysqlHost() {
return config.hasPath("mysql.host") ?
config.getString("mysql.host") : "localhost";
}
public int getMysqlPort() {
return config.hasPath("mysql.port") ?
config.getInt("mysql.port") : 3306;
}
public String getMysqlDatabase() {
return config.hasPath("mysql.database") ?
config.getString("mysql.database") : "agricultural_stock";
}
public String getMysqlUser() {
return config.hasPath("mysql.user") ?
config.getString("mysql.user") : "root";
}
public String getMysqlPassword() {
return config.hasPath("mysql.password") ?
config.getString("mysql.password") : "root";
}
// Kafka配置已移除
public String getOutputPath() {
return config.hasPath("output.path") ?
config.getString("output.path") : "/tmp/spark-output";
}
}

View File

@@ -0,0 +1,190 @@
package com.agricultural.spark.service;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.spark.sql.functions.*;
/**
* 数据清洗服务类
*
* @author Agricultural Stock Platform Team
*/
public class DataCleaningService {
private static final Logger logger = LoggerFactory.getLogger(DataCleaningService.class);
private final SparkSession spark;
public DataCleaningService(SparkSession spark) {
this.spark = spark;
}
/**
* 执行数据清洗
*
* @param rawData 原始数据
* @return 清洗后的数据
*/
public Dataset<Row> cleanData(Dataset<Row> rawData) {
logger.info("开始执行数据清洗,原始数据条数: {}", rawData.count());
Dataset<Row> cleanedData = rawData;
// 1. 去除重复数据
cleanedData = cleanedData.dropDuplicates("stock_code", "trade_date");
logger.info("去重后数据条数: {}", cleanedData.count());
// 2. 处理缺失值
cleanedData = handleMissingValues(cleanedData);
// 3. 数据类型转换
cleanedData = convertDataTypes(cleanedData);
// 4. 异常值处理
cleanedData = handleOutliers(cleanedData);
// 5. 计算派生字段
cleanedData = calculateDerivedFields(cleanedData);
long finalCount = cleanedData.count();
logger.info("数据清洗完成,最终数据条数: {}", finalCount);
return cleanedData;
}
/**
* 处理缺失值
*/
private Dataset<Row> handleMissingValues(Dataset<Row> data) {
logger.info("处理缺失值...");
// 数值字段填充为0
String[] numericCols = {
"open_price", "close_price", "high_price", "low_price",
"volume", "turnover", "change_percent", "change_amount",
"pe_ratio", "pb_ratio", "market_cap", "float_market_cap"
};
for (String col : numericCols) {
if (hasColumn(data, col)) {
data = data.na().fill(0.0, new String[]{col});
}
}
// 字符串字段填充为空字符串
String[] stringCols = {"stock_name"};
for (String col : stringCols) {
if (hasColumn(data, col)) {
data = data.na().fill("", new String[]{col});
}
}
return data;
}
/**
* 数据类型转换
*/
private Dataset<Row> convertDataTypes(Dataset<Row> data) {
logger.info("执行数据类型转换...");
// 转换时间戳字段
if (hasColumn(data, "trade_date")) {
data = data.withColumn("trade_date",
to_timestamp(col("trade_date"), "yyyy-MM-dd HH:mm:ss"));
}
// 确保数值字段为正确的数据类型
String[] numericCols = {
"open_price", "close_price", "high_price", "low_price",
"volume", "turnover", "change_percent", "change_amount",
"pe_ratio", "pb_ratio", "market_cap", "float_market_cap"
};
for (String colName : numericCols) {
if (hasColumn(data, colName)) {
data = data.withColumn(colName, col(colName).cast("double"));
}
}
return data;
}
/**
* 处理异常值
*/
private Dataset<Row> handleOutliers(Dataset<Row> data) {
logger.info("处理异常值...");
// 过滤异常数据
data = data.filter(
col("open_price").$greater$eq(0)
.and(col("close_price").$greater$eq(0))
.and(col("high_price").$greater$eq(0))
.and(col("low_price").$greater$eq(0))
.and(col("volume").$greater$eq(0))
.and(col("high_price").$greater$eq(col("low_price")))
);
// 过滤极端的涨跌幅数据超过±20%的数据需要特别检查)
data = data.filter(
col("change_percent").$greater$eq(-20.0)
.and(col("change_percent").$less$eq(20.0))
);
return data;
}
/**
* 计算派生字段
*/
private Dataset<Row> calculateDerivedFields(Dataset<Row> data) {
logger.info("计算派生字段...");
// 计算价格变动
data = data.withColumn("price_change",
col("close_price").minus(col("open_price")));
// 计算价格变动百分比
data = data.withColumn("price_change_pct",
when(col("open_price").notEqual(0),
col("close_price").minus(col("open_price"))
.divide(col("open_price")).multiply(100))
.otherwise(0));
// 计算振幅
data = data.withColumn("amplitude",
when(col("open_price").notEqual(0),
col("high_price").minus(col("low_price"))
.divide(col("open_price")).multiply(100))
.otherwise(0));
// 计算换手率(如果有流通股本数据)
if (hasColumn(data, "float_shares")) {
data = data.withColumn("turnover_rate",
when(col("float_shares").notEqual(0),
col("volume").divide(col("float_shares")).multiply(100))
.otherwise(0));
}
return data;
}
/**
* 检查DataFrame是否包含指定列
*/
private boolean hasColumn(Dataset<Row> data, String columnName) {
String[] columns = data.columns();
for (String col : columns) {
if (col.equals(columnName)) {
return true;
}
}
return false;
}
}

View File

@@ -0,0 +1,189 @@
package com.agricultural.spark.service;
import com.agricultural.spark.config.SparkConfig;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.time.LocalDate;
import java.util.Map;
import java.util.Properties;
/**
* 数据库保存服务类
* 负责将Spark处理结果保存到MySQL数据库
*
* @author Agricultural Stock Platform Team
*/
public class DatabaseSaveService {
private static final Logger logger = LoggerFactory.getLogger(DatabaseSaveService.class);
private final SparkSession spark;
private final SparkConfig config;
private final String jdbcUrl;
private final Properties connectionProps;
public DatabaseSaveService(SparkSession spark, SparkConfig config) {
this.spark = spark;
this.config = config;
this.jdbcUrl = String.format("jdbc:mysql://%s:%d/%s?useSSL=false&serverTimezone=Asia/Shanghai",
config.getMysqlHost(), config.getMysqlPort(), config.getMysqlDatabase());
this.connectionProps = new Properties();
this.connectionProps.put("user", config.getMysqlUser());
this.connectionProps.put("password", config.getMysqlPassword());
this.connectionProps.put("driver", "com.mysql.cj.jdbc.Driver");
}
/**
* 保存市场分析结果到market_analysis表
*/
public void saveMarketAnalysis(Map<String, Object> marketOverview) {
logger.info("开始保存市场分析结果到数据库...");
String sql = "INSERT INTO market_analysis (analysis_date, up_count, down_count, flat_count, " +
"total_count, total_market_cap, total_volume, total_turnover, avg_change_percent) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) " +
"ON DUPLICATE KEY UPDATE " +
"up_count = VALUES(up_count), down_count = VALUES(down_count), " +
"flat_count = VALUES(flat_count), total_count = VALUES(total_count), " +
"total_market_cap = VALUES(total_market_cap), total_volume = VALUES(total_volume), " +
"total_turnover = VALUES(total_turnover), avg_change_percent = VALUES(avg_change_percent)";
try (Connection conn = DriverManager.getConnection(jdbcUrl, connectionProps);
PreparedStatement stmt = conn.prepareStatement(sql)) {
// 解析交易日期
String tradeDateStr = (String) marketOverview.get("trade_date");
LocalDate tradeDate = LocalDate.parse(tradeDateStr.substring(0, 10)); // 提取日期部分
stmt.setDate(1, java.sql.Date.valueOf(tradeDate));
stmt.setLong(2, ((Number) marketOverview.get("up_count")).longValue());
stmt.setLong(3, ((Number) marketOverview.get("down_count")).longValue());
stmt.setLong(4, ((Number) marketOverview.get("flat_count")).longValue());
stmt.setLong(5, ((Number) marketOverview.get("total_count")).longValue());
stmt.setDouble(6, ((Number) marketOverview.get("total_market_cap")).doubleValue());
stmt.setLong(7, ((Number) marketOverview.get("total_volume")).longValue());
stmt.setDouble(8, ((Number) marketOverview.get("total_turnover")).doubleValue());
stmt.setDouble(9, ((Number) marketOverview.get("avg_change_percent")).doubleValue());
int rowsAffected = stmt.executeUpdate();
logger.info("市场分析结果保存成功,影响 {} 行数据", rowsAffected);
} catch (SQLException e) {
logger.error("保存市场分析结果失败", e);
throw new RuntimeException("保存市场分析结果到数据库失败", e);
}
}
/**
* 保存处理后的技术指标数据到新表(可选)
*/
public void saveTechnicalIndicators(Dataset<Row> dataWithIndicators) {
logger.info("开始保存技术指标数据到数据库...");
try {
// 选择需要保存的字段
Dataset<Row> selectedData = dataWithIndicators.select(
"stock_code", "stock_name", "trade_date", "close_price",
"ma5", "ma10", "ma20", "ma30",
"rsi", "macd_dif", "macd_dea",
"bb_upper", "bb_middle", "bb_lower"
);
// 保存到stock_technical_indicators表需要先创建此表
selectedData.write()
.mode(SaveMode.Overwrite)
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "stock_technical_indicators")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.save();
logger.info("技术指标数据保存成功");
} catch (Exception e) {
logger.error("保存技术指标数据失败", e);
// 不抛出异常,允许程序继续执行
}
}
/**
* 保存行业分析结果
*/
public void saveIndustryAnalysis(Dataset<Row> industryAnalysis) {
logger.info("开始保存行业分析结果到数据库...");
try {
// 添加分析日期字段
Dataset<Row> industryWithDate = industryAnalysis.withColumn("analysis_date",
org.apache.spark.sql.functions.current_date());
// 保存到industry_analysis表需要先创建此表
industryWithDate.write()
.mode(SaveMode.Append)
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "industry_analysis")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.save();
logger.info("行业分析结果保存成功");
} catch (Exception e) {
logger.error("保存行业分析结果失败", e);
// 不抛出异常,允许程序继续执行
}
}
/**
* 保存历史趋势数据
*/
public void saveHistoricalTrends(Dataset<Row> trendAnalysis) {
logger.info("开始保存历史趋势数据到数据库...");
try {
// 保存到market_trends表需要先创建此表
trendAnalysis.write()
.mode(SaveMode.Overwrite)
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "market_trends")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.save();
logger.info("历史趋势数据保存成功");
} catch (Exception e) {
logger.error("保存历史趋势数据失败", e);
// 不抛出异常,允许程序继续执行
}
}
/**
* 测试数据库连接
*/
public boolean testConnection() {
try (Connection conn = DriverManager.getConnection(jdbcUrl, connectionProps)) {
logger.info("数据库连接测试成功");
return true;
} catch (SQLException e) {
logger.error("数据库连接测试失败", e);
return false;
}
}
}

View File

@@ -0,0 +1,205 @@
package com.agricultural.spark.service;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.spark.sql.functions.*;
/**
* 市场分析服务类
*
* @author Agricultural Stock Platform Team
*/
public class MarketAnalysisService {
private static final Logger logger = LoggerFactory.getLogger(MarketAnalysisService.class);
private final SparkSession spark;
public MarketAnalysisService(SparkSession spark) {
this.spark = spark;
}
/**
* 市场总览分析
*/
public Map<String, Object> analyzeMarketOverview(Dataset<Row> data) {
logger.info("执行市场总览分析...");
// 获取最新交易日的数据
Row latestDateRow = data.agg(max("trade_date")).head();
if (latestDateRow.isNullAt(0)) {
logger.warn("无法获取最新交易日期");
return new HashMap<>();
}
Object latestDate = latestDateRow.get(0);
Dataset<Row> latestData = data.filter(col("trade_date").equalTo(latestDate));
// 统计涨跌股票数量
long upCount = latestData.filter(col("change_percent").gt(0)).count();
long downCount = latestData.filter(col("change_percent").lt(0)).count();
long flatCount = latestData.filter(col("change_percent").equalTo(0)).count();
long totalCount = latestData.count();
// 计算总市值和成交量
Row aggregateRow = latestData.agg(
sum("market_cap").alias("total_market_cap"),
sum("volume").alias("total_volume"),
sum("turnover").alias("total_turnover"),
avg("change_percent").alias("avg_change_percent")
).head();
double totalMarketCap = aggregateRow.isNullAt(0) ? 0.0 : aggregateRow.getDouble(0);
long totalVolume = aggregateRow.isNullAt(1) ? 0L : Math.round(aggregateRow.getDouble(1));
double totalTurnover = aggregateRow.isNullAt(2) ? 0.0 : aggregateRow.getDouble(2);
double avgChangePercent = aggregateRow.isNullAt(3) ? 0.0 : aggregateRow.getDouble(3);
Map<String, Object> result = new HashMap<>();
result.put("trade_date", latestDate.toString());
result.put("up_count", upCount);
result.put("down_count", downCount);
result.put("flat_count", flatCount);
result.put("total_count", totalCount);
result.put("total_market_cap", totalMarketCap);
result.put("total_volume", totalVolume);
result.put("total_turnover", totalTurnover);
result.put("avg_change_percent", Math.round(avgChangePercent * 100.0) / 100.0);
logger.info("市场总览分析完成: {}", result);
return result;
}
/**
* 分析涨跌幅榜单
*/
public Map<String, Object> analyzeTopGainersLosers(Dataset<Row> data, int limit) {
logger.info("分析涨跌幅榜单(前{}名)...", limit);
// 获取最新交易日数据
Row latestDateRow = data.agg(max("trade_date")).head();
Object latestDate = latestDateRow.get(0);
Dataset<Row> latestData = data.filter(col("trade_date").equalTo(latestDate));
// 涨幅榜
List<Row> topGainers = latestData
.orderBy(col("change_percent").desc())
.select("stock_code", "stock_name", "close_price", "change_percent",
"volume", "market_cap")
.limit(limit)
.collectAsList();
// 跌幅榜
List<Row> topLosers = latestData
.orderBy(col("change_percent").asc())
.select("stock_code", "stock_name", "close_price", "change_percent",
"volume", "market_cap")
.limit(limit)
.collectAsList();
// 成交量榜
List<Row> topVolume = latestData
.orderBy(col("volume").desc())
.select("stock_code", "stock_name", "close_price", "change_percent",
"volume", "turnover")
.limit(limit)
.collectAsList();
Map<String, Object> result = new HashMap<>();
result.put("top_gainers", topGainers);
result.put("top_losers", topLosers);
result.put("top_volume", topVolume);
logger.info("涨跌幅榜单分析完成");
return result;
}
/**
* 行业表现分析
*/
public Dataset<Row> analyzeIndustryPerformance(Dataset<Row> data) {
logger.info("执行行业表现分析...");
// 注册UDF函数用于行业分类
spark.udf().register("getIndustry", (String stockCode) -> {
if (stockCode == null || stockCode.length() < 6) {
return "其他农业";
}
String code = stockCode.substring(stockCode.length() - 6);
switch (code) {
case "000876":
case "002714":
return "畜牧业";
case "600519":
case "000858":
case "600887":
case "002304":
return "食品饮料";
default:
return "其他农业";
}
}, org.apache.spark.sql.types.DataTypes.StringType);
// 添加行业字段
Dataset<Row> dataWithIndustry = data.withColumn("industry",
callUDF("getIndustry", col("stock_code")));
// 获取最新交易日数据
Row latestDateRow = dataWithIndustry.agg(max("trade_date")).head();
Object latestDate = latestDateRow.get(0);
Dataset<Row> latestData = dataWithIndustry.filter(col("trade_date").equalTo(latestDate));
// 按行业统计
Dataset<Row> industryStats = latestData.groupBy("industry")
.agg(
count("*").alias("stock_count"),
avg("change_percent").alias("avg_change_percent"),
sum("market_cap").alias("total_market_cap"),
sum("volume").alias("total_volume")
)
.orderBy(col("avg_change_percent").desc());
logger.info("行业表现分析完成,共 {} 个行业", industryStats.count());
return industryStats;
}
/**
* 历史趋势分析
*/
public Dataset<Row> analyzeHistoricalTrends(Dataset<Row> data, String stockCode, int days) {
logger.info("分析历史趋势({}天)...", days);
// 计算起始日期
Row latestDateRow = data.agg(max("trade_date")).head();
Object endDate = latestDateRow.get(0);
// 这里简化处理,实际应该使用日期计算
Dataset<Row> trendData = data;
if (stockCode != null && !stockCode.isEmpty()) {
trendData = trendData.filter(col("stock_code").equalTo(stockCode));
}
// 按日期聚合
Dataset<Row> dailyStats = trendData.groupBy("trade_date")
.agg(
avg("close_price").alias("avg_price"),
avg("change_percent").alias("avg_change_percent"),
sum("volume").alias("total_volume"),
sum("turnover").alias("total_turnover"),
count("*").alias("stock_count")
)
.orderBy("trade_date");
logger.info("历史趋势分析完成");
return dailyStats;
}
}

View File

@@ -0,0 +1,188 @@
package com.agricultural.spark.service;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.spark.sql.functions.*;
/**
* 技术指标服务类
*
* @author Agricultural Stock Platform Team
*/
public class TechnicalIndicatorService {
private static final Logger logger = LoggerFactory.getLogger(TechnicalIndicatorService.class);
private final SparkSession spark;
public TechnicalIndicatorService(SparkSession spark) {
this.spark = spark;
}
/**
* 计算技术指标
*/
public Dataset<Row> calculateTechnicalIndicators(Dataset<Row> data) {
logger.info("开始计算技术指标...");
// 按股票代码和日期排序的窗口
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-29, 0); // 30天窗口
WindowSpec windowSpec5 = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-4, 0); // 5天窗口
WindowSpec windowSpec10 = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-9, 0); // 10天窗口
WindowSpec windowSpec20 = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-19, 0); // 20天窗口
Dataset<Row> result = data;
// 1. 移动平均线 (MA)
result = result.withColumn("ma5", avg("close_price").over(windowSpec5))
.withColumn("ma10", avg("close_price").over(windowSpec10))
.withColumn("ma20", avg("close_price").over(windowSpec20))
.withColumn("ma30", avg("close_price").over(windowSpec));
// 2. 成交量移动平均
result = result.withColumn("volume_ma5", avg("volume").over(windowSpec5))
.withColumn("volume_ma10", avg("volume").over(windowSpec10));
// 3. 计算价格相对于移动平均线的位置
result = result.withColumn("price_vs_ma5",
when(col("ma5").notEqual(0),
col("close_price").divide(col("ma5")).multiply(100).minus(100))
.otherwise(0))
.withColumn("price_vs_ma20",
when(col("ma20").notEqual(0),
col("close_price").divide(col("ma20")).multiply(100).minus(100))
.otherwise(0));
// 4. 计算波动率 (30天)
result = result.withColumn("volatility_30d",
stddev("change_percent").over(windowSpec));
// 5. 计算相对强弱指标 RSI (简化版)
result = calculateRSI(result, 14);
// 6. 计算MACD指标
result = calculateMACD(result);
// 7. 计算布林带
result = calculateBollingerBands(result, 20);
logger.info("技术指标计算完成");
return result;
}
/**
* 计算RSI相对强弱指标
*/
private Dataset<Row> calculateRSI(Dataset<Row> data, int period) {
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-(period-1), 0);
// 计算价格变化
data = data.withColumn("price_diff",
col("close_price").minus(lag("close_price", 1)
.over(Window.partitionBy("stock_code").orderBy("trade_date"))));
// 计算涨跌
data = data.withColumn("gain", when(col("price_diff").gt(0), col("price_diff")).otherwise(0))
.withColumn("loss", when(col("price_diff").lt(0), abs(col("price_diff"))).otherwise(0));
// 计算平均涨跌
data = data.withColumn("avg_gain", avg("gain").over(windowSpec))
.withColumn("avg_loss", avg("loss").over(windowSpec));
// 计算RSI
data = data.withColumn("rsi",
when(col("avg_loss").notEqual(0),
lit(100).minus(lit(100).divide(lit(1).plus(col("avg_gain").divide(col("avg_loss"))))))
.otherwise(50));
return data.drop("price_diff", "gain", "loss", "avg_gain", "avg_loss");
}
/**
* 计算MACD指标
*/
private Dataset<Row> calculateMACD(Dataset<Row> data) {
// 计算EMA12和EMA26
data = data.withColumn("ema12", calculateEMA(col("close_price"), 12))
.withColumn("ema26", calculateEMA(col("close_price"), 26));
// 计算MACD线 (DIF)
data = data.withColumn("macd_dif", col("ema12").minus(col("ema26")));
// 计算DEA (MACD的9日EMA)
data = data.withColumn("macd_dea", calculateEMA(col("macd_dif"), 9));
// 计算MACD柱状图
data = data.withColumn("macd_histogram",
col("macd_dif").minus(col("macd_dea")).multiply(2));
return data.drop("ema12", "ema26");
}
/**
* 计算布林带
*/
private Dataset<Row> calculateBollingerBands(Dataset<Row> data, int period) {
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-(period-1), 0);
// 计算中轨(移动平均)
data = data.withColumn("bb_middle", avg("close_price").over(windowSpec));
// 计算标准差
data = data.withColumn("bb_std", stddev("close_price").over(windowSpec));
// 计算上轨和下轨
data = data.withColumn("bb_upper",
col("bb_middle").plus(col("bb_std").multiply(2)))
.withColumn("bb_lower",
col("bb_middle").minus(col("bb_std").multiply(2)));
// 计算布林带宽度
data = data.withColumn("bb_width",
when(col("bb_middle").notEqual(0),
col("bb_upper").minus(col("bb_lower")).divide(col("bb_middle")).multiply(100))
.otherwise(0));
// 计算价格在布林带中的位置
data = data.withColumn("bb_position",
when(col("bb_upper").notEqual(col("bb_lower")),
col("close_price").minus(col("bb_lower"))
.divide(col("bb_upper").minus(col("bb_lower"))))
.otherwise(0.5));
return data.drop("bb_std");
}
/**
* 计算指数移动平均 (EMA) - 简化版
*/
private org.apache.spark.sql.Column calculateEMA(org.apache.spark.sql.Column priceCol, int period) {
// 简化实现实际应该使用更复杂的EMA计算
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-(period-1), 0);
return avg(priceCol).over(windowSpec);
}
}

View File

@@ -0,0 +1,52 @@
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<!-- 控制台输出配置 -->
<appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<!-- 文件输出配置(可选) -->
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
<file>logs/spark-processor.log</file>
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
<fileNamePattern>logs/spark-processor.%d{yyyy-MM-dd}.%i.log</fileNamePattern>
<maxFileSize>10MB</maxFileSize>
<maxHistory>30</maxHistory>
</rollingPolicy>
<encoder>
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{50} - %msg%n</pattern>
</encoder>
</appender>
<!-- 设置特定包的日志级别 -->
<!-- 我们自己的应用程序日志 - 保持INFO级别 -->
<logger name="com.agricultural.spark" level="INFO" additivity="false">
<appender-ref ref="CONSOLE"/>
<appender-ref ref="FILE"/>
</logger>
<!-- Spark框架日志 - 设置为WARN级别减少输出 -->
<logger name="org.apache.spark" level="WARN"/>
<logger name="org.apache.hadoop" level="WARN"/>
<logger name="org.apache.hive" level="WARN"/>
<logger name="org.apache.parquet" level="WARN"/>
<!-- 数据库连接日志 - 设置为WARN级别 -->
<logger name="com.mysql" level="WARN"/>
<logger name="mysql" level="WARN"/>
<!-- Maven和其他框架日志 - 设置为ERROR级别 -->
<logger name="org.apache.maven" level="ERROR"/>
<logger name="org.eclipse.jetty" level="ERROR"/>
<logger name="io.netty" level="ERROR"/>
<!-- 根日志级别 -->
<root level="WARN">
<appender-ref ref="CONSOLE"/>
</root>
</configuration>