first commit
This commit is contained in:
@@ -0,0 +1,228 @@
|
||||
package com.agricultural.spark;
|
||||
|
||||
import com.agricultural.spark.config.SparkConfig;
|
||||
import com.agricultural.spark.service.DataCleaningService;
|
||||
import com.agricultural.spark.service.DatabaseSaveService;
|
||||
import com.agricultural.spark.service.MarketAnalysisService;
|
||||
// StreamProcessingService 已移除 (Kafka相关)
|
||||
import com.agricultural.spark.service.TechnicalIndicatorService;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* 农业股票数据处理器主类
|
||||
* 基于Apache Spark的大数据处理平台
|
||||
*
|
||||
* @author Agricultural Stock Platform Team
|
||||
*/
|
||||
public class StockDataProcessor {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(StockDataProcessor.class);
|
||||
|
||||
private final SparkSession spark;
|
||||
private final SparkConfig config;
|
||||
private final DataCleaningService dataCleaningService;
|
||||
private final MarketAnalysisService marketAnalysisService;
|
||||
private final TechnicalIndicatorService technicalIndicatorService;
|
||||
private final DatabaseSaveService databaseSaveService;
|
||||
// StreamProcessingService 已移除
|
||||
|
||||
public StockDataProcessor(SparkConfig config) {
|
||||
this.config = config;
|
||||
this.spark = initializeSparkSession();
|
||||
this.dataCleaningService = new DataCleaningService(spark);
|
||||
this.marketAnalysisService = new MarketAnalysisService(spark);
|
||||
this.technicalIndicatorService = new TechnicalIndicatorService(spark);
|
||||
this.databaseSaveService = new DatabaseSaveService(spark, config);
|
||||
// StreamProcessingService 初始化已移除
|
||||
}
|
||||
|
||||
/**
|
||||
* 初始化Spark会话
|
||||
*/
|
||||
private SparkSession initializeSparkSession() {
|
||||
logger.info("正在初始化Spark会话...");
|
||||
|
||||
SparkSession.Builder builder = SparkSession.builder()
|
||||
.appName("AgriculturalStockDataProcessor")
|
||||
.config("spark.sql.adaptive.enabled", "true")
|
||||
.config("spark.sql.adaptive.coalescePartitions.enabled", "true")
|
||||
.config("spark.sql.warehouse.dir", "/tmp/spark-warehouse");
|
||||
|
||||
// 如果是本地模式
|
||||
if (config.getSparkMaster().startsWith("local")) {
|
||||
builder.master(config.getSparkMaster());
|
||||
}
|
||||
|
||||
SparkSession session = builder.getOrCreate();
|
||||
session.sparkContext().setLogLevel("WARN");
|
||||
|
||||
logger.info("Spark会话初始化成功");
|
||||
return session;
|
||||
}
|
||||
|
||||
/**
|
||||
* 从MySQL加载股票数据
|
||||
*/
|
||||
public Dataset<Row> loadStockDataFromMySQL() {
|
||||
logger.info("从MySQL加载股票数据...");
|
||||
|
||||
String jdbcUrl = String.format("jdbc:mysql://%s:%d/%s?useSSL=false&serverTimezone=Asia/Shanghai",
|
||||
config.getMysqlHost(), config.getMysqlPort(), config.getMysqlDatabase());
|
||||
|
||||
Dataset<Row> df = spark.read()
|
||||
.format("jdbc")
|
||||
.option("url", jdbcUrl)
|
||||
.option("dbtable", "stock_data")
|
||||
.option("user", config.getMysqlUser())
|
||||
.option("password", config.getMysqlPassword())
|
||||
.option("driver", "com.mysql.cj.jdbc.Driver")
|
||||
.load();
|
||||
|
||||
long count = df.count();
|
||||
logger.info("从MySQL加载数据完成,共 {} 条记录", count);
|
||||
return df;
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行批处理分析
|
||||
*/
|
||||
public void runBatchProcessing() {
|
||||
logger.info("开始执行批处理分析...");
|
||||
|
||||
try {
|
||||
// 0. 测试数据库连接
|
||||
if (!databaseSaveService.testConnection()) {
|
||||
logger.warn("数据库连接测试失败,将仅保存到文件系统");
|
||||
}
|
||||
|
||||
// 1. 加载数据
|
||||
Dataset<Row> rawData = loadStockDataFromMySQL();
|
||||
|
||||
// 2. 数据清洗
|
||||
Dataset<Row> cleanedData = dataCleaningService.cleanData(rawData);
|
||||
|
||||
// 3. 市场总览分析
|
||||
Map<String, Object> marketOverview = marketAnalysisService.analyzeMarketOverview(cleanedData);
|
||||
logger.info("市场总览分析完成: {}", marketOverview);
|
||||
|
||||
// 4. 涨跌幅排行分析
|
||||
Map<String, Object> rankingAnalysis = marketAnalysisService.analyzeTopGainersLosers(cleanedData, 10);
|
||||
logger.info("涨跌幅排行分析完成");
|
||||
|
||||
// 5. 行业表现分析
|
||||
Dataset<Row> industryAnalysis = marketAnalysisService.analyzeIndustryPerformance(cleanedData);
|
||||
logger.info("行业表现分析完成,共 {} 个行业", industryAnalysis.count());
|
||||
|
||||
// 6. 历史趋势分析
|
||||
Dataset<Row> trendAnalysis = marketAnalysisService.analyzeHistoricalTrends(cleanedData, null, 30);
|
||||
logger.info("历史趋势分析完成,共 {} 个数据点", trendAnalysis.count());
|
||||
|
||||
// 7. 技术指标计算
|
||||
Dataset<Row> dataWithIndicators = technicalIndicatorService.calculateTechnicalIndicators(cleanedData);
|
||||
logger.info("技术指标计算完成");
|
||||
|
||||
// 8. 保存处理结果到数据库
|
||||
try {
|
||||
databaseSaveService.saveMarketAnalysis(marketOverview);
|
||||
databaseSaveService.saveIndustryAnalysis(industryAnalysis);
|
||||
databaseSaveService.saveHistoricalTrends(trendAnalysis);
|
||||
databaseSaveService.saveTechnicalIndicators(dataWithIndicators);
|
||||
logger.info("数据已成功保存到MySQL数据库");
|
||||
} catch (Exception e) {
|
||||
logger.error("保存到数据库失败,回退到文件保存", e);
|
||||
// 回退到原有的文件保存方式
|
||||
saveResults(dataWithIndicators, "processed_data");
|
||||
saveAnalysisResults(marketOverview, "market_overview");
|
||||
}
|
||||
|
||||
logger.info("批处理分析完成");
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("批处理分析过程中出现错误", e);
|
||||
throw new RuntimeException("批处理分析失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 实时流处理已移除(原Kafka功能)
|
||||
*/
|
||||
public void runStreamProcessing() {
|
||||
logger.warn("实时流处理功能已移除,仅支持批处理模式");
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存处理结果到文件
|
||||
*/
|
||||
private void saveResults(Dataset<Row> data, String outputPath) {
|
||||
try {
|
||||
String fullPath = config.getOutputPath() + "/" + outputPath;
|
||||
data.write()
|
||||
.mode("overwrite")
|
||||
.parquet(fullPath);
|
||||
logger.info("数据已保存到: {}", fullPath);
|
||||
} catch (Exception e) {
|
||||
logger.error("保存数据失败: {}", outputPath, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存分析结果
|
||||
*/
|
||||
private void saveAnalysisResults(Map<String, Object> results, String outputPath) {
|
||||
try {
|
||||
String fullPath = config.getOutputPath() + "/" + outputPath + ".json";
|
||||
// 这里可以将Map转换为JSON并保存
|
||||
logger.info("分析结果已保存到: {}", fullPath);
|
||||
} catch (Exception e) {
|
||||
logger.error("保存分析结果失败: {}", outputPath, e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 关闭Spark会话
|
||||
*/
|
||||
public void close() {
|
||||
if (spark != null) {
|
||||
spark.stop();
|
||||
logger.info("Spark会话已关闭");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 主方法
|
||||
*/
|
||||
public static void main(String[] args) {
|
||||
logger.info("农业股票数据处理器启动");
|
||||
|
||||
StockDataProcessor processor = null;
|
||||
try {
|
||||
// 加载配置
|
||||
SparkConfig config = SparkConfig.load();
|
||||
processor = new StockDataProcessor(config);
|
||||
|
||||
// 根据参数决定运行模式
|
||||
if (args.length > 0 && "stream".equals(args[0])) {
|
||||
// 流处理模式
|
||||
processor.runStreamProcessing();
|
||||
} else {
|
||||
// 批处理模式
|
||||
processor.runBatchProcessing();
|
||||
}
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("程序运行过程中出现错误", e);
|
||||
System.exit(1);
|
||||
} finally {
|
||||
if (processor != null) {
|
||||
processor.close();
|
||||
}
|
||||
logger.info("农业股票数据处理器已停止");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package com.agricultural.spark.config;
|
||||
|
||||
import com.typesafe.config.Config;
|
||||
import com.typesafe.config.ConfigFactory;
|
||||
|
||||
/**
|
||||
* Spark配置类
|
||||
*
|
||||
* @author Agricultural Stock Platform Team
|
||||
*/
|
||||
public class SparkConfig {
|
||||
|
||||
private final Config config;
|
||||
|
||||
private SparkConfig(Config config) {
|
||||
this.config = config;
|
||||
}
|
||||
|
||||
public static SparkConfig load() {
|
||||
Config config = ConfigFactory.load();
|
||||
return new SparkConfig(config);
|
||||
}
|
||||
|
||||
public String getSparkMaster() {
|
||||
return config.hasPath("spark.master") ?
|
||||
config.getString("spark.master") : "local[*]";
|
||||
}
|
||||
|
||||
public String getMysqlHost() {
|
||||
return config.hasPath("mysql.host") ?
|
||||
config.getString("mysql.host") : "localhost";
|
||||
}
|
||||
|
||||
public int getMysqlPort() {
|
||||
return config.hasPath("mysql.port") ?
|
||||
config.getInt("mysql.port") : 3306;
|
||||
}
|
||||
|
||||
public String getMysqlDatabase() {
|
||||
return config.hasPath("mysql.database") ?
|
||||
config.getString("mysql.database") : "agricultural_stock";
|
||||
}
|
||||
|
||||
public String getMysqlUser() {
|
||||
return config.hasPath("mysql.user") ?
|
||||
config.getString("mysql.user") : "root";
|
||||
}
|
||||
|
||||
public String getMysqlPassword() {
|
||||
return config.hasPath("mysql.password") ?
|
||||
config.getString("mysql.password") : "root";
|
||||
}
|
||||
|
||||
// Kafka配置已移除
|
||||
|
||||
public String getOutputPath() {
|
||||
return config.hasPath("output.path") ?
|
||||
config.getString("output.path") : "/tmp/spark-output";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,190 @@
|
||||
package com.agricultural.spark.service;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.apache.spark.sql.functions;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
|
||||
/**
|
||||
* 数据清洗服务类
|
||||
*
|
||||
* @author Agricultural Stock Platform Team
|
||||
*/
|
||||
public class DataCleaningService {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(DataCleaningService.class);
|
||||
|
||||
private final SparkSession spark;
|
||||
|
||||
public DataCleaningService(SparkSession spark) {
|
||||
this.spark = spark;
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行数据清洗
|
||||
*
|
||||
* @param rawData 原始数据
|
||||
* @return 清洗后的数据
|
||||
*/
|
||||
public Dataset<Row> cleanData(Dataset<Row> rawData) {
|
||||
logger.info("开始执行数据清洗,原始数据条数: {}", rawData.count());
|
||||
|
||||
Dataset<Row> cleanedData = rawData;
|
||||
|
||||
// 1. 去除重复数据
|
||||
cleanedData = cleanedData.dropDuplicates("stock_code", "trade_date");
|
||||
logger.info("去重后数据条数: {}", cleanedData.count());
|
||||
|
||||
// 2. 处理缺失值
|
||||
cleanedData = handleMissingValues(cleanedData);
|
||||
|
||||
// 3. 数据类型转换
|
||||
cleanedData = convertDataTypes(cleanedData);
|
||||
|
||||
// 4. 异常值处理
|
||||
cleanedData = handleOutliers(cleanedData);
|
||||
|
||||
// 5. 计算派生字段
|
||||
cleanedData = calculateDerivedFields(cleanedData);
|
||||
|
||||
long finalCount = cleanedData.count();
|
||||
logger.info("数据清洗完成,最终数据条数: {}", finalCount);
|
||||
|
||||
return cleanedData;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理缺失值
|
||||
*/
|
||||
private Dataset<Row> handleMissingValues(Dataset<Row> data) {
|
||||
logger.info("处理缺失值...");
|
||||
|
||||
// 数值字段填充为0
|
||||
String[] numericCols = {
|
||||
"open_price", "close_price", "high_price", "low_price",
|
||||
"volume", "turnover", "change_percent", "change_amount",
|
||||
"pe_ratio", "pb_ratio", "market_cap", "float_market_cap"
|
||||
};
|
||||
|
||||
for (String col : numericCols) {
|
||||
if (hasColumn(data, col)) {
|
||||
data = data.na().fill(0.0, new String[]{col});
|
||||
}
|
||||
}
|
||||
|
||||
// 字符串字段填充为空字符串
|
||||
String[] stringCols = {"stock_name"};
|
||||
for (String col : stringCols) {
|
||||
if (hasColumn(data, col)) {
|
||||
data = data.na().fill("", new String[]{col});
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 数据类型转换
|
||||
*/
|
||||
private Dataset<Row> convertDataTypes(Dataset<Row> data) {
|
||||
logger.info("执行数据类型转换...");
|
||||
|
||||
// 转换时间戳字段
|
||||
if (hasColumn(data, "trade_date")) {
|
||||
data = data.withColumn("trade_date",
|
||||
to_timestamp(col("trade_date"), "yyyy-MM-dd HH:mm:ss"));
|
||||
}
|
||||
|
||||
// 确保数值字段为正确的数据类型
|
||||
String[] numericCols = {
|
||||
"open_price", "close_price", "high_price", "low_price",
|
||||
"volume", "turnover", "change_percent", "change_amount",
|
||||
"pe_ratio", "pb_ratio", "market_cap", "float_market_cap"
|
||||
};
|
||||
|
||||
for (String colName : numericCols) {
|
||||
if (hasColumn(data, colName)) {
|
||||
data = data.withColumn(colName, col(colName).cast("double"));
|
||||
}
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理异常值
|
||||
*/
|
||||
private Dataset<Row> handleOutliers(Dataset<Row> data) {
|
||||
logger.info("处理异常值...");
|
||||
|
||||
// 过滤异常数据
|
||||
data = data.filter(
|
||||
col("open_price").$greater$eq(0)
|
||||
.and(col("close_price").$greater$eq(0))
|
||||
.and(col("high_price").$greater$eq(0))
|
||||
.and(col("low_price").$greater$eq(0))
|
||||
.and(col("volume").$greater$eq(0))
|
||||
.and(col("high_price").$greater$eq(col("low_price")))
|
||||
);
|
||||
|
||||
// 过滤极端的涨跌幅数据(超过±20%的数据需要特别检查)
|
||||
data = data.filter(
|
||||
col("change_percent").$greater$eq(-20.0)
|
||||
.and(col("change_percent").$less$eq(20.0))
|
||||
);
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算派生字段
|
||||
*/
|
||||
private Dataset<Row> calculateDerivedFields(Dataset<Row> data) {
|
||||
logger.info("计算派生字段...");
|
||||
|
||||
// 计算价格变动
|
||||
data = data.withColumn("price_change",
|
||||
col("close_price").minus(col("open_price")));
|
||||
|
||||
// 计算价格变动百分比
|
||||
data = data.withColumn("price_change_pct",
|
||||
when(col("open_price").notEqual(0),
|
||||
col("close_price").minus(col("open_price"))
|
||||
.divide(col("open_price")).multiply(100))
|
||||
.otherwise(0));
|
||||
|
||||
// 计算振幅
|
||||
data = data.withColumn("amplitude",
|
||||
when(col("open_price").notEqual(0),
|
||||
col("high_price").minus(col("low_price"))
|
||||
.divide(col("open_price")).multiply(100))
|
||||
.otherwise(0));
|
||||
|
||||
// 计算换手率(如果有流通股本数据)
|
||||
if (hasColumn(data, "float_shares")) {
|
||||
data = data.withColumn("turnover_rate",
|
||||
when(col("float_shares").notEqual(0),
|
||||
col("volume").divide(col("float_shares")).multiply(100))
|
||||
.otherwise(0));
|
||||
}
|
||||
|
||||
return data;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查DataFrame是否包含指定列
|
||||
*/
|
||||
private boolean hasColumn(Dataset<Row> data, String columnName) {
|
||||
String[] columns = data.columns();
|
||||
for (String col : columns) {
|
||||
if (col.equals(columnName)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,189 @@
|
||||
package com.agricultural.spark.service;
|
||||
|
||||
import com.agricultural.spark.config.SparkConfig;
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SaveMode;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.sql.Connection;
|
||||
import java.sql.DriverManager;
|
||||
import java.sql.PreparedStatement;
|
||||
import java.sql.SQLException;
|
||||
import java.time.LocalDate;
|
||||
import java.util.Map;
|
||||
import java.util.Properties;
|
||||
|
||||
/**
|
||||
* 数据库保存服务类
|
||||
* 负责将Spark处理结果保存到MySQL数据库
|
||||
*
|
||||
* @author Agricultural Stock Platform Team
|
||||
*/
|
||||
public class DatabaseSaveService {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(DatabaseSaveService.class);
|
||||
|
||||
private final SparkSession spark;
|
||||
private final SparkConfig config;
|
||||
private final String jdbcUrl;
|
||||
private final Properties connectionProps;
|
||||
|
||||
public DatabaseSaveService(SparkSession spark, SparkConfig config) {
|
||||
this.spark = spark;
|
||||
this.config = config;
|
||||
this.jdbcUrl = String.format("jdbc:mysql://%s:%d/%s?useSSL=false&serverTimezone=Asia/Shanghai",
|
||||
config.getMysqlHost(), config.getMysqlPort(), config.getMysqlDatabase());
|
||||
|
||||
this.connectionProps = new Properties();
|
||||
this.connectionProps.put("user", config.getMysqlUser());
|
||||
this.connectionProps.put("password", config.getMysqlPassword());
|
||||
this.connectionProps.put("driver", "com.mysql.cj.jdbc.Driver");
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存市场分析结果到market_analysis表
|
||||
*/
|
||||
public void saveMarketAnalysis(Map<String, Object> marketOverview) {
|
||||
logger.info("开始保存市场分析结果到数据库...");
|
||||
|
||||
String sql = "INSERT INTO market_analysis (analysis_date, up_count, down_count, flat_count, " +
|
||||
"total_count, total_market_cap, total_volume, total_turnover, avg_change_percent) " +
|
||||
"VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) " +
|
||||
"ON DUPLICATE KEY UPDATE " +
|
||||
"up_count = VALUES(up_count), down_count = VALUES(down_count), " +
|
||||
"flat_count = VALUES(flat_count), total_count = VALUES(total_count), " +
|
||||
"total_market_cap = VALUES(total_market_cap), total_volume = VALUES(total_volume), " +
|
||||
"total_turnover = VALUES(total_turnover), avg_change_percent = VALUES(avg_change_percent)";
|
||||
|
||||
try (Connection conn = DriverManager.getConnection(jdbcUrl, connectionProps);
|
||||
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
||||
|
||||
// 解析交易日期
|
||||
String tradeDateStr = (String) marketOverview.get("trade_date");
|
||||
LocalDate tradeDate = LocalDate.parse(tradeDateStr.substring(0, 10)); // 提取日期部分
|
||||
|
||||
stmt.setDate(1, java.sql.Date.valueOf(tradeDate));
|
||||
stmt.setLong(2, ((Number) marketOverview.get("up_count")).longValue());
|
||||
stmt.setLong(3, ((Number) marketOverview.get("down_count")).longValue());
|
||||
stmt.setLong(4, ((Number) marketOverview.get("flat_count")).longValue());
|
||||
stmt.setLong(5, ((Number) marketOverview.get("total_count")).longValue());
|
||||
stmt.setDouble(6, ((Number) marketOverview.get("total_market_cap")).doubleValue());
|
||||
stmt.setLong(7, ((Number) marketOverview.get("total_volume")).longValue());
|
||||
stmt.setDouble(8, ((Number) marketOverview.get("total_turnover")).doubleValue());
|
||||
stmt.setDouble(9, ((Number) marketOverview.get("avg_change_percent")).doubleValue());
|
||||
|
||||
int rowsAffected = stmt.executeUpdate();
|
||||
logger.info("市场分析结果保存成功,影响 {} 行数据", rowsAffected);
|
||||
|
||||
} catch (SQLException e) {
|
||||
logger.error("保存市场分析结果失败", e);
|
||||
throw new RuntimeException("保存市场分析结果到数据库失败", e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存处理后的技术指标数据到新表(可选)
|
||||
*/
|
||||
public void saveTechnicalIndicators(Dataset<Row> dataWithIndicators) {
|
||||
logger.info("开始保存技术指标数据到数据库...");
|
||||
|
||||
try {
|
||||
// 选择需要保存的字段
|
||||
Dataset<Row> selectedData = dataWithIndicators.select(
|
||||
"stock_code", "stock_name", "trade_date", "close_price",
|
||||
"ma5", "ma10", "ma20", "ma30",
|
||||
"rsi", "macd_dif", "macd_dea",
|
||||
"bb_upper", "bb_middle", "bb_lower"
|
||||
);
|
||||
|
||||
// 保存到stock_technical_indicators表(需要先创建此表)
|
||||
selectedData.write()
|
||||
.mode(SaveMode.Overwrite)
|
||||
.format("jdbc")
|
||||
.option("url", jdbcUrl)
|
||||
.option("dbtable", "stock_technical_indicators")
|
||||
.option("user", config.getMysqlUser())
|
||||
.option("password", config.getMysqlPassword())
|
||||
.option("driver", "com.mysql.cj.jdbc.Driver")
|
||||
.save();
|
||||
|
||||
logger.info("技术指标数据保存成功");
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("保存技术指标数据失败", e);
|
||||
// 不抛出异常,允许程序继续执行
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存行业分析结果
|
||||
*/
|
||||
public void saveIndustryAnalysis(Dataset<Row> industryAnalysis) {
|
||||
logger.info("开始保存行业分析结果到数据库...");
|
||||
|
||||
try {
|
||||
// 添加分析日期字段
|
||||
Dataset<Row> industryWithDate = industryAnalysis.withColumn("analysis_date",
|
||||
org.apache.spark.sql.functions.current_date());
|
||||
|
||||
// 保存到industry_analysis表(需要先创建此表)
|
||||
industryWithDate.write()
|
||||
.mode(SaveMode.Append)
|
||||
.format("jdbc")
|
||||
.option("url", jdbcUrl)
|
||||
.option("dbtable", "industry_analysis")
|
||||
.option("user", config.getMysqlUser())
|
||||
.option("password", config.getMysqlPassword())
|
||||
.option("driver", "com.mysql.cj.jdbc.Driver")
|
||||
.save();
|
||||
|
||||
logger.info("行业分析结果保存成功");
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("保存行业分析结果失败", e);
|
||||
// 不抛出异常,允许程序继续执行
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 保存历史趋势数据
|
||||
*/
|
||||
public void saveHistoricalTrends(Dataset<Row> trendAnalysis) {
|
||||
logger.info("开始保存历史趋势数据到数据库...");
|
||||
|
||||
try {
|
||||
// 保存到market_trends表(需要先创建此表)
|
||||
trendAnalysis.write()
|
||||
.mode(SaveMode.Overwrite)
|
||||
.format("jdbc")
|
||||
.option("url", jdbcUrl)
|
||||
.option("dbtable", "market_trends")
|
||||
.option("user", config.getMysqlUser())
|
||||
.option("password", config.getMysqlPassword())
|
||||
.option("driver", "com.mysql.cj.jdbc.Driver")
|
||||
.save();
|
||||
|
||||
logger.info("历史趋势数据保存成功");
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("保存历史趋势数据失败", e);
|
||||
// 不抛出异常,允许程序继续执行
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 测试数据库连接
|
||||
*/
|
||||
public boolean testConnection() {
|
||||
try (Connection conn = DriverManager.getConnection(jdbcUrl, connectionProps)) {
|
||||
logger.info("数据库连接测试成功");
|
||||
return true;
|
||||
} catch (SQLException e) {
|
||||
logger.error("数据库连接测试失败", e);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,205 @@
|
||||
package com.agricultural.spark.service;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
|
||||
/**
|
||||
* 市场分析服务类
|
||||
*
|
||||
* @author Agricultural Stock Platform Team
|
||||
*/
|
||||
public class MarketAnalysisService {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(MarketAnalysisService.class);
|
||||
|
||||
private final SparkSession spark;
|
||||
|
||||
public MarketAnalysisService(SparkSession spark) {
|
||||
this.spark = spark;
|
||||
}
|
||||
|
||||
/**
|
||||
* 市场总览分析
|
||||
*/
|
||||
public Map<String, Object> analyzeMarketOverview(Dataset<Row> data) {
|
||||
logger.info("执行市场总览分析...");
|
||||
|
||||
// 获取最新交易日的数据
|
||||
Row latestDateRow = data.agg(max("trade_date")).head();
|
||||
if (latestDateRow.isNullAt(0)) {
|
||||
logger.warn("无法获取最新交易日期");
|
||||
return new HashMap<>();
|
||||
}
|
||||
|
||||
Object latestDate = latestDateRow.get(0);
|
||||
Dataset<Row> latestData = data.filter(col("trade_date").equalTo(latestDate));
|
||||
|
||||
// 统计涨跌股票数量
|
||||
long upCount = latestData.filter(col("change_percent").gt(0)).count();
|
||||
long downCount = latestData.filter(col("change_percent").lt(0)).count();
|
||||
long flatCount = latestData.filter(col("change_percent").equalTo(0)).count();
|
||||
long totalCount = latestData.count();
|
||||
|
||||
// 计算总市值和成交量
|
||||
Row aggregateRow = latestData.agg(
|
||||
sum("market_cap").alias("total_market_cap"),
|
||||
sum("volume").alias("total_volume"),
|
||||
sum("turnover").alias("total_turnover"),
|
||||
avg("change_percent").alias("avg_change_percent")
|
||||
).head();
|
||||
|
||||
double totalMarketCap = aggregateRow.isNullAt(0) ? 0.0 : aggregateRow.getDouble(0);
|
||||
long totalVolume = aggregateRow.isNullAt(1) ? 0L : Math.round(aggregateRow.getDouble(1));
|
||||
double totalTurnover = aggregateRow.isNullAt(2) ? 0.0 : aggregateRow.getDouble(2);
|
||||
double avgChangePercent = aggregateRow.isNullAt(3) ? 0.0 : aggregateRow.getDouble(3);
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("trade_date", latestDate.toString());
|
||||
result.put("up_count", upCount);
|
||||
result.put("down_count", downCount);
|
||||
result.put("flat_count", flatCount);
|
||||
result.put("total_count", totalCount);
|
||||
result.put("total_market_cap", totalMarketCap);
|
||||
result.put("total_volume", totalVolume);
|
||||
result.put("total_turnover", totalTurnover);
|
||||
result.put("avg_change_percent", Math.round(avgChangePercent * 100.0) / 100.0);
|
||||
|
||||
logger.info("市场总览分析完成: {}", result);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 分析涨跌幅榜单
|
||||
*/
|
||||
public Map<String, Object> analyzeTopGainersLosers(Dataset<Row> data, int limit) {
|
||||
logger.info("分析涨跌幅榜单(前{}名)...", limit);
|
||||
|
||||
// 获取最新交易日数据
|
||||
Row latestDateRow = data.agg(max("trade_date")).head();
|
||||
Object latestDate = latestDateRow.get(0);
|
||||
Dataset<Row> latestData = data.filter(col("trade_date").equalTo(latestDate));
|
||||
|
||||
// 涨幅榜
|
||||
List<Row> topGainers = latestData
|
||||
.orderBy(col("change_percent").desc())
|
||||
.select("stock_code", "stock_name", "close_price", "change_percent",
|
||||
"volume", "market_cap")
|
||||
.limit(limit)
|
||||
.collectAsList();
|
||||
|
||||
// 跌幅榜
|
||||
List<Row> topLosers = latestData
|
||||
.orderBy(col("change_percent").asc())
|
||||
.select("stock_code", "stock_name", "close_price", "change_percent",
|
||||
"volume", "market_cap")
|
||||
.limit(limit)
|
||||
.collectAsList();
|
||||
|
||||
// 成交量榜
|
||||
List<Row> topVolume = latestData
|
||||
.orderBy(col("volume").desc())
|
||||
.select("stock_code", "stock_name", "close_price", "change_percent",
|
||||
"volume", "turnover")
|
||||
.limit(limit)
|
||||
.collectAsList();
|
||||
|
||||
Map<String, Object> result = new HashMap<>();
|
||||
result.put("top_gainers", topGainers);
|
||||
result.put("top_losers", topLosers);
|
||||
result.put("top_volume", topVolume);
|
||||
|
||||
logger.info("涨跌幅榜单分析完成");
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 行业表现分析
|
||||
*/
|
||||
public Dataset<Row> analyzeIndustryPerformance(Dataset<Row> data) {
|
||||
logger.info("执行行业表现分析...");
|
||||
|
||||
// 注册UDF函数用于行业分类
|
||||
spark.udf().register("getIndustry", (String stockCode) -> {
|
||||
if (stockCode == null || stockCode.length() < 6) {
|
||||
return "其他农业";
|
||||
}
|
||||
|
||||
String code = stockCode.substring(stockCode.length() - 6);
|
||||
switch (code) {
|
||||
case "000876":
|
||||
case "002714":
|
||||
return "畜牧业";
|
||||
case "600519":
|
||||
case "000858":
|
||||
case "600887":
|
||||
case "002304":
|
||||
return "食品饮料";
|
||||
default:
|
||||
return "其他农业";
|
||||
}
|
||||
}, org.apache.spark.sql.types.DataTypes.StringType);
|
||||
|
||||
// 添加行业字段
|
||||
Dataset<Row> dataWithIndustry = data.withColumn("industry",
|
||||
callUDF("getIndustry", col("stock_code")));
|
||||
|
||||
// 获取最新交易日数据
|
||||
Row latestDateRow = dataWithIndustry.agg(max("trade_date")).head();
|
||||
Object latestDate = latestDateRow.get(0);
|
||||
Dataset<Row> latestData = dataWithIndustry.filter(col("trade_date").equalTo(latestDate));
|
||||
|
||||
// 按行业统计
|
||||
Dataset<Row> industryStats = latestData.groupBy("industry")
|
||||
.agg(
|
||||
count("*").alias("stock_count"),
|
||||
avg("change_percent").alias("avg_change_percent"),
|
||||
sum("market_cap").alias("total_market_cap"),
|
||||
sum("volume").alias("total_volume")
|
||||
)
|
||||
.orderBy(col("avg_change_percent").desc());
|
||||
|
||||
logger.info("行业表现分析完成,共 {} 个行业", industryStats.count());
|
||||
return industryStats;
|
||||
}
|
||||
|
||||
/**
|
||||
* 历史趋势分析
|
||||
*/
|
||||
public Dataset<Row> analyzeHistoricalTrends(Dataset<Row> data, String stockCode, int days) {
|
||||
logger.info("分析历史趋势({}天)...", days);
|
||||
|
||||
// 计算起始日期
|
||||
Row latestDateRow = data.agg(max("trade_date")).head();
|
||||
Object endDate = latestDateRow.get(0);
|
||||
|
||||
// 这里简化处理,实际应该使用日期计算
|
||||
Dataset<Row> trendData = data;
|
||||
|
||||
if (stockCode != null && !stockCode.isEmpty()) {
|
||||
trendData = trendData.filter(col("stock_code").equalTo(stockCode));
|
||||
}
|
||||
|
||||
// 按日期聚合
|
||||
Dataset<Row> dailyStats = trendData.groupBy("trade_date")
|
||||
.agg(
|
||||
avg("close_price").alias("avg_price"),
|
||||
avg("change_percent").alias("avg_change_percent"),
|
||||
sum("volume").alias("total_volume"),
|
||||
sum("turnover").alias("total_turnover"),
|
||||
count("*").alias("stock_count")
|
||||
)
|
||||
.orderBy("trade_date");
|
||||
|
||||
logger.info("历史趋势分析完成");
|
||||
return dailyStats;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,188 @@
|
||||
package com.agricultural.spark.service;
|
||||
|
||||
import org.apache.spark.sql.Dataset;
|
||||
import org.apache.spark.sql.Row;
|
||||
import org.apache.spark.sql.SparkSession;
|
||||
import org.apache.spark.sql.expressions.Window;
|
||||
import org.apache.spark.sql.expressions.WindowSpec;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import static org.apache.spark.sql.functions.*;
|
||||
|
||||
/**
|
||||
* 技术指标服务类
|
||||
*
|
||||
* @author Agricultural Stock Platform Team
|
||||
*/
|
||||
public class TechnicalIndicatorService {
|
||||
|
||||
private static final Logger logger = LoggerFactory.getLogger(TechnicalIndicatorService.class);
|
||||
|
||||
private final SparkSession spark;
|
||||
|
||||
public TechnicalIndicatorService(SparkSession spark) {
|
||||
this.spark = spark;
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算技术指标
|
||||
*/
|
||||
public Dataset<Row> calculateTechnicalIndicators(Dataset<Row> data) {
|
||||
logger.info("开始计算技术指标...");
|
||||
|
||||
// 按股票代码和日期排序的窗口
|
||||
WindowSpec windowSpec = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-29, 0); // 30天窗口
|
||||
|
||||
WindowSpec windowSpec5 = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-4, 0); // 5天窗口
|
||||
|
||||
WindowSpec windowSpec10 = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-9, 0); // 10天窗口
|
||||
|
||||
WindowSpec windowSpec20 = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-19, 0); // 20天窗口
|
||||
|
||||
Dataset<Row> result = data;
|
||||
|
||||
// 1. 移动平均线 (MA)
|
||||
result = result.withColumn("ma5", avg("close_price").over(windowSpec5))
|
||||
.withColumn("ma10", avg("close_price").over(windowSpec10))
|
||||
.withColumn("ma20", avg("close_price").over(windowSpec20))
|
||||
.withColumn("ma30", avg("close_price").over(windowSpec));
|
||||
|
||||
// 2. 成交量移动平均
|
||||
result = result.withColumn("volume_ma5", avg("volume").over(windowSpec5))
|
||||
.withColumn("volume_ma10", avg("volume").over(windowSpec10));
|
||||
|
||||
// 3. 计算价格相对于移动平均线的位置
|
||||
result = result.withColumn("price_vs_ma5",
|
||||
when(col("ma5").notEqual(0),
|
||||
col("close_price").divide(col("ma5")).multiply(100).minus(100))
|
||||
.otherwise(0))
|
||||
.withColumn("price_vs_ma20",
|
||||
when(col("ma20").notEqual(0),
|
||||
col("close_price").divide(col("ma20")).multiply(100).minus(100))
|
||||
.otherwise(0));
|
||||
|
||||
// 4. 计算波动率 (30天)
|
||||
result = result.withColumn("volatility_30d",
|
||||
stddev("change_percent").over(windowSpec));
|
||||
|
||||
// 5. 计算相对强弱指标 RSI (简化版)
|
||||
result = calculateRSI(result, 14);
|
||||
|
||||
// 6. 计算MACD指标
|
||||
result = calculateMACD(result);
|
||||
|
||||
// 7. 计算布林带
|
||||
result = calculateBollingerBands(result, 20);
|
||||
|
||||
logger.info("技术指标计算完成");
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算RSI相对强弱指标
|
||||
*/
|
||||
private Dataset<Row> calculateRSI(Dataset<Row> data, int period) {
|
||||
WindowSpec windowSpec = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-(period-1), 0);
|
||||
|
||||
// 计算价格变化
|
||||
data = data.withColumn("price_diff",
|
||||
col("close_price").minus(lag("close_price", 1)
|
||||
.over(Window.partitionBy("stock_code").orderBy("trade_date"))));
|
||||
|
||||
// 计算涨跌
|
||||
data = data.withColumn("gain", when(col("price_diff").gt(0), col("price_diff")).otherwise(0))
|
||||
.withColumn("loss", when(col("price_diff").lt(0), abs(col("price_diff"))).otherwise(0));
|
||||
|
||||
// 计算平均涨跌
|
||||
data = data.withColumn("avg_gain", avg("gain").over(windowSpec))
|
||||
.withColumn("avg_loss", avg("loss").over(windowSpec));
|
||||
|
||||
// 计算RSI
|
||||
data = data.withColumn("rsi",
|
||||
when(col("avg_loss").notEqual(0),
|
||||
lit(100).minus(lit(100).divide(lit(1).plus(col("avg_gain").divide(col("avg_loss"))))))
|
||||
.otherwise(50));
|
||||
|
||||
return data.drop("price_diff", "gain", "loss", "avg_gain", "avg_loss");
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算MACD指标
|
||||
*/
|
||||
private Dataset<Row> calculateMACD(Dataset<Row> data) {
|
||||
// 计算EMA12和EMA26
|
||||
data = data.withColumn("ema12", calculateEMA(col("close_price"), 12))
|
||||
.withColumn("ema26", calculateEMA(col("close_price"), 26));
|
||||
|
||||
// 计算MACD线 (DIF)
|
||||
data = data.withColumn("macd_dif", col("ema12").minus(col("ema26")));
|
||||
|
||||
// 计算DEA (MACD的9日EMA)
|
||||
data = data.withColumn("macd_dea", calculateEMA(col("macd_dif"), 9));
|
||||
|
||||
// 计算MACD柱状图
|
||||
data = data.withColumn("macd_histogram",
|
||||
col("macd_dif").minus(col("macd_dea")).multiply(2));
|
||||
|
||||
return data.drop("ema12", "ema26");
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算布林带
|
||||
*/
|
||||
private Dataset<Row> calculateBollingerBands(Dataset<Row> data, int period) {
|
||||
WindowSpec windowSpec = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-(period-1), 0);
|
||||
|
||||
// 计算中轨(移动平均)
|
||||
data = data.withColumn("bb_middle", avg("close_price").over(windowSpec));
|
||||
|
||||
// 计算标准差
|
||||
data = data.withColumn("bb_std", stddev("close_price").over(windowSpec));
|
||||
|
||||
// 计算上轨和下轨
|
||||
data = data.withColumn("bb_upper",
|
||||
col("bb_middle").plus(col("bb_std").multiply(2)))
|
||||
.withColumn("bb_lower",
|
||||
col("bb_middle").minus(col("bb_std").multiply(2)));
|
||||
|
||||
// 计算布林带宽度
|
||||
data = data.withColumn("bb_width",
|
||||
when(col("bb_middle").notEqual(0),
|
||||
col("bb_upper").minus(col("bb_lower")).divide(col("bb_middle")).multiply(100))
|
||||
.otherwise(0));
|
||||
|
||||
// 计算价格在布林带中的位置
|
||||
data = data.withColumn("bb_position",
|
||||
when(col("bb_upper").notEqual(col("bb_lower")),
|
||||
col("close_price").minus(col("bb_lower"))
|
||||
.divide(col("bb_upper").minus(col("bb_lower"))))
|
||||
.otherwise(0.5));
|
||||
|
||||
return data.drop("bb_std");
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算指数移动平均 (EMA) - 简化版
|
||||
*/
|
||||
private org.apache.spark.sql.Column calculateEMA(org.apache.spark.sql.Column priceCol, int period) {
|
||||
// 简化实现,实际应该使用更复杂的EMA计算
|
||||
WindowSpec windowSpec = Window.partitionBy("stock_code")
|
||||
.orderBy("trade_date")
|
||||
.rowsBetween(-(period-1), 0);
|
||||
|
||||
return avg(priceCol).over(windowSpec);
|
||||
}
|
||||
}
|
||||
52
spark-processor/src/main/resources/logback.xml
Normal file
52
spark-processor/src/main/resources/logback.xml
Normal file
@@ -0,0 +1,52 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<configuration>
|
||||
|
||||
<!-- 控制台输出配置 -->
|
||||
<appender name="CONSOLE" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<!-- 文件输出配置(可选) -->
|
||||
<appender name="FILE" class="ch.qos.logback.core.rolling.RollingFileAppender">
|
||||
<file>logs/spark-processor.log</file>
|
||||
<rollingPolicy class="ch.qos.logback.core.rolling.TimeBasedRollingPolicy">
|
||||
<fileNamePattern>logs/spark-processor.%d{yyyy-MM-dd}.%i.log</fileNamePattern>
|
||||
<maxFileSize>10MB</maxFileSize>
|
||||
<maxHistory>30</maxHistory>
|
||||
</rollingPolicy>
|
||||
<encoder>
|
||||
<pattern>%d{yyyy-MM-dd HH:mm:ss.SSS} [%thread] %-5level %logger{50} - %msg%n</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<!-- 设置特定包的日志级别 -->
|
||||
|
||||
<!-- 我们自己的应用程序日志 - 保持INFO级别 -->
|
||||
<logger name="com.agricultural.spark" level="INFO" additivity="false">
|
||||
<appender-ref ref="CONSOLE"/>
|
||||
<appender-ref ref="FILE"/>
|
||||
</logger>
|
||||
|
||||
<!-- Spark框架日志 - 设置为WARN级别,减少输出 -->
|
||||
<logger name="org.apache.spark" level="WARN"/>
|
||||
<logger name="org.apache.hadoop" level="WARN"/>
|
||||
<logger name="org.apache.hive" level="WARN"/>
|
||||
<logger name="org.apache.parquet" level="WARN"/>
|
||||
|
||||
<!-- 数据库连接日志 - 设置为WARN级别 -->
|
||||
<logger name="com.mysql" level="WARN"/>
|
||||
<logger name="mysql" level="WARN"/>
|
||||
|
||||
<!-- Maven和其他框架日志 - 设置为ERROR级别 -->
|
||||
<logger name="org.apache.maven" level="ERROR"/>
|
||||
<logger name="org.eclipse.jetty" level="ERROR"/>
|
||||
<logger name="io.netty" level="ERROR"/>
|
||||
|
||||
<!-- 根日志级别 -->
|
||||
<root level="WARN">
|
||||
<appender-ref ref="CONSOLE"/>
|
||||
</root>
|
||||
|
||||
</configuration>
|
||||
Reference in New Issue
Block a user