diff --git a/api_client_iter.go b/api_client_iter.go index 20fc8ed..b4e749d 100644 --- a/api_client_iter.go +++ b/api_client_iter.go @@ -17,72 +17,61 @@ type IterResult struct { Next string } -func createPagingIterator( +func getIterResult( fetchPage func(from string) (*http.Response, error), -) iter.Seq2[*IterResult, error] { + cursor string, +) (*IterResult, error) { + response, err := fetchPage(cursor) + if err != nil { + return nil, fmt.Errorf("failed to fetch next page: %w", err) + } + + defer response.Body.Close() + + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + type ResponseWithNext struct { Next string `json:"next"` } + var responseWithNext ResponseWithNext + if err := json.Unmarshal(body, &responseWithNext); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } - cursor := "" + response.Body = io.NopCloser(bytes.NewReader(body)) + + iterResult := &IterResult{ + Response: response, + Next: responseWithNext.Next, + } + + return iterResult, nil +} +func createPagingIterator( + fetchPage func(from string) (*http.Response, error), +) iter.Seq2[*IterResult, error] { + cursor := "" return func(yield func(*IterResult, error) bool) { for { - // Fire the request - response, err := fetchPage(cursor) + iterResult, err := getIterResult( + fetchPage, + cursor, + ) if err != nil { yield(nil, err) return } - - // Read the body - body, err := io.ReadAll(response.Body) - if err != nil { - yield( - nil, - fmt.Errorf("failed to read response: %w", err), - ) - return - } - - // Close the body - if err := response.Body.Close(); err != nil { - yield( - nil, - fmt.Errorf("failed to close http response: %w", err), - ) + cursor = iterResult.Next + if !yield(iterResult, err) { return } - - // Look for the next token - var responseWithNext ResponseWithNext - if err := json.Unmarshal(body, &responseWithNext); err != nil { - yield( - nil, - fmt.Errorf("failed to unmarshal response: %w", err), - ) + if cursor == "" { return } - - // Replace the body and return the response - response.Body = io.NopCloser(bytes.NewReader(body)) - if !yield( - &IterResult{ - Response: response, - Next: responseWithNext.Next, - }, - nil, - ) { - return - } - - // Was this the last page? - if responseWithNext.Next == "" { - return - } - - // Set the cursor for next page - cursor = responseWithNext.Next } } } diff --git a/api_client_iter_test.go b/api_client_iter_test.go index 1c35b26..78c203a 100644 --- a/api_client_iter_test.go +++ b/api_client_iter_test.go @@ -55,6 +55,33 @@ func TestIterGet(t *testing.T) { } +func TestIterGetBadResponse(t *testing.T) { + ct := newClientTest( + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/leaksdb/sources", r.URL.Path) + w.Write([]byte(`{"next": 11, "items": []}`)) + }), + ) + defer ct.Close() + + lastPageIndex := 0 + + for result, err := range ct.apiClient.IterGet( + "/leaksdb/sources", + nil, + ) { + lastPageIndex = lastPageIndex + 1 + if lastPageIndex > 2 { + // We are going crazy here... + break + } + assert.ErrorContains(t, err, "failed to unmarshal", "Bad next token should trigger an error") + assert.Nil(t, result, "bad response should not contain a result") + } + + assert.Equal(t, 1, lastPageIndex, "Didn't get the expected number of pages") +} + func TestIterPostJson(t *testing.T) { ct := newClientTest( http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {