first commit

This commit is contained in:
2026-02-13 15:57:29 +08:00
commit aacda0b66a
53 changed files with 10029 additions and 0 deletions

234
src/cli.rs Normal file
View File

@@ -0,0 +1,234 @@
/// 命令行参数和配置管理
/// 支持优先级CLI 参数 > 环境变量 > 配置文件 > 默认值
use clap::{Parser, ValueEnum};
use std::path::PathBuf;
/// 运行环境(强类型)
#[derive(ValueEnum, Clone, Debug)]
pub enum Environment {
/// 开发环境
Development,
/// 生产环境
Production,
}
impl Environment {
/// 转换为小写字符串
pub fn as_str(&self) -> &'static str {
match self {
Environment::Development => "development",
Environment::Production => "production",
}
}
}
/// 命令行参数
#[derive(Parser, Debug)]
#[command(name = "web-rust-template")]
#[command(about = "Web Server Template", long_about = None)]
#[command(author = "Your Name <your.email@example.com>")]
#[command(version = "0.1.0")]
#[command(propagate_version = true)]
pub struct CliArgs {
/// 指定配置文件路径
///
/// 支持相对路径和绝对路径
/// 例如:-c config/production.toml
#[arg(short, long, value_name = "FILE")]
pub config: Option<PathBuf>,
/// 指定运行环境
///
/// 自动加载对应环境的配置文件(如 config/development.toml
/// 可通过环境变量 ENV 设置
#[arg(
short = 'e',
long,
value_enum,
env = "ENV",
default_value = "development"
)]
pub env: Environment,
/// 指定服务器监听端口
///
/// 覆盖配置文件中的 port 设置
/// 可通过环境变量 SERVER_PORT 设置
#[arg(short, long, global = true, env = "SERVER_PORT")]
pub port: Option<u16>,
/// 指定服务器监听地址
///
/// 覆盖配置文件中的 host 设置
/// 可通过环境变量 SERVER_HOST 设置
#[arg(long, global = true, env = "SERVER_HOST")]
pub host: Option<String>,
/// 启用调试日志
///
/// 输出详细的日志信息,包括 SQL 查询
/// 可通过环境变量 DEBUG 设置
/// 注意:与 -v 冲突,推荐使用 -v/-vv/-vvv
#[arg(long, global = true, env = "DEBUG", conflicts_with = "verbose")]
pub debug: bool,
/// 工作目录
///
/// 指定配置文件和数据库的基准目录
#[arg(short, long, global = true)]
pub work_dir: Option<PathBuf>,
/// 显示详细日志(多级 verbose
///
/// -v : info 级别日志
/// -vv : debug 级别日志(等同于 --debug
/// -vvv : trace 级别日志(最详细)
#[arg(short, long, global = true, action = clap::ArgAction::Count)]
pub verbose: u8,
}
impl CliArgs {
/// 获取是否启用调试
pub fn is_debug_enabled(&self) -> bool {
self.debug || self.verbose >= 2
}
/// 获取日志级别
pub fn get_log_level(&self) -> &'static str {
if self.debug {
return "debug";
}
match self.verbose {
0 => "info",
1 => "debug",
_ => "trace",
}
}
/// 获取环境变量的日志过滤器(工程化版本)
pub fn get_log_filter(&self) -> String {
let level = self.get_log_level();
match level {
"trace" => "web_rust_template=trace,tower_http=trace,axum=trace,sqlx=debug".into(),
"debug" => "web_rust_template=debug,tower_http=debug,axum=debug,sqlx=debug".into(),
_ => "web_rust_template=info,tower_http=info,axum=info".into(),
}
}
/// 获取配置文件路径
///
/// 优先级:
/// 1. CLI 参数 --config
/// 2. 环境变量 CONFIG
/// 3. {work_dir}/config/{env}.toml
/// 4. ./config/{env}.toml
/// 5. ./config/default.toml
///
/// 如果找不到配置文件,返回 None允许仅使用环境变量运行
pub fn resolve_config_path(&self) -> Option<PathBuf> {
use std::env;
// 1. CLI 参数优先
if let Some(ref config) = self.config {
if config.exists() {
return Some(config.clone());
}
eprintln!("⚠ 警告:指定的配置文件不存在: {}", config.display());
eprintln!(" 将仅使用环境变量运行");
return None;
}
// 2. 环境变量
if let Ok(config_path) = env::var("CONFIG") {
let config = PathBuf::from(&config_path);
if config.exists() {
return Some(config);
}
eprintln!("⚠ 警告:环境变量 CONFIG 指定的配置文件不存在: {}", config_path);
eprintln!(" 将仅使用环境变量运行");
return None;
}
// 3-6. 查找配置文件
let work_dir = self
.work_dir
.clone()
.or_else(|| env::current_dir().ok())
.unwrap_or_else(|| PathBuf::from("."));
let env_name = self.env.as_str();
// 按优先级尝试的位置
let candidates = [
// 工作目录下的环境配置
work_dir.join("config").join(format!("{}.toml", env_name)),
// 当前目录的环境配置
PathBuf::from(format!("config/{}.toml", env_name)),
// 工作目录下的默认配置
work_dir.join("config").join("default.toml"),
// 当前目录的默认配置
PathBuf::from("config/default.toml"),
];
for candidate in &candidates {
if candidate.exists() {
// 使用 println! 而非 tracing::info!
println!("✓ Found configuration file: {}", candidate.display());
return Some(candidate.clone());
}
}
// 所有候选路径都找不到配置文件,返回 None
eprintln!(" 未找到配置文件,将仅使用环境变量和默认值");
None
}
/// 获取覆盖配置
///
/// CLI 参数可以覆盖配置文件中的值(仅 Web 服务器参数)
pub fn get_overrides(&self) -> ConfigOverrides {
ConfigOverrides {
host: self.host.clone(),
port: self.port,
}
}
/// 显示启动信息(工程化版本:打印实际解析的配置)
///
/// 使用 println! 而非 tracing::info!,因为 logger 可能尚未初始化
pub fn print_startup_info(&self) {
let separator = "=".repeat(60);
println!("{}", separator);
println!("Web Rust Template Server v0.1.0");
println!("Environment: {}", self.env.as_str());
// 打印实际解析的配置路径(而非 CLI 参数)
if let Some(config_path) = self.resolve_config_path() {
println!("Config file: {}", config_path.display());
} else {
println!("Config file: None (using environment variables)");
}
if let Some(ref work_dir) = self.work_dir {
println!("Work directory: {}", work_dir.display());
}
// 打印实际的日志级别
println!("Log level: {}", self.get_log_level());
if self.is_debug_enabled() {
println!("Debug mode: ENABLED");
}
println!("{}", separator);
}
}
/// CLI 参数覆盖的配置(仅 Web 服务器参数)
#[derive(Debug, Clone)]
pub struct ConfigOverrides {
/// Web 服务器主机覆盖
pub host: Option<String>,
/// Web 服务器端口覆盖
pub port: Option<u16>,
}

135
src/config/app.rs Normal file
View File

@@ -0,0 +1,135 @@
use super::{auth::AuthConfig, database::DatabaseConfig, redis::RedisConfig, server::ServerConfig};
use config::{Config, ConfigError, Environment, File};
use serde::Deserialize;
use std::path::PathBuf;
// 导入 redis 默认值函数(使用完整路径)
use crate::config::redis::default_redis_host;
#[derive(Debug, Deserialize, Clone)]
pub struct AppConfig {
pub server: ServerConfig,
pub database: DatabaseConfig,
pub auth: AuthConfig,
pub redis: RedisConfig,
}
impl AppConfig {
/// 加载配置(支持 CLI 覆盖)
///
/// 如果 config_path 为 None则仅使用环境变量和默认值
pub fn load_with_overrides(
cli_config_path: Option<std::path::PathBuf>,
overrides: crate::cli::ConfigOverrides,
_environment: &str,
) -> Result<Self, ConfigError> {
// 使用 ConfigBuilder 设置配置
let mut builder = Config::builder();
// 如果提供了配置文件,先加载它
if let Some(config_path) = cli_config_path {
if !config_path.exists() {
tracing::error!("Configuration file not found: {}", config_path.display());
return Err(ConfigError::NotFound(
config_path.to_string_lossy().to_string(),
));
}
tracing::info!("Loading configuration from: {}", config_path.display());
builder = builder.add_source(File::from(config_path));
} else {
tracing::info!("No configuration file found, using environment variables and defaults");
tracing::warn!("⚠️ 没有找到配置文件,将使用 SQLite 作为默认数据库");
tracing::warn!(" 默认数据库路径: db.sqlite3");
tracing::warn!(" 如需使用其他数据库,请创建配置文件或设置环境变量");
// 直接使用 set_default 设置默认值
// 注意:这些值会被环境变量覆盖
builder = builder.set_default("server.host", default_server_host())?;
builder = builder.set_default("server.port", default_server_port())?;
// 设置 database 默认值(使用 SQLite 作为默认数据库)
builder = builder.set_default("database.database_type", "sqlite")?;
builder = builder.set_default("database.path", "db.sqlite3")?;
builder = builder.set_default("database.max_connections", 10)?;
// 设置 auth 默认值
builder = builder.set_default("auth.jwt_secret", default_jwt_secret())?;
builder = builder.set_default("auth.access_token_expiration_minutes", 15)?;
builder = builder.set_default("auth.refresh_token_expiration_days", 7)?;
// 设置 redis 默认值
builder = builder.set_default("redis.host", default_redis_host())?;
builder = builder.set_default("redis.port", 6379)?;
builder = builder.set_default("redis.db", 0)?;
}
// 添加环境变量源(会覆盖配置文件的值)
builder = builder.add_source(Environment::default().separator("_"));
// 应用 CLI 覆盖(仅 Web 服务器参数)
if let Some(host) = overrides.host {
builder = builder.set_override("server.host", host)?;
}
if let Some(port) = overrides.port {
builder = builder.set_override("server.port", port)?;
}
let settings = builder.build()?;
let config: AppConfig = settings.try_deserialize()?;
// 安全警告:检查是否使用了默认的 JWT 密钥
if config.auth.jwt_secret == "change-this-to-a-strong-secret-key-in-production" {
tracing::warn!("⚠️ 警告:正在使用不安全的默认 JWT 密钥!");
tracing::warn!(" 请通过环境变量 AUTH_JWT_SECRET 或配置文件设置强密钥");
tracing::warn!(" 示例AUTH_JWT_SECRET=your-secure-random-string-here");
}
// 验证数据库配置
if let Err(e) = config.database.validate() {
tracing::error!("数据库配置无效: {}", e);
return Err(ConfigError::Message(format!("数据库配置无效: {}", e)));
}
Ok(config)
}
/// 从指定路径加载配置
pub fn load_from_path(path: &str) -> Result<Self, ConfigError> {
tracing::info!("Loading configuration from: {}", path);
let settings = Config::builder()
.add_source(File::from(PathBuf::from(path)))
.add_source(Environment::default().separator("_"))
.build()?;
let config: AppConfig = settings.try_deserialize()?;
// 安全警告:检查是否使用了默认的 JWT 密钥
if config.auth.jwt_secret == "change-this-to-a-strong-secret-key-in-production" {
tracing::warn!("⚠️ 警告:正在使用不安全的默认 JWT 密钥!");
tracing::warn!(" 请通过环境变量 AUTH_JWT_SECRET 或配置文件设置强密钥");
tracing::warn!(" 示例AUTH_JWT_SECRET=your-secure-random-string-here");
}
// 验证数据库配置
if let Err(e) = config.database.validate() {
tracing::error!("数据库配置无效: {}", e);
return Err(ConfigError::Message(format!("数据库配置无效: {}", e)));
}
Ok(config)
}
}
// 默认值函数(复用)
fn default_server_host() -> String {
"127.0.0.1".to_string()
}
fn default_server_port() -> u16 {
3000
}
fn default_jwt_secret() -> String {
"change-this-to-a-strong-secret-key-in-production".to_string()
}

25
src/config/auth.rs Normal file
View File

@@ -0,0 +1,25 @@
use serde::Deserialize;
#[derive(Debug, Deserialize, Clone)]
pub struct AuthConfig {
#[serde(default = "default_jwt_secret")]
pub jwt_secret: String,
#[serde(default = "default_access_token_expiration_minutes")]
pub access_token_expiration_minutes: u64,
#[serde(default = "default_refresh_token_expiration_days")]
pub refresh_token_expiration_days: i64,
}
fn default_jwt_secret() -> String {
// ⚠️ 警告:这是一个不安全的默认值,仅用于开发测试
// 生产环境必须通过环境变量或配置文件设置强密钥
"change-this-to-a-strong-secret-key-in-production".to_string()
}
fn default_access_token_expiration_minutes() -> u64 {
15
}
fn default_refresh_token_expiration_days() -> i64 {
7
}

168
src/config/database.rs Normal file
View File

@@ -0,0 +1,168 @@
use serde::Deserialize;
use std::path::PathBuf;
/// 数据库类型
#[derive(Debug, Deserialize, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DatabaseType {
MySQL,
SQLite,
PostgreSQL,
}
#[derive(Debug, Deserialize, Clone)]
pub struct DatabaseConfig {
/// 数据库类型
#[serde(default = "default_database_type")]
pub database_type: DatabaseType,
/// 网络数据库配置MySQL/PostgreSQL
pub host: Option<String>,
#[serde(default)]
pub port: Option<u16>,
pub user: Option<String>,
pub password: Option<String>,
pub database: Option<String>,
/// SQLite 文件路径
pub path: Option<PathBuf>,
/// 连接池最大连接数
#[serde(default = "default_max_connections")]
pub max_connections: u32,
}
impl DatabaseConfig {
/// 获取端口号(根据数据库类型返回默认值)
pub fn get_port(&self) -> u16 {
self.port.unwrap_or_else(|| match self.database_type {
DatabaseType::MySQL => 3306,
DatabaseType::PostgreSQL => 5432,
DatabaseType::SQLite => 0,
})
}
/// 构建数据库连接 URL
///
/// # 错误
///
/// 当缺少必需的配置字段时返回错误
pub fn build_url(&self) -> Result<String, String> {
match self.database_type {
DatabaseType::MySQL => {
let host = self
.host
.as_ref()
.ok_or_else(|| "MySQL 需要配置 database.host".to_string())?;
let user = self
.user
.as_ref()
.ok_or_else(|| "MySQL 需要配置 database.user".to_string())?;
let password = self
.password
.as_ref()
.ok_or_else(|| "MySQL 需要配置 database.password".to_string())?;
let database = self
.database
.as_ref()
.ok_or_else(|| "MySQL 需要配置 database.database".to_string())?;
Ok(format!(
"mysql://{}:{}@{}:{}/{}",
user,
password,
host,
self.get_port(),
database
))
}
DatabaseType::SQLite => {
let path = self
.path
.as_ref()
.ok_or_else(|| "SQLite 需要配置 database.path".to_string())?;
// SQLite URL 格式
// 相对路径sqlite:./db.sqlite3
// 绝对路径sqlite:C:/path/to/db.sqlite3
let path_str = path.to_string_lossy().replace('\\', "/");
Ok(format!("sqlite:{}", path_str))
}
DatabaseType::PostgreSQL => {
let host = self
.host
.as_ref()
.ok_or_else(|| "PostgreSQL 需要配置 database.host".to_string())?;
let user = self
.user
.as_ref()
.ok_or_else(|| "PostgreSQL 需要配置 database.user".to_string())?;
let password = self
.password
.as_ref()
.ok_or_else(|| "PostgreSQL 需要配置 database.password".to_string())?;
let database = self
.database
.as_ref()
.ok_or_else(|| "PostgreSQL 需要配置 database.database".to_string())?;
Ok(format!(
"postgresql://{}:{}@{}:{}/{}",
user,
password,
host,
self.get_port(),
database
))
}
}
}
/// 验证配置是否完整
pub fn validate(&self) -> Result<(), String> {
match self.database_type {
DatabaseType::MySQL => {
if self.host.is_none() {
return Err("MySQL 需要配置 database.host".to_string());
}
if self.user.is_none() {
return Err("MySQL 需要配置 database.user".to_string());
}
if self.password.is_none() {
return Err("MySQL 需要配置 database.password".to_string());
}
if self.database.is_none() {
return Err("MySQL 需要配置 database.database".to_string());
}
}
DatabaseType::SQLite => {
if self.path.is_none() {
return Err("SQLite 需要配置 database.path".to_string());
}
}
DatabaseType::PostgreSQL => {
if self.host.is_none() {
return Err("PostgreSQL 需要配置 database.host".to_string());
}
if self.user.is_none() {
return Err("PostgreSQL 需要配置 database.user".to_string());
}
if self.password.is_none() {
return Err("PostgreSQL 需要配置 database.password".to_string());
}
if self.database.is_none() {
return Err("PostgreSQL 需要配置 database.database".to_string());
}
}
}
Ok(())
}
}
fn default_database_type() -> DatabaseType {
DatabaseType::MySQL
}
fn default_max_connections() -> u32 {
10
}

