Skip to content

Commit

Permalink
Merge pull request #17 from domsolutions/fix-user-supplied-jwts
Browse files Browse the repository at this point in the history
Fix user supplied jwts
  • Loading branch information
domsolutions authored Aug 21, 2023
2 parents 2c94e56 + c2d5af6 commit bbf64f4
Show file tree
Hide file tree
Showing 13 changed files with 276 additions and 89 deletions.
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ Flags:
-H, --headers strings headers to send in request, can have multiple i.e -H 'content-type:application/json' -H' connection:close'
-h, --help help for run
--jwt-aud string JWT audience (aud) claim
--jwt-claims string JWT custom claims
--jwt-header string JWT header field name
--jwt-iss string JWT issuer (iss) claim
--jwt-key string JWT signing private key path
--jwt-kid string JWT KID
--jwt-sub string JWT subject (sub) claim
--jwt-claims string JWT custom claims as a JSON string, ex: {"iat": 1719410063, "browser": "chrome"}
-f, --jwts-filename string File path for pre-generated JWTs, separated by new lines
-m, --method string request method (default "GET")
--mtls-cert string mTLS cert path
--mtls-key string mTLS cert private key path
Expand Down Expand Up @@ -221,6 +222,13 @@ https://github.com/domsolutions/gopayloader
+-----------------------+-------------------------------+
```
If you have your own JWTs you want to test, you can supply a file to send the JWTs i.e. `./my-jwts.txt` where each jwt is separated by a new line.
```shell
./gopayloader run http://localhost:8081 -c 1 -r 1000000 --jwt-header "my-jwt" -f ./my-jwts.txt
```
To remove all generated jwts;
```shell
Expand Down
56 changes: 32 additions & 24 deletions cmd/payloader/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,30 @@ import (
)

const (
argMethod = "method"
argConnections = "connections"
argRequests = "requests"
argKeepAlive = "disable-keep-alive"
argVerifySigner = "skip-verify"
argTime = "time"
argMTLSKey = "mtls-key"
argMTLSCert = "mtls-cert"
argReadTimeout = "read-timeout"
argWriteTimeout = "write-timeout"
argVerbose = "verbose"
argTicker = "ticker"
argJWTKey = "jwt-key"
argJWTSUb = "jwt-sub"
argMethod = "method"
argConnections = "connections"
argRequests = "requests"
argKeepAlive = "disable-keep-alive"
argVerifySigner = "skip-verify"
argTime = "time"
argMTLSKey = "mtls-key"
argMTLSCert = "mtls-cert"
argReadTimeout = "read-timeout"
argWriteTimeout = "write-timeout"
argVerbose = "verbose"
argTicker = "ticker"
argJWTKey = "jwt-key"
argJWTSUb = "jwt-sub"
argJWTCustomClaims = "jwt-claims"
argJWTIss = "jwt-iss"
argJWTAud = "jwt-aud"
argJWTHeader = "jwt-header"
argJWTKid = "jwt-kid"
argHeaders = "headers"
argBody = "body"
argBodyFile = "body-file"
argClient = "client"
argJWTIss = "jwt-iss"
argJWTAud = "jwt-aud"
argJWTHeader = "jwt-header"
argJWTKid = "jwt-kid"
argJWTsFilename = "jwts-filename"
argHeaders = "headers"
argBody = "body"
argBodyFile = "body-file"
argClient = "client"
)

