diff --git a/README.md b/README.md index 0cfaedf..636e815 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,7 @@ func init() { ```go func main() { - // Use StpUtil directly without passing manager + // Use StpUtil directly without passing manager-example token, _ := stputil.Login(1000) println("Login successful, Token:", token) @@ -183,13 +183,13 @@ hasAny := stputil.HasPermissionsOr(1000, []string{"admin", "super"}) / ```go // Set roles -stputil.SetRoles(1000, []string{"admin", "manager"}) +stputil.SetRoles(1000, []string{"admin", "manager-example"}) // Check role hasRole := stputil.HasRole(1000, "admin") // Check multiple roles -hasAll := stputil.HasRolesAnd(1000, []string{"admin", "manager"}) +hasAll := stputil.HasRolesAnd(1000, []string{"admin", "manager-example"}) hasAny := stputil.HasRolesOr(1000, []string{"admin", "super"}) ``` @@ -266,7 +266,7 @@ func main() { r.GET("/public", sagin.Ignore(), publicHandler) // Public access r.GET("/user", sagin.CheckLogin(), userHandler) // Login required r.GET("/admin", sagin.CheckPermission("admin:*"), adminHandler) // Permission required - r.GET("/manager", sagin.CheckRole("manager"), managerHandler) // Role required + r.GET("/manager-example", sagin.CheckRole("manager-example"), managerHandler) // Role required r.GET("/sensitive", sagin.CheckDisable(), sensitiveHandler) // Check if disabled r.Run(":8080") @@ -308,7 +308,7 @@ func main() { userOrAdminHandler) // Admin role required - r.GET("/manager", sagin.CheckRole("admin"), managerHandler) + r.GET("/manager-example", sagin.CheckRole("admin"), managerHandler) // Check if account is disabled r.GET("/sensitive", sagin.CheckDisable(), sensitiveHandler) @@ -349,7 +349,7 @@ func main() { s.BindHandler("GET:/public", sagf.Ignore(), publicHandler) // Public access s.BindHandler("GET:/user", sagf.CheckLogin(), userHandler) // Login required s.BindHandler("GET:/admin", sagf.CheckPermission("admin:*"), adminHandler) // Permission required - s.BindHandler("GET:/manager", sagf.CheckRole("manager"), managerHandler) // Role required + s.BindHandler("GET:/manager-example", sagf.CheckRole("manager-example"), managerHandler) // Role required s.BindHandler("GET:/sensitive", sagf.CheckDisable(), sensitiveHandler) // Check if disabled s.SetPort(8080) @@ -506,7 +506,7 @@ manager.RegisterFunc(core.EventAll, func(data *core.EventData) { // Access advanced controls via the underlying EventManager manager.GetEventManager().SetPanicHandler(customPanicHandler) -// Use the manager globally +// Use the manager-example globally stputil.SetManager(manager) ``` diff --git a/README_zh.md b/README_zh.md index 622a0b0..dddeda2 100644 --- a/README_zh.md +++ b/README_zh.md @@ -183,13 +183,13 @@ hasAny := stputil.HasPermissionsOr(1000, []string{"admin", "super"}) / ```go // 设置角色 -stputil.SetRoles(1000, []string{"admin", "manager"}) +stputil.SetRoles(1000, []string{"admin", "manager-example"}) // 检查角色 hasRole := stputil.HasRole(1000, "admin") // 多角色检查 -hasAll := stputil.HasRolesAnd(1000, []string{"admin", "manager"}) +hasAll := stputil.HasRolesAnd(1000, []string{"admin", "manager-example"}) hasAny := stputil.HasRolesOr(1000, []string{"admin", "super"}) ``` @@ -266,7 +266,7 @@ func main() { r.GET("/public", sagin.Ignore(), publicHandler) // 公开访问 r.GET("/user", sagin.CheckLogin(), userHandler) // 需要登录 r.GET("/admin", sagin.CheckPermission("admin:*"), adminHandler) // 需要权限 - r.GET("/manager", sagin.CheckRole("manager"), managerHandler) // 需要角色 + r.GET("/manager-example", sagin.CheckRole("manager-example"), managerHandler) // 需要角色 r.GET("/sensitive", sagin.CheckDisable(), sensitiveHandler) // 检查封禁 r.Run(":8080") @@ -308,7 +308,7 @@ func main() { userOrAdminHandler) // 需要管理员角色 - r.GET("/manager", sagin.CheckRole("admin"), managerHandler) + r.GET("/manager-example", sagin.CheckRole("admin"), managerHandler) // 检查账号是否被封禁 r.GET("/sensitive", sagin.CheckDisable(), sensitiveHandler) @@ -349,7 +349,7 @@ func main() { s.BindHandler("GET:/public", sagf.Ignore(), publicHandler) // 公开访问 s.BindHandler("GET:/user", sagf.CheckLogin(), userHandler) // 需要登录 s.BindHandler("GET:/admin", sagf.CheckPermission("admin:*"), adminHandler) // 需要权限 - s.BindHandler("GET:/manager", sagf.CheckRole("manager"), managerHandler) // 需要角色 + s.BindHandler("GET:/manager-example", sagf.CheckRole("manager-example"), managerHandler) // 需要角色 s.BindHandler("GET:/sensitive", sagf.CheckDisable(), sensitiveHandler) // 检查是否禁用 s.SetPort(8080) diff --git a/codec/json/codec_adaper_json.go b/codec/json/codec_adaper_json.go new file mode 100644 index 0000000..b204802 --- /dev/null +++ b/codec/json/codec_adaper_json.go @@ -0,0 +1,22 @@ +// @Author daixk 2025/11/27 20:57:00 +package json + +import ( + "encoding/json" +) + +type JSONSerializer struct{} + +func (s *JSONSerializer) Encode(v any) ([]byte, error) { + return json.Marshal(v) +} + +func (s *JSONSerializer) Decode(data []byte, v any) error { + return json.Unmarshal(data, v) +} + +func (s *JSONSerializer) Name() string { return "json" } + +func NewJSONSerializer() *JSONSerializer { + return &JSONSerializer{} +} diff --git a/codec/json/go.mod b/codec/json/go.mod new file mode 100644 index 0000000..13a5759 --- /dev/null +++ b/codec/json/go.mod @@ -0,0 +1,3 @@ +module github.com/click33/sa-token-go/codec/json + +go 1.25.0 diff --git a/codec/msgpack/codec_adaper_msgpack.go b/codec/msgpack/codec_adaper_msgpack.go new file mode 100644 index 0000000..df0cd07 --- /dev/null +++ b/codec/msgpack/codec_adaper_msgpack.go @@ -0,0 +1,22 @@ +// @Author daixk 2025/11/27 20:58:00 +package msgpack + +import ( + "github.com/vmihailenco/msgpack/v5" +) + +type MsgPackSerializer struct{} + +func (s *MsgPackSerializer) Encode(v any) ([]byte, error) { + return msgpack.Marshal(v) +} + +func (s *MsgPackSerializer) Decode(data []byte, v any) error { + return msgpack.Unmarshal(data, v) +} + +func (s *MsgPackSerializer) Name() string { return "msgpack" } + +func NewMsgPackSerializer() *MsgPackSerializer { + return &MsgPackSerializer{} +} diff --git a/codec/msgpack/go.mod b/codec/msgpack/go.mod new file mode 100644 index 0000000..d6a5e09 --- /dev/null +++ b/codec/msgpack/go.mod @@ -0,0 +1,12 @@ +module github.com/click33/sa-token-go/codec/msgpack + +go 1.25.0 + +require github.com/vmihailenco/msgpack/v5 v5.4.1 + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/testify v1.11.1 // indirect + github.com/vmihailenco/tagparser/v2 v2.0.0 // indirect +) diff --git a/codec/msgpack/go.sum b/codec/msgpack/go.sum new file mode 100644 index 0000000..62714bd --- /dev/null +++ b/codec/msgpack/go.sum @@ -0,0 +1,6 @@ +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/vmihailenco/msgpack/v5 v5.4.1 h1:cQriyiUvjTwOHg8QZaPihLWeRAAVoCpE00IUPn0Bjt8= +github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/core/adapter/codec.go b/core/adapter/codec.go new file mode 100644 index 0000000..a05d842 --- /dev/null +++ b/core/adapter/codec.go @@ -0,0 +1,9 @@ +// @Author daixk 2025/12/12 10:46:00 +package adapter + +// Codec defines serialization behavior abstraction | 序列化行为抽象接口 +type Codec interface { + Encode(v any) ([]byte, error) // Encode value to byte slice | 将对象编码为字节数组 + Decode(data []byte, v any) error // Decode byte slice to target value | 将字节数组解码到目标对象 + Name() string // Return codec implementation name | 返回序列化器名称 +} diff --git a/core/adapter/generator.go b/core/adapter/generator.go new file mode 100644 index 0000000..6d12fd1 --- /dev/null +++ b/core/adapter/generator.go @@ -0,0 +1,44 @@ +// @Author daixk 2025/12/5 15:52:00 +package adapter + +// TokenStyle Token generation style | Token生成风格 +type TokenStyle string + +const ( + // TokenStyleUUID UUID style | UUID风格 + TokenStyleUUID TokenStyle = "uuid" + // TokenStyleSimple Simple random string | 简单随机字符串 + TokenStyleSimple TokenStyle = "simple" + // TokenStyleRandom32 32-bit random string | 32位随机字符串 + TokenStyleRandom32 TokenStyle = "random32" + // TokenStyleRandom64 64-bit random string | 64位随机字符串 + TokenStyleRandom64 TokenStyle = "random64" + // TokenStyleRandom128 128-bit random string | 128位随机字符串 + TokenStyleRandom128 TokenStyle = "random128" + // TokenStyleJWT JWT style | JWT风格 + TokenStyleJWT TokenStyle = "jwt" + // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 + TokenStyleHash TokenStyle = "hash" + // TokenStyleTimestamp Timestamp-based style | 时间戳风格 + TokenStyleTimestamp TokenStyle = "timestamp" + // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID(类似抖音) + TokenStyleTik TokenStyle = "tik" +) + +// IsValid checks if the TokenStyle is valid | 检查TokenStyle是否有效 +func (ts TokenStyle) IsValid() bool { + switch ts { + case TokenStyleUUID, TokenStyleSimple, TokenStyleRandom32, + TokenStyleRandom64, TokenStyleRandom128, TokenStyleJWT, + TokenStyleHash, TokenStyleTimestamp, TokenStyleTik: + return true + default: + return false + } +} + +// Generator token generation interface | Token生成接口 +type Generator interface { + // Generate generates token based on implementation | 生成Token(由实现决定具体规则) + Generate(loginID, device string) (string, error) +} diff --git a/core/adapter/log.go b/core/adapter/log.go new file mode 100644 index 0000000..56f7c6d --- /dev/null +++ b/core/adapter/log.go @@ -0,0 +1,64 @@ +// @Author daixk 2025/12/12 10:45:00 +package adapter + +// LogLevel defines log severity level | 日志级别定义 +type LogLevel int + +const ( + LogLevelDebug LogLevel = iota + 1 // Debug level | 调试级别 + LogLevelInfo // Info level | 信息级别 + LogLevelWarn // Warn level | 警告级别 + LogLevelError // Error level | 错误级别(最高) +) + +// String returns the string representation of log level | 返回日志级别的字符串表示 +func (l LogLevel) String() string { + switch l { + case LogLevelDebug: + return "DEBUG" + case LogLevelInfo: + return "INFO" + case LogLevelWarn: + return "WARN" + case LogLevelError: + return "ERROR" + default: + return "UNKNOWN" + } +} + +// Log defines logging behavior abstraction | 日志行为抽象接口 +type Log interface { + Print(v ...any) // Print log without level | 无级别日志输出 + Printf(format string, v ...any) // Print formatted log without level | 无级别格式化日志输出 + + Debug(v ...any) // Print debug level log | 输出调试级别日志 + Debugf(format string, v ...any) // Print formatted debug level log | 输出调试级别格式化日志 + + Info(v ...any) // Print info level log | 输出信息级别日志 + Infof(format string, v ...any) // Print formatted info level log | 输出信息级别格式化日志 + + Warn(v ...any) // Print warn level log | 输出警告级别日志 + Warnf(format string, v ...any) // Print formatted warn level log | 输出警告级别格式化日志 + + Error(v ...any) // Print error level log | 输出错误级别日志 + Errorf(format string, v ...any) // Print formatted error level log | 输出错误级别格式化日志 +} + +// LogControl defines runtime control methods for loggers | 日志运行时控制接口 +type LogControl interface { + Log + + // ---- Lifecycle | 生命周期 ---- + Close() // Close the logger and release resources | 关闭日志并释放资源 + Flush() // Flush buffered logs to output | 刷新缓冲区 + + // ---- Runtime Config | 运行时配置 ---- + SetLevel(level LogLevel) // Update minimum log level | 动态更新日志级别 + SetPrefix(prefix string) // Update log prefix | 动态更新日志前缀 + SetStdout(enable bool) // Enable/disable stdout output | 开关控制台输出 + + // ---- Status Query | 状态查询 ---- + LogPath() string // Get log directory path | 获取日志目录 + DropCount() uint64 // Get dropped log count (queue full) | 获取丢弃的日志数量 +} diff --git a/core/adapter/pool.go b/core/adapter/pool.go new file mode 100644 index 0000000..a3049b1 --- /dev/null +++ b/core/adapter/pool.go @@ -0,0 +1,8 @@ +// @Author daixk 2025/12/12 11:56:00 +package adapter + +type Pool interface { + Submit(task func()) error + Stop() + Stats() (running, capacity int, usage float64) +} diff --git a/core/adapter/storage.go b/core/adapter/storage.go index f5c6c9d..bd4eb76 100644 --- a/core/adapter/storage.go +++ b/core/adapter/storage.go @@ -1,42 +1,48 @@ package adapter -import "time" +import ( + "context" + "time" +) // Storage defines storage interface for Token and Session data | 定义存储接口,用于存储Token和Session数据 type Storage interface { // ============== Basic Operations | 基本操作 ============== // Set sets key-value pair with optional expiration time (0 means never expire) | 设置键值对,可选过期时间(0表示永不过期) - Set(key string, value any, expiration time.Duration) error + Set(ctx context.Context, key string, value any, expiration time.Duration) error // SetKeepTTL sets key-value pair but keeps the original TTL unchanged | 设置键值但保持原有TTL不变 - SetKeepTTL(key string, value any) error + SetKeepTTL(ctx context.Context, key string, value any) error // Get gets value by key, returns nil if key doesn't exist | 获取键对应的值,键不存在时返回nil - Get(key string) (any, error) + Get(ctx context.Context, key string) (any, error) + + // GetAndDelete atomically gets the value and deletes the key | 原子获取并删除键 + GetAndDelete(ctx context.Context, key string) (any, error) // Delete deletes one or more keys | 删除一个或多个键 - Delete(keys ...string) error + Delete(ctx context.Context, keys ...string) error // Exists checks if key exists | 检查键是否存在 - Exists(key string) bool + Exists(ctx context.Context, key string) bool // ============== Key Management | 键管理 ============== // Keys gets all keys matching pattern (e.g., "user:*") | 获取匹配模式的所有键(如:"user:*") - Keys(pattern string) ([]string, error) + Keys(ctx context.Context, pattern string) ([]string, error) // Expire sets expiration time for key | 设置键的过期时间 - Expire(key string, expiration time.Duration) error + Expire(ctx context.Context, key string, expiration time.Duration) error // TTL gets remaining time to live (-1 if no expiration, -2 if key doesn't exist) | 获取键的剩余生存时间(-1表示永不过期,-2表示键不存在) - TTL(key string) (time.Duration, error) + TTL(ctx context.Context, key string) (time.Duration, error) // ============== Utility Methods | 工具方法 ============== // Clear clears all data (use with caution, mainly for testing) | 清空所有数据(谨慎使用,主要用于测试) - Clear() error + Clear(ctx context.Context) error // Ping checks if storage is accessible | 检查存储是否可访问 - Ping() error + Ping(ctx context.Context) error } diff --git a/core/banner/banner.go b/core/banner/banner.go index cd1ae28..3ee0778 100644 --- a/core/banner/banner.go +++ b/core/banner/banner.go @@ -5,12 +5,10 @@ import ( "runtime" "strings" + "github.com/click33/sa-token-go/core" "github.com/click33/sa-token-go/core/config" ) -// Version version number | 版本号 -const Version = "0.1.1" - // Banner startup banner | 启动横幅 const Banner = ` _____ ______ __ ______ @@ -33,12 +31,60 @@ const ( // Print prints startup banner | 打印启动横幅 func Print() { - fmt.Printf(Banner, Version) + fmt.Printf(Banner, core.Version) fmt.Printf(":: Go Version :: %s\n", runtime.Version()) fmt.Printf(":: GOOS/GOARCH :: %s/%s\n", runtime.GOOS, runtime.GOARCH) fmt.Println() } +// PrintWithConfig prints startup banner with essential configuration | 打印启动横幅和核心配置信息 +func PrintWithConfig(cfg *config.Config) { + Print() + + fmt.Println("┌─────────────────────────────────────────────────────────┐") + fmt.Println("│ Configuration │") + fmt.Println("├─────────────────────────────────────────────────────────┤") + + // Basic Settings | 基础设置 + fmt.Print(formatConfigLine("Token Name", cfg.TokenName)) + fmt.Print(formatConfigLine("Token Style", cfg.TokenStyle)) + fmt.Print(formatConfigLine("Auth Type", cfg.AuthType)) + fmt.Print(formatConfigLine("Key Prefix", cfg.KeyPrefix)) + + // Timeout Strategy | 超时策略 + fmt.Println("├─────────────────────────────────────────────────────────┤") + fmt.Print(formatConfigLine("Token Timeout", formatTimeout(cfg.Timeout))) + fmt.Print(formatConfigLine("Active Timeout", formatTimeout(cfg.ActiveTimeout))) + fmt.Print(formatConfigLine("Auto Renew", cfg.AutoRenew)) + if cfg.AutoRenew { + fmt.Print(formatConfigLine("Max Refresh", formatTimeout(cfg.MaxRefresh))) + } + + // Login Strategy | 登录策略 + fmt.Println("├─────────────────────────────────────────────────────────┤") + fmt.Print(formatConfigLine("Concurrent Login", cfg.IsConcurrent)) + fmt.Print(formatConfigLine("Share Token", cfg.IsShare)) + if cfg.IsConcurrent && !cfg.IsShare { + fmt.Print(formatConfigLine("Max Login Count", formatCount(cfg.MaxLoginCount))) + } + + // Token Reading | Token 读取 + fmt.Println("├─────────────────────────────────────────────────────────┤") + fmt.Print(formatConfigLine("Read From", tokenReadSources(cfg))) + + // Security | 安全 + fmt.Println("├─────────────────────────────────────────────────────────┤") + if cfg.TokenStyle == "jwt" || cfg.TokenStyle == "JWT" { + fmt.Print(formatConfigLine("JWT Secret", configured)) + } else { + fmt.Print(formatConfigLine("JWT Secret", "(not used)")) + } + fmt.Print(formatConfigLine("Session Check", cfg.TokenSessionCheckLogin)) + + fmt.Println("└─────────────────────────────────────────────────────────┘") + fmt.Println() +} + // formatConfigLine formats configuration line with alignment and truncation | 格式化配置行(自动截断过长文本并保持对齐) func formatConfigLine(label string, value any) string { if len(label) > labelWidth { @@ -69,7 +115,7 @@ func formatTimeout(seconds int64) string { } // formatCount formats count value (number or "No Limit") | 格式化数量值 -func formatCount(count int) string { +func formatCount(count int64) string { if count > 0 { return fmt.Sprintf("%d", count) } @@ -93,65 +139,3 @@ func tokenReadSources(cfg *config.Config) string { } return strings.Join(parts, ", ") } - -// PrintWithConfig prints startup banner with essential configuration | 打印启动横幅和核心配置信息 -func PrintWithConfig(cfg *config.Config) { - Print() - - fmt.Println("┌─────────────────────────────────────────────────────────┐") - fmt.Println("│ Configuration │") - fmt.Println("├─────────────────────────────────────────────────────────┤") - - // Basic Token Settings | Token 基础设置 - fmt.Print(formatConfigLine("Token Name", cfg.TokenName)) - fmt.Print(formatConfigLine("Token Style", cfg.TokenStyle)) - fmt.Print(formatConfigLine("Key Prefix", cfg.KeyPrefix)) - - // Login Control | 登录控制 - fmt.Println("├─────────────────────────────────────────────────────────┤") - fmt.Print(formatConfigLine("Concurrent Login", cfg.IsConcurrent)) - fmt.Print(formatConfigLine("Share Token", cfg.IsShare)) - fmt.Print(formatConfigLine("Max Login Count", formatCount(cfg.MaxLoginCount))) - - // Timeout & Activity | 超时与活跃控制 - fmt.Println("├─────────────────────────────────────────────────────────┤") - fmt.Print(formatConfigLine("Token Timeout", formatTimeout(cfg.Timeout))) - fmt.Print(formatConfigLine("Active Timeout", formatTimeout(cfg.ActiveTimeout))) - fmt.Print(formatConfigLine("Auto Renew", cfg.AutoRenew)) - - // Renewal & Refresh Strategy | 续期与刷新策略 - fmt.Println("├─────────────────────────────────────────────────────────┤") - fmt.Print(formatConfigLine("Max Refresh", formatTimeout(cfg.MaxRefresh))) - fmt.Print(formatConfigLine("Renew Interval", formatTimeout(cfg.RenewInterval))) - fmt.Print(formatConfigLine("Data Refresh", formatTimeout(cfg.DataRefreshPeriod))) - - // Token Read Sources (compact) | Token 读取来源(紧凑显示) - fmt.Println("├─────────────────────────────────────────────────────────┤") - fmt.Print(formatConfigLine("Read From", tokenReadSources(cfg))) - - // Security & Storage | 安全与存储 - fmt.Println("├─────────────────────────────────────────────────────────┤") - if cfg.TokenStyle == "jwt" || cfg.TokenStyle == "JWT" { - fmt.Print(formatConfigLine("JWT Secret Key", configured)) - } else { - fmt.Print(formatConfigLine("JWT Secret Key", "(not used)")) - } - - // Cookie Configuration (only if enabled) | Cookie 配置(仅当启用时显示) - fmt.Println("├─────────────────────────────────────────────────────────┤") - if cfg.IsReadCookie || cfg.CookieConfig != nil { - if cfg.CookieConfig == nil { - fmt.Print(formatConfigLine("Cookie Config", "(default)")) - } else { - maxAge := formatTimeout(int64(cfg.CookieConfig.MaxAge)) - fmt.Print(formatConfigLine("Cookie MaxAge", maxAge)) - fmt.Print(formatConfigLine("Cookie Secure", cfg.CookieConfig.Secure)) - fmt.Print(formatConfigLine("Cookie HttpOnly", cfg.CookieConfig.HttpOnly)) - } - } else { - fmt.Print(formatConfigLine("Cookie Support", "disabled")) - } - - fmt.Println("└─────────────────────────────────────────────────────────┘") - fmt.Println() -} diff --git a/core/banner/banner_test.go b/core/banner/banner_test.go index 90027d4..4b1f311 100644 --- a/core/banner/banner_test.go +++ b/core/banner/banner_test.go @@ -1,379 +1,100 @@ package banner import ( - "bytes" - "io" - "os" - "runtime" - "strings" "testing" + "github.com/click33/sa-token-go/core/adapter" "github.com/click33/sa-token-go/core/config" ) -// captureOutput captures stdout output for testing -func captureOutput(f func()) string { - old := os.Stdout - r, w, _ := os.Pipe() - os.Stdout = w - - f() - - w.Close() - os.Stdout = old - - var buf bytes.Buffer - io.Copy(&buf, r) - return buf.String() -} - -func TestPrint(t *testing.T) { - output := captureOutput(func() { - Print() - }) - - // Check if output contains expected elements - if !strings.Contains(output, "Sa-Token-Go") { - t.Error("Output should contain 'Sa-Token-Go'") - } - - if !strings.Contains(output, Version) { - t.Errorf("Output should contain version %s", Version) - } - - if !strings.Contains(output, "Go Version") { - t.Error("Output should contain 'Go Version'") - } - - if !strings.Contains(output, runtime.Version()) { - t.Errorf("Output should contain Go version %s", runtime.Version()) - } - - if !strings.Contains(output, "GOOS/GOARCH") { - t.Error("Output should contain 'GOOS/GOARCH'") - } - - expectedOS := runtime.GOOS + "/" + runtime.GOARCH - if !strings.Contains(output, expectedOS) { - t.Errorf("Output should contain OS/ARCH %s", expectedOS) - } -} - -func TestFormatTimeout(t *testing.T) { - tests := []struct { - name string - seconds int64 - expected string - }{ - { - name: "Positive seconds less than a day", - seconds: 3600, - expected: "3600 seconds", - }, - { - name: "Exactly one day", - seconds: 86400, - expected: "86400 seconds (1 days)", - }, - { - name: "Multiple days", - seconds: 259200, // 3 days - expected: "259200 seconds (3 days)", - }, - { - name: "30 days", - seconds: 2592000, - expected: "2592000 seconds (30 days)", - }, - { - name: "Zero means never expire", - seconds: 0, - expected: neverExpire, - }, - { - name: "Negative means no limit", - seconds: -1, - expected: noLimit, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatTimeout(tt.seconds) - if result != tt.expected { - t.Errorf("formatTimeout(%d) = %s, want %s", tt.seconds, result, tt.expected) - } - }) - } -} - -func TestFormatCount(t *testing.T) { - tests := []struct { - name string - count int - expected string - }{ - { - name: "Positive count", - count: 12, - expected: "12", - }, - { - name: "Zero means no limit", - count: 0, - expected: noLimit, - }, - { - name: "Negative means no limit", - count: -1, - expected: noLimit, - }, - { - name: "Large count", - count: 9999, - expected: "9999", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatCount(tt.count) - if result != tt.expected { - t.Errorf("formatCount(%d) = %s, want %s", tt.count, result, tt.expected) - } - }) - } +// TestPrintWithConfig_Default tests banner printing with default configuration +func TestPrintWithConfig_Default(t *testing.T) { + t.Log("=== Testing with Default Config ===") + cfg := config.DefaultConfig() + PrintWithConfig(cfg) } -func TestFormatConfigLine(t *testing.T) { - tests := []struct { - name string - label string - value any - contains []string - }{ - { - name: "String value", - label: "Token Name", - value: "sa-token", - contains: []string{ - "Token Name", - "sa-token", - "│", - }, - }, - { - name: "Boolean value", - label: "Auto Renew", - value: true, - contains: []string{ - "Auto Renew", - "true", - }, - }, - { - name: "Integer value", - label: "Max Count", - value: 12, - contains: []string{ - "Max Count", - "12", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := formatConfigLine(tt.label, tt.value) - for _, s := range tt.contains { - if !strings.Contains(result, s) { - t.Errorf("formatConfigLine(%s, %v) should contain %s, got: %s", tt.label, tt.value, s, result) - } - } - }) - } +// TestPrintWithConfig_JWT tests banner printing with JWT configuration +func TestPrintWithConfig_JWT(t *testing.T) { + t.Log("=== Testing with JWT Config ===") + cfg := config.DefaultConfig() + cfg.SetTokenStyle(adapter.TokenStyleJWT) + cfg.SetJwtSecretKey("my-secret-key-123456") + cfg.SetTimeout(86400) // 1 day + cfg.SetActiveTimeout(3600) // 1 hour + cfg.SetAutoRenew(true) + cfg.SetMaxRefresh(43200) // 12 hours + PrintWithConfig(cfg) } -func TestPrintWithConfig(t *testing.T) { - tests := []struct { - name string - config *config.Config - contains []string - }{ - { - name: "Default configuration", - config: config.DefaultConfig(), - contains: []string{ - "Configuration", - "Token Name", - "sa-token", - "Token Style", - "uuid", - "Token Timeout", - "30 days", - "Auto Renew", - "Concurrent", - "Share Token", - "Max Login Count", - "Read From Header", - "Read From Cookie", - "Read From Body", - "Logging", - }, - }, - { - name: "JWT configuration", - config: &config.Config{ - TokenName: "jwt-token", - Timeout: 3600, - ActiveTimeout: -1, - IsConcurrent: true, - IsShare: false, - MaxLoginCount: 5, - IsReadBody: false, - IsReadHeader: true, - IsReadCookie: false, - TokenStyle: config.TokenStyleJWT, - AutoRenew: true, - JwtSecretKey: "my-secret-key", - IsLog: true, - CookieConfig: &config.CookieConfig{ - Path: "/api", - SameSite: config.SameSiteLax, - HttpOnly: true, - Secure: true, - }, - }, - contains: []string{ - "jwt-token", - "jwt", - "3600 seconds", - "JWT Secret", - "*** (configured)", - "Cookie Path", - "/api", - "Cookie SameSite", - "Cookie HttpOnly", - "Cookie Secure", - }, - }, - { - name: "Never expire configuration", - config: &config.Config{ - TokenName: "never-token", - Timeout: 0, - ActiveTimeout: -1, - IsConcurrent: false, - IsShare: true, - MaxLoginCount: -1, - TokenStyle: config.TokenStyleUUID, - CookieConfig: &config.CookieConfig{}, - }, - contains: []string{ - "Never Expire", - "No Limit", - }, - }, - { - name: "JWT without secret key", - config: &config.Config{ - TokenName: "jwt-token", - TokenStyle: config.TokenStyleJWT, - JwtSecretKey: "", - CookieConfig: &config.CookieConfig{}, - }, - contains: []string{ - "JWT Secret", - "Not Set", - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - output := captureOutput(func() { - PrintWithConfig(tt.config) - }) - - for _, s := range tt.contains { - if !strings.Contains(output, s) { - t.Errorf("PrintWithConfig() output should contain '%s'\nGot output:\n%s", s, output) - } - } - - // Check for box drawing characters - if !strings.Contains(output, "┌") || !strings.Contains(output, "└") { - t.Error("Output should contain box drawing characters") - } - }) - } +// TestPrintWithConfig_NonConcurrent tests banner printing with non-concurrent login +func TestPrintWithConfig_NonConcurrent(t *testing.T) { + t.Log("=== Testing with Non-Concurrent Login Config ===") + cfg := config.DefaultConfig() + cfg.SetIsConcurrent(false) + cfg.SetIsShare(false) + PrintWithConfig(cfg) } -func TestPrintWithConfigNilCookie(t *testing.T) { - cfg := &config.Config{ - TokenName: "test-token", - TokenStyle: config.TokenStyleSimple, - CookieConfig: nil, // nil cookie config - } - - output := captureOutput(func() { - PrintWithConfig(cfg) - }) - - // Should not panic and should not contain cookie configuration - if strings.Contains(output, "Cookie Path") { - t.Error("Output should not contain Cookie configuration when CookieConfig is nil") - } +// TestPrintWithConfig_MaxLoginCount tests banner printing with max login count +func TestPrintWithConfig_MaxLoginCount(t *testing.T) { + t.Log("=== Testing with Max Login Count Config ===") + cfg := config.DefaultConfig() + cfg.SetIsConcurrent(true) + cfg.SetIsShare(false) + cfg.SetMaxLoginCount(5) + PrintWithConfig(cfg) } -func BenchmarkPrint(b *testing.B) { - // Redirect output to discard - old := os.Stdout - os.Stdout, _ = os.Open(os.DevNull) - defer func() { os.Stdout = old }() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - Print() - } +// TestPrintWithConfig_AllReadSources tests banner printing with all read sources enabled +func TestPrintWithConfig_AllReadSources(t *testing.T) { + t.Log("=== Testing with All Read Sources Enabled ===") + cfg := config.DefaultConfig() + cfg.SetIsReadHeader(true) + cfg.SetIsReadCookie(true) + cfg.SetIsReadBody(true) + PrintWithConfig(cfg) } -func BenchmarkPrintWithConfig(b *testing.B) { +// TestPrintWithConfig_CustomPrefix tests banner printing with custom prefix and auth type +func TestPrintWithConfig_CustomPrefix(t *testing.T) { + t.Log("=== Testing with Custom Prefix and Auth Type ===") cfg := config.DefaultConfig() - - // Redirect output to discard - old := os.Stdout - os.Stdout, _ = os.Open(os.DevNull) - defer func() { os.Stdout = old }() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - PrintWithConfig(cfg) - } + cfg.SetKeyPrefix("myapp") + cfg.SetAuthType("oauth2") + cfg.SetTokenName("access_token") + PrintWithConfig(cfg) } -func BenchmarkFormatTimeout(b *testing.B) { - for i := 0; i < b.N; i++ { - formatTimeout(2592000) - } +// TestPrintWithConfig_NoAutoRenew tests banner printing without auto renew +func TestPrintWithConfig_NoAutoRenew(t *testing.T) { + t.Log("=== Testing without Auto Renew ===") + cfg := config.DefaultConfig() + cfg.SetAutoRenew(false) + PrintWithConfig(cfg) } -func BenchmarkFormatCount(b *testing.B) { - for i := 0; i < b.N; i++ { - formatCount(12) - } +// TestPrintWithConfig_NeverExpire tests banner printing with never expire timeout +func TestPrintWithConfig_NeverExpire(t *testing.T) { + t.Log("=== Testing with Never Expire Timeout ===") + cfg := config.DefaultConfig() + cfg.SetTimeout(-1) + cfg.SetActiveTimeout(-1) + PrintWithConfig(cfg) } -func BenchmarkFormatConfigLine(b *testing.B) { - for i := 0; i < b.N; i++ { - formatConfigLine("Token Name", "sa-token") - } +// TestPrintWithConfig_LongTimeout tests banner printing with long timeout (shows days) +func TestPrintWithConfig_LongTimeout(t *testing.T) { + t.Log("=== Testing with Long Timeout (30 days) ===") + cfg := config.DefaultConfig() + cfg.SetTimeout(2592000) // 30 days + cfg.SetActiveTimeout(604800) // 7 days + cfg.SetMaxRefresh(1296000) // 15 days + PrintWithConfig(cfg) } -// TestPrintWithConfigVisual is a visual test that prints the full banner and config to stdout. -// It does not assert anything — useful for manual inspection during development. -func TestPrintWithConfigVisual(t *testing.T) { - t.Log("=== Visual Output of PrintWithConfig (Default Config) ===") - PrintWithConfig(config.DefaultConfig()) - t.Log("=== End of Visual Output ===") +// TestPrint tests basic banner printing +func TestPrint(t *testing.T) { + t.Log("=== Testing Basic Banner ===") + Print() } diff --git a/core/builder/builder.go b/core/builder/builder.go index a06678d..4493602 100644 --- a/core/builder/builder.go +++ b/core/builder/builder.go @@ -1,8 +1,11 @@ package builder import ( - "fmt" - "github.com/click33/sa-token-go/core/pool" + codec_json "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/log/slog" + "github.com/click33/sa-token-go/pool/ants" "strings" "time" @@ -10,35 +13,46 @@ import ( "github.com/click33/sa-token-go/core/banner" "github.com/click33/sa-token-go/core/config" "github.com/click33/sa-token-go/core/manager" + "github.com/click33/sa-token-go/storage/memory" ) -// Builder Sa-Token builder for fluent configuration | Sa-Token构建器,用于流式配置 +// Builder provides fluent configuration for Sa-Token | Sa-Token 构建器用于流式配置 type Builder struct { - storage adapter.Storage - tokenName string - timeout int64 - maxRefresh int64 - renewInterval int64 - activeTimeout int64 - isConcurrent bool - isShare bool - maxLoginCount int - tokenStyle config.TokenStyle - autoRenew bool - jwtSecretKey string - isLog bool - isPrintBanner bool - isReadBody bool - isReadHeader bool - isReadCookie bool - dataRefreshPeriod int64 - tokenSessionCheckLogin bool - keyPrefix string - cookieConfig *config.CookieConfig - renewPoolConfig *pool.RenewPoolConfig -} - -// NewBuilder creates a new builder with default configuration | 创建新的构建器(使用默认配置) + tokenName string // Token name used by client | 客户端 Token 名称 + timeout int64 // Token timeout seconds | Token 过期时间(秒) + maxRefresh int64 // Max auto-refresh duration | 最大无感刷新时间 + renewInterval int64 // Min renewal interval seconds | 最小续期间隔(秒) + activeTimeout int64 // Force offline when idle | 活跃超时时间(秒) + isConcurrent bool // Allow concurrent login | 是否允许并发登录 + isShare bool // Share same token among devices | 是否共用 Token + maxLoginCount int64 // Max concurrent login count | 最大并发登录数 + isReadBody bool // Read token from body | 是否从 Body 读取 Token + isReadHeader bool // Read token from header | 是否从 Header 读取 Token + isReadCookie bool // Read token from cookie | 是否从 Cookie 读取 Token + tokenStyle adapter.TokenStyle // Token generation style | Token 生成方式 + tokenSessionCheckLogin bool // Check login before Session | 读取 Session 时是否检查登录 + autoRenew bool // Enable renewal | 是否启用自动续期 + jwtSecretKey string // JWT secret key | JWT 密钥 + isLog bool // Enable log output | 是否启用日志 + isPrintBanner bool // Print startup banner | 是否打印启动 Banner + keyPrefix string // Storage key prefix | 存储键前缀 + authType string // Authentication system type | 认证体系类型 + + cookieConfig *config.CookieConfig // Cookie config | Cookie 配置 + renewPoolConfig *ants.RenewPoolConfig // Renew pool config | 续期协程池配置 + logConfig *slog.LoggerConfig // log config | 日志配置 + + generator adapter.Generator // Token generator | Token 生成器 + storage adapter.Storage // Storage adapter | 存储适配器 + codec adapter.Codec // codec Codec adapter for encoding and decoding operations | 编解码操作的编码器适配器 + log adapter.Log // log Log adapter for logging operations | 日志记录操作的适配器 + pool adapter.Pool // Async task pool component | 异步任务协程池组件 + + customPermissionListFunc func(loginID, authType string) ([]string, error) // Custom permission provider | 自定义权限获取函数 + customRoleListFunc func(loginID, authType string) ([]string, error) // Custom role provider | 自定义角色获取函数 +} + +// NewBuilder creates a new builder with log configuration | 创建新的构建器(使用默认配置) func NewBuilder() *Builder { return &Builder{ tokenName: config.DefaultTokenName, @@ -49,31 +63,24 @@ func NewBuilder() *Builder { isConcurrent: true, isShare: true, maxLoginCount: config.DefaultMaxLoginCount, - tokenStyle: config.TokenStyleUUID, - autoRenew: true, - isLog: false, - isPrintBanner: true, isReadBody: false, isReadHeader: true, isReadCookie: false, - dataRefreshPeriod: config.NoLimit, + tokenStyle: adapter.TokenStyleUUID, tokenSessionCheckLogin: true, - keyPrefix: "satoken:", - cookieConfig: &config.CookieConfig{ - Domain: "", - Path: config.DefaultCookiePath, - Secure: false, - HttpOnly: true, - SameSite: config.SameSiteLax, - MaxAge: 0, - }, - } -} + autoRenew: true, + jwtSecretKey: sgenerator.DefaultJWTSecret, + isLog: false, + isPrintBanner: true, + keyPrefix: config.DefaultKeyPrefix, + authType: config.DefaultAuthType, -// Storage sets storage adapter | 设置存储适配器 -func (b *Builder) Storage(storage adapter.Storage) *Builder { - b.storage = storage - return b + cookieConfig: config.DefaultCookieConfig(), + renewPoolConfig: ants.DefaultRenewPoolConfig(), + + // 不需要设置logConfig + // logConfig: slog.DefaultLoggerConfig(), + } } // TokenName sets token name | 设置Token名称 @@ -125,17 +132,41 @@ func (b *Builder) IsShare(share bool) *Builder { } // MaxLoginCount sets maximum login count | 设置最大登录数量 -func (b *Builder) MaxLoginCount(count int) *Builder { +func (b *Builder) MaxLoginCount(count int64) *Builder { b.maxLoginCount = count return b } +// IsReadBody sets whether to read token from request body | 设置是否从请求体读取Token +func (b *Builder) IsReadBody(isRead bool) *Builder { + b.isReadBody = isRead + return b +} + +// IsReadHeader sets whether to read token from header | 设置是否从Header读取Token +func (b *Builder) IsReadHeader(isRead bool) *Builder { + b.isReadHeader = isRead + return b +} + +// IsReadCookie sets whether to read token from cookie | 设置是否从Cookie读取Token +func (b *Builder) IsReadCookie(isRead bool) *Builder { + b.isReadCookie = isRead + return b +} + // TokenStyle sets token generation style | 设置Token风格 -func (b *Builder) TokenStyle(style config.TokenStyle) *Builder { +func (b *Builder) TokenStyle(style adapter.TokenStyle) *Builder { b.tokenStyle = style return b } +// TokenSessionCheckLogin sets whether to check token session on login | 设置登录时是否检查Token会话 +func (b *Builder) TokenSessionCheckLogin(check bool) *Builder { + b.tokenSessionCheckLogin = check + return b +} + // AutoRenew sets whether to auto-renew token | 设置是否自动续期 func (b *Builder) AutoRenew(autoRenew bool) *Builder { b.autoRenew = autoRenew @@ -160,33 +191,31 @@ func (b *Builder) IsPrintBanner(isPrint bool) *Builder { return b } -// IsReadBody sets whether to read token from request body | 设置是否从请求体读取Token -func (b *Builder) IsReadBody(isRead bool) *Builder { - b.isReadBody = isRead - return b -} - -// IsReadHeader sets whether to read token from header | 设置是否从Header读取Token -func (b *Builder) IsReadHeader(isRead bool) *Builder { - b.isReadHeader = isRead +// KeyPrefix sets storage key prefix | 设置存储键前缀 +func (b *Builder) KeyPrefix(prefix string) *Builder { + // 如果前缀不为空且不以 : 结尾,自动添加 : + if prefix != "" && !strings.HasSuffix(prefix, ":") { + b.keyPrefix = prefix + ":" + } else { + b.keyPrefix = prefix + } return b } -// IsReadCookie sets whether to read token from cookie | 设置是否从Cookie读取Token -func (b *Builder) IsReadCookie(isRead bool) *Builder { - b.isReadCookie = isRead - return b -} +// AuthType sets authentication system type | 设置认证体系类型 +func (b *Builder) AuthType(authType string) *Builder { + // 如果为空,则使用默认 + if authType == "" { + b.authType = config.DefaultAuthType + } -// DataRefreshPeriod sets data refresh period | 设置数据刷新周期 -func (b *Builder) DataRefreshPeriod(seconds int64) *Builder { - b.dataRefreshPeriod = seconds - return b -} + // 如果前缀不为空且不以 : 结尾,自动添加 : + if authType != "" && !strings.HasSuffix(authType, ":") { + b.authType = authType + ":" + } else { + b.authType = authType + } -// TokenSessionCheckLogin sets whether to check token session on login | 设置登录时是否检查Token会话 -func (b *Builder) TokenSessionCheckLogin(check bool) *Builder { - b.tokenSessionCheckLogin = check return b } @@ -236,7 +265,7 @@ func (b *Builder) CookieSameSite(sameSite config.SameSiteMode) *Builder { } // CookieMaxAge sets cookie max age | 设置Cookie的最大年龄 -func (b *Builder) CookieMaxAge(maxAge int) *Builder { +func (b *Builder) CookieMaxAge(maxAge int64) *Builder { if b.cookieConfig == nil { b.cookieConfig = &config.CookieConfig{} } @@ -250,121 +279,289 @@ func (b *Builder) CookieConfig(cfg *config.CookieConfig) *Builder { return b } -// RenewPoolConfig sets the token renewal pool configuration | 设置Token续期池配置 -func (b *Builder) RenewPoolConfig(cfg *pool.RenewPoolConfig) *Builder { +// RenewPoolMinSize sets the minimum pool size | 设置最小协程数 +func (b *Builder) RenewPoolMinSize(size int) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.MinSize = size + return b +} + +// RenewPoolMaxSize sets the maximum pool size | 设置最大协程数 +func (b *Builder) RenewPoolMaxSize(size int) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.MaxSize = size + return b +} + +// RenewPoolScaleUpRate sets the scale-up threshold | 设置扩容阈值 +func (b *Builder) RenewPoolScaleUpRate(rate float64) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.ScaleUpRate = rate + return b +} + +// RenewPoolScaleDownRate sets the scale-down threshold | 设置缩容阈值 +func (b *Builder) RenewPoolScaleDownRate(rate float64) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.ScaleDownRate = rate + return b +} + +// RenewPoolCheckInterval sets the interval for auto-scale checking | 设置自动扩缩容检查间隔 +func (b *Builder) RenewPoolCheckInterval(interval time.Duration) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.CheckInterval = interval + return b +} + +// RenewPoolExpiry sets the idle worker expiry duration | 设置空闲协程过期时间 +func (b *Builder) RenewPoolExpiry(duration time.Duration) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.Expiry = duration + return b +} + +// RenewPoolPrintStatusInterval sets the status printing interval | 设置状态打印间隔 +func (b *Builder) RenewPoolPrintStatusInterval(interval time.Duration) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.PrintStatusInterval = interval + return b +} + +// RenewPoolPreAlloc sets whether to pre-allocate memory | 设置是否预分配内存 +func (b *Builder) RenewPoolPreAlloc(preAlloc bool) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.PreAlloc = preAlloc + return b +} + +// RenewPoolNonBlocking sets whether the pool works in non-blocking mode | 设置是否为非阻塞模式 +func (b *Builder) RenewPoolNonBlocking(nonBlocking bool) *Builder { + if b.renewPoolConfig == nil { + b.renewPoolConfig = &ants.RenewPoolConfig{} + } + b.renewPoolConfig.NonBlocking = nonBlocking + return b +} + +// RenewPoolConfig sets the token renewal pool configuration | 设置完整的Token续期池配置 +func (b *Builder) RenewPoolConfig(cfg *ants.RenewPoolConfig) *Builder { b.renewPoolConfig = cfg return b } -// KeyPrefix sets storage key prefix | 设置存储键前缀 -// Automatically adds ":" suffix if not present (except for empty string) | 自动添加 ":" 后缀(空字符串除外) -// Examples: "satoken" -> "satoken:", "myapp" -> "myapp:", "" -> "" -// Use empty string "" for Java sa-token compatibility | 使用空字符串 "" 兼容 Java sa-token -func (b *Builder) KeyPrefix(prefix string) *Builder { - // 如果前缀不为空且不以 : 结尾,自动添加 : - if prefix != "" && !strings.HasSuffix(prefix, ":") { - b.keyPrefix = prefix + ":" - } else { - b.keyPrefix = prefix +// LoggerPath sets the log directory path | 设置日志文件目录 +func (b *Builder) LoggerPath(path string) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.Path = path return b } -// NeverExpire sets token to never expire | 设置Token永不过期 -func (b *Builder) NeverExpire() *Builder { - b.timeout = config.NoLimit +// LoggerFileFormat sets the log file naming format | 设置日志文件命名格式 +func (b *Builder) LoggerFileFormat(format string) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} + } + b.logConfig.FileFormat = format return b } -// NoActiveTimeout disables active timeout | 禁用活跃超时 -func (b *Builder) NoActiveTimeout() *Builder { - b.activeTimeout = config.NoLimit +// LoggerPrefix sets the log line prefix | 设置日志前缀 +func (b *Builder) LoggerPrefix(prefix string) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} + } + b.logConfig.Prefix = prefix return b } -// UnlimitedLogin allows unlimited concurrent logins | 允许无限并发登录 -func (b *Builder) UnlimitedLogin() *Builder { - b.maxLoginCount = config.NoLimit +// LoggerLevel sets the minimum output log level | 设置日志最低输出级别 +func (b *Builder) LoggerLevel(level slog.LogLevel) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} + } + b.logConfig.Level = level return b } -// Validate validates the builder configuration | 验证构建器配置 -func (b *Builder) Validate() error { - if b.storage == nil { - return fmt.Errorf("storage is required, please call Storage() method") +// LoggerTimeFormat sets the timestamp format | 设置时间戳格式 +func (b *Builder) LoggerTimeFormat(format string) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.TimeFormat = format + return b +} - if b.tokenName == "" { - return fmt.Errorf("tokenName cannot be empty") +// LoggerStdout sets whether to print logs to console | 设置是否输出到控制台 +func (b *Builder) LoggerStdout(stdout bool) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.Stdout = stdout + return b +} - if b.tokenStyle == config.TokenStyleJWT && b.jwtSecretKey == "" { - return fmt.Errorf("jwtSecretKey is required when TokenStyle is JWT") +// LoggerStdoutOnly sets whether to only print to console (skip file output) | 设置是否仅输出到控制台(不写入文件) +func (b *Builder) LoggerStdoutOnly(stdoutOnly bool) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.StdoutOnly = stdoutOnly + return b +} - if !b.isReadHeader && !b.isReadCookie && !b.isReadBody { - return fmt.Errorf("at least one of IsReadHeader, IsReadCookie, or IsReadBody must be true") +// LoggerQueueSize sets the async write queue size | 设置异步写入队列大小 +func (b *Builder) LoggerQueueSize(size int) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.QueueSize = size + return b +} - // Check MaxRefresh - if b.maxRefresh < config.NoLimit { - return fmt.Errorf("MaxRefresh must be >= -1, got: %d", b.maxRefresh) +// LoggerRotateSize sets the file size threshold for log rotation (bytes) | 设置日志文件大小滚动阈值(字节) +func (b *Builder) LoggerRotateSize(size int64) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.RotateSize = size + return b +} - // Check MaxRefresh does not exceed Timeout - if b.timeout != config.NoLimit && b.maxRefresh > b.timeout { - return fmt.Errorf("MaxRefresh (%d) cannot be greater than Timeout (%d)", b.maxRefresh, b.timeout) +// LoggerRotateExpire sets the rotation interval by time duration | 设置文件时间滚动间隔 +func (b *Builder) LoggerRotateExpire(expire time.Duration) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.RotateExpire = expire + return b +} - // Check RenewInterval - if b.renewInterval < config.NoLimit { - return fmt.Errorf("RenewInterval must be >= -1, got: %d", b.renewInterval) +// LoggerRotateBackupLimit sets the maximum number of rotated backup files | 设置最大备份文件数量 +func (b *Builder) LoggerRotateBackupLimit(limit int) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} } + b.logConfig.RotateBackupLimit = limit + return b +} - // Validate RenewPoolConfig if set | 如果设置了续期池配置,进行验证 - if b.renewPoolConfig != nil { - // Check MinSize and MaxSize | 检查最小和最大协程池大小 - if b.renewPoolConfig.MinSize <= 0 { - return fmt.Errorf("RenewPoolConfig.MinSize must be > 0") // 最小协程池大小必须大于0 - } - if b.renewPoolConfig.MaxSize < b.renewPoolConfig.MinSize { - return fmt.Errorf("RenewPoolConfig.MaxSize must be >= RenewPoolConfig.MinSize") // 最大协程池大小必须大于等于最小协程池大小 - } +// LoggerRotateBackupDays sets the retention days for old log files | 设置备份文件保留天数 +func (b *Builder) LoggerRotateBackupDays(days int) *Builder { + if b.logConfig == nil { + b.logConfig = &slog.LoggerConfig{} + } + b.logConfig.RotateBackupDays = days + return b +} - // Check ScaleUpRate and ScaleDownRate | 检查扩容和缩容阈值 - if b.renewPoolConfig.ScaleUpRate <= 0 || b.renewPoolConfig.ScaleUpRate > 1 { - return fmt.Errorf("RenewPoolConfig.ScaleUpRate must be between 0 and 1") // 扩容阈值必须在0和1之间 - } - if b.renewPoolConfig.ScaleDownRate < 0 || b.renewPoolConfig.ScaleDownRate > 1 { - return fmt.Errorf("RenewPoolConfig.ScaleDownRate must be between 0 and 1") // 缩容阈值必须在0和1之间 - } +// LoggerConfig sets complete logger configuration | 设置完整的日志配置 +func (b *Builder) LoggerConfig(cfg *slog.LoggerConfig) *Builder { + b.logConfig = cfg + return b +} - // Check CheckInterval | 检查检查间隔 - if b.renewPoolConfig.CheckInterval <= 0 { - return fmt.Errorf("RenewPoolConfig.CheckInterval must be a positive duration") // 检查间隔必须是一个正值 - } +// SetGenerator sets generator adapter | 设置Token生成器 +func (b *Builder) SetGenerator(generator adapter.Generator) *Builder { + b.generator = generator + return b +} - // Check Expiry | 检查过期时间 - if b.renewPoolConfig.Expiry <= 0 { - return fmt.Errorf("RenewPoolConfig.Expiry must be a positive duration") // 过期时间必须是正值 - } +// SetStorage sets storage adapter | 设置存储适配器 +func (b *Builder) SetStorage(storage adapter.Storage) *Builder { + b.storage = storage + return b +} + +// SetCodec sets the codec for encoding and decoding operations | 设置编解码器适配器 +func (b *Builder) SetCodec(codec adapter.Codec) *Builder { + b.codec = codec + return b +} + +// SetLog sets the log adapter for logging operations | 设置日志记录适配器 +func (b *Builder) SetLog(log adapter.Log) *Builder { + b.log = log + return b +} + +// SetPool sets the goroutine pool for async task execution | 设置用于异步任务执行的协程池 +func (b *Builder) SetPool(pool adapter.Pool) *Builder { + b.pool = pool + return b +} + +// SetCustomPermissionListFunc sets the custom permission provider | 设置自定义权限获取函数 +func (b *Builder) SetCustomPermissionListFunc(f func(loginID, authType string) ([]string, error)) *Builder { + b.customPermissionListFunc = f + return b +} + +// SetCustomRoleListFunc sets the custom role provider | 设置自定义角色获取函数 +func (b *Builder) SetCustomRoleListFunc(f func(loginID, authType string) ([]string, error)) *Builder { + b.customRoleListFunc = f + return b +} + +// Jwt sets TokenStyle to JWT and sets secret key | 设置为JWT模式并指定密钥 +func (b *Builder) Jwt(secret string) *Builder { + b.tokenStyle = adapter.TokenStyleJWT + b.jwtSecretKey = secret + return b +} + +// Clone creates a deep copy of the builder | 克隆当前构建器 +func (b *Builder) Clone() *Builder { + clone := *b + + // Deep copy for cookieConfig + if b.cookieConfig != nil { + cookieCopy := *b.cookieConfig + clone.cookieConfig = &cookieCopy + } + + // Deep copy for renewPoolConfig + if b.renewPoolConfig != nil { + poolCopy := *b.renewPoolConfig + clone.renewPoolConfig = &poolCopy } - return nil + // Deep copy for logConfig + if b.logConfig != nil { + logCopy := *b.logConfig + clone.logConfig = &logCopy + } + + return &clone } // Build builds Manager and prints startup banner | 构建Manager并打印启动Banner func (b *Builder) Build() *manager.Manager { - // Validate configuration | 验证配置 - if err := b.Validate(); err != nil { - panic(fmt.Sprintf("invalid configuration: %v", err)) - } - - // Automatically adjust MaxRefresh if user customized Timeout but didn't set MaxRefresh | 自动调整MaxRefresh逻辑 - if b.timeout != config.DefaultTimeout && b.maxRefresh == config.DefaultTimeout/2 { - b.maxRefresh = b.timeout / 2 + // 如果为cookieConfig为nil 则初始化默认cookieConfig + if b.cookieConfig == nil { + b.cookieConfig = config.DefaultCookieConfig() } + // Init config | 初始化config cfg := &config.Config{ TokenName: b.tokenName, Timeout: b.timeout, @@ -378,7 +575,6 @@ func (b *Builder) Build() *manager.Manager { IsReadHeader: b.isReadHeader, IsReadCookie: b.isReadCookie, TokenStyle: b.tokenStyle, - DataRefreshPeriod: b.dataRefreshPeriod, TokenSessionCheckLogin: b.tokenSessionCheckLogin, AutoRenew: b.autoRenew, JwtSecretKey: b.jwtSecretKey, @@ -386,24 +582,83 @@ func (b *Builder) Build() *manager.Manager { IsPrintBanner: b.isPrintBanner, KeyPrefix: b.keyPrefix, CookieConfig: b.cookieConfig, - RenewPoolConfig: b.renewPoolConfig, + AuthType: b.authType, } - // Print startup banner with full configuration | 打印启动Banner和完整配置 - // Only skip printing when both IsLog=false AND IsPrintBanner=false | 只有当 IsLog=false 且 IsPrintBanner=false 时才不打印 - if b.isPrintBanner || b.isLog { - banner.PrintWithConfig(cfg) + // 验证基础配置 + err := cfg.Validate() + if err != nil { + panic("Invalid config: " + err.Error()) } - mgr := manager.NewManager(b.storage, cfg) + // 如果generator为nil,则初始化默认generator + if b.generator == nil { + b.generator = sgenerator.NewGenerator(b.timeout, b.tokenStyle, b.jwtSecretKey) + } + // 如果storage为nil,则初始化默认storage + if b.storage == nil { + b.storage = memory.NewStorage() + } + // 如果codec为nil,则初始化默认codec + if b.codec == nil { + b.codec = codec_json.NewJSONSerializer() + } - // Note: If you use the stputil package, it will automatically set the global Manager | 注意:如果你使用了 stputil 包,它会自动设置全局 Manager - // We don't directly call stputil.SetManager here to avoid hard dependencies | 这里不直接调用 stputil.SetManager,避免强依赖 + // 日志 + if b.isLog { + if b.log == nil { + if b.logConfig == nil { + b.logConfig = slog.DefaultLoggerConfig() + } + b.log, err = slog.NewLoggerWithConfig(b.logConfig) + if err != nil { + panic("Invalid LoggerConfig: " + err.Error()) + } + } + } else { + b.log = nop.NewNopLogger() + } - return mgr -} + // 续期池 + if b.autoRenew { + if b.pool == nil { + if b.renewPoolConfig == nil { + b.renewPoolConfig = ants.DefaultRenewPoolConfig() + } + err = b.renewPoolConfig.Validate() + if err != nil { + panic("Invalid RenewPoolConfig: " + err.Error()) + } + b.pool, err = ants.NewRenewPoolManagerWithConfig(b.renewPoolConfig) + if err != nil { + panic(err) + } + } + + // 续期池状态的打印 + if b.renewPoolConfig.PrintStatusInterval > 0 { + ticker := time.NewTicker(b.renewPoolConfig.PrintStatusInterval) + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + running, capacity, usage := b.pool.Stats() + b.log.Infof( + "RenewPool Status: Capacity=%d, Running=%d, Usage=%.2f%%", + capacity, running, usage*100, + ) + } + } + }() + } + } + + // Print startup banner with full configuration | 打印启动Banner和完整配置 + if b.isPrintBanner { + banner.PrintWithConfig(cfg) + } -// MustBuild builds Manager and panics if validation fails | 构建Manager,验证失败时panic -func (b *Builder) MustBuild() *manager.Manager { - return b.Build() + // Build Manager | 构建 Manager + return manager.NewManager(cfg, b.generator, b.storage, b.codec, b.log, b.pool, b.customPermissionListFunc, b.customRoleListFunc) } diff --git a/core/config/config.go b/core/config/config.go index 56ccad6..eba7141 100644 --- a/core/config/config.go +++ b/core/config/config.go @@ -2,131 +2,79 @@ package config import ( "fmt" - "github.com/click33/sa-token-go/core/pool" -) - -// TokenStyle Token generation style | Token生成风格 -type TokenStyle string - -const ( - // TokenStyleUUID UUID style | UUID风格 - TokenStyleUUID TokenStyle = "uuid" - // TokenStyleSimple Simple random string | 简单随机字符串 - TokenStyleSimple TokenStyle = "simple" - // TokenStyleRandom32 32-bit random string | 32位随机字符串 - TokenStyleRandom32 TokenStyle = "random32" - // TokenStyleRandom64 64-bit random string | 64位随机字符串 - TokenStyleRandom64 TokenStyle = "random64" - // TokenStyleRandom128 128-bit random string | 128位随机字符串 - TokenStyleRandom128 TokenStyle = "random128" - // TokenStyleJWT JWT style | JWT风格 - TokenStyleJWT TokenStyle = "jwt" - // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 - TokenStyleHash TokenStyle = "hash" - // TokenStyleTimestamp Timestamp-based style | 时间戳风格 - TokenStyleTimestamp TokenStyle = "timestamp" - // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID(类似抖音) - TokenStyleTik TokenStyle = "tik" -) - -// SameSiteMode Cookie SameSite attribute values | Cookie的SameSite属性值 -type SameSiteMode string + "strings" -const ( - // SameSiteStrict Strict mode | 严格模式 - SameSiteStrict SameSiteMode = "Strict" - // SameSiteLax Lax mode | 宽松模式 - SameSiteLax SameSiteMode = "Lax" - // SameSiteNone None mode | 无限制模式 - SameSiteNone SameSiteMode = "None" + "github.com/click33/sa-token-go/core/adapter" ) -// Default configuration constants | 默认配置常量 -const ( - DefaultTokenName = "satoken" - DefaultTimeout = 2592000 // 30 days in seconds | 30天(秒) - DefaultMaxLoginCount = 12 // Maximum concurrent logins | 最大并发登录数 - DefaultCookiePath = "/" - NoLimit = -1 // No limit flag | 不限制标志 -) - -// IsValid checks if the TokenStyle is valid | 检查TokenStyle是否有效 -func (ts TokenStyle) IsValid() bool { - switch ts { - case TokenStyleUUID, TokenStyleSimple, TokenStyleRandom32, - TokenStyleRandom64, TokenStyleRandom128, TokenStyleJWT, - TokenStyleHash, TokenStyleTimestamp, TokenStyleTik: - return true - default: - return false - } -} - // Config Sa-Token configuration | Sa-Token配置 type Config struct { // TokenName Token name (also used as Cookie name) | Token名称(同时也是Cookie名称) TokenName string - // Timeout Token expiration time in seconds, -1 for never expire | Token超时时间(单位:秒,-1代表永不过期) + // Timeout Token expiration time (in seconds); -1 means never expire | Token超时时间(单位:秒,-1代表永不过期) Timeout int64 - // MaxRefresh Threshold for triggering async token renewal (in seconds) | Token自动续期触发阈值(单位:秒,当剩余有效期低于该值时触发异步续期 -1或0代表不限制) + // MaxRefresh Threshold (in seconds) to trigger async token renewal; when remaining lifetime is below this, renewal is triggered; -1 means no limit | Token自动续期触发阈值(单位:秒,当剩余有效期低于该值时触发异步续期,-1代表不限制) + // 注意此配置与 RenewInterval 配置关系 MaxRefresh int64 - // RenewInterval Minimum interval between token renewals (ms) | Token最小续期间隔(单位:秒,同一个Token在此时间内只会续期一次 -1或0代表不限制) + // RenewInterval Minimum interval (in seconds) between two renewals for the same token; -1 means no limit | 同一Token两次续期的最小间隔时间(单位:秒,-1代表不限制) + // 注意此配置与 MaxRefresh 配置关系 RenewInterval int64 - // ActiveTimeout Token minimum activity frequency in seconds. If Token is not accessed for this time, it will be frozen. -1 means no limit | Token最低活跃频率(单位:秒),如果Token超过此时间没有访问,则会被冻结。-1代表不限制,永不冻结 + // ActiveTimeout Maximum inactivity duration (in seconds); if the Token is not accessed within this time, it will be frozen. -1 means no limit | Token最大不活跃时长(单位:秒),超过此时间未访问则被踢出,-1代表不限制 + // 注意此配置与 MaxRefresh、RenewInterval 的配置关系 此配置目前只判断续期时更新的TokenInfo里面的ActiveTime ActiveTimeout int64 - // IsConcurrent Allow concurrent login for the same account (true=allow concurrent login, false=new login kicks out old login) | 是否允许同一账号并发登录(为true时允许一起登录,为false时新登录挤掉旧登录) + // IsConcurrent Allow concurrent login for the same account (true=allow, false=new login kicks old) | 是否允许同一账号并发登录(true=允许并发,false=新登录挤掉旧登录) + // 注意此配置与 IsShare 的配置关系 IsConcurrent bool - // IsShare Share the same Token for concurrent logins (true=share one Token, false=create new Token for each login) | 在多人登录同一账号时,是否共用一个Token(为true时所有登录共用一个Token,为false时每次登录新建一个Token) + // IsShare Share the same Token for concurrent logins (true=share one, false=create new for each login) | 并发登录是否共用同一个Token(true=共用一个,false=每次登录新建一个) + // 注意此配置与 IsConcurrent 的配置关系 IsShare bool - // MaxLoginCount Maximum number of concurrent logins for the same account, -1 means no limit (only effective when IsConcurrent=true and IsShare=false) | 同一账号最大登录数量,-1代表不限(只有在IsConcurrent=true,IsShare=false时此配置才有效) - MaxLoginCount int + // MaxLoginCount Maximum concurrent login count for the same account; -1 means unlimited (only effective when IsConcurrent=true and IsShare=false) | 同一账号最大登录数量,-1代表不限(仅当IsConcurrent=true且IsShare=false时生效) + // // 注意此配置与 IsConcurrent、IsShare 的配置关系 (仅当IsConcurrent=true且IsShare=false时生效) + MaxLoginCount int64 - // IsReadBody Try to read Token from request body (default: false) | 是否尝试从请求体里读取Token(默认:false) + // IsReadBody Try to read Token from the request body (log: false) | 是否尝试从请求体读取Token(默认:false) IsReadBody bool - // IsReadHeader Try to read Token from HTTP Header (default: true, recommended) | 是否尝试从Header里读取Token(默认:true,推荐) + // IsReadHeader Try to read Token from the HTTP Header (log: true, recommended) | 是否尝试从Header读取Token(默认:true,推荐) IsReadHeader bool - // IsReadCookie Try to read Token from Cookie (default: false) | 是否尝试从Cookie里读取Token(默认:false) + // IsReadCookie Try to read Token from the Cookie (log: false) | 是否尝试从Cookie读取Token(默认:false) IsReadCookie bool - // TokenStyle Token generation style | Token风格 - TokenStyle TokenStyle + // TokenStyle Token generation style | Token生成风格 + TokenStyle adapter.TokenStyle - // DataRefreshPeriod Auto-refresh period in seconds, -1 means no auto-refresh | 自动续签(单位:秒),-1代表不自动续签 - DataRefreshPeriod int64 - - // TokenSessionCheckLogin Check if Token-Session is kicked out when logging in (true=check on login, false=skip check) | Token-Session在登录时是否检查(true=登录时验证是否被踢下线,false=不作此检查) + // TokenSessionCheckLogin Whether to check if Token-Session is kicked out when logging in (true=check, false=skip) | 登录时是否检查Token-Session是否被踢下线(true=检查,false=不检查) + // 注意此配置在manager相关逻辑中暂时未使用 TokenSessionCheckLogin bool - // AutoRenew Auto-renew Token expiration time on each validation | 是否自动续期(每次验证Token时,都会延长Token的有效期) + // AutoRenew Automatically renew Token expiration time on each validation | 是否在每次验证Token时自动续期(延长Token有效期) AutoRenew bool - // JwtSecretKey JWT secret key (only effective when TokenStyle=JWT) | JWT密钥(只有TokenStyle=JWT时,此配置才生效) + // JwtSecretKey Secret key for JWT mode (effective only when TokenStyle=JWT) | JWT模式的密钥(仅当TokenStyle=JWT时生效) JwtSecretKey string - // IsLog Enable operation logging | 是否输出操作日志 + // IsLog Enable operation logging | 是否开启操作日志 IsLog bool - // IsPrintBanner Print startup banner (default: true) | 是否打印启动 Banner(默认:true) + // IsPrintBanner Print the startup banner (log: true) | 是否打印启动Banner(默认:true) IsPrintBanner bool - // KeyPrefix Storage key prefix for Redis isolation (default: "satoken:") | 存储键前缀,用于Redis隔离(默认:"satoken:") - // Set to empty "" to be compatible with Java sa-token default behavior | 设置为空""以兼容Java sa-token默认行为 + // KeyPrefix Storage key prefix for Storage isolation | 存储键前缀 KeyPrefix string // CookieConfig Cookie configuration | Cookie配置 CookieConfig *CookieConfig - // RenewPoolConfig Configuration for renewal pool manager | 续期池配置 - RenewPoolConfig *pool.RenewPoolConfig + // Authentication system type | 认证体系类型 + AuthType string } // CookieConfig Cookie configuration | Cookie配置 @@ -147,10 +95,10 @@ type CookieConfig struct { SameSite SameSiteMode // MaxAge Cookie expiration time in seconds | 过期时间(单位:秒) - MaxAge int + MaxAge int64 } -// DefaultConfig Returns default configuration | 返回默认配置 +// DefaultConfig Returns log configuration | 返回默认配置 func DefaultConfig() *Config { return &Config{ TokenName: DefaultTokenName, @@ -164,104 +112,136 @@ func DefaultConfig() *Config { IsReadBody: false, IsReadHeader: true, IsReadCookie: false, - TokenStyle: TokenStyleUUID, - DataRefreshPeriod: NoLimit, + TokenStyle: adapter.TokenStyleUUID, TokenSessionCheckLogin: true, AutoRenew: true, JwtSecretKey: "", IsLog: false, IsPrintBanner: true, - KeyPrefix: "satoken:", - CookieConfig: &CookieConfig{ - Domain: "", - Path: DefaultCookiePath, - Secure: false, - HttpOnly: true, - SameSite: SameSiteLax, - MaxAge: 0, - }, + KeyPrefix: DefaultKeyPrefix, + CookieConfig: DefaultCookieConfig(), + AuthType: DefaultAuthType, } } // Validate validates the configuration | 验证配置是否合理 func (c *Config) Validate() error { - // Check TokenName + // =============== Phase 1: Basic format validation | 阶段1:基础格式验证 =============== + + // [Critical] TokenName is required for Token identification | TokenName是Token标识的必要字段 if c.TokenName == "" { return fmt.Errorf("TokenName cannot be empty") } - - // Check TokenStyle - if !c.TokenStyle.IsValid() { - return fmt.Errorf("invalid TokenStyle: %s", c.TokenStyle) + if strings.ContainsAny(c.TokenName, "\t\r\n") { + return fmt.Errorf("TokenName cannot contain tab/newline characters, got: %q", c.TokenName) + } + if len(c.TokenName) > 64 { + return fmt.Errorf("TokenName too long (max 64 chars), got length: %d", len(c.TokenName)) } - // Check JWT secret key when using JWT style - if c.TokenStyle == TokenStyleJWT && c.JwtSecretKey == "" { - return fmt.Errorf("JwtSecretKey is required when TokenStyle is JWT") + // [Critical] KeyPrefix is required for storage isolation | KeyPrefix是存储隔离的必要字段 + if c.KeyPrefix == "" { + return fmt.Errorf("KeyPrefix cannot be empty") + } + if strings.ContainsAny(c.KeyPrefix, "\t\r\n") { + return fmt.Errorf("KeyPrefix cannot contain tab/newline characters, got: %q", c.KeyPrefix) + } + if len(c.KeyPrefix) > 64 { + return fmt.Errorf("KeyPrefix too long (max 64 chars), got length: %d", len(c.KeyPrefix)) } - // Check Timeout - if c.Timeout < NoLimit { - return fmt.Errorf("Timeout must be >= -1, got: %d", c.Timeout) + // [Critical] AuthType is required for auth system identification | AuthType是认证体系标识的必要字段 + if c.AuthType == "" { + return fmt.Errorf("AuthType cannot be empty") + } + if strings.ContainsAny(c.AuthType, "\t\r\n") { + return fmt.Errorf("AuthType cannot contain tab/newline characters, got: %q", c.AuthType) } + if len(c.AuthType) > 64 { + return fmt.Errorf("AuthType too long (max 64 chars), got length: %d", len(c.AuthType)) + } + + // =============== Phase 2: Numeric range validation | 阶段2:数值范围验证 =============== - // Check MaxRefresh - if c.MaxRefresh < NoLimit { - return fmt.Errorf("MaxRefresh must be >= -1, got: %d", c.MaxRefresh) + // [Critical] Numeric fields must be valid: -1 (no limit) or >0 | 数值字段必须合法:-1(无限制)或>0 + if err := c.checkNoLimits(); err != nil { + return err } - // Check MaxRefresh does not exceed Timeout - if c.Timeout != NoLimit && c.MaxRefresh > c.Timeout { - return fmt.Errorf("MaxRefresh (%d) cannot be greater than Timeout (%d)", c.MaxRefresh, c.Timeout) + // =============== Phase 3: TokenStyle + JWT validation | 阶段3:Token风格验证 =============== + + // [Critical] TokenStyle must be valid | Token风格必须合法 + if !c.TokenStyle.IsValid() { + return fmt.Errorf("invalid TokenStyle: %s", c.TokenStyle) } - // Check RenewInterval - if c.RenewInterval < NoLimit { - return fmt.Errorf("RenewInterval must be >= -1, got: %d", c.RenewInterval) + // [Critical] JWT mode requires secret key, otherwise JWT cannot work | JWT模式必须设置密钥,否则JWT无法工作 + if c.TokenStyle == adapter.TokenStyleJWT && c.JwtSecretKey == "" { + return fmt.Errorf("JwtSecretKey is required when TokenStyle is JWT") } - // Check ActiveTimeout - if c.ActiveTimeout < NoLimit { - return fmt.Errorf("ActiveTimeout must be >= -1, got: %d", c.ActiveTimeout) + // =============== Phase 4: Auto-adjustment for critical issues | 阶段4:关键问题自动调整 =============== + + // [Critical] AutoRenew enabled but MaxRefresh > Timeout would cause token never renew | 启用续期但阈值大于超时时间会导致永远不续期 + if c.AutoRenew && c.Timeout != NoLimit && c.MaxRefresh != NoLimit && c.MaxRefresh > c.Timeout { + c.MaxRefresh = c.Timeout / 2 + if c.MaxRefresh <= 0 { + c.MaxRefresh = c.Timeout + } } - // Check MaxLoginCount - if c.MaxLoginCount < NoLimit { - return fmt.Errorf("MaxLoginCount must be >= -1, got: %d", c.MaxLoginCount) + // =============== Phase 5: Critical time relationship validation | 阶段5:关键时间关系验证 =============== + + // [Critical] RenewInterval >= ActiveTimeout would cause active users to be kicked out | 续期间隔大于等于活跃超时会导致活跃用户被踢出 + if c.AutoRenew && c.ActiveTimeout != NoLimit && c.RenewInterval != NoLimit && c.RenewInterval >= c.ActiveTimeout { + return fmt.Errorf("RenewInterval (%d) must be less than ActiveTimeout (%d), otherwise active users may be kicked out", c.RenewInterval, c.ActiveTimeout) } - // Check if at least one read source is enabled + // =============== Phase 6: Token read source validation | 阶段6:Token读取来源验证 =============== + + // [Critical] At least one read source must be enabled, otherwise Token cannot be obtained | 至少启用一个读取来源,否则无法获取Token if !c.IsReadHeader && !c.IsReadCookie && !c.IsReadBody { return fmt.Errorf("at least one of IsReadHeader, IsReadCookie, or IsReadBody must be true") } - // Validate RenewPoolConfig if set | 如果设置了续期池配置,进行验证 - if c.RenewPoolConfig != nil { - // Check MinSize and MaxSize | 检查最小和最大协程池大小 - if c.RenewPoolConfig.MinSize <= 0 { - return fmt.Errorf("RenewPoolConfig.MinSize must be > 0") // 最小协程池大小必须大于0 - } - if c.RenewPoolConfig.MaxSize < c.RenewPoolConfig.MinSize { - return fmt.Errorf("RenewPoolConfig.MaxSize must be >= RenewPoolConfig.MinSize") // 最大协程池大小必须大于等于最小协程池大小 - } + // =============== Phase 7: CookieConfig validation | 阶段7:Cookie配置验证 =============== - // Check ScaleUpRate and ScaleDownRate | 检查扩容和缩容阈值 - if c.RenewPoolConfig.ScaleUpRate <= 0 || c.RenewPoolConfig.ScaleUpRate > 1 { - return fmt.Errorf("RenewPoolConfig.ScaleUpRate must be between 0 and 1") // 扩容阈值必须在0和1之间 - } - if c.RenewPoolConfig.ScaleDownRate < 0 || c.RenewPoolConfig.ScaleDownRate > 1 { - return fmt.Errorf("RenewPoolConfig.ScaleDownRate must be between 0 and 1") // 缩容阈值必须在0和1之间 - } + // [Critical] CookieConfig required when IsReadCookie is true | 启用Cookie读取时必须设置CookieConfig + if c.IsReadCookie && c.CookieConfig == nil { + return fmt.Errorf("CookieConfig cannot be nil when IsReadCookie is true") + } - // Check CheckInterval | 检查检查间隔 - if c.RenewPoolConfig.CheckInterval <= 0 { - return fmt.Errorf("RenewPoolConfig.CheckInterval must be a positive duration") // 检查间隔必须是一个正值 + // Validate CookieConfig critical issues | 验证Cookie配置的关键问题 + if c.CookieConfig != nil { + if err := c.validateCookieConfig(); err != nil { + return err } + } - // Check Expiry | 检查过期时间 - if c.RenewPoolConfig.Expiry <= 0 { - return fmt.Errorf("RenewPoolConfig.Expiry must be a positive duration") // 过期时间必须是正值 - } + // All critical checks passed | 所有关键检查通过 + return nil +} + +// validateCookieConfig validates critical CookieConfig issues | 验证Cookie配置的关键问题 +func (c *Config) validateCookieConfig() error { + cc := c.CookieConfig + + // [Critical] Path is required for Cookie to work | Path是Cookie工作的必要字段 + if cc.Path == "" { + return fmt.Errorf("CookieConfig.Path cannot be empty") + } + + // [Critical] SameSite must be valid value | SameSite必须是合法值 + switch cc.SameSite { + case SameSiteLax, SameSiteStrict, SameSiteNone, "": + // Valid values (empty string will use browser default) | 合法值(空字符串将使用浏览器默认值) + default: + return fmt.Errorf("invalid CookieConfig.SameSite value: %v", cc.SameSite) + } + + // [Critical] Secure must be true when SameSite=None, otherwise browser will reject Cookie | SameSite=None时Secure必须为true,否则浏览器会拒绝Cookie + if cc.SameSite == SameSiteNone && !cc.Secure { + return fmt.Errorf("CookieConfig.Secure must be true when SameSite is None (browser requirement)") } return nil @@ -320,7 +300,7 @@ func (c *Config) SetIsShare(isShare bool) *Config { } // SetMaxLoginCount Set maximum login count | 设置最大登录数量 -func (c *Config) SetMaxLoginCount(count int) *Config { +func (c *Config) SetMaxLoginCount(count int64) *Config { c.MaxLoginCount = count return c } @@ -344,17 +324,11 @@ func (c *Config) SetIsReadCookie(isReadCookie bool) *Config { } // SetTokenStyle Set Token generation style | 设置Token风格 -func (c *Config) SetTokenStyle(style TokenStyle) *Config { +func (c *Config) SetTokenStyle(style adapter.TokenStyle) *Config { c.TokenStyle = style return c } -// SetDataRefreshPeriod Set data refresh period | 设置数据刷新周期 -func (c *Config) SetDataRefreshPeriod(period int64) *Config { - c.DataRefreshPeriod = period - return c -} - // SetTokenSessionCheckLogin Set whether to check token session on login | 设置登录时是否检查token会话 func (c *Config) SetTokenSessionCheckLogin(check bool) *Config { c.TokenSessionCheckLogin = check @@ -393,12 +367,54 @@ func (c *Config) SetKeyPrefix(prefix string) *Config { // SetCookieConfig Set cookie configuration | 设置Cookie配置 func (c *Config) SetCookieConfig(cookieConfig *CookieConfig) *Config { - c.CookieConfig = cookieConfig + if cookieConfig != nil { + c.CookieConfig = cookieConfig + } return c } -// SetRenewPoolConfig Set renewal pool configuration | 设置续期池配置 -func (c *Config) SetRenewPoolConfig(renewPoolConfig *pool.RenewPoolConfig) *Config { - c.RenewPoolConfig = renewPoolConfig +// SetAuthType Set authentication system type | 设置认证体系类型 +func (c *Config) SetAuthType(authType string) *Config { + c.AuthType = authType return c } + +// ============ Internal Helper Methods | 内部辅助方法 ============ + +// checkNoLimits validates that all numeric fields must be -1 (no limit) or >0 (valid) | 验证所有数值字段必须为 -1(无限制)或 >0(有效) +func (c *Config) checkNoLimits() error { + // Define fields to validate | 定义需要验证的字段 + fields := map[string]int64{ + "Timeout": c.Timeout, + "MaxRefresh": c.MaxRefresh, + "RenewInterval": c.RenewInterval, + "ActiveTimeout": c.ActiveTimeout, + "MaxLoginCount": c.MaxLoginCount, + } + + // Iterate through fields and validate each one | 遍历字段并验证 + for name, value := range fields { + // Must be -1 (no limit) or >0 (valid) | 必须为 -1(无限制)或 >0(有效) + if value == -1 || value > 0 { + continue + } + + // Return error if invalid | 若不合法则返回错误 + return fmt.Errorf("%s must be -1 (no limit) or >0 (valid), got: %d", name, value) + } + + // All numeric fields are valid | 所有数值字段均验证通过 + return nil +} + +// DefaultCookieConfig returns the log Cookie configuration | 返回默认的 Cookie 配置 +func DefaultCookieConfig() *CookieConfig { + return &CookieConfig{ + Domain: "", + Path: DefaultCookiePath, + Secure: false, + HttpOnly: true, + SameSite: SameSiteLax, + MaxAge: 0, + } +} diff --git a/core/config/config_test.go b/core/config/config_test.go new file mode 100644 index 0000000..5025234 --- /dev/null +++ b/core/config/config_test.go @@ -0,0 +1,940 @@ +package config + +import ( + "strings" + "testing" + + "github.com/click33/sa-token-go/core/adapter" +) + +// =============== Phase 1: Basic format validation tests | 阶段1:基础格式验证测试 =============== + +// TestValidate_TokenName tests TokenName validation +func TestValidate_TokenName(t *testing.T) { + tests := []struct { + name string + tokenName string + wantErr bool + errMsg string + }{ + { + name: "Valid TokenName", + tokenName: "satoken", + wantErr: false, + }, + { + name: "Valid TokenName with hyphen", + tokenName: "sa-token", + wantErr: false, + }, + { + name: "Valid TokenName with underscore", + tokenName: "sa_token", + wantErr: false, + }, + { + name: "Empty TokenName should error", + tokenName: "", + wantErr: true, + errMsg: "cannot be empty", + }, + { + name: "TokenName with tab should error", + tokenName: "sa\ttoken", + wantErr: true, + errMsg: "tab/newline", + }, + { + name: "TokenName with newline should error", + tokenName: "sa\ntoken", + wantErr: true, + errMsg: "tab/newline", + }, + { + name: "TokenName with carriage return should error", + tokenName: "sa\rtoken", + wantErr: true, + errMsg: "tab/newline", + }, + { + name: "TokenName too long should error", + tokenName: strings.Repeat("a", 65), + wantErr: true, + errMsg: "too long", + }, + { + name: "TokenName at max length should pass", + tokenName: strings.Repeat("a", 64), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.TokenName = tt.tokenName + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.errMsg) + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Expected error containing %q, got %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// TestValidate_KeyPrefix tests KeyPrefix validation +func TestValidate_KeyPrefix(t *testing.T) { + tests := []struct { + name string + keyPrefix string + wantErr bool + errMsg string + }{ + { + name: "Valid KeyPrefix", + keyPrefix: "satoken:", + wantErr: false, + }, + { + name: "Valid KeyPrefix with colon", + keyPrefix: "app:satoken:", + wantErr: false, + }, + { + name: "Empty KeyPrefix should error", + keyPrefix: "", + wantErr: true, + errMsg: "cannot be empty", + }, + { + name: "KeyPrefix with tab should error", + keyPrefix: "sa\ttoken:", + wantErr: true, + errMsg: "tab/newline", + }, + { + name: "KeyPrefix with newline should error", + keyPrefix: "sa\ntoken:", + wantErr: true, + errMsg: "tab/newline", + }, + { + name: "KeyPrefix too long should error", + keyPrefix: strings.Repeat("a", 65), + wantErr: true, + errMsg: "too long", + }, + { + name: "KeyPrefix at max length should pass", + keyPrefix: strings.Repeat("a", 64), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.KeyPrefix = tt.keyPrefix + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.errMsg) + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Expected error containing %q, got %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// TestValidate_AuthType tests AuthType validation +func TestValidate_AuthType(t *testing.T) { + tests := []struct { + name string + authType string + wantErr bool + errMsg string + }{ + { + name: "Valid AuthType", + authType: "login", + wantErr: false, + }, + { + name: "Empty AuthType should error", + authType: "", + wantErr: true, + errMsg: "cannot be empty", + }, + { + name: "AuthType with tab should error", + authType: "auth\ttype", + wantErr: true, + errMsg: "tab/newline", + }, + { + name: "AuthType too long should error", + authType: strings.Repeat("a", 65), + wantErr: true, + errMsg: "too long", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.AuthType = tt.authType + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.errMsg) + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Expected error containing %q, got %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// =============== Phase 2: Numeric range validation tests | 阶段2:数值范围验证测试 =============== + +// TestValidate_NumericFields tests numeric field validation +func TestValidate_NumericFields(t *testing.T) { + tests := []struct { + name string + fieldName string + setValue func(*Config, int64) + value int64 + wantErr bool + }{ + // Timeout tests + {"Timeout valid positive", "Timeout", func(c *Config, v int64) { c.Timeout = v }, 3600, false}, + {"Timeout NoLimit", "Timeout", func(c *Config, v int64) { c.Timeout = v }, NoLimit, false}, + {"Timeout zero should error", "Timeout", func(c *Config, v int64) { c.Timeout = v }, 0, true}, + {"Timeout negative should error", "Timeout", func(c *Config, v int64) { c.Timeout = v }, -2, true}, + + // MaxRefresh tests + {"MaxRefresh valid positive", "MaxRefresh", func(c *Config, v int64) { c.MaxRefresh = v }, 1800, false}, + {"MaxRefresh NoLimit", "MaxRefresh", func(c *Config, v int64) { c.MaxRefresh = v }, NoLimit, false}, + {"MaxRefresh zero should error", "MaxRefresh", func(c *Config, v int64) { c.MaxRefresh = v }, 0, true}, + + // RenewInterval tests + {"RenewInterval valid positive", "RenewInterval", func(c *Config, v int64) { c.RenewInterval = v }, 60, false}, + {"RenewInterval NoLimit", "RenewInterval", func(c *Config, v int64) { c.RenewInterval = v }, NoLimit, false}, + {"RenewInterval zero should error", "RenewInterval", func(c *Config, v int64) { c.RenewInterval = v }, 0, true}, + + // ActiveTimeout tests + {"ActiveTimeout valid positive", "ActiveTimeout", func(c *Config, v int64) { c.ActiveTimeout = v }, 1800, false}, + {"ActiveTimeout NoLimit", "ActiveTimeout", func(c *Config, v int64) { c.ActiveTimeout = v }, NoLimit, false}, + {"ActiveTimeout zero should error", "ActiveTimeout", func(c *Config, v int64) { c.ActiveTimeout = v }, 0, true}, + + // MaxLoginCount tests + {"MaxLoginCount valid positive", "MaxLoginCount", func(c *Config, v int64) { c.MaxLoginCount = v }, 5, false}, + {"MaxLoginCount NoLimit", "MaxLoginCount", func(c *Config, v int64) { c.MaxLoginCount = v }, NoLimit, false}, + {"MaxLoginCount zero should error", "MaxLoginCount", func(c *Config, v int64) { c.MaxLoginCount = v }, 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + tt.setValue(cfg, tt.value) + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("Expected error for %s=%d, got nil", tt.fieldName, tt.value) + } + } else { + if err != nil { + t.Errorf("Expected no error for %s=%d, got %v", tt.fieldName, tt.value, err) + } + } + }) + } +} + +// =============== Phase 3: TokenStyle + JWT validation tests | 阶段3:Token风格验证测试 =============== + +// TestValidate_TokenStyle tests TokenStyle validation +func TestValidate_TokenStyle(t *testing.T) { + tests := []struct { + name string + tokenStyle adapter.TokenStyle + jwtSecret string + wantErr bool + errMsg string + }{ + { + name: "Valid UUID style", + tokenStyle: adapter.TokenStyleUUID, + wantErr: false, + }, + { + name: "Valid Simple style", + tokenStyle: adapter.TokenStyleSimple, + wantErr: false, + }, + { + name: "Valid Random32 style", + tokenStyle: adapter.TokenStyleRandom32, + wantErr: false, + }, + { + name: "Valid Random64 style", + tokenStyle: adapter.TokenStyleRandom64, + wantErr: false, + }, + { + name: "Valid Random128 style", + tokenStyle: adapter.TokenStyleRandom128, + wantErr: false, + }, + { + name: "Valid JWT style with secret", + tokenStyle: adapter.TokenStyleJWT, + jwtSecret: "my-secret-key", + wantErr: false, + }, + { + name: "JWT style without secret should error", + tokenStyle: adapter.TokenStyleJWT, + jwtSecret: "", + wantErr: true, + errMsg: "JwtSecretKey is required", + }, + { + name: "Invalid TokenStyle should error", + tokenStyle: adapter.TokenStyle("invalid"), + wantErr: true, + errMsg: "invalid TokenStyle", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.TokenStyle = tt.tokenStyle + cfg.JwtSecretKey = tt.jwtSecret + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.errMsg) + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Expected error containing %q, got %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// =============== Phase 4: Auto-adjustment tests | 阶段4:自动调整测试 =============== + +// TestValidate_AutoAdjustMaxRefresh tests auto-adjustment of MaxRefresh +func TestValidate_AutoAdjustMaxRefresh(t *testing.T) { + tests := []struct { + name string + timeout int64 + maxRefresh int64 + autoRenew bool + expectedMaxRefresh int64 + }{ + { + name: "MaxRefresh exceeds Timeout - should adjust to Timeout/2", + timeout: 3600, + maxRefresh: 7200, + autoRenew: true, + expectedMaxRefresh: 1800, + }, + { + name: "MaxRefresh within Timeout - should not change", + timeout: 3600, + maxRefresh: 1800, + autoRenew: true, + expectedMaxRefresh: 1800, + }, + { + name: "AutoRenew disabled - should not adjust", + timeout: 3600, + maxRefresh: 7200, + autoRenew: false, + expectedMaxRefresh: 7200, + }, + { + name: "Timeout is NoLimit - should not adjust", + timeout: NoLimit, + maxRefresh: 7200, + autoRenew: true, + expectedMaxRefresh: 7200, + }, + { + name: "MaxRefresh is NoLimit - should not adjust", + timeout: 3600, + maxRefresh: NoLimit, + autoRenew: true, + expectedMaxRefresh: NoLimit, + }, + { + name: "Very small Timeout - MaxRefresh should equal Timeout", + timeout: 1, + maxRefresh: 3600, + autoRenew: true, + expectedMaxRefresh: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.Timeout = tt.timeout + cfg.MaxRefresh = tt.maxRefresh + cfg.AutoRenew = tt.autoRenew + + err := cfg.Validate() + if err != nil { + t.Fatalf("Validate() error = %v, want nil", err) + } + + if cfg.MaxRefresh != tt.expectedMaxRefresh { + t.Errorf("MaxRefresh = %d, want %d", cfg.MaxRefresh, tt.expectedMaxRefresh) + } + }) + } +} + +// =============== Phase 5: Time relationship validation tests | 阶段5:时间关系验证测试 =============== + +// TestValidate_RenewIntervalVsActiveTimeout tests RenewInterval vs ActiveTimeout validation +func TestValidate_RenewIntervalVsActiveTimeout(t *testing.T) { + tests := []struct { + name string + autoRenew bool + activeTimeout int64 + renewInterval int64 + wantErr bool + }{ + { + name: "RenewInterval < ActiveTimeout - should pass", + autoRenew: true, + activeTimeout: 3600, + renewInterval: 1800, + wantErr: false, + }, + { + name: "RenewInterval = ActiveTimeout - should error", + autoRenew: true, + activeTimeout: 3600, + renewInterval: 3600, + wantErr: true, + }, + { + name: "RenewInterval > ActiveTimeout - should error", + autoRenew: true, + activeTimeout: 1800, + renewInterval: 3600, + wantErr: true, + }, + { + name: "AutoRenew disabled - should pass even if RenewInterval >= ActiveTimeout", + autoRenew: false, + activeTimeout: 1800, + renewInterval: 3600, + wantErr: false, + }, + { + name: "ActiveTimeout is NoLimit - should pass", + autoRenew: true, + activeTimeout: NoLimit, + renewInterval: 3600, + wantErr: false, + }, + { + name: "RenewInterval is NoLimit - should pass", + autoRenew: true, + activeTimeout: 3600, + renewInterval: NoLimit, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.AutoRenew = tt.autoRenew + cfg.ActiveTimeout = tt.activeTimeout + cfg.RenewInterval = tt.renewInterval + + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// =============== Phase 6: Token read source validation tests | 阶段6:Token读取来源验证测试 =============== + +// TestValidate_TokenReadSources tests token read source validation +func TestValidate_TokenReadSources(t *testing.T) { + tests := []struct { + name string + isReadHeader bool + isReadCookie bool + isReadBody bool + wantErr bool + }{ + { + name: "Only Header enabled", + isReadHeader: true, + isReadCookie: false, + isReadBody: false, + wantErr: false, + }, + { + name: "Only Cookie enabled", + isReadHeader: false, + isReadCookie: true, + isReadBody: false, + wantErr: false, + }, + { + name: "Only Body enabled", + isReadHeader: false, + isReadCookie: false, + isReadBody: true, + wantErr: false, + }, + { + name: "All enabled", + isReadHeader: true, + isReadCookie: true, + isReadBody: true, + wantErr: false, + }, + { + name: "None enabled - should error", + isReadHeader: false, + isReadCookie: false, + isReadBody: false, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.IsReadHeader = tt.isReadHeader + cfg.IsReadCookie = tt.isReadCookie + cfg.IsReadBody = tt.isReadBody + + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Error("Expected error, got nil") + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// =============== Phase 7: CookieConfig validation tests | 阶段7:Cookie配置验证测试 =============== + +// TestValidate_CookieConfig tests CookieConfig validation +func TestValidate_CookieConfig(t *testing.T) { + tests := []struct { + name string + isReadCookie bool + cookieConfig *CookieConfig + wantErr bool + errMsg string + }{ + { + name: "Valid CookieConfig with IsReadCookie=true", + isReadCookie: true, + cookieConfig: DefaultCookieConfig(), + wantErr: false, + }, + { + name: "Nil CookieConfig with IsReadCookie=true - should error", + isReadCookie: true, + cookieConfig: nil, + wantErr: true, + errMsg: "CookieConfig cannot be nil", + }, + { + name: "Nil CookieConfig with IsReadCookie=false - should pass", + isReadCookie: false, + cookieConfig: nil, + wantErr: false, + }, + { + name: "Empty Path - should error", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "", SameSite: SameSiteLax}, + wantErr: true, + errMsg: "Path cannot be empty", + }, + { + name: "Valid SameSite Lax", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "/", SameSite: SameSiteLax}, + wantErr: false, + }, + { + name: "Valid SameSite Strict", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "/", SameSite: SameSiteStrict}, + wantErr: false, + }, + { + name: "Valid SameSite None with Secure=true", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "/", SameSite: SameSiteNone, Secure: true}, + wantErr: false, + }, + { + name: "SameSite None with Secure=false - should error", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "/", SameSite: SameSiteNone, Secure: false}, + wantErr: true, + errMsg: "Secure must be true when SameSite is None", + }, + { + name: "Invalid SameSite value - should error", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "/", SameSite: SameSiteMode("Invalid")}, + wantErr: true, + errMsg: "invalid CookieConfig.SameSite", + }, + { + name: "Empty SameSite - should pass (browser default)", + isReadCookie: true, + cookieConfig: &CookieConfig{Path: "/", SameSite: ""}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := DefaultConfig() + cfg.IsReadCookie = tt.isReadCookie + cfg.CookieConfig = tt.cookieConfig + + err := cfg.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("Expected error containing %q, got nil", tt.errMsg) + } else if !strings.Contains(err.Error(), tt.errMsg) { + t.Errorf("Expected error containing %q, got %q", tt.errMsg, err.Error()) + } + } else { + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + } + }) + } +} + +// =============== Clone tests | Clone测试 =============== + +// TestConfig_Clone tests configuration cloning +func TestConfig_Clone(t *testing.T) { + t.Run("Clone creates independent copy", func(t *testing.T) { + original := DefaultConfig() + original.Timeout = 3600 + original.TokenName = "custom-token" + + cloned := original.Clone() + + // Verify values are copied + if cloned.Timeout != original.Timeout { + t.Errorf("Cloned Timeout = %d, want %d", cloned.Timeout, original.Timeout) + } + if cloned.TokenName != original.TokenName { + t.Errorf("Cloned TokenName = %s, want %s", cloned.TokenName, original.TokenName) + } + + // Verify it's a deep copy + cloned.Timeout = 7200 + if cloned.Timeout == original.Timeout { + t.Error("Clone should be independent of original") + } + }) + + t.Run("Clone deep copies CookieConfig", func(t *testing.T) { + original := DefaultConfig() + original.CookieConfig.Domain = "example.com" + + cloned := original.Clone() + + // Verify CookieConfig is copied + if cloned.CookieConfig.Domain != original.CookieConfig.Domain { + t.Errorf("Cloned CookieConfig.Domain = %s, want %s", cloned.CookieConfig.Domain, original.CookieConfig.Domain) + } + + // Verify CookieConfig is independent + cloned.CookieConfig.Domain = "other.com" + if cloned.CookieConfig.Domain == original.CookieConfig.Domain { + t.Error("Cloned CookieConfig should be independent of original") + } + }) + + t.Run("Clone handles nil CookieConfig", func(t *testing.T) { + original := DefaultConfig() + original.CookieConfig = nil + + cloned := original.Clone() + + if cloned.CookieConfig != nil { + t.Error("Cloned CookieConfig should be nil when original is nil") + } + }) +} + +// =============== Setter chain tests | 链式设置测试 =============== + +// TestConfig_SetterChain tests that setters return *Config for chaining +func TestConfig_SetterChain(t *testing.T) { + cfg := DefaultConfig(). + SetTokenName("my-token"). + SetTimeout(7200). + SetMaxRefresh(3600). + SetRenewInterval(60). + SetActiveTimeout(1800). + SetIsConcurrent(true). + SetIsShare(false). + SetMaxLoginCount(5). + SetIsReadBody(true). + SetIsReadHeader(true). + SetIsReadCookie(false). + SetTokenStyle(adapter.TokenStyleSimple). + SetTokenSessionCheckLogin(true). + SetJwtSecretKey("secret"). + SetAutoRenew(true). + SetIsLog(true). + SetIsPrintBanner(false). + SetKeyPrefix("app:"). + SetAuthType("login") + + // Verify values + if cfg.TokenName != "my-token" { + t.Errorf("TokenName = %s, want my-token", cfg.TokenName) + } + if cfg.Timeout != 7200 { + t.Errorf("Timeout = %d, want 7200", cfg.Timeout) + } + if cfg.MaxRefresh != 3600 { + t.Errorf("MaxRefresh = %d, want 3600", cfg.MaxRefresh) + } + if cfg.MaxLoginCount != 5 { + t.Errorf("MaxLoginCount = %d, want 5", cfg.MaxLoginCount) + } + if !cfg.IsReadBody { + t.Error("IsReadBody should be true") + } + if cfg.TokenStyle != adapter.TokenStyleSimple { + t.Errorf("TokenStyle = %s, want simple", cfg.TokenStyle) + } + if !cfg.IsLog { + t.Error("IsLog should be true") + } + if cfg.IsPrintBanner { + t.Error("IsPrintBanner should be false") + } +} + +// =============== DefaultConfig tests | 默认配置测试 =============== + +// TestDefaultConfig tests default configuration values +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + + if cfg.TokenName != DefaultTokenName { + t.Errorf("TokenName = %s, want %s", cfg.TokenName, DefaultTokenName) + } + if cfg.Timeout != DefaultTimeout { + t.Errorf("Timeout = %d, want %d", cfg.Timeout, DefaultTimeout) + } + if cfg.MaxRefresh != DefaultTimeout/2 { + t.Errorf("MaxRefresh = %d, want %d", cfg.MaxRefresh, DefaultTimeout/2) + } + if cfg.RenewInterval != NoLimit { + t.Errorf("RenewInterval = %d, want %d", cfg.RenewInterval, NoLimit) + } + if cfg.ActiveTimeout != NoLimit { + t.Errorf("ActiveTimeout = %d, want %d", cfg.ActiveTimeout, NoLimit) + } + if !cfg.IsConcurrent { + t.Error("IsConcurrent should be true by default") + } + if !cfg.IsShare { + t.Error("IsShare should be true by default") + } + if cfg.MaxLoginCount != DefaultMaxLoginCount { + t.Errorf("MaxLoginCount = %d, want %d", cfg.MaxLoginCount, DefaultMaxLoginCount) + } + if cfg.IsReadBody { + t.Error("IsReadBody should be false by default") + } + if !cfg.IsReadHeader { + t.Error("IsReadHeader should be true by default") + } + if cfg.IsReadCookie { + t.Error("IsReadCookie should be false by default") + } + if cfg.TokenStyle != adapter.TokenStyleUUID { + t.Errorf("TokenStyle = %s, want uuid", cfg.TokenStyle) + } + if !cfg.TokenSessionCheckLogin { + t.Error("TokenSessionCheckLogin should be true by default") + } + if !cfg.AutoRenew { + t.Error("AutoRenew should be true by default") + } + if cfg.JwtSecretKey != "" { + t.Errorf("JwtSecretKey should be empty by default, got %s", cfg.JwtSecretKey) + } + if cfg.IsLog { + t.Error("IsLog should be false by default") + } + if !cfg.IsPrintBanner { + t.Error("IsPrintBanner should be true by default") + } + if cfg.KeyPrefix != DefaultKeyPrefix { + t.Errorf("KeyPrefix = %s, want %s", cfg.KeyPrefix, DefaultKeyPrefix) + } + if cfg.AuthType != DefaultAuthType { + t.Errorf("AuthType = %s, want %s", cfg.AuthType, DefaultAuthType) + } + if cfg.CookieConfig == nil { + t.Error("CookieConfig should not be nil by default") + } +} + +// TestDefaultCookieConfig tests default cookie configuration values +func TestDefaultCookieConfig(t *testing.T) { + cc := DefaultCookieConfig() + + if cc.Domain != "" { + t.Errorf("Domain = %s, want empty string", cc.Domain) + } + if cc.Path != DefaultCookiePath { + t.Errorf("Path = %s, want %s", cc.Path, DefaultCookiePath) + } + if cc.Secure { + t.Error("Secure should be false by default") + } + if !cc.HttpOnly { + t.Error("HttpOnly should be true by default") + } + if cc.SameSite != SameSiteLax { + t.Errorf("SameSite = %s, want %s", cc.SameSite, SameSiteLax) + } + if cc.MaxAge != 0 { + t.Errorf("MaxAge = %d, want 0", cc.MaxAge) + } +} + +// =============== Integration tests | 集成测试 =============== + +// TestValidate_DefaultConfig tests that default config passes validation +func TestValidate_DefaultConfig(t *testing.T) { + cfg := DefaultConfig() + err := cfg.Validate() + if err != nil { + t.Errorf("DefaultConfig should pass validation, got error: %v", err) + } +} + +// TestValidate_RealWorldScenarios tests common real-world configuration scenarios +func TestValidate_RealWorldScenarios(t *testing.T) { + t.Run("Web application with short session", func(t *testing.T) { + cfg := DefaultConfig(). + SetTimeout(3600). // 1 hour + SetActiveTimeout(1800). // 30 minutes inactive kick + SetRenewInterval(300). // renew every 5 minutes max + SetIsReadCookie(true). + SetIsReadHeader(true) + + err := cfg.Validate() + if err != nil { + t.Errorf("Expected valid config, got error: %v", err) + } + }) + + t.Run("API service with long-lived tokens", func(t *testing.T) { + cfg := DefaultConfig(). + SetTimeout(NoLimit). // never expire + SetActiveTimeout(NoLimit). // no inactive timeout + SetIsReadHeader(true). + SetIsReadCookie(false) + + err := cfg.Validate() + if err != nil { + t.Errorf("Expected valid config, got error: %v", err) + } + }) + + t.Run("JWT based authentication", func(t *testing.T) { + cfg := DefaultConfig(). + SetTokenStyle(adapter.TokenStyleJWT). + SetJwtSecretKey("my-super-secret-key-for-jwt-signing"). + SetTimeout(86400) // 1 day + + err := cfg.Validate() + if err != nil { + t.Errorf("Expected valid config, got error: %v", err) + } + }) + + t.Run("Multi-device login support", func(t *testing.T) { + cfg := DefaultConfig(). + SetIsConcurrent(true). + SetIsShare(false). + SetMaxLoginCount(5) + + err := cfg.Validate() + if err != nil { + t.Errorf("Expected valid config, got error: %v", err) + } + }) + + t.Run("Single device login only", func(t *testing.T) { + cfg := DefaultConfig(). + SetIsConcurrent(false). + SetIsShare(true) // must be true when IsConcurrent is false + + err := cfg.Validate() + if err != nil { + t.Errorf("Expected valid config, got error: %v", err) + } + }) +} diff --git a/core/config/consts.go b/core/config/consts.go new file mode 100644 index 0000000..ef4cbb4 --- /dev/null +++ b/core/config/consts.go @@ -0,0 +1,25 @@ +// @Author daixk 2025/12/7 15:34:00 +package config + +// SameSiteMode Cookie SameSite attribute values | Cookie的SameSite属性值 +type SameSiteMode string + +const ( + // SameSiteStrict Strict mode | 严格模式 + SameSiteStrict SameSiteMode = "Strict" + // SameSiteLax Lax mode | 宽松模式 + SameSiteLax SameSiteMode = "Lax" + // SameSiteNone None mode | 无限制模式 + SameSiteNone SameSiteMode = "None" +) + +// Default configuration constants | 默认配置常量 +const ( + DefaultTokenName = "satoken" // Default token name | 默认Token名称 + DefaultKeyPrefix = "satoken:" // Default Redis key prefix | 默认Redis键前缀 + DefaultAuthType = "auth:" // Default AuthType | 默认认证体系键前缀 + DefaultTimeout = 2592000 // 30 days (seconds) | 30天(秒) + DefaultMaxLoginCount = 12 // Maximum concurrent logins | 最大并发登录数 + DefaultCookiePath = "/" // Default cookie path | 默认Cookie路径 + NoLimit = -1 // No limit flag | 不限制标志 +) diff --git a/core/context/context.go b/core/context/context.go index aa7e0af..222e9c8 100644 --- a/core/context/context.go +++ b/core/context/context.go @@ -14,33 +14,18 @@ const ( // SaTokenContext Sa-Token context for current request | Sa-Token上下文,用于当前请求 type SaTokenContext struct { - ctx adapter.RequestContext + reqCtx adapter.RequestContext manager *manager.Manager } // NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx adapter.RequestContext, mgr *manager.Manager) *SaTokenContext { +func NewContext(reqCtx adapter.RequestContext, mgr *manager.Manager) *SaTokenContext { return &SaTokenContext{ - ctx: ctx, + reqCtx: reqCtx, manager: mgr, } } -// extractBearerToken 从 Authorization 头中提取 Bearer Token -func extractBearerToken(auth string) string { - auth = strings.TrimSpace(auth) - if auth == "" { - return "" - } - - // 支持大小写不敏感的 Bearer 前缀 - if len(auth) > 7 && strings.EqualFold(auth[:7], bearerPrefix) { - return strings.TrimSpace(auth[7:]) - } - - return auth -} - // GetTokenValue gets token value from current request | 获取当前请求的Token值 func (c *SaTokenContext) GetTokenValue() string { cfg := c.manager.GetConfig() @@ -48,12 +33,12 @@ func (c *SaTokenContext) GetTokenValue() string { // 1. 尝试从Header获取 if cfg.IsReadHeader { // 从自定义 token 名称的 Header 获取 - if token := strings.TrimSpace(c.ctx.GetHeader(cfg.TokenName)); token != "" { + if token := strings.TrimSpace(c.reqCtx.GetHeader(cfg.TokenName)); token != "" { return token } // 从 Authorization 头获取 - if auth := c.ctx.GetHeader(authHeader); auth != "" { + if auth := c.reqCtx.GetHeader(authHeader); auth != "" { if token := extractBearerToken(auth); token != "" { return token } @@ -62,61 +47,76 @@ func (c *SaTokenContext) GetTokenValue() string { // 2. 尝试从Cookie获取 if cfg.IsReadCookie { - if token := strings.TrimSpace(c.ctx.GetCookie(cfg.TokenName)); token != "" { + if token := strings.TrimSpace(c.reqCtx.GetCookie(cfg.TokenName)); token != "" { return token } } // 3. 尝试从Query参数获取 - if token := strings.TrimSpace(c.ctx.GetQuery(cfg.TokenName)); token != "" { + if token := strings.TrimSpace(c.reqCtx.GetQuery(cfg.TokenName)); token != "" { return token } return "" } -// IsLogin 检查当前请求是否已登录 -func (c *SaTokenContext) IsLogin() bool { - token := c.GetTokenValue() - return c.manager.IsLogin(token) +// GetRequestContext 获取原始请求上下文 +func (c *SaTokenContext) GetRequestContext() adapter.RequestContext { + return c.reqCtx } -// CheckLogin 检查登录(未登录抛出错误) -func (c *SaTokenContext) CheckLogin() error { - token := c.GetTokenValue() - return c.manager.CheckLogin(token) +// GetManager 获取管理器 +func (c *SaTokenContext) GetManager() *manager.Manager { + return c.manager } -// GetLoginID 获取当前登录ID -func (c *SaTokenContext) GetLoginID() (string, error) { - token := c.GetTokenValue() - return c.manager.GetLoginID(token) -} +//// IsLogin 检查当前请求是否已登录 +//func (c *SaTokenContext) IsLogin() bool { +// token := c.GetTokenValue() +// return c.manager-example.IsLogin(context.WithValue(c.ctx, config.CtxTokenValue, token)) +//} +// +//// CheckLogin 检查登录(未登录抛出错误) +//func (c *SaTokenContext) CheckLogin() error { +// token := c.GetTokenValue() +// return c.manager-example.CheckLogin(context.WithValue(c.ctx, config.CtxTokenValue, token)) +//} +// +//// GetLoginID 获取当前登录ID +//func (c *SaTokenContext) GetLoginID() (string, error) { +// token := c.GetTokenValue() +// return c.manager-example.GetLoginID(context.WithValue(c.ctx, config.CtxTokenValue, token)) +//} +// +//// HasPermission 检查是否有指定权限 +//func (c *SaTokenContext) HasPermission(permission string) bool { +// loginID, err := c.GetLoginID() +// if err != nil { +// return false +// } +// return c.manager-example.HasPermission(c.ctx, loginID, permission) +//} +// +//// HasRole 检查是否有指定角色 +//func (c *SaTokenContext) HasRole(role string) bool { +// loginID, err := c.GetLoginID() +// if err != nil { +// return false +// } +// return c.manager-example.HasRole(c.ctx, loginID, role) +//} -// HasPermission 检查是否有指定权限 -func (c *SaTokenContext) HasPermission(permission string) bool { - loginID, err := c.GetLoginID() - if err != nil { - return false +// extractBearerToken 从 Authorization 头中提取 Bearer Token +func extractBearerToken(auth string) string { + auth = strings.TrimSpace(auth) + if auth == "" { + return "" } - return c.manager.HasPermission(loginID, permission) -} -// HasRole 检查是否有指定角色 -func (c *SaTokenContext) HasRole(role string) bool { - loginID, err := c.GetLoginID() - if err != nil { - return false + // 支持大小写不敏感的 Bearer 前缀 + if len(auth) > 7 && strings.EqualFold(auth[:7], bearerPrefix) { + return strings.TrimSpace(auth[7:]) } - return c.manager.HasRole(loginID, role) -} - -// GetRequestContext 获取原始请求上下文 -func (c *SaTokenContext) GetRequestContext() adapter.RequestContext { - return c.ctx -} -// GetManager 获取管理器 -func (c *SaTokenContext) GetManager() *manager.Manager { - return c.manager + return auth } diff --git a/core/errors.go b/core/errors.go index 0a80c2d..332ba68 100644 --- a/core/errors.go +++ b/core/errors.go @@ -14,12 +14,21 @@ var ( // ErrNotLogin indicates the user is not logged in | 用户未登录错误 ErrNotLogin = fmt.Errorf("authentication required: user not logged in") - // ErrTokenInvalid indicates the provided token is invalid or malformed | Token无效或格式错误 + // ErrTokenInvalid indicates the provided token is invalid or malformed | Token 无效或格式错误 ErrTokenInvalid = fmt.Errorf("invalid token: the token is malformed or corrupted") - // ErrTokenExpired indicates the token has expired | Token已过期 + // ErrTokenExpired indicates the token has expired | Token 已过期 ErrTokenExpired = fmt.Errorf("token expired: please login again to get a new token") + // ErrTokenNotFound indicates the token does not exist | Token 不存在 + ErrTokenNotFound = fmt.Errorf("authentication required: token not found") + + // ErrTokenKickout indicates the token has been kicked out | Token 已被踢下线 + ErrTokenKickout = fmt.Errorf("authentication required: token has been kicked out") + + // ErrTokenReplaced indicates the token has been replaced | Token 已被顶下线 + ErrTokenReplaced = fmt.Errorf("authentication required: token has been replaced") + // ErrInvalidLoginID indicates the login ID is invalid | 登录ID无效 ErrInvalidLoginID = fmt.Errorf("invalid login ID: the login identifier cannot be empty") @@ -45,22 +54,94 @@ var ( // ErrAccountNotFound indicates the account doesn't exist | 账号不存在 ErrAccountNotFound = fmt.Errorf("account not found: no account associated with this identifier") + + // ErrLoginLimitExceeded indicates login count exceeds the maximum limit | 超出最大登录数量限制 + ErrLoginLimitExceeded = fmt.Errorf("account error: login count exceeds the maximum limit") ) // ============ Session Errors | 会话错误 ============ var ( - // ErrSessionNotFound indicates the session doesn't exist | Session不存在 + // ErrSessionNotFound indicates the session doesn't exist | Session 不存在 ErrSessionNotFound = fmt.Errorf("session not found: the session may have expired or been deleted") - // ErrKickedOut indicates the user has been kicked out | 用户已被踢下线 - ErrKickedOut = fmt.Errorf("kicked out: this session has been forcibly terminated") - - // ErrActiveTimeout indicates the session has been inactive for too long | Session活跃超时 + // ErrActiveTimeout indicates the session has been inactive for too long | Session 活跃超时 ErrActiveTimeout = fmt.Errorf("session inactive: the session has exceeded the inactivity timeout") +) + +// ============ Security Errors | Security 错误 ============ + +var ( + // ErrInvalidNonce indicates the nonce is invalid or expired | Nonce 无效或已过期 + ErrInvalidNonce = fmt.Errorf("invalid nonce: nonce is invalid or expired") + + // ErrRefreshTokenExpired indicates the refresh token has expired | 刷新令牌已过期 + ErrRefreshTokenExpired = fmt.Errorf("refresh token expired: please request a new token") - // ErrMaxLoginCount indicates maximum concurrent login limit reached | 达到最大登录数量限制 - ErrMaxLoginCount = fmt.Errorf("max login limit: maximum number of concurrent logins reached") + // ErrNonceInvalidRefreshToken indicates the refresh token is invalid | 刷新令牌无效 + ErrNonceInvalidRefreshToken = fmt.Errorf("invalid refresh token: token is malformed or does not exist") + + // ErrInvalidLoginIDEmpty indicates loginID is empty | 登录ID不能为空 + ErrInvalidLoginIDEmpty = fmt.Errorf("invalid loginID: loginID cannot be empty") +) + +// ============ OAuth2 Errors | OAuth2 错误 ============ + +var ( + // ErrClientOrClientIDEmpty indicates client or clientID is empty | 客户端或客户端ID为空 + ErrClientOrClientIDEmpty = fmt.Errorf("invalid client: clientID is required") + + // ErrClientNotFound indicates the client does not exist | 客户端不存在 + ErrClientNotFound = fmt.Errorf("client error: client not found") + + // ErrUserIDEmpty indicates userID is empty | 用户ID不能为空 + ErrUserIDEmpty = fmt.Errorf("invalid user: userID cannot be empty") + + // ErrInvalidRedirectURI indicates redirect URI is invalid | 回调URI非法 + ErrInvalidRedirectURI = fmt.Errorf("invalid redirect uri: redirectUri is not allowed") + + // ErrInvalidClientCredentials indicates incorrect client credentials | 客户端凭证无效 + ErrInvalidClientCredentials = fmt.Errorf("invalid client credentials: authentication failed") + + // ErrInvalidAuthCode indicates an invalid authorization code | 授权码无效 + ErrInvalidAuthCode = fmt.Errorf("invalid authorization code: code is malformed or does not exist") + + // ErrAuthCodeUsed indicates the authorization code has already been used | 授权码已被使用 + ErrAuthCodeUsed = fmt.Errorf("authorization code error: code already used") + + // ErrAuthCodeExpired indicates the authorization code has expired | 授权码已过期 + ErrAuthCodeExpired = fmt.Errorf("authorization code expired: please restart authorization process") + + // ErrClientMismatch indicates client mismatch | 客户端不匹配 + ErrClientMismatch = fmt.Errorf("client mismatch: clientID does not match the authorization code") + + // ErrRedirectURIMismatch indicates redirect URI mismatch | 回调URI不匹配 + ErrRedirectURIMismatch = fmt.Errorf("redirect uri mismatch: callback URL does not match registered value") + + // ErrInvalidAccessToken indicates access token invalid | 访问令牌无效 + ErrInvalidAccessToken = fmt.Errorf("invalid access token: token is malformed or expired") + + // ErrInvalidRefreshToken indicates refresh token invalid | 刷新令牌无效 + ErrInvalidRefreshToken = fmt.Errorf("invalid refresh token: token is malformed or expired") + + // ErrInvalidScope indicates requested scope is not allowed | 请求的权限范围不被允许 + ErrInvalidScope = fmt.Errorf("invalid scope: requested scope is not allowed for this client") + + // ErrInvalidGrantType indicates grant type is not allowed | 授权类型不被允许 + ErrInvalidGrantType = fmt.Errorf("invalid grant type: this grant type is not allowed for this client") + + // ErrInvalidUserCredentials indicates user credentials are incorrect | 用户凭证无效 + ErrInvalidUserCredentials = fmt.Errorf("invalid user credentials: username or password is incorrect") +) + +// ============ Session Errors | Session 错误 ============ + +var ( + // ErrSessionInvalidDataKey indicates a session data key is empty or invalid | Session 数据的 key 为空或非法 + ErrSessionInvalidDataKey = fmt.Errorf("invalid session data key: key cannot be empty") + + // ErrSessionIDEmpty indicates that a session ID is empty or missing | Session ID 为空或缺失 + ErrSessionIDEmpty = fmt.Errorf("session id cannot be empty") ) // ============ System Errors | 系统错误 ============ @@ -68,6 +149,21 @@ var ( var ( // ErrStorageUnavailable indicates the storage backend is unavailable | 存储后端不可用 ErrStorageUnavailable = fmt.Errorf("storage unavailable: unable to connect to storage backend") + + // ErrSerializeFailed indicates serialization failed | 序列化失败 + ErrSerializeFailed = fmt.Errorf("serialize failed: unable to encode data") + + // ErrDeserializeFailed indicates deserialization failed | 反序列化失败 + ErrDeserializeFailed = fmt.Errorf("deserialize failed: unable to decode data") + + // ErrTypeConvert indicates a type conversion failed | 类型转换失败 + ErrTypeConvert = fmt.Errorf("type conversion failed: unable to convert value to target type") + + // ErrManagerNotFound indicates that the manager-example was not found for the given autoType | manager-example 未找到 + ErrManagerNotFound = fmt.Errorf("manager-example not found") + + // ErrManagerInvalidType indicates that the loaded manager-example has invalid type | manager-example 类型无效 + ErrManagerInvalidType = fmt.Errorf("invalid manager-example type") ) // ============ Custom Error Type | 自定义错误类型 ============ diff --git a/core/go.mod b/core/go.mod index 5eb5cc5..dabd856 100644 --- a/core/go.mod +++ b/core/go.mod @@ -1,16 +1,3 @@ module github.com/click33/sa-token-go/core -go 1.23.0 - -require ( - github.com/golang-jwt/jwt/v5 v5.2.2 - github.com/google/uuid v1.6.0 - github.com/panjf2000/ants/v2 v2.11.3 -) - -require ( - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/stretchr/testify v1.11.1 // indirect - golang.org/x/sync v0.16.0 // indirect -) +go 1.25.0 \ No newline at end of file diff --git a/core/go.sum b/core/go.sum index ec90348..e69de29 100644 --- a/core/go.sum +++ b/core/go.sum @@ -1,11 +0,0 @@ -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= -github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= -github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= -github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= -gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= -gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/core/listener/consts.go b/core/listener/consts.go new file mode 100644 index 0000000..8028570 --- /dev/null +++ b/core/listener/consts.go @@ -0,0 +1,43 @@ +// @Author daixk 2025/12/14 20:49:00 +package listener + +// Event represents the type of authentication event | 认证事件类型 +type Event string + +const ( + // EventLogin fired when a user logs in | 用户登录事件 + EventLogin Event = "login" + + // EventLogout fired when a user logs out | 用户登出事件 + EventLogout Event = "logout" + + // EventKickout fired when a user is forcibly logged out | 用户被踢下线事件 + EventKickout Event = "kickout" + + // EventReplace fired when a user is replaced by a new login | 用户被顶下线事件 + EventReplace Event = "replace" + + // EventDisable fired when an account is disabled | 账号被禁用事件 + EventDisable Event = "disable" + + // EventUntie fired when an account is re-enabled | 账号解禁事件 + EventUntie Event = "untie" + + // EventRenew fired when a token is renewed | Token续期事件 + EventRenew Event = "renew" + + // EventCreateSession fired when a new session is created | Session创建事件 + EventCreateSession Event = "createSession" + + // EventDestroySession fired when a session is destroyed | Session销毁事件 + EventDestroySession Event = "destroySession" + + // EventPermissionCheck fired when a permission check is performed | 权限检查事件 + EventPermissionCheck Event = "permissionCheck" + + // EventRoleCheck fired when a role check is performed | 角色检查事件 + EventRoleCheck Event = "roleCheck" + + // EventAll is a wildcard event that matches all events | 通配符事件(匹配所有事件) + EventAll Event = "*" +) diff --git a/core/listener/listener.go b/core/listener/listener.go index b998117..c18fdda 100644 --- a/core/listener/listener.go +++ b/core/listener/listener.go @@ -2,51 +2,16 @@ package listener import ( "fmt" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/log/nop" "sync" "time" ) -// Event represents the type of authentication event | 认证事件类型 -type Event string - -const ( - // EventLogin fired when a user logs in | 用户登录事件 - EventLogin Event = "login" - - // EventLogout fired when a user logs out | 用户登出事件 - EventLogout Event = "logout" - - // EventKickout fired when a user is forcibly logged out | 用户被踢下线事件 - EventKickout Event = "kickout" - - // EventDisable fired when an account is disabled | 账号被禁用事件 - EventDisable Event = "disable" - - // EventUntie fired when an account is re-enabled | 账号解禁事件 - EventUntie Event = "untie" - - // EventRenew fired when a token is renewed | Token续期事件 - EventRenew Event = "renew" - - // EventCreateSession fired when a new session is created | Session创建事件 - EventCreateSession Event = "createSession" - - // EventDestroySession fired when a session is destroyed | Session销毁事件 - EventDestroySession Event = "destroySession" - - // EventPermissionCheck fired when a permission check is performed | 权限检查事件 - EventPermissionCheck Event = "permissionCheck" - - // EventRoleCheck fired when a role check is performed | 角色检查事件 - EventRoleCheck Event = "roleCheck" - - // EventAll is a wildcard event that matches all events | 通配符事件(匹配所有事件) - EventAll Event = "*" -) - // EventData contains information about a triggered event | 事件数据,包含触发事件的相关信息 type EventData struct { Event Event // Event type | 事件类型 + AuthType string // Authentication system type | 认证体系类型 LoginID string // User login ID | 用户登录ID Device string // Device identifier | 设备标识 Token string // Authentication token | 认证Token @@ -56,14 +21,14 @@ type EventData struct { // String returns a string representation of the event data | 返回事件数据的字符串表示 func (e *EventData) String() string { - return fmt.Sprintf("Event{type=%s, loginID=%s, device=%s, timestamp=%d}", - e.Event, e.LoginID, e.Device, e.Timestamp) + return fmt.Sprintf("Event{type=%s,AuthType=%s, loginID=%s, device=%s, timestamp=%d}", + e.Event, e.AuthType, e.LoginID, e.Device, e.Timestamp) } // Listener is the interface for event listeners | 事件监听器接口 type Listener interface { // OnEvent is called when an event is triggered | 当事件触发时调用 - // The listener should not panic; any panic will be recovered by the event manager | 监听器不应该panic,任何panic都会被事件管理器恢复 + // The listener should not panic; any panic will be recovered by the event manager-example | 监听器不应该panic,任何panic都会被事件管理器恢复 OnEvent(data *EventData) } @@ -78,7 +43,7 @@ func (f ListenerFunc) OnEvent(data *EventData) { // ListenerConfig holds configuration for a registered listener | 监听器配置 type ListenerConfig struct { Async bool // If true, listener runs asynchronously | 如果为true,监听器异步运行 - Priority int // Higher priority listeners are called first (default: 0) | 优先级越高越先执行(默认:0) + Priority int // Higher priority listeners are called first (log: 0) | 优先级越高越先执行(默认:0) ID string // Unique identifier for this listener (for unregistering) | 监听器唯一标识(用于注销) } @@ -108,24 +73,40 @@ type Manager struct { filters []EventFilter // Global event filters | 全局事件过滤器 stats *EventStats // Event statistics | 事件统计 enableStats bool // Whether to collect statistics | 是否收集统计信息 + logger adapter.Log // Log adapter for logging operations | 日志适配器 } -// NewManager creates a new event manager | 创建新的事件管理器 -func NewManager() *Manager { - return &Manager{ - listeners: make(map[Event][]listenerEntry), - panicHandler: func(event Event, data *EventData, recovered any) { - // Default panic handler: log but don't crash | 默认panic处理器:记录日志但不崩溃 - fmt.Printf("sa-token: listener panic recovered: event=%s, panic=%v\n", event, recovered) - }, - enabledEvents: nil, // All events enabled by default | 默认启用所有事件 +// NewManager creates a new event manager-example | 创建新的事件管理器 +func NewManager(loggers ...adapter.Log) *Manager { + var logger adapter.Log + + if len(loggers) > 0 && loggers[0] != nil { + logger = loggers[0] + } else { + logger = nop.NewNopLogger() + } + + m := &Manager{ + listeners: make(map[Event][]listenerEntry), + enabledEvents: nil, // All events enabled by log | 默认启用所有事件 filters: make([]EventFilter, 0), stats: &EventStats{ EventCounts: make(map[Event]int64), LastTriggered: make(map[Event]time.Time), }, - enableStats: false, // Stats disabled by default | 默认不启用统计 + enableStats: false, // Stats disabled by log | 默认不启用统计 + logger: logger, + } + + // panicHandler 绑定“已经确定好的 logger” + m.panicHandler = func(event Event, data *EventData, recovered any) { + logger.Errorf( + "Listener listener panic recovered: event=%s, panic=%v", + event, recovered, + ) } + + return m } // SetPanicHandler sets a custom panic handler for listener errors | 设置自定义的panic处理器 @@ -234,7 +215,7 @@ func (m *Manager) IsEventEnabled(event Event) bool { return m.enabledEvents[event] || m.enabledEvents[EventAll] } -// Register registers a listener for an event with default configuration +// Register registers a listener for an event with log configuration func (m *Manager) Register(event Event, listener Listener) string { return m.RegisterWithConfig(event, listener, ListenerConfig{ Async: true, @@ -270,7 +251,7 @@ func (m *Manager) RegisterWithConfig(event Event, listener Listener, config List return config.ID } -// RegisterFunc registers a function listener with default configuration +// RegisterFunc registers a function listener with log configuration func (m *Manager) RegisterFunc(event Event, handler func(data *EventData)) string { return m.Register(event, ListenerFunc(handler)) } @@ -358,6 +339,16 @@ func (m *Manager) Trigger(data *EventData) { m.mu.RUnlock() + // 日志 + m.logger.Infof( + "Listener auth event triggered: event=%s, authType=%s, loginID=%s, device=%s, listeners=%d", + data.Event, + data.AuthType, + data.LoginID, + data.Device, + len(listenersToCall), + ) + // Execute listeners for _, entry := range listenersToCall { if entry.config.Async { diff --git a/core/manager/consts.go b/core/manager/consts.go new file mode 100644 index 0000000..8598fcb --- /dev/null +++ b/core/manager/consts.go @@ -0,0 +1,39 @@ +// @Author daixk 2025/12/4 17:58:00 +package manager + +import ( + "time" +) + +// Constants for storage keys and log values | 存储键和默认值常量 +const ( + DefaultDevice = "default" // Default device type | 默认设备类型 + DefaultNonceTTL = 5 * time.Minute // Default nonce expiration time | 默认随机令牌有效期 + + // Key prefixes | 键前缀 + TokenKeyPrefix = "token:" // Token storage prefix | Token 存储前缀 + AccountKeyPrefix = "account:" // Account storage prefix | 账号存储前缀 + DisableKeyPrefix = "disable:" // Disable state prefix | 禁用状态存储前缀 + RenewKeyPrefix = "renew:" // Token renew prefix | Token 续期存储前缀 + TokenValueListLastKey = ":*" + + // Session keys | Session 键 + SessionKeyLoginID = "loginId" // Login ID | 登录 ID + SessionKeyDevice = "device" // Device type | 设备类型 + SessionKeyLoginTime = "loginTime" // Login time | 登录时间 + SessionKeyPermissions = "permissions" // Permissions list | 权限列表 + SessionKeyRoles = "roles" // Roles list | 角色列表 + + // Wildcard for permissions | 权限通配符 + PermissionWildcard = "*" // Global permission wildcard | 全局权限通配符 + PermissionSeparator = ":" // Permission segment separator | 权限段分隔符 +) + +// TokenState 表示 Token 的逻辑状态 +type TokenState string + +const ( + TokenStateLogout TokenState = "LOGOUT" // Logout state | 主动登出 + TokenStateKickout TokenState = "KICK_OUT" // Kickout state | 被踢下线 + TokenStateReplaced TokenState = "BE_REPLACED" // Replaced state | 被顶下线 +) diff --git a/core/manager/manager.go b/core/manager/manager.go index 28c70a2..5b7b0b5 100644 --- a/core/manager/manager.go +++ b/core/manager/manager.go @@ -1,261 +1,276 @@ package manager import ( - "encoding/json" + "context" "fmt" + codec_json "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/utils" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/pool/ants" + "github.com/click33/sa-token-go/storage/memory" "strings" "time" - "github.com/click33/sa-token-go/core/pool" - "github.com/click33/sa-token-go/core/adapter" "github.com/click33/sa-token-go/core/config" "github.com/click33/sa-token-go/core/listener" "github.com/click33/sa-token-go/core/oauth2" "github.com/click33/sa-token-go/core/security" "github.com/click33/sa-token-go/core/session" - "github.com/click33/sa-token-go/core/token" -) - -// Constants for storage keys and default values | 存储键和默认值常量 -const ( - DefaultDevice = "default" - DefaultPrefix = "satoken" - DisableValue = "1" - DefaultRenewValue = "1" - DefaultNonceTTL = 5 * time.Minute - - // Key prefixes | 键前缀 - TokenKeyPrefix = "token:" - AccountKeyPrefix = "account:" - DisableKeyPrefix = "disable:" - RenewKeyPrefix = "renew:" - - // Session keys | Session键 - SessionKeyLoginID = "loginId" - SessionKeyDevice = "device" - SessionKeyLoginTime = "loginTime" - SessionKeyPermissions = "permissions" - SessionKeyRoles = "roles" - - // Wildcard for permissions | 权限通配符 - PermissionWildcard = "*" - PermissionSeparator = ":" -) - -// TokenState 表示 Token 的逻辑状态 -type TokenState string - -const ( - TokenStateKickout TokenState = "KICK_OUT" - TokenStateReplaced TokenState = "BE_REPLACED" -) - -// Error variables | 错误变量 -var ( - ErrAccountDisabled = fmt.Errorf("account is disabled") - ErrNotLogin = fmt.Errorf("not login") - ErrTokenNotFound = fmt.Errorf("token not found") - ErrInvalidTokenData = fmt.Errorf("invalid token data") - ErrLoginLimitExceeded = fmt.Errorf("login count exceeds the maximum limit") - ErrTokenKickout = fmt.Errorf("token has been kicked out") - ErrTokenReplaced = fmt.Errorf("token has been replaced") ) -// TokenInfo Token information | Token信息 +// TokenInfo Token information | Token 信息 type TokenInfo struct { - LoginID string `json:"loginId"` - Device string `json:"device"` - CreateTime int64 `json:"createTime"` - ActiveTime int64 `json:"activeTime"` // Last active time | 最后活跃时间 - Tag string `json:"tag,omitempty"` + AuthType string `json:"authType"` // Authentication system type | 认证体系类型 + LoginID string `json:"loginId"` // Login ID | 登录 ID + Device string `json:"device"` // Device type | 设备类型 + CreateTime int64 `json:"createTime"` // Token creation timestamp | 创建时间戳 + ActiveTime int64 `json:"activeTime"` // Last active time | 最后活跃时间戳(续期时间) + Tag string `json:"tag,omitempty"` // Custom tag for additional data | 自定义标记字段(可选) } -// Manager Authentication manager | 认证管理器 +// DisableInfo Account disable information | 封禁信息结构体 +type DisableInfo struct { + DisableTime int64 `json:"disableTime"` // Disable timestamp | 封禁时间戳 + DisableReason string `json:"disableReason"` // Reason for account disable | 封禁原因 +} + +// Manager Authentication manager-example | 认证管理器 type Manager struct { - storage adapter.Storage - config *config.Config - generator *token.Generator - prefix string - nonceManager *security.NonceManager - refreshManager *security.RefreshTokenManager - oauth2Server *oauth2.OAuth2Server - renewPool *pool.RenewPoolManager - eventManager *listener.Manager -} - -// NewManager Creates a new manager | 创建管理器 -func NewManager(storage adapter.Storage, cfg *config.Config) *Manager { + config *config.Config // Global authentication configuration | 全局认证配置 + nonceManager *security.NonceManager // Nonce manager-example for preventing replay attacks | 随机串管理器 + refreshManager *security.RefreshTokenManager // Refresh token manager-example | 刷新令牌管理器 + oauth2Server *oauth2.OAuth2Server // OAuth2 authorization server | OAuth2 授权服务器 + eventManager *listener.Manager // Event manager-example | 事件管理器 + + generator adapter.Generator // Token generator | Token 生成器 + storage adapter.Storage // Storage adapter | 存储适配器 + serializer adapter.Codec // Codec adapter for encoding and decoding operations | 编解码器适配器 + logger adapter.Log // Log adapter for logging operations | 日志适配器 + pool adapter.Pool // Async task pool component | 异步任务协程池组件 + + CustomPermissionListFunc func(loginID, authType string) ([]string, error) // Custom permission func | 自定义权限获取函数 + CustomRoleListFunc func(loginID, authType string) ([]string, error) // Custom role func | 自定义角色获取函数 +} + +// NewManager creates and initializes a new Manager instance | 创建并初始化一个新的 Manager 实例 +func NewManager( + cfg *config.Config, + generator adapter.Generator, + storage adapter.Storage, + serializer adapter.Codec, + logger adapter.Log, + pool adapter.Pool, + customPermissionListFunc, CustomRoleListFunc func(loginID, authType string) ([]string, error), +) *Manager { + + // Use default configuration if cfg is nil | 如果未传入配置,则使用默认配置 if cfg == nil { cfg = config.DefaultConfig() } - // Use configured prefix, fallback to default | 使用配置的前缀,回退到默认值 - prefix := cfg.KeyPrefix - if prefix == "" { - prefix = DefaultPrefix - } - - // Initialize renew pool manager if configuration is provided | 如果配置了续期池,初始化续期池管理器 - var renewPoolManager *pool.RenewPoolManager - if cfg.RenewPoolConfig != nil { - renewPoolManager, _ = pool.NewRenewPoolManagerWithConfig(&pool.RenewPoolConfig{ - MinSize: cfg.RenewPoolConfig.MinSize, // Minimum pool size | 最小协程数 - MaxSize: cfg.RenewPoolConfig.MaxSize, // Maximum pool size | 最大协程数 - ScaleUpRate: cfg.RenewPoolConfig.ScaleUpRate, // Scale-up threshold | 扩容阈值 - ScaleDownRate: cfg.RenewPoolConfig.ScaleDownRate, // Scale-down threshold | 缩容阈值 - CheckInterval: cfg.RenewPoolConfig.CheckInterval, // Auto-scale check interval | 自动缩放检查间隔 - Expiry: cfg.RenewPoolConfig.Expiry, // Idle worker expiry duration | 空闲协程过期时间 - PrintStatusInterval: cfg.RenewPoolConfig.PrintStatusInterval, // Interval for periodic status printing | 定时打印池状态的间隔 - PreAlloc: cfg.RenewPoolConfig.PreAlloc, // Whether to pre-allocate memory | 是否预分配内存 - NonBlocking: cfg.RenewPoolConfig.NonBlocking, // Whether to use non-blocking mode | 是否为非阻塞模式 - }) + // Initialize token generator if generator is nil | 如果未传入 Token 生成器,则创建默认生成器 + if generator == nil { + generator = sgenerator.NewGenerator(cfg.Timeout, cfg.TokenStyle, cfg.JwtSecretKey) } - return &Manager{ - storage: storage, - config: cfg, - generator: token.NewGenerator(cfg), - prefix: prefix, - nonceManager: security.NewNonceManager(storage, prefix, DefaultNonceTTL), - refreshManager: security.NewRefreshTokenManager(storage, prefix, TokenKeyPrefix, cfg), - oauth2Server: oauth2.NewOAuth2Server(storage, prefix), - eventManager: listener.NewManager(), - renewPool: renewPoolManager, + // Use in-memory storage if storage is nil | 如果未传入存储实现,则使用内存存储 + if storage == nil { + storage = memory.NewStorage() } -} -// CloseManager Closes the manager and releases all resources | 关闭管理器并释放所有资源 -func (m *Manager) CloseManager() { - if m.renewPool != nil { - // 安全关闭 renewPool - m.renewPool.Stop() - m.renewPool = nil + // Use JSON serializer if serializer is nil | 如果未传入序列化器,则使用 JSON 序列化器 + if serializer == nil { + serializer = codec_json.NewJSONSerializer() } -} -// ============ Helper Methods | 辅助方法 ============ + // Use default logger if logger is nil | 如果未传入日志记录器,则使用默认日志记录器 + if logger == nil { + logger = nop.NewNopLogger() + } -// getDevice extracts device type from optional parameter | 从可选参数中提取设备类型 -func getDevice(device []string) string { - if len(device) > 0 && device[0] != "" { - return device[0] + if cfg.AutoRenew && pool == nil { + // Use default goroutine pool if pool is nil | 如果未传入协程池,则使用默认协程池 + pool = ants.NewRenewPoolManagerWithDefaultConfig() } - return DefaultDevice -} -// getExpiration calculates expiration duration from config | 从配置计算过期时间 -func (m *Manager) getExpiration() time.Duration { - if m.config.Timeout > 0 { - return time.Duration(m.config.Timeout) * time.Second + // Return the new manager-example instance with initialized sub-managers | 返回已初始化各子模块的管理器实例 + return &Manager{ + // Store global configuration | 保存全局配置 + config: cfg, + + // Token generator used for creating access/refresh tokens | 用于生成访问令牌和刷新令牌的生成器 + generator: generator, + + // Nonce manager-example for replay-attack protection | 防重放攻击的 Nonce 管理器 + nonceManager: security.NewNonceManager( + cfg.AuthType, + cfg.KeyPrefix, + storage, + DefaultNonceTTL, + ), + + // Refresh token manager-example for token renewal logic | 刷新令牌管理器,用于令牌续期逻辑 + refreshManager: security.NewRefreshTokenManager( + cfg.AuthType, + cfg.KeyPrefix, + TokenKeyPrefix, + generator, + time.Duration(cfg.Timeout)*time.Second, + storage, + serializer, + ), + + // OAuth2 server for authorization and token exchange | OAuth2 授权与令牌颁发服务 + oauth2Server: oauth2.NewOAuth2Server( + cfg.AuthType, + cfg.KeyPrefix, + storage, + serializer, + ), + + // Event manager-example for lifecycle and auth events | 生命周期与认证事件管理器 + eventManager: listener.NewManager(logger), + + // Storage adapter for persistence layer | 持久化存储适配器 + storage: storage, + + // Serializer for encoding/decoding data | 数据编解码序列化器 + serializer: serializer, + + // Logger for internal logging | 内部日志记录器 + logger: logger, + + // Goroutine pool for async task execution | 用于异步任务执行的协程池 + pool: pool, + + // Custom permission list provider | 自定义权限列表获取函数 + CustomPermissionListFunc: customPermissionListFunc, + + // Custom role list provider | 自定义角色列表获取函数 + CustomRoleListFunc: CustomRoleListFunc, } - return 0 } -// assertString safely converts interface to string | 安全地将interface转换为string -func assertString(v any) (string, bool) { - s, ok := v.(string) - return s, ok +// CloseManager Closes the manager-example and releases all resources | 关闭管理器并释放所有资源 +func (m *Manager) CloseManager() { + // Close logger if it implements LogControl | 如果日志实现了 LogControl 接口,则关闭日志 + if logControl, ok := m.logger.(adapter.LogControl); ok { + logControl.Flush() + logControl.Close() + } + + if m.pool != nil { + // Safely close the renewPool | 安全关闭 renewPool + m.pool.Stop() + // Set renewPool to nil | 将 renewPool 设置为 nil + m.pool = nil + } } -// ============ Login Authentication | 登录认证 ============ +// ============ Authentication | 登录认证 ============ -// Login Performs user login and returns token | 登录,返回Token -func (m *Manager) Login(loginID string, device ...string) (string, error) { +// Login Performs user login and returns token | 登录 返回Token +func (m *Manager) Login(ctx context.Context, loginID string, device ...string) (string, error) { // Check if account is disabled | 检查账号是否被封禁 - if m.IsDisable(loginID) { - return "", ErrAccountDisabled + if m.IsDisable(ctx, loginID) { + return "", core.ErrAccountDisabled } + // Get device type | 获取设备类型 deviceType := getDevice(device) + // Get account key | 获取账号存储键 accountKey := m.getAccountKey(loginID, deviceType) + // Get existing token list of this account | 获取该账号下所有已登录Token + existingTokenList, err := m.GetTokenValueListByLoginID(ctx, loginID) + if err != nil { + return "", err + } - // Handle shared token for concurrent login | 处理多人登录共用 Token 的情况 + // Handle shared token for concurrent login | 处理多人登录共用Token的情况 if m.config.IsShare { - // Look for existing token of this account + device | 查找账号 + 设备下是否已有登录 Token - existingToken, err := m.storage.Get(accountKey) - if err == nil && existingToken != nil { - if tokenStr, ok := assertString(existingToken); ok && m.IsLogin(tokenStr) { - // If valid token exists, return it directly | 如果已有 Token 且有效,则直接返回 - return tokenStr, nil + if len(existingTokenList) > 0 { + if isLoggedIn, err := m.IsLogin(ctx, existingTokenList[0]); isLoggedIn && err == nil { + return existingTokenList[0], nil } } } // Handle concurrent login behavior | 处理并发登录逻辑 if !m.config.IsConcurrent { - // Concurrent login not allowed → kick out previous login on same device | 不允许并发登录 → 踢掉同设备下之前的 Token - _ = m.kickout(loginID, deviceType) - - } else if m.config.MaxLoginCount > 0 && !m.config.IsShare { - // MaxLoginCount = 0 → 不允许任何 Token - if m.config.MaxLoginCount == 0 { - return "", ErrLoginLimitExceeded + // Concurrent login not allowed Kickout all existing tokens | 不允许并发登录 踢掉所有Token + if len(existingTokenList) > 0 { + for _, token := range existingTokenList { + err = m.kickoutByToken(ctx, token) + return "", err + } } + } else if m.config.MaxLoginCount > 0 && !m.config.IsShare { // Concurrent login allowed but limited by MaxLoginCount | 允许并发登录但受 MaxLoginCount 限制 - // This limit applies to all tokens of this account across devices | 该限制针对账号所有设备的登录 Token 数量 - tokens, _ := m.GetTokenValueListByLoginID(loginID) - if len(tokens) >= m.config.MaxLoginCount { - // Reached maximum concurrent login count | 已达到最大并发登录数 - // You may change to "kick out earliest token" if desired | 如需也可改为“踢掉最早 Token” - return "", ErrLoginLimitExceeded + if int64(len(existingTokenList)) >= m.config.MaxLoginCount { + return "", core.ErrLoginLimitExceeded } } // Generate token | 生成Token tokenValue, err := m.generator.Generate(loginID, deviceType) if err != nil { - return "", fmt.Errorf("failed to generate token: %w", err) + return "", err } + // Current timestamp | 当前时间戳 nowTime := time.Now().Unix() + // Calculate expiration time | 计算过期时间 expiration := m.getExpiration() - // Prepare TokenInfo object and serialize to JSON | 准备Token信息对象并序列化为JSON - tokenInfoStr, err := json.Marshal(TokenInfo{ + // Prepare TokenInfo object and serialize to JSON | 准备Token信息对象并序列化 + tokenInfo, err := m.serializer.Encode(TokenInfo{ + AuthType: m.config.AuthType, LoginID: loginID, Device: deviceType, CreateTime: nowTime, ActiveTime: nowTime, }) if err != nil { - return "", fmt.Errorf("failed to marshal tokenInfo: %w", err) + return "", fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) } // Save token-tokenInfo mapping | 保存 TokenKey-TokenInfo 映射 - tokenKey := m.getTokenKey(tokenValue) - if err = m.storage.Set(tokenKey, string(tokenInfoStr), expiration); err != nil { - return "", fmt.Errorf("failed to save token: %w", err) + if err = m.storage.Set(ctx, m.getTokenKey(tokenValue), tokenInfo, expiration); err != nil { + return "", fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } // Save account-token mapping | 保存 AccountKey-Token 映射 - if err = m.storage.Set(accountKey, tokenValue, expiration); err != nil { - return "", fmt.Errorf("failed to save account mapping: %w", err) + if err = m.storage.Set(ctx, accountKey, tokenValue, expiration); err != nil { + return "", fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } // Create session | 创建Session - err = session. - NewSession(loginID, m.storage, m.prefix). + if err = session. + NewSession(m.config.AuthType, m.config.KeyPrefix, loginID, m.storage, m.serializer). SetMulti( + ctx, map[string]any{ SessionKeyLoginID: loginID, SessionKeyDevice: deviceType, SessionKeyLoginTime: nowTime, }, expiration, - ) - if err != nil { - return "", fmt.Errorf("failed to save session: %w", err) + ); err != nil { + return "", err } // Trigger login event | 触发登录事件 if m.eventManager != nil { m.eventManager.Trigger(&listener.EventData{ - Event: listener.EventLogin, - LoginID: loginID, - Token: tokenValue, - Device: deviceType, + Event: listener.EventLogin, + AuthType: m.config.AuthType, + LoginID: loginID, + Device: deviceType, + Token: tokenValue, }) } @@ -263,173 +278,170 @@ func (m *Manager) Login(loginID string, device ...string) (string, error) { } // LoginByToken Login with specified token (for seamless token refresh) | 使用指定Token登录(用于token无感刷新) -func (m *Manager) LoginByToken(loginID string, tokenValue string, device ...string) error { - info, err := m.getTokenInfo(tokenValue) +func (m *Manager) LoginByToken(ctx context.Context, tokenValue string) error { + info, err := m.getTokenInfo(ctx, tokenValue) if err != nil { return err } - if info == nil { - return ErrInvalidTokenData - } // Check if the account is disabled | 检查账号是否被封禁 - if m.IsDisable(info.LoginID) { - return ErrAccountDisabled - } - - now := time.Now().Unix() - expiration := m.getExpiration() - - // Update last active time only | 更新活跃时间(轻量刷新) - info.ActiveTime = now - - // Write back updated TokenInfo (保留原TTL) - if data, err := json.Marshal(info); err == nil { - _ = m.storage.SetKeepTTL(m.getTokenKey(tokenValue), data) + if m.IsDisable(ctx, info.LoginID) { + return core.ErrAccountDisabled } - // Extend TTL for token, account, session | 延长Token、账号、Session的过期时间 - if expiration > 0 { - _ = m.storage.Expire(m.getTokenKey(tokenValue), expiration) - _ = m.storage.Expire(m.getAccountKey(info.LoginID, info.Device), expiration) - if sess, err := m.GetSession(info.LoginID); err == nil && sess != nil { - _ = sess.Renew(expiration) - } + // Renew token | 同步刷新Token + err = m.renewToken(context.Background(), tokenValue, info) + if err != nil { + return err } return nil } // Logout Performs user logout | 登出 -func (m *Manager) Logout(loginID string, device ...string) error { - deviceType := getDevice(device) - accountKey := m.getAccountKey(loginID, deviceType) +func (m *Manager) Logout(ctx context.Context, loginID string, device ...string) error { + // Get account key | 获取账号存储键 + accountKey := m.getAccountKey(loginID, getDevice(device)) - tokenValue, err := m.storage.Get(accountKey) - if err != nil || tokenValue == nil { - return nil // Already logged out | 已经登出 + // Get token value | 获取Token值 + tokenValue, err := m.storage.Get(ctx, accountKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if tokenValue == nil { + return core.ErrNotLogin } // Assert token value type | 类型断言为字符串 - tokenStr, ok := assertString(tokenValue) + tokenValueStr, ok := assertString(tokenValue) if !ok { - return nil + return core.ErrTokenNotFound } - return m.removeTokenChain(tokenStr, false, listener.EventLogout) + return m.removeTokenChain(ctx, tokenValueStr, nil, listener.EventLogout) } // LogoutByToken Logout by token | 根据Token登出 -func (m *Manager) LogoutByToken(tokenValue string) error { - if tokenValue == "" { - return nil - } - - return m.removeTokenChain(tokenValue, false, listener.EventLogout) +func (m *Manager) LogoutByToken(ctx context.Context, tokenValue string) error { + return m.removeTokenChain(ctx, tokenValue, nil, listener.EventLogout) } +// ============ Online Status Management | 在线状态管理 ============ + // kickout Kick user offline (private) | 踢人下线(私有) -func (m *Manager) kickout(loginID string, device string) error { +func (m *Manager) kickout(ctx context.Context, loginID string, device string) error { + // Get the account key for this user and device | 获取该用户和设备对应的账户键 accountKey := m.getAccountKey(loginID, device) - tokenValue, err := m.storage.Get(accountKey) - if err != nil || tokenValue == nil { - return nil + + // Retrieve the token associated with this account key from storage | 从存储中获取该账户键对应的 Token + tokenValue, err := m.storage.Get(ctx, accountKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } - tokenStr, ok := assertString(tokenValue) - if !ok { - return nil + // If no token exists for this account key | 如果该账户键不存在 Token + if tokenValue == nil { + return core.ErrTokenNotFound + } + + // Remove the token chain and trigger kickout event | 移除 Token 链并触发踢下线事件 + if tokenValueStr, ok := assertString(tokenValue); ok { + return m.removeTokenChain(ctx, tokenValueStr, nil, listener.EventKickout) } - return m.removeTokenChain(tokenStr, false, listener.EventKickout) + return nil } // Kickout Kick user offline (public method) | 踢人下线(公开方法) -func (m *Manager) Kickout(loginID string, device ...string) error { - deviceType := getDevice(device) - return m.kickout(loginID, deviceType) +func (m *Manager) Kickout(ctx context.Context, loginID string, device ...string) error { + return m.kickout(ctx, loginID, getDevice(device)) } // kickoutByToken Kick user offline (private) | 根据Token踢人下线(私有) -func (m *Manager) kickoutByToken(tokenValue string) error { - return m.removeTokenChain(tokenValue, false, listener.EventKickout) +func (m *Manager) kickoutByToken(ctx context.Context, tokenValue string) error { + return m.removeTokenChain(ctx, tokenValue, nil, listener.EventKickout) } // KickoutByToken Kick user offline (public method) | 根据Token踢人下线(公开方法) -func (m *Manager) KickoutByToken(tokenValue string) error { - return m.kickoutByToken(tokenValue) +func (m *Manager) KickoutByToken(ctx context.Context, tokenValue string) error { + return m.kickoutByToken(ctx, tokenValue) } -// ============ Token Validation | Token验证 ============ +// replace Replace user offline by login ID and device (private) | 根据账号和设备顶人下线(私有) +func (m *Manager) replace(ctx context.Context, loginID string, device string) error { + // Get the account key for this user and device | 获取该用户和设备对应的账户键 + accountKey := m.getAccountKey(loginID, device) -// IsLogin Checks if user is logged in | 检查是否登录 -func (m *Manager) IsLogin(tokenValue string) bool { - if tokenValue == "" { - return false + // Retrieve the token associated with this account key from storage | 从存储中获取该账户键对应的 Token + tokenValue, err := m.storage.Get(ctx, accountKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } - if _, err := m.getTokenInfo(tokenValue, false); err != nil { - return false + // If no token exists for this account key | 如果该账户键不存在 Token + if tokenValue == nil { + return core.ErrTokenNotFound } - // Async auto-renew for better performance | 异步自动续期(提高性能) - // Note: ActiveTimeout feature removed to comply with Java sa-token design - if m.config.AutoRenew && m.config.Timeout > 0 { - tokenKey := m.getTokenKey(tokenValue) - if ttl, err := m.storage.TTL(tokenKey); err == nil { - ttlSeconds := int64(ttl.Seconds()) + // Remove the token chain and trigger replace event | 移除 Token 链并触发顶下线事件 + if tokenValueStr, ok := assertString(tokenValue); ok { + return m.removeTokenChain(ctx, tokenValueStr, nil, listener.EventReplace) + } - // Perform renewal if TTL is below MaxRefresh threshold and RenewInterval allows | TTL和RenewInterval同时满足条件才续期 - if ttlSeconds > 0 && (m.config.MaxRefresh <= 0 || ttlSeconds <= m.config.MaxRefresh) && (m.config.RenewInterval <= 0 || !m.storage.Exists(m.getRenewKey(tokenValue))) { - renewFunc := func() { m.renewToken(tokenValue) } + return nil +} - // Submit to pool if configured, otherwise use goroutine | 使用续期池或协程执行续期 - if m.renewPool != nil { - _ = m.renewPool.Submit(renewFunc) // Submit token renewal task to the pool | 提交Token续期任务到续期池 - } else { - go renewFunc() // Fallback to goroutine if pool is not configured | 如果续期池未配置,使用普通协程 - } - } - } - } +// Replace user offline by login ID and device (public method) | 根据账号和设备顶人下线(公开方法) +func (m *Manager) Replace(ctx context.Context, loginID string, device ...string) error { + return m.replace(ctx, loginID, getDevice(device)) +} - return true +// replaceByToken Replace user offline by token (private) | 根据Token顶人下线(私有) +func (m *Manager) replaceByToken(ctx context.Context, tokenValue string) error { + return m.removeTokenChain(ctx, tokenValue, nil, listener.EventReplace) } -// CheckLogin Checks login status (throws error if not logged in) | 检查登录(未登录抛出错误) -func (m *Manager) CheckLogin(tokenValue string) error { - if !m.IsLogin(tokenValue) { - return ErrNotLogin - } - return nil +// ReplaceByToken Replace user offline by token (public method) | 根据Token顶人下线(公开方法) +func (m *Manager) ReplaceByToken(ctx context.Context, tokenValue string) error { + return m.replaceByToken(ctx, tokenValue) } -// CheckLoginWithState Checks if user is logged in | 检查是否登录(返回详细状态) -func (m *Manager) CheckLoginWithState(tokenValue string) (bool, error) { - if tokenValue == "" { - return false, nil - } +// ============ Token Validation | Token验证 ============ - // Try to get token info with state check | 尝试获取Token信息(包含状态检查) - _, err := m.getTokenInfo(tokenValue) +// IsLogin Checks if the user is logged in | 检查用户是否登录 +func (m *Manager) IsLogin(ctx context.Context, tokenValue string) (bool, error) { + info, err := m.getTokenInfo(ctx, tokenValue, true) if err != nil { return false, err } + // Check if the token has exceeded the active timeout | 检查Token是否超过活跃超时时间 + if m.config.ActiveTimeout > 0 { + now := time.Now().Unix() + if now-info.ActiveTime > m.config.ActiveTimeout { + // Force logout and clean up token data | 强制登出并清理Token相关数据 + _ = m.removeTokenChain(ctx, tokenValue, info, listener.EventKickout) + return false, core.ErrTokenKickout + } + } + // Async auto-renew for better performance | 异步自动续期(提高性能) - // Note: ActiveTimeout feature removed to comply with Java sa-token design if m.config.AutoRenew && m.config.Timeout > 0 { - if ttl, err := m.storage.TTL(m.getTokenKey(tokenValue)); err == nil { + // Construct the token storage key | 构造Token存储键 + tokenKey := m.getTokenKey(tokenValue) + + // Check if token renewal is needed | 检查是否需要进行续期 + if ttl, err := m.storage.TTL(ctx, tokenKey); err == nil { ttlSeconds := int64(ttl.Seconds()) - // Perform renewal if TTL is below MaxRefresh threshold and RenewInterval allows | TTL和RenewInterval同时满足条件才续期 - if ttlSeconds > 0 && (m.config.MaxRefresh <= 0 || ttlSeconds <= m.config.MaxRefresh) && (m.config.RenewInterval <= 0 || !m.storage.Exists(m.getRenewKey(tokenValue))) { - renewFunc := func() { m.renewToken(tokenValue) } + // Perform renewal if TTL is below MaxRefresh threshold and RenewInterval allows | 如果TTL小于MaxRefresh阈值且RenewInterval允许,则进行续期 + if ttlSeconds > 0 && (m.config.MaxRefresh <= 0 || ttlSeconds <= m.config.MaxRefresh) && (m.config.RenewInterval <= 0 || !m.storage.Exists(ctx, m.getRenewKey(tokenValue))) { + renewFunc := func() { m.renewToken(context.Background(), tokenValue, info) } - // Submit to pool if configured, otherwise use goroutine | 使用续期池或协程执行续期 - if m.renewPool != nil { - _ = m.renewPool.Submit(renewFunc) // Submit token renewal task to the pool | 提交Token续期任务到续期池 + // Submit renewal task to the pool if configured, otherwise use a goroutine | 如果配置了续期池,则提交续期任务到池中,否则使用协程 + if m.pool != nil { + _ = m.pool.Submit(renewFunc) // Submit token renewal task to the pool | 提交Token续期任务到续期池 } else { - go renewFunc() // Fallback to goroutine if pool is not configured | 如果续期池未配置,使用普通协程 + go renewFunc() // Fallback to goroutine if pool is not configured | 如果没有配置续期池,使用普通协程 } } } @@ -438,240 +450,584 @@ func (m *Manager) CheckLoginWithState(tokenValue string) (bool, error) { return true, nil } -// GetLoginID Gets login ID from token | 根据Token获取登录ID -func (m *Manager) GetLoginID(tokenValue string) (string, error) { - if !m.IsLogin(tokenValue) { - return "", ErrNotLogin +// CheckLogin Checks login status | 检查登录 +func (m *Manager) CheckLogin(ctx context.Context, tokenValue string) error { + isLogin, err := m.IsLogin(ctx, tokenValue) + if err != nil { + return err } + if !isLogin { + return core.ErrTokenExpired + } + + return nil +} - info, err := m.getTokenInfo(tokenValue) +// ============ Token Information | Token信息与解析 ============ + +// GetLoginID Gets login ID from token | 根据Token获取登录ID +func (m *Manager) GetLoginID(ctx context.Context, tokenValue string) (string, error) { + // Check if the user is logged in | 检查用户是否已登录 + isLogin, err := m.IsLogin(ctx, tokenValue) if err != nil { return "", err } - if info == nil { - return "", ErrInvalidTokenData + if !isLogin { + return "", core.ErrTokenExpired } - return info.LoginID, nil + // Retrieve the login ID without checking token validity | 获取登录ID 不检查Token有效性 + return m.GetLoginIDNotCheck(ctx, tokenValue) } -// GetLoginIDNotCheck Gets login ID without checking token validity | 获取登录ID(不检查Token是否有效) -func (m *Manager) GetLoginIDNotCheck(tokenValue string) (string, error) { - info, err := m.getTokenInfo(tokenValue) +// GetLoginIDNotCheck Gets login ID without checking token validity | 获取登录ID 不续期Token +func (m *Manager) GetLoginIDNotCheck(ctx context.Context, tokenValue string) (string, error) { + // Get token info | 获取Token信息 + info, err := m.getTokenInfo(ctx, tokenValue) if err != nil { return "", err } - if info == nil { - return "", ErrInvalidTokenData - } - return info.LoginID, err + + return info.LoginID, nil } -// GetTokenValue Gets token by login ID | 根据登录ID获取Token -func (m *Manager) GetTokenValue(loginID string, device ...string) (string, error) { - deviceType := getDevice(device) - accountKey := m.getAccountKey(loginID, deviceType) +// GetTokenValue Gets token by login ID and device | 根据登录ID以及设备获取Token +func (m *Manager) GetTokenValue(ctx context.Context, loginID string, device ...string) (string, error) { + // Construct the account storage key | 构造账号存储键 + accountKey := m.getAccountKey(loginID, getDevice(device)) - tokenValue, err := m.storage.Get(accountKey) - if err != nil || tokenValue == nil { - return "", fmt.Errorf("token not found for login id: %s", loginID) + // Retrieve the token value from storage | 从存储中获取Token值 + tokenValue, err := m.storage.Get(ctx, accountKey) + if err != nil { + return "", core.ErrStorageUnavailable + } + if tokenValue == nil { + return "", core.ErrTokenNotFound } - tokenStr, ok := assertString(tokenValue) + // Assert token value as a string | 断言Token值为字符串 + tokenValueStr, ok := assertString(tokenValue) if !ok { - return "", fmt.Errorf("invalid token value type") + return "", core.ErrTokenNotFound } - return tokenStr, nil + return tokenValueStr, nil } -// GetTokenInfo Gets token information | 获取Token信息 -func (m *Manager) GetTokenInfo(tokenValue string) (*TokenInfo, error) { - return m.getTokenInfo(tokenValue) +// GetTokenInfoByToken Gets token information | 获取Token信息 +func (m *Manager) GetTokenInfoByToken(ctx context.Context, tokenValue string) (*TokenInfo, error) { + return m.getTokenInfo(ctx, tokenValue) } // ============ Account Disable | 账号封禁 ============ // Disable Disables an account | 封禁账号 -func (m *Manager) Disable(loginID string, duration time.Duration) error { +func (m *Manager) Disable(ctx context.Context, loginID string, duration time.Duration, reason ...string) error { + // Retrieve the disable flag storage key | 获取封禁标记的存储键 + disableKeyKey := m.getDisableKey(loginID) + + // Prepare disable information | 准备封禁信息 + disableInfo := DisableInfo{ + DisableTime: time.Now().Unix(), + DisableReason: "", + } + if len(reason) > 0 { + disableInfo.DisableReason = reason[0] + } + + // Encode disable information into storage format | 将封禁信息序列化为存储格式 + encodeData, err := m.serializer.Encode(disableInfo) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + + // Set disable flag with specified duration | 设置封禁标记并指定封禁时长 + err = m.storage.Set(ctx, disableKeyKey, encodeData, duration) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + // Check if the account has active sessions and force logout | 检查账号是否有活跃会话并强制下线 - tokens, err := m.GetTokenValueListByLoginID(loginID) + tokens, err := m.GetTokenValueListByLoginID(ctx, loginID) if err == nil && len(tokens) > 0 { for _, tokenValue := range tokens { // Force kick out each active token | 强制踢出所有活跃的Token - _ = m.removeTokenChain(tokenValue, true, listener.EventLogout) + _ = m.removeTokenChain(ctx, tokenValue, nil, listener.EventKickout, true) } } - key := m.getDisableKey(loginID) - // Set disable flag with specified duration | 设置封禁标记并指定封禁时长 - return m.storage.Set(key, DisableValue, duration) + return nil } // Untie Re-enables a disabled account | 解封账号 -func (m *Manager) Untie(loginID string) error { - key := m.getDisableKey(loginID) - return m.storage.Delete(key) +func (m *Manager) Untie(ctx context.Context, loginID string) error { + // Retrieve the disable flag storage key | 获取封禁标记的存储键 + disableKeyKey := m.getDisableKey(loginID) + + // Remove the disable flag from storage | 删除封禁标记 + err := m.storage.Delete(ctx, disableKeyKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + return nil } // IsDisable Checks if account is disabled | 检查账号是否被封禁 -func (m *Manager) IsDisable(loginID string) bool { - key := m.getDisableKey(loginID) - return m.storage.Exists(key) +func (m *Manager) IsDisable(ctx context.Context, loginID string) bool { + // Retrieve the disable flag storage key | 获取封禁标记的存储键 + disableKeyKey := m.getDisableKey(loginID) + + // Check if the disable flag exists in storage | 检查封禁标记是否存在 + return m.storage.Exists(ctx, disableKeyKey) } -// GetDisableTime Gets remaining disable time in seconds | 获取账号剩余封禁时间(秒) -func (m *Manager) GetDisableTime(loginID string) (int64, error) { - key := m.getDisableKey(loginID) - ttl, err := m.storage.TTL(key) +// GetDisableInfo get disable info | 获取封禁信息 +func (m *Manager) GetDisableInfo(ctx context.Context, loginID string) (*DisableInfo, error) { + // Retrieve the disable flag storage key | 获取封禁标记的存储键 + disableKeyKey := m.getDisableKey(loginID) + + // Get disable data from storage | 从存储中获取封禁信息 + data, err := m.storage.Get(ctx, disableKeyKey) if err != nil { - return -2, err + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } - return int64(ttl.Seconds()), nil + if data == nil { + // 注意 这里返回一个空的DisableInfo对象,而不是返回nil 不然取值时会panic + return &DisableInfo{}, nil + } + + // 将数据转换为字节数组 + raw, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } + + // Decode stored disable information | 反序列化封禁信息 + var disableInfo DisableInfo + if err := m.serializer.Decode(raw, &disableInfo); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) + } + + return &disableInfo, nil } -// getDisableKey Gets disable storage key | 获取禁用存储键 -func (m *Manager) getDisableKey(loginID string) string { - return m.prefix + DisableKeyPrefix + loginID +// GetDisableTTL Gets remaining disable time in seconds | 获取账号剩余封禁时间(秒) +func (m *Manager) GetDisableTTL(ctx context.Context, loginID string) (int64, error) { + // Retrieve the disable flag storage key | 获取封禁标记的存储键 + disableKeyKey := m.getDisableKey(loginID) + + // Retrieve the TTL (Time to Live) for the disable flag | 获取封禁标记的TTL(剩余时间) + ttl, err := m.storage.TTL(ctx, disableKeyKey) + if err != nil { + return -2, err + } + + // Return the remaining disable time in seconds | 返回剩余封禁时间(秒) + return int64(ttl.Seconds()), nil } // ============ Session Management | Session管理 ============ -// GetSession Gets session by login ID | 获取Session -func (m *Manager) GetSession(loginID string) (*session.Session, error) { - sess, err := session.Load(loginID, m.storage, m.prefix) +// GetSession gets session by login ID | 获取 Session +func (m *Manager) GetSession(ctx context.Context, loginID string) (*session.Session, error) { + if loginID == "" { + return nil, core.ErrSessionIDEmpty + } + + key := m.config.KeyPrefix + m.config.AuthType + session.SessionKeyPrefix + loginID + data, err := m.storage.Get(ctx, key) if err != nil { - sess = session.NewSession(loginID, m.storage, m.prefix) + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + // If found, decode session | 如果找到 Session 则解码 + var sess *session.Session + if data != nil { + raw, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } + + sess = &session.Session{} + if err := m.serializer.Decode(raw, sess); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) + } + + // Set internal dependencies after decoding | 解码后设置内部依赖 + sess.SetDependencies(m.config.KeyPrefix, m.storage, m.serializer) + } + + // If not exist, create new session | 没找到就创建新的 Session + if sess == nil { + sess = session.NewSession(m.config.AuthType, m.config.KeyPrefix, loginID, m.storage, m.serializer) } + return sess, nil } // GetSessionByToken Gets session by token | 根据Token获取Session -func (m *Manager) GetSessionByToken(tokenValue string) (*session.Session, error) { - loginID, err := m.GetLoginID(tokenValue) +func (m *Manager) GetSessionByToken(ctx context.Context, tokenValue string) (*session.Session, error) { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) if err != nil { return nil, err } - return m.GetSession(loginID) + + return m.GetSession(ctx, loginID) } // DeleteSession Deletes session | 删除Session -func (m *Manager) DeleteSession(loginID string) error { - sess, err := m.GetSession(loginID) +func (m *Manager) DeleteSession(ctx context.Context, loginID string) error { + sess, err := m.GetSession(ctx, loginID) if err != nil { return err } - return sess.Destroy() -} -// ============ Permission Validation | 权限验证 ============ + return sess.Destroy(ctx) +} -// SetPermissions Sets permissions for user | 设置权限 -func (m *Manager) SetPermissions(loginID string, permissions []string) error { - sess, err := m.GetSession(loginID) +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func (m *Manager) DeleteSessionByToken(ctx context.Context, tokenValue string) error { + sess, err := m.GetSessionByToken(ctx, tokenValue) if err != nil { return err } - return sess.Set(SessionKeyPermissions, permissions, m.getExpiration()) + + return sess.Destroy(ctx) } -// GetPermissions Gets permission list | 获取权限列表 -func (m *Manager) GetPermissions(loginID string) ([]string, error) { - sess, err := m.GetSession(loginID) - if err != nil { - return nil, err +// HasSession Checks if session exists | 检查Session是否存在 +func (m *Manager) HasSession(ctx context.Context, loginID string) bool { + if loginID == "" { + return false } - perms, exists := sess.Get(SessionKeyPermissions) - if !exists { - return []string{}, nil + key := m.config.KeyPrefix + m.config.AuthType + session.SessionKeyPrefix + loginID + return m.GetStorage().Exists(ctx, key) +} + +// RenewSession Renews session TTL | 续期Session +func (m *Manager) RenewSession(ctx context.Context, loginID string, ttl time.Duration) error { + sess, err := m.GetSession(ctx, loginID) + if err != nil { + return err } - return m.toStringSlice(perms), nil + return sess.Renew(ctx, ttl) } -// HasPermission 检查是否有指定权限 -func (m *Manager) HasPermission(loginID string, permission string) bool { - perms, err := m.GetPermissions(loginID) +// ============ Permission Validation | 权限验证 ============ + +// SetPermissions Sets permissions for user | 设置权限 +func (m *Manager) SetPermissions(ctx context.Context, loginID string, permissions []string) error { + sess, err := m.GetSession(ctx, loginID) if err != nil { - return false + return err } - for _, p := range perms { - if m.matchPermission(p, permission) { - return true - } + permissionsFromSession, ok := sess.Get(SessionKeyPermissions) + if ok { + permissions = append(permissions, m.toStringSlice(permissionsFromSession)...) + permissions = removeDuplicateStrings(permissions) } - return false + return sess.Set(ctx, SessionKeyPermissions, permissions, m.getExpiration()) } -// HasPermissionsAnd 检查是否拥有所有权限(AND) -func (m *Manager) HasPermissionsAnd(loginID string, permissions []string) bool { - for _, perm := range permissions { - if !m.HasPermission(loginID, perm) { - return false - } +// SetPermissionsByToken Sets permissions by token | 根据Token设置权限 +func (m *Manager) SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string) error { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return err } - return true + + return m.SetPermissions(ctx, loginID, permissions) } -// HasPermissionsOr 检查是否拥有任一权限(OR) -func (m *Manager) HasPermissionsOr(loginID string, permissions []string) bool { - for _, perm := range permissions { - if m.HasPermission(loginID, perm) { - return true - } +// RemovePermissions removes specified permissions for user | 删除用户指定权限 +func (m *Manager) RemovePermissions(ctx context.Context, loginID string, permissions []string) error { + sess, err := m.GetSession(ctx, loginID) + if err != nil { + return err } - return false -} -// matchPermission Matches permission with wildcards support | 权限匹配(支持通配符) -func (m *Manager) matchPermission(pattern, permission string) bool { - // Exact match or wildcard | 精确匹配或通配符 - if pattern == PermissionWildcard || pattern == permission { - return true + permissionsFromSession, ok := sess.Get(SessionKeyPermissions) + if !ok { + return nil } - // Pattern like "user:*" matches "user:add", "user:delete", etc. | 支持通配符,例如 user:* 匹配 user:add, user:delete等 - wildcardSuffix := PermissionSeparator + PermissionWildcard - if strings.HasSuffix(pattern, wildcardSuffix) { - prefix := strings.TrimSuffix(pattern, PermissionWildcard) - return strings.HasPrefix(permission, prefix) + existingPerms := m.toStringSlice(permissionsFromSession) + if len(existingPerms) == 0 { + return nil } - // Pattern like "user:*:view" | 支持 user:*:view 这样的模式 - if strings.Contains(pattern, PermissionWildcard) { - parts := strings.Split(pattern, PermissionSeparator) - permParts := strings.Split(permission, PermissionSeparator) - if len(parts) != len(permParts) { - return false - } - for i, part := range parts { - if part != PermissionWildcard && part != permParts[i] { - return false - } + // Build a set for fast lookup of permissions to remove | 构建待删除权限集合 + removeSet := make(map[string]struct{}, len(permissions)) + for _, p := range permissions { + removeSet[p] = struct{}{} + } + + // Filter out permissions to be removed | 过滤掉需要删除的权限 + newPerms := make([]string, 0, len(existingPerms)) + for _, p := range existingPerms { + if _, shouldRemove := removeSet[p]; !shouldRemove { + newPerms = append(newPerms, p) } - return true } - return false + return sess.Set(ctx, SessionKeyPermissions, newPerms, m.getExpiration()) } -// ============ Role Validation | 角色验证 ============ - -// SetRoles Sets roles for user | 设置角色 -func (m *Manager) SetRoles(loginID string, roles []string) error { - sess, err := m.GetSession(loginID) +// RemovePermissionsByToken removes specified permissions by token | 根据Token删除指定权限 +func (m *Manager) RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string) error { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) if err != nil { return err } - return sess.Set(SessionKeyRoles, roles, m.getExpiration()) + + return m.RemovePermissions(ctx, loginID, permissions) } -// GetRoles Gets role list | 获取角色列表 -func (m *Manager) GetRoles(loginID string) ([]string, error) { - sess, err := m.GetSession(loginID) +// GetPermissions Gets permission list | 获取权限列表 +func (m *Manager) GetPermissions(ctx context.Context, loginID string) ([]string, error) { + if m.CustomPermissionListFunc != nil { + perms, err := m.CustomPermissionListFunc(loginID, m.config.AuthType) + if err != nil { + return nil, err + } + return perms, nil + } + + sess, err := m.GetSession(ctx, loginID) + if err != nil { + return nil, err + } + + perms, exists := sess.Get(SessionKeyPermissions) + if !exists { + return []string{}, nil + } + + return m.toStringSlice(perms), nil +} + +// GetPermissionsByToken Gets permission list by token | 根据Token获取权限列表 +func (m *Manager) GetPermissionsByToken(ctx context.Context, tokenValue string) ([]string, error) { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return nil, err + } + + return m.GetPermissions(ctx, loginID) +} + +// HasPermission checks whether the specified loginID has the given permission | 检查指定账号是否拥有指定权限 +func (m *Manager) HasPermission(ctx context.Context, loginID string, permission string) bool { + perms, err := m.GetPermissions(ctx, loginID) + if err != nil { + return false + } + + for _, p := range perms { + if m.matchPermission(p, permission) { + return true + } + } + + return false +} + +// HasPermissionByToken checks whether the current token subject has the specified permission | 根据当前 Token 判断是否拥有指定权限 +func (m *Manager) HasPermissionByToken(ctx context.Context, tokenValue string, permission string) bool { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return m.HasPermission(ctx, loginID, permission) +} + +// HasPermissionsAnd Checks whether the user has all permissions (AND) | 是否拥有所有权限(AND) +func (m *Manager) HasPermissionsAnd(ctx context.Context, loginID string, permissions []string) bool { + userPerms, err := m.GetPermissions(ctx, loginID) + if err != nil || len(userPerms) == 0 { + return false + } + + // Check every required permission | 校验每一个必需权限 + for _, need := range permissions { + if !m.hasPermissionInList(userPerms, need) { + return false + } + } + + return true +} + +// HasPermissionsAndByToken checks whether the current token subject has all specified permissions (AND) | 根据当前 Token 判断是否拥有所有指定权限(AND) +func (m *Manager) HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string) bool { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return m.HasPermissionsAnd(ctx, loginID, permissions) +} + +// HasPermissionsOr Checks whether the user has any permission (OR) | 是否拥有任一权限(OR) +func (m *Manager) HasPermissionsOr(ctx context.Context, loginID string, permissions []string) bool { + // Get all permissions once | 一次性获取用户权限 + userPerms, err := m.GetPermissions(ctx, loginID) + if err != nil || len(userPerms) == 0 { + return false + } + + // Check if any permission matches | 任一权限匹配即通过 + for _, need := range permissions { + if m.hasPermissionInList(userPerms, need) { + return true + } + } + return false +} + +// HasPermissionsOrByToken checks whether the current token subject has any of the specified permissions (OR) | 根据当前 Token 判断是否拥有任一指定权限(OR) +func (m *Manager) HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string) bool { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return m.HasPermissionsOr(ctx, loginID, permissions) +} + +// matchPermission Matches permission with wildcards support | 权限匹配(支持通配符) +func (m *Manager) matchPermission(pattern, permission string) bool { + // Exact match or wildcard | 精确匹配或通配符 + if pattern == PermissionWildcard || pattern == permission { + return true + } + + // Pattern like "user:*" matches "user:add", "user:delete", etc. | 支持通配符,例如 user:* 匹配 user:add, user:delete等 + wildcardSuffix := PermissionSeparator + PermissionWildcard + if strings.HasSuffix(pattern, wildcardSuffix) { + prefix := strings.TrimSuffix(pattern, PermissionWildcard) + return strings.HasPrefix(permission, prefix) + } + + // Pattern like "user:*:view" | 支持 user:*:view 这样的模式 + if strings.Contains(pattern, PermissionWildcard) { + parts := strings.Split(pattern, PermissionSeparator) + permParts := strings.Split(permission, PermissionSeparator) + if len(parts) != len(permParts) { + return false + } + for i, part := range parts { + if part != PermissionWildcard && part != permParts[i] { + return false + } + } + return true + } + + return false +} + +// hasPermissionInList checks whether permission exists in permission list | 判断权限是否存在于权限列表中 +func (m *Manager) hasPermissionInList(perms []string, permission string) bool { + for _, p := range perms { + if m.matchPermission(p, permission) { + return true + } + } + return false +} + +// ============ Role Validation | 角色验证 ============ + +// SetRoles Sets roles for user | 设置角色 +func (m *Manager) SetRoles(ctx context.Context, loginID string, roles []string) error { + sess, err := m.GetSession(ctx, loginID) + if err != nil { + return err + } + + rolesFromSession, ok := sess.Get(SessionKeyRoles) + if ok { + roles = append(roles, m.toStringSlice(rolesFromSession)...) + roles = removeDuplicateStrings(roles) + } + + return sess.Set(ctx, SessionKeyRoles, roles, m.getExpiration()) +} + +// SetRolesByToken Sets roles by token | 根据Token设置角色 +func (m *Manager) SetRolesByToken(ctx context.Context, tokenValue string, roles []string) error { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return err + } + + return m.SetRoles(ctx, loginID, roles) +} + +// RemoveRoles removes specified roles for user | 删除用户指定角色 +func (m *Manager) RemoveRoles(ctx context.Context, loginID string, roles []string) error { + sess, err := m.GetSession(ctx, loginID) + if err != nil { + return err + } + + // Load existing roles | 加载已有角色 + rolesFromSession, ok := sess.Get(SessionKeyRoles) + if !ok { + return nil // No roles to remove | 没有角色可删除 + } + + existingRoles := m.toStringSlice(rolesFromSession) + if len(existingRoles) == 0 { + return nil + } + + // Build lookup set for roles to remove | 构建待删除角色集合 + removeSet := make(map[string]struct{}, len(roles)) + for _, r := range roles { + removeSet[r] = struct{}{} + } + + // Filter existing roles | 过滤掉需要删除的角色 + newRoles := make([]string, 0, len(existingRoles)) + for _, r := range existingRoles { + if _, remove := removeSet[r]; !remove { + newRoles = append(newRoles, r) + } + } + + // Save updated roles | 保存更新后的角色列表 + return sess.Set(ctx, SessionKeyRoles, newRoles, m.getExpiration()) +} + +// RemoveRolesByToken removes specified roles by token | 根据Token删除指定角色 +func (m *Manager) RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string) error { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return err + } + + return m.RemoveRoles(ctx, loginID, roles) +} + +// GetRoles gets role list for the specified loginID | 获取指定账号的角色列表 +func (m *Manager) GetRoles(ctx context.Context, loginID string) ([]string, error) { + if m.CustomRoleListFunc != nil { + perms, err := m.CustomRoleListFunc(loginID, m.config.AuthType) + if err != nil { + return nil, err + } + return perms, nil + } + + sess, err := m.GetSession(ctx, loginID) if err != nil { return nil, err } @@ -684,9 +1040,19 @@ func (m *Manager) GetRoles(loginID string) ([]string, error) { return m.toStringSlice(roles), nil } -// HasRole 检查是否有指定角色 -func (m *Manager) HasRole(loginID string, role string) bool { - roles, err := m.GetRoles(loginID) +// GetRolesByToken Gets role list by token | 根据Token获取角色列表 +func (m *Manager) GetRolesByToken(ctx context.Context, tokenValue string) ([]string, error) { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return nil, err + } + + return m.GetRoles(ctx, loginID) +} + +// HasRole checks whether the specified loginID has the given role | 检查指定账号是否拥有指定角色 +func (m *Manager) HasRole(ctx context.Context, loginID string, role string) bool { + roles, err := m.GetRoles(ctx, loginID) if err != nil { return false } @@ -699,20 +1065,70 @@ func (m *Manager) HasRole(loginID string, role string) bool { return false } -// HasRolesAnd 检查是否拥有所有角色(AND) -func (m *Manager) HasRolesAnd(loginID string, roles []string) bool { - for _, role := range roles { - if !m.HasRole(loginID, role) { +// HasRoleByToken checks whether the current token subject has the specified role | 根据当前 Token 判断是否拥有指定角色 +func (m *Manager) HasRoleByToken(ctx context.Context, tokenValue string, role string) bool { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return m.HasRole(ctx, loginID, role) +} + +// HasRolesAnd Checks whether the user has all roles (AND) | 是否拥有所有角色(AND) +func (m *Manager) HasRolesAnd(ctx context.Context, loginID string, roles []string) bool { + userRoles, err := m.GetRoles(ctx, loginID) + if err != nil || len(userRoles) == 0 { + return false + } + + for _, need := range roles { + if !m.hasRoleInList(userRoles, need) { return false } } return true } -// HasRolesOr 检查是否拥有任一角色(OR) -func (m *Manager) HasRolesOr(loginID string, roles []string) bool { - for _, role := range roles { - if m.HasRole(loginID, role) { +// HasRolesAndByToken checks whether the current token subject has all specified roles (AND) | 根据当前 Token 判断是否拥有所有指定角色(AND) +func (m *Manager) HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string) bool { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return m.HasRolesAnd(ctx, loginID, roles) +} + +// HasRolesOr Checks whether the user has any role (OR) | 是否拥有任一角色(OR) +func (m *Manager) HasRolesOr(ctx context.Context, loginID string, roles []string) bool { + userRoles, err := m.GetRoles(ctx, loginID) + if err != nil || len(userRoles) == 0 { + return false + } + + for _, need := range roles { + if m.hasRoleInList(userRoles, need) { + return true + } + } + return false +} + +// HasRolesOrByToken checks whether the current token subject has any of the specified roles (OR) | 根据当前 Token 判断是否拥有任一指定角色(OR) +func (m *Manager) HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string) bool { + loginID, err := m.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return m.HasRolesOr(ctx, loginID, roles) +} + +// hasPermissionInList checks whether permission exists in permission list | 判断权限是否存在于权限列表中 +func (m *Manager) hasRoleInList(roles []string, role string) bool { + for _, r := range roles { + if r == role { return true } } @@ -722,335 +1138,589 @@ func (m *Manager) HasRolesOr(loginID string, roles []string) bool { // ============ Token Tags | Token标签 ============ // SetTokenTag Sets token tag | 设置Token标签 -func (m *Manager) SetTokenTag(tokenValue, tag string) error { +func (m *Manager) SetTokenTag(tag string) error { // Tag feature not supported to comply with Java sa-token design // If you need custom metadata, use Session instead return fmt.Errorf("token tag feature not supported (use Session for custom metadata)") } // GetTokenTag Gets token tag | 获取Token标签 -func (m *Manager) GetTokenTag(tokenValue string) (string, error) { +func (m *Manager) GetTokenTag(ctx context.Context) (string, error) { // Tag feature not supported to comply with Java sa-token design return "", fmt.Errorf("token tag feature not supported (use Session for custom metadata)") } -// ============ Session Query | 会话查询 ============ +// ============ Token & Session Info | Token 与会话信息查询 ============ // GetTokenValueListByLoginID Gets all tokens for specified account | 获取指定账号的所有Token -func (m *Manager) GetTokenValueListByLoginID(loginID string) ([]string, error) { - pattern := m.prefix + AccountKeyPrefix + loginID + ":*" - keys, err := m.storage.Keys(pattern) +func (m *Manager) GetTokenValueListByLoginID(ctx context.Context, loginID string) ([]string, error) { + // Construct the pattern for account key | 构造账号存储键的匹配模式 + pattern := m.config.KeyPrefix + m.config.AuthType + AccountKeyPrefix + loginID + TokenValueListLastKey + + // Retrieve keys matching the pattern from storage | 从存储中获取匹配的键 + keys, err := m.storage.Keys(ctx, pattern) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } + // Initialize a slice to hold the token strings | 初始化切片来存储Token字符串 tokens := make([]string, 0, len(keys)) + + // Loop through the keys and retrieve the associated token values | 遍历键并获取关联的Token值 for _, key := range keys { - value, err := m.storage.Get(key) + value, err := m.storage.Get(ctx, key) if err == nil && value != nil { + // Assert value as string and add to tokens slice | 将值断言为字符串并添加到Token切片 if tokenStr, ok := assertString(value); ok { - tokens = append(tokens, tokenStr) + // Get the token info from storage | 从存储中获取Token信息 + tokenInfo, err := m.storage.Get(ctx, m.getTokenKey(tokenStr)) + if err == nil && tokenInfo != nil { + tokenInfoStr, assertOk := assertString(tokenInfo) + if assertOk && tokenInfoStr != string(TokenStateKickout) && tokenInfoStr != string(TokenStateReplaced) { + tokens = append(tokens, tokenStr) + } + } } } } + // Return the list of token strings | 返回Token字符串列表 return tokens, nil } // GetSessionCountByLoginID Gets session count for specified account | 获取指定账号的Session数量 -func (m *Manager) GetSessionCountByLoginID(loginID string) (int, error) { - tokens, err := m.GetTokenValueListByLoginID(loginID) +func (m *Manager) GetSessionCountByLoginID(ctx context.Context, loginID string) (int, error) { + // Get the list of token values for the specified login ID | 获取指定登录ID的Token值列表 + tokens, err := m.GetTokenValueListByLoginID(ctx, loginID) if err != nil { - return 0, err + return 0, err // Return error if token list retrieval fails | 如果获取Token列表失败,则返回错误 } + + // Return the count of tokens as the session count | 返回Token数量作为Session数量 return len(tokens), nil } -// ============ Internal Helper Methods | 内部辅助方法 ============ +// ============ Event Management | 事件管理 ============ -// getTokenKey Gets token storage key | 获取Token存储键 -func (m *Manager) getTokenKey(tokenValue string) string { - return m.prefix + TokenKeyPrefix + tokenValue +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func (m *Manager) RegisterFunc(event listener.Event, fn func(*listener.EventData)) { + m.eventManager.RegisterFunc(event, fn) } -// getAccountKey Gets account storage key | 获取账号存储键 -func (m *Manager) getAccountKey(loginID, device string) string { - return m.prefix + AccountKeyPrefix + loginID + PermissionSeparator + device +// Register registers an event listener | 注册事件监听器 +func (m *Manager) Register(event listener.Event, listener listener.Listener) string { + return m.eventManager.Register(event, listener) } -// getRenewKey Gets token renewal tracking key | 获取Token续期追踪键 -func (m *Manager) getRenewKey(tokenValue string) string { - return m.prefix + RenewKeyPrefix + tokenValue +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func (m *Manager) RegisterWithConfig(event listener.Event, listener listener.Listener, config listener.ListenerConfig) string { + return m.eventManager.RegisterWithConfig(event, listener, config) } -// getLoginIDByToken Gets loginID by token (符合 Java sa-token 设计) | 通过 Token 获取 loginID -func (m *Manager) getLoginIDByToken(tokenValue string) (string, error) { - info, err := m.getTokenInfo(tokenValue) - if err != nil { - return "", err +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func (m *Manager) Unregister(id string) bool { + return m.eventManager.Unregister(id) +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func (m *Manager) TriggerEvent(data *listener.EventData) { + m.eventManager.Trigger(data) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func (m *Manager) TriggerEventAsync(data *listener.EventData) { + m.eventManager.TriggerAsync(data) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func (m *Manager) TriggerEventSync(data *listener.EventData) { + m.eventManager.TriggerSync(data) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func (m *Manager) WaitEvents() { + m.eventManager.Wait() +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func (m *Manager) ClearEventListeners(event listener.Event) { + m.eventManager.ClearEvent(event) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func (m *Manager) ClearAllEventListeners() { + m.eventManager.Clear() +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func (m *Manager) CountEventListeners(event listener.Event) int { + return m.eventManager.CountForEvent(event) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func (m *Manager) CountAllListeners() int { + return m.eventManager.Count() +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func (m *Manager) GetEventListenerIDs(event listener.Event) []string { + return m.eventManager.GetListenerIDs(event) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func (m *Manager) GetAllRegisteredEvents() []listener.Event { + return m.eventManager.GetAllEvents() +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func (m *Manager) HasEventListeners(event listener.Event) bool { + return m.eventManager.HasListeners(event) +} + +// ============ Security Features | 安全特性 ============ + +// SecurityGenerateNonce Generates a one-time nonce | 生成一次性随机数 +func (m *Manager) SecurityGenerateNonce(ctx context.Context) (string, error) { + return m.nonceManager.Generate(ctx) +} + +// SecurityVerifyNonce Verifies a nonce | 验证随机数 +func (m *Manager) SecurityVerifyNonce(ctx context.Context, nonce string) bool { + return m.nonceManager.Verify(ctx, nonce) +} + +// SecurityVerifyAndConsumeNonce Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func (m *Manager) SecurityVerifyAndConsumeNonce(ctx context.Context, nonce string) error { + return m.nonceManager.VerifyAndConsume(ctx, nonce) +} + +// SecurityIsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func (m *Manager) SecurityIsValidNonce(ctx context.Context, nonce string) bool { + return m.nonceManager.IsValid(ctx, nonce) +} + +// SecurityGenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func (m *Manager) SecurityGenerateTokenPair(ctx context.Context, loginID string, device ...string) (*security.RefreshTokenInfo, error) { + deviceType := getDevice(device) + return m.refreshManager.GenerateTokenPair(ctx, loginID, deviceType) +} + +// SecurityVerifyAccessToken Check token exists | 验证访问令牌是否存在 +func (m *Manager) SecurityVerifyAccessToken(ctx context.Context, accessToken string) bool { + return m.refreshManager.VerifyAccessToken(ctx, accessToken) +} + +// SecurityVerifyAccessTokenAndGetInfo Verify and get info | 验证访问令牌并获取信息 +func (m *Manager) SecurityVerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string) (*security.AccessTokenInfo, bool) { + return m.refreshManager.VerifyAccessTokenAndGetInfo(ctx, accessToken) +} + +// SecurityRefreshAccessToken Refresh access token by refresh token | 使用刷新令牌刷新访问令牌 +func (m *Manager) SecurityRefreshAccessToken(ctx context.Context, refreshToken string) (*security.RefreshTokenInfo, error) { + return m.refreshManager.RefreshAccessToken(ctx, refreshToken) +} + +// SecurityGetRefreshTokenInfo Get refresh token info by token | 根据刷新令牌获取刷新令牌信息 +func (m *Manager) SecurityGetRefreshTokenInfo(ctx context.Context, refreshToken string) (*security.RefreshTokenInfo, error) { + return m.refreshManager.GetRefreshTokenInfo(ctx, refreshToken) +} + +// SecurityRevokeRefreshToken Remove refresh token | 撤销刷新令牌 +func (m *Manager) SecurityRevokeRefreshToken(ctx context.Context, refreshToken string) error { + return m.refreshManager.RevokeRefreshToken(ctx, refreshToken) +} + +// SecurityIsRefreshTokenValid Check refresh token valid | 判断刷新令牌是否有效 +func (m *Manager) SecurityIsRefreshTokenValid(ctx context.Context, refreshToken string) bool { + return m.refreshManager.IsValid(ctx, refreshToken) +} + +// ============ OAuth2 Features | Oauth2特性 ============ + +// OAuth2RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func (m *Manager) OAuth2RegisterClient(client *oauth2.Client) error { + return m.oauth2Server.RegisterClient(client) +} + +// OAuth2UnregisterClient Unregisters an OAuth2 client | 注销OAuth2客户端 +func (m *Manager) OAuth2UnregisterClient(clientID string) { + m.oauth2Server.UnregisterClient(clientID) +} + +// OAuth2GetClient Gets client by ID | 根据ID获取客户端 +func (m *Manager) OAuth2GetClient(clientID string) (*oauth2.Client, error) { + return m.oauth2Server.GetClient(clientID) +} + +// OAuth2GenerateAuthorizationCode Generates authorization code | 生成授权码 +func (m *Manager) OAuth2GenerateAuthorizationCode(ctx context.Context, clientID, userID, redirectURI string, scopes []string) (*oauth2.AuthorizationCode, error) { + return m.oauth2Server.GenerateAuthorizationCode(ctx, clientID, userID, redirectURI, scopes) +} + +// OAuth2ExchangeCodeForToken Exchanges authorization code for access token | 用授权码换取访问令牌 +func (m *Manager) OAuth2ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string) (*oauth2.AccessToken, error) { + return m.oauth2Server.ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI) +} + +// OAuth2ValidateAccessToken Validates access token | 验证访问令牌 +func (m *Manager) OAuth2ValidateAccessToken(ctx context.Context, accessToken string) bool { + return m.oauth2Server.ValidateAccessToken(ctx, accessToken) +} + +// OAuth2ValidateAccessTokenAndGetInfo Validates access token and get info | 验证访问令牌并获取信息 +func (m *Manager) OAuth2ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string) (*oauth2.AccessToken, error) { + return m.oauth2Server.ValidateAccessTokenAndGetInfo(ctx, accessToken) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func (m *Manager) OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string) (*oauth2.AccessToken, error) { + return m.oauth2Server.RefreshAccessToken(ctx, clientID, refreshToken, clientSecret) +} + +// OAuth2RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func (m *Manager) OAuth2RevokeToken(ctx context.Context, accessToken string) error { + return m.oauth2Server.RevokeToken(ctx, accessToken) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点,根据授权类型分发到相应的处理逻辑 +func (m *Manager) OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator) (*oauth2.AccessToken, error) { + return m.oauth2Server.Token(ctx, req, validateUser) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func (m *Manager) OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string) (*oauth2.AccessToken, error) { + return m.oauth2Server.ClientCredentialsToken(ctx, clientID, clientSecret, scopes) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func (m *Manager) OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator) (*oauth2.AccessToken, error) { + return m.oauth2Server.PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser) +} + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func (m *Manager) GetConfig() *config.Config { + return m.config +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func (m *Manager) GetStorage() adapter.Storage { + return m.storage +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func (m *Manager) GetCodec() adapter.Codec { + return m.serializer +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func (m *Manager) GetLog() adapter.Log { + return m.logger +} + +// GetLogControl returns the logger control interface if available | 获取日志控制接口(如果支持) +func (m *Manager) GetLogControl() adapter.LogControl { + if logControl, ok := m.logger.(adapter.LogControl); ok { + return logControl } - return info.LoginID, nil + return nil } -// getTokenInfo Gets token information | 获取Token信息 -func (m *Manager) getTokenInfo(tokenValue string, checkState ...bool) (*TokenInfo, error) { - tokenKey := m.getTokenKey(tokenValue) - data, err := m.storage.Get(tokenKey) - if err != nil || data == nil { - return nil, err +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func (m *Manager) GetPool() adapter.Pool { + return m.pool +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func (m *Manager) GetGenerator() adapter.Generator { + return m.generator +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func (m *Manager) GetNonceManager() *security.NonceManager { + return m.nonceManager +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func (m *Manager) GetRefreshManager() *security.RefreshTokenManager { + return m.refreshManager +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func (m *Manager) GetEventManager() *listener.Manager { + return m.eventManager +} + +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func (m *Manager) GetOAuth2Server() *oauth2.OAuth2Server { + return m.oauth2Server +} + +// GetDevice extracts device type from optional parameter | 从可选参数中提取设备类型 公开方法 +func (m *Manager) GetDevice(device []string) string { + if len(device) > 0 && strings.TrimSpace(device[0]) != "" { + return device[0] } + return DefaultDevice +} - // Convert storage value to string | 将存储值统一转换为字符串 - var str string - switch v := data.(type) { - case []byte: - str = string(v) - case string: - str = v - default: - return nil, ErrInvalidTokenData +// ============ Internal Methods | 内部方法 ============ + +// getTokenInfo Gets token information by token value | 通过Token值获取Token信息 +func (m *Manager) getTokenInfo(ctx context.Context, tokenValue string, checkState ...bool) (*TokenInfo, error) { + // Retrieve data from storage using the token key | 使用Token键从存储中获取数据 + data, err := m.storage.Get(ctx, m.getTokenKey(tokenValue)) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return nil, core.ErrTokenNotFound + } + + // Convert data to raw byte slice | 将数据转换为原始字节切片 + raw, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) } - // Check for special token states (if enabled) | 检查是否为特殊状态(当启用检查时) - if len(checkState) == 0 || checkState[0] { - switch str { + // Check for special token states (if enabled) | 检查是否为特殊状态 + if len(checkState) > 0 && checkState[0] { + switch string(raw) { case string(TokenStateKickout): - return nil, ErrTokenKickout // 被踢下线 + return nil, core.ErrTokenKickout case string(TokenStateReplaced): - return nil, ErrTokenReplaced // 被顶号下线 + return nil, core.ErrTokenReplaced } } - // Parse TokenInfo from JSON | 从JSON解析Token信息 + // Parse TokenInfo | 解析Token信息 var info TokenInfo - if err := json.Unmarshal([]byte(str), &info); err != nil { - return nil, fmt.Errorf("%w: %v", ErrInvalidTokenData, err) + if err = m.serializer.Decode(raw, &info); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) } return &info, nil } // renewToken Renews token expiration asynchronously | 异步续期Token -func (m *Manager) renewToken(tokenValue string) { - tokenKey := m.getTokenKey(tokenValue) - info, err := m.getTokenInfo(tokenValue) - if err != nil { - return +func (m *Manager) renewToken(ctx context.Context, tokenValue string, info *TokenInfo) error { + // Before renewing the token, check if the user is disabled | 在续期之前,先检查用户是否被禁用 + if m.IsDisable(ctx, info.LoginID) { + return core.ErrAccountDisabled } - // Basic validation | 基本校验 - if info == nil || info.LoginID == "" || info.Device == "" { - return + // If info is nil, retrieve token information | 如果info为空,获取Token信息 + if info == nil { + var err error + if info, err = m.getTokenInfo(ctx, tokenValue); err != nil { + return err + } } - // Update ActiveTime and keep original TTL | 更新 ActiveTime,保持原 TTL 不变 + // Get expiration time | 获取过期时间 + exp := m.getExpiration() + // Update ActiveTime | 更新ActiveTime info.ActiveTime = time.Now().Unix() - if tokenInfo, err := json.Marshal(info); err == nil { - _ = m.storage.SetKeepTTL(tokenKey, tokenInfo) - } - // Extend TTL for token and its accountKey | 为 Token 与对应 accountKey 延长 TTL - exp := m.getExpiration() - if exp > 0 { - // Renew token TTL | 续期 Token TTL - _ = m.storage.Expire(tokenKey, exp) + // Renew token TTL | 续期Token的TTL + tokenInfo, err := m.serializer.Encode(info) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + err = m.storage.Set(ctx, m.getTokenKey(tokenValue), tokenInfo, exp) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } - // Renew accountKey TTL | 续期账号映射 TTL - accountKey := m.getAccountKey(info.LoginID, info.Device) - _ = m.storage.Expire(accountKey, exp) + // Renew accountKey TTL | 续期账号映射的TTL + err = m.storage.Expire(ctx, m.getAccountKey(info.LoginID, info.Device), exp) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } - // Renew session TTL | 续期 Session TTL - if sess, err := m.GetSession(info.LoginID); err == nil && sess != nil { - _ = sess.Renew(exp) - } + // Renew session TTL | 续期Session的TTL + if err = m.RenewSession(ctx, info.LoginID, exp); err != nil { + return err } - // Set minimal renewal interval marker | 设置最小续期间隔标记(限流续期频率) + // Set minimal renewal interval marker | 设置最小续期间隔标记 if m.config.RenewInterval > 0 { - _ = m.storage.Set( + err = m.storage.Set( + ctx, m.getRenewKey(tokenValue), - DefaultRenewValue, + time.Now().Unix(), time.Duration(m.config.RenewInterval)*time.Second, ) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } } + + return nil } // removeTokenChain Removes all related keys and triggers event | 删除Token相关的所有键并触发事件 -func (m *Manager) removeTokenChain(tokenValue string, destroySession bool, event listener.Event) error { - if tokenValue == "" { - return nil - } - - // Get TokenInfo | 获取Token信息 - info, err := m.getTokenInfo(tokenValue, false) - if err != nil { - return err - } +func (m *Manager) removeTokenChain(ctx context.Context, tokenValue string, info *TokenInfo, event listener.Event, destroySession ...bool) error { + // If info is nil, retrieve token information | 如果info为空,获取Token信息 if info == nil { - return ErrInvalidTokenData + var err error + if info, err = m.getTokenInfo(ctx, tokenValue); err != nil { + return err + } } - tokenKey := m.getTokenKey(tokenValue) // Token存储键 | Token storage key - accountKey := m.getAccountKey(info.LoginID, info.Device) // Account映射键 | Account mapping key - renewKey := m.getRenewKey(tokenValue) // 续期追踪键 | Token renewal tracking key + // Construct the token storage key | 构造Token存储键 + tokenKey := m.getTokenKey(tokenValue) + // Construct the account storage key | 构造账号存储键 + accountKey := m.getAccountKey(info.LoginID, info.Device) + // Construct the renewal key | 构造续期标记 + renewKey := m.getRenewKey(tokenValue) + // Handle different events | 处理不同的事件 switch event { // EventLogout User logout | 用户主动登出 case listener.EventLogout: - _ = m.storage.Delete(tokenKey) // Delete token-info mapping | 删除Token信息映射 - _ = m.storage.Delete(accountKey) // Delete account-token mapping | 删除账号映射 - _ = m.storage.Delete(renewKey) // Delete renew key | 删除续期标记 - if destroySession { // Optionally destroy session | 可选销毁Session - _ = m.DeleteSession(info.LoginID) + // Delete token, account mapping, and renew key | 删除Token、账号映射和续期标记 + err := m.storage.Delete(ctx, tokenKey, accountKey, renewKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if len(destroySession) > 0 && destroySession[0] { + err = m.DeleteSession(ctx, info.LoginID) + if err != nil { + return err + } } - // EventKickout User kicked offline (keep session) | 用户被踢下线(保留Session) + // EventKickout User kicked offline | 用户被踢下线 case listener.EventKickout: - _ = m.storage.SetKeepTTL(tokenKey, string(TokenStateKickout)) // Mark token as kicked out (preserve original TTL for cleanup) | 将Token标记为“被踢下线”(保留原TTL以便自动清理) - _ = m.storage.Delete(accountKey) // Delete account mapping | 删除账号映射 - _ = m.storage.Delete(renewKey) // Delete renew key | 删除续期标记 + // Mark as kicked out but keep TTL | 标记为被踢下线,保留原TTL + err := m.storage.SetKeepTTL(ctx, tokenKey, string(TokenStateKickout)) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + // Delete account mapping, renew key | 删除账号映射、续期标记 + err = m.storage.Delete(ctx, accountKey, renewKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + // EventReplace User replaced by new login (keep session) | 用户被顶下线(保留Session,自动过期) + case listener.EventReplace: + // Mark as replaced but keep TTL | 标记为被顶下线,保留原TTL + err := m.storage.SetKeepTTL(ctx, tokenKey, string(TokenStateReplaced)) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + // Delete account mapping, renew key | 删除账号映射、续期标记 + err = m.storage.Delete(ctx, accountKey, renewKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } // Default Unknown event type | 未知事件类型(默认删除) default: - _ = m.storage.Delete(tokenKey) - _ = m.storage.Delete(accountKey) - _ = m.storage.Delete(renewKey) - if destroySession { - _ = m.DeleteSession(info.LoginID) + // Delete token, account mapping, and renew key | 删除Token、账号映射和续期标记 + err := m.storage.Delete(ctx, tokenKey, accountKey, renewKey) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if len(destroySession) > 0 && destroySession[0] { + err = m.DeleteSession(ctx, info.LoginID) + if err != nil { + return err + } } } // Trigger event notification | 触发事件通知 if m.eventManager != nil { m.eventManager.Trigger(&listener.EventData{ - Event: event, - LoginID: info.LoginID, - Token: tokenValue, - Device: info.Device, + Event: event, + AuthType: m.config.AuthType, + LoginID: info.LoginID, + Device: info.Device, + Token: tokenValue, }) } return nil } -// toStringSlice Converts any to []string | 将any转换为[]string -func (m *Manager) toStringSlice(v any) []string { - switch val := v.(type) { - case []string: - return val - case []any: - result := make([]string, 0, len(val)) - for _, item := range val { - if str, ok := item.(string); ok { - result = append(result, str) - } - } - return result - default: - return []string{} - } -} - -// ============ Event Management | 事件管理 ============ +// ============ Internal Helper Methods | 内部辅助方法 ============ -// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 -func (m *Manager) RegisterFunc(event listener.Event, fn func(*listener.EventData)) { - if m.eventManager != nil { - m.eventManager.RegisterFunc(event, fn) - } +// getTokenKey Gets token storage key | 获取Token存储键 +func (m *Manager) getTokenKey(tokenValue string) string { + return m.config.KeyPrefix + m.config.AuthType + TokenKeyPrefix + tokenValue } -// Register registers an event listener | 注册事件监听器 -func (m *Manager) Register(event listener.Event, listener listener.Listener) string { - if m.eventManager != nil { - return m.eventManager.Register(event, listener) - } - return "" +// getAccountKey Gets account storage key | 获取账号存储键 +func (m *Manager) getAccountKey(loginID, device string) string { + return m.config.KeyPrefix + m.config.AuthType + AccountKeyPrefix + loginID + PermissionSeparator + device } -// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 -func (m *Manager) RegisterWithConfig(event listener.Event, listener listener.Listener, config listener.ListenerConfig) string { - if m.eventManager != nil { - return m.eventManager.RegisterWithConfig(event, listener, config) - } - return "" +// getRenewKey Gets token renewal tracking key | 获取Token续期追踪键 +func (m *Manager) getRenewKey(tokenValue string) string { + return m.config.KeyPrefix + m.config.AuthType + RenewKeyPrefix + tokenValue } -// Unregister removes an event listener by ID | 根据ID移除事件监听器 -func (m *Manager) Unregister(id string) bool { - if m.eventManager != nil { - return m.eventManager.Unregister(id) - } - return false +// getDisableKey Gets disable storage key | 获取禁用存储键 +func (m *Manager) getDisableKey(loginID string) string { + return m.config.KeyPrefix + m.config.AuthType + DisableKeyPrefix + loginID } -// TriggerEvent manually triggers an event | 手动触发事件 -func (m *Manager) TriggerEvent(data *listener.EventData) { - if m.eventManager != nil { - m.eventManager.Trigger(data) +// getDevice extracts device type from optional parameter | 从可选参数中提取设备类型 +func getDevice(device []string) string { + if len(device) > 0 && strings.TrimSpace(device[0]) != "" { + return device[0] } + return DefaultDevice } -// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 -func (m *Manager) WaitEvents() { - if m.eventManager != nil { - m.eventManager.Wait() +// getExpiration calculates expiration duration from config | 从配置计算过期时间 +func (m *Manager) getExpiration() time.Duration { + if m.config.Timeout > 0 { + return time.Duration(m.config.Timeout) * time.Second } + return 0 } -// GetEventManager gets the event manager | 获取事件管理器 -func (m *Manager) GetEventManager() *listener.Manager { - return m.eventManager -} - -// ============ Public Getters | 公共获取器 ============ - -// GetConfig Gets configuration | 获取配置 -func (m *Manager) GetConfig() *config.Config { - return m.config -} - -// GetStorage Gets storage | 获取存储 -func (m *Manager) GetStorage() adapter.Storage { - return m.storage -} - -// ============ Security Features | 安全特性 ============ - -// GenerateNonce Generates a one-time nonce | 生成一次性随机数 -func (m *Manager) GenerateNonce() (string, error) { - return m.nonceManager.Generate() -} - -// VerifyNonce Verifies a nonce | 验证随机数 -func (m *Manager) VerifyNonce(nonce string) bool { - return m.nonceManager.Verify(nonce) +// assertString asserts value as string safely | 安全断言值为字符串 +func assertString(v any) (string, bool) { + s, ok := v.(string) + return s, ok } -// LoginWithRefreshToken Logs in with refresh token | 使用刷新令牌登录 -func (m *Manager) LoginWithRefreshToken(loginID, device string) (*security.RefreshTokenInfo, error) { - deviceType := getDevice([]string{device}) - - accessToken, err := m.Login(loginID, deviceType) - if err != nil { - return nil, err +// toStringSlice Converts any to []string | 将any转换为[]string +func (m *Manager) toStringSlice(v any) []string { + switch val := v.(type) { + case []string: + return val + case []any: + result := make([]string, 0, len(val)) + for _, item := range val { + if str, ok := item.(string); ok { + result = append(result, str) + } + } + return result + default: + return []string{} } - - return m.refreshManager.GenerateTokenPair(loginID, deviceType, accessToken) -} - -// RefreshAccessToken Refreshes access token | 刷新访问令牌 -func (m *Manager) RefreshAccessToken(refreshToken string) (*security.RefreshTokenInfo, error) { - return m.refreshManager.RefreshAccessToken(refreshToken) } -// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 -func (m *Manager) RevokeRefreshToken(refreshToken string) error { - return m.refreshManager.RevokeRefreshToken(refreshToken) -} +// removeDuplicateStrings removes duplicate elements from []string | 去重字符串切片 +func removeDuplicateStrings(list []string) []string { + seen := make(map[string]struct{}, len(list)) + result := make([]string, 0, len(list)) -// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 -func (m *Manager) GetOAuth2Server() *oauth2.OAuth2Server { - return m.oauth2Server + for _, v := range list { + if _, exists := seen[v]; !exists { + seen[v] = struct{}{} + result = append(result, v) + } + } + return result } diff --git a/core/oauth2/consts.go b/core/oauth2/consts.go new file mode 100644 index 0000000..cb136aa --- /dev/null +++ b/core/oauth2/consts.go @@ -0,0 +1,33 @@ +// @Author daixk 2025/12/5 9:42:00 +package oauth2 + +import ( + "time" +) + +// Constants for OAuth2 | OAuth2常量 +const ( + DefaultCodeExpiration = 10 * time.Minute // Authorization code expiration | 授权码过期时间 + DefaultTokenExpiration = 2 * time.Hour // Access token expiration | 访问令牌过期时间 + DefaultRefreshTTL = 30 * 24 * time.Hour // Refresh token expiration | 刷新令牌过期时间 + + CodeLength = 32 // Authorization code byte length | 授权码字节长度 + AccessTokenLength = 32 // Access token byte length | 访问令牌字节长度 + RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度 + + CodeKeySuffix = "oauth2:code:" // Code key suffix after prefix | 授权码键后缀 + TokenKeySuffix = "oauth2:token:" // Token key suffix after prefix | 令牌键后缀 + RefreshKeySuffix = "oauth2:refresh:" // Refresh key suffix after prefix | 刷新令牌键后缀 + + TokenTypeBearer = "Bearer" // Token type | 令牌类型 +) + +// GrantType OAuth2 grant type | OAuth2授权类型 +type GrantType string + +const ( + GrantTypeAuthorizationCode GrantType = "authorization_code" // Authorization code flow | 授权码模式 + GrantTypeRefreshToken GrantType = "refresh_token" // Refresh token flow | 刷新令牌模式 + GrantTypeClientCredentials GrantType = "client_credentials" // Client credentials flow | 客户端凭证模式 + GrantTypePassword GrantType = "password" // Password flow | 密码模式 +) diff --git a/core/oauth2/oauth2.go b/core/oauth2/oauth2.go index 1751668..a0050af 100644 --- a/core/oauth2/oauth2.go +++ b/core/oauth2/oauth2.go @@ -1,71 +1,43 @@ package oauth2 import ( + "context" "crypto/rand" "encoding/hex" "fmt" "sync" "time" + codec_json "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/core" "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/utils" + "github.com/click33/sa-token-go/storage/memory" ) -// OAuth2 Authorization Code Flow Implementation -// OAuth2 授权码模式实现 +// Package oauth2 provides OAuth2 authorization server implementation +// OAuth2 授权服务器实现 // -// Flow | 流程: -// 1. RegisterClient() - Register OAuth2 client | 注册OAuth2客户端 -// 2. GenerateAuthorizationCode() - User authorizes, get code | 用户授权,获取授权码 -// 3. ExchangeCodeForToken() - Exchange code for access token | 用授权码换取访问令牌 -// 4. ValidateAccessToken() - Validate access token | 验证访问令牌 -// 5. RefreshAccessToken() - Use refresh token to get new token | 用刷新令牌获取新令牌 +// Supported Grant Types | 支持的授权类型: +// - Authorization Code (authorization_code) | 授权码模式 +// - Client Credentials (client_credentials) | 客户端凭证模式 +// - Password (password) | 密码模式 +// - Refresh Token (refresh_token) | 刷新令牌模式 +// +// Basic Flow | 基本流程: +// 1. RegisterClient() - Register OAuth2 client | 注册OAuth2客户端 +// 2. GenerateAuthorizationCode() - User authorizes, get code | 用户授权,获取授权码 +// 3. Token() or ExchangeCodeForToken() - Exchange code for access token | 用授权码换取访问令牌 +// 4. ValidateAccessToken() - Validate access token | 验证访问令牌 +// 5. RefreshAccessToken() - Use refresh token to get new token | 用刷新令牌获取新令牌 // // Usage | 用法: -// server := oauth2.NewOAuth2Server(storage) -// server.RegisterClient(&oauth2.Client{...}) -// authCode, _ := server.GenerateAuthorizationCode(...) -// token, _ := server.ExchangeCodeForToken(...) - -// Constants for OAuth2 | OAuth2常量 -const ( - DefaultCodeExpiration = 10 * time.Minute // Authorization code expiration | 授权码过期时间 - DefaultTokenExpiration = 2 * time.Hour // Access token expiration | 访问令牌过期时间 - DefaultRefreshTTL = 30 * 24 * time.Hour // Refresh token expiration | 刷新令牌过期时间 - - CodeLength = 32 // Authorization code byte length | 授权码字节长度 - AccessTokenLength = 32 // Access token byte length | 访问令牌字节长度 - RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度 - - CodeKeySuffix = "oauth2:code:" // Code key suffix after prefix | 授权码键后缀 - TokenKeySuffix = "oauth2:token:" // Token key suffix after prefix | 令牌键后缀 - RefreshKeySuffix = "oauth2:refresh:" // Refresh key suffix after prefix | 刷新令牌键后缀 - - TokenTypeBearer = "Bearer" // Token type | 令牌类型 -) - -// Error variables | 错误变量 -var ( - ErrClientNotFound = fmt.Errorf("client not found") - ErrInvalidRedirectURI = fmt.Errorf("invalid redirect_uri") - ErrInvalidClientCredentials = fmt.Errorf("invalid client credentials") - ErrInvalidAuthCode = fmt.Errorf("invalid authorization code") - ErrAuthCodeUsed = fmt.Errorf("authorization code already used") - ErrAuthCodeExpired = fmt.Errorf("authorization code expired") - ErrClientMismatch = fmt.Errorf("client mismatch") - ErrRedirectURIMismatch = fmt.Errorf("redirect_uri mismatch") - ErrInvalidAccessToken = fmt.Errorf("invalid access token") - ErrInvalidTokenData = fmt.Errorf("invalid token data") -) - -// GrantType OAuth2 grant type | OAuth2授权类型 -type GrantType string +// +// server := oauth2.NewOAuth2Server(authType, prefix, storage, serializer) +// server.RegisterClient(&oauth2.Client{...}) +// token, _ := server.Token(ctx, &oauth2.TokenRequest{...}, nil) -const ( - GrantTypeAuthorizationCode GrantType = "authorization_code" // Authorization code flow | 授权码模式 - GrantTypeRefreshToken GrantType = "refresh_token" // Refresh token flow | 刷新令牌模式 - GrantTypeClientCredentials GrantType = "client_credentials" // Client credentials flow | 客户端凭证模式 - GrantTypePassword GrantType = "password" // Password flow | 密码模式 -) +// ============ Type Definitions | 类型定义 ============ // Client OAuth2 client configuration | OAuth2客户端配置 type Client struct { @@ -99,32 +71,62 @@ type AccessToken struct { ClientID string // Client ID | 客户端ID } +// TokenRequest Unified token request structure | 统一的令牌请求结构 +type TokenRequest struct { + GrantType GrantType // Required: grant type | 必需:授权类型 + ClientID string // Required: client ID | 必需:客户端ID + ClientSecret string // Required: client secret | 必需:客户端密钥 + Code string // For authorization_code: authorization code | 授权码模式:授权码 + RedirectURI string // For authorization_code: redirect URI | 授权码模式:回调URI + RefreshToken string // For refresh_token: refresh token | 刷新令牌模式:刷新令牌 + Username string // For password: username | 密码模式:用户名 + Password string // For password: password | 密码模式:密码 + Scopes []string // Optional: requested scopes | 可选:请求的权限范围 +} + +// UserValidator Function type for validating user credentials | 验证用户凭证的函数类型 +type UserValidator func(username, password string) (userID string, err error) + // OAuth2Server OAuth2 authorization server | OAuth2授权服务器 type OAuth2Server struct { - storage adapter.Storage - keyPrefix string // Configurable prefix | 可配置的前缀 - clients map[string]*Client - clientsMu sync.RWMutex // Clients map lock | 客户端映射锁 - codeExpiration time.Duration // Authorization code expiration (10min) | 授权码过期时间(10分钟) - tokenExpiration time.Duration // Access token expiration (2h) | 访问令牌过期时间(2小时) + authType string // Authentication system type | 认证体系类型 + keyPrefix string // Configurable prefix | 可配置的前缀 + clients map[string]*Client // client map | 客户端映射map + clientsMu sync.RWMutex // Clients map lock | 客户端映射锁 + codeExpiration time.Duration // Authorization code expiration (10min) | 授权码过期时间(10分钟) + tokenExpiration time.Duration // Access token expiration (2h) | 访问令牌过期时间(2小时) + serializer adapter.Codec // Codec adapter for encoding and decoding operations | 编解码器适配器 + storage adapter.Storage // Storage adapter (Redis, Memory, etc.) | 存储适配器(如 Redis、Memory) } +// ============ Constructor | 构造函数 ============ + // NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器 -// prefix: key prefix (e.g., "satoken:" or "" for Java compatibility) | 键前缀(如:"satoken:" 或 "" 兼容Java) -func NewOAuth2Server(storage adapter.Storage, prefix string) *OAuth2Server { +func NewOAuth2Server(authType, prefix string, storage adapter.Storage, serializer adapter.Codec) *OAuth2Server { + if storage == nil { + storage = memory.NewStorage() // default in-memory storage | 默认内存存储 + } + if serializer == nil { + serializer = codec_json.NewJSONSerializer() // default JSON serializer | 默认 JSON 编解码器 + } + return &OAuth2Server{ - storage: storage, - keyPrefix: prefix, - clients: make(map[string]*Client), - codeExpiration: DefaultCodeExpiration, - tokenExpiration: DefaultTokenExpiration, + authType: authType, // Auth system identifier | 认证体系标识 + keyPrefix: prefix, // Global key prefix | 全局Key前缀 + clients: make(map[string]*Client), // Initialize client registry | 初始化客户端注册表 + codeExpiration: DefaultCodeExpiration, // Default auth code TTL | 默认授权码有效期 + tokenExpiration: DefaultTokenExpiration, // Default access token TTL | 默认访问令牌有效期 + storage: storage, // Storage backend | 存储后端 + serializer: serializer, // Codec implementation | 编解码实现 } } +// ============ Client Management | 客户端管理 ============ + // RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 func (s *OAuth2Server) RegisterClient(client *Client) error { if client == nil || client.ClientID == "" { - return fmt.Errorf("invalid client: clientID is required") + return core.ErrClientOrClientIDEmpty } s.clientsMu.Lock() @@ -149,15 +151,85 @@ func (s *OAuth2Server) GetClient(clientID string) (*Client, error) { client, exists := s.clients[clientID] if !exists { - return nil, ErrClientNotFound + return nil, core.ErrClientNotFound } + return client, nil } +// ============ Unified Token Endpoint | 统一令牌端点 ============ + +// Token Unified token endpoint that dispatches to appropriate handler based on grant type +// 统一的令牌端点,根据授权类型分发到相应的处理逻辑 +// +// This method provides a single entry point for all OAuth2 token operations. +// 此方法为所有 OAuth2 令牌操作提供统一入口。 +// +// Usage | 用法: +// +// // Authorization Code Grant | 授权码模式 +// token, err := server.Token(ctx, &TokenRequest{ +// GrantType: GrantTypeAuthorizationCode, +// ClientID: "client_id", +// ClientSecret: "client_secret", +// Code: "auth_code", +// RedirectURI: "https://example.com/callback", +// }, nil) +// +// // Client Credentials Grant | 客户端凭证模式 +// token, err := server.Token(ctx, &TokenRequest{ +// GrantType: GrantTypeClientCredentials, +// ClientID: "client_id", +// ClientSecret: "client_secret", +// Scopes: []string{"read", "write"}, +// }, nil) +// +// // Password Grant | 密码模式 +// token, err := server.Token(ctx, &TokenRequest{ +// GrantType: GrantTypePassword, +// ClientID: "client_id", +// ClientSecret: "client_secret", +// Username: "user", +// Password: "pass", +// Scopes: []string{"read"}, +// }, userValidator) +// +// // Refresh Token Grant | 刷新令牌模式 +// token, err := server.Token(ctx, &TokenRequest{ +// GrantType: GrantTypeRefreshToken, +// ClientID: "client_id", +// ClientSecret: "client_secret", +// RefreshToken: "refresh_token", +// }, nil) +func (s *OAuth2Server) Token(ctx context.Context, req *TokenRequest, validateUser UserValidator) (*AccessToken, error) { + if req == nil { + return nil, fmt.Errorf("%w: token request cannot be nil", core.ErrInvalidAuthCode) + } + + switch req.GrantType { + case GrantTypeAuthorizationCode: + return s.ExchangeCodeForToken(ctx, req.Code, req.ClientID, req.ClientSecret, req.RedirectURI) + + case GrantTypeClientCredentials: + return s.ClientCredentialsToken(ctx, req.ClientID, req.ClientSecret, req.Scopes) + + case GrantTypePassword: + return s.PasswordGrantToken(ctx, req.ClientID, req.ClientSecret, req.Username, req.Password, req.Scopes, validateUser) + + case GrantTypeRefreshToken: + return s.RefreshAccessToken(ctx, req.ClientID, req.RefreshToken, req.ClientSecret) + + default: + return nil, core.ErrInvalidGrantType + } +} + +// ============ Authorization Code Grant | 授权码模式 ============ + // GenerateAuthorizationCode Generates authorization code | 生成授权码 -func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID string, scopes []string) (*AuthorizationCode, error) { +func (s *OAuth2Server) GenerateAuthorizationCode(ctx context.Context, clientID, userID, redirectURI string, scopes []string) (*AuthorizationCode, error) { if userID == "" { - return nil, fmt.Errorf("userID cannot be empty") + return nil, core.ErrUserIDEmpty } client, err := s.GetClient(clientID) @@ -167,13 +239,18 @@ func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID s // Validate redirect URI | 验证回调URI if !s.isValidRedirectURI(client, redirectURI) { - return nil, ErrInvalidRedirectURI + return nil, core.ErrInvalidRedirectURI + } + + // Validate scopes | 验证权限范围 + if !s.isValidScopes(client, scopes) { + return nil, core.ErrInvalidScope } // Generate code | 生成授权码 codeBytes := make([]byte, CodeLength) - if _, err := rand.Read(codeBytes); err != nil { - return nil, fmt.Errorf("failed to generate authorization code: %w", err) + if _, err = rand.Read(codeBytes); err != nil { + return nil, err } code := hex.EncodeToString(codeBytes) @@ -188,26 +265,21 @@ func (s *OAuth2Server) GenerateAuthorizationCode(clientID, redirectURI, userID s Used: false, } + encodeData, err := s.serializer.Encode(authCode) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + key := s.getCodeKey(code) - if err := s.storage.Set(key, authCode, s.codeExpiration); err != nil { - return nil, fmt.Errorf("failed to store authorization code: %w", err) + if err := s.storage.Set(ctx, key, encodeData, s.codeExpiration); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } return authCode, nil } -// isValidRedirectURI Checks if redirect URI is valid for client | 检查回调URI是否有效 -func (s *OAuth2Server) isValidRedirectURI(client *Client, redirectURI string) bool { - for _, uri := range client.RedirectURIs { - if uri == redirectURI { - return true - } - } - return false -} - // ExchangeCodeForToken Exchanges authorization code for access token | 用授权码换取访问令牌 -func (s *OAuth2Server) ExchangeCodeForToken(code, clientID, clientSecret, redirectURI string) (*AccessToken, error) { +func (s *OAuth2Server) ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string) (*AccessToken, error) { // Verify client credentials | 验证客户端凭证 client, err := s.GetClient(clientID) if err != nil { @@ -215,110 +287,169 @@ func (s *OAuth2Server) ExchangeCodeForToken(code, clientID, clientSecret, redire } if client.ClientSecret != clientSecret { - return nil, ErrInvalidClientCredentials + return nil, core.ErrInvalidClientCredentials + } + + // Validate grant type | 验证授权类型 + if !s.isValidGrantType(client, GrantTypeAuthorizationCode) { + return nil, core.ErrInvalidGrantType } // Get authorization code | 获取授权码 key := s.getCodeKey(code) - data, err := s.storage.Get(key) - if err != nil || data == nil { - return nil, ErrInvalidAuthCode + data, err := s.storage.Get(ctx, key) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return nil, core.ErrInvalidAuthCode } - authCode, ok := data.(*AuthorizationCode) - if !ok { - return nil, fmt.Errorf("invalid code data") + rawData, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } + + var authCode AuthorizationCode + if err := s.serializer.Decode(rawData, &authCode); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) } - // Validate authorization code | 验证授权码 if authCode.Used { - return nil, ErrAuthCodeUsed + return nil, core.ErrAuthCodeUsed } if authCode.ClientID != clientID { - return nil, ErrClientMismatch + return nil, core.ErrClientMismatch } if authCode.RedirectURI != redirectURI { - return nil, ErrRedirectURIMismatch + return nil, core.ErrRedirectURIMismatch } if time.Now().Unix() > authCode.CreateTime+authCode.ExpiresIn { - s.storage.Delete(key) - return nil, ErrAuthCodeExpired + _ = s.storage.Delete(ctx, key) + return nil, core.ErrAuthCodeExpired } // Mark code as used | 标记为已使用 authCode.Used = true - s.storage.Set(key, authCode, time.Minute) - return s.generateAccessToken(authCode.UserID, authCode.ClientID, authCode.Scopes) + encodeData, err := s.serializer.Encode(authCode) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + + _ = s.storage.Set(ctx, key, encodeData, time.Minute) + + return s.generateAccessToken(ctx, authCode.UserID, authCode.ClientID, authCode.Scopes) } -// generateAccessToken Generates access token and refresh token | 生成访问令牌和刷新令牌 -func (s *OAuth2Server) generateAccessToken(userID, clientID string, scopes []string) (*AccessToken, error) { - // Generate access token | 生成访问令牌 - tokenBytes := make([]byte, AccessTokenLength) - if _, err := rand.Read(tokenBytes); err != nil { - return nil, fmt.Errorf("failed to generate access token: %w", err) +// ============ Client Credentials Grant | 客户端凭证模式 ============ + +// ClientCredentialsToken Gets access token using client credentials grant +// 使用客户端凭证模式获取访问令牌 +// +// This grant type is used for server-to-server communication where no user is involved. +// The client authenticates with its own credentials and receives an access token. +// 此授权类型用于服务器间通信,无需用户参与。客户端使用自己的凭证进行认证并获取访问令牌。 +// +// Usage | 用法: +// +// token, err := server.ClientCredentialsToken(ctx, "client_id", "client_secret", []string{"read", "write"}) +func (s *OAuth2Server) ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string) (*AccessToken, error) { + // Verify client credentials | 验证客户端凭证 + client, err := s.GetClient(clientID) + if err != nil { + return nil, err } - accessToken := hex.EncodeToString(tokenBytes) - // Generate refresh token | 生成刷新令牌 - refreshBytes := make([]byte, RefreshTokenLength) - if _, err := rand.Read(refreshBytes); err != nil { - return nil, fmt.Errorf("failed to generate refresh token: %w", err) + if client.ClientSecret != clientSecret { + return nil, core.ErrInvalidClientCredentials } - refreshToken := hex.EncodeToString(refreshBytes) - token := &AccessToken{ - Token: accessToken, - TokenType: TokenTypeBearer, - ExpiresIn: int64(s.tokenExpiration.Seconds()), - RefreshToken: refreshToken, - Scopes: scopes, - UserID: userID, - ClientID: clientID, + // Validate grant type | 验证授权类型 + if !s.isValidGrantType(client, GrantTypeClientCredentials) { + return nil, core.ErrInvalidGrantType } - tokenKey := s.getTokenKey(accessToken) - refreshKey := s.getRefreshKey(refreshToken) + // Validate scopes | 验证权限范围 + if !s.isValidScopes(client, scopes) { + return nil, core.ErrInvalidScope + } - // Store access token | 存储访问令牌 - if err := s.storage.Set(tokenKey, token, s.tokenExpiration); err != nil { - return nil, fmt.Errorf("failed to store access token: %w", err) + // For client credentials, userID is the clientID itself | 客户端凭证模式下,userID 就是 clientID + return s.generateAccessToken(ctx, clientID, clientID, scopes) +} + +// ============ Password Grant | 密码模式 ============ + +// PasswordGrantToken Gets access token using resource owner password credentials grant +// 使用密码模式获取访问令牌 +// +// This grant type is used when the application is highly trusted (e.g., official app). +// The user provides their username and password directly to the client. +// 此授权类型用于高度信任的应用(如官方App)。用户直接向客户端提供用户名和密码。 +// +// SECURITY WARNING: This grant type should only be used when other flows are not viable. +// 安全警告:仅在其他授权流程不可行时才应使用此授权类型。 +// +// Usage | 用法: +// +// validator := func(username, password string) (string, error) { +// // Validate user credentials from your user store +// if user := userService.Authenticate(username, password); user != nil { +// return user.ID, nil +// } +// return "", errors.New("invalid credentials") +// } +// token, err := server.PasswordGrantToken(ctx, "client_id", "client_secret", "user", "pass", scopes, validator) +func (s *OAuth2Server) PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser UserValidator) (*AccessToken, error) { + if validateUser == nil { + return nil, fmt.Errorf("%w: user validator function is required", core.ErrInvalidUserCredentials) } - // Store refresh token | 存储刷新令牌 - if err := s.storage.Set(refreshKey, token, DefaultRefreshTTL); err != nil { - return nil, fmt.Errorf("failed to store refresh token: %w", err) + // Verify client credentials | 验证客户端凭证 + client, err := s.GetClient(clientID) + if err != nil { + return nil, err } - return token, nil -} + if client.ClientSecret != clientSecret { + return nil, core.ErrInvalidClientCredentials + } -// ValidateAccessToken Validates access token | 验证访问令牌 -func (s *OAuth2Server) ValidateAccessToken(tokenString string) (*AccessToken, error) { - if tokenString == "" { - return nil, ErrInvalidAccessToken + // Validate grant type | 验证授权类型 + if !s.isValidGrantType(client, GrantTypePassword) { + return nil, core.ErrInvalidGrantType } - key := s.getTokenKey(tokenString) - data, err := s.storage.Get(key) - if err != nil || data == nil { - return nil, ErrInvalidAccessToken + // Validate scopes | 验证权限范围 + if !s.isValidScopes(client, scopes) { + return nil, core.ErrInvalidScope } - token, ok := data.(*AccessToken) - if !ok { - return nil, ErrInvalidTokenData + // Validate user credentials | 验证用户凭证 + userID, err := validateUser(username, password) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrInvalidUserCredentials, err) } - return token, nil + if userID == "" { + return nil, core.ErrUserIDEmpty + } + + return s.generateAccessToken(ctx, userID, clientID, scopes) } +// ============ Refresh Token Grant | 刷新令牌模式 ============ + // RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 -func (s *OAuth2Server) RefreshAccessToken(refreshToken, clientID, clientSecret string) (*AccessToken, error) { +func (s *OAuth2Server) RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string) (*AccessToken, error) { + if refreshToken == "" { + return nil, core.ErrInvalidRefreshToken + } + // Verify client credentials | 验证客户端凭证 client, err := s.GetClient(clientID) if err != nil { @@ -326,66 +457,231 @@ func (s *OAuth2Server) RefreshAccessToken(refreshToken, clientID, clientSecret s } if client.ClientSecret != clientSecret { - return nil, ErrInvalidClientCredentials + return nil, core.ErrInvalidClientCredentials + } + + // Validate grant type | 验证授权类型 + if !s.isValidGrantType(client, GrantTypeRefreshToken) { + return nil, core.ErrInvalidGrantType } // Get refresh token | 获取刷新令牌 key := s.getRefreshKey(refreshToken) - data, err := s.storage.Get(key) - if err != nil || data == nil { - return nil, fmt.Errorf("invalid refresh token") + data, err := s.storage.Get(ctx, key) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return nil, core.ErrInvalidRefreshToken + } + + rawData, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) } - oldToken, ok := data.(*AccessToken) - if !ok { - return nil, fmt.Errorf("invalid refresh token data") + var accessTokenInfo AccessToken + err = s.serializer.Decode(rawData, &accessTokenInfo) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) } - if oldToken.ClientID != clientID { - return nil, ErrClientMismatch + if accessTokenInfo.ClientID != clientID { + return nil, core.ErrClientMismatch } - // Delete old access token | 删除旧的访问令牌 - oldTokenKey := s.getTokenKey(oldToken.Token) - s.storage.Delete(oldTokenKey) + // Delete old access token | 删除旧访问令牌 + _ = s.storage.Delete(ctx, s.getTokenKey(accessTokenInfo.Token)) + + // Delete old refresh token (token rotation) | 删除旧刷新令牌(令牌轮换) + _ = s.storage.Delete(ctx, key) - return s.generateAccessToken(oldToken.UserID, oldToken.ClientID, oldToken.Scopes) + return s.generateAccessToken(ctx, accessTokenInfo.UserID, accessTokenInfo.ClientID, accessTokenInfo.Scopes) } +// ============ Token Validation | 令牌验证 ============ + +// ValidateAccessToken Validates access token | 验证访问令牌 +func (s *OAuth2Server) ValidateAccessToken(ctx context.Context, accessToken string) bool { + return s.storage.Exists(ctx, s.getTokenKey(accessToken)) +} + +// ValidateAccessTokenAndGetInfo Validates access token and get info | 验证访问令牌并获取信息 +func (s *OAuth2Server) ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string) (*AccessToken, error) { + if accessToken == "" { + return nil, core.ErrInvalidAccessToken + } + + key := s.getTokenKey(accessToken) + data, err := s.storage.Get(ctx, key) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return nil, core.ErrInvalidAccessToken + } + + rawData, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } + + var accessTokenInfo AccessToken + err = s.serializer.Decode(rawData, &accessTokenInfo) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) + } + + return &accessTokenInfo, nil +} + +// ============ Token Revocation | 令牌撤销 ============ + // RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 -func (s *OAuth2Server) RevokeToken(tokenString string) error { - if tokenString == "" { +func (s *OAuth2Server) RevokeToken(ctx context.Context, accessToken string) error { + if accessToken == "" { return nil } - key := s.getTokenKey(tokenString) - data, err := s.storage.Get(key) + key := s.getTokenKey(accessToken) + data, err := s.storage.Get(ctx, key) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return core.ErrInvalidAccessToken + } + + rawData, err := utils.ToBytes(data) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } + + var accessTokenInfo AccessToken + err = s.serializer.Decode(rawData, &accessTokenInfo) if err != nil { - return err + return fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) } - // Revoke refresh token if exists | 如果存在则撤销刷新令牌 - if token, ok := data.(*AccessToken); ok && token.RefreshToken != "" { - refreshKey := s.getRefreshKey(token.RefreshToken) - s.storage.Delete(refreshKey) + if accessTokenInfo.RefreshToken != "" { + _ = s.storage.Delete(ctx, s.getRefreshKey(accessTokenInfo.RefreshToken)) } - return s.storage.Delete(key) + return s.storage.Delete(ctx, key) } -// ============ Helper Methods | 辅助方法 ============ +// ============ Private Helper Methods | 私有辅助方法 ============ // getCodeKey Gets storage key for authorization code | 获取授权码的存储键 func (s *OAuth2Server) getCodeKey(code string) string { - return s.keyPrefix + CodeKeySuffix + code + return s.keyPrefix + s.authType + CodeKeySuffix + code } // getTokenKey Gets storage key for access token | 获取访问令牌的存储键 func (s *OAuth2Server) getTokenKey(token string) string { - return s.keyPrefix + TokenKeySuffix + token + return s.keyPrefix + s.authType + TokenKeySuffix + token } // getRefreshKey Gets storage key for refresh token | 获取刷新令牌的存储键 func (s *OAuth2Server) getRefreshKey(refreshToken string) string { - return s.keyPrefix + RefreshKeySuffix + refreshToken + return s.keyPrefix + s.authType + RefreshKeySuffix + refreshToken +} + +// isValidRedirectURI Checks if redirect URI is valid for client | 检查回调URI是否有效 +func (s *OAuth2Server) isValidRedirectURI(client *Client, redirectURI string) bool { + for _, uri := range client.RedirectURIs { + if uri == redirectURI { + return true + } + } + return false +} + +// isValidScopes Checks if requested scopes are allowed for client | 检查请求的权限范围是否被允许 +func (s *OAuth2Server) isValidScopes(client *Client, scopes []string) bool { + // If no scopes requested, allow | 如果没有请求scope,允许 + if len(scopes) == 0 { + return true + } + + // If client has no scope restrictions, allow all | 如果客户端没有scope限制,允许所有 + if len(client.Scopes) == 0 { + return true + } + + // Build allowed scopes set | 构建允许的scope集合 + allowedScopes := make(map[string]struct{}, len(client.Scopes)) + for _, scope := range client.Scopes { + allowedScopes[scope] = struct{}{} + } + + // Check if all requested scopes are allowed | 检查所有请求的scope是否都被允许 + for _, scope := range scopes { + if _, ok := allowedScopes[scope]; !ok { + return false + } + } + + return true +} + +// isValidGrantType Checks if grant type is allowed for client | 检查授权类型是否被允许 +func (s *OAuth2Server) isValidGrantType(client *Client, grantType GrantType) bool { + // If client has no grant type restrictions, allow all | 如果客户端没有授权类型限制,允许所有 + if len(client.GrantTypes) == 0 { + return true + } + + for _, gt := range client.GrantTypes { + if gt == grantType { + return true + } + } + return false +} + +// generateAccessToken Generates access token and refresh token | 生成访问令牌和刷新令牌 +func (s *OAuth2Server) generateAccessToken(ctx context.Context, userID, clientID string, scopes []string) (*AccessToken, error) { + // Generate access token | 生成访问令牌 + tokenBytes := make([]byte, AccessTokenLength) + if _, err := rand.Read(tokenBytes); err != nil { + return nil, err + } + accessToken := hex.EncodeToString(tokenBytes) + + // Generate refresh token | 生成刷新令牌 + refreshBytes := make([]byte, RefreshTokenLength) + if _, err := rand.Read(refreshBytes); err != nil { + return nil, err + } + refreshToken := hex.EncodeToString(refreshBytes) + + token := &AccessToken{ + Token: accessToken, + TokenType: TokenTypeBearer, + ExpiresIn: int64(s.tokenExpiration.Seconds()), + RefreshToken: refreshToken, + Scopes: scopes, + UserID: userID, + ClientID: clientID, + } + encodeData, err := s.serializer.Encode(token) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + + tokenKey := s.getTokenKey(accessToken) + refreshKey := s.getRefreshKey(refreshToken) + + // Store access token | 存储访问令牌 + if err = s.storage.Set(ctx, tokenKey, encodeData, s.tokenExpiration); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + // Store refresh token | 存储刷新令牌 + if err = s.storage.Set(ctx, refreshKey, encodeData, DefaultRefreshTTL); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + return token, nil } diff --git a/core/pool/pool.go b/core/pool/pool.go deleted file mode 100644 index 57f35a8..0000000 --- a/core/pool/pool.go +++ /dev/null @@ -1,269 +0,0 @@ -// @Author daixk 2025-10-28 22:00:20 -package pool - -import ( - "fmt" - "sync" - "time" - - "github.com/panjf2000/ants/v2" -) - -// Default configuration constants | 默认配置常量 -const ( - DefaultMinSize = 100 // Minimum pool size | 最小协程数 - DefaultMaxSize = 2000 // Maximum pool size | 最大协程数 - DefaultScaleUpRate = 0.8 // Scale-up threshold (expand when usage exceeds this ratio) | 扩容阈值,当使用率超过此比例时扩容 - DefaultScaleDownRate = 0.3 // Scale-down threshold (shrink when usage below this ratio) | 缩容阈值,当使用率低于此比例时缩容 - DefaultCheckInterval = time.Minute // Interval for auto-scaling checks | 检查间隔 - DefaultExpiry = 10 * time.Second // Idle worker expiry duration | 空闲协程过期时间 -) - -// RenewPoolConfig configuration for the renewal pool manager | 续期池配置 -type RenewPoolConfig struct { - MinSize int // Minimum pool size | 最小协程数 - MaxSize int // Maximum pool size | 最大协程数 - ScaleUpRate float64 // Scale-up threshold | 扩容阈值 - ScaleDownRate float64 // Scale-down threshold | 缩容阈值 - CheckInterval time.Duration // Auto-scale check interval | 检查间隔 - Expiry time.Duration // Idle worker expiry duration | 空闲协程过期时间 - PrintStatusInterval time.Duration // Interval for periodic status printing (0 = disabled) | 定时打印池状态的间隔(0表示关闭) - PreAlloc bool // Whether to pre-allocate memory | 是否预分配内存 - NonBlocking bool // Whether to use non-blocking mode | 是否为非阻塞模式 -} - -// DefaultRenewPoolConfig returns default configuration | 返回默认配置 -func DefaultRenewPoolConfig() *RenewPoolConfig { - return &RenewPoolConfig{ - MinSize: DefaultMinSize, - MaxSize: DefaultMaxSize, - ScaleUpRate: DefaultScaleUpRate, - ScaleDownRate: DefaultScaleDownRate, - CheckInterval: DefaultCheckInterval, - Expiry: DefaultExpiry, - PreAlloc: false, - NonBlocking: true, - } -} - -// RenewPoolManager manages a dynamic scaling goroutine pool for token renewal tasks | 续期任务协程池管理器 -type RenewPoolManager struct { - pool *ants.Pool // ants pool instance | ants 协程池实例 - config *RenewPoolConfig // Configuration object | 池配置对象 - mu sync.Mutex // Synchronization lock | 互斥锁 - stopCh chan struct{} // Stop signal channel | 停止信号通道 - started bool // Indicates if pool manager is running | 是否已启动 -} - -// NewRenewPoolManagerWithConfig creates manager with config | 使用配置创建续期池管理器 -func NewRenewPoolManagerWithConfig(cfg *RenewPoolConfig) (*RenewPoolManager, error) { - if cfg == nil { - cfg = DefaultRenewPoolConfig() - } - if cfg.MinSize <= 0 { - cfg.MinSize = DefaultMinSize - } - if cfg.MaxSize < cfg.MinSize { - cfg.MaxSize = cfg.MinSize - } - - mgr := &RenewPoolManager{ - config: cfg, - stopCh: make(chan struct{}), - started: true, - } - - if err := mgr.initPool(); err != nil { - return nil, err - } - - // Start auto-scaling routine | 启动自动扩缩容协程 - go mgr.autoScale() - - // Start periodic pool status printer if interval is set | 若设置了打印间隔,则启动定时打印池状态的协程 - if cfg.PrintStatusInterval > 0 { - go func() { - ticker := time.NewTicker(cfg.PrintStatusInterval) // Create ticker for status printing | 创建定时器用于打印状态 - defer ticker.Stop() // Stop ticker on exit | 退出时停止定时器 - - for { - select { - case <-ticker.C: - mgr.PrintStatus() // Print current pool status | 打印当前协程池状态 - case <-mgr.stopCh: - return // Exit when stop signal received | 收到停止信号后退出 - } - } - }() - } - - return mgr, nil -} - -// initPool initializes the ants pool | 初始化 ants 协程池 -func (m *RenewPoolManager) initPool() error { - p, err := ants.NewPool( - m.config.MinSize, - ants.WithExpiryDuration(m.config.Expiry), - ants.WithPreAlloc(m.config.PreAlloc), - ants.WithNonblocking(m.config.NonBlocking), - ) - if err != nil { - return err - } - m.pool = p - return nil -} - -// Submit submits a renewal task | 提交续期任务 -func (m *RenewPoolManager) Submit(task func()) error { - if !m.started { - return fmt.Errorf("RenewPool not started") - } - return m.pool.Submit(task) -} - -// Stop stops the auto-scaling process | 停止自动扩缩容 -func (m *RenewPoolManager) Stop() { - if !m.started { - return - } - close(m.stopCh) - m.started = false - - if m.pool != nil && !m.pool.IsClosed() { - _ = m.pool.ReleaseTimeout(10 * time.Second) - } -} - -// autoScale automatic pool scale-up/down logic | 自动扩缩容逻辑 -func (m *RenewPoolManager) autoScale() { - ticker := time.NewTicker(m.config.CheckInterval) // Ticker for periodic usage checks | 定时器,用于定期检测使用率 - defer ticker.Stop() // Stop ticker on exit | 函数退出时停止定时器 - - for { - select { - case <-ticker.C: - m.mu.Lock() // Protect concurrent access | 加锁防止并发冲突 - - // Get current pool stats | 获取当前运行状态 - running := m.pool.Running() // Number of active goroutines | 当前正在执行的任务数 - capacity := m.pool.Cap() // Current pool capacity | 当前协程池容量 - usage := float64(running) / float64(capacity) // Current usage ratio | 当前使用率(运行数 ÷ 总容量) - - switch { - // Expand if usage exceeds threshold and capacity < MaxSize | 当使用率超过扩容阈值且容量小于最大值时扩容 - case usage > m.config.ScaleUpRate && capacity < m.config.MaxSize: - newCap := int(float64(capacity) * 1.5) // Increase capacity by 1.5x | 扩容为当前的 1.5 倍 - if newCap > m.config.MaxSize { // Cap to maximum size | 限制最大值 - newCap = m.config.MaxSize - } - m.pool.Tune(newCap) // Apply new pool capacity | 调整 ants 池容量 - - // Reduce if usage below threshold and capacity > MinSize | 当使用率低于缩容阈值且容量大于最小值时缩容 - case usage < m.config.ScaleDownRate && capacity > m.config.MinSize: - newCap := int(float64(capacity) * 0.7) // Reduce capacity to 70% | 缩容为当前的 70% - if newCap < m.config.MinSize { // Ensure not below MinSize | 限制最小值 - newCap = m.config.MinSize - } - m.pool.Tune(newCap) // Apply new pool capacity | 调整 ants 池容量 - } - - m.mu.Unlock() // Unlock after adjustment | 解锁 - - case <-m.stopCh: - // Stop signal received, exit loop | 收到停止信号,终止扩缩容协程 - return - } - } -} - -// Stats returns current pool statistics | 返回当前池状态 -func (m *RenewPoolManager) Stats() (running, capacity int, usage float64) { - m.mu.Lock() - defer m.mu.Unlock() - running = m.pool.Running() // Active tasks | 当前运行任务数 - capacity = m.pool.Cap() // Pool capacity | 当前池容量 - usage = float64(running) / float64(capacity) // Usage ratio | 当前使用率 - return -} - -// PrintStatus prints current pool status | 打印池状态 -func (m *RenewPoolManager) PrintStatus() { - r, c, u := m.Stats() - fmt.Printf("RenewPool Running: %d, Capacity: %d, Usage: %.1f%%\n", r, c, u*100) -} - -// RenewPoolBuilder builder for RenewPoolManager | RenewPoolManager 构造器 -type RenewPoolBuilder struct { - cfg *RenewPoolConfig // Builder configuration | 构造器配置对象 -} - -// NewRenewPoolBuilder creates a new builder | 创建构造器 -func NewRenewPoolBuilder() *RenewPoolBuilder { - return &RenewPoolBuilder{cfg: DefaultRenewPoolConfig()} -} - -// MinSize sets minimum pool size | 设置最小协程数 -func (b *RenewPoolBuilder) MinSize(size int) *RenewPoolBuilder { - b.cfg.MinSize = size - return b -} - -// MaxSize sets maximum pool size | 设置最大协程数 -func (b *RenewPoolBuilder) MaxSize(size int) *RenewPoolBuilder { - b.cfg.MaxSize = size - return b -} - -// ScaleUpRate sets the threshold for scaling up | 设置扩容阈值 -func (b *RenewPoolBuilder) ScaleUpRate(up float64) *RenewPoolBuilder { - b.cfg.ScaleUpRate = up - return b -} - -// ScaleDownRate sets the threshold for scaling down | 设置缩容阈值 -func (b *RenewPoolBuilder) ScaleDownRate(down float64) *RenewPoolBuilder { - b.cfg.ScaleDownRate = down - return b -} - -// CheckInterval sets auto-scaling check interval | 设置检查间隔 -func (b *RenewPoolBuilder) CheckInterval(interval time.Duration) *RenewPoolBuilder { - b.cfg.CheckInterval = interval - return b -} - -// Expiry sets worker expiry duration | 设置空闲协程过期时间 -func (b *RenewPoolBuilder) Expiry(expiry time.Duration) *RenewPoolBuilder { - b.cfg.Expiry = expiry - return b -} - -// PrintStatusInterval sets the interval for printing pool status | 设置打印状态的间隔 -func (b *RenewPoolBuilder) PrintStatusInterval(interval time.Duration) *RenewPoolBuilder { - b.cfg.PrintStatusInterval = interval - return b -} - -// PreAlloc sets pre-allocation flag | 设置是否预分配内存 -func (b *RenewPoolBuilder) PreAlloc(prealloc bool) *RenewPoolBuilder { - b.cfg.PreAlloc = prealloc - return b -} - -// NonBlocking sets non-blocking mode | 设置是否非阻塞模式 -func (b *RenewPoolBuilder) NonBlocking(nonblocking bool) *RenewPoolBuilder { - b.cfg.NonBlocking = nonblocking - return b -} - -// Config returns the current RenewPoolConfig | 返回当前的续期池配置 -func (b *RenewPoolBuilder) Config() *RenewPoolConfig { - return b.cfg -} - -// Build constructs a RenewPoolManager instance | 构建 RenewPoolManager 实例 -func (b *RenewPoolBuilder) Build() (*RenewPoolManager, error) { - return NewRenewPoolManagerWithConfig(b.cfg) -} diff --git a/core/satoken.go b/core/satoken.go deleted file mode 100644 index fff2046..0000000 --- a/core/satoken.go +++ /dev/null @@ -1,191 +0,0 @@ -package core - -import ( - "time" - - "github.com/click33/sa-token-go/core/adapter" - "github.com/click33/sa-token-go/core/builder" - "github.com/click33/sa-token-go/core/config" - "github.com/click33/sa-token-go/core/context" - "github.com/click33/sa-token-go/core/listener" - "github.com/click33/sa-token-go/core/manager" - "github.com/click33/sa-token-go/core/oauth2" - "github.com/click33/sa-token-go/core/security" - "github.com/click33/sa-token-go/core/session" - "github.com/click33/sa-token-go/core/token" - "github.com/click33/sa-token-go/core/utils" -) - -// Version Sa-Token-Go version | Sa-Token-Go版本 -const Version = "0.1.3" - -// ============ Exported Types | 导出的类型 ============ -// Export main types and functions for external use | 导出主要类型和函数,方便外部使用 - -// Configuration related types | 配置相关类型 -type ( - Config = config.Config - CookieConfig = config.CookieConfig - TokenStyle = config.TokenStyle -) - -// Token style constants | Token风格常量 -const ( - TokenStyleUUID = config.TokenStyleUUID - TokenStyleSimple = config.TokenStyleSimple - TokenStyleRandom32 = config.TokenStyleRandom32 - TokenStyleRandom64 = config.TokenStyleRandom64 - TokenStyleRandom128 = config.TokenStyleRandom128 - TokenStyleJWT = config.TokenStyleJWT - TokenStyleHash = config.TokenStyleHash - TokenStyleTimestamp = config.TokenStyleTimestamp - TokenStyleTik = config.TokenStyleTik -) - -// Core types | 核心类型 -type ( - Manager = manager.Manager - TokenInfo = manager.TokenInfo - Session = session.Session - TokenGenerator = token.Generator - SaTokenContext = context.SaTokenContext - Builder = builder.Builder - NonceManager = security.NonceManager - RefreshTokenInfo = security.RefreshTokenInfo - RefreshTokenManager = security.RefreshTokenManager - OAuth2Server = oauth2.OAuth2Server - OAuth2Client = oauth2.Client - OAuth2AccessToken = oauth2.AccessToken - OAuth2GrantType = oauth2.GrantType -) - -// Adapter interfaces | 适配器接口 -type ( - Storage = adapter.Storage - RequestContext = adapter.RequestContext -) - -// Event related types | 事件相关类型 -type ( - EventListener = listener.Listener - EventManager = listener.Manager - EventData = listener.EventData - Event = listener.Event - ListenerFunc = listener.ListenerFunc - ListenerConfig = listener.ListenerConfig -) - -// Event constants | 事件常量 -const ( - EventLogin = listener.EventLogin - EventLogout = listener.EventLogout - EventKickout = listener.EventKickout - EventDisable = listener.EventDisable - EventUntie = listener.EventUntie - EventRenew = listener.EventRenew - EventCreateSession = listener.EventCreateSession - EventDestroySession = listener.EventDestroySession - EventPermissionCheck = listener.EventPermissionCheck - EventRoleCheck = listener.EventRoleCheck - EventAll = listener.EventAll -) - -const ( - GrantTypeAuthorizationCode = oauth2.GrantTypeAuthorizationCode - GrantTypeRefreshToken = oauth2.GrantTypeRefreshToken - GrantTypeClientCredentials = oauth2.GrantTypeClientCredentials - GrantTypePassword = oauth2.GrantTypePassword -) - -// ============ Utility Functions | 工具函数 ============ - -var ( - // String utilities | 字符串工具 - RandomString = utils.RandomString - RandomNumericString = utils.RandomNumericString - RandomAlphanumeric = utils.RandomAlphanumeric - IsEmpty = utils.IsEmpty - IsNotEmpty = utils.IsNotEmpty - DefaultString = utils.DefaultString - - // Slice utilities | 切片工具 - ContainsString = utils.ContainsString - RemoveString = utils.RemoveString - UniqueStrings = utils.UniqueStrings - MergeStrings = utils.MergeStrings - FilterStrings = utils.FilterStrings - MapStrings = utils.MapStrings - - // Pattern matching | 模式匹配 - MatchPattern = utils.MatchPattern - - // Duration utilities | 时长工具 - FormatDuration = utils.FormatDuration - ParseDuration = utils.ParseDuration - - // Hash & Encoding | 哈希和编码 - SHA256Hash = utils.SHA256Hash - Base64Encode = utils.Base64Encode - Base64Decode = utils.Base64Decode -) - -// ============ Factory Functions | 工厂函数 ============ - -// DefaultConfig Returns default configuration | 返回默认配置 -func DefaultConfig() *Config { - return config.DefaultConfig() -} - -// NewManager Creates a new authentication manager | 创建新的认证管理器 -func NewManager(storage Storage, cfg *Config) *Manager { - return manager.NewManager(storage, cfg) -} - -// NewContext Creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext { - return context.NewContext(ctx, mgr) -} - -// NewSession Creates a new session | 创建新的Session -func NewSession(id string, storage Storage, prefix string) *Session { - return session.NewSession(id, storage, prefix) -} - -// LoadSession Loads an existing session | 加载已存在的Session -func LoadSession(id string, storage Storage, prefix string) (*Session, error) { - return session.Load(id, storage, prefix) -} - -// NewTokenGenerator Creates a new token generator | 创建新的Token生成器 -func NewTokenGenerator(cfg *Config) *TokenGenerator { - return token.NewGenerator(cfg) -} - -// NewEventManager Creates a new event manager | 创建新的事件管理器 -func NewEventManager() *EventManager { - return listener.NewManager() -} - -// NewBuilder Creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置) -func NewBuilder() *Builder { - return builder.NewBuilder() -} - -// NewNonceManager Creates a new nonce manager | 创建新的Nonce管理器 -func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { - var duration time.Duration - if len(ttl) > 0 && ttl[0] > 0 { - duration = time.Duration(ttl[0]) * time.Second - } - return security.NewNonceManager(storage, prefix, duration) -} - -// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器 -func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { - return security.NewRefreshTokenManager(storage, prefix, manager.TokenKeyPrefix, cfg) -} - -// NewOAuth2Server Creates a new OAuth2 server | 创建新的OAuth2服务器 -func NewOAuth2Server(storage Storage, prefix string) *OAuth2Server { - return oauth2.NewOAuth2Server(storage, prefix) -} diff --git a/core/security/consts.go b/core/security/consts.go new file mode 100644 index 0000000..e6b1ba9 --- /dev/null +++ b/core/security/consts.go @@ -0,0 +1,21 @@ +// @Author daixk 2025/12/11 22:20:00 +package security + +import ( + "time" +) + +// Constants for nonce | Nonce常量 +const ( + DefaultNonceTTL = 5 * time.Minute // Default nonce expiration | 默认nonce过期时间 + NonceLength = 32 // Nonce byte length | Nonce字节长度 + NonceKeySuffix = "nonce:" // Key suffix after prefix | 前缀后的键后缀 +) + +// Constants for refresh token | 刷新令牌常量 +const ( + DefaultRefreshTTL = 30 * 24 * time.Hour // 30 days | 30天 + DefaultAccessTTL = 2 * time.Hour // 2 hours | 2小时 + RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度 + RefreshKeySuffix = "refresh:" // Key suffix after prefix | 前缀后的键后缀 +) diff --git a/core/security/nonce.go b/core/security/nonce.go index 9b9dd54..2565237 100644 --- a/core/security/nonce.go +++ b/core/security/nonce.go @@ -1,122 +1,122 @@ package security import ( + "context" "crypto/rand" "encoding/hex" "fmt" + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/storage/memory" "sync" "time" - - "github.com/click33/sa-token-go/core/adapter" ) -// Nonce Anti-Replay Attack Implementation -// Nonce 防重放攻击实现 +// Nonce Anti-Replay Attack Implementation | Nonce 防重放攻击实现 // // Flow | 流程: // 1. Generate() - Create unique nonce and store with TTL | 生成唯一nonce并存储(带过期时间) // 2. Verify() - Check existence and delete (one-time use) | 检查存在性并删除(一次性使用) -// 3. Auto-expire after TTL (default 5min) | TTL后自动过期(默认5分钟) +// 3. Auto-expire after TTL (log 5min) | TTL后自动过期(默认5分钟) // // Usage | 用法: -// nonce, _ := manager.GenerateNonce() -// valid := manager.VerifyNonce(nonce) // true -// valid = manager.VerifyNonce(nonce) // false (replay prevented) - -// Constants for nonce | Nonce常量 -const ( - DefaultNonceTTL = 5 * time.Minute // Default nonce expiration | 默认nonce过期时间 - NonceLength = 32 // Nonce byte length | Nonce字节长度 - NonceKeySuffix = "nonce:" // Key suffix after prefix | 前缀后的键后缀 -) - -// Error variables | 错误变量 -var ( - ErrInvalidNonce = fmt.Errorf("invalid or expired nonce") -) +// nonce, _ := manager-example.Generate() +// valid := manager-example.Verify(nonce) // true +// valid = manager-example.Verify(nonce) // false (replay prevented) -// NonceManager Nonce manager for anti-replay attacks | Nonce管理器,用于防重放攻击 +// NonceManager Nonce manager-example for anti-replay attacks | Nonce管理器,用于防重放攻击 type NonceManager struct { - storage adapter.Storage - keyPrefix string // Configurable prefix | 可配置的前缀 - ttl time.Duration - mu sync.RWMutex + authType string // Authentication system type | 认证体系类型 + keyPrefix string // Configurable prefix | 可配置的前缀 + ttl time.Duration // Nonce TTL | Nonce有效期 + mu sync.RWMutex // RWMutex for concurrent access | 并发访问读写锁 + storage adapter.Storage // Storage adapter (Redis, Memory, etc.) | 存储适配器(如 Redis、Memory) } -// NewNonceManager Creates a new nonce manager | 创建新的Nonce管理器 -// prefix: key prefix (e.g., "satoken:" or "" for Java compatibility) | 键前缀(如:"satoken:" 或 "" 兼容Java) -// ttl: time to live, default 5 minutes | 过期时间,默认5分钟 -func NewNonceManager(storage adapter.Storage, prefix string, ttl time.Duration) *NonceManager { +// NewNonceManager Creates a new nonce manager-example | 创建新的Nonce管理器 +func NewNonceManager(authType, prefix string, storage adapter.Storage, ttl time.Duration) *NonceManager { if ttl == 0 { - ttl = DefaultNonceTTL + ttl = DefaultNonceTTL // Default TTL 5 minutes | 默认5分钟 } + if storage == nil { + storage = memory.NewStorage() // Use in-memory storage if not provided | 如果未提供使用内存存储 + } + return &NonceManager{ - storage: storage, + authType: authType, keyPrefix: prefix, + storage: storage, ttl: ttl, } } // Generate Generates a new nonce and stores it | 生成新的nonce并存储 -// Returns 64-char hex string | 返回64字符的十六进制字符串 -func (nm *NonceManager) Generate() (string, error) { +func (nm *NonceManager) Generate(ctx context.Context) (string, error) { + // Create byte slice for nonce | 创建字节切片生成nonce bytes := make([]byte, NonceLength) if _, err := rand.Read(bytes); err != nil { - return "", fmt.Errorf("failed to generate random bytes: %w", err) + return "", err } + // Encode bytes to hex string | 编码为16进制字符串 nonce := hex.EncodeToString(bytes) + // Build storage key | 构建存储键 key := nm.getNonceKey(nonce) - if err := nm.storage.Set(key, time.Now().Unix(), nm.ttl); err != nil { - return "", fmt.Errorf("failed to store nonce: %w", err) + if err := nm.storage.Set(ctx, key, time.Now().Unix(), nm.ttl); err != nil { + return "", fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } return nonce, nil } -// Verify Verifies nonce and consumes it (one-time use) | 验证nonce并消费它(一次性使用) -// Returns false if nonce doesn't exist or already used | 如果nonce不存在或已使用则返回false -func (nm *NonceManager) Verify(nonce string) bool { +// Verify Verifies nonce and consumes it (one-time use) Returns false if nonce doesn't exist or already used | 验证nonce并消费它(一次性使用)如果nonce不存在或已使用则返回false +func (nm *NonceManager) Verify(ctx context.Context, nonce string) bool { if nonce == "" { return false } + // Build storage key | 构建存储键 key := nm.getNonceKey(nonce) - nm.mu.Lock() - defer nm.mu.Unlock() + nm.mu.Lock() // Acquire write lock | 获取写锁 + defer nm.mu.Unlock() // Release lock after function | 函数结束释放锁 - if !nm.storage.Exists(key) { + // Nonce not found | 未找到nonce + if !nm.storage.Exists(ctx, key) { return false } - nm.storage.Delete(key) + // Consume nonce | 消耗nonce + _ = nm.storage.Delete(ctx, key) + return true } // VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 -func (nm *NonceManager) VerifyAndConsume(nonce string) error { - if !nm.Verify(nonce) { - return ErrInvalidNonce +func (nm *NonceManager) VerifyAndConsume(ctx context.Context, nonce string) error { + if !nm.Verify(ctx, nonce) { + return core.ErrInvalidNonce } return nil } // IsValid Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) -func (nm *NonceManager) IsValid(nonce string) bool { +func (nm *NonceManager) IsValid(ctx context.Context, nonce string) bool { if nonce == "" { return false } + // Build storage key | 构建存储键 key := nm.getNonceKey(nonce) - nm.mu.RLock() - defer nm.mu.RUnlock() + nm.mu.RLock() // Acquire read lock | 获取读锁 + defer nm.mu.RUnlock() // Release read lock | 释放读锁 - return nm.storage.Exists(key) + // Return existence | 返回是否存在 + return nm.storage.Exists(ctx, key) } // getNonceKey Gets storage key for nonce | 获取nonce的存储键 func (nm *NonceManager) getNonceKey(nonce string) string { - return nm.keyPrefix + NonceKeySuffix + nonce + return nm.keyPrefix + nm.authType + NonceKeySuffix + nonce } diff --git a/core/security/refresh_token.go b/core/security/refresh_token.go index 500761c..cff4ebd 100644 --- a/core/security/refresh_token.go +++ b/core/security/refresh_token.go @@ -1,127 +1,110 @@ package security import ( + "context" "crypto/rand" "encoding/hex" - "encoding/json" "fmt" + codec_json "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/storage/memory" "time" - "github.com/click33/sa-token-go/core/adapter" - "github.com/click33/sa-token-go/core/config" - "github.com/click33/sa-token-go/core/token" "github.com/click33/sa-token-go/core/utils" ) -// Refresh Token Implementation -// 刷新令牌实现 +// Refresh Token Implementation | 刷新令牌实现 // // Flow | 流程: -// 1. GenerateTokenPair() - Create access token + refresh token | 创建访问令牌 + 刷新令牌 -// 2. Access token expires (short-lived, e.g. 2h) | 访问令牌过期(短期,如2小时) -// 3. RefreshAccessToken() - Use refresh token to get new access token | 使用刷新令牌获取新访问令牌 -// 4. Refresh token expires (long-lived, 30 days) | 刷新令牌过期(长期,30天) -// -// Usage | 用法: -// tokenInfo, _ := manager.LoginWithRefreshToken(loginID, "web") -// // ... access token expires ... -// newInfo, _ := manager.RefreshAccessToken(tokenInfo.RefreshToken) - -// Constants for refresh token | 刷新令牌常量 -const ( - DefaultRefreshTTL = 30 * 24 * time.Hour // 30 days | 30天 - DefaultAccessTTL = 2 * time.Hour // 2 hours | 2小时 - RefreshTokenLength = 32 // Refresh token byte length | 刷新令牌字节长度 - RefreshKeySuffix = "refresh:" // Key suffix after prefix | 前缀后的键后缀 -) - -// Error variables | 错误变量 -var ( - ErrInvalidRefreshToken = fmt.Errorf("invalid refresh token") - ErrRefreshTokenExpired = fmt.Errorf("refresh token expired") - ErrInvalidRefreshData = fmt.Errorf("invalid refresh token data") -) +// 1. GenerateTokenPair() -> AccessToken + RefreshToken | 创建访问令牌 + 刷新令牌 +// 2. AccessToken expires | 访问令牌过期 +// 3. RefreshAccessToken() -> New AccessToken | 使用刷新令牌获取新访问令牌 +// 4. RefreshToken expires -> Relogin | 刷新令牌过期需重新登录 + +// AccessTokenInfo Access token storage value | 访问令牌存储数据 +type AccessTokenInfo struct { + LoginID string `json:"loginID"` // User login ID | 用户登录ID + Device string `json:"device"` // Device type | 设备类型 +} -// RefreshTokenInfo refresh token information | 刷新令牌信息 +// RefreshTokenInfo Refresh token storage value | 刷新令牌存储数据 type RefreshTokenInfo struct { - RefreshToken string `json:"refreshToken"` // Refresh token (long-lived) | 刷新令牌(长期有效) - AccessToken string `json:"accessToken"` // Access token (short-lived) | 访问令牌(短期有效) + RefreshToken string `json:"refreshToken"` // Refresh token value | 刷新令牌值 + AccessToken string `json:"accessToken"` // Latest access token | 最新访问令牌 LoginID string `json:"loginID"` // User login ID | 用户登录ID Device string `json:"device"` // Device type | 设备类型 - CreateTime int64 `json:"createTime"` // Creation timestamp | 创建时间戳 - ExpireTime int64 `json:"expireTime"` // Expiration timestamp | 过期时间戳 -} - -// MarshalBinary implements encoding.BinaryMarshaler for Redis storage | 实现encoding.BinaryMarshaler接口用于Redis存储 -func (r *RefreshTokenInfo) MarshalBinary() ([]byte, error) { - return json.Marshal(r) + CreateTime int64 `json:"createTime"` // Create timestamp | 创建时间 + ExpireTime int64 `json:"expireTime"` // Expire timestamp | 过期时间 } -// UnmarshalBinary implements encoding.BinaryUnmarshaler for Redis storage | 实现encoding.BinaryUnmarshaler接口用于Redis存储 -func (r *RefreshTokenInfo) UnmarshalBinary(data []byte) error { - return json.Unmarshal(data, r) -} - -// RefreshTokenManager Refresh token manager | 刷新令牌管理器 +// RefreshTokenManager Refresh token manager-example | 刷新令牌管理器 type RefreshTokenManager struct { - storage adapter.Storage - keyPrefix string // Configurable prefix | 可配置的前缀 - tokenKeyPrefix string // Token key prefix | 令牌键前缀 - tokenGen *token.Generator - refreshTTL time.Duration // Refresh token TTL (30 days) | 刷新令牌有效期(30天) - accessTTL time.Duration // Access token TTL (configurable) | 访问令牌有效期(可配置) + authType string // Auth system type | 认证体系类型 + keyPrefix string // Storage key prefix | 存储前缀 + tokenKeyPrefix string // Token key prefix | Token 前缀 + refreshTTL time.Duration // Refresh token TTL | 刷新令牌有效期 + accessTTL time.Duration // Access token TTL | 访问令牌有效期 + + tokenGen adapter.Generator // Token generator | Token 生成器 + storage adapter.Storage // Storage adapter | 存储适配器 + serializer adapter.Codec // Codec adapter | 编解码器 } -// NewRefreshTokenManager Creates a new refresh token manager | 创建新的刷新令牌管理器 -// prefix: key prefix (e.g., "satoken:" or "" for Java compatibility) | 键前缀(如:"satoken:" 或 "" 兼容Java) -// cfg: configuration, uses Timeout for access token TTL | 配置,使用Timeout作为访问令牌有效期 -func NewRefreshTokenManager(storage adapter.Storage, prefix, keyPrefix string, cfg *config.Config) *RefreshTokenManager { - accessTTL := time.Duration(cfg.Timeout) * time.Second - +// NewRefreshTokenManager Create manager-example instance | 创建刷新令牌管理器 +func NewRefreshTokenManager( + authType, prefix, tokenKeyPrefix string, + tokenGen adapter.Generator, + accessTTL time.Duration, + storage adapter.Storage, + serializer adapter.Codec, +) *RefreshTokenManager { + + if tokenGen == nil { + tokenGen = sgenerator.NewDefaultGenerator() + } if accessTTL == 0 { accessTTL = DefaultAccessTTL } + if storage == nil { + storage = memory.NewStorage() + } + if serializer == nil { + serializer = codec_json.NewJSONSerializer() + } return &RefreshTokenManager{ - storage: storage, + authType: authType, keyPrefix: prefix, - tokenKeyPrefix: keyPrefix, - tokenGen: token.NewGenerator(cfg), + tokenKeyPrefix: tokenKeyPrefix, + tokenGen: tokenGen, refreshTTL: DefaultRefreshTTL, accessTTL: accessTTL, + storage: storage, + serializer: serializer, } } -// GenerateTokenPair Generates access token and refresh token pair | 生成访问令牌和刷新令牌对 -func (rtm *RefreshTokenManager) GenerateTokenPair(loginID, device string, accessTokenOverride ...string) (*RefreshTokenInfo, error) { +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func (rtm *RefreshTokenManager) GenerateTokenPair(ctx context.Context, loginID, device string) (*RefreshTokenInfo, error) { if loginID == "" { - return nil, fmt.Errorf("loginID cannot be empty") + return nil, core.ErrInvalidLoginIDEmpty } // Generate access token | 生成访问令牌 - var accessToken string - if len(accessTokenOverride) > 0 && accessTokenOverride[0] != "" { - accessToken = accessTokenOverride[0] - } else { - var err error - accessToken, err = rtm.tokenGen.Generate(loginID, device) - if err != nil { - return nil, fmt.Errorf("failed to generate access token: %w", err) - } + accessToken, err := rtm.tokenGen.Generate(loginID, device) + if err != nil { + return nil, err } - // Save token-loginID mapping (符合 Java sa-token 设计) | 保存 Token-LoginID 映射 - tokenKey := rtm.getTokenKey(accessToken) - if err := rtm.storage.Set(tokenKey, loginID, rtm.accessTTL); err != nil { - return nil, fmt.Errorf("failed to save token: %w", err) + random := make([]byte, RefreshTokenLength) + if _, err := rand.Read(random); err != nil { + return nil, err } // Generate refresh token | 生成刷新令牌 - refreshTokenBytes := make([]byte, RefreshTokenLength) - if _, err := rand.Read(refreshTokenBytes); err != nil { - return nil, fmt.Errorf("failed to generate refresh token: %w", err) - } - refreshToken := hex.EncodeToString(refreshTokenBytes) + refreshToken := hex.EncodeToString(random) now := time.Now() info := &RefreshTokenInfo{ @@ -132,124 +115,213 @@ func (rtm *RefreshTokenManager) GenerateTokenPair(loginID, device string, access CreateTime: now.Unix(), ExpireTime: now.Add(rtm.refreshTTL).Unix(), } + // Encode refresh token info | 编码刷新令牌信息 + refreshData, err := rtm.serializer.Encode(info) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } - key := rtm.getRefreshKey(refreshToken) - if err := rtm.storage.Set(key, info, rtm.refreshTTL); err != nil { - return nil, fmt.Errorf("failed to store refresh token: %w", err) + // Encode access token info | 编码访问令牌信息 + accessData, err := rtm.serializer.Encode(&AccessTokenInfo{ + LoginID: loginID, + Device: device, + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) } - return info, nil -} + // Store access token | 存储访问令牌 + if err = rtm.storage.Set( + ctx, + rtm.getTokenKey(accessToken), + accessData, + rtm.accessTTL, + ); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } -// RefreshAccessToken Generates new access token using refresh token | 使用刷新令牌生成新的访问令牌 -func (rtm *RefreshTokenManager) RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) { - if refreshToken == "" { - return nil, ErrInvalidRefreshToken + // Store refresh token | 存储刷新令牌 + if err := rtm.storage.Set( + ctx, + rtm.getRefreshKey(refreshToken), + refreshData, + rtm.refreshTTL, + ); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } - // Get refresh token info | 获取刷新令牌信息 - key := rtm.getRefreshKey(refreshToken) + return info, nil +} + +// VerifyAccessToken Check token exists | 验证访问令牌是否存在 +func (rtm *RefreshTokenManager) VerifyAccessToken(ctx context.Context, accessToken string) bool { + return rtm.storage.Exists(ctx, rtm.getTokenKey(accessToken)) +} - // Get refresh token info | 获取刷新令牌信息 - data, err := rtm.storage.Get(key) +// VerifyAccessTokenAndGetInfo Verify and get info | 验证访问令牌并获取信息 +func (rtm *RefreshTokenManager) VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string) (*AccessTokenInfo, bool) { + data, err := rtm.storage.Get(ctx, rtm.getTokenKey(accessToken)) if err != nil || data == nil { - return nil, ErrInvalidRefreshToken + return nil, false } - // Convert to RefreshTokenInfo | 转换为 RefreshTokenInfo - dataBytes, err := utils.ToBytes(data) + bytes, err := utils.ToBytes(data) if err != nil { - return nil, ErrInvalidRefreshData + return nil, false } - // Unmarshal data | 反序列化数据 - oldInfo := &RefreshTokenInfo{} - err = oldInfo.UnmarshalBinary(dataBytes) - if err != nil { - return nil, ErrInvalidRefreshData + var info AccessTokenInfo + if err := rtm.serializer.Decode(bytes, &info); err != nil { + return nil, false } - // Check expiration | 检查是否过期 - if time.Now().Unix() > oldInfo.ExpireTime { - rtm.storage.Delete(key) - return nil, ErrRefreshTokenExpired + return &info, true +} + +// RefreshAccessToken Refresh access token by refresh token | 使用刷新令牌刷新访问令牌 +func (rtm *RefreshTokenManager) RefreshAccessToken(ctx context.Context, refreshToken string) (*RefreshTokenInfo, error) { + if refreshToken == "" { + return nil, core.ErrNonceInvalidRefreshToken } - // Generate new access token | 生成新的访问令牌 - newAccessToken, err := rtm.tokenGen.Generate(oldInfo.LoginID, oldInfo.Device) + refreshKey := rtm.getRefreshKey(refreshToken) + + // Load refresh token | 读取刷新令牌 + data, err := rtm.storage.Get(ctx, refreshKey) if err != nil { - return nil, fmt.Errorf("failed to generate new access token: %w", err) + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return nil, core.ErrInvalidRefreshToken } - // Update access token info | 更新访问令牌信息 - oldInfo.AccessToken = newAccessToken + bytes, err := utils.ToBytes(data) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } - // Save token-loginID mapping (符合 Java sa-token 设计) | 保存 Token-LoginID 映射 - tokenKey := rtm.getTokenKey(newAccessToken) - if err := rtm.storage.Set(tokenKey, oldInfo.LoginID, rtm.accessTTL); err != nil { - return nil, fmt.Errorf("failed to save token: %w", err) + var info RefreshTokenInfo + if err := rtm.serializer.Decode(bytes, &info); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) } - // Update storage | 更新存储 - if err := rtm.storage.Set(key, oldInfo, rtm.refreshTTL); err != nil { - return nil, fmt.Errorf("failed to update refresh token: %w", err) + // Check expiration | 检查过期 + if time.Now().Unix() > info.ExpireTime { + _ = rtm.storage.Delete(ctx, refreshKey) + return nil, core.ErrRefreshTokenExpired } - return oldInfo, nil -} + // Remove old access token | 删除旧访问令牌 + if info.AccessToken != "" { + _ = rtm.storage.Delete(ctx, rtm.getTokenKey(info.AccessToken)) + } -// RevokeRefreshToken Revokes a refresh token | 撤销刷新令牌 -func (rtm *RefreshTokenManager) RevokeRefreshToken(refreshToken string) error { - if refreshToken == "" { - return nil + // Generate new access token | 生成新访问令牌 + newAccessToken, err := rtm.tokenGen.Generate(info.LoginID, info.Device) + if err != nil { + return nil, err + } + info.AccessToken = newAccessToken + + // Store new access token | 存储新访问令牌 + accessData, err := rtm.serializer.Encode(&AccessTokenInfo{ + LoginID: info.LoginID, + Device: info.Device, + }) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + if err := rtm.storage.Set( + ctx, + rtm.getTokenKey(newAccessToken), + accessData, + rtm.accessTTL, + ); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + // Update refresh token without extending TTL | 更新刷新令牌但不续期 + refreshData, err := rtm.serializer.Encode(&info) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) } - key := rtm.getRefreshKey(refreshToken) - return rtm.storage.Delete(key) + if err = rtm.storage.SetKeepTTL(ctx, refreshKey, refreshData); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + + return &info, nil } -// GetRefreshTokenInfo Gets refresh token information | 获取刷新令牌信息 -func (rtm *RefreshTokenManager) GetRefreshTokenInfo(refreshToken string) (*RefreshTokenInfo, error) { +// GetRefreshTokenInfo Get refresh token info by token | 根据刷新令牌获取刷新令牌信息 +func (rtm *RefreshTokenManager) GetRefreshTokenInfo(ctx context.Context, refreshToken string) (*RefreshTokenInfo, error) { if refreshToken == "" { - return nil, ErrInvalidRefreshToken + return nil, core.ErrInvalidRefreshToken } - key := rtm.getRefreshKey(refreshToken) + refreshKey := rtm.getRefreshKey(refreshToken) - data, err := rtm.storage.Get(key) - if err != nil || data == nil { - return nil, ErrInvalidRefreshToken + // Load refresh token | 读取刷新令牌 + data, err := rtm.storage.Get(ctx, refreshKey) + if err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + if data == nil { + return nil, core.ErrInvalidRefreshToken } - dataBytes, err := utils.ToBytes(data) + bytes, err := utils.ToBytes(data) if err != nil { - return nil, ErrInvalidRefreshData + return nil, fmt.Errorf("%w: %v", core.ErrTypeConvert, err) + } + + var info RefreshTokenInfo + if err = rtm.serializer.Decode(bytes, &info); err != nil { + return nil, fmt.Errorf("%w: %v", core.ErrDeserializeFailed, err) + } + + return &info, nil +} + +// RevokeRefreshToken Remove refresh token | 撤销刷新令牌 +func (rtm *RefreshTokenManager) RevokeRefreshToken(ctx context.Context, refreshToken string) error { + if refreshToken == "" { + return nil } - info := &RefreshTokenInfo{} - err = info.UnmarshalBinary(dataBytes) + err := rtm.storage.Delete(ctx, rtm.getRefreshKey(refreshToken)) if err != nil { - return nil, ErrInvalidRefreshData + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } - return info, nil + return nil } -// IsValid Checks if refresh token is valid | 检查刷新令牌是否有效 -func (rtm *RefreshTokenManager) IsValid(refreshToken string) bool { - info, err := rtm.GetRefreshTokenInfo(refreshToken) +// IsValid Check refresh token valid | 判断刷新令牌是否有效 +func (rtm *RefreshTokenManager) IsValid(ctx context.Context, refreshToken string) bool { + data, err := rtm.storage.Get(ctx, rtm.getRefreshKey(refreshToken)) + if err != nil || data == nil { + return false + } + + bytes, err := utils.ToBytes(data) if err != nil { return false } + var info RefreshTokenInfo + if err = rtm.serializer.Decode(bytes, &info); err != nil { + return false + } + return time.Now().Unix() <= info.ExpireTime } -// getRefreshKey Gets storage key for refresh token | 获取刷新令牌的存储键 +// getRefreshKey Build refresh token key | 构建刷新令牌 Key func (rtm *RefreshTokenManager) getRefreshKey(refreshToken string) string { - return rtm.keyPrefix + RefreshKeySuffix + refreshToken + return rtm.keyPrefix + rtm.authType + RefreshKeySuffix + refreshToken } -// getTokenKey Gets token storage key | 获取Token存储键 +// getTokenKey Build access token key | 构建访问令牌 Key func (rtm *RefreshTokenManager) getTokenKey(tokenValue string) string { - return rtm.keyPrefix + rtm.tokenKeyPrefix + tokenValue + return rtm.keyPrefix + rtm.authType + rtm.tokenKeyPrefix + tokenValue } diff --git a/core/session/consts.go b/core/session/consts.go new file mode 100644 index 0000000..94a7bac --- /dev/null +++ b/core/session/consts.go @@ -0,0 +1,7 @@ +// @Author daixk 2025/12/7 17:22:00 +package session + +// Constants for session keys | Session键常量 +const ( + SessionKeyPrefix = "session:" // Storage key prefix | 存储键前缀 +) diff --git a/core/session/session.go b/core/session/session.go index 1973588..17e93d8 100644 --- a/core/session/session.go +++ b/core/session/session.go @@ -1,87 +1,97 @@ package session import ( - "encoding/json" + "context" "fmt" + codec_json "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/storage/memory" "sync" "time" "github.com/click33/sa-token-go/core/adapter" ) -// Constants for session keys | Session键常量 -const ( - SessionKeyPrefix = "session:" // Storage key prefix | 存储键前缀 -) - -// Error variables | 错误变量 -var ( - ErrSessionNotFound = fmt.Errorf("session not found") - ErrInvalidSessionData = fmt.Errorf("invalid session data") -) - // Session Session object for storing user data | 会话对象,用于存储用户数据 type Session struct { - ID string `json:"id"` // Session ID | Session标识 - CreateTime int64 `json:"createTime"` // Creation time | 创建时间 - Data map[string]any `json:"data"` // Session data | 数据 - mu sync.RWMutex `json:"-"` // Read-write lock | 读写锁 - storage adapter.Storage `json:"-"` // Storage backend | 存储 - prefix string `json:"-"` // Key prefix | 键前缀 + AuthType string `json:"authType"` // Authentication system type | 认证体系类型 + ID string `json:"id"` // Session ID | Session标识 + CreateTime int64 `json:"createTime"` // Creation time | 创建时间 + Data map[string]any `json:"data"` // Session data | 数据 + + prefix string `json:"-" msgpack:"-"` // Key prefix | 键前缀 + mu sync.RWMutex `json:"-" msgpack:"-"` // Read-write lock | 读写锁 + storage adapter.Storage `json:"-" msgpack:"-"` // Storage adapter (Redis, Memory, etc.) | 存储适配器(如 Redis、Memory) + serializer adapter.Codec `json:"-" msgpack:"-"` // Codec adapter for encoding and decoding operations | 编解码器适配器 } // NewSession Creates a new session | 创建新的Session -func NewSession(id string, storage adapter.Storage, prefix string) *Session { +func NewSession(authType, prefix, id string, storage adapter.Storage, serializer adapter.Codec) *Session { + if storage == nil { + storage = memory.NewStorage() + } + if serializer == nil { + serializer = codec_json.NewJSONSerializer() + } + return &Session{ + AuthType: authType, ID: id, CreateTime: time.Now().Unix(), Data: make(map[string]any), - storage: storage, prefix: prefix, + storage: storage, + serializer: serializer, } } +// SetDependencies sets internal dependencies for a decoded session | 设置反序列化后的 Session 的内部依赖 +func (s *Session) SetDependencies(prefix string, storage adapter.Storage, serializer adapter.Codec) { + if storage == nil { + storage = memory.NewStorage() + } + if serializer == nil { + serializer = codec_json.NewJSONSerializer() + } + + s.prefix = prefix + s.storage = storage + s.serializer = serializer +} + // ============ Data Operations | 数据操作 ============ // Set Sets value | 设置值 -func (s *Session) Set(key string, value any, ttl ...time.Duration) error { +func (s *Session) Set(ctx context.Context, key string, value any, ttl ...time.Duration) error { if key == "" { - return fmt.Errorf("key cannot be empty") + return core.ErrSessionInvalidDataKey } s.mu.Lock() defer s.mu.Unlock() s.Data[key] = value - if len(ttl) > 0 && ttl[0] > 0 { - return s.saveWithTTL(ttl[0]) - } - return s.save() + return s.save(ctx, ttl...) } // SetMulti sets multiple key-value pairs | 设置多个键值对 -func (s *Session) SetMulti(values map[string]any, ttl ...time.Duration) error { - if len(values) == 0 { +func (s *Session) SetMulti(ctx context.Context, valueMap map[string]any, ttl ...time.Duration) error { + if len(valueMap) == 0 { return nil } s.mu.Lock() defer s.mu.Unlock() - for key, value := range values { + for key, value := range valueMap { if key == "" { - return fmt.Errorf("key cannot be empty") + return core.ErrSessionInvalidDataKey } s.Data[key] = value } - if len(ttl) > 0 && ttl[0] > 0 { - fmt.Println("ttl:", ttl[0]) - return s.saveWithTTL(ttl[0]) - } - - return s.save() + return s.save(ctx, ttl...) } // Get Gets value | 获取值 @@ -152,22 +162,22 @@ func (s *Session) Has(key string) bool { return exists } -// Delete 删除键 -func (s *Session) Delete(key string) error { +// Delete removes a key and preserves TTL | 删除键并保留 TTL +func (s *Session) Delete(ctx context.Context, key string) error { s.mu.Lock() defer s.mu.Unlock() delete(s.Data, key) - return s.save() + return s.saveKeepTTL(ctx) } -// Clear Clears all data | 清空所有数据 -func (s *Session) Clear() error { +// Clear removes all keys but preserves TTL | 清空所有键并保留 TTL +func (s *Session) Clear(ctx context.Context) error { s.mu.Lock() defer s.mu.Unlock() s.Data = make(map[string]any) - return s.save() + return s.saveKeepTTL(ctx) } // Keys Gets all keys | 获取所有键 @@ -196,93 +206,84 @@ func (s *Session) IsEmpty() bool { } // Renew extends the session TTL without modifying content | 续期 Session 的 TTL,但不修改内容 -func (s *Session) Renew(ttl time.Duration) error { - if ttl <= 0 { - return nil // 不允许设置 0 TTL,避免误删 +func (s *Session) Renew(ctx context.Context, ttl time.Duration) error { + if ttl < 0 { + return nil // Skip renewal if ttl is invalid | 跳过无效续期 } key := s.getStorageKey() - return s.storage.Expire(key, ttl) + return s.storage.Expire(ctx, key, ttl) } -// ============ Internal Methods | 内部方法 ============ - -// save Saves session to storage | 保存到存储 -func (s *Session) save() error { - data, err := json.Marshal(s) - if err != nil { - return fmt.Errorf("failed to marshal session: %w", err) - } +// Destroy Destroys session | 销毁Session +func (s *Session) Destroy(ctx context.Context) error { + s.mu.Lock() + defer s.mu.Unlock() key := s.getStorageKey() - return s.storage.Set(key, string(data), 0) + return s.storage.Delete(ctx, key) } -// saveWithTTL saves session with TTL | 带 TTL 保存 Session -func (s *Session) saveWithTTL(ttl time.Duration) error { - data, err := json.Marshal(s) - if err != nil { - return fmt.Errorf("failed to marshal session: %w", err) - } - - key := s.getStorageKey() - fmt.Println(ttl) - return s.storage.Set(key, string(data), ttl) -} +// ============ Internal Methods | 内部方法 ============ // getStorageKey Gets storage key for this session | 获取Session的存储键 func (s *Session) getStorageKey() string { - return s.prefix + SessionKeyPrefix + s.ID + return s.prefix + s.AuthType + SessionKeyPrefix + s.ID } -// ============ Static Methods | 静态方法 ============ +// save Saves session to storage | 保存到存储 +func (s *Session) save(ctx context.Context, ttl ...time.Duration) error { + data, err := s.serializer.Encode(s) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } + + key := s.getStorageKey() -// Load Loads session from storage | 从存储加载 -func Load(id string, storage adapter.Storage, prefix string) (*Session, error) { - if id == "" { - return nil, fmt.Errorf("session id cannot be empty") + // Default to 0 (no expiration) | 默认使用 0(无过期时间) + if len(ttl) == 0 || ttl[0] <= 0 { + err = s.storage.Set(ctx, key, data, 0) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } + return nil } - key := prefix + SessionKeyPrefix + id - data, err := storage.Get(key) + // Save with provided TTL | 使用指定 TTL 保存 + err = s.storage.Set(ctx, key, data, ttl[0]) if err != nil { - return nil, err - } - if data == nil { - return nil, ErrSessionNotFound + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) } - var ( - raw []byte - session Session - ) + return nil +} - // Support both string and []byte | 同时兼容 string 和 []byte - switch v := data.(type) { - case string: - raw = []byte(v) +// saveKeepTTL saves session while preserving its TTL | 保存 Session 并保留现有 TTL +func (s *Session) saveKeepTTL(ctx context.Context) error { + data, err := s.serializer.Encode(s) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrSerializeFailed, err) + } - case []byte: - raw = v + key := s.getStorageKey() - default: - return nil, ErrInvalidSessionData - } + // Try to get current TTL | 获取当前 TTL + // -1: never expires | 永不过期 + // -2: key not found | key不存在 + // >0: remaining TTL | 剩余时间 + ttl, _ := s.storage.TTL(ctx, key) - if err := json.Unmarshal(raw, &session); err != nil { - return nil, fmt.Errorf("%w: %v", ErrInvalidSessionData, err) + // ttl <= 0 means: not found(-2), never expires(-1), or expired + // All these cases should save with no expiration | 这些情况都保存为永久 + if ttl <= 0 { + ttl = 0 } + // ttl > 0: use original TTL | 使用原有TTL - session.storage = storage - session.prefix = prefix - return &session, nil -} - -// Destroy Destroys session | 销毁Session -func (s *Session) Destroy() error { - s.mu.Lock() - defer s.mu.Unlock() + err = s.storage.Set(ctx, key, data, ttl) + if err != nil { + return fmt.Errorf("%w: %v", core.ErrStorageUnavailable, err) + } - key := s.getStorageKey() - return s.storage.Delete(key) + return nil } diff --git a/core/token/token_test.go b/core/token/token_test.go deleted file mode 100644 index 693da94..0000000 --- a/core/token/token_test.go +++ /dev/null @@ -1,119 +0,0 @@ -package token - -import ( - "testing" - - "github.com/click33/sa-token-go/core/config" -) - -func TestGenerateHash(t *testing.T) { - cfg := &config.Config{ - TokenStyle: config.TokenStyleHash, - Timeout: 3600, - } - gen := NewGenerator(cfg) - - token1, err := gen.Generate("user1000", "default") - if err != nil { - t.Fatalf("Failed to generate hash token: %v", err) - } - - if len(token1) != 64 { - t.Errorf("Hash token should be 64 characters, got %d", len(token1)) - } - - // Generate another token, should be different - token2, err := gen.Generate("user1000", "default") - if err != nil { - t.Fatalf("Failed to generate second hash token: %v", err) - } - - if token1 == token2 { - t.Error("Hash tokens should be different due to randomness") - } - - t.Logf("Hash Token 1: %s", token1) - t.Logf("Hash Token 2: %s", token2) -} - -func TestGenerateTimestamp(t *testing.T) { - cfg := &config.Config{ - TokenStyle: config.TokenStyleTimestamp, - Timeout: 3600, - } - gen := NewGenerator(cfg) - - token, err := gen.Generate("user1000", "default") - if err != nil { - t.Fatalf("Failed to generate timestamp token: %v", err) - } - - // Timestamp token format: timestamp_loginID_random - if len(token) < 20 { - t.Errorf("Timestamp token seems too short: %s", token) - } - - t.Logf("Timestamp Token: %s", token) -} - -func TestGenerateTik(t *testing.T) { - cfg := &config.Config{ - TokenStyle: config.TokenStyleTik, - Timeout: 3600, - } - gen := NewGenerator(cfg) - - token, err := gen.Generate("user1000", "default") - if err != nil { - t.Fatalf("Failed to generate tik token: %v", err) - } - - if len(token) != 11 { - t.Errorf("Tik token should be 11 characters, got %d", len(token)) - } - - // Check all characters are alphanumeric - for _, c := range token { - if !((c >= '0' && c <= '9') || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')) { - t.Errorf("Tik token should only contain alphanumeric characters, got: %c in %s", c, token) - } - } - - t.Logf("Tik Token: %s", token) -} - -func TestAllTokenStyles(t *testing.T) { - styles := []config.TokenStyle{ - config.TokenStyleUUID, - config.TokenStyleSimple, - config.TokenStyleRandom32, - config.TokenStyleRandom64, - config.TokenStyleRandom128, - config.TokenStyleJWT, - config.TokenStyleHash, - config.TokenStyleTimestamp, - config.TokenStyleTik, - } - - for _, style := range styles { - t.Run(string(style), func(t *testing.T) { - cfg := &config.Config{ - TokenStyle: style, - Timeout: 3600, - JwtSecretKey: "test-secret-key", - } - gen := NewGenerator(cfg) - - token, err := gen.Generate("user1000", "default") - if err != nil { - t.Fatalf("Failed to generate %s token: %v", style, err) - } - - if token == "" { - t.Errorf("%s token should not be empty", style) - } - - t.Logf("%s Token: %s (length: %d)", style, token, len(token)) - }) - } -} diff --git a/core/utils/utils.go b/core/utils/utils.go index c7dbf4d..5b09d62 100644 --- a/core/utils/utils.go +++ b/core/utils/utils.go @@ -1,6 +1,7 @@ package utils import ( + "context" "crypto/rand" "crypto/sha256" "encoding/base64" @@ -108,7 +109,7 @@ func IsNotEmpty(s string) bool { return !IsEmpty(s) } -// DefaultString returns default value if string is empty | 如果字符串为空则返回默认值 +// DefaultString returns log value if string is empty | 如果字符串为空则返回默认值 func DefaultString(s, defaultValue string) string { if IsEmpty(s) { return defaultValue @@ -239,7 +240,7 @@ func ParsePermissionTag(tag string) []string { } // ParseRoleTag 解析角色标签 -// 格式: "role:admin,manager" +// 格式: "role:admin,manager-example" func ParseRoleTag(tag string) []string { if tag == "" { return []string{} @@ -588,3 +589,21 @@ func UniqueSlice[T comparable](slice []T) []T { } return result } + +// GetCtxValue Returns the value for the given key in the context as a string or the defaultValue if the value is not of type string | 从上下文中获取指定key的值,如果值不是字符串则返回默认值 +func GetCtxValue(ctx context.Context, valueKey string, defaultValue ...string) string { + val := ctx.Value(valueKey) // 从上下文中获取值 | Retrieve the value from the context + + // Check if the value is a non-empty string | 检查值是否是非空字符串 + if v, ok := val.(string); ok && v != "" { + return v // If it's a non-empty string, return it | 如果是非空字符串,返回该值 + } + + // If defaultValue is provided, return it, otherwise return an empty string | 如果提供了默认值,返回默认值,否则返回空字符串 + if len(defaultValue) > 0 { + return defaultValue[0] + } + + // Otherwise, return an empty string | 否则返回空字符串 + return "" +} diff --git a/core/version.go b/core/version.go new file mode 100644 index 0000000..d18b4ad --- /dev/null +++ b/core/version.go @@ -0,0 +1,5 @@ +// @Author daixk 2025/12/22 13:33:00 +package core + +// Version Sa-Token-Go version | Sa-Token-Go版本 +const Version = "0.1.7" diff --git a/demo/banner_demo.go b/demo/banner_demo.go deleted file mode 100644 index 51c6511..0000000 --- a/demo/banner_demo.go +++ /dev/null @@ -1,45 +0,0 @@ -package main - -import ( - "github.com/click33/sa-token-go/core/banner" - "github.com/click33/sa-token-go/core/config" -) - -func main() { - // 1. 打印基础 Banner - banner.Print() - - // 2. 打印带完整配置的 Banner - cfg := config.DefaultConfig() - banner.PrintWithConfig(cfg) - - // 3. 打印 JWT 配置的 Banner - jwtCfg := &config.Config{ - TokenName: "jwt-token", - Timeout: 86400, // 24小时 - ActiveTimeout: -1, - IsConcurrent: true, - IsShare: false, - MaxLoginCount: 5, - IsReadBody: false, - IsReadHeader: true, - IsReadCookie: true, - TokenStyle: config.TokenStyleJWT, - DataRefreshPeriod: -1, - TokenSessionCheckLogin: true, - AutoRenew: true, - JwtSecretKey: "my-super-secret-key-123456", - IsLog: true, - IsPrintBanner: true, - CookieConfig: &config.CookieConfig{ - Domain: "example.com", - Path: "/api", - Secure: true, - HttpOnly: true, - SameSite: config.SameSiteStrict, - MaxAge: 7200, - }, - } - - banner.PrintWithConfig(jwtCfg) -} diff --git a/demo/java_compat_demo.go b/demo/java_compat_demo.go deleted file mode 100644 index 0bb0553..0000000 --- a/demo/java_compat_demo.go +++ /dev/null @@ -1,87 +0,0 @@ -package main - -import ( - "fmt" - - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/storage/memory" -) - -func main() { - fmt.Println("🔄 Java sa-token 兼容性演示") - fmt.Println("=" + "────────────────────────────────────────────────────────────" + "=") - fmt.Println() - - storage := memory.NewStorage() - - // 方式1: Go 默认配置(带前缀 "satoken:") - fmt.Println("【方式1】Go 默认配置 - 使用前缀 'satoken:'") - mgr1 := core.NewBuilder(). - Storage(storage). - TokenName("satoken"). // 使用默认的 token 名称 - KeyPrefix("satoken:"). // 显式设置前缀(默认值) - IsPrintBanner(false). - Build() - - token1, _ := mgr1.Login("user001", "pc") - fmt.Printf("✅ 登录成功,Token: %s\n", token1) - fmt.Println(" Redis Keys 示例:") - fmt.Println(" - satoken:token:" + token1) - fmt.Println(" - satoken:account:user001:pc") - fmt.Println(" - satoken:session:user001") - fmt.Println() - - // 方式2: Java sa-token 兼容配置(无前缀) - fmt.Println("【方式2】Java 兼容配置 - 无前缀(与Java默认行为一致)") - storage2 := memory.NewStorage() - mgr2 := core.NewBuilder(). - Storage(storage2). - TokenName("satoken"). // 必须与 Java 端配置一致 - KeyPrefix(""). // 空前缀,兼容 Java sa-token - IsPrintBanner(false). - Build() - - token2, _ := mgr2.Login("user002", "web") - fmt.Printf("✅ 登录成功,Token: %s\n", token2) - fmt.Println(" Redis Keys 示例(兼容Java):") - fmt.Println(" - token:" + token2) - fmt.Println(" - account:user002:web") - fmt.Println(" - session:user002") - fmt.Println() - - // 方式3: 自定义前缀(多应用隔离) - fmt.Println("【方式3】自定义前缀 - 用于多应用隔离") - storage3 := memory.NewStorage() - mgr3 := core.NewBuilder(). - Storage(storage3). - TokenName("satoken"). - KeyPrefix("myapp:sa:"). // 自定义前缀 - IsPrintBanner(false). - Build() - - token3, _ := mgr3.Login("user003", "app") - fmt.Printf("✅ 登录成功,Token: %s\n", token3) - fmt.Println(" Redis Keys 示例:") - fmt.Println(" - myapp:sa:token:" + token3) - fmt.Println(" - myapp:sa:account:user003:app") - fmt.Println(" - myapp:sa:session:user003") - fmt.Println() - - // 关键配置说明 - fmt.Println("=" + "────────────────────────────────────────────────────────────" + "=") - fmt.Println("📝 关键配置说明:") - fmt.Println() - fmt.Println("1. 与 Java sa-token 互通:") - fmt.Println(" cfg.SetKeyPrefix(\"\") // 设置为空字符串") - fmt.Println(" 或") - fmt.Println(" builder.KeyPrefix(\"\") // Builder 方式") - fmt.Println() - fmt.Println("2. 多应用隔离:") - fmt.Println(" cfg.SetKeyPrefix(\"app1:\") // 应用1") - fmt.Println(" cfg.SetKeyPrefix(\"app2:\") // 应用2") - fmt.Println() - fmt.Println("3. 默认 Go 行为:") - fmt.Println(" cfg.SetKeyPrefix(\"satoken:\") // 默认值") - fmt.Println() - fmt.Println("=" + "────────────────────────────────────────────────────────────" + "=") -} diff --git a/demo/usage_example.go b/demo/usage_example.go deleted file mode 100644 index 0519ecb..0000000 --- a/demo/usage_example.go +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/docs/api/stputil.md b/docs/api/stputil.md index 52ba3ba..f5f1dbb 100644 --- a/docs/api/stputil.md +++ b/docs/api/stputil.md @@ -226,7 +226,7 @@ func SetRoles(loginID interface{}, roles []string) error **Example**: ```go -stputil.SetRoles(1000, []string{"admin", "manager"}) +stputil.SetRoles(1000, []string{"admin", "manager-example"}) ``` ### HasRole @@ -252,7 +252,7 @@ Multiple role check **Example**: ```go // AND logic -stputil.HasRolesAnd(1000, []string{"admin", "manager"}) +stputil.HasRolesAnd(1000, []string{"admin", "manager-example"}) // OR logic stputil.HasRolesOr(1000, []string{"admin", "super"}) diff --git a/docs/api/stputil_zh.md b/docs/api/stputil_zh.md index ed037ca..52c2368 100644 --- a/docs/api/stputil_zh.md +++ b/docs/api/stputil_zh.md @@ -226,7 +226,7 @@ func SetRoles(loginID interface{}, roles []string) error **示例**: ```go -stputil.SetRoles(1000, []string{"admin", "manager"}) +stputil.SetRoles(1000, []string{"admin", "manager-example"}) ``` ### HasRole @@ -252,7 +252,7 @@ if stputil.HasRole(1000, "admin") { **示例**: ```go // AND逻辑 -stputil.HasRolesAnd(1000, []string{"admin", "manager"}) +stputil.HasRolesAnd(1000, []string{"admin", "manager-example"}) // OR逻辑 stputil.HasRolesOr(1000, []string{"admin", "super"}) diff --git a/docs/guide/annotation.md b/docs/guide/annotation.md index a921fa7..770bcf0 100644 --- a/docs/guide/annotation.md +++ b/docs/guide/annotation.md @@ -38,7 +38,7 @@ r.GET("/admin", sagin.CheckRole("admin"), func(c *gin.Context) { }) // Requires any of the roles -r.GET("/dashboard", sagin.CheckRole("admin", "manager"), func(c *gin.Context) { +r.GET("/dashboard", sagin.CheckRole("admin", "manager-example"), func(c *gin.Context) { c.JSON(200, gin.H{"message": "Dashboard"}) }) ``` @@ -107,7 +107,7 @@ func main() { // Role required r.GET("/admin", sagin.CheckRole("admin"), adminHandler) - r.GET("/manager", sagin.CheckRole("admin", "manager"), managerHandler) + r.GET("/manager-example", sagin.CheckRole("admin", "manager-example"), managerHandler) // Permission required r.GET("/users", sagin.CheckPermission("user:read"), listUsersHandler) diff --git a/docs/guide/annotation_zh.md b/docs/guide/annotation_zh.md index 2f3ec47..768847e 100644 --- a/docs/guide/annotation_zh.md +++ b/docs/guide/annotation_zh.md @@ -56,10 +56,10 @@ r.DELETE("/admin/users/:id", sagin.CheckPermission("admin:*"), deleteUserHandler ```go // 需要admin角色 -r.GET("/manager", sagin.CheckRole("admin"), managerHandler) +r.GET("/manager-example", sagin.CheckRole("admin"), managerHandler) // 需要manager角色 -r.GET("/reports", sagin.CheckRole("manager"), reportsHandler) +r.GET("/reports", sagin.CheckRole("manager-example"), reportsHandler) ``` ### 5. 检查封禁 - @SaCheckDisable @@ -84,7 +84,7 @@ r.GET("/data", // 拥有admin或manager角色即可访问 r.GET("/dashboard", - sagin.CheckRole("admin", "manager"), + sagin.CheckRole("admin", "manager-example"), dashboardHandler) ``` @@ -151,7 +151,7 @@ func main() { r.DELETE("/users/:id", sagin.CheckPermission("user:delete"), deleteUserHandler) // 需要角色 - r.GET("/manager", sagin.CheckRole("manager"), managerHandler) + r.GET("/manager-example", sagin.CheckRole("manager-example"), managerHandler) // 检查封禁状态 r.GET("/sensitive", sagin.CheckDisable(), sensitiveHandler) diff --git a/docs/guide/authentication.md b/docs/guide/authentication.md index 726845d..ff1a226 100644 --- a/docs/guide/authentication.md +++ b/docs/guide/authentication.md @@ -86,7 +86,7 @@ fmt.Println("Create Time:", info.CreateTime) ### Concurrent Login ```go -// Allow concurrent login (default: true) +// Allow concurrent login (log: true) core.NewBuilder(). IsConcurrent(true). Build() @@ -95,7 +95,7 @@ core.NewBuilder(). ### Share Token ```go -// Share token for concurrent logins (default: true) +// Share token for concurrent logins (log: true) core.NewBuilder(). IsShare(true). Build() diff --git a/docs/guide/jwt.md b/docs/guide/jwt.md index d8c1b0f..edf58c4 100644 --- a/docs/guide/jwt.md +++ b/docs/guide/jwt.md @@ -115,4 +115,4 @@ JwtSecretKey(os.Getenv("JWT_SECRET_KEY")) - [Quick Start](../tutorial/quick-start.md) - [Authentication Guide](authentication.md) -- [JWT Example](../../examples/jwt-example/README.md) +- [JWT Example](../../examples/manager-example/jwt-example/README.md) diff --git a/docs/guide/jwt_zh.md b/docs/guide/jwt_zh.md index c6079b9..8d05b1e 100644 --- a/docs/guide/jwt_zh.md +++ b/docs/guide/jwt_zh.md @@ -404,7 +404,7 @@ curl http://localhost:8080/user/info \ - [快速开始](../tutorial/quick-start.md) - [认证指南](authentication.md) - [配置说明](configuration.md) -- [JWT 示例代码](../../examples/jwt-example/) +- [JWT 示例代码](../../examples/manager-example/jwt-example/) ## 在线工具 diff --git a/docs/guide/listener.md b/docs/guide/listener.md index dc2b65d..6eccad8 100644 --- a/docs/guide/listener.md +++ b/docs/guide/listener.md @@ -151,4 +151,4 @@ manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { - [Quick Start](../tutorial/quick-start.md) - [Authentication Guide](authentication.md) -- [Event Listener Example](../../examples/listener-example/README.md) +- [Event Listener Example](../../examples/manager-example/listener-example/README.md) diff --git a/docs/guide/listener_zh.md b/docs/guide/listener_zh.md index 75c716b..aba5960 100644 --- a/docs/guide/listener_zh.md +++ b/docs/guide/listener_zh.md @@ -77,7 +77,7 @@ import ( ) func main() { - // Create manager with default event support + // Create manager-example with log event support manager := core.NewBuilder(). Storage(memory.NewStorage()). Build() diff --git a/docs/guide/nonce.md b/docs/guide/nonce.md index acb53ea..9eccee0 100644 --- a/docs/guide/nonce.md +++ b/docs/guide/nonce.md @@ -222,7 +222,7 @@ POST /search // Search // Quick operations (1 minute) core.NewNonceManager(storage, 60) -// Form submissions (5 minutes, default) +// Form submissions (5 minutes, log) core.NewNonceManager(storage, 300) // Long processes (10 minutes) diff --git a/docs/guide/redis-storage_zh.md b/docs/guide/redis-storage_zh.md index ba136a0..6f4bd1c 100644 --- a/docs/guide/redis-storage_zh.md +++ b/docs/guide/redis-storage_zh.md @@ -408,7 +408,7 @@ Key: satoken:token:6R9twUC-OL_uL6JQFKfncyoVuK3NlDL2... Value: 1000 # 只是简单的字符串(4 bytes) # Account 键(loginID -> Token) -Key: satoken:account:1000:default +Key: satoken:account:1000:log Value: 6R9twUC-OL_uL6JQFKfncyoVuK3NlDL2... # Session 键(存储完整用户对象和自定义数据) @@ -438,7 +438,7 @@ GET satoken:token:6R9twUC-OL_uL6JQFKfncyoVuK3NlDL2... # 输出: "1000" # 查看 Account 映射(返回 Token) -GET satoken:account:1000:default +GET satoken:account:1000:log # 输出: "6R9twUC-OL_uL6JQFKfncyoVuK3NlDL2..." # 查看用户 Session(包含完整用户数据) diff --git a/docs/guide/single-import.md b/docs/guide/single-import.md index e48c9a8..09ecf15 100644 --- a/docs/guide/single-import.md +++ b/docs/guide/single-import.md @@ -114,10 +114,10 @@ func main() { config.Timeout = 7200 // 2 hours config.IsPrint = true - // 3. Create manager (from sagin package) + // 3. Create manager-example (from sagin package) manager := sagin.NewManager(storage, config) - // 4. Set global manager (from sagin package) + // 4. Set global manager-example (from sagin package) sagin.SetManager(manager) // 5. Create Gin router @@ -230,11 +230,11 @@ All functions from `core` and `stputil` are re-exported in framework integration ### Configuration & Initialization ```go -config := sagin.DefaultConfig() // Create default config -manager := sagin.NewManager(storage, cfg) // Create manager +config := sagin.DefaultConfig() // Create log config +manager := sagin.NewManager(storage, cfg) // Create manager-example builder := sagin.NewBuilder() // Create builder -sagin.SetManager(manager) // Set global manager -manager := sagin.GetManager() // Get global manager +sagin.SetManager(manager) // Set global manager-example +manager := sagin.GetManager() // Get global manager-example ``` ### Authentication diff --git a/examples/annotation/annotation-example/README.md b/examples/annotation/annotation-example/README.md index a6ffdb5..73d5009 100644 --- a/examples/annotation/annotation-example/README.md +++ b/examples/annotation/annotation-example/README.md @@ -1,6 +1,6 @@ # 注解装饰器示例 -本示例演示如何在 Gin 框架中使用 Sa-Token-Go 的注解装饰器(类似 Java 的 `@SaCheckLogin`、`@SaCheckRole` 等)。 +本示例演示如何在 Gin 框架中使用 Sa-Token-Go 的中间件装饰器(类似 Java 的 `@SaCheckLogin`、`@SaCheckRole` 等)。 ## 运行示例 @@ -10,44 +10,44 @@ go run main.go 服务器将在 `http://localhost:8080` 启动。 -## 注解装饰器 +## 中间件装饰器 -Sa-Token-Go 提供了类似 Java 注解的装饰器函数: +Sa-Token-Go 提供了类似 Java 注解的中间件函数: -### CheckLogin - 检查登录 +### CheckLoginMiddleware - 检查登录 ```go -r.GET("/user/info", sagin.CheckLogin(), handler.GetUserInfo) +r.GET("/user/info", sagin.CheckLoginMiddleware(), handler.GetUserInfo) ``` -### CheckRole - 检查角色 +### CheckRoleMiddleware - 检查角色 ```go -r.GET("/manager", sagin.CheckRole("admin"), handler.GetManagerData) +r.GET("/manager", sagin.CheckRoleMiddleware("admin"), handler.GetManagerData) ``` -### CheckPermission - 检查权限 +### CheckPermissionMiddleware - 检查权限 ```go // 单个权限 -r.GET("/admin", sagin.CheckPermission("admin:*"), handler.GetAdminData) +r.GET("/admin", sagin.CheckPermissionMiddleware("admin:*"), handler.GetAdminData) // 多个权限(OR 逻辑) -r.GET("/user-or-admin", - sagin.CheckPermission("user:read", "admin:*"), +r.GET("/user-or-admin", + sagin.CheckPermissionMiddleware("user:read", "admin:*"), handler.GetUserOrAdmin) ``` -### CheckDisable - 检查是否被封禁 +### CheckDisableMiddleware - 检查是否被封禁 ```go -r.GET("/sensitive", sagin.CheckDisable(), handler.GetSensitiveData) +r.GET("/sensitive", sagin.CheckDisableMiddleware(), handler.GetSensitiveData) ``` -### Ignore - 忽略认证 +### IgnoreMiddleware - 忽略认证 ```go -r.GET("/public", sagin.Ignore(), handler.GetPublic) +r.GET("/public", sagin.IgnoreMiddleware(), handler.GetPublic) ``` ## 完整示例 @@ -58,77 +58,148 @@ package main import ( "net/http" "time" - + "github.com/gin-gonic/gin" - "github.com/click33/sa-token-go/core" sagin "github.com/click33/sa-token-go/integrations/gin" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/storage/memory" ) func init() { - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - Build(), + // 初始化 Manager + sagin.SetManager( + sagin.NewDefaultBuild().Build(), ) } -func main() { - r := gin.Default() - - // 登录接口(公开) - r.POST("/login", loginHandler) - - // 使用注解装饰器 - handler := &UserHandler{} - - // 公开访问 - r.GET("/public", sagin.Ignore(), handler.GetPublic) - - // 需要登录 - r.GET("/user/info", sagin.CheckLogin(), handler.GetUserInfo) - - // 需要管理员权限 - r.GET("/admin", sagin.CheckPermission("admin:*"), handler.GetAdminData) - - // 需要管理员角色 - r.GET("/manager", sagin.CheckRole("admin"), handler.GetManagerData) - - // 检查账号是否被封禁 - r.GET("/sensitive", sagin.CheckDisable(), handler.GetSensitiveData) - - r.Run(":8080") +// 处理器结构体 +type UserHandler struct{} + +// 公开访问 - 忽略认证 +func (h *UserHandler) GetPublic(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "message": "这是公开接口,不需要登录", + }) +} + +// 需要登录 +func (h *UserHandler) GetUserInfo(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(http.StatusOK, gin.H{ + "message": "用户个人信息", + "loginId": loginID, + }) +} + +// 需要管理员权限 +func (h *UserHandler) GetAdminData(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(http.StatusOK, gin.H{ + "message": "管理员数据", + "loginId": loginID, + "data": "这是管理员专有的数据", + }) } +// 登录接口 func loginHandler(c *gin.Context) { var req struct { UserID int `json:"userId"` } - c.ShouldBindJSON(&req) - + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "参数错误"}) + return + } + + ctx := c.Request.Context() + // 登录 - token, _ := stputil.Login(req.UserID) - - // 设置权限和角色 - stputil.SetPermissions(req.UserID, []string{"user:read", "admin:*"}) - stputil.SetRoles(req.UserID, []string{"admin"}) - + token, err := sagin.Login(ctx, req.UserID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "登录失败"}) + return + } + + // 设置权限和角色(模拟) + if req.UserID == 1 { + _ = sagin.SetPermissions(ctx, req.UserID, []string{"user:read", "user:write", "admin:*"}) + _ = sagin.SetRoles(ctx, req.UserID, []string{"admin", "manager"}) + } else { + _ = sagin.SetPermissions(ctx, req.UserID, []string{"user:read", "user:write"}) + _ = sagin.SetRoles(ctx, req.UserID, []string{"user"}) + } + c.JSON(http.StatusOK, gin.H{ - "token": token, + "token": token, "message": "登录成功", }) } + +// 封禁账号接口 +func disableHandler(c *gin.Context) { + var req struct { + UserID int `json:"userId"` + } + + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "参数错误"}) + return + } + + ctx := c.Request.Context() + + // 封禁账号1小时 + _ = sagin.Disable(ctx, req.UserID, 1*time.Hour) + + c.JSON(http.StatusOK, gin.H{ + "message": "账号已封禁1小时", + }) +} + +func main() { + r := gin.Default() + + // 登录接口(公开) + r.POST("/login", loginHandler) + + // 封禁接口(需要管理员权限) + r.POST("/disable", sagin.CheckPermissionMiddleware("admin:*"), disableHandler) + + // 使用装饰器模式设置路由 + handler := &UserHandler{} + + // 公开访问 - 忽略认证 + r.GET("/public", sagin.IgnoreMiddleware(), handler.GetPublic) + + // 需要登录 + r.GET("/user/info", sagin.CheckLoginMiddleware(), handler.GetUserInfo) + + // 需要管理员权限 + r.GET("/admin", sagin.CheckPermissionMiddleware("admin:*"), handler.GetAdminData) + + // 需要用户权限或管理员权限(OR逻辑) + r.GET("/user-or-admin", + sagin.CheckPermissionMiddleware("user:read", "admin:*"), + handler.GetUserOrAdmin) + + // 需要管理员角色 + r.GET("/manager", sagin.CheckRoleMiddleware("admin"), handler.GetManagerData) + + // 检查账号是否被封禁 + r.GET("/sensitive", sagin.CheckDisableMiddleware(), handler.GetSensitiveData) + + // 启动服务器 + _ = r.Run(":8080") +} ``` ## API 测试 -### 1. 登录 +### 1. 登录(用户ID=1 获得管理员权限) ```bash curl -X POST http://localhost:8080/login \ -H "Content-Type: application/json" \ - -d '{"userId": 1000}' + -d '{"userId": 1}' ``` 响应: @@ -139,33 +210,86 @@ curl -X POST http://localhost:8080/login \ } ``` -### 2. 访问公开接口(无需登录) +### 2. 登录(普通用户) + +```bash +curl -X POST http://localhost:8080/login \ + -H "Content-Type: application/json" \ + -d '{"userId": 2}' +``` + +### 3. 访问公开接口(无需登录) ```bash curl http://localhost:8080/public ``` -### 3. 访问需要登录的接口 +响应: +```json +{ + "message": "这是公开接口,不需要登录" +} +``` + +### 4. 访问需要登录的接口 ```bash curl http://localhost:8080/user/info \ -H "Authorization: YOUR_TOKEN" ``` -### 4. 访问需要权限的接口 +响应: +```json +{ + "message": "用户个人信息", + "loginId": "1" +} +``` + +### 5. 访问需要管理员权限的接口 ```bash curl http://localhost:8080/admin \ -H "Authorization: YOUR_TOKEN" ``` -### 5. 封禁账号 +响应(管理员): +```json +{ + "message": "管理员数据", + "loginId": "1", + "data": "这是管理员专有的数据" +} +``` + +### 6. 访问需要角色的接口 + +```bash +curl http://localhost:8080/manager \ + -H "Authorization: YOUR_TOKEN" +``` + +### 7. 封禁账号(需要管理员权限) ```bash curl -X POST http://localhost:8080/disable \ -H "Content-Type: application/json" \ -H "Authorization: YOUR_TOKEN" \ - -d '{"userId": 1000}' + -d '{"userId": 2}' +``` + +响应: +```json +{ + "message": "账号已封禁1小时" +} +``` + +### 8. 访问敏感数据(检查封禁状态) + +```bash +curl http://localhost:8080/sensitive \ + -H "Authorization: YOUR_TOKEN" ``` ## 注解对比 @@ -184,25 +308,52 @@ public Result getUserInfo() { public Result getAdminData() { return Result.success(); } + +@SaCheckPermission("admin:*") +@GetMapping("/admin/data") +public Result getAdminOnlyData() { + return Result.success(); +} ``` ### Go (Sa-Token-Go) ```go -r.GET("/user/info", sagin.CheckLogin(), handler.GetUserInfo) +r.GET("/user/info", sagin.CheckLoginMiddleware(), handler.GetUserInfo) + +r.GET("/admin", sagin.CheckRoleMiddleware("admin"), handler.GetAdminData) -r.GET("/admin", sagin.CheckRole("admin"), handler.GetAdminData) +r.GET("/admin/data", sagin.CheckPermissionMiddleware("admin:*"), handler.GetAdminOnlyData) ``` +## 中间件说明 + +| 中间件 | 说明 | 对应 Java 注解 | +|--------|------|----------------| +| `CheckLoginMiddleware()` | 检查是否已登录 | `@SaCheckLogin` | +| `CheckRoleMiddleware(roles...)` | 检查是否拥有指定角色 | `@SaCheckRole` | +| `CheckPermissionMiddleware(perms...)` | 检查是否拥有指定权限 | `@SaCheckPermission` | +| `CheckDisableMiddleware()` | 检查账号是否被封禁 | `@SaCheckDisable` | +| `IgnoreMiddleware()` | 忽略认证检查 | `@SaIgnore` | + ## 优势 -- ✅ **声明式编程** - 代码更简洁、可读性更强 -- ✅ **统一验证** - 自动处理认证和授权逻辑 -- ✅ **错误处理** - 自动返回标准错误响应 -- ✅ **灵活组合** - 可以组合使用多个装饰器 +- **声明式编程** - 代码更简洁、可读性更强 +- **统一验证** - 自动处理认证和授权逻辑 +- **错误处理** - 自动返回标准错误响应 +- **灵活组合** - 可以组合使用多个中间件 +- **权限模式** - 支持通配符权限匹配(如 `admin:*`) + +## 权限说明 + +本示例中的权限分配: + +| 用户 ID | 权限 | 角色 | +|---------|------|------| +| 1 | `user:read`, `user:write`, `admin:*` | `admin`, `manager` | +| 其他 | `user:read`, `user:write` | `user` | ## 更多示例 - [快速开始](../../quick-start/simple-example) - 学习基础用法 - [Gin 集成](../../gin/gin-example) - 完整的 Gin 集成示例 - diff --git a/examples/annotation/annotation-example/go.mod b/examples/annotation/annotation-example/go.mod index a830f10..595de28 100644 --- a/examples/annotation/annotation-example/go.mod +++ b/examples/annotation/annotation-example/go.mod @@ -1,48 +1,47 @@ module github.com/click33/sa-token-go/examples/annotation-example -go 1.21 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/integrations/gin v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 + github.com/click33/sa-token-go/integrations/gin v0.1.7 github.com/gin-gonic/gin v1.10.0 ) require ( github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/click33/sa-token-go/core v0.1.7 // indirect + github.com/click33/sa-token-go/storage/memory v0.1.7 // indirect + github.com/click33/sa-token-go/stputil v0.1.7 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect - google.golang.org/protobuf v1.34.1 // indirect + golang.org/x/crypto v0.41.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/integrations/gin => ../../../integrations/gin - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory - github.com/click33/sa-token-go/stputil => ../../../stputil -) diff --git a/examples/annotation/annotation-example/go.sum b/examples/annotation/annotation-example/go.sum new file mode 100644 index 0000000..8ac49e2 --- /dev/null +++ b/examples/annotation/annotation-example/go.sum @@ -0,0 +1,53 @@ +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/core v0.1.6/go.mod h1:mb3AQAJIXqx9WdULyn5qjufK1j/u+kgB0q+tafHVhgk= +github.com/click33/sa-token-go/integrations/gin v0.1.5 h1:OYAldpSyibG6acbFOdIoEpRH3FAxBVGafFExyz6B9Og= +github.com/click33/sa-token-go/integrations/gin v0.1.5/go.mod h1:SFThmz7E84VGrGQa7RlMH/XLFGqxUfHSmbeFKmNgmbs= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/storage/memory v0.1.6/go.mod h1:YNojcgyLC/uFrmReZLePCDQ5WK2fo2WWGRjRMvXVH74= +github.com/click33/sa-token-go/stputil v0.1.5 h1:603tbI4JkBTg3MnfTj+lCMDxJOKSCOqsMyC2zyuvEco= +github.com/click33/sa-token-go/stputil v0.1.5/go.mod h1:YH+3NLXgGJfrS2wkGubMWFnr/Nk0GgejOtRxcE+9x0c= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/redis/go-redis/v9 v9.5.1 h1:H1X4D3yHPaYrkL5X06Wh6xNVM/pX0Ft4RV0vMGvLBh8= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/annotation/annotation-example/main.go b/examples/annotation/annotation-example/main.go index 2ffa270..5053818 100644 --- a/examples/annotation/annotation-example/main.go +++ b/examples/annotation/annotation-example/main.go @@ -4,19 +4,14 @@ import ( "net/http" "time" - "github.com/click33/sa-token-go/core" sagin "github.com/click33/sa-token-go/integrations/gin" - "github.com/click33/sa-token-go/storage/memory" - "github.com/click33/sa-token-go/stputil" "github.com/gin-gonic/gin" ) func init() { - // 初始化StpUtil - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - Build(), + // 初始化 Manager + sagin.SetManager( + sagin.NewDefaultBuild().Build(), ) } @@ -32,7 +27,7 @@ func (h *UserHandler) GetPublic(c *gin.Context) { // 需要登录 func (h *UserHandler) GetUserInfo(c *gin.Context) { - loginID, _ := stputil.GetLoginID(c.GetHeader("Authorization")) + loginID, _ := sagin.GetLoginIDFromRequest(c) c.JSON(http.StatusOK, gin.H{ "message": "用户个人信息", @@ -42,7 +37,7 @@ func (h *UserHandler) GetUserInfo(c *gin.Context) { // 需要管理员权限 func (h *UserHandler) GetAdminData(c *gin.Context) { - loginID, _ := stputil.GetLoginID(c.GetHeader("Authorization")) + loginID, _ := sagin.GetLoginIDFromRequest(c) c.JSON(http.StatusOK, gin.H{ "message": "管理员数据", @@ -53,7 +48,7 @@ func (h *UserHandler) GetAdminData(c *gin.Context) { // 需要多个权限之一 func (h *UserHandler) GetUserOrAdmin(c *gin.Context) { - loginID, _ := stputil.GetLoginID(c.GetHeader("Authorization")) + loginID, _ := sagin.GetLoginIDFromRequest(c) c.JSON(http.StatusOK, gin.H{ "message": "用户或管理员都可以访问", @@ -63,7 +58,7 @@ func (h *UserHandler) GetUserOrAdmin(c *gin.Context) { // 需要特定角色 func (h *UserHandler) GetManagerData(c *gin.Context) { - loginID, _ := stputil.GetLoginID(c.GetHeader("Authorization")) + loginID, _ := sagin.GetLoginIDFromRequest(c) c.JSON(http.StatusOK, gin.H{ "message": "经理数据", @@ -73,7 +68,7 @@ func (h *UserHandler) GetManagerData(c *gin.Context) { // 检查账号是否被封禁 func (h *UserHandler) GetSensitiveData(c *gin.Context) { - loginID, _ := stputil.GetLoginID(c.GetHeader("Authorization")) + loginID, _ := sagin.GetLoginIDFromRequest(c) c.JSON(http.StatusOK, gin.H{ "message": "敏感数据", @@ -92,16 +87,23 @@ func loginHandler(c *gin.Context) { return } + ctx := c.Request.Context() + // 登录 - token, err := stputil.Login(req.UserID) + token, err := sagin.Login(ctx, req.UserID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": "登录失败"}) return } // 设置权限和角色(模拟) - stputil.SetPermissions(req.UserID, []string{"user:read", "user:write", "admin:*"}) - stputil.SetRoles(req.UserID, []string{"admin", "manager"}) + if req.UserID == 1 { + _ = sagin.SetPermissions(ctx, req.UserID, []string{"user:read", "user:write", "admin:*"}) + _ = sagin.SetRoles(ctx, req.UserID, []string{"admin", "manager-example"}) + } else { + _ = sagin.SetPermissions(ctx, req.UserID, []string{"user:read", "user:write"}) + _ = sagin.SetRoles(ctx, req.UserID, []string{"user"}) + } c.JSON(http.StatusOK, gin.H{ "token": token, @@ -120,8 +122,10 @@ func disableHandler(c *gin.Context) { return } + ctx := c.Request.Context() + // 封禁账号1小时 - stputil.Disable(req.UserID, 1*time.Hour) + _ = sagin.Disable(ctx, req.UserID, 1*time.Hour) c.JSON(http.StatusOK, gin.H{ "message": "账号已封禁1小时", @@ -135,31 +139,31 @@ func main() { r.POST("/login", loginHandler) // 封禁接口(需要管理员权限) - r.POST("/disable", sagin.CheckPermission("admin:*"), disableHandler) + r.POST("/disable", sagin.CheckPermissionMiddleware("admin:*"), disableHandler) // 使用装饰器模式设置路由 handler := &UserHandler{} // 公开访问 - 忽略认证 - r.GET("/public", sagin.Ignore(), handler.GetPublic) + r.GET("/public", sagin.IgnoreMiddleware(), handler.GetPublic) // 需要登录 - r.GET("/user/info", sagin.CheckLogin(), handler.GetUserInfo) + r.GET("/user/info", sagin.CheckLoginMiddleware(), handler.GetUserInfo) // 需要管理员权限 - r.GET("/admin", sagin.CheckPermission("admin:*"), handler.GetAdminData) + r.GET("/admin", sagin.CheckPermissionMiddleware("admin:*"), handler.GetAdminData) // 需要用户权限或管理员权限(OR逻辑) r.GET("/user-or-admin", - sagin.CheckPermission("user:read", "admin:*"), + sagin.CheckPermissionMiddleware("user:read", "admin:*"), handler.GetUserOrAdmin) // 需要管理员角色 - r.GET("/manager", sagin.CheckRole("admin"), handler.GetManagerData) + r.GET("/manager-example", sagin.CheckRoleMiddleware("admin"), handler.GetManagerData) // 检查账号是否被封禁 - r.GET("/sensitive", sagin.CheckDisable(), handler.GetSensitiveData) + r.GET("/sensitive", sagin.CheckDisableMiddleware(), handler.GetSensitiveData) // 启动服务器 - r.Run(":8080") + _ = r.Run(":8080") } diff --git a/examples/annotation/annotation-example/test/api_test.go b/examples/annotation/annotation-example/test/api_test.go new file mode 100644 index 0000000..5878ebf --- /dev/null +++ b/examples/annotation/annotation-example/test/api_test.go @@ -0,0 +1,161 @@ +// @Author daixk 2026/1/4 15:57:00 +package test + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "testing" +) + +// +// 工具函数 +// + +// 简单封装 HTTP 请求 +func doRequest(t *testing.T, method, url string, body any, token string) { + var reqBody io.Reader + if body != nil { + b, _ := json.Marshal(body) + reqBody = bytes.NewReader(b) + } + + req, err := http.NewRequest(method, url, reqBody) + if err != nil { + t.Fatalf("创建请求错误: %v", err) + } + + req.Header.Set("Content-Type", "application/json") + if token != "" { + req.Header.Set("Authorization", token) + } + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("请求失败: %v", err) + } + defer resp.Body.Close() + + b, _ := io.ReadAll(resp.Body) + fmt.Printf("\n[%s %s] 返回结果:\n%s\n\n", method, url, b) +} + +// +// 每个接口单独测试 +// + +// 1. 公开接口 +func TestPublic(t *testing.T) { + doRequest(t, "GET", "http://localhost:8080/public", nil, "") +} + +// 2. 登录 +func TestLoginUser1(t *testing.T) { + doRequest(t, "POST", "http://localhost:8080/login", map[string]any{ + "userId": 1, + }, "") +} + +func TestLoginUser2(t *testing.T) { + doRequest(t, "POST", "http://localhost:8080/login", map[string]any{ + "userId": 2, + }, "") +} + +// 3. 获取用户信息(需要登录) +func TestUserInfo(t *testing.T) { + // 先登录 + token := getToken(t, 2) + + // 请求 + doRequest(t, "GET", "http://localhost:8080/user/info", nil, token) +} + +// 4. 管理员接口(admin:*) +func TestAdmin(t *testing.T) { + token := getToken(t, 1) + + doRequest(t, "GET", "http://localhost:8080/admin", nil, token) +} + +// 5. 用户 or 管理员 OR 权限 +func TestUserOrAdmin(t *testing.T) { + token := getToken(t, 2) + + doRequest(t, "GET", "http://localhost:8080/user-or-admin", nil, token) +} + +// 6. 测试角色:admin +func TestRoleManager(t *testing.T) { + token := getToken(t, 1) + + doRequest(t, "GET", "http://localhost:8080/manager-example", nil, token) +} + +// 7. 测试封禁接口 +func TestDisable(t *testing.T) { + token := getToken(t, 1) + + doRequest(t, "POST", "http://localhost:8080/disable", map[string]any{ + "userId": 2, + }, token) +} + +// 8. 查看是否被封禁 +func TestSensitive(t *testing.T) { + token := getToken(t, 2) + + doRequest(t, "GET", "http://localhost:8080/sensitive", nil, token) +} + +// 工具:登录并返回 token +func getToken(t *testing.T, userID int) string { + var respBody struct { + Token string `json:"token"` + } + + body := map[string]any{"userId": userID} + + var reqBody io.Reader + b, _ := json.Marshal(body) + reqBody = bytes.NewReader(b) + + req, _ := http.NewRequest("POST", "http://localhost:8080/login", reqBody) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("登录失败: %v", err) + } + defer resp.Body.Close() + + data, _ := io.ReadAll(resp.Body) + _ = json.Unmarshal(data, &respBody) + + if respBody.Token == "" { + t.Fatalf("登录返回 token 为空: %s", data) + } + + return respBody.Token +} + +// +// 最终测试:执行全部接口 +// + +func TestAll(t *testing.T) { + t.Run("Public", TestPublic) + + t.Run("LoginUser1", TestLoginUser1) + t.Run("LoginUser2", TestLoginUser2) + + t.Run("UserInfo", TestUserInfo) + t.Run("Admin", TestAdmin) + t.Run("UserOrAdmin", TestUserOrAdmin) + t.Run("RoleManager", TestRoleManager) + + t.Run("Disable", TestDisable) + t.Run("Sensitive", TestSensitive) +} diff --git a/examples/chi/chi-example/cmd/main.go b/examples/chi/chi-example/cmd/main.go index e6cbcd3..3d41a29 100644 --- a/examples/chi/chi-example/cmd/main.go +++ b/examples/chi/chi-example/cmd/main.go @@ -5,66 +5,180 @@ import ( "log" "net/http" - "github.com/click33/sa-token-go/core" sachi "github.com/click33/sa-token-go/integrations/chi" - "github.com/click33/sa-token-go/storage/memory" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" ) func main() { - // 创建存储 - storage := memory.NewStorage() + // 使用 Builder 模式构建 Manager | Build Manager using Builder pattern + manager := sachi.NewDefaultBuild(). + TokenName("Authorization"). + Timeout(7200). + IsLog(true). + IsPrintBanner(true). + Build() - // 创建配置 - config := core.DefaultConfig() - config.TokenName = "Authorization" - config.Timeout = 7200 // 2小时 + // 设置全局管理器 | Set global manager + sachi.SetManager(manager) - // 创建管理器 - manager := core.NewManager(storage, config) - - // 创建Chi插件 - plugin := sachi.NewPlugin(manager) - - // 创建路由 + // 创建路由 | Create router r := chi.NewRouter() r.Use(middleware.Logger) r.Use(middleware.Recoverer) - // 公开路由 - r.Post("/login", plugin.LoginHandler) + // 登录接口 | Login endpoint + r.Post("/login", func(w http.ResponseWriter, r *http.Request) { + userID := r.FormValue("user_id") + if userID == "" { + http.Error(w, `{"error": "user_id is required"}`, http.StatusBadRequest) + return + } + + ctx := r.Context() + + // 使用 sachi 包的全局函数登录 | Use sachi package global function to login + token, err := sachi.Login(ctx, userID) + if err != nil { + http.Error(w, `{"error": "`+err.Error()+`"}`, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "登录成功", "token": "` + token + `"}`)) + }) + + // 登出接口 | Logout endpoint + r.Post("/logout", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + if token == "" { + http.Error(w, `{"error": "token is required"}`, http.StatusBadRequest) + return + } + + ctx := r.Context() + + // 使用 sachi 包的全局函数登出 | Use sachi package global function to logout + if err := sachi.LogoutByToken(ctx, token); err != nil { + http.Error(w, `{"error": "`+err.Error()+`"}`, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "登出成功"}`)) + }) + + // 公开路由 | Public route r.Get("/public", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "message": "公开访问", - }) + w.Write([]byte(`{"message": "公开访问"}`)) }) - // 受保护路由 + // 检查登录状态 | Check login status + r.Get("/check", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + if token == "" { + http.Error(w, `{"error": "token is required"}`, http.StatusBadRequest) + return + } + + ctx := r.Context() + + // 使用 sachi 包的全局函数检查登录 | Use sachi package global function to check login + isLogin := sachi.IsLogin(ctx, token) + if !isLogin { + http.Error(w, `{"error": "未登录"}`, http.StatusUnauthorized) + return + } + + // 获取登录ID | Get login ID + loginID, _ := sachi.GetLoginID(ctx, token) + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "已登录", "login_id": "` + loginID + `"}`)) + }) + + // 受保护的路由组 | Protected route group r.Group(func(r chi.Router) { - r.Use(plugin.AuthMiddleware()) - r.Get("/api/user/info", func(w http.ResponseWriter, r *http.Request) { - saCtx, _ := sachi.GetSaToken(r) - loginID, _ := saCtx.GetLoginID() - permissions, _ := manager.GetPermissions(loginID) - roles, _ := manager.GetRoles(loginID) + r.Use(sachi.AuthMiddleware()) + + // 用户信息 | User info + r.Get("/api/user", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + ctx := r.Context() + loginID, _ := sachi.GetLoginID(ctx, token) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ + w.Write([]byte(`{"user_id": "` + loginID + `", "name": "User ` + loginID + `"}`)) + }) + + // 获取 Token 信息 | Get token info + r.Get("/api/token-info", func(w http.ResponseWriter, r *http.Request) { + token := r.Header.Get("Authorization") + ctx := r.Context() + + tokenInfo, err := sachi.GetTokenInfo(ctx, token) + if err != nil { + http.Error(w, `{"error": "`+err.Error()+`"}`, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + response := map[string]interface{}{ "code": 200, - "message": "success", + "message": "获取Token信息成功", "data": map[string]interface{}{ - "loginId": loginID, - "permissions": permissions, - "roles": roles, + "authType": tokenInfo.AuthType, + "loginId": tokenInfo.LoginID, + "device": tokenInfo.Device, + "createTime": tokenInfo.CreateTime, + "activeTime": tokenInfo.ActiveTime, }, - }) + } + json.NewEncoder(w).Encode(response) + }) + + // 踢人下线 | Kickout user + r.Post("/api/kickout/{user_id}", func(w http.ResponseWriter, r *http.Request) { + userID := chi.URLParam(r, "user_id") + ctx := r.Context() + + // 使用 sachi 包的全局函数踢人 | Use sachi package global function to kickout + if err := sachi.Kickout(ctx, userID); err != nil { + http.Error(w, `{"error": "`+err.Error()+`"}`, http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "踢人成功"}`)) + }) + }) + + // 需要权限的路由组 | Routes requiring permissions + r.Group(func(r chi.Router) { + r.Use(sachi.AuthMiddleware()) + r.Use(sachi.PermissionMiddleware([]string{"admin:read"}, sachi.WithLogicType(sachi.LogicOr))) + + r.Get("/admin/dashboard", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "管理员面板"}`)) + }) + }) + + // 需要角色的路由组 | Routes requiring roles + r.Group(func(r chi.Router) { + r.Use(sachi.AuthMiddleware()) + r.Use(sachi.RoleMiddleware([]string{"super-admin"}, sachi.WithLogicType(sachi.LogicAnd))) + + r.Get("/super/settings", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{"message": "超级管理员设置"}`)) }) }) - // 启动服务器 + // 启动服务器 | Start server log.Println("服务器启动在端口: 8080") + log.Println("示例: curl -X POST http://localhost:8080/login -d 'user_id=1000'") if err := http.ListenAndServe(":8080", r); err != nil { log.Fatal("服务器启动失败:", err) } diff --git a/examples/chi/chi-example/go.mod b/examples/chi/chi-example/go.mod index 509b792..116ca71 100644 --- a/examples/chi/chi-example/go.mod +++ b/examples/chi/chi-example/go.mod @@ -1,21 +1,14 @@ module github.com/click33/sa-token-go/examples/chi-example -go 1.21 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/integrations/chi v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 github.com/go-chi/chi/v5 v5.0.11 ) require ( - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/integrations/chi => ../../../integrations/chi - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect ) diff --git a/examples/chi/chi-example/go.sum b/examples/chi/chi-example/go.sum new file mode 100644 index 0000000..6c47202 --- /dev/null +++ b/examples/chi/chi-example/go.sum @@ -0,0 +1,11 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/go-chi/chi/v5 v5.0.11 h1:BnpYbFZ3T3S1WMpD79r7R5ThWX40TaFB7L31Y8xqSwA= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/echo/echo-example/cmd/main.go b/examples/echo/echo-example/cmd/main.go index 3e4f1ed..0a55c00 100644 --- a/examples/echo/echo-example/cmd/main.go +++ b/examples/echo/echo-example/cmd/main.go @@ -2,66 +2,193 @@ package main import ( "log" + "net/http" - "github.com/click33/sa-token-go/core" saecho "github.com/click33/sa-token-go/integrations/echo" - "github.com/click33/sa-token-go/storage/memory" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) func main() { - // 创建存储 - storage := memory.NewStorage() + // 使用 Builder 模式构建 Manager | Build Manager using Builder pattern + manager := saecho.NewDefaultBuild(). + TokenName("Authorization"). + Timeout(7200). + IsLog(true). + IsPrintBanner(true). + Build() - // 创建配置 - config := core.DefaultConfig() - config.TokenName = "Authorization" - config.Timeout = 7200 // 2小时 + // 设置全局管理器 | Set global manager + saecho.SetManager(manager) - // 创建管理器 - manager := core.NewManager(storage, config) - - // 创建Echo插件 - plugin := saecho.NewPlugin(manager) - - // 创建Echo实例 + // 创建Echo实例 | Create Echo instance e := echo.New() e.Use(middleware.Logger()) e.Use(middleware.Recover()) - // 公开路由 - e.POST("/login", plugin.LoginHandler) + // 登录接口 | Login endpoint + e.POST("/login", func(c echo.Context) error { + userID := c.FormValue("user_id") + if userID == "" { + return c.JSON(http.StatusBadRequest, map[string]interface{}{ + "error": "user_id is required", + }) + } + + ctx := c.Request().Context() + + // 使用 saecho 包的全局函数登录 | Use saecho package global function to login + token, err := saecho.Login(ctx, userID) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]interface{}{ + "message": "登录成功", + "token": token, + }) + }) + + // 登出接口 | Logout endpoint + e.POST("/logout", func(c echo.Context) error { + token := c.Request().Header.Get("Authorization") + if token == "" { + return c.JSON(http.StatusBadRequest, map[string]interface{}{ + "error": "token is required", + }) + } + + ctx := c.Request().Context() + + // 使用 saecho 包的全局函数登出 | Use saecho package global function to logout + if err := saecho.LogoutByToken(ctx, token); err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]interface{}{ + "message": "登出成功", + }) + }) + + // 公开路由 | Public route e.GET("/public", func(c echo.Context) error { - return c.JSON(200, map[string]interface{}{ + return c.JSON(http.StatusOK, map[string]interface{}{ "message": "公开访问", }) }) - // 受保护路由 + // 检查登录状态 | Check login status + e.GET("/check", func(c echo.Context) error { + token := c.Request().Header.Get("Authorization") + if token == "" { + return c.JSON(http.StatusBadRequest, map[string]interface{}{ + "error": "token is required", + }) + } + + ctx := c.Request().Context() + + // 使用 saecho 包的全局函数检查登录 | Use saecho package global function to check login + isLogin := saecho.IsLogin(ctx, token) + if !isLogin { + return c.JSON(http.StatusUnauthorized, map[string]interface{}{ + "error": "未登录", + }) + } + + // 获取登录ID | Get login ID + loginID, _ := saecho.GetLoginID(ctx, token) + + return c.JSON(http.StatusOK, map[string]interface{}{ + "message": "已登录", + "login_id": loginID, + }) + }) + + // 受保护的路由组 | Protected route group api := e.Group("/api") - api.Use(plugin.AuthMiddleware()) + api.Use(saecho.AuthMiddleware()) { - api.GET("/user/info", func(c echo.Context) error { - saCtx, _ := saecho.GetSaToken(c) - loginID, _ := saCtx.GetLoginID() - permissions, _ := manager.GetPermissions(loginID) - roles, _ := manager.GetRoles(loginID) + // 用户信息 | User info + api.GET("/user", func(c echo.Context) error { + token := c.Request().Header.Get("Authorization") + ctx := c.Request().Context() + loginID, _ := saecho.GetLoginID(ctx, token) + + return c.JSON(http.StatusOK, map[string]interface{}{ + "user_id": loginID, + "name": "User " + loginID, + }) + }) - return c.JSON(200, map[string]interface{}{ + // 获取 Token 信息 | Get token info + api.GET("/token-info", func(c echo.Context) error { + token := c.Request().Header.Get("Authorization") + ctx := c.Request().Context() + + tokenInfo, err := saecho.GetTokenInfo(ctx, token) + if err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]interface{}{ "code": 200, - "message": "success", - "data": map[string]interface{}{ - "loginId": loginID, - "permissions": permissions, - "roles": roles, - }, + "message": "获取Token信息成功", + "data": tokenInfo, + }) + }) + + // 踢人下线 | Kickout user + api.POST("/kickout/:user_id", func(c echo.Context) error { + userID := c.Param("user_id") + ctx := c.Request().Context() + + // 使用 saecho 包的全局函数踢人 | Use saecho package global function to kickout + if err := saecho.Kickout(ctx, userID); err != nil { + return c.JSON(http.StatusInternalServerError, map[string]interface{}{ + "error": err.Error(), + }) + } + + return c.JSON(http.StatusOK, map[string]interface{}{ + "message": "踢人成功", + }) + }) + } + + // 需要权限的路由组 | Routes requiring permissions + admin := e.Group("/admin") + admin.Use(saecho.AuthMiddleware()) + admin.Use(saecho.PermissionMiddleware([]string{"admin:read"}, saecho.WithLogicType(saecho.LogicOr))) + { + admin.GET("/dashboard", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]interface{}{ + "message": "管理员面板", + }) + }) + } + + // 需要角色的路由组 | Routes requiring roles + super := e.Group("/super") + super.Use(saecho.AuthMiddleware()) + super.Use(saecho.RoleMiddleware([]string{"super-admin"}, saecho.WithLogicType(saecho.LogicAnd))) + { + super.GET("/settings", func(c echo.Context) error { + return c.JSON(http.StatusOK, map[string]interface{}{ + "message": "超级管理员设置", }) }) } - // 启动服务器 + // 启动服务器 | Start server log.Println("服务器启动在端口: 8080") + log.Println("示例: curl -X POST http://localhost:8080/login -d 'user_id=1000'") if err := e.Start(":8080"); err != nil { log.Fatal("服务器启动失败:", err) } diff --git a/examples/echo/echo-example/go.mod b/examples/echo/echo-example/go.mod index b08a0a4..7383f7a 100644 --- a/examples/echo/echo-example/go.mod +++ b/examples/echo/echo-example/go.mod @@ -1,35 +1,25 @@ module github.com/click33/sa-token-go/examples/echo-example -go 1.23.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/integrations/echo v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 github.com/labstack/echo/v4 v4.11.4 ) require ( - github.com/click33/sa-token-go/stputil v0.0.0-20251017234446-3cf2bdee68cc // indirect github.com/golang-jwt/jwt v3.2.2+incompatible // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/labstack/gommon v0.4.2 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasttemplate v1.2.2 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect golang.org/x/time v0.5.0 // indirect ) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/integrations/echo => ../../../integrations/echo - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory -) diff --git a/examples/echo/echo-example/go.sum b/examples/echo/echo-example/go.sum new file mode 100644 index 0000000..352f2ad --- /dev/null +++ b/examples/echo/echo-example/go.sum @@ -0,0 +1,22 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/labstack/echo/v4 v4.11.4 h1:vDZmA+qNeh1pd/cCkEicDMrjtrnMGQ1QFI9gWN1zGq8= +github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/fiber/fiber-example/cmd/main.go b/examples/fiber/fiber-example/cmd/main.go index c6d572b..7dffeb9 100644 --- a/examples/fiber/fiber-example/cmd/main.go +++ b/examples/fiber/fiber-example/cmd/main.go @@ -3,66 +3,192 @@ package main import ( "log" - "github.com/click33/sa-token-go/core" safiber "github.com/click33/sa-token-go/integrations/fiber" - "github.com/click33/sa-token-go/storage/memory" "github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2/middleware/logger" "github.com/gofiber/fiber/v2/middleware/recover" ) func main() { - // 创建存储 - storage := memory.NewStorage() + // 使用 Builder 模式构建 Manager | Build Manager using Builder pattern + manager := safiber.NewDefaultBuild(). + TokenName("Authorization"). + Timeout(7200). + IsLog(true). + IsPrintBanner(true). + Build() - // 创建配置 - config := core.DefaultConfig() - config.TokenName = "Authorization" - config.Timeout = 7200 // 2小时 + // 设置全局管理器 | Set global manager + safiber.SetManager(manager) - // 创建管理器 - manager := core.NewManager(storage, config) - - // 创建Fiber插件 - plugin := safiber.NewPlugin(manager) - - // 创建Fiber应用 + // 创建Fiber应用 | Create Fiber app app := fiber.New() app.Use(logger.New()) app.Use(recover.New()) - // 公开路由 - app.Post("/login", plugin.LoginHandler) + // 登录接口 | Login endpoint + app.Post("/login", func(c *fiber.Ctx) error { + userID := c.FormValue("user_id") + if userID == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "user_id is required", + }) + } + + ctx := c.Context() + + // 使用 safiber 包的全局函数登录 | Use safiber package global function to login + token, err := safiber.Login(ctx, userID) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": err.Error(), + }) + } + + return c.JSON(fiber.Map{ + "message": "登录成功", + "token": token, + }) + }) + + // 登出接口 | Logout endpoint + app.Post("/logout", func(c *fiber.Ctx) error { + token := c.Get("Authorization") + if token == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "token is required", + }) + } + + ctx := c.Context() + + // 使用 safiber 包的全局函数登出 | Use safiber package global function to logout + if err := safiber.LogoutByToken(ctx, token); err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": err.Error(), + }) + } + + return c.JSON(fiber.Map{ + "message": "登出成功", + }) + }) + + // 公开路由 | Public route app.Get("/public", func(c *fiber.Ctx) error { return c.JSON(fiber.Map{ "message": "公开访问", }) }) - // 受保护路由 + // 检查登录状态 | Check login status + app.Get("/check", func(c *fiber.Ctx) error { + token := c.Get("Authorization") + if token == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{ + "error": "token is required", + }) + } + + ctx := c.Context() + + // 使用 safiber 包的全局函数检查登录 | Use safiber package global function to check login + isLogin := safiber.IsLogin(ctx, token) + if !isLogin { + return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{ + "error": "未登录", + }) + } + + // 获取登录ID | Get login ID + loginID, _ := safiber.GetLoginID(ctx, token) + + return c.JSON(fiber.Map{ + "message": "已登录", + "login_id": loginID, + }) + }) + + // 受保护的路由组 | Protected route group api := app.Group("/api") - api.Use(plugin.AuthMiddleware()) + api.Use(safiber.AuthMiddleware()) { - api.Get("/user/info", func(c *fiber.Ctx) error { - saCtx, _ := safiber.GetSaToken(c) - loginID, _ := saCtx.GetLoginID() - permissions, _ := manager.GetPermissions(loginID) - roles, _ := manager.GetRoles(loginID) + // 用户信息 | User info + api.Get("/user", func(c *fiber.Ctx) error { + token := c.Get("Authorization") + ctx := c.Context() + loginID, _ := safiber.GetLoginID(ctx, token) + + return c.JSON(fiber.Map{ + "user_id": loginID, + "name": "User " + loginID, + }) + }) + + // 获取 Token 信息 | Get token info + api.Get("/token-info", func(c *fiber.Ctx) error { + token := c.Get("Authorization") + ctx := c.Context() + + tokenInfo, err := safiber.GetTokenInfo(ctx, token) + if err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": err.Error(), + }) + } return c.JSON(fiber.Map{ "code": 200, - "message": "success", - "data": fiber.Map{ - "loginId": loginID, - "permissions": permissions, - "roles": roles, - }, + "message": "获取Token信息成功", + "data": tokenInfo, + }) + }) + + // 踢人下线 | Kickout user + api.Post("/kickout/:user_id", func(c *fiber.Ctx) error { + userID := c.Params("user_id") + ctx := c.Context() + + // 使用 safiber 包的全局函数踢人 | Use safiber package global function to kickout + if err := safiber.Kickout(ctx, userID); err != nil { + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{ + "error": err.Error(), + }) + } + + return c.JSON(fiber.Map{ + "message": "踢人成功", + }) + }) + } + + // 需要权限的路由组 | Routes requiring permissions + admin := app.Group("/admin") + admin.Use(safiber.AuthMiddleware()) + admin.Use(safiber.PermissionMiddleware([]string{"admin:read"}, safiber.WithLogicType(safiber.LogicOr))) + { + admin.Get("/dashboard", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "message": "管理员面板", + }) + }) + } + + // 需要角色的路由组 | Routes requiring roles + super := app.Group("/super") + super.Use(safiber.AuthMiddleware()) + super.Use(safiber.RoleMiddleware([]string{"super-admin"}, safiber.WithLogicType(safiber.LogicAnd))) + { + super.Get("/settings", func(c *fiber.Ctx) error { + return c.JSON(fiber.Map{ + "message": "超级管理员设置", }) }) } - // 启动服务器 + // 启动服务器 | Start server log.Println("服务器启动在端口: 8080") + log.Println("示例: curl -X POST http://localhost:8080/login -d 'user_id=1000'") if err := app.Listen(":8080"); err != nil { log.Fatal("服务器启动失败:", err) } diff --git a/examples/fiber/fiber-example/go.mod b/examples/fiber/fiber-example/go.mod index e87d400..4e06611 100644 --- a/examples/fiber/fiber-example/go.mod +++ b/examples/fiber/fiber-example/go.mod @@ -1,31 +1,24 @@ module github.com/click33/sa-token-go/examples/fiber-example -go 1.21 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/integrations/fiber v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 github.com/gofiber/fiber/v2 v2.52.0 ) require ( github.com/andybalholm/brotli v1.0.5 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.17.0 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/rivo/uniseg v0.2.0 // indirect + github.com/mattn/go-runewidth v0.0.16 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/sys v0.20.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/integrations/fiber => ../../../integrations/fiber - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory + golang.org/x/sync v0.19.0 // indirect + golang.org/x/sys v0.35.0 // indirect ) diff --git a/examples/fiber/fiber-example/go.sum b/examples/fiber/fiber-example/go.sum new file mode 100644 index 0000000..778dc50 --- /dev/null +++ b/examples/fiber/fiber-example/go.sum @@ -0,0 +1,21 @@ +github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/gofiber/fiber/v2 v2.52.0 h1:S+qXi7y+/Pgvqq4DrSmREGiFwtB7Bu6+QFLuIHYw/UE= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= +github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/gf/go.mod b/examples/gf/gf-example/go.mod similarity index 69% rename from examples/gf/go.mod rename to examples/gf/gf-example/go.mod index fdf155e..1b9d7f8 100644 --- a/examples/gf/go.mod +++ b/examples/gf/gf-example/go.mod @@ -1,40 +1,34 @@ module github.com/click33/sa-token-go/examples/gf-example -go 1.24.1 - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/integrations/gf => ../../integrations/gf - github.com/click33/sa-token-go/storage/memory => ../../storage/memory -) +go 1.25.0 require ( - github.com/click33/sa-token-go/integrations/gf v0.0.0-00010101000000-000000000000 - github.com/click33/sa-token-go/storage/memory v0.0.0-00010101000000-000000000000 + github.com/click33/sa-token-go/integrations/gf v0.1.7 github.com/gogf/gf/v2 v2.9.4 ) require ( github.com/BurntSushi/toml v1.5.0 // indirect github.com/clbanning/mxj/v2 v2.7.0 // indirect - github.com/click33/sa-token-go/core v0.1.3 // indirect - github.com/click33/sa-token-go/stputil v0.1.3 // indirect + github.com/click33/sa-token-go/core v0.1.7 // indirect + github.com/click33/sa-token-go/stputil v0.1.7 // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/grokify/html-strip-tags-go v0.1.0 // indirect github.com/magiconair/properties v1.8.10 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/olekukonko/errors v1.1.0 // indirect github.com/olekukonko/ll v0.0.9 // indirect github.com/olekukonko/tablewriter v1.1.0 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/rivo/uniseg v0.4.7 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect @@ -42,6 +36,7 @@ require ( go.opentelemetry.io/otel/sdk v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect golang.org/x/net v0.43.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/examples/gf/gf-example/go.sum b/examples/gf/gf-example/go.sum new file mode 100644 index 0000000..0d56965 --- /dev/null +++ b/examples/gf/gf-example/go.sum @@ -0,0 +1,46 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/core v0.1.6/go.mod h1:mb3AQAJIXqx9WdULyn5qjufK1j/u+kgB0q+tafHVhgk= +github.com/click33/sa-token-go/integrations/gf v0.1.5 h1:jypKpIDa4L11L4JKW5374Kujnh7f/txXbUIBl0/40Qc= +github.com/click33/sa-token-go/integrations/gf v0.1.5/go.mod h1:vOIAHrB8LQMBR52lbluLTBJzzEFqsk4qwy48qmuPhJg= +github.com/click33/sa-token-go/stputil v0.1.5 h1:603tbI4JkBTg3MnfTj+lCMDxJOKSCOqsMyC2zyuvEco= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/gogf/gf/v2 v2.9.4 h1:6vleEWypot9WBPncP2GjbpgAUeG6Mzb1YESb9nPMkjY= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/grokify/html-strip-tags-go v0.1.0 h1:03UrQLjAny8xci+R+qjCce/MYnpNXCtgzltlQbOBae4= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= +github.com/olekukonko/ll v0.0.9 h1:Y+1YqDfVkqMWuEQMclsF9HUR5+a82+dxJuL1HHSRpxI= +github.com/olekukonko/tablewriter v1.1.0 h1:N0LHrshF4T39KvI96fn6GT8HEjXRXYNDrDjKFDB7RIY= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/gf/gf-example/main.go b/examples/gf/gf-example/main.go new file mode 100644 index 0000000..7ef9ae2 --- /dev/null +++ b/examples/gf/gf-example/main.go @@ -0,0 +1,260 @@ +package main + +import ( + "github.com/gogf/gf/v2/os/gctx" + "net/http" + + sagf "github.com/click33/sa-token-go/integrations/gf" + + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/ghttp" +) + +func main() { + // redis存储实现 + storage, err := sagf.NewRedisStorage("redis://:root@192.168.19.104:6379/0?dial_timeout=3&read_timeout=10s&max_retries=2") + if err != nil { + panic(err) + } + // 内存存储实现 + //storage := sagf.NewMemoryStorage() + + // 使用 Builder 模式构建 Manager | Build Manager using Builder pattern + manager := sagf.NewDefaultBuild(). + //SetStorage(sagf.NewMemoryStorage()). // 设置内存存储 | Set memory storage + SetStorage(storage). // 设置内存存储 | Set memory storage + IsLog(false). // 开启日志 | Enable logging + Build() + + // 注册 Manager | Register Manager + sagf.SetManager(manager) + + ctx := gctx.New() + s := g.Server() + + // 首页路由 | Home route + s.BindHandler("/", func(r *ghttp.Request) { + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "Welcome to Sa-Token-Go GF Example", + }) + }) + + // 公开路由 | Public route + s.BindHandler("/public", func(r *ghttp.Request) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeSuccess, + "message": "公开访问 | Public access", + }) + }) + + // 登录接口 | Login API + s.BindHandler("/login", func(r *ghttp.Request) { + // 模拟用户ID | Simulate user ID + loginID := r.Get("id", "10001").String() + + // 执行登录 | Perform login + token, err := sagf.Login(r.Context(), loginID) + if err != nil { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeServerError, + "message": err.Error(), + }) + return + } + + // 角色 + if loginID == "1" { + err = sagf.SetRoles(r.Context(), loginID, []string{"admin"}) + if err != nil { + r.Response.WriteJson(g.Map{ + "code": sagf.CodeServerError, + "message": "登录失败", + "data": g.Array{}, + }) + } + } + // 权限 + if loginID == "1" { + err = sagf.SetPermissions(r.Context(), loginID, []string{"admin:read", "admin:delete"}) + if err != nil { + r.Response.WriteJson(g.Map{ + "code": sagf.CodeServerError, + "message": "登录失败", + "data": g.Array{}, + }) + } + } + + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "登录成功 | Login successful", + "data": g.Map{ + "token": token, + "loginID": loginID, + }, + }) + }) + + s.Group("/", func(group *ghttp.RouterGroup) { + group.Middleware(sagf.AuthMiddleware( + ctx, + sagf.WithFailFunc(func(r *ghttp.Request, err error) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeNotLogin, + "message": err.Error(), + }) + }), + )) + group.GET("/logout", func(r *ghttp.Request) { + // 从请求中获取 Token | Get token from request + saCtx, ok := sagf.GetSaTokenContext(r) + if !ok { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeNotLogin, + "message": "未登录 | Not logged in", + }) + return + } + + tokenValue := saCtx.GetTokenValue() + err := sagf.LogoutByToken(r.Context(), tokenValue) + if err != nil { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeServerError, + "message": err.Error(), + }) + return + } + + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "登出成功 | Logout successful", + }) + }) + }) + + // 受保护的路由组 | Protected route group + protected := s.Group("/").Middleware( + sagf.AuthMiddleware( + ctx, + sagf.WithFailFunc(func(r *ghttp.Request, err error) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeNotLogin, + "message": err.Error(), + }) + }))) + { + // 获取用户信息 | Get user info + protected.GET("/user", func(r *ghttp.Request) { + saCtx, _ := sagf.GetSaTokenContext(r) + tokenValue := saCtx.GetTokenValue() + + loginID, err := sagf.GetLoginID(r.Context(), tokenValue) + if err != nil { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeNotLogin, + "message": err.Error(), + }) + return + } + + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "获取用户信息成功 | Get user info successful", + "data": g.Map{ + "loginID": loginID, + "token": tokenValue, + }, + }) + }) + + // 获取 Token 信息 | Get token info + protected.GET("/token-info", func(r *ghttp.Request) { + saCtx, _ := sagf.GetSaTokenContext(r) + tokenValue := saCtx.GetTokenValue() + + tokenInfo, err := sagf.GetTokenInfo(r.Context(), tokenValue) + if err != nil { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeServerError, + "message": err.Error(), + }) + return + } + + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "获取Token信息成功 | Get token info successful", + "data": tokenInfo, + }) + }) + } + + // 需要特定权限的路由 | Routes requiring specific permissions + permGroup := s.Group("/").Middleware( + sagf.AuthMiddleware( + ctx, + sagf.WithFailFunc(func(r *ghttp.Request, err error) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeNotLogin, + "message": err.Error(), + }) + }), + ), + sagf.PermissionMiddleware( + ctx, + []string{"admin:read", "admin:delete"}, + sagf.WithLogicType(sagf.LogicAnd), + sagf.WithFailFunc(func(r *ghttp.Request, err error) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodePermissionDenied, + "message": err.Error(), + }) + }), + ), + ) + { + permGroup.GET("/dashboard", func(r *ghttp.Request) { + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "管理员面板 | Admin dashboard", + }) + }) + } + + // 需要特定角色的路由 | Routes requiring specific roles + roleGroup := s.Group("/").Middleware( + sagf.AuthMiddleware( + ctx, + sagf.WithFailFunc(func(r *ghttp.Request, err error) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodeNotLogin, + "message": err.Error(), + }) + }), + ), + sagf.RoleMiddleware( + ctx, + []string{"super-admin"}, + sagf.WithLogicType(sagf.LogicAnd), + sagf.WithFailFunc(func(r *ghttp.Request, err error) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": sagf.CodePermissionDenied, + "message": err.Error(), + }) + }), + ), + ) + { + roleGroup.GET("/settings", func(r *ghttp.Request) { + r.Response.WriteJson(g.Map{ + "code": sagf.CodeSuccess, + "message": "超级管理员设置 | Super admin settings", + }) + }) + } + + s.SetPort(8000) + s.Run() +} diff --git a/examples/gf/main.go b/examples/gf/main.go deleted file mode 100644 index 40c4787..0000000 --- a/examples/gf/main.go +++ /dev/null @@ -1,52 +0,0 @@ -package main - -import ( - "net/http" - - sagin "github.com/click33/sa-token-go/integrations/gf" - "github.com/click33/sa-token-go/storage/memory" - - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/net/ghttp" -) - -func main() { - // 初始化存储 - storage := memory.NewStorage() - - // 创建配置 (现在可以直接使用 sagin 包的函数) - config := sagin.DefaultConfig() - // 创建管理器 (现在可以直接使用 sagin 包的函数) - manager := sagin.NewManager(storage, config) - - // 创建 Gin 插件 - plugin := sagin.NewPlugin(manager) - s := g.Server() - - s.BindHandler("/", func(r *ghttp.Request) { - r.Response.Writef( - "Hello %s! Your Age is %d", - r.Get("name", "unknown").String(), - r.Get("age").Int(), - ) - }) - // 公开路由 - s.BindHandler("/public", func(r *ghttp.Request) { - r.Response.WriteStatusExit( - http.StatusOK, - g.Map{ - "message": "公开访问", - }, - ) - }) - s.BindHandler("/login", plugin.LoginHandler) - // 受保护路由 - protected := s.Group("/api").Middleware(plugin.AuthMiddleware()) - - { - protected.GET("/user", plugin.UserInfoHandler) - } - - s.SetPort(8000) - s.Run() -} diff --git a/examples/gin/gin-example/README.md b/examples/gin/gin-example/README.md index f017f27..08c72ba 100644 --- a/examples/gin/gin-example/README.md +++ b/examples/gin/gin-example/README.md @@ -20,91 +20,88 @@ go run cmd/main.go ## 使用方式 -### 方式一:使用 Manager 实例(推荐用于复杂场景) +### 方式一:使用 Builder 构建器(推荐) ```go package main import ( "github.com/gin-gonic/gin" - "github.com/click33/sa-token-go/core" sagin "github.com/click33/sa-token-go/integrations/gin" - "github.com/click33/sa-token-go/storage/memory" ) func main() { - // 创建 Manager - manager := core.NewBuilder(). - Storage(memory.NewStorage()). + // 使用 Builder 创建 Manager + mgr := sagin.NewDefaultBuild(). TokenName("Authorization"). + Timeout(7200). + IsPrintBanner(true). Build() - - // 创建插件 - plugin := sagin.NewPlugin(manager) - + + // 设置全局 Manager + sagin.SetManager(mgr) + // 设置路由 r := gin.Default() - r.POST("/login", plugin.LoginHandler) - r.GET("/user", plugin.AuthMiddleware(), plugin.UserInfoHandler) - + + // 登录接口 + r.POST("/login", func(c *gin.Context) { + var req struct { + UserID string `json:"userId"` + } + c.ShouldBindJSON(&req) + + ctx := c.Request.Context() + token, _ := sagin.Login(ctx, req.UserID) + c.JSON(200, gin.H{"token": token}) + }) + + // 需要登录的接口 + r.GET("/user", sagin.CheckLoginMiddleware(), func(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(200, gin.H{ + "loginId": loginID, + "message": "用户信息", + }) + }) + + // 需要权限的接口 + r.GET("/admin", sagin.CheckPermissionMiddleware("admin:*"), func(c *gin.Context) { + c.JSON(200, gin.H{"message": "管理员数据"}) + }) + r.Run(":8080") } ``` -### 方式二:使用 StpUtil 全局工具类(推荐用于简单场景) +### 方式二:使用路由组 ```go package main import ( - "net/http" - "github.com/gin-gonic/gin" - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" sagin "github.com/click33/sa-token-go/integrations/gin" - "github.com/click33/sa-token-go/storage/memory" ) -func init() { - // 初始化 StpUtil - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - Build(), - ) -} - func main() { + // 初始化 Manager + mgr := sagin.NewDefaultBuild().Build() + sagin.SetManager(mgr) + r := gin.Default() - - // 登录接口 - r.POST("/login", func(c *gin.Context) { - var req struct { - UserID int `json:"userId"` - } - c.ShouldBindJSON(&req) - - token, _ := stputil.Login(req.UserID) - c.JSON(http.StatusOK, gin.H{"token": token}) - }) - - // 使用注解装饰器 - r.GET("/user", sagin.CheckLogin(), func(c *gin.Context) { - token := c.GetHeader("Authorization") - loginID, _ := stputil.GetLoginID(token) - - c.JSON(http.StatusOK, gin.H{ - "loginId": loginID, - "message": "用户信息", - }) - }) - - // 需要权限 - r.GET("/admin", sagin.CheckPermission("admin:*"), func(c *gin.Context) { - c.JSON(http.StatusOK, gin.H{"message": "管理员数据"}) - }) - + + // 公开接口 + r.POST("/login", loginHandler) + + // 受保护的路由组 + protected := r.Group("/api") + protected.Use(sagin.CheckLoginMiddleware()) + { + protected.GET("/user", userHandler) + protected.GET("/admin", sagin.CheckPermissionMiddleware("admin:*"), adminHandler) + } + r.Run(":8080") } ``` @@ -117,7 +114,15 @@ func main() { ```bash curl -X POST http://localhost:8080/login \ -H "Content-Type: application/json" \ - -d '{"username":"test","password":"123456"}' + -d '{"userId":"1000"}' + ``` + + 响应: + ```json + { + "message": "登录成功", + "token": "YOUR_TOKEN" + } ``` - `GET /public` - 公开访问 @@ -133,12 +138,81 @@ func main() { -H "Authorization: YOUR_TOKEN" ``` + 响应: + ```json + { + "message": "用户信息", + "loginId": "1000" + } + ``` + - `GET /api/admin` - 管理员接口(需要管理员权限) ```bash curl http://localhost:8080/api/admin \ -H "Authorization: YOUR_TOKEN" ``` +## 中间件说明 + +| 中间件 | 说明 | +|--------|------| +| `CheckLoginMiddleware()` | 检查是否已登录 | +| `CheckRoleMiddleware(roles...)` | 检查是否拥有指定角色 | +| `CheckPermissionMiddleware(perms...)` | 检查是否拥有指定权限 | +| `CheckDisableMiddleware()` | 检查账号是否被封禁 | +| `IgnoreMiddleware()` | 忽略认证检查 | + +## 常用函数 + +### 认证相关 + +```go +// 登录(需要 context) +token, err := sagin.Login(ctx, userID) + +// 登出 +err := sagin.Logout(ctx, userID) +err := sagin.LogoutByToken(ctx, token) + +// 检查登录状态 +isLogin := sagin.IsLogin(ctx, token) + +// 获取登录ID +loginID, err := sagin.GetLoginID(ctx, token) + +// 从请求中获取登录ID(Gin 专用) +loginID, err := sagin.GetLoginIDFromRequest(c) +``` + +### 权限和角色 + +```go +// 设置权限 +err := sagin.SetPermissions(ctx, userID, []string{"user:read", "admin:*"}) + +// 设置角色 +err := sagin.SetRoles(ctx, userID, []string{"admin", "user"}) + +// 检查权限 +hasPermission := sagin.HasPermission(ctx, userID, "admin:*") + +// 检查角色 +hasRole := sagin.HasRole(ctx, userID, "admin") +``` + +### 踢人和封禁 + +```go +// 踢人下线 +err := sagin.Kickout(ctx, userID) + +// 封禁账号 +err := sagin.Disable(ctx, userID, time.Hour) + +// 解封账号 +err := sagin.Untie(ctx, userID) +``` + ## 配置文件 配置文件位于 `configs/config.yaml`: @@ -154,5 +228,5 @@ server: ## 更多示例 -查看 [注解示例](../../annotation/annotation-example) 了解更多注解装饰器的用法。 - +- [简单示例](../gin-simple) - 最简单的使用方式 +- [注解示例](../../annotation/annotation-example) - 中间件装饰器用法 diff --git a/examples/gin/gin-example/cmd/main.go b/examples/gin/gin-example/cmd/main.go index 568ae08..aaa536b 100644 --- a/examples/gin/gin-example/cmd/main.go +++ b/examples/gin/gin-example/cmd/main.go @@ -4,7 +4,6 @@ import ( "log" sagin "github.com/click33/sa-token-go/integrations/gin" - "github.com/click33/sa-token-go/storage/memory" "github.com/gin-gonic/gin" "github.com/spf13/viper" ) @@ -16,39 +15,74 @@ func main() { log.Printf("Warning: No config file found, using defaults: %v", err) } - // 初始化存储 - storage := memory.NewStorage() + // 创建 Builder + b := sagin.NewDefaultBuild(). + TokenName("Authorization"). + IsPrintBanner(true) - // 创建配置 (现在可以直接使用 sagin 包的函数) - config := sagin.DefaultConfig() + // 从配置文件读取配置 if viper.IsSet("token.timeout") { - config.Timeout = viper.GetInt64("token.timeout") + b.Timeout(viper.GetInt64("token.timeout")) } if viper.IsSet("token.active_timeout") { - config.ActiveTimeout = viper.GetInt64("token.active_timeout") + b.ActiveTimeout(viper.GetInt64("token.active_timeout")) } - // 创建管理器 (现在可以直接使用 sagin 包的函数) - manager := sagin.NewManager(storage, config) + // 构建 Manager + mgr := b.Build() - // 创建 Gin 插件 - plugin := sagin.NewPlugin(manager) + // 设置全局 Manager + sagin.SetManager(mgr) // 设置路由 r := gin.Default() // 公开路由 - r.POST("/login", plugin.LoginHandler) + r.POST("/login", func(c *gin.Context) { + var req struct { + UserID string `json:"userId" form:"userId"` + } + if err := c.ShouldBind(&req); err != nil || req.UserID == "" { + c.JSON(400, gin.H{"error": "userId is required"}) + return + } + + ctx := c.Request.Context() + token, err := sagin.Login(ctx, req.UserID) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + + c.JSON(200, gin.H{ + "message": "登录成功", + "token": token, + }) + }) + r.GET("/public", func(c *gin.Context) { c.JSON(200, gin.H{"message": "公开访问"}) }) // 受保护路由 protected := r.Group("/api") - protected.Use(plugin.AuthMiddleware()) + protected.Use(sagin.CheckLoginMiddleware()) { - protected.GET("/user", plugin.UserInfoHandler) - protected.GET("/admin", plugin.AdminOnlyHandler) + protected.GET("/user", func(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(200, gin.H{ + "message": "用户信息", + "loginId": loginID, + }) + }) + + protected.GET("/admin", sagin.CheckPermissionMiddleware("admin:*"), func(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(200, gin.H{ + "message": "管理员数据", + "loginId": loginID, + }) + }) } // 启动服务器 diff --git a/examples/gin/gin-example/go.mod b/examples/gin/gin-example/go.mod index 939c529..7d6ef85 100644 --- a/examples/gin/gin-example/go.mod +++ b/examples/gin/gin-example/go.mod @@ -1,41 +1,36 @@ module github.com/click33/sa-token-go/examples/gin-example -go 1.23.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/integrations/gin v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 github.com/gin-gonic/gin v1.10.0 github.com/spf13/viper v1.18.2 ) require ( github.com/bytedance/sonic v1.11.6 // indirect - github.com/click33/sa-token-go/core v0.1.3 // indirect - github.com/click33/sa-token-go/stputil v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/click33/sa-token-go/core v0.1.7 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect github.com/leodido/go-urn v1.4.0 // indirect - github.com/magiconair/properties v1.8.7 // indirect + github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect @@ -53,13 +48,8 @@ require ( golang.org/x/net v0.43.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect - google.golang.org/protobuf v1.34.1 // indirect + google.golang.org/protobuf v1.36.10 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/integrations/gin => ../../../integrations/gin - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory -) diff --git a/examples/gin/gin-example/go.sum b/examples/gin/gin-example/go.sum new file mode 100644 index 0000000..47ebca8 --- /dev/null +++ b/examples/gin/gin-example/go.sum @@ -0,0 +1,55 @@ +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/gin/gin-simple/README.md b/examples/gin/gin-simple/README.md index 8cda6bf..4b90c9f 100644 --- a/examples/gin/gin-simple/README.md +++ b/examples/gin/gin-simple/README.md @@ -4,17 +4,17 @@ This example demonstrates how to use Sa-Token-Go with Gin by **only importing th ## Features -✅ **Single Import** - Only need `github.com/click33/sa-token-go/integrations/gin` -✅ **All Functions** - Access to all core and stputil functions -✅ **Simple API** - Clean and easy to use +- **Single Import** - Only need `github.com/click33/sa-token-go/integrations/gin` +- **All Functions** - Access to all core and stputil functions +- **Simple API** - Clean and easy to use +- **Context Support** - All functions support `context.Context` ## Quick Start ### 1. Install dependencies ```bash -go get github.com/click33/sa-token-go/integrations/gin@v0.1.0 -go get github.com/click33/sa-token-go/storage/memory@v0.1.0 +go get github.com/click33/sa-token-go/integrations/gin go get github.com/gin-gonic/gin ``` @@ -57,86 +57,145 @@ curl -X POST -H "token: YOUR_TOKEN" http://localhost:8080/api/kickout/1000 # Response: {"message":"踢人成功"} ``` -## Code Highlights - -### Old Way (Multiple Imports) +## Code Example ```go -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/integrations/gin" -) - -config := core.DefaultConfig() -manager := core.NewManager(storage, config) -stputil.SetManager(manager) -token, _ := stputil.Login(userID) -``` - -### New Way (Single Import) ✨ +package main -```go import ( + "log" + sagin "github.com/click33/sa-token-go/integrations/gin" + "github.com/gin-gonic/gin" ) -config := sagin.DefaultConfig() -manager := sagin.NewManager(storage, config) -sagin.SetManager(manager) -token, _ := sagin.Login(userID) +func main() { + // Create Builder and build Manager + mgr := sagin.NewDefaultBuild(). + TokenName("token"). + Timeout(7200). + IsPrintBanner(true). + Build() + + // Set global manager + sagin.SetManager(mgr) + + // Create router + r := gin.Default() + + // Login endpoint + r.POST("/login", func(c *gin.Context) { + userID := c.PostForm("user_id") + ctx := c.Request.Context() + + token, err := sagin.Login(ctx, userID) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + + c.JSON(200, gin.H{ + "message": "登录成功", + "token": token, + }) + }) + + // Protected routes + protected := r.Group("/api") + protected.Use(sagin.CheckLoginMiddleware()) + { + protected.GET("/user", func(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(200, gin.H{"user_id": loginID}) + }) + } + + r.Run(":8080") +} ``` ## Available Functions -All functions from `core` and `stputil` are re-exported in `sagin`: +All functions from `stputil` are re-exported in `sagin` with `context.Context` support: ### Authentication -- `sagin.Login(loginID, device...)` -- `sagin.Logout(loginID, device...)` -- `sagin.IsLogin(token)` -- `sagin.CheckLogin(token)` -- `sagin.GetLoginID(token)` + +```go +sagin.Login(ctx, loginID, device...) // Login +sagin.Logout(ctx, loginID, device...) // Logout +sagin.LogoutByToken(ctx, token) // Logout by token +sagin.IsLogin(ctx, token) // Check login status +sagin.CheckLogin(ctx, token) // Check login (throws error) +sagin.GetLoginID(ctx, token) // Get login ID from token +sagin.GetLoginIDFromRequest(c) // Get login ID from Gin context +``` ### Kickout & Disable -- `sagin.Kickout(loginID, device...)` -- `sagin.Disable(loginID, duration)` -- `sagin.IsDisable(loginID)` -- `sagin.Untie(loginID)` + +```go +sagin.Kickout(ctx, loginID, device...) // Kickout user +sagin.KickoutByToken(ctx, token) // Kickout by token +sagin.Disable(ctx, loginID, duration) // Disable account +sagin.IsDisable(ctx, loginID) // Check if disabled +sagin.Untie(ctx, loginID) // Re-enable account +``` ### Permission & Role -- `sagin.CheckPermission(loginID, permission)` -- `sagin.CheckRole(loginID, role)` -- `sagin.HasPermission(loginID, permission)` -- `sagin.HasRole(loginID, role)` + +```go +sagin.SetPermissions(ctx, loginID, perms) // Set permissions +sagin.SetRoles(ctx, loginID, roles) // Set roles +sagin.HasPermission(ctx, loginID, perm) // Check permission +sagin.HasRole(ctx, loginID, role) // Check role +sagin.GetPermissions(ctx, loginID) // Get permissions +sagin.GetRoles(ctx, loginID) // Get roles +``` ### Session -- `sagin.GetSession(loginID)` -- `sagin.GetSessionByToken(token)` + +```go +sagin.GetSession(ctx, loginID) // Get session +sagin.GetSessionByToken(ctx, token) // Get session by token +sagin.HasSession(ctx, loginID) // Check session exists +``` ### Security Features -- `sagin.GenerateNonce()` -- `sagin.VerifyNonce(nonce)` -- `sagin.LoginWithRefreshToken(loginID, device...)` -- `sagin.RefreshAccessToken(refreshToken)` -- `sagin.GetOAuth2Server()` + +```go +sagin.Generate(ctx) // Generate nonce +sagin.Verify(ctx, nonce) // Verify nonce +sagin.GenerateTokenPair(ctx, loginID) // Generate access + refresh token +sagin.RefreshAccessToken(ctx, refreshToken) // Refresh access token +``` ### Builder & Config -- `sagin.DefaultConfig()` -- `sagin.NewManager(storage, config)` -- `sagin.NewBuilder()` -- `sagin.SetManager(manager)` + +```go +sagin.NewDefaultBuild() // Create default builder +sagin.NewDefaultConfig() // Create default config +sagin.SetManager(mgr) // Set global manager +sagin.GetManager() // Get global manager +``` + +## Middleware Functions + +| Middleware | Description | +|------------|-------------| +| `CheckLoginMiddleware()` | Check if user is logged in | +| `CheckRoleMiddleware(roles...)` | Check if user has specified roles | +| `CheckPermissionMiddleware(perms...)` | Check if user has specified permissions | +| `CheckDisableMiddleware()` | Check if account is disabled | +| `IgnoreMiddleware()` | Skip authentication check | ## Benefits 1. **Simpler Dependencies** - Only one import needed 2. **Cleaner Code** - Less import statements 3. **Framework-Specific** - Optimized for Gin -4. **Backward Compatible** - Old way still works +4. **Context Support** - All functions support `context.Context` ## Learn More - [Main Documentation](../../../README.md) -- [Other Examples](../../) -- [API Reference](../../../docs/api/api.md) - +- [Gin Example](../gin-example) - More complete example +- [Annotation Example](../../annotation/annotation-example) - Middleware usage diff --git a/examples/gin/gin-simple/README_zh.md b/examples/gin/gin-simple/README_zh.md index fe6c30f..ed5237b 100644 --- a/examples/gin/gin-simple/README_zh.md +++ b/examples/gin/gin-simple/README_zh.md @@ -4,17 +4,17 @@ ## 特性 -✅ **单一导入** - 只需要 `github.com/click33/sa-token-go/integrations/gin` -✅ **完整功能** - 访问所有 core 和 stputil 的功能 -✅ **简洁 API** - 干净易用 +- **单一导入** - 只需要 `github.com/click33/sa-token-go/integrations/gin` +- **完整功能** - 访问所有 core 和 stputil 的功能 +- **简洁 API** - 干净易用 +- **Context 支持** - 所有函数都支持 `context.Context` ## 快速开始 ### 1. 安装依赖 ```bash -go get github.com/click33/sa-token-go/integrations/gin@v0.1.0 -go get github.com/click33/sa-token-go/storage/memory@v0.1.0 +go get github.com/click33/sa-token-go/integrations/gin go get github.com/gin-gonic/gin ``` @@ -57,86 +57,145 @@ curl -X POST -H "token: YOUR_TOKEN" http://localhost:8080/api/kickout/1000 # 响应: {"message":"踢人成功"} ``` -## 代码亮点 - -### 旧方式(多个导入) +## 代码示例 ```go -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/integrations/gin" -) - -config := core.DefaultConfig() -manager := core.NewManager(storage, config) -stputil.SetManager(manager) -token, _ := stputil.Login(userID) -``` - -### 新方式(单一导入)✨ +package main -```go import ( + "log" + sagin "github.com/click33/sa-token-go/integrations/gin" + "github.com/gin-gonic/gin" ) -config := sagin.DefaultConfig() -manager := sagin.NewManager(storage, config) -sagin.SetManager(manager) -token, _ := sagin.Login(userID) +func main() { + // 创建 Builder 并构建 Manager + mgr := sagin.NewDefaultBuild(). + TokenName("token"). + Timeout(7200). + IsPrintBanner(true). + Build() + + // 设置全局管理器 + sagin.SetManager(mgr) + + // 创建路由 + r := gin.Default() + + // 登录接口 + r.POST("/login", func(c *gin.Context) { + userID := c.PostForm("user_id") + ctx := c.Request.Context() + + token, err := sagin.Login(ctx, userID) + if err != nil { + c.JSON(500, gin.H{"error": err.Error()}) + return + } + + c.JSON(200, gin.H{ + "message": "登录成功", + "token": token, + }) + }) + + // 受保护的路由 + protected := r.Group("/api") + protected.Use(sagin.CheckLoginMiddleware()) + { + protected.GET("/user", func(c *gin.Context) { + loginID, _ := sagin.GetLoginIDFromRequest(c) + c.JSON(200, gin.H{"user_id": loginID}) + }) + } + + r.Run(":8080") +} ``` ## 可用函数 -所有 `core` 和 `stputil` 的函数都在 `sagin` 中重新导出: +所有 `stputil` 的函数都在 `sagin` 中重新导出,并支持 `context.Context`: ### 认证相关 -- `sagin.Login(loginID, device...)` -- `sagin.Logout(loginID, device...)` -- `sagin.IsLogin(token)` -- `sagin.CheckLogin(token)` -- `sagin.GetLoginID(token)` + +```go +sagin.Login(ctx, loginID, device...) // 登录 +sagin.Logout(ctx, loginID, device...) // 登出 +sagin.LogoutByToken(ctx, token) // 根据Token登出 +sagin.IsLogin(ctx, token) // 检查登录状态 +sagin.CheckLogin(ctx, token) // 检查登录(抛出错误) +sagin.GetLoginID(ctx, token) // 从Token获取登录ID +sagin.GetLoginIDFromRequest(c) // 从Gin上下文获取登录ID +``` ### 踢人下线 & 封禁 -- `sagin.Kickout(loginID, device...)` -- `sagin.Disable(loginID, duration)` -- `sagin.IsDisable(loginID)` -- `sagin.Untie(loginID)` + +```go +sagin.Kickout(ctx, loginID, device...) // 踢人下线 +sagin.KickoutByToken(ctx, token) // 根据Token踢人下线 +sagin.Disable(ctx, loginID, duration) // 封禁账号 +sagin.IsDisable(ctx, loginID) // 检查是否被封禁 +sagin.Untie(ctx, loginID) // 解封账号 +``` ### 权限 & 角色 -- `sagin.CheckPermission(loginID, permission)` -- `sagin.CheckRole(loginID, role)` -- `sagin.HasPermission(loginID, permission)` -- `sagin.HasRole(loginID, role)` + +```go +sagin.SetPermissions(ctx, loginID, perms) // 设置权限 +sagin.SetRoles(ctx, loginID, roles) // 设置角色 +sagin.HasPermission(ctx, loginID, perm) // 检查权限 +sagin.HasRole(ctx, loginID, role) // 检查角色 +sagin.GetPermissions(ctx, loginID) // 获取权限列表 +sagin.GetRoles(ctx, loginID) // 获取角色列表 +``` ### Session 管理 -- `sagin.GetSession(loginID)` -- `sagin.GetSessionByToken(token)` + +```go +sagin.GetSession(ctx, loginID) // 获取Session +sagin.GetSessionByToken(ctx, token) // 根据Token获取Session +sagin.HasSession(ctx, loginID) // 检查Session是否存在 +``` ### 安全特性 -- `sagin.GenerateNonce()` -- `sagin.VerifyNonce(nonce)` -- `sagin.LoginWithRefreshToken(loginID, device...)` -- `sagin.RefreshAccessToken(refreshToken)` -- `sagin.GetOAuth2Server()` + +```go +sagin.Generate(ctx) // 生成随机数 +sagin.Verify(ctx, nonce) // 验证随机数 +sagin.GenerateTokenPair(ctx, loginID) // 生成访问令牌和刷新令牌 +sagin.RefreshAccessToken(ctx, refreshToken) // 刷新访问令牌 +``` ### Builder & Config -- `sagin.DefaultConfig()` -- `sagin.NewManager(storage, config)` -- `sagin.NewBuilder()` -- `sagin.SetManager(manager)` + +```go +sagin.NewDefaultBuild() // 创建默认构建器 +sagin.NewDefaultConfig() // 创建默认配置 +sagin.SetManager(mgr) // 设置全局管理器 +sagin.GetManager() // 获取全局管理器 +``` + +## 中间件函数 + +| 中间件 | 说明 | +|--------|------| +| `CheckLoginMiddleware()` | 检查是否已登录 | +| `CheckRoleMiddleware(roles...)` | 检查是否拥有指定角色 | +| `CheckPermissionMiddleware(perms...)` | 检查是否拥有指定权限 | +| `CheckDisableMiddleware()` | 检查账号是否被封禁 | +| `IgnoreMiddleware()` | 忽略认证检查 | ## 优势 1. **更简单的依赖** - 只需要一个导入 2. **更清晰的代码** - 更少的导入语句 3. **框架专用** - 为 Gin 优化 -4. **向后兼容** - 旧方式仍然有效 +4. **Context 支持** - 所有函数都支持 `context.Context` ## 了解更多 - [主文档](../../../README_zh.md) -- [其他示例](../../) -- [API 参考](../../../docs/api/api_zh.md) - +- [Gin 示例](../gin-example) - 更完整的示例 +- [注解示例](../../annotation/annotation-example) - 中间件用法 diff --git a/examples/gin/gin-simple/gin-simple b/examples/gin/gin-simple/gin-simple deleted file mode 100755 index 75f3ea8..0000000 Binary files a/examples/gin/gin-simple/gin-simple and /dev/null differ diff --git a/examples/gin/gin-simple/go.mod b/examples/gin/gin-simple/go.mod index e9dd2a9..6d3043f 100644 --- a/examples/gin/gin-simple/go.mod +++ b/examples/gin/gin-simple/go.mod @@ -1,19 +1,17 @@ module github.com/click33/sa-token-go/examples/gin/gin-simple -go 1.23.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/integrations/gin v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 + github.com/click33/sa-token-go/integrations/gin v0.1.7 github.com/gin-gonic/gin v1.10.0 ) require ( github.com/bytedance/sonic v1.11.6 // indirect - github.com/click33/sa-token-go/core v0.1.3 // indirect - github.com/click33/sa-token-go/stputil v0.1.3 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect + github.com/click33/sa-token-go/core v0.1.7 // indirect + github.com/click33/sa-token-go/stputil v0.1.7 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect @@ -22,7 +20,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect @@ -31,21 +29,16 @@ require ( github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect - google.golang.org/protobuf v1.34.1 // indirect + google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/integrations/gin => ../../../integrations/gin - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory - github.com/click33/sa-token-go/stputil => ../../../stputil -) diff --git a/examples/gin/gin-simple/go.sum b/examples/gin/gin-simple/go.sum new file mode 100644 index 0000000..4b99fae --- /dev/null +++ b/examples/gin/gin-simple/go.sum @@ -0,0 +1,44 @@ +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/integrations/gin v0.1.6 h1:OOGM7ozUFqiIbWBopWmB4Q4DVn6GZF/ho2gPhwJyZSI= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/gin/gin-simple/main.go b/examples/gin/gin-simple/main.go index 18d7280..6fa2f48 100644 --- a/examples/gin/gin-simple/main.go +++ b/examples/gin/gin-simple/main.go @@ -4,28 +4,19 @@ import ( "log" sagin "github.com/click33/sa-token-go/integrations/gin" - "github.com/click33/sa-token-go/storage/memory" "github.com/gin-gonic/gin" ) func main() { - // 初始化存储 | Initialize storage - storage := memory.NewStorage() - - // 创建配置 (只需要 sagin 包!) | Create config (only need sagin package!) - config := sagin.DefaultConfig() - config.TokenName = "token" - config.Timeout = 7200 - config.IsPrintBanner = true - - // 创建管理器 | Create manager - manager := sagin.NewManager(storage, config) + // 创建 Builder 并构建 Manager | Create Builder and build Manager + mgr := sagin.NewDefaultBuild(). + TokenName("token"). + Timeout(7200). + IsPrintBanner(true). + Build() // 设置全局管理器 | Set global manager - sagin.SetManager(manager) - - // 创建 Gin 插件 | Create Gin plugin - plugin := sagin.NewPlugin(manager) + sagin.SetManager(mgr) // 创建路由 | Create router r := gin.Default() @@ -38,8 +29,10 @@ func main() { return } + ctx := c.Request.Context() + // 使用 sagin 包的全局函数登录 | Use sagin package global function to login - token, err := sagin.Login(userID) + token, err := sagin.Login(ctx, userID) if err != nil { c.JSON(500, gin.H{"error": err.Error()}) return @@ -59,8 +52,10 @@ func main() { return } + ctx := c.Request.Context() + // 使用 sagin 包的全局函数登出 | Use sagin package global function to logout - if err := sagin.LogoutByToken(token); err != nil { + if err := sagin.LogoutByToken(ctx, token); err != nil { c.JSON(500, gin.H{"error": err.Error()}) return } @@ -76,15 +71,17 @@ func main() { return } + ctx := c.Request.Context() + // 使用 sagin 包的全局函数检查登录 | Use sagin package global function to check login - isLogin := sagin.IsLogin(token) + isLogin := sagin.IsLogin(ctx, token) if !isLogin { c.JSON(401, gin.H{"error": "未登录"}) return } // 获取登录ID | Get login ID - loginID, _ := sagin.GetLoginID(token) + loginID, _ := sagin.GetLoginID(ctx, token) c.JSON(200, gin.H{ "message": "已登录", @@ -94,12 +91,11 @@ func main() { // 受保护的路由组 | Protected route group protected := r.Group("/api") - protected.Use(plugin.AuthMiddleware()) + protected.Use(sagin.CheckLoginMiddleware()) { // 用户信息 | User info protected.GET("/user", func(c *gin.Context) { - token := c.GetHeader("token") - loginID, _ := sagin.GetLoginID(token) + loginID, _ := sagin.GetLoginIDFromRequest(c) c.JSON(200, gin.H{ "user_id": loginID, @@ -110,9 +106,10 @@ func main() { // 踢人下线 | Kickout user protected.POST("/kickout/:user_id", func(c *gin.Context) { userID := c.Param("user_id") + ctx := c.Request.Context() // 使用 sagin 包的全局函数踢人 | Use sagin package global function to kickout - if err := sagin.Kickout(userID); err != nil { + if err := sagin.Kickout(ctx, userID); err != nil { c.JSON(500, gin.H{"error": err.Error()}) return } diff --git a/examples/jwt-example/README.md b/examples/jwt-example/README.md deleted file mode 100644 index 3352624..0000000 --- a/examples/jwt-example/README.md +++ /dev/null @@ -1,210 +0,0 @@ -# JWT Token 示例 - -本示例演示如何在 Sa-Token-Go 中使用 JWT(JSON Web Token)。 - -## JWT 简介 - -JWT 是一种无状态的 Token 方案,Token 本身包含了用户信息和过期时间,适合分布式系统。 - -### JWT 优势 - -- ✅ **无状态**:不需要服务端存储 Session -- ✅ **分布式友好**:多个服务可以独立验证 -- ✅ **信息自包含**:Token 包含用户信息 -- ✅ **跨域支持**:可以跨不同域使用 - -### JWT 结构 - -JWT 由三部分组成,用 `.` 分隔: - -``` -Header.Payload.Signature -``` - -- **Header**:Token 类型和加密算法 -- **Payload**:用户数据(loginId, device, exp等) -- **Signature**:签名(使用密钥加密) - -## 运行示例 - -```bash -go run main.go -``` - -## 基本使用 - -### 1. 配置 JWT - -```go -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/storage/memory" -) - -func init() { - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleJWT). // 使用 JWT - JwtSecretKey("your-256-bit-secret-key-here"). // 设置密钥 - Timeout(3600). // 过期时间 - Build(), - ) -} -``` - -### 2. 登录获取 JWT Token - -```go -token, _ := stputil.Login(1000) -// 返回类似:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9... -``` - -### 3. 验证 JWT Token - -```go -// 验证 Token 是否有效 -if stputil.IsLogin(token) { - fmt.Println("Token 有效") -} - -// 获取登录 ID -loginID, _ := stputil.GetLoginID(token) -``` - -### 4. 解析 JWT - -你可以使用 [jwt.io](https://jwt.io) 在线解析 JWT Token 查看内容。 - -**Payload 示例:** - -```json -{ - "loginId": "1000", - "device": "", - "iat": 1697234567, - "exp": 1697238167 -} -``` - -## JWT 配置选项 - -```go -core.NewBuilder(). - TokenStyle(core.TokenStyleJWT). // 启用 JWT - JwtSecretKey("your-secret-key"). // 密钥(必需) - Timeout(3600). // Token 过期时间(秒) - IsPrintBanner(true). // 显示启动 Banner - Build() -``` - -## 安全建议 - -### 1. 使用强密钥 - -```go -// ❌ 弱密钥 -JwtSecretKey("secret") - -// ✅ 强密钥(建议至少 32 字节) -JwtSecretKey("a-very-long-and-random-secret-key-at-least-256-bits") -``` - -### 2. 设置合理的过期时间 - -```go -// 短期 Token(推荐) -Timeout(3600) // 1小时 - -// 长期 Token(需要配合刷新机制) -Timeout(86400) // 24小时 -``` - -### 3. 在生产环境中保护密钥 - -```go -// ✅ 从环境变量读取 -import "os" - -JwtSecretKey(os.Getenv("JWT_SECRET_KEY")) -``` - -## JWT vs 普通 Token - -| 特性 | JWT | UUID/Random | -|------|-----|-------------| -| 状态 | 无状态 | 有状态 | -| 服务端存储 | 不需要 | 需要 | -| Token 大小 | 较大 | 较小 | -| 可撤销性 | 困难 | 容易 | -| 分布式 | 优秀 | 需要共享存储 | -| 性能 | 高(不查数据库) | 中等(需查数据库) | - -## 使用场景 - -### 适合 JWT 的场景 - -- ✅ 微服务架构 -- ✅ 无状态 API -- ✅ 跨域认证 -- ✅ 短期访问令牌 - -### 不适合 JWT 的场景 - -- ❌ 需要立即撤销 Token -- ❌ Token 包含敏感信息 -- ❌ 需要频繁更新权限 - -## 完整示例 - -```go -package main - -import ( - "fmt" - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/storage/memory" -) - -func main() { - // 初始化 JWT - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleJWT). - JwtSecretKey("your-256-bit-secret"). - Timeout(3600). - Build(), - ) - - // 登录 - token, _ := stputil.Login(1000) - fmt.Println("Token:", token) - - // 验证 - if stputil.IsLogin(token) { - loginID, _ := stputil.GetLoginID(token) - fmt.Println("登录ID:", loginID) - } - - // 权限管理 - stputil.SetPermissions(1000, []string{"admin:*"}) - if stputil.HasPermission(1000, "admin:read") { - fmt.Println("有权限") - } -} -``` - -## 相关文档 - -- [Authentication Guide](../../docs/guide/authentication.md) -- [Token Configuration](../../docs/guide/configuration.md) -- [Quick Start](../../docs/tutorial/quick-start.md) - -## 在线工具 - -- [JWT.io](https://jwt.io) - JWT 调试工具 -- [JWT Inspector](https://jwt-inspector.netlify.app/) - JWT 检查器 - diff --git a/examples/jwt-example/go.mod b/examples/jwt-example/go.mod deleted file mode 100644 index 1daf0cf..0000000 --- a/examples/jwt-example/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/click33/sa-token-go/examples/jwt-example - -go 1.21 - -require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 -) - -require ( - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/memory => ../../storage/memory - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/examples/kratos/kratos-example/go.mod b/examples/kratos/kratos-example/go.mod index 2d82d6a..5365bcd 100644 --- a/examples/kratos/kratos-example/go.mod +++ b/examples/kratos/kratos-example/go.mod @@ -1,6 +1,6 @@ module github.com/click33/sa-token-go/examples/kratos/kratos-example -go 1.25.3 +go 1.25.0 require ( github.com/go-kratos/kratos/v2 v2.9.1 @@ -22,7 +22,7 @@ require ( go.opentelemetry.io/otel v1.38.0 // indirect go.opentelemetry.io/otel/sdk/metric v1.38.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect diff --git a/examples/kratos/kratos-example/go.sum b/examples/kratos/kratos-example/go.sum index e77d9e4..8ed5692 100644 --- a/examples/kratos/kratos-example/go.sum +++ b/examples/kratos/kratos-example/go.sum @@ -21,7 +21,7 @@ go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5 go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= google.golang.org/genproto/googleapis/api v0.0.0-20251111163417-95abcf5c77ba h1:B14OtaXuMaCQsl2deSvNkyPKIzq3BjfxQp8d00QyWx4= diff --git a/examples/listener-example/README.md b/examples/listener-example/README.md deleted file mode 100644 index aa14290..0000000 --- a/examples/listener-example/README.md +++ /dev/null @@ -1,163 +0,0 @@ -# Event Listener Example - -This example demonstrates the event listener system in Sa-Token-Go. - -## Features Demonstrated - -1. **Simple Function Listeners** - Quick way to register event handlers -2. **Wildcard Listeners** - Listen to all events -3. **Priority-Based Execution** - Control listener execution order -4. **Synchronous vs Asynchronous** - Choose blocking or non-blocking listeners -5. **Panic Recovery** - Handle listener errors gracefully -6. **Dynamic Registration** - Add/remove listeners at runtime -7. **Event Enable/Disable** - Control which events are active - -## Running the Example - -```bash -go run main.go -``` - -## Expected Output - -``` -=== Sa-Token-Go Event Listener Example === - ---- Triggering Events --- - -[AUDIT] Login audit - User: 1000, Time: 1697234567 -[ALL EVENTS] Event{type=login, loginID=1000, device=, timestamp=1697234567} -[LOGIN] User 1000 logged in with token abc123def456... - -[AUDIT] Login audit - User: 2000, Time: 1697234568 -[ALL EVENTS] Event{type=login, loginID=2000, device=, timestamp=1697234568} -[LOGIN] User 2000 logged in with token xyz789... - -[ALL EVENTS] Event{type=logout, loginID=1000, device=, timestamp=1697234569} -[LOGOUT] User 1000 logged out - -[ALL EVENTS] Event{type=kickout, loginID=2000, device=, timestamp=1697234570} -[KICKOUT] User 2000 was forcibly logged out - ---- Listener Statistics --- -Total listeners: 5 -Login listeners: 2 -Logout listeners: 1 - ---- Unregistering audit logger --- -Audit logger unregistered successfully -Remaining listeners: 4 - ---- Disabling kickout events --- - ---- Testing event disable (this should not trigger kickout listener) --- -[ALL EVENTS] Event{type=login, loginID=3000, device=, timestamp=1697234571} -[LOGIN] User 3000 logged in with token... - -=== Example Complete === -``` - -## Key Concepts - -In this example the authentication manager automatically owns an internal event manager: - -```go -manager := core.NewBuilder(). - Storage(memory.NewStorage()). - Build() - -eventMgr := manager.GetEventManager() // Advanced controls (stats, enable/disable, panic handler, ...) -``` - -### Function Listeners - -The simplest way to register an event handler: - -```go -manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { - fmt.Printf("User %s logged in\n", data.LoginID) -}) -``` - -### Priority-Based Listeners - -Control execution order with priorities: - -```go -manager.RegisterWithConfig(core.EventLogin, - myListener, - core.ListenerConfig{ - Priority: 100, // Higher = executes first - Async: false, // Synchronous execution - }, -) -``` - -### Wildcard Listeners - -Listen to all events: - -```go -manager.RegisterFunc(core.EventAll, func(data *core.EventData) { - // This will be called for every event -}) -``` - -### Dynamic Listener Management - -Add and remove listeners at runtime: - -```go -// Register with custom ID -id := manager.RegisterWithConfig(event, listener, core.ListenerConfig{ - ID: "my-listener", -}) - -// Later, unregister -manager.Unregister(id) -``` - -## Use Cases - -### 1. Audit Logging - -```go -manager.RegisterFunc(core.EventAll, func(data *core.EventData) { - auditLog.Write(fmt.Sprintf("[%s] %s - %s", - data.Event, data.LoginID, time.Unix(data.Timestamp, 0))) -}) -``` - -### 2. Security Monitoring - -```go -manager.RegisterFunc(core.EventKickout, func(data *core.EventData) { - alertSystem.Send(fmt.Sprintf("User %s was kicked out", data.LoginID)) -}) -``` - -### 3. Analytics - -```go -manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { - analytics.Track("user_login", map[string]interface{}{ - "user_id": data.LoginID, - "device": data.Device, - }) -}) -``` - -### 4. Cache Invalidation - -```go -manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { - cache.Delete("user:" + data.LoginID) -}) -``` - -## Related Documentation - -- [Listener Guide](../../docs/guide/listener.md) - Complete listener documentation -- [Authentication Guide](../../docs/guide/authentication.md) - Authentication basics -- [API Reference](../../docs/api/) - API documentation - diff --git a/examples/listener-example/go.mod b/examples/listener-example/go.mod deleted file mode 100644 index 728dd7c..0000000 --- a/examples/listener-example/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/click33/sa-token-go/examples/listener-example - -go 1.21 - -require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 -) - -require ( - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/memory => ../../storage/memory - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/examples/manager-example/jwt-example/go.mod b/examples/manager-example/jwt-example/go.mod new file mode 100644 index 0000000..536c068 --- /dev/null +++ b/examples/manager-example/jwt-example/go.mod @@ -0,0 +1,16 @@ +module github.com/click33/sa-token-go/examples/jwt-example + +go 1.25.0 + +require ( + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/memory v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 +) + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect +) diff --git a/examples/manager-example/jwt-example/go.sum b/examples/manager-example/jwt-example/go.sum new file mode 100644 index 0000000..16b1a51 --- /dev/null +++ b/examples/manager-example/jwt-example/go.sum @@ -0,0 +1,11 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/jwt-example/main.go b/examples/manager-example/jwt-example/main.go similarity index 68% rename from examples/jwt-example/main.go rename to examples/manager-example/jwt-example/main.go index ad77750..44542c9 100644 --- a/examples/jwt-example/main.go +++ b/examples/manager-example/jwt-example/main.go @@ -1,9 +1,11 @@ package main import ( + "context" "fmt" - "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" "github.com/click33/sa-token-go/storage/memory" "github.com/click33/sa-token-go/stputil" ) @@ -13,17 +15,20 @@ func main() { // 初始化使用 JWT Token 风格 stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). + builder.NewBuilder(). + SetStorage(memory.NewStorage()). TokenName("Authorization"). - TokenStyle(core.TokenStyleJWT). // 使用 JWT + TokenStyle(adapter.TokenStyleJWT). // 使用 JWT JwtSecretKey("your-256-bit-secret-key-here"). // JWT 密钥 Timeout(3600). // 1小时过期 + MaxRefresh(1800). // 自动续期触发阈值 Build(), ) + ctx := context.Background() + fmt.Println("1. 使用 JWT 登录") - token, err := stputil.Login(1000) + token, err := stputil.Login(ctx, 1000) if err != nil { fmt.Printf("登录失败: %v\n", err) return @@ -34,13 +39,13 @@ func main() { // 你可以在 https://jwt.io 解析这个 Token fmt.Println("2. 验证 JWT Token") - if stputil.IsLogin(token) { + if stputil.IsLogin(ctx, token) { fmt.Println("✓ Token 有效") } else { fmt.Println("✗ Token 无效") } - loginID, err := stputil.GetLoginID(token) + loginID, err := stputil.GetLoginID(ctx, token) if err != nil { fmt.Printf("获取登录ID失败: %v\n", err) return @@ -48,29 +53,29 @@ func main() { fmt.Printf("登录ID: %s\n\n", loginID) fmt.Println("3. 设置权限和角色") - stputil.SetPermissions(1000, []string{"user:read", "user:write", "admin:*"}) - stputil.SetRoles(1000, []string{"admin", "user"}) + _ = stputil.SetPermissions(ctx, 1000, []string{"user:read", "user:write", "admin:*"}) + _ = stputil.SetRoles(ctx, 1000, []string{"admin", "user"}) fmt.Println("已设置权限: user:read, user:write, admin:*") fmt.Println("已设置角色: admin, user\n") fmt.Println("4. 检查权限") - if stputil.HasPermission(1000, "user:read") { + if stputil.HasPermission(ctx, 1000, "user:read") { fmt.Println("✓ 拥有 user:read 权限") } - if stputil.HasPermission(1000, "admin:delete") { + if stputil.HasPermission(ctx, 1000, "admin:delete") { fmt.Println("✓ 拥有 admin:delete 权限(通配符匹配)") } fmt.Println("\n5. 检查角色") - if stputil.HasRole(1000, "admin") { + if stputil.HasRole(ctx, 1000, "admin") { fmt.Println("✓ 拥有 admin 角色") } fmt.Println("\n6. 登出") - stputil.Logout(1000) + _ = stputil.Logout(ctx, 1000) fmt.Println("已登出") - if !stputil.IsLogin(token) { + if !stputil.IsLogin(ctx, token) { fmt.Println("✓ Token 已失效") } diff --git a/examples/manager-example/listener-example/go.mod b/examples/manager-example/listener-example/go.mod new file mode 100644 index 0000000..eeb053c --- /dev/null +++ b/examples/manager-example/listener-example/go.mod @@ -0,0 +1,16 @@ +module github.com/click33/sa-token-go/examples/listener-example + +go 1.25.0 + +require ( + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/memory v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 +) + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect +) diff --git a/examples/manager-example/listener-example/go.sum b/examples/manager-example/listener-example/go.sum new file mode 100644 index 0000000..16b1a51 --- /dev/null +++ b/examples/manager-example/listener-example/go.sum @@ -0,0 +1,11 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/listener-example/main.go b/examples/manager-example/listener-example/main.go similarity index 62% rename from examples/listener-example/main.go rename to examples/manager-example/listener-example/main.go index 85c052a..e7004f8 100644 --- a/examples/listener-example/main.go +++ b/examples/manager-example/listener-example/main.go @@ -1,10 +1,12 @@ package main import ( + "context" "fmt" "time" - "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/listener" "github.com/click33/sa-token-go/storage/memory" "github.com/click33/sa-token-go/stputil" ) @@ -12,34 +14,37 @@ import ( func main() { fmt.Println("=== Sa-Token-Go Event Listener Example ===\n") + ctx := context.Background() + // 1. Simple function listener - manager := core.NewBuilder(). - Storage(memory.NewStorage()). + mgr := builder.NewBuilder(). + SetStorage(memory.NewStorage()). TokenName("Authorization"). - Timeout(7200). + Timeout(300). + MaxRefresh(150). Build() - manager.RegisterFunc(core.EventLogin, func(data *core.EventData) { + mgr.RegisterFunc(listener.EventLogin, func(data *listener.EventData) { fmt.Printf("[LOGIN] User %s logged in with token %s\n", data.LoginID, data.Token[:20]+"...") }) // 2. Logout listener - manager.RegisterFunc(core.EventLogout, func(data *core.EventData) { + mgr.RegisterFunc(listener.EventLogout, func(data *listener.EventData) { fmt.Printf("[LOGOUT] User %s logged out\n", data.LoginID) }) // 3. Kickout listener - manager.RegisterFunc(core.EventKickout, func(data *core.EventData) { + mgr.RegisterFunc(listener.EventKickout, func(data *listener.EventData) { fmt.Printf("[KICKOUT] User %s was forcibly logged out\n", data.LoginID) }) // 4. High-priority synchronous listener - auditListenerID := manager.RegisterWithConfig(core.EventLogin, - core.ListenerFunc(func(data *core.EventData) { + auditListenerID := mgr.RegisterWithConfig(listener.EventLogin, + listener.ListenerFunc(func(data *listener.EventData) { fmt.Printf("[AUDIT] Login audit - User: %s, Time: %d\n", data.LoginID, data.Timestamp) }), - core.ListenerConfig{ + listener.ListenerConfig{ Async: false, // Synchronous Priority: 100, // High priority ID: "audit-logger", @@ -47,48 +52,48 @@ func main() { ) // 5. Wildcard listener (all events) - manager.RegisterFunc(core.EventAll, func(data *core.EventData) { + mgr.RegisterFunc(listener.EventAll, func(data *listener.EventData) { fmt.Printf("[ALL EVENTS] %s\n", data.String()) }) - eventMgr := manager.GetEventManager() + eventMgr := mgr.GetEventManager() // 6. Custom panic handler - eventMgr.SetPanicHandler(func(event core.Event, data *core.EventData, recovered interface{}) { + eventMgr.SetPanicHandler(func(event listener.Event, data *listener.EventData, recovered interface{}) { fmt.Printf("[PANIC RECOVERED] Event: %s, Error: %v\n", event, recovered) }) // Initialize Sa-Token - stputil.SetManager(manager) + stputil.SetManager(mgr) fmt.Println("\n--- Triggering Events ---\n") // Trigger login event - token1, _ := stputil.Login(1000) + token1, _ := stputil.Login(ctx, 1000) time.Sleep(100 * time.Millisecond) // Wait for async listeners - token2, _ := stputil.Login(2000) + token2, _ := stputil.Login(ctx, 2000) time.Sleep(100 * time.Millisecond) // Trigger logout event - stputil.Logout(1000) + stputil.Logout(ctx, 1000) time.Sleep(100 * time.Millisecond) // Trigger kickout event - stputil.Kickout(2000) + stputil.Kickout(ctx, 2000) time.Sleep(100 * time.Millisecond) // Wait for all async listeners to complete - manager.WaitEvents() + mgr.WaitEvents() fmt.Println("\n--- Listener Statistics ---") fmt.Printf("Total listeners: %d\n", eventMgr.Count()) - fmt.Printf("Login listeners: %d\n", eventMgr.CountForEvent(core.EventLogin)) - fmt.Printf("Logout listeners: %d\n", eventMgr.CountForEvent(core.EventLogout)) + fmt.Printf("Login listeners: %d\n", eventMgr.CountForEvent(listener.EventLogin)) + fmt.Printf("Logout listeners: %d\n", eventMgr.CountForEvent(listener.EventLogout)) // Unregister a listener fmt.Println("\n--- Unregistering audit logger ---") - if manager.Unregister(auditListenerID) { + if mgr.Unregister(auditListenerID) { fmt.Println("Audit logger unregistered successfully") } @@ -96,11 +101,11 @@ func main() { // Disable certain events fmt.Println("\n--- Disabling kickout events ---") - eventMgr.DisableEvent(core.EventKickout) + eventMgr.DisableEvent(listener.EventKickout) fmt.Println("\n--- Testing event disable (this should not trigger kickout listener) ---") - stputil.Login(3000) - stputil.Kickout(3000) + stputil.Login(ctx, 3000) + stputil.Kickout(ctx, 3000) time.Sleep(100 * time.Millisecond) // Re-enable all events diff --git a/examples/oauth2-example/go.mod b/examples/manager-example/oauth2-example/go.mod similarity index 62% rename from examples/oauth2-example/go.mod rename to examples/manager-example/oauth2-example/go.mod index 905f73e..a88de38 100644 --- a/examples/oauth2-example/go.mod +++ b/examples/manager-example/oauth2-example/go.mod @@ -1,47 +1,44 @@ module github.com/click33/sa-token-go/examples/oauth2-example -go 1.21 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 + github.com/click33/sa-token-go/core v0.1.7 github.com/gin-gonic/gin v1.10.0 ) require ( github.com/bytedance/sonic v1.11.6 // indirect + github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.2.2 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect + github.com/kr/pretty v0.3.1 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect - golang.org/x/crypto v0.23.0 // indirect - golang.org/x/net v0.25.0 // indirect - golang.org/x/sys v0.20.0 // indirect - golang.org/x/text v0.15.0 // indirect - golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect - google.golang.org/protobuf v1.34.1 // indirect + golang.org/x/crypto v0.41.0 // indirect + golang.org/x/net v0.43.0 // indirect + golang.org/x/sync v0.16.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect + google.golang.org/protobuf v1.36.10 // indirect + gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/memory => ../../storage/memory -) diff --git a/examples/manager-example/oauth2-example/go.sum b/examples/manager-example/oauth2-example/go.sum new file mode 100644 index 0000000..9264424 --- /dev/null +++ b/examples/manager-example/oauth2-example/go.sum @@ -0,0 +1,42 @@ +github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= +github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= +github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-gonic/gin v1.10.0 h1:nTuyha1TYqgedzytsKYqna+DfLos46nTv2ygFy86HFU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE= +golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= +golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/oauth2-example/main.go b/examples/manager-example/oauth2-example/main.go similarity index 61% rename from examples/oauth2-example/main.go rename to examples/manager-example/oauth2-example/main.go index 0c6ffbe..8bf8c40 100644 --- a/examples/oauth2-example/main.go +++ b/examples/manager-example/oauth2-example/main.go @@ -1,25 +1,27 @@ package main import ( + "context" "fmt" "net/http" - "time" - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/core/oauth2" "github.com/gin-gonic/gin" ) -var oauth2Server *core.OAuth2Server +var oauth2Server *oauth2.OAuth2Server func main() { - storage := memory.NewStorage() - oauth2Server = core.NewOAuth2Server(storage) + // 创建 OAuth2 服务器 + // 参数:authType, prefix, storage, serializer + oauth2Server = oauth2.NewOAuth2Server("login", "sa:", nil, nil) + // 注册客户端 registerClients() r := gin.Default() + // OAuth2 端点 r.GET("/oauth/authorize", authorizeHandler) r.POST("/oauth/token", tokenHandler) r.GET("/oauth/userinfo", userinfoHandler) @@ -32,44 +34,60 @@ func main() { fmt.Println("3. GET /oauth/userinfo (Authorization: Bearer )") fmt.Println("4. POST /oauth/revoke (token=)") - r.Run(":8080") + _ = r.Run(":8080") } func registerClients() { - client := &core.OAuth2Client{ + // 注册 Web 应用客户端 + webClient := &oauth2.Client{ ClientID: "webapp", ClientSecret: "secret123", RedirectURIs: []string{ "http://localhost:8080/callback", "http://localhost:3000/callback", }, - GrantTypes: []core.OAuth2GrantType{ - core.GrantTypeAuthorizationCode, - core.GrantTypeRefreshToken, + GrantTypes: []oauth2.GrantType{ + oauth2.GrantTypeAuthorizationCode, + oauth2.GrantTypeRefreshToken, }, Scopes: []string{"read", "write", "profile"}, } - oauth2Server.RegisterClient(client) + _ = oauth2Server.RegisterClient(webClient) - mobileClient := &core.OAuth2Client{ + // 注册移动应用客户端 + mobileClient := &oauth2.Client{ ClientID: "mobile-app", ClientSecret: "mobile-secret-456", RedirectURIs: []string{ "myapp://oauth/callback", }, - GrantTypes: []core.OAuth2GrantType{ - core.GrantTypeAuthorizationCode, - core.GrantTypeRefreshToken, + GrantTypes: []oauth2.GrantType{ + oauth2.GrantTypeAuthorizationCode, + oauth2.GrantTypeRefreshToken, }, Scopes: []string{"read", "write"}, } - oauth2Server.RegisterClient(mobileClient) + _ = oauth2Server.RegisterClient(mobileClient) + + // 注册服务端客户端(客户端凭证模式) + serviceClient := &oauth2.Client{ + ClientID: "service-app", + ClientSecret: "service-secret-789", + RedirectURIs: []string{}, + GrantTypes: []oauth2.GrantType{ + oauth2.GrantTypeClientCredentials, + }, + Scopes: []string{"api:read", "api:write"}, + } + _ = oauth2Server.RegisterClient(serviceClient) fmt.Println("✅ OAuth2 Clients registered:") fmt.Println(" - webapp (client_id: webapp, secret: secret123)") fmt.Println(" - mobile-app (client_id: mobile-app, secret: mobile-secret-456)") + fmt.Println(" - service-app (client_id: service-app, secret: service-secret-789)") } +// 授权端点 - 生成授权码 func authorizeHandler(c *gin.Context) { clientID := c.Query("client_id") redirectURI := c.Query("redirect_uri") @@ -87,12 +105,18 @@ func authorizeHandler(c *gin.Context) { scopes = []string{scope} } + // 模拟已登录用户 userID := "user123" + ctx := c.Request.Context() + + // 生成授权码 + // 参数顺序:ctx, clientID, userID, redirectURI, scopes authCode, err := oauth2Server.GenerateAuthorizationCode( + ctx, clientID, - redirectURI, userID, + redirectURI, scopes, ) if err != nil { @@ -110,26 +134,32 @@ func authorizeHandler(c *gin.Context) { }) } +// 令牌端点 - 根据授权类型处理 func tokenHandler(c *gin.Context) { grantType := c.PostForm("grant_type") + ctx := c.Request.Context() switch grantType { case "authorization_code": - handleAuthorizationCodeGrant(c) + handleAuthorizationCodeGrant(ctx, c) case "refresh_token": - handleRefreshTokenGrant(c) + handleRefreshTokenGrant(ctx, c) + case "client_credentials": + handleClientCredentialsGrant(ctx, c) default: c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"}) } } -func handleAuthorizationCodeGrant(c *gin.Context) { +// 授权码模式 +func handleAuthorizationCodeGrant(ctx context.Context, c *gin.Context) { code := c.PostForm("code") clientID := c.PostForm("client_id") clientSecret := c.PostForm("client_secret") redirectURI := c.PostForm("redirect_uri") accessToken, err := oauth2Server.ExchangeCodeForToken( + ctx, code, clientID, clientSecret, @@ -149,14 +179,16 @@ func handleAuthorizationCodeGrant(c *gin.Context) { }) } -func handleRefreshTokenGrant(c *gin.Context) { +// 刷新令牌模式 +func handleRefreshTokenGrant(ctx context.Context, c *gin.Context) { refreshToken := c.PostForm("refresh_token") clientID := c.PostForm("client_id") clientSecret := c.PostForm("client_secret") accessToken, err := oauth2Server.RefreshAccessToken( - refreshToken, + ctx, clientID, + refreshToken, clientSecret, ) if err != nil { @@ -173,6 +205,37 @@ func handleRefreshTokenGrant(c *gin.Context) { }) } +// 客户端凭证模式 +func handleClientCredentialsGrant(ctx context.Context, c *gin.Context) { + clientID := c.PostForm("client_id") + clientSecret := c.PostForm("client_secret") + scope := c.PostForm("scope") + + var scopes []string + if scope != "" { + scopes = []string{scope} + } + + accessToken, err := oauth2Server.ClientCredentialsToken( + ctx, + clientID, + clientSecret, + scopes, + ) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "access_token": accessToken.Token, + "token_type": accessToken.TokenType, + "expires_in": accessToken.ExpiresIn, + "scope": accessToken.Scopes, + }) +} + +// 用户信息端点 func userinfoHandler(c *gin.Context) { authHeader := c.GetHeader("Authorization") if authHeader == "" { @@ -181,9 +244,12 @@ func userinfoHandler(c *gin.Context) { } var token string - fmt.Sscanf(authHeader, "Bearer %s", &token) + _, _ = fmt.Sscanf(authHeader, "Bearer %s", &token) + + ctx := c.Request.Context() - accessToken, err := oauth2Server.ValidateAccessToken(token) + // 验证访问令牌并获取信息 + accessToken, err := oauth2Server.ValidateAccessTokenAndGetInfo(ctx, token) if err != nil { c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid access token"}) return @@ -194,14 +260,15 @@ func userinfoHandler(c *gin.Context) { "client_id": accessToken.ClientID, "scopes": accessToken.Scopes, "expires_in": accessToken.ExpiresIn, - "issued_at": time.Now().Unix() - accessToken.ExpiresIn, }) } +// 撤销令牌端点 func revokeHandler(c *gin.Context) { token := c.PostForm("token") + ctx := c.Request.Context() - if err := oauth2Server.RevokeToken(token); err != nil { + if err := oauth2Server.RevokeToken(ctx, token); err != nil { c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) return } diff --git a/examples/manager-example/security-example/go.mod b/examples/manager-example/security-example/go.mod new file mode 100644 index 0000000..367a369 --- /dev/null +++ b/examples/manager-example/security-example/go.mod @@ -0,0 +1,15 @@ +module github.com/click33/sa-token-go/examples/security-example + +go 1.25.0 + +require ( + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/memory v0.1.7 +) + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect +) diff --git a/examples/manager-example/security-example/main.go b/examples/manager-example/security-example/main.go new file mode 100644 index 0000000..0dea13e --- /dev/null +++ b/examples/manager-example/security-example/main.go @@ -0,0 +1,159 @@ +package main + +import ( + "context" + "fmt" + "time" + + "github.com/click33/sa-token-go/core/security" + "github.com/click33/sa-token-go/storage/memory" +) + +func main() { + fmt.Println("=== Sa-Token-Go Security Features Demo ===\n") + + ctx := context.Background() + storage := memory.NewStorage() + + // 1. Nonce 防重放攻击示例 + demoNonceManager(ctx, storage) + + // 2. Refresh Token 示例 + demoRefreshTokenManager(ctx, storage) + + fmt.Println("=== Demo Complete ===") +} + +// demoNonceManager 演示 Nonce 防重放攻击功能 +func demoNonceManager(ctx context.Context, storage *memory.Storage) { + fmt.Println("1. Nonce Manager - 防重放攻击") + fmt.Println("----------------------------------------") + + // 创建 Nonce 管理器 + // 参数:authType, prefix, storage, ttl + nonceManager := security.NewNonceManager("login", "sa:", storage, 5*time.Minute) + + // 生成 Nonce + nonce1, err := nonceManager.Generate(ctx) + if err != nil { + fmt.Printf(" [ERROR] 生成 Nonce 失败: %v\n", err) + return + } + fmt.Printf(" [OK] 生成 Nonce: %s\n", nonce1) + + // 检查 Nonce 是否有效(不消费) + fmt.Printf(" [INFO] Nonce 是否有效(检查): %v\n", nonceManager.IsValid(ctx, nonce1)) + + // 验证并消费 Nonce(第一次) + result1 := nonceManager.Verify(ctx, nonce1) + fmt.Printf(" [VERIFY] 第一次验证 Nonce: %v (应该为 true)\n", result1) + + // 验证并消费 Nonce(第二次 - 重放攻击模拟) + result2 := nonceManager.Verify(ctx, nonce1) + fmt.Printf(" [BLOCKED] 第二次验证 Nonce: %v (应该为 false - 防止重放)\n", result2) + + // 使用 VerifyAndConsume 方法 + nonce2, _ := nonceManager.Generate(ctx) + fmt.Printf("\n [OK] 生成新 Nonce: %s\n", nonce2) + err = nonceManager.VerifyAndConsume(ctx, nonce2) + if err != nil { + fmt.Printf(" [ERROR] VerifyAndConsume 失败: %v\n", err) + } else { + fmt.Printf(" [OK] VerifyAndConsume 成功\n") + } + + // 再次使用已消费的 Nonce + err = nonceManager.VerifyAndConsume(ctx, nonce2) + if err != nil { + fmt.Printf(" [BLOCKED] 重复使用 Nonce 被拒绝: %v\n", err) + } + + fmt.Println() +} + +// demoRefreshTokenManager 演示 Refresh Token 功能 +func demoRefreshTokenManager(ctx context.Context, storage *memory.Storage) { + fmt.Println("2. Refresh Token Manager - 令牌刷新") + fmt.Println("----------------------------------------") + + // 创建 Refresh Token 管理器 + // 参数:authType, prefix, tokenKeyPrefix, tokenGen, accessTTL, storage, serializer + rtManager := security.NewRefreshTokenManager( + "login", // authType + "sa:", // prefix + "token:", // tokenKeyPrefix + nil, // tokenGen (nil 使用默认) + 2*time.Hour, // accessTTL + storage, // storage + nil, // serializer (nil 使用默认 JSON) + ) + + // 生成令牌对(Access Token + Refresh Token) + userID := "user1001" + device := "web" + + tokenPair, err := rtManager.GenerateTokenPair(ctx, userID, device) + if err != nil { + fmt.Printf(" [ERROR] 生成令牌对失败: %v\n", err) + return + } + + fmt.Printf(" [OK] 生成令牌对成功\n") + fmt.Printf(" Access Token: %s\n", tokenPair.AccessToken) + fmt.Printf(" Refresh Token: %s\n", tokenPair.RefreshToken) + fmt.Printf(" Login ID: %s\n", tokenPair.LoginID) + fmt.Printf(" Device: %s\n", tokenPair.Device) + fmt.Printf(" 创建时间: %s\n", time.Unix(tokenPair.CreateTime, 0).Format("2006-01-02 15:04:05")) + fmt.Printf(" 过期时间: %s\n", time.Unix(tokenPair.ExpireTime, 0).Format("2006-01-02 15:04:05")) + + // 验证 Access Token + fmt.Printf("\n [VERIFY] Access Token 是否有效: %v\n", rtManager.VerifyAccessToken(ctx, tokenPair.AccessToken)) + + // 获取 Access Token 信息 + accessInfo, valid := rtManager.VerifyAccessTokenAndGetInfo(ctx, tokenPair.AccessToken) + if valid { + fmt.Printf(" [INFO] Access Token 信息:\n") + fmt.Printf(" - LoginID: %s\n", accessInfo.LoginID) + fmt.Printf(" - Device: %s\n", accessInfo.Device) + } + + // 检查 Refresh Token 是否有效 + fmt.Printf("\n [VERIFY] Refresh Token 是否有效: %v\n", rtManager.IsValid(ctx, tokenPair.RefreshToken)) + + // 使用 Refresh Token 刷新 Access Token + fmt.Println("\n [REFRESH] 刷新 Access Token...") + newTokenPair, err := rtManager.RefreshAccessToken(ctx, tokenPair.RefreshToken) + if err != nil { + fmt.Printf(" [ERROR] 刷新失败: %v\n", err) + return + } + + fmt.Printf(" [OK] 刷新成功\n") + fmt.Printf(" 新 Access Token: %s\n", newTokenPair.AccessToken) + fmt.Printf(" Refresh Token: %s (保持不变)\n", newTokenPair.RefreshToken) + + // 验证旧 Access Token 已失效 + fmt.Printf("\n [VERIFY] 旧 Access Token 是否有效: %v (应该为 false)\n", rtManager.VerifyAccessToken(ctx, tokenPair.AccessToken)) + fmt.Printf(" [VERIFY] 新 Access Token 是否有效: %v (应该为 true)\n", rtManager.VerifyAccessToken(ctx, newTokenPair.AccessToken)) + + // 获取 Refresh Token 信息 + refreshInfo, err := rtManager.GetRefreshTokenInfo(ctx, tokenPair.RefreshToken) + if err == nil { + fmt.Printf("\n [INFO] Refresh Token 信息:\n") + fmt.Printf(" - LoginID: %s\n", refreshInfo.LoginID) + fmt.Printf(" - Device: %s\n", refreshInfo.Device) + fmt.Printf(" - AccessToken: %s\n", refreshInfo.AccessToken) + } + + // 撤销 Refresh Token + fmt.Println("\n [REVOKE] 撤销 Refresh Token...") + err = rtManager.RevokeRefreshToken(ctx, tokenPair.RefreshToken) + if err != nil { + fmt.Printf(" [ERROR] 撤销失败: %v\n", err) + } else { + fmt.Printf(" [OK] 撤销成功\n") + fmt.Printf(" [VERIFY] Refresh Token 是否有效: %v (应该为 false)\n", rtManager.IsValid(ctx, tokenPair.RefreshToken)) + } + + fmt.Println() +} diff --git a/examples/manager-example/session-demo/go.mod b/examples/manager-example/session-demo/go.mod new file mode 100644 index 0000000..4408f9e --- /dev/null +++ b/examples/manager-example/session-demo/go.mod @@ -0,0 +1,16 @@ +module github.com/click33/sa-token-go/examples/session-demo + +go 1.25.0 + +require ( + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/memory v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 +) + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect +) diff --git a/examples/manager-example/session-demo/go.sum b/examples/manager-example/session-demo/go.sum new file mode 100644 index 0000000..16b1a51 --- /dev/null +++ b/examples/manager-example/session-demo/go.sum @@ -0,0 +1,11 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/session-demo/main.go b/examples/manager-example/session-demo/main.go similarity index 83% rename from examples/session-demo/main.go rename to examples/manager-example/session-demo/main.go index e03ea92..7ea3161 100644 --- a/examples/session-demo/main.go +++ b/examples/manager-example/session-demo/main.go @@ -1,10 +1,11 @@ package main import ( + "context" "fmt" "log" - "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/builder" "github.com/click33/sa-token-go/storage/memory" "github.com/click33/sa-token-go/stputil" ) @@ -24,18 +25,22 @@ type SysUser struct { func main() { // 初始化 sa-token stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). + builder.NewBuilder(). + SetStorage(memory.NewStorage()). KeyPrefix("satoken"). - IsPrintBanner(false). + Timeout(86400). + MaxRefresh(43200). + IsPrintBanner(true). Build(), ) + ctx := context.Background() + // 模拟用户登录 userID := "1000" // 1. 执行登录 - Token 键中只存 loginID - token, err := stputil.Login(userID) + token, err := stputil.Login(ctx, userID) if err != nil { log.Fatalf("Login failed: %v", err) } @@ -55,10 +60,10 @@ func main() { } // 3. 将完整的用户对象存入 Session(Account-Session) - sess, _ := stputil.GetSession(userID) - sess.Set("user", userFromDB) // ← 完整的 User 对象存在 Session 中 - sess.Set("lastLoginTime", "2025-10-25 10:00:00") - sess.Set("loginIP", "192.168.1.100") + sess, _ := stputil.GetSession(ctx, userID) + _ = sess.Set(ctx, "user", userFromDB) // ← 完整的 User 对象存在 Session 中 + _ = sess.Set(ctx, "lastLoginTime", "2025-10-25 10:00:00") + _ = sess.Set(ctx, "loginIP", "192.168.1.100") fmt.Printf("📦 Redis 存储结构:\n\n") fmt.Printf(" 1️⃣ Token 键(只存 loginID):\n") @@ -77,12 +82,12 @@ func main() { fmt.Printf("🔍 获取用户信息流程:\n\n") // 步骤1:从 Token 获取 loginID - loginID, _ := stputil.GetLoginID(token) + loginID, _ := stputil.GetLoginID(ctx, token) fmt.Printf(" 步骤1: Token → loginID\n") fmt.Printf(" %s → %s\n\n", token, loginID) // 步骤2:从 Session 获取完整用户对象 - sess2, _ := stputil.GetSession(loginID) + sess2, _ := stputil.GetSession(ctx, loginID) userObj, exists := sess2.Get("user") if exists { // Session 返回的是 map,需要转换 diff --git a/examples/manager-example/token-styles/go.mod b/examples/manager-example/token-styles/go.mod new file mode 100644 index 0000000..96dc7d9 --- /dev/null +++ b/examples/manager-example/token-styles/go.mod @@ -0,0 +1,16 @@ +module github.com/click33/sa-token-go/examples/token-styles + +go 1.25.0 + +require ( + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/memory v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 +) + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect +) diff --git a/examples/manager-example/token-styles/go.sum b/examples/manager-example/token-styles/go.sum new file mode 100644 index 0000000..16b1a51 --- /dev/null +++ b/examples/manager-example/token-styles/go.sum @@ -0,0 +1,11 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/examples/token-styles/main.go b/examples/manager-example/token-styles/main.go similarity index 57% rename from examples/token-styles/main.go rename to examples/manager-example/token-styles/main.go index f24d426..85c06ba 100644 --- a/examples/token-styles/main.go +++ b/examples/manager-example/token-styles/main.go @@ -1,10 +1,12 @@ package main import ( + "context" "fmt" "time" - "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" "github.com/click33/sa-token-go/storage/memory" "github.com/click33/sa-token-go/stputil" ) @@ -15,41 +17,44 @@ func main() { // Demo all token styles // 演示所有 Token 风格 - demoTokenStyle(core.TokenStyleUUID, "UUID Style") - demoTokenStyle(core.TokenStyleSimple, "Simple Style") - demoTokenStyle(core.TokenStyleRandom32, "Random32 Style") - demoTokenStyle(core.TokenStyleRandom64, "Random64 Style") - demoTokenStyle(core.TokenStyleRandom128, "Random128 Style") - demoTokenStyle(core.TokenStyleJWT, "JWT Style") - demoTokenStyle(core.TokenStyleHash, "Hash Style (SHA256)") - demoTokenStyle(core.TokenStyleTimestamp, "Timestamp Style") - demoTokenStyle(core.TokenStyleTik, "Tik Style (Short ID)") + demoTokenStyle(adapter.TokenStyleUUID, "UUID Style") + demoTokenStyle(adapter.TokenStyleSimple, "Simple Style") + demoTokenStyle(adapter.TokenStyleRandom32, "Random32 Style") + demoTokenStyle(adapter.TokenStyleRandom64, "Random64 Style") + demoTokenStyle(adapter.TokenStyleRandom128, "Random128 Style") + demoTokenStyle(adapter.TokenStyleJWT, "JWT Style") + demoTokenStyle(adapter.TokenStyleHash, "Hash Style (SHA256)") + demoTokenStyle(adapter.TokenStyleTimestamp, "Timestamp Style") + demoTokenStyle(adapter.TokenStyleTik, "Tik Style (Short ID)") fmt.Println("\n========================================") fmt.Println("✅ All token styles demonstrated!") } -func demoTokenStyle(style core.TokenStyle, name string) { +func demoTokenStyle(style adapter.TokenStyle, name string) { fmt.Printf("📌 %s (%s)\n", name, style) fmt.Println("----------------------------------------") // Initialize manager with specific token style // 使用特定 Token 风格初始化管理器 - manager := core.NewBuilder(). - Storage(memory.NewStorage()). + mgr := builder.NewBuilder(). + SetStorage(memory.NewStorage()). TokenStyle(style). Timeout(3600). + MaxRefresh(1800). JwtSecretKey("my-secret-key-123"). // For JWT style | 用于JWT风格 IsPrintBanner(false). Build() - stputil.SetManager(manager) + stputil.SetManager(mgr) + + ctx := context.Background() // Generate 3 tokens to show variety // 生成3个Token展示多样性 for i := 1; i <= 3; i++ { loginID := fmt.Sprintf("user%d", 1000+i) - token, err := stputil.Login(loginID) + token, err := stputil.Login(ctx, loginID) if err != nil { fmt.Printf(" ❌ Error generating token: %v\n", err) continue diff --git a/examples/oauth2-example/README.md b/examples/oauth2-example/README.md deleted file mode 100644 index c90e8e1..0000000 --- a/examples/oauth2-example/README.md +++ /dev/null @@ -1,232 +0,0 @@ -English | [中文文档](README_zh.md) - -# OAuth2 Authorization Code Flow Example - -Complete OAuth2 authorization code flow implementation example. - -## Features - -- **Authorization Code Grant** - Standard OAuth2 authorization code flow -- **Token Refresh** - Refresh access tokens using refresh tokens -- **Token Validation** - Validate access tokens -- **Token Revocation** - Revoke access tokens -- **Multiple Clients** - Support for multiple OAuth2 clients -- **Scope Management** - Fine-grained permission control - -## Quick Start - -### 1. Run the Server - -```bash -cd examples/oauth2-example -go run main.go -``` - -Server runs on `http://localhost:8080` - -### 2. OAuth2 Flow - -#### Step 1: Authorization Request - -```bash -curl "http://localhost:8080/oauth/authorize?client_id=webapp&redirect_uri=http://localhost:8080/callback&response_type=code&state=xyz123" -``` - -Response: -```json -{ - "message": "Authorization code generated", - "code": "a3f5d8b2c1e4f6a9...", - "redirect_url": "http://localhost:8080/callback?code=...&state=xyz123", - "user_id": "user123", - "scopes": ["read", "write"] -} -``` - -#### Step 2: Exchange Code for Token - -```bash -curl -X POST http://localhost:8080/oauth/token \ - -d "grant_type=authorization_code" \ - -d "code=a3f5d8b2c1e4f6a9..." \ - -d "client_id=webapp" \ - -d "client_secret=secret123" \ - -d "redirect_uri=http://localhost:8080/callback" -``` - -Response: -```json -{ - "access_token": "b4f6d9c3d2e5f7b0...", - "token_type": "Bearer", - "expires_in": 7200, - "refresh_token": "c5f7e0d4e3f6e8c1...", - "scope": ["read", "write"] -} -``` - -#### Step 3: Use Access Token - -```bash -curl http://localhost:8080/oauth/userinfo \ - -H "Authorization: Bearer b4f6d9c3d2e5f7b0..." -``` - -Response: -```json -{ - "user_id": "user123", - "client_id": "webapp", - "scopes": ["read", "write"], - "expires_in": 7200, - "issued_at": 1700000000 -} -``` - -#### Step 4: Refresh Access Token - -```bash -curl -X POST http://localhost:8080/oauth/token \ - -d "grant_type=refresh_token" \ - -d "refresh_token=c5f7e0d4e3f6e8c1..." \ - -d "client_id=webapp" \ - -d "client_secret=secret123" -``` - -#### Step 5: Revoke Token - -```bash -curl -X POST http://localhost:8080/oauth/revoke \ - -d "token=b4f6d9c3d2e5f7b0..." -``` - -## Registered Clients - -### Web Application - -``` -Client ID: webapp -Client Secret: secret123 -Redirect URIs: - - http://localhost:8080/callback - - http://localhost:3000/callback -Scopes: read, write, profile -``` - -### Mobile Application - -``` -Client ID: mobile-app -Client Secret: mobile-secret-456 -Redirect URIs: - - myapp://oauth/callback -Scopes: read, write -``` - -## API Endpoints - -| Endpoint | Method | Description | -|----------|--------|-------------| -| `/oauth/authorize` | GET | Authorization endpoint | -| `/oauth/token` | POST | Token endpoint | -| `/oauth/userinfo` | GET | User info endpoint | -| `/oauth/revoke` | POST | Token revocation endpoint | - -## Authorization Request Parameters - -| Parameter | Required | Description | -|-----------|----------|-------------| -| `client_id` | Yes | Client identifier | -| `redirect_uri` | Yes | Callback URI | -| `response_type` | Yes | Must be "code" | -| `state` | Recommended | CSRF protection | -| `scope` | Optional | Requested scopes | - -## Token Request Parameters - -### Authorization Code Grant - -| Parameter | Required | Description | -|-----------|----------|-------------| -| `grant_type` | Yes | "authorization_code" | -| `code` | Yes | Authorization code | -| `client_id` | Yes | Client identifier | -| `client_secret` | Yes | Client secret | -| `redirect_uri` | Yes | Must match authorization request | - -### Refresh Token Grant - -| Parameter | Required | Description | -|-----------|----------|-------------| -| `grant_type` | Yes | "refresh_token" | -| `refresh_token` | Yes | Refresh token | -| `client_id` | Yes | Client identifier | -| `client_secret` | Yes | Client secret | - -## Security Features - -1. **Client Authentication** - Client secret verification -2. **Redirect URI Validation** - Prevent open redirect attacks -3. **State Parameter** - CSRF protection -4. **Code Expiration** - Authorization codes expire in 10 minutes -5. **Token Expiration** - Access tokens expire in 2 hours -6. **One-time Use** - Authorization codes can only be used once -7. **Scope Validation** - Requested scopes must be allowed - -## Integration Example - -```go -package main - -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/storage/memory" -) - -func main() { - storage := memory.NewStorage() - oauth2Server := core.NewOAuth2Server(storage) - - // Register client - oauth2Server.RegisterClient(&core.OAuth2Client{ - ClientID: "my-app", - ClientSecret: "my-secret", - RedirectURIs: []string{"http://localhost:3000/callback"}, - GrantTypes: []core.OAuth2GrantType{core.GrantTypeAuthorizationCode}, - Scopes: []string{"read", "write"}, - }) - - // Generate authorization code - authCode, _ := oauth2Server.GenerateAuthorizationCode( - "my-app", - "http://localhost:3000/callback", - "user123", - []string{"read"}, - ) - - // Exchange code for token - token, _ := oauth2Server.ExchangeCodeForToken( - authCode.Code, - "my-app", - "my-secret", - "http://localhost:3000/callback", - ) - - // Validate token - validated, _ := oauth2Server.ValidateAccessToken(token.Token) - - // Refresh token - newToken, _ := oauth2Server.RefreshAccessToken( - token.RefreshToken, - "my-app", - "my-secret", - ) -} -``` - -## Next Steps - -- [Security Features Example](../security-features/) -- [Refresh Token Guide](../../docs/guide/refresh-token.md) -- [OAuth2 Documentation](../../docs/guide/oauth2.md) - diff --git a/examples/oauth2-example/README_zh.md b/examples/oauth2-example/README_zh.md deleted file mode 100644 index db58bbc..0000000 --- a/examples/oauth2-example/README_zh.md +++ /dev/null @@ -1,232 +0,0 @@ -[English](README.md) | 中文文档 - -# OAuth2 授权码模式示例 - -完整的 OAuth2 授权码流程实现示例。 - -## 功能特性 - -- **授权码模式** - 标准的 OAuth2 授权码流程 -- **令牌刷新** - 使用刷新令牌刷新访问令牌 -- **令牌验证** - 验证访问令牌 -- **令牌撤销** - 撤销访问令牌 -- **多客户端** - 支持多个 OAuth2 客户端 -- **权限管理** - 细粒度权限控制 - -## 快速开始 - -### 1. 运行服务器 - -```bash -cd examples/oauth2-example -go run main.go -``` - -服务器运行在 `http://localhost:8080` - -### 2. OAuth2 流程 - -#### 步骤 1: 授权请求 - -```bash -curl "http://localhost:8080/oauth/authorize?client_id=webapp&redirect_uri=http://localhost:8080/callback&response_type=code&state=xyz123" -``` - -响应: -```json -{ - "message": "Authorization code generated", - "code": "a3f5d8b2c1e4f6a9...", - "redirect_url": "http://localhost:8080/callback?code=...&state=xyz123", - "user_id": "user123", - "scopes": ["read", "write"] -} -``` - -#### 步骤 2: 用授权码换取令牌 - -```bash -curl -X POST http://localhost:8080/oauth/token \ - -d "grant_type=authorization_code" \ - -d "code=a3f5d8b2c1e4f6a9..." \ - -d "client_id=webapp" \ - -d "client_secret=secret123" \ - -d "redirect_uri=http://localhost:8080/callback" -``` - -响应: -```json -{ - "access_token": "b4f6d9c3d2e5f7b0...", - "token_type": "Bearer", - "expires_in": 7200, - "refresh_token": "c5f7e0d4e3f6e8c1...", - "scope": ["read", "write"] -} -``` - -#### 步骤 3: 使用访问令牌 - -```bash -curl http://localhost:8080/oauth/userinfo \ - -H "Authorization: Bearer b4f6d9c3d2e5f7b0..." -``` - -响应: -```json -{ - "user_id": "user123", - "client_id": "webapp", - "scopes": ["read", "write"], - "expires_in": 7200, - "issued_at": 1700000000 -} -``` - -#### 步骤 4: 刷新访问令牌 - -```bash -curl -X POST http://localhost:8080/oauth/token \ - -d "grant_type=refresh_token" \ - -d "refresh_token=c5f7e0d4e3f6e8c1..." \ - -d "client_id=webapp" \ - -d "client_secret=secret123" -``` - -#### 步骤 5: 撤销令牌 - -```bash -curl -X POST http://localhost:8080/oauth/revoke \ - -d "token=b4f6d9c3d2e5f7b0..." -``` - -## 已注册的客户端 - -### Web 应用 - -``` -Client ID: webapp -Client Secret: secret123 -回调 URI: - - http://localhost:8080/callback - - http://localhost:3000/callback -权限范围: read, write, profile -``` - -### 移动应用 - -``` -Client ID: mobile-app -Client Secret: mobile-secret-456 -回调 URI: - - myapp://oauth/callback -权限范围: read, write -``` - -## API 端点 - -| 端点 | 方法 | 说明 | -|------|------|------| -| `/oauth/authorize` | GET | 授权端点 | -| `/oauth/token` | POST | 令牌端点 | -| `/oauth/userinfo` | GET | 用户信息端点 | -| `/oauth/revoke` | POST | 令牌撤销端点 | - -## 授权请求参数 - -| 参数 | 必需 | 说明 | -|------|------|------| -| `client_id` | 是 | 客户端标识符 | -| `redirect_uri` | 是 | 回调 URI | -| `response_type` | 是 | 必须是 "code" | -| `state` | 推荐 | CSRF 保护 | -| `scope` | 可选 | 请求的权限范围 | - -## 令牌请求参数 - -### 授权码模式 - -| 参数 | 必需 | 说明 | -|------|------|------| -| `grant_type` | 是 | "authorization_code" | -| `code` | 是 | 授权码 | -| `client_id` | 是 | 客户端标识符 | -| `client_secret` | 是 | 客户端密钥 | -| `redirect_uri` | 是 | 必须与授权请求匹配 | - -### 刷新令牌模式 - -| 参数 | 必需 | 说明 | -|------|------|------| -| `grant_type` | 是 | "refresh_token" | -| `refresh_token` | 是 | 刷新令牌 | -| `client_id` | 是 | 客户端标识符 | -| `client_secret` | 是 | 客户端密钥 | - -## 安全特性 - -1. **客户端认证** - 客户端密钥验证 -2. **回调 URI 验证** - 防止开放重定向攻击 -3. **State 参数** - CSRF 保护 -4. **授权码过期** - 授权码 10 分钟后过期 -5. **令牌过期** - 访问令牌 2 小时后过期 -6. **一次性使用** - 授权码只能使用一次 -7. **权限验证** - 请求的权限必须被允许 - -## 集成示例 - -```go -package main - -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/storage/memory" -) - -func main() { - storage := memory.NewStorage() - oauth2Server := core.NewOAuth2Server(storage) - - // 注册客户端 - oauth2Server.RegisterClient(&core.OAuth2Client{ - ClientID: "my-app", - ClientSecret: "my-secret", - RedirectURIs: []string{"http://localhost:3000/callback"}, - GrantTypes: []core.OAuth2GrantType{core.GrantTypeAuthorizationCode}, - Scopes: []string{"read", "write"}, - }) - - // 生成授权码 - authCode, _ := oauth2Server.GenerateAuthorizationCode( - "my-app", - "http://localhost:3000/callback", - "user123", - []string{"read"}, - ) - - // 用授权码换取令牌 - token, _ := oauth2Server.ExchangeCodeForToken( - authCode.Code, - "my-app", - "my-secret", - "http://localhost:3000/callback", - ) - - // 验证令牌 - validated, _ := oauth2Server.ValidateAccessToken(token.Token) - - // 刷新令牌 - newToken, _ := oauth2Server.RefreshAccessToken( - token.RefreshToken, - "my-app", - "my-secret", - ) -} -``` - -## 下一步 - -- [安全特性示例](../security-features/) -- [刷新令牌指南](../../docs/guide/refresh-token_zh.md) -- [OAuth2 文档](../../docs/guide/oauth2_zh.md) - diff --git a/examples/quick-start/complex-example/README.md b/examples/quick-start/complex-example/README.md new file mode 100644 index 0000000..468badc --- /dev/null +++ b/examples/quick-start/complex-example/README.md @@ -0,0 +1,3 @@ +# 快速开始示例 + +这是一个复杂的 Sa-Token-Go 使用示例,展示了如何使用 `stputil` 全局工具类全面实现认证和授权功能。 \ No newline at end of file diff --git a/examples/quick-start/complex-example/go.mod b/examples/quick-start/complex-example/go.mod new file mode 100644 index 0000000..be42a55 --- /dev/null +++ b/examples/quick-start/complex-example/go.mod @@ -0,0 +1,9 @@ +module github.com/click33/sa-token-go/examples/quick-start/complex-example + +go 1.25.0 + +require ( + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/redis v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 +) diff --git a/examples/quick-start/complex-example/main.go b/examples/quick-start/complex-example/main.go new file mode 100644 index 0000000..9cfb165 --- /dev/null +++ b/examples/quick-start/complex-example/main.go @@ -0,0 +1,181 @@ +// @Author daixk 2026/1/6 14:36:00 +package main + +import ( + "context" + "fmt" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/storage/redis" + "github.com/click33/sa-token-go/stputil" + "time" +) + +func init() { + storage, err := redis.NewStorage("redis://:root@192.168.19.104:6379/0?dial_timeout=3&read_timeout=10s&max_retries=2") + if err != nil { + panic(err) + } + + stputil.SetManager( + builder.NewBuilder(). + // ========== 存储和Token ========== + SetStorage(storage). // 设置存储实现(默认内存,可配置Redis或自实现) + TokenName("satoken"). // Token 名称(也是 Cookie 名称) + Timeout(300). // Token 过期时间(秒) + MaxRefresh(150). // 自动续期触发阈值(秒) + RenewInterval(config.NoLimit). // 续期的最小间隔(秒) + ActiveTimeout(config.NoLimit). // 最大不活跃时间(秒) + IsConcurrent(true). // 是否允许并发登录 + IsShare(false). // 并发登录是否共享 Token + MaxLoginCount(2). // 最大在线 Token 数量 + IsReadBody(false). // 是否从请求体读取 Token + IsReadHeader(true). // 是否从 Header 读取 Token + IsReadCookie(false). // 是否从 Cookie 读取 Token + TokenStyle(adapter.TokenStyleUUID). // Token 样式 + TokenSessionCheckLogin(true). // 登录时是否校验Token会话 + AutoRenew(true). // 是否自动续期 + JwtSecretKey(""). // 设置JWT密钥(JWT模式才生效) + AuthType("auth"). // 认证体系类型 + KeyPrefix("satoken"). // 存储键前缀 + + // ========== Cookie配置 ========== + CookieDomain("example.com"). // Cookie域名 + CookiePath("/"). // Cookie路径 + CookieSecure(false). // 是否启用Secure + CookieHttpOnly(true). // 是否启用HttpOnly + CookieSameSite(config.SameSiteLax). // SameSite策略 + CookieMaxAge(300). // Cookie最大过期时间 + // CookieConfig(&config.CookieConfig{...}). // 可以直接设置完整Cookie配置 + + // ========== 日志配置 ========== + IsLog(true). // 是否打印操作日志 + IsPrintBanner(true). // 是否打印启动Banner + LoggerPath("./logs"). // 日志目录 + LoggerFileFormat("{Y}-{m}-{d}.log"). // 日志文件命名格式 + LoggerPrefix("[satoken]"). // 日志前缀 + LoggerLevel(adapter.LogLevelDebug). // 最低日志级别 + LoggerTimeFormat("2006-01-02 15:04:05"). // 时间戳格式 + LoggerStdout(true). // 是否打印到控制台 + LoggerStdoutOnly(false). // 是否只打印到控制台 + LoggerQueueSize(4096). // 异步写入队列大小 + LoggerRotateSize(1024 * 1024 * 10). // 滚动文件大小阈值 10MB + LoggerRotateExpire(24 * time.Hour). // 滚动文件时间间隔 + LoggerRotateBackupLimit(30). // 最大备份文件数量 + LoggerRotateBackupDays(7). // 备份文件保留天数 + // LoggerConfig(&slog.LoggerConfig{...}). // 可以直接设置完整日志配置 + + // ========== 续期池配置 ========== + RenewPoolMinSize(10). // 最小协程数 + RenewPoolMaxSize(50). // 最大协程数 + RenewPoolScaleUpRate(0.7). // 扩容阈值 + RenewPoolScaleDownRate(0.3). // 缩容阈值 + RenewPoolCheckInterval(5 * time.Second). // 自动扩缩容检查间隔 + RenewPoolExpiry(60 * time.Second). // 空闲协程过期时间 + RenewPoolPrintStatusInterval(30 * time.Second). // 状态打印间隔 + RenewPoolPreAlloc(false). // 是否预分配内存 + RenewPoolNonBlocking(true). // 是否非阻塞模式 + // RenewPoolConfig(&ants.RenewPoolConfig{...}). // 可以直接设置完整续期池配置 + + // ========== 自定义适配器 ========== + // SetGenerator(generator). // 自定义Token生成器 + // SetCodec(codec). // 自定义编码器 + // SetLog(log). // 自定义日志 + // SetPool(pool). // 自定义协程池 + + // ========== 自定义权限与角色 ========== + SetCustomPermissionListFunc(func(loginID, authType string) ([]string, error) { + if loginID == "1" { + return []string{"admin:read", "admin:update"}, nil + } + return []string{"user:read"}, nil + }). + SetCustomRoleListFunc(func(loginID, authType string) ([]string, error) { + if loginID == "1" { + return []string{"admin", "guanliyuan"}, nil + } + return []string{"user"}, nil + }). + + // ========== JWT模式 ========== + // Jwt("your-secret-key"). // 如果需要JWT模式,可直接启用 + + Build(), // 构建Manager + ) +} + +func main() { + ctx := context.Background() + + // 1. 登录(支持多种类型) + fmt.Println("1. 登录测试") + token1, _ := stputil.Login(ctx, 1000) + fmt.Printf(" 用户1000登录成功,Token: %s\n", token1) + + token2, _ := stputil.Login(ctx, "user123") + fmt.Printf(" 用户user123登录成功,Token: %s\n\n", token2) + + // 2. 检查登录 + fmt.Println("2. 检查登录") + fmt.Printf(" Token1是否登录: %v\n", stputil.IsLogin(ctx, token1)) + fmt.Printf(" Token2是否登录: %v\n\n", stputil.IsLogin(ctx, token2)) + + // 3. 获取登录ID + fmt.Println("3. 获取登录ID") + loginID1, _ := stputil.GetLoginID(ctx, token1) + loginID2, _ := stputil.GetLoginID(ctx, token2) + fmt.Printf(" Token1的登录ID: %s\n", loginID1) + fmt.Printf(" Token2的登录ID: %s\n\n", loginID2) + + // 4. 权限管理 + fmt.Println("4. 权限管理") + _ = stputil.SetPermissions(ctx, 1000, []string{"user:read", "user:write", "admin:*"}) + fmt.Println(" 已设置权限: user:read, user:write, admin:*") + + fmt.Printf(" 是否有user:read权限: %v\n", stputil.HasPermission(ctx, 1000, "user:read")) + fmt.Printf(" 是否有user:delete权限: %v\n", stputil.HasPermission(ctx, 1000, "user:delete")) + fmt.Printf(" 是否有admin:delete权限(通配符): %v\n\n", stputil.HasPermission(ctx, 1000, "admin:delete")) + + // 5. 角色管理 + fmt.Println("5. 角色管理") + _ = stputil.SetRoles(ctx, 1000, []string{"admin", "manager-example"}) + fmt.Println(" 已设置角色: admin, manager-example") + + fmt.Printf(" 是否有admin角色: %v\n", stputil.HasRole(ctx, 1000, "admin")) + fmt.Printf(" 是否有user角色: %v\n\n", stputil.HasRole(ctx, 1000, "user")) + + // 6. Session管理 + fmt.Println("6. Session管理") + sess, _ := stputil.GetSession(ctx, 1000) + _ = sess.Set(ctx, "nickname", "张三") + _ = sess.Set(ctx, "age", 25) + fmt.Printf(" Session已设置: nickname=%s, age=%d\n", sess.GetString("nickname"), sess.GetInt("age")) + + // 7. 账号封禁 + fmt.Println("\n7. 账号封禁") + _ = stputil.Disable(ctx, "user123", 1*time.Hour) + fmt.Printf(" 用户user123已被封禁1小时\n") + fmt.Printf(" 是否被封禁: %v\n", stputil.IsDisable(ctx, "user123")) + + remainingTime, _ := stputil.GetDisableTime(ctx, "user123") + fmt.Printf(" 剩余封禁时间: %d秒\n", remainingTime) + + // 8. 解封 + _ = stputil.Untie(ctx, "user123") + fmt.Printf(" 已解封,是否被封禁: %v\n\n", stputil.IsDisable(ctx, "user123")) + + // 9. Token信息 + fmt.Println("9. Token信息") + info, _ := stputil.GetTokenInfo(ctx, token1) + fmt.Printf(" 登录ID: %s\n", info.LoginID) + fmt.Printf(" 设备: %s\n", info.Device) + fmt.Printf(" 创建时间: %d\n", info.CreateTime) + fmt.Printf(" 活跃时间: %d\n\n", info.ActiveTime) + + // 10. 登出 + fmt.Println("10. 登出") + _ = stputil.Logout(ctx, 1000) + fmt.Printf(" 用户1000已登出\n") + fmt.Printf(" Token1是否还有效: %v\n", stputil.IsLogin(ctx, token1)) +} diff --git a/examples/quick-start/simple-example/README.md b/examples/quick-start/simple-example/README.md index 5f1e664..a33db99 100644 --- a/examples/quick-start/simple-example/README.md +++ b/examples/quick-start/simple-example/README.md @@ -1,6 +1,6 @@ # 快速开始示例 -这是一个最简单的 Sa-Token-Go 使用示例,展示了如何使用 `StpUtil` 全局工具类快速实现认证和授权功能。 +这是一个最简单的 Sa-Token-Go 使用示例,展示了如何使用 `stputil` 全局工具类快速实现认证和授权功能。 ## 运行示例 @@ -13,59 +13,131 @@ go run main.go 本示例展示了以下功能: 1. **一行初始化** - 使用 Builder 模式快速配置 -2. **登录认证** - 支持多种类型的用户 ID +2. **登录认证** - 支持多种类型的用户 ID(int、string 等) 3. **检查登录** - 验证用户登录状态 -4. **权限管理** - 设置和检查用户权限 +4. **权限管理** - 设置和检查用户权限(支持通配符) 5. **角色管理** - 设置和检查用户角色 6. **Session 管理** - 存储和读取会话数据 -7. **账号封禁** - 临时封禁用户 +7. **账号封禁** - 临时封禁和解封用户 8. **Token 信息** - 查看 Token 详细信息 9. **登出** - 清除用户登录状态 ## 核心代码 ```go +package main + import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" + "context" + "fmt" + "time" + + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/stputil" ) func init() { - // 🎯 一行初始化! + // 一行初始化 stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). + builder.NewBuilder(). + SetStorage(memory.NewStorage()). TokenName("Authorization"). - Timeout(86400). // 24小时 - TokenStyle(core.TokenStyleRandom64). + Timeout(86400). // 24小时 + MaxRefresh(43200). // 12小时 + TokenStyle(adapter.TokenStyleUUID). Build(), ) } func main() { + ctx := context.Background() + // 登录 - token, _ := stputil.Login(1000) - + token, _ := stputil.Login(ctx, 1000) + fmt.Println("Token:", token) + + // 检查登录 + isLogin := stputil.IsLogin(ctx, token) + fmt.Println("是否登录:", isLogin) + + // 获取登录ID + loginID, _ := stputil.GetLoginID(ctx, token) + fmt.Println("登录ID:", loginID) + // 设置权限 - stputil.SetPermissions(1000, []string{"user:read", "user:write"}) - - // 检查权限 - hasPermission := stputil.HasPermission(1000, "user:read") - + _ = stputil.SetPermissions(ctx, 1000, []string{"user:read", "user:write", "admin:*"}) + + // 检查权限(支持通配符匹配) + hasPermission := stputil.HasPermission(ctx, 1000, "user:read") + hasAdminPerm := stputil.HasPermission(ctx, 1000, "admin:delete") // admin:* 匹配 + fmt.Println("有 user:read 权限:", hasPermission) + fmt.Println("有 admin:delete 权限:", hasAdminPerm) + + // 设置角色 + _ = stputil.SetRoles(ctx, 1000, []string{"admin", "manager"}) + + // 检查角色 + hasRole := stputil.HasRole(ctx, 1000, "admin") + fmt.Println("有 admin 角色:", hasRole) + + // Session 管理 + sess, _ := stputil.GetSession(ctx, 1000) + _ = sess.Set(ctx, "nickname", "张三") + fmt.Println("昵称:", sess.GetString("nickname")) + + // 账号封禁 + _ = stputil.Disable(ctx, 1000, 1*time.Hour) + fmt.Println("是否被封禁:", stputil.IsDisable(ctx, 1000)) + + // 解封 + _ = stputil.Untie(ctx, 1000) + // 登出 - stputil.Logout(1000) + _ = stputil.Logout(ctx, 1000) + fmt.Println("登出后是否登录:", stputil.IsLogin(ctx, token)) } ``` +## 重要说明 + +### Context 参数 + +所有 `stputil` 函数都需要 `context.Context` 作为第一个参数: + +```go +ctx := context.Background() + +// 正确用法 +token, _ := stputil.Login(ctx, userID) +isLogin := stputil.IsLogin(ctx, token) +_ = stputil.Logout(ctx, userID) +``` + +### 权限通配符 + +支持使用 `*` 作为通配符匹配权限: + +```go +// 设置权限 +_ = stputil.SetPermissions(ctx, userID, []string{"admin:*"}) + +// admin:* 可以匹配所有 admin: 开头的权限 +stputil.HasPermission(ctx, userID, "admin:read") // true +stputil.HasPermission(ctx, userID, "admin:write") // true +stputil.HasPermission(ctx, userID, "admin:delete") // true +stputil.HasPermission(ctx, userID, "user:read") // false +``` + ## 输出示例 ``` === Sa-Token-Go 简洁使用示例 === 1. 登录测试 - 用户1000登录成功,Token: xxx - 用户user123登录成功,Token: yyy + 用户1000登录成功,Token: a1b2c3d4-e5f6-7890-abcd-ef1234567890 + 用户user123登录成功,Token: b2c3d4e5-f6a7-8901-bcde-f12345678901 2. 检查登录 Token1是否登录: true @@ -81,12 +153,52 @@ func main() { 是否有user:delete权限: false 是否有admin:delete权限(通配符): true -... +5. 角色管理 + 已设置角色: admin, manager + 是否有admin角色: true + 是否有user角色: false + +6. Session管理 + Session已设置: nickname=张三, age=25 + +7. 账号封禁 + 用户user123已被封禁1小时 + 是否被封禁: true + 剩余封禁时间: 3600秒 + 已解封,是否被封禁: false + +8. Token信息 + 登录ID: 1000 + 设备: default + 创建时间: 1703750400 + 活跃时间: 1703750400 + +9. 登出 + 用户1000已登出 + Token1是否还有效: false + +=== 示例完成! === ``` +## 常用函数速查 + +| 函数 | 说明 | +|------|------| +| `stputil.Login(ctx, loginID)` | 用户登录,返回 Token | +| `stputil.Logout(ctx, loginID)` | 用户登出 | +| `stputil.IsLogin(ctx, token)` | 检查是否已登录 | +| `stputil.GetLoginID(ctx, token)` | 获取登录ID | +| `stputil.SetPermissions(ctx, loginID, perms)` | 设置权限 | +| `stputil.HasPermission(ctx, loginID, perm)` | 检查权限 | +| `stputil.SetRoles(ctx, loginID, roles)` | 设置角色 | +| `stputil.HasRole(ctx, loginID, role)` | 检查角色 | +| `stputil.GetSession(ctx, loginID)` | 获取 Session | +| `stputil.Disable(ctx, loginID, duration)` | 封禁账号 | +| `stputil.Untie(ctx, loginID)` | 解封账号 | +| `stputil.Kickout(ctx, loginID)` | 踢人下线 | + ## 扩展学习 - [Gin 集成示例](../../gin/gin-example) - 学习如何在 Gin 框架中使用 -- [注解装饰器示例](../../annotation/annotation-example) - 学习注解式编程 -- [完整文档](../../../docs) - 查看详细的 API 文档 - +- [注解装饰器示例](../../annotation/annotation-example) - 学习中间件装饰器 +- [事件监听示例](../../manager/listener-example) - 学习事件监听机制 diff --git a/examples/quick-start/simple-example/go.mod b/examples/quick-start/simple-example/go.mod index e7532cb..21a9583 100644 --- a/examples/quick-start/simple-example/go.mod +++ b/examples/quick-start/simple-example/go.mod @@ -1,20 +1,16 @@ -module github.com/click33/sa-token-go/examples/simple-example +module github.com/click33/sa-token-go/examples/quick-start/simple-example -go 1.21 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/storage/memory v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 ) require ( - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../../core - github.com/click33/sa-token-go/storage/memory => ../../../storage/memory - github.com/click33/sa-token-go/stputil => ../../../stputil + github.com/panjf2000/ants/v2 v2.11.3 // indirect + golang.org/x/sync v0.19.0 // indirect ) diff --git a/examples/quick-start/simple-example/go.sum b/examples/quick-start/simple-example/go.sum new file mode 100644 index 0000000..e029147 --- /dev/null +++ b/examples/quick-start/simple-example/go.sum @@ -0,0 +1,20 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/core v0.1.6/go.mod h1:mb3AQAJIXqx9WdULyn5qjufK1j/u+kgB0q+tafHVhgk= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/storage/memory v0.1.6/go.mod h1:YNojcgyLC/uFrmReZLePCDQ5WK2fo2WWGRjRMvXVH74= +github.com/click33/sa-token-go/stputil v0.1.5 h1:603tbI4JkBTg3MnfTj+lCMDxJOKSCOqsMyC2zyuvEco= +github.com/click33/sa-token-go/stputil v0.1.5/go.mod h1:YH+3NLXgGJfrS2wkGubMWFnr/Nk0GgejOtRxcE+9x0c= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/quick-start/simple-example/main.go b/examples/quick-start/simple-example/main.go index efb9fc7..beae5c3 100644 --- a/examples/quick-start/simple-example/main.go +++ b/examples/quick-start/simple-example/main.go @@ -1,22 +1,23 @@ package main import ( + "context" "fmt" "time" - "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" "github.com/click33/sa-token-go/storage/memory" "github.com/click33/sa-token-go/stputil" ) func init() { - // 超简洁初始化(一行搞定!) stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). + builder.NewBuilder(). + SetStorage(memory.NewStorage()). TokenName("Authorization"). - Timeout(86400). // 24小时 - TokenStyle(core.TokenStyleRandom64). + Timeout(86400). + TokenStyle(adapter.TokenStyleUUID). Build(), ) } @@ -24,66 +25,68 @@ func init() { func main() { fmt.Println("=== Sa-Token-Go 简洁使用示例 ===\n") + ctx := context.Background() + // 1. 登录(支持多种类型) fmt.Println("1. 登录测试") - token1, _ := stputil.Login(1000) + token1, _ := stputil.Login(ctx, 1000) fmt.Printf(" 用户1000登录成功,Token: %s\n", token1) - token2, _ := stputil.Login("user123") + token2, _ := stputil.Login(ctx, "user123") fmt.Printf(" 用户user123登录成功,Token: %s\n\n", token2) // 2. 检查登录 fmt.Println("2. 检查登录") - fmt.Printf(" Token1是否登录: %v\n", stputil.IsLogin(token1)) - fmt.Printf(" Token2是否登录: %v\n\n", stputil.IsLogin(token2)) + fmt.Printf(" Token1是否登录: %v\n", stputil.IsLogin(ctx, token1)) + fmt.Printf(" Token2是否登录: %v\n\n", stputil.IsLogin(ctx, token2)) // 3. 获取登录ID fmt.Println("3. 获取登录ID") - loginID1, _ := stputil.GetLoginID(token1) - loginID2, _ := stputil.GetLoginID(token2) + loginID1, _ := stputil.GetLoginID(ctx, token1) + loginID2, _ := stputil.GetLoginID(ctx, token2) fmt.Printf(" Token1的登录ID: %s\n", loginID1) fmt.Printf(" Token2的登录ID: %s\n\n", loginID2) // 4. 权限管理 fmt.Println("4. 权限管理") - stputil.SetPermissions(1000, []string{"user:read", "user:write", "admin:*"}) + _ = stputil.SetPermissions(ctx, 1000, []string{"user:read", "user:write", "admin:*"}) fmt.Println(" 已设置权限: user:read, user:write, admin:*") - fmt.Printf(" 是否有user:read权限: %v\n", stputil.HasPermission(1000, "user:read")) - fmt.Printf(" 是否有user:delete权限: %v\n", stputil.HasPermission(1000, "user:delete")) - fmt.Printf(" 是否有admin:delete权限(通配符): %v\n\n", stputil.HasPermission(1000, "admin:delete")) + fmt.Printf(" 是否有user:read权限: %v\n", stputil.HasPermission(ctx, 1000, "user:read")) + fmt.Printf(" 是否有user:delete权限: %v\n", stputil.HasPermission(ctx, 1000, "user:delete")) + fmt.Printf(" 是否有admin:delete权限(通配符): %v\n\n", stputil.HasPermission(ctx, 1000, "admin:delete")) // 5. 角色管理 fmt.Println("5. 角色管理") - stputil.SetRoles(1000, []string{"admin", "manager"}) - fmt.Println(" 已设置角色: admin, manager") + _ = stputil.SetRoles(ctx, 1000, []string{"admin", "manager-example"}) + fmt.Println(" 已设置角色: admin, manager-example") - fmt.Printf(" 是否有admin角色: %v\n", stputil.HasRole(1000, "admin")) - fmt.Printf(" 是否有user角色: %v\n\n", stputil.HasRole(1000, "user")) + fmt.Printf(" 是否有admin角色: %v\n", stputil.HasRole(ctx, 1000, "admin")) + fmt.Printf(" 是否有user角色: %v\n\n", stputil.HasRole(ctx, 1000, "user")) // 6. Session管理 fmt.Println("6. Session管理") - sess, _ := stputil.GetSession(1000) - sess.Set("nickname", "张三") - sess.Set("age", 25) + sess, _ := stputil.GetSession(ctx, 1000) + _ = sess.Set(ctx, "nickname", "张三") + _ = sess.Set(ctx, "age", 25) fmt.Printf(" Session已设置: nickname=%s, age=%d\n", sess.GetString("nickname"), sess.GetInt("age")) // 7. 账号封禁 fmt.Println("\n7. 账号封禁") - stputil.Disable("user123", 1*time.Hour) + _ = stputil.Disable(ctx, "user123", 1*time.Hour) fmt.Printf(" 用户user123已被封禁1小时\n") - fmt.Printf(" 是否被封禁: %v\n", stputil.IsDisable("user123")) + fmt.Printf(" 是否被封禁: %v\n", stputil.IsDisable(ctx, "user123")) - remainingTime, _ := stputil.GetDisableTime("user123") + remainingTime, _ := stputil.GetDisableTime(ctx, "user123") fmt.Printf(" 剩余封禁时间: %d秒\n", remainingTime) // 8. 解封 - stputil.Untie("user123") - fmt.Printf(" 已解封,是否被封禁: %v\n\n", stputil.IsDisable("user123")) + _ = stputil.Untie(ctx, "user123") + fmt.Printf(" 已解封,是否被封禁: %v\n\n", stputil.IsDisable(ctx, "user123")) // 9. Token信息 fmt.Println("9. Token信息") - info, _ := stputil.GetTokenInfo(token1) + info, _ := stputil.GetTokenInfo(ctx, token1) fmt.Printf(" 登录ID: %s\n", info.LoginID) fmt.Printf(" 设备: %s\n", info.Device) fmt.Printf(" 创建时间: %d\n", info.CreateTime) @@ -91,9 +94,9 @@ func main() { // 10. 登出 fmt.Println("10. 登出") - stputil.Logout(1000) + _ = stputil.Logout(ctx, 1000) fmt.Printf(" 用户1000已登出\n") - fmt.Printf(" Token1是否还有效: %v\n", stputil.IsLogin(token1)) + fmt.Printf(" Token1是否还有效: %v\n", stputil.IsLogin(ctx, token1)) fmt.Println("\n=== 示例完成! ===") } diff --git a/examples/redis-example/README.md b/examples/redis-example/README.md deleted file mode 100644 index 71f191b..0000000 --- a/examples/redis-example/README.md +++ /dev/null @@ -1,97 +0,0 @@ -# Redis Storage Example - -[中文说明](README_zh.md) | English - -This example demonstrates how to use Redis as the storage backend for Sa-Token-Go. - -## Prerequisites - -- Redis server running on `localhost:6379` (or set `REDIS_ADDR` environment variable) -- Go 1.21 or higher - -## Install Redis - -### macOS -```bash -brew install redis -brew services start redis -``` - -### Linux (Ubuntu/Debian) -```bash -sudo apt-get install redis-server -sudo systemctl start redis -``` - -### Docker -```bash -docker run -d -p 6379:6379 redis:7-alpine -``` - -## Run Example - -```bash -# Without password -go run main.go - -# With password -REDIS_PASSWORD=your-password go run main.go - -# Custom Redis address -REDIS_ADDR=redis.example.com:6379 go run main.go -``` - -## Key Features Demonstrated - -1. ✅ **Redis Connection** - Connect to Redis with go-redis -2. ✅ **Authentication** - Login/Logout with Redis storage -3. ✅ **Permission Management** - Store permissions in Redis -4. ✅ **Role Management** - Store roles in Redis -5. ✅ **Session Management** - Persistent session data -6. ✅ **Data Persistence** - Data survives application restarts - -## Environment Variables - -| Variable | Description | Default | -|----------|-------------|---------| -| `REDIS_ADDR` | Redis server address | `localhost:6379` | -| `REDIS_PASSWORD` | Redis password | (empty) | -| `REDIS_DB` | Redis database number | `0` | - -## View Data in Redis - -```bash -# Connect to Redis CLI -redis-cli - -# List all Sa-Token keys -KEYS satoken:* - -# View token info -GET satoken:login:token:{your-token} - -# View session data -GET satoken:session:1000 - -# View permissions -SMEMBERS satoken:permission:1000 - -# View roles -SMEMBERS satoken:role:1000 -``` - -## Production Deployment - -See [Redis Storage Guide](../../docs/guide/redis-storage.md) for: -- Connection pool configuration -- High availability (Sentinel) -- Cluster mode -- TLS/SSL support -- Docker/Kubernetes deployment - -## Related Documentation - -- [Redis Storage Guide](../../docs/guide/redis-storage.md) -- [Quick Start](../../docs/tutorial/quick-start.md) -- [Authentication Guide](../../docs/guide/authentication.md) - diff --git a/examples/redis-example/README_zh.md b/examples/redis-example/README_zh.md deleted file mode 100644 index 7f28d5c..0000000 --- a/examples/redis-example/README_zh.md +++ /dev/null @@ -1,100 +0,0 @@ -# Redis 存储示例 - -[English](README.md) | 中文说明 - -本示例演示如何使用 Redis 作为 Sa-Token-Go 的存储后端。 - -## 前置要求 - -- Redis 服务器运行在 `localhost:6379`(或设置 `REDIS_ADDR` 环境变量) -- Go 1.21 或更高版本 - -## 安装 Redis - -### macOS - -```bash -brew install redis -brew services start redis -``` - -### Linux (Ubuntu/Debian) - -```bash -sudo apt-get install redis-server -sudo systemctl start redis -``` - -### Docker - -```bash -docker run -d -p 6379:6379 redis:7-alpine -``` - -## 运行示例 - -```bash -# 无密码 -go run main.go - -# 带密码 -REDIS_PASSWORD=your-password go run main.go - -# 自定义 Redis 地址 -REDIS_ADDR=redis.example.com:6379 go run main.go -``` - -## 演示的核心功能 - -1. ✅ **Redis 连接** - 使用 go-redis 连接 Redis -2. ✅ **认证功能** - 使用 Redis 存储进行登录/登出 -3. ✅ **权限管理** - 在 Redis 中存储权限 -4. ✅ **角色管理** - 在 Redis 中存储角色 -5. ✅ **Session 管理** - 持久化的 Session 数据 -6. ✅ **数据持久化** - 数据在应用重启后仍然存在 - -## 环境变量 - -| 变量 | 说明 | 默认值 | -|------|------|--------| -| `REDIS_ADDR` | Redis 服务器地址 | `localhost:6379` | -| `REDIS_PASSWORD` | Redis 密码 | (空) | -| `REDIS_DB` | Redis 数据库编号 | `0` | - -## 在 Redis 中查看数据 - -```bash -# 连接到 Redis CLI -redis-cli - -# 列出所有 Sa-Token 键 -KEYS satoken:* - -# 查看 Token 信息 -GET satoken:login:token:{your-token} - -# 查看 Session 数据 -GET satoken:session:1000 - -# 查看权限 -SMEMBERS satoken:permission:1000 - -# 查看角色 -SMEMBERS satoken:role:1000 -``` - -## 生产环境部署 - -查看 [Redis 存储指南](../../docs/guide/redis-storage_zh.md) 了解: - -- 连接池配置 -- 高可用(哨兵模式) -- 集群模式 -- TLS/SSL 支持 -- Docker/Kubernetes 部署 - -## 相关文档 - -- [Redis 存储指南](../../docs/guide/redis-storage_zh.md) -- [快速开始](../../docs/tutorial/quick-start.md) -- [认证指南](../../docs/guide/authentication.md) diff --git a/examples/redis-example/go.mod b/examples/redis-example/go.mod deleted file mode 100644 index 4020f16..0000000 --- a/examples/redis-example/go.mod +++ /dev/null @@ -1,23 +0,0 @@ -module github.com/click33/sa-token-go/examples/redis-example - -go 1.21 - -require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/redis v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 - github.com/redis/go-redis/v9 v9.5.1 -) - -require ( - github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/redis => ../../storage/redis - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/examples/redis-example/main.go b/examples/redis-example/main.go deleted file mode 100644 index eecc30e..0000000 --- a/examples/redis-example/main.go +++ /dev/null @@ -1,135 +0,0 @@ -package main - -import ( - "context" - "fmt" - "log" - "os" - - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/storage/redis" - "github.com/click33/sa-token-go/stputil" - goredis "github.com/redis/go-redis/v9" -) - -func main() { - fmt.Println("=== Sa-Token-Go Redis Storage Example ===") - - // Get Redis configuration from environment variables | 从环境变量获取 Redis 配置 - redisAddr := os.Getenv("REDIS_ADDR") - if redisAddr == "" { - redisAddr = "localhost:6379" - } - redisPassword := os.Getenv("REDIS_PASSWORD") - - // Create Redis client | 创建 Redis 客户端 - rdb := goredis.NewClient(&goredis.Options{ - Addr: redisAddr, - Password: redisPassword, - DB: 0, - PoolSize: 10, - }) - - // Test Redis connection | 测试 Redis 连接 - ctx := context.Background() - if err := rdb.Ping(ctx).Err(); err != nil { - log.Fatalf("❌ Failed to connect to Redis: %v\n", err) - } - fmt.Printf("✅ Connected to Redis: %s\n\n", redisAddr) - - // Initialize Sa-Token with Redis storage | 使用 Redis 存储初始化 Sa-Token - redisURL := fmt.Sprintf("redis://:%s@%s/0", redisPassword, redisAddr) - redisStorage, err := redis.NewStorage(redisURL) // Storage 层不处理前缀,符合 Java sa-token 设计 - if err != nil { - log.Fatalf("❌ Failed to create Redis storage: %v\n", err) - } - - // 创建 Manager(符合 Java sa-token 标准设计) - stputil.SetManager( - core.NewBuilder(). - Storage(redisStorage). - TokenName("Authorization"). - TokenStyle(core.TokenStyleRandom64). - Timeout(3600). // 1 hour | 1小时 - KeyPrefix("satoken"). // 设计开头标识 - IsPrintBanner(true). - Build(), - ) - - fmt.Println("📌 当前配置(符合 Java sa-token 标准):") - fmt.Println(" - Storage 层前缀: \"\" (空)") - fmt.Println(" - Manager 层前缀: \"satoken\" → 自动变为 \"satoken:\"") - fmt.Println(" - Redis Key 示例: satoken:login:token:xxx") - fmt.Println(" - ✅ 完全兼容 Java sa-token") - fmt.Println() - - // Test authentication | 测试认证功能 - fmt.Println("1. Login user | 登录用户") - token, err := stputil.Login(1000) - if err != nil { - log.Fatalf("Login failed: %v\n", err) - } - fmt.Printf("✅ Login successful! Token: %s\n\n", token) - - // Check login status | 检查登录状态 - fmt.Println("2. Check login status | 检查登录状态") - if stputil.IsLogin(token) { - fmt.Println("✅ User is logged in") - } - - // Set permissions and roles | 设置权限和角色 - fmt.Println("3. Set permissions and roles | 设置权限和角色") - stputil.SetPermissions(1000, []string{"user:read", "user:write", "admin:*"}) - stputil.SetRoles(1000, []string{"admin", "user"}) - fmt.Println("✅ Permissions and roles set") - - // Check permission | 检查权限 - fmt.Println("4. Check permissions | 检查权限") - if stputil.HasPermission(1000, "user:read") { - fmt.Println("✅ Has permission: user:read") - } - if stputil.HasPermission(1000, "admin:delete") { - fmt.Println("✅ Has permission: admin:delete (wildcard match)") - } - fmt.Println() - - // Check role | 检查角色 - fmt.Println("5. Check roles | 检查角色") - if stputil.HasRole(1000, "admin") { - fmt.Println("✅ Has role: admin") - } - fmt.Println() - - // Get session | 获取 Session - fmt.Println("6. Session management | Session 管理") - sess, _ := stputil.GetSession(1000) - sess.Set("username", "admin") - sess.Set("email", "admin@example.com") - fmt.Println("✅ Session data saved") - - username := sess.GetString("username") - fmt.Printf(" Username: %s\n\n", username) - - // Logout | 登出 - fmt.Println("7. Logout | 登出") - // stputil.Logout(1000) - fmt.Println("✅ User logged out") - - if !stputil.IsLogin(token) { - fmt.Println("✅ Token is now invalid") - } - - // Close Redis connection | 关闭 Redis 连接 - defer func() { - if err := rdb.Close(); err != nil { - log.Printf("Error closing Redis: %v\n", err) - } - }() - - fmt.Println("=== Redis Example Completed ===") - fmt.Println("\n💡 Tips:") - fmt.Println(" • Data is persisted in Redis") - fmt.Println(" • Survives application restarts") - fmt.Println(" • Suitable for production environments") - fmt.Println(" • Supports distributed deployments") -} diff --git a/examples/security-features/go.mod b/examples/security-features/go.mod deleted file mode 100644 index 387f4e0..0000000 --- a/examples/security-features/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/click33/sa-token-go/examples/security-features - -go 1.21 - -require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 -) - -require ( - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/memory => ../../storage/memory - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/examples/security-features/main.go b/examples/security-features/main.go deleted file mode 100644 index f26469e..0000000 --- a/examples/security-features/main.go +++ /dev/null @@ -1,128 +0,0 @@ -package main - -import ( - "fmt" - "time" - - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/storage/memory" - "github.com/click33/sa-token-go/stputil" -) - -func main() { - storage := memory.NewStorage() - manager := core.NewBuilder(). - Storage(storage). - Timeout(3600). - IsPrintBanner(false). - Build() - - stputil.SetManager(manager) - - demoNonce(manager) - fmt.Println() - demoRefreshToken(manager) - fmt.Println() - demoOAuth2(manager) -} - -func demoNonce(manager *core.Manager) { - fmt.Println("=== Nonce Anti-Replay Demo ===") - - nonce, err := manager.GenerateNonce() - if err != nil { - fmt.Printf("Error generating nonce: %v\n", err) - return - } - fmt.Printf("Generated Nonce: %s\n", nonce) - - valid := manager.VerifyNonce(nonce) - fmt.Printf("First verification: %v (should be true)\n", valid) - - valid = manager.VerifyNonce(nonce) - fmt.Printf("Second verification: %v (should be false - replay attack prevented)\n", valid) -} - -func demoRefreshToken(manager *core.Manager) { - fmt.Println("=== Refresh Token Demo ===") - - tokenInfo, err := manager.LoginWithRefreshToken("user1000", "web") - if err != nil { - fmt.Printf("Error: %v\n", err) - return - } - - fmt.Printf("Access Token: %s\n", tokenInfo.AccessToken[:40]+"...") - fmt.Printf("Refresh Token: %s\n", tokenInfo.RefreshToken[:40]+"...") - fmt.Printf("Expires at: %s\n", time.Unix(tokenInfo.ExpireTime, 0).Format(time.RFC3339)) - - fmt.Println("\nRefreshing access token...") - newTokenInfo, err := manager.RefreshAccessToken(tokenInfo.RefreshToken) - if err != nil { - fmt.Printf("Error: %v\n", err) - return - } - - fmt.Printf("New Access Token: %s\n", newTokenInfo.AccessToken[:40]+"...") - fmt.Printf("Same Refresh Token: %v\n", newTokenInfo.RefreshToken == tokenInfo.RefreshToken) -} - -func demoOAuth2(manager *core.Manager) { - fmt.Println("=== OAuth2 Authorization Code Flow Demo ===") - - oauth2Server := manager.GetOAuth2Server() - - client := &core.OAuth2Client{ - ClientID: "webapp123", - ClientSecret: "secret456", - RedirectURIs: []string{"http://localhost:8080/callback"}, - GrantTypes: []core.OAuth2GrantType{core.GrantTypeAuthorizationCode, core.GrantTypeRefreshToken}, - Scopes: []string{"read", "write"}, - } - oauth2Server.RegisterClient(client) - fmt.Println("Client registered") - - authCode, err := oauth2Server.GenerateAuthorizationCode( - "webapp123", - "http://localhost:8080/callback", - "user1000", - []string{"read", "write"}, - ) - if err != nil { - fmt.Printf("Error: %v\n", err) - return - } - fmt.Printf("Authorization Code: %s\n", authCode.Code[:20]+"...") - - accessToken, err := oauth2Server.ExchangeCodeForToken( - authCode.Code, - "webapp123", - "secret456", - "http://localhost:8080/callback", - ) - if err != nil { - fmt.Printf("Error: %v\n", err) - return - } - - fmt.Printf("Access Token: %s\n", accessToken.Token[:20]+"...") - fmt.Printf("Token Type: %s\n", accessToken.TokenType) - fmt.Printf("Expires In: %d seconds\n", accessToken.ExpiresIn) - fmt.Printf("Refresh Token: %s\n", accessToken.RefreshToken[:20]+"...") - fmt.Printf("Scopes: %v\n", accessToken.Scopes) - - validated, err := oauth2Server.ValidateAccessToken(accessToken.Token) - if err != nil { - fmt.Printf("Validation error: %v\n", err) - return - } - fmt.Printf("Token validated for user: %s\n", validated.UserID) - - fmt.Println("\nRefreshing OAuth2 token...") - newToken, err := oauth2Server.RefreshAccessToken(accessToken.RefreshToken, "webapp123", "secret456") - if err != nil { - fmt.Printf("Error: %v\n", err) - return - } - fmt.Printf("New Access Token: %s\n", newToken.Token[:20]+"...") -} diff --git a/examples/session-demo/go.mod b/examples/session-demo/go.mod deleted file mode 100644 index 481560e..0000000 --- a/examples/session-demo/go.mod +++ /dev/null @@ -1,15 +0,0 @@ -module github.com/click33/sa-token-go/examples/session-demo - -go 1.21 - -require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/memory => ../../storage/memory - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/examples/token-styles/README.md b/examples/token-styles/README.md deleted file mode 100644 index cdba381..0000000 --- a/examples/token-styles/README.md +++ /dev/null @@ -1,222 +0,0 @@ -English | [中文文档](README_zh.md) - -# Token Styles Example - -This example demonstrates all available token generation styles in Sa-Token-Go. - -## Available Token Styles - -### 1. UUID Style (`uuid`) -``` -e.g., 550e8400-e29b-41d4-a716-446655440000 -``` -- Standard UUID v4 format -- 36 characters (including hyphens) -- Globally unique - -### 2. Simple Style (`simple`) -``` -e.g., aB3dE5fG7hI9jK1l -``` -- 16-character random string -- Base64 URL-safe encoding -- Compact and simple - -### 3. Random32 Style (`random32`) -``` -e.g., aB3dE5fG7hI9jK1lMnO2pQ4rS6tU8vW0 -``` -- 32-character random string -- High randomness -- Secure and unique - -### 4. Random64 Style (`random64`) -``` -e.g., aB3dE5fG7hI9jK1lMnO2pQ4rS6tU8vW0xY1zA2bC3dD4eE5fF6gG7hH8iI9jJ0kK1l -``` -- 64-character random string -- Maximum randomness -- Extra secure - -### 5. Random128 Style (`random128`) -``` -e.g., aB3dE5fG7hI9jK1lMnO2pQ4rS6tU8vW0xY1zA2bC3dD4eE5fF6gG7hH8iI9jJ0kK1lMmN2nO3oP4pQ5qR6rS7sT8tU9uV0vW1wX2xY3yZ4zA5aB6bC7cD8dE9eF0fG1gH2hI3iJ4jK5kL6lM7mN8nO9oP0 -``` -- 128-character random string -- Extremely secure -- For high-security scenarios - -### 6. JWT Style (`jwt`) -``` -e.g., eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkZXZpY2UiOiJkZWZhdWx0IiwiaWF0IjoxNzAwMDAwMDAwLCJsb2dpbklkIjoidXNlcjEwMDAifQ.xxx -``` -- Standard JWT format -- Contains claims (loginId, device, iat, exp) -- Self-contained and verifiable -- Requires `JwtSecretKey` configuration - -### 7. Hash Style (`hash`) 🆕 -``` -e.g., a3f5d8b2c1e4f6a9d7b8c5e2f1a4d6b9c8e5f2a7d4b1c9e6f3a8d5b2c1e7f4a6 -``` -- SHA256 hash-based token -- Combines loginID, device, timestamp, and random data -- 64-character hexadecimal -- High security and unpredictability - -### 8. Timestamp Style (`timestamp`) 🆕 -``` -e.g., 1700000000123_user1000_a3f5d8b2c1e4f6a9 -``` -- Format: `timestamp_loginID_random` -- Millisecond precision timestamp -- Easily traceable creation time -- Good for debugging and logging - -### 9. Tik Style (`tik`) 🆕 -``` -e.g., 7Kx9mN2pQr4 -``` -- Short ID format (11 characters) -- Similar to TikTok/Douyin style -- Alphanumeric characters (0-9, A-Z, a-z) -- Perfect for URL shortening and sharing - -## Quick Start - -### Installation - -```bash -go get github.com/click33/sa-token-go/core -go get github.com/click33/sa-token-go/stputil -go get github.com/click33/sa-token-go/storage/memory -``` - -### Run the Example - -```bash -cd examples/token-styles -go run main.go -``` - -### Output - -``` -Sa-Token-Go Token Styles Demo -======================================== - -📌 UUID Style (uuid) ----------------------------------------- - 1. Token for user1001: - 550e8400-e29b-41d4-a716-446655440000 - 2. Token for user1002: - f47ac10b-58cc-4372-a567-0e02b2c3d479 - 3. Token for user1003: - 7c9e6679-7425-40de-944b-e07fc1f90ae7 - -📌 Hash Style (SHA256) (hash) ----------------------------------------- - 1. Token for user1001: - a3f5d8b2c1e4f6a9d7b8c5e2f1a4d6b9c8e5f2a7d4b1c9e6f3a8d5b2c1e7f4a6 - 2. Token for user1002: - b4f6d9c3d2e5f7b0e8c9d6f3e2b5d7c0d9f6e3b8d5c2e0f7d4b9c3e8f6b3d2f5 - 3. Token for user1003: - c5f7e0d4e3f6e8c1f9d0e7f4e3c6e8d1e0f7f4c9e6d3f1e8e5c0e9f7c4e3f6e7 - -📌 Timestamp Style (timestamp) ----------------------------------------- - 1. Token for user1001: - 1700000000123_user1001_a3f5d8b2c1e4f6a9 - 2. Token for user1002: - 1700000000456_user1002_b4f6d9c3d2e5f7b0 - 3. Token for user1003: - 1700000000789_user1003_c5f7e0d4e3f6e8c1 - -📌 Tik Style (Short ID) (tik) ----------------------------------------- - 1. Token for user1001: - 7Kx9mN2pQr4 - 2. Token for user1002: - 8Ly0oO3qRs5 - 3. Token for user1003: - 9Mz1pP4rSt6 - -======================================== -✅ All token styles demonstrated! -``` - -## Usage in Your Project - -### Using Hash Style - -```go -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/storage/memory" -) - -func init() { - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleHash). // SHA256 hash style - Timeout(86400). - Build(), - ) -} - -func main() { - token, _ := stputil.Login(1000) - // token: a3f5d8b2c1e4f6a9d7b8c5e2f1a4d6b9c8e5f2a7d4b1c9e6f3a8d5b2c1e7f4a6 -} -``` - -### Using Timestamp Style - -```go -stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleTimestamp). // Timestamp style - Timeout(86400). - Build(), -) - -token, _ := stputil.Login(1000) -// token: 1700000000123_1000_a3f5d8b2c1e4f6a9 -``` - -### Using Tik Style - -```go -stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleTik). // Short ID style - Timeout(86400). - Build(), -) - -token, _ := stputil.Login(1000) -// token: 7Kx9mN2pQr4 -``` - -## Use Cases - -| Style | Best For | Pros | Cons | -|-------|----------|------|------| -| **UUID** | General purpose | Standard, widely supported | Longer | -| **Simple** | Internal APIs | Compact | Less entropy | -| **Random32/64/128** | High security | Very random | Longer strings | -| **JWT** | Stateless auth | Self-contained | Larger size | -| **Hash** 🆕 | Secure tracking | High security, deterministic | 64 chars | -| **Timestamp** 🆕 | Debugging, auditing | Time-traceable | Exposes creation time | -| **Tik** 🆕 | URL sharing, short links | Very short, user-friendly | Lower entropy | - -## Next Steps - -- [Quick Start Guide](../quick-start/) -- [JWT Example](../jwt-example/) -- [Full Documentation](../../docs/) - diff --git a/examples/token-styles/README_zh.md b/examples/token-styles/README_zh.md deleted file mode 100644 index 103a43c..0000000 --- a/examples/token-styles/README_zh.md +++ /dev/null @@ -1,222 +0,0 @@ -[English](README.md) | 中文文档 - -# Token 风格示例 - -本示例演示 Sa-Token-Go 中所有可用的 Token 生成风格。 - -## 可用的 Token 风格 - -### 1. UUID 风格 (`uuid`) -``` -例如:550e8400-e29b-41d4-a716-446655440000 -``` -- 标准 UUID v4 格式 -- 36 个字符(包含连字符) -- 全局唯一 - -### 2. 简单风格 (`simple`) -``` -例如:aB3dE5fG7hI9jK1l -``` -- 16 字符随机字符串 -- Base64 URL 安全编码 -- 紧凑简单 - -### 3. Random32 风格 (`random32`) -``` -例如:aB3dE5fG7hI9jK1lMnO2pQ4rS6tU8vW0 -``` -- 32 字符随机字符串 -- 高随机性 -- 安全且唯一 - -### 4. Random64 风格 (`random64`) -``` -例如:aB3dE5fG7hI9jK1lMnO2pQ4rS6tU8vW0xY1zA2bC3dD4eE5fF6gG7hH8iI9jJ0kK1l -``` -- 64 字符随机字符串 -- 最大随机性 -- 超级安全 - -### 5. Random128 风格 (`random128`) -``` -例如:aB3dE5fG7hI9jK1lMnO2pQ4rS6tU8vW0xY1zA2bC3dD4eE5fF6gG7hH8iI9jJ0kK1lMmN2nO3oP4pQ5qR6rS7sT8tU9uV0vW1wX2xY3yZ4zA5aB6bC7cD8dE9eF0fG1gH2hI3iJ4jK5kL6lM7mN8nO9oP0 -``` -- 128 字符随机字符串 -- 极度安全 -- 用于高安全性场景 - -### 6. JWT 风格 (`jwt`) -``` -例如:eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJkZXZpY2UiOiJkZWZhdWx0IiwiaWF0IjoxNzAwMDAwMDAwLCJsb2dpbklkIjoidXNlcjEwMDAifQ.xxx -``` -- 标准 JWT 格式 -- 包含声明(loginId, device, iat, exp) -- 自包含且可验证 -- 需要配置 `JwtSecretKey` - -### 7. 哈希风格 (`hash`) 🆕 -``` -例如:a3f5d8b2c1e4f6a9d7b8c5e2f1a4d6b9c8e5f2a7d4b1c9e6f3a8d5b2c1e7f4a6 -``` -- 基于 SHA256 哈希的 Token -- 组合 loginID、device、时间戳和随机数据 -- 64 字符十六进制 -- 高安全性和不可预测性 - -### 8. 时间戳风格 (`timestamp`) 🆕 -``` -例如:1700000000123_user1000_a3f5d8b2c1e4f6a9 -``` -- 格式:`时间戳_loginID_随机数` -- 毫秒精度时间戳 -- 易于追溯创建时间 -- 便于调试和日志记录 - -### 9. Tik 风格 (`tik`) 🆕 -``` -例如:7Kx9mN2pQr4 -``` -- 短 ID 格式(11 字符) -- 类似抖音/TikTok 风格 -- 字母数字字符(0-9, A-Z, a-z) -- 适合 URL 缩短和分享 - -## 快速开始 - -### 安装 - -```bash -go get github.com/click33/sa-token-go/core -go get github.com/click33/sa-token-go/stputil -go get github.com/click33/sa-token-go/storage/memory -``` - -### 运行示例 - -```bash -cd examples/token-styles -go run main.go -``` - -### 输出 - -``` -Sa-Token-Go Token Styles Demo -======================================== - -📌 UUID Style (uuid) ----------------------------------------- - 1. Token for user1001: - 550e8400-e29b-41d4-a716-446655440000 - 2. Token for user1002: - f47ac10b-58cc-4372-a567-0e02b2c3d479 - 3. Token for user1003: - 7c9e6679-7425-40de-944b-e07fc1f90ae7 - -📌 Hash Style (SHA256) (hash) ----------------------------------------- - 1. Token for user1001: - a3f5d8b2c1e4f6a9d7b8c5e2f1a4d6b9c8e5f2a7d4b1c9e6f3a8d5b2c1e7f4a6 - 2. Token for user1002: - b4f6d9c3d2e5f7b0e8c9d6f3e2b5d7c0d9f6e3b8d5c2e0f7d4b9c3e8f6b3d2f5 - 3. Token for user1003: - c5f7e0d4e3f6e8c1f9d0e7f4e3c6e8d1e0f7f4c9e6d3f1e8e5c0e9f7c4e3f6e7 - -📌 Timestamp Style (timestamp) ----------------------------------------- - 1. Token for user1001: - 1700000000123_user1001_a3f5d8b2c1e4f6a9 - 2. Token for user1002: - 1700000000456_user1002_b4f6d9c3d2e5f7b0 - 3. Token for user1003: - 1700000000789_user1003_c5f7e0d4e3f6e8c1 - -📌 Tik Style (Short ID) (tik) ----------------------------------------- - 1. Token for user1001: - 7Kx9mN2pQr4 - 2. Token for user1002: - 8Ly0oO3qRs5 - 3. Token for user1003: - 9Mz1pP4rSt6 - -======================================== -✅ All token styles demonstrated! -``` - -## 在项目中使用 - -### 使用哈希风格 - -```go -import ( - "github.com/click33/sa-token-go/core" - "github.com/click33/sa-token-go/stputil" - "github.com/click33/sa-token-go/storage/memory" -) - -func init() { - stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleHash). // SHA256 哈希风格 - Timeout(86400). - Build(), - ) -} - -func main() { - token, _ := stputil.Login(1000) - // token: a3f5d8b2c1e4f6a9d7b8c5e2f1a4d6b9c8e5f2a7d4b1c9e6f3a8d5b2c1e7f4a6 -} -``` - -### 使用时间戳风格 - -```go -stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleTimestamp). // 时间戳风格 - Timeout(86400). - Build(), -) - -token, _ := stputil.Login(1000) -// token: 1700000000123_1000_a3f5d8b2c1e4f6a9 -``` - -### 使用 Tik 风格 - -```go -stputil.SetManager( - core.NewBuilder(). - Storage(memory.NewStorage()). - TokenStyle(core.TokenStyleTik). // 短 ID 风格 - Timeout(86400). - Build(), -) - -token, _ := stputil.Login(1000) -// token: 7Kx9mN2pQr4 -``` - -## 使用场景 - -| 风格 | 最适用于 | 优点 | 缺点 | -|------|----------|------|------| -| **UUID** | 通用场景 | 标准、广泛支持 | 较长 | -| **Simple** | 内部 API | 紧凑 | 熵值较低 | -| **Random32/64/128** | 高安全性 | 随机性强 | 字符串较长 | -| **JWT** | 无状态认证 | 自包含 | 体积较大 | -| **Hash** 🆕 | 安全追踪 | 高安全性、确定性 | 64 字符 | -| **Timestamp** 🆕 | 调试、审计 | 可追溯时间 | 暴露创建时间 | -| **Tik** 🆕 | URL 分享、短链接 | 很短、用户友好 | 熵值较低 | - -## 下一步 - -- [快速开始指南](../quick-start/) -- [JWT 示例](../jwt-example/) -- [完整文档](../../docs/) - diff --git a/examples/token-styles/go.mod b/examples/token-styles/go.mod deleted file mode 100644 index 2f16ade..0000000 --- a/examples/token-styles/go.mod +++ /dev/null @@ -1,20 +0,0 @@ -module github.com/click33/sa-token-go/examples/token-styles - -go 1.21 - -require ( - github.com/click33/sa-token-go/core v0.1.3 - github.com/click33/sa-token-go/storage/memory v0.1.3 - github.com/click33/sa-token-go/stputil v0.1.3 -) - -require ( - github.com/golang-jwt/jwt/v5 v5.2.1 // indirect - github.com/google/uuid v1.6.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/storage/memory => ../../storage/memory - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/generator/sgenerator/consts.go b/generator/sgenerator/consts.go new file mode 100644 index 0000000..af33f74 --- /dev/null +++ b/generator/sgenerator/consts.go @@ -0,0 +1,13 @@ +// @Author daixk 2025/12/22 16:08:00 +package sgenerator + +// Constants for token generation | Token生成常量 +const ( + DefaultTimeout = 2592000 // 30 days (seconds) | 30天(秒) + DefaultJWTSecret = "log-secret-key" // Should be overridden in production | 生产环境应覆盖 + TikTokenLength = 11 // TikTok-style short ID length | Tik风格短ID长度 + TikCharset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + HashRandomBytesLen = 16 // Random bytes length for hash token | 哈希Token的随机字节长度 + TimestampRandomLen = 8 // Random bytes length for timestamp token | 时间戳Token的随机字节长度 + DefaultSimpleLength = 16 // Default simple token length | 默认简单Token长度 +) diff --git a/core/token/token.go b/generator/sgenerator/generator_adapter_sgenerator.go similarity index 77% rename from core/token/token.go rename to generator/sgenerator/generator_adapter_sgenerator.go index e4cf4e7..c4419d0 100644 --- a/core/token/token.go +++ b/generator/sgenerator/generator_adapter_sgenerator.go @@ -1,47 +1,41 @@ -package token +// @Author daixk 2025/12/17 9:39:00 +package sgenerator import ( "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" - "math/big" - "time" - - "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/adapter" "github.com/click33/sa-token-go/core/utils" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" -) - -// Constants for token generation | Token生成常量 -const ( - DefaultJWTSecret = "default-secret-key" // Should be overridden in production | 生产环境应覆盖 - TikTokenLength = 11 // TikTok-style short ID length | Tik风格短ID长度 - TikCharset = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" - HashRandomBytesLen = 16 // Random bytes length for hash token | 哈希Token的随机字节长度 - TimestampRandomLen = 8 // Random bytes length for timestamp token | 时间戳Token的随机字节长度 - DefaultSimpleLength = 16 // Default simple token length | 默认简单Token长度 -) - -// Error variables | 错误变量 -var ( - ErrInvalidToken = fmt.Errorf("invalid token") - ErrUnexpectedSigningMethod = fmt.Errorf("unexpected signing method") + "math/big" + "time" ) // Generator Token generator | Token生成器 type Generator struct { - config *config.Config + timeout int64 + tokenStyle adapter.TokenStyle + jwtSecretKey string } // NewGenerator Creates a new token generator | 创建新的Token生成器 -func NewGenerator(cfg *config.Config) *Generator { - if cfg == nil { - cfg = config.DefaultConfig() +func NewGenerator(timeout int64, tokenStyle adapter.TokenStyle, jwtSecretKey string) *Generator { + return &Generator{ + timeout: timeout, + tokenStyle: tokenStyle, + jwtSecretKey: jwtSecretKey, } +} + +// NewDefaultGenerator Creates a new token generator | 创建新的默认的Token生成器 +func NewDefaultGenerator() *Generator { return &Generator{ - config: cfg, + timeout: DefaultTimeout, + tokenStyle: adapter.TokenStyleUUID, + jwtSecretKey: DefaultJWTSecret, } } @@ -53,24 +47,24 @@ func (g *Generator) Generate(loginID string, device string) (string, error) { return "", fmt.Errorf("loginID cannot be empty") } - switch g.config.TokenStyle { - case config.TokenStyleUUID: + switch g.tokenStyle { + case adapter.TokenStyleUUID: return g.generateUUID() - case config.TokenStyleSimple: + case adapter.TokenStyleSimple: return g.generateSimple(DefaultSimpleLength) - case config.TokenStyleRandom32: + case adapter.TokenStyleRandom32: return g.generateSimple(32) - case config.TokenStyleRandom64: + case adapter.TokenStyleRandom64: return g.generateSimple(64) - case config.TokenStyleRandom128: + case adapter.TokenStyleRandom128: return g.generateSimple(128) - case config.TokenStyleJWT: + case adapter.TokenStyleJWT: return g.generateJWT(loginID, device) - case config.TokenStyleHash: + case adapter.TokenStyleHash: return g.generateHash(loginID, device) - case config.TokenStyleTimestamp: + case adapter.TokenStyleTimestamp: return g.generateTimestamp(loginID, device) - case config.TokenStyleTik: + case adapter.TokenStyleTik: return g.generateTik() default: return g.generateUUID() @@ -111,8 +105,8 @@ func (g *Generator) generateJWT(loginID string, device string) (string, error) { } // Add expiration if timeout is configured | 如果配置了超时时间则添加过期时间 - if g.config.Timeout > 0 { - claims["exp"] = now.Add(time.Duration(g.config.Timeout) * time.Second).Unix() + if g.timeout > 0 { + claims["exp"] = now.Add(time.Duration(g.timeout) * time.Second).Unix() } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) @@ -128,8 +122,8 @@ func (g *Generator) generateJWT(loginID string, device string) (string, error) { // getJWTSecret Gets JWT secret key with fallback | 获取JWT密钥(带默认值) func (g *Generator) getJWTSecret() string { - if g.config.JwtSecretKey != "" { - return g.config.JwtSecretKey + if g.jwtSecretKey != "" { + return g.jwtSecretKey } return DefaultJWTSecret } @@ -147,7 +141,7 @@ func (g *Generator) ParseJWT(tokenStr string) (jwt.MapClaims, error) { token, err := jwt.Parse(tokenStr, func(token *jwt.Token) (any, error) { // Verify signing method | 验证签名方法 if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("%w: %v", ErrUnexpectedSigningMethod, token.Header["alg"]) + return nil, fmt.Errorf("%w: %v", fmt.Errorf("unexpected signing method"), token.Header["alg"]) } return []byte(secretKey), nil }) @@ -160,7 +154,7 @@ func (g *Generator) ParseJWT(tokenStr string) (jwt.MapClaims, error) { return claims, nil } - return nil, ErrInvalidToken + return nil, fmt.Errorf("invalid token") } // ValidateJWT Validates JWT token | 验证JWT Token diff --git a/generator/sgenerator/generator_adapter_sgenerator_test.go b/generator/sgenerator/generator_adapter_sgenerator_test.go new file mode 100644 index 0000000..3874395 --- /dev/null +++ b/generator/sgenerator/generator_adapter_sgenerator_test.go @@ -0,0 +1,407 @@ +// @Author daixk 2025/12/28 10:00:00 +package sgenerator + +import ( + "strings" + "testing" + + "github.com/click33/sa-token-go/core/adapter" +) + +// ============ Constructor Tests | 构造函数测试 ============ + +func TestNewGenerator(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleUUID, "my-secret") + + if g.timeout != 3600 { + t.Errorf("expected timeout 3600, got %d", g.timeout) + } + if g.tokenStyle != adapter.TokenStyleUUID { + t.Errorf("expected tokenStyle uuid, got %s", g.tokenStyle) + } + if g.jwtSecretKey != "my-secret" { + t.Errorf("expected jwtSecretKey my-secret, got %s", g.jwtSecretKey) + } +} + +func TestNewDefaultGenerator(t *testing.T) { + g := NewDefaultGenerator() + + if g.timeout != DefaultTimeout { + t.Errorf("expected default timeout %d, got %d", DefaultTimeout, g.timeout) + } + if g.tokenStyle != adapter.TokenStyleUUID { + t.Errorf("expected default tokenStyle uuid, got %s", g.tokenStyle) + } + if g.jwtSecretKey != DefaultJWTSecret { + t.Errorf("expected default jwtSecretKey, got %s", g.jwtSecretKey) + } +} + +// ============ Generate Tests | 生成测试 ============ + +func TestGenerate_EmptyLoginID(t *testing.T) { + g := NewDefaultGenerator() + + _, err := g.Generate("", "pc") + if err == nil { + t.Error("expected error for empty loginID") + } +} + +func TestGenerate_UUID(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleUUID, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // UUID format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + if len(token) != 36 { + t.Errorf("expected UUID length 36, got %d", len(token)) + } + if strings.Count(token, "-") != 4 { + t.Errorf("expected 4 dashes in UUID, got %d", strings.Count(token, "-")) + } +} + +func TestGenerate_Simple(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleSimple, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(token) != DefaultSimpleLength { + t.Errorf("expected length %d, got %d", DefaultSimpleLength, len(token)) + } +} + +func TestGenerate_Random32(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleRandom32, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(token) != 32 { + t.Errorf("expected length 32, got %d", len(token)) + } +} + +func TestGenerate_Random64(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleRandom64, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(token) != 64 { + t.Errorf("expected length 64, got %d", len(token)) + } +} + +func TestGenerate_Random128(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleRandom128, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(token) != 128 { + t.Errorf("expected length 128, got %d", len(token)) + } +} + +func TestGenerate_JWT(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // JWT format: header.payload.signature + parts := strings.Split(token, ".") + if len(parts) != 3 { + t.Errorf("expected JWT with 3 parts, got %d", len(parts)) + } +} + +func TestGenerate_Hash(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleHash, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // SHA256 hex string length is 64 + if len(token) != 64 { + t.Errorf("expected hash length 64, got %d", len(token)) + } +} + +func TestGenerate_Timestamp(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleTimestamp, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Format: timestamp_loginID_random + parts := strings.Split(token, "_") + if len(parts) != 3 { + t.Errorf("expected 3 parts separated by underscore, got %d", len(parts)) + } + if parts[1] != "user123" { + t.Errorf("expected loginID user123 in token, got %s", parts[1]) + } +} + +func TestGenerate_Tik(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleTik, "") + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(token) != TikTokenLength { + t.Errorf("expected tik length %d, got %d", TikTokenLength, len(token)) + } + + // Check all characters are in charset + for _, c := range token { + if !strings.ContainsRune(TikCharset, c) { + t.Errorf("unexpected character %c in tik token", c) + } + } +} + +func TestGenerate_DefaultStyle(t *testing.T) { + g := &Generator{ + timeout: 3600, + tokenStyle: "invalid_style", + } + + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Should fallback to UUID + if len(token) != 36 { + t.Errorf("expected UUID fallback, got length %d", len(token)) + } +} + +// ============ JWT Helper Tests | JWT辅助方法测试 ============ + +func TestParseJWT(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + + token, err := g.Generate("user123", "mobile") + if err != nil { + t.Fatalf("failed to generate JWT: %v", err) + } + + claims, err := g.ParseJWT(token) + if err != nil { + t.Fatalf("failed to parse JWT: %v", err) + } + + if claims["loginId"] != "user123" { + t.Errorf("expected loginId user123, got %v", claims["loginId"]) + } + if claims["device"] != "mobile" { + t.Errorf("expected device mobile, got %v", claims["device"]) + } +} + +func TestParseJWT_EmptyToken(t *testing.T) { + g := NewDefaultGenerator() + + _, err := g.ParseJWT("") + if err == nil { + t.Error("expected error for empty token") + } +} + +func TestParseJWT_InvalidToken(t *testing.T) { + g := NewDefaultGenerator() + + _, err := g.ParseJWT("invalid.token.string") + if err == nil { + t.Error("expected error for invalid token") + } +} + +func TestParseJWT_WrongSecret(t *testing.T) { + g1 := NewGenerator(3600, adapter.TokenStyleJWT, "secret1") + g2 := NewGenerator(3600, adapter.TokenStyleJWT, "secret2") + + token, _ := g1.Generate("user123", "pc") + + _, err := g2.ParseJWT(token) + if err == nil { + t.Error("expected error for wrong secret") + } +} + +func TestValidateJWT(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + + token, _ := g.Generate("user123", "pc") + + err := g.ValidateJWT(token) + if err != nil { + t.Errorf("expected valid JWT, got error: %v", err) + } +} + +func TestValidateJWT_Invalid(t *testing.T) { + g := NewDefaultGenerator() + + err := g.ValidateJWT("invalid.token") + if err == nil { + t.Error("expected error for invalid JWT") + } +} + +func TestGetLoginIDFromJWT(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + + token, _ := g.Generate("user456", "pc") + + loginID, err := g.GetLoginIDFromJWT(token) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if loginID != "user456" { + t.Errorf("expected loginID user456, got %s", loginID) + } +} + +func TestGetLoginIDFromJWT_InvalidToken(t *testing.T) { + g := NewDefaultGenerator() + + _, err := g.GetLoginIDFromJWT("invalid.token") + if err == nil { + t.Error("expected error for invalid token") + } +} + +// ============ Uniqueness Tests | 唯一性测试 ============ + +func TestGenerate_Uniqueness(t *testing.T) { + styles := []adapter.TokenStyle{ + adapter.TokenStyleUUID, + adapter.TokenStyleSimple, + adapter.TokenStyleRandom32, + adapter.TokenStyleRandom64, + adapter.TokenStyleHash, + adapter.TokenStyleTimestamp, + adapter.TokenStyleTik, + } + + for _, style := range styles { + t.Run(string(style), func(t *testing.T) { + g := NewGenerator(3600, style, "test-secret") + tokens := make(map[string]bool) + count := 100 + + for i := 0; i < count; i++ { + token, err := g.Generate("user123", "pc") + if err != nil { + t.Fatalf("failed to generate token: %v", err) + } + + if tokens[token] { + t.Errorf("duplicate token generated: %s", token) + } + tokens[token] = true + } + }) + } +} + +// ============ JWT Expiration Tests | JWT过期测试 ============ + +func TestGenerate_JWT_WithExpiration(t *testing.T) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + + token, _ := g.Generate("user123", "pc") + claims, _ := g.ParseJWT(token) + + if _, ok := claims["exp"]; !ok { + t.Error("expected exp claim in JWT") + } + if _, ok := claims["iat"]; !ok { + t.Error("expected iat claim in JWT") + } +} + +func TestGenerate_JWT_NoExpiration(t *testing.T) { + g := NewGenerator(0, adapter.TokenStyleJWT, "test-secret") + + token, _ := g.Generate("user123", "pc") + claims, _ := g.ParseJWT(token) + + if _, ok := claims["exp"]; ok { + t.Error("expected no exp claim when timeout is 0") + } +} + +// ============ Benchmark Tests | 基准测试 ============ + +func BenchmarkGenerate_UUID(b *testing.B) { + g := NewGenerator(3600, adapter.TokenStyleUUID, "") + for i := 0; i < b.N; i++ { + g.Generate("user123", "pc") + } +} + +func BenchmarkGenerate_Simple(b *testing.B) { + g := NewGenerator(3600, adapter.TokenStyleSimple, "") + for i := 0; i < b.N; i++ { + g.Generate("user123", "pc") + } +} + +func BenchmarkGenerate_JWT(b *testing.B) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + for i := 0; i < b.N; i++ { + g.Generate("user123", "pc") + } +} + +func BenchmarkGenerate_Hash(b *testing.B) { + g := NewGenerator(3600, adapter.TokenStyleHash, "") + for i := 0; i < b.N; i++ { + g.Generate("user123", "pc") + } +} + +func BenchmarkGenerate_Tik(b *testing.B) { + g := NewGenerator(3600, adapter.TokenStyleTik, "") + for i := 0; i < b.N; i++ { + g.Generate("user123", "pc") + } +} + +func BenchmarkParseJWT(b *testing.B) { + g := NewGenerator(3600, adapter.TokenStyleJWT, "test-secret") + token, _ := g.Generate("user123", "pc") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + g.ParseJWT(token) + } +} diff --git a/generator/sgenerator/go.mod b/generator/sgenerator/go.mod new file mode 100644 index 0000000..198bd2d --- /dev/null +++ b/generator/sgenerator/go.mod @@ -0,0 +1,8 @@ +module github.com/click33/sa-token-go/generator/sgenerator + +go 1.25.0 + +require ( + github.com/golang-jwt/jwt/v5 v5.3.0 + github.com/google/uuid v1.6.0 +) diff --git a/generator/sgenerator/go.sum b/generator/sgenerator/go.sum new file mode 100644 index 0000000..17b1315 --- /dev/null +++ b/generator/sgenerator/go.sum @@ -0,0 +1,2 @@ +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= diff --git a/go.work b/go.work index 5dd5f54..1741896 100644 --- a/go.work +++ b/go.work @@ -1,14 +1,56 @@ -go 1.25.3 +go 1.25.0 use ( + // codec + ./codec/json + ./codec/msgpack + + // core ./core + + // examples - integrations + ./examples/annotation/annotation-example + ./examples/chi/chi-example + ./examples/echo/echo-example + ./examples/fiber/fiber-example + ./examples/gf/gf-example + ./examples/gin/gin-example + ./examples/gin/gin-simple ./examples/kratos/kratos-example + + // examples - manager-example + ./examples/manager-example/jwt-example + ./examples/manager-example/listener-example + ./examples/manager-example/oauth2-example + ./examples/manager-example/security-example + ./examples/manager-example/session-demo + ./examples/manager-example/token-styles + + // examples - quick-start + ./examples/quick-start/simple-example + ./examples/quick-start/complex-example + + // generator + ./generator/sgenerator + + // integrations ./integrations/chi ./integrations/echo ./integrations/fiber ./integrations/gf ./integrations/gin ./integrations/kratos + ./integrations/zero + + // log + ./log/gf + ./log/nop + ./log/slog + + // pool + ./pool/ants + + // storage ./storage/memory ./storage/redis ./stputil diff --git a/go.work.sum b/go.work.sum index ff4fc5a..906cf9b 100644 --- a/go.work.sum +++ b/go.work.sum @@ -1,28 +1,60 @@ cel.dev/expr v0.19.1 h1:NciYrtDRIR0lNCnH1LFJegdjspNx9fI59O7TWcua/W4= cel.dev/expr v0.19.1/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cloud.google.com/go v0.110.10 h1:LXy9GEO+timppncPIAZoOj3l58LIU9k+kn48AN7IO3Y= +cloud.google.com/go v0.110.10/go.mod h1:v1OoFqYxiBkUrruItNM3eT4lLByNjxmJSV/xDKJNnic= cloud.google.com/go/compute v1.23.3 h1:6sVlXXBmbd7jNX0Ipq0trII3e4n1/MsADLK6a+aiVlk= cloud.google.com/go/compute v1.23.3/go.mod h1:VCgBUoMnIVIR0CscqQiPJLAG25E3ZRZMzcFZeQ+h8CI= cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY= cloud.google.com/go/compute/metadata v0.2.3/go.mod h1:VAV5nSsACxMJvgaAuX6Pk2AawlZn8kiOGuCv6gTkwuA= cloud.google.com/go/compute/metadata v0.6.0 h1:A6hENjEsCDtC1k8byVsgwvVcioamEHvZ4j01OwKxG9I= cloud.google.com/go/compute/metadata v0.6.0/go.mod h1:FjyFAW1MW0C203CEOMDTu3Dk1FlqW3Rga40jzHL4hfg= +cloud.google.com/go/firestore v1.14.0 h1:8aLcKnMPoldYU3YHgu4t2exrKhLQkqaXAGqT0ljrFVw= +cloud.google.com/go/firestore v1.14.0/go.mod h1:96MVaHLsEhbvkBEdZgfN+AS/GIkco1LRpH9Xp9YZfzQ= +cloud.google.com/go/iam v1.1.5 h1:1jTsCu4bcsNsE4iiqNT5SHwrDRCfRmIaaaVFhRveTJI= +cloud.google.com/go/iam v1.1.5/go.mod h1:rB6P/Ic3mykPbFio+vo7403drjlgvoWfYpJhMXEbzv8= +cloud.google.com/go/longrunning v0.5.4 h1:w8xEcbZodnA2BbW6sVirkkoC+1gP8wS57EUUgGS0GVg= +cloud.google.com/go/longrunning v0.5.4/go.mod h1:zqNVncI0BOP8ST6XQD1+VcvuShMmq7+xFSzOL++V0dI= +cloud.google.com/go/storage v1.35.1 h1:B59ahL//eDfx2IIKFBeT5Atm9wnNmj3+8xG/W4WB//w= +cloud.google.com/go/storage v1.35.1/go.mod h1:M6M/3V/D3KpzMTJyPOR/HU6n2Si5QdaXYEsng2xgOs8= dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk= dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.25.0 h1:3c8yed4lgqTt+oTQ+JNMDo+F4xprBf+O/il4ZC0nRLw= github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.25.0/go.mod h1:obipzmGjfSjam60XLwGfqUkJsfiheAl+TUjG+4yzyPM= +github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= +github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/andybalholm/brotli v1.0.5/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= +github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJA= +github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= +github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= +github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/clbanning/mxj/v2 v2.7.0/go.mod h1:hNiWqW14h+kc+MdF9C6/YoRfjEJoR3ou6tn/Qo+ve2s= +github.com/click33/sa-token-go/core v0.1.7/go.mod h1:mb3AQAJIXqx9WdULyn5qjufK1j/u+kgB0q+tafHVhgk= +github.com/click33/sa-token-go/integrations/gf v0.1.6/go.mod h1:cwAY3VJ/KNLNcuYDNjMt1O4IERCep1JWCfvS9LyoA5w= +github.com/click33/sa-token-go/integrations/gf v0.1.7/go.mod h1:0GDMoJ4Vv5+gPo8tnutpX34V/USd4+TLjZVObb91wFA= +github.com/click33/sa-token-go/integrations/gin v0.1.6/go.mod h1:A8Ds3bUalQQcH2vYBaGPAqZ6fzZY9m7AamikDcXyRKE= +github.com/click33/sa-token-go/integrations/gin v0.1.7/go.mod h1:1STqW8wTUJ/tpfw/E22HsPzQVFnc7dE/AW0w44XQ2co= github.com/click33/sa-token-go/storage/memory v0.1.4/go.mod h1:nqyuEh23mNjcuG3aI/BqJFz71zkpsgjdStW1BC5lkB0= -github.com/click33/sa-token-go/storage/memory v0.1.5/go.mod h1:HxN2NVLq7lx+sOmq5RmV0h8xJjEUJLm4Xt1Mq+9PV2s= +github.com/click33/sa-token-go/storage/memory v0.1.7/go.mod h1:wnQVAHnFKWs6CzmM0DSRZSE2ADPXQQ9AIrWxdmzjz3Q= +github.com/click33/sa-token-go/storage/redis v0.1.7/go.mod h1:pPxB/qFRNc/TUv0mMjwPwRVs1ZZB22YCknR1H/rvQ8I= +github.com/click33/sa-token-go/stputil v0.1.6/go.mod h1:G4vYhljpN1SeGLYHWRslYncHJvn52CksHGjBzUNITFA= +github.com/click33/sa-token-go/stputil v0.1.7/go.mod h1:YY4NzfwVMwPUQLDBk9C5eVLQ08oI3vNSFQhBuZBPtgY= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cncf/udpa/go v0.0.0-20220112060539-c52dc94e7fbe h1:QQ3GSy+MqSHxm/d8nCtnAiZdYFd45cYZPs8vOOIYKfk= @@ -31,11 +63,20 @@ github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101 h1:7To3pQ+pZo0i3dsWEbi github.com/cncf/xds/go v0.0.0-20231109132714-523115ebc101/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3 h1:boJj011Hh+874zpIySeApCX4GeOjPl9qhRF3QuIZq+Q= github.com/cncf/xds/go v0.0.0-20241223141626-cff3c89139a3/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM= +github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= +github.com/coreos/go-semver v0.3.1 h1:yi21YpKnrx1gt5R+la8n5WgS0kCrsPp33dmEyHReZr4= +github.com/coreos/go-semver v0.3.1/go.mod h1:irMmmIw/7yzSRPWryHsK7EYSg09caPQL03VsM8rvUec= +github.com/coreos/go-systemd/v22 v22.3.2 h1:D9/bQk5vlXQFZ6Kwuu6zaiXJ9oTPe68++AzAJc1DzSI= +github.com/coreos/go-systemd/v22 v22.3.2/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9 h1:uDmaGzcdjhF4i/plgjmEsriH11Y0o7RKapEf/LDaM3w= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/emicklei/go-restful/v3 v3.11.0 h1:rAQeMHw1c7zTmncogyy8VvRZwtkmkZ4FxERmMY4rD+g= +github.com/emicklei/go-restful/v3 v3.11.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/envoyproxy/go-control-plane v0.11.2-0.20230627204322-7d0032219fcb h1:kxNVXsNro/lpR5WD+P1FI/yUHn2G03Glber3k8cQL2Y= github.com/envoyproxy/go-control-plane v0.11.2-0.20230627204322-7d0032219fcb/go.mod h1:GxGqnjWzl1Gz8WfAfMJSfhvsi4EPZayRb25nLHDWXyA= @@ -50,26 +91,44 @@ github.com/envoyproxy/protoc-gen-validate v1.0.2/go.mod h1:GpiZQP3dDbg4JouG/NNS7 github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fullstorydev/grpcurl v1.9.3 h1:PC1Xi3w+JAvEE2Tg2Gf2RfVgPbf9+tbuQr1ZkyVU3jk= +github.com/fullstorydev/grpcurl v1.9.3/go.mod h1:/b4Wxe8bG6ndAjlfSUjwseQReUDUvBJiFEB7UllOlUE= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.10.0/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y= +github.com/go-chi/chi/v5 v5.0.11/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= +github.com/go-openapi/jsonpointer v0.19.6 h1:eCs3fxoIi3Wh6vtgmLTOjdhSpiqphQ+DaPn38N2ZdrE= +github.com/go-openapi/jsonpointer v0.19.6/go.mod h1:osyAmYz/mB/C3I+WsTTSgw1ONzaLJoLCyoi6/zppojs= +github.com/go-openapi/jsonreference v0.20.2 h1:3sVjiK66+uXK/6oQ8xgcRKcFgQ5KXa2KvnJRumpMGbE= +github.com/go-openapi/jsonreference v0.20.2/go.mod h1:Bl1zwGIM8/wsvqjsOQLJ/SH+En5Ap4rVB5KVcIDZG2k= +github.com/go-openapi/swag v0.22.4 h1:QLMzNJnMGPRNDCbySlcj1x01tzU8/9LTTL9hZZZogBU= +github.com/go-openapi/swag v0.22.4/go.mod h1:UzaqsxGiab7freDnrUUra0MwWfN/q7tE4j+VcZ0yl14= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= +github.com/go-sql-driver/mysql v1.9.0 h1:Y0zIbQXhQKmQgTp44Y1dp3wTXcn804QoTptLZT1vtvo= +github.com/go-sql-driver/mysql v1.9.0/go.mod h1:pDetrLJeA3oMujJuvXc8RJoasr589B6A9fwzD3QMrqw= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/gofiber/fiber/v2 v2.52.0/go.mod h1:KEOE+cXMhXG0zHc9d8+E38hoX+ZN7bhOtgeF2oT6jrQ= github.com/gogf/gf/v2 v2.9.4/go.mod h1:Ukl+5HUH9S7puBmNLR4L1zUqeRwi0nrW4OigOknEztU= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= +github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.2.1/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/glog v1.1.2 h1:DVjP2PbBOzHyzA+dn3WhHIq4NdVu3Q+pvivFICf/7fo= github.com/golang/glog v1.1.2/go.mod h1:zR+okUeTbrL6EL3xHUDxZuEtGv04p5shwip1+mL/rLQ= github.com/golang/glog v1.2.4 h1:CNNw5U8lSiiBk7druxtSHHTsRWcxKoac6kZKm2peBBc= @@ -78,24 +137,79 @@ github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/snappy v1.0.0 h1:Oy607GVXHs7RtbggtPBnr2RmDArIsAefDwvrdWvRhGs= +github.com/golang/snappy v1.0.0/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/gnostic-models v0.6.8 h1:yo/ABAfM5IMRsS1VnXjTBvUb61tFIHozhlYvRgGre9I= +github.com/google/gnostic-models v0.6.8/go.mod h1:5n7qKqH0f5wFt+aWF8CW6pZLLNOfYuF5OpfBSENuI8U= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0 h1:A8PeW59pxE9IoFRqBp37U+mSNaQoZ46F1f0f863XSXw= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= +github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= +github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= +github.com/googleapis/enterprise-certificate-proxy v0.3.2/go.mod h1:VLSiSSBs/ksPL8kq3OBOQ6WRI2QnaFynd1DCjZ62+V0= +github.com/googleapis/gax-go/v2 v2.12.0 h1:A+gCJKdRfqXkr+BIRGtZLibNXf0m1f9E4HG56etFpas= +github.com/googleapis/gax-go/v2 v2.12.0/go.mod h1:y+aIqrI5eb1YGMVJfuV3185Ts/D7qKpsEkdD5+I6QGU= +github.com/googleapis/google-cloud-go-testing v0.0.0-20210719221736-1c9a4c676720 h1:zC34cGQu69FG7qzJ3WiKW244WfhDC3xxYMeNOX2gtUQ= +github.com/googleapis/google-cloud-go-testing v0.0.0-20210719221736-1c9a4c676720/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/grafana/pyroscope-go v1.2.7 h1:VWBBlqxjyR0Cwk2W6UrE8CdcdD80GOFNutj0Kb1T8ac= +github.com/grafana/pyroscope-go v1.2.7/go.mod h1:o/bpSLiJYYP6HQtvcoVKiE9s5RiNgjYTj1DhiddP2Pc= +github.com/grafana/pyroscope-go/godeltaprof v0.1.9 h1:c1Us8i6eSmkW+Ez05d3co8kasnuOY813tbMN8i/a3Og= +github.com/grafana/pyroscope-go/godeltaprof v0.1.9/go.mod h1:2+l7K7twW49Ct4wFluZD3tZ6e0SjanjcUUBPVD/UuGU= github.com/grokify/html-strip-tags-go v0.1.0/go.mod h1:ZdzgfHEzAfz9X6Xe5eBLVblWIxXfYSQ40S/VKrAOGpc= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0 h1:bkypFPDjIYGfCYD5mRBvpqxfYX1YCS1PXdKYWi8FsN0= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.20.0/go.mod h1:P+Lt/0by1T8bfcF3z737NnSbmxQAppXMRziHUxPOC8k= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= +github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/hashicorp/consul/api v1.25.1 h1:CqrdhYzc8XZuPnhIYZWH45toM0LB9ZeYr/gvpLVI3PE= +github.com/hashicorp/consul/api v1.25.1/go.mod h1:iiLVwR/htV7mas/sy0O+XSuEnrdBUUydemjxcUrAt4g= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.5.0 h1:bI2ocEMgcVlz55Oj1xZNBsVi900c7II+fWDyV9o+13c= +github.com/hashicorp/go-hclog v1.5.0/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/hashicorp/go-immutable-radix v1.3.1 h1:DKHmCUm2hRBK510BaiZlwvpD40f8bJFeZnpfm2KLowc= +github.com/hashicorp/go-immutable-radix v1.3.1/go.mod h1:0y9vanUI8NX6FsYoO3zeMjhV/C5i9g4Q3DwcSNZ4P60= +github.com/hashicorp/go-rootcerts v1.0.2 h1:jzhAVGtqPKbwpyCPELlgNWhE1znq+qwJtW5Oi2viEzc= +github.com/hashicorp/go-rootcerts v1.0.2/go.mod h1:pqUvnprVnM5bf7AOirdbb01K4ccR319Vf4pU3K5EGc8= +github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= +github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hashicorp/serf v0.10.1 h1:Z1H2J60yRKvfDYAOZLd2MU0ND4AH/WDz7xYHDWQsIPY= +github.com/hashicorp/serf v0.10.1/go.mod h1:yL2t6BqATOLGc5HF7qbFkTfXoPIY0WZdWHfEvMqbG+4= github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI= github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.4 h1:9wKznZrhWa2QiHL+NjTSPP6yjl3451BX3imWDnokYlg= +github.com/jackc/pgx/v5 v5.7.4/go.mod h1:ncY89UGWxg82EykZUwSpUKEfccBGGYq1xjrOpsbsfGQ= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94= +github.com/jhump/protoreflect v1.17.0/go.mod h1:h9+vUUL38jiBzck8ck+6G/aeMX8Z4QUY/NiJPwPNi+8= +github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.0/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= github.com/knz/go-libedit v1.10.1 h1:0pHpWtx9vcvC0xGZqEQlQdfSQs7WRlAjuPvk3fOZDCo= github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pty v1.1.1 h1:VkoXIwSboBpnk99O/KFauAEILuNHv5DVFKZMBN/gUgw= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/labstack/echo/v4 v4.11.4/go.mod h1:noh7EvLwqDsmh/X/HWKPUl1AjzJrhyptRyEbQJfxen8= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= @@ -103,44 +217,81 @@ github.com/lufia/plan9stats v0.0.0-20230326075908-cb1d2100619a h1:N9zuLhTvBSRt0g github.com/lufia/plan9stats v0.0.0-20230326075908-cb1d2100619a/go.mod h1:JKx41uQRwqlTZabZc+kILPrO/3jlKnQ2Z8b7YiVw5cE= github.com/lyft/protoc-gen-star/v2 v2.0.4-0.20230330145011-496ad1ac90a4 h1:sIXJOMrYnQZJu7OB7ANSF4MYri2fTEGIsRLz6LwI4xE= github.com/lyft/protoc-gen-star/v2 v2.0.4-0.20230330145011-496ad1ac90a4/go.mod h1:amey7yeodaJhXSbf/TlLvWiqQfLOSpEk//mLlc+axEk= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= +github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/nats-io/nats.go v1.31.0 h1:/WFBHEc/dOKBF6qf1TZhrdEfTmOZ5JzdJ+Y3m6Y/p7E= +github.com/nats-io/nats.go v1.31.0/go.mod h1:di3Bm5MLsoB4Bx61CBTsxuarI36WbhAwOm8QrW39+i8= +github.com/nats-io/nkeys v0.4.6 h1:IzVe95ru2CT6ta874rt9saQRkWfe2nFj1NtvYSLqMzY= +github.com/nats-io/nkeys v0.4.6/go.mod h1:4DxZNzenSVd1cYQoAa8948QY3QDjrHfcfVADymtkpts= +github.com/nats-io/nuid v1.0.1 h1:5iA8DT8V7q8WK2EScv2padNa/rTESc1KdnPw4TC2paw= +github.com/nats-io/nuid v1.0.1/go.mod h1:19wcPz3Ph3q0Jbyiqsd0kePYG7A95tJPxeL+1OSON2c= github.com/olekukonko/errors v1.1.0/go.mod h1:ppzxA5jBKcO1vIpCXQ9ZqgDh8iwODz6OXIGKU8r5m4Y= github.com/olekukonko/ll v0.0.9/go.mod h1:En+sEW0JNETl26+K8eZ6/W4UQ7CYSrrgg/EdIYT2H8g= github.com/olekukonko/tablewriter v1.1.0/go.mod h1:5c+EBPeSqvXnLLgkm9isDdzR3wjfBkHR9Nhfp3NWrzo= github.com/olekukonko/ts v0.0.0-20171002115256-78ecb04241c0 h1:LiZB1h0GIcudcDci2bxbqI6DXV8bF8POAnArqvRrIyw= github.com/olekukonko/ts v0.0.0-20171002115256-78ecb04241c0/go.mod h1:F/7q8/HZz+TXjlsoZQQKVYvXTZaFH4QRa3y+j1p7MS0= +github.com/openzipkin/zipkin-go v0.4.3 h1:9EGwpqkgnwdEIJ+Od7QVSEIH+ocmm5nPat0G7sjsSdg= +github.com/openzipkin/zipkin-go v0.4.3/go.mod h1:M9wCJZFWCo2RiY+o1eBCEMe0Dp2S5LDHcMZmk3RmK7c= github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs= github.com/philhofer/fwd v1.1.2 h1:bnDivRJ1EWPjUIRXV5KfORO897HTbpFAQddBdE8t7Gw= github.com/philhofer/fwd v1.1.2/go.mod h1:qkPdfjR2SIEbspLqpe1tO4n5yICnr2DY7mqEx2tUTP0= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo= +github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 h1:GFCKgmp0tecUJ0sJuv4pzYCqS9+RGSn52M3FUwPs+uo= github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b h1:0LFwY6Q3gMACTjAbMZBjXAqTOzOwFaj2Ld6cjeQ7Rig= github.com/power-devops/perfstat v0.0.0-20221212215047-62379fc7944b/go.mod h1:OmDBASR4679mdNQnz2pUhc2G8CO2JrUAVFDRBDP/hJE= +github.com/prometheus/client_golang v1.21.1 h1:DOvXXTqVzvkIewV/CDPFdejpMCGeMcbGCQ8YOmu+Ibk= +github.com/prometheus/client_golang v1.21.1/go.mod h1:U9NM32ykUErtVBxdvD3zfi+EuFkkaBvMb09mIfe0Zgg= github.com/prometheus/client_model v0.6.1 h1:ZKSh/rekM+n3CeS952MLRAdFwIKqeY8b62p8ais2e9E= github.com/prometheus/client_model v0.6.1/go.mod h1:OrxVMOVHjw3lKMa8+x6HeMGkHMQyHDk9E3jmP2AmGiY= +github.com/prometheus/common v0.62.0 h1:xasJaQlnWAeyHdUBeGjXmutelfJHWMRr+Fg4QszZ2Io= +github.com/prometheus/common v0.62.0/go.mod h1:vyBcEuLSvWos9B1+CyL7JZ2up+uFzXhkqml0W5zIY1I= +github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0learggepc= +github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/redis/go-redis/v9 v9.5.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M= +github.com/redis/go-redis/v9 v9.16.0 h1:OotgqgLSRCmzfqChbQyG1PHC3tLNR89DG4jdOERSEP4= +github.com/redis/go-redis/v9 v9.16.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sagikazarmark/crypt v0.17.0 h1:ZA/7pXyjkHoK4bW4mIdnCLvL8hd+Nrbiw7Dqk7D4qUk= +github.com/sagikazarmark/crypt v0.17.0/go.mod h1:SMtHTvdmsZMuY/bpZoqokSoChIrcJ/epOxZN58PbZDg= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= github.com/shirou/gopsutil/v3 v3.23.6 h1:5y46WPI9QBKBbK7EEccUPNXpJpNrvPuTD0O2zHEHT08= github.com/shirou/gopsutil/v3 v3.23.6/go.mod h1:j7QX50DrXYggrpN30W0Mo+I4/8U2UUIQrnrhqUeWrAU= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= github.com/spf13/afero v1.10.0 h1:EaGW2JJh15aKOejeuJ+wpFSHnbd7GE6Wvp3TsNhb6LY= github.com/spf13/afero v1.10.0/go.mod h1:UBogFpq8E9Hx+xc5CNTTEpTnuHVmXDwZcZcE1eb/UhQ= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -153,7 +304,7 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= github.com/tinylib/msgp v1.1.8 h1:FCXC1xanKO4I8plpHGH2P7koL/RzZs12l/+r7vakfm0= github.com/tinylib/msgp v1.1.8/go.mod h1:qkpG+2ldGg4xRFmx+jfTvZPxfGFhi64BcnL9vkCm/Tw= github.com/tklauser/go-sysconf v0.3.11 h1:89WgdJhk5SNwJfu+GKyYveZ4IaJ7xAkecBo+KdJV0CM= @@ -168,15 +319,57 @@ github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyC github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= github.com/valyala/fasttemplate v1.2.2/go.mod h1:KHLXt3tVN2HBp8eijSv/kGJopbvo7S+qRAEEKiv+SiQ= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/vmihailenco/msgpack/v5 v5.4.1/go.mod h1:GaZTsDaehaPpQVyxrf5mtQlH+pc21PIudVV/E3rRQok= +github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78 h1:ilQV1hzziu+LLM3zUTJ0trRztfwgjqKnBWNtSRkbmwM= +github.com/youmark/pkcs8 v0.0.0-20240726163527-a2c0da244d78/go.mod h1:aL8wCCfTfSfmXjznFBSZNN13rSJjlIOI1fUNAtF7rmI= github.com/yuin/goldmark v1.4.13 h1:fVcFKWvrslecOb/tg+Cc05dkeYx540o0FuFt3nUVDoE= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= github.com/yusufpapurcu/wmi v1.2.3 h1:E1ctvB7uKFMOJw3fdOW32DwGE9I7t++CRUEMKvFoFiw= github.com/yusufpapurcu/wmi v1.2.3/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0= +go.etcd.io/etcd/api/v3 v3.5.10 h1:szRajuUUbLyppkhs9K6BRtjY37l66XQQmw7oZRANE4k= +go.etcd.io/etcd/api/v3 v3.5.10/go.mod h1:TidfmT4Uycad3NM/o25fG3J07odo4GBB9hoxaodFCtI= +go.etcd.io/etcd/api/v3 v3.5.15 h1:3KpLJir1ZEBrYuV2v+Twaa/e2MdDCEZ/70H+lzEiwsk= +go.etcd.io/etcd/api/v3 v3.5.15/go.mod h1:N9EhGzXq58WuMllgH9ZvnEr7SI9pS0k0+DHZezGp7jM= +go.etcd.io/etcd/client/pkg/v3 v3.5.10 h1:kfYIdQftBnbAq8pUWFXfpuuxFSKzlmM5cSn76JByiT0= +go.etcd.io/etcd/client/pkg/v3 v3.5.10/go.mod h1:DYivfIviIuQ8+/lCq4vcxuseg2P2XbHygkKwFo9fc8U= +go.etcd.io/etcd/client/pkg/v3 v3.5.15 h1:fo0HpWz/KlHGMCC+YejpiCmyWDEuIpnTDzpJLB5fWlA= +go.etcd.io/etcd/client/pkg/v3 v3.5.15/go.mod h1:mXDI4NAOwEiszrHCb0aqfAYNCrZP4e9hRca3d1YK8EU= +go.etcd.io/etcd/client/v2 v2.305.10 h1:MrmRktzv/XF8CvtQt+P6wLUlURaNpSDJHFZhe//2QE4= +go.etcd.io/etcd/client/v2 v2.305.10/go.mod h1:m3CKZi69HzilhVqtPDcjhSGp+kA1OmbNn0qamH80xjA= +go.etcd.io/etcd/client/v3 v3.5.10 h1:W9TXNZ+oB3MCd/8UjxHTWK5J9Nquw9fQBLJd5ne5/Ao= +go.etcd.io/etcd/client/v3 v3.5.10/go.mod h1:RVeBnDz2PUEZqTpgqwAtUd8nAPf5kjyFyND7P1VkOKc= +go.etcd.io/etcd/client/v3 v3.5.15 h1:23M0eY4Fd/inNv1ZfU3AxrbbOdW79r9V9Rl62Nm6ip4= +go.etcd.io/etcd/client/v3 v3.5.15/go.mod h1:CLSJxrYjvLtHsrPKsy7LmZEE+DK2ktfd2bN4RhBMwlU= +go.mongodb.org/mongo-driver/v2 v2.4.0 h1:Oq6BmUAAFTzMeh6AonuDlgZMuAuEiUxoAD1koK5MuFo= +go.mongodb.org/mongo-driver/v2 v2.4.0/go.mod h1:jHeEDJHJq7tm6ZF45Issun9dbogjfnPySb1vXA7EeAI= +go.opencensus.io v0.24.0 h1:y73uSU6J157QMP2kn2r30vwW1A2W2WFwSCGnAVxeaD0= +go.opencensus.io v0.24.0/go.mod h1:vNK8G9p7aAivkbmorf4v+7Hgx+Zs0yY+0fOtgBfjQKo= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/contrib/detectors/gcp v1.34.0 h1:JRxssobiPg23otYU5SbWtQC//snGVIM3Tx6QRzlQBao= go.opentelemetry.io/contrib/detectors/gcp v1.34.0/go.mod h1:cV4BMFcscUR/ckqLkbfQmF0PRsq8w/lMGzdbCSveBHo= go.opentelemetry.io/otel v1.34.0/go.mod h1:OWFPOQ+h4G8xpyjgqo4SxJYdDQ/qmRH+wivy7zzx9oI= go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/exporters/jaeger v1.17.0 h1:D7UpUy2Xc2wsi1Ras6V40q806WM07rqoCWzXu7Sqy+4= +go.opentelemetry.io/otel/exporters/jaeger v1.17.0/go.mod h1:nPCqOnEH9rNLKqH/+rrUjiMzHJdV1BlpKcTwRTyKkKI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0 h1:t6wl9SPayj+c7lEIFgm4ooDBZVb01IhLB4InpomhRw8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.24.0/go.mod h1:iSDOcsnSA5INXzZtwaBPrKp/lWu/V14Dd+llD0oI2EA= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0 h1:Mw5xcxMwlqoJd97vwPxA8isEaIoxsta9/Q51+TTJLGE= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.24.0/go.mod h1:CQNu9bj7o7mC6U7+CA/schKEYakYXWr79ucDHTMGhCM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0 h1:Xw8U6u2f8DK2XAkGRFV7BBLENgnTGX9i4rQRxJf+/vs= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.24.0/go.mod h1:6KW1Fm6R/s6Z3PGXwSJN2K4eT6wQB3vXX6CVnYX9NmM= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.24.0 h1:s0PHtIkN+3xrbDOpt2M8OTG92cWqUESvzh2MxiR5xY8= +go.opentelemetry.io/otel/exporters/stdout/stdouttrace v1.24.0/go.mod h1:hZlFbDbRt++MMPCCfSJfmhkGIWnX1h3XjkfxZUjLrIA= +go.opentelemetry.io/otel/exporters/zipkin v1.24.0 h1:3evrL5poBuh1KF51D9gO/S+N/1msnm4DaBqs/rpXUqY= +go.opentelemetry.io/otel/exporters/zipkin v1.24.0/go.mod h1:0EHgD8R0+8yRhUYJOGR8Hfg2dpiJQxDOszd5smVO9wM= go.opentelemetry.io/otel/metric v1.34.0/go.mod h1:CEDrp0fy2D0MvkXE+dPV7cMi8tWZwX3dmaIhwPOaqHE= go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= go.opentelemetry.io/otel/sdk v1.34.0/go.mod h1:0e/pNiaMAqaykJGKbi+tSjWfNNHMTxoC9qANsCzbyxU= @@ -187,14 +380,27 @@ go.opentelemetry.io/otel/trace v1.34.0/go.mod h1:Svm7lSjQD7kG7KJ/MUHPVXSDGz2OX4h go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.opentelemetry.io/proto/otlp v1.0.0 h1:T0TX0tmXU8a3CbNXzEKGeU5mIVOdf0oykP+u2lIVU/I= go.opentelemetry.io/proto/otlp v1.0.0/go.mod h1:Sy6pihPLfYHkr3NkUbEhGHFhINUSI/v80hjKIs5JXpM= +go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= +go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.10.0 h1:9qC72Qh0+3MqyJbAn8YU5xVq1frD8bn3JtD2oXtafVQ= +go.uber.org/atomic v1.10.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/mock v0.4.0 h1:VcM4ZOtdbR4f6VXfiOpwpVJDL6lCReaZ6mw31wqh7KU= +go.uber.org/mock v0.4.0/go.mod h1:a6FSlNadKUHUa9IP5Vyt1zh4fC7uAwxMutEAscFbkZc= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +go.uber.org/zap v1.21.0 h1:WefMeulhovoZ2sYXz7st6K0sLj7bBhpiFaud4r4zST8= +go.uber.org/zap v1.21.0/go.mod h1:wjWOCqI0f2ZZrJF/UufIOkiC8ii6tm1iqIsLo76RfJw= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys= +golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= +golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/net v0.38.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= @@ -204,9 +410,12 @@ golang.org/x/oauth2 v0.14.0/go.mod h1:lAtNWgaWfL4cm7j2OV8TxGi9Qb7ECORx8DktCY74Ow golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= @@ -216,16 +425,24 @@ golang.org/x/term v0.20.0 h1:VnkxpohqXaOBYJtBmEppKUG6mXpi+4O6purfc2+sMhw= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw= +golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.23.0/go.mod h1:/BLNzu4aZCJ1+kcD0DNRotWKage4q2rGVAg4o22unh4= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/time v0.10.0 h1:3usCWA8tQn0L8+hFJQNgzpWbd89begxN66o1Ojdn5L4= +golang.org/x/time v0.10.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk= +golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8= +google.golang.org/api v0.153.0 h1:N1AwGhielyKFaUqH07/ZSIQR3uNPcV7NVw0vj+j4iR4= +google.golang.org/api v0.153.0/go.mod h1:3qNJX5eOmhiWYc67jRA/3GsDw97UFb5ivv7Y2PrriAY= +google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c= +google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= google.golang.org/genproto v0.0.0-20231212172506-995d672761c0 h1:YJ5pD9rF8o9Qtta0Cmy9rdBwkSjrTCT6XTiUQVOtIos= @@ -238,11 +455,39 @@ google.golang.org/genproto/googleapis/rpc v0.0.0-20250115164207-1a7da9e5054f/go. google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.71.0/go.mod h1:H0GRtasmQOh9LkFoCPDu3ZrwUtD1YGE+b2vYBYd/8Ec= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.33.0/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/cheggaaa/pb.v1 v1.0.28 h1:n1tBJnnK2r7g9OW2btFH91V92STTUevLXYFb8gy9EMk= +gopkg.in/cheggaaa/pb.v1 v1.0.28/go.mod h1:V/YB90LKu/1FcN3WVnfiiE5oMCibMjukxqG/qStrOgw= +gopkg.in/h2non/gock.v1 v1.1.2 h1:jBbHXgGBK/AoPVfJh5x4r/WxIrElvbLel8TCZkkZJoY= +gopkg.in/h2non/gock.v1 v1.1.2/go.mod h1:n7UGz/ckNChHiK05rDoiC4MYSunEC/lyaUm2WWaDva0= +gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= +gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/api v0.29.3 h1:2ORfZ7+bGC3YJqGpV0KSDDEVf8hdGQ6A03/50vj8pmw= +k8s.io/api v0.29.3/go.mod h1:y2yg2NTyHUUkIoTC+phinTnEa3KFM6RZ3szxt014a80= +k8s.io/apimachinery v0.29.4 h1:RaFdJiDmuKs/8cm1M6Dh1Kvyh59YQFDcFuFTSmXes6Q= +k8s.io/apimachinery v0.29.4/go.mod h1:i3FJVwhvSp/6n8Fl4K97PJEP8C+MM+aoDq4+ZJBf70Y= +k8s.io/client-go v0.29.3 h1:R/zaZbEAxqComZ9FHeQwOh3Y1ZUs7FaHKZdQtIc2WZg= +k8s.io/client-go v0.29.3/go.mod h1:tkDisCvgPfiRpxGnOORfkljmS+UrW+WtXAy2fTvXJB0= +k8s.io/klog/v2 v2.110.1 h1:U/Af64HJf7FcwMcXyKm2RPM22WZzyR7OSpYj5tg3cL0= +k8s.io/klog/v2 v2.110.1/go.mod h1:YGtd1984u+GgbuZ7e08/yBuAfKLSO0+uR1Fhi6ExXjo= +k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00 h1:aVUu9fTY98ivBPKR9Y5w/AuzbMm96cd3YHRTU83I780= +k8s.io/kube-openapi v0.0.0-20231010175941-2dd684a91f00/go.mod h1:AsvuZPBlUDVuCdzJ87iajxtXuR9oktsTctW/R9wwouA= +k8s.io/utils v0.0.0-20240711033017-18e509b52bc8 h1:pUdcCO1Lk/tbT5ztQWOBi5HBgbBP1J8+AsQnQCKsi8A= +k8s.io/utils v0.0.0-20240711033017-18e509b52bc8/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= nullprogram.com/x/optparse v1.0.0 h1:xGFgVi5ZaWOnYdac2foDT3vg0ZZC9ErXFV57mr4OHrI= nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50= rsc.io/pdf v0.1.1 h1:k1MczvYDUvJBe93bYd7wrZLLUEcLZAuF824/I4e5Xr4= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= +sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd h1:EDPBXCAspyGV4jQlpZSudPeMmr1bNJefnuqLsRAsHZo= +sigs.k8s.io/json v0.0.0-20221116044647-bc3834ca7abd/go.mod h1:B8JuhiUyNFVKdsE8h686QcCxMaH6HrOAZj4vswFpcB0= +sigs.k8s.io/structured-merge-diff/v4 v4.4.1 h1:150L+0vs/8DA78h1u02ooW1/fFq/Lwr+sGiqlzvrtq4= +sigs.k8s.io/structured-merge-diff/v4 v4.4.1/go.mod h1:N8hJocpFajUSSeSJ9bOZ77VzejKZaXsTtZo4/u7Io08= +sigs.k8s.io/yaml v1.3.0 h1:a2VclLzOGrwOHDiV8EfBGhvjHvP46CtW5j6POvhYGGo= +sigs.k8s.io/yaml v1.3.0/go.mod h1:GeOyir5tyXNByN85N/dRIT9es5UQNerPYEKK56eTBm8= diff --git a/integrations/chi/annotation.go b/integrations/chi/annotation.go index b8ec1d9..cb0fc35 100644 --- a/integrations/chi/annotation.go +++ b/integrations/chi/annotation.go @@ -1,6 +1,8 @@ +// @Author daixk 2025/12/28 package chi import ( + "context" "net/http" "strings" @@ -10,122 +12,246 @@ import ( // Annotation annotation structure | 注解结构体 type Annotation struct { - CheckLogin bool `json:"checkLogin"` - CheckRole []string `json:"checkRole"` - CheckPermission []string `json:"checkPermission"` - CheckDisable bool `json:"checkDisable"` - Ignore bool `json:"ignore"` + AuthType string `json:"authType"` // Optional: specify auth type | 可选:指定认证类型 + CheckLogin bool `json:"checkLogin"` // Check login | 检查登录 + CheckRole []string `json:"checkRole"` // Check roles | 检查角色 + CheckPermission []string `json:"checkPermission"` // Check permissions | 检查权限 + CheckDisable bool `json:"checkDisable"` // Check disable status | 检查封禁状态 + Ignore bool `json:"ignore"` // Ignore authentication | 忽略认证 + LogicType LogicType `json:"logicType"` // OR or AND logic (default: OR) | OR 或 AND 逻辑(默认: OR) } // GetHandler gets handler with annotations | 获取带注解的处理器 -func GetHandler(handler http.Handler, annotations ...*Annotation) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Check if authentication should be ignored | 检查是否忽略认证 +func GetHandler(handler http.HandlerFunc, annotations ...*Annotation) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Ignore authentication | 忽略认证直接放行 if len(annotations) > 0 && annotations[0].Ignore { if handler != nil { - handler.ServeHTTP(w, r) + handler(w, r) } return } - // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) - ctx := NewChiContext(w, r) - saCtx := core.NewContext(ctx, stputil.GetManager()) + // Check if any authentication is needed | 检查是否需要任何认证 + ann := &Annotation{} + if len(annotations) > 0 { + ann = annotations[0] + } + + // No authentication required | 无需任何认证 + needAuth := ann.CheckLogin || ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 + if !needAuth { + if handler != nil { + handler(w, r) + } + return + } + + ctx := r.Context() + + // Get manager | 获取 Manager + mgr, err := stputil.GetManager(ann.AuthType) + if err != nil { + writeErrorResponse(w, err) + return + } + + // Get SaTokenContext (reuse cached context) | 获取 SaTokenContext(复用缓存上下文) + chiCtx := NewChiContext(w, r) + saCtx := getSaContext(chiCtx.(*ChiContext), r, mgr) token := saCtx.GetTokenValue() + if token == "" { writeErrorResponse(w, core.NewNotLoginError()) return } // Check login | 检查登录 - if !stputil.IsLogin(token) { - writeErrorResponse(w, core.NewNotLoginError()) + if err := mgr.CheckLogin(ctx, token); err != nil { + writeErrorResponse(w, err) return } - // Get login ID | 获取登录ID - loginID, err := stputil.GetLoginID(token) - if err != nil { - writeErrorResponse(w, err) - return + // Get loginID for further checks | 获取 loginID 用于后续检查 + var loginID string + if ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 { + loginID, err = mgr.GetLoginIDNotCheck(ctx, token) + if err != nil { + writeErrorResponse(w, err) + return + } } // Check if account is disabled | 检查是否被封禁 - if len(annotations) > 0 && annotations[0].CheckDisable { - if stputil.IsDisable(loginID) { + if ann.CheckDisable { + if mgr.IsDisable(ctx, loginID) { writeErrorResponse(w, core.NewAccountDisabledError(loginID)) return } } // Check permission | 检查权限 - if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { - hasPermission := false - for _, perm := range annotations[0].CheckPermission { - if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { - hasPermission = true - break - } + if len(ann.CheckPermission) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasPermissionsAnd(ctx, loginID, ann.CheckPermission) + } else { + ok = mgr.HasPermissionsOr(ctx, loginID, ann.CheckPermission) } - if !hasPermission { - writeErrorResponse(w, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + if !ok { + writeErrorResponse(w, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) return } } // Check role | 检查角色 - if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { - hasRole := false - for _, role := range annotations[0].CheckRole { - if stputil.HasRole(loginID, strings.TrimSpace(role)) { - hasRole = true - break - } + if len(ann.CheckRole) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasRolesAnd(ctx, loginID, ann.CheckRole) + } else { + ok = mgr.HasRolesOr(ctx, loginID, ann.CheckRole) } - if !hasRole { - writeErrorResponse(w, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + if !ok { + writeErrorResponse(w, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) return } } // All checks passed, execute original handler | 所有检查通过,执行原函数 if handler != nil { - handler.ServeHTTP(w, r) + handler(w, r) } - }) + } } -// CheckLoginMiddleware decorator for login checking | 检查登录装饰器 -func CheckLoginMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return GetHandler(next, &Annotation{CheckLogin: true}) +// CheckLoginHandler decorator for login checking | 检查登录装饰器 +func CheckLoginHandler(authType ...string) http.HandlerFunc { + ann := &Annotation{CheckLogin: true} + if len(authType) > 0 { + ann.AuthType = authType[0] } + return GetHandler(nil, ann) +} + +// CheckRoleHandler decorator for role checking | 检查角色装饰器 +func CheckRoleHandler(roles ...string) http.HandlerFunc { + return GetHandler(nil, &Annotation{CheckRole: roles}) } -// CheckRoleMiddleware decorator for role checking | 检查角色装饰器 -func CheckRoleMiddleware(roles ...string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return GetHandler(next, &Annotation{CheckRole: roles}) +// CheckRoleHandlerWithAuthType decorator for role checking with auth type | 检查角色装饰器(带认证类型) +func CheckRoleHandlerWithAuthType(authType string, roles ...string) http.HandlerFunc { + return GetHandler(nil, &Annotation{CheckRole: roles, AuthType: authType}) +} + +// CheckPermissionHandler decorator for permission checking | 检查权限装饰器 +func CheckPermissionHandler(perms ...string) http.HandlerFunc { + return GetHandler(nil, &Annotation{CheckPermission: perms}) +} + +// CheckPermissionHandlerWithAuthType decorator for permission checking with auth type | 检查权限装饰器(带认证类型) +func CheckPermissionHandlerWithAuthType(authType string, perms ...string) http.HandlerFunc { + return GetHandler(nil, &Annotation{CheckPermission: perms, AuthType: authType}) +} + +// CheckDisableHandler decorator for checking if account is disabled | 检查是否被封禁装饰器 +func CheckDisableHandler(authType ...string) http.HandlerFunc { + ann := &Annotation{CheckDisable: true} + if len(authType) > 0 { + ann.AuthType = authType[0] } + return GetHandler(nil, ann) +} + +// IgnoreHandler decorator to ignore authentication | 忽略认证装饰器 +func IgnoreHandler() http.HandlerFunc { + return GetHandler(nil, &Annotation{Ignore: true}) +} + +// ============ Combined Handler | 组合处理器 ============ + +// CheckLoginAndRoleHandler checks login and role | 检查登录和角色 +func CheckLoginAndRoleHandler(roles ...string) http.HandlerFunc { + return GetHandler(nil, &Annotation{CheckLogin: true, CheckRole: roles}) +} + +// CheckLoginAndPermissionHandler checks login and permission | 检查登录和权限 +func CheckLoginAndPermissionHandler(perms ...string) http.HandlerFunc { + return GetHandler(nil, &Annotation{CheckLogin: true, CheckPermission: perms}) } -// CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 -func CheckPermissionMiddleware(perms ...string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return GetHandler(next, &Annotation{CheckPermission: perms}) +// CheckAllHandler checks login, role, permission and disable status | 全面检查 +func CheckAllHandler(roles []string, perms []string) http.HandlerFunc { + return GetHandler(nil, &Annotation{ + CheckLogin: true, + CheckRole: roles, + CheckPermission: perms, + CheckDisable: true, + }) +} + +// ============ Context Helper | 上下文辅助函数 ============ + +// GetLoginIDFromRequest gets login ID from request context | 从请求上下文获取登录 ID +func GetLoginIDFromRequest(w http.ResponseWriter, r *http.Request, authType ...string) (string, error) { + var at string + if len(authType) > 0 { + at = authType[0] } + + mgr, err := stputil.GetManager(at) + if err != nil { + return "", err + } + + chiCtx := NewChiContext(w, r) + saCtx := getSaContext(chiCtx.(*ChiContext), r, mgr) + token := saCtx.GetTokenValue() + if token == "" { + return "", core.ErrNotLogin + } + return mgr.GetLoginID(r.Context(), token) } -// CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 -func CheckDisableMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return GetHandler(next, &Annotation{CheckDisable: true}) +// IsLoginFromRequest checks if user is logged in from request | 从请求检查用户是否已登录 +func IsLoginFromRequest(w http.ResponseWriter, r *http.Request, authType ...string) bool { + var at string + if len(authType) > 0 { + at = authType[0] } + + mgr, err := stputil.GetManager(at) + if err != nil { + return false + } + + chiCtx := NewChiContext(w, r) + saCtx := getSaContext(chiCtx.(*ChiContext), r, mgr) + token := saCtx.GetTokenValue() + if token == "" { + return false + } + return mgr.IsLogin(r.Context(), token) } -// IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 -func IgnoreMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return GetHandler(next, &Annotation{Ignore: true}) +// GetTokenFromRequest gets token from request (exported) | 从请求获取 Token(导出) +func GetTokenFromRequest(w http.ResponseWriter, r *http.Request, authType ...string) string { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return "" } + + chiCtx := NewChiContext(w, r) + saCtx := getSaContext(chiCtx.(*ChiContext), r, mgr) + return saCtx.GetTokenValue() +} + +// WithContext creates a new context with sa-token context | 创建带 sa-token 上下文的新上下文 +func WithContext(r *http.Request, authType ...string) context.Context { + return r.Context() } diff --git a/integrations/chi/export.go b/integrations/chi/export.go index 7b756c9..aaf7e50 100644 --- a/integrations/chi/export.go +++ b/integrations/chi/export.go @@ -1,364 +1,950 @@ package chi import ( + "context" "time" + "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/codec/msgpack" "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" + "github.com/click33/sa-token-go/core/manager" + "github.com/click33/sa-token-go/core/oauth2" + "github.com/click33/sa-token-go/core/security" + "github.com/click33/sa-token-go/core/session" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/log/slog" + "github.com/click33/sa-token-go/pool/ants" + "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/storage/redis" "github.com/click33/sa-token-go/stputil" ) -// ============ Re-export core types | 重新导出核心类型 ============ +// ============ Type Aliases | 类型别名 ============ -// Configuration related types | 配置相关类型 type ( - Config = core.Config - CookieConfig = core.CookieConfig - TokenStyle = core.TokenStyle + // Config 配置 + Config = config.Config + // Manager 管理器 + Manager = manager.Manager + // Session 会话 + Session = session.Session + // TokenInfo Token信息 + TokenInfo = manager.TokenInfo + // DisableInfo 封禁信息 + DisableInfo = manager.DisableInfo + // Builder 构建器 + Builder = builder.Builder + // SaTokenError 错误类型 + SaTokenError = core.SaTokenError + // Event 事件类型 + Event = listener.Event + // EventData 事件数据 + EventData = listener.EventData + // Listener 事件监听器 + Listener = listener.Listener + // ListenerConfig 监听器配置 + ListenerConfig = listener.ListenerConfig + // RefreshTokenInfo 刷新令牌信息 + RefreshTokenInfo = security.RefreshTokenInfo + // AccessTokenInfo 访问令牌信息 + AccessTokenInfo = security.AccessTokenInfo + // OAuth2Client OAuth2客户端 + OAuth2Client = oauth2.Client + // OAuth2AccessToken OAuth2访问令牌 + OAuth2AccessToken = oauth2.AccessToken + // AuthorizationCode 授权码 + AuthorizationCode = oauth2.AuthorizationCode + // OAuth2TokenRequest OAuth2令牌请求 + OAuth2TokenRequest = oauth2.TokenRequest + // OAuth2GrantType OAuth2授权类型 + OAuth2GrantType = oauth2.GrantType + // OAuth2UserValidator OAuth2用户验证器 + OAuth2UserValidator = oauth2.UserValidator + // Storage 存储接口 + Storage = adapter.Storage + // Codec 编解码接口 + Codec = adapter.Codec + // Log 日志接口 + Log = adapter.Log + // Pool 协程池接口 + Pool = adapter.Pool + // Generator 生成器接口 + Generator = adapter.Generator + + // ============ Codec Types | 编解码器类型 ============ + + // JSONSerializer JSON编解码器 + JSONSerializer = json.JSONSerializer + // MsgPackSerializer MsgPack编解码器 + MsgPackSerializer = msgpack.MsgPackSerializer + + // ============ Storage Types | 存储类型 ============ + + // MemoryStorage 内存存储 + MemoryStorage = memory.Storage + // RedisStorage Redis存储 + RedisStorage = redis.Storage + // RedisConfig Redis配置 + RedisConfig = redis.Config + // RedisBuilder Redis构建器 + RedisBuilder = redis.Builder + + // ============ Logger Types | 日志类型 ============ + + // SlogLogger 标准日志实现 + SlogLogger = slog.Logger + // SlogLoggerConfig 标准日志配置 + SlogLoggerConfig = slog.LoggerConfig + // SlogLogLevel 日志级别 + SlogLogLevel = slog.LogLevel + // NopLogger 空日志实现 + NopLogger = nop.NopLogger + + // ============ Generator Types | 生成器类型 ============ + + // TokenGenerator Token生成器 + TokenGenerator = sgenerator.Generator + // TokenStyle Token风格 + TokenStyle = adapter.TokenStyle + + // ============ Pool Types | 协程池类型 ============ + + // RenewPoolManager 续期池管理器 + RenewPoolManager = ants.RenewPoolManager + // RenewPoolConfig 续期池配置 + RenewPoolConfig = ants.RenewPoolConfig ) -// Token style constants | Token风格常量 +// ============ Error Codes | 错误码 ============ + const ( - TokenStyleUUID = core.TokenStyleUUID - TokenStyleSimple = core.TokenStyleSimple - TokenStyleRandom32 = core.TokenStyleRandom32 - TokenStyleRandom64 = core.TokenStyleRandom64 - TokenStyleRandom128 = core.TokenStyleRandom128 - TokenStyleJWT = core.TokenStyleJWT - TokenStyleHash = core.TokenStyleHash - TokenStyleTimestamp = core.TokenStyleTimestamp - TokenStyleTik = core.TokenStyleTik + CodeSuccess = core.CodeSuccess + CodeBadRequest = core.CodeBadRequest + CodeNotLogin = core.CodeNotLogin + CodePermissionDenied = core.CodePermissionDenied + CodeNotFound = core.CodeNotFound + CodeServerError = core.CodeServerError + CodeTokenInvalid = core.CodeTokenInvalid + CodeTokenExpired = core.CodeTokenExpired + CodeAccountDisabled = core.CodeAccountDisabled + CodeKickedOut = core.CodeKickedOut + CodeActiveTimeout = core.CodeActiveTimeout + CodeMaxLoginCount = core.CodeMaxLoginCount + CodeStorageError = core.CodeStorageError + CodeInvalidParameter = core.CodeInvalidParameter + CodeSessionError = core.CodeSessionError ) -// Core types | 核心类型 -type ( - Manager = core.Manager - TokenInfo = core.TokenInfo - Session = core.Session - TokenGenerator = core.TokenGenerator - SaTokenContext = core.SaTokenContext - Builder = core.Builder - NonceManager = core.NonceManager - RefreshTokenInfo = core.RefreshTokenInfo - RefreshTokenManager = core.RefreshTokenManager - OAuth2Server = core.OAuth2Server - OAuth2Client = core.OAuth2Client - OAuth2AccessToken = core.OAuth2AccessToken - OAuth2GrantType = core.OAuth2GrantType -) +// ============ Errors | 错误变量 ============ -// Adapter interfaces | 适配器接口 -type ( - Storage = core.Storage - RequestContext = core.RequestContext +var ( + // Authentication Errors | 认证错误 + ErrNotLogin = core.ErrNotLogin + ErrTokenInvalid = core.ErrTokenInvalid + ErrTokenExpired = core.ErrTokenExpired + ErrTokenKickout = core.ErrTokenKickout + ErrTokenReplaced = core.ErrTokenReplaced + ErrInvalidLoginID = core.ErrInvalidLoginID + ErrInvalidDevice = core.ErrInvalidDevice + ErrTokenNotFound = core.ErrTokenNotFound + + // Authorization Errors | 授权错误 + ErrPermissionDenied = core.ErrPermissionDenied + ErrRoleDenied = core.ErrRoleDenied + + // Account Errors | 账号错误 + ErrAccountDisabled = core.ErrAccountDisabled + ErrAccountNotFound = core.ErrAccountNotFound + ErrLoginLimitExceeded = core.ErrLoginLimitExceeded + + // Session Errors | 会话错误 + ErrSessionNotFound = core.ErrSessionNotFound + ErrActiveTimeout = core.ErrActiveTimeout + ErrSessionInvalidDataKey = core.ErrSessionInvalidDataKey + ErrSessionIDEmpty = core.ErrSessionIDEmpty + + // Security Errors | 安全错误 + ErrInvalidNonce = core.ErrInvalidNonce + ErrRefreshTokenExpired = core.ErrRefreshTokenExpired + ErrNonceInvalidRefreshToken = core.ErrNonceInvalidRefreshToken + ErrInvalidLoginIDEmpty = core.ErrInvalidLoginIDEmpty + + // OAuth2 Errors | OAuth2错误 + ErrClientOrClientIDEmpty = core.ErrClientOrClientIDEmpty + ErrClientNotFound = core.ErrClientNotFound + ErrUserIDEmpty = core.ErrUserIDEmpty + ErrInvalidRedirectURI = core.ErrInvalidRedirectURI + ErrInvalidClientCredentials = core.ErrInvalidClientCredentials + ErrInvalidAuthCode = core.ErrInvalidAuthCode + ErrAuthCodeUsed = core.ErrAuthCodeUsed + ErrAuthCodeExpired = core.ErrAuthCodeExpired + ErrClientMismatch = core.ErrClientMismatch + ErrRedirectURIMismatch = core.ErrRedirectURIMismatch + ErrInvalidAccessToken = core.ErrInvalidAccessToken + ErrInvalidRefreshToken = core.ErrInvalidRefreshToken + ErrInvalidScope = core.ErrInvalidScope + + // System Errors | 系统错误 + ErrStorageUnavailable = core.ErrStorageUnavailable + ErrSerializeFailed = core.ErrSerializeFailed + ErrDeserializeFailed = core.ErrDeserializeFailed + ErrTypeConvert = core.ErrTypeConvert ) -// Event related types | 事件相关类型 -type ( - EventListener = core.EventListener - EventManager = core.EventManager - EventData = core.EventData - Event = core.Event - ListenerFunc = core.ListenerFunc - ListenerConfig = core.ListenerConfig -) +// ============ Error Constructors | 错误构造函数 ============ -// Event constants | 事件常量 -const ( - EventLogin = core.EventLogin - EventLogout = core.EventLogout - EventKickout = core.EventKickout - EventDisable = core.EventDisable - EventUntie = core.EventUntie - EventRenew = core.EventRenew - EventCreateSession = core.EventCreateSession - EventDestroySession = core.EventDestroySession - EventPermissionCheck = core.EventPermissionCheck - EventRoleCheck = core.EventRoleCheck - EventAll = core.EventAll +var ( + NewError = core.NewError + NewErrorWithContext = core.NewErrorWithContext + NewNotLoginError = core.NewNotLoginError + NewPermissionDeniedError = core.NewPermissionDeniedError + NewRoleDeniedError = core.NewRoleDeniedError + NewAccountDisabledError = core.NewAccountDisabledError ) -// OAuth2 grant type constants | OAuth2授权类型常量 -const ( - GrantTypeAuthorizationCode = core.GrantTypeAuthorizationCode - GrantTypeRefreshToken = core.GrantTypeRefreshToken - GrantTypeClientCredentials = core.GrantTypeClientCredentials - GrantTypePassword = core.GrantTypePassword -) +// ============ Error Checking Helpers | 错误检查辅助函数 ============ -// Utility functions | 工具函数 var ( - RandomString = core.RandomString - IsEmpty = core.IsEmpty - IsNotEmpty = core.IsNotEmpty - DefaultString = core.DefaultString - ContainsString = core.ContainsString - RemoveString = core.RemoveString - UniqueStrings = core.UniqueStrings - MergeStrings = core.MergeStrings - MatchPattern = core.MatchPattern + IsNotLoginError = core.IsNotLoginError + IsPermissionDeniedError = core.IsPermissionDeniedError + IsAccountDisabledError = core.IsAccountDisabledError + IsTokenError = core.IsTokenError + GetErrorCode = core.GetErrorCode ) -// ============ Core constructor functions | 核心构造函数 ============ +// ============ Manager Management | Manager 管理 ============ + +// SetManager stores the manager-example in the global map using the specified autoType | 使用指定的 autoType 将管理器存储在全局 map 中 +func SetManager(mgr *manager.Manager) { + stputil.SetManager(mgr) +} + +// GetManager retrieves the manager-example from the global map using the specified autoType | 使用指定的 autoType 从全局 map 中获取管理器 +func GetManager(autoType ...string) (*manager.Manager, error) { + return stputil.GetManager(autoType...) +} -// DefaultConfig returns default configuration | 返回默认配置 -func DefaultConfig() *Config { - return core.DefaultConfig() +// DeleteManager delete the specific manager-example for the given autoType and releases resources | 删除指定的管理器并释放资源 +func DeleteManager(autoType ...string) error { + return stputil.DeleteManager(autoType...) } -// NewManager creates a new authentication manager | 创建新的认证管理器 -func NewManager(storage Storage, cfg *Config) *Manager { - return core.NewManager(storage, cfg) +// DeleteAllManager delete all managers in the global map and releases resources | 关闭所有管理器并释放资源 +func DeleteAllManager() { + stputil.DeleteAllManager() } -// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext { - return core.NewContext(ctx, mgr) +// ============ Builder & Config | 构建器和配置 ============ + +// NewDefaultBuild creates a new default builder | 创建默认构建器 +func NewDefaultBuild() *builder.Builder { + return builder.NewBuilder() } -// NewSession creates a new session | 创建新的Session -func NewSession(id string, storage Storage, prefix string) *Session { - return core.NewSession(id, storage, prefix) +// NewDefaultConfig creates a new default config | 创建默认配置 +func NewDefaultConfig() *config.Config { + return config.DefaultConfig() } -// LoadSession loads an existing session | 加载已存在的Session -func LoadSession(id string, storage Storage, prefix string) (*Session, error) { - return core.LoadSession(id, storage, prefix) +// DefaultLoggerConfig returns the default logger config | 返回默认日志配置 +func DefaultLoggerConfig() *slog.LoggerConfig { + return slog.DefaultLoggerConfig() } -// NewTokenGenerator creates a new token generator | 创建新的Token生成器 -func NewTokenGenerator(cfg *Config) *TokenGenerator { - return core.NewTokenGenerator(cfg) +// DefaultRenewPoolConfig returns the default renew pool config | 返回默认续期池配置 +func DefaultRenewPoolConfig() *ants.RenewPoolConfig { + return ants.DefaultRenewPoolConfig() } -// NewEventManager creates a new event manager | 创建新的事件管理器 -func NewEventManager() *EventManager { - return core.NewEventManager() +// ============ Codec Constructors | 编解码器构造函数 ============ + +// NewJSONSerializer creates a new JSON serializer | 创建JSON序列化器 +func NewJSONSerializer() *json.JSONSerializer { + return json.NewJSONSerializer() } -// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置) -func NewBuilder() *Builder { - return core.NewBuilder() +// NewMsgPackSerializer creates a new MsgPack serializer | 创建MsgPack序列化器 +func NewMsgPackSerializer() *msgpack.MsgPackSerializer { + return msgpack.NewMsgPackSerializer() } -// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器 -func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { - return core.NewNonceManager(storage, prefix, ttl...) +// ============ Storage Constructors | 存储构造函数 ============ + +// NewMemoryStorage creates a new memory storage | 创建内存存储 +func NewMemoryStorage() *memory.Storage { + return memory.NewStorage() } -// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器 -func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { - return core.NewRefreshTokenManager(storage, prefix, cfg) +// NewMemoryStorageWithCleanupInterval creates a new memory storage with cleanup interval | 创建带清理间隔的内存存储 +func NewMemoryStorageWithCleanupInterval(interval time.Duration) *memory.Storage { + return memory.NewStorageWithCleanupInterval(interval) } -// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器 -func NewOAuth2Server(storage Storage, prefix string) *OAuth2Server { - return core.NewOAuth2Server(storage, prefix) +// NewRedisStorage creates a new Redis storage from URL | 通过URL创建Redis存储 +func NewRedisStorage(url string) (*redis.Storage, error) { + return redis.NewStorage(url) } -// ============ Global StpUtil functions | 全局StpUtil函数 ============ +// NewRedisStorageFromConfig creates a new Redis storage from config | 通过配置创建Redis存储 +func NewRedisStorageFromConfig(cfg *redis.Config) (*redis.Storage, error) { + return redis.NewStorageFromConfig(cfg) +} -// SetManager sets the global Manager (must be called first) | 设置全局Manager(必须先调用此方法) -func SetManager(mgr *Manager) { - stputil.SetManager(mgr) +// NewRedisBuilder creates a new Redis builder | 创建Redis构建器 +func NewRedisBuilder() *redis.Builder { + return redis.NewBuilder() +} + +// ============ Logger Constructors | 日志构造函数 ============ + +// NewSlogLogger creates a new slog logger with config | 使用配置创建标准日志器 +func NewSlogLogger(cfg *slog.LoggerConfig) (*slog.Logger, error) { + return slog.NewLoggerWithConfig(cfg) } -// GetManager gets the global Manager | 获取全局Manager -func GetManager() *Manager { - return stputil.GetManager() +// NewNopLogger creates a new no-op logger | 创建空日志器 +func NewNopLogger() *nop.NopLogger { + return nop.NewNopLogger() } +// ============ Generator Constructors | 生成器构造函数 ============ + +// NewTokenGenerator creates a new token generator | 创建Token生成器 +func NewTokenGenerator(timeout int64, tokenStyle adapter.TokenStyle, jwtSecretKey string) *sgenerator.Generator { + return sgenerator.NewGenerator(timeout, tokenStyle, jwtSecretKey) +} + +// NewDefaultTokenGenerator creates a new default token generator | 创建默认Token生成器 +func NewDefaultTokenGenerator() *sgenerator.Generator { + return sgenerator.NewDefaultGenerator() +} + +// ============ Pool Constructors | 协程池构造函数 ============ + +// NewRenewPoolManager creates a new renew pool manager-example with default config | 使用默认配置创建续期池管理器 +func NewRenewPoolManager() *ants.RenewPoolManager { + return ants.NewRenewPoolManagerWithDefaultConfig() +} + +// NewRenewPoolManagerWithConfig creates a new renew pool manager-example with config | 使用配置创建续期池管理器 +func NewRenewPoolManagerWithConfig(cfg *ants.RenewPoolConfig) (*ants.RenewPoolManager, error) { + return ants.NewRenewPoolManagerWithConfig(cfg) +} + +// ============ Token Style Constants | Token风格常量 ============ + +const ( + // TokenStyleUUID UUID style | UUID风格 + TokenStyleUUID = adapter.TokenStyleUUID + // TokenStyleSimple Simple random string | 简单随机字符串 + TokenStyleSimple = adapter.TokenStyleSimple + // TokenStyleRandom32 32-bit random string | 32位随机字符串 + TokenStyleRandom32 = adapter.TokenStyleRandom32 + // TokenStyleRandom64 64-bit random string | 64位随机字符串 + TokenStyleRandom64 = adapter.TokenStyleRandom64 + // TokenStyleRandom128 128-bit random string | 128位随机字符串 + TokenStyleRandom128 = adapter.TokenStyleRandom128 + // TokenStyleJWT JWT style | JWT风格 + TokenStyleJWT = adapter.TokenStyleJWT + // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 + TokenStyleHash = adapter.TokenStyleHash + // TokenStyleTimestamp Timestamp-based style | 时间戳风格 + TokenStyleTimestamp = adapter.TokenStyleTimestamp + // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID + TokenStyleTik = adapter.TokenStyleTik +) + +// ============ Log Level Constants | 日志级别常量 ============ + +const ( + // LogLevelDebug Debug level | 调试级别 + LogLevelDebug = adapter.LogLevelDebug + // LogLevelInfo Info level | 信息级别 + LogLevelInfo = adapter.LogLevelInfo + // LogLevelWarn Warn level | 警告级别 + LogLevelWarn = adapter.LogLevelWarn + // LogLevelError Error level | 错误级别 + LogLevelError = adapter.LogLevelError +) + // ============ Authentication | 登录认证 ============ // Login performs user login | 用户登录 -func Login(loginID interface{}, device ...string) (string, error) { - return stputil.Login(loginID, device...) +func Login(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.Login(ctx, loginID, deviceOrAutoType...) } // LoginByToken performs login with specified token | 使用指定Token登录 -func LoginByToken(loginID interface{}, tokenValue string, device ...string) error { - return stputil.LoginByToken(loginID, tokenValue, device...) +func LoginByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LoginByToken(ctx, tokenValue, authType...) } // Logout performs user logout | 用户登出 -func Logout(loginID interface{}, device ...string) error { - return stputil.Logout(loginID, device...) +func Logout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Logout(ctx, loginID, deviceOrAutoType...) } // LogoutByToken performs logout by token | 根据Token登出 -func LogoutByToken(tokenValue string) error { - return stputil.LogoutByToken(tokenValue) +func LogoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LogoutByToken(ctx, tokenValue, authType...) +} + +// Kickout kicks out a user session | 踢人下线 +func Kickout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Kickout(ctx, loginID, deviceOrAutoType...) +} + +// KickoutByToken Kick user offline | 根据Token踢人下线 +func KickoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.KickoutByToken(ctx, tokenValue, authType...) +} + +// Replace user offline by login ID and device | 根据账号和设备顶人下线 +func Replace(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Replace(ctx, loginID, deviceOrAutoType...) +} + +// ReplaceByToken Replace user offline by token | 根据Token顶人下线 +func ReplaceByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.ReplaceByToken(ctx, tokenValue, authType...) } +// ============ Token Validation | Token验证 ============ + // IsLogin checks if the user is logged in | 检查用户是否已登录 -func IsLogin(tokenValue string) bool { - return stputil.IsLogin(tokenValue) +func IsLogin(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsLogin(ctx, tokenValue, authType...) } // CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) -func CheckLogin(tokenValue string) error { - return stputil.CheckLogin(tokenValue) +func CheckLogin(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.CheckLogin(ctx, tokenValue, authType...) +} + +// CheckLoginWithState checks the login status (returns error to determine the reason if not logged in) | 检查登录状态(未登录时根据错误确定原因) +func CheckLoginWithState(ctx context.Context, tokenValue string, authType ...string) (bool, error) { + return stputil.CheckLoginWithState(ctx, tokenValue, authType...) } // GetLoginID gets the login ID from token | 从Token获取登录ID -func GetLoginID(tokenValue string) (string, error) { - return stputil.GetLoginID(tokenValue) +func GetLoginID(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginID(ctx, tokenValue, authType...) } -// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查) -func GetLoginIDNotCheck(tokenValue string) (string, error) { - return stputil.GetLoginIDNotCheck(tokenValue) +// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查登录状态) +func GetLoginIDNotCheck(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginIDNotCheck(ctx, tokenValue, authType...) } // GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 -func GetTokenValue(loginID interface{}, device ...string) (string, error) { - return stputil.GetTokenValue(loginID, device...) +func GetTokenValue(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.GetTokenValue(ctx, loginID, deviceOrAutoType...) } // GetTokenInfo gets token information | 获取Token信息 -func GetTokenInfo(tokenValue string) (*TokenInfo, error) { - return stputil.GetTokenInfo(tokenValue) +func GetTokenInfo(ctx context.Context, tokenValue string, authType ...string) (*manager.TokenInfo, error) { + return stputil.GetTokenInfo(ctx, tokenValue, authType...) } -// ============ Kickout | 踢人下线 ============ +// ============ Account Disable | 账号封禁 ============ -// Kickout kicks out a user session | 踢人下线 -func Kickout(loginID interface{}, device ...string) error { - return stputil.Kickout(loginID, device...) +// Disable disables an account for specified duration | 封禁账号(指定时长) +func Disable(ctx context.Context, loginID interface{}, duration time.Duration, authType ...string) error { + return stputil.Disable(ctx, loginID, duration, authType...) } -// ============ Account Disable | 账号封禁 ============ +// DisableByToken disables the account associated with the given token for a duration | 根据指定 Token 封禁其对应的账号 +func DisableByToken(ctx context.Context, tokenValue string, duration time.Duration, authType ...string) error { + return stputil.DisableByToken(ctx, tokenValue, duration, authType...) +} -// Disable disables an account for specified duration | 封禁账号(指定时长) -func Disable(loginID interface{}, duration time.Duration) error { - return stputil.Disable(loginID, duration) +// Untie re-enables a disabled account | 解封账号 +func Untie(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.Untie(ctx, loginID, authType...) +} + +// UntieByToken re-enables a disabled account by token | 根据Token解封账号 +func UntieByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.UntieByToken(ctx, tokenValue, authType...) } // IsDisable checks if an account is disabled | 检查账号是否被封禁 -func IsDisable(loginID interface{}) bool { - return stputil.IsDisable(loginID) +func IsDisable(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.IsDisable(ctx, loginID, authType...) } -// CheckDisable checks if account is disabled (throws error if disabled) | 检查账号是否被封禁(被封禁则抛出错误) -func CheckDisableByToken(tokenValue string) error { - return stputil.CheckDisable(tokenValue) +// IsDisableByToken checks if an account is disabled by token | 根据Token检查账号是否被封禁 +func IsDisableByToken(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsDisableByToken(ctx, tokenValue, authType...) } -// GetDisableTime gets remaining disabled time | 获取账号剩余封禁时间 -func GetDisableTime(loginID interface{}) (int64, error) { - return stputil.GetDisableTime(loginID) +// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) +func GetDisableTime(ctx context.Context, loginID interface{}, authType ...string) (int64, error) { + return stputil.GetDisableTime(ctx, loginID, authType...) } -// Untie unties/unlocks an account | 解除账号封禁 -func Untie(loginID interface{}) error { - return stputil.Untie(loginID) +// GetDisableTimeByToken gets remaining disable time by token | 根据Token获取剩余封禁时间(秒) +func GetDisableTimeByToken(ctx context.Context, tokenValue string, authType ...string) (int64, error) { + return stputil.GetDisableTimeByToken(ctx, tokenValue, authType...) } -// ============ Permission Check | 权限验证 ============ +// CheckDisableWithInfo gets disable info | 获取封禁信息 +func CheckDisableWithInfo(ctx context.Context, loginID interface{}, authType ...string) (*manager.DisableInfo, error) { + return stputil.CheckDisableWithInfo(ctx, loginID, authType...) +} -// CheckPermission checks if the account has specified permission | 检查账号是否拥有指定权限 -func CheckPermissionByToken(tokenValue string, permission string) error { - return stputil.CheckPermission(tokenValue, permission) +// CheckDisableWithInfoByToken gets disable info by token | 根据Token获取封禁信息 +func CheckDisableWithInfoByToken(ctx context.Context, tokenValue string, authType ...string) (*manager.DisableInfo, error) { + return stputil.CheckDisableWithInfoByToken(ctx, tokenValue, authType...) } -// HasPermission checks if the account has specified permission (returns bool) | 检查账号是否拥有指定权限(返回布尔值) -func HasPermission(loginID interface{}, permission string) bool { - return stputil.HasPermission(loginID, permission) +// ============ Session Management | Session管理 ============ + +// GetSession gets session by login ID | 根据登录ID获取Session +func GetSession(ctx context.Context, loginID interface{}, authType ...string) (*session.Session, error) { + return stputil.GetSession(ctx, loginID, authType...) +} + +// GetSessionByToken gets session by token | 根据Token获取Session +func GetSessionByToken(ctx context.Context, tokenValue string, authType ...string) (*session.Session, error) { + return stputil.GetSessionByToken(ctx, tokenValue, authType...) } -// CheckPermissionAnd checks if the account has all specified permissions (AND logic) | 检查账号是否拥有所有指定权限(AND逻辑) -func CheckPermissionAndByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionAnd(tokenValue, permissions) +// DeleteSession deletes a session | 删除Session +func DeleteSession(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.DeleteSession(ctx, loginID, authType...) } -// CheckPermissionOr checks if the account has any of the specified permissions (OR logic) | 检查账号是否拥有指定权限中的任意一个(OR逻辑) -func CheckPermissionOrByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionOr(tokenValue, permissions) +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func DeleteSessionByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.DeleteSessionByToken(ctx, tokenValue, authType...) } -// GetPermissionList gets the permission list for an account | 获取账号的权限列表 -func GetPermissionListByToken(tokenValue string) ([]string, error) { - return stputil.GetPermissionList(tokenValue) +// HasSession checks if session exists | 检查Session是否存在 +func HasSession(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.HasSession(ctx, loginID, authType...) } -// ============ Role Check | 角色验证 ============ +// RenewSession renews session TTL | 续期Session +func RenewSession(ctx context.Context, loginID interface{}, ttl time.Duration, authType ...string) error { + return stputil.RenewSession(ctx, loginID, ttl, authType...) +} -// CheckRole checks if the account has specified role | 检查账号是否拥有指定角色 -func CheckRoleByToken(tokenValue string, role string) error { - return stputil.CheckRole(tokenValue, role) +// RenewSessionByToken renews session TTL by token | 根据Token续期Session +func RenewSessionByToken(ctx context.Context, tokenValue string, ttl time.Duration, authType ...string) error { + return stputil.RenewSessionByToken(ctx, tokenValue, ttl, authType...) } -// HasRole checks if the account has specified role (returns bool) | 检查账号是否拥有指定角色(返回布尔值) -func HasRole(loginID interface{}, role string) bool { - return stputil.HasRole(loginID, role) +// ============ Permission Verification | 权限验证 ============ + +// SetPermissions sets permissions for a login ID | 设置用户权限 +func SetPermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.SetPermissions(ctx, loginID, permissions, authType...) } -// CheckRoleAnd checks if the account has all specified roles (AND logic) | 检查账号是否拥有所有指定角色(AND逻辑) -func CheckRoleAndByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleAnd(tokenValue, roles) +// SetPermissionsByToken sets permissions by token | 根据 Token 设置对应账号的权限 +func SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.SetPermissionsByToken(ctx, tokenValue, permissions, authType...) } -// CheckRoleOr checks if the account has any of the specified roles (OR logic) | 检查账号是否拥有指定角色中的任意一个(OR逻辑) -func CheckRoleOrByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleOr(tokenValue, roles) +// RemovePermissions removes specified permissions for a login ID | 删除用户指定权限 +func RemovePermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.RemovePermissions(ctx, loginID, permissions, authType...) } -// GetRoleList gets the role list for an account | 获取账号的角色列表 -func GetRoleListByToken(tokenValue string) ([]string, error) { - return stputil.GetRoleList(tokenValue) +// RemovePermissionsByToken removes specified permissions by token | 根据 Token 删除对应账号的指定权限 +func RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.RemovePermissionsByToken(ctx, tokenValue, permissions, authType...) } -// ============ Session Management | Session管理 ============ +// GetPermissions gets permission list | 获取权限列表 +func GetPermissions(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetPermissions(ctx, loginID, authType...) +} + +// GetPermissionsByToken gets permission list by token | 根据 Token 获取对应账号的权限列表 +func GetPermissionsByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetPermissionsByToken(ctx, tokenValue, authType...) +} + +// HasPermission checks if has specified permission | 检查是否拥有指定权限 +func HasPermission(ctx context.Context, loginID interface{}, permission string, authType ...string) bool { + return stputil.HasPermission(ctx, loginID, permission, authType...) +} + +// HasPermissionByToken checks if the token has the specified permission | 检查Token是否拥有指定权限 +func HasPermissionByToken(ctx context.Context, tokenValue string, permission string, authType ...string) bool { + return stputil.HasPermissionByToken(ctx, tokenValue, permission, authType...) +} + +// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) +func HasPermissionsAnd(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAnd(ctx, loginID, permissions, authType...) +} + +// HasPermissionsAndByToken checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 +func HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAndByToken(ctx, tokenValue, permissions, authType...) +} + +// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) +func HasPermissionsOr(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOr(ctx, loginID, permissions, authType...) +} + +// HasPermissionsOrByToken checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 +func HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOrByToken(ctx, tokenValue, permissions, authType...) +} + +// ============ Role Management | 角色管理 ============ + +// SetRoles sets roles for a login ID | 设置用户角色 +func SetRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.SetRoles(ctx, loginID, roles, authType...) +} + +// SetRolesByToken sets roles by token | 根据 Token 设置对应账号的角色 +func SetRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.SetRolesByToken(ctx, tokenValue, roles, authType...) +} + +// RemoveRoles removes specified roles for a login ID | 删除用户指定角色 +func RemoveRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.RemoveRoles(ctx, loginID, roles, authType...) +} + +// RemoveRolesByToken removes specified roles by token | 根据 Token 删除对应账号的指定角色 +func RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.RemoveRolesByToken(ctx, tokenValue, roles, authType...) +} + +// GetRoles gets role list | 获取角色列表 +func GetRoles(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetRoles(ctx, loginID, authType...) +} + +// GetRolesByToken gets role list by token | 根据 Token 获取对应账号的角色列表 +func GetRolesByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetRolesByToken(ctx, tokenValue, authType...) +} + +// HasRole checks if has specified role | 检查是否拥有指定角色 +func HasRole(ctx context.Context, loginID interface{}, role string, authType ...string) bool { + return stputil.HasRole(ctx, loginID, role, authType...) +} + +// HasRoleByToken checks if the token has the specified role | 检查 Token 是否拥有指定角色 +func HasRoleByToken(ctx context.Context, tokenValue string, role string, authType ...string) bool { + return stputil.HasRoleByToken(ctx, tokenValue, role, authType...) +} + +// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) +func HasRolesAnd(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesAnd(ctx, loginID, roles, authType...) +} -// GetSession gets the session for a login ID | 获取登录ID的Session -func GetSession(loginID interface{}) (*Session, error) { - return stputil.GetSession(loginID) +// HasRolesAndByToken checks if the token has all specified roles | 检查 Token 是否拥有所有指定角色 +func HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesAndByToken(ctx, tokenValue, roles, authType...) } -// GetSessionByToken gets the session by token | 根据Token获取Session -func GetSessionByToken(tokenValue string) (*Session, error) { - return stputil.GetSessionByToken(tokenValue) +// HasRolesOr checks if has any role (OR logic) | 检查是否拥有任一角色(OR逻辑) +func HasRolesOr(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesOr(ctx, loginID, roles, authType...) } -// GetTokenSession gets the token session | 获取Token的Session -func GetTokenSession(tokenValue string) (*Session, error) { - return stputil.GetTokenSession(tokenValue) +// HasRolesOrByToken checks if the token has any of the specified roles | 检查 Token 是否拥有任一指定角色 +func HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesOrByToken(ctx, tokenValue, roles, authType...) } -// ============ Token Renewal | Token续期 ============ +// ============ Token Tag | Token标签 ============ + +// SetTokenTag sets token tag | 设置Token标签 +func SetTokenTag(ctx context.Context, tokenValue, tag string, authType ...string) error { + return stputil.SetTokenTag(ctx, tokenValue, tag, authType...) +} + +// GetTokenTag gets token tag | 获取Token标签 +func GetTokenTag(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetTokenTag(ctx, tokenValue, authType...) +} -// RenewTimeout renews token timeout | 续期Token超时时间 +// ============ Session Query | 会话查询 ============ + +// GetTokenValueListByLoginID gets all tokens for a login ID | 获取指定账号的所有Token +func GetTokenValueListByLoginID(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetTokenValueListByLoginID(ctx, loginID, authType...) +} + +// GetSessionCountByLoginID gets session count for a login ID | 获取指定账号的Session数量 +func GetSessionCountByLoginID(ctx context.Context, loginID interface{}, authType ...string) (int, error) { + return stputil.GetSessionCountByLoginID(ctx, loginID, authType...) +} // ============ Security Features | 安全特性 ============ -// GenerateNonce generates a new nonce token | 生成新的Nonce令牌 -func GenerateNonce() (string, error) { - return stputil.GenerateNonce() +// Generate Generates a one-time nonce | 生成一次性随机数 +func Generate(ctx context.Context, authType ...string) (string, error) { + return stputil.Generate(ctx, authType...) +} + +// Verify Verifies a nonce | 验证随机数 +func Verify(ctx context.Context, nonce string, authType ...string) bool { + return stputil.Verify(ctx, nonce, authType...) } -// VerifyNonce verifies a nonce token | 验证Nonce令牌 -func VerifyNonce(nonce string) bool { - return stputil.VerifyNonce(nonce) +// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func VerifyAndConsume(ctx context.Context, nonce string, authType ...string) error { + return stputil.VerifyAndConsume(ctx, nonce, authType...) } -// LoginWithRefreshToken performs login and returns both access token and refresh token | 登录并返回访问令牌和刷新令牌 -func LoginWithRefreshToken(loginID interface{}, device ...string) (*RefreshTokenInfo, error) { - return stputil.LoginWithRefreshToken(loginID, device...) +// IsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func IsValidNonce(ctx context.Context, nonce string, authType ...string) bool { + return stputil.IsValidNonce(ctx, nonce, authType...) } -// RefreshAccessToken refreshes the access token using a refresh token | 使用刷新令牌刷新访问令牌 -func RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) { - return stputil.RefreshAccessToken(refreshToken) +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func GenerateTokenPair(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GenerateTokenPair(ctx, loginID, deviceOrAutoType...) } -// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌 -func RevokeRefreshToken(refreshToken string) error { - return stputil.RevokeRefreshToken(refreshToken) +// VerifyAccessToken verifies access token validity | 验证访问令牌是否有效 +func VerifyAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.VerifyAccessToken(ctx, accessToken, authType...) +} + +// VerifyAccessTokenAndGetInfo verifies access token and returns token info | 验证访问令牌并返回Token信息 +func VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*security.AccessTokenInfo, bool) { + return stputil.VerifyAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息 +func GetRefreshTokenInfo(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GetRefreshTokenInfo(ctx, refreshToken, authType...) +} + +// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func RefreshAccessToken(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.RefreshAccessToken(ctx, refreshToken, authType...) +} + +// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 +func RevokeRefreshToken(ctx context.Context, refreshToken string, authType ...string) error { + return stputil.RevokeRefreshToken(ctx, refreshToken, authType...) +} + +// IsValid checks whether token is valid | 检查Token是否有效 +func IsValid(ctx context.Context, refreshToken string, authType ...string) bool { + return stputil.IsValid(ctx, refreshToken, authType...) +} + +// ============ OAuth2 Features | OAuth2 功能 ============ + +// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func RegisterClient(ctx context.Context, client *oauth2.Client, authType ...string) error { + return stputil.RegisterClient(ctx, client, authType...) } -// GetOAuth2Server gets the OAuth2 server instance | 获取OAuth2服务器实例 -func GetOAuth2Server() *OAuth2Server { - return stputil.GetOAuth2Server() +// UnregisterClient unregisters an OAuth2 client | 注销OAuth2客户端 +func UnregisterClient(ctx context.Context, clientID string, authType ...string) error { + return stputil.UnregisterClient(ctx, clientID, authType...) } -// Version Sa-Token-Go version | Sa-Token-Go版本 -const Version = core.Version +// GetClient gets OAuth2 client information | 获取OAuth2客户端信息 +func GetClient(ctx context.Context, clientID string, authType ...string) (*oauth2.Client, error) { + return stputil.GetClient(ctx, clientID, authType...) +} + +// GenerateAuthorizationCode creates an authorization code | 创建授权码 +func GenerateAuthorizationCode(ctx context.Context, clientID, loginID, redirectURI string, scope []string, authType ...string) (*oauth2.AuthorizationCode, error) { + return stputil.GenerateAuthorizationCode(ctx, clientID, loginID, redirectURI, scope, authType...) +} + +// ExchangeCodeForToken exchanges authorization code for token | 使用授权码换取令牌 +func ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI, authType...) +} + +// ValidateAccessToken verifies OAuth2 access token | 验证OAuth2访问令牌 +func ValidateAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.ValidateAccessToken(ctx, accessToken, authType...) +} + +// ValidateAccessTokenAndGetInfo verifies OAuth2 access token and get info | 验证OAuth2访问令牌并获取信息 +func ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ValidateAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌(OAuth2) +func OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2RefreshAccessToken(ctx, clientID, refreshToken, clientSecret, authType...) +} + +// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func RevokeToken(ctx context.Context, accessToken string, authType ...string) error { + return stputil.RevokeToken(ctx, accessToken, authType...) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点 +func OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2Token(ctx, req, validateUser, authType...) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2ClientCredentialsToken(ctx, clientID, clientSecret, scopes, authType...) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser, authType...) +} + +// ============ OAuth2 Grant Type Constants | OAuth2授权类型常量 ============ + +const ( + // GrantTypeAuthorizationCode Authorization code grant type | 授权码模式 + GrantTypeAuthorizationCode = oauth2.GrantTypeAuthorizationCode + // GrantTypeClientCredentials Client credentials grant type | 客户端凭证模式 + GrantTypeClientCredentials = oauth2.GrantTypeClientCredentials + // GrantTypePassword Password grant type | 密码模式 + GrantTypePassword = oauth2.GrantTypePassword + // GrantTypeRefreshToken Refresh token grant type | 刷新令牌模式 + GrantTypeRefreshToken = oauth2.GrantTypeRefreshToken +) + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func GetConfig(ctx context.Context, authType ...string) *config.Config { + return stputil.GetConfig(ctx, authType...) +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func GetStorage(ctx context.Context, authType ...string) adapter.Storage { + return stputil.GetStorage(ctx, authType...) +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func GetCodec(ctx context.Context, authType ...string) adapter.Codec { + return stputil.GetCodec(ctx, authType...) +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func GetLog(ctx context.Context, authType ...string) adapter.Log { + return stputil.GetLog(ctx, authType...) +} + +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func GetPool(ctx context.Context, authType ...string) adapter.Pool { + return stputil.GetPool(ctx, authType...) +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func GetGenerator(ctx context.Context, authType ...string) adapter.Generator { + return stputil.GetGenerator(ctx, authType...) +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func GetNonceManager(ctx context.Context, authType ...string) *security.NonceManager { + return stputil.GetNonceManager(ctx, authType...) +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func GetRefreshManager(ctx context.Context, authType ...string) *security.RefreshTokenManager { + return stputil.GetRefreshManager(ctx, authType...) +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func GetEventManager(ctx context.Context, authType ...string) *listener.Manager { + return stputil.GetEventManager(ctx, authType...) +} + +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func GetOAuth2Server(ctx context.Context, authType ...string) *oauth2.OAuth2Server { + return stputil.GetOAuth2Server(ctx, authType...) +} + +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func RegisterFunc(event listener.Event, fn func(*listener.EventData), authType ...string) { + stputil.RegisterFunc(event, fn, authType...) +} + +// Register registers an event listener | 注册事件监听器 +func Register(event listener.Event, l listener.Listener, authType ...string) string { + return stputil.Register(event, l, authType...) +} + +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func RegisterWithConfig(event listener.Event, l listener.Listener, cfg listener.ListenerConfig, authType ...string) string { + return stputil.RegisterWithConfig(event, l, cfg, authType...) +} + +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func Unregister(id string, authType ...string) bool { + return stputil.Unregister(id, authType...) +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func TriggerEvent(data *listener.EventData, authType ...string) { + stputil.TriggerEvent(data, authType...) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func TriggerEventAsync(data *listener.EventData, authType ...string) { + stputil.TriggerEventAsync(data, authType...) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func TriggerEventSync(data *listener.EventData, authType ...string) { + stputil.TriggerEventSync(data, authType...) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func WaitEvents(authType ...string) { + stputil.WaitEvents(authType...) +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func ClearEventListeners(event listener.Event, authType ...string) { + stputil.ClearEventListeners(event, authType...) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func ClearAllEventListeners(authType ...string) { + stputil.ClearAllEventListeners(authType...) +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func CountEventListeners(event listener.Event, authType ...string) int { + return stputil.CountEventListeners(event, authType...) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func CountAllListeners(authType ...string) int { + return stputil.CountAllListeners(authType...) +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func GetEventListenerIDs(event listener.Event, authType ...string) []string { + return stputil.GetEventListenerIDs(event, authType...) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func GetAllRegisteredEvents(authType ...string) []listener.Event { + return stputil.GetAllRegisteredEvents(authType...) +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func HasEventListeners(event listener.Event, authType ...string) bool { + return stputil.HasEventListeners(event, authType...) +} diff --git a/integrations/chi/go.mod b/integrations/chi/go.mod index 7ef3a25..02929d0 100644 --- a/integrations/chi/go.mod +++ b/integrations/chi/go.mod @@ -1,20 +1,15 @@ module github.com/click33/sa-token-go/integrations/chi -go 1.23.0 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 - github.com/click33/sa-token-go/stputil v0.1.5 + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 ) require ( - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect - golang.org/x/sync v0.16.0 // indirect -) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/stputil => ../../stputil + golang.org/x/sync v0.19.0 // indirect ) diff --git a/integrations/chi/go.sum b/integrations/chi/go.sum index dda2c2d..24defe4 100644 --- a/integrations/chi/go.sum +++ b/integrations/chi/go.sum @@ -1,8 +1,10 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/integrations/chi/middleware.go b/integrations/chi/middleware.go new file mode 100644 index 0000000..af7caf0 --- /dev/null +++ b/integrations/chi/middleware.go @@ -0,0 +1,342 @@ +package chi + +import ( + "encoding/json" + "errors" + "net/http" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/manager" + + saContext "github.com/click33/sa-token-go/core/context" + "github.com/click33/sa-token-go/stputil" +) + +// LogicType permission/role logic type | 权限/角色判断逻辑 +type LogicType string + +const ( + SaTokenCtxKey = "saCtx" + + LogicOr LogicType = "OR" // Logical OR | 任一满足 + LogicAnd LogicType = "AND" // Logical AND | 全部满足 +) + +type AuthOption func(*AuthOptions) + +type AuthOptions struct { + AuthType string + LogicType LogicType + FailFunc func(w http.ResponseWriter, r *http.Request, err error) +} + +func defaultAuthOptions() *AuthOptions { + return &AuthOptions{LogicType: LogicAnd} // 默认 AND +} + +// WithAuthType sets auth type | 设置认证类型 +func WithAuthType(authType string) AuthOption { + return func(o *AuthOptions) { + o.AuthType = authType + } +} + +// WithLogicType sets LogicType option | 设置逻辑类型 +func WithLogicType(logicType LogicType) AuthOption { + return func(o *AuthOptions) { + o.LogicType = logicType + } +} + +// WithFailFunc sets auth failure callback | 设置认证失败回调 +func WithFailFunc(fn func(w http.ResponseWriter, r *http.Request, err error)) AuthOption { + return func(o *AuthOptions) { + o.FailFunc = fn + } +} + +// ========== Middlewares ========== + +// AuthMiddleware authentication middleware | 认证中间件 +func AuthMiddleware(opts ...AuthOption) func(http.Handler) http.Handler { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(w, r, err) + } else { + writeErrorResponse(w, err) + } + return + } + + // 获取 token | Get token + ctx := NewChiContext(w, r) + saCtx := getSaContext(ctx.(*ChiContext), r, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录 | Check login + err = mgr.CheckLogin(r.Context(), tokenValue) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(w, r, err) + } else { + writeErrorResponse(w, err) + } + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// AuthWithStateMiddleware with state authentication middleware | 带状态返回的认证中间件 +func AuthWithStateMiddleware(opts ...AuthOption) func(http.Handler) http.Handler { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 获取 Manager | Get Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(w, r, err) + } else { + writeErrorResponse(w, err) + } + return + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + ctx := NewChiContext(w, r) + saCtx := getSaContext(ctx.(*ChiContext), r, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录并返回状态 | Check login with state + _, err = mgr.CheckLoginWithState(r.Context(), tokenValue) + + if err != nil { + // 用户自定义回调优先 + if options.FailFunc != nil { + options.FailFunc(w, r, err) + } else { + writeErrorResponse(w, err) + } + + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// PermissionMiddleware permission check middleware | 权限校验中间件 +func PermissionMiddleware( + permissions []string, + opts ...AuthOption, +) func(http.Handler) http.Handler { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // No permission required | 无需权限直接放行 + if len(permissions) == 0 { + next.ServeHTTP(w, r) + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(w, r, err) + } else { + writeErrorResponse(w, err) + } + return + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + ctx := NewChiContext(w, r) + saCtx := getSaContext(ctx.(*ChiContext), r, mgr) + tokenValue := saCtx.GetTokenValue() + reqCtx := r.Context() + + // Permission check | 权限校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasPermissionsAndByToken(reqCtx, tokenValue, permissions) + } else { + ok = mgr.HasPermissionsOrByToken(reqCtx, tokenValue, permissions) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(w, r, core.ErrPermissionDenied) + } else { + writeErrorResponse(w, core.ErrPermissionDenied) + } + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// RoleMiddleware role check middleware | 角色校验中间件 +func RoleMiddleware( + roles []string, + opts ...AuthOption, +) func(http.Handler) http.Handler { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // No role required | 无需角色直接放行 + if len(roles) == 0 { + next.ServeHTTP(w, r) + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(w, r, err) + } else { + writeErrorResponse(w, err) + } + return + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + ctx := NewChiContext(w, r) + saCtx := getSaContext(ctx.(*ChiContext), r, mgr) + tokenValue := saCtx.GetTokenValue() + reqCtx := r.Context() + + // Role check | 角色校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasRolesAndByToken(reqCtx, tokenValue, roles) + } else { + ok = mgr.HasRolesOrByToken(reqCtx, tokenValue, roles) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(w, r, core.ErrRoleDenied) + } else { + writeErrorResponse(w, core.ErrRoleDenied) + } + return + } + + next.ServeHTTP(w, r) + }) + } +} + +// GetSaTokenContext gets Sa-Token context from request | 获取 Sa-Token 上下文 +func GetSaTokenContext(r *http.Request) (*saContext.SaTokenContext, bool) { + v := r.Context().Value(SaTokenCtxKey) + if v == nil { + return nil, false + } + + ctx, ok := v.(*saContext.SaTokenContext) + return ctx, ok +} + +func getSaContext(chiCtx *ChiContext, r *http.Request, mgr *manager.Manager) *saContext.SaTokenContext { + // Try get from context | 尝试从 ctx 取值 + if v := r.Context().Value(SaTokenCtxKey); v != nil { + if saCtx, ok := v.(*saContext.SaTokenContext); ok { + return saCtx + } + } + + // Create new context | 创建并缓存 SaTokenContext + saCtx := saContext.NewContext(chiCtx, mgr) + chiCtx.Set(SaTokenCtxKey, saCtx) + + return saCtx +} + +// ============ Error Handling Helpers | 错误处理辅助函数 ============ + +// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 +func writeErrorResponse(w http.ResponseWriter, err error) { + var saErr *core.SaTokenError + var code int + var message string + var httpStatus int + + // Check if it's a SaTokenError | 检查是否为SaTokenError + if errors.As(err, &saErr) { + code = saErr.Code + message = saErr.Message + httpStatus = getHTTPStatusFromCode(code) + } else { + // Handle standard errors | 处理标准错误 + code = core.CodeServerError + message = err.Error() + httpStatus = http.StatusInternalServerError + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(httpStatus) + json.NewEncoder(w).Encode(map[string]interface{}{ + "code": code, + "message": message, + "data": err.Error(), + }) +} + +// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 +func writeSuccessResponse(w http.ResponseWriter, data interface{}) { + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(map[string]interface{}{ + "code": core.CodeSuccess, + "message": "success", + "data": data, + }) +} + +// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 +func getHTTPStatusFromCode(code int) int { + switch code { + case core.CodeNotLogin: + return http.StatusUnauthorized + case core.CodePermissionDenied: + return http.StatusForbidden + case core.CodeBadRequest: + return http.StatusBadRequest + case core.CodeNotFound: + return http.StatusNotFound + case core.CodeServerError: + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} diff --git a/integrations/chi/plugin.go b/integrations/chi/plugin.go deleted file mode 100644 index 7fdab88..0000000 --- a/integrations/chi/plugin.go +++ /dev/null @@ -1,183 +0,0 @@ -package chi - -import ( - "encoding/json" - "errors" - "net/http" - - "github.com/click33/sa-token-go/core" -) - -// Plugin Chi plugin for Sa-Token | Chi插件 -type Plugin struct { - manager *core.Manager -} - -// NewPlugin creates a Chi plugin | 创建Chi插件 -func NewPlugin(manager *core.Manager) *Plugin { - return &Plugin{ - manager: manager, - } -} - -// AuthMiddleware authentication middleware | 认证中间件 -func (p *Plugin) AuthMiddleware() func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := NewChiContext(w, r) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(w, err) - return - } - - // Store Sa-Token context | 存储Sa-Token上下文 - ctx.Set("satoken", saCtx) - next.ServeHTTP(w, r) - }) - } -} - -// PermissionRequired permission validation middleware | 权限验证中间件 -func (p *Plugin) PermissionRequired(permission string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := NewChiContext(w, r) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(w, err) - return - } - - if !saCtx.HasPermission(permission) { - writeErrorResponse(w, core.NewPermissionDeniedError(permission)) - return - } - - ctx.Set("satoken", saCtx) - next.ServeHTTP(w, r) - }) - } -} - -// RoleRequired role validation middleware | 角色验证中间件 -func (p *Plugin) RoleRequired(role string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx := NewChiContext(w, r) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(w, err) - return - } - - if !saCtx.HasRole(role) { - writeErrorResponse(w, core.NewRoleDeniedError(role)) - return - } - - ctx.Set("satoken", saCtx) - next.ServeHTTP(w, r) - }) - } -} - -// LoginHandler 登录处理器 -func (p *Plugin) LoginHandler(w http.ResponseWriter, r *http.Request) { - var req struct { - Username string `json:"username"` - Password string `json:"password"` - Device string `json:"device"` - } - - if err := json.NewDecoder(r.Body).Decode(&req); err != nil { - writeErrorResponse(w, core.NewError(core.CodeBadRequest, "invalid request parameters", err)) - return - } - - device := req.Device - if device == "" { - device = "default" - } - - token, err := p.manager.Login(req.Username, device) - if err != nil { - writeErrorResponse(w, core.NewError(core.CodeServerError, "login failed", err)) - return - } - - writeSuccessResponse(w, map[string]interface{}{ - "token": token, - }) -} - -// GetSaToken 从请求上下文获取Sa-Token上下文 -func GetSaToken(r *http.Request) (*core.SaTokenContext, bool) { - satoken := r.Context().Value("satoken") - if satoken == nil { - return nil, false - } - ctx, ok := satoken.(*core.SaTokenContext) - return ctx, ok -} - -// ============ Error Handling Helpers | 错误处理辅助函数 ============ - -// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 -func writeErrorResponse(w http.ResponseWriter, err error) { - var saErr *core.SaTokenError - var code int - var message string - var httpStatus int - - // Check if it's a SaTokenError | 检查是否为SaTokenError - if errors.As(err, &saErr) { - code = saErr.Code - message = saErr.Message - httpStatus = getHTTPStatusFromCode(code) - } else { - // Handle standard errors | 处理标准错误 - code = core.CodeServerError - message = err.Error() - httpStatus = http.StatusInternalServerError - } - - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(httpStatus) - json.NewEncoder(w).Encode(map[string]interface{}{ - "code": code, - "message": message, - "error": err.Error(), - }) -} - -// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 -func writeSuccessResponse(w http.ResponseWriter, data interface{}) { - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(map[string]interface{}{ - "code": core.CodeSuccess, - "message": "success", - "data": data, - }) -} - -// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 -func getHTTPStatusFromCode(code int) int { - switch code { - case core.CodeNotLogin: - return http.StatusUnauthorized - case core.CodePermissionDenied: - return http.StatusForbidden - case core.CodeBadRequest: - return http.StatusBadRequest - case core.CodeNotFound: - return http.StatusNotFound - case core.CodeServerError: - return http.StatusInternalServerError - default: - return http.StatusInternalServerError - } -} diff --git a/integrations/echo/annotation.go b/integrations/echo/annotation.go index a8c9804..2eac67a 100644 --- a/integrations/echo/annotation.go +++ b/integrations/echo/annotation.go @@ -1,6 +1,8 @@ +// @Author daixk 2025/12/28 package echo import ( + "context" "strings" "github.com/click33/sa-token-go/core" @@ -10,90 +12,112 @@ import ( // Annotation annotation structure | 注解结构体 type Annotation struct { - CheckLogin bool `json:"checkLogin"` - CheckRole []string `json:"checkRole"` - CheckPermission []string `json:"checkPermission"` - CheckDisable bool `json:"checkDisable"` - Ignore bool `json:"ignore"` + AuthType string `json:"authType"` // Optional: specify auth type | 可选:指定认证类型 + CheckLogin bool `json:"checkLogin"` // Check login | 检查登录 + CheckRole []string `json:"checkRole"` // Check roles | 检查角色 + CheckPermission []string `json:"checkPermission"` // Check permissions | 检查权限 + CheckDisable bool `json:"checkDisable"` // Check disable status | 检查封禁状态 + Ignore bool `json:"ignore"` // Ignore authentication | 忽略认证 + LogicType LogicType `json:"logicType"` // OR or AND logic (default: OR) | OR 或 AND 逻辑(默认: OR) } // GetHandler gets handler with annotations | 获取带注解的处理器 +// Note: handler must not be nil, use middleware pattern instead | 注意: handler不能为nil,请使用中间件模式 func GetHandler(handler echo.HandlerFunc, annotations ...*Annotation) echo.HandlerFunc { return func(c echo.Context) error { - // Check if authentication should be ignored | 检查是否忽略认证 + // Ignore authentication | 忽略认证直接放行 if len(annotations) > 0 && annotations[0].Ignore { - if handler != nil { - return handler(c) - } - return nil + return handler(c) + } + + // Check if any authentication is needed | 检查是否需要任何认证 + ann := &Annotation{} + if len(annotations) > 0 { + ann = annotations[0] + } + + // No authentication required | 无需任何认证 + needAuth := ann.CheckLogin || ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 + if !needAuth { + return handler(c) } - // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) - ctx := NewEchoContext(c) - saCtx := core.NewContext(ctx, stputil.GetManager()) + ctx := c.Request().Context() + + // Get manager | 获取 Manager + mgr, err := stputil.GetManager(ann.AuthType) + if err != nil { + return writeErrorResponse(c, err) + } + + // Get SaTokenContext (reuse cached context) | 获取 SaTokenContext(复用缓存上下文) + saCtx := getSaContext(c, mgr) token := saCtx.GetTokenValue() + if token == "" { return writeErrorResponse(c, core.NewNotLoginError()) } // Check login | 检查登录 - if !stputil.IsLogin(token) { - return writeErrorResponse(c, core.NewNotLoginError()) + if err := mgr.CheckLogin(ctx, token); err != nil { + return writeErrorResponse(c, err) } - // Get login ID | 获取登录ID - loginID, err := stputil.GetLoginID(token) - if err != nil { - return writeErrorResponse(c, err) + // Get loginID for further checks | 获取 loginID 用于后续检查 + var loginID string + if ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 { + loginID, err = mgr.GetLoginIDNotCheck(ctx, token) + if err != nil { + return writeErrorResponse(c, err) + } } // Check if account is disabled | 检查是否被封禁 - if len(annotations) > 0 && annotations[0].CheckDisable { - if stputil.IsDisable(loginID) { + if ann.CheckDisable { + if mgr.IsDisable(ctx, loginID) { return writeErrorResponse(c, core.NewAccountDisabledError(loginID)) } } // Check permission | 检查权限 - if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { - hasPermission := false - for _, perm := range annotations[0].CheckPermission { - if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { - hasPermission = true - break - } + if len(ann.CheckPermission) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasPermissionsAnd(ctx, loginID, ann.CheckPermission) + } else { + ok = mgr.HasPermissionsOr(ctx, loginID, ann.CheckPermission) } - if !hasPermission { - return writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + if !ok { + return writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) } } // Check role | 检查角色 - if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { - hasRole := false - for _, role := range annotations[0].CheckRole { - if stputil.HasRole(loginID, strings.TrimSpace(role)) { - hasRole = true - break - } + if len(ann.CheckRole) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasRolesAnd(ctx, loginID, ann.CheckRole) + } else { + ok = mgr.HasRolesOr(ctx, loginID, ann.CheckRole) } - if !hasRole { - return writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + if !ok { + return writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) } } // All checks passed, execute original handler | 所有检查通过,执行原函数 - if handler != nil { - return handler(c) - } - return nil + return handler(c) } } // CheckLoginMiddleware decorator for login checking | 检查登录装饰器 -func CheckLoginMiddleware() echo.MiddlewareFunc { +func CheckLoginMiddleware(authType ...string) echo.MiddlewareFunc { + ann := &Annotation{CheckLogin: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } return func(next echo.HandlerFunc) echo.HandlerFunc { - return GetHandler(next, &Annotation{CheckLogin: true}) + return GetHandler(next, ann) } } @@ -104,6 +128,13 @@ func CheckRoleMiddleware(roles ...string) echo.MiddlewareFunc { } } +// CheckRoleMiddlewareWithAuthType decorator for role checking with auth type | 检查角色装饰器(带认证类型) +func CheckRoleMiddlewareWithAuthType(authType string, roles ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckRole: roles, AuthType: authType}) + } +} + // CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 func CheckPermissionMiddleware(perms ...string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { @@ -111,10 +142,21 @@ func CheckPermissionMiddleware(perms ...string) echo.MiddlewareFunc { } } +// CheckPermissionMiddlewareWithAuthType decorator for permission checking with auth type | 检查权限装饰器(带认证类型) +func CheckPermissionMiddlewareWithAuthType(authType string, perms ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckPermission: perms, AuthType: authType}) + } +} + // CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 -func CheckDisableMiddleware() echo.MiddlewareFunc { +func CheckDisableMiddleware(authType ...string) echo.MiddlewareFunc { + ann := &Annotation{CheckDisable: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } return func(next echo.HandlerFunc) echo.HandlerFunc { - return GetHandler(next, &Annotation{CheckDisable: true}) + return GetHandler(next, ann) } } @@ -124,3 +166,114 @@ func IgnoreMiddleware() echo.MiddlewareFunc { return GetHandler(next, &Annotation{Ignore: true}) } } + +// ============ Combined Middleware | 组合中间件 ============ + +// CheckLoginAndRoleMiddleware checks login and role | 检查登录和角色 +func CheckLoginAndRoleMiddleware(roles ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckLogin: true, CheckRole: roles}) + } +} + +// CheckLoginAndPermissionMiddleware checks login and permission | 检查登录和权限 +func CheckLoginAndPermissionMiddleware(perms ...string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{CheckLogin: true, CheckPermission: perms}) + } +} + +// CheckAllMiddleware checks login, role, permission and disable status | 全面检查 +func CheckAllMiddleware(roles []string, perms []string) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return GetHandler(next, &Annotation{ + CheckLogin: true, + CheckRole: roles, + CheckPermission: perms, + CheckDisable: true, + }) + } +} + +// ============ Route Group Helper | 路由组辅助函数 ============ + +// AuthGroup creates a route group with authentication | 创建带认证的路由组 +func AuthGroup(group *echo.Group, authType ...string) *echo.Group { + group.Use(CheckLoginMiddleware(authType...)) + return group +} + +// RoleGroup creates a route group with role checking | 创建带角色检查的路由组 +func RoleGroup(group *echo.Group, roles ...string) *echo.Group { + group.Use(CheckLoginAndRoleMiddleware(roles...)) + return group +} + +// PermissionGroup creates a route group with permission checking | 创建带权限检查的路由组 +func PermissionGroup(group *echo.Group, perms ...string) *echo.Group { + group.Use(CheckLoginAndPermissionMiddleware(perms...)) + return group +} + +// ============ Context Helper | 上下文辅助函数 ============ + +// GetLoginIDFromRequest gets login ID from request context | 从请求上下文获取登录 ID +func GetLoginIDFromRequest(c echo.Context, authType ...string) (string, error) { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return "", err + } + + saCtx := getSaContext(c, mgr) + token := saCtx.GetTokenValue() + if token == "" { + return "", core.ErrNotLogin + } + return mgr.GetLoginID(c.Request().Context(), token) +} + +// IsLoginFromRequest checks if user is logged in from request | 从请求检查用户是否已登录 +func IsLoginFromRequest(c echo.Context, authType ...string) bool { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return false + } + + saCtx := getSaContext(c, mgr) + token := saCtx.GetTokenValue() + if token == "" { + return false + } + return mgr.IsLogin(c.Request().Context(), token) +} + +// GetTokenFromRequest gets token from request (exported) | 从请求获取 Token(导出) +func GetTokenFromRequest(c echo.Context, authType ...string) string { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return "" + } + + saCtx := getSaContext(c, mgr) + return saCtx.GetTokenValue() +} + +// WithContext creates a new context with sa-token context | 创建带 sa-token 上下文的新上下文 +func WithContext(c echo.Context, authType ...string) context.Context { + return c.Request().Context() +} diff --git a/integrations/echo/export.go b/integrations/echo/export.go index 2d62528..b823f33 100644 --- a/integrations/echo/export.go +++ b/integrations/echo/export.go @@ -1,364 +1,950 @@ package echo import ( + "context" "time" + "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/codec/msgpack" "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" + "github.com/click33/sa-token-go/core/manager" + "github.com/click33/sa-token-go/core/oauth2" + "github.com/click33/sa-token-go/core/security" + "github.com/click33/sa-token-go/core/session" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/log/slog" + "github.com/click33/sa-token-go/pool/ants" + "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/storage/redis" "github.com/click33/sa-token-go/stputil" ) -// ============ Re-export core types | 重新导出核心类型 ============ +// ============ Type Aliases | 类型别名 ============ -// Configuration related types | 配置相关类型 type ( - Config = core.Config - CookieConfig = core.CookieConfig - TokenStyle = core.TokenStyle + // Config 配置 + Config = config.Config + // Manager 管理器 + Manager = manager.Manager + // Session 会话 + Session = session.Session + // TokenInfo Token信息 + TokenInfo = manager.TokenInfo + // DisableInfo 封禁信息 + DisableInfo = manager.DisableInfo + // Builder 构建器 + Builder = builder.Builder + // SaTokenError 错误类型 + SaTokenError = core.SaTokenError + // Event 事件类型 + Event = listener.Event + // EventData 事件数据 + EventData = listener.EventData + // Listener 事件监听器 + Listener = listener.Listener + // ListenerConfig 监听器配置 + ListenerConfig = listener.ListenerConfig + // RefreshTokenInfo 刷新令牌信息 + RefreshTokenInfo = security.RefreshTokenInfo + // AccessTokenInfo 访问令牌信息 + AccessTokenInfo = security.AccessTokenInfo + // OAuth2Client OAuth2客户端 + OAuth2Client = oauth2.Client + // OAuth2AccessToken OAuth2访问令牌 + OAuth2AccessToken = oauth2.AccessToken + // AuthorizationCode 授权码 + AuthorizationCode = oauth2.AuthorizationCode + // OAuth2TokenRequest OAuth2令牌请求 + OAuth2TokenRequest = oauth2.TokenRequest + // OAuth2GrantType OAuth2授权类型 + OAuth2GrantType = oauth2.GrantType + // OAuth2UserValidator OAuth2用户验证器 + OAuth2UserValidator = oauth2.UserValidator + // Storage 存储接口 + Storage = adapter.Storage + // Codec 编解码接口 + Codec = adapter.Codec + // Log 日志接口 + Log = adapter.Log + // Pool 协程池接口 + Pool = adapter.Pool + // Generator 生成器接口 + Generator = adapter.Generator + + // ============ Codec Types | 编解码器类型 ============ + + // JSONSerializer JSON编解码器 + JSONSerializer = json.JSONSerializer + // MsgPackSerializer MsgPack编解码器 + MsgPackSerializer = msgpack.MsgPackSerializer + + // ============ Storage Types | 存储类型 ============ + + // MemoryStorage 内存存储 + MemoryStorage = memory.Storage + // RedisStorage Redis存储 + RedisStorage = redis.Storage + // RedisConfig Redis配置 + RedisConfig = redis.Config + // RedisBuilder Redis构建器 + RedisBuilder = redis.Builder + + // ============ Logger Types | 日志类型 ============ + + // SlogLogger 标准日志实现 + SlogLogger = slog.Logger + // SlogLoggerConfig 标准日志配置 + SlogLoggerConfig = slog.LoggerConfig + // SlogLogLevel 日志级别 + SlogLogLevel = slog.LogLevel + // NopLogger 空日志实现 + NopLogger = nop.NopLogger + + // ============ Generator Types | 生成器类型 ============ + + // TokenGenerator Token生成器 + TokenGenerator = sgenerator.Generator + // TokenStyle Token风格 + TokenStyle = adapter.TokenStyle + + // ============ Pool Types | 协程池类型 ============ + + // RenewPoolManager 续期池管理器 + RenewPoolManager = ants.RenewPoolManager + // RenewPoolConfig 续期池配置 + RenewPoolConfig = ants.RenewPoolConfig ) -// Token style constants | Token风格常量 +// ============ Error Codes | 错误码 ============ + const ( - TokenStyleUUID = core.TokenStyleUUID - TokenStyleSimple = core.TokenStyleSimple - TokenStyleRandom32 = core.TokenStyleRandom32 - TokenStyleRandom64 = core.TokenStyleRandom64 - TokenStyleRandom128 = core.TokenStyleRandom128 - TokenStyleJWT = core.TokenStyleJWT - TokenStyleHash = core.TokenStyleHash - TokenStyleTimestamp = core.TokenStyleTimestamp - TokenStyleTik = core.TokenStyleTik + CodeSuccess = core.CodeSuccess + CodeBadRequest = core.CodeBadRequest + CodeNotLogin = core.CodeNotLogin + CodePermissionDenied = core.CodePermissionDenied + CodeNotFound = core.CodeNotFound + CodeServerError = core.CodeServerError + CodeTokenInvalid = core.CodeTokenInvalid + CodeTokenExpired = core.CodeTokenExpired + CodeAccountDisabled = core.CodeAccountDisabled + CodeKickedOut = core.CodeKickedOut + CodeActiveTimeout = core.CodeActiveTimeout + CodeMaxLoginCount = core.CodeMaxLoginCount + CodeStorageError = core.CodeStorageError + CodeInvalidParameter = core.CodeInvalidParameter + CodeSessionError = core.CodeSessionError ) -// Core types | 核心类型 -type ( - Manager = core.Manager - TokenInfo = core.TokenInfo - Session = core.Session - TokenGenerator = core.TokenGenerator - SaTokenContext = core.SaTokenContext - Builder = core.Builder - NonceManager = core.NonceManager - RefreshTokenInfo = core.RefreshTokenInfo - RefreshTokenManager = core.RefreshTokenManager - OAuth2Server = core.OAuth2Server - OAuth2Client = core.OAuth2Client - OAuth2AccessToken = core.OAuth2AccessToken - OAuth2GrantType = core.OAuth2GrantType -) +// ============ Errors | 错误变量 ============ -// Adapter interfaces | 适配器接口 -type ( - Storage = core.Storage - RequestContext = core.RequestContext +var ( + // Authentication Errors | 认证错误 + ErrNotLogin = core.ErrNotLogin + ErrTokenInvalid = core.ErrTokenInvalid + ErrTokenExpired = core.ErrTokenExpired + ErrTokenKickout = core.ErrTokenKickout + ErrTokenReplaced = core.ErrTokenReplaced + ErrInvalidLoginID = core.ErrInvalidLoginID + ErrInvalidDevice = core.ErrInvalidDevice + ErrTokenNotFound = core.ErrTokenNotFound + + // Authorization Errors | 授权错误 + ErrPermissionDenied = core.ErrPermissionDenied + ErrRoleDenied = core.ErrRoleDenied + + // Account Errors | 账号错误 + ErrAccountDisabled = core.ErrAccountDisabled + ErrAccountNotFound = core.ErrAccountNotFound + ErrLoginLimitExceeded = core.ErrLoginLimitExceeded + + // Session Errors | 会话错误 + ErrSessionNotFound = core.ErrSessionNotFound + ErrActiveTimeout = core.ErrActiveTimeout + ErrSessionInvalidDataKey = core.ErrSessionInvalidDataKey + ErrSessionIDEmpty = core.ErrSessionIDEmpty + + // Security Errors | 安全错误 + ErrInvalidNonce = core.ErrInvalidNonce + ErrRefreshTokenExpired = core.ErrRefreshTokenExpired + ErrNonceInvalidRefreshToken = core.ErrNonceInvalidRefreshToken + ErrInvalidLoginIDEmpty = core.ErrInvalidLoginIDEmpty + + // OAuth2 Errors | OAuth2错误 + ErrClientOrClientIDEmpty = core.ErrClientOrClientIDEmpty + ErrClientNotFound = core.ErrClientNotFound + ErrUserIDEmpty = core.ErrUserIDEmpty + ErrInvalidRedirectURI = core.ErrInvalidRedirectURI + ErrInvalidClientCredentials = core.ErrInvalidClientCredentials + ErrInvalidAuthCode = core.ErrInvalidAuthCode + ErrAuthCodeUsed = core.ErrAuthCodeUsed + ErrAuthCodeExpired = core.ErrAuthCodeExpired + ErrClientMismatch = core.ErrClientMismatch + ErrRedirectURIMismatch = core.ErrRedirectURIMismatch + ErrInvalidAccessToken = core.ErrInvalidAccessToken + ErrInvalidRefreshToken = core.ErrInvalidRefreshToken + ErrInvalidScope = core.ErrInvalidScope + + // System Errors | 系统错误 + ErrStorageUnavailable = core.ErrStorageUnavailable + ErrSerializeFailed = core.ErrSerializeFailed + ErrDeserializeFailed = core.ErrDeserializeFailed + ErrTypeConvert = core.ErrTypeConvert ) -// Event related types | 事件相关类型 -type ( - EventListener = core.EventListener - EventManager = core.EventManager - EventData = core.EventData - Event = core.Event - ListenerFunc = core.ListenerFunc - ListenerConfig = core.ListenerConfig -) +// ============ Error Constructors | 错误构造函数 ============ -// Event constants | 事件常量 -const ( - EventLogin = core.EventLogin - EventLogout = core.EventLogout - EventKickout = core.EventKickout - EventDisable = core.EventDisable - EventUntie = core.EventUntie - EventRenew = core.EventRenew - EventCreateSession = core.EventCreateSession - EventDestroySession = core.EventDestroySession - EventPermissionCheck = core.EventPermissionCheck - EventRoleCheck = core.EventRoleCheck - EventAll = core.EventAll +var ( + NewError = core.NewError + NewErrorWithContext = core.NewErrorWithContext + NewNotLoginError = core.NewNotLoginError + NewPermissionDeniedError = core.NewPermissionDeniedError + NewRoleDeniedError = core.NewRoleDeniedError + NewAccountDisabledError = core.NewAccountDisabledError ) -// OAuth2 grant type constants | OAuth2授权类型常量 -const ( - GrantTypeAuthorizationCode = core.GrantTypeAuthorizationCode - GrantTypeRefreshToken = core.GrantTypeRefreshToken - GrantTypeClientCredentials = core.GrantTypeClientCredentials - GrantTypePassword = core.GrantTypePassword -) +// ============ Error Checking Helpers | 错误检查辅助函数 ============ -// Utility functions | 工具函数 var ( - RandomString = core.RandomString - IsEmpty = core.IsEmpty - IsNotEmpty = core.IsNotEmpty - DefaultString = core.DefaultString - ContainsString = core.ContainsString - RemoveString = core.RemoveString - UniqueStrings = core.UniqueStrings - MergeStrings = core.MergeStrings - MatchPattern = core.MatchPattern + IsNotLoginError = core.IsNotLoginError + IsPermissionDeniedError = core.IsPermissionDeniedError + IsAccountDisabledError = core.IsAccountDisabledError + IsTokenError = core.IsTokenError + GetErrorCode = core.GetErrorCode ) -// ============ Core constructor functions | 核心构造函数 ============ +// ============ Manager Management | Manager 管理 ============ + +// SetManager stores the manager-example in the global map using the specified autoType | 使用指定的 autoType 将管理器存储在全局 map 中 +func SetManager(mgr *manager.Manager) { + stputil.SetManager(mgr) +} + +// GetManager retrieves the manager-example from the global map using the specified autoType | 使用指定的 autoType 从全局 map 中获取管理器 +func GetManager(autoType ...string) (*manager.Manager, error) { + return stputil.GetManager(autoType...) +} -// DefaultConfig returns default configuration | 返回默认配置 -func DefaultConfig() *Config { - return core.DefaultConfig() +// DeleteManager delete the specific manager-example for the given autoType and releases resources | 删除指定的管理器并释放资源 +func DeleteManager(autoType ...string) error { + return stputil.DeleteManager(autoType...) } -// NewManager creates a new authentication manager | 创建新的认证管理器 -func NewManager(storage Storage, cfg *Config) *Manager { - return core.NewManager(storage, cfg) +// DeleteAllManager delete all managers in the global map and releases resources | 关闭所有管理器并释放资源 +func DeleteAllManager() { + stputil.DeleteAllManager() } -// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext { - return core.NewContext(ctx, mgr) +// ============ Builder & Config | 构建器和配置 ============ + +// NewDefaultBuild creates a new default builder | 创建默认构建器 +func NewDefaultBuild() *builder.Builder { + return builder.NewBuilder() } -// NewSession creates a new session | 创建新的Session -func NewSession(id string, storage Storage, prefix string) *Session { - return core.NewSession(id, storage, prefix) +// NewDefaultConfig creates a new default config | 创建默认配置 +func NewDefaultConfig() *config.Config { + return config.DefaultConfig() } -// LoadSession loads an existing session | 加载已存在的Session -func LoadSession(id string, storage Storage, prefix string) (*Session, error) { - return core.LoadSession(id, storage, prefix) +// DefaultLoggerConfig returns the default logger config | 返回默认日志配置 +func DefaultLoggerConfig() *slog.LoggerConfig { + return slog.DefaultLoggerConfig() } -// NewTokenGenerator creates a new token generator | 创建新的Token生成器 -func NewTokenGenerator(cfg *Config) *TokenGenerator { - return core.NewTokenGenerator(cfg) +// DefaultRenewPoolConfig returns the default renew pool config | 返回默认续期池配置 +func DefaultRenewPoolConfig() *ants.RenewPoolConfig { + return ants.DefaultRenewPoolConfig() } -// NewEventManager creates a new event manager | 创建新的事件管理器 -func NewEventManager() *EventManager { - return core.NewEventManager() +// ============ Codec Constructors | 编解码器构造函数 ============ + +// NewJSONSerializer creates a new JSON serializer | 创建JSON序列化器 +func NewJSONSerializer() *json.JSONSerializer { + return json.NewJSONSerializer() } -// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置) -func NewBuilder() *Builder { - return core.NewBuilder() +// NewMsgPackSerializer creates a new MsgPack serializer | 创建MsgPack序列化器 +func NewMsgPackSerializer() *msgpack.MsgPackSerializer { + return msgpack.NewMsgPackSerializer() } -// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器 -func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { - return core.NewNonceManager(storage, prefix, ttl...) +// ============ Storage Constructors | 存储构造函数 ============ + +// NewMemoryStorage creates a new memory storage | 创建内存存储 +func NewMemoryStorage() *memory.Storage { + return memory.NewStorage() } -// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器 -func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { - return core.NewRefreshTokenManager(storage, prefix, cfg) +// NewMemoryStorageWithCleanupInterval creates a new memory storage with cleanup interval | 创建带清理间隔的内存存储 +func NewMemoryStorageWithCleanupInterval(interval time.Duration) *memory.Storage { + return memory.NewStorageWithCleanupInterval(interval) } -// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器 -func NewOAuth2Server(storage Storage, prefix string) *OAuth2Server { - return core.NewOAuth2Server(storage, prefix) +// NewRedisStorage creates a new Redis storage from URL | 通过URL创建Redis存储 +func NewRedisStorage(url string) (*redis.Storage, error) { + return redis.NewStorage(url) } -// ============ Global StpUtil functions | 全局StpUtil函数 ============ +// NewRedisStorageFromConfig creates a new Redis storage from config | 通过配置创建Redis存储 +func NewRedisStorageFromConfig(cfg *redis.Config) (*redis.Storage, error) { + return redis.NewStorageFromConfig(cfg) +} -// SetManager sets the global Manager (must be called first) | 设置全局Manager(必须先调用此方法) -func SetManager(mgr *Manager) { - stputil.SetManager(mgr) +// NewRedisBuilder creates a new Redis builder | 创建Redis构建器 +func NewRedisBuilder() *redis.Builder { + return redis.NewBuilder() +} + +// ============ Logger Constructors | 日志构造函数 ============ + +// NewSlogLogger creates a new slog logger with config | 使用配置创建标准日志器 +func NewSlogLogger(cfg *slog.LoggerConfig) (*slog.Logger, error) { + return slog.NewLoggerWithConfig(cfg) } -// GetManager gets the global Manager | 获取全局Manager -func GetManager() *Manager { - return stputil.GetManager() +// NewNopLogger creates a new no-op logger | 创建空日志器 +func NewNopLogger() *nop.NopLogger { + return nop.NewNopLogger() } +// ============ Generator Constructors | 生成器构造函数 ============ + +// NewTokenGenerator creates a new token generator | 创建Token生成器 +func NewTokenGenerator(timeout int64, tokenStyle adapter.TokenStyle, jwtSecretKey string) *sgenerator.Generator { + return sgenerator.NewGenerator(timeout, tokenStyle, jwtSecretKey) +} + +// NewDefaultTokenGenerator creates a new default token generator | 创建默认Token生成器 +func NewDefaultTokenGenerator() *sgenerator.Generator { + return sgenerator.NewDefaultGenerator() +} + +// ============ Pool Constructors | 协程池构造函数 ============ + +// NewRenewPoolManager creates a new renew pool manager-example with default config | 使用默认配置创建续期池管理器 +func NewRenewPoolManager() *ants.RenewPoolManager { + return ants.NewRenewPoolManagerWithDefaultConfig() +} + +// NewRenewPoolManagerWithConfig creates a new renew pool manager-example with config | 使用配置创建续期池管理器 +func NewRenewPoolManagerWithConfig(cfg *ants.RenewPoolConfig) (*ants.RenewPoolManager, error) { + return ants.NewRenewPoolManagerWithConfig(cfg) +} + +// ============ Token Style Constants | Token风格常量 ============ + +const ( + // TokenStyleUUID UUID style | UUID风格 + TokenStyleUUID = adapter.TokenStyleUUID + // TokenStyleSimple Simple random string | 简单随机字符串 + TokenStyleSimple = adapter.TokenStyleSimple + // TokenStyleRandom32 32-bit random string | 32位随机字符串 + TokenStyleRandom32 = adapter.TokenStyleRandom32 + // TokenStyleRandom64 64-bit random string | 64位随机字符串 + TokenStyleRandom64 = adapter.TokenStyleRandom64 + // TokenStyleRandom128 128-bit random string | 128位随机字符串 + TokenStyleRandom128 = adapter.TokenStyleRandom128 + // TokenStyleJWT JWT style | JWT风格 + TokenStyleJWT = adapter.TokenStyleJWT + // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 + TokenStyleHash = adapter.TokenStyleHash + // TokenStyleTimestamp Timestamp-based style | 时间戳风格 + TokenStyleTimestamp = adapter.TokenStyleTimestamp + // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID + TokenStyleTik = adapter.TokenStyleTik +) + +// ============ Log Level Constants | 日志级别常量 ============ + +const ( + // LogLevelDebug Debug level | 调试级别 + LogLevelDebug = adapter.LogLevelDebug + // LogLevelInfo Info level | 信息级别 + LogLevelInfo = adapter.LogLevelInfo + // LogLevelWarn Warn level | 警告级别 + LogLevelWarn = adapter.LogLevelWarn + // LogLevelError Error level | 错误级别 + LogLevelError = adapter.LogLevelError +) + // ============ Authentication | 登录认证 ============ // Login performs user login | 用户登录 -func Login(loginID interface{}, device ...string) (string, error) { - return stputil.Login(loginID, device...) +func Login(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.Login(ctx, loginID, deviceOrAutoType...) } // LoginByToken performs login with specified token | 使用指定Token登录 -func LoginByToken(loginID interface{}, tokenValue string, device ...string) error { - return stputil.LoginByToken(loginID, tokenValue, device...) +func LoginByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LoginByToken(ctx, tokenValue, authType...) } // Logout performs user logout | 用户登出 -func Logout(loginID interface{}, device ...string) error { - return stputil.Logout(loginID, device...) +func Logout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Logout(ctx, loginID, deviceOrAutoType...) } // LogoutByToken performs logout by token | 根据Token登出 -func LogoutByToken(tokenValue string) error { - return stputil.LogoutByToken(tokenValue) +func LogoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LogoutByToken(ctx, tokenValue, authType...) +} + +// Kickout kicks out a user session | 踢人下线 +func Kickout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Kickout(ctx, loginID, deviceOrAutoType...) +} + +// KickoutByToken Kick user offline | 根据Token踢人下线 +func KickoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.KickoutByToken(ctx, tokenValue, authType...) +} + +// Replace user offline by login ID and device | 根据账号和设备顶人下线 +func Replace(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Replace(ctx, loginID, deviceOrAutoType...) +} + +// ReplaceByToken Replace user offline by token | 根据Token顶人下线 +func ReplaceByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.ReplaceByToken(ctx, tokenValue, authType...) } +// ============ Token Validation | Token验证 ============ + // IsLogin checks if the user is logged in | 检查用户是否已登录 -func IsLogin(tokenValue string) bool { - return stputil.IsLogin(tokenValue) +func IsLogin(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsLogin(ctx, tokenValue, authType...) } // CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) -func CheckLogin(tokenValue string) error { - return stputil.CheckLogin(tokenValue) +func CheckLogin(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.CheckLogin(ctx, tokenValue, authType...) +} + +// CheckLoginWithState checks the login status (returns error to determine the reason if not logged in) | 检查登录状态(未登录时根据错误确定原因) +func CheckLoginWithState(ctx context.Context, tokenValue string, authType ...string) (bool, error) { + return stputil.CheckLoginWithState(ctx, tokenValue, authType...) } // GetLoginID gets the login ID from token | 从Token获取登录ID -func GetLoginID(tokenValue string) (string, error) { - return stputil.GetLoginID(tokenValue) +func GetLoginID(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginID(ctx, tokenValue, authType...) } -// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查) -func GetLoginIDNotCheck(tokenValue string) (string, error) { - return stputil.GetLoginIDNotCheck(tokenValue) +// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查登录状态) +func GetLoginIDNotCheck(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginIDNotCheck(ctx, tokenValue, authType...) } // GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 -func GetTokenValue(loginID interface{}, device ...string) (string, error) { - return stputil.GetTokenValue(loginID, device...) +func GetTokenValue(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.GetTokenValue(ctx, loginID, deviceOrAutoType...) } // GetTokenInfo gets token information | 获取Token信息 -func GetTokenInfo(tokenValue string) (*TokenInfo, error) { - return stputil.GetTokenInfo(tokenValue) +func GetTokenInfo(ctx context.Context, tokenValue string, authType ...string) (*manager.TokenInfo, error) { + return stputil.GetTokenInfo(ctx, tokenValue, authType...) } -// ============ Kickout | 踢人下线 ============ +// ============ Account Disable | 账号封禁 ============ -// Kickout kicks out a user session | 踢人下线 -func Kickout(loginID interface{}, device ...string) error { - return stputil.Kickout(loginID, device...) +// Disable disables an account for specified duration | 封禁账号(指定时长) +func Disable(ctx context.Context, loginID interface{}, duration time.Duration, authType ...string) error { + return stputil.Disable(ctx, loginID, duration, authType...) } -// ============ Account Disable | 账号封禁 ============ +// DisableByToken disables the account associated with the given token for a duration | 根据指定 Token 封禁其对应的账号 +func DisableByToken(ctx context.Context, tokenValue string, duration time.Duration, authType ...string) error { + return stputil.DisableByToken(ctx, tokenValue, duration, authType...) +} -// Disable disables an account for specified duration | 封禁账号(指定时长) -func Disable(loginID interface{}, duration time.Duration) error { - return stputil.Disable(loginID, duration) +// Untie re-enables a disabled account | 解封账号 +func Untie(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.Untie(ctx, loginID, authType...) +} + +// UntieByToken re-enables a disabled account by token | 根据Token解封账号 +func UntieByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.UntieByToken(ctx, tokenValue, authType...) } // IsDisable checks if an account is disabled | 检查账号是否被封禁 -func IsDisable(loginID interface{}) bool { - return stputil.IsDisable(loginID) +func IsDisable(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.IsDisable(ctx, loginID, authType...) } -// CheckDisable checks if account is disabled (throws error if disabled) | 检查账号是否被封禁(被封禁则抛出错误) -func CheckDisableByToken(tokenValue string) error { - return stputil.CheckDisable(tokenValue) +// IsDisableByToken checks if an account is disabled by token | 根据Token检查账号是否被封禁 +func IsDisableByToken(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsDisableByToken(ctx, tokenValue, authType...) } -// GetDisableTime gets remaining disabled time | 获取账号剩余封禁时间 -func GetDisableTime(loginID interface{}) (int64, error) { - return stputil.GetDisableTime(loginID) +// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) +func GetDisableTime(ctx context.Context, loginID interface{}, authType ...string) (int64, error) { + return stputil.GetDisableTime(ctx, loginID, authType...) } -// Untie unties/unlocks an account | 解除账号封禁 -func Untie(loginID interface{}) error { - return stputil.Untie(loginID) +// GetDisableTimeByToken gets remaining disable time by token | 根据Token获取剩余封禁时间(秒) +func GetDisableTimeByToken(ctx context.Context, tokenValue string, authType ...string) (int64, error) { + return stputil.GetDisableTimeByToken(ctx, tokenValue, authType...) } -// ============ Permission Check | 权限验证 ============ +// CheckDisableWithInfo gets disable info | 获取封禁信息 +func CheckDisableWithInfo(ctx context.Context, loginID interface{}, authType ...string) (*manager.DisableInfo, error) { + return stputil.CheckDisableWithInfo(ctx, loginID, authType...) +} -// CheckPermission checks if the account has specified permission | 检查账号是否拥有指定权限 -func CheckPermissionByToken(tokenValue string, permission string) error { - return stputil.CheckPermission(tokenValue, permission) +// CheckDisableWithInfoByToken gets disable info by token | 根据Token获取封禁信息 +func CheckDisableWithInfoByToken(ctx context.Context, tokenValue string, authType ...string) (*manager.DisableInfo, error) { + return stputil.CheckDisableWithInfoByToken(ctx, tokenValue, authType...) } -// HasPermission checks if the account has specified permission (returns bool) | 检查账号是否拥有指定权限(返回布尔值) -func HasPermission(loginID interface{}, permission string) bool { - return stputil.HasPermission(loginID, permission) +// ============ Session Management | Session管理 ============ + +// GetSession gets session by login ID | 根据登录ID获取Session +func GetSession(ctx context.Context, loginID interface{}, authType ...string) (*session.Session, error) { + return stputil.GetSession(ctx, loginID, authType...) +} + +// GetSessionByToken gets session by token | 根据Token获取Session +func GetSessionByToken(ctx context.Context, tokenValue string, authType ...string) (*session.Session, error) { + return stputil.GetSessionByToken(ctx, tokenValue, authType...) } -// CheckPermissionAnd checks if the account has all specified permissions (AND logic) | 检查账号是否拥有所有指定权限(AND逻辑) -func CheckPermissionAndByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionAnd(tokenValue, permissions) +// DeleteSession deletes a session | 删除Session +func DeleteSession(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.DeleteSession(ctx, loginID, authType...) } -// CheckPermissionOr checks if the account has any of the specified permissions (OR logic) | 检查账号是否拥有指定权限中的任意一个(OR逻辑) -func CheckPermissionOrByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionOr(tokenValue, permissions) +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func DeleteSessionByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.DeleteSessionByToken(ctx, tokenValue, authType...) } -// GetPermissionList gets the permission list for an account | 获取账号的权限列表 -func GetPermissionListByToken(tokenValue string) ([]string, error) { - return stputil.GetPermissionList(tokenValue) +// HasSession checks if session exists | 检查Session是否存在 +func HasSession(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.HasSession(ctx, loginID, authType...) } -// ============ Role Check | 角色验证 ============ +// RenewSession renews session TTL | 续期Session +func RenewSession(ctx context.Context, loginID interface{}, ttl time.Duration, authType ...string) error { + return stputil.RenewSession(ctx, loginID, ttl, authType...) +} -// CheckRole checks if the account has specified role | 检查账号是否拥有指定角色 -func CheckRoleByToken(tokenValue string, role string) error { - return stputil.CheckRole(tokenValue, role) +// RenewSessionByToken renews session TTL by token | 根据Token续期Session +func RenewSessionByToken(ctx context.Context, tokenValue string, ttl time.Duration, authType ...string) error { + return stputil.RenewSessionByToken(ctx, tokenValue, ttl, authType...) } -// HasRole checks if the account has specified role (returns bool) | 检查账号是否拥有指定角色(返回布尔值) -func HasRole(loginID interface{}, role string) bool { - return stputil.HasRole(loginID, role) +// ============ Permission Verification | 权限验证 ============ + +// SetPermissions sets permissions for a login ID | 设置用户权限 +func SetPermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.SetPermissions(ctx, loginID, permissions, authType...) } -// CheckRoleAnd checks if the account has all specified roles (AND logic) | 检查账号是否拥有所有指定角色(AND逻辑) -func CheckRoleAndByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleAnd(tokenValue, roles) +// SetPermissionsByToken sets permissions by token | 根据 Token 设置对应账号的权限 +func SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.SetPermissionsByToken(ctx, tokenValue, permissions, authType...) } -// CheckRoleOr checks if the account has any of the specified roles (OR logic) | 检查账号是否拥有指定角色中的任意一个(OR逻辑) -func CheckRoleOrByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleOr(tokenValue, roles) +// RemovePermissions removes specified permissions for a login ID | 删除用户指定权限 +func RemovePermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.RemovePermissions(ctx, loginID, permissions, authType...) } -// GetRoleList gets the role list for an account | 获取账号的角色列表 -func GetRoleListByToken(tokenValue string) ([]string, error) { - return stputil.GetRoleList(tokenValue) +// RemovePermissionsByToken removes specified permissions by token | 根据 Token 删除对应账号的指定权限 +func RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.RemovePermissionsByToken(ctx, tokenValue, permissions, authType...) } -// ============ Session Management | Session管理 ============ +// GetPermissions gets permission list | 获取权限列表 +func GetPermissions(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetPermissions(ctx, loginID, authType...) +} + +// GetPermissionsByToken gets permission list by token | 根据 Token 获取对应账号的权限列表 +func GetPermissionsByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetPermissionsByToken(ctx, tokenValue, authType...) +} + +// HasPermission checks if has specified permission | 检查是否拥有指定权限 +func HasPermission(ctx context.Context, loginID interface{}, permission string, authType ...string) bool { + return stputil.HasPermission(ctx, loginID, permission, authType...) +} + +// HasPermissionByToken checks if the token has the specified permission | 检查Token是否拥有指定权限 +func HasPermissionByToken(ctx context.Context, tokenValue string, permission string, authType ...string) bool { + return stputil.HasPermissionByToken(ctx, tokenValue, permission, authType...) +} + +// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) +func HasPermissionsAnd(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAnd(ctx, loginID, permissions, authType...) +} + +// HasPermissionsAndByToken checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 +func HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAndByToken(ctx, tokenValue, permissions, authType...) +} + +// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) +func HasPermissionsOr(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOr(ctx, loginID, permissions, authType...) +} + +// HasPermissionsOrByToken checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 +func HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOrByToken(ctx, tokenValue, permissions, authType...) +} + +// ============ Role Management | 角色管理 ============ + +// SetRoles sets roles for a login ID | 设置用户角色 +func SetRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.SetRoles(ctx, loginID, roles, authType...) +} + +// SetRolesByToken sets roles by token | 根据 Token 设置对应账号的角色 +func SetRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.SetRolesByToken(ctx, tokenValue, roles, authType...) +} + +// RemoveRoles removes specified roles for a login ID | 删除用户指定角色 +func RemoveRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.RemoveRoles(ctx, loginID, roles, authType...) +} + +// RemoveRolesByToken removes specified roles by token | 根据 Token 删除对应账号的指定角色 +func RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.RemoveRolesByToken(ctx, tokenValue, roles, authType...) +} + +// GetRoles gets role list | 获取角色列表 +func GetRoles(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetRoles(ctx, loginID, authType...) +} + +// GetRolesByToken gets role list by token | 根据 Token 获取对应账号的角色列表 +func GetRolesByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetRolesByToken(ctx, tokenValue, authType...) +} + +// HasRole checks if has specified role | 检查是否拥有指定角色 +func HasRole(ctx context.Context, loginID interface{}, role string, authType ...string) bool { + return stputil.HasRole(ctx, loginID, role, authType...) +} + +// HasRoleByToken checks if the token has the specified role | 检查 Token 是否拥有指定角色 +func HasRoleByToken(ctx context.Context, tokenValue string, role string, authType ...string) bool { + return stputil.HasRoleByToken(ctx, tokenValue, role, authType...) +} + +// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) +func HasRolesAnd(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesAnd(ctx, loginID, roles, authType...) +} -// GetSession gets the session for a login ID | 获取登录ID的Session -func GetSession(loginID interface{}) (*Session, error) { - return stputil.GetSession(loginID) +// HasRolesAndByToken checks if the token has all specified roles | 检查 Token 是否拥有所有指定角色 +func HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesAndByToken(ctx, tokenValue, roles, authType...) } -// GetSessionByToken gets the session by token | 根据Token获取Session -func GetSessionByToken(tokenValue string) (*Session, error) { - return stputil.GetSessionByToken(tokenValue) +// HasRolesOr checks if has any role (OR logic) | 检查是否拥有任一角色(OR逻辑) +func HasRolesOr(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesOr(ctx, loginID, roles, authType...) } -// GetTokenSession gets the token session | 获取Token的Session -func GetTokenSession(tokenValue string) (*Session, error) { - return stputil.GetTokenSession(tokenValue) +// HasRolesOrByToken checks if the token has any of the specified roles | 检查 Token 是否拥有任一指定角色 +func HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesOrByToken(ctx, tokenValue, roles, authType...) } -// ============ Token Renewal | Token续期 ============ +// ============ Token Tag | Token标签 ============ + +// SetTokenTag sets token tag | 设置Token标签 +func SetTokenTag(ctx context.Context, tokenValue, tag string, authType ...string) error { + return stputil.SetTokenTag(ctx, tokenValue, tag, authType...) +} + +// GetTokenTag gets token tag | 获取Token标签 +func GetTokenTag(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetTokenTag(ctx, tokenValue, authType...) +} -// RenewTimeout renews token timeout | 续期Token超时时间 +// ============ Session Query | 会话查询 ============ + +// GetTokenValueListByLoginID gets all tokens for a login ID | 获取指定账号的所有Token +func GetTokenValueListByLoginID(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetTokenValueListByLoginID(ctx, loginID, authType...) +} + +// GetSessionCountByLoginID gets session count for a login ID | 获取指定账号的Session数量 +func GetSessionCountByLoginID(ctx context.Context, loginID interface{}, authType ...string) (int, error) { + return stputil.GetSessionCountByLoginID(ctx, loginID, authType...) +} // ============ Security Features | 安全特性 ============ -// GenerateNonce generates a new nonce token | 生成新的Nonce令牌 -func GenerateNonce() (string, error) { - return stputil.GenerateNonce() +// Generate Generates a one-time nonce | 生成一次性随机数 +func Generate(ctx context.Context, authType ...string) (string, error) { + return stputil.Generate(ctx, authType...) +} + +// Verify Verifies a nonce | 验证随机数 +func Verify(ctx context.Context, nonce string, authType ...string) bool { + return stputil.Verify(ctx, nonce, authType...) } -// VerifyNonce verifies a nonce token | 验证Nonce令牌 -func VerifyNonce(nonce string) bool { - return stputil.VerifyNonce(nonce) +// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func VerifyAndConsume(ctx context.Context, nonce string, authType ...string) error { + return stputil.VerifyAndConsume(ctx, nonce, authType...) } -// LoginWithRefreshToken performs login and returns both access token and refresh token | 登录并返回访问令牌和刷新令牌 -func LoginWithRefreshToken(loginID interface{}, device ...string) (*RefreshTokenInfo, error) { - return stputil.LoginWithRefreshToken(loginID, device...) +// IsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func IsValidNonce(ctx context.Context, nonce string, authType ...string) bool { + return stputil.IsValidNonce(ctx, nonce, authType...) } -// RefreshAccessToken refreshes the access token using a refresh token | 使用刷新令牌刷新访问令牌 -func RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) { - return stputil.RefreshAccessToken(refreshToken) +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func GenerateTokenPair(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GenerateTokenPair(ctx, loginID, deviceOrAutoType...) } -// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌 -func RevokeRefreshToken(refreshToken string) error { - return stputil.RevokeRefreshToken(refreshToken) +// VerifyAccessToken verifies access token validity | 验证访问令牌是否有效 +func VerifyAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.VerifyAccessToken(ctx, accessToken, authType...) +} + +// VerifyAccessTokenAndGetInfo verifies access token and returns token info | 验证访问令牌并返回Token信息 +func VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*security.AccessTokenInfo, bool) { + return stputil.VerifyAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息 +func GetRefreshTokenInfo(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GetRefreshTokenInfo(ctx, refreshToken, authType...) +} + +// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func RefreshAccessToken(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.RefreshAccessToken(ctx, refreshToken, authType...) +} + +// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 +func RevokeRefreshToken(ctx context.Context, refreshToken string, authType ...string) error { + return stputil.RevokeRefreshToken(ctx, refreshToken, authType...) +} + +// IsValid checks whether token is valid | 检查Token是否有效 +func IsValid(ctx context.Context, refreshToken string, authType ...string) bool { + return stputil.IsValid(ctx, refreshToken, authType...) +} + +// ============ OAuth2 Features | OAuth2 功能 ============ + +// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func RegisterClient(ctx context.Context, client *oauth2.Client, authType ...string) error { + return stputil.RegisterClient(ctx, client, authType...) } -// GetOAuth2Server gets the OAuth2 server instance | 获取OAuth2服务器实例 -func GetOAuth2Server() *OAuth2Server { - return stputil.GetOAuth2Server() +// UnregisterClient unregisters an OAuth2 client | 注销OAuth2客户端 +func UnregisterClient(ctx context.Context, clientID string, authType ...string) error { + return stputil.UnregisterClient(ctx, clientID, authType...) } -// Version Sa-Token-Go version | Sa-Token-Go版本 -const Version = core.Version +// GetClient gets OAuth2 client information | 获取OAuth2客户端信息 +func GetClient(ctx context.Context, clientID string, authType ...string) (*oauth2.Client, error) { + return stputil.GetClient(ctx, clientID, authType...) +} + +// GenerateAuthorizationCode creates an authorization code | 创建授权码 +func GenerateAuthorizationCode(ctx context.Context, clientID, loginID, redirectURI string, scope []string, authType ...string) (*oauth2.AuthorizationCode, error) { + return stputil.GenerateAuthorizationCode(ctx, clientID, loginID, redirectURI, scope, authType...) +} + +// ExchangeCodeForToken exchanges authorization code for token | 使用授权码换取令牌 +func ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI, authType...) +} + +// ValidateAccessToken verifies OAuth2 access token | 验证OAuth2访问令牌 +func ValidateAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.ValidateAccessToken(ctx, accessToken, authType...) +} + +// ValidateAccessTokenAndGetInfo verifies OAuth2 access token and get info | 验证OAuth2访问令牌并获取信息 +func ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ValidateAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌(OAuth2) +func OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2RefreshAccessToken(ctx, clientID, refreshToken, clientSecret, authType...) +} + +// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func RevokeToken(ctx context.Context, accessToken string, authType ...string) error { + return stputil.RevokeToken(ctx, accessToken, authType...) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点 +func OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2Token(ctx, req, validateUser, authType...) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2ClientCredentialsToken(ctx, clientID, clientSecret, scopes, authType...) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser, authType...) +} + +// ============ OAuth2 Grant Type Constants | OAuth2授权类型常量 ============ + +const ( + // GrantTypeAuthorizationCode Authorization code grant type | 授权码模式 + GrantTypeAuthorizationCode = oauth2.GrantTypeAuthorizationCode + // GrantTypeClientCredentials Client credentials grant type | 客户端凭证模式 + GrantTypeClientCredentials = oauth2.GrantTypeClientCredentials + // GrantTypePassword Password grant type | 密码模式 + GrantTypePassword = oauth2.GrantTypePassword + // GrantTypeRefreshToken Refresh token grant type | 刷新令牌模式 + GrantTypeRefreshToken = oauth2.GrantTypeRefreshToken +) + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func GetConfig(ctx context.Context, authType ...string) *config.Config { + return stputil.GetConfig(ctx, authType...) +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func GetStorage(ctx context.Context, authType ...string) adapter.Storage { + return stputil.GetStorage(ctx, authType...) +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func GetCodec(ctx context.Context, authType ...string) adapter.Codec { + return stputil.GetCodec(ctx, authType...) +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func GetLog(ctx context.Context, authType ...string) adapter.Log { + return stputil.GetLog(ctx, authType...) +} + +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func GetPool(ctx context.Context, authType ...string) adapter.Pool { + return stputil.GetPool(ctx, authType...) +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func GetGenerator(ctx context.Context, authType ...string) adapter.Generator { + return stputil.GetGenerator(ctx, authType...) +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func GetNonceManager(ctx context.Context, authType ...string) *security.NonceManager { + return stputil.GetNonceManager(ctx, authType...) +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func GetRefreshManager(ctx context.Context, authType ...string) *security.RefreshTokenManager { + return stputil.GetRefreshManager(ctx, authType...) +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func GetEventManager(ctx context.Context, authType ...string) *listener.Manager { + return stputil.GetEventManager(ctx, authType...) +} + +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func GetOAuth2Server(ctx context.Context, authType ...string) *oauth2.OAuth2Server { + return stputil.GetOAuth2Server(ctx, authType...) +} + +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func RegisterFunc(event listener.Event, fn func(*listener.EventData), authType ...string) { + stputil.RegisterFunc(event, fn, authType...) +} + +// Register registers an event listener | 注册事件监听器 +func Register(event listener.Event, l listener.Listener, authType ...string) string { + return stputil.Register(event, l, authType...) +} + +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func RegisterWithConfig(event listener.Event, l listener.Listener, cfg listener.ListenerConfig, authType ...string) string { + return stputil.RegisterWithConfig(event, l, cfg, authType...) +} + +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func Unregister(id string, authType ...string) bool { + return stputil.Unregister(id, authType...) +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func TriggerEvent(data *listener.EventData, authType ...string) { + stputil.TriggerEvent(data, authType...) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func TriggerEventAsync(data *listener.EventData, authType ...string) { + stputil.TriggerEventAsync(data, authType...) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func TriggerEventSync(data *listener.EventData, authType ...string) { + stputil.TriggerEventSync(data, authType...) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func WaitEvents(authType ...string) { + stputil.WaitEvents(authType...) +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func ClearEventListeners(event listener.Event, authType ...string) { + stputil.ClearEventListeners(event, authType...) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func ClearAllEventListeners(authType ...string) { + stputil.ClearAllEventListeners(authType...) +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func CountEventListeners(event listener.Event, authType ...string) int { + return stputil.CountEventListeners(event, authType...) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func CountAllListeners(authType ...string) int { + return stputil.CountAllListeners(authType...) +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func GetEventListenerIDs(event listener.Event, authType ...string) []string { + return stputil.GetEventListenerIDs(event, authType...) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func GetAllRegisteredEvents(authType ...string) []listener.Event { + return stputil.GetAllRegisteredEvents(authType...) +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func HasEventListeners(event listener.Event, authType ...string) bool { + return stputil.HasEventListeners(event, authType...) +} diff --git a/integrations/echo/go.mod b/integrations/echo/go.mod index ae7ee59..96b6828 100644 --- a/integrations/echo/go.mod +++ b/integrations/echo/go.mod @@ -1,17 +1,15 @@ module github.com/click33/sa-token-go/integrations/echo -go 1.23.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 - github.com/click33/sa-token-go/stputil v0.1.5 + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 github.com/labstack/echo/v4 v4.11.4 ) require ( - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -21,9 +19,7 @@ require ( github.com/valyala/fasttemplate v1.2.2 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect ) - -replace github.com/click33/sa-token-go/core => ../../core diff --git a/integrations/echo/go.sum b/integrations/echo/go.sum index 7077012..b57347a 100644 --- a/integrations/echo/go.sum +++ b/integrations/echo/go.sum @@ -1,7 +1,7 @@ -github.com/click33/sa-token-go/stputil v0.1.4 h1:YvMEwPfAfTunQn+AePudO3Esp0CvLoc2o5kmg/uZf/c= -github.com/click33/sa-token-go/stputil v0.1.4/go.mod h1:NiFR1mUb43QRcybueAjyMeXxx92RngvpP4eLbglxEb0= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/labstack/echo/v4 v4.11.4 h1:vDZmA+qNeh1pd/cCkEicDMrjtrnMGQ1QFI9gWN1zGq8= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= @@ -14,7 +14,7 @@ github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6Kllzaw github.com/valyala/fasttemplate v1.2.2 h1:lxLXG0uE3Qnshl9QyaK6XJxMXlQZELvChBOCmQD0Loo= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/integrations/echo/middleware.go b/integrations/echo/middleware.go new file mode 100644 index 0000000..2f6017d --- /dev/null +++ b/integrations/echo/middleware.go @@ -0,0 +1,316 @@ +package echo + +import ( + "errors" + "net/http" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/manager" + + saContext "github.com/click33/sa-token-go/core/context" + "github.com/click33/sa-token-go/stputil" + "github.com/labstack/echo/v4" +) + +// LogicType permission/role logic type | 权限/角色判断逻辑 +type LogicType string + +const ( + SaTokenCtxKey = "saCtx" + + LogicOr LogicType = "OR" // Logical OR | 任一满足 + LogicAnd LogicType = "AND" // Logical AND | 全部满足 +) + +type AuthOption func(*AuthOptions) + +type AuthOptions struct { + AuthType string + LogicType LogicType + FailFunc func(c echo.Context, err error) error +} + +func defaultAuthOptions() *AuthOptions { + return &AuthOptions{LogicType: LogicAnd} // 默认 AND +} + +// WithAuthType sets auth type | 设置认证类型 +func WithAuthType(authType string) AuthOption { + return func(o *AuthOptions) { + o.AuthType = authType + } +} + +// WithLogicType sets LogicType option | 设置逻辑类型 +func WithLogicType(logicType LogicType) AuthOption { + return func(o *AuthOptions) { + o.LogicType = logicType + } +} + +// WithFailFunc sets auth failure callback | 设置认证失败回调 +func WithFailFunc(fn func(c echo.Context, err error) error) AuthOption { + return func(o *AuthOptions) { + o.FailFunc = fn + } +} + +// ========== Middlewares ========== + +// AuthMiddleware authentication middleware | 认证中间件 +func AuthMiddleware(opts ...AuthOption) echo.MiddlewareFunc { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 获取 token | Get token + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录 | Check login + err = mgr.CheckLogin(c.Request().Context(), tokenValue) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + return next(c) + } + } +} + +// AuthWithStateMiddleware with state authentication middleware | 带状态返回的认证中间件 +func AuthWithStateMiddleware(opts ...AuthOption) echo.MiddlewareFunc { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // 获取 Manager | Get Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录并返回状态 | Check login with state + _, err = mgr.CheckLoginWithState(c.Request().Context(), tokenValue) + + if err != nil { + // 用户自定义回调优先 + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + return next(c) + } + } +} + +// PermissionMiddleware permission check middleware | 权限校验中间件 +func PermissionMiddleware( + permissions []string, + opts ...AuthOption, +) echo.MiddlewareFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // No permission required | 无需权限直接放行 + if len(permissions) == 0 { + return next(c) + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + ctx := c.Request().Context() + + // Permission check | 权限校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasPermissionsAndByToken(ctx, tokenValue, permissions) + } else { + ok = mgr.HasPermissionsOrByToken(ctx, tokenValue, permissions) + } + + if !ok { + if options.FailFunc != nil { + return options.FailFunc(c, core.ErrPermissionDenied) + } + return writeErrorResponse(c, core.ErrPermissionDenied) + } + + return next(c) + } + } +} + +// RoleMiddleware role check middleware | 角色校验中间件 +func RoleMiddleware( + roles []string, + opts ...AuthOption, +) echo.MiddlewareFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + // No role required | 无需角色直接放行 + if len(roles) == 0 { + return next(c) + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + ctx := c.Request().Context() + + // Role check | 角色校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasRolesAndByToken(ctx, tokenValue, roles) + } else { + ok = mgr.HasRolesOrByToken(ctx, tokenValue, roles) + } + + if !ok { + if options.FailFunc != nil { + return options.FailFunc(c, core.ErrRoleDenied) + } + return writeErrorResponse(c, core.ErrRoleDenied) + } + + return next(c) + } + } +} + +// GetSaTokenContext gets Sa-Token context from Echo context | 获取 Sa-Token 上下文 +func GetSaTokenContext(c echo.Context) (*saContext.SaTokenContext, bool) { + v := c.Get(SaTokenCtxKey) + if v == nil { + return nil, false + } + + ctx, ok := v.(*saContext.SaTokenContext) + return ctx, ok +} + +func getSaContext(c echo.Context, mgr *manager.Manager) *saContext.SaTokenContext { + // Try get from context | 尝试从 ctx 取值 + if v := c.Get(SaTokenCtxKey); v != nil { + if saCtx, ok := v.(*saContext.SaTokenContext); ok { + return saCtx + } + } + + // Create new context | 创建并缓存 SaTokenContext + saCtx := saContext.NewContext(NewEchoContext(c), mgr) + c.Set(SaTokenCtxKey, saCtx) + + return saCtx +} + +// ============ Error Handling Helpers | 错误处理辅助函数 ============ + +// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 +func writeErrorResponse(c echo.Context, err error) error { + var saErr *core.SaTokenError + var code int + var message string + var httpStatus int + + // Check if it's a SaTokenError | 检查是否为SaTokenError + if errors.As(err, &saErr) { + code = saErr.Code + message = saErr.Message + httpStatus = getHTTPStatusFromCode(code) + } else { + // Handle standard errors | 处理标准错误 + code = core.CodeServerError + message = err.Error() + httpStatus = http.StatusInternalServerError + } + + return c.JSON(httpStatus, map[string]interface{}{ + "code": code, + "message": message, + "data": err.Error(), + }) +} + +// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 +func writeSuccessResponse(c echo.Context, data interface{}) error { + return c.JSON(http.StatusOK, map[string]interface{}{ + "code": core.CodeSuccess, + "message": "success", + "data": data, + }) +} + +// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 +func getHTTPStatusFromCode(code int) int { + switch code { + case core.CodeNotLogin: + return http.StatusUnauthorized + case core.CodePermissionDenied: + return http.StatusForbidden + case core.CodeBadRequest: + return http.StatusBadRequest + case core.CodeNotFound: + return http.StatusNotFound + case core.CodeServerError: + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} diff --git a/integrations/echo/plugin.go b/integrations/echo/plugin.go deleted file mode 100644 index 50299c8..0000000 --- a/integrations/echo/plugin.go +++ /dev/null @@ -1,172 +0,0 @@ -package echo - -import ( - "errors" - "net/http" - - "github.com/click33/sa-token-go/core" - "github.com/labstack/echo/v4" -) - -// Plugin Echo plugin for Sa-Token | Echo插件 -type Plugin struct { - manager *core.Manager -} - -// NewPlugin creates an Echo plugin | 创建Echo插件 -func NewPlugin(manager *core.Manager) *Plugin { - return &Plugin{ - manager: manager, - } -} - -// AuthMiddleware authentication middleware | 认证中间件 -func (p *Plugin) AuthMiddleware() echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - ctx := NewEchoContext(c) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - return writeErrorResponse(c, err) - } - - c.Set("satoken", saCtx) - return next(c) - } - } -} - -// PermissionRequired permission validation middleware | 权限验证中间件 -func (p *Plugin) PermissionRequired(permission string) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - ctx := NewEchoContext(c) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - return writeErrorResponse(c, err) - } - - if !saCtx.HasPermission(permission) { - return writeErrorResponse(c, core.NewPermissionDeniedError(permission)) - } - - c.Set("satoken", saCtx) - return next(c) - } - } -} - -// RoleRequired role validation middleware | 角色验证中间件 -func (p *Plugin) RoleRequired(role string) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - ctx := NewEchoContext(c) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - return writeErrorResponse(c, err) - } - - if !saCtx.HasRole(role) { - return writeErrorResponse(c, core.NewRoleDeniedError(role)) - } - - c.Set("satoken", saCtx) - return next(c) - } - } -} - -// LoginHandler 登录处理器 -func (p *Plugin) LoginHandler(c echo.Context) error { - var req struct { - Username string `json:"username"` - Password string `json:"password"` - Device string `json:"device"` - } - - if err := c.Bind(&req); err != nil { - return writeErrorResponse(c, core.NewError(core.CodeBadRequest, "invalid request parameters", err)) - } - - device := req.Device - if device == "" { - device = "default" - } - - token, err := p.manager.Login(req.Username, device) - if err != nil { - return writeErrorResponse(c, core.NewError(core.CodeServerError, "login failed", err)) - } - - return writeSuccessResponse(c, map[string]interface{}{ - "token": token, - }) -} - -// GetSaToken 从Echo上下文获取Sa-Token上下文 -func GetSaToken(c echo.Context) (*core.SaTokenContext, bool) { - satoken := c.Get("satoken") - if satoken == nil { - return nil, false - } - ctx, ok := satoken.(*core.SaTokenContext) - return ctx, ok -} - -// ============ Error Handling Helpers | 错误处理辅助函数 ============ - -// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 -func writeErrorResponse(c echo.Context, err error) error { - var saErr *core.SaTokenError - var code int - var message string - var httpStatus int - - // Check if it's a SaTokenError | 检查是否为SaTokenError - if errors.As(err, &saErr) { - code = saErr.Code - message = saErr.Message - httpStatus = getHTTPStatusFromCode(code) - } else { - // Handle standard errors | 处理标准错误 - code = core.CodeServerError - message = err.Error() - httpStatus = http.StatusInternalServerError - } - - return c.JSON(httpStatus, map[string]interface{}{ - "code": code, - "message": message, - "error": err.Error(), - }) -} - -// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 -func writeSuccessResponse(c echo.Context, data interface{}) error { - return c.JSON(http.StatusOK, map[string]interface{}{ - "code": core.CodeSuccess, - "message": "success", - "data": data, - }) -} - -// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 -func getHTTPStatusFromCode(code int) int { - switch code { - case core.CodeNotLogin: - return http.StatusUnauthorized - case core.CodePermissionDenied: - return http.StatusForbidden - case core.CodeBadRequest: - return http.StatusBadRequest - case core.CodeNotFound: - return http.StatusNotFound - case core.CodeServerError: - return http.StatusInternalServerError - default: - return http.StatusInternalServerError - } -} diff --git a/integrations/fiber/annotation.go b/integrations/fiber/annotation.go index 0182710..bc6c03d 100644 --- a/integrations/fiber/annotation.go +++ b/integrations/fiber/annotation.go @@ -1,6 +1,8 @@ +// @Author daixk 2025/12/28 package fiber import ( + "context" "strings" "github.com/click33/sa-token-go/core" @@ -10,17 +12,19 @@ import ( // Annotation annotation structure | 注解结构体 type Annotation struct { - CheckLogin bool `json:"checkLogin"` - CheckRole []string `json:"checkRole"` - CheckPermission []string `json:"checkPermission"` - CheckDisable bool `json:"checkDisable"` - Ignore bool `json:"ignore"` + AuthType string `json:"authType"` // Optional: specify auth type | 可选:指定认证类型 + CheckLogin bool `json:"checkLogin"` // Check login | 检查登录 + CheckRole []string `json:"checkRole"` // Check roles | 检查角色 + CheckPermission []string `json:"checkPermission"` // Check permissions | 检查权限 + CheckDisable bool `json:"checkDisable"` // Check disable status | 检查封禁状态 + Ignore bool `json:"ignore"` // Ignore authentication | 忽略认证 + LogicType LogicType `json:"logicType"` // OR or AND logic (default: OR) | OR 或 AND 逻辑(默认: OR) } // GetHandler gets handler with annotations | 获取带注解的处理器 func GetHandler(handler fiber.Handler, annotations ...*Annotation) fiber.Handler { return func(c *fiber.Ctx) error { - // Check if authentication should be ignored | 检查是否忽略认证 + // Ignore authentication | 忽略认证直接放行 if len(annotations) > 0 && annotations[0].Ignore { if handler != nil { return handler(c) @@ -28,57 +32,81 @@ func GetHandler(handler fiber.Handler, annotations ...*Annotation) fiber.Handler return c.Next() } - // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) - ctx := NewFiberContext(c) - saCtx := core.NewContext(ctx, stputil.GetManager()) + // Check if any authentication is needed | 检查是否需要任何认证 + ann := &Annotation{} + if len(annotations) > 0 { + ann = annotations[0] + } + + // No authentication required | 无需任何认证 + needAuth := ann.CheckLogin || ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 + if !needAuth { + if handler != nil { + return handler(c) + } + return c.Next() + } + + ctx := c.UserContext() + + // Get manager-example | 获取 Manager + mgr, err := stputil.GetManager(ann.AuthType) + if err != nil { + return writeErrorResponse(c, err) + } + + // Get SaTokenContext (reuse cached context) | 获取 SaTokenContext(复用缓存上下文) + saCtx := getSaContext(c, mgr) token := saCtx.GetTokenValue() + if token == "" { return writeErrorResponse(c, core.NewNotLoginError()) } // Check login | 检查登录 - if !stputil.IsLogin(token) { - return writeErrorResponse(c, core.NewNotLoginError()) + if err := mgr.CheckLogin(ctx, token); err != nil { + return writeErrorResponse(c, err) } - // Get login ID | 获取登录ID - loginID, err := stputil.GetLoginID(token) - if err != nil { - return writeErrorResponse(c, err) + // Get loginID for further checks | 获取 loginID 用于后续检查 + var loginID string + if ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 { + loginID, err = mgr.GetLoginIDNotCheck(ctx, token) + if err != nil { + return writeErrorResponse(c, err) + } } // Check if account is disabled | 检查是否被封禁 - if len(annotations) > 0 && annotations[0].CheckDisable { - if stputil.IsDisable(loginID) { + if ann.CheckDisable { + if mgr.IsDisable(ctx, loginID) { return writeErrorResponse(c, core.NewAccountDisabledError(loginID)) } } // Check permission | 检查权限 - if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { - hasPermission := false - for _, perm := range annotations[0].CheckPermission { - if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { - hasPermission = true - break - } + if len(ann.CheckPermission) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasPermissionsAnd(ctx, loginID, ann.CheckPermission) + } else { + ok = mgr.HasPermissionsOr(ctx, loginID, ann.CheckPermission) } - if !hasPermission { - return writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + if !ok { + return writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) } } // Check role | 检查角色 - if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { - hasRole := false - for _, role := range annotations[0].CheckRole { - if stputil.HasRole(loginID, strings.TrimSpace(role)) { - hasRole = true - break - } + if len(ann.CheckRole) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasRolesAnd(ctx, loginID, ann.CheckRole) + } else { + ok = mgr.HasRolesOr(ctx, loginID, ann.CheckRole) } - if !hasRole { - return writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + if !ok { + return writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) } } @@ -91,8 +119,12 @@ func GetHandler(handler fiber.Handler, annotations ...*Annotation) fiber.Handler } // CheckLoginMiddleware decorator for login checking | 检查登录装饰器 -func CheckLoginMiddleware() fiber.Handler { - return GetHandler(nil, &Annotation{CheckLogin: true}) +func CheckLoginMiddleware(authType ...string) fiber.Handler { + ann := &Annotation{CheckLogin: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(nil, ann) } // CheckRoleMiddleware decorator for role checking | 检查角色装饰器 @@ -100,17 +132,136 @@ func CheckRoleMiddleware(roles ...string) fiber.Handler { return GetHandler(nil, &Annotation{CheckRole: roles}) } +// CheckRoleMiddlewareWithAuthType decorator for role checking with auth type | 检查角色装饰器(带认证类型) +func CheckRoleMiddlewareWithAuthType(authType string, roles ...string) fiber.Handler { + return GetHandler(nil, &Annotation{CheckRole: roles, AuthType: authType}) +} + // CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 func CheckPermissionMiddleware(perms ...string) fiber.Handler { return GetHandler(nil, &Annotation{CheckPermission: perms}) } +// CheckPermissionMiddlewareWithAuthType decorator for permission checking with auth type | 检查权限装饰器(带认证类型) +func CheckPermissionMiddlewareWithAuthType(authType string, perms ...string) fiber.Handler { + return GetHandler(nil, &Annotation{CheckPermission: perms, AuthType: authType}) +} + // CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 -func CheckDisableMiddleware() fiber.Handler { - return GetHandler(nil, &Annotation{CheckDisable: true}) +func CheckDisableMiddleware(authType ...string) fiber.Handler { + ann := &Annotation{CheckDisable: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(nil, ann) } // IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 func IgnoreMiddleware() fiber.Handler { return GetHandler(nil, &Annotation{Ignore: true}) } + +// ============ Combined Middleware | 组合中间件 ============ + +// CheckLoginAndRoleMiddleware checks login and role | 检查登录和角色 +func CheckLoginAndRoleMiddleware(roles ...string) fiber.Handler { + return GetHandler(nil, &Annotation{CheckLogin: true, CheckRole: roles}) +} + +// CheckLoginAndPermissionMiddleware checks login and permission | 检查登录和权限 +func CheckLoginAndPermissionMiddleware(perms ...string) fiber.Handler { + return GetHandler(nil, &Annotation{CheckLogin: true, CheckPermission: perms}) +} + +// CheckAllMiddleware checks login, role, permission and disable status | 全面检查 +func CheckAllMiddleware(roles []string, perms []string) fiber.Handler { + return GetHandler(nil, &Annotation{ + CheckLogin: true, + CheckRole: roles, + CheckPermission: perms, + CheckDisable: true, + }) +} + +// ============ Route Group Helper | 路由组辅助函数 ============ + +// AuthGroup creates a route group with authentication | 创建带认证的路由组 +func AuthGroup(group fiber.Router, authType ...string) fiber.Router { + group.Use(CheckLoginMiddleware(authType...)) + return group +} + +// RoleGroup creates a route group with role checking | 创建带角色检查的路由组 +func RoleGroup(group fiber.Router, roles ...string) fiber.Router { + group.Use(CheckLoginAndRoleMiddleware(roles...)) + return group +} + +// PermissionGroup creates a route group with permission checking | 创建带权限检查的路由组 +func PermissionGroup(group fiber.Router, perms ...string) fiber.Router { + group.Use(CheckLoginAndPermissionMiddleware(perms...)) + return group +} + +// ============ Context Helper | 上下文辅助函数 ============ + +// GetLoginIDFromRequest gets login ID from request context | 从请求上下文获取登录 ID +func GetLoginIDFromRequest(c *fiber.Ctx, authType ...string) (string, error) { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return "", err + } + + saCtx := getSaContext(c, mgr) + token := saCtx.GetTokenValue() + if token == "" { + return "", core.ErrNotLogin + } + return mgr.GetLoginID(c.UserContext(), token) +} + +// IsLoginFromRequest checks if user is logged in from request | 从请求检查用户是否已登录 +func IsLoginFromRequest(c *fiber.Ctx, authType ...string) bool { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return false + } + + saCtx := getSaContext(c, mgr) + token := saCtx.GetTokenValue() + if token == "" { + return false + } + return mgr.IsLogin(c.UserContext(), token) +} + +// GetTokenFromRequest gets token from request (exported) | 从请求获取 Token(导出) +func GetTokenFromRequest(c *fiber.Ctx, authType ...string) string { + var at string + if len(authType) > 0 { + at = authType[0] + } + + mgr, err := stputil.GetManager(at) + if err != nil { + return "" + } + + saCtx := getSaContext(c, mgr) + return saCtx.GetTokenValue() +} + +// WithContext creates a new context with sa-token context | 创建带 sa-token 上下文的新上下文 +func WithContext(c *fiber.Ctx, authType ...string) context.Context { + return c.UserContext() +} diff --git a/integrations/fiber/export.go b/integrations/fiber/export.go index 038d0d3..21266ed 100644 --- a/integrations/fiber/export.go +++ b/integrations/fiber/export.go @@ -1,364 +1,950 @@ package fiber import ( + "context" "time" + "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/codec/msgpack" "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" + "github.com/click33/sa-token-go/core/manager" + "github.com/click33/sa-token-go/core/oauth2" + "github.com/click33/sa-token-go/core/security" + "github.com/click33/sa-token-go/core/session" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/log/slog" + "github.com/click33/sa-token-go/pool/ants" + "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/storage/redis" "github.com/click33/sa-token-go/stputil" ) -// ============ Re-export core types | 重新导出核心类型 ============ +// ============ Type Aliases | 类型别名 ============ -// Configuration related types | 配置相关类型 type ( - Config = core.Config - CookieConfig = core.CookieConfig - TokenStyle = core.TokenStyle + // Config 配置 + Config = config.Config + // Manager 管理器 + Manager = manager.Manager + // Session 会话 + Session = session.Session + // TokenInfo Token信息 + TokenInfo = manager.TokenInfo + // DisableInfo 封禁信息 + DisableInfo = manager.DisableInfo + // Builder 构建器 + Builder = builder.Builder + // SaTokenError 错误类型 + SaTokenError = core.SaTokenError + // Event 事件类型 + Event = listener.Event + // EventData 事件数据 + EventData = listener.EventData + // Listener 事件监听器 + Listener = listener.Listener + // ListenerConfig 监听器配置 + ListenerConfig = listener.ListenerConfig + // RefreshTokenInfo 刷新令牌信息 + RefreshTokenInfo = security.RefreshTokenInfo + // AccessTokenInfo 访问令牌信息 + AccessTokenInfo = security.AccessTokenInfo + // OAuth2Client OAuth2客户端 + OAuth2Client = oauth2.Client + // OAuth2AccessToken OAuth2访问令牌 + OAuth2AccessToken = oauth2.AccessToken + // AuthorizationCode 授权码 + AuthorizationCode = oauth2.AuthorizationCode + // OAuth2TokenRequest OAuth2令牌请求 + OAuth2TokenRequest = oauth2.TokenRequest + // OAuth2GrantType OAuth2授权类型 + OAuth2GrantType = oauth2.GrantType + // OAuth2UserValidator OAuth2用户验证器 + OAuth2UserValidator = oauth2.UserValidator + // Storage 存储接口 + Storage = adapter.Storage + // Codec 编解码接口 + Codec = adapter.Codec + // Log 日志接口 + Log = adapter.Log + // Pool 协程池接口 + Pool = adapter.Pool + // Generator 生成器接口 + Generator = adapter.Generator + + // ============ Codec Types | 编解码器类型 ============ + + // JSONSerializer JSON编解码器 + JSONSerializer = json.JSONSerializer + // MsgPackSerializer MsgPack编解码器 + MsgPackSerializer = msgpack.MsgPackSerializer + + // ============ Storage Types | 存储类型 ============ + + // MemoryStorage 内存存储 + MemoryStorage = memory.Storage + // RedisStorage Redis存储 + RedisStorage = redis.Storage + // RedisConfig Redis配置 + RedisConfig = redis.Config + // RedisBuilder Redis构建器 + RedisBuilder = redis.Builder + + // ============ Logger Types | 日志类型 ============ + + // SlogLogger 标准日志实现 + SlogLogger = slog.Logger + // SlogLoggerConfig 标准日志配置 + SlogLoggerConfig = slog.LoggerConfig + // SlogLogLevel 日志级别 + SlogLogLevel = slog.LogLevel + // NopLogger 空日志实现 + NopLogger = nop.NopLogger + + // ============ Generator Types | 生成器类型 ============ + + // TokenGenerator Token生成器 + TokenGenerator = sgenerator.Generator + // TokenStyle Token风格 + TokenStyle = adapter.TokenStyle + + // ============ Pool Types | 协程池类型 ============ + + // RenewPoolManager 续期池管理器 + RenewPoolManager = ants.RenewPoolManager + // RenewPoolConfig 续期池配置 + RenewPoolConfig = ants.RenewPoolConfig ) -// Token style constants | Token风格常量 +// ============ Error Codes | 错误码 ============ + const ( - TokenStyleUUID = core.TokenStyleUUID - TokenStyleSimple = core.TokenStyleSimple - TokenStyleRandom32 = core.TokenStyleRandom32 - TokenStyleRandom64 = core.TokenStyleRandom64 - TokenStyleRandom128 = core.TokenStyleRandom128 - TokenStyleJWT = core.TokenStyleJWT - TokenStyleHash = core.TokenStyleHash - TokenStyleTimestamp = core.TokenStyleTimestamp - TokenStyleTik = core.TokenStyleTik + CodeSuccess = core.CodeSuccess + CodeBadRequest = core.CodeBadRequest + CodeNotLogin = core.CodeNotLogin + CodePermissionDenied = core.CodePermissionDenied + CodeNotFound = core.CodeNotFound + CodeServerError = core.CodeServerError + CodeTokenInvalid = core.CodeTokenInvalid + CodeTokenExpired = core.CodeTokenExpired + CodeAccountDisabled = core.CodeAccountDisabled + CodeKickedOut = core.CodeKickedOut + CodeActiveTimeout = core.CodeActiveTimeout + CodeMaxLoginCount = core.CodeMaxLoginCount + CodeStorageError = core.CodeStorageError + CodeInvalidParameter = core.CodeInvalidParameter + CodeSessionError = core.CodeSessionError ) -// Core types | 核心类型 -type ( - Manager = core.Manager - TokenInfo = core.TokenInfo - Session = core.Session - TokenGenerator = core.TokenGenerator - SaTokenContext = core.SaTokenContext - Builder = core.Builder - NonceManager = core.NonceManager - RefreshTokenInfo = core.RefreshTokenInfo - RefreshTokenManager = core.RefreshTokenManager - OAuth2Server = core.OAuth2Server - OAuth2Client = core.OAuth2Client - OAuth2AccessToken = core.OAuth2AccessToken - OAuth2GrantType = core.OAuth2GrantType -) +// ============ Errors | 错误变量 ============ -// Adapter interfaces | 适配器接口 -type ( - Storage = core.Storage - RequestContext = core.RequestContext +var ( + // Authentication Errors | 认证错误 + ErrNotLogin = core.ErrNotLogin + ErrTokenInvalid = core.ErrTokenInvalid + ErrTokenExpired = core.ErrTokenExpired + ErrTokenKickout = core.ErrTokenKickout + ErrTokenReplaced = core.ErrTokenReplaced + ErrInvalidLoginID = core.ErrInvalidLoginID + ErrInvalidDevice = core.ErrInvalidDevice + ErrTokenNotFound = core.ErrTokenNotFound + + // Authorization Errors | 授权错误 + ErrPermissionDenied = core.ErrPermissionDenied + ErrRoleDenied = core.ErrRoleDenied + + // Account Errors | 账号错误 + ErrAccountDisabled = core.ErrAccountDisabled + ErrAccountNotFound = core.ErrAccountNotFound + ErrLoginLimitExceeded = core.ErrLoginLimitExceeded + + // Session Errors | 会话错误 + ErrSessionNotFound = core.ErrSessionNotFound + ErrActiveTimeout = core.ErrActiveTimeout + ErrSessionInvalidDataKey = core.ErrSessionInvalidDataKey + ErrSessionIDEmpty = core.ErrSessionIDEmpty + + // Security Errors | 安全错误 + ErrInvalidNonce = core.ErrInvalidNonce + ErrRefreshTokenExpired = core.ErrRefreshTokenExpired + ErrNonceInvalidRefreshToken = core.ErrNonceInvalidRefreshToken + ErrInvalidLoginIDEmpty = core.ErrInvalidLoginIDEmpty + + // OAuth2 Errors | OAuth2错误 + ErrClientOrClientIDEmpty = core.ErrClientOrClientIDEmpty + ErrClientNotFound = core.ErrClientNotFound + ErrUserIDEmpty = core.ErrUserIDEmpty + ErrInvalidRedirectURI = core.ErrInvalidRedirectURI + ErrInvalidClientCredentials = core.ErrInvalidClientCredentials + ErrInvalidAuthCode = core.ErrInvalidAuthCode + ErrAuthCodeUsed = core.ErrAuthCodeUsed + ErrAuthCodeExpired = core.ErrAuthCodeExpired + ErrClientMismatch = core.ErrClientMismatch + ErrRedirectURIMismatch = core.ErrRedirectURIMismatch + ErrInvalidAccessToken = core.ErrInvalidAccessToken + ErrInvalidRefreshToken = core.ErrInvalidRefreshToken + ErrInvalidScope = core.ErrInvalidScope + + // System Errors | 系统错误 + ErrStorageUnavailable = core.ErrStorageUnavailable + ErrSerializeFailed = core.ErrSerializeFailed + ErrDeserializeFailed = core.ErrDeserializeFailed + ErrTypeConvert = core.ErrTypeConvert ) -// Event related types | 事件相关类型 -type ( - EventListener = core.EventListener - EventManager = core.EventManager - EventData = core.EventData - Event = core.Event - ListenerFunc = core.ListenerFunc - ListenerConfig = core.ListenerConfig -) +// ============ Error Constructors | 错误构造函数 ============ -// Event constants | 事件常量 -const ( - EventLogin = core.EventLogin - EventLogout = core.EventLogout - EventKickout = core.EventKickout - EventDisable = core.EventDisable - EventUntie = core.EventUntie - EventRenew = core.EventRenew - EventCreateSession = core.EventCreateSession - EventDestroySession = core.EventDestroySession - EventPermissionCheck = core.EventPermissionCheck - EventRoleCheck = core.EventRoleCheck - EventAll = core.EventAll +var ( + NewError = core.NewError + NewErrorWithContext = core.NewErrorWithContext + NewNotLoginError = core.NewNotLoginError + NewPermissionDeniedError = core.NewPermissionDeniedError + NewRoleDeniedError = core.NewRoleDeniedError + NewAccountDisabledError = core.NewAccountDisabledError ) -// OAuth2 grant type constants | OAuth2授权类型常量 -const ( - GrantTypeAuthorizationCode = core.GrantTypeAuthorizationCode - GrantTypeRefreshToken = core.GrantTypeRefreshToken - GrantTypeClientCredentials = core.GrantTypeClientCredentials - GrantTypePassword = core.GrantTypePassword -) +// ============ Error Checking Helpers | 错误检查辅助函数 ============ -// Utility functions | 工具函数 var ( - RandomString = core.RandomString - IsEmpty = core.IsEmpty - IsNotEmpty = core.IsNotEmpty - DefaultString = core.DefaultString - ContainsString = core.ContainsString - RemoveString = core.RemoveString - UniqueStrings = core.UniqueStrings - MergeStrings = core.MergeStrings - MatchPattern = core.MatchPattern + IsNotLoginError = core.IsNotLoginError + IsPermissionDeniedError = core.IsPermissionDeniedError + IsAccountDisabledError = core.IsAccountDisabledError + IsTokenError = core.IsTokenError + GetErrorCode = core.GetErrorCode ) -// ============ Core constructor functions | 核心构造函数 ============ +// ============ Manager Management | Manager 管理 ============ + +// SetManager stores the manager-example in the global map using the specified autoType | 使用指定的 autoType 将管理器存储在全局 map 中 +func SetManager(mgr *manager.Manager) { + stputil.SetManager(mgr) +} + +// GetManager retrieves the manager-example from the global map using the specified autoType | 使用指定的 autoType 从全局 map 中获取管理器 +func GetManager(autoType ...string) (*manager.Manager, error) { + return stputil.GetManager(autoType...) +} -// DefaultConfig returns default configuration | 返回默认配置 -func DefaultConfig() *Config { - return core.DefaultConfig() +// DeleteManager delete the specific manager-example for the given autoType and releases resources | 删除指定的管理器并释放资源 +func DeleteManager(autoType ...string) error { + return stputil.DeleteManager(autoType...) } -// NewManager creates a new authentication manager | 创建新的认证管理器 -func NewManager(storage Storage, cfg *Config) *Manager { - return core.NewManager(storage, cfg) +// DeleteAllManager delete all managers in the global map and releases resources | 关闭所有管理器并释放资源 +func DeleteAllManager() { + stputil.DeleteAllManager() } -// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext { - return core.NewContext(ctx, mgr) +// ============ Builder & Config | 构建器和配置 ============ + +// NewDefaultBuild creates a new default builder | 创建默认构建器 +func NewDefaultBuild() *builder.Builder { + return builder.NewBuilder() } -// NewSession creates a new session | 创建新的Session -func NewSession(id string, storage Storage, prefix string) *Session { - return core.NewSession(id, storage, prefix) +// NewDefaultConfig creates a new default config | 创建默认配置 +func NewDefaultConfig() *config.Config { + return config.DefaultConfig() } -// LoadSession loads an existing session | 加载已存在的Session -func LoadSession(id string, storage Storage, prefix string) (*Session, error) { - return core.LoadSession(id, storage, prefix) +// DefaultLoggerConfig returns the default logger config | 返回默认日志配置 +func DefaultLoggerConfig() *slog.LoggerConfig { + return slog.DefaultLoggerConfig() } -// NewTokenGenerator creates a new token generator | 创建新的Token生成器 -func NewTokenGenerator(cfg *Config) *TokenGenerator { - return core.NewTokenGenerator(cfg) +// DefaultRenewPoolConfig returns the default renew pool config | 返回默认续期池配置 +func DefaultRenewPoolConfig() *ants.RenewPoolConfig { + return ants.DefaultRenewPoolConfig() } -// NewEventManager creates a new event manager | 创建新的事件管理器 -func NewEventManager() *EventManager { - return core.NewEventManager() +// ============ Codec Constructors | 编解码器构造函数 ============ + +// NewJSONSerializer creates a new JSON serializer | 创建JSON序列化器 +func NewJSONSerializer() *json.JSONSerializer { + return json.NewJSONSerializer() } -// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置) -func NewBuilder() *Builder { - return core.NewBuilder() +// NewMsgPackSerializer creates a new MsgPack serializer | 创建MsgPack序列化器 +func NewMsgPackSerializer() *msgpack.MsgPackSerializer { + return msgpack.NewMsgPackSerializer() } -// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器 -func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { - return core.NewNonceManager(storage, prefix, ttl...) +// ============ Storage Constructors | 存储构造函数 ============ + +// NewMemoryStorage creates a new memory storage | 创建内存存储 +func NewMemoryStorage() *memory.Storage { + return memory.NewStorage() } -// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器 -func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { - return core.NewRefreshTokenManager(storage, prefix, cfg) +// NewMemoryStorageWithCleanupInterval creates a new memory storage with cleanup interval | 创建带清理间隔的内存存储 +func NewMemoryStorageWithCleanupInterval(interval time.Duration) *memory.Storage { + return memory.NewStorageWithCleanupInterval(interval) } -// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器 -func NewOAuth2Server(storage Storage, prefix string) *OAuth2Server { - return core.NewOAuth2Server(storage, prefix) +// NewRedisStorage creates a new Redis storage from URL | 通过URL创建Redis存储 +func NewRedisStorage(url string) (*redis.Storage, error) { + return redis.NewStorage(url) } -// ============ Global StpUtil functions | 全局StpUtil函数 ============ +// NewRedisStorageFromConfig creates a new Redis storage from config | 通过配置创建Redis存储 +func NewRedisStorageFromConfig(cfg *redis.Config) (*redis.Storage, error) { + return redis.NewStorageFromConfig(cfg) +} -// SetManager sets the global Manager (must be called first) | 设置全局Manager(必须先调用此方法) -func SetManager(mgr *Manager) { - stputil.SetManager(mgr) +// NewRedisBuilder creates a new Redis builder | 创建Redis构建器 +func NewRedisBuilder() *redis.Builder { + return redis.NewBuilder() +} + +// ============ Logger Constructors | 日志构造函数 ============ + +// NewSlogLogger creates a new slog logger with config | 使用配置创建标准日志器 +func NewSlogLogger(cfg *slog.LoggerConfig) (*slog.Logger, error) { + return slog.NewLoggerWithConfig(cfg) } -// GetManager gets the global Manager | 获取全局Manager -func GetManager() *Manager { - return stputil.GetManager() +// NewNopLogger creates a new no-op logger | 创建空日志器 +func NewNopLogger() *nop.NopLogger { + return nop.NewNopLogger() } +// ============ Generator Constructors | 生成器构造函数 ============ + +// NewTokenGenerator creates a new token generator | 创建Token生成器 +func NewTokenGenerator(timeout int64, tokenStyle adapter.TokenStyle, jwtSecretKey string) *sgenerator.Generator { + return sgenerator.NewGenerator(timeout, tokenStyle, jwtSecretKey) +} + +// NewDefaultTokenGenerator creates a new default token generator | 创建默认Token生成器 +func NewDefaultTokenGenerator() *sgenerator.Generator { + return sgenerator.NewDefaultGenerator() +} + +// ============ Pool Constructors | 协程池构造函数 ============ + +// NewRenewPoolManager creates a new renew pool manager-example with default config | 使用默认配置创建续期池管理器 +func NewRenewPoolManager() *ants.RenewPoolManager { + return ants.NewRenewPoolManagerWithDefaultConfig() +} + +// NewRenewPoolManagerWithConfig creates a new renew pool manager-example with config | 使用配置创建续期池管理器 +func NewRenewPoolManagerWithConfig(cfg *ants.RenewPoolConfig) (*ants.RenewPoolManager, error) { + return ants.NewRenewPoolManagerWithConfig(cfg) +} + +// ============ Token Style Constants | Token风格常量 ============ + +const ( + // TokenStyleUUID UUID style | UUID风格 + TokenStyleUUID = adapter.TokenStyleUUID + // TokenStyleSimple Simple random string | 简单随机字符串 + TokenStyleSimple = adapter.TokenStyleSimple + // TokenStyleRandom32 32-bit random string | 32位随机字符串 + TokenStyleRandom32 = adapter.TokenStyleRandom32 + // TokenStyleRandom64 64-bit random string | 64位随机字符串 + TokenStyleRandom64 = adapter.TokenStyleRandom64 + // TokenStyleRandom128 128-bit random string | 128位随机字符串 + TokenStyleRandom128 = adapter.TokenStyleRandom128 + // TokenStyleJWT JWT style | JWT风格 + TokenStyleJWT = adapter.TokenStyleJWT + // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 + TokenStyleHash = adapter.TokenStyleHash + // TokenStyleTimestamp Timestamp-based style | 时间戳风格 + TokenStyleTimestamp = adapter.TokenStyleTimestamp + // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID + TokenStyleTik = adapter.TokenStyleTik +) + +// ============ Log Level Constants | 日志级别常量 ============ + +const ( + // LogLevelDebug Debug level | 调试级别 + LogLevelDebug = adapter.LogLevelDebug + // LogLevelInfo Info level | 信息级别 + LogLevelInfo = adapter.LogLevelInfo + // LogLevelWarn Warn level | 警告级别 + LogLevelWarn = adapter.LogLevelWarn + // LogLevelError Error level | 错误级别 + LogLevelError = adapter.LogLevelError +) + // ============ Authentication | 登录认证 ============ // Login performs user login | 用户登录 -func Login(loginID interface{}, device ...string) (string, error) { - return stputil.Login(loginID, device...) +func Login(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.Login(ctx, loginID, deviceOrAutoType...) } // LoginByToken performs login with specified token | 使用指定Token登录 -func LoginByToken(loginID interface{}, tokenValue string, device ...string) error { - return stputil.LoginByToken(loginID, tokenValue, device...) +func LoginByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LoginByToken(ctx, tokenValue, authType...) } // Logout performs user logout | 用户登出 -func Logout(loginID interface{}, device ...string) error { - return stputil.Logout(loginID, device...) +func Logout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Logout(ctx, loginID, deviceOrAutoType...) } // LogoutByToken performs logout by token | 根据Token登出 -func LogoutByToken(tokenValue string) error { - return stputil.LogoutByToken(tokenValue) +func LogoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LogoutByToken(ctx, tokenValue, authType...) +} + +// Kickout kicks out a user session | 踢人下线 +func Kickout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Kickout(ctx, loginID, deviceOrAutoType...) +} + +// KickoutByToken Kick user offline | 根据Token踢人下线 +func KickoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.KickoutByToken(ctx, tokenValue, authType...) +} + +// Replace user offline by login ID and device | 根据账号和设备顶人下线 +func Replace(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Replace(ctx, loginID, deviceOrAutoType...) +} + +// ReplaceByToken Replace user offline by token | 根据Token顶人下线 +func ReplaceByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.ReplaceByToken(ctx, tokenValue, authType...) } +// ============ Token Validation | Token验证 ============ + // IsLogin checks if the user is logged in | 检查用户是否已登录 -func IsLogin(tokenValue string) bool { - return stputil.IsLogin(tokenValue) +func IsLogin(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsLogin(ctx, tokenValue, authType...) } // CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) -func CheckLogin(tokenValue string) error { - return stputil.CheckLogin(tokenValue) +func CheckLogin(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.CheckLogin(ctx, tokenValue, authType...) +} + +// CheckLoginWithState checks the login status (returns error to determine the reason if not logged in) | 检查登录状态(未登录时根据错误确定原因) +func CheckLoginWithState(ctx context.Context, tokenValue string, authType ...string) (bool, error) { + return stputil.CheckLoginWithState(ctx, tokenValue, authType...) } // GetLoginID gets the login ID from token | 从Token获取登录ID -func GetLoginID(tokenValue string) (string, error) { - return stputil.GetLoginID(tokenValue) +func GetLoginID(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginID(ctx, tokenValue, authType...) } -// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查) -func GetLoginIDNotCheck(tokenValue string) (string, error) { - return stputil.GetLoginIDNotCheck(tokenValue) +// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查登录状态) +func GetLoginIDNotCheck(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginIDNotCheck(ctx, tokenValue, authType...) } // GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 -func GetTokenValue(loginID interface{}, device ...string) (string, error) { - return stputil.GetTokenValue(loginID, device...) +func GetTokenValue(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.GetTokenValue(ctx, loginID, deviceOrAutoType...) } // GetTokenInfo gets token information | 获取Token信息 -func GetTokenInfo(tokenValue string) (*TokenInfo, error) { - return stputil.GetTokenInfo(tokenValue) +func GetTokenInfo(ctx context.Context, tokenValue string, authType ...string) (*manager.TokenInfo, error) { + return stputil.GetTokenInfo(ctx, tokenValue, authType...) } -// ============ Kickout | 踢人下线 ============ +// ============ Account Disable | 账号封禁 ============ -// Kickout kicks out a user session | 踢人下线 -func Kickout(loginID interface{}, device ...string) error { - return stputil.Kickout(loginID, device...) +// Disable disables an account for specified duration | 封禁账号(指定时长) +func Disable(ctx context.Context, loginID interface{}, duration time.Duration, authType ...string) error { + return stputil.Disable(ctx, loginID, duration, authType...) } -// ============ Account Disable | 账号封禁 ============ +// DisableByToken disables the account associated with the given token for a duration | 根据指定 Token 封禁其对应的账号 +func DisableByToken(ctx context.Context, tokenValue string, duration time.Duration, authType ...string) error { + return stputil.DisableByToken(ctx, tokenValue, duration, authType...) +} -// Disable disables an account for specified duration | 封禁账号(指定时长) -func Disable(loginID interface{}, duration time.Duration) error { - return stputil.Disable(loginID, duration) +// Untie re-enables a disabled account | 解封账号 +func Untie(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.Untie(ctx, loginID, authType...) +} + +// UntieByToken re-enables a disabled account by token | 根据Token解封账号 +func UntieByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.UntieByToken(ctx, tokenValue, authType...) } // IsDisable checks if an account is disabled | 检查账号是否被封禁 -func IsDisable(loginID interface{}) bool { - return stputil.IsDisable(loginID) +func IsDisable(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.IsDisable(ctx, loginID, authType...) } -// CheckDisable checks if account is disabled (throws error if disabled) | 检查账号是否被封禁(被封禁则抛出错误) -func CheckDisableByToken(tokenValue string) error { - return stputil.CheckDisable(tokenValue) +// IsDisableByToken checks if an account is disabled by token | 根据Token检查账号是否被封禁 +func IsDisableByToken(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsDisableByToken(ctx, tokenValue, authType...) } -// GetDisableTime gets remaining disabled time | 获取账号剩余封禁时间 -func GetDisableTime(loginID interface{}) (int64, error) { - return stputil.GetDisableTime(loginID) +// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) +func GetDisableTime(ctx context.Context, loginID interface{}, authType ...string) (int64, error) { + return stputil.GetDisableTime(ctx, loginID, authType...) } -// Untie unties/unlocks an account | 解除账号封禁 -func Untie(loginID interface{}) error { - return stputil.Untie(loginID) +// GetDisableTimeByToken gets remaining disable time by token | 根据Token获取剩余封禁时间(秒) +func GetDisableTimeByToken(ctx context.Context, tokenValue string, authType ...string) (int64, error) { + return stputil.GetDisableTimeByToken(ctx, tokenValue, authType...) } -// ============ Permission Check | 权限验证 ============ +// CheckDisableWithInfo gets disable info | 获取封禁信息 +func CheckDisableWithInfo(ctx context.Context, loginID interface{}, authType ...string) (*manager.DisableInfo, error) { + return stputil.CheckDisableWithInfo(ctx, loginID, authType...) +} -// CheckPermission checks if the account has specified permission | 检查账号是否拥有指定权限 -func CheckPermissionByToken(tokenValue string, permission string) error { - return stputil.CheckPermission(tokenValue, permission) +// CheckDisableWithInfoByToken gets disable info by token | 根据Token获取封禁信息 +func CheckDisableWithInfoByToken(ctx context.Context, tokenValue string, authType ...string) (*manager.DisableInfo, error) { + return stputil.CheckDisableWithInfoByToken(ctx, tokenValue, authType...) } -// HasPermission checks if the account has specified permission (returns bool) | 检查账号是否拥有指定权限(返回布尔值) -func HasPermission(loginID interface{}, permission string) bool { - return stputil.HasPermission(loginID, permission) +// ============ Session Management | Session管理 ============ + +// GetSession gets session by login ID | 根据登录ID获取Session +func GetSession(ctx context.Context, loginID interface{}, authType ...string) (*session.Session, error) { + return stputil.GetSession(ctx, loginID, authType...) +} + +// GetSessionByToken gets session by token | 根据Token获取Session +func GetSessionByToken(ctx context.Context, tokenValue string, authType ...string) (*session.Session, error) { + return stputil.GetSessionByToken(ctx, tokenValue, authType...) } -// CheckPermissionAnd checks if the account has all specified permissions (AND logic) | 检查账号是否拥有所有指定权限(AND逻辑) -func CheckPermissionAndByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionAnd(tokenValue, permissions) +// DeleteSession deletes a session | 删除Session +func DeleteSession(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.DeleteSession(ctx, loginID, authType...) } -// CheckPermissionOr checks if the account has any of the specified permissions (OR logic) | 检查账号是否拥有指定权限中的任意一个(OR逻辑) -func CheckPermissionOrByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionOr(tokenValue, permissions) +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func DeleteSessionByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.DeleteSessionByToken(ctx, tokenValue, authType...) } -// GetPermissionList gets the permission list for an account | 获取账号的权限列表 -func GetPermissionListByToken(tokenValue string) ([]string, error) { - return stputil.GetPermissionList(tokenValue) +// HasSession checks if session exists | 检查Session是否存在 +func HasSession(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.HasSession(ctx, loginID, authType...) } -// ============ Role Check | 角色验证 ============ +// RenewSession renews session TTL | 续期Session +func RenewSession(ctx context.Context, loginID interface{}, ttl time.Duration, authType ...string) error { + return stputil.RenewSession(ctx, loginID, ttl, authType...) +} -// CheckRole checks if the account has specified role | 检查账号是否拥有指定角色 -func CheckRoleByToken(tokenValue string, role string) error { - return stputil.CheckRole(tokenValue, role) +// RenewSessionByToken renews session TTL by token | 根据Token续期Session +func RenewSessionByToken(ctx context.Context, tokenValue string, ttl time.Duration, authType ...string) error { + return stputil.RenewSessionByToken(ctx, tokenValue, ttl, authType...) } -// HasRole checks if the account has specified role (returns bool) | 检查账号是否拥有指定角色(返回布尔值) -func HasRole(loginID interface{}, role string) bool { - return stputil.HasRole(loginID, role) +// ============ Permission Verification | 权限验证 ============ + +// SetPermissions sets permissions for a login ID | 设置用户权限 +func SetPermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.SetPermissions(ctx, loginID, permissions, authType...) } -// CheckRoleAnd checks if the account has all specified roles (AND logic) | 检查账号是否拥有所有指定角色(AND逻辑) -func CheckRoleAndByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleAnd(tokenValue, roles) +// SetPermissionsByToken sets permissions by token | 根据 Token 设置对应账号的权限 +func SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.SetPermissionsByToken(ctx, tokenValue, permissions, authType...) } -// CheckRoleOr checks if the account has any of the specified roles (OR logic) | 检查账号是否拥有指定角色中的任意一个(OR逻辑) -func CheckRoleOrByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleOr(tokenValue, roles) +// RemovePermissions removes specified permissions for a login ID | 删除用户指定权限 +func RemovePermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.RemovePermissions(ctx, loginID, permissions, authType...) } -// GetRoleList gets the role list for an account | 获取账号的角色列表 -func GetRoleListByToken(tokenValue string) ([]string, error) { - return stputil.GetRoleList(tokenValue) +// RemovePermissionsByToken removes specified permissions by token | 根据 Token 删除对应账号的指定权限 +func RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.RemovePermissionsByToken(ctx, tokenValue, permissions, authType...) } -// ============ Session Management | Session管理 ============ +// GetPermissions gets permission list | 获取权限列表 +func GetPermissions(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetPermissions(ctx, loginID, authType...) +} + +// GetPermissionsByToken gets permission list by token | 根据 Token 获取对应账号的权限列表 +func GetPermissionsByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetPermissionsByToken(ctx, tokenValue, authType...) +} + +// HasPermission checks if has specified permission | 检查是否拥有指定权限 +func HasPermission(ctx context.Context, loginID interface{}, permission string, authType ...string) bool { + return stputil.HasPermission(ctx, loginID, permission, authType...) +} + +// HasPermissionByToken checks if the token has the specified permission | 检查Token是否拥有指定权限 +func HasPermissionByToken(ctx context.Context, tokenValue string, permission string, authType ...string) bool { + return stputil.HasPermissionByToken(ctx, tokenValue, permission, authType...) +} + +// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) +func HasPermissionsAnd(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAnd(ctx, loginID, permissions, authType...) +} + +// HasPermissionsAndByToken checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 +func HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAndByToken(ctx, tokenValue, permissions, authType...) +} + +// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) +func HasPermissionsOr(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOr(ctx, loginID, permissions, authType...) +} + +// HasPermissionsOrByToken checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 +func HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOrByToken(ctx, tokenValue, permissions, authType...) +} + +// ============ Role Management | 角色管理 ============ + +// SetRoles sets roles for a login ID | 设置用户角色 +func SetRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.SetRoles(ctx, loginID, roles, authType...) +} + +// SetRolesByToken sets roles by token | 根据 Token 设置对应账号的角色 +func SetRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.SetRolesByToken(ctx, tokenValue, roles, authType...) +} + +// RemoveRoles removes specified roles for a login ID | 删除用户指定角色 +func RemoveRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.RemoveRoles(ctx, loginID, roles, authType...) +} + +// RemoveRolesByToken removes specified roles by token | 根据 Token 删除对应账号的指定角色 +func RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.RemoveRolesByToken(ctx, tokenValue, roles, authType...) +} + +// GetRoles gets role list | 获取角色列表 +func GetRoles(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetRoles(ctx, loginID, authType...) +} + +// GetRolesByToken gets role list by token | 根据 Token 获取对应账号的角色列表 +func GetRolesByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetRolesByToken(ctx, tokenValue, authType...) +} + +// HasRole checks if has specified role | 检查是否拥有指定角色 +func HasRole(ctx context.Context, loginID interface{}, role string, authType ...string) bool { + return stputil.HasRole(ctx, loginID, role, authType...) +} + +// HasRoleByToken checks if the token has the specified role | 检查 Token 是否拥有指定角色 +func HasRoleByToken(ctx context.Context, tokenValue string, role string, authType ...string) bool { + return stputil.HasRoleByToken(ctx, tokenValue, role, authType...) +} + +// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) +func HasRolesAnd(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesAnd(ctx, loginID, roles, authType...) +} -// GetSession gets the session for a login ID | 获取登录ID的Session -func GetSession(loginID interface{}) (*Session, error) { - return stputil.GetSession(loginID) +// HasRolesAndByToken checks if the token has all specified roles | 检查 Token 是否拥有所有指定角色 +func HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesAndByToken(ctx, tokenValue, roles, authType...) } -// GetSessionByToken gets the session by token | 根据Token获取Session -func GetSessionByToken(tokenValue string) (*Session, error) { - return stputil.GetSessionByToken(tokenValue) +// HasRolesOr checks if has any role (OR logic) | 检查是否拥有任一角色(OR逻辑) +func HasRolesOr(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesOr(ctx, loginID, roles, authType...) } -// GetTokenSession gets the token session | 获取Token的Session -func GetTokenSession(tokenValue string) (*Session, error) { - return stputil.GetTokenSession(tokenValue) +// HasRolesOrByToken checks if the token has any of the specified roles | 检查 Token 是否拥有任一指定角色 +func HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesOrByToken(ctx, tokenValue, roles, authType...) } -// ============ Token Renewal | Token续期 ============ +// ============ Token Tag | Token标签 ============ + +// SetTokenTag sets token tag | 设置Token标签 +func SetTokenTag(ctx context.Context, tokenValue, tag string, authType ...string) error { + return stputil.SetTokenTag(ctx, tokenValue, tag, authType...) +} + +// GetTokenTag gets token tag | 获取Token标签 +func GetTokenTag(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetTokenTag(ctx, tokenValue, authType...) +} -// RenewTimeout renews token timeout | 续期Token超时时间 +// ============ Session Query | 会话查询 ============ + +// GetTokenValueListByLoginID gets all tokens for a login ID | 获取指定账号的所有Token +func GetTokenValueListByLoginID(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetTokenValueListByLoginID(ctx, loginID, authType...) +} + +// GetSessionCountByLoginID gets session count for a login ID | 获取指定账号的Session数量 +func GetSessionCountByLoginID(ctx context.Context, loginID interface{}, authType ...string) (int, error) { + return stputil.GetSessionCountByLoginID(ctx, loginID, authType...) +} // ============ Security Features | 安全特性 ============ -// GenerateNonce generates a new nonce token | 生成新的Nonce令牌 -func GenerateNonce() (string, error) { - return stputil.GenerateNonce() +// Generate Generates a one-time nonce | 生成一次性随机数 +func Generate(ctx context.Context, authType ...string) (string, error) { + return stputil.Generate(ctx, authType...) +} + +// Verify Verifies a nonce | 验证随机数 +func Verify(ctx context.Context, nonce string, authType ...string) bool { + return stputil.Verify(ctx, nonce, authType...) } -// VerifyNonce verifies a nonce token | 验证Nonce令牌 -func VerifyNonce(nonce string) bool { - return stputil.VerifyNonce(nonce) +// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func VerifyAndConsume(ctx context.Context, nonce string, authType ...string) error { + return stputil.VerifyAndConsume(ctx, nonce, authType...) } -// LoginWithRefreshToken performs login and returns both access token and refresh token | 登录并返回访问令牌和刷新令牌 -func LoginWithRefreshToken(loginID interface{}, device ...string) (*RefreshTokenInfo, error) { - return stputil.LoginWithRefreshToken(loginID, device...) +// IsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func IsValidNonce(ctx context.Context, nonce string, authType ...string) bool { + return stputil.IsValidNonce(ctx, nonce, authType...) } -// RefreshAccessToken refreshes the access token using a refresh token | 使用刷新令牌刷新访问令牌 -func RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) { - return stputil.RefreshAccessToken(refreshToken) +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func GenerateTokenPair(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GenerateTokenPair(ctx, loginID, deviceOrAutoType...) } -// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌 -func RevokeRefreshToken(refreshToken string) error { - return stputil.RevokeRefreshToken(refreshToken) +// VerifyAccessToken verifies access token validity | 验证访问令牌是否有效 +func VerifyAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.VerifyAccessToken(ctx, accessToken, authType...) +} + +// VerifyAccessTokenAndGetInfo verifies access token and returns token info | 验证访问令牌并返回Token信息 +func VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*security.AccessTokenInfo, bool) { + return stputil.VerifyAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息 +func GetRefreshTokenInfo(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GetRefreshTokenInfo(ctx, refreshToken, authType...) +} + +// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func RefreshAccessToken(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.RefreshAccessToken(ctx, refreshToken, authType...) +} + +// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 +func RevokeRefreshToken(ctx context.Context, refreshToken string, authType ...string) error { + return stputil.RevokeRefreshToken(ctx, refreshToken, authType...) +} + +// IsValid checks whether token is valid | 检查Token是否有效 +func IsValid(ctx context.Context, refreshToken string, authType ...string) bool { + return stputil.IsValid(ctx, refreshToken, authType...) +} + +// ============ OAuth2 Features | OAuth2 功能 ============ + +// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func RegisterClient(ctx context.Context, client *oauth2.Client, authType ...string) error { + return stputil.RegisterClient(ctx, client, authType...) } -// GetOAuth2Server gets the OAuth2 server instance | 获取OAuth2服务器实例 -func GetOAuth2Server() *OAuth2Server { - return stputil.GetOAuth2Server() +// UnregisterClient unregisters an OAuth2 client | 注销OAuth2客户端 +func UnregisterClient(ctx context.Context, clientID string, authType ...string) error { + return stputil.UnregisterClient(ctx, clientID, authType...) } -// Version Sa-Token-Go version | Sa-Token-Go版本 -const Version = core.Version +// GetClient gets OAuth2 client information | 获取OAuth2客户端信息 +func GetClient(ctx context.Context, clientID string, authType ...string) (*oauth2.Client, error) { + return stputil.GetClient(ctx, clientID, authType...) +} + +// GenerateAuthorizationCode creates an authorization code | 创建授权码 +func GenerateAuthorizationCode(ctx context.Context, clientID, loginID, redirectURI string, scope []string, authType ...string) (*oauth2.AuthorizationCode, error) { + return stputil.GenerateAuthorizationCode(ctx, clientID, loginID, redirectURI, scope, authType...) +} + +// ExchangeCodeForToken exchanges authorization code for token | 使用授权码换取令牌 +func ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI, authType...) +} + +// ValidateAccessToken verifies OAuth2 access token | 验证OAuth2访问令牌 +func ValidateAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.ValidateAccessToken(ctx, accessToken, authType...) +} + +// ValidateAccessTokenAndGetInfo verifies OAuth2 access token and get info | 验证OAuth2访问令牌并获取信息 +func ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ValidateAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌(OAuth2) +func OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2RefreshAccessToken(ctx, clientID, refreshToken, clientSecret, authType...) +} + +// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func RevokeToken(ctx context.Context, accessToken string, authType ...string) error { + return stputil.RevokeToken(ctx, accessToken, authType...) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点 +func OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2Token(ctx, req, validateUser, authType...) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2ClientCredentialsToken(ctx, clientID, clientSecret, scopes, authType...) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser, authType...) +} + +// ============ OAuth2 Grant Type Constants | OAuth2授权类型常量 ============ + +const ( + // GrantTypeAuthorizationCode Authorization code grant type | 授权码模式 + GrantTypeAuthorizationCode = oauth2.GrantTypeAuthorizationCode + // GrantTypeClientCredentials Client credentials grant type | 客户端凭证模式 + GrantTypeClientCredentials = oauth2.GrantTypeClientCredentials + // GrantTypePassword Password grant type | 密码模式 + GrantTypePassword = oauth2.GrantTypePassword + // GrantTypeRefreshToken Refresh token grant type | 刷新令牌模式 + GrantTypeRefreshToken = oauth2.GrantTypeRefreshToken +) + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func GetConfig(ctx context.Context, authType ...string) *config.Config { + return stputil.GetConfig(ctx, authType...) +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func GetStorage(ctx context.Context, authType ...string) adapter.Storage { + return stputil.GetStorage(ctx, authType...) +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func GetCodec(ctx context.Context, authType ...string) adapter.Codec { + return stputil.GetCodec(ctx, authType...) +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func GetLog(ctx context.Context, authType ...string) adapter.Log { + return stputil.GetLog(ctx, authType...) +} + +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func GetPool(ctx context.Context, authType ...string) adapter.Pool { + return stputil.GetPool(ctx, authType...) +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func GetGenerator(ctx context.Context, authType ...string) adapter.Generator { + return stputil.GetGenerator(ctx, authType...) +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func GetNonceManager(ctx context.Context, authType ...string) *security.NonceManager { + return stputil.GetNonceManager(ctx, authType...) +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func GetRefreshManager(ctx context.Context, authType ...string) *security.RefreshTokenManager { + return stputil.GetRefreshManager(ctx, authType...) +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func GetEventManager(ctx context.Context, authType ...string) *listener.Manager { + return stputil.GetEventManager(ctx, authType...) +} + +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func GetOAuth2Server(ctx context.Context, authType ...string) *oauth2.OAuth2Server { + return stputil.GetOAuth2Server(ctx, authType...) +} + +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func RegisterFunc(event listener.Event, fn func(*listener.EventData), authType ...string) { + stputil.RegisterFunc(event, fn, authType...) +} + +// Register registers an event listener | 注册事件监听器 +func Register(event listener.Event, l listener.Listener, authType ...string) string { + return stputil.Register(event, l, authType...) +} + +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func RegisterWithConfig(event listener.Event, l listener.Listener, cfg listener.ListenerConfig, authType ...string) string { + return stputil.RegisterWithConfig(event, l, cfg, authType...) +} + +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func Unregister(id string, authType ...string) bool { + return stputil.Unregister(id, authType...) +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func TriggerEvent(data *listener.EventData, authType ...string) { + stputil.TriggerEvent(data, authType...) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func TriggerEventAsync(data *listener.EventData, authType ...string) { + stputil.TriggerEventAsync(data, authType...) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func TriggerEventSync(data *listener.EventData, authType ...string) { + stputil.TriggerEventSync(data, authType...) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func WaitEvents(authType ...string) { + stputil.WaitEvents(authType...) +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func ClearEventListeners(event listener.Event, authType ...string) { + stputil.ClearEventListeners(event, authType...) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func ClearAllEventListeners(authType ...string) { + stputil.ClearAllEventListeners(authType...) +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func CountEventListeners(event listener.Event, authType ...string) int { + return stputil.CountEventListeners(event, authType...) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func CountAllListeners(authType ...string) int { + return stputil.CountAllListeners(authType...) +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func GetEventListenerIDs(event listener.Event, authType ...string) []string { + return stputil.GetEventListenerIDs(event, authType...) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func GetAllRegisteredEvents(authType ...string) []listener.Event { + return stputil.GetAllRegisteredEvents(authType...) +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func HasEventListeners(event listener.Event, authType ...string) bool { + return stputil.HasEventListeners(event, authType...) +} diff --git a/integrations/fiber/go.mod b/integrations/fiber/go.mod index abb4115..47a0c17 100644 --- a/integrations/fiber/go.mod +++ b/integrations/fiber/go.mod @@ -1,18 +1,16 @@ module github.com/click33/sa-token-go/integrations/fiber -go 1.23.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 - github.com/click33/sa-token-go/stputil v0.1.5 + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 github.com/gofiber/fiber/v2 v2.52.0 ) require ( github.com/andybalholm/brotli v1.0.5 // indirect - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/klauspost/compress v1.17.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect @@ -23,8 +21,6 @@ require ( github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/valyala/fasthttp v1.51.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect ) - -replace github.com/click33/sa-token-go/core => ../../core diff --git a/integrations/fiber/go.sum b/integrations/fiber/go.sum index a04b709..714f472 100644 --- a/integrations/fiber/go.sum +++ b/integrations/fiber/go.sum @@ -1,8 +1,9 @@ github.com/andybalholm/brotli v1.0.5 h1:8uQZIdzKmjc/iuPu7O2ioW48L81FgatrcpfFmiq/cCs= -github.com/click33/sa-token-go/stputil v0.1.4 h1:YvMEwPfAfTunQn+AePudO3Esp0CvLoc2o5kmg/uZf/c= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/gofiber/fiber/v2 v2.52.0 h1:S+qXi7y+/Pgvqq4DrSmREGiFwtB7Bu6+QFLuIHYw/UE= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/klauspost/compress v1.17.0 h1:Rnbp4K9EjcDuVuHtd0dgA4qNuv9yKDYKK1ulpJwgrqM= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= @@ -15,6 +16,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/integrations/fiber/middleware.go b/integrations/fiber/middleware.go new file mode 100644 index 0000000..d020ecb --- /dev/null +++ b/integrations/fiber/middleware.go @@ -0,0 +1,307 @@ +package fiber + +import ( + "errors" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/manager" + + saContext "github.com/click33/sa-token-go/core/context" + "github.com/click33/sa-token-go/stputil" + "github.com/gofiber/fiber/v2" +) + +// LogicType permission/role logic type | 权限/角色判断逻辑 +type LogicType string + +const ( + SaTokenCtxKey = "saCtx" + + LogicOr LogicType = "OR" // Logical OR | 任一满足 + LogicAnd LogicType = "AND" // Logical AND | 全部满足 +) + +type AuthOption func(*AuthOptions) + +type AuthOptions struct { + AuthType string + LogicType LogicType + FailFunc func(c *fiber.Ctx, err error) error +} + +func defaultAuthOptions() *AuthOptions { + return &AuthOptions{LogicType: LogicAnd} // 默认 AND +} + +// WithAuthType sets auth type | 设置认证类型 +func WithAuthType(authType string) AuthOption { + return func(o *AuthOptions) { + o.AuthType = authType + } +} + +// WithLogicType sets LogicType option | 设置逻辑类型 +func WithLogicType(logicType LogicType) AuthOption { + return func(o *AuthOptions) { + o.LogicType = logicType + } +} + +// WithFailFunc sets auth failure callback | 设置认证失败回调 +func WithFailFunc(fn func(c *fiber.Ctx, err error) error) AuthOption { + return func(o *AuthOptions) { + o.FailFunc = fn + } +} + +// ========== Middlewares ========== + +// AuthMiddleware authentication middleware | 认证中间件 +func AuthMiddleware(opts ...AuthOption) fiber.Handler { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *fiber.Ctx) error { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 获取 token | Get token + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录 | Check login + err = mgr.CheckLogin(c.UserContext(), tokenValue) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + return c.Next() + } +} + +// AuthWithStateMiddleware with state authentication middleware | 带状态返回的认证中间件 +func AuthWithStateMiddleware(opts ...AuthOption) fiber.Handler { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *fiber.Ctx) error { + // 获取 Manager | Get Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录并返回状态 | Check login with state + _, err = mgr.CheckLoginWithState(c.UserContext(), tokenValue) + + if err != nil { + // 用户自定义回调优先 + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + return c.Next() + } +} + +// PermissionMiddleware permission check middleware | 权限校验中间件 +func PermissionMiddleware( + permissions []string, + opts ...AuthOption, +) fiber.Handler { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *fiber.Ctx) error { + // No permission required | 无需权限直接放行 + if len(permissions) == 0 { + return c.Next() + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + ctx := c.UserContext() + + // Permission check | 权限校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasPermissionsAndByToken(ctx, tokenValue, permissions) + } else { + ok = mgr.HasPermissionsOrByToken(ctx, tokenValue, permissions) + } + + if !ok { + if options.FailFunc != nil { + return options.FailFunc(c, core.ErrPermissionDenied) + } + return writeErrorResponse(c, core.ErrPermissionDenied) + } + + return c.Next() + } +} + +// RoleMiddleware role check middleware | 角色校验中间件 +func RoleMiddleware( + roles []string, + opts ...AuthOption, +) fiber.Handler { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *fiber.Ctx) error { + // No role required | 无需角色直接放行 + if len(roles) == 0 { + return c.Next() + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + return options.FailFunc(c, err) + } + return writeErrorResponse(c, err) + } + + // 构建 Sa-Token 上下文 | Build Sa-Token context + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + ctx := c.UserContext() + + // Role check | 角色校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasRolesAndByToken(ctx, tokenValue, roles) + } else { + ok = mgr.HasRolesOrByToken(ctx, tokenValue, roles) + } + + if !ok { + if options.FailFunc != nil { + return options.FailFunc(c, core.ErrRoleDenied) + } + return writeErrorResponse(c, core.ErrRoleDenied) + } + + return c.Next() + } +} + +// GetSaTokenContext gets Sa-Token context from Fiber context | 获取 Sa-Token 上下文 +func GetSaTokenContext(c *fiber.Ctx) (*saContext.SaTokenContext, bool) { + v := c.Locals(SaTokenCtxKey) + if v == nil { + return nil, false + } + + ctx, ok := v.(*saContext.SaTokenContext) + return ctx, ok +} + +func getSaContext(c *fiber.Ctx, mgr *manager.Manager) *saContext.SaTokenContext { + // Try get from context | 尝试从 ctx 取值 + if v := c.Locals(SaTokenCtxKey); v != nil { + if saCtx, ok := v.(*saContext.SaTokenContext); ok { + return saCtx + } + } + + // Create new context | 创建并缓存 SaTokenContext + saCtx := saContext.NewContext(NewFiberContext(c), mgr) + c.Locals(SaTokenCtxKey, saCtx) + + return saCtx +} + +// ============ Error Handling Helpers | 错误处理辅助函数 ============ + +// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 +func writeErrorResponse(c *fiber.Ctx, err error) error { + var saErr *core.SaTokenError + var code int + var message string + var httpStatus int + + // Check if it's a SaTokenError | 检查是否为SaTokenError + if errors.As(err, &saErr) { + code = saErr.Code + message = saErr.Message + httpStatus = getHTTPStatusFromCode(code) + } else { + // Handle standard errors | 处理标准错误 + code = core.CodeServerError + message = err.Error() + httpStatus = fiber.StatusInternalServerError + } + + return c.Status(httpStatus).JSON(fiber.Map{ + "code": code, + "message": message, + "data": err.Error(), + }) +} + +// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 +func writeSuccessResponse(c *fiber.Ctx, data interface{}) error { + return c.JSON(fiber.Map{ + "code": core.CodeSuccess, + "message": "success", + "data": data, + }) +} + +// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 +func getHTTPStatusFromCode(code int) int { + switch code { + case core.CodeNotLogin: + return fiber.StatusUnauthorized + case core.CodePermissionDenied: + return fiber.StatusForbidden + case core.CodeBadRequest: + return fiber.StatusBadRequest + case core.CodeNotFound: + return fiber.StatusNotFound + case core.CodeServerError: + return fiber.StatusInternalServerError + default: + return fiber.StatusInternalServerError + } +} diff --git a/integrations/fiber/plugin.go b/integrations/fiber/plugin.go deleted file mode 100644 index 7aa193b..0000000 --- a/integrations/fiber/plugin.go +++ /dev/null @@ -1,165 +0,0 @@ -package fiber - -import ( - "errors" - - "github.com/click33/sa-token-go/core" - "github.com/gofiber/fiber/v2" -) - -// Plugin Fiber plugin for Sa-Token | Fiber插件 -type Plugin struct { - manager *core.Manager -} - -// NewPlugin creates a Fiber plugin | 创建Fiber插件 -func NewPlugin(manager *core.Manager) *Plugin { - return &Plugin{ - manager: manager, - } -} - -// AuthMiddleware authentication middleware | 认证中间件 -func (p *Plugin) AuthMiddleware() fiber.Handler { - return func(c *fiber.Ctx) error { - ctx := NewFiberContext(c) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - return writeErrorResponse(c, err) - } - - c.Locals("satoken", saCtx) - return c.Next() - } -} - -// PermissionRequired permission validation middleware | 权限验证中间件 -func (p *Plugin) PermissionRequired(permission string) fiber.Handler { - return func(c *fiber.Ctx) error { - ctx := NewFiberContext(c) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - return writeErrorResponse(c, err) - } - - if !saCtx.HasPermission(permission) { - return writeErrorResponse(c, core.NewPermissionDeniedError(permission)) - } - - c.Locals("satoken", saCtx) - return c.Next() - } -} - -// RoleRequired role validation middleware | 角色验证中间件 -func (p *Plugin) RoleRequired(role string) fiber.Handler { - return func(c *fiber.Ctx) error { - ctx := NewFiberContext(c) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - return writeErrorResponse(c, err) - } - - if !saCtx.HasRole(role) { - return writeErrorResponse(c, core.NewRoleDeniedError(role)) - } - - c.Locals("satoken", saCtx) - return c.Next() - } -} - -// LoginHandler 登录处理器 -func (p *Plugin) LoginHandler(c *fiber.Ctx) error { - var req struct { - Username string `json:"username"` - Password string `json:"password"` - Device string `json:"device"` - } - - if err := c.BodyParser(&req); err != nil { - return writeErrorResponse(c, core.NewError(core.CodeBadRequest, "invalid request parameters", err)) - } - - device := req.Device - if device == "" { - device = "default" - } - - token, err := p.manager.Login(req.Username, device) - if err != nil { - return writeErrorResponse(c, core.NewError(core.CodeServerError, "login failed", err)) - } - - return writeSuccessResponse(c, fiber.Map{ - "token": token, - }) -} - -// GetSaToken 从Fiber上下文获取Sa-Token上下文 -func GetSaToken(c *fiber.Ctx) (*core.SaTokenContext, bool) { - satoken := c.Locals("satoken") - if satoken == nil { - return nil, false - } - ctx, ok := satoken.(*core.SaTokenContext) - return ctx, ok -} - -// ============ Error Handling Helpers | 错误处理辅助函数 ============ - -// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 -func writeErrorResponse(c *fiber.Ctx, err error) error { - var saErr *core.SaTokenError - var code int - var message string - var httpStatus int - - // Check if it's a SaTokenError | 检查是否为SaTokenError - if errors.As(err, &saErr) { - code = saErr.Code - message = saErr.Message - httpStatus = getHTTPStatusFromCode(code) - } else { - // Handle standard errors | 处理标准错误 - code = core.CodeServerError - message = err.Error() - httpStatus = fiber.StatusInternalServerError - } - - return c.Status(httpStatus).JSON(fiber.Map{ - "code": code, - "message": message, - "error": err.Error(), - }) -} - -// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 -func writeSuccessResponse(c *fiber.Ctx, data interface{}) error { - return c.JSON(fiber.Map{ - "code": core.CodeSuccess, - "message": "success", - "data": data, - }) -} - -// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 -func getHTTPStatusFromCode(code int) int { - switch code { - case core.CodeNotLogin: - return fiber.StatusUnauthorized - case core.CodePermissionDenied: - return fiber.StatusForbidden - case core.CodeBadRequest: - return fiber.StatusBadRequest - case core.CodeNotFound: - return fiber.StatusNotFound - case core.CodeServerError: - return fiber.StatusInternalServerError - default: - return fiber.StatusInternalServerError - } -} diff --git a/integrations/gf/README.md b/integrations/gf/README.md deleted file mode 100644 index 109a767..0000000 --- a/integrations/gf/README.md +++ /dev/null @@ -1 +0,0 @@ -# sa-token-go 集成 goframe 框架 \ No newline at end of file diff --git a/integrations/gf/annotation.go b/integrations/gf/annotation.go index d1e0630..c7a25fe 100644 --- a/integrations/gf/annotation.go +++ b/integrations/gf/annotation.go @@ -1,6 +1,8 @@ +// @Author daixk 2025/12/28 1:27:00 package gf import ( + "context" "strings" "github.com/click33/sa-token-go/core" @@ -10,17 +12,19 @@ import ( // Annotation annotation structure | 注解结构体 type Annotation struct { - CheckLogin bool `json:"checkLogin"` - CheckRole []string `json:"checkRole"` - CheckPermission []string `json:"checkPermission"` - CheckDisable bool `json:"checkDisable"` - Ignore bool `json:"ignore"` + AuthType string `json:"authType"` // Optional: specify auth type | 可选:指定认证类型 + CheckLogin bool `json:"checkLogin"` // Check login | 检查登录 + CheckRole []string `json:"checkRole"` // Check roles | 检查角色 + CheckPermission []string `json:"checkPermission"` // Check permissions | 检查权限 + CheckDisable bool `json:"checkDisable"` // Check disable status | 检查封禁状态 + Ignore bool `json:"ignore"` // Ignore authentication | 忽略认证 + LogicType LogicType `json:"logicType"` // OR or AND logic (default: OR) | OR 或 AND 逻辑(默认: OR) } // GetHandler gets handler with annotations | 获取带注解的处理器 -func GetHandler(handler ghttp.HandlerFunc, annotations ...*Annotation) ghttp.HandlerFunc { +func GetHandler(ctx context.Context, handler ghttp.HandlerFunc, failFunc func(r *ghttp.Request, err error), annotations ...*Annotation) ghttp.HandlerFunc { return func(r *ghttp.Request) { - // Check if authentication should be ignored | 检查是否忽略认证 + // Ignore authentication | 忽略认证直接放行 if len(annotations) > 0 && annotations[0].Ignore { if handler != nil { handler(r) @@ -30,62 +34,115 @@ func GetHandler(handler ghttp.HandlerFunc, annotations ...*Annotation) ghttp.Han return } - // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, stputil.GetManager()) - token := saCtx.GetTokenValue() - if token == "" { - writeErrorResponse(r, core.NewNotLoginError()) + // Check if any authentication is needed | 检查是否需要任何认证 + ann := &Annotation{} + if len(annotations) > 0 { + ann = annotations[0] + } + + // No authentication required | 无需任何认证 + needAuth := ann.CheckLogin || ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 + if !needAuth { + if handler != nil { + handler(r) + } else { + r.Middleware.Next() + } return } - // Check login | 检查登录 - if !stputil.IsLogin(token) { - writeErrorResponse(r, core.NewNotLoginError()) + // Get manager-example | 获取 Manager + mgr, err := stputil.GetManager(ann.AuthType) + if err != nil { + if failFunc != nil { + failFunc(r, err) + } else { + writeErrorResponse(r, err) + } return } - // Get login ID | 获取登录ID - loginID, err := stputil.GetLoginID(token) + // Get SaTokenContext (reuse cached context) | 获取 SaTokenContext(复用缓存上下文) + saCtx := getSaContext(r, mgr) + token := saCtx.GetTokenValue() + + // Check login | 检查登录 + isLogin, err := mgr.IsLogin(ctx, token) if err != nil { - writeErrorResponse(r, err) + if failFunc != nil { + failFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + if !isLogin { + if failFunc != nil { + failFunc(r, core.NewNotLoginError()) + } else { + writeErrorResponse(r, core.NewNotLoginError()) + } return } + // Get loginID for further checks | 获取 loginID 用于后续检查 + var loginID string + if ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 { + loginID, err = mgr.GetLoginIDNotCheck(ctx, token) + if err != nil { + if failFunc != nil { + failFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + } + // Check if account is disabled | 检查是否被封禁 - if len(annotations) > 0 && annotations[0].CheckDisable { - if stputil.IsDisable(loginID) { - writeErrorResponse(r, core.NewAccountDisabledError(loginID)) + if ann.CheckDisable { + if mgr.IsDisable(ctx, loginID) { + if failFunc != nil { + failFunc(r, core.NewAccountDisabledError(loginID)) + } else { + writeErrorResponse(r, core.NewAccountDisabledError(loginID)) + } return } } // Check permission | 检查权限 - if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { - hasPermission := false - for _, perm := range annotations[0].CheckPermission { - if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { - hasPermission = true - break - } + if len(ann.CheckPermission) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasPermissionsAnd(ctx, loginID, ann.CheckPermission) + } else { + ok = mgr.HasPermissionsOr(ctx, loginID, ann.CheckPermission) } - if !hasPermission { - writeErrorResponse(r, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + if !ok { + if failFunc != nil { + failFunc(r, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) + } else { + writeErrorResponse(r, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) + } return } } // Check role | 检查角色 - if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { - hasRole := false - for _, role := range annotations[0].CheckRole { - if stputil.HasRole(loginID, strings.TrimSpace(role)) { - hasRole = true - break - } + if len(ann.CheckRole) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasRolesAnd(ctx, loginID, ann.CheckRole) + } else { + ok = mgr.HasRolesOr(ctx, loginID, ann.CheckRole) } - if !hasRole { - writeErrorResponse(r, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + if !ok { + if failFunc != nil { + failFunc(r, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) + } else { + writeErrorResponse(r, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) + } return } } @@ -100,26 +157,171 @@ func GetHandler(handler ghttp.HandlerFunc, annotations ...*Annotation) ghttp.Han } // CheckLoginMiddleware decorator for login checking | 检查登录装饰器 -func CheckLoginMiddleware() ghttp.HandlerFunc { - return GetHandler(nil, &Annotation{CheckLogin: true}) +func CheckLoginMiddleware( + ctx context.Context, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckLogin: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } // CheckRoleMiddleware decorator for role checking | 检查角色装饰器 -func CheckRoleMiddleware(roles ...string) ghttp.HandlerFunc { - return GetHandler(nil, &Annotation{CheckRole: roles}) +func CheckRoleMiddleware( + ctx context.Context, + roles []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckRole: roles} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } // CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 -func CheckPermissionMiddleware(perms ...string) ghttp.HandlerFunc { - return GetHandler(nil, &Annotation{CheckPermission: perms}) +func CheckPermissionMiddleware( + ctx context.Context, + perms []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckPermission: perms} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } // CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 -func CheckDisableMiddleware() ghttp.HandlerFunc { - return GetHandler(nil, &Annotation{CheckDisable: true}) +func CheckDisableMiddleware( + ctx context.Context, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckDisable: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } // IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 -func IgnoreMiddleware() ghttp.HandlerFunc { - return GetHandler(nil, &Annotation{Ignore: true}) +func IgnoreMiddleware( + ctx context.Context, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), +) ghttp.HandlerFunc { + ann := &Annotation{Ignore: true} + return GetHandler(ctx, handler, failFunc, ann) +} + +// ============ Combined Middleware | 组合中间件 ============ + +// CheckLoginAndRoleMiddleware checks login and role | 检查登录和角色 +func CheckLoginAndRoleMiddleware( + ctx context.Context, + roles []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckLogin: true, CheckRole: roles} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) +} + +// CheckLoginAndPermissionMiddleware checks login and permission | 检查登录和权限 +func CheckLoginAndPermissionMiddleware( + ctx context.Context, + perms []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckLogin: true, CheckPermission: perms} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) +} + +// CheckAllMiddleware checks login, role, permission and disable status | 全面检查 +func CheckAllMiddleware( + ctx context.Context, + roles []string, + perms []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) ghttp.HandlerFunc { + ann := &Annotation{CheckLogin: true, CheckRole: roles, CheckPermission: perms} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) +} + +// ============ Route Group Helper | 路由组辅助函数 ============ + +// AuthGroup creates a route group with authentication | 创建带认证的路由组 +func AuthGroup( + ctx context.Context, + group *ghttp.RouterGroup, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) *ghttp.RouterGroup { + group.Middleware(CheckLoginMiddleware(ctx, handler, failFunc, authType...)) + return group +} + +// RoleGroup creates a route group with role checking | 创建带角色检查的路由组 +func RoleGroup( + ctx context.Context, + group *ghttp.RouterGroup, + roles []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) *ghttp.RouterGroup { + group.Middleware(CheckLoginAndRoleMiddleware(ctx, roles, handler, failFunc, authType...)) + return group +} + +// PermissionGroup creates a route group with permission checking | 创建带权限检查的路由组 +func PermissionGroup( + ctx context.Context, + group *ghttp.RouterGroup, + perms []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) *ghttp.RouterGroup { + group.Middleware(CheckLoginAndPermissionMiddleware(ctx, perms, handler, failFunc, authType...)) + return group +} + +// RoleAndPermissionGroup creates a route group with role and permission checking | 创建带角色和权限检查的路由组 +func RoleAndPermissionGroup( + ctx context.Context, + group *ghttp.RouterGroup, + roles []string, + perms []string, + handler ghttp.HandlerFunc, + failFunc func(r *ghttp.Request, err error), + authType ...string, +) *ghttp.RouterGroup { + group.Middleware(CheckAllMiddleware(ctx, roles, perms, handler, failFunc, authType...)) + return group } diff --git a/integrations/gf/context.go b/integrations/gf/context.go index c28eae0..6ec3b2c 100644 --- a/integrations/gf/context.go +++ b/integrations/gf/context.go @@ -12,6 +12,13 @@ type GFContext struct { aborted bool } +// NewGFContext creates a GF context adapter | 创建GF上下文适配器 +func NewGFContext(c *ghttp.Request) adapter.RequestContext { + return &GFContext{ + c: c, + } +} + // Get implements adapter.RequestContext. func (g *GFContext) Get(key string) (interface{}, bool) { v := g.c.Get(key) @@ -71,13 +78,6 @@ func (g *GFContext) SetHeader(key string, value string) { g.c.Header.Set(key, value) } -// NewGFContext creates a GF context adapter | 创建GF上下文适配器 -func NewGFContext(c *ghttp.Request) adapter.RequestContext { - return &GFContext{ - c: c, - } -} - // ============ Additional Required Methods | 额外必需的方法 ============ // GetHeaders implements adapter.RequestContext. @@ -123,7 +123,7 @@ func (g *GFContext) SetCookieWithOptions(options *adapter.CookieOptions) { HttpOnly: options.HttpOnly, SameSite: http.SameSite(0), // Default to SameSiteNone } - + // Set SameSite attribute switch options.SameSite { case "Strict": @@ -133,7 +133,7 @@ func (g *GFContext) SetCookieWithOptions(options *adapter.CookieOptions) { case "None": cookie.SameSite = http.SameSiteNoneMode } - + g.c.Cookie.SetHttpCookie(cookie) } diff --git a/integrations/gf/export.go b/integrations/gf/export.go index 1c73b61..7a95523 100644 --- a/integrations/gf/export.go +++ b/integrations/gf/export.go @@ -1,364 +1,952 @@ package gf import ( + "context" + "github.com/gogf/gf/v2/os/glog" "time" + "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/codec/msgpack" "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" + "github.com/click33/sa-token-go/core/manager" + "github.com/click33/sa-token-go/core/oauth2" + "github.com/click33/sa-token-go/core/security" + "github.com/click33/sa-token-go/core/session" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/gf" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/log/slog" + "github.com/click33/sa-token-go/pool/ants" + "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/storage/redis" "github.com/click33/sa-token-go/stputil" ) -// ============ Re-export core types | 重新导出核心类型 ============ +// ============ Type Aliases | 类型别名 ============ -// Configuration related types | 配置相关类型 type ( - Config = core.Config - CookieConfig = core.CookieConfig - TokenStyle = core.TokenStyle + // Config 配置 + Config = config.Config + // Manager 管理器 + Manager = manager.Manager + // Session 会话 + Session = session.Session + // TokenInfo Token信息 + TokenInfo = manager.TokenInfo + // DisableInfo 封禁信息 + DisableInfo = manager.DisableInfo + // Builder 构建器 + Builder = builder.Builder + // SaTokenError 错误类型 + SaTokenError = core.SaTokenError + // Event 事件类型 + Event = listener.Event + // EventData 事件数据 + EventData = listener.EventData + // Listener 事件监听器 + Listener = listener.Listener + // ListenerConfig 监听器配置 + ListenerConfig = listener.ListenerConfig + // RefreshTokenInfo 刷新令牌信息 + RefreshTokenInfo = security.RefreshTokenInfo + // AccessTokenInfo 访问令牌信息 + AccessTokenInfo = security.AccessTokenInfo + // OAuth2Client OAuth2客户端 + OAuth2Client = oauth2.Client + // OAuth2AccessToken OAuth2访问令牌 + OAuth2AccessToken = oauth2.AccessToken + // AuthorizationCode 授权码 + AuthorizationCode = oauth2.AuthorizationCode + // OAuth2TokenRequest OAuth2令牌请求 + OAuth2TokenRequest = oauth2.TokenRequest + // OAuth2GrantType OAuth2授权类型 + OAuth2GrantType = oauth2.GrantType + // OAuth2UserValidator OAuth2用户验证器 + OAuth2UserValidator = oauth2.UserValidator + // Storage 存储接口 + Storage = adapter.Storage + // Codec 编解码接口 + Codec = adapter.Codec + // Log 日志接口 + Log = adapter.Log + // Pool 协程池接口 + Pool = adapter.Pool + // Generator 生成器接口 + Generator = adapter.Generator + + // ============ Codec Types | 编解码器类型 ============ + + // JSONSerializer JSON编解码器 + JSONSerializer = json.JSONSerializer + // MsgPackSerializer MsgPack编解码器 + MsgPackSerializer = msgpack.MsgPackSerializer + + // ============ Storage Types | 存储类型 ============ + + // MemoryStorage 内存存储 + MemoryStorage = memory.Storage + // RedisStorage Redis存储 + RedisStorage = redis.Storage + // RedisConfig Redis配置 + RedisConfig = redis.Config + // RedisBuilder Redis构建器 + RedisBuilder = redis.Builder + + // ============ Logger Types | 日志类型 ============ + + // SlogLogger 标准日志实现 + SlogLogger = slog.Logger + // SlogLoggerConfig 标准日志配置 + SlogLoggerConfig = slog.LoggerConfig + // SlogLogLevel 日志级别 + SlogLogLevel = slog.LogLevel + // NopLogger 空日志实现 + NopLogger = nop.NopLogger + + // ============ Generator Types | 生成器类型 ============ + + // TokenGenerator Token生成器 + TokenGenerator = sgenerator.Generator + // TokenStyle Token风格 + TokenStyle = adapter.TokenStyle + + // ============ Pool Types | 协程池类型 ============ + + // RenewPoolManager 续期池管理器 + RenewPoolManager = ants.RenewPoolManager + // RenewPoolConfig 续期池配置 + RenewPoolConfig = ants.RenewPoolConfig ) -// Token style constants | Token风格常量 +// ============ Error Codes | 错误码 ============ + const ( - TokenStyleUUID = core.TokenStyleUUID - TokenStyleSimple = core.TokenStyleSimple - TokenStyleRandom32 = core.TokenStyleRandom32 - TokenStyleRandom64 = core.TokenStyleRandom64 - TokenStyleRandom128 = core.TokenStyleRandom128 - TokenStyleJWT = core.TokenStyleJWT - TokenStyleHash = core.TokenStyleHash - TokenStyleTimestamp = core.TokenStyleTimestamp - TokenStyleTik = core.TokenStyleTik + CodeSuccess = core.CodeSuccess + CodeBadRequest = core.CodeBadRequest + CodeNotLogin = core.CodeNotLogin + CodePermissionDenied = core.CodePermissionDenied + CodeNotFound = core.CodeNotFound + CodeServerError = core.CodeServerError + CodeTokenInvalid = core.CodeTokenInvalid + CodeTokenExpired = core.CodeTokenExpired + CodeAccountDisabled = core.CodeAccountDisabled + CodeKickedOut = core.CodeKickedOut + CodeActiveTimeout = core.CodeActiveTimeout + CodeMaxLoginCount = core.CodeMaxLoginCount + CodeStorageError = core.CodeStorageError + CodeInvalidParameter = core.CodeInvalidParameter + CodeSessionError = core.CodeSessionError ) -// Core types | 核心类型 -type ( - Manager = core.Manager - TokenInfo = core.TokenInfo - Session = core.Session - TokenGenerator = core.TokenGenerator - SaTokenContext = core.SaTokenContext - Builder = core.Builder - NonceManager = core.NonceManager - RefreshTokenInfo = core.RefreshTokenInfo - RefreshTokenManager = core.RefreshTokenManager - OAuth2Server = core.OAuth2Server - OAuth2Client = core.OAuth2Client - OAuth2AccessToken = core.OAuth2AccessToken - OAuth2GrantType = core.OAuth2GrantType -) +// ============ Errors | 错误变量 ============ -// Adapter interfaces | 适配器接口 -type ( - Storage = core.Storage - RequestContext = core.RequestContext +var ( + // Authentication Errors | 认证错误 + ErrNotLogin = core.ErrNotLogin + ErrTokenInvalid = core.ErrTokenInvalid + ErrTokenExpired = core.ErrTokenExpired + ErrTokenKickout = core.ErrTokenKickout + ErrTokenReplaced = core.ErrTokenReplaced + ErrInvalidLoginID = core.ErrInvalidLoginID + ErrInvalidDevice = core.ErrInvalidDevice + ErrTokenNotFound = core.ErrTokenNotFound + + // Authorization Errors | 授权错误 + ErrPermissionDenied = core.ErrPermissionDenied + ErrRoleDenied = core.ErrRoleDenied + + // Account Errors | 账号错误 + ErrAccountDisabled = core.ErrAccountDisabled + ErrAccountNotFound = core.ErrAccountNotFound + ErrLoginLimitExceeded = core.ErrLoginLimitExceeded + + // Session Errors | 会话错误 + ErrSessionNotFound = core.ErrSessionNotFound + ErrActiveTimeout = core.ErrActiveTimeout + ErrSessionInvalidDataKey = core.ErrSessionInvalidDataKey + ErrSessionIDEmpty = core.ErrSessionIDEmpty + + // Security Errors | 安全错误 + ErrInvalidNonce = core.ErrInvalidNonce + ErrRefreshTokenExpired = core.ErrRefreshTokenExpired + ErrNonceInvalidRefreshToken = core.ErrNonceInvalidRefreshToken + ErrInvalidLoginIDEmpty = core.ErrInvalidLoginIDEmpty + + // OAuth2 Errors | OAuth2错误 + ErrClientOrClientIDEmpty = core.ErrClientOrClientIDEmpty + ErrClientNotFound = core.ErrClientNotFound + ErrUserIDEmpty = core.ErrUserIDEmpty + ErrInvalidRedirectURI = core.ErrInvalidRedirectURI + ErrInvalidClientCredentials = core.ErrInvalidClientCredentials + ErrInvalidAuthCode = core.ErrInvalidAuthCode + ErrAuthCodeUsed = core.ErrAuthCodeUsed + ErrAuthCodeExpired = core.ErrAuthCodeExpired + ErrClientMismatch = core.ErrClientMismatch + ErrRedirectURIMismatch = core.ErrRedirectURIMismatch + ErrInvalidAccessToken = core.ErrInvalidAccessToken + ErrInvalidRefreshToken = core.ErrInvalidRefreshToken + ErrInvalidScope = core.ErrInvalidScope + + // System Errors | 系统错误 + ErrStorageUnavailable = core.ErrStorageUnavailable + ErrSerializeFailed = core.ErrSerializeFailed + ErrDeserializeFailed = core.ErrDeserializeFailed + ErrTypeConvert = core.ErrTypeConvert ) -// Event related types | 事件相关类型 -type ( - EventListener = core.EventListener - EventManager = core.EventManager - EventData = core.EventData - Event = core.Event - ListenerFunc = core.ListenerFunc - ListenerConfig = core.ListenerConfig -) +// ============ Error Constructors | 错误构造函数 ============ -// Event constants | 事件常量 -const ( - EventLogin = core.EventLogin - EventLogout = core.EventLogout - EventKickout = core.EventKickout - EventDisable = core.EventDisable - EventUntie = core.EventUntie - EventRenew = core.EventRenew - EventCreateSession = core.EventCreateSession - EventDestroySession = core.EventDestroySession - EventPermissionCheck = core.EventPermissionCheck - EventRoleCheck = core.EventRoleCheck - EventAll = core.EventAll +var ( + NewError = core.NewError + NewErrorWithContext = core.NewErrorWithContext + NewNotLoginError = core.NewNotLoginError + NewPermissionDeniedError = core.NewPermissionDeniedError + NewRoleDeniedError = core.NewRoleDeniedError + NewAccountDisabledError = core.NewAccountDisabledError ) -// OAuth2 grant type constants | OAuth2授权类型常量 -const ( - GrantTypeAuthorizationCode = core.GrantTypeAuthorizationCode - GrantTypeRefreshToken = core.GrantTypeRefreshToken - GrantTypeClientCredentials = core.GrantTypeClientCredentials - GrantTypePassword = core.GrantTypePassword -) +// ============ Error Checking Helpers | 错误检查辅助函数 ============ -// Utility functions | 工具函数 var ( - RandomString = core.RandomString - IsEmpty = core.IsEmpty - IsNotEmpty = core.IsNotEmpty - DefaultString = core.DefaultString - ContainsString = core.ContainsString - RemoveString = core.RemoveString - UniqueStrings = core.UniqueStrings - MergeStrings = core.MergeStrings - MatchPattern = core.MatchPattern + IsNotLoginError = core.IsNotLoginError + IsPermissionDeniedError = core.IsPermissionDeniedError + IsAccountDisabledError = core.IsAccountDisabledError + IsTokenError = core.IsTokenError + GetErrorCode = core.GetErrorCode ) -// ============ Core constructor functions | 核心构造函数 ============ +// ============ Manager Management | Manager 管理 ============ + +// SetManager stores the manager-example in the global map using the specified autoType | 使用指定的 autoType 将管理器存储在全局 map 中 +func SetManager(mgr *manager.Manager) { + stputil.SetManager(mgr) +} + +// GetManager retrieves the manager-example from the global map using the specified autoType | 使用指定的 autoType 从全局 map 中获取管理器 +func GetManager(autoType ...string) (*manager.Manager, error) { + return stputil.GetManager(autoType...) +} -// DefaultConfig returns default configuration | 返回默认配置 -func DefaultConfig() *Config { - return core.DefaultConfig() +// DeleteManager delete the specific manager-example for the given autoType and releases resources | 删除指定的管理器并释放资源 +func DeleteManager(autoType ...string) error { + return stputil.DeleteManager(autoType...) } -// NewManager creates a new authentication manager | 创建新的认证管理器 -func NewManager(storage Storage, cfg *Config) *Manager { - return core.NewManager(storage, cfg) +// DeleteAllManager delete all managers in the global map and releases resources | 关闭所有管理器并释放资源 +func DeleteAllManager() { + stputil.DeleteAllManager() } -// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext { - return core.NewContext(ctx, mgr) +// ============ Builder & Config | 构建器和配置 ============ + +// NewDefaultBuild creates a new default builder | 创建默认构建器 +func NewDefaultBuild() *builder.Builder { + return builder.NewBuilder() } -// NewSession creates a new session | 创建新的Session -func NewSession(id string, storage Storage, prefix string) *Session { - return core.NewSession(id, storage, prefix) +// NewDefaultConfig creates a new default config | 创建默认配置 +func NewDefaultConfig() *config.Config { + return config.DefaultConfig() } -// LoadSession loads an existing session | 加载已存在的Session -func LoadSession(id string, storage Storage, prefix string) (*Session, error) { - return core.LoadSession(id, storage, prefix) +// DefaultLoggerConfig returns the default logger config | 返回默认日志配置 +func DefaultLoggerConfig() *slog.LoggerConfig { + return slog.DefaultLoggerConfig() } -// NewTokenGenerator creates a new token generator | 创建新的Token生成器 -func NewTokenGenerator(cfg *Config) *TokenGenerator { - return core.NewTokenGenerator(cfg) +// DefaultRenewPoolConfig returns the default renew pool config | 返回默认续期池配置 +func DefaultRenewPoolConfig() *ants.RenewPoolConfig { + return ants.DefaultRenewPoolConfig() } -// NewEventManager creates a new event manager | 创建新的事件管理器 -func NewEventManager() *EventManager { - return core.NewEventManager() +// ============ Codec Constructors | 编解码器构造函数 ============ + +// NewJSONSerializer creates a new JSON serializer | 创建JSON序列化器 +func NewJSONSerializer() *json.JSONSerializer { + return json.NewJSONSerializer() } -// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置) -func NewBuilder() *Builder { - return core.NewBuilder() +// NewMsgPackSerializer creates a new MsgPack serializer | 创建MsgPack序列化器 +func NewMsgPackSerializer() *msgpack.MsgPackSerializer { + return msgpack.NewMsgPackSerializer() } -// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器 -func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { - return core.NewNonceManager(storage, prefix, ttl...) +// ============ Storage Constructors | 存储构造函数 ============ + +// NewMemoryStorage creates a new memory storage | 创建内存存储 +func NewMemoryStorage() *memory.Storage { + return memory.NewStorage() } -// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器 -func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { - return core.NewRefreshTokenManager(storage, prefix, cfg) +// NewMemoryStorageWithCleanupInterval creates a new memory storage with cleanup interval | 创建带清理间隔的内存存储 +func NewMemoryStorageWithCleanupInterval(interval time.Duration) *memory.Storage { + return memory.NewStorageWithCleanupInterval(interval) } -// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器 -func NewOAuth2Server(storage Storage, prefix string) *OAuth2Server { - return core.NewOAuth2Server(storage, prefix) +// NewRedisStorage creates a new Redis storage from URL | 通过URL创建Redis存储 +func NewRedisStorage(url string) (*redis.Storage, error) { + return redis.NewStorage(url) } -// ============ Global StpUtil functions | 全局StpUtil函数 ============ +// NewRedisStorageFromConfig creates a new Redis storage from config | 通过配置创建Redis存储 +func NewRedisStorageFromConfig(cfg *redis.Config) (*redis.Storage, error) { + return redis.NewStorageFromConfig(cfg) +} -// SetManager sets the global Manager (must be called first) | 设置全局Manager(必须先调用此方法) -func SetManager(mgr *Manager) { - stputil.SetManager(mgr) +// NewRedisBuilder creates a new Redis builder | 创建Redis构建器 +func NewRedisBuilder() *redis.Builder { + return redis.NewBuilder() +} + +// ============ Logger Constructors | 日志构造函数 ============ + +// NewSlogLogger creates a new slog logger with config | 使用配置创建标准日志器 +func NewSlogLogger(cfg *slog.LoggerConfig) (*slog.Logger, error) { + return slog.NewLoggerWithConfig(cfg) +} + +// NewNopLogger creates a new no-op logger | 创建空日志器 +func NewNopLogger() *nop.NopLogger { + return nop.NewNopLogger() } -// GetManager gets the global Manager | 获取全局Manager -func GetManager() *Manager { - return stputil.GetManager() +// NewGfLogger creates a GF logger adapter | 创建GoFrame日志适配器 +func NewGfLogger(ctx context.Context, log *glog.Logger) *gf.GFLogger { + return gf.NewGFLogger(ctx, log) } +// ============ Generator Constructors | 生成器构造函数 ============ + +// NewTokenGenerator creates a new token generator | 创建Token生成器 +func NewTokenGenerator(timeout int64, tokenStyle adapter.TokenStyle, jwtSecretKey string) *sgenerator.Generator { + return sgenerator.NewGenerator(timeout, tokenStyle, jwtSecretKey) +} + +// NewDefaultTokenGenerator creates a new default token generator | 创建默认Token生成器 +func NewDefaultTokenGenerator() *sgenerator.Generator { + return sgenerator.NewDefaultGenerator() +} + +// ============ Pool Constructors | 协程池构造函数 ============ + +// NewRenewPoolManager creates a new renew pool manager-example with default config | 使用默认配置创建续期池管理器 +func NewRenewPoolManager() *ants.RenewPoolManager { + return ants.NewRenewPoolManagerWithDefaultConfig() +} + +// NewRenewPoolManagerWithConfig creates a new renew pool manager-example with config | 使用配置创建续期池管理器 +func NewRenewPoolManagerWithConfig(cfg *ants.RenewPoolConfig) (*ants.RenewPoolManager, error) { + return ants.NewRenewPoolManagerWithConfig(cfg) +} + +// ============ Token Style Constants | Token风格常量 ============ + +const ( + // TokenStyleUUID UUID style | UUID风格 + TokenStyleUUID = adapter.TokenStyleUUID + // TokenStyleSimple Simple random string | 简单随机字符串 + TokenStyleSimple = adapter.TokenStyleSimple + // TokenStyleRandom32 32-bit random string | 32位随机字符串 + TokenStyleRandom32 = adapter.TokenStyleRandom32 + // TokenStyleRandom64 64-bit random string | 64位随机字符串 + TokenStyleRandom64 = adapter.TokenStyleRandom64 + // TokenStyleRandom128 128-bit random string | 128位随机字符串 + TokenStyleRandom128 = adapter.TokenStyleRandom128 + // TokenStyleJWT JWT style | JWT风格 + TokenStyleJWT = adapter.TokenStyleJWT + // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 + TokenStyleHash = adapter.TokenStyleHash + // TokenStyleTimestamp Timestamp-based style | 时间戳风格 + TokenStyleTimestamp = adapter.TokenStyleTimestamp + // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID + TokenStyleTik = adapter.TokenStyleTik +) + +// ============ Log Level Constants | 日志级别常量 ============ + +const ( + // LogLevelDebug Debug level | 调试级别 + LogLevelDebug = adapter.LogLevelDebug + // LogLevelInfo Info level | 信息级别 + LogLevelInfo = adapter.LogLevelInfo + // LogLevelWarn Warn level | 警告级别 + LogLevelWarn = adapter.LogLevelWarn + // LogLevelError Error level | 错误级别 + LogLevelError = adapter.LogLevelError +) + // ============ Authentication | 登录认证 ============ // Login performs user login | 用户登录 -func Login(loginID interface{}, device ...string) (string, error) { - return stputil.Login(loginID, device...) +func Login(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.Login(ctx, loginID, deviceOrAutoType...) } // LoginByToken performs login with specified token | 使用指定Token登录 -func LoginByToken(loginID interface{}, tokenValue string, device ...string) error { - return stputil.LoginByToken(loginID, tokenValue, device...) +func LoginByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LoginByToken(ctx, tokenValue, authType...) } // Logout performs user logout | 用户登出 -func Logout(loginID interface{}, device ...string) error { - return stputil.Logout(loginID, device...) +func Logout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Logout(ctx, loginID, deviceOrAutoType...) } // LogoutByToken performs logout by token | 根据Token登出 -func LogoutByToken(tokenValue string) error { - return stputil.LogoutByToken(tokenValue) +func LogoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LogoutByToken(ctx, tokenValue, authType...) +} + +// Kickout kicks out a user session | 踢人下线 +func Kickout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Kickout(ctx, loginID, deviceOrAutoType...) +} + +// KickoutByToken Kick user offline | 根据Token踢人下线 +func KickoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.KickoutByToken(ctx, tokenValue, authType...) +} + +// Replace user offline by login ID and device | 根据账号和设备顶人下线 +func Replace(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Replace(ctx, loginID, deviceOrAutoType...) } +// ReplaceByToken Replace user offline by token | 根据Token顶人下线 +func ReplaceByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.ReplaceByToken(ctx, tokenValue, authType...) +} + +// ============ Token Validation | Token验证 ============ + // IsLogin checks if the user is logged in | 检查用户是否已登录 -func IsLogin(tokenValue string) bool { - return stputil.IsLogin(tokenValue) +func IsLogin(ctx context.Context, tokenValue string, authType ...string) (bool, error) { + return stputil.IsLogin(ctx, tokenValue, authType...) } // CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) -func CheckLogin(tokenValue string) error { - return stputil.CheckLogin(tokenValue) +func CheckLogin(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.CheckLogin(ctx, tokenValue, authType...) } // GetLoginID gets the login ID from token | 从Token获取登录ID -func GetLoginID(tokenValue string) (string, error) { - return stputil.GetLoginID(tokenValue) +func GetLoginID(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginID(ctx, tokenValue, authType...) } -// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查) -func GetLoginIDNotCheck(tokenValue string) (string, error) { - return stputil.GetLoginIDNotCheck(tokenValue) +// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查登录状态) +func GetLoginIDNotCheck(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginIDNotCheck(ctx, tokenValue, authType...) } // GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 -func GetTokenValue(loginID interface{}, device ...string) (string, error) { - return stputil.GetTokenValue(loginID, device...) +func GetTokenValue(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.GetTokenValue(ctx, loginID, deviceOrAutoType...) } // GetTokenInfo gets token information | 获取Token信息 -func GetTokenInfo(tokenValue string) (*TokenInfo, error) { - return stputil.GetTokenInfo(tokenValue) +func GetTokenInfo(ctx context.Context, tokenValue string, authType ...string) (*manager.TokenInfo, error) { + return stputil.GetTokenInfo(ctx, tokenValue, authType...) } -// ============ Kickout | 踢人下线 ============ +// ============ Account Disable | 账号封禁 ============ -// Kickout kicks out a user session | 踢人下线 -func Kickout(loginID interface{}, device ...string) error { - return stputil.Kickout(loginID, device...) +// Disable disables an account for specified duration | 封禁账号(指定时长) +func Disable(ctx context.Context, loginID interface{}, duration time.Duration, authType ...string) error { + return stputil.Disable(ctx, loginID, duration, authType...) } -// ============ Account Disable | 账号封禁 ============ +// DisableByToken disables the account associated with the given token for a duration | 根据指定 Token 封禁其对应的账号 +func DisableByToken(ctx context.Context, tokenValue string, duration time.Duration, authType ...string) error { + return stputil.DisableByToken(ctx, tokenValue, duration, authType...) +} -// Disable disables an account for specified duration | 封禁账号(指定时长) -func Disable(loginID interface{}, duration time.Duration) error { - return stputil.Disable(loginID, duration) +// Untie re-enables a disabled account | 解封账号 +func Untie(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.Untie(ctx, loginID, authType...) +} + +// UntieByToken re-enables a disabled account by token | 根据Token解封账号 +func UntieByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.UntieByToken(ctx, tokenValue, authType...) } // IsDisable checks if an account is disabled | 检查账号是否被封禁 -func IsDisable(loginID interface{}) bool { - return stputil.IsDisable(loginID) +func IsDisable(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.IsDisable(ctx, loginID, authType...) } -// CheckDisable checks if account is disabled (throws error if disabled) | 检查账号是否被封禁(被封禁则抛出错误) -func CheckDisableByToken(tokenValue string) error { - return stputil.CheckDisable(tokenValue) +// IsDisableByToken checks if an account is disabled by token | 根据Token检查账号是否被封禁 +func IsDisableByToken(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsDisableByToken(ctx, tokenValue, authType...) } -// GetDisableTime gets remaining disabled time | 获取账号剩余封禁时间 -func GetDisableTime(loginID interface{}) (int64, error) { - return stputil.GetDisableTime(loginID) +// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) +func GetDisableTime(ctx context.Context, loginID interface{}, authType ...string) (int64, error) { + return stputil.GetDisableTime(ctx, loginID, authType...) } -// Untie unties/unlocks an account | 解除账号封禁 -func Untie(loginID interface{}) error { - return stputil.Untie(loginID) +// GetDisableTimeByToken gets remaining disable time by token | 根据Token获取剩余封禁时间(秒) +func GetDisableTimeByToken(ctx context.Context, tokenValue string, authType ...string) (int64, error) { + return stputil.GetDisableTimeByToken(ctx, tokenValue, authType...) } -// ============ Permission Check | 权限验证 ============ +// GetDisableInfo gets disable info | 获取封禁信息 +func GetDisableInfo(ctx context.Context, loginID interface{}, authType ...string) (*manager.DisableInfo, error) { + return stputil.GetDisableInfo(ctx, loginID, authType...) +} -// CheckPermission checks if the account has specified permission | 检查账号是否拥有指定权限 -func CheckPermissionByToken(tokenValue string, permission string) error { - return stputil.CheckPermission(tokenValue, permission) +// GetDisableInfoByToken gets disable info by token | 根据Token获取封禁信息 +func GetDisableInfoByToken(ctx context.Context, tokenValue string, authType ...string) (*manager.DisableInfo, error) { + return stputil.GetDisableInfoByToken(ctx, tokenValue, authType...) } -// HasPermission checks if the account has specified permission (returns bool) | 检查账号是否拥有指定权限(返回布尔值) -func HasPermission(loginID interface{}, permission string) bool { - return stputil.HasPermission(loginID, permission) +// ============ Session Management | Session管理 ============ + +// GetSession gets session by login ID | 根据登录ID获取Session +func GetSession(ctx context.Context, loginID interface{}, authType ...string) (*session.Session, error) { + return stputil.GetSession(ctx, loginID, authType...) +} + +// GetSessionByToken gets session by token | 根据Token获取Session +func GetSessionByToken(ctx context.Context, tokenValue string, authType ...string) (*session.Session, error) { + return stputil.GetSessionByToken(ctx, tokenValue, authType...) } -// CheckPermissionAnd checks if the account has all specified permissions (AND logic) | 检查账号是否拥有所有指定权限(AND逻辑) -func CheckPermissionAndByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionAnd(tokenValue, permissions) +// DeleteSession deletes a session | 删除Session +func DeleteSession(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.DeleteSession(ctx, loginID, authType...) } -// CheckPermissionOr checks if the account has any of the specified permissions (OR logic) | 检查账号是否拥有指定权限中的任意一个(OR逻辑) -func CheckPermissionOrByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionOr(tokenValue, permissions) +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func DeleteSessionByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.DeleteSessionByToken(ctx, tokenValue, authType...) } -// GetPermissionList gets the permission list for an account | 获取账号的权限列表 -func GetPermissionListByToken(tokenValue string) ([]string, error) { - return stputil.GetPermissionList(tokenValue) +// HasSession checks if session exists | 检查Session是否存在 +func HasSession(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.HasSession(ctx, loginID, authType...) } -// ============ Role Check | 角色验证 ============ +// RenewSession renews session TTL | 续期Session +func RenewSession(ctx context.Context, loginID interface{}, ttl time.Duration, authType ...string) error { + return stputil.RenewSession(ctx, loginID, ttl, authType...) +} -// CheckRole checks if the account has specified role | 检查账号是否拥有指定角色 -func CheckRoleByToken(tokenValue string, role string) error { - return stputil.CheckRole(tokenValue, role) +// RenewSessionByToken renews session TTL by token | 根据Token续期Session +func RenewSessionByToken(ctx context.Context, tokenValue string, ttl time.Duration, authType ...string) error { + return stputil.RenewSessionByToken(ctx, tokenValue, ttl, authType...) } -// HasRole checks if the account has specified role (returns bool) | 检查账号是否拥有指定角色(返回布尔值) -func HasRole(loginID interface{}, role string) bool { - return stputil.HasRole(loginID, role) +// ============ Permission Verification | 权限验证 ============ + +// SetPermissions sets permissions for a login ID | 设置用户权限 +func SetPermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.SetPermissions(ctx, loginID, permissions, authType...) } -// CheckRoleAnd checks if the account has all specified roles (AND logic) | 检查账号是否拥有所有指定角色(AND逻辑) -func CheckRoleAndByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleAnd(tokenValue, roles) +// SetPermissionsByToken sets permissions by token | 根据 Token 设置对应账号的权限 +func SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.SetPermissionsByToken(ctx, tokenValue, permissions, authType...) } -// CheckRoleOr checks if the account has any of the specified roles (OR logic) | 检查账号是否拥有指定角色中的任意一个(OR逻辑) -func CheckRoleOrByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleOr(tokenValue, roles) +// RemovePermissions removes specified permissions for a login ID | 删除用户指定权限 +func RemovePermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.RemovePermissions(ctx, loginID, permissions, authType...) } -// GetRoleList gets the role list for an account | 获取账号的角色列表 -func GetRoleListByToken(tokenValue string) ([]string, error) { - return stputil.GetRoleList(tokenValue) +// RemovePermissionsByToken removes specified permissions by token | 根据 Token 删除对应账号的指定权限 +func RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.RemovePermissionsByToken(ctx, tokenValue, permissions, authType...) } -// ============ Session Management | Session管理 ============ +// GetPermissions gets permission list | 获取权限列表 +func GetPermissions(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetPermissions(ctx, loginID, authType...) +} + +// GetPermissionsByToken gets permission list by token | 根据 Token 获取对应账号的权限列表 +func GetPermissionsByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetPermissionsByToken(ctx, tokenValue, authType...) +} + +// HasPermission checks if has specified permission | 检查是否拥有指定权限 +func HasPermission(ctx context.Context, loginID interface{}, permission string, authType ...string) bool { + return stputil.HasPermission(ctx, loginID, permission, authType...) +} + +// HasPermissionByToken checks if the token has the specified permission | 检查Token是否拥有指定权限 +func HasPermissionByToken(ctx context.Context, tokenValue string, permission string, authType ...string) bool { + return stputil.HasPermissionByToken(ctx, tokenValue, permission, authType...) +} + +// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) +func HasPermissionsAnd(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAnd(ctx, loginID, permissions, authType...) +} + +// HasPermissionsAndByToken checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 +func HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAndByToken(ctx, tokenValue, permissions, authType...) +} + +// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) +func HasPermissionsOr(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOr(ctx, loginID, permissions, authType...) +} + +// HasPermissionsOrByToken checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 +func HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOrByToken(ctx, tokenValue, permissions, authType...) +} + +// ============ Role Management | 角色管理 ============ + +// SetRoles sets roles for a login ID | 设置用户角色 +func SetRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.SetRoles(ctx, loginID, roles, authType...) +} + +// SetRolesByToken sets roles by token | 根据 Token 设置对应账号的角色 +func SetRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.SetRolesByToken(ctx, tokenValue, roles, authType...) +} + +// RemoveRoles removes specified roles for a login ID | 删除用户指定角色 +func RemoveRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.RemoveRoles(ctx, loginID, roles, authType...) +} + +// RemoveRolesByToken removes specified roles by token | 根据 Token 删除对应账号的指定角色 +func RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.RemoveRolesByToken(ctx, tokenValue, roles, authType...) +} + +// GetRoles gets role list | 获取角色列表 +func GetRoles(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetRoles(ctx, loginID, authType...) +} + +// GetRolesByToken gets role list by token | 根据 Token 获取对应账号的角色列表 +func GetRolesByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetRolesByToken(ctx, tokenValue, authType...) +} + +// HasRole checks if has specified role | 检查是否拥有指定角色 +func HasRole(ctx context.Context, loginID interface{}, role string, authType ...string) bool { + return stputil.HasRole(ctx, loginID, role, authType...) +} + +// HasRoleByToken checks if the token has the specified role | 检查 Token 是否拥有指定角色 +func HasRoleByToken(ctx context.Context, tokenValue string, role string, authType ...string) bool { + return stputil.HasRoleByToken(ctx, tokenValue, role, authType...) +} + +// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) +func HasRolesAnd(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesAnd(ctx, loginID, roles, authType...) +} -// GetSession gets the session for a login ID | 获取登录ID的Session -func GetSession(loginID interface{}) (*Session, error) { - return stputil.GetSession(loginID) +// HasRolesAndByToken checks if the token has all specified roles | 检查 Token 是否拥有所有指定角色 +func HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesAndByToken(ctx, tokenValue, roles, authType...) } -// GetSessionByToken gets the session by token | 根据Token获取Session -func GetSessionByToken(tokenValue string) (*Session, error) { - return stputil.GetSessionByToken(tokenValue) +// HasRolesOr checks if has any role (OR logic) | 检查是否拥有任一角色(OR逻辑) +func HasRolesOr(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesOr(ctx, loginID, roles, authType...) } -// GetTokenSession gets the token session | 获取Token的Session -func GetTokenSession(tokenValue string) (*Session, error) { - return stputil.GetTokenSession(tokenValue) +// HasRolesOrByToken checks if the token has any of the specified roles | 检查 Token 是否拥有任一指定角色 +func HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesOrByToken(ctx, tokenValue, roles, authType...) } -// ============ Token Renewal | Token续期 ============ +// ============ Token Tag | Token标签 ============ + +// SetTokenTag sets token tag | 设置Token标签 +func SetTokenTag(ctx context.Context, tokenValue, tag string, authType ...string) error { + return stputil.SetTokenTag(ctx, tokenValue, tag, authType...) +} + +// GetTokenTag gets token tag | 获取Token标签 +func GetTokenTag(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetTokenTag(ctx, tokenValue, authType...) +} -// RenewTimeout renews token timeout | 续期Token超时时间 +// ============ Session Query | 会话查询 ============ + +// GetTokenValueListByLoginID gets all tokens for a login ID | 获取指定账号的所有Token +func GetTokenValueListByLoginID(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetTokenValueListByLoginID(ctx, loginID, authType...) +} + +// GetSessionCountByLoginID gets session count for a login ID | 获取指定账号的Session数量 +func GetSessionCountByLoginID(ctx context.Context, loginID interface{}, authType ...string) (int, error) { + return stputil.GetSessionCountByLoginID(ctx, loginID, authType...) +} // ============ Security Features | 安全特性 ============ -// GenerateNonce generates a new nonce token | 生成新的Nonce令牌 -func GenerateNonce() (string, error) { - return stputil.GenerateNonce() +// Generate Generates a one-time nonce | 生成一次性随机数 +func Generate(ctx context.Context, authType ...string) (string, error) { + return stputil.Generate(ctx, authType...) +} + +// Verify Verifies a nonce | 验证随机数 +func Verify(ctx context.Context, nonce string, authType ...string) bool { + return stputil.Verify(ctx, nonce, authType...) } -// VerifyNonce verifies a nonce token | 验证Nonce令牌 -func VerifyNonce(nonce string) bool { - return stputil.VerifyNonce(nonce) +// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func VerifyAndConsume(ctx context.Context, nonce string, authType ...string) error { + return stputil.VerifyAndConsume(ctx, nonce, authType...) } -// LoginWithRefreshToken performs login and returns both access token and refresh token | 登录并返回访问令牌和刷新令牌 -func LoginWithRefreshToken(loginID interface{}, device ...string) (*RefreshTokenInfo, error) { - return stputil.LoginWithRefreshToken(loginID, device...) +// IsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func IsValidNonce(ctx context.Context, nonce string, authType ...string) bool { + return stputil.IsValidNonce(ctx, nonce, authType...) } -// RefreshAccessToken refreshes the access token using a refresh token | 使用刷新令牌刷新访问令牌 -func RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) { - return stputil.RefreshAccessToken(refreshToken) +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func GenerateTokenPair(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GenerateTokenPair(ctx, loginID, deviceOrAutoType...) } -// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌 -func RevokeRefreshToken(refreshToken string) error { - return stputil.RevokeRefreshToken(refreshToken) +// VerifyAccessToken verifies access token validity | 验证访问令牌是否有效 +func VerifyAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.VerifyAccessToken(ctx, accessToken, authType...) +} + +// VerifyAccessTokenAndGetInfo verifies access token and returns token info | 验证访问令牌并返回Token信息 +func VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*security.AccessTokenInfo, bool) { + return stputil.VerifyAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息 +func GetRefreshTokenInfo(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GetRefreshTokenInfo(ctx, refreshToken, authType...) +} + +// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func RefreshAccessToken(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.RefreshAccessToken(ctx, refreshToken, authType...) +} + +// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 +func RevokeRefreshToken(ctx context.Context, refreshToken string, authType ...string) error { + return stputil.RevokeRefreshToken(ctx, refreshToken, authType...) +} + +// IsValid checks whether token is valid | 检查Token是否有效 +func IsValid(ctx context.Context, refreshToken string, authType ...string) bool { + return stputil.IsValid(ctx, refreshToken, authType...) +} + +// ============ OAuth2 Features | OAuth2 功能 ============ + +// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func RegisterClient(ctx context.Context, client *oauth2.Client, authType ...string) error { + return stputil.RegisterClient(ctx, client, authType...) } -// GetOAuth2Server gets the OAuth2 server instance | 获取OAuth2服务器实例 -func GetOAuth2Server() *OAuth2Server { - return stputil.GetOAuth2Server() +// UnregisterClient unregisters an OAuth2 client | 注销OAuth2客户端 +func UnregisterClient(ctx context.Context, clientID string, authType ...string) error { + return stputil.UnregisterClient(ctx, clientID, authType...) } -// Version Sa-Token-Go version | Sa-Token-Go版本 -const Version = core.Version +// GetClient gets OAuth2 client information | 获取OAuth2客户端信息 +func GetClient(ctx context.Context, clientID string, authType ...string) (*oauth2.Client, error) { + return stputil.GetClient(ctx, clientID, authType...) +} + +// GenerateAuthorizationCode creates an authorization code | 创建授权码 +func GenerateAuthorizationCode(ctx context.Context, clientID, loginID, redirectURI string, scope []string, authType ...string) (*oauth2.AuthorizationCode, error) { + return stputil.GenerateAuthorizationCode(ctx, clientID, loginID, redirectURI, scope, authType...) +} + +// ExchangeCodeForToken exchanges authorization code for token | 使用授权码换取令牌 +func ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI, authType...) +} + +// ValidateAccessToken verifies OAuth2 access token | 验证OAuth2访问令牌 +func ValidateAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.ValidateAccessToken(ctx, accessToken, authType...) +} + +// ValidateAccessTokenAndGetInfo verifies OAuth2 access token and get info | 验证OAuth2访问令牌并获取信息 +func ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ValidateAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌(OAuth2) +func OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2RefreshAccessToken(ctx, clientID, refreshToken, clientSecret, authType...) +} + +// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func RevokeToken(ctx context.Context, accessToken string, authType ...string) error { + return stputil.RevokeToken(ctx, accessToken, authType...) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点 +func OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2Token(ctx, req, validateUser, authType...) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2ClientCredentialsToken(ctx, clientID, clientSecret, scopes, authType...) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser, authType...) +} + +// ============ OAuth2 Grant Type Constants | OAuth2授权类型常量 ============ + +const ( + // GrantTypeAuthorizationCode Authorization code grant type | 授权码模式 + GrantTypeAuthorizationCode = oauth2.GrantTypeAuthorizationCode + // GrantTypeClientCredentials Client credentials grant type | 客户端凭证模式 + GrantTypeClientCredentials = oauth2.GrantTypeClientCredentials + // GrantTypePassword Password grant type | 密码模式 + GrantTypePassword = oauth2.GrantTypePassword + // GrantTypeRefreshToken Refresh token grant type | 刷新令牌模式 + GrantTypeRefreshToken = oauth2.GrantTypeRefreshToken +) + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func GetConfig(ctx context.Context, authType ...string) *config.Config { + return stputil.GetConfig(ctx, authType...) +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func GetStorage(ctx context.Context, authType ...string) adapter.Storage { + return stputil.GetStorage(ctx, authType...) +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func GetCodec(ctx context.Context, authType ...string) adapter.Codec { + return stputil.GetCodec(ctx, authType...) +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func GetLog(ctx context.Context, authType ...string) adapter.Log { + return stputil.GetLog(ctx, authType...) +} + +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func GetPool(ctx context.Context, authType ...string) adapter.Pool { + return stputil.GetPool(ctx, authType...) +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func GetGenerator(ctx context.Context, authType ...string) adapter.Generator { + return stputil.GetGenerator(ctx, authType...) +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func GetNonceManager(ctx context.Context, authType ...string) *security.NonceManager { + return stputil.GetNonceManager(ctx, authType...) +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func GetRefreshManager(ctx context.Context, authType ...string) *security.RefreshTokenManager { + return stputil.GetRefreshManager(ctx, authType...) +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func GetEventManager(ctx context.Context, authType ...string) *listener.Manager { + return stputil.GetEventManager(ctx, authType...) +} + +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func GetOAuth2Server(ctx context.Context, authType ...string) *oauth2.OAuth2Server { + return stputil.GetOAuth2Server(ctx, authType...) +} + +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func RegisterFunc(event listener.Event, fn func(*listener.EventData), authType ...string) { + stputil.RegisterFunc(event, fn, authType...) +} + +// Register registers an event listener | 注册事件监听器 +func Register(event listener.Event, l listener.Listener, authType ...string) string { + return stputil.Register(event, l, authType...) +} + +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func RegisterWithConfig(event listener.Event, l listener.Listener, cfg listener.ListenerConfig, authType ...string) string { + return stputil.RegisterWithConfig(event, l, cfg, authType...) +} + +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func Unregister(id string, authType ...string) bool { + return stputil.Unregister(id, authType...) +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func TriggerEvent(data *listener.EventData, authType ...string) { + stputil.TriggerEvent(data, authType...) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func TriggerEventAsync(data *listener.EventData, authType ...string) { + stputil.TriggerEventAsync(data, authType...) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func TriggerEventSync(data *listener.EventData, authType ...string) { + stputil.TriggerEventSync(data, authType...) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func WaitEvents(authType ...string) { + stputil.WaitEvents(authType...) +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func ClearEventListeners(event listener.Event, authType ...string) { + stputil.ClearEventListeners(event, authType...) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func ClearAllEventListeners(authType ...string) { + stputil.ClearAllEventListeners(authType...) +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func CountEventListeners(event listener.Event, authType ...string) int { + return stputil.CountEventListeners(event, authType...) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func CountAllListeners(authType ...string) int { + return stputil.CountAllListeners(authType...) +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func GetEventListenerIDs(event listener.Event, authType ...string) []string { + return stputil.GetEventListenerIDs(event, authType...) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func GetAllRegisteredEvents(authType ...string) []listener.Event { + return stputil.GetAllRegisteredEvents(authType...) +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func HasEventListeners(event listener.Event, authType ...string) bool { + return stputil.HasEventListeners(event, authType...) +} diff --git a/integrations/gf/go.mod b/integrations/gf/go.mod index 340f21f..6392805 100644 --- a/integrations/gf/go.mod +++ b/integrations/gf/go.mod @@ -1,23 +1,22 @@ module github.com/click33/sa-token-go/integrations/gf -go 1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 - github.com/click33/sa-token-go/stputil v0.1.5 + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 github.com/gogf/gf/v2 v2.9.4 ) require ( github.com/BurntSushi/toml v1.5.0 // indirect github.com/clbanning/mxj/v2 v2.7.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/emirpasic/gods v1.18.1 // indirect github.com/fatih/color v1.18.0 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/grokify/html-strip-tags-go v0.1.0 // indirect @@ -29,7 +28,6 @@ require ( github.com/olekukonko/ll v0.0.9 // indirect github.com/olekukonko/tablewriter v1.1.0 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rivo/uniseg v0.4.7 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect go.opentelemetry.io/otel v1.38.0 // indirect @@ -37,7 +35,7 @@ require ( go.opentelemetry.io/otel/sdk v1.38.0 // indirect go.opentelemetry.io/otel/trace v1.38.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/integrations/gf/go.sum b/integrations/gf/go.sum index 18b899f..b11f761 100644 --- a/integrations/gf/go.sum +++ b/integrations/gf/go.sum @@ -1,8 +1,8 @@ github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= -github.com/click33/sa-token-go/core v0.1.4 h1:mODeJ0WKSusQmiO5b/uK9UDD0OFqLrkJ/j7W/2e3Ios= -github.com/click33/sa-token-go/core v0.1.4/go.mod h1:LK9zMyf3L8adSAYkAQj8ypwxKicS1q0qxdWV5uLDD6E= -github.com/click33/sa-token-go/stputil v0.1.4 h1:YvMEwPfAfTunQn+AePudO3Esp0CvLoc2o5kmg/uZf/c= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/core v0.1.6/go.mod h1:mb3AQAJIXqx9WdULyn5qjufK1j/u+kgB0q+tafHVhgk= +github.com/click33/sa-token-go/stputil v0.1.5 h1:603tbI4JkBTg3MnfTj+lCMDxJOKSCOqsMyC2zyuvEco= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= @@ -10,7 +10,7 @@ github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/gogf/gf/v2 v2.9.4 h1:6vleEWypot9WBPncP2GjbpgAUeG6Mzb1YESb9nPMkjY= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= @@ -37,7 +37,7 @@ go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6 go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/integrations/gf/middleware.go b/integrations/gf/middleware.go new file mode 100644 index 0000000..ebaf95b --- /dev/null +++ b/integrations/gf/middleware.go @@ -0,0 +1,405 @@ +package gf + +import ( + "context" + "errors" + "net/http" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/manager" + + saContext "github.com/click33/sa-token-go/core/context" + "github.com/click33/sa-token-go/stputil" + "github.com/gogf/gf/v2/frame/g" + "github.com/gogf/gf/v2/net/ghttp" +) + +// LogicType permission/role logic type | 权限/角色判断逻辑 +type LogicType string + +const ( + SaTokenCtxKey = "saCtx" + + LogicOr LogicType = "OR" // Logical OR | 任一满足 + LogicAnd LogicType = "AND" // Logical AND | 全部满足 +) + +type AuthOption func(*AuthOptions) + +type AuthOptions struct { + AuthType string + LogicType LogicType + FailFunc func(r *ghttp.Request, err error) +} + +func defaultAuthOptions() *AuthOptions { + return &AuthOptions{LogicType: LogicAnd} // 默认 AND +} + +// WithAuthType sets auth type | 设置认证类型 +func WithAuthType(authType string) AuthOption { + return func(o *AuthOptions) { + o.AuthType = authType + } +} + +func WithLogicType(logicType LogicType) AuthOption { + return func(o *AuthOptions) { + o.LogicType = logicType + } +} + +// WithFailFunc sets auth failure callback | 设置认证失败回调 +func WithFailFunc(fn func(r *ghttp.Request, err error)) AuthOption { + return func(o *AuthOptions) { + o.FailFunc = fn + } +} + +// ============ Middlewares | 中间件 ============ + +// RegisterSaTokenContextMiddleware initializes Sa-Token context for each request | 初始化每次请求的 Sa-Token 上下文的中间件 +func RegisterSaTokenContextMiddleware(ctx context.Context, opts ...AuthOption) ghttp.HandlerFunc { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(r *ghttp.Request) { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + + _ = getSaContext(r, mgr) + } +} + +// AuthMiddleware authentication middleware | 认证中间件 +func AuthMiddleware(ctx context.Context, opts ...AuthOption) ghttp.HandlerFunc { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(r *ghttp.Request) { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + + saCtx := getSaContext(r, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录 | Check login + isLogin, err := mgr.IsLogin(ctx, tokenValue) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + if !isLogin { + if options.FailFunc != nil { + options.FailFunc(r, core.ErrTokenExpired) + } else { + writeErrorResponse(r, core.ErrTokenExpired) + } + return + } + + r.Middleware.Next() + } +} + +// PermissionMiddleware permission check middleware | 权限校验中间件 +func PermissionMiddleware( + ctx context.Context, + permissions []string, + opts ...AuthOption, +) ghttp.HandlerFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(r *ghttp.Request) { + // No permission required | 无需权限直接放行 + if len(permissions) == 0 { + r.Middleware.Next() + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + + saCtx := getSaContext(r, mgr) + tokenValue := saCtx.GetTokenValue() + + // Permission check | 权限校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasPermissionsAndByToken(ctx, tokenValue, permissions) + } else { + ok = mgr.HasPermissionsOrByToken(ctx, tokenValue, permissions) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(r, core.ErrPermissionDenied) + } else { + writeErrorResponse(r, core.ErrPermissionDenied) + } + return + } + + r.Middleware.Next() + } +} + +// PermissionPathMiddleware permission check middleware | 基于路径的权限校验中间件 +func PermissionPathMiddleware( + ctx context.Context, + permissions []string, + opts ...AuthOption, +) ghttp.HandlerFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(r *ghttp.Request) { + // Create a per-request copy of permissions and append current path | 每次请求创建权限副本并追加当前路径 + reqPermissions := append([]string{}, permissions...) + reqPermissions = append(reqPermissions, r.URL.Path) + + if len(reqPermissions) == 0 { + r.Middleware.Next() + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + + saCtx := getSaContext(r, mgr) + tokenValue := saCtx.GetTokenValue() + + // Permission check | 权限校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasPermissionsAndByToken(ctx, tokenValue, reqPermissions) + } else { + ok = mgr.HasPermissionsOrByToken(ctx, tokenValue, reqPermissions) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(r, core.ErrPermissionDenied) + } else { + writeErrorResponse(r, core.ErrPermissionDenied) + } + return + } + + r.Middleware.Next() + } +} + +// RoleMiddleware role check middleware | 角色校验中间件 +func RoleMiddleware( + ctx context.Context, + roles []string, + opts ...AuthOption, +) ghttp.HandlerFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(r *ghttp.Request) { + // No role required | 无需角色直接放行 + if len(roles) == 0 { + r.Middleware.Next() + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(r, err) + } else { + writeErrorResponse(r, err) + } + return + } + + saCtx := getSaContext(r, mgr) + tokenValue := saCtx.GetTokenValue() + + // Role check | 角色校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasRolesAndByToken(ctx, tokenValue, roles) + } else { + ok = mgr.HasRolesOrByToken(ctx, tokenValue, roles) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(r, core.ErrRoleDenied) + } else { + writeErrorResponse(r, core.ErrRoleDenied) + } + return + } + + r.Middleware.Next() + } +} + +// GetSaTokenContext gets Sa-Token context from GoFrame context | 获取 Sa-Token 上下文 +func GetSaTokenContext(r *ghttp.Request) (*saContext.SaTokenContext, bool) { + v := r.GetCtxVar(SaTokenCtxKey) + if v == nil { + return nil, false + } + + ctx, ok := v.Val().(*saContext.SaTokenContext) + return ctx, ok +} + +// GetSaTokenContextByCtx gets Sa-Token context from GoFrame context | 获取 Sa-Token 上下文 +func GetSaTokenContextByCtx(ctx context.Context) (*saContext.SaTokenContext, bool) { + request := g.RequestFromCtx(ctx) + ctxVar := request.GetCtxVar(SaTokenCtxKey) + if ctxVar == nil { + return nil, false + } + + tokenContext, ok := ctxVar.Val().(*saContext.SaTokenContext) + return tokenContext, ok +} + +// GetLoginIDByCtx gets the login ID from the context | 从上下文获取登录ID +func GetLoginIDByCtx(ctx context.Context, authType ...string) (string, error) { + mgr, err := stputil.GetManager(authType...) + if err != nil { + return "", err + } + + return mgr.GetLoginIDNotCheck(ctx, getSaContext(g.RequestFromCtx(ctx), mgr).GetTokenValue()) +} + +// GetTokenInfoByCtx gets the token information from the context | 从上下文获取Token信息 +func GetTokenInfoByCtx(ctx context.Context, authType ...string) (*manager.TokenInfo, error) { + mgr, err := stputil.GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.GetTokenInfoByToken(ctx, getSaContext(g.RequestFromCtx(ctx), mgr).GetTokenValue()) +} + +// getSaContext returns or creates the Sa-Token context for the request | 获取或创建当前请求的 Sa-Token 上下文 +func getSaContext(r *ghttp.Request, mgr *manager.Manager) *saContext.SaTokenContext { + // Try get from context | 尝试从 ctx 取值 + if v := r.GetCtxVar(SaTokenCtxKey); v != nil { + // gvar.Var -> interface{} -> *SaTokenContext + if saCtx, ok := v.Val().(*saContext.SaTokenContext); ok { + return saCtx + } + } + + // Create new context | 创建并缓存 SaTokenContext + saCtx := saContext.NewContext(NewGFContext(r), mgr) + r.SetCtxVar(SaTokenCtxKey, saCtx) + + return saCtx +} + +// ============ Error Handling Helpers | 错误处理辅助函数 ============ + +// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 +func writeErrorResponse(r *ghttp.Request, err error) { + var saErr *core.SaTokenError + var code int + var message string + var httpStatus int + + // Check if it's a SaTokenError | 检查是否为SaTokenError + if errors.As(err, &saErr) { + code = saErr.Code + message = saErr.Message + httpStatus = getHTTPStatusFromCode(code) + } else { + // Handle standard errors | 处理标准错误 + code = core.CodeServerError + message = err.Error() + httpStatus = http.StatusInternalServerError + } + + r.Response.WriteStatusExit(httpStatus, g.Map{ + "code": code, + "message": message, + "data": err.Error(), + }) +} + +// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 +func writeSuccessResponse(r *ghttp.Request, data interface{}) { + r.Response.WriteStatusExit(http.StatusOK, g.Map{ + "code": core.CodeSuccess, + "message": "success", + "data": data, + }) +} + +// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 +func getHTTPStatusFromCode(code int) int { + switch code { + case core.CodeNotLogin: + return http.StatusUnauthorized + case core.CodePermissionDenied: + return http.StatusForbidden + case core.CodeBadRequest: + return http.StatusBadRequest + case core.CodeNotFound: + return http.StatusNotFound + case core.CodeServerError: + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} diff --git a/integrations/gf/plugin.go b/integrations/gf/plugin.go deleted file mode 100644 index d9d18db..0000000 --- a/integrations/gf/plugin.go +++ /dev/null @@ -1,313 +0,0 @@ -package gf - -import ( - "errors" - "net/http" - "strings" - - "github.com/click33/sa-token-go/core" - "github.com/gogf/gf/v2/frame/g" - "github.com/gogf/gf/v2/net/ghttp" -) - -type MiddlewareType string - -var ( - MiddlewareTypeOr MiddlewareType = "MiddlewareTypeOr" // Logical OR permission mode | “或” 逻辑权限模式 - MiddlewareTypeAnd MiddlewareType = "MiddlewareTypeAnd" // Logical AND permission mode | “与” 逻辑权限模式 -) - -// Plugin GoFrame plugin for Sa-Token | GoFrame插件 -type Plugin struct { - manager *core.Manager -} - -// NewPlugin creates an GoFrame plugin | 创建GoFrame插件 -func NewPlugin(manager *core.Manager) *Plugin { - return &Plugin{ - manager: manager, - } -} - -// AuthMiddleware authentication middleware | 认证中间件 -func (p *Plugin) AuthMiddleware() ghttp.HandlerFunc { - return func(r *ghttp.Request) { - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - // Check login | 检查登录 - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(r, err) - return - } - // Store Sa-Token context in GoFrame context | 将Sa-Token上下文存储到GoFrame上下文 - r.SetCtxVar("satoken", saCtx) - - r.Middleware.Next() - } - -} - -// PermissionRequired permission validation middleware | 权限验证中间件 -func (p *Plugin) PermissionRequired(permission string) ghttp.HandlerFunc { - return func(r *ghttp.Request) { - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(r, err) - return - } - if !saCtx.HasPermission(permission) { - writeErrorResponse(r, core.NewPermissionDeniedError(permission)) - return - } - r.SetCtxVar("satoken", saCtx) - r.Middleware.Next() - } - -} - -// RoleRequired role validation middleware | 角色验证中间件 -func (p *Plugin) RoleRequired(role string) ghttp.HandlerFunc { - return func(r *ghttp.Request) { - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(r, err) - return - } - - if !saCtx.HasRole(role) { - writeErrorResponse(r, core.NewRoleDeniedError(role)) - return - } - - r.SetCtxVar("satoken", saCtx) - r.Middleware.Next() - } -} - -// HandlerAuthMiddleware — Authentication check middleware | 认证校验中间件 -func (p *Plugin) HandlerAuthMiddleware(authFailedFunc ...func(r *ghttp.Request)) ghttp.HandlerFunc { - return func(r *ghttp.Request) { - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - // Check login | 检查登录 - if err := saCtx.CheckLogin(); err != nil { - if len(authFailedFunc) > 0 && authFailedFunc[0] != nil { - authFailedFunc[0](r) - return - } - writeErrorResponse(r, err) - return - } - - // Store Sa-Token context in GoFrame context | 将Sa-Token上下文存储到GoFrame上下文 - r.SetCtxVar("satoken", saCtx) - - r.Middleware.Next() - } -} - -// HandlerPermissionRequiredMiddleware — Permission check middleware | 权限校验中间件 -func (p *Plugin) HandlerPermissionRequiredMiddleware(middlewareType MiddlewareType, permissions []string, permFailedFunc ...func(r *ghttp.Request)) ghttp.HandlerFunc { - return func(r *ghttp.Request) { - if len(permissions) == 0 { // Skip if no permission required | 无需权限则跳过 - r.Middleware.Next() - return - } - - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - loginID, err := saCtx.GetLoginID() - if err != nil { - if len(permFailedFunc) > 0 && permFailedFunc[0] != nil { - permFailedFunc[0](r) - return - } - writeErrorResponse(r, err) - return - } - - var hasPerm bool - switch middlewareType { - case MiddlewareTypeOr: - hasPerm = saCtx.GetManager().HasPermissionsOr(loginID, permissions) // OR check | 任一权限满足即可 - case MiddlewareTypeAnd: - hasPerm = saCtx.GetManager().HasPermissionsAnd(loginID, permissions) // AND check | 所有权限都需满足 - default: - hasPerm = false - } - - if !hasPerm { // No permission | 权限不足 - if len(permFailedFunc) > 0 && permFailedFunc[0] != nil { - permFailedFunc[0](r) - return - } - writeErrorResponse(r, core.NewPermissionDeniedError(strings.Join(permissions, ","))) - return - } - - r.Middleware.Next() // Continue | 继续执行 - } -} - -// HandlerRoleRequiredMiddleware — Role check middleware | 角色校验中间件 -func (p *Plugin) HandlerRoleRequiredMiddleware(middlewareType MiddlewareType, roles []string, roleFailedFunc ...func(r *ghttp.Request)) ghttp.HandlerFunc { - return func(r *ghttp.Request) { - if len(roles) == 0 { // Skip if no role required | 无需角色则跳过 - r.Middleware.Next() - return - } - - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - loginID, err := saCtx.GetLoginID() - if err != nil { - if len(roleFailedFunc) > 0 && roleFailedFunc[0] != nil { - roleFailedFunc[0](r) - return - } - writeErrorResponse(r, err) - return - } - - var hasRole bool - switch middlewareType { - case MiddlewareTypeOr: - hasRole = saCtx.GetManager().HasRolesOr(loginID, roles) // OR mode | 任一角色满足即可 - case MiddlewareTypeAnd: - hasRole = saCtx.GetManager().HasRolesAnd(loginID, roles) // AND mode | 所有角色都需满足 - default: - hasRole = false - } - - if !hasRole { // No required role | 无权限角色 - if len(roleFailedFunc) > 0 && roleFailedFunc[0] != nil { - roleFailedFunc[0](r) - return - } - writeErrorResponse(r, core.NewRoleDeniedError(strings.Join(roles, ","))) - return - } - - r.Middleware.Next() // Continue | 继续执行 - } -} - -// LoginHandler 登录处理器 -func (p *Plugin) LoginHandler(r *ghttp.Request) { - var req struct { - Username string `json:"username"` - Password string `json:"password"` - Device string `json:"device"` - } - - if err := r.Parse(&req); err != nil { - writeErrorResponse(r, core.NewError(core.CodeBadRequest, "invalid request parameters", err)) - return - } - - device := req.Device - if device == "" { - device = "default" - } - - token, err := p.manager.Login(req.Username, device) - if err != nil { - writeErrorResponse(r, core.NewError(core.CodeServerError, "login failed", err)) - return - } - - writeSuccessResponse(r, g.Map{ - "token": token, - }) -} - -// UserInfoHandler user info handler example | 获取用户信息处理器示例 -func (p *Plugin) UserInfoHandler(r *ghttp.Request) { - ctx := NewGFContext(r) - saCtx := core.NewContext(ctx, p.manager) - - loginID, err := saCtx.GetLoginID() - if err != nil { - writeErrorResponse(r, err) - return - } - - // Get user permissions and roles | 获取用户权限和角色 - permissions, _ := p.manager.GetPermissions(loginID) - roles, _ := p.manager.GetRoles(loginID) - - writeSuccessResponse(r, g.Map{ - "loginId": loginID, - "permissions": permissions, - "roles": roles, - }) -} - -// GetSaToken 从GoFrame上下文获取Sa-Token上下文 -func GetSaToken(r *ghttp.Request) (*core.SaTokenContext, bool) { - satoken := r.GetCtx().Value("satoken") - if satoken == nil { - return nil, false - } - ctx, ok := satoken.(*core.SaTokenContext) - return ctx, ok -} - -// ============ Error Handling Helpers | 错误处理辅助函数 ============ - -// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 -func writeErrorResponse(r *ghttp.Request, err error) { - var saErr *core.SaTokenError - var code int - var message string - var httpStatus int - - // Check if it's a SaTokenError | 检查是否为SaTokenError - if errors.As(err, &saErr) { - code = saErr.Code - message = saErr.Message - httpStatus = getHTTPStatusFromCode(code) - } else { - // Handle standard errors | 处理标准错误 - code = core.CodeServerError - message = err.Error() - httpStatus = http.StatusInternalServerError - } - - r.Response.WriteStatusExit(httpStatus, g.Map{ - "code": code, - "message": message, - "error": err.Error(), - }) -} - -// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 -func writeSuccessResponse(r *ghttp.Request, data interface{}) { - r.Response.WriteStatusExit(http.StatusOK, g.Map{ - "code": core.CodeSuccess, - "message": "success", - "data": data, - }) -} - -// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 -func getHTTPStatusFromCode(code int) int { - switch code { - case core.CodeNotLogin: - return http.StatusUnauthorized - case core.CodePermissionDenied: - return http.StatusForbidden - case core.CodeBadRequest: - return http.StatusBadRequest - case core.CodeNotFound: - return http.StatusNotFound - case core.CodeServerError: - return http.StatusInternalServerError - default: - return http.StatusInternalServerError - } -} diff --git a/integrations/gin/annotation.go b/integrations/gin/annotation.go index cc7a1dc..d23ab32 100644 --- a/integrations/gin/annotation.go +++ b/integrations/gin/annotation.go @@ -1,368 +1,330 @@ +// @Author daixk 2025/12/28 package gin import ( - "reflect" + "context" "strings" "github.com/click33/sa-token-go/core" "github.com/click33/sa-token-go/stputil" - ginfw "github.com/gin-gonic/gin" -) - -// Annotation constants | 注解常量 -const ( - TagSaCheckLogin = "sa_check_login" - TagSaCheckRole = "sa_check_role" - TagSaCheckPermission = "sa_check_permission" - TagSaCheckDisable = "sa_check_disable" - TagSaIgnore = "sa_ignore" + "github.com/gin-gonic/gin" ) // Annotation annotation structure | 注解结构体 type Annotation struct { - CheckLogin bool `json:"checkLogin"` - CheckRole []string `json:"checkRole"` - CheckPermission []string `json:"checkPermission"` - CheckDisable bool `json:"checkDisable"` - Ignore bool `json:"ignore"` + AuthType string `json:"authType"` // Optional: specify auth type | 可选:指定认证类型 + CheckLogin bool `json:"checkLogin"` // Check login | 检查登录 + CheckRole []string `json:"checkRole"` // Check roles | 检查角色 + CheckPermission []string `json:"checkPermission"` // Check permissions | 检查权限 + CheckDisable bool `json:"checkDisable"` // Check disable status | 检查封禁状态 + Ignore bool `json:"ignore"` // Ignore authentication | 忽略认证 + LogicType LogicType `json:"logicType"` // OR or AND logic (default: OR) | OR 或 AND 逻辑(默认: OR) } -// ParseTag parses struct tags | 解析结构体标签 -func ParseTag(tag string) *Annotation { - ann := &Annotation{} - - if tag == "" { - return ann - } - - parts := strings.Split(tag, ",") - for _, part := range parts { - part = strings.TrimSpace(part) - switch { - case part == TagSaCheckLogin || part == "login": - ann.CheckLogin = true - case strings.HasPrefix(part, TagSaCheckRole+"=") || strings.HasPrefix(part, "role="): - roles := strings.TrimPrefix(part, TagSaCheckRole+"=") - roles = strings.TrimPrefix(roles, "role=") - if roles != "" { - ann.CheckRole = strings.Split(roles, "|") - } - case strings.HasPrefix(part, TagSaCheckPermission+"=") || strings.HasPrefix(part, "permission="): - perms := strings.TrimPrefix(part, TagSaCheckPermission+"=") - perms = strings.TrimPrefix(perms, "permission=") - if perms != "" { - ann.CheckPermission = strings.Split(perms, "|") +// GetHandler gets handler with annotations | 获取带注解的处理器 +func GetHandler(ctx context.Context, handler gin.HandlerFunc, failFunc func(c *gin.Context, err error), annotations ...*Annotation) gin.HandlerFunc { + return func(c *gin.Context) { + // Ignore authentication | 忽略认证直接放行 + if len(annotations) > 0 && annotations[0].Ignore { + if handler != nil { + handler(c) + } else { + c.Next() } - case part == TagSaCheckDisable || part == "disable": - ann.CheckDisable = true - case part == TagSaIgnore || part == "ignore": - ann.Ignore = true + return } - } - - return ann -} - -// Validate validates if annotation is valid | 验证注解是否有效 -func (a *Annotation) Validate() bool { - if a.Ignore { - return true // When ignore is true, other checks are invalid | 忽略认证时,其他检查无效 - } - count := 0 - if a.CheckLogin { - count++ - } - if len(a.CheckRole) > 0 { - count++ - } - if len(a.CheckPermission) > 0 { - count++ - } - if a.CheckDisable { - count++ - } - - // At most one check type allowed | 最多只能有一个检查类型 - return count <= 1 -} + // Check if any authentication is needed | 检查是否需要任何认证 + ann := &Annotation{} + if len(annotations) > 0 { + ann = annotations[0] + } -// GetHandler gets handler with annotations | 获取带注解的处理器 -func GetHandler(handler interface{}, annotations ...*Annotation) ginfw.HandlerFunc { - return func(c *ginfw.Context) { - // Check if authentication should be ignored | 检查是否忽略认证 - if len(annotations) > 0 && annotations[0].Ignore { - if callHandler(handler, c) { - return + // No authentication required | 无需任何认证 + needAuth := ann.CheckLogin || ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 + if !needAuth { + if handler != nil { + handler(c) + } else { + c.Next() } - c.Next() return } - // Get token from context using configured TokenName | 从上下文获取Token(使用配置的TokenName) - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, stputil.GetManager()) - token := saCtx.GetTokenValue() - if token == "" { - writeErrorResponse(c, core.NewNotLoginError()) + // Get manager-example | 获取 Manager + mgr, err := stputil.GetManager(ann.AuthType) + if err != nil { + if failFunc != nil { + failFunc(c, err) + } else { + writeErrorResponse(c, err) + } c.Abort() return } + // Get SaTokenContext (reuse cached context) | 获取 SaTokenContext(复用缓存上下文) + saCtx := getSaContext(c, mgr) + token := saCtx.GetTokenValue() + // Check login | 检查登录 - if !stputil.IsLogin(token) { - writeErrorResponse(c, core.NewNotLoginError()) + isLogin, err := mgr.IsLogin(ctx, token) + if err != nil { + if failFunc != nil { + failFunc(c, err) + } else { + writeErrorResponse(c, err) + } c.Abort() return } - - // Get login ID | 获取登录ID - loginID, err := stputil.GetLoginID(token) - if err != nil { - writeErrorResponse(c, err) + if !isLogin { + if failFunc != nil { + failFunc(c, core.NewNotLoginError()) + } else { + writeErrorResponse(c, core.NewNotLoginError()) + } c.Abort() return } + // Get loginID for further checks | 获取 loginID 用于后续检查 + var loginID string + if ann.CheckDisable || len(ann.CheckPermission) > 0 || len(ann.CheckRole) > 0 { + loginID, err = mgr.GetLoginIDNotCheck(ctx, token) + if err != nil { + writeErrorResponse(c, err) + c.Abort() + return + } + } + // Check if account is disabled | 检查是否被封禁 - if len(annotations) > 0 && annotations[0].CheckDisable { - if stputil.IsDisable(loginID) { - writeErrorResponse(c, core.NewAccountDisabledError(loginID)) + if ann.CheckDisable { + if mgr.IsDisable(ctx, loginID) { + if failFunc != nil { + failFunc(c, core.NewAccountDisabledError(loginID)) + } else { + writeErrorResponse(c, core.NewAccountDisabledError(loginID)) + } c.Abort() return } } // Check permission | 检查权限 - if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { - hasPermission := false - for _, perm := range annotations[0].CheckPermission { - if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { - hasPermission = true - break - } + if len(ann.CheckPermission) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasPermissionsAnd(ctx, loginID, ann.CheckPermission) + } else { + ok = mgr.HasPermissionsOr(ctx, loginID, ann.CheckPermission) } - if !hasPermission { - writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) + if !ok { + if failFunc != nil { + failFunc(c, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) + } else { + writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(ann.CheckPermission, ","))) + } c.Abort() return } } // Check role | 检查角色 - if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { - hasRole := false - for _, role := range annotations[0].CheckRole { - if stputil.HasRole(loginID, strings.TrimSpace(role)) { - hasRole = true - break - } + if len(ann.CheckRole) > 0 { + var ok bool + if ann.LogicType == LogicAnd { + ok = mgr.HasRolesAnd(ctx, loginID, ann.CheckRole) + } else { + ok = mgr.HasRolesOr(ctx, loginID, ann.CheckRole) } - if !hasRole { - writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) + if !ok { + if failFunc != nil { + failFunc(c, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) + } else { + writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(ann.CheckRole, ","))) + } c.Abort() return } } - // All checks passed, execute original handler or continue | 所有检查通过,执行原函数或继续 - if callHandler(handler, c) { - return + // All checks passed, execute original handler | 所有检查通过,执行原函数 + if handler != nil { + handler(c) + } else { + c.Next() } - c.Next() } } -func callHandler(handler interface{}, c *ginfw.Context) bool { - if handler == nil { - return false - } - - switch h := handler.(type) { - case func(*ginfw.Context): - if h == nil { - return false - } - h(c) - return true - case ginfw.HandlerFunc: - if h == nil { - return false - } - h(c) - return true - } - - hv := reflect.ValueOf(handler) - if hv.Kind() != reflect.Func || hv.IsNil() || hv.Type().NumIn() != 1 { - return false - } - - argType := hv.Type().In(0) - if !argType.AssignableTo(reflect.TypeOf(c)) { - return false +// CheckLoginMiddleware decorator for login checking | 检查登录装饰器 +func CheckLoginMiddleware( + ctx context.Context, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckLogin: true} + if len(authType) > 0 { + ann.AuthType = authType[0] } - - hv.Call([]reflect.Value{reflect.ValueOf(c)}) - return true -} - -// Decorator functions | 装饰器函数 - -// CheckLogin decorator for login checking | 检查登录装饰器 -func CheckLogin() ginfw.HandlerFunc { - return GetHandler(nil, &Annotation{CheckLogin: true}) + return GetHandler(ctx, handler, failFunc, ann) } -// CheckRole decorator for role checking | 检查角色装饰器 -func CheckRole(roles ...string) ginfw.HandlerFunc { - return GetHandler(nil, &Annotation{CheckRole: roles}) +// CheckRoleMiddleware decorator for role checking | 检查角色装饰器 +func CheckRoleMiddleware( + ctx context.Context, + roles []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckRole: roles} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } -// CheckPermission decorator for permission checking | 检查权限装饰器 -func CheckPermission(perms ...string) ginfw.HandlerFunc { - return GetHandler(nil, &Annotation{CheckPermission: perms}) +// CheckPermissionMiddleware decorator for permission checking | 检查权限装饰器 +func CheckPermissionMiddleware( + ctx context.Context, + perms []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckPermission: perms} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } -// CheckDisable decorator for checking if account is disabled | 检查是否被封禁装饰器 -func CheckDisable() ginfw.HandlerFunc { - return GetHandler(nil, &Annotation{CheckDisable: true}) +// CheckDisableMiddleware decorator for checking if account is disabled | 检查是否被封禁装饰器 +func CheckDisableMiddleware( + ctx context.Context, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckDisable: true} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } -// Ignore decorator to ignore authentication | 忽略认证装饰器 -func Ignore() ginfw.HandlerFunc { - return GetHandler(nil, &Annotation{Ignore: true}) +// IgnoreMiddleware decorator to ignore authentication | 忽略认证装饰器 +func IgnoreMiddleware( + ctx context.Context, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), +) gin.HandlerFunc { + ann := &Annotation{Ignore: true} + return GetHandler(ctx, handler, failFunc, ann) } -// WithAnnotation decorator with custom annotation | 使用自定义注解装饰器 -func WithAnnotation(ann *Annotation) ginfw.HandlerFunc { - return GetHandler(nil, ann) +// ============ Combined Middleware | 组合中间件 ============ + +// CheckLoginAndRoleMiddleware checks login and role | 检查登录和角色 +func CheckLoginAndRoleMiddleware( + ctx context.Context, + roles []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckLogin: true, CheckRole: roles} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } -// ProcessStructAnnotations processes annotations on struct tags | 处理结构体上的注解标签 -func ProcessStructAnnotations(handler interface{}) ginfw.HandlerFunc { - handlerValue := reflect.ValueOf(handler) - handlerType := reflect.TypeOf(handler) - - // Find method name, usually the last path segment | 查找方法名,通常是最后一个路径段 - methodName := "unknown" - if handlerType.Kind() == reflect.Ptr { - handlerType = handlerType.Elem() +// CheckLoginAndPermissionMiddleware checks login and permission | 检查登录和权限 +func CheckLoginAndPermissionMiddleware( + ctx context.Context, + perms []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckLogin: true, CheckPermission: perms} + if len(authType) > 0 { + ann.AuthType = authType[0] } - if handlerType.Kind() == reflect.Struct { - methodName = handlerType.Name() - } - - // Parse method annotations | 解析方法上的注解标签 - ann := parseMethodAnnotation(handlerType, methodName) - - return GetHandler(func(c *ginfw.Context) { - handlerValue.MethodByName("ServeHTTP").Call([]reflect.Value{reflect.ValueOf(c)}) - }, ann) + return GetHandler(ctx, handler, failFunc, ann) } -// parseMethodAnnotation parses method annotations | 解析方法注解 -func parseMethodAnnotation(t reflect.Type, methodName string) *Annotation { - // Simplified implementation, returns empty annotation | 简化实现,直接返回空注解 - return &Annotation{} +// CheckAllMiddleware checks login, role, permission and disable status | 全面检查 +func CheckAllMiddleware( + ctx context.Context, + roles []string, + perms []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) gin.HandlerFunc { + ann := &Annotation{CheckLogin: true, CheckRole: roles, CheckPermission: perms} + if len(authType) > 0 { + ann.AuthType = authType[0] + } + return GetHandler(ctx, handler, failFunc, ann) } -// HandlerWithAnnotations 带注解的处理器包装器 -type HandlerWithAnnotations struct { - Handler interface{} - Annotations []*Annotation +// ============ Route Group Helper | 路由组辅助函数 ============ + +// AuthGroup creates a route group with authentication | 创建带认证的路由组 +func AuthGroup( + ctx context.Context, + group *gin.RouterGroup, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) *gin.RouterGroup { + group.Use(CheckLoginMiddleware(ctx, handler, failFunc, authType...)) + return group } -// NewHandlerWithAnnotations 创建带注解的处理器 -func NewHandlerWithAnnotations(handler interface{}, annotations ...*Annotation) *HandlerWithAnnotations { - return &HandlerWithAnnotations{ - Handler: handler, - Annotations: annotations, - } +// RoleGroup creates a route group with role checking | 创建带角色检查的路由组 +func RoleGroup( + ctx context.Context, + group *gin.RouterGroup, + roles []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) *gin.RouterGroup { + group.Use(CheckLoginAndRoleMiddleware(ctx, roles, handler, failFunc, authType...)) + return group } -// ToGinHandler 转换为Gin处理器 -func (h *HandlerWithAnnotations) ToGinHandler() ginfw.HandlerFunc { - return GetHandler(h.Handler, h.Annotations...) +// PermissionGroup creates a route group with permission checking | 创建带权限检查的路由组 +func PermissionGroup( + ctx context.Context, + group *gin.RouterGroup, + perms []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) *gin.RouterGroup { + group.Use(CheckLoginAndPermissionMiddleware(ctx, perms, handler, failFunc, authType...)) + return group } -// Middleware 创建中间件版本 -func Middleware(annotations ...*Annotation) ginfw.HandlerFunc { - return func(c *ginfw.Context) { - - // 检查是否忽略认证 - if len(annotations) > 0 && annotations[0].Ignore { - c.Next() - return - } - - // 获取Token(使用配置的TokenName) - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, stputil.GetManager()) - token := saCtx.GetTokenValue() - if token == "" { - writeErrorResponse(c, core.NewNotLoginError()) - c.Abort() - return - } - - // 检查登录 - if !stputil.IsLogin(token) { - writeErrorResponse(c, core.NewNotLoginError()) - c.Abort() - return - } - - // 获取登录ID - loginID, err := stputil.GetLoginID(token) - if err != nil { - writeErrorResponse(c, err) - c.Abort() - return - } - - // 检查是否被封禁 - if len(annotations) > 0 && annotations[0].CheckDisable { - if stputil.IsDisable(loginID) { - writeErrorResponse(c, core.NewAccountDisabledError(loginID)) - c.Abort() - return - } - } - - // 检查权限 - if len(annotations) > 0 && len(annotations[0].CheckPermission) > 0 { - hasPermission := false - for _, perm := range annotations[0].CheckPermission { - if stputil.HasPermission(loginID, strings.TrimSpace(perm)) { - hasPermission = true - break - } - } - if !hasPermission { - writeErrorResponse(c, core.NewPermissionDeniedError(strings.Join(annotations[0].CheckPermission, ","))) - c.Abort() - return - } - } - - // 检查角色 - if len(annotations) > 0 && len(annotations[0].CheckRole) > 0 { - hasRole := false - for _, role := range annotations[0].CheckRole { - if stputil.HasRole(loginID, strings.TrimSpace(role)) { - hasRole = true - break - } - } - if !hasRole { - writeErrorResponse(c, core.NewRoleDeniedError(strings.Join(annotations[0].CheckRole, ","))) - c.Abort() - return - } - } - - // 所有检查通过,继续下一个处理器 - c.Next() - } +// RoleAndPermissionGroup creates a route group with role and permission checking | 创建带角色和权限检查的路由组 +func RoleAndPermissionGroup( + ctx context.Context, + group *gin.RouterGroup, + roles []string, + perms []string, + handler gin.HandlerFunc, + failFunc func(c *gin.Context, err error), + authType ...string, +) *gin.RouterGroup { + group.Use(CheckAllMiddleware(ctx, roles, perms, handler, failFunc, authType...)) + return group } diff --git a/integrations/gin/annotation_test.go b/integrations/gin/annotation_test.go deleted file mode 100644 index 4d00279..0000000 --- a/integrations/gin/annotation_test.go +++ /dev/null @@ -1,486 +0,0 @@ -package gin - -import ( - "net/http" - "net/http/httptest" - "testing" - - "github.com/click33/sa-token-go/core/config" - "github.com/click33/sa-token-go/core/manager" - "github.com/click33/sa-token-go/storage/memory" - "github.com/click33/sa-token-go/stputil" - ginfw "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" -) - -// setupTestRouter 创建测试路由器和初始化 sa-token -func setupTestRouter() *ginfw.Engine { - ginfw.SetMode(ginfw.TestMode) - router := ginfw.New() - - // 创建内存存储 - storage := memory.NewStorage() - - // 创建配置 - cfg := &config.Config{ - TokenName: "satoken", - Timeout: 2592000, // 30 天(秒) - IsConcurrent: true, - IsShare: true, - MaxLoginCount: -1, - } - - // 创建并设置全局 Manager - mgr := manager.NewManager(storage, cfg) - stputil.SetManager(mgr) - - return router -} - -// mockLogin 模拟用户登录并返回 token -func mockLogin(loginID interface{}) string { - token, _ := stputil.Login(loginID) - return token -} - -// mockLoginWithRole 模拟用户登录并设置角色 -func mockLoginWithRole(loginID interface{}, roles []string) string { - token, _ := stputil.Login(loginID) - stputil.SetRoles(loginID, roles) - return token -} - -// mockLoginWithPermission 模拟用户登录并设置权限 -func mockLoginWithPermission(loginID interface{}, permissions []string) string { - token, _ := stputil.Login(loginID) - stputil.SetPermissions(loginID, permissions) - return token -} - -// TestCheckRole_WithValidRole 测试具有有效角色的用户访问 -func TestCheckRole_WithValidRole(t *testing.T) { - router := setupTestRouter() - - // 设置路由 - 使用 CheckRole 作为中间件 - router.GET("/admin", CheckRole("Admin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - // 创建一个具有 Admin 角色的用户 - token := mockLoginWithRole("user123", []string{"Admin"}) - - // 发送请求 - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/admin", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - // 断言 - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "success") -} - -// TestCheckRole_WithInvalidRole 测试没有所需角色的用户访问 -func TestCheckRole_WithInvalidRole(t *testing.T) { - router := setupTestRouter() - - // 设置路由 - router.GET("/admin", CheckRole("Admin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - // 创建一个只有 User 角色的用户 - token := mockLoginWithRole("user456", []string{"User"}) - - // 发送请求 - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/admin", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - // 断言 - assert.Equal(t, http.StatusForbidden, w.Code) - assert.Contains(t, w.Body.String(), "角色不足") -} - -// TestCheckRole_MultipleRoles 测试多个角色的情况(OR 逻辑) -func TestCheckRole_MultipleRoles(t *testing.T) { - router := setupTestRouter() - - // 设置路由 - 需要 Admin 或 SuperAdmin 角色 - router.GET("/manage", CheckRole("Admin", "SuperAdmin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - // 测试具有 SuperAdmin 角色的用户 - token := mockLoginWithRole("superuser", []string{"SuperAdmin"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/manage", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "success") -} - -// TestCheckRole_NoToken 测试未提供 token 的情况 -func TestCheckRole_NoToken(t *testing.T) { - router := setupTestRouter() - - router.GET("/admin", CheckRole("Admin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/admin", nil) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "未登录") -} - -// TestCheckRole_InvalidToken 测试无效 token 的情况 -func TestCheckRole_InvalidToken(t *testing.T) { - router := setupTestRouter() - - router.GET("/admin", CheckRole("Admin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/admin", nil) - req.Header.Set("Authorization", "invalid-token-12345") - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "未登录") -} - -// TestCheckPermission_WithValidPermission 测试具有有效权限的用户访问 -func TestCheckPermission_WithValidPermission(t *testing.T) { - router := setupTestRouter() - - router.GET("/users", CheckPermission("user.read"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - token := mockLoginWithPermission("user789", []string{"user.read"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/users", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "success") -} - -// TestCheckPermission_WithInvalidPermission 测试没有所需权限的用户访问 -func TestCheckPermission_WithInvalidPermission(t *testing.T) { - router := setupTestRouter() - - router.GET("/users", CheckPermission("user.delete"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "success"}) - }) - - token := mockLoginWithPermission("user789", []string{"user.read"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/users", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusForbidden, w.Code) - assert.Contains(t, w.Body.String(), "权限不足") -} - -// TestCheckLogin_Success 测试登录检查成功 -func TestCheckLogin_Success(t *testing.T) { - router := setupTestRouter() - - router.GET("/profile", CheckLogin(), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "profile data"}) - }) - - token := mockLogin("user999") - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/profile", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "profile data") -} - -// TestCheckLogin_Failed 测试登录检查失败 -func TestCheckLogin_Failed(t *testing.T) { - router := setupTestRouter() - - router.GET("/profile", CheckLogin(), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "profile data"}) - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/profile", nil) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusUnauthorized, w.Code) - assert.Contains(t, w.Body.String(), "未登录") -} - -// TestCheckDisable_NotDisabled 测试账号未被封禁的情况 -func TestCheckDisable_NotDisabled(t *testing.T) { - router := setupTestRouter() - - router.GET("/resource", CheckDisable(), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "resource data"}) - }) - - token := mockLogin("user101") - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/resource", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "resource data") -} - -// TestCheckDisable_IsDisabled 测试账号被封禁的情况 -func TestCheckDisable_IsDisabled(t *testing.T) { - router := setupTestRouter() - - router.GET("/resource", CheckDisable(), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "resource data"}) - }) - - loginID := "user102" - token := mockLogin(loginID) - - // 封禁账号 - stputil.Disable(loginID, 3600) // 封禁 1 小时 - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/resource", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusForbidden, w.Code) - assert.Contains(t, w.Body.String(), "账号已被封禁") -} - -// TestIgnore_SkipsAuthentication 测试忽略认证装饰器 -func TestIgnore_SkipsAuthentication(t *testing.T) { - router := setupTestRouter() - - router.GET("/public", Ignore(), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "public data"}) - }) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/public", nil) - // 不提供任何 token - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "public data") -} - -// TestChainedMiddleware_CheckRoleAndHandler 测试链式中间件:CheckRole + 实际处理器 -func TestChainedMiddleware_CheckRoleAndHandler(t *testing.T) { - router := setupTestRouter() - - // 模拟用户示例代码的使用方式 - safeGroup := router.Group("/safe") - { - safeGroup.GET("", CheckRole("SuperAdmin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "safe settings"}) - }) - } - - // 测试具有 SuperAdmin 角色的用户 - token := mockLoginWithRole("admin123", []string{"SuperAdmin"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/safe", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "safe settings") -} - -// TestChainedMiddleware_CheckRoleAndHandler_NoRole 测试链式中间件:无角色访问 -func TestChainedMiddleware_CheckRoleAndHandler_NoRole(t *testing.T) { - router := setupTestRouter() - - safeGroup := router.Group("/safe") - { - safeGroup.GET("", CheckRole("SuperAdmin"), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "safe settings"}) - }) - } - - // 测试具有普通用户角色 - token := mockLoginWithRole("user123", []string{"User"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/safe", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusForbidden, w.Code) - assert.Contains(t, w.Body.String(), "角色不足") -} - -// TestGetHandler_WithNilHandler 测试 GetHandler 在 handler 为 nil 时的行为 -func TestGetHandler_WithNilHandler(t *testing.T) { - router := setupTestRouter() - - // 直接使用 GetHandler 创建中间件 - middleware := GetHandler(nil, &Annotation{CheckRole: []string{"Admin"}}) - - router.GET("/test", middleware, func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"message": "test passed"}) - }) - - token := mockLoginWithRole("testuser", []string{"Admin"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/test", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - // 应该能够正常执行,不会 panic - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "test passed") -} - -// TestMiddleware_CheckRole 测试 Middleware 函数的角色检查 -func TestMiddleware_CheckRole(t *testing.T) { - router := setupTestRouter() - - // 使用 Middleware 函数 - router.GET("/api/data", Middleware(&Annotation{CheckRole: []string{"Admin"}}), func(c *ginfw.Context) { - c.JSON(http.StatusOK, ginfw.H{"data": "sensitive data"}) - }) - - token := mockLoginWithRole("admin999", []string{"Admin"}) - - w := httptest.NewRecorder() - req, _ := http.NewRequest("GET", "/api/data", nil) - req.Header.Set("Authorization", token) - router.ServeHTTP(w, req) - - assert.Equal(t, http.StatusOK, w.Code) - assert.Contains(t, w.Body.String(), "sensitive data") -} - -// TestParseTag 测试标签解析功能 -func TestParseTag(t *testing.T) { - tests := []struct { - name string - tag string - expected *Annotation - }{ - { - name: "解析登录检查标签", - tag: "sa_check_login", - expected: &Annotation{ - CheckLogin: true, - }, - }, - { - name: "解析角色检查标签", - tag: "sa_check_role=Admin|SuperAdmin", - expected: &Annotation{ - CheckRole: []string{"Admin", "SuperAdmin"}, - }, - }, - { - name: "解析权限检查标签", - tag: "sa_check_permission=user.read|user.write", - expected: &Annotation{ - CheckPermission: []string{"user.read", "user.write"}, - }, - }, - { - name: "解析忽略认证标签", - tag: "sa_ignore", - expected: &Annotation{ - Ignore: true, - }, - }, - { - name: "解析封禁检查标签", - tag: "sa_check_disable", - expected: &Annotation{ - CheckDisable: true, - }, - }, - { - name: "空标签", - tag: "", - expected: &Annotation{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := ParseTag(tt.tag) - assert.Equal(t, tt.expected.CheckLogin, result.CheckLogin) - assert.Equal(t, tt.expected.CheckRole, result.CheckRole) - assert.Equal(t, tt.expected.CheckPermission, result.CheckPermission) - assert.Equal(t, tt.expected.CheckDisable, result.CheckDisable) - assert.Equal(t, tt.expected.Ignore, result.Ignore) - }) - } -} - -// TestAnnotationValidate 测试注解验证功能 -func TestAnnotationValidate(t *testing.T) { - tests := []struct { - name string - annotation *Annotation - valid bool - }{ - { - name: "有效的单一检查", - annotation: &Annotation{ - CheckLogin: true, - }, - valid: true, - }, - { - name: "有效的忽略标记", - annotation: &Annotation{ - Ignore: true, - CheckLogin: true, // 即使有其他标记,忽略时仍然有效 - }, - valid: true, - }, - { - name: "有效的空注解", - annotation: &Annotation{}, - valid: true, - }, - { - name: "无效的多重检查", - annotation: &Annotation{ - CheckLogin: true, - CheckRole: []string{"Admin"}, - }, - valid: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.annotation.Validate() - assert.Equal(t, tt.valid, result) - }) - } -} diff --git a/integrations/gin/context.go b/integrations/gin/context.go index 55dffed..042bfd6 100644 --- a/integrations/gin/context.go +++ b/integrations/gin/context.go @@ -7,7 +7,6 @@ import ( "github.com/gin-gonic/gin" ) -// GinContext Gin request context adapter | Gin请求上下文适配器 type GinContext struct { c *gin.Context aborted bool @@ -15,59 +14,61 @@ type GinContext struct { // NewGinContext creates a Gin context adapter | 创建Gin上下文适配器 func NewGinContext(c *gin.Context) adapter.RequestContext { - return &GinContext{c: c} + return &GinContext{ + c: c, + } } -// GetHeader gets request header | 获取请求头 -func (g *GinContext) GetHeader(key string) string { - return g.c.GetHeader(key) +// Get implements adapter.RequestContext. +func (g *GinContext) Get(key string) (interface{}, bool) { + return g.c.Get(key) } -// GetQuery gets query parameter | 获取查询参数 -func (g *GinContext) GetQuery(key string) string { - return g.c.Query(key) +// GetClientIP implements adapter.RequestContext. +func (g *GinContext) GetClientIP() string { + return g.c.ClientIP() } -// GetCookie gets cookie | 获取Cookie +// GetCookie implements adapter.RequestContext. func (g *GinContext) GetCookie(key string) string { cookie, _ := g.c.Cookie(key) return cookie } -// SetHeader sets response header | 设置响应头 -func (g *GinContext) SetHeader(key, value string) { - g.c.Header(key, value) -} - -// SetCookie sets cookie | 设置Cookie -func (g *GinContext) SetCookie(name, value string, maxAge int, path, domain string, secure, httpOnly bool) { - g.c.SetCookie(name, value, maxAge, path, domain, secure, httpOnly) - g.c.SetSameSite(http.SameSiteLaxMode) -} - -// GetClientIP gets client IP address | 获取客户端IP地址 -func (g *GinContext) GetClientIP() string { - return g.c.ClientIP() +// GetHeader implements adapter.RequestContext. +func (g *GinContext) GetHeader(key string) string { + return g.c.GetHeader(key) } -// GetMethod gets request method | 获取请求方法 +// GetMethod implements adapter.RequestContext. func (g *GinContext) GetMethod() string { return g.c.Request.Method } -// GetPath gets request path | 获取请求路径 +// GetPath implements adapter.RequestContext. func (g *GinContext) GetPath() string { return g.c.Request.URL.Path } -// Set sets context value | 设置上下文值 +// GetQuery implements adapter.RequestContext. +func (g *GinContext) GetQuery(key string) string { + return g.c.Query(key) +} + +// Set implements adapter.RequestContext. func (g *GinContext) Set(key string, value interface{}) { g.c.Set(key, value) } -// Get gets context value | 获取上下文值 -func (g *GinContext) Get(key string) (interface{}, bool) { - return g.c.Get(key) +// SetCookie implements adapter.RequestContext. +func (g *GinContext) SetCookie(name string, value string, maxAge int, path string, domain string, secure bool, httpOnly bool) { + g.c.SetCookie(name, value, maxAge, path, domain, secure, httpOnly) + g.c.SetSameSite(http.SameSiteLaxMode) +} + +// SetHeader implements adapter.RequestContext. +func (g *GinContext) SetHeader(key string, value string) { + g.c.Header(key, value) } // ============ Additional Required Methods | 额外必需的方法 ============ @@ -113,7 +114,7 @@ func (g *GinContext) SetCookieWithOptions(options *adapter.CookieOptions) { options.Secure, options.HttpOnly, ) - + // Set SameSite attribute switch options.SameSite { case "Strict": @@ -127,23 +128,17 @@ func (g *GinContext) SetCookieWithOptions(options *adapter.CookieOptions) { // GetString implements adapter.RequestContext. func (g *GinContext) GetString(key string) string { - value, exists := g.c.Get(key) - if !exists { - return "" - } - if str, ok := value.(string); ok { - return str - } - return "" + v := g.c.GetString(key) + return v } // MustGet implements adapter.RequestContext. func (g *GinContext) MustGet(key string) any { - value, exists := g.c.Get(key) + v, exists := g.c.Get(key) if !exists { panic("key not found: " + key) } - return value + return v } // Abort implements adapter.RequestContext. diff --git a/integrations/gin/export.go b/integrations/gin/export.go index 56177a2..06a60d0 100644 --- a/integrations/gin/export.go +++ b/integrations/gin/export.go @@ -1,364 +1,945 @@ package gin import ( + "context" "time" + "github.com/click33/sa-token-go/codec/json" + "github.com/click33/sa-token-go/codec/msgpack" "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/builder" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" + "github.com/click33/sa-token-go/core/manager" + "github.com/click33/sa-token-go/core/oauth2" + "github.com/click33/sa-token-go/core/security" + "github.com/click33/sa-token-go/core/session" + "github.com/click33/sa-token-go/generator/sgenerator" + "github.com/click33/sa-token-go/log/nop" + "github.com/click33/sa-token-go/log/slog" + "github.com/click33/sa-token-go/pool/ants" + "github.com/click33/sa-token-go/storage/memory" + "github.com/click33/sa-token-go/storage/redis" "github.com/click33/sa-token-go/stputil" ) -// ============ Re-export core types | 重新导出核心类型 ============ +// ============ Type Aliases | 类型别名 ============ -// Configuration related types | 配置相关类型 type ( - Config = core.Config - CookieConfig = core.CookieConfig - TokenStyle = core.TokenStyle + // Config 配置 + Config = config.Config + // Manager 管理器 + Manager = manager.Manager + // Session 会话 + Session = session.Session + // TokenInfo Token信息 + TokenInfo = manager.TokenInfo + // DisableInfo 封禁信息 + DisableInfo = manager.DisableInfo + // Builder 构建器 + Builder = builder.Builder + // SaTokenError 错误类型 + SaTokenError = core.SaTokenError + // Event 事件类型 + Event = listener.Event + // EventData 事件数据 + EventData = listener.EventData + // Listener 事件监听器 + Listener = listener.Listener + // ListenerConfig 监听器配置 + ListenerConfig = listener.ListenerConfig + // RefreshTokenInfo 刷新令牌信息 + RefreshTokenInfo = security.RefreshTokenInfo + // AccessTokenInfo 访问令牌信息 + AccessTokenInfo = security.AccessTokenInfo + // OAuth2Client OAuth2客户端 + OAuth2Client = oauth2.Client + // OAuth2AccessToken OAuth2访问令牌 + OAuth2AccessToken = oauth2.AccessToken + // AuthorizationCode 授权码 + AuthorizationCode = oauth2.AuthorizationCode + // OAuth2TokenRequest OAuth2令牌请求 + OAuth2TokenRequest = oauth2.TokenRequest + // OAuth2GrantType OAuth2授权类型 + OAuth2GrantType = oauth2.GrantType + // OAuth2UserValidator OAuth2用户验证器 + OAuth2UserValidator = oauth2.UserValidator + // Storage 存储接口 + Storage = adapter.Storage + // Codec 编解码接口 + Codec = adapter.Codec + // Log 日志接口 + Log = adapter.Log + // Pool 协程池接口 + Pool = adapter.Pool + // Generator 生成器接口 + Generator = adapter.Generator + + // ============ Codec Types | 编解码器类型 ============ + + // JSONSerializer JSON编解码器 + JSONSerializer = json.JSONSerializer + // MsgPackSerializer MsgPack编解码器 + MsgPackSerializer = msgpack.MsgPackSerializer + + // ============ Storage Types | 存储类型 ============ + + // MemoryStorage 内存存储 + MemoryStorage = memory.Storage + // RedisStorage Redis存储 + RedisStorage = redis.Storage + // RedisConfig Redis配置 + RedisConfig = redis.Config + // RedisBuilder Redis构建器 + RedisBuilder = redis.Builder + + // ============ Logger Types | 日志类型 ============ + + // SlogLogger 标准日志实现 + SlogLogger = slog.Logger + // SlogLoggerConfig 标准日志配置 + SlogLoggerConfig = slog.LoggerConfig + // SlogLogLevel 日志级别 + SlogLogLevel = slog.LogLevel + // NopLogger 空日志实现 + NopLogger = nop.NopLogger + + // ============ Generator Types | 生成器类型 ============ + + // TokenGenerator Token生成器 + TokenGenerator = sgenerator.Generator + // TokenStyle Token风格 + TokenStyle = adapter.TokenStyle + + // ============ Pool Types | 协程池类型 ============ + + // RenewPoolManager 续期池管理器 + RenewPoolManager = ants.RenewPoolManager + // RenewPoolConfig 续期池配置 + RenewPoolConfig = ants.RenewPoolConfig ) -// Token style constants | Token风格常量 +// ============ Error Codes | 错误码 ============ + const ( - TokenStyleUUID = core.TokenStyleUUID - TokenStyleSimple = core.TokenStyleSimple - TokenStyleRandom32 = core.TokenStyleRandom32 - TokenStyleRandom64 = core.TokenStyleRandom64 - TokenStyleRandom128 = core.TokenStyleRandom128 - TokenStyleJWT = core.TokenStyleJWT - TokenStyleHash = core.TokenStyleHash - TokenStyleTimestamp = core.TokenStyleTimestamp - TokenStyleTik = core.TokenStyleTik + CodeSuccess = core.CodeSuccess + CodeBadRequest = core.CodeBadRequest + CodeNotLogin = core.CodeNotLogin + CodePermissionDenied = core.CodePermissionDenied + CodeNotFound = core.CodeNotFound + CodeServerError = core.CodeServerError + CodeTokenInvalid = core.CodeTokenInvalid + CodeTokenExpired = core.CodeTokenExpired + CodeAccountDisabled = core.CodeAccountDisabled + CodeKickedOut = core.CodeKickedOut + CodeActiveTimeout = core.CodeActiveTimeout + CodeMaxLoginCount = core.CodeMaxLoginCount + CodeStorageError = core.CodeStorageError + CodeInvalidParameter = core.CodeInvalidParameter + CodeSessionError = core.CodeSessionError ) -// Core types | 核心类型 -type ( - Manager = core.Manager - TokenInfo = core.TokenInfo - Session = core.Session - TokenGenerator = core.TokenGenerator - SaTokenContext = core.SaTokenContext - Builder = core.Builder - NonceManager = core.NonceManager - RefreshTokenInfo = core.RefreshTokenInfo - RefreshTokenManager = core.RefreshTokenManager - OAuth2Server = core.OAuth2Server - OAuth2Client = core.OAuth2Client - OAuth2AccessToken = core.OAuth2AccessToken - OAuth2GrantType = core.OAuth2GrantType -) +// ============ Errors | 错误变量 ============ -// Adapter interfaces | 适配器接口 -type ( - Storage = core.Storage - RequestContext = core.RequestContext +var ( + // Authentication Errors | 认证错误 + ErrNotLogin = core.ErrNotLogin + ErrTokenInvalid = core.ErrTokenInvalid + ErrTokenExpired = core.ErrTokenExpired + ErrTokenKickout = core.ErrTokenKickout + ErrTokenReplaced = core.ErrTokenReplaced + ErrInvalidLoginID = core.ErrInvalidLoginID + ErrInvalidDevice = core.ErrInvalidDevice + ErrTokenNotFound = core.ErrTokenNotFound + + // Authorization Errors | 授权错误 + ErrPermissionDenied = core.ErrPermissionDenied + ErrRoleDenied = core.ErrRoleDenied + + // Account Errors | 账号错误 + ErrAccountDisabled = core.ErrAccountDisabled + ErrAccountNotFound = core.ErrAccountNotFound + ErrLoginLimitExceeded = core.ErrLoginLimitExceeded + + // Session Errors | 会话错误 + ErrSessionNotFound = core.ErrSessionNotFound + ErrActiveTimeout = core.ErrActiveTimeout + ErrSessionInvalidDataKey = core.ErrSessionInvalidDataKey + ErrSessionIDEmpty = core.ErrSessionIDEmpty + + // Security Errors | 安全错误 + ErrInvalidNonce = core.ErrInvalidNonce + ErrRefreshTokenExpired = core.ErrRefreshTokenExpired + ErrNonceInvalidRefreshToken = core.ErrNonceInvalidRefreshToken + ErrInvalidLoginIDEmpty = core.ErrInvalidLoginIDEmpty + + // OAuth2 Errors | OAuth2错误 + ErrClientOrClientIDEmpty = core.ErrClientOrClientIDEmpty + ErrClientNotFound = core.ErrClientNotFound + ErrUserIDEmpty = core.ErrUserIDEmpty + ErrInvalidRedirectURI = core.ErrInvalidRedirectURI + ErrInvalidClientCredentials = core.ErrInvalidClientCredentials + ErrInvalidAuthCode = core.ErrInvalidAuthCode + ErrAuthCodeUsed = core.ErrAuthCodeUsed + ErrAuthCodeExpired = core.ErrAuthCodeExpired + ErrClientMismatch = core.ErrClientMismatch + ErrRedirectURIMismatch = core.ErrRedirectURIMismatch + ErrInvalidAccessToken = core.ErrInvalidAccessToken + ErrInvalidRefreshToken = core.ErrInvalidRefreshToken + ErrInvalidScope = core.ErrInvalidScope + + // System Errors | 系统错误 + ErrStorageUnavailable = core.ErrStorageUnavailable + ErrSerializeFailed = core.ErrSerializeFailed + ErrDeserializeFailed = core.ErrDeserializeFailed + ErrTypeConvert = core.ErrTypeConvert ) -// Event related types | 事件相关类型 -type ( - EventListener = core.EventListener - EventManager = core.EventManager - EventData = core.EventData - Event = core.Event - ListenerFunc = core.ListenerFunc - ListenerConfig = core.ListenerConfig -) +// ============ Error Constructors | 错误构造函数 ============ -// Event constants | 事件常量 -const ( - EventLogin = core.EventLogin - EventLogout = core.EventLogout - EventKickout = core.EventKickout - EventDisable = core.EventDisable - EventUntie = core.EventUntie - EventRenew = core.EventRenew - EventCreateSession = core.EventCreateSession - EventDestroySession = core.EventDestroySession - EventPermissionCheck = core.EventPermissionCheck - EventRoleCheck = core.EventRoleCheck - EventAll = core.EventAll +var ( + NewError = core.NewError + NewErrorWithContext = core.NewErrorWithContext + NewNotLoginError = core.NewNotLoginError + NewPermissionDeniedError = core.NewPermissionDeniedError + NewRoleDeniedError = core.NewRoleDeniedError + NewAccountDisabledError = core.NewAccountDisabledError ) -// OAuth2 grant type constants | OAuth2授权类型常量 -const ( - GrantTypeAuthorizationCode = core.GrantTypeAuthorizationCode - GrantTypeRefreshToken = core.GrantTypeRefreshToken - GrantTypeClientCredentials = core.GrantTypeClientCredentials - GrantTypePassword = core.GrantTypePassword -) +// ============ Error Checking Helpers | 错误检查辅助函数 ============ -// Utility functions | 工具函数 var ( - RandomString = core.RandomString - IsEmpty = core.IsEmpty - IsNotEmpty = core.IsNotEmpty - DefaultString = core.DefaultString - ContainsString = core.ContainsString - RemoveString = core.RemoveString - UniqueStrings = core.UniqueStrings - MergeStrings = core.MergeStrings - MatchPattern = core.MatchPattern + IsNotLoginError = core.IsNotLoginError + IsPermissionDeniedError = core.IsPermissionDeniedError + IsAccountDisabledError = core.IsAccountDisabledError + IsTokenError = core.IsTokenError + GetErrorCode = core.GetErrorCode ) -// ============ Core constructor functions | 核心构造函数 ============ +// ============ Manager Management | Manager 管理 ============ + +// SetManager stores the manager-example in the global map using the specified autoType | 使用指定的 autoType 将管理器存储在全局 map 中 +func SetManager(mgr *manager.Manager) { + stputil.SetManager(mgr) +} + +// GetManager retrieves the manager-example from the global map using the specified autoType | 使用指定的 autoType 从全局 map 中获取管理器 +func GetManager(autoType ...string) (*manager.Manager, error) { + return stputil.GetManager(autoType...) +} -// DefaultConfig returns default configuration | 返回默认配置 -func DefaultConfig() *Config { - return core.DefaultConfig() +// DeleteManager delete the specific manager-example for the given autoType and releases resources | 删除指定的管理器并释放资源 +func DeleteManager(autoType ...string) error { + return stputil.DeleteManager(autoType...) } -// NewManager creates a new authentication manager | 创建新的认证管理器 -func NewManager(storage Storage, cfg *Config) *Manager { - return core.NewManager(storage, cfg) +// DeleteAllManager delete all managers in the global map and releases resources | 关闭所有管理器并释放资源 +func DeleteAllManager() { + stputil.DeleteAllManager() } -// NewContext creates a new Sa-Token context | 创建新的Sa-Token上下文 -func NewContext(ctx RequestContext, mgr *Manager) *SaTokenContext { - return core.NewContext(ctx, mgr) +// ============ Builder & Config | 构建器和配置 ============ + +// NewDefaultBuild creates a new default builder | 创建默认构建器 +func NewDefaultBuild() *builder.Builder { + return builder.NewBuilder() } -// NewSession creates a new session | 创建新的Session -func NewSession(id string, storage Storage, prefix string) *Session { - return core.NewSession(id, storage, prefix) +// NewDefaultConfig creates a new default config | 创建默认配置 +func NewDefaultConfig() *config.Config { + return config.DefaultConfig() } -// LoadSession loads an existing session | 加载已存在的Session -func LoadSession(id string, storage Storage, prefix string) (*Session, error) { - return core.LoadSession(id, storage, prefix) +// DefaultLoggerConfig returns the default logger config | 返回默认日志配置 +func DefaultLoggerConfig() *slog.LoggerConfig { + return slog.DefaultLoggerConfig() } -// NewTokenGenerator creates a new token generator | 创建新的Token生成器 -func NewTokenGenerator(cfg *Config) *TokenGenerator { - return core.NewTokenGenerator(cfg) +// DefaultRenewPoolConfig returns the default renew pool config | 返回默认续期池配置 +func DefaultRenewPoolConfig() *ants.RenewPoolConfig { + return ants.DefaultRenewPoolConfig() } -// NewEventManager creates a new event manager | 创建新的事件管理器 -func NewEventManager() *EventManager { - return core.NewEventManager() +// ============ Codec Constructors | 编解码器构造函数 ============ + +// NewJSONSerializer creates a new JSON serializer | 创建JSON序列化器 +func NewJSONSerializer() *json.JSONSerializer { + return json.NewJSONSerializer() } -// NewBuilder creates a new builder for fluent configuration | 创建新的Builder构建器(用于流式配置) -func NewBuilder() *Builder { - return core.NewBuilder() +// NewMsgPackSerializer creates a new MsgPack serializer | 创建MsgPack序列化器 +func NewMsgPackSerializer() *msgpack.MsgPackSerializer { + return msgpack.NewMsgPackSerializer() } -// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器 -func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { - return core.NewNonceManager(storage, prefix, ttl...) +// ============ Storage Constructors | 存储构造函数 ============ + +// NewMemoryStorage creates a new memory storage | 创建内存存储 +func NewMemoryStorage() *memory.Storage { + return memory.NewStorage() } -// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器 -func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { - return core.NewRefreshTokenManager(storage, prefix, cfg) +// NewMemoryStorageWithCleanupInterval creates a new memory storage with cleanup interval | 创建带清理间隔的内存存储 +func NewMemoryStorageWithCleanupInterval(interval time.Duration) *memory.Storage { + return memory.NewStorageWithCleanupInterval(interval) } -// NewOAuth2Server creates a new OAuth2 server | 创建新的OAuth2服务器 -func NewOAuth2Server(storage Storage, prefix string) *OAuth2Server { - return core.NewOAuth2Server(storage, prefix) +// NewRedisStorage creates a new Redis storage from URL | 通过URL创建Redis存储 +func NewRedisStorage(url string) (*redis.Storage, error) { + return redis.NewStorage(url) } -// ============ Global StpUtil functions | 全局StpUtil函数 ============ +// NewRedisStorageFromConfig creates a new Redis storage from config | 通过配置创建Redis存储 +func NewRedisStorageFromConfig(cfg *redis.Config) (*redis.Storage, error) { + return redis.NewStorageFromConfig(cfg) +} -// SetManager sets the global Manager (must be called first) | 设置全局Manager(必须先调用此方法) -func SetManager(mgr *Manager) { - stputil.SetManager(mgr) +// NewRedisBuilder creates a new Redis builder | 创建Redis构建器 +func NewRedisBuilder() *redis.Builder { + return redis.NewBuilder() +} + +// ============ Logger Constructors | 日志构造函数 ============ + +// NewSlogLogger creates a new slog logger with config | 使用配置创建标准日志器 +func NewSlogLogger(cfg *slog.LoggerConfig) (*slog.Logger, error) { + return slog.NewLoggerWithConfig(cfg) } -// GetManager gets the global Manager | 获取全局Manager -func GetManager() *Manager { - return stputil.GetManager() +// NewNopLogger creates a new no-op logger | 创建空日志器 +func NewNopLogger() *nop.NopLogger { + return nop.NewNopLogger() } +// ============ Generator Constructors | 生成器构造函数 ============ + +// NewTokenGenerator creates a new token generator | 创建Token生成器 +func NewTokenGenerator(timeout int64, tokenStyle adapter.TokenStyle, jwtSecretKey string) *sgenerator.Generator { + return sgenerator.NewGenerator(timeout, tokenStyle, jwtSecretKey) +} + +// NewDefaultTokenGenerator creates a new default token generator | 创建默认Token生成器 +func NewDefaultTokenGenerator() *sgenerator.Generator { + return sgenerator.NewDefaultGenerator() +} + +// ============ Pool Constructors | 协程池构造函数 ============ + +// NewRenewPoolManager creates a new renew pool manager-example with default config | 使用默认配置创建续期池管理器 +func NewRenewPoolManager() *ants.RenewPoolManager { + return ants.NewRenewPoolManagerWithDefaultConfig() +} + +// NewRenewPoolManagerWithConfig creates a new renew pool manager-example with config | 使用配置创建续期池管理器 +func NewRenewPoolManagerWithConfig(cfg *ants.RenewPoolConfig) (*ants.RenewPoolManager, error) { + return ants.NewRenewPoolManagerWithConfig(cfg) +} + +// ============ Token Style Constants | Token风格常量 ============ + +const ( + // TokenStyleUUID UUID style | UUID风格 + TokenStyleUUID = adapter.TokenStyleUUID + // TokenStyleSimple Simple random string | 简单随机字符串 + TokenStyleSimple = adapter.TokenStyleSimple + // TokenStyleRandom32 32-bit random string | 32位随机字符串 + TokenStyleRandom32 = adapter.TokenStyleRandom32 + // TokenStyleRandom64 64-bit random string | 64位随机字符串 + TokenStyleRandom64 = adapter.TokenStyleRandom64 + // TokenStyleRandom128 128-bit random string | 128位随机字符串 + TokenStyleRandom128 = adapter.TokenStyleRandom128 + // TokenStyleJWT JWT style | JWT风格 + TokenStyleJWT = adapter.TokenStyleJWT + // TokenStyleHash SHA256 hash-based style | SHA256哈希风格 + TokenStyleHash = adapter.TokenStyleHash + // TokenStyleTimestamp Timestamp-based style | 时间戳风格 + TokenStyleTimestamp = adapter.TokenStyleTimestamp + // TokenStyleTik Short ID style (like TikTok) | Tik风格短ID + TokenStyleTik = adapter.TokenStyleTik +) + +// ============ Log Level Constants | 日志级别常量 ============ + +const ( + // LogLevelDebug Debug level | 调试级别 + LogLevelDebug = adapter.LogLevelDebug + // LogLevelInfo Info level | 信息级别 + LogLevelInfo = adapter.LogLevelInfo + // LogLevelWarn Warn level | 警告级别 + LogLevelWarn = adapter.LogLevelWarn + // LogLevelError Error level | 错误级别 + LogLevelError = adapter.LogLevelError +) + // ============ Authentication | 登录认证 ============ // Login performs user login | 用户登录 -func Login(loginID interface{}, device ...string) (string, error) { - return stputil.Login(loginID, device...) +func Login(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.Login(ctx, loginID, deviceOrAutoType...) } // LoginByToken performs login with specified token | 使用指定Token登录 -func LoginByToken(loginID interface{}, tokenValue string, device ...string) error { - return stputil.LoginByToken(loginID, tokenValue, device...) +func LoginByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LoginByToken(ctx, tokenValue, authType...) } // Logout performs user logout | 用户登出 -func Logout(loginID interface{}, device ...string) error { - return stputil.Logout(loginID, device...) +func Logout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Logout(ctx, loginID, deviceOrAutoType...) } // LogoutByToken performs logout by token | 根据Token登出 -func LogoutByToken(tokenValue string) error { - return stputil.LogoutByToken(tokenValue) +func LogoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.LogoutByToken(ctx, tokenValue, authType...) +} + +// Kickout kicks out a user session | 踢人下线 +func Kickout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Kickout(ctx, loginID, deviceOrAutoType...) +} + +// KickoutByToken Kick user offline | 根据Token踢人下线 +func KickoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.KickoutByToken(ctx, tokenValue, authType...) +} + +// Replace user offline by login ID and device | 根据账号和设备顶人下线 +func Replace(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + return stputil.Replace(ctx, loginID, deviceOrAutoType...) +} + +// ReplaceByToken Replace user offline by token | 根据Token顶人下线 +func ReplaceByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.ReplaceByToken(ctx, tokenValue, authType...) } +// ============ Token Validation | Token验证 ============ + // IsLogin checks if the user is logged in | 检查用户是否已登录 -func IsLogin(tokenValue string) bool { - return stputil.IsLogin(tokenValue) +func IsLogin(ctx context.Context, tokenValue string, authType ...string) (bool, error) { + return stputil.IsLogin(ctx, tokenValue, authType...) } -// CheckLoginByToken checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) -func CheckLoginByToken(tokenValue string) error { - return stputil.CheckLogin(tokenValue) +// CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) +func CheckLogin(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.CheckLogin(ctx, tokenValue, authType...) } // GetLoginID gets the login ID from token | 从Token获取登录ID -func GetLoginID(tokenValue string) (string, error) { - return stputil.GetLoginID(tokenValue) +func GetLoginID(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginID(ctx, tokenValue, authType...) } -// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查) -func GetLoginIDNotCheck(tokenValue string) (string, error) { - return stputil.GetLoginIDNotCheck(tokenValue) +// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查登录状态) +func GetLoginIDNotCheck(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetLoginIDNotCheck(ctx, tokenValue, authType...) } // GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 -func GetTokenValue(loginID interface{}, device ...string) (string, error) { - return stputil.GetTokenValue(loginID, device...) +func GetTokenValue(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + return stputil.GetTokenValue(ctx, loginID, deviceOrAutoType...) } // GetTokenInfo gets token information | 获取Token信息 -func GetTokenInfo(tokenValue string) (*TokenInfo, error) { - return stputil.GetTokenInfo(tokenValue) +func GetTokenInfo(ctx context.Context, tokenValue string, authType ...string) (*manager.TokenInfo, error) { + return stputil.GetTokenInfo(ctx, tokenValue, authType...) } -// ============ Kickout | 踢人下线 ============ +// ============ Account Disable | 账号封禁 ============ -// Kickout kicks out a user session | 踢人下线 -func Kickout(loginID interface{}, device ...string) error { - return stputil.Kickout(loginID, device...) +// Disable disables an account for specified duration | 封禁账号(指定时长) +func Disable(ctx context.Context, loginID interface{}, duration time.Duration, authType ...string) error { + return stputil.Disable(ctx, loginID, duration, authType...) } -// ============ Account Disable | 账号封禁 ============ +// DisableByToken disables the account associated with the given token for a duration | 根据指定 Token 封禁其对应的账号 +func DisableByToken(ctx context.Context, tokenValue string, duration time.Duration, authType ...string) error { + return stputil.DisableByToken(ctx, tokenValue, duration, authType...) +} -// Disable disables an account for specified duration | 封禁账号(指定时长) -func Disable(loginID interface{}, duration time.Duration) error { - return stputil.Disable(loginID, duration) +// Untie re-enables a disabled account | 解封账号 +func Untie(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.Untie(ctx, loginID, authType...) +} + +// UntieByToken re-enables a disabled account by token | 根据Token解封账号 +func UntieByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.UntieByToken(ctx, tokenValue, authType...) } // IsDisable checks if an account is disabled | 检查账号是否被封禁 -func IsDisable(loginID interface{}) bool { - return stputil.IsDisable(loginID) +func IsDisable(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.IsDisable(ctx, loginID, authType...) } -// CheckDisableByToken checks if account is disabled (throws error if disabled) | 检查Token对应账号是否被封禁(被封禁则抛出错误) -func CheckDisableByToken(tokenValue string) error { - return stputil.CheckDisable(tokenValue) +// IsDisableByToken checks if an account is disabled by token | 根据Token检查账号是否被封禁 +func IsDisableByToken(ctx context.Context, tokenValue string, authType ...string) bool { + return stputil.IsDisableByToken(ctx, tokenValue, authType...) } -// GetDisableTime gets remaining disabled time | 获取账号剩余封禁时间 -func GetDisableTime(loginID interface{}) (int64, error) { - return stputil.GetDisableTime(loginID) +// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) +func GetDisableTime(ctx context.Context, loginID interface{}, authType ...string) (int64, error) { + return stputil.GetDisableTime(ctx, loginID, authType...) } -// Untie unties/unlocks an account | 解除账号封禁 -func Untie(loginID interface{}) error { - return stputil.Untie(loginID) +// GetDisableTimeByToken gets remaining disable time by token | 根据Token获取剩余封禁时间(秒) +func GetDisableTimeByToken(ctx context.Context, tokenValue string, authType ...string) (int64, error) { + return stputil.GetDisableTimeByToken(ctx, tokenValue, authType...) } -// ============ Permission Check | 权限验证 ============ +// GetDisableInfo gets disable info | 获取封禁信息 +func GetDisableInfo(ctx context.Context, loginID interface{}, authType ...string) (*manager.DisableInfo, error) { + return stputil.GetDisableInfo(ctx, loginID, authType...) +} -// CheckPermissionByToken checks if the token has specified permission | 检查Token是否拥有指定权限 -func CheckPermissionByToken(tokenValue string, permission string) error { - return stputil.CheckPermission(tokenValue, permission) +// GetDisableInfoByToken gets disable info by token | 根据Token获取封禁信息 +func GetDisableInfoByToken(ctx context.Context, tokenValue string, authType ...string) (*manager.DisableInfo, error) { + return stputil.GetDisableInfoByToken(ctx, tokenValue, authType...) } -// HasPermission checks if the account has specified permission (returns bool) | 检查账号是否拥有指定权限(返回布尔值) -func HasPermission(loginID interface{}, permission string) bool { - return stputil.HasPermission(loginID, permission) +// ============ Session Management | Session管理 ============ + +// GetSession gets session by login ID | 根据登录ID获取Session +func GetSession(ctx context.Context, loginID interface{}, authType ...string) (*session.Session, error) { + return stputil.GetSession(ctx, loginID, authType...) } -// CheckPermissionAndByToken checks if the token has all specified permissions (AND logic) | 检查Token是否拥有所有指定权限(AND逻辑) -func CheckPermissionAndByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionAnd(tokenValue, permissions) +// GetSessionByToken gets session by token | 根据Token获取Session +func GetSessionByToken(ctx context.Context, tokenValue string, authType ...string) (*session.Session, error) { + return stputil.GetSessionByToken(ctx, tokenValue, authType...) } -// CheckPermissionOrByToken checks if the token has any of the specified permissions (OR logic) | 检查Token是否拥有指定权限中的任意一个(OR逻辑) -func CheckPermissionOrByToken(tokenValue string, permissions []string) error { - return stputil.CheckPermissionOr(tokenValue, permissions) +// DeleteSession deletes a session | 删除Session +func DeleteSession(ctx context.Context, loginID interface{}, authType ...string) error { + return stputil.DeleteSession(ctx, loginID, authType...) } -// GetPermissionListByToken gets the permission list for a token | 获取Token的权限列表 -func GetPermissionListByToken(tokenValue string) ([]string, error) { - return stputil.GetPermissionList(tokenValue) +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func DeleteSessionByToken(ctx context.Context, tokenValue string, authType ...string) error { + return stputil.DeleteSessionByToken(ctx, tokenValue, authType...) } -// ============ Role Check | 角色验证 ============ +// HasSession checks if session exists | 检查Session是否存在 +func HasSession(ctx context.Context, loginID interface{}, authType ...string) bool { + return stputil.HasSession(ctx, loginID, authType...) +} -// CheckRoleByToken checks if the token has specified role | 检查Token是否拥有指定角色 -func CheckRoleByToken(tokenValue string, role string) error { - return stputil.CheckRole(tokenValue, role) +// RenewSession renews session TTL | 续期Session +func RenewSession(ctx context.Context, loginID interface{}, ttl time.Duration, authType ...string) error { + return stputil.RenewSession(ctx, loginID, ttl, authType...) } -// HasRole checks if the account has specified role (returns bool) | 检查账号是否拥有指定角色(返回布尔值) -func HasRole(loginID interface{}, role string) bool { - return stputil.HasRole(loginID, role) +// RenewSessionByToken renews session TTL by token | 根据Token续期Session +func RenewSessionByToken(ctx context.Context, tokenValue string, ttl time.Duration, authType ...string) error { + return stputil.RenewSessionByToken(ctx, tokenValue, ttl, authType...) } -// CheckRoleAndByToken checks if the token has all specified roles (AND logic) | 检查Token是否拥有所有指定角色(AND逻辑) -func CheckRoleAndByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleAnd(tokenValue, roles) +// ============ Permission Verification | 权限验证 ============ + +// SetPermissions sets permissions for a login ID | 设置用户权限 +func SetPermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.SetPermissions(ctx, loginID, permissions, authType...) } -// CheckRoleOrByToken checks if the token has any of the specified roles (OR logic) | 检查Token是否拥有指定角色中的任意一个(OR逻辑) -func CheckRoleOrByToken(tokenValue string, roles []string) error { - return stputil.CheckRoleOr(tokenValue, roles) +// SetPermissionsByToken sets permissions by token | 根据 Token 设置对应账号的权限 +func SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.SetPermissionsByToken(ctx, tokenValue, permissions, authType...) } -// GetRoleListByToken gets the role list for a token | 获取Token的角色列表 -func GetRoleListByToken(tokenValue string) ([]string, error) { - return stputil.GetRoleList(tokenValue) +// RemovePermissions removes specified permissions for a login ID | 删除用户指定权限 +func RemovePermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + return stputil.RemovePermissions(ctx, loginID, permissions, authType...) } -// ============ Session Management | Session管理 ============ +// RemovePermissionsByToken removes specified permissions by token | 根据 Token 删除对应账号的指定权限 +func RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + return stputil.RemovePermissionsByToken(ctx, tokenValue, permissions, authType...) +} + +// GetPermissions gets permission list | 获取权限列表 +func GetPermissions(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetPermissions(ctx, loginID, authType...) +} + +// GetPermissionsByToken gets permission list by token | 根据 Token 获取对应账号的权限列表 +func GetPermissionsByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetPermissionsByToken(ctx, tokenValue, authType...) +} + +// HasPermission checks if has specified permission | 检查是否拥有指定权限 +func HasPermission(ctx context.Context, loginID interface{}, permission string, authType ...string) bool { + return stputil.HasPermission(ctx, loginID, permission, authType...) +} + +// HasPermissionByToken checks if the token has the specified permission | 检查Token是否拥有指定权限 +func HasPermissionByToken(ctx context.Context, tokenValue string, permission string, authType ...string) bool { + return stputil.HasPermissionByToken(ctx, tokenValue, permission, authType...) +} + +// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) +func HasPermissionsAnd(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAnd(ctx, loginID, permissions, authType...) +} + +// HasPermissionsAndByToken checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 +func HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsAndByToken(ctx, tokenValue, permissions, authType...) +} -// GetSession gets the session for a login ID | 获取登录ID的Session -func GetSession(loginID interface{}) (*Session, error) { - return stputil.GetSession(loginID) +// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) +func HasPermissionsOr(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOr(ctx, loginID, permissions, authType...) } -// GetSessionByToken gets the session by token | 根据Token获取Session -func GetSessionByToken(tokenValue string) (*Session, error) { - return stputil.GetSessionByToken(tokenValue) +// HasPermissionsOrByToken checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 +func HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + return stputil.HasPermissionsOrByToken(ctx, tokenValue, permissions, authType...) } -// GetTokenSession gets the token session | 获取Token的Session -func GetTokenSession(tokenValue string) (*Session, error) { - return stputil.GetTokenSession(tokenValue) +// ============ Role Management | 角色管理 ============ + +// SetRoles sets roles for a login ID | 设置用户角色 +func SetRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.SetRoles(ctx, loginID, roles, authType...) } -// ============ Token Renewal | Token续期 ============ -// Note: Token auto-renewal is handled automatically by the manager -// 注意:Token自动续期由管理器自动处理 +// SetRolesByToken sets roles by token | 根据 Token 设置对应账号的角色 +func SetRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.SetRolesByToken(ctx, tokenValue, roles, authType...) +} + +// RemoveRoles removes specified roles for a login ID | 删除用户指定角色 +func RemoveRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + return stputil.RemoveRoles(ctx, loginID, roles, authType...) +} + +// RemoveRolesByToken removes specified roles by token | 根据 Token 删除对应账号的指定角色 +func RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + return stputil.RemoveRolesByToken(ctx, tokenValue, roles, authType...) +} + +// GetRoles gets role list | 获取角色列表 +func GetRoles(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetRoles(ctx, loginID, authType...) +} + +// GetRolesByToken gets role list by token | 根据 Token 获取对应账号的角色列表 +func GetRolesByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + return stputil.GetRolesByToken(ctx, tokenValue, authType...) +} + +// HasRole checks if has specified role | 检查是否拥有指定角色 +func HasRole(ctx context.Context, loginID interface{}, role string, authType ...string) bool { + return stputil.HasRole(ctx, loginID, role, authType...) +} + +// HasRoleByToken checks if the token has the specified role | 检查 Token 是否拥有指定角色 +func HasRoleByToken(ctx context.Context, tokenValue string, role string, authType ...string) bool { + return stputil.HasRoleByToken(ctx, tokenValue, role, authType...) +} + +// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) +func HasRolesAnd(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesAnd(ctx, loginID, roles, authType...) +} + +// HasRolesAndByToken checks if the token has all specified roles | 检查 Token 是否拥有所有指定角色 +func HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesAndByToken(ctx, tokenValue, roles, authType...) +} + +// HasRolesOr checks if has any role (OR logic) | 检查是否拥有任一角色(OR逻辑) +func HasRolesOr(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + return stputil.HasRolesOr(ctx, loginID, roles, authType...) +} + +// HasRolesOrByToken checks if the token has any of the specified roles | 检查 Token 是否拥有任一指定角色 +func HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + return stputil.HasRolesOrByToken(ctx, tokenValue, roles, authType...) +} + +// ============ Token Tag | Token标签 ============ + +// SetTokenTag sets token tag | 设置Token标签 +func SetTokenTag(ctx context.Context, tokenValue, tag string, authType ...string) error { + return stputil.SetTokenTag(ctx, tokenValue, tag, authType...) +} + +// GetTokenTag gets token tag | 获取Token标签 +func GetTokenTag(ctx context.Context, tokenValue string, authType ...string) (string, error) { + return stputil.GetTokenTag(ctx, tokenValue, authType...) +} + +// ============ Session Query | 会话查询 ============ + +// GetTokenValueListByLoginID gets all tokens for a login ID | 获取指定账号的所有Token +func GetTokenValueListByLoginID(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + return stputil.GetTokenValueListByLoginID(ctx, loginID, authType...) +} + +// GetSessionCountByLoginID gets session count for a login ID | 获取指定账号的Session数量 +func GetSessionCountByLoginID(ctx context.Context, loginID interface{}, authType ...string) (int, error) { + return stputil.GetSessionCountByLoginID(ctx, loginID, authType...) +} // ============ Security Features | 安全特性 ============ -// GenerateNonce generates a new nonce token | 生成新的Nonce令牌 -func GenerateNonce() (string, error) { - return stputil.GenerateNonce() +// Generate Generates a one-time nonce | 生成一次性随机数 +func Generate(ctx context.Context, authType ...string) (string, error) { + return stputil.Generate(ctx, authType...) +} + +// Verify Verifies a nonce | 验证随机数 +func Verify(ctx context.Context, nonce string, authType ...string) bool { + return stputil.Verify(ctx, nonce, authType...) +} + +// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func VerifyAndConsume(ctx context.Context, nonce string, authType ...string) error { + return stputil.VerifyAndConsume(ctx, nonce, authType...) +} + +// IsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func IsValidNonce(ctx context.Context, nonce string, authType ...string) bool { + return stputil.IsValidNonce(ctx, nonce, authType...) +} + +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func GenerateTokenPair(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GenerateTokenPair(ctx, loginID, deviceOrAutoType...) +} + +// VerifyAccessToken verifies access token validity | 验证访问令牌是否有效 +func VerifyAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.VerifyAccessToken(ctx, accessToken, authType...) +} + +// VerifyAccessTokenAndGetInfo verifies access token and returns token info | 验证访问令牌并返回Token信息 +func VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*security.AccessTokenInfo, bool) { + return stputil.VerifyAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息 +func GetRefreshTokenInfo(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.GetRefreshTokenInfo(ctx, refreshToken, authType...) +} + +// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func RefreshAccessToken(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + return stputil.RefreshAccessToken(ctx, refreshToken, authType...) +} + +// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 +func RevokeRefreshToken(ctx context.Context, refreshToken string, authType ...string) error { + return stputil.RevokeRefreshToken(ctx, refreshToken, authType...) +} + +// IsValid checks whether token is valid | 检查Token是否有效 +func IsValid(ctx context.Context, refreshToken string, authType ...string) bool { + return stputil.IsValid(ctx, refreshToken, authType...) +} + +// ============ OAuth2 Features | OAuth2 功能 ============ + +// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func RegisterClient(ctx context.Context, client *oauth2.Client, authType ...string) error { + return stputil.RegisterClient(ctx, client, authType...) +} + +// UnregisterClient unregisters an OAuth2 client | 注销OAuth2客户端 +func UnregisterClient(ctx context.Context, clientID string, authType ...string) error { + return stputil.UnregisterClient(ctx, clientID, authType...) +} + +// GetClient gets OAuth2 client information | 获取OAuth2客户端信息 +func GetClient(ctx context.Context, clientID string, authType ...string) (*oauth2.Client, error) { + return stputil.GetClient(ctx, clientID, authType...) +} + +// GenerateAuthorizationCode creates an authorization code | 创建授权码 +func GenerateAuthorizationCode(ctx context.Context, clientID, loginID, redirectURI string, scope []string, authType ...string) (*oauth2.AuthorizationCode, error) { + return stputil.GenerateAuthorizationCode(ctx, clientID, loginID, redirectURI, scope, authType...) +} + +// ExchangeCodeForToken exchanges authorization code for token | 使用授权码换取令牌 +func ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI, authType...) +} + +// ValidateAccessToken verifies OAuth2 access token | 验证OAuth2访问令牌 +func ValidateAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + return stputil.ValidateAccessToken(ctx, accessToken, authType...) +} + +// ValidateAccessTokenAndGetInfo verifies OAuth2 access token and get info | 验证OAuth2访问令牌并获取信息 +func ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.ValidateAccessTokenAndGetInfo(ctx, accessToken, authType...) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌(OAuth2) +func OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2RefreshAccessToken(ctx, clientID, refreshToken, clientSecret, authType...) +} + +// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func RevokeToken(ctx context.Context, accessToken string, authType ...string) error { + return stputil.RevokeToken(ctx, accessToken, authType...) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点 +func OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2Token(ctx, req, validateUser, authType...) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2ClientCredentialsToken(ctx, clientID, clientSecret, scopes, authType...) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + return stputil.OAuth2PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser, authType...) +} + +// ============ OAuth2 Grant Type Constants | OAuth2授权类型常量 ============ + +const ( + // GrantTypeAuthorizationCode Authorization code grant type | 授权码模式 + GrantTypeAuthorizationCode = oauth2.GrantTypeAuthorizationCode + // GrantTypeClientCredentials Client credentials grant type | 客户端凭证模式 + GrantTypeClientCredentials = oauth2.GrantTypeClientCredentials + // GrantTypePassword Password grant type | 密码模式 + GrantTypePassword = oauth2.GrantTypePassword + // GrantTypeRefreshToken Refresh token grant type | 刷新令牌模式 + GrantTypeRefreshToken = oauth2.GrantTypeRefreshToken +) + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func GetConfig(ctx context.Context, authType ...string) *config.Config { + return stputil.GetConfig(ctx, authType...) +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func GetStorage(ctx context.Context, authType ...string) adapter.Storage { + return stputil.GetStorage(ctx, authType...) +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func GetCodec(ctx context.Context, authType ...string) adapter.Codec { + return stputil.GetCodec(ctx, authType...) +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func GetLog(ctx context.Context, authType ...string) adapter.Log { + return stputil.GetLog(ctx, authType...) +} + +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func GetPool(ctx context.Context, authType ...string) adapter.Pool { + return stputil.GetPool(ctx, authType...) +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func GetGenerator(ctx context.Context, authType ...string) adapter.Generator { + return stputil.GetGenerator(ctx, authType...) +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func GetNonceManager(ctx context.Context, authType ...string) *security.NonceManager { + return stputil.GetNonceManager(ctx, authType...) +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func GetRefreshManager(ctx context.Context, authType ...string) *security.RefreshTokenManager { + return stputil.GetRefreshManager(ctx, authType...) +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func GetEventManager(ctx context.Context, authType ...string) *listener.Manager { + return stputil.GetEventManager(ctx, authType...) } -// VerifyNonce verifies a nonce token | 验证Nonce令牌 -func VerifyNonce(nonce string) bool { - return stputil.VerifyNonce(nonce) +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func GetOAuth2Server(ctx context.Context, authType ...string) *oauth2.OAuth2Server { + return stputil.GetOAuth2Server(ctx, authType...) } -// LoginWithRefreshToken performs login and returns refresh token info | 登录并返回刷新令牌信息 -func LoginWithRefreshToken(loginID interface{}, device ...string) (*RefreshTokenInfo, error) { - return stputil.LoginWithRefreshToken(loginID, device...) +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func RegisterFunc(event listener.Event, fn func(*listener.EventData), authType ...string) { + stputil.RegisterFunc(event, fn, authType...) } -// RefreshAccessToken refreshes the access token using a refresh token | 使用刷新令牌刷新访问令牌 -func RefreshAccessToken(refreshToken string) (*RefreshTokenInfo, error) { - return stputil.RefreshAccessToken(refreshToken) +// Register registers an event listener | 注册事件监听器 +func Register(event listener.Event, l listener.Listener, authType ...string) string { + return stputil.Register(event, l, authType...) } -// RevokeRefreshToken revokes a refresh token | 撤销刷新令牌 -func RevokeRefreshToken(refreshToken string) error { - return stputil.RevokeRefreshToken(refreshToken) +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func RegisterWithConfig(event listener.Event, l listener.Listener, cfg listener.ListenerConfig, authType ...string) string { + return stputil.RegisterWithConfig(event, l, cfg, authType...) } -// GetOAuth2Server gets the OAuth2 server instance | 获取OAuth2服务器实例 -func GetOAuth2Server() *OAuth2Server { - return stputil.GetOAuth2Server() +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func Unregister(id string, authType ...string) bool { + return stputil.Unregister(id, authType...) } -// Version Sa-Token-Go version | Sa-Token-Go版本 -const Version = core.Version +// TriggerEvent manually triggers an event | 手动触发事件 +func TriggerEvent(data *listener.EventData, authType ...string) { + stputil.TriggerEvent(data, authType...) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func TriggerEventAsync(data *listener.EventData, authType ...string) { + stputil.TriggerEventAsync(data, authType...) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func TriggerEventSync(data *listener.EventData, authType ...string) { + stputil.TriggerEventSync(data, authType...) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func WaitEvents(authType ...string) { + stputil.WaitEvents(authType...) +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func ClearEventListeners(event listener.Event, authType ...string) { + stputil.ClearEventListeners(event, authType...) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func ClearAllEventListeners(authType ...string) { + stputil.ClearAllEventListeners(authType...) +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func CountEventListeners(event listener.Event, authType ...string) int { + return stputil.CountEventListeners(event, authType...) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func CountAllListeners(authType ...string) int { + return stputil.CountAllListeners(authType...) +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func GetEventListenerIDs(event listener.Event, authType ...string) []string { + return stputil.GetEventListenerIDs(event, authType...) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func GetAllRegisteredEvents(authType ...string) []listener.Event { + return stputil.GetAllRegisteredEvents(authType...) +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func HasEventListeners(event listener.Event, authType ...string) bool { + return stputil.HasEventListeners(event, authType...) +} diff --git a/integrations/gin/go.mod b/integrations/gin/go.mod index 3b8d3d3..9350f94 100644 --- a/integrations/gin/go.mod +++ b/integrations/gin/go.mod @@ -1,14 +1,11 @@ module github.com/click33/sa-token-go/integrations/gin -go 1.23.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 - github.com/click33/sa-token-go/stputil v0.1.5 + github.com/click33/sa-token-go/core v0.1.7 + github.com/click33/sa-token-go/stputil v0.1.7 github.com/gin-gonic/gin v1.10.0 - github.com/stretchr/testify v1.11.1 ) require ( @@ -16,14 +13,13 @@ require ( github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.20.0 // indirect github.com/goccy/go-json v0.10.2 // indirect - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/cpuid/v2 v2.2.7 // indirect @@ -34,22 +30,16 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/ugorji/go/codec v1.2.12 // indirect golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.41.0 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) - -replace ( - github.com/click33/sa-token-go/core => ../../core - github.com/click33/sa-token-go/stputil => ../../stputil -) diff --git a/integrations/gin/go.sum b/integrations/gin/go.sum index 7d6271a..3343982 100644 --- a/integrations/gin/go.sum +++ b/integrations/gin/go.sum @@ -1,5 +1,7 @@ github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/stputil v0.1.6 h1:S+V64jQzppE9c1wXcmHppCRlrSsU2iTfvdPGlMbs2WI= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -11,7 +13,7 @@ github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/o github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -32,7 +34,7 @@ github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65E golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= diff --git a/integrations/gin/middleware.go b/integrations/gin/middleware.go new file mode 100644 index 0000000..03525e7 --- /dev/null +++ b/integrations/gin/middleware.go @@ -0,0 +1,321 @@ +package gin + +import ( + "context" + "errors" + "net/http" + + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/manager" + + saContext "github.com/click33/sa-token-go/core/context" + "github.com/click33/sa-token-go/stputil" + "github.com/gin-gonic/gin" +) + +// LogicType permission/role logic type | 权限/角色判断逻辑 +type LogicType string + +const ( + SaTokenCtxKey = "saCtx" + + LogicOr LogicType = "OR" // Logical OR | 任一满足 + LogicAnd LogicType = "AND" // Logical AND | 全部满足 +) + +type AuthOption func(*AuthOptions) + +type AuthOptions struct { + AuthType string + LogicType LogicType + FailFunc func(c *gin.Context, err error) +} + +func defaultAuthOptions() *AuthOptions { + return &AuthOptions{LogicType: LogicAnd} // 默认 AND +} + +// WithAuthType sets auth type | 设置认证类型 +func WithAuthType(authType string) AuthOption { + return func(o *AuthOptions) { + o.AuthType = authType + } +} + +// WithLogicType sets LogicType option | 设置逻辑类型 +func WithLogicType(logicType LogicType) AuthOption { + return func(o *AuthOptions) { + o.LogicType = logicType + } +} + +// WithFailFunc sets auth failure callback | 设置认证失败回调 +func WithFailFunc(fn func(c *gin.Context, err error)) AuthOption { + return func(o *AuthOptions) { + o.FailFunc = fn + } +} + +// ============ Middlewares | 中间件 ============ + +// RegisterSaTokenContextMiddleware initializes Sa-Token context for each request | 初始化每次请求的 Sa-Token 上下文的中间件 +func RegisterSaTokenContextMiddleware(ctx context.Context, opts ...AuthOption) gin.HandlerFunc { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *gin.Context) { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(c, err) + } else { + writeErrorResponse(c, err) + } + return + } + + _ = getSaContext(c, mgr) + } +} + +// AuthMiddleware authentication middleware | 认证中间件 +func AuthMiddleware(ctx context.Context, opts ...AuthOption) gin.HandlerFunc { + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *gin.Context) { + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(c, err) + } else { + writeErrorResponse(c, err) + } + c.Abort() + return + } + + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // 检查登录 | Check login + isLogin, err := mgr.IsLogin(ctx, tokenValue) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(c, err) + } else { + writeErrorResponse(c, err) + } + c.Abort() + return + } + if !isLogin { + if options.FailFunc != nil { + options.FailFunc(c, core.ErrTokenExpired) + } else { + writeErrorResponse(c, core.ErrTokenExpired) + } + c.Abort() + return + } + + c.Next() + } +} + +// PermissionMiddleware permission check middleware | 权限校验中间件 +func PermissionMiddleware( + ctx context.Context, + permissions []string, + opts ...AuthOption, +) gin.HandlerFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *gin.Context) { + // No permission required | 无需权限直接放行 + if len(permissions) == 0 { + c.Next() + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(c, err) + } else { + writeErrorResponse(c, err) + } + c.Abort() + return + } + + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // Permission check | 权限校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasPermissionsAndByToken(ctx, tokenValue, permissions) + } else { + ok = mgr.HasPermissionsOrByToken(ctx, tokenValue, permissions) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(c, core.ErrPermissionDenied) + } else { + writeErrorResponse(c, core.ErrPermissionDenied) + } + c.Abort() + return + } + + c.Next() + } +} + +// RoleMiddleware role check middleware | 角色校验中间件 +func RoleMiddleware( + ctx context.Context, + roles []string, + opts ...AuthOption, +) gin.HandlerFunc { + + options := defaultAuthOptions() + for _, opt := range opts { + opt(options) + } + + return func(c *gin.Context) { + // No role required | 无需角色直接放行 + if len(roles) == 0 { + c.Next() + return + } + + // Get Manager | 获取 Manager + mgr, err := stputil.GetManager(options.AuthType) + if err != nil { + if options.FailFunc != nil { + options.FailFunc(c, err) + } else { + writeErrorResponse(c, err) + } + c.Abort() + return + } + + saCtx := getSaContext(c, mgr) + tokenValue := saCtx.GetTokenValue() + + // Role check | 角色校验 + var ok bool + if options.LogicType == LogicAnd { + ok = mgr.HasRolesAndByToken(ctx, tokenValue, roles) + } else { + ok = mgr.HasRolesOrByToken(ctx, tokenValue, roles) + } + + if !ok { + if options.FailFunc != nil { + options.FailFunc(c, core.ErrRoleDenied) + } else { + writeErrorResponse(c, core.ErrRoleDenied) + } + c.Abort() + return + } + + c.Next() + } +} + +// GetSaTokenContext gets Sa-Token context from Gin context | 获取 Sa-Token 上下文 +func GetSaTokenContext(c *gin.Context) (*saContext.SaTokenContext, bool) { + v, exists := c.Get(SaTokenCtxKey) + if !exists { + return nil, false + } + + ctx, ok := v.(*saContext.SaTokenContext) + return ctx, ok +} + +func getSaContext(c *gin.Context, mgr *manager.Manager) *saContext.SaTokenContext { + // Try get from context | 尝试从 ctx 取值 + if v, exists := c.Get(SaTokenCtxKey); exists { + if saCtx, ok := v.(*saContext.SaTokenContext); ok { + return saCtx + } + } + + // Create new context | 创建并缓存 SaTokenContext + saCtx := saContext.NewContext(NewGinContext(c), mgr) + c.Set(SaTokenCtxKey, saCtx) + + return saCtx +} + +// ============ Error Handling Helpers | 错误处理辅助函数 ============ + +// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 +func writeErrorResponse(c *gin.Context, err error) { + var saErr *core.SaTokenError + var code int + var message string + var httpStatus int + + // Check if it's a SaTokenError | 检查是否为SaTokenError + if errors.As(err, &saErr) { + code = saErr.Code + message = saErr.Message + httpStatus = getHTTPStatusFromCode(code) + } else { + // Handle standard errors | 处理标准错误 + code = core.CodeServerError + message = err.Error() + httpStatus = http.StatusInternalServerError + } + + c.JSON(httpStatus, gin.H{ + "code": code, + "message": message, + "data": err.Error(), + }) +} + +// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 +func writeSuccessResponse(c *gin.Context, data interface{}) { + c.JSON(http.StatusOK, gin.H{ + "code": core.CodeSuccess, + "message": "success", + "data": data, + }) +} + +// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 +func getHTTPStatusFromCode(code int) int { + switch code { + case core.CodeNotLogin: + return http.StatusUnauthorized + case core.CodePermissionDenied: + return http.StatusForbidden + case core.CodeBadRequest: + return http.StatusBadRequest + case core.CodeNotFound: + return http.StatusNotFound + case core.CodeServerError: + return http.StatusInternalServerError + default: + return http.StatusInternalServerError + } +} diff --git a/integrations/gin/plugin.go b/integrations/gin/plugin.go deleted file mode 100644 index d5e3f29..0000000 --- a/integrations/gin/plugin.go +++ /dev/null @@ -1,249 +0,0 @@ -package gin - -import ( - "errors" - "net/http" - - "github.com/click33/sa-token-go/core" - "github.com/gin-gonic/gin" -) - -// Plugin Gin plugin for Sa-Token | Gin插件 -type Plugin struct { - manager *core.Manager -} - -// NewPlugin creates a Gin plugin | 创建Gin插件 -func NewPlugin(manager *core.Manager) *Plugin { - return &Plugin{ - manager: manager, - } -} - -// AuthMiddleware authentication middleware | 认证中间件 -func (p *Plugin) AuthMiddleware() gin.HandlerFunc { - return func(c *gin.Context) { - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, p.manager) - - // Check login | 检查登录 - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(c, err) - c.Abort() - return - } - - // Store Sa-Token context in Gin context | 将Sa-Token上下文存储到Gin上下文 - c.Set("satoken", saCtx) - c.Next() - } -} - -// PermissionRequired permission validation middleware | 权限验证中间件 -func (p *Plugin) PermissionRequired(permission string) gin.HandlerFunc { - return func(c *gin.Context) { - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, p.manager) - - // Check login | 检查登录 - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(c, err) - c.Abort() - return - } - - // Check permission | 检查权限 - if !saCtx.HasPermission(permission) { - writeErrorResponse(c, core.NewPermissionDeniedError(permission)) - c.Abort() - return - } - - c.Set("satoken", saCtx) - c.Next() - } -} - -// RoleRequired role validation middleware | 角色验证中间件 -func (p *Plugin) RoleRequired(role string) gin.HandlerFunc { - return func(c *gin.Context) { - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, p.manager) - - // Check login | 检查登录 - if err := saCtx.CheckLogin(); err != nil { - writeErrorResponse(c, err) - c.Abort() - return - } - - // Check role | 检查角色 - if !saCtx.HasRole(role) { - writeErrorResponse(c, core.NewRoleDeniedError(role)) - c.Abort() - return - } - - c.Set("satoken", saCtx) - c.Next() - } -} - -// LoginHandler login handler example | 登录处理器示例 -func (p *Plugin) LoginHandler(c *gin.Context) { - var req struct { - Username string `json:"username" binding:"required"` - Password string `json:"password" binding:"required"` - Device string `json:"device"` - } - - if err := c.ShouldBindJSON(&req); err != nil { - writeErrorResponse(c, core.NewError(core.CodeBadRequest, "invalid request parameters", err)) - return - } - - // TODO: Validate username and password (should call your user service) | 验证用户名密码(这里应该调用你的用户服务) - // if !validateUser(req.Username, req.Password) { ... } - - // Login | 登录 - device := req.Device - if device == "" { - device = "default" - } - - token, err := p.manager.Login(req.Username, device) - if err != nil { - writeErrorResponse(c, core.NewError(core.CodeServerError, "login failed", err)) - return - } - - // Set cookie (optional) | 设置Cookie(可选) - cfg := p.manager.GetConfig() - if cfg.IsReadCookie { - maxAge := int(cfg.Timeout) - if maxAge < 0 { - maxAge = 0 - } - c.SetCookie( - cfg.TokenName, - token, - maxAge, - cfg.CookieConfig.Path, - cfg.CookieConfig.Domain, - cfg.CookieConfig.Secure, - cfg.CookieConfig.HttpOnly, - ) - } - - writeSuccessResponse(c, gin.H{ - "token": token, - }) -} - -// LogoutHandler logout handler | 登出处理器 -func (p *Plugin) LogoutHandler(c *gin.Context) { - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, p.manager) - - loginID, err := saCtx.GetLoginID() - if err != nil { - writeErrorResponse(c, err) - return - } - - if err := p.manager.Logout(loginID); err != nil { - writeErrorResponse(c, core.NewError(core.CodeServerError, "logout failed", err)) - return - } - - writeSuccessResponse(c, gin.H{ - "message": "logout successful", - }) -} - -// UserInfoHandler user info handler example | 获取用户信息处理器示例 -func (p *Plugin) UserInfoHandler(c *gin.Context) { - ctx := NewGinContext(c) - saCtx := core.NewContext(ctx, p.manager) - - loginID, err := saCtx.GetLoginID() - if err != nil { - writeErrorResponse(c, err) - return - } - - // Get user permissions and roles | 获取用户权限和角色 - permissions, _ := p.manager.GetPermissions(loginID) - roles, _ := p.manager.GetRoles(loginID) - - writeSuccessResponse(c, gin.H{ - "loginId": loginID, - "permissions": permissions, - "roles": roles, - }) -} - -// GetSaToken gets Sa-Token context from Gin context | 从Gin上下文获取Sa-Token上下文 -func GetSaToken(c *gin.Context) (*core.SaTokenContext, bool) { - satoken, exists := c.Get("satoken") - if !exists { - return nil, false - } - ctx, ok := satoken.(*core.SaTokenContext) - return ctx, ok -} - -// ============ Error Handling Helpers | 错误处理辅助函数 ============ - -// writeErrorResponse writes a standardized error response | 写入标准化的错误响应 -func writeErrorResponse(c *gin.Context, err error) { - var saErr *core.SaTokenError - var code int - var message string - var httpStatus int - - // Check if it's a SaTokenError | 检查是否为SaTokenError - if errors.As(err, &saErr) { - code = saErr.Code - message = saErr.Message - httpStatus = getHTTPStatusFromCode(code) - } else { - // Handle standard errors | 处理标准错误 - code = core.CodeServerError - message = err.Error() - httpStatus = http.StatusInternalServerError - } - - c.JSON(httpStatus, gin.H{ - "code": code, - "message": message, - "error": err.Error(), - }) -} - -// writeSuccessResponse writes a standardized success response | 写入标准化的成功响应 -func writeSuccessResponse(c *gin.Context, data interface{}) { - c.JSON(http.StatusOK, gin.H{ - "code": core.CodeSuccess, - "message": "success", - "data": data, - }) -} - -// getHTTPStatusFromCode converts Sa-Token error code to HTTP status | 将Sa-Token错误码转换为HTTP状态码 -func getHTTPStatusFromCode(code int) int { - switch code { - case core.CodeNotLogin: - return http.StatusUnauthorized - case core.CodePermissionDenied: - return http.StatusForbidden - case core.CodeBadRequest: - return http.StatusBadRequest - case core.CodeNotFound: - return http.StatusNotFound - case core.CodeServerError: - return http.StatusInternalServerError - default: - return http.StatusInternalServerError - } -} diff --git a/integrations/kratos/export.go b/integrations/kratos/export.go index 595df16..2000c96 100644 --- a/integrations/kratos/export.go +++ b/integrations/kratos/export.go @@ -100,12 +100,12 @@ var ( // ============ Core constructor functions | 核心构造函数 ============ -// DefaultConfig returns default configuration | 返回默认配置 +// DefaultConfig returns log configuration | 返回默认配置 func DefaultConfig() *Config { return core.DefaultConfig() } -// NewManager creates a new authentication manager | 创建新的认证管理器 +// NewManager creates a new authentication manager-example | 创建新的认证管理器 func NewManager(storage Storage, cfg *Config) *Manager { return core.NewManager(storage, cfg) } @@ -130,7 +130,7 @@ func NewTokenGenerator(cfg *Config) *TokenGenerator { return core.NewTokenGenerator(cfg) } -// NewEventManager creates a new event manager | 创建新的事件管理器 +// NewEventManager creates a new event manager-example | 创建新的事件管理器 func NewEventManager() *EventManager { return core.NewEventManager() } @@ -140,12 +140,12 @@ func NewBuilder() *Builder { return core.NewBuilder() } -// NewNonceManager creates a new nonce manager | 创建新的Nonce管理器 +// NewNonceManager creates a new nonce manager-example | 创建新的Nonce管理器 func NewNonceManager(storage Storage, prefix string, ttl ...int64) *NonceManager { return core.NewNonceManager(storage, prefix, ttl...) } -// NewRefreshTokenManager creates a new refresh token manager | 创建新的刷新令牌管理器 +// NewRefreshTokenManager creates a new refresh token manager-example | 创建新的刷新令牌管理器 func NewRefreshTokenManager(storage Storage, prefix string, cfg *Config) *RefreshTokenManager { return core.NewRefreshTokenManager(storage, prefix, cfg) } diff --git a/integrations/kratos/go.mod b/integrations/kratos/go.mod index ee423cc..70530cc 100644 --- a/integrations/kratos/go.mod +++ b/integrations/kratos/go.mod @@ -1,31 +1,26 @@ module github.com/click33/sa-token-go/integrations/kratos -go 1.24.0 - -toolchain go1.24.1 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 - github.com/click33/sa-token-go/storage/memory v0.1.5 + github.com/click33/sa-token-go/core v0.1.6 + github.com/click33/sa-token-go/storage/memory v0.1.6 github.com/click33/sa-token-go/stputil v0.1.5 github.com/go-kratos/kratos/v2 v2.9.1 ) require ( - github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/go-kratos/aegis v0.2.0 // indirect github.com/go-playground/assert/v2 v2.2.0 // indirect github.com/go-playground/form/v4 v4.2.0 // indirect - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/mux v1.8.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect - github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.13.1 // indirect - github.com/stretchr/testify v1.11.1 // indirect golang.org/x/net v0.43.0 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.35.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 // indirect google.golang.org/grpc v1.71.0 // indirect diff --git a/integrations/kratos/go.sum b/integrations/kratos/go.sum index 17c483d..e6a7fbc 100644 --- a/integrations/kratos/go.sum +++ b/integrations/kratos/go.sum @@ -1,6 +1,6 @@ -github.com/click33/sa-token-go/core v0.1.4 h1:mODeJ0WKSusQmiO5b/uK9UDD0OFqLrkJ/j7W/2e3Ios= -github.com/click33/sa-token-go/storage/memory v0.1.4 h1:gA2HT42Q84+qaOotHOOrVQOHTRybNxdafllELV5yw4o= -github.com/click33/sa-token-go/stputil v0.1.4 h1:YvMEwPfAfTunQn+AePudO3Esp0CvLoc2o5kmg/uZf/c= +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= +github.com/click33/sa-token-go/storage/memory v0.1.6 h1:iGFEy+HtTJLOpKnbIMbgpXyKotsKpPQu6wWTZVOXQis= +github.com/click33/sa-token-go/stputil v0.1.5 h1:603tbI4JkBTg3MnfTj+lCMDxJOKSCOqsMyC2zyuvEco= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/go-kratos/aegis v0.2.0 h1:dObzCDWn3XVjUkgxyBp6ZeWtx/do0DPZ7LY3yNSJLUQ= @@ -11,8 +11,7 @@ github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvSc github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/form/v4 v4.2.0 h1:N1wh+Goz61e6w66vo8vJkQt+uwZSoLz50kZPJWR8eic= github.com/go-playground/form/v4 v4.2.0/go.mod h1:q1a2BY+AQUUzhl6xA/6hBetay6dEIhMHjgvJiGo6K7U= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= -github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -29,7 +28,7 @@ github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRI github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= google.golang.org/genproto/googleapis/rpc v0.0.0-20251103181224-f26f9409b101 h1:tRPGkdGHuewF4UisLzzHHr1spKw92qLM98nIzxbC0wY= diff --git a/integrations/zero/go.mod b/integrations/zero/go.mod new file mode 100644 index 0000000..0e7068e --- /dev/null +++ b/integrations/zero/go.mod @@ -0,0 +1,3 @@ +module github.com/click33/sa-token-go/integrations/zero + +go 1.25.0 diff --git a/log/gf/go.mod b/log/gf/go.mod new file mode 100644 index 0000000..d4888f6 --- /dev/null +++ b/log/gf/go.mod @@ -0,0 +1,25 @@ +module github.com/click33/sa-token-go/log/gf + +go 1.25.0 + +require github.com/gogf/gf/v2 v2.9.4 + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/emirpasic/gods v1.18.1 // indirect + github.com/fatih/color v1.18.0 // indirect + github.com/fsnotify/fsnotify v1.9.0 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + go.opentelemetry.io/auto/sdk v1.1.0 // indirect + go.opentelemetry.io/otel v1.38.0 // indirect + go.opentelemetry.io/otel/metric v1.38.0 // indirect + go.opentelemetry.io/otel/sdk v1.38.0 // indirect + go.opentelemetry.io/otel/trace v1.38.0 // indirect + golang.org/x/sys v0.35.0 // indirect + golang.org/x/text v0.28.0 // indirect +) diff --git a/log/gf/go.sum b/log/gf/go.sum new file mode 100644 index 0000000..84e26e9 --- /dev/null +++ b/log/gf/go.sum @@ -0,0 +1,34 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/clbanning/mxj/v2 v2.7.0 h1:WA/La7UGCanFe5NpHF0Q3DNtnCsVoxbPKuyBNHWRyME= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/gogf/gf/v2 v2.9.4 h1:6vleEWypot9WBPncP2GjbpgAUeG6Mzb1YESb9nPMkjY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/grokify/html-strip-tags-go v0.1.0 h1:03UrQLjAny8xci+R+qjCce/MYnpNXCtgzltlQbOBae4= +github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= +github.com/olekukonko/errors v1.1.0 h1:RNuGIh15QdDenh+hNvKrJkmxxjV4hcS50Db478Ou5sM= +github.com/olekukonko/ll v0.0.9 h1:Y+1YqDfVkqMWuEQMclsF9HUR5+a82+dxJuL1HHSRpxI= +github.com/olekukonko/tablewriter v1.1.0 h1:N0LHrshF4T39KvI96fn6GT8HEjXRXYNDrDjKFDB7RIY= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/sdk v1.38.0 h1:l48sr5YbNf2hpCUj/FoGhW9yDkl+Ma+LrVl8qaM5b+E= +go.opentelemetry.io/otel/sdk/metric v1.38.0 h1:aSH66iL0aZqo//xXzQLYozmWrXxyFkBJ6qT5wthqPoM= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= +golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= +golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/log/gf/log_adaper_gf.go b/log/gf/log_adaper_gf.go new file mode 100644 index 0000000..a547177 --- /dev/null +++ b/log/gf/log_adaper_gf.go @@ -0,0 +1,62 @@ +// @Author daixk 2025/11/27 22:58:00 +package gf + +import ( + "context" + "github.com/gogf/gf/v2/os/glog" +) + +// GFLogger adapts GoFrame v2 glog.Logger to Sa-Token logger interface | GoFrame v2 glog 适配器 +type GFLogger struct { + ctx context.Context + l *glog.Logger +} + +func NewGFLogger(ctx context.Context, l *glog.Logger) *GFLogger { + return &GFLogger{ + ctx: ctx, + l: l, + } +} + +// ---- Implement Adapter Interface | 实现 Adapter 接口 ---- + +func (g *GFLogger) Print(v ...any) { + g.l.Print(g.ctx, v...) +} + +func (g *GFLogger) Printf(format string, v ...any) { + g.l.Printf(g.ctx, format, v...) +} + +func (g *GFLogger) Debug(v ...any) { + g.l.Debug(g.ctx, v...) +} + +func (g *GFLogger) Debugf(format string, v ...any) { + g.l.Debugf(g.ctx, format, v...) +} + +func (g *GFLogger) Info(v ...any) { + g.l.Info(g.ctx, v...) +} + +func (g *GFLogger) Infof(format string, v ...any) { + g.l.Infof(g.ctx, format, v...) +} + +func (g *GFLogger) Warn(v ...any) { + g.l.Warning(g.ctx, v...) +} + +func (g *GFLogger) Warnf(format string, v ...any) { + g.l.Warningf(g.ctx, format, v...) +} + +func (g *GFLogger) Error(v ...any) { + g.l.Error(g.ctx, v...) +} + +func (g *GFLogger) Errorf(format string, v ...any) { + g.l.Errorf(g.ctx, format, v...) +} diff --git a/log/nop/go.mod b/log/nop/go.mod new file mode 100644 index 0000000..4553975 --- /dev/null +++ b/log/nop/go.mod @@ -0,0 +1,3 @@ +module github.com/click33/sa-token-go/log/nop + +go 1.25.0 \ No newline at end of file diff --git a/log/nop/log_adaper_nop.go b/log/nop/log_adaper_nop.go new file mode 100644 index 0000000..a05557d --- /dev/null +++ b/log/nop/log_adaper_nop.go @@ -0,0 +1,26 @@ +// @Author daixk 2025/11/27 21:08:00 +package nop + +// NopLogger is a logger implementation that performs no operations | 用于禁用所有日志输出的空日志器 +type NopLogger struct{} + +func NewNopLogger() *NopLogger { + return &NopLogger{} +} + +// ---- Implement Adapter Interface | 实现 Adapter 接口 ---- + +func (n *NopLogger) Print(v ...any) {} +func (n *NopLogger) Printf(format string, v ...any) {} + +func (n *NopLogger) Debug(v ...any) {} +func (n *NopLogger) Debugf(format string, v ...any) {} + +func (n *NopLogger) Info(v ...any) {} +func (n *NopLogger) Infof(format string, v ...any) {} + +func (n *NopLogger) Warn(v ...any) {} +func (n *NopLogger) Warnf(format string, v ...any) {} + +func (n *NopLogger) Error(v ...any) {} +func (n *NopLogger) Errorf(format string, v ...any) {} diff --git a/log/slog/config.go b/log/slog/config.go new file mode 100644 index 0000000..789af12 --- /dev/null +++ b/log/slog/config.go @@ -0,0 +1,123 @@ +// @Author daixk 2025/12/22 15:55:00 +package slog + +import ( + "time" +) + +// LoggerConfig defines configuration for the logger | 日志配置项,定义日志输出的行为、格式和文件管理策略 +type LoggerConfig struct { + Path string // Log directory path | 日志文件目录 + FileFormat string // Log file naming format | 日志文件命名格式 + Prefix string // Log line prefix | 日志前缀 + Level LogLevel // Minimum output level | 最低输出级别 + TimeFormat string // Timestamp format | 时间戳格式 + Stdout bool // Print logs to console | 是否输出到控制台 + StdoutOnly bool // Only print to console, skip file output | 仅输出到控制台,不写入文件 + QueueSize int // Async write queue size | 异步写入队列大小 + RotateSize int64 // File size threshold before rotation (bytes) | 文件滚动大小阈值(字节) + RotateExpire time.Duration // Rotation interval by time duration | 文件时间滚动间隔 + RotateBackupLimit int // Maximum number of rotated backup files | 最大备份文件数量 + RotateBackupDays int // Retention days for old log files | 备份文件保留天数 +} + +// DefaultLoggerConfig returns default logger configuration | 返回默认日志配置 +func DefaultLoggerConfig() *LoggerConfig { + return &LoggerConfig{ + TimeFormat: DefaultTimeFormat, + FileFormat: DefaultFileFormat, + Prefix: DefaultPrefix, + Level: LevelInfo, + Stdout: true, + StdoutOnly: false, + QueueSize: DefaultQueueSize, + RotateSize: DefaultRotateSize, + RotateExpire: DefaultRotateExpire, + RotateBackupLimit: DefaultRotateBackupLimit, + RotateBackupDays: DefaultRotateBackupDays, + } +} + +// SetPath sets the log output directory | 设置日志输出目录 +func (c *LoggerConfig) SetPath(path string) *LoggerConfig { + c.Path = path + return c +} + +// SetFileFormat sets the log file naming format | 设置日志文件命名格式 +func (c *LoggerConfig) SetFileFormat(format string) *LoggerConfig { + c.FileFormat = format + return c +} + +// SetPrefix sets the log line prefix | 设置日志输出前缀 +func (c *LoggerConfig) SetPrefix(prefix string) *LoggerConfig { + c.Prefix = prefix + return c +} + +// SetLevel sets the minimum output log level | 设置日志最低输出级别 +func (c *LoggerConfig) SetLevel(level LogLevel) *LoggerConfig { + c.Level = level + return c +} + +// SetTimeFormat sets the timestamp format in log lines | 设置日志时间戳格式 +func (c *LoggerConfig) SetTimeFormat(format string) *LoggerConfig { + c.TimeFormat = format + return c +} + +// SetStdout enables or disables console output | 设置是否输出到控制台 +func (c *LoggerConfig) SetStdout(enable bool) *LoggerConfig { + c.Stdout = enable + return c +} + +// SetStdoutOnly enables console-only mode (no file output) | 设置仅输出到控制台模式 +func (c *LoggerConfig) SetStdoutOnly(enable bool) *LoggerConfig { + c.StdoutOnly = enable + if enable { + c.Stdout = true + } + return c +} + +// SetQueueSize sets the async write queue size | 设置异步写入队列大小 +func (c *LoggerConfig) SetQueueSize(size int) *LoggerConfig { + c.QueueSize = size + return c +} + +// SetRotateSize sets the file size threshold for log rotation | 设置日志文件大小滚动阈值 +func (c *LoggerConfig) SetRotateSize(size int64) *LoggerConfig { + c.RotateSize = size + return c +} + +// SetRotateExpire sets the time-based rotation interval | 设置时间滚动间隔 +func (c *LoggerConfig) SetRotateExpire(d time.Duration) *LoggerConfig { + c.RotateExpire = d + return c +} + +// SetRotateBackupLimit sets the maximum number of backup log files retained | 设置最大备份文件数量 +func (c *LoggerConfig) SetRotateBackupLimit(limit int) *LoggerConfig { + c.RotateBackupLimit = limit + return c +} + +// SetRotateBackupDays sets the retention days for backup log files | 设置日志备份保留天数 +func (c *LoggerConfig) SetRotateBackupDays(days int) *LoggerConfig { + c.RotateBackupDays = days + return c +} + +// Clone returns a copy of the current logger configuration | 返回当前日志配置的副本 +func (c *LoggerConfig) Clone() *LoggerConfig { + if c == nil { + return &LoggerConfig{} + } + copyCfg := *c + return ©Cfg +} diff --git a/log/slog/consts.go b/log/slog/consts.go new file mode 100644 index 0000000..2ca3d17 --- /dev/null +++ b/log/slog/consts.go @@ -0,0 +1,32 @@ +// @Author daixk 2025/12/22 15:56:00 +package slog + +import ( + "time" + + "github.com/click33/sa-token-go/core/adapter" +) + +// LogLevel is an alias for adapter.LogLevel | 日志级别别名 +type LogLevel = adapter.LogLevel + +// Log level constants | 日志级别常量 +const ( + LevelDebug = adapter.LogLevelDebug // Debug level | 调试级别 + LevelInfo = adapter.LogLevelInfo // Info level | 信息级别 + LevelWarn = adapter.LogLevelWarn // Warn level | 警告级别 + LevelError = adapter.LogLevelError // Error level | 错误级别(最高) +) + +const ( + DefaultPrefix = "[SA-TOKEN-GO] " // Default log prefix | 默认日志前缀 + DefaultFileFormat = "SA-TOKEN-GO_{Y}-{m}-{d}.log" // Default log filename format | 默认文件命名格式 + DefaultTimeFormat = "2006-01-02 15:04:05" // Default time format | 默认时间格式 + DefaultDirName = "sa_token_go_logs" // Default log directory name | 默认日志目录名 + DefaultBaseName = "SA-TOKEN-GO" // Default log filename prefix | 默认日志文件基础前缀 + DefaultQueueSize = 4096 // Default async queue size | 默认异步队列大小 + DefaultRotateSize = 10 * 1024 * 1024 // Rotate threshold (10MB) | 文件滚动大小阈值 + DefaultRotateExpire = 24 * time.Hour // Rotate by time interval (1 day) | 时间滚动间隔 + DefaultRotateBackupLimit = 10 // Max number of backups | 最大备份数量 + DefaultRotateBackupDays = 7 // Retain logs for 7 days | 备份保留天数 +) diff --git a/log/slog/doc.go b/log/slog/doc.go new file mode 100644 index 0000000..1f8d972 --- /dev/null +++ b/log/slog/doc.go @@ -0,0 +1,18 @@ +// @Author daixk 2025/12/26 15:17:00 +package slog + +// Package slog provides an async logging implementation for sa-token-go. +// +// Features: +// - Async write with buffered queue (non-blocking) +// - Log rotation by size and time +// - Auto cleanup of expired backup files +// - Runtime config modification (level, prefix, stdout) +// - Thread-safe design with proper locking +// +// TODO: Future enhancements | 未来增强计划: +// - [ ] Structured logging with JSON format output | 结构化日志(JSON 格式输出) +// - [ ] Sampling and rate limiting mechanism | 日志采样与限流机制 +// - [ ] Trace/Span ID support for distributed tracing | 分布式链路追踪 trace/span ID 支持 +// - [ ] Log aggregation hooks (e.g., send to ELK, Loki) | 日志聚合钩子(如发送到 ELK、Loki) +// - [ ] Context-aware logging with context.Context | 支持 context.Context 的上下文日志 diff --git a/log/slog/go.mod b/log/slog/go.mod new file mode 100644 index 0000000..42dd01e --- /dev/null +++ b/log/slog/go.mod @@ -0,0 +1,4 @@ +module github.com/click33/sa-token-go/log/slog + +go 1.25.0 + diff --git a/log/slog/go.sum b/log/slog/go.sum new file mode 100644 index 0000000..2fb8f28 --- /dev/null +++ b/log/slog/go.sum @@ -0,0 +1,6 @@ +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/log/slog/log_adaper_slog.go b/log/slog/log_adaper_slog.go new file mode 100644 index 0000000..9ad3f04 --- /dev/null +++ b/log/slog/log_adaper_slog.go @@ -0,0 +1,725 @@ +// @Author daixk 2025-12-26 14:14:15 +package slog + +import ( + "bytes" + "crypto/rand" + "fmt" + "math/big" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +// Logger implements ILogger + LoggerControl | 日志核心实现 +type Logger struct { + // ---- Config & State ---- + cfg *LoggerConfig // Logger configuration | 日志配置 + cfgMu sync.RWMutex // Config lock | 配置锁 + + // ---- File IO ---- + fileMu sync.Mutex // File write lock | 文件写锁 + curFile *os.File // Current log file | 当前日志文件 + curName string // Current file name | 当前日志文件名 + curSize int64 // Current log size | 当前文件大小 + lastRotate time.Time // Last rotation time | 上次切分时间 + + // ---- Async Write ---- + queue chan []byte // Async write queue | 异步写队列 + quit chan struct{} // Stop signal | 停止信号 + wg sync.WaitGroup + + // ---- Time Cache ---- + timeCache atomic.Value // Cached time info | 缓存的时间信息 + + // ---- State ---- + closed uint32 // Closed flag | 关闭标记 + dropCount uint64 // Dropped log counter | 队列满时丢弃日志计数 + + closeOnce sync.Once // Ensure Close only executes once | 确保 Close 只执行一次 +} + +// timeCacheEntry stores cached timestamp | 时间缓存条目 +type timeCacheEntry struct { + sec int64 // Unix seconds | Unix 秒 + str string // Formatted string | 格式化字符串 +} + +// NewLoggerWithConfig creates a logger instance | 使用配置创建日志器 +func NewLoggerWithConfig(cfg *LoggerConfig) (*Logger, error) { + newCfg, err := prepareConfig(cfg) + if err != nil { + return nil, err + } + + queueSize := newCfg.QueueSize + if queueSize <= 0 { + queueSize = DefaultQueueSize + } + + l := &Logger{ + cfg: newCfg, + queue: make(chan []byte, queueSize), + quit: make(chan struct{}), + lastRotate: time.Now(), + } + + // Initialize time cache | 初始化时间缓存 + now := time.Now() + l.timeCache.Store(&timeCacheEntry{ + sec: now.Unix(), + str: now.Format(newCfg.TimeFormat), + }) + + l.wg.Add(1) + go func() { + defer l.wg.Done() + l.writerLoop() + }() + + return l, nil +} + +// write handles simple log output | 输出普通日志 +func (l *Logger) write(level LogLevel, args ...any) { + if atomic.LoadUint32(&l.closed) != 0 { + return + } + cfg := l.currentCfg() + if level < cfg.Level { + return + } + l.enqueue(l.buildLine(level, cfg, args...)) +} + +// writef handles formatted log output | 输出格式化日志 +func (l *Logger) writef(level LogLevel, format string, args ...any) { + if atomic.LoadUint32(&l.closed) != 0 { + return + } + cfg := l.currentCfg() + if level < cfg.Level { + return + } + buf := getBuf() + _, _ = fmt.Fprintf(buf, format, args...) + line := l.buildLine(level, cfg, buf.String()) + putBuf(buf) + l.enqueue(line) +} + +// enqueue pushes logs to async queue | 将日志推入异步队列 +func (l *Logger) enqueue(b []byte) { + if atomic.LoadUint32(&l.closed) != 0 { + return + } + select { + case l.queue <- b: + default: + // queue full, drop | 队列满丢弃 + atomic.AddUint64(&l.dropCount, 1) + } +} + +// ---- Build Log Line ---- + +// buildLine builds complete log line | 构建完整日志行 +func (l *Logger) buildLine(level LogLevel, cfg LoggerConfig, args ...any) []byte { + buf := getBuf() + + // Get cached timestamp or format new one | 获取缓存时间戳或格式化新的 + now := time.Now() + sec := now.Unix() + + ts := l.getTimeString(now, sec, cfg.TimeFormat) + buf.WriteString(ts) + + buf.WriteString(" [") + buf.WriteString(levelString(level)) + buf.WriteString("] ") + + buf.WriteString(cfg.Prefix) + + for i, arg := range args { + if i > 0 { + buf.WriteByte(' ') + } + appendValue(buf, arg) + } + + buf.WriteByte('\n') + + // copy to new slice to avoid buffer reuse | 拷贝到新切片避免复用冲突 + out := append([]byte(nil), buf.Bytes()...) + putBuf(buf) + return out +} + +// getTimeString returns cached or formatted time string | 返回缓存或格式化的时间字符串 +func (l *Logger) getTimeString(now time.Time, sec int64, format string) string { + // Try to load from cache | 尝试从缓存加载 + if cached, ok := l.timeCache.Load().(*timeCacheEntry); ok && cached.sec == sec { + return cached.str + } + + // Format new string and update cache (atomic, no race) | 格式化新字符串并更新缓存 + str := now.Format(format) + l.timeCache.Store(&timeCacheEntry{sec: sec, str: str}) + return str +} + +// appendValue writes a single value with optimized type handling | 写入单个参数(优化类型处理) +func appendValue(buf *bytes.Buffer, v any) { + if v == nil { + buf.WriteString("") + return + } + + switch val := v.(type) { + case string: + buf.WriteString(val) + case []byte: + buf.Write(val) + case error: + if val != nil { + buf.WriteString(val.Error()) + } else { + buf.WriteString("") + } + + // Optimized integer handling | 优化整数处理 + case int: + buf.WriteString(strconv.FormatInt(int64(val), 10)) + case int8: + buf.WriteString(strconv.FormatInt(int64(val), 10)) + case int16: + buf.WriteString(strconv.FormatInt(int64(val), 10)) + case int32: + buf.WriteString(strconv.FormatInt(int64(val), 10)) + case int64: + buf.WriteString(strconv.FormatInt(val, 10)) + case uint: + buf.WriteString(strconv.FormatUint(uint64(val), 10)) + case uint8: + buf.WriteString(strconv.FormatUint(uint64(val), 10)) + case uint16: + buf.WriteString(strconv.FormatUint(uint64(val), 10)) + case uint32: + buf.WriteString(strconv.FormatUint(uint64(val), 10)) + case uint64: + buf.WriteString(strconv.FormatUint(val, 10)) + + case float32: + buf.WriteString(strconv.FormatFloat(float64(val), 'g', -1, 32)) + case float64: + buf.WriteString(strconv.FormatFloat(val, 'g', -1, 64)) + + case bool: + if val { + buf.WriteString("true") + } else { + buf.WriteString("false") + } + + case time.Time: + buf.WriteString(val.Format(DefaultTimeFormat)) + + default: + _, _ = fmt.Fprint(buf, val) + } +} + +// ---- Async Writer ---- + +// writerLoop processes all file IO | 异步写线程处理文件操作 +func (l *Logger) writerLoop() { + defer func() { + l.Flush() + }() + + for { + select { + case b, ok := <-l.queue: + if !ok { + return + } + l.writeToOutput(b) + + case <-l.quit: + // drain queue | 退出前清空队列 + for { + select { + case b := <-l.queue: + l.writeToOutput(b) + default: + return + } + } + } + } +} + +// writeToOutput writes to file and/or stdout | 写入文件和/或控制台 +func (l *Logger) writeToOutput(b []byte) { + cfg := l.currentCfg() + + // StdoutOnly mode: only print to console | 仅控制台模式 + if cfg.StdoutOnly { + if cfg.Stdout { + _, _ = os.Stdout.Write(b) + } + return + } + + now := time.Now() + + l.fileMu.Lock() + defer l.fileMu.Unlock() + + // open file if needed | 无文件则打开 + if err := l.ensureLogFile(now, cfg); err != nil { + // File open failed, fallback to stdout | 文件打开失败,回退到控制台 + if cfg.Stdout { + _, _ = os.Stdout.Write(b) + } + return + } + + if l.curFile != nil { + n, err := l.curFile.Write(b) + if err != nil { + _ = l.curFile.Close() + l.curFile = nil + // Retry once with new file | 重试一次新文件 + if retryErr := l.openNewFile(now, cfg); retryErr == nil && l.curFile != nil { + n, _ = l.curFile.Write(b) + l.curSize += int64(n) + } + } else { + l.curSize += int64(n) + } + } + + if cfg.Stdout { + _, _ = os.Stdout.Write(b) + } + + // check rotate | 检测切分 + if l.shouldRotate(now, cfg) { + _ = l.rotate(cfg) + } +} + +// ---- File Handling ---- + +// ensureLogFile ensures a log file is open | 确保日志文件存在 +func (l *Logger) ensureLogFile(now time.Time, cfg LoggerConfig) error { + if l.curFile == nil { + return l.openNewFile(now, cfg) + } + if cfg.RotateExpire > 0 && now.Sub(l.lastRotate) >= cfg.RotateExpire { + return l.rotate(cfg) + } + return nil +} + +// openNewFile opens a new log file | 打开新日志文件 +func (l *Logger) openNewFile(now time.Time, cfg LoggerConfig) error { + name := l.formatFileName(now, cfg) + path := filepath.Join(cfg.Path, name) + + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0666) + if err != nil { + return err + } + + l.curFile = f + l.curName = name + l.curSize = getFileSize(f) + l.lastRotate = now + return nil +} + +// shouldRotate checks rotation conditions | 检查是否需要切分 +func (l *Logger) shouldRotate(now time.Time, cfg LoggerConfig) bool { + if cfg.RotateSize > 0 && l.curSize >= cfg.RotateSize { + return true + } + if cfg.RotateExpire > 0 && now.Sub(l.lastRotate) >= cfg.RotateExpire { + return true + } + return false +} + +// rotate rotates the current log file | 日志切分逻辑 +func (l *Logger) rotate(cfg LoggerConfig) error { + if l.curFile == nil { + return nil + } + + old := filepath.Join(cfg.Path, l.curName) + _ = l.curFile.Sync() + _ = l.curFile.Close() + l.curFile = nil + + now := time.Now() + ts := fmt.Sprintf("%s_%03d", now.Format("20060102_150405"), now.Nanosecond()/1e6) + + base := strings.TrimSuffix(l.curName, ".log") + newName := fmt.Sprintf("%s_%s.log", base, ts) + newPath := filepath.Join(cfg.Path, newName) + + if err := os.Rename(old, newPath); err != nil { + // Use crypto/rand for secure random number | 使用加密安全的随机数 + randNum := secureRandomInt(1_000_000) + _ = os.Rename(old, filepath.Join(cfg.Path, base+fmt.Sprintf("_%06d.log", randNum))) + } + + l.curSize = 0 + l.curName = "" + l.lastRotate = now + + // Async cleanup to avoid blocking writes | 异步清理避免阻塞写入 + go l.cleanup(cfg) + + return l.openNewFile(now, cfg) +} + +// cleanup removes expired logs | 清理过期/多余日志文件 +func (l *Logger) cleanup(cfg LoggerConfig) { + // Recover from panic to avoid crashing the program | 捕获 panic 避免程序崩溃 + defer func() { + if r := recover(); r != nil { + // Silently ignore cleanup errors | 静默忽略清理错误 + } + }() + + // base is the fixed prefix of log files for this logger | base 为该 Logger 对应日志文件的固定前缀 + base := normalizeBaseName(cfg.FileFormat) + if base == "" { + base = DefaultBaseName + } + + files, _ := filepath.Glob(filepath.Join(cfg.Path, "*.log")) + if len(files) == 0 { + return + } + + var keep []struct { + path string + t time.Time + } + + now := time.Now() + expire := time.Time{} + if cfg.RotateBackupDays > 0 { + expire = now.AddDate(0, 0, -cfg.RotateBackupDays) + } + + for _, f := range files { + info, err := os.Stat(f) + if err != nil { + continue + } + + filename := filepath.Base(f) + + // 只处理以 base 开头的文件 | only handle files with the same base prefix + if !strings.HasPrefix(filename, base) { + continue + } + + // 清理过期文件 | remove expired files + if !expire.IsZero() && info.ModTime().Before(expire) { + _ = os.Remove(f) + continue + } + + // 当前正在写入的文件此时尚未创建(在 rotate 之后), + // 这里收集到的全是"备份文件",后续按数量进行裁剪 + keep = append(keep, struct { + path string + t time.Time + }{f, info.ModTime()}) + } + + // 根据 RotateBackupLimit 限制保留的备份文件数量(不包含当前正在写的那个文件)| + // keep only the newest RotateBackupLimit backup files (current file is not included here) + if cfg.RotateBackupLimit > 0 && len(keep) > cfg.RotateBackupLimit { + // 按修改时间排序,最旧的在前 | sort by time ascending + sort.Slice(keep, func(i, j int) bool { return keep[i].t.Before(keep[j].t) }) + + // 删除多余的,只保留最新的 cfg.RotateBackupLimit 个 | remove oldest extras + for _, f := range keep[:len(keep)-cfg.RotateBackupLimit] { + _ = os.Remove(f.path) + } + } +} + +// formatFileName generates filename | 生成日志文件名 +func (l *Logger) formatFileName(t time.Time, cfg LoggerConfig) string { + name := cfg.FileFormat + if name == "" { + return fmt.Sprintf("%s_%s.log", DefaultBaseName, t.Format("2006-01-02")) + } + + r := strings.NewReplacer( + "{Y}", t.Format("2006"), + "{m}", t.Format("01"), + "{d}", t.Format("02"), + ) + + name = r.Replace(name) + if !strings.HasSuffix(name, ".log") { + name += ".log" + } + return name +} + +// ---- Runtime Control ---- + +// SetLevel updates minimum level | 动态更新日志级别 +func (l *Logger) SetLevel(level LogLevel) { + l.cfgMu.Lock() + defer l.cfgMu.Unlock() + if l.cfg != nil { + l.cfg.Level = level + } +} + +// SetPrefix updates prefix | 动态更新日志前缀 +func (l *Logger) SetPrefix(prefix string) { + l.cfgMu.Lock() + defer l.cfgMu.Unlock() + if l.cfg != nil { + l.cfg.Prefix = prefix + } +} + +// SetStdout enables/disables stdout | 开关控制台输出 +func (l *Logger) SetStdout(enable bool) { + l.cfgMu.Lock() + defer l.cfgMu.Unlock() + if l.cfg != nil { + l.cfg.Stdout = enable + } +} + +// SetConfig replaces config and reopens log file | 动态替换配置并重新创建日志文件 +func (l *Logger) SetConfig(cfg *LoggerConfig) { + newCfg, err := prepareConfig(cfg) + if err != nil { + return + } + + // Lock in consistent order: fileMu first, then cfgMu | 统一锁顺序:先 fileMu,再 cfgMu + l.fileMu.Lock() + defer l.fileMu.Unlock() + + l.cfgMu.Lock() + defer l.cfgMu.Unlock() + + l.cfg = newCfg + + if l.curFile != nil { + _ = l.curFile.Sync() + _ = l.curFile.Close() + l.curFile = nil + } + + l.curName = "" + l.curSize = 0 + l.lastRotate = time.Now() +} + +// Close stops logger | 关闭日志系统 +func (l *Logger) Close() { + l.closeOnce.Do(func() { + atomic.StoreUint32(&l.closed, 1) + close(l.quit) + + l.wg.Wait() + + l.fileMu.Lock() + defer l.fileMu.Unlock() + + if l.curFile != nil { + _ = l.curFile.Sync() + _ = l.curFile.Close() + } + }) +} + +// Flush flushes file buffer | 强制刷新文件缓冲区 +func (l *Logger) Flush() { + l.fileMu.Lock() + defer l.fileMu.Unlock() + if l.curFile != nil { + _ = l.curFile.Sync() + } +} + +// LogPath returns directory | 返回日志目录 +func (l *Logger) LogPath() string { + l.cfgMu.RLock() + defer l.cfgMu.RUnlock() + if l.cfg == nil { + return "" + } + return l.cfg.Path +} + +// DropCount returns dropped logs | 返回丢弃日志数量 +func (l *Logger) DropCount() uint64 { + return atomic.LoadUint64(&l.dropCount) +} + +// ---- Buffer Pool ---- + +var bufPool = sync.Pool{ + New: func() any { return new(bytes.Buffer) }, +} + +func getBuf() *bytes.Buffer { + b := bufPool.Get().(*bytes.Buffer) + b.Reset() + return b +} + +func putBuf(b *bytes.Buffer) { + bufPool.Put(b) +} + +// ---- Helpers ---- + +// getFileSize returns file size | 获取文件大小 +func getFileSize(f *os.File) int64 { + info, err := f.Stat() + if err != nil { + return 0 + } + return info.Size() +} + +// prepareConfig applies defaults and ensures directory | 应用默认配置并确保目录存在 +func prepareConfig(cfg *LoggerConfig) (*LoggerConfig, error) { + if cfg == nil { + cfg = &LoggerConfig{} + } + + c := *cfg // copy + + if c.TimeFormat == "" { + c.TimeFormat = DefaultTimeFormat + } + if c.Prefix == "" { + c.Prefix = DefaultPrefix + } + if c.QueueSize <= 0 { + c.QueueSize = DefaultQueueSize + } + + // StdoutOnly mode doesn't need file config | 仅控制台模式不需要文件配置 + if c.StdoutOnly { + c.Stdout = true + return &c, nil + } + + // File mode: apply file-related defaults | 文件模式:应用文件相关默认值 + if c.FileFormat == "" { + c.FileFormat = DefaultFileFormat + } + if c.RotateSize <= 0 { + c.RotateSize = DefaultRotateSize + } + if c.RotateExpire < 0 { + c.RotateExpire = 0 + } + if c.RotateBackupLimit <= 0 { + c.RotateBackupLimit = DefaultRotateBackupLimit + } + if c.RotateBackupDays < 0 { + c.RotateBackupDays = 0 + } + + // Ensure path exists | 确保路径存在 + if c.Path == "" { + wd, err := os.Getwd() + if err != nil { + wd = "." + } + c.Path = filepath.Join(wd, DefaultDirName) + } + + if err := os.MkdirAll(c.Path, 0755); err != nil { + return nil, fmt.Errorf("failed to create log directory: %w", err) + } + + return &c, nil +} + +// currentCfg returns a config snapshot | 返回当前配置快照 +func (l *Logger) currentCfg() LoggerConfig { + l.cfgMu.RLock() + defer l.cfgMu.RUnlock() + + if l.cfg == nil { + return LoggerConfig{} + } + return *l.cfg +} + +// levelString converts log level to string | 将日志级别转换为字符串 +func levelString(level LogLevel) string { + return level.String() +} + +// normalizeBaseName extracts static name | 提取基础日志文件名前缀 +func normalizeBaseName(format string) string { + if format == "" { + return DefaultBaseName + } + + // 去掉 .log 后缀 | strip ".log" suffix + name := strings.TrimSuffix(format, ".log") + + // 如果包含占位符,则取第一个占位符之前的固定前缀 | if contains "{...}", take prefix before first placeholder + if idx := strings.Index(name, "{"); idx >= 0 { + name = name[:idx] + // 去掉末尾的连接符(常见为 "_" 或 "-")| trim trailing separators like "_" or "-" + name = strings.TrimRight(name, "_- ") + } + + name = strings.TrimSpace(name) + if name == "" { + return DefaultBaseName + } + return name +} + +// secureRandomInt returns a cryptographically secure random int | 返回加密安全的随机整数 +func secureRandomInt(max int) int { + n, err := rand.Int(rand.Reader, big.NewInt(int64(max))) + if err != nil { + return 0 + } + return int(n.Int64()) +} + +// ---- Logging API ---- + +func (l *Logger) Print(v ...any) { l.write(LevelInfo, v...) } +func (l *Logger) Printf(f string, v ...any) { l.writef(LevelInfo, f, v...) } +func (l *Logger) Debug(v ...any) { l.write(LevelDebug, v...) } +func (l *Logger) Debugf(f string, v ...any) { l.writef(LevelDebug, f, v...) } +func (l *Logger) Info(v ...any) { l.write(LevelInfo, v...) } +func (l *Logger) Infof(f string, v ...any) { l.writef(LevelInfo, f, v...) } +func (l *Logger) Warn(v ...any) { l.write(LevelWarn, v...) } +func (l *Logger) Warnf(f string, v ...any) { l.writef(LevelWarn, f, v...) } +func (l *Logger) Error(v ...any) { l.write(LevelError, v...) } +func (l *Logger) Errorf(f string, v ...any) { l.writef(LevelError, f, v...) } diff --git a/log/slog/slog_test.go b/log/slog/slog_test.go new file mode 100644 index 0000000..5ce6b57 --- /dev/null +++ b/log/slog/slog_test.go @@ -0,0 +1,964 @@ +// @Author daixk +package slog + +import ( + "bytes" + "errors" + "os" + "path/filepath" + "strings" + "sync" + "testing" + "time" +) + +// ============ LogLevel Tests | 日志级别测试 ============ + +func TestLogLevel_String(t *testing.T) { + tests := []struct { + level LogLevel + expected string + }{ + {LevelDebug, "DEBUG"}, + {LevelInfo, "INFO"}, + {LevelWarn, "WARN"}, + {LevelError, "ERROR"}, + {LogLevel(0), "UNKNOWN"}, + {LogLevel(100), "UNKNOWN"}, + } + + for _, tt := range tests { + t.Run(tt.expected, func(t *testing.T) { + if got := tt.level.String(); got != tt.expected { + t.Errorf("LogLevel.String() = %v, want %v", got, tt.expected) + } + }) + } +} + +// ============ LoggerConfig Tests | 配置测试 ============ + +func TestDefaultLoggerConfig(t *testing.T) { + cfg := DefaultLoggerConfig() + + if cfg.TimeFormat != DefaultTimeFormat { + t.Errorf("TimeFormat = %v, want %v", cfg.TimeFormat, DefaultTimeFormat) + } + if cfg.FileFormat != DefaultFileFormat { + t.Errorf("FileFormat = %v, want %v", cfg.FileFormat, DefaultFileFormat) + } + if cfg.Prefix != DefaultPrefix { + t.Errorf("Prefix = %v, want %v", cfg.Prefix, DefaultPrefix) + } + if cfg.Level != LevelInfo { + t.Errorf("Level = %v, want %v", cfg.Level, LevelInfo) + } + if !cfg.Stdout { + t.Error("Stdout should be true by default") + } + if cfg.StdoutOnly { + t.Error("StdoutOnly should be false by default") + } + if cfg.QueueSize != DefaultQueueSize { + t.Errorf("QueueSize = %v, want %v", cfg.QueueSize, DefaultQueueSize) + } + if cfg.RotateSize != DefaultRotateSize { + t.Errorf("RotateSize = %v, want %v", cfg.RotateSize, DefaultRotateSize) + } +} + +func TestLoggerConfig_Setters(t *testing.T) { + cfg := &LoggerConfig{} + testPath := "test_logs_path" + + cfg.SetPath(testPath). + SetFileFormat("test_{Y}-{m}-{d}.log"). + SetPrefix("[TEST] "). + SetLevel(LevelDebug). + SetTimeFormat("2006-01-02"). + SetStdout(true). + SetStdoutOnly(false). + SetQueueSize(1024). + SetRotateSize(1024 * 1024). + SetRotateExpire(time.Hour). + SetRotateBackupLimit(5). + SetRotateBackupDays(3) + + if cfg.Path != testPath { + t.Errorf("Path = %v, want %v", cfg.Path, testPath) + } + if cfg.FileFormat != "test_{Y}-{m}-{d}.log" { + t.Errorf("FileFormat = %v", cfg.FileFormat) + } + if cfg.Prefix != "[TEST] " { + t.Errorf("Prefix = %v", cfg.Prefix) + } + if cfg.Level != LevelDebug { + t.Errorf("Level = %v", cfg.Level) + } + if cfg.QueueSize != 1024 { + t.Errorf("QueueSize = %v", cfg.QueueSize) + } + if cfg.RotateBackupLimit != 5 { + t.Errorf("RotateBackupLimit = %v", cfg.RotateBackupLimit) + } +} + +func TestLoggerConfig_SetStdoutOnly(t *testing.T) { + cfg := &LoggerConfig{Stdout: false} + + cfg.SetStdoutOnly(true) + + if !cfg.Stdout { + t.Error("SetStdoutOnly(true) should also set Stdout to true") + } + if !cfg.StdoutOnly { + t.Error("StdoutOnly should be true") + } +} + +func TestLoggerConfig_Clone(t *testing.T) { + original := DefaultLoggerConfig() + original.Path = "original_path" + + cloned := original.Clone() + + if cloned.Path != original.Path { + t.Errorf("Clone().Path = %v, want %v", cloned.Path, original.Path) + } + + // Modify clone should not affect original + cloned.Path = "cloned_path" + if original.Path == cloned.Path { + t.Error("Modifying clone should not affect original") + } +} + +func TestLoggerConfig_Clone_Nil(t *testing.T) { + var cfg *LoggerConfig + cloned := cfg.Clone() + + if cloned == nil { + t.Error("Clone of nil should return empty config, not nil") + } +} + +// ============ Logger Creation Tests | 日志器创建测试 ============ + +func TestNewLoggerWithConfig_StdoutOnly(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + defer logger.Close() + + // StdoutOnly mode should not create directory + if logger.LogPath() != "" { + t.Errorf("StdoutOnly mode should have empty path, got %v", logger.LogPath()) + } +} + +func TestNewLoggerWithConfig_WithPath(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := DefaultLoggerConfig() + cfg.Path = tmpDir + cfg.Stdout = false + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + defer logger.Close() + + if logger.LogPath() != tmpDir { + t.Errorf("LogPath() = %v, want %v", logger.LogPath(), tmpDir) + } +} + +func TestNewLoggerWithConfig_NilConfig(t *testing.T) { + // Should create logger with default config + logger, err := NewLoggerWithConfig(nil) + if err != nil { + t.Fatalf("NewLoggerWithConfig(nil) error = %v", err) + } + defer logger.Close() + + // Clean up default directory + defer os.RemoveAll(logger.LogPath()) +} + +func TestNewLoggerWithConfig_InvalidPath(t *testing.T) { + cfg := DefaultLoggerConfig() + // Use invalid path (NUL is invalid on Windows, /dev/null/invalid on Unix) + cfg.Path = string([]byte{0}) // Null byte is invalid in paths + + _, err := NewLoggerWithConfig(cfg) + if err == nil { + t.Error("Expected error for invalid path") + } +} + +// ============ Logging Tests | 日志记录测试 ============ + +func TestLogger_AllLevels(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "test.log", + QueueSize: 100, + } + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + + // Log at all levels + logger.Debug("debug message") + logger.Debugf("debug formatted %d", 1) + logger.Info("info message") + logger.Infof("info formatted %d", 2) + logger.Warn("warn message") + logger.Warnf("warn formatted %d", 3) + logger.Error("error message") + logger.Errorf("error formatted %d", 4) + logger.Print("print message") + logger.Printf("print formatted %d", 5) + + logger.Close() + + // Read log file and verify + content, err := os.ReadFile(filepath.Join(tmpDir, "test.log")) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + lines := strings.Split(string(content), "\n") + expectedCount := 10 // 5 pairs of log calls + + actualCount := 0 + for _, line := range lines { + if strings.TrimSpace(line) != "" { + actualCount++ + } + } + + if actualCount != expectedCount { + t.Errorf("Expected %d log lines, got %d", expectedCount, actualCount) + } + + // Verify level filtering works + if !strings.Contains(string(content), "[DEBUG]") { + t.Error("Should contain DEBUG logs") + } + if !strings.Contains(string(content), "[INFO]") { + t.Error("Should contain INFO logs") + } + if !strings.Contains(string(content), "[WARN]") { + t.Error("Should contain WARN logs") + } + if !strings.Contains(string(content), "[ERROR]") { + t.Error("Should contain ERROR logs") + } +} + +func TestLogger_LevelFiltering(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelWarn, // Only WARN and ERROR + Stdout: false, + FileFormat: "test.log", + QueueSize: 100, + } + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + + logger.Debug("should not appear") + logger.Info("should not appear") + logger.Warn("should appear") + logger.Error("should appear") + + logger.Close() + + content, err := os.ReadFile(filepath.Join(tmpDir, "test.log")) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + if strings.Contains(string(content), "[DEBUG]") { + t.Error("Should NOT contain DEBUG logs") + } + if strings.Contains(string(content), "[INFO]") { + t.Error("Should NOT contain INFO logs") + } + if !strings.Contains(string(content), "[WARN]") { + t.Error("Should contain WARN logs") + } + if !strings.Contains(string(content), "[ERROR]") { + t.Error("Should contain ERROR logs") + } +} + +// ============ appendValue Tests | 值追加测试 ============ + +func TestAppendValue_AllTypes(t *testing.T) { + tests := []struct { + name string + value any + expected string + }{ + {"nil", nil, ""}, + {"string", "hello", "hello"}, + {"bytes", []byte("world"), "world"}, + {"error", errors.New("test error"), "test error"}, + {"nil error", error(nil), ""}, + {"int", 42, "42"}, + {"int8", int8(-8), "-8"}, + {"int16", int16(-16), "-16"}, + {"int32", int32(-32), "-32"}, + {"int64", int64(-64), "-64"}, + {"uint", uint(42), "42"}, + {"uint8", uint8(8), "8"}, + {"uint16", uint16(16), "16"}, + {"uint32", uint32(32), "32"}, + {"uint64", uint64(64), "64"}, + {"float32", float32(3.14), "3.14"}, + {"float64", float64(3.14159), "3.14159"}, + {"bool true", true, "true"}, + {"bool false", false, "false"}, + {"time", time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC), "2025-01-01 12:00:00"}, + {"struct", struct{ Name string }{"test"}, "{test}"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + buf := &bytes.Buffer{} + appendValue(buf, tt.value) + + if got := buf.String(); got != tt.expected { + t.Errorf("appendValue() = %v, want %v", got, tt.expected) + } + }) + } +} + +// ============ Runtime Control Tests | 运行时控制测试 ============ + +func TestLogger_SetLevel(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + logger.SetLevel(LevelError) + + currentCfg := logger.currentCfg() + if currentCfg.Level != LevelError { + t.Errorf("Level = %v, want %v", currentCfg.Level, LevelError) + } +} + +func TestLogger_SetPrefix(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + logger.SetPrefix("[NEW] ") + + currentCfg := logger.currentCfg() + if currentCfg.Prefix != "[NEW] " { + t.Errorf("Prefix = %v, want [NEW] ", currentCfg.Prefix) + } +} + +func TestLogger_SetStdout(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + logger.SetStdout(false) + + currentCfg := logger.currentCfg() + if currentCfg.Stdout { + t.Error("Stdout should be false") + } +} + +func TestLogger_SetConfig(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Stdout: false, + FileFormat: "old.log", + QueueSize: 100, + } + + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + logger.Info("old log") + + // Change config + newCfg := &LoggerConfig{ + Path: tmpDir, + Stdout: false, + FileFormat: "new.log", + Prefix: "[NEW] ", + QueueSize: 100, + } + + logger.SetConfig(newCfg) + logger.Info("new log") + + // Wait for async write + time.Sleep(100 * time.Millisecond) + logger.Flush() +} + +// ============ Concurrent Tests | 并发测试 ============ + +func TestLogger_ConcurrentWrite(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "concurrent.log", + QueueSize: 1000, + } + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + + var wg sync.WaitGroup + goroutines := 10 + logsPerGoroutine := 100 + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < logsPerGoroutine; j++ { + logger.Infof("goroutine %d, log %d", id, j) + } + }(i) + } + + wg.Wait() + logger.Close() + + content, err := os.ReadFile(filepath.Join(tmpDir, "concurrent.log")) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + expectedLines := goroutines * logsPerGoroutine + + if len(lines) != expectedLines { + t.Errorf("Expected %d lines, got %d", expectedLines, len(lines)) + } +} + +// ============ Close Tests | 关闭测试 ============ + +func TestLogger_DoubleClose(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + logger, _ := NewLoggerWithConfig(cfg) + + // Should not panic on double close + logger.Close() + logger.Close() +} + +func TestLogger_WriteAfterClose(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + logger, _ := NewLoggerWithConfig(cfg) + logger.Close() + + // Should not panic + logger.Info("after close") + logger.Infof("after close %d", 1) +} + +// ============ DropCount Tests | 丢弃计数测试 ============ + +func TestLogger_DropCount(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: 1, // Very small queue + } + + logger, _ := NewLoggerWithConfig(cfg) + + initial := logger.DropCount() + if initial != 0 { + t.Errorf("Initial DropCount = %v, want 0", initial) + } + + // Flood the logger to potentially cause drops + for i := 0; i < 100; i++ { + logger.Info("flood message") + } + + logger.Close() + + // DropCount should be accessible after close + _ = logger.DropCount() +} + +// ============ Time Cache Tests | 时间缓存测试 ============ + +func TestLogger_TimeCache(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "timecache.log", + QueueSize: 100, + } + + logger, _ := NewLoggerWithConfig(cfg) + + // Write multiple logs in same second - should use cache + for i := 0; i < 10; i++ { + logger.Info("same second log") + } + + logger.Close() + + content, err := os.ReadFile(filepath.Join(tmpDir, "timecache.log")) + if err != nil { + t.Fatalf("Failed to read log file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(content)), "\n") + if len(lines) != 10 { + t.Errorf("Expected 10 lines, got %d", len(lines)) + } +} + +// ============ File Rotation Tests | 文件轮转测试 ============ + +func TestLogger_RotateBySize(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "rotate.log", + RotateSize: 500, // Very small for testing + RotateBackupLimit: 3, + QueueSize: 100, + } + + logger, _ := NewLoggerWithConfig(cfg) + + // Write enough to trigger rotation + for i := 0; i < 100; i++ { + logger.Infof("rotation test message number %d with some padding text", i) + } + + logger.Close() + + // Wait for async cleanup + time.Sleep(200 * time.Millisecond) + + // Check for rotated files + files, _ := filepath.Glob(filepath.Join(tmpDir, "*.log")) + if len(files) == 0 { + t.Error("Expected at least one log file") + } + + t.Logf("Found %d log files after rotation", len(files)) +} + +// ============ Format Tests | 格式测试 ============ + +func TestFormatFileName(t *testing.T) { + logger := &Logger{} + + tests := []struct { + format string + time time.Time + expected string + }{ + { + format: "app_{Y}-{m}-{d}.log", + time: time.Date(2025, 6, 15, 0, 0, 0, 0, time.UTC), + expected: "app_2025-06-15.log", + }, + { + format: "log_{Y}{m}{d}", + time: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + expected: "log_20250101.log", + }, + { + format: "", + time: time.Date(2025, 12, 31, 0, 0, 0, 0, time.UTC), + expected: "SA-TOKEN-GO_2025-12-31.log", + }, + } + + for _, tt := range tests { + t.Run(tt.format, func(t *testing.T) { + cfg := LoggerConfig{FileFormat: tt.format} + got := logger.formatFileName(tt.time, cfg) + if got != tt.expected { + t.Errorf("formatFileName() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestNormalizeBaseName(t *testing.T) { + tests := []struct { + format string + expected string + }{ + {"SA-TOKEN-GO_{Y}-{m}-{d}.log", "SA-TOKEN-GO"}, + {"app_{Y}{m}{d}.log", "app"}, + {"mylog-{Y}-{m}-{d}.log", "mylog"}, + {"simple.log", "simple"}, + {"{Y}-{m}-{d}.log", DefaultBaseName}, + {"", DefaultBaseName}, + } + + for _, tt := range tests { + t.Run(tt.format, func(t *testing.T) { + got := normalizeBaseName(tt.format) + if got != tt.expected { + t.Errorf("normalizeBaseName(%q) = %v, want %v", tt.format, got, tt.expected) + } + }) + } +} + +// ============ Benchmark Tests | 性能测试 ============ + +func BenchmarkLogger_Info(b *testing.B) { + cfg := &LoggerConfig{ + Stdout: false, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Info("benchmark message") + } +} + +func BenchmarkLogger_Infof(b *testing.B) { + cfg := &LoggerConfig{ + Stdout: false, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + logger.Infof("benchmark message %d", i) + } +} + +func BenchmarkLogger_Concurrent(b *testing.B) { + cfg := &LoggerConfig{ + Stdout: false, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + logger.Info("concurrent benchmark") + } + }) +} + +func BenchmarkAppendValue_String(b *testing.B) { + buf := &bytes.Buffer{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + appendValue(buf, "test string") + } +} + +func BenchmarkAppendValue_Int(b *testing.B) { + buf := &bytes.Buffer{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + appendValue(buf, 12345678) + } +} + +func BenchmarkAppendValue_Float(b *testing.B) { + buf := &bytes.Buffer{} + b.ResetTimer() + for i := 0; i < b.N; i++ { + buf.Reset() + appendValue(buf, 3.14159265359) + } +} + +// ============ Edge Cases | 边界情况 ============ + +func TestLogger_EmptyMessage(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "empty.log", + QueueSize: 100, + } + + logger, _ := NewLoggerWithConfig(cfg) + logger.Info("") + logger.Info() // No args + logger.Close() + + content, _ := os.ReadFile(filepath.Join(tmpDir, "empty.log")) + if len(content) == 0 { + t.Error("Log file should not be empty") + } +} + +func TestLogger_SpecialCharacters(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "special.log", + QueueSize: 100, + } + + logger, _ := NewLoggerWithConfig(cfg) + logger.Info("hello\nworld") + logger.Info("tab\there") + logger.Info("中文日志") + logger.Info("emoji 🎉") + logger.Close() + + content, _ := os.ReadFile(filepath.Join(tmpDir, "special.log")) + if !strings.Contains(string(content), "中文日志") { + t.Error("Should contain Chinese characters") + } +} + +func TestLogger_LargeMessage(t *testing.T) { + tmpDir := t.TempDir() + //tmpDir, _ := os.Getwd() + + cfg := &LoggerConfig{ + Path: tmpDir, + Level: LevelDebug, + Stdout: false, + FileFormat: "large.log", + QueueSize: 100, + } + + logger, _ := NewLoggerWithConfig(cfg) + + // 1MB message + largeMsg := strings.Repeat("x", 1024*1024) + logger.Info(largeMsg) + logger.Close() + + info, err := os.Stat(filepath.Join(tmpDir, "large.log")) + if err != nil { + t.Fatalf("Failed to stat log file: %v", err) + } + + if info.Size() < 1024*1024 { + t.Errorf("Log file too small: %d bytes", info.Size()) + } +} + +// ============ Secure Random Tests | 安全随机数测试 ============ + +func TestSecureRandomInt(t *testing.T) { + seen := make(map[int]bool) + + for i := 0; i < 100; i++ { + n := secureRandomInt(1000000) + if n < 0 || n >= 1000000 { + t.Errorf("secureRandomInt returned out of range: %d", n) + } + seen[n] = true + } + + // Should have some variety (very unlikely to get < 50 unique in 100 tries) + if len(seen) < 50 { + t.Errorf("secureRandomInt seems not random enough: only %d unique values", len(seen)) + } +} + +// ============ Buffer Pool Tests | 缓冲池测试 ============ + +func TestBufferPool(t *testing.T) { + buf1 := getBuf() + buf1.WriteString("test") + putBuf(buf1) + + buf2 := getBuf() + if buf2.Len() != 0 { + t.Error("Buffer from pool should be reset") + } + putBuf(buf2) +} + +func BenchmarkBufferPool(b *testing.B) { + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + buf := getBuf() + buf.WriteString("test message for buffer pool") + putBuf(buf) + } + }) +} + +// ============ LogControl Interface Tests | 日志控制接口测试 ============ + +func TestLogger_ImplementsLogControl(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + QueueSize: DefaultQueueSize, + } + + logger, _ := NewLoggerWithConfig(cfg) + defer logger.Close() + + // Test all LogControl methods + logger.SetLevel(LevelDebug) + logger.SetPrefix("[TEST] ") + logger.SetStdout(false) + logger.Flush() + + path := logger.LogPath() + if path != "" { + t.Errorf("StdoutOnly mode should have empty LogPath, got %v", path) + } + + dropCount := logger.DropCount() + if dropCount != 0 { + t.Errorf("Initial DropCount should be 0, got %v", dropCount) + } +} + +// ============ StdoutOnly Mode Tests | 仅控制台模式测试 ============ + +func TestLogger_StdoutOnlyMode(t *testing.T) { + cfg := &LoggerConfig{ + Stdout: true, + StdoutOnly: true, + Level: LevelDebug, + Prefix: "[STDOUT-ONLY] ", + QueueSize: 100, + } + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + defer logger.Close() + + // These should not panic and should output to stdout + logger.Debug("stdout only debug") + logger.Info("stdout only info") + logger.Warn("stdout only warn") + logger.Error("stdout only error") + + // Verify no file was created + if logger.LogPath() != "" { + t.Errorf("StdoutOnly mode should not have a log path") + } +} + +func TestLogger_StdoutOnlyWithStdoutDisabled(t *testing.T) { + // Edge case: StdoutOnly=true but Stdout=false + // prepareConfig should force Stdout=true + cfg := &LoggerConfig{ + Stdout: false, // Will be overridden + StdoutOnly: true, + QueueSize: 100, + } + + logger, err := NewLoggerWithConfig(cfg) + if err != nil { + t.Fatalf("NewLoggerWithConfig() error = %v", err) + } + defer logger.Close() + + currentCfg := logger.currentCfg() + if !currentCfg.Stdout { + t.Error("StdoutOnly mode should force Stdout=true") + } +} diff --git a/pool/ants/config.go b/pool/ants/config.go new file mode 100644 index 0000000..1d7a881 --- /dev/null +++ b/pool/ants/config.go @@ -0,0 +1,127 @@ +// @Author daixk 2025/12/22 15:56:00 +package ants + +import ( + "fmt" + "time" +) + +// RenewPoolConfig configuration for the renewal pool manager-example | 续期池配置 +type RenewPoolConfig struct { + MinSize int // Minimum pool size | 最小协程数 + MaxSize int // Maximum pool size | 最大协程数 + ScaleUpRate float64 // Scale-up threshold | 扩容阈值 + ScaleDownRate float64 // Scale-down threshold | 缩容阈值 + CheckInterval time.Duration // Auto-scale check interval | 检查间隔 + Expiry time.Duration // Idle worker expiry duration | 空闲协程过期时间 + PrintStatusInterval time.Duration // Interval for periodic status printing | 定时打印池状态的间隔 + PreAlloc bool // Whether to pre-allocate memory | 是否预分配内存 + NonBlocking bool // Whether to use non-blocking mode | 是否为非阻塞模式 +} + +// DefaultRenewPoolConfig returns default renew pool config | 返回默认续期池配置 +func DefaultRenewPoolConfig() *RenewPoolConfig { + return &RenewPoolConfig{ + MinSize: DefaultMinSize, + MaxSize: DefaultMaxSize, + ScaleUpRate: DefaultScaleUpRate, + ScaleDownRate: DefaultScaleDownRate, + CheckInterval: DefaultCheckInterval, + Expiry: DefaultExpiry, + PreAlloc: false, + NonBlocking: true, + } +} + +// Validate validates renew pool configuration | 验证续期池配置合法性 +func (c *RenewPoolConfig) Validate() error { + if c == nil { + return nil // Nil config is allowed | 允许未配置续期池 + } + + if c.MinSize <= 0 { + return fmt.Errorf("RenewPoolConfig.MinSize must be > 0") + } + if c.MaxSize < c.MinSize { + return fmt.Errorf("RenewPoolConfig.MaxSize must be >= RenewPoolConfig.MinSize") + } + + if c.ScaleUpRate <= 0 || c.ScaleUpRate > 1 { + return fmt.Errorf("RenewPoolConfig.ScaleUpRate must be between 0 and 1") + } + if c.ScaleDownRate < 0 || c.ScaleDownRate > 1 { + return fmt.Errorf("RenewPoolConfig.ScaleDownRate must be between 0 and 1") + } + + if c.CheckInterval <= 0 { + return fmt.Errorf("RenewPoolConfig.CheckInterval must be a positive duration") + } + if c.Expiry <= 0 { + return fmt.Errorf("RenewPoolConfig.Expiry must be a positive duration") + } + + return nil +} + +// Clone returns a deep copy of the renew pool config | 克隆续期池配置 +func (c *RenewPoolConfig) Clone() *RenewPoolConfig { + if c == nil { + return nil + } + copyCfg := *c + return ©Cfg +} + +// SetMinSize sets minimum pool size | 设置最小协程数 +func (c *RenewPoolConfig) SetMinSize(size int) *RenewPoolConfig { + c.MinSize = size + return c +} + +// SetMaxSize sets maximum pool size | 设置最大协程数 +func (c *RenewPoolConfig) SetMaxSize(size int) *RenewPoolConfig { + c.MaxSize = size + return c +} + +// SetScaleUpRate sets scale-up threshold | 设置扩容阈值 +func (c *RenewPoolConfig) SetScaleUpRate(up float64) *RenewPoolConfig { + c.ScaleUpRate = up + return c +} + +// SetScaleDownRate sets scale-down threshold | 设置缩容阈值 +func (c *RenewPoolConfig) SetScaleDownRate(down float64) *RenewPoolConfig { + c.ScaleDownRate = down + return c +} + +// SetCheckInterval sets auto-scaling check interval | 设置检查间隔 +func (c *RenewPoolConfig) SetCheckInterval(interval time.Duration) *RenewPoolConfig { + c.CheckInterval = interval + return c +} + +// SetExpiry sets worker expiry duration | 设置空闲协程过期时间 +func (c *RenewPoolConfig) SetExpiry(expiry time.Duration) *RenewPoolConfig { + c.Expiry = expiry + return c +} + +// SetPrintStatusInterval sets status print interval | 设置打印状态的间隔 +func (c *RenewPoolConfig) SetPrintStatusInterval(interval time.Duration) *RenewPoolConfig { + c.PrintStatusInterval = interval + return c +} + +// SetPreAlloc sets pre-allocation flag | 设置是否预分配内存 +func (c *RenewPoolConfig) SetPreAlloc(prealloc bool) *RenewPoolConfig { + c.PreAlloc = prealloc + return c +} + +// SetNonBlocking sets non-blocking mode | 设置是否非阻塞模式 +func (c *RenewPoolConfig) SetNonBlocking(nonblocking bool) *RenewPoolConfig { + c.NonBlocking = nonblocking + return c +} diff --git a/pool/ants/consts.go b/pool/ants/consts.go new file mode 100644 index 0000000..9a7e82b --- /dev/null +++ b/pool/ants/consts.go @@ -0,0 +1,14 @@ +// @Author daixk 2025/12/22 15:56:00 +package ants + +import "time" + +// Default configuration constants | 默认配置常量 +const ( + DefaultMinSize = 100 // Minimum pool size | 最小协程数 + DefaultMaxSize = 2000 // Maximum pool size | 最大协程数 + DefaultScaleUpRate = 0.8 // Scale-up threshold | 扩容阈值 + DefaultScaleDownRate = 0.3 // Scale-down threshold | 缩容阈值 + DefaultCheckInterval = time.Minute // Interval for auto-scaling checks | 检查间隔 + DefaultExpiry = 10 * time.Second // Idle worker expiry duration | 空闲协程过期时间 +) diff --git a/pool/ants/go.mod b/pool/ants/go.mod new file mode 100644 index 0000000..13040d8 --- /dev/null +++ b/pool/ants/go.mod @@ -0,0 +1,12 @@ +module github.com/click33/sa-token-go/pool/ants + +go 1.25.0 + +require github.com/panjf2000/ants/v2 v2.11.3 + +require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect + github.com/stretchr/testify v1.11.1 // indirect + golang.org/x/sync v0.19.0 // indirect +) diff --git a/pool/ants/go.sum b/pool/ants/go.sum new file mode 100644 index 0000000..a11d3f6 --- /dev/null +++ b/pool/ants/go.sum @@ -0,0 +1,8 @@ +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= +github.com/panjf2000/ants/v2 v2.11.3/go.mod h1:8u92CYMUc6gyvTIw8Ru7Mt7+/ESnJahz5EVtqfrilek= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pool/ants/pool_adapter_ants.go b/pool/ants/pool_adapter_ants.go new file mode 100644 index 0000000..6a293f2 --- /dev/null +++ b/pool/ants/pool_adapter_ants.go @@ -0,0 +1,171 @@ +// @Author daixk 2025/12/12 11:55:00 +package ants + +import ( + "fmt" + "sync" + "time" + + "github.com/panjf2000/ants/v2" +) + +// RenewPoolManager manages a dynamic scaling goroutine pool for token renewal tasks | 续期任务协程池管理器 +type RenewPoolManager struct { + pool *ants.Pool // ants pool instance | ants 协程池实例 + config *RenewPoolConfig // Configuration object | 池配置对象 + mu sync.Mutex // Synchronization lock | 互斥锁 + stopCh chan struct{} // Stop signal channel | 停止信号通道 + started bool // Indicates if pool manager-example is running | 是否已启动 + closeOnce sync.Once // Ensure Stop only executes once | 确保 Stop 只执行一次 +} + +// NewRenewPoolManagerWithDefaultConfig creates manager-example with default config | 使用默认配置创建续期池管理器 +func NewRenewPoolManagerWithDefaultConfig() *RenewPoolManager { + mgr := &RenewPoolManager{ + config: DefaultRenewPoolConfig(), + stopCh: make(chan struct{}), + started: true, + } + + _ = mgr.initPool() + + // Start auto-scaling routine | 启动自动扩缩容协程 + go mgr.autoScale() + + return mgr +} + +// NewRenewPoolManagerWithConfig creates manager-example with config | 使用配置创建续期池管理器 +func NewRenewPoolManagerWithConfig(cfg *RenewPoolConfig) (*RenewPoolManager, error) { + if cfg == nil { + cfg = DefaultRenewPoolConfig() + } + if cfg.MinSize <= 0 { + cfg.MinSize = DefaultMinSize + } + if cfg.MaxSize < cfg.MinSize { + cfg.MaxSize = cfg.MinSize + } + + mgr := &RenewPoolManager{ + config: cfg, + stopCh: make(chan struct{}), + started: true, + } + + if err := mgr.initPool(); err != nil { + return nil, err + } + + // Start auto-scaling routine | 启动自动扩缩容协程 + go mgr.autoScale() + + return mgr, nil +} + +// initPool initializes the ants pool | 初始化 ants 协程池 +func (m *RenewPoolManager) initPool() error { + p, err := ants.NewPool( + m.config.MinSize, + ants.WithExpiryDuration(m.config.Expiry), + ants.WithPreAlloc(m.config.PreAlloc), + ants.WithNonblocking(m.config.NonBlocking), + ) + if err != nil { + return err + } + + m.pool = p + return nil +} + +// Submit submits a renewal task | 提交续期任务 +func (m *RenewPoolManager) Submit(task func()) error { + if !m.started { + return fmt.Errorf("renew pool not started") + } + return m.pool.Submit(task) +} + +// Stop stops the auto-scaling process | 停止自动扩缩容 +func (m *RenewPoolManager) Stop() { + m.closeOnce.Do(func() { + if !m.started { + return + } + close(m.stopCh) + m.started = false + + if m.pool != nil && !m.pool.IsClosed() { + _ = m.pool.ReleaseTimeout(3 * time.Second) + } + }) +} + +// Stats returns current pool statistics | 返回当前池状态 +func (m *RenewPoolManager) Stats() (running, capacity int, usage float64) { + m.mu.Lock() + defer m.mu.Unlock() + + running = m.pool.Running() // Active tasks | 当前运行任务数 + capacity = m.pool.Cap() // Pool capacity | 当前池容量 + if capacity > 0 { + usage = float64(running) / float64(capacity) // Usage ratio | 当前使用率 + // Cap usage at 1.0 to handle race condition between Running() and Cap() calls + // 限制使用率最大为 1.0,处理 Running() 和 Cap() 调用之间的竞态条件 + if usage > 1.0 { + usage = 1.0 + } + } + + return +} + +// autoScale automatic pool scale-up/down logic | 自动扩缩容逻辑 +func (m *RenewPoolManager) autoScale() { + ticker := time.NewTicker(m.config.CheckInterval) // Ticker for periodic usage checks | 定时器,用于定期检测使用率 + defer ticker.Stop() // Stop ticker on exit | 函数退出时停止定时器 + + for { + select { + case <-ticker.C: + m.mu.Lock() // Protect concurrent access | 加锁防止并发冲突 + + // Get current pool stats | 获取当前运行状态 + running := m.pool.Running() // Number of active goroutines | 当前正在执行的任务数 + capacity := m.pool.Cap() // Current pool capacity | 当前协程池容量 + + // Skip if capacity is 0 to avoid division by zero | 容量为0时跳过,避免除零 + if capacity <= 0 { + m.mu.Unlock() + continue + } + + usage := float64(running) / float64(capacity) // Current usage ratio | 当前使用率(运行数 ÷ 总容量) + + switch { + // Expand if usage exceeds threshold and capacity < MaxSize | 当使用率超过扩容阈值且容量小于最大值时扩容 + case usage > m.config.ScaleUpRate && capacity < m.config.MaxSize: + newCap := int(float64(capacity) * 1.5) // Increase capacity by 1.5x | 扩容为当前的 1.5 倍 + if newCap > m.config.MaxSize { // Cap to maximum size | 限制最大值 + newCap = m.config.MaxSize + } + m.pool.Tune(newCap) // Apply new pool capacity | 调整 ants 池容量 + + // Reduce if usage below threshold and capacity > MinSize | 当使用率低于缩容阈值且容量大于最小值时缩容 + case usage < m.config.ScaleDownRate && capacity > m.config.MinSize: + newCap := int(float64(capacity) * 0.7) // Reduce capacity to 70% | 缩容为当前的 70% + if newCap < m.config.MinSize { // Ensure not below MinSize | 限制最小值 + newCap = m.config.MinSize + } + m.pool.Tune(newCap) // Apply new pool capacity | 调整 ants 池容量 + } + + m.mu.Unlock() // Unlock after adjustment | 解锁 + + case <-m.stopCh: + // Stop signal received, exit loop | 收到停止信号,终止扩缩容协程 + return + } + } +} diff --git a/pool/ants/pool_adapter_ants_test.go b/pool/ants/pool_adapter_ants_test.go new file mode 100644 index 0000000..746ecfa --- /dev/null +++ b/pool/ants/pool_adapter_ants_test.go @@ -0,0 +1,795 @@ +// @Author daixk 2025/12/26 10:00:00 +package ants + +import ( + "sync" + "sync/atomic" + "testing" + "time" +) + +// ============ RenewPoolManager Tests | 续期池管理器测试 ============ + +func TestNewRenewPoolManagerWithDefaultConfig(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + if mgr == nil { + t.Fatal("NewRenewPoolManagerWithDefaultConfig returned nil") + } + defer mgr.Stop() + + if mgr.pool == nil { + t.Error("pool should not be nil") + } + if mgr.config == nil { + t.Error("config should not be nil") + } + if !mgr.started { + t.Error("manager-example should be started") + } +} + +func TestNewRenewPoolManagerWithConfig(t *testing.T) { + tests := []struct { + name string + config *RenewPoolConfig + wantErr bool + }{ + { + name: "nil config uses default", + config: nil, + wantErr: false, + }, + { + name: "valid config", + config: &RenewPoolConfig{ + MinSize: 50, + MaxSize: 500, + ScaleUpRate: 0.7, + ScaleDownRate: 0.2, + CheckInterval: 30 * time.Second, + Expiry: 5 * time.Second, + PreAlloc: false, + NonBlocking: true, + }, + wantErr: false, + }, + { + name: "MinSize <= 0 uses default", + config: &RenewPoolConfig{ + MinSize: 0, + MaxSize: 500, + ScaleUpRate: 0.7, + ScaleDownRate: 0.2, + CheckInterval: 30 * time.Second, + Expiry: 5 * time.Second, + }, + wantErr: false, + }, + { + name: "MaxSize < MinSize adjusts", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 50, + ScaleUpRate: 0.7, + ScaleDownRate: 0.2, + CheckInterval: 30 * time.Second, + Expiry: 5 * time.Second, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mgr, err := NewRenewPoolManagerWithConfig(tt.config) + if (err != nil) != tt.wantErr { + t.Errorf("NewRenewPoolManagerWithConfig() error = %v, wantErr %v", err, tt.wantErr) + return + } + if mgr != nil { + defer mgr.Stop() + if mgr.pool == nil { + t.Error("pool should not be nil") + } + } + }) + } +} + +func TestRenewPoolManager_Submit(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + var counter int32 + var wg sync.WaitGroup + + taskCount := 10 + wg.Add(taskCount) + + for i := 0; i < taskCount; i++ { + err := mgr.Submit(func() { + atomic.AddInt32(&counter, 1) + wg.Done() + }) + if err != nil { + t.Errorf("Submit() error = %v", err) + } + } + + wg.Wait() + + if atomic.LoadInt32(&counter) != int32(taskCount) { + t.Errorf("expected counter = %d, got %d", taskCount, counter) + } +} + +func TestRenewPoolManager_Submit_AfterStop(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + mgr.Stop() + + err := mgr.Submit(func() {}) + if err == nil { + t.Error("Submit() should return error after Stop()") + } +} + +func TestRenewPoolManager_Stop(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + + // Stop should be idempotent | Stop 应该是幂等的 + mgr.Stop() + mgr.Stop() // Should not panic | 不应该 panic + + if mgr.started { + t.Error("manager-example should not be started after Stop()") + } +} + +func TestRenewPoolManager_Stats(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + running, capacity, usage := mgr.Stats() + + if running < 0 { + t.Errorf("running should be >= 0, got %d", running) + } + if capacity <= 0 { + t.Errorf("capacity should be > 0, got %d", capacity) + } + if usage < 0 || usage > 1 { + t.Errorf("usage should be between 0 and 1, got %f", usage) + } +} + +func TestRenewPoolManager_Stats_WithTasks(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + // Submit some long-running tasks | 提交一些长时间运行的任务 + taskCount := 5 + doneCh := make(chan struct{}) + + for i := 0; i < taskCount; i++ { + _ = mgr.Submit(func() { + <-doneCh // Wait for signal | 等待信号 + }) + } + + // Give some time for tasks to start | 等待任务启动 + time.Sleep(50 * time.Millisecond) + + running, capacity, usage := mgr.Stats() + + if running < taskCount { + t.Errorf("expected running >= %d, got %d", taskCount, running) + } + if capacity < running { + t.Errorf("capacity should be >= running, got capacity=%d, running=%d", capacity, running) + } + if usage <= 0 { + t.Errorf("usage should be > 0 when tasks are running, got %f", usage) + } + + close(doneCh) // Release tasks | 释放任务 +} + +func TestRenewPoolManager_AutoScale_ScaleUp(t *testing.T) { + cfg := &RenewPoolConfig{ + MinSize: 10, + MaxSize: 100, + ScaleUpRate: 0.5, // Scale up when usage > 50% | 使用率超过 50% 时扩容 + ScaleDownRate: 0.1, + CheckInterval: 50 * time.Millisecond, // Fast check for testing | 快速检查用于测试 + Expiry: 5 * time.Second, + NonBlocking: false, // Blocking mode to ensure tasks queue | 阻塞模式确保任务排队 + } + + mgr, err := NewRenewPoolManagerWithConfig(cfg) + if err != nil { + t.Fatalf("NewRenewPoolManagerWithConfig() error = %v", err) + } + defer mgr.Stop() + + _, initialCap, _ := mgr.Stats() + + // Submit many long-running tasks to trigger scale-up | 提交多个长时间运行的任务触发扩容 + doneCh := make(chan struct{}) + taskCount := initialCap + 5 // More than capacity | 超过容量 + + for i := 0; i < taskCount; i++ { + go func() { + _ = mgr.Submit(func() { + <-doneCh + }) + }() + } + + // Wait for auto-scale to trigger | 等待自动扩容触发 + time.Sleep(200 * time.Millisecond) + + _, newCap, _ := mgr.Stats() + + // Capacity should have increased or stayed at max | 容量应该增加或保持最大值 + if newCap < initialCap { + t.Errorf("expected capacity to increase or stay same, initial=%d, new=%d", initialCap, newCap) + } + + close(doneCh) +} + +func TestRenewPoolManager_ConcurrentSubmit(t *testing.T) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + var counter int32 + var wg sync.WaitGroup + + goroutines := 100 + tasksPerGoroutine := 10 + + wg.Add(goroutines * tasksPerGoroutine) + + for i := 0; i < goroutines; i++ { + go func() { + for j := 0; j < tasksPerGoroutine; j++ { + err := mgr.Submit(func() { + atomic.AddInt32(&counter, 1) + wg.Done() + }) + if err != nil { + t.Errorf("Submit() error = %v", err) + wg.Done() + } + } + }() + } + + wg.Wait() + + expected := int32(goroutines * tasksPerGoroutine) + if atomic.LoadInt32(&counter) != expected { + t.Errorf("expected counter = %d, got %d", expected, counter) + } +} + +// ============ RenewPoolConfig Tests | 续期池配置测试 ============ + +func TestDefaultRenewPoolConfig(t *testing.T) { + cfg := DefaultRenewPoolConfig() + + if cfg.MinSize != DefaultMinSize { + t.Errorf("MinSize = %d, want %d", cfg.MinSize, DefaultMinSize) + } + if cfg.MaxSize != DefaultMaxSize { + t.Errorf("MaxSize = %d, want %d", cfg.MaxSize, DefaultMaxSize) + } + if cfg.ScaleUpRate != DefaultScaleUpRate { + t.Errorf("ScaleUpRate = %f, want %f", cfg.ScaleUpRate, DefaultScaleUpRate) + } + if cfg.ScaleDownRate != DefaultScaleDownRate { + t.Errorf("ScaleDownRate = %f, want %f", cfg.ScaleDownRate, DefaultScaleDownRate) + } + if cfg.CheckInterval != DefaultCheckInterval { + t.Errorf("CheckInterval = %v, want %v", cfg.CheckInterval, DefaultCheckInterval) + } + if cfg.Expiry != DefaultExpiry { + t.Errorf("Expiry = %v, want %v", cfg.Expiry, DefaultExpiry) + } + if cfg.PreAlloc != false { + t.Errorf("PreAlloc = %v, want false", cfg.PreAlloc) + } + if cfg.NonBlocking != true { + t.Errorf("NonBlocking = %v, want true", cfg.NonBlocking) + } +} + +func TestRenewPoolConfig_Validate(t *testing.T) { + tests := []struct { + name string + config *RenewPoolConfig + wantErr bool + }{ + { + name: "nil config is valid", + config: nil, + wantErr: false, + }, + { + name: "default config is valid", + config: DefaultRenewPoolConfig(), + wantErr: false, + }, + { + name: "MinSize <= 0 is invalid", + config: &RenewPoolConfig{ + MinSize: 0, + MaxSize: 100, + ScaleUpRate: 0.8, + ScaleDownRate: 0.3, + CheckInterval: time.Minute, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "MaxSize < MinSize is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 50, + ScaleUpRate: 0.8, + ScaleDownRate: 0.3, + CheckInterval: time.Minute, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "ScaleUpRate <= 0 is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 200, + ScaleUpRate: 0, + ScaleDownRate: 0.3, + CheckInterval: time.Minute, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "ScaleUpRate > 1 is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 200, + ScaleUpRate: 1.5, + ScaleDownRate: 0.3, + CheckInterval: time.Minute, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "ScaleDownRate < 0 is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 200, + ScaleUpRate: 0.8, + ScaleDownRate: -0.1, + CheckInterval: time.Minute, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "ScaleDownRate > 1 is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 200, + ScaleUpRate: 0.8, + ScaleDownRate: 1.5, + CheckInterval: time.Minute, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "CheckInterval <= 0 is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 200, + ScaleUpRate: 0.8, + ScaleDownRate: 0.3, + CheckInterval: 0, + Expiry: 10 * time.Second, + }, + wantErr: true, + }, + { + name: "Expiry <= 0 is invalid", + config: &RenewPoolConfig{ + MinSize: 100, + MaxSize: 200, + ScaleUpRate: 0.8, + ScaleDownRate: 0.3, + CheckInterval: time.Minute, + Expiry: 0, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestRenewPoolConfig_Clone(t *testing.T) { + original := &RenewPoolConfig{ + MinSize: 50, + MaxSize: 500, + ScaleUpRate: 0.7, + ScaleDownRate: 0.2, + CheckInterval: 30 * time.Second, + Expiry: 5 * time.Second, + PreAlloc: true, + NonBlocking: false, + } + + cloned := original.Clone() + + if cloned == original { + t.Error("Clone() should return a different pointer") + } + if cloned.MinSize != original.MinSize { + t.Errorf("MinSize mismatch: got %d, want %d", cloned.MinSize, original.MinSize) + } + if cloned.MaxSize != original.MaxSize { + t.Errorf("MaxSize mismatch: got %d, want %d", cloned.MaxSize, original.MaxSize) + } + if cloned.ScaleUpRate != original.ScaleUpRate { + t.Errorf("ScaleUpRate mismatch: got %f, want %f", cloned.ScaleUpRate, original.ScaleUpRate) + } + if cloned.ScaleDownRate != original.ScaleDownRate { + t.Errorf("ScaleDownRate mismatch: got %f, want %f", cloned.ScaleDownRate, original.ScaleDownRate) + } + if cloned.CheckInterval != original.CheckInterval { + t.Errorf("CheckInterval mismatch: got %v, want %v", cloned.CheckInterval, original.CheckInterval) + } + if cloned.Expiry != original.Expiry { + t.Errorf("Expiry mismatch: got %v, want %v", cloned.Expiry, original.Expiry) + } + if cloned.PreAlloc != original.PreAlloc { + t.Errorf("PreAlloc mismatch: got %v, want %v", cloned.PreAlloc, original.PreAlloc) + } + if cloned.NonBlocking != original.NonBlocking { + t.Errorf("NonBlocking mismatch: got %v, want %v", cloned.NonBlocking, original.NonBlocking) + } + + // Modify clone should not affect original | 修改克隆不应影响原始 + cloned.MinSize = 999 + if original.MinSize == 999 { + t.Error("Modifying clone affected original") + } +} + +func TestRenewPoolConfig_Clone_Nil(t *testing.T) { + var cfg *RenewPoolConfig + cloned := cfg.Clone() + if cloned != nil { + t.Error("Clone() of nil should return nil") + } +} + +func TestRenewPoolConfig_Setters(t *testing.T) { + cfg := &RenewPoolConfig{} + + // Test chaining | 测试链式调用 + result := cfg. + SetMinSize(50). + SetMaxSize(500). + SetScaleUpRate(0.75). + SetScaleDownRate(0.25). + SetCheckInterval(45 * time.Second). + SetExpiry(15 * time.Second). + SetPrintStatusInterval(5 * time.Minute). + SetPreAlloc(true). + SetNonBlocking(false) + + if result != cfg { + t.Error("Setters should return the same config pointer for chaining") + } + + if cfg.MinSize != 50 { + t.Errorf("MinSize = %d, want 50", cfg.MinSize) + } + if cfg.MaxSize != 500 { + t.Errorf("MaxSize = %d, want 500", cfg.MaxSize) + } + if cfg.ScaleUpRate != 0.75 { + t.Errorf("ScaleUpRate = %f, want 0.75", cfg.ScaleUpRate) + } + if cfg.ScaleDownRate != 0.25 { + t.Errorf("ScaleDownRate = %f, want 0.25", cfg.ScaleDownRate) + } + if cfg.CheckInterval != 45*time.Second { + t.Errorf("CheckInterval = %v, want 45s", cfg.CheckInterval) + } + if cfg.Expiry != 15*time.Second { + t.Errorf("Expiry = %v, want 15s", cfg.Expiry) + } + if cfg.PrintStatusInterval != 5*time.Minute { + t.Errorf("PrintStatusInterval = %v, want 5m", cfg.PrintStatusInterval) + } + if cfg.PreAlloc != true { + t.Errorf("PreAlloc = %v, want true", cfg.PreAlloc) + } + if cfg.NonBlocking != false { + t.Errorf("NonBlocking = %v, want false", cfg.NonBlocking) + } +} + +// ============ Benchmark Tests | 基准测试 ============ + +func BenchmarkRenewPoolManager_Submit(b *testing.B) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = mgr.Submit(func() { + // Empty task | 空任务 + }) + } +} + +func BenchmarkRenewPoolManager_Submit_WithWork(b *testing.B) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = mgr.Submit(func() { + // Simulate some work | 模拟一些工作 + time.Sleep(time.Microsecond) + }) + } +} + +func BenchmarkRenewPoolManager_Stats(b *testing.B) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + mgr.Stats() + } +} + +func BenchmarkRenewPoolManager_ConcurrentSubmit(b *testing.B) { + mgr := NewRenewPoolManagerWithDefaultConfig() + defer mgr.Stop() + + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _ = mgr.Submit(func() { + // Empty task | 空任务 + }) + } + }) +} + +// ============ Auto-Scale Demo Test | 自动扩缩容演示测试 ============ + +// TestRenewPoolManager_AutoScale_Demo demonstrates the auto-scaling behavior with status printing +// 演示自动扩缩容行为并打印状态 +// Run with: go test -v -run TestRenewPoolManager_AutoScale_Demo -timeout 60s +func TestRenewPoolManager_AutoScale_Demo(t *testing.T) { + cfg := &RenewPoolConfig{ + MinSize: 5, // Minimum pool size | 最小池大小 + MaxSize: 50, // Maximum pool size | 最大池大小 + ScaleUpRate: 0.6, // Scale up when usage > 60% | 使用率超过 60% 时扩容 + ScaleDownRate: 0.2, // Scale down when usage < 20% | 使用率低于 20% 时缩容 + CheckInterval: 200 * time.Millisecond, // Check every 200ms | 每 200ms 检查一次 + Expiry: 2 * time.Second, // Worker expiry | Worker 过期时间 + NonBlocking: false, // Blocking mode | 阻塞模式 + } + + mgr, err := NewRenewPoolManagerWithConfig(cfg) + if err != nil { + t.Fatalf("Failed to create pool manager-example: %v", err) + } + defer mgr.Stop() + + // Status printer | 状态打印器 + stopPrinter := make(chan struct{}) + go func() { + ticker := time.NewTicker(100 * time.Millisecond) + defer ticker.Stop() + + lastCap := 0 + for { + select { + case <-ticker.C: + running, capacity, usage := mgr.Stats() + action := " STABLE " + if capacity > lastCap { + action = "⬆️ SCALE UP" + } else if capacity < lastCap { + action = "⬇️ SCALE DN" + } + lastCap = capacity + + // Print status bar | 打印状态条 + usageBar := generateUsageBar(usage, 20) + t.Logf("[%s] Cap: %3d | Running: %3d | Usage: %5.1f%% |%s|", + action, capacity, running, usage*100, usageBar) + + case <-stopPrinter: + return + } + } + }() + + t.Log("========== Phase 1: Initial State (2s) | 初始状态 ==========") + time.Sleep(2 * time.Second) + + t.Log("========== Phase 2: High Load - Triggering Scale Up | 高负载 - 触发扩容 ==========") + // Submit many long-running tasks | 提交大量长时间运行的任务 + phase2Done := make(chan struct{}) + taskCount := 30 + for i := 0; i < taskCount; i++ { + go func(id int) { + _ = mgr.Submit(func() { + <-phase2Done // Wait for signal | 等待信号 + }) + }(i) + } + time.Sleep(3 * time.Second) // Wait for scale up | 等待扩容 + + t.Log("========== Phase 3: Releasing Tasks - Observe Scale Down | 释放任务 - 观察缩容 ==========") + close(phase2Done) // Release all tasks | 释放所有任务 + time.Sleep(4 * time.Second) // Wait for scale down | 等待缩容 + + t.Log("========== Phase 4: Burst Load Again | 再次突发负载 ==========") + phase4Done := make(chan struct{}) + for i := 0; i < 40; i++ { + go func(id int) { + _ = mgr.Submit(func() { + <-phase4Done + }) + }(i) + } + time.Sleep(3 * time.Second) + + t.Log("========== Phase 5: Gradual Release | 逐步释放 ==========") + close(phase4Done) + time.Sleep(4 * time.Second) + + t.Log("========== Phase 6: Final State | 最终状态 ==========") + time.Sleep(2 * time.Second) + + close(stopPrinter) + + // Final stats | 最终统计 + running, capacity, usage := mgr.Stats() + t.Logf("Final Stats - Capacity: %d, Running: %d, Usage: %.1f%%", capacity, running, usage*100) +} + +// generateUsageBar creates a visual usage bar | 生成可视化使用率条 +func generateUsageBar(usage float64, width int) string { + filled := int(usage * float64(width)) + if filled > width { + filled = width + } + + bar := make([]byte, width) + for i := 0; i < width; i++ { + if i < filled { + bar[i] = '#' + } else { + bar[i] = '-' + } + } + return string(bar) +} + +// TestRenewPoolManager_AutoScale_StressTest stress test for auto-scaling +// 自动扩缩容压力测试 +func TestRenewPoolManager_AutoScale_StressTest(t *testing.T) { + cfg := &RenewPoolConfig{ + MinSize: 10, + MaxSize: 100, + ScaleUpRate: 0.7, + ScaleDownRate: 0.2, + CheckInterval: 100 * time.Millisecond, + Expiry: 1 * time.Second, + NonBlocking: false, + } + + mgr, err := NewRenewPoolManagerWithConfig(cfg) + if err != nil { + t.Fatalf("Failed to create pool manager-example: %v", err) + } + defer mgr.Stop() + + // Track capacity changes | 记录容量变化 + var ( + maxCapSeen = 0 + minCapSeen = 1000 + scaleUps = 0 + scaleDowns = 0 + lastCap = cfg.MinSize + ) + + stopMonitor := make(chan struct{}) + go func() { + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + _, capacity, _ := mgr.Stats() + if capacity > maxCapSeen { + maxCapSeen = capacity + } + if capacity < minCapSeen { + minCapSeen = capacity + } + if capacity > lastCap { + scaleUps++ + } else if capacity < lastCap { + scaleDowns++ + } + lastCap = capacity + case <-stopMonitor: + return + } + } + }() + + // Wave pattern load | 波浪模式负载 + for wave := 0; wave < 3; wave++ { + t.Logf("Wave %d: Increasing load...", wave+1) + doneCh := make(chan struct{}) + + // Increase load | 增加负载 + for i := 0; i < 50+wave*20; i++ { + go func() { + _ = mgr.Submit(func() { + <-doneCh + }) + }() + } + time.Sleep(1500 * time.Millisecond) + + t.Logf("Wave %d: Releasing load...", wave+1) + close(doneCh) + time.Sleep(2 * time.Second) + } + + close(stopMonitor) + + t.Logf("Stress Test Results:") + t.Logf(" - Max capacity seen: %d", maxCapSeen) + t.Logf(" - Min capacity seen: %d", minCapSeen) + t.Logf(" - Scale up events: %d", scaleUps) + t.Logf(" - Scale down events: %d", scaleDowns) + + // Verify scaling occurred | 验证发生了扩缩容 + if scaleUps == 0 { + t.Error("Expected at least one scale up event") + } + if scaleDowns == 0 { + t.Error("Expected at least one scale down event") + } + if maxCapSeen <= cfg.MinSize { + t.Errorf("Expected max capacity > MinSize(%d), got %d", cfg.MinSize, maxCapSeen) + } +} diff --git a/storage/memory/go.mod b/storage/memory/go.mod index ef4bdfc..2cef58b 100644 --- a/storage/memory/go.mod +++ b/storage/memory/go.mod @@ -1,7 +1,3 @@ module github.com/click33/sa-token-go/storage/memory -go 1.23.0 - -require github.com/click33/sa-token-go/core v0.1.5 - -replace github.com/click33/sa-token-go/core => ../../core +go 1.25.0 diff --git a/storage/memory/memory.go b/storage/memory/memory.go index 6fc9797..7d6b837 100644 --- a/storage/memory/memory.go +++ b/storage/memory/memory.go @@ -6,8 +6,6 @@ import ( "strings" "sync" "time" - - "github.com/click33/sa-token-go/core/adapter" ) var ( @@ -37,12 +35,12 @@ type Storage struct { } // NewStorage 创建内存存储 -func NewStorage() adapter.Storage { +func NewStorage() *Storage { return NewStorageWithCleanupInterval(time.Minute) } // NewStorageWithCleanupInterval 创建内存存储 -func NewStorageWithCleanupInterval(interval time.Duration) adapter.Storage { +func NewStorageWithCleanupInterval(interval time.Duration) *Storage { ctx, cancel := context.WithCancel(context.Background()) s := &Storage{ data: make(map[string]*item), @@ -54,7 +52,7 @@ func NewStorageWithCleanupInterval(interval time.Duration) adapter.Storage { } // Set 设置键值对 -func (s *Storage) Set(key string, value any, expiration time.Duration) error { +func (s *Storage) Set(_ context.Context, key string, value any, expiration time.Duration) error { s.mu.Lock() defer s.mu.Unlock() @@ -72,7 +70,7 @@ func (s *Storage) Set(key string, value any, expiration time.Duration) error { } // SetKeepTTL Sets value without modifying TTL | 设置键值但保持原有TTL不变 -func (s *Storage) SetKeepTTL(key string, value any) error { +func (s *Storage) SetKeepTTL(_ context.Context, key string, value any) error { now := time.Now().Unix() s.mu.Lock() @@ -97,7 +95,7 @@ func (s *Storage) SetKeepTTL(key string, value any) error { } // Get 获取值 -func (s *Storage) Get(key string) (any, error) { +func (s *Storage) Get(_ context.Context, key string) (any, error) { now := time.Now().Unix() s.mu.RLock() @@ -110,15 +108,38 @@ func (s *Storage) Get(key string) (any, error) { if item.isExpired(now) { // 异步删除过期项 - go s.Delete(key) + go s.Delete(context.Background(), key) return nil, ErrKeyExpired } return item.value, nil } +// GetAndDelete atomically gets the value and deletes the key | 原子获取并删除键 +func (s *Storage) GetAndDelete(_ context.Context, key string) (any, error) { + now := time.Now().Unix() + + s.mu.Lock() + defer s.mu.Unlock() + + item, exists := s.data[key] + if !exists { + return nil, ErrKeyNotFound + } + + if item.isExpired(now) { + delete(s.data, key) + return nil, ErrKeyExpired + } + + val := item.value + delete(s.data, key) + + return val, nil +} + // Delete 删除键 -func (s *Storage) Delete(keys ...string) error { +func (s *Storage) Delete(_ context.Context, keys ...string) error { s.mu.Lock() defer s.mu.Unlock() @@ -129,7 +150,7 @@ func (s *Storage) Delete(keys ...string) error { } // Exists 检查键是否存在 -func (s *Storage) Exists(key string) bool { +func (s *Storage) Exists(_ context.Context, key string) bool { now := time.Now().Unix() s.mu.RLock() @@ -142,7 +163,7 @@ func (s *Storage) Exists(key string) bool { if item.isExpired(now) { // 异步删除过期项 - go s.Delete(key) + go s.Delete(context.Background(), key) return false } @@ -150,7 +171,7 @@ func (s *Storage) Exists(key string) bool { } // Keys 获取匹配模式的所有键 -func (s *Storage) Keys(pattern string) ([]string, error) { +func (s *Storage) Keys(_ context.Context, pattern string) ([]string, error) { now := time.Now().Unix() s.mu.RLock() @@ -170,7 +191,7 @@ func (s *Storage) Keys(pattern string) ([]string, error) { } // Expire 设置键的过期时间 -func (s *Storage) Expire(key string, expiration time.Duration) error { +func (s *Storage) Expire(_ context.Context, key string, expiration time.Duration) error { s.mu.Lock() defer s.mu.Unlock() @@ -179,17 +200,19 @@ func (s *Storage) Expire(key string, expiration time.Duration) error { return ErrKeyNotFound } + // expiration>0 设置 TTL if expiration > 0 { item.expiration = time.Now().Add(expiration).Unix() - } else { - item.expiration = 0 // 永不过期 + return nil } + // 兼容redis的过期时间设置 + delete(s.data, key) return nil } // TTL 获取键的剩余生存时间 -func (s *Storage) TTL(key string) (time.Duration, error) { +func (s *Storage) TTL(_ context.Context, key string) (time.Duration, error) { now := time.Now().Unix() s.mu.RLock() @@ -213,7 +236,7 @@ func (s *Storage) TTL(key string) (time.Duration, error) { } // Clear 清空所有数据 -func (s *Storage) Clear() error { +func (s *Storage) Clear(_ context.Context) error { s.mu.Lock() defer s.mu.Unlock() @@ -222,7 +245,7 @@ func (s *Storage) Clear() error { } // Ping 检查存储可用性 -func (s *Storage) Ping() error { +func (s *Storage) Ping(_ context.Context) error { s.mu.RLock() defer s.mu.RUnlock() diff --git a/storage/memory/memory_test.go b/storage/memory/memory_test.go index 6eef018..485e10d 100644 --- a/storage/memory/memory_test.go +++ b/storage/memory/memory_test.go @@ -1,67 +1,668 @@ package memory import ( + "context" + "strings" "testing" "time" ) -func TestSetKeepTTL(t *testing.T) { +func TestMemoryStorage_SetAndGet(t *testing.T) { storage := NewStorage() + defer storage.Close() - // 测试场景1: 键不存在的情况 - err := storage.SetKeepTTL("non_existent_key", "value") - if err == nil { - t.Errorf("Expected error for non-existent key, got nil") + ctx := context.Background() + + t.Run("Set and Get basic value", func(t *testing.T) { + key := "test_key" + value := "test_value" + + err := storage.Set(ctx, key, value, 0) + if err != nil { + t.Fatalf("Failed to set key: %v", err) + } + + got, err := storage.Get(ctx, key) + if err != nil { + t.Fatalf("Failed to get key: %v", err) + } + + if got != value { + t.Errorf("Expected value %q, got %q", value, got) + } + }) + + t.Run("Set with expiration", func(t *testing.T) { + key := "expire_key" + value := "expire_value" + + err := storage.Set(ctx, key, value, 2*time.Second) + if err != nil { + t.Fatalf("Failed to set key: %v", err) + } + + // 立即获取应该成功 + got, err := storage.Get(ctx, key) + if err != nil { + t.Fatalf("Failed to get key: %v", err) + } + if got != value { + t.Errorf("Expected value %q, got %q", value, got) + } + + // 等待过期 + time.Sleep(3 * time.Second) + + // 过期后获取应该失败 + _, err = storage.Get(ctx, key) + if err == nil { + t.Error("Expected error for expired key, got nil") + } + }) + + t.Run("Get non-existent key", func(t *testing.T) { + _, err := storage.Get(ctx, "non_existent") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) +} + +func TestMemoryStorage_SetKeepTTL(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + t.Run("SetKeepTTL for non-existent key", func(t *testing.T) { + err := storage.SetKeepTTL(ctx, "non_existent_key", "value") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) + + t.Run("SetKeepTTL preserves TTL", func(t *testing.T) { + key := "test_key" + originalValue := "original_value" + newValue := "new_value" + ttl := 10 * time.Second + + // 设置初始值和TTL + err := storage.Set(ctx, key, originalValue, ttl) + if err != nil { + t.Fatalf("Failed to set key: %v", err) + } + + // 获取原始TTL + originalTTL, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + // 等待1秒 + time.Sleep(1 * time.Second) + + // 使用SetKeepTTL更新值 + err = storage.SetKeepTTL(ctx, key, newValue) + if err != nil { + t.Fatalf("SetKeepTTL failed: %v", err) + } + + // 验证值已更新 + value, err := storage.Get(ctx, key) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + if value != newValue { + t.Errorf("Expected value %q, got %q", newValue, value) + } + + // 验证TTL保持相对不变(允许误差) + newTTL, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL after update: %v", err) + } + + ttlDiff := originalTTL - newTTL + if ttlDiff < 0 { + ttlDiff = -ttlDiff + } + if ttlDiff > 2*time.Second { + t.Errorf("TTL changed significantly. Original: %v, New: %v, Diff: %v", originalTTL, newTTL, ttlDiff) + } + }) +} + +func TestMemoryStorage_Delete(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + t.Run("Delete single key", func(t *testing.T) { + key := "delete_key" + value := "delete_value" + + storage.Set(ctx, key, value, 0) + + err := storage.Delete(ctx, key) + if err != nil { + t.Fatalf("Failed to delete key: %v", err) + } + + exists := storage.Exists(ctx, key) + if exists { + t.Error("Key should not exist after deletion") + } + }) + + t.Run("Delete multiple keys", func(t *testing.T) { + keys := []string{"key1", "key2", "key3"} + for _, key := range keys { + storage.Set(ctx, key, "value", 0) + } + + err := storage.Delete(ctx, keys...) + if err != nil { + t.Fatalf("Failed to delete keys: %v", err) + } + + for _, key := range keys { + if storage.Exists(ctx, key) { + t.Errorf("Key %s should not exist after deletion", key) + } + } + }) +} + +func TestMemoryStorage_GetAndDelete(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + t.Run("GetAndDelete existing key", func(t *testing.T) { + key := "getdel_key" + value := "getdel_value" + + storage.Set(ctx, key, value, 0) + + got, err := storage.GetAndDelete(ctx, key) + if err != nil { + t.Fatalf("GetAndDelete failed: %v", err) + } + + if got != value { + t.Errorf("Expected value %q, got %q", value, got) + } + + // 键应该已被删除 + if storage.Exists(ctx, key) { + t.Error("Key should not exist after GetAndDelete") + } + }) + + t.Run("GetAndDelete non-existent key", func(t *testing.T) { + _, err := storage.GetAndDelete(ctx, "non_existent") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) +} + +func TestMemoryStorage_Exists(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + t.Run("Exists for existing key", func(t *testing.T) { + key := "exists_key" + storage.Set(ctx, key, "value", 0) + + if !storage.Exists(ctx, key) { + t.Error("Key should exist") + } + }) + + t.Run("Exists for non-existent key", func(t *testing.T) { + if storage.Exists(ctx, "non_existent") { + t.Error("Key should not exist") + } + }) + + t.Run("Exists for expired key", func(t *testing.T) { + key := "expire_exists_key" + storage.Set(ctx, key, "value", 1*time.Second) + + if !storage.Exists(ctx, key) { + t.Error("Key should exist before expiration") + } + + time.Sleep(2 * time.Second) + + if storage.Exists(ctx, key) { + t.Error("Key should not exist after expiration") + } + }) +} + +func TestMemoryStorage_Keys(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + // 设置测试数据 + testData := map[string]string{ + "user:1:token": "token1", + "user:2:token": "token2", + "user:1:role": "admin", + "session:abc": "data1", + "session:xyz": "data2", + "product:100": "item", + "product:200": "item", + "product:300": "item", + "expired:key": "value", + } + + for key, value := range testData { + storage.Set(ctx, key, value, 0) + } + + // 设置一个过期的键 + storage.Set(ctx, "expired:test", "value", 1*time.Second) + time.Sleep(2 * time.Second) + + t.Run("Match all keys with *", func(t *testing.T) { + keys, err := storage.Keys(ctx, "*") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + // 应该至少有9个键(不包括过期的) + if len(keys) < len(testData) { + t.Errorf("Expected at least %d keys, got %d", len(testData), len(keys)) + } + }) + + t.Run("Match prefix pattern user:*", func(t *testing.T) { + keys, err := storage.Keys(ctx, "user:*") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + }) + + t.Run("Match pattern user:*:token", func(t *testing.T) { + keys, err := storage.Keys(ctx, "user:*:token") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 2 { + t.Errorf("Expected 2 keys, got %d", len(keys)) + } + }) + + t.Run("Match suffix pattern *:token", func(t *testing.T) { + keys, err := storage.Keys(ctx, "*:token") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 2 { + t.Errorf("Expected 2 keys, got %d", len(keys)) + } + }) + + t.Run("Match exact key", func(t *testing.T) { + keys, err := storage.Keys(ctx, "user:1:token") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + }) + + t.Run("Match product:* pattern", func(t *testing.T) { + keys, err := storage.Keys(ctx, "product:*") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + }) +} + +func TestMemoryStorage_Expire(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + t.Run("Set expiration on existing key", func(t *testing.T) { + key := "expire_test" + storage.Set(ctx, key, "value", 0) + + err := storage.Expire(ctx, key, 2*time.Second) + if err != nil { + t.Fatalf("Failed to set expiration: %v", err) + } + + // 立即检查应该存在 + if !storage.Exists(ctx, key) { + t.Error("Key should exist") + } + + // 等待过期 + time.Sleep(3 * time.Second) + + // 过期后应该不存在 + if storage.Exists(ctx, key) { + t.Error("Key should not exist after expiration") + } + }) + + t.Run("Expire non-existent key", func(t *testing.T) { + err := storage.Expire(ctx, "non_existent", 1*time.Second) + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) + + t.Run("Expire with negative duration deletes key", func(t *testing.T) { + key := "delete_via_expire" + storage.Set(ctx, key, "value", 0) + + err := storage.Expire(ctx, key, -1*time.Second) + if err != nil { + t.Fatalf("Failed to expire key: %v", err) + } + + if storage.Exists(ctx, key) { + t.Error("Key should be deleted") + } + }) +} + +func TestMemoryStorage_TTL(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + t.Run("TTL for key with expiration", func(t *testing.T) { + key := "ttl_key" + storage.Set(ctx, key, "value", 10*time.Second) + + ttl, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + if ttl <= 0 || ttl > 10*time.Second { + t.Errorf("Expected TTL between 0 and 10s, got %v", ttl) + } + }) + + t.Run("TTL for key without expiration", func(t *testing.T) { + key := "no_ttl_key" + storage.Set(ctx, key, "value", 0) + + ttl, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + if ttl != -1*time.Second { + t.Errorf("Expected TTL -1s (no expiration), got %v", ttl) + } + }) + + t.Run("TTL for non-existent key", func(t *testing.T) { + ttl, err := storage.TTL(ctx, "non_existent") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + if ttl != -2*time.Second { + t.Errorf("Expected TTL -2s (not found), got %v", ttl) + } + }) + + t.Run("TTL for expired key", func(t *testing.T) { + key := "expired_ttl_key" + storage.Set(ctx, key, "value", 1*time.Second) + + time.Sleep(2 * time.Second) + + ttl, err := storage.TTL(ctx, key) + if err != nil { + // 过期键可能已被清理,这是正常的 + if ttl != -2*time.Second { + t.Errorf("Expected TTL -2s for expired key, got %v", ttl) + } + } + }) +} + +func TestMemoryStorage_Clear(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + // 设置多个键 + storage.Set(ctx, "key1", "value1", 0) + storage.Set(ctx, "key2", "value2", 0) + storage.Set(ctx, "key3", "value3", 0) + + err := storage.Clear(ctx) + if err != nil { + t.Fatalf("Failed to clear storage: %v", err) } - // 测试场景2: 键存在且未过期的情况 - key := "test_key" - originalValue := "original_value" - newValue := "new_value" - ttl := 10 * time.Second + // 验证所有键都被删除 + if storage.Exists(ctx, "key1") || storage.Exists(ctx, "key2") || storage.Exists(ctx, "key3") { + t.Error("All keys should be deleted after Clear") + } + + keys, _ := storage.Keys(ctx, "*") + if len(keys) != 0 { + t.Errorf("Expected 0 keys after Clear, got %d", len(keys)) + } +} + +func TestMemoryStorage_Ping(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() - // 先设置一个键值对 - err = storage.Set(key, originalValue, ttl) + err := storage.Ping(ctx) if err != nil { - t.Fatalf("Failed to set key: %v", err) + t.Fatalf("Ping should succeed: %v", err) } +} + +func TestMemoryStorage_Close(t *testing.T) { + storage := NewStorage() - // 获取原始TTL - originalTTL, err := storage.TTL(key) + ctx := context.Background() + + // 关闭前应该正常工作 + err := storage.Ping(ctx) if err != nil { - t.Fatalf("Failed to get TTL: %v", err) + t.Fatalf("Ping should succeed before close: %v", err) } - // 使用SetKeepTTL更新值 - err = storage.SetKeepTTL(key, newValue) + // 关闭存储 + err = storage.Close() if err != nil { - t.Fatalf("SetKeepTTL failed: %v", err) + t.Fatalf("Failed to close storage: %v", err) } - // 验证值已更新 - value, err := storage.Get(key) + // 关闭后 Ping 应该失败 + err = storage.Ping(ctx) + if err == nil { + t.Error("Ping should fail after close") + } + + // 重复关闭应该不报错 + err = storage.Close() if err != nil { - t.Fatalf("Failed to get value: %v", err) + t.Errorf("Second close should not return error: %v", err) + } +} + +func TestMemoryStorage_Cleanup(t *testing.T) { + // 使用较短的清理间隔创建存储 + storage := NewStorageWithCleanupInterval(500 * time.Millisecond) + defer storage.Close() + + ctx := context.Background() + + // 设置多个短期过期的键 + for i := 0; i < 10; i++ { + key := "cleanup_key_" + string(rune(i)) + storage.Set(ctx, key, "value", 1*time.Second) + } + + // 验证键存在 + keys, _ := storage.Keys(ctx, "cleanup_key_*") + if len(keys) == 0 { + t.Error("Keys should exist before expiration") + } + + // 等待过期和清理 + time.Sleep(2 * time.Second) + + // 验证过期键被清理 + keys, _ = storage.Keys(ctx, "cleanup_key_*") + if len(keys) != 0 { + t.Errorf("Expected 0 keys after cleanup, got %d", len(keys)) + } +} + +func TestMemoryStorage_ConcurrentAccess(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + + // 并发写入 + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(n int) { + key := "concurrent_key_" + string(rune(n)) + storage.Set(ctx, key, n, 0) + done <- true + }(i) + } + + // 等待所有写入完成 + for i := 0; i < 10; i++ { + <-done + } + + // 并发读取 + for i := 0; i < 10; i++ { + go func(n int) { + key := "concurrent_key_" + string(rune(n)) + storage.Get(ctx, key) + done <- true + }(i) + } + + // 等待所有读取完成 + for i := 0; i < 10; i++ { + <-done + } + + // 并发删除 + for i := 0; i < 10; i++ { + go func(n int) { + key := "concurrent_key_" + string(rune(n)) + storage.Delete(ctx, key) + done <- true + }(i) + } + + // 等待所有删除完成 + for i := 0; i < 10; i++ { + <-done + } +} +func TestMemoryStorage_ConcurrentDeviceAndTokenCountEnhanced(t *testing.T) { + storage := NewStorage() + defer storage.Close() + + ctx := context.Background() + loginId := "user1" + + // 清理历史数据 + if err := storage.Clear(ctx); err != nil { + t.Fatalf("failed to clear storage: %v", err) } - if value != newValue { - t.Errorf("Expected value %q, got %q", newValue, value) + + // 模拟同账号不同设备的登录 + keys := []string{ + // pc 设备下多个 token + "satoken:auth:" + loginId + ":pc:tokenA", + "satoken:auth:" + loginId + ":pc:tokenB", + "satoken:auth:" + loginId + ":pc:tokenC", + "satoken:auth:" + loginId + ":pc:tokenD", + "satoken:auth:" + loginId + ":pc:tokenE", + + // 其他设备 + "satoken:auth:" + loginId + ":mobile:token123", + "satoken:auth:" + loginId + ":ipad:token456", + "satoken:auth:" + loginId + ":tv:token789", } - // 验证TTL保持不变 - newTTL, err := storage.TTL(key) + for _, key := range keys { + if err := storage.Set(ctx, key, "dummy", 0); err != nil { + t.Fatalf("failed to set key %s: %v", key, err) + } + } + + // ---------- 1. 测试同账号不同设备数 ---------- + devicePattern := "satoken:auth:" + loginId + ":*:*" + allKeys, err := storage.Keys(ctx, devicePattern) if err != nil { - t.Fatalf("Failed to get TTL after update: %v", err) + t.Fatalf("failed to scan keys: %v", err) + } + + deviceSet := map[string]struct{}{} + for _, key := range allKeys { + parts := strings.Split(key, ":") + if len(parts) >= 4 { + deviceSet[parts[3]] = struct{}{} + } } - // 允许有轻微误差(不超过1秒) - ttlDiff := originalTTL - newTTL - if ttlDiff < 0 { - ttlDiff = -ttlDiff + expectedDeviceCount := 4 // pc, mobile, ipad, tv + if len(deviceSet) != expectedDeviceCount { + t.Errorf("Expected %d devices, got %d", expectedDeviceCount, len(deviceSet)) + } else { + t.Logf("Device count correct: %d", len(deviceSet)) } - if ttlDiff > time.Second { - t.Errorf("TTL changed significantly. Original: %v, New: %v", originalTTL, newTTL) + + // ---------- 2. 测试同账号同设备下 token 数 ---------- + device := "pc" + tokenPattern := "satoken:auth:" + loginId + ":" + device + ":*" + deviceKeys, err := storage.Keys(ctx, tokenPattern) + if err != nil { + t.Fatalf("failed to scan keys for device %s: %v", device, err) } - // 注意:Memory实现中,过期检查是在访问时进行的,而不是通过后台任务 - // 因此我们无法可靠地测试已过期键的情况,这里只测试键不存在的情况 + t.Logf("Token keys for device %s: %v", device, deviceKeys) + + expectedTokenCount := 5 // tokenA ~ tokenE + if len(deviceKeys) != expectedTokenCount { + t.Errorf("Expected %d tokens for device %s, got %d", expectedTokenCount, device, len(deviceKeys)) + } else { + t.Logf("Token count for device %s correct: %d", device, len(deviceKeys)) + } } diff --git a/storage/redis/go.mod b/storage/redis/go.mod index 4901b81..9282a5b 100644 --- a/storage/redis/go.mod +++ b/storage/redis/go.mod @@ -1,9 +1,8 @@ module github.com/click33/sa-token-go/storage/redis -go 1.23.0 +go 1.25.0 require ( - github.com/click33/sa-token-go/core v0.1.5 github.com/redis/go-redis/v9 v9.5.1 ) @@ -11,5 +10,3 @@ require ( github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect ) - -replace github.com/click33/sa-token-go/core => ../../core diff --git a/storage/redis/redis.go b/storage/redis/redis.go index c2a6fee..0aada78 100644 --- a/storage/redis/redis.go +++ b/storage/redis/redis.go @@ -5,14 +5,12 @@ import ( "fmt" "time" - "github.com/click33/sa-token-go/core/adapter" "github.com/redis/go-redis/v9" ) // Storage Redis存储实现 type Storage struct { client *redis.Client - ctx context.Context opTimeout time.Duration } @@ -33,7 +31,7 @@ type Config struct { } // NewStorage 通过Redis URL创建存储 -func NewStorage(url string) (adapter.Storage, error) { +func NewStorage(url string) (*Storage, error) { opts, err := redis.ParseURL(url) if err != nil { return nil, fmt.Errorf("failed to parse redis url: %w", err) @@ -49,13 +47,12 @@ func NewStorage(url string) (adapter.Storage, error) { return &Storage{ client: client, - ctx: ctx, opTimeout: 3 * time.Second, }, nil } // NewStorageFromConfig 通过配置创建存储 -func NewStorageFromConfig(cfg *Config) (adapter.Storage, error) { +func NewStorageFromConfig(cfg *Config) (*Storage, error) { client := redis.NewClient(&redis.Options{ Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), Password: cfg.Password, @@ -80,16 +77,14 @@ func NewStorageFromConfig(cfg *Config) (adapter.Storage, error) { return &Storage{ client: client, - ctx: ctx, opTimeout: opTimeout, }, nil } // NewStorageFromClient 从已有的Redis客户端创建存储 -func NewStorageFromClient(client *redis.Client) adapter.Storage { +func NewStorageFromClient(client *redis.Client) *Storage { return &Storage{ client: client, - ctx: context.Background(), opTimeout: 3 * time.Second, } } @@ -100,15 +95,15 @@ func (s *Storage) getKey(key string) string { } // Set 设置键值对 -func (s *Storage) Set(key string, value any, expiration time.Duration) error { - ctx, cancel := s.withTimeout() +func (s *Storage) Set(ctx context.Context, key string, value any, expiration time.Duration) error { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() return s.client.Set(ctx, s.getKey(key), value, expiration).Err() } // SetKeepTTL Sets value without modifying TTL | 设置键值但保持原有TTL不变 -func (s *Storage) SetKeepTTL(key string, value any) error { - ctx, cancel := s.withTimeout() +func (s *Storage) SetKeepTTL(ctx context.Context, key string, value any) error { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() // 先检查键是否存在,不存在则返回错误(与Memory实现保持一致) @@ -127,8 +122,8 @@ func (s *Storage) SetKeepTTL(key string, value any) error { } // Get 获取值 -func (s *Storage) Get(key string) (any, error) { - ctx, cancel := s.withTimeout() +func (s *Storage) Get(ctx context.Context, key string) (any, error) { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() val, err := s.client.Get(ctx, s.getKey(key)).Result() if err == redis.Nil { @@ -140,13 +135,31 @@ func (s *Storage) Get(key string) (any, error) { return val, nil } +// GetAndDelete atomically gets the value and deletes the key | 原子获取并删除键 +func (s *Storage) GetAndDelete(ctx context.Context, key string) (any, error) { + ctx, cancel := s.withTimeoutCtx(ctx) + defer cancel() + + val, err := s.client.Get(ctx, s.getKey(key)).Result() + if err == redis.Nil { + return nil, fmt.Errorf("key not found: %s", key) + } + if err != nil { + return nil, err + } + + _, _ = s.client.Del(ctx, s.getKey(key)).Result() + + return val, nil +} + // Delete 删除键 -func (s *Storage) Delete(keys ...string) error { +func (s *Storage) Delete(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } - ctx, cancel := s.withTimeout() + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() fullKeys := make([]string, len(keys)) @@ -157,8 +170,8 @@ func (s *Storage) Delete(keys ...string) error { } // Exists 检查键是否存在 -func (s *Storage) Exists(key string) bool { - ctx, cancel := s.withTimeout() +func (s *Storage) Exists(ctx context.Context, key string) bool { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() result, err := s.client.Exists(ctx, s.getKey(key)).Result() if err != nil { @@ -168,8 +181,8 @@ func (s *Storage) Exists(key string) bool { } // Keys 获取匹配模式的所有键 -func (s *Storage) Keys(pattern string) ([]string, error) { - ctx, cancel := s.withTimeout() +func (s *Storage) Keys(ctx context.Context, pattern string) ([]string, error) { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() var ( @@ -194,22 +207,22 @@ func (s *Storage) Keys(pattern string) ([]string, error) { } // Expire 设置键的过期时间 -func (s *Storage) Expire(key string, expiration time.Duration) error { - ctx, cancel := s.withTimeout() +func (s *Storage) Expire(ctx context.Context, key string, expiration time.Duration) error { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() return s.client.Expire(ctx, s.getKey(key), expiration).Err() } // TTL 获取键的剩余生存时间 -func (s *Storage) TTL(key string) (time.Duration, error) { - ctx, cancel := s.withTimeout() +func (s *Storage) TTL(ctx context.Context, key string) (time.Duration, error) { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() return s.client.TTL(ctx, s.getKey(key)).Result() } // Clear 清空所有数据(警告:会清空整个 Redis,谨慎使用!应由 Manager 层控制) -func (s *Storage) Clear() error { - ctx, cancel := s.withTimeout() +func (s *Storage) Clear(ctx context.Context) error { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() var cursor uint64 @@ -233,8 +246,8 @@ func (s *Storage) Clear() error { } // Ping 检查连接 -func (s *Storage) Ping() error { - ctx, cancel := s.withTimeout() +func (s *Storage) Ping(ctx context.Context) error { + ctx, cancel := s.withTimeoutCtx(ctx) defer cancel() return s.client.Ping(ctx).Err() } @@ -249,12 +262,13 @@ func (s *Storage) GetClient() *redis.Client { return s.client } -// withTimeout returns a context with the configured per-operation timeout. -func (s *Storage) withTimeout() (context.Context, context.CancelFunc) { +// withTimeoutCtx returns a context with the configured per-operation timeout. +// If the parent context has a shorter deadline, it will be respected. +func (s *Storage) withTimeoutCtx(parent context.Context) (context.Context, context.CancelFunc) { if s.opTimeout > 0 { - return context.WithTimeout(s.ctx, s.opTimeout) + return context.WithTimeout(parent, s.opTimeout) } - return context.WithCancel(s.ctx) + return context.WithCancel(parent) } // Builder Redis存储构建器 @@ -308,7 +322,7 @@ func (b *Builder) PoolSize(poolSize int) *Builder { } // Build 构建存储 -func (b *Builder) Build() (adapter.Storage, error) { +func (b *Builder) Build() (*Storage, error) { return NewStorageFromConfig(&Config{ Host: b.host, Port: b.port, diff --git a/storage/redis/redis_test.go b/storage/redis/redis_test.go index 4ca09b6..8678810 100644 --- a/storage/redis/redis_test.go +++ b/storage/redis/redis_test.go @@ -1,86 +1,834 @@ package redis import ( + "context" + "fmt" + "os" + "strings" "testing" + "time" + + "github.com/redis/go-redis/v9" ) -// 如果需要在本地运行测试,请取消下面注释并配置Redis连接信息 -/* -func TestSetKeepTTL(t *testing.T) { - // 创建Redis客户端 +// getTestRedisClient 获取测试用的Redis客户端 +// 优先使用环境变量 REDIS_URL,否则使用默认地址 +func getTestRedisClient(t *testing.T) *redis.Client { + addr := os.Getenv("REDIS_URL") + if addr == "" { + addr = "192.168.19.104:6379" + } + client := redis.NewClient(&redis.Options{ - Addr: "localhost:6379", // Redis地址 - Password: "", // 无密码 - DB: 0, // 默认DB + Addr: addr, + Password: "root", + DB: 0, // 使用独立的测试DB }) - // 创建存储实例 - storage := &Storage{ - client: client, - prefix: "test:", - timeout: 5 * time.Second, + // 测试连接 + ctx := context.Background() + if err := client.Ping(ctx).Err(); err != nil { + t.Skipf("Redis not available at %s, skipping test: %v", addr, err) } - // 测试场景1: 键不存在的情况 - err := storage.SetKeepTTL("non_existent_key", "value") - if err == nil { - t.Errorf("Expected error for non-existent key, got nil") + return client +} + +// cleanupTestData 清理测试数据 +func cleanupTestData(t *testing.T, storage *Storage) { + ctx := context.Background() + if err := storage.Clear(ctx); err != nil { + t.Logf("Warning: failed to cleanup test data: %v", err) + } +} + +func TestRedisStorage_SetAndGet(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("Set and Get basic value", func(t *testing.T) { + key := "test_key" + value := "test_value" + + err := storage.Set(ctx, key, value, 0) + if err != nil { + t.Fatalf("Failed to set key: %v", err) + } + + got, err := storage.Get(ctx, key) + if err != nil { + t.Fatalf("Failed to get key: %v", err) + } + + if got != value { + t.Errorf("Expected value %q, got %q", value, got) + } + + // 清理 + storage.Delete(ctx, key) + }) + + t.Run("Set with expiration", func(t *testing.T) { + key := "expire_key" + value := "expire_value" + + err := storage.Set(ctx, key, value, 2*time.Second) + if err != nil { + t.Fatalf("Failed to set key: %v", err) + } + + // 立即获取应该成功 + got, err := storage.Get(ctx, key) + if err != nil { + t.Fatalf("Failed to get key: %v", err) + } + if got != value { + t.Errorf("Expected value %q, got %q", value, got) + } + + // 等待过期 + time.Sleep(3 * time.Second) + + // 过期后获取应该失败 + _, err = storage.Get(ctx, key) + if err == nil { + t.Error("Expected error for expired key, got nil") + } + }) + + t.Run("Get non-existent key", func(t *testing.T) { + _, err := storage.Get(ctx, "non_existent_key_12345") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) +} + +func TestRedisStorage_SetKeepTTL(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("SetKeepTTL for non-existent key", func(t *testing.T) { + err := storage.SetKeepTTL(ctx, "non_existent_key_999", "value") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) + + t.Run("SetKeepTTL preserves TTL", func(t *testing.T) { + key := "test_key_keepttl" + originalValue := "original_value" + newValue := "new_value" + ttl := 10 * time.Second + + // 设置初始值和TTL + err := storage.Set(ctx, key, originalValue, ttl) + if err != nil { + t.Fatalf("Failed to set key: %v", err) + } + + // 获取原始TTL + originalTTL, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + // 等待1秒 + time.Sleep(1 * time.Second) + + // 使用SetKeepTTL更新值 + err = storage.SetKeepTTL(ctx, key, newValue) + if err != nil { + t.Fatalf("SetKeepTTL failed: %v", err) + } + + // 验证值已更新 + value, err := storage.Get(ctx, key) + if err != nil { + t.Fatalf("Failed to get value: %v", err) + } + if value != newValue { + t.Errorf("Expected value %q, got %q", newValue, value) + } + + // 验证TTL保持相对不变(允许误差) + newTTL, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL after update: %v", err) + } + + ttlDiff := originalTTL - newTTL + if ttlDiff < 0 { + ttlDiff = -ttlDiff + } + if ttlDiff > 2*time.Second { + t.Errorf("TTL changed significantly. Original: %v, New: %v, Diff: %v", originalTTL, newTTL, ttlDiff) + } + + // 清理 + storage.Delete(ctx, key) + }) +} + +func TestRedisStorage_Delete(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("Delete single key", func(t *testing.T) { + key := "delete_key" + value := "delete_value" + + storage.Set(ctx, key, value, 0) + + err := storage.Delete(ctx, key) + if err != nil { + t.Fatalf("Failed to delete key: %v", err) + } + + exists := storage.Exists(ctx, key) + if exists { + t.Error("Key should not exist after deletion") + } + }) + + t.Run("Delete multiple keys", func(t *testing.T) { + keys := []string{"del_key1", "del_key2", "del_key3"} + for _, key := range keys { + storage.Set(ctx, key, "value", 0) + } + + err := storage.Delete(ctx, keys...) + if err != nil { + t.Fatalf("Failed to delete keys: %v", err) + } + + for _, key := range keys { + if storage.Exists(ctx, key) { + t.Errorf("Key %s should not exist after deletion", key) + } + } + }) + + t.Run("Delete empty keys", func(t *testing.T) { + err := storage.Delete(ctx) + if err != nil { + t.Errorf("Delete with no keys should not return error: %v", err) + } + }) +} + +func TestRedisStorage_GetAndDelete(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("GetAndDelete existing key", func(t *testing.T) { + key := "getdel_key" + value := "getdel_value" + + storage.Set(ctx, key, value, 0) + + got, err := storage.GetAndDelete(ctx, key) + if err != nil { + t.Fatalf("GetAndDelete failed: %v", err) + } + + if got != value { + t.Errorf("Expected value %q, got %q", value, got) + } + + // 键应该已被删除 + if storage.Exists(ctx, key) { + t.Error("Key should not exist after GetAndDelete") + } + }) + + t.Run("GetAndDelete non-existent key", func(t *testing.T) { + _, err := storage.GetAndDelete(ctx, "non_existent_getdel") + if err == nil { + t.Error("Expected error for non-existent key, got nil") + } + }) +} + +func TestRedisStorage_Exists(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("Exists for existing key", func(t *testing.T) { + key := "exists_key" + storage.Set(ctx, key, "value", 0) + + if !storage.Exists(ctx, key) { + t.Error("Key should exist") + } + + // 清理 + storage.Delete(ctx, key) + }) + + t.Run("Exists for non-existent key", func(t *testing.T) { + if storage.Exists(ctx, "non_existent_exists") { + t.Error("Key should not exist") + } + }) + + t.Run("Exists for expired key", func(t *testing.T) { + key := "expire_exists_key" + storage.Set(ctx, key, "value", 1*time.Second) + + if !storage.Exists(ctx, key) { + t.Error("Key should exist before expiration") + } + + time.Sleep(2 * time.Second) + + if storage.Exists(ctx, key) { + t.Error("Key should not exist after expiration") + } + }) +} + +func TestRedisStorage_Keys(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + // 设置测试数据 + testData := map[string]string{ + "test:user:1:token": "token1", + "test:user:2:token": "token2", + "test:user:1:role": "admin", + "test:session:abc": "data1", + "test:session:xyz": "data2", + "test:product:100": "item", + "test:product:200": "item", + "test:product:300": "item", + } + + for key, value := range testData { + storage.Set(ctx, key, value, 0) } - // 测试场景2: 键存在且未过期的情况 - key := "test_key" - originalValue := "original_value" - newValue := "new_value" - ttl := 10 * time.Second + t.Run("Match all test keys with test:*", func(t *testing.T) { + keys, err := storage.Keys(ctx, "test:*") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != len(testData) { + t.Errorf("Expected %d keys, got %d", len(testData), len(keys)) + } + }) + + t.Run("Match prefix pattern test:user:*", func(t *testing.T) { + keys, err := storage.Keys(ctx, "test:user:*") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + }) + + t.Run("Match pattern test:user:*:token", func(t *testing.T) { + keys, err := storage.Keys(ctx, "test:user:*:token") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 2 { + t.Errorf("Expected 2 keys, got %d", len(keys)) + } + }) + + t.Run("Match exact key", func(t *testing.T) { + keys, err := storage.Keys(ctx, "test:user:1:token") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 1 { + t.Errorf("Expected 1 key, got %d", len(keys)) + } + }) + + t.Run("Match product:* pattern", func(t *testing.T) { + keys, err := storage.Keys(ctx, "test:product:*") + if err != nil { + t.Fatalf("Failed to get keys: %v", err) + } + if len(keys) != 3 { + t.Errorf("Expected 3 keys, got %d", len(keys)) + } + }) +} + +func TestRedisStorage_Expire(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("Set expiration on existing key", func(t *testing.T) { + key := "expire_test" + storage.Set(ctx, key, "value", 0) + + err := storage.Expire(ctx, key, 2*time.Second) + if err != nil { + t.Fatalf("Failed to set expiration: %v", err) + } + + // 立即检查应该存在 + if !storage.Exists(ctx, key) { + t.Error("Key should exist") + } + + // 等待过期 + time.Sleep(3 * time.Second) + + // 过期后应该不存在 + if storage.Exists(ctx, key) { + t.Error("Key should not exist after expiration") + } + }) - // 先设置一个键值对 - err = storage.Set(key, originalValue, ttl) + t.Run("Expire non-existent key", func(t *testing.T) { + // Redis的EXPIRE命令对不存在的键会返回0,但不会报错 + // 这里只是确保不会崩溃 + err := storage.Expire(ctx, "non_existent_expire", 1*time.Second) + if err != nil { + t.Logf("Expire on non-existent key returned error (expected): %v", err) + } + }) +} + +func TestRedisStorage_TTL(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + t.Run("TTL for key with expiration", func(t *testing.T) { + key := "ttl_key" + storage.Set(ctx, key, "value", 10*time.Second) + + ttl, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + if ttl <= 0 || ttl > 10*time.Second { + t.Errorf("Expected TTL between 0 and 10s, got %v", ttl) + } + + // 清理 + storage.Delete(ctx, key) + }) + + t.Run("TTL for key without expiration", func(t *testing.T) { + key := "no_ttl_key" + storage.Set(ctx, key, "value", 0) + + ttl, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + // Redis返回-1表示永不过期 + if ttl != -1*time.Second { + t.Errorf("Expected TTL -1s (no expiration), got %v", ttl) + } + + // 清理 + storage.Delete(ctx, key) + }) + + t.Run("TTL for non-existent key", func(t *testing.T) { + ttl, err := storage.TTL(ctx, "non_existent_ttl") + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + // Redis返回-2表示键不存在 + if ttl != -2*time.Second { + t.Errorf("Expected TTL -2s (not found), got %v", ttl) + } + }) + + t.Run("TTL for expired key", func(t *testing.T) { + key := "expired_ttl_key" + storage.Set(ctx, key, "value", 1*time.Second) + + time.Sleep(2 * time.Second) + + ttl, err := storage.TTL(ctx, key) + if err != nil { + t.Fatalf("Failed to get TTL: %v", err) + } + + // 已过期的键应该返回-2 + if ttl != -2*time.Second { + t.Errorf("Expected TTL -2s for expired key, got %v", ttl) + } + }) +} + +func TestRedisStorage_Clear(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + + ctx := context.Background() + + // 设置多个键 + storage.Set(ctx, "clear_key1", "value1", 0) + storage.Set(ctx, "clear_key2", "value2", 0) + storage.Set(ctx, "clear_key3", "value3", 0) + + err := storage.Clear(ctx) if err != nil { - t.Fatalf("Failed to set key: %v", err) + t.Fatalf("Failed to clear storage: %v", err) + } + + // 验证所有键都被删除 + if storage.Exists(ctx, "clear_key1") || storage.Exists(ctx, "clear_key2") || storage.Exists(ctx, "clear_key3") { + t.Error("All keys should be deleted after Clear") } +} - // 获取原始TTL - originalTTL, err := storage.TTL(key) +func TestRedisStorage_Ping(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + + ctx := context.Background() + + err := storage.Ping(ctx) if err != nil { - t.Fatalf("Failed to get TTL: %v", err) + t.Fatalf("Ping should succeed: %v", err) } +} - // 使用SetKeepTTL更新值 - err = storage.SetKeepTTL(key, newValue) +func TestRedisStorage_Close(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + + ctx := context.Background() + + // 关闭前应该正常工作 + err := storage.Ping(ctx) if err != nil { - t.Fatalf("SetKeepTTL failed: %v", err) + t.Fatalf("Ping should succeed before close: %v", err) } - // 验证值已更新 - value, err := storage.Get(key) + // 关闭存储 + err = storage.Close() if err != nil { - t.Fatalf("Failed to get value: %v", err) + t.Fatalf("Failed to close storage: %v", err) } - if value != newValue { - t.Errorf("Expected value %q, got %q", newValue, value) + + // 关闭后操作应该失败 + err = storage.Ping(ctx) + if err == nil { + t.Error("Ping should fail after close") } +} + +func TestRedisStorage_NewStorage(t *testing.T) { + t.Run("NewStorage with valid URL", func(t *testing.T) { + url := "redis://localhost:6379/15" + storage, err := NewStorage(url) + if err != nil { + t.Skipf("Redis not available, skipping test: %v", err) + } + defer storage.Close() - // 验证TTL保持不变 - newTTL, err := storage.TTL(key) + ctx := context.Background() + err = storage.Ping(ctx) + if err != nil { + t.Fatalf("Ping should succeed: %v", err) + } + }) + + t.Run("NewStorage with invalid URL", func(t *testing.T) { + url := "invalid://url" + _, err := NewStorage(url) + if err == nil { + t.Error("Expected error for invalid URL, got nil") + } + }) +} + +func TestRedisStorage_NewStorageFromConfig(t *testing.T) { + t.Run("NewStorageFromConfig with valid config", func(t *testing.T) { + cfg := &Config{ + Host: "localhost", + Port: 6379, + Password: "", + Database: 15, + PoolSize: 10, + OperationTimeout: 3 * time.Second, + } + + storage, err := NewStorageFromConfig(cfg) + if err != nil { + t.Skipf("Redis not available, skipping test: %v", err) + } + defer storage.Close() + + ctx := context.Background() + err = storage.Ping(ctx) + if err != nil { + t.Fatalf("Ping should succeed: %v", err) + } + }) + + t.Run("NewStorageFromConfig with invalid config", func(t *testing.T) { + cfg := &Config{ + Host: "invalid-host-12345", + Port: 9999, + Database: 0, + PoolSize: 10, + } + + _, err := NewStorageFromConfig(cfg) + if err == nil { + t.Error("Expected error for invalid config, got nil") + } + }) +} + +func TestRedisStorage_Builder(t *testing.T) { + t.Run("Builder pattern", func(t *testing.T) { + storage, err := NewBuilder(). + Host("localhost"). + Port(6379). + Database(15). + PoolSize(10). + Build() + + if err != nil { + t.Skipf("Redis not available, skipping test: %v", err) + } + defer storage.Close() + + ctx := context.Background() + err = storage.Ping(ctx) + if err != nil { + t.Fatalf("Ping should succeed: %v", err) + } + }) + + t.Run("Builder with password", func(t *testing.T) { + storage, err := NewBuilder(). + Host("localhost"). + Port(6379). + Password(""). // 测试环境通常没有密码 + Database(15). + Build() + + if err != nil { + t.Skipf("Redis not available, skipping test: %v", err) + } + defer storage.Close() + + ctx := context.Background() + err = storage.Ping(ctx) + if err != nil { + t.Fatalf("Ping should succeed: %v", err) + } + }) +} + +func TestRedisStorage_GetClient(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + + redisClient := storage.GetClient() + if redisClient == nil { + t.Error("GetClient should return a valid client") + } + + // 测试使用获取的客户端 + ctx := context.Background() + err := redisClient.Ping(ctx).Err() + if err != nil { + t.Fatalf("Client from GetClient should work: %v", err) + } +} + +func TestRedisStorage_ConcurrentAccess(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + + // 并发写入 + done := make(chan bool) + for i := 0; i < 10; i++ { + go func(n int) { + key := "concurrent_key_" + string(rune(n)) + storage.Set(ctx, key, n, 0) + done <- true + }(i) + } + + // 等待所有写入完成 + for i := 0; i < 10; i++ { + <-done + } + + // 并发读取 + for i := 0; i < 10; i++ { + go func(n int) { + key := "concurrent_key_" + string(rune(n)) + storage.Get(ctx, key) + done <- true + }(i) + } + + // 等待所有读取完成 + for i := 0; i < 10; i++ { + <-done + } + + // 并发删除 + for i := 0; i < 10; i++ { + go func(n int) { + key := "concurrent_key_" + string(rune(n)) + storage.Delete(ctx, key) + done <- true + }(i) + } + + // 等待所有删除完成 + for i := 0; i < 10; i++ { + <-done + } +} + +func TestConcurrentDeviceAndTokenCountEnhanced(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + defer cleanupTestData(t, storage) + + ctx := context.Background() + loginId := "1" + + // 清理历史数据 + if err := storage.Clear(ctx); err != nil { + t.Fatalf("failed to clear storage: %v", err) + } + + // 模拟同账号不同设备的登录 + keys := []string{ + // pc 设备下多个 token + fmt.Sprintf("satoken:auth:%s:pc:tokenA", loginId), + fmt.Sprintf("satoken:auth:%s:pc:tokenB", loginId), + fmt.Sprintf("satoken:auth:%s:pc:tokenC", loginId), + fmt.Sprintf("satoken:auth:%s:pc:tokenD", loginId), + fmt.Sprintf("satoken:auth:%s:pc:tokenE", loginId), + + // 其他设备 + fmt.Sprintf("satoken:auth:%s:mobile:token123", loginId), + fmt.Sprintf("satoken:auth:%s:ipad:token456", loginId), + fmt.Sprintf("satoken:auth:%s:tv:token789", loginId), + } + + for _, key := range keys { + if err := storage.Set(ctx, key, "dummy", 0); err != nil { + t.Fatalf("failed to set key %s: %v", key, err) + } + } + + // ---------- 1. 测试同账号不同设备数 ---------- + devicePattern := fmt.Sprintf("satoken:auth:%s:*:*", loginId) + allKeys, err := storage.Keys(ctx, devicePattern) if err != nil { - t.Fatalf("Failed to get TTL after update: %v", err) + t.Fatalf("failed to scan keys: %v", err) } - // 允许有轻微误差(不超过1秒) - ttlDiff := originalTTL - newTTL - if ttlDiff < 0 { - ttlDiff = -ttlDiff + deviceSet := map[string]struct{}{} + for _, key := range allKeys { + parts := strings.Split(key, ":") + if len(parts) >= 4 { + deviceSet[parts[3]] = struct{}{} + } } - if ttlDiff > time.Second { - t.Errorf("TTL changed significantly. Original: %v, New: %v", originalTTL, newTTL) + + expectedDeviceCount := 4 // pc, mobile, ipad, tv + if len(deviceSet) != expectedDeviceCount { + t.Errorf("Expected %d devices, got %d", expectedDeviceCount, len(deviceSet)) + } else { + t.Logf("Device count correct: %d", len(deviceSet)) + } + + // ---------- 2. 测试同账号同设备下 token 数 ---------- + device := "pc" + tokenPattern := fmt.Sprintf("satoken:auth:%s:%s:*", loginId, device) + deviceKeys, err := storage.Keys(ctx, tokenPattern) + if err != nil { + t.Fatalf("failed to scan keys for device %s: %v", device, err) } - // 清理测试数据 - storage.Delete(key) + fmt.Println(len(deviceKeys)) + fmt.Println(deviceKeys) + + expectedTokenCount := 5 // tokenA ~ tokenE + if len(deviceKeys) != expectedTokenCount { + t.Errorf("Expected %d tokens for device %s, got %d", expectedTokenCount, device, len(deviceKeys)) + } else { + t.Logf("Token count for device %s correct: %d", device, len(deviceKeys)) + } } -*/ -// 占位测试,确保测试文件能够编译通过 -func TestDummy(t *testing.T) { - // 这是一个空测试,仅用于确保测试文件能够编译通过 +func TestDaixk(t *testing.T) { + client := getTestRedisClient(t) + storage := NewStorageFromClient(client) + defer storage.Close() + + ctx := context.Background() + //marshal, _ := json.Marshal(manager.TokenInfo{ + // AuthType: "dsfdsf", + // LoginID: "dsfsdf", + // Device: "dsfsdf", + // CreateTime: 11, + // ActiveTime: 3423423, + // Tag: "", + //}) + _ = storage.Set(ctx, "11111111", "KICK_OUT", 10000) + + tokenInfo, _ := storage.Get(ctx, "11111111") + s, ok := tokenInfo.(string) + if ok { + fmt.Println(s) + } } diff --git a/stputil/go.mod b/stputil/go.mod index 9f31aa9..d0d0635 100644 --- a/stputil/go.mod +++ b/stputil/go.mod @@ -1,14 +1,12 @@ module github.com/click33/sa-token-go/stputil -go 1.23.0 +go 1.25.0 -require github.com/click33/sa-token-go/core v0.1.5 +require github.com/click33/sa-token-go/core v0.1.7 require ( - github.com/golang-jwt/jwt/v5 v5.2.2 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/google/uuid v1.6.0 // indirect github.com/panjf2000/ants/v2 v2.11.3 // indirect - golang.org/x/sync v0.16.0 // indirect + golang.org/x/sync v0.19.0 // indirect ) - -replace github.com/click33/sa-token-go/core => ../core diff --git a/stputil/go.sum b/stputil/go.sum index dda2c2d..e5a3180 100644 --- a/stputil/go.sum +++ b/stputil/go.sum @@ -1,8 +1,9 @@ +github.com/click33/sa-token-go/core v0.1.6 h1:ELOe0qSH1b3LRsQD3DIBg0e1VgYANKFg5H7z57Lkt/8= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= -github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/panjf2000/ants/v2 v2.11.3 h1:AfI0ngBoXJmYOpDh9m516vjqoUu2sLrIVgppI9TZVpg= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/stputil/stputil.go b/stputil/stputil.go index 4a84240..efd8114 100644 --- a/stputil/stputil.go +++ b/stputil/stputil.go @@ -1,7 +1,13 @@ package stputil import ( + "context" "fmt" + "github.com/click33/sa-token-go/core" + "github.com/click33/sa-token-go/core/adapter" + "github.com/click33/sa-token-go/core/config" + "github.com/click33/sa-token-go/core/listener" + "strings" "sync" "time" @@ -11,433 +17,1460 @@ import ( "github.com/click33/sa-token-go/core/session" ) -// Global Manager instance | 全局Manager实例 var ( - globalManager *manager.Manager - once sync.Once - mu sync.RWMutex + globalManagerMap sync.Map ) -// SetManager sets the global Manager (must be called first) | 设置全局Manager(必须先调用此方法) -func SetManager(mgr *manager.Manager) { - mu.Lock() - defer mu.Unlock() - globalManager = mgr -} +// ============ Authentication | 登录认证 ============ -// GetManager gets the global Manager | 获取全局Manager -func GetManager() *manager.Manager { - mu.RLock() - defer mu.RUnlock() - if globalManager == nil { - panic("StpUtil not initialized, please call SetManager() first or use builder.NewBuilder().Build()") +// Login performs user login | 用户登录 +func Login(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + mgr, err := GetManager(deviceOrAutoType...) + if err != nil { + return "", err } - return globalManager -} -// CloseManager closes global Manager and releases resources | 关闭全局 Manager 并释放资源 -func CloseManager() { - mu.Lock() - defer mu.Unlock() - if globalManager != nil { - globalManager.CloseManager() - globalManager = nil // 置 nil 避免后续误用 + if id, err := toString(loginID); err != nil { + return "", err + } else { + return mgr.Login(ctx, id, deviceOrAutoType...) } } -// ============ Authentication | 登录认证 ============ - -// Login performs user login | 用户登录 -func Login(loginID interface{}, device ...string) (string, error) { - return GetManager().Login(toString(loginID), device...) -} - // LoginByToken performs login with specified token | 使用指定Token登录 -func LoginByToken(loginID interface{}, tokenValue string, device ...string) error { - return GetManager().LoginByToken(toString(loginID), tokenValue, device...) +func LoginByToken(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + return mgr.LoginByToken(ctx, tokenValue) } // Logout performs user logout | 用户登出 -func Logout(loginID interface{}, device ...string) error { - return GetManager().Logout(toString(loginID), device...) +func Logout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + mgr, err := GetManager(deviceOrAutoType...) + if err != nil { + return err + } + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.Logout(ctx, id, deviceOrAutoType...) + } } // LogoutByToken performs logout by token | 根据Token登出 -func LogoutByToken(tokenValue string) error { - return GetManager().LogoutByToken(tokenValue) -} +func LogoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// IsLogin checks if the user is logged in | 检查用户是否已登录 -func IsLogin(tokenValue string) bool { - return GetManager().IsLogin(tokenValue) + return mgr.LogoutByToken(ctx, tokenValue) } -// CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) -func CheckLogin(tokenValue string) error { - return GetManager().CheckLogin(tokenValue) -} +// ============ Online Status Management | 在线状态管理 ============ -// GetLoginID gets the login ID from token | 从Token获取登录ID -func GetLoginID(tokenValue string) (string, error) { - return GetManager().GetLoginID(tokenValue) -} +// Kickout kicks out a user session | 踢人下线 +func Kickout(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + mgr, err := GetManager(deviceOrAutoType...) + if err != nil { + return err + } -// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查) -func GetLoginIDNotCheck(tokenValue string) (string, error) { - return GetManager().GetLoginIDNotCheck(tokenValue) + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.Kickout(ctx, id, deviceOrAutoType...) + } } -// GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 -func GetTokenValue(loginID interface{}, device ...string) (string, error) { - return GetManager().GetTokenValue(toString(loginID), device...) -} +// KickoutByToken Kick user offline | 根据Token踢人下线 +func KickoutByToken(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// GetTokenInfo gets token information | 获取Token信息 -func GetTokenInfo(tokenValue string) (*manager.TokenInfo, error) { - return GetManager().GetTokenInfo(tokenValue) + return mgr.KickoutByToken(ctx, tokenValue) } -// ============ Kickout | 踢人下线 ============ +// Replace user offline by login ID and device | 根据账号和设备顶人下线 +func Replace(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) error { + mgr, err := GetManager(deviceOrAutoType...) + if err != nil { + return err + } -// Kickout kicks out a user session | 踢人下线 -func Kickout(loginID interface{}, device ...string) error { - return GetManager().Kickout(toString(loginID), device...) + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.Replace(ctx, id, deviceOrAutoType...) + } } -// ============ Account Disable | 账号封禁 ============ +// ReplaceByToken Replace user offline by token | 根据Token顶人下线 +func ReplaceByToken(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// Disable disables an account for specified duration | 封禁账号(指定时长) -func Disable(loginID interface{}, duration time.Duration) error { - return GetManager().Disable(toString(loginID), duration) + return mgr.ReplaceByToken(ctx, tokenValue) } -// Untie re-enables a disabled account | 解封账号 -func Untie(loginID interface{}) error { - return GetManager().Untie(toString(loginID)) -} +// ============ Token Validation | Token验证 ============ -// IsDisable checks if an account is disabled | 检查账号是否被封禁 -func IsDisable(loginID interface{}) bool { - return GetManager().IsDisable(toString(loginID)) -} +// IsLogin checks if the user is logged in | 检查用户是否已登录 +func IsLogin(ctx context.Context, tokenValue string, authType ...string) (bool, error) { + mgr, err := GetManager(authType...) + if err != nil { + return false, core.ErrManagerNotFound + } -// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) -func GetDisableTime(loginID interface{}) (int64, error) { - return GetManager().GetDisableTime(toString(loginID)) + return mgr.IsLogin(ctx, tokenValue) } -// ============ Session Management | Session管理 ============ +// CheckLogin checks login status (throws error if not logged in) | 检查登录状态(未登录抛出错误) +func CheckLogin(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// GetSession gets session by login ID | 根据登录ID获取Session -func GetSession(loginID interface{}) (*session.Session, error) { - return GetManager().GetSession(toString(loginID)) + return mgr.CheckLogin(ctx, tokenValue) } -// GetSessionByToken gets session by token | 根据Token获取Session -func GetSessionByToken(tokenValue string) (*session.Session, error) { - return GetManager().GetSessionByToken(tokenValue) -} +// ============ Token Information | Token信息与解析 ============ -// DeleteSession deletes a session | 删除Session -func DeleteSession(loginID interface{}) error { - return GetManager().DeleteSession(toString(loginID)) +// GetLoginID gets the login ID from token | 从Token获取登录ID +func GetLoginID(ctx context.Context, tokenValue string, authType ...string) (string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return "", err + } + + return mgr.GetLoginID(ctx, tokenValue) } -// ============ Permission Verification | 权限验证 ============ +// GetLoginIDNotCheck gets login ID without checking | 获取登录ID(不检查登录状态) +func GetLoginIDNotCheck(ctx context.Context, tokenValue string, authType ...string) (string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return "", err + } -// SetPermissions sets permissions for a login ID | 设置用户权限 -func SetPermissions(loginID interface{}, permissions []string) error { - return GetManager().SetPermissions(toString(loginID), permissions) + return mgr.GetLoginIDNotCheck(ctx, tokenValue) } -// GetPermissions gets permission list | 获取权限列表 -func GetPermissions(loginID interface{}) ([]string, error) { - return GetManager().GetPermissions(toString(loginID)) -} +// GetTokenValue gets the token value for a login ID | 获取登录ID对应的Token值 +func GetTokenValue(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (string, error) { + mgr, err := GetManager(deviceOrAutoType...) + if err != nil { + return "", err + } -// HasPermission checks if has specified permission | 检查是否拥有指定权限 -func HasPermission(loginID interface{}, permission string) bool { - return GetManager().HasPermission(toString(loginID), permission) + if id, err := toString(loginID); err != nil { + return "", err + } else { + return mgr.GetTokenValue(ctx, id, deviceOrAutoType...) + } } -// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) -func HasPermissionsAnd(loginID interface{}, permissions []string) bool { - return GetManager().HasPermissionsAnd(toString(loginID), permissions) -} +// GetTokenInfo gets token information | 获取Token信息 +func GetTokenInfo(ctx context.Context, tokenValue string, authType ...string) (*manager.TokenInfo, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } -// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) -func HasPermissionsOr(loginID interface{}, permissions []string) bool { - return GetManager().HasPermissionsOr(toString(loginID), permissions) + return mgr.GetTokenInfoByToken(ctx, tokenValue) } -// ============ Role Management | 角色管理 ============ +// ============ Account Disable | 账号封禁 ============ -// SetRoles sets roles for a login ID | 设置用户角色 -func SetRoles(loginID interface{}, roles []string) error { - return GetManager().SetRoles(toString(loginID), roles) -} +// Disable disables an account for specified duration | 封禁账号(指定时长) +func Disable(ctx context.Context, loginID interface{}, duration time.Duration, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// GetRoles gets role list | 获取角色列表 -func GetRoles(loginID interface{}) ([]string, error) { - return GetManager().GetRoles(toString(loginID)) + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.Disable(ctx, id, duration) + } } -// HasRole checks if has specified role | 检查是否拥有指定角色 -func HasRole(loginID interface{}, role string) bool { - return GetManager().HasRole(toString(loginID), role) -} +// DisableByToken disables the account associated with the given token for a duration | 根据指定 Token 封禁其对应的账号 +func DisableByToken(ctx context.Context, tokenValue string, duration time.Duration, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) -func HasRolesAnd(loginID interface{}, roles []string) bool { - return GetManager().HasRolesAnd(toString(loginID), roles) -} + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return err + } -// HasRolesOr 检查是否拥有任一角色(OR) -func HasRolesOr(loginID interface{}, roles []string) bool { - return GetManager().HasRolesOr(toString(loginID), roles) + return mgr.Disable(ctx, loginID, duration) } -// ============ Token标签 ============ +// Untie re-enables a disabled account | 解封账号 +func Untie(ctx context.Context, loginID interface{}, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// SetTokenTag 设置Token标签 -func SetTokenTag(tokenValue, tag string) error { - return GetManager().SetTokenTag(tokenValue, tag) + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.Untie(ctx, id) + } } -// GetTokenTag 获取Token标签 -func GetTokenTag(tokenValue string) (string, error) { - return GetManager().GetTokenTag(tokenValue) -} +// UntieByToken re-enables the account associated with the given token | 根据指定 Token 解封其对应的账号 +func UntieByToken(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } -// ============ 会话查询 ============ + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return err + } -// GetTokenValueList 获取指定账号的所有Token -func GetTokenValueList(loginID interface{}) ([]string, error) { - return GetManager().GetTokenValueListByLoginID(toString(loginID)) + return mgr.Untie(ctx, loginID) } -// GetSessionCount 获取指定账号的Session数量 -func GetSessionCount(loginID interface{}) (int, error) { - return GetManager().GetSessionCountByLoginID(toString(loginID)) +// IsDisable checks if an account is disabled | 检查账号是否被封禁 +func IsDisable(ctx context.Context, loginID interface{}, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.IsDisable(ctx, id) + } } -// ============ 辅助方法 ============ +// IsDisableByToken checks if an account associated with the token is disabled | 根据指定 Token 检查账号是否被封禁 +func IsDisableByToken(ctx context.Context, tokenValue string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } -// toString 将interface{}转换为string -func toString(v interface{}) string { - switch val := v.(type) { - case string: - return val - case int: - return intToString(val) - case int64: - return int64ToString(val) - case uint: - return uintToString(val) - case uint64: - return uint64ToString(val) - default: - return "" + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false } -} -func intToString(i int) string { - return int64ToString(int64(i)) + return mgr.IsDisable(ctx, loginID) } -func int64ToString(i int64) string { - // 简单实现,可以用 strconv.FormatInt(i, 10) 但为了减少依赖 - if i == 0 { - return "0" +// GetDisableTime gets remaining disable time in seconds | 获取剩余封禁时间(秒) +func GetDisableTime(ctx context.Context, loginID interface{}, authType ...string) (int64, error) { + mgr, err := GetManager(authType...) + if err != nil { + return 0, err } - negative := i < 0 - if negative { - i = -i + if id, err := toString(loginID); err != nil { + return 0, err + } else { + return mgr.GetDisableTTL(ctx, id) } +} - var result []byte - for i > 0 { - result = append([]byte{byte('0' + i%10)}, result...) - i /= 10 +// GetDisableTimeByToken gets remaining disable time by token (seconds) | 根据 Token 获取剩余封禁时间(秒) +func GetDisableTimeByToken(ctx context.Context, tokenValue string, authType ...string) (int64, error) { + mgr, err := GetManager(authType...) + if err != nil { + return 0, err } - if negative { - result = append([]byte{'-'}, result...) + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return 0, err } - return string(result) + return mgr.GetDisableTTL(ctx, loginID) } -func uintToString(u uint) string { - return uint64ToString(uint64(u)) +// GetDisableInfo gets disable info | 获取封禁信息 +func GetDisableInfo(ctx context.Context, loginID interface{}, authType ...string) (*manager.DisableInfo, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + if id, err := toString(loginID); err != nil { + return nil, err + } else { + return mgr.GetDisableInfo(ctx, id) + } } -func uint64ToString(u uint64) string { - if u == 0 { - return "0" +// GetDisableInfoByToken gets disable info by token | 根据Token获取封禁信息 +func GetDisableInfoByToken(ctx context.Context, tokenValue string, authType ...string) (*manager.DisableInfo, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err } - var result []byte - for u > 0 { - result = append([]byte{byte('0' + u%10)}, result...) - u /= 10 + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return nil, err } - return string(result) + return mgr.GetDisableInfo(ctx, loginID) } -func GenerateNonce() (string, error) { - if globalManager == nil { - panic("Manager not initialized. Call stputil.SetManager() first") +// ============ Session Management | Session管理 ============ + +// GetSession gets session by login ID | 根据登录ID获取Session +func GetSession(ctx context.Context, loginID interface{}, authType ...string) (*session.Session, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + if id, err := toString(loginID); err != nil { + return nil, err + } else { + return mgr.GetSession(ctx, id) } - return globalManager.GenerateNonce() } -func VerifyNonce(nonce string) bool { - if globalManager == nil { - panic("Manager not initialized. Call stputil.SetManager() first") +// GetSessionByToken gets session by token | 根据Token获取Session +func GetSessionByToken(ctx context.Context, tokenValue string, authType ...string) (*session.Session, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err } - return globalManager.VerifyNonce(nonce) + + return mgr.GetSessionByToken(ctx, tokenValue) } -func LoginWithRefreshToken(loginID interface{}, device ...string) (*security.RefreshTokenInfo, error) { - if globalManager == nil { - panic("Manager not initialized. Call stputil.SetManager() first") +// DeleteSession deletes a session | 删除Session +func DeleteSession(ctx context.Context, loginID interface{}, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err } - deviceType := "default" - if len(device) > 0 { - deviceType = device[0] + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.DeleteSession(ctx, id) } - return globalManager.LoginWithRefreshToken(fmt.Sprintf("%v", loginID), deviceType) } -func RefreshAccessToken(refreshToken string) (*security.RefreshTokenInfo, error) { - if globalManager == nil { - panic("Manager not initialized. Call stputil.SetManager() first") +// DeleteSessionByToken Deletes session by token | 根据Token删除Session +func DeleteSessionByToken(ctx context.Context, tokenValue string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err } - return globalManager.RefreshAccessToken(refreshToken) + + return mgr.DeleteSessionByToken(ctx, tokenValue) } -func RevokeRefreshToken(refreshToken string) error { - if globalManager == nil { - panic("Manager not initialized. Call stputil.SetManager() first") +// HasSession checks if session exists | 检查Session是否存在 +func HasSession(ctx context.Context, loginID interface{}, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false } - return globalManager.RevokeRefreshToken(refreshToken) -} -func GetOAuth2Server() *oauth2.OAuth2Server { - if globalManager == nil { - panic("Manager not initialized. Call stputil.SetManager() first") + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasSession(ctx, id) } - return globalManager.GetOAuth2Server() } -// ============ Check Functions for Token-based operations | 基于Token的检查函数 ============ - -// CheckDisable checks if the account associated with the token is disabled | 检查Token对应账号是否被封禁 -func CheckDisable(tokenValue string) error { - loginID, err := GetLoginID(tokenValue) +// RenewSession renews session TTL | 续期Session +func RenewSession(ctx context.Context, loginID interface{}, ttl time.Duration, authType ...string) error { + mgr, err := GetManager(authType...) if err != nil { return err } - if IsDisable(loginID) { - return fmt.Errorf("account is disabled") + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.RenewSession(ctx, id, ttl) } - return nil } -// CheckPermission checks if the token has the specified permission | 检查Token是否拥有指定权限 -func CheckPermission(tokenValue string, permission string) error { - loginID, err := GetLoginID(tokenValue) +// RenewSessionByToken renews session TTL by token | 根据Token续期Session +func RenewSessionByToken(ctx context.Context, tokenValue string, ttl time.Duration, authType ...string) error { + mgr, err := GetManager(authType...) if err != nil { return err } - if !HasPermission(loginID, permission) { - return fmt.Errorf("permission denied: %s", permission) + + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return err } - return nil + + return mgr.RenewSession(ctx, loginID, ttl) } -// CheckPermissionAnd checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 -func CheckPermissionAnd(tokenValue string, permissions []string) error { - loginID, err := GetLoginID(tokenValue) +// ============ Permission Verification | 权限验证 ============ + +// SetPermissions sets permissions for a login ID | 设置用户权限 +func SetPermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + mgr, err := GetManager(authType...) if err != nil { return err } - if !HasPermissionsAnd(loginID, permissions) { - return fmt.Errorf("permission denied: %v", permissions) + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.SetPermissions(ctx, id, permissions) } - return nil } -// CheckPermissionOr checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 -func CheckPermissionOr(tokenValue string, permissions []string) error { - loginID, err := GetLoginID(tokenValue) +// SetPermissionsByToken sets permissions by token | 根据 Token 设置对应账号的权限 +func SetPermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + mgr, err := GetManager(authType...) if err != nil { return err } - if !HasPermissionsOr(loginID, permissions) { - return fmt.Errorf("permission denied: %v", permissions) - } - return nil -} -// GetPermissionList gets permission list for the token | 获取Token对应的权限列表 -func GetPermissionList(tokenValue string) ([]string, error) { - loginID, err := GetLoginID(tokenValue) + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) if err != nil { - return nil, err + return err } - return GetPermissions(loginID) + + return mgr.SetPermissions(ctx, loginID, permissions) } -// CheckRole checks if the token has the specified role | 检查Token是否拥有指定角色 -func CheckRole(tokenValue string, role string) error { - loginID, err := GetLoginID(tokenValue) +// RemovePermissions removes specified permissions for a login ID | 删除用户指定权限 +func RemovePermissions(ctx context.Context, loginID interface{}, permissions []string, authType ...string) error { + mgr, err := GetManager(authType...) if err != nil { return err } - if !HasRole(loginID, role) { - return fmt.Errorf("role denied: %s", role) + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.RemovePermissions(ctx, id, permissions) } - return nil } -// CheckRoleAnd checks if the token has all specified roles | 检查Token是否拥有所有指定角色 -func CheckRoleAnd(tokenValue string, roles []string) error { - loginID, err := GetLoginID(tokenValue) +// RemovePermissionsByToken removes specified permissions by token | 根据 Token 删除对应账号的指定权限 +func RemovePermissionsByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) error { + mgr, err := GetManager(authType...) if err != nil { return err } - if !HasRolesAnd(loginID, roles) { - return fmt.Errorf("role denied: %v", roles) + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return err } - return nil + + return mgr.RemovePermissions(ctx, loginID, permissions) } -// CheckRoleOr checks if the token has any of the specified roles | 检查Token是否拥有任一指定角色 -func CheckRoleOr(tokenValue string, roles []string) error { - loginID, err := GetLoginID(tokenValue) +// GetPermissions gets permission list | 获取权限列表 +func GetPermissions(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + mgr, err := GetManager(authType...) if err != nil { - return err + return nil, err } - if !HasRolesOr(loginID, roles) { - return fmt.Errorf("role denied: %v", roles) + + if id, err := toString(loginID); err != nil { + return nil, err + } else { + return mgr.GetPermissions(ctx, id) } - return nil } -// GetRoleList gets role list for the token | 获取Token对应的角色列表 -func GetRoleList(tokenValue string) ([]string, error) { - loginID, err := GetLoginID(tokenValue) +// GetPermissionsByToken gets permission list by token | 根据 Token 获取对应账号的权限列表 +func GetPermissionsByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) if err != nil { return nil, err } - return GetRoles(loginID) + + return mgr.GetPermissions(ctx, loginID) +} + +// HasPermission checks if has specified permission | 检查是否拥有指定权限 +func HasPermission(ctx context.Context, loginID interface{}, permissions string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasPermission(ctx, id, permissions) + } } -// GetTokenSession gets session for the token | 获取Token对应的Session -func GetTokenSession(tokenValue string) (*session.Session, error) { - return GetSessionByToken(tokenValue) +// HasPermissionByToken checks if the token has the specified permission | 检查Token是否拥有指定权限 +func HasPermissionByToken(ctx context.Context, tokenValue string, permission string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return mgr.HasPermission(ctx, loginID, permission) +} + +// HasPermissionsAnd checks if has all permissions (AND logic) | 检查是否拥有所有权限(AND逻辑) +func HasPermissionsAnd(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasPermissionsAnd(ctx, id, permissions) + } +} + +// HasPermissionsAndByToken checks if the token has all specified permissions | 检查Token是否拥有所有指定权限 +func HasPermissionsAndByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return mgr.HasPermissionsAnd(ctx, loginID, permissions) +} + +// HasPermissionsOr checks if has any permission (OR logic) | 检查是否拥有任一权限(OR逻辑) +func HasPermissionsOr(ctx context.Context, loginID interface{}, permissions []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasPermissionsOr(ctx, id, permissions) + } +} + +// HasPermissionsOrByToken checks if the token has any of the specified permissions | 检查Token是否拥有任一指定权限 +func HasPermissionsOrByToken(ctx context.Context, tokenValue string, permissions []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + loginID, err := mgr.GetLoginIDNotCheck(ctx, tokenValue) + if err != nil { + return false + } + + return mgr.HasPermissionsOr(ctx, loginID, permissions) +} + +// ============ Role Verification | 角色验证 ============ + +// SetRoles sets roles for a login ID | 设置用户角色 +func SetRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.SetRoles(ctx, id, roles) + } +} + +// SetRolesByToken sets roles by token | 根据 Token 设置对应账号的角色 +func SetRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return err + } + + return mgr.SetRoles(ctx, loginID, roles) +} + +// RemoveRoles removes specified roles for a login ID | 删除用户指定角色 +func RemoveRoles(ctx context.Context, loginID interface{}, roles []string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + if id, err := toString(loginID); err != nil { + return err + } else { + return mgr.RemoveRoles(ctx, id, roles) + } +} + +// RemoveRolesByToken removes specified roles by token | 根据 Token 删除对应账号的指定角色 +func RemoveRolesByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return err + } + + return mgr.RemoveRoles(ctx, loginID, roles) +} + +// GetRoles gets role list | 获取角色列表 +func GetRoles(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + if id, err := toString(loginID); err != nil { + return nil, err + } else { + return mgr.GetRoles(ctx, id) + } +} + +// GetRolesByToken gets role list by token | 根据 Token 获取对应账号的角色列表 +func GetRolesByToken(ctx context.Context, tokenValue string, authType ...string) ([]string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return nil, err + } + + return mgr.GetRoles(ctx, loginID) +} + +// HasRole checks if has specified role | 检查是否拥有指定角色 +func HasRole(ctx context.Context, loginID interface{}, role string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasRole(ctx, id, role) + } +} + +// HasRoleByToken checks if the token has the specified role | 检查 Token 是否拥有指定角色 +func HasRoleByToken(ctx context.Context, tokenValue string, role string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return false + } + + return mgr.HasRole(ctx, loginID, role) +} + +// HasRolesAnd checks if has all roles (AND logic) | 检查是否拥有所有角色(AND逻辑) +func HasRolesAnd(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasRolesAnd(ctx, id, roles) + } +} + +// HasRolesAndByToken checks if the token has all specified roles | 检查 Token 是否拥有所有指定角色 +func HasRolesAndByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return false + } + + return mgr.HasRolesAnd(ctx, loginID, roles) +} + +// HasRolesOr 检查是否拥有任一角色(OR) +func HasRolesOr(ctx context.Context, loginID interface{}, roles []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + if id, err := toString(loginID); err != nil { + return false + } else { + return mgr.HasRolesOr(ctx, id, roles) + } +} + +// HasRolesOrByToken checks if the token has any of the specified roles | 检查 Token 是否拥有任一指定角色 +func HasRolesOrByToken(ctx context.Context, tokenValue string, roles []string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + loginID, err := mgr.GetLoginIDNotCheck( + ctx, tokenValue, + ) + if err != nil { + return false + } + + return mgr.HasRolesOr(ctx, loginID, roles) +} + +// ============ Token Tag | Token 标签 ============ + +// SetTokenTag 设置Token标签 +func SetTokenTag(ctx context.Context, tokenValue, tag string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + return mgr.SetTokenTag(tag) +} + +// GetTokenTag 获取Token标签 +func GetTokenTag(ctx context.Context, tokenValue string, authType ...string) (string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return "", err + } + + return mgr.GetTokenTag(ctx) +} + +// ============ Token & Session Info | Token 与会话信息查询 ============ + +// GetTokenValueListByLoginID 获取指定账号的所有Token +func GetTokenValueListByLoginID(ctx context.Context, loginID interface{}, authType ...string) ([]string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + if id, err := toString(loginID); err != nil { + return nil, err + } else { + return mgr.GetTokenValueListByLoginID(ctx, id) + } +} + +// GetSessionCountByLoginID 获取指定账号的Session数量 +func GetSessionCountByLoginID(ctx context.Context, loginID interface{}, authType ...string) (int, error) { + mgr, err := GetManager(authType...) + if err != nil { + return 0, err + } + + if id, err := toString(loginID); err != nil { + return 0, err + } else { + return mgr.GetSessionCountByLoginID(ctx, id) + } +} + +// ============ Security Features | 安全特性 ============ + +// Generate Generates a one-time nonce | 生成一次性随机数 +func Generate(ctx context.Context, authType ...string) (string, error) { + mgr, err := GetManager(authType...) + if err != nil { + return "", err + } + + return mgr.SecurityGenerateNonce(ctx) +} + +// Verify Verifies a nonce | 验证随机数 +func Verify(ctx context.Context, nonce string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + return mgr.SecurityVerifyNonce(ctx, nonce) +} + +// VerifyAndConsume Verifies and consumes nonce, returns error if invalid | 验证并消费nonce,无效时返回错误 +func VerifyAndConsume(ctx context.Context, nonce string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + return mgr.SecurityVerifyAndConsumeNonce(ctx, nonce) +} + +// IsValidNonce Checks if nonce is valid without consuming it | 检查nonce是否有效(不消费) +func IsValidNonce(ctx context.Context, nonce string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + return mgr.SecurityIsValidNonce(ctx, nonce) +} + +// GenerateTokenPair Create access + refresh token | 生成访问令牌和刷新令牌 +func GenerateTokenPair(ctx context.Context, loginID interface{}, deviceOrAutoType ...string) (*security.RefreshTokenInfo, error) { + mgr, err := GetManager(deviceOrAutoType...) + if err != nil { + return nil, err + } + + if id, err := toString(loginID); err != nil { + return nil, err + } else { + return mgr.SecurityGenerateTokenPair(ctx, id, mgr.GetDevice(deviceOrAutoType)) + } +} + +// VerifyAccessToken verifies access token validity | 验证访问令牌是否有效 +func VerifyAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + return mgr.SecurityVerifyAccessToken(ctx, accessToken) +} + +// VerifyAccessTokenAndGetInfo verifies access token and returns token info | 验证访问令牌并返回Token信息 +func VerifyAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*security.AccessTokenInfo, bool) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, false + } + + return mgr.SecurityVerifyAccessTokenAndGetInfo(ctx, accessToken) +} + +// GetRefreshTokenInfo gets refresh token information | 获取刷新令牌信息 +func GetRefreshTokenInfo(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.SecurityGetRefreshTokenInfo(ctx, refreshToken) +} + +// RefreshAccessToken refreshes access token using refresh token | 使用刷新令牌刷新访问令牌 +func RefreshAccessToken(ctx context.Context, refreshToken string, authType ...string) (*security.RefreshTokenInfo, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.SecurityRefreshAccessToken(ctx, refreshToken) +} + +// RevokeRefreshToken Revokes refresh token | 撤销刷新令牌 +func RevokeRefreshToken(ctx context.Context, refreshToken string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + return mgr.SecurityRevokeRefreshToken(ctx, refreshToken) +} + +// IsValid checks whether token is valid | 检查Token是否有效 +func IsValid(ctx context.Context, refreshToken string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + return mgr.SecurityIsRefreshTokenValid(ctx, refreshToken) +} + +// ============ OAuth2 Features | Oauth2特性 ============ + +// RegisterClient Registers an OAuth2 client | 注册OAuth2客户端 +func RegisterClient(ctx context.Context, client *oauth2.Client, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + return mgr.OAuth2RegisterClient(client) +} + +// UnregisterClient unregisters an OAuth2 client | 注销OAuth2客户端 +func UnregisterClient(ctx context.Context, clientID string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + mgr.OAuth2UnregisterClient(clientID) + + return nil +} + +// GetClient gets OAuth2 client information | 获取OAuth2客户端信息 +func GetClient(ctx context.Context, clientID string, authType ...string) (*oauth2.Client, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2GetClient(clientID) +} + +// GenerateAuthorizationCode creates an authorization code | 创建授权码 +func GenerateAuthorizationCode(ctx context.Context, clientID, loginID, redirectURI string, scope []string, authType ...string) (*oauth2.AuthorizationCode, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2GenerateAuthorizationCode(ctx, clientID, loginID, redirectURI, scope) +} + +// ExchangeCodeForToken exchanges authorization code for token | 使用授权码换取令牌 +func ExchangeCodeForToken(ctx context.Context, code, clientID, clientSecret, redirectURI string, authType ...string) (*oauth2.AccessToken, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2ExchangeCodeForToken(ctx, code, clientID, clientSecret, redirectURI) +} + +// ValidateAccessToken verifies OAuth2 access token | 验证OAuth2访问令牌 +func ValidateAccessToken(ctx context.Context, accessToken string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + + return mgr.OAuth2ValidateAccessToken(ctx, accessToken) +} + +// ValidateAccessTokenAndGetInfo verifies OAuth2 access token and get info | 验证OAuth2访问令牌并获取信息 +func ValidateAccessTokenAndGetInfo(ctx context.Context, accessToken string, authType ...string) (*oauth2.AccessToken, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2ValidateAccessTokenAndGetInfo(ctx, accessToken) +} + +// OAuth2RefreshAccessToken Refreshes access token using refresh token | 使用刷新令牌刷新访问令牌(OAuth2) +func OAuth2RefreshAccessToken(ctx context.Context, clientID, refreshToken, clientSecret string, authType ...string) (*oauth2.AccessToken, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2RefreshAccessToken(ctx, clientID, refreshToken, clientSecret) +} + +// RevokeToken Revokes access token and its refresh token | 撤销访问令牌及其刷新令牌 +func RevokeToken(ctx context.Context, accessToken string, authType ...string) error { + mgr, err := GetManager(authType...) + if err != nil { + return err + } + + return mgr.OAuth2RevokeToken(ctx, accessToken) +} + +// OAuth2Token Unified token endpoint that dispatches to appropriate handler based on grant type | 统一的令牌端点,根据授权类型分发到相应的处理逻辑 +func OAuth2Token(ctx context.Context, req *oauth2.TokenRequest, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2Token(ctx, req, validateUser) +} + +// OAuth2ClientCredentialsToken Gets access token using client credentials grant | 使用客户端凭证模式获取访问令牌 +func OAuth2ClientCredentialsToken(ctx context.Context, clientID, clientSecret string, scopes []string, authType ...string) (*oauth2.AccessToken, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2ClientCredentialsToken(ctx, clientID, clientSecret, scopes) +} + +// OAuth2PasswordGrantToken Gets access token using resource owner password credentials grant | 使用密码模式获取访问令牌 +func OAuth2PasswordGrantToken(ctx context.Context, clientID, clientSecret, username, password string, scopes []string, validateUser oauth2.UserValidator, authType ...string) (*oauth2.AccessToken, error) { + mgr, err := GetManager(authType...) + if err != nil { + return nil, err + } + + return mgr.OAuth2PasswordGrantToken(ctx, clientID, clientSecret, username, password, scopes, validateUser) +} + +// ============ Public Getters | 公共获取器 ============ + +// GetConfig returns the manager-example configuration | 获取 Manager 当前使用的配置 +func GetConfig(ctx context.Context, authType ...string) *config.Config { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetConfig() +} + +// GetStorage returns the storage adapter | 获取 Manager 使用的存储适配器 +func GetStorage(ctx context.Context, authType ...string) adapter.Storage { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetStorage() +} + +// GetCodec returns the codec (serializer) | 获取 Manager 使用的编解码器 +func GetCodec(ctx context.Context, authType ...string) adapter.Codec { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetCodec() +} + +// GetLog returns the logger adapter | 获取 Manager 使用的日志适配器 +func GetLog(ctx context.Context, authType ...string) adapter.Log { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetLog() +} + +// GetPool returns the goroutine pool | 获取 Manager 使用的协程池 +func GetPool(ctx context.Context, authType ...string) adapter.Pool { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetPool() +} + +// GetGenerator returns the token generator | 获取 Token 生成器 +func GetGenerator(ctx context.Context, authType ...string) adapter.Generator { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetGenerator() +} + +// GetNonceManager returns the nonce manager-example | 获取随机串管理器 +func GetNonceManager(ctx context.Context, authType ...string) *security.NonceManager { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetNonceManager() +} + +// GetRefreshManager returns the refresh token manager-example | 获取刷新令牌管理器 +func GetRefreshManager(ctx context.Context, authType ...string) *security.RefreshTokenManager { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetRefreshManager() +} + +// GetEventManager returns the event manager-example | 获取事件管理器 +func GetEventManager(ctx context.Context, authType ...string) *listener.Manager { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetEventManager() +} + +// GetOAuth2Server Gets OAuth2 server instance | 获取OAuth2服务器实例 +func GetOAuth2Server(ctx context.Context, authType ...string) *oauth2.OAuth2Server { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetOAuth2Server() +} + +// ============ Event Management | 事件管理 ============ + +// RegisterFunc registers a function as an event listener | 注册函数作为事件监听器 +func RegisterFunc(event listener.Event, fn func(*listener.EventData), authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.RegisterFunc(event, fn) +} + +// Register registers an event listener | 注册事件监听器 +func Register(event listener.Event, l listener.Listener, authType ...string) string { + mgr, err := GetManager(authType...) + if err != nil { + return "" + } + return mgr.Register(event, l) +} + +// RegisterWithConfig registers an event listener with config | 注册带配置的事件监听器 +func RegisterWithConfig(event listener.Event, l listener.Listener, config listener.ListenerConfig, authType ...string) string { + mgr, err := GetManager(authType...) + if err != nil { + return "" + } + return mgr.RegisterWithConfig(event, l, config) +} + +// Unregister removes an event listener by ID | 根据ID移除事件监听器 +func Unregister(id string, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + return mgr.Unregister(id) +} + +// TriggerEvent manually triggers an event | 手动触发事件 +func TriggerEvent(data *listener.EventData, authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.TriggerEvent(data) +} + +// TriggerEventAsync triggers an event asynchronously and returns immediately | 异步触发事件并立即返回 +func TriggerEventAsync(data *listener.EventData, authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.TriggerEventAsync(data) +} + +// TriggerEventSync triggers an event synchronously and waits for all listeners | 同步触发事件并等待所有监听器完成 +func TriggerEventSync(data *listener.EventData, authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.TriggerEventSync(data) +} + +// WaitEvents waits for all async event listeners to complete | 等待所有异步事件监听器完成 +func WaitEvents(authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.WaitEvents() +} + +// ClearEventListeners removes all listeners for a specific event | 清除指定事件的所有监听器 +func ClearEventListeners(event listener.Event, authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.ClearEventListeners(event) +} + +// ClearAllEventListeners removes all listeners | 清除所有事件监听器 +func ClearAllEventListeners(authType ...string) { + mgr, err := GetManager(authType...) + if err != nil { + return + } + mgr.ClearAllEventListeners() +} + +// CountEventListeners returns the number of listeners for a specific event | 获取指定事件监听器数量 +func CountEventListeners(event listener.Event, authType ...string) int { + mgr, err := GetManager(authType...) + if err != nil { + return 0 + } + return mgr.CountEventListeners(event) +} + +// CountAllListeners returns the total number of registered listeners | 获取已注册监听器总数 +func CountAllListeners(authType ...string) int { + mgr, err := GetManager(authType...) + if err != nil { + return 0 + } + return mgr.CountAllListeners() +} + +// GetEventListenerIDs returns all listener IDs for a specific event | 获取指定事件的所有监听器ID +func GetEventListenerIDs(event listener.Event, authType ...string) []string { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetEventListenerIDs(event) +} + +// GetAllRegisteredEvents returns all events that have registered listeners | 获取所有已注册事件 +func GetAllRegisteredEvents(authType ...string) []listener.Event { + mgr, err := GetManager(authType...) + if err != nil { + return nil + } + return mgr.GetAllRegisteredEvents() +} + +// HasEventListeners checks if there are any listeners for a specific event | 检查指定事件是否有监听器 +func HasEventListeners(event listener.Event, authType ...string) bool { + mgr, err := GetManager(authType...) + if err != nil { + return false + } + return mgr.HasEventListeners(event) +} + +// ============ Internal Helper Methods | 内部辅助方法 ============ + +// SetManager stores the manager-example in the global map using the specified autoType | 使用指定的 autoType 将管理器存储在全局 map 中 +func SetManager(mgr *manager.Manager) { + // Validate and get the autoType value | 验证并获取 autoType 值 + validAutoType := getAutoType(mgr.GetConfig().AuthType) // 获取 autoType,默认为 config.DefaultAuthType + // Store the manager-example in the global map with the valid autoType | 使用有效的 autoType 将管理器存储在全局 map 中 + globalManagerMap.Store(validAutoType, mgr) +} + +// GetManager retrieves the manager-example from the global map using the specified autoType | 使用指定的 autoType 从全局 map 中获取管理器 +func GetManager(autoType ...string) (*manager.Manager, error) { + // Validate and get the autoType value | 验证并获取 autoType 值 + validAutoType := getAutoType(autoType...) // 获取 autoType,默认为 config.DefaultAuthType + // Use LoadManager to retrieve the manager-example | 使用 LoadManager 方法来获取管理器 + return loadManager(validAutoType) +} + +// DeleteManager delete the specific manager-example for the given autoType and releases resources | 删除指定的管理器并释放资源 +func DeleteManager(autoType ...string) error { + // Validate and get the autoType value | 验证并获取 autoType 值 + validAutoType := getAutoType(autoType...) // 获取 autoType,默认为 config.DefaultAuthType + // Load the manager-example from global map | 从全局 map 中加载管理器 + mgr, err := loadManager(validAutoType) + if err != nil { + return err + } + // Close the manager-example and release resources | 关闭管理器并释放资源 + mgr.CloseManager() + // Remove the manager-example from the global map | 从全局 map 中移除该管理器 + globalManagerMap.Delete(validAutoType) + return nil +} + +// DeleteAllManager delete all managers in the global map and releases resources | 关闭所有管理器并释放资源 +func DeleteAllManager() { + // Iterate over all managers in the global map and close them | 遍历全局 map 中的所有管理器并关闭它们 + globalManagerMap.Range(func(key, value interface{}) bool { + // Assert the value to the correct type | 将值断言为正确的类型 + mgr, ok := value.(*manager.Manager) + if ok { + // Close each manager-example | 关闭每个管理器 + mgr.CloseManager() + } + // Continue iterating | 继续遍历 + return true + }) + // Clear the global map after closing all managers | 关闭所有管理器后清空全局 map + globalManagerMap = sync.Map{} +} + +// getAutoType checks if a valid autoType is provided, ensures it's trimmed, appends ":" if missing, and returns the value | 检查是否提供有效的 autoType,修剪空格,如果缺少 ":" 则添加,并返回值 +func getAutoType(autoType ...string) string { + // Check if autoType is provided and not empty, trim it and append ":" if missing | 检查是否提供了有效的 autoType,修剪空格,如果缺少 ":" 则添加 + if len(autoType) > 1 && strings.TrimSpace(autoType[1]) != "" { + trimmed := strings.TrimSpace(autoType[1]) + // If it doesn't end with ":", append ":" | 如果 autoType 的值不以 ":" 结尾,则添加 ":" + if !strings.HasSuffix(trimmed, ":") { + trimmed = trimmed + ":" + } + return trimmed + } + // Return log autoType if autoType is empty or invalid | 如果 autoType 为空或无效,返回默认值 + return config.DefaultAuthType +} + +// loadManager retrieves the manager-example from the global map using the valid autoType | 使用有效的 autoType 从全局 map 中加载管理器 +func loadManager(autoType string) (*manager.Manager, error) { + // Load the manager-example from the global map using the valid autoType | 使用有效的 autoType 从全局 map 中加载管理器 + value, ok := globalManagerMap.Load(autoType) + if !ok { + return nil, core.ErrManagerNotFound + } + // Assert the loaded value to the correct type | 将加载的值断言为正确的类型 + mgr, ok := value.(*manager.Manager) + if !ok { + return nil, core.ErrManagerInvalidType + } + return mgr, nil +} + +// toString Converts interface{} to string | 将interface{}转换为string +func toString(v interface{}) (string, error) { + // Check the type and convert to string | 判断类型并转换为字符串 + switch val := v.(type) { + case string: + return val, nil // If it's a string, return it directly | 如果是字符串,直接返回 + case int: + return intToString(val), nil // If it's int, convert to string | 如果是int,转换为string + case int64: + return int64ToString(val), nil // If it's int64, convert to string | 如果是int64,转换为string + case uint: + return uintToString(val), nil // If it's uint, convert to string | 如果是uint,转换为string + case uint64: + return uint64ToString(val), nil // If it's uint64, convert to string | 如果是uint64,转换为string + default: + return "", fmt.Errorf("Invalid type") // For other types, return error | 对于其他类型,返回错误 + } +} + +// intToString Converts int to string | 将int转换为string +func intToString(i int) string { + return int64ToString(int64(i)) // Call int64ToString to convert | 调用int64ToString进行转换 +} + +// int64ToString Converts int64 to string | 将int64转换为string +func int64ToString(i int64) string { + // If it's zero, return "0" | 如果是零,返回 "0" + if i == 0 { + return "0" + } + + // Check if it's negative and handle it | 判断是否为负数并处理 + negative := i < 0 + if negative { + i = -i // Take the absolute value | 取绝对值 + } + + var result []byte + // Process each digit and prepend to the result array | 将每一位数字依次处理并添加到结果数组 + for i > 0 { + result = append([]byte{byte('0' + i%10)}, result...) + i /= 10 + } + + // If it's negative, add the '-' sign | 如果是负数,添加负号 + if negative { + result = append([]byte{'-'}, result...) + } + + return string(result) +} + +// uintToString Converts uint to string | 将uint转换为string +func uintToString(u uint) string { + return uint64ToString(uint64(u)) // Call uint64ToString to convert | 调用uint64ToString进行转换 +} + +// uint64ToString Converts uint64 to string | 将uint64转换为string +func uint64ToString(u uint64) string { + // If it's zero, return "0" | 如果是零,返回 "0" + if u == 0 { + return "0" + } + + var result []byte + // Process each digit and prepend to the result array | 将每一位数字依次处理并添加到结果数组 + for u > 0 { + result = append([]byte{byte('0' + u%10)}, result...) + u /= 10 + } + + return string(result) }