5
src/config/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod app;
pub mod auth;
pub mod database;
pub mod redis;
pub mod server;

52
src/config/redis.rs Normal file
View File

@@ -0,0 +1,52 @@
use serde::Deserialize;
#[derive(Debug, Deserialize, Clone)]
pub struct RedisConfig {
/// Redis 主机地址
#[serde(default = "default_redis_host")]
pub host: String,
/// Redis 端口
#[serde(default = "default_redis_port")]
pub port: u16,
/// Redis 密码(可选)
#[serde(default)]
pub password: Option<String>,
/// Redis 数据库编号(可选)
#[serde(default = "default_redis_db")]
pub db: u8,
}
pub fn default_redis_host() -> String {
"localhost".to_string()
}
pub fn default_redis_port() -> u16 {
6379
}
pub fn default_redis_db() -> u8 {
0
}
impl RedisConfig {
/// 构建 Redis 连接 URL
pub fn build_url(&self) -> String {
// 判断密码是否存在且非空
match &self.password {
Some(password) if !password.is_empty() => {
// 有密码redis://:password@host:port/db
format!(
"redis://:{}@{}:{}/{}",
password, self.host, self.port, self.db
)
}
_ => {
// 无密码None 或空字符串redis://host:port/db
format!("redis://{}:{}/{}", self.host, self.port, self.db)
}
}
}
}

17
src/config/server.rs Normal file
View File

@@ -0,0 +1,17 @@
use serde::Deserialize;
#[derive(Debug, Deserialize, Clone)]
pub struct ServerConfig {
#[serde(default = "default_server_host")]
pub host: String,
#[serde(default = "default_server_port")]
pub port: u16,
}
fn default_server_host() -> String {
"127.0.0.1".to_string()
}
fn default_server_port() -> u16 {
3000
}

329
src/db.rs Normal file
View File

@@ -0,0 +1,329 @@
use crate::config::database::{DatabaseConfig, DatabaseType};
use sea_orm::{
ConnectionTrait, Database, DatabaseConnection, DbBackend, EntityName, EntityTrait, ConnectOptions, Schema,
Statement,
};
use std::time::Duration;
/// 数据库连接池SeaORM 统一接口)
pub type DbPool = DatabaseConnection;
/// 创建数据库连接池
pub async fn create_pool(config: &DatabaseConfig) -> anyhow::Result<DbPool> {
let url = config
.build_url()
.map_err(|e| anyhow::anyhow!("数据库配置错误: {}", e))?;
tracing::debug!("数据库连接 URL: {}", url);
let mut opt = ConnectOptions::new(&url);
opt.max_connections(config.max_connections)
.min_connections(1)
.connect_timeout(Duration::from_secs(8))
.idle_timeout(Duration::from_secs(8))
.max_lifetime(Duration::from_secs(7200))
.sqlx_logging(true);
let pool = Database::connect(opt)
.await
.map_err(|e| anyhow::anyhow!("数据库连接失败: {}", e))?;
tracing::info!("已连接到数据库: {}", sanitize_url(&url));
Ok(pool)
}
/// 隐藏 URL 中的敏感信息(用于日志输出)
fn sanitize_url(url: &str) -> String {
// 隐藏密码mysql://user:password@host -> mysql://user:***@host
if let Some(at_pos) = url.find('@') {
if let Some(scheme_end) = url.find("://") {
if scheme_end < at_pos {
return format!("{}***@{}", &url[..scheme_end + 3], &url[at_pos + 1..]);
}
}
}
url.to_string()
}
/// 健康检查(保持向后兼容)
pub async fn health_check(pool: &DbPool) -> anyhow::Result<()> {
// 使用官方推荐的 ping 方法
pool.ping()
.await
.map_err(|e| anyhow::anyhow!("数据库健康检查失败: {}", e))
}
/// 初始化数据库和表结构
/// 每次启动时检查数据库和表是否存在,不存在则创建
pub async fn init_database(config: &DatabaseConfig) -> anyhow::Result<DatabaseConnection> {
match config.database_type {
DatabaseType::MySQL => {
init_mysql_database(config).await?;
}
DatabaseType::PostgreSQL => {
init_postgresql_database(config).await?;
}
DatabaseType::SQLite => {
// 确保 SQLite 数据库文件的目录存在
init_sqlite_database(config).await?;
}
}
// 连接到数据库
let pool = create_pool(config).await?;
// 创建表
create_tables(&pool).await?;
Ok(pool)
}
/// 获取端口号(根据数据库类型返回默认值)
fn get_database_port(config: &DatabaseConfig) -> u16 {
config.port.unwrap_or_else(|| match config.database_type {
DatabaseType::MySQL => 3306,
DatabaseType::PostgreSQL => 5432,
DatabaseType::SQLite => 0,
})
}
/// 为 MySQL 创建数据库(如果不存在)
async fn init_mysql_database(config: &DatabaseConfig) -> anyhow::Result<()> {
let database_name = config
.database
.as_ref()
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.database"))?;
let host = config
.host
.as_ref()
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.host"))?;
let user = config
.user
.as_ref()
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.user"))?;
let password = config
.password
.as_ref()
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.password"))?;
// 连接到 MySQL 服务器(不指定数据库)
let url = format!(
"mysql://{}:{}@{}:{}",
user,
password,
host,
get_database_port(config)
);
let mut opt = ConnectOptions::new(&url);
opt.max_connections(1)
.connect_timeout(Duration::from_secs(8))
.sqlx_logging(true);
let conn = Database::connect(opt)
.await
.map_err(|e| anyhow::anyhow!("连接 MySQL 服务器失败: {}", e))?;
// 检查数据库是否存在,不存在则创建
let query = format!(
"CREATE DATABASE IF NOT EXISTS `{}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
database_name
);
conn.execute(Statement::from_string(
sea_orm::DatabaseBackend::MySql,
query,
))
.await
.map_err(|e| anyhow::anyhow!("创建 MySQL 数据库失败: {}", e))?;
tracing::info!("✅ MySQL 数据库 '{}' 检查完成", database_name);
Ok(())
}
/// 为 PostgreSQL 创建数据库(如果不存在)
async fn init_postgresql_database(config: &DatabaseConfig) -> anyhow::Result<()> {
let database_name = config
.database
.as_ref()
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.database"))?;
let host = config
.host
.as_ref()
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.host"))?;
let user = config
.user
.as_ref()
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.user"))?;
let password = config
.password
.as_ref()
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.password"))?;
// 连接到 PostgreSQL 默认数据库postgres
let url = format!(
"postgresql://{}:{}@{}:{}/postgres",
user,
password,
host,
get_database_port(config)
);
let mut opt = ConnectOptions::new(&url);
opt.max_connections(1)
.connect_timeout(Duration::from_secs(8))
.sqlx_logging(true);
let conn = Database::connect(opt)
.await
.map_err(|e| anyhow::anyhow!("连接 PostgreSQL 服务器失败: {}", e))?;
// 检查数据库是否存在,不存在则创建
// PostgreSQL 不支持 CREATE DATABASE IF NOT EXISTS需要先查询
let check_query = format!(
"SELECT 1 FROM pg_database WHERE datname='{}'",
database_name
);
let result = conn
.execute(Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
check_query,
))
.await;
match result {
Ok(_) => {
tracing::info!("PostgreSQL 数据库 '{}' 已存在", database_name);
}
Err(_) => {
// 数据库不存在,创建它
let create_query = format!(
"CREATE DATABASE {} WITH ENCODING 'UTF8' LC_COLLATE='en_US.UTF-8' LC_CTYPE='en_US.UTF-8'",
database_name
);
conn.execute(Statement::from_string(
sea_orm::DatabaseBackend::Postgres,
create_query,
))
.await
.map_err(|e| anyhow::anyhow!("创建 PostgreSQL 数据库失败: {}", e))?;
tracing::info!("✅ PostgreSQL 数据库 '{}' 创建成功", database_name);
}
}
Ok(())
}
/// 为 SQLite 确保数据库文件目录存在
async fn init_sqlite_database(config: &DatabaseConfig) -> anyhow::Result<()> {
let path = config
.path
.as_ref()
.ok_or_else(|| anyhow::anyhow!("SQLite 需要配置 database.path"))?;
// 如果是相对路径,转换为绝对路径
let absolute_path = if path.is_absolute() {
path.clone()
} else {
std::env::current_dir()
.map_err(|e| anyhow::anyhow!("获取当前目录失败: {}", e))?
.join(path)
};
tracing::info!("SQLite 数据库路径: {}", absolute_path.display());
// 获取数据库文件的父目录
if let Some(parent) = absolute_path.parent() {
// 如果父目录不存在,则创建
if !parent.exists() {
std::fs::create_dir_all(parent)
.map_err(|e| anyhow::anyhow!("创建 SQLite 数据库目录失败: {}", e))?;
tracing::info!("✅ SQLite 数据库目录创建成功: {}", parent.display());
} else {
tracing::info!("SQLite 数据库目录已存在: {}", parent.display());
}
}
// 如果数据库文件不存在,创建空文件
if !absolute_path.exists() {
std::fs::File::create(&absolute_path)
.map_err(|e| anyhow::anyhow!("创建 SQLite 数据库文件失败: {}", e))?;
tracing::info!("✅ SQLite 数据库文件创建成功: {}", absolute_path.display());
} else {
tracing::info!("SQLite 数据库文件已存在: {}", absolute_path.display());
}
Ok(())
}
/// 辅助函数:创建单个表(如果不存在)
async fn create_single_table<E>(
db: &DatabaseConnection,
schema: &Schema,
builder: &DbBackend,
entity: E,
table_name: &str,
) -> anyhow::Result<()>
where
E: EntityName + EntityTrait,
{
let create_table = schema.create_table_from_entity(entity);
let sql = match builder {
DbBackend::MySql => {
use sea_orm::sea_query::MysqlQueryBuilder;
create_table.to_string(MysqlQueryBuilder {})
}
DbBackend::Postgres => {
use sea_orm::sea_query::PostgresQueryBuilder;
create_table.to_string(PostgresQueryBuilder {})
}
DbBackend::Sqlite => {
use sea_orm::sea_query::SqliteQueryBuilder;
create_table.to_string(SqliteQueryBuilder {})
}
};
let sql = sql.replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS");
match db.execute(Statement::from_string(*builder, sql)).await {
Ok(_) => {
tracing::info!("✅ {}检查完成", table_name);
}
Err(e) => {
let err_msg = e.to_string();
if err_msg.contains("already exists") || (err_msg.contains("table") && err_msg.contains("exists")) {
tracing::info!("✅ {}已存在", table_name);
} else {
return Err(anyhow::anyhow!("创建{}失败: {}", table_name, e));
}
}
}
Ok(())
}
/// 创建数据库表结构
async fn create_tables(db: &DatabaseConnection) -> anyhow::Result<()> {
tracing::info!("检查数据库表结构...");
let builder = db.get_database_backend();
let schema = Schema::new(builder);
// 导入所有 entities
use crate::domain::entities::users;
// 创建所有表(添加新表只需一行!)
create_single_table(db, &schema, &builder, users::Entity, "用户表").await?;
tracing::info!("✅ 数据库表结构检查完成");
Ok(())
}