var (
Expand All @@ -55,6 +56,7 @@ var (
jwtAud string
jwtHeader string
jwtKID string
jwtsFilename string
headers *[]string
body string
bodyFile string
Expand Down Expand Up @@ -92,6 +94,7 @@ var runCmd = &cobra.Command{
jwtIss,
jwtAud,
jwtHeader,
jwtsFilename,
*headers,
body,
bodyFile,
Expand Down Expand Up @@ -128,11 +131,16 @@ func init() {
runCmd.Flags().StringVar(&jwtIss, argJWTIss, "", "JWT issuer (iss) claim")
runCmd.Flags().StringVar(&jwtSub, argJWTSUb, "", "JWT subject (sub) claim")
runCmd.Flags().StringVar(&jwtCustomClaims, argJWTCustomClaims, "", "JWT custom claims")
runCmd.Flags().StringVarP(&jwtsFilename, argJWTsFilename, "f", "", "File path for pre-generated JWTs, separated by new lines")
runCmd.Flags().StringVar(&jwtHeader, argJWTHeader, "", "JWT header field name")

runCmd.MarkFlagsRequiredTogether(argMTLSCert, argMTLSKey)
runCmd.MarkFlagsRequiredTogether(argJWTKey, argJWTHeader)
runCmd.MarkFlagsMutuallyExclusive(argBody, argBodyFile)

runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTKid)
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTAud)
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTIss)
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTCustomClaims)
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTSUb)
runCmd.MarkFlagsMutuallyExclusive(argJWTsFilename, argJWTKey)
rootCmd.AddCommand(runCmd)
}
22 changes: 20 additions & 2 deletions cmd/payloader/test-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@ import (
"log"
"net/http"
"os"
"os/signal"
"path/filepath"
"strconv"
"strings"
"syscall"
"time"
)

Expand Down Expand Up @@ -87,9 +89,25 @@ var runServerCmd = &cobra.Command{
},
}

if err := server.ListenAndServe(addr); err != nil {
return err
errs := make(chan error)
go func() {
if err := server.ListenAndServe(addr); err != nil {
log.Println(err)
errs <- err
}
}()

c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt, syscall.SIGTERM)

select {
case <-c:
log.Println("User cancelled, shutting down")
server.Shutdown()
case err := <-errs:
log.Printf("Got error from server; %v \n", err)
}

return nil
}

Expand Down
32 changes: 24 additions & 8 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ type Config struct {
JwtIss string
JwtAud string
JwtHeader string
JwtsFilename string
SendJWT bool
Headers []string
Body string
BodyFile string
Client string
}

