diff --git a/.gitignore b/.gitignore index 2016845b..28e7880c 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ _testmain.go .vscode/ cmd/gost/gost +.idea \ No newline at end of file diff --git a/cmd/gost/cfg.go b/cmd/gost/cfg.go index df0c53ce..8b2912a8 100644 --- a/cmd/gost/cfg.go +++ b/cmd/gost/cfg.go @@ -148,6 +148,26 @@ func parseAuthenticator(s string) (gost.Authenticator, error) { return au, nil } +func parseLimiter(s string) (gost.Limiter, error) { + if s == "" { + return nil, nil + } + f, err := os.Open(s) + if err != nil { + return nil, err + } + defer f.Close() + + l, _ := gost.NewLocalLimiter("", "") + err = l.Reload(f) + if err != nil { + return nil, err + } + go gost.PeriodReload(l, s) + + return l, nil +} + func parseIP(s string, port string) (ips []string) { if s == "" { return diff --git a/cmd/gost/route.go b/cmd/gost/route.go index 360bc2d1..9470041e 100644 --- a/cmd/gost/route.go +++ b/cmd/gost/route.go @@ -386,6 +386,19 @@ func (r *route) GenRouters() ([]router, error) { node.User = users[0] } } + + //init rate limiter + limiterHandler, err := parseLimiter(node.Get("secrets")) + if err != nil { + return nil, err + } + if limiterHandler == nil && strings.TrimSpace(node.Get("limiter")) != "" && node.User != nil { + limiterHandler, err = gost.NewLocalLimiter(node.User.Username(), strings.TrimSpace(node.Get("limiter"))) + if err != nil { + return nil, err + } + } + certFile, keyFile := node.Get("cert"), node.Get("key") tlsCfg, err := tlsConfig(certFile, keyFile, node.Get("ca")) if err != nil && certFile != "" && keyFile != "" { @@ -671,6 +684,7 @@ func (r *route) GenRouters() ([]router, error) { gost.IPRoutesHandlerOption(tunRoutes...), gost.ProxyAgentHandlerOption(node.Get("proxyAgent")), gost.HTTPTunnelHandlerOption(node.GetBool("httpTunnel")), + gost.LimiterHandlerOption(limiterHandler), ) rt := router{ diff --git a/handler.go b/handler.go index ee82cea2..105d0300 100644 --- a/handler.go +++ b/handler.go @@ -44,6 +44,7 @@ type HandlerOptions struct { IPRoutes []IPRoute ProxyAgent string HTTPTunnel bool + Limiter Limiter } // HandlerOption allows a common way to set handler options. @@ -87,6 +88,13 @@ func AuthenticatorHandlerOption(au Authenticator) HandlerOption { } } +// LimiterHandlerOption sets the Rate limiter option of HandlerOptions +func LimiterHandlerOption(l Limiter) HandlerOption { + return func(opts *HandlerOptions) { + opts.Limiter = l + } +} + // TLSConfigHandlerOption sets the TLSConfig option of HandlerOptions. func TLSConfigHandlerOption(config *tls.Config) HandlerOption { return func(opts *HandlerOptions) { diff --git a/http.go b/http.go index 8f9e3fde..846f7ce2 100644 --- a/http.go +++ b/http.go @@ -212,7 +212,23 @@ func (h *httpHandler) handleRequest(conn net.Conn, req *http.Request) { if !h.authenticate(conn, req, resp) { return } + user, _, _ := basicProxyAuth(req.Header.Get("Proxy-Authorization")) + if h.options.Limiter != nil { + done, ok := h.options.Limiter.CheckRate(user, true) + if !ok { + resp.StatusCode = http.StatusTooManyRequests + if Debug { + dump, _ := httputil.DumpResponse(resp, false) + log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), string(dump)) + } + + resp.Write(conn) + return + } else { + defer done() + } + } if req.Method == "PRI" || (req.Method != http.MethodConnect && req.URL.Scheme != "http") { resp.StatusCode = http.StatusBadRequest diff --git a/http2.go b/http2.go index de152eae..8ba40e0a 100644 --- a/http2.go +++ b/http2.go @@ -394,7 +394,18 @@ func (h *http2Handler) roundTrip(w http.ResponseWriter, r *http.Request) { if !h.authenticate(w, r, resp) { return } - + user, _, _ := basicProxyAuth(r.Header.Get("Proxy-Authorization")) + if h.options.Limiter != nil { + done, ok := h.options.Limiter.CheckRate(user, true) + if !ok { + log.Logf("[http2] %s - %s rate limiter %s, user is %s", + r.RemoteAddr, laddr, host, user) + w.WriteHeader(http.StatusTooManyRequests) + return + } else { + defer done() + } + } // delete the proxy related headers. r.Header.Del("Proxy-Authorization") r.Header.Del("Proxy-Connection") diff --git a/limiter.go b/limiter.go new file mode 100644 index 00000000..754aa051 --- /dev/null +++ b/limiter.go @@ -0,0 +1,259 @@ +package gost + +import ( + "bufio" + "errors" + "io" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +type Limiter interface { + CheckRate(key string, checkConcurrent bool) (func(), bool) +} + +func NewLocalLimiter(user string, cfg string) (*LocalLimiter, error) { + limiter := LocalLimiter{ + buckets: map[string]*limiterBucket{}, + concurrent: map[string]chan bool{}, + stopped: make(chan struct{}), + } + if cfg == "" || user == "" { + return &limiter, nil + } + if err := limiter.AddRule(user, cfg); err != nil { + return nil, err + } + return &limiter, nil +} + +// Token Bucket +type limiterBucket struct { + max int64 + cur int64 + duration int64 + batch int64 +} + +type LocalLimiter struct { + buckets map[string]*limiterBucket + concurrent map[string]chan bool + mux sync.RWMutex + stopped chan struct{} + period time.Duration +} + +func (l *LocalLimiter) CheckRate(key string, checkConcurrent bool) (func(), bool) { + if checkConcurrent { + done, ok := l.checkConcurrent(key) + if !ok { + return nil, false + } + if t := l.getToken(key); !t { + done() + return nil, false + } + return done, true + } else { + if t := l.getToken(key); !t { + return nil, false + } + return nil, true + } +} + +func (l *LocalLimiter) AddRule(user string, cfg string) error { + if user == "" { + return nil + } + if cfg == "" { + //reload need check old limit exists + if _, ok := l.buckets[user]; ok { + delete(l.buckets, user) + } + if _, ok := l.concurrent[user]; ok { + delete(l.concurrent, user) + } + return nil + } + args := strings.Split(cfg, ",") + if len(args) < 2 || len(args) > 3 { + return errors.New("parse limiter fail:" + cfg) + } + if len(args) == 2 { + args = append(args, "0") + } + + duration, e1 := strconv.ParseInt(strings.TrimSpace(args[0]), 10, 64) + count, e2 := strconv.ParseInt(strings.TrimSpace(args[1]), 10, 64) + cur, e3 := strconv.ParseInt(strings.TrimSpace(args[2]), 10, 64) + if e1 != nil || e2 != nil || e3 != nil { + return errors.New("parse limiter fail:" + cfg) + } + // 0 means not limit + if duration > 0 && count > 0 { + bu := &limiterBucket{ + cur: count * 10, + max: count * 10, + duration: duration * 100, + batch: count, + } + go func() { + for { + time.Sleep(time.Millisecond * time.Duration(bu.duration)) + if bu.cur+bu.batch > bu.max { + bu.cur = bu.max + } else { + atomic.AddInt64(&bu.cur, bu.batch) + } + } + }() + l.buckets[user] = bu + } else { + if _, ok := l.buckets[user]; ok { + delete(l.buckets, user) + } + } + // zero means not limit + if cur > 0 { + l.concurrent[user] = make(chan bool, cur) + } else { + if _, ok := l.concurrent[user]; ok { + delete(l.concurrent, user) + } + } + return nil +} + +// Reload parses config from r, then live reloads the LocalLimiter. +func (l *LocalLimiter) Reload(r io.Reader) error { + var period time.Duration + kvs := make(map[string]string) + + if r == nil || l.Stopped() { + return nil + } + + // splitLine splits a line text by white space. + // A line started with '#' will be ignored, otherwise it is valid. + split := func(line string) []string { + if line == "" { + return nil + } + line = strings.Replace(line, "\t", " ", -1) + line = strings.TrimSpace(line) + + if strings.IndexByte(line, '#') == 0 { + return nil + } + + var ss []string + for _, s := range strings.Split(line, " ") { + if s = strings.TrimSpace(s); s != "" { + ss = append(ss, s) + } + } + return ss + } + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + ss := split(line) + if len(ss) == 0 { + continue + } + + switch ss[0] { + case "reload": // reload option + if len(ss) > 1 { + period, _ = time.ParseDuration(ss[1]) + } + default: + var k, v string + k = ss[0] + if len(ss) > 2 { + v = ss[2] + } + kvs[k] = v + } + } + + if err := scanner.Err(); err != nil { + return err + } + + l.mux.Lock() + defer l.mux.Unlock() + + l.period = period + for user, args := range kvs { + err := l.AddRule(user, args) + if err != nil { + return err + } + } + + return nil +} + +// Period returns the reload period. +func (l *LocalLimiter) Period() time.Duration { + if l.Stopped() { + return -1 + } + + l.mux.RLock() + defer l.mux.RUnlock() + + return l.period +} + +// Stop stops reloading. +func (l *LocalLimiter) Stop() { + select { + case <-l.stopped: + default: + close(l.stopped) + } +} + +// Stopped checks whether the reloader is stopped. +func (l *LocalLimiter) Stopped() bool { + select { + case <-l.stopped: + return true + default: + return false + } +} + +func (l *LocalLimiter) getToken(key string) bool { + b, ok := l.buckets[key] + if !ok || b == nil { + return true + } + if b.cur <= 0 { + return false + } + atomic.AddInt64(&b.cur, -10) + return true +} + +func (l *LocalLimiter) checkConcurrent(key string) (func(), bool) { + c, ok := l.concurrent[key] + if !ok || c == nil { + return func() {}, true + } + select { + case c <- true: + return func() { + <-c + }, true + default: + return nil, false + } +} diff --git a/limiter_test.go b/limiter_test.go new file mode 100644 index 00000000..d491352f --- /dev/null +++ b/limiter_test.go @@ -0,0 +1,69 @@ +package gost + +import ( + "fmt" + "testing" +) + +func TestNewLocalLimiter(t *testing.T) { + items := []struct { + user string + args string + success bool + }{ + {"admin", "10,1", true}, + {"admin", "", true}, + {"admin", "10,1,1", true}, + {"admin", "10", false}, + {"admin", "0,1", true}, + {"admin", "0,1,1", true}, + {"admin", "a,b", false}, + {"", "", true}, + {"", "1,2", true}, + } + for i, item := range items { + item := item + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + _, err := NewLocalLimiter(item.user, item.args) + if (err == nil) != item.success { + t.Error("test NewLocalLimiter fail", item.user, item.args) + } + }) + } +} + +func TestCheckRate(t *testing.T) { + items := []struct { + user string + args string + testUser string + checkCount int + shouldSuccessCount int + }{ + {"admin", "10,3", "admin", 10, 3}, + {"admin", "10,3,0", "admin", 10, 3}, + {"admin", "10,3,2", "admin", 10, 2}, + {"admin", "0,0", "admin", 10, 10}, + {"admin", "10,3,5", "admin", 10, 3}, + {"admin", "10,3,5", "admin22", 10, 10}, + {"admin", "0,0,5", "admin", 10, 5}, + } + for i, item := range items { + item := item + t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { + l, err := NewLocalLimiter(item.user, item.args) + if err != nil { + t.Error("test NewLocalLimiter fail", item.user, item.args) + } + successCount := 0 + for j := 0; j < item.checkCount; j++ { + if _, ok := l.CheckRate(item.testUser, true); ok { + successCount++ + } + } + if successCount != item.shouldSuccessCount { + t.Error("test localLimiter fail", item) + } + }) + } +} diff --git a/relay.go b/relay.go index 74423f45..103cac3c 100644 --- a/relay.go +++ b/relay.go @@ -171,6 +171,17 @@ func (h *relayHandler) Handle(conn net.Conn) { log.Logf("[relay] %s -> %s : %s unauthorized", conn.RemoteAddr(), conn.LocalAddr(), user) return } + if h.options.Limiter != nil { + done, ok := h.options.Limiter.CheckRate(user, true) + if !ok { + resp.Status = relay.StatusForbidden + resp.WriteTo(conn) + log.Logf("[relay] %s -> %s : %s rate limiter", conn.RemoteAddr(), conn.LocalAddr(), user) + return + } else { + defer done() + } + } if raddr != "" { if len(h.group.Nodes()) > 0 { diff --git a/socks.go b/socks.go index d59dd89e..5554a724 100644 --- a/socks.go +++ b/socks.go @@ -112,6 +112,7 @@ type serverSelector struct { // Users []*url.Userinfo Authenticator Authenticator TLSConfig *tls.Config + Limiter Limiter } func (selector *serverSelector) Methods() []uint8 { @@ -181,7 +182,14 @@ func (selector *serverSelector) OnSelected(method uint8, conn net.Conn) (net.Con log.Logf("[socks5] %s - %s: proxy authentication required", conn.RemoteAddr(), conn.LocalAddr()) return nil, gosocks5.ErrAuthFailure } - + if req.Username != "" && selector.Limiter != nil { + if _, ok := selector.Limiter.CheckRate(req.Username, false); !ok { + if Debug { + log.Logf("[http] %s <- %s rate limiter \n%s", conn.RemoteAddr(), conn.LocalAddr(), req.Username) + } + return nil, errors.New("rate limiter check fail") + } + } resp := gosocks5.NewUserPassResponse(gosocks5.UserPassVer, gosocks5.Succeeded) if err := resp.Write(conn); err != nil { log.Logf("[socks5] %s - %s: %s", conn.RemoteAddr(), conn.LocalAddr(), err) @@ -836,6 +844,7 @@ func (h *socks5Handler) Init(options ...HandlerOption) { // Users: h.options.Users, Authenticator: h.options.Authenticator, TLSConfig: tlsConfig, + Limiter: h.options.Limiter, } // methods that socks5 server supported h.selector.AddMethod(