57
src/domain/dto/auth.rs Normal file
View File

@@ -0,0 +1,57 @@
use serde::Deserialize;
use std::fmt;
/// 注册请求
#[derive(Deserialize)]
pub struct RegisterRequest {
pub email: String,
pub password: String,
}
// 实现 Debug trait对密码进行脱敏
impl fmt::Debug for RegisterRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RegisterRequest {{ email: {}, password: *** }}", self.email)
}
}
/// 登录请求
#[derive(Deserialize)]
pub struct LoginRequest {
pub email: String,
pub password: String,
}
// 实现 Debug trait
impl fmt::Debug for LoginRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "LoginRequest {{ email: {}, password: *** }}", self.email)
}
}
/// 删除用户请求
#[derive(Deserialize)]
pub struct DeleteUserRequest {
pub user_id: String,
pub password: String,
}
// 实现 Debug trait
impl fmt::Debug for DeleteUserRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "DeleteUserRequest {{ user_id: {}, password: *** }}", self.user_id)
}
}
/// 刷新令牌请求
#[derive(Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
// RefreshRequest 的 refresh_token 是敏感字段,需要脱敏
impl fmt::Debug for RefreshRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RefreshRequest {{ refresh_token: *** }}")
}
}

1
src/domain/dto/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod auth;

1
src/domain/dto/user.rs Normal file
View File

@@ -0,0 +1 @@
// 用户相关 DTO预留

View File

@@ -0,0 +1,2 @@
pub mod users;

View File

@@ -0,0 +1,41 @@
use sea_orm::entity::prelude::*;
use sea_orm::Set;
use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "users")]
pub struct Model {
#[sea_orm(primary_key, auto_increment = false)]
pub id: String,
#[sea_orm(unique)]
pub email: String,
pub password_hash: String,
pub created_at: DateTime,
pub updated_at: DateTime,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
#[async_trait::async_trait]
impl ActiveModelBehavior for ActiveModel {
/// 在保存前自动填充时间戳
async fn before_save<C>(self, _db: &C, insert: bool) -> Result<Self, DbErr>
where
C: ConnectionTrait,
{
let mut this = self;
let now = chrono::Utc::now().naive_utc();
if insert {
// 插入时:设置创建时间和更新时间
this.created_at = Set(now);
this.updated_at = Set(now);
} else {
// 更新时:只更新更新时间
this.updated_at = Set(now);
}
Ok(this)
}
}

3
src/domain/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod dto;
pub mod vo;
pub mod entities;

50
src/domain/vo/auth.rs Normal file
View File

@@ -0,0 +1,50 @@
use serde::Serialize;
/// 注册结果
#[derive(Debug, Serialize)]
pub struct RegisterResult {
pub email: String,
pub created_at: String, // ISO 8601 格式
pub access_token: String,
pub refresh_token: String,
}
impl From<(crate::domain::entities::users::Model, String, String)> for RegisterResult {
fn from((user_model, access_token, refresh_token): (crate::domain::entities::users::Model, String, String)) -> Self {
Self {
email: user_model.email,
created_at: user_model.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
access_token,
refresh_token,
}
}
}
/// 登录结果
#[derive(Debug, Serialize)]
pub struct LoginResult {
pub id: String,
pub email: String,
pub created_at: String, // ISO 8601 格式
pub access_token: String,
pub refresh_token: String,
}
impl From<(crate::domain::entities::users::Model, String, String)> for LoginResult {
fn from((user_model, access_token, refresh_token): (crate::domain::entities::users::Model, String, String)) -> Self {
Self {
id: user_model.id,
email: user_model.email,
created_at: user_model.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
access_token,
refresh_token,
}
}
}
/// 刷新 Token 结果
#[derive(Debug, Serialize)]
pub struct RefreshResult {
pub access_token: String,
pub refresh_token: String,
}

56
src/domain/vo/mod.rs Normal file
View File

@@ -0,0 +1,56 @@
pub mod auth;
pub mod user;
/// 统一的 API 响应结构
use serde::Serialize;
use axum::http::StatusCode;
#[derive(Debug, Serialize)]
pub struct ApiResponse<T> {
/// HTTP 状态码
pub code: u16,
/// 响应消息
pub message: String,
/// 响应数据
pub data: Option<T>,
}
impl<T: Serialize> ApiResponse<T> {
/// 成功响应200
pub fn success(data: T) -> Self {
Self {
code: 200,
message: "Success".to_string(),
data: Some(data),
}
}
/// 成功响应(自定义消息)
pub fn success_with_message(data: T, message: &str) -> Self {
Self {
code: 200,
message: message.to_string(),
data: Some(data),
}
}
/// 错误响应
#[allow(dead_code)]
pub fn error(status_code: StatusCode, message: &str) -> ApiResponse<()> {
ApiResponse {
code: status_code.as_u16(),
message: message.to_string(),
data: None,
}
}
/// 错误响应(带数据)
#[allow(dead_code)]
pub fn error_with_data(status_code: StatusCode, message: &str, data: T) -> ApiResponse<T> {
ApiResponse {
code: status_code.as_u16(),
message: message.to_string(),
data: Some(data),
}
}
}

1
src/domain/vo/user.rs Normal file
View File

@@ -0,0 +1 @@
// 用户相关 VO预留

89
src/error.rs Normal file
View File

@@ -0,0 +1,89 @@
use crate::domain::vo::ApiResponse;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
/// 应用错误类型
#[allow(dead_code)]
pub struct AppError(pub anyhow::Error);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
tracing::error!("Application error: {:?}", self.0);
let (status, message) = match self.0.downcast_ref::<&str>() {
Some(&"not_found") => (StatusCode::NOT_FOUND, "Resource not found"),
Some(&"unauthorized") => (StatusCode::UNAUTHORIZED, "Unauthorized"),
_ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
};
let body = ApiResponse::<()> {
code: status.as_u16(),
message: message.to_string(),
data: None,
};
(status, Json(body)).into_response()
}
}
// 为具体类型实现 From
impl From<anyhow::Error> for AppError {
fn from(err: anyhow::Error) -> Self {
Self(err)
}
}
/// 统一的 API 错误响应结构
#[derive(Debug)]
pub struct ErrorResponse {
pub status: StatusCode,
pub message: String,
}
impl ErrorResponse {
pub fn new(message: impl Into<String>) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
message: message.into(),
}
}
#[allow(dead_code)]
pub fn not_found(message: impl Into<String>) -> Self {
Self {
status: StatusCode::NOT_FOUND,
message: message.into(),
}
}
#[allow(dead_code)]
pub fn unauthorized(message: impl Into<String>) -> Self {
Self {
status: StatusCode::UNAUTHORIZED,
message: message.into(),
}
}
#[allow(dead_code)]
pub fn internal(message: impl Into<String>) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
message: message.into(),
}
}
}
impl IntoResponse for ErrorResponse {
fn into_response(self) -> Response {
let body = ApiResponse::<()> {
code: self.status.as_u16(),
message: self.message,
data: None,
};
(self.status, Json(body)).into_response()
}
}

