-
Notifications
You must be signed in to change notification settings - Fork 0
/
middleware.go
161 lines (134 loc) · 3.99 KB
/
middleware.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package httpsling
import (
"context"
"fmt"
"io"
"net/http"
"net/http/httputil"
"os"
"github.com/theopenlane/utils/rout"
)
// Middleware can be used to wrap Doers with additional functionality.
type Middleware func(Doer) Doer
// Apply implements Option
func (m Middleware) Apply(r *Requester) error {
r.Middleware = append(r.Middleware, m)
return nil
}
// Wrap applies a set of middleware to a Doer. The returned Doer will invoke
// the middleware in the order of the arguments.
func Wrap(d Doer, m ...Middleware) Doer {
for i := len(m) - 1; i > -1; i-- {
d = m[i](d)
}
return d
}
// Dump dumps requests and responses to a writer. Just intended for debugging.
func Dump(w io.Writer) Middleware {
return func(next Doer) Doer {
return DoerFunc(func(req *http.Request) (*http.Response, error) {
dump, dumperr := httputil.DumpRequestOut(req, true)
if dumperr != nil {
io.WriteString(w, "Error dumping request: "+dumperr.Error()+"\n") // nolint: errcheck
} else {
io.WriteString(w, string(dump)+"\n") // nolint: errcheck
}
resp, err := next.Do(req)
if resp != nil {
dump, dumperr = httputil.DumpResponse(resp, true)
if dumperr != nil {
io.WriteString(w, "Error dumping response: "+dumperr.Error()+"\n") // nolint: errcheck
} else {
io.WriteString(w, string(dump)+"\n") // nolint: errcheck
}
}
return resp, err
})
}
}
// DumpToStout dumps requests and responses to os.Stdout
func DumpToStout() Middleware {
return Dump(os.Stdout)
}
// DumpToStderr dumps requests and responses to os.Stderr
func DumpToStderr() Middleware {
return Dump(os.Stderr)
}
type logFunc func(a ...interface{})
func (f logFunc) Write(p []byte) (n int, err error) {
f(string(p))
return len(p), nil
}
// DumpToLog dumps the request and response to a logging function.
// logf is compatible with fmt.Print(), testing.T.Log, or log.XXX()
// functions.
//
// logf will be invoked once for the request, and once for the response.
// Each invocation will only have a single argument (the entire request
// or response is logged as a single string value).
func DumpToLog(logf func(a ...interface{})) Middleware {
return Dump(logFunc(logf))
}
// ExpectCode generates an error if the response's status code does not match
// the expected code.
//
// The response body will still be read and returned.
func ExpectCode(code int) Middleware {
return func(next Doer) Doer {
return DoerFunc(func(req *http.Request) (*http.Response, error) {
r, c := getCodeChecker(req)
c.code = code
resp, err := next.Do(r)
return c.checkCode(resp, err)
})
}
}
// ExpectSuccessCode is middleware which generates an error if the response's status code is not between 200 and
// 299.
//
// The response body will still be read and returned.
func ExpectSuccessCode() Middleware {
return func(next Doer) Doer {
return DoerFunc(func(req *http.Request) (*http.Response, error) {
r, c := getCodeChecker(req)
c.code = expectSuccessCode
resp, err := next.Do(r)
return c.checkCode(resp, err)
})
}
}
type ctxKey int
const expectCodeCtxKey ctxKey = iota
const expectSuccessCode = -1
type codeChecker struct {
code int
}
func (c *codeChecker) checkCode(resp *http.Response, err error) (*http.Response, error) {
switch {
case err != nil, resp == nil:
case c.code == expectSuccessCode:
if resp.StatusCode < 200 || resp.StatusCode > 299 {
err = rout.HTTPErrorResponse(
fmt.Errorf("%w: server returned unsuccessful status code: %d",
ErrUnsuccessfulResponse,
resp.StatusCode,
))
}
case c.code != resp.StatusCode:
err = rout.HTTPErrorResponse(
fmt.Errorf("%w: server returned unexpected status code. expected: %d, received: %d",
ErrUnsuccessfulResponse,
c.code,
resp.StatusCode,
))
}
return resp, err
}
func getCodeChecker(req *http.Request) (*http.Request, *codeChecker) {
c, _ := req.Context().Value(expectCodeCtxKey).(*codeChecker)
if c == nil {
c = &codeChecker{}
req = req.WithContext(context.WithValue(req.Context(), expectCodeCtxKey, c))
}
return req, c
}