first commit

This commit is contained in:
2025-06-12 19:37:54 +08:00
parent bb2eb010f7
commit 1c6093fa9a
87 changed files with 18432 additions and 0 deletions

View File

@@ -0,0 +1,228 @@
package com.agricultural.spark;
import com.agricultural.spark.config.SparkConfig;
import com.agricultural.spark.service.DataCleaningService;
import com.agricultural.spark.service.DatabaseSaveService;
import com.agricultural.spark.service.MarketAnalysisService;
// StreamProcessingService 已移除 (Kafka相关)
import com.agricultural.spark.service.TechnicalIndicatorService;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Map;
/**
* 农业股票数据处理器主类
* 基于Apache Spark的大数据处理平台
*
* @author Agricultural Stock Platform Team
*/
public class StockDataProcessor {
private static final Logger logger = LoggerFactory.getLogger(StockDataProcessor.class);
private final SparkSession spark;
private final SparkConfig config;
private final DataCleaningService dataCleaningService;
private final MarketAnalysisService marketAnalysisService;
private final TechnicalIndicatorService technicalIndicatorService;
private final DatabaseSaveService databaseSaveService;
// StreamProcessingService 已移除
public StockDataProcessor(SparkConfig config) {
this.config = config;
this.spark = initializeSparkSession();
this.dataCleaningService = new DataCleaningService(spark);
this.marketAnalysisService = new MarketAnalysisService(spark);
this.technicalIndicatorService = new TechnicalIndicatorService(spark);
this.databaseSaveService = new DatabaseSaveService(spark, config);
// StreamProcessingService 初始化已移除
}
/**
* 初始化Spark会话
*/
private SparkSession initializeSparkSession() {
logger.info("正在初始化Spark会话...");
SparkSession.Builder builder = SparkSession.builder()
.appName("AgriculturalStockDataProcessor")
.config("spark.sql.adaptive.enabled", "true")
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
.config("spark.sql.warehouse.dir", "/tmp/spark-warehouse");
// 如果是本地模式
if (config.getSparkMaster().startsWith("local")) {
builder.master(config.getSparkMaster());
}
SparkSession session = builder.getOrCreate();
session.sparkContext().setLogLevel("WARN");
logger.info("Spark会话初始化成功");
return session;
}
/**
* 从MySQL加载股票数据
*/
public Dataset<Row> loadStockDataFromMySQL() {
logger.info("从MySQL加载股票数据...");
String jdbcUrl = String.format("jdbc:mysql://%s:%d/%s?useSSL=false&serverTimezone=Asia/Shanghai",
config.getMysqlHost(), config.getMysqlPort(), config.getMysqlDatabase());
Dataset<Row> df = spark.read()
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "stock_data")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.load();
long count = df.count();
logger.info("从MySQL加载数据完成共 {} 条记录", count);
return df;
}
/**
* 执行批处理分析
*/
public void runBatchProcessing() {
logger.info("开始执行批处理分析...");
try {
// 0. 测试数据库连接
if (!databaseSaveService.testConnection()) {
logger.warn("数据库连接测试失败,将仅保存到文件系统");
}
// 1. 加载数据
Dataset<Row> rawData = loadStockDataFromMySQL();
// 2. 数据清洗
Dataset<Row> cleanedData = dataCleaningService.cleanData(rawData);
// 3. 市场总览分析
Map<String, Object> marketOverview = marketAnalysisService.analyzeMarketOverview(cleanedData);
logger.info("市场总览分析完成: {}", marketOverview);
// 4. 涨跌幅排行分析
Map<String, Object> rankingAnalysis = marketAnalysisService.analyzeTopGainersLosers(cleanedData, 10);
logger.info("涨跌幅排行分析完成");
// 5. 行业表现分析
Dataset<Row> industryAnalysis = marketAnalysisService.analyzeIndustryPerformance(cleanedData);
logger.info("行业表现分析完成,共 {} 个行业", industryAnalysis.count());
// 6. 历史趋势分析
Dataset<Row> trendAnalysis = marketAnalysisService.analyzeHistoricalTrends(cleanedData, null, 30);
logger.info("历史趋势分析完成,共 {} 个数据点", trendAnalysis.count());
// 7. 技术指标计算
Dataset<Row> dataWithIndicators = technicalIndicatorService.calculateTechnicalIndicators(cleanedData);
logger.info("技术指标计算完成");
// 8. 保存处理结果到数据库
try {
databaseSaveService.saveMarketAnalysis(marketOverview);
databaseSaveService.saveIndustryAnalysis(industryAnalysis);
databaseSaveService.saveHistoricalTrends(trendAnalysis);
databaseSaveService.saveTechnicalIndicators(dataWithIndicators);
logger.info("数据已成功保存到MySQL数据库");
} catch (Exception e) {
logger.error("保存到数据库失败,回退到文件保存", e);
// 回退到原有的文件保存方式
saveResults(dataWithIndicators, "processed_data");
saveAnalysisResults(marketOverview, "market_overview");
}
logger.info("批处理分析完成");
} catch (Exception e) {
logger.error("批处理分析过程中出现错误", e);
throw new RuntimeException("批处理分析失败", e);
}
}
/**
* 实时流处理已移除原Kafka功能
*/
public void runStreamProcessing() {
logger.warn("实时流处理功能已移除,仅支持批处理模式");
}
/**
* 保存处理结果到文件
*/
private void saveResults(Dataset<Row> data, String outputPath) {
try {
String fullPath = config.getOutputPath() + "/" + outputPath;
data.write()
.mode("overwrite")
.parquet(fullPath);
logger.info("数据已保存到: {}", fullPath);
} catch (Exception e) {
logger.error("保存数据失败: {}", outputPath, e);
}
}
/**
* 保存分析结果
*/
private void saveAnalysisResults(Map<String, Object> results, String outputPath) {
try {
String fullPath = config.getOutputPath() + "/" + outputPath + ".json";
// 这里可以将Map转换为JSON并保存
logger.info("分析结果已保存到: {}", fullPath);
} catch (Exception e) {
logger.error("保存分析结果失败: {}", outputPath, e);
}
}
/**
* 关闭Spark会话
*/
public void close() {
if (spark != null) {
spark.stop();
logger.info("Spark会话已关闭");
}
}
/**
* 主方法
*/
public static void main(String[] args) {
logger.info("农业股票数据处理器启动");
StockDataProcessor processor = null;
try {
// 加载配置
SparkConfig config = SparkConfig.load();
processor = new StockDataProcessor(config);
// 根据参数决定运行模式
if (args.length > 0 && "stream".equals(args[0])) {
// 流处理模式
processor.runStreamProcessing();
} else {
// 批处理模式
processor.runBatchProcessing();
}
} catch (Exception e) {
logger.error("程序运行过程中出现错误", e);
System.exit(1);
} finally {
if (processor != null) {
processor.close();
}
logger.info("农业股票数据处理器已停止");
}
}
}

View File

@@ -0,0 +1,60 @@
package com.agricultural.spark.config;
import com.typesafe.config.Config;
import com.typesafe.config.ConfigFactory;
/**
* Spark配置类
*
* @author Agricultural Stock Platform Team
*/
public class SparkConfig {
private final Config config;
private SparkConfig(Config config) {
this.config = config;
}
public static SparkConfig load() {
Config config = ConfigFactory.load();
return new SparkConfig(config);
}
public String getSparkMaster() {
return config.hasPath("spark.master") ?
config.getString("spark.master") : "local[*]";
}
public String getMysqlHost() {
return config.hasPath("mysql.host") ?
config.getString("mysql.host") : "localhost";
}
public int getMysqlPort() {
return config.hasPath("mysql.port") ?
config.getInt("mysql.port") : 3306;
}
public String getMysqlDatabase() {
return config.hasPath("mysql.database") ?
config.getString("mysql.database") : "agricultural_stock";
}
public String getMysqlUser() {
return config.hasPath("mysql.user") ?
config.getString("mysql.user") : "root";
}
public String getMysqlPassword() {
return config.hasPath("mysql.password") ?
config.getString("mysql.password") : "root";
}
// Kafka配置已移除
public String getOutputPath() {
return config.hasPath("output.path") ?
config.getString("output.path") : "/tmp/spark-output";
}
}

