Skip to content

Commit

Permalink
Merge pull request #3 from nmaltais/master
Browse files Browse the repository at this point in the history
Add --jwt-claims flag to add custom claims to the generated JWTs via CLI
  • Loading branch information
domsolutions authored Jun 29, 2023
2 parents c17c4aa + 30fa3ba commit 8941653
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 148 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Flags:
--jwt-key string JWT signing private key path
--jwt-kid string JWT KID
--jwt-sub string JWT subject (sub) claim
--jwt-claims string JWT custom claims as a JSON string, ex: {"iat": 1719410063, "browser": "chrome"}
-m, --method string request method (default "GET")
--mtls-cert string mTLS cert path
--mtls-key string mTLS cert private key path
Expand Down
48 changes: 26 additions & 22 deletions cmd/payloader/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,28 +9,29 @@ import (
)

const (
argMethod = "method"
argConnections = "connections"
argRequests = "requests"
argKeepAlive = "disable-keep-alive"
argVerifySigner = "skip-verify"
argTime = "time"
argMTLSKey = "mtls-key"
argMTLSCert = "mtls-cert"
argReadTimeout = "read-timeout"
argWriteTimeout = "write-timeout"
argVerbose = "verbose"
argTicker = "ticker"
argJWTKey = "jwt-key"
argJWTSUb = "jwt-sub"
argJWTIss = "jwt-iss"
argJWTAud = "jwt-aud"
argJWTHeader = "jwt-header"
argJWTKid = "jwt-kid"
argHeaders = "headers"
argBody = "body"
argBodyFile = "body-file"
argClient = "client"
argMethod = "method"
argConnections = "connections"
argRequests = "requests"
argKeepAlive = "disable-keep-alive"
argVerifySigner = "skip-verify"
argTime = "time"
argMTLSKey = "mtls-key"
argMTLSCert = "mtls-cert"
argReadTimeout = "read-timeout"
argWriteTimeout = "write-timeout"
argVerbose = "verbose"
argTicker = "ticker"
argJWTKey = "jwt-key"
argJWTSUb = "jwt-sub"
argJWTCustomClaims = "jwt-claims"
argJWTIss = "jwt-iss"
argJWTAud = "jwt-aud"
argJWTHeader = "jwt-header"
argJWTKid = "jwt-kid"
argHeaders = "headers"
argBody = "body"
argBodyFile = "body-file"
argClient = "client"
)

