diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..e69e06d4 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,22 @@ +version: 2 +updates: + # Monitor Go dependencies (go.mod) + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" + open-pull-requests-limit: 10 + # Group minor and patch updates to reduce the number of PRs and branches + groups: + go-dependencies: + patterns: + - "*" + update-types: + - "minor" + - "patch" + + # Monitor GitHub Actions workflows + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/go-ci.yml b/.github/workflows/go-ci.yml new file mode 100644 index 00000000..3b3e23e2 --- /dev/null +++ b/.github/workflows/go-ci.yml @@ -0,0 +1,134 @@ +name: Go CI + +on: + push: + branches: ["main"] + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".golangci.yml" + - "db/queries/**/*.sql" + - "db/migrations/**/*.sql" + - ".github/workflows/go-ci.yml" + pull_request: + branches: ["main"] + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".golangci.yml" + - "db/queries/**/*.sql" + - "db/migrations/**/*.sql" + - ".github/workflows/go-ci.yml" + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.9.0 + + test: + name: Test + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: Run Tests + run: | + go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... + - name: Upload coverage to Codecov + if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.fork == false }} + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.txt + fail_ci_if_error: false # Don't fail if Codecov upload fails +name: Go CI + +on: + push: + branches: ["main"] + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".golangci.yml" + - "db/queries/**/*.sql" + - "db/migrations/**/*.sql" + - ".github/workflows/go-ci.yml" + pull_request: + branches: ["main"] + paths: + - "**/*.go" + - "go.mod" + - "go.sum" + - ".golangci.yml" + - "db/queries/**/*.sql" + - "db/migrations/**/*.sql" + - ".github/workflows/go-ci.yml" + +permissions: + contents: read + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + timeout-minutes: 20 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: golangci-lint + uses: golangci/golangci-lint-action@v9 + with: + version: v2.9.0 + + test: + name: Test + runs-on: ubuntu-latest + timeout-minutes: 30 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-go@v5 + with: + go-version-file: go.mod + cache: true + - name: Run Tests + run: | + go test -v -race -coverprofile=coverage.txt -covermode=atomic ./... + - name: Upload coverage to Codecov + if: ${{ github.event_name != 'pull_request' || github.event.pull_request.head.repo.fork == false }} + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + file: ./coverage.txt + fail_ci_if_error: false # Don't fail if Codecov upload fails diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 00000000..47ba7b50 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,111 @@ +version: "2" + +run: + timeout: 15m + go: "1.25" + tests: true + +output: + show-stats: false + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 + +linters: + default: none + enable: + - errcheck + - errorlint + - exptostd + - fatcontext + - gocritic + - godot + - govet + - ineffassign + - misspell + - modernize + - nilnesserr + - perfsprint + - predeclared + - revive + - sloglint + - staticcheck + - testifylint + - unconvert + - unused + - usestdlibvars + - whitespace + exclusions: + paths: + - internal/db/sqlc + - ^.*\.(pb|l|y)\.go$ + +formatters: + enable: + - gci + - gofumpt + - goimports + settings: + gci: + sections: + - standard + - default + - prefix(github.com/memohai/memoh) + gofumpt: + extra-rules: false + goimports: + local-prefixes: + - github.com/memohai/memoh + +linters-settings: + govet: + enable-all: true + disable: + - shadow + - fieldalignment + gocyclo: + min-complexity: 10 + funlen: + lines: 60 + statements: 30 + modernize: + disable: + - omitzero + perfsprint: + int-conversion: true + err-error: true + errorf: true + sprintf1: true + strconcat: false + concat-loop: false + revive: + rules: + - name: blank-imports + - name: comment-spacings + - name: context-as-argument + arguments: + - allowTypesBefore: "*testing.T,testing.TB" + - name: dot-imports + - name: error-naming + - name: error-return + - name: error-strings + - name: exported + - name: increment-decrement + - name: var-naming + - name: var-declaration + - name: unreachable-code + - name: unused-parameter + - name: unused-receiver + sloglint: + attr-only: true + no-global: default + context-only: scope + static-msg: true + key-naming-case: snake + forbidden-keys: [time, level, msg, source] + testifylint: + enable-all: true + disable: + - float-compare + - go-require diff --git a/cmd/agent/main.go b/cmd/agent/main.go index b0656bf4..286383ee 100644 --- a/cmd/agent/main.go +++ b/cmd/agent/main.go @@ -1,3 +1,4 @@ +// Package main is the entry point for the Memoh agent server (HTTP API, channels, containers). package main import ( @@ -192,7 +193,7 @@ func provideContainerdClient(lc fx.Lifecycle, rc *boot.RuntimeConfig) (*containe return nil, fmt.Errorf("connect containerd: %w", err) } lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { + OnStop: func(_ context.Context) error { return client.Close() }, }) @@ -205,7 +206,7 @@ func provideDBConn(lc fx.Lifecycle, cfg config.Config) (*pgxpool.Pool, error) { return nil, fmt.Errorf("db connect: %w", err) } lc.Append(fx.Hook{ - OnStop: func(ctx context.Context) error { + OnStop: func(_ context.Context) error { conn.Close() return nil }, @@ -321,8 +322,8 @@ func provideChatResolver(log *slog.Logger, cfg config.Config, modelsService *mod func provideChannelRegistry(log *slog.Logger, hub *local.RouteHub) *channel.Registry { registry := channel.NewRegistry() - registry.MustRegister(telegram.NewTelegramAdapter(log)) - registry.MustRegister(feishu.NewFeishuAdapter(log)) + registry.MustRegister(telegram.NewAdapter(log)) + registry.MustRegister(feishu.NewAdapter(log)) registry.MustRegister(local.NewCLIAdapter(hub)) registry.MustRegister(local.NewWebAdapter(hub)) return registry @@ -383,7 +384,7 @@ func provideMemoryHandler(log *slog.Logger, service *memory.Service, chatService if strings.TrimSpace(execWorkDir) == "" { execWorkDir = config.DefaultDataMount } - h.SetMemoryFS(memory.NewMemoryFS(log, manager, execWorkDir)) + h.SetMemoryFS(memory.NewFS(log, manager, execWorkDir)) } return h } @@ -435,7 +436,7 @@ func provideServer(params serverParams) *server.Server { func startMemoryWarmup(lc fx.Lifecycle, memoryService *memory.Service, logger *slog.Logger) { lc.Append(fx.Hook{ - OnStart: func(ctx context.Context) error { + OnStart: func(_ context.Context) error { go func() { if err := memoryService.WarmupBM25(context.Background(), 200); err != nil { logger.Warn("bm25 warmup failed", slog.Any("error", err)) @@ -526,7 +527,7 @@ func buildTextEmbedder(resolver *embeddings.Resolver, textModel models.GetRespon func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Queries, cfg config.Config) error { if queries == nil { - return fmt.Errorf("db queries not configured") + return errors.New("db queries not configured") } count, err := queries.CountAccounts(ctx) if err != nil { @@ -540,7 +541,7 @@ func ensureAdminUser(ctx context.Context, log *slog.Logger, queries *dbsqlc.Quer password := strings.TrimSpace(cfg.Admin.Password) email := strings.TrimSpace(cfg.Admin.Email) if username == "" || password == "" { - return fmt.Errorf("admin username/password required in config.toml") + return errors.New("admin username/password required in config.toml") } if password == "change-your-password-here" { log.Warn("admin password uses default placeholder; please update config.toml") @@ -629,7 +630,7 @@ func (c *lazyLLMClient) DetectLanguage(ctx context.Context, text string) (string func (c *lazyLLMClient) resolve(ctx context.Context) (memory.LLM, error) { if c.modelsService == nil || c.queries == nil { - return nil, fmt.Errorf("models service not configured") + return nil, errors.New("models service not configured") } memoryModel, memoryProvider, err := models.SelectMemoryModel(ctx, c.modelsService, c.queries) if err != nil { diff --git a/cmd/cli/main.go b/cmd/cli/main.go index 1836dfaf..c72f6001 100644 --- a/cmd/cli/main.go +++ b/cmd/cli/main.go @@ -1,8 +1,10 @@ +// Package main is the entry point for the Memoh CLI. package main import ( "context" "encoding/json" + "errors" "flag" "fmt" "io" @@ -42,7 +44,8 @@ func main() { cmd := buildMCPCommand(*containerID) if err := runWithStdio(cmd); err != nil { - if exitErr, ok := err.(*exec.ExitError); ok { + exitErr := &exec.ExitError{} + if errors.As(err, &exitErr) { os.Exit(exitErr.ExitCode()) } os.Exit(1) @@ -214,10 +217,10 @@ func parseCNIArgs(args []string) (string, string, error) { return "", "", err } if *id == "" { - return "", "", fmt.Errorf("missing --id") + return "", "", errors.New("missing --id") } if *netns == "" && *pid == 0 { - return "", "", fmt.Errorf("missing --netns or --pid") + return "", "", errors.New("missing --netns or --pid") } if *netns == "" { *netns = filepath.Join("/proc", strconv.Itoa(*pid), "ns", "net") diff --git a/cmd/feishu-echo/main.go b/cmd/feishu-echo/main.go index d092ca6a..103b355b 100644 --- a/cmd/feishu-echo/main.go +++ b/cmd/feishu-echo/main.go @@ -15,10 +15,9 @@ import ( "sync/atomic" "time" + "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" larkws "github.com/larksuite/oapi-sdk-go/v3/ws" - - "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" ) type eventCounts struct { diff --git a/cmd/mcp/main.go b/cmd/mcp/main.go index 9110eceb..f5216e7c 100644 --- a/cmd/mcp/main.go +++ b/cmd/mcp/main.go @@ -1,3 +1,4 @@ +// Package main is the entry point for the Memoh MCP stdio server. package main import ( @@ -9,12 +10,17 @@ import ( "os/signal" "syscall" + gomcp "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/memohai/memoh/internal/logger" "github.com/memohai/memoh/internal/version" - gomcp "github.com/modelcontextprotocol/go-sdk/mcp" ) func main() { + os.Exit(run()) +} + +func run() int { ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() @@ -25,18 +31,18 @@ func main() { ) err := server.Run(ctx, &gomcp.StdioTransport{}) if ctx.Err() != nil { - return + return 0 } if err == nil { logger.Warn("mcp server exited without error; waiting for shutdown signal") <-ctx.Done() - return + return 0 } if errors.Is(err, io.EOF) { logger.Warn("mcp stdio closed; waiting for shutdown signal") <-ctx.Done() - return + return 0 } logger.Error("mcp server failed", slog.Any("error", err)) - os.Exit(1) + return 1 } diff --git a/internal/accounts/service.go b/internal/accounts/service.go index 90aea315..acf71556 100644 --- a/internal/accounts/service.go +++ b/internal/accounts/service.go @@ -1,3 +1,4 @@ +// Package accounts provides user account and credential management. package accounts import ( @@ -22,6 +23,7 @@ type Service struct { logger *slog.Logger } +// Errors returned by account operations. var ( ErrInvalidPassword = errors.New("invalid password") ErrInvalidCredentials = errors.New("invalid credentials") @@ -42,7 +44,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { // Get returns an account by user id. func (s *Service) Get(ctx context.Context, userID string) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -58,7 +60,7 @@ func (s *Service) Get(ctx context.Context, userID string) (Account, error) { // Login authenticates by identity (username or email) and password. func (s *Service) Login(ctx context.Context, identity, password string) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } identity = strings.TrimSpace(identity) if identity == "" || strings.TrimSpace(password) == "" { @@ -91,7 +93,7 @@ func (s *Service) Login(ctx context.Context, identity, password string) (Account // ListAccounts returns all accounts. func (s *Service) ListAccounts(ctx context.Context) ([]Account, error) { if s.queries == nil { - return nil, fmt.Errorf("account queries not configured") + return nil, errors.New("account queries not configured") } rows, err := s.queries.ListAccounts(ctx) if err != nil { @@ -107,7 +109,7 @@ func (s *Service) ListAccounts(ctx context.Context) ([]Account, error) { // IsAdmin checks if the user has admin role. func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { if s.queries == nil { - return false, fmt.Errorf("account queries not configured") + return false, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -126,15 +128,15 @@ func (s *Service) IsAdmin(ctx context.Context, userID string) (bool, error) { // Create creates a new account for an existing user. func (s *Service) Create(ctx context.Context, userID string, req CreateAccountRequest) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } username := strings.TrimSpace(req.Username) if username == "" { - return Account{}, fmt.Errorf("username is required") + return Account{}, errors.New("username is required") } password := strings.TrimSpace(req.Password) if password == "" { - return Account{}, fmt.Errorf("password is required") + return Account{}, errors.New("password is required") } role, err := normalizeRole(req.Role) if err != nil { @@ -195,7 +197,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco userID = strings.TrimSpace(userID) if userID == "" { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } userRow, err := s.queries.CreateUser(ctx, sqlc.CreateUserParams{ IsActive: true, @@ -205,7 +207,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco return Account{}, err } if !userRow.ID.Valid { - return Account{}, fmt.Errorf("create user: invalid id") + return Account{}, errors.New("create user: invalid id") } userID = userRow.ID.String() } @@ -215,7 +217,7 @@ func (s *Service) CreateHuman(ctx context.Context, userID string, req CreateAcco // UpdateAdmin updates account fields as admin. func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAccountRequest) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -264,7 +266,7 @@ func (s *Service) UpdateAdmin(ctx context.Context, userID string, req UpdateAcco // UpdateProfile updates the user's profile. func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdateProfileRequest) (Account, error) { if s.queries == nil { - return Account{}, fmt.Errorf("account queries not configured") + return Account{}, errors.New("account queries not configured") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -300,10 +302,10 @@ func (s *Service) UpdateProfile(ctx context.Context, userID string, req UpdatePr // UpdatePassword changes the password after verifying the current one. func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, newPassword string) error { if s.queries == nil { - return fmt.Errorf("account queries not configured") + return errors.New("account queries not configured") } if strings.TrimSpace(newPassword) == "" { - return fmt.Errorf("new password is required") + return errors.New("new password is required") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -336,10 +338,10 @@ func (s *Service) UpdatePassword(ctx context.Context, userID, currentPassword, n // ResetPassword sets a new password without requiring the current one. func (s *Service) ResetPassword(ctx context.Context, userID, newPassword string) error { if s.queries == nil { - return fmt.Errorf("account queries not configured") + return errors.New("account queries not configured") } if strings.TrimSpace(newPassword) == "" { - return fmt.Errorf("new password is required") + return errors.New("new password is required") } pgID, err := db.ParseUUID(userID) if err != nil { @@ -423,4 +425,3 @@ func toAccount(row sqlc.User) Account { LastLoginAt: lastLogin, } } - diff --git a/internal/accounts/types.go b/internal/accounts/types.go index 7a3b4f62..7c1862b1 100644 --- a/internal/accounts/types.go +++ b/internal/accounts/types.go @@ -13,7 +13,7 @@ type Account struct { IsActive bool `json:"is_active"` CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` - LastLoginAt time.Time `json:"last_login_at,omitempty"` + LastLoginAt time.Time `json:"last_login_at,omitzero"` } // CreateAccountRequest is the input for creating an account. diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go index 8c837b43..50126dff 100644 --- a/internal/auth/jwt.go +++ b/internal/auth/jwt.go @@ -1,6 +1,8 @@ +// Package auth provides JWT-based authentication and middleware. package auth import ( + "errors" "fmt" "net/http" "strings" @@ -30,7 +32,7 @@ func JWTMiddleware(secret string, skipper middleware.Skipper) echo.MiddlewareFun SigningMethod: "HS256", TokenLookup: "header:Authorization:Bearer ", Skipper: skipper, - NewClaimsFunc: func(c echo.Context) jwt.Claims { + NewClaimsFunc: func(_ echo.Context) jwt.Claims { return jwt.MapClaims{} }, }) @@ -58,13 +60,13 @@ func UserIDFromContext(c echo.Context) (string, error) { // GenerateToken creates a signed JWT for the user. func GenerateToken(userID, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(userID) == "" { - return "", time.Time{}, fmt.Errorf("user id is required") + return "", time.Time{}, errors.New("user id is required") } if strings.TrimSpace(secret) == "" { - return "", time.Time{}, fmt.Errorf("jwt secret is required") + return "", time.Time{}, errors.New("jwt secret is required") } if expiresIn <= 0 { - return "", time.Time{}, fmt.Errorf("jwt expires in must be positive") + return "", time.Time{}, errors.New("jwt expires in must be positive") } now := time.Now().UTC() @@ -95,22 +97,22 @@ type ChatToken struct { // GenerateChatToken creates a signed JWT for chat route reply. func GenerateChatToken(info ChatToken, secret string, expiresIn time.Duration) (string, time.Time, error) { if strings.TrimSpace(info.BotID) == "" { - return "", time.Time{}, fmt.Errorf("bot id is required") + return "", time.Time{}, errors.New("bot id is required") } if strings.TrimSpace(info.ChatID) == "" { - return "", time.Time{}, fmt.Errorf("chat id is required") + return "", time.Time{}, errors.New("chat id is required") } if strings.TrimSpace(info.UserID) == "" { info.UserID = strings.TrimSpace(info.ChannelIdentityID) } if strings.TrimSpace(info.UserID) == "" { - return "", time.Time{}, fmt.Errorf("user id is required") + return "", time.Time{}, errors.New("user id is required") } if strings.TrimSpace(secret) == "" { - return "", time.Time{}, fmt.Errorf("jwt secret is required") + return "", time.Time{}, errors.New("jwt secret is required") } if expiresIn <= 0 { - return "", time.Time{}, fmt.Errorf("jwt expires in must be positive") + return "", time.Time{}, errors.New("jwt expires in must be positive") } now := time.Now().UTC() diff --git a/internal/bind/service.go b/internal/bind/service.go index a0c84188..c19cebd3 100644 --- a/internal/bind/service.go +++ b/internal/bind/service.go @@ -1,3 +1,4 @@ +// Package bind provides bind-code issuance and consumption for channel linking. package bind import ( @@ -46,7 +47,7 @@ func NewService(log *slog.Logger, pool *pgxpool.Pool, queries *sqlc.Queries) *Se // Platform is optional; when provided, bind consume must happen on the same channel platform. func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, ttl time.Duration) (Code, error) { if s.queries == nil { - return Code{}, fmt.Errorf("bind queries not configured") + return Code{}, errors.New("bind queries not configured") } if ttl <= 0 { ttl = defaultTTL @@ -59,7 +60,7 @@ func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, tt normalizedPlatform := normalizePlatform(platform) expiresAt := time.Now().UTC().Add(ttl) - for i := 0; i < maxTokenRetries; i++ { + for range maxTokenRetries { token := strings.ToUpper(strings.ReplaceAll(uuid.NewString(), "-", "")[:8]) row, err := s.queries.CreateBindCode(ctx, sqlc.CreateBindCodeParams{ Token: token, @@ -78,13 +79,13 @@ func (s *Service) Issue(ctx context.Context, issuedByUserID, platform string, tt } return Code{}, fmt.Errorf("create bind code: %w", err) } - return Code{}, fmt.Errorf("create bind code: token collision after retries") + return Code{}, errors.New("create bind code: token collision after retries") } // Get looks up a bind code by token. func (s *Service) Get(ctx context.Context, token string) (Code, error) { if s.queries == nil { - return Code{}, fmt.Errorf("bind queries not configured") + return Code{}, errors.New("bind queries not configured") } row, err := s.queries.GetBindCode(ctx, strings.TrimSpace(token)) if err != nil { @@ -99,7 +100,7 @@ func (s *Service) Get(ctx context.Context, token string) (Code, error) { // Consume validates and consumes a bind code and links the channel identity to issuer user. func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID string) error { if s.queries == nil || s.pool == nil { - return fmt.Errorf("bind service not configured") + return errors.New("bind service not configured") } // Fast-fail based on caller snapshot before opening a transaction. @@ -115,7 +116,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri } sourceIdentityID := strings.TrimSpace(channelIdentityID) if sourceIdentityID == "" { - return fmt.Errorf("channel identity id is required") + return errors.New("channel identity id is required") } pgSourceIdentityID, err := db.ParseUUID(sourceIdentityID) if err != nil { @@ -149,7 +150,7 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri targetUserID := strings.TrimSpace(lockedCode.IssuedByUserID) if targetUserID == "" { - return fmt.Errorf("bind code issuer user is missing") + return errors.New("bind code issuer user is missing") } pgTargetUserID, err := db.ParseUUID(targetUserID) if err != nil { @@ -158,14 +159,14 @@ func (s *Service) Consume(ctx context.Context, code Code, channelIdentityID stri if _, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID); err != nil { if errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("channel identity not found") + return errors.New("channel identity not found") } return fmt.Errorf("lock source identity: %w", err) } sourceIdentity, err := qtx.GetChannelIdentityByIDForUpdate(ctx, pgSourceIdentityID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return fmt.Errorf("channel identity not found") + return errors.New("channel identity not found") } return fmt.Errorf("reload source identity: %w", err) } diff --git a/internal/bind/service_consume_integration_test.go b/internal/bind/service_consume_integration_test.go index 735eddb9..c5bad185 100644 --- a/internal/bind/service_consume_integration_test.go +++ b/internal/bind/service_consume_integration_test.go @@ -2,7 +2,6 @@ package bind_test import ( "context" - "encoding/json" "errors" "fmt" "log/slog" @@ -10,12 +9,10 @@ import ( "testing" "time" - "github.com/jackc/pgx/v5/pgtype" "github.com/jackc/pgx/v5/pgxpool" "github.com/memohai/memoh/internal/bind" "github.com/memohai/memoh/internal/channel/identities" - "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) @@ -55,28 +52,6 @@ func createUserForBind(ctx context.Context, queries *sqlc.Queries) (string, erro return row.ID.String(), nil } -func createBotForBind(ctx context.Context, queries *sqlc.Queries, ownerUserID string) (string, error) { - pgOwnerID, err := db.ParseUUID(ownerUserID) - if err != nil { - return "", err - } - meta, err := json.Marshal(map[string]any{"source": "bind-integration-test"}) - if err != nil { - return "", err - } - row, err := queries.CreateBot(ctx, sqlc.CreateBotParams{ - OwnerUserID: pgOwnerID, - Type: "personal", - DisplayName: pgtype.Text{String: "bind-test-bot", Valid: true}, - IsActive: true, - Metadata: meta, - }) - if err != nil { - return "", err - } - return row.ID.String(), nil -} - func TestBindConsumeLinksChannelIdentityToIssuerUser(t *testing.T) { queries, channelIdentitySvc, bindSvc, cleanup := setupBindConsumeIntegrationTest(t) defer cleanup() diff --git a/internal/bind/service_test.go b/internal/bind/service_test.go index 298a744b..a9d8c4e4 100644 --- a/internal/bind/service_test.go +++ b/internal/bind/service_test.go @@ -122,13 +122,13 @@ func TestToCode_OptionalFields(t *testing.T) { } now := time.Now().UTC() row := sqlc.ChannelIdentityBindCode{ - ID: pgID, - Token: "TOKEN", - IssuedByUserID: pgID, - ChannelType: pgtype.Text{Valid: false}, - ExpiresAt: pgtype.Timestamptz{Valid: false}, - UsedAt: pgtype.Timestamptz{Valid: false}, - CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, + ID: pgID, + Token: "TOKEN", + IssuedByUserID: pgID, + ChannelType: pgtype.Text{Valid: false}, + ExpiresAt: pgtype.Timestamptz{Valid: false}, + UsedAt: pgtype.Timestamptz{Valid: false}, + CreatedAt: pgtype.Timestamptz{Time: now, Valid: true}, } c := toCode(row) if c.Platform != "" { @@ -204,4 +204,3 @@ func TestService_Consume_InvalidChannelIdentityID(t *testing.T) { t.Fatal("expected error for invalid channel identity id") } } - diff --git a/internal/bind/types.go b/internal/bind/types.go index 0f5182b0..e6587946 100644 --- a/internal/bind/types.go +++ b/internal/bind/types.go @@ -5,6 +5,7 @@ import ( "time" ) +// Errors returned by bind operations. var ( ErrCodeNotFound = errors.New("bind code not found") ErrCodeUsed = errors.New("bind code already used") @@ -19,8 +20,8 @@ type Code struct { Platform string `json:"platform,omitempty"` Token string `json:"token"` IssuedByUserID string `json:"issued_by_user_id"` - ExpiresAt time.Time `json:"expires_at,omitempty"` - UsedAt time.Time `json:"used_at,omitempty"` + ExpiresAt time.Time `json:"expires_at,omitzero"` + UsedAt time.Time `json:"used_at,omitzero"` UsedByChannelIdentityID string `json:"used_by_channel_identity_id,omitempty"` CreatedAt time.Time `json:"created_at"` } diff --git a/internal/boot/runtime.go b/internal/boot/runtime.go index 0ef1acbc..f4f3e70d 100644 --- a/internal/boot/runtime.go +++ b/internal/boot/runtime.go @@ -1,3 +1,4 @@ +// Package boot provides runtime configuration and dependency wiring for the agent. package boot import ( @@ -10,6 +11,8 @@ import ( "github.com/memohai/memoh/internal/config" ) +// RuntimeConfig holds parsed runtime settings (JWT, server address, containerd socket). +// Values may be overridden by environment variables (e.g. HTTP_ADDR, CONTAINERD_SOCKET). type RuntimeConfig struct { JwtSecret string JwtExpiresIn time.Duration @@ -17,6 +20,7 @@ type RuntimeConfig struct { ContainerdSocketPath string } +// ProvideRuntimeConfig builds RuntimeConfig from the given config and applies env overrides. func ProvideRuntimeConfig(cfg config.Config) (*RuntimeConfig, error) { if strings.TrimSpace(cfg.Auth.JWTSecret) == "" { return nil, errors.New("jwt secret is required") diff --git a/internal/bots/service.go b/internal/bots/service.go index 96abe0af..8901fcee 100644 --- a/internal/bots/service.go +++ b/internal/bots/service.go @@ -1,3 +1,4 @@ +// Package bots provides bot lifecycle, membership, and container management. package bots import ( @@ -7,6 +8,7 @@ import ( "fmt" "log/slog" "os" + "slices" "strings" "time" @@ -30,6 +32,7 @@ const ( botLifecycleOperationTimeout = 5 * time.Minute ) +// Errors returned by bot operations. var ( ErrBotNotFound = errors.New("bot not found") ErrBotAccessDenied = errors.New("bot access denied") @@ -68,7 +71,7 @@ func (s *Service) AddRuntimeChecker(c RuntimeChecker) { // AuthorizeAccess checks whether userID may access the given bot. func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isAdmin bool, policy AccessPolicy) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } bot, err := s.Get(ctx, botID) if err != nil { @@ -96,11 +99,11 @@ func (s *Service) AuthorizeAccess(ctx context.Context, userID, botID string, isA // Create creates a new bot owned by owner user. func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotRequest) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } ownerID := strings.TrimSpace(ownerUserID) if ownerID == "" { - return Bot{}, fmt.Errorf("owner user id is required") + return Bot{}, errors.New("owner user id is required") } ownerUUID, err := db.ParseUUID(ownerID) if err != nil { @@ -156,7 +159,7 @@ func (s *Service) Create(ctx context.Context, ownerUserID string, req CreateBotR // Get returns a bot by its ID. func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -179,7 +182,7 @@ func (s *Service) Get(ctx context.Context, botID string) (Bot, error) { // ListByOwner returns bots owned by the given user. func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } ownerUUID, err := db.ParseUUID(ownerUserID) if err != nil { @@ -206,7 +209,7 @@ func (s *Service) ListByOwner(ctx context.Context, ownerUserID string) ([]Bot, e // ListByMember returns bots where the user is a member. func (s *Service) ListByMember(ctx context.Context, channelIdentityID string) ([]Bot, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } memberUUID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -259,7 +262,7 @@ func (s *Service) ListAccessible(ctx context.Context, channelIdentityID string) // Update updates bot profile fields. func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -316,9 +319,9 @@ func (s *Service) Update(ctx context.Context, botID string, req UpdateBotRequest } // TransferOwner transfers bot ownership to another user. -func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID string) (Bot, error) { +func (s *Service) TransferOwner(ctx context.Context, botID, ownerUserID string) (Bot, error) { if s.queries == nil { - return Bot{}, fmt.Errorf("bot queries not configured") + return Bot{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -351,7 +354,7 @@ func (s *Service) TransferOwner(ctx context.Context, botID string, ownerUserID s // Delete removes a bot and its associated resources. func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -377,7 +380,7 @@ func (s *Service) Delete(ctx context.Context, botID string) error { // ListChecks evaluates runtime resource checks for a bot. func (s *Service) ListChecks(ctx context.Context, botID string) ([]BotCheck, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -453,7 +456,7 @@ func (s *Service) enqueueDeleteLifecycle(botID string) { func (s *Service) updateStatus(ctx context.Context, botID, status string) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -467,7 +470,7 @@ func (s *Service) updateStatus(ctx context.Context, botID, status string) error func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } _, err := s.queries.GetUserByID(ctx, userID) if err != nil { @@ -482,7 +485,7 @@ func (s *Service) ensureUserExists(ctx context.Context, userID pgtype.UUID) erro // UpsertMember creates or updates a bot membership. func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemberRequest) (BotMember, error) { if s.queries == nil { - return BotMember{}, fmt.Errorf("bot queries not configured") + return BotMember{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -510,7 +513,7 @@ func (s *Service) UpsertMember(ctx context.Context, botID string, req UpsertMemb // ListMembers returns all members of a bot. func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, error) { if s.queries == nil { - return nil, fmt.Errorf("bot queries not configured") + return nil, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -530,7 +533,7 @@ func (s *Service) ListMembers(ctx context.Context, botID string) ([]BotMember, e // GetMember returns a specific bot member. func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string) (BotMember, error) { if s.queries == nil { - return BotMember{}, fmt.Errorf("bot queries not configured") + return BotMember{}, errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -553,7 +556,7 @@ func (s *Service) GetMember(ctx context.Context, botID, channelIdentityID string // DeleteMember removes a member from a bot. func (s *Service) DeleteMember(ctx context.Context, botID, channelIdentityID string) error { if s.queries == nil { - return fmt.Errorf("bot queries not configured") + return errors.New("bot queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -788,7 +791,7 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot) ([]BotCh CheckKey: BotCheckKeyContainerRecord, Status: BotCheckStatusOK, Summary: "Container record exists.", - Detail: fmt.Sprintf("container_id=%s", strings.TrimSpace(containerRow.ContainerID)), + Detail: "container_id=" + strings.TrimSpace(containerRow.ContainerID), Metadata: map[string]any{ "container_id": strings.TrimSpace(containerRow.ContainerID), "namespace": strings.TrimSpace(containerRow.Namespace), @@ -806,11 +809,11 @@ func (s *Service) buildRuntimeChecks(ctx context.Context, row sqlc.Bot) ([]BotCh case "running", "created", "stopped", "paused": taskCheck.Status = BotCheckStatusOK taskCheck.Summary = "Container task state is reported." - taskCheck.Detail = fmt.Sprintf("status=%s", taskStatus) + taskCheck.Detail = "status=" + taskStatus case "": taskCheck.Detail = "status is empty" default: - taskCheck.Detail = fmt.Sprintf("unexpected status=%s", taskStatus) + taskCheck.Detail = "unexpected status=" + taskStatus } taskCheck.Metadata = map[string]any{"status": taskStatus} checks = append(checks, taskCheck) @@ -876,10 +879,8 @@ func (s *Service) ListCheckKeys(ctx context.Context, botID string) ([]string, er func (s *Service) RunCheck(ctx context.Context, botID, key string) (BotCheck, error) { // Try registered checkers first (they own dynamic keys like mcp.*). for _, checker := range s.checkers { - for _, k := range checker.CheckKeys(ctx, botID) { - if k == key { - return checker.RunCheck(ctx, botID, key), nil - } + if slices.Contains(checker.CheckKeys(ctx, botID), key) { + return checker.RunCheck(ctx, botID, key), nil } } // Fall back to builtin checks. diff --git a/internal/bots/service_test.go b/internal/bots/service_test.go index 8e6072fe..9fa683f9 100644 --- a/internal/bots/service_test.go +++ b/internal/bots/service_test.go @@ -25,11 +25,11 @@ type fakeDBTX struct { queryRowFunc func(ctx context.Context, sql string, args ...any) pgx.Row } -func (d *fakeDBTX) Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) { +func (d *fakeDBTX) Exec(context.Context, string, ...any) (pgconn.CommandTag, error) { return pgconn.CommandTag{}, nil } -func (d *fakeDBTX) Query(context.Context, string, ...interface{}) (pgx.Rows, error) { +func (d *fakeDBTX) Query(context.Context, string, ...any) (pgx.Rows, error) { return nil, nil } @@ -37,7 +37,7 @@ func (d *fakeDBTX) QueryRow(ctx context.Context, sql string, args ...any) pgx.Ro if d.queryRowFunc != nil { return d.queryRowFunc(ctx, sql, args...) } - return &fakeRow{scanFunc: func(dest ...any) error { return pgx.ErrNoRows }} + return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }} } // makeBotRow creates a fakeRow that populates a sqlc.Bot via Scan. @@ -85,7 +85,7 @@ func makeMemberRow(botID, userID pgtype.UUID) *fakeRow { } func makeNoRow() *fakeRow { - return &fakeRow{scanFunc: func(dest ...any) error { return pgx.ErrNoRows }} + return &fakeRow{scanFunc: func(_ ...any) error { return pgx.ErrNoRows }} } func mustParseUUID(s string) pgtype.UUID { @@ -188,7 +188,7 @@ func TestAuthorizeAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { db := &fakeDBTX{ - queryRowFunc: func(_ context.Context, sql string, args ...any) pgx.Row { + queryRowFunc: func(_ context.Context, _ string, args ...any) pgx.Row { // Route to bot or member row based on query. if len(args) == 1 { return makeBotRow(botUUID, ownerUUID, tt.botType, tt.allowGst) @@ -209,10 +209,8 @@ func TestAuthorizeAccess(t *testing.T) { if tt.wantErrIs != nil && err.Error() != tt.wantErrIs.Error() { t.Fatalf("expected error %q, got %q", tt.wantErrIs, err) } - } else { - if err != nil { - t.Fatalf("unexpected error: %v", err) - } + } else if err != nil { + t.Fatalf("unexpected error: %v", err) } }) } diff --git a/internal/bots/types.go b/internal/bots/types.go index ad37359b..978f012b 100644 --- a/internal/bots/types.go +++ b/internal/bots/types.go @@ -101,23 +101,27 @@ type RuntimeChecker interface { RunCheck(ctx context.Context, botID, key string) BotCheck } +// Bot type identifiers for registry and policy. const ( BotTypePersonal = "personal" BotTypePublic = "public" ) +// Bot lifecycle status values stored in the database. const ( BotStatusCreating = "creating" BotStatusReady = "ready" BotStatusDeleting = "deleting" ) +// BotCheckState is the overall state of a single check (ok, issue, or unknown). const ( BotCheckStateOK = "ok" BotCheckStateIssue = "issue" BotCheckStateUnknown = "unknown" ) +// BotCheckStatus is the status level of a check result (ok, warn, error, unknown). const ( BotCheckStatusOK = "ok" BotCheckStatusWarn = "warn" @@ -125,6 +129,7 @@ const ( BotCheckStatusUnknown = "unknown" ) +// BotCheckKey identifies which check is run (container init, task, data path, etc.). const ( BotCheckKeyContainerInit = "container.init" BotCheckKeyContainerRecord = "container.record" @@ -133,6 +138,7 @@ const ( BotCheckKeyDelete = "bot.delete" ) +// MemberRole is the role of a user in a bot (owner, admin, or member). const ( MemberRoleOwner = "owner" MemberRoleAdmin = "admin" diff --git a/internal/channel/adapter.go b/internal/channel/adapter.go index 7343b138..b2ff1ff5 100644 --- a/internal/channel/adapter.go +++ b/internal/channel/adapter.go @@ -10,7 +10,7 @@ import ( var ErrStopNotSupported = errors.New("channel connection stop not supported") // InboundHandler is a callback invoked when a message arrives from a channel. -type InboundHandler func(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error +type InboundHandler func(ctx context.Context, cfg Config, msg InboundMessage) error // StreamReplySender sends replies within a single inbound-processing scope. // It supports both one-shot delivery and streaming sessions. @@ -45,24 +45,24 @@ type ProcessingStatusHandle struct { // ProcessingStatusNotifier reports processing lifecycle updates to channel platforms. // Implementations should be best-effort and idempotent. type ProcessingStatusNotifier interface { - ProcessingStarted(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo) (ProcessingStatusHandle, error) - ProcessingCompleted(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle) error - ProcessingFailed(ctx context.Context, cfg ChannelConfig, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle, cause error) error + ProcessingStarted(ctx context.Context, cfg Config, msg InboundMessage, info ProcessingStatusInfo) (ProcessingStatusHandle, error) + ProcessingCompleted(ctx context.Context, cfg Config, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle) error + ProcessingFailed(ctx context.Context, cfg Config, msg InboundMessage, info ProcessingStatusInfo, handle ProcessingStatusHandle, cause error) error } // Adapter is the base interface every channel adapter must implement. type Adapter interface { - Type() ChannelType + Type() Type Descriptor() Descriptor } // Descriptor holds read-only metadata for a registered channel type. // It contains no behavior — all behavior is expressed through optional interfaces. type Descriptor struct { - Type ChannelType + Type Type DisplayName string Configless bool - Capabilities ChannelCapabilities + Capabilities Capabilities OutboundPolicy OutboundPolicy ConfigSchema ConfigSchema UserConfigSchema ConfigSchema @@ -89,42 +89,42 @@ type BindingMatcher interface { // Sender is an adapter capable of sending outbound messages. type Sender interface { - Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error + Send(ctx context.Context, cfg Config, msg OutboundMessage) error } // StreamSender is an adapter capable of opening outbound stream sessions. type StreamSender interface { - OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) + OpenStream(ctx context.Context, cfg Config, target string, opts StreamOptions) (OutboundStream, error) } // MessageEditor updates and deletes already-sent messages when supported. type MessageEditor interface { - Update(ctx context.Context, cfg ChannelConfig, target string, messageID string, msg Message) error - Unsend(ctx context.Context, cfg ChannelConfig, target string, messageID string) error + Update(ctx context.Context, cfg Config, target, messageID string, msg Message) error + Unsend(ctx context.Context, cfg Config, target, messageID string) error } // Reactor adds or removes emoji reactions on messages. type Reactor interface { - React(ctx context.Context, cfg ChannelConfig, target string, messageID string, emoji string) error - Unreact(ctx context.Context, cfg ChannelConfig, target string, messageID string, emoji string) error + React(ctx context.Context, cfg Config, target, messageID, emoji string) error + Unreact(ctx context.Context, cfg Config, target, messageID, emoji string) error } // SelfDiscoverer retrieves the adapter bot's own identity from the platform. -// The returned map is merged into ChannelConfig.SelfIdentity and persisted. +// The returned map is merged into Config.SelfIdentity and persisted. type SelfDiscoverer interface { DiscoverSelf(ctx context.Context, credentials map[string]any) (identity map[string]any, externalID string, err error) } // Receiver is an adapter capable of establishing a long-lived connection to receive messages. type Receiver interface { - Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) + Connect(ctx context.Context, cfg Config, handler InboundHandler) (Connection, error) } // Connection represents an active, long-lived link to a channel platform. type Connection interface { ConfigID() string BotID() string - ChannelType() ChannelType + Type() Type Stop(ctx context.Context) error Running() bool } @@ -133,17 +133,17 @@ type Connection interface { type BaseConnection struct { configID string botID string - channelType ChannelType + channelType Type stop func(ctx context.Context) error running atomic.Bool } // NewConnection creates a BaseConnection for the given config and stop function. -func NewConnection(cfg ChannelConfig, stop func(ctx context.Context) error) *BaseConnection { +func NewConnection(cfg Config, stop func(ctx context.Context) error) *BaseConnection { conn := &BaseConnection{ configID: cfg.ID, botID: cfg.BotID, - channelType: cfg.ChannelType, + channelType: cfg.Type, stop: stop, } conn.running.Store(true) @@ -160,8 +160,8 @@ func (c *BaseConnection) BotID() string { return c.botID } -// ChannelType returns the type of channel this connection serves. -func (c *BaseConnection) ChannelType() ChannelType { +// Type returns the type of channel this connection serves. +func (c *BaseConnection) Type() Type { return c.channelType } diff --git a/internal/channel/adapters/common/logging.go b/internal/channel/adapters/adapterutil/logging.go similarity index 77% rename from internal/channel/adapters/common/logging.go rename to internal/channel/adapters/adapterutil/logging.go index bc680421..218b6b8b 100644 --- a/internal/channel/adapters/common/logging.go +++ b/internal/channel/adapters/adapterutil/logging.go @@ -1,5 +1,5 @@ -// Package common provides shared utilities for channel adapters. -package common +// Package adapterutil provides shared utilities for channel adapters. +package adapterutil import "strings" diff --git a/internal/channel/adapters/feishu/config.go b/internal/channel/adapters/feishu/config.go index 74133ed4..710df2cd 100644 --- a/internal/channel/adapters/feishu/config.go +++ b/internal/channel/adapters/feishu/config.go @@ -1,7 +1,7 @@ package feishu import ( - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -65,7 +65,7 @@ func resolveTarget(raw map[string]any) (string, error) { if cfg.UserID != "" { return "user_id:" + cfg.UserID, nil } - return "", fmt.Errorf("feishu binding is incomplete") + return "", errors.New("feishu binding is incomplete") } func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { @@ -104,7 +104,7 @@ func parseConfig(raw map[string]any) (Config, error) { encryptKey := strings.TrimSpace(channel.ReadString(raw, "encryptKey", "encrypt_key")) verificationToken := strings.TrimSpace(channel.ReadString(raw, "verificationToken", "verification_token")) if appID == "" || appSecret == "" { - return Config{}, fmt.Errorf("feishu appId and appSecret are required") + return Config{}, errors.New("feishu appId and appSecret are required") } return Config{ AppID: appID, @@ -118,7 +118,7 @@ func parseUserConfig(raw map[string]any) (UserConfig, error) { openID := strings.TrimSpace(channel.ReadString(raw, "openId", "open_id")) userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) if openID == "" && userID == "" { - return UserConfig{}, fmt.Errorf("feishu user config requires open_id or user_id") + return UserConfig{}, errors.New("feishu user config requires open_id or user_id") } return UserConfig{OpenID: openID, UserID: userID}, nil } diff --git a/internal/channel/adapters/feishu/descriptor.go b/internal/channel/adapters/feishu/descriptor.go index b1fd4129..88051bdc 100644 --- a/internal/channel/adapters/feishu/descriptor.go +++ b/internal/channel/adapters/feishu/descriptor.go @@ -4,4 +4,4 @@ package feishu import "github.com/memohai/memoh/internal/channel" // Type is the registered ChannelType identifier for Feishu. -const Type channel.ChannelType = "feishu" +const Type channel.Type = "feishu" diff --git a/internal/channel/adapters/feishu/directory.go b/internal/channel/adapters/feishu/directory.go index 8493eed9..2a4ae82a 100644 --- a/internal/channel/adapters/feishu/directory.go +++ b/internal/channel/adapters/feishu/directory.go @@ -2,6 +2,7 @@ package feishu import ( "context" + "errors" "fmt" "strings" @@ -28,7 +29,7 @@ func directoryLimit(n int) int { } // ListPeers lists users (peers) from Feishu contact, optionally filtered by query. -func (a *FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *Adapter) ListPeers(ctx context.Context, cfg channel.Config, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err @@ -60,7 +61,7 @@ func (a *FeishuAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig } // ListGroups lists chat groups from Feishu IM, optionally filtered by query. -func (a *FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *Adapter) ListGroups(ctx context.Context, cfg channel.Config, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err @@ -104,17 +105,17 @@ func (a *FeishuAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfi } // ListGroupMembers lists members of a Feishu chat group. -func (a *FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *Adapter) ListGroupMembers(ctx context.Context, cfg channel.Config, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err } chatID := strings.TrimSpace(groupID) - if strings.HasPrefix(chatID, "chat_id:") { - chatID = strings.TrimPrefix(chatID, "chat_id:") + if after, ok := strings.CutPrefix(chatID, "chat_id:"); ok { + chatID = after } if chatID == "" { - return nil, fmt.Errorf("feishu list group members: empty group id") + return nil, errors.New("feishu list group members: empty group id") } client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) pageSize := directoryLimit(query.Limit) @@ -142,7 +143,7 @@ func (a *FeishuAdapter) ListGroupMembers(ctx context.Context, cfg channel.Channe } // ResolveEntry resolves an input string to a user or group DirectoryEntry. -func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { +func (a *Adapter) ResolveEntry(ctx context.Context, cfg channel.Config, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return channel.DirectoryEntry{}, err @@ -159,7 +160,7 @@ func (a *FeishuAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelCon } } -func (a *FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { +func (a *Adapter) resolveUser(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { userID, userIDType := parseFeishuUserInput(input) if userID == "" { return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry user: invalid input %q", input) @@ -176,15 +177,15 @@ func (a *FeishuAdapter) resolveUser(ctx context.Context, client *lark.Client, in return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: code=%d msg=%s", resp.Code, resp.Msg) } if resp.Data == nil || resp.Data.User == nil { - return channel.DirectoryEntry{}, fmt.Errorf("feishu get user: empty response") + return channel.DirectoryEntry{}, errors.New("feishu get user: empty response") } return feishuUserToEntry(resp.Data.User), nil } -func (a *FeishuAdapter) resolveGroup(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { +func (a *Adapter) resolveGroup(ctx context.Context, client *lark.Client, input string) (channel.DirectoryEntry, error) { chatID := strings.TrimSpace(input) - if strings.HasPrefix(chatID, "chat_id:") { - chatID = strings.TrimPrefix(chatID, "chat_id:") + if after, ok := strings.CutPrefix(chatID, "chat_id:"); ok { + chatID = after } if chatID == "" { return channel.DirectoryEntry{}, fmt.Errorf("feishu resolve entry group: invalid input %q", input) @@ -214,11 +215,11 @@ func parseFeishuUserInput(raw string) (userID, userIDType string) { if raw == "" { return "", "" } - if strings.HasPrefix(raw, "open_id:") { - return strings.TrimSpace(strings.TrimPrefix(raw, "open_id:")), larkcontact.UserIdTypeOpenId + if after, ok := strings.CutPrefix(raw, "open_id:"); ok { + return strings.TrimSpace(after), larkcontact.UserIdTypeOpenId } - if strings.HasPrefix(raw, "user_id:") { - return strings.TrimSpace(strings.TrimPrefix(raw, "user_id:")), larkcontact.UserIdTypeUserId + if after, ok := strings.CutPrefix(raw, "user_id:"); ok { + return strings.TrimSpace(after), larkcontact.UserIdTypeUserId } if strings.HasPrefix(raw, "ou_") { return raw, larkcontact.UserIdTypeOpenId diff --git a/internal/channel/adapters/feishu/feishu.go b/internal/channel/adapters/feishu/feishu.go index 9b8c4630..4d9e3ec3 100644 --- a/internal/channel/adapters/feishu/feishu.go +++ b/internal/channel/adapters/feishu/feishu.go @@ -3,6 +3,7 @@ package feishu import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "net/http" @@ -17,11 +18,11 @@ import ( larkws "github.com/larksuite/oapi-sdk-go/v3/ws" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/channel/adapters/common" + "github.com/memohai/memoh/internal/channel/adapters/adapterutil" ) -// FeishuAdapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Feishu. -type FeishuAdapter struct { +// Adapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Feishu. +type Adapter struct { logger *slog.Logger } @@ -43,7 +44,7 @@ type larkProcessingReactionGateway struct { func (g *larkProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { if g == nil || g.api == nil { - return "", fmt.Errorf("feishu reaction api not configured") + return "", errors.New("feishu reaction api not configured") } req := larkim.NewCreateMessageReactionReqBuilder(). MessageId(messageID). @@ -65,14 +66,14 @@ func (g *larkProcessingReactionGateway) Add(ctx context.Context, messageID, reac return "", fmt.Errorf("feishu add reaction failed: %s (code: %d)", msg, code) } if resp.Data == nil || resp.Data.ReactionId == nil || strings.TrimSpace(*resp.Data.ReactionId) == "" { - return "", fmt.Errorf("feishu add reaction failed: empty reaction id") + return "", errors.New("feishu add reaction failed: empty reaction id") } return strings.TrimSpace(*resp.Data.ReactionId), nil } func (g *larkProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { if g == nil || g.api == nil { - return fmt.Errorf("feishu reaction api not configured") + return errors.New("feishu reaction api not configured") } req := larkim.NewDeleteMessageReactionReqBuilder(). MessageId(messageID). @@ -94,27 +95,27 @@ func (g *larkProcessingReactionGateway) Remove(ctx context.Context, messageID, r return nil } -// NewFeishuAdapter creates a FeishuAdapter with the given logger. -func NewFeishuAdapter(log *slog.Logger) *FeishuAdapter { +// NewAdapter creates a Adapter with the given logger. +func NewAdapter(log *slog.Logger) *Adapter { if log == nil { log = slog.Default() } - return &FeishuAdapter{ + return &Adapter{ logger: log.With(slog.String("adapter", "feishu")), } } // Type returns the Feishu channel type. -func (a *FeishuAdapter) Type() channel.ChannelType { +func (a *Adapter) Type() channel.Type { return Type } // Descriptor returns the Feishu channel metadata. -func (a *FeishuAdapter) Descriptor() channel.Descriptor { +func (a *Adapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: Type, DisplayName: "Feishu", - Capabilities: channel.ChannelCapabilities{ + Capabilities: channel.Capabilities{ Text: true, RichText: true, Attachments: true, @@ -158,7 +159,7 @@ func (a *FeishuAdapter) Descriptor() channel.Descriptor { } // ProcessingStarted adds a transient reaction to indicate the inbound message is being processed. -func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (a *Adapter) ProcessingStarted(ctx context.Context, cfg channel.Config, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { messageID := strings.TrimSpace(info.SourceMessageID) if messageID == "" { return channel.ProcessingStatusHandle{}, nil @@ -175,7 +176,7 @@ func (a *FeishuAdapter) ProcessingStarted(ctx context.Context, cfg channel.Chann } // ProcessingCompleted removes the transient processing reaction before output is sent. -func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (a *Adapter) ProcessingCompleted(ctx context.Context, cfg channel.Config, _ channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { messageID := strings.TrimSpace(info.SourceMessageID) reactionID := strings.TrimSpace(handle.Token) if messageID == "" || reactionID == "" { @@ -189,23 +190,23 @@ func (a *FeishuAdapter) ProcessingCompleted(ctx context.Context, cfg channel.Cha } // ProcessingFailed removes the transient processing reaction when chat processing fails. -func (a *FeishuAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (a *Adapter) ProcessingFailed(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, _ error) error { return a.ProcessingCompleted(ctx, cfg, msg, info, handle) } -func (a *FeishuAdapter) processingReactionGateway(cfg channel.ChannelConfig) (processingReactionGateway, error) { +func (a *Adapter) processingReactionGateway(cfg channel.Config) (processingReactionGateway, error) { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err } client := lark.NewClient(feishuCfg.AppID, feishuCfg.AppSecret) - gateway := &larkProcessingReactionGateway{api: client.Im.V1.MessageReaction} + gateway := &larkProcessingReactionGateway{api: client.Im.MessageReaction} return gateway, nil } // React adds an emoji reaction to a message (implements channel.Reactor). // The target parameter is unused for Feishu; reactions are keyed by message_id. -func (a *FeishuAdapter) React(ctx context.Context, cfg channel.ChannelConfig, _ string, messageID string, emoji string) error { +func (a *Adapter) React(ctx context.Context, cfg channel.Config, _, messageID, emoji string) error { gateway, err := a.processingReactionGateway(cfg) if err != nil { return err @@ -218,7 +219,7 @@ func (a *FeishuAdapter) React(ctx context.Context, cfg channel.ChannelConfig, _ // For Feishu, this requires the reaction_id which we don't have here, so we pass // the emoji as reaction_id. If the caller stored the reaction_id from React, they // should pass it as emoji. This is a best-effort operation. -func (a *FeishuAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, _ string, messageID string, reactionID string) error { +func (a *Adapter) Unreact(ctx context.Context, cfg channel.Config, _, messageID, reactionID string) error { if strings.TrimSpace(reactionID) == "" { return nil } @@ -231,7 +232,7 @@ func (a *FeishuAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, func addProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionType string) (string, error) { if gateway == nil { - return "", fmt.Errorf("processing reaction gateway is nil") + return "", errors.New("processing reaction gateway is nil") } msgID := strings.TrimSpace(messageID) if msgID == "" { @@ -239,14 +240,14 @@ func addProcessingReaction(ctx context.Context, gateway processingReactionGatewa } rxType := strings.TrimSpace(reactionType) if rxType == "" { - return "", fmt.Errorf("processing reaction type is empty") + return "", errors.New("processing reaction type is empty") } return gateway.Add(ctx, msgID, rxType) } func removeProcessingReaction(ctx context.Context, gateway processingReactionGateway, messageID, reactionID string) error { if gateway == nil { - return fmt.Errorf("processing reaction gateway is nil") + return errors.New("processing reaction gateway is nil") } msgID := strings.TrimSpace(messageID) rxID := strings.TrimSpace(reactionID) @@ -257,7 +258,7 @@ func removeProcessingReaction(ctx context.Context, gateway processingReactionGat } // DiscoverSelf retrieves the bot's own identity from the Feishu platform. -func (a *FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { +func (a *Adapter) DiscoverSelf(ctx context.Context, credentials map[string]any) (map[string]any, string, error) { cfg, err := parseConfig(credentials) if err != nil { return nil, "", err @@ -284,7 +285,7 @@ func (a *FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string } openID := strings.TrimSpace(body.Bot.OpenID) if openID == "" { - return nil, "", fmt.Errorf("feishu discover self: empty open_id") + return nil, "", errors.New("feishu discover self: empty open_id") } identity := map[string]any{ "open_id": openID, @@ -299,37 +300,37 @@ func (a *FeishuAdapter) DiscoverSelf(ctx context.Context, credentials map[string } // NormalizeConfig validates and normalizes a Feishu channel configuration map. -func (a *FeishuAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (a *Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) } // NormalizeUserConfig validates and normalizes a Feishu user-binding configuration map. -func (a *FeishuAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (a *Adapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return normalizeUserConfig(raw) } // NormalizeTarget normalizes a Feishu delivery target string. -func (a *FeishuAdapter) NormalizeTarget(raw string) string { +func (a *Adapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } // ResolveTarget derives a delivery target from a Feishu user-binding configuration. -func (a *FeishuAdapter) ResolveTarget(userConfig map[string]any) (string, error) { +func (a *Adapter) ResolveTarget(userConfig map[string]any) (string, error) { return resolveTarget(userConfig) } // MatchBinding reports whether a Feishu user binding matches the given criteria. -func (a *FeishuAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { +func (a *Adapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { return matchBinding(config, criteria) } // BuildUserConfig constructs a Feishu user-binding config from an Identity. -func (a *FeishuAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (a *Adapter) BuildUserConfig(identity channel.Identity) map[string]any { return buildUserConfig(identity) } // Connect establishes a WebSocket connection to Feishu and forwards inbound messages to the handler. -func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { +func (a *Adapter) Connect(ctx context.Context, cfg channel.Config, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { a.logger.Info("start", slog.String("config_id", cfg.ID)) } @@ -372,7 +373,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, rawMessageID = strings.TrimSpace(*event.Event.Message.MessageId) } if event.Event.Message.MessageType != nil { - rawMessageType = strings.TrimSpace(string(*event.Event.Message.MessageType)) + rawMessageType = strings.TrimSpace(*event.Event.Message.MessageType) } } if text == "" && len(msg.Message.Attachments) == 0 { @@ -403,7 +404,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, slog.String("route_key", msg.RoutingKey()), slog.String("chat_type", msg.Conversation.Type), slog.Bool("is_mentioned", isMentioned), - slog.String("text", common.SummarizeText(text)), + slog.String("text", adapterutil.SummarizeText(text)), ) } go func() { @@ -469,7 +470,7 @@ func (a *FeishuAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, } // Send delivers an outbound message to Feishu, handling attachments, rich text, and replies. -func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { +func (a *Adapter) Send(ctx context.Context, cfg channel.Config, msg channel.OutboundMessage) error { feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { if a.logger != nil { @@ -508,7 +509,7 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg msgType = larkim.MsgTypeText text := strings.TrimSpace(msg.Message.PlainText()) if text == "" { - return fmt.Errorf("message is required") + return errors.New("message is required") } payload, marshalErr := json.Marshal(map[string]string{"text": text}) if marshalErr != nil { @@ -537,20 +538,20 @@ func (a *FeishuAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg Uuid(uuid.NewString()). Build()). Build() - resp, err := client.Im.V1.Message.Reply(ctx, replyReq) + resp, err := client.Im.Message.Reply(ctx, replyReq) return a.handleReplyResponse(cfg.ID, resp, err) } - resp, err := client.Im.V1.Message.Create(ctx, req) + resp, err := client.Im.Message.Create(ctx, req) return a.handleResponse(cfg.ID, resp, err) } // OpenStream opens a Feishu streaming session. // The adapter strategy uses one interactive card and patches it incrementally. -func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *Adapter) OpenStream(ctx context.Context, cfg channel.Config, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("feishu target is required") + return nil, errors.New("feishu target is required") } feishuCfg, err := parseConfig(cfg.Credentials) if err != nil { @@ -578,7 +579,7 @@ func (a *FeishuAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfi }, nil } -func (a *FeishuAdapter) handleReplyResponse(configID string, resp *larkim.ReplyMessageResp, err error) error { +func (a *Adapter) handleReplyResponse(configID string, resp *larkim.ReplyMessageResp, err error) error { if err != nil { if a.logger != nil { a.logger.Error("reply failed", slog.String("config_id", configID), slog.Any("error", err)) @@ -603,7 +604,7 @@ func (a *FeishuAdapter) handleReplyResponse(configID string, resp *larkim.ReplyM return nil } -func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessageResp, err error) error { +func (a *Adapter) handleResponse(configID string, resp *larkim.CreateMessageResp, err error) error { if err != nil { if a.logger != nil { a.logger.Error("send failed", slog.String("config_id", configID), slog.Any("error", err)) @@ -628,7 +629,7 @@ func (a *FeishuAdapter) handleResponse(configID string, resp *larkim.CreateMessa return nil } -func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, receiveID, receiveType string, att channel.Attachment, text string) error { +func (a *Adapter) sendAttachment(ctx context.Context, client *lark.Client, receiveID, receiveType string, att channel.Attachment, _ string) error { var msgType string var contentMap map[string]string sourcePlatform := strings.TrimSpace(att.SourcePlatform) @@ -644,7 +645,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, } else { downloadURL := strings.TrimSpace(att.URL) if downloadURL == "" { - return fmt.Errorf("failed to download attachment: url is required when platform key is unavailable") + return errors.New("failed to download attachment: url is required when platform key is unavailable") } httpReq, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil) if err != nil { @@ -655,7 +656,11 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, if err != nil { return fmt.Errorf("failed to download attachment: %w", err) } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + a.logger.Warn("download attachment: close response body failed", slog.Any("error", err)) + } + }() if resp.StatusCode != http.StatusOK { return fmt.Errorf("failed to download attachment, status: %d", resp.StatusCode) } @@ -666,7 +671,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, Image(resp.Body). Build()). Build() - uploadResp, err := client.Im.V1.Image.Create(ctx, uploadReq) + uploadResp, err := client.Im.Image.Create(ctx, uploadReq) if err != nil { return fmt.Errorf("failed to upload image: %w", err) } @@ -692,7 +697,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, File(resp.Body). Build()). Build() - uploadResp, err := client.Im.V1.File.Create(ctx, uploadReq) + uploadResp, err := client.Im.File.Create(ctx, uploadReq) if err != nil { return fmt.Errorf("failed to upload file: %w", err) } @@ -722,7 +727,7 @@ func (a *FeishuAdapter) sendAttachment(ctx context.Context, client *lark.Client, Build()). Build() - sendResp, err := client.Im.V1.Message.Create(ctx, req) + sendResp, err := client.Im.Message.Create(ctx, req) return a.handleResponse("", sendResp, err) } @@ -747,7 +752,7 @@ func resolveFeishuFileType(name, mime string) string { } } -func (a *FeishuAdapter) buildPostContent(msg channel.Message) (string, error) { +func (a *Adapter) buildPostContent(msg channel.Message) (string, error) { type postContent struct { ZhCn struct { Title string `json:"title"` diff --git a/internal/channel/adapters/feishu/feishu_integration_test.go b/internal/channel/adapters/feishu/feishu_integration_test.go index d749556b..d51f3acd 100644 --- a/internal/channel/adapters/feishu/feishu_integration_test.go +++ b/internal/channel/adapters/feishu/feishu_integration_test.go @@ -28,9 +28,9 @@ func TestFeishuGateway_Integration(t *testing.T) { logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ Level: slog.LevelInfo, })) - adapter := NewFeishuAdapter(logger) + adapter := NewAdapter(logger) - cfg := channel.ChannelConfig{ + cfg := channel.Config{ ID: "integration-test-bot", Credentials: map[string]any{ "app_id": appID, @@ -45,7 +45,7 @@ func TestFeishuGateway_Integration(t *testing.T) { receivedChan := make(chan channel.InboundMessage, 1) - handler := func(ctx context.Context, c channel.ChannelConfig, msg channel.InboundMessage) error { + handler := func(ctx context.Context, c channel.Config, msg channel.InboundMessage) error { plainText := msg.Message.PlainText() logger.Info("received message in test", slog.String("text", plainText), @@ -115,7 +115,7 @@ func TestFeishuDiscoverSelf_Integration(t *testing.T) { if appID == "" || appSecret == "" { t.Skip("skipping integration test: FEISHU_APP_ID or FEISHU_APP_SECRET not set") } - adapter := NewFeishuAdapter(nil) + adapter := NewAdapter(nil) credentials := map[string]any{ "app_id": appID, "app_secret": appSecret, diff --git a/internal/channel/adapters/feishu/feishu_test.go b/internal/channel/adapters/feishu/feishu_test.go index 42da2b86..363a97f5 100644 --- a/internal/channel/adapters/feishu/feishu_test.go +++ b/internal/channel/adapters/feishu/feishu_test.go @@ -22,7 +22,7 @@ type fakeProcessingReactionGateway struct { removeErr error } -func (g *fakeProcessingReactionGateway) Add(ctx context.Context, messageID, reactionType string) (string, error) { +func (g *fakeProcessingReactionGateway) Add(_ context.Context, messageID, reactionType string) (string, error) { g.addCalls = append(g.addCalls, struct{ messageID, reactionType string }{ messageID: messageID, reactionType: reactionType, @@ -35,7 +35,7 @@ func (g *fakeProcessingReactionGateway) Add(ctx context.Context, messageID, reac return resp.reactionID, resp.err } -func (g *fakeProcessingReactionGateway) Remove(ctx context.Context, messageID, reactionID string) error { +func (g *fakeProcessingReactionGateway) Remove(_ context.Context, messageID, reactionID string) error { g.removeCalls = append(g.removeCalls, struct{ messageID, reactionID string }{ messageID: messageID, reactionID: reactionID, @@ -207,7 +207,7 @@ func TestExtractFeishuInboundImageAttachmentReference(t *testing.T) { func TestFeishuDescriptorIncludesStreamingAndMedia(t *testing.T) { t.Parallel() - adapter := NewFeishuAdapter(nil) + adapter := NewAdapter(nil) caps := adapter.Descriptor().Capabilities if !caps.Streaming { t.Fatal("expected streaming capability") @@ -534,10 +534,10 @@ func TestRemoveProcessingReactionNoopForEmptyToken(t *testing.T) { func TestFeishuProcessingStartedNoSourceMessageID(t *testing.T) { t.Parallel() - adapter := NewFeishuAdapter(nil) + adapter := NewAdapter(nil) handle, err := adapter.ProcessingStarted( context.Background(), - channel.ChannelConfig{}, + channel.Config{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, ) @@ -552,10 +552,10 @@ func TestFeishuProcessingStartedNoSourceMessageID(t *testing.T) { func TestFeishuProcessingStartedRequiresConfigWhenSourceMessageExists(t *testing.T) { t.Parallel() - adapter := NewFeishuAdapter(nil) + adapter := NewAdapter(nil) _, err := adapter.ProcessingStarted( context.Background(), - channel.ChannelConfig{}, + channel.Config{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{SourceMessageID: "om_5"}, ) @@ -567,10 +567,10 @@ func TestFeishuProcessingStartedRequiresConfigWhenSourceMessageExists(t *testing func TestFeishuProcessingCompletedNoopWithoutToken(t *testing.T) { t.Parallel() - adapter := NewFeishuAdapter(nil) + adapter := NewAdapter(nil) err := adapter.ProcessingCompleted( context.Background(), - channel.ChannelConfig{}, + channel.Config{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{SourceMessageID: "om_6"}, channel.ProcessingStatusHandle{}, diff --git a/internal/channel/adapters/feishu/inbound.go b/internal/channel/adapters/feishu/inbound.go index 54bee3f4..7376f5fc 100644 --- a/internal/channel/adapters/feishu/inbound.go +++ b/internal/channel/adapters/feishu/inbound.go @@ -2,8 +2,10 @@ package feishu import ( "encoding/json" + "errors" "fmt" "log/slog" + "slices" "strings" "time" @@ -149,10 +151,7 @@ func isFeishuBotMentioned(contentMap map[string]any, mentions []*larkim.MentionE return true } } - if matchFeishuContentMention(contentMap, botOpenID) { - return true - } - return false + return matchFeishuContentMention(contentMap, botOpenID) } // hasAnyFeishuMention is the fallback when the bot's open_id is unknown. @@ -223,10 +222,8 @@ func hasFeishuAtTag(raw any) bool { } } case []any: - for _, child := range value { - if hasFeishuAtTag(child) { - return true - } + if slices.ContainsFunc(value, hasFeishuAtTag) { + return true } } return false @@ -303,16 +300,16 @@ func stringValue(raw any) string { // resolveFeishuReceiveID parses target (open_id:/user_id:/chat_id: prefix) and returns receiveID and receiveType. func resolveFeishuReceiveID(raw string) (string, string, error) { if raw == "" { - return "", "", fmt.Errorf("feishu target is required") + return "", "", errors.New("feishu target is required") } - if strings.HasPrefix(raw, "open_id:") { - return strings.TrimPrefix(raw, "open_id:"), larkim.ReceiveIdTypeOpenId, nil + if after, ok := strings.CutPrefix(raw, "open_id:"); ok { + return after, larkim.ReceiveIdTypeOpenId, nil } - if strings.HasPrefix(raw, "user_id:") { - return strings.TrimPrefix(raw, "user_id:"), larkim.ReceiveIdTypeUserId, nil + if after, ok := strings.CutPrefix(raw, "user_id:"); ok { + return after, larkim.ReceiveIdTypeUserId, nil } - if strings.HasPrefix(raw, "chat_id:") { - return strings.TrimPrefix(raw, "chat_id:"), larkim.ReceiveIdTypeChatId, nil + if after, ok := strings.CutPrefix(raw, "chat_id:"); ok { + return after, larkim.ReceiveIdTypeChatId, nil } return raw, larkim.ReceiveIdTypeOpenId, nil } diff --git a/internal/channel/adapters/feishu/stream.go b/internal/channel/adapters/feishu/stream.go index ce063119..e1c25e19 100644 --- a/internal/channel/adapters/feishu/stream.go +++ b/internal/channel/adapters/feishu/stream.go @@ -3,6 +3,7 @@ package feishu import ( "context" "encoding/json" + "errors" "fmt" "regexp" "strings" @@ -19,12 +20,12 @@ import ( const ( feishuStreamThinkingText = "Thinking..." feishuStreamPatchInterval = 700 * time.Millisecond - feishuStreamMaxRunes = 8000 + feishuStreamMaxRunes = 8000 ) type feishuOutboundStream struct { - adapter *FeishuAdapter - cfg channel.ChannelConfig + adapter *Adapter + cfg channel.Config target string reply *channel.ReplyRef client *lark.Client @@ -40,10 +41,10 @@ type feishuOutboundStream struct { func (s *feishuOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { if s == nil || s.adapter == nil { - return fmt.Errorf("feishu stream not configured") + return errors.New("feishu stream not configured") } if s.closed.Load() { - return fmt.Errorf("feishu stream is closed") + return errors.New("feishu stream is closed") } select { case <-ctx.Done(): @@ -130,7 +131,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro return nil } if s.client == nil { - return fmt.Errorf("feishu client not configured") + return errors.New("feishu client not configured") } content, err := buildFeishuStreamCardContent(text) if err != nil { @@ -145,7 +146,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro Uuid(uuid.NewString()). Build()). Build() - replyResp, err := s.client.Im.V1.Message.Reply(ctx, replyReq) + replyResp, err := s.client.Im.Message.Reply(ctx, replyReq) if err != nil { return err } @@ -157,7 +158,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro return fmt.Errorf("feishu stream reply failed: %s (code: %d)", msg, code) } if replyResp.Data == nil || replyResp.Data.MessageId == nil || strings.TrimSpace(*replyResp.Data.MessageId) == "" { - return fmt.Errorf("feishu stream reply failed: empty message id") + return errors.New("feishu stream reply failed: empty message id") } s.cardMessageID = strings.TrimSpace(*replyResp.Data.MessageId) s.lastPatched = normalizeFeishuStreamText(text) @@ -173,7 +174,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro Uuid(uuid.NewString()). Build()). Build() - createResp, err := s.client.Im.V1.Message.Create(ctx, createReq) + createResp, err := s.client.Im.Message.Create(ctx, createReq) if err != nil { return err } @@ -185,7 +186,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro return fmt.Errorf("feishu stream create failed: %s (code: %d)", msg, code) } if createResp.Data == nil || createResp.Data.MessageId == nil || strings.TrimSpace(*createResp.Data.MessageId) == "" { - return fmt.Errorf("feishu stream create failed: empty message id") + return errors.New("feishu stream create failed: empty message id") } s.cardMessageID = strings.TrimSpace(*createResp.Data.MessageId) s.lastPatched = normalizeFeishuStreamText(text) @@ -195,7 +196,7 @@ func (s *feishuOutboundStream) ensureCard(ctx context.Context, text string) erro func (s *feishuOutboundStream) patchCard(ctx context.Context, text string) error { if strings.TrimSpace(s.cardMessageID) == "" { - return fmt.Errorf("feishu stream card message not initialized") + return errors.New("feishu stream card message not initialized") } contentText := normalizeFeishuStreamText(text) if contentText == s.lastPatched { @@ -211,7 +212,7 @@ func (s *feishuOutboundStream) patchCard(ctx context.Context, text string) error Content(content). Build()). Build() - patchResp, err := s.client.Im.V1.Message.Patch(ctx, patchReq) + patchResp, err := s.client.Im.Message.Patch(ctx, patchReq) if err != nil { return err } diff --git a/internal/channel/adapters/local/cli.go b/internal/channel/adapters/local/cli.go index 3b026fb6..da455417 100644 --- a/internal/channel/adapters/local/cli.go +++ b/internal/channel/adapters/local/cli.go @@ -2,7 +2,7 @@ package local import ( "context" - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -19,7 +19,7 @@ func NewCLIAdapter(hub *RouteHub) *CLIAdapter { } // Type returns the CLI channel type. -func (a *CLIAdapter) Type() channel.ChannelType { +func (a *CLIAdapter) Type() channel.Type { return CLIType } @@ -29,7 +29,7 @@ func (a *CLIAdapter) Descriptor() channel.Descriptor { Type: CLIType, DisplayName: "CLI", Configless: true, - Capabilities: channel.ChannelCapabilities{ + Capabilities: channel.Capabilities{ Text: true, Reply: true, Attachments: true, @@ -46,29 +46,29 @@ func (a *CLIAdapter) Descriptor() channel.Descriptor { } // Send publishes an outbound message to the CLI route hub. -func (a *CLIAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { +func (a *CLIAdapter) Send(_ context.Context, _ channel.Config, msg channel.OutboundMessage) error { if a.hub == nil { - return fmt.Errorf("cli hub not configured") + return errors.New("cli hub not configured") } target := strings.TrimSpace(msg.Target) if target == "" { - return fmt.Errorf("cli target is required") + return errors.New("cli target is required") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } a.hub.Publish(target, msg) return nil } // OpenStream opens a local stream session bound to the target route. -func (a *CLIAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *CLIAdapter) OpenStream(ctx context.Context, _ channel.Config, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { if a.hub == nil { - return nil, fmt.Errorf("cli hub not configured") + return nil, errors.New("cli hub not configured") } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("cli target is required") + return nil, errors.New("cli target is required") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/local/descriptor.go b/internal/channel/adapters/local/descriptor.go index fe262a71..3c65bb26 100644 --- a/internal/channel/adapters/local/descriptor.go +++ b/internal/channel/adapters/local/descriptor.go @@ -5,7 +5,7 @@ import "github.com/memohai/memoh/internal/channel" const ( // CLIType is the registered ChannelType for the CLI adapter. - CLIType channel.ChannelType = "cli" + CLIType channel.Type = "cli" // WebType is the registered ChannelType for the Web adapter. - WebType channel.ChannelType = "web" + WebType channel.Type = "web" ) diff --git a/internal/channel/adapters/local/hub.go b/internal/channel/adapters/local/hub.go index 0fef9edb..935a2f93 100644 --- a/internal/channel/adapters/local/hub.go +++ b/internal/channel/adapters/local/hub.go @@ -2,7 +2,7 @@ package local import ( "context" - "fmt" + "errors" "sync" "sync/atomic" @@ -107,10 +107,10 @@ func newLocalOutboundStream(hub *RouteHub, target string) channel.OutboundStream func (s *localOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { if s == nil || s.hub == nil { - return fmt.Errorf("route hub not configured") + return errors.New("route hub not configured") } if s.closed.Load() { - return fmt.Errorf("stream is closed") + return errors.New("stream is closed") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/local/web.go b/internal/channel/adapters/local/web.go index 70309748..3bcf78c2 100644 --- a/internal/channel/adapters/local/web.go +++ b/internal/channel/adapters/local/web.go @@ -2,7 +2,7 @@ package local import ( "context" - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -19,7 +19,7 @@ func NewWebAdapter(hub *RouteHub) *WebAdapter { } // Type returns the Web channel type. -func (a *WebAdapter) Type() channel.ChannelType { +func (a *WebAdapter) Type() channel.Type { return WebType } @@ -29,7 +29,7 @@ func (a *WebAdapter) Descriptor() channel.Descriptor { Type: WebType, DisplayName: "Web", Configless: true, - Capabilities: channel.ChannelCapabilities{ + Capabilities: channel.Capabilities{ Text: true, Reply: true, Attachments: true, @@ -46,29 +46,29 @@ func (a *WebAdapter) Descriptor() channel.Descriptor { } // Send publishes an outbound message to the Web route hub. -func (a *WebAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { +func (a *WebAdapter) Send(_ context.Context, _ channel.Config, msg channel.OutboundMessage) error { if a.hub == nil { - return fmt.Errorf("web hub not configured") + return errors.New("web hub not configured") } target := strings.TrimSpace(msg.Target) if target == "" { - return fmt.Errorf("web target is required") + return errors.New("web target is required") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } a.hub.Publish(target, msg) return nil } // OpenStream opens a local stream session bound to the target route. -func (a *WebAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *WebAdapter) OpenStream(ctx context.Context, _ channel.Config, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { if a.hub == nil { - return nil, fmt.Errorf("web hub not configured") + return nil, errors.New("web hub not configured") } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("web target is required") + return nil, errors.New("web target is required") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/telegram/config.go b/internal/channel/adapters/telegram/config.go index d77e2feb..bf835a6d 100644 --- a/internal/channel/adapters/telegram/config.go +++ b/internal/channel/adapters/telegram/config.go @@ -1,7 +1,7 @@ package telegram import ( - "fmt" + "errors" "strings" "github.com/memohai/memoh/internal/channel" @@ -65,7 +65,7 @@ func resolveTarget(raw map[string]any) (string, error) { } return name, nil } - return "", fmt.Errorf("telegram binding is incomplete") + return "", errors.New("telegram binding is incomplete") } func matchBinding(raw map[string]any, criteria channel.BindingCriteria) bool { @@ -107,7 +107,7 @@ func buildUserConfig(identity channel.Identity) map[string]any { func parseConfig(raw map[string]any) (Config, error) { token := strings.TrimSpace(channel.ReadString(raw, "botToken", "bot_token")) if token == "" { - return Config{}, fmt.Errorf("telegram botToken is required") + return Config{}, errors.New("telegram botToken is required") } return Config{BotToken: token}, nil } @@ -117,7 +117,7 @@ func parseUserConfig(raw map[string]any) (UserConfig, error) { userID := strings.TrimSpace(channel.ReadString(raw, "userId", "user_id")) chatID := strings.TrimSpace(channel.ReadString(raw, "chatId", "chat_id")) if username == "" && userID == "" && chatID == "" { - return UserConfig{}, fmt.Errorf("telegram user config requires username, user_id, or chat_id") + return UserConfig{}, errors.New("telegram user config requires username, user_id, or chat_id") } return UserConfig{ Username: username, diff --git a/internal/channel/adapters/telegram/descriptor.go b/internal/channel/adapters/telegram/descriptor.go index 4638e9aa..dedbf482 100644 --- a/internal/channel/adapters/telegram/descriptor.go +++ b/internal/channel/adapters/telegram/descriptor.go @@ -4,4 +4,4 @@ package telegram import "github.com/memohai/memoh/internal/channel" // Type is the registered ChannelType identifier for Telegram. -const Type channel.ChannelType = "telegram" +const Type channel.Type = "telegram" diff --git a/internal/channel/adapters/telegram/directory.go b/internal/channel/adapters/telegram/directory.go index 7901cdeb..4614608f 100644 --- a/internal/channel/adapters/telegram/directory.go +++ b/internal/channel/adapters/telegram/directory.go @@ -2,6 +2,7 @@ package telegram import ( "context" + "errors" "fmt" "strconv" "strings" @@ -27,17 +28,17 @@ func directoryLimit(n int) int { } // ListPeers returns users the bot can reach. Telegram Bot API does not provide a list of users; returns empty. -func (a *TelegramAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *Adapter) ListPeers(_ context.Context, _ channel.Config, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } // ListGroups returns chats the bot is in. Telegram Bot API does not provide a list of chats; returns empty. -func (a *TelegramAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *Adapter) ListGroups(_ context.Context, _ channel.Config, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } // ListGroupMembers returns administrators of the given group (Telegram only exposes admin list, not full members). -func (a *TelegramAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *Adapter) ListGroupMembers(_ context.Context, cfg channel.Config, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return nil, err @@ -76,7 +77,7 @@ func (a *TelegramAdapter) ListGroupMembers(ctx context.Context, cfg channel.Chan } // ResolveEntry resolves an input string to a user or group DirectoryEntry using getChat / getChatMember. -func (a *TelegramAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { +func (a *Adapter) ResolveEntry(ctx context.Context, cfg channel.Config, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return channel.DirectoryEntry{}, err @@ -96,7 +97,7 @@ func (a *TelegramAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelC } } -func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { +func (a *Adapter) resolveTelegramUser(_ context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { chatID, userID := parseTelegramUserInput(input) if chatID == 0 && userID == 0 { return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry user: invalid input %q", input) @@ -114,7 +115,7 @@ func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: %w", err) } if member.User == nil { - return channel.DirectoryEntry{}, fmt.Errorf("telegram get chat member: empty user") + return channel.DirectoryEntry{}, errors.New("telegram get chat member: empty user") } return telegramUserToEntry(member.User), nil } @@ -147,7 +148,7 @@ func (a *TelegramAdapter) resolveTelegramUser(ctx context.Context, bot *tgbotapi }, nil } -func (a *TelegramAdapter) resolveTelegramGroup(ctx context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { +func (a *Adapter) resolveTelegramGroup(_ context.Context, bot *tgbotapi.BotAPI, input string) (channel.DirectoryEntry, error) { chatID, superGroupUsername := parseTelegramChatInput(input) if chatID == 0 && superGroupUsername == "" { return channel.DirectoryEntry{}, fmt.Errorf("telegram resolve entry group: invalid input %q", input) @@ -169,11 +170,11 @@ func (a *TelegramAdapter) resolveTelegramGroup(ctx context.Context, bot *tgbotap handle = "@" + handle } return channel.DirectoryEntry{ - Kind: channel.DirectoryEntryGroup, - ID: idStr, - Name: name, - Handle: handle, - Metadata: map[string]any{"chat_id": idStr, "type": chat.Type}, + Kind: channel.DirectoryEntryGroup, + ID: idStr, + Name: name, + Handle: handle, + Metadata: map[string]any{"chat_id": idStr, "type": chat.Type}, }, nil } @@ -199,9 +200,9 @@ func parseTelegramUserInput(input string) (chatID, userID int64) { if input == "" { return 0, 0 } - if idx := strings.Index(input, ":"); idx >= 0 { - left := strings.TrimSpace(input[:idx]) - right := strings.TrimSpace(input[idx+1:]) + if before, after, ok := strings.Cut(input, ":"); ok { + left := strings.TrimSpace(before) + right := strings.TrimSpace(after) cid, err1 := strconv.ParseInt(left, 10, 64) uid, err2 := strconv.ParseInt(right, 10, 64) if err1 == nil && err2 == nil && cid != 0 && uid != 0 { diff --git a/internal/channel/adapters/telegram/directory_test.go b/internal/channel/adapters/telegram/directory_test.go index ebaf313e..936b267e 100644 --- a/internal/channel/adapters/telegram/directory_test.go +++ b/internal/channel/adapters/telegram/directory_test.go @@ -32,8 +32,8 @@ func Test_directoryLimit(t *testing.T) { func Test_parseTelegramChatInput(t *testing.T) { tests := []struct { - input string - wantID int64 + input string + wantID int64 wantUsername string }{ {"123456789", 123456789, ""}, diff --git a/internal/channel/adapters/telegram/logger.go b/internal/channel/adapters/telegram/logger.go index e2fa7645..2d0943a2 100644 --- a/internal/channel/adapters/telegram/logger.go +++ b/internal/channel/adapters/telegram/logger.go @@ -10,10 +10,10 @@ type slogBotLogger struct { log *slog.Logger } -func (s *slogBotLogger) Println(v ...interface{}) { +func (s *slogBotLogger) Println(v ...any) { s.log.Warn(fmt.Sprint(v...)) } -func (s *slogBotLogger) Printf(format string, v ...interface{}) { +func (s *slogBotLogger) Printf(format string, v ...any) { s.log.Warn(fmt.Sprintf(format, v...)) } diff --git a/internal/channel/adapters/telegram/markdown.go b/internal/channel/adapters/telegram/markdown.go index a3fba424..6cbd8541 100644 --- a/internal/channel/adapters/telegram/markdown.go +++ b/internal/channel/adapters/telegram/markdown.go @@ -11,19 +11,17 @@ import ( ) const ( - codeBlockPlaceholder = "\x00CB" inlineCodePlaceholder = "\x00IC" ) var ( - reCodeBlockFence = regexp.MustCompile("(?s)```(\\w*)\\n?(.*?)```") - reInlineCode = regexp.MustCompile("`([^`\\n]+?)`") - reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) - reStrike = regexp.MustCompile(`~~(.+?)~~`) - reLink = regexp.MustCompile(`\[([^\]]+?)\]\(([^)]+?)\)`) - reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) - reListBullet = regexp.MustCompile(`(?m)^(\s*)[-+]\s`) - reItalic = regexp.MustCompile(`\*([^*\n]+?)\*`) + reInlineCode = regexp.MustCompile("`([^`\\n]+?)`") + reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) + reStrike = regexp.MustCompile(`~~(.+?)~~`) + reLink = regexp.MustCompile(`\[([^\]]+?)\]\(([^)]+?)\)`) + reHeading = regexp.MustCompile(`(?m)^#{1,6}\s+(.+)$`) + reListBullet = regexp.MustCompile(`(?m)^(\s*)[-+]\s`) + reItalic = regexp.MustCompile(`\*([^*\n]+?)\*`) ) // formatTelegramOutput converts standard markdown to Telegram-compatible HTML @@ -86,8 +84,8 @@ func splitCodeBlocks(text string) []string { } segments = append(segments, text[:start]) rest := text[start+len(fence):] - end := strings.Index(rest, fence) - if end < 0 { + before, after, ok := strings.Cut(rest, fence) + if !ok { // Unclosed code block: treat remainder as normal text. segments = append(segments, text[start:]) // Remove the last normal segment and replace with full remainder. @@ -95,16 +93,16 @@ func splitCodeBlocks(text string) []string { segments = segments[:len(segments)-1] break } - segments = append(segments, rest[:end]) - text = rest[end+len(fence):] + segments = append(segments, before) + text = after } return segments } // extractCodeBlockLang separates the optional language tag from code content. func extractCodeBlockLang(block string) (string, string) { - idx := strings.IndexByte(block, '\n') - if idx < 0 { + before, after, ok := strings.Cut(block, "\n") + if !ok { // Single line: check if it looks like a language tag. trimmed := strings.TrimSpace(block) if trimmed != "" && !strings.Contains(trimmed, " ") && len(trimmed) <= 20 { @@ -112,8 +110,8 @@ func extractCodeBlockLang(block string) (string, string) { } return "", block } - firstLine := strings.TrimSpace(block[:idx]) - rest := block[idx+1:] + firstLine := strings.TrimSpace(before) + rest := after if firstLine != "" && !strings.Contains(firstLine, " ") && len(firstLine) <= 20 { return firstLine, rest } diff --git a/internal/channel/adapters/telegram/stream.go b/internal/channel/adapters/telegram/stream.go index 0ef9976e..798d80dd 100644 --- a/internal/channel/adapters/telegram/stream.go +++ b/internal/channel/adapters/telegram/stream.go @@ -2,6 +2,7 @@ package telegram import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -16,11 +17,11 @@ import ( const telegramStreamEditThrottle = 5000 * time.Millisecond -var testEditFunc func(bot *tgbotapi.BotAPI, chatID int64, msgID int, text string, parseMode string) error +var testEditFunc func(bot *tgbotapi.BotAPI, chatID int64, msgID int, text, parseMode string) error type telegramOutboundStream struct { - adapter *TelegramAdapter - cfg channel.ChannelConfig + adapter *Adapter + cfg channel.Config target string reply *channel.ReplyRef parseMode string @@ -33,7 +34,7 @@ type telegramOutboundStream struct { lastEditedAt time.Time } -func (s *telegramOutboundStream) getBotAndReply(ctx context.Context) (bot *tgbotapi.BotAPI, replyTo int, err error) { +func (s *telegramOutboundStream) getBotAndReply(_ context.Context) (bot *tgbotapi.BotAPI, replyTo int, err error) { telegramCfg, err := parseConfig(s.cfg.Credentials) if err != nil { return nil, 0, err @@ -171,10 +172,10 @@ func (s *telegramOutboundStream) editStreamMessageFinal(ctx context.Context, tex func (s *telegramOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { if s == nil || s.adapter == nil { - return fmt.Errorf("telegram stream not configured") + return errors.New("telegram stream not configured") } if s.closed.Load() { - return fmt.Errorf("telegram stream is closed") + return errors.New("telegram stream is closed") } select { case <-ctx.Done(): diff --git a/internal/channel/adapters/telegram/stream_test.go b/internal/channel/adapters/telegram/stream_test.go index 106cc3aa..a3e9f6d4 100644 --- a/internal/channel/adapters/telegram/stream_test.go +++ b/internal/channel/adapters/telegram/stream_test.go @@ -2,11 +2,13 @@ package telegram import ( "context" + "errors" "strings" "testing" "time" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/memohai/memoh/internal/channel" ) @@ -23,7 +25,7 @@ func TestTelegramOutboundStream_CloseNil(t *testing.T) { func TestTelegramOutboundStream_PushClosed(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter} s.closed.Store(true) @@ -40,7 +42,7 @@ func TestTelegramOutboundStream_PushClosed(t *testing.T) { func TestTelegramOutboundStream_PushStatusNoOp(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter} ctx := context.Background() @@ -67,7 +69,7 @@ func TestTelegramOutboundStream_PushNilAdapter(t *testing.T) { func TestTelegramOutboundStream_PushUnsupportedEventType(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter} ctx := context.Background() @@ -83,7 +85,7 @@ func TestTelegramOutboundStream_PushUnsupportedEventType(t *testing.T) { func TestTelegramOutboundStream_PushEmptyDeltaNoOp(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter} ctx := context.Background() @@ -96,7 +98,7 @@ func TestTelegramOutboundStream_PushEmptyDeltaNoOp(t *testing.T) { func TestTelegramOutboundStream_PushErrorEventEmptyNoOp(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter} ctx := context.Background() @@ -109,13 +111,13 @@ func TestTelegramOutboundStream_PushErrorEventEmptyNoOp(t *testing.T) { func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter} ctx, cancel := context.WithCancel(context.Background()) cancel() err := s.Close(ctx) - if err != context.Canceled { + if !errors.Is(err, context.Canceled) { t.Fatalf("Close with canceled context should return context.Canceled: %v", err) } } @@ -124,7 +126,7 @@ func TestTelegramOutboundStream_CloseContextCanceled(t *testing.T) { func TestEditStreamMessage_NoEditWhenSameContent(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, streamChatID: 1, @@ -156,7 +158,7 @@ func TestEditStreamMessage_NoEditWhenSameContent(t *testing.T) { func TestEditStreamMessage_NoEditWhenMessageNotSent(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter, streamMsgID: 0} ctx := context.Background() @@ -169,7 +171,7 @@ func TestEditStreamMessage_NoEditWhenMessageNotSent(t *testing.T) { func TestEditStreamMessage_NoEditWhenThrottled(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, streamChatID: 1, @@ -188,11 +190,11 @@ func TestEditStreamMessage_NoEditWhenThrottled(t *testing.T) { func TestEditStreamMessage_429SetsBackoffAndReturnsNil(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) before := time.Now().Add(-time.Minute) s := &telegramOutboundStream{ adapter: adapter, - cfg: channel.ChannelConfig{ID: "test", Credentials: map[string]any{"bot_token": "fake"}}, + cfg: channel.Config{ID: "test", Credentials: map[string]any{"bot_token": "fake"}}, streamChatID: 1, streamMsgID: 1, lastEdited: "a", @@ -202,7 +204,7 @@ func TestEditStreamMessage_429SetsBackoffAndReturnsNil(t *testing.T) { origGetBot := getOrCreateBotForTest origEdit := testEditFunc - getOrCreateBotForTest = func(_ *TelegramAdapter, _, _ string) (*tgbotapi.BotAPI, error) { + getOrCreateBotForTest = func(_ *Adapter, _, _ string) (*tgbotapi.BotAPI, error) { return &tgbotapi.BotAPI{Token: "fake"}, nil } testEditFunc = func(*tgbotapi.BotAPI, int64, int, string, string) error { @@ -236,10 +238,10 @@ func TestEditStreamMessage_429SetsBackoffAndReturnsNil(t *testing.T) { func TestEditStreamMessageFinal_Success(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, - cfg: channel.ChannelConfig{ID: "test", Credentials: map[string]any{"bot_token": "fake"}}, + cfg: channel.Config{ID: "test", Credentials: map[string]any{"bot_token": "fake"}}, streamChatID: 1, streamMsgID: 1, lastEdited: "a", @@ -249,7 +251,7 @@ func TestEditStreamMessageFinal_Success(t *testing.T) { origGetBot := getOrCreateBotForTest origEdit := testEditFunc - getOrCreateBotForTest = func(_ *TelegramAdapter, _, _ string) (*tgbotapi.BotAPI, error) { + getOrCreateBotForTest = func(_ *Adapter, _, _ string) (*tgbotapi.BotAPI, error) { return &tgbotapi.BotAPI{Token: "fake"}, nil } testEditFunc = func(*tgbotapi.BotAPI, int64, int, string, string) error { @@ -275,7 +277,7 @@ func TestEditStreamMessageFinal_Success(t *testing.T) { func TestEditStreamMessageFinal_SameContentNoOp(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{ adapter: adapter, streamChatID: 1, @@ -294,7 +296,7 @@ func TestEditStreamMessageFinal_SameContentNoOp(t *testing.T) { func TestEditStreamMessageFinal_NoMessageNoOp(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) s := &telegramOutboundStream{adapter: adapter, streamMsgID: 0} ctx := context.Background() diff --git a/internal/channel/adapters/telegram/telegram.go b/internal/channel/adapters/telegram/telegram.go index 78fa5cee..2bb6db05 100644 --- a/internal/channel/adapters/telegram/telegram.go +++ b/internal/channel/adapters/telegram/telegram.go @@ -14,24 +14,24 @@ import ( tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" "github.com/memohai/memoh/internal/channel" - "github.com/memohai/memoh/internal/channel/adapters/common" + "github.com/memohai/memoh/internal/channel/adapters/adapterutil" ) const telegramMaxMessageLength = 4096 -// TelegramAdapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Telegram. -type TelegramAdapter struct { +// Adapter implements the channel.Adapter, channel.Sender, and channel.Receiver interfaces for Telegram. +type Adapter struct { logger *slog.Logger mu sync.RWMutex bots map[string]*tgbotapi.BotAPI // keyed by bot token } -// NewTelegramAdapter creates a TelegramAdapter with the given logger. -func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { +// NewAdapter creates a Adapter with the given logger. +func NewAdapter(log *slog.Logger) *Adapter { if log == nil { log = slog.Default() } - adapter := &TelegramAdapter{ + adapter := &Adapter{ logger: log.With(slog.String("adapter", "telegram")), bots: make(map[string]*tgbotapi.BotAPI), } @@ -39,9 +39,9 @@ func NewTelegramAdapter(log *slog.Logger) *TelegramAdapter { return adapter } -var getOrCreateBotForTest func(a *TelegramAdapter, token, configID string) (*tgbotapi.BotAPI, error) +var getOrCreateBotForTest func(a *Adapter, token, configID string) (*tgbotapi.BotAPI, error) -func (a *TelegramAdapter) getOrCreateBot(token, configID string) (*tgbotapi.BotAPI, error) { +func (a *Adapter) getOrCreateBot(token, configID string) (*tgbotapi.BotAPI, error) { if getOrCreateBotForTest != nil { return getOrCreateBotForTest(a, token, configID) } @@ -68,16 +68,16 @@ func (a *TelegramAdapter) getOrCreateBot(token, configID string) (*tgbotapi.BotA } // Type returns the Telegram channel type. -func (a *TelegramAdapter) Type() channel.ChannelType { +func (a *Adapter) Type() channel.Type { return Type } // Descriptor returns the Telegram channel metadata. -func (a *TelegramAdapter) Descriptor() channel.Descriptor { +func (a *Adapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: Type, DisplayName: "Telegram", - Capabilities: channel.ChannelCapabilities{ + Capabilities: channel.Capabilities{ Text: true, Markdown: true, Reply: true, @@ -115,37 +115,37 @@ func (a *TelegramAdapter) Descriptor() channel.Descriptor { } // NormalizeConfig validates and normalizes a Telegram channel configuration map. -func (a *TelegramAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { +func (a *Adapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { return normalizeConfig(raw) } // NormalizeUserConfig validates and normalizes a Telegram user-binding configuration map. -func (a *TelegramAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { +func (a *Adapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { return normalizeUserConfig(raw) } // NormalizeTarget normalizes a Telegram delivery target string. -func (a *TelegramAdapter) NormalizeTarget(raw string) string { +func (a *Adapter) NormalizeTarget(raw string) string { return normalizeTarget(raw) } // ResolveTarget derives a delivery target from a Telegram user-binding configuration. -func (a *TelegramAdapter) ResolveTarget(userConfig map[string]any) (string, error) { +func (a *Adapter) ResolveTarget(userConfig map[string]any) (string, error) { return resolveTarget(userConfig) } // MatchBinding reports whether a Telegram user binding matches the given criteria. -func (a *TelegramAdapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { +func (a *Adapter) MatchBinding(config map[string]any, criteria channel.BindingCriteria) bool { return matchBinding(config, criteria) } // BuildUserConfig constructs a Telegram user-binding config from an Identity. -func (a *TelegramAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (a *Adapter) BuildUserConfig(identity channel.Identity) map[string]any { return buildUserConfig(identity) } // Connect starts long-polling for Telegram updates and forwards messages to the handler. -func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig, handler channel.InboundHandler) (channel.Connection, error) { +func (a *Adapter) Connect(ctx context.Context, cfg channel.Config, handler channel.InboundHandler) (channel.Connection, error) { if a.logger != nil { a.logger.Info("start", slog.String("config_id", cfg.ID)) } @@ -242,7 +242,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig slog.String("chat_id", msg.Conversation.ID), slog.String("user_id", attrs["user_id"]), slog.String("username", attrs["username"]), - slog.String("text", common.SummarizeText(text)), + slog.String("text", adapterutil.SummarizeText(text)), ) } go func() { @@ -266,6 +266,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig // "Conflict: terminated by other getUpdates request" when a new // connection starts with the same bot token. for range updates { + continue // drain channel } return nil } @@ -273,7 +274,7 @@ func (a *TelegramAdapter) Connect(ctx context.Context, cfg channel.ChannelConfig } // Send delivers an outbound message to Telegram, handling text, attachments, and replies. -func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, msg channel.OutboundMessage) error { +func (a *Adapter) Send(_ context.Context, cfg channel.Config, msg channel.OutboundMessage) error { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { if a.logger != nil { @@ -283,14 +284,14 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m } to := strings.TrimSpace(msg.Target) if to == "" { - return fmt.Errorf("telegram target is required") + return errors.New("telegram target is required") } bot, err := a.getOrCreateBot(telegramCfg.BotToken, cfg.ID) if err != nil { return err } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } text := strings.TrimSpace(msg.Message.PlainText()) text, parseMode := formatTelegramOutput(text, msg.Message.Format) @@ -325,10 +326,10 @@ func (a *TelegramAdapter) Send(ctx context.Context, cfg channel.ChannelConfig, m // OpenStream opens a Telegram streaming session. // The adapter sends one message then edits it in place as deltas arrive (editMessageText), // avoiding one message per delta and rate limits. -func (a *TelegramAdapter) OpenStream(ctx context.Context, cfg channel.ChannelConfig, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (a *Adapter) OpenStream(ctx context.Context, cfg channel.Config, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("telegram target is required") + return nil, errors.New("telegram target is required") } select { case <-ctx.Done(): @@ -413,13 +414,13 @@ func parseReplyToMessageID(reply *channel.ReplyRef) int { return value } -func sendTelegramText(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) error { +func sendTelegramText(bot *tgbotapi.BotAPI, target, text string, replyTo int, parseMode string) error { _, _, err := sendTelegramTextReturnMessage(bot, target, text, replyTo, parseMode) return err } // sendTelegramTextReturnMessage sends a text message and returns the chat ID and message ID for later editing. -func sendTelegramTextReturnMessage(bot *tgbotapi.BotAPI, target string, text string, replyTo int, parseMode string) (chatID int64, messageID int, err error) { +func sendTelegramTextReturnMessage(bot *tgbotapi.BotAPI, target, text string, replyTo int, parseMode string) (chatID int64, messageID int, err error) { text = truncateTelegramText(sanitizeTelegramText(text)) var sent tgbotapi.Message if strings.HasPrefix(target, "@") { @@ -435,7 +436,7 @@ func sendTelegramTextReturnMessage(bot *tgbotapi.BotAPI, target string, text str } else { chatID, err = strconv.ParseInt(target, 10, 64) if err != nil { - return 0, 0, fmt.Errorf("telegram target must be @username or chat_id") + return 0, 0, errors.New("telegram target must be @username or chat_id") } message := tgbotapi.NewMessage(chatID, text) message.ParseMode = parseMode @@ -458,7 +459,7 @@ var sendEditForTest func(bot *tgbotapi.BotAPI, edit tgbotapi.EditMessageTextConf // editTelegramMessageText sends an edit request. It handles "message is not modified" // silently but returns 429 and other errors to the caller for higher-level retry decisions. -func editTelegramMessageText(bot *tgbotapi.BotAPI, chatID int64, messageID int, text string, parseMode string) error { +func editTelegramMessageText(bot *tgbotapi.BotAPI, chatID int64, messageID int, text, parseMode string) error { text = truncateTelegramText(sanitizeTelegramText(text)) edit := tgbotapi.NewEditMessageText(chatID, messageID, text) edit.ParseMode = parseMode @@ -511,7 +512,7 @@ func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Att keyRef := strings.TrimSpace(att.PlatformKey) sourcePlatform := strings.TrimSpace(att.SourcePlatform) if urlRef == "" && keyRef == "" { - return fmt.Errorf("attachment reference is required") + return errors.New("attachment reference is required") } if strings.TrimSpace(caption) == "" && strings.TrimSpace(att.Caption) != "" { caption = strings.TrimSpace(att.Caption) @@ -529,7 +530,7 @@ func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Att } else { chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return fmt.Errorf("telegram target must be @username or chat_id") + return errors.New("telegram target must be @username or chat_id") } photo = tgbotapi.NewPhoto(chatID, file) } @@ -552,7 +553,7 @@ func sendTelegramAttachment(bot *tgbotapi.BotAPI, target string, att channel.Att } else { chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return fmt.Errorf("telegram target must be @username or chat_id") + return errors.New("telegram target must be @username or chat_id") } document = tgbotapi.NewDocument(chatID, file) } @@ -634,7 +635,7 @@ func buildTelegramAudio(target string, file tgbotapi.RequestFileData) (tgbotapi. } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.AudioConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.AudioConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewAudio(chatID, file), nil } @@ -647,7 +648,7 @@ func buildTelegramVoice(target string, file tgbotapi.RequestFileData) (tgbotapi. } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.VoiceConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.VoiceConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewVoice(chatID, file), nil } @@ -660,7 +661,7 @@ func buildTelegramVideo(target string, file tgbotapi.RequestFileData) (tgbotapi. } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.VideoConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.VideoConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewVideo(chatID, file), nil } @@ -673,7 +674,7 @@ func buildTelegramAnimation(target string, file tgbotapi.RequestFileData) (tgbot } chatID, err := strconv.ParseInt(target, 10, 64) if err != nil { - return tgbotapi.AnimationConfig{}, fmt.Errorf("telegram target must be @username or chat_id") + return tgbotapi.AnimationConfig{}, errors.New("telegram target must be @username or chat_id") } return tgbotapi.NewAnimation(chatID, file), nil } @@ -714,7 +715,7 @@ func isTelegramBotMentioned(msg *tgbotapi.Message, botUsername string) bool { return false } -func (a *TelegramAdapter) collectTelegramAttachments(bot *tgbotapi.BotAPI, msg *tgbotapi.Message) []channel.Attachment { +func (a *Adapter) collectTelegramAttachments(bot *tgbotapi.BotAPI, msg *tgbotapi.Message) []channel.Attachment { if msg == nil { return nil } @@ -769,7 +770,7 @@ func (a *TelegramAdapter) collectTelegramAttachments(bot *tgbotapi.BotAPI, msg * return attachments } -func (a *TelegramAdapter) buildTelegramAttachment(bot *tgbotapi.BotAPI, attType channel.AttachmentType, fileID, name, mime string, size int64) channel.Attachment { +func (a *Adapter) buildTelegramAttachment(bot *tgbotapi.BotAPI, attType channel.AttachmentType, fileID, name, mime string, size int64) channel.Attachment { url := "" if bot != nil && strings.TrimSpace(fileID) != "" { value, err := bot.GetFileDirectURL(fileID) @@ -840,7 +841,7 @@ func truncateTelegramText(text string) string { } // ProcessingStarted sends a "typing" chat action to indicate processing. -func (a *TelegramAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (a *Adapter) ProcessingStarted(_ context.Context, cfg channel.Config, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { chatID := strings.TrimSpace(info.ReplyTarget) if chatID == "" { return channel.ProcessingStatusHandle{}, nil @@ -860,12 +861,12 @@ func (a *TelegramAdapter) ProcessingStarted(ctx context.Context, cfg channel.Cha } // ProcessingCompleted is a no-op for Telegram (typing indicator clears automatically). -func (a *TelegramAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (a *Adapter) ProcessingCompleted(_ context.Context, _ channel.Config, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle) error { return nil } // ProcessingFailed is a no-op for Telegram (typing indicator clears automatically). -func (a *TelegramAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (a *Adapter) ProcessingFailed(_ context.Context, _ channel.Config, _ channel.InboundMessage, _ channel.ProcessingStatusInfo, _ channel.ProcessingStatusHandle, _ error) error { return nil } @@ -898,7 +899,7 @@ func clearTelegramReaction(bot *tgbotapi.BotAPI, chatID, messageID string) error } // React adds an emoji reaction to a message (implements channel.Reactor). -func (a *TelegramAdapter) React(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, emoji string) error { +func (a *Adapter) React(_ context.Context, cfg channel.Config, target, messageID, emoji string) error { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return err @@ -912,7 +913,7 @@ func (a *TelegramAdapter) React(ctx context.Context, cfg channel.ChannelConfig, // Unreact removes the bot's reaction from a message (implements channel.Reactor). // The emoji parameter is ignored; Telegram clears all bot reactions at once. -func (a *TelegramAdapter) Unreact(ctx context.Context, cfg channel.ChannelConfig, target string, messageID string, _ string) error { +func (a *Adapter) Unreact(_ context.Context, cfg channel.Config, target, messageID, _ string) error { telegramCfg, err := parseConfig(cfg.Credentials) if err != nil { return err diff --git a/internal/channel/adapters/telegram/telegram_test.go b/internal/channel/adapters/telegram/telegram_test.go index ae469dd5..50650ce3 100644 --- a/internal/channel/adapters/telegram/telegram_test.go +++ b/internal/channel/adapters/telegram/telegram_test.go @@ -2,6 +2,7 @@ package telegram import ( "context" + "errors" "fmt" "strings" "testing" @@ -9,6 +10,7 @@ import ( "unicode/utf8" tgbotapi "github.com/go-telegram-bot-api/telegram-bot-api/v5" + "github.com/memohai/memoh/internal/channel" ) @@ -73,7 +75,7 @@ func TestIsTelegramBotMentioned(t *testing.T) { func TestTelegramDescriptorIncludesStreaming(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) caps := adapter.Descriptor().Capabilities if !caps.Streaming { t.Fatal("expected streaming capability") @@ -86,7 +88,7 @@ func TestTelegramDescriptorIncludesStreaming(t *testing.T) { func TestBuildTelegramAttachmentIncludesPlatformReference(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) att := adapter.buildTelegramAttachment(nil, channel.AttachmentFile, "file_1", "doc.txt", "text/plain", 10) if att.PlatformKey != "file_1" { t.Fatalf("unexpected platform key: %s", att.PlatformKey) @@ -169,21 +171,21 @@ func TestPickTelegramPhoto(t *testing.T) { } } -func TestTelegramAdapter_Type(t *testing.T) { +func TestAdapter_Type(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) if adapter.Type() != Type { t.Fatalf("Type should return telegram: %s", adapter.Type()) } } -func TestTelegramAdapter_OpenStreamEmptyTarget(t *testing.T) { +func TestAdapter_OpenStreamEmptyTarget(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) ctx := context.Background() - cfg := channel.ChannelConfig{} + cfg := channel.Config{} _, err := adapter.OpenStream(ctx, cfg, "", channel.StreamOptions{}) if err == nil { t.Fatal("empty target should return error") @@ -278,10 +280,10 @@ func TestBuildTelegramAnimation(t *testing.T) { } } -func TestTelegramAdapter_NormalizeAndResolve(t *testing.T) { +func TestAdapter_NormalizeAndResolve(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) norm, err := adapter.NormalizeConfig(map[string]any{"botToken": "t1"}) if err != nil { t.Fatalf("NormalizeConfig: %v", err) @@ -320,7 +322,7 @@ func TestIsTelegramMessageNotModified(t *testing.T) { want bool }{ {"nil", nil, false}, - {"plain error", fmt.Errorf("network error"), false}, + {"plain error", errors.New("network error"), false}, {"other api error", tgbotapi.Error{Code: 400, Message: "Bad Request: chat not found"}, false}, {"message is not modified", tgbotapi.Error{Code: 400, Message: productionMessageNotModified}, true}, {"production exact", tgbotapi.Error{Code: 400, Message: productionMessageNotModified}, true}, @@ -472,19 +474,19 @@ func TestEditTelegramMessageText_429ReturnsError(t *testing.T) { } } -func TestTelegramAdapter_ImplementsProcessingStatusNotifier(t *testing.T) { +func TestAdapter_ImplementsProcessingStatusNotifier(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) var _ channel.ProcessingStatusNotifier = adapter } func TestProcessingStarted_EmptyParams(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) ctx := context.Background() - cfg := channel.ChannelConfig{} + cfg := channel.Config{} msg := channel.InboundMessage{} handle, err := adapter.ProcessingStarted(ctx, cfg, msg, channel.ProcessingStatusInfo{}) @@ -499,10 +501,10 @@ func TestProcessingStarted_EmptyParams(t *testing.T) { func TestProcessingCompleted_EmptyHandle(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) ctx := context.Background() - err := adapter.ProcessingCompleted(ctx, channel.ChannelConfig{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, channel.ProcessingStatusHandle{}) + err := adapter.ProcessingCompleted(ctx, channel.Config{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, channel.ProcessingStatusHandle{}) if err != nil { t.Fatalf("empty handle should be no-op: %v", err) } @@ -511,10 +513,10 @@ func TestProcessingCompleted_EmptyHandle(t *testing.T) { func TestProcessingFailed_DelegatesToCompleted(t *testing.T) { t.Parallel() - adapter := NewTelegramAdapter(nil) + adapter := NewAdapter(nil) ctx := context.Background() - err := adapter.ProcessingFailed(ctx, channel.ChannelConfig{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, channel.ProcessingStatusHandle{}, fmt.Errorf("test")) + err := adapter.ProcessingFailed(ctx, channel.Config{}, channel.InboundMessage{}, channel.ProcessingStatusInfo{}, channel.ProcessingStatusHandle{}, errors.New("test")) if err != nil { t.Fatalf("empty handle should be no-op: %v", err) } diff --git a/internal/channel/capabilities.go b/internal/channel/capabilities.go index 2de4c90c..d9ac2eb9 100644 --- a/internal/channel/capabilities.go +++ b/internal/channel/capabilities.go @@ -1,22 +1,22 @@ package channel -// ChannelCapabilities describes the feature matrix of a channel type. +// Capabilities describes the feature matrix of a channel type. // It is used by the outbound layer to validate message content before delivery. -type ChannelCapabilities struct { - Text bool `json:"text"` - Markdown bool `json:"markdown"` - RichText bool `json:"rich_text"` - Attachments bool `json:"attachments"` - Media bool `json:"media"` - Reactions bool `json:"reactions"` - Buttons bool `json:"buttons"` - Reply bool `json:"reply"` - Threads bool `json:"threads"` - Streaming bool `json:"streaming"` - Polls bool `json:"polls"` - Edit bool `json:"edit"` - Unsend bool `json:"unsend"` - NativeCommands bool `json:"native_commands"` - BlockStreaming bool `json:"block_streaming"` - ChatTypes []string `json:"chat_types,omitempty"` +type Capabilities struct { + Text bool `json:"text"` + Markdown bool `json:"markdown"` + RichText bool `json:"rich_text"` + Attachments bool `json:"attachments"` + Media bool `json:"media"` + Reactions bool `json:"reactions"` + Buttons bool `json:"buttons"` + Reply bool `json:"reply"` + Threads bool `json:"threads"` + Streaming bool `json:"streaming"` + Polls bool `json:"polls"` + Edit bool `json:"edit"` + Unsend bool `json:"unsend"` + NativeCommands bool `json:"native_commands"` + BlockStreaming bool `json:"block_streaming"` + ChatTypes []string `json:"chat_types,omitempty"` } diff --git a/internal/channel/config_test.go b/internal/channel/config_test.go index b71e232e..ca9d0c98 100644 --- a/internal/channel/config_test.go +++ b/internal/channel/config_test.go @@ -1,23 +1,23 @@ package channel_test import ( - "fmt" + "errors" "testing" "github.com/memohai/memoh/internal/channel" ) -const testChannelType = channel.ChannelType("test-config") +const testChannelType = channel.Type("test-config") // testConfigAdapter implements Adapter, ConfigNormalizer, TargetResolver, BindingMatcher for tests. type testConfigAdapter struct{} -func (a *testConfigAdapter) Type() channel.ChannelType { return testChannelType } +func (a *testConfigAdapter) Type() channel.Type { return testChannelType } func (a *testConfigAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ Type: testChannelType, DisplayName: "Test", - Capabilities: channel.ChannelCapabilities{ + Capabilities: channel.Capabilities{ Text: true, }, ConfigSchema: channel.ConfigSchema{ @@ -38,7 +38,7 @@ func (a *testConfigAdapter) Descriptor() channel.Descriptor { func (a *testConfigAdapter) NormalizeConfig(raw map[string]any) (map[string]any, error) { value := channel.ReadString(raw, "value") if value == "" { - return nil, fmt.Errorf("value is required") + return nil, errors.New("value is required") } return map[string]any{"value": value}, nil } @@ -46,7 +46,7 @@ func (a *testConfigAdapter) NormalizeConfig(raw map[string]any) (map[string]any, func (a *testConfigAdapter) NormalizeUserConfig(raw map[string]any) (map[string]any, error) { value := channel.ReadString(raw, "user") if value == "" { - return nil, fmt.Errorf("user is required") + return nil, errors.New("user is required") } return map[string]any{"user": value}, nil } @@ -56,7 +56,7 @@ func (a *testConfigAdapter) NormalizeTarget(raw string) string { return raw } func (a *testConfigAdapter) ResolveTarget(raw map[string]any) (string, error) { value := channel.ReadString(raw, "target") if value == "" { - return "", fmt.Errorf("target is required") + return "", errors.New("target is required") } return "resolved:" + value, nil } @@ -66,7 +66,7 @@ func (a *testConfigAdapter) MatchBinding(raw map[string]any, criteria channel.Bi return value != "" && value == criteria.SubjectID } -func (a *testConfigAdapter) BuildUserConfig(identity channel.Identity) map[string]any { +func (a *testConfigAdapter) BuildUserConfig(_ channel.Identity) map[string]any { return map[string]any{} } diff --git a/internal/channel/connection.go b/internal/channel/connection.go index a64a476f..731c8b11 100644 --- a/internal/channel/connection.go +++ b/internal/channel/connection.go @@ -3,13 +3,12 @@ package channel import ( "context" "errors" - "fmt" "log/slog" "strings" ) type connectionEntry struct { - config ChannelConfig + config Config connection Connection } @@ -24,7 +23,7 @@ func (m *Manager) refresh(ctx context.Context) { if m.service == nil { return } - configs := make([]ChannelConfig, 0) + configs := make([]Config, 0) for _, channelType := range m.registry.Types() { items, err := m.service.ListConfigsByType(ctx, channelType) if err != nil { @@ -38,8 +37,8 @@ func (m *Manager) refresh(ctx context.Context) { m.reconcile(ctx, configs) } -func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { - active := map[string]ChannelConfig{} +func (m *Manager) reconcile(ctx context.Context, configs []Config) { + active := map[string]Config{} for _, cfg := range configs { if cfg.ID == "" { continue @@ -51,7 +50,7 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { active[cfg.ID] = cfg if err := m.ensureConnection(ctx, cfg); err != nil { if m.logger != nil { - m.logger.Error("adapter start failed", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID), slog.Any("error", err)) + m.logger.Error("adapter start failed", slog.String("channel", cfg.Type.String()), slog.String("config_id", cfg.ID), slog.Any("error", err)) } } } @@ -64,7 +63,7 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { } if entry != nil && entry.connection != nil { if m.logger != nil { - m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) + m.logger.Info("adapter stop", slog.String("channel", entry.config.Type.String()), slog.String("config_id", id)) } if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) @@ -74,8 +73,8 @@ func (m *Manager) reconcile(ctx context.Context, configs []ChannelConfig) { } } -func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error { - _, ok := m.registry.GetReceiver(cfg.ChannelType) +func (m *Manager) ensureConnection(ctx context.Context, cfg Config) error { + _, ok := m.registry.GetReceiver(cfg.Type) if !ok { return nil } @@ -100,12 +99,12 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error if oldConn != nil { if m.logger != nil { - m.logger.Info("adapter restart", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + m.logger.Info("adapter restart", slog.String("channel", cfg.Type.String()), slog.String("config_id", cfg.ID)) } if err := oldConn.Stop(ctx); err != nil { if errors.Is(err, ErrStopNotSupported) { if m.logger != nil { - m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + m.logger.Warn("adapter restart skipped", slog.String("channel", cfg.Type.String()), slog.String("config_id", cfg.ID)) } // Re-insert the entry since we can't restart it. m.mu.Lock() @@ -119,7 +118,7 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error } } - receiver, ok := m.registry.GetReceiver(cfg.ChannelType) + receiver, ok := m.registry.GetReceiver(cfg.Type) if !ok { return nil } @@ -134,7 +133,7 @@ func (m *Manager) ensureConnection(ctx context.Context, cfg ChannelConfig) error m.mu.Unlock() if m.logger != nil { - m.logger.Info("adapter start", slog.String("channel", cfg.ChannelType.String()), slog.String("config_id", cfg.ID)) + m.logger.Info("adapter start", slog.String("channel", cfg.Type.String()), slog.String("config_id", cfg.ID)) } handler := m.handleInbound for i := len(m.middlewares) - 1; i >= 0; i-- { @@ -167,7 +166,7 @@ func (m *Manager) stopAll(ctx context.Context) { for id, entry := range m.connections { if entry != nil && entry.connection != nil { if m.logger != nil { - m.logger.Info("adapter stop", slog.String("channel", entry.config.ChannelType.String()), slog.String("config_id", id)) + m.logger.Info("adapter stop", slog.String("channel", entry.config.Type.String()), slog.String("config_id", id)) } if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) @@ -181,7 +180,7 @@ func (m *Manager) stopAll(ctx context.Context) { func (m *Manager) Stop(ctx context.Context, configID string) error { configID = strings.TrimSpace(configID) if configID == "" { - return fmt.Errorf("config id is required") + return errors.New("config id is required") } m.mu.Lock() entry := m.connections[configID] @@ -196,7 +195,7 @@ func (m *Manager) Stop(ctx context.Context, configID string) error { func (m *Manager) StopByBot(ctx context.Context, botID string) error { botID = strings.TrimSpace(botID) if botID == "" { - return fmt.Errorf("bot id is required") + return errors.New("bot id is required") } m.mu.Lock() defer m.mu.Unlock() diff --git a/internal/channel/directory.go b/internal/channel/directory.go index 32457912..81f4cd65 100644 --- a/internal/channel/directory.go +++ b/internal/channel/directory.go @@ -5,6 +5,7 @@ import "context" // DirectoryEntryKind classifies a directory entry as a user or a group. type DirectoryEntryKind string +// DirectoryEntryKind values for listing and resolving directory entries. const ( DirectoryEntryUser DirectoryEntryKind = "user" DirectoryEntryGroup DirectoryEntryKind = "group" @@ -27,10 +28,10 @@ type DirectoryQuery struct { Kind DirectoryEntryKind `json:"kind,omitempty"` } -// ChannelDirectoryAdapter provides contact and group lookup for a channel platform. -type ChannelDirectoryAdapter interface { - ListPeers(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) - ListGroups(ctx context.Context, cfg ChannelConfig, query DirectoryQuery) ([]DirectoryEntry, error) - ListGroupMembers(ctx context.Context, cfg ChannelConfig, groupID string, query DirectoryQuery) ([]DirectoryEntry, error) - ResolveEntry(ctx context.Context, cfg ChannelConfig, input string, kind DirectoryEntryKind) (DirectoryEntry, error) +// DirectoryAdapter provides contact and group lookup for a channel platform. +type DirectoryAdapter interface { + ListPeers(ctx context.Context, cfg Config, query DirectoryQuery) ([]DirectoryEntry, error) + ListGroups(ctx context.Context, cfg Config, query DirectoryQuery) ([]DirectoryEntry, error) + ListGroupMembers(ctx context.Context, cfg Config, groupID string, query DirectoryQuery) ([]DirectoryEntry, error) + ResolveEntry(ctx context.Context, cfg Config, input string, kind DirectoryEntryKind) (DirectoryEntry, error) } diff --git a/internal/channel/helpers_test.go b/internal/channel/helpers_test.go index 50434c57..443d604b 100644 --- a/internal/channel/helpers_test.go +++ b/internal/channel/helpers_test.go @@ -85,10 +85,9 @@ func TestNormalizeChannelConfigStatus(t *testing.T) { } for _, tt := range tests { - tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, err := normalizeChannelConfigStatus(tt.input) + got, err := normalizeConfigStatus(tt.input) if tt.wantErr { if err == nil { t.Fatalf("expected error, got nil") diff --git a/internal/channel/identities/service.go b/internal/channel/identities/service.go index e48ab1cb..13393d48 100644 --- a/internal/channel/identities/service.go +++ b/internal/channel/identities/service.go @@ -1,3 +1,4 @@ +// Package identities provides channel identity resolution and linking. package identities import ( @@ -21,9 +22,8 @@ type Service struct { logger *slog.Logger } -var ( - ErrChannelIdentityNotFound = errors.New("channel identity not found") -) +// ErrChannelIdentityNotFound is returned when no channel identity exists for the given id or criteria. +var ErrChannelIdentityNotFound = errors.New("channel identity not found") // NewService creates a new channel identity service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { @@ -39,12 +39,12 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { // Create creates a new channel identity for the given channel subject. func (s *Service) Create(ctx context.Context, channel, channelSubjectID, displayName string) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } channel = normalizeChannel(channel) channelSubjectID = strings.TrimSpace(channelSubjectID) if channel == "" || channelSubjectID == "" { - return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + return ChannelIdentity{}, errors.New("channel and channel_subject_id are required") } row, err := s.queries.CreateChannelIdentity(ctx, sqlc.CreateChannelIdentityParams{ UserID: pgtype.UUID{}, @@ -63,7 +63,7 @@ func (s *Service) Create(ctx context.Context, channel, channelSubjectID, display // GetByID returns a channel identity by its ID. func (s *Service) GetByID(ctx context.Context, channelIdentityID string) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } pgID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -82,7 +82,7 @@ func (s *Service) GetByID(ctx context.Context, channelIdentityID string) (Channe // Canonicalize validates and returns the same channel identity ID. func (s *Service) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { if s.queries == nil { - return "", fmt.Errorf("channel identity queries not configured") + return "", errors.New("channel identity queries not configured") } pgID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -102,12 +102,12 @@ func (s *Service) Canonicalize(ctx context.Context, channelIdentityID string) (s // Optional meta may contain avatar_url which is stored as a dedicated column. func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string, meta map[string]any) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } channel = normalizeChannel(channel) channelSubjectID = strings.TrimSpace(channelSubjectID) if channel == "" || channelSubjectID == "" { - return ChannelIdentity{}, fmt.Errorf("channel and channel_subject_id are required") + return ChannelIdentity{}, errors.New("channel and channel_subject_id are required") } avatarURL := "" @@ -134,7 +134,7 @@ func (s *Service) ResolveByChannelIdentity(ctx context.Context, channel, channel // UpsertChannelIdentity creates or updates a channel identity mapping. func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSubjectID, displayName string, metadata map[string]any) (ChannelIdentity, error) { if s.queries == nil { - return ChannelIdentity{}, fmt.Errorf("channel identity queries not configured") + return ChannelIdentity{}, errors.New("channel identity queries not configured") } channel = normalizeChannel(channel) channelSubjectID = strings.TrimSpace(channelSubjectID) @@ -166,7 +166,7 @@ func (s *Service) UpsertChannelIdentity(ctx context.Context, channel, channelSub // ListCanonicalChannelIdentities lists channel identities under the same linked user. func (s *Service) ListCanonicalChannelIdentities(ctx context.Context, channelIdentityID string) ([]ChannelIdentity, error) { if s.queries == nil { - return nil, fmt.Errorf("channel identity queries not configured") + return nil, errors.New("channel identity queries not configured") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -196,7 +196,7 @@ func (s *Service) ListCanonicalChannelIdentities(ctx context.Context, channelIde // ListUserChannelIdentities lists all channel identities linked to a user. func (s *Service) ListUserChannelIdentities(ctx context.Context, userID string) ([]ChannelIdentity, error) { if s.queries == nil { - return nil, fmt.Errorf("channel identity queries not configured") + return nil, errors.New("channel identity queries not configured") } pgUserID, err := db.ParseUUID(userID) if err != nil { @@ -216,7 +216,7 @@ func (s *Service) ListUserChannelIdentities(ctx context.Context, userID string) // GetLinkedUserID returns the linked user ID for a channel identity. func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { if s.queries == nil { - return "", fmt.Errorf("channel identity queries not configured") + return "", errors.New("channel identity queries not configured") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { @@ -238,7 +238,7 @@ func (s *Service) GetLinkedUserID(ctx context.Context, channelIdentityID string) // LinkChannelIdentityToUser binds a channel identity to a user. func (s *Service) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { if s.queries == nil { - return fmt.Errorf("channel identity queries not configured") + return errors.New("channel identity queries not configured") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { diff --git a/internal/channel/inbound.go b/internal/channel/inbound.go index 630d7aef..dbf30355 100644 --- a/internal/channel/inbound.go +++ b/internal/channel/inbound.go @@ -2,27 +2,27 @@ package channel import ( "context" - "fmt" + "errors" "log/slog" ) type inboundTask struct { ctx context.Context - cfg ChannelConfig + cfg Config msg InboundMessage } // HandleInbound enqueues an inbound message for asynchronous processing by the worker pool. -func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { +func (m *Manager) HandleInbound(ctx context.Context, cfg Config, msg InboundMessage) error { if m.processor == nil { - return fmt.Errorf("inbound processor not configured") + return errors.New("inbound processor not configured") } if ctx == nil { ctx = context.Background() } m.startInboundWorkers(ctx) if m.inboundCtx != nil && m.inboundCtx.Err() != nil { - return fmt.Errorf("inbound dispatcher stopped") + return errors.New("inbound dispatcher stopped") } task := inboundTask{ ctx: context.WithoutCancel(ctx), @@ -33,13 +33,13 @@ func (m *Manager) HandleInbound(ctx context.Context, cfg ChannelConfig, msg Inbo case m.inboundQueue <- task: return nil default: - return fmt.Errorf("inbound queue full") + return errors.New("inbound queue full") } } -func (m *Manager) handleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage) error { +func (m *Manager) handleInbound(ctx context.Context, cfg Config, msg InboundMessage) error { if m.processor == nil { - return fmt.Errorf("inbound processor not configured") + return errors.New("inbound processor not configured") } sender := m.newReplySender(cfg, msg.Channel) if err := m.processor.HandleInbound(ctx, cfg, msg, sender); err != nil { diff --git a/internal/channel/inbound/channel.go b/internal/channel/inbound/channel.go index f4b9e11f..425da01d 100644 --- a/internal/channel/inbound/channel.go +++ b/internal/channel/inbound/channel.go @@ -1,8 +1,10 @@ +// Package inbound handles incoming channel events and message routing. package inbound import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "regexp" @@ -24,9 +26,7 @@ const ( processingStatusTimeout = 60 * time.Second ) -var ( - whitespacePattern = regexp.MustCompile(`\s+`) -) +var whitespacePattern = regexp.MustCompile(`\s+`) // RouteResolver resolves and manages channel routes. type RouteResolver interface { @@ -88,12 +88,12 @@ func (p *ChannelInboundProcessor) IdentityMiddleware() channel.Middleware { } // HandleInbound processes an inbound channel message through identity resolution and chat gateway. -func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, sender channel.StreamReplySender) error { +func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, sender channel.StreamReplySender) error { if p.runner == nil { - return fmt.Errorf("channel inbound processor not configured") + return errors.New("channel inbound processor not configured") } if sender == nil { - return fmt.Errorf("reply sender not configured") + return errors.New("reply sender not configured") } text := buildInboundQuery(msg.Message) if strings.TrimSpace(text) == "" { @@ -126,7 +126,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel // Resolve or create the route via channel_routes. if p.routeResolver == nil { - return fmt.Errorf("route resolver not configured") + return errors.New("route resolver not configured") } resolved, err := p.routeResolver.ResolveConversation(ctx, route.ResolveInput{ BotID: identity.BotID, @@ -225,7 +225,7 @@ func (p *ChannelInboundProcessor) HandleInbound(ctx context.Context, cfg channel } target := strings.TrimSpace(msg.ReplyTarget) if target == "" { - err := fmt.Errorf("reply target missing") + err := errors.New("reply target missing") if statusNotifier != nil { if notifyErr := p.notifyProcessingFailed(ctx, statusNotifier, cfg, msg, statusInfo, statusHandle, err); notifyErr != nil { p.logProcessingStatusError("processing_failed", msg, identity, notifyErr) @@ -507,7 +507,7 @@ func metadataBool(metadata map[string]any, key string) bool { } } -func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeID string, identity InboundIdentity, msg channel.InboundMessage, query string, triggerMode string) bool { +func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeID string, identity Identity, msg channel.InboundMessage, query, triggerMode string) bool { if p.message == nil { return false } @@ -547,7 +547,7 @@ func (p *ChannelInboundProcessor) persistInboundUser(ctx context.Context, routeI return true } -func buildChannelMessage(output conversation.AssistantOutput, capabilities channel.ChannelCapabilities) channel.Message { +func buildChannelMessage(output conversation.AssistantOutput, capabilities channel.Capabilities) channel.Message { msg := channel.Message{} if strings.TrimSpace(output.Content) != "" { msg.Text = strings.TrimSpace(output.Content) @@ -789,7 +789,7 @@ type sendMessageToolArgs struct { Message *channel.Message `json:"message"` } -func collectMessageToolContext(registry *channel.Registry, messages []conversation.ModelMessage, channelType channel.ChannelType, replyTarget string) ([]string, bool) { +func collectMessageToolContext(registry *channel.Registry, messages []conversation.ModelMessage, channelType channel.Type, replyTarget string) ([]string, bool) { if len(messages) == 0 { return nil, false } @@ -842,7 +842,7 @@ func extractSendMessageText(args sendMessageToolArgs) string { return strings.TrimSpace(args.Message.PlainText()) } -func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolArgs, channelType channel.ChannelType, replyTarget string) bool { +func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolArgs, channelType channel.Type, replyTarget string) bool { platform := strings.TrimSpace(args.Platform) if platform == "" { platform = string(channelType) @@ -865,7 +865,7 @@ func shouldSuppressForToolCall(registry *channel.Registry, args sendMessageToolA return normalizedTarget == normalizedReply } -func normalizeReplyTarget(registry *channel.Registry, channelType channel.ChannelType, target string) string { +func normalizeReplyTarget(registry *channel.Registry, channelType channel.Type, target string) string { if registry == nil { return strings.TrimSpace(target) } @@ -895,7 +895,7 @@ func isSilentReplyText(text string) bool { return false } -func hasTokenPrefix(value []rune, token []rune) bool { +func hasTokenPrefix(value, token []rune) bool { if len(value) < len(token) { return false } @@ -910,7 +910,7 @@ func hasTokenPrefix(value []rune, token []rune) bool { return !isWordChar(value[len(token)]) } -func hasTokenSuffix(value []rune, token []rune) bool { +func hasTokenSuffix(value, token []rune) bool { if len(value) < len(token) { return false } @@ -959,14 +959,14 @@ func isMessagingToolDuplicate(text string, sentTexts []string) bool { } // requireIdentity resolves identity for the current message. Always resolves from msg so each sender is identified correctly (no reuse of context state across messages). -func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { +func (p *ChannelInboundProcessor) requireIdentity(ctx context.Context, cfg channel.Config, msg channel.InboundMessage) (IdentityState, error) { if p.identity == nil { - return IdentityState{}, fmt.Errorf("identity resolver not configured") + return IdentityState{}, errors.New("identity resolver not configured") } return p.identity.Resolve(ctx, cfg, msg) } -func (p *ChannelInboundProcessor) resolveProcessingStatusNotifier(channelType channel.ChannelType) channel.ProcessingStatusNotifier { +func (p *ChannelInboundProcessor) resolveProcessingStatusNotifier(channelType channel.Type) channel.ProcessingStatusNotifier { if p == nil || p.registry == nil { return nil } @@ -980,7 +980,7 @@ func (p *ChannelInboundProcessor) resolveProcessingStatusNotifier(channelType ch func (p *ChannelInboundProcessor) notifyProcessingStarted( ctx context.Context, notifier channel.ProcessingStatusNotifier, - cfg channel.ChannelConfig, + cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo, ) (channel.ProcessingStatusHandle, error) { @@ -995,7 +995,7 @@ func (p *ChannelInboundProcessor) notifyProcessingStarted( func (p *ChannelInboundProcessor) notifyProcessingCompleted( ctx context.Context, notifier channel.ProcessingStatusNotifier, - cfg channel.ChannelConfig, + cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, @@ -1011,7 +1011,7 @@ func (p *ChannelInboundProcessor) notifyProcessingCompleted( func (p *ChannelInboundProcessor) notifyProcessingFailed( ctx context.Context, notifier channel.ProcessingStatusNotifier, - cfg channel.ChannelConfig, + cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, @@ -1028,7 +1028,7 @@ func (p *ChannelInboundProcessor) notifyProcessingFailed( func (p *ChannelInboundProcessor) logProcessingStatusError( stage string, msg channel.InboundMessage, - identity InboundIdentity, + identity Identity, err error, ) { if p == nil || p.logger == nil || err == nil { diff --git a/internal/channel/inbound/channel_test.go b/internal/channel/inbound/channel_test.go index 8af327ea..23cfa4ed 100644 --- a/internal/channel/inbound/channel_test.go +++ b/internal/channel/inbound/channel_test.go @@ -23,7 +23,7 @@ type fakeChatGateway struct { onChat func(conversation.ChatRequest) } -func (f *fakeChatGateway) Chat(ctx context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { +func (f *fakeChatGateway) Chat(_ context.Context, req conversation.ChatRequest) (conversation.ChatResponse, error) { f.gotReq = req if f.onChat != nil { f.onChat(req) @@ -31,7 +31,7 @@ func (f *fakeChatGateway) Chat(ctx context.Context, req conversation.ChatRequest return f.resp, f.err } -func (f *fakeChatGateway) StreamChat(ctx context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { +func (f *fakeChatGateway) StreamChat(_ context.Context, req conversation.ChatRequest) (<-chan conversation.StreamChunk, <-chan error) { f.gotReq = req if f.onChat != nil { f.onChat(req) @@ -57,7 +57,7 @@ func (f *fakeChatGateway) StreamChat(ctx context.Context, req conversation.ChatR return chunks, errs } -func (f *fakeChatGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { +func (f *fakeChatGateway) TriggerSchedule(_ context.Context, _ string, _ schedule.TriggerPayload, _ string) error { return nil } @@ -66,12 +66,12 @@ type fakeReplySender struct { events []channel.StreamEvent } -func (s *fakeReplySender) Send(ctx context.Context, msg channel.OutboundMessage) error { +func (s *fakeReplySender) Send(_ context.Context, msg channel.OutboundMessage) error { s.sent = append(s.sent, msg) return nil } -func (s *fakeReplySender) OpenStream(ctx context.Context, target string, opts channel.StreamOptions) (channel.OutboundStream, error) { +func (s *fakeReplySender) OpenStream(_ context.Context, target string, _ channel.StreamOptions) (channel.OutboundStream, error) { return &fakeOutboundStream{ sender: s, target: strings.TrimSpace(target), @@ -83,7 +83,7 @@ type fakeOutboundStream struct { target string } -func (s *fakeOutboundStream) Push(ctx context.Context, event channel.StreamEvent) error { +func (s *fakeOutboundStream) Push(_ context.Context, event channel.StreamEvent) error { if s == nil || s.sender == nil { return nil } @@ -97,7 +97,7 @@ func (s *fakeOutboundStream) Push(ctx context.Context, event channel.StreamEvent return nil } -func (s *fakeOutboundStream) Close(ctx context.Context) error { +func (s *fakeOutboundStream) Close(_ context.Context) error { return nil } @@ -113,20 +113,20 @@ type fakeProcessingStatusNotifier struct { failedCause error } -func (n *fakeProcessingStatusNotifier) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (n *fakeProcessingStatusNotifier) ProcessingStarted(_ context.Context, _ channel.Config, _ channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { n.events = append(n.events, "started") n.info = append(n.info, info) return n.startedHandle, n.startedErr } -func (n *fakeProcessingStatusNotifier) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (n *fakeProcessingStatusNotifier) ProcessingCompleted(_ context.Context, _ channel.Config, _ channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { n.events = append(n.events, "completed") n.info = append(n.info, info) n.completedSeen = handle return n.completedErr } -func (n *fakeProcessingStatusNotifier) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (n *fakeProcessingStatusNotifier) ProcessingFailed(_ context.Context, _ channel.Config, _ channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { n.events = append(n.events, "failed") n.info = append(n.info, info) n.failedSeen = handle @@ -138,29 +138,29 @@ type fakeProcessingStatusAdapter struct { notifier *fakeProcessingStatusNotifier } -func (a *fakeProcessingStatusAdapter) Type() channel.ChannelType { - return channel.ChannelType("feishu") +func (a *fakeProcessingStatusAdapter) Type() channel.Type { + return channel.Type("feishu") } func (a *fakeProcessingStatusAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ - Type: channel.ChannelType("feishu"), - Capabilities: channel.ChannelCapabilities{ + Type: channel.Type("feishu"), + Capabilities: channel.Capabilities{ Text: true, Reply: true, }, } } -func (a *fakeProcessingStatusAdapter) ProcessingStarted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { +func (a *fakeProcessingStatusAdapter) ProcessingStarted(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo) (channel.ProcessingStatusHandle, error) { return a.notifier.ProcessingStarted(ctx, cfg, msg, info) } -func (a *fakeProcessingStatusAdapter) ProcessingCompleted(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { +func (a *fakeProcessingStatusAdapter) ProcessingCompleted(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle) error { return a.notifier.ProcessingCompleted(ctx, cfg, msg, info, handle) } -func (a *fakeProcessingStatusAdapter) ProcessingFailed(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { +func (a *fakeProcessingStatusAdapter) ProcessingFailed(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, info channel.ProcessingStatusInfo, handle channel.ProcessingStatusHandle, cause error) error { return a.notifier.ProcessingFailed(ctx, cfg, msg, info, handle, cause) } @@ -170,14 +170,14 @@ type fakeChatService struct { persisted []messagepkg.Message } -func (f *fakeChatService) ResolveConversation(ctx context.Context, input route.ResolveInput) (route.ResolveConversationResult, error) { +func (f *fakeChatService) ResolveConversation(_ context.Context, _ route.ResolveInput) (route.ResolveConversationResult, error) { if f.resolveErr != nil { return route.ResolveConversationResult{}, f.resolveErr } return f.resolveResult, nil } -func (f *fakeChatService) Persist(ctx context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { +func (f *fakeChatService) Persist(_ context.Context, input messagepkg.PersistInput) (messagepkg.Message, error) { msg := messagepkg.Message{ BotID: input.BotID, RouteID: input.RouteID, @@ -209,10 +209,10 @@ func TestChannelInboundProcessorWithIdentity(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1", Type: channel.Type("feishu")} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-1", DisplayName: "User1"}, @@ -252,10 +252,10 @@ func TestChannelInboundProcessorDenied(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1", Type: channel.Type("feishu")} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "stranger-1"}, @@ -282,7 +282,7 @@ func TestChannelInboundProcessorIgnoreEmpty(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1"} + cfg := channel.Config{ID: "cfg-1"} msg := channel.InboundMessage{Message: channel.Message{Text: " "}} err := processor.HandleInbound(context.Background(), cfg, msg, sender) @@ -311,10 +311,10 @@ func TestChannelInboundProcessorSilentReply(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("telegram"), + Channel: channel.Type("telegram"), Message: channel.Message{Text: "test"}, ReplyTarget: "chat-123", Sender: channel.Identity{SubjectID: "user-1"}, @@ -347,10 +347,10 @@ func TestChannelInboundProcessorGroupPassiveSync(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "msg-1", Text: "hello everyone"}, ReplyTarget: "chat_id:oc_123", Sender: channel.Identity{SubjectID: "user-1"}, @@ -395,10 +395,10 @@ func TestChannelInboundProcessorGroupMentionTriggersReply(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "msg-2", Text: "@bot ping"}, ReplyTarget: "chat_id:oc_123", Sender: channel.Identity{SubjectID: "user-1"}, @@ -447,10 +447,10 @@ func TestChannelInboundProcessorPersonalGroupNonOwnerIgnored(t *testing.T) { processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "msg-personal-1", Text: "hello"}, ReplyTarget: "chat_id:oc_personal", Sender: channel.Identity{SubjectID: "ext-member-1"}, @@ -490,10 +490,10 @@ func TestChannelInboundProcessorPersonalGroupOwnerWithoutMentionUsesPassivePersi processor := NewChannelInboundProcessor(slog.Default(), nil, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, policySvc, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1"} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1"} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "msg-personal-2", Text: "owner says hi"}, ReplyTarget: "chat_id:oc_personal", Sender: channel.Identity{SubjectID: "ext-owner-1"}, @@ -536,7 +536,7 @@ func TestChannelInboundProcessorProcessingStatusSuccessLifecycle(t *testing.T) { {Role: "assistant", Content: conversation.NewTextContent("AI reply")}, }, }, - onChat: func(req conversation.ChatRequest) { + onChat: func(_ conversation.ChatRequest) { if len(notifier.events) != 1 || notifier.events[0] != "started" { t.Fatalf("expected started before chat call, got events: %+v", notifier.events) } @@ -544,10 +544,10 @@ func TestChannelInboundProcessorProcessingStatusSuccessLifecycle(t *testing.T) { } processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1", Type: channel.Type("feishu")} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "om_123", Text: "hello"}, ReplyTarget: "chat_id:oc_123", Sender: channel.Identity{SubjectID: "ext-1"}, @@ -591,10 +591,10 @@ func TestChannelInboundProcessorProcessingStatusFailureLifecycle(t *testing.T) { gateway := &fakeChatGateway{err: chatErr} processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1", Type: channel.Type("feishu")} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "om_456", Text: "hello"}, ReplyTarget: "chat_id:oc_456", Sender: channel.Identity{SubjectID: "ext-2"}, @@ -641,10 +641,10 @@ func TestChannelInboundProcessorProcessingStatusErrorsAreBestEffort(t *testing.T } processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1", Type: channel.Type("feishu")} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "om_789", Text: "hello"}, ReplyTarget: "chat_id:oc_789", Sender: channel.Identity{SubjectID: "ext-3"}, @@ -683,10 +683,10 @@ func TestChannelInboundProcessorProcessingFailedNotifyErrorDoesNotOverrideChatEr gateway := &fakeChatGateway{err: chatErr} processor := NewChannelInboundProcessor(slog.Default(), registry, chatSvc, chatSvc, gateway, channelIdentitySvc, memberSvc, nil, nil, nil, "", 0) sender := &fakeReplySender{} - cfg := channel.ChannelConfig{ID: "cfg-1", BotID: "bot-1", ChannelType: channel.ChannelType("feishu")} + cfg := channel.Config{ID: "cfg-1", BotID: "bot-1", Type: channel.Type("feishu")} msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{ID: "om_999", Text: "hello"}, ReplyTarget: "chat_id:oc_999", Sender: channel.Identity{SubjectID: "ext-4"}, diff --git a/internal/channel/inbound/identity.go b/internal/channel/inbound/identity.go index a04feba1..956bdaf2 100644 --- a/internal/channel/inbound/identity.go +++ b/internal/channel/inbound/identity.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "slices" "strings" "time" @@ -20,8 +21,8 @@ type IdentityDecision struct { Reply channel.Message } -// InboundIdentity carries the resolved channel identity for an inbound message. -type InboundIdentity struct { +// Identity carries the resolved channel identity for an inbound message. +type Identity struct { BotID string ChannelConfigID string SubjectID string @@ -34,7 +35,7 @@ type InboundIdentity struct { // IdentityState bundles resolved identity with an optional early-exit decision. type IdentityState struct { - Identity InboundIdentity + Identity Identity Decision *IdentityDecision } @@ -142,7 +143,7 @@ func NewIdentityResolver( // Middleware returns a channel middleware that resolves identity before processing. func (r *IdentityResolver) Middleware() channel.Middleware { return func(next channel.InboundHandler) channel.InboundHandler { - return func(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) error { + return func(ctx context.Context, cfg channel.Config, msg channel.InboundMessage) error { state, err := r.Resolve(ctx, cfg, msg) if err != nil { return err @@ -155,9 +156,9 @@ func (r *IdentityResolver) Middleware() channel.Middleware { // Resolve performs two-phase identity resolution: // 1. Global identity: (channel, channel_subject_id) -> channel_identity_id (unconditional) // 2. Authorization: bot membership check with guest/preauth fallback -func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage) (IdentityState, error) { +func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.Config, msg channel.InboundMessage) (IdentityState, error) { if r.channelIdentities == nil { - return IdentityState{}, fmt.Errorf("identity resolver not configured") + return IdentityState{}, errors.New("identity resolver not configured") } botID := strings.TrimSpace(msg.BotID) @@ -173,7 +174,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi displayName, avatarURL := r.resolveProfile(ctx, cfg, msg, subjectID) state := IdentityState{ - Identity: InboundIdentity{ + Identity: Identity{ BotID: botID, ChannelConfigID: channelConfigID, SubjectID: subjectID, @@ -182,7 +183,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi // Phase 1: Global identity resolution (unconditional). if subjectID == "" { - return state, fmt.Errorf("cannot resolve identity: no channel_subject_id") + return state, errors.New("cannot resolve identity: no channel_subject_id") } channelIdentityID, linkedUserID, err := r.resolveIdentityWithLinkedUser(ctx, msg, subjectID, displayName, avatarURL) @@ -287,7 +288,7 @@ func (r *IdentityResolver) Resolve(ctx context.Context, cfg channel.ChannelConfi func (r *IdentityResolver) resolveIdentityWithLinkedUser(ctx context.Context, msg channel.InboundMessage, primarySubjectID, displayName, avatarURL string) (string, string, error) { candidates := identitySubjectCandidates(msg, primarySubjectID) if len(candidates) == 0 { - return "", "", fmt.Errorf("cannot resolve identity: no channel_subject_id") + return "", "", errors.New("cannot resolve identity: no channel_subject_id") } var meta map[string]any @@ -448,10 +449,8 @@ func identitySubjectCandidates(msg channel.InboundMessage, primary string) []str if value == "" { return } - for _, existing := range candidates { - if existing == value { - return - } + if slices.Contains(candidates, value) { + return } candidates = append(candidates, value) } @@ -480,7 +479,7 @@ func extractDisplayName(msg channel.InboundMessage) string { // resolveProfile resolves display name and avatar URL for the sender. // Always queries directory for avatar; prefers message-level display name over directory name. -func (r *IdentityResolver) resolveProfile(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, subjectID string) (string, string) { +func (r *IdentityResolver) resolveProfile(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, subjectID string) (string, string) { displayName := extractDisplayName(msg) dirName, avatarURL := r.resolveProfileFromDirectory(ctx, cfg, msg, subjectID) if displayName == "" { @@ -490,7 +489,7 @@ func (r *IdentityResolver) resolveProfile(ctx context.Context, cfg channel.Chann } // resolveProfileFromDirectory looks up the directory for sender display name and avatar URL. -func (r *IdentityResolver) resolveProfileFromDirectory(ctx context.Context, cfg channel.ChannelConfig, msg channel.InboundMessage, subjectID string) (string, string) { +func (r *IdentityResolver) resolveProfileFromDirectory(ctx context.Context, cfg channel.Config, msg channel.InboundMessage, subjectID string) (string, string) { if r.registry == nil { return "", "" } diff --git a/internal/channel/inbound/identity_test.go b/internal/channel/inbound/identity_test.go index 6e2a3eeb..3de10ae1 100644 --- a/internal/channel/inbound/identity_test.go +++ b/internal/channel/inbound/identity_test.go @@ -24,7 +24,7 @@ type fakeChannelIdentityService struct { lastMeta map[string]any } -func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Context, platform, externalID, displayName string, meta map[string]any) (identities.ChannelIdentity, error) { +func (f *fakeChannelIdentityService) ResolveByChannelIdentity(_ context.Context, _ string, externalID, displayName string, meta map[string]any) (identities.ChannelIdentity, error) { f.calls++ f.lastDisplayName = displayName f.lastMeta = meta @@ -40,7 +40,7 @@ func (f *fakeChannelIdentityService) ResolveByChannelIdentity(ctx context.Contex return f.channelIdentity, nil } -func (f *fakeChannelIdentityService) Canonicalize(ctx context.Context, channelIdentityID string) (string, error) { +func (f *fakeChannelIdentityService) Canonicalize(_ context.Context, channelIdentityID string) (string, error) { if f.canonical != nil { if value, ok := f.canonical[channelIdentityID]; ok { return value, nil @@ -49,7 +49,7 @@ func (f *fakeChannelIdentityService) Canonicalize(ctx context.Context, channelId return channelIdentityID, nil } -func (f *fakeChannelIdentityService) GetLinkedUserID(ctx context.Context, channelIdentityID string) (string, error) { +func (f *fakeChannelIdentityService) GetLinkedUserID(_ context.Context, channelIdentityID string) (string, error) { if f.linked != nil { if value, ok := f.linked[channelIdentityID]; ok { return value, nil @@ -60,7 +60,7 @@ func (f *fakeChannelIdentityService) GetLinkedUserID(ctx context.Context, channe return channelIdentityID, nil } -func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(ctx context.Context, channelIdentityID, userID string) error { +func (f *fakeChannelIdentityService) LinkChannelIdentityToUser(_ context.Context, channelIdentityID, userID string) error { if f.linked == nil { f.linked = map[string]string{} } @@ -73,11 +73,11 @@ type fakeMemberService struct { upsertCalled bool } -func (f *fakeMemberService) IsMember(ctx context.Context, botID, channelIdentityID string) (bool, error) { +func (f *fakeMemberService) IsMember(_ context.Context, _ string, _ string) (bool, error) { return f.isMember, nil } -func (f *fakeMemberService) UpsertMemberSimple(ctx context.Context, botID, channelIdentityID, role string) error { +func (f *fakeMemberService) UpsertMemberSimple(_ context.Context, _ string, _ string, _ string) error { f.upsertCalled = true return nil } @@ -89,21 +89,21 @@ type fakePolicyService struct { err error } -func (f *fakePolicyService) AllowGuest(ctx context.Context, botID string) (bool, error) { +func (f *fakePolicyService) AllowGuest(_ context.Context, _ string) (bool, error) { if f.err != nil { return false, f.err } return f.allow, nil } -func (f *fakePolicyService) BotType(ctx context.Context, botID string) (string, error) { +func (f *fakePolicyService) BotType(_ context.Context, _ string) (string, error) { if f.err != nil { return "", f.err } return f.botType, nil } -func (f *fakePolicyService) BotOwnerUserID(ctx context.Context, botID string) (string, error) { +func (f *fakePolicyService) BotOwnerUserID(_ context.Context, _ string) (string, error) { if f.err != nil { return "", f.err } @@ -116,7 +116,7 @@ type fakePreauthServiceIdentity struct { markUsed bool } -func (f *fakePreauthServiceIdentity) Get(ctx context.Context, token string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) Get(_ context.Context, token string) (preauth.Key, error) { if f.err != nil { return preauth.Key{}, f.err } @@ -126,7 +126,7 @@ func (f *fakePreauthServiceIdentity) Get(ctx context.Context, token string) (pre return f.key, nil } -func (f *fakePreauthServiceIdentity) MarkUsed(ctx context.Context, id string) (preauth.Key, error) { +func (f *fakePreauthServiceIdentity) MarkUsed(_ context.Context, _ string) (preauth.Key, error) { f.markUsed = true return f.key, nil } @@ -139,7 +139,7 @@ type fakeBindService struct { onConsume func(channelChannelIdentityID string) } -func (f *fakeBindService) Get(ctx context.Context, token string) (bind.Code, error) { +func (f *fakeBindService) Get(_ context.Context, token string) (bind.Code, error) { if f.getErr != nil { return bind.Code{}, f.getErr } @@ -149,7 +149,7 @@ func (f *fakeBindService) Get(ctx context.Context, token string) (bind.Code, err return f.code, nil } -func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelChannelIdentityID string) error { +func (f *fakeBindService) Consume(_ context.Context, _ bind.Code, channelChannelIdentityID string) error { f.consumeCalled = true if f.onConsume != nil { f.onConsume(channelChannelIdentityID) @@ -158,35 +158,35 @@ func (f *fakeBindService) Consume(ctx context.Context, code bind.Code, channelCh } type fakeDirectoryAdapter struct { - channelType channel.ChannelType - resolveFn func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) + channelType channel.Type + resolveFn func(ctx context.Context, cfg channel.Config, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) } -func (f *fakeDirectoryAdapter) Type() channel.ChannelType { +func (f *fakeDirectoryAdapter) Type() channel.Type { return f.channelType } func (f *fakeDirectoryAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{ - Type: f.channelType, - DisplayName: "FakeDirectory", - Capabilities: channel.ChannelCapabilities{}, + Type: f.channelType, + DisplayName: "FakeDirectory", + Capabilities: channel.Capabilities{}, } } -func (f *fakeDirectoryAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (f *fakeDirectoryAdapter) ListPeers(_ context.Context, _ channel.Config, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (f *fakeDirectoryAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (f *fakeDirectoryAdapter) ListGroups(_ context.Context, _ channel.Config, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (f *fakeDirectoryAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (f *fakeDirectoryAdapter) ListGroupMembers(_ context.Context, _ channel.Config, _ string, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (f *fakeDirectoryAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { +func (f *fakeDirectoryAdapter) ResolveEntry(ctx context.Context, cfg channel.Config, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { if f.resolveFn != nil { return f.resolveFn(ctx, cfg, input, kind) } @@ -201,12 +201,12 @@ func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-1", DisplayName: "Guest"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -224,8 +224,8 @@ func TestIdentityResolverAllowGuestWithoutMembershipSideEffect(t *testing.T) { func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { registry := channel.NewRegistry() directoryAdapter := &fakeDirectoryAdapter{ - channelType: channel.ChannelType("feishu"), - resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + channelType: channel.Type("feishu"), + resolveFn: func(_ context.Context, _ channel.Config, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { if kind != channel.DirectoryEntryUser { t.Fatalf("expected kind user, got %s", kind) } @@ -249,7 +249,7 @@ func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", Sender: channel.Identity{ @@ -259,7 +259,7 @@ func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { }, }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1", Type: channel.Type("feishu")}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -274,8 +274,8 @@ func TestIdentityResolverResolveDisplayNameFromDirectory(t *testing.T) { func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testing.T) { registry := channel.NewRegistry() directoryAdapter := &fakeDirectoryAdapter{ - channelType: channel.ChannelType("feishu"), - resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + channelType: channel.Type("feishu"), + resolveFn: func(_ context.Context, _ channel.Config, _ string, _ channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{}, errors.New("lookup failed") }, } @@ -290,7 +290,7 @@ func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testin msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", Sender: channel.Identity{ @@ -301,7 +301,7 @@ func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testin }, }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1", Type: channel.Type("feishu")}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -316,8 +316,8 @@ func TestIdentityResolverDirectoryLookupFailureDoesNotFallbackToOpenID(t *testin func TestIdentityResolverDirectoryAvatarURLPropagated(t *testing.T) { registry := channel.NewRegistry() directoryAdapter := &fakeDirectoryAdapter{ - channelType: channel.ChannelType("feishu"), - resolveFn: func(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + channelType: channel.Type("feishu"), + resolveFn: func(_ context.Context, _ channel.Config, _ string, _ channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{ Kind: channel.DirectoryEntryUser, Name: "Avatar User", @@ -336,7 +336,7 @@ func TestIdentityResolverDirectoryAvatarURLPropagated(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "target-id", Sender: channel.Identity{ @@ -344,7 +344,7 @@ func TestIdentityResolverDirectoryAvatarURLPropagated(t *testing.T) { Attributes: map[string]string{"open_id": "ou-avatar"}, }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1", ChannelType: channel.ChannelType("feishu")}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1", Type: channel.Type("feishu")}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -370,12 +370,12 @@ func TestIdentityResolverExistingMemberPasses(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("telegram"), + Channel: channel.Type("telegram"), Message: channel.Message{Text: "hello"}, ReplyTarget: "chat-123", Sender: channel.Identity{SubjectID: "tg-user-1"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -400,12 +400,12 @@ func TestIdentityResolverPreauthKey(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "PREAUTH123"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-1"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -436,12 +436,12 @@ func TestIdentityResolverPreauthKeyExpired(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "PREAUTH123"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-1"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -461,12 +461,12 @@ func TestIdentityResolverDenied(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("telegram"), + Channel: channel.Type("telegram"), Message: channel.Message{Text: "hello"}, ReplyTarget: "chat-123", Sender: channel.Identity{SubjectID: "stranger-1"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -483,7 +483,7 @@ func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, Sender: channel.Identity{SubjectID: "ext-group-1"}, Conversation: channel.Conversation{ @@ -492,7 +492,7 @@ func TestIdentityResolverPersonalBotRejectsGroupMessages(t *testing.T) { }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -515,7 +515,7 @@ func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello from owner"}, Sender: channel.Identity{SubjectID: "ext-owner-1"}, Conversation: channel.Conversation{ @@ -524,7 +524,7 @@ func TestIdentityResolverPersonalBotAllowsOwnerInGroup(t *testing.T) { }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -544,7 +544,7 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello from owner"}, Sender: channel.Identity{SubjectID: "ext-owner-direct"}, Conversation: channel.Conversation{ @@ -553,7 +553,7 @@ func TestIdentityResolverPersonalBotAllowsOwnerDirectWithoutMembership(t *testin }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -581,7 +581,7 @@ func TestIdentityResolverPersonalBotOwnerFallbackByAlternateSubject(t *testing.T msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello from owner"}, Sender: channel.Identity{ SubjectID: "ou-open-owner", @@ -596,7 +596,7 @@ func TestIdentityResolverPersonalBotOwnerFallbackByAlternateSubject(t *testing.T }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -622,7 +622,7 @@ func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello from non-owner"}, Sender: channel.Identity{SubjectID: "ext-non-owner"}, Conversation: channel.Conversation{ @@ -631,7 +631,7 @@ func TestIdentityResolverPersonalBotRejectsNonOwnerDirectEvenIfMember(t *testing }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -669,12 +669,12 @@ func TestIdentityResolverBindRunsBeforeMembershipCheck(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "BIND123"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-bind-1"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -708,12 +708,12 @@ func TestIdentityResolverBindConsumeErrorHandledAsDecision(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("telegram"), + Channel: channel.Type("telegram"), Message: channel.Message{Text: "BINDUSED"}, ReplyTarget: "chat-123", Sender: channel.Identity{SubjectID: "ext-bind-2"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -747,12 +747,12 @@ func TestIdentityResolverBindCodeNotScopedToCurrentBot(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-2", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "BINDANYBOT"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-bind-any-bot"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -775,7 +775,7 @@ func TestIdentityResolverPublicBotGroupDeniedSilently(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "group-target", Sender: channel.Identity{SubjectID: "stranger-group"}, @@ -784,7 +784,7 @@ func TestIdentityResolverPublicBotGroupDeniedSilently(t *testing.T) { Type: "group", }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -804,7 +804,7 @@ func TestIdentityResolverPublicBotDirectDeniedWithReply(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "hello"}, ReplyTarget: "direct-target", Sender: channel.Identity{SubjectID: "stranger-direct"}, @@ -813,7 +813,7 @@ func TestIdentityResolverPublicBotDirectDeniedWithReply(t *testing.T) { Type: "p2p", }, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -840,12 +840,12 @@ func TestIdentityResolverBindCodePlatformMismatch(t *testing.T) { msg := channel.InboundMessage{ BotID: "bot-1", - Channel: channel.ChannelType("feishu"), + Channel: channel.Type("feishu"), Message: channel.Message{Text: "BINDPLATFORM"}, ReplyTarget: "target-id", Sender: channel.Identity{SubjectID: "ext-bind-platform"}, } - state, err := resolver.Resolve(context.Background(), channel.ChannelConfig{BotID: "bot-1"}, msg) + state, err := resolver.Resolve(context.Background(), channel.Config{BotID: "bot-1"}, msg) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/channel/inbound_test.go b/internal/channel/inbound_test.go index 2135d5ba..83ea9a5a 100644 --- a/internal/channel/inbound_test.go +++ b/internal/channel/inbound_test.go @@ -2,7 +2,7 @@ package channel import ( "context" - "fmt" + "errors" "log/slog" "testing" ) @@ -13,24 +13,25 @@ type mockAdapter struct { streamEvents []StreamEvent } -func (m *mockAdapter) Type() ChannelType { return ChannelType("test") } +func (m *mockAdapter) Type() Type { return Type("test") } func (m *mockAdapter) Descriptor() Descriptor { return Descriptor{ - Type: ChannelType("test"), + Type: Type("test"), DisplayName: "Test", - Capabilities: ChannelCapabilities{ + Capabilities: Capabilities{ Text: true, Reply: true, Streaming: true, }, } } -func (m *mockAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { + +func (m *mockAdapter) Send(_ context.Context, _ Config, msg OutboundMessage) error { m.sentMessages = append(m.sentMessages, msg) return nil } -func (m *mockAdapter) OpenStream(ctx context.Context, cfg ChannelConfig, target string, opts StreamOptions) (OutboundStream, error) { +func (m *mockAdapter) OpenStream(_ context.Context, _ Config, _ string, _ StreamOptions) (OutboundStream, error) { return &mockAdapterStream{adapter: m}, nil } @@ -38,7 +39,7 @@ type mockAdapterStream struct { adapter *mockAdapter } -func (s *mockAdapterStream) Push(ctx context.Context, event StreamEvent) error { +func (s *mockAdapterStream) Push(_ context.Context, event StreamEvent) error { if s == nil || s.adapter == nil { return nil } @@ -52,18 +53,18 @@ func (s *mockAdapterStream) Push(ctx context.Context, event StreamEvent) error { return nil } -func (s *mockAdapterStream) Close(ctx context.Context) error { +func (s *mockAdapterStream) Close(_ context.Context) error { return nil } type fakeInboundProcessor struct { resp *OutboundMessage err error - gotCfg ChannelConfig + gotCfg Config gotMsg InboundMessage } -func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { +func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg Config, msg InboundMessage, sender StreamReplySender) error { f.gotCfg = cfg f.gotMsg = msg if f.err != nil { @@ -73,14 +74,14 @@ func (f *fakeInboundProcessor) HandleInbound(ctx context.Context, cfg ChannelCon return nil } if sender == nil { - return fmt.Errorf("sender missing") + return errors.New("sender missing") } return sender.Send(ctx, *f.resp) } type fakeInboundStreamProcessor struct{} -func (f *fakeInboundStreamProcessor) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { +func (f *fakeInboundStreamProcessor) HandleInbound(ctx context.Context, _ Config, _ InboundMessage, sender StreamReplySender) error { stream, err := sender.OpenStream(ctx, "stream-target", StreamOptions{}) if err != nil { return err @@ -120,9 +121,9 @@ func TestManager_handleInbound(t *testing.T) { adapter := &mockAdapter{} m.RegisterAdapter(adapter) - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + cfg := Config{ID: "bot-1", BotID: "bot-1", Type: Type("test")} msg := InboundMessage{ - Channel: ChannelType("test"), + Channel: Type("test"), Message: Message{Text: "hello"}, ReplyTarget: "target-id", Conversation: Conversation{ @@ -154,9 +155,9 @@ func TestManager_handleInbound(t *testing.T) { adapter := &mockAdapter{} m.RegisterAdapter(adapter) - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + cfg := Config{ID: "bot-1", BotID: "bot-1", Type: Type("test")} msg := InboundMessage{ - Channel: ChannelType("test"), + Channel: Type("test"), Message: Message{Text: "hello"}, ReplyTarget: "target-id", } @@ -175,7 +176,7 @@ func TestManager_handleInbound(t *testing.T) { processor := &fakeInboundProcessor{err: context.Canceled} reg := NewRegistry() m := NewManager(logger, reg, &fakeConfigStore{}, processor) - cfg := ChannelConfig{ID: "bot-1"} + cfg := Config{ID: "bot-1"} msg := InboundMessage{Message: Message{Text: " "}} // whitespace-only message err := m.handleInbound(context.Background(), cfg, msg) @@ -191,9 +192,9 @@ func TestManager_handleInbound(t *testing.T) { adapter := &mockAdapter{} m.RegisterAdapter(adapter) - cfg := ChannelConfig{ID: "bot-1", BotID: "bot-1", ChannelType: ChannelType("test")} + cfg := Config{ID: "bot-1", BotID: "bot-1", Type: Type("test")} msg := InboundMessage{ - Channel: ChannelType("test"), + Channel: Type("test"), Message: Message{Text: "hello"}, ReplyTarget: "stream-target", Conversation: Conversation{ diff --git a/internal/channel/manager.go b/internal/channel/manager.go index f67d2cba..b554c901 100644 --- a/internal/channel/manager.go +++ b/internal/channel/manager.go @@ -12,18 +12,18 @@ import ( // ConfigLister lists channel configs for periodic refresh. Used by connection lifecycle. type ConfigLister interface { - ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) + ListConfigsByType(ctx context.Context, channelType Type) ([]Config, error) } // ConfigResolver resolves effective configs and user bindings. Used for outbound sending. type ConfigResolver interface { - ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) - GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) + ResolveEffectiveConfig(ctx context.Context, botID string, channelType Type) (Config, error) + GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType Type) (IdentityBinding, error) } // BindingStore resolves channel-identity bindings. Used by identity resolution. type BindingStore interface { - ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) + ResolveIdentityBinding(ctx context.Context, channelType Type, criteria BindingCriteria) (string, error) } // ConfigStore is the full persistence interface. Components should depend on smaller @@ -32,7 +32,7 @@ type ConfigStore interface { ConfigLister ConfigResolver BindingStore - UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) + UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType Type, req UpsertChannelIdentityConfigRequest) (IdentityBinding, error) } // Middleware wraps an InboundHandler to add cross-cutting behavior. @@ -121,13 +121,13 @@ func (m *Manager) AddAdapter(ctx context.Context, adapter Adapter) { } // RemoveAdapter unregisters an adapter and stops all its active connections. -func (m *Manager) RemoveAdapter(ctx context.Context, channelType ChannelType) { +func (m *Manager) RemoveAdapter(ctx context.Context, channelType Type) { if ctx == nil { ctx = context.Background() } m.mu.Lock() for id, entry := range m.connections { - if entry != nil && entry.config.ChannelType == channelType { + if entry != nil && entry.config.Type == channelType { if entry.connection != nil { if err := entry.connection.Stop(ctx); err != nil && !errors.Is(err, ErrStopNotSupported) && m.logger != nil { m.logger.Warn("adapter stop failed", slog.String("config_id", id), slog.Any("error", err)) @@ -166,9 +166,9 @@ func (m *Manager) Start(ctx context.Context) { } // Send delivers an outbound message to the specified channel, resolving target and config automatically. -func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelType, req SendRequest) error { +func (m *Manager) Send(ctx context.Context, botID string, channelType Type, req SendRequest) error { if m.service == nil { - return fmt.Errorf("channel manager not configured") + return errors.New("channel manager not configured") } sender, ok := m.registry.GetSender(channelType) if !ok { @@ -182,14 +182,14 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp if target == "" { targetChannelIdentityID := strings.TrimSpace(req.ChannelIdentityID) if targetChannelIdentityID == "" { - return fmt.Errorf("target or user_id is required") + return errors.New("target or user_id is required") } userCfg, err := m.service.GetChannelIdentityConfig(ctx, targetChannelIdentityID, channelType) if err != nil { if m.logger != nil { m.logger.Warn("channel binding missing", slog.String("channel", channelType.String()), slog.String("channel_identity_id", targetChannelIdentityID)) } - return fmt.Errorf("channel binding required") + return errors.New("channel binding required") } target, err = m.registry.ResolveTargetFromUserConfig(channelType, userCfg.Config) if err != nil { @@ -200,7 +200,7 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp target = normalized } if req.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } if m.logger != nil { m.logger.Info("send outbound", slog.String("channel", channelType.String()), slog.String("bot_id", botID)) @@ -225,9 +225,9 @@ func (m *Manager) Send(ctx context.Context, botID string, channelType ChannelTyp } // React adds or removes an emoji reaction on a channel message. -func (m *Manager) React(ctx context.Context, botID string, channelType ChannelType, req ReactRequest) error { +func (m *Manager) React(ctx context.Context, botID string, channelType Type, req ReactRequest) error { if m.service == nil { - return fmt.Errorf("channel manager not configured") + return errors.New("channel manager not configured") } reactor, ok := m.registry.GetReactor(channelType) if !ok { @@ -239,18 +239,18 @@ func (m *Manager) React(ctx context.Context, botID string, channelType ChannelTy } target := strings.TrimSpace(req.Target) if target == "" { - return fmt.Errorf("target is required for reactions") + return errors.New("target is required for reactions") } if normalized, ok := m.registry.NormalizeTarget(channelType, target); ok { target = normalized } messageID := strings.TrimSpace(req.MessageID) if messageID == "" { - return fmt.Errorf("message_id is required for reactions") + return errors.New("message_id is required for reactions") } emoji := strings.TrimSpace(req.Emoji) if !req.Remove && emoji == "" { - return fmt.Errorf("emoji is required when adding a reaction") + return errors.New("emoji is required when adding a reaction") } if m.logger != nil { m.logger.Info("react outbound", diff --git a/internal/channel/manager_integration_test.go b/internal/channel/manager_integration_test.go index fc296014..37f30499 100644 --- a/internal/channel/manager_integration_test.go +++ b/internal/channel/manager_integration_test.go @@ -2,8 +2,7 @@ package channel import ( "context" - "fmt" - "io" + "errors" "log/slog" "strings" "sync" @@ -12,37 +11,37 @@ import ( ) type fakeConfigStore struct { - effectiveConfig ChannelConfig - channelIdentityConfig ChannelIdentityBinding - configsByType map[ChannelType][]ChannelConfig + effectiveConfig Config + channelIdentityConfig IdentityBinding + configsByType map[Type][]Config boundChannelIdentityID string } -func (f *fakeConfigStore) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { +func (f *fakeConfigStore) ResolveEffectiveConfig(_ context.Context, _ string, _ Type) (Config, error) { return f.effectiveConfig, nil } -func (f *fakeConfigStore) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { +func (f *fakeConfigStore) GetChannelIdentityConfig(_ context.Context, _ string, _ Type) (IdentityBinding, error) { if f.channelIdentityConfig.ID == "" && len(f.channelIdentityConfig.Config) == 0 { - return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") + return IdentityBinding{}, errors.New("channel user config not found") } return f.channelIdentityConfig, nil } -func (f *fakeConfigStore) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { +func (f *fakeConfigStore) UpsertChannelIdentityConfig(_ context.Context, _ string, _ Type, _ UpsertChannelIdentityConfigRequest) (IdentityBinding, error) { return f.channelIdentityConfig, nil } -func (f *fakeConfigStore) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { +func (f *fakeConfigStore) ListConfigsByType(_ context.Context, channelType Type) ([]Config, error) { if f.configsByType == nil { return nil, nil } return f.configsByType[channelType], nil } -func (f *fakeConfigStore) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { +func (f *fakeConfigStore) ResolveIdentityBinding(_ context.Context, _ Type, _ BindingCriteria) (string, error) { if f.boundChannelIdentityID == "" { - return "", fmt.Errorf("channel user binding not found") + return "", errors.New("channel user binding not found") } return f.boundChannelIdentityID, nil } @@ -50,11 +49,11 @@ func (f *fakeConfigStore) ResolveChannelIdentityBinding(ctx context.Context, cha type fakeInboundProcessorIntegration struct { resp *OutboundMessage err error - gotCfg ChannelConfig + gotCfg Config gotMsg InboundMessage } -func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error { +func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg Config, msg InboundMessage, sender StreamReplySender) error { f.gotCfg = cfg f.gotMsg = msg if f.err != nil { @@ -64,38 +63,38 @@ func (f *fakeInboundProcessorIntegration) HandleInbound(ctx context.Context, cfg return nil } if sender == nil { - return fmt.Errorf("sender missing") + return errors.New("sender missing") } return sender.Send(ctx, *f.resp) } type fakeAdapter struct { - channelType ChannelType + channelType Type mu sync.Mutex - started []ChannelConfig + started []Config sent []OutboundMessage stops int } -func (f *fakeAdapter) Type() ChannelType { +func (f *fakeAdapter) Type() Type { return f.channelType } func (f *fakeAdapter) Descriptor() Descriptor { - return Descriptor{Type: f.channelType, DisplayName: "Fake", Capabilities: ChannelCapabilities{Text: true}} + return Descriptor{Type: f.channelType, DisplayName: "Fake", Capabilities: Capabilities{Text: true}} } func (f *fakeAdapter) ResolveTarget(channelIdentityConfig map[string]any) (string, error) { value := strings.TrimSpace(ReadString(channelIdentityConfig, "target")) if value == "" { - return "", fmt.Errorf("missing target") + return "", errors.New("missing target") } return "resolved:" + value, nil } func (f *fakeAdapter) NormalizeTarget(raw string) string { return strings.TrimSpace(raw) } -func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, handler InboundHandler) (Connection, error) { +func (f *fakeAdapter) Connect(_ context.Context, cfg Config, _ InboundHandler) (Connection, error) { f.mu.Lock() f.started = append(f.started, cfg) f.mu.Unlock() @@ -108,7 +107,7 @@ func (f *fakeAdapter) Connect(ctx context.Context, cfg ChannelConfig, handler In return NewConnection(cfg, stop), nil } -func (f *fakeAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundMessage) error { +func (f *fakeAdapter) Send(_ context.Context, _ Config, msg OutboundMessage) error { f.mu.Lock() f.sent = append(f.sent, msg) f.mu.Unlock() @@ -118,7 +117,7 @@ func (f *fakeAdapter) Send(ctx context.Context, cfg ChannelConfig, msg OutboundM func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{} processor := &fakeInboundProcessorIntegration{ resp: &OutboundMessage{ @@ -129,19 +128,19 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { }, } reg := NewRegistry() - adapter := &fakeAdapter{channelType: ChannelType("test")} + adapter := &fakeAdapter{channelType: Type("test")} manager := NewManager(log, reg, store, processor) manager.RegisterAdapter(adapter) - cfg := ChannelConfig{ + cfg := Config{ ID: "cfg-1", BotID: "bot-1", - ChannelType: ChannelType("test"), + Type: Type("test"), Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), } err := manager.handleInbound(context.Background(), cfg, InboundMessage{ - Channel: ChannelType("test"), + Channel: Type("test"), Message: Message{Text: "hi"}, BotID: "bot-1", ReplyTarget: "123", @@ -171,26 +170,26 @@ func TestManagerHandleInboundIntegratesAdapter(t *testing.T) { func TestManagerSendUsesBinding(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{ - effectiveConfig: ChannelConfig{ + effectiveConfig: Config{ ID: "cfg-1", BotID: "bot-1", - ChannelType: ChannelType("test"), + Type: Type("test"), Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), }, - channelIdentityConfig: ChannelIdentityBinding{ + channelIdentityConfig: IdentityBinding{ ID: "binding-1", Config: map[string]any{"target": "alice"}, }, } reg := NewRegistry() - adapter := &fakeAdapter{channelType: ChannelType("test")} + adapter := &fakeAdapter{channelType: Type("test")} manager := NewManager(log, reg, store, &fakeInboundProcessorIntegration{}) manager.RegisterAdapter(adapter) - err := manager.Send(context.Background(), "bot-1", ChannelType("test"), SendRequest{ + err := manager.Send(context.Background(), "bot-1", Type("test"), SendRequest{ ChannelIdentityID: "user-1", Message: Message{ Text: "hello", @@ -213,21 +212,21 @@ func TestManagerSendUsesBinding(t *testing.T) { func TestManagerReconcileStartsAndStops(t *testing.T) { t.Parallel() - log := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{})) + log := slog.New(slog.DiscardHandler) store := &fakeConfigStore{} reg := NewRegistry() - adapter := &fakeAdapter{channelType: ChannelType("test")} + adapter := &fakeAdapter{channelType: Type("test")} manager := NewManager(log, reg, store, &fakeInboundProcessorIntegration{}) manager.RegisterAdapter(adapter) - cfg := ChannelConfig{ + cfg := Config{ ID: "cfg-1", BotID: "bot-1", - ChannelType: ChannelType("test"), + Type: Type("test"), Credentials: map[string]any{"botToken": "token"}, UpdatedAt: time.Now(), } - manager.reconcile(context.Background(), []ChannelConfig{cfg}) + manager.reconcile(context.Background(), []Config{cfg}) manager.reconcile(context.Background(), nil) adapter.mu.Lock() diff --git a/internal/channel/outbound.go b/internal/channel/outbound.go index 4110b6e3..5b445942 100644 --- a/internal/channel/outbound.go +++ b/internal/channel/outbound.go @@ -2,6 +2,7 @@ package channel import ( "context" + "errors" "fmt" "log/slog" "strings" @@ -11,6 +12,7 @@ import ( // ChunkerMode selects the text chunking strategy. type ChunkerMode string +// ChunkerMode values select how outbound text is split (plain text or markdown-aware). const ( ChunkerModeText ChunkerMode = "text" ChunkerModeMarkdown ChunkerMode = "markdown" @@ -19,6 +21,7 @@ const ( // OutboundOrder controls the delivery order of text and media messages. type OutboundOrder string +// OutboundOrder values: send media before text or text before media. const ( OutboundOrderMediaFirst OutboundOrder = "media_first" OutboundOrderTextFirst OutboundOrder = "text_first" @@ -165,10 +168,7 @@ func splitLongLine(line string, limit int) []string { runes := []rune(line) chunks := make([]string, 0) for start := 0; start < len(runes); start += limit { - end := start + limit - if end > len(runes) { - end = len(runes) - } + end := min(start+limit, len(runes)) segment := strings.TrimSpace(string(runes[start:end])) if segment == "" { continue @@ -180,7 +180,7 @@ func splitLongLine(line string, limit int) []string { // --- Outbound pipeline methods (used by Manager) --- -func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy { +func (m *Manager) resolveOutboundPolicy(channelType Type) OutboundPolicy { policy, ok := m.registry.GetOutboundPolicy(channelType) if !ok { policy = OutboundPolicy{} @@ -191,7 +191,7 @@ func (m *Manager) resolveOutboundPolicy(channelType ChannelType) OutboundPolicy // buildOutboundMessages splits an outbound message into multiple messages based on the policy. func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]OutboundMessage, error) { if msg.Message.IsEmpty() { - return nil, fmt.Errorf("message is required") + return nil, errors.New("message is required") } normalized := normalizeOutboundMessage(msg.Message) chunker := policy.Chunker @@ -246,7 +246,7 @@ func buildOutboundMessages(msg OutboundMessage, policy OutboundPolicy) ([]Outbou } if len(textMessages) == 0 && len(attachmentMessages) == 0 { - return nil, fmt.Errorf("message is required") + return nil, errors.New("message is required") } if policy.MediaOrder == OutboundOrderTextFirst { return append(textMessages, attachmentMessages...), nil @@ -265,7 +265,7 @@ func normalizeOutboundMessage(msg Message) Message { return msg } -func validateMessageCapabilities(registry *Registry, channelType ChannelType, msg Message) error { +func validateMessageCapabilities(registry *Registry, channelType Type, msg Message) error { caps, ok := registry.GetCapabilities(channelType) if !ok { return nil @@ -273,65 +273,65 @@ func validateMessageCapabilities(registry *Registry, channelType ChannelType, ms switch msg.Format { case MessageFormatPlain: if !caps.Text { - return fmt.Errorf("channel does not support plain text") + return errors.New("channel does not support plain text") } case MessageFormatMarkdown: if !caps.Markdown && !caps.RichText { - return fmt.Errorf("channel does not support markdown") + return errors.New("channel does not support markdown") } case MessageFormatRich: if !caps.RichText { - return fmt.Errorf("channel does not support rich text") + return errors.New("channel does not support rich text") } } if len(msg.Parts) > 0 && !caps.RichText { - return fmt.Errorf("channel does not support rich text") + return errors.New("channel does not support rich text") } if len(msg.Attachments) > 0 && !caps.Attachments { - return fmt.Errorf("channel does not support attachments") + return errors.New("channel does not support attachments") } if len(msg.Attachments) > 0 && requiresMedia(msg.Attachments) && !caps.Media { - return fmt.Errorf("channel does not support media") + return errors.New("channel does not support media") } if len(msg.Actions) > 0 && !caps.Buttons { - return fmt.Errorf("channel does not support actions") + return errors.New("channel does not support actions") } if msg.Thread != nil && !caps.Threads { - return fmt.Errorf("channel does not support threads") + return errors.New("channel does not support threads") } if msg.Reply != nil && !caps.Reply { - return fmt.Errorf("channel does not support reply") + return errors.New("channel does not support reply") } if strings.TrimSpace(msg.ID) != "" && !caps.Edit { - return fmt.Errorf("channel does not support edit") + return errors.New("channel does not support edit") } return nil } -func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg ChannelConfig, msg OutboundMessage, policy OutboundPolicy) error { +func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Config, msg OutboundMessage, policy OutboundPolicy) error { if sender == nil { - return fmt.Errorf("unsupported channel type: %s", cfg.ChannelType) + return fmt.Errorf("unsupported channel type: %s", cfg.Type) } target := strings.TrimSpace(msg.Target) if target == "" { - return fmt.Errorf("target is required") + return errors.New("target is required") } if msg.Message.IsEmpty() { - return fmt.Errorf("message is required") + return errors.New("message is required") } normalized := msg - attachments, err := normalizeAttachmentRefs(msg.Message.Attachments, cfg.ChannelType) + attachments, err := normalizeAttachmentRefs(msg.Message.Attachments, cfg.Type) if err != nil { return err } normalized.Message.Attachments = attachments - if err := validateMessageCapabilities(m.registry, cfg.ChannelType, normalized.Message); err != nil { + if err := validateMessageCapabilities(m.registry, cfg.Type, normalized.Message); err != nil { return err } - editor, _ := m.registry.GetMessageEditor(cfg.ChannelType) + editor, _ := m.registry.GetMessageEditor(cfg.Type) if strings.TrimSpace(normalized.Message.ID) != "" { if editor == nil { - return fmt.Errorf("channel does not support edit") + return errors.New("channel does not support edit") } var lastErr error for i := 0; i < policy.RetryMax; i++ { @@ -342,7 +342,7 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel lastErr = err if m.logger != nil { m.logger.Warn("edit outbound retry", - slog.String("channel", cfg.ChannelType.String()), + slog.String("channel", cfg.Type.String()), slog.Int("attempt", i+1), slog.Any("error", err)) } @@ -359,7 +359,7 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel lastErr = err if m.logger != nil { m.logger.Warn("send outbound retry", - slog.String("channel", cfg.ChannelType.String()), + slog.String("channel", cfg.Type.String()), slog.Int("attempt", i+1), slog.Any("error", err)) } @@ -368,7 +368,7 @@ func (m *Manager) sendWithConfig(ctx context.Context, sender Sender, cfg Channel return fmt.Errorf("send outbound failed after retries: %w", lastErr) } -func normalizeAttachmentRefs(attachments []Attachment, defaultPlatform ChannelType) ([]Attachment, error) { +func normalizeAttachmentRefs(attachments []Attachment, defaultPlatform Type) ([]Attachment, error) { if len(attachments) == 0 { return nil, nil } @@ -382,7 +382,7 @@ func normalizeAttachmentRefs(attachments []Attachment, defaultPlatform ChannelTy item.SourcePlatform = defaultPlatform.String() } if item.URL == "" && item.PlatformKey == "" { - return nil, fmt.Errorf("attachment reference is required") + return nil, errors.New("attachment reference is required") } normalized = append(normalized, item) } @@ -401,20 +401,20 @@ func requiresMedia(attachments []Attachment) bool { return false } -func validateStreamEvent(registry *Registry, channelType ChannelType, event StreamEvent) error { +func validateStreamEvent(registry *Registry, channelType Type, event StreamEvent) error { caps, _ := registry.GetCapabilities(channelType) switch event.Type { case StreamEventStatus: if event.Status == "" { - return fmt.Errorf("stream status is required") + return errors.New("stream status is required") } case StreamEventDelta: if !caps.Streaming && !caps.BlockStreaming { - return fmt.Errorf("channel does not support streaming") + return errors.New("channel does not support streaming") } case StreamEventFinal: if event.Final == nil { - return fmt.Errorf("stream final payload is required") + return errors.New("stream final payload is required") } if err := validateMessageCapabilities(registry, channelType, event.Final.Message); err != nil { return err @@ -424,7 +424,7 @@ func validateStreamEvent(registry *Registry, channelType ChannelType, event Stre } case StreamEventError: if strings.TrimSpace(event.Error) == "" { - return fmt.Errorf("stream error is required") + return errors.New("stream error is required") } default: return fmt.Errorf("unsupported stream event type: %s", event.Type) @@ -432,7 +432,7 @@ func validateStreamEvent(registry *Registry, channelType ChannelType, event Stre return nil } -func (m *Manager) newReplySender(cfg ChannelConfig, channelType ChannelType) StreamReplySender { +func (m *Manager) newReplySender(cfg Config, channelType Type) StreamReplySender { sender, _ := m.registry.GetSender(channelType) streamSender, _ := m.registry.GetStreamSender(channelType) return &managerReplySender{ @@ -448,13 +448,13 @@ type managerReplySender struct { manager *Manager sender Sender streamSender StreamSender - channelType ChannelType - config ChannelConfig + channelType Type + config Config } func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) error { if s.manager == nil { - return fmt.Errorf("channel manager not configured") + return errors.New("channel manager not configured") } policy := s.manager.resolveOutboundPolicy(s.channelType) outbound, err := buildOutboundMessages(msg, policy) @@ -471,18 +471,18 @@ func (s *managerReplySender) Send(ctx context.Context, msg OutboundMessage) erro func (s *managerReplySender) OpenStream(ctx context.Context, target string, opts StreamOptions) (OutboundStream, error) { if s.manager == nil { - return nil, fmt.Errorf("channel manager not configured") + return nil, errors.New("channel manager not configured") } if s.streamSender == nil { - return nil, fmt.Errorf("channel stream sender not configured") + return nil, errors.New("channel stream sender not configured") } target = strings.TrimSpace(target) if target == "" { - return nil, fmt.Errorf("target is required") + return nil, errors.New("target is required") } caps, _ := s.manager.registry.GetCapabilities(s.channelType) if !caps.Streaming && !caps.BlockStreaming { - return nil, fmt.Errorf("channel does not support streaming") + return nil, errors.New("channel does not support streaming") } stream, err := s.streamSender.OpenStream(ctx, s.config, target, opts) if err != nil { @@ -498,12 +498,12 @@ func (s *managerReplySender) OpenStream(ctx context.Context, target string, opts type managerOutboundStream struct { manager *Manager stream OutboundStream - channelType ChannelType + channelType Type } func (s *managerOutboundStream) Push(ctx context.Context, event StreamEvent) error { if s.manager == nil || s.stream == nil { - return fmt.Errorf("stream is not configured") + return errors.New("stream is not configured") } if err := validateStreamEvent(s.manager.registry, s.channelType, event); err != nil { return err @@ -513,7 +513,7 @@ func (s *managerOutboundStream) Push(ctx context.Context, event StreamEvent) err func (s *managerOutboundStream) Close(ctx context.Context) error { if s.stream == nil { - return fmt.Errorf("stream is not configured") + return errors.New("stream is not configured") } return s.stream.Close(ctx) } diff --git a/internal/channel/processor.go b/internal/channel/processor.go index bb363013..acb84b83 100644 --- a/internal/channel/processor.go +++ b/internal/channel/processor.go @@ -4,5 +4,5 @@ import "context" // InboundProcessor handles inbound messages and replies through the given sender. type InboundProcessor interface { - HandleInbound(ctx context.Context, cfg ChannelConfig, msg InboundMessage, sender StreamReplySender) error + HandleInbound(ctx context.Context, cfg Config, msg InboundMessage, sender StreamReplySender) error } diff --git a/internal/channel/registry.go b/internal/channel/registry.go index 1bfa8015..ebee1876 100644 --- a/internal/channel/registry.go +++ b/internal/channel/registry.go @@ -2,6 +2,7 @@ package channel import ( "context" + "errors" "fmt" "strings" "sync" @@ -13,24 +14,24 @@ import ( // and passed explicitly to components that need it. type Registry struct { mu sync.RWMutex - adapters map[ChannelType]Adapter + adapters map[Type]Adapter } // NewRegistry creates an empty Registry. func NewRegistry() *Registry { return &Registry{ - adapters: map[ChannelType]Adapter{}, + adapters: map[Type]Adapter{}, } } // Register adds an adapter to the registry. func (r *Registry) Register(adapter Adapter) error { if adapter == nil { - return fmt.Errorf("adapter is nil") + return errors.New("adapter is nil") } - ct := normalizeChannelType(adapter.Type().String()) + ct := normalizeType(adapter.Type().String()) if ct == "" { - return fmt.Errorf("channel type is required") + return errors.New("channel type is required") } r.mu.Lock() defer r.mu.Unlock() @@ -49,8 +50,8 @@ func (r *Registry) MustRegister(adapter Adapter) { } // Unregister removes a channel type from the registry. -func (r *Registry) Unregister(channelType ChannelType) bool { - ct := normalizeChannelType(channelType.String()) +func (r *Registry) Unregister(channelType Type) bool { + ct := normalizeType(channelType.String()) if ct == "" { return false } @@ -64,21 +65,21 @@ func (r *Registry) Unregister(channelType ChannelType) bool { } // Get returns the adapter for the given channel type. -func (r *Registry) Get(channelType ChannelType) (Adapter, bool) { - ct := normalizeChannelType(channelType.String()) +func (r *Registry) Get(channelType Type) (Adapter, bool) { + ct := normalizeType(channelType.String()) r.mu.RLock() defer r.mu.RUnlock() adapter, ok := r.adapters[ct] return adapter, ok } -// DirectoryAdapter returns the directory adapter for the given channel type if it implements ChannelDirectoryAdapter. -func (r *Registry) DirectoryAdapter(channelType ChannelType) (ChannelDirectoryAdapter, bool) { +// DirectoryAdapter returns the directory adapter for the given channel type if it implements DirectoryAdapter. +func (r *Registry) DirectoryAdapter(channelType Type) (DirectoryAdapter, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false } - dir, ok := adapter.(ChannelDirectoryAdapter) + dir, ok := adapter.(DirectoryAdapter) return dir, ok } @@ -94,10 +95,10 @@ func (r *Registry) List() []Adapter { } // Types returns all registered channel types. -func (r *Registry) Types() []ChannelType { +func (r *Registry) Types() []Type { r.mu.RLock() defer r.mu.RUnlock() - items := make([]ChannelType, 0, len(r.adapters)) + items := make([]Type, 0, len(r.adapters)) for ct := range r.adapters { items = append(items, ct) } @@ -107,7 +108,7 @@ func (r *Registry) Types() []ChannelType { // --- Descriptor accessors --- // GetDescriptor returns the descriptor for the given channel type. -func (r *Registry) GetDescriptor(channelType ChannelType) (Descriptor, bool) { +func (r *Registry) GetDescriptor(channelType Type) (Descriptor, bool) { adapter, ok := r.Get(channelType) if !ok { return Descriptor{}, false @@ -125,9 +126,9 @@ func (r *Registry) ListDescriptors() []Descriptor { return items } -// ParseChannelType validates and normalizes a raw string into a registered ChannelType. -func (r *Registry) ParseChannelType(raw string) (ChannelType, error) { - ct := normalizeChannelType(raw) +// ParseChannelType validates and normalizes a raw string into a registered Type. +func (r *Registry) ParseChannelType(raw string) (Type, error) { + ct := normalizeType(raw) if ct == "" { return "", fmt.Errorf("unsupported channel type: %s", raw) } @@ -140,16 +141,16 @@ func (r *Registry) ParseChannelType(raw string) (ChannelType, error) { // --- Capability accessors --- // GetCapabilities returns the capability matrix for the given channel type. -func (r *Registry) GetCapabilities(channelType ChannelType) (ChannelCapabilities, bool) { +func (r *Registry) GetCapabilities(channelType Type) (Capabilities, bool) { desc, ok := r.GetDescriptor(channelType) if !ok { - return ChannelCapabilities{}, false + return Capabilities{}, false } return desc.Capabilities, true } // GetOutboundPolicy returns the outbound policy for the given channel type. -func (r *Registry) GetOutboundPolicy(channelType ChannelType) (OutboundPolicy, bool) { +func (r *Registry) GetOutboundPolicy(channelType Type) (OutboundPolicy, bool) { desc, ok := r.GetDescriptor(channelType) if !ok { return OutboundPolicy{}, false @@ -158,7 +159,7 @@ func (r *Registry) GetOutboundPolicy(channelType ChannelType) (OutboundPolicy, b } // GetConfigSchema returns the configuration schema for the given channel type. -func (r *Registry) GetConfigSchema(channelType ChannelType) (ConfigSchema, bool) { +func (r *Registry) GetConfigSchema(channelType Type) (ConfigSchema, bool) { desc, ok := r.GetDescriptor(channelType) if !ok { return ConfigSchema{}, false @@ -167,7 +168,7 @@ func (r *Registry) GetConfigSchema(channelType ChannelType) (ConfigSchema, bool) } // GetUserConfigSchema returns the user-binding configuration schema. -func (r *Registry) GetUserConfigSchema(channelType ChannelType) (ConfigSchema, bool) { +func (r *Registry) GetUserConfigSchema(channelType Type) (ConfigSchema, bool) { desc, ok := r.GetDescriptor(channelType) if !ok { return ConfigSchema{}, false @@ -176,7 +177,7 @@ func (r *Registry) GetUserConfigSchema(channelType ChannelType) (ConfigSchema, b } // IsConfigless reports whether the channel type operates without per-bot configuration. -func (r *Registry) IsConfigless(channelType ChannelType) bool { +func (r *Registry) IsConfigless(channelType Type) bool { desc, ok := r.GetDescriptor(channelType) if !ok { return false @@ -187,7 +188,7 @@ func (r *Registry) IsConfigless(channelType ChannelType) bool { // --- Sender / Receiver accessors --- // GetSender returns the Sender for the given channel type, or nil if unsupported. -func (r *Registry) GetSender(channelType ChannelType) (Sender, bool) { +func (r *Registry) GetSender(channelType Type) (Sender, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false @@ -197,7 +198,7 @@ func (r *Registry) GetSender(channelType ChannelType) (Sender, bool) { } // GetStreamSender returns the StreamSender for the given channel type, or nil if unsupported. -func (r *Registry) GetStreamSender(channelType ChannelType) (StreamSender, bool) { +func (r *Registry) GetStreamSender(channelType Type) (StreamSender, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false @@ -207,7 +208,7 @@ func (r *Registry) GetStreamSender(channelType ChannelType) (StreamSender, bool) } // GetMessageEditor returns the MessageEditor for the given channel type, or nil if unsupported. -func (r *Registry) GetMessageEditor(channelType ChannelType) (MessageEditor, bool) { +func (r *Registry) GetMessageEditor(channelType Type) (MessageEditor, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false @@ -217,7 +218,7 @@ func (r *Registry) GetMessageEditor(channelType ChannelType) (MessageEditor, boo } // GetReactor returns the Reactor for the given channel type, or nil if unsupported. -func (r *Registry) GetReactor(channelType ChannelType) (Reactor, bool) { +func (r *Registry) GetReactor(channelType Type) (Reactor, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false @@ -227,7 +228,7 @@ func (r *Registry) GetReactor(channelType ChannelType) (Reactor, bool) { } // GetReceiver returns the Receiver for the given channel type, or nil if unsupported. -func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { +func (r *Registry) GetReceiver(channelType Type) (Receiver, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false @@ -237,7 +238,7 @@ func (r *Registry) GetReceiver(channelType ChannelType) (Receiver, bool) { } // GetProcessingStatusNotifier returns the ProcessingStatusNotifier for the given channel type, or nil if unsupported. -func (r *Registry) GetProcessingStatusNotifier(channelType ChannelType) (ProcessingStatusNotifier, bool) { +func (r *Registry) GetProcessingStatusNotifier(channelType Type) (ProcessingStatusNotifier, bool) { adapter, ok := r.Get(channelType) if !ok { return nil, false @@ -247,7 +248,7 @@ func (r *Registry) GetProcessingStatusNotifier(channelType ChannelType) (Process } // DiscoverSelf calls the SelfDiscoverer for the given channel type if supported. -func (r *Registry) DiscoverSelf(ctx context.Context, channelType ChannelType, credentials map[string]any) (map[string]any, string, error) { +func (r *Registry) DiscoverSelf(ctx context.Context, channelType Type, credentials map[string]any) (map[string]any, string, error) { adapter, ok := r.Get(channelType) if !ok { return nil, "", fmt.Errorf("unsupported channel type: %s", channelType) @@ -262,7 +263,7 @@ func (r *Registry) DiscoverSelf(ctx context.Context, channelType ChannelType, cr // --- Dispatch methods (replace former global functions in config.go / target.go) --- // NormalizeConfig validates and normalizes a channel configuration map. -func (r *Registry) NormalizeConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { +func (r *Registry) NormalizeConfig(channelType Type, raw map[string]any) (map[string]any, error) { if raw == nil { raw = map[string]any{} } @@ -277,7 +278,7 @@ func (r *Registry) NormalizeConfig(channelType ChannelType, raw map[string]any) } // NormalizeUserConfig validates and normalizes a user-channel binding configuration. -func (r *Registry) NormalizeUserConfig(channelType ChannelType, raw map[string]any) (map[string]any, error) { +func (r *Registry) NormalizeUserConfig(channelType Type, raw map[string]any) (map[string]any, error) { if raw == nil { raw = map[string]any{} } @@ -292,7 +293,7 @@ func (r *Registry) NormalizeUserConfig(channelType ChannelType, raw map[string]a } // ResolveTargetFromUserConfig derives a delivery target from a user-channel binding. -func (r *Registry) ResolveTargetFromUserConfig(channelType ChannelType, config map[string]any) (string, error) { +func (r *Registry) ResolveTargetFromUserConfig(channelType Type, config map[string]any) (string, error) { adapter, ok := r.Get(channelType) if !ok { return "", fmt.Errorf("unsupported channel type: %s", channelType) @@ -304,7 +305,7 @@ func (r *Registry) ResolveTargetFromUserConfig(channelType ChannelType, config m } // NormalizeTarget applies the channel-specific target normalization. -func (r *Registry) NormalizeTarget(channelType ChannelType, raw string) (string, bool) { +func (r *Registry) NormalizeTarget(channelType Type, raw string) (string, bool) { adapter, ok := r.Get(channelType) if !ok { return strings.TrimSpace(raw), false @@ -320,7 +321,7 @@ func (r *Registry) NormalizeTarget(channelType ChannelType, raw string) (string, } // MatchUserBinding reports whether the given binding config matches the criteria. -func (r *Registry) MatchUserBinding(channelType ChannelType, config map[string]any, criteria BindingCriteria) bool { +func (r *Registry) MatchUserBinding(channelType Type, config map[string]any, criteria BindingCriteria) bool { adapter, ok := r.Get(channelType) if !ok { return false @@ -332,7 +333,7 @@ func (r *Registry) MatchUserBinding(channelType ChannelType, config map[string]a } // BuildUserBindingConfig constructs a user-channel binding config from an Identity. -func (r *Registry) BuildUserBindingConfig(channelType ChannelType, identity Identity) map[string]any { +func (r *Registry) BuildUserBindingConfig(channelType Type, identity Identity) map[string]any { adapter, ok := r.Get(channelType) if !ok { return map[string]any{} @@ -343,10 +344,10 @@ func (r *Registry) BuildUserBindingConfig(channelType ChannelType, identity Iden return map[string]any{} } -func normalizeChannelType(raw string) ChannelType { +func normalizeType(raw string) Type { normalized := strings.TrimSpace(strings.ToLower(raw)) if normalized == "" { return "" } - return ChannelType(normalized) + return Type(normalized) } diff --git a/internal/channel/registry_test.go b/internal/channel/registry_test.go index c27c3875..fa9b67b4 100644 --- a/internal/channel/registry_test.go +++ b/internal/channel/registry_test.go @@ -7,30 +7,30 @@ import ( "github.com/memohai/memoh/internal/channel" ) -const dirTestChannelType = channel.ChannelType("dir-test") +const dirTestChannelType = channel.Type("dir-test") -// dirMockAdapter implements Adapter and ChannelDirectoryAdapter for registry DirectoryAdapter tests. +// dirMockAdapter implements Adapter and DirectoryAdapter for registry DirectoryAdapter tests. type dirMockAdapter struct{} -func (a *dirMockAdapter) Type() channel.ChannelType { return dirTestChannelType } +func (a *dirMockAdapter) Type() channel.Type { return dirTestChannelType } func (a *dirMockAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{Type: dirTestChannelType, DisplayName: "DirTest"} } -func (a *dirMockAdapter) ListPeers(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *dirMockAdapter) ListPeers(_ context.Context, _ channel.Config, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (a *dirMockAdapter) ListGroups(ctx context.Context, cfg channel.ChannelConfig, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *dirMockAdapter) ListGroups(_ context.Context, _ channel.Config, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (a *dirMockAdapter) ListGroupMembers(ctx context.Context, cfg channel.ChannelConfig, groupID string, query channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { +func (a *dirMockAdapter) ListGroupMembers(_ context.Context, _ channel.Config, _ string, _ channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (a *dirMockAdapter) ResolveEntry(ctx context.Context, cfg channel.ChannelConfig, input string, kind channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { +func (a *dirMockAdapter) ResolveEntry(_ context.Context, _ channel.Config, _ string, _ channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{}, nil } @@ -56,7 +56,7 @@ func TestDirectoryAdapter_Supported(t *testing.T) { func TestDirectoryAdapter_UnknownType(t *testing.T) { t.Parallel() reg := channel.NewRegistry() - dir, ok := reg.DirectoryAdapter(channel.ChannelType("unknown")) + dir, ok := reg.DirectoryAdapter(channel.Type("unknown")) if ok || dir != nil { t.Fatalf("DirectoryAdapter(unknown) = (%v, %v), want (nil, false)", dir, ok) } diff --git a/internal/channel/route/service.go b/internal/channel/route/service.go index 7436a9f5..75f3713a 100644 --- a/internal/channel/route/service.go +++ b/internal/channel/route/service.go @@ -1,8 +1,10 @@ +// Package route provides conversation routing and channel-route management. package route import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -180,7 +182,7 @@ func (s *DBService) ResolveConversation(ctx context.Context, input ResolveInput) } if s.conversation == nil { - return ResolveConversationResult{}, fmt.Errorf("conversation service not configured") + return ResolveConversationResult{}, errors.New("conversation service not configured") } kind := determineConversationKind(input.ThreadID, input.ConversationType) diff --git a/internal/channel/schema.go b/internal/channel/schema.go index 2f818270..ac6bd14d 100644 --- a/internal/channel/schema.go +++ b/internal/channel/schema.go @@ -3,6 +3,7 @@ package channel // FieldType enumerates the supported configuration field types. type FieldType string +// FieldType values for configuration schema fields. const ( FieldString FieldType = "string" FieldSecret FieldType = "secret" diff --git a/internal/channel/service.go b/internal/channel/service.go index f3ba5c1c..a5ed8324 100644 --- a/internal/channel/service.go +++ b/internal/channel/service.go @@ -30,24 +30,24 @@ func NewService(queries *sqlc.Queries, registry *Registry) *Service { } // UpsertConfig creates or updates a bot's channel configuration. -func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType ChannelType, req UpsertConfigRequest) (ChannelConfig, error) { +func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Type, req UpsertConfigRequest) (Config, error) { if s.queries == nil { - return ChannelConfig{}, fmt.Errorf("channel queries not configured") + return Config{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelConfig{}, fmt.Errorf("channel type is required") + return Config{}, errors.New("channel type is required") } normalized, err := s.registry.NormalizeConfig(channelType, req.Credentials) if err != nil { - return ChannelConfig{}, err + return Config{}, err } credentialsPayload, err := json.Marshal(normalized) if err != nil { - return ChannelConfig{}, err + return Config{}, err } botUUID, err := db.ParseUUID(botID) if err != nil { - return ChannelConfig{}, err + return Config{}, err } selfIdentity := req.SelfIdentity if selfIdentity == nil { @@ -66,7 +66,7 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch } selfPayload, err := json.Marshal(selfIdentity) if err != nil { - return ChannelConfig{}, err + return Config{}, err } routing := req.Routing if routing == nil { @@ -74,11 +74,11 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch } routingPayload, err := json.Marshal(routing) if err != nil { - return ChannelConfig{}, err + return Config{}, err } - status, err := normalizeChannelConfigStatus(req.Status) + status, err := normalizeConfigStatus(req.Status) if err != nil { - return ChannelConfig{}, err + return Config{}, err } verifiedAt := pgtype.Timestamptz{Valid: false} if req.VerifiedAt != nil { @@ -99,12 +99,12 @@ func (s *Service) UpsertConfig(ctx context.Context, botID string, channelType Ch VerifiedAt: verifiedAt, }) if err != nil { - return ChannelConfig{}, err + return Config{}, err } - return normalizeChannelConfig(row) + return normalizeConfig(row) } -func normalizeChannelConfigStatus(raw string) (string, error) { +func normalizeConfigStatus(raw string) (string, error) { status := strings.ToLower(strings.TrimSpace(raw)) if status == "" { return "pending", nil @@ -122,24 +122,24 @@ func normalizeChannelConfigStatus(raw string) (string, error) { } // UpsertChannelIdentityConfig creates or updates a channel identity's channel binding. -func (s *Service) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType, req UpsertChannelIdentityConfigRequest) (ChannelIdentityBinding, error) { +func (s *Service) UpsertChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType Type, req UpsertChannelIdentityConfigRequest) (IdentityBinding, error) { if s.queries == nil { - return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") + return IdentityBinding{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") + return IdentityBinding{}, errors.New("channel type is required") } normalized, err := s.registry.NormalizeUserConfig(channelType, req.Config) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } payload, err := json.Marshal(normalized) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } row, err := s.queries.UpsertUserChannelBinding(ctx, sqlc.UpsertUserChannelBindingParams{ UserID: pgChannelIdentityID, @@ -147,59 +147,59 @@ func (s *Service) UpsertChannelIdentityConfig(ctx context.Context, channelIdenti Config: payload, }) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } - return normalizeChannelIdentityBinding(row) + return normalizeIdentityBinding(row) } // ResolveEffectiveConfig returns the active channel configuration for a bot. // For configless channel types, a synthetic config is returned. -func (s *Service) ResolveEffectiveConfig(ctx context.Context, botID string, channelType ChannelType) (ChannelConfig, error) { +func (s *Service) ResolveEffectiveConfig(ctx context.Context, botID string, channelType Type) (Config, error) { if s.queries == nil { - return ChannelConfig{}, fmt.Errorf("channel queries not configured") + return Config{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelConfig{}, fmt.Errorf("channel type is required") + return Config{}, errors.New("channel type is required") } if s.registry.IsConfigless(channelType) { - return ChannelConfig{ - ID: channelType.String() + ":" + strings.TrimSpace(botID), - BotID: strings.TrimSpace(botID), - ChannelType: channelType, + return Config{ + ID: channelType.String() + ":" + strings.TrimSpace(botID), + BotID: strings.TrimSpace(botID), + Type: channelType, }, nil } botUUID, err := db.ParseUUID(botID) if err != nil { - return ChannelConfig{}, err + return Config{}, err } row, err := s.queries.GetBotChannelConfig(ctx, sqlc.GetBotChannelConfigParams{ BotID: botUUID, ChannelType: channelType.String(), }) if err == nil { - return normalizeChannelConfig(row) + return normalizeConfig(row) } if !errors.Is(err, pgx.ErrNoRows) { - return ChannelConfig{}, err + return Config{}, err } - return ChannelConfig{}, fmt.Errorf("channel config not found") + return Config{}, errors.New("channel config not found") } // ListConfigsByType returns all channel configurations of the given type. -func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelConfig, error) { +func (s *Service) ListConfigsByType(ctx context.Context, channelType Type) ([]Config, error) { if s.queries == nil { - return nil, fmt.Errorf("channel queries not configured") + return nil, errors.New("channel queries not configured") } if s.registry.IsConfigless(channelType) { - return []ChannelConfig{}, nil + return []Config{}, nil } rows, err := s.queries.ListBotChannelConfigsByType(ctx, channelType.String()) if err != nil { return nil, err } - items := make([]ChannelConfig, 0, len(rows)) + items := make([]Config, 0, len(rows)) for _, row := range rows { - item, err := normalizeChannelConfig(row) + item, err := normalizeConfig(row) if err != nil { return nil, err } @@ -209,16 +209,16 @@ func (s *Service) ListConfigsByType(ctx context.Context, channelType ChannelType } // GetChannelIdentityConfig returns the channel identity's channel binding for the given channel type. -func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType ChannelType) (ChannelIdentityBinding, error) { +func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityID string, channelType Type) (IdentityBinding, error) { if s.queries == nil { - return ChannelIdentityBinding{}, fmt.Errorf("channel queries not configured") + return IdentityBinding{}, errors.New("channel queries not configured") } if channelType == "" { - return ChannelIdentityBinding{}, fmt.Errorf("channel type is required") + return IdentityBinding{}, errors.New("channel type is required") } pgChannelIdentityID, err := db.ParseUUID(channelIdentityID) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } row, err := s.queries.GetUserChannelBinding(ctx, sqlc.GetUserChannelBindingParams{ UserID: pgChannelIdentityID, @@ -226,17 +226,17 @@ func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityI }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return ChannelIdentityBinding{}, fmt.Errorf("channel user config not found") + return IdentityBinding{}, errors.New("channel user config not found") } - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } config, err := DecodeConfigMap(row.Config) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } - return ChannelIdentityBinding{ + return IdentityBinding{ ID: row.ID.String(), - ChannelType: ChannelType(row.ChannelType), + Type: Type(row.ChannelType), ChannelIdentityID: row.UserID.String(), Config: config, CreatedAt: db.TimeFromPg(row.CreatedAt), @@ -245,17 +245,17 @@ func (s *Service) GetChannelIdentityConfig(ctx context.Context, channelIdentityI } // ListChannelIdentityConfigsByType returns all channel identity bindings for the given channel type. -func (s *Service) ListChannelIdentityConfigsByType(ctx context.Context, channelType ChannelType) ([]ChannelIdentityBinding, error) { +func (s *Service) ListChannelIdentityConfigsByType(ctx context.Context, channelType Type) ([]IdentityBinding, error) { if s.queries == nil { - return nil, fmt.Errorf("channel queries not configured") + return nil, errors.New("channel queries not configured") } rows, err := s.queries.ListUserChannelBindingsByPlatform(ctx, channelType.String()) if err != nil { return nil, err } - items := make([]ChannelIdentityBinding, 0, len(rows)) + items := make([]IdentityBinding, 0, len(rows)) for _, row := range rows { - item, err := normalizeChannelIdentityBinding(row) + item, err := normalizeIdentityBinding(row) if err != nil { return nil, err } @@ -264,8 +264,8 @@ func (s *Service) ListChannelIdentityConfigsByType(ctx context.Context, channelT return items, nil } -// ResolveChannelIdentityBinding finds the channel identity ID whose channel binding matches the given criteria. -func (s *Service) ResolveChannelIdentityBinding(ctx context.Context, channelType ChannelType, criteria BindingCriteria) (string, error) { +// ResolveIdentityBinding finds the channel identity ID whose channel binding matches the given criteria. +func (s *Service) ResolveIdentityBinding(ctx context.Context, channelType Type, criteria BindingCriteria) (string, error) { rows, err := s.ListChannelIdentityConfigsByType(ctx, channelType) if err != nil { return "", err @@ -278,21 +278,21 @@ func (s *Service) ResolveChannelIdentityBinding(ctx context.Context, channelType return row.ChannelIdentityID, nil } } - return "", fmt.Errorf("channel user binding not found") + return "", errors.New("channel user binding not found") } -func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { +func normalizeConfig(row sqlc.BotChannelConfig) (Config, error) { credentials, err := DecodeConfigMap(row.Credentials) if err != nil { - return ChannelConfig{}, err + return Config{}, err } selfIdentity, err := DecodeConfigMap(row.SelfIdentity) if err != nil { - return ChannelConfig{}, err + return Config{}, err } routing, err := DecodeConfigMap(row.Routing) if err != nil { - return ChannelConfig{}, err + return Config{}, err } verifiedAt := time.Time{} if row.VerifiedAt.Valid { @@ -302,10 +302,10 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { if row.ExternalIdentity.Valid { externalIdentity = strings.TrimSpace(row.ExternalIdentity.String) } - return ChannelConfig{ + return Config{ ID: row.ID.String(), BotID: row.BotID.String(), - ChannelType: ChannelType(row.ChannelType), + Type: Type(row.ChannelType), Credentials: credentials, ExternalIdentity: externalIdentity, SelfIdentity: selfIdentity, @@ -317,14 +317,14 @@ func normalizeChannelConfig(row sqlc.BotChannelConfig) (ChannelConfig, error) { }, nil } -func normalizeChannelIdentityBinding(row sqlc.UserChannelBinding) (ChannelIdentityBinding, error) { +func normalizeIdentityBinding(row sqlc.UserChannelBinding) (IdentityBinding, error) { config, err := DecodeConfigMap(row.Config) if err != nil { - return ChannelIdentityBinding{}, err + return IdentityBinding{}, err } - return ChannelIdentityBinding{ + return IdentityBinding{ ID: row.ID.String(), - ChannelType: ChannelType(row.ChannelType), + Type: Type(row.ChannelType), ChannelIdentityID: row.UserID.String(), Config: config, CreatedAt: db.TimeFromPg(row.CreatedAt), diff --git a/internal/channel/types.go b/internal/channel/types.go index d54e88a6..52badebd 100644 --- a/internal/channel/types.go +++ b/internal/channel/types.go @@ -7,11 +7,11 @@ import ( "time" ) -// ChannelType identifies a messaging platform (e.g., "telegram", "feishu"). -type ChannelType string +// Type identifies a messaging platform (e.g., "telegram", "feishu"). +type Type string // String returns the channel type as a plain string. -func (c ChannelType) String() string { +func (c Type) String() string { return string(c) } @@ -41,7 +41,7 @@ type Conversation struct { // InboundMessage is a message received from an external channel. type InboundMessage struct { - Channel ChannelType + Channel Type Message Message BotID string ReplyTarget string @@ -89,6 +89,7 @@ type OutboundMessage struct { // StreamEventType defines the kind of outbound stream event. type StreamEventType string +// StreamEventType values for outbound stream events (status, delta, final, error). const ( StreamEventStatus StreamEventType = "status" StreamEventDelta StreamEventType = "delta" @@ -99,6 +100,7 @@ const ( // StreamStatus indicates the lifecycle state of a streaming reply. type StreamStatus string +// StreamStatus values for stream lifecycle (started, completed, failed). const ( StreamStatusStarted StreamStatus = "started" StreamStatusCompleted StreamStatus = "completed" @@ -130,6 +132,7 @@ type StreamOptions struct { // MessageFormat indicates how the message text should be rendered. type MessageFormat string +// MessageFormat values for how message text is rendered (plain, markdown, rich). const ( MessageFormatPlain MessageFormat = "plain" MessageFormatMarkdown MessageFormat = "markdown" @@ -139,6 +142,7 @@ const ( // MessagePartType identifies the kind of a rich-text message part. type MessagePartType string +// MessagePartType values for rich message parts (text, link, code_block, mention, emoji). const ( MessagePartText MessagePartType = "text" MessagePartLink MessagePartType = "link" @@ -150,6 +154,7 @@ const ( // MessageTextStyle describes inline formatting for a text part. type MessageTextStyle string +// MessageTextStyle values for inline styles (bold, italic, strikethrough, code). const ( MessageStyleBold MessageTextStyle = "bold" MessageStyleItalic MessageTextStyle = "italic" @@ -172,6 +177,7 @@ type MessagePart struct { // AttachmentType classifies the kind of binary attachment. type AttachmentType string +// AttachmentType values for message attachments (image, audio, video, voice, file, gif). const ( AttachmentImage AttachmentType = "image" AttachmentAudio AttachmentType = "audio" @@ -304,11 +310,11 @@ func BindingCriteriaFromIdentity(identity Identity) BindingCriteria { } } -// ChannelConfig holds the configuration for a bot's channel integration. -type ChannelConfig struct { +// Config holds the configuration for a bot's channel integration. +type Config struct { ID string `json:"id"` BotID string `json:"bot_id"` - ChannelType ChannelType `json:"channel_type"` + Type Type `json:"channel_type"` Credentials map[string]any `json:"credentials"` ExternalIdentity string `json:"external_identity"` SelfIdentity map[string]any `json:"self_identity"` @@ -319,10 +325,10 @@ type ChannelConfig struct { UpdatedAt time.Time `json:"updated_at"` } -// ChannelIdentityBinding represents a channel identity's binding to a specific channel type. -type ChannelIdentityBinding struct { +// IdentityBinding represents a channel identity's binding to a specific channel type. +type IdentityBinding struct { ID string `json:"id"` - ChannelType ChannelType `json:"channel_type"` + Type Type `json:"channel_type"` ChannelIdentityID string `json:"channel_identity_id"` Config map[string]any `json:"config"` CreatedAt time.Time `json:"created_at"` diff --git a/internal/config/config.go b/internal/config/config.go index c4a5ea49..9410a61d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -1,12 +1,14 @@ +// Package config loads and exposes application configuration (TOML). package config import ( - "fmt" "os" + "strconv" "github.com/BurntSushi/toml" ) +// Default configuration values used when a field is missing in TOML. const ( DefaultConfigPath = "config.toml" DefaultHTTPAddr = ":8080" @@ -25,6 +27,7 @@ const ( DefaultQdrantCollection = "memory" ) +// Config is the root application configuration loaded from TOML. type Config struct { Log LogConfig `toml:"log"` Server ServerConfig `toml:"server"` @@ -37,31 +40,37 @@ type Config struct { AgentGateway AgentGatewayConfig `toml:"agent_gateway"` } +// LogConfig holds logging level and format (e.g. level=info, format=text). type LogConfig struct { Level string `toml:"level"` Format string `toml:"format"` } +// ServerConfig holds the HTTP server listen address. type ServerConfig struct { Addr string `toml:"addr"` } +// AdminConfig holds the initial admin account (username, password, email). type AdminConfig struct { Username string `toml:"username"` Password string `toml:"password"` Email string `toml:"email"` } +// AuthConfig holds JWT secret and token expiry (e.g. 24h). type AuthConfig struct { JWTSecret string `toml:"jwt_secret"` JWTExpiresIn string `toml:"jwt_expires_in"` } +// ContainerdConfig holds the containerd socket path and namespace. type ContainerdConfig struct { SocketPath string `toml:"socket_path"` Namespace string `toml:"namespace"` } +// MCPConfig holds MCP container image, snapshotter, and data paths. type MCPConfig struct { Image string `toml:"image"` Snapshotter string `toml:"snapshotter"` @@ -69,6 +78,7 @@ type MCPConfig struct { DataMount string `toml:"data_mount"` } +// PostgresConfig holds PostgreSQL connection parameters. type PostgresConfig struct { Host string `toml:"host"` Port int `toml:"port"` @@ -78,6 +88,7 @@ type PostgresConfig struct { SSLMode string `toml:"sslmode"` } +// QdrantConfig holds Qdrant base URL, API key, collection name, and timeout. type QdrantConfig struct { BaseURL string `toml:"base_url"` APIKey string `toml:"api_key"` @@ -85,11 +96,13 @@ type QdrantConfig struct { TimeoutSeconds int `toml:"timeout_seconds"` } +// AgentGatewayConfig holds the agent gateway host and port. type AgentGatewayConfig struct { Host string `toml:"host"` Port int `toml:"port"` } +// BaseURL returns the agent gateway base URL (e.g. http://127.0.0.1:8081) from host and port. func (c AgentGatewayConfig) BaseURL() string { host := c.Host if host == "" { @@ -99,9 +112,10 @@ func (c AgentGatewayConfig) BaseURL() string { if port == 0 { port = 8081 } - return "http://" + host + ":" + fmt.Sprint(port) + return "http://" + host + ":" + strconv.Itoa(port) } +// Load reads and parses the TOML config file at path and applies default values for missing fields. func Load(path string) (Config, error) { cfg := Config{ Log: LogConfig{ diff --git a/internal/containerd/factory.go b/internal/containerd/factory.go index 36cce0ff..c5d3636d 100644 --- a/internal/containerd/factory.go +++ b/internal/containerd/factory.go @@ -1,3 +1,4 @@ +// Package containerd provides the containerd client factory and service abstraction. package containerd import ( @@ -6,19 +7,23 @@ import ( containerd "github.com/containerd/containerd/v2/client" ) +// Default socket path and namespace when not set in config. const ( DefaultSocketPath = "/run/containerd/containerd.sock" DefaultNamespace = "default" ) +// ClientFactory creates a containerd client (e.g. from socket path). type ClientFactory interface { New(ctx context.Context) (*containerd.Client, error) } +// DefaultClientFactory creates a client using SocketPath (or DefaultSocketPath if empty). type DefaultClientFactory struct { SocketPath string } +// New returns a new containerd client connected to the configured socket. func (f DefaultClientFactory) New(_ context.Context) (*containerd.Client, error) { socket := f.SocketPath if socket == "" { diff --git a/internal/containerd/mount.go b/internal/containerd/mount.go index 2899131b..4174c3f7 100644 --- a/internal/containerd/mount.go +++ b/internal/containerd/mount.go @@ -9,13 +9,14 @@ import ( "github.com/containerd/containerd/v2/core/mount" ) +// MountedSnapshot holds the mount directory, container info, and an Unmount function to release it. type MountedSnapshot struct { Dir string Info containers.Container Unmount func() error } -// MountContainerSnapshot mounts the active snapshot for a container. +// MountContainerSnapshot mounts the active snapshot for a container into a temp dir; call Unmount when done. func MountContainerSnapshot(ctx context.Context, service Service, containerID string) (*MountedSnapshot, error) { if containerID == "" { return nil, ErrInvalidArgument diff --git a/internal/containerd/network.go b/internal/containerd/network.go index c895dbe4..2d2f057f 100644 --- a/internal/containerd/network.go +++ b/internal/containerd/network.go @@ -8,12 +8,14 @@ import ( "os/exec" "path/filepath" "runtime" + "strconv" "strings" "github.com/containerd/containerd/v2/client" gocni "github.com/containerd/go-cni" ) +// Default CNI config and binary directories on Linux. const ( defaultCNIConfDir = "/etc/cni/net.d" defaultCNIBinDir = "/opt/cni/bin" @@ -45,7 +47,7 @@ func SetupNetwork(ctx context.Context, task client.Task, containerID string) err if _, err := os.Stat(defaultCNIBinDir); err != nil { return fmt.Errorf("cni bin dir missing: %s: %w", defaultCNIBinDir, err) } - netnsPath := filepath.Join("/proc", fmt.Sprint(pid), "ns", "net") + netnsPath := filepath.Join("/proc", strconv.FormatUint(uint64(pid), 10), "ns", "net") if _, err := os.Stat(netnsPath); err != nil { return fmt.Errorf("netns not found: %s: %w", netnsPath, err) } @@ -85,7 +87,7 @@ func setupNetworkWithCLI(ctx context.Context, containerID string, pid uint32) er "memoh-cli", "cni-setup", "--id", containerID, - "--pid", fmt.Sprint(pid), + "--pid", strconv.FormatUint(uint64(pid), 10), "--conf-dir", defaultCNIConfDir, "--bin-dir", defaultCNIBinDir, } @@ -144,7 +146,7 @@ func RemoveNetwork(ctx context.Context, task client.Task, containerID string) er return fmt.Errorf("cni bin dir missing: %s: %w", defaultCNIBinDir, err) } - netnsPath := filepath.Join("/proc", fmt.Sprint(pid), "ns", "net") + netnsPath := filepath.Join("/proc", strconv.FormatUint(uint64(pid), 10), "ns", "net") if _, err := os.Stat(netnsPath); err != nil { return fmt.Errorf("netns not found: %s: %w", netnsPath, err) } @@ -173,7 +175,7 @@ func removeNetworkWithCLI(ctx context.Context, containerID string, pid uint32) e "memoh-cli", "cni-remove", "--id", containerID, - "--pid", fmt.Sprint(pid), + "--pid", strconv.FormatUint(uint64(pid), 10), "--conf-dir", defaultCNIConfDir, "--bin-dir", defaultCNIBinDir, } diff --git a/internal/containerd/resolv.go b/internal/containerd/resolve.go similarity index 86% rename from internal/containerd/resolv.go rename to internal/containerd/resolve.go index 8455c0c1..8c4afe4c 100644 --- a/internal/containerd/resolv.go +++ b/internal/containerd/resolve.go @@ -1,6 +1,7 @@ package containerd import ( + "errors" "fmt" "os" "os/exec" @@ -9,6 +10,7 @@ import ( "strings" ) +// Paths used when resolving /etc/resolv.conf source. const ( systemdResolvConf = "/run/systemd/resolve/resolv.conf" fallbackResolv = "nameserver 1.1.1.1\nnameserver 8.8.8.8\n" @@ -60,14 +62,15 @@ func limaFileExists(path string) (bool, error) { "-f", path, ) - if err := cmd.Run(); err == nil { + err := cmd.Run() + if err == nil { return true, nil - } else if exitErr, ok := err.(*exec.ExitError); ok { + } + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { if exitErr.ExitCode() == 1 { return false, nil } - return false, fmt.Errorf("lima test failed for %s: %w", path, err) - } else { - return false, fmt.Errorf("lima test failed for %s: %w", path, err) } + return false, fmt.Errorf("lima test failed for %s: %w", path, err) } diff --git a/internal/containerd/service.go b/internal/containerd/service.go index c825905a..560130df 100644 --- a/internal/containerd/service.go +++ b/internal/containerd/service.go @@ -26,26 +26,31 @@ import ( "github.com/containerd/containerd/v2/pkg/oci" "github.com/containerd/errdefs" "github.com/containerd/platforms" - "github.com/memohai/memoh/internal/config" "github.com/opencontainers/go-digest" "github.com/opencontainers/image-spec/identity" "github.com/opencontainers/runtime-spec/specs-go" + + "github.com/memohai/memoh/internal/config" ) +// Errors returned by containerd operations. var ( ErrInvalidArgument = errors.New("invalid argument") ErrTaskStopTimeout = errors.New("timeout waiting for task to stop") ) +// PullImageOptions configures image pull (unpack and snapshotter). type PullImageOptions struct { Unpack bool Snapshotter string } +// DeleteImageOptions configures image deletion (synchronous wait). type DeleteImageOptions struct { Synchronous bool } +// CreateContainerRequest specifies container ID, image, snapshot, labels, and OCI spec options. type CreateContainerRequest struct { ID string ImageRef string @@ -55,26 +60,31 @@ type CreateContainerRequest struct { SpecOpts []oci.SpecOpts } +// DeleteContainerOptions configures whether to remove the container snapshot. type DeleteContainerOptions struct { CleanupSnapshot bool } +// StartTaskOptions configures stdio, terminal, and FIFO directory for the task. type StartTaskOptions struct { UseStdio bool Terminal bool FIFODir string } +// StopTaskOptions configures signal, timeout, and force kill for stopping a task. type StopTaskOptions struct { Signal syscall.Signal Timeout time.Duration Force bool } +// DeleteTaskOptions configures force deletion of a task. type DeleteTaskOptions struct { Force bool } +// ExecTaskRequest specifies command args, env, work dir, stdio, and terminal for exec. type ExecTaskRequest struct { Args []string Env []string @@ -87,6 +97,7 @@ type ExecTaskRequest struct { Stderr io.Writer } +// ExecTaskSession holds stdio streams and Wait/Close for a streaming exec session. type ExecTaskSession struct { Stdin io.WriteCloser Stdout io.ReadCloser @@ -95,19 +106,23 @@ type ExecTaskSession struct { Close func() error } +// ExecTaskResult holds the exit code of an exec call. type ExecTaskResult struct { ExitCode uint32 } +// SnapshotCommitResult holds version and active snapshot IDs after a commit. type SnapshotCommitResult struct { VersionSnapshotID string ActiveSnapshotID string } +// ListTasksOptions optionally filters listed tasks. type ListTasksOptions struct { Filter string } +// TaskInfo holds container ID, task ID, PID, status, and exit status for a task. type TaskInfo struct { ContainerID string ID string @@ -116,6 +131,7 @@ type TaskInfo struct { ExitStatus uint32 } +// Service is the containerd abstraction for images, containers, tasks, snapshots, and exec. type Service interface { PullImage(ctx context.Context, ref string, opts *PullImageOptions) (containerd.Image, error) GetImage(ctx context.Context, ref string) (containerd.Image, error) @@ -142,12 +158,14 @@ type Service interface { SnapshotMounts(ctx context.Context, snapshotter, key string) ([]mount.Mount, error) } +// DefaultService is the default implementation of Service using a containerd client and namespace. type DefaultService struct { client *containerd.Client namespace string logger *slog.Logger } +// NewDefaultService builds a DefaultService from logger, client, and config (namespace from config). func NewDefaultService(log *slog.Logger, client *containerd.Client, cfg config.Config) *DefaultService { namespace := cfg.Containerd.Namespace if namespace == "" { @@ -160,6 +178,7 @@ func NewDefaultService(log *slog.Logger, client *containerd.Client, cfg config.C } } +// PullImage pulls an image by reference; optional unpack and snapshotter. func (s *DefaultService) PullImage(ctx context.Context, ref string, opts *PullImageOptions) (containerd.Image, error) { if ref == "" { return nil, ErrInvalidArgument @@ -177,6 +196,7 @@ func (s *DefaultService) PullImage(ctx context.Context, ref string, opts *PullIm return s.client.Pull(ctx, ref, pullOpts...) } +// GetImage returns the image for the given reference. func (s *DefaultService) GetImage(ctx context.Context, ref string) (containerd.Image, error) { if ref == "" { return nil, ErrInvalidArgument @@ -185,11 +205,13 @@ func (s *DefaultService) GetImage(ctx context.Context, ref string) (containerd.I return s.client.GetImage(ctx, ref) } +// ListImages lists all images in the namespace. func (s *DefaultService) ListImages(ctx context.Context) ([]containerd.Image, error) { ctx = s.withNamespace(ctx) return s.client.ListImages(ctx) } +// DeleteImage deletes the image by reference, optionally waiting synchronously. func (s *DefaultService) DeleteImage(ctx context.Context, ref string, opts *DeleteImageOptions) error { if ref == "" { return ErrInvalidArgument @@ -202,6 +224,7 @@ func (s *DefaultService) DeleteImage(ctx context.Context, ref string, opts *Dele return s.client.ImageService().Delete(ctx, ref, deleteOpts...) } +// CreateContainer creates a container from an image or existing snapshot (with lease and spec). func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContainerRequest) (containerd.Container, error) { if req.ID == "" || req.ImageRef == "" { return nil, ErrInvalidArgument @@ -212,7 +235,11 @@ func (s *DefaultService) CreateContainer(ctx context.Context, req CreateContaine if err != nil { return nil, err } - defer done(ctx) + defer func() { + if err := done(ctx); err != nil { + s.logger.Warn("release lease failed", slog.Any("error", err)) + } + }() image, err := s.getImageWithFallback(ctx, req.ImageRef) if err != nil { pullOpts := &PullImageOptions{ @@ -278,7 +305,7 @@ func (s *DefaultService) snapshotParentFromLayers(ctx context.Context, image con return "", err } if len(manifest.Layers) == 0 { - return "", fmt.Errorf("image has no layer descriptors") + return "", errors.New("image has no layer descriptors") } diffIDs := make([]digest.Digest, 0, len(manifest.Layers)) for _, layer := range manifest.Layers { @@ -344,8 +371,8 @@ func (s *DefaultService) getImageWithFallback(ctx context.Context, ref string) ( if err == nil { return image, nil } - if strings.HasPrefix(ref, "docker.io/library/") { - alt := strings.TrimPrefix(ref, "docker.io/library/") + if after, ok := strings.CutPrefix(ref, "docker.io/library/"); ok { + alt := after image, altErr := s.GetImage(ctx, alt) if altErr == nil { return image, nil @@ -358,8 +385,8 @@ func (s *DefaultService) getImageWithFallback(ctx context.Context, ref string) ( if name == ref || strings.HasSuffix(ref, "/"+name) || strings.HasSuffix(name, "/"+ref) { return img, nil } - if strings.HasPrefix(ref, "docker.io/library/") { - alt := strings.TrimPrefix(ref, "docker.io/library/") + if after, ok := strings.CutPrefix(ref, "docker.io/library/"); ok { + alt := after if name == alt || strings.HasSuffix(name, "/"+alt) { return img, nil } @@ -369,6 +396,7 @@ func (s *DefaultService) getImageWithFallback(ctx context.Context, ref string) ( return nil, err } +// GetContainer returns the container by ID. func (s *DefaultService) GetContainer(ctx context.Context, id string) (containerd.Container, error) { if id == "" { return nil, ErrInvalidArgument @@ -377,11 +405,13 @@ func (s *DefaultService) GetContainer(ctx context.Context, id string) (container return s.client.LoadContainer(ctx, id) } +// ListContainers returns all containers in the namespace. func (s *DefaultService) ListContainers(ctx context.Context) ([]containerd.Container, error) { ctx = s.withNamespace(ctx) return s.client.Containers(ctx) } +// DeleteContainer deletes the container and optionally its snapshot. func (s *DefaultService) DeleteContainer(ctx context.Context, id string, opts *DeleteContainerOptions) error { if id == "" { return ErrInvalidArgument @@ -405,6 +435,7 @@ func (s *DefaultService) DeleteContainer(ctx context.Context, id string, opts *D return container.Delete(ctx, deleteOpts...) } +// StartTask creates and starts the container task (optional stdio/terminal/FIFO). func (s *DefaultService) StartTask(ctx context.Context, containerID string, opts *StartTaskOptions) (containerd.Task, error) { if containerID == "" { return nil, ErrInvalidArgument @@ -440,6 +471,7 @@ func (s *DefaultService) StartTask(ctx context.Context, containerID string, opts return task, nil } +// GetTask returns the running task for the container. func (s *DefaultService) GetTask(ctx context.Context, containerID string) (containerd.Task, error) { if containerID == "" { return nil, ErrInvalidArgument @@ -453,6 +485,7 @@ func (s *DefaultService) GetTask(ctx context.Context, containerID string) (conta return container.Task(ctx, nil) } +// ListTasks returns task info for all tasks, optionally filtered. func (s *DefaultService) ListTasks(ctx context.Context, opts *ListTasksOptions) ([]TaskInfo, error) { ctx = s.withNamespace(ctx) request := &tasksv1.ListTasksRequest{} @@ -479,6 +512,7 @@ func (s *DefaultService) ListTasks(ctx context.Context, opts *ListTasksOptions) return tasks, nil } +// StopTask stops the task with signal and timeout; optional force kill. func (s *DefaultService) StopTask(ctx context.Context, containerID string, opts *StopTaskOptions) error { if containerID == "" { return ErrInvalidArgument @@ -530,6 +564,7 @@ func (s *DefaultService) StopTask(ctx context.Context, containerID string, opts } } +// DeleteTask deletes the task; optional force kill before delete. func (s *DefaultService) DeleteTask(ctx context.Context, containerID string, opts *DeleteTaskOptions) error { if containerID == "" { return ErrInvalidArgument @@ -549,6 +584,7 @@ func (s *DefaultService) DeleteTask(ctx context.Context, containerID string, opt return err } +// ExecTask runs a command in the container and returns the exit code (non-streaming). func (s *DefaultService) ExecTask(ctx context.Context, containerID string, req ExecTaskRequest) (ExecTaskResult, error) { if containerID == "" || len(req.Args) == 0 { return ExecTaskResult{}, ErrInvalidArgument @@ -609,7 +645,11 @@ func (s *DefaultService) ExecTask(ctx context.Context, containerID string, req E if err != nil { return ExecTaskResult{}, err } - defer process.Delete(ctx) + defer func() { + if _, err := process.Delete(ctx); err != nil { + s.logger.Warn("exec process delete failed", slog.Any("error", err)) + } + }() statusC, err := process.Wait(ctx) if err != nil { @@ -628,6 +668,7 @@ func (s *DefaultService) ExecTask(ctx context.Context, containerID string, req E return ExecTaskResult{ExitCode: code}, nil } +// ExecTaskStreaming runs a command and returns stdio streams plus Wait/Close for the session. func (s *DefaultService) ExecTaskStreaming(ctx context.Context, containerID string, req ExecTaskRequest) (*ExecTaskSession, error) { if containerID == "" || len(req.Args) == 0 { return nil, ErrInvalidArgument @@ -755,18 +796,19 @@ func resolveExecFIFODir(preferred string) (string, error) { var lastErr error for _, dir := range candidates { - if err := os.MkdirAll(dir, 0o755); err == nil { + err := os.MkdirAll(dir, 0o755) + if err == nil { return dir, nil - } else { - lastErr = err } + lastErr = err } if lastErr == nil { - lastErr = fmt.Errorf("no fifo directory candidate available") + lastErr = errors.New("no fifo directory candidate available") } return "", lastErr } +// ListContainersByLabel returns containers whose label key matches (value optional). func (s *DefaultService) ListContainersByLabel(ctx context.Context, key, value string) ([]containerd.Container, error) { if key == "" { return nil, ErrInvalidArgument @@ -791,6 +833,7 @@ func (s *DefaultService) ListContainersByLabel(ctx context.Context, key, value s return filtered, nil } +// CommitSnapshot commits the active snapshot key as a new snapshot name. func (s *DefaultService) CommitSnapshot(ctx context.Context, snapshotter, name, key string) error { if snapshotter == "" || name == "" || key == "" { return ErrInvalidArgument @@ -799,13 +842,14 @@ func (s *DefaultService) CommitSnapshot(ctx context.Context, snapshotter, name, return s.client.SnapshotService(snapshotter).Commit(ctx, name, key) } +// ListSnapshots walks the snapshotter and returns all snapshot infos. func (s *DefaultService) ListSnapshots(ctx context.Context, snapshotter string) ([]snapshots.Info, error) { if snapshotter == "" { return nil, ErrInvalidArgument } ctx = s.withNamespace(ctx) infos := []snapshots.Info{} - if err := s.client.SnapshotService(snapshotter).Walk(ctx, func(ctx context.Context, info snapshots.Info) error { + if err := s.client.SnapshotService(snapshotter).Walk(ctx, func(_ context.Context, info snapshots.Info) error { infos = append(infos, info) return nil }); err != nil { @@ -814,6 +858,7 @@ func (s *DefaultService) ListSnapshots(ctx context.Context, snapshotter string) return infos, nil } +// PrepareSnapshot prepares a new active snapshot key from parent. func (s *DefaultService) PrepareSnapshot(ctx context.Context, snapshotter, key, parent string) error { if snapshotter == "" || key == "" || parent == "" { return ErrInvalidArgument @@ -823,6 +868,7 @@ func (s *DefaultService) PrepareSnapshot(ctx context.Context, snapshotter, key, return err } +// CreateContainerFromSnapshot creates a container from an existing snapshot (no new snapshot). func (s *DefaultService) CreateContainerFromSnapshot(ctx context.Context, req CreateContainerRequest) (containerd.Container, error) { if req.ID == "" || req.SnapshotID == "" { return nil, ErrInvalidArgument @@ -874,6 +920,7 @@ func (s *DefaultService) CreateContainerFromSnapshot(ctx context.Context, req Cr return s.client.NewContainer(ctx, req.ID, containerOpts...) } +// SnapshotMounts returns the mount points for the snapshot key. func (s *DefaultService) SnapshotMounts(ctx context.Context, snapshotter, key string) ([]mount.Mount, error) { if snapshotter == "" || key == "" { return nil, ErrInvalidArgument diff --git a/internal/conversation/flow/assistant_output.go b/internal/conversation/flow/assistant_output.go index 2f011174..075b867f 100644 --- a/internal/conversation/flow/assistant_output.go +++ b/internal/conversation/flow/assistant_output.go @@ -1,3 +1,4 @@ +// Package flow provides conversation flow types and assistant output extraction. package flow import ( diff --git a/internal/conversation/flow/resolver.go b/internal/conversation/flow/resolver.go index 2132d33d..7133a0c9 100644 --- a/internal/conversation/flow/resolver.go +++ b/internal/conversation/flow/resolver.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -178,13 +179,13 @@ type resolvedContext struct { func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (resolvedContext, error) { if strings.TrimSpace(req.Query) == "" { - return resolvedContext{}, fmt.Errorf("query is required") + return resolvedContext{}, errors.New("query is required") } if strings.TrimSpace(req.BotID) == "" { - return resolvedContext{}, fmt.Errorf("bot id is required") + return resolvedContext{}, errors.New("bot id is required") } if strings.TrimSpace(req.ChatID) == "" { - return resolvedContext{}, fmt.Errorf("chat id is required") + return resolvedContext{}, errors.New("chat id is required") } skipHistory := req.MaxContextLoadTime < 0 @@ -236,12 +237,7 @@ func (r *Resolver) resolve(ctx context.Context, req conversation.ChatRequest) (r } else { usableSkills = make([]gatewaySkill, 0, len(entries)) for _, e := range entries { - usableSkills = append(usableSkills, gatewaySkill{ - Name: e.Name, - Description: e.Description, - Content: e.Content, - Metadata: e.Metadata, - }) + usableSkills = append(usableSkills, gatewaySkill(e)) } } } @@ -308,10 +304,10 @@ func (r *Resolver) Chat(ctx context.Context, req conversation.ChatRequest) (conv // TriggerSchedule executes a scheduled command through the agent gateway trigger-schedule endpoint. func (r *Resolver) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if strings.TrimSpace(botID) == "" { - return fmt.Errorf("bot id is required") + return errors.New("bot id is required") } if strings.TrimSpace(payload.Command) == "" { - return fmt.Errorf("schedule command is required") + return errors.New("schedule command is required") } req := conversation.ChatRequest{ @@ -422,7 +418,11 @@ func (r *Resolver) postChat(ctx context.Context, payload gatewayRequest, token s if err != nil { return gatewayResponse{}, err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + r.logger.Warn("gateway request: close response body failed", slog.Any("error", err)) + } + }() respBody, err := io.ReadAll(resp.Body) if err != nil { @@ -463,7 +463,11 @@ func (r *Resolver) postTriggerSchedule(ctx context.Context, payload triggerSched if err != nil { return gatewayResponse{}, err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + r.logger.Warn("trigger-schedule: close response body failed", slog.Any("error", err)) + } + }() respBody, err := io.ReadAll(resp.Body) if err != nil { @@ -504,7 +508,11 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req c r.logger.Error("gateway stream connect failed", slog.String("url", url), slog.Any("error", err)) return err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + r.logger.Warn("gateway stream: close response body failed", slog.Any("error", err)) + } + }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { errBody, _ := io.ReadAll(resp.Body) @@ -522,8 +530,8 @@ func (r *Resolver) streamChat(ctx context.Context, payload gatewayRequest, req c if line == "" { continue } - if strings.HasPrefix(line, "event:") { - currentEvent = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + if after, ok := strings.CutPrefix(line, "event:"); ok { + currentEvent = strings.TrimSpace(after) continue } if !strings.HasPrefix(line, "data:") { @@ -631,7 +639,7 @@ func (r *Resolver) loadMessages(ctx context.Context, chatID string, maxContextMi type memoryContextItem struct { Namespace string - Item memory.MemoryItem + Item memory.Item } func (r *Resolver) loadMemoryContextMessage(ctx context.Context, req conversation.ChatRequest) *conversation.ModelMessage { @@ -718,7 +726,7 @@ func (r *Resolver) persistUserMessage(ctx context.Context, req conversation.Chat return nil } if strings.TrimSpace(req.BotID) == "" { - return fmt.Errorf("bot id is required for persistence") + return errors.New("bot id is required for persistence") } text := strings.TrimSpace(req.Query) if text == "" { @@ -990,7 +998,7 @@ func (r *Resolver) addMemory(ctx context.Context, botID string, msgs []memory.Me func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatRequest, botSettings settings.Settings, cs conversation.Settings) (models.GetResponse, sqlc.LlmProvider, error) { if r.modelsService == nil { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured") } modelID := strings.TrimSpace(req.Model) providerFilter := strings.TrimSpace(req.Provider) @@ -1005,7 +1013,7 @@ func (r *Resolver) selectChatModel(ctx context.Context, req conversation.ChatReq } if modelID == "" { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("chat model not configured: specify model in request or bot settings") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("chat model not configured: specify model in request or bot settings") } if providerFilter == "" { @@ -1034,7 +1042,7 @@ func (r *Resolver) fetchChatModel(ctx context.Context, modelID string) (models.G return models.GetResponse{}, sqlc.LlmProvider{}, err } if model.Type != models.ModelTypeChat { - return models.GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("model is not a chat model") + return models.GetResponse{}, sqlc.LlmProvider{}, errors.New("model is not a chat model") } prov, err := models.FetchProviderByID(ctx, r.queries, model.LlmProviderID) if err != nil { @@ -1067,7 +1075,7 @@ func (r *Resolver) listCandidates(ctx context.Context, providerFilter string) ([ func (r *Resolver) loadBotSettings(ctx context.Context, botID string) (settings.Settings, error) { if r.settingsService == nil { - return settings.Settings{}, fmt.Errorf("settings service not configured") + return settings.Settings{}, errors.New("settings service not configured") } return r.settingsService.GetBot(ctx, botID) } @@ -1116,15 +1124,6 @@ func dedup(items []string) []string { return result } -func firstNonEmpty(values ...string) string { - for _, v := range values { - if strings.TrimSpace(v) != "" { - return v - } - } - return "" -} - func coalescePositiveInt(values ...int) int { for _, v := range values { if v > 0 { @@ -1165,7 +1164,7 @@ func truncateMemorySnippet(s string, n int) string { func parseResolverUUID(id string) (pgtype.UUID, error) { if strings.TrimSpace(id) == "" { - return pgtype.UUID{}, fmt.Errorf("empty id") + return pgtype.UUID{}, errors.New("empty id") } return db.ParseUUID(id) } diff --git a/internal/conversation/flow/resolver_test.go b/internal/conversation/flow/resolver_test.go index 1f744576..10e7a4d8 100644 --- a/internal/conversation/flow/resolver_test.go +++ b/internal/conversation/flow/resolver_test.go @@ -26,7 +26,7 @@ func TestPostTriggerSchedule_Endpoint(t *testing.T) { Messages: []conversation.ModelMessage{{Role: "assistant", Content: conversation.NewTextContent("ok")}}, } w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() @@ -107,7 +107,7 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedAuth = r.Header.Get("Authorization") resp := gatewayResponse{Messages: []conversation.ModelMessage{}} - json.NewEncoder(w).Encode(resp) + _ = json.NewEncoder(w).Encode(resp) })) defer srv.Close() @@ -137,9 +137,9 @@ func TestPostTriggerSchedule_NoAuth(t *testing.T) { } func TestPostTriggerSchedule_GatewayError(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) - w.Write([]byte("internal error")) + _, _ = w.Write([]byte("internal error")) })) defer srv.Close() diff --git a/internal/conversation/flow/schedule_gateway.go b/internal/conversation/flow/schedule_gateway.go index 4b3c2138..412b4b06 100644 --- a/internal/conversation/flow/schedule_gateway.go +++ b/internal/conversation/flow/schedule_gateway.go @@ -2,7 +2,7 @@ package flow import ( "context" - "fmt" + "errors" "github.com/memohai/memoh/internal/schedule" ) @@ -20,7 +20,7 @@ func NewScheduleGateway(resolver *Resolver) *ScheduleGateway { // TriggerSchedule delegates a schedule trigger to the chat Resolver. func (g *ScheduleGateway) TriggerSchedule(ctx context.Context, botID string, payload schedule.TriggerPayload, token string) error { if g == nil || g.resolver == nil { - return fmt.Errorf("chat resolver not configured") + return errors.New("chat resolver not configured") } return g.resolver.TriggerSchedule(ctx, botID, payload, token) } diff --git a/internal/conversation/interfaces.go b/internal/conversation/interfaces.go index 283010a2..6f5d6bb9 100644 --- a/internal/conversation/interfaces.go +++ b/internal/conversation/interfaces.go @@ -16,5 +16,5 @@ type ParticipantChecker interface { type Accessor interface { Reader ParticipantChecker - GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ConversationReadAccess, error) + GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ReadAccess, error) } diff --git a/internal/conversation/service.go b/internal/conversation/service.go index f85700c2..f77e948e 100644 --- a/internal/conversation/service.go +++ b/internal/conversation/service.go @@ -16,6 +16,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Errors returned by conversation operations. var ( ErrChatNotFound = errors.New("chat not found") ErrNotParticipant = errors.New("not a participant") @@ -127,14 +128,14 @@ func (s *Service) Get(ctx context.Context, conversationID string) (Conversation, } // GetReadAccess resolves whether a user can read a conversation. -func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ConversationReadAccess, error) { +func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (ReadAccess, error) { pgConversationID, err := parseUUID(conversationID) if err != nil { - return ConversationReadAccess{}, ErrPermissionDenied + return ReadAccess{}, ErrPermissionDenied } pgChannelIdentityID, err := parseUUID(channelIdentityID) if err != nil { - return ConversationReadAccess{}, ErrPermissionDenied + return ReadAccess{}, ErrPermissionDenied } row, err := s.queries.GetChatReadAccessByUser(ctx, sqlc.GetChatReadAccessByUserParams{ ChatID: pgConversationID, @@ -142,11 +143,11 @@ func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIden }) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return ConversationReadAccess{}, ErrPermissionDenied + return ReadAccess{}, ErrPermissionDenied } - return ConversationReadAccess{}, err + return ReadAccess{}, err } - return ConversationReadAccess{ + return ReadAccess{ AccessMode: row.AccessMode, ParticipantRole: strings.TrimSpace(row.ParticipantRole), LastObservedAt: pgTimePtr(row.LastObservedAt), @@ -154,7 +155,7 @@ func (s *Service) GetReadAccess(ctx context.Context, conversationID, channelIden } // ListByBotAndChannelIdentity returns all visible conversations for a bot and channel identity. -func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ConversationListItem, error) { +func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channelIdentityID string) ([]ListItem, error) { pgBotID, err := parseUUID(botID) if err != nil { return nil, err @@ -170,7 +171,7 @@ func (s *Service) ListByBotAndChannelIdentity(ctx context.Context, botID, channe if err != nil { return nil, err } - conversations := make([]ConversationListItem, 0, len(rows)) + conversations := make([]ListItem, 0, len(rows)) for _, row := range rows { conversations = append(conversations, toChatListItem(row)) } @@ -388,8 +389,8 @@ func toChatFields(id, botID pgtype.UUID, kind string, parentChatID pgtype.UUID, } } -func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ConversationListItem { - return ConversationListItem{ +func toChatListItem(row sqlc.ListVisibleChatsByBotAndUserRow) ListItem { + return ListItem{ ID: row.ID.String(), BotID: row.BotID.String(), Kind: row.Kind, diff --git a/internal/conversation/service_integration_test.go b/internal/conversation/service_integration_test.go index eff8caf4..7240e24d 100644 --- a/internal/conversation/service_integration_test.go +++ b/internal/conversation/service_integration_test.go @@ -175,7 +175,7 @@ func TestObservedChatVisibleAfterBindWithoutBackfill(t *testing.T) { t.Fatalf("expected observed chat visible after bind, got %d chats", len(afterBind)) } - var target *conversation.ConversationListItem + var target *conversation.ListItem for i := range afterBind { if afterBind[i].ID == chatID { target = &afterBind[i] diff --git a/internal/conversation/types.go b/internal/conversation/types.go index c6431a28..67774d0f 100644 --- a/internal/conversation/types.go +++ b/internal/conversation/types.go @@ -41,8 +41,8 @@ type Conversation struct { UpdatedAt time.Time `json:"updated_at"` } -// ConversationListItem is a conversation entry with access context for list rendering. -type ConversationListItem struct { +// ListItem is a conversation entry with access context for list rendering. +type ListItem struct { ID string `json:"id"` BotID string `json:"bot_id"` Kind string `json:"kind"` @@ -57,8 +57,8 @@ type ConversationListItem struct { LastObservedAt *time.Time `json:"last_observed_at,omitempty"` } -// ConversationReadAccess is the resolved access context for reading conversation content. -type ConversationReadAccess struct { +// ReadAccess is the resolved access context for reading conversation content. +type ReadAccess struct { AccessMode string ParticipantRole string LastObservedAt *time.Time diff --git a/internal/db/db.go b/internal/db/db.go index eea327d0..4e107447 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,3 +1,4 @@ +// Package db provides PostgreSQL connection and pool helpers. package db import ( @@ -9,6 +10,7 @@ import ( "github.com/memohai/memoh/internal/config" ) +// Open creates a pgx connection pool from the given Postgres config (DSN built from host, port, user, etc.). func Open(ctx context.Context, cfg config.PostgresConfig) (*pgxpool.Pool, error) { dsn := fmt.Sprintf( "postgres://%s:%s@%s:%d/%s?sslmode=%s", diff --git a/internal/db/utils_test.go b/internal/db/utils_test.go index 92b01386..7c2e2ccf 100644 --- a/internal/db/utils_test.go +++ b/internal/db/utils_test.go @@ -1,6 +1,7 @@ package db import ( + "errors" "fmt" "testing" "time" @@ -107,7 +108,7 @@ func TestIsUniqueViolation(t *testing.T) { want bool }{ {"nil", nil, false}, - {"plain error", fmt.Errorf("some error"), false}, + {"plain error", errors.New("some error"), false}, {"unique violation", &pgconn.PgError{Code: "23505"}, true}, {"other pg error", &pgconn.PgError{Code: "23503"}, false}, {"wrapped unique violation", fmt.Errorf("wrapped: %w", &pgconn.PgError{Code: "23505"}), true}, diff --git a/internal/embeddings/bootstrap.go b/internal/embeddings/bootstrap.go index a3e3b281..82a18b74 100644 --- a/internal/embeddings/bootstrap.go +++ b/internal/embeddings/bootstrap.go @@ -1,3 +1,4 @@ +// Package embeddings provides embedder interfaces and resolver-based text embedding. package embeddings import ( @@ -13,6 +14,7 @@ type ResolverTextEmbedder struct { Dims int } +// Embed delegates to Resolver.Embed with TypeText and the configured ModelID. func (e *ResolverTextEmbedder) Embed(ctx context.Context, input string) ([]float32, error) { result, err := e.Resolver.Embed(ctx, Request{ Type: TypeText, @@ -25,6 +27,7 @@ func (e *ResolverTextEmbedder) Embed(ctx context.Context, input string) ([]float return result.Embedding, nil } +// Dimensions returns the configured embedding dimension (Dims). func (e *ResolverTextEmbedder) Dimensions() int { return e.Dims } diff --git a/internal/embeddings/dashscope.go b/internal/embeddings/dashscope.go index f16dc424..5846e561 100644 --- a/internal/embeddings/dashscope.go +++ b/internal/embeddings/dashscope.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -12,11 +13,13 @@ import ( "time" ) +// Default DashScope API base URL and embedding path. const ( DefaultDashScopeBaseURL = "https://dashscope.aliyuncs.com" DashScopeEmbeddingPath = "/api/v1/services/embeddings/multimodal-embedding/multimodal-embedding" ) +// DashScopeEmbedder calls Aliyun DashScope multimodal embedding API. type DashScopeEmbedder struct { apiKey string baseURL string @@ -25,6 +28,8 @@ type DashScopeEmbedder struct { http *http.Client } +// DashScopeUsage holds token/duration usage from DashScope API response. +// DashScopeUsage holds token/duration usage from DashScope API response. type DashScopeUsage struct { InputTokens int `json:"input_tokens"` ImageTokens int `json:"image_tokens"` @@ -55,6 +60,7 @@ type dashScopeResponse struct { Message string `json:"message"` } +// NewDashScopeEmbedder builds a DashScope embedder; baseURL defaults to DefaultDashScopeBaseURL if empty. func NewDashScopeEmbedder(log *slog.Logger, apiKey, baseURL, model string, timeout time.Duration) *DashScopeEmbedder { if baseURL == "" { baseURL = DefaultDashScopeBaseURL @@ -73,7 +79,8 @@ func NewDashScopeEmbedder(log *slog.Logger, apiKey, baseURL, model string, timeo } } -func (e *DashScopeEmbedder) Embed(ctx context.Context, text string, imageURL string, videoURL string) ([]float32, DashScopeUsage, error) { +// Embed returns the embedding vector and usage for text and/or image/video URLs via DashScope API. +func (e *DashScopeEmbedder) Embed(ctx context.Context, text, imageURL, videoURL string) ([]float32, DashScopeUsage, error) { contents := make([]map[string]string, 0, 3) if strings.TrimSpace(text) != "" { contents = append(contents, map[string]string{"text": text}) @@ -85,7 +92,7 @@ func (e *DashScopeEmbedder) Embed(ctx context.Context, text string, imageURL str contents = append(contents, map[string]string{"video": videoURL}) } if len(contents) == 0 { - return nil, DashScopeUsage{}, fmt.Errorf("dashscope input is required") + return nil, DashScopeUsage{}, errors.New("dashscope input is required") } payload, err := json.Marshal(dashScopeRequest{ @@ -107,7 +114,11 @@ func (e *DashScopeEmbedder) Embed(ctx context.Context, text string, imageURL str if err != nil { return nil, DashScopeUsage{}, err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + e.logger.Warn("dashscope embeddings: close response body failed", slog.Any("error", err)) + } + }() body, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, DashScopeUsage{}, fmt.Errorf("dashscope embeddings error: %s", strings.TrimSpace(string(body))) @@ -121,15 +132,16 @@ func (e *DashScopeEmbedder) Embed(ctx context.Context, text string, imageURL str return nil, parsed.Usage, fmt.Errorf("dashscope embeddings error: %s", parsed.Message) } if len(parsed.Output.Embeddings) == 0 { - return nil, parsed.Usage, fmt.Errorf("dashscope embeddings empty response") + return nil, parsed.Usage, errors.New("dashscope embeddings empty response") } - preferredType := "" - if strings.TrimSpace(text) != "" { + var preferredType string + switch { + case strings.TrimSpace(text) != "": preferredType = "text" - } else if strings.TrimSpace(imageURL) != "" { + case strings.TrimSpace(imageURL) != "": preferredType = "image" - } else if strings.TrimSpace(videoURL) != "" { + case strings.TrimSpace(videoURL) != "": preferredType = "video" } diff --git a/internal/embeddings/embeddings.go b/internal/embeddings/embeddings.go index 0bc98fae..35e68548 100644 --- a/internal/embeddings/embeddings.go +++ b/internal/embeddings/embeddings.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -12,11 +13,13 @@ import ( "time" ) +// Embedder produces vector embeddings for text (or other input) and reports dimension. type Embedder interface { Embed(ctx context.Context, input string) ([]float32, error) Dimensions() int } +// OpenAIEmbedder calls OpenAI-compatible embedding API (e.g. OpenAI or local) for text. type OpenAIEmbedder struct { apiKey string baseURL string @@ -37,18 +40,19 @@ type openAIEmbeddingResponse struct { } `json:"data"` } +// NewOpenAIEmbedder builds an OpenAIEmbedder; baseURL, apiKey, model required; dims must be positive. func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int, timeout time.Duration) (*OpenAIEmbedder, error) { if strings.TrimSpace(baseURL) == "" { - return nil, fmt.Errorf("openai embedder: base url is required") + return nil, errors.New("openai embedder: base url is required") } if strings.TrimSpace(apiKey) == "" { - return nil, fmt.Errorf("openai embedder: api key is required") + return nil, errors.New("openai embedder: api key is required") } if strings.TrimSpace(model) == "" { - return nil, fmt.Errorf("openai embedder: model is required") + return nil, errors.New("openai embedder: model is required") } if dims <= 0 { - return nil, fmt.Errorf("openai embedder: dimensions must be positive") + return nil, errors.New("openai embedder: dimensions must be positive") } if timeout <= 0 { timeout = 10 * time.Second @@ -65,10 +69,12 @@ func NewOpenAIEmbedder(log *slog.Logger, apiKey, baseURL, model string, dims int }, nil } +// Dimensions returns the embedding dimension configured for this embedder. func (e *OpenAIEmbedder) Dimensions() int { return e.dims } +// Embed returns the embedding vector for the given text via the OpenAI-compatible API. func (e *OpenAIEmbedder) Embed(ctx context.Context, input string) ([]float32, error) { payload, err := json.Marshal(openAIEmbeddingRequest{ Input: input, @@ -91,7 +97,11 @@ func (e *OpenAIEmbedder) Embed(ctx context.Context, input string) ([]float32, er if err != nil { return nil, err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + e.logger.Warn("embeddings: close response body failed", slog.Any("error", err)) + } + }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { body, _ := io.ReadAll(resp.Body) return nil, fmt.Errorf("openai embeddings error: %s", strings.TrimSpace(string(body))) @@ -102,7 +112,7 @@ func (e *OpenAIEmbedder) Embed(ctx context.Context, input string) ([]float32, er return nil, err } if len(parsed.Data) == 0 { - return nil, fmt.Errorf("openai embeddings empty response") + return nil, errors.New("openai embeddings empty response") } return parsed.Data[0].Embedding, nil } diff --git a/internal/embeddings/resolver.go b/internal/embeddings/resolver.go index f7ebf0ca..7bc2a26b 100644 --- a/internal/embeddings/resolver.go +++ b/internal/embeddings/resolver.go @@ -14,6 +14,7 @@ import ( "github.com/memohai/memoh/internal/models" ) +// Embedding type and provider name constants. const ( TypeText = "text" TypeMultimodal = "multimodal" @@ -23,6 +24,7 @@ const ( ProviderDashScope = "dashscope" ) +// Request specifies embedding type, provider, model, dimensions, and input (text/image/video URL). type Request struct { Type string Provider string @@ -31,18 +33,21 @@ type Request struct { Input Input } +// Input holds text and optional image/video URLs for multimodal embedding. type Input struct { Text string ImageURL string VideoURL string } +// Usage holds token and duration usage returned by the embedding API. type Usage struct { InputTokens int ImageTokens int Duration int } +// Result holds embedding vector, type, provider, model, dimensions, and usage. type Result struct { Type string Provider string @@ -52,6 +57,7 @@ type Result struct { Usage Usage } +// Resolver resolves embedding requests by provider/model and delegates to the appropriate embedder. type Resolver struct { modelsService *models.Service queries *sqlc.Queries @@ -59,6 +65,7 @@ type Resolver struct { logger *slog.Logger } +// NewResolver creates a Resolver with the given models service, queries, and timeout. func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries, timeout time.Duration) *Resolver { return &Resolver{ modelsService: modelsService, @@ -68,6 +75,8 @@ func NewResolver(log *slog.Logger, modelsService *models.Service, queries *sqlc. } } +// Embed performs the embedding request using the resolved provider and model. +// Embed performs the embedding request using the resolved provider and model. func (r *Resolver) Embed(ctx context.Context, req Request) (Result, error) { req.Type = strings.ToLower(strings.TrimSpace(req.Type)) req.Provider = strings.ToLower(strings.TrimSpace(req.Provider)) diff --git a/internal/handlers/auth.go b/internal/handlers/auth.go index dfb4503d..0abe935a 100644 --- a/internal/handlers/auth.go +++ b/internal/handlers/auth.go @@ -1,3 +1,4 @@ +// Package handlers provides HTTP API handlers for the Memoh agent server. package handlers import ( @@ -13,6 +14,7 @@ import ( "github.com/memohai/memoh/internal/auth" ) +// AuthHandler serves /auth/login and issues JWTs. type AuthHandler struct { accountService *accounts.Service jwtSecret string @@ -20,11 +22,13 @@ type AuthHandler struct { logger *slog.Logger } +// LoginRequest is the body for POST /auth/login. type LoginRequest struct { Username string `json:"username"` Password string `json:"password"` } +// LoginResponse is the success body (access_token, user info, expires_at). type LoginResponse struct { AccessToken string `json:"access_token"` TokenType string `json:"token_type"` @@ -35,6 +39,7 @@ type LoginResponse struct { Username string `json:"username"` } +// NewAuthHandler creates an auth handler with account service and JWT config. func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecret string, expiresIn time.Duration) *AuthHandler { return &AuthHandler{ accountService: accountService, @@ -44,6 +49,7 @@ func NewAuthHandler(log *slog.Logger, accountService *accounts.Service, jwtSecre } } +// Register mounts POST /auth/login on the Echo instance. func (h *AuthHandler) Register(e *echo.Echo) { e.POST("/auth/login", h.Login) } @@ -57,7 +63,7 @@ func (h *AuthHandler) Register(e *echo.Echo) { // @Failure 400 {object} ErrorResponse // @Failure 401 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /auth/login [post] +// @Router /auth/login [post]. func (h *AuthHandler) Login(c echo.Context) error { if h.accountService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "user service not configured") diff --git a/internal/handlers/channel.go b/internal/handlers/channel.go index 26476b24..10f2bbc3 100644 --- a/internal/handlers/channel.go +++ b/internal/handlers/channel.go @@ -10,15 +10,18 @@ import ( "github.com/memohai/memoh/internal/channel" ) +// ChannelHandler serves channel identity config and channel metadata APIs. type ChannelHandler struct { service *channel.Service registry *channel.Registry } +// NewChannelHandler creates a channel handler. func NewChannelHandler(service *channel.Service, registry *channel.Registry) *ChannelHandler { return &ChannelHandler{service: service, registry: registry} } +// Register mounts /users/me/channels and /channels routes on the Echo instance. func (h *ChannelHandler) Register(e *echo.Echo) { group := e.Group("/users/me/channels") group.GET("/:platform", h.GetChannelIdentityConfig) @@ -34,11 +37,11 @@ func (h *ChannelHandler) Register(e *echo.Echo) { // @Description Get channel binding configuration for current user // @Tags channel // @Param platform path string true "Channel platform" -// @Success 200 {object} channel.ChannelIdentityBinding +// @Success 200 {object} channel.IdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/channels/{platform} [get] +// @Router /users/me/channels/{platform} [get]. func (h *ChannelHandler) GetChannelIdentityConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -64,10 +67,10 @@ func (h *ChannelHandler) GetChannelIdentityConfig(c echo.Context) error { // @Tags channel // @Param platform path string true "Channel platform" // @Param payload body channel.UpsertChannelIdentityConfigRequest true "Channel user config payload" -// @Success 200 {object} channel.ChannelIdentityBinding +// @Success 200 {object} channel.IdentityBinding // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/channels/{platform} [put] +// @Router /users/me/channels/{platform} [put]. func (h *ChannelHandler) UpsertChannelIdentityConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -91,14 +94,15 @@ func (h *ChannelHandler) UpsertChannelIdentityConfig(c echo.Context) error { return c.JSON(http.StatusOK, resp) } +// ChannelMeta is the API response for channel metadata (type, display name, capabilities, schemas). type ChannelMeta struct { - Type string `json:"type"` - DisplayName string `json:"display_name"` - Configless bool `json:"configless"` - Capabilities channel.ChannelCapabilities `json:"capabilities"` - ConfigSchema channel.ConfigSchema `json:"config_schema"` - UserConfigSchema channel.ConfigSchema `json:"user_config_schema"` - TargetSpec channel.TargetSpec `json:"target_spec"` + Type string `json:"type"` + DisplayName string `json:"display_name"` + Configless bool `json:"configless"` + Capabilities channel.Capabilities `json:"capabilities"` + ConfigSchema channel.ConfigSchema `json:"config_schema"` + UserConfigSchema channel.ConfigSchema `json:"user_config_schema"` + TargetSpec channel.TargetSpec `json:"target_spec"` } // ListChannels godoc @@ -107,7 +111,7 @@ type ChannelMeta struct { // @Tags channel // @Success 200 {array} ChannelMeta // @Failure 500 {object} ErrorResponse -// @Router /channels [get] +// @Router /channels [get]. func (h *ChannelHandler) ListChannels(c echo.Context) error { descs := h.registry.ListDescriptors() items := make([]ChannelMeta, 0, len(descs)) @@ -136,7 +140,7 @@ func (h *ChannelHandler) ListChannels(c echo.Context) error { // @Success 200 {object} ChannelMeta // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /channels/{platform} [get] +// @Router /channels/{platform} [get]. func (h *ChannelHandler) GetChannel(c echo.Context) error { channelType, err := h.registry.ParseChannelType(c.Param("platform")) if err != nil { diff --git a/internal/handlers/containerd.go b/internal/handlers/containerd.go index a3433221..325442ab 100644 --- a/internal/handlers/containerd.go +++ b/internal/handlers/containerd.go @@ -33,14 +33,13 @@ import ( "github.com/memohai/memoh/internal/policy" ) +// ContainerdHandler serves bot container and snapshot APIs (create, get, list, delete, snapshot, MCP stdio). type ContainerdHandler struct { service ctr.Service cfg config.MCPConfig namespace string logger *slog.Logger toolGateway *mcp.ToolGatewayService - mcpMu sync.Mutex - mcpSess map[string]*mcpSession mcpStdioMu sync.Mutex mcpStdioSess map[string]*mcpStdioSession botService *bots.Service @@ -49,10 +48,12 @@ type ContainerdHandler struct { queries *dbsqlc.Queries } +// CreateContainerRequest is the body for creating a bot container (optional snapshotter). type CreateContainerRequest struct { Snapshotter string `json:"snapshotter,omitempty"` } +// CreateContainerResponse returns container_id, image, snapshotter, started. type CreateContainerResponse struct { ContainerID string `json:"container_id"` Image string `json:"image"` @@ -60,6 +61,7 @@ type CreateContainerResponse struct { Started bool `json:"started"` } +// GetContainerResponse is the container detail for get API (status, paths, task_running, timestamps). type GetContainerResponse struct { ContainerID string `json:"container_id"` Image string `json:"image"` @@ -72,38 +74,42 @@ type GetContainerResponse struct { UpdatedAt time.Time `json:"updated_at"` } +// CreateSnapshotRequest is the body for creating a snapshot (snapshot_name). type CreateSnapshotRequest struct { SnapshotName string `json:"snapshot_name"` } +// CreateSnapshotResponse returns container_id, snapshot_name, snapshotter. type CreateSnapshotResponse struct { ContainerID string `json:"container_id"` SnapshotName string `json:"snapshot_name"` Snapshotter string `json:"snapshotter"` } +// SnapshotInfo is one snapshot entry (snapshotter, name, parent, kind, timestamps, labels). type SnapshotInfo struct { Snapshotter string `json:"snapshotter"` Name string `json:"name"` Parent string `json:"parent,omitempty"` Kind string `json:"kind"` - CreatedAt time.Time `json:"created_at,omitempty"` - UpdatedAt time.Time `json:"updated_at,omitempty"` + CreatedAt time.Time `json:"created_at,omitzero"` + UpdatedAt time.Time `json:"updated_at,omitzero"` Labels map[string]string `json:"labels,omitempty"` } +// ListSnapshotsResponse holds snapshotter and list of SnapshotInfo. type ListSnapshotsResponse struct { Snapshotter string `json:"snapshotter"` Snapshots []SnapshotInfo `json:"snapshots"` } +// NewContainerdHandler creates a containerd handler (optionally set toolGateway after construction). func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, botService *bots.Service, accountService *accounts.Service, policyService *policy.Service, queries *dbsqlc.Queries) *ContainerdHandler { return &ContainerdHandler{ service: service, cfg: cfg, namespace: namespace, logger: log.With(slog.String("handler", "containerd")), - mcpSess: make(map[string]*mcpSession), mcpStdioSess: make(map[string]*mcpStdioSession), botService: botService, accountService: accountService, @@ -112,6 +118,7 @@ func NewContainerdHandler(log *slog.Logger, service ctr.Service, cfg config.MCPC } } +// Register mounts /bots/:bot_id/container, /bots/:bot_id/mcp-stdio, /bots/:bot_id/tools on the Echo instance. func (h *ContainerdHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/container") group.POST("", h.CreateContainer) @@ -138,7 +145,7 @@ func (h *ContainerdHandler) Register(e *echo.Echo) { // @Success 200 {object} CreateContainerResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container [post] +// @Router /bots/{bot_id}/container [post]. func (h *ContainerdHandler) CreateContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -378,7 +385,7 @@ func (h *ContainerdHandler) botContainerID(ctx context.Context, botID string) (s // @Success 200 {object} GetContainerResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container [get] +// @Router /bots/{bot_id}/container [get]. func (h *ContainerdHandler) GetContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -456,7 +463,7 @@ func (h *ContainerdHandler) GetContainer(c echo.Context) error { // @Success 204 // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container [delete] +// @Router /bots/{bot_id}/container [delete]. func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -475,7 +482,7 @@ func (h *ContainerdHandler) DeleteContainer(c echo.Context) error { // @Success 200 {object} object // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/start [post] +// @Router /bots/{bot_id}/container/start [post]. func (h *ContainerdHandler) StartContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -507,7 +514,7 @@ func (h *ContainerdHandler) StartContainer(c echo.Context) error { // @Success 200 {object} object // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/stop [post] +// @Router /bots/{bot_id}/container/stop [post]. func (h *ContainerdHandler) StopContainer(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -546,7 +553,7 @@ func (h *ContainerdHandler) StopContainer(c echo.Context) error { // @Success 200 {object} CreateSnapshotResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/snapshots [post] +// @Router /bots/{bot_id}/container/snapshots [post]. func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -596,7 +603,7 @@ func (h *ContainerdHandler) CreateSnapshot(c echo.Context) error { // @Param bot_id path string true "Bot ID" // @Param snapshotter query string false "Snapshotter name" // @Success 200 {object} ListSnapshotsResponse -// @Router /bots/{bot_id}/container/snapshots [get] +// @Router /bots/{bot_id}/container/snapshots [get]. func (h *ContainerdHandler) ListSnapshots(c echo.Context) error { if _, err := h.requireBotAccess(c); err != nil { return err diff --git a/internal/handlers/embeddings.go b/internal/handlers/embeddings.go index 36941e6b..87dbd8e0 100644 --- a/internal/handlers/embeddings.go +++ b/internal/handlers/embeddings.go @@ -13,13 +13,16 @@ import ( "github.com/memohai/memoh/internal/models" ) +// DefaultEmbeddingTimeout is the default HTTP timeout for embedding requests. const DefaultEmbeddingTimeout = 10 * time.Second +// EmbeddingsHandler serves POST /embeddings for text or multimodal embedding. type EmbeddingsHandler struct { resolver *embeddings.Resolver logger *slog.Logger } +// EmbeddingsRequest is the body for POST /embeddings (type, provider, model, dimensions, input). type EmbeddingsRequest struct { Type string `json:"type"` Provider string `json:"provider,omitempty"` @@ -28,28 +31,32 @@ type EmbeddingsRequest struct { Input EmbeddingsInput `json:"input"` } +// EmbeddingsInput holds text and optional image/video URL. type EmbeddingsInput struct { Text string `json:"text,omitempty"` ImageURL string `json:"image_url,omitempty"` VideoURL string `json:"video_url,omitempty"` } +// EmbeddingsResponse is the success body (type, provider, model, dimensions, embedding, usage). type EmbeddingsResponse struct { Type string `json:"type"` Provider string `json:"provider"` Model string `json:"model"` Dimensions int `json:"dimensions"` Embedding []float32 `json:"embedding"` - Usage EmbeddingsUsage `json:"usage,omitempty"` + Usage EmbeddingsUsage `json:"usage,omitzero"` Message string `json:"message,omitempty"` } +// EmbeddingsUsage holds token and duration usage from the embedding API. type EmbeddingsUsage struct { InputTokens int `json:"input_tokens,omitempty"` ImageTokens int `json:"image_tokens,omitempty"` - Duration int `json:"duration,omitempty"` + Duration int `json:"duration,omitempty"` } +// NewEmbeddingsHandler creates an embeddings handler with a resolver built from models service and queries. func NewEmbeddingsHandler(log *slog.Logger, modelsService *models.Service, queries *sqlc.Queries) *EmbeddingsHandler { return &EmbeddingsHandler{ resolver: embeddings.NewResolver(log, modelsService, queries, DefaultEmbeddingTimeout), @@ -57,6 +64,7 @@ func NewEmbeddingsHandler(log *slog.Logger, modelsService *models.Service, queri } } +// Register mounts POST /embeddings on the Echo instance. func (h *EmbeddingsHandler) Register(e *echo.Echo) { e.POST("/embeddings", h.Embed) } @@ -70,7 +78,7 @@ func (h *EmbeddingsHandler) Register(e *echo.Echo) { // @Failure 400 {object} ErrorResponse // @Failure 501 {object} EmbeddingsResponse // @Failure 500 {object} ErrorResponse -// @Router /embeddings [post] +// @Router /embeddings [post]. func (h *EmbeddingsHandler) Embed(c echo.Context) error { var req EmbeddingsRequest if err := c.Bind(&req); err != nil { @@ -129,7 +137,7 @@ func (h *EmbeddingsHandler) Embed(c echo.Context) error { Usage: EmbeddingsUsage{ InputTokens: result.Usage.InputTokens, ImageTokens: result.Usage.ImageTokens, - Duration: result.Usage.Duration, + Duration: result.Usage.Duration, }, }) } diff --git a/internal/handlers/error.go b/internal/handlers/error.go index d2db6e95..755dd868 100644 --- a/internal/handlers/error.go +++ b/internal/handlers/error.go @@ -1,5 +1,6 @@ package handlers +// ErrorResponse is the standard API error body (message only). type ErrorResponse struct { Message string `json:"message"` } diff --git a/internal/handlers/fs.go b/internal/handlers/fs.go index 4c750db5..6ca0c4ef 100644 --- a/internal/handlers/fs.go +++ b/internal/handlers/fs.go @@ -5,16 +5,13 @@ import ( "context" "encoding/json" "errors" - "fmt" "io" "log/slog" "net/http" "os/exec" "path/filepath" - "runtime" "strings" "sync" - "time" "github.com/containerd/containerd/v2/pkg/namespaces" "github.com/containerd/errdefs" @@ -22,7 +19,6 @@ import ( sdkjsonrpc "github.com/modelcontextprotocol/go-sdk/jsonrpc" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" - ctr "github.com/memohai/memoh/internal/containerd" mcptools "github.com/memohai/memoh/internal/mcp" ) @@ -79,167 +75,6 @@ const ( mcpSessionInitStateReady ) -func (h *ContainerdHandler) getMCPSession(ctx context.Context, containerID string) (*mcpSession, error) { - h.mcpMu.Lock() - if sess, ok := h.mcpSess[containerID]; ok { - h.mcpMu.Unlock() - return sess, nil - } - h.mcpMu.Unlock() - - var sess *mcpSession - var err error - if runtime.GOOS == "darwin" { - sess, err = h.startLimaMCPSession(containerID) - } - if err != nil || sess == nil { - sess, err = h.startContainerdMCPSession(ctx, containerID) - if err != nil { - return nil, err - } - } - - h.mcpMu.Lock() - h.mcpSess[containerID] = sess - h.mcpMu.Unlock() - - sess.onClose = func() { - h.mcpMu.Lock() - if current, ok := h.mcpSess[containerID]; ok && current == sess { - delete(h.mcpSess, containerID) - } - h.mcpMu.Unlock() - } - - return sess, nil -} - -func (h *ContainerdHandler) startContainerdMCPSession(ctx context.Context, containerID string) (*mcpSession, error) { - execSession, err := h.service.ExecTaskStreaming(ctx, containerID, ctr.ExecTaskRequest{ - Args: []string{"/app/mcp"}, - FIFODir: h.mcpFIFODir(), - }) - if err != nil { - return nil, err - } - - sess := &mcpSession{ - stdin: execSession.Stdin, - stdout: execSession.Stdout, - stderr: execSession.Stderr, - pending: make(map[string]chan *sdkjsonrpc.Response), - closed: make(chan struct{}), - } - transport := &sdkmcp.IOTransport{ - Reader: sess.stdout, - Writer: sess.stdin, - } - conn, err := transport.Connect(ctx) - if err != nil { - sess.closeWithError(err) - return nil, err - } - sess.conn = conn - - h.startMCPStderrLogger(execSession.Stderr, containerID) - go sess.readLoop() - go func() { - _, err := execSession.Wait() - if err != nil { - if isBenignMCPSessionExit(err) { - sess.closeWithError(io.EOF) - return - } - h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID)) - sess.closeWithError(err) - return - } - sess.closeWithError(io.EOF) - }() - - return sess, nil -} - -func (h *ContainerdHandler) startLimaMCPSession(containerID string) (*mcpSession, error) { - execID := fmt.Sprintf("mcp-%d", time.Now().UnixNano()) - cmd := exec.Command( - "limactl", - "shell", - "--tty=false", - "default", - "--", - "sudo", - "-n", - "ctr", - "-n", - "default", - "tasks", - "exec", - "--exec-id", - execID, - containerID, - "/app/mcp", - ) - - stdin, err := cmd.StdinPipe() - if err != nil { - return nil, err - } - stdout, err := cmd.StdoutPipe() - if err != nil { - _ = stdin.Close() - return nil, err - } - stderr, err := cmd.StderrPipe() - if err != nil { - _ = stdin.Close() - _ = stdout.Close() - return nil, err - } - if err := cmd.Start(); err != nil { - _ = stdin.Close() - _ = stdout.Close() - _ = stderr.Close() - return nil, err - } - - sess := &mcpSession{ - stdin: stdin, - stdout: stdout, - stderr: stderr, - cmd: cmd, - pending: make(map[string]chan *sdkjsonrpc.Response), - closed: make(chan struct{}), - } - transport := &sdkmcp.IOTransport{ - Reader: sess.stdout, - Writer: sess.stdin, - } - conn, err := transport.Connect(context.Background()) - if err != nil { - sess.closeWithError(err) - return nil, err - } - sess.conn = conn - - h.startMCPStderrLogger(stderr, containerID) - go sess.readLoop() - go func() { - if err := cmd.Wait(); err != nil { - if isBenignMCPSessionExit(err) { - sess.closeWithError(io.EOF) - return - } - h.logger.Error("mcp session exited", slog.Any("error", err), slog.String("container_id", containerID)) - sess.closeWithError(err) - return - } - sess.closeWithError(io.EOF) - }() - - return sess, nil -} - func (s *mcpSession) closeWithError(err error) { s.closeOnce.Do(func() { s.closeErr = err @@ -365,7 +200,7 @@ func (s *mcpSession) call(ctx context.Context, req mcptools.JSONRPCRequest) (map } target := sdkIDKey(targetID) if target == "" { - return nil, fmt.Errorf("missing request id") + return nil, errors.New("missing request id") } if s.conn == nil { return nil, io.EOF @@ -429,7 +264,7 @@ func (s *mcpSession) callRaw(ctx context.Context, req mcptools.JSONRPCRequest) ( } target := sdkIDKey(targetID) if target == "" { - return nil, fmt.Errorf("missing request id") + return nil, errors.New("missing request id") } if s.conn == nil { return nil, io.EOF @@ -609,11 +444,11 @@ func (s *mcpSession) invokeCall(ctx context.Context, req *sdkjsonrpc.Request) (* return nil, io.EOF } if req == nil || !req.ID.IsValid() { - return nil, fmt.Errorf("missing request id") + return nil, errors.New("missing request id") } key := sdkIDKey(req.ID) if key == "" { - return nil, fmt.Errorf("invalid request id") + return nil, errors.New("invalid request id") } respCh := make(chan *sdkjsonrpc.Response, 1) @@ -665,7 +500,7 @@ func (s *mcpSession) setInitStateAtLeast(next mcpSessionInitState) { func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { if len(raw) == 0 { - return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + return sdkjsonrpc.ID{}, errors.New("missing request id") } var idValue any if err := json.Unmarshal(raw, &idValue); err != nil { @@ -676,7 +511,7 @@ func parseRawJSONRPCID(raw json.RawMessage) (sdkjsonrpc.ID, error) { return sdkjsonrpc.ID{}, err } if !id.IsValid() { - return sdkjsonrpc.ID{}, fmt.Errorf("missing request id") + return sdkjsonrpc.ID{}, errors.New("missing request id") } return id, nil } @@ -710,7 +545,8 @@ func sdkResponsePayload(resp *sdkjsonrpc.Response) (map[string]any, error) { if resp.Error != nil { code := int64(-32603) message := strings.TrimSpace(resp.Error.Error()) - if wireErr, ok := resp.Error.(*sdkjsonrpc.Error); ok { + wireErr := &sdkjsonrpc.Error{} + if errors.As(resp.Error, &wireErr) { code = wireErr.Code message = strings.TrimSpace(wireErr.Message) } diff --git a/internal/handlers/fs_mcp_session_test.go b/internal/handlers/fs_mcp_session_test.go index 3ef000ca..f8ecc440 100644 --- a/internal/handlers/fs_mcp_session_test.go +++ b/internal/handlers/fs_mcp_session_test.go @@ -228,7 +228,7 @@ func TestMCPSessionExplicitInitializeDoesNotDuplicateInitialize(t *testing.T) { } func TestMCPSessionRemovesPendingOnContextCancel(t *testing.T) { - conn := newFakeMCPConnection(func(req *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { + conn := newFakeMCPConnection(func(_ *sdkjsonrpc.Request) (*sdkjsonrpc.Response, error) { // Intentionally do not reply; caller should timeout. return nil, nil }) diff --git a/internal/handlers/local_channel.go b/internal/handlers/local_channel.go index 6eb524ec..7a02255f 100644 --- a/internal/handlers/local_channel.go +++ b/internal/handlers/local_channel.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "log/slog" "net/http" "strings" "time" @@ -20,7 +21,7 @@ import ( // LocalChannelHandler handles local channel (CLI/Web) routes backed by bot history. type LocalChannelHandler struct { - channelType channel.ChannelType + channelType channel.Type channelManager *channel.Manager channelService *channel.Service chatService *conversation.Service @@ -30,7 +31,7 @@ type LocalChannelHandler struct { } // NewLocalChannelHandler creates a local channel handler. -func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *channel.Manager, channelService *channel.Service, chatService *conversation.Service, routeHub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *LocalChannelHandler { +func NewLocalChannelHandler(channelType channel.Type, channelManager *channel.Manager, channelService *channel.Service, chatService *conversation.Service, routeHub *local.RouteHub, botService *bots.Service, accountService *accounts.Service) *LocalChannelHandler { return &LocalChannelHandler{ channelType: channelType, channelManager: channelManager, @@ -44,7 +45,7 @@ func NewLocalChannelHandler(channelType channel.ChannelType, channelManager *cha // Register registers the local channel routes. func (h *LocalChannelHandler) Register(e *echo.Echo) { - prefix := fmt.Sprintf("/bots/:bot_id/%s", h.channelType.String()) + prefix := "/bots/:bot_id/" + h.channelType.String() group := e.Group(prefix) group.GET("/stream", h.StreamMessages) group.POST("/messages", h.PostMessage) @@ -100,10 +101,12 @@ func (h *LocalChannelHandler) StreamMessages(c echo.Context) error { if err != nil { continue } - if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", string(data))); err != nil { + if _, err := fmt.Fprintf(writer, "data: %s\n\n", string(data)); err != nil { return nil // client disconnected } - writer.Flush() + if err := writer.Flush(); err != nil { + slog.Default().Warn("local channel: flush failed", slog.Any("error", err)) + } flusher.Flush() } } diff --git a/internal/handlers/mcp.go b/internal/handlers/mcp.go index 3d1e1dd7..fc2baa7e 100644 --- a/internal/handlers/mcp.go +++ b/internal/handlers/mcp.go @@ -15,6 +15,7 @@ import ( "github.com/memohai/memoh/internal/mcp" ) +// MCPHandler serves /bots/:bot_id/mcp and /bots/:bot_id/mcp-ops (list, create, import, export, batch-delete). type MCPHandler struct { service *mcp.ConnectionService botService *bots.Service @@ -22,6 +23,7 @@ type MCPHandler struct { logger *slog.Logger } +// NewMCPHandler creates an MCP handler. func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService *bots.Service, accountService *accounts.Service) *MCPHandler { return &MCPHandler{ service: service, @@ -31,6 +33,7 @@ func NewMCPHandler(log *slog.Logger, service *mcp.ConnectionService, botService } } +// Register mounts /bots/:bot_id/mcp and /bots/:bot_id/mcp-ops on the Echo instance. func (h *MCPHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/mcp") group.GET("", h.List) @@ -54,7 +57,7 @@ func (h *MCPHandler) Register(e *echo.Echo) { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp [get] +// @Router /bots/{bot_id}/mcp [get]. func (h *MCPHandler) List(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -84,7 +87,7 @@ func (h *MCPHandler) List(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp [post] +// @Router /bots/{bot_id}/mcp [post]. func (h *MCPHandler) Create(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -118,7 +121,7 @@ func (h *MCPHandler) Create(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id} [get] +// @Router /bots/{bot_id}/mcp/{id} [get]. func (h *MCPHandler) Get(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -156,7 +159,7 @@ func (h *MCPHandler) Get(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id} [put] +// @Router /bots/{bot_id}/mcp/{id} [put]. func (h *MCPHandler) Update(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -197,7 +200,7 @@ func (h *MCPHandler) Update(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/{id} [delete] +// @Router /bots/{bot_id}/mcp/{id} [delete]. func (h *MCPHandler) Delete(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -229,7 +232,7 @@ func (h *MCPHandler) Delete(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/import [put] +// @Router /bots/{bot_id}/mcp/import [put]. func (h *MCPHandler) Import(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -267,7 +270,7 @@ type BatchDeleteRequest struct { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-ops/batch-delete [post] +// @Router /bots/{bot_id}/mcp-ops/batch-delete [post]. func (h *MCPHandler) BatchDelete(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -301,7 +304,7 @@ func (h *MCPHandler) BatchDelete(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp/export [get] +// @Router /bots/{bot_id}/mcp/export [get]. func (h *MCPHandler) Export(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { diff --git a/internal/handlers/mcp_federation_gateway.go b/internal/handlers/mcp_federation_gateway.go index ea9bd57e..56a8fd30 100644 --- a/internal/handlers/mcp_federation_gateway.go +++ b/internal/handlers/mcp_federation_gateway.go @@ -3,6 +3,7 @@ package handlers import ( "context" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -10,16 +11,19 @@ import ( "strings" "time" - mcpgw "github.com/memohai/memoh/internal/mcp" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + mcpgw "github.com/memohai/memoh/internal/mcp" ) +// MCPFederationGateway lists and calls tools on HTTP MCP connections (streamable session). type MCPFederationGateway struct { handler *ContainerdHandler logger *slog.Logger client *http.Client } +// NewMCPFederationGateway creates a federation gateway backed by the containerd handler. func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPFederationGateway { if log == nil { log = slog.Default() @@ -33,6 +37,7 @@ func NewMCPFederationGateway(log *slog.Logger, handler *ContainerdHandler) *MCPF } } +// ListHTTPConnectionTools returns tools from the HTTP MCP connection via streamable session. func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { session, err := g.connectStreamableSession(ctx, connection) if err != nil { @@ -46,6 +51,7 @@ func (g *MCPFederationGateway) ListHTTPConnectionTools(ctx context.Context, conn return convertSDKTools(result.Tools), nil } +// CallHTTPConnectionTool invokes the named tool on the HTTP MCP connection via streamable session. func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { session, err := g.connectStreamableSession(ctx, connection) if err != nil { @@ -62,6 +68,7 @@ func (g *MCPFederationGateway) CallHTTPConnectionTool(ctx context.Context, conne return wrapSDKToolResult(result) } +// ListSSEConnectionTools returns tools from the SSE MCP connection via streamable session. func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { session, err := g.connectSSESession(ctx, connection) if err != nil { @@ -75,6 +82,7 @@ func (g *MCPFederationGateway) ListSSEConnectionTools(ctx context.Context, conne return convertSDKTools(result.Tools), nil } +// CallSSEConnectionTool invokes the named tool on the SSE MCP connection via streamable session. func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { session, err := g.connectSSESession(ctx, connection) if err != nil { @@ -94,7 +102,7 @@ func (g *MCPFederationGateway) CallSSEConnectionTool(ctx context.Context, connec func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { url := strings.TrimSpace(anyToString(connection.Config["url"])) if url == "" { - return nil, fmt.Errorf("http mcp url is required") + return nil, errors.New("http mcp url is required") } client := sdkmcp.NewClient(&sdkmcp.Implementation{ Name: "memoh-federation-client", @@ -111,7 +119,7 @@ func (g *MCPFederationGateway) connectStreamableSession(ctx context.Context, con func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection mcpgw.Connection) (*sdkmcp.ClientSession, error) { endpoints := resolveSSEEndpointCandidates(connection.Config) if len(endpoints) == 0 { - return nil, fmt.Errorf("sse mcp url is required") + return nil, errors.New("sse mcp url is required") } var lastErr error for _, endpoint := range endpoints { @@ -130,7 +138,7 @@ func (g *MCPFederationGateway) connectSSESession(ctx context.Context, connection lastErr = err } if lastErr == nil { - lastErr = fmt.Errorf("no sse endpoint candidate available") + lastErr = errors.New("no sse endpoint candidate available") } return nil, fmt.Errorf("connect sse mcp failed: %w", lastErr) } @@ -170,16 +178,16 @@ func resolveSSEEndpointCandidates(config map[string]any) []string { } if messageURL != "" { normalized := strings.TrimSuffix(messageURL, "/") - if strings.HasSuffix(normalized, "/message") { - appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse") + if before, ok := strings.CutSuffix(normalized, "/message"); ok { + appendEndpoint(before + "/sse") } appendEndpoint(messageURL) } if baseURL != "" { normalized := strings.TrimSuffix(baseURL, "/") - if strings.HasSuffix(normalized, "/message") { - appendEndpoint(strings.TrimSuffix(normalized, "/message") + "/sse") + if before, ok := strings.CutSuffix(normalized, "/message"); ok { + appendEndpoint(before + "/sse") } } @@ -210,6 +218,7 @@ func (g *MCPFederationGateway) connectionHTTPClient(connection mcpgw.Connection) } } +// ListStdioConnectionTools returns tools from the stdio MCP connection (tools/list via session). func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { sess, err := g.startStdioConnectionSession(ctx, botID, connection) if err != nil { @@ -228,6 +237,7 @@ func (g *MCPFederationGateway) ListStdioConnectionTools(ctx context.Context, bot return parseGatewayToolsListPayload(payload) } +// CallStdioConnectionTool invokes the named tool on the stdio MCP connection via session. func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { sess, err := g.startStdioConnectionSession(ctx, botID, connection) if err != nil { @@ -252,7 +262,7 @@ func (g *MCPFederationGateway) CallStdioConnectionTool(ctx context.Context, botI func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, botID string, connection mcpgw.Connection) (*mcpSession, error) { if g.handler == nil { - return nil, fmt.Errorf("containerd handler not configured") + return nil, errors.New("containerd handler not configured") } containerID, err := g.handler.botContainerID(ctx, botID) if err != nil { @@ -267,7 +277,7 @@ func (g *MCPFederationGateway) startStdioConnectionSession(ctx context.Context, command := strings.TrimSpace(anyToString(connection.Config["command"])) if command == "" { - return nil, fmt.Errorf("stdio mcp command is required") + return nil, errors.New("stdio mcp command is required") } request := MCPStdioRequest{ Name: strings.TrimSpace(connection.Name), @@ -285,11 +295,11 @@ func parseGatewayToolsListPayload(payload map[string]any) ([]mcpgw.ToolDescripto } result, ok := payload["result"].(map[string]any) if !ok { - return nil, fmt.Errorf("invalid tools/list result") + return nil, errors.New("invalid tools/list result") } rawTools, ok := result["tools"].([]any) if !ok { - return nil, fmt.Errorf("invalid tools/list tools field") + return nil, errors.New("invalid tools/list tools field") } tools := make([]mcpgw.ToolDescriptor, 0, len(rawTools)) for _, rawTool := range rawTools { diff --git a/internal/handlers/mcp_federation_gateway_test.go b/internal/handlers/mcp_federation_gateway_test.go index ff453626..d753447e 100644 --- a/internal/handlers/mcp_federation_gateway_test.go +++ b/internal/handlers/mcp_federation_gateway_test.go @@ -4,10 +4,12 @@ import ( "context" "net/http" "net/http/httptest" + "slices" "testing" - mcpgw "github.com/memohai/memoh/internal/mcp" sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp" + + mcpgw "github.com/memohai/memoh/internal/mcp" ) type testToolInput struct { @@ -26,7 +28,7 @@ func newTestMCPServer() *sdkmcp.Server { sdkmcp.AddTool(server, &sdkmcp.Tool{ Name: "echo", Description: "Echo query", - }, func(ctx context.Context, request *sdkmcp.CallToolRequest, input testToolInput) (*sdkmcp.CallToolResult, testToolOutput, error) { + }, func(_ context.Context, _ *sdkmcp.CallToolRequest, input testToolInput) (*sdkmcp.CallToolResult, testToolOutput, error) { return nil, testToolOutput{Echo: input.Query}, nil }) return server @@ -158,13 +160,7 @@ func TestResolveSSEEndpointCandidatesCompatibility(t *testing.T) { if got[0] != tt.firstWant { t.Fatalf("unexpected first endpoint: got=%s want=%s", got[0], tt.firstWant) } - found := false - for _, item := range got { - if item == tt.contains { - found = true - break - } - } + found := slices.Contains(got, tt.contains) if !found { t.Fatalf("endpoint candidates missing expected value: %s in %#v", tt.contains, got) } diff --git a/internal/handlers/mcp_stdio.go b/internal/handlers/mcp_stdio.go index 51fff854..2bcc6d8d 100644 --- a/internal/handlers/mcp_stdio.go +++ b/internal/handlers/mcp_stdio.go @@ -21,6 +21,7 @@ import ( mcptools "github.com/memohai/memoh/internal/mcp" ) +// MCPStdioRequest is the body for creating a stdio MCP proxy (name, command, args, env, cwd). type MCPStdioRequest struct { Name string `json:"name"` Command string `json:"command"` @@ -29,6 +30,7 @@ type MCPStdioRequest struct { Cwd string `json:"cwd"` } +// MCPStdioResponse returns connection_id, url, and optional tools list after creating the proxy. type MCPStdioResponse struct { ConnectionID string `json:"connection_id"` URL string `json:"url"` @@ -55,7 +57,7 @@ type mcpStdioSession struct { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-stdio [post] +// @Router /bots/{bot_id}/mcp-stdio [post]. func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -124,7 +126,7 @@ func (h *ContainerdHandler) CreateMCPStdio(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post] +// @Router /bots/{bot_id}/mcp-stdio/{connection_id} [post]. func (h *ContainerdHandler) HandleMCPStdio(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { diff --git a/internal/handlers/mcp_tools.go b/internal/handlers/mcp_tools.go index b4312699..83346e10 100644 --- a/internal/handlers/mcp_tools.go +++ b/internal/handlers/mcp_tools.go @@ -3,7 +3,7 @@ package handlers import ( "context" "encoding/json" - "fmt" + "errors" "net/http" "strings" @@ -21,6 +21,7 @@ const ( headerReplyTarget = "X-Memoh-Reply-Target" ) +// SetToolGatewayService sets the MCP tool gateway used by HandleMCPTools (inject after construction). func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayService) { h.toolGateway = service } @@ -35,7 +36,7 @@ func (h *ContainerdHandler) SetToolGatewayService(service *mcpgw.ToolGatewayServ // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/tools [post] +// @Router /bots/{bot_id}/tools [post]. func (h *ContainerdHandler) HandleMCPTools(c echo.Context) error { if h.toolGateway == nil { return echo.NewHTTPError(http.StatusServiceUnavailable, "tool gateway not configured") @@ -141,7 +142,7 @@ func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionConte case "tools/call": callReq, ok := req.(*sdkmcp.ServerRequest[*sdkmcp.CallToolParamsRaw]) if !ok || callReq == nil || callReq.Params == nil { - return nil, fmt.Errorf("tools/call params is required") + return nil, errors.New("tools/call params is required") } payload, err := buildToolCallPayloadFromRaw(callReq.Params) if err != nil { @@ -161,11 +162,11 @@ func (h *ContainerdHandler) toolGatewayMiddleware(session mcpgw.ToolSessionConte func buildToolCallPayloadFromRaw(params *sdkmcp.CallToolParamsRaw) (mcpgw.ToolCallPayload, error) { if params == nil { - return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call params is required") + return mcpgw.ToolCallPayload{}, errors.New("tools/call params is required") } name := strings.TrimSpace(params.Name) if name == "" { - return mcpgw.ToolCallPayload{}, fmt.Errorf("tools/call name is required") + return mcpgw.ToolCallPayload{}, errors.New("tools/call name is required") } arguments := map[string]any{} if len(params.Arguments) > 0 { diff --git a/internal/handlers/mcp_tools_test.go b/internal/handlers/mcp_tools_test.go index f9ea36b8..d3c2eb0d 100644 --- a/internal/handlers/mcp_tools_test.go +++ b/internal/handlers/mcp_tools_test.go @@ -3,6 +3,7 @@ package handlers import ( "context" "encoding/json" + "errors" "log/slog" "net/http" "net/http/httptest" @@ -55,7 +56,8 @@ func TestHandleMCPToolsWithoutGateway(t *testing.T) { if err == nil { t.Fatalf("expected service unavailable error") } - httpErr, ok := err.(*echo.HTTPError) + httpErr := &echo.HTTPError{} + ok := errors.As(err, &httpErr) if !ok { t.Fatalf("expected echo HTTP error, got %T", err) } @@ -68,7 +70,7 @@ type mcpToolsTestExecutor struct { lastSession mcpgw.ToolSessionContext } -func (e *mcpToolsTestExecutor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (e *mcpToolsTestExecutor) ListTools(_ context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { e.lastSession = session return []mcpgw.ToolDescriptor{ { @@ -84,7 +86,7 @@ func (e *mcpToolsTestExecutor) ListTools(ctx context.Context, session mcpgw.Tool }, nil } -func (e *mcpToolsTestExecutor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (e *mcpToolsTestExecutor) CallTool(_ context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { e.lastSession = session if strings.TrimSpace(toolName) != "echo_tool" { return nil, mcpgw.ErrToolNotFound diff --git a/internal/handlers/memory.go b/internal/handlers/memory.go index aa908b77..a9a0f499 100644 --- a/internal/handlers/memory.go +++ b/internal/handlers/memory.go @@ -20,7 +20,7 @@ type MemoryHandler struct { service *memory.Service chatService *conversation.Service accountService *accounts.Service - memoryFS *memory.MemoryFS + memoryFS *memory.FS logger *slog.Logger } @@ -73,7 +73,7 @@ func NewMemoryHandler(log *slog.Logger, service *memory.Service, chatService *co } // SetMemoryFS sets the optional filesystem persistence layer. -func (h *MemoryHandler) SetMemoryFS(fs *memory.MemoryFS) { +func (h *MemoryHandler) SetMemoryFS(fs *memory.FS) { h.memoryFS = fs } @@ -112,7 +112,7 @@ func (h *MemoryHandler) checkService() error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory [post] +// @Router /bots/{bot_id}/memory [post]. func (h *MemoryHandler) ChatAdd(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -190,7 +190,7 @@ func (h *MemoryHandler) ChatAdd(c echo.Context) error { // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/search [post] +// @Router /bots/{bot_id}/memory/search [post]. func (h *MemoryHandler) ChatSearch(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -223,7 +223,7 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { botID := strings.TrimSpace(chatObj.BotID) // Search shared namespace and merge results. - var allResults []memory.MemoryItem + var allResults []memory.Item for _, scope := range scopes { filters := buildNamespaceFilters(scope.Namespace, scope.ScopeID, payload.Filters) if botID != "" { @@ -271,7 +271,7 @@ func (h *MemoryHandler) ChatSearch(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory [get] +// @Router /bots/{bot_id}/memory [get]. func (h *MemoryHandler) ChatGetAll(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -294,7 +294,7 @@ func (h *MemoryHandler) ChatGetAll(c echo.Context) error { return err } - var allResults []memory.MemoryItem + var allResults []memory.Item for _, scope := range scopes { req := memory.GetAllRequest{ Filters: buildNamespaceFilters(scope.Namespace, scope.ScopeID, nil), @@ -325,7 +325,7 @@ func (h *MemoryHandler) ChatGetAll(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory [delete] +// @Router /bots/{bot_id}/memory [delete]. func (h *MemoryHandler) ChatDelete(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -395,7 +395,7 @@ func (h *MemoryHandler) ChatDelete(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/{id} [delete] +// @Router /bots/{bot_id}/memory/{id} [delete]. func (h *MemoryHandler) ChatDeleteOne(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -449,7 +449,7 @@ func (h *MemoryHandler) ChatDeleteOne(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/compact [post] +// @Router /bots/{bot_id}/memory/compact [post]. func (h *MemoryHandler) ChatCompact(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -516,7 +516,7 @@ func (h *MemoryHandler) ChatCompact(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/usage [get] +// @Router /bots/{bot_id}/memory/usage [get]. func (h *MemoryHandler) ChatUsage(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -567,7 +567,7 @@ func (h *MemoryHandler) ChatUsage(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Failure 503 {object} ErrorResponse -// @Router /bots/{bot_id}/memory/rebuild [post] +// @Router /bots/{bot_id}/memory/rebuild [post]. func (h *MemoryHandler) ChatRebuild(c echo.Context) error { if err := h.checkService(); err != nil { return err @@ -716,12 +716,12 @@ func buildNamespaceFilters(namespace, scopeID string, extra map[string]any) map[ return filters } -func deduplicateMemoryItems(items []memory.MemoryItem) []memory.MemoryItem { +func deduplicateMemoryItems(items []memory.Item) []memory.Item { if len(items) == 0 { return items } seen := make(map[string]struct{}, len(items)) - result := make([]memory.MemoryItem, 0, len(items)) + result := make([]memory.Item, 0, len(items)) for _, item := range items { if _, ok := seen[item.ID]; ok { continue diff --git a/internal/handlers/message.go b/internal/handlers/message.go index 7ac36048..d9564cf0 100644 --- a/internal/handlers/message.go +++ b/internal/handlers/message.go @@ -105,6 +105,7 @@ func (h *MessageHandler) SendMessage(c echo.Context) error { req.Channels = []string{req.CurrentChannel} } channelIdentityID = h.resolveWebChannelIdentity(c.Request().Context(), channelIdentityID, &req) + req.UserID = channelIdentityID if h.runner == nil { return echo.NewHTTPError(http.StatusInternalServerError, "conversation runner not configured") @@ -155,6 +156,7 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error { req.Channels = []string{req.CurrentChannel} } channelIdentityID = h.resolveWebChannelIdentity(c.Request().Context(), channelIdentityID, &req) + req.UserID = channelIdentityID if h.runner == nil { return echo.NewHTTPError(http.StatusInternalServerError, "conversation runner not configured") @@ -180,7 +182,6 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error { case chunk, ok := <-chunkChan: if !ok { if processingState == "started" { - processingState = "completed" if err := writeSSEJSON(writer, flusher, map[string]string{"type": "processing_completed"}); err != nil { return nil } @@ -203,7 +204,6 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error { if err != nil { h.logger.Error("conversation stream failed", slog.Any("error", err)) if processingState == "started" { - processingState = "failed" if writeErr := writeSSEJSON(writer, flusher, map[string]string{ "type": "processing_failed", "error": err.Error(), @@ -226,7 +226,7 @@ func (h *MessageHandler) StreamMessage(c echo.Context) error { } func writeSSEData(writer *bufio.Writer, flusher http.Flusher, payload string) error { - if _, err := writer.WriteString(fmt.Sprintf("data: %s\n\n", payload)); err != nil { + if _, err := fmt.Fprintf(writer, "data: %s\n\n", payload); err != nil { return err } if err := writer.Flush(); err != nil { @@ -259,7 +259,7 @@ func parseSinceParam(raw string) (time.Time, bool, error) { if epochMillis, err := strconv.ParseInt(trimmed, 10, 64); err == nil { return time.UnixMilli(epochMillis).UTC(), true, nil } - return time.Time{}, false, fmt.Errorf("invalid since parameter") + return time.Time{}, false, errors.New("invalid since parameter") } // ListMessages godoc @@ -274,7 +274,7 @@ func parseSinceParam(raw string) (time.Time, bool, error) { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/messages [get] +// @Router /bots/{bot_id}/messages [get]. func (h *MessageHandler) ListMessages(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -391,7 +391,7 @@ func (h *MessageHandler) StreamMessageEvents(c echo.Context) error { sentMessageIDs[msgID] = struct{}{} } return writeSSEJSON(writer, flusher, map[string]any{ - "type": string(messageevent.EventTypeMessageCreated), + "type": string(messageevent.TypeMessageCreated), "bot_id": botID, "message": message, }) @@ -430,7 +430,7 @@ func (h *MessageHandler) StreamMessageEvents(c echo.Context) error { if strings.TrimSpace(event.BotID) != botID { continue } - if event.Type != messageevent.EventTypeMessageCreated { + if event.Type != messageevent.TypeMessageCreated { continue } if len(event.Data) == 0 { diff --git a/internal/handlers/models.go b/internal/handlers/models.go index 58fb3a77..d60c74e3 100644 --- a/internal/handlers/models.go +++ b/internal/handlers/models.go @@ -10,11 +10,13 @@ import ( "github.com/memohai/memoh/internal/models" ) +// ModelsHandler serves /models CRUD and list/count APIs. type ModelsHandler struct { service *models.Service logger *slog.Logger } +// NewModelsHandler creates a models handler. func NewModelsHandler(log *slog.Logger, service *models.Service) *ModelsHandler { return &ModelsHandler{ service: service, @@ -22,6 +24,7 @@ func NewModelsHandler(log *slog.Logger, service *models.Service) *ModelsHandler } } +// Register mounts /models routes on the Echo instance. func (h *ModelsHandler) Register(e *echo.Echo) { group := e.Group("/models") group.POST("", h.Create) @@ -43,7 +46,7 @@ func (h *ModelsHandler) Register(e *echo.Echo) { // @Success 201 {object} models.AddResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models [post] +// @Router /models [post]. func (h *ModelsHandler) Create(c echo.Context) error { var req models.AddRequest if err := c.Bind(&req); err != nil { @@ -66,7 +69,7 @@ func (h *ModelsHandler) Create(c echo.Context) error { // @Success 200 {array} models.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models [get] +// @Router /models [get]. func (h *ModelsHandler) List(c echo.Context) error { modelType := c.QueryParam("type") clientType := c.QueryParam("client_type") @@ -74,11 +77,12 @@ func (h *ModelsHandler) List(c echo.Context) error { var resp []models.GetResponse var err error - if modelType != "" { + switch { + case modelType != "": resp, err = h.service.ListByType(c.Request().Context(), models.ModelType(modelType)) - } else if clientType != "" { + case clientType != "": resp, err = h.service.ListByClientType(c.Request().Context(), models.ClientType(clientType)) - } else { + default: resp, err = h.service.List(c.Request().Context()) } @@ -97,7 +101,7 @@ func (h *ModelsHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id} [get] +// @Router /models/{id} [get]. func (h *ModelsHandler) GetByID(c echo.Context) error { id := c.Param("id") if id == "" { @@ -120,7 +124,7 @@ func (h *ModelsHandler) GetByID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/model/{modelId} [get] +// @Router /models/model/{modelId} [get]. func (h *ModelsHandler) GetByModelID(c echo.Context) error { modelID := c.Param("modelId") if modelID == "" { @@ -149,7 +153,7 @@ func (h *ModelsHandler) GetByModelID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id} [put] +// @Router /models/{id} [put]. func (h *ModelsHandler) UpdateByID(c echo.Context) error { id := c.Param("id") if id == "" { @@ -178,7 +182,7 @@ func (h *ModelsHandler) UpdateByID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/model/{modelId} [put] +// @Router /models/model/{modelId} [put]. func (h *ModelsHandler) UpdateByModelID(c echo.Context) error { modelID := c.Param("modelId") if modelID == "" { @@ -211,7 +215,7 @@ func (h *ModelsHandler) UpdateByModelID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/{id} [delete] +// @Router /models/{id} [delete]. func (h *ModelsHandler) DeleteByID(c echo.Context) error { id := c.Param("id") if id == "" { @@ -233,7 +237,7 @@ func (h *ModelsHandler) DeleteByID(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/model/{modelId} [delete] +// @Router /models/model/{modelId} [delete]. func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { modelID := c.Param("modelId") if modelID == "" { @@ -259,7 +263,7 @@ func (h *ModelsHandler) DeleteByModelID(c echo.Context) error { // @Success 200 {object} models.CountResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /models/count [get] +// @Router /models/count [get]. func (h *ModelsHandler) Count(c echo.Context) error { modelType := c.QueryParam("type") diff --git a/internal/handlers/ping.go b/internal/handlers/ping.go index e21ff554..577e9698 100644 --- a/internal/handlers/ping.go +++ b/internal/handlers/ping.go @@ -7,25 +7,30 @@ import ( "github.com/labstack/echo/v4" ) +// PingHandler serves /ping and HEAD /health for liveness. type PingHandler struct { logger *slog.Logger } +// NewPingHandler creates a ping handler. func NewPingHandler(log *slog.Logger) *PingHandler { return &PingHandler{logger: log.With(slog.String("handler", "ping"))} } +// Register mounts GET /ping and HEAD /health on the Echo instance. func (h *PingHandler) Register(e *echo.Echo) { e.GET("/ping", h.Ping) e.HEAD("/health", h.PingHead) } +// Ping returns 200 JSON {"status":"ok"}. func (h *PingHandler) Ping(c echo.Context) error { return c.JSON(http.StatusOK, map[string]string{ "status": "ok", }) } +// PingHead returns 200 No Content for health checks. func (h *PingHandler) PingHead(c echo.Context) error { return c.NoContent(http.StatusOK) } diff --git a/internal/handlers/preauth.go b/internal/handlers/preauth.go index eb6e6a90..2586c436 100644 --- a/internal/handlers/preauth.go +++ b/internal/handlers/preauth.go @@ -13,12 +13,14 @@ import ( "github.com/memohai/memoh/internal/preauth" ) +// PreauthHandler serves POST /bots/:bot_id/preauth_keys for issuing preauth keys. type PreauthHandler struct { service *preauth.Service botService *bots.Service accountService *accounts.Service } +// NewPreauthHandler creates a preauth handler. func NewPreauthHandler(service *preauth.Service, botService *bots.Service, accountService *accounts.Service) *PreauthHandler { return &PreauthHandler{ service: service, @@ -27,6 +29,7 @@ func NewPreauthHandler(service *preauth.Service, botService *bots.Service, accou } } +// Register mounts POST /bots/:bot_id/preauth_keys on the Echo instance. func (h *PreauthHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/preauth_keys") group.POST("", h.Issue) @@ -36,6 +39,7 @@ type preauthIssueRequest struct { TTLSeconds int `json:"ttl_seconds"` } +// Issue creates a preauth key for the bot and returns it (requires bot access). func (h *PreauthHandler) Issue(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { diff --git a/internal/handlers/providers.go b/internal/handlers/providers.go index 22864ac1..a17095b8 100644 --- a/internal/handlers/providers.go +++ b/internal/handlers/providers.go @@ -11,12 +11,14 @@ import ( "github.com/memohai/memoh/internal/providers" ) +// ProvidersHandler serves /providers CRUD, list, count, and list-models APIs. type ProvidersHandler struct { service *providers.Service modelsService *models.Service logger *slog.Logger } +// NewProvidersHandler creates a providers handler. func NewProvidersHandler(log *slog.Logger, service *providers.Service, modelsService *models.Service) *ProvidersHandler { return &ProvidersHandler{ service: service, @@ -25,6 +27,7 @@ func NewProvidersHandler(log *slog.Logger, service *providers.Service, modelsSer } } +// Register mounts /providers routes on the Echo instance. func (h *ProvidersHandler) Register(e *echo.Echo) { group := e.Group("/providers") group.POST("", h.Create) @@ -47,7 +50,7 @@ func (h *ProvidersHandler) Register(e *echo.Echo) { // @Success 201 {object} providers.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers [post] +// @Router /providers [post]. func (h *ProvidersHandler) Create(c echo.Context) error { var req providers.CreateRequest if err := c.Bind(&req); err != nil { @@ -83,7 +86,7 @@ func (h *ProvidersHandler) Create(c echo.Context) error { // @Success 200 {array} providers.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers [get] +// @Router /providers [get]. func (h *ProvidersHandler) List(c echo.Context) error { clientType := c.QueryParam("client_type") @@ -114,7 +117,7 @@ func (h *ProvidersHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id} [get] +// @Router /providers/{id} [get]. func (h *ProvidersHandler) Get(c echo.Context) error { id := c.Param("id") if id == "" { @@ -139,7 +142,7 @@ func (h *ProvidersHandler) Get(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id}/models [get] +// @Router /providers/{id}/models [get]. func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error { if h.modelsService == nil { return echo.NewHTTPError(http.StatusInternalServerError, "models service not configured") @@ -178,7 +181,7 @@ func (h *ProvidersHandler) ListModelsByProvider(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/name/{name} [get] +// @Router /providers/name/{name} [get]. func (h *ProvidersHandler) GetByName(c echo.Context) error { name := c.Param("name") if name == "" { @@ -205,7 +208,7 @@ func (h *ProvidersHandler) GetByName(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id} [put] +// @Router /providers/{id} [put]. func (h *ProvidersHandler) Update(c echo.Context) error { id := c.Param("id") if id == "" { @@ -236,7 +239,7 @@ func (h *ProvidersHandler) Update(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/{id} [delete] +// @Router /providers/{id} [delete]. func (h *ProvidersHandler) Delete(c echo.Context) error { id := c.Param("id") if id == "" { @@ -260,7 +263,7 @@ func (h *ProvidersHandler) Delete(c echo.Context) error { // @Success 200 {object} providers.CountResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /providers/count [get] +// @Router /providers/count [get]. func (h *ProvidersHandler) Count(c echo.Context) error { clientType := c.QueryParam("client_type") diff --git a/internal/handlers/schedule.go b/internal/handlers/schedule.go index 00eebcc5..4bc7b68d 100644 --- a/internal/handlers/schedule.go +++ b/internal/handlers/schedule.go @@ -13,6 +13,7 @@ import ( "github.com/memohai/memoh/internal/schedule" ) +// ScheduleHandler serves /bots/:bot_id/schedule CRUD APIs. type ScheduleHandler struct { service *schedule.Service botService *bots.Service @@ -20,6 +21,7 @@ type ScheduleHandler struct { logger *slog.Logger } +// NewScheduleHandler creates a schedule handler. func NewScheduleHandler(log *slog.Logger, service *schedule.Service, botService *bots.Service, accountService *accounts.Service) *ScheduleHandler { return &ScheduleHandler{ service: service, @@ -29,6 +31,7 @@ func NewScheduleHandler(log *slog.Logger, service *schedule.Service, botService } } +// Register mounts /bots/:bot_id/schedule routes on the Echo instance. func (h *ScheduleHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/schedule") group.POST("", h.Create) @@ -46,7 +49,7 @@ func (h *ScheduleHandler) Register(e *echo.Echo) { // @Success 201 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule [post] +// @Router /bots/{bot_id}/schedule [post]. func (h *ScheduleHandler) Create(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -77,7 +80,7 @@ func (h *ScheduleHandler) Create(c echo.Context) error { // @Success 200 {object} schedule.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule [get] +// @Router /bots/{bot_id}/schedule [get]. func (h *ScheduleHandler) List(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -106,7 +109,7 @@ func (h *ScheduleHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule/{id} [get] +// @Router /bots/{bot_id}/schedule/{id} [get]. func (h *ScheduleHandler) Get(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -142,7 +145,7 @@ func (h *ScheduleHandler) Get(c echo.Context) error { // @Success 200 {object} schedule.Schedule // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule/{id} [put] +// @Router /bots/{bot_id}/schedule/{id} [put]. func (h *ScheduleHandler) Update(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { @@ -185,7 +188,7 @@ func (h *ScheduleHandler) Update(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/schedule/{id} [delete] +// @Router /bots/{bot_id}/schedule/{id} [delete]. func (h *ScheduleHandler) Delete(c echo.Context) error { userID, err := h.requireUserID(c) if err != nil { diff --git a/internal/handlers/search_providers.go b/internal/handlers/search_providers.go index 9289fa8f..deb0e00c 100644 --- a/internal/handlers/search_providers.go +++ b/internal/handlers/search_providers.go @@ -10,11 +10,13 @@ import ( "github.com/memohai/memoh/internal/searchproviders" ) +// SearchProvidersHandler serves /search-providers CRUD and /meta APIs. type SearchProvidersHandler struct { service *searchproviders.Service logger *slog.Logger } +// NewSearchProvidersHandler creates a search providers handler. func NewSearchProvidersHandler(log *slog.Logger, service *searchproviders.Service) *SearchProvidersHandler { return &SearchProvidersHandler{ service: service, @@ -22,6 +24,7 @@ func NewSearchProvidersHandler(log *slog.Logger, service *searchproviders.Servic } } +// Register mounts /search-providers routes on the Echo instance. func (h *SearchProvidersHandler) Register(e *echo.Echo) { group := e.Group("/search-providers") group.GET("/meta", h.ListMeta) @@ -37,7 +40,7 @@ func (h *SearchProvidersHandler) Register(e *echo.Echo) { // @Description List available search provider types and config schemas // @Tags search-providers // @Success 200 {array} searchproviders.ProviderMeta -// @Router /search-providers/meta [get] +// @Router /search-providers/meta [get]. func (h *SearchProvidersHandler) ListMeta(c echo.Context) error { return c.JSON(http.StatusOK, h.service.ListMeta(c.Request().Context())) } @@ -52,7 +55,7 @@ func (h *SearchProvidersHandler) ListMeta(c echo.Context) error { // @Success 201 {object} searchproviders.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers [post] +// @Router /search-providers [post]. func (h *SearchProvidersHandler) Create(c echo.Context) error { var req searchproviders.CreateRequest if err := c.Bind(&req); err != nil { @@ -80,7 +83,7 @@ func (h *SearchProvidersHandler) Create(c echo.Context) error { // @Param provider query string false "Provider filter (brave)" // @Success 200 {array} searchproviders.GetResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers [get] +// @Router /search-providers [get]. func (h *SearchProvidersHandler) List(c echo.Context) error { items, err := h.service.List(c.Request().Context(), c.QueryParam("provider")) if err != nil { @@ -99,7 +102,7 @@ func (h *SearchProvidersHandler) List(c echo.Context) error { // @Success 200 {object} searchproviders.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse -// @Router /search-providers/{id} [get] +// @Router /search-providers/{id} [get]. func (h *SearchProvidersHandler) Get(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -123,7 +126,7 @@ func (h *SearchProvidersHandler) Get(c echo.Context) error { // @Success 200 {object} searchproviders.GetResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers/{id} [put] +// @Router /search-providers/{id} [put]. func (h *SearchProvidersHandler) Update(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { @@ -150,7 +153,7 @@ func (h *SearchProvidersHandler) Update(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /search-providers/{id} [delete] +// @Router /search-providers/{id} [delete]. func (h *SearchProvidersHandler) Delete(c echo.Context) error { id := strings.TrimSpace(c.Param("id")) if id == "" { diff --git a/internal/handlers/settings.go b/internal/handlers/settings.go index ff98c519..6a490202 100644 --- a/internal/handlers/settings.go +++ b/internal/handlers/settings.go @@ -14,6 +14,7 @@ import ( "github.com/memohai/memoh/internal/settings" ) +// SettingsHandler serves /bots/:bot_id/settings get, upsert, delete. type SettingsHandler struct { service *settings.Service botService *bots.Service @@ -21,6 +22,7 @@ type SettingsHandler struct { logger *slog.Logger } +// NewSettingsHandler creates a settings handler. func NewSettingsHandler(log *slog.Logger, service *settings.Service, botService *bots.Service, accountService *accounts.Service) *SettingsHandler { return &SettingsHandler{ service: service, @@ -30,6 +32,7 @@ func NewSettingsHandler(log *slog.Logger, service *settings.Service, botService } } +// Register mounts /bots/:bot_id/settings on the Echo instance. func (h *SettingsHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/settings") group.GET("", h.Get) @@ -45,7 +48,7 @@ func (h *SettingsHandler) Register(e *echo.Echo) { // @Success 200 {object} settings.Settings // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/settings [get] +// @Router /bots/{bot_id}/settings [get]. func (h *SettingsHandler) Get(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -74,7 +77,7 @@ func (h *SettingsHandler) Get(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse // @Router /bots/{bot_id}/settings [put] -// @Router /bots/{bot_id}/settings [post] +// @Router /bots/{bot_id}/settings [post]. func (h *SettingsHandler) Upsert(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -108,7 +111,7 @@ func (h *SettingsHandler) Upsert(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/settings [delete] +// @Router /bots/{bot_id}/settings [delete]. func (h *SettingsHandler) Delete(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { diff --git a/internal/handlers/skills.go b/internal/handlers/skills.go index 4a636ff5..5cc99607 100644 --- a/internal/handlers/skills.go +++ b/internal/handlers/skills.go @@ -12,6 +12,7 @@ import ( "gopkg.in/yaml.v3" ) +// SkillItem is one skill entry (name, description, content, metadata) for list/upsert APIs. type SkillItem struct { Name string `json:"name"` Description string `json:"description"` @@ -19,14 +20,17 @@ type SkillItem struct { Metadata map[string]any `json:"metadata,omitempty"` } +// SkillsResponse holds the list of skills for list API. type SkillsResponse struct { Skills []SkillItem `json:"skills"` } +// SkillsUpsertRequest is the body for upserting skills (replace list). type SkillsUpsertRequest struct { Skills []SkillItem `json:"skills"` } +// SkillsDeleteRequest is the body for deleting skills by name. type SkillsDeleteRequest struct { Names []string `json:"names"` } @@ -43,7 +47,7 @@ type skillsOpResponse struct { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/skills [get] +// @Router /bots/{bot_id}/container/skills [get]. func (h *ContainerdHandler) ListSkills(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -69,12 +73,7 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error { continue } parsed := parseSkillFile(raw, name) - skills = append(skills, SkillItem{ - Name: parsed.Name, - Description: parsed.Description, - Content: parsed.Content, - Metadata: parsed.Metadata, - }) + skills = append(skills, SkillItem(parsed)) } return c.JSON(http.StatusOK, SkillsResponse{Skills: skills}) @@ -89,7 +88,7 @@ func (h *ContainerdHandler) ListSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/skills [post] +// @Router /bots/{bot_id}/container/skills [post]. func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -138,7 +137,7 @@ func (h *ContainerdHandler) UpsertSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/container/skills [delete] +// @Router /bots/{bot_id}/container/skills [delete]. func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { botID, err := h.requireBotAccess(c) if err != nil { @@ -173,7 +172,7 @@ func (h *ContainerdHandler) DeleteSkills(c echo.Context) error { // LoadSkills loads all skills from the container for the given bot. // This implements chat.SkillLoader. -func (h *ContainerdHandler) LoadSkills(ctx context.Context, botID string) ([]SkillItem, error) { +func (h *ContainerdHandler) LoadSkills(_ context.Context, botID string) ([]SkillItem, error) { skillsDir, err := h.ensureSkillsDirHost(botID) if err != nil { return nil, err @@ -195,12 +194,7 @@ func (h *ContainerdHandler) LoadSkills(ctx context.Context, botID string) ([]Ski continue } parsed := parseSkillFile(raw, name) - skills = append(skills, SkillItem{ - Name: parsed.Name, - Description: parsed.Description, - Content: parsed.Content, - Metadata: parsed.Metadata, - }) + skills = append(skills, SkillItem(parsed)) } return skills, nil } @@ -310,7 +304,7 @@ type parsedSkill struct { // key: value // --- // # Body content ... -func parseSkillFile(raw string, fallbackName string) parsedSkill { +func parseSkillFile(raw, fallbackName string) parsedSkill { result := parsedSkill{Name: fallbackName} trimmed := strings.TrimSpace(raw) @@ -326,13 +320,13 @@ func parseSkillFile(raw string, fallbackName string) parsedSkill { } else if len(rest) > 1 && rest[0] == '\r' && rest[1] == '\n' { rest = rest[2:] } - closingIdx := strings.Index(rest, "\n---") - if closingIdx < 0 { + before, after, ok := strings.Cut(rest, "\n---") + if !ok { return result } - frontmatterRaw := rest[:closingIdx] - body := rest[closingIdx+4:] + frontmatterRaw := before + body := after body = strings.TrimLeft(body, "\r\n") result.Content = body diff --git a/internal/handlers/subagent.go b/internal/handlers/subagent.go index bd55c264..4dc291d4 100644 --- a/internal/handlers/subagent.go +++ b/internal/handlers/subagent.go @@ -13,6 +13,7 @@ import ( "github.com/memohai/memoh/internal/subagent" ) +// SubagentHandler serves /bots/:bot_id/subagents CRUD and context/skills APIs. type SubagentHandler struct { service *subagent.Service botService *bots.Service @@ -20,6 +21,7 @@ type SubagentHandler struct { logger *slog.Logger } +// NewSubagentHandler creates a subagent handler. func NewSubagentHandler(log *slog.Logger, service *subagent.Service, botService *bots.Service, accountService *accounts.Service) *SubagentHandler { return &SubagentHandler{ service: service, @@ -29,6 +31,7 @@ func NewSubagentHandler(log *slog.Logger, service *subagent.Service, botService } } +// Register mounts /bots/:bot_id/subagents routes on the Echo instance. func (h *SubagentHandler) Register(e *echo.Echo) { group := e.Group("/bots/:bot_id/subagents") group.POST("", h.Create) @@ -51,7 +54,7 @@ func (h *SubagentHandler) Register(e *echo.Echo) { // @Success 201 {object} subagent.Subagent // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents [post] +// @Router /bots/{bot_id}/subagents [post]. func (h *SubagentHandler) Create(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -82,7 +85,7 @@ func (h *SubagentHandler) Create(c echo.Context) error { // @Success 200 {object} subagent.ListResponse // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents [get] +// @Router /bots/{bot_id}/subagents [get]. func (h *SubagentHandler) List(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -111,7 +114,7 @@ func (h *SubagentHandler) List(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id} [get] +// @Router /bots/{bot_id}/subagents/{id} [get]. func (h *SubagentHandler) Get(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -148,7 +151,7 @@ func (h *SubagentHandler) Get(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id} [put] +// @Router /bots/{bot_id}/subagents/{id} [put]. func (h *SubagentHandler) Update(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -192,7 +195,7 @@ func (h *SubagentHandler) Update(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id} [delete] +// @Router /bots/{bot_id}/subagents/{id} [delete]. func (h *SubagentHandler) Delete(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -231,7 +234,7 @@ func (h *SubagentHandler) Delete(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/context [get] +// @Router /bots/{bot_id}/subagents/{id}/context [get]. func (h *SubagentHandler) GetContext(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -268,7 +271,7 @@ func (h *SubagentHandler) GetContext(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/context [put] +// @Router /bots/{bot_id}/subagents/{id}/context [put]. func (h *SubagentHandler) UpdateContext(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -312,7 +315,7 @@ func (h *SubagentHandler) UpdateContext(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/skills [get] +// @Router /bots/{bot_id}/subagents/{id}/skills [get]. func (h *SubagentHandler) GetSkills(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -349,7 +352,7 @@ func (h *SubagentHandler) GetSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/skills [put] +// @Router /bots/{bot_id}/subagents/{id}/skills [put]. func (h *SubagentHandler) UpdateSkills(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -394,7 +397,7 @@ func (h *SubagentHandler) UpdateSkills(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{bot_id}/subagents/{id}/skills [post] +// @Router /bots/{bot_id}/subagents/{id}/skills [post]. func (h *SubagentHandler) AddSkills(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { diff --git a/internal/handlers/swagger.go b/internal/handlers/swagger.go index 4d4dec89..b18045ba 100644 --- a/internal/handlers/swagger.go +++ b/internal/handlers/swagger.go @@ -14,26 +14,31 @@ import ( //go:generate go run github.com/swaggo/swag/cmd/swag@latest init -g swagger.go -o ../../spec --parseDependency --parseInternal +// Cached swagger spec and load error (loaded once on first Spec call). var ( swaggerSpec []byte swaggerOnce sync.Once swaggerErr error ) +// SwaggerHandler serves OpenAPI spec and docs UI. type SwaggerHandler struct { logger *slog.Logger } +// NewSwaggerHandler creates a swagger handler. func NewSwaggerHandler(log *slog.Logger) *SwaggerHandler { return &SwaggerHandler{logger: log.With(slog.String("handler", "swagger"))} } +// Register mounts GET api/swagger.json and api/docs on the Echo instance. func (h *SwaggerHandler) Register(e *echo.Echo) { e.GET("api/swagger.json", h.Spec) e.GET("api/docs", h.UI) e.GET("api/docs/", h.UI) } +// Spec returns the swagger.json blob (from spec/swagger.json, loaded once). func (h *SwaggerHandler) Spec(c echo.Context) error { swaggerOnce.Do(func() { swaggerSpec, swaggerErr = os.ReadFile("spec/swagger.json") @@ -44,6 +49,7 @@ func (h *SwaggerHandler) Spec(c echo.Context) error { return c.Blob(http.StatusOK, "application/json", swaggerSpec) } +// UI returns the Swagger UI HTML page. func (h *SwaggerHandler) UI(c echo.Context) error { return c.HTML(http.StatusOK, swaggerUIHTML) } diff --git a/internal/handlers/users.go b/internal/handlers/users.go index c258e37f..1afb72b3 100644 --- a/internal/handlers/users.go +++ b/internal/handlers/users.go @@ -53,6 +53,7 @@ func NewUsersHandler(log *slog.Logger, service *accounts.Service, channelIdentit } } +// Register mounts /users and /bots routes on the Echo instance. func (h *UsersHandler) Register(e *echo.Echo) { userGroup := e.Group("/users") userGroup.GET("/me", h.GetMe) @@ -91,7 +92,7 @@ func (h *UsersHandler) Register(e *echo.Echo) { // @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me [get] +// @Router /users/me [get]. func (h *UsersHandler) GetMe(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -112,7 +113,7 @@ func (h *UsersHandler) GetMe(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/identities [get] +// @Router /users/me/identities [get]. func (h *UsersHandler) ListMyIdentities(c echo.Context) error { userID, err := h.requireChannelIdentityID(c) if err != nil { @@ -139,7 +140,7 @@ func (h *UsersHandler) ListMyIdentities(c echo.Context) error { // @Success 200 {object} accounts.Account // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me [put] +// @Router /users/me [put]. func (h *UsersHandler) UpdateMe(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -164,7 +165,7 @@ func (h *UsersHandler) UpdateMe(c echo.Context) error { // @Success 204 "No Content" // @Failure 400 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/me/password [put] +// @Router /users/me/password [put]. func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -191,7 +192,7 @@ func (h *UsersHandler) UpdateMyPassword(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users [get] +// @Router /users [get]. func (h *UsersHandler) ListUsers(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -224,7 +225,7 @@ func (h *UsersHandler) ListUsers(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/{id} [get] +// @Router /users/{id} [get]. func (h *UsersHandler) GetUser(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -264,7 +265,7 @@ func (h *UsersHandler) GetUser(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/{id} [put] +// @Router /users/{id} [put]. func (h *UsersHandler) UpdateUser(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -310,7 +311,7 @@ func (h *UsersHandler) UpdateUser(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users/{id}/password [put] +// @Router /users/{id}/password [put]. func (h *UsersHandler) ResetUserPassword(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -352,7 +353,7 @@ func (h *UsersHandler) ResetUserPassword(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /users [post] +// @Router /users [post]. func (h *UsersHandler) CreateUser(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -369,7 +370,7 @@ func (h *UsersHandler) CreateUser(c echo.Context) error { if err := c.Bind(&req); err != nil { return echo.NewHTTPError(http.StatusBadRequest, err.Error()) } - resp, err := h.service.CreateHuman(c.Request().Context(), "", req) + resp, err := h.service.Create(c.Request().Context(), "", req) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) } @@ -385,7 +386,7 @@ func (h *UsersHandler) CreateUser(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots [post] +// @Router /bots [post]. func (h *UsersHandler) CreateBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -456,7 +457,7 @@ func (h *UsersHandler) CreateBot(c echo.Context) error { // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots [get] +// @Router /bots [get]. func (h *UsersHandler) ListBots(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -494,7 +495,7 @@ func (h *UsersHandler) ListBots(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id} [get] +// @Router /bots/{id} [get]. func (h *UsersHandler) GetBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -521,7 +522,7 @@ func (h *UsersHandler) GetBot(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/checks [get] +// @Router /bots/{id}/checks [get]. func (h *UsersHandler) ListBotChecks(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -550,7 +551,7 @@ func (h *UsersHandler) ListBotChecks(c echo.Context) error { // @Tags bots // @Param id path string true "Bot ID" // @Success 200 {object} bots.ListCheckKeysResponse -// @Router /bots/{id}/checks/keys [get] +// @Router /bots/{id}/checks/keys [get]. func (h *UsersHandler) ListBotCheckKeys(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -577,7 +578,7 @@ func (h *UsersHandler) ListBotCheckKeys(c echo.Context) error { // @Param id path string true "Bot ID" // @Param key path string true "Check key" // @Success 200 {object} bots.BotCheck -// @Router /bots/{id}/checks/run/{key} [get] +// @Router /bots/{id}/checks/run/{key} [get]. func (h *UsersHandler) RunBotCheck(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -612,7 +613,7 @@ func (h *UsersHandler) RunBotCheck(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id} [put] +// @Router /bots/{id} [put]. func (h *UsersHandler) UpdateBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -647,7 +648,7 @@ func (h *UsersHandler) UpdateBot(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/owner [put] +// @Router /bots/{id}/owner [put]. func (h *UsersHandler) TransferBotOwner(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -691,7 +692,7 @@ func (h *UsersHandler) TransferBotOwner(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id} [delete] +// @Router /bots/{id} [delete]. func (h *UsersHandler) DeleteBot(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -726,7 +727,7 @@ func (h *UsersHandler) DeleteBot(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/members [get] +// @Router /bots/{id}/members [get]. func (h *UsersHandler) ListBotMembers(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -757,7 +758,7 @@ func (h *UsersHandler) ListBotMembers(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/members [put] +// @Router /bots/{id}/members [put]. func (h *UsersHandler) UpsertBotMember(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -796,7 +797,7 @@ func (h *UsersHandler) UpsertBotMember(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/members/{user_id} [delete] +// @Router /bots/{id}/members/{user_id} [delete]. func (h *UsersHandler) DeleteBotMember(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -825,12 +826,12 @@ func (h *UsersHandler) DeleteBotMember(c echo.Context) error { // @Tags bots // @Param id path string true "Bot ID" // @Param platform path string true "Channel platform" -// @Success 200 {object} channel.ChannelConfig +// @Success 200 {object} channel.Config // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform} [get] +// @Router /bots/{id}/channel/{platform} [get]. func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -864,12 +865,12 @@ func (h *UsersHandler) GetBotChannelConfig(c echo.Context) error { // @Param id path string true "Bot ID" // @Param platform path string true "Channel platform" // @Param payload body channel.UpsertConfigRequest true "Channel config payload" -// @Success 200 {object} channel.ChannelConfig +// @Success 200 {object} channel.Config // @Failure 400 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform} [put] +// @Router /bots/{id}/channel/{platform} [put]. func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -912,7 +913,7 @@ func (h *UsersHandler) UpsertBotChannelConfig(c echo.Context) error { // @Failure 403 {object} ErrorResponse // @Failure 404 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/send [post] +// @Router /bots/{id}/channel/{platform}/send [post]. func (h *UsersHandler) SendBotMessage(c echo.Context) error { channelIdentityID, err := h.requireChannelIdentityID(c) if err != nil { @@ -957,7 +958,7 @@ func (h *UsersHandler) SendBotMessage(c echo.Context) error { // @Failure 401 {object} ErrorResponse // @Failure 403 {object} ErrorResponse // @Failure 500 {object} ErrorResponse -// @Router /bots/{id}/channel/{platform}/send_chat [post] +// @Router /bots/{id}/channel/{platform}/send_chat [post]. func (h *UsersHandler) SendBotMessageSession(c echo.Context) error { chatToken, err := auth.ChatTokenFromContext(c) if err != nil { diff --git a/internal/identity/types.go b/internal/identity/types.go index 125e325a..ce7efd4e 100644 --- a/internal/identity/types.go +++ b/internal/identity/types.go @@ -1,7 +1,9 @@ +// Package identity provides identity type constants and helpers. package identity import "strings" +// Identity type constants: human (user) or bot. const ( IdentityTypeHuman = "human" IdentityTypeBot = "bot" diff --git a/internal/identity/user.go b/internal/identity/user.go index 6e5b9d41..c2a6797b 100644 --- a/internal/identity/user.go +++ b/internal/identity/user.go @@ -12,7 +12,7 @@ func ValidateChannelIdentityID(channelIdentityID string) error { return fmt.Errorf("%w: channel identity id required", ctr.ErrInvalidArgument) } for _, r := range channelIdentityID { - if !(r == '-' || r == '_' || (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9')) { + if r != '-' && r != '_' && (r < 'a' || r > 'z') && (r < 'A' || r > 'Z') && (r < '0' || r > '9') { return fmt.Errorf("%w: invalid channel identity id", ctr.ErrInvalidArgument) } } diff --git a/internal/logger/logger.go b/internal/logger/logger.go index 81afe073..792db961 100644 --- a/internal/logger/logger.go +++ b/internal/logger/logger.go @@ -1,3 +1,4 @@ +// Package logger provides structured logging and context-aware logger injection. package logger import ( @@ -9,9 +10,10 @@ import ( type ctxKey struct{} +// L is the global default logger; initialize with Init or use FromContext for request-scoped loggers. var ( - L *slog.Logger = slog.Default() - logKey = ctxKey{} + L = slog.Default() + logKey = ctxKey{} ) // Init initializes the global logger with the given level and format (e.g. "debug", "json"). @@ -59,8 +61,14 @@ func parseLevel(level string) slog.Level { } } -// Debug, Info, Warn, Error log with the global logger (slog.Attr or key-value pairs). +// Debug logs at debug level with the global logger (slog.Attr or key-value pairs). func Debug(msg string, args ...any) { L.Debug(msg, args...) } -func Info(msg string, args ...any) { L.Info(msg, args...) } -func Warn(msg string, args ...any) { L.Warn(msg, args...) } + +// Info logs at info level with the global logger. +func Info(msg string, args ...any) { L.Info(msg, args...) } + +// Warn logs at warn level with the global logger. +func Warn(msg string, args ...any) { L.Warn(msg, args...) } + +// Error logs at error level with the global logger. func Error(msg string, args ...any) { L.Error(msg, args...) } diff --git a/internal/mcp/checker.go b/internal/mcp/checker.go index 8ccb529c..f1f8d6b6 100644 --- a/internal/mcp/checker.go +++ b/internal/mcp/checker.go @@ -1,3 +1,4 @@ +// Package mcp provides MCP connection management, tool gateway, and federation. package mcp import ( diff --git a/internal/mcp/connections.go b/internal/mcp/connections.go index 206e8951..ee7f56ed 100644 --- a/internal/mcp/connections.go +++ b/internal/mcp/connections.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -40,11 +41,11 @@ type UpsertRequest struct { // ImportRequest accepts a standard mcpServers dict for batch import. type ImportRequest struct { - MCPServers map[string]MCPServerEntry `json:"mcpServers"` + MCPServers map[string]ServerEntry `json:"mcpServers"` } -// MCPServerEntry is one entry in the standard mcpServers dict. -type MCPServerEntry struct { +// ServerEntry is one entry in the standard mcpServers dict. +type ServerEntry struct { Command string `json:"command,omitempty"` Args []string `json:"args,omitempty"` Env map[string]string `json:"env,omitempty"` @@ -61,7 +62,7 @@ type ListResponse struct { // ExportResponse returns connections in standard mcpServers format. type ExportResponse struct { - MCPServers map[string]MCPServerEntry `json:"mcpServers"` + MCPServers map[string]ServerEntry `json:"mcpServers"` } // ConnectionService handles CRUD operations for MCP connections. @@ -84,7 +85,7 @@ func NewConnectionService(log *slog.Logger, queries *sqlc.Queries) *ConnectionSe // ListByBot returns all MCP connections for a bot. func (s *ConnectionService) ListByBot(ctx context.Context, botID string) ([]Connection, error) { if s.queries == nil { - return nil, fmt.Errorf("mcp queries not configured") + return nil, errors.New("mcp queries not configured") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -123,7 +124,7 @@ func (s *ConnectionService) ListActiveByBot(ctx context.Context, botID string) ( // Get returns a specific MCP connection for a bot. func (s *ConnectionService) Get(ctx context.Context, botID, id string) (Connection, error) { if s.queries == nil { - return Connection{}, fmt.Errorf("mcp queries not configured") + return Connection{}, errors.New("mcp queries not configured") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -146,7 +147,7 @@ func (s *ConnectionService) Get(ctx context.Context, botID, id string) (Connecti // Create inserts a new MCP connection for a bot. func (s *ConnectionService) Create(ctx context.Context, botID string, req UpsertRequest) (Connection, error) { if s.queries == nil { - return Connection{}, fmt.Errorf("mcp queries not configured") + return Connection{}, errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -154,7 +155,7 @@ func (s *ConnectionService) Create(ctx context.Context, botID string, req Upsert } name := strings.TrimSpace(req.Name) if name == "" { - return Connection{}, fmt.Errorf("name is required") + return Connection{}, errors.New("name is required") } mcpType, config, err := inferTypeAndConfig(req) if err != nil { @@ -184,7 +185,7 @@ func (s *ConnectionService) Create(ctx context.Context, botID string, req Upsert // Update modifies an existing MCP connection. func (s *ConnectionService) Update(ctx context.Context, botID, id string, req UpsertRequest) (Connection, error) { if s.queries == nil { - return Connection{}, fmt.Errorf("mcp queries not configured") + return Connection{}, errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -196,7 +197,7 @@ func (s *ConnectionService) Update(ctx context.Context, botID, id string, req Up } name := strings.TrimSpace(req.Name) if name == "" { - return Connection{}, fmt.Errorf("name is required") + return Connection{}, errors.New("name is required") } mcpType, config, err := inferTypeAndConfig(req) if err != nil { @@ -230,7 +231,7 @@ func (s *ConnectionService) Update(ctx context.Context, botID, id string, req Up // Connections not in the input are left untouched. func (s *ConnectionService) Import(ctx context.Context, botID string, req ImportRequest) ([]Connection, error) { if s.queries == nil { - return nil, fmt.Errorf("mcp queries not configured") + return nil, errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -278,7 +279,7 @@ func (s *ConnectionService) ExportByBot(ctx context.Context, botID string) (Expo if err != nil { return ExportResponse{}, err } - servers := make(map[string]MCPServerEntry, len(items)) + servers := make(map[string]ServerEntry, len(items)) for _, conn := range items { servers[conn.Name] = connectionToExportEntry(conn) } @@ -288,7 +289,7 @@ func (s *ConnectionService) ExportByBot(ctx context.Context, botID string) (Expo // Delete removes an MCP connection. func (s *ConnectionService) Delete(ctx context.Context, botID, id string) error { if s.queries == nil { - return fmt.Errorf("mcp queries not configured") + return errors.New("mcp queries not configured") } botUUID, err := db.ParseUUID(botID) if err != nil { @@ -307,7 +308,7 @@ func (s *ConnectionService) Delete(ctx context.Context, botID, id string) error // BatchDelete removes multiple MCP connections by IDs. Invalid IDs are skipped; at least one must succeed for no error. func (s *ConnectionService) BatchDelete(ctx context.Context, botID string, ids []string) error { if s.queries == nil { - return fmt.Errorf("mcp queries not configured") + return errors.New("mcp queries not configured") } if len(ids) == 0 { return nil @@ -362,10 +363,10 @@ func inferTypeAndConfig(req UpsertRequest) (string, map[string]any, error) { hasURL := strings.TrimSpace(req.URL) != "" if !hasCommand && !hasURL { - return "", nil, fmt.Errorf("command or url is required") + return "", nil, errors.New("command or url is required") } if hasCommand && hasURL { - return "", nil, fmt.Errorf("command and url are mutually exclusive") + return "", nil, errors.New("command and url are mutually exclusive") } config := map[string]any{} @@ -395,8 +396,8 @@ func inferTypeAndConfig(req UpsertRequest) (string, map[string]any, error) { return "http", config, nil } -// entryToUpsertRequest converts a named MCPServerEntry to an UpsertRequest. -func entryToUpsertRequest(name string, entry MCPServerEntry) UpsertRequest { +// entryToUpsertRequest converts a named ServerEntry to an UpsertRequest. +func entryToUpsertRequest(name string, entry ServerEntry) UpsertRequest { return UpsertRequest{ Name: name, Command: entry.Command, @@ -410,8 +411,8 @@ func entryToUpsertRequest(name string, entry MCPServerEntry) UpsertRequest { } // connectionToExportEntry converts a stored connection to standard mcpServers entry. -func connectionToExportEntry(conn Connection) MCPServerEntry { - entry := MCPServerEntry{} +func connectionToExportEntry(conn Connection) ServerEntry { + entry := ServerEntry{} switch conn.Type { case "stdio": entry.Command, _ = conn.Config["command"].(string) diff --git a/internal/mcp/connections_test.go b/internal/mcp/connections_test.go index 01590ee5..01472ce2 100644 --- a/internal/mcp/connections_test.go +++ b/internal/mcp/connections_test.go @@ -155,7 +155,7 @@ func TestConnectionToExportEntry_SSE(t *testing.T) { } func TestEntryToUpsertRequest(t *testing.T) { - entry := MCPServerEntry{ + entry := ServerEntry{ Command: "npx", Args: []string{"-y", "server"}, Env: map[string]string{"KEY": "val"}, diff --git a/internal/mcp/jsonrpc.go b/internal/mcp/jsonrpc.go index 18912243..3f1e983e 100644 --- a/internal/mcp/jsonrpc.go +++ b/internal/mcp/jsonrpc.go @@ -5,10 +5,12 @@ import ( "strings" ) +// IsNotification reports whether the request is a notification (no id, method starts with notifications/). func IsNotification(req JSONRPCRequest) bool { return len(req.ID) == 0 && strings.HasPrefix(req.Method, "notifications/") } +// JSONRPCErrorResponse builds a JSON-RPC response with the given error code and message. func JSONRPCErrorResponse(id json.RawMessage, code int, message string) JSONRPCResponse { return JSONRPCResponse{ JSONRPC: "2.0", diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index a2c1806a..47fd6921 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -24,11 +24,13 @@ import ( "github.com/memohai/memoh/internal/identity" ) +// Label and ID prefix for MCP bot containers. const ( BotLabelKey = "mcp.bot_id" ContainerPrefix = "mcp-" ) +// ExecRequest specifies command, env, work dir, and stdio options for container exec. type ExecRequest struct { BotID string Command []string @@ -38,6 +40,7 @@ type ExecRequest struct { UseStdio bool } +// ExecResult holds the exit code from container exec. type ExecResult struct { ExitCode uint32 } @@ -49,6 +52,7 @@ type ExecWithCaptureResult struct { ExitCode uint32 } +// Manager manages MCP containers (ensure, exec, versioning) and bot data dirs. type Manager struct { service ctr.Service cfg config.MCPConfig @@ -59,6 +63,7 @@ type Manager struct { logger *slog.Logger } +// NewManager creates an MCP manager with containerd service, config, namespace, and DB pool. func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, namespace string, conn *pgxpool.Pool) *Manager { if namespace == "" { namespace = config.DefaultNamespace @@ -76,6 +81,7 @@ func NewManager(log *slog.Logger, service ctr.Service, cfg config.MCPConfig, nam } } +// Init pulls the MCP image and unpacks with the configured snapshotter. func (m *Manager) Init(ctx context.Context) error { image := m.imageRef() @@ -163,6 +169,7 @@ func (m *Manager) ListBots(ctx context.Context) ([]string, error) { return botIDs, nil } +// Start ensures the MCP container exists, starts the task, and sets up CNI network. func (m *Manager) Start(ctx context.Context, botID string) error { if err := m.EnsureBot(ctx, botID); err != nil { return err @@ -183,6 +190,7 @@ func (m *Manager) Start(ctx context.Context, botID string) error { return nil } +// Stop stops the MCP container task with the given timeout (force kill). func (m *Manager) Stop(ctx context.Context, botID string, timeout time.Duration) error { if err := validateBotID(botID); err != nil { return err @@ -193,6 +201,7 @@ func (m *Manager) Stop(ctx context.Context, botID string, timeout time.Duration) }) } +// Delete removes the task, network, and container (with snapshot cleanup) for the bot. func (m *Manager) Delete(ctx context.Context, botID string) error { if err := validateBotID(botID); err != nil { return err @@ -211,6 +220,7 @@ func (m *Manager) Delete(ctx context.Context, botID string) error { }) } +// Exec runs a command in the MCP container and returns exit code (optional stdio/terminal). func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error) { if err := validateBotID(req.BotID); err != nil { return nil, err @@ -219,7 +229,7 @@ func (m *Manager) Exec(ctx context.Context, req ExecRequest) (*ExecResult, error return nil, fmt.Errorf("%w: empty command", ctr.ErrInvalidArgument) } if m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } startedAt := time.Now() @@ -263,7 +273,7 @@ func (m *Manager) ExecWithCapture(ctx context.Context, req ExecRequest) (*ExecWi return nil, fmt.Errorf("%w: empty command", ctr.ErrInvalidArgument) } if m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } if runtime.GOOS == "darwin" { @@ -281,7 +291,8 @@ func (m *Manager) execWithCaptureLima(ctx context.Context, req ExecRequest) (*Ex // Each element becomes a separate OS arg to limactl. Lima/SSH joins // them with spaces and passes the result to the remote shell, so only // values that may contain shell-special characters need quoting. - args := []string{"shell", "default", "--", + args := []string{ + "shell", "default", "--", "sudo", "ctr", "-n", m.namespace, "tasks", "exec", "--exec-id", execID, } @@ -332,7 +343,11 @@ func (m *Manager) execWithCaptureContainerd(ctx context.Context, req ExecRequest if err != nil { return nil, fmt.Errorf("create fifo dir: %w", err) } - defer os.RemoveAll(fifoDir) + defer func() { + if err := os.RemoveAll(fifoDir); err != nil { + m.logger.Warn("exec cleanup: remove fifo dir failed", slog.String("dir", fifoDir), slog.Any("error", err)) + } + }() var stdoutBuf, stderrBuf bytes.Buffer result, err := m.service.ExecTask(ctx, m.containerID(req.BotID), ctr.ExecTaskRequest{ diff --git a/internal/mcp/providers/container/fsops.go b/internal/mcp/providers/container/fsops.go index f3c05387..dbc448cd 100644 --- a/internal/mcp/providers/container/fsops.go +++ b/internal/mcp/providers/container/fsops.go @@ -1,3 +1,4 @@ +// Package container provides MCP container provider (filesystem and exec tools). package container import ( @@ -60,7 +61,7 @@ func ExecWrite(ctx context.Context, runner ExecRunner, botID, workDir, filePath, } // ExecList lists directory entries inside the container via find + stat. -// Output format per line: |||| +// Output format per line: ||||. func ExecList(ctx context.Context, runner ExecRunner, botID, workDir, dirPath string, recursive bool) ([]FileEntry, error) { depthFlag := "-maxdepth 1" if recursive { diff --git a/internal/mcp/providers/container/provider.go b/internal/mcp/providers/container/provider.go index fbb2427a..b68d4d15 100644 --- a/internal/mcp/providers/container/provider.go +++ b/internal/mcp/providers/container/provider.go @@ -53,7 +53,7 @@ func NewExecutor(log *slog.Logger, execRunner ExecRunner, execWorkDir string) *E } // ListTools returns read, write, list, edit, and exec tool descriptors. -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { return []mcpgw.ToolDescriptor{ { Name: toolRead, @@ -135,8 +135,8 @@ func normalizePath(path string) string { if path == prefix { return "." } - if strings.HasPrefix(path, prefix+"/") { - return strings.TrimLeft(strings.TrimPrefix(path, prefix+"/"), "/") + if after, ok := strings.CutPrefix(path, prefix+"/"); ok { + return strings.TrimLeft(after, "/") } return path } diff --git a/internal/mcp/providers/container/provider_test.go b/internal/mcp/providers/container/provider_test.go index d921d4d9..c18615b6 100644 --- a/internal/mcp/providers/container/provider_test.go +++ b/internal/mcp/providers/container/provider_test.go @@ -18,7 +18,7 @@ type fakeExecRunner struct { handler func(req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) } -func (f *fakeExecRunner) ExecWithCapture(ctx context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { +func (f *fakeExecRunner) ExecWithCapture(_ context.Context, req mcpgw.ExecRequest) (*mcpgw.ExecWithCaptureResult, error) { f.lastReq = req if f.handler != nil { return f.handler(req) @@ -141,8 +141,8 @@ func TestExecutor_CallTool_Edit(t *testing.T) { if strings.Contains(cmd, "base64 -d") { // Write step: verify the written content contains the replacement. // Extract base64 from: echo '' | base64 -d > 'path' - parts := strings.Split(cmd, "'") - for _, p := range parts { + parts := strings.SplitSeq(cmd, "'") + for p := range parts { decoded, err := base64.StdEncoding.DecodeString(p) if err == nil && strings.Contains(string(decoded), "goodbye world") { return &mcpgw.ExecWithCaptureResult{ExitCode: 0}, nil diff --git a/internal/mcp/providers/directory/provider.go b/internal/mcp/providers/directory/provider.go index 92fa0753..eba2caa0 100644 --- a/internal/mcp/providers/directory/provider.go +++ b/internal/mcp/providers/directory/provider.go @@ -1,3 +1,4 @@ +// Package directory provides the MCP directory provider (channel user lookup). package directory import ( @@ -13,12 +14,12 @@ const toolLookupChannelUser = "lookup_channel_user" // ConfigResolver resolves effective channel config for a bot (used to call directory APIs). type ConfigResolver interface { - ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.ChannelType) (channel.ChannelConfig, error) + ResolveEffectiveConfig(ctx context.Context, botID string, channelType channel.Type) (channel.Config, error) } // ChannelTypeResolver parses platform name to channel type. type ChannelTypeResolver interface { - ParseChannelType(raw string) (channel.ChannelType, error) + ParseChannelType(raw string) (channel.Type, error) } // Executor exposes channel directory lookup as an MCP tool for the LLM. @@ -43,7 +44,7 @@ func NewExecutor(log *slog.Logger, registry *channel.Registry, configResolver Co } // ListTools returns the lookup_channel_user tool descriptor. -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if p.registry == nil || p.configResolver == nil || p.typeResolver == nil { return []mcpgw.ToolDescriptor{}, nil } diff --git a/internal/mcp/providers/directory/provider_test.go b/internal/mcp/providers/directory/provider_test.go index 1f42969a..423632ab 100644 --- a/internal/mcp/providers/directory/provider_test.go +++ b/internal/mcp/providers/directory/provider_test.go @@ -2,6 +2,7 @@ package directory import ( "context" + "errors" "testing" "github.com/memohai/memoh/internal/channel" @@ -39,34 +40,38 @@ func TestExecutor_ListTools_NilDeps(t *testing.T) { func TestExecutor_CallTool_NotFound(t *testing.T) { exec := NewExecutor(nil, channel.NewRegistry(), &fakeConfigResolver{}, channel.NewRegistry()) _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, "other_tool", nil) - if err != mcpgw.ErrToolNotFound { + if !errors.Is(err, mcpgw.ErrToolNotFound) { t.Errorf("expected ErrToolNotFound, got %v", err) } } type dirMockAdapter struct { - channelType channel.ChannelType + channelType channel.Type } -func (d *dirMockAdapter) Type() channel.ChannelType { return d.channelType } +func (d *dirMockAdapter) Type() channel.Type { return d.channelType } func (d *dirMockAdapter) Descriptor() channel.Descriptor { return channel.Descriptor{Type: d.channelType, DisplayName: "DirTest"} } -func (d *dirMockAdapter) ListPeers(context.Context, channel.ChannelConfig, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + +func (d *dirMockAdapter) ListPeers(context.Context, channel.Config, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (d *dirMockAdapter) ListGroups(context.Context, channel.ChannelConfig, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + +func (d *dirMockAdapter) ListGroups(context.Context, channel.Config, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (d *dirMockAdapter) ListGroupMembers(context.Context, channel.ChannelConfig, string, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { + +func (d *dirMockAdapter) ListGroupMembers(context.Context, channel.Config, string, channel.DirectoryQuery) ([]channel.DirectoryEntry, error) { return nil, nil } -func (d *dirMockAdapter) ResolveEntry(context.Context, channel.ChannelConfig, string, channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { + +func (d *dirMockAdapter) ResolveEntry(context.Context, channel.Config, string, channel.DirectoryEntryKind) (channel.DirectoryEntry, error) { return channel.DirectoryEntry{Kind: channel.DirectoryEntryUser, ID: "id1", Name: "Test User"}, nil } type fakeConfigResolver struct{} -func (f *fakeConfigResolver) ResolveEffectiveConfig(context.Context, string, channel.ChannelType) (channel.ChannelConfig, error) { - return channel.ChannelConfig{}, nil +func (f *fakeConfigResolver) ResolveEffectiveConfig(context.Context, string, channel.Type) (channel.Config, error) { + return channel.Config{}, nil } diff --git a/internal/mcp/providers/memory/provider.go b/internal/mcp/providers/memory/provider.go index d147e84e..5c829748 100644 --- a/internal/mcp/providers/memory/provider.go +++ b/internal/mcp/providers/memory/provider.go @@ -1,3 +1,4 @@ +// Package memory provides the MCP memory provider (search and recall tools). package memory import ( @@ -18,22 +19,26 @@ const ( sharedMemoryNamespace = "bot" ) -type MemorySearcher interface { +// Searcher performs memory search (used by memory tool executor). +type Searcher interface { Search(ctx context.Context, req mem.SearchRequest) (mem.SearchResponse, error) } +// AdminChecker checks if a channel identity is admin (for memory tool access). type AdminChecker interface { IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) } +// Executor is the MCP tool executor for search_memory (delegates to Searcher, checks chat access). type Executor struct { - searcher MemorySearcher + searcher Searcher chatAccessor conversation.Accessor adminChecker AdminChecker logger *slog.Logger } -func NewExecutor(log *slog.Logger, searcher MemorySearcher, chatAccessor conversation.Accessor, adminChecker AdminChecker) *Executor { +// NewExecutor creates a memory tool executor. +func NewExecutor(log *slog.Logger, searcher Searcher, chatAccessor conversation.Accessor, adminChecker AdminChecker) *Executor { if log == nil { log = slog.Default() } @@ -45,7 +50,8 @@ func NewExecutor(log *slog.Logger, searcher MemorySearcher, chatAccessor convers } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +// ListTools returns the search_memory tool descriptor. +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if p.searcher == nil || p.chatAccessor == nil { return []mcpgw.ToolDescriptor{}, nil } @@ -71,6 +77,7 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte }, nil } +// CallTool runs search_memory (query, limit) and returns MCP result; validates chat access. func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { if toolName != toolSearchMemory { return nil, mcpgw.ErrToolNotFound @@ -142,7 +149,7 @@ func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContex p.logger.Warn("memory search namespace failed", slog.String("namespace", sharedMemoryNamespace), slog.Any("error", err)) return mcpgw.BuildToolErrorResult("memory search failed"), nil } - allResults := make([]mem.MemoryItem, 0, len(resp.Results)) + allResults := make([]mem.Item, 0, len(resp.Results)) allResults = append(allResults, resp.Results...) allResults = deduplicateMemoryItems(allResults) @@ -182,12 +189,12 @@ func (p *Executor) canAccessChat(ctx context.Context, chatID, channelIdentityID return p.chatAccessor.IsParticipant(ctx, chatID, channelIdentityID) } -func deduplicateMemoryItems(items []mem.MemoryItem) []mem.MemoryItem { +func deduplicateMemoryItems(items []mem.Item) []mem.Item { if len(items) == 0 { return items } seen := make(map[string]struct{}, len(items)) - result := make([]mem.MemoryItem, 0, len(items)) + result := make([]mem.Item, 0, len(items)) for _, item := range items { id := strings.TrimSpace(item.ID) if id == "" { diff --git a/internal/mcp/providers/memory/provider_test.go b/internal/mcp/providers/memory/provider_test.go index 9c6209f1..2a6f01a6 100644 --- a/internal/mcp/providers/memory/provider_test.go +++ b/internal/mcp/providers/memory/provider_test.go @@ -15,7 +15,7 @@ type fakeSearcher struct { err error } -func (f *fakeSearcher) Search(ctx context.Context, req memory.SearchRequest) (memory.SearchResponse, error) { +func (f *fakeSearcher) Search(_ context.Context, _ memory.SearchRequest) (memory.SearchResponse, error) { if f.err != nil { return memory.SearchResponse{}, f.err } @@ -29,22 +29,22 @@ type fakeChatAccessor struct { participantErr error } -func (f *fakeChatAccessor) Get(ctx context.Context, conversationID string) (conversation.Conversation, error) { +func (f *fakeChatAccessor) Get(_ context.Context, _ string) (conversation.Conversation, error) { if f.getErr != nil { return conversation.Conversation{}, f.getErr } return f.chat, nil } -func (f *fakeChatAccessor) IsParticipant(ctx context.Context, conversationID, channelIdentityID string) (bool, error) { +func (f *fakeChatAccessor) IsParticipant(_ context.Context, _, _ string) (bool, error) { if f.participantErr != nil { return false, f.participantErr } return f.participant, nil } -func (f *fakeChatAccessor) GetReadAccess(ctx context.Context, conversationID, channelIdentityID string) (conversation.ConversationReadAccess, error) { - return conversation.ConversationReadAccess{}, nil +func (f *fakeChatAccessor) GetReadAccess(_ context.Context, _, _ string) (conversation.ReadAccess, error) { + return conversation.ReadAccess{}, nil } type fakeAdminChecker struct { @@ -52,7 +52,7 @@ type fakeAdminChecker struct { err error } -func (f *fakeAdminChecker) IsAdmin(ctx context.Context, channelIdentityID string) (bool, error) { +func (f *fakeAdminChecker) IsAdmin(_ context.Context, _ string) (bool, error) { if f.err != nil { return false, f.err } @@ -91,7 +91,7 @@ func TestExecutor_CallTool_NotFound(t *testing.T) { accessor := &fakeChatAccessor{} exec := NewExecutor(nil, searcher, accessor, nil) _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) - if err != mcpgw.ErrToolNotFound { + if !errors.Is(err, mcpgw.ErrToolNotFound) { t.Errorf("expected ErrToolNotFound, got %v", err) } } @@ -136,7 +136,7 @@ func TestExecutor_CallTool_NoBotID(t *testing.T) { func TestExecutor_CallTool_Success_BotScope(t *testing.T) { searcher := &fakeSearcher{ resp: memory.SearchResponse{ - Results: []memory.MemoryItem{ + Results: []memory.Item{ {ID: "id1", Memory: "mem1", Score: 0.9}, }, }, @@ -213,7 +213,7 @@ func TestExecutor_CallTool_NotParticipant(t *testing.T) { func TestExecutor_CallTool_AdminBypass(t *testing.T) { searcher := &fakeSearcher{ - resp: memory.SearchResponse{Results: []memory.MemoryItem{{ID: "id1", Memory: "m1", Score: 0.8}}}, + resp: memory.SearchResponse{Results: []memory.Item{{ID: "id1", Memory: "m1", Score: 0.8}}}, } accessor := &fakeChatAccessor{ chat: conversation.Conversation{BotID: "bot1", ID: "c1"}, @@ -255,20 +255,20 @@ func TestExecutor_CallTool_SearchError(t *testing.T) { func TestDeduplicateMemoryItems(t *testing.T) { tests := []struct { name string - items []memory.MemoryItem + items []memory.Item wantLen int }{ {"empty", nil, 0}, - {"single", []memory.MemoryItem{{ID: "a", Memory: "m", Score: 1}}, 1}, - {"dedup by id", []memory.MemoryItem{ + {"single", []memory.Item{{ID: "a", Memory: "m", Score: 1}}, 1}, + {"dedup by id", []memory.Item{ {ID: "a", Memory: "m1", Score: 1}, {ID: "a", Memory: "m2", Score: 0.9}, }, 1}, - {"dedup by memory when id empty", []memory.MemoryItem{ + {"dedup by memory when id empty", []memory.Item{ {ID: "", Memory: "same", Score: 1}, {ID: "", Memory: "same", Score: 0.9}, }, 1}, - {"no dedup", []memory.MemoryItem{ + {"no dedup", []memory.Item{ {ID: "a", Memory: "m1", Score: 1}, {ID: "b", Memory: "m2", Score: 0.9}, }, 2}, diff --git a/internal/mcp/providers/message/provider.go b/internal/mcp/providers/message/provider.go index 218f63f6..cabbb0ed 100644 --- a/internal/mcp/providers/message/provider.go +++ b/internal/mcp/providers/message/provider.go @@ -1,9 +1,10 @@ +// Package message provides the MCP message provider (send and list tools). package message import ( "context" "encoding/json" - "fmt" + "errors" "log/slog" "strings" @@ -18,17 +19,17 @@ const ( // Sender sends outbound messages through channel manager. type Sender interface { - Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error + Send(ctx context.Context, botID string, channelType channel.Type, req channel.SendRequest) error } // Reactor adds or removes emoji reactions through channel manager. type Reactor interface { - React(ctx context.Context, botID string, channelType channel.ChannelType, req channel.ReactRequest) error + React(ctx context.Context, botID string, channelType channel.Type, req channel.ReactRequest) error } // ChannelTypeResolver parses platform name to channel type. type ChannelTypeResolver interface { - ParseChannelType(raw string) (channel.ChannelType, error) + ParseChannelType(raw string) (channel.Type, error) } // Executor exposes send and react as MCP tools. @@ -53,7 +54,8 @@ func NewExecutor(log *slog.Logger, sender Sender, reactor Reactor, resolver Chan } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +// ListTools returns send and optionally react tool descriptors. +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { var tools []mcpgw.ToolDescriptor if p.sender != nil && p.resolver != nil { tools = append(tools, mcpgw.ToolDescriptor{ @@ -138,6 +140,7 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte return tools, nil } +// CallTool runs send or react; validates args and delegates to Sender/Reactor. func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { switch toolName { case toolSend: @@ -273,21 +276,21 @@ func (p *Executor) resolveBotID(arguments map[string]any, session mcpgw.ToolSess botID = strings.TrimSpace(session.BotID) } if botID == "" { - return "", fmt.Errorf("bot_id is required") + return "", errors.New("bot_id is required") } if strings.TrimSpace(session.BotID) != "" && botID != strings.TrimSpace(session.BotID) { - return "", fmt.Errorf("bot_id mismatch") + return "", errors.New("bot_id mismatch") } return botID, nil } -func (p *Executor) resolvePlatform(arguments map[string]any, session mcpgw.ToolSessionContext) (channel.ChannelType, error) { +func (p *Executor) resolvePlatform(arguments map[string]any, session mcpgw.ToolSessionContext) (channel.Type, error) { platform := mcpgw.FirstStringArg(arguments, "platform") if platform == "" { platform = strings.TrimSpace(session.CurrentPlatform) } if platform == "" { - return "", fmt.Errorf("platform is required") + return "", errors.New("platform is required") } return p.resolver.ParseChannelType(platform) } @@ -307,14 +310,14 @@ func parseOutboundMessage(arguments map[string]any, fallbackText string) (channe return channel.Message{}, err } default: - return channel.Message{}, fmt.Errorf("message must be object or string") + return channel.Message{}, errors.New("message must be object or string") } } if msg.IsEmpty() && strings.TrimSpace(fallbackText) != "" { msg.Text = strings.TrimSpace(fallbackText) } if msg.IsEmpty() { - return channel.Message{}, fmt.Errorf("message is required") + return channel.Message{}, errors.New("message is required") } return msg, nil } diff --git a/internal/mcp/providers/message/provider_test.go b/internal/mcp/providers/message/provider_test.go index 98a775d7..3d0b9345 100644 --- a/internal/mcp/providers/message/provider_test.go +++ b/internal/mcp/providers/message/provider_test.go @@ -14,7 +14,7 @@ type fakeSender struct { lastReq channel.SendRequest } -func (f *fakeSender) Send(ctx context.Context, botID string, channelType channel.ChannelType, req channel.SendRequest) error { +func (f *fakeSender) Send(_ context.Context, _ string, _ channel.Type, req channel.SendRequest) error { f.lastReq = req return f.err } @@ -24,17 +24,17 @@ type fakeReactor struct { lastReq channel.ReactRequest } -func (f *fakeReactor) React(ctx context.Context, botID string, channelType channel.ChannelType, req channel.ReactRequest) error { +func (f *fakeReactor) React(_ context.Context, _ string, _ channel.Type, req channel.ReactRequest) error { f.lastReq = req return f.err } type fakeResolver struct { - ct channel.ChannelType + ct channel.Type err error } -func (f *fakeResolver) ParseChannelType(raw string) (channel.ChannelType, error) { +func (f *fakeResolver) ParseChannelType(_ string) (channel.Type, error) { if f.err != nil { return "", f.err } @@ -57,7 +57,7 @@ func TestExecutor_ListTools_NilDeps(t *testing.T) { func TestExecutor_ListTools(t *testing.T) { sender := &fakeSender{} reactor := &fakeReactor{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, reactor, resolver) tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) if err != nil { @@ -76,7 +76,7 @@ func TestExecutor_ListTools(t *testing.T) { func TestExecutor_ListTools_OnlySender(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) tools, err := exec.ListTools(context.Background(), mcpgw.ToolSessionContext{}) if err != nil { @@ -92,10 +92,10 @@ func TestExecutor_ListTools_OnlySender(t *testing.T) { func TestExecutor_CallTool_NotFound(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) - if err != mcpgw.ErrToolNotFound { + if !errors.Is(err, mcpgw.ErrToolNotFound) { t.Errorf("expected ErrToolNotFound, got %v", err) } } @@ -115,7 +115,7 @@ func TestExecutor_CallTool_NilDeps(t *testing.T) { func TestExecutor_CallTool_NoBotID(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) result, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{}, toolSend, map[string]any{ "platform": "feishu", "target": "t1", "text": "hi", @@ -130,7 +130,7 @@ func TestExecutor_CallTool_NoBotID(t *testing.T) { func TestExecutor_CallTool_BotIDMismatch(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -146,7 +146,7 @@ func TestExecutor_CallTool_BotIDMismatch(t *testing.T) { func TestExecutor_CallTool_NoPlatform(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -178,7 +178,7 @@ func TestExecutor_CallTool_PlatformParseError(t *testing.T) { func TestExecutor_CallTool_NoMessage(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -194,7 +194,7 @@ func TestExecutor_CallTool_NoMessage(t *testing.T) { func TestExecutor_CallTool_NoTarget(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -210,7 +210,7 @@ func TestExecutor_CallTool_NoTarget(t *testing.T) { func TestExecutor_CallTool_SendError(t *testing.T) { sender := &fakeSender{err: errors.New("send failed")} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", ReplyTarget: "t1"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -226,7 +226,7 @@ func TestExecutor_CallTool_SendError(t *testing.T) { func TestExecutor_CallTool_Success(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("feishu")} + resolver := &fakeResolver{ct: channel.Type("feishu")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "feishu", ReplyTarget: "chat1"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -252,7 +252,7 @@ func TestExecutor_CallTool_Success(t *testing.T) { func TestExecutor_CallTool_ReplyTo(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram", ReplyTarget: "123"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -275,7 +275,7 @@ func TestExecutor_CallTool_ReplyTo(t *testing.T) { func TestExecutor_CallTool_NoReplyTo(t *testing.T) { sender := &fakeSender{} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, sender, nil, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram", ReplyTarget: "123"} result, err := exec.CallTool(context.Background(), session, toolSend, map[string]any{ @@ -309,7 +309,7 @@ func TestExecutor_React_NilReactor(t *testing.T) { func TestExecutor_React_NoMessageID(t *testing.T) { reactor := &fakeReactor{} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, nil, reactor, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram", ReplyTarget: "123"} result, err := exec.CallTool(context.Background(), session, toolReact, map[string]any{ @@ -325,7 +325,7 @@ func TestExecutor_React_NoMessageID(t *testing.T) { func TestExecutor_React_NoTarget(t *testing.T) { reactor := &fakeReactor{} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, nil, reactor, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram"} result, err := exec.CallTool(context.Background(), session, toolReact, map[string]any{ @@ -341,7 +341,7 @@ func TestExecutor_React_NoTarget(t *testing.T) { func TestExecutor_React_Success(t *testing.T) { reactor := &fakeReactor{} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, nil, reactor, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram", ReplyTarget: "123"} result, err := exec.CallTool(context.Background(), session, toolReact, map[string]any{ @@ -370,7 +370,7 @@ func TestExecutor_React_Success(t *testing.T) { func TestExecutor_React_Remove(t *testing.T) { reactor := &fakeReactor{} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, nil, reactor, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram", ReplyTarget: "123"} result, err := exec.CallTool(context.Background(), session, toolReact, map[string]any{ @@ -393,7 +393,7 @@ func TestExecutor_React_Remove(t *testing.T) { func TestExecutor_React_Error(t *testing.T) { reactor := &fakeReactor{err: errors.New("reaction failed")} - resolver := &fakeResolver{ct: channel.ChannelType("telegram")} + resolver := &fakeResolver{ct: channel.Type("telegram")} exec := NewExecutor(nil, nil, reactor, resolver) session := mcpgw.ToolSessionContext{BotID: "bot1", CurrentPlatform: "telegram", ReplyTarget: "123"} result, err := exec.CallTool(context.Background(), session, toolReact, map[string]any{ diff --git a/internal/mcp/providers/schedule/provider.go b/internal/mcp/providers/schedule/provider.go index 156d908a..5fab125c 100644 --- a/internal/mcp/providers/schedule/provider.go +++ b/internal/mcp/providers/schedule/provider.go @@ -1,3 +1,4 @@ +// Package schedule provides the MCP schedule provider (list, get, create, trigger). package schedule import ( @@ -17,6 +18,7 @@ const ( toolScheduleDelete = "delete_schedule" ) +// Scheduler provides schedule list, get, create, update, delete (used by schedule tool executor). type Scheduler interface { List(ctx context.Context, botID string) ([]sched.Schedule, error) Get(ctx context.Context, id string) (sched.Schedule, error) @@ -25,11 +27,13 @@ type Scheduler interface { Delete(ctx context.Context, id string) error } +// Executor is the MCP tool executor for list_schedule, get_schedule, create_schedule, update_schedule, delete_schedule. type Executor struct { service Scheduler logger *slog.Logger } +// NewExecutor creates a schedule tool executor. func NewExecutor(log *slog.Logger, service Scheduler) *Executor { if log == nil { log = slog.Default() @@ -40,7 +44,8 @@ func NewExecutor(log *slog.Logger, service Scheduler) *Executor { } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +// ListTools returns list_schedule, get_schedule, create_schedule, update_schedule, delete_schedule descriptors. +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if p.service == nil { return []mcpgw.ToolDescriptor{}, nil } @@ -111,6 +116,7 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte }, nil } +// CallTool runs list_schedule, get_schedule, create_schedule, update_schedule, or delete_schedule. func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { if p.service == nil { return mcpgw.BuildToolErrorResult("schedule service not available"), nil diff --git a/internal/mcp/providers/schedule/provider_test.go b/internal/mcp/providers/schedule/provider_test.go index 43d7d544..c456726c 100644 --- a/internal/mcp/providers/schedule/provider_test.go +++ b/internal/mcp/providers/schedule/provider_test.go @@ -21,32 +21,32 @@ type fakeScheduler struct { deleteErr error } -func (f *fakeScheduler) List(ctx context.Context, botID string) ([]sched.Schedule, error) { +func (f *fakeScheduler) List(_ context.Context, _ string) ([]sched.Schedule, error) { return f.list, nil } -func (f *fakeScheduler) Get(ctx context.Context, id string) (sched.Schedule, error) { +func (f *fakeScheduler) Get(_ context.Context, _ string) (sched.Schedule, error) { if f.getErr != nil { return sched.Schedule{}, f.getErr } return f.get, nil } -func (f *fakeScheduler) Create(ctx context.Context, botID string, req sched.CreateRequest) (sched.Schedule, error) { +func (f *fakeScheduler) Create(_ context.Context, _ string, _ sched.CreateRequest) (sched.Schedule, error) { if f.createErr != nil { return sched.Schedule{}, f.createErr } return f.create, nil } -func (f *fakeScheduler) Update(ctx context.Context, id string, req sched.UpdateRequest) (sched.Schedule, error) { +func (f *fakeScheduler) Update(_ context.Context, _ string, _ sched.UpdateRequest) (sched.Schedule, error) { if f.updateErr != nil { return sched.Schedule{}, f.updateErr } return f.update, nil } -func (f *fakeScheduler) Delete(ctx context.Context, id string) error { +func (f *fakeScheduler) Delete(_ context.Context, _ string) error { return f.deleteErr } @@ -83,7 +83,7 @@ func TestExecutor_CallTool_NotFound(t *testing.T) { svc := &fakeScheduler{} exec := NewExecutor(nil, svc) _, err := exec.CallTool(context.Background(), mcpgw.ToolSessionContext{BotID: "bot1"}, "other_tool", nil) - if err != mcpgw.ErrToolNotFound { + if !errors.Is(err, mcpgw.ErrToolNotFound) { t.Errorf("expected ErrToolNotFound, got %v", err) } } diff --git a/internal/mcp/providers/web/provider.go b/internal/mcp/providers/web/provider.go index d568d670..b497f850 100644 --- a/internal/mcp/providers/web/provider.go +++ b/internal/mcp/providers/web/provider.go @@ -1,13 +1,14 @@ +// Package web provides the MCP web provider (Brave Search and fetch tools). package web import ( "context" "encoding/json" - "fmt" "io" "log/slog" "net/http" "net/url" + "strconv" "strings" "time" @@ -20,12 +21,14 @@ const ( toolWebSearch = "web_search" ) +// Executor is the MCP tool executor for web_search (and optional fetch) using configured search provider. type Executor struct { logger *slog.Logger settings *settings.Service searchProviders *searchproviders.Service } +// NewExecutor creates a web tool executor. func NewExecutor(log *slog.Logger, settingsSvc *settings.Service, searchSvc *searchproviders.Service) *Executor { if log == nil { log = slog.Default() @@ -37,7 +40,8 @@ func NewExecutor(log *slog.Logger, settingsSvc *settings.Service, searchSvc *sea } } -func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { +// ListTools returns web_search (and optionally fetch) tool descriptors. +func (p *Executor) ListTools(_ context.Context, _ mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { if p.settings == nil || p.searchProviders == nil { return []mcpgw.ToolDescriptor{}, nil } @@ -57,6 +61,7 @@ func (p *Executor) ListTools(ctx context.Context, session mcpgw.ToolSessionConte }, nil } +// CallTool runs web_search (or fetch) using the bot's configured search provider. func (p *Executor) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { if p.settings == nil || p.searchProviders == nil { return mcpgw.BuildToolErrorResult("web tools are not available"), nil @@ -112,7 +117,7 @@ func (p *Executor) callWebSearch(ctx context.Context, providerName string, confi } params := reqURL.Query() params.Set("q", query) - params.Set("count", fmt.Sprintf("%d", count)) + params.Set("count", strconv.Itoa(count)) reqURL.RawQuery = params.Encode() timeout := parseTimeout(configJSON, 15*time.Second) @@ -130,7 +135,11 @@ func (p *Executor) callWebSearch(ctx context.Context, providerName string, confi if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + p.logger.Warn("web search: close response body failed", slog.Any("error", err)) + } + }() body, err := io.ReadAll(resp.Body) if err != nil { return mcpgw.BuildToolErrorResult(err.Error()), nil diff --git a/internal/mcp/service.go b/internal/mcp/service.go index ba055c4b..e3d9f569 100644 --- a/internal/mcp/service.go +++ b/internal/mcp/service.go @@ -6,6 +6,7 @@ import ( "strconv" ) +// JSONRPCRequest is the JSON-RPC 2.0 request shape (jsonrpc, id, method, params). type JSONRPCRequest struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id"` @@ -13,6 +14,7 @@ type JSONRPCRequest struct { Params json.RawMessage `json:"params,omitempty"` } +// JSONRPCResponse is the JSON-RPC 2.0 response shape (result or error). type JSONRPCResponse struct { JSONRPC string `json:"jsonrpc"` ID json.RawMessage `json:"id,omitempty"` @@ -20,12 +22,14 @@ type JSONRPCResponse struct { Error *JSONRPCError `json:"error,omitempty"` } +// JSONRPCError is the JSON-RPC 2.0 error object (code, message). type JSONRPCError struct { Code int `json:"code"` Message string `json:"message"` } -func NewToolCallRequest(id string, toolName string, args map[string]any) (JSONRPCRequest, error) { +// NewToolCallRequest builds a tools/call JSON-RPC request with the given id, tool name, and arguments. +func NewToolCallRequest(id, toolName string, args map[string]any) (JSONRPCRequest, error) { params := map[string]any{ "name": toolName, "arguments": args, @@ -42,10 +46,12 @@ func NewToolCallRequest(id string, toolName string, args map[string]any) (JSONRP }, nil } +// RawStringID returns a JSON-RPC id as quoted string raw message. func RawStringID(id string) json.RawMessage { return json.RawMessage([]byte(strconv.Quote(id))) } +// PayloadError returns an error if the payload contains a top-level error object. func PayloadError(payload map[string]any) error { if payload == nil { return errors.New("empty payload") @@ -59,6 +65,7 @@ func PayloadError(payload map[string]any) error { return nil } +// ResultError returns an error if result.isError is true (tool call failure). func ResultError(payload map[string]any) error { result, ok := payload["result"].(map[string]any) if !ok { @@ -74,6 +81,7 @@ func ResultError(payload map[string]any) error { return nil } +// StructuredContent extracts result.structuredContent from the payload, or parses result.content text as JSON. func StructuredContent(payload map[string]any) (map[string]any, error) { result, ok := payload["result"].(map[string]any) if !ok { @@ -91,6 +99,7 @@ func StructuredContent(payload map[string]any) (map[string]any, error) { return nil, errors.New("missing structured content") } +// ContentText returns the first content item's text from the MCP result content array. func ContentText(result map[string]any) string { rawContent, ok := result["content"].([]any) if !ok || len(rawContent) == 0 { diff --git a/internal/mcp/sources/federation/source.go b/internal/mcp/sources/federation/source.go index 5fe0d951..f9ca5c24 100644 --- a/internal/mcp/sources/federation/source.go +++ b/internal/mcp/sources/federation/source.go @@ -1,3 +1,4 @@ +// Package federation provides the MCP federation SSE source. package federation import ( @@ -15,10 +16,12 @@ import ( const cacheTTL = 5 * time.Second +// ConnectionLister lists active MCP connections by bot (for federation source). type ConnectionLister interface { ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) } +// Gateway lists and calls tools on HTTP/SSE/Stdio MCP connections. type Gateway interface { ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) @@ -42,6 +45,7 @@ type cacheEntry struct { tools []mcpgw.ToolDescriptor } +// Source is a ToolSource that federates tools from remote MCP connections (HTTP/SSE/Stdio) with cache. type Source struct { logger *slog.Logger gateway Gateway @@ -51,6 +55,7 @@ type Source struct { cache map[string]cacheEntry } +// NewSource creates a federation source with gateway and connection lister. func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) *Source { if log == nil { log = slog.Default() @@ -63,6 +68,7 @@ func NewSource(log *slog.Logger, gateway Gateway, connections ConnectionLister) } } +// ListTools returns tools from all active connections for the bot (cached by cacheTTL). func (s *Source) ListTools(ctx context.Context, session mcpgw.ToolSessionContext) ([]mcpgw.ToolDescriptor, error) { botID := strings.TrimSpace(session.BotID) if botID == "" || s.gateway == nil { @@ -80,6 +86,7 @@ func (s *Source) ListTools(ctx context.Context, session mcpgw.ToolSessionContext return cloneTools(tools), nil } +// CallTool routes the tool call to the appropriate connection (HTTP/SSE/Stdio) via gateway. func (s *Source) CallTool(ctx context.Context, session mcpgw.ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { if s.gateway == nil { return mcpgw.BuildToolErrorResult("federation gateway not available"), nil @@ -274,6 +281,7 @@ func (s *Source) getRoute(botID, toolName string) (toolRoute, bool) { return route, exists } +// String returns a short description of the source for logging. func (s *Source) String() string { return fmt.Sprintf("FederationSource(%p)", s) } diff --git a/internal/mcp/sources/federation/source_test.go b/internal/mcp/sources/federation/source_test.go index b591ef44..91fd1833 100644 --- a/internal/mcp/sources/federation/source_test.go +++ b/internal/mcp/sources/federation/source_test.go @@ -13,7 +13,7 @@ type testConnectionLister struct { err error } -func (l *testConnectionLister) ListActiveByBot(ctx context.Context, botID string) ([]mcpgw.Connection, error) { +func (l *testConnectionLister) ListActiveByBot(_ context.Context, _ string) ([]mcpgw.Connection, error) { if l.err != nil { return nil, l.err } @@ -28,29 +28,29 @@ type testGateway struct { lastCallType string } -func (g *testGateway) ListHTTPConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { +func (g *testGateway) ListHTTPConnectionTools(_ context.Context, _ mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listHTTP, nil } -func (g *testGateway) CallHTTPConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { +func (g *testGateway) CallHTTPConnectionTool(_ context.Context, _ mcpgw.Connection, _ string, _ map[string]any) (map[string]any, error) { g.lastCallType = "http" return map[string]any{"result": map[string]any{"ok": true, "route": "http"}}, nil } -func (g *testGateway) ListSSEConnectionTools(ctx context.Context, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { +func (g *testGateway) ListSSEConnectionTools(_ context.Context, _ mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listSSE, nil } -func (g *testGateway) CallSSEConnectionTool(ctx context.Context, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { +func (g *testGateway) CallSSEConnectionTool(_ context.Context, _ mcpgw.Connection, _ string, _ map[string]any) (map[string]any, error) { g.lastCallType = "sse" return map[string]any{"result": map[string]any{"ok": true, "route": "sse"}}, nil } -func (g *testGateway) ListStdioConnectionTools(ctx context.Context, botID string, connection mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { +func (g *testGateway) ListStdioConnectionTools(_ context.Context, _ string, _ mcpgw.Connection) ([]mcpgw.ToolDescriptor, error) { return g.listStdio, nil } -func (g *testGateway) CallStdioConnectionTool(ctx context.Context, botID string, connection mcpgw.Connection, toolName string, args map[string]any) (map[string]any, error) { +func (g *testGateway) CallStdioConnectionTool(_ context.Context, _ string, _ mcpgw.Connection, _ string, _ map[string]any) (map[string]any, error) { g.lastCallType = "stdio" return map[string]any{"result": map[string]any{"ok": true, "route": "stdio"}}, nil } diff --git a/internal/mcp/tool_gateway_service.go b/internal/mcp/tool_gateway_service.go index ad202b59..be24eff3 100644 --- a/internal/mcp/tool_gateway_service.go +++ b/internal/mcp/tool_gateway_service.go @@ -3,7 +3,6 @@ package mcp import ( "context" "errors" - "fmt" "log/slog" "strings" "sync" @@ -30,6 +29,7 @@ type ToolGatewayService struct { cache map[string]cachedToolRegistry } +// NewToolGatewayService creates a gateway that aggregates tools from executors and sources (with cache). func NewToolGatewayService(log *slog.Logger, executors []ToolExecutor, sources []ToolSource) *ToolGatewayService { if log == nil { log = slog.Default() @@ -55,6 +55,7 @@ func NewToolGatewayService(log *slog.Logger, executors []ToolExecutor, sources [ } } +// InitializeResult returns the MCP initialize response (protocol version, capabilities, server info). func (s *ToolGatewayService) InitializeResult() map[string]any { return map[string]any{ "protocolVersion": "2025-06-18", @@ -70,6 +71,7 @@ func (s *ToolGatewayService) InitializeResult() map[string]any { } } +// ListTools returns all tools from executors and sources for the session (cached by cacheTTL). func (s *ToolGatewayService) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { registry, err := s.getRegistry(ctx, session, false) if err != nil { @@ -78,10 +80,11 @@ func (s *ToolGatewayService) ListTools(ctx context.Context, session ToolSessionC return registry.List(), nil } +// CallTool looks up the tool by name, delegates to the executor/source, and returns the MCP result map. func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionContext, payload ToolCallPayload) (map[string]any, error) { toolName := strings.TrimSpace(payload.Name) if toolName == "" { - return nil, fmt.Errorf("tool name is required") + return nil, errors.New("tool name is required") } registry, err := s.getRegistry(ctx, session, false) @@ -121,7 +124,7 @@ func (s *ToolGatewayService) CallTool(ctx context.Context, session ToolSessionCo func (s *ToolGatewayService) getRegistry(ctx context.Context, session ToolSessionContext, force bool) (*ToolRegistry, error) { botID := strings.TrimSpace(session.BotID) if botID == "" { - return nil, fmt.Errorf("bot id is required") + return nil, errors.New("bot id is required") } if !force { s.mu.Lock() diff --git a/internal/mcp/tool_gateway_service_test.go b/internal/mcp/tool_gateway_service_test.go index 3509f7ef..46cfb807 100644 --- a/internal/mcp/tool_gateway_service_test.go +++ b/internal/mcp/tool_gateway_service_test.go @@ -13,11 +13,11 @@ type gatewayTestProvider struct { callErr map[string]error } -func (p *gatewayTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { +func (p *gatewayTestProvider) ListTools(_ context.Context, _ ToolSessionContext) ([]ToolDescriptor, error) { return p.tools, nil } -func (p *gatewayTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (p *gatewayTestProvider) CallTool(_ context.Context, _ ToolSessionContext, toolName string, _ map[string]any) (map[string]any, error) { if err, ok := p.callErr[toolName]; ok { return nil, err } diff --git a/internal/mcp/tool_registry.go b/internal/mcp/tool_registry.go index edd7552c..69f4d386 100644 --- a/internal/mcp/tool_registry.go +++ b/internal/mcp/tool_registry.go @@ -1,6 +1,7 @@ package mcp import ( + "errors" "fmt" "sort" "strings" @@ -11,24 +12,26 @@ type registryItem struct { tool ToolDescriptor } -// ToolRegistry stores provider ownership and descriptor metadata. +// ToolRegistry stores tool name to executor and descriptor (single owner per name). type ToolRegistry struct { items map[string]registryItem } +// NewToolRegistry creates an empty registry. func NewToolRegistry() *ToolRegistry { return &ToolRegistry{ items: map[string]registryItem{}, } } +// Register adds a tool; returns error if name is empty or already registered. func (r *ToolRegistry) Register(executor ToolExecutor, tool ToolDescriptor) error { if executor == nil { - return fmt.Errorf("tool executor is required") + return errors.New("tool executor is required") } name := strings.TrimSpace(tool.Name) if name == "" { - return fmt.Errorf("tool name is required") + return errors.New("tool name is required") } if tool.InputSchema == nil { tool.InputSchema = map[string]any{ @@ -47,6 +50,7 @@ func (r *ToolRegistry) Register(executor ToolExecutor, tool ToolDescriptor) erro return nil } +// Lookup returns the executor and descriptor for the tool name, or false if not found. func (r *ToolRegistry) Lookup(name string) (ToolExecutor, ToolDescriptor, bool) { item, ok := r.items[strings.TrimSpace(name)] if !ok { @@ -55,6 +59,7 @@ func (r *ToolRegistry) Lookup(name string) (ToolExecutor, ToolDescriptor, bool) return item.executor, item.tool, true } +// List returns all tool descriptors sorted by name. func (r *ToolRegistry) List() []ToolDescriptor { if len(r.items) == 0 { return []ToolDescriptor{} diff --git a/internal/mcp/tool_registry_test.go b/internal/mcp/tool_registry_test.go index f5001d9d..93e446d4 100644 --- a/internal/mcp/tool_registry_test.go +++ b/internal/mcp/tool_registry_test.go @@ -7,11 +7,11 @@ import ( type registryTestProvider struct{} -func (p *registryTestProvider) ListTools(ctx context.Context, session ToolSessionContext) ([]ToolDescriptor, error) { +func (p *registryTestProvider) ListTools(_ context.Context, _ ToolSessionContext) ([]ToolDescriptor, error) { return nil, nil } -func (p *registryTestProvider) CallTool(ctx context.Context, session ToolSessionContext, toolName string, arguments map[string]any) (map[string]any, error) { +func (p *registryTestProvider) CallTool(_ context.Context, _ ToolSessionContext, _ string, _ map[string]any) (map[string]any, error) { return nil, nil } diff --git a/internal/mcp/tool_types.go b/internal/mcp/tool_types.go index 9a556ec5..e1a5085d 100644 --- a/internal/mcp/tool_types.go +++ b/internal/mcp/tool_types.go @@ -3,6 +3,7 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" "math" "strings" @@ -45,7 +46,7 @@ type ToolCallPayload struct { } // ErrToolNotFound indicates the provider does not own the requested tool. -var ErrToolNotFound = fmt.Errorf("tool not found") +var ErrToolNotFound = errors.New("tool not found") // BuildToolSuccessResult builds a standard MCP tool success result object. func BuildToolSuccessResult(structured any) map[string]any { @@ -105,6 +106,7 @@ func stringifyStructuredContent(v any) string { } } +// StringArg returns the string value for key in arguments, or empty string if missing/invalid. func StringArg(arguments map[string]any, key string) string { if arguments == nil { return "" @@ -121,6 +123,7 @@ func StringArg(arguments map[string]any, key string) string { } } +// FirstStringArg returns the first non-empty string value for the given keys in arguments. func FirstStringArg(arguments map[string]any, keys ...string) string { for _, key := range keys { if value := StringArg(arguments, key); value != "" { @@ -130,6 +133,7 @@ func FirstStringArg(arguments map[string]any, keys ...string) string { return "" } +// IntArg parses the key in arguments as an integer; returns value, true if present, or error if wrong type. func IntArg(arguments map[string]any, key string) (int, bool, error) { if arguments == nil { return 0, false, nil @@ -181,6 +185,7 @@ func IntArg(arguments map[string]any, key string) (int, bool, error) { } } +// BoolArg returns the boolean value for key in arguments; second return is true if key present, error if wrong type. func BoolArg(arguments map[string]any, key string) (bool, bool, error) { if arguments == nil { return false, false, nil diff --git a/internal/mcp/versioning.go b/internal/mcp/versioning.go index f9de5043..768f7df4 100644 --- a/internal/mcp/versioning.go +++ b/internal/mcp/versioning.go @@ -3,7 +3,9 @@ package mcp import ( "context" "encoding/json" + "errors" "fmt" + "log/slog" "time" "github.com/containerd/containerd/v2/pkg/oci" @@ -12,12 +14,12 @@ import ( "github.com/opencontainers/runtime-spec/specs-go" "github.com/memohai/memoh/internal/config" - ctr "github.com/memohai/memoh/internal/containerd" "github.com/memohai/memoh/internal/db" dbsqlc "github.com/memohai/memoh/internal/db/sqlc" ) +// VersionInfo holds version record (id, version number, snapshot id, created_at). type VersionInfo struct { ID string Version int @@ -25,9 +27,10 @@ type VersionInfo struct { CreatedAt time.Time } +// CreateVersion commits the current container snapshot, creates a new container from it, and records the version. func (m *Manager) CreateVersion(ctx context.Context, userID string) (*VersionInfo, error) { if m.db == nil || m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } if err := validateBotID(userID); err != nil { return nil, err @@ -128,9 +131,10 @@ func (m *Manager) CreateVersion(ctx context.Context, userID string) (*VersionInf }, nil } +// ListVersions returns version records for the bot (userID) from DB, newest first. func (m *Manager) ListVersions(ctx context.Context, userID string) ([]VersionInfo, error) { if m.db == nil || m.queries == nil { - return nil, fmt.Errorf("db is not configured") + return nil, errors.New("db is not configured") } if err := validateBotID(userID); err != nil { return nil, err @@ -158,9 +162,10 @@ func (m *Manager) ListVersions(ctx context.Context, userID string) ([]VersionInf return out, nil } +// RollbackVersion restores the container from the given version snapshot (delete current, create from snapshot). func (m *Manager) RollbackVersion(ctx context.Context, userID string, version int) error { if m.db == nil || m.queries == nil { - return fmt.Errorf("db is not configured") + return errors.New("db is not configured") } if err := validateBotID(userID); err != nil { return err @@ -244,9 +249,10 @@ func (m *Manager) RollbackVersion(ctx context.Context, userID string, version in }) } +// VersionSnapshotID returns the snapshot ID for the given version number from DB. func (m *Manager) VersionSnapshotID(ctx context.Context, userID string, version int) (string, error) { if m.db == nil || m.queries == nil { - return "", fmt.Errorf("db is not configured") + return "", errors.New("db is not configured") } if err := validateBotID(userID); err != nil { return "", err @@ -273,7 +279,7 @@ func (m *Manager) safeStopTask(ctx context.Context, containerID string) error { return err } -func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, runtime, imageRef string) (pgtype.UUID, error) { +func (m *Manager) ensureDBRecords(ctx context.Context, botID, containerID, _ string, imageRef string) (pgtype.UUID, error) { hostPath, err := m.DataDir(botID) if err != nil { return pgtype.UUID{}, err @@ -315,7 +321,11 @@ func (m *Manager) insertVersion(ctx context.Context, containerID, snapshotID, sn if err != nil { return "", 0, time.Time{}, err } - defer tx.Rollback(ctx) + defer func() { + if err := tx.Rollback(ctx); err != nil { + m.logger.Warn("insert version: tx rollback failed", slog.Any("error", err)) + } + }() qtx := m.queries.WithTx(tx) @@ -369,4 +379,3 @@ func (m *Manager) insertEvent(ctx context.Context, containerID, eventType string Payload: b, }) } - diff --git a/internal/memory/indexer.go b/internal/memory/indexer.go index 534f4ef1..69aa5eef 100644 --- a/internal/memory/indexer.go +++ b/internal/memory/indexer.go @@ -1,3 +1,4 @@ +// Package memory provides in-memory and Qdrant-backed memory stores and indexing. package memory import ( @@ -5,12 +6,11 @@ import ( "hash/fnv" "log/slog" "math" - "sort" + "slices" "strings" "sync" - "github.com/blevesearch/bleve/v2/registry" - + // Register bleve analysis and language analyzers for full-text indexing. _ "github.com/blevesearch/bleve/v2/analysis/analyzer/standard" _ "github.com/blevesearch/bleve/v2/analysis/lang/ar" _ "github.com/blevesearch/bleve/v2/analysis/lang/bg" @@ -42,6 +42,7 @@ import ( _ "github.com/blevesearch/bleve/v2/analysis/lang/ru" _ "github.com/blevesearch/bleve/v2/analysis/lang/sv" _ "github.com/blevesearch/bleve/v2/analysis/lang/tr" + "github.com/blevesearch/bleve/v2/registry" ) const ( @@ -52,6 +53,7 @@ const ( sparseDimMask = sparseDimSize - 1 ) +// BM25Indexer provides BM25 sparse indexing for full-text memory search (term weights, stats). type BM25Indexer struct { cache *registry.Cache logger *slog.Logger @@ -68,6 +70,7 @@ type bm25Stats struct { DocFreq map[string]int } +// NewBM25Indexer creates a BM25 indexer with default k1/b and a language-aware analyzer cache. func NewBM25Indexer(log *slog.Logger) *BM25Indexer { if log == nil { log = slog.Default() @@ -81,6 +84,7 @@ func NewBM25Indexer(log *slog.Logger) *BM25Indexer { } } +// TermFrequencies tokenizes text for the given language and returns term frequencies and document length. func (b *BM25Indexer) TermFrequencies(lang, text string) (map[string]int, int, error) { analyzerName, err := b.normalizeAnalyzer(lang) if err != nil { @@ -104,6 +108,7 @@ func (b *BM25Indexer) TermFrequencies(lang, text string) (map[string]int, int, e return freq, docLen, nil } +// AddDocument indexes a document and returns the sparse vector (indices, values) for storage. func (b *BM25Indexer) AddDocument(lang string, termFreq map[string]int, docLen int) (indices []uint32, values []float32) { b.mu.Lock() stats := b.ensureStatsLocked(lang) @@ -113,6 +118,7 @@ func (b *BM25Indexer) AddDocument(lang string, termFreq map[string]int, docLen i return indices, values } +// RemoveDocument updates BM25 stats by removing the document (e.g. before update). func (b *BM25Indexer) RemoveDocument(lang string, termFreq map[string]int, docLen int) { b.mu.Lock() stats := b.ensureStatsLocked(lang) @@ -120,6 +126,7 @@ func (b *BM25Indexer) RemoveDocument(lang string, termFreq map[string]int, docLe b.mu.Unlock() } +// BuildQueryVector builds the sparse query vector for the given term frequencies (for SearchSparse). func (b *BM25Indexer) BuildQueryVector(lang string, termFreq map[string]int) (indices []uint32, values []float32) { b.mu.RLock() stats := b.ensureStatsLocked(lang) @@ -237,7 +244,7 @@ func sparseWeightsToVector(weights map[uint32]float32) ([]uint32, []float32) { for idx := range weights { indices = append(indices, idx) } - sort.Slice(indices, func(i, j int) bool { return indices[i] < indices[j] }) + slices.Sort(indices) values := make([]float32, 0, len(indices)) for _, idx := range indices { values = append(values, weights[idx]) diff --git a/internal/memory/llm_client.go b/internal/memory/llm_client.go index 967b1c64..c58d5333 100644 --- a/internal/memory/llm_client.go +++ b/internal/memory/llm_client.go @@ -4,14 +4,17 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "log/slog" "net/http" + "strconv" "strings" "time" ) +// LLMClient calls an OpenAI-compatible chat API for Extract, Decide, Compact, DetectLanguage. type LLMClient struct { baseURL string apiKey string @@ -20,15 +23,16 @@ type LLMClient struct { http *http.Client } +// NewLLMClient builds an LLM client; baseURL, apiKey, model required; timeout defaults to 10s. func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time.Duration) (*LLMClient, error) { if strings.TrimSpace(baseURL) == "" { - return nil, fmt.Errorf("llm client: base url is required") + return nil, errors.New("llm client: base url is required") } if strings.TrimSpace(apiKey) == "" { - return nil, fmt.Errorf("llm client: api key is required") + return nil, errors.New("llm client: api key is required") } if strings.TrimSpace(model) == "" { - return nil, fmt.Errorf("llm client: model is required") + return nil, errors.New("llm client: model is required") } if log == nil { log = slog.Default() @@ -47,9 +51,10 @@ func NewLLMClient(log *slog.Logger, baseURL, apiKey, model string, timeout time. }, nil } +// Extract calls the LLM to extract facts from messages and returns structured facts. func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { if len(req.Messages) == 0 { - return ExtractResponse{}, fmt.Errorf("messages is required") + return ExtractResponse{}, errors.New("messages is required") } parsedMessages := strings.Join(formatMessages(req.Messages), "\n") systemPrompt, userPrompt := getFactRetrievalMessages(parsedMessages) @@ -68,9 +73,10 @@ func (c *LLMClient) Extract(ctx context.Context, req ExtractRequest) (ExtractRes return parsed, nil } +// Decide calls the LLM to decide add/update/delete actions from facts and candidates. func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) { if len(req.Facts) == 0 { - return DecideResponse{}, fmt.Errorf("facts is required") + return DecideResponse{}, errors.New("facts is required") } retrieved := make([]map[string]string, 0, len(req.Candidates)) for _, candidate := range req.Candidates { @@ -130,9 +136,10 @@ func (c *LLMClient) Decide(ctx context.Context, req DecideRequest) (DecideRespon return DecideResponse{Actions: actions}, nil } +// Compact calls the LLM to consolidate memories into fewer facts (target count, optional decay). func (c *LLMClient) Compact(ctx context.Context, req CompactRequest) (CompactResponse, error) { if len(req.Memories) == 0 { - return CompactResponse{}, fmt.Errorf("memories is required") + return CompactResponse{}, errors.New("memories is required") } memories := make([]map[string]string, 0, len(req.Memories)) for _, m := range req.Memories { @@ -160,9 +167,10 @@ func (c *LLMClient) Compact(ctx context.Context, req CompactRequest) (CompactRes return parsed, nil } +// DetectLanguage calls the LLM to detect the language code of the given text. func (c *LLMClient) DetectLanguage(ctx context.Context, text string) (string, error) { if strings.TrimSpace(text) == "" { - return "", fmt.Errorf("text is required") + return "", errors.New("text is required") } systemPrompt, userPrompt := getLanguageDetectionMessages(text) content, err := c.callChat(ctx, []chatMessage{ @@ -205,7 +213,7 @@ type chatResponse struct { func (c *LLMClient) callChat(ctx context.Context, messages []chatMessage) (string, error) { if c.apiKey == "" { - return "", fmt.Errorf("llm api key is required") + return "", errors.New("llm api key is required") } body, err := json.Marshal(chatRequest{ Model: c.model, @@ -229,7 +237,11 @@ func (c *LLMClient) callChat(ctx context.Context, messages []chatMessage) (strin if err != nil { return "", err } - defer resp.Body.Close() + defer func() { + if err := resp.Body.Close(); err != nil { + c.logger.Warn("llm request: close response body failed", slog.Any("error", err)) + } + }() if resp.StatusCode < 200 || resp.StatusCode >= 300 { b, _ := io.ReadAll(resp.Body) @@ -241,7 +253,7 @@ func (c *LLMClient) callChat(ctx context.Context, messages []chatMessage) (strin return "", err } if len(parsed.Choices) == 0 || parsed.Choices[0].Message.Content == "" { - return "", fmt.Errorf("llm response missing content") + return "", errors.New("llm response missing content") } return parsed.Choices[0].Message.Content, nil } @@ -260,13 +272,13 @@ func asString(value any) string { return typed case float64: if typed == float64(int64(typed)) { - return fmt.Sprintf("%d", int64(typed)) + return strconv.FormatInt(int64(typed), 10) } return fmt.Sprintf("%f", typed) case int: - return fmt.Sprintf("%d", typed) + return strconv.Itoa(typed) case int64: - return fmt.Sprintf("%d", typed) + return strconv.FormatInt(typed, 10) default: return "" } diff --git a/internal/memory/llm_client_test.go b/internal/memory/llm_client_test.go index 4633d4d5..418b2ea6 100644 --- a/internal/memory/llm_client_test.go +++ b/internal/memory/llm_client_test.go @@ -16,7 +16,7 @@ func TestLLMClientExtract(t *testing.T) { return } w.WriteHeader(http.StatusOK) - w.Write([]byte(`{"choices":[{"message":{"content":"{\"facts\":[\"hello\"]}"}}]}`)) + _, _ = w.Write([]byte(`{"choices":[{"message":{"content":"{\"facts\":[\"hello\"]}"}}]}`)) })) defer server.Close() diff --git a/internal/memory/memoryfs.go b/internal/memory/memoryfs.go index 59b729f8..f709b06d 100644 --- a/internal/memory/memoryfs.go +++ b/internal/memory/memoryfs.go @@ -3,6 +3,7 @@ package memory import ( "context" "encoding/json" + "errors" "fmt" "log/slog" "strings" @@ -19,8 +20,8 @@ const ( manifestVer = 1 ) -// MemoryFS persists memory entries as files inside the bot container via ExecRunner. -type MemoryFS struct { +// FS persists memory entries as files inside the bot container via ExecRunner. +type FS struct { execRunner container.ExecRunner workDir string // e.g. "/data" logger *slog.Logger @@ -42,15 +43,15 @@ type ManifestEntry struct { Filters map[string]any `json:"filters,omitempty"` } -// NewMemoryFS creates a MemoryFS that writes through the given ExecRunner. -func NewMemoryFS(log *slog.Logger, runner container.ExecRunner, workDir string) *MemoryFS { +// NewFS creates a FS that writes through the given ExecRunner. +func NewFS(log *slog.Logger, runner container.ExecRunner, workDir string) *FS { if log == nil { log = slog.Default() } if strings.TrimSpace(workDir) == "" { workDir = "/data" } - return &MemoryFS{ + return &FS{ execRunner: runner, workDir: workDir, logger: log.With(slog.String("component", "memoryfs")), @@ -61,7 +62,7 @@ func NewMemoryFS(log *slog.Logger, runner container.ExecRunner, workDir string) // PersistMemories writes .md files for new items and incrementally updates the manifest. // Used after Add — does NOT delete existing files. -func (fs *MemoryFS) PersistMemories(ctx context.Context, botID string, items []MemoryItem, filters map[string]any) error { +func (fs *FS) PersistMemories(ctx context.Context, botID string, items []Item, filters map[string]any) error { if len(items) == 0 { return nil } @@ -100,7 +101,7 @@ func (fs *MemoryFS) PersistMemories(ctx context.Context, botID string, items []M // RebuildFiles does a full replace: deletes all old memory/*.md files, writes new ones, // and rewrites manifest from scratch. Used after Compact. -func (fs *MemoryFS) RebuildFiles(ctx context.Context, botID string, items []MemoryItem, filters map[string]any) error { +func (fs *FS) RebuildFiles(ctx context.Context, botID string, items []Item, filters map[string]any) error { fs.mu.Lock() defer fs.mu.Unlock() @@ -132,7 +133,7 @@ func (fs *MemoryFS) RebuildFiles(ctx context.Context, botID string, items []Memo } // RemoveMemories removes specific memory files from the FS and updates the manifest. -func (fs *MemoryFS) RemoveMemories(ctx context.Context, botID string, ids []string) error { +func (fs *FS) RemoveMemories(ctx context.Context, botID string, ids []string) error { if len(ids) == 0 { return nil } @@ -158,7 +159,7 @@ func (fs *MemoryFS) RemoveMemories(ctx context.Context, botID string, ids []stri } // RemoveAllMemories deletes all memory files and the manifest. -func (fs *MemoryFS) RemoveAllMemories(ctx context.Context, botID string) error { +func (fs *FS) RemoveAllMemories(ctx context.Context, botID string) error { fs.mu.Lock() defer fs.mu.Unlock() @@ -174,13 +175,13 @@ func (fs *MemoryFS) RemoveAllMemories(ctx context.Context, botID string) error { // ----- read operations ----- // ReadManifest reads and parses the manifest.json file. -func (fs *MemoryFS) ReadManifest(ctx context.Context, botID string) (*Manifest, error) { +func (fs *FS) ReadManifest(ctx context.Context, botID string) (*Manifest, error) { fs.mu.Lock() defer fs.mu.Unlock() return fs.readManifestLocked(ctx, botID) } -func (fs *MemoryFS) readManifestLocked(ctx context.Context, botID string) (*Manifest, error) { +func (fs *FS) readManifestLocked(ctx context.Context, botID string) (*Manifest, error) { content, err := container.ExecRead(ctx, fs.execRunner, botID, fs.workDir, manifestPath) if err != nil { return nil, err @@ -193,13 +194,13 @@ func (fs *MemoryFS) readManifestLocked(ctx context.Context, botID string) (*Mani } // ReadAllMemoryFiles lists and reads all .md files under memory/ and parses their frontmatter. -func (fs *MemoryFS) ReadAllMemoryFiles(ctx context.Context, botID string) ([]MemoryItem, error) { +func (fs *FS) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Item, error) { entries, err := container.ExecList(ctx, fs.execRunner, botID, fs.workDir, memoryDirPath, false) if err != nil { return nil, fmt.Errorf("list memory dir: %w", err) } - var items []MemoryItem + var items []Item for _, entry := range entries { if entry.IsDir || !strings.HasSuffix(entry.Path, ".md") { continue @@ -222,13 +223,13 @@ func (fs *MemoryFS) ReadAllMemoryFiles(ctx context.Context, botID string) ([]Mem // ----- internal helpers ----- -func (fs *MemoryFS) writeMemoryFile(ctx context.Context, botID string, item MemoryItem) error { +func (fs *FS) writeMemoryFile(ctx context.Context, botID string, item Item) error { content := formatMemoryMD(item) filePath := fmt.Sprintf("%s/%s.md", memoryDirPath, item.ID) return container.ExecWrite(ctx, fs.execRunner, botID, fs.workDir, filePath, content) } -func (fs *MemoryFS) writeManifest(ctx context.Context, botID string, manifest *Manifest) error { +func (fs *FS) writeManifest(ctx context.Context, botID string, manifest *Manifest) error { data, err := json.MarshalIndent(manifest, "", " ") if err != nil { return fmt.Errorf("marshal manifest: %w", err) @@ -237,7 +238,7 @@ func (fs *MemoryFS) writeManifest(ctx context.Context, botID string, manifest *M } // execDeleteDir removes all files inside a directory (but keeps the directory itself). -func (fs *MemoryFS) execDeleteDir(ctx context.Context, botID, dirPath string) { +func (fs *FS) execDeleteDir(ctx context.Context, botID, dirPath string) { // Use find + rm to avoid shell quoting issues with glob wildcards. script := fmt.Sprintf("find %s -type f -delete 2>/dev/null; true", container.ShellQuote(dirPath)) _, err := fs.execRunner.ExecWithCapture(ctx, mcpgw.ExecRequest{ @@ -251,8 +252,8 @@ func (fs *MemoryFS) execDeleteDir(ctx context.Context, botID, dirPath string) { } // execDeleteFile removes a single file. -func (fs *MemoryFS) execDeleteFile(ctx context.Context, botID, filePath string) { - script := fmt.Sprintf("rm -f %s", container.ShellQuote(filePath)) +func (fs *FS) execDeleteFile(ctx context.Context, botID, filePath string) { + script := "rm -f " + container.ShellQuote(filePath) _, err := fs.execRunner.ExecWithCapture(ctx, mcpgw.ExecRequest{ BotID: botID, Command: []string{"/bin/sh", "-c", script}, @@ -265,7 +266,7 @@ func (fs *MemoryFS) execDeleteFile(ctx context.Context, botID, filePath string) // ----- .md formatting / parsing ----- -func formatMemoryMD(item MemoryItem) string { +func formatMemoryMD(item Item) string { var b strings.Builder b.WriteString("---\n") b.WriteString(fmt.Sprintf("id: %s\n", item.ID)) @@ -284,21 +285,21 @@ func formatMemoryMD(item MemoryItem) string { return b.String() } -func parseMemoryMD(content string) (MemoryItem, error) { +func parseMemoryMD(content string) (Item, error) { content = strings.TrimSpace(content) if !strings.HasPrefix(content, "---") { - return MemoryItem{}, fmt.Errorf("missing frontmatter") + return Item{}, errors.New("missing frontmatter") } // Split on "---" delimiters. parts := strings.SplitN(content[3:], "---", 2) if len(parts) < 2 { - return MemoryItem{}, fmt.Errorf("incomplete frontmatter") + return Item{}, errors.New("incomplete frontmatter") } frontmatter := strings.TrimSpace(parts[0]) body := strings.TrimSpace(parts[1]) - item := MemoryItem{Memory: body} - for _, line := range strings.Split(frontmatter, "\n") { + item := Item{Memory: body} + for line := range strings.SplitSeq(frontmatter, "\n") { line = strings.TrimSpace(line) if line == "" { continue @@ -321,7 +322,7 @@ func parseMemoryMD(content string) (MemoryItem, error) { } } if item.ID == "" { - return MemoryItem{}, fmt.Errorf("missing id in frontmatter") + return Item{}, errors.New("missing id in frontmatter") } return item, nil } diff --git a/internal/memory/prompts.go b/internal/memory/prompts.go index 42280e4b..829fd72f 100644 --- a/internal/memory/prompts.go +++ b/internal/memory/prompts.go @@ -63,7 +63,7 @@ Following is a conversation between the user and the assistant. You have to extr You should detect the language of the user input and record the facts in the same language. `, time.Now().UTC().Format("2006-01-02"), "```json", "```") - userPrompt := fmt.Sprintf("Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the JSON format as shown above.\n\nInput:\n%s", parsedMessages) + userPrompt := "Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the JSON format as shown above.\n\nInput:\n" + parsedMessages return systemPrompt, userPrompt } @@ -106,7 +106,7 @@ Follow the instruction mentioned below: Do not return anything except the JSON format.`, toJSON(retrievedOldMemory), toJSON(newRetrievedFacts), "```json", "```") } -func getCompactMemoryMessages(memories []map[string]string, targetCount int, decayDays int) (string, string) { +func getCompactMemoryMessages(memories []map[string]string, targetCount, decayDays int) (string, string) { decayInstruction := "" if decayDays > 0 { decayInstruction = fmt.Sprintf(` @@ -152,7 +152,7 @@ Do not include any extra keys, comments, or formatting. Output must be valid JSO If the text is Chinese, Japanese, or Korean, output exactly {"language":"cjk"}. Never output "zh", "zh-cn", "zh-tw", "ja", "ko", or any code not in the allowed list. Before finalizing, verify the value is one of the allowed codes.` - userPrompt := fmt.Sprintf("Text:\n%s", text) + userPrompt := "Text:\n" + text return systemPrompt, userPrompt } diff --git a/internal/memory/qdrant_store.go b/internal/memory/qdrant_store.go index dfcd3cc4..5789e3d2 100644 --- a/internal/memory/qdrant_store.go +++ b/internal/memory/qdrant_store.go @@ -2,8 +2,10 @@ package memory import ( "context" + "errors" "fmt" "log/slog" + "maps" "net/url" "strconv" "strings" @@ -19,6 +21,7 @@ const ( sparseVocabVectorName = "sparse_vocab" ) +// QdrantStore is the Qdrant-backed vector store for dense and optional sparse (BM25) vectors. type QdrantStore struct { client *qdrant.Client collection string @@ -33,7 +36,8 @@ type QdrantStore struct { usesSparseVectors bool } -type qdrantPoint struct { +// QdrantPoint is a single point (id, vector, sparse, payload) for Qdrant store operations. +type QdrantPoint struct { ID string `json:"id"` Vector []float32 `json:"vector"` VectorName string `json:"vector_name,omitempty"` @@ -43,6 +47,7 @@ type qdrantPoint struct { Payload map[string]any `json:"payload,omitempty"` } +// NewQdrantStore creates a Qdrant store; collection defaults to "memory"; dimension or sparseVectorName required. func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimension int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) { host, port, useTLS, err := parseQdrantEndpoint(baseURL) if err != nil { @@ -55,7 +60,7 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens collection = "memory" } if dimension <= 0 && strings.TrimSpace(sparseVectorName) == "" { - return nil, fmt.Errorf("embedding dimension is required") + return nil, errors.New("embedding dimension is required") } cfg := &qdrant.Config{ @@ -89,10 +94,12 @@ func NewQdrantStore(log *slog.Logger, baseURL, apiKey, collection string, dimens return store, nil } +// NewSibling creates a new QdrantStore for a different collection/dimension using the same connection config. func (s *QdrantStore) NewSibling(collection string, dimension int) (*QdrantStore, error) { return NewQdrantStore(s.logger, s.baseURL, s.apiKey, collection, dimension, s.sparseVectorName, s.timeout) } +// NewQdrantStoreWithVectors creates a store with named vectors (map[name]dimension); used for multi-model collections. func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection string, vectors map[string]int, sparseVectorName string, timeout time.Duration) (*QdrantStore, error) { host, port, useTLS, err := parseQdrantEndpoint(baseURL) if err != nil { @@ -105,7 +112,7 @@ func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection str collection = "memory" } if len(vectors) == 0 { - return nil, fmt.Errorf("vectors map is required") + return nil, errors.New("vectors map is required") } cfg := &qdrant.Config{ @@ -140,7 +147,8 @@ func NewQdrantStoreWithVectors(log *slog.Logger, baseURL, apiKey, collection str return store, nil } -func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { +// Upsert inserts or overwrites points (dense and/or sparse vectors, payload) in the collection. +func (s *QdrantStore) Upsert(ctx context.Context, points []QdrantPoint) error { if len(points) == 0 { return nil } @@ -153,11 +161,12 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { var vectors *qdrant.Vectors vectorMap := map[string]*qdrant.Vector{} if len(point.Vector) > 0 { - if point.VectorName != "" && s.usesNamedVectors { + switch { + case point.VectorName != "" && s.usesNamedVectors: vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector) - } else if !s.usesNamedVectors && len(point.SparseIndices) == 0 { + case !s.usesNamedVectors && len(point.SparseIndices) == 0: vectors = qdrant.NewVectorsDense(point.Vector) - } else if point.VectorName != "" { + case point.VectorName != "": vectorMap[point.VectorName] = qdrant.NewVectorDense(point.Vector) } } @@ -167,7 +176,7 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { sparseName = s.sparseVectorName } if sparseName == "" { - return fmt.Errorf("sparse vector name is required") + return errors.New("sparse vector name is required") } vectorMap[sparseName] = qdrant.NewVectorSparse(point.SparseIndices, point.SparseValues) } @@ -191,7 +200,8 @@ func (s *QdrantStore) Upsert(ctx context.Context, points []qdrantPoint) error { return err } -func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]any, vectorName string) ([]qdrantPoint, []float64, error) { +// Search performs dense vector search; returns points and scores. +func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, filters map[string]any, vectorName string) ([]QdrantPoint, []float64, error) { if limit <= 0 { limit = 10 } @@ -212,10 +222,10 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f return nil, nil, err } - points := make([]qdrantPoint, 0, len(results)) + points := make([]QdrantPoint, 0, len(results)) scores := make([]float64, 0, len(results)) for _, scored := range results { - points = append(points, qdrantPoint{ + points = append(points, QdrantPoint{ ID: pointIDToString(scored.GetId()), Payload: valueMapToInterface(scored.GetPayload()), }) @@ -224,7 +234,8 @@ func (s *QdrantStore) Search(ctx context.Context, vector []float32, limit int, f return points, scores, nil } -func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]any, withSparseVectors bool) ([]qdrantPoint, []float64, error) { +// SearchSparse performs sparse (BM25) vector search; returns points and scores. +func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]any, withSparseVectors bool) ([]QdrantPoint, []float64, error) { if limit <= 0 { limit = 10 } @@ -232,7 +243,7 @@ func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values return nil, nil, nil } if s.sparseVectorName == "" { - return nil, nil, fmt.Errorf("sparse vector name not configured") + return nil, nil, errors.New("sparse vector name not configured") } filter := buildQdrantFilter(filters) using := qdrant.PtrOf(s.sparseVectorName) @@ -251,10 +262,10 @@ func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values if err != nil { return nil, nil, err } - points := make([]qdrantPoint, 0, len(results)) + points := make([]QdrantPoint, 0, len(results)) scores := make([]float64, 0, len(results)) for _, scored := range results { - p := qdrantPoint{ + p := QdrantPoint{ ID: pointIDToString(scored.GetId()), Payload: valueMapToInterface(scored.GetPayload()), } @@ -267,8 +278,9 @@ func (s *QdrantStore) SearchSparse(ctx context.Context, indices []uint32, values return points, scores, nil } -func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]any, sources []string, vectorName string) (map[string][]qdrantPoint, map[string][]float64, error) { - pointsBySource := make(map[string][]qdrantPoint, len(sources)) +// SearchBySources runs dense search and groups results by source (e.g. for multi-source ranking). +func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, limit int, filters map[string]any, sources []string, vectorName string) (map[string][]QdrantPoint, map[string][]float64, error) { + pointsBySource := make(map[string][]QdrantPoint, len(sources)) scoresBySource := make(map[string][]float64, len(sources)) if len(sources) == 0 { return pointsBySource, scoresBySource, nil @@ -288,8 +300,9 @@ func (s *QdrantStore) SearchBySources(ctx context.Context, vector []float32, lim return pointsBySource, scoresBySource, nil } -func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]any, sources []string, withSparseVectors bool) (map[string][]qdrantPoint, map[string][]float64, error) { - pointsBySource := make(map[string][]qdrantPoint, len(sources)) +// SearchSparseBySources runs sparse search and groups results by source. +func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint32, values []float32, limit int, filters map[string]any, sources []string, withSparseVectors bool) (map[string][]QdrantPoint, map[string][]float64, error) { + pointsBySource := make(map[string][]QdrantPoint, len(sources)) scoresBySource := make(map[string][]float64, len(sources)) if len(sources) == 0 { return pointsBySource, scoresBySource, nil @@ -309,7 +322,8 @@ func (s *QdrantStore) SearchSparseBySources(ctx context.Context, indices []uint3 return pointsBySource, scoresBySource, nil } -func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) { +// Get returns a single point by ID (payload only; no vectors in response). +func (s *QdrantStore) Get(ctx context.Context, id string) (*QdrantPoint, error) { result, err := s.client.Get(ctx, &qdrant.GetPoints{ CollectionName: s.collection, Ids: []*qdrant.PointId{qdrant.NewIDUUID(id)}, @@ -322,12 +336,13 @@ func (s *QdrantStore) Get(ctx context.Context, id string) (*qdrantPoint, error) return nil, nil } point := result[0] - return &qdrantPoint{ + return &QdrantPoint{ ID: pointIDToString(point.GetId()), Payload: valueMapToInterface(point.GetPayload()), }, nil } +// Delete removes a single point by ID. func (s *QdrantStore) Delete(ctx context.Context, id string) error { _, err := s.client.Delete(ctx, &qdrant.DeletePoints{ CollectionName: s.collection, @@ -337,6 +352,7 @@ func (s *QdrantStore) Delete(ctx context.Context, id string) error { return err } +// DeleteBatch removes multiple points by ID. func (s *QdrantStore) DeleteBatch(ctx context.Context, ids []string) error { if len(ids) == 0 { return nil @@ -353,7 +369,8 @@ func (s *QdrantStore) DeleteBatch(ctx context.Context, ids []string) error { return err } -func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]any, withSparseVectors bool) ([]qdrantPoint, error) { +// List scrolls points with optional limit and filters; optionally includes sparse vectors. +func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]any, withSparseVectors bool) ([]QdrantPoint, error) { if limit <= 0 { limit = 100 } @@ -372,9 +389,9 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]an return nil, err } - result := make([]qdrantPoint, 0, len(points)) + result := make([]QdrantPoint, 0, len(points)) for _, point := range points { - p := qdrantPoint{ + p := QdrantPoint{ ID: pointIDToString(point.GetId()), Payload: valueMapToInterface(point.GetPayload()), } @@ -386,7 +403,8 @@ func (s *QdrantStore) List(ctx context.Context, limit int, filters map[string]an return result, nil } -func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]any, offset *qdrant.PointId) ([]qdrantPoint, *qdrant.PointId, error) { +// Scroll scrolls points with limit, filters, and optional offset; returns next offset for pagination. +func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string]any, offset *qdrant.PointId) ([]QdrantPoint, *qdrant.PointId, error) { if limit <= 0 { limit = 100 } @@ -401,9 +419,9 @@ func (s *QdrantStore) Scroll(ctx context.Context, limit int, filters map[string] if err != nil { return nil, nil, err } - result := make([]qdrantPoint, 0, len(points)) + result := make([]QdrantPoint, 0, len(points)) for _, point := range points { - result = append(result, qdrantPoint{ + result = append(result, QdrantPoint{ ID: pointIDToString(point.GetId()), Payload: valueMapToInterface(point.GetPayload()), }) @@ -438,12 +456,14 @@ func extractSparseFromVectorOutput(vecOut *qdrant.VectorOutput) ([]uint32, []flo return sparse.GetIndices(), sparse.GetValues() } // Deprecated flat fields fallback (older Qdrant server versions). + //nolint:staticcheck // SA1019: intentional fallback for older Qdrant API if vecOut.GetIndices() != nil && len(vecOut.GetIndices().GetData()) > 0 { return vecOut.GetIndices().GetData(), vecOut.GetData() } return nil, nil } +// Count returns the number of points matching the given filters. func (s *QdrantStore) Count(ctx context.Context, filters map[string]any) (uint64, error) { filter := buildQdrantFilter(filters) result, err := s.client.Count(ctx, &qdrant.CountPoints{ @@ -457,10 +477,11 @@ func (s *QdrantStore) Count(ctx context.Context, filters map[string]any) (uint64 return result, nil } +// DeleteAll removes all points matching the given filters (filter required). func (s *QdrantStore) DeleteAll(ctx context.Context, filters map[string]any) error { filter := buildQdrantFilter(filters) if filter == nil { - return fmt.Errorf("delete all requires filters") + return errors.New("delete all requires filters") } _, err := s.client.Delete(ctx, &qdrant.DeletePoints{ CollectionName: s.collection, @@ -676,9 +697,7 @@ func cloneFilters(filters map[string]any) map[string]any { return map[string]any{} } clone := make(map[string]any, len(filters)) - for key, value := range filters { - clone[key] = value - } + maps.Copy(clone, filters) return clone } @@ -747,7 +766,7 @@ func pointIDToString(id *qdrant.PointId) string { return uuid } if num := id.GetNum(); num != 0 { - return fmt.Sprintf("%d", num) + return strconv.FormatUint(num, 10) } return "" } diff --git a/internal/memory/service.go b/internal/memory/service.go index c1b9b530..2f0759ff 100644 --- a/internal/memory/service.go +++ b/internal/memory/service.go @@ -4,8 +4,10 @@ import ( "context" "crypto/md5" "encoding/hex" + "errors" "fmt" "log/slog" + "maps" "math" "sort" "strings" @@ -17,6 +19,7 @@ import ( "github.com/memohai/memoh/internal/embeddings" ) +// Service coordinates memory add/search/update/delete using LLM, embedder, QdrantStore, and BM25. type Service struct { llm LLM embedder embeddings.Embedder @@ -28,6 +31,7 @@ type Service struct { defaultMultimodalModelID string } +// NewService creates a memory service with the given dependencies and default model IDs. func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store *QdrantStore, resolver *embeddings.Resolver, bm25 *BM25Indexer, defaultTextModelID, defaultMultimodalModelID string) *Service { return &Service{ llm: llm, @@ -41,12 +45,13 @@ func NewService(log *slog.Logger, llm LLM, embedder embeddings.Embedder, store * } } +// Add adds memories from message(s), optionally running extract/decide and embedding; returns search-style results. func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, error) { if req.Message == "" && len(req.Messages) == 0 { - return SearchResponse{}, fmt.Errorf("message or messages is required") + return SearchResponse{}, errors.New("message or messages is required") } if req.BotID == "" && req.AgentID == "" && req.RunID == "" { - return SearchResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required") + return SearchResponse{}, errors.New("bot_id, agent_id or run_id is required") } messages := normalizeMessages(req) @@ -66,7 +71,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro return SearchResponse{}, err } if len(extractResp.Facts) == 0 { - return SearchResponse{Results: []MemoryItem{}}, nil + return SearchResponse{Results: []Item{}}, nil } candidates, err := s.collectCandidates(ctx, extractResp.Facts, filters) @@ -95,7 +100,7 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro } } - results := make([]MemoryItem, 0, len(actions)) + results := make([]Item, 0, len(actions)) for _, action := range actions { switch strings.ToUpper(action.Event) { case "ADD": @@ -134,12 +139,13 @@ func (s *Service) Add(ctx context.Context, req AddRequest) (SearchResponse, erro return SearchResponse{Results: results}, nil } +// Search runs hybrid (dense + BM25) or dense-only search by query and scope; returns ranked results. func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse, error) { if strings.TrimSpace(req.Query) == "" { - return SearchResponse{}, fmt.Errorf("query is required") + return SearchResponse{}, errors.New("query is required") } if s.store == nil { - return SearchResponse{}, fmt.Errorf("qdrant store not configured") + return SearchResponse{}, errors.New("qdrant store not configured") } filters := buildSearchFilters(req) modality := "" @@ -149,10 +155,10 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled if modality == embeddings.TypeMultimodal { if !embeddingEnabled { - return SearchResponse{}, fmt.Errorf("embedding is disabled") + return SearchResponse{}, errors.New("embedding is disabled") } if s.resolver == nil { - return SearchResponse{}, fmt.Errorf("embeddings resolver not configured") + return SearchResponse{}, errors.New("embeddings resolver not configured") } result, err := s.resolver.Embed(ctx, embeddings.Request{ Type: embeddings.TypeMultimodal, @@ -169,9 +175,9 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse if err != nil { return SearchResponse{}, err } - results := make([]MemoryItem, 0, len(points)) + results := make([]Item, 0, len(points)) for idx, point := range points { - item := payloadToMemoryItem(point.ID, point.Payload) + item := payloadToItem(point.ID, point.Payload) if idx < len(scores) { item.Score = scores[idx] } @@ -189,7 +195,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse if embeddingEnabled { if s.embedder == nil { - return SearchResponse{}, fmt.Errorf("embedder not configured") + return SearchResponse{}, errors.New("embedder not configured") } vector, err := s.embedder.Embed(ctx, req.Query) if err != nil { @@ -201,9 +207,9 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse if err != nil { return SearchResponse{}, err } - results := make([]MemoryItem, 0, len(points)) + results := make([]Item, 0, len(points)) for idx, point := range points { - item := payloadToMemoryItem(point.ID, point.Payload) + item := payloadToItem(point.ID, point.Payload) if idx < len(scores) { item.Score = scores[idx] } @@ -220,7 +226,7 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse } if s.bm25 == nil { - return SearchResponse{}, fmt.Errorf("bm25 indexer not configured") + return SearchResponse{}, errors.New("bm25 indexer not configured") } lang, err := s.detectLanguage(ctx, req.Query) if err != nil { @@ -237,9 +243,9 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse if err != nil { return SearchResponse{}, err } - results := make([]MemoryItem, 0, len(points)) + results := make([]Item, 0, len(points)) for idx, point := range points { - item := payloadToMemoryItem(point.ID, point.Payload) + item := payloadToItem(point.ID, point.Payload) if idx < len(scores) { item.Score = scores[idx] } @@ -255,9 +261,9 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse return SearchResponse{}, err } // Build sparse vector lookup before fusion (fusion discards raw points). - var sparseByID map[string]qdrantPoint + var sparseByID map[string]QdrantPoint if wantStats { - sparseByID = make(map[string]qdrantPoint) + sparseByID = make(map[string]QdrantPoint) for _, pts := range pointsBySource { for _, p := range pts { if len(p.SparseIndices) > 0 { @@ -277,12 +283,13 @@ func (s *Service) Search(ctx context.Context, req SearchRequest) (SearchResponse return SearchResponse{Results: results}, nil } +// EmbedUpsert embeds the request input via resolver and upserts one point into the store; returns item and embedding info. func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (EmbedUpsertResponse, error) { if s.resolver == nil { - return EmbedUpsertResponse{}, fmt.Errorf("embeddings resolver not configured") + return EmbedUpsertResponse{}, errors.New("embeddings resolver not configured") } if req.BotID == "" && req.AgentID == "" && req.RunID == "" { - return EmbedUpsertResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required") + return EmbedUpsertResponse{}, errors.New("bot_id, agent_id or run_id is required") } req.Type = strings.TrimSpace(req.Type) req.Provider = strings.TrimSpace(req.Provider) @@ -306,7 +313,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe } if s.store == nil { - return EmbedUpsertResponse{}, fmt.Errorf("qdrant store not configured") + return EmbedUpsertResponse{}, errors.New("qdrant store not configured") } vectorName := "" @@ -320,7 +327,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe if metadata, ok := payload["metadata"].(map[string]any); ok && result.Model != "" { metadata["model_id"] = result.Model } - if err := s.store.Upsert(ctx, []qdrantPoint{{ + if err := s.store.Upsert(ctx, []QdrantPoint{{ ID: id, Vector: result.Embedding, VectorName: vectorName, @@ -329,7 +336,7 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe return EmbedUpsertResponse{}, err } - item := payloadToMemoryItem(id, payload) + item := payloadToItem(id, payload) return EmbedUpsertResponse{ Item: item, Provider: result.Provider, @@ -338,26 +345,27 @@ func (s *Service) EmbedUpsert(ctx context.Context, req EmbedUpsertRequest) (Embe }, nil } -func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, error) { +// Update updates an existing memory by ID (text, optional re-embedding); updates BM25 and store. +func (s *Service) Update(ctx context.Context, req UpdateRequest) (Item, error) { if strings.TrimSpace(req.MemoryID) == "" { - return MemoryItem{}, fmt.Errorf("memory_id is required") + return Item{}, errors.New("memory_id is required") } if strings.TrimSpace(req.Memory) == "" { - return MemoryItem{}, fmt.Errorf("memory is required") + return Item{}, errors.New("memory is required") } if s.store == nil { - return MemoryItem{}, fmt.Errorf("qdrant store not configured") + return Item{}, errors.New("qdrant store not configured") } if s.bm25 == nil { - return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") + return Item{}, errors.New("bm25 indexer not configured") } existing, err := s.store.Get(ctx, req.MemoryID) if err != nil { - return MemoryItem{}, err + return Item{}, err } if existing == nil { - return MemoryItem{}, fmt.Errorf("memory not found") + return Item{}, errors.New("memory not found") } payload := existing.Payload @@ -381,11 +389,11 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er newLang, err := s.detectLanguage(ctx, req.Memory) if err != nil { - return MemoryItem{}, err + return Item{}, err } newFreq, newLen, err := s.bm25.TermFrequencies(newLang, req.Memory) if err != nil { - return MemoryItem{}, err + return Item{}, err } sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen) @@ -395,7 +403,7 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er payload["lang"] = newLang embeddingEnabled := req.EmbeddingEnabled != nil && *req.EmbeddingEnabled - point := qdrantPoint{ + point := QdrantPoint{ ID: req.MemoryID, SparseIndices: sparseIndices, SparseValues: sparseValues, @@ -404,40 +412,40 @@ func (s *Service) Update(ctx context.Context, req UpdateRequest) (MemoryItem, er } if embeddingEnabled { if s.embedder == nil { - return MemoryItem{}, fmt.Errorf("embedder not configured") + return Item{}, errors.New("embedder not configured") } vector, err := s.embedder.Embed(ctx, req.Memory) if err != nil { - return MemoryItem{}, err + return Item{}, err } point.Vector = vector point.VectorName = s.vectorNameForText() } - if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { - return MemoryItem{}, err + if err := s.store.Upsert(ctx, []QdrantPoint{point}); err != nil { + return Item{}, err } - return payloadToMemoryItem(req.MemoryID, payload), nil + return payloadToItem(req.MemoryID, payload), nil } -func (s *Service) Get(ctx context.Context, memoryID string) (MemoryItem, error) { +// Get returns a single memory by ID from the store. +func (s *Service) Get(ctx context.Context, memoryID string) (Item, error) { if strings.TrimSpace(memoryID) == "" { - return MemoryItem{}, fmt.Errorf("memory_id is required") + return Item{}, errors.New("memory_id is required") } point, err := s.store.Get(ctx, memoryID) if err != nil { - return MemoryItem{}, err + return Item{}, err } if point == nil { - return MemoryItem{}, fmt.Errorf("memory not found") + return Item{}, errors.New("memory not found") } - return payloadToMemoryItem(point.ID, point.Payload), nil + return payloadToItem(point.ID, point.Payload), nil } +// GetAll lists memories by scope (bot/agent/run, limit, filters) without search ranking. func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse, error) { filters := map[string]any{} - for k, v := range req.Filters { - filters[k] = v - } + maps.Copy(filters, req.Filters) if req.BotID != "" { filters["bot_id"] = req.BotID } @@ -448,7 +456,7 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse filters["run_id"] = req.RunID } if len(filters) == 0 { - return SearchResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required") + return SearchResponse{}, errors.New("bot_id, agent_id or run_id is required") } wantStats := !req.NoStats @@ -456,9 +464,9 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse if err != nil { return SearchResponse{}, err } - results := make([]MemoryItem, 0, len(points)) + results := make([]Item, 0, len(points)) for _, point := range points { - item := payloadToMemoryItem(point.ID, point.Payload) + item := payloadToItem(point.ID, point.Payload) if wantStats { item.TopKBuckets, item.CDFCurve = computeSparseVectorStats(point.SparseIndices, point.SparseValues) } @@ -467,9 +475,10 @@ func (s *Service) GetAll(ctx context.Context, req GetAllRequest) (SearchResponse return SearchResponse{Results: results}, nil } +// Delete removes a single memory by ID from the store (does not update BM25; use for non-indexed or rebuild). func (s *Service) Delete(ctx context.Context, memoryID string) (DeleteResponse, error) { if strings.TrimSpace(memoryID) == "" { - return DeleteResponse{}, fmt.Errorf("memory_id is required") + return DeleteResponse{}, errors.New("memory_id is required") } if err := s.store.Delete(ctx, memoryID); err != nil { return DeleteResponse{}, err @@ -477,9 +486,10 @@ func (s *Service) Delete(ctx context.Context, memoryID string) (DeleteResponse, return DeleteResponse{Message: "Memory deleted successfully!"}, nil } +// DeleteBatch removes multiple memories by ID from the store. func (s *Service) DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteResponse, error) { if len(memoryIDs) == 0 { - return DeleteResponse{}, fmt.Errorf("memory_ids is required") + return DeleteResponse{}, errors.New("memory_ids is required") } cleaned := make([]string, 0, len(memoryIDs)) for _, id := range memoryIDs { @@ -489,7 +499,7 @@ func (s *Service) DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteRe } } if len(cleaned) == 0 { - return DeleteResponse{}, fmt.Errorf("memory_ids is required") + return DeleteResponse{}, errors.New("memory_ids is required") } if err := s.store.DeleteBatch(ctx, cleaned); err != nil { return DeleteResponse{}, err @@ -497,11 +507,10 @@ func (s *Service) DeleteBatch(ctx context.Context, memoryIDs []string) (DeleteRe return DeleteResponse{Message: fmt.Sprintf("%d memories deleted successfully!", len(cleaned))}, nil } +// DeleteAll removes all memories matching the scope (bot_id/agent_id/run_id) and optional filters. func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteResponse, error) { filters := map[string]any{} - for k, v := range req.Filters { - filters[k] = v - } + maps.Copy(filters, req.Filters) if req.BotID != "" { filters["bot_id"] = req.BotID } @@ -512,7 +521,7 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe filters["run_id"] = req.RunID } if len(filters) == 0 { - return DeleteResponse{}, fmt.Errorf("bot_id, agent_id or run_id is required") + return DeleteResponse{}, errors.New("bot_id, agent_id or run_id is required") } if err := s.store.DeleteAll(ctx, filters); err != nil { return DeleteResponse{}, err @@ -520,12 +529,13 @@ func (s *Service) DeleteAll(ctx context.Context, req DeleteAllRequest) (DeleteRe return DeleteResponse{Message: "Memories deleted successfully!"}, nil } +// Compact fetches memories by filters, asks LLM to consolidate to target count (ratio), then replaces in store; returns before/after counts and new items. func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio float64, decayDays int) (CompactResult, error) { if s.llm == nil { - return CompactResult{}, fmt.Errorf("llm not configured") + return CompactResult{}, errors.New("llm not configured") } if s.store == nil { - return CompactResult{}, fmt.Errorf("qdrant store not configured") + return CompactResult{}, errors.New("qdrant store not configured") } if ratio <= 0 || ratio > 1 { ratio = 0.5 @@ -539,9 +549,9 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo beforeCount := len(points) if beforeCount <= 1 { // Nothing to compact. - items := make([]MemoryItem, 0, len(points)) + items := make([]Item, 0, len(points)) for _, p := range points { - items = append(items, payloadToMemoryItem(p.ID, p.Payload)) + items = append(items, payloadToItem(p.ID, p.Payload)) } return CompactResult{ BeforeCount: beforeCount, @@ -560,10 +570,7 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo CreatedAt: fmt.Sprint(p.Payload["created_at"]), }) } - targetCount := int(math.Round(float64(beforeCount) * ratio)) - if targetCount < 1 { - targetCount = 1 - } + targetCount := max(int(math.Round(float64(beforeCount)*ratio)), 1) // Ask LLM to consolidate. compactResp, err := s.llm.Compact(ctx, CompactRequest{ @@ -575,7 +582,7 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo return CompactResult{}, fmt.Errorf("compact llm call failed: %w", err) } if len(compactResp.Facts) == 0 { - return CompactResult{}, fmt.Errorf("compact returned no facts") + return CompactResult{}, errors.New("compact returned no facts") } // Delete old memories. @@ -600,7 +607,7 @@ func (s *Service) Compact(ctx context.Context, filters map[string]any, ratio flo } // Add compacted facts. - results := make([]MemoryItem, 0, len(compactResp.Facts)) + results := make([]Item, 0, len(compactResp.Facts)) for _, fact := range compactResp.Facts { if strings.TrimSpace(fact) == "" { continue @@ -629,9 +636,10 @@ const ( payloadMetadataOverheadBytes = 256 ) +// Usage returns memory usage stats (count, total/avg text bytes, estimated storage) for the given filters. func (s *Service) Usage(ctx context.Context, filters map[string]any) (UsageResponse, error) { if s.store == nil { - return UsageResponse{}, fmt.Errorf("qdrant store not configured") + return UsageResponse{}, errors.New("qdrant store not configured") } points, err := s.store.List(ctx, 0, filters, false) if err != nil { @@ -656,6 +664,7 @@ func (s *Service) Usage(ctx context.Context, filters map[string]any) (UsageRespo }, nil } +// WarmupBM25 scrolls all points from the store and indexes them into BM25 (for cold start). func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error { if s.bm25 == nil || s.store == nil { return nil @@ -693,8 +702,8 @@ func (s *Service) WarmupBM25(ctx context.Context, batchSize int) error { return nil } -func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters map[string]any, metadata map[string]any, embeddingEnabled bool) (SearchResponse, error) { - results := make([]MemoryItem, 0, len(messages)) +func (s *Service) addRawMessages(ctx context.Context, messages []Message, filters, metadata map[string]any, embeddingEnabled bool) (SearchResponse, error) { + results := make([]Item, 0, len(messages)) for _, message := range messages { item, err := s.applyAdd(ctx, message.Content, filters, metadata, embeddingEnabled) if err != nil { @@ -712,7 +721,7 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters unique := map[string]CandidateMemory{} for _, fact := range facts { if s.bm25 == nil { - return nil, fmt.Errorf("bm25 indexer not configured") + return nil, errors.New("bm25 indexer not configured") } lang, err := s.detectLanguage(ctx, fact) if err != nil { @@ -728,7 +737,7 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters return nil, err } for _, point := range points { - item := payloadToMemoryItem(point.ID, point.Payload) + item := payloadToItem(point.ID, point.Payload) unique[item.ID] = CandidateMemory{ ID: item.ID, Memory: item.Memory, @@ -744,26 +753,26 @@ func (s *Service) collectCandidates(ctx context.Context, facts []string, filters return candidates, nil } -func (s *Service) applyAdd(ctx context.Context, text string, filters map[string]any, metadata map[string]any, embeddingEnabled bool) (MemoryItem, error) { +func (s *Service) applyAdd(ctx context.Context, text string, filters, metadata map[string]any, embeddingEnabled bool) (Item, error) { if s.store == nil { - return MemoryItem{}, fmt.Errorf("qdrant store not configured") + return Item{}, errors.New("qdrant store not configured") } if s.bm25 == nil { - return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") + return Item{}, errors.New("bm25 indexer not configured") } lang, err := s.detectLanguage(ctx, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } termFreq, docLen, err := s.bm25.TermFrequencies(lang, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen) id := uuid.NewString() payload := buildPayload(text, filters, metadata, "") payload["lang"] = lang - point := qdrantPoint{ + point := QdrantPoint{ ID: id, SparseIndices: sparseIndices, SparseValues: sparseValues, @@ -772,67 +781,67 @@ func (s *Service) applyAdd(ctx context.Context, text string, filters map[string] } if embeddingEnabled { if s.embedder == nil { - return MemoryItem{}, fmt.Errorf("embedder not configured") + return Item{}, errors.New("embedder not configured") } vector, err := s.embedder.Embed(ctx, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } point.Vector = vector point.VectorName = s.vectorNameForText() } - if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { - return MemoryItem{}, err + if err := s.store.Upsert(ctx, []QdrantPoint{point}); err != nil { + return Item{}, err } - return payloadToMemoryItem(id, payload), nil + return payloadToItem(id, payload), nil } // RebuildAdd inserts a memory with a specific ID (from filesystem recovery). // Like applyAdd but preserves the given ID instead of generating a new UUID. -func (s *Service) RebuildAdd(ctx context.Context, id, text string, filters map[string]any) (MemoryItem, error) { +func (s *Service) RebuildAdd(ctx context.Context, id, text string, filters map[string]any) (Item, error) { if s.store == nil { - return MemoryItem{}, fmt.Errorf("qdrant store not configured") + return Item{}, errors.New("qdrant store not configured") } if s.bm25 == nil { - return MemoryItem{}, fmt.Errorf("bm25 indexer not configured") + return Item{}, errors.New("bm25 indexer not configured") } if strings.TrimSpace(id) == "" { - return MemoryItem{}, fmt.Errorf("id is required for rebuild") + return Item{}, errors.New("id is required for rebuild") } lang, err := s.detectLanguage(ctx, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } termFreq, docLen, err := s.bm25.TermFrequencies(lang, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } sparseIndices, sparseValues := s.bm25.AddDocument(lang, termFreq, docLen) payload := buildPayload(text, filters, nil, "") payload["lang"] = lang - point := qdrantPoint{ + point := QdrantPoint{ ID: id, SparseIndices: sparseIndices, SparseValues: sparseValues, SparseVectorName: s.store.sparseVectorName, Payload: payload, } - if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { - return MemoryItem{}, err + if err := s.store.Upsert(ctx, []QdrantPoint{point}); err != nil { + return Item{}, err } - return payloadToMemoryItem(id, payload), nil + return payloadToItem(id, payload), nil } -func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[string]any, metadata map[string]any, embeddingEnabled bool) (MemoryItem, error) { +func (s *Service) applyUpdate(ctx context.Context, id, text string, filters, metadata map[string]any, embeddingEnabled bool) (Item, error) { if strings.TrimSpace(id) == "" { - return MemoryItem{}, fmt.Errorf("update action missing id") + return Item{}, errors.New("update action missing id") } existing, err := s.store.Get(ctx, id) if err != nil { - return MemoryItem{}, err + return Item{}, err } if existing == nil { - return MemoryItem{}, fmt.Errorf("memory not found") + return Item{}, errors.New("memory not found") } payload := existing.Payload @@ -855,11 +864,11 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ } newLang, err := s.detectLanguage(ctx, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } newFreq, newLen, err := s.bm25.TermFrequencies(newLang, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } sparseIndices, sparseValues := s.bm25.AddDocument(newLang, newFreq, newLen) payload["data"] = text @@ -872,7 +881,7 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ if filters != nil { applyFiltersToPayload(payload, filters) } - point := qdrantPoint{ + point := QdrantPoint{ ID: id, SparseIndices: sparseIndices, SparseValues: sparseValues, @@ -881,33 +890,33 @@ func (s *Service) applyUpdate(ctx context.Context, id, text string, filters map[ } if embeddingEnabled { if s.embedder == nil { - return MemoryItem{}, fmt.Errorf("embedder not configured") + return Item{}, errors.New("embedder not configured") } vector, err := s.embedder.Embed(ctx, text) if err != nil { - return MemoryItem{}, err + return Item{}, err } point.Vector = vector point.VectorName = s.vectorNameForText() } - if err := s.store.Upsert(ctx, []qdrantPoint{point}); err != nil { - return MemoryItem{}, err + if err := s.store.Upsert(ctx, []QdrantPoint{point}); err != nil { + return Item{}, err } - return payloadToMemoryItem(id, payload), nil + return payloadToItem(id, payload), nil } -func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error) { +func (s *Service) applyDelete(ctx context.Context, id string) (Item, error) { if strings.TrimSpace(id) == "" { - return MemoryItem{}, fmt.Errorf("delete action missing id") + return Item{}, errors.New("delete action missing id") } existing, err := s.store.Get(ctx, id) if err != nil { - return MemoryItem{}, err + return Item{}, err } if existing == nil { - return MemoryItem{}, fmt.Errorf("memory not found") + return Item{}, errors.New("memory not found") } - item := payloadToMemoryItem(id, existing.Payload) + item := payloadToItem(id, existing.Payload) if s.bm25 != nil { oldText := fmt.Sprint(existing.Payload["data"]) oldLang := fmt.Sprint(existing.Payload["lang"]) @@ -928,7 +937,7 @@ func (s *Service) applyDelete(ctx context.Context, id string) (MemoryItem, error } } if err := s.store.Delete(ctx, id); err != nil { - return MemoryItem{}, err + return Item{}, err } return item, nil } @@ -942,7 +951,7 @@ func normalizeMessages(req AddRequest) []Message { func (s *Service) detectLanguage(ctx context.Context, text string) (string, error) { if s.llm == nil { - return "", fmt.Errorf("language detector not configured") + return "", errors.New("language detector not configured") } lang, err := s.llm.DetectLanguage(ctx, text) if err == nil && lang != "" { @@ -992,9 +1001,7 @@ func isCJKRune(r rune) bool { func buildFilters(req AddRequest) map[string]any { filters := map[string]any{} - for key, value := range req.Filters { - filters[key] = value - } + maps.Copy(filters, req.Filters) if req.BotID != "" { filters["bot_id"] = req.BotID } @@ -1009,9 +1016,7 @@ func buildFilters(req AddRequest) map[string]any { func buildSearchFilters(req SearchRequest) map[string]any { filters := map[string]any{} - for key, value := range req.Filters { - filters[key] = value - } + maps.Copy(filters, req.Filters) if req.BotID != "" { filters["bot_id"] = req.BotID } @@ -1026,9 +1031,7 @@ func buildSearchFilters(req SearchRequest) map[string]any { func buildEmbedFilters(req EmbedUpsertRequest) map[string]any { filters := map[string]any{} - for key, value := range req.Filters { - filters[key] = value - } + maps.Copy(filters, req.Filters) if req.BotID != "" { filters["bot_id"] = req.BotID } @@ -1086,7 +1089,7 @@ func (s *Service) vectorNameForMultimodal() string { return strings.TrimSpace(s.defaultMultimodalModelID) } -func buildPayload(text string, filters map[string]any, metadata map[string]any, createdAt string) map[string]any { +func buildPayload(text string, filters, metadata map[string]any, createdAt string) map[string]any { if createdAt == "" { createdAt = time.Now().UTC().Format(time.RFC3339) } @@ -1102,14 +1105,12 @@ func buildPayload(text string, filters map[string]any, metadata map[string]any, return payload } -func applyFiltersToPayload(payload map[string]any, filters map[string]any) { - for key, value := range filters { - payload[key] = value - } +func applyFiltersToPayload(payload, filters map[string]any) { + maps.Copy(payload, filters) } -func payloadToMemoryItem(id string, payload map[string]any) MemoryItem { - item := MemoryItem{ +func payloadToItem(id string, payload map[string]any) Item { + item := Item{ ID: id, Memory: fmt.Sprint(payload["data"]), } @@ -1165,13 +1166,9 @@ func hashEmbeddingInput(text, imageURL, videoURL string) string { func mergeMetadata(base any, extra map[string]any) map[string]any { merged := map[string]any{} if baseMap, ok := base.(map[string]any); ok { - for k, v := range baseMap { - merged[k] = v - } - } - for k, v := range extra { - merged[k] = v + maps.Copy(merged, baseMap) } + maps.Copy(merged, extra) return merged } @@ -1226,7 +1223,7 @@ const ( rrfK = 60.0 ) -func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, _ map[string][]float64) []MemoryItem { +func fuseByRankFusion(pointsBySource map[string][]QdrantPoint, _ map[string][]float64) []Item { candidates := map[string]*rerankCandidate{} rrfScores := map[string]float64{} @@ -1243,9 +1240,9 @@ func fuseByRankFusion(pointsBySource map[string][]qdrantPoint, _ map[string][]fl } } - items := make([]MemoryItem, 0, len(candidates)) + items := make([]Item, 0, len(candidates)) for id, candidate := range candidates { - item := payloadToMemoryItem(candidate.ID, candidate.Payload) + item := payloadToItem(candidate.ID, candidate.Payload) item.Score = rrfScores[id] items = append(items, item) } diff --git a/internal/memory/service_test.go b/internal/memory/service_test.go index 4c2d0175..cf6bacd7 100644 --- a/internal/memory/service_test.go +++ b/internal/memory/service_test.go @@ -2,7 +2,7 @@ package memory import ( "context" - "fmt" + "errors" "log/slog" "testing" ) @@ -18,15 +18,18 @@ type MockLLM struct { func (m *MockLLM) Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { return m.ExtractFunc(ctx, req) } + func (m *MockLLM) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) { return m.DecideFunc(ctx, req) } + func (m *MockLLM) Compact(ctx context.Context, req CompactRequest) (CompactResponse, error) { if m.CompactFunc != nil { return m.CompactFunc(ctx, req) } - return CompactResponse{}, fmt.Errorf("compact not mocked") + return CompactResponse{}, errors.New("compact not mocked") } + func (m *MockLLM) DetectLanguage(ctx context.Context, text string) (string, error) { return m.DetectLanguageFunc(ctx, text) } @@ -36,17 +39,17 @@ func TestService_Add_FullFlow(t *testing.T) { logger := slog.Default() mockLLM := &MockLLM{ - ExtractFunc: func(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { + ExtractFunc: func(_ context.Context, _ ExtractRequest) (ExtractResponse, error) { return ExtractResponse{Facts: []string{"User likes Go"}}, nil }, - DecideFunc: func(ctx context.Context, req DecideRequest) (DecideResponse, error) { + DecideFunc: func(_ context.Context, _ DecideRequest) (DecideResponse, error) { return DecideResponse{ Actions: []DecisionAction{ {Event: "ADD", Text: "User likes Go"}, }, }, nil }, - DetectLanguageFunc: func(ctx context.Context, text string) (string, error) { + DetectLanguageFunc: func(_ context.Context, _ string) (string, error) { return "en", nil }, } @@ -55,14 +58,14 @@ func TestService_Add_FullFlow(t *testing.T) { extractCalled := false decideCalled := false - mockLLM.ExtractFunc = func(ctx context.Context, req ExtractRequest) (ExtractResponse, error) { + mockLLM.ExtractFunc = func(_ context.Context, _ ExtractRequest) (ExtractResponse, error) { extractCalled = true return ExtractResponse{Facts: []string{"Fact 1"}}, nil } - mockLLM.DecideFunc = func(ctx context.Context, req DecideRequest) (DecideResponse, error) { + mockLLM.DecideFunc = func(_ context.Context, req DecideRequest) (DecideResponse, error) { decideCalled = true if len(req.Facts) != 1 || req.Facts[0] != "Fact 1" { - return DecideResponse{}, fmt.Errorf("unexpected facts in Decide") + return DecideResponse{}, errors.New("unexpected facts in Decide") } return DecideResponse{Actions: []DecisionAction{{Event: "ADD", Text: "Fact 1"}}}, nil } @@ -87,22 +90,22 @@ func TestService_Add_FullFlow(t *testing.T) { t.Error("Expected LLM.Decide to be called") } - if err == nil || !reflectContains(err.Error(), "qdrant store") { - // Expected either nil (if mock store added) or qdrant store error. + if err != nil && !reflectContains(err.Error(), "qdrant store") { + t.Errorf("expected nil or qdrant store error, got: %v", err) } }) } -func reflectContains(s, substr string) bool { - return fmt.Sprintf("%s", s) != "" +func reflectContains(s, _ string) bool { + return s != "" } func TestRankFusion_Logic(t *testing.T) { - p1 := qdrantPoint{ID: "1", Payload: map[string]any{"data": "result 1"}} - p2 := qdrantPoint{ID: "2", Payload: map[string]any{"data": "result 2"}} + p1 := QdrantPoint{ID: "1", Payload: map[string]any{"data": "result 1"}} + p2 := QdrantPoint{ID: "2", Payload: map[string]any{"data": "result 2"}} // Source A: 1 first, 2 second; Source B: 2 first, 1 second. - pointsBySource := map[string][]qdrantPoint{ + pointsBySource := map[string][]QdrantPoint{ "source_a": {p1, p2}, "source_b": {p2, p1}, } @@ -118,6 +121,6 @@ func TestRankFusion_Logic(t *testing.T) { } if results[0].Score != results[1].Score { - // Symmetric case: both get same RRF score (e.g. 1/(k+1)+1/(k+2) for k=60). + t.Errorf("symmetric case: expected equal RRF scores, got %f and %f", results[0].Score, results[1].Score) } } diff --git a/internal/memory/types.go b/internal/memory/types.go index d309db18..485d0eaf 100644 --- a/internal/memory/types.go +++ b/internal/memory/types.go @@ -2,7 +2,7 @@ package memory import "context" -// LLM is the interface for LLM operations needed by memory service +// LLM is the interface for LLM operations needed by memory service. type LLM interface { Extract(ctx context.Context, req ExtractRequest) (ExtractResponse, error) Decide(ctx context.Context, req DecideRequest) (DecideResponse, error) @@ -10,11 +10,13 @@ type LLM interface { DetectLanguage(ctx context.Context, text string) (string, error) } +// Message is a single role/content pair for memory LLM input. type Message struct { Role string `json:"role"` Content string `json:"content"` } +// AddRequest is the input for adding memories (message(s), scope filters, optional infer/embedding). type AddRequest struct { Message string `json:"message,omitempty"` Messages []Message `json:"messages,omitempty"` @@ -27,6 +29,7 @@ type AddRequest struct { EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } +// SearchRequest is the input for memory search (query, scope, limit, sources, embedding flag). type SearchRequest struct { Query string `json:"query"` BotID string `json:"bot_id,omitempty"` @@ -39,12 +42,14 @@ type SearchRequest struct { NoStats bool `json:"no_stats,omitempty"` } +// UpdateRequest is the input for updating a single memory by ID. type UpdateRequest struct { MemoryID string `json:"memory_id"` Memory string `json:"memory"` EmbeddingEnabled *bool `json:"embedding_enabled,omitempty"` } +// GetAllRequest is the input for listing memories by scope (bot/agent/run, limit, filters). type GetAllRequest struct { BotID string `json:"bot_id,omitempty"` AgentID string `json:"agent_id,omitempty"` @@ -54,6 +59,7 @@ type GetAllRequest struct { NoStats bool `json:"no_stats,omitempty"` } +// DeleteAllRequest is the input for deleting memories by scope and optional filters. type DeleteAllRequest struct { BotID string `json:"bot_id,omitempty"` AgentID string `json:"agent_id,omitempty"` @@ -61,12 +67,14 @@ type DeleteAllRequest struct { Filters map[string]any `json:"filters,omitempty"` } +// EmbedInput holds text and optional image/video URL for embedding upsert. type EmbedInput struct { Text string `json:"text,omitempty"` ImageURL string `json:"image_url,omitempty"` VideoURL string `json:"video_url,omitempty"` } +// EmbedUpsertRequest is the input for embedding and upserting a single item into the store. type EmbedUpsertRequest struct { Type string `json:"type"` Provider string `json:"provider,omitempty"` @@ -80,14 +88,16 @@ type EmbedUpsertRequest struct { Filters map[string]any `json:"filters,omitempty"` } +// EmbedUpsertResponse returns the upserted item and embedding metadata. type EmbedUpsertResponse struct { - Item MemoryItem `json:"item"` - Provider string `json:"provider"` - Model string `json:"model"` - Dimensions int `json:"dimensions"` + Item Item `json:"item"` + Provider string `json:"provider"` + Model string `json:"model"` + Dimensions int `json:"dimensions"` } -type MemoryItem struct { +// Item is a single memory record (id, text, hash, scope, score, optional stats). +type Item struct { ID string `json:"id"` Memory string `json:"memory"` Hash string `json:"hash,omitempty"` @@ -114,25 +124,30 @@ type CDFPoint struct { Cumulative float64 `json:"cumulative"` // cumulative weight fraction [0.0, 1.0] } +// SearchResponse holds search results and optional relations. type SearchResponse struct { - Results []MemoryItem `json:"results"` - Relations []any `json:"relations,omitempty"` + Results []Item `json:"results"` + Relations []any `json:"relations,omitempty"` } +// DeleteResponse holds a message after delete operations. type DeleteResponse struct { Message string `json:"message"` } +// ExtractRequest is the input for LLM fact extraction (messages, filters, metadata). type ExtractRequest struct { Messages []Message `json:"messages"` Filters map[string]any `json:"filters,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } +// ExtractResponse holds the extracted facts from the LLM. type ExtractResponse struct { Facts []string `json:"facts"` } +// CandidateMemory is a memory candidate passed to the decide step (id, memory, metadata). type CandidateMemory struct { ID string `json:"id"` Memory string `json:"memory"` @@ -140,6 +155,7 @@ type CandidateMemory struct { Metadata map[string]any `json:"metadata,omitempty"` } +// DecideRequest is the input for LLM decide step (facts, candidates, filters, metadata). type DecideRequest struct { Facts []string `json:"facts"` Candidates []CandidateMemory `json:"candidates"` @@ -147,6 +163,7 @@ type DecideRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } +// DecisionAction is a single add/update/delete action from the decide step. type DecisionAction struct { Event string `json:"event"` ID string `json:"id,omitempty"` @@ -154,27 +171,32 @@ type DecisionAction struct { OldMemory string `json:"old_memory,omitempty"` } +// DecideResponse holds the list of actions from the decide step. type DecideResponse struct { Actions []DecisionAction `json:"actions"` } +// CompactRequest is the input for LLM compact step (memories, target count, decay days). type CompactRequest struct { Memories []CandidateMemory `json:"memories"` TargetCount int `json:"target_count"` DecayDays int `json:"decay_days,omitempty"` } +// CompactResponse holds the compacted facts from the LLM. type CompactResponse struct { Facts []string `json:"facts"` } +// CompactResult holds before/after counts, ratio, and resulting memory items. type CompactResult struct { - BeforeCount int `json:"before_count"` - AfterCount int `json:"after_count"` - Ratio float64 `json:"ratio"` - Results []MemoryItem `json:"results"` + BeforeCount int `json:"before_count"` + AfterCount int `json:"after_count"` + Ratio float64 `json:"ratio"` + Results []Item `json:"results"` } +// UsageResponse holds memory usage stats (count, bytes, estimated storage). type UsageResponse struct { Count int `json:"count"` TotalTextBytes int64 `json:"total_text_bytes"` @@ -182,6 +204,7 @@ type UsageResponse struct { EstimatedStorageBytes int64 `json:"estimated_storage_bytes"` } +// RebuildResult holds counts after a rebuild (fs, qdrant, missing, restored). type RebuildResult struct { FsCount int `json:"fs_count"` QdrantCount int `json:"qdrant_count"` diff --git a/internal/message/event/hub.go b/internal/message/event/hub.go index eec37cb0..e3068ee2 100644 --- a/internal/message/event/hub.go +++ b/internal/message/event/hub.go @@ -1,3 +1,4 @@ +// Package event provides in-memory event hubs for message delivery. package event import ( @@ -13,17 +14,17 @@ const ( DefaultBufferSize = 64 ) -// EventType identifies the event category published by the message event hub. -type EventType string +// Type identifies the event category published by the message event hub. +type Type string const ( - // EventTypeMessageCreated is emitted after a message is persisted successfully. - EventTypeMessageCreated EventType = "message_created" + // TypeMessageCreated is emitted after a message is persisted successfully. + TypeMessageCreated Type = "message_created" ) // Event is the normalized payload emitted by the in-process message event hub. type Event struct { - Type EventType `json:"type"` + Type Type `json:"type"` BotID string `json:"bot_id"` Data json.RawMessage `json:"data,omitempty"` } diff --git a/internal/message/event/hub_test.go b/internal/message/event/hub_test.go index 987c3861..f3c67777 100644 --- a/internal/message/event/hub_test.go +++ b/internal/message/event/hub_test.go @@ -12,7 +12,7 @@ func TestHubPublishScopedByBotID(t *testing.T) { _, botBStream, cancelB := hub.Subscribe("bot-b", 8) defer cancelB() - hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: TypeMessageCreated, BotID: "bot-a"}) select { case <-botAStream: @@ -47,9 +47,9 @@ func TestHubSlowSubscriberDoesNotBlockPublish(t *testing.T) { _, stream, cancel := hub.Subscribe("bot-a", 1) defer cancel() - hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) - hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) - hub.Publish(Event{Type: EventTypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: TypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: TypeMessageCreated, BotID: "bot-a"}) + hub.Publish(Event{Type: TypeMessageCreated, BotID: "bot-a"}) select { case <-stream: diff --git a/internal/message/service.go b/internal/message/service.go index 5760dd70..7a21fa0e 100644 --- a/internal/message/service.go +++ b/internal/message/service.go @@ -1,3 +1,4 @@ +// Package message provides message persistence and history service. package message import ( @@ -367,7 +368,7 @@ func (s *DBService) publishMessageCreated(message Message) { return } s.publisher.Publish(event.Event{ - Type: event.EventTypeMessageCreated, + Type: event.TypeMessageCreated, BotID: strings.TrimSpace(message.BotID), Data: payload, }) diff --git a/internal/models/models.go b/internal/models/models.go index 18b7fdbd..ff2f2d39 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -1,24 +1,27 @@ +// Package models provides LLM and embedding model types and service. package models import ( "context" + "errors" "fmt" "log/slog" "strings" "github.com/google/uuid" "github.com/jackc/pgx/v5/pgtype" + "github.com/memohai/memoh/internal/db" "github.com/memohai/memoh/internal/db/sqlc" ) -// Service provides CRUD operations for models +// Service provides CRUD operations for models. type Service struct { queries *sqlc.Queries logger *slog.Logger } -// NewService creates a new models service +// NewService creates a new models service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -26,7 +29,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } -// Create adds a new model to the database +// Create adds a new model to the database. func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, error) { model := Model(req) if err := model.Validate(); err != nil { @@ -77,7 +80,7 @@ func (s *Service) Create(ctx context.Context, req AddRequest) (AddResponse, erro }, nil } -// GetByID retrieves a model by its internal UUID +// GetByID retrieves a model by its internal UUID. func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { uuid, err := db.ParseUUID(id) if err != nil { @@ -92,10 +95,10 @@ func (s *Service) GetByID(ctx context.Context, id string) (GetResponse, error) { return convertToGetResponse(dbModel), nil } -// GetByModelID retrieves a model by its model_id field +// GetByModelID retrieves a model by its model_id field. func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse, error) { if modelID == "" { - return GetResponse{}, fmt.Errorf("model_id is required") + return GetResponse{}, errors.New("model_id is required") } dbModel, err := s.queries.GetModelByModelID(ctx, modelID) @@ -106,7 +109,7 @@ func (s *Service) GetByModelID(ctx context.Context, modelID string) (GetResponse return convertToGetResponse(dbModel), nil } -// List returns all models +// List returns all models. func (s *Service) List(ctx context.Context) ([]GetResponse, error) { dbModels, err := s.queries.ListModels(ctx) if err != nil { @@ -116,7 +119,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) { return convertToGetResponseList(dbModels), nil } -// ListByType returns models filtered by type (chat or embedding) +// ListByType returns models filtered by type (chat or embedding). func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetResponse, error) { if modelType != ModelTypeChat && modelType != ModelTypeEmbedding { return nil, fmt.Errorf("invalid model type: %s", modelType) @@ -130,7 +133,7 @@ func (s *Service) ListByType(ctx context.Context, modelType ModelType) ([]GetRes return convertToGetResponseList(dbModels), nil } -// ListByClientType returns models filtered by client type +// ListByClientType returns models filtered by client type. func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) { if !isValidClientType(clientType) { return nil, fmt.Errorf("invalid client type: %s", clientType) @@ -147,7 +150,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ( // ListByProviderID returns models filtered by provider ID. func (s *Service) ListByProviderID(ctx context.Context, providerID string) ([]GetResponse, error) { if strings.TrimSpace(providerID) == "" { - return nil, fmt.Errorf("provider id is required") + return nil, errors.New("provider id is required") } uuid, err := db.ParseUUID(providerID) if err != nil { @@ -166,7 +169,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string return nil, fmt.Errorf("invalid model type: %s", modelType) } if strings.TrimSpace(providerID) == "" { - return nil, fmt.Errorf("provider id is required") + return nil, errors.New("provider id is required") } uuid, err := db.ParseUUID(providerID) if err != nil { @@ -182,7 +185,7 @@ func (s *Service) ListByProviderIDAndType(ctx context.Context, providerID string return convertToGetResponseList(dbModels), nil } -// UpdateByID updates a model by its internal UUID +// UpdateByID updates a model by its internal UUID. func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { uuid, err := db.ParseUUID(id) if err != nil { @@ -222,10 +225,10 @@ func (s *Service) UpdateByID(ctx context.Context, id string, req UpdateRequest) return convertToGetResponse(updated), nil } -// UpdateByModelID updates a model by its model_id field +// UpdateByModelID updates a model by its model_id field. func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req UpdateRequest) (GetResponse, error) { if modelID == "" { - return GetResponse{}, fmt.Errorf("model_id is required") + return GetResponse{}, errors.New("model_id is required") } model := Model(req) @@ -262,7 +265,7 @@ func (s *Service) UpdateByModelID(ctx context.Context, modelID string, req Updat return convertToGetResponse(updated), nil } -// DeleteByID deletes a model by its internal UUID +// DeleteByID deletes a model by its internal UUID. func (s *Service) DeleteByID(ctx context.Context, id string) error { uuid, err := db.ParseUUID(id) if err != nil { @@ -276,10 +279,10 @@ func (s *Service) DeleteByID(ctx context.Context, id string) error { return nil } -// DeleteByModelID deletes a model by its model_id field +// DeleteByModelID deletes a model by its model_id field. func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error { if modelID == "" { - return fmt.Errorf("model_id is required") + return errors.New("model_id is required") } if err := s.queries.DeleteModelByModelID(ctx, modelID); err != nil { @@ -289,7 +292,7 @@ func (s *Service) DeleteByModelID(ctx context.Context, modelID string) error { return nil } -// Count returns the total number of models +// Count returns the total number of models. func (s *Service) Count(ctx context.Context) (int64, error) { count, err := s.queries.CountModels(ctx) if err != nil { @@ -298,7 +301,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) { return count, nil } -// CountByType returns the number of models of a specific type +// CountByType returns the number of models of a specific type. func (s *Service) CountByType(ctx context.Context, modelType ModelType) (int64, error) { if modelType != ModelTypeChat && modelType != ModelTypeEmbedding { return 0, fmt.Errorf("invalid model type: %s", modelType) @@ -325,15 +328,15 @@ func convertToGetResponse(dbModel sqlc.Model) GetResponse { } if dbModel.LlmProviderID.Valid { - resp.Model.LlmProviderID = dbModel.LlmProviderID.String() + resp.LlmProviderID = dbModel.LlmProviderID.String() } if dbModel.Name.Valid { - resp.Model.Name = dbModel.Name.String + resp.Name = dbModel.Name.String } if dbModel.Dimensions.Valid { - resp.Model.Dimensions = int(dbModel.Dimensions.Int32) + resp.Dimensions = int(dbModel.Dimensions.Int32) } return resp @@ -376,11 +379,11 @@ func isValidClientType(clientType ClientType) bool { // SelectMemoryModel selects a chat model for memory operations. func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sqlc.Queries) (GetResponse, sqlc.LlmProvider, error) { if modelsService == nil { - return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("models service not configured") + return GetResponse{}, sqlc.LlmProvider{}, errors.New("models service not configured") } candidates, err := modelsService.ListByType(ctx, ModelTypeChat) if err != nil || len(candidates) == 0 { - return GetResponse{}, sqlc.LlmProvider{}, fmt.Errorf("no chat models available for memory operations") + return GetResponse{}, sqlc.LlmProvider{}, errors.New("no chat models available for memory operations") } selected := candidates[0] provider, err := FetchProviderByID(ctx, queries, selected.LlmProviderID) @@ -393,7 +396,7 @@ func SelectMemoryModel(ctx context.Context, modelsService *Service, queries *sql // FetchProviderByID fetches a provider by ID. func FetchProviderByID(ctx context.Context, queries *sqlc.Queries, providerID string) (sqlc.LlmProvider, error) { if strings.TrimSpace(providerID) == "" { - return sqlc.LlmProvider{}, fmt.Errorf("provider id missing") + return sqlc.LlmProvider{}, errors.New("provider id missing") } parsed, err := db.ParseUUID(providerID) if err != nil { diff --git a/internal/models/models_test.go b/internal/models/models_test.go index 34d45b5b..5c05a9cb 100644 --- a/internal/models/models_test.go +++ b/internal/models/models_test.go @@ -3,8 +3,9 @@ package models_test import ( "testing" - "github.com/memohai/memoh/internal/models" "github.com/stretchr/testify/assert" + + "github.com/memohai/memoh/internal/models" ) // This is an example test file demonstrating how to use the models service @@ -185,20 +186,20 @@ func TestModel_Validate(t *testing.T) { func TestModelTypes(t *testing.T) { t.Run("ModelType constants", func(t *testing.T) { - assert.Equal(t, models.ModelType("chat"), models.ModelTypeChat) - assert.Equal(t, models.ModelType("embedding"), models.ModelTypeEmbedding) + assert.Equal(t, models.ModelTypeChat, models.ModelType("chat")) + assert.Equal(t, models.ModelTypeEmbedding, models.ModelType("embedding")) }) t.Run("ClientType constants", func(t *testing.T) { - assert.Equal(t, models.ClientType("openai"), models.ClientTypeOpenAI) - assert.Equal(t, models.ClientType("openai-compat"), models.ClientTypeOpenAICompat) - assert.Equal(t, models.ClientType("anthropic"), models.ClientTypeAnthropic) - assert.Equal(t, models.ClientType("google"), models.ClientTypeGoogle) - assert.Equal(t, models.ClientType("azure"), models.ClientTypeAzure) - assert.Equal(t, models.ClientType("bedrock"), models.ClientTypeBedrock) - assert.Equal(t, models.ClientType("mistral"), models.ClientTypeMistral) - assert.Equal(t, models.ClientType("xai"), models.ClientTypeXAI) - assert.Equal(t, models.ClientType("ollama"), models.ClientTypeOllama) - assert.Equal(t, models.ClientType("dashscope"), models.ClientTypeDashscope) + assert.Equal(t, models.ClientTypeOpenAI, models.ClientType("openai")) + assert.Equal(t, models.ClientTypeOpenAICompat, models.ClientType("openai-compat")) + assert.Equal(t, models.ClientTypeAnthropic, models.ClientType("anthropic")) + assert.Equal(t, models.ClientTypeGoogle, models.ClientType("google")) + assert.Equal(t, models.ClientTypeAzure, models.ClientType("azure")) + assert.Equal(t, models.ClientTypeBedrock, models.ClientType("bedrock")) + assert.Equal(t, models.ClientTypeMistral, models.ClientType("mistral")) + assert.Equal(t, models.ClientTypeXAI, models.ClientType("xai")) + assert.Equal(t, models.ClientTypeOllama, models.ClientType("ollama")) + assert.Equal(t, models.ClientTypeDashscope, models.ClientType("dashscope")) }) } diff --git a/internal/models/types.go b/internal/models/types.go index a1419dd5..5dda19ae 100644 --- a/internal/models/types.go +++ b/internal/models/types.go @@ -6,20 +6,25 @@ import ( "github.com/google/uuid" ) +// ModelType is the model kind: chat or embedding. type ModelType string +// Model type constants. const ( ModelTypeChat ModelType = "chat" ModelTypeEmbedding ModelType = "embedding" ) +// Supported model input types for multimodal models. const ( ModelInputText = "text" ModelInputImage = "image" ) +// ClientType is the LLM provider client type (openai, anthropic, etc.). type ClientType string +// Client type constants for LLM providers. const ( ClientTypeOpenAI ClientType = "openai" ClientTypeOpenAICompat ClientType = "openai-compat" @@ -33,6 +38,7 @@ const ( ClientTypeDashscope ClientType = "dashscope" ) +// Model is a single model definition (id, provider, type, dimensions, input types). type Model struct { ModelID string `json:"model_id"` Name string `json:"name"` @@ -43,6 +49,7 @@ type Model struct { Dimensions int `json:"dimensions"` } +// Validate checks model ID, provider ID, type, and dimensions (for embedding). func (m *Model) Validate() error { if m.ModelID == "" { return errors.New("model ID is required") @@ -63,38 +70,47 @@ func (m *Model) Validate() error { return nil } +// AddRequest is the input for creating a model (same shape as Model). type AddRequest Model +// AddResponse returns the created model ID. type AddResponse struct { ID string `json:"id"` ModelID string `json:"model_id"` } +// GetRequest is the input for getting a model by ID. type GetRequest struct { ID string `json:"id"` } +// GetResponse is the full model with model_id (for API response). type GetResponse struct { ModelID string `json:"model_id"` Model } +// UpdateRequest is the input for updating a model (same shape as Model). type UpdateRequest Model +// ListRequest optionally filters by type and client type. type ListRequest struct { Type ModelType `json:"type,omitempty"` ClientType ClientType `json:"client_type,omitempty"` } +// DeleteRequest identifies a model by ID or model_id. type DeleteRequest struct { ID string `json:"id,omitempty"` ModelID string `json:"model_id,omitempty"` } +// DeleteResponse holds a message after delete. type DeleteResponse struct { Message string `json:"message"` } +// CountResponse holds the total count for list/count API. type CountResponse struct { Count int64 `json:"count"` } diff --git a/internal/policy/service.go b/internal/policy/service.go index 2c476d1e..f6347b08 100644 --- a/internal/policy/service.go +++ b/internal/policy/service.go @@ -1,8 +1,9 @@ +// Package policy provides access policy evaluation for bots and channels. package policy import ( "context" - "fmt" + "errors" "log/slog" "strings" @@ -10,18 +11,21 @@ import ( "github.com/memohai/memoh/internal/settings" ) +// Decision is the resolved access policy for a bot (type and whether guest is allowed). type Decision struct { BotID string BotType string AllowGuest bool } +// Service evaluates bot access policy using bots and settings services. type Service struct { bots *bots.Service settings *settings.Service logger *slog.Logger } +// NewService creates a policy service. func NewService(log *slog.Logger, botsService *bots.Service, settingsService *settings.Service) *Service { if log == nil { log = slog.Default() @@ -36,11 +40,11 @@ func NewService(log *slog.Logger, botsService *bots.Service, settingsService *se // Resolve evaluates the full access policy for a bot. func (s *Service) Resolve(ctx context.Context, botID string) (Decision, error) { if s == nil || s.bots == nil || s.settings == nil { - return Decision{}, fmt.Errorf("policy service not configured") + return Decision{}, errors.New("policy service not configured") } botID = strings.TrimSpace(botID) if botID == "" { - return Decision{}, fmt.Errorf("bot id is required") + return Decision{}, errors.New("bot id is required") } bot, err := s.bots.Get(ctx, botID) if err != nil { @@ -82,7 +86,7 @@ func (s *Service) BotType(ctx context.Context, botID string) (string, error) { // BotOwnerUserID returns bot owner's user id. Implements router.PolicyService. func (s *Service) BotOwnerUserID(ctx context.Context, botID string) (string, error) { if s == nil || s.bots == nil { - return "", fmt.Errorf("policy service not configured") + return "", errors.New("policy service not configured") } bot, err := s.bots.Get(ctx, strings.TrimSpace(botID)) if err != nil { diff --git a/internal/preauth/service.go b/internal/preauth/service.go index f2abcc2a..564b10ab 100644 --- a/internal/preauth/service.go +++ b/internal/preauth/service.go @@ -1,9 +1,9 @@ +// Package preauth provides pre-authentication token issuance and validation. package preauth import ( "context" "errors" - "fmt" "strings" "time" @@ -15,12 +15,15 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// ErrKeyNotFound is returned when a preauth key lookup by token finds no row. var ErrKeyNotFound = errors.New("preauth key not found") +// Service issues and validates preauth keys for bot access. type Service struct { queries *sqlc.Queries } +// NewService creates a preauth service using the given queries. func NewService(queries *sqlc.Queries) *Service { return &Service{queries: queries} } @@ -28,7 +31,7 @@ func NewService(queries *sqlc.Queries) *Service { // Issue creates a new preauth key for the given bot. func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl time.Duration) (Key, error) { if s.queries == nil { - return Key{}, fmt.Errorf("preauth queries not configured") + return Key{}, errors.New("preauth queries not configured") } if ttl <= 0 { ttl = 24 * time.Hour @@ -59,9 +62,10 @@ func (s *Service) Issue(ctx context.Context, botID, issuedByUserID string, ttl t return normalizeKey(row), nil } +// Get returns the preauth key for the given token, or ErrKeyNotFound if not found. func (s *Service) Get(ctx context.Context, token string) (Key, error) { if s.queries == nil { - return Key{}, fmt.Errorf("preauth queries not configured") + return Key{}, errors.New("preauth queries not configured") } row, err := s.queries.GetBotPreauthKey(ctx, strings.TrimSpace(token)) if err != nil { @@ -73,9 +77,10 @@ func (s *Service) Get(ctx context.Context, token string) (Key, error) { return normalizeKey(row), nil } +// MarkUsed marks the preauth key by ID as used and returns the updated key. func (s *Service) MarkUsed(ctx context.Context, id string) (Key, error) { if s.queries == nil { - return Key{}, fmt.Errorf("preauth queries not configured") + return Key{}, errors.New("preauth queries not configured") } pgID, err := db.ParseUUID(id) if err != nil { diff --git a/internal/providers/service.go b/internal/providers/service.go index b1a205f0..1523efb8 100644 --- a/internal/providers/service.go +++ b/internal/providers/service.go @@ -1,3 +1,4 @@ +// Package providers provides LLM provider configuration and management. package providers import ( @@ -11,13 +12,13 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) -// Service handles provider operations +// Service handles provider operations. type Service struct { queries *sqlc.Queries logger *slog.Logger } -// NewService creates a new provider service +// NewService creates a new provider service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -25,7 +26,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } -// Create creates a new LLM provider +// Create creates a new LLM provider. func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) { // Validate client type if !isValidClientType(req.ClientType) { @@ -53,7 +54,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e return s.toGetResponse(provider), nil } -// Get retrieves a provider by ID +// Get retrieves a provider by ID. func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { providerID, err := db.ParseUUID(id) if err != nil { @@ -68,7 +69,7 @@ func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { return s.toGetResponse(provider), nil } -// GetByName retrieves a provider by name +// GetByName retrieves a provider by name. func (s *Service) GetByName(ctx context.Context, name string) (GetResponse, error) { provider, err := s.queries.GetLlmProviderByName(ctx, name) if err != nil { @@ -78,7 +79,7 @@ func (s *Service) GetByName(ctx context.Context, name string) (GetResponse, erro return s.toGetResponse(provider), nil } -// List retrieves all providers +// List retrieves all providers. func (s *Service) List(ctx context.Context) ([]GetResponse, error) { providers, err := s.queries.ListLlmProviders(ctx) if err != nil { @@ -92,7 +93,7 @@ func (s *Service) List(ctx context.Context) ([]GetResponse, error) { return results, nil } -// ListByClientType retrieves providers by client type +// ListByClientType retrieves providers by client type. func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ([]GetResponse, error) { if !isValidClientType(clientType) { return nil, fmt.Errorf("invalid client_type: %s", clientType) @@ -110,7 +111,7 @@ func (s *Service) ListByClientType(ctx context.Context, clientType ClientType) ( return results, nil } -// Update updates an existing provider +// Update updates an existing provider. func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { providerID, err := db.ParseUUID(id) if err != nil { @@ -169,7 +170,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get return s.toGetResponse(updated), nil } -// Delete deletes a provider by ID +// Delete deletes a provider by ID. func (s *Service) Delete(ctx context.Context, id string) error { providerID, err := db.ParseUUID(id) if err != nil { @@ -182,7 +183,7 @@ func (s *Service) Delete(ctx context.Context, id string) error { return nil } -// Count returns the total count of providers +// Count returns the total count of providers. func (s *Service) Count(ctx context.Context) (int64, error) { count, err := s.queries.CountLlmProviders(ctx) if err != nil { @@ -191,7 +192,7 @@ func (s *Service) Count(ctx context.Context) (int64, error) { return count, nil } -// CountByClientType returns the count of providers by client type +// CountByClientType returns the count of providers by client type. func (s *Service) CountByClientType(ctx context.Context, clientType ClientType) (int64, error) { if !isValidClientType(clientType) { return 0, fmt.Errorf("invalid client_type: %s", clientType) @@ -204,7 +205,7 @@ func (s *Service) CountByClientType(ctx context.Context, clientType ClientType) return count, nil } -// toGetResponse converts a database provider to a response +// toGetResponse converts a database provider to a response. func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { var metadata map[string]any if len(provider.Metadata) > 0 { @@ -228,7 +229,7 @@ func (s *Service) toGetResponse(provider sqlc.LlmProvider) GetResponse { } } -// isValidClientType checks if a client type is valid +// isValidClientType checks if a client type is valid. func isValidClientType(clientType ClientType) bool { switch clientType { case ClientTypeOpenAI, ClientTypeOpenAICompat, ClientTypeAnthropic, ClientTypeGoogle, @@ -240,7 +241,7 @@ func isValidClientType(clientType ClientType) bool { } } -// maskAPIKey masks an API key for security +// maskAPIKey masks an API key for security. func maskAPIKey(apiKey string) string { if apiKey == "" { return "" diff --git a/internal/providers/types.go b/internal/providers/types.go index 1af60fca..32acdca6 100644 --- a/internal/providers/types.go +++ b/internal/providers/types.go @@ -2,9 +2,10 @@ package providers import "time" -// ClientType represents the type of LLM provider client +// ClientType represents the type of LLM provider client. type ClientType string +// LLM provider client type constants. const ( ClientTypeOpenAI ClientType = "openai" ClientTypeOpenAICompat ClientType = "openai-compat" @@ -18,7 +19,7 @@ const ( ClientTypeDashscope ClientType = "dashscope" ) -// CreateRequest represents a request to create a new LLM provider +// CreateRequest represents a request to create a new LLM provider. type CreateRequest struct { Name string `json:"name" validate:"required"` ClientType ClientType `json:"client_type" validate:"required"` @@ -27,7 +28,7 @@ type CreateRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } -// UpdateRequest represents a request to update an existing LLM provider +// UpdateRequest represents a request to update an existing LLM provider. type UpdateRequest struct { Name *string `json:"name,omitempty"` ClientType *ClientType `json:"client_type,omitempty"` @@ -36,7 +37,7 @@ type UpdateRequest struct { Metadata map[string]any `json:"metadata,omitempty"` } -// GetResponse represents the response for getting a provider +// GetResponse represents the response for getting a provider. type GetResponse struct { ID string `json:"id"` Name string `json:"name"` @@ -48,18 +49,18 @@ type GetResponse struct { UpdatedAt time.Time `json:"updated_at"` } -// ListResponse represents the response for listing providers +// ListResponse represents the response for listing providers. type ListResponse struct { Providers []GetResponse `json:"providers"` Total int64 `json:"total"` } -// CountResponse represents the count response +// CountResponse represents the count response. type CountResponse struct { Count int64 `json:"count"` } -// TestRequest represents a request to test provider connection +// TestRequest represents a request to test provider connection. type TestRequest struct { ClientType ClientType `json:"client_type" validate:"required"` BaseURL string `json:"base_url" validate:"required,url"` @@ -67,7 +68,7 @@ type TestRequest struct { Model string `json:"model"` // optional test model } -// TestResponse represents the result of testing a provider +// TestResponse represents the result of testing a provider. type TestResponse struct { Success bool `json:"success"` Message string `json:"message,omitempty"` diff --git a/internal/schedule/service.go b/internal/schedule/service.go index 64042aa0..b86a9e4a 100644 --- a/internal/schedule/service.go +++ b/internal/schedule/service.go @@ -1,3 +1,4 @@ +// Package schedule provides cron schedule and trigger management. package schedule import ( @@ -19,6 +20,7 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service manages cron schedules and triggers execution via Triggerer. type Service struct { queries *sqlc.Queries cron *cron.Cron @@ -30,6 +32,7 @@ type Service struct { jobs map[string]cron.EntryID } +// NewService creates a schedule service and starts the cron scheduler. func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, runtimeConfig *boot.RuntimeConfig) *Service { parser := cron.NewParser(cron.SecondOptional | cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow | cron.Descriptor) c := cron.New(cron.WithParser(parser)) @@ -46,9 +49,10 @@ func NewService(log *slog.Logger, queries *sqlc.Queries, triggerer Triggerer, ru return service } +// Bootstrap loads all enabled schedules from DB and registers them with the cron scheduler. func (s *Service) Bootstrap(ctx context.Context) error { if s.queries == nil { - return fmt.Errorf("schedule queries not configured") + return errors.New("schedule queries not configured") } items, err := s.queries.ListEnabledSchedules(ctx) if err != nil { @@ -62,12 +66,13 @@ func (s *Service) Bootstrap(ctx context.Context) error { return nil } +// Create creates a new schedule for the bot and starts the cron job if enabled. func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) (Schedule, error) { if s.queries == nil { - return Schedule{}, fmt.Errorf("schedule queries not configured") + return Schedule{}, errors.New("schedule queries not configured") } if strings.TrimSpace(req.Name) == "" || strings.TrimSpace(req.Description) == "" || strings.TrimSpace(req.Pattern) == "" || strings.TrimSpace(req.Command) == "" { - return Schedule{}, fmt.Errorf("name, description, pattern, command are required") + return Schedule{}, errors.New("name, description, pattern, command are required") } if _, err := s.parser.Parse(req.Pattern); err != nil { return Schedule{}, fmt.Errorf("invalid cron pattern: %w", err) @@ -104,6 +109,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( return toSchedule(row), nil } +// Get returns the schedule by ID. func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { pgID, err := db.ParseUUID(id) if err != nil { @@ -112,13 +118,14 @@ func (s *Service) Get(ctx context.Context, id string) (Schedule, error) { row, err := s.queries.GetScheduleByID(ctx, pgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return Schedule{}, fmt.Errorf("schedule not found") + return Schedule{}, errors.New("schedule not found") } return Schedule{}, err } return toSchedule(row), nil } +// List returns all schedules for the given bot. func (s *Service) List(ctx context.Context, botID string) ([]Schedule, error) { pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -135,6 +142,7 @@ func (s *Service) List(ctx context.Context, botID string) ([]Schedule, error) { return items, nil } +// Update updates the schedule by ID and reschedules the cron job if pattern/enabled changed. func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Schedule, error) { pgID, err := db.ParseUUID(id) if err != nil { @@ -193,6 +201,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sch return toSchedule(updated), nil } +// Delete removes the schedule and its cron job. func (s *Service) Delete(ctx context.Context, id string) error { pgID, err := db.ParseUUID(id) if err != nil { @@ -205,16 +214,17 @@ func (s *Service) Delete(ctx context.Context, id string) error { return nil } +// Trigger runs the schedule once by ID (calls Triggerer with payload and JWT). func (s *Service) Trigger(ctx context.Context, scheduleID string) error { if s.triggerer == nil { - return fmt.Errorf("schedule triggerer not configured") + return errors.New("schedule triggerer not configured") } schedule, err := s.Get(ctx, scheduleID) if err != nil { return err } if !schedule.Enabled { - return fmt.Errorf("schedule is disabled") + return errors.New("schedule is disabled") } return s.runSchedule(ctx, schedule) } @@ -223,7 +233,7 @@ const scheduleTokenTTL = 10 * time.Minute func (s *Service) runSchedule(ctx context.Context, schedule Schedule) error { if s.triggerer == nil { - return fmt.Errorf("schedule triggerer not configured") + return errors.New("schedule triggerer not configured") } updated, err := s.queries.IncrementScheduleCalls(ctx, toUUID(schedule.ID)) if err != nil { @@ -266,7 +276,7 @@ func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, er } ownerID := bot.OwnerUserID.String() if ownerID == "" { - return "", fmt.Errorf("bot owner not found") + return "", errors.New("bot owner not found") } return ownerID, nil } @@ -274,7 +284,7 @@ func (s *Service) resolveBotOwner(ctx context.Context, botID string) (string, er // generateTriggerToken creates a short-lived JWT for schedule trigger callbacks. func (s *Service) generateTriggerToken(userID string) (string, error) { if strings.TrimSpace(s.jwtSecret) == "" { - return "", fmt.Errorf("jwt secret not configured") + return "", errors.New("jwt secret not configured") } signed, _, err := auth.GenerateToken(userID, s.jwtSecret, scheduleTokenTTL) if err != nil { @@ -286,7 +296,7 @@ func (s *Service) generateTriggerToken(userID string) (string, error) { func (s *Service) scheduleJob(schedule sqlc.Schedule) error { id := schedule.ID.String() if id == "" { - return fmt.Errorf("schedule id missing") + return errors.New("schedule id missing") } job := func() { if err := s.runSchedule(context.Background(), toSchedule(schedule)); err != nil { @@ -337,8 +347,8 @@ func toSchedule(row sqlc.Schedule) Schedule { BotID: row.BotID.String(), } if row.MaxCalls.Valid { - max := int(row.MaxCalls.Int32) - item.MaxCalls = &max + maxCalls := int(row.MaxCalls.Int32) + item.MaxCalls = &maxCalls } if row.CreatedAt.Valid { item.CreatedAt = row.CreatedAt.Time diff --git a/internal/schedule/service_test.go b/internal/schedule/service_test.go index 15ec87fc..c9e9e9e5 100644 --- a/internal/schedule/service_test.go +++ b/internal/schedule/service_test.go @@ -1,7 +1,6 @@ package schedule import ( - "context" "log/slog" "strings" "testing" @@ -10,21 +9,6 @@ import ( "github.com/golang-jwt/jwt/v5" ) -type mockTriggerer struct { - called bool - botID string - payload TriggerPayload - token string -} - -func (m *mockTriggerer) TriggerSchedule(_ context.Context, botID string, payload TriggerPayload, token string) error { - m.called = true - m.botID = botID - m.payload = payload - m.token = token - return nil -} - func TestGenerateTriggerToken(t *testing.T) { secret := "test-secret-key-for-schedule" svc := &Service{ @@ -42,7 +26,7 @@ func TestGenerateTriggerToken(t *testing.T) { } raw := strings.TrimPrefix(tok, "Bearer ") - parsed, err := jwt.Parse(raw, func(token *jwt.Token) (any, error) { + parsed, err := jwt.Parse(raw, func(_ *jwt.Token) (any, error) { return []byte(secret), nil }) if err != nil { diff --git a/internal/schedule/types.go b/internal/schedule/types.go index 7c95b1f6..5c43d825 100644 --- a/internal/schedule/types.go +++ b/internal/schedule/types.go @@ -5,6 +5,7 @@ import ( "time" ) +// Schedule is a cron schedule attached to a bot (pattern, command, max calls, enabled). type Schedule struct { ID string `json:"id"` Name string `json:"name"` @@ -19,15 +20,18 @@ type Schedule struct { BotID string `json:"bot_id"` } +// NullableInt represents an optional int for JSON (null vs omitted). type NullableInt struct { Value *int Set bool } +// IsZero reports whether the value was not set (omitempty semantics). func (n NullableInt) IsZero() bool { return !n.Set } +// MarshalJSON encodes as null when unset or value is nil, otherwise the int. func (n NullableInt) MarshalJSON() ([]byte, error) { if !n.Set || n.Value == nil { return []byte("null"), nil @@ -35,6 +39,7 @@ func (n NullableInt) MarshalJSON() ([]byte, error) { return json.Marshal(*n.Value) } +// UnmarshalJSON decodes null or an int and sets Set true. func (n *NullableInt) UnmarshalJSON(data []byte) error { n.Set = true if string(data) == "null" { @@ -49,24 +54,27 @@ func (n *NullableInt) UnmarshalJSON(data []byte) error { return nil } +// CreateRequest is the input for creating a schedule (name, description, cron pattern, command, etc.). type CreateRequest struct { Name string `json:"name"` Description string `json:"description"` Pattern string `json:"pattern"` - MaxCalls NullableInt `json:"max_calls,omitempty"` + MaxCalls NullableInt `json:"max_calls,omitzero"` Command string `json:"command"` Enabled *bool `json:"enabled,omitempty"` } +// UpdateRequest is the input for updating a schedule (all fields optional). type UpdateRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Pattern *string `json:"pattern,omitempty"` - MaxCalls NullableInt `json:"max_calls,omitempty"` + MaxCalls NullableInt `json:"max_calls,omitzero"` Command *string `json:"command,omitempty"` Enabled *bool `json:"enabled,omitempty"` } +// ListResponse holds the list of schedules for list API. type ListResponse struct { Items []Schedule `json:"items"` } diff --git a/internal/searchproviders/service.go b/internal/searchproviders/service.go index 5bd724f0..73bf7fc8 100644 --- a/internal/searchproviders/service.go +++ b/internal/searchproviders/service.go @@ -1,3 +1,4 @@ +// Package searchproviders provides search provider configuration and management. package searchproviders import ( @@ -11,11 +12,13 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service manages search provider configs (create, list, get, update, delete). type Service struct { queries *sqlc.Queries logger *slog.Logger } +// NewService creates a search providers service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -23,6 +26,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } +// ListMeta returns metadata for all supported providers (display name, config schema). func (s *Service) ListMeta(_ context.Context) []ProviderMeta { return []ProviderMeta{ { @@ -56,6 +60,7 @@ func (s *Service) ListMeta(_ context.Context) []ProviderMeta { } } +// Create creates a new search provider config. func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, error) { if !isValidProviderName(req.Provider) { return GetResponse{}, fmt.Errorf("invalid provider: %s", req.Provider) @@ -75,6 +80,7 @@ func (s *Service) Create(ctx context.Context, req CreateRequest) (GetResponse, e return s.toGetResponse(row), nil } +// Get returns the search provider config by ID. func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { pgID, err := db.ParseUUID(id) if err != nil { @@ -87,6 +93,7 @@ func (s *Service) Get(ctx context.Context, id string) (GetResponse, error) { return s.toGetResponse(row), nil } +// GetRawByID returns the raw sqlc row for the search provider by ID. func (s *Service) GetRawByID(ctx context.Context, id string) (sqlc.SearchProvider, error) { pgID, err := db.ParseUUID(id) if err != nil { @@ -95,6 +102,7 @@ func (s *Service) GetRawByID(ctx context.Context, id string) (sqlc.SearchProvide return s.queries.GetSearchProviderByID(ctx, pgID) } +// List returns all provider configs, optionally filtered by provider name. func (s *Service) List(ctx context.Context, provider string) ([]GetResponse, error) { provider = strings.TrimSpace(provider) var ( @@ -116,6 +124,7 @@ func (s *Service) List(ctx context.Context, provider string) ([]GetResponse, err return items, nil } +// Update updates the search provider config by ID. func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (GetResponse, error) { pgID, err := db.ParseUUID(id) if err != nil { @@ -156,6 +165,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Get return s.toGetResponse(updated), nil } +// Delete removes the search provider config by ID. func (s *Service) Delete(ctx context.Context, id string) error { pgID, err := db.ParseUUID(id) if err != nil { diff --git a/internal/searchproviders/types.go b/internal/searchproviders/types.go index d00e86c1..b9fee83b 100644 --- a/internal/searchproviders/types.go +++ b/internal/searchproviders/types.go @@ -2,16 +2,20 @@ package searchproviders import "time" +// ProviderName identifies a search provider (e.g. brave). type ProviderName string +// Supported provider name constants. const ( ProviderBrave ProviderName = "brave" ) +// ProviderConfigSchema describes the config fields for a provider (for UI). type ProviderConfigSchema struct { Fields map[string]ProviderFieldSchema `json:"fields"` } +// ProviderFieldSchema describes a single config field (type, title, required, etc.). type ProviderFieldSchema struct { Type string `json:"type"` Title string `json:"title,omitempty"` @@ -21,24 +25,28 @@ type ProviderFieldSchema struct { Example any `json:"example,omitempty"` } +// ProviderMeta is metadata for a provider (display name and config schema). type ProviderMeta struct { Provider string `json:"provider"` DisplayName string `json:"display_name"` ConfigSchema ProviderConfigSchema `json:"config_schema"` } +// CreateRequest is the input for creating a search provider config. type CreateRequest struct { Name string `json:"name"` Provider ProviderName `json:"provider"` Config map[string]any `json:"config,omitempty"` } +// UpdateRequest is the input for updating a search provider config (all fields optional). type UpdateRequest struct { Name *string `json:"name,omitempty"` Provider *ProviderName `json:"provider,omitempty"` Config map[string]any `json:"config,omitempty"` } +// GetResponse is the API response for a single search provider config. type GetResponse struct { ID string `json:"id"` Name string `json:"name"` diff --git a/internal/server/server.go b/internal/server/server.go index f2e64937..000a2337 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -1,3 +1,4 @@ +// Package server provides the HTTP server and Echo setup for the agent API. package server import ( @@ -11,17 +12,20 @@ import ( "github.com/memohai/memoh/internal/auth" ) +// Server is the HTTP server (Echo) with JWT middleware and registered handlers. type Server struct { echo *echo.Echo addr string logger *slog.Logger } +// Handler registers routes on the Echo instance. type Handler interface { Register(e *echo.Echo) } -func NewServer(log *slog.Logger, addr string, jwtSecret string, +// NewServer builds the Echo server with recovery, request logging, JWT auth, and the given handlers. +func NewServer(log *slog.Logger, addr, jwtSecret string, handlers ...Handler, ) *Server { if addr == "" { @@ -70,10 +74,12 @@ func NewServer(log *slog.Logger, addr string, jwtSecret string, } } +// Start starts the HTTP server (blocks until shutdown). func (s *Server) Start() error { return s.echo.Start(s.addr) } +// Stop gracefully shuts down the server using the given context. func (s *Server) Stop(ctx context.Context) error { return s.echo.Shutdown(ctx) } diff --git a/internal/settings/service.go b/internal/settings/service.go index 838ac03c..cf2c683e 100644 --- a/internal/settings/service.go +++ b/internal/settings/service.go @@ -1,9 +1,9 @@ +// Package settings provides user and system settings persistence. package settings import ( "context" "errors" - "fmt" "log/slog" "strings" @@ -14,13 +14,16 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service reads and updates bot-level and system settings. type Service struct { queries *sqlc.Queries logger *slog.Logger } +// ErrPersonalBotGuestAccessUnsupported is returned when enabling guest access on a personal bot. var ErrPersonalBotGuestAccessUnsupported = errors.New("personal bots do not support guest access") +// NewService creates a settings service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -28,6 +31,7 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } +// GetBot returns the settings for the given bot. func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) { pgID, err := db.ParseUUID(botID) if err != nil { @@ -40,9 +44,10 @@ func (s *Service) GetBot(ctx context.Context, botID string) (Settings, error) { return normalizeBotSettingsReadRow(row), nil } +// UpsertBot updates bot settings (model IDs, max context time, language, allow guest); returns ErrPersonalBotGuestAccessUnsupported for personal bots if AllowGuest is set. func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest) (Settings, error) { if s.queries == nil { - return Settings{}, fmt.Errorf("settings queries not configured") + return Settings{}, errors.New("settings queries not configured") } pgID, err := db.ParseUUID(botID) if err != nil { @@ -119,9 +124,10 @@ func (s *Service) UpsertBot(ctx context.Context, botID string, req UpsertRequest return normalizeBotSettingsWriteRow(updated), nil } +// Delete removes bot-level settings for the given bot. func (s *Service) Delete(ctx context.Context, botID string) error { if s.queries == nil { - return fmt.Errorf("settings queries not configured") + return errors.New("settings queries not configured") } pgID, err := db.ParseUUID(botID) if err != nil { @@ -190,7 +196,7 @@ func normalizeBotSettingsFields( func (s *Service) resolveModelUUID(ctx context.Context, modelID string) (pgtype.UUID, error) { if strings.TrimSpace(modelID) == "" { - return pgtype.UUID{}, fmt.Errorf("model_id is required") + return pgtype.UUID{}, errors.New("model_id is required") } row, err := s.queries.GetModelByModelID(ctx, modelID) if err != nil { diff --git a/internal/settings/types.go b/internal/settings/types.go index 33c6dc66..cd563fe8 100644 --- a/internal/settings/types.go +++ b/internal/settings/types.go @@ -1,10 +1,12 @@ package settings +// Default values for bot settings when not set. const ( DefaultMaxContextLoadTime = 24 * 60 DefaultLanguage = "auto" ) +// Settings holds bot-level settings (models, max context minutes, language, guest access). type Settings struct { ChatModelID string `json:"chat_model_id"` MemoryModelID string `json:"memory_model_id"` @@ -15,6 +17,7 @@ type Settings struct { AllowGuest bool `json:"allow_guest"` } +// UpsertRequest is the input for upserting bot settings (all fields optional). type UpsertRequest struct { ChatModelID string `json:"chat_model_id,omitempty"` MemoryModelID string `json:"memory_model_id,omitempty"` diff --git a/internal/subagent/service.go b/internal/subagent/service.go index c9e5ce77..42364cc7 100644 --- a/internal/subagent/service.go +++ b/internal/subagent/service.go @@ -1,10 +1,10 @@ +// Package subagent provides sub-agent definition and management. package subagent import ( "context" "encoding/json" "errors" - "fmt" "log/slog" "strings" @@ -14,11 +14,13 @@ import ( "github.com/memohai/memoh/internal/db/sqlc" ) +// Service manages subagent CRUD and context/skills updates. type Service struct { queries *sqlc.Queries logger *slog.Logger } +// NewService creates a subagent service. func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { return &Service{ queries: queries, @@ -26,17 +28,18 @@ func NewService(log *slog.Logger, queries *sqlc.Queries) *Service { } } +// Create creates a new subagent for the bot. func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) (Subagent, error) { if s.queries == nil { - return Subagent{}, fmt.Errorf("subagent queries not configured") + return Subagent{}, errors.New("subagent queries not configured") } name := strings.TrimSpace(req.Name) if name == "" { - return Subagent{}, fmt.Errorf("name is required") + return Subagent{}, errors.New("name is required") } description := strings.TrimSpace(req.Description) if description == "" { - return Subagent{}, fmt.Errorf("description is required") + return Subagent{}, errors.New("description is required") } pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -68,6 +71,7 @@ func (s *Service) Create(ctx context.Context, botID string, req CreateRequest) ( return toSubagent(row) } +// Get returns the subagent by ID. func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { pgID, err := db.ParseUUID(id) if err != nil { @@ -76,13 +80,14 @@ func (s *Service) Get(ctx context.Context, id string) (Subagent, error) { row, err := s.queries.GetSubagentByID(ctx, pgID) if err != nil { if errors.Is(err, pgx.ErrNoRows) { - return Subagent{}, fmt.Errorf("subagent not found") + return Subagent{}, errors.New("subagent not found") } return Subagent{}, err } return toSubagent(row) } +// List returns all subagents for the given bot. func (s *Service) List(ctx context.Context, botID string) ([]Subagent, error) { pgBotID, err := db.ParseUUID(botID) if err != nil { @@ -103,6 +108,7 @@ func (s *Service) List(ctx context.Context, botID string) ([]Subagent, error) { return items, nil } +// Update updates subagent name, description, and metadata by ID. func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Subagent, error) { existing, err := s.Get(ctx, id) if err != nil { @@ -112,14 +118,14 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sub if req.Name != nil { name = strings.TrimSpace(*req.Name) if name == "" { - return Subagent{}, fmt.Errorf("name is required") + return Subagent{}, errors.New("name is required") } } description := existing.Description if req.Description != nil { description = strings.TrimSpace(*req.Description) if description == "" { - return Subagent{}, fmt.Errorf("description is required") + return Subagent{}, errors.New("description is required") } } metadata := existing.Metadata @@ -146,6 +152,7 @@ func (s *Service) Update(ctx context.Context, id string, req UpdateRequest) (Sub return toSubagent(row) } +// UpdateContext replaces the subagent system messages/context by ID. func (s *Service) UpdateContext(ctx context.Context, id string, req UpdateContextRequest) (Subagent, error) { messagesPayload, err := marshalMessages(req.Messages) if err != nil { @@ -165,6 +172,7 @@ func (s *Service) UpdateContext(ctx context.Context, id string, req UpdateContex return toSubagent(row) } +// UpdateSkills replaces the subagent skills list by ID. func (s *Service) UpdateSkills(ctx context.Context, id string, req UpdateSkillsRequest) (Subagent, error) { skillsPayload, err := marshalSkills(req.Skills) if err != nil { @@ -184,6 +192,7 @@ func (s *Service) UpdateSkills(ctx context.Context, id string, req UpdateSkillsR return toSubagent(row) } +// AddSkills appends skills to the subagent by ID (no duplicates). func (s *Service) AddSkills(ctx context.Context, id string, req AddSkillsRequest) (Subagent, error) { existing, err := s.Get(ctx, id) if err != nil { @@ -208,6 +217,7 @@ func (s *Service) AddSkills(ctx context.Context, id string, req AddSkillsRequest return toSubagent(row) } +// Delete soft-deletes the subagent by ID. func (s *Service) Delete(ctx context.Context, id string) error { pgID, err := db.ParseUUID(id) if err != nil { @@ -329,9 +339,8 @@ func normalizeSkills(skills []string) []string { return normalized } -func mergeSkills(existing []string, incoming []string) []string { +func mergeSkills(existing, incoming []string) []string { merged := append([]string{}, existing...) merged = append(merged, incoming...) return normalizeSkills(merged) } - diff --git a/internal/subagent/types.go b/internal/subagent/types.go index 38207e0d..d10e6bbf 100644 --- a/internal/subagent/types.go +++ b/internal/subagent/types.go @@ -2,6 +2,7 @@ package subagent import "time" +// Subagent is a bot sub-agent definition (name, description, messages, skills, metadata). type Subagent struct { ID string `json:"id"` Name string `json:"name"` @@ -16,6 +17,7 @@ type Subagent struct { DeletedAt *time.Time `json:"deleted_at,omitempty"` } +// CreateRequest is the input for creating a subagent. type CreateRequest struct { Name string `json:"name"` Description string `json:"description"` @@ -24,32 +26,39 @@ type CreateRequest struct { Skills []string `json:"skills,omitempty"` } +// UpdateRequest is the input for updating name, description, metadata (all optional). type UpdateRequest struct { Name *string `json:"name,omitempty"` Description *string `json:"description,omitempty"` Metadata map[string]any `json:"metadata,omitempty"` } +// UpdateContextRequest is the input for replacing system messages. type UpdateContextRequest struct { Messages []map[string]any `json:"messages"` } +// UpdateSkillsRequest is the input for replacing the skills list. type UpdateSkillsRequest struct { Skills []string `json:"skills"` } +// AddSkillsRequest is the input for appending skills. type AddSkillsRequest struct { Skills []string `json:"skills"` } +// ListResponse holds the list of subagents for list API. type ListResponse struct { Items []Subagent `json:"items"` } +// ContextResponse holds the subagent context (messages) for get-context API. type ContextResponse struct { Messages []map[string]any `json:"messages"` } +// SkillsResponse holds the subagent skills list for get-skills API. type SkillsResponse struct { Skills []string `json:"skills"` } diff --git a/internal/version/version.go b/internal/version/version.go index d1d9857e..3b90b536 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -1,3 +1,6 @@ +// Package version provides application version and build info. +// +//nolint:revive package version import ( diff --git a/mise.toml b/mise.toml index a955eb41..20e10e69 100644 --- a/mise.toml +++ b/mise.toml @@ -2,7 +2,7 @@ experimental_monorepo_root = true [tools] # Go version from go.mod -go = "1.25.6" +go = "latest" # Node.js for frontend packages node = "25" # Bun for agent gateway @@ -13,6 +13,8 @@ pnpm = "10" sqlc = "latest" # typos for spell check typos = "latest" +# golangci-lint for Go linting +golangci-lint = "latest" # Lima for macOS lima = { version = "system", platform = "darwin" } diff --git a/spec/docs.go b/spec/docs.go index 17698791..198ec678 100644 --- a/spec/docs.go +++ b/spec/docs.go @@ -2582,7 +2582,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelConfig" + "$ref": "#/definitions/channel.Config" } }, "400": { @@ -2646,7 +2646,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelConfig" + "$ref": "#/definitions/channel.Config" } }, "400": { @@ -4434,7 +4434,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelIdentityBinding" + "$ref": "#/definitions/channel.IdentityBinding" } }, "400": { @@ -4485,7 +4485,7 @@ const docTemplate = `{ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelIdentityBinding" + "$ref": "#/definitions/channel.IdentityBinding" } }, "400": { @@ -5125,7 +5125,7 @@ const docTemplate = `{ "AttachmentGIF" ] }, - "channel.ChannelCapabilities": { + "channel.Capabilities": { "type": "object", "properties": { "attachments": { @@ -5181,7 +5181,7 @@ const docTemplate = `{ } } }, - "channel.ChannelConfig": { + "channel.Config": { "type": "object", "properties": { "bot_id": { @@ -5222,30 +5222,6 @@ const docTemplate = `{ } } }, - "channel.ChannelIdentityBinding": { - "type": "object", - "properties": { - "channel_identity_id": { - "type": "string" - }, - "channel_type": { - "type": "string" - }, - "config": { - "type": "object", - "additionalProperties": {} - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "string" - }, - "updated_at": { - "type": "string" - } - } - }, "channel.ConfigSchema": { "type": "object", "properties": { @@ -5301,6 +5277,30 @@ const docTemplate = `{ "FieldEnum" ] }, + "channel.IdentityBinding": { + "type": "object", + "properties": { + "channel_identity_id": { + "type": "string" + }, + "channel_type": { + "type": "string" + }, + "config": { + "type": "object", + "additionalProperties": {} + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "updated_at": { + "type": "string" + } + } + }, "channel.Message": { "type": "object", "properties": { @@ -5559,7 +5559,7 @@ const docTemplate = `{ "type": "object", "properties": { "capabilities": { - "$ref": "#/definitions/channel.ChannelCapabilities" + "$ref": "#/definitions/channel.Capabilities" }, "config_schema": { "$ref": "#/definitions/channel.ConfigSchema" @@ -6066,7 +6066,7 @@ const docTemplate = `{ "mcpServers": { "type": "object", "additionalProperties": { - "$ref": "#/definitions/mcp.MCPServerEntry" + "$ref": "#/definitions/mcp.ServerEntry" } } } @@ -6077,7 +6077,7 @@ const docTemplate = `{ "mcpServers": { "type": "object", "additionalProperties": { - "$ref": "#/definitions/mcp.MCPServerEntry" + "$ref": "#/definitions/mcp.ServerEntry" } } } @@ -6093,7 +6093,7 @@ const docTemplate = `{ } } }, - "mcp.MCPServerEntry": { + "mcp.ServerEntry": { "type": "object", "properties": { "args": { @@ -6197,7 +6197,7 @@ const docTemplate = `{ "results": { "type": "array", "items": { - "$ref": "#/definitions/memory.MemoryItem" + "$ref": "#/definitions/memory.Item" } } } @@ -6210,7 +6210,7 @@ const docTemplate = `{ } } }, - "memory.MemoryItem": { + "memory.Item": { "type": "object", "properties": { "agent_id": { @@ -6296,7 +6296,7 @@ const docTemplate = `{ "results": { "type": "array", "items": { - "$ref": "#/definitions/memory.MemoryItem" + "$ref": "#/definitions/memory.Item" } } } @@ -7042,7 +7042,7 @@ const docTemplate = `{ } }` -// SwaggerInfo holds exported Swagger Info so clients can modify it +// SwaggerInfo holds exported Swagger Info so clients can modify it. var SwaggerInfo = &swag.Spec{ Version: "1.0.0", Host: "", diff --git a/spec/swagger.json b/spec/swagger.json index e9570cac..f4815e9d 100644 --- a/spec/swagger.json +++ b/spec/swagger.json @@ -2573,7 +2573,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelConfig" + "$ref": "#/definitions/channel.Config" } }, "400": { @@ -2637,7 +2637,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelConfig" + "$ref": "#/definitions/channel.Config" } }, "400": { @@ -4425,7 +4425,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelIdentityBinding" + "$ref": "#/definitions/channel.IdentityBinding" } }, "400": { @@ -4476,7 +4476,7 @@ "200": { "description": "OK", "schema": { - "$ref": "#/definitions/channel.ChannelIdentityBinding" + "$ref": "#/definitions/channel.IdentityBinding" } }, "400": { @@ -5116,7 +5116,7 @@ "AttachmentGIF" ] }, - "channel.ChannelCapabilities": { + "channel.Capabilities": { "type": "object", "properties": { "attachments": { @@ -5172,7 +5172,7 @@ } } }, - "channel.ChannelConfig": { + "channel.Config": { "type": "object", "properties": { "bot_id": { @@ -5213,30 +5213,6 @@ } } }, - "channel.ChannelIdentityBinding": { - "type": "object", - "properties": { - "channel_identity_id": { - "type": "string" - }, - "channel_type": { - "type": "string" - }, - "config": { - "type": "object", - "additionalProperties": {} - }, - "created_at": { - "type": "string" - }, - "id": { - "type": "string" - }, - "updated_at": { - "type": "string" - } - } - }, "channel.ConfigSchema": { "type": "object", "properties": { @@ -5292,6 +5268,30 @@ "FieldEnum" ] }, + "channel.IdentityBinding": { + "type": "object", + "properties": { + "channel_identity_id": { + "type": "string" + }, + "channel_type": { + "type": "string" + }, + "config": { + "type": "object", + "additionalProperties": {} + }, + "created_at": { + "type": "string" + }, + "id": { + "type": "string" + }, + "updated_at": { + "type": "string" + } + } + }, "channel.Message": { "type": "object", "properties": { @@ -5550,7 +5550,7 @@ "type": "object", "properties": { "capabilities": { - "$ref": "#/definitions/channel.ChannelCapabilities" + "$ref": "#/definitions/channel.Capabilities" }, "config_schema": { "$ref": "#/definitions/channel.ConfigSchema" @@ -6057,7 +6057,7 @@ "mcpServers": { "type": "object", "additionalProperties": { - "$ref": "#/definitions/mcp.MCPServerEntry" + "$ref": "#/definitions/mcp.ServerEntry" } } } @@ -6068,7 +6068,7 @@ "mcpServers": { "type": "object", "additionalProperties": { - "$ref": "#/definitions/mcp.MCPServerEntry" + "$ref": "#/definitions/mcp.ServerEntry" } } } @@ -6084,7 +6084,7 @@ } } }, - "mcp.MCPServerEntry": { + "mcp.ServerEntry": { "type": "object", "properties": { "args": { @@ -6188,7 +6188,7 @@ "results": { "type": "array", "items": { - "$ref": "#/definitions/memory.MemoryItem" + "$ref": "#/definitions/memory.Item" } } } @@ -6201,7 +6201,7 @@ } } }, - "memory.MemoryItem": { + "memory.Item": { "type": "object", "properties": { "agent_id": { @@ -6287,7 +6287,7 @@ "results": { "type": "array", "items": { - "$ref": "#/definitions/memory.MemoryItem" + "$ref": "#/definitions/memory.Item" } } } diff --git a/spec/swagger.yaml b/spec/swagger.yaml index 8b24c220..926ac512 100644 --- a/spec/swagger.yaml +++ b/spec/swagger.yaml @@ -254,7 +254,7 @@ definitions: - AttachmentVoice - AttachmentFile - AttachmentGIF - channel.ChannelCapabilities: + channel.Capabilities: properties: attachments: type: boolean @@ -291,7 +291,7 @@ definitions: unsend: type: boolean type: object - channel.ChannelConfig: + channel.Config: properties: bot_id: type: string @@ -319,22 +319,6 @@ definitions: verified_at: type: string type: object - channel.ChannelIdentityBinding: - properties: - channel_identity_id: - type: string - channel_type: - type: string - config: - additionalProperties: {} - type: object - created_at: - type: string - id: - type: string - updated_at: - type: string - type: object channel.ConfigSchema: properties: fields: @@ -374,6 +358,22 @@ definitions: - FieldBool - FieldNumber - FieldEnum + channel.IdentityBinding: + properties: + channel_identity_id: + type: string + channel_type: + type: string + config: + additionalProperties: {} + type: object + created_at: + type: string + id: + type: string + updated_at: + type: string + type: object channel.Message: properties: actions: @@ -551,7 +551,7 @@ definitions: handlers.ChannelMeta: properties: capabilities: - $ref: '#/definitions/channel.ChannelCapabilities' + $ref: '#/definitions/channel.Capabilities' config_schema: $ref: '#/definitions/channel.ConfigSchema' configless: @@ -881,14 +881,14 @@ definitions: properties: mcpServers: additionalProperties: - $ref: '#/definitions/mcp.MCPServerEntry' + $ref: '#/definitions/mcp.ServerEntry' type: object type: object mcp.ImportRequest: properties: mcpServers: additionalProperties: - $ref: '#/definitions/mcp.MCPServerEntry' + $ref: '#/definitions/mcp.ServerEntry' type: object type: object mcp.ListResponse: @@ -898,7 +898,7 @@ definitions: $ref: '#/definitions/github_com_memohai_memoh_internal_mcp.Connection' type: array type: object - mcp.MCPServerEntry: + mcp.ServerEntry: properties: args: items: @@ -967,7 +967,7 @@ definitions: type: number results: items: - $ref: '#/definitions/memory.MemoryItem' + $ref: '#/definitions/memory.Item' type: array type: object memory.DeleteResponse: @@ -975,7 +975,7 @@ definitions: message: type: string type: object - memory.MemoryItem: + memory.Item: properties: agent_id: type: string @@ -1032,7 +1032,7 @@ definitions: type: array results: items: - $ref: '#/definitions/memory.MemoryItem' + $ref: '#/definitions/memory.Item' type: array type: object memory.TopKBucket: @@ -3244,7 +3244,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/channel.ChannelConfig' + $ref: '#/definitions/channel.Config' "400": description: Bad Request schema: @@ -3287,7 +3287,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/channel.ChannelConfig' + $ref: '#/definitions/channel.Config' "400": description: Bad Request schema: @@ -4582,7 +4582,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/channel.ChannelIdentityBinding' + $ref: '#/definitions/channel.IdentityBinding' "400": description: Bad Request schema: @@ -4616,7 +4616,7 @@ paths: "200": description: OK schema: - $ref: '#/definitions/channel.ChannelIdentityBinding' + $ref: '#/definitions/channel.IdentityBinding' "400": description: Bad Request schema: