diff --git a/hcloud/exp/mockutil/http.go b/hcloud/exp/mockutil/http.go index 15b3959f..03aac1da 100644 --- a/hcloud/exp/mockutil/http.go +++ b/hcloud/exp/mockutil/http.go @@ -27,70 +27,97 @@ type Request struct { JSONRaw string } -// Handler is used with a [httptest.Server] to mock http requests provided by the user. -// -// Request matching is based on the request count, and the user provided request will be -// iterated over. +// Handler is using a [Server] to mock http requests provided by the user. func Handler(t *testing.T, requests []Request) http.HandlerFunc { t.Helper() - index := 0 + server := NewServer(t, requests) + t.Cleanup(server.close) - t.Cleanup(func() { - assert.EqualValues(t, len(requests), index, "expected more calls") - }) + return server.handler +} - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if testing.Verbose() { - t.Logf("call %d: %s %s\n", index, r.Method, r.RequestURI) - } +// NewServer returns a new mock server that closes itself at the end of the test. +func NewServer(t *testing.T, requests []Request) *Server { + t.Helper() - if index >= len(requests) { - t.Fatalf("received unknown call %d", index) - } + o := &Server{t: t} + o.Server = httptest.NewServer(http.HandlerFunc(o.handler)) + t.Cleanup(o.close) - expected := requests[index] + o.Expect(requests) - expectedCall := expected.Method - foundCall := r.Method - if expected.Path != "" { - expectedCall += " " + expected.Path - foundCall += " " + r.RequestURI - } - require.Equal(t, expectedCall, foundCall) + return o +} - if expected.Want != nil { - expected.Want(t, r) - } +// Server embeds a [httptest.Server] that answers HTTP calls with a list of expected [Request]. +// +// Request matching is based on the request count, and the user provided request will be +// iterated over. +// +// A Server must be created using the [NewServer] function. +type Server struct { + *httptest.Server - switch { - case expected.JSON != nil: - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(expected.Status) - if err := json.NewEncoder(w).Encode(expected.JSON); err != nil { - t.Fatal(err) - } - case expected.JSONRaw != "": - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(expected.Status) - _, err := w.Write([]byte(expected.JSONRaw)) - if err != nil { - t.Fatal(err) - } - default: - w.WriteHeader(expected.Status) - } + t *testing.T - index++ - }) + requests []Request + index int } -// Server is a [httptest.Server] wrapping a [Handler] that closes itself at the end of the test. -func Server(t *testing.T, requests []Request) *httptest.Server { - t.Helper() +// Expect adds requests to the list of requests expected by the [Server]. +func (m *Server) Expect(requests []Request) { + m.requests = append(m.requests, requests...) +} + +func (m *Server) close() { + m.t.Helper() - server := httptest.NewServer(Handler(t, requests)) - t.Cleanup(server.Close) + m.Server.Close() + + assert.EqualValues(m.t, len(m.requests), m.index, "expected more calls") +} + +func (m *Server) handler(w http.ResponseWriter, r *http.Request) { + if testing.Verbose() { + m.t.Logf("call %d: %s %s\n", m.index, r.Method, r.RequestURI) + } + + if m.index >= len(m.requests) { + m.t.Fatalf("received unknown call %d", m.index) + } + + expected := m.requests[m.index] + + expectedCall := expected.Method + foundCall := r.Method + if expected.Path != "" { + expectedCall += " " + expected.Path + foundCall += " " + r.RequestURI + } + require.Equal(m.t, expectedCall, foundCall) + + if expected.Want != nil { + expected.Want(m.t, r) + } + + switch { + case expected.JSON != nil: + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(expected.Status) + if err := json.NewEncoder(w).Encode(expected.JSON); err != nil { + m.t.Fatal(err) + } + case expected.JSONRaw != "": + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(expected.Status) + _, err := w.Write([]byte(expected.JSONRaw)) + if err != nil { + m.t.Fatal(err) + } + default: + w.WriteHeader(expected.Status) + } - return server + m.index++ } diff --git a/hcloud/exp/mockutil/http_test.go b/hcloud/exp/mockutil/http_test.go index 2e027e6c..c6bce052 100644 --- a/hcloud/exp/mockutil/http_test.go +++ b/hcloud/exp/mockutil/http_test.go @@ -13,7 +13,7 @@ import ( ) func TestHandler(t *testing.T) { - server := Server(t, []Request{ + server := NewServer(t, []Request{ { Method: "GET", Path: "/", Status: 200, @@ -68,6 +68,17 @@ func TestHandler(t *testing.T) { assert.Equal(t, 200, resp.StatusCode) assert.Equal(t, "", resp.Header.Get("Content-Type")) assert.Equal(t, "", readBody(t, resp)) + + // Extra request 5 + server.Expect([]Request{ + {Method: "GET", Path: "/", Status: 200}, + }) + + resp, err = http.Get(server.URL) + require.NoError(t, err) + assert.Equal(t, 200, resp.StatusCode) + assert.Equal(t, "", resp.Header.Get("Content-Type")) + assert.Equal(t, "", readBody(t, resp)) } func readBody(t *testing.T, resp *http.Response) string {