Skip to content

Commit

Permalink
cmd serve: support admin api
Browse files Browse the repository at this point in the history
  • Loading branch information
rkonfj committed May 17, 2023
1 parent e75ee0e commit 298a228
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 33 deletions.
5 changes: 5 additions & 0 deletions cmd/serve/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ func init() {
RunE: startAction,
}
Cmd.Flags().String("acl", "acl.json", "file containing access control rules")
Cmd.Flags().String("admin", "", "admin key (leave blank to disable admin api)")
Cmd.Flags().String("copy-buf", "16Ki", "buffer size for copying network data")
Cmd.Flags().StringP("listen", "l", "localhost:9986", "http server listen address")
}
Expand All @@ -42,6 +43,10 @@ func processServerOptions(cmd *cobra.Command) (options server.Options, err error
if err != nil {
return
}
options.Admin, err = cmd.Flags().GetString("admin")
if err != nil {
return
}
copyBuf, err := cmd.Flags().GetString("copy-buf")
if err != nil {
return
Expand Down
116 changes: 89 additions & 27 deletions server/acl.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ var (
)

type ACL struct {
adminKey string
keys map[string]*key
stoFilePath string
sto *ACLStorage
Expand All @@ -36,7 +37,7 @@ type ACLStorage struct {
}

type Key struct {
Name string `json:"name"`
Name string `json:"name,omitempty"`
Key string `json:"key"`
Limit *Limit `json:"limit,omitempty"`
BytesUsage *api.BytesUsage `json:"bytesUsage,omitempty"`
Expand Down Expand Up @@ -82,10 +83,14 @@ func (k *key) outBytesLimited() bool {
return k.bytesUsage.In+k.bytesUsage.Out >= k.bytesLimit
}

func NewACL(aclPath string) (*ACL, error) {
func NewACL(aclPath, adminKey string) (*ACL, error) {
if len(adminKey) < 16 {
return nil, errors.New("the minimum admin key is 16 characters")
}
acl := &ACL{
keys: make(map[string]*key),
stoFilePath: aclPath,
adminKey: adminKey,
}

var sto ACLStorage
Expand Down Expand Up @@ -123,37 +128,17 @@ func NewACL(aclPath string) (*ACL, error) {
ke.bytesUsage = k.BytesUsage
}
acl.keys[k.Key] = ke
if k.Limit != nil {
if k.Limit.Bytes != "" {
b, err := humanize.ParseBytes(k.Limit.Bytes)
if err != nil {
return nil, err
}
ke.bytesLimit = b
}
if k.Limit.InBytes != "" {
b, err := humanize.ParseBytes(k.Limit.InBytes)
if err != nil {
return nil, err
}
ke.inBytesLimit = b
}
if k.Limit.OutBytes != "" {
b, err := humanize.ParseBytes(k.Limit.OutBytes)
if err != nil {
return nil, err
}
ke.outBytesLimit = b
}
ke.blacklist = k.Limit.Blacklist
ke.whitelist = k.Limit.Whitelist
}
acl.applyACLKeyLimit(ke, k.Limit)
}
logrus.Infof("acl: load %d keys", len(acl.keys))
go acl.aclPersistLoop()
return acl, nil
}

func (a *ACL) IsAdminAccess(key string) bool {
return a.adminKey != "" && a.adminKey == key
}

func (a *ACL) CheckKey(key string) error {
if k, ok := a.keys[key]; ok {
if k.inBytesLimited() {
Expand Down Expand Up @@ -210,6 +195,83 @@ func (a *ACL) UpdateBytesUsage(key string, in, out uint64) {
}
}

func (a *ACL) NewKey(name string) string {
k := uuid.NewString()

ke := &Key{
Name: name,
Key: k,
}
a.keys[k] = &key{
bytesUsage: &api.BytesUsage{},
}
a.stoUpdatePendingCountLock.Lock()
defer a.stoUpdatePendingCountLock.Unlock()
a.sto.Keys = append(a.sto.Keys, ke)
a.stoUpdatePendingCount++
return k
}

func (a *ACL) DelKey(key string) {
a.stoUpdatePendingCountLock.Lock()
defer a.stoUpdatePendingCountLock.Unlock()
delete(a.keys, key)
a.stoUpdatePendingCount++
for i, v := range a.sto.Keys {
if v.Key == key {
a.sto.Keys = append(a.sto.Keys[:i], a.sto.Keys[i+1:]...)
}
}
}

// Limit replace key's limit
func (a *ACL) Limit(key string, l *Limit) error {
a.stoUpdatePendingCountLock.Lock()
defer a.stoUpdatePendingCountLock.Unlock()
a.stoUpdatePendingCount++
if k, ok := a.keys[key]; ok {
err := a.applyACLKeyLimit(k, l)
if err != nil {
return err
}
for _, ke := range a.sto.Keys {
if ke.Key == key {
ke.Limit = l
}
}
}
return nil
}

func (a *ACL) applyACLKeyLimit(ke *key, l *Limit) error {
if l != nil {
if l.Bytes != "" {
b, err := humanize.ParseBytes(l.Bytes)
if err != nil {
return err
}
ke.bytesLimit = b
}
if l.InBytes != "" {
b, err := humanize.ParseBytes(l.InBytes)
if err != nil {
return err
}
ke.inBytesLimit = b
}
if l.OutBytes != "" {
b, err := humanize.ParseBytes(l.OutBytes)
if err != nil {
return err
}
ke.outBytesLimit = b
}
ke.blacklist = l.Blacklist
ke.whitelist = l.Whitelist
}
return nil
}

func (a *ACL) persist() error {
aclF, err := os.OpenFile(a.stoFilePath, os.O_WRONLY, 0644)
if err != nil {
Expand Down
76 changes: 76 additions & 0 deletions server/admin.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package server

import (
"encoding/json"
"net/http"

"github.com/rkonfj/toh/spec"
"github.com/sirupsen/logrus"
)

func (s *TohServer) registerAdminAPIIfEnabled() {
if s.options.Admin == "" {
return
}
http.HandleFunc("/admin/key", s.HandleAdminKey)
http.HandleFunc("/admin/limit", s.HandleAdminLimit)
logrus.Info("admin api(/admin/**) is enabled")
}

func (s *TohServer) HandleAdminKey(w http.ResponseWriter, r *http.Request) {
if !s.acl.IsAdminAccess(r.Header.Get(spec.HeaderHandshakeKey)) {
w.WriteHeader(http.StatusUnauthorized)
return
}
switch r.Method {
case http.MethodPost:
name := r.URL.Query().Get("name")
key := s.acl.NewKey(name)
w.Write([]byte(key))
case http.MethodDelete:
key := r.URL.Query().Get("key")
if len(key) == 0 {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("required key parameter not found in url"))
return
}
s.acl.DelKey(key)
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}

func (s *TohServer) HandleAdminLimit(w http.ResponseWriter, r *http.Request) {
if !s.acl.IsAdminAccess(r.Header.Get(spec.HeaderHandshakeKey)) {
w.WriteHeader(http.StatusUnauthorized)
return
}
switch r.Method {
case http.MethodPatch:
if r.Body == nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("missing body"))
return
}
var l Limit
err := json.NewDecoder(r.Body).Decode(&l)
if err != nil {
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(err.Error()))
return
}
key := r.URL.Query().Get("key")
if key == "" {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("required key parameter not found in url"))
return
}
err = s.acl.Limit(key, &l)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte(err.Error()))
}
default:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}
4 changes: 4 additions & 0 deletions server/api/stats.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,7 @@ type BytesUsage struct {
In uint64 `json:"in"`
Out uint64 `json:"out"`
}

type AdminStats struct {
ConnCount int64 `json:"connCount"`
}
8 changes: 6 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@ type Options struct {
Listen string
ACL string
Buf uint64
Admin string
}

func NewTohServer(options Options) (*TohServer, error) {
acl, err := NewACL(options.ACL)
acl, err := NewACL(options.ACL, options.Admin)
if err != nil {
return nil, err
}
Expand All @@ -43,9 +44,12 @@ func NewTohServer(options Options) (*TohServer, error) {
}

func (s *TohServer) Run() {
s.startTrafficEventConsumeDaemon()
s.registerAdminAPIIfEnabled()

http.HandleFunc("/stats", s.HandleShowStats)
http.HandleFunc("/", s.HandleUpgradeWebSocket)
s.startTrafficEventConsumeDaemon()

logrus.Infof("server listen on %s now", s.options.Listen)
err := http.ListenAndServe(s.options.Listen, nil)
if err != nil {
Expand Down
15 changes: 11 additions & 4 deletions socks5/socks5.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,21 @@ func (s *Socks5Server) handshake(ctx context.Context, conn net.Conn) (
respHostUnreachable(conn)
return
}
conn.Write([]byte{0x05, 0x00, 0x00, 0x01})

if netConn.LocalAddr() != nil {
addrPort := netip.MustParseAddrPort(netConn.LocalAddr().String())
ip := addrPort.Addr().As4()
conn.Write(ip[:])
if addrPort.Addr().Is6() {
ip := addrPort.Addr().As16()
conn.Write([]byte{5, 0, 0, 4})
conn.Write(ip[:])
} else {
ip := addrPort.Addr().As4()
conn.Write([]byte{5, 0, 0, 1})
conn.Write(ip[:])
}
conn.Write(spec.Uint16ToBytes(addrPort.Port()))
} else {
conn.Write([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
conn.Write([]byte{5, 0, 0, 1, 0, 0, 0, 0, 0, 0})
}
closeConn = false
return
Expand Down

0 comments on commit 298a228

Please sign in to comment.