first commit
This commit is contained in:
135
src/config/app.rs
Normal file
135
src/config/app.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use super::{auth::AuthConfig, database::DatabaseConfig, redis::RedisConfig, server::ServerConfig};
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// 导入 redis 默认值函数(使用完整路径)
|
||||
use crate::config::redis::default_redis_host;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
pub redis: RedisConfig,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// 加载配置(支持 CLI 覆盖)
|
||||
///
|
||||
/// 如果 config_path 为 None,则仅使用环境变量和默认值
|
||||
pub fn load_with_overrides(
|
||||
cli_config_path: Option<std::path::PathBuf>,
|
||||
overrides: crate::cli::ConfigOverrides,
|
||||
_environment: &str,
|
||||
) -> Result<Self, ConfigError> {
|
||||
// 使用 ConfigBuilder 设置配置
|
||||
let mut builder = Config::builder();
|
||||
|
||||
// 如果提供了配置文件,先加载它
|
||||
if let Some(config_path) = cli_config_path {
|
||||
if !config_path.exists() {
|
||||
tracing::error!("Configuration file not found: {}", config_path.display());
|
||||
return Err(ConfigError::NotFound(
|
||||
config_path.to_string_lossy().to_string(),
|
||||
));
|
||||
}
|
||||
tracing::info!("Loading configuration from: {}", config_path.display());
|
||||
builder = builder.add_source(File::from(config_path));
|
||||
} else {
|
||||
tracing::info!("No configuration file found, using environment variables and defaults");
|
||||
tracing::warn!("⚠️ 没有找到配置文件,将使用 SQLite 作为默认数据库");
|
||||
tracing::warn!(" 默认数据库路径: db.sqlite3");
|
||||
tracing::warn!(" 如需使用其他数据库,请创建配置文件或设置环境变量");
|
||||
|
||||
// 直接使用 set_default 设置默认值
|
||||
// 注意:这些值会被环境变量覆盖
|
||||
builder = builder.set_default("server.host", default_server_host())?;
|
||||
builder = builder.set_default("server.port", default_server_port())?;
|
||||
|
||||
// 设置 database 默认值(使用 SQLite 作为默认数据库)
|
||||
builder = builder.set_default("database.database_type", "sqlite")?;
|
||||
builder = builder.set_default("database.path", "db.sqlite3")?;
|
||||
builder = builder.set_default("database.max_connections", 10)?;
|
||||
|
||||
// 设置 auth 默认值
|
||||
builder = builder.set_default("auth.jwt_secret", default_jwt_secret())?;
|
||||
builder = builder.set_default("auth.access_token_expiration_minutes", 15)?;
|
||||
builder = builder.set_default("auth.refresh_token_expiration_days", 7)?;
|
||||
|
||||
// 设置 redis 默认值
|
||||
builder = builder.set_default("redis.host", default_redis_host())?;
|
||||
builder = builder.set_default("redis.port", 6379)?;
|
||||
builder = builder.set_default("redis.db", 0)?;
|
||||
}
|
||||
|
||||
// 添加环境变量源(会覆盖配置文件的值)
|
||||
builder = builder.add_source(Environment::default().separator("_"));
|
||||
|
||||
// 应用 CLI 覆盖(仅 Web 服务器参数)
|
||||
if let Some(host) = overrides.host {
|
||||
builder = builder.set_override("server.host", host)?;
|
||||
}
|
||||
if let Some(port) = overrides.port {
|
||||
builder = builder.set_override("server.port", port)?;
|
||||
}
|
||||
|
||||
let settings = builder.build()?;
|
||||
let config: AppConfig = settings.try_deserialize()?;
|
||||
|
||||
// 安全警告:检查是否使用了默认的 JWT 密钥
|
||||
if config.auth.jwt_secret == "change-this-to-a-strong-secret-key-in-production" {
|
||||
tracing::warn!("⚠️ 警告:正在使用不安全的默认 JWT 密钥!");
|
||||
tracing::warn!(" 请通过环境变量 AUTH_JWT_SECRET 或配置文件设置强密钥");
|
||||
tracing::warn!(" 示例:AUTH_JWT_SECRET=your-secure-random-string-here");
|
||||
}
|
||||
|
||||
// 验证数据库配置
|
||||
if let Err(e) = config.database.validate() {
|
||||
tracing::error!("数据库配置无效: {}", e);
|
||||
return Err(ConfigError::Message(format!("数据库配置无效: {}", e)));
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// 从指定路径加载配置
|
||||
pub fn load_from_path(path: &str) -> Result<Self, ConfigError> {
|
||||
tracing::info!("Loading configuration from: {}", path);
|
||||
|
||||
let settings = Config::builder()
|
||||
.add_source(File::from(PathBuf::from(path)))
|
||||
.add_source(Environment::default().separator("_"))
|
||||
.build()?;
|
||||
|
||||
let config: AppConfig = settings.try_deserialize()?;
|
||||
|
||||
// 安全警告:检查是否使用了默认的 JWT 密钥
|
||||
if config.auth.jwt_secret == "change-this-to-a-strong-secret-key-in-production" {
|
||||
tracing::warn!("⚠️ 警告:正在使用不安全的默认 JWT 密钥!");
|
||||
tracing::warn!(" 请通过环境变量 AUTH_JWT_SECRET 或配置文件设置强密钥");
|
||||
tracing::warn!(" 示例:AUTH_JWT_SECRET=your-secure-random-string-here");
|
||||
}
|
||||
|
||||
// 验证数据库配置
|
||||
if let Err(e) = config.database.validate() {
|
||||
tracing::error!("数据库配置无效: {}", e);
|
||||
return Err(ConfigError::Message(format!("数据库配置无效: {}", e)));
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
// 默认值函数(复用)
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 {
|
||||
3000
|
||||
}
|
||||
|
||||
fn default_jwt_secret() -> String {
|
||||
"change-this-to-a-strong-secret-key-in-production".to_string()
|
||||
}
|
||||
25
src/config/auth.rs
Normal file
25
src/config/auth.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct AuthConfig {
|
||||
#[serde(default = "default_jwt_secret")]
|
||||
pub jwt_secret: String,
|
||||
#[serde(default = "default_access_token_expiration_minutes")]
|
||||
pub access_token_expiration_minutes: u64,
|
||||
#[serde(default = "default_refresh_token_expiration_days")]
|
||||
pub refresh_token_expiration_days: i64,
|
||||
}
|
||||
|
||||
fn default_jwt_secret() -> String {
|
||||
// ⚠️ 警告:这是一个不安全的默认值,仅用于开发测试
|
||||
// 生产环境必须通过环境变量或配置文件设置强密钥
|
||||
"change-this-to-a-strong-secret-key-in-production".to_string()
|
||||
}
|
||||
|
||||
fn default_access_token_expiration_minutes() -> u64 {
|
||||
15
|
||||
}
|
||||
|
||||
fn default_refresh_token_expiration_days() -> i64 {
|
||||
7
|
||||
}
|
||||
168
src/config/database.rs
Normal file
168
src/config/database.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// 数据库类型
|
||||
#[derive(Debug, Deserialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum DatabaseType {
|
||||
MySQL,
|
||||
SQLite,
|
||||
PostgreSQL,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct DatabaseConfig {
|
||||
/// 数据库类型
|
||||
#[serde(default = "default_database_type")]
|
||||
pub database_type: DatabaseType,
|
||||
|
||||
/// 网络数据库配置(MySQL/PostgreSQL)
|
||||
pub host: Option<String>,
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
pub user: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub database: Option<String>,
|
||||
|
||||
/// SQLite 文件路径
|
||||
pub path: Option<PathBuf>,
|
||||
|
||||
/// 连接池最大连接数
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: u32,
|
||||
}
|
||||
|
||||
impl DatabaseConfig {
|
||||
/// 获取端口号(根据数据库类型返回默认值)
|
||||
pub fn get_port(&self) -> u16 {
|
||||
self.port.unwrap_or_else(|| match self.database_type {
|
||||
DatabaseType::MySQL => 3306,
|
||||
DatabaseType::PostgreSQL => 5432,
|
||||
DatabaseType::SQLite => 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// 构建数据库连接 URL
|
||||
///
|
||||
/// # 错误
|
||||
///
|
||||
/// 当缺少必需的配置字段时返回错误
|
||||
pub fn build_url(&self) -> Result<String, String> {
|
||||
match self.database_type {
|
||||
DatabaseType::MySQL => {
|
||||
let host = self
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.host".to_string())?;
|
||||
let user = self
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.user".to_string())?;
|
||||
let password = self
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.password".to_string())?;
|
||||
let database = self
|
||||
.database
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.database".to_string())?;
|
||||
|
||||
Ok(format!(
|
||||
"mysql://{}:{}@{}:{}/{}",
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
self.get_port(),
|
||||
database
|
||||
))
|
||||
}
|
||||
DatabaseType::SQLite => {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.ok_or_else(|| "SQLite 需要配置 database.path".to_string())?;
|
||||
|
||||
// SQLite URL 格式
|
||||
// 相对路径:sqlite:./db.sqlite3
|
||||
// 绝对路径:sqlite:C:/path/to/db.sqlite3
|
||||
let path_str = path.to_string_lossy().replace('\\', "/");
|
||||
Ok(format!("sqlite:{}", path_str))
|
||||
}
|
||||
DatabaseType::PostgreSQL => {
|
||||
let host = self
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.host".to_string())?;
|
||||
let user = self
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.user".to_string())?;
|
||||
let password = self
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.password".to_string())?;
|
||||
let database = self
|
||||
.database
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.database".to_string())?;
|
||||
|
||||
Ok(format!(
|
||||
"postgresql://{}:{}@{}:{}/{}",
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
self.get_port(),
|
||||
database
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 验证配置是否完整
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
match self.database_type {
|
||||
DatabaseType::MySQL => {
|
||||
if self.host.is_none() {
|
||||
return Err("MySQL 需要配置 database.host".to_string());
|
||||
}
|
||||
if self.user.is_none() {
|
||||
return Err("MySQL 需要配置 database.user".to_string());
|
||||
}
|
||||
if self.password.is_none() {
|
||||
return Err("MySQL 需要配置 database.password".to_string());
|
||||
}
|
||||
if self.database.is_none() {
|
||||
return Err("MySQL 需要配置 database.database".to_string());
|
||||
}
|
||||
}
|
||||
DatabaseType::SQLite => {
|
||||
if self.path.is_none() {
|
||||
return Err("SQLite 需要配置 database.path".to_string());
|
||||
}
|
||||
}
|
||||
DatabaseType::PostgreSQL => {
|
||||
if self.host.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.host".to_string());
|
||||
}
|
||||
if self.user.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.user".to_string());
|
||||
}
|
||||
if self.password.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.password".to_string());
|
||||
}
|
||||
if self.database.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.database".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_database_type() -> DatabaseType {
|
||||
DatabaseType::MySQL
|
||||
}
|
||||
|
||||
fn default_max_connections() -> u32 {
|
||||
10
|
||||
}
|
||||
5
src/config/mod.rs
Normal file
5
src/config/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod app;
|
||||
pub mod auth;
|
||||
pub mod database;
|
||||
pub mod redis;
|
||||
pub mod server;
|
||||
52
src/config/redis.rs
Normal file
52
src/config/redis.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct RedisConfig {
|
||||
/// Redis 主机地址
|
||||
#[serde(default = "default_redis_host")]
|
||||
pub host: String,
|
||||
|
||||
/// Redis 端口
|
||||
#[serde(default = "default_redis_port")]
|
||||
pub port: u16,
|
||||
|
||||
/// Redis 密码(可选)
|
||||
#[serde(default)]
|
||||
pub password: Option<String>,
|
||||
|
||||
/// Redis 数据库编号(可选)
|
||||
#[serde(default = "default_redis_db")]
|
||||
pub db: u8,
|
||||
}
|
||||
|
||||
pub fn default_redis_host() -> String {
|
||||
"localhost".to_string()
|
||||
}
|
||||
|
||||
pub fn default_redis_port() -> u16 {
|
||||
6379
|
||||
}
|
||||
|
||||
pub fn default_redis_db() -> u8 {
|
||||
0
|
||||
}
|
||||
|
||||
impl RedisConfig {
|
||||
/// 构建 Redis 连接 URL
|
||||
pub fn build_url(&self) -> String {
|
||||
// 判断密码是否存在且非空
|
||||
match &self.password {
|
||||
Some(password) if !password.is_empty() => {
|
||||
// 有密码:redis://:password@host:port/db
|
||||
format!(
|
||||
"redis://:{}@{}:{}/{}",
|
||||
password, self.host, self.port, self.db
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// 无密码(None 或空字符串):redis://host:port/db
|
||||
format!("redis://{}:{}/{}", self.host, self.port, self.db)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
17
src/config/server.rs
Normal file
17
src/config/server.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_server_host")]
|
||||
pub host: String,
|
||||
#[serde(default = "default_server_port")]
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 {
|
||||
3000
|
||||
}
|
||||
Reference in New Issue
Block a user