Skip to content

Commit e2c4fd2

Browse files
authored
feat: rate limiting (#36)
* feat(ratelimit): add rate limiting module * feat(ratelimit): setup rate limiter * fixup! feat(ratelimit): setup rate limiter * fixup! fixup! feat(ratelimit): setup rate limiter * fixup! feat(ratelimit): setup rate limiter * fix(e2e): set ratelimiter config in tests so they would pass * refactor(ratelimit): remove unused code * fix(ratelimit): now the middleware shouldn't panic * fix(ratelimit): actually handle if user isnt found * refactor(ratelimit): use rw mutex * chore: update .env.example
1 parent 47d33af commit e2c4fd2

File tree

7 files changed

+178
-8
lines changed

7 files changed

+178
-8
lines changed

.env.example

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,7 @@ MAILGUN_FROM=onasty@mail.com
2424
MAILGUN_DOMAI='<domain>'
2525
MAILGUN_API_KEY='<token>'
2626
VERIFICATION_TOKEN_TTL=48h
27+
28+
RATELIMITER_RPS=100
29+
RATELIMITER_BURST=10
30+
RATELIMITER_TTL=3m

cmd/server/main.go

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/olexsmir/onasty/internal/store/psqlutil"
2626
httptransport "github.com/olexsmir/onasty/internal/transport/http"
2727
"github.com/olexsmir/onasty/internal/transport/http/httpserver"
28+
"github.com/olexsmir/onasty/internal/transport/http/ratelimit"
2829
)
2930

