From 2778f648b5081a7898b2a9aac6a8769f51bfaa1e Mon Sep 17 00:00:00 2001 From: developerWIT Date: Thu, 21 Mar 2024 13:37:59 +0700 Subject: [PATCH] feature: jwt middleware for fasthttp framework --- go.mod | 15 +++++++ go.sum | 12 ++++++ helper.go | 43 +++++++++++++++++++ jwt.go | 110 ++++++++++++++++++++++++++++++++++++++++++++++++ jwtwrapper.go | 42 ++++++++++++++++++ model.go | 39 +++++++++++++++++ service.go | 71 +++++++++++++++++++++++++++++++ service_test.go | 70 ++++++++++++++++++++++++++++++ 8 files changed, 402 insertions(+) create mode 100644 go.mod create mode 100644 go.sum create mode 100644 helper.go create mode 100644 jwt.go create mode 100644 jwtwrapper.go create mode 100644 model.go create mode 100644 service.go create mode 100644 service_test.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..80d2096 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/erajayatech/go-fasthttp-keycloak-middleware + +go 1.22.1 + +require ( + github.com/cristalhq/jwt/v3 v3.1.0 + github.com/joho/godotenv v1.5.1 + github.com/valyala/fasthttp v1.52.0 +) + +require ( + github.com/andybalholm/brotli v1.1.0 // indirect + github.com/klauspost/compress v1.17.7 // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5f96add --- /dev/null +++ b/go.sum @@ -0,0 +1,12 @@ +github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= +github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/cristalhq/jwt/v3 v3.1.0 h1:iLeL9VzB0SCtjCy9Kg53rMwTcrNm+GHyVcz2eUujz6s= +github.com/cristalhq/jwt/v3 v3.1.0/go.mod h1:XOnIXst8ozq/esy5N1XOlSyQqBd+84fxJ99FK+1jgL8= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg= +github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.52.0 h1:wqBQpxH71XW0e2g+Og4dzQM8pk34aFYlA1Ga8db7gU0= +github.com/valyala/fasthttp v1.52.0/go.mod h1:hf5C4QnVMkNXMspnsUlfM3WitlgYflyhHYoKol/szxQ= diff --git a/helper.go b/helper.go new file mode 100644 index 0000000..48afe04 --- /dev/null +++ b/helper.go @@ -0,0 +1,43 @@ +package keycloakmiddleware + +import ( + "encoding/base64" + "encoding/json" + "github.com/joho/godotenv" + "log" + "math/big" + "os" +) + +func getEnvOrDefault(key string, defaultValue interface{}) interface{} { + value := os.Getenv(key) + if len(value) == 0 { + return defaultValue + } + return value +} + +func getEnv(key string) string { + err := godotenv.Load() + if err != nil { + log.Println("Cannot load file .env: ", err) + panic(err) + } + + value := getEnvOrDefault(key, "").(string) + return value +} + +func decodeBase64BigInt(s string) *big.Int { + buffer, _ := base64.URLEncoding.WithPadding(base64.NoPadding).DecodeString(s) + return big.NewInt(0).SetBytes(buffer) +} + +func prettyPrint(data interface{}) string { + JSON, err := json.MarshalIndent(data, "", " ") + if err != nil { + log.Fatalf(err.Error()) + } + + return string(JSON) +} diff --git a/jwt.go b/jwt.go new file mode 100644 index 0000000..923ba85 --- /dev/null +++ b/jwt.go @@ -0,0 +1,110 @@ +package keycloakmiddleware + +import ( + "encoding/json" + "fmt" + "github.com/cristalhq/jwt/v3" + "github.com/valyala/fasthttp" + "net/http" + "strings" + "time" +) + +type middleware struct { + wrapperCode int // 0: default, 1:standard, 2:traceable +} + +func Construct(wrapperCode int) middleware { + return middleware{wrapperCode: wrapperCode} +} + +func (middleware *middleware) Validate(scopes []string, next fasthttp.RequestHandler) fasthttp.RequestHandler { + return fasthttp.RequestHandler(func(ctx *fasthttp.RequestCtx) { + var isEnabled = getEnvOrDefault("KEYCLOAK_JWT_ENABLED", "false").(string) + if strings.ToLower(isEnabled) == "false" || isEnabled == "0" { + return + } + + authHeader := string(ctx.Request.Header.Peek("Authorization")) + s := strings.SplitN(authHeader, " ", 2) + if len(s) != 2 { + msg := "Authorization token is not found" + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + headerToken := s[1] + unverifiedToken, err := jwt.Parse([]byte(headerToken)) + if err != nil { + msg := err.Error() + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + kid := unverifiedToken.Header().KeyID + key, err := getPublicKey(kid) + if err != nil { + msg := err.Error() + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + verifier, err := jwt.NewVerifierRS(jwt.RS256, key) + if err != nil { + msg := err.Error() + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + token, err := jwt.ParseAndVerifyString(headerToken, verifier) + if err != nil { + msg := err.Error() + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + var claims claims + errClaims := json.Unmarshal(token.RawClaims(), &claims) + if errClaims != nil { + msg := errClaims.Error() + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + var iss = getEnv("KEYCLOAK_JWT_ISS") + if claims.Issuer != iss { + msg := "Token issuer is not valid" + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + if claims.ExpiresAt.Unix() < time.Now().Unix() { + msg := "Token expired" + middleware.abort(http.StatusUnauthorized, ctx, msg) + return + } + + if !isScopesValid(claims, scopes) { + msg := "Access to this endpoint is not allowed" + middleware.abort(http.StatusForbidden, ctx, msg) + return + } + + ctx.SetUserValue("keycloak_username", claims.Username) + ctx.SetUserValue("keycloak_name", claims.Name) + ctx.SetUserValue("keycloak_email", claims.Email) + + next(ctx) + }) +} + +func (middleware *middleware) abort(status int, ctx *fasthttp.RequestCtx, message interface{}) { + httpStatus := http.StatusOK + if middleware.wrapperCode != 0 { + httpStatus = status + } + ctx.SetStatusCode(httpStatus) + ctx.SetContentType("application/json") + response := middleware.wrapper(httpStatus, ctx, message, nil) + fmt.Fprintf(ctx, prettyPrint(response)) +} diff --git a/jwtwrapper.go b/jwtwrapper.go new file mode 100644 index 0000000..4562022 --- /dev/null +++ b/jwtwrapper.go @@ -0,0 +1,42 @@ +package keycloakmiddleware + +import ( + "github.com/valyala/fasthttp" +) + +func (middleware *middleware) wrapper(status int, context *fasthttp.RequestCtx, message interface{}, data interface{}) map[string]interface{} { + if middleware.wrapperCode == 2 { + return middleware.traceableWrapper(context, message, data) + } else if middleware.wrapperCode == 1 { + return middleware.standardWrapper(message, data) + } else { + return middleware.defaultWrapper(status, message, data) + } +} + +func (middleware *middleware) defaultWrapper(status int, message interface{}, data interface{}) map[string]interface{} { + return map[string]interface{}{ + "status": status, + "error_message": message, + "data": data, + } +} + +func (middleware *middleware) standardWrapper(message interface{}, data interface{}) map[string]interface{} { + return map[string]interface{}{ + "message": message, + "data": data, + } +} + +func (middleware *middleware) traceableWrapper(context *fasthttp.RequestCtx, message interface{}, data interface{}) map[string]interface{} { + var traceID = context.Value("X-Trace-Id") + return map[string]interface{}{ + "id": traceID, + "appName": getEnvOrDefault("APP_NAME", nil), + "version": getEnvOrDefault("APP_VERSION", nil), + "build": getEnvOrDefault("BUILD", nil), + "message": message, + "data": data, + } +} diff --git a/model.go b/model.go new file mode 100644 index 0000000..2bb33b6 --- /dev/null +++ b/model.go @@ -0,0 +1,39 @@ +package keycloakmiddleware + +import "github.com/cristalhq/jwt/v3" + +// Set all model to private + +type claims struct { + jwt.StandardClaims + Authorization authorization `json:"authorization,omitempty"` + Username string `json:"preferred_username,omitempty"` + Name string `json:"name,omitempty"` + Email string `json:"email,omitempty"` +} + +type authorization struct { + Permissions []permission `json:"permissions,omitempty"` +} + +type permission struct { + RsID string `json:"rsid,omitempty"` + RsName string `json:"rsname,omitempty"` + Scopes []string `json:"scopes,omitempty"` +} + +type keycloakJWKDetail struct { + Key string `json:"kty"` + Kid string `json:"kid"` + Use string `json:"sig"` + Alg string `json:"alg"` + N string `json:"n"` + E string `json:"e"` + X5c []string `json:"x5c"` + X5t string `json:"x5t"` + X5tS256 string `json:"x5t#S256"` +} + +type keycloakJWK struct { + Keys []keycloakJWKDetail `json:"keys"` +} diff --git a/service.go b/service.go new file mode 100644 index 0000000..b418f56 --- /dev/null +++ b/service.go @@ -0,0 +1,71 @@ +package keycloakmiddleware + +import ( + "crypto/rsa" + "encoding/json" + "io/ioutil" + "math/big" + "net/http" +) + +func getPublicKey(kid string) (*rsa.PublicKey, error) { + var keysUrl = getEnv("KEYCLOAK_JWT_JWK_ENDPOINT") + keysRequest, err := http.NewRequest("GET", keysUrl, nil) + if err != nil { + return nil, err + } + + keysResponse, err := http.DefaultClient.Do(keysRequest) + if err != nil { + return nil, err + } + + keysResponseBody, err := ioutil.ReadAll(keysResponse.Body) + if err != nil { + return nil, err + } + + var jwk keycloakJWK + err = json.Unmarshal([]byte(keysResponseBody), &jwk) + if err != nil { + return nil, err + } + + var n *big.Int + var e int + for _, key := range jwk.Keys { + if key.Kid == kid { + n = decodeBase64BigInt(key.N) + e = int(decodeBase64BigInt(key.E).Int64()) + break + } + } + + if n == nil || e == 0 { + return nil, err + } + + jwtKey := &rsa.PublicKey{ + N: n, + E: e, + } + return jwtKey, nil +} + +func isScopesValid(claims claims, scopes []string) bool { + scopeMap := make(map[string]struct{}) + + for _, search := range scopes { + scopeMap[search] = struct{}{} + } + + for _, permission := range claims.Authorization.Permissions { + for _, scope := range permission.Scopes { + if _, exists := scopeMap[scope]; exists { + return true + } + } + } + + return false +} diff --git a/service_test.go b/service_test.go new file mode 100644 index 0000000..c17b63e --- /dev/null +++ b/service_test.go @@ -0,0 +1,70 @@ +package keycloakmiddleware + +import ( + "testing" +) + +func Test_isScopesValid(t *testing.T) { + type args struct { + claims claims + scopes []string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "scope valid", + args: args{ + claims: claims{ + Authorization: authorization{ + Permissions: []permission{ + {Scopes: []string{"foo", "bar", "baz"}}, + {Scopes: []string{"qux", "fred"}}, + }, + }, + }, + scopes: []string{"bar", "fred"}, + }, + want: true, + }, + { + name: "scope valid 2", + args: args{ + claims: claims{ + Authorization: authorization{ + Permissions: []permission{ + {Scopes: []string{"foo", "bar", "baz"}}, + {Scopes: []string{"qux", "fred"}}, + }, + }, + }, + scopes: []string{"qux", "thud"}, + }, + want: true, + }, + { + name: "scope valid 2", + args: args{ + claims: claims{ + Authorization: authorization{ + Permissions: []permission{ + {Scopes: []string{"foo", "bar", "baz"}}, + {Scopes: []string{"qux", "fred"}}, + }, + }, + }, + scopes: []string{"thud", "chips"}, + }, + want: false, + }, + } + for i := range tests { + t.Run(tests[i].name, func(t *testing.T) { + if got := isScopesValid(tests[i].args.claims, tests[i].args.scopes); got != tests[i].want { + t.Errorf("isScopesValid() = %v, want %v", got, tests[i].want) + } + }) + } +}