func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader string, headers []string, body, bodyFile string, client string) *Config {
func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKeepAlive bool, reqs int64, conns uint, totalTime time.Duration, skipVerify bool, readTimeout, writeTimeout time.Duration, method string, verbose bool, ticker time.Duration, jwtKID, jwtKey, jwtSub, jwtCustomClaimsJSON, jwtIss, jwtAud, jwtHeader, jwtsFilename string, headers []string, body, bodyFile string, client string) *Config {
return &Config{
Ctx: ctx,
ReqURI: reqURI,
Expand All @@ -64,6 +65,7 @@ func NewConfig(ctx context.Context, reqURI, mTLScert, mTLSKey string, disableKee
JwtIss: jwtIss,
JwtAud: jwtAud,
JwtHeader: jwtHeader,
JwtsFilename: jwtsFilename,
Headers: headers,
Body: body,
BodyFile: bodyFile,
Expand Down Expand Up @@ -139,14 +141,14 @@ func (c *Config) Validate() error {
}
}

if (c.JwtHeader == "") != (c.JwtKey == "") {
if c.JwtHeader == "" {
return errors.New("config: empty jwt header")
}
// Require JwtHeader if JwtKey or JwtsFilename is present
if (c.JwtsFilename != "" || c.JwtKey != "") && c.JwtHeader == "" {
return errors.New("config: empty jwt header")
}

if c.JwtKey == "" {
return errors.New("empty jwt key")
}
// Require JwtKey or JwtsFilename if JwtHeader is present
if c.JwtHeader != "" && c.JwtsFilename == "" && c.JwtKey == "" {
return errors.New("config: empty jwt filename and jwt key, one of those is needed to send requests with JWTs")
}

if c.JwtKey != "" {
Expand All @@ -163,6 +165,20 @@ func (c *Config) Validate() error {
c.SendJWT = true
}

if c.JwtsFilename != "" {
_, err := os.OpenFile(c.JwtsFilename, os.O_RDONLY, os.ModePerm)
if err != nil {
if os.IsNotExist(err) {
return errors.New("config: jwt file does not exist: " + c.JwtsFilename)
}
return fmt.Errorf("config: jwt file error checking file exists; %v", err)
}
if c.ReqTarget == 0 {
return errors.New("can only send jwts when request number is specified")
}
c.SendJWT = true
}

if len(c.Headers) > 0 {
for _, h := range c.Headers {
if !strings.Contains(h, ":") {
Expand Down
1 change: 0 additions & 1 deletion pkgs/http-clients/definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ type Config struct {
Method string
Verbose bool
JwtStreamReceiver <-chan string
JwtStreamErr <-chan error
JWTHeader string
Headers []string
Body string
Expand Down
73 changes: 60 additions & 13 deletions pkgs/jwt-generator/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,17 @@ package jwt_generator

import (
"bufio"
"encoding/binary"
"errors"
"fmt"
"github.com/pterm/pterm"
"os"
"strconv"
"strings"
"time"
)

const byteSizeCounter = 20

type cache struct {
f *os.File
count int64
Expand All @@ -20,14 +23,22 @@ func newCache(f *os.File) (*cache, error) {
c := cache{f: f}

c.scanner = bufio.NewScanner(c.f)
// Get count found on first line of the file
c.scanner.Split(bufio.ScanLines)
if c.scanner.Scan() {
bb := make([]byte, 8)
bb := make([]byte, byteSizeCounter)
_, err := f.ReadAt(bb, 0)
if err != nil {
return nil, err
}
c.count = int64(binary.LittleEndian.Uint64(bb))

count, err := getCount(bb)
if err != nil {
pterm.Error.Printf("Got error reading jwt count from cache; %v", err)
return nil, err
}

c.count = count
return &c, nil
}
return &c, nil
Expand All @@ -37,6 +48,23 @@ func (c *cache) getJwtCount() int64 {
return c.count
}

func getCount(bb []byte) (int64, error) {
num := make([]byte, 0)
for _, m := range bb {
if m == 0 {
break
}
num = append(num, m)
}

s := string(num)
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return 0, err
}
return i, nil
}

func (c *cache) get(count int64) (<-chan string, <-chan error) {
recv := make(chan string, 1000000)
errs := make(chan error, 1)
Expand All @@ -61,14 +89,22 @@ func (c *cache) get(count int64) (<-chan string, <-chan error) {
}

meta := c.scanner.Bytes()
if len(meta) < 8 {
if len(meta) < byteSizeCounter {
errs <- fmt.Errorf("jwt_generator: retrieving; corrupt jwt cache, wanted 8 bytes got %d", len(meta))
close(errs)
close(recv)
return recv, errs
}

if count > int64(binary.LittleEndian.Uint64(meta[0:8])) {
i, err := getCount(meta)
if err != nil {
errs <- fmt.Errorf("failed to get jwt count; %v", err)
close(errs)
close(recv)
return recv, errs
}

if count > i {
errs <- errors.New("jwt_generator: retrieving; not enough jwts stored in cache")
close(errs)
close(recv)
Expand All @@ -83,20 +119,25 @@ func (c *cache) get(count int64) (<-chan string, <-chan error) {

func (c *cache) retrieve(count int64, recv chan<- string, errs chan<- error) {
var i int64 = 0
defer func() {
close(errs)
close(recv)
}()

for i = 0; i < count; i++ {
if c.scanner.Scan() {
recv <- string(c.scanner.Bytes())
continue
}
// reached EOF or err

if err := c.scanner.Err(); err != nil {
errs <- err
close(errs)
return
}
break

errs <- errors.New("unable to read anymore jwts from file")
return
}
close(recv)
}

func (c *cache) save(tokens []string) error {
Expand All @@ -110,19 +151,25 @@ func (c *cache) save(tokens []string) error {
if stat.Size() > 0 {
pos = stat.Size()
}

if _, err := c.f.WriteAt([]byte(strings.Join(tokens, "\n")+"\n"), pos); err != nil {
return err
}

b := make([]byte, 8)
newCount := uint64(int64(add) + c.count)
binary.LittleEndian.PutUint64(b, newCount)
newCount := int64(add) + c.count
s := strconv.FormatInt(newCount, 10)

b := make([]byte, byteSizeCounter)
for i, ss := range s {
b[i] = byte(ss)
}

_, err = c.f.WriteAt(b, 0)
if err != nil {
return err
}

_, err = c.f.WriteAt([]byte{byte('\n')}, 9)
_, err = c.f.WriteAt([]byte{byte('\n')}, byteSizeCounter)
if err != nil {
return err
}
Expand Down
Loading

0 comments on commit bbf64f4

Please sign in to comment.