3031
func main() {
@@ -83,7 +84,17 @@ func run(ctx context.Context) error {
8384
noterepo := noterepo.New(psqlDB)
8485
notesrv := notesrv.New(noterepo)
8586

86-
handler := httptransport.NewTransport(usersrv, notesrv)
87+
rateLimiterConfig := ratelimit.Config{
88+
RPS: cfg.RateLimiterRPS,
89+
TTL: cfg.RateLimiterTTL,
90+
Burst: cfg.RateLimiterBurst,
91+
}
92+
93+
handler := httptransport.NewTransport(
94+
usersrv,
95+
notesrv,
96+
rateLimiterConfig,
97+
)
8798

8899
// http server
89100
srv := httpserver.NewServer(cfg.ServerPort, handler.Handler())

e2e/e2e_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
"github.com/olexsmir/onasty/internal/store/psql/vertokrepo"
2727
"github.com/olexsmir/onasty/internal/store/psqlutil"
2828
httptransport "github.com/olexsmir/onasty/internal/transport/http"
29+
"github.com/olexsmir/onasty/internal/transport/http/ratelimit"
2930
"github.com/stretchr/testify/require"
3031
"github.com/stretchr/testify/suite"
3132
"github.com/testcontainers/testcontainers-go"
@@ -117,7 +118,14 @@ func (e *AppTestSuite) initDeps() {
117118
noterepo := noterepo.New(e.postgresDB)
118119
notesrv := notesrv.New(noterepo)
119120

120-
handler := httptransport.NewTransport(usersrv, notesrv)
121+
// for testing purposes, it's ok to have high values ig
122+
ratelimitCfg := ratelimit.Config{
123+
RPS: 1000,
124+
TTL: time.Millisecond,
125+
Burst: 1000,
126+
}
127+
128+
handler := httptransport.NewTransport(usersrv, notesrv, ratelimitCfg)
121129
e.router = handler.Handler()
122130
}
123131

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ require (
1515
github.com/stretchr/testify v1.9.0
1616
github.com/testcontainers/testcontainers-go v0.33.0
1717
github.com/testcontainers/testcontainers-go/modules/postgres v0.33.0
18+
golang.org/x/time v0.5.0
1819
)
1920

2021
require (

internal/config/config.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package config
33
import (
44
"errors"
55
"os"
6+
"strconv"
67
"time"
78
)
89

@@ -28,6 +29,10 @@ type Config struct {
2829
LogLevel string
2930
LogFormat string
3031
LogShowLine bool
32+
33+
RateLimiterRPS int
34+
RateLimiterBurst int
35+
RateLimiterTTL time.Duration
3136
}
3237

3338
func NewConfig() *Config {
@@ -39,17 +44,17 @@ func NewConfig() *Config {
3944
PasswordSalt: getenvOrDefault("PASSWORD_SALT", ""),
4045

4146
JwtSigningKey: getenvOrDefault("JWT_SIGNING_KEY", ""),
42-
JwtAccessTokenTTL: mustParseDurationOrPanic(
47+
JwtAccessTokenTTL: mustParseDuration(
4348
getenvOrDefault("JWT_ACCESS_TOKEN_TTL", "15m"),
4449
),
45-
JwtRefreshTokenTTL: mustParseDurationOrPanic(
50+
JwtRefreshTokenTTL: mustParseDuration(
4651
getenvOrDefault("JWT_REFRESH_TOKEN_TTL", "24h"),
4752
),
4853

4954
MailgunFrom: getenvOrDefault("MAILGUN_FROM", ""),
5055
MailgunDomain: getenvOrDefault("MAILGUN_DOMAIN", ""),
5156
MailgunAPIKey: getenvOrDefault("MAILGUN_API_KEY", ""),
52-
VerificationTokenTTL: mustParseDurationOrPanic(
57+
VerificationTokenTTL: mustParseDuration(
5358
getenvOrDefault("VERIFICATION_TOKEN_TTL", "24h"),
5459
),
5560

@@ -59,6 +64,10 @@ func NewConfig() *Config {
5964
LogLevel: getenvOrDefault("LOG_LEVEL", "debug"),
6065
LogFormat: getenvOrDefault("LOG_FORMAT", "json"),
6166
LogShowLine: getenvOrDefault("LOG_SHOW_LINE", "true") == "true",
67+
68+
RateLimiterRPS: mustGetenvOrDefaultInt("RATELIMITER_RPS", 100),
69+
RateLimiterBurst: mustGetenvOrDefaultInt("RATELIMITER_BURST", 10),
70+
RateLimiterTTL: mustParseDuration(getenvOrDefault("RATELIMITER_TTL", "1m")),
6271
}
6372
}
6473

@@ -73,7 +82,18 @@ func getenvOrDefault(key, def string) string {
7382
return def
7483
}
7584

76-
func mustParseDurationOrPanic(dur string) time.Duration {
85+
func mustGetenvOrDefaultInt(key string, def int) int {
86+
if v, ok := os.LookupEnv(key); ok {
87+
r, err := strconv.Atoi(v)
88+
if err != nil {
89+
panic(err)
90+
}
91+
return r
92+
}
93+
return def
94+
}
95+
96+
func mustParseDuration(dur string) time.Duration {
7797
d, err := time.ParseDuration(dur)
7898
if err != nil {
7999
panic(errors.Join(errors.New("cannot time.ParseDuration"), err)) //nolint:err113

internal/transport/http/http.go

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,26 @@ import (
77
"github.com/olexsmir/onasty/internal/service/notesrv"
88
"github.com/olexsmir/onasty/internal/service/usersrv"
99
"github.com/olexsmir/onasty/internal/transport/http/apiv1"
10+
"github.com/olexsmir/onasty/internal/transport/http/ratelimit"
1011
"github.com/olexsmir/onasty/internal/transport/http/reqid"
1112
)
1213

1314
type Transport struct {
1415
usersrv usersrv.UserServicer
1516
notesrv notesrv.NoteServicer
17+
18+
ratelimitCfg ratelimit.Config
1619
}
1720

1821
func NewTransport(
1922
us usersrv.UserServicer,
2023
ns notesrv.NoteServicer,
24+
ratelimitCfg ratelimit.Config,
2125
) *Transport {
2226
return &Transport{
23-
usersrv: us,
24-
notesrv: ns,
27+
usersrv: us,
28+
notesrv: ns,
29+
ratelimitCfg: ratelimitCfg,
2530
}
2631
}
2732

@@ -31,6 +36,7 @@ func (t *Transport) Handler() http.Handler {
3136
gin.Recovery(),
3237
reqid.Middleware(),
3338
t.logger(),
39+
ratelimit.MiddlewareWithConfig(t.ratelimitCfg),
3440
)
3541

3642
api := r.Group("/api")
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
// thanks to https://www.alexedwards.net/blog/how-to-rate-limit-http-requests
2+
3+
package ratelimit
4+
5+
import (
6+
"net/http"
7+
"sync"
8+
"time"
9+
10+
"github.com/gin-gonic/gin"
11+
"golang.org/x/time/rate"
12+
)
13+
14+
type (
15+
rateLimiter struct {
16+
mu sync.RWMutex
17+
18+
visitors map[visitorIP]*visitor
19+
20+
// limit is the maximum number of requests per second
21+
limit rate.Limit
22+
23+
// ttl is the time after which a visitor is forgotten
24+
ttl time.Duration
25+
26+
// burst is the maximum number of requests that can be made in a short amount of time
27+
burst int
28+
}
29+
30+
visitorIP string
31+
visitor struct {
32+
limiter *rate.Limiter
33+
lastSeen time.Time
34+
}
35+
)
36+
37+
func newLimiter(rps, burst int, ttl time.Duration) *rateLimiter {
38+
return &rateLimiter{ //nolint:exhaustruct
39+
visitors: make(map[visitorIP]*visitor),
40+
limit: rate.Limit(rps),
41+
burst: burst,
42+
ttl: ttl,
43+
}
44+
}
45+
46+
// Retrieve and return the rate limiter for the current visitor if it
47+
// already exists. Otherwise create a new rate limiter and add it to
48+
// the visitors map, using the IP address as the key.
49+
func (r *rateLimiter) getVisitor(ip visitorIP) *rate.Limiter {
50+
r.mu.RLock()
51+
v, exists := r.visitors[ip]
52+
r.mu.RUnlock()
53+
54+
if !exists {
55+
limit := rate.NewLimiter(r.limit, r.burst)
56+
57+
r.mu.Lock()
58+
r.visitors[ip] = &visitor{
59+
limiter: limit,
60+
lastSeen: time.Now(),
61+
}
62+
r.mu.Unlock()
63+
64+
return limit
65+
}
66+
67+
r.mu.Lock()
68+
v.lastSeen = time.Now()
69+
r.mu.Unlock()
70+
71+
return v.limiter
72+
}
73+
74+
// Every minute check the map for visitors that haven't been seen for
75+
// more than 3 minutes and delete the entries.
76+
func (r *rateLimiter) cleanupVisitors() {
77+
for {
78+
time.Sleep(time.Minute)
79+
80+
r.mu.Lock()
81+
for ip, v := range r.visitors {
82+
if time.Since(v.lastSeen) > r.ttl {
83+
delete(r.visitors, ip)
84+
}
85+
}
86+
r.mu.Unlock()
87+
}
88+
}
89+
90+
type Config struct {
91+
// RPS is the maximum number of requests per second
92+
RPS int
93+
94+
// TTL is the time after which a visitor is forgotten
95+
TTL time.Duration
96+
97+
// Burst is the maximum number of requests that can be made in a short amount of time
98+
Burst int
99+
}
100+
101+
// MiddlewareWithConfig returns a new rate limiting middleware with the given config
102+
func MiddlewareWithConfig(c Config) gin.HandlerFunc {
103+
lmt := newLimiter(c.RPS, c.Burst, c.TTL)
104+
go lmt.cleanupVisitors()
105+
106+
return func(c *gin.Context) {
107+
visitor := lmt.getVisitor(visitorIP(c.ClientIP()))
108+
if visitor == nil {
109+
c.AbortWithStatus(http.StatusInternalServerError)
110+
return
111+
}
112+
113+
if !visitor.Allow() {
114+
c.AbortWithStatus(http.StatusTooManyRequests)
115+
return
116+
}
117+
118+
c.Next()
119+
}
120+
}

0 commit comments

Comments
 (0)