View File

@@ -0,0 +1,190 @@
package com.agricultural.spark.service;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.functions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.spark.sql.functions.*;
/**
* 数据清洗服务类
*
* @author Agricultural Stock Platform Team
*/
public class DataCleaningService {
private static final Logger logger = LoggerFactory.getLogger(DataCleaningService.class);
private final SparkSession spark;
public DataCleaningService(SparkSession spark) {
this.spark = spark;
}
/**
* 执行数据清洗
*
* @param rawData 原始数据
* @return 清洗后的数据
*/
public Dataset<Row> cleanData(Dataset<Row> rawData) {
logger.info("开始执行数据清洗,原始数据条数: {}", rawData.count());
Dataset<Row> cleanedData = rawData;
// 1. 去除重复数据
cleanedData = cleanedData.dropDuplicates("stock_code", "trade_date");
logger.info("去重后数据条数: {}", cleanedData.count());
// 2. 处理缺失值
cleanedData = handleMissingValues(cleanedData);
// 3. 数据类型转换
cleanedData = convertDataTypes(cleanedData);
// 4. 异常值处理
cleanedData = handleOutliers(cleanedData);
// 5. 计算派生字段
cleanedData = calculateDerivedFields(cleanedData);
long finalCount = cleanedData.count();
logger.info("数据清洗完成,最终数据条数: {}", finalCount);
return cleanedData;
}
/**
* 处理缺失值
*/
private Dataset<Row> handleMissingValues(Dataset<Row> data) {
logger.info("处理缺失值...");
// 数值字段填充为0
String[] numericCols = {
"open_price", "close_price", "high_price", "low_price",
"volume", "turnover", "change_percent", "change_amount",
"pe_ratio", "pb_ratio", "market_cap", "float_market_cap"
};
for (String col : numericCols) {
if (hasColumn(data, col)) {
data = data.na().fill(0.0, new String[]{col});
}
}
// 字符串字段填充为空字符串
String[] stringCols = {"stock_name"};
for (String col : stringCols) {
if (hasColumn(data, col)) {
data = data.na().fill("", new String[]{col});
}
}
return data;
}
/**
* 数据类型转换
*/
private Dataset<Row> convertDataTypes(Dataset<Row> data) {
logger.info("执行数据类型转换...");
// 转换时间戳字段
if (hasColumn(data, "trade_date")) {
data = data.withColumn("trade_date",
to_timestamp(col("trade_date"), "yyyy-MM-dd HH:mm:ss"));
}
// 确保数值字段为正确的数据类型
String[] numericCols = {
"open_price", "close_price", "high_price", "low_price",
"volume", "turnover", "change_percent", "change_amount",
"pe_ratio", "pb_ratio", "market_cap", "float_market_cap"
};
for (String colName : numericCols) {
if (hasColumn(data, colName)) {
data = data.withColumn(colName, col(colName).cast("double"));
}
}
return data;
}
/**
* 处理异常值
*/
private Dataset<Row> handleOutliers(Dataset<Row> data) {
logger.info("处理异常值...");
// 过滤异常数据
data = data.filter(
col("open_price").$greater$eq(0)
.and(col("close_price").$greater$eq(0))
.and(col("high_price").$greater$eq(0))
.and(col("low_price").$greater$eq(0))
.and(col("volume").$greater$eq(0))
.and(col("high_price").$greater$eq(col("low_price")))
);
// 过滤极端的涨跌幅数据超过±20%的数据需要特别检查)
data = data.filter(
col("change_percent").$greater$eq(-20.0)
.and(col("change_percent").$less$eq(20.0))
);
return data;
}
/**
* 计算派生字段
*/
private Dataset<Row> calculateDerivedFields(Dataset<Row> data) {
logger.info("计算派生字段...");
// 计算价格变动
data = data.withColumn("price_change",
col("close_price").minus(col("open_price")));
// 计算价格变动百分比
data = data.withColumn("price_change_pct",
when(col("open_price").notEqual(0),
col("close_price").minus(col("open_price"))
.divide(col("open_price")).multiply(100))
.otherwise(0));
// 计算振幅
data = data.withColumn("amplitude",
when(col("open_price").notEqual(0),
col("high_price").minus(col("low_price"))
.divide(col("open_price")).multiply(100))
.otherwise(0));
// 计算换手率(如果有流通股本数据)
if (hasColumn(data, "float_shares")) {
data = data.withColumn("turnover_rate",
when(col("float_shares").notEqual(0),
col("volume").divide(col("float_shares")).multiply(100))
.otherwise(0));
}
return data;
}
/**
* 检查DataFrame是否包含指定列
*/
private boolean hasColumn(Dataset<Row> data, String columnName) {
String[] columns = data.columns();
for (String col : columns) {
if (col.equals(columnName)) {
return true;
}
}
return false;
}
}