var (
Expand All @@ -49,6 +50,7 @@ var (
ticker time.Duration
jwtKey string
jwtSub string
jwtCustomClaims string
jwtIss string
jwtAud string
jwtHeader string
Expand Down Expand Up @@ -86,6 +88,7 @@ var runCmd = &cobra.Command{
jwtKID,
jwtKey,
jwtSub,
jwtCustomClaims,
jwtIss,
jwtAud,
jwtHeader,
Expand Down Expand Up @@ -124,6 +127,7 @@ func init() {
runCmd.Flags().StringVar(&jwtAud, argJWTAud, "", "JWT audience (aud) claim")
runCmd.Flags().StringVar(&jwtIss, argJWTIss, "", "JWT issuer (iss) claim")
runCmd.Flags().StringVar(&jwtSub, argJWTSUb, "", "JWT subject (sub) claim")
runCmd.Flags().StringVar(&jwtCustomClaims, argJWTCustomClaims, "", "JWT custom claims")
runCmd.Flags().StringVar(&jwtHeader, argJWTHeader, "", "JWT header field name")

runCmd.MarkFlagsRequiredTogether(argMTLSCert, argMTLSKey)
Expand Down
127 changes: 77 additions & 50 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,62 +9,65 @@ import (
"regexp"
"strings"
"time"
"encoding/json"
)

type Config struct {
Ctx context.Context
ReqURI string
DisableKeepAlive bool
ReqTarget int64
Conns uint
Duration time.Duration
MTLSKey string
MTLSCert string
SkipVerify bool
ReadTimeout time.Duration
WriteTimeout time.Duration
Method string
Verbose bool
VerboseTicker time.Duration
JwtKID string
JwtKey string
JwtSub string
JwtIss string
JwtAud string
JwtHeader string
SendJWT bool
Headers []string
Body string
BodyFile string
Client string
Ctx context.Context
ReqURI string
DisableKeepAlive bool
ReqTarget int64
Conns uint
Duration time.Duration
MTLSKey string
MTLSCert string
SkipVerify bool
ReadTimeout time.Duration
WriteTimeout time.Duration
Method string
Verbose bool
VerboseTicker time.Duration
JwtKID string
JwtKey string
JwtSub string
JwtCustomClaimsJSON string
JwtIss string
JwtAud string
JwtHeader string
SendJWT bool
Headers []string
Body string
BodyFile string
Client string
}

func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtIss, jwtAud, jwtHeader string, headers []string, body, bodyFile string, client string) *Config {
func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader string, headers []string, body, bodyFile string, client string) *Config {
return &Config{
Ctx: ctx,
ReqURI: reqURI,
MTLSKey: mTLSKey,
MTLSCert: mTLScert,
DisableKeepAlive: disableKeepAlive,
ReqTarget: reqs,
Conns: conns,
Duration: totalTime,
SkipVerify: skipVerify,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
Method: method,
Verbose: verbose,
VerboseTicker: ticker,
JwtKID: jwtKID,
JwtKey: jwtKey,
JwtSub: jwtSub,
JwtIss: jwtIss,
JwtAud: jwtAud,
JwtHeader: jwtHeader,
Headers: headers,
Body: body,
BodyFile: bodyFile,
Client: client,
Ctx: ctx,
ReqURI: reqURI,
MTLSKey: mTLSKey,
MTLSCert: mTLScert,
DisableKeepAlive: disableKeepAlive,
ReqTarget: reqs,
Conns: conns,
Duration: totalTime,
SkipVerify: skipVerify,
ReadTimeout: readTimeout,
WriteTimeout: writeTimeout,
Method: method,
Verbose: verbose,
VerboseTicker: ticker,
JwtKID: jwtKID,
JwtKey: jwtKey,
JwtSub: jwtSub,
JwtCustomClaimsJSON: jwtCustomClaimsJSON,
JwtIss: jwtIss,
JwtAud: jwtAud,
JwtHeader: jwtHeader,
Headers: headers,
Body: body,
BodyFile: bodyFile,
Client: client,
}
}

Expand All @@ -83,6 +86,22 @@ var allowedMethods = [4]string{
"DELETE",
}

// Converts jwtCustomClaimsJSON from string to map[string]interface{}
func JwtCustomClaimsJSONStringToMap(jwtCustomClaimsJSON string) (map[string]interface{}, error) {
if jwtCustomClaimsJSON == "" {
return nil, nil
}

jwtCustomClaimsMap := map[string]interface{}{}

err := json.Unmarshal([]byte(jwtCustomClaimsJSON), &jwtCustomClaimsMap)
if err != nil {
return nil, err
}

return jwtCustomClaimsMap, nil
}

func (c *Config) Validate() error {
if _, err := url.ParseRequestURI(c.ReqURI); err != nil {
return fmt.Errorf("config: invalid request uri, got error %v", err)
Expand Down Expand Up @@ -180,6 +199,14 @@ func (c *Config) Validate() error {
if c.ReqTarget == 0 && c.Duration == 0 {
return errors.New("config: ReqTarget 0 and Duration 0")
}

if c.JwtCustomClaimsJSON != "" {
_, err := JwtCustomClaimsJSONStringToMap(c.JwtCustomClaimsJSON)
if err != nil {
return fmt.Errorf("config: failed to parse custom json in --jwt-claims, got error; %v", err)
}
}

return nil
}

Expand Down
40 changes: 29 additions & 11 deletions pkgs/jwt-generator/jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import (
"encoding/hex"
"errors"
"fmt"
"strings"
jwt_signer "github.com/domsolutions/gopayloader/pkgs/jwt-signer"
"github.com/domsolutions/gopayloader/pkgs/jwt-signer/definition"
config "github.com/domsolutions/gopayloader/config"
"github.com/golang-jwt/jwt"
"github.com/google/uuid"
"github.com/pterm/pterm"
Expand All @@ -22,15 +24,16 @@ const (
)

type Config struct {
Ctx context.Context
Kid string
JwtKeyPath string
jwtKeyBlob []byte
JwtSub string
JwtIss string
JwtAud string
signer definition.Signer
store *cache
Ctx context.Context
Kid string
JwtKeyPath string
jwtKeyBlob []byte
JwtSub string
JwtCustomClaimsJSON string
JwtIss string
JwtAud string
signer definition.Signer
store *cache
}

type JWTGenerator struct {
Expand Down Expand Up @@ -60,7 +63,9 @@ func (j *JWTGenerator) getFileName(dir string) string {
hash.Write([]byte(j.config.JwtAud))
hash.Write([]byte(j.config.JwtIss))
hash.Write([]byte(j.config.JwtSub))
hash.Write(j.config.jwtKeyBlob)
hash.Write([]byte(j.config.JwtCustomClaimsJSON))
strippedKey := strings.ReplaceAll(strings.ReplaceAll(string(j.config.jwtKeyBlob), "\r", ""), "\n", "") // Replace \r and \n to have the same value in Windows and Linux
hash.Write([]byte(strippedKey))
hash.Write([]byte(j.config.Kid))
return filepath.Join(dir, "gopayloader-jwtstore-"+hex.EncodeToString(hash.Sum(nil))+".txt")
}
Expand Down Expand Up @@ -163,8 +168,8 @@ func (j *JWTGenerator) generate(limit int64, errs chan<- error, response chan<-
var err error
var i int64 = 0

claims := j.commonClaims() // Claims common to all JWTs, computed only once
for i = 0; i < limit; i++ {
claims := j.commonClaims()
claims["jti"] = uuid.New().String()
tokens[i], err = j.config.signer.Generate(claims)
if err != nil {
Expand All @@ -187,5 +192,18 @@ func (j *JWTGenerator) commonClaims() jwt.MapClaims {
claims["iss"] = j.config.JwtIss
}
claims["exp"] = time.Now().Add(24 * time.Hour * 365).Unix()

if j.config.JwtCustomClaimsJSON != "" {
// At this point the JSON in JwtCustomClaimsJSON has already been validated, but checking for errors again in case the workflow changes in the future
jwtCustomClaimsMap, err := config.JwtCustomClaimsJSONStringToMap(j.config.JwtCustomClaimsJSON)
if err != nil {
return claims // Return claims if there's an error
}
for key, value := range jwtCustomClaimsMap {
if key != "" {
claims[key] = value
}
}
}
return claims
}
13 changes: 7 additions & 6 deletions pkgs/payloader/payloader.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,13 @@ func (p *PayLoader) handleReqs() (*GoPayloaderResults, error) {
pterm.Info.Printf("Sending jwts with requests, checking for jwts in cache\n")

jwt := jwt_generator.NewJWTGenerator(&jwt_generator.Config{
Ctx: p.config.Ctx,
Kid: p.config.JwtKID,
JwtKeyPath: p.config.JwtKey,
JwtSub: p.config.JwtSub,
JwtIss: p.config.JwtIss,
JwtAud: p.config.JwtAud,
Ctx: p.config.Ctx,
Kid: p.config.JwtKID,
JwtKeyPath: p.config.JwtKey,
JwtSub: p.config.JwtSub,
JwtCustomClaimsJSON: p.config.JwtCustomClaimsJSON,
JwtIss: p.config.JwtIss,
JwtAud: p.config.JwtAud,
})

if err := os.MkdirAll(JwtCacheDir, 0755); err != nil {
Expand Down
Loading

0 comments on commit 8941653

Please sign in to comment.