diff --git a/v2/bk.go b/v2/bk.go new file mode 100644 index 0000000..82383c7 --- /dev/null +++ b/v2/bk.go @@ -0,0 +1,240 @@ +package bk + +import ( + "context" + "errors" + "io" + "net/http" + "strconv" + "time" +) + +type Keeper struct { + ctx context.Context + cancel context.CancelFunc + fn RoundTrip + retries int + ticker *time.Ticker + delay time.Duration + concurrency int + concurrencyticker chan bool + requests chan<- *requestWrapper + requestTimeout time.Duration +} + +// Receive creates a new listening channel for http request wrappers. After +// creating the request channel it then monitors the delay timer (aka ticker) +// for each tick then checks for an available concurrency entry on the +// concurrency channel to process work. Once it's cleared the ticker and +// concurrency channel it then pulls an available request from the request +// channel and executes the http request against the endpoint and returns the +// response across the response channel of the request along with any errors +// that occurred when making the request +func (k *Keeper) receive() chan<- *requestWrapper { + reqs := make(chan *requestWrapper) + + go func(reqs chan *requestWrapper) { + defer func() { + if r := recover(); r != nil { + k.cancel() + } + }() + + k.reqHandler(reqs) + }(reqs) + + return reqs +} + +func (k *Keeper) reqHandler(reqs chan *requestWrapper) { + for { + select { + case <-k.ctx.Done(): + return + case _, ok := <-k.ticker.C: + if !ok { + return + } + + select { + case <-k.ctx.Done(): + return + case _, ok = <-k.concurrencyticker: + if !ok { + return + } + + k.process(reqs) + } + } + } +} + +func (k *Keeper) process(requests chan *requestWrapper) { + defer func() { + _ = recover() + }() + + select { + case <-k.ctx.Done(): + case req, ok := <-requests: + if !ok { + return + } + + go k.handleRequest(req) + } +} + +func (k *Keeper) handleRequest(req *requestWrapper) { + defer func() { + if r := recover(); r != nil { + return + } + + // Must have the context done in the select here otherwise if the ctx is + // closed then this will cause a panic because of sending on a closed + // channel + select { + case <-k.ctx.Done(): + case k.concurrencyticker <- true: + } + }() + + // Execute a call against the endpoint handling any potential panics from + // the http client + resp, err := k.execute(req) + if resp == nil { + select { + case <-req.ctx.Done(): + case req.response <- responseWrapper{resp, err}: + } + + return + } + + if (err != nil || resp.StatusCode >= 400) && req.attempts < k.retries { + // Read and close the body of the response + readAndClose(resp.Body) + + go k.resend(req, timer(resp.Header.Get("Retry-After"))) + return + } + + select { + case <-req.ctx.Done(): + case req.response <- responseWrapper{resp, err}: + } +} + +/* +https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html +14.37 Retry-After +The Retry-After response-header field can be used with a +503 (Service Unavailable) response to indicate how long the +service is expected to be unavailable to the requesting client. +This field MAY also be used with any 3xx (Redirection) response +to indicate the minimum time the user-agent is asked wait before +issuing the redirected request. The value of this field can be either +an HTTP-date or an integer number of seconds (in decimal) after the +time of the response. + + Retry-After = "Retry-After" ":" ( HTTP-date | delta-seconds ) + +Two examples of its use are + + Retry-After: Fri, 31 Dec 1999 23:59:59 GMT + Retry-After: 120 + +In the latter example, the delay is 2 minutes. +*/ +func timer(retryHeader string) *time.Timer { + if retryHeader == "" { + return time.NewTimer(0) + } + + delay, err := strconv.Atoi(retryHeader) + if err == nil { + return time.NewTimer(time.Second * time.Duration(delay)) + } + + t, err := time.Parse(time.RFC1123, retryHeader) + if err == nil { + return time.NewTimer(time.Until(t)) + } + + return time.NewTimer(0) +} + +func (k *Keeper) resend(req *requestWrapper, timer *time.Timer) { + defer timer.Stop() + + select { + case <-req.ctx.Done(): + case <-timer.C: + + // Send the request back on the channel + select { + case <-req.ctx.Done(): + case k.requests <- req: + } + } +} + +func (k *Keeper) execute(req *requestWrapper) (resp *http.Response, err error) { + defer func(req *requestWrapper) { + if r := recover(); r != nil { + err = errors.New("panic occurred while executing http request") + return + } + + // increment the attempt counter + req.attempts++ + }(req) + + // TODO: Not sure this is necessary? Causes the TCP connection to not be + // reused... + // req.request.Close = true + + // Execute the http request and return the response to the requester + return k.fn(req.request) +} + +// readAndClose reads all of the contents of the readcloser and closes +// +// As per the Go stdlib documentation for net/http.Response +// +// The default HTTP client's Transport may not reuse HTTP/1.x +// "keep-alive" TCP connections if the Body is not read to +// completion and closed. +func readAndClose(rc io.ReadCloser) { + if rc != nil { + defer func() { + if r := recover(); r != nil { + return + } + }() + + _, _ = io.ReadAll(rc) + _ = rc.Close() + } +} + +// wrapper for transporting requests along a channel along with a response +// channel for returning the response from the endpoint as well as an attempt +// counter for tracking the number of times a request has been attempted in the +// event that it continues to fail +type requestWrapper struct { + ctx context.Context + request *http.Request + response chan<- responseWrapper + attempts int +} + +// wrapper for tracking the response of executing a client.Do against an +// http request. This returns any errors from the bridgekeeper attempting to execute +// the request as well as the http response in the event of a response +type responseWrapper struct { + response *http.Response + err error +} diff --git a/v2/bk_mocks_test.go b/v2/bk_mocks_test.go new file mode 100644 index 0000000..5096071 --- /dev/null +++ b/v2/bk_mocks_test.go @@ -0,0 +1,131 @@ +package bk + +import ( + "context" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type httpclient struct { + delay time.Duration + requests int + status int + retries int + attempts int + concurrency int + cancel bool +} + +func (client *httpclient) RoundTrip(r *http.Request) (*http.Response, error) { + return client.Do(r) +} + +func (client *httpclient) Do(r *http.Request) (*http.Response, error) { + if client.delay > 0 { + time.Sleep(client.delay) + } + + status := client.status + if client.retries > 0 { + client.attempts++ + + if client.attempts < client.retries { + status = http.StatusBadRequest + } + } + + return &http.Response{ + StatusCode: status, + Body: &fakeReadCloser{}, + }, nil +} + +type fakeReadCloser struct{} + +func (rc *fakeReadCloser) Read(p []byte) (n int, err error) { return 0, io.EOF } +func (rc *fakeReadCloser) Close() error { return nil } + +type badclient struct { + panic bool + delay time.Duration + requests int + status int + retries int + attempts int + concurrency int +} + +func (client *badclient) RoundTrip(r *http.Request) (*http.Response, error) { + return client.Do(r) +} + +func (client *badclient) Do(r *http.Request) (*http.Response, error) { + if client.panic { + panic("panic") + } + + return nil, errors.New("error") +} + +type tstruct struct { + error bool +} + +func (t *tstruct) correct(err error, paniced bool) error { + if paniced { + return errors.New("unexpected panic") + } + + if t.error && err == nil { + return errors.New("expected error but success instead") + } + + if !t.error && err != nil { + return fmt.Errorf("expected success but errored instead | %s", err) + } + + return nil +} + +func newGetReqWOutCtx() *http.Request { + r, _ := http.NewRequest( + http.MethodGet, + "", + strings.NewReader(""), + ) + + return r +} + +func newGetReqWCtx() *http.Request { + r, _ := http.NewRequestWithContext( + context.TODO(), + http.MethodGet, + "", + strings.NewReader(""), + ) + + return r +} + +type passthrough struct { + ctx context.Context + out chan *http.Request +} + +func (pass *passthrough) RoundTrip(r *http.Request) (*http.Response, error) { + return pass.Do(r) +} + +func (pass *passthrough) Do(r *http.Request) (*http.Response, error) { + select { + case <-pass.ctx.Done(): + case pass.out <- r: + } + + return nil, nil +} diff --git a/v2/bk_test.go b/v2/bk_test.go new file mode 100644 index 0000000..59449c0 --- /dev/null +++ b/v2/bk_test.go @@ -0,0 +1,567 @@ +package bk + +import ( + "context" + "net/http" + "reflect" + "testing" + "time" +) + +type tcase struct { + client *httpclient + request *http.Request + success tstruct +} + +func cases(t *testing.T, req func() *http.Request) map[string]tcase { + return map[string]tcase{ + "ValidWValidClient": { + &httpclient{ + requests: 1, + status: http.StatusOK, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientConcurrency": { + &httpclient{ + requests: 1, + status: http.StatusOK, + concurrency: 10, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientConcurrency_0_Default": { + &httpclient{ + requests: 1, + status: http.StatusOK, + concurrency: 0, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientDelay": { + &httpclient{ + delay: time.Millisecond * 25, + requests: 1, + status: http.StatusOK, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientDelay_0": { + &httpclient{ + delay: 0, + requests: 1, + status: http.StatusOK, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientDelay_neg": { + &httpclient{ + delay: -1, + requests: 1, + status: http.StatusOK, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientDelayAndConcurrency": { + &httpclient{ + delay: time.Millisecond * 25, + requests: 1, + status: http.StatusOK, + concurrency: 10, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientW5Retries": { + &httpclient{ + requests: 1, + status: http.StatusOK, + retries: 5, + }, + req(), + tstruct{false}, + }, + "ValidWValidClientW0Retries": { + &httpclient{ + requests: 1, + status: http.StatusOK, + retries: 0, + }, + req(), + tstruct{false}, + }, + "FailWValidClientW5Retries4Attempts": { + &httpclient{ + requests: 1, + status: http.StatusOK, + retries: 5, + attempts: -1, + }, + req(), + tstruct{true}, + }, + "FailWBadStatus": { + &httpclient{ + requests: 1, + status: http.StatusBadRequest, + }, + req(), + tstruct{true}, + }, + "FailByCancellation": { + &httpclient{ + requests: 1, + status: http.StatusOK, + cancel: true, + }, + req(), + tstruct{true}, + }, + "FailByCancellation_Millisecond": { + &httpclient{ + delay: time.Millisecond, + requests: 1, + status: http.StatusOK, + cancel: true, + }, + req(), + tstruct{true}, + }, + "FailByNilRequest": { + &httpclient{ + requests: 1, + status: http.StatusOK, + cancel: true, + }, + nil, + tstruct{true}, + }, + } +} + +func Test_Do(t *testing.T) { + alltests := map[string]map[string]tcase{ + "w/ctx-": cases(t, newGetReqWCtx), + "wout/ctx-": cases(t, newGetReqWOutCtx), + } + + for key, tests := range alltests { + for name, test := range tests { + t.Run(key+name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("test [%s] had a panic | %s", name, r) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := New( + ctx, + test.client.Do, + test.client.delay, + test.client.retries, + test.client.concurrency, + time.Minute, + ) + + // Cancellation test + if test.client.cancel { + cancel() + } + + resp, err := client.Do(test.request) + if err != nil { + if test.client.retries > 0 && + test.client.retries != test.client.attempts && + !test.success.error { + t.Fatalf("[%s] failed; number of attempts doesn't match the expected retries [%v:%v]", name, test.client.attempts, test.client.retries) + } else { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + } + + if resp == nil { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + }) + } + } +} + +func Test_DoBadClient(t *testing.T) { + tests := map[string]struct { + client *badclient + request *http.Request + success tstruct + }{ + "PanicyClient": { + &badclient{ + panic: true, + requests: 1, + status: http.StatusOK, + }, + newGetReqWCtx(), + tstruct{true}, + }, + "ErroringClient": { + &badclient{ + requests: 1, + status: http.StatusOK, + }, + newGetReqWCtx(), + tstruct{true}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("test [%s] had a panic | %s", name, r) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := New( + ctx, + test.client.Do, + test.client.delay, + test.client.retries, + test.client.concurrency, + time.Minute, + ) + + resp, err := client.Do(test.request) + if err != nil { + if test.client.retries > 0 && test.client.retries != test.client.attempts && !test.success.error { + t.Fatalf("[%s] failed; number of attempts doesn't match the expected retries [%v:%v]", name, test.client.attempts, test.client.retries) + } else { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + } + + if resp == nil { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + }) + } +} + +func Test_Do_FailOpen(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + code := http.StatusContinue + client := &httpclient{ + status: code, + } + + wrapper := &Keeper{ + ctx: ctx, + fn: client.Do, + cancel: cancel, + concurrencyticker: make(chan bool), + requests: make(chan *requestWrapper), + requestTimeout: time.Minute, + } + + // cancel the context to trigger passthrough + cancel() + + resp, err := wrapper.Do(newGetReqWCtx()) + if err != nil { + t.Fatalf("error %s", err) + } + + if resp.StatusCode != code { + t.Fatalf("Expected status code %v got %v", code, resp.StatusCode) + } +} + +func Test_Do_Request_Timeout_ReqWOutCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := &httpclient{ + delay: time.Minute, + } + client := New( + ctx, + c.Do, + 0, + 0, + 0, + time.Second, + ) + + _, err := client.Do(newGetReqWOutCtx()) + if err != nil { + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded; got %T", err) + } + } else { + t.Fatal("Expected timeout error") + } +} + +func Test_Do_Request_Timeout_ReqWCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := &httpclient{ + delay: time.Minute, + } + client := New( + ctx, + c.Do, + 0, + 0, + 0, + time.Second, + ) + + _, err := client.Do(newGetReqWCtx()) + if err != nil { + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded; got %T", err) + } + } else { + t.Fatal("Expected timeout error") + } +} + +func Test_New_Defaults(t *testing.T) { + // Setting the value for default http timeout + http.DefaultClient.Timeout = time.Millisecond + + k := New(nil, nil, -1, -1, -1, -1) + + if k.requestTimeout != http.DefaultClient.Timeout { + t.Fatal("Expected request timeout to default to http.DefaultClient.Timeout") + } + + if k.retries != 0 { + t.Fatalf("Expected retries to be 0; got %v", k.retries) + } + + if k.delay != time.Nanosecond { + t.Fatalf("Expected delay to be %v; got %v", time.Nanosecond, k.delay) + } + + if k.concurrency != 1 { + t.Fatalf("Expected concurrency 1; got %v", k.concurrency) + } + + if k.ctx == nil || k.cancel == nil { + t.Fatalf( + "Invalid context fallback ctx: %v cancelfunc: %v", + k.ctx, + k.cancel, + ) + } + + if k.ticker == nil { + t.Fatal("Nil keeper ticker") + } + + if k.concurrencyticker == nil { + t.Fatal("Nil keeper concurrency ticker") + } +} + +func Test_Do_Throughput(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p := &passthrough{ + ctx, + make(chan *http.Request), + } + + client := New(ctx, p.Do, 0, 0, 0, time.Minute) + r := newGetReqWCtx() + + go func(r *http.Request) { + client.Do(r) + }(r) + + select { + case <-ctx.Done(): + t.Fatal("context closed prematurely") + case rout, ok := <-p.out: + if !ok { + t.Fatal("passthrough closed prematurely") + } + + if !reflect.DeepEqual(r, rout) { + t.Fatal("requests do not match") + } + } +} + +// NOTE: The timer tests use for the most part a 1 second tolerance +// due to the fact that the timer has to be scheduled and by the time +// the current time is checked it actually may have elapsed an entire second +// this is meant to be as close as possible to real-world + +func Test_timer_intHeader(t *testing.T) { + testdata := map[string]struct { + retryHeader string + expected time.Duration + tolerance time.Duration + }{ + "integer 1s": { + "1", + time.Second, + time.Second, + }, + "integer 2s": { + "2", + time.Second * 2, + time.Second, + }, + "integer 5s": { + "5", + time.Second * 5, + time.Second, + }, + "integer 10s": { + "10", + time.Second * 10, + time.Second, + }, + "invalid parse": { + "not a valid time to parse", + 0, + time.Microsecond * 500, + }, + "empty header": { + "", + 0, + time.Microsecond * 500, + }, + } + + for name, test := range testdata { + t.Run(name, func(t *testing.T) { + timer := timer(test.retryHeader) + defer timer.Stop() + + tstart := time.Now() + + <-timer.C + diff := time.Since(tstart) + expPos := test.expected + test.tolerance + expNeg := test.expected - test.tolerance + if diff < expNeg || diff > expPos { + t.Fatalf("timer exceeded tolerance %s < %s < %s", expNeg, diff, expPos) + } + }) + } +} + +func timedelay(format string, delay time.Duration) string { + return time.Now().Add(delay).Format(format) +} + +func Test_timer_timeHeader(t *testing.T) { + testdata := map[string]struct { + format string + expected time.Duration + tolerance time.Duration + }{ + "RFC1123 - 1s": { + time.RFC1123, + time.Second, + time.Second, + }, + "RFC1123 - 2s": { + time.RFC1123, + time.Second * 2, + time.Second, + }, + "RFC1123 - 5s": { + time.RFC1123, + time.Second * 5, + time.Second, + }, + "RFC1123 - 10s": { + time.RFC1123, + time.Second * 5, + time.Second, + }, + "invalid parse": { + time.RFC3339, + 0, + time.Microsecond * 500, + }, + } + + for name, test := range testdata { + t.Run(name, func(t *testing.T) { + timer := timer(timedelay(test.format, test.expected)) + defer timer.Stop() + + tstart := time.Now() + + <-timer.C + diff := time.Since(tstart) + expPos := test.expected + test.tolerance + expNeg := test.expected - test.tolerance + if diff < expNeg || diff > expPos { + t.Fatalf("timer exceeded tolerance %s < %s < %s", expNeg, diff, expPos) + } + }) + } +} + +func Benchmark_Do_ZeroConcurrency(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p := &passthrough{ + ctx, + make(chan *http.Request), + } + + client := New(ctx, p.Do, 0, 0, 0, time.Minute) + r := newGetReqWCtx() + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + go func(r *http.Request) { + client.Do(r) + }(r) + + select { + case <-ctx.Done(): + b.Fatal("context closed prematurely") + case _, ok := <-p.out: + if !ok { + b.Fatal("passthrough closed prematurely") + } + } + } +} diff --git a/v2/bk_test_rt.go b/v2/bk_test_rt.go new file mode 100644 index 0000000..799e780 --- /dev/null +++ b/v2/bk_test_rt.go @@ -0,0 +1,274 @@ +package bk + +import ( + "context" + "net/http" + "reflect" + "testing" + "time" +) + +func Test_RoundTrip(t *testing.T) { + alltests := map[string]map[string]tcase{ + "w/ctx-": cases(t, newGetReqWCtx), + "wout/ctx-": cases(t, newGetReqWOutCtx), + } + + for key, tests := range alltests { + for name, test := range tests { + t.Run(key+name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("test [%s] had a panic | %s", name, r) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := New( + ctx, + test.client.RoundTrip, + test.client.delay, + test.client.retries, + test.client.concurrency, + time.Minute, + ) + + // Cancellation test + if test.client.cancel { + cancel() + } + + resp, err := client.RoundTrip(test.request) + if err != nil { + if test.client.retries > 0 && + test.client.retries != test.client.attempts && + !test.success.error { + t.Fatalf("[%s] failed; number of attempts doesn't match the expected retries [%v:%v]", name, test.client.attempts, test.client.retries) + } else { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + } + + if resp == nil { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + }) + } + } +} + +func Test_RoundTrip_BadClient(t *testing.T) { + tests := map[string]struct { + client *badclient + request *http.Request + success tstruct + }{ + "PanicyClient": { + &badclient{ + panic: true, + requests: 1, + status: http.StatusOK, + }, + newGetReqWCtx(), + tstruct{true}, + }, + "ErroringClient": { + &badclient{ + requests: 1, + status: http.StatusOK, + }, + newGetReqWCtx(), + tstruct{true}, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatalf("test [%s] had a panic | %s", name, r) + } + }() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + client := New( + ctx, + test.client.RoundTrip, + test.client.delay, + test.client.retries, + test.client.concurrency, + time.Minute, + ) + + resp, err := client.RoundTrip(test.request) + if err != nil { + if test.client.retries > 0 && test.client.retries != test.client.attempts && !test.success.error { + t.Fatalf("[%s] failed; number of attempts doesn't match the expected retries [%v:%v]", name, test.client.attempts, test.client.retries) + } else { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + } + + if resp == nil { + testErr := test.success.correct(err, false) + if testErr != nil { + t.Fatalf("[%s] failed; %s", name, testErr.Error()) + } + } + }) + } +} + +func Test_RoundTrip_FailOpen(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + code := http.StatusContinue + client := &httpclient{ + status: code, + } + + wrapper := &Keeper{ + ctx: ctx, + fn: client.RoundTrip, + cancel: cancel, + concurrencyticker: make(chan bool), + requests: make(chan *requestWrapper), + requestTimeout: time.Minute, + } + + // cancel the context to trigger passthrough + cancel() + + resp, err := wrapper.RoundTrip(newGetReqWCtx()) + if err != nil { + t.Fatalf("error %s", err) + } + + if resp.StatusCode != code { + t.Fatalf("Expected status code %v got %v", code, resp.StatusCode) + } +} + +func Test_RoundTrip_Request_Timeout_ReqWOutCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := &httpclient{ + delay: time.Minute, + } + client := New( + ctx, + c.RoundTrip, + 0, + 0, + 0, + time.Second, + ) + + _, err := client.RoundTrip(newGetReqWOutCtx()) + if err != nil { + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded; got %T", err) + } + } else { + t.Fatal("Expected timeout error") + } +} + +func Test_RoundTrip_Request_Timeout_ReqWCtx(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + c := &httpclient{ + delay: time.Minute, + } + client := New( + ctx, + c.RoundTrip, + 0, + 0, + 0, + time.Second, + ) + + _, err := client.RoundTrip(newGetReqWCtx()) + if err != nil { + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded; got %T", err) + } + } else { + t.Fatal("Expected timeout error") + } +} + +func Test_RoundTrip_Throughput(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p := &passthrough{ + ctx, + make(chan *http.Request), + } + + client := New(ctx, p.RoundTrip, 0, 0, 0, time.Minute) + r := newGetReqWCtx() + + go func(r *http.Request) { + client.RoundTrip(r) + }(r) + + select { + case <-ctx.Done(): + t.Fatal("context closed prematurely") + case rout, ok := <-p.out: + if !ok { + t.Fatal("passthrough closed prematurely") + } + + if !reflect.DeepEqual(r, rout) { + t.Fatal("requests do not match") + } + } +} + +func Benchmark_RoundTrip_ZeroConcurrency(b *testing.B) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + p := &passthrough{ + ctx, + make(chan *http.Request), + } + + client := New(ctx, p.RoundTrip, 0, 0, 0, time.Minute) + r := newGetReqWCtx() + + b.ResetTimer() + + for n := 0; n < b.N; n++ { + go func(r *http.Request) { + client.RoundTrip(r) + }(r) + + select { + case <-ctx.Done(): + b.Fatal("context closed prematurely") + case _, ok := <-p.out: + if !ok { + b.Fatal("passthrough closed prematurely") + } + } + } +} diff --git a/v2/exported.go b/v2/exported.go new file mode 100644 index 0000000..9eaed21 --- /dev/null +++ b/v2/exported.go @@ -0,0 +1,172 @@ +// package bk is intended to create a client side load balancer or +// rate limiter for API integrations. This library is specifically designed to +// wrap the `Do` method of the http.Client but since it uses an interface +// abstraction it can wrap any interface and limit requests. +// +// Controls +// - Delay between requests +// - Number of retries per request +// - Concurrency limit for the client +package bk + +import ( + "context" + "errors" + "fmt" + "net/http" + "time" +) + +type RoundTrip func(*http.Request) (*http.Response, error) + +// New creates a new instance of the bridgekeeper for use with an api. New +// returns an interface implementation of Client which replaces the +// implementation of an http.Client interface so that it looks like an +// http.Client and can perform the same functions but it limits the requests +// using the parameters defined when created. NOTE: If a request timeout is not +// set at creation then the default HTTP client request timeout will be used +func New( + ctx context.Context, + fn RoundTrip, + delay time.Duration, + retries int, + concurrency int, + requestTimeout time.Duration, +) *Keeper { + if requestTimeout < time.Nanosecond { + requestTimeout = http.DefaultClient.Timeout + } + + if retries < 0 { + retries = 0 + } + + if delay <= 0 { + delay = time.Nanosecond + } + + // ensure the concurrency is setup above zero + if concurrency < 1 { + concurrency = 1 + } + + // Setup a background context if no context is passed + if ctx == nil { + ctx = context.Background() + } + + ctx, cancel := context.WithCancel(ctx) + + // If a nil client is passed to the bridgekeeper then initialize using the + // default http client + if fn == nil { + fn = http.DefaultClient.Do + } + + k := &Keeper{ + ctx: ctx, + cancel: cancel, + fn: fn, + retries: retries, + delay: delay, + ticker: time.NewTicker(delay), + concurrency: concurrency, + concurrencyticker: make(chan bool, concurrency), + requestTimeout: requestTimeout, + } + + // Initialize the concurrency channel for managing concurrent calls + for i := 0; i < k.concurrency; i++ { + select { + case <-k.ctx.Done(): + case k.concurrencyticker <- true: + } + } + + // Setup requests channel + k.requests = k.receive() + + go k.cleanup() + + return k +} + +// cleanup deals with cleaning any struct values for the keeper +func (k *Keeper) cleanup() { + <-k.ctx.Done() + + if k.ticker != nil { + k.ticker.Stop() + } +} + +// RoundTrip is a wrapper for the Do method of the bridgekeeper. This is +// necessary for the bridgekeeper to implement the http.RoundTripper +func (k *Keeper) RoundTrip(req *http.Request) (*http.Response, error) { + return k.Do(req) +} + +// Do sends the http request through the bridgekeeper to be executed against the +// endpoint when there are available threads to do so. This returns an http +// response which is returned from the execution of the http request as well +// as an error +// +// XXX: Possibly add in defer here that determines if the response is nil +// and executes the wrapped `Do` method directly +func (k *Keeper) Do(request *http.Request) (*http.Response, error) { + if request == nil { + return nil, errors.New("request cannot be nil") + } + + // Fail open if the bridgekeeper request was canceled + select { + case <-k.ctx.Done(): + return k.fn(request) + default: + + // If the request has a context then use it + ctx := request.Context() + if ctx == nil || ctx == context.Background() { + ctx = k.ctx + + // Add the context to the request if it didn't already + // have one assigned + request = request.WithContext(ctx) + } + + // Enforce request specific timeout + ctx, cancel := context.WithTimeout(ctx, k.requestTimeout) + defer cancel() + + var responsechan = make(chan responseWrapper) + defer close(responsechan) + + // Create the request wrapper to send to receive + req := &requestWrapper{ + ctx: ctx, + request: request, + response: responsechan, + } + + // Send the request to the processing channel of the bridgekeeper + go func() { + select { + case <-ctx.Done(): + return + case k.requests <- req: + } + }() + + // Wait for the response from the request + select { + case <-ctx.Done(): + return nil, ctx.Err() + case resp, ok := <-responsechan: + if !ok { + return nil, fmt.Errorf("response channel closed prematurely") + } + + return resp.response, resp.err + } + } +} diff --git a/v2/go.mod b/v2/go.mod new file mode 100644 index 0000000..827e11f --- /dev/null +++ b/v2/go.mod @@ -0,0 +1,3 @@ +module go.devnw.com/bk/v2 + +go 1.23 diff --git a/v2/go.sum b/v2/go.sum new file mode 100644 index 0000000..e69de29