View File

@@ -0,0 +1,189 @@
package com.agricultural.spark.service;
import com.agricultural.spark.config.SparkConfig;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SaveMode;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.time.LocalDate;
import java.util.Map;
import java.util.Properties;
/**
* 数据库保存服务类
* 负责将Spark处理结果保存到MySQL数据库
*
* @author Agricultural Stock Platform Team
*/
public class DatabaseSaveService {
private static final Logger logger = LoggerFactory.getLogger(DatabaseSaveService.class);
private final SparkSession spark;
private final SparkConfig config;
private final String jdbcUrl;
private final Properties connectionProps;
public DatabaseSaveService(SparkSession spark, SparkConfig config) {
this.spark = spark;
this.config = config;
this.jdbcUrl = String.format("jdbc:mysql://%s:%d/%s?useSSL=false&serverTimezone=Asia/Shanghai",
config.getMysqlHost(), config.getMysqlPort(), config.getMysqlDatabase());
this.connectionProps = new Properties();
this.connectionProps.put("user", config.getMysqlUser());
this.connectionProps.put("password", config.getMysqlPassword());
this.connectionProps.put("driver", "com.mysql.cj.jdbc.Driver");
}
/**
* 保存市场分析结果到market_analysis表
*/
public void saveMarketAnalysis(Map<String, Object> marketOverview) {
logger.info("开始保存市场分析结果到数据库...");
String sql = "INSERT INTO market_analysis (analysis_date, up_count, down_count, flat_count, " +
"total_count, total_market_cap, total_volume, total_turnover, avg_change_percent) " +
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) " +
"ON DUPLICATE KEY UPDATE " +
"up_count = VALUES(up_count), down_count = VALUES(down_count), " +
"flat_count = VALUES(flat_count), total_count = VALUES(total_count), " +
"total_market_cap = VALUES(total_market_cap), total_volume = VALUES(total_volume), " +
"total_turnover = VALUES(total_turnover), avg_change_percent = VALUES(avg_change_percent)";
try (Connection conn = DriverManager.getConnection(jdbcUrl, connectionProps);
PreparedStatement stmt = conn.prepareStatement(sql)) {
// 解析交易日期
String tradeDateStr = (String) marketOverview.get("trade_date");
LocalDate tradeDate = LocalDate.parse(tradeDateStr.substring(0, 10)); // 提取日期部分
stmt.setDate(1, java.sql.Date.valueOf(tradeDate));
stmt.setLong(2, ((Number) marketOverview.get("up_count")).longValue());
stmt.setLong(3, ((Number) marketOverview.get("down_count")).longValue());
stmt.setLong(4, ((Number) marketOverview.get("flat_count")).longValue());
stmt.setLong(5, ((Number) marketOverview.get("total_count")).longValue());
stmt.setDouble(6, ((Number) marketOverview.get("total_market_cap")).doubleValue());
stmt.setLong(7, ((Number) marketOverview.get("total_volume")).longValue());
stmt.setDouble(8, ((Number) marketOverview.get("total_turnover")).doubleValue());
stmt.setDouble(9, ((Number) marketOverview.get("avg_change_percent")).doubleValue());
int rowsAffected = stmt.executeUpdate();
logger.info("市场分析结果保存成功,影响 {} 行数据", rowsAffected);
} catch (SQLException e) {
logger.error("保存市场分析结果失败", e);
throw new RuntimeException("保存市场分析结果到数据库失败", e);
}
}
/**
* 保存处理后的技术指标数据到新表(可选)
*/
public void saveTechnicalIndicators(Dataset<Row> dataWithIndicators) {
logger.info("开始保存技术指标数据到数据库...");
try {
// 选择需要保存的字段
Dataset<Row> selectedData = dataWithIndicators.select(
"stock_code", "stock_name", "trade_date", "close_price",
"ma5", "ma10", "ma20", "ma30",
"rsi", "macd_dif", "macd_dea",
"bb_upper", "bb_middle", "bb_lower"
);
// 保存到stock_technical_indicators表需要先创建此表
selectedData.write()
.mode(SaveMode.Overwrite)
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "stock_technical_indicators")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.save();
logger.info("技术指标数据保存成功");
} catch (Exception e) {
logger.error("保存技术指标数据失败", e);
// 不抛出异常,允许程序继续执行
}
}
/**
* 保存行业分析结果
*/
public void saveIndustryAnalysis(Dataset<Row> industryAnalysis) {
logger.info("开始保存行业分析结果到数据库...");
try {
// 添加分析日期字段
Dataset<Row> industryWithDate = industryAnalysis.withColumn("analysis_date",
org.apache.spark.sql.functions.current_date());
// 保存到industry_analysis表需要先创建此表
industryWithDate.write()
.mode(SaveMode.Append)
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "industry_analysis")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.save();
logger.info("行业分析结果保存成功");
} catch (Exception e) {
logger.error("保存行业分析结果失败", e);
// 不抛出异常,允许程序继续执行
}
}
/**
* 保存历史趋势数据
*/
public void saveHistoricalTrends(Dataset<Row> trendAnalysis) {
logger.info("开始保存历史趋势数据到数据库...");
try {
// 保存到market_trends表需要先创建此表
trendAnalysis.write()
.mode(SaveMode.Overwrite)
.format("jdbc")
.option("url", jdbcUrl)
.option("dbtable", "market_trends")
.option("user", config.getMysqlUser())
.option("password", config.getMysqlPassword())
.option("driver", "com.mysql.cj.jdbc.Driver")
.save();
logger.info("历史趋势数据保存成功");
} catch (Exception e) {
logger.error("保存历史趋势数据失败", e);
// 不抛出异常,允许程序继续执行
}
}
/**
* 测试数据库连接
*/
public boolean testConnection() {
try (Connection conn = DriverManager.getConnection(jdbcUrl, connectionProps)) {
logger.info("数据库连接测试成功");
return true;
} catch (SQLException e) {
logger.error("数据库连接测试失败", e);
return false;
}
}
}

View File

@@ -0,0 +1,205 @@
package com.agricultural.spark.service;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import static org.apache.spark.sql.functions.*;
/**
* 市场分析服务类
*
* @author Agricultural Stock Platform Team
*/
public class MarketAnalysisService {
private static final Logger logger = LoggerFactory.getLogger(MarketAnalysisService.class);
private final SparkSession spark;
public MarketAnalysisService(SparkSession spark) {
this.spark = spark;
}
/**
* 市场总览分析
*/
public Map<String, Object> analyzeMarketOverview(Dataset<Row> data) {
logger.info("执行市场总览分析...");
// 获取最新交易日的数据
Row latestDateRow = data.agg(max("trade_date")).head();
if (latestDateRow.isNullAt(0)) {
logger.warn("无法获取最新交易日期");
return new HashMap<>();
}
Object latestDate = latestDateRow.get(0);
Dataset<Row> latestData = data.filter(col("trade_date").equalTo(latestDate));
// 统计涨跌股票数量
long upCount = latestData.filter(col("change_percent").gt(0)).count();
long downCount = latestData.filter(col("change_percent").lt(0)).count();
long flatCount = latestData.filter(col("change_percent").equalTo(0)).count();
long totalCount = latestData.count();
// 计算总市值和成交量
Row aggregateRow = latestData.agg(
sum("market_cap").alias("total_market_cap"),
sum("volume").alias("total_volume"),
sum("turnover").alias("total_turnover"),
avg("change_percent").alias("avg_change_percent")
).head();
double totalMarketCap = aggregateRow.isNullAt(0) ? 0.0 : aggregateRow.getDouble(0);
long totalVolume = aggregateRow.isNullAt(1) ? 0L : Math.round(aggregateRow.getDouble(1));
double totalTurnover = aggregateRow.isNullAt(2) ? 0.0 : aggregateRow.getDouble(2);
double avgChangePercent = aggregateRow.isNullAt(3) ? 0.0 : aggregateRow.getDouble(3);
Map<String, Object> result = new HashMap<>();
result.put("trade_date", latestDate.toString());
result.put("up_count", upCount);
result.put("down_count", downCount);
result.put("flat_count", flatCount);
result.put("total_count", totalCount);
result.put("total_market_cap", totalMarketCap);
result.put("total_volume", totalVolume);
result.put("total_turnover", totalTurnover);
result.put("avg_change_percent", Math.round(avgChangePercent * 100.0) / 100.0);
logger.info("市场总览分析完成: {}", result);
return result;
}
/**
* 分析涨跌幅榜单
*/
public Map<String, Object> analyzeTopGainersLosers(Dataset<Row> data, int limit) {
logger.info("分析涨跌幅榜单(前{}名)...", limit);
// 获取最新交易日数据
Row latestDateRow = data.agg(max("trade_date")).head();
Object latestDate = latestDateRow.get(0);
Dataset<Row> latestData = data.filter(col("trade_date").equalTo(latestDate));
// 涨幅榜
List<Row> topGainers = latestData
.orderBy(col("change_percent").desc())
.select("stock_code", "stock_name", "close_price", "change_percent",
"volume", "market_cap")
.limit(limit)
.collectAsList();
// 跌幅榜
List<Row> topLosers = latestData
.orderBy(col("change_percent").asc())
.select("stock_code", "stock_name", "close_price", "change_percent",
"volume", "market_cap")
.limit(limit)
.collectAsList();
// 成交量榜
List<Row> topVolume = latestData
.orderBy(col("volume").desc())
.select("stock_code", "stock_name", "close_price", "change_percent",
"volume", "turnover")
.limit(limit)
.collectAsList();
Map<String, Object> result = new HashMap<>();
result.put("top_gainers", topGainers);
result.put("top_losers", topLosers);
result.put("top_volume", topVolume);
logger.info("涨跌幅榜单分析完成");
return result;
}
/**
* 行业表现分析
*/
public Dataset<Row> analyzeIndustryPerformance(Dataset<Row> data) {
logger.info("执行行业表现分析...");
// 注册UDF函数用于行业分类
spark.udf().register("getIndustry", (String stockCode) -> {
if (stockCode == null || stockCode.length() < 6) {
return "其他农业";
}
String code = stockCode.substring(stockCode.length() - 6);
switch (code) {
case "000876":
case "002714":
return "畜牧业";
case "600519":
case "000858":
case "600887":
case "002304":
return "食品饮料";
default:
return "其他农业";
}
}, org.apache.spark.sql.types.DataTypes.StringType);
// 添加行业字段
Dataset<Row> dataWithIndustry = data.withColumn("industry",
callUDF("getIndustry", col("stock_code")));
// 获取最新交易日数据
Row latestDateRow = dataWithIndustry.agg(max("trade_date")).head();
Object latestDate = latestDateRow.get(0);
Dataset<Row> latestData = dataWithIndustry.filter(col("trade_date").equalTo(latestDate));
// 按行业统计
Dataset<Row> industryStats = latestData.groupBy("industry")
.agg(
count("*").alias("stock_count"),
avg("change_percent").alias("avg_change_percent"),
sum("market_cap").alias("total_market_cap"),
sum("volume").alias("total_volume")
)
.orderBy(col("avg_change_percent").desc());
logger.info("行业表现分析完成,共 {} 个行业", industryStats.count());
return industryStats;
}
/**
* 历史趋势分析
*/
public Dataset<Row> analyzeHistoricalTrends(Dataset<Row> data, String stockCode, int days) {
logger.info("分析历史趋势({}天)...", days);
// 计算起始日期
Row latestDateRow = data.agg(max("trade_date")).head();
Object endDate = latestDateRow.get(0);
// 这里简化处理,实际应该使用日期计算
Dataset<Row> trendData = data;
if (stockCode != null && !stockCode.isEmpty()) {
trendData = trendData.filter(col("stock_code").equalTo(stockCode));
}
// 按日期聚合
Dataset<Row> dailyStats = trendData.groupBy("trade_date")
.agg(
avg("close_price").alias("avg_price"),
avg("change_percent").alias("avg_change_percent"),
sum("volume").alias("total_volume"),
sum("turnover").alias("total_turnover"),
count("*").alias("stock_count")
)
.orderBy("trade_date");
logger.info("历史趋势分析完成");
return dailyStats;
}
}

