diff --git a/.golangci.yml b/.golangci.yml index 81c9d90..499d9c0 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -49,7 +49,7 @@ linters: - gosec # Inspects source code for security problems - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - - nestif # Reports deeply nested if statements + #- nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - noctx # noctx finds sending http request without context.Context - nolintlint # Reports ill-formed or insufficient nolint directives diff --git a/api/response.go b/api/response.go new file mode 100644 index 0000000..33dbcf0 --- /dev/null +++ b/api/response.go @@ -0,0 +1,179 @@ +package api + +import ( + "encoding/json" + "net/http" +) + +const ( + okStatus = "ok" + errStatus = "error" +) + +// Error is a generic error structure that is used to send error responses to the client. +type Error struct { + Code string `json:"code"` + Message string `json:"message"` + Extra interface{} `json:"extra,omitempty"` +} + +// Response is a generic response structure that is used to send responses to the client. +type Response struct { + Status string `json:"status"` + Data interface{} `json:"data,omitempty"` + Error *Error `json:"error,omitempty"` +} + +// NewResponse creates a new response object. +func NewResponse() *Response { + return &Response{ + Status: okStatus, + } +} + +// Error message +func (e *Error) Error() string { + return e.Message +} + +// Set data to response +func (rsp *Response) SetData(data interface{}) *Response { + rsp.Status = okStatus + rsp.Error = nil + rsp.Data = data + + return rsp +} + +// Set error to response +func (rsp *Response) SetError(code string, message string, extra ...interface{}) *Response { + rsp.Status = errStatus + rsp.Data = nil + + var extraData interface{} + if len(extra) > 0 { + extraData = extra[0] // Берем первый переданный аргумент, если он есть + } else { + extraData = nil // Если аргумент не был передан, оставляем nil + } + + rsp.Error = &Error{ + Code: code, + Message: message, + Extra: extraData, + } + + return rsp +} + +// Send success response to client +func (rsp *Response) Ok(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + + rsp.Status = "ok" + + _ = json.NewEncoder(w).Encode(rsp) +} + +// Send error response to client +func (rsp *Response) BadRequest(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + + rsp.Status = errStatus + + if rsp.Error == nil { + rsp.Error = &Error{ + Code: "bad_request", + Message: "Bad request", + } + } + + _ = json.NewEncoder(w).Encode(rsp) +} + +// Send error response to client +func (rsp *Response) InternalServerError(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + + rsp.Status = errStatus + + if rsp.Error == nil { + rsp.Error = &Error{ + Code: "internal_server_error", + Message: "Internal server error", + } + } + + _ = json.NewEncoder(w).Encode(rsp) +} + +// Send error response to client +func (rsp *Response) NotFound(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + + rsp.Status = errStatus + + if rsp.Error == nil { + rsp.Error = &Error{ + Code: "not_found", + Message: "Not found", + } + } + + _ = json.NewEncoder(w).Encode(rsp) +} + +// Send error response to client +func (rsp *Response) Unauthorized(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + + rsp.Status = errStatus + + if rsp.Error == nil { + rsp.Error = &Error{ + Code: "unauthorized", + Message: "Unauthorized", + } + } + + _ = json.NewEncoder(w).Encode(rsp) +} + +// Send error response to client +func (rsp *Response) Forbidden(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + + rsp.Status = errStatus + + if rsp.Error == nil { + rsp.Error = &Error{ + Code: "forbidden", + Message: "Forbidden", + } + } + + _ = json.NewEncoder(w).Encode(rsp) +} + +// Send error response to client +func (rsp *Response) MethodNotAllowed(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusMethodNotAllowed) + + rsp.Status = errStatus + + if rsp.Error == nil { + rsp.Error = &Error{ + Code: "method_not_allowed", + Message: "Method not allowed", + } + } + + _ = json.NewEncoder(w).Encode(rsp) +} diff --git a/cmd/main.go b/cmd/main.go index 40406ed..f178d89 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -2,16 +2,24 @@ package main import ( "context" + "errors" "fmt" logByDefault "log" "log/slog" + "net/http" "os" + "os/signal" + "sync" + "syscall" "time" config "github.com/plugfox/foxy-gram-server/internal/config" + "github.com/plugfox/foxy-gram-server/internal/err" + "github.com/plugfox/foxy-gram-server/internal/global" "github.com/plugfox/foxy-gram-server/internal/httpclient" log "github.com/plugfox/foxy-gram-server/internal/log" "github.com/plugfox/foxy-gram-server/internal/model" + "github.com/plugfox/foxy-gram-server/internal/server" storage "github.com/plugfox/foxy-gram-server/internal/storage" "github.com/plugfox/foxy-gram-server/internal/telegram" @@ -37,7 +45,11 @@ func main() { log.WithSource(), ) - if err := run(config, logger); err != nil { + global.Config = config + global.Logger = logger + + // Run the server + if err := run(); err != nil { logger.ErrorContext(context.Background(), "an error occurred", slog.String("error", err.Error())) os.Exit(1) } @@ -45,11 +57,78 @@ func main() { os.Exit(0) } -func run(config *config.Config, logger *slog.Logger) error { - ctx := context.Background() +// waitExitSignal waits for the SIGINT or SIGTERM signal to shutdown the centrifuge node. +// It creates a channel to receive signals and a channel to indicate when the shutdown is complete. +// Then it notifies the channel for SIGINT and SIGTERM signals and starts a goroutine to wait for the signal. +// Once the signal is received, it shuts down the centrifuge node and indicates that the shutdown is complete. +func waitExitSignal(sigCh chan os.Signal, t *telegram.Telegram, s *server.Server /* n *centrifuge.Node */) { + wg := sync.WaitGroup{} + + // Notify the channel for SIGINT and SIGTERM signals. + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + + const timeout = 10 * time.Second + + // Start a goroutine to wait for the signal and handle graceful shutdown. + wg.Add(1) + + go func() { + defer wg.Done() + + // Wait for the signal. + <-sigCh + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + + defer cancel() + + // _ = n.Shutdown(ctx) + + _ = s.Shutdown(ctx) + }() + + // Handle Telegram bot shutdown. + wg.Add(1) + + go func() { + defer wg.Done() + + // Wait for the signal. + <-sigCh + // Create a channel to indicate when the shutdown is complete. + done := make(chan struct{}) + + // Stop the Telegram bot + go func() { + defer close(done) + t.Stop() + }() + + // Ensure the shutdown happens within 10 seconds. + select { + case <-done: // Done + case <-time.After(timeout): // Timeout + } + }() + + // Wait for both goroutines to complete before exiting. + wg.Wait() +} + +// Starts the server and waits for the SIGINT or SIGTERM signal to shutdown the server. +func run() error { + if global.Config == nil || global.Logger == nil { + return err.ErrorGlobalVariablesNotInitialized + } + + ctx, cancel := context.WithCancel(context.Background()) + + defer cancel() + + // Set the maxprocs environment variable in container runtimes. _, err := maxprocs.Set(maxprocs.Logger(func(s string, i ...interface{}) { - logger.DebugContext(ctx, fmt.Sprintf(s, i...)) + global.Logger.DebugContext(ctx, fmt.Sprintf(s, i...)) })) if err != nil { return fmt.Errorf("setting max procs: %w", err) @@ -59,36 +138,114 @@ func run(config *config.Config, logger *slog.Logger) error { model.InitHashFunction() // Setup database connection - db, err := storage.New(config, logger) - if err != nil { - return fmt.Errorf("database connection error: %w", err) - } + db := initStorage() // Create a http client - httpClient, err := httpclient.NewHTTPClient(&config.Proxy) - if err != nil { - return fmt.Errorf("database connection error: %w", err) - } + httpClient := initHTTPClient() // Setup Telegram bot - telegram, err := telegram.New(db, httpClient, config, logger) - if err != nil { - return fmt.Errorf("telegram bot setup error: %w", err) - } + telegram := initTelegram(db, httpClient) + // Update the bot user information if err := db.UpsertUser(telegram.Me().Seen()); err != nil { return fmt.Errorf("upserting user error: %w", err) } - // TODO: Setup API server + // Setup API server + server := initServer() + server.AddHealthCheck( + func() (bool, map[string]string) { + dbStatus, dbErr := db.Status() + srvStatus, srvErr := server.Status() + tgStatus, tgErr := telegram.Status() + + isHealthy := dbErr == nil && srvErr == nil && tgErr == nil + + return isHealthy, map[string]string{ + "database": dbStatus, + "server": srvStatus, + "telegram": tgStatus, + } + }, + ) // Add health check endpoint + server.AddVerifyUsers(db) // Add verify users endpoint [POST] /admin/verify // TODO: Setup Centrifuge server // TODO: Setup InfluxDB metrics (if any) - telegram.Start() + // Create a channel to shutdown the server. + sigCh := make(chan os.Signal, 1) + + // Create a function to stop the server. + // Call this function when the server needs to be closed. + /* stop := func(sigCh chan os.Signal) func() { + return func() { + sigCh <- syscall.SIGTERM // Close server. + } + } */ + + // Log the server start + global.Logger.InfoContext( + ctx, + "Server started", + slog.String("host", global.Config.API.Host), + slog.Int("port", global.Config.API.Port), + ) - logger.InfoContext(ctx, "Server started", slog.String("host", config.API.Host), slog.Int("port", config.API.Port)) + // Wait for the SIGINT or SIGTERM signal to shutdown the server. + waitExitSignal(sigCh, telegram, server) + close(sigCh) return nil } + +// initStorage initializes the database connection. +func initStorage() *storage.Storage { + db, err := storage.New() + if err != nil { + panic(fmt.Sprintf("database connection error: %v", err)) + } + + return db +} + +// Create a new HTTP client +func initHTTPClient() *http.Client { + httpClient, err := httpclient.NewHTTPClient(&global.Config.Proxy) + if err != nil { + panic(fmt.Sprintf("http client error: %v", err)) + } + + return httpClient +} + +// Initialize the Telegram bot +func initTelegram(db *storage.Storage, httpClient *http.Client) *telegram.Telegram { + tg, err := telegram.New(db, httpClient) + if err != nil { + panic(fmt.Sprintf("telegram bot setup error: %v", err)) + } + + // Start the Telegram bot polling + go func() { + tg.Start() + }() + + return tg +} + +// Initialize the API server +func initServer() *server.Server { + srv := server.New() + + // Start the server + go func() { + if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { + global.Logger.Error("Server error", slog.String("error", err.Error())) + os.Exit(1) // Exit the program if the server fails to start. + } + }() + + return srv +} diff --git a/go.mod b/go.mod index ce6cd0f..8b96dbf 100644 --- a/go.mod +++ b/go.mod @@ -16,8 +16,11 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect + github.com/ajg/form v1.5.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/glebarez/go-sqlite v1.22.0 // indirect + github.com/go-chi/chi/v5 v5.1.0 // indirect + github.com/go-chi/render v1.0.3 // indirect github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang/glog v1.2.2 // indirect github.com/google/uuid v1.6.0 // indirect diff --git a/go.sum b/go.sum index 60069c3..1add6d6 100644 --- a/go.sum +++ b/go.sum @@ -87,6 +87,8 @@ github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lpr github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OpenPeeDeeP/depguard/v2 v2.2.0 h1:vDfG60vDtIuf0MEOhmLlLLSzqaRM8EMcgJPdp74zmpA= github.com/OpenPeeDeeP/depguard/v2 v2.2.0/go.mod h1:CIzddKRvLBC4Au5aYP/i3nyaWQ+ClszLIuVocRiCYFQ= +github.com/ajg/form v1.5.1 h1:t9c7v8JUKu/XxOGBU0yjNpaMloxGEJhUkqFRq0ibGeU= +github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/alecthomas/assert/v2 v2.2.2 h1:Z/iVC0xZfWTaFNE6bA3z07T86hd45Xe2eLt6WVy2bbk= github.com/alecthomas/assert/v2 v2.2.2/go.mod h1:pXcQ2Asjp247dahGEmsZ6ru0UVwnkhktn7S0bBDLxvQ= github.com/alecthomas/go-check-sumtype v0.1.4 h1:WCvlB3l5Vq5dZQTFmodqL2g68uHiSwwlWcT5a2FGK0c= @@ -229,6 +231,10 @@ github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-chi/chi/v5 v5.1.0 h1:acVI1TYaD+hhedDJ3r54HyA6sExp3HfXq7QWEEY/xMw= +github.com/go-chi/chi/v5 v5.1.0/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= +github.com/go-chi/render v1.0.3 h1:AsXqd2a1/INaIfUSKq3G5uA8weYx20FOsM7uSoCyyt4= +github.com/go-chi/render v1.0.3/go.mod h1:/gr3hVkmYR0YlEy3LxCuVRFzEu9Ruok+gFqbIofjao0= github.com/go-critic/go-critic v0.11.4 h1:O7kGOCx0NDIni4czrkRIXTnit0mkyKOCePh3My6OyEU= github.com/go-critic/go-critic v0.11.4/go.mod h1:2QAdo4iuLik5S9YG0rT4wcZ8QxwHYkrr6/2MWAiv/vc= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= diff --git a/internal/config/config.go b/internal/config/config.go index 96a559e..5dd8459 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -51,11 +51,12 @@ type CaptchaConfig struct { // API config. type APIConfig struct { - Host string `env:"API_HOST" env-default:"localhost" env-description:"API host address to bind to" yaml:"host"` - Port int `env:"API_PORT" env-default:"8080" env-description:"API port to bind to" yaml:"port"` - ReadTimeout time.Duration `env:"API_READ_TIMEOUT" env-default:"10s" yaml:"read_timeout"` - WriteTimeout time.Duration `env:"API_WRITE_TIMEOUT" env-default:"10s" yaml:"write_timeout"` - IdleTimeout time.Duration `env:"API_IDLE_TIMEOUT" env-default:"15s" yaml:"idle_timeout"` + Host string `env:"API_HOST" env-default:"" env-description:"API host address to bind to" yaml:"host"` + Port int `env:"API_PORT" env-default:"8080" env-description:"API port to bind to" yaml:"port"` + Timeout time.Duration `env:"API_TIMEOUT" env-default:"15s" yaml:"timeout"` + ReadTimeout time.Duration `env:"API_READ_TIMEOUT" env-default:"10s" yaml:"read_timeout"` + WriteTimeout time.Duration `env:"API_WRITE_TIMEOUT" env-default:"10s" yaml:"write_timeout"` + IdleTimeout time.Duration `env:"API_IDLE_TIMEOUT" env-default:"15s" yaml:"idle_timeout"` } // SQLite / PostgreSQL / MySQL config for GORM dialector. diff --git a/internal/err/err.go b/internal/err/err.go new file mode 100644 index 0000000..133f481 --- /dev/null +++ b/internal/err/err.go @@ -0,0 +1,16 @@ +package err + +import ( + "errors" + "fmt" +) + +var ( + ErrorUnexpectedType = errors.New("unexpected type") // Static error for unexpected type. + ErrorGlobalVariablesNotInitialized = errors.New("global variables not initialized") // Static error for global variables not initialized. +) + +// WrapUnexpectedType wraps the error for unexpected type. +func WrapUnexpectedType(expected string, actual interface{}) error { + return fmt.Errorf("%w: expected %s, got %T", ErrorUnexpectedType, expected, actual) +} diff --git a/internal/errors/errors.go b/internal/errors/errors.go deleted file mode 100644 index 156ef96..0000000 --- a/internal/errors/errors.go +++ /dev/null @@ -1,14 +0,0 @@ -package errors - -import ( - "errors" - "fmt" -) - -// Static error for unexpected type. -var ErrorUnexpectedType = errors.New("unexpected type") - -// WrapUnexpectedType wraps the error for unexpected type. -func WrapUnexpectedType(expected string, actual interface{}) error { - return fmt.Errorf("%w: expected %s, got %T", ErrorUnexpectedType, expected, actual) -} diff --git a/internal/global/global.go b/internal/global/global.go new file mode 100644 index 0000000..96f2a80 --- /dev/null +++ b/internal/global/global.go @@ -0,0 +1,12 @@ +package global + +import ( + slog "log/slog" + + conf "github.com/plugfox/foxy-gram-server/internal/config" +) + +var ( + Logger *slog.Logger //nolint:gochecknoglobals + Config *conf.Config //nolint:gochecknoglobals +) diff --git a/internal/log/log_adapter.go b/internal/log/log_adapter.go index ec450c8..734d88c 100644 --- a/internal/log/log_adapter.go +++ b/internal/log/log_adapter.go @@ -18,3 +18,9 @@ func (a *logAdapter) Write(p []byte) (n int, err error) { return len(p), nil } + +func (a *logAdapter) Print(p []byte) (n int, err error) { + a.slog.Info(string(p)) + + return len(p), nil +} diff --git a/internal/server/routes.go b/internal/server/routes.go new file mode 100644 index 0000000..c3d1fd9 --- /dev/null +++ b/internal/server/routes.go @@ -0,0 +1,47 @@ +package server + +import ( + "fmt" + "net/http" + "strings" + + "github.com/go-chi/render" + "github.com/plugfox/foxy-gram-server/api" +) + +// echo route for testing purposes +func echoRoute(w http.ResponseWriter, r *http.Request) { + // Create a map to hold the request data + var data map[string]any + + // Decode the request body into the data map + if r.ContentLength != 0 { + if strings.Contains(r.Header.Get("Content-Type"), "application/json") { + if err := render.Decode(r, &data); err != nil { + api.NewResponse().SetError("bad_request", err.Error()).BadRequest(w) + + return + } + } else { + msg := fmt.Sprintf("Content-Type: %s", r.Header.Get("Content-Type")) + + api.NewResponse().SetError("bad_request", "Content-Type must be application/json", msg).BadRequest(w) + + return + } + } + + api.NewResponse().SetData(struct { + URL string `json:"url"` + Remote string `json:"remote"` + Method string `json:"method"` + Headers http.Header `json:"headers"` + Body map[string]any `json:"body"` + }{ + URL: r.URL.String(), + Remote: r.RemoteAddr, + Method: r.Method, + Headers: r.Header, + Body: data, + }).Ok(w) +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..62e4fa0 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,265 @@ +package server + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log/slog" + "net/http" + "runtime" + "runtime/debug" + "strings" + "time" + + "github.com/go-chi/chi/v5" + "github.com/go-chi/chi/v5/middleware" + "github.com/plugfox/foxy-gram-server/api" + "github.com/plugfox/foxy-gram-server/internal/global" + "github.com/plugfox/foxy-gram-server/internal/log" + "github.com/plugfox/foxy-gram-server/internal/storage" +) + +type Server struct { + router *chi.Mux + public chi.Router + admin chi.Router + server *http.Server +} + +func New() *Server { // Router for HTTP API and Websocket centrifuge protocol. + middleware.DefaultLogger = middleware.RequestLogger(&middleware.DefaultLogFormatter{Logger: log.NewLogAdapter(global.Logger)}) + router := chi.NewRouter() + /* router.Use(middleware.Recoverer) */ + router.Use(middlewareErrorRecoverer(global.Logger)) + router.Use(middleware.Logger) + router.Use(middleware.RequestID) + router.Use(middleware.RealIP) + router.Use(middleware.URLFormat) + router.Use(middleware.StripSlashes) + router.Use(middleware.RedirectSlashes) + router.Use(middleware.Timeout(global.Config.API.Timeout)) + router.Use(middleware.Heartbeat("/ping")) + + /* + r.Use(middleware.StripSlashes) + r.Use(middleware.Compress(5)) + r.Use(middleware.RedirectSlashes) + r.Use(middleware.RequestLogger(&middleware.DefaultLogFormatter{Logger: log})) + r.Use(middleware.Throttle(100)) + */ + + // Public API group + public := router.Group(func(r chi.Router) { + // Middleware + r.Use(middleware.NoCache) + + // Routes + r.HandleFunc("/echo", echoRoute) + r.HandleFunc("/echo/*", echoRoute) + }) + + // Admin API group + const compressionLevel = 5 + + fs := http.FileServer(http.Dir("./")) // File server + + admin := router.Group(func(r chi.Router) { + // Middleware + r.Use(middlewareAuthorization(global.Config.Secret)) + + // File server + r.Route("/admin", func(r chi.Router) { + r.Route("/files", func(r chi.Router) { + r.Use(middleware.NoCache) + r.Use(middleware.Compress(compressionLevel)) + r.Handle("/*", http.StripPrefix("/admin/files", fs)) + }) + }) + }) + + // Create a new HTTP server + server := &http.Server{ + Addr: fmt.Sprintf("%s:%d", global.Config.API.Host, global.Config.API.Port), + Handler: router, + WriteTimeout: global.Config.API.WriteTimeout, + ReadTimeout: global.Config.API.ReadTimeout, + IdleTimeout: global.Config.API.IdleTimeout, + ErrorLog: log.NewLogAdapter(global.Logger), + } + + return &Server{ + router: router, + public: public, + admin: admin, + server: server, + } +} + +// AddHealthCheck adds a health check endpoint to the server. +// The statusFunc function should return a map of status information. +// The map keys will be used as the status names in the response. +// The map values will be used as the status values in the response. +func (srv *Server) AddHealthCheck(statusFunc func() (bool, map[string]string)) { + const bytesInMb = 1024 * 1024 + + startedAt := time.Now() // Start time + + handler := func(w http.ResponseWriter, _ *http.Request) { + ok, status := statusFunc() + + var memStats runtime.MemStats + + runtime.ReadMemStats(&memStats) + + data := map[string]any{ + "status": status, + "uptime": time.Since(startedAt).String(), + // Allocated memory / Reserved program memory + "memory": fmt.Sprintf("%v Mb / %v Mb", memStats.Alloc/bytesInMb, memStats.Sys/bytesInMb), + "cpu": runtime.NumCPU(), + "goroutines": runtime.NumGoroutine(), + } + + if ok { + api.NewResponse().SetData(data).Ok(w) + } else { + api.NewResponse().SetError("status_error", "One or more services are not healthy", data).InternalServerError(w) + } + } + + srv.public.Get("/health", handler) + srv.public.Get("/status", handler) + srv.public.Get("/healthz", handler) + srv.public.Get("/statusz", handler) + srv.public.Get("/metrics", handler) + srv.public.Get("/info", handler) +} + +func (srv *Server) AddVerifyUsers(db *storage.Storage) { + handler := func(w http.ResponseWriter, r *http.Request) { + var requestBody struct { + IDs []int `json:"ids"` + Reason string `json:"reason,omitempty"` + } + + if err := json.NewDecoder(r.Body).Decode(&requestBody); err != nil { + api.NewResponse().SetError("bad_request", err.Error()).BadRequest(w) + } else if (requestBody.IDs == nil) || (len(requestBody.IDs) == 0) { + api.NewResponse().SetError("bad_request", "IDs are required").BadRequest(w) + } else { + if requestBody.Reason == "" { + requestBody.Reason = "Verified from API" + } + + if err := db.VerifyUsers(requestBody.Reason, requestBody.IDs); err != nil { + api.NewResponse().SetError("internal_server_error", err.Error()).InternalServerError(w) + } else { + api.NewResponse().Ok(w) + } + } + } + + srv.admin.Post("/admin/verify", handler) +} + +// AddPublicRoute adds a public route to the server. +func (srv *Server) AddPublicRoute(method string, path string, handler http.HandlerFunc) { + srv.public.Method(method, path, handler) +} + +// AddAdminRoute adds an admin route to the server. +func (srv *Server) AddAdminRoute(method string, path string, handler http.HandlerFunc) { + srv.admin.Method(method, path, handler) +} + +// Status returns the server status. +func (srv *Server) Status() (string, error) { + return "ok", nil +} + +// ListenAndServe starts the server and listens for incoming requests. +func (srv *Server) ListenAndServe() error { + return srv.server.ListenAndServe() +} + +// Shutdown gracefully shuts down the server without interrupting any active connections. +func (srv *Server) Shutdown(ctx context.Context) error { + return srv.server.Shutdown(ctx) +} + +// Close closes the server immediately. +func (srv *Server) Close() error { + return srv.server.Close() +} + +// middlewareAuthorization is a middleware function that checks the Authorization header for a Bearer token. +func middlewareAuthorization(secret string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + + // Check if the Authorization header is missing + if authHeader == "" { + api.NewResponse().SetError("unauthorized", "Authorization header is required").Unauthorized(w) + + return + } + + // Check if the Authorization header is not a Bearer token + token := strings.TrimPrefix(authHeader, "Bearer ") + if token == authHeader { // If the Authorization header is not a Bearer token + api.NewResponse().SetError("unauthorized", "Bearer token is required").Unauthorized(w) + + return + } + + // Check if the Bearer token is invalid + if token != secret { + api.NewResponse().SetError("unauthorized", "Invalid Bearer token").Unauthorized(w) + + return + } + + // Call the next handler + next.ServeHTTP(w, r) + }) + } +} + +// middlewareErrorRecoverer is a middleware function that recovers from panics and returns an error response. +func middlewareErrorRecoverer(logger *slog.Logger) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + if e, ok := err.(error); ok { + if errors.Is(e, http.ErrAbortHandler) { + // we don't recover http.ErrAbortHandler so the response + // to the client is aborted, this should not be logged + panic(err) + } + } + + if r.Header.Get("Connection") == "Upgrade" { + return + } + + // Log the error + logger.ErrorContext(context.Background(), "Recovered from panic", slog.String("error", fmt.Sprintf("%v", err))) + + api.NewResponse().SetError("internal_server_error", + "Internal Server Error", + map[string]any{ + "error": fmt.Sprintf("%v", err), + "stack": string(debug.Stack()), + }, + ).InternalServerError(w) + } + }() + + // Call the next handler + next.ServeHTTP(w, r) + }) + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go index b83a520..3f7b5a7 100644 --- a/internal/storage/storage.go +++ b/internal/storage/storage.go @@ -11,11 +11,10 @@ import ( "encoding/gob" "errors" "fmt" - "log/slog" "time" "github.com/dgraph-io/ristretto" - config "github.com/plugfox/foxy-gram-server/internal/config" + "github.com/plugfox/foxy-gram-server/internal/global" "github.com/plugfox/foxy-gram-server/internal/model" storage_logger "github.com/plugfox/foxy-gram-server/internal/storage/storagelogger" "gorm.io/gorm" @@ -31,7 +30,7 @@ type Storage struct { db *gorm.DB } -func New(config *config.Config, log *slog.Logger) (*Storage, error) { +func New() (*Storage, error) { // Cache const ( numCounters = 1e7 // number of keys to track frequency of (10M). @@ -63,14 +62,14 @@ func New(config *config.Config, log *slog.Logger) (*Storage, error) { } // SQL database connection - dialector, err := createDialector(&config.Database) + dialector, err := createDialector(&global.Config.Database) if err != nil { return nil, err } // Log SQL queries if enabled - dbLogger := storage_logger.NewGormSlogLogger(log) - if config.Database.Logging { + dbLogger := storage_logger.NewGormSlogLogger(global.Logger) + if global.Config.Database.Logging { dbLogger.LogMode(logger.Info) } else { dbLogger.LogMode(logger.Silent) @@ -117,6 +116,18 @@ func New(config *config.Config, log *slog.Logger) (*Storage, error) { }, nil } +// Status - get the status of the database connection. +func (s *Storage) Status() (string, error) { + var result int + if err := s.db.Raw("SELECT 1").Scan(&result).Error; err != nil { + return "error", err + } + + s.cache.Get("status") // Just to show that the cache is working + + return "ok", nil +} + // Close - close the database connection. func (s *Storage) Close() error { s.cache.Close() @@ -481,6 +492,40 @@ func (s *Storage) VerifyUser(verifiedUser *model.VerifiedUser) error { return nil } +// VerifyUsers - verify the multiple users. +func (s *Storage) VerifyUsers(reason string, userIDs []int) error { + if err := s.db.Transaction(func(tx *gorm.DB) error { + if err := tx.Delete(&model.BannedUser{}, "id IN ?", userIDs).Error; err != nil { + return err + } + + users := make([]model.VerifiedUser, 0, len(userIDs)) + for _, userID := range userIDs { + users = append(users, model.VerifiedUser{ + ID: model.UserID(userID), + VerifiedAt: time.Now(), + Reason: reason, + }) + } + + const batchSize = 1000 + + if err := tx.Clauses(clause.OnConflict{ + /* UpdateAll: true, */ + Columns: []clause.Column{{Name: "id"}}, + DoUpdates: clause.AssignmentColumns([]string{"verified_at", "reason"}), + }).CreateInBatches(users, batchSize).Error; err != nil { + return err + } + + return nil + }); err != nil { + return err + } + + return nil +} + // Ban the user. func (s *Storage) BanUser(bannedUser *model.BannedUser) error { if err := s.db.Transaction(func(tx *gorm.DB) error { diff --git a/internal/telegram/middlewares.go b/internal/telegram/middlewares.go index 035c9ea..b5a4702 100644 --- a/internal/telegram/middlewares.go +++ b/internal/telegram/middlewares.go @@ -8,8 +8,8 @@ import ( "net/http" "time" - config "github.com/plugfox/foxy-gram-server/internal/config" "github.com/plugfox/foxy-gram-server/internal/converters" + "github.com/plugfox/foxy-gram-server/internal/global" "github.com/plugfox/foxy-gram-server/internal/model" "github.com/plugfox/foxy-gram-server/internal/storage" "github.com/plugfox/foxy-gram-server/internal/utility" @@ -27,14 +27,14 @@ type captchaMessage struct { } // Check if the chat is allowed. -func allowedChats(config *config.Config, chatID int64) bool { - for _, id := range config.Telegram.Chats { +func allowedChats(chatID int64) bool { + for _, id := range global.Config.Telegram.Chats { if id == chatID { return true } } - return len(config.Telegram.Chats) == 0 + return len(global.Config.Telegram.Chats) == 0 } // Restrict user rights @@ -123,7 +123,6 @@ func isUserBanned(db *storage.Storage, httpClient *http.Client, user *tele.User) func verifyUserMiddleware( db *storage.Storage, httpClient *http.Client, - config *config.Config, onError func(error), ) tele.MiddlewareFunc { // Centralized error handling @@ -148,7 +147,7 @@ func verifyUserMiddleware( } // If it not allowed chat - skip it - if !allowedChats(config, chat.ID) { + if !allowedChats(chat.ID) { return nil // Skip the current message, if it is not allowed chat } diff --git a/internal/telegram/telegram.go b/internal/telegram/telegram.go index b1b8c8e..af0b82d 100644 --- a/internal/telegram/telegram.go +++ b/internal/telegram/telegram.go @@ -6,8 +6,8 @@ import ( "log/slog" "net/http" - config "github.com/plugfox/foxy-gram-server/internal/config" "github.com/plugfox/foxy-gram-server/internal/converters" + "github.com/plugfox/foxy-gram-server/internal/global" log "github.com/plugfox/foxy-gram-server/internal/log" "github.com/plugfox/foxy-gram-server/internal/model" "github.com/plugfox/foxy-gram-server/internal/storage" @@ -21,15 +21,15 @@ type Telegram struct { bot *tele.Bot } -func New(db *storage.Storage, httpClient *http.Client, config *config.Config, logger *slog.Logger) (*Telegram, error) { +func New(db *storage.Storage, httpClient *http.Client) (*Telegram, error) { pref := tele.Settings{ - Token: config.Telegram.Token, + Token: global.Config.Telegram.Token, Client: httpClient, Poller: &tele.LongPoller{ - Timeout: config.Telegram.Timeout, + Timeout: global.Config.Telegram.Timeout, }, OnError: func(err error, _ tele.Context) { - logger.Error("telegram error", slog.String("error", err.Error())) + global.Logger.Error("telegram error", slog.String("error", err.Error())) }, } @@ -41,27 +41,27 @@ func New(db *storage.Storage, httpClient *http.Client, config *config.Config, lo // Global-scoped middleware: bot.Use(mw.Recover()) bot.Use(mw.AutoRespond()) - bot.Use(mw.Logger(log.NewLogAdapter(logger))) + bot.Use(mw.Logger(log.NewLogAdapter(global.Logger))) - if config.Telegram.IgnoreVia { + if global.Config.Telegram.IgnoreVia { bot.Use(mw.IgnoreVia()) } - bot.Use(verifyUserMiddleware(db, httpClient, config, func(err error) { - logger.Error("verify user error", slog.String("error", err.Error())) + bot.Use(verifyUserMiddleware(db, httpClient, func(err error) { + global.Logger.Error("verify user error", slog.String("error", err.Error())) })) - if len(config.Telegram.Whitelist) > 0 { - bot.Use(mw.Whitelist(config.Telegram.Whitelist...)) + if len(global.Config.Telegram.Whitelist) > 0 { + bot.Use(mw.Whitelist(global.Config.Telegram.Whitelist...)) } - if len(config.Telegram.Blacklist) > 0 { - bot.Use(mw.Blacklist(config.Telegram.Blacklist...)) + if len(global.Config.Telegram.Blacklist) > 0 { + bot.Use(mw.Blacklist(global.Config.Telegram.Blacklist...)) } // Store messages in the database bot.Use(storeMessagesMiddleware(db, func(err error) { - logger.Error("store message error", slog.String("error", err.Error())) + global.Logger.Error("store message error", slog.String("error", err.Error())) })) /* bot.Use(mw.Restrict(mw.RestrictConfig{ @@ -79,11 +79,11 @@ func New(db *storage.Storage, httpClient *http.Client, config *config.Config, lo }) */ // Group-scoped middleware: - if len(config.Telegram.Admins) > 0 { + if len(global.Config.Telegram.Admins) > 0 { adminOnly := bot.Group() /* adminOnly.Handle("/ban", onBan) adminOnly.Handle("/kick", onKick) */ - adminOnly.Use(middleware.Whitelist(config.Telegram.Admins...)) + adminOnly.Use(middleware.Whitelist(global.Config.Telegram.Admins...)) } // TODO: add more handlers @@ -109,11 +109,18 @@ func New(db *storage.Storage, httpClient *http.Client, config *config.Config, lo return nil }) + // TODO: handle captcha methods, get information about captcha directly from the database + return &Telegram{ bot: bot, }, nil } +// Status returns the telegram bot status. +func (t *Telegram) Status() (string, error) { + return "ok", nil +} + // Start the bot. func (t *Telegram) Start() { t.bot.Start()