first commit
This commit is contained in:
234
src/cli.rs
Normal file
234
src/cli.rs
Normal file
@@ -0,0 +1,234 @@
|
||||
/// 命令行参数和配置管理
|
||||
/// 支持优先级:CLI 参数 > 环境变量 > 配置文件 > 默认值
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// 运行环境(强类型)
|
||||
#[derive(ValueEnum, Clone, Debug)]
|
||||
pub enum Environment {
|
||||
/// 开发环境
|
||||
Development,
|
||||
/// 生产环境
|
||||
Production,
|
||||
}
|
||||
|
||||
impl Environment {
|
||||
/// 转换为小写字符串
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Environment::Development => "development",
|
||||
Environment::Production => "production",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 命令行参数
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "web-rust-template")]
|
||||
#[command(about = "Web Server Template", long_about = None)]
|
||||
#[command(author = "Your Name <your.email@example.com>")]
|
||||
#[command(version = "0.1.0")]
|
||||
#[command(propagate_version = true)]
|
||||
pub struct CliArgs {
|
||||
/// 指定配置文件路径
|
||||
///
|
||||
/// 支持相对路径和绝对路径
|
||||
/// 例如:-c config/production.toml
|
||||
#[arg(short, long, value_name = "FILE")]
|
||||
pub config: Option<PathBuf>,
|
||||
|
||||
/// 指定运行环境
|
||||
///
|
||||
/// 自动加载对应环境的配置文件(如 config/development.toml)
|
||||
/// 可通过环境变量 ENV 设置
|
||||
#[arg(
|
||||
short = 'e',
|
||||
long,
|
||||
value_enum,
|
||||
env = "ENV",
|
||||
default_value = "development"
|
||||
)]
|
||||
pub env: Environment,
|
||||
|
||||
/// 指定服务器监听端口
|
||||
///
|
||||
/// 覆盖配置文件中的 port 设置
|
||||
/// 可通过环境变量 SERVER_PORT 设置
|
||||
#[arg(short, long, global = true, env = "SERVER_PORT")]
|
||||
pub port: Option<u16>,
|
||||
|
||||
/// 指定服务器监听地址
|
||||
///
|
||||
/// 覆盖配置文件中的 host 设置
|
||||
/// 可通过环境变量 SERVER_HOST 设置
|
||||
#[arg(long, global = true, env = "SERVER_HOST")]
|
||||
pub host: Option<String>,
|
||||
|
||||
/// 启用调试日志
|
||||
///
|
||||
/// 输出详细的日志信息,包括 SQL 查询
|
||||
/// 可通过环境变量 DEBUG 设置
|
||||
/// 注意:与 -v 冲突,推荐使用 -v/-vv/-vvv
|
||||
#[arg(long, global = true, env = "DEBUG", conflicts_with = "verbose")]
|
||||
pub debug: bool,
|
||||
|
||||
/// 工作目录
|
||||
///
|
||||
/// 指定配置文件和数据库的基准目录
|
||||
#[arg(short, long, global = true)]
|
||||
pub work_dir: Option<PathBuf>,
|
||||
|
||||
/// 显示详细日志(多级 verbose)
|
||||
///
|
||||
/// -v : info 级别日志
|
||||
/// -vv : debug 级别日志(等同于 --debug)
|
||||
/// -vvv : trace 级别日志(最详细)
|
||||
#[arg(short, long, global = true, action = clap::ArgAction::Count)]
|
||||
pub verbose: u8,
|
||||
}
|
||||
|
||||
impl CliArgs {
|
||||
/// 获取是否启用调试
|
||||
pub fn is_debug_enabled(&self) -> bool {
|
||||
self.debug || self.verbose >= 2
|
||||
}
|
||||
|
||||
/// 获取日志级别
|
||||
pub fn get_log_level(&self) -> &'static str {
|
||||
if self.debug {
|
||||
return "debug";
|
||||
}
|
||||
match self.verbose {
|
||||
0 => "info",
|
||||
1 => "debug",
|
||||
_ => "trace",
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取环境变量的日志过滤器(工程化版本)
|
||||
pub fn get_log_filter(&self) -> String {
|
||||
let level = self.get_log_level();
|
||||
match level {
|
||||
"trace" => "web_rust_template=trace,tower_http=trace,axum=trace,sqlx=debug".into(),
|
||||
"debug" => "web_rust_template=debug,tower_http=debug,axum=debug,sqlx=debug".into(),
|
||||
_ => "web_rust_template=info,tower_http=info,axum=info".into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取配置文件路径
|
||||
///
|
||||
/// 优先级:
|
||||
/// 1. CLI 参数 --config
|
||||
/// 2. 环境变量 CONFIG
|
||||
/// 3. {work_dir}/config/{env}.toml
|
||||
/// 4. ./config/{env}.toml
|
||||
/// 5. ./config/default.toml
|
||||
///
|
||||
/// 如果找不到配置文件,返回 None(允许仅使用环境变量运行)
|
||||
pub fn resolve_config_path(&self) -> Option<PathBuf> {
|
||||
use std::env;
|
||||
|
||||
// 1. CLI 参数优先
|
||||
if let Some(ref config) = self.config {
|
||||
if config.exists() {
|
||||
return Some(config.clone());
|
||||
}
|
||||
eprintln!("⚠ 警告:指定的配置文件不存在: {}", config.display());
|
||||
eprintln!(" 将仅使用环境变量运行");
|
||||
return None;
|
||||
}
|
||||
|
||||
// 2. 环境变量
|
||||
if let Ok(config_path) = env::var("CONFIG") {
|
||||
let config = PathBuf::from(&config_path);
|
||||
if config.exists() {
|
||||
return Some(config);
|
||||
}
|
||||
eprintln!("⚠ 警告:环境变量 CONFIG 指定的配置文件不存在: {}", config_path);
|
||||
eprintln!(" 将仅使用环境变量运行");
|
||||
return None;
|
||||
}
|
||||
|
||||
// 3-6. 查找配置文件
|
||||
let work_dir = self
|
||||
.work_dir
|
||||
.clone()
|
||||
.or_else(|| env::current_dir().ok())
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
|
||||
let env_name = self.env.as_str();
|
||||
|
||||
// 按优先级尝试的位置
|
||||
let candidates = [
|
||||
// 工作目录下的环境配置
|
||||
work_dir.join("config").join(format!("{}.toml", env_name)),
|
||||
// 当前目录的环境配置
|
||||
PathBuf::from(format!("config/{}.toml", env_name)),
|
||||
// 工作目录下的默认配置
|
||||
work_dir.join("config").join("default.toml"),
|
||||
// 当前目录的默认配置
|
||||
PathBuf::from("config/default.toml"),
|
||||
];
|
||||
|
||||
for candidate in &candidates {
|
||||
if candidate.exists() {
|
||||
// 使用 println! 而非 tracing::info!
|
||||
println!("✓ Found configuration file: {}", candidate.display());
|
||||
return Some(candidate.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// 所有候选路径都找不到配置文件,返回 None
|
||||
eprintln!("ℹ 未找到配置文件,将仅使用环境变量和默认值");
|
||||
None
|
||||
}
|
||||
|
||||
/// 获取覆盖配置
|
||||
///
|
||||
/// CLI 参数可以覆盖配置文件中的值(仅 Web 服务器参数)
|
||||
pub fn get_overrides(&self) -> ConfigOverrides {
|
||||
ConfigOverrides {
|
||||
host: self.host.clone(),
|
||||
port: self.port,
|
||||
}
|
||||
}
|
||||
|
||||
/// 显示启动信息(工程化版本:打印实际解析的配置)
|
||||
///
|
||||
/// 使用 println! 而非 tracing::info!,因为 logger 可能尚未初始化
|
||||
pub fn print_startup_info(&self) {
|
||||
let separator = "=".repeat(60);
|
||||
println!("{}", separator);
|
||||
println!("Web Rust Template Server v0.1.0");
|
||||
println!("Environment: {}", self.env.as_str());
|
||||
|
||||
// 打印实际解析的配置路径(而非 CLI 参数)
|
||||
if let Some(config_path) = self.resolve_config_path() {
|
||||
println!("Config file: {}", config_path.display());
|
||||
} else {
|
||||
println!("Config file: None (using environment variables)");
|
||||
}
|
||||
|
||||
if let Some(ref work_dir) = self.work_dir {
|
||||
println!("Work directory: {}", work_dir.display());
|
||||
}
|
||||
|
||||
// 打印实际的日志级别
|
||||
println!("Log level: {}", self.get_log_level());
|
||||
|
||||
if self.is_debug_enabled() {
|
||||
println!("Debug mode: ENABLED");
|
||||
}
|
||||
println!("{}", separator);
|
||||
}
|
||||
}
|
||||
|
||||
/// CLI 参数覆盖的配置(仅 Web 服务器参数)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ConfigOverrides {
|
||||
/// Web 服务器主机覆盖
|
||||
pub host: Option<String>,
|
||||
|
||||
/// Web 服务器端口覆盖
|
||||
pub port: Option<u16>,
|
||||
}
|
||||
135
src/config/app.rs
Normal file
135
src/config/app.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use super::{auth::AuthConfig, database::DatabaseConfig, redis::RedisConfig, server::ServerConfig};
|
||||
use config::{Config, ConfigError, Environment, File};
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
// 导入 redis 默认值函数(使用完整路径)
|
||||
use crate::config::redis::default_redis_host;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct AppConfig {
|
||||
pub server: ServerConfig,
|
||||
pub database: DatabaseConfig,
|
||||
pub auth: AuthConfig,
|
||||
pub redis: RedisConfig,
|
||||
}
|
||||
|
||||
impl AppConfig {
|
||||
/// 加载配置(支持 CLI 覆盖)
|
||||
///
|
||||
/// 如果 config_path 为 None,则仅使用环境变量和默认值
|
||||
pub fn load_with_overrides(
|
||||
cli_config_path: Option<std::path::PathBuf>,
|
||||
overrides: crate::cli::ConfigOverrides,
|
||||
_environment: &str,
|
||||
) -> Result<Self, ConfigError> {
|
||||
// 使用 ConfigBuilder 设置配置
|
||||
let mut builder = Config::builder();
|
||||
|
||||
// 如果提供了配置文件,先加载它
|
||||
if let Some(config_path) = cli_config_path {
|
||||
if !config_path.exists() {
|
||||
tracing::error!("Configuration file not found: {}", config_path.display());
|
||||
return Err(ConfigError::NotFound(
|
||||
config_path.to_string_lossy().to_string(),
|
||||
));
|
||||
}
|
||||
tracing::info!("Loading configuration from: {}", config_path.display());
|
||||
builder = builder.add_source(File::from(config_path));
|
||||
} else {
|
||||
tracing::info!("No configuration file found, using environment variables and defaults");
|
||||
tracing::warn!("⚠️ 没有找到配置文件,将使用 SQLite 作为默认数据库");
|
||||
tracing::warn!(" 默认数据库路径: db.sqlite3");
|
||||
tracing::warn!(" 如需使用其他数据库,请创建配置文件或设置环境变量");
|
||||
|
||||
// 直接使用 set_default 设置默认值
|
||||
// 注意:这些值会被环境变量覆盖
|
||||
builder = builder.set_default("server.host", default_server_host())?;
|
||||
builder = builder.set_default("server.port", default_server_port())?;
|
||||
|
||||
// 设置 database 默认值(使用 SQLite 作为默认数据库)
|
||||
builder = builder.set_default("database.database_type", "sqlite")?;
|
||||
builder = builder.set_default("database.path", "db.sqlite3")?;
|
||||
builder = builder.set_default("database.max_connections", 10)?;
|
||||
|
||||
// 设置 auth 默认值
|
||||
builder = builder.set_default("auth.jwt_secret", default_jwt_secret())?;
|
||||
builder = builder.set_default("auth.access_token_expiration_minutes", 15)?;
|
||||
builder = builder.set_default("auth.refresh_token_expiration_days", 7)?;
|
||||
|
||||
// 设置 redis 默认值
|
||||
builder = builder.set_default("redis.host", default_redis_host())?;
|
||||
builder = builder.set_default("redis.port", 6379)?;
|
||||
builder = builder.set_default("redis.db", 0)?;
|
||||
}
|
||||
|
||||
// 添加环境变量源(会覆盖配置文件的值)
|
||||
builder = builder.add_source(Environment::default().separator("_"));
|
||||
|
||||
// 应用 CLI 覆盖(仅 Web 服务器参数)
|
||||
if let Some(host) = overrides.host {
|
||||
builder = builder.set_override("server.host", host)?;
|
||||
}
|
||||
if let Some(port) = overrides.port {
|
||||
builder = builder.set_override("server.port", port)?;
|
||||
}
|
||||
|
||||
let settings = builder.build()?;
|
||||
let config: AppConfig = settings.try_deserialize()?;
|
||||
|
||||
// 安全警告:检查是否使用了默认的 JWT 密钥
|
||||
if config.auth.jwt_secret == "change-this-to-a-strong-secret-key-in-production" {
|
||||
tracing::warn!("⚠️ 警告:正在使用不安全的默认 JWT 密钥!");
|
||||
tracing::warn!(" 请通过环境变量 AUTH_JWT_SECRET 或配置文件设置强密钥");
|
||||
tracing::warn!(" 示例:AUTH_JWT_SECRET=your-secure-random-string-here");
|
||||
}
|
||||
|
||||
// 验证数据库配置
|
||||
if let Err(e) = config.database.validate() {
|
||||
tracing::error!("数据库配置无效: {}", e);
|
||||
return Err(ConfigError::Message(format!("数据库配置无效: {}", e)));
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// 从指定路径加载配置
|
||||
pub fn load_from_path(path: &str) -> Result<Self, ConfigError> {
|
||||
tracing::info!("Loading configuration from: {}", path);
|
||||
|
||||
let settings = Config::builder()
|
||||
.add_source(File::from(PathBuf::from(path)))
|
||||
.add_source(Environment::default().separator("_"))
|
||||
.build()?;
|
||||
|
||||
let config: AppConfig = settings.try_deserialize()?;
|
||||
|
||||
// 安全警告:检查是否使用了默认的 JWT 密钥
|
||||
if config.auth.jwt_secret == "change-this-to-a-strong-secret-key-in-production" {
|
||||
tracing::warn!("⚠️ 警告:正在使用不安全的默认 JWT 密钥!");
|
||||
tracing::warn!(" 请通过环境变量 AUTH_JWT_SECRET 或配置文件设置强密钥");
|
||||
tracing::warn!(" 示例:AUTH_JWT_SECRET=your-secure-random-string-here");
|
||||
}
|
||||
|
||||
// 验证数据库配置
|
||||
if let Err(e) = config.database.validate() {
|
||||
tracing::error!("数据库配置无效: {}", e);
|
||||
return Err(ConfigError::Message(format!("数据库配置无效: {}", e)));
|
||||
}
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
// 默认值函数(复用)
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 {
|
||||
3000
|
||||
}
|
||||
|
||||
fn default_jwt_secret() -> String {
|
||||
"change-this-to-a-strong-secret-key-in-production".to_string()
|
||||
}
|
||||
25
src/config/auth.rs
Normal file
25
src/config/auth.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct AuthConfig {
|
||||
#[serde(default = "default_jwt_secret")]
|
||||
pub jwt_secret: String,
|
||||
#[serde(default = "default_access_token_expiration_minutes")]
|
||||
pub access_token_expiration_minutes: u64,
|
||||
#[serde(default = "default_refresh_token_expiration_days")]
|
||||
pub refresh_token_expiration_days: i64,
|
||||
}
|
||||
|
||||
fn default_jwt_secret() -> String {
|
||||
// ⚠️ 警告:这是一个不安全的默认值,仅用于开发测试
|
||||
// 生产环境必须通过环境变量或配置文件设置强密钥
|
||||
"change-this-to-a-strong-secret-key-in-production".to_string()
|
||||
}
|
||||
|
||||
fn default_access_token_expiration_minutes() -> u64 {
|
||||
15
|
||||
}
|
||||
|
||||
fn default_refresh_token_expiration_days() -> i64 {
|
||||
7
|
||||
}
|
||||
168
src/config/database.rs
Normal file
168
src/config/database.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
use serde::Deserialize;
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// 数据库类型
|
||||
#[derive(Debug, Deserialize, Clone, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum DatabaseType {
|
||||
MySQL,
|
||||
SQLite,
|
||||
PostgreSQL,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct DatabaseConfig {
|
||||
/// 数据库类型
|
||||
#[serde(default = "default_database_type")]
|
||||
pub database_type: DatabaseType,
|
||||
|
||||
/// 网络数据库配置(MySQL/PostgreSQL)
|
||||
pub host: Option<String>,
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
pub user: Option<String>,
|
||||
pub password: Option<String>,
|
||||
pub database: Option<String>,
|
||||
|
||||
/// SQLite 文件路径
|
||||
pub path: Option<PathBuf>,
|
||||
|
||||
/// 连接池最大连接数
|
||||
#[serde(default = "default_max_connections")]
|
||||
pub max_connections: u32,
|
||||
}
|
||||
|
||||
impl DatabaseConfig {
|
||||
/// 获取端口号(根据数据库类型返回默认值)
|
||||
pub fn get_port(&self) -> u16 {
|
||||
self.port.unwrap_or_else(|| match self.database_type {
|
||||
DatabaseType::MySQL => 3306,
|
||||
DatabaseType::PostgreSQL => 5432,
|
||||
DatabaseType::SQLite => 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// 构建数据库连接 URL
|
||||
///
|
||||
/// # 错误
|
||||
///
|
||||
/// 当缺少必需的配置字段时返回错误
|
||||
pub fn build_url(&self) -> Result<String, String> {
|
||||
match self.database_type {
|
||||
DatabaseType::MySQL => {
|
||||
let host = self
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.host".to_string())?;
|
||||
let user = self
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.user".to_string())?;
|
||||
let password = self
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.password".to_string())?;
|
||||
let database = self
|
||||
.database
|
||||
.as_ref()
|
||||
.ok_or_else(|| "MySQL 需要配置 database.database".to_string())?;
|
||||
|
||||
Ok(format!(
|
||||
"mysql://{}:{}@{}:{}/{}",
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
self.get_port(),
|
||||
database
|
||||
))
|
||||
}
|
||||
DatabaseType::SQLite => {
|
||||
let path = self
|
||||
.path
|
||||
.as_ref()
|
||||
.ok_or_else(|| "SQLite 需要配置 database.path".to_string())?;
|
||||
|
||||
// SQLite URL 格式
|
||||
// 相对路径:sqlite:./db.sqlite3
|
||||
// 绝对路径:sqlite:C:/path/to/db.sqlite3
|
||||
let path_str = path.to_string_lossy().replace('\\', "/");
|
||||
Ok(format!("sqlite:{}", path_str))
|
||||
}
|
||||
DatabaseType::PostgreSQL => {
|
||||
let host = self
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.host".to_string())?;
|
||||
let user = self
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.user".to_string())?;
|
||||
let password = self
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.password".to_string())?;
|
||||
let database = self
|
||||
.database
|
||||
.as_ref()
|
||||
.ok_or_else(|| "PostgreSQL 需要配置 database.database".to_string())?;
|
||||
|
||||
Ok(format!(
|
||||
"postgresql://{}:{}@{}:{}/{}",
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
self.get_port(),
|
||||
database
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 验证配置是否完整
|
||||
pub fn validate(&self) -> Result<(), String> {
|
||||
match self.database_type {
|
||||
DatabaseType::MySQL => {
|
||||
if self.host.is_none() {
|
||||
return Err("MySQL 需要配置 database.host".to_string());
|
||||
}
|
||||
if self.user.is_none() {
|
||||
return Err("MySQL 需要配置 database.user".to_string());
|
||||
}
|
||||
if self.password.is_none() {
|
||||
return Err("MySQL 需要配置 database.password".to_string());
|
||||
}
|
||||
if self.database.is_none() {
|
||||
return Err("MySQL 需要配置 database.database".to_string());
|
||||
}
|
||||
}
|
||||
DatabaseType::SQLite => {
|
||||
if self.path.is_none() {
|
||||
return Err("SQLite 需要配置 database.path".to_string());
|
||||
}
|
||||
}
|
||||
DatabaseType::PostgreSQL => {
|
||||
if self.host.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.host".to_string());
|
||||
}
|
||||
if self.user.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.user".to_string());
|
||||
}
|
||||
if self.password.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.password".to_string());
|
||||
}
|
||||
if self.database.is_none() {
|
||||
return Err("PostgreSQL 需要配置 database.database".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn default_database_type() -> DatabaseType {
|
||||
DatabaseType::MySQL
|
||||
}
|
||||
|
||||
fn default_max_connections() -> u32 {
|
||||
10
|
||||
}
|
||||
5
src/config/mod.rs
Normal file
5
src/config/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
pub mod app;
|
||||
pub mod auth;
|
||||
pub mod database;
|
||||
pub mod redis;
|
||||
pub mod server;
|
||||
52
src/config/redis.rs
Normal file
52
src/config/redis.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct RedisConfig {
|
||||
/// Redis 主机地址
|
||||
#[serde(default = "default_redis_host")]
|
||||
pub host: String,
|
||||
|
||||
/// Redis 端口
|
||||
#[serde(default = "default_redis_port")]
|
||||
pub port: u16,
|
||||
|
||||
/// Redis 密码(可选)
|
||||
#[serde(default)]
|
||||
pub password: Option<String>,
|
||||
|
||||
/// Redis 数据库编号(可选)
|
||||
#[serde(default = "default_redis_db")]
|
||||
pub db: u8,
|
||||
}
|
||||
|
||||
pub fn default_redis_host() -> String {
|
||||
"localhost".to_string()
|
||||
}
|
||||
|
||||
pub fn default_redis_port() -> u16 {
|
||||
6379
|
||||
}
|
||||
|
||||
pub fn default_redis_db() -> u8 {
|
||||
0
|
||||
}
|
||||
|
||||
impl RedisConfig {
|
||||
/// 构建 Redis 连接 URL
|
||||
pub fn build_url(&self) -> String {
|
||||
// 判断密码是否存在且非空
|
||||
match &self.password {
|
||||
Some(password) if !password.is_empty() => {
|
||||
// 有密码:redis://:password@host:port/db
|
||||
format!(
|
||||
"redis://:{}@{}:{}/{}",
|
||||
password, self.host, self.port, self.db
|
||||
)
|
||||
}
|
||||
_ => {
|
||||
// 无密码(None 或空字符串):redis://host:port/db
|
||||
format!("redis://{}:{}/{}", self.host, self.port, self.db)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
17
src/config/server.rs
Normal file
17
src/config/server.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize, Clone)]
|
||||
pub struct ServerConfig {
|
||||
#[serde(default = "default_server_host")]
|
||||
pub host: String,
|
||||
#[serde(default = "default_server_port")]
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
fn default_server_host() -> String {
|
||||
"127.0.0.1".to_string()
|
||||
}
|
||||
|
||||
fn default_server_port() -> u16 {
|
||||
3000
|
||||
}
|
||||
329
src/db.rs
Normal file
329
src/db.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
use crate::config::database::{DatabaseConfig, DatabaseType};
|
||||
use sea_orm::{
|
||||
ConnectionTrait, Database, DatabaseConnection, DbBackend, EntityName, EntityTrait, ConnectOptions, Schema,
|
||||
Statement,
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
/// 数据库连接池(SeaORM 统一接口)
|
||||
pub type DbPool = DatabaseConnection;
|
||||
|
||||
/// 创建数据库连接池
|
||||
pub async fn create_pool(config: &DatabaseConfig) -> anyhow::Result<DbPool> {
|
||||
let url = config
|
||||
.build_url()
|
||||
.map_err(|e| anyhow::anyhow!("数据库配置错误: {}", e))?;
|
||||
|
||||
tracing::debug!("数据库连接 URL: {}", url);
|
||||
|
||||
let mut opt = ConnectOptions::new(&url);
|
||||
opt.max_connections(config.max_connections)
|
||||
.min_connections(1)
|
||||
.connect_timeout(Duration::from_secs(8))
|
||||
.idle_timeout(Duration::from_secs(8))
|
||||
.max_lifetime(Duration::from_secs(7200))
|
||||
.sqlx_logging(true);
|
||||
|
||||
let pool = Database::connect(opt)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("数据库连接失败: {}", e))?;
|
||||
|
||||
tracing::info!("已连接到数据库: {}", sanitize_url(&url));
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// 隐藏 URL 中的敏感信息(用于日志输出)
|
||||
fn sanitize_url(url: &str) -> String {
|
||||
// 隐藏密码:mysql://user:password@host -> mysql://user:***@host
|
||||
if let Some(at_pos) = url.find('@') {
|
||||
if let Some(scheme_end) = url.find("://") {
|
||||
if scheme_end < at_pos {
|
||||
return format!("{}***@{}", &url[..scheme_end + 3], &url[at_pos + 1..]);
|
||||
}
|
||||
}
|
||||
}
|
||||
url.to_string()
|
||||
}
|
||||
|
||||
/// 健康检查(保持向后兼容)
|
||||
pub async fn health_check(pool: &DbPool) -> anyhow::Result<()> {
|
||||
// 使用官方推荐的 ping 方法
|
||||
pool.ping()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("数据库健康检查失败: {}", e))
|
||||
}
|
||||
|
||||
/// 初始化数据库和表结构
|
||||
/// 每次启动时检查数据库和表是否存在,不存在则创建
|
||||
pub async fn init_database(config: &DatabaseConfig) -> anyhow::Result<DatabaseConnection> {
|
||||
match config.database_type {
|
||||
DatabaseType::MySQL => {
|
||||
init_mysql_database(config).await?;
|
||||
}
|
||||
DatabaseType::PostgreSQL => {
|
||||
init_postgresql_database(config).await?;
|
||||
}
|
||||
DatabaseType::SQLite => {
|
||||
// 确保 SQLite 数据库文件的目录存在
|
||||
init_sqlite_database(config).await?;
|
||||
}
|
||||
}
|
||||
|
||||
// 连接到数据库
|
||||
let pool = create_pool(config).await?;
|
||||
|
||||
// 创建表
|
||||
create_tables(&pool).await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
|
||||
/// 获取端口号(根据数据库类型返回默认值)
|
||||
fn get_database_port(config: &DatabaseConfig) -> u16 {
|
||||
config.port.unwrap_or_else(|| match config.database_type {
|
||||
DatabaseType::MySQL => 3306,
|
||||
DatabaseType::PostgreSQL => 5432,
|
||||
DatabaseType::SQLite => 0,
|
||||
})
|
||||
}
|
||||
|
||||
/// 为 MySQL 创建数据库(如果不存在)
|
||||
async fn init_mysql_database(config: &DatabaseConfig) -> anyhow::Result<()> {
|
||||
let database_name = config
|
||||
.database
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.database"))?;
|
||||
|
||||
let host = config
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.host"))?;
|
||||
let user = config
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.user"))?;
|
||||
let password = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("MySQL 需要配置 database.password"))?;
|
||||
|
||||
// 连接到 MySQL 服务器(不指定数据库)
|
||||
let url = format!(
|
||||
"mysql://{}:{}@{}:{}",
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
get_database_port(config)
|
||||
);
|
||||
|
||||
let mut opt = ConnectOptions::new(&url);
|
||||
opt.max_connections(1)
|
||||
.connect_timeout(Duration::from_secs(8))
|
||||
.sqlx_logging(true);
|
||||
|
||||
let conn = Database::connect(opt)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("连接 MySQL 服务器失败: {}", e))?;
|
||||
|
||||
// 检查数据库是否存在,不存在则创建
|
||||
let query = format!(
|
||||
"CREATE DATABASE IF NOT EXISTS `{}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
|
||||
database_name
|
||||
);
|
||||
|
||||
conn.execute(Statement::from_string(
|
||||
sea_orm::DatabaseBackend::MySql,
|
||||
query,
|
||||
))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("创建 MySQL 数据库失败: {}", e))?;
|
||||
|
||||
tracing::info!("✅ MySQL 数据库 '{}' 检查完成", database_name);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 为 PostgreSQL 创建数据库(如果不存在)
|
||||
async fn init_postgresql_database(config: &DatabaseConfig) -> anyhow::Result<()> {
|
||||
let database_name = config
|
||||
.database
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.database"))?;
|
||||
|
||||
let host = config
|
||||
.host
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.host"))?;
|
||||
let user = config
|
||||
.user
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.user"))?;
|
||||
let password = config
|
||||
.password
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("PostgreSQL 需要配置 database.password"))?;
|
||||
|
||||
// 连接到 PostgreSQL 默认数据库(postgres)
|
||||
let url = format!(
|
||||
"postgresql://{}:{}@{}:{}/postgres",
|
||||
user,
|
||||
password,
|
||||
host,
|
||||
get_database_port(config)
|
||||
);
|
||||
|
||||
let mut opt = ConnectOptions::new(&url);
|
||||
opt.max_connections(1)
|
||||
.connect_timeout(Duration::from_secs(8))
|
||||
.sqlx_logging(true);
|
||||
|
||||
let conn = Database::connect(opt)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("连接 PostgreSQL 服务器失败: {}", e))?;
|
||||
|
||||
// 检查数据库是否存在,不存在则创建
|
||||
// PostgreSQL 不支持 CREATE DATABASE IF NOT EXISTS,需要先查询
|
||||
let check_query = format!(
|
||||
"SELECT 1 FROM pg_database WHERE datname='{}'",
|
||||
database_name
|
||||
);
|
||||
|
||||
let result = conn
|
||||
.execute(Statement::from_string(
|
||||
sea_orm::DatabaseBackend::Postgres,
|
||||
check_query,
|
||||
))
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(_) => {
|
||||
tracing::info!("PostgreSQL 数据库 '{}' 已存在", database_name);
|
||||
}
|
||||
Err(_) => {
|
||||
// 数据库不存在,创建它
|
||||
let create_query = format!(
|
||||
"CREATE DATABASE {} WITH ENCODING 'UTF8' LC_COLLATE='en_US.UTF-8' LC_CTYPE='en_US.UTF-8'",
|
||||
database_name
|
||||
);
|
||||
|
||||
conn.execute(Statement::from_string(
|
||||
sea_orm::DatabaseBackend::Postgres,
|
||||
create_query,
|
||||
))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("创建 PostgreSQL 数据库失败: {}", e))?;
|
||||
|
||||
tracing::info!("✅ PostgreSQL 数据库 '{}' 创建成功", database_name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 为 SQLite 确保数据库文件目录存在
|
||||
async fn init_sqlite_database(config: &DatabaseConfig) -> anyhow::Result<()> {
|
||||
let path = config
|
||||
.path
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("SQLite 需要配置 database.path"))?;
|
||||
|
||||
// 如果是相对路径,转换为绝对路径
|
||||
let absolute_path = if path.is_absolute() {
|
||||
path.clone()
|
||||
} else {
|
||||
std::env::current_dir()
|
||||
.map_err(|e| anyhow::anyhow!("获取当前目录失败: {}", e))?
|
||||
.join(path)
|
||||
};
|
||||
|
||||
tracing::info!("SQLite 数据库路径: {}", absolute_path.display());
|
||||
|
||||
// 获取数据库文件的父目录
|
||||
if let Some(parent) = absolute_path.parent() {
|
||||
// 如果父目录不存在,则创建
|
||||
if !parent.exists() {
|
||||
std::fs::create_dir_all(parent)
|
||||
.map_err(|e| anyhow::anyhow!("创建 SQLite 数据库目录失败: {}", e))?;
|
||||
tracing::info!("✅ SQLite 数据库目录创建成功: {}", parent.display());
|
||||
} else {
|
||||
tracing::info!("SQLite 数据库目录已存在: {}", parent.display());
|
||||
}
|
||||
}
|
||||
|
||||
// 如果数据库文件不存在,创建空文件
|
||||
if !absolute_path.exists() {
|
||||
std::fs::File::create(&absolute_path)
|
||||
.map_err(|e| anyhow::anyhow!("创建 SQLite 数据库文件失败: {}", e))?;
|
||||
tracing::info!("✅ SQLite 数据库文件创建成功: {}", absolute_path.display());
|
||||
} else {
|
||||
tracing::info!("SQLite 数据库文件已存在: {}", absolute_path.display());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 辅助函数:创建单个表(如果不存在)
|
||||
async fn create_single_table<E>(
|
||||
db: &DatabaseConnection,
|
||||
schema: &Schema,
|
||||
builder: &DbBackend,
|
||||
entity: E,
|
||||
table_name: &str,
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
E: EntityName + EntityTrait,
|
||||
{
|
||||
let create_table = schema.create_table_from_entity(entity);
|
||||
|
||||
let sql = match builder {
|
||||
DbBackend::MySql => {
|
||||
use sea_orm::sea_query::MysqlQueryBuilder;
|
||||
create_table.to_string(MysqlQueryBuilder {})
|
||||
}
|
||||
DbBackend::Postgres => {
|
||||
use sea_orm::sea_query::PostgresQueryBuilder;
|
||||
create_table.to_string(PostgresQueryBuilder {})
|
||||
}
|
||||
DbBackend::Sqlite => {
|
||||
use sea_orm::sea_query::SqliteQueryBuilder;
|
||||
create_table.to_string(SqliteQueryBuilder {})
|
||||
}
|
||||
};
|
||||
|
||||
let sql = sql.replace("CREATE TABLE", "CREATE TABLE IF NOT EXISTS");
|
||||
|
||||
match db.execute(Statement::from_string(*builder, sql)).await {
|
||||
Ok(_) => {
|
||||
tracing::info!("✅ {}检查完成", table_name);
|
||||
}
|
||||
Err(e) => {
|
||||
let err_msg = e.to_string();
|
||||
if err_msg.contains("already exists") || (err_msg.contains("table") && err_msg.contains("exists")) {
|
||||
tracing::info!("✅ {}已存在", table_name);
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("创建{}失败: {}", table_name, e));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 创建数据库表结构
|
||||
async fn create_tables(db: &DatabaseConnection) -> anyhow::Result<()> {
|
||||
tracing::info!("检查数据库表结构...");
|
||||
|
||||
let builder = db.get_database_backend();
|
||||
let schema = Schema::new(builder);
|
||||
|
||||
// 导入所有 entities
|
||||
use crate::domain::entities::users;
|
||||
|
||||
// 创建所有表(添加新表只需一行!)
|
||||
create_single_table(db, &schema, &builder, users::Entity, "用户表").await?;
|
||||
|
||||
tracing::info!("✅ 数据库表结构检查完成");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
57
src/domain/dto/auth.rs
Normal file
57
src/domain/dto/auth.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
use serde::Deserialize;
|
||||
use std::fmt;
|
||||
|
||||
/// 注册请求
|
||||
#[derive(Deserialize)]
|
||||
pub struct RegisterRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
// 实现 Debug trait,对密码进行脱敏
|
||||
impl fmt::Debug for RegisterRequest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "RegisterRequest {{ email: {}, password: *** }}", self.email)
|
||||
}
|
||||
}
|
||||
|
||||
/// 登录请求
|
||||
#[derive(Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
pub email: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
// 实现 Debug trait
|
||||
impl fmt::Debug for LoginRequest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "LoginRequest {{ email: {}, password: *** }}", self.email)
|
||||
}
|
||||
}
|
||||
|
||||
/// 删除用户请求
|
||||
#[derive(Deserialize)]
|
||||
pub struct DeleteUserRequest {
|
||||
pub user_id: String,
|
||||
pub password: String,
|
||||
}
|
||||
|
||||
// 实现 Debug trait
|
||||
impl fmt::Debug for DeleteUserRequest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "DeleteUserRequest {{ user_id: {}, password: *** }}", self.user_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// 刷新令牌请求
|
||||
#[derive(Deserialize)]
|
||||
pub struct RefreshRequest {
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
// RefreshRequest 的 refresh_token 是敏感字段,需要脱敏
|
||||
impl fmt::Debug for RefreshRequest {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "RefreshRequest {{ refresh_token: *** }}")
|
||||
}
|
||||
}
|
||||
1
src/domain/dto/mod.rs
Normal file
1
src/domain/dto/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod auth;
|
||||
1
src/domain/dto/user.rs
Normal file
1
src/domain/dto/user.rs
Normal file
@@ -0,0 +1 @@
|
||||
// 用户相关 DTO(预留)
|
||||
2
src/domain/entities/mod.rs
Normal file
2
src/domain/entities/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod users;
|
||||
|
||||
41
src/domain/entities/users.rs
Normal file
41
src/domain/entities/users.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use sea_orm::entity::prelude::*;
|
||||
use sea_orm::Set;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, DeriveEntityModel, Serialize, Deserialize)]
|
||||
#[sea_orm(table_name = "users")]
|
||||
pub struct Model {
|
||||
#[sea_orm(primary_key, auto_increment = false)]
|
||||
pub id: String,
|
||||
#[sea_orm(unique)]
|
||||
pub email: String,
|
||||
pub password_hash: String,
|
||||
pub created_at: DateTime,
|
||||
pub updated_at: DateTime,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
|
||||
pub enum Relation {}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl ActiveModelBehavior for ActiveModel {
|
||||
/// 在保存前自动填充时间戳
|
||||
async fn before_save<C>(self, _db: &C, insert: bool) -> Result<Self, DbErr>
|
||||
where
|
||||
C: ConnectionTrait,
|
||||
{
|
||||
let mut this = self;
|
||||
let now = chrono::Utc::now().naive_utc();
|
||||
|
||||
if insert {
|
||||
// 插入时:设置创建时间和更新时间
|
||||
this.created_at = Set(now);
|
||||
this.updated_at = Set(now);
|
||||
} else {
|
||||
// 更新时:只更新更新时间
|
||||
this.updated_at = Set(now);
|
||||
}
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
}
|
||||
3
src/domain/mod.rs
Normal file
3
src/domain/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod dto;
|
||||
pub mod vo;
|
||||
pub mod entities;
|
||||
50
src/domain/vo/auth.rs
Normal file
50
src/domain/vo/auth.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use serde::Serialize;
|
||||
|
||||
/// 注册结果
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RegisterResult {
|
||||
pub email: String,
|
||||
pub created_at: String, // ISO 8601 格式
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
impl From<(crate::domain::entities::users::Model, String, String)> for RegisterResult {
|
||||
fn from((user_model, access_token, refresh_token): (crate::domain::entities::users::Model, String, String)) -> Self {
|
||||
Self {
|
||||
email: user_model.email,
|
||||
created_at: user_model.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
|
||||
access_token,
|
||||
refresh_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 登录结果
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct LoginResult {
|
||||
pub id: String,
|
||||
pub email: String,
|
||||
pub created_at: String, // ISO 8601 格式
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
|
||||
impl From<(crate::domain::entities::users::Model, String, String)> for LoginResult {
|
||||
fn from((user_model, access_token, refresh_token): (crate::domain::entities::users::Model, String, String)) -> Self {
|
||||
Self {
|
||||
id: user_model.id,
|
||||
email: user_model.email,
|
||||
created_at: user_model.created_at.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
|
||||
access_token,
|
||||
refresh_token,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 刷新 Token 结果
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct RefreshResult {
|
||||
pub access_token: String,
|
||||
pub refresh_token: String,
|
||||
}
|
||||
56
src/domain/vo/mod.rs
Normal file
56
src/domain/vo/mod.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
pub mod auth;
|
||||
pub mod user;
|
||||
|
||||
/// 统一的 API 响应结构
|
||||
use serde::Serialize;
|
||||
use axum::http::StatusCode;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ApiResponse<T> {
|
||||
/// HTTP 状态码
|
||||
pub code: u16,
|
||||
/// 响应消息
|
||||
pub message: String,
|
||||
/// 响应数据
|
||||
pub data: Option<T>,
|
||||
}
|
||||
|
||||
impl<T: Serialize> ApiResponse<T> {
|
||||
/// 成功响应(200)
|
||||
pub fn success(data: T) -> Self {
|
||||
Self {
|
||||
code: 200,
|
||||
message: "Success".to_string(),
|
||||
data: Some(data),
|
||||
}
|
||||
}
|
||||
|
||||
/// 成功响应(自定义消息)
|
||||
pub fn success_with_message(data: T, message: &str) -> Self {
|
||||
Self {
|
||||
code: 200,
|
||||
message: message.to_string(),
|
||||
data: Some(data),
|
||||
}
|
||||
}
|
||||
|
||||
/// 错误响应
|
||||
#[allow(dead_code)]
|
||||
pub fn error(status_code: StatusCode, message: &str) -> ApiResponse<()> {
|
||||
ApiResponse {
|
||||
code: status_code.as_u16(),
|
||||
message: message.to_string(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// 错误响应(带数据)
|
||||
#[allow(dead_code)]
|
||||
pub fn error_with_data(status_code: StatusCode, message: &str, data: T) -> ApiResponse<T> {
|
||||
ApiResponse {
|
||||
code: status_code.as_u16(),
|
||||
message: message.to_string(),
|
||||
data: Some(data),
|
||||
}
|
||||
}
|
||||
}
|
||||
1
src/domain/vo/user.rs
Normal file
1
src/domain/vo/user.rs
Normal file
@@ -0,0 +1 @@
|
||||
// 用户相关 VO(预留)
|
||||
89
src/error.rs
Normal file
89
src/error.rs
Normal file
@@ -0,0 +1,89 @@
|
||||
use crate::domain::vo::ApiResponse;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
|
||||
/// 应用错误类型
|
||||
#[allow(dead_code)]
|
||||
pub struct AppError(pub anyhow::Error);
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
tracing::error!("Application error: {:?}", self.0);
|
||||
|
||||
let (status, message) = match self.0.downcast_ref::<&str>() {
|
||||
Some(&"not_found") => (StatusCode::NOT_FOUND, "Resource not found"),
|
||||
Some(&"unauthorized") => (StatusCode::UNAUTHORIZED, "Unauthorized"),
|
||||
_ => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"),
|
||||
};
|
||||
|
||||
let body = ApiResponse::<()> {
|
||||
code: status.as_u16(),
|
||||
message: message.to_string(),
|
||||
data: None,
|
||||
};
|
||||
|
||||
(status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
// 为具体类型实现 From
|
||||
impl From<anyhow::Error> for AppError {
|
||||
fn from(err: anyhow::Error) -> Self {
|
||||
Self(err)
|
||||
}
|
||||
}
|
||||
|
||||
/// 统一的 API 错误响应结构
|
||||
#[derive(Debug)]
|
||||
pub struct ErrorResponse {
|
||||
pub status: StatusCode,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
pub fn new(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
status: StatusCode::BAD_REQUEST,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn not_found(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
status: StatusCode::NOT_FOUND,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn unauthorized(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn internal(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ErrorResponse {
|
||||
fn into_response(self) -> Response {
|
||||
let body = ApiResponse::<()> {
|
||||
code: self.status.as_u16(),
|
||||
message: self.message,
|
||||
data: None,
|
||||
};
|
||||
|
||||
(self.status, Json(body)).into_response()
|
||||
}
|
||||
}
|
||||
157
src/handlers/auth.rs
Normal file
157
src/handlers/auth.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use crate::error::ErrorResponse;
|
||||
use crate::infra::middleware::logging::{log_info, RequestId};
|
||||
use crate::domain::dto::auth::{RegisterRequest, LoginRequest, RefreshRequest, DeleteUserRequest};
|
||||
use crate::domain::vo::auth::{RegisterResult, LoginResult, RefreshResult};
|
||||
use crate::domain::vo::ApiResponse;
|
||||
use crate::repositories::user_repository::UserRepository;
|
||||
use crate::services::auth_service::AuthService;
|
||||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::{Extension, State},
|
||||
Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
/// 注册
|
||||
pub async fn register(
|
||||
Extension(request_id): Extension<RequestId>,
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<RegisterRequest>,
|
||||
) -> Result<Json<ApiResponse<RegisterResult>>, ErrorResponse> {
|
||||
log_info(&request_id, "注册请求参数", &payload);
|
||||
|
||||
let user_repo = UserRepository::new(state.pool.clone());
|
||||
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
|
||||
|
||||
match service.register(payload).await {
|
||||
Ok((user_model, access_token, refresh_token)) => {
|
||||
let data = RegisterResult::from((user_model, access_token, refresh_token));
|
||||
let response = ApiResponse::success(data);
|
||||
log_info(&request_id, "注册成功", &response);
|
||||
Ok(Json(response))
|
||||
}
|
||||
Err(e) => {
|
||||
log_info(&request_id, "注册失败", &e.to_string());
|
||||
Err(ErrorResponse::new(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 登录
|
||||
pub async fn login(
|
||||
Extension(request_id): Extension<RequestId>,
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<LoginRequest>,
|
||||
) -> Result<Json<ApiResponse<LoginResult>>, ErrorResponse> {
|
||||
log_info(&request_id, "登录请求参数", &payload);
|
||||
|
||||
let user_repo = UserRepository::new(state.pool.clone());
|
||||
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
|
||||
|
||||
match service.login(payload).await {
|
||||
Ok((user_model, access_token, refresh_token)) => {
|
||||
let data = LoginResult::from((user_model, access_token, refresh_token));
|
||||
let response = ApiResponse::success(data);
|
||||
log_info(&request_id, "登录成功", &response);
|
||||
Ok(Json(response))
|
||||
}
|
||||
Err(e) => {
|
||||
log_info(&request_id, "登录失败", &e.to_string());
|
||||
Err(ErrorResponse::new(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 刷新 Token
|
||||
pub async fn refresh(
|
||||
Extension(request_id): Extension<RequestId>,
|
||||
State(state): State<AppState>,
|
||||
Json(payload): Json<RefreshRequest>,
|
||||
) -> Result<Json<ApiResponse<RefreshResult>>, ErrorResponse> {
|
||||
log_info(
|
||||
&request_id,
|
||||
"刷新 token 请求",
|
||||
&json!({"device_id": "default"}),
|
||||
);
|
||||
|
||||
let user_repo = UserRepository::new(state.pool.clone());
|
||||
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
|
||||
|
||||
match service
|
||||
.refresh_access_token(&payload.refresh_token)
|
||||
.await
|
||||
{
|
||||
Ok((access_token, refresh_token)) => {
|
||||
let data = RefreshResult {
|
||||
access_token,
|
||||
refresh_token,
|
||||
};
|
||||
let response = ApiResponse::success(data);
|
||||
|
||||
log_info(
|
||||
&request_id,
|
||||
"刷新成功",
|
||||
&json!({"access_token": "***"}),
|
||||
);
|
||||
Ok(Json(response))
|
||||
}
|
||||
Err(e) => {
|
||||
log_info(&request_id, "刷新失败", &e.to_string());
|
||||
Err(ErrorResponse::new(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 删除账号
|
||||
pub async fn delete_account(
|
||||
Extension(request_id): Extension<RequestId>,
|
||||
State(state): State<AppState>,
|
||||
Extension(user_id): Extension<String>,
|
||||
Json(payload): Json<DeleteUserRequest>,
|
||||
) -> Result<Json<ApiResponse<()>>, ErrorResponse> {
|
||||
log_info(&request_id, "删除账号请求", &format!("user_id={}", user_id));
|
||||
|
||||
let user_repo = UserRepository::new(state.pool.clone());
|
||||
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
|
||||
|
||||
let delete_request = DeleteUserRequest {
|
||||
user_id: user_id.clone(),
|
||||
password: payload.password,
|
||||
};
|
||||
|
||||
match service.delete_user(delete_request).await {
|
||||
Ok(_) => {
|
||||
log_info(&request_id, "账号删除成功", &format!("user_id={}", user_id));
|
||||
let response = ApiResponse::success_with_message((), "账号删除成功");
|
||||
Ok(Json(response))
|
||||
}
|
||||
Err(e) => {
|
||||
log_info(&request_id, "账号删除失败", &e.to_string());
|
||||
Err(ErrorResponse::new(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 刷新令牌
|
||||
pub async fn delete_refresh_token(
|
||||
Extension(request_id): Extension<RequestId>,
|
||||
State(state): State<AppState>,
|
||||
Extension(user_id): Extension<String>,
|
||||
) -> Result<Json<ApiResponse<()>>, ErrorResponse> {
|
||||
log_info(&request_id, "删除刷新令牌请求", &format!("user_id={}", user_id));
|
||||
|
||||
let user_repo = UserRepository::new(state.pool.clone());
|
||||
let service = AuthService::new(user_repo, state.redis_client.clone(), state.config.auth.clone());
|
||||
|
||||
match service.delete_refresh_token(&user_id).await {
|
||||
Ok(_) => {
|
||||
log_info(&request_id, "刷新令牌删除成功", &format!("user_id={}", user_id));
|
||||
let response = ApiResponse::success_with_message((), "刷新令牌删除成功");
|
||||
Ok(Json(response))
|
||||
}
|
||||
Err(e) => {
|
||||
log_info(&request_id, "刷新令牌删除失败", &e.to_string());
|
||||
Err(ErrorResponse::new(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
25
src/handlers/health.rs
Normal file
25
src/handlers/health.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use crate::AppState;
|
||||
use crate::db;
|
||||
use axum::{
|
||||
extract::State,
|
||||
response::{IntoResponse, Json},
|
||||
};
|
||||
use serde_json::json;
|
||||
|
||||
/// 健康检查端点
|
||||
pub async fn health_check(State(state): State<AppState>) -> impl IntoResponse {
|
||||
match db::health_check(&state.pool).await {
|
||||
Ok(_) => Json(json!({"status": "ok"})),
|
||||
Err(_) => Json(json!({"status": "unavailable"})),
|
||||
}
|
||||
}
|
||||
|
||||
/// 获取服务器信息
|
||||
pub async fn server_info() -> impl IntoResponse {
|
||||
Json(json!({
|
||||
"name": "web-rust-template",
|
||||
"version": "0.1.0",
|
||||
"status": "running",
|
||||
"timestamp": chrono::Utc::now().timestamp()
|
||||
}))
|
||||
}
|
||||
2
src/handlers/mod.rs
Normal file
2
src/handlers/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod auth;
|
||||
pub mod health;
|
||||
51
src/infra/middleware/auth.rs
Normal file
51
src/infra/middleware/auth.rs
Normal file
@@ -0,0 +1,51 @@
|
||||
use crate::AppState;
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use jsonwebtoken::{decode, DecodingKey, Validation};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String, // user_id
|
||||
#[allow(dead_code)]
|
||||
pub exp: usize,
|
||||
}
|
||||
|
||||
/// JWT 认证中间件
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
mut req: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
// 1. 提取 Authorization header
|
||||
let auth_header = headers
|
||||
.get("Authorization")
|
||||
.and_then(|h| h.to_str().ok())
|
||||
.ok_or(StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
if !auth_header.starts_with("Bearer ") {
|
||||
return Err(StatusCode::UNAUTHORIZED);
|
||||
}
|
||||
|
||||
let token = &auth_header[7..];
|
||||
|
||||
// 2. 验证 JWT
|
||||
let jwt_secret = &state.config.auth.jwt_secret;
|
||||
|
||||
let token_data = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(jwt_secret.as_ref()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|_| StatusCode::UNAUTHORIZED)?;
|
||||
|
||||
// 3. 将 user_id 添加到请求扩展
|
||||
req.extensions_mut().insert(token_data.claims.sub);
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
71
src/infra/middleware/logging.rs
Normal file
71
src/infra/middleware/logging.rs
Normal file
@@ -0,0 +1,71 @@
|
||||
use axum::{extract::Request, response::Response};
|
||||
use std::time::Instant;
|
||||
|
||||
/// Request ID 标记
|
||||
#[derive(Clone)]
|
||||
pub struct RequestId(pub String);
|
||||
|
||||
/// 请求日志中间件
|
||||
pub async fn request_logging_middleware(
|
||||
mut req: Request,
|
||||
next: axum::middleware::Next,
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
|
||||
// 提取请求信息
|
||||
let method = req.method().clone();
|
||||
let path = req.uri().path().to_string();
|
||||
let query = req.uri().query().map(|s| s.to_string());
|
||||
|
||||
// 生成请求 ID
|
||||
let request_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// 将 request_id 存储到请求扩展中
|
||||
req.extensions_mut().insert(RequestId(request_id.clone()));
|
||||
|
||||
// 第1条日志:请求开始
|
||||
let separator = "=".repeat(80);
|
||||
let header = format!("{} {}", method, path);
|
||||
|
||||
tracing::info!("{}", separator);
|
||||
tracing::info!("{}", header);
|
||||
tracing::info!("{}", separator);
|
||||
|
||||
let now_beijing = chrono::Local::now().format("%Y-%m-%d %H:%M:%S%.3f");
|
||||
let query_str = query.as_deref().unwrap_or("无");
|
||||
tracing::info!(
|
||||
"[{}] 📥 查询参数: {} | 时间: {}",
|
||||
request_id,
|
||||
query_str,
|
||||
now_beijing
|
||||
);
|
||||
|
||||
// 调用下一个处理器
|
||||
let response = next.run(req).await;
|
||||
|
||||
// 第3条日志:请求完成
|
||||
let duration = start.elapsed();
|
||||
let status = response.status();
|
||||
tracing::info!(
|
||||
"[{}] ✅ 状态码: {} | 耗时: {}ms",
|
||||
request_id,
|
||||
status.as_u16(),
|
||||
duration.as_millis()
|
||||
);
|
||||
|
||||
tracing::info!("{}", separator);
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// 请求日志辅助工具
|
||||
pub fn log_info<T: std::fmt::Debug>(request_id: &RequestId, label: &str, data: T) {
|
||||
let data_str = format!("{:?}", data);
|
||||
let truncated = if data_str.len() > 300 {
|
||||
format!("{}...", &data_str[..300])
|
||||
} else {
|
||||
data_str
|
||||
};
|
||||
|
||||
tracing::info!("[{}] 🔧 {} | {}", request_id.0, label, truncated);
|
||||
}
|
||||
2
src/infra/middleware/mod.rs
Normal file
2
src/infra/middleware/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod auth;
|
||||
pub mod logging;
|
||||
2
src/infra/mod.rs
Normal file
2
src/infra/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod middleware;
|
||||
pub mod redis;
|
||||
20
src/infra/redis/errors.rs
Normal file
20
src/infra/redis/errors.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Redis 错误类型
|
||||
#[derive(Error, Debug)]
|
||||
pub enum RedisError {
|
||||
#[error("Redis 连接失败: {0}")]
|
||||
ConnectionError(#[from] redis::RedisError),
|
||||
|
||||
#[error("Redis 序列化失败: {0}")]
|
||||
SerializationError(#[from] serde_json::Error),
|
||||
|
||||
#[error("Redis 数据不存在: {key}")]
|
||||
NotFound { key: String },
|
||||
|
||||
#[error("Redis 操作失败: {message}")]
|
||||
OperationError { message: String },
|
||||
|
||||
#[error("Failed to create redis pool: {0}")]
|
||||
PoolCreation(#[from] deadpool_redis::CreatePoolError),
|
||||
}
|
||||
2
src/infra/redis/mod.rs
Normal file
2
src/infra/redis/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod redis_client;
|
||||
pub mod redis_key;
|
||||
135
src/infra/redis/redis_client.rs
Normal file
135
src/infra/redis/redis_client.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use super::redis_key::RedisKey;
|
||||
use redis::aio::MultiplexedConnection;
|
||||
use redis::{AsyncCommands, Client};
|
||||
use serde::Serialize;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// Redis 客户端(使用 MultiplexedConnection)
|
||||
#[derive(Clone)]
|
||||
pub struct RedisClient {
|
||||
conn: Arc<Mutex<MultiplexedConnection>>,
|
||||
}
|
||||
|
||||
impl RedisClient {
|
||||
/// 创建新的 Redis 客户端
|
||||
pub async fn new(url: &str) -> redis::RedisResult<Self> {
|
||||
let client = Client::open(url)?;
|
||||
let conn = client.get_multiplexed_async_connection().await?;
|
||||
Ok(Self {
|
||||
conn: Arc::new(Mutex::new(conn)),
|
||||
})
|
||||
}
|
||||
|
||||
/// 设置字符串值
|
||||
pub async fn set(&self, k: &str, v: &str) -> redis::RedisResult<()> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.set(k, v).await
|
||||
}
|
||||
|
||||
/// 获取字符串值
|
||||
pub async fn get(&self, k: &str) -> redis::RedisResult<Option<String>> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.get(k).await
|
||||
}
|
||||
|
||||
/// 设置字符串值并指定过期时间(秒)
|
||||
pub async fn set_ex(&self, k: &str, v: &str, seconds: u64) -> redis::RedisResult<()> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.set_ex(k, v, seconds).await
|
||||
}
|
||||
|
||||
/// 删除键
|
||||
pub async fn del(&self, k: &str) -> redis::RedisResult<()> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.del(k).await
|
||||
}
|
||||
|
||||
/// 设置键的过期时间(秒)
|
||||
pub async fn expire(&self, k: &str, seconds: u64) -> redis::RedisResult<()> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.expire(k, seconds as i64).await
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 设置 JSON 值
|
||||
pub async fn set_key<T: Serialize>(
|
||||
&self,
|
||||
key: &RedisKey,
|
||||
value: &T,
|
||||
) -> redis::RedisResult<()> {
|
||||
let json = serde_json::to_string(value).map_err(|e| {
|
||||
redis::RedisError::from((
|
||||
redis::ErrorKind::TypeError,
|
||||
"JSON serialization failed",
|
||||
e.to_string(),
|
||||
))
|
||||
})?;
|
||||
let mut c = self.conn.lock().await;
|
||||
c.set(key.build(), json).await
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 设置 JSON 值并指定过期时间(秒)
|
||||
pub async fn set_key_ex<T: Serialize>(
|
||||
&self,
|
||||
key: &RedisKey,
|
||||
value: &T,
|
||||
expiration_seconds: u64,
|
||||
) -> redis::RedisResult<()> {
|
||||
let json = serde_json::to_string(value).map_err(|e| {
|
||||
redis::RedisError::from((
|
||||
redis::ErrorKind::TypeError,
|
||||
"JSON serialization failed",
|
||||
e.to_string(),
|
||||
))
|
||||
})?;
|
||||
let mut c = self.conn.lock().await;
|
||||
c.set_ex(key.build(), json, expiration_seconds).await
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 获取字符串值
|
||||
pub async fn get_key(&self, key: &RedisKey) -> redis::RedisResult<Option<String>> {
|
||||
let mut c = self.conn.lock().await;
|
||||
let json: Option<String> = c.get(key.build()).await?;
|
||||
Ok(json)
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 获取并反序列化 JSON 值
|
||||
pub async fn get_key_json<T: for<'de> serde::Deserialize<'de>>(
|
||||
&self,
|
||||
key: &RedisKey,
|
||||
) -> redis::RedisResult<Option<T>> {
|
||||
let mut c = self.conn.lock().await;
|
||||
let json: Option<String> = c.get(key.build()).await?;
|
||||
match json {
|
||||
Some(data) => {
|
||||
let value = serde_json::from_str(&data).map_err(|e| {
|
||||
redis::RedisError::from((
|
||||
redis::ErrorKind::TypeError,
|
||||
"JSON deserialization failed",
|
||||
e.to_string(),
|
||||
))
|
||||
})?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 删除键
|
||||
pub async fn delete_key(&self, key: &RedisKey) -> redis::RedisResult<()> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.del(key.build()).await
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 检查键是否存在
|
||||
pub async fn exists_key(&self, key: &RedisKey) -> redis::RedisResult<bool> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.exists(key.build()).await
|
||||
}
|
||||
|
||||
/// 使用 RedisKey 设置键的过期时间(秒)
|
||||
pub async fn expire_key(&self, key: &RedisKey, seconds: u64) -> redis::RedisResult<()> {
|
||||
let mut c = self.conn.lock().await;
|
||||
c.expire(key.build(), seconds as i64).await
|
||||
}
|
||||
}
|
||||
61
src/infra/redis/redis_key.rs
Normal file
61
src/infra/redis/redis_key.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// 业务类型枚举
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum BusinessType {
|
||||
#[serde(rename = "auth")]
|
||||
Auth,
|
||||
#[serde(rename = "user")]
|
||||
User,
|
||||
#[serde(rename = "cache")]
|
||||
Cache,
|
||||
#[serde(rename = "session")]
|
||||
Session,
|
||||
#[serde(rename = "rate_limit")]
|
||||
RateLimit,
|
||||
}
|
||||
|
||||
impl BusinessType {
|
||||
pub fn prefix(self) -> &'static str {
|
||||
match self {
|
||||
BusinessType::Auth => "auth",
|
||||
BusinessType::User => "user",
|
||||
BusinessType::Cache => "cache",
|
||||
BusinessType::Session => "session",
|
||||
BusinessType::RateLimit => "rate_limit",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Redis 键构建器
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RedisKey {
|
||||
business: BusinessType,
|
||||
identifiers: Vec<String>,
|
||||
}
|
||||
|
||||
impl RedisKey {
|
||||
pub fn new(business: BusinessType) -> Self {
|
||||
Self {
|
||||
business,
|
||||
identifiers: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_identifier(mut self, id: impl Into<String>) -> Self {
|
||||
self.identifiers.push(id.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(&self) -> String {
|
||||
format!("{}:{}", self.business.prefix(), self.identifiers.join(":"))
|
||||
}
|
||||
}
|
||||
|
||||
// 兼容现有格式
|
||||
impl fmt::Display for RedisKey {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{}", self.build())
|
||||
}
|
||||
}
|
||||
131
src/main.rs
Normal file
131
src/main.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
mod cli;
|
||||
mod config;
|
||||
mod db;
|
||||
mod domain;
|
||||
mod error;
|
||||
mod handlers;
|
||||
mod infra;
|
||||
mod repositories;
|
||||
mod services;
|
||||
mod utils;
|
||||
|
||||
use axum::{
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use clap::Parser;
|
||||
use cli::CliArgs;
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
/// 应用状态
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub pool: db::DbPool,
|
||||
pub config: config::app::AppConfig,
|
||||
pub redis_client: infra::redis::redis_client::RedisClient,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
// 解析命令行参数
|
||||
let args = CliArgs::parse();
|
||||
|
||||
// 初始化日志
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| args.get_log_filter().into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
// 打印启动信息
|
||||
args.print_startup_info();
|
||||
|
||||
// 设置工作目录(如果指定)
|
||||
if let Some(ref work_dir) = args.work_dir {
|
||||
std::env::set_current_dir(work_dir).ok();
|
||||
println!("Working directory set to: {}", work_dir.display());
|
||||
}
|
||||
|
||||
// 解析配置文件路径(可选)
|
||||
let config_path = args.resolve_config_path();
|
||||
|
||||
// 加载配置(支持 CLI 覆盖)
|
||||
// 如果没有配置文件,将仅使用环境变量和默认值
|
||||
let config = config::app::AppConfig::load_with_overrides(
|
||||
config_path,
|
||||
args.get_overrides(),
|
||||
args.env.as_str(),
|
||||
)?;
|
||||
|
||||
tracing::info!("Configuration loaded successfully");
|
||||
tracing::info!("Environment: {}", args.env.as_str());
|
||||
tracing::info!("Debug mode: {}", args.is_debug_enabled());
|
||||
|
||||
// 初始化数据库(自动创建数据库和表)
|
||||
let pool = db::init_database(&config.database).await?;
|
||||
|
||||
// 初始化 Redis 客户端
|
||||
let redis_client = infra::redis::redis_client::RedisClient::new(&config.redis.build_url())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Redis 初始化失败: {}", e))?;
|
||||
|
||||
tracing::info!("Redis 连接池初始化成功");
|
||||
|
||||
// 创建应用状态
|
||||
let app_state = AppState {
|
||||
pool: pool.clone(),
|
||||
config: config.clone(),
|
||||
redis_client,
|
||||
};
|
||||
|
||||
// ========== 公开路由 ==========
|
||||
let public_routes = Router::new()
|
||||
.route("/health", get(handlers::health::health_check))
|
||||
.route("/info", get(handlers::health::server_info))
|
||||
.route("/auth/register", post(handlers::auth::register))
|
||||
.route("/auth/login", post(handlers::auth::login))
|
||||
.route("/auth/refresh", post(handlers::auth::refresh));
|
||||
|
||||
// ========== 受保护路由 ==========
|
||||
let protected_routes = Router::new()
|
||||
.route("/auth/delete", post(handlers::auth::delete_account))
|
||||
.route(
|
||||
"/auth/delete-refresh-token",
|
||||
post(handlers::auth::delete_refresh_token),
|
||||
)
|
||||
// JWT 认证中间件(仅应用于受保护路由)
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
infra::middleware::auth::auth_middleware,
|
||||
));
|
||||
|
||||
// ========== 合并路由 ==========
|
||||
let app = public_routes
|
||||
.merge(protected_routes)
|
||||
// CORS(应用于所有路由)
|
||||
.layer(
|
||||
CorsLayer::new()
|
||||
.allow_origin(Any)
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any)
|
||||
)
|
||||
// 日志中间件(应用于所有路由)
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
app_state.clone(),
|
||||
infra::middleware::logging::request_logging_middleware,
|
||||
))
|
||||
.with_state(app_state);
|
||||
|
||||
// 启动服务器
|
||||
let addr = format!("{}:{}", config.server.host, config.server.port);
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
tracing::info!("Server listening on {}", addr);
|
||||
tracing::info!("Press Ctrl+C to stop");
|
||||
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
2
src/repositories/mod.rs
Normal file
2
src/repositories/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod user_repository;
|
||||
|
||||
85
src/repositories/user_repository.rs
Normal file
85
src/repositories/user_repository.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
use sea_orm::{EntityTrait, QueryFilter, ColumnTrait, DatabaseConnection, Set, ActiveModelTrait, PaginatorTrait};
|
||||
use crate::domain::entities::users;
|
||||
use anyhow::Result;
|
||||
|
||||
/// 用户数据访问仓库
|
||||
pub struct UserRepository {
|
||||
db: DatabaseConnection,
|
||||
}
|
||||
|
||||
impl UserRepository {
|
||||
pub fn new(db: DatabaseConnection) -> Self {
|
||||
Self { db }
|
||||
}
|
||||
|
||||
/// 根据 email 查询用户
|
||||
pub async fn find_by_email(&self, email: &str) -> Result<Option<users::Model>> {
|
||||
let user = users::Entity::find()
|
||||
.filter(users::Column::Email.eq(email))
|
||||
.one(&self.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
/// 统计邮箱数量
|
||||
pub async fn count_by_email(&self, email: &str) -> Result<i64> {
|
||||
let count = users::Entity::find()
|
||||
.filter(users::Column::Email.eq(email))
|
||||
.count(&self.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
|
||||
|
||||
Ok(count as i64)
|
||||
}
|
||||
|
||||
/// 统计用户 ID 数量
|
||||
pub async fn count_by_id(&self, id: &str) -> Result<i64> {
|
||||
let count = users::Entity::find()
|
||||
.filter(users::Column::Id.eq(id))
|
||||
.count(&self.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
|
||||
|
||||
Ok(count as i64)
|
||||
}
|
||||
|
||||
/// 获取密码哈希
|
||||
pub async fn get_password_hash(&self, email: &str) -> Result<Option<String>> {
|
||||
let user = users::Entity::find()
|
||||
.filter(users::Column::Email.eq(email))
|
||||
.one(&self.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("查询失败: {}", e))?;
|
||||
|
||||
Ok(user.map(|u| u.password_hash))
|
||||
}
|
||||
|
||||
/// 插入用户(created_at 和 updated_at 会自动填充),返回插入后的用户对象
|
||||
pub async fn insert(&self, id: String, email: String, password_hash: String) -> Result<users::Model> {
|
||||
let user_model = users::ActiveModel {
|
||||
id: Set(id),
|
||||
email: Set(email),
|
||||
password_hash: Set(password_hash),
|
||||
// created_at 和 updated_at 由 ActiveModelBehavior 自动填充
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let inserted_user = user_model.insert(&self.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("插入失败: {}", e))?;
|
||||
|
||||
Ok(inserted_user)
|
||||
}
|
||||
|
||||
/// 根据 ID 删除用户
|
||||
pub async fn delete_by_id(&self, id: &str) -> Result<()> {
|
||||
users::Entity::delete_by_id(id)
|
||||
.exec(&self.db)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("删除失败: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
232
src/services/auth_service.rs
Normal file
232
src/services/auth_service.rs
Normal file
@@ -0,0 +1,232 @@
|
||||
use anyhow::Result;
|
||||
use argon2::{
|
||||
password_hash::{rand_core::OsRng, PasswordHash, PasswordHasher, PasswordVerifier, SaltString},
|
||||
Argon2,
|
||||
};
|
||||
use rand::Rng;
|
||||
|
||||
use crate::utils::jwt::TokenService;
|
||||
use crate::domain::dto::auth::{RegisterRequest, LoginRequest, DeleteUserRequest};
|
||||
use crate::domain::entities::users;
|
||||
use crate::config::auth::AuthConfig;
|
||||
use crate::infra::redis::{redis_client::RedisClient, redis_key::{BusinessType, RedisKey}};
|
||||
use crate::repositories::user_repository::UserRepository;
|
||||
|
||||
pub struct AuthService {
|
||||
user_repo: UserRepository,
|
||||
redis_client: RedisClient,
|
||||
auth_config: AuthConfig,
|
||||
}
|
||||
|
||||
impl AuthService {
|
||||
pub fn new(user_repo: UserRepository, redis_client: RedisClient, auth_config: AuthConfig) -> Self {
|
||||
Self { user_repo, redis_client, auth_config }
|
||||
}
|
||||
|
||||
/// 哈希密码
|
||||
pub fn hash_password(&self, password: &str) -> Result<String> {
|
||||
let salt = SaltString::generate(&mut OsRng);
|
||||
let argon2 = Argon2::default();
|
||||
let password_hash = argon2
|
||||
.hash_password(password.as_bytes(), &salt)
|
||||
.map_err(|e| anyhow::anyhow!("密码哈希失败: {}", e))?
|
||||
.to_string();
|
||||
Ok(password_hash)
|
||||
}
|
||||
|
||||
/// 生成用户 ID
|
||||
pub fn generate_user_id(&self) -> String {
|
||||
let mut rng = rand::thread_rng();
|
||||
rng.gen_range(1_000_000_000i64..10_000_000_000i64)
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// 生成唯一的用户 ID
|
||||
pub async fn generate_unique_user_id(&self) -> Result<String> {
|
||||
let mut attempts = 0;
|
||||
const MAX_ATTEMPTS: u32 = 10;
|
||||
|
||||
loop {
|
||||
let candidate_id = self.generate_user_id();
|
||||
|
||||
let existing = self.user_repo.count_by_id(&candidate_id).await?;
|
||||
if existing == 0 {
|
||||
return Ok(candidate_id);
|
||||
}
|
||||
|
||||
attempts += 1;
|
||||
if attempts >= MAX_ATTEMPTS {
|
||||
return Err(anyhow::anyhow!("生成唯一用户 ID 失败"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// 保存 refresh_token 到 Redis
|
||||
async fn save_refresh_token(&self, user_id: &str, refresh_token: &str, expiration_days: i64) -> Result<()> {
|
||||
let key = RedisKey::new(BusinessType::Auth)
|
||||
.add_identifier("refresh_token")
|
||||
.add_identifier(user_id);
|
||||
|
||||
let expiration_seconds = expiration_days * 24 * 3600;
|
||||
|
||||
self.redis_client
|
||||
.set_ex(&key.build(), refresh_token, expiration_seconds as u64)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Redis 保存失败: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 获取并删除 refresh_token
|
||||
async fn get_and_delete_refresh_token(&self, user_id: &str) -> Result<String> {
|
||||
let key = RedisKey::new(BusinessType::Auth)
|
||||
.add_identifier("refresh_token")
|
||||
.add_identifier(user_id);
|
||||
|
||||
let token: Option<String> = self.redis_client
|
||||
.get(&key.build())
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Redis 查询失败: {}", e))?;
|
||||
|
||||
if token.is_some() {
|
||||
self.redis_client
|
||||
.delete_key(&key)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Redis 删除失败: {}", e))?;
|
||||
}
|
||||
|
||||
token.ok_or_else(|| anyhow::anyhow!("刷新令牌无效或已过期"))
|
||||
}
|
||||
|
||||
/// 删除用户的 refresh_token
|
||||
pub async fn delete_refresh_token(&self, user_id: &str) -> Result<()> {
|
||||
let key = RedisKey::new(BusinessType::Auth)
|
||||
.add_identifier("refresh_token")
|
||||
.add_identifier(user_id);
|
||||
|
||||
self.redis_client
|
||||
.delete_key(&key)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Redis 删除失败: {}", e))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// 注册用户
|
||||
pub async fn register(
|
||||
&self,
|
||||
request: RegisterRequest,
|
||||
) -> Result<(users::Model, String, String)> {
|
||||
// 1. 检查邮箱是否已存在
|
||||
let existing = self.user_repo.count_by_email(&request.email).await?;
|
||||
|
||||
if existing > 0 {
|
||||
return Err(anyhow::anyhow!("邮箱已注册"));
|
||||
}
|
||||
|
||||
// 2. 哈希密码
|
||||
let password_hash = self.hash_password(&request.password)?;
|
||||
|
||||
// 3. 生成用户 ID
|
||||
let user_id = self.generate_unique_user_id().await?;
|
||||
|
||||
// 4. 插入数据库并获取包含真实 created_at 的用户对象
|
||||
let user = self.user_repo.insert(user_id.clone(), request.email, password_hash).await?;
|
||||
|
||||
// 5. 生成 token
|
||||
let (access_token, refresh_token) = TokenService::generate_token_pair(
|
||||
&user_id,
|
||||
self.auth_config.access_token_expiration_minutes,
|
||||
self.auth_config.refresh_token_expiration_days,
|
||||
&self.auth_config.jwt_secret,
|
||||
)?;
|
||||
|
||||
// 6. 保存 refresh_token
|
||||
self.save_refresh_token(&user_id, &refresh_token, self.auth_config.refresh_token_expiration_days as i64).await?;
|
||||
|
||||
Ok((user, access_token, refresh_token))
|
||||
}
|
||||
|
||||
/// 登录
|
||||
pub async fn login(
|
||||
&self,
|
||||
request: LoginRequest,
|
||||
) -> Result<(users::Model, String, String)> {
|
||||
// 1. 查询用户
|
||||
let user = self.user_repo.find_by_email(&request.email).await?
|
||||
.ok_or_else(|| anyhow::anyhow!("邮箱或密码错误"))?;
|
||||
|
||||
// 2. 验证密码
|
||||
let password_hash = self.user_repo.get_password_hash(&request.email).await?
|
||||
.ok_or_else(|| anyhow::anyhow!("邮箱或密码错误"))?;
|
||||
|
||||
let parsed_hash = PasswordHash::new(&password_hash)
|
||||
.map_err(|e| anyhow::anyhow!("解析密码哈希失败: {}", e))?;
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
argon2
|
||||
.verify_password(request.password.as_bytes(), &parsed_hash)
|
||||
.map_err(|_| anyhow::anyhow!("邮箱或密码错误"))?;
|
||||
|
||||
// 3. 生成 token
|
||||
let (access_token, refresh_token) = TokenService::generate_token_pair(
|
||||
&user.id,
|
||||
self.auth_config.access_token_expiration_minutes,
|
||||
self.auth_config.refresh_token_expiration_days,
|
||||
&self.auth_config.jwt_secret,
|
||||
)?;
|
||||
|
||||
// 4. 保存 refresh_token
|
||||
self.save_refresh_token(&user.id, &refresh_token, self.auth_config.refresh_token_expiration_days as i64).await?;
|
||||
|
||||
Ok((user, access_token, refresh_token))
|
||||
}
|
||||
|
||||
/// 使用 refresh_token 刷新 access_token
|
||||
pub async fn refresh_access_token(
|
||||
&self,
|
||||
refresh_token: &str,
|
||||
) -> Result<(String, String)> {
|
||||
// 1. 从 refresh_token 中解码出 user_id
|
||||
let user_id = TokenService::decode_user_id(refresh_token, &self.auth_config.jwt_secret)?;
|
||||
|
||||
// 2. 从 Redis 获取存储的 token 并删除
|
||||
let stored_token = self.get_and_delete_refresh_token(&user_id).await?;
|
||||
|
||||
// 3. 验证 token 是否匹配
|
||||
if stored_token != refresh_token {
|
||||
return Err(anyhow::anyhow!("刷新令牌无效"));
|
||||
}
|
||||
|
||||
// 4. 生成新的 token 对
|
||||
let (new_access_token, new_refresh_token) = TokenService::generate_token_pair(
|
||||
&user_id,
|
||||
self.auth_config.access_token_expiration_minutes,
|
||||
self.auth_config.refresh_token_expiration_days,
|
||||
&self.auth_config.jwt_secret,
|
||||
)?;
|
||||
|
||||
// 5. 保存新的 refresh_token
|
||||
self.save_refresh_token(&user_id, &new_refresh_token, self.auth_config.refresh_token_expiration_days as i64).await?;
|
||||
|
||||
Ok((new_access_token, new_refresh_token))
|
||||
}
|
||||
|
||||
/// 删除用户
|
||||
pub async fn delete_user(&self, request: DeleteUserRequest) -> Result<()> {
|
||||
let password_hash = self.user_repo.get_password_hash(&request.user_id).await?
|
||||
.ok_or_else(|| anyhow::anyhow!("用户不存在"))?;
|
||||
|
||||
let parsed_hash = PasswordHash::new(&password_hash)
|
||||
.map_err(|e| anyhow::anyhow!("解析密码哈希失败: {}", e))?;
|
||||
let argon2 = Argon2::default();
|
||||
|
||||
argon2
|
||||
.verify_password(request.password.as_bytes(), &parsed_hash)
|
||||
.map_err(|_| anyhow::anyhow!("密码错误"))?;
|
||||
|
||||
self.user_repo.delete_by_id(&request.user_id).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
1
src/services/mod.rs
Normal file
1
src/services/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod auth_service;
|
||||
101
src/utils/jwt.rs
Normal file
101
src/utils/jwt.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use anyhow::Result;
|
||||
use chrono::{Duration, Utc};
|
||||
use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// JWT 工具类,负责生成和验证 JWT token
|
||||
pub struct TokenService;
|
||||
|
||||
impl TokenService {
|
||||
/// 生成 JWT access token
|
||||
pub fn generate_access_token(
|
||||
user_id: &str,
|
||||
expiration_minutes: u64,
|
||||
jwt_secret: &str,
|
||||
) -> Result<String> {
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(Duration::minutes(expiration_minutes as i64))
|
||||
.expect("invalid expiration timestamp")
|
||||
.timestamp() as usize;
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
exp: expiration,
|
||||
token_type: TokenType::Access,
|
||||
};
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(jwt_secret.as_ref()),
|
||||
)?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// 生成 refresh token
|
||||
pub fn generate_refresh_token(
|
||||
user_id: &str,
|
||||
expiration_days: i64,
|
||||
jwt_secret: &str,
|
||||
) -> Result<String> {
|
||||
let expiration = Utc::now()
|
||||
.checked_add_signed(Duration::days(expiration_days))
|
||||
.expect("invalid expiration timestamp")
|
||||
.timestamp() as usize;
|
||||
|
||||
let claims = Claims {
|
||||
sub: user_id.to_string(),
|
||||
exp: expiration,
|
||||
token_type: TokenType::Refresh,
|
||||
};
|
||||
|
||||
let token = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(jwt_secret.as_ref()),
|
||||
)?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// 生成 access token 和 refresh token
|
||||
pub fn generate_token_pair(
|
||||
user_id: &str,
|
||||
access_token_expiration_minutes: u64,
|
||||
refresh_token_expiration_days: i64,
|
||||
jwt_secret: &str,
|
||||
) -> Result<(String, String)> {
|
||||
let access_token =
|
||||
Self::generate_access_token(user_id, access_token_expiration_minutes, jwt_secret)?;
|
||||
let refresh_token =
|
||||
Self::generate_refresh_token(user_id, refresh_token_expiration_days, jwt_secret)?;
|
||||
|
||||
Ok((access_token, refresh_token))
|
||||
}
|
||||
|
||||
/// 从 token 中解码出 user_id
|
||||
pub fn decode_user_id(token: &str, jwt_secret: &str) -> Result<String> {
|
||||
let token_data = decode::<Claims>(
|
||||
token,
|
||||
&DecodingKey::from_secret(jwt_secret.as_ref()),
|
||||
&Validation::default(),
|
||||
)
|
||||
.map_err(|e| anyhow::anyhow!("Token 解码失败: {}", e))?;
|
||||
|
||||
Ok(token_data.claims.sub)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String, // user_id
|
||||
pub exp: usize, // 过期时间
|
||||
pub token_type: TokenType,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum TokenType {
|
||||
Access,
|
||||
Refresh,
|
||||
}
|
||||
1
src/utils/mod.rs
Normal file
1
src/utils/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod jwt;
|
||||
Reference in New Issue
Block a user