diff --git a/breach.go b/breach.go index b5bfeb2..280a601 100644 --- a/breach.go +++ b/breach.go @@ -3,7 +3,6 @@ package hibp import ( "encoding/json" "fmt" - "io" "net/http" "strings" "time" @@ -100,7 +99,7 @@ func (b *BreachApi) Breaches(options ...BreachOption) ([]*Breach, *http.Response queryParams := b.setBreachOpts(options...) apiUrl := fmt.Sprintf("%s/breaches", BaseUrl) - hb, hr, err := b.apiCall(http.MethodGet, apiUrl, queryParams) + hb, hr, err := b.hibp.HttpReqBody(http.MethodGet, apiUrl, queryParams) if err != nil { return nil, nil, err } @@ -122,7 +121,7 @@ func (b *BreachApi) BreachByName(n string, options ...BreachOption) (*Breach, *h } apiUrl := fmt.Sprintf("%s/breach/%s", BaseUrl, n) - hb, hr, err := b.apiCall(http.MethodGet, apiUrl, queryParams) + hb, hr, err := b.hibp.HttpReqBody(http.MethodGet, apiUrl, queryParams) if err != nil { return nil, nil, err } @@ -139,7 +138,7 @@ func (b *BreachApi) BreachByName(n string, options ...BreachOption) (*Breach, *h // with all registered data classes known to HIBP func (b *BreachApi) DataClasses() ([]string, *http.Response, error) { apiUrl := fmt.Sprintf("%s/dataclasses", BaseUrl) - hb, hr, err := b.apiCall(http.MethodGet, apiUrl, nil) + hb, hr, err := b.hibp.HttpReqBody(http.MethodGet, apiUrl, nil) if err != nil { return nil, nil, err } @@ -161,7 +160,7 @@ func (b *BreachApi) BreachedAccount(a string, options ...BreachOption) ([]*Breac } apiUrl := fmt.Sprintf("%s/breachedaccount/%s", BaseUrl, a) - hb, hr, err := b.apiCall(http.MethodGet, apiUrl, queryParams) + hb, hr, err := b.hibp.HttpReqBody(http.MethodGet, apiUrl, queryParams) if err != nil { return nil, nil, err } @@ -246,30 +245,3 @@ func (b *BreachApi) setBreachOpts(options ...BreachOption) map[string]string { return queryParams } - -// apiCall performs the API call to the breaches API and returns the HTTP response body JSON as -// byte array -func (b *BreachApi) apiCall(m string, p string, q map[string]string) ([]byte, *http.Response, error) { - hreq, err := b.hibp.HttpReq(m, p, q) - if err != nil { - return nil, nil, err - } - hr, err := b.hibp.hc.Do(hreq) - if err != nil { - return nil, hr, err - } - defer func() { - _ = hr.Body.Close() - }() - - hb, err := io.ReadAll(hr.Body) - if err != nil { - return nil, hr, err - } - - if hr.StatusCode != 200 { - return nil, hr, fmt.Errorf("API responded with non HTTP-200: %s - %s", hr.Status, hb) - } - - return hb, hr, nil -} diff --git a/breach_test.go b/breach_test.go index 4ce8e07..a3e186d 100644 --- a/breach_test.go +++ b/breach_test.go @@ -218,3 +218,51 @@ func TestBreachedAccount(t *testing.T) { }) } } + +// TestBreachedAccountWithoutTruncate tests the BreachedAccount() method of the breaches API with the +// truncateResponse option set to false +func TestBreachedAccountWithoutTruncate(t *testing.T) { + testTable := []struct { + testName string + accountName string + breachName string + breachDomain string + shouldFail bool + }{ + {"account-exists is breached once", "account-exists", "Adobe", + "adobe.com", false}, + {"multiple-breaches is breached multiple times", "multiple-breaches", "Adobe", + "adobe.com", false}, + {"opt-out is not breached", "opt-out", "", "", true}, + } + + hc := New(WithApiKey(os.Getenv("HIBP_API_KEY")), WithRateLimitNoFail()) + if hc == nil { + t.Error("failed to create HIBP client") + return + } + + for _, tc := range testTable { + t.Run(tc.testName, func(t *testing.T) { + breachDetails, _, err := hc.BreachApi.BreachedAccount( + fmt.Sprintf("%s@hibp-integration-tests.com", tc.accountName), + WithoutTruncate()) + if err != nil && !tc.shouldFail { + t.Error(err) + return + } + + for _, b := range breachDetails { + if tc.breachName != b.Name { + t.Errorf("breach name for the account %q does not match. expected: %q, got: %q", + tc.accountName, tc.breachName, b.Name) + } + if tc.breachDomain != b.Domain { + t.Errorf("breach domain for the account %q does not match. expected: %q, got: %q", + tc.accountName, tc.breachDomain, b.Domain) + } + break + } + }) + } +} diff --git a/hibp.go b/hibp.go index 2c92609..89fec59 100644 --- a/hibp.go +++ b/hibp.go @@ -3,14 +3,16 @@ package hibp import ( "bytes" "crypto/tls" + "fmt" "io" + "log" "net/http" "net/url" "time" ) // Version represents the version of this package -const Version = "0.1.3" +const Version = "0.1.4" // BaseUrl is the base URL for the majority of API calls const BaseUrl = "https://haveibeenpwned.com/api/v3" @@ -22,10 +24,11 @@ const DefaultUserAgent = `go-hibp v` + Version // + ` - https://github.com/wnees // Client is the HIBP client object type Client struct { - hc *http.Client // HTTP client to perform the API requests - to time.Duration // HTTP client timeout - ak string // HIBP API key - ua string // User agent string for the HTTP client + hc *http.Client // HTTP client to perform the API requests + to time.Duration // HTTP client timeout + ak string // HIBP API key + ua string // User agent string for the HTTP client + rlNoFail bool // Controls wether the HTTP client should fail or sleep in case the rate limiting hits PwnedPassApi *PwnedPassApi // Reference to the PwnedPassApi API PwnedPassApiOpts *PwnedPasswordOptions // Additional options for the PwnedPassApi API @@ -94,6 +97,13 @@ func WithUserAgent(a string) Option { } } +// WithRateLimitNoFail let's the HTTP client sleep in case the API rate limiting hits (Defaults to fail) +func WithRateLimitNoFail() Option { + return func(c *Client) { + c.rlNoFail = true + } +} + // HttpReq performs an HTTP request to the corresponding API func (c *Client) HttpReq(m, p string, q map[string]string) (*http.Request, error) { u, err := url.Parse(p) @@ -136,6 +146,43 @@ func (c *Client) HttpReq(m, p string, q map[string]string) (*http.Request, error return hr, nil } +// HttpReqBody performs the API call to the given path and returns the response body as byte array +func (c *Client) HttpReqBody(m string, p string, q map[string]string) ([]byte, *http.Response, error) { + hreq, err := c.HttpReq(m, p, q) + if err != nil { + return nil, nil, err + } + hr, err := c.hc.Do(hreq) + if err != nil { + return nil, hr, err + } + defer func() { + _ = hr.Body.Close() + }() + + hb, err := io.ReadAll(hr.Body) + if err != nil { + return nil, hr, err + } + + if hr.StatusCode == 429 && c.rlNoFail { + headerDelay := hr.Header.Get("Retry-After") + delayTime, err := time.ParseDuration(headerDelay + "s") + if err != nil { + return nil, hr, err + } + log.Printf("API rate limit hit. Retrying request in %s", delayTime.String()) + time.Sleep(delayTime) + return c.HttpReqBody(m, p, q) + } + + if hr.StatusCode != 200 { + return nil, hr, fmt.Errorf("API responded with non HTTP-200: %s - %s", hr.Status, hb) + } + + return hb, hr, nil +} + // httpClient returns a custom http client for the HIBP Client object func httpClient(to time.Duration) *http.Client { tlsConfig := &tls.Config{