diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/cmd_agent.go index 8658c9d32..1c92e0b6c 100644 --- a/cmd/picoclaw/cmd_agent.go +++ b/cmd/picoclaw/cmd_agent.go @@ -70,6 +70,7 @@ func agentCmd() { } msgBus := bus.NewMessageBus() + defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) // Print agent startup info (only for interactive mode) diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 28ef76ad3..798ad2813 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -6,26 +6,36 @@ package main import ( "context" "fmt" - "net/http" "os" "os/signal" "path/filepath" - "strings" "time" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" + _ "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/line" + _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" + _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" + _ "github.com/sipeed/picoclaw/pkg/channels/qq" + _ "github.com/sipeed/picoclaw/pkg/channels/slack" + _ "github.com/sipeed/picoclaw/pkg/channels/telegram" + _ "github.com/sipeed/picoclaw/pkg/channels/wecom" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/voice" ) func gatewayCmd() { @@ -112,50 +122,18 @@ func gatewayCmd() { return tools.SilentResult(response) }) - channelManager, err := channels.NewManager(cfg, msgBus) + // Create media store for file lifecycle management + mediaStore := media.NewFileMediaStore() + + channelManager, err := channels.NewManager(cfg, msgBus, mediaStore) if err != nil { fmt.Printf("Error creating channel manager: %v\n", err) os.Exit(1) } - // Inject channel manager into agent loop for command handling + // Inject channel manager and media store into agent loop agentLoop.SetChannelManager(channelManager) - - var transcriber *voice.GroqTranscriber - groqAPIKey := cfg.Providers.Groq.APIKey - if groqAPIKey == "" { - for _, mc := range cfg.ModelList { - if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { - groqAPIKey = mc.APIKey - break - } - } - } - if groqAPIKey != "" { - transcriber = voice.NewGroqTranscriber(groqAPIKey) - logger.InfoC("voice", "Groq voice transcription enabled") - } - - if transcriber != nil { - if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { - tc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Telegram channel") - } - } - if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { - dc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Discord channel") - } - } - if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { - sc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Slack channel") - } - } - } + agentLoop.SetMediaStore(mediaStore) enabledChannels := channelManager.GetEnabledChannels() if len(enabledChannels) > 0 { @@ -192,16 +170,15 @@ func gatewayCmd() { fmt.Println("✓ Device event service started") } + // Setup shared HTTP server with health endpoints and webhook handlers + healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) + channelManager.SetupHTTPServer(addr, healthServer) + if err := channelManager.StartAll(ctx); err != nil { fmt.Printf("Error starting channels: %v\n", err) } - healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - go func() { - if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) - } - }() fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) go agentLoop.Run(ctx) @@ -212,12 +189,18 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() - healthServer.Stop(context.Background()) + msgBus.Close() + + // Use a fresh context with timeout for graceful shutdown, + // since the original ctx is already cancelled. + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) + defer shutdownCancel() + + channelManager.StopAll(shutdownCtx) deviceService.Stop() heartbeatService.Stop() cronService.Stop() agentLoop.Stop() - channelManager.StopAll(ctx) fmt.Println("✓ Gateway stopped") } diff --git a/config/config.example.json b/config/config.example.json index 555509732..149247526 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -52,30 +52,35 @@ "proxy": "", "allow_from": [ "YOUR_USER_ID" - ] + ], + "reasoning_channel_id": "" }, "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", "allow_from": [], - "mention_only": false + "mention_only": false, + "reasoning_channel_id": "" }, "qq": { "enabled": false, "app_id": "YOUR_QQ_APP_ID", "app_secret": "YOUR_QQ_APP_SECRET", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "maixcam": { "enabled": false, "host": "0.0.0.0", "port": 18790, - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "whatsapp": { "enabled": false, "bridge_url": "ws://localhost:3001", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "feishu": { "enabled": false, @@ -83,19 +88,22 @@ "app_secret": "", "encrypt_key": "", "verification_token": "", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "dingtalk": { "enabled": false, "client_id": "YOUR_CLIENT_ID", "client_secret": "YOUR_CLIENT_SECRET", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "slack": { "enabled": false, "bot_token": "xoxb-YOUR-BOT-TOKEN", "app_token": "xapp-YOUR-APP-TOKEN", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "line": { "enabled": false, @@ -104,7 +112,8 @@ "webhook_host": "0.0.0.0", "webhook_port": 18791, "webhook_path": "/webhook/line", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "onebot": { "enabled": false, @@ -112,7 +121,8 @@ "access_token": "", "reconnect_interval": 5, "group_trigger_prefix": [], - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "wecom": { "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats", @@ -124,7 +134,8 @@ "webhook_port": 18793, "webhook_path": "/webhook/wecom", "allow_from": [], - "reply_timeout": 5 + "reply_timeout": 5, + "reasoning_channel_id": "" }, "wecom_app": { "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md", @@ -138,7 +149,8 @@ "webhook_port": 18792, "webhook_path": "/webhook/wecom-app", "allow_from": [], - "reply_timeout": 5 + "reply_timeout": 5, + "reasoning_channel_id": "" } }, "providers": { @@ -250,4 +262,4 @@ "host": "127.0.0.1", "port": 18790 } -} +} \ No newline at end of file diff --git a/go.mod b/go.mod index 1f88639c8..32436ce53 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/time v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0e95bf5cd..0e2d37cab 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index bf229ad74..e2bd222b9 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "path/filepath" "strings" "sync" "sync/atomic" @@ -21,6 +22,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -38,6 +40,7 @@ type AgentLoop struct { summarizing sync.Map fallback *providers.FallbackChain channelManager *channels.Manager + mediaStore media.MediaStore } // processOptions configures how a message is processed @@ -118,12 +121,13 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: content, }) - return nil }) agent.Tools.Register(messageTool) @@ -167,33 +171,49 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } + // Process message + func() { + // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. + // Currently disabled because files are deleted before the LLM can access their content. + // defer func() { + // if al.mediaStore != nil && msg.MediaScope != "" { + // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { + // logger.WarnCF("agent", "Failed to release media", map[string]any{ + // "scope": msg.MediaScope, + // "error": releaseErr.Error(), + // }) + // } + // } + // }() + + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) + } - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } } } - } - if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + if !alreadySent { + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } } - } + }() } } @@ -216,6 +236,41 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// SetMediaStore injects a MediaStore for media lifecycle management. +func (al *AgentLoop) SetMediaStore(s media.MediaStore) { + al.mediaStore = s +} + +// inferMediaType determines the media type ("image", "audio", "video", "file") +// from a filename and MIME content type. +func inferMediaType(filename, contentType string) string { + ct := strings.ToLower(contentType) + fn := strings.ToLower(filename) + + if strings.HasPrefix(ct, "image/") { + return "image" + } + if strings.HasPrefix(ct, "audio/") || ct == "application/ogg" { + return "audio" + } + if strings.HasPrefix(ct, "video/") { + return "video" + } + + // Fallback: infer from extension + ext := filepath.Ext(fn) + switch ext { + case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg": + return "image" + case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus": + return "audio" + case ".mp4", ".avi", ".mov", ".webm", ".mkv": + return "video" + } + + return "file" +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -450,7 +505,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -470,6 +525,28 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt return finalContent, nil } +func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) { + if al.channelManager == nil { + return "" + } + if ch, ok := al.channelManager.GetChannel(channelName); ok { + return ch.ReasoningChannelID() + } + return "" +} + +func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) { + if reasoningContent == "" || channelName == "" || channelID == "" { + return + } + + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: channelName, + ChatID: channelID, + Content: reasoningContent, + }) +} + // runLLMIteration executes the LLM call loop with tool handling. func (al *AgentLoop) runLLMIteration( ctx context.Context, @@ -565,7 +642,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -594,6 +671,18 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } + go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) + + logger.DebugCF("agent", "LLM response", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "content_chars": len(response.Content), + "tool_calls": len(response.ToolCalls), + "reasoning": response.Reasoning, + "target_channel": al.targetReasoningChannelID(opts.Channel), + "channel": opts.Channel, + }) // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content @@ -694,7 +783,7 @@ func (al *AgentLoop) runLLMIteration( // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, @@ -706,6 +795,28 @@ func (al *AgentLoop) runLLMIteration( }) } + // If tool returned media refs, publish them as outbound media + if len(toolResult.Media) > 0 && opts.SendResponse { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { + part := bus.MediaPart{Ref: ref} + // Populate metadata from MediaStore when available + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) + } + al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Parts: parts, + }) + } + // Determine content for LLM based on tool result contentForLLM := toolResult.ForLLM if contentForLLM == "" && toolResult.Err != nil { @@ -759,7 +870,9 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c go func() { defer al.summarizing.Delete(summarizeKey) if !constants.IsInternalChannel(channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: "Memory threshold reached. Optimizing conversation history...", @@ -1122,21 +1235,20 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } -// extractPeer extracts the routing peer from inbound message metadata. +// extractPeer extracts the routing peer from the inbound message's structured Peer field. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - peerKind := msg.Metadata["peer_kind"] - if peerKind == "" { + if msg.Peer.Kind == "" { return nil } - peerID := msg.Metadata["peer_id"] + peerID := msg.Peer.ID if peerID == "" { - if peerKind == "direct" { + if msg.Peer.Kind == "direct" { peerID = msg.SenderID } else { peerID = msg.ChatID } } - return &routing.RoutePeer{Kind: peerKind, ID: peerID} + return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} } // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4414398b1..6dfc7ef3e 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -9,11 +9,23 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) +type fakeChannel struct{ id string } + +func (f *fakeChannel) Name() string { return "fake" } +func (f *fakeChannel) Start(ctx context.Context) error { return nil } +func (f *fakeChannel) Stop(ctx context.Context) error { return nil } +func (f *fakeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil } +func (f *fakeChannel) IsRunning() bool { return true } +func (f *fakeChannel) IsAllowed(string) bool { return true } +func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } +func (f *fakeChannel) ReasoningChannelID() string { return f.id } + func TestRecordLastChannel(t *testing.T) { // Create temp workspace tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -631,3 +643,158 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +func TestTargetReasoningChannelID_AllChannels(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{}) + chManager, err := channels.NewManager(&config.Config{}, bus.NewMessageBus(), nil) + if err != nil { + t.Fatalf("Failed to create channel manager: %v", err) + } + for name, id := range map[string]string{ + "whatsapp": "rid-whatsapp", + "telegram": "rid-telegram", + "feishu": "rid-feishu", + "discord": "rid-discord", + "maixcam": "rid-maixcam", + "qq": "rid-qq", + "dingtalk": "rid-dingtalk", + "slack": "rid-slack", + "line": "rid-line", + "onebot": "rid-onebot", + "wecom": "rid-wecom", + "wecom_app": "rid-wecom-app", + } { + chManager.RegisterChannel(name, &fakeChannel{id: id}) + } + al.SetChannelManager(chManager) + tests := []struct { + channel string + wantID string + }{ + {channel: "whatsapp", wantID: "rid-whatsapp"}, + {channel: "telegram", wantID: "rid-telegram"}, + {channel: "feishu", wantID: "rid-feishu"}, + {channel: "discord", wantID: "rid-discord"}, + {channel: "maixcam", wantID: "rid-maixcam"}, + {channel: "qq", wantID: "rid-qq"}, + {channel: "dingtalk", wantID: "rid-dingtalk"}, + {channel: "slack", wantID: "rid-slack"}, + {channel: "line", wantID: "rid-line"}, + {channel: "onebot", wantID: "rid-onebot"}, + {channel: "wecom", wantID: "rid-wecom"}, + {channel: "wecom_app", wantID: "rid-wecom-app"}, + {channel: "unknown", wantID: ""}, + } + + for _, tt := range tests { + t.Run(tt.channel, func(t *testing.T) { + got := al.targetReasoningChannelID(tt.channel) + if got != tt.wantID { + t.Fatalf("targetReasoningChannelID(%q) = %q, want %q", tt.channel, got, tt.wantID) + } + }) + } +} + +func TestHandleReasoning(t *testing.T) { + newLoop := func(t *testing.T) (*AgentLoop, *bus.MessageBus) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + msgBus := bus.NewMessageBus() + return NewAgentLoop(cfg, msgBus, &mockProvider{}), msgBus + } + + t.Run("skips when any required field is empty", func(t *testing.T) { + al, msgBus := newLoop(t) + al.handleReasoning(context.Background(), "reasoning", "telegram", "") + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if msg, ok := msgBus.SubscribeOutbound(ctx); ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + }) + + t.Run("publishes one message for non telegram", func(t *testing.T) { + al, msgBus := newLoop(t) + al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatal("expected an outbound message") + } + if msg.Channel != "slack" || msg.ChatID != "channel-1" || msg.Content != "hello reasoning" { + t.Fatalf("unexpected outbound message: %+v", msg) + } + }) + + t.Run("publishes one message for telegram", func(t *testing.T) { + al, msgBus := newLoop(t) + reasoning := "hello telegram reasoning" + al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatal("expected outbound message") + } + + if msg.Channel != "telegram" { + t.Fatalf("expected telegram channel message, got %+v", msg) + } + if msg.ChatID != "tg-chat" { + t.Fatalf("expected chatID tg-chat, got %+v", msg) + } + if msg.Content != reasoning { + t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning) + } + }) + t.Run("expired ctx", func(t *testing.T) { + al, msgBus := newLoop(t) + reasoning := "hello telegram reasoning" + ctx, cancel := context.WithCancel(context.Background()) + cancel() + al.handleReasoning(ctx, reasoning, "telegram", "tg-chat") + + ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + }) +} diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 58c0a25d5..c749b6535 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -2,81 +2,154 @@ package bus import ( "context" - "sync" + "errors" + "sync/atomic" + + "github.com/sipeed/picoclaw/pkg/logger" ) +// ErrBusClosed is returned when publishing to a closed MessageBus. +var ErrBusClosed = errors.New("message bus closed") + type MessageBus struct { - inbound chan InboundMessage - outbound chan OutboundMessage - handlers map[string]MessageHandler - closed bool - mu sync.RWMutex + inbound chan InboundMessage + outbound chan OutboundMessage + outboundMedia chan OutboundMediaMessage + done chan struct{} + closed atomic.Bool } func NewMessageBus() *MessageBus { return &MessageBus{ - inbound: make(chan InboundMessage, 100), - outbound: make(chan OutboundMessage, 100), - handlers: make(map[string]MessageHandler), + inbound: make(chan InboundMessage, 100), + outbound: make(chan OutboundMessage, 100), + outboundMedia: make(chan OutboundMediaMessage, 100), + done: make(chan struct{}), } } -func (mb *MessageBus) PublishInbound(msg InboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + if err := ctx.Err(); err != nil { + return err + } + select { + case mb.inbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.inbound <- msg } func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { select { - case msg := <-mb.inbound: - return msg, true + case msg, ok := <-mb.inbound: + return msg, ok + case <-mb.done: + return InboundMessage{}, false case <-ctx.Done(): return InboundMessage{}, false } } -func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + if err := ctx.Err(); err != nil { + return err + } + select { + case mb.outbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.outbound <- msg } func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { select { - case msg := <-mb.outbound: - return msg, true + case msg, ok := <-mb.outbound: + return msg, ok + case <-mb.done: + return OutboundMessage{}, false case <-ctx.Done(): return OutboundMessage{}, false } } -func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) { - mb.mu.Lock() - defer mb.mu.Unlock() - mb.handlers[channel] = handler +func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + if err := ctx.Err(); err != nil { + return err + } + select { + case mb.outboundMedia <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() + } } -func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { - mb.mu.RLock() - defer mb.mu.RUnlock() - handler, ok := mb.handlers[channel] - return handler, ok +func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) { + select { + case msg, ok := <-mb.outboundMedia: + return msg, ok + case <-mb.done: + return OutboundMediaMessage{}, false + case <-ctx.Done(): + return OutboundMediaMessage{}, false + } } func (mb *MessageBus) Close() { - mb.mu.Lock() - defer mb.mu.Unlock() - if mb.closed { - return + if mb.closed.CompareAndSwap(false, true) { + close(mb.done) + + // Drain buffered channels so messages aren't silently lost. + // Channels are NOT closed to avoid send-on-closed panics from concurrent publishers. + drained := 0 + for { + select { + case <-mb.inbound: + drained++ + default: + goto doneInbound + } + } + doneInbound: + for { + select { + case <-mb.outbound: + drained++ + default: + goto doneOutbound + } + } + doneOutbound: + for { + select { + case <-mb.outboundMedia: + drained++ + default: + goto doneMedia + } + } + doneMedia: + if drained > 0 { + logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ + "count": drained, + }) + } } - mb.closed = true - close(mb.inbound) - close(mb.outbound) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 000000000..47826824e --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,229 @@ +package bus + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestPublishConsume(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + } + + if err := mb.PublishInbound(ctx, msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got, ok := mb.ConsumeInbound(ctx) + if !ok { + t.Fatal("ConsumeInbound returned ok=false") + } + if got.Content != "hello" { + t.Fatalf("expected content 'hello', got %q", got.Content) + } + if got.Channel != "test" { + t.Fatalf("expected channel 'test', got %q", got.Channel) + } +} + +func TestPublishOutboundSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := OutboundMessage{ + Channel: "telegram", + ChatID: "123", + Content: "world", + } + + if err := mb.PublishOutbound(ctx, msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got, ok := mb.SubscribeOutbound(ctx) + if !ok { + t.Fatal("SubscribeOutbound returned ok=false") + } + if got.Content != "world" { + t.Fatalf("expected content 'world', got %q", got.Content) + } +} + +func TestPublishInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + // Fill the buffer + ctx := context.Background() + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Now buffer is full; publish with a cancelled context + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestPublishInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestPublishOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestConsumeInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when context is cancelled") + } +} + +func TestConsumeInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestSubscribeOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.SubscribeOutbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestConcurrentPublishClose(t *testing.T) { + mb := NewMessageBus() + ctx := context.Background() + + const numGoroutines = 100 + var wg sync.WaitGroup + wg.Add(numGoroutines + 1) + + // Spawn many goroutines trying to publish + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // Use a short timeout context so we don't block forever after close + publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + // Errors are expected; we just must not panic or deadlock + _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"}) + }() + } + + // Close from another goroutine + go func() { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + mb.Close() + }() + + // Must complete without deadlock + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // success + case <-time.After(5 * time.Second): + t.Fatal("test timed out - possible deadlock") + } +} + +func TestPublishInbound_FullBuffer(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + // Fill the buffer + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Buffer is full; publish with short timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error when buffer is full and context times out") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestCloseIdempotent(t *testing.T) { + mb := NewMessageBus() + + // Multiple Close calls must not panic + mb.Close() + mb.Close() + mb.Close() + + // After close, publish should return ErrBusClosed + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) + } +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 44f9181a5..7ad8f0417 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -1,11 +1,30 @@ package bus +// Peer identifies the routing peer for a message (direct, group, channel, etc.) +type Peer struct { + Kind string `json:"kind"` // "direct" | "group" | "channel" | "" + ID string `json:"id"` +} + +// SenderInfo provides structured sender identity information. +type SenderInfo struct { + Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ... + PlatformID string `json:"platform_id,omitempty"` // raw platform ID, e.g. "123456" + CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" format + Username string `json:"username,omitempty"` // username (e.g. @alice) + DisplayName string `json:"display_name,omitempty"` // display name +} + type InboundMessage struct { Channel string `json:"channel"` SenderID string `json:"sender_id"` + Sender SenderInfo `json:"sender"` ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` + Peer Peer `json:"peer"` // routing peer + MessageID string `json:"message_id,omitempty"` // platform message ID + MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope SessionKey string `json:"session_key"` Metadata map[string]string `json:"metadata,omitempty"` } @@ -16,4 +35,18 @@ type OutboundMessage struct { Content string `json:"content"` } -type MessageHandler func(InboundMessage) error +// MediaPart describes a single media attachment to send. +type MediaPart struct { + Type string `json:"type"` // "image" | "audio" | "video" | "file" + Ref string `json:"ref"` // media store ref, e.g. "media://abc123" + Caption string `json:"caption,omitempty"` // optional caption text + Filename string `json:"filename,omitempty"` // original filename hint + ContentType string `json:"content_type,omitempty"` // MIME type hint +} + +// OutboundMediaMessage carries media attachments from Agent to channels via the bus. +type OutboundMediaMessage struct { + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Parts []MediaPart `json:"parts"` +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index cd6419ebb..c8c721341 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -3,8 +3,15 @@ package channels import ( "context" "strings" + "sync/atomic" + + "github.com/google/uuid" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" ) type Channel interface { @@ -14,32 +21,125 @@ type Channel interface { Send(ctx context.Context, msg bus.OutboundMessage) error IsRunning() bool IsAllowed(senderID string) bool + IsAllowedSender(sender bus.SenderInfo) bool + ReasoningChannelID() string +} + +// BaseChannelOption is a functional option for configuring a BaseChannel. +type BaseChannelOption func(*BaseChannel) + +// WithMaxMessageLength sets the maximum message length (in runes) for a channel. +// Messages exceeding this limit will be automatically split by the Manager. +// A value of 0 means no limit. +func WithMaxMessageLength(n int) BaseChannelOption { + return func(c *BaseChannel) { c.maxMessageLength = n } +} + +// WithGroupTrigger sets the group trigger configuration for a channel. +func WithGroupTrigger(gt config.GroupTriggerConfig) BaseChannelOption { + return func(c *BaseChannel) { c.groupTrigger = gt } +} + +// WithReasoningChannelID sets the reasoning channel ID where thoughts should be sent. +func WithReasoningChannelID(id string) BaseChannelOption { + return func(c *BaseChannel) { c.reasoningChannelID = id } +} + +// MessageLengthProvider is an opt-in interface that channels implement +// to advertise their maximum message length. The Manager uses this via +// type assertion to decide whether to split outbound messages. +type MessageLengthProvider interface { + MaxMessageLength() int } type BaseChannel struct { - config any - bus *bus.MessageBus - running bool - name string - allowList []string + config any + bus *bus.MessageBus + running atomic.Bool + name string + allowList []string + maxMessageLength int + groupTrigger config.GroupTriggerConfig + mediaStore media.MediaStore + placeholderRecorder PlaceholderRecorder + reasoningChannelID string } -func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []string) *BaseChannel { - return &BaseChannel{ +func NewBaseChannel( + name string, + config any, + bus *bus.MessageBus, + allowList []string, + opts ...BaseChannelOption, +) *BaseChannel { + bc := &BaseChannel{ config: config, bus: bus, name: name, allowList: allowList, - running: false, } + for _, opt := range opts { + opt(bc) + } + return bc +} + +// MaxMessageLength returns the maximum message length (in runes) for this channel. +// A value of 0 means no limit. +func (c *BaseChannel) MaxMessageLength() int { + return c.maxMessageLength +} + +// ShouldRespondInGroup determines whether the bot should respond in a group chat. +// Each channel is responsible for: +// 1. Detecting isMentioned (platform-specific) +// 2. Stripping bot mention from content (platform-specific) +// 3. Calling this method to get the group response decision +// +// Logic: +// - If isMentioned → always respond +// - If mention_only configured and not mentioned → ignore +// - If prefixes configured → respond if content starts with any prefix (strip it) +// - If prefixes configured but no match and not mentioned → ignore +// - Otherwise (no group_trigger configured) → respond to all (permissive default) +func (c *BaseChannel) ShouldRespondInGroup(isMentioned bool, content string) (bool, string) { + gt := c.groupTrigger + + // Mentioned → always respond + if isMentioned { + return true, strings.TrimSpace(content) + } + + // mention_only → require mention + if gt.MentionOnly { + return false, content + } + + // Prefix matching + if len(gt.Prefixes) > 0 { + for _, prefix := range gt.Prefixes { + if prefix != "" && strings.HasPrefix(content, prefix) { + return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) + } + } + // Prefixes configured but none matched and not mentioned → ignore + return false, content + } + + // No group_trigger configured → permissive (respond to all) + return true, strings.TrimSpace(content) } func (c *BaseChannel) Name() string { return c.name } +func (c *BaseChannel) ReasoningChannelID() string { + return c.reasoningChannelID +} + func (c *BaseChannel) IsRunning() bool { - return c.running + return c.running.Load() } func (c *BaseChannel) IsAllowed(senderID string) bool { @@ -81,23 +181,101 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { return false } -func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []string, metadata map[string]string) { - if !c.IsAllowed(senderID) { - return +// IsAllowedSender checks whether a structured SenderInfo is permitted by the allow-list. +// It delegates to identity.MatchAllowed for each entry, providing unified matching +// across all legacy formats and the new canonical "platform:id" format. +func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool { + if len(c.allowList) == 0 { + return true } + for _, allowed := range c.allowList { + if identity.MatchAllowed(sender, allowed) { + return true + } + } + + return false +} + +func (c *BaseChannel) HandleMessage( + ctx context.Context, + peer bus.Peer, + messageID, senderID, chatID, content string, + media []string, + metadata map[string]string, + senderOpts ...bus.SenderInfo, +) { + // Use SenderInfo-based allow check when available, else fall back to string + var sender bus.SenderInfo + if len(senderOpts) > 0 { + sender = senderOpts[0] + } + if sender.CanonicalID != "" || sender.PlatformID != "" { + if !c.IsAllowedSender(sender) { + return + } + } else { + if !c.IsAllowed(senderID) { + return + } + } + + // Set SenderID to canonical if available, otherwise keep the raw senderID + resolvedSenderID := senderID + if sender.CanonicalID != "" { + resolvedSenderID = sender.CanonicalID + } + + scope := BuildMediaScope(c.name, chatID, messageID) + msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - Metadata: metadata, + Channel: c.name, + SenderID: resolvedSenderID, + Sender: sender, + ChatID: chatID, + Content: content, + Media: media, + Peer: peer, + MessageID: messageID, + MediaScope: scope, + Metadata: metadata, } - c.bus.PublishInbound(msg) + if err := c.bus.PublishInbound(ctx, msg); err != nil { + logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{ + "channel": c.name, + "chat_id": chatID, + "error": err.Error(), + }) + } } -func (c *BaseChannel) setRunning(running bool) { - c.running = running +func (c *BaseChannel) SetRunning(running bool) { + c.running.Store(running) +} + +// SetMediaStore injects a MediaStore into the channel. +func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } + +// GetMediaStore returns the injected MediaStore (may be nil). +func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } + +// SetPlaceholderRecorder injects a PlaceholderRecorder into the channel. +func (c *BaseChannel) SetPlaceholderRecorder(r PlaceholderRecorder) { + c.placeholderRecorder = r +} + +// GetPlaceholderRecorder returns the injected PlaceholderRecorder (may be nil). +func (c *BaseChannel) GetPlaceholderRecorder() PlaceholderRecorder { + return c.placeholderRecorder +} + +// BuildMediaScope constructs a scope key for media lifecycle tracking. +func BuildMediaScope(channel, chatID, messageID string) string { + id := messageID + if id == "" { + id = uuid.New().String() + } + return channel + ":" + chatID + ":" + id } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go index 78c6d1d66..6132b8bf9 100644 --- a/pkg/channels/base_test.go +++ b/pkg/channels/base_test.go @@ -1,6 +1,11 @@ package channels -import "testing" +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) func TestBaseChannelIsAllowed(t *testing.T) { tests := []struct { @@ -50,3 +55,211 @@ func TestBaseChannelIsAllowed(t *testing.T) { }) } } + +func TestShouldRespondInGroup(t *testing.T) { + tests := []struct { + name string + gt config.GroupTriggerConfig + isMentioned bool + content string + wantRespond bool + wantContent string + }{ + { + name: "no config - permissive default", + gt: config.GroupTriggerConfig{}, + isMentioned: false, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "no config - mentioned", + gt: config.GroupTriggerConfig{}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "mention_only - not mentioned", + gt: config.GroupTriggerConfig{MentionOnly: true}, + isMentioned: false, + content: "hello world", + wantRespond: false, + wantContent: "hello world", + }, + { + name: "mention_only - mentioned", + gt: config.GroupTriggerConfig{MentionOnly: true}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "prefix match", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "/ask hello", + wantRespond: true, + wantContent: "hello", + }, + { + name: "prefix no match - not mentioned", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "hello world", + wantRespond: false, + wantContent: "hello world", + }, + { + name: "prefix no match - but mentioned", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "multiple prefixes - second matches", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask", "/bot"}}, + isMentioned: false, + content: "/bot help me", + wantRespond: true, + wantContent: "help me", + }, + { + name: "mention_only with prefixes - mentioned overrides", + gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}}, + isMentioned: true, + content: "hello", + wantRespond: true, + wantContent: "hello", + }, + { + name: "mention_only with prefixes - not mentioned, no prefix", + gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "hello", + wantRespond: false, + wantContent: "hello", + }, + { + name: "empty prefix in list is skipped", + gt: config.GroupTriggerConfig{Prefixes: []string{"", "/ask"}}, + isMentioned: false, + content: "/ask test", + wantRespond: true, + wantContent: "test", + }, + { + name: "prefix strips leading whitespace after prefix", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask "}}, + isMentioned: false, + content: "/ask hello", + wantRespond: true, + wantContent: "hello", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, nil, WithGroupTrigger(tt.gt)) + gotRespond, gotContent := ch.ShouldRespondInGroup(tt.isMentioned, tt.content) + if gotRespond != tt.wantRespond { + t.Errorf("ShouldRespondInGroup() respond = %v, want %v", gotRespond, tt.wantRespond) + } + if gotContent != tt.wantContent { + t.Errorf("ShouldRespondInGroup() content = %q, want %q", gotContent, tt.wantContent) + } + }) + } +} + +func TestIsAllowedSender(t *testing.T) { + tests := []struct { + name string + allowList []string + sender bus.SenderInfo + want bool + }{ + { + name: "empty allowlist allows all", + allowList: nil, + sender: bus.SenderInfo{PlatformID: "anyone"}, + want: true, + }, + { + name: "numeric ID matches PlatformID", + allowList: []string{"123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: true, + }, + { + name: "canonical format matches", + allowList: []string{"telegram:123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: true, + }, + { + name: "canonical format wrong platform", + allowList: []string{"discord:123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: false, + }, + { + name: "@username matches", + allowList: []string{"@alice"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + }, + want: true, + }, + { + name: "compound id|username matches by ID", + allowList: []string{"123456|alice"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + }, + want: true, + }, + { + name: "non matching sender denied", + allowList: []string{"654321"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, tt.allowList) + if got := ch.IsAllowedSender(tt.sender); got != tt.want { + t.Fatalf("IsAllowedSender(%+v) = %v, want %v", tt.sender, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go similarity index 83% rename from pkg/channels/dingtalk.go rename to pkg/channels/dingtalk/dingtalk.go index 662fba3b7..8642ad362 100644 --- a/pkg/channels/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -1,7 +1,7 @@ // PicoClaw - Ultra-lightweight personal AI agent // DingTalk channel implementation using Stream Mode -package channels +package dingtalk import ( "context" @@ -12,7 +12,9 @@ import ( "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -20,7 +22,7 @@ import ( // DingTalkChannel implements the Channel interface for DingTalk (钉钉) // It uses WebSocket for receiving messages via stream mode and API for sending type DingTalkChannel struct { - *BaseChannel + *channels.BaseChannel config config.DingTalkConfig clientID string clientSecret string @@ -37,7 +39,11 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("dingtalk client_id and client_secret are required") } - base := NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(20000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &DingTalkChannel{ BaseChannel: base, @@ -70,7 +76,7 @@ func (c *DingTalkChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to start stream client: %w", err) } - c.setRunning(true) + c.SetRunning(true) logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") return nil } @@ -87,7 +93,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { c.streamClient.Close() } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("dingtalk", "DingTalk channel stopped") return nil } @@ -95,7 +101,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { // Send sends a message to DingTalk via the chatbot reply API func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("dingtalk channel not running") + return channels.ErrNotRunning } // Get session webhook from storage @@ -159,12 +165,17 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "session_webhook": data.SessionWebhook, } + var peer bus.Peer if data.ConversationType == "1" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = data.ConversationId + peer = bus.Peer{Kind: "group", ID: data.ConversationId} + // In group chats, apply unified group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return nil, nil + } + content = cleaned } logger.DebugCF("dingtalk", "Received message", map[string]any{ @@ -173,8 +184,20 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "preview": utils.Truncate(content, 50), }) + // Build sender info + sender := bus.SenderInfo{ + Platform: "dingtalk", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("dingtalk", senderID), + DisplayName: senderNick, + } + + if !c.IsAllowedSender(sender) { + return nil, nil + } + // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus @@ -197,7 +220,7 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c contentBytes, ) if err != nil { - return fmt.Errorf("failed to send reply: %w", err) + return fmt.Errorf("dingtalk send: %w", channels.ErrTemporary) } return nil diff --git a/pkg/channels/dingtalk/init.go b/pkg/channels/dingtalk/init.go new file mode 100644 index 000000000..5f49bce8c --- /dev/null +++ b/pkg/channels/dingtalk/init.go @@ -0,0 +1,13 @@ +package dingtalk + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDingTalkChannel(cfg.Channels.DingTalk, b) + }) +} diff --git a/pkg/channels/discord.go b/pkg/channels/discord/discord.go similarity index 55% rename from pkg/channels/discord.go rename to pkg/channels/discord/discord.go index 20f3b267c..fe0f8e82c 100644 --- a/pkg/channels/discord.go +++ b/pkg/channels/discord/discord.go @@ -1,4 +1,4 @@ -package channels +package discord import ( "context" @@ -11,26 +11,27 @@ import ( "github.com/bwmarrin/discordgo" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) const ( - transcriptionTimeout = 30 * time.Second - sendTimeout = 10 * time.Second + sendTimeout = 10 * time.Second ) type DiscordChannel struct { - *BaseChannel - session *discordgo.Session - config config.DiscordConfig - transcriber *voice.GroqTranscriber - ctx context.Context - typingMu sync.Mutex - typingStop map[string]chan struct{} // chatID → stop signal - botUserID string // stored for mention checking + *channels.BaseChannel + session *discordgo.Session + config config.DiscordConfig + ctx context.Context + cancel context.CancelFunc + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal + botUserID string // stored for mention checking } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -39,33 +40,25 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } - base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, + channels.WithMaxMessageLength(2000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &DiscordChannel{ BaseChannel: base, session: session, config: cfg, - transcriber: nil, ctx: context.Background(), typingStop: make(map[string]chan struct{}), }, nil } -func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - -func (c *DiscordChannel) getContext() context.Context { - if c.ctx == nil { - return context.Background() - } - return c.ctx -} - func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") - c.ctx = ctx + c.ctx, c.cancel = context.WithCancel(ctx) // Get bot user ID before opening session to avoid race condition botUser, err := c.session.User("@me") @@ -80,7 +73,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to open discord session: %w", err) } - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("discord", "Discord bot connected", map[string]any{ "username": botUser.Username, @@ -92,7 +85,7 @@ func (c *DiscordChannel) Start(ctx context.Context) error { func (c *DiscordChannel) Stop(ctx context.Context) error { logger.InfoC("discord", "Stopping Discord bot") - c.setRunning(false) + c.SetRunning(false) // Stop all typing goroutines before closing session c.typingMu.Lock() @@ -102,6 +95,11 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } c.typingMu.Unlock() + // Cancel our context so typing goroutines using c.ctx.Done() exit + if c.cancel != nil { + c.cancel() + } + if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } @@ -110,10 +108,26 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.stopTyping(msg.ChatID) + if !c.IsRunning() { + return channels.ErrNotRunning + } + + channelID := msg.ChatID + if channelID == "" { + return fmt.Errorf("channel ID is empty") + } + + if len([]rune(msg.Content)) == 0 { + return nil + } + + return c.sendChunk(ctx, channelID, msg.Content) +} +// SendMedia implements the channels.MediaSender interface. +func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { if !c.IsRunning() { - return fmt.Errorf("discord bot not running") + return channels.ErrNotRunning } channelID := msg.ChatID @@ -121,20 +135,94 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("channel ID is empty") } - runes := []rune(msg.Content) - if len(runes) == 0 { + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // Collect all files into a single ChannelMessageSendComplex call + files := make([]*discordgo.File, 0, len(msg.Parts)) + var caption string + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("discord", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("discord", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + // Note: discordgo reads from the Reader and we can't close it before send + + filename := part.Filename + if filename == "" { + filename = "file" + } + + files = append(files, &discordgo.File{ + Name: filename, + ContentType: part.ContentType, + Reader: file, + }) + + if part.Caption != "" && caption == "" { + caption = part.Caption + } + } + + if len(files) == 0 { return nil } - chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() - for _, chunk := range chunks { - if err := c.sendChunk(ctx, channelID, chunk); err != nil { - return err + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSendComplex(channelID, &discordgo.MessageSend{ + Content: caption, + Files: files, + }) + done <- err + }() + + select { + case err := <-done: + // Close all file readers + for _, f := range files { + if closer, ok := f.Reader.(*os.File); ok { + closer.Close() + } } + if err != nil { + return fmt.Errorf("discord send media: %w", channels.ErrTemporary) + } + return nil + case <-sendCtx.Done(): + // Close all file readers + for _, f := range files { + if closer, ok := f.Reader.(*os.File); ok { + closer.Close() + } + } + return sendCtx.Err() } +} - return nil +// EditMessage implements channels.MessageEditor. +func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + _, err := c.session.ChannelMessageEdit(chatID, messageID, content) + return err } func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { @@ -151,11 +239,11 @@ func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content strin select { case err := <-done: if err != nil { - return fmt.Errorf("failed to send discord message: %w", err) + return fmt.Errorf("discord send: %w", channels.ErrTemporary) } return nil case <-sendCtx.Done(): - return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + return sendCtx.Err() } } @@ -176,17 +264,32 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } - // Check allowlist first to avoid downloading attachments and transcribing for rejected users - if !c.IsAllowed(m.Author.ID) { + // Check allowlist first to avoid downloading attachments for rejected users + sender := bus.SenderInfo{ + Platform: "discord", + PlatformID: m.Author.ID, + CanonicalID: identity.BuildCanonicalID("discord", m.Author.ID), + Username: m.Author.Username, + } + // Build display name + displayName := m.Author.Username + if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { + displayName += "#" + m.Author.Discriminator + } + sender.DisplayName = displayName + + if !c.IsAllowedSender(sender) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ "user_id": m.Author.ID, }) return } - // If configured to only respond to mentions, check if bot is mentioned - // Skip this check for DMs (GuildID is empty) - DMs should always be responded to - if c.config.MentionOnly && m.GuildID != "" { + content := m.Content + + // In guild (group) channels, apply unified group trigger filtering + // DMs (GuildID is empty) always get a response + if m.GuildID != "" { isMentioned := false for _, mention := range m.Mentions { if mention.ID == c.botUserID { @@ -194,36 +297,39 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag break } } - if !isMentioned { - logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ + content = c.stripBotMention(content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + logger.DebugCF("discord", "Group message ignored by group trigger", map[string]any{ "user_id": m.Author.ID, }) return } + content = cleaned + } else { + // DMs: just strip bot mention without filtering + content = c.stripBotMention(content) } senderID := m.Author.ID - senderName := m.Author.Username - if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { - senderName += "#" + m.Author.Discriminator - } - content := m.Content - content = c.stripBotMention(content) mediaPaths := make([]string, 0, len(m.Attachments)) - localFiles := make([]string, 0, len(m.Attachments)) - - // Ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + scope := channels.BuildMediaScope("discord", m.ChannelID, m.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "discord", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } for _, attachment := range m.Attachments { isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) @@ -231,30 +337,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - localFiles = append(localFiles, localPath) - - transcribedText := "" - if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) - result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() // Release context resources immediately to avoid leaks in for loop - - if err != nil { - logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) - } else { - transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) - } - - content = appendContent(content, transcribedText) + mediaPaths = append(mediaPaths, storeMedia(localPath, attachment.Filename)) + content = appendContent(content, fmt.Sprintf("[audio: %s]", attachment.Filename)) } else { logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ "url": attachment.URL, @@ -279,9 +363,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag // Start typing after all early returns — guaranteed to have a matching Send() c.startTyping(m.ChannelID) + // Register typing stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("discord", m.ChannelID, func() { c.stopTyping(m.ChannelID) }) + } logger.DebugCF("discord", "Received message", map[string]any{ - "sender_name": senderName, + "sender_name": sender.DisplayName, "sender_id": senderID, "preview": utils.Truncate(content, 50), }) @@ -293,19 +381,18 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag peerID = senderID } + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ - "message_id": m.ID, "user_id": senderID, "username": m.Author.Username, - "display_name": senderName, + "display_name": sender.DisplayName, "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), - "peer_kind": peerKind, - "peer_id": peerID, } - c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender) } // startTyping starts a continuous typing indicator loop for the given chatID. diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go new file mode 100644 index 000000000..15a539804 --- /dev/null +++ b/pkg/channels/discord/init.go @@ -0,0 +1,13 @@ +package discord + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDiscordChannel(cfg.Channels.Discord, b) + }) +} diff --git a/pkg/channels/errors.go b/pkg/channels/errors.go new file mode 100644 index 000000000..09ee88b3f --- /dev/null +++ b/pkg/channels/errors.go @@ -0,0 +1,21 @@ +package channels + +import "errors" + +var ( + // ErrNotRunning indicates the channel is not running. + // Manager will not retry. + ErrNotRunning = errors.New("channel not running") + + // ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429). + // Manager will wait a fixed delay and retry. + ErrRateLimit = errors.New("rate limited") + + // ErrTemporary indicates a transient failure (e.g. network timeout, 5xx). + // Manager will use exponential backoff and retry. + ErrTemporary = errors.New("temporary failure") + + // ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429). + // Manager will not retry. + ErrSendFailed = errors.New("send failed") +) diff --git a/pkg/channels/errors_test.go b/pkg/channels/errors_test.go new file mode 100644 index 000000000..e5592345a --- /dev/null +++ b/pkg/channels/errors_test.go @@ -0,0 +1,56 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestErrorsIs(t *testing.T) { + wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit) + if !errors.Is(wrapped, ErrRateLimit) { + t.Error("wrapped ErrRateLimit should match") + } + if errors.Is(wrapped, ErrTemporary) { + t.Error("wrapped ErrRateLimit should not match ErrTemporary") + } +} + +func TestErrorsIsAllTypes(t *testing.T) { + sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed} + + for _, sentinel := range sentinels { + wrapped := fmt.Errorf("context: %w", sentinel) + if !errors.Is(wrapped, sentinel) { + t.Errorf("wrapped %v should match itself", sentinel) + } + + // Verify it doesn't match other sentinel errors + for _, other := range sentinels { + if other == sentinel { + continue + } + if errors.Is(wrapped, other) { + t.Errorf("wrapped %v should not match %v", sentinel, other) + } + } + } +} + +func TestErrorMessages(t *testing.T) { + tests := []struct { + err error + want string + }{ + {ErrNotRunning, "channel not running"}, + {ErrRateLimit, "rate limited"}, + {ErrTemporary, "temporary failure"}, + {ErrSendFailed, "send failed"}, + } + + for _, tt := range tests { + if got := tt.err.Error(); got != tt.want { + t.Errorf("error message = %q, want %q", got, tt.want) + } + } +} diff --git a/pkg/channels/errutil.go b/pkg/channels/errutil.go new file mode 100644 index 000000000..319e3c980 --- /dev/null +++ b/pkg/channels/errutil.go @@ -0,0 +1,30 @@ +package channels + +import ( + "fmt" + "net/http" +) + +// ClassifySendError wraps a raw error with the appropriate sentinel based on +// an HTTP status code. Channels that perform HTTP API calls should use this +// in their Send path. +func ClassifySendError(statusCode int, rawErr error) error { + switch { + case statusCode == http.StatusTooManyRequests: + return fmt.Errorf("%w: %v", ErrRateLimit, rawErr) + case statusCode >= 500: + return fmt.Errorf("%w: %v", ErrTemporary, rawErr) + case statusCode >= 400: + return fmt.Errorf("%w: %v", ErrSendFailed, rawErr) + default: + return rawErr + } +} + +// ClassifyNetError wraps a network/timeout error as ErrTemporary. +func ClassifyNetError(err error) error { + if err == nil { + return nil + } + return fmt.Errorf("%w: %v", ErrTemporary, err) +} diff --git a/pkg/channels/errutil_test.go b/pkg/channels/errutil_test.go new file mode 100644 index 000000000..e3d35f65b --- /dev/null +++ b/pkg/channels/errutil_test.go @@ -0,0 +1,97 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestClassifySendError(t *testing.T) { + raw := fmt.Errorf("some API error") + + tests := []struct { + name string + statusCode int + wantIs error + wantNil bool + }{ + {"429 -> ErrRateLimit", 429, ErrRateLimit, false}, + {"500 -> ErrTemporary", 500, ErrTemporary, false}, + {"502 -> ErrTemporary", 502, ErrTemporary, false}, + {"503 -> ErrTemporary", 503, ErrTemporary, false}, + {"400 -> ErrSendFailed", 400, ErrSendFailed, false}, + {"403 -> ErrSendFailed", 403, ErrSendFailed, false}, + {"404 -> ErrSendFailed", 404, ErrSendFailed, false}, + {"200 -> raw error", 200, nil, false}, + {"201 -> raw error", 201, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ClassifySendError(tt.statusCode, raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if tt.wantIs != nil { + if !errors.Is(err, tt.wantIs) { + t.Errorf("errors.Is(err, %v) = false, want true; err = %v", tt.wantIs, err) + } + } else { + // Should return the raw error unchanged + if err != raw { + t.Errorf("expected raw error to be returned unchanged for status %d, got %v", tt.statusCode, err) + } + } + }) + } +} + +func TestClassifySendErrorNoFalsePositive(t *testing.T) { + raw := fmt.Errorf("some error") + + // 429 should NOT match ErrTemporary or ErrSendFailed + err := ClassifySendError(429, raw) + if errors.Is(err, ErrTemporary) { + t.Error("429 should not match ErrTemporary") + } + if errors.Is(err, ErrSendFailed) { + t.Error("429 should not match ErrSendFailed") + } + + // 500 should NOT match ErrRateLimit or ErrSendFailed + err = ClassifySendError(500, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("500 should not match ErrRateLimit") + } + if errors.Is(err, ErrSendFailed) { + t.Error("500 should not match ErrSendFailed") + } + + // 400 should NOT match ErrRateLimit or ErrTemporary + err = ClassifySendError(400, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("400 should not match ErrRateLimit") + } + if errors.Is(err, ErrTemporary) { + t.Error("400 should not match ErrTemporary") + } +} + +func TestClassifyNetError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + if err := ClassifyNetError(nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + + t.Run("non-nil error wraps as ErrTemporary", func(t *testing.T) { + raw := fmt.Errorf("connection refused") + err := ClassifyNetError(raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, ErrTemporary) { + t.Errorf("errors.Is(err, ErrTemporary) = false, want true; err = %v", err) + } + }) +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go new file mode 100644 index 000000000..e8a057741 --- /dev/null +++ b/pkg/channels/feishu/common.go @@ -0,0 +1,9 @@ +package feishu + +// stringValue safely dereferences a *string pointer. +func stringValue(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu/feishu_32.go similarity index 93% rename from pkg/channels/feishu_32.go rename to pkg/channels/feishu/feishu_32.go index 5109b8195..d0ec758c6 100644 --- a/pkg/channels/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -1,18 +1,19 @@ //go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 -package channels +package feishu import ( "context" "errors" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) // FeishuChannel is a stub implementation for 32-bit architectures type FeishuChannel struct { - *BaseChannel + *channels.BaseChannel } // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu/feishu_64.go similarity index 78% rename from pkg/channels/feishu_64.go rename to pkg/channels/feishu/feishu_64.go index 42e74980f..1db1bf669 100644 --- a/pkg/channels/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -1,6 +1,6 @@ //go:build amd64 || arm64 || riscv64 || mips64 || ppc64 -package channels +package feishu import ( "context" @@ -15,13 +15,15 @@ import ( larkws "github.com/larksuite/oapi-sdk-go/v3/ws" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) type FeishuChannel struct { - *BaseChannel + *channels.BaseChannel config config.FeishuConfig client *lark.Client wsClient *larkws.Client @@ -31,7 +33,10 @@ type FeishuChannel struct { } func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - base := NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &FeishuChannel{ BaseChannel: base, @@ -60,7 +65,7 @@ func (c *FeishuChannel) Start(ctx context.Context) error { wsClient := c.wsClient c.mu.Unlock() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("feishu", "Feishu channel started (websocket mode)") go func() { @@ -83,14 +88,14 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { c.wsClient = nil c.mu.Unlock() - c.setRunning(false) + c.SetRunning(false) logger.InfoC("feishu", "Feishu channel stopped") return nil } func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("feishu channel not running") + return channels.ErrNotRunning } if msg.ChatID == "" { @@ -114,11 +119,11 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("failed to send feishu message: %w", err) + return fmt.Errorf("feishu send: %w", channels.ErrTemporary) } if !resp.Success() { - return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) } logger.DebugCF("feishu", "Feishu message sent", map[string]any{ @@ -128,7 +133,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } -func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error { +func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil } @@ -152,8 +157,9 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 } metadata := map[string]string{} - if messageID := stringValue(message.MessageId); messageID != "" { - metadata["message_id"] = messageID + messageID := "" + if mid := stringValue(message.MessageId); mid != "" { + messageID = mid } if messageType := stringValue(message.MessageType); messageType != "" { metadata["message_type"] = messageType @@ -166,12 +172,17 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 } chatType := stringValue(message.ChatType) + var peer bus.Peer if chatType == "p2p" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} + // In group chats, apply unified group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return nil + } + content = cleaned } logger.InfoCF("feishu", "Feishu message received", map[string]any{ @@ -180,7 +191,17 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 "preview": utils.Truncate(content, 80), }) - c.HandleMessage(senderID, chatID, content, nil, metadata) + senderInfo := bus.SenderInfo{ + Platform: "feishu", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("feishu", senderID), + } + + if !c.IsAllowedSender(senderInfo) { + return nil + } + + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo) return nil } @@ -218,10 +239,3 @@ func extractFeishuMessageContent(message *larkim.EventMessage) string { return *message.Content } - -func stringValue(v *string) string { - if v == nil { - return "" - } - return *v -} diff --git a/pkg/channels/feishu/init.go b/pkg/channels/feishu/init.go new file mode 100644 index 000000000..7e5a62dae --- /dev/null +++ b/pkg/channels/feishu/init.go @@ -0,0 +1,13 @@ +package feishu + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewFeishuChannel(cfg.Channels.Feishu, b) + }) +} diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go new file mode 100644 index 000000000..32bfe95f8 --- /dev/null +++ b/pkg/channels/interfaces.go @@ -0,0 +1,24 @@ +package channels + +import "context" + +// TypingCapable — channels that can show a typing/thinking indicator. +// StartTyping begins the indicator and returns a stop function. +// The stop function MUST be idempotent and safe to call multiple times. +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +// MessageEditor — channels that can edit an existing message. +// messageID is always string; channels convert platform-specific types internally. +type MessageEditor interface { + EditMessage(ctx context.Context, chatID string, messageID string, content string) error +} + +// PlaceholderRecorder is injected into channels by Manager. +// Channels call these methods on inbound to register typing/placeholder state. +// Manager uses the registered state on outbound to stop typing and edit placeholders. +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) +} diff --git a/pkg/channels/line/init.go b/pkg/channels/line/init.go new file mode 100644 index 000000000..9265575cc --- /dev/null +++ b/pkg/channels/line/init.go @@ -0,0 +1,13 @@ +package line + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewLINEChannel(cfg.Channels.LINE, b) + }) +} diff --git a/pkg/channels/line.go b/pkg/channels/line/line.go similarity index 81% rename from pkg/channels/line.go rename to pkg/channels/line/line.go index 44134996f..21eb4cb67 100644 --- a/pkg/channels/line.go +++ b/pkg/channels/line/line.go @@ -1,4 +1,4 @@ -package channels +package line import ( "bytes" @@ -10,14 +10,16 @@ import ( "fmt" "io" "net/http" - "os" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -41,9 +43,8 @@ type replyTokenEntry struct { // using the LINE Messaging API with HTTP webhook for receiving messages // and REST API for sending messages. type LINEChannel struct { - *BaseChannel + *channels.BaseChannel config config.LINEConfig - httpServer *http.Server botUserID string // Bot's user ID botBasicID string // Bot's basic ID (e.g. @216ru...) botDisplayName string // Bot's display name for text-based mention detection @@ -59,7 +60,11 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return nil, fmt.Errorf("line channel_secret and channel_access_token are required") } - base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(5000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &LINEChannel{ BaseChannel: base, @@ -67,7 +72,7 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha }, nil } -// Start launches the HTTP webhook server. +// Start initializes the LINE channel. func (c *LINEChannel) Start(ctx context.Context) error { logger.InfoC("line", "Starting LINE channel (Webhook Mode)") @@ -86,32 +91,7 @@ func (c *LINEChannel) Start(ctx context.Context) error { }) } - mux := http.NewServeMux() - path := c.config.WebhookPath - if path == "" { - path = "/webhook/line" - } - mux.HandleFunc(path, c.webhookHandler) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.httpServer = &http.Server{ - Addr: addr, - Handler: mux, - } - - go func() { - logger.InfoCF("line", "LINE webhook server listening", map[string]any{ - "addr": addr, - "path": path, - }) - if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("line", "Webhook server error", map[string]any{ - "error": err.Error(), - }) - } - }() - - c.setRunning(true) + c.SetRunning(true) logger.InfoC("line", "LINE channel started (Webhook Mode)") return nil } @@ -150,7 +130,7 @@ func (c *LINEChannel) fetchBotInfo() error { return nil } -// Stop gracefully shuts down the HTTP server. +// Stop gracefully stops the LINE channel. func (c *LINEChannel) Stop(ctx context.Context) error { logger.InfoC("line", "Stopping LINE channel") @@ -158,21 +138,24 @@ func (c *LINEChannel) Stop(ctx context.Context) error { c.cancel() } - if c.httpServer != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if err := c.httpServer.Shutdown(shutdownCtx); err != nil { - logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{ - "error": err.Error(), - }) - } - } - - c.setRunning(false) + c.SetRunning(false) logger.InfoC("line", "LINE channel stopped") return nil } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *LINEChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/line" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *LINEChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.webhookHandler(w, r) +} + // webhookHandler handles incoming LINE webhook requests. func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { @@ -284,14 +267,6 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } - // In group chats, only respond when the bot is mentioned - if isGroup && !c.isBotMentioned(msg) { - logger.DebugCF("line", "Ignoring group message without mention", map[string]any{ - "chat_id": chatID, - }) - return - } - // Store reply token for later use if event.ReplyToken != "" { c.replyTokens.Store(chatID, replyTokenEntry{ @@ -307,18 +282,22 @@ func (c *LINEChannel) processEvent(event lineEvent) { var content string var mediaPaths []string - localFiles := []string{} - - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + scope := channels.BuildMediaScope("line", chatID, msg.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "line", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } switch msg.Type { case "text": @@ -330,22 +309,19 @@ func (c *LINEChannel) processEvent(event lineEvent) { case "image": localPath := c.downloadContent(msg.ID, "image.jpg") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "image.jpg")) content = "[image]" } case "audio": localPath := c.downloadContent(msg.ID, "audio.m4a") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "audio.m4a")) content = "[audio]" } case "video": localPath := c.downloadContent(msg.ID, "video.mp4") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "video.mp4")) content = "[video]" } case "file": @@ -360,18 +336,29 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } + // In group chats, apply unified group trigger filtering + if isGroup { + isMentioned := c.isBotMentioned(msg) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{ + "chat_id": chatID, + }) + return + } + content = cleaned + } + metadata := map[string]string{ "platform": "line", "source_type": event.Source.Type, - "message_id": msg.ID, } + var peer bus.Peer if isGroup { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } else { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } logger.DebugCF("line", "Received message", map[string]any{ @@ -385,7 +372,17 @@ func (c *LINEChannel) processEvent(event lineEvent) { // Show typing/loading indicator (requires user ID, not group ID) c.sendLoading(senderID) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + sender := bus.SenderInfo{ + Platform: "line", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("line", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender) } // isBotMentioned checks if the bot is mentioned in the message. @@ -491,7 +488,7 @@ func (c *LINEChannel) resolveChatID(source lineSource) string { // using a cached reply token, then falls back to the Push API. func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("line channel not running") + return channels.ErrNotRunning } // Load and consume quote token for this chat @@ -519,6 +516,36 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) } +// SendMedia implements the channels.MediaSender interface. +// LINE requires media to be accessible via public URL; since we only have local files, +// we fall back to sending a text message with the filename/caption. +// For full support, an external file hosting service would be needed. +func (c *LINEChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // LINE Messaging API requires publicly accessible URLs for media messages. + // Since we only have local file paths, send caption text as fallback. + for _, part := range msg.Parts { + caption := part.Caption + if caption == "" { + caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) + } + + if err := c.sendPush(ctx, msg.ChatID, caption, ""); err != nil { + return err + } + } + + return nil +} + // buildTextMessage creates a text message object, optionally with quoteToken. func buildTextMessage(content, quoteToken string) map[string]string { msg := map[string]string{ @@ -582,13 +609,13 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("API request failed: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody))) } return nil diff --git a/pkg/channels/maixcam/init.go b/pkg/channels/maixcam/init.go new file mode 100644 index 000000000..5a269b22b --- /dev/null +++ b/pkg/channels/maixcam/init.go @@ -0,0 +1,13 @@ +package maixcam + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMaixCamChannel(cfg.Channels.MaixCam, b) + }) +} diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam/maixcam.go similarity index 78% rename from pkg/channels/maixcam.go rename to pkg/channels/maixcam/maixcam.go index 34ce62b20..ff9a3ed1a 100644 --- a/pkg/channels/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -1,4 +1,4 @@ -package channels +package maixcam import ( "context" @@ -6,16 +6,21 @@ import ( "fmt" "net" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) type MaixCamChannel struct { - *BaseChannel + *channels.BaseChannel config config.MaixCamConfig listener net.Listener + ctx context.Context + cancel context.CancelFunc clients map[net.Conn]bool clientsMux sync.RWMutex } @@ -28,7 +33,13 @@ type MaixCamMessage struct { } func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { - base := NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel( + "maixcam", + cfg, + bus, + cfg.AllowFrom, + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &MaixCamChannel{ BaseChannel: base, @@ -40,37 +51,40 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC func (c *MaixCamChannel) Start(ctx context.Context) error { logger.InfoC("maixcam", "Starting MaixCam channel server") + c.ctx, c.cancel = context.WithCancel(ctx) + addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) listener, err := net.Listen("tcp", addr) if err != nil { + c.cancel() return fmt.Errorf("failed to listen on %s: %w", addr, err) } c.listener = listener - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{ "host": c.config.Host, "port": c.config.Port, }) - go c.acceptConnections(ctx) + go c.acceptConnections() return nil } -func (c *MaixCamChannel) acceptConnections(ctx context.Context) { +func (c *MaixCamChannel) acceptConnections() { logger.DebugC("maixcam", "Starting connection acceptor") for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): logger.InfoC("maixcam", "Stopping connection acceptor") return default: conn, err := c.listener.Accept() if err != nil { - if c.running { + if c.IsRunning() { logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{ "error": err.Error(), }) @@ -86,12 +100,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { c.clients[conn] = true c.clientsMux.Unlock() - go c.handleConnection(conn, ctx) + go c.handleConnection(conn) } } } -func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { +func (c *MaixCamChannel) handleConnection(conn net.Conn) { logger.DebugC("maixcam", "Handling MaixCam connection") defer func() { @@ -106,7 +120,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: var msg MaixCamMessage @@ -170,11 +184,29 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { "y": fmt.Sprintf("%.0f", y), "w": fmt.Sprintf("%.0f", w), "h": fmt.Sprintf("%.0f", h), - "peer_kind": "channel", - "peer_id": "default", } - c.HandleMessage(senderID, chatID, content, []string{}, metadata) + sender := bus.SenderInfo{ + Platform: "maixcam", + PlatformID: "maixcam", + CanonicalID: identity.BuildCanonicalID("maixcam", "maixcam"), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage( + c.ctx, + bus.Peer{Kind: "channel", ID: "default"}, + "", + senderID, + chatID, + content, + []string{}, + metadata, + sender, + ) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { @@ -185,7 +217,12 @@ func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { func (c *MaixCamChannel) Stop(ctx context.Context) error { logger.InfoC("maixcam", "Stopping MaixCam channel") - c.setRunning(false) + c.SetRunning(false) + + // Cancel context first to signal goroutines to exit + if c.cancel != nil { + c.cancel() + } if c.listener != nil { c.listener.Close() @@ -205,7 +242,14 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error { func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("maixcam channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.clientsMux.RLock() @@ -230,13 +274,15 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro var sendErr error for conn := range c.clients { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if _, err := conn.Write(data); err != nil { logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ "client": conn.RemoteAddr().String(), "error": err.Error(), }) - sendErr = err + sendErr = fmt.Errorf("maixcam send: %w", channels.ErrTemporary) } + _ = conn.SetWriteDeadline(time.Time{}) } return sendErr diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 75edaf49e..07c2ce1e2 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -8,32 +8,115 @@ package channels import ( "context" + "errors" "fmt" + "math" + "net/http" "sync" + "time" + + "golang.org/x/time/rate" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" +) + +const ( + defaultChannelQueueSize = 100 + defaultRateLimit = 10 // default 10 msg/s + maxRetries = 3 + rateLimitDelay = 1 * time.Second + baseBackoff = 500 * time.Millisecond + maxBackoff = 8 * time.Second ) +// channelRateConfig maps channel name to per-second rate limit. +var channelRateConfig = map[string]float64{ + "telegram": 20, + "discord": 1, + "slack": 1, + "line": 10, +} + +type channelWorker struct { + ch Channel + queue chan bus.OutboundMessage + mediaQueue chan bus.OutboundMediaMessage + done chan struct{} + mediaDone chan struct{} + limiter *rate.Limiter +} + type Manager struct { channels map[string]Channel + workers map[string]*channelWorker bus *bus.MessageBus config *config.Config + mediaStore media.MediaStore dispatchTask *asyncTask + mux *http.ServeMux + httpServer *http.Server mu sync.RWMutex + placeholders sync.Map // "channel:chatID" → placeholderID (string) + typingStops sync.Map // "channel:chatID" → func() } type asyncTask struct { cancel context.CancelFunc } -func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error) { +// RecordPlaceholder registers a placeholder message for later editing. +// Implements PlaceholderRecorder. +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { + key := channel + ":" + chatID + m.placeholders.Store(key, placeholderID) +} + +// RecordTypingStop registers a typing stop function for later invocation. +// Implements PlaceholderRecorder. +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) { + key := channel + ":" + chatID + m.typingStops.Store(key, stop) +} + +// preSend handles typing stop and placeholder editing before sending a message. +// Returns true if the message was edited into a placeholder (skip Send). +func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if stop, ok := v.(func()); ok { + stop() // idempotent, safe + } + } + + // 2. Try editing placeholder + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if placeholderID, ok := v.(string); ok && placeholderID != "" { + if editor, ok := ch.(MessageEditor); ok { + if err := editor.EditMessage(ctx, msg.ChatID, placeholderID, msg.Content); err == nil { + return true // edited successfully, skip Send + } + // edit failed → fall through to normal Send + } + } + } + + return false +} + +func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ - channels: make(map[string]Channel), - bus: messageBus, - config: cfg, + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: messageBus, + config: cfg, + mediaStore: store, } if err := m.initChannels(); err != nil { @@ -43,163 +126,96 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error return m, nil } +// initChannel is a helper that looks up a factory by name and creates the channel. +func (m *Manager) initChannel(name, displayName string) { + f, ok := getFactory(name) + if !ok { + logger.WarnCF("channels", "Factory not registered", map[string]any{ + "channel": displayName, + }) + return + } + logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{ + "channel": displayName, + }) + ch, err := f(m.config, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{ + "channel": displayName, + "error": err.Error(), + }) + } else { + // Inject MediaStore if channel supports it + if m.mediaStore != nil { + if setter, ok := ch.(interface{ SetMediaStore(s media.MediaStore) }); ok { + setter.SetMediaStore(m.mediaStore) + } + } + // Inject PlaceholderRecorder if channel supports it + if setter, ok := ch.(interface{ SetPlaceholderRecorder(r PlaceholderRecorder) }); ok { + setter.SetPlaceholderRecorder(m) + } + m.channels[name] = ch + m.workers[name] = newChannelWorker(name, ch) + logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ + "channel": displayName, + }) + } +} + func (m *Manager) initChannels() error { logger.InfoC("channels", "Initializing channel manager") if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { - logger.DebugC("channels", "Attempting to initialize Telegram channel") - telegram, err := NewTelegramChannel(m.config, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["telegram"] = telegram - logger.InfoC("channels", "Telegram channel enabled successfully") - } + m.initChannel("telegram", "Telegram") } if m.config.Channels.WhatsApp.Enabled && m.config.Channels.WhatsApp.BridgeURL != "" { - logger.DebugC("channels", "Attempting to initialize WhatsApp channel") - whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["whatsapp"] = whatsapp - logger.InfoC("channels", "WhatsApp channel enabled successfully") - } + m.initChannel("whatsapp", "WhatsApp") } if m.config.Channels.Feishu.Enabled { - logger.DebugC("channels", "Attempting to initialize Feishu channel") - feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["feishu"] = feishu - logger.InfoC("channels", "Feishu channel enabled successfully") - } + m.initChannel("feishu", "Feishu") } if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" { - logger.DebugC("channels", "Attempting to initialize Discord channel") - discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["discord"] = discord - logger.InfoC("channels", "Discord channel enabled successfully") - } + m.initChannel("discord", "Discord") } if m.config.Channels.MaixCam.Enabled { - logger.DebugC("channels", "Attempting to initialize MaixCam channel") - maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["maixcam"] = maixcam - logger.InfoC("channels", "MaixCam channel enabled successfully") - } + m.initChannel("maixcam", "MaixCam") } if m.config.Channels.QQ.Enabled { - logger.DebugC("channels", "Attempting to initialize QQ channel") - qq, err := NewQQChannel(m.config.Channels.QQ, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["qq"] = qq - logger.InfoC("channels", "QQ channel enabled successfully") - } + m.initChannel("qq", "QQ") } if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" { - logger.DebugC("channels", "Attempting to initialize DingTalk channel") - dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["dingtalk"] = dingtalk - logger.InfoC("channels", "DingTalk channel enabled successfully") - } + m.initChannel("dingtalk", "DingTalk") } if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { - logger.DebugC("channels", "Attempting to initialize Slack channel") - slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["slack"] = slackCh - logger.InfoC("channels", "Slack channel enabled successfully") - } + m.initChannel("slack", "Slack") } if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" { - logger.DebugC("channels", "Attempting to initialize LINE channel") - line, err := NewLINEChannel(m.config.Channels.LINE, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["line"] = line - logger.InfoC("channels", "LINE channel enabled successfully") - } + m.initChannel("line", "LINE") } if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" { - logger.DebugC("channels", "Attempting to initialize OneBot channel") - onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["onebot"] = onebot - logger.InfoC("channels", "OneBot channel enabled successfully") - } + m.initChannel("onebot", "OneBot") } if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" { - logger.DebugC("channels", "Attempting to initialize WeCom channel") - wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom"] = wecom - logger.InfoC("channels", "WeCom channel enabled successfully") - } + m.initChannel("wecom", "WeCom") } if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { - logger.DebugC("channels", "Attempting to initialize WeCom App channel") - wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom_app"] = wecomApp - logger.InfoC("channels", "WeCom App channel enabled successfully") - } + m.initChannel("wecom_app", "WeCom App") + } + + if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" { + m.initChannel("pico", "Pico") } logger.InfoCF("channels", "Channel initialization completed", map[string]any{ @@ -209,6 +225,43 @@ func (m *Manager) initChannels() error { return nil } +// SetupHTTPServer creates a shared HTTP server with the given listen address. +// It registers health endpoints from the health server and discovers channels +// that implement WebhookHandler and/or HealthChecker to register their handlers. +func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { + m.mux = http.NewServeMux() + + // Register health endpoints + if healthServer != nil { + healthServer.RegisterOnMux(m.mux) + } + + // Discover and register webhook handlers and health checkers + for name, ch := range m.channels { + if wh, ok := ch.(WebhookHandler); ok { + m.mux.Handle(wh.WebhookPath(), wh) + logger.InfoCF("channels", "Webhook handler registered", map[string]any{ + "channel": name, + "path": wh.WebhookPath(), + }) + } + if hc, ok := ch.(HealthChecker); ok { + m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler) + logger.InfoCF("channels", "Health endpoint registered", map[string]any{ + "channel": name, + "path": hc.HealthPath(), + }) + } + } + + m.httpServer = &http.Server{ + Addr: addr, + Handler: m.mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + } +} + func (m *Manager) StartAll(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() @@ -223,8 +276,6 @@ func (m *Manager) StartAll(ctx context.Context) error { dispatchCtx, cancel := context.WithCancel(ctx) m.dispatchTask = &asyncTask{cancel: cancel} - go m.dispatchOutbound(dispatchCtx) - for name, channel := range m.channels { logger.InfoCF("channels", "Starting channel", map[string]any{ "channel": name, @@ -237,6 +288,30 @@ func (m *Manager) StartAll(ctx context.Context) error { } } + // Start per-channel workers + for name, w := range m.workers { + go m.runWorker(dispatchCtx, name, w) + go m.runMediaWorker(dispatchCtx, name, w) + } + + // Start the dispatcher that reads from the bus and routes to workers + go m.dispatchOutbound(dispatchCtx) + go m.dispatchOutboundMedia(dispatchCtx) + + // Start shared HTTP server if configured + if m.httpServer != nil { + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": m.httpServer.Addr, + }) + if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() + } + logger.InfoC("channels", "All channels started") return nil } @@ -247,11 +322,40 @@ func (m *Manager) StopAll(ctx context.Context) error { logger.InfoC("channels", "Stopping all channels") + // Shutdown shared HTTP server first + if m.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := m.httpServer.Shutdown(shutdownCtx); err != nil { + logger.ErrorCF("channels", "Shared HTTP server shutdown error", map[string]any{ + "error": err.Error(), + }) + } + m.httpServer = nil + } + + // Cancel dispatcher if m.dispatchTask != nil { m.dispatchTask.cancel() m.dispatchTask = nil } + // Close all worker queues and wait for them to drain + for _, w := range m.workers { + close(w.queue) + } + for _, w := range m.workers { + <-w.done + } + // Close all media worker queues and wait for them to drain + for _, w := range m.workers { + close(w.mediaQueue) + } + for _, w := range m.workers { + <-w.mediaDone + } + + // Stop all channels for name, channel := range m.channels { logger.InfoCF("channels", "Stopping channel", map[string]any{ "channel": name, @@ -268,6 +372,117 @@ func (m *Manager) StopAll(ctx context.Context) error { return nil } +// newChannelWorker creates a channelWorker with a rate limiter configured +// for the given channel name. +func newChannelWorker(name string, ch Channel) *channelWorker { + rateVal := float64(defaultRateLimit) + if r, ok := channelRateConfig[name]; ok { + rateVal = r + } + burst := int(math.Max(1, math.Ceil(rateVal/2))) + + return &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + mediaQueue: make(chan bus.OutboundMediaMessage, defaultChannelQueueSize), + done: make(chan struct{}), + mediaDone: make(chan struct{}), + limiter: rate.NewLimiter(rate.Limit(rateVal), burst), + } +} + +// runWorker processes outbound messages for a single channel, splitting +// messages that exceed the channel's maximum message length. +func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) { + defer close(w.done) + for { + select { + case msg, ok := <-w.queue: + if !ok { + return + } + maxLen := 0 + if mlp, ok := w.ch.(MessageLengthProvider); ok { + maxLen = mlp.MaxMessageLength() + } + if maxLen > 0 && len([]rune(msg.Content)) > maxLen { + chunks := SplitMessage(msg.Content, maxLen) + for _, chunk := range chunks { + chunkMsg := msg + chunkMsg.Content = chunk + m.sendWithRetry(ctx, name, w, chunkMsg) + } + } else { + m.sendWithRetry(ctx, name, w, msg) + } + case <-ctx.Done(): + return + } + } +} + +// sendWithRetry sends a message through the channel with rate limiting and +// retry logic. It classifies errors to determine the retry strategy: +// - ErrNotRunning / ErrSendFailed: permanent, no retry +// - ErrRateLimit: fixed delay retry +// - ErrTemporary / unknown: exponential backoff retry +func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) { + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + // ctx cancelled, shutting down + return + } + + // Pre-send: stop typing and try to edit placeholder + if m.preSend(ctx, name, msg, w.ch) { + return // placeholder was edited successfully, skip Send + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = w.ch.Send(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): + continue + case <-ctx.Done(): + return + } + } + + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): + case <-ctx.Done(): + return + } + } + + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "Send failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) +} + func (m *Manager) dispatchOutbound(ctx context.Context) { logger.InfoC("channels", "Outbound dispatcher started") @@ -288,7 +503,8 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { } m.mu.RLock() - channel, exists := m.channels[msg.Channel] + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] m.mu.RUnlock() if !exists { @@ -298,16 +514,136 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { continue } - if err := channel.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message to channel", map[string]any{ + if wExists { + select { + case w.queue <- msg: + case <-ctx.Done(): + return + } + } + } + } +} + +func (m *Manager) dispatchOutboundMedia(ctx context.Context) { + logger.InfoC("channels", "Outbound media dispatcher started") + + for { + select { + case <-ctx.Done(): + logger.InfoC("channels", "Outbound media dispatcher stopped") + return + default: + msg, ok := m.bus.SubscribeOutboundMedia(ctx) + if !ok { + continue + } + + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{ "channel": msg.Channel, - "error": err.Error(), }) + continue + } + + if wExists { + select { + case w.mediaQueue <- msg: + case <-ctx.Done(): + return + } } } } } +// runMediaWorker processes outbound media messages for a single channel. +func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWorker) { + defer close(w.mediaDone) + for { + select { + case msg, ok := <-w.mediaQueue: + if !ok { + return + } + m.sendMediaWithRetry(ctx, name, w, msg) + case <-ctx.Done(): + return + } + } +} + +// sendMediaWithRetry sends a media message through the channel with rate limiting and +// retry logic. If the channel does not implement MediaSender, it silently skips. +func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) { + ms, ok := w.ch.(MediaSender) + if !ok { + logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ + "channel": name, + }) + return + } + + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + return + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = ms.SendMedia(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): + continue + case <-ctx.Done(): + return + } + } + + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): + case <-ctx.Done(): + return + } + } + + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "SendMedia failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) +} + func (m *Manager) GetChannel(name string) (Channel, bool) { m.mu.RLock() defer m.mu.RUnlock() @@ -344,17 +680,26 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() m.channels[name] = channel + m.workers[name] = newChannelWorker(name, channel) } func (m *Manager) UnregisterChannel(name string) { m.mu.Lock() defer m.mu.Unlock() + if w, ok := m.workers[name]; ok { + close(w.queue) + <-w.done + close(w.mediaQueue) + <-w.mediaDone + } + delete(m.workers, name) delete(m.channels, name) } func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() - channel, exists := m.channels[channelName] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { @@ -367,5 +712,16 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten Content: content, } + if wExists { + select { + case w.queue <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + // Fallback: direct send (should not happen) + channel, _ := m.channels[channelName] return channel.Send(ctx, msg) } diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go new file mode 100644 index 000000000..0573c0a8e --- /dev/null +++ b/pkg/channels/manager_test.go @@ -0,0 +1,634 @@ +package channels + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/time/rate" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// mockChannel is a test double that delegates Send to a configurable function. +type mockChannel struct { + BaseChannel + sendFn func(ctx context.Context, msg bus.OutboundMessage) error +} + +func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return m.sendFn(ctx, msg) +} + +func (m *mockChannel) Start(ctx context.Context) error { return nil } +func (m *mockChannel) Stop(ctx context.Context) error { return nil } + +// newTestManager creates a minimal Manager suitable for unit tests. +func newTestManager() *Manager { + return &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + } +} + +func TestSendWithRetry_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call, got %d", callCount) + } +} + +func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount <= 2 { + return fmt.Errorf("network error: %w", ErrTemporary) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 3 { + t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount) + } +} + +func TestSendWithRetry_PermanentFailure(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("bad chat ID: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount) + } +} + +func TestSendWithRetry_NotRunning(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return ErrNotRunning + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount) + } +} + +func TestSendWithRetry_RateLimitRetry(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("429: %w", ErrRateLimit) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + elapsed := time.Since(start) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount) + } + // Should have waited at least rateLimitDelay (1s) but allow some slack + if elapsed < 900*time.Millisecond { + t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed) + } +} + +func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + expected := maxRetries + 1 // initial attempt + maxRetries retries + if callCount != expected { + t.Fatalf("expected %d Send calls, got %d", expected, callCount) + } +} + +func TestSendWithRetry_UnknownError(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return errors.New("random unexpected error") + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount) + } +} + +func TestSendWithRetry_ContextCancelled(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + // Cancel context after first Send attempt returns + ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + cancel() + return fmt.Errorf("timeout: %w", ErrTemporary) + } + + m.sendWithRetry(ctx, "test", w, msg) + + // Should have called Send once, then noticed ctx cancelled during backoff + if callCount != 1 { + t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount) + } +} + +func TestWorkerRateLimiter(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var sendTimes []time.Time + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + mu.Lock() + sendTimes = append(sendTimes, time.Now()) + mu.Unlock() + return nil + }, + } + + // Create a worker with a low rate: 2 msg/s, burst 1 + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(2, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go m.runWorker(ctx, "test", w) + + // Enqueue 4 messages + for i := 0; i < 4; i++ { + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)} + } + + // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin) + time.Sleep(3 * time.Second) + + mu.Lock() + times := make([]time.Time, len(sendTimes)) + copy(times, sendTimes) + mu.Unlock() + + if len(times) != 4 { + t.Fatalf("expected 4 sends, got %d", len(times)) + } + + // Verify rate limiting: total duration should be at least 1s + // (first message immediate, then ~500ms between each subsequent one at 2/s) + totalDuration := times[len(times)-1].Sub(times[0]) + if totalDuration < 1*time.Second { + t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration) + } +} + +func TestNewChannelWorker_DefaultRate(t *testing.T) { + ch := &mockChannel{} + w := newChannelWorker("unknown_channel", ch) + + if w.limiter == nil { + t.Fatal("expected limiter to be non-nil") + } + if w.limiter.Limit() != rate.Limit(defaultRateLimit) { + t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit()) + } +} + +func TestNewChannelWorker_ConfiguredRate(t *testing.T) { + ch := &mockChannel{} + + for name, expectedRate := range channelRateConfig { + w := newChannelWorker(name, ch) + if w.limiter.Limit() != rate.Limit(expectedRate) { + t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit()) + } + } +} + +func TestRunWorker_MessageSplitting(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var received []string + + ch := &mockChannelWithLength{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + mu.Lock() + received = append(received, msg.Content) + mu.Unlock() + return nil + }, + }, + maxLen: 5, + } + + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go m.runWorker(ctx, "test", w) + + // Send a message that should be split + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"} + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + count := len(received) + mu.Unlock() + + if count < 2 { + t.Fatalf("expected message to be split into at least 2 chunks, got %d", count) + } +} + +// mockChannelWithLength implements MessageLengthProvider. +type mockChannelWithLength struct { + mockChannel + maxLen int +} + +func (m *mockChannelWithLength) MaxMessageLength() int { + return m.maxLen +} + +func TestSendWithRetry_ExponentialBackoff(t *testing.T) { + m := newTestManager() + + var callTimes []time.Time + var callCount atomic.Int32 + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callTimes = append(callTimes, time.Now()) + callCount.Add(1) + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + totalElapsed := time.Since(start) + + // With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s + // Total backoff: 500ms + 1s + 2s = 3.5s + // Allow some margin + if totalElapsed < 3*time.Second { + t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed) + } + + if int(callCount.Load()) != maxRetries+1 { + t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) + } +} + +// --- Phase 10: preSend orchestration tests --- + +// mockMessageEditor is a channel that supports MessageEditor. +type mockMessageEditor struct { + mockChannel + editFn func(ctx context.Context, chatID, messageID, content string) error +} + +func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return m.editFn(ctx, chatID, messageID, content) +} + +func TestPreSend_PlaceholderEditSuccess(t *testing.T) { + m := newTestManager() + var sendCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, chatID, messageID, content string) error { + editCalled = true + if chatID != "123" { + t.Fatalf("expected chatID 123, got %s", chatID) + } + if messageID != "456" { + t.Fatalf("expected messageID 456, got %s", messageID) + } + if content != "hello" { + t.Fatalf("expected content 'hello', got %s", content) + } + return nil + }, + } + + // Register placeholder + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !edited { + t.Fatal("expected preSend to return true (placeholder edited)") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder edited") + } +} + +func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { + m := newTestManager() + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return fmt.Errorf("edit failed") + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false when edit fails") + } +} + +func TestPreSend_TypingStopCalled(t *testing.T) { + m := newTestManager() + var stopCalled bool + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop func to be called") + } +} + +func TestPreSend_NoRegisteredState(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false with no registered state") + } +} + +func TestPreSend_TypingAndPlaceholder(t *testing.T) { + m := newTestManager() + var stopCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + editCalled = true + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop to be called") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if !edited { + t.Fatal("expected preSend to return true") + } +} + +func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordPlaceholder("test", chatID, fmt.Sprintf("msg_%d", i)) + }(i) + } + wg.Wait() +} + +func TestRecordTypingStop_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordTypingStop("test", chatID, func() {}) + }(i) + } + wg.Wait() +} + +func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { + m := newTestManager() + var sendCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return nil // edit succeeds + }, + } + + m.RecordPlaceholder("test", "123", "456") + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.sendWithRetry(context.Background(), "test", w, msg) + + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder was edited") + } +} diff --git a/pkg/channels/media.go b/pkg/channels/media.go new file mode 100644 index 000000000..c645a6180 --- /dev/null +++ b/pkg/channels/media.go @@ -0,0 +1,15 @@ +package channels + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// MediaSender is an optional interface for channels that can send +// media attachments (images, files, audio, video). +// Manager discovers channels implementing this interface via type +// assertion and routes OutboundMediaMessage to them. +type MediaSender interface { + SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error +} diff --git a/pkg/channels/onebot/init.go b/pkg/channels/onebot/init.go new file mode 100644 index 000000000..84c06dfd6 --- /dev/null +++ b/pkg/channels/onebot/init.go @@ -0,0 +1,13 @@ +package onebot + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewOneBotChannel(cfg.Channels.OneBot, b) + }) +} diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot/onebot.go similarity index 77% rename from pkg/channels/onebot.go rename to pkg/channels/onebot/onebot.go index cee8ad9d3..e0be58fa0 100644 --- a/pkg/channels/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -1,10 +1,9 @@ -package channels +package onebot import ( "context" "encoding/json" "fmt" - "os" "strconv" "strings" "sync" @@ -14,14 +13,16 @@ import ( "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type OneBotChannel struct { - *BaseChannel + *channels.BaseChannel config config.OneBotConfig conn *websocket.Conn ctx context.Context @@ -35,7 +36,6 @@ type OneBotChannel struct { selfID int64 pending map[string]chan json.RawMessage pendingMu sync.Mutex - transcriber *voice.GroqTranscriber lastMessageID sync.Map pendingEmojiMsg sync.Map } @@ -98,7 +98,10 @@ type oneBotMessageSegment struct { } func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { - base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) const dedupSize = 1024 return &OneBotChannel{ @@ -111,10 +114,6 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One }, nil } -func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { go func() { _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{ @@ -159,7 +158,7 @@ func (c *OneBotChannel) Start(ctx context.Context) error { } } - c.setRunning(true) + c.SetRunning(true) logger.InfoC("onebot", "OneBot channel started successfully") return nil @@ -297,7 +296,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { @@ -306,6 +307,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D select { case resp := <-ch: + if resp == nil { + return nil, fmt.Errorf("API request %s: channel stopped", action) + } return resp, nil case <-time.After(timeout): return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) @@ -346,7 +350,7 @@ func (c *OneBotChannel) reconnectLoop() { func (c *OneBotChannel) Stop(ctx context.Context) error { logger.InfoC("onebot", "Stopping OneBot channel") - c.setRunning(false) + c.SetRunning(false) if c.cancel != nil { c.cancel() @@ -354,7 +358,10 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { c.pendingMu.Lock() for echo, ch := range c.pending { - close(ch) + select { + case ch <- nil: // non-blocking wake for blocked sendAPIRequest goroutines + default: + } delete(c.pending, echo) } c.pendingMu.Unlock() @@ -371,7 +378,14 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("OneBot channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.mu.Lock() @@ -401,20 +415,127 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { logger.ErrorCF("onebot", "Failed to send message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("onebot send: %w", channels.ErrTemporary) + } + + return nil +} + +// SendMedia implements the channels.MediaSender interface. +func (c *OneBotChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: } - if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { - if mid, ok := msgID.(string); ok && mid != "" { - c.setMsgEmojiLike(mid, 289, false) + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return fmt.Errorf("OneBot WebSocket not connected") + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // Build media segments + var segments []oneBotMessageSegment + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("onebot", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + segType := "image" + switch part.Type { + case "image": + segType = "image" + case "video": + segType = "video" + case "audio": + segType = "record" + default: + segType = "file" } + + segments = append(segments, oneBotMessageSegment{ + Type: segType, + Data: map[string]any{"file": "file://" + localPath}, + }) + + if part.Caption != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]any{"text": part.Caption}, + }) + } + } + + if len(segments) == 0 { + return nil + } + + chatID := msg.ChatID + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID + } + + id, err := strconv.ParseInt(rawID, 10, 64) + if err != nil { + return fmt.Errorf("invalid %s in chatID: %s: %w", idKey, chatID, channels.ErrSendFailed) + } + + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) + + req := oneBotAPIRequest{ + Action: action, + Params: map[string]any{idKey: id, "message": segments}, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal OneBot request: %w", err) + } + + c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) + c.writeMu.Unlock() + + if err != nil { + logger.ErrorCF("onebot", "Failed to send media message", map[string]any{ + "error": err.Error(), + }) + return fmt.Errorf("onebot send media: %w", channels.ErrTemporary) } return nil @@ -571,11 +692,15 @@ type parseMessageResult struct { Text string IsBotMentioned bool Media []string - LocalFiles []string ReplyTo string } -func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { +func (c *OneBotChannel) parseMessageSegments( + raw json.RawMessage, + selfID int64, + store media.MediaStore, + scope string, +) parseMessageResult { if len(raw) == 0 { return parseMessageResult{} } @@ -602,10 +727,23 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) var textParts []string mentioned := false selfIDStr := strconv.FormatInt(selfID, 10) - var media []string - var localFiles []string + var mediaRefs []string var replyTo string + // Helper to register a local file with the media store + storeFile := func(localPath, filename string) string { + if store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "onebot", + }, scope) + if err == nil { + return ref + } + } + return localPath // fallback + } + for _, seg := range segments { segType, _ := seg["type"].(string) data, _ := seg["data"].(map[string]any) @@ -641,8 +779,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - media = append(media, localPath) - localFiles = append(localFiles, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, filename)) textParts = append(textParts, fmt.Sprintf("[%s]", segType)) } } @@ -656,24 +793,8 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - localFiles = append(localFiles, localPath) - if c.transcriber != nil && c.transcriber.IsAvailable() { - tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) - result, err := c.transcriber.Transcribe(tctx, localPath) - tcancel() - if err != nil { - logger.WarnCF("onebot", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - textParts = append(textParts, "[voice (transcription failed)]") - media = append(media, localPath) - } else { - textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) - } - } else { - textParts = append(textParts, "[voice]") - media = append(media, localPath) - } + textParts = append(textParts, "[voice]") + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } } } @@ -702,8 +823,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) return parseMessageResult{ Text: strings.TrimSpace(strings.Join(textParts, "")), IsBotMentioned: mentioned, - Media: media, - LocalFiles: localFiles, + Media: mediaRefs, ReplyTo: replyTo, } } @@ -712,7 +832,13 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { switch raw.PostType { case "message": if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { - if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + // Build minimal sender for allowlist check + sender := bus.SenderInfo{ + Platform: "onebot", + PlatformID: strconv.FormatInt(userID, 10), + CanonicalID: identity.BuildCanonicalID("onebot", strconv.FormatInt(userID, 10)), + } + if !c.IsAllowedSender(sender) { logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{ "user_id": userID, }) @@ -795,7 +921,17 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { selfID = atomic.LoadInt64(&c.selfID) } - parsed := c.parseMessageSegments(raw.Message, selfID) + // Compute scope for media store before parsing (parsing may download files) + var chatIDForScope string + switch raw.MessageType { + case "group": + chatIDForScope = "group:" + strconv.FormatInt(groupID, 10) + default: + chatIDForScope = "private:" + strconv.FormatInt(userID, 10) + } + scope := channels.BuildMediaScope("onebot", chatIDForScope, messageID) + + parsed := c.parseMessageSegments(raw.Message, selfID, c.GetMediaStore(), scope) isBotMentioned := parsed.IsBotMentioned content := raw.RawMessage @@ -824,20 +960,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } } - // Clean up temp files when done - if len(parsed.LocalFiles) > 0 { - defer func() { - for _, f := range parsed.LocalFiles { - if err := os.Remove(f); err != nil { - logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ - "path": f, - "error": err.Error(), - }) - } - } - }() - } - if c.isDuplicate(messageID) { logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ "message_id": messageID, @@ -855,9 +977,9 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { senderID := strconv.FormatInt(userID, 10) var chatID string - metadata := map[string]string{ - "message_id": messageID, - } + var peer bus.Peer + + metadata := map[string]string{} if parsed.ReplyTo != "" { metadata["reply_to_message_id"] = parsed.ReplyTo @@ -866,14 +988,12 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { switch raw.MessageType { case "private": chatID = "private:" + senderID - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} case "group": groupIDStr := strconv.FormatInt(groupID, 10) chatID = "group:" + groupIDStr - metadata["peer_kind"] = "group" - metadata["peer_id"] = groupIDStr + peer = bus.Peer{Kind: "group", ID: groupIDStr} metadata["group_id"] = groupIDStr senderUserID, _ := parseJSONInt64(sender.UserID) @@ -887,8 +1007,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { metadata["sender_name"] = sender.Nickname } - triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) - if !triggered { + respond, strippedContent := c.ShouldRespondInGroup(isBotMentioned, content) + if !respond { logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{ "sender": senderID, "group": groupIDStr, @@ -926,9 +1046,30 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { if raw.MessageType == "group" && messageID != "" && messageID != "0" { c.setMsgEmojiLike(messageID, 289, true) c.pendingEmojiMsg.Store(chatID, messageID) + // Register emoji stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedMsgID := messageID + rec.RecordTypingStop("onebot", chatID, func() { + c.setMsgEmojiLike(capturedMsgID, 289, false) + }) + } } - c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) + senderInfo := bus.SenderInfo{ + Platform: "onebot", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("onebot", senderID), + DisplayName: sender.Nickname, + } + + if !c.IsAllowedSender(senderInfo) { + logger.DebugCF("onebot", "Message rejected by allowlist (senderInfo)", map[string]any{ + "sender": senderID, + }) + return + } + + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo) } func (c *OneBotChannel) isDuplicate(messageID string) bool { @@ -960,23 +1101,3 @@ func truncate(s string, n int) string { } return string(runes[:n]) + "..." } - -func (c *OneBotChannel) checkGroupTrigger( - content string, - isBotMentioned bool, -) (triggered bool, strippedContent string) { - if isBotMentioned { - return true, strings.TrimSpace(content) - } - - for _, prefix := range c.config.GroupTriggerPrefix { - if prefix == "" { - continue - } - if strings.HasPrefix(content, prefix) { - return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) - } - } - - return false, content -} diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go new file mode 100644 index 000000000..96d764418 --- /dev/null +++ b/pkg/channels/pico/init.go @@ -0,0 +1,13 @@ +package pico + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewPicoChannel(cfg.Channels.Pico, b) + }) +} diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go new file mode 100644 index 000000000..c646a3b0b --- /dev/null +++ b/pkg/channels/pico/pico.go @@ -0,0 +1,444 @@ +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// picoConn represents a single WebSocket connection. +type picoConn struct { + id string + conn *websocket.Conn + sessionID string + writeMu sync.Mutex + closed atomic.Bool +} + +// writeJSON sends a JSON message to the connection with write locking. +func (pc *picoConn) writeJSON(v any) error { + if pc.closed.Load() { + return fmt.Errorf("connection closed") + } + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + return pc.conn.WriteJSON(v) +} + +// close closes the connection. +func (pc *picoConn) close() { + if pc.closed.CompareAndSwap(false, true) { + pc.conn.Close() + } +} + +// PicoChannel implements the native Pico Protocol WebSocket channel. +// It serves as the reference implementation for all optional capability interfaces. +type PicoChannel struct { + *channels.BaseChannel + config config.PicoConfig + upgrader websocket.Upgrader + connections sync.Map // connID → *picoConn + connCount atomic.Int32 + ctx context.Context + cancel context.CancelFunc +} + +// NewPicoChannel creates a new Pico Protocol channel. +func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) { + if cfg.Token == "" { + return nil, fmt.Errorf("pico token is required") + } + + base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom) + + allowOrigins := cfg.AllowOrigins + checkOrigin := func(r *http.Request) bool { + if len(allowOrigins) == 0 { + return true // allow all if not configured + } + origin := r.Header.Get("Origin") + for _, allowed := range allowOrigins { + if allowed == "*" || allowed == origin { + return true + } + } + return false + } + + return &PicoChannel{ + BaseChannel: base, + config: cfg, + upgrader: websocket.Upgrader{ + CheckOrigin: checkOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + }, nil +} + +// Start implements Channel. +func (c *PicoChannel) Start(ctx context.Context) error { + logger.InfoC("pico", "Starting Pico Protocol channel") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + logger.InfoC("pico", "Pico Protocol channel started") + return nil +} + +// Stop implements Channel. +func (c *PicoChannel) Stop(ctx context.Context) error { + logger.InfoC("pico", "Stopping Pico Protocol channel") + c.SetRunning(false) + + // Close all connections + c.connections.Range(func(key, value any) bool { + if pc, ok := value.(*picoConn); ok { + pc.close() + } + c.connections.Delete(key) + return true + }) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("pico", "Pico Protocol channel stopped") + return nil +} + +// WebhookPath implements channels.WebhookHandler. +func (c *PicoChannel) WebhookPath() string { return "/pico/" } + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/pico") + + switch { + case path == "/ws" || path == "/ws/": + c.handleWebSocket(w, r) + default: + http.NotFound(w, r) + } +} + +// Send implements Channel — sends a message to the appropriate WebSocket connection. +func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + outMsg := newMessage(TypeMessageCreate, map[string]any{ + "content": msg.Content, + }) + + return c.broadcastToSession(msg.ChatID, outMsg) +} + +// EditMessage implements channels.MessageEditor. +func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + outMsg := newMessage(TypeMessageUpdate, map[string]any{ + "message_id": messageID, + "content": content, + }) + return c.broadcastToSession(chatID, outMsg) +} + +// StartTyping implements channels.TypingCapable. +func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + startMsg := newMessage(TypeTypingStart, nil) + if err := c.broadcastToSession(chatID, startMsg); err != nil { + return func() {}, err + } + return func() { + stopMsg := newMessage(TypeTypingStop, nil) + c.broadcastToSession(chatID, stopMsg) + }, nil +} + +// broadcastToSession sends a message to all connections with a matching session. +func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { + // chatID format: "pico:" + sessionID := strings.TrimPrefix(chatID, "pico:") + msg.SessionID = sessionID + + var sent bool + c.connections.Range(func(key, value any) bool { + pc, ok := value.(*picoConn) + if !ok { + return true + } + if pc.sessionID == sessionID { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true + } + } + return true + }) + + if !sent { + return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) + } + return nil +} + +// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle. +func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !c.IsRunning() { + http.Error(w, "channel not running", http.StatusServiceUnavailable) + return + } + + // Authenticate + if !c.authenticate(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check connection limit + maxConns := c.config.MaxConnections + if maxConns <= 0 { + maxConns = 100 + } + if int(c.connCount.Load()) >= maxConns { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + conn, err := c.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ + "error": err.Error(), + }) + return + } + + // Determine session ID from query param or generate one + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = uuid.New().String() + } + + pc := &picoConn{ + id: uuid.New().String(), + conn: conn, + sessionID: sessionID, + } + + c.connections.Store(pc.id, pc) + c.connCount.Add(1) + + logger.InfoCF("pico", "WebSocket client connected", map[string]any{ + "conn_id": pc.id, + "session_id": sessionID, + }) + + go c.readLoop(pc) +} + +// authenticate checks the Bearer token from the Authorization header. +// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled. +func (c *PicoChannel) authenticate(r *http.Request) bool { + token := c.config.Token + if token == "" { + return false + } + + // Check Authorization header + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + if strings.TrimPrefix(auth, "Bearer ") == token { + return true + } + } + + // Check query parameter only when explicitly allowed + if c.config.AllowTokenQuery { + if r.URL.Query().Get("token") == token { + return true + } + } + + return false +} + +// readLoop reads messages from a WebSocket connection. +func (c *PicoChannel) readLoop(pc *picoConn) { + defer func() { + pc.close() + c.connections.Delete(pc.id) + c.connCount.Add(-1) + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": pc.id, + "session_id": pc.sessionID, + }) + }() + + readTimeout := time.Duration(c.config.ReadTimeout) * time.Second + if readTimeout <= 0 { + readTimeout = 60 * time.Second + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.conn.SetPongHandler(func(appData string) error { + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + + // Start ping ticker + pingInterval := time.Duration(c.config.PingInterval) * time.Second + if pingInterval <= 0 { + pingInterval = 30 * time.Second + } + go c.pingLoop(pc, pingInterval) + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + _, rawMsg, err := pc.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + logger.DebugCF("pico", "WebSocket read error", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } + return + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + var msg PicoMessage + if err := json.Unmarshal(rawMsg, &msg); err != nil { + errMsg := newError("invalid_message", "failed to parse message") + pc.writeJSON(errMsg) + continue + } + + c.handleMessage(pc, msg) + } +} + +// pingLoop sends periodic ping frames to keep the connection alive. +func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if pc.closed.Load() { + return + } + pc.writeMu.Lock() + err := pc.conn.WriteMessage(websocket.PingMessage, nil) + pc.writeMu.Unlock() + if err != nil { + return + } + } + } +} + +// handleMessage processes an inbound Pico Protocol message. +func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { + switch msg.Type { + case TypePing: + pong := newMessage(TypePong, nil) + pong.ID = msg.ID + pc.writeJSON(pong) + + case TypeMessageSend: + c.handleMessageSend(pc, msg) + + default: + errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) + pc.writeJSON(errMsg) + } +} + +// handleMessageSend processes an inbound message.send from a client. +func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { + content, _ := msg.Payload["content"].(string) + if strings.TrimSpace(content) == "" { + errMsg := newError("empty_content", "message content is empty") + pc.writeJSON(errMsg) + return + } + + sessionID := msg.SessionID + if sessionID == "" { + sessionID = pc.sessionID + } + + chatID := "pico:" + sessionID + senderID := "pico-user" + + peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} + + metadata := map[string]string{ + "platform": "pico", + "session_id": sessionID, + "conn_id": pc.id, + } + + logger.DebugCF("pico", "Received message", map[string]any{ + "session_id": sessionID, + "preview": truncate(content, 50), + }) + + // Register typing with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + stop, err := c.StartTyping(c.ctx, chatID) + if err == nil { + rec.RecordTypingStop("pico", chatID, stop) + } + } + + sender := bus.SenderInfo{ + Platform: "pico", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("pico", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender) +} + +// truncate truncates a string to maxLen runes. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) + "..." +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go new file mode 100644 index 000000000..0a630e193 --- /dev/null +++ b/pkg/channels/pico/protocol.go @@ -0,0 +1,46 @@ +package pico + +import "time" + +// Protocol message types. +const ( + // TypeMessageSend is sent from client to server. + TypeMessageSend = "message.send" + TypeMediaSend = "media.send" + TypePing = "ping" + + // TypeMessageCreate is sent from server to client. + TypeMessageCreate = "message.create" + TypeMessageUpdate = "message.update" + TypeMediaCreate = "media.create" + TypeTypingStart = "typing.start" + TypeTypingStop = "typing.stop" + TypeError = "error" + TypePong = "pong" +) + +// PicoMessage is the wire format for all Pico Protocol messages. +type PicoMessage struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +// newMessage creates a PicoMessage with the given type and payload. +func newMessage(msgType string, payload map[string]any) PicoMessage { + return PicoMessage{ + Type: msgType, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } +} + +// newError creates an error PicoMessage. +func newError(code, message string) PicoMessage { + return newMessage(TypeError, map[string]any{ + "code": code, + "message": message, + }) +} diff --git a/pkg/channels/qq/init.go b/pkg/channels/qq/init.go new file mode 100644 index 000000000..15b955089 --- /dev/null +++ b/pkg/channels/qq/init.go @@ -0,0 +1,13 @@ +package qq + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewQQChannel(cfg.Channels.QQ, b) + }) +} diff --git a/pkg/channels/qq.go b/pkg/channels/qq/qq.go similarity index 74% rename from pkg/channels/qq.go rename to pkg/channels/qq/qq.go index b10776db6..112964143 100644 --- a/pkg/channels/qq.go +++ b/pkg/channels/qq/qq.go @@ -1,4 +1,4 @@ -package channels +package qq import ( "context" @@ -14,12 +14,14 @@ import ( "golang.org/x/oauth2" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) type QQChannel struct { - *BaseChannel + *channels.BaseChannel config config.QQConfig api openapi.OpenAPI tokenSource oauth2.TokenSource @@ -31,7 +33,10 @@ type QQChannel struct { } func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { - base := NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &QQChannel{ BaseChannel: base, @@ -90,11 +95,11 @@ func (c *QQChannel) Start(ctx context.Context) error { logger.ErrorCF("qq", "WebSocket session error", map[string]any{ "error": err.Error(), }) - c.setRunning(false) + c.SetRunning(false) } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("qq", "QQ bot started successfully") return nil @@ -102,7 +107,7 @@ func (c *QQChannel) Start(ctx context.Context) error { func (c *QQChannel) Stop(ctx context.Context) error { logger.InfoC("qq", "Stopping QQ bot") - c.setRunning(false) + c.SetRunning(false) if c.cancel != nil { c.cancel() @@ -113,7 +118,7 @@ func (c *QQChannel) Stop(ctx context.Context) error { func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("QQ bot not running") + return channels.ErrNotRunning } // construct message @@ -127,7 +132,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("qq send: %w", channels.ErrTemporary) } return nil @@ -162,20 +167,35 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { "length": len(content), }) - // forward to message bus - metadata := map[string]string{ - "message_id": data.ID, - "peer_kind": "direct", - "peer_id": senderID, + // 转发到消息总线 + metadata := map[string]string{} + + sender := bus.SenderInfo{ + Platform: "qq", + PlatformID: data.Author.ID, + CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID), } - c.HandleMessage(senderID, senderID, content, []string{}, metadata) + if !c.IsAllowedSender(sender) { + return nil + } + + c.HandleMessage(c.ctx, + bus.Peer{Kind: "direct", ID: senderID}, + data.ID, + senderID, + senderID, + content, + []string{}, + metadata, + sender, + ) return nil } } -// handleGroupATMessage handles group @messages +// handleGroupATMessage handles QQ group @ messages func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error { // deduplication check @@ -192,34 +212,57 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return nil } - // extract message content (remove @bot part) + // extract message content (remove @ bot part) content := data.Content if content == "" { logger.DebugC("qq", "Received empty group message, ignoring") return nil } + // GroupAT event means bot is always mentioned; apply group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(true, content) + if !respond { + return nil + } + content = cleaned + logger.InfoCF("qq", "Received group AT message", map[string]any{ "sender": senderID, "group": data.GroupID, "length": len(content), }) - // forward to message bus (use GroupID as ChatID) + // 转发到消息总线(使用 GroupID 作为 ChatID) metadata := map[string]string{ - "message_id": data.ID, - "group_id": data.GroupID, - "peer_kind": "group", - "peer_id": data.GroupID, + "group_id": data.GroupID, + } + + sender := bus.SenderInfo{ + Platform: "qq", + PlatformID: data.Author.ID, + CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID), + } + + if !c.IsAllowedSender(sender) { + return nil } - c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) + c.HandleMessage(c.ctx, + bus.Peer{Kind: "group", ID: data.GroupID}, + data.ID, + senderID, + data.GroupID, + content, + []string{}, + metadata, + sender, + ) return nil } } -// isDuplicate checks if message is duplicate +// isDuplicate 检查消息是否重复 func (c *QQChannel) isDuplicate(messageID string) bool { c.mu.Lock() defer c.mu.Unlock() @@ -230,9 +273,9 @@ func (c *QQChannel) isDuplicate(messageID string) bool { c.processedIDs[messageID] = true - // simple cleanup: limit map size + // 简单清理:限制 map 大小 if len(c.processedIDs) > 10000 { - // clear half + // 清空一半 count := 0 for id := range c.processedIDs { if count >= 5000 { diff --git a/pkg/channels/registry.go b/pkg/channels/registry.go new file mode 100644 index 000000000..36a05bf3e --- /dev/null +++ b/pkg/channels/registry.go @@ -0,0 +1,32 @@ +package channels + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// ChannelFactory is a constructor function that creates a Channel from config and message bus. +// Each channel subpackage registers one or more factories via init(). +type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) + +var ( + factoriesMu sync.RWMutex + factories = map[string]ChannelFactory{} +) + +// RegisterFactory registers a named channel factory. Called from subpackage init() functions. +func RegisterFactory(name string, f ChannelFactory) { + factoriesMu.Lock() + defer factoriesMu.Unlock() + factories[name] = f +} + +// getFactory looks up a channel factory by name. +func getFactory(name string) (ChannelFactory, bool) { + factoriesMu.RLock() + defer factoriesMu.RUnlock() + f, ok := factories[name] + return f, ok +} diff --git a/pkg/channels/slack/init.go b/pkg/channels/slack/init.go new file mode 100644 index 000000000..c131bb291 --- /dev/null +++ b/pkg/channels/slack/init.go @@ -0,0 +1,13 @@ +package slack + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewSlackChannel(cfg.Channels.Slack, b) + }) +} diff --git a/pkg/channels/slack.go b/pkg/channels/slack/slack.go similarity index 65% rename from pkg/channels/slack.go rename to pkg/channels/slack/slack.go index f087aa8da..5e2d5dc4b 100644 --- a/pkg/channels/slack.go +++ b/pkg/channels/slack/slack.go @@ -1,32 +1,31 @@ -package channels +package slack import ( "context" "fmt" - "os" "strings" "sync" - "time" "github.com/slack-go/slack" "github.com/slack-go/slack/slackevents" "github.com/slack-go/slack/socketmode" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type SlackChannel struct { - *BaseChannel + *channels.BaseChannel config config.SlackConfig api *slack.Client socketClient *socketmode.Client botUserID string teamID string - transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc pendingAcks sync.Map @@ -49,7 +48,11 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack socketClient := socketmode.New(api) - base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(40000), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &SlackChannel{ BaseChannel: base, @@ -59,10 +62,6 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack }, nil } -func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *SlackChannel) Start(ctx context.Context) error { logger.InfoC("slack", "Starting Slack channel (Socket Mode)") @@ -92,7 +91,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { } }() - c.setRunning(true) + c.SetRunning(true) logger.InfoC("slack", "Slack channel started (Socket Mode)") return nil } @@ -104,14 +103,14 @@ func (c *SlackChannel) Stop(ctx context.Context) error { c.cancel() } - c.setRunning(false) + c.SetRunning(false) logger.InfoC("slack", "Slack channel stopped") return nil } func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("slack channel not running") + return channels.ErrNotRunning } channelID, threadTS := parseSlackChatID(msg.ChatID) @@ -129,7 +128,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) if err != nil { - return fmt.Errorf("failed to send slack message: %w", err) + return fmt.Errorf("slack send: %w", channels.ErrTemporary) } if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { @@ -148,6 +147,60 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } +// SendMedia implements the channels.MediaSender interface. +func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + channelID, _ := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("slack", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + filename := part.Filename + if filename == "" { + filename = "file" + } + + title := part.Caption + if title == "" { + title = filename + } + + _, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{ + Channel: channelID, + File: localPath, + Filename: filename, + Title: title, + }) + if err != nil { + logger.ErrorCF("slack", "Failed to upload media", map[string]any{ + "filename": filename, + "error": err.Error(), + }) + return fmt.Errorf("slack send media: %w", channels.ErrTemporary) + } + } + + return nil +} + func (c *SlackChannel) eventLoop() { for { select { @@ -201,7 +254,12 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } // check allowlist to avoid downloading attachments for rejected users - if !c.IsAllowed(ev.User) { + sender := bus.SenderInfo{ + Platform: "slack", + PlatformID: ev.User, + CanonicalID: identity.BuildCanonicalID("slack", ev.User), + } + if !c.IsAllowedSender(sender) { logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{ "user_id": ev.User, }) @@ -223,6 +281,18 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -231,20 +301,32 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content := ev.Text content = c.stripBotMention(content) + // In non-DM channels, apply group trigger filtering + if !strings.HasPrefix(channelID, "D") { + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return + } + content = cleaned + } + var mediaPaths []string - localFiles := []string{} // track local files that need cleanup - - // ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + scope := channels.BuildMediaScope("slack", chatID, messageTS) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "slack", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } if ev.Message != nil && len(ev.Message.Files) > 0 { for _, file := range ev.Message.Files { @@ -252,23 +334,8 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { if localPath == "" { continue } - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) - - if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) - defer cancel() - result, err := c.transcriber.Transcribe(ctx, localPath) - - if err != nil { - logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()}) - content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) - } else { - content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) - } - } else { - content += fmt.Sprintf("\n[file: %s]", file.Name) - } + mediaPaths = append(mediaPaths, storeMedia(localPath, file.Name)) + content += fmt.Sprintf("\n[file: %s]", file.Name) } } @@ -283,13 +350,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { peerID = senderID } + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", - "peer_kind": peerKind, - "peer_id": peerID, "team_id": c.teamID, } @@ -300,7 +367,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "has_thread": threadTS != "", }) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -308,7 +375,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } - if !c.IsAllowed(ev.User) { + if !c.IsAllowedSender(bus.SenderInfo{ + Platform: "slack", + PlatformID: ev.User, + CanonicalID: identity.BuildCanonicalID("slack", ev.User), + }) { logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{ "user_id": ev.User, }) @@ -316,6 +387,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { } senderID := ev.User + mentionSender := bus.SenderInfo{ + Platform: "slack", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("slack", senderID), + } channelID := ev.Channel threadTS := ev.ThreadTimeStamp messageTS := ev.TimeStamp @@ -332,6 +408,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -350,18 +438,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { mentionPeerID = senderID } + mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID} + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", - "peer_kind": mentionPeerKind, - "peer_id": mentionPeerID, "team_id": c.teamID, } - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -374,7 +462,12 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { c.socketClient.Ack(*event.Request) } - if !c.IsAllowed(cmd.UserID) { + cmdSender := bus.SenderInfo{ + Platform: "slack", + PlatformID: cmd.UserID, + CanonicalID: identity.BuildCanonicalID("slack", cmd.UserID), + } + if !c.IsAllowedSender(cmdSender) { logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{ "user_id": cmd.UserID, }) @@ -395,8 +488,6 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, - "peer_kind": "channel", - "peer_id": channelID, "team_id": c.teamID, } @@ -406,7 +497,17 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "text": utils.Truncate(content, 50), }) - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage( + c.ctx, + bus.Peer{Kind: "channel", ID: channelID}, + "", + senderID, + chatID, + content, + nil, + metadata, + cmdSender, + ) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack/slack_test.go similarity index 99% rename from pkg/channels/slack_test.go rename to pkg/channels/slack/slack_test.go index 3707c2703..30e0d2d73 100644 --- a/pkg/channels/slack_test.go +++ b/pkg/channels/slack/slack_test.go @@ -1,4 +1,4 @@ -package channels +package slack import ( "testing" diff --git a/pkg/channels/split.go b/pkg/channels/split.go new file mode 100644 index 000000000..27d76df1b --- /dev/null +++ b/pkg/channels/split.go @@ -0,0 +1,209 @@ +package channels + +import ( + "strings" +) + +// SplitMessage splits long messages into chunks, preserving code block integrity. +// The maxLen parameter is measured in runes (Unicode characters), not bytes. +// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, +// but may extend to maxLen when needed. +// Call SplitMessage with the full text content and the maximum allowed length of a single message; +// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. +func SplitMessage(content string, maxLen int) []string { + if maxLen <= 0 { + if content == "" { + return nil + } + return []string{content} + } + + runes := []rune(content) + var messages []string + + // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible + codeBlockBuffer := maxLen / 10 + if codeBlockBuffer < 50 { + codeBlockBuffer = 50 + } + if codeBlockBuffer > maxLen/2 { + codeBlockBuffer = maxLen / 2 + } + + for len(runes) > 0 { + if len(runes) <= maxLen { + messages = append(messages, string(runes)) + break + } + + // Effective split point: maxLen minus buffer, to leave room for code blocks + effectiveLimit := maxLen - codeBlockBuffer + if effectiveLimit < maxLen/2 { + effectiveLimit = maxLen / 2 + } + + // Find natural split point within the effective limit + msgEnd := findLastNewlineRunes(runes[:effectiveLimit], 200) + if msgEnd <= 0 { + msgEnd = findLastSpaceRunes(runes[:effectiveLimit], 100) + } + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + // Check if this would end with an incomplete code block + candidate := runes[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlockRunes(candidate) + + if unclosedIdx >= 0 { + // Message would end with incomplete code block + // Try to extend up to maxLen to include the closing ``` + if len(runes) > msgEnd { + closingIdx := findNextClosingCodeBlockRunes(runes, msgEnd) + if closingIdx > 0 && closingIdx <= maxLen { + // Extend to include the closing ``` + msgEnd = closingIdx + } else { + // Code block is too long to fit in one chunk or missing closing fence. + // Try to split inside by injecting closing and reopening fences. + fenceRunes := runes[unclosedIdx:] + headerEnd := findNewlineInRunes(fenceRunes) + var header string + if headerEnd == -1 { + header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+3])) + } else { + header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+headerEnd])) + } + headerEndIdx := unclosedIdx + len([]rune(header)) + if headerEnd != -1 { + headerEndIdx = unclosedIdx + headerEnd + } + + // If we have a reasonable amount of content after the header, split inside + if msgEnd > headerEndIdx+20 { + // Find a better split point closer to maxLen + innerLimit := maxLen - 5 // Leave room for "\n```" + betterEnd := findLastNewlineRunes(runes[:innerLimit], 200) + if betterEnd > headerEndIdx { + msgEnd = betterEnd + } else { + msgEnd = innerLimit + } + chunk := strings.TrimRight(string(runes[:msgEnd]), " \t\n\r") + "\n```" + messages = append(messages, chunk) + remaining := strings.TrimSpace(header + "\n" + string(runes[msgEnd:])) + runes = []rune(remaining) + continue + } + + // Otherwise, try to split before the code block starts + newEnd := findLastNewlineRunes(runes[:unclosedIdx], 200) + if newEnd <= 0 { + newEnd = findLastSpaceRunes(runes[:unclosedIdx], 100) + } + if newEnd > 0 { + msgEnd = newEnd + } else { + // If we can't split before, we MUST split inside (last resort) + if unclosedIdx > 20 { + msgEnd = unclosedIdx + } else { + msgEnd = maxLen - 5 + chunk := strings.TrimRight(string(runes[:msgEnd]), " \t\n\r") + "\n```" + messages = append(messages, chunk) + remaining := strings.TrimSpace(header + "\n" + string(runes[msgEnd:])) + runes = []rune(remaining) + continue + } + } + } + } + } + + if msgEnd <= 0 { + msgEnd = effectiveLimit + } + + messages = append(messages, string(runes[:msgEnd])) + remaining := strings.TrimSpace(string(runes[msgEnd:])) + runes = []rune(remaining) + } + + return messages +} + +// findLastUnclosedCodeBlockRunes finds the last opening ``` that doesn't have a closing ``` +// Returns the rune position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlockRunes(runes []rune) int { + inCodeBlock := false + lastOpenIdx := -1 + + for i := 0; i < len(runes); i++ { + if i+2 < len(runes) && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' { + // Toggle code block state on each fence + if !inCodeBlock { + // Entering a code block: record this opening fence + lastOpenIdx = i + } + inCodeBlock = !inCodeBlock + i += 2 + } + } + + if inCodeBlock { + return lastOpenIdx + } + return -1 +} + +// findNextClosingCodeBlockRunes finds the next closing ``` starting from a rune position +// Returns the rune position after the closing ``` or -1 if not found +func findNextClosingCodeBlockRunes(runes []rune, startIdx int) int { + for i := startIdx; i < len(runes); i++ { + if i+2 < len(runes) && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' { + return i + 3 + } + } + return -1 +} + +// findNewlineInRunes finds the first newline character in a rune slice. +// Returns the rune index of the newline or -1 if not found. +func findNewlineInRunes(runes []rune) int { + for i, r := range runes { + if r == '\n' { + return i + } + } + return -1 +} + +// findLastNewlineRunes finds the last newline character within the last N runes +// Returns the rune position of the newline or -1 if not found +func findLastNewlineRunes(runes []rune, searchWindow int) int { + searchStart := len(runes) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(runes) - 1; i >= searchStart; i-- { + if runes[i] == '\n' { + return i + } + } + return -1 +} + +// findLastSpaceRunes finds the last space character within the last N runes +// Returns the rune position of the space or -1 if not found +func findLastSpaceRunes(runes []rune, searchWindow int) int { + searchStart := len(runes) - searchWindow + if searchStart < 0 { + searchStart = 0 + } + for i := len(runes) - 1; i >= searchStart; i-- { + if runes[i] == ' ' || runes[i] == '\t' { + return i + } + } + return -1 +} diff --git a/pkg/utils/message_test.go b/pkg/channels/split_test.go similarity index 67% rename from pkg/utils/message_test.go rename to pkg/channels/split_test.go index 338509437..d6356bdb9 100644 --- a/pkg/utils/message_test.go +++ b/pkg/channels/split_test.go @@ -1,4 +1,4 @@ -package utils +package channels import ( "strings" @@ -34,11 +34,15 @@ func TestSplitMessage(t *testing.T) { maxLen: 2000, expectChunks: 2, checkContent: func(t *testing.T, chunks []string) { - if len(chunks[0]) > 2000 { - t.Errorf("Chunk 0 too large: %d", len(chunks[0])) + if len([]rune(chunks[0])) > 2000 { + t.Errorf("Chunk 0 too large: %d runes", len([]rune(chunks[0]))) } - if len(chunks[0])+len(chunks[1]) != len(longText) { - t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) + if len([]rune(chunks[0]))+len([]rune(chunks[1])) != len([]rune(longText)) { + t.Errorf( + "Total rune length mismatch. Got %d, want %d", + len([]rune(chunks[0]))+len([]rune(chunks[1])), + len([]rune(longText)), + ) } }, }, @@ -53,11 +57,11 @@ func TestSplitMessage(t *testing.T) { maxLen: 2000, expectChunks: 2, checkContent: func(t *testing.T, chunks []string) { - if len(chunks[0]) != 1750 { - t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) + if len([]rune(chunks[0])) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 runes (split at newline), got %d", len([]rune(chunks[0]))) } if chunks[1] != strings.Repeat("b", 300) { - t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) + t.Errorf("Chunk 1 content mismatch. Len: %d", len([]rune(chunks[1]))) } }, }, @@ -78,17 +82,39 @@ func TestSplitMessage(t *testing.T) { }, }, { - name: "Preserve Unicode characters", - content: strings.Repeat("\u4e16", 1000), // 3000 bytes + name: "Preserve Unicode characters (rune-aware)", + content: strings.Repeat("\u4e16", 2500), // 2500 runes, 7500 bytes maxLen: 2000, expectChunks: 2, checkContent: func(t *testing.T, chunks []string) { - // Just verify we didn't panic and got valid strings. - // Go strings are UTF-8, if we split mid-rune it would be bad, - // but standard slicing might do that. - // Let's assume standard behavior is acceptable or check if it produces invalid rune? - if !strings.Contains(chunks[0], "\u4e16") { - t.Error("Chunk should contain unicode characters") + // Verify chunks contain valid unicode and don't split mid-rune + for i, chunk := range chunks { + runeCount := len([]rune(chunk)) + if runeCount > 2000 { + t.Errorf("Chunk %d has %d runes, exceeds maxLen 2000", i, runeCount) + } + if !strings.Contains(chunk, "\u4e16") { + t.Errorf("Chunk %d should contain unicode characters", i) + } + } + // Verify total rune count is preserved + totalRunes := 0 + for _, chunk := range chunks { + totalRunes += len([]rune(chunk)) + } + if totalRunes != 2500 { + t.Errorf("Total rune count mismatch. Got %d, want 2500", totalRunes) + } + }, + }, + { + name: "Zero maxLen returns single chunk", + content: "Hello world", + maxLen: 0, + expectChunks: 1, + checkContent: func(t *testing.T, chunks []string) { + if chunks[0] != "Hello world" { + t.Errorf("Expected original content, got %q", chunks[0]) } }, }, @@ -145,7 +171,7 @@ func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { } // First chunk should contain meaningful content - if len(chunks[0]) > 40 { - t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) + if len([]rune(chunks[0])) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d runes", len([]rune(chunks[0]))) } } diff --git a/pkg/channels/telegram/init.go b/pkg/channels/telegram/init.go new file mode 100644 index 000000000..ac87bb805 --- /dev/null +++ b/pkg/channels/telegram/init.go @@ -0,0 +1,13 @@ +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram/telegram.go similarity index 57% rename from pkg/channels/telegram.go rename to pkg/channels/telegram/telegram.go index 5cd51e8bc..86bfc89f8 100644 --- a/pkg/channels/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -1,4 +1,4 @@ -package channels +package telegram import ( "context" @@ -7,8 +7,8 @@ import ( "net/url" "os" "regexp" + "strconv" "strings" - "sync" "time" "github.com/mymmrac/telego" @@ -17,31 +17,23 @@ import ( tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type TelegramChannel struct { - *BaseChannel - bot *telego.Bot - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - transcriber *voice.GroqTranscriber - placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> thinkingCancel -} - -type thinkingCancel struct { - fn context.CancelFunc -} - -func (c *thinkingCancel) Cancel() { - if c != nil && c.fn != nil { - c.fn() - } + *channels.BaseChannel + bot *telego.Bot + bh *telegohandler.BotHandler + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -72,38 +64,44 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann return nil, fmt.Errorf("failed to create telegram bot: %w", err) } - base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) + base := channels.NewBaseChannel( + "telegram", + telegramCfg, + bus, + telegramCfg.AllowFrom, + channels.WithMaxMessageLength(4096), + channels.WithGroupTrigger(telegramCfg.GroupTrigger), + channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID), + ) return &TelegramChannel{ - BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), - bot: bot, - config: cfg, - chatIDs: make(map[string]int64), - transcriber: nil, - placeholders: sync.Map{}, - stopThinking: sync.Map{}, + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), }, nil } -func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *TelegramChannel) Start(ctx context.Context) error { logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + c.ctx, c.cancel = context.WithCancel(ctx) + + updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) if err != nil { + c.cancel() return fmt.Errorf("failed to start long polling: %w", err) } bh, err := telegohandler.NewBotHandler(c.bot, updates) if err != nil { + c.cancel() return fmt.Errorf("failed to create bot handler: %w", err) } + c.bh = bh bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { c.commands.Help(ctx, message) @@ -125,59 +123,46 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return c.handleMessage(ctx, &message) }, th.AnyMessage()) - c.setRunning(true) + c.SetRunning(true) logger.InfoCF("telegram", "Telegram bot connected", map[string]any{ "username": c.bot.Username(), }) go bh.Start() - go func() { - <-ctx.Done() - bh.Stop() - }() - return nil } func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") - c.setRunning(false) + c.SetRunning(false) + + // Stop the bot handler + if c.bh != nil { + c.bh.Stop() + } + + // Cancel our context (stops long polling) + if c.cancel != nil { + c.cancel() + } + return nil } func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("telegram bot not running") + return channels.ErrNotRunning } chatID, err := parseChatID(msg.ChatID) if err != nil { - return fmt.Errorf("invalid chat ID: %w", err) - } - - // Stop thinking animation - if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - if cf, ok := stop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(msg.ChatID) + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } htmlContent := markdownToTelegramHTML(msg.Content) - // Try to edit placeholder - if pID, ok := c.placeholders.Load(msg.ChatID); ok { - c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { - return nil - } - // Fallback to new message if edit fails - } - + // Typing/placeholder handled by Manager.preSend — just send the message tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML @@ -186,9 +171,112 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err "error": err.Error(), }) tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + return fmt.Errorf("telegram send: %w", channels.ErrTemporary) + } + } + + return nil +} + +// EditMessage implements channels.MessageEditor. +func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + cid, err := parseChatID(chatID) + if err != nil { return err } + mid, err := strconv.Atoi(messageID) + if err != nil { + return err + } + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent) + editMsg.ParseMode = telego.ModeHTML + _, err = c.bot.EditMessageText(ctx, editMsg) + return err +} + +// SendMedia implements the channels.MediaSender interface. +func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + chatID, err := parseChatID(msg.ChatID) + if err != nil { + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("telegram", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("telegram", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + + filename := part.Filename + if filename == "" { + filename = "file" + } + + switch part.Type { + case "image": + params := &telego.SendPhotoParams{ + ChatID: tu.ID(chatID), + Photo: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendPhoto(ctx, params) + case "audio": + params := &telego.SendAudioParams{ + ChatID: tu.ID(chatID), + Audio: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendAudio(ctx, params) + case "video": + params := &telego.SendVideoParams{ + ChatID: tu.ID(chatID), + Video: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendVideo(ctx, params) + default: // "file" or unknown types + params := &telego.SendDocumentParams{ + ChatID: tu.ID(chatID), + Document: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendDocument(ctx, params) + } + + file.Close() + + if err != nil { + logger.ErrorCF("telegram", "Failed to send media", map[string]any{ + "type": part.Type, + "error": err.Error(), + }) + return fmt.Errorf("telegram send media: %w", channels.ErrTemporary) + } + } return nil } @@ -203,37 +291,46 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes return fmt.Errorf("message sender (user) is nil") } - senderID := fmt.Sprintf("%d", user.ID) - if user.Username != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + platformID := fmt.Sprintf("%d", user.ID) + sender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: platformID, + CanonicalID: identity.BuildCanonicalID("telegram", platformID), + Username: user.Username, + DisplayName: user.FirstName, } // check allowlist to avoid downloading attachments for rejected users - if !c.IsAllowed(senderID) { + if !c.IsAllowedSender(sender) { logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ - "user_id": senderID, + "user_id": platformID, }) return nil } chatID := message.Chat.ID - c.chatIDs[senderID] = chatID + c.chatIDs[platformID] = chatID content := "" mediaPaths := []string{} - localFiles := []string{} // track local files that need cleanup - - // ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + chatIDStr := fmt.Sprintf("%d", chatID) + messageIDStr := fmt.Sprintf("%d", message.MessageID) + scope := channels.BuildMediaScope("telegram", chatIDStr, messageIDStr) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "telegram", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback: use raw path + } if message.Text != "" { content += message.Text @@ -250,8 +347,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { - localFiles = append(localFiles, photoPath) - mediaPaths = append(mediaPaths, photoPath) + mediaPaths = append(mediaPaths, storeMedia(photoPath, "photo.jpg")) if content != "" { content += "\n" } @@ -262,43 +358,19 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Voice != nil { voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { - localFiles = append(localFiles, voicePath) - mediaPaths = append(mediaPaths, voicePath) - - transcribedText := "" - if c.transcriber != nil && c.transcriber.IsAvailable() { - transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - result, err := c.transcriber.Transcribe(transcriberCtx, voicePath) - if err != nil { - logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{ - "error": err.Error(), - "path": voicePath, - }) - transcribedText = "[voice (transcription failed)]" - } else { - transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = "[voice]" - } + mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg")) if content != "" { content += "\n" } - content += transcribedText + content += "[voice]" } } if message.Audio != nil { audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { - localFiles = append(localFiles, audioPath) - mediaPaths = append(mediaPaths, audioPath) + mediaPaths = append(mediaPaths, storeMedia(audioPath, "audio.mp3")) if content != "" { content += "\n" } @@ -309,8 +381,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Document != nil { docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { - localFiles = append(localFiles, docPath) - mediaPaths = append(mediaPaths, docPath) + mediaPaths = append(mediaPaths, storeMedia(docPath, "document")) if content != "" { content += "\n" } @@ -322,8 +393,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes content = "[empty message]" } + // In group chats, apply unified group trigger filtering + if message.Chat.Type != "private" { + isMentioned := c.isBotMentioned(message) + if isMentioned { + content = c.stripBotMention(content) + } + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + return nil + } + content = cleaned + } + logger.DebugCF("telegram", "Received message", map[string]any{ - "sender_id": senderID, + "sender_id": sender.CanonicalID, "chat_id": fmt.Sprintf("%d", chatID), "preview": utils.Truncate(content, 50), }) @@ -336,22 +420,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes }) } - // Stop any previous thinking animation - chatIDStr := fmt.Sprintf("%d", chatID) - if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { - if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - } - - // Create cancel function for thinking state + // Create cancel function for thinking state and register with Manager _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) - c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("telegram", chatIDStr, thinkCancel) + } else { + // No recorder — cancel immediately to avoid context leak + thinkCancel() + } pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(chatIDStr, pID) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordPlaceholder("telegram", chatIDStr, fmt.Sprintf("%d", pID)) + } } peerKind := "direct" @@ -361,17 +444,26 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes peerID = fmt.Sprintf("%d", chatID) } + peer := bus.Peer{Kind: peerKind, ID: peerID} + messageID := fmt.Sprintf("%d", message.MessageID) + metadata := map[string]string{ - "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), - "peer_kind": peerKind, - "peer_id": peerID, } - c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage(c.ctx, + peer, + messageID, + platformID, + fmt.Sprintf("%d", chatID), + content, + mediaPaths, + metadata, + sender, + ) return nil } @@ -527,3 +619,52 @@ func escapeHTML(text string) string { text = strings.ReplaceAll(text, ">", ">") return text } + +// isBotMentioned checks if the bot is mentioned in the message via entities. +func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { + botUsername := c.bot.Username() + if botUsername == "" { + return false + } + + entities := message.Entities + if entities == nil { + entities = message.CaptionEntities + } + + for _, entity := range entities { + if entity.Type == "mention" { + // Extract the mention text from the message + text := message.Text + if text == "" { + text = message.Caption + } + runes := []rune(text) + end := entity.Offset + entity.Length + if end <= len(runes) { + mention := string(runes[entity.Offset:end]) + if strings.EqualFold(mention, "@"+botUsername) { + return true + } + } + } + if entity.Type == "text_mention" && entity.User != nil { + if entity.User.Username == botUsername { + return true + } + } + } + return false +} + +// stripBotMention removes the @bot mention from the content. +func (c *TelegramChannel) stripBotMention(content string) string { + botUsername := c.bot.Username() + if botUsername == "" { + return content + } + // Case-insensitive replacement + re := regexp.MustCompile(`(?i)@` + regexp.QuoteMeta(botUsername)) + content = re.ReplaceAllString(content, "") + return strings.TrimSpace(content) +} diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go similarity index 99% rename from pkg/channels/telegram_commands.go rename to pkg/channels/telegram/telegram_commands.go index a084b641b..f17912260 100644 --- a/pkg/channels/telegram_commands.go +++ b/pkg/channels/telegram/telegram_commands.go @@ -1,4 +1,4 @@ -package channels +package telegram import ( "context" diff --git a/pkg/channels/webhook.go b/pkg/channels/webhook.go new file mode 100644 index 000000000..3cf27baf6 --- /dev/null +++ b/pkg/channels/webhook.go @@ -0,0 +1,20 @@ +package channels + +import "net/http" + +// WebhookHandler is an optional interface for channels that receive messages +// via HTTP webhooks. Manager discovers channels implementing this interface +// and registers them on the shared HTTP server. +type WebhookHandler interface { + // WebhookPath returns the path to mount this handler on the shared server. + // Examples: "/webhook/line", "/webhook/wecom" + WebhookPath() string + http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request) +} + +// HealthChecker is an optional interface for channels that expose +// a health check endpoint on the shared HTTP server. +type HealthChecker interface { + HealthPath() string + HealthHandler(w http.ResponseWriter, r *http.Request) +} diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom/app.go similarity index 68% rename from pkg/channels/wecom_app.go rename to pkg/channels/wecom/app.go index 715c48707..409aa8e96 100644 --- a/pkg/channels/wecom_app.go +++ b/pkg/channels/wecom/app.go @@ -1,8 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel implementation -// Supports receiving messages via webhook callback and sending messages proactively - -package channels +package wecom import ( "bytes" @@ -11,14 +7,19 @@ import ( "encoding/xml" "fmt" "io" + "mime/multipart" "net/http" "net/url" + "os" + "path/filepath" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -29,9 +30,8 @@ const ( // WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) type WeComAppChannel struct { - *BaseChannel + *channels.BaseChannel config config.WeComAppConfig - server *http.Server accessToken string tokenExpiry time.Time tokenMu sync.RWMutex @@ -123,7 +123,11 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") } - base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WeComAppChannel{ BaseChannel: base, @@ -137,7 +141,7 @@ func (c *WeComAppChannel) Name() string { return "wecom_app" } -// Start initializes the WeCom App channel with HTTP webhook server +// Start initializes the WeCom App channel func (c *WeComAppChannel) Start(ctx context.Context) error { logger.InfoC("wecom_app", "Starting WeCom App channel...") @@ -153,37 +157,8 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { // Start token refresh goroutine go c.tokenRefreshLoop() - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom-app" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom-app", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - - c.setRunning(true) - logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() + c.SetRunning(true) + logger.InfoC("wecom_app", "WeCom App channel started") return nil } @@ -196,13 +171,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { c.cancel() } - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - - c.setRunning(false) + c.SetRunning(false) logger.InfoC("wecom_app", "WeCom App channel stopped") return nil } @@ -210,7 +179,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { // Send sends a message to WeCom user proactively using access token func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom_app channel not running") + return channels.ErrNotRunning } accessToken := c.getAccessToken() @@ -226,6 +195,220 @@ func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) } +// SendMedia implements the channels.MediaSender interface. +func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + accessToken := c.getAccessToken() + if accessToken == "" { + return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + // Map part type to WeCom media type + mediaType := "file" + switch part.Type { + case "image": + mediaType = "image" + case "audio": + mediaType = "voice" + case "video": + mediaType = "video" + default: + mediaType = "file" + } + + // Upload media to get media_id + mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{ + "type": mediaType, + "error": err.Error(), + }) + // Fallback: send caption as text + if part.Caption != "" { + _ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption) + } + continue + } + + // Send media message using the media_id + if mediaType == "image" { + err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID) + } else { + // For non-image types, send as text fallback with caption + caption := part.Caption + if caption == "" { + caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) + } + err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption) + } + + if err != nil { + return err + } + } + + return nil +} + +// uploadMedia uploads a local file to WeCom temporary media storage. +func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) { + apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s", + wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType)) + + file, err := os.Open(localPath) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + filename := filepath.Base(localPath) + formFile, err := writer.CreateFormFile("media", filename) + if err != nil { + return "", fmt.Errorf("failed to create form file: %w", err) + } + + if _, err = io.Copy(formFile, file); err != nil { + return "", fmt.Errorf("failed to copy file content: %w", err) + } + writer.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", channels.ClassifyNetError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody))) + } + + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + MediaID string `json:"media_id"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to parse upload response: %w", err) + } + + if result.ErrCode != 0 { + return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode) + } + + return result.MediaID, nil +} + +// sendImageMessage sends an image message using a media_id. +func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComImageMessage{ + ToUser: userID, + MsgType: "image", + AgentID: c.config.AgentID, + } + msg.Image.MediaID = mediaID + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return channels.ClassifyNetError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody))) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(respBody, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *WeComAppChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/wecom-app" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path. +func (c *WeComAppChannel) HealthPath() string { + return "/health/wecom-app" +} + +// HealthHandler handles health check requests. +func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + // handleWebhook handles incoming webhook requests from WeCom func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -279,7 +462,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ "token": c.config.Token, "msg_signature": msgSignature, @@ -298,7 +481,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons "encoding_aes_key": c.config.EncodingAESKey, "corp_id": c.config.CorpID, }) - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), @@ -357,7 +540,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom_app", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -365,7 +548,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp // Decrypt message with CorpID verification // For WeCom App (自建应用), receiveid should be corp_id - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) if err != nil { logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ "error": err.Error(), @@ -428,6 +611,9 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag // Build metadata // WeCom App only supports direct messages (private chat) + peer := bus.Peer{Kind: "direct", ID: senderID} + messageID := fmt.Sprintf("%d", msg.MsgId) + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": fmt.Sprintf("%d", msg.MsgId), @@ -435,8 +621,6 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag "platform": "wecom_app", "media_id": msg.MediaId, "create_time": fmt.Sprintf("%d", msg.CreateTime), - "peer_kind": "direct", - "peer_id": senderID, } content := msg.Content @@ -447,8 +631,15 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag "preview": utils.Truncate(content, 50), }) + // Build sender info + appSender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + } + // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender) } // tokenRefreshLoop periodically refreshes the access token @@ -550,10 +741,15 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) @@ -605,10 +801,15 @@ func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom/app_test.go similarity index 95% rename from pkg/channels/wecom_app_test.go rename to pkg/channels/wecom/app_test.go index abf15c52b..5420949de 100644 --- a/pkg/channels/wecom_app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel tests - -package channels +package wecom import ( "bytes" @@ -197,7 +194,7 @@ func TestWeComAppVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -207,7 +204,7 @@ func TestWeComAppVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -221,7 +218,7 @@ func TestWeComAppVerifySignature(t *testing.T) { } chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -243,7 +240,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -268,7 +265,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -286,7 +283,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -301,7 +298,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { } ch, _ := NewWeComAppChannel(cfg, msgBus) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -319,7 +316,7 @@ func TestWeComAppDecryptMessage(t *testing.T) { // Encrypt a very short message that results in ciphertext less than block size shortData := make([]byte, 8) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for short ciphertext, got nil") } @@ -361,7 +358,7 @@ func TestWeComAppPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) + result, err := pkcs7Unpad(tt.input) if tt.expected == nil { // This case should return an error if err == nil { @@ -852,6 +849,28 @@ func TestWeComAppMessageStructures(t *testing.T) { } }) + t.Run("WeComImageMessage structure", func(t *testing.T) { + msg := WeComImageMessage{ + ToUser: "user123", + MsgType: "image", + AgentID: 1000002, + } + msg.Image.MediaID = "media_123456" + + if msg.Image.MediaID != "media_123456" { + t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") + } + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } + }) + t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { jsonData := `{ "errcode": 0, diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom/bot.go similarity index 66% rename from pkg/channels/wecom.go rename to pkg/channels/wecom/bot.go index f8daf89de..4c576b84b 100644 --- a/pkg/channels/wecom.go +++ b/pkg/channels/wecom/bot.go @@ -1,29 +1,21 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel implementation -// Uses webhook callback mode for receiving messages and webhook API for sending replies - -package channels +package wecom import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" "encoding/json" "encoding/xml" "fmt" "io" "net/http" - "sort" "strings" "sync" "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -31,9 +23,8 @@ import ( // WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) // Uses webhook callback mode - simpler than WeCom App but only supports passive replies type WeComBotChannel struct { - *BaseChannel + *channels.BaseChannel config config.WeComConfig - server *http.Server ctx context.Context cancel context.CancelFunc processedMsgs map[string]bool // Message deduplication: msg_id -> processed @@ -96,7 +87,11 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We return nil, fmt.Errorf("wecom token and webhook_url are required") } - base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WeComBotChannel{ BaseChannel: base, @@ -110,43 +105,14 @@ func (c *WeComBotChannel) Name() string { return "wecom" } -// Start initializes the WeCom Bot channel with HTTP webhook server +// Start initializes the WeCom Bot channel func (c *WeComBotChannel) Start(ctx context.Context) error { logger.InfoC("wecom", "Starting WeCom Bot channel...") c.ctx, c.cancel = context.WithCancel(ctx) - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - - c.setRunning(true) - logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() + c.SetRunning(true) + logger.InfoC("wecom", "WeCom Bot channel started") return nil } @@ -159,13 +125,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { c.cancel() } - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - - c.setRunning(false) + c.SetRunning(false) logger.InfoC("wecom", "WeCom Bot channel stopped") return nil } @@ -175,7 +135,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { // For delayed responses, we use the webhook URL func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom channel not running") + return channels.ErrNotRunning } logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ @@ -186,6 +146,29 @@ func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *WeComBotChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/wecom" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path. +func (c *WeComBotChannel) HealthPath() string { + return "/health/wecom" +} + +// HealthHandler handles health check requests. +func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + // handleWebhook handles incoming webhook requests from WeCom func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -219,7 +202,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { logger.WarnC("wecom", "Signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -228,7 +211,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons // Decrypt echostr // For AIBOT (智能机器人), receiveid should be empty string "" // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") if err != nil { logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), @@ -281,7 +264,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp } // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { logger.WarnC("wecom", "Message signature verification failed") http.Error(w, "Invalid signature", http.StatusForbidden) return @@ -290,7 +273,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp // Decrypt message // For AIBOT (智能机器人), receiveid should be empty string "" // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") if err != nil { logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ "error": err.Error(), @@ -387,12 +370,21 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag } // Build metadata + peer := bus.Peer{Kind: peerKind, ID: peerID} + + // In group chats, apply unified group trigger filtering + if isGroupChat { + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return + } + content = cleaned + } + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": msg.MsgID, "platform": "wecom", - "peer_kind": peerKind, - "peer_id": peerID, "response_url": msg.ResponseURL, } if isGroupChat { @@ -408,8 +400,19 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag "preview": utils.Truncate(content, 50), }) + // Build sender info + sender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender) } // sendWebhookReply sends a reply using the webhook URL @@ -442,10 +445,15 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send webhook reply: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) @@ -477,129 +485,3 @@ func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(status) } - -// WeCom common utilities for both WeCom Bot and WeCom App -// The following functions were moved from wecom_common.go - -// WeComVerifySignature verifies the message signature for WeCom -// This is a common function used by both WeCom Bot and WeCom App -func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { - if token == "" { - return true // Skip verification if token is not set - } - - // Sort parameters - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature -} - -// WeComDecryptMessage decrypts the encrypted message using AES -// This is a common function used by both WeCom Bot and WeCom App -// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id -func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) { - return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "") -} - -// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid -// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. -func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { - if encodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) - } - - // Decode encrypted message - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - // AES decrypt - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) - } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - // IV is the first 16 bytes of AESKey - iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) - - // Remove PKCS7 padding - plainText, err = pkcs7UnpadWeCom(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) - } - - // Parse message structure - // Format: random(16) + msg_len(4) + msg + receiveid - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") - } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") - } - - msg := plainText[20 : 20+msgLen] - - // Verify receiveid if provided - if receiveid != "" && len(plainText) > 20+int(msgLen) { - actualReceiveID := string(plainText[20+msgLen:]) - if actualReceiveID != receiveid { - return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) - } - } - - return string(msg), nil -} - -// pkcs7UnpadWeCom removes PKCS7 padding with validation -// WeCom uses block size of 32 (not standard AES block size of 16) -const wecomBlockSize = 32 - -func pkcs7UnpadWeCom(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - // WeCom uses 32-byte block size for PKCS7 padding - if padding == 0 || padding > wecomBlockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := 0; i < padding; i++ { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom/bot_test.go similarity index 95% rename from pkg/channels/wecom_test.go rename to pkg/channels/wecom/bot_test.go index 8afa7e8c3..328b145c2 100644 --- a/pkg/channels/wecom_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -1,7 +1,4 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel tests - -package channels +package wecom import ( "bytes" @@ -177,7 +174,7 @@ func TestWeComBotVerifySignature(t *testing.T) { msgEncrypt := "test_message" expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { t.Error("valid signature should pass verification") } }) @@ -187,7 +184,7 @@ func TestWeComBotVerifySignature(t *testing.T) { nonce := "test_nonce" msgEncrypt := "test_message" - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { t.Error("invalid signature should fail verification") } }) @@ -202,7 +199,7 @@ func TestWeComBotVerifySignature(t *testing.T) { config: cfgEmpty, } - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { t.Error("empty token should skip verification and return true") } }) @@ -223,7 +220,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { plainText := "hello world" encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -247,7 +244,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { t.Fatalf("failed to encrypt test message: %v", err) } - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -264,7 +261,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid base64, got nil") } @@ -278,7 +275,7 @@ func TestWeComBotDecryptMessage(t *testing.T) { } ch, _ := NewWeComBotChannel(cfg, msgBus) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) if err == nil { t.Error("expected error for invalid AES key, got nil") } @@ -320,20 +317,20 @@ func TestWeComBotPKCS7Unpad(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) + result, err := pkcs7Unpad(tt.input) if tt.expected == nil { // This case should return an error if err == nil { - t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result) + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) } return } if err != nil { - t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err) + t.Errorf("pkcs7Unpad() unexpected error: %v", err) return } if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected) + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) } }) } diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go new file mode 100644 index 000000000..3c1629577 --- /dev/null +++ b/pkg/channels/wecom/common.go @@ -0,0 +1,134 @@ +package wecom + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "sort" + "strings" +) + +// blockSize is the PKCS7 block size used by WeCom (32) +const blockSize = 32 + +// verifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// decryptMessage decrypts the encrypted message using AES +// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id +func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") +} + +// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid +// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. +func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + // IV is the first 16 bytes of AESKey + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCDecrypter(block, iv) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7Unpad(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + receiveid + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + // Verify receiveid if provided + if receiveid != "" && len(plainText) > 20+int(msgLen) { + actualReceiveID := string(plainText[20+msgLen:]) + if actualReceiveID != receiveid { + return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) + } + } + + return string(msg), nil +} + +// pkcs7Unpad removes PKCS7 padding with validation +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + // WeCom uses 32-byte block size for PKCS7 padding + if padding == 0 || padding > blockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go new file mode 100644 index 000000000..3ef1ecdf3 --- /dev/null +++ b/pkg/channels/wecom/init.go @@ -0,0 +1,16 @@ +package wecom + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComBotChannel(cfg.Channels.WeCom, b) + }) + channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComAppChannel(cfg.Channels.WeComApp, b) + }) +} diff --git a/pkg/channels/whatsapp/init.go b/pkg/channels/whatsapp/init.go new file mode 100644 index 000000000..d9c2669c3 --- /dev/null +++ b/pkg/channels/whatsapp/init.go @@ -0,0 +1,13 @@ +package whatsapp + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWhatsAppChannel(cfg.Channels.WhatsApp, b) + }) +} diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go similarity index 54% rename from pkg/channels/whatsapp.go rename to pkg/channels/whatsapp/whatsapp.go index 958d850bb..5c1b639b3 100644 --- a/pkg/channels/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -1,31 +1,42 @@ -package channels +package whatsapp import ( "context" "encoding/json" "fmt" - "log" "sync" "time" "github.com/gorilla/websocket" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) type WhatsAppChannel struct { - *BaseChannel + *channels.BaseChannel conn *websocket.Conn config config.WhatsAppConfig url string + ctx context.Context + cancel context.CancelFunc mu sync.Mutex connected bool } func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { - base := NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel( + "whatsapp", + cfg, + bus, + cfg.AllowFrom, + channels.WithMaxMessageLength(65536), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WhatsAppChannel{ BaseChannel: base, @@ -36,13 +47,18 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA } func (c *WhatsAppChannel) Start(ctx context.Context) error { - log.Printf("Starting WhatsApp channel connecting to %s...", c.url) + logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{ + "bridge_url": c.url, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 10 * time.Second conn, _, err := dialer.Dial(c.url, nil) if err != nil { + c.cancel() return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) } @@ -51,39 +67,57 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { c.connected = true c.mu.Unlock() - c.setRunning(true) - log.Println("WhatsApp channel connected") + c.SetRunning(true) + logger.InfoC("whatsapp", "WhatsApp channel connected") - go c.listen(ctx) + go c.listen() return nil } func (c *WhatsAppChannel) Stop(ctx context.Context) error { - log.Println("Stopping WhatsApp channel...") + logger.InfoC("whatsapp", "Stopping WhatsApp channel...") + + // Cancel context first to signal listen goroutine to exit + if c.cancel != nil { + c.cancel() + } c.mu.Lock() defer c.mu.Unlock() if c.conn != nil { if err := c.conn.Close(); err != nil { - log.Printf("Error closing WhatsApp connection: %v", err) + logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{ + "error": err.Error(), + }) } c.conn = nil } c.connected = false - c.setRunning(false) + c.SetRunning(false) return nil } func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // Check ctx before acquiring lock + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { - return fmt.Errorf("whatsapp connection not established") + return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary) } payload := map[string]any{ @@ -97,17 +131,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("failed to marshal message: %w", err) } + _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { - return fmt.Errorf("failed to send message: %w", err) + _ = c.conn.SetWriteDeadline(time.Time{}) + return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary) } + _ = c.conn.SetWriteDeadline(time.Time{}) return nil } -func (c *WhatsAppChannel) listen(ctx context.Context) { +func (c *WhatsAppChannel) listen() { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: c.mu.Lock() @@ -121,14 +158,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { _, message, err := conn.ReadMessage() if err != nil { - log.Printf("WhatsApp read error: %v", err) + logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{ + "error": err.Error(), + }) time.Sleep(2 * time.Second) continue } var msg map[string]any if err := json.Unmarshal(message, &msg); err != nil { - log.Printf("Failed to unmarshal WhatsApp message: %v", err) + logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{ + "error": err.Error(), + }) continue } @@ -171,22 +212,38 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { } metadata := make(map[string]string) - if messageID, ok := msg["id"].(string); ok { - metadata["message_id"] = messageID + var messageID string + if mid, ok := msg["id"].(string); ok { + messageID = mid } if userName, ok := msg["from_name"].(string); ok { metadata["user_name"] = userName } + var peer bus.Peer if chatID == senderID { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } - log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) + logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{ + "sender": senderID, + "preview": utils.Truncate(content, 50), + }) + + sender := bus.SenderInfo{ + Platform: "whatsapp", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("whatsapp", senderID), + } + if display, ok := metadata["user_name"]; ok { + sender.DisplayName = display + } + + if !c.IsAllowedSender(sender) { + return + } - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 2595398c7..b2385ede5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -192,108 +192,170 @@ type ChannelsConfig struct { OneBot OneBotConfig `json:"onebot"` WeCom WeComConfig `json:"wecom"` WeComApp WeComAppConfig `json:"wecom_app"` + Pico PicoConfig `json:"pico"` +} + +// GroupTriggerConfig controls when the bot responds in group chats. +type GroupTriggerConfig struct { + MentionOnly bool `json:"mention_only,omitempty"` + Prefixes []string `json:"prefixes,omitempty"` +} + +// TypingConfig controls typing indicator behavior (Phase 10). +type TypingConfig struct { + Enabled bool `json:"enabled,omitempty"` +} + +// PlaceholderConfig controls placeholder message behavior (Phase 10). +type PlaceholderConfig struct { + Enabled bool `json:"enabled,omitempty"` + Text string `json:"text,omitempty"` } type WhatsAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` - BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` + BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WHATSAPP_REASONING_CHANNEL_ID"` } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` - VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` - MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DISCORD_REASONING_CHANNEL_ID"` } type MaixCamConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` - Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` - Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` + Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` + Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MAIXCAM_REASONING_CHANNEL_ID"` } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DINGTALK_REASONING_CHANNEL_ID"` } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_SLACK_REASONING_CHANNEL_ID"` } type LINEConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` - ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` - ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` + ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` + ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_LINE_REASONING_CHANNEL_ID"` } type OneBotConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` - WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` - AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` - ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` - GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` + WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` + AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` + ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` + GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_ONEBOT_REASONING_CHANNEL_ID"` } type WeComConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` + WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` } type WeComAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` + CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` + CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` + AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` +} + +type PicoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` + AllowTokenQuery bool `json:"allow_token_query,omitempty"` + AllowOrigins []string `json:"allow_origins,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` + ReadTimeout int `json:"read_timeout,omitempty"` + WriteTimeout int `json:"write_timeout,omitempty"` + MaxConnections int `json:"max_connections,omitempty"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` } type HeartbeatConfig struct { @@ -507,6 +569,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Migrate legacy channel config fields to new unified structures + cfg.migrateChannelConfigs() + // Auto-migrate: if only legacy providers config exists, convert to model_list if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { cfg.ModelList = ConvertProvidersToModelList(cfg) @@ -520,6 +585,18 @@ func LoadConfig(path string) (*Config, error) { return cfg, nil } +func (c *Config) migrateChannelConfigs() { + // Discord: mention_only -> group_trigger.mention_only + if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly { + c.Channels.Discord.GroupTrigger.MentionOnly = true + } + + // OneBot: group_trigger_prefix -> group_trigger.prefixes + if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { + c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix + } +} + func SaveConfig(path string, cfg *Config) error { data, err := json.MarshalIndent(cfg, "", " ") if err != nil { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index b96ee4d89..604b53e24 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -33,6 +33,11 @@ func DefaultConfig() *Config { Enabled: false, Token: "", AllowFrom: FlexibleStringSlice{}, + Typing: TypingConfig{Enabled: true}, + Placeholder: PlaceholderConfig{ + Enabled: true, + Text: "Thinking... 💭", + }, }, Feishu: FeishuConfig{ Enabled: false, @@ -80,6 +85,7 @@ func DefaultConfig() *Config { WebhookPort: 18791, WebhookPath: "/webhook/line", AllowFrom: FlexibleStringSlice{}, + GroupTrigger: GroupTriggerConfig{MentionOnly: true}, }, OneBot: OneBotConfig{ Enabled: false, @@ -113,6 +119,15 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, ReplyTimeout: 5, }, + Pico: PicoConfig{ + Enabled: false, + Token: "", + PingInterval: 30, + ReadTimeout: 60, + WriteTimeout: 10, + MaxConnections: 100, + AllowFrom: FlexibleStringSlice{}, + }, }, Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true}, diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1541d3c57..1bafe6085 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -4,6 +4,7 @@ import ( "context" "strings" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/constants" @@ -127,7 +128,9 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { } msg := ev.FormatMessage() - msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: msg, diff --git a/pkg/health/server.go b/pkg/health/server.go index 77b36034d..de1ff60fe 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -156,6 +156,13 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { }) } +// RegisterOnMux registers /health and /ready handlers onto the given mux. +// This allows the health endpoints to be served by a shared HTTP server. +func (s *Server) RegisterOnMux(mux *http.ServeMux) { + mux.HandleFunc("/health", s.healthHandler) + mux.HandleFunc("/ready", s.readyHandler) +} + func statusString(ok bool) string { if ok { return "ok" diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 75d6248b9..475f10509 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -7,6 +7,7 @@ package heartbeat import ( + "context" "fmt" "os" "path/filepath" @@ -307,7 +308,9 @@ func (hs *HeartbeatService) sendResponse(response string) { return } - msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: response, diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go new file mode 100644 index 000000000..6bc09c210 --- /dev/null +++ b/pkg/identity/identity.go @@ -0,0 +1,107 @@ +// Package identity provides unified user identity utilities for PicoClaw. +// It introduces a canonical "platform:id" format and matching logic +// that is backward-compatible with all legacy allow-list formats. +package identity + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// BuildCanonicalID constructs a canonical "platform:id" identifier. +// Both platform and platformID are lowercased and trimmed. +func BuildCanonicalID(platform, platformID string) string { + p := strings.ToLower(strings.TrimSpace(platform)) + id := strings.TrimSpace(platformID) + if p == "" || id == "" { + return "" + } + return p + ":" + id +} + +// ParseCanonicalID splits a canonical ID ("platform:id") into its parts. +// Returns ok=false if the input does not contain a colon separator. +func ParseCanonicalID(canonical string) (platform, id string, ok bool) { + canonical = strings.TrimSpace(canonical) + idx := strings.Index(canonical, ":") + if idx <= 0 || idx == len(canonical)-1 { + return "", "", false + } + return canonical[:idx], canonical[idx+1:], true +} + +// MatchAllowed checks whether the given sender matches a single allow-list entry. +// It is backward-compatible with all legacy formats: +// +// - "123456" → matches sender.PlatformID +// - "@alice" → matches sender.Username +// - "123456|alice" → matches PlatformID or Username +// - "telegram:123456" → exact match on sender.CanonicalID +func MatchAllowed(sender bus.SenderInfo, allowed string) bool { + allowed = strings.TrimSpace(allowed) + if allowed == "" { + return false + } + + // Try canonical match first: "platform:id" format + if platform, id, ok := ParseCanonicalID(allowed); ok { + // Only treat as canonical if the platform portion looks like a known platform name + // (not a pure-numeric string, which could be a compound ID) + if !isNumeric(platform) { + candidate := BuildCanonicalID(platform, id) + if candidate != "" && sender.CanonicalID != "" { + return strings.EqualFold(sender.CanonicalID, candidate) + } + // If sender has no canonical ID, try matching platform + platformID + return strings.EqualFold(platform, sender.Platform) && + sender.PlatformID == id + } + } + + // Strip leading "@" for username matching + trimmed := strings.TrimPrefix(allowed, "@") + + // Split compound "id|username" format + allowedID := trimmed + allowedUser := "" + if idx := strings.Index(trimmed, "|"); idx > 0 { + allowedID = trimmed[:idx] + allowedUser = trimmed[idx+1:] + } + + // Match against PlatformID + if sender.PlatformID != "" && sender.PlatformID == allowedID { + return true + } + + // Match against Username + if sender.Username != "" { + if sender.Username == trimmed || sender.Username == allowedUser { + return true + } + } + + // Match compound sender format against allowed parts + if allowedUser != "" && sender.PlatformID != "" && sender.PlatformID == allowedID { + return true + } + if allowedUser != "" && sender.Username != "" && sender.Username == allowedUser { + return true + } + + return false +} + +// isNumeric returns true if s consists entirely of digits. +func isNumeric(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/pkg/identity/identity_test.go b/pkg/identity/identity_test.go new file mode 100644 index 000000000..3d24bd794 --- /dev/null +++ b/pkg/identity/identity_test.go @@ -0,0 +1,229 @@ +package identity + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +func TestBuildCanonicalID(t *testing.T) { + tests := []struct { + platform string + platformID string + want string + }{ + {"telegram", "123456", "telegram:123456"}, + {"Discord", "98765432", "discord:98765432"}, + {"SLACK", "U123ABC", "slack:U123ABC"}, + {"", "123", ""}, + {"telegram", "", ""}, + {" telegram ", " 123 ", "telegram:123"}, + } + + for _, tt := range tests { + got := BuildCanonicalID(tt.platform, tt.platformID) + if got != tt.want { + t.Errorf("BuildCanonicalID(%q, %q) = %q, want %q", + tt.platform, tt.platformID, got, tt.want) + } + } +} + +func TestParseCanonicalID(t *testing.T) { + tests := []struct { + input string + wantPlatform string + wantID string + wantOk bool + }{ + {"telegram:123456", "telegram", "123456", true}, + {"discord:98765432", "discord", "98765432", true}, + {"slack:U123ABC", "slack", "U123ABC", true}, + {"nocolon", "", "", false}, + {"", "", "", false}, + {":missing", "", "", false}, + {"missing:", "", "", false}, + } + + for _, tt := range tests { + platform, id, ok := ParseCanonicalID(tt.input) + if ok != tt.wantOk || platform != tt.wantPlatform || id != tt.wantID { + t.Errorf("ParseCanonicalID(%q) = (%q, %q, %v), want (%q, %q, %v)", + tt.input, platform, id, ok, + tt.wantPlatform, tt.wantID, tt.wantOk) + } + } +} + +func TestMatchAllowed(t *testing.T) { + telegramSender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + DisplayName: "Alice Smith", + } + + discordSender := bus.SenderInfo{ + Platform: "discord", + PlatformID: "98765432", + CanonicalID: "discord:98765432", + Username: "bob", + DisplayName: "bob#1234", + } + + noCanonicalSender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: "999", + Username: "carol", + } + + tests := []struct { + name string + sender bus.SenderInfo + allowed string + want bool + }{ + // Pure numeric ID matching + { + name: "numeric ID matches PlatformID", + sender: telegramSender, + allowed: "123456", + want: true, + }, + { + name: "numeric ID does not match", + sender: telegramSender, + allowed: "654321", + want: false, + }, + // Username matching + { + name: "@username matches Username", + sender: telegramSender, + allowed: "@alice", + want: true, + }, + { + name: "@username does not match", + sender: telegramSender, + allowed: "@bob", + want: false, + }, + // Compound format "id|username" + { + name: "compound matches by ID", + sender: telegramSender, + allowed: "123456|alice", + want: true, + }, + { + name: "compound matches by username", + sender: telegramSender, + allowed: "999|alice", + want: true, + }, + { + name: "compound does not match", + sender: telegramSender, + allowed: "654321|bob", + want: false, + }, + // Canonical format "platform:id" + { + name: "canonical matches exactly", + sender: telegramSender, + allowed: "telegram:123456", + want: true, + }, + { + name: "canonical case-insensitive platform", + sender: telegramSender, + allowed: "Telegram:123456", + want: true, + }, + { + name: "canonical wrong platform", + sender: telegramSender, + allowed: "discord:123456", + want: false, + }, + { + name: "canonical wrong ID", + sender: telegramSender, + allowed: "telegram:654321", + want: false, + }, + // Cross-platform canonical + { + name: "discord canonical match", + sender: discordSender, + allowed: "discord:98765432", + want: true, + }, + { + name: "telegram canonical does not match discord sender", + sender: discordSender, + allowed: "telegram:98765432", + want: false, + }, + // Sender without canonical ID + { + name: "canonical match falls back to platform+platformID", + sender: noCanonicalSender, + allowed: "telegram:999", + want: true, + }, + { + name: "platform mismatch on fallback", + sender: noCanonicalSender, + allowed: "discord:999", + want: false, + }, + // Empty allowed string + { + name: "empty allowed never matches", + sender: telegramSender, + allowed: "", + want: false, + }, + // Whitespace handling + { + name: "trimmed allowed matches", + sender: telegramSender, + allowed: " 123456 ", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchAllowed(tt.sender, tt.allowed) + if got != tt.want { + t.Errorf("MatchAllowed(%+v, %q) = %v, want %v", + tt.sender, tt.allowed, got, tt.want) + } + }) + } +} + +func TestIsNumeric(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"123456", true}, + {"0", true}, + {"", false}, + {"abc", false}, + {"12a34", false}, + {"telegram", false}, + } + + for _, tt := range tests { + got := isNumeric(tt.input) + if got != tt.want { + t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/media/store.go b/pkg/media/store.go new file mode 100644 index 000000000..2df4420e9 --- /dev/null +++ b/pkg/media/store.go @@ -0,0 +1,123 @@ +package media + +import ( + "fmt" + "os" + "sync" + + "github.com/google/uuid" +) + +// MediaMeta holds metadata about a stored media file. +type MediaMeta struct { + Filename string + ContentType string + Source string // "telegram", "discord", "tool:image-gen", etc. +} + +// MediaStore manages the lifecycle of media files associated with processing scopes. +type MediaStore interface { + // Store registers an existing local file under the given scope. + // Returns a ref identifier (e.g. "media://"). + // Store does not move or copy the file; it only records the mapping. + Store(localPath string, meta MediaMeta, scope string) (ref string, err error) + + // Resolve returns the local file path for a given ref. + Resolve(ref string) (localPath string, err error) + + // ResolveWithMeta returns the local file path and metadata for a given ref. + ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error) + + // ReleaseAll deletes all files registered under the given scope + // and removes the mapping entries. File-not-exist errors are ignored. + ReleaseAll(scope string) error +} + +// mediaEntry holds the path and metadata for a stored media file. +type mediaEntry struct { + path string + meta MediaMeta +} + +// FileMediaStore is a pure in-memory implementation of MediaStore. +// Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/). +type FileMediaStore struct { + mu sync.RWMutex + refs map[string]mediaEntry + scopeToRefs map[string]map[string]struct{} +} + +// NewFileMediaStore creates a new FileMediaStore. +func NewFileMediaStore() *FileMediaStore { + return &FileMediaStore{ + refs: make(map[string]mediaEntry), + scopeToRefs: make(map[string]map[string]struct{}), + } +} + +// Store registers a local file under the given scope. The file must exist. +func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) { + if _, err := os.Stat(localPath); err != nil { + return "", fmt.Errorf("media store: %s: %w", localPath, err) + } + + ref := "media://" + uuid.New().String() + + s.mu.Lock() + defer s.mu.Unlock() + + s.refs[ref] = mediaEntry{path: localPath, meta: meta} + if s.scopeToRefs[scope] == nil { + s.scopeToRefs[scope] = make(map[string]struct{}) + } + s.scopeToRefs[scope][ref] = struct{}{} + + return ref, nil +} + +// Resolve returns the local path for the given ref. +func (s *FileMediaStore) Resolve(ref string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.refs[ref] + if !ok { + return "", fmt.Errorf("media store: unknown ref: %s", ref) + } + return entry.path, nil +} + +// ResolveWithMeta returns the local path and metadata for the given ref. +func (s *FileMediaStore) ResolveWithMeta(ref string) (string, MediaMeta, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.refs[ref] + if !ok { + return "", MediaMeta{}, fmt.Errorf("media store: unknown ref: %s", ref) + } + return entry.path, entry.meta, nil +} + +// ReleaseAll removes all files under the given scope and cleans up mappings. +func (s *FileMediaStore) ReleaseAll(scope string) error { + s.mu.Lock() + defer s.mu.Unlock() + + refs, ok := s.scopeToRefs[scope] + if !ok { + return nil + } + + for ref := range refs { + if entry, exists := s.refs[ref]; exists { + if err := os.Remove(entry.path); err != nil && !os.IsNotExist(err) { + // Log but continue — best effort cleanup + } + delete(s.refs, ref) + } + } + + delete(s.scopeToRefs, scope) + return nil +} diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go new file mode 100644 index 000000000..95bd1eb7a --- /dev/null +++ b/pkg/media/store_test.go @@ -0,0 +1,223 @@ +package media + +import ( + "os" + "path/filepath" + "strings" + "sync" + "testing" +) + +func createTempFile(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("test content"), 0o644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + return path +} + +func TestStoreAndResolve(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "photo.jpg") + + ref, err := store.Store(path, MediaMeta{Filename: "photo.jpg", Source: "telegram"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + if !strings.HasPrefix(ref, "media://") { + t.Errorf("ref should start with media://, got %q", ref) + } + + resolved, err := store.Resolve(ref) + if err != nil { + t.Fatalf("Resolve failed: %v", err) + } + if resolved != path { + t.Errorf("Resolve returned %q, want %q", resolved, path) + } +} + +func TestReleaseAll(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + paths := make([]string, 3) + refs := make([]string, 3) + for i := 0; i < 3; i++ { + paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg") + var err error + refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + } + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("ReleaseAll failed: %v", err) + } + + // Files should be deleted + for _, p := range paths { + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Errorf("file %q should have been deleted", p) + } + } + + // Refs should be unresolvable + for _, ref := range refs { + if _, err := store.Resolve(ref); err == nil { + t.Errorf("Resolve(%q) should fail after ReleaseAll", ref) + } + } +} + +func TestMultiScopeIsolation(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + pathA := createTempFile(t, dir, "fileA.jpg") + pathB := createTempFile(t, dir, "fileB.jpg") + + refA, _ := store.Store(pathA, MediaMeta{Source: "test"}, "scopeA") + refB, _ := store.Store(pathB, MediaMeta{Source: "test"}, "scopeB") + + // Release only scopeA + if err := store.ReleaseAll("scopeA"); err != nil { + t.Fatalf("ReleaseAll(scopeA) failed: %v", err) + } + + // scopeA file should be gone + if _, err := os.Stat(pathA); !os.IsNotExist(err) { + t.Error("file A should have been deleted") + } + if _, err := store.Resolve(refA); err == nil { + t.Error("refA should be unresolvable after release") + } + + // scopeB file should still exist + if _, err := os.Stat(pathB); err != nil { + t.Error("file B should still exist") + } + resolved, err := store.Resolve(refB) + if err != nil { + t.Fatalf("refB should still resolve: %v", err) + } + if resolved != pathB { + t.Errorf("resolved %q, want %q", resolved, pathB) + } +} + +func TestReleaseAllIdempotent(t *testing.T) { + store := NewFileMediaStore() + + // ReleaseAll on non-existent scope should not error + if err := store.ReleaseAll("nonexistent"); err != nil { + t.Fatalf("ReleaseAll on empty scope should not error: %v", err) + } + + // Create and release, then release again + dir := t.TempDir() + path := createTempFile(t, dir, "file.jpg") + _, _ = store.Store(path, MediaMeta{Source: "test"}, "scope1") + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("first ReleaseAll failed: %v", err) + } + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("second ReleaseAll should not error: %v", err) + } +} + +func TestStoreNonexistentFile(t *testing.T) { + store := NewFileMediaStore() + + _, err := store.Store("/nonexistent/path/file.jpg", MediaMeta{Source: "test"}, "scope1") + if err == nil { + t.Error("Store should fail for nonexistent file") + } + // Error message should include the underlying os error, not just "file does not exist" + if !strings.Contains(err.Error(), "no such file or directory") { + t.Errorf("Error should contain OS error detail, got: %v", err) + } +} + +func TestResolveWithMeta(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "image.png") + meta := MediaMeta{ + Filename: "image.png", + ContentType: "image/png", + Source: "telegram", + } + + ref, err := store.Store(path, meta, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + resolvedPath, resolvedMeta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("ResolveWithMeta failed: %v", err) + } + if resolvedPath != path { + t.Errorf("ResolveWithMeta path = %q, want %q", resolvedPath, path) + } + if resolvedMeta.Filename != meta.Filename { + t.Errorf("ResolveWithMeta Filename = %q, want %q", resolvedMeta.Filename, meta.Filename) + } + if resolvedMeta.ContentType != meta.ContentType { + t.Errorf("ResolveWithMeta ContentType = %q, want %q", resolvedMeta.ContentType, meta.ContentType) + } + if resolvedMeta.Source != meta.Source { + t.Errorf("ResolveWithMeta Source = %q, want %q", resolvedMeta.Source, meta.Source) + } + + // Unknown ref should fail + _, _, err = store.ResolveWithMeta("media://nonexistent") + if err == nil { + t.Error("ResolveWithMeta should fail for unknown ref") + } +} + +func TestConcurrentSafety(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + const goroutines = 20 + const filesPerGoroutine = 5 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := 0; g < goroutines; g++ { + go func(gIdx int) { + defer wg.Done() + scope := strings.Repeat("s", gIdx+1) + + for i := 0; i < filesPerGoroutine; i++ { + path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp") + ref, err := store.Store(path, MediaMeta{Source: "test"}, scope) + if err != nil { + t.Errorf("Store failed: %v", err) + return + } + + if _, err := store.Resolve(ref); err != nil { + t.Errorf("Resolve failed: %v", err) + } + } + + if err := store.ReleaseAll(scope); err != nil { + t.Errorf("ReleaseAll failed: %v", err) + } + }(g) + } + + wg.Wait() +} diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 236a048c4..e8bd21ebd 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -25,6 +25,7 @@ type ( ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition ExtraContent = protocoltypes.ExtraContent GoogleExtra = protocoltypes.GoogleExtra + ReasoningDetail = protocoltypes.ReasoningDetail ) type Provider struct { @@ -148,8 +149,10 @@ func parseResponse(body []byte) (*LLMResponse, error) { var apiResponse struct { Choices []struct { Message struct { - Content string `json:"content"` - ToolCalls []struct { + Content string `json:"content"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` + ToolCalls []struct { ID string `json:"id"` Type string `json:"type"` Function *struct { @@ -221,10 +224,12 @@ func parseResponse(body []byte) (*LLMResponse, error) { } return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, + Reasoning: choice.Message.Reasoning, + ReasoningDetails: choice.Message.ReasoningDetails, + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, }, nil } diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 5e1c6d397..3e18fca43 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -25,12 +25,19 @@ type FunctionCall struct { } type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` +} +type ReasoningDetail struct { + Format string `json:"format"` + Index int `json:"index"` + Type string `json:"type"` + Text string `json:"text"` } - type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go index e12f0d1d8..eab592bec 100644 --- a/pkg/routing/session_key.go +++ b/pkg/routing/session_key.go @@ -163,6 +163,15 @@ func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID stri scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) candidates[scopedCandidate] = true } + + // If peerID is already in canonical "platform:id" format, also add the + // bare ID part as a candidate for backward compatibility with identity_links + // that use raw IDs (e.g. "123" instead of "telegram:123"). + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + bareID := rawCandidate[idx+1:] + candidates[bareID] = true + } + if len(candidates) == 0 { return "" } diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go index 81e4ce018..ad7a1ca02 100644 --- a/pkg/routing/session_key_test.go +++ b/pkg/routing/session_key_test.go @@ -115,6 +115,51 @@ func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { } } +func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) { + // When peerID is already in canonical "platform:id" format, + // it should match identity_links that use the bare ID. + links := map[string][]string{ + "john": {"123"}, + } + got := resolveLinkedPeerID(links, "telegram", "telegram:123") + if got != "john" { + t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) { + // When identity_links contain canonical IDs and peerID is canonical too + links := map[string][]string{ + "john": {"telegram:123", "discord:456"}, + } + got := resolveLinkedPeerID(links, "telegram", "telegram:123") + if got != "john" { + t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) { + // When peerID is bare "123" and links have "telegram:123", + // the scoped candidate "telegram:123" should match. + links := map[string][]string{ + "john": {"telegram:123"}, + } + got := resolveLinkedPeerID(links, "telegram", "123") + if got != "john" { + t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_NoMatch(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:123"}, + } + got := resolveLinkedPeerID(links, "discord", "999") + if got != "" { + t.Errorf("resolveLinkedPeerID no match = %q, want empty", got) + } +} + func TestParseAgentSessionKey_Valid(t *testing.T) { parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") if parsed == nil { diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 562fffc84..52f914622 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -294,7 +294,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } - t.msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: output, @@ -304,7 +306,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // If deliver=true, send message directly without agent processing if job.Payload.Deliver { - t.msgBus.PublishOutbound(bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: job.Payload.Message, diff --git a/pkg/tools/result.go b/pkg/tools/result.go index b13055b1c..cab833284 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/result.go @@ -30,6 +30,10 @@ type ToolResult struct { // Err is the underlying error (not JSON serialized). // Used for internal error handling and logging. Err error `json:"-"` + + // Media contains media store refs produced by this tool. + // When non-empty, the agent will publish these as OutboundMediaMessage. + Media []string `json:"media,omitempty"` } // NewToolResult creates a basic ToolResult with content for the LLM. @@ -120,6 +124,19 @@ func UserResult(content string) *ToolResult { } } +// MediaResult creates a ToolResult with media refs for the user. +// The agent will publish these refs as OutboundMediaMessage. +// +// Example: +// +// result := MediaResult("Image generated successfully", []string{"media://abc123"}) +func MediaResult(forLLM string, mediaRefs []string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Media: mediaRefs, + } +} + // MarshalJSON implements custom JSON serialization. // The Err field is excluded from JSON output via the json:"-" tag. func (tr *ToolResult) MarshalJSON() ([]byte, error) { diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 91ebff636..fee53fc28 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -218,7 +218,9 @@ After completing the task, provide a clear summary of what was done.` // Send announce message back to main agent if sm.bus != nil { announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - sm.bus.PublishInbound(bus.InboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + sm.bus.PublishInbound(pubCtx, bus.InboundMessage{ Channel: "system", SenderID: fmt.Sprintf("subagent:%s", task.ID), // Format: "original_channel:original_chat_id" for routing back diff --git a/pkg/utils/message.go b/pkg/utils/message.go deleted file mode 100644 index 1d05950d9..000000000 --- a/pkg/utils/message.go +++ /dev/null @@ -1,179 +0,0 @@ -package utils - -import ( - "strings" -) - -// SplitMessage splits long messages into chunks, preserving code block integrity. -// The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, -// but may extend to maxLen when needed. -// Call SplitMessage with the full text content and the maximum allowed length of a single message; -// it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. -func SplitMessage(content string, maxLen int) []string { - var messages []string - - // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible - codeBlockBuffer := maxLen / 10 - if codeBlockBuffer < 50 { - codeBlockBuffer = 50 - } - if codeBlockBuffer > maxLen/2 { - codeBlockBuffer = maxLen / 2 - } - - for len(content) > 0 { - if len(content) <= maxLen { - messages = append(messages, content) - break - } - - // Effective split point: maxLen minus buffer, to leave room for code blocks - effectiveLimit := maxLen - codeBlockBuffer - if effectiveLimit < maxLen/2 { - effectiveLimit = maxLen / 2 - } - - // Find natural split point within the effective limit - msgEnd := findLastNewline(content[:effectiveLimit], 200) - if msgEnd <= 0 { - msgEnd = findLastSpace(content[:effectiveLimit], 100) - } - if msgEnd <= 0 { - msgEnd = effectiveLimit - } - - // Check if this would end with an incomplete code block - candidate := content[:msgEnd] - unclosedIdx := findLastUnclosedCodeBlock(candidate) - - if unclosedIdx >= 0 { - // Message would end with incomplete code block - // Try to extend up to maxLen to include the closing ``` - if len(content) > msgEnd { - closingIdx := findNextClosingCodeBlock(content, msgEnd) - if closingIdx > 0 && closingIdx <= maxLen { - // Extend to include the closing ``` - msgEnd = closingIdx - } else { - // Code block is too long to fit in one chunk or missing closing fence. - // Try to split inside by injecting closing and reopening fences. - headerEnd := strings.Index(content[unclosedIdx:], "\n") - if headerEnd == -1 { - headerEnd = unclosedIdx + 3 - } else { - headerEnd += unclosedIdx - } - header := strings.TrimSpace(content[unclosedIdx:headerEnd]) - - // If we have a reasonable amount of content after the header, split inside - if msgEnd > headerEnd+20 { - // Find a better split point closer to maxLen - innerLimit := maxLen - 5 // Leave room for "\n```" - betterEnd := findLastNewline(content[:innerLimit], 200) - if betterEnd > headerEnd { - msgEnd = betterEnd - } else { - msgEnd = innerLimit - } - messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") - content = strings.TrimSpace(header + "\n" + content[msgEnd:]) - continue - } - - // Otherwise, try to split before the code block starts - newEnd := findLastNewline(content[:unclosedIdx], 200) - if newEnd <= 0 { - newEnd = findLastSpace(content[:unclosedIdx], 100) - } - if newEnd > 0 { - msgEnd = newEnd - } else { - // If we can't split before, we MUST split inside (last resort) - if unclosedIdx > 20 { - msgEnd = unclosedIdx - } else { - msgEnd = maxLen - 5 - messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") - content = strings.TrimSpace(header + "\n" + content[msgEnd:]) - continue - } - } - } - } - } - - if msgEnd <= 0 { - msgEnd = effectiveLimit - } - - messages = append(messages, content[:msgEnd]) - content = strings.TrimSpace(content[msgEnd:]) - } - - return messages -} - -// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` -// Returns the position of the opening ``` or -1 if all code blocks are complete -func findLastUnclosedCodeBlock(text string) int { - inCodeBlock := false - lastOpenIdx := -1 - - for i := 0; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { - // Toggle code block state on each fence - if !inCodeBlock { - // Entering a code block: record this opening fence - lastOpenIdx = i - } - inCodeBlock = !inCodeBlock - i += 2 - } - } - - if inCodeBlock { - return lastOpenIdx - } - return -1 -} - -// findNextClosingCodeBlock finds the next closing ``` starting from a position -// Returns the position after the closing ``` or -1 if not found -func findNextClosingCodeBlock(text string, startIdx int) int { - for i := startIdx; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { - return i + 3 - } - } - return -1 -} - -// findLastNewline finds the last newline character within the last N characters -// Returns the position of the newline or -1 if not found -func findLastNewline(s string, searchWindow int) int { - searchStart := len(s) - searchWindow - if searchStart < 0 { - searchStart = 0 - } - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == '\n' { - return i - } - } - return -1 -} - -// findLastSpace finds the last space character within the last N characters -// Returns the position of the space or -1 if not found -func findLastSpace(s string, searchWindow int) int { - searchStart := len(s) - searchWindow - if searchStart < 0 { - searchStart = 0 - } - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == ' ' || s[i] == '\t' { - return i - } - } - return -1 -}