157
src/handlers/auth.rs Normal file
View File

@@ -0,0 +1,157 @@
use crate::error::ErrorResponse;
use crate::infra::middleware::logging::{log_info, RequestId};
use crate::domain::dto::auth::{RegisterRequest, LoginRequest, RefreshRequest, DeleteUserRequest};
use crate::domain::vo::auth::{RegisterResult, LoginResult, RefreshResult};
use crate::domain::vo::ApiResponse;
use crate::repositories::user_repository::UserRepository;
use crate::services::auth_service::AuthService;
use crate::AppState;
use axum::{
extract::{Extension, State},
Json,
};
use serde_json::json;
/// 注册
pub async fn register(
Extension(request_id): Extension<RequestId>,
State(state): State<AppState>,
Json(payload): Json<RegisterRequest>,
) -> Result<Json<ApiResponse<RegisterResult>>, ErrorResponse> {
log_info(&request_id, "注册请求参数", &payload);
let user_repo = UserRepository::new(state.pool.clone());
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
match service.register(payload).await {
Ok((user_model, access_token, refresh_token)) => {
let data = RegisterResult::from((user_model, access_token, refresh_token));
let response = ApiResponse::success(data);
log_info(&request_id, "注册成功", &response);
Ok(Json(response))
}
Err(e) => {
log_info(&request_id, "注册失败", &e.to_string());
Err(ErrorResponse::new(e.to_string()))
}
}
}
/// 登录
pub async fn login(
Extension(request_id): Extension<RequestId>,
State(state): State<AppState>,
Json(payload): Json<LoginRequest>,
) -> Result<Json<ApiResponse<LoginResult>>, ErrorResponse> {
log_info(&request_id, "登录请求参数", &payload);
let user_repo = UserRepository::new(state.pool.clone());
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
match service.login(payload).await {
Ok((user_model, access_token, refresh_token)) => {
let data = LoginResult::from((user_model, access_token, refresh_token));
let response = ApiResponse::success(data);
log_info(&request_id, "登录成功", &response);
Ok(Json(response))
}
Err(e) => {
log_info(&request_id, "登录失败", &e.to_string());
Err(ErrorResponse::new(e.to_string()))
}
}
}
/// 刷新 Token
pub async fn refresh(
Extension(request_id): Extension<RequestId>,
State(state): State<AppState>,
Json(payload): Json<RefreshRequest>,
) -> Result<Json<ApiResponse<RefreshResult>>, ErrorResponse> {
log_info(
&request_id,
"刷新 token 请求",
&json!({"device_id": "default"}),
);
let user_repo = UserRepository::new(state.pool.clone());
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
match service
.refresh_access_token(&payload.refresh_token)
.await
{
Ok((access_token, refresh_token)) => {
let data = RefreshResult {
access_token,
refresh_token,
};
let response = ApiResponse::success(data);
log_info(
&request_id,
"刷新成功",
&json!({"access_token": "***"}),
);
Ok(Json(response))
}
Err(e) => {
log_info(&request_id, "刷新失败", &e.to_string());
Err(ErrorResponse::new(e.to_string()))
}
}
}
/// 删除账号
pub async fn delete_account(
Extension(request_id): Extension<RequestId>,
State(state): State<AppState>,
Extension(user_id): Extension<String>,
Json(payload): Json<DeleteUserRequest>,
) -> Result<Json<ApiResponse<()>>, ErrorResponse> {
log_info(&request_id, "删除账号请求", &format!("user_id={}", user_id));
let user_repo = UserRepository::new(state.pool.clone());
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
let delete_request = DeleteUserRequest {
user_id: user_id.clone(),
password: payload.password,
};
match service.delete_user(delete_request).await {
Ok(_) => {
log_info(&request_id, "账号删除成功", &format!("user_id={}", user_id));
let response = ApiResponse::success_with_message((), "账号删除成功");
Ok(Json(response))
}
Err(e) => {
log_info(&request_id, "账号删除失败", &e.to_string());
Err(ErrorResponse::new(e.to_string()))
}
}
}
/// 刷新令牌
pub async fn delete_refresh_token(
Extension(request_id): Extension<RequestId>,
State(state): State<AppState>,
Extension(user_id): Extension<String>,
) -> Result<Json<ApiResponse<()>>, ErrorResponse> {
log_info(&request_id, "删除刷新令牌请求", &format!("user_id={}", user_id));
let user_repo = UserRepository::new(state.pool.clone());
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
match service.delete_refresh_token(&user_id).await {
Ok(_) => {
log_info(&request_id, "刷新令牌删除成功", &format!("user_id={}", user_id));
let response = ApiResponse::success_with_message((), "刷新令牌删除成功");
Ok(Json(response))
}
Err(e) => {
log_info(&request_id, "刷新令牌删除失败", &e.to_string());
Err(ErrorResponse::new(e.to_string()))
}
}
}

25
src/handlers/health.rs Normal file
View File

@@ -0,0 +1,25 @@
use crate::AppState;
use crate::db;
use axum::{
extract::State,
response::{IntoResponse, Json},
};
use serde_json::json;
/// 健康检查端点
pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
match db::health_check(&state.pool).await {
Ok(_) => Json(json!({"status": "ok"})),
Err(_) => Json(json!({"status": "unavailable"})),
}
}
/// 获取服务器信息
pub async fn server_info() -> impl IntoResponse {
Json(json!({
"name": "web-rust-template",
"version": "0.1.0",
"status": "running",
"timestamp": chrono::Utc::now().timestamp()
}))
}

2
src/handlers/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod auth;
pub mod health;

View File