View File

@@ -0,0 +1,188 @@
package com.agricultural.spark.service;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.Window;
import org.apache.spark.sql.expressions.WindowSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static org.apache.spark.sql.functions.*;
/**
* 技术指标服务类
*
* @author Agricultural Stock Platform Team
*/
public class TechnicalIndicatorService {
private static final Logger logger = LoggerFactory.getLogger(TechnicalIndicatorService.class);
private final SparkSession spark;
public TechnicalIndicatorService(SparkSession spark) {
this.spark = spark;
}
/**
* 计算技术指标
*/
public Dataset<Row> calculateTechnicalIndicators(Dataset<Row> data) {
logger.info("开始计算技术指标...");
// 按股票代码和日期排序的窗口
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-29, 0); // 30天窗口
WindowSpec windowSpec5 = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-4, 0); // 5天窗口
WindowSpec windowSpec10 = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-9, 0); // 10天窗口
WindowSpec windowSpec20 = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-19, 0); // 20天窗口
Dataset<Row> result = data;
// 1. 移动平均线 (MA)
result = result.withColumn("ma5", avg("close_price").over(windowSpec5))
.withColumn("ma10", avg("close_price").over(windowSpec10))
.withColumn("ma20", avg("close_price").over(windowSpec20))
.withColumn("ma30", avg("close_price").over(windowSpec));
// 2. 成交量移动平均
result = result.withColumn("volume_ma5", avg("volume").over(windowSpec5))
.withColumn("volume_ma10", avg("volume").over(windowSpec10));
// 3. 计算价格相对于移动平均线的位置
result = result.withColumn("price_vs_ma5",
when(col("ma5").notEqual(0),
col("close_price").divide(col("ma5")).multiply(100).minus(100))
.otherwise(0))
.withColumn("price_vs_ma20",
when(col("ma20").notEqual(0),
col("close_price").divide(col("ma20")).multiply(100).minus(100))
.otherwise(0));
// 4. 计算波动率 (30天)
result = result.withColumn("volatility_30d",
stddev("change_percent").over(windowSpec));
// 5. 计算相对强弱指标 RSI (简化版)
result = calculateRSI(result, 14);
// 6. 计算MACD指标
result = calculateMACD(result);
// 7. 计算布林带
result = calculateBollingerBands(result, 20);
logger.info("技术指标计算完成");
return result;
}
/**
* 计算RSI相对强弱指标
*/
private Dataset<Row> calculateRSI(Dataset<Row> data, int period) {
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-(period-1), 0);
// 计算价格变化
data = data.withColumn("price_diff",
col("close_price").minus(lag("close_price", 1)
.over(Window.partitionBy("stock_code").orderBy("trade_date"))));
// 计算涨跌
data = data.withColumn("gain", when(col("price_diff").gt(0), col("price_diff")).otherwise(0))
.withColumn("loss", when(col("price_diff").lt(0), abs(col("price_diff"))).otherwise(0));
// 计算平均涨跌
data = data.withColumn("avg_gain", avg("gain").over(windowSpec))
.withColumn("avg_loss", avg("loss").over(windowSpec));
// 计算RSI
data = data.withColumn("rsi",
when(col("avg_loss").notEqual(0),
lit(100).minus(lit(100).divide(lit(1).plus(col("avg_gain").divide(col("avg_loss"))))))
.otherwise(50));
return data.drop("price_diff", "gain", "loss", "avg_gain", "avg_loss");
}
/**
* 计算MACD指标
*/
private Dataset<Row> calculateMACD(Dataset<Row> data) {
// 计算EMA12和EMA26
data = data.withColumn("ema12", calculateEMA(col("close_price"), 12))
.withColumn("ema26", calculateEMA(col("close_price"), 26));
// 计算MACD线 (DIF)
data = data.withColumn("macd_dif", col("ema12").minus(col("ema26")));
// 计算DEA (MACD的9日EMA)
data = data.withColumn("macd_dea", calculateEMA(col("macd_dif"), 9));
// 计算MACD柱状图
data = data.withColumn("macd_histogram",
col("macd_dif").minus(col("macd_dea")).multiply(2));
return data.drop("ema12", "ema26");
}
/**
* 计算布林带
*/
private Dataset<Row> calculateBollingerBands(Dataset<Row> data, int period) {
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-(period-1), 0);
// 计算中轨(移动平均)
data = data.withColumn("bb_middle", avg("close_price").over(windowSpec));
// 计算标准差
data = data.withColumn("bb_std", stddev("close_price").over(windowSpec));
// 计算上轨和下轨
data = data.withColumn("bb_upper",
col("bb_middle").plus(col("bb_std").multiply(2)))
.withColumn("bb_lower",
col("bb_middle").minus(col("bb_std").multiply(2)));
// 计算布林带宽度
data = data.withColumn("bb_width",
when(col("bb_middle").notEqual(0),
col("bb_upper").minus(col("bb_lower")).divide(col("bb_middle")).multiply(100))
.otherwise(0));
// 计算价格在布林带中的位置
data = data.withColumn("bb_position",
when(col("bb_upper").notEqual(col("bb_lower")),
col("close_price").minus(col("bb_lower"))
.divide(col("bb_upper").minus(col("bb_lower"))))
.otherwise(0.5));
return data.drop("bb_std");
}
/**
* 计算指数移动平均 (EMA) - 简化版
*/
private org.apache.spark.sql.Column calculateEMA(org.apache.spark.sql.Column priceCol, int period) {
// 简化实现实际应该使用更复杂的EMA计算
WindowSpec windowSpec = Window.partitionBy("stock_code")
.orderBy("trade_date")
.rowsBetween(-(period-1), 0);
return avg(priceCol).over(windowSpec);
}
}

