-
Notifications
You must be signed in to change notification settings - Fork 0
/
ratelimit.go
150 lines (122 loc) · 3.46 KB
/
ratelimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package ratelimit
import (
"net"
"net/http"
"strconv"
"sync"
"time"
bucket "github.com/DavidCai1993/token-bucket"
"github.com/go-http-utils/headers"
)
// Version is this package's version number.
const Version = "0.3.0"
// GetIDFunc represents a function that return an ID for each request.
// All requests which have the same ID will be regarded from one source and
// be ratelimited.
type GetIDFunc func(*http.Request) string
func defaultGetIDFunc(req *http.Request) string {
ra := req.RemoteAddr
if ip := req.Header.Get(headers.XForwardedFor); ip != "" {
ra = ip
} else if ip := req.Header.Get(headers.XRealIP); ip != "" {
ra = ip
} else {
ra, _, _ = net.SplitHostPort(ra)
}
return net.ParseIP(ra).String()
}
type expireMap map[string]time.Time
func (em expireMap) checkIfExpired(id string) bool {
if e, ok := em[id]; ok {
if time.Now().After(e) {
return true
}
}
return false
}
func (em expireMap) getOneExpiredID() (string, bool) {
now := time.Now()
for id, e := range em {
if now.After(e) {
return id, true
}
}
return "", false
}
// Options is the ratelimit middleware options.
type Options struct {
// GetIDFunc represents a function that return an ID for each request.
// All requests which have the same ID will be regarded from one source and
// be ratelimited.
GetID GetIDFunc
// Ratelimit factor: only Count requests can pass through in Duration.
// By default is 1 minute.
Duration time.Duration
// Ratelimit factor: only Count requests can pass through in Duration.
// By default is 1000.
Count int64
}
// Handler wraps the http.Handler with reatelimit support (only count requests
// can pass through in duration).
func Handler(h http.Handler, opts Options) http.Handler {
if opts.GetID == nil {
opts.GetID = defaultGetIDFunc
}
if opts.Duration == 0 {
opts.Duration = time.Minute
}
if opts.Count == 0 {
opts.Count = 1000
}
mutex := sync.Mutex{}
bucketsMap := map[string]*bucket.TokenBucket{}
expireMap := expireMap{}
interval := opts.Count / int64(opts.Duration/1e9) * 1e9
// Start a deamon gorouinue to check expiring.
// To take the performance into consideration, the deamon will
// only check at most 1000 buckets or cost at most one second at
// one tick.
go func() {
for now := range time.Tick(opts.Duration) {
mutex.Lock()
hasExpired := true
numExpired := 0
timeLimit := now.Add(time.Second)
for hasExpired && (numExpired < 1000 || now.After(timeLimit)) {
if id, ok := expireMap.getOneExpiredID(); ok {
delete(expireMap, id)
bucketsMap[id].Destory()
delete(bucketsMap, id)
numExpired++
} else {
hasExpired = false
}
}
mutex.Unlock()
}
}()
return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
id := opts.GetID(req)
mutex.Lock()
b, ok := bucketsMap[id]
if !ok || expireMap.checkIfExpired(id) {
b = bucket.New(time.Duration(interval), opts.Count)
bucketsMap[id] = b
}
ok = b.TryTake(1)
avail := b.Availible()
cap := b.Capability()
expireMap[id] = time.Now().Add(opts.Duration)
mutex.Unlock()
resHeader := res.Header()
if ok {
resHeader.Set(headers.XRatelimitLimit, strconv.FormatInt(cap, 10))
resHeader.Set(headers.XRatelimitRemaining, strconv.FormatInt(avail, 10))
h.ServeHTTP(res, req)
} else {
resHeader.Set(headers.RetryAfter, strconv.FormatInt(interval/1e9, 10))
res.WriteHeader(http.StatusTooManyRequests)
res.Write([]byte(http.StatusText(http.StatusTooManyRequests)))
}
})
}