@@ -0,0 +1,51 @@
use crate::AppState;
use axum::{
extract::{Request, State},
http::{HeaderMap, StatusCode},
middleware::Next,
response::Response,
};
use jsonwebtoken::{decode, DecodingKey, Validation};
use serde::Deserialize;
#[derive(Deserialize)]
pub struct Claims {
pub sub: String, // user_id
#[allow(dead_code)]
pub exp: usize,
}
/// JWT 认证中间件
pub async fn auth_middleware(
State(state): State<AppState>,
headers: HeaderMap,
mut req: Request,
next: Next,
) -> Result<Response, StatusCode> {
// 1. 提取 Authorization header
let auth_header = headers
.get("Authorization")
.and_then(|h| h.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;
if !auth_header.starts_with("Bearer ") {
return Err(StatusCode::UNAUTHORIZED);
}
let token = &auth_header[7..];
// 2. 验证 JWT
let jwt_secret = &state.config.auth.jwt_secret;
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(jwt_secret.as_ref()),
&Validation::default(),
)
.map_err(|_| StatusCode::UNAUTHORIZED)?;
// 3. 将 user_id 添加到请求扩展
req.extensions_mut().insert(token_data.claims.sub);
Ok(next.run(req).await)
}

View File

@@ -0,0 +1,71 @@
use axum::{extract::Request, response::Response};
use std::time::Instant;
/// Request ID 标记
#[derive(Clone)]
pub struct RequestId(pub String);
/// 请求日志中间件
pub async fn request_logging_middleware(
mut req: Request,
next: axum::middleware::Next,
) -> Response {
let start = Instant::now();
// 提取请求信息
let method = req.method().clone();
let path = req.uri().path().to_string();
let query = req.uri().query().map(|s| s.to_string());
// 生成请求 ID
let request_id = uuid::Uuid::new_v4().to_string();
// 将 request_id 存储到请求扩展中
req.extensions_mut().insert(RequestId(request_id.clone()));
// 第1条日志请求开始
let separator = "=".repeat(80);
let header = format!("{} {}", method, path);
tracing::info!("{}", separator);
tracing::info!("{}", header);
tracing::info!("{}", separator);
let now_beijing = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
let query_str = query.as_deref().unwrap_or("");
tracing::info!(
"[{}] 📥 查询参数: {} | 时间: {}",
request_id,
query_str,
now_beijing
);
// 调用下一个处理器
let response = next.run(req).await;
// 第3条日志请求完成
let duration = start.elapsed();
let status = response.status();
tracing::info!(
"[{}] ✅ 状态码: {} | 耗时: {}ms",
request_id,
status.as_u16(),
duration.as_millis()
);
tracing::info!("{}", separator);
response
}
/// 请求日志辅助工具
pub fn log_info<T: std::fmt::Debug>(request_id: &RequestId, label: &str, data: T) {
let data_str = format!("{:?}", data);
let truncated = if data_str.len() > 300 {
format!("{}...", &data_str[..300])
} else {
data_str
};
tracing::info!("[{}] 🔧 {} | {}", request_id.0, label, truncated);
}

View File

@@ -0,0 +1,2 @@
pub mod auth;
pub mod logging;

2
src/infra/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod middleware;
pub mod redis;

20
src/infra/redis/errors.rs Normal file
View File

@@ -0,0 +1,20 @@
use thiserror::Error;
/// Redis 错误类型
#[derive(Error, Debug)]
pub enum RedisError {
#[error("Redis 连接失败: {0}")]
ConnectionError(#[from] redis::RedisError),
#[error("Redis 序列化失败: {0}")]
SerializationError(#[from] serde_json::Error),
#[error("Redis 数据不存在: {key}")]
NotFound { key: String },
#[error("Redis 操作失败: {message}")]
OperationError { message: String },
#[error("Failed to create redis pool: {0}")]
PoolCreation(#[from] deadpool_redis::CreatePoolError),
}

2
src/infra/redis/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod redis_client;
pub mod redis_key;

View File

@@ -0,0 +1,135 @@
use super::redis_key::RedisKey;
use redis::aio::MultiplexedConnection;
use redis::{AsyncCommands, Client};
use serde::Serialize;
use std::sync::Arc;
use tokio::sync::Mutex;
/// Redis 客户端(使用 MultiplexedConnection
#[derive(Clone)]
pub struct RedisClient {
conn: Arc<Mutex<MultiplexedConnection>>,
}
impl RedisClient {
/// 创建新的 Redis 客户端
pub async fn new(url: &str) -> redis::RedisResult<Self> {
let client = Client::open(url)?;
let conn = client.get_multiplexed_async_connection().await?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
})
}
/// 设置字符串值
pub async fn set(&self, k: &str, v: &str) -> redis::RedisResult<()> {
let mut c = self.conn.lock().await;
c.set(k, v).await
}
/// 获取字符串值
pub async fn get(&self, k: &str) -> redis::RedisResult<Option<String>> {
let mut c = self.conn.lock().await;
c.get(k).await
}
/// 设置字符串值并指定过期时间(秒)
pub async fn set_ex(&self, k: &str, v: &str, seconds: u64) -> redis::RedisResult<()> {
let mut c = self.conn.lock().await;
c.set_ex(k, v, seconds).await
}
/// 删除键
pub async fn del(&self, k: &str) -> redis::RedisResult<()> {
let mut c = self.conn.lock().await;
c.del(k).await
}
/// 设置键的过期时间(秒)
pub async fn expire(&self, k: &str, seconds: u64) -> redis::RedisResult<()> {
let mut c = self.conn.lock().await;
c.expire(k, seconds as i64).await
}
/// 使用 RedisKey 设置 JSON 值
pub async fn set_key<T: Serialize>(
&self,
key: &RedisKey,
value: &T,
) -> redis::RedisResult<()> {
let json = serde_json::to_string(value).map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::TypeError,
"JSON serialization failed",
e.to_string(),
))
})?;
let mut c = self.conn.lock().await;
c.set(key.build(), json).await
}
/// 使用 RedisKey 设置 JSON 值并指定过期时间(秒)
pub async fn set_key_ex<T: Serialize>(
&self,
key: &RedisKey,
value: &T,
expiration_seconds: u64,
) -> redis::RedisResult<()> {
let json = serde_json::to_string(value).map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::TypeError,
"JSON serialization failed",
e.to_string(),
))
})?;
let mut c = self.conn.lock().await;
c.set_ex(key.build(), json, expiration_seconds).await
}
/// 使用 RedisKey 获取字符串值
pub async fn get_key(&self, key: &RedisKey) -> redis::RedisResult<Option<String>> {
let mut c = self.conn.lock().await;
let json: Option<String> = c.get(key.build()).await?;
Ok(json)
}
/// 使用 RedisKey 获取并反序列化 JSON 值
pub async fn get_key_json<T: for<'de> serde::Deserialize<'de>>(
&self,
key: &RedisKey,
) -> redis::RedisResult<Option<T>> {
let mut c = self.conn.lock().await;
let json: Option<String> = c.get(key.build()).await?;
match json {
Some(data) => {
let value = serde_json::from_str(&data).map_err(|e| {
redis::RedisError::from((
redis::ErrorKind::TypeError,
"JSON deserialization failed",
e.to_string(),
))
})?;
Ok(Some(value))
}
None => Ok(None),
}
}
/// 使用 RedisKey 删除键
pub async fn delete_key(&self, key: &RedisKey) -> redis::RedisResult<()> {
let mut c = self.conn.lock().await;
c.del(key.build()).await
}
/// 使用 RedisKey 检查键是否存在
pub async fn exists_key(&self, key: &RedisKey) -> redis::RedisResult<bool> {
let mut c = self.conn.lock().await;
c.exists(key.build()).await
}
/// 使用 RedisKey 设置键的过期时间(秒)
pub async fn expire_key(&self, key: &RedisKey, seconds: u64) -> redis::RedisResult<()> {
let mut c = self.conn.lock().await;
c.expire(key.build(), seconds as i64).await
}
}

View File

@@ -0,0 +1,61 @@
use serde::{Deserialize, Serialize};
use std::fmt;
/// 业务类型枚举
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BusinessType {
#[serde(rename = "auth")]
Auth,
#[serde(rename = "user")]
User,
#[serde(rename = "cache")]
Cache,
#[serde(rename = "session")]
Session,
#[serde(rename = "rate_limit")]
RateLimit,
}
impl BusinessType {
pub fn prefix(self) -> &'static str {
match self {
BusinessType::Auth => "auth",
BusinessType::User => "user",
BusinessType::Cache => "cache",
BusinessType::Session => "session",
BusinessType::RateLimit => "rate_limit",
}
}
}
/// Redis 键构建器
#[derive(Debug, Clone)]
pub struct RedisKey {
business: BusinessType,
identifiers: Vec<String>,
}
impl RedisKey {
pub fn new(business: BusinessType) -> Self {
Self {
business,
identifiers: Vec::new(),
}
}
pub fn add_identifier(mut self, id: impl Into<String>) -> Self {
self.identifiers.push(id.into());
self
}
pub fn build(&self) -> String {
format!("{}:{}", self.business.prefix(), self.identifiers.join(":"))
}
}
// 兼容现有格式
impl fmt::Display for RedisKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.build())
}
}

131
src/main.rs Normal file
View File

@@ -0,0 +1,131 @@
mod cli;
mod config;
mod db;
mod domain;
mod error;
mod handlers;
mod infra;
mod repositories;
mod services;
mod utils;
use axum::{
routing::{get, post},
Router,
};
use clap::Parser;
use cli::CliArgs;
use tower_http::cors::{Any, CorsLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
/// 应用状态
#[derive(Clone)]
pub struct AppState {
pub pool: db::DbPool,
pub config: config::app::AppConfig,
pub redis_client: infra::redis::redis_client::RedisClient,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// 解析命令行参数
let args = CliArgs::parse();
// 初始化日志
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| args.get_log_filter().into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// 打印启动信息
args.print_startup_info();
// 设置工作目录(如果指定)
if let Some(ref work_dir) = args.work_dir {
std::env::set_current_dir(work_dir).ok();
println!("Working directory set to: {}", work_dir.display());
}
// 解析配置文件路径(可选)
let config_path = args.resolve_config_path();
// 加载配置(支持 CLI 覆盖)
// 如果没有配置文件,将仅使用环境变量和默认值
let config = config::app::AppConfig::load_with_overrides(
config_path,
args.get_overrides(),
args.env.as_str(),
)?;
tracing::info!("Configuration loaded successfully");
tracing::info!("Environment: {}", args.env.as_str());
tracing::info!("Debug mode: {}", args.is_debug_enabled());
// 初始化数据库(自动创建数据库和表)
let pool = db::init_database(&config.database).await?;
// 初始化 Redis 客户端
let redis_client = infra::redis::redis_client::RedisClient::new(&config.redis.build_url())
.await
.map_err(|e| anyhow::anyhow!("Redis 初始化失败: {}", e))?;
tracing::info!("Redis 连接池初始化成功");
// 创建应用状态
let app_state = AppState {
pool: pool.clone(),
config: config.clone(),
redis_client,
};
// ========== 公开路由 ==========
let public_routes = Router::new()
.route("/health", get(handlers::health::health_check))
.route("/info", get(handlers::health::server_info))
.route("/auth/register", post(handlers::auth::register))
.route("/auth/login", post(handlers::auth::login))
.route("/auth/refresh", post(handlers::auth::refresh));
// ========== 受保护路由 ==========
let protected_routes = Router::new()
.route("/auth/delete", post(handlers::auth::delete_account))
.route(
"/auth/delete-refresh-token",
post(handlers::auth::delete_refresh_token),
)
// JWT 认证中间件(仅应用于受保护路由)
.route_layer(axum::middleware::from_fn_with_state(
app_state.clone(),
infra::middleware::auth::auth_middleware,
));
// ========== 合并路由 ==========
let app = public_routes
.merge(protected_routes)
// CORS应用于所有路由
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any)
)
// 日志中间件(应用于所有路由)
.layer(axum::middleware::from_fn_with_state(
app_state.clone(),
infra::middleware::logging::request_logging_middleware,
))
.with_state(app_state);
// 启动服务器
let addr = format!("{}:{}", config.server.host, config.server.port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
tracing::info!("Server listening on {}", addr);
tracing::info!("Press Ctrl+C to stop");
axum::serve(listener, app).await?;
Ok(())
}

2
src/repositories/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod user_repository;

View File

@@ -0,0 +1,85 @@
use sea_orm::{EntityTrait, QueryFilter, ColumnTrait, DatabaseConnection, Set, ActiveModelTrait, PaginatorTrait};
use crate::domain::entities::users;
use anyhow::Result;
/// 用户数据访问仓库
pub struct UserRepository {
db: DatabaseConnection,
}
impl UserRepository {
pub fn new(db: DatabaseConnection) -> Self {
Self { db }
}
/// 根据 email 查询用户
pub async fn find_by_email(&self, email: &str) -> Result<Option<users::Model>> {
let user = users::Entity::find()
.filter(users::Column::Email.eq(email))
.one(&self.db)
.await
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
Ok(user)
}
/// 统计邮箱数量
pub async fn count_by_email(&self, email: &str) -> Result<i64> {
let count = users::Entity::find()
.filter(users::Column::Email.eq(email))
.count(&self.db)
.await
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
Ok(count as i64)
}
/// 统计用户 ID 数量
pub async fn count_by_id(&self, id: &str) -> Result<i64> {
let count = users::Entity::find()
.filter(users::Column::Id.eq(id))
.count(&self.db)
.await
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
Ok(count as i64)
}
/// 获取密码哈希
pub async fn get_password_hash(&self, email: &str) -> Result<Option<String>> {
let user = users::Entity::find()
.filter(users::Column::Email.eq(email))
.one(&self.db)
.await
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
Ok(user.map(|u| u.password_hash))
}
/// 插入用户created_at 和 updated_at 会自动填充),返回插入后的用户对象
pub async fn insert(&self, id: String, email: String, password_hash: String) -> Result<users::Model> {
let user_model = users::ActiveModel {
id: Set(id),
email: Set(email),
password_hash: Set(password_hash),
// created_at 和 updated_at 由 ActiveModelBehavior 自动填充
..Default::default()
};
let inserted_user = user_model.insert(&self.db)
.await
.map_err(|e| anyhow::anyhow!("插入失败: {}", e))?;
Ok(inserted_user)
}
/// 根据 ID 删除用户
pub async fn delete_by_id(&self, id: &str) -> Result<()> {
users::Entity::delete_by_id(id)
.exec(&self.db)
.await
.map_err(|e| anyhow::anyhow!("删除失败: {}", e))?;
Ok(())
}
}

View File