View File

@@ -0,0 +1,52 @@
<?xml version="1.0" encoding="UTF-8"?>
<configuration>
<!-- 控制台输出配置 -->
<appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<!-- 文件输出配置(可选) -->
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
<file>logs/spark-processor.log</file>
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
<fileNamePattern>logs/spark-processor.%d{yyyy-MM-dd}.%i.log</fileNamePattern>
<maxFileSize>10MB</maxFileSize>
<maxHistory>30</maxHistory>
</rollingPolicy>
<encoder>
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{50} - %msg%n</pattern>
</encoder>
</appender>
<!-- 设置特定包的日志级别 -->
<!-- 我们自己的应用程序日志 - 保持INFO级别 -->
<logger name="com.agricultural.spark" level="INFO" additivity="false">
<appender-ref ref="CONSOLE"/>
<appender-ref ref="FILE"/>
</logger>
<!-- Spark框架日志 - 设置为WARN级别减少输出 -->
<logger name="org.apache.spark" level="WARN"/>
<logger name="org.apache.hadoop" level="WARN"/>
<logger name="org.apache.hive" level="WARN"/>
<logger name="org.apache.parquet" level="WARN"/>
<!-- 数据库连接日志 - 设置为WARN级别 -->
<logger name="com.mysql" level="WARN"/>
<logger name="mysql" level="WARN"/>
<!-- Maven和其他框架日志 - 设置为ERROR级别 -->
<logger name="org.apache.maven" level="ERROR"/>
<logger name="org.eclipse.jetty" level="ERROR"/>
<logger name="io.netty" level="ERROR"/>
<!-- 根日志级别 -->
<root level="WARN">
<appender-ref ref="CONSOLE"/>
</root>
</configuration>