From 72e01184ca81bd0e64e81d2b45effaf9cf1d80e4 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Mon, 10 Jun 2024 19:52:37 -0600 Subject: [PATCH 01/36] start refactor --- Dockerfile | 14 + hijacker.go => api/bmclapi/hijacker.go | 0 api.go => api/v0/api.go | 0 api_token.go => api/v0/api_token.go | 0 cluster/cluster.go | 76 +++++ cluster/handler.go | 20 ++ cluster/keepalive.go | 45 +++ cluster/status.go | 61 ++++ handler.go | 2 - http_listener.go | 286 ------------------ internal/build/version.go | 1 + storage/manager.go | 114 +++++++ .../cmd_compress.go | 0 cmd_webdav.go => sub_commands/cmd_webdav.go | 0 util.go | 16 - bar.go => utils/bar.go | 18 +- utils/http.go | 260 ++++++++++++++++ exitcodes.go => utils/rand.go | 32 +- utils/util.go | 10 +- 19 files changed, 623 insertions(+), 332 deletions(-) rename hijacker.go => api/bmclapi/hijacker.go (100%) rename api.go => api/v0/api.go (100%) rename api_token.go => api/v0/api_token.go (100%) create mode 100644 cluster/cluster.go create mode 100644 cluster/handler.go create mode 100644 cluster/keepalive.go create mode 100644 cluster/status.go delete mode 100644 http_listener.go create mode 100644 storage/manager.go rename cmd_compress.go => sub_commands/cmd_compress.go (100%) rename cmd_webdav.go => sub_commands/cmd_webdav.go (100%) rename bar.go => utils/bar.go (79%) rename exitcodes.go => utils/rand.go (61%) diff --git a/Dockerfile b/Dockerfile index 6b0e9683..c00e1cc6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,18 @@ # syntax=docker/dockerfile:1 +# Copyright (C) 2023 Kevin Z +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . ARG GO_VERSION=1.21 ARG REPO=github.com/LiterMC/go-openbmclapi diff --git a/hijacker.go b/api/bmclapi/hijacker.go similarity index 100% rename from hijacker.go rename to api/bmclapi/hijacker.go diff --git a/api.go b/api/v0/api.go similarity index 100% rename from api.go rename to api/v0/api.go diff --git a/api_token.go b/api/v0/api_token.go similarity index 100% rename from api_token.go rename to api/v0/api_token.go diff --git a/cluster/cluster.go b/cluster/cluster.go new file mode 100644 index 00000000..398d32a3 --- /dev/null +++ b/cluster/cluster.go @@ -0,0 +1,76 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +import ( + "context" + "sync/atomic" +) + +type Cluster struct { + id string + secret string + host string + port uint16 + + storageManager *storage.Manager + storages []int // the index of storages in the storage manager + + status atomic.Int32 +} + +// ID returns the cluster id +func (cr *Cluster) ID() string { + return cr.id +} + +// Host returns the cluster public host +func (cr *Cluster) Host() string { + return cr.host +} + +// Port returns the cluster public port +func (cr *Cluster) Port() string { + return cr.port +} + +// Init do setup on the cluster +// Init should only be called once during the cluster's whole life +// The context passed in only affect the logical of Init method +func (cr *Cluster) Init(ctx context.Context) error { + return +} + +// Enable send enable packet to central server +// The context passed in only affect the logical of Enable method +func (cr *Cluster) Enable(ctx context.Context) error { + return +} + +// Disable send disable packet to central server +// The context passed in only affect the logical of Disable method +func (cr *Cluster) Disable(ctx context.Context) error { + return +} + +// setDisabled marked the cluster as disabled or kicked +func (cr *Cluster) setDisabled(kicked bool) { + return +} diff --git a/cluster/handler.go b/cluster/handler.go new file mode 100644 index 00000000..c79366ca --- /dev/null +++ b/cluster/handler.go @@ -0,0 +1,20 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster diff --git a/cluster/keepalive.go b/cluster/keepalive.go new file mode 100644 index 00000000..c57030eb --- /dev/null +++ b/cluster/keepalive.go @@ -0,0 +1,45 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +import ( + "context" +) + +type KeepAliveRes int + +// Succeed returns true when KeepAlive actions is succeed as well as the cluster is not kicked by the controller +func (r KeepAliveRes) Succeed() bool { + return r == 0 +} + +// Failed returns true when KeepAlive action is succeed but the cluster is forced kick by the controller +func (r KeepAliveRes) Kicked() bool { + return r == 1 +} + +// Failed returns true when KeepAlive is interrupted by unexpected reason +func (r KeepAliveRes) Failed() bool { + return r == 2 +} + +func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { + // +} diff --git a/cluster/status.go b/cluster/status.go new file mode 100644 index 00000000..922d96ff --- /dev/null +++ b/cluster/status.go @@ -0,0 +1,61 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +const ( + clusterDisabled = 0 + clusterEnabled = 1 + clusterEnabling = 2 + clusterKicked = 4 +) + +// Enabled returns true if the cluster is enabled or enabling +func (cr *Cluster) Enabled() bool { + s := cr.status.Load() + return s == clusterEnabled || s == clusterEnabling +} + +// Running returns true if the cluster is completely enabled +func (cr *Cluster) Running() bool { + return cr.status.Load() == clusterEnabled +} + +// Disabled returns true if the cluster is disabled manually +func (cr *Cluster) Disabled() bool { + return cr.status.Load() == clusterDisabled +} + +// IsKicked returns true if the cluster is kicked by the central server +func (cr *Cluster) IsKicked() bool { + return cr.status.Load() == clusterKicked +} + +// WaitForEnable returns a channel which receives true when cluster enabled succeed, or receives false when it failed to enable +// If the cluster is already enable, the channel always returns true +// The channel should not be used multiple times +func (cr *Cluster) WaitForEnable() <-chan bool { + ch := make(chan bool, 1) + if cr.Running() { + ch <- true + } else { + cr.enableSignals = append(cr.enableSignals, ch) + } + return ch +} diff --git a/handler.go b/handler.go index baf45d26..7aafccab 100644 --- a/handler.go +++ b/handler.go @@ -360,8 +360,6 @@ var emptyHashes = func() (hashes map[string]struct{}) { return }() -var HeaderXPoweredBy = fmt.Sprintf("go-openbmclapi/%s; url=https://github.com/LiterMC/go-openbmclapi", build.BuildVersion) - //go:embed robots.txt var robotTxtContent string diff --git a/http_listener.go b/http_listener.go deleted file mode 100644 index 452583e6..00000000 --- a/http_listener.go +++ /dev/null @@ -1,286 +0,0 @@ -/** - * OpenBmclAPI (Golang Edition) - * Copyright (C) 2024 Kevin Z - * All rights reserved - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published - * by the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "bufio" - "bytes" - "crypto/tls" - "io" - "net" - "net/http" - "net/url" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" -) - -// httpTLSListener will serve a http or a tls connection -// When Accept was called, if a pure http request is received, -// it will response and redirect the client to the https protocol. -// Else it will just return the tls connection -type httpTLSListener struct { - net.Listener - TLSConfig *tls.Config - mux sync.RWMutex - hosts []string - port string - - accepting atomic.Bool - acceptedCh chan net.Conn - errCh chan error -} - -var _ net.Listener = (*httpTLSListener)(nil) - -func newHttpTLSListener(l net.Listener, cfg *tls.Config, publicHosts []string, port uint16) net.Listener { - return &httpTLSListener{ - Listener: l, - TLSConfig: cfg, - hosts: publicHosts, - port: strconv.Itoa((int)(port)), - acceptedCh: make(chan net.Conn, 1), - errCh: make(chan error, 1), - } -} - -func (s *httpTLSListener) Close() (err error) { - err = s.Listener.Close() - select { - case conn := <-s.acceptedCh: - conn.Close() - default: - } - select { - case <-s.errCh: - default: - } - return -} - -func (s *httpTLSListener) SetPublicPort(port string) { - s.mux.Lock() - defer s.mux.Unlock() - s.port = port -} - -func (s *httpTLSListener) getPublicPort() string { - s.mux.RLock() - defer s.mux.RUnlock() - return s.port -} - -func (s *httpTLSListener) maybeHTTPConn(c *connHeadReader) (ishttp bool) { - if len(s.hosts) == 0 { - return false - } - var buf [4096]byte - i, n := 0, 0 -READ_HEAD: - for { - m, err := c.ReadForHead(buf[i:]) - if err != nil { - return false - } - n += m - for ; i < n; i++ { - b := buf[i] - switch { - case b == '\r': // first line of HTTP request end - break READ_HEAD - case b < 0x20 || 0x7e < b: // not in ascii printable range - return false - } - } - } - // check if it's actually a HTTP request, not something else - method, rest, _ := bytes.Cut(buf[:i], ([]byte)(" ")) - uurl, proto, _ := bytes.Cut(rest, ([]byte)(" ")) - if len(method) == 0 || len(uurl) == 0 || len(proto) == 0 { - return false - } - _, _, ok := http.ParseHTTPVersion((string)(proto)) - if !ok { - return false - } - _, err := url.ParseRequestURI((string)(uurl)) - if err != nil { - return false - } - return true -} - -func (s *httpTLSListener) accepter() { - for s.accepting.CompareAndSwap(false, true) { - conn, err := s.Listener.Accept() - s.accepting.Store(false) - if err != nil { - s.errCh <- err - return - } - go s.accepter() - hr := &connHeadReader{Conn: conn} - hr.SetReadDeadline(time.Now().Add(time.Second * 5)) - ishttp := s.maybeHTTPConn(hr) - hr.SetReadDeadline(time.Time{}) - if !ishttp { - // if it's not a http connection, it must be a tls connection - s.acceptedCh <- tls.Server(hr, s.TLSConfig) - return - } - go s.serveHTTP(hr) - } -} - -func (s *httpTLSListener) serveHTTP(conn net.Conn) { - defer conn.Close() - - conn.SetReadDeadline(time.Now().Add(time.Second * 15)) - req, err := http.ReadRequest(bufio.NewReader(conn)) - if err != nil { - return - } - conn.SetReadDeadline(time.Time{}) - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host - } - inhosts := false - if host != "" { - host = strings.ToLower(host) - for _, h := range s.hosts { - if h, ok := strings.CutPrefix(h, "*."); ok { - if strings.HasSuffix(host, h) { - inhosts = true - break - } - } else if h == host { - inhosts = true - break - } - } - } - u := *req.URL - u.Scheme = "https" - if !inhosts { - for _, h := range s.hosts { - if !strings.HasSuffix(h, "*.") { - host = h - break - } - } - } - if host == "" { - // we have nowhere to redirect - body := strings.NewReader("Sent http request on https server") - resp := &http.Response{ - StatusCode: http.StatusBadRequest, - ProtoMajor: req.ProtoMajor, - ProtoMinor: req.ProtoMinor, - Request: req, - Header: http.Header{ - "Content-Type": {"text/plain"}, - "X-Powered-By": {HeaderXPoweredBy}, - }, - ContentLength: (int64)(body.Len()), - } - conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) - resp.Write(conn) - io.Copy(conn, body) - return - } - u.Host = net.JoinHostPort(host, s.getPublicPort()) - resp := &http.Response{ - StatusCode: http.StatusPermanentRedirect, - ProtoMajor: req.ProtoMajor, - ProtoMinor: req.ProtoMinor, - Request: req, - Header: http.Header{ - "Location": {u.String()}, - "X-Powered-By": {HeaderXPoweredBy}, - }, - } - conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) - resp.Write(conn) -} - -func (s *httpTLSListener) Accept() (conn net.Conn, err error) { - select { - case conn = <-s.acceptedCh: - return - case err = <-s.errCh: - return - default: - } - go s.accepter() - select { - case conn = <-s.acceptedCh: - case err = <-s.errCh: - } - return -} - -// connHeadReader is used by httpTLSListener -// it wraps a net.Conn, and the first few bytes can be read multiple times -// the head buf will be discard when the main content starts to be read -type connHeadReader struct { - net.Conn - head []byte - headi int - headDone bool // the main content had start been read -} - -func (c *connHeadReader) Head() []byte { - return c.head -} - -// ReadForHead will read the underlying net.Conn, -// and append the data to its internal head buffer -func (c *connHeadReader) ReadForHead(buf []byte) (n int, err error) { - if c.headDone { - panic("connHeadReader: Content is already started to read") - } - n, err = c.Conn.Read(buf) - c.head = append(c.head, buf[:n]...) - return -} - -type connReaderForHead struct { - c *connHeadReader -} - -func (c *connReaderForHead) Read(buf []byte) (n int, err error) { - return c.c.ReadForHead(buf) -} - -func (c *connHeadReader) Read(buf []byte) (n int, err error) { - if c.headi < len(c.head) { - n = copy(buf, c.head[c.headi:]) - c.headi += n - return - } - if !c.headDone { - c.head = nil - c.headDone = true - } - return c.Conn.Read(buf) -} diff --git a/internal/build/version.go b/internal/build/version.go index f95a7d04..2b8613ad 100644 --- a/internal/build/version.go +++ b/internal/build/version.go @@ -29,3 +29,4 @@ var BuildVersion string = "dev" var ClusterUserAgent string = fmt.Sprintf("openbmclapi-cluster/%s", ClusterVersion) var ClusterUserAgentFull string = fmt.Sprintf("%s go-openbmclapi-cluster/%s", ClusterUserAgent, BuildVersion) +var HeaderXPoweredBy = fmt.Sprintf("go-openbmclapi/%s; url=https://github.com/LiterMC/go-openbmclapi", BuildVersion) diff --git a/storage/manager.go b/storage/manager.go new file mode 100644 index 00000000..d4d4d554 --- /dev/null +++ b/storage/manager.go @@ -0,0 +1,114 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package storage + +import ( + "github.com/LiterMC/go-openbmclapi/utils" +) + +// Manager manages a list of storages +type Manager struct { + Options []StorageOption + Storages []Storage + weights []uint + totalWeight uint + totalWeightsCache utils.SyncMap[[]string, *weightCache] +} + +func NewManager(opts []StorageOption, storages []Storage) (m *Manager) { + m = new(Manager) + m.Options = opts + m.Storages = storages + m.weights = make([]uint, len(opts)) + m.totalWeight = 0 + m.totalWeightsCache = utils.NewSyncMap[[]string, *weightCache]() + for i, s := range opts { + m.weights[i] = s.Weight + m.totalWeight += s.Weight + } + return +} + +type weightCache struct{ + weights []uint + total uint +} + +func (m *Manager) ForEachFromRandom(storages []int, cb func(s Storage) (done bool)) (done bool) { + data, _ := m.totalWeightsCache.GetOrSet(storages, func() (c *weightCache) { + c = new(weightCache) + c.weights = make([]int, len(storages)) + for i, j := range storages { + w := m.weights[j] + c.weights[i] = w + c.total += w + } + return + }) + return forEachFromRandomIndexWithPossibility(data.weights, data.total, cb) +} + +func forEachFromRandomIndex(leng int, cb func(i int) (done bool)) (done bool) { + if leng <= 0 { + return false + } + start := utils.RandIntn(leng) + for i := start; i < leng; i++ { + if cb(i) { + return true + } + } + for i := 0; i < start; i++ { + if cb(i) { + return true + } + } + return false +} + +func forEachFromRandomIndexWithPossibility(poss []uint, total uint, cb func(i int) (done bool)) (done bool) { + leng := len(poss) + if leng == 0 { + return false + } + if total == 0 { + return forEachFromRandomIndex(leng, cb) + } + n := (uint)(utils.RandIntn((int)(total))) + start := 0 + for i, p := range poss { + if n < p { + start = i + break + } + n -= p + } + for i := start; i < leng; i++ { + if cb(i) { + return true + } + } + for i := 0; i < start; i++ { + if cb(i) { + return true + } + } + return false +} diff --git a/cmd_compress.go b/sub_commands/cmd_compress.go similarity index 100% rename from cmd_compress.go rename to sub_commands/cmd_compress.go diff --git a/cmd_webdav.go b/sub_commands/cmd_webdav.go similarity index 100% rename from cmd_webdav.go rename to sub_commands/cmd_webdav.go diff --git a/util.go b/util.go index 6d6d0c17..539cf9a3 100644 --- a/util.go +++ b/util.go @@ -83,22 +83,6 @@ func parseCertCommonName(body []byte) (string, error) { return cert.Subject.CommonName, nil } -var rd = func() chan int32 { - ch := make(chan int32, 64) - r := rand.New(rand.NewSource(time.Now().Unix())) - go func() { - for { - ch <- r.Int31() - } - }() - return ch -}() - -func randIntn(n int) int { - rn := <-rd - return (int)(rn) % n -} - func forEachFromRandomIndex(leng int, cb func(i int) (done bool)) (done bool) { if leng <= 0 { return false diff --git a/bar.go b/utils/bar.go similarity index 79% rename from bar.go rename to utils/bar.go index a89c29b6..7ca96bff 100644 --- a/bar.go +++ b/utils/bar.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package utils import ( "io" @@ -27,15 +27,15 @@ import ( "github.com/vbauerster/mpb/v8" ) -type ProxiedReader struct { +type ProxiedPBReader struct { io.Reader bar, total *mpb.Bar lastRead time.Time lastInc *atomic.Int64 } -func ProxyReader(r io.Reader, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedReader { - return &ProxiedReader{ +func ProxyPBReader(r io.Reader, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedPBReader { + return &ProxiedPBReader{ Reader: r, bar: bar, total: total, @@ -43,7 +43,7 @@ func ProxyReader(r io.Reader, bar, total *mpb.Bar, lastInc *atomic.Int64) *Proxi } } -func (p *ProxiedReader) Read(buf []byte) (n int, err error) { +func (p *ProxiedPBReader) Read(buf []byte) (n int, err error) { start := p.lastRead if start.IsZero() { start = time.Now() @@ -60,15 +60,15 @@ func (p *ProxiedReader) Read(buf []byte) (n int, err error) { return } -type ProxiedReadSeeker struct { +type ProxiedPBReadSeeker struct { io.ReadSeeker bar, total *mpb.Bar lastRead time.Time lastInc *atomic.Int64 } -func ProxyReadSeeker(r io.ReadSeeker, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedReadSeeker { - return &ProxiedReadSeeker{ +func ProxyPBReadSeeker(r io.ReadSeeker, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedPBReadSeeker { + return &ProxiedPBReadSeeker{ ReadSeeker: r, bar: bar, total: total, @@ -76,7 +76,7 @@ func ProxyReadSeeker(r io.ReadSeeker, bar, total *mpb.Bar, lastInc *atomic.Int64 } } -func (p *ProxiedReadSeeker) Read(buf []byte) (n int, err error) { +func (p *ProxiedPBReadSeeker) Read(buf []byte) (n int, err error) { start := p.lastRead if start.IsZero() { start = time.Now() diff --git a/utils/http.go b/utils/http.go index ad2f4193..65809812 100644 --- a/utils/http.go +++ b/utils/http.go @@ -21,14 +21,23 @@ package utils import ( "bufio" + "bytes" + "crypto/tls" "errors" "io" "net" "net/http" + "net/url" "path" "runtime" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/internal/build" ) type StatusResponseWriter struct { @@ -158,3 +167,254 @@ func (m *HttpMiddleWareHandler) UseFunc(fns ...MiddleWareFunc) { m.middles = append(m.middles, fn) } } + +// HTTPTLSListener will serve a http or a tls connection +// When Accept was called, if a pure http request is received, +// it will response and redirect the client to the https protocol. +// Else it will just return the tls connection +type HTTPTLSListener struct { + net.Listener + TLSConfig *tls.Config + mux sync.RWMutex + hosts []string + port string + + accepting atomic.Bool + acceptedCh chan net.Conn + errCh chan error +} + +var _ net.Listener = (*HTTPTLSListener)(nil) + +func NewHttpTLSListener(l net.Listener, cfg *tls.Config, publicHosts []string, port uint16) net.Listener { + return &HTTPTLSListener{ + Listener: l, + TLSConfig: cfg, + hosts: publicHosts, + port: strconv.Itoa((int)(port)), + acceptedCh: make(chan net.Conn, 1), + errCh: make(chan error, 1), + } +} + +func (s *HTTPTLSListener) Close() (err error) { + err = s.Listener.Close() + select { + case conn := <-s.acceptedCh: + conn.Close() + default: + } + select { + case <-s.errCh: + default: + } + return +} + +func (s *HTTPTLSListener) SetPublicPort(port string) { + s.mux.Lock() + defer s.mux.Unlock() + s.port = port +} + +func (s *HTTPTLSListener) GetPublicPort() string { + s.mux.RLock() + defer s.mux.RUnlock() + return s.port +} + +func (s *HTTPTLSListener) maybeHTTPConn(c *connHeadReader) (ishttp bool) { + if len(s.hosts) == 0 { + return false + } + var buf [4096]byte + i, n := 0, 0 +READ_HEAD: + for { + m, err := c.ReadForHead(buf[i:]) + if err != nil { + return false + } + n += m + for ; i < n; i++ { + b := buf[i] + switch { + case b == '\r': // first line of HTTP request end + break READ_HEAD + case b < 0x20 || 0x7e < b: // not in ascii printable range + return false + } + } + } + // check if it's actually a HTTP request, not something else + method, rest, _ := bytes.Cut(buf[:i], ([]byte)(" ")) + uurl, proto, _ := bytes.Cut(rest, ([]byte)(" ")) + if len(method) == 0 || len(uurl) == 0 || len(proto) == 0 { + return false + } + _, _, ok := http.ParseHTTPVersion((string)(proto)) + if !ok { + return false + } + _, err := url.ParseRequestURI((string)(uurl)) + if err != nil { + return false + } + return true +} + +func (s *HTTPTLSListener) accepter() { + for s.accepting.CompareAndSwap(false, true) { + conn, err := s.Listener.Accept() + s.accepting.Store(false) + if err != nil { + s.errCh <- err + return + } + go s.accepter() + hr := &connHeadReader{Conn: conn} + hr.SetReadDeadline(time.Now().Add(time.Second * 5)) + ishttp := s.maybeHTTPConn(hr) + hr.SetReadDeadline(time.Time{}) + if !ishttp { + // if it's not a http connection, it must be a tls connection + s.acceptedCh <- tls.Server(hr, s.TLSConfig) + return + } + go s.serveHTTP(hr) + } +} + +func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { + defer conn.Close() + + conn.SetReadDeadline(time.Now().Add(time.Second * 15)) + req, err := http.ReadRequest(bufio.NewReader(conn)) + if err != nil { + return + } + conn.SetReadDeadline(time.Time{}) + host, _, err := net.SplitHostPort(req.Host) + if err != nil { + host = req.Host + } + inhosts := false + if host != "" { + host = strings.ToLower(host) + for _, h := range s.hosts { + if h, ok := strings.CutPrefix(h, "*."); ok { + if strings.HasSuffix(host, h) { + inhosts = true + break + } + } else if h == host { + inhosts = true + break + } + } + } + u := *req.URL + u.Scheme = "https" + if !inhosts { + for _, h := range s.hosts { + if !strings.HasSuffix(h, "*.") { + host = h + break + } + } + } + if host == "" { + // we have nowhere to redirect + body := strings.NewReader("Sent http request on https server") + resp := &http.Response{ + StatusCode: http.StatusBadRequest, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Request: req, + Header: http.Header{ + "Content-Type": {"text/plain"}, + "X-Powered-By": {build.HeaderXPoweredBy}, + }, + ContentLength: (int64)(body.Len()), + } + conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) + resp.Write(conn) + io.Copy(conn, body) + return + } + u.Host = net.JoinHostPort(host, s.GetPublicPort()) + resp := &http.Response{ + StatusCode: http.StatusPermanentRedirect, + ProtoMajor: req.ProtoMajor, + ProtoMinor: req.ProtoMinor, + Request: req, + Header: http.Header{ + "Location": {u.String()}, + "X-Powered-By": {build.HeaderXPoweredBy}, + }, + } + conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) + resp.Write(conn) +} + +func (s *HTTPTLSListener) Accept() (conn net.Conn, err error) { + select { + case conn = <-s.acceptedCh: + return + case err = <-s.errCh: + return + default: + } + go s.accepter() + select { + case conn = <-s.acceptedCh: + case err = <-s.errCh: + } + return +} + +// connHeadReader is used by HTTPTLSListener +// it wraps a net.Conn, and the first few bytes can be read multiple times +// the head buf will be discard when the main content starts to be read +type connHeadReader struct { + net.Conn + head []byte + headi int + headDone bool // the main content had start been read +} + +func (c *connHeadReader) Head() []byte { + return c.head +} + +// ReadForHead will read the underlying net.Conn, +// and append the data to its internal head buffer +func (c *connHeadReader) ReadForHead(buf []byte) (n int, err error) { + if c.headDone { + panic("connHeadReader: Content is already started to read") + } + n, err = c.Conn.Read(buf) + c.head = append(c.head, buf[:n]...) + return +} + +type connReaderForHead struct { + c *connHeadReader +} + +func (c *connReaderForHead) Read(buf []byte) (n int, err error) { + return c.c.ReadForHead(buf) +} + +func (c *connHeadReader) Read(buf []byte) (n int, err error) { + if c.headi < len(c.head) { + n = copy(buf, c.head[c.headi:]) + c.headi += n + return + } + if !c.headDone { + c.head = nil + c.headDone = true + } + return c.Conn.Read(buf) +} diff --git a/exitcodes.go b/utils/rand.go similarity index 61% rename from exitcodes.go rename to utils/rand.go index 9d540825..37497eb3 100644 --- a/exitcodes.go +++ b/utils/rand.go @@ -17,19 +17,23 @@ * along with this program. If not, see . */ -package main - -const ( - CodeClientError = 0x01 - CodeServerError = 0x02 - CodeEnvironmentError = 0x04 - CodeUnexpectedError = 0x08 +import ( + "math/rand" + "time" ) -const ( - CodeClientOrServerError = CodeClientError | CodeServerError - CodeClientOrEnvionmentError = CodeClientError | CodeEnvironmentError - CodeClientUnexpectedError = CodeUnexpectedError | CodeClientError - CodeServerOrEnvionmentError = CodeServerError | CodeEnvironmentError - CodeServerUnexpectedError = CodeUnexpectedError | CodeServerError -) +var randInt32Ch = func() chan int32 { + ch := make(chan int32, 64) + r := rand.New(rand.NewSource(time.Now().Unix())) + go func() { + for { + ch <- r.Int31() + } + }() + return ch +}() + +func RandIntn(n int) int { + rn := <-randInt32Ch + return (int)(rn) % n +} diff --git a/utils/util.go b/utils/util.go index a286117d..78955cb2 100644 --- a/utils/util.go +++ b/utils/util.go @@ -66,17 +66,17 @@ func (m *SyncMap[K, V]) Has(k K) bool { return ok } -func (m *SyncMap[K, V]) GetOrSet(k K, setter func() V) (v V, has bool) { +func (m *SyncMap[K, V]) GetOrSet(k K, setter func() V) (v V, had bool) { m.l.RLock() - v, has = m.m[k] + v, had = m.m[k] m.l.RUnlock() - if has { + if had { return } m.l.Lock() defer m.l.Unlock() - v, has = m.m[k] - if !has { + v, had = m.m[k] + if !had { v = setter() m.m[k] = v } From e0815997026e9b164c19b2e745fceab0dee397b1 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Tue, 18 Jun 2024 10:06:42 -0600 Subject: [PATCH 02/36] refactoring cluster --- cluster.go | 309 +--------------------------------- cluster/cluster.go | 215 +++++++++++++++++++++-- token.go => cluster/config.go | 68 +++++++- cluster/handler.go | 29 ++++ cluster/keepalive.go | 69 ++++++-- cluster/socket.go | 33 ++++ cluster/status.go | 8 + config.go | 9 - lang/en/us.go | 1 + lang/zh/cn.go | 1 + log/tr.go | 44 +++++ main.go | 22 +-- storage/manager.go | 68 ++++++-- storage/storage.go | 12 +- storage/storage_local.go | 18 +- storage/storage_mount.go | 18 +- storage/storage_webdav.go | 18 +- utils/rand.go | 2 + 18 files changed, 546 insertions(+), 398 deletions(-) rename token.go => cluster/config.go (76%) create mode 100644 cluster/socket.go create mode 100644 log/tr.go diff --git a/cluster.go b/cluster.go index ea2a476b..592b7ed4 100644 --- a/cluster.go +++ b/cluster.go @@ -29,7 +29,6 @@ import ( "net/http" "os" "path/filepath" - "regexp" "runtime" "sync" "sync/atomic" @@ -52,10 +51,6 @@ import ( "github.com/LiterMC/go-openbmclapi/utils" ) -var ( - reFileHashMismatchError = regexp.MustCompile(` hash mismatch, expected ([0-9a-f]+), got ([0-9a-f]+)`) -) - type Cluster struct { host string // not the public access host, but maybe a public IP, or a host that will be resolved to the IP publicHosts []string // should not contains port, can be nil @@ -252,11 +247,11 @@ func (cr *Cluster) Init(ctx context.Context) (err error) { defer ticker.Stop() if err := cr.checkUpdate(); err != nil { - log.Errorf(Tr("error.update.check.failed"), err) + log.TrErrorf("error.update.check.failed", err) } for range ticker.C { if err := cr.checkUpdate(); err != nil { - log.Errorf(Tr("error.update.check.failed"), err) + log.TrErrorf("error.update.check.failed", err) } } }(cr.updateChecker) @@ -288,7 +283,7 @@ func (cr *Cluster) Connect(ctx context.Context) bool { _, err := cr.GetAuthToken(ctx) if err != nil { - log.Errorf(Tr("error.cluster.auth.failed"), err) + log.TrErrorf("error.cluster.auth.failed", err) osExit(CodeClientOrServerError) } @@ -338,7 +333,7 @@ func (cr *Cluster) Connect(ctx context.Context) bool { cr.reconnectCount++ if config.MaxReconnectCount > 0 && cr.reconnectCount >= config.MaxReconnectCount { if cr.shouldEnable.Load() { - log.Error(Tr("error.cluster.connect.failed.toomuch")) + log.TrErrorf("error.cluster.connect.failed.toomuch") osExit(CodeServerOrEnvionmentError) } } @@ -348,10 +343,10 @@ func (cr *Cluster) Connect(ctx context.Context) bool { }) engio.OnDialError(func(_ *engine.Socket, err error) { cr.reconnectCount++ - log.Errorf(Tr("error.cluster.connect.failed"), cr.reconnectCount, config.MaxReconnectCount, err) + log.TrErrorf("error.cluster.connect.failed", cr.reconnectCount, config.MaxReconnectCount, err) if config.MaxReconnectCount >= 0 && cr.reconnectCount >= config.MaxReconnectCount { if cr.shouldEnable.Load() { - log.Error(Tr("error.cluster.connect.failed.toomuch")) + log.TrErrorf("error.cluster.connect.failed.toomuch") osExit(CodeServerOrEnvionmentError) } } @@ -360,7 +355,7 @@ func (cr *Cluster) Connect(ctx context.Context) bool { cr.socket = socket.NewSocket(engio, socket.WithAuthTokenFn(func() string { token, err := cr.GetAuthToken(ctx) if err != nil { - log.Errorf(Tr("error.cluster.auth.failed"), err) + log.TrErrorf("error.cluster.auth.failed", err) osExit(CodeServerOrEnvionmentError) } return token @@ -373,7 +368,7 @@ func (cr *Cluster) Connect(ctx context.Context) bool { log.Debugf("shouldEnable is %v", cr.shouldEnable.Load()) if cr.shouldEnable.Load() { if err := cr.Enable(ctx); err != nil { - log.Errorf(Tr("error.cluster.enable.failed"), err) + log.TrErrorf("error.cluster.enable.failed", err) osExit(CodeClientOrEnvionmentError) } } @@ -406,229 +401,6 @@ func (cr *Cluster) Connect(ctx context.Context) bool { return true } -func (cr *Cluster) WaitForEnable() <-chan struct{} { - if cr.enabled.Load() { - return closedCh - } - - cr.mux.Lock() - defer cr.mux.Unlock() - - if cr.enabled.Load() { - return closedCh - } - ch := make(chan struct{}, 0) - cr.waitEnable = append(cr.waitEnable, ch) - return ch -} - -type EnableData struct { - Host string `json:"host"` - Port uint16 `json:"port"` - Version string `json:"version"` - Byoc bool `json:"byoc"` - NoFastEnable bool `json:"noFastEnable"` - Flavor ConfigFlavor `json:"flavor"` -} - -type ConfigFlavor struct { - Runtime string `json:"runtime"` - Storage string `json:"storage"` -} - -func (cr *Cluster) Enable(ctx context.Context) (err error) { - cr.mux.Lock() - defer cr.mux.Unlock() - - if cr.enabled.Load() { - log.Debug("Extra enable") - return - } - - if cr.socket != nil && !cr.socket.IO().Connected() && config.MaxReconnectCount == 0 { - log.Error(Tr("error.cluster.disconnected")) - osExit(CodeServerOrEnvionmentError) - return - } - - cr.shouldEnable.Store(true) - - storagesCount := make(map[string]int, 2) - for _, s := range cr.storageOpts { - switch s.Type { - case storage.StorageLocal: - storagesCount["file"]++ - case storage.StorageMount, storage.StorageWebdav: - storagesCount["alist"]++ - default: - log.Errorf("Unknown storage type %q", s.Type) - } - } - storageStr := "" - for s, _ := range storagesCount { - if len(storageStr) > 0 { - storageStr += "+" - } - storageStr += s - } - - log.Info(Tr("info.cluster.enable.sending")) - resCh, err := cr.socket.EmitWithAck("enable", EnableData{ - Host: cr.host, - Port: cr.publicPort, - Version: build.ClusterVersion, - Byoc: cr.byoc, - NoFastEnable: config.Advanced.NoFastEnable, - Flavor: ConfigFlavor{ - Runtime: "golang/" + runtime.GOOS + "-" + runtime.GOARCH, - Storage: storageStr, - }, - }) - if err != nil { - return - } - var data []any - tctx, cancel := context.WithTimeout(ctx, time.Minute*6) - select { - case <-tctx.Done(): - cancel() - return tctx.Err() - case data = <-resCh: - cancel() - } - log.Debug("got enable ack:", data) - if ero := data[0]; ero != nil { - if ero, ok := ero.(map[string]any); ok { - if msg, ok := ero["message"].(string); ok { - if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { - hash := hashMismatch[1] - log.Warnf("Detected hash mismatch error, removing bad file %s", hash) - for _, s := range cr.storages { - s.Remove(hash) - } - } - return fmt.Errorf("Enable failed: %v", msg) - } - } - return fmt.Errorf("Enable failed: %v", ero) - } - if !data[1].(bool) { - return errors.New("Enable ack non true value") - } - log.Info(Tr("info.cluster.enabled")) - cr.reconnectCount = 0 - cr.disabled = make(chan struct{}, 0) - cr.enabled.Store(true) - for _, ch := range cr.waitEnable { - close(ch) - } - cr.waitEnable = cr.waitEnable[:0] - go cr.notifyManager.OnEnabled() - - const maxFailCount = 3 - var ( - keepaliveCtx context.Context - failedCount = 0 - ) - keepaliveCtx, cr.cancelKeepalive = context.WithCancel(ctx) - createInterval(keepaliveCtx, func() { - tctx, cancel := context.WithTimeout(keepaliveCtx, KeepAliveInterval/2) - status := cr.KeepAlive(tctx) - cancel() - if status == 0 { - failedCount = 0 - return - } - if status == -1 { - log.Errorf("Kicked by remote server!!!") - osExit(CodeEnvironmentError) - return - } - if keepaliveCtx.Err() == nil { - if tctx.Err() != nil { - failedCount++ - log.Warnf("keep-alive failed (%d/%d)", failedCount, maxFailCount) - if failedCount < maxFailCount { - return - } - } - log.Info(Tr("info.cluster.reconnect.keepalive")) - cr.disable(ctx) - log.Info(Tr("info.cluster.reconnecting")) - if !cr.Connect(ctx) { - log.Error(Tr("error.cluster.reconnect.failed")) - if ctx.Err() != nil { - return - } - osExit(CodeServerOrEnvionmentError) - } - if err := cr.Enable(ctx); err != nil { - log.Errorf(Tr("error.cluster.enable.failed"), err) - if ctx.Err() != nil { - return - } - osExit(CodeClientOrEnvionmentError) - } - } - }, KeepAliveInterval) - return -} - -// KeepAlive will fresh hits & hit bytes data and send the keep-alive packet -func (cr *Cluster) KeepAlive(ctx context.Context) (status int) { - hits, hbts := cr.stats.GetTmpHits() - lhits, lhbts := cr.lastHits.Load(), cr.lastHbts.Load() - hits2, hbts2 := cr.statOnlyHits.Load(), cr.statOnlyHbts.Load() - ahits, ahbts := hits-lhits-hits2, hbts-lhbts-hbts2 - resCh, err := cr.socket.EmitWithAck("keep-alive", Map{ - "time": time.Now().UTC().Format("2006-01-02T15:04:05Z"), - "hits": ahits, - "bytes": ahbts, - }) - go cr.notifyManager.OnReportStatus(&cr.stats) - - if e := cr.stats.Save(cr.dataDir); e != nil { - log.Errorf(Tr("error.cluster.stat.save.failed"), e) - } - if err != nil { - log.Errorf(Tr("error.cluster.keepalive.send.failed"), err) - return 1 - } - var data []any - select { - case <-ctx.Done(): - return 1 - case data = <-resCh: - } - log.Debugf("Keep-alive response: %v", data) - if ero := data[0]; len(data) <= 1 || ero != nil { - if ero, ok := ero.(map[string]any); ok { - if msg, ok := ero["message"].(string); ok { - if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { - hash := hashMismatch[1] - log.Warnf("Detected hash mismatch error, removing bad file %s", hash) - for _, s := range cr.storages { - s.Remove(hash) - } - } - log.Errorf(Tr("error.cluster.keepalive.failed"), msg) - return 1 - } - } - log.Errorf(Tr("error.cluster.keepalive.failed"), ero) - return 1 - } - log.Infof(Tr("info.cluster.keepalive.success"), ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) - cr.lastHits.Store(hits) - cr.lastHbts.Store(hbts) - cr.statOnlyHits.Add(-hits2) - cr.statOnlyHbts.Add(-hbts2) - if data[1] == false { - return -1 - } - return 0 -} - func (cr *Cluster) disconnected() bool { cr.mux.Lock() defer cr.mux.Unlock() @@ -644,11 +416,6 @@ func (cr *Cluster) disconnected() bool { return true } -func (cr *Cluster) Disable(ctx context.Context) (ok bool) { - cr.shouldEnable.Store(false) - return cr.disable(ctx) -} - func (cr *Cluster) disable(ctx context.Context) (ok bool) { cr.mux.Lock() defer cr.mux.Unlock() @@ -698,63 +465,3 @@ func (cr *Cluster) disable(ctx context.Context) (ok bool) { log.Warn(Tr("warn.cluster.disabled")) return } - -func (cr *Cluster) Enabled() bool { - return cr.enabled.Load() -} - -func (cr *Cluster) Disabled() <-chan struct{} { - cr.mux.RLock() - defer cr.mux.RUnlock() - return cr.disabled -} - -type CertKeyPair struct { - Cert string `json:"cert"` - Key string `json:"key"` -} - -func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error) { - resCh, err := cr.socket.EmitWithAck("request-cert") - if err != nil { - return - } - var data []any - select { - case <-ctx.Done(): - return nil, ctx.Err() - case data = <-resCh: - } - if ero := data[0]; ero != nil { - err = fmt.Errorf("socket.io remote error: %v", ero) - return - } - pair := data[1].(map[string]any) - ckp = &CertKeyPair{ - Cert: pair["cert"].(string), - Key: pair["key"].(string), - } - return -} - -func (cr *Cluster) GetConfig(ctx context.Context) (cfg *OpenbmclapiAgentConfig, err error) { - req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/configuration", nil) - if err != nil { - return - } - res, err := cr.cachedCli.Do(req) - if err != nil { - return - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - err = utils.NewHTTPStatusErrorFromResponse(res) - return - } - cfg = new(OpenbmclapiAgentConfig) - if err = json.NewDecoder(res.Body).Decode(cfg); err != nil { - cfg = nil - return - } - return -} diff --git a/cluster/cluster.go b/cluster/cluster.go index 398d32a3..f7198d57 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -21,56 +21,243 @@ package cluster import ( "context" + "fmt" + "regexp" + "runtime" + "sync" "sync/atomic" + "time" + + "github.com/LiterMC/socket.io" + + "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" +) + +var ( + reFileHashMismatchError = regexp.MustCompile(` hash mismatch, expected ([0-9a-f]+), got ([0-9a-f]+)`) ) +type ClusterOptions struct { + Id string `json:"id" yaml:"id"` + Secret string `json:"secret" yaml:"secret"` + PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` + Prefix string `json:"prefix" yaml:"prefix"` +} + +type ClusterGeneralConfig struct { + Host string `json:"host" yaml:"host"` + Port uint16 `json:"port" yaml:"port"` + Byoc bool `json:"byoc" yaml:"byoc"` + NoFastEnable bool `json:"no-fast-enable" yaml:"no-fast-enable"` +} + type Cluster struct { - id string - secret string - host string - port uint16 + opts ClusterOptions + gcfg ClusterGeneralConfig storageManager *storage.Manager storages []int // the index of storages in the storage manager + enableSignals []chan bool + disableSignal chan struct{} + + mux sync.RWMutex status atomic.Int32 + socket *socket.Socket +} + +func NewCluster( + opts ClusterOptions, gcfg ClusterGeneralConfig, + storageManager *storage.Manager, storages []int, +) (cr *Cluster) { + cr = &Cluster{ + opts: opts, + gcfg: gcfg, + + storageManager: storageManager, + storages: storages, + } + return } // ID returns the cluster id func (cr *Cluster) ID() string { - return cr.id + return cr.opts.Id } // Host returns the cluster public host func (cr *Cluster) Host() string { - return cr.host + return cr.gcfg.Host } // Port returns the cluster public port -func (cr *Cluster) Port() string { - return cr.port +func (cr *Cluster) Port() uint16 { + return cr.gcfg.Port +} + +// PublicHosts returns the cluster public hosts +func (cr *Cluster) PublicHosts() []string { + return cr.opts.PublicHosts } // Init do setup on the cluster // Init should only be called once during the cluster's whole life // The context passed in only affect the logical of Init method func (cr *Cluster) Init(ctx context.Context) error { - return + return nil +} + +type EnableData struct { + Host string `json:"host"` + Port uint16 `json:"port"` + Version string `json:"version"` + Byoc bool `json:"byoc"` + NoFastEnable bool `json:"noFastEnable"` + Flavor ConfigFlavor `json:"flavor"` +} + +type ConfigFlavor struct { + Runtime string `json:"runtime"` + Storage string `json:"storage"` } // Enable send enable packet to central server // The context passed in only affect the logical of Enable method func (cr *Cluster) Enable(ctx context.Context) error { - return + if cr.status.Load() == clusterEnabled { + return nil + } + cr.mux.Lock() + defer cr.mux.Unlock() + if cr.status.Load() == clusterEnabled { + return nil + } + defer func() { + enabled := cr.Running() + for _, ch := range cr.enableSignals { + ch <- enabled + } + cr.enableSignals = cr.enableSignals[:0] + }() + oldStatus := cr.status.Swap(clusterEnabling) + defer cr.status.CompareAndSwap(clusterEnabling, oldStatus) + + storageStr := cr.storageManager.GetFlavorString(cr.storages) + + log.TrInfof("info.cluster.enable.sending") + resCh, err := cr.socket.EmitWithAck("enable", EnableData{ + Host: cr.gcfg.Host, + Port: cr.gcfg.Port, + Version: build.ClusterVersion, + Byoc: cr.gcfg.Byoc, + NoFastEnable: cr.gcfg.NoFastEnable, + Flavor: ConfigFlavor{ + Runtime: "golang/" + runtime.GOOS + "-" + runtime.GOARCH, + Storage: storageStr, + }, + }) + if err != nil { + return err + } + var data []any + { + tctx, cancel := context.WithTimeout(ctx, time.Minute*6) + select { + case data = <-resCh: + cancel() + case <-tctx.Done(): + cancel() + return tctx.Err() + } + } + log.Debug("got enable ack:", data) + if ero := data[0]; ero != nil { + if ero, ok := ero.(map[string]any); ok { + if msg, ok := ero["message"].(string); ok { + if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { + hash := hashMismatch[1] + log.Warnf(Tr("warn.cluster.detected.hash.mismatch"), hash) + cr.storageManager.RemoveForAll(hash) + } + return fmt.Errorf("Enable failed: %v", msg) + } + } + return fmt.Errorf("Enable failed: %v", ero) + } + if v := data[1]; !v.(bool) { + return fmt.Errorf("FATAL: Enable ack non true value, got (%T) %#v", v, v) + } + cr.disableSignal = make(chan struct{}, 0) + log.TrInfof("info.cluster.enabled") + cr.status.Store(clusterEnabled) + return nil } // Disable send disable packet to central server // The context passed in only affect the logical of Disable method +// Disable method is thread-safe, and it will wait until the first invoke exited func (cr *Cluster) Disable(ctx context.Context) error { - return + if cr.Enabled() { + cr.mux.Lock() + defer cr.mux.Unlock() + if cr.Enabled() { + defer close(cr.disableSignal) + defer cr.status.Store(clusterDisabled) + return cr.disable(ctx) + } + } + cr.mux.RLock() + disableCh := cr.disableSignal + cr.mux.RUnlock() + select { + case <-disableCh: + case <-ctx.Done(): + return ctx.Err() + } + return nil } -// setDisabled marked the cluster as disabled or kicked -func (cr *Cluster) setDisabled(kicked bool) { - return +// disable send disable packet to central server +// The context passed in only affect the logical of disable method +func (cr *Cluster) disable(ctx context.Context) error { + log.TrInfof("info.cluster.disabling") + resCh, err := cr.socket.EmitWithAck("disable", nil) + if err != nil { + return err + } + select { + case <-ctx.Done(): + return ctx.Err() + case data := <-resCh: + log.Debug("disable ack:", data) + if ero := data[0]; ero != nil { + return fmt.Errorf("Disable failed: %v", ero) + } else if !data[1].(bool) { + return errors.New("Disable acked non true value") + } + } + return nil +} + +// markDisconnected marked the cluster as error or kicked +func (cr *Cluster) markDisconnected(kicked bool) { + if !cr.Enabled() { + return + } + cr.mux.Lock() + defer cr.mux.Unlock() + if cr.Enabled() { + return + } + defer close(cr.disableSignal) + + var nextStatus int32 + if kicked { + nextStatus = clusterKicked + } else { + nextStatus = clusterError + } + cr.status.Store(nextStatus) } diff --git a/token.go b/cluster/config.go similarity index 76% rename from token.go rename to cluster/config.go index a86f1b91..7613cb1e 100644 --- a/token.go +++ b/cluster/config.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package cluster import ( "bytes" @@ -26,6 +26,7 @@ import ( "crypto/hmac" "encoding/hex" "encoding/json" + "fmt" "net/http" "net/url" "time" @@ -193,3 +194,68 @@ func (cr *Cluster) refreshToken(ctx context.Context, oldToken string) (token *Cl ExpireAt: time.Now().Add((time.Duration)(res.TTL)*time.Millisecond - 10*time.Second), }, nil } + +type OpenbmclapiAgentConfig struct { + Sync OpenbmclapiAgentSyncConfig `json:"sync"` +} + +type OpenbmclapiAgentSyncConfig struct { + Source string `json:"source"` + Concurrency int `json:"concurrency"` +} + +func (cr *Cluster) GetConfig(ctx context.Context) (cfg *OpenbmclapiAgentConfig, err error) { + req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/configuration", nil) + if err != nil { + return + } + res, err := cr.cachedCli.Do(req) + if err != nil { + return + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + err = utils.NewHTTPStatusErrorFromResponse(res) + return + } + cfg = new(OpenbmclapiAgentConfig) + if err = json.NewDecoder(res.Body).Decode(cfg); err != nil { + cfg = nil + return + } + return +} + +type CertKeyPair struct { + Cert string `json:"cert"` + Key string `json:"key"` +} + +func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error) { + resCh, err := cr.socket.EmitWithAck("request-cert") + if err != nil { + return + } + var data []any + select { + case <-ctx.Done(): + return nil, ctx.Err() + case data = <-resCh: + } + if ero := data[0]; ero != nil { + err = fmt.Errorf("socket.io remote error: %v", ero) + return + } + pair := data[1].(map[string]any) + ckp = new(CertKeyPair) + var ok bool + if ckp.Cert, ok = pair["cert"].(string); !ok { + err = fmt.Errorf(`"cert" is not a string, got %T`, pair["cert"]) + return + } + if ckp.Key, ok = pair["key"].(string); !ok { + err = fmt.Errorf(`"key" is not a string, got %T`, pair["key"]) + return + } + return +} diff --git a/cluster/handler.go b/cluster/handler.go index c79366ca..9f47f319 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -18,3 +18,32 @@ */ package cluster + +import ( + "net/http" +) + +func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string) { + if cr.storageManager.ForEachFromRandom(cr.storages, func(s storage.Storage) bool { + log.Debugf("[handler]: Checking %s on storage [%d] %s ...", hash, i, sto.String()) + + sz, er := sto.ServeDownload(rw, req, hash, size) + if er != nil { + log.Debugf("[handler]: File %s failed on storage [%d] %s: %v", hash, i, sto.String(), er) + err = er + return false + } + if sz >= 0 { + opts := cr.storageOpts[i] + cr.AddHits(1, sz, s.Options().Id) + if !keepaliveRec { + cr.statOnlyHits.Add(1) + cr.statOnlyHbts.Add(sz) + } + } + return true + }) { + return + } + http.Error(http.StatusInternation) +} diff --git a/cluster/keepalive.go b/cluster/keepalive.go index c57030eb..8d6362e9 100644 --- a/cluster/keepalive.go +++ b/cluster/keepalive.go @@ -21,25 +21,68 @@ package cluster import ( "context" + "time" ) type KeepAliveRes int -// Succeed returns true when KeepAlive actions is succeed as well as the cluster is not kicked by the controller -func (r KeepAliveRes) Succeed() bool { - return r == 0 -} - -// Failed returns true when KeepAlive action is succeed but the cluster is forced kick by the controller -func (r KeepAliveRes) Kicked() bool { - return r == 1 -} +const ( + KeepAliveSucceed KeepAliveRes = iota + KeepAliveFailed + KeepAliveKicked +) -// Failed returns true when KeepAlive is interrupted by unexpected reason -func (r KeepAliveRes) Failed() bool { - return r == 2 +type keepAliveReq struct { + Time string `json:"time"` + Hits int32 `json:"hits"` + Bytes int64 `json:"bytes"` } +// KeepAlive will send the keep-alive packet and fresh hits & hit bytes data func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { - // + hits, hbts := cr.hits.Load(), cr.hbts.Load() + resCh, err := cr.socket.EmitWithAck("keep-alive", keepAliveReq{ + Time: time.Now().UTC().Format("2006-01-02T15:04:05Z"), + Hits: hits, + Bytes: hbts, + }) + + if e := cr.stats.Save(cr.dataDir); e != nil { + log.Errorf(Tr("error.cluster.stat.save.failed"), e) + } + if err != nil { + log.Errorf(Tr("error.cluster.keepalive.send.failed"), err) + return KeepAliveFailed + } + var data []any + select { + case <-ctx.Done(): + return KeepAliveFailed + case data = <-resCh: + } + log.Debugf("Keep-alive response: %v", data) + if ero := data[0]; len(data) <= 1 || ero != nil { + if ero, ok := ero.(map[string]any); ok { + if msg, ok := ero["message"].(string); ok { + log.Errorf(Tr("error.cluster.keepalive.failed"), msg) + if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { + hash := hashMismatch[1] + log.Warnf("Detected hash mismatch error, removing bad file %s", hash) + for _, s := range cr.storages { + go s.Remove(hash) + } + } + return KeepAliveFailed + } + } + log.Errorf(Tr("error.cluster.keepalive.failed"), ero) + return KeepAliveFailed + } + log.Infof(Tr("info.cluster.keepalive.success"), ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) + cr.hits.Add(-hits2) + cr.hbts.Add(-hbts2) + if data[1] == false { + return KeepAliveKicked + } + return KeepAliveSucceed } diff --git a/cluster/socket.go b/cluster/socket.go new file mode 100644 index 00000000..fcb3e2b2 --- /dev/null +++ b/cluster/socket.go @@ -0,0 +1,33 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +import ( + "context" + + "github.com/LiterMC/socket.io" + "github.com/LiterMC/socket.io/engine.io" +) + +// Connect connects to the central server +// The context passed in only affect the logical of Connect method +func (cr *Cluster) Connect(ctx context.Context) error { + return +} diff --git a/cluster/status.go b/cluster/status.go index 922d96ff..f08b9586 100644 --- a/cluster/status.go +++ b/cluster/status.go @@ -24,6 +24,7 @@ const ( clusterEnabled = 1 clusterEnabling = 2 clusterKicked = 4 + clusterError = 5 ) // Enabled returns true if the cluster is enabled or enabling @@ -47,10 +48,17 @@ func (cr *Cluster) IsKicked() bool { return cr.status.Load() == clusterKicked } +// IsError returns true if the cluster is disabled since connection error +func (cr *Cluster) IsError() bool { + return cr.status.Load() == clusterError +} + // WaitForEnable returns a channel which receives true when cluster enabled succeed, or receives false when it failed to enable // If the cluster is already enable, the channel always returns true // The channel should not be used multiple times func (cr *Cluster) WaitForEnable() <-chan bool { + cr.mux.Lock() + defer cr.mux.Unlock() ch := make(chan bool, 1) if cr.Running() { ch <- true diff --git a/config.go b/config.go index e3ddb66c..6baa51d5 100644 --- a/config.go +++ b/config.go @@ -478,12 +478,3 @@ func readConfig() (config Config) { } return } - -type OpenbmclapiAgentSyncConfig struct { - Source string `json:"source"` - Concurrency int `json:"concurrency"` -} - -type OpenbmclapiAgentConfig struct { - Sync OpenbmclapiAgentSyncConfig `json:"sync"` -} diff --git a/lang/en/us.go b/lang/en/us.go index b836f763..7f0df686 100644 --- a/lang/en/us.go +++ b/lang/en/us.go @@ -8,6 +8,7 @@ var areaUS = map[string]string{ "program.exited": "Program exiting with code %d", "error.exit.please.read.faq": "Please read https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq before report your issue", "warn.exit.detected.windows.open.browser": "Detected that you are in windows environment, we are helping you to open the browser", + "warn.cluster.detected.hash.mismatch": "Detected hash mismatch error, removing bad file %s", "info.filelist.fetching": "Fetching file list", "error.filelist.fetch.failed": "Cannot fetch cluster file list: %v", diff --git a/lang/zh/cn.go b/lang/zh/cn.go index ed16e39a..35c359e0 100644 --- a/lang/zh/cn.go +++ b/lang/zh/cn.go @@ -8,6 +8,7 @@ var areaCN = map[string]string{ "program.exited": "节点正在退出, 代码 %d", "error.exit.please.read.faq": "请在提交问题前阅读 https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq", "warn.exit.detected.windows.open.browser": "检测到您是新手 Windows 用户. 我们正在帮助您打开浏览器 ...", + "warn.cluster.detected.hash.mismatch": "检测到文件哈希值不匹配, 正在删除 %s", "info.filelist.fetching": "获取文件列表中", "error.filelist.fetch.failed": "文件列表获取失败: %v", diff --git a/log/tr.go b/log/tr.go new file mode 100644 index 00000000..0470801c --- /dev/null +++ b/log/tr.go @@ -0,0 +1,44 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package log + +import ( + "github.com/LiterMC/go-openbmclapi/lang" +) + +func TrDebugf(key string, vals ...any) { + Debugf(lang.Tr(key), vals...) +} + +func TrInfof(key string, vals ...any) { + Infof(lang.Tr(key), vals...) +} + +func TrWarnf(key string, vals ...any) { + Warnf(lang.Tr(key), vals...) +} + +func TrErrorf(key string, vals ...any) { + Errorf(lang.Tr(key), vals...) +} + +func TrPanicf(key string, vals ...any) { + Panicf(lang.Tr(key), vals...) +} diff --git a/main.go b/main.go index 97496571..663550d6 100644 --- a/main.go +++ b/main.go @@ -131,10 +131,10 @@ func main() { } } if code != 0 { - log.Errorf(Tr("program.exited"), code) - log.Error(Tr("error.exit.please.read.faq")) + log.TrErrorf("program.exited", code) + log.TrErrorf("error.exit.please.read.faq") if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { - log.Warn(Tr("warn.exit.detected.windows.open.browser")) + log.TrWarnf("warn.exit.detected.windows.open.browser") cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") cmd.Start() time.Sleep(time.Hour) @@ -164,10 +164,10 @@ START: config.applyWebManifest(dsbManifest) - log.Infof(Tr("program.starting"), build.ClusterVersion, build.BuildVersion) + log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion) if config.ClusterId == defaultConfig.ClusterId || config.ClusterSecret == defaultConfig.ClusterSecret { - log.Error(Tr("error.set.cluster.id")) + log.TrErrorf("error.set.cluster.id") osExit(CodeClientError) } @@ -204,16 +204,16 @@ START: } if !config.Tunneler.Enable { strPort := strconv.Itoa((int)(r.getPublicPort())) - log.Infof(Tr("info.server.public.at"), net.JoinHostPort(publicHost, strPort), r.clusterSvr.Addr, r.getCertCount()) + log.TrInfof("info.server.public.at", net.JoinHostPort(publicHost, strPort), r.clusterSvr.Addr, r.getCertCount()) if len(r.publicHosts) > 1 { - log.Info(Tr("info.server.alternative.hosts")) + log.TrInfof("info.server.alternative.hosts") for _, h := range r.publicHosts[1:] { log.Infof("\t- https://%s", net.JoinHostPort(h, strPort)) } } } - log.Info(Tr("info.wait.first.sync")) + log.TrInfof("info.wait.first.sync") select { case <-firstSyncDone: case <-ctx.Done(): @@ -308,13 +308,13 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { cancel() shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) - log.Warn(Tr("warn.server.closing")) + log.TrWarnf("warn.server.closing") shutExit := make(chan struct{}, 0) go func() { defer close(shutExit) defer cancelShut() r.cluster.Disable(shutCtx) - log.Warn(Tr("warn.httpserver.closing")) + log.TrWarnf("warn.httpserver.closing") r.clusterSvr.Shutdown(shutCtx) }() select { @@ -324,7 +324,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { log.Error("Second close signal received, exit") return CodeClientError } - log.Warn(Tr("warn.server.closed")) + log.TrWarnf("warn.server.closed") if s == syscall.SIGHUP { log.Info("Restarting server ...") r.restartFlag = true diff --git a/storage/manager.go b/storage/manager.go index d4d4d554..8747fef7 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -20,41 +20,73 @@ package storage import ( + "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) // Manager manages a list of storages type Manager struct { - Options []StorageOption - Storages []Storage - weights []uint - totalWeight uint - totalWeightsCache utils.SyncMap[[]string, *weightCache] + Storages []Storage + weights []uint + totalWeight uint + totalWeightsCache *utils.SyncMap[int, *weightCache] } -func NewManager(opts []StorageOption, storages []Storage) (m *Manager) { +func NewManager(storages []Storage) (m *Manager) { m = new(Manager) - m.Options = opts m.Storages = storages - m.weights = make([]uint, len(opts)) + m.weights = make([]uint, len(storages)) m.totalWeight = 0 - m.totalWeightsCache = utils.NewSyncMap[[]string, *weightCache]() - for i, s := range opts { - m.weights[i] = s.Weight - m.totalWeight += s.Weight + m.totalWeightsCache = utils.NewSyncMap[int, *weightCache]() + for i, s := range storages { + w := s.Options().Weight + m.weights[i] = w + m.totalWeight += w } return } -type weightCache struct{ +func (m *Manager) GetFlavorString(storages []int) string { + typeCount := make(map[string]int, 2) + for _, i := range storages { + t := m.Storages[i].Options().Type + switch t { + case StorageLocal: + typeCount["file"]++ + case StorageMount, StorageWebdav: + typeCount["alist"]++ + default: + log.Errorf("Unknown storage type %q", t) + } + } + flavor := "" + for s, _ := range typeCount { + if len(flavor) > 0 { + flavor += "+" + } + flavor += s + } + return flavor +} + +type weightCache struct { weights []uint - total uint + total uint +} + +func calcStoragesCacheKey(storages []int) int { + key := len(storages) + for _, v := range storages { + key = key*31 + v + } + return key } func (m *Manager) ForEachFromRandom(storages []int, cb func(s Storage) (done bool)) (done bool) { - data, _ := m.totalWeightsCache.GetOrSet(storages, func() (c *weightCache) { + cacheKey := calcStoragesCacheKey(storages) + data, _ := m.totalWeightsCache.GetOrSet(cacheKey, func() (c *weightCache) { c = new(weightCache) - c.weights = make([]int, len(storages)) + c.weights = make([]uint, len(storages)) for i, j := range storages { w := m.weights[j] c.weights[i] = w @@ -62,7 +94,9 @@ func (m *Manager) ForEachFromRandom(storages []int, cb func(s Storage) (done boo } return }) - return forEachFromRandomIndexWithPossibility(data.weights, data.total, cb) + return forEachFromRandomIndexWithPossibility(data.weights, data.total, func(i int) bool { + return cb(m.Storages[i]) + }) } func forEachFromRandomIndex(leng int, cb func(i int) (done bool)) (done bool) { diff --git a/storage/storage.go b/storage/storage.go index 5068857b..5c5c4a83 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -35,11 +35,8 @@ import ( type Storage interface { fmt.Stringer - // Options should return the pointer of the storage options - // which should be able to marshal/unmarshal with yaml format - Options() any - // SetOptions will be called with the same type of the Options() result - SetOptions(any) + // Options should return the pointer of the StorageOption that should not be modified. + Options() *StorageOption // Init will be called before start to use a storage Init(context.Context) error CheckUpload(context.Context) error @@ -61,7 +58,7 @@ const ( ) type StorageFactory struct { - New func() Storage + New func(StorageOption) Storage NewConfig func() any } @@ -78,8 +75,7 @@ func RegisterStorageFactory(typ string, inst StorageFactory) { } func NewStorage(opt StorageOption) Storage { - s := storageFactories[opt.Type].New() - s.SetOptions(opt.Data) + s := storageFactories[opt.Type].New(opt) return s } diff --git a/storage/storage_local.go b/storage/storage_local.go index f9b7a9f0..a4c4a8e0 100644 --- a/storage/storage_local.go +++ b/storage/storage_local.go @@ -42,14 +42,20 @@ type LocalStorageOption struct { } type LocalStorage struct { - opt LocalStorageOption + basicOpt StorageOption + opt LocalStorageOption } var _ Storage = (*LocalStorage)(nil) func init() { RegisterStorageFactory(StorageLocal, StorageFactory{ - New: func() Storage { return new(LocalStorage) }, + New: func(opt StorageOption) Storage { + return &LocalStorage{ + basicOpt: opt, + opt: *(opt.Data.(*LocalStorageOption)), + } + }, NewConfig: func() any { return new(LocalStorageOption) }, }) } @@ -58,12 +64,8 @@ func (s *LocalStorage) String() string { return fmt.Sprintf("", s.opt.CachePath) } -func (s *LocalStorage) Options() any { - return &s.opt -} - -func (s *LocalStorage) SetOptions(newOpts any) { - s.opt = *(newOpts.(*LocalStorageOption)) +func (s *LocalStorage) Options() *StorageOption { + return &s.basicOpt } func (s *LocalStorage) Init(context.Context) (err error) { diff --git a/storage/storage_mount.go b/storage/storage_mount.go index 5d1e312b..62a75e4a 100644 --- a/storage/storage_mount.go +++ b/storage/storage_mount.go @@ -53,7 +53,8 @@ func (opt *MountStorageOption) CachePath() string { } type MountStorage struct { - opt MountStorageOption + basicOpt StorageOption + opt MountStorageOption supportRange atomic.Bool working atomic.Int32 @@ -65,7 +66,12 @@ var _ Storage = (*MountStorage)(nil) func init() { RegisterStorageFactory(StorageMount, StorageFactory{ - New: func() Storage { return new(MountStorage) }, + New: func(opt StorageOption) Storage { + return &MountStorage{ + basicOpt: opt, + opt: *(opt.Data.(*MountStorageOption)), + } + }, NewConfig: func() any { return new(MountStorageOption) }, }) } @@ -74,12 +80,8 @@ func (s *MountStorage) String() string { return fmt.Sprintf("", s.opt.Path, s.opt.RedirectBase) } -func (s *MountStorage) Options() any { - return &s.opt -} - -func (s *MountStorage) SetOptions(newOpts any) { - s.opt = *(newOpts.(*MountStorageOption)) +func (s *MountStorage) Options() *StorageOption { + return &s.basicOpt } var checkerClient = &http.Client{ diff --git a/storage/storage_webdav.go b/storage/storage_webdav.go index 2cff835d..f9443598 100644 --- a/storage/storage_webdav.go +++ b/storage/storage_webdav.go @@ -121,7 +121,8 @@ func (o *WebDavStorageOption) GetPassword() string { } type WebDavStorage struct { - opt WebDavStorageOption + basicOpt StorageOption + opt WebDavStorageOption cache gocache.Cache cli *gowebdav.Client @@ -139,7 +140,12 @@ var _ Storage = (*WebDavStorage)(nil) func init() { RegisterStorageFactory(StorageWebdav, StorageFactory{ - New: func() Storage { return new(WebDavStorage) }, + New: func(opt StorageOption) Storage { + return &WebDavStorage{ + basicOpt: opt, + opt: *(opt.Data.(*WebDavStorageOption)), + } + }, NewConfig: func() any { return new(WebDavStorageOption) }, }) } @@ -148,12 +154,8 @@ func (s *WebDavStorage) String() string { return fmt.Sprintf("", s.opt.GetEndPoint(), s.opt.GetUsername()) } -func (s *WebDavStorage) Options() any { - return &s.opt -} - -func (s *WebDavStorage) SetOptions(newOpts any) { - s.opt = *(newOpts.(*WebDavStorageOption)) +func (s *WebDavStorage) Options() *StorageOption { + return &s.basicOpt } func webdavIsHTTPError(err error, code int) bool { diff --git a/utils/rand.go b/utils/rand.go index 37497eb3..eb6a86b2 100644 --- a/utils/rand.go +++ b/utils/rand.go @@ -17,6 +17,8 @@ * along with this program. If not, see . */ +package utils + import ( "math/rand" "time" From 5e89ede8280d987e8422acdb297fd32d8a973177 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 20 Jun 2024 00:25:43 -0600 Subject: [PATCH 03/36] add socket logic --- cluster/cluster.go | 67 +++++++++++++++++++++++++----- cluster/keepalive.go | 2 + cluster/socket.go | 99 +++++++++++++++++++++++++++++++++++++++++++- cluster/status.go | 22 +++++++--- 4 files changed, 172 insertions(+), 18 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index f7198d57..67f93bc9 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -143,7 +143,10 @@ func (cr *Cluster) Enable(ctx context.Context) error { }() oldStatus := cr.status.Swap(clusterEnabling) defer cr.status.CompareAndSwap(clusterEnabling, oldStatus) + return cr.enable(ctx) +} +func (cr *Cluster) enable(ctx context.Context) error { storageStr := cr.storageManager.GetFlavorString(cr.storages) log.TrInfof("info.cluster.enable.sending") @@ -189,15 +192,64 @@ func (cr *Cluster) Enable(ctx context.Context) error { if v := data[1]; !v.(bool) { return fmt.Errorf("FATAL: Enable ack non true value, got (%T) %#v", v, v) } - cr.disableSignal = make(chan struct{}, 0) + disableSignal := make(chan struct{}, 0) + cr.disableSignal = disableSignal log.TrInfof("info.cluster.enabled") cr.status.Store(clusterEnabled) + cr.socket.OnceConnect(func(_ *socket.Socket, ns string) { + if ns != "" { + return + } + if cr.status.Load() != clusterEnabled { + return + } + select { + case <-disableSignal: + return + default: + } + cr.status.Store(clusterEnabling) + go cr.reEnable(disableSignal) + }) return nil } +func (cr *Cluster) reEnable(disableSignal <-chan struct{}) { + tctx, cancel := context.WithTimeout(context.Background(), time.Minute*7) + go func() { + select { + case <-tctx.Done(): + case <-disableSignal: + cancel() + } + }() + err := cr.enable(tctx) + cancel() + if err != nil { + log.TrErrorf("error.cluster.enable.failed", err) + if cr.status.Load() == clusterEnabled { + ctx, cancel := context.WithCancel(context.Background()) + timer := time.AfterFunc(time.Minute, func() { + cancel() + if cr.status.CompareAndSwap(clusterEnabled, clusterEnabling) { + cr.reEnable(disableSignal) + } + }) + go func() { + select { + case <-ctx.Done(): + case <-disableSignal: + cancel() + } + }() + } + } +} + // Disable send disable packet to central server // The context passed in only affect the logical of Disable method // Disable method is thread-safe, and it will wait until the first invoke exited +// Connection will not be closed after disable func (cr *Cluster) Disable(ctx context.Context) error { if cr.Enabled() { cr.mux.Lock() @@ -241,8 +293,8 @@ func (cr *Cluster) disable(ctx context.Context) error { return nil } -// markDisconnected marked the cluster as error or kicked -func (cr *Cluster) markDisconnected(kicked bool) { +// markKicked marks the cluster as kicked +func (cr *Cluster) markKicked() { if !cr.Enabled() { return } @@ -252,12 +304,5 @@ func (cr *Cluster) markDisconnected(kicked bool) { return } defer close(cr.disableSignal) - - var nextStatus int32 - if kicked { - nextStatus = clusterKicked - } else { - nextStatus = clusterError - } - cr.status.Store(nextStatus) + cr.status.Store(clusterKicked) } diff --git a/cluster/keepalive.go b/cluster/keepalive.go index 8d6362e9..478c1763 100644 --- a/cluster/keepalive.go +++ b/cluster/keepalive.go @@ -39,6 +39,7 @@ type keepAliveReq struct { } // KeepAlive will send the keep-alive packet and fresh hits & hit bytes data +// If cluster is kicked by the central server, the cluster status will be mark as kicked func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { hits, hbts := cr.hits.Load(), cr.hbts.Load() resCh, err := cr.socket.EmitWithAck("keep-alive", keepAliveReq{ @@ -82,6 +83,7 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { cr.hits.Add(-hits2) cr.hbts.Add(-hbts2) if data[1] == false { + cr.markKicked() return KeepAliveKicked } return KeepAliveSucceed diff --git a/cluster/socket.go b/cluster/socket.go index fcb3e2b2..0eb1a8a7 100644 --- a/cluster/socket.go +++ b/cluster/socket.go @@ -21,6 +21,7 @@ package cluster import ( "context" + "fmt" "github.com/LiterMC/socket.io" "github.com/LiterMC/socket.io/engine.io" @@ -28,6 +29,102 @@ import ( // Connect connects to the central server // The context passed in only affect the logical of Connect method +// Connection will not be closed after disable +// +// See Disconnect func (cr *Cluster) Connect(ctx context.Context) error { - return + if !cr.Disconnected() { + return errors.New("Attempt to connect while connecting") + } + _, err := cr.GetAuthToken(ctx) + if err != nil { + return fmt.Errorf("Auth failed %w", err) + } + + engio, err := engine.NewSocket(engine.Options{ + Host: cr.prefix, + Path: "/socket.io/", + ExtraHeaders: http.Header{ + "Origin": {cr.prefix}, + "User-Agent": {build.ClusterUserAgent}, + }, + DialTimeout: time.Minute * 6, + }) + if err != nil { + return fmt.Errorf("Could not parse Engine.IO options: %w", err) + } + if ctx.Value("cluster.options.engine-io.debug") == true { + engio.OnRecv(func(s *engine.Socket, data []byte) { + log.Debugf("Engine.IO %s recv: %q", s.ID(), (string)(data)) + }) + engio.OnSend(func(s *engine.Socket, data []byte) { + log.Debugf("Engine.IO %s send: %q", s.ID(), (string)(data)) + }) + } + engio.OnConnect(func(s *engine.Socket) { + log.Info("Engine.IO %s connected for cluster %s", s.ID(), cr.Id()) + }) + engio.OnDisconnect(cr.onDisconnected) + engio.OnDialError(func(s *engine.Socket, err *DialErrorContext) { + if err.Count() < 0 { + return + } + log.TrErrorf("error.cluster.connect.failed", cr.Id(), err.Count(), config.MaxReconnectCount, err.Err()) + if config.MaxReconnectCount >= 0 && err.Count() >= config.MaxReconnectCount { + log.TrErrorf("error.cluster.connect.failed.toomuch", cr.Id()) + s.Close() + } + }) + log.Infof("Dialing %s for cluster %s", engio.URL().String(), cr.Id()) + if err := engio.Dial(ctx); err != nil { + log.Errorf("Dial error: %v", err) + return false + } + + cr.socket = socket.NewSocket(engio, socket.WithAuthTokenFn(func() (string, error) { + token, err := cr.GetAuthToken(ctx) + if err != nil { + log.TrErrorf("error.cluster.auth.failed", err) + return "", err + } + return token, nil + })) + cr.socket.OnError(func(_ *socket.Socket, err error) { + log.Errorf("Socket.IO error: %v", err) + }) + cr.socket.OnMessage(func(event string, data []any) { + if event == "message" { + log.Infof("[remote]: %v", data[0]) + } + }) + log.Info("Connecting to socket.io namespace") + if err := cr.socket.Connect(""); err != nil { + log.Errorf("Namespace connect error: %v", err) + return false + } + return true +} + +// Disconnect close the connection which connected to the central server +// Disconnect will not disable the cluster +// +// See Connect +func (cr *Cluster) Disconnect() error { + if cr.Disconnected() { + return + } + cr.mux.Lock() + defer cr.mux.Unlock() + err := cr.socket.Close() + cr.socketStatus.Store(socketDisconnected) + cr.socket = nil + return err +} + +func (cr *Cluster) onDisconnected(s *engine.Socket, err error) { + if err != nil { + log.Warnf("Engine.IO %s disconnected: %v", s.ID(), err) + } + cr.socketStatus.Store(socketDisconnected) + cr.socket = nil } diff --git a/cluster/status.go b/cluster/status.go index f08b9586..96e94d35 100644 --- a/cluster/status.go +++ b/cluster/status.go @@ -19,14 +19,29 @@ package cluster +const ( + socketDisconnected = 0 + socketConnected = 1 + socketConnecting = 2 +) + const ( clusterDisabled = 0 clusterEnabled = 1 clusterEnabling = 2 clusterKicked = 4 - clusterError = 5 ) +// Disconnected returns true if the cluster is disconnected from the central server +func (cr *Cluster) Disconnected() bool { + return cr.socketStatus.Load() == socketDisconnected +} + +// Connected returns true if the cluster is connected to the central server +func (cr *Cluster) Connected() bool { + return cr.socketStatus.Load() == socketConnected +} + // Enabled returns true if the cluster is enabled or enabling func (cr *Cluster) Enabled() bool { s := cr.status.Load() @@ -48,11 +63,6 @@ func (cr *Cluster) IsKicked() bool { return cr.status.Load() == clusterKicked } -// IsError returns true if the cluster is disabled since connection error -func (cr *Cluster) IsError() bool { - return cr.status.Load() == clusterError -} - // WaitForEnable returns a channel which receives true when cluster enabled succeed, or receives false when it failed to enable // If the cluster is already enable, the channel always returns true // The channel should not be used multiple times From 103c480e8bf9ad62507a0e93cb970f3728e906e5 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Tue, 25 Jun 2024 19:45:32 -0600 Subject: [PATCH 04/36] migrated more cluster --- cluster.go | 467 --------------------------------------------- cluster/cluster.go | 16 +- cluster/config.go | 8 +- cluster/handler.go | 23 +-- cluster/http.go | 25 +++ go.mod | 2 +- log/tr.go | 4 - storage/manager.go | 16 ++ 8 files changed, 73 insertions(+), 488 deletions(-) delete mode 100644 cluster.go create mode 100644 cluster/http.go diff --git a/cluster.go b/cluster.go deleted file mode 100644 index 592b7ed4..00000000 --- a/cluster.go +++ /dev/null @@ -1,467 +0,0 @@ -/** - * OpenBmclAPI (Golang Edition) - * Copyright (C) 2023 Kevin Z - * All rights reserved - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published - * by the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "context" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "net" - "net/http" - "os" - "path/filepath" - "runtime" - "sync" - "sync/atomic" - "time" - - "github.com/LiterMC/socket.io" - "github.com/LiterMC/socket.io/engine.io" - "github.com/gorilla/websocket" - "github.com/gregjones/httpcache" - - gocache "github.com/LiterMC/go-openbmclapi/cache" - "github.com/LiterMC/go-openbmclapi/database" - "github.com/LiterMC/go-openbmclapi/internal/build" - "github.com/LiterMC/go-openbmclapi/limited" - "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/notify" - "github.com/LiterMC/go-openbmclapi/notify/email" - "github.com/LiterMC/go-openbmclapi/notify/webpush" - "github.com/LiterMC/go-openbmclapi/storage" - "github.com/LiterMC/go-openbmclapi/utils" -) - -type Cluster struct { - host string // not the public access host, but maybe a public IP, or a host that will be resolved to the IP - publicHosts []string // should not contains port, can be nil - publicPort uint16 - clusterId string - clusterSecret string - prefix string - byoc bool - jwtIssuer string - - dataDir string - maxConn int - storageOpts []storage.StorageOption - storages []storage.Storage - storageWeights []uint - storageTotalWeight uint - cache gocache.Cache - apiHmacKey []byte - hijackProxy *HjProxy - - stats notify.Stats - lastHits, statOnlyHits atomic.Int32 - lastHbts, statOnlyHbts atomic.Int64 - issync atomic.Bool - syncProg atomic.Int64 - syncTotal atomic.Int64 - - mux sync.RWMutex - enabled atomic.Bool - disabled chan struct{} - waitEnable []chan struct{} - shouldEnable atomic.Bool - reconnectCount int - socket *socket.Socket - cancelKeepalive context.CancelFunc - downloadMux sync.RWMutex - downloading map[string]*downloadingItem - filesetMux sync.RWMutex - fileset map[string]int64 - authTokenMux sync.RWMutex - authToken *ClusterToken - - client *http.Client - cachedCli *http.Client - bufSlots *limited.BufSlots - database database.DB - notifyManager *notify.Manager - webpushKeyB64 string - updateChecker *time.Ticker - apiRateLimiter *limited.APIRateMiddleWare - - wsUpgrader *websocket.Upgrader - handlerAPIv0 http.Handler - handlerAPIv1 http.Handler - hijackHandler http.Handler -} - -func NewCluster( - ctx context.Context, - prefix string, - baseDir string, - host string, publicPort uint16, - clusterId string, clusterSecret string, - byoc bool, dialer *net.Dialer, - storageOpts []storage.StorageOption, - cache gocache.Cache, -) (cr *Cluster) { - transport := http.DefaultTransport - if dialer != nil { - transport = &http.Transport{ - DialContext: dialer.DialContext, - } - } - - cachedTransport := transport - if cache != gocache.NoCache { - cachedTransport = &httpcache.Transport{ - Transport: transport, - Cache: gocache.WrapToHTTPCache(gocache.NewCacheWithNamespace(cache, "http@")), - } - } - - cr = &Cluster{ - host: host, - publicPort: publicPort, - clusterId: clusterId, - clusterSecret: clusterSecret, - prefix: prefix, - byoc: byoc, - jwtIssuer: jwtIssuerPrefix + "#" + clusterId, - - dataDir: filepath.Join(baseDir, "data"), - maxConn: config.DownloadMaxConn, - storageOpts: storageOpts, - cache: cache, - - disabled: make(chan struct{}, 0), - fileset: make(map[string]int64, 0), - - downloading: make(map[string]*downloadingItem), - - client: &http.Client{ - Transport: transport, - }, - cachedCli: &http.Client{ - Transport: cachedTransport, - }, - - wsUpgrader: &websocket.Upgrader{ - HandshakeTimeout: time.Minute, - }, - } - close(cr.disabled) - - if cr.maxConn <= 0 { - panic("download-max-conn must be a positive integer") - } - cr.bufSlots = limited.NewBufSlots(cr.maxConn) - - { - var ( - n uint = 0 - wgs = make([]uint, len(storageOpts)) - sts = make([]storage.Storage, len(storageOpts)) - ) - for i, s := range storageOpts { - sts[i] = storage.NewStorage(s) - wgs[i] = s.Weight - n += s.Weight - } - cr.storages = sts - cr.storageWeights = wgs - cr.storageTotalWeight = n - } - return -} - -func (cr *Cluster) Init(ctx context.Context) (err error) { - // create data folder - os.MkdirAll(cr.dataDir, 0755) - - if config.Database.Driver == "memory" { - cr.database = database.NewMemoryDB() - } else if cr.database, err = database.NewSqlDB(config.Database.Driver, config.Database.DSN); err != nil { - return - } - - if config.Hijack.Enable { - cr.hijackProxy = NewHjProxy(cr.client, cr.database, cr.handleDownload) - if config.Hijack.EnableLocalCache { - os.MkdirAll(config.Hijack.LocalCachePath, 0755) - } - } - - // Init notification manager - cr.notifyManager = notify.NewManager(cr.dataDir, cr.database, cr.client, config.Dashboard.NotifySubject) - // Add notification plugins - webpushPlg := new(webpush.Plugin) - cr.notifyManager.AddPlugin(webpushPlg) - if config.Notification.EnableEmail { - emailPlg, err := email.NewSMTP( - config.Notification.EmailSMTP, config.Notification.EmailSMTPEncryption, - config.Notification.EmailSender, config.Notification.EmailSenderPassword, - ) - if err != nil { - return err - } - cr.notifyManager.AddPlugin(emailPlg) - } - - if err = cr.notifyManager.Init(ctx); err != nil { - return - } - cr.webpushKeyB64 = base64.RawURLEncoding.EncodeToString(webpushPlg.GetPublicKey()) - - // Init storages - vctx := context.WithValue(ctx, storage.ClusterCacheCtxKey, cr.cache) - for _, s := range cr.storages { - s.Init(vctx) - } - - // read old stats - if err := cr.stats.Load(cr.dataDir); err != nil { - log.Errorf("Could not load stats: %v", err) - } - if cr.apiHmacKey, err = utils.LoadOrCreateHmacKey(cr.dataDir); err != nil { - return fmt.Errorf("Cannot load hmac key: %w", err) - } - - cr.updateChecker = time.NewTicker(time.Hour) - - go func(ticker *time.Ticker) { - defer log.RecoverPanic(nil) - defer ticker.Stop() - - if err := cr.checkUpdate(); err != nil { - log.TrErrorf("error.update.check.failed", err) - } - for range ticker.C { - if err := cr.checkUpdate(); err != nil { - log.TrErrorf("error.update.check.failed", err) - } - } - }(cr.updateChecker) - return -} - -func (cr *Cluster) Destroy(ctx context.Context) { - if cr.database != nil { - cr.database.Cleanup() - } - cr.updateChecker.Stop() - if cr.apiRateLimiter != nil { - cr.apiRateLimiter.Destroy() - } -} - -func (cr *Cluster) allocBuf(ctx context.Context) (slotId int, buf []byte, free func()) { - return cr.bufSlots.Alloc(ctx) -} - -func (cr *Cluster) Connect(ctx context.Context) bool { - cr.mux.Lock() - defer cr.mux.Unlock() - - if cr.socket != nil { - log.Debug("Extra connect") - return true - } - - _, err := cr.GetAuthToken(ctx) - if err != nil { - log.TrErrorf("error.cluster.auth.failed", err) - osExit(CodeClientOrServerError) - } - - engio, err := engine.NewSocket(engine.Options{ - Host: cr.prefix, - Path: "/socket.io/", - ExtraHeaders: http.Header{ - "Origin": {cr.prefix}, - "User-Agent": {build.ClusterUserAgent}, - }, - DialTimeout: time.Minute * 6, - }) - if err != nil { - log.Errorf("Could not parse Engine.IO options: %v; exit.", err) - osExit(CodeClientUnexpectedError) - } - - cr.reconnectCount = 0 - connected := false - - if config.Advanced.SocketIOLog { - engio.OnRecv(func(_ *engine.Socket, data []byte) { - log.Debugf("Engine.IO recv: %q", (string)(data)) - }) - engio.OnSend(func(_ *engine.Socket, data []byte) { - log.Debugf("Engine.IO sending: %q", (string)(data)) - }) - } - engio.OnConnect(func(*engine.Socket) { - log.Info("Engine.IO connected") - }) - engio.OnDisconnect(func(_ *engine.Socket, err error) { - if ctx.Err() != nil { - // Ignore if the error is because context cancelled - return - } - if err != nil { - log.Warnf("Engine.IO disconnected: %v", err) - } - if config.MaxReconnectCount == 0 { - if cr.shouldEnable.Load() { - log.Errorf("Cluster disconnected from remote; exit.") - osExit(CodeServerOrEnvionmentError) - } - } - if !connected { - cr.reconnectCount++ - if config.MaxReconnectCount > 0 && cr.reconnectCount >= config.MaxReconnectCount { - if cr.shouldEnable.Load() { - log.TrErrorf("error.cluster.connect.failed.toomuch") - osExit(CodeServerOrEnvionmentError) - } - } - } - connected = false - go cr.disconnected() - }) - engio.OnDialError(func(_ *engine.Socket, err error) { - cr.reconnectCount++ - log.TrErrorf("error.cluster.connect.failed", cr.reconnectCount, config.MaxReconnectCount, err) - if config.MaxReconnectCount >= 0 && cr.reconnectCount >= config.MaxReconnectCount { - if cr.shouldEnable.Load() { - log.TrErrorf("error.cluster.connect.failed.toomuch") - osExit(CodeServerOrEnvionmentError) - } - } - }) - - cr.socket = socket.NewSocket(engio, socket.WithAuthTokenFn(func() string { - token, err := cr.GetAuthToken(ctx) - if err != nil { - log.TrErrorf("error.cluster.auth.failed", err) - osExit(CodeServerOrEnvionmentError) - } - return token - })) - cr.socket.OnBeforeConnect(func(*socket.Socket) { - log.Infof(Tr("info.cluster.connect.prepare"), cr.reconnectCount, config.MaxReconnectCount) - }) - cr.socket.OnConnect(func(*socket.Socket, string) { - connected = true - log.Debugf("shouldEnable is %v", cr.shouldEnable.Load()) - if cr.shouldEnable.Load() { - if err := cr.Enable(ctx); err != nil { - log.TrErrorf("error.cluster.enable.failed", err) - osExit(CodeClientOrEnvionmentError) - } - } - }) - cr.socket.OnDisconnect(func(*socket.Socket, string) { - go cr.disconnected() - }) - cr.socket.OnError(func(_ *socket.Socket, err error) { - if ctx.Err() != nil { - // Ignore if the error is because context cancelled - return - } - log.Errorf("Socket.IO error: %v", err) - }) - cr.socket.OnMessage(func(event string, data []any) { - if event == "message" { - log.Infof("[remote]: %v", data[0]) - } - }) - log.Infof("Dialing %s", engio.URL().String()) - if err := engio.Dial(ctx); err != nil { - log.Errorf("Dial error: %v", err) - return false - } - log.Info("Connecting to socket.io namespace") - if err := cr.socket.Connect(""); err != nil { - log.Errorf("Open namespace error: %v", err) - return false - } - return true -} - -func (cr *Cluster) disconnected() bool { - cr.mux.Lock() - defer cr.mux.Unlock() - - if cr.enabled.CompareAndSwap(true, false) { - return false - } - if cr.cancelKeepalive != nil { - cr.cancelKeepalive() - cr.cancelKeepalive = nil - } - cr.notifyManager.OnDisabled() - return true -} - -func (cr *Cluster) disable(ctx context.Context) (ok bool) { - cr.mux.Lock() - defer cr.mux.Unlock() - - if !cr.enabled.Load() { - log.Debug("Extra disable") - return false - } - - defer cr.notifyManager.OnDisabled() - - if cr.cancelKeepalive != nil { - cr.cancelKeepalive() - cr.cancelKeepalive = nil - } - if cr.socket == nil { - return false - } - log.Info(Tr("info.cluster.disabling")) - resCh, err := cr.socket.EmitWithAck("disable", nil) - if err == nil { - tctx, cancel := context.WithTimeout(ctx, time.Second*(time.Duration)(config.Advanced.KeepaliveTimeout)) - select { - case <-tctx.Done(): - cancel() - err = tctx.Err() - case data := <-resCh: - cancel() - log.Debug("disable ack:", data) - if ero := data[0]; ero != nil { - log.Errorf("Disable failed: %v", ero) - } else if !data[1].(bool) { - log.Error("Disable failed: acked non true value") - } else { - ok = true - } - } - } - if err != nil { - log.Errorf(Tr("error.cluster.disable.failed"), err) - } - - cr.enabled.Store(false) - go cr.socket.Close() - cr.socket = nil - close(cr.disabled) - log.Warn(Tr("warn.cluster.disabled")) - return -} diff --git a/cluster/cluster.go b/cluster/cluster.go index 67f93bc9..835b5ff4 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -21,7 +21,9 @@ package cluster import ( "context" + "errors" "fmt" + "net/http" "regexp" "runtime" "sync" @@ -66,6 +68,10 @@ type Cluster struct { mux sync.RWMutex status atomic.Int32 socket *socket.Socket + client *http.Client + + authTokenMux sync.RWMutex + authToken *ClusterToken } func NewCluster( @@ -78,6 +84,8 @@ func NewCluster( storageManager: storageManager, storages: storages, + + client: &http.Client{}, } return } @@ -87,6 +95,11 @@ func (cr *Cluster) ID() string { return cr.opts.Id } +// Secret returns the cluster secret +func (cr *Cluster) Secret() string { + return cr.opts.Secret +} + // Host returns the cluster public host func (cr *Cluster) Host() string { return cr.gcfg.Host @@ -181,7 +194,7 @@ func (cr *Cluster) enable(ctx context.Context) error { if msg, ok := ero["message"].(string); ok { if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { hash := hashMismatch[1] - log.Warnf(Tr("warn.cluster.detected.hash.mismatch"), hash) + log.TrWarnf("warn.cluster.detected.hash.mismatch", hash) cr.storageManager.RemoveForAll(hash) } return fmt.Errorf("Enable failed: %v", msg) @@ -239,6 +252,7 @@ func (cr *Cluster) reEnable(disableSignal <-chan struct{}) { select { case <-ctx.Done(): case <-disableSignal: + timer.Stop() cancel() } }() diff --git a/cluster/config.go b/cluster/config.go index 7613cb1e..2eebc617 100644 --- a/cluster/config.go +++ b/cluster/config.go @@ -86,7 +86,7 @@ func (cr *Cluster) fetchToken(ctx context.Context) (token *ClusterToken, err err } }() req, err := cr.makeReq(ctx, http.MethodGet, "/openbmclapi-agent/challenge", url.Values{ - "clusterId": {cr.clusterId}, + "clusterId": {cr.ID()}, }) if err != nil { return @@ -110,7 +110,7 @@ func (cr *Cluster) fetchToken(ctx context.Context) (token *ClusterToken, err err } var buf [32]byte - hs := hmac.New(crypto.SHA256.New, ([]byte)(cr.clusterSecret)) + hs := hmac.New(crypto.SHA256.New, ([]byte)(cr.Secret())) hs.Write(([]byte)(res1.Challenge)) signature := hex.EncodeToString(hs.Sum(buf[:0])) @@ -119,7 +119,7 @@ func (cr *Cluster) fetchToken(ctx context.Context) (token *ClusterToken, err err Challenge string `json:"challenge"` Signature string `json:"signature"` }{ - ClusterId: cr.clusterId, + ClusterId: cr.ID(), Challenge: res1.Challenge, Signature: signature, }) @@ -159,7 +159,7 @@ func (cr *Cluster) refreshToken(ctx context.Context, oldToken string) (token *Cl ClusterId string `json:"clusterId"` Token string `json:"token"` }{ - ClusterId: cr.clusterId, + ClusterId: cr.ID(), Token: oldToken, }) if err != nil { diff --git a/cluster/handler.go b/cluster/handler.go index 9f47f319..52c1ba16 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -21,29 +21,30 @@ package cluster import ( "net/http" + + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" ) -func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string) { +func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string, size int64) { + defer log.RecoverPanic(nil) + var err error if cr.storageManager.ForEachFromRandom(cr.storages, func(s storage.Storage) bool { - log.Debugf("[handler]: Checking %s on storage [%d] %s ...", hash, i, sto.String()) + opts := s.Options() + log.Debugf("[handler]: Checking %s on storage %s ...", hash, opts.Id) - sz, er := sto.ServeDownload(rw, req, hash, size) + sz, er := s.ServeDownload(rw, req, hash, size) if er != nil { - log.Debugf("[handler]: File %s failed on storage [%d] %s: %v", hash, i, sto.String(), er) + log.Debugf("[handler]: File %s failed on storage %s: %v", hash, opts.Id, er) err = er return false } if sz >= 0 { - opts := cr.storageOpts[i] - cr.AddHits(1, sz, s.Options().Id) - if !keepaliveRec { - cr.statOnlyHits.Add(1) - cr.statOnlyHbts.Add(sz) - } + cr.AddHits(1, sz, opts.Id) } return true }) { return } - http.Error(http.StatusInternation) + http.Error(rw, err.Error(), http.StatusInternalServerError) } diff --git a/cluster/http.go b/cluster/http.go new file mode 100644 index 00000000..d83245fa --- /dev/null +++ b/cluster/http.go @@ -0,0 +1,25 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +import ( + "net/http" + "net/url" +) diff --git a/go.mod b/go.mod index 18d282b5..9d665f95 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/LiterMC/go-openbmclapi -go 1.21.6 +go 1.22.0 require ( github.com/LiterMC/socket.io v0.2.4 diff --git a/log/tr.go b/log/tr.go index 0470801c..13e15006 100644 --- a/log/tr.go +++ b/log/tr.go @@ -38,7 +38,3 @@ func TrWarnf(key string, vals ...any) { func TrErrorf(key string, vals ...any) { Errorf(lang.Tr(key), vals...) } - -func TrPanicf(key string, vals ...any) { - Panicf(lang.Tr(key), vals...) -} diff --git a/storage/manager.go b/storage/manager.go index 8747fef7..bc348d2a 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -20,6 +20,8 @@ package storage import ( + "errors" + "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -146,3 +148,17 @@ func forEachFromRandomIndexWithPossibility(poss []uint, total uint, cb func(i in } return false } + +func (m *Manager) RemoveForAll(hash string) error { + errCh := make(chan error, 0) + for _, s := range m.Storages { + go func(s Storage) { + errCh <- s.Remove(hash) + }(s) + } + errs := make([]error, len(m.Storages)) + for i := range len(m.Storages) { + errs[i] = <-errCh + } + return errors.Join(errs...) +} From 451a665b249aec3066e7a11722a465c9cc27b376 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 27 Jun 2024 16:04:53 -0600 Subject: [PATCH 05/36] abstract subscription, token, and user api --- api/subscription.go | 298 ++++++++++++++++++ api/token.go | 35 +++ api/user.go | 62 ++++ api/v0/api.go | 607 +++++++++++------------------------- api/v0/api_token.go | 4 +- api/v0/configure_cluster.go | 24 ++ api/v0/subscription.go | 198 ++++++++++++ cluster/config.go | 64 ++++ cluster/http.go | 45 +++ cluster/keepalive.go | 17 +- database/db.go | 241 -------------- limited/api_rate.go | 10 +- sync.go | 100 ------ utils/crypto.go | 8 +- utils/http.go | 122 +++++++- 15 files changed, 1056 insertions(+), 779 deletions(-) create mode 100644 api/subscription.go create mode 100644 api/token.go create mode 100644 api/user.go create mode 100644 api/v0/configure_cluster.go create mode 100644 api/v0/subscription.go diff --git a/api/subscription.go b/api/subscription.go new file mode 100644 index 00000000..da5034c4 --- /dev/null +++ b/api/subscription.go @@ -0,0 +1,298 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "time" + + "github.com/google/uuid" + + "github.com/LiterMC/go-openbmclapi/utils" +) + +type SubscriptionManager interface { + GetWebPushKey() string + + GetSubscribe(user string, client string) (*SubscribeRecord, error) + SetSubscribe(SubscribeRecord) error + RemoveSubscribe(user string, client string) error + ForEachSubscribe(cb func(*SubscribeRecord) error) error + + GetEmailSubscription(user string, addr string) (*EmailSubscriptionRecord, error) + AddEmailSubscription(EmailSubscriptionRecord) error + UpdateEmailSubscription(EmailSubscriptionRecord) error + RemoveEmailSubscription(user string, addr string) error + ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) error) error + ForEachUsersEmailSubscription(user string, cb func(*EmailSubscriptionRecord) error) error + ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRecord) error) error + + GetWebhook(user string, id uuid.UUID) (*WebhookRecord, error) + AddWebhook(WebhookRecord) error + UpdateWebhook(WebhookRecord) error + UpdateEnableWebhook(user string, id uuid.UUID, enabled bool) error + RemoveWebhook(user string, id uuid.UUID) error + ForEachWebhook(cb func(*WebhookRecord) error) error + ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) error + ForEachEnabledWebhook(cb func(*WebhookRecord) error) error +} + +type SubscribeRecord struct { + User string `json:"user"` + Client string `json:"client"` + EndPoint string `json:"endpoint"` + Keys SubscribeRecordKeys `json:"keys"` + Scopes NotificationScopes `json:"scopes"` + ReportAt Schedule `json:"report_at"` + LastReport sql.NullTime `json:"-"` +} + +type SubscribeRecordKeys struct { + Auth string `json:"auth"` + P256dh string `json:"p256dh"` +} + +var ( + _ sql.Scanner = (*SubscribeRecordKeys)(nil) + _ driver.Valuer = (*SubscribeRecordKeys)(nil) +) + +func (sk *SubscribeRecordKeys) Scan(src any) error { + var data []byte + switch v := src.(type) { + case []byte: + data = v + case string: + data = ([]byte)(v) + default: + return errors.New("Source is not a string") + } + return json.Unmarshal(data, sk) +} + +func (sk SubscribeRecordKeys) Value() (driver.Value, error) { + return json.Marshal(sk) +} + +type NotificationScopes struct { + Disabled bool `json:"disabled"` + Enabled bool `json:"enabled"` + SyncBegin bool `json:"syncbegin"` + SyncDone bool `json:"syncdone"` + Updates bool `json:"updates"` + DailyReport bool `json:"dailyreport"` +} + +var ( + _ sql.Scanner = (*NotificationScopes)(nil) + _ driver.Valuer = (*NotificationScopes)(nil) +) + +//// !!WARN: Do not edit nsFlag's order //// + +const ( + nsFlagDisabled = 1 << iota + nsFlagEnabled + nsFlagSyncDone + nsFlagUpdates + nsFlagDailyReport + nsFlagSyncBegin +) + +func (ns NotificationScopes) ToInt64() (v int64) { + if ns.Disabled { + v |= nsFlagDisabled + } + if ns.Enabled { + v |= nsFlagEnabled + } + if ns.SyncBegin { + v |= nsFlagSyncBegin + } + if ns.SyncDone { + v |= nsFlagSyncDone + } + if ns.Updates { + v |= nsFlagUpdates + } + if ns.DailyReport { + v |= nsFlagDailyReport + } + return +} + +func (ns *NotificationScopes) FromInt64(v int64) { + ns.Disabled = v&nsFlagDisabled != 0 + ns.Enabled = v&nsFlagEnabled != 0 + ns.SyncBegin = v&nsFlagSyncBegin != 0 + ns.SyncDone = v&nsFlagSyncDone != 0 + ns.Updates = v&nsFlagUpdates != 0 + ns.DailyReport = v&nsFlagDailyReport != 0 +} + +func (ns *NotificationScopes) Scan(src any) error { + v, ok := src.(int64) + if !ok { + return errors.New("Source is not a integer") + } + ns.FromInt64(v) + return nil +} + +func (ns NotificationScopes) Value() (driver.Value, error) { + return ns.ToInt64(), nil +} + +func (ns *NotificationScopes) FromStrings(scopes []string) { + for _, s := range scopes { + switch s { + case "disabled": + ns.Disabled = true + case "enabled": + ns.Enabled = true + case "syncbegin": + ns.SyncBegin = true + case "syncdone": + ns.SyncDone = true + case "updates": + ns.Updates = true + case "dailyreport": + ns.DailyReport = true + } + } +} + +func (ns *NotificationScopes) UnmarshalJSON(data []byte) (err error) { + { + type T NotificationScopes + if err = json.Unmarshal(data, (*T)(ns)); err == nil { + return + } + } + var v []string + if err = json.Unmarshal(data, &v); err != nil { + return + } + ns.FromStrings(v) + return +} + +type Schedule struct { + Hour int + Minute int +} + +var ( + _ sql.Scanner = (*Schedule)(nil) + _ driver.Valuer = (*Schedule)(nil) +) + +func (s Schedule) String() string { + return fmt.Sprintf("%02d:%02d", s.Hour, s.Minute) +} + +func (s *Schedule) UnmarshalText(buf []byte) (err error) { + if _, err = fmt.Sscanf((string)(buf), "%02d:%02d", &s.Hour, &s.Minute); err != nil { + return + } + if s.Hour < 0 || s.Hour >= 24 { + return fmt.Errorf("Hour %d out of range [0, 24)", s.Hour) + } + if s.Minute < 0 || s.Minute >= 60 { + return fmt.Errorf("Minute %d out of range [0, 60)", s.Minute) + } + return +} + +func (s *Schedule) UnmarshalJSON(buf []byte) (err error) { + var v string + if err = json.Unmarshal(buf, &v); err != nil { + return + } + return s.UnmarshalText(([]byte)(v)) +} + +func (s *Schedule) MarshalJSON() (buf []byte, err error) { + return json.Marshal(s.String()) +} + +func (s *Schedule) Scan(src any) error { + var v []byte + switch w := src.(type) { + case []byte: + v = w + case string: + v = ([]byte)(w) + default: + return fmt.Errorf("Unexpected type %T", src) + } + return s.UnmarshalText(v) +} + +func (s Schedule) Value() (driver.Value, error) { + return s.String(), nil +} + +func (s Schedule) ReadySince(last, now time.Time) bool { + if last.IsZero() { + last = now.Add(-time.Hour*24 + 1) + } + mustAfter := last.Add(time.Hour * 12) + if now.Before(mustAfter) { + return false + } + if !now.Before(last.Add(time.Hour * 24)) { + return true + } + hour, min := now.Hour(), now.Minute() + if s.Hour < hour && s.Hour+3 > hour || s.Hour == hour && s.Minute <= min { + return true + } + return false +} + +type EmailSubscriptionRecord struct { + User string `json:"user"` + Addr string `json:"addr"` + Scopes NotificationScopes `json:"scopes"` + Enabled bool `json:"enabled"` +} + +type WebhookRecord struct { + User string `json:"user"` + Id uuid.UUID `json:"id"` + Name string `json:"name"` + EndPoint string `json:"endpoint"` + Auth *string `json:"auth,omitempty"` + AuthHash string `json:"authHash,omitempty"` + Scopes NotificationScopes `json:"scopes"` + Enabled bool `json:"enabled"` +} + +func (rec *WebhookRecord) CovertAuthHash() { + if rec.Auth == nil || *rec.Auth == "" { + rec.AuthHash = "" + } else { + rec.AuthHash = "sha256:" + utils.AsSha256Hex(*rec.Auth) + } + rec.Auth = nil +} diff --git a/api/token.go b/api/token.go new file mode 100644 index 00000000..3f2081f3 --- /dev/null +++ b/api/token.go @@ -0,0 +1,35 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "net/url" +) + +type TokenVerifier interface { + VerifyAuthToken(clientId string, token string) (tokenId string, userId string, err error) + VerifyAPIToken(clientId string, token string, path string, query url.Values) (userId string, err error) +} + +type TokenManager interface { + TokenVerifier + GenerateAuthToken(clientId string, userId string) (token string, err error) + GenerateAPIToken(clientId string, userId string, path string, query map[string]string) (token string, err error) +} diff --git a/api/user.go b/api/user.go new file mode 100644 index 00000000..85ec6fd6 --- /dev/null +++ b/api/user.go @@ -0,0 +1,62 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +type UserManager interface { + GetUsers() []*User + GetUser(id string) *User + AddUser(*User) error + RemoveUser(id string) error + UpdateUserPassword(username string, password string) error + UpdateUserPermissions(username string, permissions PermissionFlag) error + + VerifyUserPassword(userId string, comparator func(password string) bool) error +} + +type PermissionFlag uint32 + +const ( + // BasicPerm includes majority client side actions, such as login, which do not have a significant impact on the server + BasicPerm PermissionFlag = 1 << iota + // SubscribePerm allows the user to subscribe server status & other posts + SubscribePerm + // LogPerm allows the user to view non-debug logs & download access logs + LogPerm + // DebugPerm allows the user to access debug settings and download debug logs + DebugPerm + // FullConfigPerm allows the user to access all config values + FullConfigPerm + // ClusterPerm allows the user to configure clusters' settings & stop/start clusters + ClusterPerm + // StoragePerm allows the user to configure storages' settings & decides to manually start storages' sync process + StoragePerm + // BypassLimitPerm allows the user to ignore API access limit + BypassLimitPerm + // RootPerm user can add/remove users, reset their password, and change their permission flags + RootPerm PermissionFlag = 1 << 31 + + AllPerm = ^(PermissionFlag)(0) +) + +type User struct { + Username string + Password string // as sha256 + Permissions PermissionFlag +} diff --git a/api/v0/api.go b/api/v0/api.go index 7d52ddf9..1cc4ee42 100644 --- a/api/v0/api.go +++ b/api/v0/api.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package v0 import ( "compress/gzip" @@ -36,10 +36,11 @@ import ( "sync/atomic" "time" - "runtime/pprof" - // "github.com/gorilla/websocket" "github.com/google/uuid" + "github.com/gorilla/schema" + "runtime/pprof" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/limited" @@ -58,32 +59,108 @@ func apiGetClientId(req *http.Request) (id string) { return req.Context().Value(clientIdKey).(string) } -func (cr *Cluster) cliIdHandle(next http.Handler) http.Handler { - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - var id string - if cid, _ := req.Cookie(clientIdCookieName); cid != nil { - id = cid.Value - } else { - var err error - id, err = utils.GenRandB64(16) - if err != nil { - http.Error(rw, "cannot generate random number", http.StatusInternalServerError) - return - } - http.SetCookie(rw, &http.Cookie{ - Name: clientIdCookieName, - Value: id, - Expires: time.Now().Add(time.Hour * 24 * 365 * 16), - Secure: true, - HttpOnly: true, - }) - } - req = req.WithContext(context.WithValue(req.Context(), clientIdKey, utils.AsSha256(id))) - next.ServeHTTP(rw, req) +type Handler struct { + handler *utils.HttpMiddleWareHandler + router *http.ServeMux + userManager api.UserManager + tokenManager api.TokenManager + subManager api.SubscriptionManager +} + +var _ http.Handler = (*Handler)(nil) + +func NewHandler(verifier TokenVerifier, subManager api.SubscriptionManager) *Handler { + mux := http.NewServeMux() + h := &Handler{ + router: mux, + handler: utils.NewHttpMiddleWareHandler(mux), + verifier: verifier, + subManager: subManager, + } + h.buildRoute() + h.handler.Use(cliIdMiddleWare) + h.handler.Use(h.authMiddleWare) + return h +} + +func (h *Handler) Handler() *utils.HttpMiddleWareHandler { + return h.handler +} + +func (h *Handler) buildRoute() { + mux := h.router + + mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { + writeJson(rw, http.StatusNotFound, Map{ + "error": "404 not found", + "path": req.URL.Path, + }) }) + + mux.HandleFunc("/ping", h.routePing) + mux.HandleFunc("/status", h.routeStatus) + mux.Handle("/stat/", http.StripPrefix("/stat/", (http.HandlerFunc)(h.routeStat))) + + mux.HandleFunc("/challenge", h.routeChallenge) + mux.HandleFunc("/login", h.routeLogin) + mux.Handle("/requestToken", authHandleFunc(h.routeRequestToken)) + mux.Handle("/logout", authHandleFunc(h.routeLogout)) + + mux.HandleFunc("/log.io", h.routeLogIO) + mux.Handle("/pprof", authHandleFunc(h.routePprof)) + mux.HandleFunc("/subscribeKey", h.routeSubscribeKey) + mux.Handle("/subscribe", authHandle(&utils.HttpMethodHandler{ + Get: h.routeSubscribeGET, + Post: h.routeSubscribePOST, + Delete: h.routeSubscribeDELETE, + })) + mux.Handle("/subscribe_email", authHandle(&utils.HttpMethodHandler{ + Get: h.routeSubscribeEmailGET, + Post: h.routeSubscribeEmailPOST, + Patch: h.routeSubscribeEmailPATCH, + Delete: h.routeSubscribeEmailDELETE, + })) + mux.Handle("/webhook", authHandle(&utils.HttpMethodHandler{ + Get: h.routeWebhookGET, + Post: h.routeWebhookPOST, + Patch: h.routeWebhookPATCH, + Delete: h.routeWebhookDELETE, + })) + + mux.Handle("/log_files", authHandleFunc(h.routeLogFiles)) + mux.Handle("/log_file/", authHandle(http.StripPrefix("/log_file/", (http.HandlerFunc)(h.routeLogFile)))) + + mux.Handle("/configure/cluster", authHandleFunc(h.routeConfigureCluster)) +} + +func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + h.handler.ServeHTTP(rw, req) +} + +func cliIdMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { + var id string + if cid, _ := req.Cookie(clientIdCookieName); cid != nil { + id = cid.Value + } else { + var err error + id, err = utils.GenRandB64(16) + if err != nil { + http.Error(rw, "cannot generate random number", http.StatusInternalServerError) + return + } + http.SetCookie(rw, &http.Cookie{ + Name: clientIdCookieName, + Value: id, + Expires: time.Now().Add(time.Hour * 24 * 365 * 16), + Secure: true, + HttpOnly: true, + }) + } + req = req.WithContext(context.WithValue(req.Context(), clientIdKey, utils.AsSha256(id))) + next.ServeHTTP(rw, req) } -func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, next http.Handler) { +func (h *Handler) authMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { cli := apiGetClientId(req) ctx := req.Context() @@ -96,7 +173,7 @@ func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, nex if req.Method == http.MethodGet { if tk := req.URL.Query().Get("_t"); tk != "" { path := GetRequestRealPath(req) - if id, uid, err = cr.verifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { + if id, uid, err = h.verifier.verifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { ctx = context.WithValue(ctx, tokenTypeKey, tokenTypeAPI) } } @@ -108,7 +185,7 @@ func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, nex if err == nil { err = ErrUnsupportAuthType } - } else if id, uid, err = cr.verifyAuthToken(cli, tk); err != nil { + } else if id, uid, err = h.verifier.VerifyAuthToken(cli, tk); err != nil { id = "" } else { ctx = context.WithValue(ctx, tokenTypeKey, tokenTypeAuth) @@ -122,7 +199,7 @@ func (cr *Cluster) authMiddleware(rw http.ResponseWriter, req *http.Request, nex next.ServeHTTP(rw, req) } -func (cr *Cluster) apiAuthHandle(next http.Handler) http.Handler { +func authHandle(next http.Handler) http.Handler { return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { if req.Context().Value(tokenTypeKey) == nil { writeJson(rw, http.StatusUnauthorized, Map{ @@ -134,69 +211,13 @@ func (cr *Cluster) apiAuthHandle(next http.Handler) http.Handler { }) } -func (cr *Cluster) apiAuthHandleFunc(next http.HandlerFunc) http.Handler { - return cr.apiAuthHandle(next) -} - -func (cr *Cluster) initAPIv0() http.Handler { - mux := http.NewServeMux() - mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - writeJson(rw, http.StatusNotFound, Map{ - "error": "404 not found", - "path": req.URL.Path, - }) - }) - - mux.HandleFunc("/ping", cr.apiV1Ping) - mux.HandleFunc("/status", cr.apiV0Status) - mux.Handle("/stat/", http.StripPrefix("/stat/", (http.HandlerFunc)(cr.apiV0Stat))) - - mux.HandleFunc("/challenge", cr.apiV1Challenge) - mux.HandleFunc("/login", cr.apiV0Login) - mux.Handle("/requestToken", cr.apiAuthHandleFunc(cr.apiV0RequestToken)) - mux.Handle("/logout", cr.apiAuthHandleFunc(cr.apiV1Logout)) - - mux.HandleFunc("/log.io", cr.apiV1LogIO) - mux.Handle("/pprof", cr.apiAuthHandleFunc(cr.apiV1Pprof)) - mux.HandleFunc("/subscribeKey", cr.apiV0SubscribeKey) - mux.Handle("/subscribe", cr.apiAuthHandleFunc(cr.apiV0Subscribe)) - mux.Handle("/subscribe_email", cr.apiAuthHandleFunc(cr.apiV0SubscribeEmail)) - mux.Handle("/webhook", cr.apiAuthHandleFunc(cr.apiV0Webhook)) - - mux.Handle("/log_files", cr.apiAuthHandleFunc(cr.apiV0LogFiles)) - mux.Handle("/log_file/", cr.apiAuthHandle(http.StripPrefix("/log_file/", (http.HandlerFunc)(cr.apiV0LogFile)))) - - next := cr.apiRateLimiter.WrapHandler(mux) - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - cr.authMiddleware(rw, req, next) - }) -} - -func (cr *Cluster) initAPIv1() http.Handler { - mux := http.NewServeMux() - mux.HandleFunc("/", func(rw http.ResponseWriter, req *http.Request) { - writeJson(rw, http.StatusNotFound, Map{ - "error": "404 not found", - "path": req.URL.Path, - }) - }) - - mux.HandleFunc("/ping", cr.apiV1Ping) - - mux.HandleFunc("/challenge", cr.apiV1Challenge) - mux.Handle("/logout", cr.apiAuthHandleFunc(cr.apiV1Logout)) - - mux.HandleFunc("/log.io", cr.apiV1LogIO) - mux.Handle("/pprof", cr.apiAuthHandleFunc(cr.apiV1Pprof)) - - next := cr.apiRateLimiter.WrapHandler(mux) - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - cr.authMiddleware(rw, req, next) - }) +func authHandleFunc(next http.HandlerFunc) http.Handler { + return authHandle(next) } -func (cr *Cluster) apiV1Ping(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routePing(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } limited.SetSkipRateLimit(req) @@ -208,8 +229,9 @@ func (cr *Cluster) apiV1Ping(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0Status(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routeStatus(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } limited.SetSkipRateLimit(req) @@ -245,8 +267,9 @@ func (cr *Cluster) apiV0Status(rw http.ResponseWriter, req *http.Request) { writeJson(rw, http.StatusOK, &status) } -func (cr *Cluster) apiV0Stat(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routeStat(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } limited.SetSkipRateLimit(req) @@ -265,14 +288,15 @@ func (cr *Cluster) apiV0Stat(rw http.ResponseWriter, req *http.Request) { writeJson(rw, http.StatusOK, (json.RawMessage)(data)) } -func (cr *Cluster) apiV1Challenge(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (h *Handler) routeChallenge(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } cli := apiGetClientId(req) query := req.URL.Query() action := query.Get("action") - token, err := cr.generateChallengeToken(cli, action) + token, err := h.generateChallengeToken(cli, action) if err != nil { writeJson(rw, http.StatusInternalServerError, Map{ "error": "Cannot generate token", @@ -285,8 +309,9 @@ func (cr *Cluster) apiV1Challenge(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0Login(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodPost) { +func (h *Handler) routeLogin(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + errorMethodNotAllowed(rw, req, http.MethodPost) return } if !config.Dashboard.Enable { @@ -297,44 +322,25 @@ func (cr *Cluster) apiV0Login(rw http.ResponseWriter, req *http.Request) { } cli := apiGetClientId(req) - type T = struct { - User string `json:"username"` - Challenge string `json:"challenge"` - Signature string `json:"signature"` - } - data, ok := parseRequestBody(rw, req, func(rw http.ResponseWriter, req *http.Request, ct string, data *T) error { - switch ct { - case "application/x-www-form-urlencoded": - data.User = req.PostFormValue("username") - data.Challenge = req.PostFormValue("challenge") - data.Signature = req.PostFormValue("signature") - return nil - default: - return errUnknownContent - } - }) - if !ok { - return + var data struct { + User string `json:"username" schema:"username"` + Challenge string `json:"challenge" schema:"challenge"` + Signature string `json:"signature" schema:"signature"` } - - expectUsername, expectPassword := config.Dashboard.Username, config.Dashboard.Password - if expectUsername == "" || expectPassword == "" { - writeJson(rw, http.StatusUnauthorized, Map{ - "error": "The username or password was not set on the server", - }) + if !parseRequestBody(rw, req, &data) { return } - if err := cr.verifyChallengeToken(cli, "login", data.Challenge); err != nil { + if err := h.verifier.VerifyChallengeToken(cli, "login", data.Challenge); err != nil { writeJson(rw, http.StatusUnauthorized, Map{ "error": "Invalid challenge", }) return } - expectPassword = utils.AsSha256Hex(expectPassword) - expectSignature := utils.HMACSha256Hex(expectPassword, data.Challenge) - if subtle.ConstantTimeCompare(([]byte)(expectUsername), ([]byte)(data.User)) == 0 || - subtle.ConstantTimeCompare(([]byte)(expectSignature), ([]byte)(data.Signature)) == 0 { + if err := h.verifier.VerifyUserPassword(data.User, func(password string) bool { + expectSignature := utils.HMACSha256HexBytes(password, data.Challenge) + return subtle.ConstantTimeCompare(expectSignature, ([]byte)(data.Signature)) == 0 + }); err != nil { writeJson(rw, http.StatusUnauthorized, Map{ "error": "The username or password is incorrect", }) @@ -353,8 +359,9 @@ func (cr *Cluster) apiV0Login(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0RequestToken(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodPost) { +func (cr *Cluster) routeRequestToken(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + errorMethodNotAllowed(rw, req, http.MethodPost) return } defer req.Body.Close() @@ -369,11 +376,8 @@ func (cr *Cluster) apiV0RequestToken(rw http.ResponseWriter, req *http.Request) Path string `json:"path"` Query map[string]string `json:"query,omitempty"` } - if err := json.NewDecoder(req.Body).Decode(&payload); err != nil { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "cannot decode payload in json format", - "message": err.Error(), - }) + if !parseRequestBody(rw, req, &payload) { + return } log.Debugf("payload: %#v", payload) if payload.Path == "" || payload.Path[0] != '/' { @@ -398,8 +402,9 @@ func (cr *Cluster) apiV0RequestToken(rw http.ResponseWriter, req *http.Request) }) } -func (cr *Cluster) apiV1Logout(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodPost) { +func (cr *Cluster) routeLogout(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + errorMethodNotAllowed(rw, req, http.MethodPost) return } limited.SetSkipRateLimit(req) @@ -408,7 +413,7 @@ func (cr *Cluster) apiV1Logout(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(http.StatusNoContent) } -func (cr *Cluster) apiV1LogIO(rw http.ResponseWriter, req *http.Request) { +func (cr *Cluster) routeLogIO(rw http.ResponseWriter, req *http.Request) { addr, _ := req.Context().Value(RealAddrCtxKey).(string) conn, err := cr.wsUpgrader.Upgrade(rw, req, nil) @@ -624,8 +629,9 @@ func (cr *Cluster) apiV1LogIO(rw http.ResponseWriter, req *http.Request) { } } -func (cr *Cluster) apiV1Pprof(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routePprof(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } query := req.URL.Query() @@ -661,244 +667,8 @@ func (cr *Cluster) apiV1Pprof(rw http.ResponseWriter, req *http.Request) { p.WriteTo(rw, debug) } -func (cr *Cluster) apiV0SubscribeKey(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { - return - } - key := cr.webpushKeyB64 - etag := `"` + utils.AsSha256(key) + `"` - rw.Header().Set("ETag", etag) - if cachedTag := req.Header.Get("If-None-Match"); cachedTag == etag { - rw.WriteHeader(http.StatusNotModified) - return - } - writeJson(rw, http.StatusOK, Map{ - "publicKey": key, - }) -} - -func (cr *Cluster) apiV0Subscribe(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodPost, http.MethodDelete) { - return - } - cliId := apiGetClientId(req) - user := getLoggedUser(req) - if user == "" { - writeJson(rw, http.StatusForbidden, Map{ - "error": "Unauthorized", - }) - return - } - switch req.Method { - case http.MethodGet: - cr.apiV0SubscribeGET(rw, req, user, cliId) - case http.MethodPost: - cr.apiV0SubscribePOST(rw, req, user, cliId) - case http.MethodDelete: - cr.apiV0SubscribeDELETE(rw, req, user, cliId) - default: - panic("unreachable") - } -} - -func (cr *Cluster) apiV0SubscribeGET(rw http.ResponseWriter, req *http.Request, user string, client string) { - record, err := cr.database.GetSubscribe(user, client) - if err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, Map{ - "scopes": record.Scopes, - "reportAt": record.ReportAt, - }) -} - -func (cr *Cluster) apiV0SubscribePOST(rw http.ResponseWriter, req *http.Request, user string, client string) { - data, ok := parseRequestBody[database.SubscribeRecord](rw, req, nil) - if !ok { - return - } - data.User = user - data.Client = client - if err := cr.database.SetSubscribe(data); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Database update failed", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0SubscribeDELETE(rw http.ResponseWriter, req *http.Request, user string, client string) { - if err := cr.database.RemoveSubscribe(user, client); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0SubscribeEmail(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete) { - return - } - user := getLoggedUser(req) - if user == "" { - writeJson(rw, http.StatusForbidden, Map{ - "error": "Unauthorized", - }) - return - } - switch req.Method { - case http.MethodGet: - cr.apiV0SubscribeEmailGET(rw, req, user) - case http.MethodPost: - cr.apiV0SubscribeEmailPOST(rw, req, user) - case http.MethodPatch: - cr.apiV0SubscribeEmailPATCH(rw, req, user) - case http.MethodDelete: - cr.apiV0SubscribeEmailDELETE(rw, req, user) - default: - panic("unreachable") - } -} - -func (cr *Cluster) apiV0SubscribeEmailGET(rw http.ResponseWriter, req *http.Request, user string) { - if addr := req.URL.Query().Get("addr"); addr != "" { - record, err := cr.database.GetEmailSubscription(user, addr) - if err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no email subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, record) - return - } - records := make([]database.EmailSubscriptionRecord, 0, 4) - if err := cr.database.ForEachUsersEmailSubscription(user, func(rec *database.EmailSubscriptionRecord) error { - records = append(records, *rec) - return nil - }); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, records) -} - -func (cr *Cluster) apiV0SubscribeEmailPOST(rw http.ResponseWriter, req *http.Request, user string) { - data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) - if !ok { - return - } - - data.User = user - if err := cr.database.AddEmailSubscription(data); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Database update failed", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusCreated) -} - -func (cr *Cluster) apiV0SubscribeEmailPATCH(rw http.ResponseWriter, req *http.Request, user string) { - addr := req.URL.Query().Get("addr") - data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) - if !ok { - return - } - data.User = user - data.Addr = addr - if err := cr.database.UpdateEmailSubscription(data); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no email subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0SubscribeEmailDELETE(rw http.ResponseWriter, req *http.Request, user string) { - addr := req.URL.Query().Get("addr") - if err := cr.database.RemoveEmailSubscription(user, addr); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no email subscription was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) apiV0Webhook(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete) { - return - } +func (cr *Cluster) routeWebhookGET(rw http.ResponseWriter, req *http.Request) { user := getLoggedUser(req) - if user == "" { - writeJson(rw, http.StatusForbidden, Map{ - "error": "Unauthorized", - }) - return - } - switch req.Method { - case http.MethodGet: - cr.apiV0WebhookGET(rw, req, user) - case http.MethodPost: - cr.apiV0WebhookPOST(rw, req, user) - case http.MethodPatch: - cr.apiV0WebhookPATCH(rw, req, user) - case http.MethodDelete: - cr.apiV0WebhookDELETE(rw, req, user) - default: - panic("unreachable") - } -} - -func (cr *Cluster) apiV0WebhookGET(rw http.ResponseWriter, req *http.Request, user string) { if sid := req.URL.Query().Get("id"); sid != "" { id, err := uuid.Parse(sid) if err != nil { @@ -939,9 +709,10 @@ func (cr *Cluster) apiV0WebhookGET(rw http.ResponseWriter, req *http.Request, us writeJson(rw, http.StatusOK, records) } -func (cr *Cluster) apiV0WebhookPOST(rw http.ResponseWriter, req *http.Request, user string) { - data, ok := parseRequestBody[database.WebhookRecord](rw, req, nil) - if !ok { +func (cr *Cluster) routeWebhookPOST(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + var data database.WebhookRecord + if !parseRequestBody(rw, req, &data) { return } @@ -956,10 +727,11 @@ func (cr *Cluster) apiV0WebhookPOST(rw http.ResponseWriter, req *http.Request, u rw.WriteHeader(http.StatusCreated) } -func (cr *Cluster) apiV0WebhookPATCH(rw http.ResponseWriter, req *http.Request, user string) { +func (cr *Cluster) routeWebhookPATCH(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) id := req.URL.Query().Get("id") - data, ok := parseRequestBody[database.WebhookRecord](rw, req, nil) - if !ok { + var data database.WebhookRecord + if !parseRequestBody(rw, req, &data) { return } data.User = user @@ -987,7 +759,8 @@ func (cr *Cluster) apiV0WebhookPATCH(rw http.ResponseWriter, req *http.Request, rw.WriteHeader(http.StatusNoContent) } -func (cr *Cluster) apiV0WebhookDELETE(rw http.ResponseWriter, req *http.Request, user string) { +func (cr *Cluster) routeWebhookDELETE(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) id, err := uuid.Parse(req.URL.Query().Get("id")) if err != nil { writeJson(rw, http.StatusBadRequest, Map{ @@ -1012,8 +785,9 @@ func (cr *Cluster) apiV0WebhookDELETE(rw http.ResponseWriter, req *http.Request, rw.WriteHeader(http.StatusNoContent) } -func (cr *Cluster) apiV0LogFiles(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet) { +func (cr *Cluster) routeLogFiles(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } files := log.ListLogs() @@ -1035,8 +809,9 @@ func (cr *Cluster) apiV0LogFiles(rw http.ResponseWriter, req *http.Request) { }) } -func (cr *Cluster) apiV0LogFile(rw http.ResponseWriter, req *http.Request) { - if checkRequestMethodOrRejectWithJson(rw, req, http.MethodGet, http.MethodHead) { +func (cr *Cluster) routeLogFile(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet && req.Method != http.MethodHead { + errorMethodNotAllowed(rw, req, http.MethodGet+", "+http.MethodHead) return } query := req.URL.Query() @@ -1078,11 +853,11 @@ func (cr *Cluster) apiV0LogFile(rw http.ResponseWriter, req *http.Request) { } rw.Header().Set("Content-Type", "application/octet-stream") rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name+".encrypted")) - cr.apiV0LogFileEncrypted(rw, req, fd, !isGzip) + cr.routeLogFileEncrypted(rw, req, fd, !isGzip) } } -func (cr *Cluster) apiV0LogFileEncrypted(rw http.ResponseWriter, req *http.Request, r io.Reader, useGzip bool) { +func (cr *Cluster) routeLogFileEncrypted(rw http.ResponseWriter, req *http.Request, r io.Reader, useGzip bool) { rw.WriteHeader(http.StatusOK) if req.Method == http.MethodHead { return @@ -1112,10 +887,9 @@ func (cr *Cluster) apiV0LogFileEncrypted(rw http.ResponseWriter, req *http.Reque type Map = map[string]any var errUnknownContent = errors.New("unknown content-type") +var formDecoder = schema.NewDecoder() -type requestBodyParser[T any] func(rw http.ResponseWriter, req *http.Request, contentType string, data *T) error - -func parseRequestBody[T any](rw http.ResponseWriter, req *http.Request, fallback requestBodyParser[T]) (data T, parsed bool) { +func parseRequestBody(rw http.ResponseWriter, req *http.Request, ptr any) (parsed bool) { contentType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type")) if err != nil { writeJson(rw, http.StatusBadRequest, Map{ @@ -1127,26 +901,31 @@ func parseRequestBody[T any](rw http.ResponseWriter, req *http.Request, fallback } switch contentType { case "application/json": - if err := json.NewDecoder(req.Body).Decode(&data); err != nil { + if err := json.NewDecoder(req.Body).Decode(ptr); err != nil { writeJson(rw, http.StatusBadRequest, Map{ "error": "Cannot decode request body", "message": err.Error(), }) return } - return data, true - default: - if fallback != nil { - if err := fallback(rw, req, contentType, &data); err == nil { - return data, true - } else if err != errUnknownContent { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "Cannot decode request body", - "message": err.Error(), - }) - return - } + return true + case "application/x-www-form-urlencoded": + if err := req.ParseForm(); err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "Cannot decode request body", + "message": err.Error(), + }) + return } + if err := formDecoder.Decode(ptr, req.PostForm); err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "Cannot decode request body", + "message": err.Error(), + }) + return + } + return true + default: writeJson(rw, http.StatusBadRequest, Map{ "error": "Unexpected Content-Type", "content-type": contentType, @@ -1168,18 +947,8 @@ func writeJson(rw http.ResponseWriter, code int, data any) (err error) { return } -func checkRequestMethodOrRejectWithJson(rw http.ResponseWriter, req *http.Request, allows ...string) (rejected bool) { - m := req.Method - for _, a := range allows { - if m == a { - return false - } - } - rw.Header().Set("Allow", strings.Join(allows, ", ")) - writeJson(rw, http.StatusMethodNotAllowed, Map{ - "error": "405 method not allowed", - "method": m, - "allow": allows, - }) +func errorMethodNotAllowed(rw http.ResponseWriter, req *http.Request, allow string) { + rw.Header().Set("Allow", allow) + rw.WriteHeader(http.StatusMethodNotAllowed) return true } diff --git a/api/v0/api_token.go b/api/v0/api_token.go index 0568e452..bd4ba177 100644 --- a/api/v0/api_token.go +++ b/api/v0/api_token.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package v0 import ( "errors" @@ -232,7 +232,7 @@ func (cr *Cluster) generateAPIToken(cliId string, userId string, path string, qu return tokenStr, nil } -func (cr *Cluster) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { +func (h *Handler) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { var claims apiTokenClaims _, err = jwt.ParseWithClaims( token, diff --git a/api/v0/configure_cluster.go b/api/v0/configure_cluster.go new file mode 100644 index 00000000..7bc143be --- /dev/null +++ b/api/v0/configure_cluster.go @@ -0,0 +1,24 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v0 + +func (h *Handler) apiConfigureCluster() { + // +} diff --git a/api/v0/subscription.go b/api/v0/subscription.go new file mode 100644 index 00000000..0a5f4709 --- /dev/null +++ b/api/v0/subscription.go @@ -0,0 +1,198 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v0 + +import ( + "net/http" +) + +func (h *Handler) routeSubscribeKey(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) + return + } + key := h.subManager.GetWebPushKey() + etag := `"` + utils.AsSha256(key) + `"` + rw.Header().Set("ETag", etag) + if cachedTag := req.Header.Get("If-None-Match"); cachedTag == etag { + rw.WriteHeader(http.StatusNotModified) + return + } + writeJson(rw, http.StatusOK, Map{ + "publicKey": key, + }) +} + +func (h *Handler) routeSubscribeGET(rw http.ResponseWriter, req *http.Request) { + client := apiGetClientId(req) + user := getLoggedUser(req) + record, err := h.subManager.GetSubscribe(user, client) + if err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, Map{ + "scopes": record.Scopes, + "reportAt": record.ReportAt, + }) +} + +func (h *Handler) routeSubscribePOST(rw http.ResponseWriter, req *http.Request) { + client := apiGetClientId(req) + user := getLoggedUser(req) + data, ok := parseRequestBody[database.SubscribeRecord](rw, req, nil) + if !ok { + return + } + data.User = user + data.Client = client + if err := h.subManager.SetSubscribe(data); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Database update failed", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeSubscribeDELETE(rw http.ResponseWriter, req *http.Request) { + client := apiGetClientId(req) + user := getLoggedUser(req) + if err := h.subManager.RemoveSubscribe(user, client); err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeSubscribeEmailGET(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + if addr := req.URL.Query().Get("addr"); addr != "" { + record, err := h.subManager.GetEmailSubscription(user, addr) + if err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no email subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, record) + return + } + records := make([]database.EmailSubscriptionRecord, 0, 4) + if err := h.subManager.ForEachUsersEmailSubscription(user, func(rec *database.EmailSubscriptionRecord) error { + records = append(records, *rec) + return nil + }); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, records) +} + +func (h *Handler) routeSubscribeEmailPOST(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) + if !ok { + return + } + + data.User = user + if err := h.subManager.AddEmailSubscription(data); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Database update failed", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusCreated) +} + +func (h *Handler) routeSubscribeEmailPATCH(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + addr := req.URL.Query().Get("addr") + data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) + if !ok { + return + } + data.User = user + data.Addr = addr + if err := h.subManager.UpdateEmailSubscription(data); err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no email subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeSubscribeEmailDELETE(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + addr := req.URL.Query().Get("addr") + if err := h.subManager.RemoveEmailSubscription(user, addr); err != nil { + if err == database.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no email subscription was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} diff --git a/cluster/config.go b/cluster/config.go index 2eebc617..c8fac95d 100644 --- a/cluster/config.go +++ b/cluster/config.go @@ -29,8 +29,12 @@ import ( "fmt" "net/http" "net/url" + "strconv" "time" + "github.com/hamba/avro/v2" + "github.com/klauspost/compress/zstd" + "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -259,3 +263,63 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error } return } + +type FileInfo struct { + Path string `json:"path" avro:"path"` + Hash string `json:"hash" avro:"hash"` + Size int64 `json:"size" avro:"size"` + Mtime int64 `json:"mtime" avro:"mtime"` +} + +// from +var fileListSchema = avro.MustParse(`{ + "type": "array", + "items": { + "type": "record", + "name": "fileinfo", + "fields": [ + {"name": "path", "type": "string"}, + {"name": "hash", "type": "string"}, + {"name": "size", "type": "long"}, + {"name": "mtime", "type": "long"} + ] + } +}`) + +func (cr *Cluster) GetFileList(ctx context.Context, lastMod int64) (files []FileInfo, err error) { + var query url.Values + if lastMod > 0 { + query = url.Values{ + "lastModified": {strconv.FormatInt(lastMod, 10)}, + } + } + req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) + if err != nil { + return + } + res, err := cr.cachedCli.Do(req) + if err != nil { + return + } + defer res.Body.Close() + switch res.StatusCode { + case http.StatusOK: + // + case http.StatusNoContent, http.StatusNotModified: + return + default: + err = utils.NewHTTPStatusErrorFromResponse(res) + return + } + log.Debug("Parsing filelist body ...") + zr, err := zstd.NewReader(res.Body) + if err != nil { + return + } + defer zr.Close() + if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { + return + } + log.Debugf("Filelist parsed, length = %d", len(files)) + return +} diff --git a/cluster/http.go b/cluster/http.go index d83245fa..20045c71 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -20,6 +20,51 @@ package cluster import ( + "context" + "io" "net/http" "net/url" + "path" + + "github.com/LiterMC/go-openbmclapi/internal/build" ) + +func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { + return cr.makeReqWithBody(ctx, method, relpath, query, nil) +} + +func (cr *Cluster) makeReqWithBody( + ctx context.Context, + method string, relpath string, + query url.Values, body io.Reader, +) (req *http.Request, err error) { + var u *url.URL + if u, err = url.Parse(cr.opts.Prefix); err != nil { + return + } + u.Path = path.Join(u.Path, relpath) + if query != nil { + u.RawQuery = query.Encode() + } + target := u.String() + + req, err = http.NewRequestWithContext(ctx, method, target, body) + if err != nil { + return + } + req.Header.Set("User-Agent", build.ClusterUserAgent) + return +} + +func (cr *Cluster) makeReqWithAuth(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { + req, err = cr.makeReq(ctx, method, relpath, query) + if err != nil { + return + } + token, err := cr.GetAuthToken(ctx) + if err != nil { + return + } + req.Header.Set("Authorization", "Bearer "+token) + return +} diff --git a/cluster/keepalive.go b/cluster/keepalive.go index 478c1763..2dee82a9 100644 --- a/cluster/keepalive.go +++ b/cluster/keepalive.go @@ -22,6 +22,9 @@ package cluster import ( "context" "time" + + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/utils" ) type KeepAliveRes int @@ -49,10 +52,10 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { }) if e := cr.stats.Save(cr.dataDir); e != nil { - log.Errorf(Tr("error.cluster.stat.save.failed"), e) + log.TrErrorf("error.cluster.stat.save.failed", e) } if err != nil { - log.Errorf(Tr("error.cluster.keepalive.send.failed"), err) + log.TrErrorf("error.cluster.keepalive.send.failed", err) return KeepAliveFailed } var data []any @@ -65,21 +68,19 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { if ero := data[0]; len(data) <= 1 || ero != nil { if ero, ok := ero.(map[string]any); ok { if msg, ok := ero["message"].(string); ok { - log.Errorf(Tr("error.cluster.keepalive.failed"), msg) + log.TrErrorf("error.cluster.keepalive.failed", msg) if hashMismatch := reFileHashMismatchError.FindStringSubmatch(msg); hashMismatch != nil { hash := hashMismatch[1] log.Warnf("Detected hash mismatch error, removing bad file %s", hash) - for _, s := range cr.storages { - go s.Remove(hash) - } + cr.storageManager.RemoveForAll(hash) } return KeepAliveFailed } } - log.Errorf(Tr("error.cluster.keepalive.failed"), ero) + log.TrErrorf("error.cluster.keepalive.failed", ero) return KeepAliveFailed } - log.Infof(Tr("info.cluster.keepalive.success"), ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) + log.TrInfof("info.cluster.keepalive.success", ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) cr.hits.Add(-hits2) cr.hbts.Add(-hbts2) if data[1] == false { diff --git a/database/db.go b/database/db.go index 5429d9d7..2f42c46a 100644 --- a/database/db.go +++ b/database/db.go @@ -83,244 +83,3 @@ type FileRecord struct { Hash string Size int64 } - -type SubscribeRecord struct { - User string `json:"user"` - Client string `json:"client"` - EndPoint string `json:"endpoint"` - Keys SubscribeRecordKeys `json:"keys"` - Scopes NotificationScopes `json:"scopes"` - ReportAt Schedule `json:"report_at"` - LastReport sql.NullTime `json:"-"` -} - -type SubscribeRecordKeys struct { - Auth string `json:"auth"` - P256dh string `json:"p256dh"` -} - -var ( - _ sql.Scanner = (*SubscribeRecordKeys)(nil) - _ driver.Valuer = (*SubscribeRecordKeys)(nil) -) - -func (sk *SubscribeRecordKeys) Scan(src any) error { - var data []byte - switch v := src.(type) { - case []byte: - data = v - case string: - data = ([]byte)(v) - default: - return errors.New("Source is not a string") - } - return json.Unmarshal(data, sk) -} - -func (sk SubscribeRecordKeys) Value() (driver.Value, error) { - return json.Marshal(sk) -} - -type NotificationScopes struct { - Disabled bool `json:"disabled"` - Enabled bool `json:"enabled"` - SyncBegin bool `json:"syncbegin"` - SyncDone bool `json:"syncdone"` - Updates bool `json:"updates"` - DailyReport bool `json:"dailyreport"` -} - -var ( - _ sql.Scanner = (*NotificationScopes)(nil) - _ driver.Valuer = (*NotificationScopes)(nil) -) - -//// !!WARN: Do not edit nsFlag's order //// - -const ( - nsFlagDisabled = 1 << iota - nsFlagEnabled - nsFlagSyncDone - nsFlagUpdates - nsFlagDailyReport - nsFlagSyncBegin -) - -func (ns NotificationScopes) ToInt64() (v int64) { - if ns.Disabled { - v |= nsFlagDisabled - } - if ns.Enabled { - v |= nsFlagEnabled - } - if ns.SyncBegin { - v |= nsFlagSyncBegin - } - if ns.SyncDone { - v |= nsFlagSyncDone - } - if ns.Updates { - v |= nsFlagUpdates - } - if ns.DailyReport { - v |= nsFlagDailyReport - } - return -} - -func (ns *NotificationScopes) FromInt64(v int64) { - ns.Disabled = v&nsFlagDisabled != 0 - ns.Enabled = v&nsFlagEnabled != 0 - ns.SyncBegin = v&nsFlagSyncBegin != 0 - ns.SyncDone = v&nsFlagSyncDone != 0 - ns.Updates = v&nsFlagUpdates != 0 - ns.DailyReport = v&nsFlagDailyReport != 0 -} - -func (ns *NotificationScopes) Scan(src any) error { - v, ok := src.(int64) - if !ok { - return errors.New("Source is not a integer") - } - ns.FromInt64(v) - return nil -} - -func (ns NotificationScopes) Value() (driver.Value, error) { - return ns.ToInt64(), nil -} - -func (ns *NotificationScopes) FromStrings(scopes []string) { - for _, s := range scopes { - switch s { - case "disabled": - ns.Disabled = true - case "enabled": - ns.Enabled = true - case "syncbegin": - ns.SyncBegin = true - case "syncdone": - ns.SyncDone = true - case "updates": - ns.Updates = true - case "dailyreport": - ns.DailyReport = true - } - } -} - -func (ns *NotificationScopes) UnmarshalJSON(data []byte) (err error) { - { - type T NotificationScopes - if err = json.Unmarshal(data, (*T)(ns)); err == nil { - return - } - } - var v []string - if err = json.Unmarshal(data, &v); err != nil { - return - } - ns.FromStrings(v) - return -} - -type Schedule struct { - Hour int - Minute int -} - -var ( - _ sql.Scanner = (*Schedule)(nil) - _ driver.Valuer = (*Schedule)(nil) -) - -func (s Schedule) String() string { - return fmt.Sprintf("%02d:%02d", s.Hour, s.Minute) -} - -func (s *Schedule) UnmarshalText(buf []byte) (err error) { - if _, err = fmt.Sscanf((string)(buf), "%02d:%02d", &s.Hour, &s.Minute); err != nil { - return - } - if s.Hour < 0 || s.Hour >= 24 { - return fmt.Errorf("Hour %d out of range [0, 24)", s.Hour) - } - if s.Minute < 0 || s.Minute >= 60 { - return fmt.Errorf("Minute %d out of range [0, 60)", s.Minute) - } - return -} - -func (s *Schedule) UnmarshalJSON(buf []byte) (err error) { - var v string - if err = json.Unmarshal(buf, &v); err != nil { - return - } - return s.UnmarshalText(([]byte)(v)) -} - -func (s *Schedule) MarshalJSON() (buf []byte, err error) { - return json.Marshal(s.String()) -} - -func (s *Schedule) Scan(src any) error { - var v []byte - switch w := src.(type) { - case []byte: - v = w - case string: - v = ([]byte)(w) - default: - return fmt.Errorf("Unexpected type %T", src) - } - return s.UnmarshalText(v) -} - -func (s Schedule) Value() (driver.Value, error) { - return s.String(), nil -} - -func (s Schedule) ReadySince(last, now time.Time) bool { - if last.IsZero() { - last = now.Add(-time.Hour*24 + 1) - } - mustAfter := last.Add(time.Hour * 12) - if now.Before(mustAfter) { - return false - } - if !now.Before(last.Add(time.Hour * 24)) { - return true - } - hour, min := now.Hour(), now.Minute() - if s.Hour < hour && s.Hour+3 > hour || s.Hour == hour && s.Minute <= min { - return true - } - return false -} - -type EmailSubscriptionRecord struct { - User string `json:"user"` - Addr string `json:"addr"` - Scopes NotificationScopes `json:"scopes"` - Enabled bool `json:"enabled"` -} - -type WebhookRecord struct { - User string `json:"user"` - Id uuid.UUID `json:"id"` - Name string `json:"name"` - EndPoint string `json:"endpoint"` - Auth *string `json:"auth,omitempty"` - AuthHash string `json:"authHash,omitempty"` - Scopes NotificationScopes `json:"scopes"` - Enabled bool `json:"enabled"` -} - -func (rec *WebhookRecord) CovertAuthHash() { - if rec.Auth == nil || *rec.Auth == "" { - rec.AuthHash = "" - } else { - rec.AuthHash = "sha256:" + utils.AsSha256Hex(*rec.Auth) - } - rec.Auth = nil -} diff --git a/limited/api_rate.go b/limited/api_rate.go index 961501c7..33b7ea83 100644 --- a/limited/api_rate.go +++ b/limited/api_rate.go @@ -176,6 +176,8 @@ type APIRateMiddleWare struct { startAt time.Time } +var _ utils.MiddleWare = (*APIRateMiddleWare)(nil) + func NewAPIRateMiddleWare(realIPContextKey, loggedContextKey any) (a *APIRateMiddleWare) { a = &APIRateMiddleWare{ loggedContextKey: loggedContextKey, @@ -184,9 +186,9 @@ func NewAPIRateMiddleWare(realIPContextKey, loggedContextKey any) (a *APIRateMid cleanTicker: time.NewTicker(time.Minute), startAt: time.Now(), } - go func() { + go func(ticker *time.Ticker) { count := 0 - for range a.cleanTicker.C { + for range ticker.C { count++ ishour := count > 60 if ishour { @@ -195,12 +197,10 @@ func NewAPIRateMiddleWare(realIPContextKey, loggedContextKey any) (a *APIRateMid a.clean(ishour) } log.Debugf("cleaner exited") - }() + }(a.cleanTicker) return } -var _ utils.MiddleWare = (*APIRateMiddleWare)(nil) - const ( RateLimitOverrideContextKey = "go-openbmclapi.limited.rate.api.override" RateLimitSkipContextKey = "go-openbmclapi.limited.rate.api.skip" diff --git a/sync.go b/sync.go index ea4ed602..0053653f 100644 --- a/sync.go +++ b/sync.go @@ -75,106 +75,6 @@ func (cr *Cluster) CachedFileSize(hash string) (size int64, ok bool) { return } -func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { - return cr.makeReqWithBody(ctx, method, relpath, query, nil) -} - -func (cr *Cluster) makeReqWithBody( - ctx context.Context, - method string, relpath string, - query url.Values, body io.Reader, -) (req *http.Request, err error) { - var u *url.URL - if u, err = url.Parse(cr.prefix); err != nil { - return - } - u.Path = path.Join(u.Path, relpath) - if query != nil { - u.RawQuery = query.Encode() - } - target := u.String() - - req, err = http.NewRequestWithContext(ctx, method, target, body) - if err != nil { - return - } - req.Header.Set("User-Agent", build.ClusterUserAgent) - return -} - -func (cr *Cluster) makeReqWithAuth(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { - req, err = cr.makeReq(ctx, method, relpath, query) - if err != nil { - return - } - token, err := cr.GetAuthToken(ctx) - if err != nil { - return - } - req.Header.Set("Authorization", "Bearer "+token) - return -} - -type FileInfo struct { - Path string `json:"path" avro:"path"` - Hash string `json:"hash" avro:"hash"` - Size int64 `json:"size" avro:"size"` - Mtime int64 `json:"mtime" avro:"mtime"` -} - -// from -var fileListSchema = avro.MustParse(`{ - "type": "array", - "items": { - "type": "record", - "name": "fileinfo", - "fields": [ - {"name": "path", "type": "string"}, - {"name": "hash", "type": "string"}, - {"name": "size", "type": "long"}, - {"name": "mtime", "type": "long"} - ] - } -}`) - -func (cr *Cluster) GetFileList(ctx context.Context, lastMod int64) (files []FileInfo, err error) { - var query url.Values - if lastMod > 0 { - query = url.Values{ - "lastModified": {strconv.FormatInt(lastMod, 10)}, - } - } - req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) - if err != nil { - return - } - res, err := cr.cachedCli.Do(req) - if err != nil { - return - } - defer res.Body.Close() - switch res.StatusCode { - case http.StatusOK: - // - case http.StatusNoContent, http.StatusNotModified: - return - default: - err = utils.NewHTTPStatusErrorFromResponse(res) - return - } - log.Debug("Parsing filelist body ...") - zr, err := zstd.NewReader(res.Body) - if err != nil { - return - } - defer zr.Close() - if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { - return - } - log.Debugf("Filelist parsed, length = %d", len(files)) - return -} - type syncStats struct { slots *limited.BufSlots noOpen bool diff --git a/utils/crypto.go b/utils/crypto.go index 55c90c0b..40f70846 100644 --- a/utils/crypto.go +++ b/utils/crypto.go @@ -60,10 +60,16 @@ func AsSha256Hex(s string) string { } func HMACSha256Hex(key, data string) string { + return (string)(HMACSha256HexBytes(key, data)) +} + +func HMACSha256HexBytes(key, data string) []byte { m := hmac.New(sha256.New, ([]byte)(key)) m.Write(([]byte)(data)) buf := m.Sum(nil) - return hex.EncodeToString(buf[:]) + value := make([]byte, hex.EncodedLen(len(buf))) + hex.Encode(value, buf[:]) + return value } func GenRandB64(n int) (s string, err error) { diff --git a/utils/http.go b/utils/http.go index 65809812..f0187e26 100644 --- a/utils/http.go +++ b/utils/http.go @@ -36,8 +36,8 @@ import ( "sync/atomic" "time" - "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/log" ) type StatusResponseWriter struct { @@ -125,12 +125,22 @@ type HttpMiddleWareHandler struct { middles []MiddleWare } -func NewHttpMiddleWareHandler(final http.Handler) *HttpMiddleWareHandler { +var _ http.Handler = (*HttpMiddleWareHandler)(nil) + +func NewHttpMiddleWareHandler(final http.Handler, middles ...MiddleWare) *HttpMiddleWareHandler { return &HttpMiddleWareHandler{ - final: final, + final: final, + middles: middles, } } +// Handler returns the final http.Handler +func (m *HttpMiddleWareHandler) Handler() http.Handler { + return m.final +} + +// ServeHTTP implements http.Handler +// It will invoke the middlewares in order func (m *HttpMiddleWareHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { i := 0 var getNext func() http.Handler @@ -158,16 +168,122 @@ func (m *HttpMiddleWareHandler) ServeHTTP(rw http.ResponseWriter, req *http.Requ getNext().ServeHTTP(rw, req) } +// Use append MiddleWares to the middleware chain func (m *HttpMiddleWareHandler) Use(mids ...MiddleWare) { m.middles = append(m.middles, mids...) } +// UseFunc append MiddleWareFuncs to the middleware chain func (m *HttpMiddleWareHandler) UseFunc(fns ...MiddleWareFunc) { for _, fn := range fns { m.middles = append(m.middles, fn) } } +// HttpMethodHandler pass down http requests to different handler based on the request methods +// The HttpMethodHandler should not be modified after called ServeHTTP +type HttpMethodHandler struct { + Get http.Handler + Head bool + Post http.Handler + Put http.Handler + Patch http.Handler + Delete http.Handler + Connect http.Handler + Options http.Handler + Trace http.Handler + + allows string + allowsOnce sync.Once +} + +var _ http.Handler = (*HttpMethodHandler)(nil) + +// ServeHTTP implements http.Handler +// Once ServeHTTP is called the HttpMethodHandler should not be modified +func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodHead: + if !m.Head { + break + } + fallthrough + case http.MethodGet: + if m.Get != nil { + m.Get.ServeHTTP(rw, req) + return + } + case http.MethodPost: + if m.Post != nil { + m.Post.ServeHTTP(rw, req) + return + } + case http.MethodPut: + if m.Put != nil { + m.Put.ServeHTTP(rw, req) + return + } + case http.MethodPatch: + if m.Patch != nil { + m.Patch.ServeHTTP(rw, req) + return + } + case http.MethodDelete: + if m.Delete != nil { + m.Delete.ServeHTTP(rw, req) + return + } + case http.MethodConnect: + if m.Connect != nil { + m.Connect.ServeHTTP(rw, req) + return + } + case http.MethodOptions: + if m.Options != nil { + m.Options.ServeHTTP(rw, req) + return + } + case http.MethodTrace: + if m.Trace != nil { + m.Trace.ServeHTTP(rw, req) + return + } + } + m.allowsOnce.Do(func() { + allows := make([]string, 0, 5) + if m.Get != nil { + allows = append(allows, http.MethodGet) + if m.Head { + allows = append(allows, http.MethodGet) + } + } + if m.Post != nil { + allows = append(allows, http.MethodPost) + } + if m.Put != nil { + allows = append(allows, http.MethodPut) + } + if m.Patch != nil { + allows = append(allows, http.MethodPatch) + } + if m.Delete != nil { + allows = append(allows, http.MethodDelete) + } + if m.Connect != nil { + allows = append(allows, http.MethodConnect) + } + if m.Options != nil { + allows = append(allows, http.MethodOptions) + } + if m.Trace != nil { + allows = append(allows, http.MethodTrace) + } + m.allows = strings.Join(allows, ", ") + }) + rw.Header().Set("Allow", m.allows) + rw.WriteHeader(http.StatusMethodNotAllowed) +} + // HTTPTLSListener will serve a http or a tls connection // When Accept was called, if a pure http request is received, // it will response and redirect the client to the https protocol. From 8cbd7c5f4278d6d2fca0626508fd832c27e9b7f3 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 27 Jun 2024 16:22:21 -0600 Subject: [PATCH 06/36] fix wrong jwt subject usage --- api/v0/api_token.go | 48 ++++++++++++++++++++++----------------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/api/v0/api_token.go b/api/v0/api_token.go index bd4ba177..4d298f09 100644 --- a/api/v0/api_token.go +++ b/api/v0/api_token.go @@ -61,7 +61,7 @@ func getLoggedUser(req *http.Request) string { var ( ErrUnsupportAuthType = errors.New("unsupported authorization type") - ErrClientIdNotMatch = errors.New("client id not match") + ErrScopeNotMatch = errors.New("scope not match") ErrJTINotExists = errors.New("jti not exists") ErrStrictPathNotMatch = errors.New("strict path not match") @@ -76,15 +76,15 @@ func (cr *Cluster) getJWTKey(t *jwt.Token) (any, error) { } const ( - challengeTokenSubject = "GOBA-challenge" - authTokenSubject = "GOBA-auth" - apiTokenSubject = "GOBA-API" + challengeTokenScope = "GOBA-challenge" + authTokenScope = "GOBA-auth" + apiTokenScope = "GOBA-API" ) type challengeTokenClaims struct { jwt.RegisteredClaims - Client string `json:"cli"` + Scope string `json:"scope"` Action string `json:"act"` } @@ -93,12 +93,12 @@ func (cr *Cluster) generateChallengeToken(cliId string, action string) (string, exp := now.Add(time.Minute * 1) token := jwt.NewWithClaims(jwt.SigningMethodHS256, &challengeTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ - Subject: challengeTokenSubject, + Subject: cliId, Issuer: cr.jwtIssuer, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(exp), }, - Client: cliId, + Scope: challengeTokenScope, Action: action, }) tokenStr, err := token.SignedString(cr.apiHmacKey) @@ -114,14 +114,14 @@ func (cr *Cluster) verifyChallengeToken(cliId string, action string, token strin token, &claims, cr.getJWTKey, - jwt.WithSubject(challengeTokenSubject), + jwt.WithSubject(cliId), jwt.WithIssuedAt(), jwt.WithIssuer(cr.jwtIssuer), ); err != nil { return } - if claims.Client != cliId { - return ErrClientIdNotMatch + if claims.Scope != challengeTokenScope { + return ErrScopeNotMatch } if claims.Action != action { return ErrJTINotExists @@ -132,8 +132,8 @@ func (cr *Cluster) verifyChallengeToken(cliId string, action string, token strin type authTokenClaims struct { jwt.RegisteredClaims - Client string `json:"cli"` - User string `json:"usr"` + Scope string `json:"scope"` + User string `json:"usr"` } func (cr *Cluster) generateAuthToken(cliId string, userId string) (string, error) { @@ -146,13 +146,13 @@ func (cr *Cluster) generateAuthToken(cliId string, userId string) (string, error token := jwt.NewWithClaims(jwt.SigningMethodHS256, &authTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ ID: jti, - Subject: authTokenSubject, + Subject: cliId, Issuer: cr.jwtIssuer, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(exp), }, - Client: cliId, - User: userId, + Scope: authTokenScope, + User: userId, }) tokenStr, err := token.SignedString(cr.apiHmacKey) if err != nil { @@ -170,14 +170,14 @@ func (cr *Cluster) verifyAuthToken(cliId string, token string) (id string, user token, &claims, cr.getJWTKey, - jwt.WithSubject(authTokenSubject), + jwt.WithSubject(cliId), jwt.WithIssuedAt(), jwt.WithIssuer(cr.jwtIssuer), ); err != nil { return } - if claims.Client != cliId { - err = ErrClientIdNotMatch + if claims.Scope != authTokenScope { + err = ErrScopeNotMatch return } if user = claims.User; user == "" { @@ -196,7 +196,7 @@ func (cr *Cluster) verifyAuthToken(cliId string, token string) (id string, user type apiTokenClaims struct { jwt.RegisteredClaims - Client string `json:"cli"` + Scope string `json:"scope"` User string `json:"usr"` StrictPath string `json:"str-p"` StrictQuery map[string]string `json:"str-q,omitempty"` @@ -212,12 +212,12 @@ func (cr *Cluster) generateAPIToken(cliId string, userId string, path string, qu token := jwt.NewWithClaims(jwt.SigningMethodHS256, &apiTokenClaims{ RegisteredClaims: jwt.RegisteredClaims{ ID: jti, - Subject: apiTokenSubject, + Subject: cliId, Issuer: cr.jwtIssuer, IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(exp), }, - Client: cliId, + Scope: apiTokenScope, User: userId, StrictPath: path, StrictQuery: query, @@ -238,15 +238,15 @@ func (h *Handler) verifyAPIToken(cliId string, token string, path string, query token, &claims, cr.getJWTKey, - jwt.WithSubject(apiTokenSubject), + jwt.WithSubject(cliId), jwt.WithIssuedAt(), jwt.WithIssuer(cr.jwtIssuer), ) if err != nil { return } - if claims.Client != cliId { - err = ErrClientIdNotMatch + if claims.Scope != apiTokenScope { + err = ErrScopeNotMatch return } if user = claims.User; user == "" { From e17d8c710e10c96303b935084097641f45b710ea Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 27 Jun 2024 16:24:05 -0600 Subject: [PATCH 07/36] fix format --- api/user.go | 4 ++-- go.mod | 1 + go.sum | 2 ++ 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/api/user.go b/api/user.go index 85ec6fd6..e145c9a3 100644 --- a/api/user.go +++ b/api/user.go @@ -56,7 +56,7 @@ const ( ) type User struct { - Username string - Password string // as sha256 + Username string + Password string // as sha256 Permissions PermissionFlag } diff --git a/go.mod b/go.mod index 9d665f95..9fdf8233 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-sql-driver/mysql v1.8.0 // indirect github.com/google/uuid v1.5.0 // indirect + github.com/gorilla/schema v1.4.0 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/lib/pq v1.10.9 // indirect diff --git a/go.sum b/go.sum index a02d9225..3ed4d9ee 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S3 github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU= github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/schema v1.4.0 h1:l2N+lRTJtev9SUhBtj6NmSxd/6+8LhvN0kV+H2Y8R9k= +github.com/gorilla/schema v1.4.0/go.mod h1:Dg5SSm5PV60mhF2NFaTV1xuYYj8tV8NOPRo4FggUMnM= github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY= github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA= From f2dbcbf2015888addc162407db78cc565fa69555 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Fri, 28 Jun 2024 11:34:04 -0600 Subject: [PATCH 08/36] add permHandler --- api/user.go | 2 +- api/v0/api.go | 123 ++++-------------- api/v0/{api_token.go => auth.go} | 103 ++++++++++++++- api/v0/{configure_cluster.go => configure.go} | 10 +- 4 files changed, 135 insertions(+), 103 deletions(-) rename api/v0/{api_token.go => auth.go} (71%) rename api/v0/{configure_cluster.go => configure.go} (80%) diff --git a/api/user.go b/api/user.go index e145c9a3..f6de76c6 100644 --- a/api/user.go +++ b/api/user.go @@ -39,7 +39,7 @@ const ( SubscribePerm // LogPerm allows the user to view non-debug logs & download access logs LogPerm - // DebugPerm allows the user to access debug settings and download debug logs + // DebugPerm allows the user to access debug settings & download debug logs DebugPerm // FullConfigPerm allows the user to access all config values FullConfigPerm diff --git a/api/v0/api.go b/api/v0/api.go index 1cc4ee42..b721a2f8 100644 --- a/api/v0/api.go +++ b/api/v0/api.go @@ -60,22 +60,27 @@ func apiGetClientId(req *http.Request) (id string) { } type Handler struct { - handler *utils.HttpMiddleWareHandler - router *http.ServeMux - userManager api.UserManager - tokenManager api.TokenManager - subManager api.SubscriptionManager + handler *utils.HttpMiddleWareHandler + router *http.ServeMux + users api.UserManager + tokens api.TokenManager + subscriptions api.SubscriptionManager } var _ http.Handler = (*Handler)(nil) -func NewHandler(verifier TokenVerifier, subManager api.SubscriptionManager) *Handler { +func NewHandler( + users api.UserManager, + tokenManager api.TokenManager, + subManager api.SubscriptionManager, +) *Handler { mux := http.NewServeMux() h := &Handler{ - router: mux, - handler: utils.NewHttpMiddleWareHandler(mux), - verifier: verifier, - subManager: subManager, + router: mux, + handler: utils.NewHttpMiddleWareHandler(mux), + users: users, + tokens: tokenManager, + subscriptions: subManager, } h.buildRoute() h.handler.Use(cliIdMiddleWare) @@ -107,115 +112,37 @@ func (h *Handler) buildRoute() { mux.Handle("/logout", authHandleFunc(h.routeLogout)) mux.HandleFunc("/log.io", h.routeLogIO) - mux.Handle("/pprof", authHandleFunc(h.routePprof)) + mux.Handle("/pprof", permHandleFunc(api.DebugPerm, h.routePprof)) mux.HandleFunc("/subscribeKey", h.routeSubscribeKey) - mux.Handle("/subscribe", authHandle(&utils.HttpMethodHandler{ + mux.Handle("/subscribe", permHandleFunc(api.SubscribePerm, &utils.HttpMethodHandler{ Get: h.routeSubscribeGET, Post: h.routeSubscribePOST, Delete: h.routeSubscribeDELETE, })) - mux.Handle("/subscribe_email", authHandle(&utils.HttpMethodHandler{ + mux.Handle("/subscribe_email", permHandleFunc(api.SubscribePerm, &utils.HttpMethodHandler{ Get: h.routeSubscribeEmailGET, Post: h.routeSubscribeEmailPOST, Patch: h.routeSubscribeEmailPATCH, Delete: h.routeSubscribeEmailDELETE, })) - mux.Handle("/webhook", authHandle(&utils.HttpMethodHandler{ + mux.Handle("/webhook", permHandleFunc(api.SubscribePerm, &utils.HttpMethodHandler{ Get: h.routeWebhookGET, Post: h.routeWebhookPOST, Patch: h.routeWebhookPATCH, Delete: h.routeWebhookDELETE, })) - mux.Handle("/log_files", authHandleFunc(h.routeLogFiles)) - mux.Handle("/log_file/", authHandle(http.StripPrefix("/log_file/", (http.HandlerFunc)(h.routeLogFile)))) + mux.Handle("/log_files", permHandleFunc(api.LogPerm, h.routeLogFiles)) + mux.Handle("/log_file/", permHandle(api.LogPerm, http.StripPrefix("/log_file/", (http.HandlerFunc)(h.routeLogFile)))) - mux.Handle("/configure/cluster", authHandleFunc(h.routeConfigureCluster)) + mux.Handle("/configure/cluster", permHandleFunc(api.ClusterPerm, h.routeConfigureCluster)) } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { h.handler.ServeHTTP(rw, req) } -func cliIdMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { - var id string - if cid, _ := req.Cookie(clientIdCookieName); cid != nil { - id = cid.Value - } else { - var err error - id, err = utils.GenRandB64(16) - if err != nil { - http.Error(rw, "cannot generate random number", http.StatusInternalServerError) - return - } - http.SetCookie(rw, &http.Cookie{ - Name: clientIdCookieName, - Value: id, - Expires: time.Now().Add(time.Hour * 24 * 365 * 16), - Secure: true, - HttpOnly: true, - }) - } - req = req.WithContext(context.WithValue(req.Context(), clientIdKey, utils.AsSha256(id))) - next.ServeHTTP(rw, req) -} - -func (h *Handler) authMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { - cli := apiGetClientId(req) - - ctx := req.Context() - - var ( - id string - uid string - err error - ) - if req.Method == http.MethodGet { - if tk := req.URL.Query().Get("_t"); tk != "" { - path := GetRequestRealPath(req) - if id, uid, err = h.verifier.verifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { - ctx = context.WithValue(ctx, tokenTypeKey, tokenTypeAPI) - } - } - } - if id == "" { - auth := req.Header.Get("Authorization") - tk, ok := strings.CutPrefix(auth, "Bearer ") - if !ok { - if err == nil { - err = ErrUnsupportAuthType - } - } else if id, uid, err = h.verifier.VerifyAuthToken(cli, tk); err != nil { - id = "" - } else { - ctx = context.WithValue(ctx, tokenTypeKey, tokenTypeAuth) - } - } - if id != "" { - ctx = context.WithValue(ctx, loggedUserKey, uid) - ctx = context.WithValue(ctx, tokenIdKey, id) - req = req.WithContext(ctx) - } - next.ServeHTTP(rw, req) -} - -func authHandle(next http.Handler) http.Handler { - return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - if req.Context().Value(tokenTypeKey) == nil { - writeJson(rw, http.StatusUnauthorized, Map{ - "error": "403 Unauthorized", - }) - return - } - next.ServeHTTP(rw, req) - }) -} - -func authHandleFunc(next http.HandlerFunc) http.Handler { - return authHandle(next) -} - -func (cr *Cluster) routePing(rw http.ResponseWriter, req *http.Request) { +func (h *Handler) routePing(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet { errorMethodNotAllowed(rw, req, http.MethodGet) return @@ -331,13 +258,13 @@ func (h *Handler) routeLogin(rw http.ResponseWriter, req *http.Request) { return } - if err := h.verifier.VerifyChallengeToken(cli, "login", data.Challenge); err != nil { + if err := h.tokens.VerifyChallengeToken(cli, "login", data.Challenge); err != nil { writeJson(rw, http.StatusUnauthorized, Map{ "error": "Invalid challenge", }) return } - if err := h.verifier.VerifyUserPassword(data.User, func(password string) bool { + if err := h.tokens.VerifyUserPassword(data.User, func(password string) bool { expectSignature := utils.HMACSha256HexBytes(password, data.Challenge) return subtle.ConstantTimeCompare(expectSignature, ([]byte)(data.Signature)) == 0 }); err != nil { diff --git a/api/v0/api_token.go b/api/v0/auth.go similarity index 71% rename from api/v0/api_token.go rename to api/v0/auth.go index 4d298f09..217d6ab3 100644 --- a/api/v0/api_token.go +++ b/api/v0/auth.go @@ -52,11 +52,108 @@ func getRequestTokenType(req *http.Request) string { return "" } -func getLoggedUser(req *http.Request) string { - if user, ok := req.Context().Value(loggedUserKey).(string); ok { +func getLoggedUser(req *http.Request) *api.User { + if user, ok := req.Context().Value(loggedUserKey).(*api.User); ok { return user } - return "" + return nil +} + +func cliIdMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { + var id string + if cid, _ := req.Cookie(clientIdCookieName); cid != nil { + id = cid.Value + } else { + var err error + id, err = utils.GenRandB64(16) + if err != nil { + http.Error(rw, "cannot generate random number", http.StatusInternalServerError) + return + } + http.SetCookie(rw, &http.Cookie{ + Name: clientIdCookieName, + Value: id, + Expires: time.Now().Add(time.Hour * 24 * 365 * 16), + Secure: true, + HttpOnly: true, + }) + } + req = req.WithContext(context.WithValue(req.Context(), clientIdKey, utils.AsSha256(id))) + next.ServeHTTP(rw, req) +} + +func (h *Handler) authMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { + cli := apiGetClientId(req) + + ctx := req.Context() + + var ( + typ string + id string + uid string + err error + ) + if req.Method == http.MethodGet { + if tk := req.URL.Query().Get("_t"); tk != "" { + path := GetRequestRealPath(req) + if id, uid, err = h.tokens.VerifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { + typ = tokenTypeAPI + } + } + } + if id == "" { + auth := req.Header.Get("Authorization") + tk, ok := strings.CutPrefix(auth, "Bearer ") + if !ok { + if err == nil { + err = ErrUnsupportAuthType + } + } else if id, uid, err = h.tokens.VerifyAuthToken(cli, tk); err == nil { + typ = tokenTypeAuth + } + } + if typ != "" { + user, err := h.users.GetUser(uid) + if err == nil { + ctx = context.WithValue(ctx, tokenTypeKey, typ) + ctx = context.WithValue(ctx, loggedUserKey, user) + ctx = context.WithValue(ctx, tokenIdKey, id) + req = req.WithContext(ctx) + } + } + next.ServeHTTP(rw, req) +} + +func authHandle(next http.Handler) http.Handler { + return permHandle(api.BasicPerm, next) +} + +func authHandleFunc(next http.HandlerFunc) http.Handler { + return authHandle(next) +} + +func permHandle(perm api.PermissionFlag, next http.Handler) http.Handler { + perm |= api.BasicPerm + return (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + if user == nil { + writeJson(rw, http.StatusUnauthorized, Map{ + "error": "403 Unauthorized", + }) + return + } + if user.Permissions & perm != perm { + writeJson(rw, http.StatusForbidden, Map{ + "error": "Permission denied", + }) + return + } + next.ServeHTTP(rw, req) + }) +} + +func permHandleFunc(perm api.PermissionFlag, next http.HandlerFunc) http.Handler { + return permHandle(perm, next) } var ( diff --git a/api/v0/configure_cluster.go b/api/v0/configure.go similarity index 80% rename from api/v0/configure_cluster.go rename to api/v0/configure.go index 7bc143be..e1e16877 100644 --- a/api/v0/configure_cluster.go +++ b/api/v0/configure.go @@ -19,6 +19,14 @@ package v0 -func (h *Handler) apiConfigureCluster() { +import ( + "net/http" +) + +func (h *Handler) apiConfigureCluster(rw http.ResponseWriter, req *http.Request) { + // +} + +func (h *Handler) apiConfigureStorage(rw http.ResponseWriter, req *http.Request) { // } From b7b6cc6aedff3b682e712e62b83d5f22d6a8e200 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Fri, 28 Jun 2024 21:45:11 -0600 Subject: [PATCH 09/36] complete api --- api/config.go | 35 ++ api/request.go | 45 +++ api/stats.go | 89 +++++ api/subscription.go | 6 + api/token.go | 3 + api/v0/api.go | 742 +---------------------------------------- api/v0/auth.go | 508 +++++++++++++++++----------- api/v0/configure.go | 153 ++++++++- api/v0/debug.go | 393 ++++++++++++++++++++++ api/v0/stat.go | 63 ++++ api/v0/subscription.go | 196 +++++++++-- config.go | 281 ++++++++-------- handler.go | 16 - 13 files changed, 1417 insertions(+), 1113 deletions(-) create mode 100644 api/config.go create mode 100644 api/request.go create mode 100644 api/stats.go create mode 100644 api/v0/debug.go create mode 100644 api/v0/stat.go diff --git a/api/config.go b/api/config.go new file mode 100644 index 00000000..e091b6bb --- /dev/null +++ b/api/config.go @@ -0,0 +1,35 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "encoding/json" +) + +type ConfigHandler interface { + json.Marshaler + json.Unmarshaler + UnmarshalYAML(data []byte) error + MarshalJSONPath(path string) ([]byte, error) + UnmarshalJSONPath(path string, data []byte) error + + Fingerprint() string + DoLockedAction(fingerprint string, callback func(ConfigHandler) error) error +} diff --git a/api/request.go b/api/request.go new file mode 100644 index 00000000..74588c34 --- /dev/null +++ b/api/request.go @@ -0,0 +1,45 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "net/http" +) + +const ( + RealAddrCtxKey = "handle.real.addr" + RealPathCtxKey = "handle.real.path" + AccessLogExtraCtxKey = "handle.access.extra" +) + +func GetRequestRealAddr(req *http.Request) string { + addr, _ := req.Context().Value(RealAddrCtxKey).(string) + return addr +} + +func GetRequestRealPath(req *http.Request) string { + return req.Context().Value(RealPathCtxKey).(string) +} + +func SetAccessInfo(req *http.Request, key string, value any) { + if info, ok := req.Context().Value(AccessLogExtraCtxKey).(map[string]any); ok { + info[key] = value + } +} diff --git a/api/stats.go b/api/stats.go new file mode 100644 index 00000000..6ddd0f7d --- /dev/null +++ b/api/stats.go @@ -0,0 +1,89 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "time" +) + +type StatsManager interface { + GetStatus() StatusData + // if name is empty then gets the overall access data + GetAccessStat(name string) *AccessStatData +} + +type StatusData struct { + StartAt time.Time `json:"startAt"` + Clusters []string `json:"clusters"` + Storages []string `json:"storages"` +} + +type statInstData struct { + Hits int32 `json:"hits"` + Bytes int64 `json:"bytes"` +} + +func (d *statInstData) update(o *statInstData) { + d.Hits += o.Hits + d.Bytes += o.Bytes +} + +// statTime always save a UTC time +type statTime struct { + Hour int `json:"hour"` + Day int `json:"day"` + Month int `json:"month"` + Year int `json:"year"` +} + +func makeStatTime(t time.Time) (st statTime) { + t = t.UTC() + st.Hour = t.Hour() + y, m, d := t.Date() + st.Day = d - 1 + st.Month = (int)(m) - 1 + st.Year = y + return +} + +func (t statTime) IsLastDay() bool { + return time.Date(t.Year, (time.Month)(t.Month+1), t.Day+1+1, 0, 0, 0, 0, time.UTC).Day() == 1 +} + +type ( + statDataHours = [24]statInstData + statDataDays = [31]statInstData + statDataMonths = [12]statInstData +) + +type accessStatHistoryData struct { + Hours statDataHours `json:"hours"` + Days statDataDays `json:"days"` + Months statDataMonths `json:"months"` +} + +type AccessStatData struct { + Date statTime `json:"date"` + accessStatHistoryData + Prev accessStatHistoryData `json:"prev"` + Years map[string]statInstData `json:"years"` + + Accesses map[string]int `json:"accesses"` +} diff --git a/api/subscription.go b/api/subscription.go index da5034c4..7b88ad98 100644 --- a/api/subscription.go +++ b/api/subscription.go @@ -22,6 +22,8 @@ package api import ( "database/sql" "database/sql/driver" + "encoding/json" + "errors" "fmt" "time" @@ -30,6 +32,10 @@ import ( "github.com/LiterMC/go-openbmclapi/utils" ) +var ( + ErrNotFound = errors.New("Item not found") +) + type SubscriptionManager interface { GetWebPushKey() string diff --git a/api/token.go b/api/token.go index 3f2081f3..8dfcbdc9 100644 --- a/api/token.go +++ b/api/token.go @@ -24,12 +24,15 @@ import ( ) type TokenVerifier interface { + VerifyChallengeToken(clientId string, token string, action string) (err error) VerifyAuthToken(clientId string, token string) (tokenId string, userId string, err error) VerifyAPIToken(clientId string, token string, path string, query url.Values) (userId string, err error) } type TokenManager interface { TokenVerifier + GenerateChallengeToken(clientId string, action string) (token string, err error) GenerateAuthToken(clientId string, userId string) (token string, err error) GenerateAPIToken(clientId string, userId string, path string, query map[string]string) (token string, err error) + InvalidToken(tokenId string) error } diff --git a/api/v0/api.go b/api/v0/api.go index b721a2f8..65b288c1 100644 --- a/api/v0/api.go +++ b/api/v0/api.go @@ -20,56 +20,32 @@ package v0 import ( - "compress/gzip" - "context" - "crypto/subtle" "encoding/json" "errors" - "fmt" - "io" "mime" "net/http" - "os" - "path/filepath" "strconv" - "strings" - "sync/atomic" - "time" - "github.com/google/uuid" "github.com/gorilla/schema" - "runtime/pprof" "github.com/LiterMC/go-openbmclapi/api" - "github.com/LiterMC/go-openbmclapi/database" - "github.com/LiterMC/go-openbmclapi/internal/build" - "github.com/LiterMC/go-openbmclapi/limited" - "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/notify" "github.com/LiterMC/go-openbmclapi/utils" ) -const ( - clientIdCookieName = "_id" - - clientIdKey = "go-openbmclapi.cluster.client.id" -) - -func apiGetClientId(req *http.Request) (id string) { - return req.Context().Value(clientIdKey).(string) -} - type Handler struct { handler *utils.HttpMiddleWareHandler router *http.ServeMux + config api.ConfigHandler users api.UserManager tokens api.TokenManager subscriptions api.SubscriptionManager + stats api.StatsManager } var _ http.Handler = (*Handler)(nil) func NewHandler( + config api.ConfigHandler, users api.UserManager, tokenManager api.TokenManager, subManager api.SubscriptionManager, @@ -78,13 +54,13 @@ func NewHandler( h := &Handler{ router: mux, handler: utils.NewHttpMiddleWareHandler(mux), + config: config, users: users, tokens: tokenManager, subscriptions: subManager, } h.buildRoute() - h.handler.Use(cliIdMiddleWare) - h.handler.Use(h.authMiddleWare) + h.handler.UseFunc(cliIdMiddleWare, h.authMiddleWare) return h } @@ -102,715 +78,16 @@ func (h *Handler) buildRoute() { }) }) - mux.HandleFunc("/ping", h.routePing) - mux.HandleFunc("/status", h.routeStatus) - mux.Handle("/stat/", http.StripPrefix("/stat/", (http.HandlerFunc)(h.routeStat))) - - mux.HandleFunc("/challenge", h.routeChallenge) - mux.HandleFunc("/login", h.routeLogin) - mux.Handle("/requestToken", authHandleFunc(h.routeRequestToken)) - mux.Handle("/logout", authHandleFunc(h.routeLogout)) - - mux.HandleFunc("/log.io", h.routeLogIO) - mux.Handle("/pprof", permHandleFunc(api.DebugPerm, h.routePprof)) - mux.HandleFunc("/subscribeKey", h.routeSubscribeKey) - mux.Handle("/subscribe", permHandleFunc(api.SubscribePerm, &utils.HttpMethodHandler{ - Get: h.routeSubscribeGET, - Post: h.routeSubscribePOST, - Delete: h.routeSubscribeDELETE, - })) - mux.Handle("/subscribe_email", permHandleFunc(api.SubscribePerm, &utils.HttpMethodHandler{ - Get: h.routeSubscribeEmailGET, - Post: h.routeSubscribeEmailPOST, - Patch: h.routeSubscribeEmailPATCH, - Delete: h.routeSubscribeEmailDELETE, - })) - mux.Handle("/webhook", permHandleFunc(api.SubscribePerm, &utils.HttpMethodHandler{ - Get: h.routeWebhookGET, - Post: h.routeWebhookPOST, - Patch: h.routeWebhookPATCH, - Delete: h.routeWebhookDELETE, - })) - - mux.Handle("/log_files", permHandleFunc(api.LogPerm, h.routeLogFiles)) - mux.Handle("/log_file/", permHandle(api.LogPerm, http.StripPrefix("/log_file/", (http.HandlerFunc)(h.routeLogFile)))) - - mux.Handle("/configure/cluster", permHandleFunc(api.ClusterPerm, h.routeConfigureCluster)) + h.buildStatRoute(mux) + h.buildAuthRoute(mux) + h.buildSubscriptionRoute(mux) + h.buildConfigureRoute(mux) } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { h.handler.ServeHTTP(rw, req) } -func (h *Handler) routePing(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - limited.SetSkipRateLimit(req) - authed := getRequestTokenType(req) == tokenTypeAuth - writeJson(rw, http.StatusOK, Map{ - "version": build.BuildVersion, - "time": time.Now().UnixMilli(), - "authed": authed, - }) -} - -func (cr *Cluster) routeStatus(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - limited.SetSkipRateLimit(req) - type syncData struct { - Prog int64 `json:"prog"` - Total int64 `json:"total"` - } - type statusData struct { - StartAt time.Time `json:"startAt"` - Stats *notify.Stats `json:"stats"` - Enabled bool `json:"enabled"` - IsSync bool `json:"isSync"` - Sync *syncData `json:"sync,omitempty"` - Storages []string `json:"storages"` - } - storages := make([]string, len(cr.storageOpts)) - for i, opt := range cr.storageOpts { - storages[i] = opt.Id - } - status := statusData{ - StartAt: startTime, - Stats: &cr.stats, - Enabled: cr.enabled.Load(), - IsSync: cr.issync.Load(), - Storages: storages, - } - if status.IsSync { - status.Sync = &syncData{ - Prog: cr.syncProg.Load(), - Total: cr.syncTotal.Load(), - } - } - writeJson(rw, http.StatusOK, &status) -} - -func (cr *Cluster) routeStat(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - limited.SetSkipRateLimit(req) - name := req.URL.Path - if name == "" { - rw.Header().Set("Cache-Control", "public, max-age=60") - writeJson(rw, http.StatusOK, &cr.stats) - return - } - data, err := cr.stats.MarshalSubStat(name) - if err != nil { - http.Error(rw, "Error when encoding response: "+err.Error(), http.StatusInternalServerError) - return - } - rw.Header().Set("Cache-Control", "public, max-age=30") - writeJson(rw, http.StatusOK, (json.RawMessage)(data)) -} - -func (h *Handler) routeChallenge(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - cli := apiGetClientId(req) - query := req.URL.Query() - action := query.Get("action") - token, err := h.generateChallengeToken(cli, action) - if err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Cannot generate token", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, Map{ - "token": token, - }) -} - -func (h *Handler) routeLogin(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { - errorMethodNotAllowed(rw, req, http.MethodPost) - return - } - if !config.Dashboard.Enable { - writeJson(rw, http.StatusServiceUnavailable, Map{ - "error": "dashboard is disabled in the config", - }) - return - } - cli := apiGetClientId(req) - - var data struct { - User string `json:"username" schema:"username"` - Challenge string `json:"challenge" schema:"challenge"` - Signature string `json:"signature" schema:"signature"` - } - if !parseRequestBody(rw, req, &data) { - return - } - - if err := h.tokens.VerifyChallengeToken(cli, "login", data.Challenge); err != nil { - writeJson(rw, http.StatusUnauthorized, Map{ - "error": "Invalid challenge", - }) - return - } - if err := h.tokens.VerifyUserPassword(data.User, func(password string) bool { - expectSignature := utils.HMACSha256HexBytes(password, data.Challenge) - return subtle.ConstantTimeCompare(expectSignature, ([]byte)(data.Signature)) == 0 - }); err != nil { - writeJson(rw, http.StatusUnauthorized, Map{ - "error": "The username or password is incorrect", - }) - return - } - token, err := cr.generateAuthToken(cli, data.User) - if err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Cannot generate token", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, Map{ - "token": token, - }) -} - -func (cr *Cluster) routeRequestToken(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { - errorMethodNotAllowed(rw, req, http.MethodPost) - return - } - defer req.Body.Close() - if getRequestTokenType(req) != tokenTypeAuth { - writeJson(rw, http.StatusUnauthorized, Map{ - "error": "invalid authorization type", - }) - return - } - - var payload struct { - Path string `json:"path"` - Query map[string]string `json:"query,omitempty"` - } - if !parseRequestBody(rw, req, &payload) { - return - } - log.Debugf("payload: %#v", payload) - if payload.Path == "" || payload.Path[0] != '/' { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "path is invalid", - "message": "'path' must be a non empty string which starts with '/'", - }) - return - } - cli := apiGetClientId(req) - user := getLoggedUser(req) - token, err := cr.generateAPIToken(cli, user, payload.Path, payload.Query) - if err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "cannot generate token", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, Map{ - "token": token, - }) -} - -func (cr *Cluster) routeLogout(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodPost { - errorMethodNotAllowed(rw, req, http.MethodPost) - return - } - limited.SetSkipRateLimit(req) - tid := req.Context().Value(tokenIdKey).(string) - cr.database.RemoveJTI(tid) - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) routeLogIO(rw http.ResponseWriter, req *http.Request) { - addr, _ := req.Context().Value(RealAddrCtxKey).(string) - - conn, err := cr.wsUpgrader.Upgrade(rw, req, nil) - if err != nil { - log.Debugf("[log.io]: Websocket upgrade error: %v", err) - http.Error(rw, err.Error(), http.StatusInternalServerError) - return - } - defer conn.Close() - - cli := apiGetClientId(req) - - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - - conn.SetReadLimit(1024 * 4) - pongTimeoutTimer := time.NewTimer(time.Second * 75) - go func() { - defer conn.Close() - defer cancel() - defer pongTimeoutTimer.Stop() - select { - case _, ok := <-pongTimeoutTimer.C: - if !ok { - return - } - log.Error("[log.io]: Did not receive packet from client longer than 75s") - return - case <-ctx.Done(): - return - } - }() - - var authData struct { - Token string `json:"token"` - } - deadline := time.Now().Add(time.Second * 10) - conn.SetReadDeadline(deadline) - err = conn.ReadJSON(&authData) - conn.SetReadDeadline(time.Time{}) - if err != nil { - if time.Now().After(deadline) { - conn.WriteJSON(Map{ - "type": "error", - "message": "auth timeout", - }) - } else { - conn.WriteJSON(Map{ - "type": "error", - "message": "unexpected auth data: " + err.Error(), - }) - } - return - } - if _, _, err = cr.verifyAuthToken(cli, authData.Token); err != nil { - conn.WriteJSON(Map{ - "type": "error", - "message": "auth failed", - }) - return - } - if err := conn.WriteJSON(Map{ - "type": "ready", - }); err != nil { - return - } - - var level atomic.Int32 - level.Store((int32)(log.LevelInfo)) - - type logObj struct { - Type string `json:"type"` - Time int64 `json:"time"` // UnixMilli - Level string `json:"lvl"` - Log string `json:"log"` - } - c := make(chan *logObj, 64) - unregister := log.RegisterLogMonitor(log.LevelDebug, func(ts int64, l log.Level, msg string) { - if (log.Level)(level.Load()) > l&log.LevelMask { - return - } - select { - case c <- &logObj{ - Type: "log", - Time: ts, - Level: l.String(), - Log: msg, - }: - default: - } - }) - defer unregister() - - go func() { - defer log.RecoverPanic(nil) - defer conn.Close() - defer cancel() - var data map[string]any - for { - clear(data) - if err := conn.ReadJSON(&data); err != nil { - log.Errorf("[log.io]: Cannot read from peer: %v", err) - return - } - typ, ok := data["type"].(string) - if !ok { - continue - } - switch typ { - case "pong": - log.Debugf("[log.io]: received PONG from %s: %v", addr, data["data"]) - pongTimeoutTimer.Reset(time.Second * 75) - case "set-level": - l, ok := data["level"].(string) - if ok { - switch l { - case "DBUG": - level.Store((int32)(log.LevelDebug)) - case "INFO": - level.Store((int32)(log.LevelInfo)) - case "WARN": - level.Store((int32)(log.LevelWarn)) - case "ERRO": - level.Store((int32)(log.LevelError)) - default: - continue - } - select { - case c <- &logObj{ - Type: "log", - Time: time.Now().UnixMilli(), - Level: log.LevelInfo.String(), - Log: "[dashboard]: Set log level to " + l + " for this log.io", - }: - default: - } - } - } - } - }() - - sendMsgCh := make(chan any, 64) - go func() { - for { - select { - case v := <-c: - select { - case sendMsgCh <- v: - case <-ctx.Done(): - return - } - case <-ctx.Done(): - return - } - } - }() - - pingTicker := time.NewTicker(time.Second * 45) - defer pingTicker.Stop() - forceSendTimer := time.NewTimer(time.Second) - if !forceSendTimer.Stop() { - <-forceSendTimer.C - } - - batchMsg := make([]any, 0, 64) - for { - select { - case v := <-sendMsgCh: - batchMsg = append(batchMsg, v) - forceSendTimer.Reset(time.Second) - WAIT_MORE: - for { - select { - case v := <-sendMsgCh: - batchMsg = append(batchMsg, v) - case <-time.After(time.Millisecond * 20): - if !forceSendTimer.Stop() { - <-forceSendTimer.C - } - break WAIT_MORE - case <-forceSendTimer.C: - break WAIT_MORE - case <-ctx.Done(): - forceSendTimer.Stop() - return - } - } - if len(batchMsg) == 1 { - if err := conn.WriteJSON(batchMsg[0]); err != nil { - return - } - } else { - if err := conn.WriteJSON(batchMsg); err != nil { - return - } - } - // release objects - for i, _ := range batchMsg { - batchMsg[i] = nil - } - batchMsg = batchMsg[:0] - case <-pingTicker.C: - if err := conn.WriteJSON(Map{ - "type": "ping", - "data": time.Now().UnixMilli(), - }); err != nil { - log.Errorf("[log.io]: Error when sending ping packet: %v", err) - return - } - case <-ctx.Done(): - return - } - } -} - -func (cr *Cluster) routePprof(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - query := req.URL.Query() - lookup := query.Get("lookup") - p := pprof.Lookup(lookup) - if p == nil { - http.Error(rw, fmt.Sprintf("pprof.Lookup(%q) returned nil", lookup), http.StatusBadRequest) - return - } - view := query.Get("view") - debug, err := strconv.Atoi(query.Get("debug")) - if err != nil { - debug = 1 - } - if debug == 1 { - rw.Header().Set("Content-Type", "text/plain; charset=utf-8") - } else { - rw.Header().Set("Content-Type", "application/octet-stream") - } - if view != "1" { - name := fmt.Sprintf(time.Now().Format("dump-%s-20060102-150405"), lookup) - if debug == 1 { - name += ".txt" - } else { - name += ".dump" - } - rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) - } - rw.WriteHeader(http.StatusOK) - if debug == 1 { - fmt.Fprintf(rw, "version: %s (%s)\n", build.BuildVersion, build.ClusterVersion) - } - p.WriteTo(rw, debug) -} - -func (cr *Cluster) routeWebhookGET(rw http.ResponseWriter, req *http.Request) { - user := getLoggedUser(req) - if sid := req.URL.Query().Get("id"); sid != "" { - id, err := uuid.Parse(sid) - if err != nil { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "uuid format error", - "message": err.Error(), - }) - return - } - record, err := cr.database.GetWebhook(user, id) - if err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no webhook was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, record) - return - } - records := make([]database.WebhookRecord, 0, 4) - if err := cr.database.ForEachUsersWebhook(user, func(rec *database.WebhookRecord) error { - records = append(records, *rec) - return nil - }); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - writeJson(rw, http.StatusOK, records) -} - -func (cr *Cluster) routeWebhookPOST(rw http.ResponseWriter, req *http.Request) { - user := getLoggedUser(req) - var data database.WebhookRecord - if !parseRequestBody(rw, req, &data) { - return - } - - data.User = user - if err := cr.database.AddWebhook(data); err != nil { - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "Database update failed", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusCreated) -} - -func (cr *Cluster) routeWebhookPATCH(rw http.ResponseWriter, req *http.Request) { - user := getLoggedUser(req) - id := req.URL.Query().Get("id") - var data database.WebhookRecord - if !parseRequestBody(rw, req, &data) { - return - } - data.User = user - var err error - if data.Id, err = uuid.Parse(id); err != nil { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "uuid format error", - "message": err.Error(), - }) - return - } - if err := cr.database.UpdateWebhook(data); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no webhook was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) routeWebhookDELETE(rw http.ResponseWriter, req *http.Request) { - user := getLoggedUser(req) - id, err := uuid.Parse(req.URL.Query().Get("id")) - if err != nil { - writeJson(rw, http.StatusBadRequest, Map{ - "error": "uuid format error", - "message": err.Error(), - }) - return - } - if err := cr.database.RemoveWebhook(user, id); err != nil { - if err == database.ErrNotFound { - writeJson(rw, http.StatusNotFound, Map{ - "error": "no webhook was found", - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "database error", - "message": err.Error(), - }) - return - } - rw.WriteHeader(http.StatusNoContent) -} - -func (cr *Cluster) routeLogFiles(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - files := log.ListLogs() - type FileInfo struct { - Name string `json:"name"` - Size int64 `json:"size"` - } - data := make([]FileInfo, 0, len(files)) - for _, file := range files { - if s, err := os.Stat(filepath.Join(log.BaseDir(), file)); err == nil { - data = append(data, FileInfo{ - Name: file, - Size: s.Size(), - }) - } - } - writeJson(rw, http.StatusOK, Map{ - "files": data, - }) -} - -func (cr *Cluster) routeLogFile(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet && req.Method != http.MethodHead { - errorMethodNotAllowed(rw, req, http.MethodGet+", "+http.MethodHead) - return - } - query := req.URL.Query() - fd, err := os.Open(filepath.Join(log.BaseDir(), req.URL.Path)) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - writeJson(rw, http.StatusNotFound, Map{ - "error": "file not exists", - "message": "Cannot find log file", - "path": req.URL.Path, - }) - return - } - writeJson(rw, http.StatusInternalServerError, Map{ - "error": "cannot open file", - "message": err.Error(), - }) - return - } - defer fd.Close() - name := filepath.Base(req.URL.Path) - isGzip := filepath.Ext(name) == ".gz" - if query.Get("no_encrypt") == "1" { - var modTime time.Time - if stat, err := fd.Stat(); err == nil { - modTime = stat.ModTime() - } - rw.Header().Set("Cache-Control", "public, max-age=60, stale-while-revalidate=600") - if isGzip { - rw.Header().Set("Content-Type", "application/octet-stream") - } else { - rw.Header().Set("Content-Type", "text/plain; charset=utf-8") - } - rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) - http.ServeContent(rw, req, name, modTime, fd) - } else { - if !isGzip { - name += ".gz" - } - rw.Header().Set("Content-Type", "application/octet-stream") - rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name+".encrypted")) - cr.routeLogFileEncrypted(rw, req, fd, !isGzip) - } -} - -func (cr *Cluster) routeLogFileEncrypted(rw http.ResponseWriter, req *http.Request, r io.Reader, useGzip bool) { - rw.WriteHeader(http.StatusOK) - if req.Method == http.MethodHead { - return - } - if useGzip { - pr, pw := io.Pipe() - defer pr.Close() - go func(r io.Reader) { - gw := gzip.NewWriter(pw) - if _, err := io.Copy(gw, r); err != nil { - pw.CloseWithError(err) - return - } - if err := gw.Close(); err != nil { - pw.CloseWithError(err) - return - } - pw.Close() - }(r) - r = pr - } - if err := utils.EncryptStream(rw, r, utils.DeveloporPublicKey); err != nil { - log.Errorf("Cannot write encrypted log stream: %v", err) - } -} - type Map = map[string]any var errUnknownContent = errors.New("unknown content-type") @@ -877,5 +154,4 @@ func writeJson(rw http.ResponseWriter, code int, data any) (err error) { func errorMethodNotAllowed(rw http.ResponseWriter, req *http.Request, allow string) { rw.Header().Set("Allow", allow) rw.WriteHeader(http.StatusMethodNotAllowed) - return true } diff --git a/api/v0/auth.go b/api/v0/auth.go index 217d6ab3..0af13030 100644 --- a/api/v0/auth.go +++ b/api/v0/auth.go @@ -20,17 +20,29 @@ package v0 import ( + "context" + "crypto/subtle" "errors" - "fmt" "net/http" - "net/url" + "strings" "time" - "github.com/golang-jwt/jwt/v5" - + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/limited" + "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) +const ( + clientIdCookieName = "_id" + + clientIdKey = "go-openbmclapi.cluster.client.id" +) + +func apiGetClientId(req *http.Request) (id string) { + return req.Context().Value(clientIdKey).(string) +} + const jwtIssuerPrefix = "GOBA.dash.api" const ( @@ -95,26 +107,26 @@ func (h *Handler) authMiddleWare(rw http.ResponseWriter, req *http.Request, next ) if req.Method == http.MethodGet { if tk := req.URL.Query().Get("_t"); tk != "" { - path := GetRequestRealPath(req) - if id, uid, err = h.tokens.VerifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { + path := api.GetRequestRealPath(req) + if uid, err = h.tokens.VerifyAPIToken(cli, tk, path, req.URL.Query()); err == nil { typ = tokenTypeAPI } } } - if id == "" { + if typ == "" { auth := req.Header.Get("Authorization") tk, ok := strings.CutPrefix(auth, "Bearer ") if !ok { if err == nil { - err = ErrUnsupportAuthType + err = errors.New("Unsupported authorization type") } } else if id, uid, err = h.tokens.VerifyAuthToken(cli, tk); err == nil { typ = tokenTypeAuth } } if typ != "" { - user, err := h.users.GetUser(uid) - if err == nil { + user := h.users.GetUser(uid) + if user != nil { ctx = context.WithValue(ctx, tokenTypeKey, typ) ctx = context.WithValue(ctx, loggedUserKey, user) ctx = context.WithValue(ctx, tokenIdKey, id) @@ -142,7 +154,7 @@ func permHandle(perm api.PermissionFlag, next http.Handler) http.Handler { }) return } - if user.Permissions & perm != perm { + if user.Permissions&perm != perm { writeJson(rw, http.StatusForbidden, Map{ "error": "Permission denied", }) @@ -156,214 +168,328 @@ func permHandleFunc(perm api.PermissionFlag, next http.HandlerFunc) http.Handler return permHandle(perm, next) } -var ( - ErrUnsupportAuthType = errors.New("unsupported authorization type") - ErrScopeNotMatch = errors.New("scope not match") - ErrJTINotExists = errors.New("jti not exists") - - ErrStrictPathNotMatch = errors.New("strict path not match") - ErrStrictQueryNotMatch = errors.New("strict query value not match") -) - -func (cr *Cluster) getJWTKey(t *jwt.Token) (any, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) - } - return cr.apiHmacKey, nil -} - -const ( - challengeTokenScope = "GOBA-challenge" - authTokenScope = "GOBA-auth" - apiTokenScope = "GOBA-API" -) - -type challengeTokenClaims struct { - jwt.RegisteredClaims - - Scope string `json:"scope"` - Action string `json:"act"` -} - -func (cr *Cluster) generateChallengeToken(cliId string, action string) (string, error) { - now := time.Now() - exp := now.Add(time.Minute * 1) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, &challengeTokenClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - Subject: cliId, - Issuer: cr.jwtIssuer, - IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(exp), - }, - Scope: challengeTokenScope, - Action: action, - }) - tokenStr, err := token.SignedString(cr.apiHmacKey) - if err != nil { - return "", err - } - return tokenStr, nil +func (h *Handler) buildAuthRoute(mux *http.ServeMux) { + mux.HandleFunc("/challenge", h.routeChallenge) + mux.HandleFunc("POST /login", h.routeLogin) + mux.Handle("POST /requestToken", authHandleFunc(h.routeRequestToken)) + mux.Handle("POST /logout", authHandleFunc(h.routeLogout)) } -func (cr *Cluster) verifyChallengeToken(cliId string, action string, token string) (err error) { - var claims challengeTokenClaims - if _, err = jwt.ParseWithClaims( - token, - &claims, - cr.getJWTKey, - jwt.WithSubject(cliId), - jwt.WithIssuedAt(), - jwt.WithIssuer(cr.jwtIssuer), - ); err != nil { +func (h *Handler) routeChallenge(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) return } - if claims.Scope != challengeTokenScope { - return ErrScopeNotMatch - } - if claims.Action != action { - return ErrJTINotExists - } - return -} - -type authTokenClaims struct { - jwt.RegisteredClaims - - Scope string `json:"scope"` - User string `json:"usr"` -} - -func (cr *Cluster) generateAuthToken(cliId string, userId string) (string, error) { - jti, err := utils.GenRandB64(16) + cli := apiGetClientId(req) + query := req.URL.Query() + action := query.Get("action") + token, err := h.tokens.GenerateChallengeToken(cli, action) if err != nil { - return "", err + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Cannot generate token", + "message": err.Error(), + }) + return } - now := time.Now() - exp := now.Add(time.Hour * 24) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, &authTokenClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ID: jti, - Subject: cliId, - Issuer: cr.jwtIssuer, - IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(exp), - }, - Scope: authTokenScope, - User: userId, + writeJson(rw, http.StatusOK, Map{ + "token": token, }) - tokenStr, err := token.SignedString(cr.apiHmacKey) - if err != nil { - return "", err - } - if err = cr.database.AddJTI(jti, exp); err != nil { - return "", err - } - return tokenStr, nil } -func (cr *Cluster) verifyAuthToken(cliId string, token string) (id string, user string, err error) { - var claims authTokenClaims - if _, err = jwt.ParseWithClaims( - token, - &claims, - cr.getJWTKey, - jwt.WithSubject(cliId), - jwt.WithIssuedAt(), - jwt.WithIssuer(cr.jwtIssuer), - ); err != nil { - return +func (h *Handler) routeLogin(rw http.ResponseWriter, req *http.Request) { + cli := apiGetClientId(req) + + var data struct { + User string `json:"username" schema:"username"` + Challenge string `json:"challenge" schema:"challenge"` + Signature string `json:"signature" schema:"signature"` } - if claims.Scope != authTokenScope { - err = ErrScopeNotMatch + if !parseRequestBody(rw, req, &data) { return } - if user = claims.User; user == "" { - // reject old token - err = ErrJTINotExists + + if err := h.tokens.VerifyChallengeToken(cli, data.Challenge, "login"); err != nil { + writeJson(rw, http.StatusUnauthorized, Map{ + "error": "Invalid challenge", + }) return } - id = claims.ID - if ok, _ := cr.database.ValidJTI(id); !ok { - err = ErrJTINotExists + if err := h.users.VerifyUserPassword(data.User, func(password string) bool { + expectSignature := utils.HMACSha256HexBytes(password, data.Challenge) + return subtle.ConstantTimeCompare(expectSignature, ([]byte)(data.Signature)) == 0 + }); err != nil { + writeJson(rw, http.StatusUnauthorized, Map{ + "error": "The username or password is incorrect", + }) return } - return -} - -type apiTokenClaims struct { - jwt.RegisteredClaims - - Scope string `json:"scope"` - User string `json:"usr"` - StrictPath string `json:"str-p"` - StrictQuery map[string]string `json:"str-q,omitempty"` -} - -func (cr *Cluster) generateAPIToken(cliId string, userId string, path string, query map[string]string) (string, error) { - jti, err := utils.GenRandB64(8) + token, err := h.tokens.GenerateAuthToken(cli, data.User) if err != nil { - return "", err + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Cannot generate token", + "message": err.Error(), + }) + return } - now := time.Now() - exp := now.Add(time.Minute * 10) - token := jwt.NewWithClaims(jwt.SigningMethodHS256, &apiTokenClaims{ - RegisteredClaims: jwt.RegisteredClaims{ - ID: jti, - Subject: cliId, - Issuer: cr.jwtIssuer, - IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(exp), - }, - Scope: apiTokenScope, - User: userId, - StrictPath: path, - StrictQuery: query, + writeJson(rw, http.StatusOK, Map{ + "token": token, }) - tokenStr, err := token.SignedString(cr.apiHmacKey) - if err != nil { - return "", err - } - if err = cr.database.AddJTI(jti, exp); err != nil { - return "", err - } - return tokenStr, nil } -func (h *Handler) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { - var claims apiTokenClaims - _, err = jwt.ParseWithClaims( - token, - &claims, - cr.getJWTKey, - jwt.WithSubject(cliId), - jwt.WithIssuedAt(), - jwt.WithIssuer(cr.jwtIssuer), - ) - if err != nil { +func (h *Handler) routeRequestToken(rw http.ResponseWriter, req *http.Request) { + defer req.Body.Close() + if getRequestTokenType(req) != tokenTypeAuth { + writeJson(rw, http.StatusUnauthorized, Map{ + "error": "invalid authorization type", + }) return } - if claims.Scope != apiTokenScope { - err = ErrScopeNotMatch - return + + var payload struct { + Path string `json:"path"` + Query map[string]string `json:"query,omitempty"` } - if user = claims.User; user == "" { - err = ErrJTINotExists + if !parseRequestBody(rw, req, &payload) { return } - id = claims.ID - if ok, _ := cr.database.ValidJTI(id); !ok { - err = ErrJTINotExists + log.Debugf("payload: %#v", payload) + if payload.Path == "" || payload.Path[0] != '/' { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "path is invalid", + "message": "'path' must be a non empty string which starts with '/'", + }) return } - if claims.StrictPath != path { - err = ErrStrictPathNotMatch + cli := apiGetClientId(req) + user := getLoggedUser(req) + token, err := h.tokens.GenerateAPIToken(cli, user.Username, payload.Path, payload.Query) + if err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "cannot generate token", + "message": err.Error(), + }) return } - for k, v := range claims.StrictQuery { - if query.Get(k) != v { - err = ErrStrictQueryNotMatch - return - } - } - return + writeJson(rw, http.StatusOK, Map{ + "token": token, + }) +} + +func (h *Handler) routeLogout(rw http.ResponseWriter, req *http.Request) { + limited.SetSkipRateLimit(req) + tid := req.Context().Value(tokenIdKey).(string) + h.tokens.InvalidToken(tid) + rw.WriteHeader(http.StatusNoContent) } + +// var ( +// ErrUnsupportAuthType = errors.New("unsupported authorization type") +// ErrScopeNotMatch = errors.New("scope not match") +// ErrJTINotExists = errors.New("jti not exists") + +// ErrStrictPathNotMatch = errors.New("strict path not match") +// ErrStrictQueryNotMatch = errors.New("strict query value not match") +// ) + +// func (cr *Cluster) getJWTKey(t *jwt.Token) (any, error) { +// if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { +// return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) +// } +// return cr.apiHmacKey, nil +// } + +// const ( +// challengeTokenScope = "GOBA-challenge" +// authTokenScope = "GOBA-auth" +// apiTokenScope = "GOBA-API" +// ) + +// type challengeTokenClaims struct { +// jwt.RegisteredClaims + +// Scope string `json:"scope"` +// Action string `json:"act"` +// } + +// func (cr *Cluster) generateChallengeToken(cliId string, action string) (string, error) { +// now := time.Now() +// exp := now.Add(time.Minute * 1) +// token := jwt.NewWithClaims(jwt.SigningMethodHS256, &challengeTokenClaims{ +// RegisteredClaims: jwt.RegisteredClaims{ +// Subject: cliId, +// Issuer: cr.jwtIssuer, +// IssuedAt: jwt.NewNumericDate(now), +// ExpiresAt: jwt.NewNumericDate(exp), +// }, +// Scope: challengeTokenScope, +// Action: action, +// }) +// tokenStr, err := token.SignedString(cr.apiHmacKey) +// if err != nil { +// return "", err +// } +// return tokenStr, nil +// } + +// func (cr *Cluster) verifyChallengeToken(cliId string, action string, token string) (err error) { +// var claims challengeTokenClaims +// if _, err = jwt.ParseWithClaims( +// token, +// &claims, +// cr.getJWTKey, +// jwt.WithSubject(cliId), +// jwt.WithIssuedAt(), +// jwt.WithIssuer(cr.jwtIssuer), +// ); err != nil { +// return +// } +// if claims.Scope != challengeTokenScope { +// return ErrScopeNotMatch +// } +// if claims.Action != action { +// return ErrJTINotExists +// } +// return +// } + +// type authTokenClaims struct { +// jwt.RegisteredClaims + +// Scope string `json:"scope"` +// User string `json:"usr"` +// } + +// func (cr *Cluster) generateAuthToken(cliId string, userId string) (string, error) { +// jti, err := utils.GenRandB64(16) +// if err != nil { +// return "", err +// } +// now := time.Now() +// exp := now.Add(time.Hour * 24) +// token := jwt.NewWithClaims(jwt.SigningMethodHS256, &authTokenClaims{ +// RegisteredClaims: jwt.RegisteredClaims{ +// ID: jti, +// Subject: cliId, +// Issuer: cr.jwtIssuer, +// IssuedAt: jwt.NewNumericDate(now), +// ExpiresAt: jwt.NewNumericDate(exp), +// }, +// Scope: authTokenScope, +// User: userId, +// }) +// tokenStr, err := token.SignedString(cr.apiHmacKey) +// if err != nil { +// return "", err +// } +// if err = cr.database.AddJTI(jti, exp); err != nil { +// return "", err +// } +// return tokenStr, nil +// } + +// func (cr *Cluster) verifyAuthToken(cliId string, token string) (id string, user string, err error) { +// var claims authTokenClaims +// if _, err = jwt.ParseWithClaims( +// token, +// &claims, +// cr.getJWTKey, +// jwt.WithSubject(cliId), +// jwt.WithIssuedAt(), +// jwt.WithIssuer(cr.jwtIssuer), +// ); err != nil { +// return +// } +// if claims.Scope != authTokenScope { +// err = ErrScopeNotMatch +// return +// } +// if user = claims.User; user == "" { +// // reject old token +// err = ErrJTINotExists +// return +// } +// id = claims.ID +// if ok, _ := cr.database.ValidJTI(id); !ok { +// err = ErrJTINotExists +// return +// } +// return +// } + +// type apiTokenClaims struct { +// jwt.RegisteredClaims + +// Scope string `json:"scope"` +// User string `json:"usr"` +// StrictPath string `json:"str-p"` +// StrictQuery map[string]string `json:"str-q,omitempty"` +// } + +// func (cr *Cluster) generateAPIToken(cliId string, userId string, path string, query map[string]string) (string, error) { +// jti, err := utils.GenRandB64(8) +// if err != nil { +// return "", err +// } +// now := time.Now() +// exp := now.Add(time.Minute * 10) +// token := jwt.NewWithClaims(jwt.SigningMethodHS256, &apiTokenClaims{ +// RegisteredClaims: jwt.RegisteredClaims{ +// ID: jti, +// Subject: cliId, +// Issuer: cr.jwtIssuer, +// IssuedAt: jwt.NewNumericDate(now), +// ExpiresAt: jwt.NewNumericDate(exp), +// }, +// Scope: apiTokenScope, +// User: userId, +// StrictPath: path, +// StrictQuery: query, +// }) +// tokenStr, err := token.SignedString(cr.apiHmacKey) +// if err != nil { +// return "", err +// } +// if err = cr.database.AddJTI(jti, exp); err != nil { +// return "", err +// } +// return tokenStr, nil +// } + +// func (h *Handler) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { +// var claims apiTokenClaims +// _, err = jwt.ParseWithClaims( +// token, +// &claims, +// cr.getJWTKey, +// jwt.WithSubject(cliId), +// jwt.WithIssuedAt(), +// jwt.WithIssuer(cr.jwtIssuer), +// ) +// if err != nil { +// return +// } +// if claims.Scope != apiTokenScope { +// err = ErrScopeNotMatch +// return +// } +// if user = claims.User; user == "" { +// err = ErrJTINotExists +// return +// } +// id = claims.ID +// if ok, _ := cr.database.ValidJTI(id); !ok { +// err = ErrJTINotExists +// return +// } +// if claims.StrictPath != path { +// err = ErrStrictPathNotMatch +// return +// } +// for k, v := range claims.StrictQuery { +// if query.Get(k) != v { +// err = ErrStrictQueryNotMatch +// return +// } +// } +// return +// } diff --git a/api/v0/configure.go b/api/v0/configure.go index e1e16877..276b62f0 100644 --- a/api/v0/configure.go +++ b/api/v0/configure.go @@ -20,13 +20,160 @@ package v0 import ( + "fmt" + "io" + "mime" "net/http" + + "github.com/LiterMC/go-openbmclapi/api" ) -func (h *Handler) apiConfigureCluster(rw http.ResponseWriter, req *http.Request) { - // +func (h *Handler) buildConfigureRoute(mux *http.ServeMux) { + mux.Handle("GET /config", permHandleFunc(api.FullConfigPerm, h.routeConfigGET)) + mux.Handle("GET /config/{path}", permHandleFunc(api.FullConfigPerm, h.routeConfigGETPath)) + mux.Handle("PUT /config", permHandleFunc(api.FullConfigPerm, h.routeConfigPUT)) + mux.Handle("PATCH /config/{path}", permHandleFunc(api.FullConfigPerm, h.routeConfigPATCH)) + mux.Handle("DELETE /config/{path}", permHandleFunc(api.FullConfigPerm, h.routeConfigDELETE)) + + mux.Handle("GET /configure/clusters", permHandleFunc(api.ClusterPerm, h.routeConfigureClustersGET)) + mux.Handle("GET /configure/cluster/{cluster_id}", permHandleFunc(api.ClusterPerm, h.routeConfigureClusterGET)) + mux.Handle("PUT /configure/cluster/{cluster_id}", permHandleFunc(api.ClusterPerm, h.routeConfigureClusterPUT)) + mux.Handle("PATCH /configure/cluster/{cluster_id}/{path}", permHandleFunc(api.ClusterPerm, h.routeConfigureClusterPATCH)) + mux.Handle("DELETE /configure/cluster/{cluster_id}", permHandleFunc(api.ClusterPerm, h.routeConfigureClusterDELETE)) + + mux.Handle("GET /configure/storages", permHandleFunc(api.StoragePerm, h.routeConfigureStoragesGET)) + mux.Handle("GET /configure/storage/{storage_index}", permHandleFunc(api.StoragePerm, h.routeConfigureStorageGET)) + mux.Handle("PUT /configure/storage/{storage_index}", permHandleFunc(api.StoragePerm, h.routeConfigureStoragePUT)) + mux.Handle("PATCH /configure/storage/{storage_index}/{path}", permHandleFunc(api.StoragePerm, h.routeConfigureStoragePATCH)) + mux.Handle("DELETE /configure/storage/{storage_index}", permHandleFunc(api.StoragePerm, h.routeConfigureStorageDELETE)) + mux.Handle("POST /configure/storage/{storage_index}/move", permHandleFunc(api.StoragePerm, h.routeConfigureStorageMove)) } -func (h *Handler) apiConfigureStorage(rw http.ResponseWriter, req *http.Request) { +func (h *Handler) routeConfigGET(rw http.ResponseWriter, req *http.Request) { + buf, err := h.config.MarshalJSON() + if err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "MarshalJSONError", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusOK) + rw.Write(buf) +} + +func (h *Handler) routeConfigPUT(rw http.ResponseWriter, req *http.Request) { + contentType, _, err := mime.ParseMediaType(req.Header.Get("Content-Type")) + if err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "Unexpected Content-Type", + "content-type": req.Header.Get("Content-Type"), + "message": err.Error(), + }) + return + } + etag := req.Header.Get("If-Match") + if len(etag) > 2 && etag[0] == '"' && etag[len(etag)-1] == '"' { + etag = etag[1 : len(etag)-1] + } else { + etag = "" + } + err = h.config.DoLockedAction(etag, func(config api.ConfigHandler) error { + switch contentType { + case "application/json": + buf, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("Failed to read request body: %w", err) + } + return config.UnmarshalJSON(buf) + case "application/x-yaml": + buf, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("Failed to read request body: %w", err) + } + return config.UnmarshalYAML(buf) + default: + return errUnknownContent + } + }) + if err != nil { + if err == errUnknownContent { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "Unexpected Content-Type", + "content-type": req.Header.Get("Content-Type"), + "message": "Expected application/json, application/x-yaml", + }) + return + } + writeJson(rw, http.StatusBadRequest, Map{ + "error": "UnmarshalError", + "message": err.Error(), + }) + return + } +} + +func (h *Handler) routeConfigGETPath(rw http.ResponseWriter, req *http.Request) { + rw.WriteHeader(http.StatusOK) +} + +func (h *Handler) routeConfigPATCH(rw http.ResponseWriter, req *http.Request) { +} + +func (h *Handler) routeConfigDELETE(rw http.ResponseWriter, req *http.Request) { +} + +func (h *Handler) routeConfigureClustersGET(rw http.ResponseWriter, req *http.Request) { // } + +func (h *Handler) routeConfigureClusterGET(rw http.ResponseWriter, req *http.Request) { + clusterId := req.PathValue("cluster_id") + _ = clusterId +} + +func (h *Handler) routeConfigureClusterPUT(rw http.ResponseWriter, req *http.Request) { + clusterId := req.PathValue("cluster_id") + _ = clusterId +} + +func (h *Handler) routeConfigureClusterPATCH(rw http.ResponseWriter, req *http.Request) { + clusterId := req.PathValue("cluster_id") + path := req.PathValue("path") + _, _ = clusterId, path +} + +func (h *Handler) routeConfigureClusterDELETE(rw http.ResponseWriter, req *http.Request) { + clusterId := req.PathValue("cluster_id") + _ = clusterId +} + +func (h *Handler) routeConfigureStoragesGET(rw http.ResponseWriter, req *http.Request) { +} + +func (h *Handler) routeConfigureStorageGET(rw http.ResponseWriter, req *http.Request) { + storageIndex := req.PathValue("storage_index") + _ = storageIndex +} + +func (h *Handler) routeConfigureStoragePUT(rw http.ResponseWriter, req *http.Request) { + storageIndex := req.PathValue("storage_index") + _ = storageIndex +} + +func (h *Handler) routeConfigureStoragePATCH(rw http.ResponseWriter, req *http.Request) { + storageIndex := req.PathValue("storage_index") + path := req.PathValue("path") + _, _ = storageIndex, path +} + +func (h *Handler) routeConfigureStorageDELETE(rw http.ResponseWriter, req *http.Request) { + storageIndex := req.PathValue("storage_index") + _ = storageIndex +} + +func (h *Handler) routeConfigureStorageMove(rw http.ResponseWriter, req *http.Request) { + storageIndex := req.PathValue("storage_index") + storageIndexTo := req.URL.Query().Get("to") + _, _ = storageIndex, storageIndexTo +} diff --git a/api/v0/debug.go b/api/v0/debug.go new file mode 100644 index 00000000..9727d627 --- /dev/null +++ b/api/v0/debug.go @@ -0,0 +1,393 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v0 + +import ( + "compress/gzip" + "context" + "errors" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "runtime/pprof" + "strconv" + "sync/atomic" + "time" + + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/utils" +) + +func (h *Handler) buildDebugRoute(mux *http.ServeMux) { + mux.HandleFunc("/log.io", h.routeLogIO) + mux.Handle("/pprof", permHandleFunc(api.DebugPerm, h.routePprof)) + mux.Handle("GET /log_files", permHandleFunc(api.LogPerm, h.routeLogFiles)) + mux.Handle("GET /log_file/{file_name}", permHandleFunc(api.LogPerm, h.routeLogFile)) +} + +func (h *Handler) routeLogIO(rw http.ResponseWriter, req *http.Request) { + addr := api.GetRequestRealAddr(req) + + conn, err := h.wsUpgrader.Upgrade(rw, req, nil) + if err != nil { + log.Debugf("[log.io]: Websocket upgrade error: %v", err) + http.Error(rw, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + + cli := apiGetClientId(req) + + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + + conn.SetReadLimit(1024 * 4) + pongTimeoutTimer := time.NewTimer(time.Second * 75) + go func() { + defer conn.Close() + defer cancel() + defer pongTimeoutTimer.Stop() + select { + case _, ok := <-pongTimeoutTimer.C: + if !ok { + return + } + log.Error("[log.io]: Did not receive packet from client longer than 75s") + return + case <-ctx.Done(): + return + } + }() + + var authData struct { + Token string `json:"token"` + } + deadline := time.Now().Add(time.Second * 10) + conn.SetReadDeadline(deadline) + err = conn.ReadJSON(&authData) + conn.SetReadDeadline(time.Time{}) + if err != nil { + if time.Now().After(deadline) { + conn.WriteJSON(Map{ + "type": "error", + "message": "auth timeout", + }) + } else { + conn.WriteJSON(Map{ + "type": "error", + "message": "unexpected auth data: " + err.Error(), + }) + } + return + } + if _, _, err = h.tokens.VerifyAuthToken(cli, authData.Token); err != nil { + conn.WriteJSON(Map{ + "type": "error", + "message": "auth failed", + }) + return + } + if err := conn.WriteJSON(Map{ + "type": "ready", + }); err != nil { + return + } + + var level atomic.Int32 + level.Store((int32)(log.LevelInfo)) + + type logObj struct { + Type string `json:"type"` + Time int64 `json:"time"` // UnixMilli + Level string `json:"lvl"` + Log string `json:"log"` + } + c := make(chan *logObj, 64) + unregister := log.RegisterLogMonitor(log.LevelDebug, func(ts int64, l log.Level, msg string) { + if (log.Level)(level.Load()) > l&log.LevelMask { + return + } + select { + case c <- &logObj{ + Type: "log", + Time: ts, + Level: l.String(), + Log: msg, + }: + default: + } + }) + defer unregister() + + go func() { + defer log.RecoverPanic(nil) + defer conn.Close() + defer cancel() + var data map[string]any + for { + clear(data) + if err := conn.ReadJSON(&data); err != nil { + log.Errorf("[log.io]: Cannot read from peer: %v", err) + return + } + typ, ok := data["type"].(string) + if !ok { + continue + } + switch typ { + case "pong": + log.Debugf("[log.io]: received PONG from %s: %v", addr, data["data"]) + pongTimeoutTimer.Reset(time.Second * 75) + case "set-level": + l, ok := data["level"].(string) + if ok { + switch l { + case "DBUG": + level.Store((int32)(log.LevelDebug)) + case "INFO": + level.Store((int32)(log.LevelInfo)) + case "WARN": + level.Store((int32)(log.LevelWarn)) + case "ERRO": + level.Store((int32)(log.LevelError)) + default: + continue + } + select { + case c <- &logObj{ + Type: "log", + Time: time.Now().UnixMilli(), + Level: log.LevelInfo.String(), + Log: "[dashboard]: Set log level to " + l + " for this log.io", + }: + default: + } + } + } + } + }() + + sendMsgCh := make(chan any, 64) + go func() { + for { + select { + case v := <-c: + select { + case sendMsgCh <- v: + case <-ctx.Done(): + return + } + case <-ctx.Done(): + return + } + } + }() + + pingTicker := time.NewTicker(time.Second * 45) + defer pingTicker.Stop() + forceSendTimer := time.NewTimer(time.Second) + if !forceSendTimer.Stop() { + <-forceSendTimer.C + } + + batchMsg := make([]any, 0, 64) + for { + select { + case v := <-sendMsgCh: + batchMsg = append(batchMsg, v) + forceSendTimer.Reset(time.Second) + WAIT_MORE: + for { + select { + case v := <-sendMsgCh: + batchMsg = append(batchMsg, v) + case <-time.After(time.Millisecond * 20): + if !forceSendTimer.Stop() { + <-forceSendTimer.C + } + break WAIT_MORE + case <-forceSendTimer.C: + break WAIT_MORE + case <-ctx.Done(): + forceSendTimer.Stop() + return + } + } + if len(batchMsg) == 1 { + if err := conn.WriteJSON(batchMsg[0]); err != nil { + return + } + } else { + if err := conn.WriteJSON(batchMsg); err != nil { + return + } + } + // release objects + for i, _ := range batchMsg { + batchMsg[i] = nil + } + batchMsg = batchMsg[:0] + case <-pingTicker.C: + if err := conn.WriteJSON(Map{ + "type": "ping", + "data": time.Now().UnixMilli(), + }); err != nil { + log.Errorf("[log.io]: Error when sending ping packet: %v", err) + return + } + case <-ctx.Done(): + return + } + } +} + +func (h *Handler) routePprof(rw http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodGet { + errorMethodNotAllowed(rw, req, http.MethodGet) + return + } + query := req.URL.Query() + lookup := query.Get("lookup") + p := pprof.Lookup(lookup) + if p == nil { + http.Error(rw, fmt.Sprintf("pprof.Lookup(%q) returned nil", lookup), http.StatusBadRequest) + return + } + view := query.Get("view") + debug, err := strconv.Atoi(query.Get("debug")) + if err != nil { + debug = 1 + } + if debug == 1 { + rw.Header().Set("Content-Type", "text/plain; charset=utf-8") + } else { + rw.Header().Set("Content-Type", "application/octet-stream") + } + if view != "1" { + name := fmt.Sprintf(time.Now().Format("dump-%s-20060102-150405"), lookup) + if debug == 1 { + name += ".txt" + } else { + name += ".dump" + } + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + } + rw.WriteHeader(http.StatusOK) + if debug == 1 { + fmt.Fprintf(rw, "version: %s (%s)\n", build.BuildVersion, build.ClusterVersion) + } + p.WriteTo(rw, debug) +} + +func (h *Handler) routeLogFiles(rw http.ResponseWriter, req *http.Request) { + files := log.ListLogs() + type FileInfo struct { + Name string `json:"name"` + Size int64 `json:"size"` + } + data := make([]FileInfo, 0, len(files)) + for _, file := range files { + if s, err := os.Stat(filepath.Join(log.BaseDir(), file)); err == nil { + data = append(data, FileInfo{ + Name: file, + Size: s.Size(), + }) + } + } + writeJson(rw, http.StatusOK, Map{ + "files": data, + }) +} + +func (h *Handler) routeLogFile(rw http.ResponseWriter, req *http.Request) { + fileName := req.PathValue("file_name") + query := req.URL.Query() + fd, err := os.Open(filepath.Join(log.BaseDir(), fileName)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + writeJson(rw, http.StatusNotFound, Map{ + "error": "file not exists", + "message": "Cannot find log file", + "path": req.URL.Path, + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "cannot open file", + "message": err.Error(), + }) + return + } + defer fd.Close() + name := filepath.Base(req.URL.Path) + isGzip := filepath.Ext(name) == ".gz" + if query.Get("no_encrypt") == "1" { + var modTime time.Time + if stat, err := fd.Stat(); err == nil { + modTime = stat.ModTime() + } + rw.Header().Set("Cache-Control", "public, max-age=60, stale-while-revalidate=600") + if isGzip { + rw.Header().Set("Content-Type", "application/octet-stream") + } else { + rw.Header().Set("Content-Type", "text/plain; charset=utf-8") + } + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + http.ServeContent(rw, req, name, modTime, fd) + } else { + if !isGzip { + name += ".gz" + } + rw.Header().Set("Content-Type", "application/octet-stream") + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name+".encrypted")) + h.routeLogFileEncrypted(rw, req, fd, !isGzip) + } +} + +func (h *Handler) routeLogFileEncrypted(rw http.ResponseWriter, req *http.Request, r io.Reader, useGzip bool) { + rw.WriteHeader(http.StatusOK) + if req.Method == http.MethodHead { + return + } + if useGzip { + pr, pw := io.Pipe() + defer pr.Close() + go func(r io.Reader) { + gw := gzip.NewWriter(pw) + if _, err := io.Copy(gw, r); err != nil { + pw.CloseWithError(err) + return + } + if err := gw.Close(); err != nil { + pw.CloseWithError(err) + return + } + pw.Close() + }(r) + r = pr + } + if err := utils.EncryptStream(rw, r, utils.DeveloporPublicKey); err != nil { + log.Errorf("Cannot write encrypted log stream: %v", err) + } +} diff --git a/api/v0/stat.go b/api/v0/stat.go new file mode 100644 index 00000000..125aaf2c --- /dev/null +++ b/api/v0/stat.go @@ -0,0 +1,63 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package v0 + +import ( + "net/http" + "time" + + "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/limited" +) + +func (h *Handler) buildStatRoute(mux *http.ServeMux) { + mux.HandleFunc("GET /ping", h.routePing) + mux.HandleFunc("GET /status", h.routeStatus) + mux.HandleFunc("GET /stat/{name}", h.routeStat) +} + +func (h *Handler) routePing(rw http.ResponseWriter, req *http.Request) { + limited.SetSkipRateLimit(req) + authed := getRequestTokenType(req) == tokenTypeAuth + writeJson(rw, http.StatusOK, Map{ + "version": build.BuildVersion, + "time": time.Now().UnixMilli(), + "authed": authed, + }) +} + +func (h *Handler) routeStatus(rw http.ResponseWriter, req *http.Request) { + limited.SetSkipRateLimit(req) + writeJson(rw, http.StatusOK, h.stats.GetStatus()) +} + +func (h *Handler) routeStat(rw http.ResponseWriter, req *http.Request) { + limited.SetSkipRateLimit(req) + name := req.PathValue("name") + data := h.stats.GetAccessStat(name) + if data == nil { + writeJson(rw, http.StatusNotFound, Map{ + "error": "AccessStatNotFoudn", + "name": name, + }) + return + } + writeJson(rw, http.StatusOK, data) +} diff --git a/api/v0/subscription.go b/api/v0/subscription.go index 0a5f4709..cdf30fc7 100644 --- a/api/v0/subscription.go +++ b/api/v0/subscription.go @@ -21,14 +21,36 @@ package v0 import ( "net/http" + + "github.com/google/uuid" + + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/utils" ) +func (h *Handler) buildSubscriptionRoute(mux *http.ServeMux) { + mux.HandleFunc("GET /subscribeKey", h.routeSubscribeKey) + mux.Handle("/subscribe", permHandle(api.SubscribePerm, &utils.HttpMethodHandler{ + Get: (http.HandlerFunc)(h.routeSubscribeGET), + Post: (http.HandlerFunc)(h.routeSubscribePOST), + Delete: (http.HandlerFunc)(h.routeSubscribeDELETE), + })) + mux.Handle("/subscribe_email", permHandle(api.SubscribePerm, &utils.HttpMethodHandler{ + Get: (http.HandlerFunc)(h.routeSubscribeEmailGET), + Post: (http.HandlerFunc)(h.routeSubscribeEmailPOST), + Patch: (http.HandlerFunc)(h.routeSubscribeEmailPATCH), + Delete: (http.HandlerFunc)(h.routeSubscribeEmailDELETE), + })) + mux.Handle("/webhook", permHandle(api.SubscribePerm, &utils.HttpMethodHandler{ + Get: (http.HandlerFunc)(h.routeWebhookGET), + Post: (http.HandlerFunc)(h.routeWebhookPOST), + Patch: (http.HandlerFunc)(h.routeWebhookPATCH), + Delete: (http.HandlerFunc)(h.routeWebhookDELETE), + })) +} + func (h *Handler) routeSubscribeKey(rw http.ResponseWriter, req *http.Request) { - if req.Method != http.MethodGet { - errorMethodNotAllowed(rw, req, http.MethodGet) - return - } - key := h.subManager.GetWebPushKey() + key := h.subscriptions.GetWebPushKey() etag := `"` + utils.AsSha256(key) + `"` rw.Header().Set("ETag", etag) if cachedTag := req.Header.Get("If-None-Match"); cachedTag == etag { @@ -43,9 +65,9 @@ func (h *Handler) routeSubscribeKey(rw http.ResponseWriter, req *http.Request) { func (h *Handler) routeSubscribeGET(rw http.ResponseWriter, req *http.Request) { client := apiGetClientId(req) user := getLoggedUser(req) - record, err := h.subManager.GetSubscribe(user, client) + record, err := h.subscriptions.GetSubscribe(user.Username, client) if err != nil { - if err == database.ErrNotFound { + if err == api.ErrNotFound { writeJson(rw, http.StatusNotFound, Map{ "error": "no subscription was found", }) @@ -66,13 +88,13 @@ func (h *Handler) routeSubscribeGET(rw http.ResponseWriter, req *http.Request) { func (h *Handler) routeSubscribePOST(rw http.ResponseWriter, req *http.Request) { client := apiGetClientId(req) user := getLoggedUser(req) - data, ok := parseRequestBody[database.SubscribeRecord](rw, req, nil) - if !ok { + var data api.SubscribeRecord + if !parseRequestBody(rw, req, &data) { return } - data.User = user + data.User = user.Username data.Client = client - if err := h.subManager.SetSubscribe(data); err != nil { + if err := h.subscriptions.SetSubscribe(data); err != nil { writeJson(rw, http.StatusInternalServerError, Map{ "error": "Database update failed", "message": err.Error(), @@ -85,8 +107,8 @@ func (h *Handler) routeSubscribePOST(rw http.ResponseWriter, req *http.Request) func (h *Handler) routeSubscribeDELETE(rw http.ResponseWriter, req *http.Request) { client := apiGetClientId(req) user := getLoggedUser(req) - if err := h.subManager.RemoveSubscribe(user, client); err != nil { - if err == database.ErrNotFound { + if err := h.subscriptions.RemoveSubscribe(user.Username, client); err != nil { + if err == api.ErrNotFound { writeJson(rw, http.StatusNotFound, Map{ "error": "no subscription was found", }) @@ -104,9 +126,9 @@ func (h *Handler) routeSubscribeDELETE(rw http.ResponseWriter, req *http.Request func (h *Handler) routeSubscribeEmailGET(rw http.ResponseWriter, req *http.Request) { user := getLoggedUser(req) if addr := req.URL.Query().Get("addr"); addr != "" { - record, err := h.subManager.GetEmailSubscription(user, addr) + record, err := h.subscriptions.GetEmailSubscription(user.Username, addr) if err != nil { - if err == database.ErrNotFound { + if err == api.ErrNotFound { writeJson(rw, http.StatusNotFound, Map{ "error": "no email subscription was found", }) @@ -121,8 +143,8 @@ func (h *Handler) routeSubscribeEmailGET(rw http.ResponseWriter, req *http.Reque writeJson(rw, http.StatusOK, record) return } - records := make([]database.EmailSubscriptionRecord, 0, 4) - if err := h.subManager.ForEachUsersEmailSubscription(user, func(rec *database.EmailSubscriptionRecord) error { + records := make([]api.EmailSubscriptionRecord, 0, 4) + if err := h.subscriptions.ForEachUsersEmailSubscription(user.Username, func(rec *api.EmailSubscriptionRecord) error { records = append(records, *rec) return nil }); err != nil { @@ -137,13 +159,13 @@ func (h *Handler) routeSubscribeEmailGET(rw http.ResponseWriter, req *http.Reque func (h *Handler) routeSubscribeEmailPOST(rw http.ResponseWriter, req *http.Request) { user := getLoggedUser(req) - data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) - if !ok { + var data api.EmailSubscriptionRecord + if !parseRequestBody(rw, req, &data) { return } - data.User = user - if err := h.subManager.AddEmailSubscription(data); err != nil { + data.User = user.Username + if err := h.subscriptions.AddEmailSubscription(data); err != nil { writeJson(rw, http.StatusInternalServerError, Map{ "error": "Database update failed", "message": err.Error(), @@ -156,14 +178,14 @@ func (h *Handler) routeSubscribeEmailPOST(rw http.ResponseWriter, req *http.Requ func (h *Handler) routeSubscribeEmailPATCH(rw http.ResponseWriter, req *http.Request) { user := getLoggedUser(req) addr := req.URL.Query().Get("addr") - data, ok := parseRequestBody[database.EmailSubscriptionRecord](rw, req, nil) - if !ok { + var data api.EmailSubscriptionRecord + if !parseRequestBody(rw, req, &data) { return } - data.User = user + data.User = user.Username data.Addr = addr - if err := h.subManager.UpdateEmailSubscription(data); err != nil { - if err == database.ErrNotFound { + if err := h.subscriptions.UpdateEmailSubscription(data); err != nil { + if err == api.ErrNotFound { writeJson(rw, http.StatusNotFound, Map{ "error": "no email subscription was found", }) @@ -181,8 +203,8 @@ func (h *Handler) routeSubscribeEmailPATCH(rw http.ResponseWriter, req *http.Req func (h *Handler) routeSubscribeEmailDELETE(rw http.ResponseWriter, req *http.Request) { user := getLoggedUser(req) addr := req.URL.Query().Get("addr") - if err := h.subManager.RemoveEmailSubscription(user, addr); err != nil { - if err == database.ErrNotFound { + if err := h.subscriptions.RemoveEmailSubscription(user.Username, addr); err != nil { + if err == api.ErrNotFound { writeJson(rw, http.StatusNotFound, Map{ "error": "no email subscription was found", }) @@ -196,3 +218,121 @@ func (h *Handler) routeSubscribeEmailDELETE(rw http.ResponseWriter, req *http.Re } rw.WriteHeader(http.StatusNoContent) } + +func (h *Handler) routeWebhookGET(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + if sid := req.URL.Query().Get("id"); sid != "" { + id, err := uuid.Parse(sid) + if err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "uuid format error", + "message": err.Error(), + }) + return + } + record, err := h.subscriptions.GetWebhook(user.Username, id) + if err != nil { + if err == api.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no webhook was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, record) + return + } + records := make([]api.WebhookRecord, 0, 4) + if err := h.subscriptions.ForEachUsersWebhook(user.Username, func(rec *api.WebhookRecord) error { + records = append(records, *rec) + return nil + }); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + writeJson(rw, http.StatusOK, records) +} + +func (h *Handler) routeWebhookPOST(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + var data api.WebhookRecord + if !parseRequestBody(rw, req, &data) { + return + } + + data.User = user.Username + if err := h.subscriptions.AddWebhook(data); err != nil { + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "Database update failed", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusCreated) +} + +func (h *Handler) routeWebhookPATCH(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + id := req.URL.Query().Get("id") + var data api.WebhookRecord + if !parseRequestBody(rw, req, &data) { + return + } + data.User = user.Username + var err error + if data.Id, err = uuid.Parse(id); err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "uuid format error", + "message": err.Error(), + }) + return + } + if err := h.subscriptions.UpdateWebhook(data); err != nil { + if err == api.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no webhook was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} + +func (h *Handler) routeWebhookDELETE(rw http.ResponseWriter, req *http.Request) { + user := getLoggedUser(req) + id, err := uuid.Parse(req.URL.Query().Get("id")) + if err != nil { + writeJson(rw, http.StatusBadRequest, Map{ + "error": "uuid format error", + "message": err.Error(), + }) + return + } + if err := h.subscriptions.RemoveWebhook(user.Username, id); err != nil { + if err == api.ErrNotFound { + writeJson(rw, http.StatusNotFound, Map{ + "error": "no webhook was found", + }) + return + } + writeJson(rw, http.StatusInternalServerError, Map{ + "error": "database error", + "message": err.Error(), + }) + return + } + rw.WriteHeader(http.StatusNoContent) +} diff --git a/config.go b/config.go index 6baa51d5..03e31745 100644 --- a/config.go +++ b/config.go @@ -52,7 +52,6 @@ type AdvancedConfig struct { NoGC bool `yaml:"no-gc"` HeavyCheckInterval int `yaml:"heavy-check-interval"` KeepaliveTimeout int `yaml:"keepalive-timeout"` - SkipFirstSync bool `yaml:"skip-first-sync"` SkipSignatureCheck bool `yaml:"skip-signature-check"` NoFastEnable bool `yaml:"no-fast-enable"` WaitBeforeEnable int `yaml:"wait-before-enable"` @@ -193,13 +192,12 @@ type Config struct { PublicHost string `yaml:"public-host"` PublicPort uint16 `yaml:"public-port"` Port uint16 `yaml:"port"` - ClusterId string `yaml:"cluster-id"` - ClusterSecret string `yaml:"cluster-secret"` SyncInterval int `yaml:"sync-interval"` OnlyGcWhenStart bool `yaml:"only-gc-when-start"` DownloadMaxConn int `yaml:"download-max-conn"` MaxReconnectCount int `yaml:"max-reconnect-count"` + Clusters map[string]ClusterItem `yaml:"clusters"` Certificates []CertificateConfig `yaml:"certificates"` Tunneler TunnelConfig `yaml:"tunneler"` Cache CacheConfig `yaml:"cache"` @@ -223,111 +221,107 @@ func (cfg *Config) applyWebManifest(manifest map[string]any) { } } -var defaultConfig = Config{ - LogSlots: 7, - NoAccessLog: false, - AccessLogSlots: 16, - Byoc: false, - TrustedXForwardedFor: false, - PublicHost: "", - PublicPort: 0, - Port: 4000, - ClusterId: "${CLUSTER_ID}", - ClusterSecret: "${CLUSTER_SECRET}", - SyncInterval: 10, - OnlyGcWhenStart: false, - DownloadMaxConn: 16, - MaxReconnectCount: 10, - - Certificates: []CertificateConfig{ - { - Cert: "/path/to/cert.pem", - Key: "/path/to/key.pem", +func getDefaultConfig() *Config { + return &Config{ + LogSlots: 7, + NoAccessLog: false, + AccessLogSlots: 16, + Byoc: false, + TrustedXForwardedFor: false, + PublicHost: "", + PublicPort: 0, + Port: 4000, + SyncInterval: 10, + OnlyGcWhenStart: false, + DownloadMaxConn: 16, + MaxReconnectCount: 10, + + Clusters: map[string]ClusterItem{}, + + Certificates: []CertificateConfig{}, + + Tunneler: TunnelConfig{ + Enable: false, + TunnelProg: "./path/to/tunnel/program", + OutputRegex: `\bNATedAddr\s+(?P[0-9.]+|\[[0-9a-f:]+\]):(?P\d+)$`, + TunnelTimeout: 0, }, - }, - - Tunneler: TunnelConfig{ - Enable: false, - TunnelProg: "./path/to/tunnel/program", - OutputRegex: `\bNATedAddr\s+(?P[0-9.]+|\[[0-9a-f:]+\]):(?P\d+)$`, - TunnelTimeout: 0, - }, - - Cache: CacheConfig{ - Type: "inmem", - newCache: func() cache.Cache { return cache.NewInMemCache() }, - }, - - ServeLimit: ServeLimitConfig{ - Enable: false, - MaxConn: 16384, - UploadRate: 1024 * 12, // 12MB - }, - - RateLimit: APIRateLimitConfig{ - Anonymous: limited.RateLimit{ - PerMin: 10, - PerHour: 120, + + Cache: CacheConfig{ + Type: "inmem", + newCache: func() cache.Cache { return cache.NewInMemCache() }, }, - Logged: limited.RateLimit{ - PerMin: 120, - PerHour: 6000, + + ServeLimit: ServeLimitConfig{ + Enable: false, + MaxConn: 16384, + UploadRate: 1024 * 12, // 12MB }, - }, - - Notification: NotificationConfig{ - EnableEmail: false, - EmailSMTP: "smtp.example.com:25", - EmailSMTPEncryption: "tls", - EmailSender: "noreply@example.com", - EmailSenderPassword: "example-password", - EnableWebhook: true, - }, - - Dashboard: DashboardConfig{ - Enable: true, - PwaName: "GoOpenBmclApi Dashboard", - PwaShortName: "GOBA Dash", - PwaDesc: "Go-Openbmclapi Internal Dashboard", - NotifySubject: "mailto:user@example.com", - }, - - GithubAPI: GithubAPIConfig{ - UpdateCheckInterval: (utils.YAMLDuration)(time.Hour), - }, - - Database: DatabaseConfig{ - Driver: "sqlite", - DSN: filepath.Join("data", "files.db"), - }, - - Hijack: HijackConfig{ - Enable: false, - RequireAuth: false, - EnableLocalCache: false, - LocalCachePath: "hijack_cache", - AuthUsers: []UserItem{ - { - Username: "example-username", - Password: "example-password", + + RateLimit: APIRateLimitConfig{ + Anonymous: limited.RateLimit{ + PerMin: 10, + PerHour: 120, + }, + Logged: limited.RateLimit{ + PerMin: 120, + PerHour: 6000, + }, + }, + + Notification: NotificationConfig{ + EnableEmail: false, + EmailSMTP: "smtp.example.com:25", + EmailSMTPEncryption: "tls", + EmailSender: "noreply@example.com", + EmailSenderPassword: "example-password", + EnableWebhook: true, + }, + + Dashboard: DashboardConfig{ + Enable: true, + PwaName: "GoOpenBmclApi Dashboard", + PwaShortName: "GOBA Dash", + PwaDesc: "Go-Openbmclapi Internal Dashboard", + NotifySubject: "mailto:user@example.com", + }, + + GithubAPI: GithubAPIConfig{ + UpdateCheckInterval: (utils.YAMLDuration)(time.Hour), + }, + + Database: DatabaseConfig{ + Driver: "sqlite", + DSN: filepath.Join("data", "files.db"), + }, + + Hijack: HijackConfig{ + Enable: false, + RequireAuth: false, + EnableLocalCache: false, + LocalCachePath: "hijack_cache", + AuthUsers: []UserItem{ + { + Username: "example-username", + Password: "example-password", + }, }, }, - }, - - Storages: nil, - - WebdavUsers: map[string]*storage.WebDavUser{}, - - Advanced: AdvancedConfig{ - DebugLog: false, - NoHeavyCheck: false, - NoGC: false, - HeavyCheckInterval: 120, - KeepaliveTimeout: 10, - SkipFirstSync: false, - NoFastEnable: false, - WaitBeforeEnable: 0, - }, + + Storages: nil, + + WebdavUsers: map[string]*storage.WebDavUser{}, + + Advanced: AdvancedConfig{ + DebugLog: false, + NoHeavyCheck: false, + NoGC: false, + HeavyCheckInterval: 120, + KeepaliveTimeout: 10, + NoFastEnable: false, + WaitBeforeEnable: 0, + }, + } } func migrateConfig(data []byte, config *Config) { @@ -345,12 +339,24 @@ func migrateConfig(data []byte, config *Config) { if v, ok := oldConfig["keepalive-timeout"].(int); ok { config.Advanced.KeepaliveTimeout = v } + if oldConfig["clusters"].(map[string]any) == nil { + id, ok1 := oldConfig["cluster-id"].(string) + secret, ok2 := oldConfig["cluster-secret"].(string) + if ok1 && ok2 { + config.Clusters = map[string]ClusterItem{ + "main": { + Id: id, + Secret: secret, + }, + } + } + } } -func readConfig() (config Config) { +func readConfig() (config Config, err error) { const configPath = "config.yaml" - config = defaultConfig + config = getDefaultConfig() data, err := os.ReadFile(configPath) notexists := false @@ -362,11 +368,27 @@ func readConfig() (config Config) { log.Error(Tr("error.config.not.exists")) notexists = true } else { - migrateConfig(data, &config) - if err = yaml.Unmarshal(data, &config); err != nil { + migrateConfig(data, config) + if err = yaml.Unmarshal(data, config); err != nil { log.Errorf(Tr("error.config.parse.failed"), err) osExit(CodeClientError) } + if len(config.Clusters) == 0 { + config.Clusters = map[string]ClusterItem{ + "main": { + Id: "${CLUSTER_ID}", + Secret: "${CLUSTER_SECRET}", + }, + } + } + if len(config.Certificates) == 0 { + config.Certificates = []CertificateConfig{ + { + Cert: "/path/to/cert.pem", + Key: "/path/to/key.pem", + }, + } + } if len(config.Storages) == 0 { config.Storages = []storage.StorageOption{ { @@ -396,9 +418,15 @@ func readConfig() (config Config) { } if j, ok := ids[s.Id]; ok { log.Errorf("Duplicated storage id %q at [%d] and [%d], please edit the config.", s.Id, i, j) - osExit(CodeClientError) + os.Exit(CodeClientError) } ids[s.Id] = i + if s.Cluster != "" && s.Cluster != "-" { + if _, ok := config.Clusters[s.Cluster]; !ok { + log.Errorf("Storage %q is trying to connect to a not exists cluster %q.", s.Id, s.Cluster) + os.Exit(CodeClientError) + } + } } } @@ -409,7 +437,7 @@ func readConfig() (config Config) { user, ok := config.WebdavUsers[alias] if !ok { log.Errorf(Tr("error.config.alias.user.not.exists"), alias) - osExit(CodeClientError) + os.Exit(CodeClientError) } opt.AliasUser = user var end *url.URL @@ -436,45 +464,14 @@ func readConfig() (config Config) { encoder.SetIndent(2) if err = encoder.Encode(config); err != nil { log.Errorf(Tr("error.config.encode.failed"), err) - osExit(CodeClientError) + os.Exit(CodeClientError) } if err = os.WriteFile(configPath, buf.Bytes(), 0600); err != nil { log.Errorf(Tr("error.config.write.failed"), err) - osExit(CodeClientError) + os.Exit(CodeClientError) } if notexists { log.Error(Tr("error.config.created")) - osExit(0xff) - } - - if os.Getenv("DEBUG") == "true" { - config.Advanced.DebugLog = true - } - if v := os.Getenv("CLUSTER_IP"); v != "" { - config.PublicHost = v - } - if v := os.Getenv("CLUSTER_PORT"); v != "" { - if n, err := strconv.Atoi(v); err != nil { - log.Errorf("Cannot parse CLUSTER_PORT %q: %v", v, err) - } else { - config.Port = (uint16)(n) - } - } - if v := os.Getenv("CLUSTER_PUBLIC_PORT"); v != "" { - if n, err := strconv.Atoi(v); err != nil { - log.Errorf("Cannot parse CLUSTER_PUBLIC_PORT %q: %v", v, err) - } else { - config.PublicPort = (uint16)(n) - } - } - if v := os.Getenv("CLUSTER_ID"); v != "" { - config.ClusterId = v - } - if v := os.Getenv("CLUSTER_SECRET"); v != "" { - config.ClusterSecret = v - } - if byoc := os.Getenv("CLUSTER_BYOC"); byoc != "" { - config.Byoc = byoc == "true" } return } diff --git a/handler.go b/handler.go index 7aafccab..1106d80d 100644 --- a/handler.go +++ b/handler.go @@ -52,22 +52,6 @@ func init() { }) } -const ( - RealAddrCtxKey = "handle.real.addr" - RealPathCtxKey = "handle.real.path" - AccessLogExtraCtxKey = "handle.access.extra" -) - -func GetRequestRealPath(req *http.Request) string { - return req.Context().Value(RealPathCtxKey).(string) -} - -func SetAccessInfo(req *http.Request, key string, value any) { - if info, ok := req.Context().Value(AccessLogExtraCtxKey).(map[string]any); ok { - info[key] = value - } -} - type preAccessRecord struct { Type string `json:"type"` Time time.Time `json:"time"` From a498e1171b6fd1f808020f3e75ac5631a3f1aab0 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Mon, 1 Jul 2024 08:31:44 -0600 Subject: [PATCH 10/36] go fmt --- api/stats.go | 2 +- utils/http.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/api/stats.go b/api/stats.go index 6ddd0f7d..88a24f93 100644 --- a/api/stats.go +++ b/api/stats.go @@ -82,7 +82,7 @@ type accessStatHistoryData struct { type AccessStatData struct { Date statTime `json:"date"` accessStatHistoryData - Prev accessStatHistoryData `json:"prev"` + Prev accessStatHistoryData `json:"prev"` Years map[string]statInstData `json:"years"` Accesses map[string]int `json:"accesses"` diff --git a/utils/http.go b/utils/http.go index f0187e26..08437157 100644 --- a/utils/http.go +++ b/utils/http.go @@ -193,7 +193,7 @@ type HttpMethodHandler struct { Options http.Handler Trace http.Handler - allows string + allows string allowsOnce sync.Once } From 10eee068ade100a1f26485f61eb765120937cf39 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Tue, 6 Aug 2024 11:23:36 -0700 Subject: [PATCH 11/36] update dockerfile --- Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index c00e1cc6..5d014231 100644 --- a/Dockerfile +++ b/Dockerfile @@ -37,6 +37,7 @@ ARG NPM_DIR WORKDIR "/go/src/${REPO}/" +ENV CGO_ENABLED=0 COPY ./go.mod ./go.sum "/go/src/${REPO}/" RUN go mod download COPY . "/go/src/${REPO}" @@ -45,7 +46,7 @@ COPY --from=WEB_BUILD "/web/dist" "/go/src/${REPO}/${NPM_DIR}/dist" ENV ldflags="-X 'github.com/LiterMC/go-openbmclapi/internal/build.BuildVersion=$TAG'" RUN --mount=type=cache,target=/root/.cache/go-build \ - CGO_ENABLED=0 go build -v -o "/go/bin/go-openbmclapi" -ldflags="$ldflags" "." + go build -v -o "/go/bin/go-openbmclapi" -ldflags="$ldflags" "." FROM alpine:latest @@ -54,4 +55,4 @@ COPY ./config.yaml /opt/openbmclapi/config.yaml COPY --from=BUILD "/go/bin/go-openbmclapi" "/go-openbmclapi" -CMD ["/go-openbmclapi"] +ENTRYPOINT ["/go-openbmclapi"] From 8dca5cb2de7d9bcbde516c78d3efe44513805a16 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 8 Aug 2024 08:04:27 -0700 Subject: [PATCH 12/36] seperate config --- api/bmclapi/hijacker.go | 16 +- .../db_test.go => api/subscription_test.go | 4 +- api/v0/api.go | 4 + cluster/cluster.go | 38 +- cluster/handler.go | 4 +- cluster/http.go | 51 +++ cluster/keepalive.go | 10 +- cluster/socket.go | 32 +- {notify => cluster}/stat.go | 246 ++++++------ config.go | 378 ++---------------- config/advanced.go | 35 ++ config/config.go | 182 +++++++++ config/dashboard.go | 49 +++ config/server.go | 149 +++++++ database/db.go | 36 +- database/memory.go | 39 +- database/sql.go | 51 +-- go.mod | 15 +- go.sum | 18 +- notify/email/email.go | 17 +- notify/webpush/webpush.go | 25 +- storage/manager.go | 9 + storage/storage_local.go | 14 +- storage/storage_webdav.go | 14 +- sub_commands/cmd_compress.go | 2 + sub_commands/cmd_webdav.go | 48 ++- 26 files changed, 858 insertions(+), 628 deletions(-) rename database/db_test.go => api/subscription_test.go (97%) rename {notify => cluster}/stat.go (70%) create mode 100644 config/advanced.go create mode 100644 config/config.go create mode 100644 config/dashboard.go create mode 100644 config/server.go diff --git a/api/bmclapi/hijacker.go b/api/bmclapi/hijacker.go index 6cc78049..a040b036 100644 --- a/api/bmclapi/hijacker.go +++ b/api/bmclapi/hijacker.go @@ -33,6 +33,7 @@ import ( "sync" "time" + "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -52,6 +53,9 @@ func getDialerWithDNS(dnsaddr string) *net.Dialer { type downloadHandlerFn = func(rw http.ResponseWriter, req *http.Request, hash string) type HjProxy struct { + RequireAuth bool + AuthUsers []config.UserItem + client *http.Client fileMap database.DB downloadHandler downloadHandlerFn @@ -76,11 +80,11 @@ func NewHjProxy(client *http.Client, fileMap database.DB, downloadHandler downlo return } -func hjResponseWithCache(rw http.ResponseWriter, req *http.Request, c *cacheStat, force bool) (ok bool) { +func hjResponseWithCache(rw http.ResponseWriter, req *http.Request, cachePath string, c *cacheStat, force bool) (ok bool) { if c == nil { return false } - cacheFileName := filepath.Join(config.Hijack.LocalCachePath, filepath.FromSlash(req.URL.Path)) + cacheFileName := filepath.Join(cachePath, filepath.FromSlash(req.URL.Path)) age := c.ExpiresAt - time.Now().Unix() if !force && age <= 0 { return false @@ -107,15 +111,11 @@ func hjResponseWithCache(rw http.ResponseWriter, req *http.Request, c *cacheStat const hijackingHost = "bmclapi2.bangbang93.com" func (h *HjProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { - if !config.Hijack.Enable { - http.Error(rw, "Hijack is disabled in the config", http.StatusServiceUnavailable) - return - } - if config.Hijack.RequireAuth { + if h.RequireAuth { needAuth := true user, passwd, ok := req.BasicAuth() if ok { - for _, u := range config.Hijack.AuthUsers { + for _, u := range h.AuthUsers { if u.Username == user && utils.ComparePasswd(u.Password, passwd) { needAuth = false return diff --git a/database/db_test.go b/api/subscription_test.go similarity index 97% rename from database/db_test.go rename to api/subscription_test.go index 4300f23d..146505a8 100644 --- a/database/db_test.go +++ b/api/subscription_test.go @@ -17,13 +17,13 @@ * along with this program. If not, see . */ -package database_test +package api_test import ( "encoding/json" "time" - . "github.com/LiterMC/go-openbmclapi/database" + . "github.com/LiterMC/go-openbmclapi/api" "testing" ) diff --git a/api/v0/api.go b/api/v0/api.go index 65b288c1..bd2aab78 100644 --- a/api/v0/api.go +++ b/api/v0/api.go @@ -27,6 +27,7 @@ import ( "strconv" "github.com/gorilla/schema" + "github.com/gorilla/websocket" "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/utils" @@ -35,6 +36,7 @@ import ( type Handler struct { handler *utils.HttpMiddleWareHandler router *http.ServeMux + wsUpgrader *websocket.Upgrader config api.ConfigHandler users api.UserManager tokens api.TokenManager @@ -45,6 +47,7 @@ type Handler struct { var _ http.Handler = (*Handler)(nil) func NewHandler( + wsUpgrader *websocket.Upgrader, config api.ConfigHandler, users api.UserManager, tokenManager api.TokenManager, @@ -54,6 +57,7 @@ func NewHandler( h := &Handler{ router: mux, handler: utils.NewHttpMiddleWareHandler(mux), + wsUpgrader: wsUpgrader, config: config, users: users, tokens: tokenManager, diff --git a/cluster/cluster.go b/cluster/cluster.go index 835b5ff4..e437aebb 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -32,6 +32,7 @@ import ( "github.com/LiterMC/socket.io" + "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/storage" @@ -41,42 +42,34 @@ var ( reFileHashMismatchError = regexp.MustCompile(` hash mismatch, expected ([0-9a-f]+), got ([0-9a-f]+)`) ) -type ClusterOptions struct { - Id string `json:"id" yaml:"id"` - Secret string `json:"secret" yaml:"secret"` - PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` - Prefix string `json:"prefix" yaml:"prefix"` -} - -type ClusterGeneralConfig struct { - Host string `json:"host" yaml:"host"` - Port uint16 `json:"port" yaml:"port"` - Byoc bool `json:"byoc" yaml:"byoc"` - NoFastEnable bool `json:"no-fast-enable" yaml:"no-fast-enable"` -} - type Cluster struct { - opts ClusterOptions - gcfg ClusterGeneralConfig + opts config.ClusterOptions + gcfg config.ClusterGeneralConfig storageManager *storage.Manager storages []int // the index of storages in the storage manager + statManager *StatManager enableSignals []chan bool disableSignal chan struct{} + hits atomic.Int32 + hbts atomic.Int64 - mux sync.RWMutex - status atomic.Int32 - socket *socket.Socket - client *http.Client + mux sync.RWMutex + status atomic.Int32 + socketStatus atomic.Int32 + socket *socket.Socket + client *http.Client + cachedCli *http.Client authTokenMux sync.RWMutex authToken *ClusterToken } func NewCluster( - opts ClusterOptions, gcfg ClusterGeneralConfig, + opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, storageManager *storage.Manager, storages []int, + statManager *StatManager, ) (cr *Cluster) { cr = &Cluster{ opts: opts, @@ -84,8 +77,7 @@ func NewCluster( storageManager: storageManager, storages: storages, - - client: &http.Client{}, + statManager: statManager, } return } diff --git a/cluster/handler.go b/cluster/handler.go index 52c1ba16..40077429 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -40,7 +40,9 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st return false } if sz >= 0 { - cr.AddHits(1, sz, opts.Id) + cr.hits.Add(1) + cr.hbts.Add(sz) + cr.statManager.AddHit(sz, cr.ID(), opts.Id) } return true }) { diff --git a/cluster/http.go b/cluster/http.go index 20045c71..f66f3f54 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -21,14 +21,65 @@ package cluster import ( "context" + "errors" "io" + "net" "net/http" "net/url" "path" + "github.com/gregjones/httpcache" + + gocache "github.com/LiterMC/go-openbmclapi/cache" "github.com/LiterMC/go-openbmclapi/internal/build" ) +type HTTPClient struct { + cli, cachedCli *http.Client +} + +func NewHTTPClient(dialer *net.Dialer, cache gocache.Cache) *HTTPClient { + transport := http.DefaultTransport + if dialer != nil { + transport = &http.Transport{ + DialContext: dialer.DialContext, + } + } + cachedTransport := transport + if cache != gocache.NoCache { + cachedTransport = &httpcache.Transport{ + Transport: transport, + Cache: gocache.WrapToHTTPCache(gocache.NewCacheWithNamespace(cache, "http@")), + } + } + return &HTTPClient{ + cli: &http.Client{ + Transport: transport, + CheckRedirect: redirectChecker, + }, + cachedCli: &http.Client{ + Transport: cachedTransport, + CheckRedirect: redirectChecker, + }, + } +} + +func (c *HTTPClient) Do(req *http.Request) (*http.Response, error) { + return c.cli.Do(req) +} + +func (c *HTTPClient) DoUseCache(req *http.Request) (*http.Response, error) { + return c.cachedCli.Do(req) +} + +func redirectChecker(req *http.Request, via []*http.Request) error { + req.Header.Del("Referer") + if len(via) > 10 { + return errors.New("More than 10 redirects detected") + } + return nil +} + func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { return cr.makeReqWithBody(ctx, method, relpath, query, nil) } diff --git a/cluster/keepalive.go b/cluster/keepalive.go index 2dee82a9..78fc2eef 100644 --- a/cluster/keepalive.go +++ b/cluster/keepalive.go @@ -50,10 +50,6 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { Hits: hits, Bytes: hbts, }) - - if e := cr.stats.Save(cr.dataDir); e != nil { - log.TrErrorf("error.cluster.stat.save.failed", e) - } if err != nil { log.TrErrorf("error.cluster.keepalive.send.failed", err) return KeepAliveFailed @@ -80,9 +76,9 @@ func (cr *Cluster) KeepAlive(ctx context.Context) KeepAliveRes { log.TrErrorf("error.cluster.keepalive.failed", ero) return KeepAliveFailed } - log.TrInfof("info.cluster.keepalive.success", ahits, utils.BytesToUnit((float64)(ahbts)), data[1]) - cr.hits.Add(-hits2) - cr.hbts.Add(-hbts2) + log.TrInfof("info.cluster.keepalive.success", hits, utils.BytesToUnit((float64)(hbts)), data[1]) + cr.hits.Add(-hits) + cr.hbts.Add(-hbts) if data[1] == false { cr.markKicked() return KeepAliveKicked diff --git a/cluster/socket.go b/cluster/socket.go index 0eb1a8a7..7b2b78f1 100644 --- a/cluster/socket.go +++ b/cluster/socket.go @@ -21,10 +21,16 @@ package cluster import ( "context" + "errors" "fmt" + "net/http" + "time" "github.com/LiterMC/socket.io" "github.com/LiterMC/socket.io/engine.io" + + "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/log" ) // Connect connects to the central server @@ -42,10 +48,10 @@ func (cr *Cluster) Connect(ctx context.Context) error { } engio, err := engine.NewSocket(engine.Options{ - Host: cr.prefix, + Host: cr.opts.Prefix, Path: "/socket.io/", ExtraHeaders: http.Header{ - "Origin": {cr.prefix}, + "Origin": {cr.opts.Prefix}, "User-Agent": {build.ClusterUserAgent}, }, DialTimeout: time.Minute * 6, @@ -62,23 +68,22 @@ func (cr *Cluster) Connect(ctx context.Context) error { }) } engio.OnConnect(func(s *engine.Socket) { - log.Info("Engine.IO %s connected for cluster %s", s.ID(), cr.Id()) + log.Info("Engine.IO %s connected for cluster %s", s.ID(), cr.ID()) }) engio.OnDisconnect(cr.onDisconnected) - engio.OnDialError(func(s *engine.Socket, err *DialErrorContext) { + engio.OnDialError(func(s *engine.Socket, err *engine.DialErrorContext) { if err.Count() < 0 { return } - log.TrErrorf("error.cluster.connect.failed", cr.Id(), err.Count(), config.MaxReconnectCount, err.Err()) - if config.MaxReconnectCount >= 0 && err.Count() >= config.MaxReconnectCount { - log.TrErrorf("error.cluster.connect.failed.toomuch", cr.Id()) + log.TrErrorf("error.cluster.connect.failed", cr.ID(), err.Count(), cr.gcfg.MaxReconnectCount, err.Err()) + if cr.gcfg.MaxReconnectCount >= 0 && err.Count() >= cr.gcfg.MaxReconnectCount { + log.TrErrorf("error.cluster.connect.failed.toomuch", cr.ID()) s.Close() } }) - log.Infof("Dialing %s for cluster %s", engio.URL().String(), cr.Id()) + log.Infof("Dialing %s for cluster %s", engio.URL().String(), cr.ID()) if err := engio.Dial(ctx); err != nil { - log.Errorf("Dial error: %v", err) - return false + return fmt.Errorf("Dial error: %w", err) } cr.socket = socket.NewSocket(engio, socket.WithAuthTokenFn(func() (string, error) { @@ -99,10 +104,9 @@ func (cr *Cluster) Connect(ctx context.Context) error { }) log.Info("Connecting to socket.io namespace") if err := cr.socket.Connect(""); err != nil { - log.Errorf("Namespace connect error: %v", err) - return false + return fmt.Errorf("Namespace connect error: %w", err) } - return true + return nil } // Disconnect close the connection which connected to the central server @@ -111,7 +115,7 @@ func (cr *Cluster) Connect(ctx context.Context) error { // See Connect func (cr *Cluster) Disconnect() error { if cr.Disconnected() { - return + return nil } cr.mux.Lock() defer cr.mux.Unlock() diff --git a/notify/stat.go b/cluster/stat.go similarity index 70% rename from notify/stat.go rename to cluster/stat.go index ae3a1d29..a3e50a2a 100644 --- a/notify/stat.go +++ b/cluster/stat.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package notify +package cluster import ( "encoding/json" @@ -27,10 +27,120 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" ) +const statsOverallFileName = "stat.json" + +type StatManager struct { + mux sync.RWMutex + + Overall *StatData + Clusters map[string]*StatData + Storages map[string]*StatData +} + +func NewStatManager() *StatManager { + return &StatManager{ + Overall: new(StatData), + Clusters: make(map[string]*StatData), + Storages: make(map[string]*StatData), + } +} + +func (m *StatManager) AddHit(bytes int64, cluster, storage string) { + m.mux.Lock() + defer m.mux.Unlock() + + data := &statInstData{ + Hits: 1, + Bytes: bytes, + } + m.Overall.update(data) + if cluster != "" { + d := m.Clusters[cluster] + if d == nil { + d = NewStatData() + m.Clusters[cluster] = d + } + d.update(data) + } + if storage != "" { + d := m.Storages[storage] + if d == nil { + d = NewStatData() + m.Storages[storage] = d + } + d.update(data) + } +} + +func (m *StatManager) Load(dir string) error { + clustersDir, storagesDir := filepath.Join(dir, "clusters"), filepath.Join(dir, "storages") + + *m.Overall = StatData{} + clear(m.Clusters) + clear(m.Storages) + + if err := m.Overall.load(filepath.Join(dir, statsOverallFileName)); err != nil { + return err + } + if entries, err := os.ReadDir(clustersDir); err == nil { + for _, entry := range entries { + if entry.IsDir() { + continue + } + if name, ok := strings.CutSuffix(entry.Name(), ".json"); ok { + d := new(StatData) + if err := d.load(filepath.Join(clustersDir, entry.Name())); err != nil { + return err + } + m.Clusters[name] = d + } + } + } + if entries, err := os.ReadDir(storagesDir); err == nil { + for _, entry := range entries { + if entry.IsDir() { + continue + } + if name, ok := strings.CutSuffix(entry.Name(), ".json"); ok { + d := new(StatData) + if err := d.load(filepath.Join(storagesDir, entry.Name())); err != nil { + return err + } + m.Storages[name] = d + } + } + } + return nil +} + +func (m *StatManager) Save(dir string) error { + clustersDir, storagesDir := filepath.Join(dir, "clusters"), filepath.Join(dir, "storages") + + if err := m.Overall.save(filepath.Join(dir, statsOverallFileName)); err != nil { + return err + } + if err := os.Mkdir(clustersDir, 0755); err != nil && !errors.Is(err, os.ErrExist) { + return err + } + if err := os.Mkdir(storagesDir, 0755); err != nil && !errors.Is(err, os.ErrExist) { + return err + } + for name, data := range m.Clusters { + if err := data.save(filepath.Join(clustersDir, name+".json")); err != nil { + return err + } + } + for name, data := range m.Storages { + if err := data.save(filepath.Join(storagesDir, name+".json")); err != nil { + return err + } + } + return nil +} + type statInstData struct { Hits int32 `json:"hits"` Bytes int64 `json:"bytes"` @@ -84,6 +194,12 @@ type StatData struct { Accesses map[string]int `json:"accesses"` } +func NewStatData() *StatData { + return &StatData{ + Years: make(map[string]statInstData, 2), + Accesses: make(map[string]int, 5), + } +} func (d *StatData) Clone() *StatData { cloned := new(StatData) *cloned = *d @@ -216,44 +332,11 @@ func (d *StatData) update(newData *statInstData) { d.Date = now } -type Stats struct { - sync.RWMutex - StatData - - subStat map[string]*StatData - - hits atomic.Int32 - bts atomic.Int64 -} - -const statsDirName = "stats" -const statsFileName = "stat.json" - -func (s *Stats) Clone() *StatData { - s.RLock() - defer s.RUnlock() - return s.StatData.Clone() -} - -func (s *Stats) MarshalJSON() ([]byte, error) { - s.RLock() - defer s.RUnlock() - - return json.Marshal(&s.StatData) -} - -func (s *Stats) MarshalSubStat(name string) ([]byte, error) { - s.RLock() - defer s.RUnlock() - - return json.Marshal(s.subStat[name]) -} - -func (s *StatData) load(name string) (err error) { - if err = parseFileOrOld(name, func(buf []byte) error { +func (s *StatData) load(name string) error { + if err := parseFileOrOld(name, func(buf []byte) error { return json.Unmarshal(buf, s) }); err != nil { - return + return err } if s.Years == nil { @@ -262,91 +345,18 @@ func (s *StatData) load(name string) (err error) { if s.Accesses == nil { s.Accesses = make(map[string]int, 5) } - return -} - -func (s *Stats) Load(dir string) (err error) { - s.Lock() - defer s.Unlock() - - if err = s.StatData.load(filepath.Join(dir, statsFileName)); err != nil { - return - } - s.subStat = make(map[string]*StatData) - - if entries, err := os.ReadDir(filepath.Join(dir, statsDirName)); err == nil { - for _, entry := range entries { - if entry.IsDir() { - continue - } - if name, ok := strings.CutSuffix(entry.Name(), ".json"); ok { - data := new(StatData) - if err := data.load(filepath.Join(dir, statsDirName, entry.Name())); err != nil { - return err - } - s.subStat[name] = data - } - } - } - return + return nil } -// Save -func (s *Stats) Save(dir string) (err error) { - s.RLock() - defer s.RUnlock() - - var buf []byte - if buf, err = json.Marshal(&s.StatData); err != nil { - return - } - if err = writeFileWithOld(filepath.Join(dir, statsFileName), buf, 0644); err != nil { - return - } - - if err := os.Mkdir(filepath.Join(dir, statsDirName), 0755); err != nil && !errors.Is(err, os.ErrExist) { +func (s *StatData) save(name string) error { + buf, err := json.Marshal(s) + if err != nil { return err } - for name, data := range s.subStat { - if buf, err = json.Marshal(data); err != nil { - return - } - if err = writeFileWithOld(filepath.Join(dir, statsDirName, name+".json"), buf, 0644); err != nil { - return - } - } - return -} - -func (s *Stats) GetTmpHits() (hits int32, bts int64) { - return s.hits.Load(), s.bts.Load() -} - -func (s *Stats) AddHits(hits int32, bytes int64, name string) { - s.hits.Add(hits) - s.bts.Add(bytes) - - s.Lock() - defer s.Unlock() - - data := &statInstData{ - Hits: hits, - Bytes: bytes, - } - s.update(data) - if name != "" { - ss := s.subStat[name] - if ss == nil { - ss = new(StatData) - ss.Years = make(map[string]statInstData, 2) - ss.Accesses = make(map[string]int, 5) - if s.subStat == nil { - s.subStat = make(map[string]*StatData) - } - s.subStat[name] = ss - } - ss.update(data) + if err := writeFileWithOld(name, buf, 0644); err != nil { + return err } + return nil } func parseFileOrOld(path string, parser func(buf []byte) error) error { diff --git a/config.go b/config.go index 03e31745..43298f58 100644 --- a/config.go +++ b/config.go @@ -1,6 +1,6 @@ /** * OpenBmclAPI (Golang Edition) - * Copyright (C) 2023 Kevin Z + * Copyright (C) 2024 Kevin Z * All rights reserved * * This program is free software: you can redistribute it and/or modify @@ -21,329 +21,32 @@ package main import ( "bytes" - "errors" - "fmt" - "net/url" - "os" - "path/filepath" - "regexp" - "strconv" - "strings" - "time" "gopkg.in/yaml.v3" - "github.com/LiterMC/go-openbmclapi/cache" - "github.com/LiterMC/go-openbmclapi/limited" - "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/storage" - "github.com/LiterMC/go-openbmclapi/utils" + "github.com/LiterMC/go-openbmclapi/config" ) -type UserItem struct { - Username string `yaml:"username"` - Password string `yaml:"password"` -} - -type AdvancedConfig struct { - DebugLog bool `yaml:"debug-log"` - SocketIOLog bool `yaml:"socket-io-log"` - NoHeavyCheck bool `yaml:"no-heavy-check"` - NoGC bool `yaml:"no-gc"` - HeavyCheckInterval int `yaml:"heavy-check-interval"` - KeepaliveTimeout int `yaml:"keepalive-timeout"` - SkipSignatureCheck bool `yaml:"skip-signature-check"` - NoFastEnable bool `yaml:"no-fast-enable"` - WaitBeforeEnable int `yaml:"wait-before-enable"` - - DoNotRedirectHTTPSToSecureHostname bool `yaml:"do-NOT-redirect-https-to-SECURE-hostname"` - DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` -} - -type CertificateConfig struct { - Cert string `yaml:"cert"` - Key string `yaml:"key"` -} - -type ServeLimitConfig struct { - Enable bool `yaml:"enable"` - MaxConn int `yaml:"max-conn"` - UploadRate int `yaml:"upload-rate"` -} - -type APIRateLimitConfig struct { - Anonymous limited.RateLimit `yaml:"anonymous"` - Logged limited.RateLimit `yaml:"logged"` -} - -type NotificationConfig struct { - EnableEmail bool `yaml:"enable-email"` - EmailSMTP string `yaml:"email-smtp"` - EmailSMTPEncryption string `yaml:"email-smtp-encryption"` - EmailSender string `yaml:"email-sender"` - EmailSenderPassword string `yaml:"email-sender-password"` - EnableWebhook bool `yaml:"enable-webhook"` -} - -type DatabaseConfig struct { - Driver string `yaml:"driver"` - DSN string `yaml:"data-source-name"` -} - -type HijackConfig struct { - Enable bool `yaml:"enable"` - EnableLocalCache bool `yaml:"enable-local-cache"` - LocalCachePath string `yaml:"local-cache-path"` - RequireAuth bool `yaml:"require-auth"` - AuthUsers []UserItem `yaml:"auth-users"` -} - -type CacheConfig struct { - Type string `yaml:"type"` - Data any `yaml:"data,omitempty"` - - newCache func() cache.Cache `yaml:"-"` -} - -func (c *CacheConfig) UnmarshalYAML(n *yaml.Node) (err error) { - var cfg struct { - Type string `yaml:"type"` - Data utils.RawYAML `yaml:"data,omitempty"` - } - if err = n.Decode(&cfg); err != nil { - return - } - c.Type = cfg.Type - c.Data = nil - switch strings.ToLower(c.Type) { - case "no", "off", "disabled", "nocache", "no-cache": - c.newCache = func() cache.Cache { return cache.NoCache } - case "mem", "memory", "inmem": - c.newCache = func() cache.Cache { return cache.NewInMemCache() } - case "redis": - opt := new(cache.RedisOptions) - if err = cfg.Data.Decode(opt); err != nil { - return - } - c.Data = opt - c.newCache = func() cache.Cache { return cache.NewRedisCache(opt.ToRedis()) } - default: - return fmt.Errorf("Unexpected cache type %q", c.Type) - } - return nil -} - -type GithubAPIConfig struct { - UpdateCheckInterval utils.YAMLDuration `yaml:"update-check-interval"` - Authorization string `yaml:"authorization"` -} - -type DashboardConfig struct { - Enable bool `yaml:"enable"` - Username string `yaml:"username"` - Password string `yaml:"password"` - PwaName string `yaml:"pwa-name"` - PwaShortName string `yaml:"pwa-short_name"` - PwaDesc string `yaml:"pwa-description"` - - NotifySubject string `yaml:"notification-subject"` -} - -type TunnelConfig struct { - Enable bool `yaml:"enable"` - TunnelProg string `yaml:"tunnel-program"` - OutputRegex string `yaml:"output-regex"` - TunnelTimeout int `yaml:"tunnel-timeout"` - - outputRegex *regexp.Regexp - hostOut int - portOut int -} - -func (c *TunnelConfig) UnmarshalYAML(n *yaml.Node) (err error) { - type T TunnelConfig - if err = n.Decode((*T)(c)); err != nil { - return - } - if !c.Enable { - return - } - if c.outputRegex, err = regexp.Compile(c.OutputRegex); err != nil { - return - } - c.hostOut = c.outputRegex.SubexpIndex("host") - c.portOut = c.outputRegex.SubexpIndex("port") - if c.hostOut <= 0 { - return errors.New("tunneler.output-regex: missing named `(?)` capture group") - } - if c.portOut <= 0 { - return errors.New("tunneler.output-regex: missing named `(?)` capture group") - } - return -} - -type Config struct { - LogSlots int `yaml:"log-slots"` - NoAccessLog bool `yaml:"no-access-log"` - AccessLogSlots int `yaml:"access-log-slots"` - Byoc bool `yaml:"byoc"` - UseCert bool `yaml:"use-cert"` - TrustedXForwardedFor bool `yaml:"trusted-x-forwarded-for"` - PublicHost string `yaml:"public-host"` - PublicPort uint16 `yaml:"public-port"` - Port uint16 `yaml:"port"` - SyncInterval int `yaml:"sync-interval"` - OnlyGcWhenStart bool `yaml:"only-gc-when-start"` - DownloadMaxConn int `yaml:"download-max-conn"` - MaxReconnectCount int `yaml:"max-reconnect-count"` - - Clusters map[string]ClusterItem `yaml:"clusters"` - Certificates []CertificateConfig `yaml:"certificates"` - Tunneler TunnelConfig `yaml:"tunneler"` - Cache CacheConfig `yaml:"cache"` - ServeLimit ServeLimitConfig `yaml:"serve-limit"` - RateLimit APIRateLimitConfig `yaml:"api-rate-limit"` - Notification NotificationConfig `yaml:"notification"` - Dashboard DashboardConfig `yaml:"dashboard"` - GithubAPI GithubAPIConfig `yaml:"github-api"` - Database DatabaseConfig `yaml:"database"` - Hijack HijackConfig `yaml:"hijack"` - Storages []storage.StorageOption `yaml:"storages"` - WebdavUsers map[string]*storage.WebDavUser `yaml:"webdav-users"` - Advanced AdvancedConfig `yaml:"advanced"` -} - -func (cfg *Config) applyWebManifest(manifest map[string]any) { - if cfg.Dashboard.Enable { - manifest["name"] = cfg.Dashboard.PwaName - manifest["short_name"] = cfg.Dashboard.PwaShortName - manifest["description"] = cfg.Dashboard.PwaDesc - } -} - -func getDefaultConfig() *Config { - return &Config{ - LogSlots: 7, - NoAccessLog: false, - AccessLogSlots: 16, - Byoc: false, - TrustedXForwardedFor: false, - PublicHost: "", - PublicPort: 0, - Port: 4000, - SyncInterval: 10, - OnlyGcWhenStart: false, - DownloadMaxConn: 16, - MaxReconnectCount: 10, - - Clusters: map[string]ClusterItem{}, - - Certificates: []CertificateConfig{}, - - Tunneler: TunnelConfig{ - Enable: false, - TunnelProg: "./path/to/tunnel/program", - OutputRegex: `\bNATedAddr\s+(?P[0-9.]+|\[[0-9a-f:]+\]):(?P\d+)$`, - TunnelTimeout: 0, - }, - - Cache: CacheConfig{ - Type: "inmem", - newCache: func() cache.Cache { return cache.NewInMemCache() }, - }, - - ServeLimit: ServeLimitConfig{ - Enable: false, - MaxConn: 16384, - UploadRate: 1024 * 12, // 12MB - }, - - RateLimit: APIRateLimitConfig{ - Anonymous: limited.RateLimit{ - PerMin: 10, - PerHour: 120, - }, - Logged: limited.RateLimit{ - PerMin: 120, - PerHour: 6000, - }, - }, - - Notification: NotificationConfig{ - EnableEmail: false, - EmailSMTP: "smtp.example.com:25", - EmailSMTPEncryption: "tls", - EmailSender: "noreply@example.com", - EmailSenderPassword: "example-password", - EnableWebhook: true, - }, - - Dashboard: DashboardConfig{ - Enable: true, - PwaName: "GoOpenBmclApi Dashboard", - PwaShortName: "GOBA Dash", - PwaDesc: "Go-Openbmclapi Internal Dashboard", - NotifySubject: "mailto:user@example.com", - }, - - GithubAPI: GithubAPIConfig{ - UpdateCheckInterval: (utils.YAMLDuration)(time.Hour), - }, - - Database: DatabaseConfig{ - Driver: "sqlite", - DSN: filepath.Join("data", "files.db"), - }, - - Hijack: HijackConfig{ - Enable: false, - RequireAuth: false, - EnableLocalCache: false, - LocalCachePath: "hijack_cache", - AuthUsers: []UserItem{ - { - Username: "example-username", - Password: "example-password", - }, - }, - }, - - Storages: nil, - - WebdavUsers: map[string]*storage.WebDavUser{}, - - Advanced: AdvancedConfig{ - DebugLog: false, - NoHeavyCheck: false, - NoGC: false, - HeavyCheckInterval: 120, - KeepaliveTimeout: 10, - NoFastEnable: false, - WaitBeforeEnable: 0, - }, - } -} - -func migrateConfig(data []byte, config *Config) { +func migrateConfig(data []byte, cfg *config.Config) { var oldConfig map[string]any if err := yaml.Unmarshal(data, &oldConfig); err != nil { return } if v, ok := oldConfig["debug"].(bool); ok { - config.Advanced.DebugLog = v + cfg.Advanced.DebugLog = v } if v, ok := oldConfig["no-heavy-check"].(bool); ok { - config.Advanced.NoHeavyCheck = v + cfg.Advanced.NoHeavyCheck = v } if v, ok := oldConfig["keepalive-timeout"].(int); ok { - config.Advanced.KeepaliveTimeout = v + cfg.Advanced.KeepaliveTimeout = v } if oldConfig["clusters"].(map[string]any) == nil { id, ok1 := oldConfig["cluster-id"].(string) secret, ok2 := oldConfig["cluster-secret"].(string) if ok1 && ok2 { - config.Clusters = map[string]ClusterItem{ + cfg.Clusters = map[string]ClusterItem{ "main": { Id: id, Secret: secret, @@ -353,44 +56,43 @@ func migrateConfig(data []byte, config *Config) { } } -func readConfig() (config Config, err error) { +func readAndRewriteConfig() (cfg *config.Config, err error) { const configPath = "config.yaml" - config = getDefaultConfig() - + cfg = config.NewDefaultConfig() data, err := os.ReadFile(configPath) notexists := false if err != nil { if !errors.Is(err, os.ErrNotExist) { - log.Errorf(Tr("error.config.read.failed"), err) - osExit(CodeClientError) + log.TrErrorf("error.config.read.failed", err) + os.Exit(1) } - log.Error(Tr("error.config.not.exists")) + log.TrError("error.config.not.exists") notexists = true } else { - migrateConfig(data, config) - if err = yaml.Unmarshal(data, config); err != nil { - log.Errorf(Tr("error.config.parse.failed"), err) - osExit(CodeClientError) + migrateConfig(data, cfg) + if err = cfg.UnmarshalText(data); err != nil { + log.TrErrorf("error.config.parse.failed", err) + os.Exit(1) } - if len(config.Clusters) == 0 { - config.Clusters = map[string]ClusterItem{ + if len(cfg.Clusters) == 0 { + cfg.Clusters = map[string]ClusterItem{ "main": { Id: "${CLUSTER_ID}", Secret: "${CLUSTER_SECRET}", }, } } - if len(config.Certificates) == 0 { - config.Certificates = []CertificateConfig{ + if len(cfg.Certificates) == 0 { + cfg.Certificates = []CertificateConfig{ { Cert: "/path/to/cert.pem", Key: "/path/to/key.pem", }, } } - if len(config.Storages) == 0 { - config.Storages = []storage.StorageOption{ + if len(cfg.Storages) == 0 { + cfg.Storages = []storage.StorageOption{ { BasicStorageOption: storage.BasicStorageOption{ Id: "local", @@ -403,41 +105,41 @@ func readConfig() (config Config, err error) { }, } } - if len(config.WebdavUsers) == 0 { - config.WebdavUsers["example-user"] = &storage.WebDavUser{ + if len(cfg.WebdavUsers) == 0 { + cfg.WebdavUsers["example-user"] = &storage.WebDavUser{ EndPoint: "https://webdav.example.com/path/to/endpoint/", Username: "example-username", Password: "example-password", } } - ids := make(map[string]int, len(config.Storages)) - for i, s := range config.Storages { + ids := make(map[string]int, len(cfg.Storages)) + for i, s := range cfg.Storages { if s.Id == "" { s.Id = fmt.Sprintf("storage-%d", i) - config.Storages[i].Id = s.Id + cfg.Storages[i].Id = s.Id } if j, ok := ids[s.Id]; ok { log.Errorf("Duplicated storage id %q at [%d] and [%d], please edit the config.", s.Id, i, j) - os.Exit(CodeClientError) + os.Exit(1) } ids[s.Id] = i if s.Cluster != "" && s.Cluster != "-" { - if _, ok := config.Clusters[s.Cluster]; !ok { + if _, ok := cfg.Clusters[s.Cluster]; !ok { log.Errorf("Storage %q is trying to connect to a not exists cluster %q.", s.Id, s.Cluster) - os.Exit(CodeClientError) + os.Exit(1) } } } } - for _, so := range config.Storages { + for _, so := range cfg.Storages { switch opt := so.Data.(type) { case *storage.WebDavStorageOption: if alias := opt.Alias; alias != "" { - user, ok := config.WebdavUsers[alias] + user, ok := cfg.WebdavUsers[alias] if !ok { - log.Errorf(Tr("error.config.alias.user.not.exists"), alias) - os.Exit(CodeClientError) + log.TrErrorf("error.config.alias.user.not.exists", alias) + os.Exit(1) } opt.AliasUser = user var end *url.URL @@ -462,16 +164,16 @@ func readConfig() (config Config, err error) { var buf bytes.Buffer encoder := yaml.NewEncoder(&buf) encoder.SetIndent(2) - if err = encoder.Encode(config); err != nil { - log.Errorf(Tr("error.config.encode.failed"), err) - os.Exit(CodeClientError) + if err = encoder.Encode(cfg); err != nil { + log.TrErrorf("error.config.encode.failed", err) + os.Exit(1) } if err = os.WriteFile(configPath, buf.Bytes(), 0600); err != nil { - log.Errorf(Tr("error.config.write.failed"), err) - os.Exit(CodeClientError) + log.TrErrorf("error.config.write.failed", err) + os.Exit(1) } if notexists { - log.Error(Tr("error.config.created")) + log.TrError("error.config.created") } return } diff --git a/config/advanced.go b/config/advanced.go new file mode 100644 index 00000000..8d3c9805 --- /dev/null +++ b/config/advanced.go @@ -0,0 +1,35 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package config + +type AdvancedConfig struct { + DebugLog bool `yaml:"debug-log"` + SocketIOLog bool `yaml:"socket-io-log"` + NoHeavyCheck bool `yaml:"no-heavy-check"` + NoGC bool `yaml:"no-gc"` + HeavyCheckInterval int `yaml:"heavy-check-interval"` + KeepaliveTimeout int `yaml:"keepalive-timeout"` + SkipSignatureCheck bool `yaml:"skip-signature-check"` + NoFastEnable bool `yaml:"no-fast-enable"` + WaitBeforeEnable int `yaml:"wait-before-enable"` + + DoNotRedirectHTTPSToSecureHostname bool `yaml:"do-NOT-redirect-https-to-SECURE-hostname"` + DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 00000000..747b961b --- /dev/null +++ b/config/config.go @@ -0,0 +1,182 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package config + +import ( + "path/filepath" + "time" + + "gopkg.in/yaml.v3" + + "github.com/LiterMC/go-openbmclapi/cache" + "github.com/LiterMC/go-openbmclapi/limited" + "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/utils" +) + +type Config struct { + PublicHost string `yaml:"public-host"` + PublicPort uint16 `yaml:"public-port"` + Port uint16 `yaml:"port"` + Byoc bool `yaml:"byoc"` + UseCert bool `yaml:"use-cert"` + TrustedXForwardedFor bool `yaml:"trusted-x-forwarded-for"` + + OnlyGcWhenStart bool `yaml:"only-gc-when-start"` + SyncInterval int `yaml:"sync-interval"` + DownloadMaxConn int `yaml:"download-max-conn"` + MaxReconnectCount int `yaml:"max-reconnect-count"` + + LogSlots int `yaml:"log-slots"` + NoAccessLog bool `yaml:"no-access-log"` + AccessLogSlots int `yaml:"access-log-slots"` + + Clusters map[string]ClusterOptions `yaml:"clusters"` + Certificates []CertificateConfig `yaml:"certificates"` + Tunneler TunnelConfig `yaml:"tunneler"` + Cache CacheConfig `yaml:"cache"` + ServeLimit ServeLimitConfig `yaml:"serve-limit"` + RateLimit APIRateLimitConfig `yaml:"api-rate-limit"` + Notification NotificationConfig `yaml:"notification"` + Dashboard DashboardConfig `yaml:"dashboard"` + GithubAPI GithubAPIConfig `yaml:"github-api"` + Database DatabaseConfig `yaml:"database"` + Hijack HijackConfig `yaml:"hijack"` + Storages []storage.StorageOption `yaml:"storages"` + WebdavUsers map[string]*storage.WebDavUser `yaml:"webdav-users"` + Advanced AdvancedConfig `yaml:"advanced"` +} + +func (cfg *Config) applyWebManifest(manifest map[string]any) { + if cfg.Dashboard.Enable { + manifest["name"] = cfg.Dashboard.PwaName + manifest["short_name"] = cfg.Dashboard.PwaShortName + manifest["description"] = cfg.Dashboard.PwaDesc + } +} + +func NewDefaultConfig() *Config { + return &Config{ + PublicHost: "", + PublicPort: 0, + Port: 4000, + Byoc: false, + TrustedXForwardedFor: false, + + OnlyGcWhenStart: false, + SyncInterval: 10, + DownloadMaxConn: 16, + MaxReconnectCount: 10, + + LogSlots: 7, + NoAccessLog: false, + AccessLogSlots: 16, + + Clusters: map[string]ClusterOptions{}, + + Certificates: []CertificateConfig{}, + + Tunneler: TunnelConfig{ + Enable: false, + TunnelProg: "./path/to/tunnel/program", + OutputRegex: `\bNATedAddr\s+(?P[0-9.]+|\[[0-9a-f:]+\]):(?P\d+)$`, + TunnelTimeout: 0, + }, + + Cache: CacheConfig{ + Type: "inmem", + newCache: func() cache.Cache { return cache.NewInMemCache() }, + }, + + ServeLimit: ServeLimitConfig{ + Enable: false, + MaxConn: 16384, + UploadRate: 1024 * 12, // 12MB + }, + + RateLimit: APIRateLimitConfig{ + Anonymous: limited.RateLimit{ + PerMin: 10, + PerHour: 120, + }, + Logged: limited.RateLimit{ + PerMin: 120, + PerHour: 6000, + }, + }, + + Notification: NotificationConfig{ + EnableEmail: false, + EmailSMTP: "smtp.example.com:25", + EmailSMTPEncryption: "tls", + EmailSender: "noreply@example.com", + EmailSenderPassword: "example-password", + EnableWebhook: true, + }, + + Dashboard: DashboardConfig{ + Enable: true, + PwaName: "GoOpenBmclApi Dashboard", + PwaShortName: "GOBA Dash", + PwaDesc: "Go-Openbmclapi Internal Dashboard", + NotifySubject: "mailto:user@example.com", + }, + + GithubAPI: GithubAPIConfig{ + UpdateCheckInterval: (utils.YAMLDuration)(time.Hour), + }, + + Database: DatabaseConfig{ + Driver: "sqlite", + DSN: filepath.Join("data", "files.db"), + }, + + Hijack: HijackConfig{ + Enable: false, + RequireAuth: false, + EnableLocalCache: false, + LocalCachePath: "hijack_cache", + AuthUsers: []UserItem{ + { + Username: "example-username", + Password: "example-password", + }, + }, + }, + + Storages: nil, + + WebdavUsers: map[string]*storage.WebDavUser{}, + + Advanced: AdvancedConfig{ + DebugLog: false, + NoHeavyCheck: false, + NoGC: false, + HeavyCheckInterval: 120, + KeepaliveTimeout: 10, + NoFastEnable: false, + WaitBeforeEnable: 0, + }, + } +} + +func (config *Config) UnmarshalText(data []byte) error { + return yaml.Unmarshal(data, config) +} diff --git a/config/dashboard.go b/config/dashboard.go new file mode 100644 index 00000000..3fb954b6 --- /dev/null +++ b/config/dashboard.go @@ -0,0 +1,49 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package config + +import ( + "github.com/LiterMC/go-openbmclapi/limited" +) + +type APIRateLimitConfig struct { + Anonymous limited.RateLimit `yaml:"anonymous"` + Logged limited.RateLimit `yaml:"logged"` +} + +type NotificationConfig struct { + EnableEmail bool `yaml:"enable-email"` + EmailSMTP string `yaml:"email-smtp"` + EmailSMTPEncryption string `yaml:"email-smtp-encryption"` + EmailSender string `yaml:"email-sender"` + EmailSenderPassword string `yaml:"email-sender-password"` + EnableWebhook bool `yaml:"enable-webhook"` +} + +type DashboardConfig struct { + Enable bool `yaml:"enable"` + Username string `yaml:"username"` + Password string `yaml:"password"` + PwaName string `yaml:"pwa-name"` + PwaShortName string `yaml:"pwa-short_name"` + PwaDesc string `yaml:"pwa-description"` + + NotifySubject string `yaml:"notification-subject"` +} diff --git a/config/server.go b/config/server.go new file mode 100644 index 00000000..d1b43046 --- /dev/null +++ b/config/server.go @@ -0,0 +1,149 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package config + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "gopkg.in/yaml.v3" + + "github.com/LiterMC/go-openbmclapi/cache" + "github.com/LiterMC/go-openbmclapi/utils" +) + +type ClusterOptions struct { + Id string `json:"id" yaml:"id"` + Secret string `json:"secret" yaml:"secret"` + PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` + Prefix string `json:"prefix" yaml:"prefix"` +} + +type ClusterGeneralConfig struct { + Host string `json:"host"` + Port uint16 `json:"port"` + Byoc bool `json:"byoc"` + NoFastEnable bool `json:"no-fast-enable"` + MaxReconnectCount int `json:"max-reconnect-count"` +} + +type UserItem struct { + Username string `yaml:"username"` + Password string `yaml:"password"` +} + +type CertificateConfig struct { + Cert string `yaml:"cert"` + Key string `yaml:"key"` +} + +type DatabaseConfig struct { + Driver string `yaml:"driver"` + DSN string `yaml:"data-source-name"` +} + +type HijackConfig struct { + Enable bool `yaml:"enable"` + EnableLocalCache bool `yaml:"enable-local-cache"` + LocalCachePath string `yaml:"local-cache-path"` + RequireAuth bool `yaml:"require-auth"` + AuthUsers []UserItem `yaml:"auth-users"` +} + +type CacheConfig struct { + Type string `yaml:"type"` + Data any `yaml:"data,omitempty"` + + newCache func() cache.Cache `yaml:"-"` +} + +func (c *CacheConfig) UnmarshalYAML(n *yaml.Node) (err error) { + var cfg struct { + Type string `yaml:"type"` + Data utils.RawYAML `yaml:"data,omitempty"` + } + if err = n.Decode(&cfg); err != nil { + return + } + c.Type = cfg.Type + c.Data = nil + switch strings.ToLower(c.Type) { + case "no", "off", "disabled", "nocache", "no-cache": + c.newCache = func() cache.Cache { return cache.NoCache } + case "mem", "memory", "inmem": + c.newCache = func() cache.Cache { return cache.NewInMemCache() } + case "redis": + opt := new(cache.RedisOptions) + if err = cfg.Data.Decode(opt); err != nil { + return + } + c.Data = opt + c.newCache = func() cache.Cache { return cache.NewRedisCache(opt.ToRedis()) } + default: + return fmt.Errorf("Unexpected cache type %q", c.Type) + } + return nil +} + +type ServeLimitConfig struct { + Enable bool `yaml:"enable"` + MaxConn int `yaml:"max-conn"` + UploadRate int `yaml:"upload-rate"` +} + +type GithubAPIConfig struct { + UpdateCheckInterval utils.YAMLDuration `yaml:"update-check-interval"` + Authorization string `yaml:"authorization"` +} + +type TunnelConfig struct { + Enable bool `yaml:"enable"` + TunnelProg string `yaml:"tunnel-program"` + OutputRegex string `yaml:"output-regex"` + TunnelTimeout int `yaml:"tunnel-timeout"` + + outputRegex *regexp.Regexp + hostOut int + portOut int +} + +func (c *TunnelConfig) UnmarshalYAML(n *yaml.Node) (err error) { + type T TunnelConfig + if err = n.Decode((*T)(c)); err != nil { + return + } + if !c.Enable { + return + } + if c.outputRegex, err = regexp.Compile(c.OutputRegex); err != nil { + return + } + c.hostOut = c.outputRegex.SubexpIndex("host") + c.portOut = c.outputRegex.SubexpIndex("port") + if c.hostOut <= 0 { + return errors.New("tunneler.output-regex: missing named `(?)` capture group") + } + if c.portOut <= 0 { + return errors.New("tunneler.output-regex: missing named `(?)` capture group") + } + return +} diff --git a/database/db.go b/database/db.go index 2f42c46a..594b7775 100644 --- a/database/db.go +++ b/database/db.go @@ -20,16 +20,12 @@ package database import ( - "database/sql" - "database/sql/driver" - "encoding/json" "errors" - "fmt" "time" "github.com/google/uuid" - "github.com/LiterMC/go-openbmclapi/utils" + "github.com/LiterMC/go-openbmclapi/api" ) var ( @@ -55,27 +51,27 @@ type DB interface { // the callback should not edit the record pointer ForEachFileRecord(cb func(*FileRecord) error) error - GetSubscribe(user string, client string) (*SubscribeRecord, error) - SetSubscribe(SubscribeRecord) error + GetSubscribe(user string, client string) (*api.SubscribeRecord, error) + SetSubscribe(api.SubscribeRecord) error RemoveSubscribe(user string, client string) error - ForEachSubscribe(cb func(*SubscribeRecord) error) error + ForEachSubscribe(cb func(*api.SubscribeRecord) error) error - GetEmailSubscription(user string, addr string) (*EmailSubscriptionRecord, error) - AddEmailSubscription(EmailSubscriptionRecord) error - UpdateEmailSubscription(EmailSubscriptionRecord) error + GetEmailSubscription(user string, addr string) (*api.EmailSubscriptionRecord, error) + AddEmailSubscription(api.EmailSubscriptionRecord) error + UpdateEmailSubscription(api.EmailSubscriptionRecord) error RemoveEmailSubscription(user string, addr string) error - ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) error) error - ForEachUsersEmailSubscription(user string, cb func(*EmailSubscriptionRecord) error) error - ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRecord) error) error + ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord) error) error + ForEachUsersEmailSubscription(user string, cb func(*api.EmailSubscriptionRecord) error) error + ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptionRecord) error) error - GetWebhook(user string, id uuid.UUID) (*WebhookRecord, error) - AddWebhook(WebhookRecord) error - UpdateWebhook(WebhookRecord) error + GetWebhook(user string, id uuid.UUID) (*api.WebhookRecord, error) + AddWebhook(api.WebhookRecord) error + UpdateWebhook(api.WebhookRecord) error UpdateEnableWebhook(user string, id uuid.UUID, enabled bool) error RemoveWebhook(user string, id uuid.UUID) error - ForEachWebhook(cb func(*WebhookRecord) error) error - ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) error - ForEachEnabledWebhook(cb func(*WebhookRecord) error) error + ForEachWebhook(cb func(*api.WebhookRecord) error) error + ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) error) error + ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) error } type FileRecord struct { diff --git a/database/memory.go b/database/memory.go index 3b7462c1..8169cdc6 100644 --- a/database/memory.go +++ b/database/memory.go @@ -25,6 +25,7 @@ import ( "github.com/google/uuid" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -41,13 +42,13 @@ type MemoryDB struct { tokens map[string]time.Time subscribeMux sync.RWMutex - subscribeRecords map[[2]string]*SubscribeRecord + subscribeRecords map[[2]string]*api.SubscribeRecord emailSubMux sync.RWMutex - emailSubRecords map[[2]string]*EmailSubscriptionRecord + emailSubRecords map[[2]string]*api.EmailSubscriptionRecord webhookMux sync.RWMutex - webhookRecords map[webhookMemKey]*WebhookRecord + webhookRecords map[webhookMemKey]*api.WebhookRecord } var _ DB = (*MemoryDB)(nil) @@ -56,7 +57,7 @@ func NewMemoryDB() *MemoryDB { return &MemoryDB{ fileRecords: make(map[string]*FileRecord), tokens: make(map[string]time.Time), - subscribeRecords: make(map[[2]string]*SubscribeRecord), + subscribeRecords: make(map[[2]string]*api.SubscribeRecord), } } @@ -156,7 +157,7 @@ func (m *MemoryDB) ForEachFileRecord(cb func(*FileRecord) error) error { return nil } -func (m *MemoryDB) GetSubscribe(user string, client string) (*SubscribeRecord, error) { +func (m *MemoryDB) GetSubscribe(user string, client string) (*api.SubscribeRecord, error) { m.subscribeMux.RLock() defer m.subscribeMux.RUnlock() @@ -167,7 +168,7 @@ func (m *MemoryDB) GetSubscribe(user string, client string) (*SubscribeRecord, e return record, nil } -func (m *MemoryDB) SetSubscribe(record SubscribeRecord) error { +func (m *MemoryDB) SetSubscribe(record api.SubscribeRecord) error { m.subscribeMux.Lock() defer m.subscribeMux.Unlock() @@ -196,7 +197,7 @@ func (m *MemoryDB) RemoveSubscribe(user string, client string) error { return nil } -func (m *MemoryDB) ForEachSubscribe(cb func(*SubscribeRecord) error) error { +func (m *MemoryDB) ForEachSubscribe(cb func(*api.SubscribeRecord) error) error { m.subscribeMux.RLock() defer m.subscribeMux.RUnlock() @@ -211,7 +212,7 @@ func (m *MemoryDB) ForEachSubscribe(cb func(*SubscribeRecord) error) error { return nil } -func (m *MemoryDB) GetEmailSubscription(user string, addr string) (*EmailSubscriptionRecord, error) { +func (m *MemoryDB) GetEmailSubscription(user string, addr string) (*api.EmailSubscriptionRecord, error) { m.emailSubMux.RLock() defer m.emailSubMux.RUnlock() @@ -222,7 +223,7 @@ func (m *MemoryDB) GetEmailSubscription(user string, addr string) (*EmailSubscri return record, nil } -func (m *MemoryDB) AddEmailSubscription(record EmailSubscriptionRecord) error { +func (m *MemoryDB) AddEmailSubscription(record api.EmailSubscriptionRecord) error { m.emailSubMux.Lock() defer m.emailSubMux.Unlock() @@ -234,7 +235,7 @@ func (m *MemoryDB) AddEmailSubscription(record EmailSubscriptionRecord) error { return nil } -func (m *MemoryDB) UpdateEmailSubscription(record EmailSubscriptionRecord) error { +func (m *MemoryDB) UpdateEmailSubscription(record api.EmailSubscriptionRecord) error { m.emailSubMux.Lock() defer m.emailSubMux.Unlock() @@ -260,7 +261,7 @@ func (m *MemoryDB) RemoveEmailSubscription(user string, addr string) error { return nil } -func (m *MemoryDB) ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) error) error { +func (m *MemoryDB) ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord) error) error { m.emailSubMux.RLock() defer m.emailSubMux.RUnlock() @@ -275,7 +276,7 @@ func (m *MemoryDB) ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) er return nil } -func (m *MemoryDB) ForEachUsersEmailSubscription(user string, cb func(*EmailSubscriptionRecord) error) error { +func (m *MemoryDB) ForEachUsersEmailSubscription(user string, cb func(*api.EmailSubscriptionRecord) error) error { m.emailSubMux.RLock() defer m.emailSubMux.RUnlock() @@ -293,7 +294,7 @@ func (m *MemoryDB) ForEachUsersEmailSubscription(user string, cb func(*EmailSubs return nil } -func (m *MemoryDB) ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRecord) error) error { +func (m *MemoryDB) ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptionRecord) error) error { m.emailSubMux.RLock() defer m.emailSubMux.RUnlock() @@ -311,7 +312,7 @@ func (m *MemoryDB) ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRec return nil } -func (m *MemoryDB) GetWebhook(user string, id uuid.UUID) (*WebhookRecord, error) { +func (m *MemoryDB) GetWebhook(user string, id uuid.UUID) (*api.WebhookRecord, error) { m.webhookMux.RLock() defer m.webhookMux.RUnlock() @@ -327,7 +328,7 @@ var ( emptyStrPtr = &emptyStr ) -func (m *MemoryDB) AddWebhook(record WebhookRecord) (err error) { +func (m *MemoryDB) AddWebhook(record api.WebhookRecord) (err error) { m.webhookMux.Lock() defer m.webhookMux.Unlock() @@ -349,7 +350,7 @@ func (m *MemoryDB) AddWebhook(record WebhookRecord) (err error) { return nil } -func (m *MemoryDB) UpdateWebhook(record WebhookRecord) error { +func (m *MemoryDB) UpdateWebhook(record api.WebhookRecord) error { m.webhookMux.Lock() defer m.webhookMux.Unlock() @@ -395,7 +396,7 @@ func (m *MemoryDB) RemoveWebhook(user string, id uuid.UUID) error { return nil } -func (m *MemoryDB) ForEachWebhook(cb func(*WebhookRecord) error) error { +func (m *MemoryDB) ForEachWebhook(cb func(*api.WebhookRecord) error) error { m.webhookMux.RLock() defer m.webhookMux.RUnlock() @@ -410,7 +411,7 @@ func (m *MemoryDB) ForEachWebhook(cb func(*WebhookRecord) error) error { return nil } -func (m *MemoryDB) ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) error { +func (m *MemoryDB) ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) error) error { m.webhookMux.RLock() defer m.webhookMux.RUnlock() @@ -428,7 +429,7 @@ func (m *MemoryDB) ForEachUsersWebhook(user string, cb func(*WebhookRecord) erro return nil } -func (m *MemoryDB) ForEachEnabledWebhook(cb func(*WebhookRecord) error) error { +func (m *MemoryDB) ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) error { m.webhookMux.RLock() defer m.webhookMux.RUnlock() diff --git a/database/sql.go b/database/sql.go index bc4cc1fb..aa8ca576 100644 --- a/database/sql.go +++ b/database/sql.go @@ -28,6 +28,7 @@ import ( "github.com/google/uuid" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/log" ) @@ -613,11 +614,11 @@ func (db *SqlDB) setupSubscribeDollarMark(ctx context.Context) (err error) { return err } -func (db *SqlDB) GetSubscribe(user string, client string) (rec *SubscribeRecord, err error) { +func (db *SqlDB) GetSubscribe(user string, client string) (rec *api.SubscribeRecord, err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - rec = new(SubscribeRecord) + rec = new(api.SubscribeRecord) rec.User = user rec.Client = client if err = db.subscribeStmts.get.QueryRowContext(ctx, user, client).Scan(&rec.EndPoint, &rec.Keys, &rec.Scopes, &rec.ReportAt); err != nil { @@ -629,7 +630,7 @@ func (db *SqlDB) GetSubscribe(user string, client string) (rec *SubscribeRecord, return } -func (db *SqlDB) SetSubscribe(rec SubscribeRecord) (err error) { +func (db *SqlDB) SetSubscribe(rec api.SubscribeRecord) (err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -691,7 +692,7 @@ func (db *SqlDB) RemoveSubscribe(user string, client string) (err error) { return } -func (db *SqlDB) ForEachSubscribe(cb func(*SubscribeRecord) error) (err error) { +func (db *SqlDB) ForEachSubscribe(cb func(*api.SubscribeRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -700,7 +701,7 @@ func (db *SqlDB) ForEachSubscribe(cb func(*SubscribeRecord) error) (err error) { return } defer rows.Close() - var rec SubscribeRecord + var rec api.SubscribeRecord for rows.Next() { if err = rows.Scan(&rec.User, &rec.Client, &rec.EndPoint, &rec.Keys, &rec.Scopes, &rec.ReportAt, &rec.LastReport); err != nil { return @@ -856,11 +857,11 @@ func (db *SqlDB) setupEmailSubscriptionsDollarMark(ctx context.Context) (err err return err } -func (db *SqlDB) GetEmailSubscription(user string, addr string) (rec *EmailSubscriptionRecord, err error) { +func (db *SqlDB) GetEmailSubscription(user string, addr string) (rec *api.EmailSubscriptionRecord, err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - rec = new(EmailSubscriptionRecord) + rec = new(api.EmailSubscriptionRecord) rec.User = user rec.Addr = addr if err = db.emailSubscriptionStmts.get.QueryRowContext(ctx, user, addr).Scan(&rec.Scopes, &rec.Enabled); err != nil { @@ -872,7 +873,7 @@ func (db *SqlDB) GetEmailSubscription(user string, addr string) (rec *EmailSubsc return } -func (db *SqlDB) AddEmailSubscription(rec EmailSubscriptionRecord) (err error) { +func (db *SqlDB) AddEmailSubscription(rec api.EmailSubscriptionRecord) (err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -882,7 +883,7 @@ func (db *SqlDB) AddEmailSubscription(rec EmailSubscriptionRecord) (err error) { return } -func (db *SqlDB) UpdateEmailSubscription(rec EmailSubscriptionRecord) (err error) { +func (db *SqlDB) UpdateEmailSubscription(rec api.EmailSubscriptionRecord) (err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -915,7 +916,7 @@ func (db *SqlDB) RemoveEmailSubscription(user string, addr string) (err error) { return } -func (db *SqlDB) ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) error) (err error) { +func (db *SqlDB) ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -924,7 +925,7 @@ func (db *SqlDB) ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) erro return } defer rows.Close() - var rec EmailSubscriptionRecord + var rec api.EmailSubscriptionRecord for rows.Next() { if err = rows.Scan(&rec.User, &rec.Addr, &rec.Scopes, &rec.Enabled); err != nil { return @@ -937,7 +938,7 @@ func (db *SqlDB) ForEachEmailSubscription(cb func(*EmailSubscriptionRecord) erro return } -func (db *SqlDB) ForEachUsersEmailSubscription(user string, cb func(*EmailSubscriptionRecord) error) (err error) { +func (db *SqlDB) ForEachUsersEmailSubscription(user string, cb func(*api.EmailSubscriptionRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -946,7 +947,7 @@ func (db *SqlDB) ForEachUsersEmailSubscription(user string, cb func(*EmailSubscr return } defer rows.Close() - var rec EmailSubscriptionRecord + var rec api.EmailSubscriptionRecord rec.User = user for rows.Next() { if err = rows.Scan(&rec.Addr, &rec.Scopes, &rec.Enabled); err != nil { @@ -960,7 +961,7 @@ func (db *SqlDB) ForEachUsersEmailSubscription(user string, cb func(*EmailSubscr return } -func (db *SqlDB) ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRecord) error) (err error) { +func (db *SqlDB) ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptionRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -969,7 +970,7 @@ func (db *SqlDB) ForEachEnabledEmailSubscription(cb func(*EmailSubscriptionRecor return } defer rows.Close() - var rec EmailSubscriptionRecord + var rec api.EmailSubscriptionRecord for rows.Next() { if err = rows.Scan(&rec.User, &rec.Addr, &rec.Scopes); err != nil { return @@ -1143,11 +1144,11 @@ func (db *SqlDB) setupWebhooksDollarMark(ctx context.Context) (err error) { return err } -func (db *SqlDB) GetWebhook(user string, id uuid.UUID) (rec *WebhookRecord, err error) { +func (db *SqlDB) GetWebhook(user string, id uuid.UUID) (rec *api.WebhookRecord, err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() - rec = new(WebhookRecord) + rec = new(api.WebhookRecord) rec.User = user rec.Id = id if err = db.webhookStmts.get.QueryRowContext(ctx, user, hex.EncodeToString(id[:])).Scan(&rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled); err != nil { @@ -1159,7 +1160,7 @@ func (db *SqlDB) GetWebhook(user string, id uuid.UUID) (rec *WebhookRecord, err return } -func (db *SqlDB) AddWebhook(rec WebhookRecord) (err error) { +func (db *SqlDB) AddWebhook(rec api.WebhookRecord) (err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -1172,7 +1173,7 @@ func (db *SqlDB) AddWebhook(rec WebhookRecord) (err error) { return } -func (db *SqlDB) UpdateWebhook(rec WebhookRecord) (err error) { +func (db *SqlDB) UpdateWebhook(rec api.WebhookRecord) (err error) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() @@ -1211,7 +1212,7 @@ func (db *SqlDB) RemoveWebhook(user string, id uuid.UUID) (err error) { return } -func (db *SqlDB) ForEachWebhook(cb func(*WebhookRecord) error) (err error) { +func (db *SqlDB) ForEachWebhook(cb func(*api.WebhookRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -1220,7 +1221,7 @@ func (db *SqlDB) ForEachWebhook(cb func(*WebhookRecord) error) (err error) { return } defer rows.Close() - var rec WebhookRecord + var rec api.WebhookRecord for rows.Next() { if err = rows.Scan(&rec.User, &rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled); err != nil { return @@ -1233,7 +1234,7 @@ func (db *SqlDB) ForEachWebhook(cb func(*WebhookRecord) error) (err error) { return } -func (db *SqlDB) ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) (err error) { +func (db *SqlDB) ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -1242,7 +1243,7 @@ func (db *SqlDB) ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) return } defer rows.Close() - var rec WebhookRecord + var rec api.WebhookRecord rec.User = user for rows.Next() { if err = rows.Scan(&rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled, &rec.User); err != nil { @@ -1256,7 +1257,7 @@ func (db *SqlDB) ForEachUsersWebhook(user string, cb func(*WebhookRecord) error) return } -func (db *SqlDB) ForEachEnabledWebhook(cb func(*WebhookRecord) error) (err error) { +func (db *SqlDB) ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) (err error) { ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() @@ -1265,7 +1266,7 @@ func (db *SqlDB) ForEachEnabledWebhook(cb func(*WebhookRecord) error) (err error return } defer rows.Close() - var rec WebhookRecord + var rec api.WebhookRecord for rows.Next() { if err = rows.Scan(&rec.User, &rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes); err != nil { return diff --git a/go.mod b/go.mod index 9fdf8233..8d644429 100644 --- a/go.mod +++ b/go.mod @@ -3,18 +3,24 @@ module github.com/LiterMC/go-openbmclapi go 1.22.0 require ( - github.com/LiterMC/socket.io v0.2.4 + github.com/LiterMC/socket.io v0.2.5 github.com/crow-misia/http-ece v0.0.1 github.com/glebarez/go-sqlite v1.22.0 + github.com/go-sql-driver/mysql v1.8.0 github.com/golang-jwt/jwt/v5 v5.2.0 + github.com/google/uuid v1.5.0 + github.com/gorilla/schema v1.4.0 github.com/gorilla/websocket v1.5.1 github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 github.com/hamba/avro/v2 v2.18.0 github.com/klauspost/compress v1.17.4 + github.com/lib/pq v1.10.9 + github.com/libp2p/go-doh-resolver v0.4.0 github.com/patrickmn/go-cache v2.1.0+incompatible github.com/redis/go-redis/v9 v9.4.0 github.com/studio-b12/gowebdav v0.9.0 github.com/vbauerster/mpb/v8 v8.7.2 + github.com/xhit/go-simple-mail/v2 v2.16.0 gopkg.in/yaml.v3 v3.0.1 ) @@ -25,13 +31,9 @@ require ( github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/go-sql-driver/mysql v1.8.0 // indirect - github.com/google/uuid v1.5.0 // indirect - github.com/gorilla/schema v1.4.0 // indirect + github.com/go-test/deep v1.1.1 // indirect github.com/ipfs/go-log/v2 v2.1.3 // indirect github.com/json-iterator/go v1.1.12 // indirect - github.com/lib/pq v1.10.9 // indirect - github.com/libp2p/go-doh-resolver v0.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.15 // indirect github.com/miekg/dns v1.1.41 // indirect @@ -49,7 +51,6 @@ require ( github.com/rivo/uniseg v0.4.4 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/toorop/go-dkim v0.0.0-20201103131630-e1cd1a0a5208 // indirect - github.com/xhit/go-simple-mail/v2 v2.16.0 // indirect go.uber.org/atomic v1.7.0 // indirect go.uber.org/multierr v1.6.0 // indirect go.uber.org/zap v1.16.0 // indirect diff --git a/go.sum b/go.sum index 3ed4d9ee..9e8d673b 100644 --- a/go.sum +++ b/go.sum @@ -1,8 +1,11 @@ filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= +github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/LiterMC/socket.io v0.2.4 h1:ycVw/soQQZDA57Lz029Sre/oeDbNMu/+eh5EIJxt/u4= github.com/LiterMC/socket.io v0.2.4/go.mod h1:MqUeyAZQgqD8PrRPIS3h+mV63xRa4rJw6uZohSvc8NY= +github.com/LiterMC/socket.io v0.2.5 h1:gCO8QhnRTPfYfqEw9exq1Qnl3AMZ9Jozw+qFZ+kxD8s= +github.com/LiterMC/socket.io v0.2.5/go.mod h1:MqUeyAZQgqD8PrRPIS3h+mV63xRa4rJw6uZohSvc8NY= github.com/VividCortex/ewma v1.2.0 h1:f58SaIzcDXrSy3kWaHNvuJgJ3Nmz59Zji6XoJR/q1ow= github.com/VividCortex/ewma v1.2.0/go.mod h1:nz4BbCtbLyFDeC9SUHbtcT5644juEuWfUAUnGx7j5l4= github.com/acarl005/stripansi v0.0.0-20180116102854-5a71ef0e047d h1:licZJFw2RwpHMqeKTCYkitsPqHNxTmd4SNR5r94FGM8= @@ -26,6 +29,8 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/go-sql-driver/mysql v1.8.0 h1:UtktXaU2Nb64z/pLiGIxY4431SJ4/dR5cjMmlVHgnT4= github.com/go-sql-driver/mysql v1.8.0/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-test/deep v1.1.1 h1:0r/53hagsehfO4bzD2Pgr/+RgHqhmf+k1Bpse2cTu1U= +github.com/go-test/deep v1.1.1/go.mod h1:5C2ZWiW0ErCdrYzpqxLbTX7MG14M9iiw8DgHncVwcsE= github.com/golang-jwt/jwt/v5 v5.2.0 h1:d/ix8ftRUorsN+5eMIlF4T6J8CAt9rch3My2winC1Jw= github.com/golang-jwt/jwt/v5 v5.2.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -49,8 +54,10 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= @@ -85,6 +92,7 @@ github.com/multiformats/go-varint v0.0.1 h1:TR/0rdQtnNxuN2IhiB639xC3tWM4IUi7DkTB github.com/multiformats/go-varint v0.0.1/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -128,9 +136,12 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA= golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de h1:5hukYrvBGR8/eNkX5mdUezrA6JiaEZDtJb9Ei+1LlBs= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= @@ -140,6 +151,7 @@ golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -159,16 +171,20 @@ golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.16.1 h1:TLyB3WofjdOEepBHAU20JdNC1Zbg87elYofWYAY5oZA= +golang.org/x/tools v0.16.1/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= modernc.org/libc v1.37.6 h1:orZH3c5wmhIQFTXF+Nt+eeauyd+ZIt2BX6ARe+kD+aw= modernc.org/libc v1.37.6/go.mod h1:YAXkAZ8ktnkCKaN9sw/UDeUVkGYJ/YquGO4FTi5nmHE= diff --git a/notify/email/email.go b/notify/email/email.go index 5e5ff143..d014fb16 100644 --- a/notify/email/email.go +++ b/notify/email/email.go @@ -33,6 +33,7 @@ import ( mail "github.com/xhit/go-simple-mail/v2" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/notify" ) @@ -131,9 +132,9 @@ func (p *Plugin) sendEmail(ctx context.Context, subject string, body []byte, to return m.Send(cli) } -func (p *Plugin) sendEmailIf(ctx context.Context, subject string, body []byte, filter func(*database.EmailSubscriptionRecord) bool) (err error) { +func (p *Plugin) sendEmailIf(ctx context.Context, subject string, body []byte, filter func(*api.EmailSubscriptionRecord) bool) (err error) { var recipients []string - p.db.ForEachEnabledEmailSubscription(func(record *database.EmailSubscriptionRecord) error { + p.db.ForEachEnabledEmailSubscription(func(record *api.EmailSubscriptionRecord) error { if filter(record) { recipients = append(recipients, record.Addr) } @@ -153,7 +154,7 @@ func (p *Plugin) OnEnabled(e *notify.EnabledEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Enabled", buf.Bytes(), func(record *database.EmailSubscriptionRecord) bool { return record.Scopes.Enabled }) + return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Enabled", buf.Bytes(), func(record *api.EmailSubscriptionRecord) bool { return record.Scopes.Enabled }) } func (p *Plugin) OnDisabled(e *notify.DisabledEvent) error { @@ -164,7 +165,7 @@ func (p *Plugin) OnDisabled(e *notify.DisabledEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Disabled", buf.Bytes(), func(record *database.EmailSubscriptionRecord) bool { return record.Scopes.Disabled }) + return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Disabled", buf.Bytes(), func(record *api.EmailSubscriptionRecord) bool { return record.Scopes.Disabled }) } func (p *Plugin) OnSyncBegin(e *notify.SyncBeginEvent) error { @@ -175,7 +176,7 @@ func (p *Plugin) OnSyncBegin(e *notify.SyncBeginEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Sync Begin", buf.Bytes(), func(record *database.EmailSubscriptionRecord) bool { return record.Scopes.SyncBegin }) + return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Sync Begin", buf.Bytes(), func(record *api.EmailSubscriptionRecord) bool { return record.Scopes.SyncBegin }) } func (p *Plugin) OnSyncDone(e *notify.SyncDoneEvent) error { @@ -186,7 +187,7 @@ func (p *Plugin) OnSyncDone(e *notify.SyncDoneEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Sync Done", buf.Bytes(), func(record *database.EmailSubscriptionRecord) bool { return record.Scopes.SyncDone }) + return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Sync Done", buf.Bytes(), func(record *api.EmailSubscriptionRecord) bool { return record.Scopes.SyncDone }) } func (p *Plugin) OnUpdateAvaliable(e *notify.UpdateAvaliableEvent) error { @@ -197,7 +198,7 @@ func (p *Plugin) OnUpdateAvaliable(e *notify.UpdateAvaliableEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Update Avaliable", buf.Bytes(), func(record *database.EmailSubscriptionRecord) bool { return record.Scopes.Updates }) + return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Update Avaliable", buf.Bytes(), func(record *api.EmailSubscriptionRecord) bool { return record.Scopes.Updates }) } func (p *Plugin) OnReportStatus(e *notify.ReportStatusEvent) error { @@ -213,5 +214,5 @@ func (p *Plugin) OnReportStatus(e *notify.ReportStatusEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Daily Report", buf.Bytes(), func(record *database.EmailSubscriptionRecord) bool { return record.Scopes.DailyReport }) + return p.sendEmailIf(tctx, "Go-OpenBMCLAPI Daily Report", buf.Bytes(), func(record *api.EmailSubscriptionRecord) bool { return record.Scopes.DailyReport }) } diff --git a/notify/webpush/webpush.go b/notify/webpush/webpush.go index aaf3ca1a..cc5b2552 100644 --- a/notify/webpush/webpush.go +++ b/notify/webpush/webpush.go @@ -44,6 +44,7 @@ import ( "github.com/crow-misia/http-ece" "github.com/golang-jwt/jwt/v5" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/log" @@ -88,7 +89,7 @@ func (p *Plugin) Init(ctx context.Context, m *notify.Manager) (err error) { type Subscription struct { EndPoint string - Keys database.SubscribeRecordKeys + Keys api.SubscribeRecordKeys } type PushOptions struct { @@ -170,16 +171,16 @@ func (p *Plugin) sendNotification(ctx context.Context, message []byte, s *Subscr return } -func (p *Plugin) sendMessageIf(ctx context.Context, message []byte, opts *PushOptions, filter func(*database.SubscribeRecord) bool) (err error) { +func (p *Plugin) sendMessageIf(ctx context.Context, message []byte, opts *PushOptions, filter func(*api.SubscribeRecord) bool) (err error) { log.Debugf("Sending notification: %s", message) var wg sync.WaitGroup var mux sync.Mutex - var outdated []database.SubscribeRecord - err = p.db.ForEachSubscribe(func(record *database.SubscribeRecord) error { + var outdated []api.SubscribeRecord + err = p.db.ForEachSubscribe(func(record *api.SubscribeRecord) error { if filter(record) { log.Debugf("Sending notification to %s", record.EndPoint) wg.Add(1) - go func(record database.SubscribeRecord) { + go func(record api.SubscribeRecord) { defer wg.Done() subs := &Subscription{ EndPoint: record.EndPoint, @@ -279,7 +280,7 @@ func (p *Plugin) OnEnabled(e *notify.EnabledEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() - return p.sendMessageIf(tctx, message, opts, func(record *database.SubscribeRecord) bool { return record.Scopes.Enabled }) + return p.sendMessageIf(tctx, message, opts, func(record *api.SubscribeRecord) bool { return record.Scopes.Enabled }) } func (p *Plugin) OnDisabled(e *notify.DisabledEvent) error { @@ -298,7 +299,7 @@ func (p *Plugin) OnDisabled(e *notify.DisabledEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() - return p.sendMessageIf(tctx, message, opts, func(record *database.SubscribeRecord) bool { return record.Scopes.Disabled }) + return p.sendMessageIf(tctx, message, opts, func(record *api.SubscribeRecord) bool { return record.Scopes.Disabled }) } func (p *Plugin) OnSyncBegin(e *notify.SyncBeginEvent) error { @@ -319,7 +320,7 @@ func (p *Plugin) OnSyncBegin(e *notify.SyncBeginEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() - return p.sendMessageIf(tctx, message, opts, func(record *database.SubscribeRecord) bool { return record.Scopes.SyncBegin }) + return p.sendMessageIf(tctx, message, opts, func(record *api.SubscribeRecord) bool { return record.Scopes.SyncBegin }) } func (p *Plugin) OnSyncDone(e *notify.SyncDoneEvent) error { @@ -338,7 +339,7 @@ func (p *Plugin) OnSyncDone(e *notify.SyncDoneEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() - return p.sendMessageIf(tctx, message, opts, func(record *database.SubscribeRecord) bool { return record.Scopes.SyncDone }) + return p.sendMessageIf(tctx, message, opts, func(record *api.SubscribeRecord) bool { return record.Scopes.SyncDone }) } func (p *Plugin) OnUpdateAvaliable(e *notify.UpdateAvaliableEvent) error { @@ -357,7 +358,7 @@ func (p *Plugin) OnUpdateAvaliable(e *notify.UpdateAvaliableEvent) error { tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() - return p.sendMessageIf(tctx, message, opts, func(record *database.SubscribeRecord) bool { return record.Scopes.Updates }) + return p.sendMessageIf(tctx, message, opts, func(record *api.SubscribeRecord) bool { return record.Scopes.Updates }) } func (p *Plugin) OnReportStatus(e *notify.ReportStatusEvent) (err error) { @@ -388,8 +389,8 @@ func (p *Plugin) OnReportStatus(e *notify.ReportStatusEvent) (err error) { tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) defer cancel() now := e.At.UTC() - var sent []database.SubscribeRecord - err = p.sendMessageIf(tctx, message, opts, func(record *database.SubscribeRecord) bool { + var sent []api.SubscribeRecord + err = p.sendMessageIf(tctx, message, opts, func(record *api.SubscribeRecord) bool { if !record.Scopes.DailyReport { return false } diff --git a/storage/manager.go b/storage/manager.go index bc348d2a..36c43405 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -48,6 +48,15 @@ func NewManager(storages []Storage) (m *Manager) { return } +func (m *Manager) Get(id string) Storage { + for _, s := range m.Storages { + if s.Options().Id == id { + return s + } + } + return nil +} + func (m *Manager) GetFlavorString(storages []int) string { typeCount := make(map[string]int, 2) for _, i := range storages { diff --git a/storage/storage_local.go b/storage/storage_local.go index a4c4a8e0..b912e1d5 100644 --- a/storage/storage_local.go +++ b/storage/storage_local.go @@ -50,16 +50,18 @@ var _ Storage = (*LocalStorage)(nil) func init() { RegisterStorageFactory(StorageLocal, StorageFactory{ - New: func(opt StorageOption) Storage { - return &LocalStorage{ - basicOpt: opt, - opt: *(opt.Data.(*LocalStorageOption)), - } - }, + New: func(opt StorageOption) Storage { return NewLocalStorage(opt) }, NewConfig: func() any { return new(LocalStorageOption) }, }) } +func NewLocalStorage(opt StorageOption) *LocalStorage { + return &LocalStorage{ + basicOpt: opt, + opt: *(opt.Data.(*LocalStorageOption)), + } +} + func (s *LocalStorage) String() string { return fmt.Sprintf("", s.opt.CachePath) } diff --git a/storage/storage_webdav.go b/storage/storage_webdav.go index f9443598..5b472cee 100644 --- a/storage/storage_webdav.go +++ b/storage/storage_webdav.go @@ -140,16 +140,18 @@ var _ Storage = (*WebDavStorage)(nil) func init() { RegisterStorageFactory(StorageWebdav, StorageFactory{ - New: func(opt StorageOption) Storage { - return &WebDavStorage{ - basicOpt: opt, - opt: *(opt.Data.(*WebDavStorageOption)), - } - }, + New: func(opt StorageOption) Storage { return NewWebDavStorage(opt) }, NewConfig: func() any { return new(WebDavStorageOption) }, }) } +func NewWebDavStorage(opt StorageOption) *WebDavStorage { + return &WebDavStorage{ + basicOpt: opt, + opt: *(opt.Data.(*WebDavStorageOption)), + } +} + func (s *WebDavStorage) String() string { return fmt.Sprintf("", s.opt.GetEndPoint(), s.opt.GetUsername()) } diff --git a/sub_commands/cmd_compress.go b/sub_commands/cmd_compress.go index b8ee69ed..4f60637e 100644 --- a/sub_commands/cmd_compress.go +++ b/sub_commands/cmd_compress.go @@ -1,3 +1,5 @@ +//go:build ignore + /** * OpenBmclAPI (Golang Edition) * Copyright (C) 2024 Kevin Z diff --git a/sub_commands/cmd_webdav.go b/sub_commands/cmd_webdav.go index 18a86edb..33ea8932 100644 --- a/sub_commands/cmd_webdav.go +++ b/sub_commands/cmd_webdav.go @@ -23,6 +23,7 @@ import ( "context" "fmt" "os" + "errors" "runtime" "sync" "sync/atomic" @@ -31,19 +32,22 @@ import ( "github.com/vbauerster/mpb/v8" "github.com/vbauerster/mpb/v8/decor" + "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/storage" ) func cmdUploadWebdav(args []string) { - config = readConfig() + cfg := readConfig() - var localOpt *storage.LocalStorageOption - webdavOpts := make([]*storage.WebDavStorageOption, 0, 4) - for _, s := range config.Storages { - switch s := s.Data.(type) { + var ( + localOpt storage.StorageOption + webdavOpts = make([]storage.StorageOption, 0, 4) + ) + for _, s := range cfg.Storages { + switch s.Data.(type) { case *storage.LocalStorageOption: - if localOpt == nil { + if localOpt.Data == nil { localOpt = s } case *storage.WebDavStorageOption: @@ -51,7 +55,7 @@ func cmdUploadWebdav(args []string) { } } - if localOpt == nil { + if localOpt.Data == nil { log.Error("At least one local storage is required") os.Exit(1) } @@ -62,8 +66,7 @@ func cmdUploadWebdav(args []string) { ctx := context.Background() - var local storage.LocalStorage - local.SetOptions(localOpt) + local := storage.NewLocalStorage(localOpt) if err := local.Init(ctx); err != nil { log.Errorf("Cannot initialize %s: %v", local.String(), err) os.Exit(1) @@ -73,11 +76,10 @@ func cmdUploadWebdav(args []string) { webdavs := make([]*storage.WebDavStorage, len(webdavOpts)) maxProc := 0 for i, opt := range webdavOpts { - if opt.MaxConn > maxProc { - maxProc = opt.MaxConn + if maxConn := opt.Data.(*storage.WebDavStorageOption).MaxConn; maxConn > maxProc { + maxProc = maxConn } - s := new(storage.WebDavStorage) - s.SetOptions(opt) + s := storage.NewWebDavStorage(opt) if err := s.Init(ctx); err != nil { log.Errorf("Cannot initialize %s: %v", s.String(), err) os.Exit(1) @@ -248,3 +250,23 @@ func cmdUploadWebdav(args []string) { pg.Wait() log.SetLogOutput(nil) } + +func readConfig() (cfg *config.Config) { + const configPath = "config.yaml" + + cfg = config.NewDefaultConfig() + data, err := os.ReadFile(configPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + log.TrErrorf("error.config.read.failed", err) + os.Exit(1) + } + log.TrErrorf("error.config.not.exists") + os.Exit(1) + } + if err = cfg.UnmarshalText(data); err != nil { + log.TrErrorf("error.config.parse.failed", err) + os.Exit(1) + } + return +} From 9d449013ba7ff5c2290ab4281af41e3a1394d587 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 8 Aug 2024 08:12:15 -0700 Subject: [PATCH 13/36] fix notifier error --- api/bmclapi/hijacker.go | 22 ++++++++++++---------- main.go | 2 -- notify/event.go | 3 ++- notify/manager.go | 5 +++-- sub_commands/cmd_webdav.go | 5 +++-- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/api/bmclapi/hijacker.go b/api/bmclapi/hijacker.go index a040b036..f45d0b24 100644 --- a/api/bmclapi/hijacker.go +++ b/api/bmclapi/hijacker.go @@ -53,8 +53,10 @@ func getDialerWithDNS(dnsaddr string) *net.Dialer { type downloadHandlerFn = func(rw http.ResponseWriter, req *http.Request, hash string) type HjProxy struct { - RequireAuth bool - AuthUsers []config.UserItem + RequireAuth bool + AuthUsers []config.UserItem + EnableLocalCache bool + LocalCachePath string client *http.Client fileMap database.DB @@ -80,11 +82,11 @@ func NewHjProxy(client *http.Client, fileMap database.DB, downloadHandler downlo return } -func hjResponseWithCache(rw http.ResponseWriter, req *http.Request, cachePath string, c *cacheStat, force bool) (ok bool) { +func (h *HjProxy) hjResponseWithCache(rw http.ResponseWriter, req *http.Request, c *cacheStat, force bool) (ok bool) { if c == nil { return false } - cacheFileName := filepath.Join(cachePath, filepath.FromSlash(req.URL.Path)) + cacheFileName := filepath.Join(h.LocalCachePath, filepath.FromSlash(req.URL.Path)) age := c.ExpiresAt - time.Now().Unix() if !force && age <= 0 { return false @@ -139,9 +141,9 @@ func (h *HjProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { nowUnix := time.Now().Unix() - cacheFileName := filepath.Join(config.Hijack.LocalCachePath, filepath.FromSlash(req.URL.Path)) + cacheFileName := filepath.Join(h.LocalCachePath, filepath.FromSlash(req.URL.Path)) cached := h.getCache(req.URL.Path) - if hjResponseWithCache(rw, req, cached, false) { + if h.hjResponseWithCache(rw, req, cached, false) { return } @@ -158,7 +160,7 @@ func (h *HjProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } res, err := h.client.Do(req2) if err != nil { - if hjResponseWithCache(rw, req, cached, true) { + if h.hjResponseWithCache(rw, req, cached, true) { return } http.Error(rw, "remote: "+err.Error(), http.StatusBadGateway) @@ -178,7 +180,7 @@ func (h *HjProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } rw.WriteHeader(res.StatusCode) var body io.Reader = res.Body - if config.Hijack.EnableLocalCache && res.StatusCode == http.StatusOK { + if h.EnableLocalCache && res.StatusCode == http.StatusOK { if exp, ok := utils.ParseCacheControl(res.Header.Get("Cache-Control")); ok { if exp > 0 { os.MkdirAll(filepath.Dir(cacheFileName), 0755) @@ -211,7 +213,7 @@ type cacheStat struct { func (h *HjProxy) loadCache() (err error) { h.cache = make(map[string]*cacheStat) - fd, err := os.Open(filepath.Join(config.Hijack.LocalCachePath, "__cache.json")) + fd, err := os.Open(filepath.Join(h.LocalCachePath, "__cache.json")) if err != nil { return } @@ -220,7 +222,7 @@ func (h *HjProxy) loadCache() (err error) { } func (h *HjProxy) saveCache() (err error) { - fd, err := os.Create(filepath.Join(config.Hijack.LocalCachePath, "__cache.json")) + fd, err := os.Create(filepath.Join(h.LocalCachePath, "__cache.json")) if err != nil { return } diff --git a/main.go b/main.go index 663550d6..766ef043 100644 --- a/main.go +++ b/main.go @@ -62,8 +62,6 @@ var ( var startTime = time.Now() -var config Config = defaultConfig - const baseDir = "." func parseArgs() { diff --git a/notify/event.go b/notify/event.go index ca36a686..b253aaa4 100644 --- a/notify/event.go +++ b/notify/event.go @@ -22,6 +22,7 @@ package notify import ( "time" + "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/update" ) @@ -48,6 +49,6 @@ type ( ReportStatusEvent struct { TimestampEvent - Stats *StatData + Stats *cluster.StatData } ) diff --git a/notify/manager.go b/notify/manager.go index 7ba58a45..f6ca99af 100644 --- a/notify/manager.go +++ b/notify/manager.go @@ -25,6 +25,7 @@ import ( "sync" "time" + "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/update" @@ -192,7 +193,7 @@ func (m *Manager) OnUpdateAvaliable(release *update.GithubRelease) { } } -func (m *Manager) OnReportStatus(stats *Stats) { +func (m *Manager) OnReportStatus(stats *cluster.StatManager) { if !m.reportMux.TryLock() { return } @@ -208,7 +209,7 @@ func (m *Manager) OnReportStatus(stats *Stats) { TimestampEvent: TimestampEvent{ At: now, }, - Stats: stats.Clone(), + Stats: stats.Overall.Clone(), } res := make(chan error, 0) for _, p := range m.plugins { diff --git a/sub_commands/cmd_webdav.go b/sub_commands/cmd_webdav.go index 33ea8932..5351df30 100644 --- a/sub_commands/cmd_webdav.go +++ b/sub_commands/cmd_webdav.go @@ -21,9 +21,9 @@ package main import ( "context" + "errors" "fmt" "os" - "errors" "runtime" "sync" "sync/atomic" @@ -35,6 +35,7 @@ import ( "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/utils" ) func cmdUploadWebdav(args []string) { @@ -236,7 +237,7 @@ func cmdUploadWebdav(args []string) { bar.SetTotal(size, false) log.Debugf("Uploading %s/%s", s.String(), hash) - err := s.Create(hash, ProxyReadSeeker(fd, bar, totalBar, lastInc)) + err := s.Create(hash, utils.ProxyPBReadSeeker(fd, bar, totalBar, lastInc)) uploadedFiles.Add(1) if err != nil { log.Errorf("Cannot create %s at %s: %v", hash, s.String(), err) From caae4f755333f9924ee2f0d1f5360e1896d167ce Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 8 Aug 2024 09:06:17 -0700 Subject: [PATCH 14/36] add webhook --- cluster/cluster.go | 10 +- cluster/stat.go | 19 ++++ handler.go | 12 +- notify/event.go | 2 +- notify/manager.go | 2 +- notify/webhook/webhook.go | 223 ++++++++++++++++++++++++++++++++++++++ notify/webpush/webpush.go | 2 +- 7 files changed, 258 insertions(+), 12 deletions(-) create mode 100644 notify/webhook/webhook.go diff --git a/cluster/cluster.go b/cluster/cluster.go index e437aebb..848bbc58 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -55,12 +55,12 @@ type Cluster struct { hits atomic.Int32 hbts atomic.Int64 - mux sync.RWMutex - status atomic.Int32 + mux sync.RWMutex + status atomic.Int32 socketStatus atomic.Int32 - socket *socket.Socket - client *http.Client - cachedCli *http.Client + socket *socket.Socket + client *http.Client + cachedCli *http.Client authTokenMux sync.RWMutex authToken *ClusterToken diff --git a/cluster/stat.go b/cluster/stat.go index a3e50a2a..30b3d8b3 100644 --- a/cluster/stat.go +++ b/cluster/stat.go @@ -40,6 +40,8 @@ type StatManager struct { Storages map[string]*StatData } +var _ json.Marshaler = (*StatManager)(nil) + func NewStatManager() *StatManager { return &StatManager{ Overall: new(StatData), @@ -78,6 +80,9 @@ func (m *StatManager) AddHit(bytes int64, cluster, storage string) { func (m *StatManager) Load(dir string) error { clustersDir, storagesDir := filepath.Join(dir, "clusters"), filepath.Join(dir, "storages") + m.mux.Lock() + defer m.mux.Unlock() + *m.Overall = StatData{} clear(m.Clusters) clear(m.Storages) @@ -119,6 +124,9 @@ func (m *StatManager) Load(dir string) error { func (m *StatManager) Save(dir string) error { clustersDir, storagesDir := filepath.Join(dir, "clusters"), filepath.Join(dir, "storages") + m.mux.RLock() + defer m.mux.RUnlock() + if err := m.Overall.save(filepath.Join(dir, statsOverallFileName)); err != nil { return err } @@ -141,6 +149,17 @@ func (m *StatManager) Save(dir string) error { return nil } +func (m *StatManager) MarshalJSON() ([]byte, error) { + m.mux.RLock() + defer m.mux.RUnlock() + + return json.Marshal(map[string]any{ + "overall": m.Overall, + "clusters": m.Clusters, + "storages": m.Storages, + }) +} + type statInstData struct { Hits int32 `json:"hits"` Bytes int64 `json:"bytes"` diff --git a/handler.go b/handler.go index 1106d80d..5fad8ca5 100644 --- a/handler.go +++ b/handler.go @@ -405,6 +405,7 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if err := cr.storages[0].ServeMeasure(rw, req, n); err != nil { log.Errorf("Could not serve measure %d: %v", n, err) + SetAccessInfo(req, "error", err.Error()) http.Error(rw, err.Error(), http.StatusInternalServerError) } return @@ -472,18 +473,20 @@ func (cr *Cluster) handleDownload(rw http.ResponseWriter, req *http.Request, has if !ok { if err := cr.DownloadFile(req.Context(), hash); err != nil { // TODO: check if the file exists - http.Error(rw, "Cannot download file from center server: "+err.Error(), http.StatusInternalServerError) + estr := "Cannot download file from center server: " + err.Error() + SetAccessInfo(req, "error", estr) + http.Error(rw, estr, http.StatusInternalServerError) return } } var sto storage.Storage if forEachFromRandomIndexWithPossibility(cr.storageWeights, cr.storageTotalWeight, func(i int) bool { sto = cr.storages[i] - log.Debugf("[handler]: Checking %s on storage [%d] %s ...", hash, i, sto.String()) + log.Debugf("[handler]: Checking %s on storage [%d] %s ...", hash, i, sto.Options().Id) sz, er := sto.ServeDownload(rw, req, hash, size) if er != nil { - log.Debugf("[handler]: File %s failed on storage [%d] %s: %v", hash, i, sto.String(), er) + log.Debugf("[handler]: File %s failed on storage [%d] %s: %v", hash, i, sto.Options().Id, er) err = er return false } @@ -500,7 +503,7 @@ func (cr *Cluster) handleDownload(rw http.ResponseWriter, req *http.Request, has err = nil } if sto != nil { - SetAccessInfo(req, "storage", sto.String()) + SetAccessInfo(req, "storage", sto.Options().Id) } if err != nil { log.Debugf("[handler]: failed to serve download: %v", err) @@ -508,6 +511,7 @@ func (cr *Cluster) handleDownload(rw http.ResponseWriter, req *http.Request, has http.Error(rw, "404 Status Not Found", http.StatusNotFound) return } + SetAccessInfo(req, "error", err.Error()) if _, ok := err.(*utils.HTTPStatusError); ok { http.Error(rw, err.Error(), http.StatusBadGateway) } else { diff --git a/notify/event.go b/notify/event.go index b253aaa4..a8618c8c 100644 --- a/notify/event.go +++ b/notify/event.go @@ -49,6 +49,6 @@ type ( ReportStatusEvent struct { TimestampEvent - Stats *cluster.StatData + Stats *cluster.StatManager } ) diff --git a/notify/manager.go b/notify/manager.go index f6ca99af..d2661020 100644 --- a/notify/manager.go +++ b/notify/manager.go @@ -209,7 +209,7 @@ func (m *Manager) OnReportStatus(stats *cluster.StatManager) { TimestampEvent: TimestampEvent{ At: now, }, - Stats: stats.Overall.Clone(), + Stats: stats, } res := make(chan error, 0) for _, p := range m.plugins { diff --git a/notify/webhook/webhook.go b/notify/webhook/webhook.go new file mode 100644 index 00000000..44b8b020 --- /dev/null +++ b/notify/webhook/webhook.go @@ -0,0 +1,223 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package webpush + +import ( + "bytes" + "context" + "encoding/json" + "net/http" + "sync" + "time" + + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/cluster" + "github.com/LiterMC/go-openbmclapi/database" + "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/notify" + "github.com/LiterMC/go-openbmclapi/utils" +) + +type Plugin struct { + db database.DB + client *http.Client +} + +var _ notify.Plugin = (*Plugin)(nil) + +func (p *Plugin) ID() string { + return "webhook" +} + +func (p *Plugin) Init(ctx context.Context, m *notify.Manager) (err error) { + p.db = m.DB() + p.client = m.HTTPClient() + return nil +} + +func (p *Plugin) sendMessage(ctx context.Context, message []byte, r *api.WebhookRecord) (err error) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, r.EndPoint, bytes.NewReader(message)) + if err != nil { + return + } + req.Header.Set("User-Agent", build.ClusterUserAgentFull) + req.Header.Set("Content-Type", "application/json") + if r.Auth != nil { + req.Header.Set("Authorization", *r.Auth) + } + + resp, err := p.client.Do(req) + if err != nil { + return + } + if resp.StatusCode/100 != 2 { + return utils.NewHTTPStatusErrorFromResponse(resp) + } + return +} + +func (p *Plugin) sendMessageIf(ctx context.Context, msg Message, filter func(*api.WebhookRecord) bool) (err error) { + message, err := json.Marshal(msg) + if err != nil { + return + } + log.Debugf("Triggering webhook: %s", message) + var wg sync.WaitGroup + err = p.db.ForEachWebhook(func(record *api.WebhookRecord) error { + if filter(record) { + log.Debugf("Triggering webhook at %s", record.EndPoint) + wg.Add(1) + go func(record api.WebhookRecord) { + defer wg.Done() + if err := p.sendMessage(ctx, message, &record); err != nil { + log.Warnf("Error when triggering webhook: %v", err) + } + }(*record) + } + return nil + }) + wg.Wait() + return +} + +type MessageType string + +const ( + TypeEnabled MessageType = "enabled" + TypeDisabled MessageType = "disabled" + TypeSyncBegin MessageType = "syncbegin" + TypeSyncDone MessageType = "syncdone" + TypeUpdates MessageType = "updates" + TypeDailyReport MessageType = "daily-report" +) + +type ( + Message struct { + Type MessageType `json:"type"` + Data any `json:"data"` + } + + EnabledData struct { + At time.Time `json:"at"` + } + + DisabledData struct { + At time.Time `json:"at"` + } + + SyncBeginData struct { + At time.Time `json:"at"` + Count int `json:"count"` + Size int64 `json:"size"` + } + + SyncDoneData struct { + At time.Time `json:"at"` + } + + UpdatesData struct { + Tag string `json:"tag"` + } + + DailyReportData struct { + Stats *cluster.StatManager `json:"stats"` + } +) + +func (p *Plugin) OnEnabled(e *notify.EnabledEvent) error { + message := Message{ + Type: TypeEnabled, + Data: EnabledData{ + At: e.At, + }, + } + + tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return p.sendMessageIf(tctx, message, func(record *api.WebhookRecord) bool { return record.Scopes.Enabled }) +} + +func (p *Plugin) OnDisabled(e *notify.DisabledEvent) error { + message := Message{ + Type: TypeDisabled, + Data: DisabledData{ + At: e.At, + }, + } + + tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return p.sendMessageIf(tctx, message, func(record *api.WebhookRecord) bool { return record.Scopes.Disabled }) +} + +func (p *Plugin) OnSyncBegin(e *notify.SyncBeginEvent) error { + message := Message{ + Type: TypeSyncBegin, + Data: SyncBeginData{ + At: e.At, + Count: e.Count, + Size: e.Size, + }, + } + + tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return p.sendMessageIf(tctx, message, func(record *api.WebhookRecord) bool { return record.Scopes.SyncBegin }) +} + +func (p *Plugin) OnSyncDone(e *notify.SyncDoneEvent) error { + message := Message{ + Type: TypeSyncDone, + Data: SyncDoneData{ + At: e.At, + }, + } + + tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return p.sendMessageIf(tctx, message, func(record *api.WebhookRecord) bool { return record.Scopes.SyncDone }) +} + +func (p *Plugin) OnUpdateAvaliable(e *notify.UpdateAvaliableEvent) error { + message := Message{ + Type: TypeUpdates, + Data: UpdatesData{ + Tag: e.Release.Tag.String(), + }, + } + + tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return p.sendMessageIf(tctx, message, func(record *api.WebhookRecord) bool { return record.Scopes.Updates }) +} + +func (p *Plugin) OnReportStatus(e *notify.ReportStatusEvent) (err error) { + message := Message{ + Type: TypeDailyReport, + Data: DailyReportData{ + Stats: e.Stats, + }, + } + + tctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + return p.sendMessageIf(tctx, message, func(record *api.WebhookRecord) bool { return record.Scopes.DailyReport }) +} diff --git a/notify/webpush/webpush.go b/notify/webpush/webpush.go index cc5b2552..36a4da78 100644 --- a/notify/webpush/webpush.go +++ b/notify/webpush/webpush.go @@ -362,7 +362,7 @@ func (p *Plugin) OnUpdateAvaliable(e *notify.UpdateAvaliableEvent) error { } func (p *Plugin) OnReportStatus(e *notify.ReportStatusEvent) (err error) { - stat, err := json.Marshal(e.Stats) + stat, err := json.Marshal(e.Stats.Overall) if err != nil { log.Errorf("Cannot marshal subscribe message: %v", err) return From 74c811d75eb0e5342d12d261d640ec03e087cabd Mon Sep 17 00:00:00 2001 From: zyxkad Date: Thu, 8 Aug 2024 23:06:23 -0700 Subject: [PATCH 15/36] add license header to installer.sh --- installer/service/installer.sh | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/installer/service/installer.sh b/installer/service/installer.sh index 1f8ee5b3..b875d6c0 100755 --- a/installer/service/installer.sh +++ b/installer/service/installer.sh @@ -1,4 +1,20 @@ #!/bin/bash +# Go-OpenBMCLAPI service installer +# Copyright (C) 2024 the Go-OpenBMCLAPI Authors +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . + # if [ $(id -u) -ne 0 ]; then # echo -e "\e[31mERROR: Not root user\e[0m" From 18ee9074be42308a8dac1c8103675feacbd2e233 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 10 Aug 2024 10:11:20 -0700 Subject: [PATCH 16/36] update cluster handler --- cluster/cluster.go | 16 ++- cluster/handler.go | 167 +++++++++++++++++++++- cluster/stat.go | 11 +- config/advanced.go | 3 +- config/server.go | 9 +- dashboard.go | 2 +- handler.go | 349 ++++++--------------------------------------- main.go | 67 +++++---- 8 files changed, 266 insertions(+), 358 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index 848bbc58..f48ae384 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -26,6 +26,7 @@ import ( "net/http" "regexp" "runtime" + "strings" "sync" "sync/atomic" "time" @@ -107,6 +108,17 @@ func (cr *Cluster) PublicHosts() []string { return cr.opts.PublicHosts } +// AcceptHost checks if the host is binded to the cluster +func (cr *Cluster) AcceptHost(host string) bool { + host = strings.ToUpper(host) + for _, h := range cr.opts.PublicHosts { + if h == "*" || strings.ToUpper(h) == host { + return true + } + } + return false +} + // Init do setup on the cluster // Init should only be called once during the cluster's whole life // The context passed in only affect the logical of Init method @@ -286,8 +298,6 @@ func (cr *Cluster) disable(ctx context.Context) error { return err } select { - case <-ctx.Done(): - return ctx.Err() case data := <-resCh: log.Debug("disable ack:", data) if ero := data[0]; ero != nil { @@ -295,6 +305,8 @@ func (cr *Cluster) disable(ctx context.Context) error { } else if !data[1].(bool) { return errors.New("Disable acked non true value") } + case <-ctx.Done(): + return ctx.Err() } return nil } diff --git a/cluster/handler.go b/cluster/handler.go index 40077429..a14ca248 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -20,16 +20,80 @@ package cluster import ( + "crypto" + "encoding/base64" + "encoding/hex" + "fmt" + "io" "net/http" + "net/textproto" + "strconv" + "strings" + "time" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/storage" ) func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string, size int64) { defer log.RecoverPanic(nil) - var err error - if cr.storageManager.ForEachFromRandom(cr.storages, func(s storage.Storage) bool { + + if !cr.Enabled() { + // do not serve file if cluster is not enabled yet + http.Error(rw, "Cluster is not enabled yet", http.StatusServiceUnavailable) + return + } + + if !cr.checkQuerySign(req, hash) { + http.Error(rw, "Cannot verify signature", http.StatusForbidden) + return + } + + log.Debugf("Handling download %s", hash) + + keepaliveRec := req.Context().Value("go-openbmclapi.handler.no.record.for.keepalive") != true + + countUA := true + if r := req.Header.Get("Range"); r != "" { + api.SetAccessInfo(req, "range", r) + if start, ok := parseRangeFirstStart(r); ok && start != 0 { + countUA = false + } + } + ua := "" + if countUA { + ua, _, _ = strings.Cut(req.UserAgent(), " ") + ua, _, _ = strings.Cut(ua, "/") + } + + rw.Header().Set("X-Bmclapi-Hash", hash) + + if _, ok := emptyHashes[hash]; ok { + name := req.URL.Query().Get("name") + rw.Header().Set("ETag", `"`+hash+`"`) + rw.Header().Set("Cache-Control", "public, max-age=31536000, immutable") // cache for a year + rw.Header().Set("Content-Type", "application/octet-stream") + rw.Header().Set("Content-Length", "0") + if name != "" { + rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) + } + rw.WriteHeader(http.StatusOK) + cr.statManager.AddHit(0, cr.ID(), "", ua) + if keepaliveRec { + cr.hits.Add(1) + } + return + } + + api.SetAccessInfo(req, "cluster", cr.ID()) + + var ( + sto storage.Storage + err error + ) + ok := cr.storageManager.ForEachFromRandom(cr.storages, func(s storage.Storage) bool { + sto = s opts := s.Options() log.Debugf("[handler]: Checking %s on storage %s ...", hash, opts.Id) @@ -40,13 +104,104 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st return false } if sz >= 0 { - cr.hits.Add(1) - cr.hbts.Add(sz) - cr.statManager.AddHit(sz, cr.ID(), opts.Id) + if keepaliveRec { + cr.hits.Add(1) + cr.hbts.Add(sz) + } + cr.statManager.AddHit(sz, cr.ID(), opts.Id, ua) } return true - }) { + }) + if sto != nil { + api.SetAccessInfo(req, "storage", sto.Options().Id) + } + if ok { return } http.Error(rw, err.Error(), http.StatusInternalServerError) } + +func (cr *Cluster) HandleMeasure(req *http.Request, rw http.ResponseWriter, size int) { + if !cr.Enabled() { + // do not serve file if cluster is not enabled yet + http.Error(rw, "Cluster is not enabled yet", http.StatusServiceUnavailable) + return + } + + if !cr.checkQuerySign(req, req.URL.Path) { + http.Error(rw, "Cannot verify signature", http.StatusForbidden) + return + } + + api.SetAccessInfo(req, "cluster", cr.ID()) + storage := cr.storageManager.Storages[cr.storages[0]] + api.SetAccessInfo(req, "storage", storage.Options().Id) + if err := storage.ServeMeasure(rw, req, size); err != nil { + log.Errorf("Could not serve measure %d: %v", size, err) + api.SetAccessInfo(req, "error", err.Error()) + http.Error(rw, err.Error(), http.StatusInternalServerError) + } +} + +func (cr *Cluster) checkQuerySign(req *http.Request, hash string) bool { + if cr.opts.SkipSignatureCheck { + return true + } + query := req.URL.Query() + sign, e := query.Get("s"), query.Get("e") + if len(sign) == 0 || len(e) == 0 { + return false + } + before, err := strconv.ParseInt(e, 36, 64) + if err != nil { + return false + } + if time.Now().UnixMilli() > before { + return false + } + hs := crypto.SHA1.New() + io.WriteString(hs, cr.Secret()) + io.WriteString(hs, hash) + io.WriteString(hs, e) + var ( + buf [20]byte + sbuf [27]byte + ) + base64.RawURLEncoding.Encode(sbuf[:], hs.Sum(buf[:0])) + if (string)(sbuf[:]) != sign { + return false + } + return true +} + +var emptyHashes = func() (hashes map[string]struct{}) { + hashMethods := []crypto.Hash{ + crypto.MD5, crypto.SHA1, + } + hashes = make(map[string]struct{}, len(hashMethods)) + for _, h := range hashMethods { + hs := hex.EncodeToString(h.New().Sum(nil)) + hashes[hs] = struct{}{} + } + return +}() + +// Note: this method is a fast parse, it does not deeply check if the range is valid or not +func parseRangeFirstStart(rg string) (start int64, ok bool) { + const b = "bytes=" + if rg, ok = strings.CutPrefix(rg, b); !ok { + return + } + rg, _, _ = strings.Cut(rg, ",") + if rg, _, ok = strings.Cut(rg, "-"); !ok { + return + } + if rg = textproto.TrimString(rg); rg == "" { + return -1, true + } + start, err := strconv.ParseInt(rg, 10, 64) + if err != nil { + return 0, false + } + return start, true +} diff --git a/cluster/stat.go b/cluster/stat.go index 30b3d8b3..8a772458 100644 --- a/cluster/stat.go +++ b/cluster/stat.go @@ -50,7 +50,7 @@ func NewStatManager() *StatManager { } } -func (m *StatManager) AddHit(bytes int64, cluster, storage string) { +func (m *StatManager) AddHit(bytes int64, cluster, storage string, userAgent string) { m.mux.Lock() defer m.mux.Unlock() @@ -59,6 +59,9 @@ func (m *StatManager) AddHit(bytes int64, cluster, storage string) { Bytes: bytes, } m.Overall.update(data) + if userAgent != "" { + m.Overall.Accesses[userAgent]++ + } if cluster != "" { d := m.Clusters[cluster] if d == nil { @@ -66,6 +69,9 @@ func (m *StatManager) AddHit(bytes int64, cluster, storage string) { m.Clusters[cluster] = d } d.update(data) + if userAgent != "" { + d.Accesses[userAgent]++ + } } if storage != "" { d := m.Storages[storage] @@ -74,6 +80,9 @@ func (m *StatManager) AddHit(bytes int64, cluster, storage string) { m.Storages[storage] = d } d.update(data) + if userAgent != "" { + d.Accesses[userAgent]++ + } } } diff --git a/config/advanced.go b/config/advanced.go index 8d3c9805..7b3fcfb8 100644 --- a/config/advanced.go +++ b/config/advanced.go @@ -26,10 +26,9 @@ type AdvancedConfig struct { NoGC bool `yaml:"no-gc"` HeavyCheckInterval int `yaml:"heavy-check-interval"` KeepaliveTimeout int `yaml:"keepalive-timeout"` - SkipSignatureCheck bool `yaml:"skip-signature-check"` NoFastEnable bool `yaml:"no-fast-enable"` WaitBeforeEnable int `yaml:"wait-before-enable"` - DoNotRedirectHTTPSToSecureHostname bool `yaml:"do-NOT-redirect-https-to-SECURE-hostname"` + // DoNotRedirectHTTPSToSecureHostname bool `yaml:"do-NOT-redirect-https-to-SECURE-hostname"` DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` } diff --git a/config/server.go b/config/server.go index d1b43046..6b16e00e 100644 --- a/config/server.go +++ b/config/server.go @@ -32,10 +32,11 @@ import ( ) type ClusterOptions struct { - Id string `json:"id" yaml:"id"` - Secret string `json:"secret" yaml:"secret"` - PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` - Prefix string `json:"prefix" yaml:"prefix"` + Id string `json:"id" yaml:"id"` + Secret string `json:"secret" yaml:"secret"` + PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` + Prefix string `json:"prefix" yaml:"prefix"` + SkipSignatureCheck bool `json:"skip-signature-check" yaml:"skip-signature-check"` } type ClusterGeneralConfig struct { diff --git a/dashboard.go b/dashboard.go index b07b0e8f..f8fbcaef 100644 --- a/dashboard.go +++ b/dashboard.go @@ -60,7 +60,7 @@ var dsbManifest = func() (dsbManifest map[string]any) { return }() -func (cr *Cluster) serveDashboard(rw http.ResponseWriter, req *http.Request, pth string) { +func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request, pth string) { if req.Method != http.MethodGet && req.Method != http.MethodHead { rw.Header().Set("Allow", http.MethodGet+", "+http.MethodHead) http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed) diff --git a/handler.go b/handler.go index 5fad8ca5..3c762647 100644 --- a/handler.go +++ b/handler.go @@ -38,6 +38,8 @@ import ( "strings" "time" + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/api/v0" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" @@ -99,15 +101,18 @@ func (r *accessRecord) String() string { return buf.String() } -func (cr *Cluster) GetHandler() http.Handler { - cr.apiRateLimiter = limited.NewAPIRateMiddleWare(RealAddrCtxKey, loggedUserKey) - cr.apiRateLimiter.SetAnonymousRateLimit(config.RateLimit.Anonymous) - cr.apiRateLimiter.SetLoggedRateLimit(config.RateLimit.Logged) - cr.handlerAPIv0 = http.StripPrefix("/api/v0", cr.cliIdHandle(cr.initAPIv0())) - cr.handlerAPIv1 = http.StripPrefix("/api/v1", cr.cliIdHandle(cr.initAPIv1())) - cr.hijackHandler = http.StripPrefix("/bmclapi", cr.hijackProxy) +var wsUpgrader = &websocket.Upgrader{ + HandshakeTimeout: time.Second * 30, +} + +func (r *Runner) GetHandler() http.Handler { + r.apiRateLimiter = limited.NewAPIRateMiddleWare(RealAddrCtxKey, loggedUserKey) + r.apiRateLimiter.SetAnonymousRateLimit(r.RateLimit.Anonymous) + r.apiRateLimiter.SetLoggedRateLimit(r.RateLimit.Logged) + r.handlerAPIv0 = http.StripPrefix("/api/v0", v0.NewHandler(wsUpgrader)) + r.hijackHandler = http.StripPrefix("/bmclapi", r.hijackProxy) - handler := utils.NewHttpMiddleWareHandler(cr) + handler := utils.NewHttpMiddleWareHandler(r) // recover panic and log it handler.UseFunc(func(rw http.ResponseWriter, req *http.Request, next http.Handler) { defer log.RecoverPanic(func(any) { @@ -115,55 +120,9 @@ func (cr *Cluster) GetHandler() http.Handler { }) next.ServeHTTP(rw, req) }) + handler.Use(r.apiRateLimiter) - if !config.Advanced.DoNotRedirectHTTPSToSecureHostname { - // rediect the client to the first public host if it is connecting with a unsecure host - handler.UseFunc(func(rw http.ResponseWriter, req *http.Request, next http.Handler) { - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host - } - if host != "" && len(cr.publicHosts) > 0 { - host = strings.ToLower(host) - needRed := true - for _, h := range cr.publicHosts { // cr.publicHosts are already lower case - if h, ok := strings.CutPrefix(h, "*."); ok { - if strings.HasSuffix(host, h) { - needRed = false - break - } - } else if host == h { - needRed = false - break - } - } - if needRed { - host := "" - for _, h := range cr.publicHosts { - if !strings.HasSuffix(h, "*.") { - host = h - break - } - } - if host != "" { - u := *req.URL - u.Scheme = "https" - u.Host = net.JoinHostPort(host, strconv.Itoa((int)(cr.publicPort))) - - log.Debugf("Redirecting from %s to %s", req.Host, u.String()) - - rw.Header().Set("Location", u.String()) - rw.Header().Set("Content-Length", "0") - rw.WriteHeader(http.StatusFound) - return - } - } - } - next.ServeHTTP(rw, req) - }) - } - - handler.Use(cr.getRecordMiddleWare()) + handler.Use(r.getRecordMiddleWare()) return handler } @@ -176,64 +135,6 @@ func (cr *Cluster) getRecordMiddleWare() utils.MiddleWareFunc { } recordCh := make(chan record, 1024) - go func() { - defer log.RecoverPanic(nil) - - <-cr.WaitForEnable() - disabled := cr.Disabled() - - updateTicker := time.NewTicker(time.Minute) - defer updateTicker.Stop() - - var ( - total int - totalUsed float64 - totalBytes float64 - uas = make(map[string]int, 10) - ) - for { - select { - case <-updateTicker.C: - cr.stats.Lock() - - log.Infof("Served %d requests, total responsed body = %s, total IO waiting time = %.2fs", - total, utils.BytesToUnit(totalBytes), totalUsed) - for ua, v := range uas { - if ua == "" { - ua = "[Unknown]" - } - cr.stats.Accesses[ua] += v - } - - total = 0 - totalUsed = 0 - totalBytes = 0 - clear(uas) - - cr.stats.Unlock() - case rec := <-recordCh: - total++ - totalUsed += rec.used - totalBytes += rec.bytes - if !rec.skipUA { - uas[rec.ua]++ - } - case <-disabled: - total = 0 - totalUsed = 0 - totalBytes = 0 - clear(uas) - - select { - case <-cr.WaitForEnable(): - disabled = cr.Disabled() - case <-time.After(time.Hour): - return - } - } - } - }() - return func(rw http.ResponseWriter, req *http.Request, next http.Handler) { ua := req.UserAgent() var addr string @@ -281,57 +182,9 @@ func (cr *Cluster) getRecordMiddleWare() utils.MiddleWareFunc { accRec.Extra = extraInfoMap } log.LogAccess(log.LevelInfo, accRec) - - if srw.Status < 200 || 400 <= srw.Status { - return - } - if !strings.HasPrefix(req.URL.Path, "/download/") { - return - } - var rec record - rec.used = used.Seconds() - rec.bytes = (float64)(srw.Wrote) - ua, _, _ = strings.Cut(ua, " ") - rec.ua, _, _ = strings.Cut(ua, "/") - rec.skipUA = extraInfoMap["skip-ua-count"] != nil - select { - case recordCh <- rec: - default: - } } } -func (cr *Cluster) checkQuerySign(req *http.Request, hash string, secret string) bool { - if config.Advanced.SkipSignatureCheck { - return true - } - query := req.URL.Query() - sign, e := query.Get("s"), query.Get("e") - if len(sign) == 0 || len(e) == 0 { - return false - } - before, err := strconv.ParseInt(e, 36, 64) - if err != nil { - return false - } - if time.Now().UnixMilli() > before { - return false - } - hs := crypto.SHA1.New() - io.WriteString(hs, secret) - io.WriteString(hs, hash) - io.WriteString(hs, e) - var ( - buf [20]byte - sbuf [27]byte - ) - base64.RawURLEncoding.Encode(sbuf[:], hs.Sum(buf[:0])) - if (string)(sbuf[:]) != sign { - return false - } - return true -} - var emptyHashes = func() (hashes map[string]struct{}) { hashMethods := []crypto.Hash{ crypto.MD5, crypto.SHA1, @@ -347,7 +200,7 @@ var emptyHashes = func() (hashes map[string]struct{}) { //go:embed robots.txt var robotTxtContent string -func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { method := req.Method u := req.URL @@ -368,19 +221,13 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - if !cr.checkQuerySign(req, hash, cr.clusterSecret) { - http.Error(rw, "Cannot verify signature", http.StatusForbidden) - return - } - - if !cr.shouldEnable.Load() { - // do not serve file if cluster is not enabled yet - http.Error(rw, "Cluster is not enabled yet", http.StatusServiceUnavailable) - return + for _, cr := range r.clusters { + if cr.AcceptHost(req.Host) { + cr.HandleFile(rw, req, hash) + return + } } - - log.Debugf("Handling download %s", hash) - cr.handleDownload(rw, req, hash) + http.Error(rw, "Host have not bind to a cluster", http.StatusNotFound) return case strings.HasPrefix(rawpath, "/measure/"): if method != http.MethodGet && method != http.MethodHead { @@ -389,162 +236,50 @@ func (cr *Cluster) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } - if !cr.checkQuerySign(req, u.Path, cr.clusterSecret) { - http.Error(rw, "Cannot verify signature", http.StatusForbidden) - return - } - - size := rawpath[len("/measure/"):] - n, e := strconv.Atoi(size) + size, e := strconv.Atoi(rawpath[len("/measure/"):]) if e != nil { http.Error(rw, e.Error(), http.StatusBadRequest) return - } else if n < 0 || n > 200 { - http.Error(rw, fmt.Sprintf("measure size %d out of range (0, 200]", n), http.StatusBadRequest) + } else if size < 0 || size > 200 { + http.Error(rw, fmt.Sprintf("measure size %d out of range (0, 200]", size), http.StatusBadRequest) return } - if err := cr.storages[0].ServeMeasure(rw, req, n); err != nil { - log.Errorf("Could not serve measure %d: %v", n, err) - SetAccessInfo(req, "error", err.Error()) - http.Error(rw, err.Error(), http.StatusInternalServerError) + + for _, cr := range r.clusters { + if cr.AcceptHost(req.Host) { + cr.HandleFile(rw, req, hash) + return + } } + http.Error(rw, "Host have not bind to a cluster", http.StatusNotFound) + return + case rawpath == "/robots.txt": + http.ServeContent(rw, req, "robots.txt", time.Time{}, strings.NewReader(robotTxtContent)) return case strings.HasPrefix(rawpath, "/api/"): version, _, _ := strings.Cut(rawpath[len("/api/"):], "/") switch version { case "v0": - cr.handlerAPIv0.ServeHTTP(rw, req) + r.handlerAPIv0.ServeHTTP(rw, req) return case "v1": - cr.handlerAPIv1.ServeHTTP(rw, req) + r.handlerAPIv1.ServeHTTP(rw, req) return } - case rawpath == "/robots.txt": - http.ServeContent(rw, req, "robots.txt", time.Time{}, strings.NewReader(robotTxtContent)) + case rawpath == "/" || rawpath == "/dashboard": + http.Redirect(rw, req, "/dashboard/", http.StatusFound) return case strings.HasPrefix(rawpath, "/dashboard/"): - if !config.Dashboard.Enable { + if !r.DashboardEnabled { http.NotFound(rw, req) return } pth := rawpath[len("/dashboard/"):] - cr.serveDashboard(rw, req, pth) - return - case rawpath == "/" || rawpath == "/dashboard": - http.Redirect(rw, req, "/dashboard/", http.StatusFound) + r.serveDashboard(rw, req, pth) return case strings.HasPrefix(rawpath, "/bmclapi/"): - cr.hijackHandler.ServeHTTP(rw, req) + r.hijackHandler.ServeHTTP(rw, req) return } http.NotFound(rw, req) } - -func (cr *Cluster) handleDownload(rw http.ResponseWriter, req *http.Request, hash string) { - keepaliveRec := req.Context().Value("go-openbmclapi.handler.no.record.for.keepalive") != true - rw.Header().Set("X-Bmclapi-Hash", hash) - - if _, ok := emptyHashes[hash]; ok { - name := req.URL.Query().Get("name") - rw.Header().Set("ETag", `"`+hash+`"`) - rw.Header().Set("Cache-Control", "public, max-age=31536000, immutable") // cache for a year - rw.Header().Set("Content-Type", "application/octet-stream") - rw.Header().Set("Content-Length", "0") - if name != "" { - rw.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", name)) - } - rw.WriteHeader(http.StatusOK) - cr.stats.AddHits(1, 0, "") - if !keepaliveRec { - cr.statOnlyHits.Add(1) - } - return - } - - if r := req.Header.Get("Range"); r != "" { - if start, ok := parseRangeFirstStart(r); ok && start != 0 { - SetAccessInfo(req, "skip-ua-count", "range") - } - } - - var err error - // check if file was indexed in the fileset - size, ok := cr.CachedFileSize(hash) - if !ok { - if err := cr.DownloadFile(req.Context(), hash); err != nil { - // TODO: check if the file exists - estr := "Cannot download file from center server: " + err.Error() - SetAccessInfo(req, "error", estr) - http.Error(rw, estr, http.StatusInternalServerError) - return - } - } - var sto storage.Storage - if forEachFromRandomIndexWithPossibility(cr.storageWeights, cr.storageTotalWeight, func(i int) bool { - sto = cr.storages[i] - log.Debugf("[handler]: Checking %s on storage [%d] %s ...", hash, i, sto.Options().Id) - - sz, er := sto.ServeDownload(rw, req, hash, size) - if er != nil { - log.Debugf("[handler]: File %s failed on storage [%d] %s: %v", hash, i, sto.Options().Id, er) - err = er - return false - } - if sz >= 0 { - opts := cr.storageOpts[i] - cr.stats.AddHits(1, sz, opts.Id) - if !keepaliveRec { - cr.statOnlyHits.Add(1) - cr.statOnlyHbts.Add(sz) - } - } - return true - }) { - err = nil - } - if sto != nil { - SetAccessInfo(req, "storage", sto.Options().Id) - } - if err != nil { - log.Debugf("[handler]: failed to serve download: %v", err) - if errors.Is(err, os.ErrNotExist) { - http.Error(rw, "404 Status Not Found", http.StatusNotFound) - return - } - SetAccessInfo(req, "error", err.Error()) - if _, ok := err.(*utils.HTTPStatusError); ok { - http.Error(rw, err.Error(), http.StatusBadGateway) - } else { - http.Error(rw, err.Error(), http.StatusInternalServerError) - } - if err == storage.ErrNotWorking { - log.Errorf("All storages are down, exit.") - tctx, cancel := context.WithTimeout(context.TODO(), time.Second*10) - cr.Disable(tctx) - cancel() - osExit(CodeClientOrEnvionmentError) - } - return - } - log.Debug("[handler]: download served successed") -} - -// Note: this method is a fast parse, it does not deeply check if the range is valid or not -func parseRangeFirstStart(rg string) (start int64, ok bool) { - const b = "bytes=" - if rg, ok = strings.CutPrefix(rg, b); !ok { - return - } - rg, _, _ = strings.Cut(rg, ",") - if rg, _, ok = strings.Cut(rg, "-"); !ok { - return - } - if rg = textproto.TrimString(rg); rg == "" { - return -1, true - } - start, err := strconv.ParseInt(rg, 10, 64) - if err != nil { - return 0, false - } - return start, true -} diff --git a/main.go b/main.go index 766ef043..0bb6cb0e 100644 --- a/main.go +++ b/main.go @@ -132,9 +132,9 @@ func main() { log.TrErrorf("program.exited", code) log.TrErrorf("error.exit.please.read.faq") if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { - log.TrWarnf("warn.exit.detected.windows.open.browser") - cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") - cmd.Start() + // log.TrWarnf("warn.exit.detected.windows.open.browser") + // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") + // cmd.Start() time.Sleep(time.Hour) } } @@ -145,7 +145,6 @@ func main() { r := new(Runner) -START: ctx, cancel := context.WithCancel(context.Background()) config = readConfig() @@ -222,15 +221,10 @@ START: }(ctx) code := r.DoSignals(cancel) - if r.restartFlag { - goto START - } exitCode = code } type Runner struct { - restartFlag bool - cluster *Cluster clusterSvr *http.Server @@ -261,13 +255,13 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) defer signal.Stop(signalCh) - r.restartFlag = false for { select { case code := <-exitCh: return code case s := <-signalCh: - if s == syscall.SIGQUIT { + switch s { + case syscall.SIGQUIT: // avaliable commands see dumpCommand := "heap" dumpFileName := filepath.Join(os.TempDir(), fmt.Sprintf("go-openbmclapi-dump-command.%d.in", os.Getpid())) @@ -302,37 +296,40 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { } } continue + case syscall.SIGHUP: + r.ReloadConfig() + default: + cancel() + r.StopServer(signalCh) } - cancel() - shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) - log.TrWarnf("warn.server.closing") - shutExit := make(chan struct{}, 0) - go func() { - defer close(shutExit) - defer cancelShut() - r.cluster.Disable(shutCtx) - log.TrWarnf("warn.httpserver.closing") - r.clusterSvr.Shutdown(shutCtx) - }() - select { - case <-shutExit: - case s := <-signalCh: - log.Warn("signal:", s) - log.Error("Second close signal received, exit") - return CodeClientError - } - log.TrWarnf("warn.server.closed") - if s == syscall.SIGHUP { - log.Info("Restarting server ...") - r.restartFlag = true - return 0 - } } return 0 } } +func (r *Runner) StopServer(sigCh <-chan os.Signal) { + shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) + defer cancelShut() + log.TrWarnf("warn.server.closing") + shutDone := make(chan struct{}, 0) + go func() { + defer close(shutDone) + defer cancelShut() + r.cluster.Disable(shutCtx) + log.TrWarnf("warn.httpserver.closing") + r.clusterSvr.Shutdown(shutCtx) + }() + select { + case <-shutDone: + case s := <-sigCh: + log.Warn("signal:", s) + log.Error("Second close signal received, forcely exit") + return + } + log.TrWarnf("warn.server.closed") +} + func (r *Runner) InitCluster(ctx context.Context) { var ( dialer *net.Dialer From 4fe7a796907ffcbd0ed5638be2d525c10ff68615 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 10 Aug 2024 20:40:27 -0700 Subject: [PATCH 17/36] start to reforge storage sync --- cluster/cluster.go | 13 +- cluster/config.go | 64 ------- cluster/handler.go | 4 +- cluster/storage.go | 354 +++++++++++++++++++++++++++++++++++ config/advanced.go | 2 +- config/config.go | 2 + config/server.go | 4 +- handler.go | 4 +- main.go | 334 ++++++++++++++++++--------------- storage/manager.go | 2 +- storage/storage.go | 1 + storage/storage_local.go | 4 + storage/storage_mount.go | 4 + storage/storage_webdav.go | 4 + sub_commands/cmd_compress.go | 2 +- sub_commands/cmd_webdav.go | 4 +- sync.go | 28 --- utils/http.go | 22 ++- 18 files changed, 589 insertions(+), 263 deletions(-) create mode 100644 cluster/storage.go diff --git a/cluster/cluster.go b/cluster/cluster.go index f48ae384..3f7943f2 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -63,8 +63,9 @@ type Cluster struct { client *http.Client cachedCli *http.Client - authTokenMux sync.RWMutex - authToken *ClusterToken + authTokenMux sync.RWMutex + authToken *ClusterToken + fileListLastMod int64 } func NewCluster( @@ -95,12 +96,12 @@ func (cr *Cluster) Secret() string { // Host returns the cluster public host func (cr *Cluster) Host() string { - return cr.gcfg.Host + return cr.gcfg.PublicHost } // Port returns the cluster public port func (cr *Cluster) Port() uint16 { - return cr.gcfg.Port + return cr.gcfg.PublicPort } // PublicHosts returns the cluster public hosts @@ -168,8 +169,8 @@ func (cr *Cluster) enable(ctx context.Context) error { log.TrInfof("info.cluster.enable.sending") resCh, err := cr.socket.EmitWithAck("enable", EnableData{ - Host: cr.gcfg.Host, - Port: cr.gcfg.Port, + Host: cr.gcfg.PublicHost, + Port: cr.gcfg.PublicPort, Version: build.ClusterVersion, Byoc: cr.gcfg.Byoc, NoFastEnable: cr.gcfg.NoFastEnable, diff --git a/cluster/config.go b/cluster/config.go index c8fac95d..2eebc617 100644 --- a/cluster/config.go +++ b/cluster/config.go @@ -29,12 +29,8 @@ import ( "fmt" "net/http" "net/url" - "strconv" "time" - "github.com/hamba/avro/v2" - "github.com/klauspost/compress/zstd" - "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -263,63 +259,3 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error } return } - -type FileInfo struct { - Path string `json:"path" avro:"path"` - Hash string `json:"hash" avro:"hash"` - Size int64 `json:"size" avro:"size"` - Mtime int64 `json:"mtime" avro:"mtime"` -} - -// from -var fileListSchema = avro.MustParse(`{ - "type": "array", - "items": { - "type": "record", - "name": "fileinfo", - "fields": [ - {"name": "path", "type": "string"}, - {"name": "hash", "type": "string"}, - {"name": "size", "type": "long"}, - {"name": "mtime", "type": "long"} - ] - } -}`) - -func (cr *Cluster) GetFileList(ctx context.Context, lastMod int64) (files []FileInfo, err error) { - var query url.Values - if lastMod > 0 { - query = url.Values{ - "lastModified": {strconv.FormatInt(lastMod, 10)}, - } - } - req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) - if err != nil { - return - } - res, err := cr.cachedCli.Do(req) - if err != nil { - return - } - defer res.Body.Close() - switch res.StatusCode { - case http.StatusOK: - // - case http.StatusNoContent, http.StatusNotModified: - return - default: - err = utils.NewHTTPStatusErrorFromResponse(res) - return - } - log.Debug("Parsing filelist body ...") - zr, err := zstd.NewReader(res.Body) - if err != nil { - return - } - defer zr.Close() - if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { - return - } - log.Debugf("Filelist parsed, length = %d", len(files)) - return -} diff --git a/cluster/handler.go b/cluster/handler.go index a14ca248..dd2151a4 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -113,7 +113,7 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st return true }) if sto != nil { - api.SetAccessInfo(req, "storage", sto.Options().Id) + api.SetAccessInfo(req, "storage", sto.Id()) } if ok { return @@ -135,7 +135,7 @@ func (cr *Cluster) HandleMeasure(req *http.Request, rw http.ResponseWriter, size api.SetAccessInfo(req, "cluster", cr.ID()) storage := cr.storageManager.Storages[cr.storages[0]] - api.SetAccessInfo(req, "storage", storage.Options().Id) + api.SetAccessInfo(req, "storage", storage.Id()) if err := storage.ServeMeasure(rw, req, size); err != nil { log.Errorf("Could not serve measure %d: %v", size, err) api.SetAccessInfo(req, "error", err.Error()) diff --git a/cluster/storage.go b/cluster/storage.go new file mode 100644 index 00000000..9b7a65c2 --- /dev/null +++ b/cluster/storage.go @@ -0,0 +1,354 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster + +import ( + "context" + "fmt" + "io" + "net/http" + "net/url" + "runtime" + "slices" + "strconv" + "sync" + "sync/atomic" + "crypto" + "time" + "encoding/hex" + + "github.com/hamba/avro/v2" + "github.com/klauspost/compress/zstd" + "github.com/vbauerster/mpb/v8" + "github.com/vbauerster/mpb/v8/decor" + + "github.com/LiterMC/go-openbmclapi/lang" + "github.com/LiterMC/go-openbmclapi/limited" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/utils" +) + +// from +var fileListSchema = avro.MustParse(`{ + "type": "array", + "items": { + "type": "record", + "name": "fileinfo", + "fields": [ + {"name": "path", "type": "string"}, + {"name": "hash", "type": "string"}, + {"name": "size", "type": "long"}, + {"name": "mtime", "type": "long"} + ] + } +}`) + +type FileInfo struct { + Path string `json:"path" avro:"path"` + Hash string `json:"hash" avro:"hash"` + Size int64 `json:"size" avro:"size"` + Mtime int64 `json:"mtime" avro:"mtime"` +} + +type StorageFileInfo struct { + FileInfo + Storages []storage.Storage +} + +func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageFileInfo, forceAll bool) (err error) { + var query url.Values + lastMod := cr.fileListLastMod + if forceAll { + lastMod = 0 + } + if lastMod > 0 { + query = url.Values{ + "lastModified": {strconv.FormatInt(lastMod, 10)}, + } + } + req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) + if err != nil { + return + } + res, err := cr.cachedCli.Do(req) + if err != nil { + return + } + defer res.Body.Close() + switch res.StatusCode { + case http.StatusOK: + // + case http.StatusNoContent, http.StatusNotModified: + return + default: + err = utils.NewHTTPStatusErrorFromResponse(res) + return + } + log.Debug("Parsing filelist body ...") + zr, err := zstd.NewReader(res.Body) + if err != nil { + return + } + defer zr.Close() + var files []FileInfo + if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { + return + } + + for _, f := range files { + if f.Mtime > lastMod { + lastMod = f.Mtime + } + if ff, ok := fileMap[f.Hash]; ok { + if ff.Size != f.Size { + log.Panicf("Hash conflict detected, hash of both %q (%dB) and %q (%dB) is %s", ff.Path, ff.Size, f.Path, f.Size, f.Hash) + } + for _, s := range cr.storages { + sto := cr.storageManager.Storages[s] + if i, ok := slices.BinarySearchFunc(ff.Storages, sto, storageIdSortFunc); !ok { + ff.Storages = slices.Insert(ff.Storages, i, sto) + } + } + } else { + ff := &StorageFileInfo{ + FileInfo: f, + Storages: make([]storage.Storage, len(cr.storages)), + } + for i, s := range cr.storages { + ff.Storages[i] = cr.storageManager.Storages[s] + } + slices.SortFunc(ff.Storages, storageIdSortFunc) + fileMap[f.Hash] = ff + } + } + cr.fileListLastMod = lastMod + log.Debugf("Filelist parsed, length = %d, lastMod = %d", len(files), lastMod) + return +} + +func storageIdSortFunc(a, b storage.Storage) int { + if a.Id() < b.Id() { + return -1 + } + return 1 +} + +// func SyncFiles(ctx context.Context, manager *storage.Manager, files map[string]*StorageFileInfo, heavyCheck bool) bool { +// log.TrInfof("info.sync.prepare", len(files)) + +// slices.SortFunc(files, func(a, b *StorageFileInfo) int { return a.Size - b.Size }) +// if cr.syncFiles(ctx, files, heavyCheck) != nil { +// return false +// } + +// cr.filesetMux.Lock() +// for _, f := range files { +// cr.fileset[f.Hash] = f.Size +// } +// cr.filesetMux.Unlock() + +// return true +// } + +var emptyStr string + +func checkFile( + ctx context.Context, + manager *storage.Manager, + files map[string]*StorageFileInfo, + heavy bool, + missing map[string]*StorageFileInfo, + pg *mpb.Progress, +) (err error) { + var missingCount atomic.Int32 + addMissing := func(f FileInfo, sto storage.Storage) { + missingCount.Add(1) + if info, ok := missing[f.Hash]; ok { + info.Storages = append(info.Storages, sto) + } else { + missing[f.Hash] = &StorageFileInfo{ + FileInfo: f, + Storages: []storage.Storage{sto}, + } + } + } + + log.TrInfof("info.check.start", heavy) + + var ( + checkingHash atomic.Pointer[string] + lastCheckingHash string + slots *limited.BufSlots + wg sync.WaitGroup + ) + checkingHash.Store(&emptyStr) + + if heavy { + slots = limited.NewBufSlots(runtime.GOMAXPROCS(0) * 2) + } + + bar := pg.AddBar(0, + mpb.BarRemoveOnComplete(), + mpb.PrependDecorators( + decor.Name(lang.Tr("hint.check.checking")), + decor.OnCondition( + decor.Any(func(decor.Statistics) string { + c, l := slots.Cap(), slots.Len() + return fmt.Sprintf(" (%d / %d)", c-l, c) + }), + heavy, + ), + ), + mpb.AppendDecorators( + decor.CountersNoUnit("%d / %d", decor.WCSyncSpaceR), + decor.NewPercentage("%d", decor.WCSyncSpaceR), + decor.EwmaETA(decor.ET_STYLE_GO, 60), + ), + mpb.BarExtender((mpb.BarFillerFunc)(func(w io.Writer, _ decor.Statistics) (err error) { + lastCheckingHash = *checkingHash.Load() + if lastCheckingHash != "" { + _, err = fmt.Fprintln(w, "\t", lastCheckingHash) + } + return + }), false), + ) + defer bar.Wait() + defer bar.Abort(true) + + bar.SetTotal(0x100, false) + + ssizeMap := make(map[storage.Storage]map[string]int64, len(manager.Storages)) + for _, sto := range manager.Storages { + sizeMap := make(map[string]int64, len(files)) + ssizeMap[sto] = sizeMap + wg.Add(1) + go func(sto storage.Storage, sizeMap map[string]int64) { + defer wg.Done() + start := time.Now() + var checkedMp [256]bool + if err := sto.WalkDir(func(hash string, size int64) error { + if n := utils.HexTo256(hash); !checkedMp[n] { + checkedMp[n] = true + now := time.Now() + bar.EwmaIncrement(now.Sub(start)) + start = now + } + sizeMap[hash] = size + return nil + }); err != nil { + log.Errorf("Cannot walk %s: %v", sto.Id(), err) + return + } + }(sto, sizeMap) + } + wg.Wait() + + bar.SetCurrent(0) + bar.SetTotal((int64)(len(files)), false) + for _, f := range files { + if err := ctx.Err(); err != nil { + return err + } + start := time.Now() + hash := f.Hash + checkingHash.Store(&hash) + if f.Size == 0 { + log.Debugf("Skipped empty file %s", hash) + bar.EwmaIncrement(time.Since(start)) + continue + } + for _, sto := range f.Storages { + name := sto.Id() + "/" + hash + size, ok := ssizeMap[sto][hash] + if !ok { + // log.Debugf("Could not found file %q", name) + addMissing(f.FileInfo, sto) + bar.EwmaIncrement(time.Since(start)) + continue + } + if size != f.Size { + log.TrWarnf("warn.check.modified.size", name, size, f.Size) + addMissing(f.FileInfo, sto) + bar.EwmaIncrement(time.Since(start)) + continue + } + if !heavy { + bar.EwmaIncrement(time.Since(start)) + continue + } + hashMethod, err := getHashMethod(len(hash)) + if err != nil { + log.TrErrorf("error.check.unknown.hash.method", hash) + bar.EwmaIncrement(time.Since(start)) + continue + } + _, buf, free := slots.Alloc(ctx) + if buf == nil { + return ctx.Err() + } + wg.Add(1) + go func(f FileInfo, buf []byte, free func()) { + defer log.RecoverPanic(nil) + defer wg.Done() + miss := true + r, err := sto.Open(hash) + if err != nil { + log.TrErrorf("error.check.open.failed", name, err) + } else { + hw := hashMethod.New() + _, err = io.CopyBuffer(hw, r, buf[:]) + r.Close() + if err != nil { + log.TrErrorf("error.check.hash.failed", name, err) + } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != hash { + log.TrWarnf("warn.check.modified.hash", name, hs, hash) + } else { + miss = false + } + } + bar.EwmaIncrement(time.Since(start)) + free() + if miss { + addMissing(f, sto) + } + }(f.FileInfo, buf, free) + } + } + wg.Wait() + + checkingHash.Store(&emptyStr) + + bar.SetTotal(-1, true) + log.TrInfof("info.check.done", missingCount.Load()) + return nil +} + +func getHashMethod(l int) (hashMethod crypto.Hash, err error) { + switch l { + case 32: + hashMethod = crypto.MD5 + case 40: + hashMethod = crypto.SHA1 + default: + err = fmt.Errorf("Unknown hash length %d", l) + } + return +} diff --git a/config/advanced.go b/config/advanced.go index 7b3fcfb8..99fc1320 100644 --- a/config/advanced.go +++ b/config/advanced.go @@ -30,5 +30,5 @@ type AdvancedConfig struct { WaitBeforeEnable int `yaml:"wait-before-enable"` // DoNotRedirectHTTPSToSecureHostname bool `yaml:"do-NOT-redirect-https-to-SECURE-hostname"` - DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` + DoNotOpenFAQOnWindows bool `yaml:"do-not-open-faq-on-windows"` } diff --git a/config/config.go b/config/config.go index 747b961b..5c843c9d 100644 --- a/config/config.go +++ b/config/config.go @@ -34,6 +34,7 @@ import ( type Config struct { PublicHost string `yaml:"public-host"` PublicPort uint16 `yaml:"public-port"` + Host string `yaml:"host"` Port uint16 `yaml:"port"` Byoc bool `yaml:"byoc"` UseCert bool `yaml:"use-cert"` @@ -76,6 +77,7 @@ func NewDefaultConfig() *Config { return &Config{ PublicHost: "", PublicPort: 0, + Host: "0.0.0.0", Port: 4000, Byoc: false, TrustedXForwardedFor: false, diff --git a/config/server.go b/config/server.go index 6b16e00e..32963a58 100644 --- a/config/server.go +++ b/config/server.go @@ -40,8 +40,8 @@ type ClusterOptions struct { } type ClusterGeneralConfig struct { - Host string `json:"host"` - Port uint16 `json:"port"` + PublicHost string `json:"public-host"` + PublicPort uint16 `json:"public-port"` Byoc bool `json:"byoc"` NoFastEnable bool `json:"no-fast-enable"` MaxReconnectCount int `json:"max-reconnect-count"` diff --git a/handler.go b/handler.go index 3c762647..bedc3fb6 100644 --- a/handler.go +++ b/handler.go @@ -38,6 +38,8 @@ import ( "strings" "time" + "github.com/gorilla/websocket" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/api/v0" "github.com/LiterMC/go-openbmclapi/internal/build" @@ -126,7 +128,7 @@ func (r *Runner) GetHandler() http.Handler { return handler } -func (cr *Cluster) getRecordMiddleWare() utils.MiddleWareFunc { +func (r *Runner) getRecordMiddleWare() utils.MiddleWareFunc { type record struct { used float64 bytes float64 diff --git a/main.go b/main.go index 0bb6cb0e..21fd45f2 100644 --- a/main.go +++ b/main.go @@ -44,11 +44,14 @@ import ( doh "github.com/libp2p/go-doh-resolver" + "github.com/LiterMC/go-openbmclapi/config" + "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/lang" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" + subcmds "github.com/LiterMC/go-openbmclapi/sub_commands" _ "github.com/LiterMC/go-openbmclapi/lang/en" _ "github.com/LiterMC/go-openbmclapi/lang/zh" @@ -79,14 +82,14 @@ func parseArgs() { case "help", "--help": printHelp() os.Exit(0) - case "zip-cache": - cmdZipCache(os.Args[2:]) - os.Exit(0) - case "unzip-cache": - cmdUnzipCache(os.Args[2:]) - os.Exit(0) + // case "zip-cache": + // cmdZipCache(os.Args[2:]) + // os.Exit(0) + // case "unzip-cache": + // cmdUnzipCache(os.Args[2:]) + // os.Exit(0) case "upload-webdav": - cmdUploadWebdav(os.Args[2:]) + subcmds.CmdUploadWebdav(os.Args[2:]) os.Exit(0) default: fmt.Println("Unknown sub command:", subcmd) @@ -96,16 +99,6 @@ func parseArgs() { } } -var exitCh = make(chan int, 1) - -func osExit(n int) { - select { - case exitCh <- n: - default: - } - runtime.Goexit() -} - func main() { if runtime.GOOS == "windows" { lang.SetLang("zh-cn") @@ -118,28 +111,6 @@ func main() { printShortLicense() parseArgs() - exitCode := -1 - defer func() { - code := exitCode - if code == -1 { - select { - case code = <-exitCh: - default: - code = 0 - } - } - if code != 0 { - log.TrErrorf("program.exited", code) - log.TrErrorf("error.exit.please.read.faq") - if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { - // log.TrWarnf("warn.exit.detected.windows.open.browser") - // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") - // cmd.Start() - time.Sleep(time.Hour) - } - } - os.Exit(code) - }() defer log.RecordPanic() log.StartFlushLogFile() @@ -147,28 +118,32 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - config = readConfig() - if config.Advanced.DebugLog { + if config, err := readAndRewriteConfig(); err != nil { + log.Errorf("Config error: %s", err) + os.Exit(1) + } else { + r.Config = config + } + if r.Config.Advanced.DebugLog { log.SetLevel(log.LevelDebug) } else { log.SetLevel(log.LevelInfo) } - if config.NoAccessLog { + if r.Config.NoAccessLog { log.SetAccessLogSlots(-1) } else { - log.SetAccessLogSlots(config.AccessLogSlots) + log.SetAccessLogSlots(r.Config.AccessLogSlots) } - config.applyWebManifest(dsbManifest) + r.Config.applyWebManifest(dsbManifest) log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion) - if config.ClusterId == defaultConfig.ClusterId || config.ClusterSecret == defaultConfig.ClusterSecret { - log.TrErrorf("error.set.cluster.id") - osExit(CodeClientError) + if r.Config.Tunneler.Enable { + r.StartTunneler() } - - r.InitCluster(ctx) + r.InitServer() + r.InitClusters(ctx) go func(ctx context.Context) { defer log.RecordPanic() @@ -189,7 +164,7 @@ func main() { defer listener.Close() if err := r.clusterSvr.Serve(listener); !errors.Is(err, http.ErrServerClosed) { log.Error("Error when serving:", err) - osExit(CodeClientError) + os.Exit(1) } }(listener) @@ -221,25 +196,39 @@ func main() { }(ctx) code := r.DoSignals(cancel) - exitCode = code + if code != 0 { + log.TrErrorf("program.exited", code) + log.TrErrorf("error.exit.please.read.faq") + if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { + // log.TrWarnf("warn.exit.detected.windows.open.browser") + // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") + // cmd.Start() + time.Sleep(time.Hour) + } + } + os.Exit(code) } type Runner struct { - cluster *Cluster - clusterSvr *http.Server + Config *config.Config + + clusters map[string]*Cluster + server *http.Server tlsConfig *tls.Config listener net.Listener publicHosts []string - updating atomic.Bool + reloading atomic.Bool + updating atomic.Bool + tunnelCancel context.CancelFunc } func (r *Runner) getPublicPort() uint16 { - if config.PublicPort > 0 { - return config.PublicPort + if r.Config.PublicPort > 0 { + return r.Config.PublicPort } - return config.Port + return r.Config.Port } func (r *Runner) getCertCount() int { @@ -255,10 +244,15 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) defer signal.Stop(signalCh) + var ( + forceStop context.CancelFunc + exited = make(chan struct{}, 0) + ) + for { select { - case code := <-exitCh: - return code + case <-exited: + return 0 case s := <-signalCh: switch s { case syscall.SIGQUIT: @@ -295,20 +289,46 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { log.Info("Dump file created") } } - continue case syscall.SIGHUP: - r.ReloadConfig() + go r.ReloadConfig() default: cancel() - r.StopServer(signalCh) + if forceStop == nil { + ctx, cancel := context.WithCancel(context.Background()) + forceStop = cancel + go func() { + defer close(exited) + r.StopServer(ctx) + }() + } else { + log.Warn("signal:", s) + log.Error("Second close signal received, forcely shutting down") + forceStop() + } } } - return 0 } + return 0 } -func (r *Runner) StopServer(sigCh <-chan os.Signal) { +func (r *Runner) ReloadConfig() { + if r.reloading.CompareAndSwap(false, true) { + log.Error("Config is already reloading!") + return + } + defer r.reloading.Store(false) + + config, err := readAndRewriteConfig() + if err != nil { + log.Errorf("Config error: %s", err) + } else { + r.Config = config + } +} + +func (r *Runner) StopServer(ctx context.Context) { + r.tunnelCancel() shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) defer cancelShut() log.TrWarnf("warn.server.closing") @@ -316,24 +336,39 @@ func (r *Runner) StopServer(sigCh <-chan os.Signal) { go func() { defer close(shutDone) defer cancelShut() - r.cluster.Disable(shutCtx) + var wg sync.WaitGroup + for _, cr := range r.clusters { + go func() { + defer wg.Done() + cr.Disable(shutCtx) + }() + } + wg.Wait() log.TrWarnf("warn.httpserver.closing") - r.clusterSvr.Shutdown(shutCtx) + r.server.Shutdown(shutCtx) }() select { case <-shutDone: - case s := <-sigCh: - log.Warn("signal:", s) - log.Error("Second close signal received, forcely exit") + case <-ctx.Done(): return } log.TrWarnf("warn.server.closed") } -func (r *Runner) InitCluster(ctx context.Context) { +func (r *Runner) InitServer() { + r.server = &http.Server{ + Addr: fmt.Sprintf("%s:%d", d.Config.Host, d.Config.Port), + ReadTimeout: 10 * time.Second, + IdleTimeout: 5 * time.Second, + Handler: r, + ErrorLog: log.ProxiedStdLog, + } +} + +func (r *Runner) InitClusters(ctx context.Context) { var ( dialer *net.Dialer - cache = config.Cache.newCache() + cache = r.Config.Cache.newCache() ) _ = doh.NewResolver // TODO: use doh resolver @@ -349,20 +384,12 @@ func (r *Runner) InitCluster(ctx context.Context) { ) if err := r.cluster.Init(ctx); err != nil { log.Errorf(Tr("error.init.failed"), err) - osExit(CodeClientError) - } - - r.clusterSvr = &http.Server{ - Addr: fmt.Sprintf("%s:%d", "0.0.0.0", config.Port), - ReadTimeout: 10 * time.Second, - IdleTimeout: 5 * time.Second, - Handler: r.cluster.GetHandler(), - ErrorLog: log.ProxiedStdLog, + os.Exit(1) } } -func (r *Runner) UpdateFileRecords(files []FileInfo, oldfileset map[string]int64) { - if !config.Hijack.Enable { +func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) { + if !r.hijacker.Enabled { return } if !r.updating.CompareAndSwap(false, true) { @@ -394,15 +421,14 @@ func (r *Runner) UpdateFileRecords(files []FileInfo, oldfileset map[string]int64 } func (r *Runner) InitSynchronizer(ctx context.Context) { - log.Info(Tr("info.filelist.fetching")) - fl, err := r.cluster.GetFileList(ctx, 0) - if err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), err) - if errors.Is(err, context.Canceled) { - return - } - if !config.Advanced.SkipFirstSync { - osExit(CodeClientOrServerError) + fileMap := make(map[string]*StorageFileInfo) + for _, cr := range r.clusters { + log.Info(Tr("info.filelist.fetching"), cr.ID()) + if err := cr.GetFileList(ctx, fileMap, true); err != nil { + log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + if errors.Is(err, context.Canceled) { + return + } } } @@ -414,10 +440,10 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } if !config.Advanced.SkipFirstSync { - if !r.cluster.SyncFiles(ctx, fl, false) { + if !r.cluster.SyncFiles(ctx, fileMap, false) { return } - go r.UpdateFileRecords(fl, nil) + go r.UpdateFileRecords(fileMap, nil) if !config.Advanced.NoGC { go r.cluster.Gc() @@ -428,29 +454,19 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } } - var lastMod int64 - for _, f := range fl { - if f.Mtime > lastMod { - lastMod = f.Mtime - } - } - createInterval(ctx, func() { - log.Info(Tr("info.filelist.fetching")) - fl, err := r.cluster.GetFileList(ctx, lastMod) - if err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), err) - return + fileMap := make(map[string]*StorageFileInfo) + for _, cr := range r.clusters { + log.Info(Tr("info.filelist.fetching"), cr.ID()) + if err := cr.GetFileList(ctx, fileMap, false); err != nil { + log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + return + } } - if len(fl) == 0 { - log.Infof("No file was updated since %s", time.UnixMilli(lastMod).Format(time.DateTime)) + if len(fileMap) == 0 { + log.Infof("No file was updated since last check") return } - for _, f := range fl { - if f.Mtime > lastMod { - lastMod = f.Mtime - } - } checkCount = (checkCount + 1) % heavyCheckInterval oldfileset := r.cluster.CloneFileset() @@ -464,12 +480,12 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Listener) { - listener, err := net.Listen("tcp", r.clusterSvr.Addr) + listener, err := net.Listen("tcp", r.Addr) if err != nil { - log.Errorf(Tr("error.address.listen.failed"), r.clusterSvr.Addr, err) + log.Errorf(Tr("error.address.listen.failed"), r.Addr, err) osExit(CodeEnvironmentError) } - if config.ServeLimit.Enable { + if r.Config.ServeLimit.Enable { limted := limited.NewLimitedListener(listener, config.ServeLimit.MaxConn, 0, config.ServeLimit.UploadRate*1024) limted.SetMinWriteRate(1024) listener = limted @@ -483,7 +499,7 @@ func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Lis r.publicHosts = append(r.publicHosts, strings.ToLower(h)) } } - listener = newHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort()) + listener = utils.NewHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort()) } r.listener = listener return @@ -493,7 +509,7 @@ func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) if config.UseCert { if len(config.Certificates) == 0 { log.Error(Tr("error.cert.not.set")) - osExit(CodeClientError) + os.Exit(1) } tlsConfig = new(tls.Config) tlsConfig.Certificates = make([]tls.Certificate, len(config.Certificates)) @@ -502,31 +518,33 @@ func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) if err != nil { log.Errorf(Tr("error.cert.parse.failed"), i, err) - osExit(CodeClientError) + os.Exit(1) } } } if !config.Byoc { - log.Info(Tr("info.cert.requesting")) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := r.cluster.RequestCert(tctx) - cancel() - if err != nil { - log.Errorf(Tr("error.cert.request.failed"), err) - osExit(CodeServerError) - } - if tlsConfig == nil { - tlsConfig = new(tls.Config) - } - var cert tls.Certificate - cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.Errorf(Tr("error.cert.requested.parse.failed"), err) - osExit(CodeServerUnexpectedError) + for _, cr := range r.clusters { + log.Info(Tr("info.cert.requesting"), cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.Errorf(Tr("error.cert.request.failed"), err) + os.Exit(2) + } + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } + var cert tls.Certificate + cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.Errorf(Tr("error.cert.requested.parse.failed"), err) + os.Exit(2) + } + tlsConfig.Certificates = append(tlsConfig.Certificates, cert) + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.Infof(Tr("info.cert.requested"), certHost) } - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.Infof(Tr("info.cert.requested"), certHost) } r.tlsConfig = tlsConfig return @@ -554,7 +572,31 @@ func (r *Runner) EnableCluster(ctx context.Context) { } } -func (r *Runner) enableClusterByTunnel(ctx context.Context) { +func (r *Runner) StartTunneler() { + ctx, cancel := context.WithCancel(context.Background()) + r.tunnelCancel = cancel + go func() { + dur := time.Second + for { + start := time.Now() + r.RunTunneler(ctx) + used := time.Since(start) + // If the program runs no longer than 30s, then it fails too fast. + if used < time.Second*30 { + dur = min(dur*2, time.Minute*10) + } else { + dur = time.Second + } + select { + case <-time.After(dur): + case <-ctx.Done(): + return + } + } + }() +} + +func (r *Runner) RunTunneler(ctx context.Context) { cmd := exec.CommandContext(ctx, config.Tunneler.TunnelProg) log.Infof(Tr("info.tunnel.running"), cmd.String()) var ( @@ -562,18 +604,18 @@ func (r *Runner) enableClusterByTunnel(ctx context.Context) { err error ) cmd.Env = append(os.Environ(), - "CLUSTER_PORT="+strconv.Itoa((int)(config.Port))) + "CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port))) if cmdOut, err = cmd.StdoutPipe(); err != nil { log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) - osExit(CodeClientUnexpectedError) + os.Exit(1) } if cmdErr, err = cmd.StderrPipe(); err != nil { log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) - osExit(CodeClientUnexpectedError) + os.Exit(1) } if err = cmd.Start(); err != nil { log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) - osExit(CodeClientError) + os.Exit(1) } type addrOut struct { host string @@ -651,7 +693,7 @@ func (r *Runner) enableClusterByTunnel(ctx context.Context) { if ctx.Err() != nil { return } - osExit(CodeServerOrEnvionmentError) + os.Exit(2) } case <-ctx.Done(): return @@ -663,11 +705,5 @@ func (r *Runner) enableClusterByTunnel(ctx context.Context) { return } log.Errorf("Tunnel program exited: %v", err) - osExit(CodeClientError) } - // TODO: maybe restart the tunnel program? -} - -func Tr(name string) string { - return lang.Tr(name) } diff --git a/storage/manager.go b/storage/manager.go index 36c43405..79353dd6 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -50,7 +50,7 @@ func NewManager(storages []Storage) (m *Manager) { func (m *Manager) Get(id string) Storage { for _, s := range m.Storages { - if s.Options().Id == id { + if s.Id() == id { return s } } diff --git a/storage/storage.go b/storage/storage.go index 5c5c4a83..91c99521 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -35,6 +35,7 @@ import ( type Storage interface { fmt.Stringer + Id() string // Options should return the pointer of the StorageOption that should not be modified. Options() *StorageOption // Init will be called before start to use a storage diff --git a/storage/storage_local.go b/storage/storage_local.go index b912e1d5..a0179b5d 100644 --- a/storage/storage_local.go +++ b/storage/storage_local.go @@ -66,6 +66,10 @@ func (s *LocalStorage) String() string { return fmt.Sprintf("", s.opt.CachePath) } +func (s *LocalStorage) Id() string { + return s.basicOpt.Id +} + func (s *LocalStorage) Options() *StorageOption { return &s.basicOpt } diff --git a/storage/storage_mount.go b/storage/storage_mount.go index 62a75e4a..ea2481d1 100644 --- a/storage/storage_mount.go +++ b/storage/storage_mount.go @@ -80,6 +80,10 @@ func (s *MountStorage) String() string { return fmt.Sprintf("", s.opt.Path, s.opt.RedirectBase) } +func (s *MountStorage) Id() string { + return s.basicOpt.Id +} + func (s *MountStorage) Options() *StorageOption { return &s.basicOpt } diff --git a/storage/storage_webdav.go b/storage/storage_webdav.go index 5b472cee..940067f4 100644 --- a/storage/storage_webdav.go +++ b/storage/storage_webdav.go @@ -156,6 +156,10 @@ func (s *WebDavStorage) String() string { return fmt.Sprintf("", s.opt.GetEndPoint(), s.opt.GetUsername()) } +func (s *WebDavStorage) Id() string { + return s.basicOpt.Id +} + func (s *WebDavStorage) Options() *StorageOption { return &s.basicOpt } diff --git a/sub_commands/cmd_compress.go b/sub_commands/cmd_compress.go index 4f60637e..680a3527 100644 --- a/sub_commands/cmd_compress.go +++ b/sub_commands/cmd_compress.go @@ -19,7 +19,7 @@ * along with this program. If not, see . */ -package main +package sub_commands import ( "compress/gzip" diff --git a/sub_commands/cmd_webdav.go b/sub_commands/cmd_webdav.go index 5351df30..2826bd47 100644 --- a/sub_commands/cmd_webdav.go +++ b/sub_commands/cmd_webdav.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package sub_commands import ( "context" @@ -38,7 +38,7 @@ import ( "github.com/LiterMC/go-openbmclapi/utils" ) -func cmdUploadWebdav(args []string) { +func CmdUploadWebdav(args []string) { cfg := readConfig() var ( diff --git a/sync.go b/sync.go index 0053653f..bded4ac1 100644 --- a/sync.go +++ b/sync.go @@ -88,34 +88,6 @@ type syncStats struct { lastInc atomic.Int64 } -func (cr *Cluster) SyncFiles(ctx context.Context, files []FileInfo, heavyCheck bool) bool { - log.Infof(Tr("info.sync.prepare"), len(files)) - if !cr.issync.CompareAndSwap(false, true) { - log.Warn("Another sync task is running!") - return false - } - defer cr.issync.Store(false) - - sort.Slice(files, func(i, j int) bool { return files[i].Hash < files[j].Hash }) - if cr.syncFiles(ctx, files, heavyCheck) != nil { - return false - } - - cr.filesetMux.Lock() - for _, f := range files { - cr.fileset[f.Hash] = f.Size - } - cr.filesetMux.Unlock() - - return true -} - -type fileInfoWithTargets struct { - FileInfo - tgMux sync.Mutex - targets []storage.Storage -} - func (cr *Cluster) checkFileFor( ctx context.Context, sto storage.Storage, files []FileInfo, diff --git a/utils/http.go b/utils/http.go index 08437157..d83fbdbf 100644 --- a/utils/http.go +++ b/utils/http.go @@ -290,10 +290,12 @@ func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) // Else it will just return the tls connection type HTTPTLSListener struct { net.Listener - TLSConfig *tls.Config - mux sync.RWMutex - hosts []string - port string + TLSConfig *tls.Config + AllowUnsecure bool + + mux sync.RWMutex + hosts []string + port string accepting atomic.Bool acceptedCh chan net.Conn @@ -397,6 +399,10 @@ func (s *HTTPTLSListener) accepter() { s.acceptedCh <- tls.Server(hr, s.TLSConfig) return } + if s.AllowUnsecure { + s.acceptedCh <- hr + return + } go s.serveHTTP(hr) } } @@ -418,6 +424,10 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { if host != "" { host = strings.ToLower(host) for _, h := range s.hosts { + if h == "*" { + inhosts = true + break + } if h, ok := strings.CutPrefix(h, "*."); ok { if strings.HasSuffix(host, h) { inhosts = true @@ -432,15 +442,15 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { u := *req.URL u.Scheme = "https" if !inhosts { + host = "" for _, h := range s.hosts { - if !strings.HasSuffix(h, "*.") { + if h != "*" && !strings.HasSuffix(h, "*.") { host = h break } } } if host == "" { - // we have nowhere to redirect body := strings.NewReader("Sent http request on https server") resp := &http.Response{ StatusCode: http.StatusBadRequest, From 3c20fd8edb808b000dcf708b259aa9de990c6083 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 10 Aug 2024 20:46:14 -0700 Subject: [PATCH 18/36] run go fmt --- cluster/storage.go | 4 ++-- main.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cluster/storage.go b/cluster/storage.go index 9b7a65c2..77173d6f 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -21,6 +21,8 @@ package cluster import ( "context" + "crypto" + "encoding/hex" "fmt" "io" "net/http" @@ -30,9 +32,7 @@ import ( "strconv" "sync" "sync/atomic" - "crypto" "time" - "encoding/hex" "github.com/hamba/avro/v2" "github.com/klauspost/compress/zstd" diff --git a/main.go b/main.go index 21fd45f2..fff4384d 100644 --- a/main.go +++ b/main.go @@ -44,8 +44,8 @@ import ( doh "github.com/libp2p/go-doh-resolver" - "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/cluster" + "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/lang" From 173aced2290d30ea8328e35e9f8d702af38834d8 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sun, 11 Aug 2024 15:10:50 -0700 Subject: [PATCH 19/36] refactored most stuffs :o --- api/bmclapi/hijacker.go | 2 +- cluster/cluster.go | 21 +- cluster/handler.go | 6 +- cluster/http.go | 2 +- cluster/socket.go | 4 +- config.go | 40 +-- config/config.go | 4 +- config/server.go | 17 +- dashboard.go | 3 +- handler.go | 134 +++++----- internal/gosrc/httpstrip.go | 26 ++ main.go | 512 +++++++++++++++++++++--------------- storage/manager.go | 9 + sync.go | 2 + util.go | 49 ---- utils/http.go | 84 ++---- 16 files changed, 489 insertions(+), 426 deletions(-) create mode 100644 internal/gosrc/httpstrip.go diff --git a/api/bmclapi/hijacker.go b/api/bmclapi/hijacker.go index f45d0b24..3218c2b0 100644 --- a/api/bmclapi/hijacker.go +++ b/api/bmclapi/hijacker.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package main +package bmclapi import ( "context" diff --git a/cluster/cluster.go b/cluster/cluster.go index 3f7943f2..51d2cf6d 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -70,9 +70,13 @@ type Cluster struct { func NewCluster( opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, - storageManager *storage.Manager, storages []int, + storageManager *storage.Manager, statManager *StatManager, ) (cr *Cluster) { + storages := make([]int, len(opts.Storages)) + for i, name := range opts.Storages { + storages[i] = storageManager.GetIndex(name) + } cr = &Cluster{ opts: opts, gcfg: gcfg, @@ -120,10 +124,23 @@ func (cr *Cluster) AcceptHost(host string) bool { return false } +func (cr *Cluster) Options() *config.ClusterOptions { + return &cr.opts +} + +func (cr *Cluster) GeneralConfig() *config.ClusterGeneralConfig { + return &cr.gcfg +} + // Init do setup on the cluster // Init should only be called once during the cluster's whole life // The context passed in only affect the logical of Init method func (cr *Cluster) Init(ctx context.Context) error { + for i, ind := range cr.storages { + if ind == -1 { + return fmt.Errorf("Storage %q does not exists", cr.opts.Storages[i]) + } + } return nil } @@ -172,7 +189,7 @@ func (cr *Cluster) enable(ctx context.Context) error { Host: cr.gcfg.PublicHost, Port: cr.gcfg.PublicPort, Version: build.ClusterVersion, - Byoc: cr.gcfg.Byoc, + Byoc: cr.opts.Byoc, NoFastEnable: cr.gcfg.NoFastEnable, Flavor: ConfigFlavor{ Runtime: "golang/" + runtime.GOOS + "-" + runtime.GOARCH, diff --git a/cluster/handler.go b/cluster/handler.go index dd2151a4..4dbd3a18 100644 --- a/cluster/handler.go +++ b/cluster/handler.go @@ -36,7 +36,7 @@ import ( "github.com/LiterMC/go-openbmclapi/storage" ) -func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash string, size int64) { +func (cr *Cluster) HandleFile(rw http.ResponseWriter, req *http.Request, hash string) { defer log.RecoverPanic(nil) if !cr.Enabled() { @@ -88,6 +88,8 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st api.SetAccessInfo(req, "cluster", cr.ID()) + var size int64 = -1 // TODO: get the size + var ( sto storage.Storage err error @@ -121,7 +123,7 @@ func (cr *Cluster) HandleFile(req *http.Request, rw http.ResponseWriter, hash st http.Error(rw, err.Error(), http.StatusInternalServerError) } -func (cr *Cluster) HandleMeasure(req *http.Request, rw http.ResponseWriter, size int) { +func (cr *Cluster) HandleMeasure(rw http.ResponseWriter, req *http.Request, size int) { if !cr.Enabled() { // do not serve file if cluster is not enabled yet http.Error(rw, "Cluster is not enabled yet", http.StatusServiceUnavailable) diff --git a/cluster/http.go b/cluster/http.go index f66f3f54..4910444c 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -90,7 +90,7 @@ func (cr *Cluster) makeReqWithBody( query url.Values, body io.Reader, ) (req *http.Request, err error) { var u *url.URL - if u, err = url.Parse(cr.opts.Prefix); err != nil { + if u, err = url.Parse(cr.opts.Server); err != nil { return } u.Path = path.Join(u.Path, relpath) diff --git a/cluster/socket.go b/cluster/socket.go index 7b2b78f1..db238397 100644 --- a/cluster/socket.go +++ b/cluster/socket.go @@ -48,10 +48,10 @@ func (cr *Cluster) Connect(ctx context.Context) error { } engio, err := engine.NewSocket(engine.Options{ - Host: cr.opts.Prefix, + Host: cr.opts.Server, Path: "/socket.io/", ExtraHeaders: http.Header{ - "Origin": {cr.opts.Prefix}, + "Origin": {cr.opts.Server}, "User-Agent": {build.ClusterUserAgent}, }, DialTimeout: time.Minute * 6, diff --git a/config.go b/config.go index 43298f58..b5259111 100644 --- a/config.go +++ b/config.go @@ -21,12 +21,20 @@ package main import ( "bytes" + "errors" + "fmt" + "net/url" + "os" "gopkg.in/yaml.v3" "github.com/LiterMC/go-openbmclapi/config" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" ) +const DefaultBMCLAPIServer = "https://openbmclapi.bangbang93.com" + func migrateConfig(data []byte, cfg *config.Config) { var oldConfig map[string]any if err := yaml.Unmarshal(data, &oldConfig); err != nil { @@ -45,11 +53,13 @@ func migrateConfig(data []byte, cfg *config.Config) { if oldConfig["clusters"].(map[string]any) == nil { id, ok1 := oldConfig["cluster-id"].(string) secret, ok2 := oldConfig["cluster-secret"].(string) - if ok1 && ok2 { - cfg.Clusters = map[string]ClusterItem{ + publicHost, ok3 := oldConfig["public-host"].(string) + if ok1 && ok2 && ok3 { + cfg.Clusters = map[string]config.ClusterOptions{ "main": { - Id: id, - Secret: secret, + Id: id, + Secret: secret, + PublicHosts: []string{publicHost}, }, } } @@ -67,7 +77,7 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { log.TrErrorf("error.config.read.failed", err) os.Exit(1) } - log.TrError("error.config.not.exists") + log.TrErrorf("error.config.not.exists") notexists = true } else { migrateConfig(data, cfg) @@ -76,15 +86,18 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { os.Exit(1) } if len(cfg.Clusters) == 0 { - cfg.Clusters = map[string]ClusterItem{ + cfg.Clusters = map[string]config.ClusterOptions{ "main": { - Id: "${CLUSTER_ID}", - Secret: "${CLUSTER_SECRET}", + Id: "${CLUSTER_ID}", + Secret: "${CLUSTER_SECRET}", + PublicHosts: []string{}, + Server: DefaultBMCLAPIServer, + SkipSignatureCheck: false, }, } } if len(cfg.Certificates) == 0 { - cfg.Certificates = []CertificateConfig{ + cfg.Certificates = []config.CertificateConfig{ { Cert: "/path/to/cert.pem", Key: "/path/to/key.pem", @@ -123,12 +136,6 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { os.Exit(1) } ids[s.Id] = i - if s.Cluster != "" && s.Cluster != "-" { - if _, ok := cfg.Clusters[s.Cluster]; !ok { - log.Errorf("Storage %q is trying to connect to a not exists cluster %q.", s.Id, s.Cluster) - os.Exit(1) - } - } } } @@ -173,7 +180,8 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { os.Exit(1) } if notexists { - log.TrError("error.config.created") + log.TrErrorf("error.config.created") + return nil, errors.New("Please edit the config before continue!") } return } diff --git a/config/config.go b/config/config.go index 5c843c9d..ec18ca2f 100644 --- a/config/config.go +++ b/config/config.go @@ -36,7 +36,6 @@ type Config struct { PublicPort uint16 `yaml:"public-port"` Host string `yaml:"host"` Port uint16 `yaml:"port"` - Byoc bool `yaml:"byoc"` UseCert bool `yaml:"use-cert"` TrustedXForwardedFor bool `yaml:"trusted-x-forwarded-for"` @@ -65,7 +64,7 @@ type Config struct { Advanced AdvancedConfig `yaml:"advanced"` } -func (cfg *Config) applyWebManifest(manifest map[string]any) { +func (cfg *Config) ApplyWebManifest(manifest map[string]any) { if cfg.Dashboard.Enable { manifest["name"] = cfg.Dashboard.PwaName manifest["short_name"] = cfg.Dashboard.PwaShortName @@ -79,7 +78,6 @@ func NewDefaultConfig() *Config { PublicPort: 0, Host: "0.0.0.0", Port: 4000, - Byoc: false, TrustedXForwardedFor: false, OnlyGcWhenStart: false, diff --git a/config/server.go b/config/server.go index 32963a58..a1763a8f 100644 --- a/config/server.go +++ b/config/server.go @@ -34,15 +34,16 @@ import ( type ClusterOptions struct { Id string `json:"id" yaml:"id"` Secret string `json:"secret" yaml:"secret"` + Byoc bool `json:"byoc"` PublicHosts []string `json:"public-hosts" yaml:"public-hosts"` - Prefix string `json:"prefix" yaml:"prefix"` + Server string `json:"server" yaml:"server"` SkipSignatureCheck bool `json:"skip-signature-check" yaml:"skip-signature-check"` + Storages []string `json:"storages" yaml:"storages"` } type ClusterGeneralConfig struct { PublicHost string `json:"public-host"` PublicPort uint16 `json:"public-port"` - Byoc bool `json:"byoc"` NoFastEnable bool `json:"no-fast-enable"` MaxReconnectCount int `json:"max-reconnect-count"` } @@ -77,6 +78,10 @@ type CacheConfig struct { newCache func() cache.Cache `yaml:"-"` } +func (c *CacheConfig) NewCache() cache.Cache { + return c.newCache() +} + func (c *CacheConfig) UnmarshalYAML(n *yaml.Node) (err error) { var cfg struct { Type string `yaml:"type"` @@ -148,3 +153,11 @@ func (c *TunnelConfig) UnmarshalYAML(n *yaml.Node) (err error) { } return } + +func (c *TunnelConfig) MatchTunnelOutput(line []byte) (host, port []byte, ok bool) { + res := c.outputRegex.FindSubmatch(line) + if res == nil { + return + } + return res[c.hostOut], res[c.portOut], true +} diff --git a/dashboard.go b/dashboard.go index f8fbcaef..3bcd689d 100644 --- a/dashboard.go +++ b/dashboard.go @@ -60,13 +60,14 @@ var dsbManifest = func() (dsbManifest map[string]any) { return }() -func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request, pth string) { +func (r *Runner) serveDashboard(rw http.ResponseWriter, req *http.Request) { if req.Method != http.MethodGet && req.Method != http.MethodHead { rw.Header().Set("Allow", http.MethodGet+", "+http.MethodHead) http.Error(rw, "405 Method Not Allowed", http.StatusMethodNotAllowed) return } acceptEncoding := utils.SplitCSV(req.Header.Get("Accept-Encoding")) + pth := strings.TrimPrefix(req.URL.Path, "/") switch pth { case "": break diff --git a/handler.go b/handler.go index bedc3fb6..97e4e85a 100644 --- a/handler.go +++ b/handler.go @@ -24,16 +24,11 @@ import ( "context" "crypto" _ "embed" - "encoding/base64" "encoding/hex" "encoding/json" - "errors" "fmt" - "io" "net" "net/http" - "net/textproto" - "os" "strconv" "strings" "time" @@ -43,9 +38,9 @@ import ( "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/api/v0" "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/internal/gosrc" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/storage" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -108,13 +103,13 @@ var wsUpgrader = &websocket.Upgrader{ } func (r *Runner) GetHandler() http.Handler { - r.apiRateLimiter = limited.NewAPIRateMiddleWare(RealAddrCtxKey, loggedUserKey) - r.apiRateLimiter.SetAnonymousRateLimit(r.RateLimit.Anonymous) - r.apiRateLimiter.SetLoggedRateLimit(r.RateLimit.Logged) + r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) + r.apiRateLimiter.SetAnonymousRateLimit(r.Config.RateLimit.Anonymous) + r.apiRateLimiter.SetLoggedRateLimit(r.Config.RateLimit.Logged) r.handlerAPIv0 = http.StripPrefix("/api/v0", v0.NewHandler(wsUpgrader)) - r.hijackHandler = http.StripPrefix("/bmclapi", r.hijackProxy) + r.hijackHandler = http.StripPrefix("/bmclapi", r.hijacker) - handler := utils.NewHttpMiddleWareHandler(r) + handler := utils.NewHttpMiddleWareHandler((http.HandlerFunc)(r.serveHTTP)) // recover panic and log it handler.UseFunc(func(rw http.ResponseWriter, req *http.Request, next http.Handler) { defer log.RecoverPanic(func(any) { @@ -124,67 +119,57 @@ func (r *Runner) GetHandler() http.Handler { }) handler.Use(r.apiRateLimiter) - handler.Use(r.getRecordMiddleWare()) + handler.UseFunc(r.recordMiddleWare) return handler } -func (r *Runner) getRecordMiddleWare() utils.MiddleWareFunc { - type record struct { - used float64 - bytes float64 - ua string - skipUA bool +func (r *Runner) recordMiddleWare(rw http.ResponseWriter, req *http.Request, next http.Handler) { + ua := req.UserAgent() + var addr string + if r.Config.TrustedXForwardedFor { + // X-Forwarded-For: , , + adr, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",") + addr = strings.TrimSpace(adr) } - recordCh := make(chan record, 1024) - - return func(rw http.ResponseWriter, req *http.Request, next http.Handler) { - ua := req.UserAgent() - var addr string - if config.TrustedXForwardedFor { - // X-Forwarded-For: , , - adr, _, _ := strings.Cut(req.Header.Get("X-Forwarded-For"), ",") - addr = strings.TrimSpace(adr) - } - if addr == "" { - addr, _, _ = net.SplitHostPort(req.RemoteAddr) - } - srw := utils.WrapAsStatusResponseWriter(rw) - start := time.Now() + if addr == "" { + addr, _, _ = net.SplitHostPort(req.RemoteAddr) + } + srw := utils.WrapAsStatusResponseWriter(rw) + start := time.Now() - log.LogAccess(log.LevelDebug, &preAccessRecord{ - Type: "pre-access", - Time: start, - Addr: addr, - Method: req.Method, - URI: req.RequestURI, - UA: ua, - }) + log.LogAccess(log.LevelDebug, &preAccessRecord{ + Type: "pre-access", + Time: start, + Addr: addr, + Method: req.Method, + URI: req.RequestURI, + UA: ua, + }) - extraInfoMap := make(map[string]any) - ctx := req.Context() - ctx = context.WithValue(ctx, RealAddrCtxKey, addr) - ctx = context.WithValue(ctx, RealPathCtxKey, req.URL.Path) - ctx = context.WithValue(ctx, AccessLogExtraCtxKey, extraInfoMap) - req = req.WithContext(ctx) - next.ServeHTTP(srw, req) + extraInfoMap := make(map[string]any) + ctx := req.Context() + ctx = context.WithValue(ctx, api.RealAddrCtxKey, addr) + ctx = context.WithValue(ctx, api.RealPathCtxKey, req.URL.Path) + ctx = context.WithValue(ctx, api.AccessLogExtraCtxKey, extraInfoMap) + req = req.WithContext(ctx) + next.ServeHTTP(srw, req) - used := time.Since(start) - accRec := &accessRecord{ - Type: "access", - Status: srw.Status, - Used: used, - Content: srw.Wrote, - Addr: addr, - Proto: req.Proto, - Method: req.Method, - URI: req.RequestURI, - UA: ua, - } - if len(extraInfoMap) > 0 { - accRec.Extra = extraInfoMap - } - log.LogAccess(log.LevelInfo, accRec) + used := time.Since(start) + accRec := &accessRecord{ + Type: "access", + Status: srw.Status, + Used: used, + Content: srw.Wrote, + Addr: addr, + Proto: req.Proto, + Method: req.Method, + URI: req.RequestURI, + UA: ua, + } + if len(extraInfoMap) > 0 { + accRec.Extra = extraInfoMap } + log.LogAccess(log.LevelInfo, accRec) } var emptyHashes = func() (hashes map[string]struct{}) { @@ -202,11 +187,11 @@ var emptyHashes = func() (hashes map[string]struct{}) { //go:embed robots.txt var robotTxtContent string -func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { +func (r *Runner) serveHTTP(rw http.ResponseWriter, req *http.Request) { method := req.Method u := req.URL - rw.Header().Set("X-Powered-By", HeaderXPoweredBy) + rw.Header().Set("X-Powered-By", build.HeaderXPoweredBy) rawpath := u.EscapedPath() switch { @@ -249,7 +234,7 @@ func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { for _, cr := range r.clusters { if cr.AcceptHost(req.Host) { - cr.HandleFile(rw, req, hash) + cr.HandleMeasure(rw, req, size) return } } @@ -264,23 +249,24 @@ func (r *Runner) ServeHTTP(rw http.ResponseWriter, req *http.Request) { case "v0": r.handlerAPIv0.ServeHTTP(rw, req) return - case "v1": - r.handlerAPIv1.ServeHTTP(rw, req) - return + // case "v1": + // r.handlerAPIv1.ServeHTTP(rw, req) + // return } case rawpath == "/" || rawpath == "/dashboard": http.Redirect(rw, req, "/dashboard/", http.StatusFound) return case strings.HasPrefix(rawpath, "/dashboard/"): - if !r.DashboardEnabled { + if !r.Config.Dashboard.Enable { http.NotFound(rw, req) return } - pth := rawpath[len("/dashboard/"):] - r.serveDashboard(rw, req, pth) + req2 := gosrc.RequestStripPrefix(req, "/dashboard") + r.serveDashboard(rw, req2) return case strings.HasPrefix(rawpath, "/bmclapi/"): - r.hijackHandler.ServeHTTP(rw, req) + req2 := gosrc.RequestStripPrefix(req, "/bmclapi") + r.hijackHandler.ServeHTTP(rw, req2) return } http.NotFound(rw, req) diff --git a/internal/gosrc/httpstrip.go b/internal/gosrc/httpstrip.go new file mode 100644 index 00000000..8dcc129a --- /dev/null +++ b/internal/gosrc/httpstrip.go @@ -0,0 +1,26 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package gosrc + +import ( + "net/http" + "net/url" + "strings" +) + +func RequestStripPrefix(r *http.Request, prefix string) *http.Request { + p, ok := strings.CutPrefix(r.URL.Path, prefix) + rp, ok2 := strings.CutPrefix(r.URL.RawPath, prefix) + if ok && (ok2 || r.URL.RawPath == "") { + r2 := new(http.Request) + *r2 = *r + r2.URL = new(url.URL) + *r2.URL = *r.URL + r2.URL.Path = p + r2.URL.RawPath = rp + return r2 + } + return nil +} diff --git a/main.go b/main.go index fff4384d..55f42dc4 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,7 @@ import ( "runtime" "strconv" "strings" + "sync" "sync/atomic" "syscall" "time" @@ -44,6 +45,7 @@ import ( doh "github.com/libp2p/go-doh-resolver" + "github.com/LiterMC/go-openbmclapi/api/bmclapi" "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/database" @@ -51,7 +53,9 @@ import ( "github.com/LiterMC/go-openbmclapi/lang" "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/storage" subcmds "github.com/LiterMC/go-openbmclapi/sub_commands" + "github.com/LiterMC/go-openbmclapi/utils" _ "github.com/LiterMC/go-openbmclapi/lang/en" _ "github.com/LiterMC/go-openbmclapi/lang/zh" @@ -124,18 +128,8 @@ func main() { } else { r.Config = config } - if r.Config.Advanced.DebugLog { - log.SetLevel(log.LevelDebug) - } else { - log.SetLevel(log.LevelInfo) - } - if r.Config.NoAccessLog { - log.SetAccessLogSlots(-1) - } else { - log.SetAccessLogSlots(r.Config.AccessLogSlots) - } - r.Config.applyWebManifest(dsbManifest) + r.SetupLogger(ctx) log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion) @@ -143,13 +137,37 @@ func main() { r.StartTunneler() } r.InitServer() + r.StartServer(ctx) + r.InitClusters(ctx) go func(ctx context.Context) { defer log.RecordPanic() - if !r.cluster.Connect(ctx) { - osExit(CodeClientOrServerError) + var wg sync.WaitGroup + errs := make([]error, len(r.clusters)) + { + i := 0 + for _, cr := range r.clusters { + i++ + go func(i int, cr *cluster.Cluster) { + defer wg.Done() + errs[i] = cr.Connect(ctx) + }(i, cr) + } + } + wg.Wait() + if ctx.Err() != nil { + return + } + + { + var err error + r.tlsConfig, err = r.PatchTLSWithClusterCert(ctx, r.tlsConfig) + if err != nil { + return + } + r.listener.TLSConfig.Store(r.tlsConfig) } firstSyncDone := make(chan struct{}, 0) @@ -159,30 +177,11 @@ func main() { r.InitSynchronizer(ctx) }() - listener := r.CreateHTTPServerListener(ctx) - go func(listener net.Listener) { - defer listener.Close() - if err := r.clusterSvr.Serve(listener); !errors.Is(err, http.ErrServerClosed) { - log.Error("Error when serving:", err) - os.Exit(1) - } - }(listener) - - var publicHost string - if len(r.publicHosts) == 0 { - publicHost = config.PublicHost - } else { - publicHost = r.publicHosts[0] - } - if !config.Tunneler.Enable { + if !r.Config.Tunneler.Enable { strPort := strconv.Itoa((int)(r.getPublicPort())) - log.TrInfof("info.server.public.at", net.JoinHostPort(publicHost, strPort), r.clusterSvr.Addr, r.getCertCount()) - if len(r.publicHosts) > 1 { - log.TrInfof("info.server.alternative.hosts") - for _, h := range r.publicHosts[1:] { - log.Infof("\t- https://%s", net.JoinHostPort(h, strPort)) - } - } + pubAddr := net.JoinHostPort(r.Config.PublicHost, strPort) + localAddr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port))) + log.TrInfof("info.server.public.at", pubAddr, localAddr, r.getCertCount()) } log.TrInfof("info.wait.first.sync") @@ -192,14 +191,14 @@ func main() { return } - r.EnableCluster(ctx) + // r.EnableCluster(ctx) }(ctx) - code := r.DoSignals(cancel) + code := r.ListenSignals(ctx, cancel) if code != 0 { log.TrErrorf("program.exited", code) log.TrErrorf("error.exit.please.read.faq") - if runtime.GOOS == "windows" && !config.Advanced.DoNotOpenFAQOnWindows { + if runtime.GOOS == "windows" && !r.Config.Advanced.DoNotOpenFAQOnWindows { // log.TrWarnf("warn.exit.detected.windows.open.browser") // cmd := exec.Command("cmd", "/C", "start", "https://cdn.crashmc.com/https://github.com/LiterMC/go-openbmclapi?tab=readme-ov-file#faq") // cmd.Start() @@ -212,12 +211,21 @@ func main() { type Runner struct { Config *config.Config - clusters map[string]*Cluster - server *http.Server + clusters map[string]*cluster.Cluster + apiRateLimiter *limited.APIRateMiddleWare + storageManager *storage.Manager + statManager *cluster.StatManager + hijacker *bmclapi.HjProxy + database database.DB + + server *http.Server + handlerAPIv0 http.Handler + hijackHandler http.Handler tlsConfig *tls.Config - listener net.Listener - publicHosts []string + publicHost string + publicPort uint16 + listener *utils.HTTPTLSListener reloading atomic.Bool updating atomic.Bool @@ -225,8 +233,8 @@ type Runner struct { } func (r *Runner) getPublicPort() uint16 { - if r.Config.PublicPort > 0 { - return r.Config.PublicPort + if r.publicPort > 0 { + return r.publicPort } return r.Config.Port } @@ -238,7 +246,80 @@ func (r *Runner) getCertCount() int { return len(r.tlsConfig.Certificates) } -func (r *Runner) DoSignals(cancel context.CancelFunc) int { +func (r *Runner) InitServer() { + r.server = &http.Server{ + ReadTimeout: 10 * time.Second, + IdleTimeout: 5 * time.Second, + Handler: r.GetHandler(), + ErrorLog: log.ProxiedStdLog, + } +} + +// StartServer will start the HTTP server +// If a server is already running on an old listener, the listener will be closed. +func (r *Runner) StartServer(ctx context.Context) error { + htListener, err := r.CreateHTTPListener(ctx) + if err != nil { + return err + } + if r.listener != nil { + r.listener.Close() + } + r.listener = htListener + go func() { + defer htListener.Close() + if err := r.server.Serve(htListener); !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { + log.Error("Error when serving:", err) + os.Exit(1) + } + }() + return nil +} + +func (r *Runner) GetClusterGeneralConfig() config.ClusterGeneralConfig { + return config.ClusterGeneralConfig{ + PublicHost: r.publicHost, + PublicPort: r.getPublicPort(), + NoFastEnable: r.Config.Advanced.NoFastEnable, + MaxReconnectCount: r.Config.MaxReconnectCount, + } +} + +func (r *Runner) InitClusters(ctx context.Context) { + // var ( + // dialer *net.Dialer + // cache = r.Config.Cache.NewCache() + // ) + + _ = doh.NewResolver // TODO: use doh resolver + + r.clusters = make(map[string]*cluster.Cluster) + gcfg := r.GetClusterGeneralConfig() + for name, opts := range r.Config.Clusters { + cr := cluster.NewCluster(opts, gcfg, r.storageManager, r.statManager) + if err := cr.Init(ctx); err != nil { + log.TrErrorf("error.init.failed", err) + } else { + r.clusters[name] = cr + } + } + + // r.cluster = NewCluster(ctx, + // ClusterServerURL, + // baseDir, + // config.PublicHost, r.getPublicPort(), + // config.ClusterId, config.ClusterSecret, + // config.Byoc, dialer, + // config.Storages, + // cache, + // ) + // if err := r.cluster.Init(ctx); err != nil { + // log.TrErrorf("error.init.failed"), err) + // os.Exit(1) + // } +} + +func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) int { signalCh := make(chan os.Signal, 1) log.Debugf("Receiving signals") signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) @@ -290,7 +371,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { } } case syscall.SIGHUP: - go r.ReloadConfig() + go r.ReloadConfig(ctx) default: cancel() if forceStop == nil { @@ -312,7 +393,7 @@ func (r *Runner) DoSignals(cancel context.CancelFunc) int { return 0 } -func (r *Runner) ReloadConfig() { +func (r *Runner) ReloadConfig(ctx context.Context) { if r.reloading.CompareAndSwap(false, true) { log.Error("Config is already reloading!") return @@ -321,10 +402,54 @@ func (r *Runner) ReloadConfig() { config, err := readAndRewriteConfig() if err != nil { - log.Errorf("Config error: %s", err) + log.Errorf("Config error: %v", err) } else { - r.Config = config + if err := r.updateConfig(ctx, config); err != nil { + log.Errorf("Error when reloading config: %v", err) + } + } +} + +func (r *Runner) updateConfig(ctx context.Context, newConfig *config.Config) error { + oldConfig := r.Config + reloadProcesses := make([]func(context.Context) error, 0, 8) + + if newConfig.LogSlots != oldConfig.LogSlots || newConfig.NoAccessLog != oldConfig.NoAccessLog || newConfig.AccessLogSlots != oldConfig.AccessLogSlots || newConfig.Advanced.DebugLog != oldConfig.Advanced.DebugLog { + reloadProcesses = append(reloadProcesses, r.SetupLogger) + } + if newConfig.Host != oldConfig.Host || newConfig.Port != oldConfig.Port { + reloadProcesses = append(reloadProcesses, r.StartServer) + } + if newConfig.PublicHost != oldConfig.PublicHost || newConfig.PublicPort != oldConfig.PublicPort || newConfig.Advanced.NoFastEnable != oldConfig.Advanced.NoFastEnable || newConfig.MaxReconnectCount != oldConfig.MaxReconnectCount { + reloadProcesses = append(reloadProcesses, r.updateClustersWithGeneralConfig) + } + + r.Config = newConfig + r.publicHost = r.Config.PublicHost + r.publicPort = r.Config.PublicPort + for _, proc := range reloadProcesses { + if err := proc(ctx); err != nil { + return err + } + } + return nil +} + +func (r *Runner) SetupLogger(ctx context.Context) error { + if r.Config.Advanced.DebugLog { + log.SetLevel(log.LevelDebug) + } else { + log.SetLevel(log.LevelInfo) + } + log.SetLogSlots(r.Config.LogSlots) + if r.Config.NoAccessLog { + log.SetAccessLogSlots(-1) + } else { + log.SetAccessLogSlots(r.Config.AccessLogSlots) } + + r.Config.ApplyWebManifest(dsbManifest) + return nil } func (r *Runner) StopServer(ctx context.Context) { @@ -346,6 +471,8 @@ func (r *Runner) StopServer(ctx context.Context) { wg.Wait() log.TrWarnf("warn.httpserver.closing") r.server.Shutdown(shutCtx) + r.listener.Close() + r.listener = nil }() select { case <-shutDone: @@ -355,41 +482,8 @@ func (r *Runner) StopServer(ctx context.Context) { log.TrWarnf("warn.server.closed") } -func (r *Runner) InitServer() { - r.server = &http.Server{ - Addr: fmt.Sprintf("%s:%d", d.Config.Host, d.Config.Port), - ReadTimeout: 10 * time.Second, - IdleTimeout: 5 * time.Second, - Handler: r, - ErrorLog: log.ProxiedStdLog, - } -} - -func (r *Runner) InitClusters(ctx context.Context) { - var ( - dialer *net.Dialer - cache = r.Config.Cache.newCache() - ) - - _ = doh.NewResolver // TODO: use doh resolver - - r.cluster = NewCluster(ctx, - ClusterServerURL, - baseDir, - config.PublicHost, r.getPublicPort(), - config.ClusterId, config.ClusterSecret, - config.Byoc, dialer, - config.Storages, - cache, - ) - if err := r.cluster.Init(ctx); err != nil { - log.Errorf(Tr("error.init.failed"), err) - os.Exit(1) - } -} - func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) { - if !r.hijacker.Enabled { + if !r.Config.Hijack.Enable { return } if !r.updating.CompareAndSwap(false, true) { @@ -409,7 +503,7 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol sem.Acquire() go func(rec database.FileRecord) { defer sem.Release() - r.cluster.database.SetFileRecord(rec) + r.database.SetFileRecord(rec) }(database.FileRecord{ Path: f.Path, Hash: f.Hash, @@ -421,11 +515,11 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol } func (r *Runner) InitSynchronizer(ctx context.Context) { - fileMap := make(map[string]*StorageFileInfo) + fileMap := make(map[string]*cluster.StorageFileInfo) for _, cr := range r.clusters { - log.Info(Tr("info.filelist.fetching"), cr.ID()) + log.TrInfof("info.filelist.fetching", cr.ID()) if err := cr.GetFileList(ctx, fileMap, true); err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) if errors.Is(err, context.Canceled) { return } @@ -433,33 +527,34 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } checkCount := -1 - heavyCheck := !config.Advanced.NoHeavyCheck - heavyCheckInterval := config.Advanced.HeavyCheckInterval + heavyCheck := !r.Config.Advanced.NoHeavyCheck + heavyCheckInterval := r.Config.Advanced.HeavyCheckInterval if heavyCheckInterval <= 0 { heavyCheck = false } - if !config.Advanced.SkipFirstSync { - if !r.cluster.SyncFiles(ctx, fileMap, false) { - return - } - go r.UpdateFileRecords(fileMap, nil) + // if !r.Config.Advanced.SkipFirstSync { + // if !r.cluster.SyncFiles(ctx, fileMap, false) { + // return + // } + // go r.UpdateFileRecords(fileMap, nil) - if !config.Advanced.NoGC { - go r.cluster.Gc() - } - } else if fl != nil { - if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { - return - } - } + // if !r.Config.Advanced.NoGC { + // go r.cluster.Gc() + // } + // } else + // if fl != nil { + // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { + // return + // } + // } createInterval(ctx, func() { - fileMap := make(map[string]*StorageFileInfo) + fileMap := make(map[string]*cluster.StorageFileInfo) for _, cr := range r.clusters { - log.Info(Tr("info.filelist.fetching"), cr.ID()) + log.TrInfof("info.filelist.fetching", cr.ID()) if err := cr.GetFileList(ctx, fileMap, false); err != nil { - log.Errorf(Tr("error.filelist.fetch.failed"), cr.ID(), err) + log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) return } } @@ -472,104 +567,122 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { oldfileset := r.cluster.CloneFileset() if r.cluster.SyncFiles(ctx, fl, heavyCheck && checkCount == 0) { go r.UpdateFileRecords(fl, oldfileset) - if !config.Advanced.NoGC && !config.OnlyGcWhenStart { + if !r.Config.Advanced.NoGC && !r.Config.OnlyGcWhenStart { go r.cluster.Gc() } } - }, (time.Duration)(config.SyncInterval)*time.Minute) + }, (time.Duration)(r.Config.SyncInterval)*time.Minute) } -func (r *Runner) CreateHTTPServerListener(ctx context.Context) (listener net.Listener) { - listener, err := net.Listen("tcp", r.Addr) +func (r *Runner) CreateHTTPListener(ctx context.Context) (*utils.HTTPTLSListener, error) { + addr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port))) + listener, err := net.Listen("tcp", addr) if err != nil { - log.Errorf(Tr("error.address.listen.failed"), r.Addr, err) - osExit(CodeEnvironmentError) + log.TrErrorf("error.address.listen.failed", addr, err) + return nil, err } if r.Config.ServeLimit.Enable { - limted := limited.NewLimitedListener(listener, config.ServeLimit.MaxConn, 0, config.ServeLimit.UploadRate*1024) + limted := limited.NewLimitedListener(listener, r.Config.ServeLimit.MaxConn, 0, r.Config.ServeLimit.UploadRate*1024) limted.SetMinWriteRate(1024) listener = limted } - tlsConfig := r.GenerateTLSConfig(ctx) - r.publicHosts = make([]string, 0, 2) - if tlsConfig != nil { - for _, cert := range tlsConfig.Certificates { - if h, err := parseCertCommonName(cert.Certificate[0]); err == nil { - r.publicHosts = append(r.publicHosts, strings.ToLower(h)) - } + if r.Config.UseCert { + var err error + r.tlsConfig, err = r.GenerateTLSConfig() + if err != nil { + log.Errorf("Failed to generate TLS config: %v", err) + return nil, err } - listener = utils.NewHttpTLSListener(listener, tlsConfig, r.publicHosts, r.getPublicPort()) } - r.listener = listener - return + return utils.NewHttpTLSListener(listener, r.tlsConfig), nil } -func (r *Runner) GenerateTLSConfig(ctx context.Context) (tlsConfig *tls.Config) { - if config.UseCert { - if len(config.Certificates) == 0 { - log.Error(Tr("error.cert.not.set")) - os.Exit(1) +func (r *Runner) GenerateTLSConfig() (*tls.Config, error) { + if len(r.Config.Certificates) == 0 { + log.TrErrorf("error.cert.not.set") + return nil, errors.New("No certificate is defined") + } + tlsConfig := new(tls.Config) + tlsConfig.Certificates = make([]tls.Certificate, len(r.Config.Certificates)) + for i, c := range r.Config.Certificates { + var err error + tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) + if err != nil { + log.TrErrorf("error.cert.parse.failed", i, err) + return nil, err } - tlsConfig = new(tls.Config) - tlsConfig.Certificates = make([]tls.Certificate, len(config.Certificates)) - for i, c := range config.Certificates { - var err error - tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) - if err != nil { - log.Errorf(Tr("error.cert.parse.failed"), i, err) - os.Exit(1) - } + } + return tlsConfig, nil +} + +func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Config) (*tls.Config, error) { + certs := make([]tls.Certificate, 0) + for _, cr := range r.clusters { + if cr.Options().Byoc { + continue + } + log.TrInfof("info.cert.requesting", cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.TrErrorf("error.cert.request.failed", err) + continue } + cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.TrErrorf("error.cert.requested.parse.failed", err) + continue + } + certs = append(certs, cert) + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.TrInfof("info.cert.requested", certHost) } - if !config.Byoc { - for _, cr := range r.clusters { - log.Info(Tr("info.cert.requesting"), cr.ID()) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := cr.RequestCert(tctx) - cancel() - if err != nil { - log.Errorf(Tr("error.cert.request.failed"), err) - os.Exit(2) - } - if tlsConfig == nil { - tlsConfig = new(tls.Config) - } - var cert tls.Certificate - cert, err = tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.Errorf(Tr("error.cert.requested.parse.failed"), err) - os.Exit(2) - } - tlsConfig.Certificates = append(tlsConfig.Certificates, cert) - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.Infof(Tr("info.cert.requested"), certHost) + if len(certs) == 0 { + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } else { + tlsConfig = tlsConfig.Clone() } + tlsConfig.Certificates = append(tlsConfig.Certificates, certs...) } - r.tlsConfig = tlsConfig - return + return tlsConfig, nil } -func (r *Runner) EnableCluster(ctx context.Context) { - if config.Advanced.WaitBeforeEnable > 0 { - select { - case <-time.After(time.Second * (time.Duration)(config.Advanced.WaitBeforeEnable)): - case <-ctx.Done(): - return - } +// updateClustersWithGeneralConfig will re-enable all clusters with latest general config +func (r *Runner) updateClustersWithGeneralConfig(ctx context.Context) error { + gcfg := r.GetClusterGeneralConfig() + var wg sync.WaitGroup + for _, cr := range r.clusters { + wg.Add(1) + go func(cr *cluster.Cluster) { + defer wg.Done() + cr.Disable(ctx) + *cr.GeneralConfig() = gcfg + if err := cr.Enable(ctx); err != nil { + log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) + return + } + }(cr) } + wg.Wait() + return nil +} - if config.Tunneler.Enable { - r.enableClusterByTunnel(ctx) - } else { - if err := r.cluster.Enable(ctx); err != nil { - log.Errorf(Tr("error.cluster.enable.failed"), err) - if ctx.Err() != nil { +func (r *Runner) EnableClusterAll(ctx context.Context) { + var wg sync.WaitGroup + for _, cr := range r.clusters { + wg.Add(1) + go func(cr *cluster.Cluster) { + defer wg.Done() + if err := cr.Enable(ctx); err != nil { + log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) return } - osExit(CodeServerOrEnvionmentError) - } + }(cr) } + wg.Wait() } func (r *Runner) StartTunneler() { @@ -597,8 +710,8 @@ func (r *Runner) StartTunneler() { } func (r *Runner) RunTunneler(ctx context.Context) { - cmd := exec.CommandContext(ctx, config.Tunneler.TunnelProg) - log.Infof(Tr("info.tunnel.running"), cmd.String()) + cmd := exec.CommandContext(ctx, r.Config.Tunneler.TunnelProg) + log.TrInfof("info.tunnel.running", cmd.String()) var ( cmdOut, cmdErr io.ReadCloser err error @@ -606,15 +719,15 @@ func (r *Runner) RunTunneler(ctx context.Context) { cmd.Env = append(os.Environ(), "CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port))) if cmdOut, err = cmd.StdoutPipe(); err != nil { - log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) + log.TrErrorf("error.tunnel.command.prepare.failed", err) os.Exit(1) } if cmdErr, err = cmd.StderrPipe(); err != nil { - log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) + log.TrErrorf("error.tunnel.command.prepare.failed", err) os.Exit(1) } if err = cmd.Start(); err != nil { - log.Errorf(Tr("error.tunnel.command.prepare.failed"), err) + log.TrErrorf("error.tunnel.command.prepare.failed", err) os.Exit(1) } type addrOut struct { @@ -623,11 +736,10 @@ func (r *Runner) RunTunneler(ctx context.Context) { } detectedCh := make(chan addrOut, 1) onLog := func(line []byte) { - res := config.Tunneler.outputRegex.FindSubmatch(line) - if res == nil { + tunnelHost, tunnelPort, ok := r.Config.Tunneler.MatchTunnelOutput(line) + if !ok { return } - tunnelHost, tunnelPort := res[config.Tunneler.hostOut], res[config.Tunneler.portOut] if len(tunnelHost) > 0 && tunnelHost[0] == '[' && tunnelHost[len(tunnelHost)-1] == ']' { // a IPv6 with port []: tunnelHost = tunnelHost[1 : len(tunnelHost)-1] } @@ -667,33 +779,11 @@ func (r *Runner) RunTunneler(ctx context.Context) { for { select { case addr := <-detectedCh: - log.Infof(Tr("info.tunnel.detected"), addr.host, addr.port) - r.cluster.publicPort = addr.port - if !r.cluster.byoc { - r.cluster.host = addr.host - } - strPort := strconv.Itoa((int)(r.getPublicPort())) - if spp, ok := r.listener.(interface{ SetPublicPort(port string) }); ok { - spp.SetPublicPort(strPort) - } - log.Infof(Tr("info.server.public.at"), net.JoinHostPort(addr.host, strPort), r.clusterSvr.Addr, r.getCertCount()) - if len(r.publicHosts) > 1 { - log.Info(Tr("info.server.alternative.hosts")) - for _, h := range r.publicHosts[1:] { - log.Infof("\t- https://%s", net.JoinHostPort(h, strPort)) - } - } - if !r.cluster.Enabled() { - shutCtx, cancel := context.WithTimeout(ctx, time.Minute) - r.cluster.Disable(shutCtx) - cancel() - } - if err := r.cluster.Enable(ctx); err != nil { - log.Errorf(Tr("error.cluster.enable.failed"), err) - if ctx.Err() != nil { - return - } - os.Exit(2) + log.TrInfof("info.tunnel.detected", addr.host, addr.port) + r.publicHost, r.publicPort = addr.host, addr.port + r.updateClustersWithGeneralConfig(ctx) + if ctx.Err() != nil { + return } case <-ctx.Done(): return diff --git a/storage/manager.go b/storage/manager.go index 79353dd6..164a34a7 100644 --- a/storage/manager.go +++ b/storage/manager.go @@ -57,6 +57,15 @@ func (m *Manager) Get(id string) Storage { return nil } +func (m *Manager) GetIndex(id string) int { + for i, s := range m.Storages { + if s.Id() == id { + return i + } + } + return -1 +} + func (m *Manager) GetFlavorString(storages []int) string { typeCount := make(map[string]int, 2) for _, i := range storages { diff --git a/sync.go b/sync.go index bded4ac1..12430b40 100644 --- a/sync.go +++ b/sync.go @@ -1,3 +1,5 @@ +//go:build ignore + /** * OpenBmclAPI (Golang Edition) * Copyright (C) 2024 Kevin Z diff --git a/util.go b/util.go index 539cf9a3..3bd2ba97 100644 --- a/util.go +++ b/util.go @@ -25,7 +25,6 @@ import ( "crypto/x509" "fmt" "io" - "math/rand" "net/http" "net/url" "os" @@ -83,54 +82,6 @@ func parseCertCommonName(body []byte) (string, error) { return cert.Subject.CommonName, nil } -func forEachFromRandomIndex(leng int, cb func(i int) (done bool)) (done bool) { - if leng <= 0 { - return false - } - start := randIntn(leng) - for i := start; i < leng; i++ { - if cb(i) { - return true - } - } - for i := 0; i < start; i++ { - if cb(i) { - return true - } - } - return false -} - -func forEachFromRandomIndexWithPossibility(poss []uint, total uint, cb func(i int) (done bool)) (done bool) { - leng := len(poss) - if leng == 0 { - return false - } - if total == 0 { - return forEachFromRandomIndex(leng, cb) - } - n := (uint)(randIntn((int)(total))) - start := 0 - for i, p := range poss { - if n < p { - start = i - break - } - n -= p - } - for i := start; i < leng; i++ { - if cb(i) { - return true - } - } - for i := 0; i < start; i++ { - if cb(i) { - return true - } - } - return false -} - func copyFile(src, dst string, mode os.FileMode) (err error) { var srcFd, dstFd *os.File if srcFd, err = os.Open(src); err != nil { diff --git a/utils/http.go b/utils/http.go index d83fbdbf..0412bf27 100644 --- a/utils/http.go +++ b/utils/http.go @@ -30,7 +30,6 @@ import ( "net/url" "path" "runtime" - "strconv" "strings" "sync" "sync/atomic" @@ -290,13 +289,10 @@ func (m *HttpMethodHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) // Else it will just return the tls connection type HTTPTLSListener struct { net.Listener - TLSConfig *tls.Config + TLSConfig atomic.Pointer[tls.Config] + DoRedirect bool AllowUnsecure bool - mux sync.RWMutex - hosts []string - port string - accepting atomic.Bool acceptedCh chan net.Conn errCh chan error @@ -304,15 +300,17 @@ type HTTPTLSListener struct { var _ net.Listener = (*HTTPTLSListener)(nil) -func NewHttpTLSListener(l net.Listener, cfg *tls.Config, publicHosts []string, port uint16) net.Listener { - return &HTTPTLSListener{ - Listener: l, - TLSConfig: cfg, - hosts: publicHosts, - port: strconv.Itoa((int)(port)), +func NewHttpTLSListener(l net.Listener, cfg *tls.Config) *HTTPTLSListener { + h := &HTTPTLSListener{ + Listener: l, + DoRedirect: true, + AllowUnsecure: false, + acceptedCh: make(chan net.Conn, 1), errCh: make(chan error, 1), } + h.TLSConfig.Store(cfg) + return h } func (s *HTTPTLSListener) Close() (err error) { @@ -329,22 +327,7 @@ func (s *HTTPTLSListener) Close() (err error) { return } -func (s *HTTPTLSListener) SetPublicPort(port string) { - s.mux.Lock() - defer s.mux.Unlock() - s.port = port -} - -func (s *HTTPTLSListener) GetPublicPort() string { - s.mux.RLock() - defer s.mux.RUnlock() - return s.port -} - func (s *HTTPTLSListener) maybeHTTPConn(c *connHeadReader) (ishttp bool) { - if len(s.hosts) == 0 { - return false - } var buf [4096]byte i, n := 0, 0 READ_HEAD: @@ -389,6 +372,11 @@ func (s *HTTPTLSListener) accepter() { s.errCh <- err return } + tlsCfg := s.TLSConfig.Load() + if tlsCfg == nil { + s.acceptedCh <- conn + return + } go s.accepter() hr := &connHeadReader{Conn: conn} hr.SetReadDeadline(time.Now().Add(time.Second * 5)) @@ -396,7 +384,7 @@ func (s *HTTPTLSListener) accepter() { hr.SetReadDeadline(time.Time{}) if !ishttp { // if it's not a http connection, it must be a tls connection - s.acceptedCh <- tls.Server(hr, s.TLSConfig) + s.acceptedCh <- tls.Server(hr, tlsCfg) return } if s.AllowUnsecure { @@ -416,41 +404,13 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { return } conn.SetReadDeadline(time.Time{}) - host, _, err := net.SplitHostPort(req.Host) - if err != nil { - host = req.Host - } - inhosts := false - if host != "" { - host = strings.ToLower(host) - for _, h := range s.hosts { - if h == "*" { - inhosts = true - break - } - if h, ok := strings.CutPrefix(h, "*."); ok { - if strings.HasSuffix(host, h) { - inhosts = true - break - } - } else if h == host { - inhosts = true - break - } - } - } + // host, _, err := net.SplitHostPort(req.Host) + // if err != nil { + // host = req.Host + // } u := *req.URL u.Scheme = "https" - if !inhosts { - host = "" - for _, h := range s.hosts { - if h != "*" && !strings.HasSuffix(h, "*.") { - host = h - break - } - } - } - if host == "" { + if !s.DoRedirect { body := strings.NewReader("Sent http request on https server") resp := &http.Response{ StatusCode: http.StatusBadRequest, @@ -468,7 +428,7 @@ func (s *HTTPTLSListener) serveHTTP(conn net.Conn) { io.Copy(conn, body) return } - u.Host = net.JoinHostPort(host, s.GetPublicPort()) + // u.Host = net.JoinHostPort(host, s.GetPublicPort()) resp := &http.Response{ StatusCode: http.StatusPermanentRedirect, ProtoMajor: req.ProtoMajor, From 650b0c36982f2313d66576f76c5f6582e18302a1 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Mon, 12 Aug 2024 13:12:09 -0700 Subject: [PATCH 20/36] refactored all errors --- api/config.go | 4 + cluster/http.go | 13 +- cluster/storage.go | 413 +++++++++++++++++- cluster/tempfile_test.go | 180 ++++++++ config.go | 152 +++++++ config/config.go | 33 +- config/server.go | 7 +- handler.go | 10 +- limited/api_rate.go | 52 ++- main.go | 124 +++--- storage/storage_webdav.go | 2 +- sync.go | 862 -------------------------------------- util.go | 46 -- utils/http.go | 43 ++ utils/util.go | 59 ++- 15 files changed, 972 insertions(+), 1028 deletions(-) create mode 100644 cluster/tempfile_test.go delete mode 100644 sync.go diff --git a/api/config.go b/api/config.go index e091b6bb..a32ffcd1 100644 --- a/api/config.go +++ b/api/config.go @@ -21,8 +21,11 @@ package api import ( "encoding/json" + "errors" ) +var ErrPreconditionFailed = errors.New("Precondition Failed") + type ConfigHandler interface { json.Marshaler json.Unmarshaler @@ -31,5 +34,6 @@ type ConfigHandler interface { UnmarshalJSONPath(path string, data []byte) error Fingerprint() string + // DoLockedAction will execute callback if the fingerprint matches, or return ErrPreconditionFailed DoLockedAction(fingerprint string, callback func(ConfigHandler) error) error } diff --git a/cluster/http.go b/cluster/http.go index 4910444c..07480e33 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -80,6 +80,14 @@ func redirectChecker(req *http.Request, via []*http.Request) error { return nil } +func (cr *Cluster) getFullURL(relpath string) (u *url.URL, err error) { + if u, err = url.Parse(cr.opts.Server); err != nil { + return + } + u.Path = path.Join(u.Path, relpath) + return +} + func (cr *Cluster) makeReq(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { return cr.makeReqWithBody(ctx, method, relpath, query, nil) } @@ -89,11 +97,10 @@ func (cr *Cluster) makeReqWithBody( method string, relpath string, query url.Values, body io.Reader, ) (req *http.Request, err error) { - var u *url.URL - if u, err = url.Parse(cr.opts.Server); err != nil { + u, err := cr.getFullURL(relpath) + if err != nil { return } - u.Path = path.Join(u.Path, relpath) if query != nil { u.RawQuery = query.Encode() } diff --git a/cluster/storage.go b/cluster/storage.go index 77173d6f..1d48521e 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -20,16 +20,21 @@ package cluster import ( + "compress/gzip" + "compress/zlib" "context" "crypto" "encoding/hex" + "errors" "fmt" "io" "net/http" "net/url" + "os" "runtime" "slices" "strconv" + "strings" "sync" "sync/atomic" "time" @@ -68,12 +73,19 @@ type FileInfo struct { Mtime int64 `json:"mtime" avro:"mtime"` } +type RequestPath struct { + *http.Request + Path string +} + type StorageFileInfo struct { - FileInfo + Hash string + Size int64 Storages []storage.Storage + URLs map[string]RequestPath } -func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageFileInfo, forceAll bool) (err error) { +func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageFileInfo, forceAll bool) error { var query url.Values lastMod := cr.fileListLastMod if forceAll { @@ -86,31 +98,30 @@ func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageF } req, err := cr.makeReqWithAuth(ctx, http.MethodGet, "/openbmclapi/files", query) if err != nil { - return + return err } res, err := cr.cachedCli.Do(req) if err != nil { - return + return err } defer res.Body.Close() switch res.StatusCode { case http.StatusOK: // case http.StatusNoContent, http.StatusNotModified: - return + return nil default: - err = utils.NewHTTPStatusErrorFromResponse(res) - return + return utils.NewHTTPStatusErrorFromResponse(res) } log.Debug("Parsing filelist body ...") zr, err := zstd.NewReader(res.Body) if err != nil { - return + return err } defer zr.Close() var files []FileInfo - if err = avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { - return + if err := avro.NewDecoderForSchema(fileListSchema, zr).Decode(&files); err != nil { + return err } for _, f := range files { @@ -119,7 +130,7 @@ func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageF } if ff, ok := fileMap[f.Hash]; ok { if ff.Size != f.Size { - log.Panicf("Hash conflict detected, hash of both %q (%dB) and %q (%dB) is %s", ff.Path, ff.Size, f.Path, f.Size, f.Hash) + log.Panicf("Hash conflict detected, hash of both %q (%dB) and %v (%dB) is %s", f.Path, f.Size, ff.URLs, ff.Size, f.Hash) } for _, s := range cr.storages { sto := cr.storageManager.Storages[s] @@ -129,19 +140,26 @@ func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageF } } else { ff := &StorageFileInfo{ - FileInfo: f, + Hash: f.Hash, + Size: f.Size, Storages: make([]storage.Storage, len(cr.storages)), + URLs: make(map[string]RequestPath), } for i, s := range cr.storages { ff.Storages[i] = cr.storageManager.Storages[s] } slices.SortFunc(ff.Storages, storageIdSortFunc) + req, err := cr.makeReqWithAuth(context.Background(), http.MethodGet, f.Path, nil) + if err != nil { + return err + } + ff.URLs[req.URL.String()] = RequestPath{Request: req, Path: f.Path} fileMap[f.Hash] = ff } } cr.fileListLastMod = lastMod log.Debugf("Filelist parsed, length = %d, lastMod = %d", len(files), lastMod) - return + return nil } func storageIdSortFunc(a, b storage.Storage) int { @@ -179,15 +197,15 @@ func checkFile( pg *mpb.Progress, ) (err error) { var missingCount atomic.Int32 - addMissing := func(f FileInfo, sto storage.Storage) { + addMissing := func(f *StorageFileInfo, sto storage.Storage) { missingCount.Add(1) if info, ok := missing[f.Hash]; ok { info.Storages = append(info.Storages, sto) } else { - missing[f.Hash] = &StorageFileInfo{ - FileInfo: f, - Storages: []storage.Storage{sto}, - } + info := new(StorageFileInfo) + *info = *f + info.Storages = []storage.Storage{sto} + missing[f.Hash] = info } } @@ -280,13 +298,13 @@ func checkFile( size, ok := ssizeMap[sto][hash] if !ok { // log.Debugf("Could not found file %q", name) - addMissing(f.FileInfo, sto) + addMissing(f, sto) bar.EwmaIncrement(time.Since(start)) continue } if size != f.Size { log.TrWarnf("warn.check.modified.size", name, size, f.Size) - addMissing(f.FileInfo, sto) + addMissing(f, sto) bar.EwmaIncrement(time.Since(start)) continue } @@ -305,7 +323,7 @@ func checkFile( return ctx.Err() } wg.Add(1) - go func(f FileInfo, buf []byte, free func()) { + go func(f *StorageFileInfo, buf []byte, free func()) { defer log.RecoverPanic(nil) defer wg.Done() miss := true @@ -329,7 +347,7 @@ func checkFile( if miss { addMissing(f, sto) } - }(f.FileInfo, buf, free) + }(f, buf, free) } } wg.Wait() @@ -341,6 +359,349 @@ func checkFile( return nil } +type syncStats struct { + slots *limited.BufSlots + + totalSize int64 + okCount, failCount atomic.Int32 + totalFiles int + + pg *mpb.Progress + totalBar *mpb.Bar + lastInc atomic.Int64 +} + +func (c *HTTPClient) SyncFiles( + ctx context.Context, + manager *storage.Manager, + files map[string]*StorageFileInfo, + heavy bool, + slots int, +) error { + pg := mpb.New(mpb.WithRefreshRate(time.Second/2), mpb.WithAutoRefresh(), mpb.WithWidth(140)) + defer pg.Shutdown() + log.SetLogOutput(pg) + defer log.SetLogOutput(nil) + + missingMap := make(map[string]*StorageFileInfo) + if err := checkFile(ctx, manager, files, heavy, missingMap, pg); err != nil { + return err + } + + totalFiles := len(files) + + var stats syncStats + stats.pg = pg + stats.slots = limited.NewBufSlots(slots) + stats.totalFiles = totalFiles + + var barUnit decor.SizeB1024 + stats.lastInc.Store(time.Now().UnixNano()) + stats.totalBar = pg.AddBar(stats.totalSize, + mpb.BarRemoveOnComplete(), + mpb.BarPriority(stats.slots.Cap()), + mpb.PrependDecorators( + decor.Name(lang.Tr("hint.sync.total")), + decor.NewPercentage("%.2f"), + ), + mpb.AppendDecorators( + decor.Any(func(decor.Statistics) string { + return fmt.Sprintf("(%d + %d / %d) ", stats.okCount.Load(), stats.failCount.Load(), stats.totalFiles) + }), + decor.Counters(barUnit, "(%.1f/%.1f) "), + decor.EwmaSpeed(barUnit, "%.1f ", 30), + decor.OnComplete( + decor.EwmaETA(decor.ET_STYLE_GO, 30), "done", + ), + ), + ) + + log.TrInfof("hint.sync.start", totalFiles, utils.BytesToUnit((float64)(stats.totalSize))) + start := time.Now() + + done := make(chan []storage.Storage, 1) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + stLen := len(manager.Storages) + aliveStorages := make(map[storage.Storage]struct{}, stLen) + for _, s := range manager.Storages { + tctx, cancel := context.WithTimeout(ctx, time.Second*10) + err := s.CheckUpload(tctx) + cancel() + if err != nil { + if err := ctx.Err(); err != nil { + return err + } + log.Errorf("Storage %s does not work: %v", s.String(), err) + } else { + aliveStorages[s] = struct{}{} + } + } + if len(aliveStorages) == 0 { + err := errors.New("All storages are broken") + log.TrErrorf("error.sync.failed", err) + return err + } + if len(aliveStorages) < stLen { + log.TrErrorf("error.sync.part.working", len(aliveStorages), stLen) + select { + case <-time.After(time.Minute): + case <-ctx.Done(): + return ctx.Err() + } + } + + for _, info := range missingMap { + log.Debugf("File %s is for %s", info.Hash, joinStorageIDs(info.Storages)) + pathRes, err := c.fetchFile(ctx, &stats, info) + if err != nil { + log.TrWarnf("warn.sync.interrupted") + return err + } + go func(info *StorageFileInfo, pathRes <-chan string) { + defer log.RecordPanic() + select { + case path := <-pathRes: + // cr.syncProg.Add(1) + if path == "" { + select { + case done <- nil: // TODO: or all storage? + case <-ctx.Done(): + } + return + } + defer os.Remove(path) + // acquire slot here + slotId, buf, free := stats.slots.Alloc(ctx) + if buf == nil { + return + } + defer free() + _ = slotId + var srcFd *os.File + if srcFd, err = os.Open(path); err != nil { + return + } + defer srcFd.Close() + var failed []storage.Storage + for _, target := range info.Storages { + if _, err = srcFd.Seek(0, io.SeekStart); err != nil { + log.Errorf("Cannot seek file %q to start: %v", path, err) + continue + } + if err = target.Create(info.Hash, srcFd); err != nil { + failed = append(failed, target) + log.TrErrorf("error.sync.create.failed", target.String(), info.Hash, err) + continue + } + } + free() + srcFd.Close() + os.Remove(path) + select { + case done <- failed: + case <-ctx.Done(): + } + case <-ctx.Done(): + return + } + }(info, pathRes) + } + + for i := len(missingMap); i > 0; i-- { + select { + case failed := <-done: + for _, s := range failed { + if _, ok := aliveStorages[s]; ok { + delete(aliveStorages, s) + log.Debugf("Broken storage %d / %d", stLen-len(aliveStorages), stLen) + if len(aliveStorages) == 0 { + cancel() + err := errors.New("All storages are broken") + log.TrErrorf("error.sync.failed", err) + return err + } + } + } + case <-ctx.Done(): + log.TrWarnf("warn.sync.interrupted") + return ctx.Err() + } + } + + use := time.Since(start) + stats.totalBar.Abort(true) + pg.Wait() + + log.TrInfof("hint.sync.done", use, utils.BytesToUnit((float64)(stats.totalSize)/use.Seconds())) + return nil +} + +func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *StorageFileInfo) (<-chan string, error) { + const maxRetryCount = 10 + + slotId, buf, free := stats.slots.Alloc(ctx) + if buf == nil { + return nil, ctx.Err() + } + + pathRes := make(chan string, 1) + go func() { + defer log.RecordPanic() + defer free() + defer close(pathRes) + + var barUnit decor.SizeB1024 + var tried atomic.Int32 + tried.Store(1) + + fPath := f.Hash // TODO: show downloading URL instead? Will it be too long? + + bar := stats.pg.AddBar(f.Size, + mpb.BarRemoveOnComplete(), + mpb.BarPriority(slotId), + mpb.PrependDecorators( + decor.Name(lang.Tr("hint.sync.downloading")), + decor.Any(func(decor.Statistics) string { + tc := tried.Load() + if tc <= 1 { + return "" + } + return fmt.Sprintf("(%d/%d) ", tc, maxRetryCount) + }), + decor.Name(fPath, decor.WCSyncSpaceR), + ), + mpb.AppendDecorators( + decor.NewPercentage("%d", decor.WCSyncSpace), + decor.Counters(barUnit, "[%.1f / %.1f]", decor.WCSyncSpace), + decor.EwmaSpeed(barUnit, "%.1f", 30, decor.WCSyncSpace), + decor.OnComplete( + decor.EwmaETA(decor.ET_STYLE_GO, 30, decor.WCSyncSpace), "done", + ), + ), + ) + defer bar.Abort(true) + + interval := time.Second + for { + bar.SetCurrent(0) + hashMethod, err := getHashMethod(len(f.Hash)) + if err == nil { + var path string + if path, err = c.fetchFileWithBuf(ctx, f, hashMethod, buf, func(r io.Reader) io.Reader { + return utils.ProxyPBReader(r, bar, stats.totalBar, &stats.lastInc) + }); err == nil { + pathRes <- path + stats.okCount.Add(1) + log.Infof(lang.Tr("info.sync.downloaded"), fPath, + utils.BytesToUnit((float64)(f.Size)), + (float64)(stats.totalBar.Current())/(float64)(stats.totalSize)*100) + return + } + } + bar.SetRefill(bar.Current()) + + c := tried.Add(1) + if c > maxRetryCount { + log.TrErrorf("error.sync.download.failed", fPath, err) + break + } + log.TrErrorf("error.sync.download.failed.retry", fPath, interval, err) + select { + case <-time.After(interval): + interval *= 2 + case <-ctx.Done(): + return + } + } + stats.failCount.Add(1) + }() + return pathRes, nil +} + +func (c *HTTPClient) fetchFileWithBuf( + ctx context.Context, f *StorageFileInfo, + hashMethod crypto.Hash, buf []byte, + wrapper func(io.Reader) io.Reader, +) (path string, err error) { + var ( + req *http.Request + res *http.Response + fd *os.File + r io.Reader + ) + for _, rq := range f.URLs { + req = rq.Request + break + } + req = req.Clone(ctx) + req.Header.Set("Accept-Encoding", "gzip, deflate") + if res, err = c.Do(req); err != nil { + return + } + defer res.Body.Close() + if err = ctx.Err(); err != nil { + return + } + if res.StatusCode != http.StatusOK { + err = utils.ErrorFromRedirect(utils.NewHTTPStatusErrorFromResponse(res), res) + return + } + switch ce := strings.ToLower(res.Header.Get("Content-Encoding")); ce { + case "": + r = res.Body + case "gzip": + if r, err = gzip.NewReader(res.Body); err != nil { + err = utils.ErrorFromRedirect(err, res) + return + } + case "deflate": + if r, err = zlib.NewReader(res.Body); err != nil { + err = utils.ErrorFromRedirect(err, res) + return + } + default: + err = utils.ErrorFromRedirect(fmt.Errorf("Unexpected Content-Encoding %q", ce), res) + return + } + if wrapper != nil { + r = wrapper(r) + } + + hw := hashMethod.New() + + if fd, err = os.CreateTemp("", "*.downloading"); err != nil { + return + } + path = fd.Name() + defer func(path string) { + if err != nil { + os.Remove(path) + } + }(path) + + _, err = io.CopyBuffer(io.MultiWriter(hw, fd), r, buf) + stat, err2 := fd.Stat() + fd.Close() + if err != nil { + err = utils.ErrorFromRedirect(err, res) + return + } + if err2 != nil { + err = err2 + return + } + if t := stat.Size(); f.Size >= 0 && t != f.Size { + err = utils.ErrorFromRedirect(fmt.Errorf("File size wrong, got %d, expect %d", t, f.Size), res) + return + } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != f.Hash { + err = utils.ErrorFromRedirect(fmt.Errorf("File hash not match, got %s, expect %s", hs, f.Hash), res) + return + } + return +} + func getHashMethod(l int) (hashMethod crypto.Hash, err error) { switch l { case 32: @@ -352,3 +713,11 @@ func getHashMethod(l int) (hashMethod crypto.Hash, err error) { } return } + +func joinStorageIDs(storages []storage.Storage) string { + ss := make([]string, len(storages)) + for i, s := range storages { + ss[i] = s.Id() + } + return "[" + strings.Join(ss, ", ") + "]" +} diff --git a/cluster/tempfile_test.go b/cluster/tempfile_test.go new file mode 100644 index 00000000..d3d216b6 --- /dev/null +++ b/cluster/tempfile_test.go @@ -0,0 +1,180 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster_test + +import ( + "testing" + + "io" + "os" +) + +var datas = func() [][]byte { + datas := make([][]byte, 0x7) + for i := range len(datas) { + b := make([]byte, 0xff00+i) + for j := range len(b) { + b[j] = (byte)(i + j) + } + datas[i] = b + } + return datas +}() + +func BenchmarkCreateAndRemoveFile(t *testing.B) { + t.ReportAllocs() + buf := make([]byte, 1024) + _ = buf + for i := 0; i < t.N; i++ { + d := datas[i%len(datas)] + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + t.Fatalf("Cannot create temp file: %v", err) + } + if _, err = fd.Write(d); err != nil { + t.Errorf("Cannot write file: %v", err) + } else if err = fd.Sync(); err != nil { + t.Errorf("Cannot write file: %v", err) + } + fd.Close() + os.Remove(fd.Name()) + if err != nil { + t.FailNow() + } + } +} + +func BenchmarkWriteAndTruncateFile(t *testing.B) { + t.ReportAllocs() + buf := make([]byte, 1024) + _ = buf + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + t.Fatalf("Cannot create temp file: %v", err) + } + defer os.Remove(fd.Name()) + for i := 0; i < t.N; i++ { + d := datas[i%len(datas)] + if _, err := fd.Write(d); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if err := fd.Sync(); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if err := fd.Truncate(0); err != nil { + t.Fatalf("Cannot truncate file: %v", err) + } + } +} + +func BenchmarkWriteAndSeekFile(t *testing.B) { + t.ReportAllocs() + buf := make([]byte, 1024) + _ = buf + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + t.Fatalf("Cannot create temp file: %v", err) + } + defer os.Remove(fd.Name()) + for i := 0; i < t.N; i++ { + d := datas[i%len(datas)] + if _, err := fd.Write(d); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if err := fd.Sync(); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if _, err := fd.Seek(io.SeekStart, 0); err != nil { + t.Fatalf("Cannot seek file: %v", err) + } + } +} + +func BenchmarkParallelCreateAndRemoveFile(t *testing.B) { + t.ReportAllocs() + t.SetParallelism(4) + buf := make([]byte, 1024) + _ = buf + t.RunParallel(func(pb *testing.PB) { + for i := 0; pb.Next(); i++ { + d := datas[i%len(datas)] + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + t.Fatalf("Cannot create temp file: %v", err) + } + if _, err = fd.Write(d); err != nil { + t.Errorf("Cannot write file: %v", err) + } else if err = fd.Sync(); err != nil { + t.Errorf("Cannot write file: %v", err) + } + fd.Close() + if err := os.Remove(fd.Name()); err != nil { + t.Fatalf("Cannot remove file: %v", err) + } + if err != nil { + t.FailNow() + } + } + }) +} + +func BenchmarkParallelWriteAndTruncateFile(t *testing.B) { + t.ReportAllocs() + t.SetParallelism(4) + buf := make([]byte, 1024) + _ = buf + t.RunParallel(func(pb *testing.PB) { + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + t.Fatalf("Cannot create temp file: %v", err) + } + defer os.Remove(fd.Name()) + for i := 0; pb.Next(); i++ { + d := datas[i%len(datas)] + if _, err := fd.Write(d); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if err := fd.Sync(); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if err := fd.Truncate(0); err != nil { + t.Fatalf("Cannot truncate file: %v", err) + } + } + }) +} + +func BenchmarkParallelWriteAndSeekFile(t *testing.B) { + t.ReportAllocs() + t.SetParallelism(4) + buf := make([]byte, 1024) + _ = buf + t.RunParallel(func(pb *testing.PB) { + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + t.Fatalf("Cannot create temp file: %v", err) + } + defer os.Remove(fd.Name()) + for i := 0; pb.Next(); i++ { + d := datas[i%len(datas)] + if _, err := fd.Write(d); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if err := fd.Sync(); err != nil { + t.Fatalf("Cannot write file: %v", err) + } else if _, err := fd.Seek(io.SeekStart, 0); err != nil { + t.Fatalf("Cannot seel file: %v", err) + } + } + }) +} diff --git a/config.go b/config.go index b5259111..1c4540af 100644 --- a/config.go +++ b/config.go @@ -21,16 +21,22 @@ package main import ( "bytes" + "context" + "encoding/json" "errors" "fmt" "net/url" "os" + "strings" + "sync" "gopkg.in/yaml.v3" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/utils" ) const DefaultBMCLAPIServer = "https://openbmclapi.bangbang93.com" @@ -185,3 +191,149 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { } return } + +type ConfigHandler struct { + mux sync.RWMutex + r *Runner + + updateProcess []func(context.Context) error +} + +var _ api.ConfigHandler = (*ConfigHandler)(nil) + +func (c *ConfigHandler) update(newConfig *config.Config) error { + r := c.r + oldConfig := r.Config + c.updateProcess = c.updateProcess[:0] + + if newConfig.LogSlots != oldConfig.LogSlots || newConfig.NoAccessLog != oldConfig.NoAccessLog || newConfig.AccessLogSlots != oldConfig.AccessLogSlots || newConfig.Advanced.DebugLog != oldConfig.Advanced.DebugLog { + c.updateProcess = append(c.updateProcess, r.SetupLogger) + } + if newConfig.Host != oldConfig.Host || newConfig.Port != oldConfig.Port { + c.updateProcess = append(c.updateProcess, r.StartServer) + } + if newConfig.PublicHost != oldConfig.PublicHost || newConfig.PublicPort != oldConfig.PublicPort || newConfig.Advanced.NoFastEnable != oldConfig.Advanced.NoFastEnable || newConfig.MaxReconnectCount != oldConfig.MaxReconnectCount { + c.updateProcess = append(c.updateProcess, r.updateClustersWithGeneralConfig) + } + if newConfig.RateLimit != oldConfig.RateLimit { + c.updateProcess = append(c.updateProcess, r.updateRateLimit) + } + if newConfig.Notification != oldConfig.Notification { + // c.updateProcess = append(c.updateProcess, ) + } + + r.Config = newConfig + r.publicHost = r.Config.PublicHost + r.publicPort = r.Config.PublicPort + return nil +} + +func (c *ConfigHandler) doUpdateProcesses(ctx context.Context) error { + for _, proc := range c.updateProcess { + if err := proc(ctx); err != nil { + return err + } + } + c.updateProcess = c.updateProcess[:0] + return nil +} + +func (c *ConfigHandler) MarshalJSON() ([]byte, error) { + return c.r.Config.MarshalJSON() +} + +func (c *ConfigHandler) UnmarshalJSON(data []byte) error { + c2 := c.r.Config.Clone() + if err := c2.UnmarshalJSON(data); err != nil { + return err + } + c.update(c2) + return nil +} + +func (c *ConfigHandler) UnmarshalYAML(data []byte) error { + c2 := c.r.Config.Clone() + if err := c2.UnmarshalText(data); err != nil { + return err + } + c.update(c2) + return nil +} + +func (c *ConfigHandler) MarshalJSONPath(path string) ([]byte, error) { + names := strings.Split(path, ".") + data, err := c.r.Config.MarshalJSON() + if err != nil { + return nil, err + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return nil, err + } + accessed := "" + var x any = m + for _, n := range names { + mc, ok := x.(map[string]any) + if !ok { + return nil, fmt.Errorf("Unexpected type %T on path %q, expect map[string]any", x, accessed) + } + accessed += n + "." + x = mc[n] + } + return json.Marshal(x) +} + +func (c *ConfigHandler) UnmarshalJSONPath(path string, data []byte) error { + names := strings.Split(path, ".") + var d any + if err := json.Unmarshal(data, &d); err != nil { + return err + } + accessed := "" + var m map[string]any + { + b, err := c.MarshalJSON() + if err != nil { + return err + } + if err := json.Unmarshal(b, &m); err != nil { + return err + } + } + x := m + for _, p := range names[:len(names)-1] { + accessed += p + "." + var ok bool + x, ok = x[p].(map[string]any) + if !ok { + return fmt.Errorf("Unexpected type %T on path %q, expect map[string]any", x, accessed) + } + } + x[names[len(names)-1]] = d + dt, err := json.Marshal(m) + if err != nil { + return err + } + return c.UnmarshalJSON(dt) +} + +func (c *ConfigHandler) Fingerprint() string { + c.mux.RLock() + defer c.mux.RUnlock() + return c.fingerprintLocked() +} + +func (c *ConfigHandler) fingerprintLocked() string { + data, err := c.MarshalJSON() + if err != nil { + log.Panicf("ConfigHandler.Fingerprint: MarshalJSON: %v", err) + } + return utils.BytesAsSha256(data) +} + +func (c *ConfigHandler) DoLockedAction(fingerprint string, callback func(api.ConfigHandler) error) error { + if c.fingerprintLocked() != fingerprint { + return api.ErrPreconditionFailed + } + return callback(c) +} diff --git a/config/config.go b/config/config.go index ec18ca2f..3654b6cd 100644 --- a/config/config.go +++ b/config/config.go @@ -20,6 +20,7 @@ package config import ( + "encoding/json" "path/filepath" "time" @@ -94,10 +95,9 @@ func NewDefaultConfig() *Config { Certificates: []CertificateConfig{}, Tunneler: TunnelConfig{ - Enable: false, - TunnelProg: "./path/to/tunnel/program", - OutputRegex: `\bNATedAddr\s+(?P[0-9.]+|\[[0-9a-f:]+\]):(?P\d+)$`, - TunnelTimeout: 0, + Enable: false, + TunnelProg: "./path/to/tunnel/program", + OutputRegex: `\bNATedAddr\s+(?P[0-9.]+|\[[0-9a-f:]+\]):(?P\d+)$`, }, Cache: CacheConfig{ @@ -133,6 +133,8 @@ func NewDefaultConfig() *Config { Dashboard: DashboardConfig{ Enable: true, + Username: "", + Password: "", PwaName: "GoOpenBmclApi Dashboard", PwaShortName: "GOBA Dash", PwaDesc: "Go-Openbmclapi Internal Dashboard", @@ -141,6 +143,7 @@ func NewDefaultConfig() *Config { GithubAPI: GithubAPIConfig{ UpdateCheckInterval: (utils.YAMLDuration)(time.Hour), + Authorization: "", }, Database: DatabaseConfig{ @@ -177,6 +180,28 @@ func NewDefaultConfig() *Config { } } +func (config *Config) MarshalJSON() ([]byte, error) { + type T Config + return json.Marshal((*T)(config)) +} + +func (config *Config) UnmarshalJSON(data []byte) error { + type T Config + return json.Unmarshal(data, (*T)(config)) +} + func (config *Config) UnmarshalText(data []byte) error { return yaml.Unmarshal(data, config) } + +func (config *Config) Clone() *Config { + data, err := config.MarshalJSON() + if err != nil { + panic(err) + } + cloned := new(Config) + if err := cloned.UnmarshalJSON(data); err != nil { + panic(err) + } + return cloned +} diff --git a/config/server.go b/config/server.go index a1763a8f..e6514d75 100644 --- a/config/server.go +++ b/config/server.go @@ -122,10 +122,9 @@ type GithubAPIConfig struct { } type TunnelConfig struct { - Enable bool `yaml:"enable"` - TunnelProg string `yaml:"tunnel-program"` - OutputRegex string `yaml:"output-regex"` - TunnelTimeout int `yaml:"tunnel-timeout"` + Enable bool `yaml:"enable"` + TunnelProg string `yaml:"tunnel-program"` + OutputRegex string `yaml:"output-regex"` outputRegex *regexp.Regexp hostOut int diff --git a/handler.go b/handler.go index 97e4e85a..5fecf915 100644 --- a/handler.go +++ b/handler.go @@ -39,7 +39,6 @@ import ( "github.com/LiterMC/go-openbmclapi/api/v0" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/internal/gosrc" - "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -102,11 +101,14 @@ var wsUpgrader = &websocket.Upgrader{ HandshakeTimeout: time.Second * 30, } -func (r *Runner) GetHandler() http.Handler { - r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) +func (r *Runner) updateRateLimit(ctx context.Context) error { r.apiRateLimiter.SetAnonymousRateLimit(r.Config.RateLimit.Anonymous) r.apiRateLimiter.SetLoggedRateLimit(r.Config.RateLimit.Logged) - r.handlerAPIv0 = http.StripPrefix("/api/v0", v0.NewHandler(wsUpgrader)) + return nil +} + +func (r *Runner) GetHandler() http.Handler { + r.handlerAPIv0 = http.StripPrefix("/api/v0", v0.NewHandler(wsUpgrader, r.configHandler, r.userManager, r.tokenManager, r.subManager)) r.hijackHandler = http.StripPrefix("/bmclapi", r.hijacker) handler := utils.NewHttpMiddleWareHandler((http.HandlerFunc)(r.serveHTTP)) diff --git a/limited/api_rate.go b/limited/api_rate.go index 33b7ea83..d6646309 100644 --- a/limited/api_rate.go +++ b/limited/api_rate.go @@ -38,9 +38,8 @@ type RateLimit struct { } type limitSet struct { - Limit RateLimit - mux sync.RWMutex + limit RateLimit cleanCount int // min clean mask: 0xffff; hour clean mask: 0xff0000 accessMin map[string]*atomic.Int64 accessHour map[string]*atomic.Int64 @@ -54,8 +53,26 @@ func makeLimitSet() limitSet { } } +func (s *limitSet) GetLimit() RateLimit { + s.mux.RLock() + defer s.mux.RUnlock() + return s.limit +} + +func (s *limitSet) SetLimit(limit RateLimit) { + s.mux.Lock() + defer s.mux.Unlock() + s.limit = limit +} + func (s *limitSet) try(id string) (leftHour, leftMin int64, cleanId int) { - checkHour, checkMin := s.Limit.PerHour > 0, s.Limit.PerMin > 0 + var ( + hour, min *atomic.Int64 + ok1, ok2 bool + ) + + s.mux.RLock() + checkHour, checkMin := s.limit.PerHour > 0, s.limit.PerMin > 0 if !checkHour { leftHour = -1 } @@ -67,12 +84,6 @@ func (s *limitSet) try(id string) (leftHour, leftMin int64, cleanId int) { return } - var ( - hour, min *atomic.Int64 - ok1, ok2 bool - ) - - s.mux.RLock() cleanId = s.cleanCount if checkHour { hour, ok1 = s.accessHour[id] @@ -99,7 +110,7 @@ func (s *limitSet) try(id string) (leftHour, leftMin int64, cleanId int) { } s.mux.Unlock() } - leftHour = s.Limit.PerHour - hour.Add(1) + leftHour = s.limit.PerHour - hour.Add(1) if leftHour < 0 { hour.Add(-1) leftHour = 0 @@ -118,7 +129,7 @@ func (s *limitSet) try(id string) (leftHour, leftMin int64, cleanId int) { } s.mux.Unlock() } - leftMin = s.Limit.PerMin - min.Add(1) + leftMin = s.limit.PerMin - min.Add(1) if leftMin < 0 { hour.Add(-1) min.Add(-1) @@ -134,12 +145,12 @@ func (s *limitSet) release(id string, cleanId int) { if cleanId <= 0 { return } - checkHour, checkMin := s.Limit.PerHour > 0, s.Limit.PerMin > 0 + s.mux.Lock() + defer s.mux.Unlock() + checkHour, checkMin := s.limit.PerHour > 0, s.limit.PerMin > 0 if !checkHour && !checkMin { return } - s.mux.Lock() - defer s.mux.Unlock() releaseHour := checkHour && cleanId&0xff0000 == s.cleanCount&0xff0000 releaseMin := checkMin && cleanId&0xffff == s.cleanCount&0xffff if releaseHour { @@ -219,19 +230,19 @@ func SetSkipRateLimit(req *http.Request) *http.Request { } func (a *APIRateMiddleWare) AnonymousRateLimit() RateLimit { - return a.annoySet.Limit + return a.annoySet.GetLimit() } func (a *APIRateMiddleWare) SetAnonymousRateLimit(v RateLimit) { - a.annoySet.Limit = v + a.annoySet.SetLimit(v) } func (a *APIRateMiddleWare) LoggedRateLimit() RateLimit { - return a.loggedSet.Limit + return a.loggedSet.GetLimit() } func (a *APIRateMiddleWare) SetLoggedRateLimit(v RateLimit) { - a.loggedSet.Limit = v + a.loggedSet.SetLimit(v) } func (a *APIRateMiddleWare) Destroy() { @@ -265,6 +276,7 @@ func (a *APIRateMiddleWare) ServeMiddle(rw http.ResponseWriter, req *http.Reques } set = &a.annoySet } + limit := set.GetLimit() hourLeft, minLeft, cleanId := set.try(id) now := time.Now() var retryAfter int @@ -274,8 +286,8 @@ func (a *APIRateMiddleWare) ServeMiddle(rw http.ResponseWriter, req *http.Reques retryAfter = 60 - (int)(now.Sub(a.startAt)/time.Second%60) } resetAfter := now.Add((time.Duration)(retryAfter) * time.Second).Unix() - rw.Header().Set("X-Ratelimit-Limit-Minute", strconv.FormatInt(set.Limit.PerMin, 10)) - rw.Header().Set("X-Ratelimit-Limit-Hour", strconv.FormatInt(set.Limit.PerHour, 10)) + rw.Header().Set("X-Ratelimit-Limit-Minute", strconv.FormatInt(limit.PerMin, 10)) + rw.Header().Set("X-Ratelimit-Limit-Hour", strconv.FormatInt(limit.PerHour, 10)) rw.Header().Set("X-Ratelimit-Remaining-Minute", strconv.FormatInt(minLeft, 10)) rw.Header().Set("X-Ratelimit-Remaining-Hour", strconv.FormatInt(hourLeft, 10)) rw.Header().Set("X-Ratelimit-Reset-After", strconv.FormatInt(resetAfter, 10)) diff --git a/main.go b/main.go index 55f42dc4..3dd36c89 100644 --- a/main.go +++ b/main.go @@ -45,6 +45,7 @@ import ( doh "github.com/libp2p/go-doh-resolver" + "github.com/LiterMC/go-openbmclapi/api" "github.com/LiterMC/go-openbmclapi/api/bmclapi" "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/config" @@ -118,7 +119,7 @@ func main() { defer log.RecordPanic() log.StartFlushLogFile() - r := new(Runner) + r := NewRunner() ctx, cancel := context.WithCancel(context.Background()) @@ -211,27 +212,40 @@ func main() { type Runner struct { Config *config.Config + configHandler *ConfigHandler + client *cluster.HTTPClient clusters map[string]*cluster.Cluster - apiRateLimiter *limited.APIRateMiddleWare + userManager api.UserManager + tokenManager api.TokenManager + subManager api.SubscriptionManager storageManager *storage.Manager statManager *cluster.StatManager hijacker *bmclapi.HjProxy database database.DB - server *http.Server - handlerAPIv0 http.Handler - hijackHandler http.Handler + server *http.Server + apiRateLimiter *limited.APIRateMiddleWare + handler http.Handler + handlerAPIv0 http.Handler + hijackHandler http.Handler - tlsConfig *tls.Config + tlsConfig *tls.Config publicHost string publicPort uint16 - listener *utils.HTTPTLSListener + listener *utils.HTTPTLSListener reloading atomic.Bool updating atomic.Bool tunnelCancel context.CancelFunc } +func NewRunner() *Runner { + r := new(Runner) + r.configHandler = &ConfigHandler{r: r} + r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) + return r +} + func (r *Runner) getPublicPort() uint16 { if r.publicPort > 0 { return r.publicPort @@ -250,9 +264,13 @@ func (r *Runner) InitServer() { r.server = &http.Server{ ReadTimeout: 10 * time.Second, IdleTimeout: 5 * time.Second, - Handler: r.GetHandler(), - ErrorLog: log.ProxiedStdLog, + Handler: (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { + r.handler.ServeHTTP(rw, req) + }), + ErrorLog: log.ProxiedStdLog, } + r.updateRateLimit(context.TODO()) + r.handler = r.GetHandler() } // StartServer will start the HTTP server @@ -390,7 +408,6 @@ func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) i } } - return 0 } func (r *Runner) ReloadConfig(ctx context.Context) { @@ -411,28 +428,10 @@ func (r *Runner) ReloadConfig(ctx context.Context) { } func (r *Runner) updateConfig(ctx context.Context, newConfig *config.Config) error { - oldConfig := r.Config - reloadProcesses := make([]func(context.Context) error, 0, 8) - - if newConfig.LogSlots != oldConfig.LogSlots || newConfig.NoAccessLog != oldConfig.NoAccessLog || newConfig.AccessLogSlots != oldConfig.AccessLogSlots || newConfig.Advanced.DebugLog != oldConfig.Advanced.DebugLog { - reloadProcesses = append(reloadProcesses, r.SetupLogger) - } - if newConfig.Host != oldConfig.Host || newConfig.Port != oldConfig.Port { - reloadProcesses = append(reloadProcesses, r.StartServer) - } - if newConfig.PublicHost != oldConfig.PublicHost || newConfig.PublicPort != oldConfig.PublicPort || newConfig.Advanced.NoFastEnable != oldConfig.Advanced.NoFastEnable || newConfig.MaxReconnectCount != oldConfig.MaxReconnectCount { - reloadProcesses = append(reloadProcesses, r.updateClustersWithGeneralConfig) - } - - r.Config = newConfig - r.publicHost = r.Config.PublicHost - r.publicPort = r.Config.PublicPort - for _, proc := range reloadProcesses { - if err := proc(ctx); err != nil { - return err - } + if err := r.configHandler.update(newConfig); err != nil { + return err } - return nil + return r.configHandler.doUpdateProcesses(ctx) } func (r *Runner) SetupLogger(ctx context.Context) error { @@ -494,21 +493,23 @@ func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, ol sem := limited.NewSemaphore(12) log.Info("Begin to update file records") for _, f := range files { - if strings.HasPrefix(f.Path, "/openbmclapi/download/") { - continue - } - if oldfileset[f.Hash] > 0 { - continue + for _, u := range f.URLs { + if strings.HasPrefix(u.Path, "/openbmclapi/download/") { + continue + } + if oldfileset[f.Hash] > 0 { + continue + } + sem.Acquire() + go func(rec database.FileRecord) { + defer sem.Release() + r.database.SetFileRecord(rec) + }(database.FileRecord{ + Path: u.Path, + Hash: f.Hash, + Size: f.Size, + }) } - sem.Acquire() - go func(rec database.FileRecord) { - defer sem.Release() - r.database.SetFileRecord(rec) - }(database.FileRecord{ - Path: f.Path, - Hash: f.Hash, - Size: f.Size, - }) } sem.Wait() log.Info("All file records are updated") @@ -533,17 +534,20 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { heavyCheck = false } - // if !r.Config.Advanced.SkipFirstSync { - // if !r.cluster.SyncFiles(ctx, fileMap, false) { - // return - // } - // go r.UpdateFileRecords(fileMap, nil) + // if !r.Config.Advanced.SkipFirstSync + { + slots := 10 + if err := r.client.SyncFiles(ctx, r.storageManager, fileMap, false, slots); err != nil { + log.Errorf("Sync failed: %v", err) + return + } + go r.UpdateFileRecords(fileMap, nil) - // if !r.Config.Advanced.NoGC { - // go r.cluster.Gc() - // } - // } else - // if fl != nil { + // if !r.Config.Advanced.NoGC { + // go r.cluster.Gc() + // } + } + // else if fl != nil { // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { // return // } @@ -564,12 +568,10 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } checkCount = (checkCount + 1) % heavyCheckInterval - oldfileset := r.cluster.CloneFileset() - if r.cluster.SyncFiles(ctx, fl, heavyCheck && checkCount == 0) { - go r.UpdateFileRecords(fl, oldfileset) - if !r.Config.Advanced.NoGC && !r.Config.OnlyGcWhenStart { - go r.cluster.Gc() - } + slots := 10 + if err := r.client.SyncFiles(ctx, r.storageManager, fileMap, heavyCheck && (checkCount == 0), slots); err != nil { + log.Errorf("Sync failed: %v", err) + return } }, (time.Duration)(r.Config.SyncInterval)*time.Minute) } diff --git a/storage/storage_webdav.go b/storage/storage_webdav.go index 940067f4..188fc223 100644 --- a/storage/storage_webdav.go +++ b/storage/storage_webdav.go @@ -551,7 +551,7 @@ func (s *WebDavStorage) ServeMeasure(rw http.ResponseWriter, req *http.Request, } func (s *WebDavStorage) createMeasureFile(ctx context.Context, size int) error { - if s.measures.Has(size) { + if s.measures.Contains(size) { // TODO: is this safe? return nil } diff --git a/sync.go b/sync.go deleted file mode 100644 index 12430b40..00000000 --- a/sync.go +++ /dev/null @@ -1,862 +0,0 @@ -//go:build ignore - -/** - * OpenBmclAPI (Golang Edition) - * Copyright (C) 2024 Kevin Z - * All rights reserved - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published - * by the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - */ - -package main - -import ( - "compress/gzip" - "compress/zlib" - "context" - "crypto" - "encoding/hex" - "errors" - "fmt" - "io" - "net/http" - "net/url" - "os" - "path" - "runtime" - "sort" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/hamba/avro/v2" - "github.com/klauspost/compress/zstd" - "github.com/vbauerster/mpb/v8" - "github.com/vbauerster/mpb/v8/decor" - - "github.com/LiterMC/go-openbmclapi/internal/build" - "github.com/LiterMC/go-openbmclapi/limited" - "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/storage" - "github.com/LiterMC/go-openbmclapi/update" - "github.com/LiterMC/go-openbmclapi/utils" -) - -func (cr *Cluster) CloneFileset() map[string]int64 { - cr.filesetMux.RLock() - defer cr.filesetMux.RUnlock() - fileset := make(map[string]int64, len(cr.fileset)) - for k, v := range cr.fileset { - fileset[k] = v - } - return fileset -} - -func (cr *Cluster) CachedFileSize(hash string) (size int64, ok bool) { - cr.filesetMux.RLock() - defer cr.filesetMux.RUnlock() - if size, ok = cr.fileset[hash]; !ok { - return - } - if size < 0 { - size = -size - } - return -} - -type syncStats struct { - slots *limited.BufSlots - noOpen bool - - totalSize int64 - okCount, failCount atomic.Int32 - totalFiles int - - pg *mpb.Progress - totalBar *mpb.Bar - lastInc atomic.Int64 -} - -func (cr *Cluster) checkFileFor( - ctx context.Context, - sto storage.Storage, files []FileInfo, - heavy bool, - missing *utils.SyncMap[string, *fileInfoWithTargets], - pg *mpb.Progress, -) (err error) { - var missingCount atomic.Int32 - addMissing := func(f FileInfo) { - missingCount.Add(1) - if info, has := missing.GetOrSet(f.Hash, func() *fileInfoWithTargets { - return &fileInfoWithTargets{ - FileInfo: f, - targets: []storage.Storage{sto}, - } - }); has { - info.tgMux.Lock() - info.targets = append(info.targets, sto) - info.tgMux.Unlock() - } - } - - log.Infof(Tr("info.check.start"), sto.String(), heavy) - - var ( - checkingHashMux sync.Mutex - checkingHash string - lastCheckingHash string - slots *limited.BufSlots - ) - - if heavy { - slots = limited.NewBufSlots(runtime.GOMAXPROCS(0) * 2) - } - - bar := pg.AddBar(0, - mpb.BarRemoveOnComplete(), - mpb.PrependDecorators( - decor.Name(Tr("hint.check.checking")), - decor.Name(sto.String()), - decor.OnCondition( - decor.Any(func(decor.Statistics) string { - c, l := slots.Cap(), slots.Len() - return fmt.Sprintf(" (%d / %d)", c-l, c) - }), - heavy, - ), - ), - mpb.AppendDecorators( - decor.CountersNoUnit("%d / %d", decor.WCSyncSpaceR), - decor.NewPercentage("%d", decor.WCSyncSpaceR), - decor.EwmaETA(decor.ET_STYLE_GO, 60), - ), - mpb.BarExtender((mpb.BarFillerFunc)(func(w io.Writer, _ decor.Statistics) (err error) { - if checkingHashMux.TryLock() { - lastCheckingHash = checkingHash - checkingHashMux.Unlock() - } - if lastCheckingHash != "" { - _, err = fmt.Fprintln(w, "\t", lastCheckingHash) - } - return - }), false), - ) - defer bar.Wait() - defer bar.Abort(true) - - bar.SetTotal(0x100, false) - - sizeMap := make(map[string]int64, len(files)) - { - start := time.Now() - var checkedMp [256]bool - if err = sto.WalkDir(func(hash string, size int64) error { - if n := utils.HexTo256(hash); !checkedMp[n] { - checkedMp[n] = true - now := time.Now() - bar.EwmaIncrement(now.Sub(start)) - start = now - } - sizeMap[hash] = size - return nil - }); err != nil { - return - } - } - - bar.SetCurrent(0) - bar.SetTotal((int64)(len(files)), false) - for _, f := range files { - if err = ctx.Err(); err != nil { - return - } - start := time.Now() - hash := f.Hash - if checkingHashMux.TryLock() { - checkingHash = hash - checkingHashMux.Unlock() - } - name := sto.String() + "/" + hash - if f.Size == 0 { - log.Debugf("Skipped empty file %s", name) - } else if size, ok := sizeMap[hash]; ok { - if size != f.Size { - log.Warnf(Tr("warn.check.modified.size"), name, size, f.Size) - addMissing(f) - } else if heavy { - hashMethod, err := getHashMethod(len(hash)) - if err != nil { - log.Errorf(Tr("error.check.unknown.hash.method"), hash) - } else { - _, buf, free := slots.Alloc(ctx) - if buf == nil { - return ctx.Err() - } - go func(f FileInfo, buf []byte, free func()) { - defer log.RecoverPanic(nil) - defer free() - miss := true - r, err := sto.Open(hash) - if err != nil { - log.Errorf(Tr("error.check.open.failed"), name, err) - } else { - hw := hashMethod.New() - _, err = io.CopyBuffer(hw, r, buf[:]) - r.Close() - if err != nil { - log.Errorf(Tr("error.check.hash.failed"), name, err) - } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != hash { - log.Warnf(Tr("warn.check.modified.hash"), name, hs, hash) - } else { - miss = false - } - } - free() - if miss { - addMissing(f) - } - bar.EwmaIncrement(time.Since(start)) - }(f, buf, free) - continue - } - } - } else { - // log.Debugf("Could not found file %q", name) - addMissing(f) - } - bar.EwmaIncrement(time.Since(start)) - } - - checkingHashMux.Lock() - checkingHash = "" - checkingHashMux.Unlock() - - bar.SetTotal(-1, true) - log.Infof(Tr("info.check.done"), sto.String(), missingCount.Load()) - return -} - -func (cr *Cluster) CheckFiles( - ctx context.Context, - files []FileInfo, - heavyCheck bool, - pg *mpb.Progress, -) (map[string]*fileInfoWithTargets, error) { - missingMap := utils.NewSyncMap[string, *fileInfoWithTargets]() - done := make(chan bool, 0) - - for _, s := range cr.storages { - go func(s storage.Storage) { - defer log.RecordPanic() - err := cr.checkFileFor(ctx, s, files, heavyCheck, missingMap, pg) - if ctx.Err() != nil { - return - } - if err != nil { - log.Errorf(Tr("error.check.failed"), s, err) - } - select { - case done <- err == nil: - case <-ctx.Done(): - } - }(s) - } - goodCount := 0 - for i := len(cr.storages); i > 0; i-- { - select { - case ok := <-done: - if ok { - goodCount++ - } - case <-ctx.Done(): - log.Warn(Tr("warn.sync.interrupted")) - return nil, ctx.Err() - } - } - if err := ctx.Err(); err != nil { - return nil, err - } - if goodCount == 0 { - return nil, errors.New("All storages are failed") - } - return missingMap.RawMap(), nil -} - -func (cr *Cluster) SetFilesetByExists(ctx context.Context, files []FileInfo) error { - pg := mpb.New(mpb.WithRefreshRate(time.Second/2), mpb.WithAutoRefresh(), mpb.WithWidth(140)) - defer pg.Shutdown() - log.SetLogOutput(pg) - defer log.SetLogOutput(nil) - - missingMap, err := cr.CheckFiles(ctx, files, false, pg) - if err != nil { - return err - } - fileset := make(map[string]int64, len(files)) - stoCount := len(cr.storages) - for _, f := range files { - if t, ok := missingMap[f.Hash]; !ok || len(t.targets) < stoCount { - fileset[f.Hash] = f.Size - } - } - - cr.mux.Lock() - cr.fileset = fileset - cr.mux.Unlock() - return nil -} - -func (cr *Cluster) syncFiles(ctx context.Context, files []FileInfo, heavyCheck bool) error { - pg := mpb.New(mpb.WithRefreshRate(time.Second/2), mpb.WithAutoRefresh(), mpb.WithWidth(140)) - defer pg.Shutdown() - log.SetLogOutput(pg) - defer log.SetLogOutput(nil) - - cr.syncProg.Store(0) - cr.syncTotal.Store(-1) - - missingMap, err := cr.CheckFiles(ctx, files, heavyCheck, pg) - if err != nil { - return err - } - var ( - missing = make([]*fileInfoWithTargets, 0, len(missingMap)) - missingSize int64 = 0 - ) - for _, f := range missingMap { - missing = append(missing, f) - missingSize += f.Size - } - totalFiles := len(missing) - if totalFiles == 0 { - log.Info(Tr("info.sync.none")) - return nil - } - - go cr.notifyManager.OnSyncBegin(len(missing), missingSize) - defer func() { - go cr.notifyManager.OnSyncDone() - }() - - cr.syncTotal.Store((int64)(totalFiles)) - - ccfg, err := cr.GetConfig(ctx) - if err != nil { - return err - } - syncCfg := ccfg.Sync - log.Infof(Tr("info.sync.config"), syncCfg) - - var stats syncStats - stats.pg = pg - stats.noOpen = syncCfg.Source == "center" - stats.slots = limited.NewBufSlots(syncCfg.Concurrency + 1) - stats.totalFiles = totalFiles - for _, f := range missing { - stats.totalSize += f.Size - } - - var barUnit decor.SizeB1024 - stats.lastInc.Store(time.Now().UnixNano()) - stats.totalBar = pg.AddBar(stats.totalSize, - mpb.BarRemoveOnComplete(), - mpb.BarPriority(stats.slots.Cap()), - mpb.PrependDecorators( - decor.Name(Tr("hint.sync.total")), - decor.NewPercentage("%.2f"), - ), - mpb.AppendDecorators( - decor.Any(func(decor.Statistics) string { - return fmt.Sprintf("(%d + %d / %d) ", stats.okCount.Load(), stats.failCount.Load(), stats.totalFiles) - }), - decor.Counters(barUnit, "(%.1f/%.1f) "), - decor.EwmaSpeed(barUnit, "%.1f ", 30), - decor.OnComplete( - decor.EwmaETA(decor.ET_STYLE_GO, 30), "done", - ), - ), - ) - - log.Infof(Tr("hint.sync.start"), totalFiles, utils.BytesToUnit((float64)(stats.totalSize))) - start := time.Now() - - done := make(chan []storage.Storage, 1) - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - aliveStorages := len(cr.storages) - for _, s := range cr.storages { - tctx, cancel := context.WithTimeout(ctx, time.Second*10) - err := s.CheckUpload(tctx) - cancel() - if err != nil { - if err := ctx.Err(); err != nil { - return err - } - aliveStorages-- - log.Errorf("Storage %s does not work: %v", s.String(), err) - } - } - if aliveStorages == 0 { - err := errors.New("All storages are broken") - log.Errorf(Tr("error.sync.failed"), err) - return err - } - if aliveStorages < len(cr.storages) { - log.Errorf(Tr("error.sync.part.working"), aliveStorages < len(cr.storages)) - select { - case <-time.After(time.Minute): - case <-ctx.Done(): - return ctx.Err() - } - } - - for _, f := range missing { - log.Debugf("File %s is for %v", f.Hash, f.targets) - pathRes, err := cr.fetchFile(ctx, &stats, f.FileInfo) - if err != nil { - log.Warn(Tr("warn.sync.interrupted")) - return err - } - go func(f *fileInfoWithTargets, pathRes <-chan string) { - defer log.RecordPanic() - select { - case path := <-pathRes: - cr.syncProg.Add(1) - if path == "" { - select { - case done <- nil: // TODO: or all storage? - case <-ctx.Done(): - } - return - } - defer os.Remove(path) - // acquire slot here - slotId, buf, free := stats.slots.Alloc(ctx) - if buf == nil { - return - } - defer free() - _ = slotId - var srcFd *os.File - if srcFd, err = os.Open(path); err != nil { - return - } - defer srcFd.Close() - var failed []storage.Storage - for _, target := range f.targets { - if _, err = srcFd.Seek(0, io.SeekStart); err != nil { - log.Errorf("Cannot seek file %q to start: %v", path, err) - continue - } - if err = target.Create(f.Hash, srcFd); err != nil { - failed = append(failed, target) - log.Errorf(Tr("error.sync.create.failed"), target.String(), f.Hash, err) - continue - } - } - free() - srcFd.Close() - os.Remove(path) - select { - case done <- failed: - case <-ctx.Done(): - } - case <-ctx.Done(): - return - } - }(f, pathRes) - } - - stLen := len(cr.storages) - broken := make(map[storage.Storage]bool, stLen) - - for i := len(missing); i > 0; i-- { - select { - case failed := <-done: - for _, s := range failed { - if !broken[s] { - broken[s] = true - log.Debugf("Broken storage %d / %d", len(broken), stLen) - if len(broken) >= stLen { - cancel() - err := errors.New("All storages are broken") - log.Errorf(Tr("error.sync.failed"), err) - return err - } - } - } - case <-ctx.Done(): - log.Warn(Tr("warn.sync.interrupted")) - return ctx.Err() - } - } - - use := time.Since(start) - stats.totalBar.Abort(true) - pg.Wait() - - log.Infof(Tr("hint.sync.done"), use, utils.BytesToUnit((float64)(stats.totalSize)/use.Seconds())) - return nil -} - -func (cr *Cluster) Gc() { - for _, s := range cr.storages { - cr.gcFor(s) - } -} - -func (cr *Cluster) gcFor(s storage.Storage) { - log.Infof(Tr("info.gc.start"), s.String()) - err := s.WalkDir(func(hash string, _ int64) error { - if cr.issync.Load() { - return context.Canceled - } - if _, ok := cr.CachedFileSize(hash); !ok { - log.Infof(Tr("info.gc.found"), s.String()+"/"+hash) - s.Remove(hash) - } - return nil - }) - if err != nil { - if err == context.Canceled { - log.Warnf(Tr("warn.gc.interrupted"), s.String()) - } else { - log.Errorf(Tr("error.gc.error"), err) - } - return - } - log.Infof(Tr("info.gc.done"), s.String()) -} - -func (cr *Cluster) fetchFile(ctx context.Context, stats *syncStats, f FileInfo) (<-chan string, error) { - const ( - maxRetryCount = 5 - maxTryWithOpen = 3 - ) - - slotId, buf, free := stats.slots.Alloc(ctx) - if buf == nil { - return nil, ctx.Err() - } - - pathRes := make(chan string, 1) - go func() { - defer log.RecordPanic() - defer free() - defer close(pathRes) - - var barUnit decor.SizeB1024 - var tried atomic.Int32 - tried.Store(1) - bar := stats.pg.AddBar(f.Size, - mpb.BarRemoveOnComplete(), - mpb.BarPriority(slotId), - mpb.PrependDecorators( - decor.Name(Tr("hint.sync.downloading")), - decor.Any(func(decor.Statistics) string { - tc := tried.Load() - if tc <= 1 { - return "" - } - return fmt.Sprintf("(%d/%d) ", tc, maxRetryCount) - }), - decor.Name(f.Path, decor.WCSyncSpaceR), - ), - mpb.AppendDecorators( - decor.NewPercentage("%d", decor.WCSyncSpace), - decor.Counters(barUnit, "[%.1f / %.1f]", decor.WCSyncSpace), - decor.EwmaSpeed(barUnit, "%.1f", 30, decor.WCSyncSpace), - decor.OnComplete( - decor.EwmaETA(decor.ET_STYLE_GO, 30, decor.WCSyncSpace), "done", - ), - ), - ) - defer bar.Abort(true) - - noOpen := stats.noOpen - badOpen := false - interval := time.Second - for { - bar.SetCurrent(0) - hashMethod, err := getHashMethod(len(f.Hash)) - if err == nil { - var path string - if path, err = cr.fetchFileWithBuf(ctx, f, hashMethod, buf, noOpen, badOpen, func(r io.Reader) io.Reader { - return ProxyReader(r, bar, stats.totalBar, &stats.lastInc) - }); err == nil { - pathRes <- path - stats.okCount.Add(1) - log.Infof(Tr("info.sync.downloaded"), f.Path, - utils.BytesToUnit((float64)(f.Size)), - (float64)(stats.totalBar.Current())/(float64)(stats.totalSize)*100) - return - } - } - bar.SetRefill(bar.Current()) - - c := tried.Add(1) - if c > maxRetryCount { - log.Errorf(Tr("error.sync.download.failed"), f.Path, err) - break - } - if c > maxTryWithOpen { - badOpen = true - } - log.Errorf(Tr("error.sync.download.failed.retry"), f.Path, interval, err) - select { - case <-time.After(interval): - interval *= 2 - case <-ctx.Done(): - return - } - } - stats.failCount.Add(1) - }() - return pathRes, nil -} - -func (cr *Cluster) fetchFileWithBuf( - ctx context.Context, f FileInfo, - hashMethod crypto.Hash, buf []byte, - noOpen bool, badOpen bool, - wrapper func(io.Reader) io.Reader, -) (path string, err error) { - var ( - reqPath = f.Path - query url.Values - req *http.Request - res *http.Response - fd *os.File - r io.Reader - ) - if badOpen { - reqPath = "/openbmclapi/download/" + f.Hash - } else if noOpen { - query = url.Values{ - "noopen": {"1"}, - } - } - if req, err = cr.makeReqWithAuth(ctx, http.MethodGet, reqPath, query); err != nil { - return - } - req.Header.Set("Accept-Encoding", "gzip, deflate") - if res, err = cr.client.Do(req); err != nil { - return - } - defer res.Body.Close() - if err = ctx.Err(); err != nil { - return - } - if res.StatusCode != http.StatusOK { - err = ErrorFromRedirect(utils.NewHTTPStatusErrorFromResponse(res), res) - return - } - switch ce := strings.ToLower(res.Header.Get("Content-Encoding")); ce { - case "": - r = res.Body - case "gzip": - if r, err = gzip.NewReader(res.Body); err != nil { - err = ErrorFromRedirect(err, res) - return - } - case "deflate": - if r, err = zlib.NewReader(res.Body); err != nil { - err = ErrorFromRedirect(err, res) - return - } - default: - err = ErrorFromRedirect(fmt.Errorf("Unexpected Content-Encoding %q", ce), res) - return - } - if wrapper != nil { - r = wrapper(r) - } - - hw := hashMethod.New() - - if fd, err = os.CreateTemp("", "*.downloading"); err != nil { - return - } - path = fd.Name() - defer func(path string) { - if err != nil { - os.Remove(path) - } - }(path) - - _, err = io.CopyBuffer(io.MultiWriter(hw, fd), r, buf) - stat, err2 := fd.Stat() - fd.Close() - if err != nil { - err = ErrorFromRedirect(err, res) - return - } - if err2 != nil { - err = err2 - return - } - if t := stat.Size(); f.Size >= 0 && t != f.Size { - err = ErrorFromRedirect(fmt.Errorf("File size wrong, got %d, expect %d", t, f.Size), res) - return - } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != f.Hash { - err = ErrorFromRedirect(fmt.Errorf("File hash not match, got %s, expect %s", hs, f.Hash), res) - return - } - return -} - -type downloadingItem struct { - err error - done chan struct{} -} - -func (cr *Cluster) lockDownloading(target string) (*downloadingItem, bool) { - cr.downloadMux.RLock() - item := cr.downloading[target] - cr.downloadMux.RUnlock() - if item != nil { - return item, true - } - - cr.downloadMux.Lock() - defer cr.downloadMux.Unlock() - - if item = cr.downloading[target]; item != nil { - return item, true - } - item = &downloadingItem{ - done: make(chan struct{}, 0), - } - cr.downloading[target] = item - return item, false -} - -func (cr *Cluster) DownloadFile(ctx context.Context, hash string) (err error) { - hashMethod, err := getHashMethod(len(hash)) - if err != nil { - return - } - - f := FileInfo{ - Path: "/openbmclapi/download/" + hash, - Hash: hash, - Size: -1, - Mtime: 0, - } - item, ok := cr.lockDownloading(hash) - if !ok { - go func() { - defer log.RecoverPanic(nil) - var err error - defer func() { - if err != nil { - log.Errorf(Tr("error.sync.download.failed"), hash, err) - } - item.err = err - close(item.done) - - cr.downloadMux.Lock() - defer cr.downloadMux.Unlock() - delete(cr.downloading, hash) - }() - - log.Infof(Tr("hint.sync.downloading.handler"), hash) - - ctx, cancel := context.WithCancel(context.Background()) - go func() { - if cr.enabled.Load() { - select { - case <-cr.Disabled(): - cancel() - case <-ctx.Done(): - } - } else { - select { - case <-cr.WaitForEnable(): - cancel() - case <-ctx.Done(): - } - } - }() - defer cancel() - - var buf []byte - _, buf, free := cr.allocBuf(ctx) - if buf == nil { - err = ctx.Err() - return - } - defer free() - - path, err := cr.fetchFileWithBuf(ctx, f, hashMethod, buf, true, true, nil) - if err != nil { - return - } - defer os.Remove(path) - var srcFd *os.File - if srcFd, err = os.Open(path); err != nil { - return - } - defer srcFd.Close() - var stat os.FileInfo - if stat, err = srcFd.Stat(); err != nil { - return - } - size := stat.Size() - - for _, target := range cr.storages { - if _, err = srcFd.Seek(0, io.SeekStart); err != nil { - log.Errorf("Cannot seek file %q: %v", path, err) - return - } - if err := target.Create(hash, srcFd); err != nil { - log.Errorf(Tr("error.sync.create.failed"), target.String(), hash, err) - continue - } - } - - cr.filesetMux.Lock() - cr.fileset[hash] = -size // negative means that the file was not stored into the database yet - cr.filesetMux.Unlock() - }() - } - select { - case <-item.done: - err = item.err - case <-ctx.Done(): - err = ctx.Err() - case <-cr.Disabled(): - err = context.Canceled - } - return -} - -func (cr *Cluster) checkUpdate() (err error) { - if update.CurrentBuildTag == nil { - return - } - log.Info(Tr("info.update.checking")) - release, err := update.Check(cr.cachedCli, config.GithubAPI.Authorization) - if err != nil || release == nil { - return - } - // TODO: print all middle change logs - log.Infof(Tr("info.update.detected"), release.Tag, update.CurrentBuildTag) - log.Infof(Tr("info.update.changelog"), update.CurrentBuildTag, release.Tag, release.Body) - cr.notifyManager.OnUpdateAvaliable(release) - return -} diff --git a/util.go b/util.go index 3bd2ba97..28105071 100644 --- a/util.go +++ b/util.go @@ -25,11 +25,7 @@ import ( "crypto/x509" "fmt" "io" - "net/http" - "net/url" "os" - "slices" - "strings" "time" "github.com/LiterMC/go-openbmclapi/log" @@ -95,45 +91,3 @@ func copyFile(src, dst string, mode os.FileMode) (err error) { _, err = io.Copy(dstFd, srcFd) return } - -type RedirectError struct { - Redirects []*url.URL - Err error -} - -func ErrorFromRedirect(err error, resp *http.Response) *RedirectError { - redirects := make([]*url.URL, 0, 4) - for resp != nil && resp.Request != nil { - redirects = append(redirects, resp.Request.URL) - resp = resp.Request.Response - } - if len(redirects) > 1 { - slices.Reverse(redirects) - } else { - redirects = nil - } - return &RedirectError{ - Redirects: redirects, - Err: err, - } -} - -func (e *RedirectError) Error() string { - if len(e.Redirects) == 0 { - return e.Err.Error() - } - - var b strings.Builder - b.WriteString("Redirect from:\n\t") - for _, r := range e.Redirects { - b.WriteString("- ") - b.WriteString(r.String()) - b.WriteString("\n\t") - } - b.WriteString(e.Err.Error()) - return b.String() -} - -func (e *RedirectError) Unwrap() error { - return e.Err -} diff --git a/utils/http.go b/utils/http.go index 0412bf27..4c6e5e66 100644 --- a/utils/http.go +++ b/utils/http.go @@ -30,6 +30,7 @@ import ( "net/url" "path" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -504,3 +505,45 @@ func (c *connHeadReader) Read(buf []byte) (n int, err error) { } return c.Conn.Read(buf) } + +type RedirectError struct { + Redirects []*url.URL + Err error +} + +func ErrorFromRedirect(err error, resp *http.Response) *RedirectError { + redirects := make([]*url.URL, 0, 4) + for resp != nil && resp.Request != nil { + redirects = append(redirects, resp.Request.URL) + resp = resp.Request.Response + } + if len(redirects) > 1 { + slices.Reverse(redirects) + } else { + redirects = nil + } + return &RedirectError{ + Redirects: redirects, + Err: err, + } +} + +func (e *RedirectError) Error() string { + if len(e.Redirects) == 0 { + return e.Err.Error() + } + + var b strings.Builder + b.WriteString("Redirect from:\n\t") + for _, r := range e.Redirects { + b.WriteString("- ") + b.WriteString(r.String()) + b.WriteString("\n\t") + } + b.WriteString(e.Err.Error()) + return b.String() +} + +func (e *RedirectError) Unwrap() error { + return e.Err +} diff --git a/utils/util.go b/utils/util.go index 78955cb2..a5a6572f 100644 --- a/utils/util.go +++ b/utils/util.go @@ -21,8 +21,10 @@ package utils import ( "errors" + "fmt" "os" "path/filepath" + "strings" "sync" ) @@ -47,6 +49,12 @@ func (m *SyncMap[K, V]) RawMap() map[K]V { return m.m } +func (m *SyncMap[K, V]) Clear() { + m.l.Lock() + defer m.l.Unlock() + clear(m.m) +} + func (m *SyncMap[K, V]) Set(k K, v V) { m.l.Lock() defer m.l.Unlock() @@ -59,7 +67,7 @@ func (m *SyncMap[K, V]) Get(k K) V { return m.m[k] } -func (m *SyncMap[K, V]) Has(k K) bool { +func (m *SyncMap[K, V]) Contains(k K) bool { m.l.RLock() defer m.l.RUnlock() _, ok := m.m[k] @@ -83,6 +91,55 @@ func (m *SyncMap[K, V]) GetOrSet(k K, setter func() V) (v V, had bool) { return } +type Set[T comparable] map[T]struct{} + +func NewSet[T comparable]() Set[T] { + return make(Set[T]) +} + +func (s Set[T]) Clear() { + clear(s) +} + +func (s Set[T]) Put(v T) { + s[v] = struct{}{} +} + +func (s Set[T]) Contains(v T) bool { + _, ok := s[v] + return ok +} + +func (s Set[T]) Remove(v T) bool { + _, ok := s[v] + if ok { + delete(s, v) + } + return ok +} + +func (s Set[T]) ToSlice(arr []T) []T { + for v, _ := range s { + arr = append(arr, v) + } + return arr +} + +func (s Set[T]) String() string { + var b strings.Builder + b.WriteString("Set{") + first := true + for v := range s { + if first { + first = false + b.WriteByte(' ') + } + fmt.Fprintf(&b, "%v", v) + } + b.WriteByte('}') + return b.String() +} + func WalkCacheDir(cacheDir string, walker func(hash string, size int64) (err error)) (err error) { for _, dir := range Hex256 { files, err := os.ReadDir(filepath.Join(cacheDir, dir)) From 1d399d97f988da6f494f0778407cf24eca776cca Mon Sep 17 00:00:00 2001 From: zyxkad Date: Mon, 12 Aug 2024 19:27:10 -0600 Subject: [PATCH 21/36] add gc --- cluster/storage.go | 25 +++++++++++++++++++++++++ main.go | 6 +++--- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/cluster/storage.go b/cluster/storage.go index 1d48521e..05d45101 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -702,6 +702,31 @@ func (c *HTTPClient) fetchFileWithBuf( return } +func (c *HTTPClient) Gc( + ctx context.Context, + manager *storage.Manager, + files map[string]*StorageFileInfo, +) error { + errs := make([]error, len(manager.Storages)) + var wg sync.WaitGroup + for i, s := range manager.Storages { + wg.Add(1) + go func(i int, s storage.Storage) { + defer wg.Done() + errs[i] = s.WalkDir(func(hash string, size int64) error { + info, ok := files[hash] + ok = ok && slices.Contains(info.Storages, s) + if !ok { + s.Remove(hash) + } + return nil + }) + }(i, s) + } + wg.Wait() + return errors.Join(errs...) +} + func getHashMethod(l int) (hashMethod crypto.Hash, err error) { switch l { case 32: diff --git a/main.go b/main.go index 3dd36c89..0d46a91c 100644 --- a/main.go +++ b/main.go @@ -543,9 +543,9 @@ func (r *Runner) InitSynchronizer(ctx context.Context) { } go r.UpdateFileRecords(fileMap, nil) - // if !r.Config.Advanced.NoGC { - // go r.cluster.Gc() - // } + if !r.Config.Advanced.NoGC { + go r.client.Gc(context.TODO(), r.storageManager, fileMap) + } } // else if fl != nil { // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { From e72384719d9b6d1c70002e7d5c2bbf7691887d81 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Mon, 12 Aug 2024 23:04:24 -0600 Subject: [PATCH 22/36] fix certificate request logic --- cluster/cluster.go | 10 +++- main.go | 124 ++++++++++++++++++++++++--------------------- 2 files changed, 74 insertions(+), 60 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index 51d2cf6d..08a850e6 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -44,6 +44,7 @@ var ( ) type Cluster struct { + name string opts config.ClusterOptions gcfg config.ClusterGeneralConfig @@ -69,7 +70,7 @@ type Cluster struct { } func NewCluster( - opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, + name string, opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, storageManager *storage.Manager, statManager *StatManager, ) (cr *Cluster) { @@ -89,10 +90,17 @@ func NewCluster( } // ID returns the cluster id +// The ID may not be unique in the openbmclapi cluster runtime func (cr *Cluster) ID() string { return cr.opts.Id } +// Name returns the cluster's alias name +// The name must be unique in the openbmclapi cluster runtime +func (cr *Cluster) Name() string { + return cr.name +} + // Secret returns the cluster secret func (cr *Cluster) Secret() string { return cr.opts.Secret diff --git a/main.go b/main.go index 0d46a91c..0dd90072 100644 --- a/main.go +++ b/main.go @@ -145,38 +145,41 @@ func main() { go func(ctx context.Context) { defer log.RecordPanic() - var wg sync.WaitGroup - errs := make([]error, len(r.clusters)) { - i := 0 + type clusterSetupRes struct { + cluster *cluster.Cluster + err error + cert *tls.Certificate + } + resCh := make(chan clusterSetupRes) for _, cr := range r.clusters { - i++ - go func(i int, cr *cluster.Cluster) { - defer wg.Done() - errs[i] = cr.Connect(ctx) - }(i, cr) + go func(cr *cluster.Cluster) { + defer log.RecordPanic() + if err := cr.Connect(ctx); err != nil { + log.Errorf("Failed to connect cluster %s to server %q: %v", cr.ID(), cr.Options().Server, err) + resCh <- clusterSetupRes{cluster: cr, err: err} + return + } + cert, err := r.RequestClusterCert(ctx, cr) + if err != nil { + log.Errorf("Failed to request certificate for cluster %s: %v", cr.ID(), err) + resCh <- clusterSetupRes{cluster: cr, err: err} + return + } + resCh <- clusterSetupRes{cluster: cr, cert: cert} + }(cr) } - } - wg.Wait() - if ctx.Err() != nil { - return - } - - { - var err error - r.tlsConfig, err = r.PatchTLSWithClusterCert(ctx, r.tlsConfig) - if err != nil { - return + for range len(r.clusters) { + select { + case res := <-resCh: + r.certificates[res.cluster.Name()] = res.cert + case <-ctx.Done(): + return + } } - r.listener.TLSConfig.Store(r.tlsConfig) } - firstSyncDone := make(chan struct{}, 0) - go func() { - defer log.RecordPanic() - defer close(firstSyncDone) - r.InitSynchronizer(ctx) - }() + r.listener.TLSConfig.Store(r.PatchTLSWithClusterCertificates(r.tlsConfig)) if !r.Config.Tunneler.Enable { strPort := strconv.Itoa((int)(r.getPublicPort())) @@ -186,13 +189,9 @@ func main() { } log.TrInfof("info.wait.first.sync") - select { - case <-firstSyncDone: - case <-ctx.Done(): - return - } + r.InitSynchronizer(ctx) - // r.EnableCluster(ctx) + r.EnableClusterAll(ctx) }(ctx) code := r.ListenSignals(ctx, cancel) @@ -229,10 +228,11 @@ type Runner struct { handlerAPIv0 http.Handler hijackHandler http.Handler - tlsConfig *tls.Config - publicHost string - publicPort uint16 - listener *utils.HTTPTLSListener + tlsConfig *tls.Config + certificates map[string]*tls.Certificate + publicHost string + publicPort uint16 + listener *utils.HTTPTLSListener reloading atomic.Bool updating atomic.Bool @@ -314,7 +314,7 @@ func (r *Runner) InitClusters(ctx context.Context) { r.clusters = make(map[string]*cluster.Cluster) gcfg := r.GetClusterGeneralConfig() for name, opts := range r.Config.Clusters { - cr := cluster.NewCluster(opts, gcfg, r.storageManager, r.statManager) + cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager) if err := cr.Init(ctx); err != nil { log.TrErrorf("error.init.failed", err) } else { @@ -618,28 +618,34 @@ func (r *Runner) GenerateTLSConfig() (*tls.Config, error) { return tlsConfig, nil } -func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Config) (*tls.Config, error) { - certs := make([]tls.Certificate, 0) - for _, cr := range r.clusters { - if cr.Options().Byoc { - continue - } - log.TrInfof("info.cert.requesting", cr.ID()) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := cr.RequestCert(tctx) - cancel() - if err != nil { - log.TrErrorf("error.cert.request.failed", err) - continue - } - cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.TrErrorf("error.cert.requested.parse.failed", err) - continue +func (r *Runner) RequestClusterCert(ctx context.Context, cr *cluster.Cluster) (*tls.Certificate, error) { + if cr.Options().Byoc { + return nil, nil + } + log.TrInfof("info.cert.requesting", cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.TrErrorf("error.cert.request.failed", err) + return nil, err + } + cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.TrErrorf("error.cert.requested.parse.failed", err) + return nil, err + } + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.TrInfof("info.cert.requested", certHost) + return &cert, nil +} + +func (r *Runner) PatchTLSWithClusterCertificates(tlsConfig *tls.Config) *tls.Config { + certs := make([]tls.Certificate, 0, len(r.certificates)) + for _, c := range r.certificates { + if c != nil { + certs = append(certs, *c) } - certs = append(certs, cert) - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.TrInfof("info.cert.requested", certHost) } if len(certs) == 0 { if tlsConfig == nil { @@ -649,7 +655,7 @@ func (r *Runner) PatchTLSWithClusterCert(ctx context.Context, tlsConfig *tls.Con } tlsConfig.Certificates = append(tlsConfig.Certificates, certs...) } - return tlsConfig, nil + return tlsConfig } // updateClustersWithGeneralConfig will re-enable all clusters with latest general config From 546576657cb31a0441dd86e2340ffcc295da6112 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Tue, 13 Aug 2024 21:58:40 -0600 Subject: [PATCH 23/36] add report API --- cluster/cluster.go | 5 ++- cluster/http.go | 10 ++++- cluster/{config.go => requests.go} | 30 +++++++++++++ cluster/storage.go | 35 ++++++--------- utils/encoding.go | 26 +++++++++++ utils/encoding_test.go | 69 ++++++++++++++++++++++++++++++ utils/http.go | 31 ++++++++------ 7 files changed, 170 insertions(+), 36 deletions(-) rename cluster/{config.go => requests.go} (88%) create mode 100644 utils/encoding_test.go diff --git a/cluster/cluster.go b/cluster/cluster.go index 08a850e6..f20416bd 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -90,13 +90,14 @@ func NewCluster( } // ID returns the cluster id -// The ID may not be unique in the openbmclapi cluster runtime +// The ID may not be unique in the OpenBMCLAPI cluster runtime. +// To identify the cluster instance for analyzing, use Name instead. func (cr *Cluster) ID() string { return cr.opts.Id } // Name returns the cluster's alias name -// The name must be unique in the openbmclapi cluster runtime +// The name must be unique in the OpenBMCLAPI cluster runtime. func (cr *Cluster) Name() string { return cr.name } diff --git a/cluster/http.go b/cluster/http.go index 07480e33..8c747c3a 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -115,7 +115,15 @@ func (cr *Cluster) makeReqWithBody( } func (cr *Cluster) makeReqWithAuth(ctx context.Context, method string, relpath string, query url.Values) (req *http.Request, err error) { - req, err = cr.makeReq(ctx, method, relpath, query) + return cr.makeReqWithAuthBody(ctx, method, relpath, query, nil) +} + +func (cr *Cluster) makeReqWithAuthBody( + ctx context.Context, + method string, relpath string, + query url.Values, body io.Reader, +) (req *http.Request, err error) { + req, err = cr.makeReqWithBody(ctx, method, relpath, query, nil) if err != nil { return } diff --git a/cluster/config.go b/cluster/requests.go similarity index 88% rename from cluster/config.go rename to cluster/requests.go index 2eebc617..51209b0f 100644 --- a/cluster/config.go +++ b/cluster/requests.go @@ -259,3 +259,33 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error } return } + +func (cr *Cluster) ReportDownload(ctx context.Context, request *http.Request, err error) error { + type ReportPayload struct { + Urls []string `json:"urls"` + Error utils.EmbedJSON[struct{ Message string }] `json:"error"` + } + var payload ReportPayload + redirects := utils.GetRedirects(request) + payload.Urls = make([]string, len(redirects)) + for i, u := range redirects { + payload.Urls[i] = u.String() + } + payload.Error.V.Message = err.Error() + data, err := json.Marshal(payload) + if err != nil { + return err + } + req, err := cr.makeReqWithAuthBody(ctx, http.MethodPost, "/openbmclapi/report", nil, bytes.NewReader(data)) + if err != nil { + return err + } + resp, err := cr.client.Do(req) + if err != nil { + return err + } + if resp.StatusCode/100 != 2 { + return utils.NewHTTPStatusErrorFromResponse(resp) + } + return nil +} diff --git a/cluster/storage.go b/cluster/storage.go index 05d45101..35dda611 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -641,29 +641,22 @@ func (c *HTTPClient) fetchFileWithBuf( return } defer res.Body.Close() - if err = ctx.Err(); err != nil { - return - } if res.StatusCode != http.StatusOK { - err = utils.ErrorFromRedirect(utils.NewHTTPStatusErrorFromResponse(res), res) - return - } - switch ce := strings.ToLower(res.Header.Get("Content-Encoding")); ce { - case "": - r = res.Body - case "gzip": - if r, err = gzip.NewReader(res.Body); err != nil { - err = utils.ErrorFromRedirect(err, res) - return + err = utils.NewHTTPStatusErrorFromResponse(res) + }else { + switch ce := strings.ToLower(res.Header.Get("Content-Encoding")); ce { + case "": + r = res.Body + case "gzip": + r, err = gzip.NewReader(res.Body) + case "deflate": + r, err = zlib.NewReader(res.Body) + default: + err = fmt.Errorf("Unexpected Content-Encoding %q", ce) } - case "deflate": - if r, err = zlib.NewReader(res.Body); err != nil { - err = utils.ErrorFromRedirect(err, res) - return - } - default: - err = utils.ErrorFromRedirect(fmt.Errorf("Unexpected Content-Encoding %q", ce), res) - return + } + if err != nil { + return "", utils.ErrorFromRedirect(err, res) } if wrapper != nil { r = wrapper(r) diff --git a/utils/encoding.go b/utils/encoding.go index 2db3b4ca..620c2e50 100644 --- a/utils/encoding.go +++ b/utils/encoding.go @@ -20,6 +20,7 @@ package utils import ( + "encoding/json" "time" "gopkg.in/yaml.v3" @@ -107,3 +108,28 @@ func (d *YAMLDuration) UnmarshalYAML(n *yaml.Node) (err error) { *d = (YAMLDuration)(td) return nil } + +type EmbedJSON[T any] struct { + V T +} + +var ( + _ json.Marshaler = EmbedJSON[any]{} + _ json.Unmarshaler = (*EmbedJSON[any])(nil) +) + +func (e EmbedJSON[T]) MarshalJSON() ([]byte, error) { + data, err := json.Marshal(e.V) + if err != nil { + return nil, err + } + return json.Marshal((string)(data)) +} + +func (e *EmbedJSON[T]) UnmarshalJSON(data []byte) error { + var str string + if err := json.Unmarshal(data, &str); err != nil { + return err + } + return json.Unmarshal(([]byte)(str), &e.V) +} diff --git a/utils/encoding_test.go b/utils/encoding_test.go new file mode 100644 index 00000000..50b3b9e4 --- /dev/null +++ b/utils/encoding_test.go @@ -0,0 +1,69 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package utils_test + +import ( + "testing" + + "encoding/json" + "reflect" + + "github.com/LiterMC/go-openbmclapi/utils" +) + +func TestEmbedJSON(t *testing.T) { + type testPayload struct { + A int + B utils.EmbedJSON[struct { + C string + D float64 + E utils.EmbedJSON[*string] + F *utils.EmbedJSON[string] + }] + G utils.EmbedJSON[*string] + H *utils.EmbedJSON[string] + I utils.EmbedJSON[*string] `json:",omitempty"` + J *utils.EmbedJSON[string] `json:",omitempty"` + } + var v testPayload + v.A = 1 + v.B.V.C = "2\"" + v.B.V.D = 3.4 + v.B.V.E.V = new(string) + *v.B.V.E.V = `{5"6"7}` + v.B.V.F = &utils.EmbedJSON[string]{ + V: "f", + } + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + dataStr := (string)(data) + if want := `{"A":1,"B":"{\"C\":\"2\\\"\",\"D\":3.4,\"E\":\"\\\"{5\\\\\\\"6\\\\\\\"7}\\\"\",\"F\":\"\\\"f\\\"\"}","G":"null","H":null,"I":"null"}`; dataStr != want { + t.Fatalf("Marshal error, got %s, want %s", dataStr, want) + } + var w testPayload + if err := json.Unmarshal(data, &w); err != nil { + t.Fatalf("Unmarshal error: %v", err) + } + if !reflect.DeepEqual(w, v) { + t.Fatalf("Unmarshal error, got %#v, want %#v", w, v) + } +} diff --git a/utils/http.go b/utils/http.go index 4c6e5e66..25ddfe15 100644 --- a/utils/http.go +++ b/utils/http.go @@ -506,30 +506,37 @@ func (c *connHeadReader) Read(buf []byte) (n int, err error) { return c.Conn.Read(buf) } +func GetRedirects(req *http.Request) []*url.URL { + redirects := make([]*url.URL, 0, 5) + for req != nil { + redirects = append(redirects, req.URL) + resp := req.Response + if resp == nil { + break + } + req = resp.Request + } + if len(redirects) == 0 { + return nil + } + slices.Reverse(redirects) + return redirects +} + type RedirectError struct { Redirects []*url.URL Err error } func ErrorFromRedirect(err error, resp *http.Response) *RedirectError { - redirects := make([]*url.URL, 0, 4) - for resp != nil && resp.Request != nil { - redirects = append(redirects, resp.Request.URL) - resp = resp.Request.Response - } - if len(redirects) > 1 { - slices.Reverse(redirects) - } else { - redirects = nil - } return &RedirectError{ - Redirects: redirects, + Redirects: GetRedirects(resp.Request), Err: err, } } func (e *RedirectError) Error() string { - if len(e.Redirects) == 0 { + if len(e.Redirects) <= 1 { return e.Err.Error() } From 16e9db431a68f5c25361a80df9dc32eb234fa88b Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 15:55:10 -0600 Subject: [PATCH 24/36] seperate runner --- main.go | 621 +----------------------------------------------- runner.go | 690 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 691 insertions(+), 620 deletions(-) create mode 100644 runner.go diff --git a/main.go b/main.go index 0dd90072..e8ac129c 100644 --- a/main.go +++ b/main.go @@ -20,43 +20,21 @@ package main import ( - "bufio" - "bytes" "context" "crypto/tls" - "errors" "fmt" - "io" "net" - "net/http" "os" - "os/exec" - "os/signal" - "path/filepath" "runtime" "strconv" "strings" - "sync" - "sync/atomic" - "syscall" "time" - "runtime/pprof" - - doh "github.com/libp2p/go-doh-resolver" - - "github.com/LiterMC/go-openbmclapi/api" - "github.com/LiterMC/go-openbmclapi/api/bmclapi" "github.com/LiterMC/go-openbmclapi/cluster" - "github.com/LiterMC/go-openbmclapi/config" - "github.com/LiterMC/go-openbmclapi/database" "github.com/LiterMC/go-openbmclapi/internal/build" "github.com/LiterMC/go-openbmclapi/lang" - "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" - "github.com/LiterMC/go-openbmclapi/storage" subcmds "github.com/LiterMC/go-openbmclapi/sub_commands" - "github.com/LiterMC/go-openbmclapi/utils" _ "github.com/LiterMC/go-openbmclapi/lang/en" _ "github.com/LiterMC/go-openbmclapi/lang/zh" @@ -71,6 +49,7 @@ var ( var startTime = time.Now() const baseDir = "." +const dataDir = "data" func parseArgs() { if len(os.Args) > 1 { @@ -207,601 +186,3 @@ func main() { } os.Exit(code) } - -type Runner struct { - Config *config.Config - - configHandler *ConfigHandler - client *cluster.HTTPClient - clusters map[string]*cluster.Cluster - userManager api.UserManager - tokenManager api.TokenManager - subManager api.SubscriptionManager - storageManager *storage.Manager - statManager *cluster.StatManager - hijacker *bmclapi.HjProxy - database database.DB - - server *http.Server - apiRateLimiter *limited.APIRateMiddleWare - handler http.Handler - handlerAPIv0 http.Handler - hijackHandler http.Handler - - tlsConfig *tls.Config - certificates map[string]*tls.Certificate - publicHost string - publicPort uint16 - listener *utils.HTTPTLSListener - - reloading atomic.Bool - updating atomic.Bool - tunnelCancel context.CancelFunc -} - -func NewRunner() *Runner { - r := new(Runner) - r.configHandler = &ConfigHandler{r: r} - r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) - return r -} - -func (r *Runner) getPublicPort() uint16 { - if r.publicPort > 0 { - return r.publicPort - } - return r.Config.Port -} - -func (r *Runner) getCertCount() int { - if r.tlsConfig == nil { - return 0 - } - return len(r.tlsConfig.Certificates) -} - -func (r *Runner) InitServer() { - r.server = &http.Server{ - ReadTimeout: 10 * time.Second, - IdleTimeout: 5 * time.Second, - Handler: (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { - r.handler.ServeHTTP(rw, req) - }), - ErrorLog: log.ProxiedStdLog, - } - r.updateRateLimit(context.TODO()) - r.handler = r.GetHandler() -} - -// StartServer will start the HTTP server -// If a server is already running on an old listener, the listener will be closed. -func (r *Runner) StartServer(ctx context.Context) error { - htListener, err := r.CreateHTTPListener(ctx) - if err != nil { - return err - } - if r.listener != nil { - r.listener.Close() - } - r.listener = htListener - go func() { - defer htListener.Close() - if err := r.server.Serve(htListener); !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { - log.Error("Error when serving:", err) - os.Exit(1) - } - }() - return nil -} - -func (r *Runner) GetClusterGeneralConfig() config.ClusterGeneralConfig { - return config.ClusterGeneralConfig{ - PublicHost: r.publicHost, - PublicPort: r.getPublicPort(), - NoFastEnable: r.Config.Advanced.NoFastEnable, - MaxReconnectCount: r.Config.MaxReconnectCount, - } -} - -func (r *Runner) InitClusters(ctx context.Context) { - // var ( - // dialer *net.Dialer - // cache = r.Config.Cache.NewCache() - // ) - - _ = doh.NewResolver // TODO: use doh resolver - - r.clusters = make(map[string]*cluster.Cluster) - gcfg := r.GetClusterGeneralConfig() - for name, opts := range r.Config.Clusters { - cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager) - if err := cr.Init(ctx); err != nil { - log.TrErrorf("error.init.failed", err) - } else { - r.clusters[name] = cr - } - } - - // r.cluster = NewCluster(ctx, - // ClusterServerURL, - // baseDir, - // config.PublicHost, r.getPublicPort(), - // config.ClusterId, config.ClusterSecret, - // config.Byoc, dialer, - // config.Storages, - // cache, - // ) - // if err := r.cluster.Init(ctx); err != nil { - // log.TrErrorf("error.init.failed"), err) - // os.Exit(1) - // } -} - -func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) int { - signalCh := make(chan os.Signal, 1) - log.Debugf("Receiving signals") - signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) - defer signal.Stop(signalCh) - - var ( - forceStop context.CancelFunc - exited = make(chan struct{}, 0) - ) - - for { - select { - case <-exited: - return 0 - case s := <-signalCh: - switch s { - case syscall.SIGQUIT: - // avaliable commands see - dumpCommand := "heap" - dumpFileName := filepath.Join(os.TempDir(), fmt.Sprintf("go-openbmclapi-dump-command.%d.in", os.Getpid())) - log.Infof("Reading dump command file at %q", dumpFileName) - var buf [128]byte - if dumpFile, err := os.Open(dumpFileName); err != nil { - log.Errorf("Cannot open dump command file: %v", err) - } else if n, err := dumpFile.Read(buf[:]); err != nil { - dumpFile.Close() - log.Errorf("Cannot read dump command file: %v", err) - } else { - dumpFile.Truncate(0) - dumpFile.Close() - dumpCommand = (string)(bytes.TrimSpace(buf[:n])) - } - pcmd := pprof.Lookup(dumpCommand) - if pcmd == nil { - log.Errorf("No pprof command is named %q", dumpCommand) - continue - } - name := fmt.Sprintf(time.Now().Format("dump-%s-20060102-150405.txt"), dumpCommand) - log.Infof("Creating goroutine dump file at %s", name) - if fd, err := os.Create(name); err != nil { - log.Infof("Cannot create dump file: %v", err) - } else { - err := pcmd.WriteTo(fd, 1) - fd.Close() - if err != nil { - log.Infof("Cannot write dump file: %v", err) - } else { - log.Info("Dump file created") - } - } - case syscall.SIGHUP: - go r.ReloadConfig(ctx) - default: - cancel() - if forceStop == nil { - ctx, cancel := context.WithCancel(context.Background()) - forceStop = cancel - go func() { - defer close(exited) - r.StopServer(ctx) - }() - } else { - log.Warn("signal:", s) - log.Error("Second close signal received, forcely shutting down") - forceStop() - } - } - - } - } -} - -func (r *Runner) ReloadConfig(ctx context.Context) { - if r.reloading.CompareAndSwap(false, true) { - log.Error("Config is already reloading!") - return - } - defer r.reloading.Store(false) - - config, err := readAndRewriteConfig() - if err != nil { - log.Errorf("Config error: %v", err) - } else { - if err := r.updateConfig(ctx, config); err != nil { - log.Errorf("Error when reloading config: %v", err) - } - } -} - -func (r *Runner) updateConfig(ctx context.Context, newConfig *config.Config) error { - if err := r.configHandler.update(newConfig); err != nil { - return err - } - return r.configHandler.doUpdateProcesses(ctx) -} - -func (r *Runner) SetupLogger(ctx context.Context) error { - if r.Config.Advanced.DebugLog { - log.SetLevel(log.LevelDebug) - } else { - log.SetLevel(log.LevelInfo) - } - log.SetLogSlots(r.Config.LogSlots) - if r.Config.NoAccessLog { - log.SetAccessLogSlots(-1) - } else { - log.SetAccessLogSlots(r.Config.AccessLogSlots) - } - - r.Config.ApplyWebManifest(dsbManifest) - return nil -} - -func (r *Runner) StopServer(ctx context.Context) { - r.tunnelCancel() - shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) - defer cancelShut() - log.TrWarnf("warn.server.closing") - shutDone := make(chan struct{}, 0) - go func() { - defer close(shutDone) - defer cancelShut() - var wg sync.WaitGroup - for _, cr := range r.clusters { - go func() { - defer wg.Done() - cr.Disable(shutCtx) - }() - } - wg.Wait() - log.TrWarnf("warn.httpserver.closing") - r.server.Shutdown(shutCtx) - r.listener.Close() - r.listener = nil - }() - select { - case <-shutDone: - case <-ctx.Done(): - return - } - log.TrWarnf("warn.server.closed") -} - -func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) { - if !r.Config.Hijack.Enable { - return - } - if !r.updating.CompareAndSwap(false, true) { - return - } - defer r.updating.Store(false) - - sem := limited.NewSemaphore(12) - log.Info("Begin to update file records") - for _, f := range files { - for _, u := range f.URLs { - if strings.HasPrefix(u.Path, "/openbmclapi/download/") { - continue - } - if oldfileset[f.Hash] > 0 { - continue - } - sem.Acquire() - go func(rec database.FileRecord) { - defer sem.Release() - r.database.SetFileRecord(rec) - }(database.FileRecord{ - Path: u.Path, - Hash: f.Hash, - Size: f.Size, - }) - } - } - sem.Wait() - log.Info("All file records are updated") -} - -func (r *Runner) InitSynchronizer(ctx context.Context) { - fileMap := make(map[string]*cluster.StorageFileInfo) - for _, cr := range r.clusters { - log.TrInfof("info.filelist.fetching", cr.ID()) - if err := cr.GetFileList(ctx, fileMap, true); err != nil { - log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) - if errors.Is(err, context.Canceled) { - return - } - } - } - - checkCount := -1 - heavyCheck := !r.Config.Advanced.NoHeavyCheck - heavyCheckInterval := r.Config.Advanced.HeavyCheckInterval - if heavyCheckInterval <= 0 { - heavyCheck = false - } - - // if !r.Config.Advanced.SkipFirstSync - { - slots := 10 - if err := r.client.SyncFiles(ctx, r.storageManager, fileMap, false, slots); err != nil { - log.Errorf("Sync failed: %v", err) - return - } - go r.UpdateFileRecords(fileMap, nil) - - if !r.Config.Advanced.NoGC { - go r.client.Gc(context.TODO(), r.storageManager, fileMap) - } - } - // else if fl != nil { - // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { - // return - // } - // } - - createInterval(ctx, func() { - fileMap := make(map[string]*cluster.StorageFileInfo) - for _, cr := range r.clusters { - log.TrInfof("info.filelist.fetching", cr.ID()) - if err := cr.GetFileList(ctx, fileMap, false); err != nil { - log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) - return - } - } - if len(fileMap) == 0 { - log.Infof("No file was updated since last check") - return - } - - checkCount = (checkCount + 1) % heavyCheckInterval - slots := 10 - if err := r.client.SyncFiles(ctx, r.storageManager, fileMap, heavyCheck && (checkCount == 0), slots); err != nil { - log.Errorf("Sync failed: %v", err) - return - } - }, (time.Duration)(r.Config.SyncInterval)*time.Minute) -} - -func (r *Runner) CreateHTTPListener(ctx context.Context) (*utils.HTTPTLSListener, error) { - addr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port))) - listener, err := net.Listen("tcp", addr) - if err != nil { - log.TrErrorf("error.address.listen.failed", addr, err) - return nil, err - } - if r.Config.ServeLimit.Enable { - limted := limited.NewLimitedListener(listener, r.Config.ServeLimit.MaxConn, 0, r.Config.ServeLimit.UploadRate*1024) - limted.SetMinWriteRate(1024) - listener = limted - } - - if r.Config.UseCert { - var err error - r.tlsConfig, err = r.GenerateTLSConfig() - if err != nil { - log.Errorf("Failed to generate TLS config: %v", err) - return nil, err - } - } - return utils.NewHttpTLSListener(listener, r.tlsConfig), nil -} - -func (r *Runner) GenerateTLSConfig() (*tls.Config, error) { - if len(r.Config.Certificates) == 0 { - log.TrErrorf("error.cert.not.set") - return nil, errors.New("No certificate is defined") - } - tlsConfig := new(tls.Config) - tlsConfig.Certificates = make([]tls.Certificate, len(r.Config.Certificates)) - for i, c := range r.Config.Certificates { - var err error - tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) - if err != nil { - log.TrErrorf("error.cert.parse.failed", i, err) - return nil, err - } - } - return tlsConfig, nil -} - -func (r *Runner) RequestClusterCert(ctx context.Context, cr *cluster.Cluster) (*tls.Certificate, error) { - if cr.Options().Byoc { - return nil, nil - } - log.TrInfof("info.cert.requesting", cr.ID()) - tctx, cancel := context.WithTimeout(ctx, time.Minute*10) - pair, err := cr.RequestCert(tctx) - cancel() - if err != nil { - log.TrErrorf("error.cert.request.failed", err) - return nil, err - } - cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) - if err != nil { - log.TrErrorf("error.cert.requested.parse.failed", err) - return nil, err - } - certHost, _ := parseCertCommonName(cert.Certificate[0]) - log.TrInfof("info.cert.requested", certHost) - return &cert, nil -} - -func (r *Runner) PatchTLSWithClusterCertificates(tlsConfig *tls.Config) *tls.Config { - certs := make([]tls.Certificate, 0, len(r.certificates)) - for _, c := range r.certificates { - if c != nil { - certs = append(certs, *c) - } - } - if len(certs) == 0 { - if tlsConfig == nil { - tlsConfig = new(tls.Config) - } else { - tlsConfig = tlsConfig.Clone() - } - tlsConfig.Certificates = append(tlsConfig.Certificates, certs...) - } - return tlsConfig -} - -// updateClustersWithGeneralConfig will re-enable all clusters with latest general config -func (r *Runner) updateClustersWithGeneralConfig(ctx context.Context) error { - gcfg := r.GetClusterGeneralConfig() - var wg sync.WaitGroup - for _, cr := range r.clusters { - wg.Add(1) - go func(cr *cluster.Cluster) { - defer wg.Done() - cr.Disable(ctx) - *cr.GeneralConfig() = gcfg - if err := cr.Enable(ctx); err != nil { - log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) - return - } - }(cr) - } - wg.Wait() - return nil -} - -func (r *Runner) EnableClusterAll(ctx context.Context) { - var wg sync.WaitGroup - for _, cr := range r.clusters { - wg.Add(1) - go func(cr *cluster.Cluster) { - defer wg.Done() - if err := cr.Enable(ctx); err != nil { - log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) - return - } - }(cr) - } - wg.Wait() -} - -func (r *Runner) StartTunneler() { - ctx, cancel := context.WithCancel(context.Background()) - r.tunnelCancel = cancel - go func() { - dur := time.Second - for { - start := time.Now() - r.RunTunneler(ctx) - used := time.Since(start) - // If the program runs no longer than 30s, then it fails too fast. - if used < time.Second*30 { - dur = min(dur*2, time.Minute*10) - } else { - dur = time.Second - } - select { - case <-time.After(dur): - case <-ctx.Done(): - return - } - } - }() -} - -func (r *Runner) RunTunneler(ctx context.Context) { - cmd := exec.CommandContext(ctx, r.Config.Tunneler.TunnelProg) - log.TrInfof("info.tunnel.running", cmd.String()) - var ( - cmdOut, cmdErr io.ReadCloser - err error - ) - cmd.Env = append(os.Environ(), - "CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port))) - if cmdOut, err = cmd.StdoutPipe(); err != nil { - log.TrErrorf("error.tunnel.command.prepare.failed", err) - os.Exit(1) - } - if cmdErr, err = cmd.StderrPipe(); err != nil { - log.TrErrorf("error.tunnel.command.prepare.failed", err) - os.Exit(1) - } - if err = cmd.Start(); err != nil { - log.TrErrorf("error.tunnel.command.prepare.failed", err) - os.Exit(1) - } - type addrOut struct { - host string - port uint16 - } - detectedCh := make(chan addrOut, 1) - onLog := func(line []byte) { - tunnelHost, tunnelPort, ok := r.Config.Tunneler.MatchTunnelOutput(line) - if !ok { - return - } - if len(tunnelHost) > 0 && tunnelHost[0] == '[' && tunnelHost[len(tunnelHost)-1] == ']' { // a IPv6 with port []: - tunnelHost = tunnelHost[1 : len(tunnelHost)-1] - } - port, err := strconv.Atoi((string)(tunnelPort)) - if err != nil { - log.Panic(err) - } - select { - case detectedCh <- addrOut{ - host: (string)(tunnelHost), - port: (uint16)(port), - }: - default: - } - } - go func() { - defer cmdOut.Close() - defer cmd.Process.Kill() - sc := bufio.NewScanner(cmdOut) - for sc.Scan() { - log.Info("[tunneler/stdout]:", sc.Text()) - onLog(sc.Bytes()) - } - }() - go func() { - defer cmdErr.Close() - defer cmd.Process.Kill() - sc := bufio.NewScanner(cmdErr) - for sc.Scan() { - log.Info("[tunneler/stderr]:", sc.Text()) - onLog(sc.Bytes()) - } - }() - go func() { - defer log.RecordPanic() - defer cmd.Process.Kill() - for { - select { - case addr := <-detectedCh: - log.TrInfof("info.tunnel.detected", addr.host, addr.port) - r.publicHost, r.publicPort = addr.host, addr.port - r.updateClustersWithGeneralConfig(ctx) - if ctx.Err() != nil { - return - } - case <-ctx.Done(): - return - } - } - }() - if _, err := cmd.Process.Wait(); err != nil { - if ctx.Err() != nil { - return - } - log.Errorf("Tunnel program exited: %v", err) - } -} diff --git a/runner.go b/runner.go new file mode 100644 index 00000000..7c92a5fe --- /dev/null +++ b/runner.go @@ -0,0 +1,690 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "bufio" + "bytes" + "context" + "crypto/tls" + "encoding/base64" + "errors" + "fmt" + "io" + "net" + "net/http" + "os" + "os/exec" + "os/signal" + "path/filepath" + "strconv" + "strings" + "sync" + "sync/atomic" + "syscall" + "time" + + "runtime/pprof" + + doh "github.com/libp2p/go-doh-resolver" + + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/api/bmclapi" + "github.com/LiterMC/go-openbmclapi/cluster" + "github.com/LiterMC/go-openbmclapi/config" + "github.com/LiterMC/go-openbmclapi/database" + "github.com/LiterMC/go-openbmclapi/limited" + "github.com/LiterMC/go-openbmclapi/log" + "github.com/LiterMC/go-openbmclapi/notify" + "github.com/LiterMC/go-openbmclapi/notify/webpush" + "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/utils" +) + +type Runner struct { + Config *config.Config + + configHandler *ConfigHandler + client *cluster.HTTPClient + database database.DB + clusters map[string]*cluster.Cluster + userManager api.UserManager + tokenManager api.TokenManager + subManager api.SubscriptionManager + notifyManager *notify.Manager + storageManager *storage.Manager + statManager *cluster.StatManager + hijacker *bmclapi.HjProxy + + server *http.Server + apiRateLimiter *limited.APIRateMiddleWare + handler http.Handler + handlerAPIv0 http.Handler + hijackHandler http.Handler + + tlsConfig *tls.Config + certificates map[string]*tls.Certificate + publicHost string + publicPort uint16 + listener *utils.HTTPTLSListener + + reloading atomic.Bool + updating atomic.Bool + tunnelCancel context.CancelFunc +} + +func NewRunner() *Runner { + r := new(Runner) + + r.configHandler = &ConfigHandler{r: r} + + var dialer *net.Dialer + r.client = cluster.NewHTTPClient(dialer, r.Config.Cache.NewCache()) + + { + var err error + if r.Config.Database.Driver == "memory" { + r.database = database.NewMemoryDB() + } else if r.database, err = database.NewSqlDB(r.Config.Database.Driver, r.Config.Database.DSN); err != nil { + log.Errorf("Cannot connect to database: %v", err) + } + } + + // r.userManager = + // r.tokenManager = + webpushPlg := new(webpush.Plugin) + r.subManager = &subscriptionManager{ + webpushPlg: webpushPlg, + DB: r.database, + } + r.notifyManager = notify.NewManager(dataDir, r.database, r.client.CachedClient(), "go-openbmclapi") + r.storageManager = storage.NewManager(storages) + r.statManager = cluster.NewStatManager() + if err := r.statManager.Load(dataDir); err != nil { + log.Errorf("Stat load failed:", err) + } + r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) + return r +} + +func (r *Runner) getPublicPort() uint16 { + if r.publicPort > 0 { + return r.publicPort + } + return r.Config.Port +} + +func (r *Runner) getCertCount() int { + if r.tlsConfig == nil { + return 0 + } + return len(r.tlsConfig.Certificates) +} + +func (r *Runner) InitServer() { + r.server = &http.Server{ + ReadTimeout: 10 * time.Second, + IdleTimeout: 5 * time.Second, + Handler: (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { + r.handler.ServeHTTP(rw, req) + }), + ErrorLog: log.ProxiedStdLog, + } + r.updateRateLimit(context.TODO()) + r.handler = r.GetHandler() +} + +// StartServer will start the HTTP server +// If a server is already running on an old listener, the listener will be closed. +func (r *Runner) StartServer(ctx context.Context) error { + htListener, err := r.CreateHTTPListener(ctx) + if err != nil { + return err + } + if r.listener != nil { + r.listener.Close() + } + r.listener = htListener + go func() { + defer htListener.Close() + if err := r.server.Serve(htListener); !errors.Is(err, http.ErrServerClosed) && !errors.Is(err, net.ErrClosed) { + log.Error("Error when serving:", err) + os.Exit(1) + } + }() + return nil +} + +func (r *Runner) GetClusterGeneralConfig() config.ClusterGeneralConfig { + return config.ClusterGeneralConfig{ + PublicHost: r.publicHost, + PublicPort: r.getPublicPort(), + NoFastEnable: r.Config.Advanced.NoFastEnable, + MaxReconnectCount: r.Config.MaxReconnectCount, + } +} + +func (r *Runner) InitClusters(ctx context.Context) { + + _ = doh.NewResolver // TODO: use doh resolver + + r.clusters = make(map[string]*cluster.Cluster) + gcfg := r.GetClusterGeneralConfig() + for name, opts := range r.Config.Clusters { + cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager) + if err := cr.Init(ctx); err != nil { + log.TrErrorf("error.init.failed", err) + } else { + r.clusters[name] = cr + } + } + + // r.cluster = NewCluster(ctx, + // ClusterServerURL, + // baseDir, + // config.PublicHost, r.getPublicPort(), + // config.ClusterId, config.ClusterSecret, + // config.Byoc, dialer, + // config.Storages, + // cache, + // ) + // if err := r.cluster.Init(ctx); err != nil { + // log.TrErrorf("error.init.failed"), err) + // os.Exit(1) + // } +} + +func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) int { + signalCh := make(chan os.Signal, 1) + log.Debugf("Receiving signals") + signal.Notify(signalCh, syscall.SIGHUP, syscall.SIGINT, syscall.SIGTERM, syscall.SIGQUIT) + defer signal.Stop(signalCh) + + var ( + forceStop context.CancelFunc + exited = make(chan struct{}, 0) + ) + + for { + select { + case <-exited: + return 0 + case s := <-signalCh: + switch s { + case syscall.SIGQUIT: + // avaliable commands see + dumpCommand := "heap" + dumpFileName := filepath.Join(os.TempDir(), fmt.Sprintf("go-openbmclapi-dump-command.%d.in", os.Getpid())) + log.Infof("Reading dump command file at %q", dumpFileName) + var buf [128]byte + if dumpFile, err := os.Open(dumpFileName); err != nil { + log.Errorf("Cannot open dump command file: %v", err) + } else if n, err := dumpFile.Read(buf[:]); err != nil { + dumpFile.Close() + log.Errorf("Cannot read dump command file: %v", err) + } else { + dumpFile.Truncate(0) + dumpFile.Close() + dumpCommand = (string)(bytes.TrimSpace(buf[:n])) + } + pcmd := pprof.Lookup(dumpCommand) + if pcmd == nil { + log.Errorf("No pprof command is named %q", dumpCommand) + continue + } + name := fmt.Sprintf(time.Now().Format("dump-%s-20060102-150405.txt"), dumpCommand) + log.Infof("Creating goroutine dump file at %s", name) + if fd, err := os.Create(name); err != nil { + log.Infof("Cannot create dump file: %v", err) + } else { + err := pcmd.WriteTo(fd, 1) + fd.Close() + if err != nil { + log.Infof("Cannot write dump file: %v", err) + } else { + log.Info("Dump file created") + } + } + case syscall.SIGHUP: + go r.ReloadConfig(ctx) + default: + cancel() + if forceStop == nil { + ctx, cancel := context.WithCancel(context.Background()) + forceStop = cancel + go func() { + defer close(exited) + r.StopServer(ctx) + }() + } else { + log.Warn("signal:", s) + log.Error("Second close signal received, forcely shutting down") + forceStop() + } + } + + } + } +} + +func (r *Runner) ReloadConfig(ctx context.Context) { + if r.reloading.CompareAndSwap(false, true) { + log.Error("Config is already reloading!") + return + } + defer r.reloading.Store(false) + + config, err := readAndRewriteConfig() + if err != nil { + log.Errorf("Config error: %v", err) + } else { + if err := r.updateConfig(ctx, config); err != nil { + log.Errorf("Error when reloading config: %v", err) + } + } +} + +func (r *Runner) updateConfig(ctx context.Context, newConfig *config.Config) error { + if err := r.configHandler.update(newConfig); err != nil { + return err + } + return r.configHandler.doUpdateProcesses(ctx) +} + +func (r *Runner) SetupLogger(ctx context.Context) error { + if r.Config.Advanced.DebugLog { + log.SetLevel(log.LevelDebug) + } else { + log.SetLevel(log.LevelInfo) + } + log.SetLogSlots(r.Config.LogSlots) + if r.Config.NoAccessLog { + log.SetAccessLogSlots(-1) + } else { + log.SetAccessLogSlots(r.Config.AccessLogSlots) + } + + r.Config.ApplyWebManifest(dsbManifest) + return nil +} + +func (r *Runner) StopServer(ctx context.Context) { + r.tunnelCancel() + shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) + defer cancelShut() + log.TrWarnf("warn.server.closing") + shutDone := make(chan struct{}, 0) + go func() { + defer close(shutDone) + defer cancelShut() + var wg sync.WaitGroup + for _, cr := range r.clusters { + go func() { + defer wg.Done() + cr.Disable(shutCtx) + }() + } + wg.Wait() + log.TrWarnf("warn.httpserver.closing") + r.server.Shutdown(shutCtx) + r.listener.Close() + r.listener = nil + }() + select { + case <-shutDone: + case <-ctx.Done(): + return + } + log.TrWarnf("warn.server.closed") +} + +func (r *Runner) UpdateFileRecords(files map[string]*cluster.StorageFileInfo, oldfileset map[string]int64) { + if !r.Config.Hijack.Enable { + return + } + if !r.updating.CompareAndSwap(false, true) { + return + } + defer r.updating.Store(false) + + sem := limited.NewSemaphore(12) + log.Info("Begin to update file records") + for _, f := range files { + for _, u := range f.URLs { + if strings.HasPrefix(u.Path, "/openbmclapi/download/") { + continue + } + if oldfileset[f.Hash] > 0 { + continue + } + sem.Acquire() + go func(rec database.FileRecord) { + defer sem.Release() + r.database.SetFileRecord(rec) + }(database.FileRecord{ + Path: u.Path, + Hash: f.Hash, + Size: f.Size, + }) + } + } + sem.Wait() + log.Info("All file records are updated") +} + +func (r *Runner) InitSynchronizer(ctx context.Context) { + fileMap := make(map[string]*cluster.StorageFileInfo) + for _, cr := range r.clusters { + log.TrInfof("info.filelist.fetching", cr.ID()) + if err := cr.GetFileList(ctx, fileMap, true); err != nil { + log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) + if errors.Is(err, context.Canceled) { + return + } + } + } + + checkCount := -1 + heavyCheck := !r.Config.Advanced.NoHeavyCheck + heavyCheckInterval := r.Config.Advanced.HeavyCheckInterval + if heavyCheckInterval <= 0 { + heavyCheck = false + } + + // if !r.Config.Advanced.SkipFirstSync + { + slots := 10 + if err := r.client.SyncFiles(ctx, r.storageManager, fileMap, false, slots); err != nil { + log.Errorf("Sync failed: %v", err) + return + } + go r.UpdateFileRecords(fileMap, nil) + + if !r.Config.Advanced.NoGC { + go r.client.Gc(context.TODO(), r.storageManager, fileMap) + } + } + // else if fl != nil { + // if err := r.cluster.SetFilesetByExists(ctx, fl); err != nil { + // return + // } + // } + + createInterval(ctx, func() { + fileMap := make(map[string]*cluster.StorageFileInfo) + for _, cr := range r.clusters { + log.TrInfof("info.filelist.fetching", cr.ID()) + if err := cr.GetFileList(ctx, fileMap, false); err != nil { + log.TrErrorf("error.filelist.fetch.failed", cr.ID(), err) + return + } + } + if len(fileMap) == 0 { + log.Infof("No file was updated since last check") + return + } + + checkCount = (checkCount + 1) % heavyCheckInterval + slots := 10 + if err := r.client.SyncFiles(ctx, r.storageManager, fileMap, heavyCheck && (checkCount == 0), slots); err != nil { + log.Errorf("Sync failed: %v", err) + return + } + }, (time.Duration)(r.Config.SyncInterval)*time.Minute) +} + +func (r *Runner) CreateHTTPListener(ctx context.Context) (*utils.HTTPTLSListener, error) { + addr := net.JoinHostPort(r.Config.Host, strconv.Itoa((int)(r.Config.Port))) + listener, err := net.Listen("tcp", addr) + if err != nil { + log.TrErrorf("error.address.listen.failed", addr, err) + return nil, err + } + if r.Config.ServeLimit.Enable { + limted := limited.NewLimitedListener(listener, r.Config.ServeLimit.MaxConn, 0, r.Config.ServeLimit.UploadRate*1024) + limted.SetMinWriteRate(1024) + listener = limted + } + + if r.Config.UseCert { + var err error + r.tlsConfig, err = r.GenerateTLSConfig() + if err != nil { + log.Errorf("Failed to generate TLS config: %v", err) + return nil, err + } + } + return utils.NewHttpTLSListener(listener, r.tlsConfig), nil +} + +func (r *Runner) GenerateTLSConfig() (*tls.Config, error) { + if len(r.Config.Certificates) == 0 { + log.TrErrorf("error.cert.not.set") + return nil, errors.New("No certificate is defined") + } + tlsConfig := new(tls.Config) + tlsConfig.Certificates = make([]tls.Certificate, len(r.Config.Certificates)) + for i, c := range r.Config.Certificates { + var err error + tlsConfig.Certificates[i], err = tls.LoadX509KeyPair(c.Cert, c.Key) + if err != nil { + log.TrErrorf("error.cert.parse.failed", i, err) + return nil, err + } + } + return tlsConfig, nil +} + +func (r *Runner) RequestClusterCert(ctx context.Context, cr *cluster.Cluster) (*tls.Certificate, error) { + if cr.Options().Byoc { + return nil, nil + } + log.TrInfof("info.cert.requesting", cr.ID()) + tctx, cancel := context.WithTimeout(ctx, time.Minute*10) + pair, err := cr.RequestCert(tctx) + cancel() + if err != nil { + log.TrErrorf("error.cert.request.failed", err) + return nil, err + } + cert, err := tls.X509KeyPair(([]byte)(pair.Cert), ([]byte)(pair.Key)) + if err != nil { + log.TrErrorf("error.cert.requested.parse.failed", err) + return nil, err + } + certHost, _ := parseCertCommonName(cert.Certificate[0]) + log.TrInfof("info.cert.requested", certHost) + return &cert, nil +} + +func (r *Runner) PatchTLSWithClusterCertificates(tlsConfig *tls.Config) *tls.Config { + certs := make([]tls.Certificate, 0, len(r.certificates)) + for _, c := range r.certificates { + if c != nil { + certs = append(certs, *c) + } + } + if len(certs) == 0 { + if tlsConfig == nil { + tlsConfig = new(tls.Config) + } else { + tlsConfig = tlsConfig.Clone() + } + tlsConfig.Certificates = append(tlsConfig.Certificates, certs...) + } + return tlsConfig +} + +// updateClustersWithGeneralConfig will re-enable all clusters with latest general config +func (r *Runner) updateClustersWithGeneralConfig(ctx context.Context) error { + gcfg := r.GetClusterGeneralConfig() + var wg sync.WaitGroup + for _, cr := range r.clusters { + wg.Add(1) + go func(cr *cluster.Cluster) { + defer wg.Done() + cr.Disable(ctx) + *cr.GeneralConfig() = gcfg + if err := cr.Enable(ctx); err != nil { + log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) + return + } + }(cr) + } + wg.Wait() + return nil +} + +func (r *Runner) EnableClusterAll(ctx context.Context) { + var wg sync.WaitGroup + for _, cr := range r.clusters { + wg.Add(1) + go func(cr *cluster.Cluster) { + defer wg.Done() + if err := cr.Enable(ctx); err != nil { + log.TrErrorf("error.cluster.enable.failed", cr.ID(), err) + return + } + }(cr) + } + wg.Wait() +} + +func (r *Runner) StartTunneler() { + ctx, cancel := context.WithCancel(context.Background()) + r.tunnelCancel = cancel + go func() { + dur := time.Second + for { + start := time.Now() + r.RunTunneler(ctx) + used := time.Since(start) + // If the program runs no longer than 30s, then it fails too fast. + if used < time.Second*30 { + dur = min(dur*2, time.Minute*10) + } else { + dur = time.Second + } + select { + case <-time.After(dur): + case <-ctx.Done(): + return + } + } + }() +} + +func (r *Runner) RunTunneler(ctx context.Context) { + cmd := exec.CommandContext(ctx, r.Config.Tunneler.TunnelProg) + log.TrInfof("info.tunnel.running", cmd.String()) + var ( + cmdOut, cmdErr io.ReadCloser + err error + ) + cmd.Env = append(os.Environ(), + "CLUSTER_PORT="+strconv.Itoa((int)(r.Config.Port))) + if cmdOut, err = cmd.StdoutPipe(); err != nil { + log.TrErrorf("error.tunnel.command.prepare.failed", err) + os.Exit(1) + } + if cmdErr, err = cmd.StderrPipe(); err != nil { + log.TrErrorf("error.tunnel.command.prepare.failed", err) + os.Exit(1) + } + if err = cmd.Start(); err != nil { + log.TrErrorf("error.tunnel.command.prepare.failed", err) + os.Exit(1) + } + type addrOut struct { + host string + port uint16 + } + detectedCh := make(chan addrOut, 1) + onLog := func(line []byte) { + tunnelHost, tunnelPort, ok := r.Config.Tunneler.MatchTunnelOutput(line) + if !ok { + return + } + if len(tunnelHost) > 0 && tunnelHost[0] == '[' && tunnelHost[len(tunnelHost)-1] == ']' { // a IPv6 with port []: + tunnelHost = tunnelHost[1 : len(tunnelHost)-1] + } + port, err := strconv.Atoi((string)(tunnelPort)) + if err != nil { + log.Panic(err) + } + select { + case detectedCh <- addrOut{ + host: (string)(tunnelHost), + port: (uint16)(port), + }: + default: + } + } + go func() { + defer cmdOut.Close() + defer cmd.Process.Kill() + sc := bufio.NewScanner(cmdOut) + for sc.Scan() { + log.Info("[tunneler/stdout]:", sc.Text()) + onLog(sc.Bytes()) + } + }() + go func() { + defer cmdErr.Close() + defer cmd.Process.Kill() + sc := bufio.NewScanner(cmdErr) + for sc.Scan() { + log.Info("[tunneler/stderr]:", sc.Text()) + onLog(sc.Bytes()) + } + }() + go func() { + defer log.RecordPanic() + defer cmd.Process.Kill() + for { + select { + case addr := <-detectedCh: + log.TrInfof("info.tunnel.detected", addr.host, addr.port) + r.publicHost, r.publicPort = addr.host, addr.port + r.updateClustersWithGeneralConfig(ctx) + if ctx.Err() != nil { + return + } + case <-ctx.Done(): + return + } + } + }() + if _, err := cmd.Process.Wait(); err != nil { + if ctx.Err() != nil { + return + } + log.Errorf("Tunnel program exited: %v", err) + } +} + +type subscriptionManager struct { + webpushPlg *webpush.Plugin + database.DB +} + +func (s *subscriptionManager) GetWebPushKey() string { + return base64.RawURLEncoding.EncodeToString(s.webpushPlg.GetPublicKey()) +} From 88b902b573ffe1cb8b780cda54a68d8a60c36bf1 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 15:55:57 -0600 Subject: [PATCH 25/36] refactor file download now we download first then calculate the hash, which support zero copy --- cluster/http.go | 10 +++ cluster/storage.go | 166 +++++++++++++++++++++------------------ cluster/storage_test.go | 136 ++++++++++++++++++++++++++++++++ cluster/tempfile_test.go | 92 +++++++++++----------- utils/bar.go | 112 ++++++++++++++++++-------- utils/http.go | 19 +++++ 6 files changed, 381 insertions(+), 154 deletions(-) create mode 100644 cluster/storage_test.go diff --git a/cluster/http.go b/cluster/http.go index 8c747c3a..30add80d 100644 --- a/cluster/http.go +++ b/cluster/http.go @@ -32,6 +32,7 @@ import ( gocache "github.com/LiterMC/go-openbmclapi/cache" "github.com/LiterMC/go-openbmclapi/internal/build" + "github.com/LiterMC/go-openbmclapi/utils" ) type HTTPClient struct { @@ -45,6 +46,7 @@ func NewHTTPClient(dialer *net.Dialer, cache gocache.Cache) *HTTPClient { DialContext: dialer.DialContext, } } + transport = utils.NewRoundTripRedirectErrorWrapper(transport) cachedTransport := transport if cache != gocache.NoCache { cachedTransport = &httpcache.Transport{ @@ -72,6 +74,14 @@ func (c *HTTPClient) DoUseCache(req *http.Request) (*http.Response, error) { return c.cachedCli.Do(req) } +func (c *HTTPClient) Client() *http.Client { + return c.cli +} + +func (c *HTTPClient) CachedClient() *http.Client { + return c.cachedCli +} + func redirectChecker(req *http.Request, via []*http.Request) error { req.Header.Del("Referer") if len(via) > 10 { diff --git a/cluster/storage.go b/cluster/storage.go index 35dda611..1f1d7697 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -454,24 +454,25 @@ func (c *HTTPClient) SyncFiles( for _, info := range missingMap { log.Debugf("File %s is for %s", info.Hash, joinStorageIDs(info.Storages)) - pathRes, err := c.fetchFile(ctx, &stats, info) + fileRes, err := c.fetchFile(ctx, &stats, info) if err != nil { log.TrWarnf("warn.sync.interrupted") return err } - go func(info *StorageFileInfo, pathRes <-chan string) { + go func(info *StorageFileInfo, fileRes <-chan *os.File) { defer log.RecordPanic() select { - case path := <-pathRes: + case srcFd := <-fileRes: // cr.syncProg.Add(1) - if path == "" { + if srcFd == nil { select { case done <- nil: // TODO: or all storage? case <-ctx.Done(): } return } - defer os.Remove(path) + defer os.Remove(srcFd.Name()) + defer srcFd.Close() // acquire slot here slotId, buf, free := stats.slots.Alloc(ctx) if buf == nil { @@ -479,15 +480,10 @@ func (c *HTTPClient) SyncFiles( } defer free() _ = slotId - var srcFd *os.File - if srcFd, err = os.Open(path); err != nil { - return - } - defer srcFd.Close() var failed []storage.Storage for _, target := range info.Storages { if _, err = srcFd.Seek(0, io.SeekStart); err != nil { - log.Errorf("Cannot seek file %q to start: %v", path, err) + log.Errorf("Cannot seek file %q to start: %v", srcFd.Name(), err) continue } if err = target.Create(info.Hash, srcFd); err != nil { @@ -498,7 +494,7 @@ func (c *HTTPClient) SyncFiles( } free() srcFd.Close() - os.Remove(path) + os.Remove(srcFd.Name()) select { case done <- failed: case <-ctx.Done(): @@ -506,10 +502,10 @@ func (c *HTTPClient) SyncFiles( case <-ctx.Done(): return } - }(info, pathRes) + }(info, fileRes) } - for i := len(missingMap); i > 0; i-- { + for range len(missingMap) { select { case failed := <-done: for _, s := range failed { @@ -538,7 +534,7 @@ func (c *HTTPClient) SyncFiles( return nil } -func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *StorageFileInfo) (<-chan string, error) { +func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *StorageFileInfo) (<-chan *os.File, error) { const maxRetryCount = 10 slotId, buf, free := stats.slots.Alloc(ctx) @@ -546,18 +542,27 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage return nil, ctx.Err() } - pathRes := make(chan string, 1) + hashMethod, err := getHashMethod(len(f.Hash)) + if err != nil { + return nil, err + } + + reqInd := 0 + reqs := make([]*http.Request, 0, len(f.URLs)) + for _, rq := range f.URLs { + reqs = append(reqs, rq.Request) + } + + fileRes := make(chan *os.File, 1) go func() { defer log.RecordPanic() defer free() - defer close(pathRes) + defer close(fileRes) var barUnit decor.SizeB1024 var tried atomic.Int32 tried.Store(1) - fPath := f.Hash // TODO: show downloading URL instead? Will it be too long? - bar := stats.pg.AddBar(f.Size, mpb.BarRemoveOnComplete(), mpb.BarPriority(slotId), @@ -570,7 +575,7 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage } return fmt.Sprintf("(%d/%d) ", tc, maxRetryCount) }), - decor.Name(fPath, decor.WCSyncSpaceR), + decor.Name(f.Hash, decor.WCSyncSpaceR), ), mpb.AppendDecorators( decor.NewPercentage("%d", decor.WCSyncSpace), @@ -583,58 +588,81 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage ) defer bar.Abort(true) + fd, err := os.CreateTemp("", "*.downloading") + if err != nil { + log.Errorf("Cannot create temporary file: %s", err) + stats.failCount.Add(1) + return + } + successed := false + defer func(fd *os.File) { + if !successed { + fd.Close() + os.Remove(fd.Name()) + } + }(fd) + // prealloc space + if err := fd.Truncate(f.Size); err != nil { + log.Warnf("File space pre-alloc failed: %v", err) + } + + downloadOnce := func() error { + if _, err := fd.Seek(io.SeekStart, 0); err != nil { + return err + } + if err := c.fetchFileWithBuf(ctx, reqs[reqInd], f.Size, hashMethod, f.Hash, fd, buf, func(r io.Reader) io.Reader { + return utils.ProxyPBReader(r, bar, stats.totalBar, &stats.lastInc) + }); err != nil { + reqInd = (reqInd + 1) % len(reqs) + return err + } + return nil + } + interval := time.Second for { bar.SetCurrent(0) - hashMethod, err := getHashMethod(len(f.Hash)) + err := downloadOnce() if err == nil { - var path string - if path, err = c.fetchFileWithBuf(ctx, f, hashMethod, buf, func(r io.Reader) io.Reader { - return utils.ProxyPBReader(r, bar, stats.totalBar, &stats.lastInc) - }); err == nil { - pathRes <- path - stats.okCount.Add(1) - log.Infof(lang.Tr("info.sync.downloaded"), fPath, - utils.BytesToUnit((float64)(f.Size)), - (float64)(stats.totalBar.Current())/(float64)(stats.totalSize)*100) - return - } + break } bar.SetRefill(bar.Current()) c := tried.Add(1) if c > maxRetryCount { - log.TrErrorf("error.sync.download.failed", fPath, err) - break + log.TrErrorf("error.sync.download.failed", f.Hash, err) + stats.failCount.Add(1) + return } - log.TrErrorf("error.sync.download.failed.retry", fPath, interval, err) + log.TrErrorf("error.sync.download.failed.retry", f.Hash, interval, err) select { case <-time.After(interval): - interval *= 2 + interval = min(interval*2, time.Minute*10) case <-ctx.Done(): + stats.failCount.Add(1) return } } - stats.failCount.Add(1) + successed = true + fileRes <- fd + stats.okCount.Add(1) + log.Infof(lang.Tr("info.sync.downloaded"), f.Hash, + utils.BytesToUnit((float64)(f.Size)), + (float64)(stats.totalBar.Current())/(float64)(stats.totalSize)*100) }() - return pathRes, nil + return fileRes, nil } func (c *HTTPClient) fetchFileWithBuf( - ctx context.Context, f *StorageFileInfo, - hashMethod crypto.Hash, buf []byte, + ctx context.Context, req *http.Request, + size int64, hashMethod crypto.Hash, hash string, + rw io.ReadWriteSeeker, buf []byte, wrapper func(io.Reader) io.Reader, -) (path string, err error) { +) (err error) { var ( - req *http.Request res *http.Response - fd *os.File r io.Reader ) - for _, rq := range f.URLs { - req = rq.Request - break - } req = req.Clone(ctx) req.Header.Set("Accept-Encoding", "gzip, deflate") if res, err = c.Do(req); err != nil { @@ -643,10 +671,13 @@ func (c *HTTPClient) fetchFileWithBuf( defer res.Body.Close() if res.StatusCode != http.StatusOK { err = utils.NewHTTPStatusErrorFromResponse(res) - }else { + } else { switch ce := strings.ToLower(res.Header.Get("Content-Encoding")); ce { case "": r = res.Body + if res.ContentLength >= 0 && res.ContentLength != size { + err = fmt.Errorf("File size wrong, got %d, expect %d", res.ContentLength, size) + } case "gzip": r, err = gzip.NewReader(res.Body) case "deflate": @@ -656,41 +687,26 @@ func (c *HTTPClient) fetchFileWithBuf( } } if err != nil { - return "", utils.ErrorFromRedirect(err, res) + return utils.ErrorFromRedirect(err, res) } if wrapper != nil { r = wrapper(r) } - hw := hashMethod.New() - - if fd, err = os.CreateTemp("", "*.downloading"); err != nil { - return + if n, err := io.CopyBuffer(rw, r, buf); err != nil { + return utils.ErrorFromRedirect(err, res) + } else if n != size { + return utils.ErrorFromRedirect(fmt.Errorf("File size wrong, got %d, expect %d", n, size), res) } - path = fd.Name() - defer func(path string) { - if err != nil { - os.Remove(path) - } - }(path) - - _, err = io.CopyBuffer(io.MultiWriter(hw, fd), r, buf) - stat, err2 := fd.Stat() - fd.Close() - if err != nil { - err = utils.ErrorFromRedirect(err, res) - return + if _, err := rw.Seek(io.SeekStart, 0); err != nil { + return err } - if err2 != nil { - err = err2 - return + hw := hashMethod.New() + if _, err := io.CopyBuffer(hw, rw, buf); err != nil { + return err } - if t := stat.Size(); f.Size >= 0 && t != f.Size { - err = utils.ErrorFromRedirect(fmt.Errorf("File size wrong, got %d, expect %d", t, f.Size), res) - return - } else if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != f.Hash { - err = utils.ErrorFromRedirect(fmt.Errorf("File hash not match, got %s, expect %s", hs, f.Hash), res) - return + if hs := hex.EncodeToString(hw.Sum(buf[:0])); hs != hash { + return utils.ErrorFromRedirect(fmt.Errorf("File hash not match, got %s, expect %s", hs, hash), res) } return } diff --git a/cluster/storage_test.go b/cluster/storage_test.go new file mode 100644 index 00000000..0f805744 --- /dev/null +++ b/cluster/storage_test.go @@ -0,0 +1,136 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package cluster_test + +import ( + "testing" + + "bytes" + "crypto" + "io" + "net" + "net/http" + "os" + "strconv" +) + +var emptyBytes = make([]byte, 1024) + +func startServer() string { + listener, err := net.ListenTCP("tcp4", &net.TCPAddr{ + IP: net.IPv4(127, 0, 0, 1), + }) + if err != nil { + panic(err) + } + server := &http.Server{ + Handler: (http.HandlerFunc)(func(rw http.ResponseWriter, req *http.Request) { + size := 128 + rw.Header().Set("Content-Length", strconv.Itoa(size*len(emptyBytes))) + rw.WriteHeader(http.StatusOK) + for range size { + rw.Write(emptyBytes) + } + }), + } + go server.Serve(listener) + return "http://" + listener.Addr().String() +} + +var expectedDownloadHash = []byte{0xfa, 0x43, 0x23, 0x9b, 0xce, 0xe7, 0xb9, 0x7c, 0xa6, 0x2f, 0x0, 0x7c, 0xc6, 0x84, 0x87, 0x56, 0xa, 0x39, 0xe1, 0x9f, 0x74, 0xf3, 0xdd, 0xe7, 0x48, 0x6d, 0xb3, 0xf9, 0x8d, 0xf8, 0xe4, 0x71} + +func BenchmarkDownlaodWhileVerify(b *testing.B) { + url := startServer() + fd, err := os.CreateTemp("", "gotest-") + if err != nil { + b.Fatalf("Cannot create temporary file: %v", err) + } + defer fd.Close() + defer os.Remove(fd.Name()) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + b.Fatalf("Cannot form new request: %v", err) + } + + hashMethod := crypto.SHA256 + buf := make([]byte, 1024) + client := &http.Client{} + + b.ResetTimer() + for range b.N { + if _, err := fd.Seek(io.SeekStart, 0); err != nil { + b.Fatalf("Seek error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + b.Fatalf("Request error: %v", err) + } + hw := hashMethod.New() + if _, err := io.CopyBuffer(io.MultiWriter(fd, hw), resp.Body, buf); err != nil { + b.Fatalf("Copy error: %v", err) + } + resp.Body.Close() + if hs := hw.Sum(buf[:0]); !bytes.Equal(hs, expectedDownloadHash) { + b.Fatalf("Hash mismatch: %#v", hs) + } + } +} + +func BenchmarkDownlaodThenVerify(b *testing.B) { + url := startServer() + fd, err := os.CreateTemp("", "gotest-") + if err != nil { + b.Fatalf("Cannot create temporary file: %v", err) + } + defer fd.Close() + defer os.Remove(fd.Name()) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + b.Fatalf("Cannot form new request: %v", err) + } + + hashMethod := crypto.SHA256 + buf := make([]byte, 1024) + client := &http.Client{} + + b.ResetTimer() + for range b.N { + if _, err := fd.Seek(io.SeekStart, 0); err != nil { + b.Fatalf("Seek error: %v", err) + } + resp, err := client.Do(req) + if err != nil { + b.Fatalf("Request error: %v", err) + } + if _, err := io.CopyBuffer(fd, resp.Body, buf); err != nil { + b.Fatalf("Copy error: %v", err) + } + resp.Body.Close() + hw := hashMethod.New() + if _, err := fd.Seek(io.SeekStart, 0); err != nil { + b.Fatalf("Seek error: %v", err) + } + if _, err := io.CopyBuffer(hw, fd, buf); err != nil { + b.Fatalf("Copy error: %v", err) + } + if hs := hw.Sum(buf[:0]); !bytes.Equal(hs, expectedDownloadHash) { + b.Fatalf("Hash mismatch: %#v", hs) + } + } +} diff --git a/cluster/tempfile_test.go b/cluster/tempfile_test.go index d3d216b6..a2ad6ab1 100644 --- a/cluster/tempfile_test.go +++ b/cluster/tempfile_test.go @@ -38,142 +38,142 @@ var datas = func() [][]byte { return datas }() -func BenchmarkCreateAndRemoveFile(t *testing.B) { - t.ReportAllocs() +func BenchmarkCreateAndRemoveFile(b *testing.B) { + b.ReportAllocs() buf := make([]byte, 1024) _ = buf - for i := 0; i < t.N; i++ { + for i := 0; i < b.N; i++ { d := datas[i%len(datas)] fd, err := os.CreateTemp("", "*.downloading") if err != nil { - t.Fatalf("Cannot create temp file: %v", err) + b.Fatalf("Cannot create temp file: %v", err) } if _, err = fd.Write(d); err != nil { - t.Errorf("Cannot write file: %v", err) + b.Errorf("Cannot write file: %v", err) } else if err = fd.Sync(); err != nil { - t.Errorf("Cannot write file: %v", err) + b.Errorf("Cannot write file: %v", err) } fd.Close() os.Remove(fd.Name()) if err != nil { - t.FailNow() + b.FailNow() } } } -func BenchmarkWriteAndTruncateFile(t *testing.B) { - t.ReportAllocs() +func BenchmarkWriteAndTruncateFile(b *testing.B) { + b.ReportAllocs() buf := make([]byte, 1024) _ = buf fd, err := os.CreateTemp("", "*.downloading") if err != nil { - t.Fatalf("Cannot create temp file: %v", err) + b.Fatalf("Cannot create temp file: %v", err) } defer os.Remove(fd.Name()) - for i := 0; i < t.N; i++ { + for i := 0; i < b.N; i++ { d := datas[i%len(datas)] if _, err := fd.Write(d); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if err := fd.Sync(); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if err := fd.Truncate(0); err != nil { - t.Fatalf("Cannot truncate file: %v", err) + b.Fatalf("Cannot truncate file: %v", err) } } } -func BenchmarkWriteAndSeekFile(t *testing.B) { - t.ReportAllocs() +func BenchmarkWriteAndSeekFile(b *testing.B) { + b.ReportAllocs() buf := make([]byte, 1024) _ = buf fd, err := os.CreateTemp("", "*.downloading") if err != nil { - t.Fatalf("Cannot create temp file: %v", err) + b.Fatalf("Cannot create temp file: %v", err) } defer os.Remove(fd.Name()) - for i := 0; i < t.N; i++ { + for i := 0; i < b.N; i++ { d := datas[i%len(datas)] if _, err := fd.Write(d); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if err := fd.Sync(); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if _, err := fd.Seek(io.SeekStart, 0); err != nil { - t.Fatalf("Cannot seek file: %v", err) + b.Fatalf("Cannot seek file: %v", err) } } } -func BenchmarkParallelCreateAndRemoveFile(t *testing.B) { - t.ReportAllocs() - t.SetParallelism(4) +func BenchmarkParallelCreateAndRemoveFile(b *testing.B) { + b.ReportAllocs() + b.SetParallelism(4) buf := make([]byte, 1024) _ = buf - t.RunParallel(func(pb *testing.PB) { + b.RunParallel(func(pb *testing.PB) { for i := 0; pb.Next(); i++ { d := datas[i%len(datas)] fd, err := os.CreateTemp("", "*.downloading") if err != nil { - t.Fatalf("Cannot create temp file: %v", err) + b.Fatalf("Cannot create temp file: %v", err) } if _, err = fd.Write(d); err != nil { - t.Errorf("Cannot write file: %v", err) + b.Errorf("Cannot write file: %v", err) } else if err = fd.Sync(); err != nil { - t.Errorf("Cannot write file: %v", err) + b.Errorf("Cannot write file: %v", err) } fd.Close() if err := os.Remove(fd.Name()); err != nil { - t.Fatalf("Cannot remove file: %v", err) + b.Fatalf("Cannot remove file: %v", err) } if err != nil { - t.FailNow() + b.FailNow() } } }) } -func BenchmarkParallelWriteAndTruncateFile(t *testing.B) { - t.ReportAllocs() - t.SetParallelism(4) +func BenchmarkParallelWriteAndTruncateFile(b *testing.B) { + b.ReportAllocs() + b.SetParallelism(4) buf := make([]byte, 1024) _ = buf - t.RunParallel(func(pb *testing.PB) { + b.RunParallel(func(pb *testing.PB) { fd, err := os.CreateTemp("", "*.downloading") if err != nil { - t.Fatalf("Cannot create temp file: %v", err) + b.Fatalf("Cannot create temp file: %v", err) } defer os.Remove(fd.Name()) for i := 0; pb.Next(); i++ { d := datas[i%len(datas)] if _, err := fd.Write(d); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if err := fd.Sync(); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if err := fd.Truncate(0); err != nil { - t.Fatalf("Cannot truncate file: %v", err) + b.Fatalf("Cannot truncate file: %v", err) } } }) } -func BenchmarkParallelWriteAndSeekFile(t *testing.B) { - t.ReportAllocs() - t.SetParallelism(4) +func BenchmarkParallelWriteAndSeekFile(b *testing.B) { + b.ReportAllocs() + b.SetParallelism(4) buf := make([]byte, 1024) _ = buf - t.RunParallel(func(pb *testing.PB) { + b.RunParallel(func(pb *testing.PB) { fd, err := os.CreateTemp("", "*.downloading") if err != nil { - t.Fatalf("Cannot create temp file: %v", err) + b.Fatalf("Cannot create temp file: %v", err) } defer os.Remove(fd.Name()) for i := 0; pb.Next(); i++ { d := datas[i%len(datas)] if _, err := fd.Write(d); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if err := fd.Sync(); err != nil { - t.Fatalf("Cannot write file: %v", err) + b.Fatalf("Cannot write file: %v", err) } else if _, err := fd.Seek(io.SeekStart, 0); err != nil { - t.Fatalf("Cannot seel file: %v", err) + b.Fatalf("Cannot seel file: %v", err) } } }) diff --git a/utils/bar.go b/utils/bar.go index 7ca96bff..c66308e4 100644 --- a/utils/bar.go +++ b/utils/bar.go @@ -27,28 +27,21 @@ import ( "github.com/vbauerster/mpb/v8" ) -type ProxiedPBReader struct { - io.Reader +type pbReader struct { bar, total *mpb.Bar lastRead time.Time lastInc *atomic.Int64 } -func ProxyPBReader(r io.Reader, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedPBReader { - return &ProxiedPBReader{ - Reader: r, - bar: bar, - total: total, - lastInc: lastInc, - } -} - -func (p *ProxiedPBReader) Read(buf []byte) (n int, err error) { +func (p *pbReader) beforeRead() time.Time { start := p.lastRead if start.IsZero() { start = time.Now() } - n, err = p.Reader.Read(buf) + return start +} + +func (p *pbReader) afterRead(n int, start time.Time) { end := time.Now() p.lastRead = end used := end.Sub(start) @@ -57,38 +50,91 @@ func (p *ProxiedPBReader) Read(buf []byte) (n int, err error) { nowSt := end.UnixNano() last := p.lastInc.Swap(nowSt) p.total.EwmaIncrBy(n, (time.Duration)(nowSt-last)*time.Nanosecond) +} + +func (p *pbReader) read(r io.Reader, buf []byte) (n int, err error) { + start := p.beforeRead() + n, err = r.Read(buf) + p.afterRead(n, start) return } +type readerDeadline interface { + SetReadDeadline(time.Time) error +} + +func (p *pbReader) writeTo(r io.Reader, w io.Writer) (int64, error) { + const maxChunkSize = 1024 * 16 + const maxUpdateInterval = time.Second + lr := &io.LimitedReader{ + R: r, + N: 0, + } + rd, deadOk := r.(readerDeadline) + if deadOk { + defer rd.SetReadDeadline(time.Time{}) + } + var n int64 + for lr.N == 0 { + lr.N = maxChunkSize + start := p.beforeRead() + if deadOk { + rd.SetReadDeadline(time.Now().Add(maxUpdateInterval)) + } + n0, err := io.Copy(w, lr) + n += n0 + p.afterRead((int)(n0), start) + if err != nil { + return n, err + } + } + return n, nil +} + +type ProxiedPBReader struct { + io.Reader + pbr pbReader +} + +func ProxyPBReader(r io.Reader, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedPBReader { + return &ProxiedPBReader{ + Reader: r, + pbr: pbReader{ + bar: bar, + total: total, + lastInc: lastInc, + }, + } +} + +func (p *ProxiedPBReader) Read(buf []byte) (int, error) { + return p.pbr.read(p.Reader, buf) +} + +func (p *ProxiedPBReader) WriteTo(w io.Writer) (int64, error) { + return p.pbr.writeTo(p.Reader, w) +} + type ProxiedPBReadSeeker struct { io.ReadSeeker - bar, total *mpb.Bar - lastRead time.Time - lastInc *atomic.Int64 + pbr pbReader } func ProxyPBReadSeeker(r io.ReadSeeker, bar, total *mpb.Bar, lastInc *atomic.Int64) *ProxiedPBReadSeeker { return &ProxiedPBReadSeeker{ ReadSeeker: r, - bar: bar, - total: total, - lastInc: lastInc, + pbr: pbReader{ + bar: bar, + total: total, + lastInc: lastInc, + }, } } -func (p *ProxiedPBReadSeeker) Read(buf []byte) (n int, err error) { - start := p.lastRead - if start.IsZero() { - start = time.Now() - } - n, err = p.ReadSeeker.Read(buf) - end := time.Now() - p.lastRead = end - used := end.Sub(start) +func (p *ProxiedPBReadSeeker) Read(buf []byte) (int, error) { + return p.pbr.read(p.ReadSeeker, buf) +} - p.bar.EwmaIncrBy(n, used) - nowSt := end.UnixNano() - last := p.lastInc.Swap(nowSt) - p.total.EwmaIncrBy(n, (time.Duration)(nowSt-last)*time.Nanosecond) - return +func (p *ProxiedPBReadSeeker) WriteTo(w io.Writer) (int64, error) { + return p.pbr.writeTo(p.ReadSeeker, w) } diff --git a/utils/http.go b/utils/http.go index 25ddfe15..600d1866 100644 --- a/utils/http.go +++ b/utils/http.go @@ -554,3 +554,22 @@ func (e *RedirectError) Error() string { func (e *RedirectError) Unwrap() error { return e.Err } + +type redirectErrorWrapper struct { + rt http.RoundTripper +} + +func (w *redirectErrorWrapper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := w.rt.RoundTrip(req) + if err != nil { + if req.Response != nil { + return nil, ErrorFromRedirect(err, req.Response) + } + return nil, err + } + return resp, nil +} + +func NewRoundTripRedirectErrorWrapper(rt http.RoundTripper) http.RoundTripper { + return &redirectErrorWrapper{rt: rt} +} From 8cf633b9177319e94db6098a44e203e4e5f3483e Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:06:39 -0600 Subject: [PATCH 26/36] fill more manager --- api/user.go | 1 + api/v0/auth.go | 212 ------------------------------- database/db.go | 9 ++ notify/webhook/webhook.go | 2 +- runner.go | 48 +++++-- token/package.go | 256 ++++++++++++++++++++++++++++++++++++++ token/token_db.go | 67 ++++++++++ utils/crypto.go | 4 +- 8 files changed, 375 insertions(+), 224 deletions(-) create mode 100644 token/package.go create mode 100644 token/token_db.go diff --git a/api/user.go b/api/user.go index f6de76c6..f426a542 100644 --- a/api/user.go +++ b/api/user.go @@ -24,6 +24,7 @@ type UserManager interface { GetUser(id string) *User AddUser(*User) error RemoveUser(id string) error + ForEachUser(cb func(*User) error) error UpdateUserPassword(username string, password string) error UpdateUserPermissions(username string, permissions PermissionFlag) error diff --git a/api/v0/auth.go b/api/v0/auth.go index 0af13030..f16dce14 100644 --- a/api/v0/auth.go +++ b/api/v0/auth.go @@ -281,215 +281,3 @@ func (h *Handler) routeLogout(rw http.ResponseWriter, req *http.Request) { h.tokens.InvalidToken(tid) rw.WriteHeader(http.StatusNoContent) } - -// var ( -// ErrUnsupportAuthType = errors.New("unsupported authorization type") -// ErrScopeNotMatch = errors.New("scope not match") -// ErrJTINotExists = errors.New("jti not exists") - -// ErrStrictPathNotMatch = errors.New("strict path not match") -// ErrStrictQueryNotMatch = errors.New("strict query value not match") -// ) - -// func (cr *Cluster) getJWTKey(t *jwt.Token) (any, error) { -// if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { -// return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) -// } -// return cr.apiHmacKey, nil -// } - -// const ( -// challengeTokenScope = "GOBA-challenge" -// authTokenScope = "GOBA-auth" -// apiTokenScope = "GOBA-API" -// ) - -// type challengeTokenClaims struct { -// jwt.RegisteredClaims - -// Scope string `json:"scope"` -// Action string `json:"act"` -// } - -// func (cr *Cluster) generateChallengeToken(cliId string, action string) (string, error) { -// now := time.Now() -// exp := now.Add(time.Minute * 1) -// token := jwt.NewWithClaims(jwt.SigningMethodHS256, &challengeTokenClaims{ -// RegisteredClaims: jwt.RegisteredClaims{ -// Subject: cliId, -// Issuer: cr.jwtIssuer, -// IssuedAt: jwt.NewNumericDate(now), -// ExpiresAt: jwt.NewNumericDate(exp), -// }, -// Scope: challengeTokenScope, -// Action: action, -// }) -// tokenStr, err := token.SignedString(cr.apiHmacKey) -// if err != nil { -// return "", err -// } -// return tokenStr, nil -// } - -// func (cr *Cluster) verifyChallengeToken(cliId string, action string, token string) (err error) { -// var claims challengeTokenClaims -// if _, err = jwt.ParseWithClaims( -// token, -// &claims, -// cr.getJWTKey, -// jwt.WithSubject(cliId), -// jwt.WithIssuedAt(), -// jwt.WithIssuer(cr.jwtIssuer), -// ); err != nil { -// return -// } -// if claims.Scope != challengeTokenScope { -// return ErrScopeNotMatch -// } -// if claims.Action != action { -// return ErrJTINotExists -// } -// return -// } - -// type authTokenClaims struct { -// jwt.RegisteredClaims - -// Scope string `json:"scope"` -// User string `json:"usr"` -// } - -// func (cr *Cluster) generateAuthToken(cliId string, userId string) (string, error) { -// jti, err := utils.GenRandB64(16) -// if err != nil { -// return "", err -// } -// now := time.Now() -// exp := now.Add(time.Hour * 24) -// token := jwt.NewWithClaims(jwt.SigningMethodHS256, &authTokenClaims{ -// RegisteredClaims: jwt.RegisteredClaims{ -// ID: jti, -// Subject: cliId, -// Issuer: cr.jwtIssuer, -// IssuedAt: jwt.NewNumericDate(now), -// ExpiresAt: jwt.NewNumericDate(exp), -// }, -// Scope: authTokenScope, -// User: userId, -// }) -// tokenStr, err := token.SignedString(cr.apiHmacKey) -// if err != nil { -// return "", err -// } -// if err = cr.database.AddJTI(jti, exp); err != nil { -// return "", err -// } -// return tokenStr, nil -// } - -// func (cr *Cluster) verifyAuthToken(cliId string, token string) (id string, user string, err error) { -// var claims authTokenClaims -// if _, err = jwt.ParseWithClaims( -// token, -// &claims, -// cr.getJWTKey, -// jwt.WithSubject(cliId), -// jwt.WithIssuedAt(), -// jwt.WithIssuer(cr.jwtIssuer), -// ); err != nil { -// return -// } -// if claims.Scope != authTokenScope { -// err = ErrScopeNotMatch -// return -// } -// if user = claims.User; user == "" { -// // reject old token -// err = ErrJTINotExists -// return -// } -// id = claims.ID -// if ok, _ := cr.database.ValidJTI(id); !ok { -// err = ErrJTINotExists -// return -// } -// return -// } - -// type apiTokenClaims struct { -// jwt.RegisteredClaims - -// Scope string `json:"scope"` -// User string `json:"usr"` -// StrictPath string `json:"str-p"` -// StrictQuery map[string]string `json:"str-q,omitempty"` -// } - -// func (cr *Cluster) generateAPIToken(cliId string, userId string, path string, query map[string]string) (string, error) { -// jti, err := utils.GenRandB64(8) -// if err != nil { -// return "", err -// } -// now := time.Now() -// exp := now.Add(time.Minute * 10) -// token := jwt.NewWithClaims(jwt.SigningMethodHS256, &apiTokenClaims{ -// RegisteredClaims: jwt.RegisteredClaims{ -// ID: jti, -// Subject: cliId, -// Issuer: cr.jwtIssuer, -// IssuedAt: jwt.NewNumericDate(now), -// ExpiresAt: jwt.NewNumericDate(exp), -// }, -// Scope: apiTokenScope, -// User: userId, -// StrictPath: path, -// StrictQuery: query, -// }) -// tokenStr, err := token.SignedString(cr.apiHmacKey) -// if err != nil { -// return "", err -// } -// if err = cr.database.AddJTI(jti, exp); err != nil { -// return "", err -// } -// return tokenStr, nil -// } - -// func (h *Handler) verifyAPIToken(cliId string, token string, path string, query url.Values) (id string, user string, err error) { -// var claims apiTokenClaims -// _, err = jwt.ParseWithClaims( -// token, -// &claims, -// cr.getJWTKey, -// jwt.WithSubject(cliId), -// jwt.WithIssuedAt(), -// jwt.WithIssuer(cr.jwtIssuer), -// ) -// if err != nil { -// return -// } -// if claims.Scope != apiTokenScope { -// err = ErrScopeNotMatch -// return -// } -// if user = claims.User; user == "" { -// err = ErrJTINotExists -// return -// } -// id = claims.ID -// if ok, _ := cr.database.ValidJTI(id); !ok { -// err = ErrJTINotExists -// return -// } -// if claims.StrictPath != path { -// err = ErrStrictPathNotMatch -// return -// } -// for k, v := range claims.StrictQuery { -// if query.Get(k) != v { -// err = ErrStrictQueryNotMatch -// return -// } -// } -// return -// } diff --git a/database/db.go b/database/db.go index 594b7775..8bca3464 100644 --- a/database/db.go +++ b/database/db.go @@ -51,6 +51,15 @@ type DB interface { // the callback should not edit the record pointer ForEachFileRecord(cb func(*FileRecord) error) error + // GetUsers() []*api.User + // GetUser(id string) *api.User + // AddUser(*api.User) error + // RemoveUser(id string) error + // ForEachUser(cb func(*api.User) error) error + // UpdateUserPassword(username string, password string) error + // UpdateUserPermissions(username string, permissions api.PermissionFlag) error + // VerifyUserPassword(userId string, comparator func(password string) bool) error + GetSubscribe(user string, client string) (*api.SubscribeRecord, error) SetSubscribe(api.SubscribeRecord) error RemoveSubscribe(user string, client string) error diff --git a/notify/webhook/webhook.go b/notify/webhook/webhook.go index 44b8b020..1fa637e3 100644 --- a/notify/webhook/webhook.go +++ b/notify/webhook/webhook.go @@ -17,7 +17,7 @@ * along with this program. If not, see . */ -package webpush +package webhook import ( "bytes" diff --git a/runner.go b/runner.go index 7c92a5fe..41251c2a 100644 --- a/runner.go +++ b/runner.go @@ -53,8 +53,11 @@ import ( "github.com/LiterMC/go-openbmclapi/limited" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/notify" + "github.com/LiterMC/go-openbmclapi/notify/email" + "github.com/LiterMC/go-openbmclapi/notify/webhook" "github.com/LiterMC/go-openbmclapi/notify/webpush" "github.com/LiterMC/go-openbmclapi/storage" + "github.com/LiterMC/go-openbmclapi/token" "github.com/LiterMC/go-openbmclapi/utils" ) @@ -104,21 +107,48 @@ func NewRunner() *Runner { r.database = database.NewMemoryDB() } else if r.database, err = database.NewSqlDB(r.Config.Database.Driver, r.Config.Database.DSN); err != nil { log.Errorf("Cannot connect to database: %v", err) + os.Exit(1) } } // r.userManager = - // r.tokenManager = - webpushPlg := new(webpush.Plugin) - r.subManager = &subscriptionManager{ - webpushPlg: webpushPlg, - DB: r.database, - } - r.notifyManager = notify.NewManager(dataDir, r.database, r.client.CachedClient(), "go-openbmclapi") - r.storageManager = storage.NewManager(storages) + if apiHMACKey, err := utils.LoadOrCreateHmacKey(dataDir, "server"); err != nil { + log.Errorf("Cannot load HMAC key: %v", err) + os.Exit(1) + } else { + r.tokenManager = token.NewDBManager("go-openbmclapi", apiHMACKey, r.database) + } + { + r.notifyManager = notify.NewManager(dataDir, r.database, r.client.CachedClient(), "go-openbmclapi") + r.notifyManager.AddPlugin(new(webhook.Plugin)) + if r.Config.Notification.EnableEmail { + emailPlg, err := email.NewSMTP(r.Config.Notification.EmailSMTP, r.Config.Notification.EmailSMTPEncryption, + r.Config.Notification.EmailSender, r.Config.Notification.EmailSenderPassword) + if err != nil { + log.Errorf("Cannot init SMTP client: %v", err) + os.Exit(1) + } + r.notifyManager.AddPlugin(emailPlg) + } + r.notifyManager.AddPlugin(new(email.Plugin)) + webpushPlg := new(webpush.Plugin) + r.notifyManager.AddPlugin(webpushPlg) + + r.subManager = &subscriptionManager{ + webpushPlg: webpushPlg, + DB: r.database, + } + } + { + storages := make([]storage.Storage, len(r.Config.Storages)) + for i, s := range r.Config.Storages { + storages[i] = storage.NewStorage(s) + } + r.storageManager = storage.NewManager(storages) + } r.statManager = cluster.NewStatManager() if err := r.statManager.Load(dataDir); err != nil { - log.Errorf("Stat load failed:", err) + log.Errorf("Stat load failed: %v", err) } r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) return r diff --git a/token/package.go b/token/package.go new file mode 100644 index 00000000..470830e0 --- /dev/null +++ b/token/package.go @@ -0,0 +1,256 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package token + +import ( + "errors" + "fmt" + "net/url" + "time" + + "github.com/golang-jwt/jwt/v5" + + "github.com/LiterMC/go-openbmclapi/utils" +) + +var ( + ErrUnsupportAuthType = errors.New("unsupported authorization type") + ErrScopeNotMatch = errors.New("scope not match") + ErrJTINotExists = errors.New("jti not exists") + + ErrStrictPathNotMatch = errors.New("strict path not match") + ErrStrictQueryNotMatch = errors.New("strict query value not match") +) + +const ( + challengeTokenScope = "GOBA-challenge" + authTokenScope = "GOBA-auth" + apiTokenScope = "GOBA-API" +) + +type ( + basicTokenManager struct { + impl basicTokenManagerImpl + } + basicTokenManagerImpl interface { + Issuer() string + HmacKey() []byte + AddJTI(string, time.Time) error + ValidJTI(string) bool + } +) + +type ( + challengeTokenClaims struct { + jwt.RegisteredClaims + + Scope string `json:"scope"` + Action string `json:"act"` + } + + authTokenClaims struct { + jwt.RegisteredClaims + + Scope string `json:"scope"` + User string `json:"usr"` + } + + apiTokenClaims struct { + jwt.RegisteredClaims + + Scope string `json:"scope"` + User string `json:"usr"` + StrictPath string `json:"str-p"` + StrictQuery map[string]string `json:"str-q,omitempty"` + } +) + +func (m *basicTokenManager) getJWTKey(t *jwt.Token) (any, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) + } + return m.impl.HmacKey(), nil +} + +func (m *basicTokenManager) GenerateChallengeToken(cliId string, action string) (string, error) { + now := time.Now() + exp := now.Add(time.Minute * 1) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &challengeTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + Subject: cliId, + Issuer: m.impl.Issuer(), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(exp), + }, + Scope: challengeTokenScope, + Action: action, + }) + tokenStr, err := token.SignedString(m.impl.HmacKey()) + if err != nil { + return "", err + } + return tokenStr, nil +} + +func (m *basicTokenManager) VerifyChallengeToken(cliId string, action string, token string) (err error) { + var claims challengeTokenClaims + if _, err = jwt.ParseWithClaims( + token, + &claims, + m.getJWTKey, + jwt.WithSubject(cliId), + jwt.WithIssuedAt(), + jwt.WithIssuer(m.impl.Issuer()), + ); err != nil { + return + } + if claims.Scope != challengeTokenScope { + return ErrScopeNotMatch + } + if claims.Action != action { + return ErrJTINotExists + } + return +} + +func (m *basicTokenManager) GenerateAuthToken(cliId string, userId string) (string, error) { + jti, err := utils.GenRandB64(16) + if err != nil { + return "", err + } + now := time.Now() + exp := now.Add(time.Hour * 24) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &authTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: jti, + Subject: cliId, + Issuer: m.impl.Issuer(), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(exp), + }, + Scope: authTokenScope, + User: userId, + }) + tokenStr, err := token.SignedString(m.impl.HmacKey()) + if err != nil { + return "", err + } + if err = m.impl.AddJTI(jti, exp); err != nil { + return "", err + } + return tokenStr, nil +} + +func (m *basicTokenManager) VerifyAuthToken(cliId string, token string) (id string, user string, err error) { + var claims authTokenClaims + if _, err = jwt.ParseWithClaims( + token, + &claims, + m.getJWTKey, + jwt.WithSubject(cliId), + jwt.WithIssuedAt(), + jwt.WithIssuer(m.impl.Issuer()), + ); err != nil { + return + } + if claims.Scope != authTokenScope { + err = ErrScopeNotMatch + return + } + if user = claims.User; user == "" { + // reject old token + err = ErrJTINotExists + return + } + id = claims.ID + if ok := m.impl.ValidJTI(id); !ok { + err = ErrJTINotExists + return + } + return +} + +func (m *basicTokenManager) GenerateAPIToken(cliId string, userId string, path string, query map[string]string) (string, error) { + jti, err := utils.GenRandB64(8) + if err != nil { + return "", err + } + now := time.Now() + exp := now.Add(time.Minute * 10) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &apiTokenClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + ID: jti, + Subject: cliId, + Issuer: m.impl.Issuer(), + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(exp), + }, + Scope: apiTokenScope, + User: userId, + StrictPath: path, + StrictQuery: query, + }) + tokenStr, err := token.SignedString(m.impl.HmacKey()) + if err != nil { + return "", err + } + if err = m.impl.AddJTI(jti, exp); err != nil { + return "", err + } + return tokenStr, nil +} + +func (m *basicTokenManager) VerifyAPIToken(cliId string, token string, path string, query url.Values) (user string, err error) { + var claims apiTokenClaims + _, err = jwt.ParseWithClaims( + token, + &claims, + m.getJWTKey, + jwt.WithSubject(cliId), + jwt.WithIssuedAt(), + jwt.WithIssuer(m.impl.Issuer()), + ) + if err != nil { + return + } + if claims.Scope != apiTokenScope { + err = ErrScopeNotMatch + return + } + if user = claims.User; user == "" { + err = ErrJTINotExists + return + } + if ok := m.impl.ValidJTI(claims.ID); !ok { + err = ErrJTINotExists + return + } + if claims.StrictPath != path { + err = ErrStrictPathNotMatch + return + } + for k, v := range claims.StrictQuery { + if query.Get(k) != v { + err = ErrStrictQueryNotMatch + return + } + } + return +} diff --git a/token/token_db.go b/token/token_db.go new file mode 100644 index 00000000..f9e507cf --- /dev/null +++ b/token/token_db.go @@ -0,0 +1,67 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2023 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package token + +import ( + "time" + + "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/database" +) + +type DBManager struct { + basicTokenManager + db database.DB + issuer string + apiHmacKey []byte +} + +var _ api.TokenManager = (*DBManager)(nil) + +func NewDBManager(issuer string, apiHmacKey []byte, db database.DB) *DBManager { + m := &DBManager{ + db: db, + issuer: issuer, + apiHmacKey: apiHmacKey, + } + m.basicTokenManager.impl = m + return m +} + +func (m *DBManager) Issuer() string { + return m.issuer +} + +func (m *DBManager) HmacKey() []byte { + return m.apiHmacKey +} + +func (m *DBManager) AddJTI(id string, expire time.Time) error { + return m.db.AddJTI(id, expire) +} + +func (m *DBManager) ValidJTI(id string) bool { + ok, _ := m.db.ValidJTI(id) + return ok +} + +func (m *DBManager) InvalidToken(id string) error { + return m.db.RemoveJTI(id) +} diff --git a/utils/crypto.go b/utils/crypto.go index 40f70846..d1fe2bcf 100644 --- a/utils/crypto.go +++ b/utils/crypto.go @@ -81,8 +81,8 @@ func GenRandB64(n int) (s string, err error) { return } -func LoadOrCreateHmacKey(dataDir string) (key []byte, err error) { - path := filepath.Join(dataDir, "server.hmac.private_key") +func LoadOrCreateHmacKey(dataDir string, name string) (key []byte, err error) { + path := filepath.Join(dataDir, name + ".hmac.private_key") buf, err := os.ReadFile(path) if err != nil { if !errors.Is(err, os.ErrNotExist) { From f2a883d572c6e17517df14a508a0d587fa772b57 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:20:59 -0600 Subject: [PATCH 27/36] implemented singleUserManager --- api/errors.go | 30 +++++++++++++++++ api/subscription.go | 4 --- database/db.go | 7 ---- database/memory.go | 52 +++++++++++++++--------------- database/sql.go | 78 ++++++++++++++++++++++++++++++++++----------- runner.go | 71 ++++++++++++++++++++++++++++++++++++++++- 6 files changed, 185 insertions(+), 57 deletions(-) create mode 100644 api/errors.go diff --git a/api/errors.go b/api/errors.go new file mode 100644 index 00000000..3b0214d9 --- /dev/null +++ b/api/errors.go @@ -0,0 +1,30 @@ +/** + * OpenBmclAPI (Golang Edition) + * Copyright (C) 2024 Kevin Z + * All rights reserved + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published + * by the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package api + +import ( + "errors" +) + +var ( + ErrStopIter = errors.New("stop iteration") + ErrNotFound = errors.New("Item not found") + ErrExist = errors.New("Item is already exist") +) diff --git a/api/subscription.go b/api/subscription.go index 7b88ad98..2749c895 100644 --- a/api/subscription.go +++ b/api/subscription.go @@ -32,10 +32,6 @@ import ( "github.com/LiterMC/go-openbmclapi/utils" ) -var ( - ErrNotFound = errors.New("Item not found") -) - type SubscriptionManager interface { GetWebPushKey() string diff --git a/database/db.go b/database/db.go index 8bca3464..b2fd0039 100644 --- a/database/db.go +++ b/database/db.go @@ -20,7 +20,6 @@ package database import ( - "errors" "time" "github.com/google/uuid" @@ -28,12 +27,6 @@ import ( "github.com/LiterMC/go-openbmclapi/api" ) -var ( - ErrStopIter = errors.New("stop iteration") - ErrNotFound = errors.New("no record was found") - ErrExists = errors.New("record's key was already exists") -) - type DB interface { // Cleanup will release any release that the database created // No operation should be executed during or after cleanup diff --git a/database/memory.go b/database/memory.go index 8169cdc6..394e9f1c 100644 --- a/database/memory.go +++ b/database/memory.go @@ -73,7 +73,7 @@ func (m *MemoryDB) ValidJTI(jti string) (bool, error) { expire, ok := m.tokens[jti] if !ok { - return false, ErrNotFound + return false, api.ErrNotFound } if time.Now().After(expire) { return false, nil @@ -85,7 +85,7 @@ func (m *MemoryDB) AddJTI(jti string, expire time.Time) error { m.tokenMux.Lock() defer m.tokenMux.Unlock() if _, ok := m.tokens[jti]; ok { - return ErrExists + return api.ErrExist } m.tokens[jti] = expire return nil @@ -96,13 +96,13 @@ func (m *MemoryDB) RemoveJTI(jti string) error { _, ok := m.tokens[jti] m.tokenMux.RUnlock() if !ok { - return ErrNotFound + return api.ErrNotFound } m.tokenMux.Lock() defer m.tokenMux.Unlock() if _, ok := m.tokens[jti]; !ok { - return ErrNotFound + return api.ErrNotFound } delete(m.tokens, jti) return nil @@ -114,7 +114,7 @@ func (m *MemoryDB) GetFileRecord(path string) (*FileRecord, error) { record, ok := m.fileRecords[path] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -136,7 +136,7 @@ func (m *MemoryDB) RemoveFileRecord(path string) error { defer m.fileRecMux.Unlock() if _, ok := m.fileRecords[path]; !ok { - return ErrNotFound + return api.ErrNotFound } delete(m.fileRecords, path) return nil @@ -148,7 +148,7 @@ func (m *MemoryDB) ForEachFileRecord(cb func(*FileRecord) error) error { for _, v := range m.fileRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -163,7 +163,7 @@ func (m *MemoryDB) GetSubscribe(user string, client string) (*api.SubscribeRecor record, ok := m.subscribeRecords[[2]string{user, client}] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -176,7 +176,7 @@ func (m *MemoryDB) SetSubscribe(record api.SubscribeRecord) error { if record.EndPoint == "" { old, ok := m.subscribeRecords[key] if !ok { - return ErrNotFound + return api.ErrNotFound } record.EndPoint = old.EndPoint } @@ -191,7 +191,7 @@ func (m *MemoryDB) RemoveSubscribe(user string, client string) error { key := [2]string{user, client} _, ok := m.subscribeRecords[key] if !ok { - return ErrNotFound + return api.ErrNotFound } delete(m.subscribeRecords, key) return nil @@ -203,7 +203,7 @@ func (m *MemoryDB) ForEachSubscribe(cb func(*api.SubscribeRecord) error) error { for _, v := range m.subscribeRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -218,7 +218,7 @@ func (m *MemoryDB) GetEmailSubscription(user string, addr string) (*api.EmailSub record, ok := m.emailSubRecords[[2]string{user, addr}] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -229,7 +229,7 @@ func (m *MemoryDB) AddEmailSubscription(record api.EmailSubscriptionRecord) erro key := [2]string{record.User, record.Addr} if _, ok := m.emailSubRecords[key]; ok { - return ErrExists + return api.ErrExist } m.emailSubRecords[key] = &record return nil @@ -242,7 +242,7 @@ func (m *MemoryDB) UpdateEmailSubscription(record api.EmailSubscriptionRecord) e key := [2]string{record.User, record.Addr} old, ok := m.emailSubRecords[key] if ok { - return ErrNotFound + return api.ErrNotFound } _ = old m.emailSubRecords[key] = &record @@ -255,7 +255,7 @@ func (m *MemoryDB) RemoveEmailSubscription(user string, addr string) error { key := [2]string{user, addr} if _, ok := m.emailSubRecords[key]; ok { - return ErrNotFound + return api.ErrNotFound } delete(m.emailSubRecords, key) return nil @@ -267,7 +267,7 @@ func (m *MemoryDB) ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord for _, v := range m.emailSubRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -285,7 +285,7 @@ func (m *MemoryDB) ForEachUsersEmailSubscription(user string, cb func(*api.Email continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -303,7 +303,7 @@ func (m *MemoryDB) ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptio continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -318,7 +318,7 @@ func (m *MemoryDB) GetWebhook(user string, id uuid.UUID) (*api.WebhookRecord, er record, ok := m.webhookRecords[webhookMemKey{user, id}] if !ok { - return nil, ErrNotFound + return nil, api.ErrNotFound } return record, nil } @@ -338,7 +338,7 @@ func (m *MemoryDB) AddWebhook(record api.WebhookRecord) (err error) { key := webhookMemKey{record.User, record.Id} if _, ok := m.webhookRecords[key]; ok { - return ErrExists + return api.ErrExist } if record.Auth == nil { record.Auth = emptyStrPtr @@ -357,7 +357,7 @@ func (m *MemoryDB) UpdateWebhook(record api.WebhookRecord) error { key := webhookMemKey{record.User, record.Id} old, ok := m.webhookRecords[key] if ok { - return ErrNotFound + return api.ErrNotFound } if record.Auth == nil { record.Auth = old.Auth @@ -376,7 +376,7 @@ func (m *MemoryDB) UpdateEnableWebhook(user string, id uuid.UUID, enabled bool) key := webhookMemKey{user, id} old, ok := m.webhookRecords[key] if ok { - return ErrNotFound + return api.ErrNotFound } record := *old record.Enabled = enabled @@ -390,7 +390,7 @@ func (m *MemoryDB) RemoveWebhook(user string, id uuid.UUID) error { key := webhookMemKey{user, id} if _, ok := m.webhookRecords[key]; ok { - return ErrNotFound + return api.ErrNotFound } delete(m.webhookRecords, key) return nil @@ -402,7 +402,7 @@ func (m *MemoryDB) ForEachWebhook(cb func(*api.WebhookRecord) error) error { for _, v := range m.webhookRecords { if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -420,7 +420,7 @@ func (m *MemoryDB) ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err @@ -438,7 +438,7 @@ func (m *MemoryDB) ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) erro continue } if err := cb(v); err != nil { - if err == ErrStopIter { + if err == api.ErrStopIter { break } return err diff --git a/database/sql.go b/database/sql.go index aa8ca576..a7df6d55 100644 --- a/database/sql.go +++ b/database/sql.go @@ -283,7 +283,7 @@ func (db *SqlDB) RemoveJTI(jti string) (err error) { if _, err = db.jtiStmts.remove.ExecContext(ctx, jti); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -411,7 +411,7 @@ func (db *SqlDB) GetFileRecord(path string) (rec *FileRecord, err error) { rec.Path = path if err = db.fileRecordStmts.get.QueryRowContext(ctx, &rec.Path).Scan(&rec.Hash, &rec.Size); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -436,7 +436,7 @@ func (db *SqlDB) RemoveFileRecord(path string) (err error) { if _, err = db.fileRecordStmts.remove.ExecContext(ctx, path); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -457,7 +457,12 @@ func (db *SqlDB) ForEachFileRecord(cb func(*FileRecord) error) (err error) { if err = rows.Scan(&rec.Path, &rec.Hash, &rec.Size); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -623,7 +628,7 @@ func (db *SqlDB) GetSubscribe(user string, client string) (rec *api.SubscribeRec rec.Client = client if err = db.subscribeStmts.get.QueryRowContext(ctx, user, client).Scan(&rec.EndPoint, &rec.Keys, &rec.Scopes, &rec.ReportAt); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -661,14 +666,14 @@ func (db *SqlDB) SetSubscribe(rec api.SubscribeRecord) (err error) { } else if rec.LastReport.Valid { if _, err = tx.Stmt(db.subscribeStmts.setUpdateLastReportOnly).Exec(rec.LastReport, rec.User, rec.Client); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } } else { if _, err = tx.Stmt(db.subscribeStmts.setUpdateScopesOnly).Exec(rec.Scopes, rec.ReportAt, rec.User, rec.Client); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -685,7 +690,7 @@ func (db *SqlDB) RemoveSubscribe(user string, client string) (err error) { if _, err = db.subscribeStmts.remove.ExecContext(ctx, user, client); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -706,7 +711,12 @@ func (db *SqlDB) ForEachSubscribe(cb func(*api.SubscribeRecord) error) (err erro if err = rows.Scan(&rec.User, &rec.Client, &rec.EndPoint, &rec.Keys, &rec.Scopes, &rec.ReportAt, &rec.LastReport); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -866,7 +876,7 @@ func (db *SqlDB) GetEmailSubscription(user string, addr string) (rec *api.EmailS rec.Addr = addr if err = db.emailSubscriptionStmts.get.QueryRowContext(ctx, user, addr).Scan(&rec.Scopes, &rec.Enabled); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -909,7 +919,7 @@ func (db *SqlDB) RemoveEmailSubscription(user string, addr string) (err error) { if _, err = db.emailSubscriptionStmts.remove.ExecContext(ctx, user, addr); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -930,7 +940,12 @@ func (db *SqlDB) ForEachEmailSubscription(cb func(*api.EmailSubscriptionRecord) if err = rows.Scan(&rec.User, &rec.Addr, &rec.Scopes, &rec.Enabled); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -953,7 +968,12 @@ func (db *SqlDB) ForEachUsersEmailSubscription(user string, cb func(*api.EmailSu if err = rows.Scan(&rec.Addr, &rec.Scopes, &rec.Enabled); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -975,7 +995,12 @@ func (db *SqlDB) ForEachEnabledEmailSubscription(cb func(*api.EmailSubscriptionR if err = rows.Scan(&rec.User, &rec.Addr, &rec.Scopes); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -1153,7 +1178,7 @@ func (db *SqlDB) GetWebhook(user string, id uuid.UUID) (rec *api.WebhookRecord, rec.Id = id if err = db.webhookStmts.get.QueryRowContext(ctx, user, hex.EncodeToString(id[:])).Scan(&rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -1205,7 +1230,7 @@ func (db *SqlDB) RemoveWebhook(user string, id uuid.UUID) (err error) { if _, err = db.webhookStmts.remove.ExecContext(ctx, user, hex.EncodeToString(id[:])); err != nil { if err == sql.ErrNoRows { - err = ErrNotFound + err = api.ErrNotFound } return } @@ -1226,7 +1251,12 @@ func (db *SqlDB) ForEachWebhook(cb func(*api.WebhookRecord) error) (err error) { if err = rows.Scan(&rec.User, &rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -1249,7 +1279,12 @@ func (db *SqlDB) ForEachUsersWebhook(user string, cb func(*api.WebhookRecord) er if err = rows.Scan(&rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes, &rec.Enabled, &rec.User); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return @@ -1271,7 +1306,12 @@ func (db *SqlDB) ForEachEnabledWebhook(cb func(*api.WebhookRecord) error) (err e if err = rows.Scan(&rec.User, &rec.Id, &rec.Name, &rec.EndPoint, &rec.Auth, &rec.Scopes); err != nil { return } - cb(&rec) + if err = cb(&rec); err != nil { + if err == api.ErrStopIter { + return nil + } + return + } } if err = rows.Err(); err != nil { return diff --git a/runner.go b/runner.go index 41251c2a..9d609799 100644 --- a/runner.go +++ b/runner.go @@ -111,7 +111,13 @@ func NewRunner() *Runner { } } - // r.userManager = + r.userManager = &singleUserManager{ + user: &api.User{ + Username: r.Config.Dashboard.Username, + Password: r.Config.Dashboard.Password, + Permissions: api.RootPerm, + }, + } if apiHMACKey, err := utils.LoadOrCreateHmacKey(dataDir, "server"); err != nil { log.Errorf("Cannot load HMAC key: %v", err) os.Exit(1) @@ -718,3 +724,66 @@ type subscriptionManager struct { func (s *subscriptionManager) GetWebPushKey() string { return base64.RawURLEncoding.EncodeToString(s.webpushPlg.GetPublicKey()) } + +type singleUserManager struct { + user *api.User +} + +func (m *singleUserManager) GetUsers() []*api.User { + return []*api.User{m.user} +} + +func (m *singleUserManager) GetUser(id string) *api.User { + if id == m.user.Username { + return m.user + } + return nil +} + +func (m *singleUserManager) AddUser(user *api.User) error { + if user.Username == m.user.Username { + return api.ErrExist + } + return errors.New("Not implemented") +} + +func (m *singleUserManager) RemoveUser(id string) error { + if id != m.user.Username { + return api.ErrNotFound + } + return errors.New("Not implemented") +} + +func (m *singleUserManager) ForEachUser(cb func(*api.User) error) error { + err := cb(m.user) + if err == api.ErrStopIter { + return nil + } + return err +} + +func (m *singleUserManager) UpdateUserPassword(username string, password string) error { + if username != m.user.Username { + return api.ErrNotFound + } + m.user.Password = password + return nil +} + +func (m *singleUserManager) UpdateUserPermissions(username string, permissions api.PermissionFlag) error { + if username != m.user.Username { + return api.ErrNotFound + } + m.user.Permissions = permissions + return nil +} + +func (m *singleUserManager) VerifyUserPassword(userId string, comparator func(password string) bool) error { + if userId != m.user.Username { + return errors.New("Username or password is incorrect") + } + if !comparator(m.user.Password) { + return errors.New("Username or password is incorrect") + } + return nil +} From a09e758adbc080b2ca02692cbce344561f404f18 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:27:44 -0600 Subject: [PATCH 28/36] run go fmt --- utils/crypto.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/crypto.go b/utils/crypto.go index d1fe2bcf..75895169 100644 --- a/utils/crypto.go +++ b/utils/crypto.go @@ -82,7 +82,7 @@ func GenRandB64(n int) (s string, err error) { } func LoadOrCreateHmacKey(dataDir string, name string) (key []byte, err error) { - path := filepath.Join(dataDir, name + ".hmac.private_key") + path := filepath.Join(dataDir, name+".hmac.private_key") buf, err := os.ReadFile(path) if err != nil { if !errors.Is(err, os.ErrNotExist) { From 29265d048fbe6238c21d0b3873d6d1a97c09a554 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:37:16 -0600 Subject: [PATCH 29/36] bump go version to 1.23 --- Dockerfile | 2 +- go.mod | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 5d014231..e2ea2a9b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -14,7 +14,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -ARG GO_VERSION=1.21 +ARG GO_VERSION=1.23 ARG REPO=github.com/LiterMC/go-openbmclapi ARG NPM_DIR=dashboard diff --git a/go.mod b/go.mod index 8d644429..0fe02797 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/LiterMC/go-openbmclapi -go 1.22.0 +go 1.23.0 require ( github.com/LiterMC/socket.io v0.2.5 From d436d88ac7ac735073aaf0699c1217af15122f0d Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:47:16 -0600 Subject: [PATCH 30/36] use report API --- cluster/requests.go | 4 ++-- cluster/storage.go | 23 ++++++++++++++++++----- internal/build/version.go | 2 +- utils/http.go | 6 ++++++ 4 files changed, 27 insertions(+), 8 deletions(-) diff --git a/cluster/requests.go b/cluster/requests.go index 51209b0f..8ba81572 100644 --- a/cluster/requests.go +++ b/cluster/requests.go @@ -260,13 +260,13 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error return } -func (cr *Cluster) ReportDownload(ctx context.Context, request *http.Request, err error) error { +func (cr *Cluster) ReportDownload(ctx context.Context, response *http.Response, err error) error { type ReportPayload struct { Urls []string `json:"urls"` Error utils.EmbedJSON[struct{ Message string }] `json:"error"` } var payload ReportPayload - redirects := utils.GetRedirects(request) + redirects := utils.GetRedirects(response.Request) payload.Urls = make([]string, len(redirects)) for i, u := range redirects { payload.Urls[i] = u.String() diff --git a/cluster/storage.go b/cluster/storage.go index 1f1d7697..04c9cd87 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -75,7 +75,8 @@ type FileInfo struct { type RequestPath struct { *http.Request - Path string + Cluster *Cluster + Path string } type StorageFileInfo struct { @@ -153,7 +154,11 @@ func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageF if err != nil { return err } - ff.URLs[req.URL.String()] = RequestPath{Request: req, Path: f.Path} + ff.URLs[req.URL.String()] = RequestPath{ + Request: req, + Cluster: cr, + Path: f.Path, + } fileMap[f.Hash] = ff } } @@ -548,9 +553,9 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage } reqInd := 0 - reqs := make([]*http.Request, 0, len(f.URLs)) + reqs := make([]RequestPath, 0, len(f.URLs)) for _, rq := range f.URLs { - reqs = append(reqs, rq.Request) + reqs = append(reqs, rq) } fileRes := make(chan *os.File, 1) @@ -610,10 +615,18 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage if _, err := fd.Seek(io.SeekStart, 0); err != nil { return err } - if err := c.fetchFileWithBuf(ctx, reqs[reqInd], f.Size, hashMethod, f.Hash, fd, buf, func(r io.Reader) io.Reader { + rp := reqs[reqInd] + if err := c.fetchFileWithBuf(ctx, rp.Request, f.Size, hashMethod, f.Hash, fd, buf, func(r io.Reader) io.Reader { return utils.ProxyPBReader(r, bar, stats.totalBar, &stats.lastInc) }); err != nil { reqInd = (reqInd + 1) % len(reqs) + if rerr, ok := err.(*utils.RedirectError); ok { + go func() { + if err := rp.Cluster.ReportDownload(context.WithoutCancel(ctx), rerr.GetResponse(), rerr.Unwrap()); err != nil { + log.Warnf("Report API error: %v", err) + } + }() + } return err } return nil diff --git a/internal/build/version.go b/internal/build/version.go index 2b8613ad..d0d06852 100644 --- a/internal/build/version.go +++ b/internal/build/version.go @@ -23,7 +23,7 @@ import ( "fmt" ) -const ClusterVersion = "1.10.9" +const ClusterVersion = "1.11.0" var BuildVersion string = "dev" diff --git a/utils/http.go b/utils/http.go index 600d1866..cd5b3863 100644 --- a/utils/http.go +++ b/utils/http.go @@ -525,12 +525,14 @@ func GetRedirects(req *http.Request) []*url.URL { type RedirectError struct { Redirects []*url.URL + Response *http.Response Err error } func ErrorFromRedirect(err error, resp *http.Response) *RedirectError { return &RedirectError{ Redirects: GetRedirects(resp.Request), + Response: resp, Err: err, } } @@ -551,6 +553,10 @@ func (e *RedirectError) Error() string { return b.String() } +func (e *RedirectError) GetResponse() *http.Response { + return e.Response +} + func (e *RedirectError) Unwrap() error { return e.Err } From 8698d11a3d05bec7bc3a6132d82d2970b0bb149d Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:55:32 -0600 Subject: [PATCH 31/36] fix nil pointer config --- main.go | 8 +++----- runner.go | 3 ++- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/main.go b/main.go index e8ac129c..0fdae27d 100644 --- a/main.go +++ b/main.go @@ -98,17 +98,15 @@ func main() { defer log.RecordPanic() log.StartFlushLogFile() - r := NewRunner() - ctx, cancel := context.WithCancel(context.Background()) - if config, err := readAndRewriteConfig(); err != nil { + config, err := readAndRewriteConfig() + if err != nil { log.Errorf("Config error: %s", err) os.Exit(1) - } else { - r.Config = config } + r := NewRunner(cfg) r.SetupLogger(ctx) log.TrInfof("program.starting", build.ClusterVersion, build.BuildVersion) diff --git a/runner.go b/runner.go index 9d609799..58cfafe5 100644 --- a/runner.go +++ b/runner.go @@ -93,9 +93,10 @@ type Runner struct { tunnelCancel context.CancelFunc } -func NewRunner() *Runner { +func NewRunner(cfg *config.Config) *Runner { r := new(Runner) + r.Config = cfg r.configHandler = &ConfigHandler{r: r} var dialer *net.Dialer From 774a388a9ab8a559b33c05b3934753739479bc69 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 16:56:13 -0600 Subject: [PATCH 32/36] fix typo --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 0fdae27d..14ce0549 100644 --- a/main.go +++ b/main.go @@ -100,7 +100,7 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - config, err := readAndRewriteConfig() + cfg, err := readAndRewriteConfig() if err != nil { log.Errorf("Config error: %s", err) os.Exit(1) From 9308d58cef9d6b2eee0b0383131c3bf32611b87f Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sat, 17 Aug 2024 17:00:54 -0600 Subject: [PATCH 33/36] fix http client --- cluster/cluster.go | 6 +++--- cluster/requests.go | 2 +- cluster/storage.go | 2 +- config.go | 2 +- runner.go | 2 +- 5 files changed, 7 insertions(+), 7 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index f20416bd..e2347935 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -23,7 +23,6 @@ import ( "context" "errors" "fmt" - "net/http" "regexp" "runtime" "strings" @@ -61,8 +60,7 @@ type Cluster struct { status atomic.Int32 socketStatus atomic.Int32 socket *socket.Socket - client *http.Client - cachedCli *http.Client + client *HTTPClient authTokenMux sync.RWMutex authToken *ClusterToken @@ -73,6 +71,7 @@ func NewCluster( name string, opts config.ClusterOptions, gcfg config.ClusterGeneralConfig, storageManager *storage.Manager, statManager *StatManager, + client *HTTPClient, ) (cr *Cluster) { storages := make([]int, len(opts.Storages)) for i, name := range opts.Storages { @@ -85,6 +84,7 @@ func NewCluster( storageManager: storageManager, storages: storages, statManager: statManager, + client: client, } return } diff --git a/cluster/requests.go b/cluster/requests.go index 8ba81572..83c6403f 100644 --- a/cluster/requests.go +++ b/cluster/requests.go @@ -209,7 +209,7 @@ func (cr *Cluster) GetConfig(ctx context.Context) (cfg *OpenbmclapiAgentConfig, if err != nil { return } - res, err := cr.cachedCli.Do(req) + res, err := cr.client.DoUseCache(req) if err != nil { return } diff --git a/cluster/storage.go b/cluster/storage.go index 04c9cd87..7a234b81 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -101,7 +101,7 @@ func (cr *Cluster) GetFileList(ctx context.Context, fileMap map[string]*StorageF if err != nil { return err } - res, err := cr.cachedCli.Do(req) + res, err := cr.client.DoUseCache(req) if err != nil { return err } diff --git a/config.go b/config.go index 1c4540af..29b5cd49 100644 --- a/config.go +++ b/config.go @@ -56,7 +56,7 @@ func migrateConfig(data []byte, cfg *config.Config) { if v, ok := oldConfig["keepalive-timeout"].(int); ok { cfg.Advanced.KeepaliveTimeout = v } - if oldConfig["clusters"].(map[string]any) == nil { + if oldConfig["clusters"] == nil { id, ok1 := oldConfig["cluster-id"].(string) secret, ok2 := oldConfig["cluster-secret"].(string) publicHost, ok3 := oldConfig["public-host"].(string) diff --git a/runner.go b/runner.go index 58cfafe5..d2d67bd9 100644 --- a/runner.go +++ b/runner.go @@ -225,7 +225,7 @@ func (r *Runner) InitClusters(ctx context.Context) { r.clusters = make(map[string]*cluster.Cluster) gcfg := r.GetClusterGeneralConfig() for name, opts := range r.Config.Clusters { - cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager) + cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager, r.client) if err := cr.Init(ctx); err != nil { log.TrErrorf("error.init.failed", err) } else { From 0ec518973c93805c5f87b8fd232c4a517958eaba Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sun, 18 Aug 2024 08:18:36 -0600 Subject: [PATCH 34/36] fix report API, and a few translations --- cluster/cluster.go | 6 ++++++ cluster/requests.go | 14 +++++++++++--- cluster/socket.go | 4 ++-- cluster/storage.go | 27 +++++++-------------------- config.go | 5 ++--- lang/en/us.go | 10 +++++----- lang/zh/cn.go | 10 +++++----- main.go | 3 +++ runner.go | 26 ++++++++------------------ utils/error.go | 2 +- utils/http.go | 18 ++++++++++-------- 11 files changed, 60 insertions(+), 65 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index e2347935..fa10f638 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -42,6 +42,8 @@ var ( reFileHashMismatchError = regexp.MustCompile(` hash mismatch, expected ([0-9a-f]+), got ([0-9a-f]+)`) ) +const DefaultBMCLAPIServer = "https://openbmclapi.bangbang93.com" + type Cluster struct { name string opts config.ClusterOptions @@ -77,7 +79,11 @@ func NewCluster( for i, name := range opts.Storages { storages[i] = storageManager.GetIndex(name) } + if opts.Server == "" { + opts.Server = DefaultBMCLAPIServer + } cr = &Cluster{ + name: name, opts: opts, gcfg: gcfg, diff --git a/cluster/requests.go b/cluster/requests.go index 83c6403f..0899c10e 100644 --- a/cluster/requests.go +++ b/cluster/requests.go @@ -26,6 +26,7 @@ import ( "crypto/hmac" "encoding/hex" "encoding/json" + "errors" "fmt" "net/http" "net/url" @@ -261,12 +262,18 @@ func (cr *Cluster) RequestCert(ctx context.Context) (ckp *CertKeyPair, err error } func (cr *Cluster) ReportDownload(ctx context.Context, response *http.Response, err error) error { + if errors.Is(err, context.Canceled) { + return nil + } + type ReportPayload struct { - Urls []string `json:"urls"` - Error utils.EmbedJSON[struct{ Message string }] `json:"error"` + Urls []string `json:"urls"` + Error utils.EmbedJSON[struct { + Message string `json:"message"` + }] `json:"error"` } var payload ReportPayload - redirects := utils.GetRedirects(response.Request) + redirects := utils.GetRedirects(response) payload.Urls = make([]string, len(redirects)) for i, u := range redirects { payload.Urls[i] = u.String() @@ -280,6 +287,7 @@ func (cr *Cluster) ReportDownload(ctx context.Context, response *http.Response, if err != nil { return err } + req.Header.Set("Content-Type", "application/json") resp, err := cr.client.Do(req) if err != nil { return err diff --git a/cluster/socket.go b/cluster/socket.go index db238397..ccce2bd5 100644 --- a/cluster/socket.go +++ b/cluster/socket.go @@ -68,7 +68,7 @@ func (cr *Cluster) Connect(ctx context.Context) error { }) } engio.OnConnect(func(s *engine.Socket) { - log.Info("Engine.IO %s connected for cluster %s", s.ID(), cr.ID()) + log.Infof("Engine.IO %s connected for cluster %s", s.ID(), cr.ID()) }) engio.OnDisconnect(cr.onDisconnected) engio.OnDialError(func(s *engine.Socket, err *engine.DialErrorContext) { @@ -102,7 +102,7 @@ func (cr *Cluster) Connect(ctx context.Context) error { log.Infof("[remote]: %v", data[0]) } }) - log.Info("Connecting to socket.io namespace") + log.Infof("Cluster %s is connecting to socket.io namespace", cr.Name()) if err := cr.socket.Connect(""); err != nil { return fmt.Errorf("Namespace connect error: %w", err) } diff --git a/cluster/storage.go b/cluster/storage.go index 7a234b81..2750b5d1 100644 --- a/cluster/storage.go +++ b/cluster/storage.go @@ -174,23 +174,6 @@ func storageIdSortFunc(a, b storage.Storage) int { return 1 } -// func SyncFiles(ctx context.Context, manager *storage.Manager, files map[string]*StorageFileInfo, heavyCheck bool) bool { -// log.TrInfof("info.sync.prepare", len(files)) - -// slices.SortFunc(files, func(a, b *StorageFileInfo) int { return a.Size - b.Size }) -// if cr.syncFiles(ctx, files, heavyCheck) != nil { -// return false -// } - -// cr.filesetMux.Lock() -// for _, f := range files { -// cr.fileset[f.Hash] = f.Size -// } -// cr.filesetMux.Unlock() - -// return true -// } - var emptyStr string func checkFile( @@ -393,12 +376,15 @@ func (c *HTTPClient) SyncFiles( return err } - totalFiles := len(files) + totalFiles := len(missingMap) var stats syncStats stats.pg = pg stats.slots = limited.NewBufSlots(slots) stats.totalFiles = totalFiles + for _, f := range missingMap { + stats.totalSize += f.Size + } var barUnit decor.SizeB1024 stats.lastInc.Store(time.Now().UnixNano()) @@ -620,7 +606,8 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage return utils.ProxyPBReader(r, bar, stats.totalBar, &stats.lastInc) }); err != nil { reqInd = (reqInd + 1) % len(reqs) - if rerr, ok := err.(*utils.RedirectError); ok { + var rerr *utils.RedirectError + if errors.As(err, &rerr) { go func() { if err := rp.Cluster.ReportDownload(context.WithoutCancel(ctx), rerr.GetResponse(), rerr.Unwrap()); err != nil { log.Warnf("Report API error: %v", err) @@ -642,7 +629,7 @@ func (c *HTTPClient) fetchFile(ctx context.Context, stats *syncStats, f *Storage bar.SetRefill(bar.Current()) c := tried.Add(1) - if c > maxRetryCount { + if c > maxRetryCount || errors.Is(err, context.Canceled) { log.TrErrorf("error.sync.download.failed", f.Hash, err) stats.failCount.Add(1) return diff --git a/config.go b/config.go index 29b5cd49..02dd529c 100644 --- a/config.go +++ b/config.go @@ -33,14 +33,13 @@ import ( "gopkg.in/yaml.v3" "github.com/LiterMC/go-openbmclapi/api" + "github.com/LiterMC/go-openbmclapi/cluster" "github.com/LiterMC/go-openbmclapi/config" "github.com/LiterMC/go-openbmclapi/log" "github.com/LiterMC/go-openbmclapi/storage" "github.com/LiterMC/go-openbmclapi/utils" ) -const DefaultBMCLAPIServer = "https://openbmclapi.bangbang93.com" - func migrateConfig(data []byte, cfg *config.Config) { var oldConfig map[string]any if err := yaml.Unmarshal(data, &oldConfig); err != nil { @@ -97,7 +96,7 @@ func readAndRewriteConfig() (cfg *config.Config, err error) { Id: "${CLUSTER_ID}", Secret: "${CLUSTER_SECRET}", PublicHosts: []string{}, - Server: DefaultBMCLAPIServer, + Server: cluster.DefaultBMCLAPIServer, SkipSignatureCheck: false, }, } diff --git a/lang/en/us.go b/lang/en/us.go index 7f0df686..95f453ee 100644 --- a/lang/en/us.go +++ b/lang/en/us.go @@ -10,12 +10,12 @@ var areaUS = map[string]string{ "warn.exit.detected.windows.open.browser": "Detected that you are in windows environment, we are helping you to open the browser", "warn.cluster.detected.hash.mismatch": "Detected hash mismatch error, removing bad file %s", - "info.filelist.fetching": "Fetching file list", + "info.filelist.fetching": "Fetching file list for %s", "error.filelist.fetch.failed": "Cannot fetch cluster file list: %v", "error.address.listen.failed": "Cannot listen address %s: %v", - "info.cert.requesting": "Requesting certificates, please wait ...", + "info.cert.requesting": "Requesting certificates for %s, please wait ...", "info.cert.requested": "Requested certificate for %s", "error.cert.not.set": "No certificates was set in the config", "error.cert.parse.failed": "Cannot parse certificate key pair[%d]: %v", @@ -27,7 +27,7 @@ var areaUS = map[string]string{ "info.wait.first.sync": "Waiting for the first sync ...", "info.cluster.enable.sending": "Sending enable packet", "info.cluster.enabled": "Cluster enabled", - "error.cluster.enable.failed": "Cannot enable cluster: %v", + "error.cluster.enable.failed": "Cannot enable cluster %s: %v", "error.cluster.disconnected": "Cluster disconnected from remote. exit.", "info.cluster.reconnect.keepalive": "Reconnecting due to keepalive failed", "info.cluster.reconnecting": "Reconnecting ...", @@ -49,8 +49,8 @@ var areaUS = map[string]string{ "warn.cluster.disabled": "Cluster disabled", "warn.httpserver.closing": "Closing HTTP server ...", - "info.check.start": "Start checking files for %s, heavy = %v", - "info.check.done": "File check finished for %s, missing %d files", + "info.check.start": "Start checking files, heavy = %v", + "info.check.done": "File check finished, missing %d files", "error.check.failed": "Failed to check %s: %v", "hint.check.checking": "> Checking ", "warn.check.modified.size": "Found modified file: size of %q is %d, expect %d", diff --git a/lang/zh/cn.go b/lang/zh/cn.go index 35c359e0..2b62e4d2 100644 --- a/lang/zh/cn.go +++ b/lang/zh/cn.go @@ -10,12 +10,12 @@ var areaCN = map[string]string{ "warn.exit.detected.windows.open.browser": "检测到您是新手 Windows 用户. 我们正在帮助您打开浏览器 ...", "warn.cluster.detected.hash.mismatch": "检测到文件哈希值不匹配, 正在删除 %s", - "info.filelist.fetching": "获取文件列表中", + "info.filelist.fetching": "为 %s 获取文件列表中", "error.filelist.fetch.failed": "文件列表获取失败: %v", "error.address.listen.failed": "无法监听地址 %s: %v", - "info.cert.requesting": "请求证书中, 请稍候 ...", + "info.cert.requesting": "正在为 %s 请求证书, 请稍候 ...", "info.cert.requested": "证书请求完毕, 域名为 %s", "error.cert.not.set": "配置文件内没有提供证书", "error.cert.parse.failed": "无法解析证书密钥对[%d]: %v", @@ -27,7 +27,7 @@ var areaCN = map[string]string{ "info.wait.first.sync": "正在等待第一次同步 ...", "info.cluster.enable.sending": "正在发送启用数据包", "info.cluster.enabled": "节点已启用", - "error.cluster.enable.failed": "无法启用节点: %v", + "error.cluster.enable.failed": "无法启用节点 %s: %v", "error.cluster.disconnected": "节点从主控断开. exit.", "info.cluster.reconnect.keepalive": "保活失败, 重连中 ...", "info.cluster.reconnecting": "重连中 ...", @@ -49,8 +49,8 @@ var areaCN = map[string]string{ "warn.cluster.disabled": "节点已禁用", "warn.httpserver.closing": "正在关闭 HTTP 服务器 ...", - "info.check.start": "开始在 %s 检测文件. 强检查 = %v", - "info.check.done": "文件在 %s 检查完毕, 缺失 %d 个文件", + "info.check.start": "开始检测文件. 强检查 = %v", + "info.check.done": "文件检查完毕, 缺失 %d 个文件", "error.check.failed": "无法检查 %s: %v", "hint.check.checking": "> 检查中 ", "warn.check.modified.size": "找到修改过的文件: %q 的大小为 %d, 预期 %d", diff --git a/main.go b/main.go index 14ce0549..0a869dd3 100644 --- a/main.go +++ b/main.go @@ -167,6 +167,9 @@ func main() { log.TrInfof("info.wait.first.sync") r.InitSynchronizer(ctx) + if ctx.Err() != nil { + return + } r.EnableClusterAll(ctx) }(ctx) diff --git a/runner.go b/runner.go index d2d67bd9..4f89fa5e 100644 --- a/runner.go +++ b/runner.go @@ -158,6 +158,7 @@ func NewRunner(cfg *config.Config) *Runner { log.Errorf("Stat load failed: %v", err) } r.apiRateLimiter = limited.NewAPIRateMiddleWare(api.RealAddrCtxKey, "go-openbmclapi.cluster.logged.user" /* api/v0.loggedUserKey */) + r.certificates = make(map[string]*tls.Certificate) return r } @@ -227,25 +228,11 @@ func (r *Runner) InitClusters(ctx context.Context) { for name, opts := range r.Config.Clusters { cr := cluster.NewCluster(name, opts, gcfg, r.storageManager, r.statManager, r.client) if err := cr.Init(ctx); err != nil { - log.TrErrorf("error.init.failed", err) - } else { - r.clusters[name] = cr + log.TrErrorf("error.init.failed", cr.Name(), err) + continue } + r.clusters[name] = cr } - - // r.cluster = NewCluster(ctx, - // ClusterServerURL, - // baseDir, - // config.PublicHost, r.getPublicPort(), - // config.ClusterId, config.ClusterSecret, - // config.Byoc, dialer, - // config.Storages, - // cache, - // ) - // if err := r.cluster.Init(ctx); err != nil { - // log.TrErrorf("error.init.failed"), err) - // os.Exit(1) - // } } func (r *Runner) ListenSignals(ctx context.Context, cancel context.CancelFunc) int { @@ -363,7 +350,6 @@ func (r *Runner) SetupLogger(ctx context.Context) error { } func (r *Runner) StopServer(ctx context.Context) { - r.tunnelCancel() shutCtx, cancelShut := context.WithTimeout(context.Background(), time.Second*15) defer cancelShut() log.TrWarnf("warn.server.closing") @@ -373,6 +359,7 @@ func (r *Runner) StopServer(ctx context.Context) { defer cancelShut() var wg sync.WaitGroup for _, cr := range r.clusters { + wg.Add(1) go func() { defer wg.Done() cr.Disable(shutCtx) @@ -381,6 +368,9 @@ func (r *Runner) StopServer(ctx context.Context) { wg.Wait() log.TrWarnf("warn.httpserver.closing") r.server.Shutdown(shutCtx) + if r.tunnelCancel != nil { + r.tunnelCancel() + } r.listener.Close() r.listener = nil }() diff --git a/utils/error.go b/utils/error.go index 31d62f76..627450bb 100644 --- a/utils/error.go +++ b/utils/error.go @@ -38,7 +38,7 @@ func NewHTTPStatusErrorFromResponse(res *http.Response) (e *HTTPStatusError) { e.URL = res.Request.URL.String() } if res.Body != nil { - var buf [512]byte + var buf [1024]byte n, _ := res.Body.Read(buf[:]) msg := (string)(buf[:n]) for _, b := range msg { diff --git a/utils/http.go b/utils/http.go index cd5b3863..c4169da8 100644 --- a/utils/http.go +++ b/utils/http.go @@ -22,6 +22,7 @@ package utils import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "io" @@ -506,15 +507,16 @@ func (c *connHeadReader) Read(buf []byte) (n int, err error) { return c.Conn.Read(buf) } -func GetRedirects(req *http.Request) []*url.URL { +func GetRedirects(resp *http.Response) []*url.URL { redirects := make([]*url.URL, 0, 5) - for req != nil { - redirects = append(redirects, req.URL) - resp := req.Response - if resp == nil { - break - } + if u, _ := resp.Location(); u != nil { + redirects = append(redirects, u) + } + var req *http.Request + for resp != nil { req = resp.Request + redirects = append(redirects, req.URL) + resp = req.Response } if len(redirects) == 0 { return nil @@ -531,7 +533,7 @@ type RedirectError struct { func ErrorFromRedirect(err error, resp *http.Response) *RedirectError { return &RedirectError{ - Redirects: GetRedirects(resp.Request), + Redirects: GetRedirects(resp), Response: resp, Err: err, } From 07bf29d86f7cecdfca2d4c381ab2a9fbe6be3e28 Mon Sep 17 00:00:00 2001 From: zyxkad Date: Sun, 18 Aug 2024 08:20:52 -0600 Subject: [PATCH 35/36] fix unused import --- utils/http.go | 1 - 1 file changed, 1 deletion(-) diff --git a/utils/http.go b/utils/http.go index c4169da8..389b8642 100644 --- a/utils/http.go +++ b/utils/http.go @@ -22,7 +22,6 @@ package utils import ( "bufio" "bytes" - "context" "crypto/tls" "errors" "io" From 62c50ae629256b0d03a8600e5ca4b6c2f31a43ea Mon Sep 17 00:00:00 2001 From: zyxkad Date: Wed, 9 Oct 2024 11:54:36 -0600 Subject: [PATCH 36/36] update config --- config/config.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/config.go b/config/config.go index 3654b6cd..c114cc34 100644 --- a/config/config.go +++ b/config/config.go @@ -50,6 +50,7 @@ type Config struct { AccessLogSlots int `yaml:"access-log-slots"` Clusters map[string]ClusterOptions `yaml:"clusters"` + Storages []storage.StorageOption `yaml:"storages"` Certificates []CertificateConfig `yaml:"certificates"` Tunneler TunnelConfig `yaml:"tunneler"` Cache CacheConfig `yaml:"cache"` @@ -60,7 +61,6 @@ type Config struct { GithubAPI GithubAPIConfig `yaml:"github-api"` Database DatabaseConfig `yaml:"database"` Hijack HijackConfig `yaml:"hijack"` - Storages []storage.StorageOption `yaml:"storages"` WebdavUsers map[string]*storage.WebDavUser `yaml:"webdav-users"` Advanced AdvancedConfig `yaml:"advanced"` } @@ -92,6 +92,8 @@ func NewDefaultConfig() *Config { Clusters: map[string]ClusterOptions{}, + Storages: nil, + Certificates: []CertificateConfig{}, Tunneler: TunnelConfig{ @@ -164,8 +166,6 @@ func NewDefaultConfig() *Config { }, }, - Storages: nil, - WebdavUsers: map[string]*storage.WebDavUser{}, Advanced: AdvancedConfig{