@@ -0,0 +1,232 @@
use anyhow::Result;
use argon2::{
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
Argon2,
};
use rand::Rng;
use crate::utils::jwt::TokenService;
use crate::domain::dto::auth::{RegisterRequest, LoginRequest, DeleteUserRequest};
use crate::domain::entities::users;
use crate::config::auth::AuthConfig;
use crate::infra::redis::{redis_client::RedisClient, redis_key::{BusinessType, RedisKey}};
use crate::repositories::user_repository::UserRepository;
pub struct AuthService {
user_repo: UserRepository,
redis_client: RedisClient,
auth_config: AuthConfig,
}
impl AuthService {
pub fn new(user_repo: UserRepository, redis_client: RedisClient, auth_config: AuthConfig) -> Self {
Self { user_repo, redis_client, auth_config }
}
/// 哈希密码
pub fn hash_password(&self, password: &str) -> Result<String> {
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let password_hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| anyhow::anyhow!("密码哈希失败: {}", e))?
.to_string();
Ok(password_hash)
}
/// 生成用户 ID
pub fn generate_user_id(&self) -> String {
let mut rng = rand::thread_rng();
rng.gen_range(1_000_000_000i64..10_000_000_000i64)
.to_string()
}
/// 生成唯一的用户 ID
pub async fn generate_unique_user_id(&self) -> Result<String> {
let mut attempts = 0;
const MAX_ATTEMPTS: u32 = 10;
loop {
let candidate_id = self.generate_user_id();
let existing = self.user_repo.count_by_id(&candidate_id).await?;
if existing == 0 {
return Ok(candidate_id);
}
attempts += 1;
if attempts >= MAX_ATTEMPTS {
return Err(anyhow::anyhow!("生成唯一用户 ID 失败"));
}
}
}
/// 保存 refresh_token 到 Redis
async fn save_refresh_token(&self, user_id: &str, refresh_token: &str, expiration_days: i64) -> Result<()> {
let key = RedisKey::new(BusinessType::Auth)
.add_identifier("refresh_token")
.add_identifier(user_id);
let expiration_seconds = expiration_days * 24 * 3600;
self.redis_client
.set_ex(&key.build(), refresh_token, expiration_seconds as u64)
.await
.map_err(|e| anyhow::anyhow!("Redis 保存失败: {}", e))?;
Ok(())
}
/// 获取并删除 refresh_token
async fn get_and_delete_refresh_token(&self, user_id: &str) -> Result<String> {
let key = RedisKey::new(BusinessType::Auth)
.add_identifier("refresh_token")
.add_identifier(user_id);
let token: Option<String> = self.redis_client
.get(&key.build())
.await
.map_err(|e| anyhow::anyhow!("Redis 查询失败: {}", e))?;
if token.is_some() {
self.redis_client
.delete_key(&key)
.await
.map_err(|e| anyhow::anyhow!("Redis 删除失败: {}", e))?;
}
token.ok_or_else(|| anyhow::anyhow!("刷新令牌无效或已过期"))
}
/// 删除用户的 refresh_token
pub async fn delete_refresh_token(&self, user_id: &str) -> Result<()> {
let key = RedisKey::new(BusinessType::Auth)
.add_identifier("refresh_token")
.add_identifier(user_id);
self.redis_client
.delete_key(&key)
.await
.map_err(|e| anyhow::anyhow!("Redis 删除失败: {}", e))?;
Ok(())
}
/// 注册用户
pub async fn register(
&self,
request: RegisterRequest,
) -> Result<(users::Model, String, String)> {
// 1. 检查邮箱是否已存在
let existing = self.user_repo.count_by_email(&request.email).await?;
if existing > 0 {
return Err(anyhow::anyhow!("邮箱已注册"));
}
// 2. 哈希密码
let password_hash = self.hash_password(&request.password)?;
// 3. 生成用户 ID
let user_id = self.generate_unique_user_id().await?;
// 4. 插入数据库并获取包含真实 created_at 的用户对象
let user = self.user_repo.insert(user_id.clone(), request.email, password_hash).await?;
// 5. 生成 token
let (access_token, refresh_token) = TokenService::generate_token_pair(
&user_id,
self.auth_config.access_token_expiration_minutes,
self.auth_config.refresh_token_expiration_days,
&self.auth_config.jwt_secret,
)?;
// 6. 保存 refresh_token
self.save_refresh_token(&user_id, &refresh_token, self.auth_config.refresh_token_expiration_days as i64).await?;
Ok((user, access_token, refresh_token))
}
/// 登录
pub async fn login(
&self,
request: LoginRequest,
) -> Result<(users::Model, String, String)> {
// 1. 查询用户
let user = self.user_repo.find_by_email(&request.email).await?
.ok_or_else(|| anyhow::anyhow!("邮箱或密码错误"))?;
// 2. 验证密码
let password_hash = self.user_repo.get_password_hash(&request.email).await?
.ok_or_else(|| anyhow::anyhow!("邮箱或密码错误"))?;
let parsed_hash = PasswordHash::new(&password_hash)
.map_err(|e| anyhow::anyhow!("解析密码哈希失败: {}", e))?;
let argon2 = Argon2::default();
argon2
.verify_password(request.password.as_bytes(), &parsed_hash)
.map_err(|_| anyhow::anyhow!("邮箱或密码错误"))?;
// 3. 生成 token
let (access_token, refresh_token) = TokenService::generate_token_pair(
&user.id,
self.auth_config.access_token_expiration_minutes,
self.auth_config.refresh_token_expiration_days,
&self.auth_config.jwt_secret,
)?;
// 4. 保存 refresh_token
self.save_refresh_token(&user.id, &refresh_token, self.auth_config.refresh_token_expiration_days as i64).await?;
Ok((user, access_token, refresh_token))
}
/// 使用 refresh_token 刷新 access_token
pub async fn refresh_access_token(
&self,
refresh_token: &str,
) -> Result<(String, String)> {
// 1. 从 refresh_token 中解码出 user_id
let user_id = TokenService::decode_user_id(refresh_token, &self.auth_config.jwt_secret)?;
// 2. 从 Redis 获取存储的 token 并删除
let stored_token = self.get_and_delete_refresh_token(&user_id).await?;
// 3. 验证 token 是否匹配
if stored_token != refresh_token {
return Err(anyhow::anyhow!("刷新令牌无效"));
}
// 4. 生成新的 token 对
let (new_access_token, new_refresh_token) = TokenService::generate_token_pair(
&user_id,
self.auth_config.access_token_expiration_minutes,
self.auth_config.refresh_token_expiration_days,
&self.auth_config.jwt_secret,
)?;
// 5. 保存新的 refresh_token
self.save_refresh_token(&user_id, &new_refresh_token, self.auth_config.refresh_token_expiration_days as i64).await?;
Ok((new_access_token, new_refresh_token))
}
/// 删除用户
pub async fn delete_user(&self, request: DeleteUserRequest) -> Result<()> {
let password_hash = self.user_repo.get_password_hash(&request.user_id).await?
.ok_or_else(|| anyhow::anyhow!("用户不存在"))?;
let parsed_hash = PasswordHash::new(&password_hash)
.map_err(|e| anyhow::anyhow!("解析密码哈希失败: {}", e))?;
let argon2 = Argon2::default();
argon2
.verify_password(request.password.as_bytes(), &parsed_hash)
.map_err(|_| anyhow::anyhow!("密码错误"))?;
self.user_repo.delete_by_id(&request.user_id).await?;
Ok(())
}
}

1
src/services/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod auth_service;

101
src/utils/jwt.rs Normal file
View File

@@ -0,0 +1,101 @@
use anyhow::Result;
use chrono::{Duration, Utc};
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
use serde::{Deserialize, Serialize};
/// JWT 工具类,负责生成和验证 JWT token
pub struct TokenService;
impl TokenService {
/// 生成 JWT access token
pub fn generate_access_token(
user_id: &str,
expiration_minutes: u64,
jwt_secret: &str,
) -> Result<String> {
let expiration = Utc::now()
.checked_add_signed(Duration::minutes(expiration_minutes as i64))
.expect("invalid expiration timestamp")
.timestamp() as usize;
let claims = Claims {
sub: user_id.to_string(),
exp: expiration,
token_type: TokenType::Access,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(jwt_secret.as_ref()),
)?;
Ok(token)
}
/// 生成 refresh token
pub fn generate_refresh_token(
user_id: &str,
expiration_days: i64,
jwt_secret: &str,
) -> Result<String> {
let expiration = Utc::now()
.checked_add_signed(Duration::days(expiration_days))
.expect("invalid expiration timestamp")
.timestamp() as usize;
let claims = Claims {
sub: user_id.to_string(),
exp: expiration,
token_type: TokenType::Refresh,
};
let token = encode(
&Header::default(),
&claims,
&EncodingKey::from_secret(jwt_secret.as_ref()),
)?;
Ok(token)
}
/// 生成 access token 和 refresh token
pub fn generate_token_pair(
user_id: &str,
access_token_expiration_minutes: u64,
refresh_token_expiration_days: i64,
jwt_secret: &str,
) -> Result<(String, String)> {
let access_token =
Self::generate_access_token(user_id, access_token_expiration_minutes, jwt_secret)?;
let refresh_token =
Self::generate_refresh_token(user_id, refresh_token_expiration_days, jwt_secret)?;
Ok((access_token, refresh_token))
}
/// 从 token 中解码出 user_id
pub fn decode_user_id(token: &str, jwt_secret: &str) -> Result<String> {
let token_data = decode::<Claims>(
token,
&DecodingKey::from_secret(jwt_secret.as_ref()),
&Validation::default(),
)
.map_err(|e| anyhow::anyhow!("Token 解码失败: {}", e))?;
Ok(token_data.claims.sub)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Claims {
pub sub: String, // user_id
pub exp: usize, // 过期时间
pub token_type: TokenType,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum TokenType {
Access,
Refresh,
}

1
src/utils/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod jwt;