From 57c1d37c22b5c18ee0b1b055966d2673744b98ab Mon Sep 17 00:00:00 2001 From: Hoshina Date: Fri, 20 Feb 2026 23:18:46 +0800 Subject: [PATCH 01/28] refactor(channels): add factory registry and export SetRunning on BaseChannel --- pkg/channels/base.go | 4 ++++ pkg/channels/registry.go | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100644 pkg/channels/registry.go diff --git a/pkg/channels/base.go b/pkg/channels/base.go index cd6419ebb..3f0a766ea 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -101,3 +101,7 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st func (c *BaseChannel) setRunning(running bool) { c.running = running } + +func (c *BaseChannel) SetRunning(running bool) { + c.running = running +} diff --git a/pkg/channels/registry.go b/pkg/channels/registry.go new file mode 100644 index 000000000..36a05bf3e --- /dev/null +++ b/pkg/channels/registry.go @@ -0,0 +1,32 @@ +package channels + +import ( + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// ChannelFactory is a constructor function that creates a Channel from config and message bus. +// Each channel subpackage registers one or more factories via init(). +type ChannelFactory func(cfg *config.Config, bus *bus.MessageBus) (Channel, error) + +var ( + factoriesMu sync.RWMutex + factories = map[string]ChannelFactory{} +) + +// RegisterFactory registers a named channel factory. Called from subpackage init() functions. +func RegisterFactory(name string, f ChannelFactory) { + factoriesMu.Lock() + defer factoriesMu.Unlock() + factories[name] = f +} + +// getFactory looks up a channel factory by name. +func getFactory(name string) (ChannelFactory, bool) { + factoriesMu.RLock() + defer factoriesMu.RUnlock() + f, ok := factories[name] + return f, ok +} From 383687dc57e93c33b4cfea2f669ba5eb3f4349d4 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Fri, 20 Feb 2026 23:19:40 +0800 Subject: [PATCH 02/28] refactor(channels): replace direct constructors with factory registry in manager --- pkg/channels/manager.go | 178 +++++++++++----------------------------- 1 file changed, 48 insertions(+), 130 deletions(-) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 75edaf49e..091982282 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -43,166 +43,84 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error return m, nil } +// initChannel is a helper that looks up a factory by name and creates the channel. +func (m *Manager) initChannel(name, displayName string) { + f, ok := getFactory(name) + if !ok { + logger.WarnCF("channels", "Factory not registered", map[string]interface{}{ + "channel": displayName, + }) + return + } + logger.DebugCF("channels", "Attempting to initialize channel", map[string]interface{}{ + "channel": displayName, + }) + ch, err := f(m.config, m.bus) + if err != nil { + logger.ErrorCF("channels", "Failed to initialize channel", map[string]interface{}{ + "channel": displayName, + "error": err.Error(), + }) + } else { + m.channels[name] = ch + logger.InfoCF("channels", "Channel enabled successfully", map[string]interface{}{ + "channel": displayName, + }) + } +} + func (m *Manager) initChannels() error { logger.InfoC("channels", "Initializing channel manager") if m.config.Channels.Telegram.Enabled && m.config.Channels.Telegram.Token != "" { - logger.DebugC("channels", "Attempting to initialize Telegram channel") - telegram, err := NewTelegramChannel(m.config, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Telegram channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["telegram"] = telegram - logger.InfoC("channels", "Telegram channel enabled successfully") - } + m.initChannel("telegram", "Telegram") } if m.config.Channels.WhatsApp.Enabled && m.config.Channels.WhatsApp.BridgeURL != "" { - logger.DebugC("channels", "Attempting to initialize WhatsApp channel") - whatsapp, err := NewWhatsAppChannel(m.config.Channels.WhatsApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WhatsApp channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["whatsapp"] = whatsapp - logger.InfoC("channels", "WhatsApp channel enabled successfully") - } + m.initChannel("whatsapp", "WhatsApp") } if m.config.Channels.Feishu.Enabled { - logger.DebugC("channels", "Attempting to initialize Feishu channel") - feishu, err := NewFeishuChannel(m.config.Channels.Feishu, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Feishu channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["feishu"] = feishu - logger.InfoC("channels", "Feishu channel enabled successfully") - } + m.initChannel("feishu", "Feishu") } if m.config.Channels.Discord.Enabled && m.config.Channels.Discord.Token != "" { - logger.DebugC("channels", "Attempting to initialize Discord channel") - discord, err := NewDiscordChannel(m.config.Channels.Discord, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Discord channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["discord"] = discord - logger.InfoC("channels", "Discord channel enabled successfully") - } + m.initChannel("discord", "Discord") } if m.config.Channels.MaixCam.Enabled { - logger.DebugC("channels", "Attempting to initialize MaixCam channel") - maixcam, err := NewMaixCamChannel(m.config.Channels.MaixCam, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize MaixCam channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["maixcam"] = maixcam - logger.InfoC("channels", "MaixCam channel enabled successfully") - } + m.initChannel("maixcam", "MaixCam") } if m.config.Channels.QQ.Enabled { - logger.DebugC("channels", "Attempting to initialize QQ channel") - qq, err := NewQQChannel(m.config.Channels.QQ, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize QQ channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["qq"] = qq - logger.InfoC("channels", "QQ channel enabled successfully") - } + m.initChannel("qq", "QQ") } if m.config.Channels.DingTalk.Enabled && m.config.Channels.DingTalk.ClientID != "" { - logger.DebugC("channels", "Attempting to initialize DingTalk channel") - dingtalk, err := NewDingTalkChannel(m.config.Channels.DingTalk, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize DingTalk channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["dingtalk"] = dingtalk - logger.InfoC("channels", "DingTalk channel enabled successfully") - } + m.initChannel("dingtalk", "DingTalk") } if m.config.Channels.Slack.Enabled && m.config.Channels.Slack.BotToken != "" { - logger.DebugC("channels", "Attempting to initialize Slack channel") - slackCh, err := NewSlackChannel(m.config.Channels.Slack, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize Slack channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["slack"] = slackCh - logger.InfoC("channels", "Slack channel enabled successfully") - } + m.initChannel("slack", "Slack") } if m.config.Channels.LINE.Enabled && m.config.Channels.LINE.ChannelAccessToken != "" { - logger.DebugC("channels", "Attempting to initialize LINE channel") - line, err := NewLINEChannel(m.config.Channels.LINE, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize LINE channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["line"] = line - logger.InfoC("channels", "LINE channel enabled successfully") - } + m.initChannel("line", "LINE") } if m.config.Channels.OneBot.Enabled && m.config.Channels.OneBot.WSUrl != "" { - logger.DebugC("channels", "Attempting to initialize OneBot channel") - onebot, err := NewOneBotChannel(m.config.Channels.OneBot, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize OneBot channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["onebot"] = onebot - logger.InfoC("channels", "OneBot channel enabled successfully") - } + m.initChannel("onebot", "OneBot") } if m.config.Channels.WeCom.Enabled && m.config.Channels.WeCom.Token != "" { - logger.DebugC("channels", "Attempting to initialize WeCom channel") - wecom, err := NewWeComBotChannel(m.config.Channels.WeCom, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom"] = wecom - logger.InfoC("channels", "WeCom channel enabled successfully") - } + m.initChannel("wecom", "WeCom") } if m.config.Channels.WeComApp.Enabled && m.config.Channels.WeComApp.CorpID != "" { - logger.DebugC("channels", "Attempting to initialize WeCom App channel") - wecomApp, err := NewWeComAppChannel(m.config.Channels.WeComApp, m.bus) - if err != nil { - logger.ErrorCF("channels", "Failed to initialize WeCom App channel", map[string]any{ - "error": err.Error(), - }) - } else { - m.channels["wecom_app"] = wecomApp - logger.InfoC("channels", "WeCom App channel enabled successfully") - } + m.initChannel("wecom_app", "WeCom App") } - logger.InfoCF("channels", "Channel initialization completed", map[string]any{ + logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ "enabled_channels": len(m.channels), }) @@ -226,11 +144,11 @@ func (m *Manager) StartAll(ctx context.Context) error { go m.dispatchOutbound(dispatchCtx) for name, channel := range m.channels { - logger.InfoCF("channels", "Starting channel", map[string]any{ + logger.InfoCF("channels", "Starting channel", map[string]interface{}{ "channel": name, }) if err := channel.Start(ctx); err != nil { - logger.ErrorCF("channels", "Failed to start channel", map[string]any{ + logger.ErrorCF("channels", "Failed to start channel", map[string]interface{}{ "channel": name, "error": err.Error(), }) @@ -253,11 +171,11 @@ func (m *Manager) StopAll(ctx context.Context) error { } for name, channel := range m.channels { - logger.InfoCF("channels", "Stopping channel", map[string]any{ + logger.InfoCF("channels", "Stopping channel", map[string]interface{}{ "channel": name, }) if err := channel.Stop(ctx); err != nil { - logger.ErrorCF("channels", "Error stopping channel", map[string]any{ + logger.ErrorCF("channels", "Error stopping channel", map[string]interface{}{ "channel": name, "error": err.Error(), }) @@ -292,14 +210,14 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { m.mu.RUnlock() if !exists { - logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ + logger.WarnCF("channels", "Unknown channel for outbound message", map[string]interface{}{ "channel": msg.Channel, }) continue } if err := channel.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message to channel", map[string]any{ + logger.ErrorCF("channels", "Error sending message to channel", map[string]interface{}{ "channel": msg.Channel, "error": err.Error(), }) @@ -315,13 +233,13 @@ func (m *Manager) GetChannel(name string) (Channel, bool) { return channel, ok } -func (m *Manager) GetStatus() map[string]any { +func (m *Manager) GetStatus() map[string]interface{} { m.mu.RLock() defer m.mu.RUnlock() - status := make(map[string]any) + status := make(map[string]interface{}) for name, channel := range m.channels { - status[name] = map[string]any{ + status[name] = map[string]interface{}{ "enabled": true, "running": channel.IsRunning(), } From 36eb68dd6784fe5b0815acd1d960e0980a3353f0 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Fri, 20 Feb 2026 23:25:44 +0800 Subject: [PATCH 03/28] refactor(channels): add channel subpackages and update gateway imports --- cmd/picoclaw/cmd_gateway.go | 19 +- pkg/channels/dingtalk/dingtalk.go | 202 ++++ pkg/channels/dingtalk/init.go | 13 + pkg/channels/discord/discord.go | 373 +++++++ pkg/channels/discord/init.go | 13 + pkg/channels/feishu/common.go | 9 + pkg/channels/feishu/feishu_32.go | 37 + pkg/channels/feishu/feishu_64.go | 221 ++++ pkg/channels/feishu/init.go | 13 + pkg/channels/line/init.go | 13 + pkg/channels/line/line.go | 607 +++++++++++ pkg/channels/maixcam/init.go | 13 + pkg/channels/maixcam/maixcam.go | 244 +++++ pkg/channels/onebot/init.go | 13 + pkg/channels/onebot/onebot.go | 980 ++++++++++++++++++ pkg/channels/qq/init.go | 13 + pkg/channels/qq/qq.go | 248 +++++ pkg/channels/slack/init.go | 13 + pkg/channels/slack/slack.go | 444 ++++++++ pkg/channels/slack/slack_test.go | 174 ++++ pkg/channels/telegram/init.go | 13 + pkg/channels/telegram/telegram.go | 526 ++++++++++ pkg/channels/telegram/telegram_commands.go | 153 +++ pkg/channels/wecom/app.go | 636 ++++++++++++ pkg/channels/wecom/app_test.go | 1086 ++++++++++++++++++++ pkg/channels/wecom/bot.go | 469 +++++++++ pkg/channels/wecom/bot_test.go | 753 ++++++++++++++ pkg/channels/wecom/common.go | 134 +++ pkg/channels/wecom/init.go | 16 + pkg/channels/whatsapp/init.go | 13 + pkg/channels/whatsapp/whatsapp.go | 193 ++++ 31 files changed, 7651 insertions(+), 3 deletions(-) create mode 100644 pkg/channels/dingtalk/dingtalk.go create mode 100644 pkg/channels/dingtalk/init.go create mode 100644 pkg/channels/discord/discord.go create mode 100644 pkg/channels/discord/init.go create mode 100644 pkg/channels/feishu/common.go create mode 100644 pkg/channels/feishu/feishu_32.go create mode 100644 pkg/channels/feishu/feishu_64.go create mode 100644 pkg/channels/feishu/init.go create mode 100644 pkg/channels/line/init.go create mode 100644 pkg/channels/line/line.go create mode 100644 pkg/channels/maixcam/init.go create mode 100644 pkg/channels/maixcam/maixcam.go create mode 100644 pkg/channels/onebot/init.go create mode 100644 pkg/channels/onebot/onebot.go create mode 100644 pkg/channels/qq/init.go create mode 100644 pkg/channels/qq/qq.go create mode 100644 pkg/channels/slack/init.go create mode 100644 pkg/channels/slack/slack.go create mode 100644 pkg/channels/slack/slack_test.go create mode 100644 pkg/channels/telegram/init.go create mode 100644 pkg/channels/telegram/telegram.go create mode 100644 pkg/channels/telegram/telegram_commands.go create mode 100644 pkg/channels/wecom/app.go create mode 100644 pkg/channels/wecom/app_test.go create mode 100644 pkg/channels/wecom/bot.go create mode 100644 pkg/channels/wecom/bot_test.go create mode 100644 pkg/channels/wecom/common.go create mode 100644 pkg/channels/wecom/init.go create mode 100644 pkg/channels/whatsapp/init.go create mode 100644 pkg/channels/whatsapp/whatsapp.go diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 28ef76ad3..29b31e071 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -16,6 +16,9 @@ import ( "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + dch "github.com/sipeed/picoclaw/pkg/channels/discord" + slackch "github.com/sipeed/picoclaw/pkg/channels/slack" + tgram "github.com/sipeed/picoclaw/pkg/channels/telegram" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" @@ -26,6 +29,16 @@ import ( "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/voice" + + // Channel factory registrations (blank imports trigger init()) + _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" + _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/line" + _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" + _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/qq" + _ "github.com/sipeed/picoclaw/pkg/channels/wecom" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" ) func gatewayCmd() { @@ -138,19 +151,19 @@ func gatewayCmd() { if transcriber != nil { if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*channels.TelegramChannel); ok { + if tc, ok := telegramChannel.(*tgram.TelegramChannel); ok { tc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Telegram channel") } } if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*channels.DiscordChannel); ok { + if dc, ok := discordChannel.(*dch.DiscordChannel); ok { dc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Discord channel") } } if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*channels.SlackChannel); ok { + if sc, ok := slackChannel.(*slackch.SlackChannel); ok { sc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Slack channel") } diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go new file mode 100644 index 000000000..0edb0023c --- /dev/null +++ b/pkg/channels/dingtalk/dingtalk.go @@ -0,0 +1,202 @@ +// PicoClaw - Ultra-lightweight personal AI agent +// DingTalk channel implementation using Stream Mode + +package dingtalk + +import ( + "context" + "fmt" + "sync" + + "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" + "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// DingTalkChannel implements the Channel interface for DingTalk (钉钉) +// It uses WebSocket for receiving messages via stream mode and API for sending +type DingTalkChannel struct { + *channels.BaseChannel + config config.DingTalkConfig + clientID string + clientSecret string + streamClient *client.StreamClient + ctx context.Context + cancel context.CancelFunc + // Map to store session webhooks for each chat + sessionWebhooks sync.Map // chatID -> sessionWebhook +} + +// NewDingTalkChannel creates a new DingTalk channel instance +func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (*DingTalkChannel, error) { + if cfg.ClientID == "" || cfg.ClientSecret == "" { + return nil, fmt.Errorf("dingtalk client_id and client_secret are required") + } + + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) + + return &DingTalkChannel{ + BaseChannel: base, + config: cfg, + clientID: cfg.ClientID, + clientSecret: cfg.ClientSecret, + }, nil +} + +// Start initializes the DingTalk channel with Stream Mode +func (c *DingTalkChannel) Start(ctx context.Context) error { + logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Create credential config + cred := client.NewAppCredentialConfig(c.clientID, c.clientSecret) + + // Create the stream client with options + c.streamClient = client.NewStreamClient( + client.WithAppCredential(cred), + client.WithAutoReconnect(true), + ) + + // Register chatbot callback handler (IChatBotMessageHandler is a function type) + c.streamClient.RegisterChatBotCallbackRouter(c.onChatBotMessageReceived) + + // Start the stream client + if err := c.streamClient.Start(c.ctx); err != nil { + return fmt.Errorf("failed to start stream client: %w", err) + } + + c.SetRunning(true) + logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") + return nil +} + +// Stop gracefully stops the DingTalk channel +func (c *DingTalkChannel) Stop(ctx context.Context) error { + logger.InfoC("dingtalk", "Stopping DingTalk channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.streamClient != nil { + c.streamClient.Close() + } + + c.SetRunning(false) + logger.InfoC("dingtalk", "DingTalk channel stopped") + return nil +} + +// Send sends a message to DingTalk via the chatbot reply API +func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("dingtalk channel not running") + } + + // Get session webhook from storage + sessionWebhookRaw, ok := c.sessionWebhooks.Load(msg.ChatID) + if !ok { + return fmt.Errorf("no session_webhook found for chat %s, cannot send message", msg.ChatID) + } + + sessionWebhook, ok := sessionWebhookRaw.(string) + if !ok { + return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) + } + + logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + // Use the session webhook to send the reply + return c.SendDirectReply(ctx, sessionWebhook, msg.Content) +} + +// onChatBotMessageReceived implements the IChatBotMessageHandler function signature +// This is called by the Stream SDK when a new message arrives +// IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) +func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { + // Extract message content from Text field + content := data.Text.Content + if content == "" { + // Try to extract from Content interface{} if Text is empty + if contentMap, ok := data.Content.(map[string]interface{}); ok { + if textContent, ok := contentMap["content"].(string); ok { + content = textContent + } + } + } + + if content == "" { + return nil, nil // Ignore empty messages + } + + senderID := data.SenderStaffId + senderNick := data.SenderNick + chatID := senderID + if data.ConversationType != "1" { + // For group chats + chatID = data.ConversationId + } + + // Store the session webhook for this chat so we can reply later + c.sessionWebhooks.Store(chatID, data.SessionWebhook) + + metadata := map[string]string{ + "sender_name": senderNick, + "conversation_id": data.ConversationId, + "conversation_type": data.ConversationType, + "platform": "dingtalk", + "session_webhook": data.SessionWebhook, + } + + if data.ConversationType == "1" { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = data.ConversationId + } + + logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + "sender_nick": senderNick, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) + + // Return nil to indicate we've handled the message asynchronously + // The response will be sent through the message bus + return nil, nil +} + +// SendDirectReply sends a direct reply using the session webhook +func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { + replier := chatbot.NewChatbotReplier() + + // Convert string content to []byte for the API + contentBytes := []byte(content) + titleBytes := []byte("PicoClaw") + + // Send markdown formatted reply + err := replier.SimpleReplyMarkdown( + ctx, + sessionWebhook, + titleBytes, + contentBytes, + ) + + if err != nil { + return fmt.Errorf("failed to send reply: %w", err) + } + + return nil +} diff --git a/pkg/channels/dingtalk/init.go b/pkg/channels/dingtalk/init.go new file mode 100644 index 000000000..5f49bce8c --- /dev/null +++ b/pkg/channels/dingtalk/init.go @@ -0,0 +1,13 @@ +package dingtalk + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("dingtalk", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDingTalkChannel(cfg.Channels.DingTalk, b) + }) +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go new file mode 100644 index 000000000..6c4efd87c --- /dev/null +++ b/pkg/channels/discord/discord.go @@ -0,0 +1,373 @@ +package discord + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/bwmarrin/discordgo" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +const ( + transcriptionTimeout = 30 * time.Second + sendTimeout = 10 * time.Second +) + +type DiscordChannel struct { + *channels.BaseChannel + session *discordgo.Session + config config.DiscordConfig + transcriber *voice.GroqTranscriber + ctx context.Context + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal + botUserID string // stored for mention checking +} + +func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { + session, err := discordgo.New("Bot " + cfg.Token) + if err != nil { + return nil, fmt.Errorf("failed to create discord session: %w", err) + } + + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) + + return &DiscordChannel{ + BaseChannel: base, + session: session, + config: cfg, + transcriber: nil, + ctx: context.Background(), + typingStop: make(map[string]chan struct{}), + }, nil +} + +func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *DiscordChannel) getContext() context.Context { + if c.ctx == nil { + return context.Background() + } + return c.ctx +} + +func (c *DiscordChannel) Start(ctx context.Context) error { + logger.InfoC("discord", "Starting Discord bot") + + c.ctx = ctx + + // Get bot user ID before opening session to avoid race condition + botUser, err := c.session.User("@me") + if err != nil { + return fmt.Errorf("failed to get bot user: %w", err) + } + c.botUserID = botUser.ID + + c.session.AddHandler(c.handleMessage) + + if err := c.session.Open(); err != nil { + return fmt.Errorf("failed to open discord session: %w", err) + } + + c.SetRunning(true) + + logger.InfoCF("discord", "Discord bot connected", map[string]any{ + "username": botUser.Username, + "user_id": botUser.ID, + }) + + return nil +} + +func (c *DiscordChannel) Stop(ctx context.Context) error { + logger.InfoC("discord", "Stopping Discord bot") + c.SetRunning(false) + + // Stop all typing goroutines before closing session + c.typingMu.Lock() + for chatID, stop := range c.typingStop { + close(stop) + delete(c.typingStop, chatID) + } + c.typingMu.Unlock() + + if err := c.session.Close(); err != nil { + return fmt.Errorf("failed to close discord session: %w", err) + } + + return nil +} + +func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + c.stopTyping(msg.ChatID) + + if !c.IsRunning() { + return fmt.Errorf("discord bot not running") + } + + channelID := msg.ChatID + if channelID == "" { + return fmt.Errorf("channel ID is empty") + } + + runes := []rune(msg.Content) + if len(runes) == 0 { + return nil + } + + chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars + + for _, chunk := range chunks { + if err := c.sendChunk(ctx, channelID, chunk); err != nil { + return err + } + } + + return nil +} + +func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { + // Use the passed ctx for timeout control + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSend(channelID, content) + done <- err + }() + + select { + case err := <-done: + if err != nil { + return fmt.Errorf("failed to send discord message: %w", err) + } + return nil + case <-sendCtx.Done(): + return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + } +} + +// appendContent safely appends content to existing text +func appendContent(content, suffix string) string { + if content == "" { + return suffix + } + return content + "\n" + suffix +} + +func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { + if m == nil || m.Author == nil { + return + } + + if m.Author.ID == s.State.User.ID { + return + } + + // Check allowlist first to avoid downloading attachments and transcribing for rejected users + if !c.IsAllowed(m.Author.ID) { + logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + + // If configured to only respond to mentions, check if bot is mentioned + // Skip this check for DMs (GuildID is empty) - DMs should always be responded to + if c.config.MentionOnly && m.GuildID != "" { + isMentioned := false + for _, mention := range m.Mentions { + if mention.ID == c.botUserID { + isMentioned = true + break + } + } + if !isMentioned { + logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ + "user_id": m.Author.ID, + }) + return + } + } + + senderID := m.Author.ID + senderName := m.Author.Username + if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { + senderName += "#" + m.Author.Discriminator + } + + content := m.Content + content = c.stripBotMention(content) + mediaPaths := make([]string, 0, len(m.Attachments)) + localFiles := make([]string, 0, len(m.Attachments)) + + // Ensure temp files are cleaned up when function returns + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + for _, attachment := range m.Attachments { + isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) + + if isAudio { + localPath := c.downloadAttachment(attachment.URL, attachment.Filename) + if localPath != "" { + localFiles = append(localFiles, localPath) + + transcribedText := "" + if c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) + result, err := c.transcriber.Transcribe(ctx, localPath) + cancel() // Release context resources immediately to avoid leaks in for loop + + if err != nil { + logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ + "error": err.Error(), + }) + transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) + } else { + transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) + logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ + "text": result.Text, + }) + } + } else { + transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) + } + + content = appendContent(content, transcribedText) + } else { + logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ + "url": attachment.URL, + "filename": attachment.Filename, + }) + mediaPaths = append(mediaPaths, attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) + } + } else { + mediaPaths = append(mediaPaths, attachment.URL) + content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) + } + } + + if content == "" && len(mediaPaths) == 0 { + return + } + + if content == "" { + content = "[media only]" + } + + // Start typing after all early returns — guaranteed to have a matching Send() + c.startTyping(m.ChannelID) + + logger.DebugCF("discord", "Received message", map[string]any{ + "sender_name": senderName, + "sender_id": senderID, + "preview": utils.Truncate(content, 50), + }) + + peerKind := "channel" + peerID := m.ChannelID + if m.GuildID == "" { + peerKind = "direct" + peerID = senderID + } + + metadata := map[string]string{ + "message_id": m.ID, + "user_id": senderID, + "username": m.Author.Username, + "display_name": senderName, + "guild_id": m.GuildID, + "channel_id": m.ChannelID, + "is_dm": fmt.Sprintf("%t", m.GuildID == ""), + "peer_kind": peerKind, + "peer_id": peerID, + } + + c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) +} + +// startTyping starts a continuous typing indicator loop for the given chatID. +// It stops any existing typing loop for that chatID before starting a new one. +func (c *DiscordChannel) startTyping(chatID string) { + c.typingMu.Lock() + // Stop existing loop for this chatID if any + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + } + stop := make(chan struct{}) + c.typingStop[chatID] = stop + c.typingMu.Unlock() + + go func() { + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + } + ticker := time.NewTicker(8 * time.Second) + defer ticker.Stop() + timeout := time.After(5 * time.Minute) + for { + select { + case <-stop: + return + case <-timeout: + return + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.session.ChannelTyping(chatID); err != nil { + logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + } + } + } + }() +} + +// stopTyping stops the typing indicator loop for the given chatID. +func (c *DiscordChannel) stopTyping(chatID string) { + c.typingMu.Lock() + defer c.typingMu.Unlock() + if stop, ok := c.typingStop[chatID]; ok { + close(stop) + delete(c.typingStop, chatID) + } +} + +func (c *DiscordChannel) downloadAttachment(url, filename string) string { + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "discord", + }) +} + +// stripBotMention removes the bot mention from the message content. +// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). +func (c *DiscordChannel) stripBotMention(text string) string { + if c.botUserID == "" { + return text + } + // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID> + text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "") + text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "") + return strings.TrimSpace(text) +} diff --git a/pkg/channels/discord/init.go b/pkg/channels/discord/init.go new file mode 100644 index 000000000..15a539804 --- /dev/null +++ b/pkg/channels/discord/init.go @@ -0,0 +1,13 @@ +package discord + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("discord", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewDiscordChannel(cfg.Channels.Discord, b) + }) +} diff --git a/pkg/channels/feishu/common.go b/pkg/channels/feishu/common.go new file mode 100644 index 000000000..e8a057741 --- /dev/null +++ b/pkg/channels/feishu/common.go @@ -0,0 +1,9 @@ +package feishu + +// stringValue safely dereferences a *string pointer. +func stringValue(v *string) string { + if v == nil { + return "" + } + return *v +} diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go new file mode 100644 index 000000000..14711e49e --- /dev/null +++ b/pkg/channels/feishu/feishu_32.go @@ -0,0 +1,37 @@ +//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 + +package feishu + +import ( + "context" + "errors" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +// FeishuChannel is a stub implementation for 32-bit architectures +type FeishuChannel struct { + *channels.BaseChannel +} + +// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported +func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { + return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config") +} + +// Start is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Start(ctx context.Context) error { + return nil +} + +// Stop is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Stop(ctx context.Context) error { + return nil +} + +// Send is a stub method to satisfy the Channel interface +func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return errors.New("feishu channel is not supported on 32-bit architectures") +} diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go new file mode 100644 index 000000000..a49ee34cb --- /dev/null +++ b/pkg/channels/feishu/feishu_64.go @@ -0,0 +1,221 @@ +//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 + +package feishu + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "time" + + lark "github.com/larksuite/oapi-sdk-go/v3" + larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" + larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" + larkws "github.com/larksuite/oapi-sdk-go/v3/ws" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type FeishuChannel struct { + *channels.BaseChannel + config config.FeishuConfig + client *lark.Client + wsClient *larkws.Client + + mu sync.Mutex + cancel context.CancelFunc +} + +func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { + base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) + + return &FeishuChannel{ + BaseChannel: base, + config: cfg, + client: lark.NewClient(cfg.AppID, cfg.AppSecret), + }, nil +} + +func (c *FeishuChannel) Start(ctx context.Context) error { + if c.config.AppID == "" || c.config.AppSecret == "" { + return fmt.Errorf("feishu app_id or app_secret is empty") + } + + dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey). + OnP2MessageReceiveV1(c.handleMessageReceive) + + runCtx, cancel := context.WithCancel(ctx) + + c.mu.Lock() + c.cancel = cancel + c.wsClient = larkws.NewClient( + c.config.AppID, + c.config.AppSecret, + larkws.WithEventHandler(dispatcher), + ) + wsClient := c.wsClient + c.mu.Unlock() + + c.SetRunning(true) + logger.InfoC("feishu", "Feishu channel started (websocket mode)") + + go func() { + if err := wsClient.Start(runCtx); err != nil { + logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + return nil +} + +func (c *FeishuChannel) Stop(ctx context.Context) error { + c.mu.Lock() + if c.cancel != nil { + c.cancel() + c.cancel = nil + } + c.wsClient = nil + c.mu.Unlock() + + c.SetRunning(false) + logger.InfoC("feishu", "Feishu channel stopped") + return nil +} + +func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("feishu channel not running") + } + + if msg.ChatID == "" { + return fmt.Errorf("chat ID is empty") + } + + payload, err := json.Marshal(map[string]string{"text": msg.Content}) + if err != nil { + return fmt.Errorf("failed to marshal feishu content: %w", err) + } + + req := larkim.NewCreateMessageReqBuilder(). + ReceiveIdType(larkim.ReceiveIdTypeChatId). + Body(larkim.NewCreateMessageReqBodyBuilder(). + ReceiveId(msg.ChatID). + MsgType(larkim.MsgTypeText). + Content(string(payload)). + Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())). + Build()). + Build() + + resp, err := c.client.Im.V1.Message.Create(ctx, req) + if err != nil { + return fmt.Errorf("failed to send feishu message: %w", err) + } + + if !resp.Success() { + return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) + } + + logger.DebugCF("feishu", "Feishu message sent", map[string]interface{}{ + "chat_id": msg.ChatID, + }) + + return nil +} + +func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error { + if event == nil || event.Event == nil || event.Event.Message == nil { + return nil + } + + message := event.Event.Message + sender := event.Event.Sender + + chatID := stringValue(message.ChatId) + if chatID == "" { + return nil + } + + senderID := extractFeishuSenderID(sender) + if senderID == "" { + senderID = "unknown" + } + + content := extractFeishuMessageContent(message) + if content == "" { + content = "[empty message]" + } + + metadata := map[string]string{} + if messageID := stringValue(message.MessageId); messageID != "" { + metadata["message_id"] = messageID + } + if messageType := stringValue(message.MessageType); messageType != "" { + metadata["message_type"] = messageType + } + if chatType := stringValue(message.ChatType); chatType != "" { + metadata["chat_type"] = chatType + } + if sender != nil && sender.TenantKey != nil { + metadata["tenant_key"] = *sender.TenantKey + } + + chatType := stringValue(message.ChatType) + if chatType == "p2p" { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } + + logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 80), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) + return nil +} + +func extractFeishuSenderID(sender *larkim.EventSender) string { + if sender == nil || sender.SenderId == nil { + return "" + } + + if sender.SenderId.UserId != nil && *sender.SenderId.UserId != "" { + return *sender.SenderId.UserId + } + if sender.SenderId.OpenId != nil && *sender.SenderId.OpenId != "" { + return *sender.SenderId.OpenId + } + if sender.SenderId.UnionId != nil && *sender.SenderId.UnionId != "" { + return *sender.SenderId.UnionId + } + + return "" +} + +func extractFeishuMessageContent(message *larkim.EventMessage) string { + if message == nil || message.Content == nil || *message.Content == "" { + return "" + } + + if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText { + var textPayload struct { + Text string `json:"text"` + } + if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil { + return textPayload.Text + } + } + + return *message.Content +} diff --git a/pkg/channels/feishu/init.go b/pkg/channels/feishu/init.go new file mode 100644 index 000000000..7e5a62dae --- /dev/null +++ b/pkg/channels/feishu/init.go @@ -0,0 +1,13 @@ +package feishu + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("feishu", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewFeishuChannel(cfg.Channels.Feishu, b) + }) +} diff --git a/pkg/channels/line/init.go b/pkg/channels/line/init.go new file mode 100644 index 000000000..9265575cc --- /dev/null +++ b/pkg/channels/line/init.go @@ -0,0 +1,13 @@ +package line + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("line", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewLINEChannel(cfg.Channels.LINE, b) + }) +} diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go new file mode 100644 index 000000000..7df0491d9 --- /dev/null +++ b/pkg/channels/line/line.go @@ -0,0 +1,607 @@ +package line + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + lineAPIBase = "https://api.line.me/v2/bot" + lineDataAPIBase = "https://api-data.line.me/v2/bot" + lineReplyEndpoint = lineAPIBase + "/message/reply" + linePushEndpoint = lineAPIBase + "/message/push" + lineContentEndpoint = lineDataAPIBase + "/message/%s/content" + lineBotInfoEndpoint = lineAPIBase + "/info" + lineLoadingEndpoint = lineAPIBase + "/chat/loading/start" + lineReplyTokenMaxAge = 25 * time.Second +) + +type replyTokenEntry struct { + token string + timestamp time.Time +} + +// LINEChannel implements the Channel interface for LINE Official Account +// using the LINE Messaging API with HTTP webhook for receiving messages +// and REST API for sending messages. +type LINEChannel struct { + *channels.BaseChannel + config config.LINEConfig + httpServer *http.Server + botUserID string // Bot's user ID + botBasicID string // Bot's basic ID (e.g. @216ru...) + botDisplayName string // Bot's display name for text-based mention detection + replyTokens sync.Map // chatID -> replyTokenEntry + quoteTokens sync.Map // chatID -> quoteToken (string) + ctx context.Context + cancel context.CancelFunc +} + +// NewLINEChannel creates a new LINE channel instance. +func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) { + if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" { + return nil, fmt.Errorf("line channel_secret and channel_access_token are required") + } + + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + + return &LINEChannel{ + BaseChannel: base, + config: cfg, + }, nil +} + +// Start launches the HTTP webhook server. +func (c *LINEChannel) Start(ctx context.Context) error { + logger.InfoC("line", "Starting LINE channel (Webhook Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Fetch bot profile to get bot's userId for mention detection + if err := c.fetchBotInfo(); err != nil { + logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]interface{}{ + "error": err.Error(), + }) + } else { + logger.InfoCF("line", "Bot info fetched", map[string]interface{}{ + "bot_user_id": c.botUserID, + "basic_id": c.botBasicID, + "display_name": c.botDisplayName, + }) + } + + mux := http.NewServeMux() + path := c.config.WebhookPath + if path == "" { + path = "/webhook/line" + } + mux.HandleFunc(path, c.webhookHandler) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.httpServer = &http.Server{ + Addr: addr, + Handler: mux, + } + + go func() { + logger.InfoCF("line", "LINE webhook server listening", map[string]interface{}{ + "addr": addr, + "path": path, + }) + if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("line", "Webhook server error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + c.SetRunning(true) + logger.InfoC("line", "LINE channel started (Webhook Mode)") + return nil +} + +// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API. +func (c *LINEChannel) fetchBotInfo() error { + req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("bot info API returned status %d", resp.StatusCode) + } + + var info struct { + UserID string `json:"userId"` + BasicID string `json:"basicId"` + DisplayName string `json:"displayName"` + } + if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { + return err + } + + c.botUserID = info.UserID + c.botBasicID = info.BasicID + c.botDisplayName = info.DisplayName + return nil +} + +// Stop gracefully shuts down the HTTP server. +func (c *LINEChannel) Stop(ctx context.Context) error { + logger.InfoC("line", "Stopping LINE channel") + + if c.cancel != nil { + c.cancel() + } + + if c.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := c.httpServer.Shutdown(shutdownCtx); err != nil { + logger.ErrorCF("line", "Webhook server shutdown error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + + c.SetRunning(false) + logger.InfoC("line", "LINE channel stopped") + return nil +} + +// webhookHandler handles incoming LINE webhook requests. +func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + logger.ErrorCF("line", "Failed to read request body", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + signature := r.Header.Get("X-Line-Signature") + if !c.verifySignature(body, signature) { + logger.WarnC("line", "Invalid webhook signature") + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + var payload struct { + Events []lineEvent `json:"events"` + } + if err := json.Unmarshal(body, &payload); err != nil { + logger.ErrorCF("line", "Failed to parse webhook payload", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Bad request", http.StatusBadRequest) + return + } + + // Return 200 immediately, process events asynchronously + w.WriteHeader(http.StatusOK) + + for _, event := range payload.Events { + go c.processEvent(event) + } +} + +// verifySignature validates the X-Line-Signature using HMAC-SHA256. +func (c *LINEChannel) verifySignature(body []byte, signature string) bool { + if signature == "" { + return false + } + + mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret)) + mac.Write(body) + expected := base64.StdEncoding.EncodeToString(mac.Sum(nil)) + + return hmac.Equal([]byte(expected), []byte(signature)) +} + +// LINE webhook event types +type lineEvent struct { + Type string `json:"type"` + ReplyToken string `json:"replyToken"` + Source lineSource `json:"source"` + Message json.RawMessage `json:"message"` + Timestamp int64 `json:"timestamp"` +} + +type lineSource struct { + Type string `json:"type"` // "user", "group", "room" + UserID string `json:"userId"` + GroupID string `json:"groupId"` + RoomID string `json:"roomId"` +} + +type lineMessage struct { + ID string `json:"id"` + Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker" + Text string `json:"text"` + QuoteToken string `json:"quoteToken"` + Mention *struct { + Mentionees []lineMentionee `json:"mentionees"` + } `json:"mention"` + ContentProvider struct { + Type string `json:"type"` + } `json:"contentProvider"` +} + +type lineMentionee struct { + Index int `json:"index"` + Length int `json:"length"` + Type string `json:"type"` // "user", "all" + UserID string `json:"userId"` +} + +func (c *LINEChannel) processEvent(event lineEvent) { + if event.Type != "message" { + logger.DebugCF("line", "Ignoring non-message event", map[string]interface{}{ + "type": event.Type, + }) + return + } + + senderID := event.Source.UserID + chatID := c.resolveChatID(event.Source) + isGroup := event.Source.Type == "group" || event.Source.Type == "room" + + var msg lineMessage + if err := json.Unmarshal(event.Message, &msg); err != nil { + logger.ErrorCF("line", "Failed to parse message", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + // In group chats, only respond when the bot is mentioned + if isGroup && !c.isBotMentioned(msg) { + logger.DebugCF("line", "Ignoring group message without mention", map[string]interface{}{ + "chat_id": chatID, + }) + return + } + + // Store reply token for later use + if event.ReplyToken != "" { + c.replyTokens.Store(chatID, replyTokenEntry{ + token: event.ReplyToken, + timestamp: time.Now(), + }) + } + + // Store quote token for quoting the original message in reply + if msg.QuoteToken != "" { + c.quoteTokens.Store(chatID, msg.QuoteToken) + } + + var content string + var mediaPaths []string + localFiles := []string{} + + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("line", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + switch msg.Type { + case "text": + content = msg.Text + // Strip bot mention from text in group chats + if isGroup { + content = c.stripBotMention(content, msg) + } + case "image": + localPath := c.downloadContent(msg.ID, "image.jpg") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[image]" + } + case "audio": + localPath := c.downloadContent(msg.ID, "audio.m4a") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[audio]" + } + case "video": + localPath := c.downloadContent(msg.ID, "video.mp4") + if localPath != "" { + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + content = "[video]" + } + case "file": + content = "[file]" + case "sticker": + content = "[sticker]" + default: + content = fmt.Sprintf("[%s]", msg.Type) + } + + if strings.TrimSpace(content) == "" { + return + } + + metadata := map[string]string{ + "platform": "line", + "source_type": event.Source.Type, + "message_id": msg.ID, + } + + if isGroup { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } else { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } + + logger.DebugCF("line", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "message_type": msg.Type, + "is_group": isGroup, + "preview": utils.Truncate(content, 50), + }) + + // Show typing/loading indicator (requires user ID, not group ID) + c.sendLoading(senderID) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +// isBotMentioned checks if the bot is mentioned in the message. +// It first checks the mention metadata (userId match), then falls back +// to text-based detection using the bot's display name, since LINE may +// not include userId in mentionees for Official Accounts. +func (c *LINEChannel) isBotMentioned(msg lineMessage) bool { + // Check mention metadata + if msg.Mention != nil { + for _, m := range msg.Mention.Mentionees { + if m.Type == "all" { + return true + } + if c.botUserID != "" && m.UserID == c.botUserID { + return true + } + } + // Mention metadata exists with mentionees but bot not matched by userId. + // The bot IS likely mentioned (LINE includes mention struct when bot is @-ed), + // so check if any mentionee overlaps with bot display name in text. + if c.botDisplayName != "" { + for _, m := range msg.Mention.Mentionees { + if m.Index >= 0 && m.Length > 0 { + runes := []rune(msg.Text) + end := m.Index + m.Length + if end <= len(runes) { + mentionText := string(runes[m.Index:end]) + if strings.Contains(mentionText, c.botDisplayName) { + return true + } + } + } + } + } + } + + // Fallback: text-based detection with display name + if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) { + return true + } + + return false +} + +// stripBotMention removes the @BotName mention text from the message. +func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string { + stripped := false + + // Try to strip using mention metadata indices + if msg.Mention != nil { + runes := []rune(text) + for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- { + m := msg.Mention.Mentionees[i] + // Strip if userId matches OR if the mention text contains the bot display name + shouldStrip := false + if c.botUserID != "" && m.UserID == c.botUserID { + shouldStrip = true + } else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 { + end := m.Index + m.Length + if end <= len(runes) { + mentionText := string(runes[m.Index:end]) + if strings.Contains(mentionText, c.botDisplayName) { + shouldStrip = true + } + } + } + if shouldStrip { + start := m.Index + end := m.Index + m.Length + if start >= 0 && end <= len(runes) { + runes = append(runes[:start], runes[end:]...) + stripped = true + } + } + } + if stripped { + return strings.TrimSpace(string(runes)) + } + } + + // Fallback: strip @DisplayName from text + if c.botDisplayName != "" { + text = strings.ReplaceAll(text, "@"+c.botDisplayName, "") + } + + return strings.TrimSpace(text) +} + +// resolveChatID determines the chat ID from the event source. +// For group/room messages, use the group/room ID; for 1:1, use the user ID. +func (c *LINEChannel) resolveChatID(source lineSource) string { + switch source.Type { + case "group": + return source.GroupID + case "room": + return source.RoomID + default: + return source.UserID + } +} + +// Send sends a message to LINE. It first tries the Reply API (free) +// using a cached reply token, then falls back to the Push API. +func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("line channel not running") + } + + // Load and consume quote token for this chat + var quoteToken string + if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok { + quoteToken = qt.(string) + } + + // Try reply token first (free, valid for ~25 seconds) + if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok { + tokenEntry := entry.(replyTokenEntry) + if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge { + if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil { + logger.DebugCF("line", "Message sent via Reply API", map[string]interface{}{ + "chat_id": msg.ChatID, + "quoted": quoteToken != "", + }) + return nil + } + logger.DebugC("line", "Reply API failed, falling back to Push API") + } + } + + // Fall back to Push API + return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) +} + +// buildTextMessage creates a text message object, optionally with quoteToken. +func buildTextMessage(content, quoteToken string) map[string]string { + msg := map[string]string{ + "type": "text", + "text": content, + } + if quoteToken != "" { + msg["quoteToken"] = quoteToken + } + return msg +} + +// sendReply sends a message using the LINE Reply API. +func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error { + payload := map[string]interface{}{ + "replyToken": replyToken, + "messages": []map[string]string{buildTextMessage(content, quoteToken)}, + } + + return c.callAPI(ctx, lineReplyEndpoint, payload) +} + +// sendPush sends a message using the LINE Push API. +func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error { + payload := map[string]interface{}{ + "to": to, + "messages": []map[string]string{buildTextMessage(content, quoteToken)}, + } + + return c.callAPI(ctx, linePushEndpoint, payload) +} + +// sendLoading sends a loading animation indicator to the chat. +func (c *LINEChannel) sendLoading(chatID string) { + payload := map[string]interface{}{ + "chatId": chatID, + "loadingSeconds": 60, + } + if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil { + logger.DebugCF("line", "Failed to send loading indicator", map[string]interface{}{ + "error": err.Error(), + }) + } +} + +// callAPI makes an authenticated POST request to the LINE API. +func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload interface{}) error { + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal payload: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("API request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// downloadContent downloads media content from the LINE API. +func (c *LINEChannel) downloadContent(messageID, filename string) string { + url := fmt.Sprintf(lineContentEndpoint, messageID) + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "line", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.ChannelAccessToken, + }, + }) +} diff --git a/pkg/channels/maixcam/init.go b/pkg/channels/maixcam/init.go new file mode 100644 index 000000000..5a269b22b --- /dev/null +++ b/pkg/channels/maixcam/init.go @@ -0,0 +1,13 @@ +package maixcam + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("maixcam", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewMaixCamChannel(cfg.Channels.MaixCam, b) + }) +} diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go new file mode 100644 index 000000000..d3c6662d7 --- /dev/null +++ b/pkg/channels/maixcam/maixcam.go @@ -0,0 +1,244 @@ +package maixcam + +import ( + "context" + "encoding/json" + "fmt" + "net" + "sync" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +type MaixCamChannel struct { + *channels.BaseChannel + config config.MaixCamConfig + listener net.Listener + clients map[net.Conn]bool + clientsMux sync.RWMutex +} + +type MaixCamMessage struct { + Type string `json:"type"` + Tips string `json:"tips"` + Timestamp float64 `json:"timestamp"` + Data map[string]interface{} `json:"data"` +} + +func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { + base := channels.NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) + + return &MaixCamChannel{ + BaseChannel: base, + config: cfg, + clients: make(map[net.Conn]bool), + }, nil +} + +func (c *MaixCamChannel) Start(ctx context.Context) error { + logger.InfoC("maixcam", "Starting MaixCam channel server") + + addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) + listener, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("failed to listen on %s: %w", addr, err) + } + + c.listener = listener + c.SetRunning(true) + + logger.InfoCF("maixcam", "MaixCam server listening", map[string]interface{}{ + "host": c.config.Host, + "port": c.config.Port, + }) + + go c.acceptConnections(ctx) + + return nil +} + +func (c *MaixCamChannel) acceptConnections(ctx context.Context) { + logger.DebugC("maixcam", "Starting connection acceptor") + + for { + select { + case <-ctx.Done(): + logger.InfoC("maixcam", "Stopping connection acceptor") + return + default: + conn, err := c.listener.Accept() + if err != nil { + if c.IsRunning() { + logger.ErrorCF("maixcam", "Failed to accept connection", map[string]interface{}{ + "error": err.Error(), + }) + } + return + } + + logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]interface{}{ + "remote_addr": conn.RemoteAddr().String(), + }) + + c.clientsMux.Lock() + c.clients[conn] = true + c.clientsMux.Unlock() + + go c.handleConnection(conn, ctx) + } + } +} + +func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { + logger.DebugC("maixcam", "Handling MaixCam connection") + + defer func() { + conn.Close() + c.clientsMux.Lock() + delete(c.clients, conn) + c.clientsMux.Unlock() + logger.DebugC("maixcam", "Connection closed") + }() + + decoder := json.NewDecoder(conn) + + for { + select { + case <-ctx.Done(): + return + default: + var msg MaixCamMessage + if err := decoder.Decode(&msg); err != nil { + if err.Error() != "EOF" { + logger.ErrorCF("maixcam", "Failed to decode message", map[string]interface{}{ + "error": err.Error(), + }) + } + return + } + + c.processMessage(msg, conn) + } + } +} + +func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) { + switch msg.Type { + case "person_detected": + c.handlePersonDetection(msg) + case "heartbeat": + logger.DebugC("maixcam", "Received heartbeat") + case "status": + c.handleStatusUpdate(msg) + default: + logger.WarnCF("maixcam", "Unknown message type", map[string]interface{}{ + "type": msg.Type, + }) + } +} + +func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { + logger.InfoCF("maixcam", "", map[string]interface{}{ + "timestamp": msg.Timestamp, + "data": msg.Data, + }) + + senderID := "maixcam" + chatID := "default" + + classInfo, ok := msg.Data["class_name"].(string) + if !ok { + classInfo = "person" + } + + score, _ := msg.Data["score"].(float64) + x, _ := msg.Data["x"].(float64) + y, _ := msg.Data["y"].(float64) + w, _ := msg.Data["w"].(float64) + h, _ := msg.Data["h"].(float64) + + content := fmt.Sprintf("📷 Person detected!\nClass: %s\nConfidence: %.2f%%\nPosition: (%.0f, %.0f)\nSize: %.0fx%.0f", + classInfo, score*100, x, y, w, h) + + metadata := map[string]string{ + "timestamp": fmt.Sprintf("%.0f", msg.Timestamp), + "class_id": fmt.Sprintf("%.0f", msg.Data["class_id"]), + "score": fmt.Sprintf("%.2f", score), + "x": fmt.Sprintf("%.0f", x), + "y": fmt.Sprintf("%.0f", y), + "w": fmt.Sprintf("%.0f", w), + "h": fmt.Sprintf("%.0f", h), + "peer_kind": "channel", + "peer_id": "default", + } + + c.HandleMessage(senderID, chatID, content, []string{}, metadata) +} + +func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { + logger.InfoCF("maixcam", "Status update from MaixCam", map[string]interface{}{ + "status": msg.Data, + }) +} + +func (c *MaixCamChannel) Stop(ctx context.Context) error { + logger.InfoC("maixcam", "Stopping MaixCam channel") + c.SetRunning(false) + + if c.listener != nil { + c.listener.Close() + } + + c.clientsMux.Lock() + defer c.clientsMux.Unlock() + + for conn := range c.clients { + conn.Close() + } + c.clients = make(map[net.Conn]bool) + + logger.InfoC("maixcam", "MaixCam channel stopped") + return nil +} + +func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("maixcam channel not running") + } + + c.clientsMux.RLock() + defer c.clientsMux.RUnlock() + + if len(c.clients) == 0 { + logger.WarnC("maixcam", "No MaixCam devices connected") + return fmt.Errorf("no connected MaixCam devices") + } + + response := map[string]interface{}{ + "type": "command", + "timestamp": float64(0), + "message": msg.Content, + "chat_id": msg.ChatID, + } + + data, err := json.Marshal(response) + if err != nil { + return fmt.Errorf("failed to marshal response: %w", err) + } + + var sendErr error + for conn := range c.clients { + if _, err := conn.Write(data); err != nil { + logger.ErrorCF("maixcam", "Failed to send to client", map[string]interface{}{ + "client": conn.RemoteAddr().String(), + "error": err.Error(), + }) + sendErr = err + } + } + + return sendErr +} diff --git a/pkg/channels/onebot/init.go b/pkg/channels/onebot/init.go new file mode 100644 index 000000000..84c06dfd6 --- /dev/null +++ b/pkg/channels/onebot/init.go @@ -0,0 +1,13 @@ +package onebot + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("onebot", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewOneBotChannel(cfg.Channels.OneBot, b) + }) +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go new file mode 100644 index 000000000..209f2dc00 --- /dev/null +++ b/pkg/channels/onebot/onebot.go @@ -0,0 +1,980 @@ +package onebot + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type OneBotChannel struct { + *channels.BaseChannel + config config.OneBotConfig + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + dedup map[string]struct{} + dedupRing []string + dedupIdx int + mu sync.Mutex + writeMu sync.Mutex + echoCounter int64 + selfID int64 + pending map[string]chan json.RawMessage + pendingMu sync.Mutex + transcriber *voice.GroqTranscriber + lastMessageID sync.Map + pendingEmojiMsg sync.Map +} + +type oneBotRawEvent struct { + PostType string `json:"post_type"` + MessageType string `json:"message_type"` + SubType string `json:"sub_type"` + MessageID json.RawMessage `json:"message_id"` + UserID json.RawMessage `json:"user_id"` + GroupID json.RawMessage `json:"group_id"` + RawMessage string `json:"raw_message"` + Message json.RawMessage `json:"message"` + Sender json.RawMessage `json:"sender"` + SelfID json.RawMessage `json:"self_id"` + Time json.RawMessage `json:"time"` + MetaEventType string `json:"meta_event_type"` + NoticeType string `json:"notice_type"` + Echo string `json:"echo"` + RetCode json.RawMessage `json:"retcode"` + Status json.RawMessage `json:"status"` + Data json.RawMessage `json:"data"` +} + +type BotStatus struct { + Online bool `json:"online"` + Good bool `json:"good"` +} + +func isAPIResponse(raw json.RawMessage) bool { + if len(raw) == 0 { + return false + } + var s string + if json.Unmarshal(raw, &s) == nil { + return s == "ok" || s == "failed" + } + var bs BotStatus + if json.Unmarshal(raw, &bs) == nil { + return bs.Online || bs.Good + } + return false +} + +type oneBotSender struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + Card string `json:"card"` +} + +type oneBotAPIRequest struct { + Action string `json:"action"` + Params interface{} `json:"params"` + Echo string `json:"echo,omitempty"` +} + +type oneBotMessageSegment struct { + Type string `json:"type"` + Data map[string]interface{} `json:"data"` +} + +func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { + base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + + const dedupSize = 1024 + return &OneBotChannel{ + BaseChannel: base, + config: cfg, + dedup: make(map[string]struct{}, dedupSize), + dedupRing: make([]string, dedupSize), + dedupIdx: 0, + pending: make(map[string]chan json.RawMessage), + }, nil +} + +func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { + go func() { + _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{ + "message_id": messageID, + "emoji_id": emojiID, + "set": set, + }, 5*time.Second) + if err != nil { + logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{ + "message_id": messageID, + "error": err.Error(), + }) + } + }() +} + +func (c *OneBotChannel) Start(ctx context.Context) error { + if c.config.WSUrl == "" { + return fmt.Errorf("OneBot ws_url not configured") + } + + logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{ + "ws_url": c.config.WSUrl, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) + + if err := c.connect(); err != nil { + logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{ + "error": err.Error(), + }) + } else { + go c.listen() + c.fetchSelfID() + } + + if c.config.ReconnectInterval > 0 { + go c.reconnectLoop() + } else { + if c.conn == nil { + return fmt.Errorf("failed to connect to OneBot and reconnect is disabled") + } + } + + c.SetRunning(true) + logger.InfoC("onebot", "OneBot channel started successfully") + + return nil +} + +func (c *OneBotChannel) connect() error { + dialer := websocket.DefaultDialer + dialer.HandshakeTimeout = 10 * time.Second + + header := make(map[string][]string) + if c.config.AccessToken != "" { + header["Authorization"] = []string{"Bearer " + c.config.AccessToken} + } + + conn, _, err := dialer.Dial(c.config.WSUrl, header) + if err != nil { + return err + } + + conn.SetPongHandler(func(appData string) error { + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + return nil + }) + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + c.mu.Lock() + c.conn = conn + c.mu.Unlock() + + go c.pinger(conn) + + logger.InfoC("onebot", "WebSocket connected") + return nil +} + +func (c *OneBotChannel) pinger(conn *websocket.Conn) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + c.writeMu.Lock() + err := conn.WriteMessage(websocket.PingMessage, nil) + c.writeMu.Unlock() + if err != nil { + logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{ + "error": err.Error(), + }) + return + } + } + } +} + +func (c *OneBotChannel) fetchSelfID() { + resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second) + if err != nil { + logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{ + "error": err.Error(), + }) + return + } + + type loginInfo struct { + UserID json.RawMessage `json:"user_id"` + Nickname string `json:"nickname"` + } + for _, extract := range []func() (*loginInfo, error){ + func() (*loginInfo, error) { + var w struct { + Data loginInfo `json:"data"` + } + err := json.Unmarshal(resp, &w) + return &w.Data, err + }, + func() (*loginInfo, error) { + var f loginInfo + err := json.Unmarshal(resp, &f) + return &f, err + }, + } { + info, err := extract() + if err != nil || len(info.UserID) == 0 { + continue + } + if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 { + atomic.StoreInt64(&c.selfID, uid) + logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{ + "self_id": uid, + "nickname": info.Nickname, + }) + return + } + } + + logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{ + "response": string(resp), + }) +} + +func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return nil, fmt.Errorf("WebSocket not connected") + } + + echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1)) + + ch := make(chan json.RawMessage, 1) + c.pendingMu.Lock() + c.pending[echo] = ch + c.pendingMu.Unlock() + + defer func() { + c.pendingMu.Lock() + delete(c.pending, echo) + c.pendingMu.Unlock() + }() + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("failed to marshal API request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + return nil, fmt.Errorf("failed to write API request: %w", err) + } + + select { + case resp := <-ch: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) + case <-c.ctx.Done(): + return nil, fmt.Errorf("context cancelled") + } +} + +func (c *OneBotChannel) reconnectLoop() { + interval := time.Duration(c.config.ReconnectInterval) * time.Second + if interval < 5*time.Second { + interval = 5 * time.Second + } + + for { + select { + case <-c.ctx.Done(): + return + case <-time.After(interval): + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.InfoC("onebot", "Attempting to reconnect...") + if err := c.connect(); err != nil { + logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{ + "error": err.Error(), + }) + } else { + go c.listen() + c.fetchSelfID() + } + } + } + } +} + +func (c *OneBotChannel) Stop(ctx context.Context) error { + logger.InfoC("onebot", "Stopping OneBot channel") + c.SetRunning(false) + + if c.cancel != nil { + c.cancel() + } + + c.pendingMu.Lock() + for echo, ch := range c.pending { + close(ch) + delete(c.pending, echo) + } + c.pendingMu.Unlock() + + c.mu.Lock() + if c.conn != nil { + c.conn.Close() + c.conn = nil + } + c.mu.Unlock() + + return nil +} + +func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("OneBot channel not running") + } + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return fmt.Errorf("OneBot WebSocket not connected") + } + + action, params, err := c.buildSendRequest(msg) + if err != nil { + return err + } + + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) + + req := oneBotAPIRequest{ + Action: action, + Params: params, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal OneBot request: %w", err) + } + + c.writeMu.Lock() + err = conn.WriteMessage(websocket.TextMessage, data) + c.writeMu.Unlock() + + if err != nil { + logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{ + "error": err.Error(), + }) + return err + } + + if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { + if mid, ok := msgID.(string); ok && mid != "" { + c.setMsgEmojiLike(mid, 289, false) + } + } + + return nil +} + +func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment { + var segments []oneBotMessageSegment + + if lastMsgID, ok := c.lastMessageID.Load(chatID); ok { + if msgID, ok := lastMsgID.(string); ok && msgID != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "reply", + Data: map[string]interface{}{"id": msgID}, + }) + } + } + + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]interface{}{"text": content}, + }) + + return segments +} + +func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) { + chatID := msg.ChatID + segments := c.buildMessageSegments(chatID, msg.Content) + + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID + } + + id, err := strconv.ParseInt(rawID, 10, 64) + if err != nil { + return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID) + } + return action, map[string]interface{}{idKey: id, "message": segments}, nil +} + +func (c *OneBotChannel) listen() { + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + logger.WarnC("onebot", "WebSocket connection is nil, listener exiting") + return + } + + for { + select { + case <-c.ctx.Done(): + return + default: + _, message, err := conn.ReadMessage() + if err != nil { + logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{ + "error": err.Error(), + }) + c.mu.Lock() + if c.conn == conn { + c.conn.Close() + c.conn = nil + } + c.mu.Unlock() + return + } + + _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) + + var raw oneBotRawEvent + if err := json.Unmarshal(message, &raw); err != nil { + logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{ + "error": err.Error(), + "payload": string(message), + }) + continue + } + + logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{ + "length": len(message), + "post_type": raw.PostType, + "sub_type": raw.SubType, + }) + + if raw.Echo != "" { + c.pendingMu.Lock() + ch, ok := c.pending[raw.Echo] + c.pendingMu.Unlock() + + if ok { + select { + case ch <- message: + default: + } + } else { + logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{ + "echo": raw.Echo, + "status": string(raw.Status), + }) + } + continue + } + + if isAPIResponse(raw.Status) { + logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{ + "status": string(raw.Status), + }) + continue + } + + c.handleRawEvent(&raw) + } + } +} + +func parseJSONInt64(raw json.RawMessage) (int64, error) { + if len(raw) == 0 { + return 0, nil + } + + var n int64 + if err := json.Unmarshal(raw, &n); err == nil { + return n, nil + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return strconv.ParseInt(s, 10, 64) + } + return 0, fmt.Errorf("cannot parse as int64: %s", string(raw)) +} + +func parseJSONString(raw json.RawMessage) string { + if len(raw) == 0 { + return "" + } + var s string + if err := json.Unmarshal(raw, &s); err == nil { + return s + } + + return string(raw) +} + +type parseMessageResult struct { + Text string + IsBotMentioned bool + Media []string + LocalFiles []string + ReplyTo string +} + +func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { + if len(raw) == 0 { + return parseMessageResult{} + } + + var s string + if err := json.Unmarshal(raw, &s); err == nil { + mentioned := false + if selfID > 0 { + cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) + if strings.Contains(s, cqAt) { + mentioned = true + s = strings.ReplaceAll(s, cqAt, "") + s = strings.TrimSpace(s) + } + } + return parseMessageResult{Text: s, IsBotMentioned: mentioned} + } + + var segments []map[string]interface{} + if err := json.Unmarshal(raw, &segments); err != nil { + return parseMessageResult{} + } + + var textParts []string + mentioned := false + selfIDStr := strconv.FormatInt(selfID, 10) + var media []string + var localFiles []string + var replyTo string + + for _, seg := range segments { + segType, _ := seg["type"].(string) + data, _ := seg["data"].(map[string]interface{}) + + switch segType { + case "text": + if data != nil { + if t, ok := data["text"].(string); ok { + textParts = append(textParts, t) + } + } + + case "at": + if data != nil && selfID > 0 { + qqVal := fmt.Sprintf("%v", data["qq"]) + if qqVal == selfIDStr || qqVal == "all" { + mentioned = true + } + } + + case "image", "video", "file": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"} + filename := defaults[segType] + if f, ok := data["file"].(string); ok && f != "" { + filename = f + } else if n, ok := data["name"].(string); ok && n != "" { + filename = n + } + localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + media = append(media, localPath) + localFiles = append(localFiles, localPath) + textParts = append(textParts, fmt.Sprintf("[%s]", segType)) + } + } + } + + case "record": + if data != nil { + url, _ := data["url"].(string) + if url != "" { + localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{ + LoggerPrefix: "onebot", + }) + if localPath != "" { + localFiles = append(localFiles, localPath) + if c.transcriber != nil && c.transcriber.IsAvailable() { + tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) + result, err := c.transcriber.Transcribe(tctx, localPath) + tcancel() + if err != nil { + logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + }) + textParts = append(textParts, "[voice (transcription failed)]") + media = append(media, localPath) + } else { + textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) + } + } else { + textParts = append(textParts, "[voice]") + media = append(media, localPath) + } + } + } + } + + case "reply": + if data != nil { + if id, ok := data["id"]; ok { + replyTo = fmt.Sprintf("%v", id) + } + } + + case "face": + if data != nil { + faceID, _ := data["id"] + textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID)) + } + + case "forward": + textParts = append(textParts, "[forward message]") + + default: + + } + } + + return parseMessageResult{ + Text: strings.TrimSpace(strings.Join(textParts, "")), + IsBotMentioned: mentioned, + Media: media, + LocalFiles: localFiles, + ReplyTo: replyTo, + } +} + +func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { + switch raw.PostType { + case "message": + if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { + if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{ + "user_id": userID, + }) + return + } + } + c.handleMessage(raw) + + case "message_sent": + logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{ + "message_type": raw.MessageType, + "message_id": parseJSONString(raw.MessageID), + }) + + case "meta_event": + c.handleMetaEvent(raw) + + case "notice": + c.handleNoticeEvent(raw) + + case "request": + logger.DebugCF("onebot", "Request event received", map[string]interface{}{ + "sub_type": raw.SubType, + }) + + case "": + logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{ + "echo": raw.Echo, + "status": raw.Status, + }) + + default: + logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{ + "post_type": raw.PostType, + }) + } +} + +func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { + if raw.MetaEventType == "lifecycle" { + logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType}) + } else if raw.MetaEventType != "heartbeat" { + logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil) + } +} + +func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) { + fields := map[string]interface{}{ + "notice_type": raw.NoticeType, + "sub_type": raw.SubType, + "group_id": parseJSONString(raw.GroupID), + "user_id": parseJSONString(raw.UserID), + "message_id": parseJSONString(raw.MessageID), + } + switch raw.NoticeType { + case "group_recall", "group_increase", "group_decrease", + "friend_add", "group_admin", "group_ban": + logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields) + default: + logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields) + } +} + +func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { + // Parse fields from raw event + userID, err := parseJSONInt64(raw.UserID) + if err != nil { + logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{ + "error": err.Error(), + "raw": string(raw.UserID), + }) + return + } + + groupID, _ := parseJSONInt64(raw.GroupID) + selfID, _ := parseJSONInt64(raw.SelfID) + messageID := parseJSONString(raw.MessageID) + + if selfID == 0 { + selfID = atomic.LoadInt64(&c.selfID) + } + + parsed := c.parseMessageSegments(raw.Message, selfID) + isBotMentioned := parsed.IsBotMentioned + + content := raw.RawMessage + if content == "" { + content = parsed.Text + } else if selfID > 0 { + cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) + if strings.Contains(content, cqAt) { + isBotMentioned = true + content = strings.ReplaceAll(content, cqAt, "") + content = strings.TrimSpace(content) + } + } + + if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") { + content = parsed.Text + } + + var sender oneBotSender + if len(raw.Sender) > 0 { + if err := json.Unmarshal(raw.Sender, &sender); err != nil { + logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{ + "error": err.Error(), + "sender": string(raw.Sender), + }) + } + } + + // Clean up temp files when done + if len(parsed.LocalFiles) > 0 { + defer func() { + for _, f := range parsed.LocalFiles { + if err := os.Remove(f); err != nil { + logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{ + "path": f, + "error": err.Error(), + }) + } + } + }() + } + + if c.isDuplicate(messageID) { + logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{ + "message_id": messageID, + }) + return + } + + if content == "" { + logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{ + "message_id": messageID, + }) + return + } + + senderID := strconv.FormatInt(userID, 10) + var chatID string + + metadata := map[string]string{ + "message_id": messageID, + } + + if parsed.ReplyTo != "" { + metadata["reply_to_message_id"] = parsed.ReplyTo + } + + switch raw.MessageType { + case "private": + chatID = "private:" + senderID + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + + case "group": + groupIDStr := strconv.FormatInt(groupID, 10) + chatID = "group:" + groupIDStr + metadata["peer_kind"] = "group" + metadata["peer_id"] = groupIDStr + metadata["group_id"] = groupIDStr + + senderUserID, _ := parseJSONInt64(sender.UserID) + if senderUserID > 0 { + metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10) + } + + if sender.Card != "" { + metadata["sender_name"] = sender.Card + } else if sender.Nickname != "" { + metadata["sender_name"] = sender.Nickname + } + + triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) + if !triggered { + logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{ + "sender": senderID, + "group": groupIDStr, + "is_mentioned": isBotMentioned, + "content": truncate(content, 100), + }) + return + } + content = strippedContent + + default: + logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{ + "type": raw.MessageType, + "message_id": messageID, + "user_id": userID, + }) + return + } + + logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{ + "sender": senderID, + "chat_id": chatID, + "message_id": messageID, + "length": len(content), + "content": truncate(content, 100), + "media_count": len(parsed.Media), + }) + + if sender.Nickname != "" { + metadata["nickname"] = sender.Nickname + } + + c.lastMessageID.Store(chatID, messageID) + + if raw.MessageType == "group" && messageID != "" && messageID != "0" { + c.setMsgEmojiLike(messageID, 289, true) + c.pendingEmojiMsg.Store(chatID, messageID) + } + + c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) +} + +func (c *OneBotChannel) isDuplicate(messageID string) bool { + if messageID == "" || messageID == "0" { + return false + } + + c.mu.Lock() + defer c.mu.Unlock() + + if _, exists := c.dedup[messageID]; exists { + return true + } + + if old := c.dedupRing[c.dedupIdx]; old != "" { + delete(c.dedup, old) + } + c.dedupRing[c.dedupIdx] = messageID + c.dedup[messageID] = struct{}{} + c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing) + + return false +} + +func truncate(s string, n int) string { + runes := []rune(s) + if len(runes) <= n { + return s + } + return string(runes[:n]) + "..." +} + +func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) { + if isBotMentioned { + return true, strings.TrimSpace(content) + } + + for _, prefix := range c.config.GroupTriggerPrefix { + if prefix == "" { + continue + } + if strings.HasPrefix(content, prefix) { + return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) + } + } + + return false, content +} diff --git a/pkg/channels/qq/init.go b/pkg/channels/qq/init.go new file mode 100644 index 000000000..15b955089 --- /dev/null +++ b/pkg/channels/qq/init.go @@ -0,0 +1,13 @@ +package qq + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("qq", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewQQChannel(cfg.Channels.QQ, b) + }) +} diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go new file mode 100644 index 000000000..9b07be0cc --- /dev/null +++ b/pkg/channels/qq/qq.go @@ -0,0 +1,248 @@ +package qq + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/tencent-connect/botgo" + "github.com/tencent-connect/botgo/dto" + "github.com/tencent-connect/botgo/event" + "github.com/tencent-connect/botgo/openapi" + "github.com/tencent-connect/botgo/token" + "golang.org/x/oauth2" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +type QQChannel struct { + *channels.BaseChannel + config config.QQConfig + api openapi.OpenAPI + tokenSource oauth2.TokenSource + ctx context.Context + cancel context.CancelFunc + sessionManager botgo.SessionManager + processedIDs map[string]bool + mu sync.RWMutex +} + +func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { + base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) + + return &QQChannel{ + BaseChannel: base, + config: cfg, + processedIDs: make(map[string]bool), + }, nil +} + +func (c *QQChannel) Start(ctx context.Context) error { + if c.config.AppID == "" || c.config.AppSecret == "" { + return fmt.Errorf("QQ app_id and app_secret not configured") + } + + logger.InfoC("qq", "Starting QQ bot (WebSocket mode)") + + // 创建 token source + credentials := &token.QQBotCredentials{ + AppID: c.config.AppID, + AppSecret: c.config.AppSecret, + } + c.tokenSource = token.NewQQBotTokenSource(credentials) + + // 创建子 context + c.ctx, c.cancel = context.WithCancel(ctx) + + // 启动自动刷新 token 协程 + if err := token.StartRefreshAccessToken(c.ctx, c.tokenSource); err != nil { + return fmt.Errorf("failed to start token refresh: %w", err) + } + + // 初始化 OpenAPI 客户端 + c.api = botgo.NewOpenAPI(c.config.AppID, c.tokenSource).WithTimeout(5 * time.Second) + + // 注册事件处理器 + intent := event.RegisterHandlers( + c.handleC2CMessage(), + c.handleGroupATMessage(), + ) + + // 获取 WebSocket 接入点 + wsInfo, err := c.api.WS(c.ctx, nil, "") + if err != nil { + return fmt.Errorf("failed to get websocket info: %w", err) + } + + logger.InfoCF("qq", "Got WebSocket info", map[string]interface{}{ + "shards": wsInfo.Shards, + }) + + // 创建并保存 sessionManager + c.sessionManager = botgo.NewSessionManager() + + // 在 goroutine 中启动 WebSocket 连接,避免阻塞 + go func() { + if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil { + logger.ErrorCF("qq", "WebSocket session error", map[string]interface{}{ + "error": err.Error(), + }) + c.SetRunning(false) + } + }() + + c.SetRunning(true) + logger.InfoC("qq", "QQ bot started successfully") + + return nil +} + +func (c *QQChannel) Stop(ctx context.Context) error { + logger.InfoC("qq", "Stopping QQ bot") + c.SetRunning(false) + + if c.cancel != nil { + c.cancel() + } + + return nil +} + +func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("QQ bot not running") + } + + // 构造消息 + msgToCreate := &dto.MessageToCreate{ + Content: msg.Content, + } + + // C2C 消息发送 + _, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate) + if err != nil { + logger.ErrorCF("qq", "Failed to send C2C message", map[string]interface{}{ + "error": err.Error(), + }) + return err + } + + return nil +} + +// handleC2CMessage 处理 QQ 私聊消息 +func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { + return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error { + // 去重检查 + if c.isDuplicate(data.ID) { + return nil + } + + // 提取用户信息 + var senderID string + if data.Author != nil && data.Author.ID != "" { + senderID = data.Author.ID + } else { + logger.WarnC("qq", "Received message with no sender ID") + return nil + } + + // 提取消息内容 + content := data.Content + if content == "" { + logger.DebugC("qq", "Received empty message, ignoring") + return nil + } + + logger.InfoCF("qq", "Received C2C message", map[string]interface{}{ + "sender": senderID, + "length": len(content), + }) + + // 转发到消息总线 + metadata := map[string]string{ + "message_id": data.ID, + "peer_kind": "direct", + "peer_id": senderID, + } + + c.HandleMessage(senderID, senderID, content, []string{}, metadata) + + return nil + } +} + +// handleGroupATMessage 处理群@消息 +func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { + return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error { + // 去重检查 + if c.isDuplicate(data.ID) { + return nil + } + + // 提取用户信息 + var senderID string + if data.Author != nil && data.Author.ID != "" { + senderID = data.Author.ID + } else { + logger.WarnC("qq", "Received group message with no sender ID") + return nil + } + + // 提取消息内容(去掉 @ 机器人部分) + content := data.Content + if content == "" { + logger.DebugC("qq", "Received empty group message, ignoring") + return nil + } + + logger.InfoCF("qq", "Received group AT message", map[string]interface{}{ + "sender": senderID, + "group": data.GroupID, + "length": len(content), + }) + + // 转发到消息总线(使用 GroupID 作为 ChatID) + metadata := map[string]string{ + "message_id": data.ID, + "group_id": data.GroupID, + "peer_kind": "group", + "peer_id": data.GroupID, + } + + c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) + + return nil + } +} + +// isDuplicate 检查消息是否重复 +func (c *QQChannel) isDuplicate(messageID string) bool { + c.mu.Lock() + defer c.mu.Unlock() + + if c.processedIDs[messageID] { + return true + } + + c.processedIDs[messageID] = true + + // 简单清理:限制 map 大小 + if len(c.processedIDs) > 10000 { + // 清空一半 + count := 0 + for id := range c.processedIDs { + if count >= 5000 { + break + } + delete(c.processedIDs, id) + count++ + } + } + + return false +} diff --git a/pkg/channels/slack/init.go b/pkg/channels/slack/init.go new file mode 100644 index 000000000..c131bb291 --- /dev/null +++ b/pkg/channels/slack/init.go @@ -0,0 +1,13 @@ +package slack + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("slack", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewSlackChannel(cfg.Channels.Slack, b) + }) +} diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go new file mode 100644 index 000000000..dc5190fc9 --- /dev/null +++ b/pkg/channels/slack/slack.go @@ -0,0 +1,444 @@ +package slack + +import ( + "context" + "fmt" + "os" + "strings" + "sync" + "time" + + "github.com/slack-go/slack" + "github.com/slack-go/slack/slackevents" + "github.com/slack-go/slack/socketmode" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type SlackChannel struct { + *channels.BaseChannel + config config.SlackConfig + api *slack.Client + socketClient *socketmode.Client + botUserID string + teamID string + transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc + pendingAcks sync.Map +} + +type slackMessageRef struct { + ChannelID string + Timestamp string +} + +func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) { + if cfg.BotToken == "" || cfg.AppToken == "" { + return nil, fmt.Errorf("slack bot_token and app_token are required") + } + + api := slack.New( + cfg.BotToken, + slack.OptionAppLevelToken(cfg.AppToken), + ) + + socketClient := socketmode.New(api) + + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + + return &SlackChannel{ + BaseChannel: base, + config: cfg, + api: api, + socketClient: socketClient, + }, nil +} + +func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *SlackChannel) Start(ctx context.Context) error { + logger.InfoC("slack", "Starting Slack channel (Socket Mode)") + + c.ctx, c.cancel = context.WithCancel(ctx) + + authResp, err := c.api.AuthTest() + if err != nil { + return fmt.Errorf("slack auth test failed: %w", err) + } + c.botUserID = authResp.UserID + c.teamID = authResp.TeamID + + logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + "bot_user_id": c.botUserID, + "team": authResp.Team, + }) + + go c.eventLoop() + + go func() { + if err := c.socketClient.RunContext(c.ctx); err != nil { + if c.ctx.Err() == nil { + logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + "error": err.Error(), + }) + } + } + }() + + c.SetRunning(true) + logger.InfoC("slack", "Slack channel started (Socket Mode)") + return nil +} + +func (c *SlackChannel) Stop(ctx context.Context) error { + logger.InfoC("slack", "Stopping Slack channel") + + if c.cancel != nil { + c.cancel() + } + + c.SetRunning(false) + logger.InfoC("slack", "Slack channel stopped") + return nil +} + +func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("slack channel not running") + } + + channelID, threadTS := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + opts := []slack.MsgOption{ + slack.MsgOptionText(msg.Content, false), + } + + if threadTS != "" { + opts = append(opts, slack.MsgOptionTS(threadTS)) + } + + _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) + if err != nil { + return fmt.Errorf("failed to send slack message: %w", err) + } + + if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { + msgRef := ref.(slackMessageRef) + c.api.AddReaction("white_check_mark", slack.ItemRef{ + Channel: msgRef.ChannelID, + Timestamp: msgRef.Timestamp, + }) + } + + logger.DebugCF("slack", "Message sent", map[string]interface{}{ + "channel_id": channelID, + "thread_ts": threadTS, + }) + + return nil +} + +func (c *SlackChannel) eventLoop() { + for { + select { + case <-c.ctx.Done(): + return + case event, ok := <-c.socketClient.Events: + if !ok { + return + } + switch event.Type { + case socketmode.EventTypeEventsAPI: + c.handleEventsAPI(event) + case socketmode.EventTypeSlashCommand: + c.handleSlashCommand(event) + case socketmode.EventTypeInteractive: + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + } + } + } +} + +func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent) + if !ok { + return + } + + switch ev := eventsAPIEvent.InnerEvent.Data.(type) { + case *slackevents.MessageEvent: + c.handleMessageEvent(ev) + case *slackevents.AppMentionEvent: + c.handleAppMention(ev) + } +} + +func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { + if ev.User == c.botUserID || ev.User == "" { + return + } + if ev.BotID != "" { + return + } + if ev.SubType != "" && ev.SubType != "file_share" { + return + } + + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + chatID := channelID + if threadTS != "" { + chatID = channelID + "/" + threadTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := ev.Text + content = c.stripBotMention(content) + + var mediaPaths []string + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if ev.Message != nil && len(ev.Message.Files) > 0 { + for _, file := range ev.Message.Files { + localPath := c.downloadSlackFile(file) + if localPath == "" { + continue + } + localFiles = append(localFiles, localPath) + mediaPaths = append(mediaPaths, localPath) + + if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) + defer cancel() + result, err := c.transcriber.Transcribe(ctx, localPath) + + if err != nil { + logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) + } else { + content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) + } + } else { + content += fmt.Sprintf("\n[file: %s]", file.Name) + } + } + } + + if strings.TrimSpace(content) == "" { + return + } + + peerKind := "channel" + peerID := channelID + if strings.HasPrefix(channelID, "D") { + peerKind = "direct" + peerID = senderID + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "peer_kind": peerKind, + "peer_id": peerID, + "team_id": c.teamID, + } + + logger.DebugCF("slack", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": chatID, + "preview": utils.Truncate(content, 50), + "has_thread": threadTS != "", + }) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} + +func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { + if ev.User == c.botUserID { + return + } + + if !c.IsAllowed(ev.User) { + logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{ + "user_id": ev.User, + }) + return + } + + senderID := ev.User + channelID := ev.Channel + threadTS := ev.ThreadTimeStamp + messageTS := ev.TimeStamp + + var chatID string + if threadTS != "" { + chatID = channelID + "/" + threadTS + } else { + chatID = channelID + "/" + messageTS + } + + c.api.AddReaction("eyes", slack.ItemRef{ + Channel: channelID, + Timestamp: messageTS, + }) + + c.pendingAcks.Store(chatID, slackMessageRef{ + ChannelID: channelID, + Timestamp: messageTS, + }) + + content := c.stripBotMention(ev.Text) + + if strings.TrimSpace(content) == "" { + return + } + + mentionPeerKind := "channel" + mentionPeerID := channelID + if strings.HasPrefix(channelID, "D") { + mentionPeerKind = "direct" + mentionPeerID = senderID + } + + metadata := map[string]string{ + "message_ts": messageTS, + "channel_id": channelID, + "thread_ts": threadTS, + "platform": "slack", + "is_mention": "true", + "peer_kind": mentionPeerKind, + "peer_id": mentionPeerID, + "team_id": c.teamID, + } + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { + cmd, ok := event.Data.(slack.SlashCommand) + if !ok { + return + } + + if event.Request != nil { + c.socketClient.Ack(*event.Request) + } + + if !c.IsAllowed(cmd.UserID) { + logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{ + "user_id": cmd.UserID, + }) + return + } + + senderID := cmd.UserID + channelID := cmd.ChannelID + chatID := channelID + content := cmd.Text + + if strings.TrimSpace(content) == "" { + content = "help" + } + + metadata := map[string]string{ + "channel_id": channelID, + "platform": "slack", + "is_command": "true", + "trigger_id": cmd.TriggerID, + "peer_kind": "channel", + "peer_id": channelID, + "team_id": c.teamID, + } + + logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + "sender_id": senderID, + "command": cmd.Command, + "text": utils.Truncate(content, 50), + }) + + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +func (c *SlackChannel) downloadSlackFile(file slack.File) string { + downloadURL := file.URLPrivateDownload + if downloadURL == "" { + downloadURL = file.URLPrivate + } + if downloadURL == "" { + logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + return "" + } + + return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ + LoggerPrefix: "slack", + ExtraHeaders: map[string]string{ + "Authorization": "Bearer " + c.config.BotToken, + }, + }) +} + +func (c *SlackChannel) stripBotMention(text string) string { + mention := fmt.Sprintf("<@%s>", c.botUserID) + text = strings.ReplaceAll(text, mention, "") + return strings.TrimSpace(text) +} + +func parseSlackChatID(chatID string) (channelID, threadTS string) { + parts := strings.SplitN(chatID, "/", 2) + channelID = parts[0] + if len(parts) > 1 { + threadTS = parts[1] + } + return +} diff --git a/pkg/channels/slack/slack_test.go b/pkg/channels/slack/slack_test.go new file mode 100644 index 000000000..30e0d2d73 --- /dev/null +++ b/pkg/channels/slack/slack_test.go @@ -0,0 +1,174 @@ +package slack + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +func TestParseSlackChatID(t *testing.T) { + tests := []struct { + name string + chatID string + wantChanID string + wantThread string + }{ + { + name: "channel only", + chatID: "C123456", + wantChanID: "C123456", + wantThread: "", + }, + { + name: "channel with thread", + chatID: "C123456/1234567890.123456", + wantChanID: "C123456", + wantThread: "1234567890.123456", + }, + { + name: "DM channel", + chatID: "D987654", + wantChanID: "D987654", + wantThread: "", + }, + { + name: "empty string", + chatID: "", + wantChanID: "", + wantThread: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chanID, threadTS := parseSlackChatID(tt.chatID) + if chanID != tt.wantChanID { + t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID) + } + if threadTS != tt.wantThread { + t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread) + } + }) + } +} + +func TestStripBotMention(t *testing.T) { + ch := &SlackChannel{botUserID: "U12345BOT"} + + tests := []struct { + name string + input string + want string + }{ + { + name: "mention at start", + input: "<@U12345BOT> hello there", + want: "hello there", + }, + { + name: "mention in middle", + input: "hey <@U12345BOT> can you help", + want: "hey can you help", + }, + { + name: "no mention", + input: "hello world", + want: "hello world", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only mention", + input: "<@U12345BOT>", + want: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := ch.stripBotMention(tt.input) + if got != tt.want { + t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestNewSlackChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing bot token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "", + AppToken: "xapp-test", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing bot_token, got nil") + } + }) + + t.Run("missing app token", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "", + } + _, err := NewSlackChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing app_token, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U123"}, + } + ch, err := NewSlackChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "slack" { + t.Errorf("Name() = %q, want %q", ch.Name(), "slack") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestSlackChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ANYONE") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.SlackConfig{ + BotToken: "xoxb-test", + AppToken: "xapp-test", + AllowFrom: []string{"U_ALLOWED"}, + } + ch, _ := NewSlackChannel(cfg, msgBus) + if !ch.IsAllowed("U_ALLOWED") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("U_BLOCKED") { + t.Error("non-allowed user should be blocked") + } + }) +} diff --git a/pkg/channels/telegram/init.go b/pkg/channels/telegram/init.go new file mode 100644 index 000000000..ac87bb805 --- /dev/null +++ b/pkg/channels/telegram/init.go @@ -0,0 +1,13 @@ +package telegram + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("telegram", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewTelegramChannel(cfg, b) + }) +} diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go new file mode 100644 index 000000000..f4c5108df --- /dev/null +++ b/pkg/channels/telegram/telegram.go @@ -0,0 +1,526 @@ +package telegram + +import ( + "context" + "fmt" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "sync" + "time" + + th "github.com/mymmrac/telego/telegohandler" + + "github.com/mymmrac/telego" + "github.com/mymmrac/telego/telegohandler" + tu "github.com/mymmrac/telego/telegoutil" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" + "github.com/sipeed/picoclaw/pkg/voice" +) + +type TelegramChannel struct { + *channels.BaseChannel + bot *telego.Bot + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + transcriber *voice.GroqTranscriber + placeholders sync.Map // chatID -> messageID + stopThinking sync.Map // chatID -> thinkingCancel +} + +type thinkingCancel struct { + fn context.CancelFunc +} + +func (c *thinkingCancel) Cancel() { + if c != nil && c.fn != nil { + c.fn() + } +} + +func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { + var opts []telego.BotOption + telegramCfg := cfg.Channels.Telegram + + if telegramCfg.Proxy != "" { + proxyURL, parseErr := url.Parse(telegramCfg.Proxy) + if parseErr != nil { + return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr) + } + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyURL(proxyURL), + }, + })) + } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { + // Use environment proxy if configured + opts = append(opts, telego.WithHTTPClient(&http.Client{ + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + }, + })) + } + + bot, err := telego.NewBot(telegramCfg.Token, opts...) + if err != nil { + return nil, fmt.Errorf("failed to create telegram bot: %w", err) + } + + base := channels.NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) + + return &TelegramChannel{ + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), + transcriber: nil, + placeholders: sync.Map{}, + stopThinking: sync.Map{}, + }, nil +} + +func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { + c.transcriber = transcriber +} + +func (c *TelegramChannel) Start(ctx context.Context) error { + logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") + + updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + Timeout: 30, + }) + if err != nil { + return fmt.Errorf("failed to start long polling: %w", err) + } + + bh, err := telegohandler.NewBotHandler(c.bot, updates) + if err != nil { + return fmt.Errorf("failed to create bot handler: %w", err) + } + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + c.commands.Help(ctx, message) + return nil + }, th.CommandEqual("help")) + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Start(ctx, message) + }, th.CommandEqual("start")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.Show(ctx, message) + }, th.CommandEqual("show")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.commands.List(ctx, message) + }, th.CommandEqual("list")) + + bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { + return c.handleMessage(ctx, &message) + }, th.AnyMessage()) + + c.SetRunning(true) + logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + "username": c.bot.Username(), + }) + + go bh.Start() + + go func() { + <-ctx.Done() + bh.Stop() + }() + + return nil +} +func (c *TelegramChannel) Stop(ctx context.Context) error { + logger.InfoC("telegram", "Stopping Telegram bot...") + c.SetRunning(false) + return nil +} + +func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("telegram bot not running") + } + + chatID, err := parseChatID(msg.ChatID) + if err != nil { + return fmt.Errorf("invalid chat ID: %w", err) + } + + // Stop thinking animation + if stop, ok := c.stopThinking.Load(msg.ChatID); ok { + if cf, ok := stop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + c.stopThinking.Delete(msg.ChatID) + } + + htmlContent := markdownToTelegramHTML(msg.Content) + + // Try to edit placeholder + if pID, ok := c.placeholders.Load(msg.ChatID); ok { + c.placeholders.Delete(msg.ChatID) + editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) + editMsg.ParseMode = telego.ModeHTML + + if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { + return nil + } + // Fallback to new message if edit fails + } + + tgMsg := tu.Message(tu.ID(chatID), htmlContent) + tgMsg.ParseMode = telego.ModeHTML + + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + "error": err.Error(), + }) + tgMsg.ParseMode = "" + _, err = c.bot.SendMessage(ctx, tgMsg) + return err + } + + return nil +} + +func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { + if message == nil { + return fmt.Errorf("message is nil") + } + + user := message.From + if user == nil { + return fmt.Errorf("message sender (user) is nil") + } + + senderID := fmt.Sprintf("%d", user.ID) + if user.Username != "" { + senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + } + + // 检查白名单,避免为被拒绝的用户下载附件 + if !c.IsAllowed(senderID) { + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + "user_id": senderID, + }) + return nil + } + + chatID := message.Chat.ID + c.chatIDs[senderID] = chatID + + content := "" + mediaPaths := []string{} + localFiles := []string{} // 跟踪需要清理的本地文件 + + // 确保临时文件在函数返回时被清理 + defer func() { + for _, file := range localFiles { + if err := os.Remove(file); err != nil { + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + "file": file, + "error": err.Error(), + }) + } + } + }() + + if message.Text != "" { + content += message.Text + } + + if message.Caption != "" { + if content != "" { + content += "\n" + } + content += message.Caption + } + + if len(message.Photo) > 0 { + photo := message.Photo[len(message.Photo)-1] + photoPath := c.downloadPhoto(ctx, photo.FileID) + if photoPath != "" { + localFiles = append(localFiles, photoPath) + mediaPaths = append(mediaPaths, photoPath) + if content != "" { + content += "\n" + } + content += "[image: photo]" + } + } + + if message.Voice != nil { + voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") + if voicePath != "" { + localFiles = append(localFiles, voicePath) + mediaPaths = append(mediaPaths, voicePath) + + transcribedText := "" + if c.transcriber != nil && c.transcriber.IsAvailable() { + ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + result, err := c.transcriber.Transcribe(ctx, voicePath) + if err != nil { + logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + "error": err.Error(), + "path": voicePath, + }) + transcribedText = "[voice (transcription failed)]" + } else { + transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + "text": result.Text, + }) + } + } else { + transcribedText = "[voice]" + } + + if content != "" { + content += "\n" + } + content += transcribedText + } + } + + if message.Audio != nil { + audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") + if audioPath != "" { + localFiles = append(localFiles, audioPath) + mediaPaths = append(mediaPaths, audioPath) + if content != "" { + content += "\n" + } + content += "[audio]" + } + } + + if message.Document != nil { + docPath := c.downloadFile(ctx, message.Document.FileID, "") + if docPath != "" { + localFiles = append(localFiles, docPath) + mediaPaths = append(mediaPaths, docPath) + if content != "" { + content += "\n" + } + content += "[file]" + } + } + + if content == "" { + content = "[empty message]" + } + + logger.DebugCF("telegram", "Received message", map[string]interface{}{ + "sender_id": senderID, + "chat_id": fmt.Sprintf("%d", chatID), + "preview": utils.Truncate(content, 50), + }) + + // Thinking indicator + err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) + if err != nil { + logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Stop any previous thinking animation + chatIDStr := fmt.Sprintf("%d", chatID) + if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { + if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + } + + // Create cancel function for thinking state + _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) + c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + + pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) + if err == nil { + pID := pMsg.MessageID + c.placeholders.Store(chatIDStr, pID) + } + + peerKind := "direct" + peerID := fmt.Sprintf("%d", user.ID) + if message.Chat.Type != "private" { + peerKind = "group" + peerID = fmt.Sprintf("%d", chatID) + } + + metadata := map[string]string{ + "message_id": fmt.Sprintf("%d", message.MessageID), + "user_id": fmt.Sprintf("%d", user.ID), + "username": user.Username, + "first_name": user.FirstName, + "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), + "peer_kind": peerKind, + "peer_id": peerID, + } + + c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + return nil +} + +func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) + if err != nil { + logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + return c.downloadFileWithInfo(file, ".jpg") +} + +func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string { + if file.FilePath == "" { + return "" + } + + url := c.bot.FileDownloadURL(file.FilePath) + logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) + + // Use FilePath as filename for better identification + filename := file.FilePath + ext + return utils.DownloadFile(url, filename, utils.DownloadOptions{ + LoggerPrefix: "telegram", + }) +} + +func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { + file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) + if err != nil { + logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + "error": err.Error(), + }) + return "" + } + + return c.downloadFileWithInfo(file, ext) +} + +func parseChatID(chatIDStr string) (int64, error) { + var id int64 + _, err := fmt.Sscanf(chatIDStr, "%d", &id) + return id, err +} + +func markdownToTelegramHTML(text string) string { + if text == "" { + return "" + } + + codeBlocks := extractCodeBlocks(text) + text = codeBlocks.text + + inlineCodes := extractInlineCodes(text) + text = inlineCodes.text + + text = regexp.MustCompile(`^#{1,6}\s+(.+)$`).ReplaceAllString(text, "$1") + + text = regexp.MustCompile(`^>\s*(.*)$`).ReplaceAllString(text, "$1") + + text = escapeHTML(text) + + text = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(text, `$1`) + + text = regexp.MustCompile(`\*\*(.+?)\*\*`).ReplaceAllString(text, "$1") + + text = regexp.MustCompile(`__(.+?)__`).ReplaceAllString(text, "$1") + + reItalic := regexp.MustCompile(`_([^_]+)_`) + text = reItalic.ReplaceAllStringFunc(text, func(s string) string { + match := reItalic.FindStringSubmatch(s) + if len(match) < 2 { + return s + } + return "" + match[1] + "" + }) + + text = regexp.MustCompile(`~~(.+?)~~`).ReplaceAllString(text, "$1") + + text = regexp.MustCompile(`^[-*]\s+`).ReplaceAllString(text, "• ") + + for i, code := range inlineCodes.codes { + escaped := escapeHTML(code) + text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("%s", escaped)) + } + + for i, code := range codeBlocks.codes { + escaped := escapeHTML(code) + text = strings.ReplaceAll(text, fmt.Sprintf("\x00CB%d\x00", i), fmt.Sprintf("
%s
", escaped)) + } + + return text +} + +type codeBlockMatch struct { + text string + codes []string +} + +func extractCodeBlocks(text string) codeBlockMatch { + re := regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```") + matches := re.FindAllStringSubmatch(text, -1) + + codes := make([]string, 0, len(matches)) + for _, match := range matches { + codes = append(codes, match[1]) + } + + i := 0 + text = re.ReplaceAllStringFunc(text, func(m string) string { + placeholder := fmt.Sprintf("\x00CB%d\x00", i) + i++ + return placeholder + }) + + return codeBlockMatch{text: text, codes: codes} +} + +type inlineCodeMatch struct { + text string + codes []string +} + +func extractInlineCodes(text string) inlineCodeMatch { + re := regexp.MustCompile("`([^`]+)`") + matches := re.FindAllStringSubmatch(text, -1) + + codes := make([]string, 0, len(matches)) + for _, match := range matches { + codes = append(codes, match[1]) + } + + i := 0 + text = re.ReplaceAllStringFunc(text, func(m string) string { + placeholder := fmt.Sprintf("\x00IC%d\x00", i) + i++ + return placeholder + }) + + return inlineCodeMatch{text: text, codes: codes} +} + +func escapeHTML(text string) string { + text = strings.ReplaceAll(text, "&", "&") + text = strings.ReplaceAll(text, "<", "<") + text = strings.ReplaceAll(text, ">", ">") + return text +} diff --git a/pkg/channels/telegram/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go new file mode 100644 index 000000000..4bf1b3aff --- /dev/null +++ b/pkg/channels/telegram/telegram_commands.go @@ -0,0 +1,153 @@ +package telegram + +import ( + "context" + "fmt" + "strings" + + "github.com/mymmrac/telego" + "github.com/sipeed/picoclaw/pkg/config" +) + +type TelegramCommander interface { + Help(ctx context.Context, message telego.Message) error + Start(ctx context.Context, message telego.Message) error + Show(ctx context.Context, message telego.Message) error + List(ctx context.Context, message telego.Message) error +} + +type cmd struct { + bot *telego.Bot + config *config.Config +} + +func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { + return &cmd{ + bot: bot, + config: cfg, + } +} + +func commandArgs(text string) string { + parts := strings.SplitN(text, " ", 2) + if len(parts) < 2 { + return "" + } + return strings.TrimSpace(parts[1]) +} +func (c *cmd) Help(ctx context.Context, message telego.Message) error { + msg := `/start - Start the bot +/help - Show this help message +/show [model|channel] - Show current configuration +/list [models|channels] - List available options + ` + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: msg, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Start(ctx context.Context, message telego.Message) error { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Hello! I am PicoClaw 🦞", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} + +func (c *cmd) Show(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /show [model|channel]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "model": + response = fmt.Sprintf("Current Model: %s (Provider: %s)", + c.config.Agents.Defaults.Model, + c.config.Agents.Defaults.Provider) + case "channel": + response = "Current Channel: telegram" + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} +func (c *cmd) List(ctx context.Context, message telego.Message) error { + args := commandArgs(message.Text) + if args == "" { + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: "Usage: /list [models|channels]", + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err + } + + var response string + switch args { + case "models": + provider := c.config.Agents.Defaults.Provider + if provider == "" { + provider = "configured default" + } + response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", + c.config.Agents.Defaults.Model, provider) + + case "channels": + var enabled []string + if c.config.Channels.Telegram.Enabled { + enabled = append(enabled, "telegram") + } + if c.config.Channels.WhatsApp.Enabled { + enabled = append(enabled, "whatsapp") + } + if c.config.Channels.Feishu.Enabled { + enabled = append(enabled, "feishu") + } + if c.config.Channels.Discord.Enabled { + enabled = append(enabled, "discord") + } + if c.config.Channels.Slack.Enabled { + enabled = append(enabled, "slack") + } + response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) + + default: + response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) + } + + _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ + ChatID: telego.ChatID{ID: message.Chat.ID}, + Text: response, + ReplyParameters: &telego.ReplyParameters{ + MessageID: message.MessageID, + }, + }) + return err +} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go new file mode 100644 index 000000000..85c017958 --- /dev/null +++ b/pkg/channels/wecom/app.go @@ -0,0 +1,636 @@ +package wecom + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +const ( + wecomAPIBase = "https://qyapi.weixin.qq.com" +) + +// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) +type WeComAppChannel struct { + *channels.BaseChannel + config config.WeComAppConfig + server *http.Server + accessToken string + tokenExpiry time.Time + tokenMu sync.RWMutex + ctx context.Context + cancel context.CancelFunc + processedMsgs map[string]bool // Message deduplication: msg_id -> processed + msgMu sync.RWMutex +} + +// WeComXMLMessage represents the XML message structure from WeCom +type WeComXMLMessage struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + FromUserName string `xml:"FromUserName"` + CreateTime int64 `xml:"CreateTime"` + MsgType string `xml:"MsgType"` + Content string `xml:"Content"` + MsgId int64 `xml:"MsgId"` + AgentID int64 `xml:"AgentID"` + PicUrl string `xml:"PicUrl"` + MediaId string `xml:"MediaId"` + Format string `xml:"Format"` + ThumbMediaId string `xml:"ThumbMediaId"` + LocationX float64 `xml:"Location_X"` + LocationY float64 `xml:"Location_Y"` + Scale int `xml:"Scale"` + Label string `xml:"Label"` + Title string `xml:"Title"` + Description string `xml:"Description"` + Url string `xml:"Url"` + Event string `xml:"Event"` + EventKey string `xml:"EventKey"` +} + +// WeComTextMessage represents text message for sending +type WeComTextMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Safe int `json:"safe,omitempty"` +} + +// WeComMarkdownMessage represents markdown message for sending +type WeComMarkdownMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Markdown struct { + Content string `json:"content"` + } `json:"markdown"` +} + +// WeComImageMessage represents image message for sending +type WeComImageMessage struct { + ToUser string `json:"touser"` + MsgType string `json:"msgtype"` + AgentID int64 `json:"agentid"` + Image struct { + MediaID string `json:"media_id"` + } `json:"image"` +} + +// WeComAccessTokenResponse represents the access token API response +type WeComAccessTokenResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` +} + +// WeComSendMessageResponse represents the send message API response +type WeComSendMessageResponse struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + InvalidUser string `json:"invaliduser"` + InvalidParty string `json:"invalidparty"` + InvalidTag string `json:"invalidtag"` +} + +// PKCS7Padding adds PKCS7 padding +type PKCS7Padding struct{} + +// NewWeComAppChannel creates a new WeCom App channel instance +func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) { + if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 { + return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") + } + + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + + return &WeComAppChannel{ + BaseChannel: base, + config: cfg, + processedMsgs: make(map[string]bool), + }, nil +} + +// Name returns the channel name +func (c *WeComAppChannel) Name() string { + return "wecom_app" +} + +// Start initializes the WeCom App channel with HTTP webhook server +func (c *WeComAppChannel) Start(ctx context.Context) error { + logger.InfoC("wecom_app", "Starting WeCom App channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Get initial access token + if err := c.refreshAccessToken(); err != nil { + logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]interface{}{ + "error": err.Error(), + }) + } + + // Start token refresh goroutine + go c.tokenRefreshLoop() + + // Setup HTTP server for webhook + mux := http.NewServeMux() + webhookPath := c.config.WebhookPath + if webhookPath == "" { + webhookPath = "/webhook/wecom-app" + } + mux.HandleFunc(webhookPath, c.handleWebhook) + + // Health check endpoint + mux.HandleFunc("/health/wecom-app", c.handleHealth) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.server = &http.Server{ + Addr: addr, + Handler: mux, + } + + c.SetRunning(true) + logger.InfoCF("wecom_app", "WeCom App channel started", map[string]interface{}{ + "address": addr, + "path": webhookPath, + }) + + // Start server in goroutine + go func() { + if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("wecom_app", "HTTP server error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + return nil +} + +// Stop gracefully stops the WeCom App channel +func (c *WeComAppChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom_app", "Stopping WeCom App channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + c.server.Shutdown(shutdownCtx) + } + + c.SetRunning(false) + logger.InfoC("wecom_app", "WeCom App channel stopped") + return nil +} + +// Send sends a message to WeCom user proactively using access token +func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("wecom_app channel not running") + } + + accessToken := c.getAccessToken() + if accessToken == "" { + return fmt.Errorf("no valid access token available") + } + + logger.DebugCF("wecom_app", "Sending message", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) +} + +// handleWebhook handles incoming webhook requests from WeCom +func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + // Log all incoming requests for debugging + logger.DebugCF("wecom_app", "Received webhook request", map[string]interface{}{ + "method": r.Method, + "url": r.URL.String(), + "path": r.URL.Path, + "query": r.URL.RawQuery, + }) + + if r.Method == http.MethodGet { + // Handle verification request + c.handleVerification(ctx, w, r) + return + } + + if r.Method == http.MethodPost { + // Handle message callback + c.handleMessageCallback(ctx, w, r) + return + } + + logger.WarnCF("wecom_app", "Method not allowed", map[string]interface{}{ + "method": r.Method, + }) + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + echostr := query.Get("echostr") + + logger.DebugCF("wecom_app", "Handling verification request", map[string]interface{}{ + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + "echostr": echostr, + "corp_id": c.config.CorpID, + }) + + if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { + logger.ErrorC("wecom_app", "Missing parameters in verification request") + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.WarnCF("wecom_app", "Signature verification failed", map[string]interface{}{ + "token": c.config.Token, + "msg_signature": msgSignature, + "timestamp": timestamp, + "nonce": nonce, + }) + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + logger.DebugC("wecom_app", "Signature verification passed") + + // Decrypt echostr with CorpID verification + // For WeCom App (自建应用), receiveid should be corp_id + logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]interface{}{ + "encoding_aes_key": c.config.EncodingAESKey, + "corp_id": c.config.CorpID, + }) + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]interface{}{ + "error": err.Error(), + "encoding_aes_key": c.config.EncodingAESKey, + "corp_id": c.config.CorpID, + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]interface{}{ + "decrypted": decryptedEchoStr, + }) + + // Remove BOM and whitespace as per WeCom documentation + // The response must be plain text without quotes, BOM, or newlines + decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) + decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM + w.Write([]byte(decryptedEchoStr)) +} + +// handleMessageCallback handles incoming messages from WeCom +func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + + if msgSignature == "" || timestamp == "" || nonce == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse XML to get encrypted message + var encryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + Encrypt string `xml:"Encrypt"` + AgentID string `xml:"AgentID"` + } + + if err := xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid XML", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.WarnC("wecom_app", "Message signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt message with CorpID verification + // For WeCom App (自建应用), receiveid should be corp_id + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted XML message + var msg WeComXMLMessage + if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid message format", http.StatusBadRequest) + return + } + + // Process the message with context + go c.processMessage(ctx, msg) + + // Return success response immediately + // WeCom App requires response within configured timeout (default 5 seconds) + w.Write([]byte("success")) +} + +// processMessage processes the received message +func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { + // Skip non-text messages for now (can be extended) + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { + logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]interface{}{ + "msg_type": msg.MsgType, + }) + return + } + + // Message deduplication: Use msg_id to prevent duplicate processing + // As per WeCom documentation, use msg_id for deduplication + msgID := fmt.Sprintf("%d", msg.MsgId) + c.msgMu.Lock() + if c.processedMsgs[msgID] { + c.msgMu.Unlock() + logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]interface{}{ + "msg_id": msgID, + }) + return + } + c.processedMsgs[msgID] = true + c.msgMu.Unlock() + + // Clean up old messages periodically (keep last 1000) + if len(c.processedMsgs) > 1000 { + c.msgMu.Lock() + c.processedMsgs = make(map[string]bool) + c.msgMu.Unlock() + } + + senderID := msg.FromUserName + chatID := senderID // WeCom App uses user ID as chat ID for direct messages + + // Build metadata + // WeCom App only supports direct messages (private chat) + metadata := map[string]string{ + "msg_type": msg.MsgType, + "msg_id": fmt.Sprintf("%d", msg.MsgId), + "agent_id": fmt.Sprintf("%d", msg.AgentID), + "platform": "wecom_app", + "media_id": msg.MediaId, + "create_time": fmt.Sprintf("%d", msg.CreateTime), + "peer_kind": "direct", + "peer_id": senderID, + } + + content := msg.Content + + logger.DebugCF("wecom_app", "Received message", map[string]interface{}{ + "sender_id": senderID, + "msg_type": msg.MsgType, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +// tokenRefreshLoop periodically refreshes the access token +func (c *WeComAppChannel) tokenRefreshLoop() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if err := c.refreshAccessToken(); err != nil { + logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]interface{}{ + "error": err.Error(), + }) + } + } + } +} + +// refreshAccessToken gets a new access token from WeCom API +func (c *WeComAppChannel) refreshAccessToken() error { + apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s", + wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret)) + + resp, err := http.Get(apiURL) + if err != nil { + return fmt.Errorf("failed to request access token: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var tokenResp WeComAccessTokenResponse + if err := json.Unmarshal(body, &tokenResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if tokenResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode) + } + + c.tokenMu.Lock() + c.accessToken = tokenResp.AccessToken + c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early + c.tokenMu.Unlock() + + logger.DebugC("wecom_app", "Access token refreshed successfully") + return nil +} + +// getAccessToken returns the current valid access token +func (c *WeComAppChannel) getAccessToken() string { + c.tokenMu.RLock() + defer c.tokenMu.RUnlock() + + if time.Now().After(c.tokenExpiry) { + return "" + } + + return c.accessToken +} + +// sendTextMessage sends a text message to a user +func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComTextMessage{ + ToUser: userID, + MsgType: "text", + AgentID: c.config.AgentID, + } + msg.Text.Content = content + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(body, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// sendMarkdownMessage sends a markdown message to a user +func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, userID, content string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComMarkdownMessage{ + ToUser: userID, + MsgType: "markdown", + AgentID: c.config.AgentID, + } + msg.Markdown.Content = content + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(body, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + +// handleHealth handles health check requests +func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]interface{}{ + "status": "ok", + "running": c.IsRunning(), + "has_token": c.getAccessToken() != "", + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go new file mode 100644 index 000000000..d9817fd49 --- /dev/null +++ b/pkg/channels/wecom/app_test.go @@ -0,0 +1,1086 @@ +package wecom + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" +) + +// generateTestAESKeyApp generates a valid test AES key for WeCom App +func generateTestAESKeyApp() string { + // AES key needs to be 32 bytes (256 bits) for AES-256 + key := make([]byte, 32) + for i := range key { + key[i] = byte(i + 1) + } + // Return base64 encoded key without padding + return base64.StdEncoding.EncodeToString(key)[:43] +} + +// encryptTestMessageApp encrypts a message for testing WeCom App +func encryptTestMessageApp(message, aesKey string) (string, error) { + // Decode AES key + key, err := base64.StdEncoding.DecodeString(aesKey + "=") + if err != nil { + return "", err + } + + // Prepare message: random(16) + msg_len(4) + msg + corp_id + random := make([]byte, 0, 16) + for i := 0; i < 16; i++ { + random = append(random, byte(i+1)) + } + + msgBytes := []byte(message) + corpID := []byte("test_corp_id") + + msgLen := uint32(len(msgBytes)) + lenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lenBytes, msgLen) + + plainText := append(random, lenBytes...) + plainText = append(plainText, msgBytes...) + plainText = append(plainText, corpID...) + + // PKCS7 padding + blockSize := aes.BlockSize + padding := blockSize - len(plainText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + plainText = append(plainText, padText...) + + // Encrypt + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) + cipherText := make([]byte, len(plainText)) + mode.CryptBlocks(cipherText, plainText) + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// generateSignatureApp generates a signature for testing WeCom App +func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string { + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + +func TestNewWeComAppChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing corp_id", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "", + CorpSecret: "test_secret", + AgentID: 1000002, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing corp_id, got nil") + } + }) + + t.Run("missing corp_secret", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "", + AgentID: 1000002, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing corp_secret, got nil") + } + }) + + t.Run("missing agent_id", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 0, + } + _, err := NewWeComAppChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing agent_id, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{"user1", "user2"}, + } + ch, err := NewWeComAppChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "wecom_app" { + t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestWeComAppChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{}, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + if !ch.IsAllowed("any_user") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + AllowFrom: []string{"allowed_user"}, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + if !ch.IsAllowed("allowed_user") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("blocked_user") { + t.Error("non-allowed user should be blocked") + } + }) +} + +func TestWeComAppVerifySignature(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) + + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + t.Error("valid signature should pass verification") + } + }) + + t.Run("invalid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + t.Error("invalid signature should fail verification") + } + }) + + t.Run("empty token skips verification", func(t *testing.T) { + cfgEmpty := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "", + } + chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) + + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should skip verification and return true") + } + }) +} + +func TestWeComAppDecryptMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("decrypt without AES key", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + // Without AES key, message should be base64 decoded only + plainText := "hello world" + encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) + + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != plainText { + t.Errorf("decryptMessage() = %q, want %q", result, plainText) + } + }) + + t.Run("decrypt with AES key", func(t *testing.T) { + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + originalMsg := "Hello" + encrypted, err := encryptTestMessageApp(originalMsg, aesKey) + if err != nil { + t.Fatalf("failed to encrypt test message: %v", err) + } + + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != originalMsg { + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) + } + }) + + t.Run("invalid base64", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid base64, got nil") + } + }) + + t.Run("invalid AES key", func(t *testing.T) { + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: "invalid_key", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid AES key, got nil") + } + }) + + t.Run("ciphertext too short", func(t *testing.T) { + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + // Encrypt a very short message that results in ciphertext less than block size + shortData := make([]byte, 8) + _, err := decryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for short ciphertext, got nil") + } + }) +} + +func TestWeComAppPKCS7Unpad(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "valid padding 3 bytes", + input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), + expected: []byte("hello"), + }, + { + name: "valid padding 16 bytes (full block)", + input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("123456789012345"), + }, + { + name: "invalid padding larger than data", + input: []byte{20}, + expected: nil, // should return error + }, + { + name: "invalid padding zero", + input: append([]byte("test"), byte(0)), + expected: nil, // should return error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs7Unpad(tt.input) + if tt.expected == nil { + // This case should return an error + if err == nil { + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) + } + return + } + if err != nil { + t.Errorf("pkcs7Unpad() unexpected error: %v", err) + return + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWeComAppHandleVerification(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid verification request", func(t *testing.T) { + echostr := "test_echostr_123" + encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != echostr { + t.Errorf("response body = %q, want %q", w.Body.String(), echostr) + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + echostr := "test_echostr" + encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComAppHandleMessageCallback(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKeyApp() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("valid message callback", func(t *testing.T) { + // Create XML message + xmlMsg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "text", + Content: "Hello World", + MsgId: 123456, + AgentID: 1000002, + } + xmlData, _ := xml.Marshal(xmlMsg) + + // Encrypt message + encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid XML", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, "") + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, strings.NewReader("invalid xml")) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: "encrypted_data", + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComAppProcessMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("process text message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "text", + Content: "Hello World", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process image message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "image", + PicUrl: "https://example.com/image.jpg", + MediaId: "media_123", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "voice", + MediaId: "media_123", + Format: "amr", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("skip unsupported message type", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "video", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process event message", func(t *testing.T) { + msg := WeComXMLMessage{ + ToUserName: "corp_id", + FromUserName: "user123", + CreateTime: 1234567890, + MsgType: "event", + Event: "subscribe", + MsgId: 123456, + AgentID: 1000002, + } + + // Should not panic + ch.processMessage(context.Background(), msg) + }) +} + +func TestWeComAppHandleWebhook(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + Token: "test_token", + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("GET request calls verification", func(t *testing.T) { + echostr := "test_echostr" + encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encoded) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("POST request calls message callback", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + // Should not be method not allowed + if w.Code == http.StatusMethodNotAllowed { + t.Error("POST request should not return Method Not Allowed") + } + }) + + t.Run("unsupported method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + }) +} + +func TestWeComAppHandleHealth(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil) + w := httptest.NewRecorder() + + ch.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + body := w.Body.String() + if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") { + t.Errorf("response body should contain status, running, and has_token fields, got: %s", body) + } +} + +func TestWeComAppAccessToken(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComAppConfig{ + CorpID: "test_corp_id", + CorpSecret: "test_secret", + AgentID: 1000002, + } + ch, _ := NewWeComAppChannel(cfg, msgBus) + + t.Run("get empty access token initially", func(t *testing.T) { + token := ch.getAccessToken() + if token != "" { + t.Errorf("getAccessToken() = %q, want empty string", token) + } + }) + + t.Run("set and get access token", func(t *testing.T) { + ch.tokenMu.Lock() + ch.accessToken = "test_token_123" + ch.tokenExpiry = time.Now().Add(1 * time.Hour) + ch.tokenMu.Unlock() + + token := ch.getAccessToken() + if token != "test_token_123" { + t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123") + } + }) + + t.Run("expired token returns empty", func(t *testing.T) { + ch.tokenMu.Lock() + ch.accessToken = "expired_token" + ch.tokenExpiry = time.Now().Add(-1 * time.Hour) + ch.tokenMu.Unlock() + + token := ch.getAccessToken() + if token != "" { + t.Errorf("getAccessToken() = %q, want empty string for expired token", token) + } + }) +} + +func TestWeComAppMessageStructures(t *testing.T) { + t.Run("WeComTextMessage structure", func(t *testing.T) { + msg := WeComTextMessage{ + ToUser: "user123", + MsgType: "text", + AgentID: 1000002, + } + msg.Text.Content = "Hello World" + + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } + + // Test JSON marshaling + jsonData, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal JSON: %v", err) + } + + var unmarshaled WeComTextMessage + err = json.Unmarshal(jsonData, &unmarshaled) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if unmarshaled.ToUser != msg.ToUser { + t.Errorf("JSON round-trip failed for ToUser") + } + }) + + t.Run("WeComMarkdownMessage structure", func(t *testing.T) { + msg := WeComMarkdownMessage{ + ToUser: "user123", + MsgType: "markdown", + AgentID: 1000002, + } + msg.Markdown.Content = "# Hello\nWorld" + + if msg.Markdown.Content != "# Hello\nWorld" { + t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld") + } + + // Test JSON marshaling + jsonData, err := json.Marshal(msg) + if err != nil { + t.Fatalf("failed to marshal JSON: %v", err) + } + + if !bytes.Contains(jsonData, []byte("markdown")) { + t.Error("JSON should contain 'markdown' field") + } + }) + + t.Run("WeComImageMessage structure", func(t *testing.T) { + msg := WeComImageMessage{ + ToUser: "user123", + MsgType: "image", + AgentID: 1000002, + } + msg.Image.MediaID = "media_123456" + + if msg.Image.MediaID != "media_123456" { + t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") + } + }) + + t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { + jsonData := `{ + "errcode": 0, + "errmsg": "ok", + "access_token": "test_access_token", + "expires_in": 7200 + }` + + var resp WeComAccessTokenResponse + err := json.Unmarshal([]byte(jsonData), &resp) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if resp.ErrCode != 0 { + t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) + } + if resp.ErrMsg != "ok" { + t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") + } + if resp.AccessToken != "test_access_token" { + t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token") + } + if resp.ExpiresIn != 7200 { + t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200) + } + }) + + t.Run("WeComSendMessageResponse structure", func(t *testing.T) { + jsonData := `{ + "errcode": 0, + "errmsg": "ok", + "invaliduser": "", + "invalidparty": "", + "invalidtag": "" + }` + + var resp WeComSendMessageResponse + err := json.Unmarshal([]byte(jsonData), &resp) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if resp.ErrCode != 0 { + t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) + } + if resp.ErrMsg != "ok" { + t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") + } + }) +} + +func TestWeComAppXMLMessageStructure(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.ToUserName != "corp_id" { + t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") + } + if msg.FromUserName != "user123" { + t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") + } + if msg.CreateTime != 1234567890 { + t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Content != "Hello World" { + t.Errorf("Content = %q, want %q", msg.Content, "Hello World") + } + if msg.MsgId != 1234567890123456 { + t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } +} + +func TestWeComAppXMLMessageImage(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.PicUrl != "https://example.com/image.jpg" { + t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg") + } + if msg.MediaId != "media_123" { + t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123") + } +} + +func TestWeComAppXMLMessageVoice(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "voice" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice") + } + if msg.Format != "amr" { + t.Errorf("Format = %q, want %q", msg.Format, "amr") + } +} + +func TestWeComAppXMLMessageLocation(t *testing.T) { + xmlData := ` + + + + 1234567890 + + 39.9042 + 116.4074 + 16 + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "location" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "location") + } + if msg.LocationX != 39.9042 { + t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042) + } + if msg.LocationY != 116.4074 { + t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074) + } + if msg.Scale != 16 { + t.Errorf("Scale = %d, want %d", msg.Scale, 16) + } + if msg.Label != "Beijing" { + t.Errorf("Label = %q, want %q", msg.Label, "Beijing") + } +} + +func TestWeComAppXMLMessageLink(t *testing.T) { + xmlData := ` + + + + 1234567890 + + <![CDATA[Link Title]]> + + + 1234567890123456 + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "link" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "link") + } + if msg.Title != "Link Title" { + t.Errorf("Title = %q, want %q", msg.Title, "Link Title") + } + if msg.Description != "Link Description" { + t.Errorf("Description = %q, want %q", msg.Description, "Link Description") + } + if msg.Url != "https://example.com" { + t.Errorf("Url = %q, want %q", msg.Url, "https://example.com") + } +} + +func TestWeComAppXMLMessageEvent(t *testing.T) { + xmlData := ` + + + + 1234567890 + + + + 1000002 +` + + var msg WeComXMLMessage + err := xml.Unmarshal([]byte(xmlData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal XML: %v", err) + } + + if msg.MsgType != "event" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "event") + } + if msg.Event != "subscribe" { + t.Errorf("Event = %q, want %q", msg.Event, "subscribe") + } + if msg.EventKey != "event_key_123" { + t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123") + } +} diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go new file mode 100644 index 000000000..9683a308f --- /dev/null +++ b/pkg/channels/wecom/bot.go @@ -0,0 +1,469 @@ +package wecom + +import ( + "bytes" + "context" + "encoding/json" + "encoding/xml" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" +) + +// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) +// Uses webhook callback mode - simpler than WeCom App but only supports passive replies +type WeComBotChannel struct { + *channels.BaseChannel + config config.WeComConfig + server *http.Server + ctx context.Context + cancel context.CancelFunc + processedMsgs map[string]bool // Message deduplication: msg_id -> processed + msgMu sync.RWMutex +} + +// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) +type WeComBotMessage struct { + MsgID string `json:"msgid"` + AIBotID string `json:"aibotid"` + ChatID string `json:"chatid"` // Session ID, only present for group chats + ChatType string `json:"chattype"` // "single" for DM, "group" for group chat + From struct { + UserID string `json:"userid"` + } `json:"from"` + ResponseURL string `json:"response_url"` + MsgType string `json:"msgtype"` // text, image, voice, file, mixed + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + Voice struct { + Content string `json:"content"` // Voice to text content + } `json:"voice"` + File struct { + URL string `json:"url"` + } `json:"file"` + Mixed struct { + MsgItem []struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + Image struct { + URL string `json:"url"` + } `json:"image"` + } `json:"msg_item"` + } `json:"mixed"` + Quote struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text"` + } `json:"quote"` +} + +// WeComBotReplyMessage represents the reply message structure +type WeComBotReplyMessage struct { + MsgType string `json:"msgtype"` + Text struct { + Content string `json:"content"` + } `json:"text,omitempty"` +} + +// NewWeComBotChannel creates a new WeCom Bot channel instance +func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) { + if cfg.Token == "" || cfg.WebhookURL == "" { + return nil, fmt.Errorf("wecom token and webhook_url are required") + } + + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + + return &WeComBotChannel{ + BaseChannel: base, + config: cfg, + processedMsgs: make(map[string]bool), + }, nil +} + +// Name returns the channel name +func (c *WeComBotChannel) Name() string { + return "wecom" +} + +// Start initializes the WeCom Bot channel with HTTP webhook server +func (c *WeComBotChannel) Start(ctx context.Context) error { + logger.InfoC("wecom", "Starting WeCom Bot channel...") + + c.ctx, c.cancel = context.WithCancel(ctx) + + // Setup HTTP server for webhook + mux := http.NewServeMux() + webhookPath := c.config.WebhookPath + if webhookPath == "" { + webhookPath = "/webhook/wecom" + } + mux.HandleFunc(webhookPath, c.handleWebhook) + + // Health check endpoint + mux.HandleFunc("/health/wecom", c.handleHealth) + + addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) + c.server = &http.Server{ + Addr: addr, + Handler: mux, + } + + c.SetRunning(true) + logger.InfoCF("wecom", "WeCom Bot channel started", map[string]interface{}{ + "address": addr, + "path": webhookPath, + }) + + // Start server in goroutine + go func() { + if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("wecom", "HTTP server error", map[string]interface{}{ + "error": err.Error(), + }) + } + }() + + return nil +} + +// Stop gracefully stops the WeCom Bot channel +func (c *WeComBotChannel) Stop(ctx context.Context) error { + logger.InfoC("wecom", "Stopping WeCom Bot channel...") + + if c.cancel != nil { + c.cancel() + } + + if c.server != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + c.server.Shutdown(shutdownCtx) + } + + c.SetRunning(false) + logger.InfoC("wecom", "WeCom Bot channel stopped") + return nil +} + +// Send sends a message to WeCom user via webhook API +// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message +// For delayed responses, we use the webhook URL +func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return fmt.Errorf("wecom channel not running") + } + + logger.DebugCF("wecom", "Sending message via webhook", map[string]interface{}{ + "chat_id": msg.ChatID, + "preview": utils.Truncate(msg.Content, 100), + }) + + return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) +} + +// handleWebhook handles incoming webhook requests from WeCom +func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + if r.Method == http.MethodGet { + // Handle verification request + c.handleVerification(ctx, w, r) + return + } + + if r.Method == http.MethodPost { + // Handle message callback + c.handleMessageCallback(ctx, w, r) + return + } + + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) +} + +// handleVerification handles the URL verification request from WeCom +func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + echostr := query.Get("echostr") + + if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { + logger.WarnC("wecom", "Signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt echostr + // For AIBOT (智能机器人), receiveid should be empty string "" + // Reference: https://developer.work.weixin.qq.com/document/path/101033 + decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Remove BOM and whitespace as per WeCom documentation + // The response must be plain text without quotes, BOM, or newlines + decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) + decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM + w.Write([]byte(decryptedEchoStr)) +} + +// handleMessageCallback handles incoming messages from WeCom +func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + msgSignature := query.Get("msg_signature") + timestamp := query.Get("timestamp") + nonce := query.Get("nonce") + + if msgSignature == "" || timestamp == "" || nonce == "" { + http.Error(w, "Missing parameters", http.StatusBadRequest) + return + } + + // Read request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read body", http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Parse XML to get encrypted message + var encryptedMsg struct { + XMLName xml.Name `xml:"xml"` + ToUserName string `xml:"ToUserName"` + Encrypt string `xml:"Encrypt"` + AgentID string `xml:"AgentID"` + } + + if err := xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom", "Failed to parse XML", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid XML", http.StatusBadRequest) + return + } + + // Verify signature + if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { + logger.WarnC("wecom", "Message signature verification failed") + http.Error(w, "Invalid signature", http.StatusForbidden) + return + } + + // Decrypt message + // For AIBOT (智能机器人), receiveid should be empty string "" + // Reference: https://developer.work.weixin.qq.com/document/path/101033 + decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") + if err != nil { + logger.ErrorCF("wecom", "Failed to decrypt message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Decryption failed", http.StatusInternalServerError) + return + } + + // Parse decrypted JSON message (AIBOT uses JSON format) + var msg WeComBotMessage + if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { + logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]interface{}{ + "error": err.Error(), + }) + http.Error(w, "Invalid message format", http.StatusBadRequest) + return + } + + // Process the message asynchronously with context + go c.processMessage(ctx, msg) + + // Return success response immediately + // WeCom Bot requires response within configured timeout (default 5 seconds) + w.Write([]byte("success")) +} + +// processMessage processes the received message +func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { + // Skip unsupported message types + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && msg.MsgType != "mixed" { + logger.DebugCF("wecom", "Skipping non-supported message type", map[string]interface{}{ + "msg_type": msg.MsgType, + }) + return + } + + // Message deduplication: Use msg_id to prevent duplicate processing + msgID := msg.MsgID + c.msgMu.Lock() + if c.processedMsgs[msgID] { + c.msgMu.Unlock() + logger.DebugCF("wecom", "Skipping duplicate message", map[string]interface{}{ + "msg_id": msgID, + }) + return + } + c.processedMsgs[msgID] = true + c.msgMu.Unlock() + + // Clean up old messages periodically (keep last 1000) + if len(c.processedMsgs) > 1000 { + c.msgMu.Lock() + c.processedMsgs = make(map[string]bool) + c.msgMu.Unlock() + } + + senderID := msg.From.UserID + + // Determine if this is a group chat or direct message + // ChatType: "single" for DM, "group" for group chat + isGroupChat := msg.ChatType == "group" + + var chatID, peerKind, peerID string + if isGroupChat { + // Group chat: use ChatID as chatID and peer_id + chatID = msg.ChatID + peerKind = "group" + peerID = msg.ChatID + } else { + // Direct message: use senderID as chatID and peer_id + chatID = senderID + peerKind = "direct" + peerID = senderID + } + + // Extract content based on message type + var content string + switch msg.MsgType { + case "text": + content = msg.Text.Content + case "voice": + content = msg.Voice.Content // Voice to text content + case "mixed": + // For mixed messages, concatenate text items + for _, item := range msg.Mixed.MsgItem { + if item.MsgType == "text" { + content += item.Text.Content + } + } + case "image", "file": + // For image and file, we don't have text content + content = "" + } + + // Build metadata + metadata := map[string]string{ + "msg_type": msg.MsgType, + "msg_id": msg.MsgID, + "platform": "wecom", + "peer_kind": peerKind, + "peer_id": peerID, + "response_url": msg.ResponseURL, + } + if isGroupChat { + metadata["chat_id"] = msg.ChatID + metadata["sender_id"] = senderID + } + + logger.DebugCF("wecom", "Received message", map[string]interface{}{ + "sender_id": senderID, + "msg_type": msg.MsgType, + "peer_kind": peerKind, + "is_group_chat": isGroupChat, + "preview": utils.Truncate(content, 50), + }) + + // Handle the message through the base channel + c.HandleMessage(senderID, chatID, content, nil, metadata) +} + +// sendWebhookReply sends a reply using the webhook URL +func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { + reply := WeComBotReplyMessage{ + MsgType: "text", + } + reply.Text.Content = content + + jsonData, err := json.Marshal(reply) + if err != nil { + return fmt.Errorf("failed to marshal reply: %w", err) + } + + // Use configurable timeout (default 5 seconds) + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to send webhook reply: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + // Check response + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + } + if err := json.Unmarshal(body, &result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if result.ErrCode != 0 { + return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode) + } + + return nil +} + +// handleHealth handles health check requests +func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { + status := map[string]interface{}{ + "status": "ok", + "running": c.IsRunning(), + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(status) +} diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go new file mode 100644 index 000000000..460e0058f --- /dev/null +++ b/pkg/channels/wecom/bot_test.go @@ -0,0 +1,753 @@ +package wecom + +import ( + "bytes" + "context" + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "encoding/json" + "encoding/xml" + "fmt" + "net/http" + "net/http/httptest" + "sort" + "strings" + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +// generateTestAESKey generates a valid test AES key +func generateTestAESKey() string { + // AES key needs to be 32 bytes (256 bits) for AES-256 + key := make([]byte, 32) + for i := range key { + key[i] = byte(i) + } + // Return base64 encoded key without padding + return base64.StdEncoding.EncodeToString(key)[:43] +} + +// encryptTestMessage encrypts a message for testing (AIBOT JSON format) +func encryptTestMessage(message, aesKey string) (string, error) { + // Decode AES key + key, err := base64.StdEncoding.DecodeString(aesKey + "=") + if err != nil { + return "", err + } + + // Prepare message: random(16) + msg_len(4) + msg + receiveid + random := make([]byte, 0, 16) + for i := 0; i < 16; i++ { + random = append(random, byte(i)) + } + + msgBytes := []byte(message) + receiveID := []byte("test_aibot_id") + + msgLen := uint32(len(msgBytes)) + lenBytes := make([]byte, 4) + binary.BigEndian.PutUint32(lenBytes, msgLen) + + plainText := append(random, lenBytes...) + plainText = append(plainText, msgBytes...) + plainText = append(plainText, receiveID...) + + // PKCS7 padding + blockSize := aes.BlockSize + padding := blockSize - len(plainText)%blockSize + padText := bytes.Repeat([]byte{byte(padding)}, padding) + plainText = append(plainText, padText...) + + // Encrypt + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + + mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) + cipherText := make([]byte, len(plainText)) + mode.CryptBlocks(cipherText, plainText) + + return base64.StdEncoding.EncodeToString(cipherText), nil +} + +// generateSignature generates a signature for testing +func generateSignature(token, timestamp, nonce, msgEncrypt string) string { + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + str := strings.Join(params, "") + hash := sha1.Sum([]byte(str)) + return fmt.Sprintf("%x", hash) +} + +func TestNewWeComBotChannel(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("missing token", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + _, err := NewWeComBotChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing token, got nil") + } + }) + + t.Run("missing webhook_url", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "", + } + _, err := NewWeComBotChannel(cfg, msgBus) + if err == nil { + t.Error("expected error for missing webhook_url, got nil") + } + }) + + t.Run("valid config", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{"user1", "user2"}, + } + ch, err := NewWeComBotChannel(cfg, msgBus) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ch.Name() != "wecom" { + t.Errorf("Name() = %q, want %q", ch.Name(), "wecom") + } + if ch.IsRunning() { + t.Error("new channel should not be running") + } + }) +} + +func TestWeComBotChannelIsAllowed(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("empty allowlist allows all", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{}, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + if !ch.IsAllowed("any_user") { + t.Error("empty allowlist should allow all users") + } + }) + + t.Run("allowlist restricts users", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + AllowFrom: []string{"allowed_user"}, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + if !ch.IsAllowed("allowed_user") { + t.Error("allowed user should pass allowlist check") + } + if ch.IsAllowed("blocked_user") { + t.Error("non-allowed user should be blocked") + } + }) +} + +func TestWeComBotVerifySignature(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) + + if !verifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { + t.Error("valid signature should pass verification") + } + }) + + t.Run("invalid signature", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + msgEncrypt := "test_message" + + if verifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { + t.Error("invalid signature should fail verification") + } + }) + + t.Run("empty token skips verification", func(t *testing.T) { + // Create a channel manually with empty token to test the behavior + cfgEmpty := config.WeComConfig{ + Token: "", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + base := channels.NewBaseChannel("wecom", cfgEmpty, msgBus, cfgEmpty.AllowFrom) + chEmpty := &WeComBotChannel{ + BaseChannel: base, + config: cfgEmpty, + } + + if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { + t.Error("empty token should skip verification and return true") + } + }) +} + +func TestWeComBotDecryptMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + + t.Run("decrypt without AES key", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + // Without AES key, message should be base64 decoded only + plainText := "hello world" + encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) + + result, err := decryptMessage(encoded, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != plainText { + t.Errorf("decryptMessage() = %q, want %q", result, plainText) + } + }) + + t.Run("decrypt with AES key", func(t *testing.T) { + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: aesKey, + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + originalMsg := "Hello" + encrypted, err := encryptTestMessage(originalMsg, aesKey) + if err != nil { + t.Fatalf("failed to encrypt test message: %v", err) + } + + result, err := decryptMessage(encrypted, ch.config.EncodingAESKey) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if result != originalMsg { + t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) + } + }) + + t.Run("invalid base64", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + _, err := decryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid base64, got nil") + } + }) + + t.Run("invalid AES key", func(t *testing.T) { + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + EncodingAESKey: "invalid_key", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + _, err := decryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) + if err == nil { + t.Error("expected error for invalid AES key, got nil") + } + }) +} + +func TestWeComBotPKCS7Unpad(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty input", + input: []byte{}, + expected: []byte{}, + }, + { + name: "valid padding 3 bytes", + input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), + expected: []byte("hello"), + }, + { + name: "valid padding 16 bytes (full block)", + input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), + expected: []byte("123456789012345"), + }, + { + name: "invalid padding larger than data", + input: []byte{20}, + expected: nil, // should return error + }, + { + name: "invalid padding zero", + input: append([]byte("test"), byte(0)), + expected: nil, // should return error + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := pkcs7Unpad(tt.input) + if tt.expected == nil { + // This case should return an error + if err == nil { + t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) + } + return + } + if err != nil { + t.Errorf("pkcs7Unpad() unexpected error: %v", err) + return + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestWeComBotHandleVerification(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + EncodingAESKey: aesKey, + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid verification request", func(t *testing.T) { + echostr := "test_echostr_123" + encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != echostr { + t.Errorf("response body = %q, want %q", w.Body.String(), echostr) + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + echostr := "test_echostr" + encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + w := httptest.NewRecorder() + + ch.handleVerification(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComBotHandleMessageCallback(t *testing.T) { + msgBus := bus.NewMessageBus() + aesKey := generateTestAESKey() + cfg := config.WeComConfig{ + Token: "test_token", + EncodingAESKey: aesKey, + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("valid direct message callback", func(t *testing.T) { + // Create JSON message for direct chat (single) + jsonMsg := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("valid group message callback", func(t *testing.T) { + // Create JSON message for group chat + jsonMsg := `{ + "msgid": "test_msg_id_456", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user456"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello Group"} + }` + + // Encrypt message + encrypted, _ := encryptTestMessage(jsonMsg, aesKey) + + // Create encrypted XML wrapper + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: encrypted, + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encrypted) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + if w.Body.String() != "success" { + t.Errorf("response body = %q, want %q", w.Body.String(), "success") + } + }) + + t.Run("missing parameters", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid XML", func(t *testing.T) { + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, "") + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, strings.NewReader("invalid xml")) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusBadRequest { + t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) + } + }) + + t.Run("invalid signature", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: "encrypted_data", + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleMessageCallback(context.Background(), w, req) + + if w.Code != http.StatusForbidden { + t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) + } + }) +} + +func TestWeComBotProcessMessage(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("process direct text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_123", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", + } + msg.From.UserID = "user123" + msg.Text.Content = "Hello World" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process group text message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_456", + AIBotID: "test_aibot_id", + ChatID: "group_chat_id_123", + ChatType: "group", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "text", + } + msg.From.UserID = "user456" + msg.Text.Content = "Hello Group" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("process voice message", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_789", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "voice", + } + msg.From.UserID = "user123" + msg.Voice.Content = "Voice message text" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) + + t.Run("skip unsupported message type", func(t *testing.T) { + msg := WeComBotMessage{ + MsgID: "test_msg_id_000", + AIBotID: "test_aibot_id", + ChatType: "single", + ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + MsgType: "video", + } + msg.From.UserID = "user123" + + // Should not panic + ch.processMessage(context.Background(), msg) + }) +} + +func TestWeComBotHandleWebhook(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + t.Run("GET request calls verification", func(t *testing.T) { + echostr := "test_echostr" + encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encoded) + + req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + }) + + t.Run("POST request calls message callback", func(t *testing.T) { + encryptedWrapper := struct { + XMLName xml.Name `xml:"xml"` + Encrypt string `xml:"Encrypt"` + }{ + Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), + } + wrapperData, _ := xml.Marshal(encryptedWrapper) + + timestamp := "1234567890" + nonce := "test_nonce" + signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) + + req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + // Should not be method not allowed + if w.Code == http.StatusMethodNotAllowed { + t.Error("POST request should not return Method Not Allowed") + } + }) + + t.Run("unsupported method", func(t *testing.T) { + req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil) + w := httptest.NewRecorder() + + ch.handleWebhook(w, req) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) + } + }) +} + +func TestWeComBotHandleHealth(t *testing.T) { + msgBus := bus.NewMessageBus() + cfg := config.WeComConfig{ + Token: "test_token", + WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + } + ch, _ := NewWeComBotChannel(cfg, msgBus) + + req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil) + w := httptest.NewRecorder() + + ch.handleHealth(w, req) + + if w.Code != http.StatusOK { + t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) + } + + contentType := w.Header().Get("Content-Type") + if contentType != "application/json" { + t.Errorf("Content-Type = %q, want %q", contentType, "application/json") + } + + body := w.Body.String() + if !strings.Contains(body, "status") || !strings.Contains(body, "running") { + t.Errorf("response body should contain status and running fields, got: %s", body) + } +} + +func TestWeComBotReplyMessage(t *testing.T) { + msg := WeComBotReplyMessage{ + MsgType: "text", + } + msg.Text.Content = "Hello World" + + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } +} + +func TestWeComBotMessageStructure(t *testing.T) { + jsonData := `{ + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` + + var msg WeComBotMessage + err := json.Unmarshal([]byte(jsonData), &msg) + if err != nil { + t.Fatalf("failed to unmarshal JSON: %v", err) + } + + if msg.MsgID != "test_msg_id_123" { + t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") + } + if msg.AIBotID != "test_aibot_id" { + t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") + } + if msg.ChatID != "group_chat_id_123" { + t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") + } + if msg.ChatType != "group" { + t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") + } + if msg.From.UserID != "user123" { + t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") + } + if msg.MsgType != "text" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") + } + if msg.Text.Content != "Hello World" { + t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") + } +} diff --git a/pkg/channels/wecom/common.go b/pkg/channels/wecom/common.go new file mode 100644 index 000000000..3c1629577 --- /dev/null +++ b/pkg/channels/wecom/common.go @@ -0,0 +1,134 @@ +package wecom + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/sha1" + "encoding/base64" + "encoding/binary" + "fmt" + "sort" + "strings" +) + +// blockSize is the PKCS7 block size used by WeCom (32) +const blockSize = 32 + +// verifySignature verifies the message signature for WeCom +// This is a common function used by both WeCom Bot and WeCom App +func verifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { + if token == "" { + return true // Skip verification if token is not set + } + + // Sort parameters + params := []string{token, timestamp, nonce, msgEncrypt} + sort.Strings(params) + + // Concatenate + str := strings.Join(params, "") + + // SHA1 hash + hash := sha1.Sum([]byte(str)) + expectedSignature := fmt.Sprintf("%x", hash) + + return expectedSignature == msgSignature +} + +// decryptMessage decrypts the encrypted message using AES +// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id +func decryptMessage(encryptedMsg, encodingAESKey string) (string, error) { + return decryptMessageWithVerify(encryptedMsg, encodingAESKey, "") +} + +// decryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid +// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. +func decryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { + if encodingAESKey == "" { + // No encryption, return as is (base64 decode) + decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", err + } + return string(decoded), nil + } + + // Decode AES key (base64) + aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") + if err != nil { + return "", fmt.Errorf("failed to decode AES key: %w", err) + } + + // Decode encrypted message + cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) + if err != nil { + return "", fmt.Errorf("failed to decode message: %w", err) + } + + // AES decrypt + block, err := aes.NewCipher(aesKey) + if err != nil { + return "", fmt.Errorf("failed to create cipher: %w", err) + } + + if len(cipherText) < aes.BlockSize { + return "", fmt.Errorf("ciphertext too short") + } + + // IV is the first 16 bytes of AESKey + iv := aesKey[:aes.BlockSize] + mode := cipher.NewCBCDecrypter(block, iv) + plainText := make([]byte, len(cipherText)) + mode.CryptBlocks(plainText, cipherText) + + // Remove PKCS7 padding + plainText, err = pkcs7Unpad(plainText) + if err != nil { + return "", fmt.Errorf("failed to unpad: %w", err) + } + + // Parse message structure + // Format: random(16) + msg_len(4) + msg + receiveid + if len(plainText) < 20 { + return "", fmt.Errorf("decrypted message too short") + } + + msgLen := binary.BigEndian.Uint32(plainText[16:20]) + if int(msgLen) > len(plainText)-20 { + return "", fmt.Errorf("invalid message length") + } + + msg := plainText[20 : 20+msgLen] + + // Verify receiveid if provided + if receiveid != "" && len(plainText) > 20+int(msgLen) { + actualReceiveID := string(plainText[20+msgLen:]) + if actualReceiveID != receiveid { + return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) + } + } + + return string(msg), nil +} + +// pkcs7Unpad removes PKCS7 padding with validation +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return data, nil + } + padding := int(data[len(data)-1]) + // WeCom uses 32-byte block size for PKCS7 padding + if padding == 0 || padding > blockSize { + return nil, fmt.Errorf("invalid padding size: %d", padding) + } + if padding > len(data) { + return nil, fmt.Errorf("padding size larger than data") + } + // Verify all padding bytes + for i := 0; i < padding; i++ { + if data[len(data)-1-i] != byte(padding) { + return nil, fmt.Errorf("invalid padding byte at position %d", i) + } + } + return data[:len(data)-padding], nil +} diff --git a/pkg/channels/wecom/init.go b/pkg/channels/wecom/init.go new file mode 100644 index 000000000..3ef1ecdf3 --- /dev/null +++ b/pkg/channels/wecom/init.go @@ -0,0 +1,16 @@ +package wecom + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("wecom", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComBotChannel(cfg.Channels.WeCom, b) + }) + channels.RegisterFactory("wecom_app", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWeComAppChannel(cfg.Channels.WeComApp, b) + }) +} diff --git a/pkg/channels/whatsapp/init.go b/pkg/channels/whatsapp/init.go new file mode 100644 index 000000000..d9c2669c3 --- /dev/null +++ b/pkg/channels/whatsapp/init.go @@ -0,0 +1,13 @@ +package whatsapp + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("whatsapp", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewWhatsAppChannel(cfg.Channels.WhatsApp, b) + }) +} diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go new file mode 100644 index 000000000..1ac256766 --- /dev/null +++ b/pkg/channels/whatsapp/whatsapp.go @@ -0,0 +1,193 @@ +package whatsapp + +import ( + "context" + "encoding/json" + "fmt" + "log" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/utils" +) + +type WhatsAppChannel struct { + *channels.BaseChannel + conn *websocket.Conn + config config.WhatsAppConfig + url string + mu sync.Mutex + connected bool +} + +func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { + base := channels.NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) + + return &WhatsAppChannel{ + BaseChannel: base, + config: cfg, + url: cfg.BridgeURL, + connected: false, + }, nil +} + +func (c *WhatsAppChannel) Start(ctx context.Context) error { + log.Printf("Starting WhatsApp channel connecting to %s...", c.url) + + dialer := websocket.DefaultDialer + dialer.HandshakeTimeout = 10 * time.Second + + conn, _, err := dialer.Dial(c.url, nil) + if err != nil { + return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) + } + + c.mu.Lock() + c.conn = conn + c.connected = true + c.mu.Unlock() + + c.SetRunning(true) + log.Println("WhatsApp channel connected") + + go c.listen(ctx) + + return nil +} + +func (c *WhatsAppChannel) Stop(ctx context.Context) error { + log.Println("Stopping WhatsApp channel...") + + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Printf("Error closing WhatsApp connection: %v", err) + } + c.conn = nil + } + + c.connected = false + c.SetRunning(false) + + return nil +} + +func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + c.mu.Lock() + defer c.mu.Unlock() + + if c.conn == nil { + return fmt.Errorf("whatsapp connection not established") + } + + payload := map[string]interface{}{ + "type": "message", + "to": msg.ChatID, + "content": msg.Content, + } + + data, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + return fmt.Errorf("failed to send message: %w", err) + } + + return nil +} + +func (c *WhatsAppChannel) listen(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + default: + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + time.Sleep(1 * time.Second) + continue + } + + _, message, err := conn.ReadMessage() + if err != nil { + log.Printf("WhatsApp read error: %v", err) + time.Sleep(2 * time.Second) + continue + } + + var msg map[string]interface{} + if err := json.Unmarshal(message, &msg); err != nil { + log.Printf("Failed to unmarshal WhatsApp message: %v", err) + continue + } + + msgType, ok := msg["type"].(string) + if !ok { + continue + } + + if msgType == "message" { + c.handleIncomingMessage(msg) + } + } + } +} + +func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { + senderID, ok := msg["from"].(string) + if !ok { + return + } + + chatID, ok := msg["chat"].(string) + if !ok { + chatID = senderID + } + + content, ok := msg["content"].(string) + if !ok { + content = "" + } + + var mediaPaths []string + if mediaData, ok := msg["media"].([]interface{}); ok { + mediaPaths = make([]string, 0, len(mediaData)) + for _, m := range mediaData { + if path, ok := m.(string); ok { + mediaPaths = append(mediaPaths, path) + } + } + } + + metadata := make(map[string]string) + if messageID, ok := msg["id"].(string); ok { + metadata["message_id"] = messageID + } + if userName, ok := msg["from_name"].(string); ok { + metadata["user_name"] = userName + } + + if chatID == senderID { + metadata["peer_kind"] = "direct" + metadata["peer_id"] = senderID + } else { + metadata["peer_kind"] = "group" + metadata["peer_id"] = chatID + } + + log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) + + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) +} From 952ae91501c34cfa290d555c41d5a3b90e6e862e Mon Sep 17 00:00:00 2001 From: Hoshina Date: Fri, 20 Feb 2026 23:26:33 +0800 Subject: [PATCH 04/28] refactor(channels): remove old channel files from parent package --- pkg/channels/dingtalk.go | 204 ------ pkg/channels/discord.go | 373 ---------- pkg/channels/feishu_32.go | 38 - pkg/channels/feishu_64.go | 227 ------ pkg/channels/line.go | 606 ---------------- pkg/channels/maixcam.go | 243 ------- pkg/channels/onebot.go | 982 ------------------------- pkg/channels/qq.go | 247 ------- pkg/channels/slack.go | 443 ------------ pkg/channels/slack_test.go | 174 ----- pkg/channels/telegram.go | 529 -------------- pkg/channels/telegram_commands.go | 156 ---- pkg/channels/wecom.go | 605 ---------------- pkg/channels/wecom_app.go | 639 ----------------- pkg/channels/wecom_app_test.go | 1104 ----------------------------- pkg/channels/wecom_test.go | 785 -------------------- pkg/channels/whatsapp.go | 192 ----- 17 files changed, 7547 deletions(-) delete mode 100644 pkg/channels/dingtalk.go delete mode 100644 pkg/channels/discord.go delete mode 100644 pkg/channels/feishu_32.go delete mode 100644 pkg/channels/feishu_64.go delete mode 100644 pkg/channels/line.go delete mode 100644 pkg/channels/maixcam.go delete mode 100644 pkg/channels/onebot.go delete mode 100644 pkg/channels/qq.go delete mode 100644 pkg/channels/slack.go delete mode 100644 pkg/channels/slack_test.go delete mode 100644 pkg/channels/telegram.go delete mode 100644 pkg/channels/telegram_commands.go delete mode 100644 pkg/channels/wecom.go delete mode 100644 pkg/channels/wecom_app.go delete mode 100644 pkg/channels/wecom_app_test.go delete mode 100644 pkg/channels/wecom_test.go delete mode 100644 pkg/channels/whatsapp.go diff --git a/pkg/channels/dingtalk.go b/pkg/channels/dingtalk.go deleted file mode 100644 index 662fba3b7..000000000 --- a/pkg/channels/dingtalk.go +++ /dev/null @@ -1,204 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// DingTalk channel implementation using Stream Mode - -package channels - -import ( - "context" - "fmt" - "sync" - - "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" - "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// DingTalkChannel implements the Channel interface for DingTalk (钉钉) -// It uses WebSocket for receiving messages via stream mode and API for sending -type DingTalkChannel struct { - *BaseChannel - config config.DingTalkConfig - clientID string - clientSecret string - streamClient *client.StreamClient - ctx context.Context - cancel context.CancelFunc - // Map to store session webhooks for each chat - sessionWebhooks sync.Map // chatID -> sessionWebhook -} - -// NewDingTalkChannel creates a new DingTalk channel instance -func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) (*DingTalkChannel, error) { - if cfg.ClientID == "" || cfg.ClientSecret == "" { - return nil, fmt.Errorf("dingtalk client_id and client_secret are required") - } - - base := NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) - - return &DingTalkChannel{ - BaseChannel: base, - config: cfg, - clientID: cfg.ClientID, - clientSecret: cfg.ClientSecret, - }, nil -} - -// Start initializes the DingTalk channel with Stream Mode -func (c *DingTalkChannel) Start(ctx context.Context) error { - logger.InfoC("dingtalk", "Starting DingTalk channel (Stream Mode)...") - - c.ctx, c.cancel = context.WithCancel(ctx) - - // Create credential config - cred := client.NewAppCredentialConfig(c.clientID, c.clientSecret) - - // Create the stream client with options - c.streamClient = client.NewStreamClient( - client.WithAppCredential(cred), - client.WithAutoReconnect(true), - ) - - // Register chatbot callback handler (IChatBotMessageHandler is a function type) - c.streamClient.RegisterChatBotCallbackRouter(c.onChatBotMessageReceived) - - // Start the stream client - if err := c.streamClient.Start(c.ctx); err != nil { - return fmt.Errorf("failed to start stream client: %w", err) - } - - c.setRunning(true) - logger.InfoC("dingtalk", "DingTalk channel started (Stream Mode)") - return nil -} - -// Stop gracefully stops the DingTalk channel -func (c *DingTalkChannel) Stop(ctx context.Context) error { - logger.InfoC("dingtalk", "Stopping DingTalk channel...") - - if c.cancel != nil { - c.cancel() - } - - if c.streamClient != nil { - c.streamClient.Close() - } - - c.setRunning(false) - logger.InfoC("dingtalk", "DingTalk channel stopped") - return nil -} - -// Send sends a message to DingTalk via the chatbot reply API -func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("dingtalk channel not running") - } - - // Get session webhook from storage - sessionWebhookRaw, ok := c.sessionWebhooks.Load(msg.ChatID) - if !ok { - return fmt.Errorf("no session_webhook found for chat %s, cannot send message", msg.ChatID) - } - - sessionWebhook, ok := sessionWebhookRaw.(string) - if !ok { - return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) - } - - logger.DebugCF("dingtalk", "Sending message", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - // Use the session webhook to send the reply - return c.SendDirectReply(ctx, sessionWebhook, msg.Content) -} - -// onChatBotMessageReceived implements the IChatBotMessageHandler function signature -// This is called by the Stream SDK when a new message arrives -// IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) -func (c *DingTalkChannel) onChatBotMessageReceived( - ctx context.Context, - data *chatbot.BotCallbackDataModel, -) ([]byte, error) { - // Extract message content from Text field - content := data.Text.Content - if content == "" { - // Try to extract from Content interface{} if Text is empty - if contentMap, ok := data.Content.(map[string]any); ok { - if textContent, ok := contentMap["content"].(string); ok { - content = textContent - } - } - } - - if content == "" { - return nil, nil // Ignore empty messages - } - - senderID := data.SenderStaffId - senderNick := data.SenderNick - chatID := senderID - if data.ConversationType != "1" { - // For group chats - chatID = data.ConversationId - } - - // Store the session webhook for this chat so we can reply later - c.sessionWebhooks.Store(chatID, data.SessionWebhook) - - metadata := map[string]string{ - "sender_name": senderNick, - "conversation_id": data.ConversationId, - "conversation_type": data.ConversationType, - "platform": "dingtalk", - "session_webhook": data.SessionWebhook, - } - - if data.ConversationType == "1" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID - } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = data.ConversationId - } - - logger.DebugCF("dingtalk", "Received message", map[string]any{ - "sender_nick": senderNick, - "sender_id": senderID, - "preview": utils.Truncate(content, 50), - }) - - // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) - - // Return nil to indicate we've handled the message asynchronously - // The response will be sent through the message bus - return nil, nil -} - -// SendDirectReply sends a direct reply using the session webhook -func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, content string) error { - replier := chatbot.NewChatbotReplier() - - // Convert string content to []byte for the API - contentBytes := []byte(content) - titleBytes := []byte("PicoClaw") - - // Send markdown formatted reply - err := replier.SimpleReplyMarkdown( - ctx, - sessionWebhook, - titleBytes, - contentBytes, - ) - if err != nil { - return fmt.Errorf("failed to send reply: %w", err) - } - - return nil -} diff --git a/pkg/channels/discord.go b/pkg/channels/discord.go deleted file mode 100644 index 20f3b267c..000000000 --- a/pkg/channels/discord.go +++ /dev/null @@ -1,373 +0,0 @@ -package channels - -import ( - "context" - "fmt" - "os" - "strings" - "sync" - "time" - - "github.com/bwmarrin/discordgo" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" -) - -const ( - transcriptionTimeout = 30 * time.Second - sendTimeout = 10 * time.Second -) - -type DiscordChannel struct { - *BaseChannel - session *discordgo.Session - config config.DiscordConfig - transcriber *voice.GroqTranscriber - ctx context.Context - typingMu sync.Mutex - typingStop map[string]chan struct{} // chatID → stop signal - botUserID string // stored for mention checking -} - -func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { - session, err := discordgo.New("Bot " + cfg.Token) - if err != nil { - return nil, fmt.Errorf("failed to create discord session: %w", err) - } - - base := NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) - - return &DiscordChannel{ - BaseChannel: base, - session: session, - config: cfg, - transcriber: nil, - ctx: context.Background(), - typingStop: make(map[string]chan struct{}), - }, nil -} - -func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - -func (c *DiscordChannel) getContext() context.Context { - if c.ctx == nil { - return context.Background() - } - return c.ctx -} - -func (c *DiscordChannel) Start(ctx context.Context) error { - logger.InfoC("discord", "Starting Discord bot") - - c.ctx = ctx - - // Get bot user ID before opening session to avoid race condition - botUser, err := c.session.User("@me") - if err != nil { - return fmt.Errorf("failed to get bot user: %w", err) - } - c.botUserID = botUser.ID - - c.session.AddHandler(c.handleMessage) - - if err := c.session.Open(); err != nil { - return fmt.Errorf("failed to open discord session: %w", err) - } - - c.setRunning(true) - - logger.InfoCF("discord", "Discord bot connected", map[string]any{ - "username": botUser.Username, - "user_id": botUser.ID, - }) - - return nil -} - -func (c *DiscordChannel) Stop(ctx context.Context) error { - logger.InfoC("discord", "Stopping Discord bot") - c.setRunning(false) - - // Stop all typing goroutines before closing session - c.typingMu.Lock() - for chatID, stop := range c.typingStop { - close(stop) - delete(c.typingStop, chatID) - } - c.typingMu.Unlock() - - if err := c.session.Close(); err != nil { - return fmt.Errorf("failed to close discord session: %w", err) - } - - return nil -} - -func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.stopTyping(msg.ChatID) - - if !c.IsRunning() { - return fmt.Errorf("discord bot not running") - } - - channelID := msg.ChatID - if channelID == "" { - return fmt.Errorf("channel ID is empty") - } - - runes := []rune(msg.Content) - if len(runes) == 0 { - return nil - } - - chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars - - for _, chunk := range chunks { - if err := c.sendChunk(ctx, channelID, chunk); err != nil { - return err - } - } - - return nil -} - -func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { - // Use the passed ctx for timeout control - sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) - defer cancel() - - done := make(chan error, 1) - go func() { - _, err := c.session.ChannelMessageSend(channelID, content) - done <- err - }() - - select { - case err := <-done: - if err != nil { - return fmt.Errorf("failed to send discord message: %w", err) - } - return nil - case <-sendCtx.Done(): - return fmt.Errorf("send message timeout: %w", sendCtx.Err()) - } -} - -// appendContent safely appends content to existing text -func appendContent(content, suffix string) string { - if content == "" { - return suffix - } - return content + "\n" + suffix -} - -func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.MessageCreate) { - if m == nil || m.Author == nil { - return - } - - if m.Author.ID == s.State.User.ID { - return - } - - // Check allowlist first to avoid downloading attachments and transcribing for rejected users - if !c.IsAllowed(m.Author.ID) { - logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ - "user_id": m.Author.ID, - }) - return - } - - // If configured to only respond to mentions, check if bot is mentioned - // Skip this check for DMs (GuildID is empty) - DMs should always be responded to - if c.config.MentionOnly && m.GuildID != "" { - isMentioned := false - for _, mention := range m.Mentions { - if mention.ID == c.botUserID { - isMentioned = true - break - } - } - if !isMentioned { - logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ - "user_id": m.Author.ID, - }) - return - } - } - - senderID := m.Author.ID - senderName := m.Author.Username - if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { - senderName += "#" + m.Author.Discriminator - } - - content := m.Content - content = c.stripBotMention(content) - mediaPaths := make([]string, 0, len(m.Attachments)) - localFiles := make([]string, 0, len(m.Attachments)) - - // Ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) - } - } - }() - - for _, attachment := range m.Attachments { - isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) - - if isAudio { - localPath := c.downloadAttachment(attachment.URL, attachment.Filename) - if localPath != "" { - localFiles = append(localFiles, localPath) - - transcribedText := "" - if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) - result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() // Release context resources immediately to avoid leaks in for loop - - if err != nil { - logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) - } else { - transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) - } - - content = appendContent(content, transcribedText) - } else { - logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ - "url": attachment.URL, - "filename": attachment.Filename, - }) - mediaPaths = append(mediaPaths, attachment.URL) - content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) - } - } else { - mediaPaths = append(mediaPaths, attachment.URL) - content = appendContent(content, fmt.Sprintf("[attachment: %s]", attachment.URL)) - } - } - - if content == "" && len(mediaPaths) == 0 { - return - } - - if content == "" { - content = "[media only]" - } - - // Start typing after all early returns — guaranteed to have a matching Send() - c.startTyping(m.ChannelID) - - logger.DebugCF("discord", "Received message", map[string]any{ - "sender_name": senderName, - "sender_id": senderID, - "preview": utils.Truncate(content, 50), - }) - - peerKind := "channel" - peerID := m.ChannelID - if m.GuildID == "" { - peerKind = "direct" - peerID = senderID - } - - metadata := map[string]string{ - "message_id": m.ID, - "user_id": senderID, - "username": m.Author.Username, - "display_name": senderName, - "guild_id": m.GuildID, - "channel_id": m.ChannelID, - "is_dm": fmt.Sprintf("%t", m.GuildID == ""), - "peer_kind": peerKind, - "peer_id": peerID, - } - - c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) -} - -// startTyping starts a continuous typing indicator loop for the given chatID. -// It stops any existing typing loop for that chatID before starting a new one. -func (c *DiscordChannel) startTyping(chatID string) { - c.typingMu.Lock() - // Stop existing loop for this chatID if any - if stop, ok := c.typingStop[chatID]; ok { - close(stop) - } - stop := make(chan struct{}) - c.typingStop[chatID] = stop - c.typingMu.Unlock() - - go func() { - if err := c.session.ChannelTyping(chatID); err != nil { - logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err}) - } - ticker := time.NewTicker(8 * time.Second) - defer ticker.Stop() - timeout := time.After(5 * time.Minute) - for { - select { - case <-stop: - return - case <-timeout: - return - case <-c.ctx.Done(): - return - case <-ticker.C: - if err := c.session.ChannelTyping(chatID); err != nil { - logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err}) - } - } - } - }() -} - -// stopTyping stops the typing indicator loop for the given chatID. -func (c *DiscordChannel) stopTyping(chatID string) { - c.typingMu.Lock() - defer c.typingMu.Unlock() - if stop, ok := c.typingStop[chatID]; ok { - close(stop) - delete(c.typingStop, chatID) - } -} - -func (c *DiscordChannel) downloadAttachment(url, filename string) string { - return utils.DownloadFile(url, filename, utils.DownloadOptions{ - LoggerPrefix: "discord", - }) -} - -// stripBotMention removes the bot mention from the message content. -// Discord mentions have the format <@USER_ID> or <@!USER_ID> (with nickname). -func (c *DiscordChannel) stripBotMention(text string) string { - if c.botUserID == "" { - return text - } - // Remove both regular mention <@USER_ID> and nickname mention <@!USER_ID> - text = strings.ReplaceAll(text, fmt.Sprintf("<@%s>", c.botUserID), "") - text = strings.ReplaceAll(text, fmt.Sprintf("<@!%s>", c.botUserID), "") - return strings.TrimSpace(text) -} diff --git a/pkg/channels/feishu_32.go b/pkg/channels/feishu_32.go deleted file mode 100644 index 5109b8195..000000000 --- a/pkg/channels/feishu_32.go +++ /dev/null @@ -1,38 +0,0 @@ -//go:build !amd64 && !arm64 && !riscv64 && !mips64 && !ppc64 - -package channels - -import ( - "context" - "errors" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// FeishuChannel is a stub implementation for 32-bit architectures -type FeishuChannel struct { - *BaseChannel -} - -// NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported -func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - return nil, errors.New( - "feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config", - ) -} - -// Start is a stub method to satisfy the Channel interface -func (c *FeishuChannel) Start(ctx context.Context) error { - return nil -} - -// Stop is a stub method to satisfy the Channel interface -func (c *FeishuChannel) Stop(ctx context.Context) error { - return nil -} - -// Send is a stub method to satisfy the Channel interface -func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - return errors.New("feishu channel is not supported on 32-bit architectures") -} diff --git a/pkg/channels/feishu_64.go b/pkg/channels/feishu_64.go deleted file mode 100644 index 42e74980f..000000000 --- a/pkg/channels/feishu_64.go +++ /dev/null @@ -1,227 +0,0 @@ -//go:build amd64 || arm64 || riscv64 || mips64 || ppc64 - -package channels - -import ( - "context" - "encoding/json" - "fmt" - "sync" - "time" - - lark "github.com/larksuite/oapi-sdk-go/v3" - larkdispatcher "github.com/larksuite/oapi-sdk-go/v3/event/dispatcher" - larkim "github.com/larksuite/oapi-sdk-go/v3/service/im/v1" - larkws "github.com/larksuite/oapi-sdk-go/v3/ws" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -type FeishuChannel struct { - *BaseChannel - config config.FeishuConfig - client *lark.Client - wsClient *larkws.Client - - mu sync.Mutex - cancel context.CancelFunc -} - -func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - base := NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) - - return &FeishuChannel{ - BaseChannel: base, - config: cfg, - client: lark.NewClient(cfg.AppID, cfg.AppSecret), - }, nil -} - -func (c *FeishuChannel) Start(ctx context.Context) error { - if c.config.AppID == "" || c.config.AppSecret == "" { - return fmt.Errorf("feishu app_id or app_secret is empty") - } - - dispatcher := larkdispatcher.NewEventDispatcher(c.config.VerificationToken, c.config.EncryptKey). - OnP2MessageReceiveV1(c.handleMessageReceive) - - runCtx, cancel := context.WithCancel(ctx) - - c.mu.Lock() - c.cancel = cancel - c.wsClient = larkws.NewClient( - c.config.AppID, - c.config.AppSecret, - larkws.WithEventHandler(dispatcher), - ) - wsClient := c.wsClient - c.mu.Unlock() - - c.setRunning(true) - logger.InfoC("feishu", "Feishu channel started (websocket mode)") - - go func() { - if err := wsClient.Start(runCtx); err != nil { - logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{ - "error": err.Error(), - }) - } - }() - - return nil -} - -func (c *FeishuChannel) Stop(ctx context.Context) error { - c.mu.Lock() - if c.cancel != nil { - c.cancel() - c.cancel = nil - } - c.wsClient = nil - c.mu.Unlock() - - c.setRunning(false) - logger.InfoC("feishu", "Feishu channel stopped") - return nil -} - -func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("feishu channel not running") - } - - if msg.ChatID == "" { - return fmt.Errorf("chat ID is empty") - } - - payload, err := json.Marshal(map[string]string{"text": msg.Content}) - if err != nil { - return fmt.Errorf("failed to marshal feishu content: %w", err) - } - - req := larkim.NewCreateMessageReqBuilder(). - ReceiveIdType(larkim.ReceiveIdTypeChatId). - Body(larkim.NewCreateMessageReqBodyBuilder(). - ReceiveId(msg.ChatID). - MsgType(larkim.MsgTypeText). - Content(string(payload)). - Uuid(fmt.Sprintf("picoclaw-%d", time.Now().UnixNano())). - Build()). - Build() - - resp, err := c.client.Im.V1.Message.Create(ctx, req) - if err != nil { - return fmt.Errorf("failed to send feishu message: %w", err) - } - - if !resp.Success() { - return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) - } - - logger.DebugCF("feishu", "Feishu message sent", map[string]any{ - "chat_id": msg.ChatID, - }) - - return nil -} - -func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error { - if event == nil || event.Event == nil || event.Event.Message == nil { - return nil - } - - message := event.Event.Message - sender := event.Event.Sender - - chatID := stringValue(message.ChatId) - if chatID == "" { - return nil - } - - senderID := extractFeishuSenderID(sender) - if senderID == "" { - senderID = "unknown" - } - - content := extractFeishuMessageContent(message) - if content == "" { - content = "[empty message]" - } - - metadata := map[string]string{} - if messageID := stringValue(message.MessageId); messageID != "" { - metadata["message_id"] = messageID - } - if messageType := stringValue(message.MessageType); messageType != "" { - metadata["message_type"] = messageType - } - if chatType := stringValue(message.ChatType); chatType != "" { - metadata["chat_type"] = chatType - } - if sender != nil && sender.TenantKey != nil { - metadata["tenant_key"] = *sender.TenantKey - } - - chatType := stringValue(message.ChatType) - if chatType == "p2p" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID - } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID - } - - logger.InfoCF("feishu", "Feishu message received", map[string]any{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 80), - }) - - c.HandleMessage(senderID, chatID, content, nil, metadata) - return nil -} - -func extractFeishuSenderID(sender *larkim.EventSender) string { - if sender == nil || sender.SenderId == nil { - return "" - } - - if sender.SenderId.UserId != nil && *sender.SenderId.UserId != "" { - return *sender.SenderId.UserId - } - if sender.SenderId.OpenId != nil && *sender.SenderId.OpenId != "" { - return *sender.SenderId.OpenId - } - if sender.SenderId.UnionId != nil && *sender.SenderId.UnionId != "" { - return *sender.SenderId.UnionId - } - - return "" -} - -func extractFeishuMessageContent(message *larkim.EventMessage) string { - if message == nil || message.Content == nil || *message.Content == "" { - return "" - } - - if message.MessageType != nil && *message.MessageType == larkim.MsgTypeText { - var textPayload struct { - Text string `json:"text"` - } - if err := json.Unmarshal([]byte(*message.Content), &textPayload); err == nil { - return textPayload.Text - } - } - - return *message.Content -} - -func stringValue(v *string) string { - if v == nil { - return "" - } - return *v -} diff --git a/pkg/channels/line.go b/pkg/channels/line.go deleted file mode 100644 index 44134996f..000000000 --- a/pkg/channels/line.go +++ /dev/null @@ -1,606 +0,0 @@ -package channels - -import ( - "bytes" - "context" - "crypto/hmac" - "crypto/sha256" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "os" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -const ( - lineAPIBase = "https://api.line.me/v2/bot" - lineDataAPIBase = "https://api-data.line.me/v2/bot" - lineReplyEndpoint = lineAPIBase + "/message/reply" - linePushEndpoint = lineAPIBase + "/message/push" - lineContentEndpoint = lineDataAPIBase + "/message/%s/content" - lineBotInfoEndpoint = lineAPIBase + "/info" - lineLoadingEndpoint = lineAPIBase + "/chat/loading/start" - lineReplyTokenMaxAge = 25 * time.Second -) - -type replyTokenEntry struct { - token string - timestamp time.Time -} - -// LINEChannel implements the Channel interface for LINE Official Account -// using the LINE Messaging API with HTTP webhook for receiving messages -// and REST API for sending messages. -type LINEChannel struct { - *BaseChannel - config config.LINEConfig - httpServer *http.Server - botUserID string // Bot's user ID - botBasicID string // Bot's basic ID (e.g. @216ru...) - botDisplayName string // Bot's display name for text-based mention detection - replyTokens sync.Map // chatID -> replyTokenEntry - quoteTokens sync.Map // chatID -> quoteToken (string) - ctx context.Context - cancel context.CancelFunc -} - -// NewLINEChannel creates a new LINE channel instance. -func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINEChannel, error) { - if cfg.ChannelSecret == "" || cfg.ChannelAccessToken == "" { - return nil, fmt.Errorf("line channel_secret and channel_access_token are required") - } - - base := NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) - - return &LINEChannel{ - BaseChannel: base, - config: cfg, - }, nil -} - -// Start launches the HTTP webhook server. -func (c *LINEChannel) Start(ctx context.Context) error { - logger.InfoC("line", "Starting LINE channel (Webhook Mode)") - - c.ctx, c.cancel = context.WithCancel(ctx) - - // Fetch bot profile to get bot's userId for mention detection - if err := c.fetchBotInfo(); err != nil { - logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]any{ - "error": err.Error(), - }) - } else { - logger.InfoCF("line", "Bot info fetched", map[string]any{ - "bot_user_id": c.botUserID, - "basic_id": c.botBasicID, - "display_name": c.botDisplayName, - }) - } - - mux := http.NewServeMux() - path := c.config.WebhookPath - if path == "" { - path = "/webhook/line" - } - mux.HandleFunc(path, c.webhookHandler) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.httpServer = &http.Server{ - Addr: addr, - Handler: mux, - } - - go func() { - logger.InfoCF("line", "LINE webhook server listening", map[string]any{ - "addr": addr, - "path": path, - }) - if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("line", "Webhook server error", map[string]any{ - "error": err.Error(), - }) - } - }() - - c.setRunning(true) - logger.InfoC("line", "LINE channel started (Webhook Mode)") - return nil -} - -// fetchBotInfo retrieves the bot's userId, basicId, and displayName from the LINE API. -func (c *LINEChannel) fetchBotInfo() error { - req, err := http.NewRequest(http.MethodGet, lineBotInfoEndpoint, nil) - if err != nil { - return err - } - req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) - - client := &http.Client{Timeout: 10 * time.Second} - resp, err := client.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("bot info API returned status %d", resp.StatusCode) - } - - var info struct { - UserID string `json:"userId"` - BasicID string `json:"basicId"` - DisplayName string `json:"displayName"` - } - if err := json.NewDecoder(resp.Body).Decode(&info); err != nil { - return err - } - - c.botUserID = info.UserID - c.botBasicID = info.BasicID - c.botDisplayName = info.DisplayName - return nil -} - -// Stop gracefully shuts down the HTTP server. -func (c *LINEChannel) Stop(ctx context.Context) error { - logger.InfoC("line", "Stopping LINE channel") - - if c.cancel != nil { - c.cancel() - } - - if c.httpServer != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if err := c.httpServer.Shutdown(shutdownCtx); err != nil { - logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{ - "error": err.Error(), - }) - } - } - - c.setRunning(false) - logger.InfoC("line", "LINE channel stopped") - return nil -} - -// webhookHandler handles incoming LINE webhook requests. -func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPost { - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - logger.ErrorCF("line", "Failed to read request body", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - - signature := r.Header.Get("X-Line-Signature") - if !c.verifySignature(body, signature) { - logger.WarnC("line", "Invalid webhook signature") - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - - var payload struct { - Events []lineEvent `json:"events"` - } - if err := json.Unmarshal(body, &payload); err != nil { - logger.ErrorCF("line", "Failed to parse webhook payload", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Bad request", http.StatusBadRequest) - return - } - - // Return 200 immediately, process events asynchronously - w.WriteHeader(http.StatusOK) - - for _, event := range payload.Events { - go c.processEvent(event) - } -} - -// verifySignature validates the X-Line-Signature using HMAC-SHA256. -func (c *LINEChannel) verifySignature(body []byte, signature string) bool { - if signature == "" { - return false - } - - mac := hmac.New(sha256.New, []byte(c.config.ChannelSecret)) - mac.Write(body) - expected := base64.StdEncoding.EncodeToString(mac.Sum(nil)) - - return hmac.Equal([]byte(expected), []byte(signature)) -} - -// LINE webhook event types -type lineEvent struct { - Type string `json:"type"` - ReplyToken string `json:"replyToken"` - Source lineSource `json:"source"` - Message json.RawMessage `json:"message"` - Timestamp int64 `json:"timestamp"` -} - -type lineSource struct { - Type string `json:"type"` // "user", "group", "room" - UserID string `json:"userId"` - GroupID string `json:"groupId"` - RoomID string `json:"roomId"` -} - -type lineMessage struct { - ID string `json:"id"` - Type string `json:"type"` // "text", "image", "video", "audio", "file", "sticker" - Text string `json:"text"` - QuoteToken string `json:"quoteToken"` - Mention *struct { - Mentionees []lineMentionee `json:"mentionees"` - } `json:"mention"` - ContentProvider struct { - Type string `json:"type"` - } `json:"contentProvider"` -} - -type lineMentionee struct { - Index int `json:"index"` - Length int `json:"length"` - Type string `json:"type"` // "user", "all" - UserID string `json:"userId"` -} - -func (c *LINEChannel) processEvent(event lineEvent) { - if event.Type != "message" { - logger.DebugCF("line", "Ignoring non-message event", map[string]any{ - "type": event.Type, - }) - return - } - - senderID := event.Source.UserID - chatID := c.resolveChatID(event.Source) - isGroup := event.Source.Type == "group" || event.Source.Type == "room" - - var msg lineMessage - if err := json.Unmarshal(event.Message, &msg); err != nil { - logger.ErrorCF("line", "Failed to parse message", map[string]any{ - "error": err.Error(), - }) - return - } - - // In group chats, only respond when the bot is mentioned - if isGroup && !c.isBotMentioned(msg) { - logger.DebugCF("line", "Ignoring group message without mention", map[string]any{ - "chat_id": chatID, - }) - return - } - - // Store reply token for later use - if event.ReplyToken != "" { - c.replyTokens.Store(chatID, replyTokenEntry{ - token: event.ReplyToken, - timestamp: time.Now(), - }) - } - - // Store quote token for quoting the original message in reply - if msg.QuoteToken != "" { - c.quoteTokens.Store(chatID, msg.QuoteToken) - } - - var content string - var mediaPaths []string - localFiles := []string{} - - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) - } - } - }() - - switch msg.Type { - case "text": - content = msg.Text - // Strip bot mention from text in group chats - if isGroup { - content = c.stripBotMention(content, msg) - } - case "image": - localPath := c.downloadContent(msg.ID, "image.jpg") - if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) - content = "[image]" - } - case "audio": - localPath := c.downloadContent(msg.ID, "audio.m4a") - if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) - content = "[audio]" - } - case "video": - localPath := c.downloadContent(msg.ID, "video.mp4") - if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) - content = "[video]" - } - case "file": - content = "[file]" - case "sticker": - content = "[sticker]" - default: - content = fmt.Sprintf("[%s]", msg.Type) - } - - if strings.TrimSpace(content) == "" { - return - } - - metadata := map[string]string{ - "platform": "line", - "source_type": event.Source.Type, - "message_id": msg.ID, - } - - if isGroup { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID - } else { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID - } - - logger.DebugCF("line", "Received message", map[string]any{ - "sender_id": senderID, - "chat_id": chatID, - "message_type": msg.Type, - "is_group": isGroup, - "preview": utils.Truncate(content, 50), - }) - - // Show typing/loading indicator (requires user ID, not group ID) - c.sendLoading(senderID) - - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) -} - -// isBotMentioned checks if the bot is mentioned in the message. -// It first checks the mention metadata (userId match), then falls back -// to text-based detection using the bot's display name, since LINE may -// not include userId in mentionees for Official Accounts. -func (c *LINEChannel) isBotMentioned(msg lineMessage) bool { - // Check mention metadata - if msg.Mention != nil { - for _, m := range msg.Mention.Mentionees { - if m.Type == "all" { - return true - } - if c.botUserID != "" && m.UserID == c.botUserID { - return true - } - } - // Mention metadata exists with mentionees but bot not matched by userId. - // The bot IS likely mentioned (LINE includes mention struct when bot is @-ed), - // so check if any mentionee overlaps with bot display name in text. - if c.botDisplayName != "" { - for _, m := range msg.Mention.Mentionees { - if m.Index >= 0 && m.Length > 0 { - runes := []rune(msg.Text) - end := m.Index + m.Length - if end <= len(runes) { - mentionText := string(runes[m.Index:end]) - if strings.Contains(mentionText, c.botDisplayName) { - return true - } - } - } - } - } - } - - // Fallback: text-based detection with display name - if c.botDisplayName != "" && strings.Contains(msg.Text, "@"+c.botDisplayName) { - return true - } - - return false -} - -// stripBotMention removes the @BotName mention text from the message. -func (c *LINEChannel) stripBotMention(text string, msg lineMessage) string { - stripped := false - - // Try to strip using mention metadata indices - if msg.Mention != nil { - runes := []rune(text) - for i := len(msg.Mention.Mentionees) - 1; i >= 0; i-- { - m := msg.Mention.Mentionees[i] - // Strip if userId matches OR if the mention text contains the bot display name - shouldStrip := false - if c.botUserID != "" && m.UserID == c.botUserID { - shouldStrip = true - } else if c.botDisplayName != "" && m.Index >= 0 && m.Length > 0 { - end := m.Index + m.Length - if end <= len(runes) { - mentionText := string(runes[m.Index:end]) - if strings.Contains(mentionText, c.botDisplayName) { - shouldStrip = true - } - } - } - if shouldStrip { - start := m.Index - end := m.Index + m.Length - if start >= 0 && end <= len(runes) { - runes = append(runes[:start], runes[end:]...) - stripped = true - } - } - } - if stripped { - return strings.TrimSpace(string(runes)) - } - } - - // Fallback: strip @DisplayName from text - if c.botDisplayName != "" { - text = strings.ReplaceAll(text, "@"+c.botDisplayName, "") - } - - return strings.TrimSpace(text) -} - -// resolveChatID determines the chat ID from the event source. -// For group/room messages, use the group/room ID; for 1:1, use the user ID. -func (c *LINEChannel) resolveChatID(source lineSource) string { - switch source.Type { - case "group": - return source.GroupID - case "room": - return source.RoomID - default: - return source.UserID - } -} - -// Send sends a message to LINE. It first tries the Reply API (free) -// using a cached reply token, then falls back to the Push API. -func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("line channel not running") - } - - // Load and consume quote token for this chat - var quoteToken string - if qt, ok := c.quoteTokens.LoadAndDelete(msg.ChatID); ok { - quoteToken = qt.(string) - } - - // Try reply token first (free, valid for ~25 seconds) - if entry, ok := c.replyTokens.LoadAndDelete(msg.ChatID); ok { - tokenEntry := entry.(replyTokenEntry) - if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge { - if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil { - logger.DebugCF("line", "Message sent via Reply API", map[string]any{ - "chat_id": msg.ChatID, - "quoted": quoteToken != "", - }) - return nil - } - logger.DebugC("line", "Reply API failed, falling back to Push API") - } - } - - // Fall back to Push API - return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) -} - -// buildTextMessage creates a text message object, optionally with quoteToken. -func buildTextMessage(content, quoteToken string) map[string]string { - msg := map[string]string{ - "type": "text", - "text": content, - } - if quoteToken != "" { - msg["quoteToken"] = quoteToken - } - return msg -} - -// sendReply sends a message using the LINE Reply API. -func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error { - payload := map[string]any{ - "replyToken": replyToken, - "messages": []map[string]string{buildTextMessage(content, quoteToken)}, - } - - return c.callAPI(ctx, lineReplyEndpoint, payload) -} - -// sendPush sends a message using the LINE Push API. -func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error { - payload := map[string]any{ - "to": to, - "messages": []map[string]string{buildTextMessage(content, quoteToken)}, - } - - return c.callAPI(ctx, linePushEndpoint, payload) -} - -// sendLoading sends a loading animation indicator to the chat. -func (c *LINEChannel) sendLoading(chatID string) { - payload := map[string]any{ - "chatId": chatID, - "loadingSeconds": 60, - } - if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil { - logger.DebugCF("line", "Failed to send loading indicator", map[string]any{ - "error": err.Error(), - }) - } -} - -// callAPI makes an authenticated POST request to the LINE API. -func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) error { - body, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal payload: %w", err) - } - - req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.config.ChannelAccessToken) - - client := &http.Client{Timeout: 30 * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("API request failed: %w", err) - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) - } - - return nil -} - -// downloadContent downloads media content from the LINE API. -func (c *LINEChannel) downloadContent(messageID, filename string) string { - url := fmt.Sprintf(lineContentEndpoint, messageID) - return utils.DownloadFile(url, filename, utils.DownloadOptions{ - LoggerPrefix: "line", - ExtraHeaders: map[string]string{ - "Authorization": "Bearer " + c.config.ChannelAccessToken, - }, - }) -} diff --git a/pkg/channels/maixcam.go b/pkg/channels/maixcam.go deleted file mode 100644 index 34ce62b20..000000000 --- a/pkg/channels/maixcam.go +++ /dev/null @@ -1,243 +0,0 @@ -package channels - -import ( - "context" - "encoding/json" - "fmt" - "net" - "sync" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" -) - -type MaixCamChannel struct { - *BaseChannel - config config.MaixCamConfig - listener net.Listener - clients map[net.Conn]bool - clientsMux sync.RWMutex -} - -type MaixCamMessage struct { - Type string `json:"type"` - Tips string `json:"tips"` - Timestamp float64 `json:"timestamp"` - Data map[string]any `json:"data"` -} - -func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { - base := NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) - - return &MaixCamChannel{ - BaseChannel: base, - config: cfg, - clients: make(map[net.Conn]bool), - }, nil -} - -func (c *MaixCamChannel) Start(ctx context.Context) error { - logger.InfoC("maixcam", "Starting MaixCam channel server") - - addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) - listener, err := net.Listen("tcp", addr) - if err != nil { - return fmt.Errorf("failed to listen on %s: %w", addr, err) - } - - c.listener = listener - c.setRunning(true) - - logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{ - "host": c.config.Host, - "port": c.config.Port, - }) - - go c.acceptConnections(ctx) - - return nil -} - -func (c *MaixCamChannel) acceptConnections(ctx context.Context) { - logger.DebugC("maixcam", "Starting connection acceptor") - - for { - select { - case <-ctx.Done(): - logger.InfoC("maixcam", "Stopping connection acceptor") - return - default: - conn, err := c.listener.Accept() - if err != nil { - if c.running { - logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{ - "error": err.Error(), - }) - } - return - } - - logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]any{ - "remote_addr": conn.RemoteAddr().String(), - }) - - c.clientsMux.Lock() - c.clients[conn] = true - c.clientsMux.Unlock() - - go c.handleConnection(conn, ctx) - } - } -} - -func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { - logger.DebugC("maixcam", "Handling MaixCam connection") - - defer func() { - conn.Close() - c.clientsMux.Lock() - delete(c.clients, conn) - c.clientsMux.Unlock() - logger.DebugC("maixcam", "Connection closed") - }() - - decoder := json.NewDecoder(conn) - - for { - select { - case <-ctx.Done(): - return - default: - var msg MaixCamMessage - if err := decoder.Decode(&msg); err != nil { - if err.Error() != "EOF" { - logger.ErrorCF("maixcam", "Failed to decode message", map[string]any{ - "error": err.Error(), - }) - } - return - } - - c.processMessage(msg, conn) - } - } -} - -func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) { - switch msg.Type { - case "person_detected": - c.handlePersonDetection(msg) - case "heartbeat": - logger.DebugC("maixcam", "Received heartbeat") - case "status": - c.handleStatusUpdate(msg) - default: - logger.WarnCF("maixcam", "Unknown message type", map[string]any{ - "type": msg.Type, - }) - } -} - -func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { - logger.InfoCF("maixcam", "", map[string]any{ - "timestamp": msg.Timestamp, - "data": msg.Data, - }) - - senderID := "maixcam" - chatID := "default" - - classInfo, ok := msg.Data["class_name"].(string) - if !ok { - classInfo = "person" - } - - score, _ := msg.Data["score"].(float64) - x, _ := msg.Data["x"].(float64) - y, _ := msg.Data["y"].(float64) - w, _ := msg.Data["w"].(float64) - h, _ := msg.Data["h"].(float64) - - content := fmt.Sprintf("📷 Person detected!\nClass: %s\nConfidence: %.2f%%\nPosition: (%.0f, %.0f)\nSize: %.0fx%.0f", - classInfo, score*100, x, y, w, h) - - metadata := map[string]string{ - "timestamp": fmt.Sprintf("%.0f", msg.Timestamp), - "class_id": fmt.Sprintf("%.0f", msg.Data["class_id"]), - "score": fmt.Sprintf("%.2f", score), - "x": fmt.Sprintf("%.0f", x), - "y": fmt.Sprintf("%.0f", y), - "w": fmt.Sprintf("%.0f", w), - "h": fmt.Sprintf("%.0f", h), - "peer_kind": "channel", - "peer_id": "default", - } - - c.HandleMessage(senderID, chatID, content, []string{}, metadata) -} - -func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { - logger.InfoCF("maixcam", "Status update from MaixCam", map[string]any{ - "status": msg.Data, - }) -} - -func (c *MaixCamChannel) Stop(ctx context.Context) error { - logger.InfoC("maixcam", "Stopping MaixCam channel") - c.setRunning(false) - - if c.listener != nil { - c.listener.Close() - } - - c.clientsMux.Lock() - defer c.clientsMux.Unlock() - - for conn := range c.clients { - conn.Close() - } - c.clients = make(map[net.Conn]bool) - - logger.InfoC("maixcam", "MaixCam channel stopped") - return nil -} - -func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("maixcam channel not running") - } - - c.clientsMux.RLock() - defer c.clientsMux.RUnlock() - - if len(c.clients) == 0 { - logger.WarnC("maixcam", "No MaixCam devices connected") - return fmt.Errorf("no connected MaixCam devices") - } - - response := map[string]any{ - "type": "command", - "timestamp": float64(0), - "message": msg.Content, - "chat_id": msg.ChatID, - } - - data, err := json.Marshal(response) - if err != nil { - return fmt.Errorf("failed to marshal response: %w", err) - } - - var sendErr error - for conn := range c.clients { - if _, err := conn.Write(data); err != nil { - logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ - "client": conn.RemoteAddr().String(), - "error": err.Error(), - }) - sendErr = err - } - } - - return sendErr -} diff --git a/pkg/channels/onebot.go b/pkg/channels/onebot.go deleted file mode 100644 index cee8ad9d3..000000000 --- a/pkg/channels/onebot.go +++ /dev/null @@ -1,982 +0,0 @@ -package channels - -import ( - "context" - "encoding/json" - "fmt" - "os" - "strconv" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" -) - -type OneBotChannel struct { - *BaseChannel - config config.OneBotConfig - conn *websocket.Conn - ctx context.Context - cancel context.CancelFunc - dedup map[string]struct{} - dedupRing []string - dedupIdx int - mu sync.Mutex - writeMu sync.Mutex - echoCounter int64 - selfID int64 - pending map[string]chan json.RawMessage - pendingMu sync.Mutex - transcriber *voice.GroqTranscriber - lastMessageID sync.Map - pendingEmojiMsg sync.Map -} - -type oneBotRawEvent struct { - PostType string `json:"post_type"` - MessageType string `json:"message_type"` - SubType string `json:"sub_type"` - MessageID json.RawMessage `json:"message_id"` - UserID json.RawMessage `json:"user_id"` - GroupID json.RawMessage `json:"group_id"` - RawMessage string `json:"raw_message"` - Message json.RawMessage `json:"message"` - Sender json.RawMessage `json:"sender"` - SelfID json.RawMessage `json:"self_id"` - Time json.RawMessage `json:"time"` - MetaEventType string `json:"meta_event_type"` - NoticeType string `json:"notice_type"` - Echo string `json:"echo"` - RetCode json.RawMessage `json:"retcode"` - Status json.RawMessage `json:"status"` - Data json.RawMessage `json:"data"` -} - -type BotStatus struct { - Online bool `json:"online"` - Good bool `json:"good"` -} - -func isAPIResponse(raw json.RawMessage) bool { - if len(raw) == 0 { - return false - } - var s string - if json.Unmarshal(raw, &s) == nil { - return s == "ok" || s == "failed" - } - var bs BotStatus - if json.Unmarshal(raw, &bs) == nil { - return bs.Online || bs.Good - } - return false -} - -type oneBotSender struct { - UserID json.RawMessage `json:"user_id"` - Nickname string `json:"nickname"` - Card string `json:"card"` -} - -type oneBotAPIRequest struct { - Action string `json:"action"` - Params any `json:"params"` - Echo string `json:"echo,omitempty"` -} - -type oneBotMessageSegment struct { - Type string `json:"type"` - Data map[string]any `json:"data"` -} - -func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { - base := NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) - - const dedupSize = 1024 - return &OneBotChannel{ - BaseChannel: base, - config: cfg, - dedup: make(map[string]struct{}, dedupSize), - dedupRing: make([]string, dedupSize), - dedupIdx: 0, - pending: make(map[string]chan json.RawMessage), - }, nil -} - -func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - -func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { - go func() { - _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{ - "message_id": messageID, - "emoji_id": emojiID, - "set": set, - }, 5*time.Second) - if err != nil { - logger.DebugCF("onebot", "Failed to set emoji like", map[string]any{ - "message_id": messageID, - "error": err.Error(), - }) - } - }() -} - -func (c *OneBotChannel) Start(ctx context.Context) error { - if c.config.WSUrl == "" { - return fmt.Errorf("OneBot ws_url not configured") - } - - logger.InfoCF("onebot", "Starting OneBot channel", map[string]any{ - "ws_url": c.config.WSUrl, - }) - - c.ctx, c.cancel = context.WithCancel(ctx) - - if err := c.connect(); err != nil { - logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]any{ - "error": err.Error(), - }) - } else { - go c.listen() - c.fetchSelfID() - } - - if c.config.ReconnectInterval > 0 { - go c.reconnectLoop() - } else { - if c.conn == nil { - return fmt.Errorf("failed to connect to OneBot and reconnect is disabled") - } - } - - c.setRunning(true) - logger.InfoC("onebot", "OneBot channel started successfully") - - return nil -} - -func (c *OneBotChannel) connect() error { - dialer := websocket.DefaultDialer - dialer.HandshakeTimeout = 10 * time.Second - - header := make(map[string][]string) - if c.config.AccessToken != "" { - header["Authorization"] = []string{"Bearer " + c.config.AccessToken} - } - - conn, _, err := dialer.Dial(c.config.WSUrl, header) - if err != nil { - return err - } - - conn.SetPongHandler(func(appData string) error { - _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - return nil - }) - _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - - c.mu.Lock() - c.conn = conn - c.mu.Unlock() - - go c.pinger(conn) - - logger.InfoC("onebot", "WebSocket connected") - return nil -} - -func (c *OneBotChannel) pinger(conn *websocket.Conn) { - ticker := time.NewTicker(30 * time.Second) - defer ticker.Stop() - - for { - select { - case <-c.ctx.Done(): - return - case <-ticker.C: - c.writeMu.Lock() - err := conn.WriteMessage(websocket.PingMessage, nil) - c.writeMu.Unlock() - if err != nil { - logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]any{ - "error": err.Error(), - }) - return - } - } - } -} - -func (c *OneBotChannel) fetchSelfID() { - resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second) - if err != nil { - logger.WarnCF("onebot", "Failed to get_login_info", map[string]any{ - "error": err.Error(), - }) - return - } - - type loginInfo struct { - UserID json.RawMessage `json:"user_id"` - Nickname string `json:"nickname"` - } - for _, extract := range []func() (*loginInfo, error){ - func() (*loginInfo, error) { - var w struct { - Data loginInfo `json:"data"` - } - err := json.Unmarshal(resp, &w) - return &w.Data, err - }, - func() (*loginInfo, error) { - var f loginInfo - err := json.Unmarshal(resp, &f) - return &f, err - }, - } { - info, err := extract() - if err != nil || len(info.UserID) == 0 { - continue - } - if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 { - atomic.StoreInt64(&c.selfID, uid) - logger.InfoCF("onebot", "Bot self ID retrieved", map[string]any{ - "self_id": uid, - "nickname": info.Nickname, - }) - return - } - } - - logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]any{ - "response": string(resp), - }) -} - -func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.Duration) (json.RawMessage, error) { - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - - if conn == nil { - return nil, fmt.Errorf("WebSocket not connected") - } - - echo := fmt.Sprintf("api_%d_%d", time.Now().UnixNano(), atomic.AddInt64(&c.echoCounter, 1)) - - ch := make(chan json.RawMessage, 1) - c.pendingMu.Lock() - c.pending[echo] = ch - c.pendingMu.Unlock() - - defer func() { - c.pendingMu.Lock() - delete(c.pending, echo) - c.pendingMu.Unlock() - }() - - req := oneBotAPIRequest{ - Action: action, - Params: params, - Echo: echo, - } - - data, err := json.Marshal(req) - if err != nil { - return nil, fmt.Errorf("failed to marshal API request: %w", err) - } - - c.writeMu.Lock() - err = conn.WriteMessage(websocket.TextMessage, data) - c.writeMu.Unlock() - - if err != nil { - return nil, fmt.Errorf("failed to write API request: %w", err) - } - - select { - case resp := <-ch: - return resp, nil - case <-time.After(timeout): - return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) - case <-c.ctx.Done(): - return nil, fmt.Errorf("context cancelled") - } -} - -func (c *OneBotChannel) reconnectLoop() { - interval := time.Duration(c.config.ReconnectInterval) * time.Second - if interval < 5*time.Second { - interval = 5 * time.Second - } - - for { - select { - case <-c.ctx.Done(): - return - case <-time.After(interval): - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - - if conn == nil { - logger.InfoC("onebot", "Attempting to reconnect...") - if err := c.connect(); err != nil { - logger.ErrorCF("onebot", "Reconnect failed", map[string]any{ - "error": err.Error(), - }) - } else { - go c.listen() - c.fetchSelfID() - } - } - } - } -} - -func (c *OneBotChannel) Stop(ctx context.Context) error { - logger.InfoC("onebot", "Stopping OneBot channel") - c.setRunning(false) - - if c.cancel != nil { - c.cancel() - } - - c.pendingMu.Lock() - for echo, ch := range c.pending { - close(ch) - delete(c.pending, echo) - } - c.pendingMu.Unlock() - - c.mu.Lock() - if c.conn != nil { - c.conn.Close() - c.conn = nil - } - c.mu.Unlock() - - return nil -} - -func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("OneBot channel not running") - } - - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - - if conn == nil { - return fmt.Errorf("OneBot WebSocket not connected") - } - - action, params, err := c.buildSendRequest(msg) - if err != nil { - return err - } - - echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) - - req := oneBotAPIRequest{ - Action: action, - Params: params, - Echo: echo, - } - - data, err := json.Marshal(req) - if err != nil { - return fmt.Errorf("failed to marshal OneBot request: %w", err) - } - - c.writeMu.Lock() - err = conn.WriteMessage(websocket.TextMessage, data) - c.writeMu.Unlock() - - if err != nil { - logger.ErrorCF("onebot", "Failed to send message", map[string]any{ - "error": err.Error(), - }) - return err - } - - if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { - if mid, ok := msgID.(string); ok && mid != "" { - c.setMsgEmojiLike(mid, 289, false) - } - } - - return nil -} - -func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment { - var segments []oneBotMessageSegment - - if lastMsgID, ok := c.lastMessageID.Load(chatID); ok { - if msgID, ok := lastMsgID.(string); ok && msgID != "" { - segments = append(segments, oneBotMessageSegment{ - Type: "reply", - Data: map[string]any{"id": msgID}, - }) - } - } - - segments = append(segments, oneBotMessageSegment{ - Type: "text", - Data: map[string]any{"text": content}, - }) - - return segments -} - -func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, any, error) { - chatID := msg.ChatID - segments := c.buildMessageSegments(chatID, msg.Content) - - var action, idKey string - var rawID string - if rest, ok := strings.CutPrefix(chatID, "group:"); ok { - action, idKey, rawID = "send_group_msg", "group_id", rest - } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { - action, idKey, rawID = "send_private_msg", "user_id", rest - } else { - action, idKey, rawID = "send_private_msg", "user_id", chatID - } - - id, err := strconv.ParseInt(rawID, 10, 64) - if err != nil { - return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID) - } - return action, map[string]any{idKey: id, "message": segments}, nil -} - -func (c *OneBotChannel) listen() { - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - - if conn == nil { - logger.WarnC("onebot", "WebSocket connection is nil, listener exiting") - return - } - - for { - select { - case <-c.ctx.Done(): - return - default: - _, message, err := conn.ReadMessage() - if err != nil { - logger.ErrorCF("onebot", "WebSocket read error", map[string]any{ - "error": err.Error(), - }) - c.mu.Lock() - if c.conn == conn { - c.conn.Close() - c.conn = nil - } - c.mu.Unlock() - return - } - - _ = conn.SetReadDeadline(time.Now().Add(60 * time.Second)) - - var raw oneBotRawEvent - if err := json.Unmarshal(message, &raw); err != nil { - logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]any{ - "error": err.Error(), - "payload": string(message), - }) - continue - } - - logger.DebugCF("onebot", "WebSocket event", map[string]any{ - "length": len(message), - "post_type": raw.PostType, - "sub_type": raw.SubType, - }) - - if raw.Echo != "" { - c.pendingMu.Lock() - ch, ok := c.pending[raw.Echo] - c.pendingMu.Unlock() - - if ok { - select { - case ch <- message: - default: - } - } else { - logger.DebugCF("onebot", "Received API response (no waiter)", map[string]any{ - "echo": raw.Echo, - "status": string(raw.Status), - }) - } - continue - } - - if isAPIResponse(raw.Status) { - logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]any{ - "status": string(raw.Status), - }) - continue - } - - c.handleRawEvent(&raw) - } - } -} - -func parseJSONInt64(raw json.RawMessage) (int64, error) { - if len(raw) == 0 { - return 0, nil - } - - var n int64 - if err := json.Unmarshal(raw, &n); err == nil { - return n, nil - } - - var s string - if err := json.Unmarshal(raw, &s); err == nil { - return strconv.ParseInt(s, 10, 64) - } - return 0, fmt.Errorf("cannot parse as int64: %s", string(raw)) -} - -func parseJSONString(raw json.RawMessage) string { - if len(raw) == 0 { - return "" - } - var s string - if err := json.Unmarshal(raw, &s); err == nil { - return s - } - - return string(raw) -} - -type parseMessageResult struct { - Text string - IsBotMentioned bool - Media []string - LocalFiles []string - ReplyTo string -} - -func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { - if len(raw) == 0 { - return parseMessageResult{} - } - - var s string - if err := json.Unmarshal(raw, &s); err == nil { - mentioned := false - if selfID > 0 { - cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) - if strings.Contains(s, cqAt) { - mentioned = true - s = strings.ReplaceAll(s, cqAt, "") - s = strings.TrimSpace(s) - } - } - return parseMessageResult{Text: s, IsBotMentioned: mentioned} - } - - var segments []map[string]any - if err := json.Unmarshal(raw, &segments); err != nil { - return parseMessageResult{} - } - - var textParts []string - mentioned := false - selfIDStr := strconv.FormatInt(selfID, 10) - var media []string - var localFiles []string - var replyTo string - - for _, seg := range segments { - segType, _ := seg["type"].(string) - data, _ := seg["data"].(map[string]any) - - switch segType { - case "text": - if data != nil { - if t, ok := data["text"].(string); ok { - textParts = append(textParts, t) - } - } - - case "at": - if data != nil && selfID > 0 { - qqVal := fmt.Sprintf("%v", data["qq"]) - if qqVal == selfIDStr || qqVal == "all" { - mentioned = true - } - } - - case "image", "video", "file": - if data != nil { - url, _ := data["url"].(string) - if url != "" { - defaults := map[string]string{"image": "image.jpg", "video": "video.mp4", "file": "file"} - filename := defaults[segType] - if f, ok := data["file"].(string); ok && f != "" { - filename = f - } else if n, ok := data["name"].(string); ok && n != "" { - filename = n - } - localPath := utils.DownloadFile(url, filename, utils.DownloadOptions{ - LoggerPrefix: "onebot", - }) - if localPath != "" { - media = append(media, localPath) - localFiles = append(localFiles, localPath) - textParts = append(textParts, fmt.Sprintf("[%s]", segType)) - } - } - } - - case "record": - if data != nil { - url, _ := data["url"].(string) - if url != "" { - localPath := utils.DownloadFile(url, "voice.amr", utils.DownloadOptions{ - LoggerPrefix: "onebot", - }) - if localPath != "" { - localFiles = append(localFiles, localPath) - if c.transcriber != nil && c.transcriber.IsAvailable() { - tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) - result, err := c.transcriber.Transcribe(tctx, localPath) - tcancel() - if err != nil { - logger.WarnCF("onebot", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - textParts = append(textParts, "[voice (transcription failed)]") - media = append(media, localPath) - } else { - textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) - } - } else { - textParts = append(textParts, "[voice]") - media = append(media, localPath) - } - } - } - } - - case "reply": - if data != nil { - if id, ok := data["id"]; ok { - replyTo = fmt.Sprintf("%v", id) - } - } - - case "face": - if data != nil { - faceID, _ := data["id"] - textParts = append(textParts, fmt.Sprintf("[face:%v]", faceID)) - } - - case "forward": - textParts = append(textParts, "[forward message]") - - default: - - } - } - - return parseMessageResult{ - Text: strings.TrimSpace(strings.Join(textParts, "")), - IsBotMentioned: mentioned, - Media: media, - LocalFiles: localFiles, - ReplyTo: replyTo, - } -} - -func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { - switch raw.PostType { - case "message": - if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { - if !c.IsAllowed(strconv.FormatInt(userID, 10)) { - logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{ - "user_id": userID, - }) - return - } - } - c.handleMessage(raw) - - case "message_sent": - logger.DebugCF("onebot", "Bot sent message event", map[string]any{ - "message_type": raw.MessageType, - "message_id": parseJSONString(raw.MessageID), - }) - - case "meta_event": - c.handleMetaEvent(raw) - - case "notice": - c.handleNoticeEvent(raw) - - case "request": - logger.DebugCF("onebot", "Request event received", map[string]any{ - "sub_type": raw.SubType, - }) - - case "": - logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]any{ - "echo": raw.Echo, - "status": raw.Status, - }) - - default: - logger.DebugCF("onebot", "Unknown post_type", map[string]any{ - "post_type": raw.PostType, - }) - } -} - -func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { - if raw.MetaEventType == "lifecycle" { - logger.InfoCF("onebot", "Lifecycle event", map[string]any{"sub_type": raw.SubType}) - } else if raw.MetaEventType != "heartbeat" { - logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil) - } -} - -func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) { - fields := map[string]any{ - "notice_type": raw.NoticeType, - "sub_type": raw.SubType, - "group_id": parseJSONString(raw.GroupID), - "user_id": parseJSONString(raw.UserID), - "message_id": parseJSONString(raw.MessageID), - } - switch raw.NoticeType { - case "group_recall", "group_increase", "group_decrease", - "friend_add", "group_admin", "group_ban": - logger.InfoCF("onebot", "Notice: "+raw.NoticeType, fields) - default: - logger.DebugCF("onebot", "Notice: "+raw.NoticeType, fields) - } -} - -func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { - // Parse fields from raw event - userID, err := parseJSONInt64(raw.UserID) - if err != nil { - logger.WarnCF("onebot", "Failed to parse user_id", map[string]any{ - "error": err.Error(), - "raw": string(raw.UserID), - }) - return - } - - groupID, _ := parseJSONInt64(raw.GroupID) - selfID, _ := parseJSONInt64(raw.SelfID) - messageID := parseJSONString(raw.MessageID) - - if selfID == 0 { - selfID = atomic.LoadInt64(&c.selfID) - } - - parsed := c.parseMessageSegments(raw.Message, selfID) - isBotMentioned := parsed.IsBotMentioned - - content := raw.RawMessage - if content == "" { - content = parsed.Text - } else if selfID > 0 { - cqAt := fmt.Sprintf("[CQ:at,qq=%d]", selfID) - if strings.Contains(content, cqAt) { - isBotMentioned = true - content = strings.ReplaceAll(content, cqAt, "") - content = strings.TrimSpace(content) - } - } - - if parsed.Text != "" && content != parsed.Text && (len(parsed.Media) > 0 || parsed.ReplyTo != "") { - content = parsed.Text - } - - var sender oneBotSender - if len(raw.Sender) > 0 { - if err := json.Unmarshal(raw.Sender, &sender); err != nil { - logger.WarnCF("onebot", "Failed to parse sender", map[string]any{ - "error": err.Error(), - "sender": string(raw.Sender), - }) - } - } - - // Clean up temp files when done - if len(parsed.LocalFiles) > 0 { - defer func() { - for _, f := range parsed.LocalFiles { - if err := os.Remove(f); err != nil { - logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ - "path": f, - "error": err.Error(), - }) - } - } - }() - } - - if c.isDuplicate(messageID) { - logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ - "message_id": messageID, - }) - return - } - - if content == "" { - logger.DebugCF("onebot", "Received empty message, ignoring", map[string]any{ - "message_id": messageID, - }) - return - } - - senderID := strconv.FormatInt(userID, 10) - var chatID string - - metadata := map[string]string{ - "message_id": messageID, - } - - if parsed.ReplyTo != "" { - metadata["reply_to_message_id"] = parsed.ReplyTo - } - - switch raw.MessageType { - case "private": - chatID = "private:" + senderID - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID - - case "group": - groupIDStr := strconv.FormatInt(groupID, 10) - chatID = "group:" + groupIDStr - metadata["peer_kind"] = "group" - metadata["peer_id"] = groupIDStr - metadata["group_id"] = groupIDStr - - senderUserID, _ := parseJSONInt64(sender.UserID) - if senderUserID > 0 { - metadata["sender_user_id"] = strconv.FormatInt(senderUserID, 10) - } - - if sender.Card != "" { - metadata["sender_name"] = sender.Card - } else if sender.Nickname != "" { - metadata["sender_name"] = sender.Nickname - } - - triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) - if !triggered { - logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{ - "sender": senderID, - "group": groupIDStr, - "is_mentioned": isBotMentioned, - "content": truncate(content, 100), - }) - return - } - content = strippedContent - - default: - logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]any{ - "type": raw.MessageType, - "message_id": messageID, - "user_id": userID, - }) - return - } - - logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]any{ - "sender": senderID, - "chat_id": chatID, - "message_id": messageID, - "length": len(content), - "content": truncate(content, 100), - "media_count": len(parsed.Media), - }) - - if sender.Nickname != "" { - metadata["nickname"] = sender.Nickname - } - - c.lastMessageID.Store(chatID, messageID) - - if raw.MessageType == "group" && messageID != "" && messageID != "0" { - c.setMsgEmojiLike(messageID, 289, true) - c.pendingEmojiMsg.Store(chatID, messageID) - } - - c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) -} - -func (c *OneBotChannel) isDuplicate(messageID string) bool { - if messageID == "" || messageID == "0" { - return false - } - - c.mu.Lock() - defer c.mu.Unlock() - - if _, exists := c.dedup[messageID]; exists { - return true - } - - if old := c.dedupRing[c.dedupIdx]; old != "" { - delete(c.dedup, old) - } - c.dedupRing[c.dedupIdx] = messageID - c.dedup[messageID] = struct{}{} - c.dedupIdx = (c.dedupIdx + 1) % len(c.dedupRing) - - return false -} - -func truncate(s string, n int) string { - runes := []rune(s) - if len(runes) <= n { - return s - } - return string(runes[:n]) + "..." -} - -func (c *OneBotChannel) checkGroupTrigger( - content string, - isBotMentioned bool, -) (triggered bool, strippedContent string) { - if isBotMentioned { - return true, strings.TrimSpace(content) - } - - for _, prefix := range c.config.GroupTriggerPrefix { - if prefix == "" { - continue - } - if strings.HasPrefix(content, prefix) { - return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) - } - } - - return false, content -} diff --git a/pkg/channels/qq.go b/pkg/channels/qq.go deleted file mode 100644 index b10776db6..000000000 --- a/pkg/channels/qq.go +++ /dev/null @@ -1,247 +0,0 @@ -package channels - -import ( - "context" - "fmt" - "sync" - "time" - - "github.com/tencent-connect/botgo" - "github.com/tencent-connect/botgo/dto" - "github.com/tencent-connect/botgo/event" - "github.com/tencent-connect/botgo/openapi" - "github.com/tencent-connect/botgo/token" - "golang.org/x/oauth2" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" -) - -type QQChannel struct { - *BaseChannel - config config.QQConfig - api openapi.OpenAPI - tokenSource oauth2.TokenSource - ctx context.Context - cancel context.CancelFunc - sessionManager botgo.SessionManager - processedIDs map[string]bool - mu sync.RWMutex -} - -func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { - base := NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) - - return &QQChannel{ - BaseChannel: base, - config: cfg, - processedIDs: make(map[string]bool), - }, nil -} - -func (c *QQChannel) Start(ctx context.Context) error { - if c.config.AppID == "" || c.config.AppSecret == "" { - return fmt.Errorf("QQ app_id and app_secret not configured") - } - - logger.InfoC("qq", "Starting QQ bot (WebSocket mode)") - - // create token source - credentials := &token.QQBotCredentials{ - AppID: c.config.AppID, - AppSecret: c.config.AppSecret, - } - c.tokenSource = token.NewQQBotTokenSource(credentials) - - // create child context - c.ctx, c.cancel = context.WithCancel(ctx) - - // start auto-refresh token goroutine - if err := token.StartRefreshAccessToken(c.ctx, c.tokenSource); err != nil { - return fmt.Errorf("failed to start token refresh: %w", err) - } - - // initialize OpenAPI client - c.api = botgo.NewOpenAPI(c.config.AppID, c.tokenSource).WithTimeout(5 * time.Second) - - // register event handlers - intent := event.RegisterHandlers( - c.handleC2CMessage(), - c.handleGroupATMessage(), - ) - - // get WebSocket endpoint - wsInfo, err := c.api.WS(c.ctx, nil, "") - if err != nil { - return fmt.Errorf("failed to get websocket info: %w", err) - } - - logger.InfoCF("qq", "Got WebSocket info", map[string]any{ - "shards": wsInfo.Shards, - }) - - // create and save sessionManager - c.sessionManager = botgo.NewSessionManager() - - // start WebSocket connection in goroutine to avoid blocking - go func() { - if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil { - logger.ErrorCF("qq", "WebSocket session error", map[string]any{ - "error": err.Error(), - }) - c.setRunning(false) - } - }() - - c.setRunning(true) - logger.InfoC("qq", "QQ bot started successfully") - - return nil -} - -func (c *QQChannel) Stop(ctx context.Context) error { - logger.InfoC("qq", "Stopping QQ bot") - c.setRunning(false) - - if c.cancel != nil { - c.cancel() - } - - return nil -} - -func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("QQ bot not running") - } - - // construct message - msgToCreate := &dto.MessageToCreate{ - Content: msg.Content, - } - - // send C2C message - _, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate) - if err != nil { - logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ - "error": err.Error(), - }) - return err - } - - return nil -} - -// handleC2CMessage handles QQ private messages -func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { - return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error { - // deduplication check - if c.isDuplicate(data.ID) { - return nil - } - - // extract user info - var senderID string - if data.Author != nil && data.Author.ID != "" { - senderID = data.Author.ID - } else { - logger.WarnC("qq", "Received message with no sender ID") - return nil - } - - // extract message content - content := data.Content - if content == "" { - logger.DebugC("qq", "Received empty message, ignoring") - return nil - } - - logger.InfoCF("qq", "Received C2C message", map[string]any{ - "sender": senderID, - "length": len(content), - }) - - // forward to message bus - metadata := map[string]string{ - "message_id": data.ID, - "peer_kind": "direct", - "peer_id": senderID, - } - - c.HandleMessage(senderID, senderID, content, []string{}, metadata) - - return nil - } -} - -// handleGroupATMessage handles group @messages -func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { - return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error { - // deduplication check - if c.isDuplicate(data.ID) { - return nil - } - - // extract user info - var senderID string - if data.Author != nil && data.Author.ID != "" { - senderID = data.Author.ID - } else { - logger.WarnC("qq", "Received group message with no sender ID") - return nil - } - - // extract message content (remove @bot part) - content := data.Content - if content == "" { - logger.DebugC("qq", "Received empty group message, ignoring") - return nil - } - - logger.InfoCF("qq", "Received group AT message", map[string]any{ - "sender": senderID, - "group": data.GroupID, - "length": len(content), - }) - - // forward to message bus (use GroupID as ChatID) - metadata := map[string]string{ - "message_id": data.ID, - "group_id": data.GroupID, - "peer_kind": "group", - "peer_id": data.GroupID, - } - - c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) - - return nil - } -} - -// isDuplicate checks if message is duplicate -func (c *QQChannel) isDuplicate(messageID string) bool { - c.mu.Lock() - defer c.mu.Unlock() - - if c.processedIDs[messageID] { - return true - } - - c.processedIDs[messageID] = true - - // simple cleanup: limit map size - if len(c.processedIDs) > 10000 { - // clear half - count := 0 - for id := range c.processedIDs { - if count >= 5000 { - break - } - delete(c.processedIDs, id) - count++ - } - } - - return false -} diff --git a/pkg/channels/slack.go b/pkg/channels/slack.go deleted file mode 100644 index f087aa8da..000000000 --- a/pkg/channels/slack.go +++ /dev/null @@ -1,443 +0,0 @@ -package channels - -import ( - "context" - "fmt" - "os" - "strings" - "sync" - "time" - - "github.com/slack-go/slack" - "github.com/slack-go/slack/slackevents" - "github.com/slack-go/slack/socketmode" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" -) - -type SlackChannel struct { - *BaseChannel - config config.SlackConfig - api *slack.Client - socketClient *socketmode.Client - botUserID string - teamID string - transcriber *voice.GroqTranscriber - ctx context.Context - cancel context.CancelFunc - pendingAcks sync.Map -} - -type slackMessageRef struct { - ChannelID string - Timestamp string -} - -func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*SlackChannel, error) { - if cfg.BotToken == "" || cfg.AppToken == "" { - return nil, fmt.Errorf("slack bot_token and app_token are required") - } - - api := slack.New( - cfg.BotToken, - slack.OptionAppLevelToken(cfg.AppToken), - ) - - socketClient := socketmode.New(api) - - base := NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) - - return &SlackChannel{ - BaseChannel: base, - config: cfg, - api: api, - socketClient: socketClient, - }, nil -} - -func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - -func (c *SlackChannel) Start(ctx context.Context) error { - logger.InfoC("slack", "Starting Slack channel (Socket Mode)") - - c.ctx, c.cancel = context.WithCancel(ctx) - - authResp, err := c.api.AuthTest() - if err != nil { - return fmt.Errorf("slack auth test failed: %w", err) - } - c.botUserID = authResp.UserID - c.teamID = authResp.TeamID - - logger.InfoCF("slack", "Slack bot connected", map[string]any{ - "bot_user_id": c.botUserID, - "team": authResp.Team, - }) - - go c.eventLoop() - - go func() { - if err := c.socketClient.RunContext(c.ctx); err != nil { - if c.ctx.Err() == nil { - logger.ErrorCF("slack", "Socket Mode connection error", map[string]any{ - "error": err.Error(), - }) - } - } - }() - - c.setRunning(true) - logger.InfoC("slack", "Slack channel started (Socket Mode)") - return nil -} - -func (c *SlackChannel) Stop(ctx context.Context) error { - logger.InfoC("slack", "Stopping Slack channel") - - if c.cancel != nil { - c.cancel() - } - - c.setRunning(false) - logger.InfoC("slack", "Slack channel stopped") - return nil -} - -func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("slack channel not running") - } - - channelID, threadTS := parseSlackChatID(msg.ChatID) - if channelID == "" { - return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) - } - - opts := []slack.MsgOption{ - slack.MsgOptionText(msg.Content, false), - } - - if threadTS != "" { - opts = append(opts, slack.MsgOptionTS(threadTS)) - } - - _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) - if err != nil { - return fmt.Errorf("failed to send slack message: %w", err) - } - - if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { - msgRef := ref.(slackMessageRef) - c.api.AddReaction("white_check_mark", slack.ItemRef{ - Channel: msgRef.ChannelID, - Timestamp: msgRef.Timestamp, - }) - } - - logger.DebugCF("slack", "Message sent", map[string]any{ - "channel_id": channelID, - "thread_ts": threadTS, - }) - - return nil -} - -func (c *SlackChannel) eventLoop() { - for { - select { - case <-c.ctx.Done(): - return - case event, ok := <-c.socketClient.Events: - if !ok { - return - } - switch event.Type { - case socketmode.EventTypeEventsAPI: - c.handleEventsAPI(event) - case socketmode.EventTypeSlashCommand: - c.handleSlashCommand(event) - case socketmode.EventTypeInteractive: - if event.Request != nil { - c.socketClient.Ack(*event.Request) - } - } - } - } -} - -func (c *SlackChannel) handleEventsAPI(event socketmode.Event) { - if event.Request != nil { - c.socketClient.Ack(*event.Request) - } - - eventsAPIEvent, ok := event.Data.(slackevents.EventsAPIEvent) - if !ok { - return - } - - switch ev := eventsAPIEvent.InnerEvent.Data.(type) { - case *slackevents.MessageEvent: - c.handleMessageEvent(ev) - case *slackevents.AppMentionEvent: - c.handleAppMention(ev) - } -} - -func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { - if ev.User == c.botUserID || ev.User == "" { - return - } - if ev.BotID != "" { - return - } - if ev.SubType != "" && ev.SubType != "file_share" { - return - } - - // check allowlist to avoid downloading attachments for rejected users - if !c.IsAllowed(ev.User) { - logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{ - "user_id": ev.User, - }) - return - } - - senderID := ev.User - channelID := ev.Channel - threadTS := ev.ThreadTimeStamp - messageTS := ev.TimeStamp - - chatID := channelID - if threadTS != "" { - chatID = channelID + "/" + threadTS - } - - c.api.AddReaction("eyes", slack.ItemRef{ - Channel: channelID, - Timestamp: messageTS, - }) - - c.pendingAcks.Store(chatID, slackMessageRef{ - ChannelID: channelID, - Timestamp: messageTS, - }) - - content := ev.Text - content = c.stripBotMention(content) - - var mediaPaths []string - localFiles := []string{} // track local files that need cleanup - - // ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) - } - } - }() - - if ev.Message != nil && len(ev.Message.Files) > 0 { - for _, file := range ev.Message.Files { - localPath := c.downloadSlackFile(file) - if localPath == "" { - continue - } - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) - - if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) - defer cancel() - result, err := c.transcriber.Transcribe(ctx, localPath) - - if err != nil { - logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()}) - content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) - } else { - content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) - } - } else { - content += fmt.Sprintf("\n[file: %s]", file.Name) - } - } - } - - if strings.TrimSpace(content) == "" { - return - } - - peerKind := "channel" - peerID := channelID - if strings.HasPrefix(channelID, "D") { - peerKind = "direct" - peerID = senderID - } - - metadata := map[string]string{ - "message_ts": messageTS, - "channel_id": channelID, - "thread_ts": threadTS, - "platform": "slack", - "peer_kind": peerKind, - "peer_id": peerID, - "team_id": c.teamID, - } - - logger.DebugCF("slack", "Received message", map[string]any{ - "sender_id": senderID, - "chat_id": chatID, - "preview": utils.Truncate(content, 50), - "has_thread": threadTS != "", - }) - - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) -} - -func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { - if ev.User == c.botUserID { - return - } - - if !c.IsAllowed(ev.User) { - logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{ - "user_id": ev.User, - }) - return - } - - senderID := ev.User - channelID := ev.Channel - threadTS := ev.ThreadTimeStamp - messageTS := ev.TimeStamp - - var chatID string - if threadTS != "" { - chatID = channelID + "/" + threadTS - } else { - chatID = channelID + "/" + messageTS - } - - c.api.AddReaction("eyes", slack.ItemRef{ - Channel: channelID, - Timestamp: messageTS, - }) - - c.pendingAcks.Store(chatID, slackMessageRef{ - ChannelID: channelID, - Timestamp: messageTS, - }) - - content := c.stripBotMention(ev.Text) - - if strings.TrimSpace(content) == "" { - return - } - - mentionPeerKind := "channel" - mentionPeerID := channelID - if strings.HasPrefix(channelID, "D") { - mentionPeerKind = "direct" - mentionPeerID = senderID - } - - metadata := map[string]string{ - "message_ts": messageTS, - "channel_id": channelID, - "thread_ts": threadTS, - "platform": "slack", - "is_mention": "true", - "peer_kind": mentionPeerKind, - "peer_id": mentionPeerID, - "team_id": c.teamID, - } - - c.HandleMessage(senderID, chatID, content, nil, metadata) -} - -func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { - cmd, ok := event.Data.(slack.SlashCommand) - if !ok { - return - } - - if event.Request != nil { - c.socketClient.Ack(*event.Request) - } - - if !c.IsAllowed(cmd.UserID) { - logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{ - "user_id": cmd.UserID, - }) - return - } - - senderID := cmd.UserID - channelID := cmd.ChannelID - chatID := channelID - content := cmd.Text - - if strings.TrimSpace(content) == "" { - content = "help" - } - - metadata := map[string]string{ - "channel_id": channelID, - "platform": "slack", - "is_command": "true", - "trigger_id": cmd.TriggerID, - "peer_kind": "channel", - "peer_id": channelID, - "team_id": c.teamID, - } - - logger.DebugCF("slack", "Slash command received", map[string]any{ - "sender_id": senderID, - "command": cmd.Command, - "text": utils.Truncate(content, 50), - }) - - c.HandleMessage(senderID, chatID, content, nil, metadata) -} - -func (c *SlackChannel) downloadSlackFile(file slack.File) string { - downloadURL := file.URLPrivateDownload - if downloadURL == "" { - downloadURL = file.URLPrivate - } - if downloadURL == "" { - logger.ErrorCF("slack", "No download URL for file", map[string]any{"file_id": file.ID}) - return "" - } - - return utils.DownloadFile(downloadURL, file.Name, utils.DownloadOptions{ - LoggerPrefix: "slack", - ExtraHeaders: map[string]string{ - "Authorization": "Bearer " + c.config.BotToken, - }, - }) -} - -func (c *SlackChannel) stripBotMention(text string) string { - mention := fmt.Sprintf("<@%s>", c.botUserID) - text = strings.ReplaceAll(text, mention, "") - return strings.TrimSpace(text) -} - -func parseSlackChatID(chatID string) (channelID, threadTS string) { - parts := strings.SplitN(chatID, "/", 2) - channelID = parts[0] - if len(parts) > 1 { - threadTS = parts[1] - } - return -} diff --git a/pkg/channels/slack_test.go b/pkg/channels/slack_test.go deleted file mode 100644 index 3707c2703..000000000 --- a/pkg/channels/slack_test.go +++ /dev/null @@ -1,174 +0,0 @@ -package channels - -import ( - "testing" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -func TestParseSlackChatID(t *testing.T) { - tests := []struct { - name string - chatID string - wantChanID string - wantThread string - }{ - { - name: "channel only", - chatID: "C123456", - wantChanID: "C123456", - wantThread: "", - }, - { - name: "channel with thread", - chatID: "C123456/1234567890.123456", - wantChanID: "C123456", - wantThread: "1234567890.123456", - }, - { - name: "DM channel", - chatID: "D987654", - wantChanID: "D987654", - wantThread: "", - }, - { - name: "empty string", - chatID: "", - wantChanID: "", - wantThread: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - chanID, threadTS := parseSlackChatID(tt.chatID) - if chanID != tt.wantChanID { - t.Errorf("parseSlackChatID(%q) channelID = %q, want %q", tt.chatID, chanID, tt.wantChanID) - } - if threadTS != tt.wantThread { - t.Errorf("parseSlackChatID(%q) threadTS = %q, want %q", tt.chatID, threadTS, tt.wantThread) - } - }) - } -} - -func TestStripBotMention(t *testing.T) { - ch := &SlackChannel{botUserID: "U12345BOT"} - - tests := []struct { - name string - input string - want string - }{ - { - name: "mention at start", - input: "<@U12345BOT> hello there", - want: "hello there", - }, - { - name: "mention in middle", - input: "hey <@U12345BOT> can you help", - want: "hey can you help", - }, - { - name: "no mention", - input: "hello world", - want: "hello world", - }, - { - name: "empty string", - input: "", - want: "", - }, - { - name: "only mention", - input: "<@U12345BOT>", - want: "", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := ch.stripBotMention(tt.input) - if got != tt.want { - t.Errorf("stripBotMention(%q) = %q, want %q", tt.input, got, tt.want) - } - }) - } -} - -func TestNewSlackChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing bot token", func(t *testing.T) { - cfg := config.SlackConfig{ - BotToken: "", - AppToken: "xapp-test", - } - _, err := NewSlackChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing bot_token, got nil") - } - }) - - t.Run("missing app token", func(t *testing.T) { - cfg := config.SlackConfig{ - BotToken: "xoxb-test", - AppToken: "", - } - _, err := NewSlackChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing app_token, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.SlackConfig{ - BotToken: "xoxb-test", - AppToken: "xapp-test", - AllowFrom: []string{"U123"}, - } - ch, err := NewSlackChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "slack" { - t.Errorf("Name() = %q, want %q", ch.Name(), "slack") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestSlackChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.SlackConfig{ - BotToken: "xoxb-test", - AppToken: "xapp-test", - AllowFrom: []string{}, - } - ch, _ := NewSlackChannel(cfg, msgBus) - if !ch.IsAllowed("U_ANYONE") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.SlackConfig{ - BotToken: "xoxb-test", - AppToken: "xapp-test", - AllowFrom: []string{"U_ALLOWED"}, - } - ch, _ := NewSlackChannel(cfg, msgBus) - if !ch.IsAllowed("U_ALLOWED") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("U_BLOCKED") { - t.Error("non-allowed user should be blocked") - } - }) -} diff --git a/pkg/channels/telegram.go b/pkg/channels/telegram.go deleted file mode 100644 index 5cd51e8bc..000000000 --- a/pkg/channels/telegram.go +++ /dev/null @@ -1,529 +0,0 @@ -package channels - -import ( - "context" - "fmt" - "net/http" - "net/url" - "os" - "regexp" - "strings" - "sync" - "time" - - "github.com/mymmrac/telego" - "github.com/mymmrac/telego/telegohandler" - th "github.com/mymmrac/telego/telegohandler" - tu "github.com/mymmrac/telego/telegoutil" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" -) - -type TelegramChannel struct { - *BaseChannel - bot *telego.Bot - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - transcriber *voice.GroqTranscriber - placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> thinkingCancel -} - -type thinkingCancel struct { - fn context.CancelFunc -} - -func (c *thinkingCancel) Cancel() { - if c != nil && c.fn != nil { - c.fn() - } -} - -func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { - var opts []telego.BotOption - telegramCfg := cfg.Channels.Telegram - - if telegramCfg.Proxy != "" { - proxyURL, parseErr := url.Parse(telegramCfg.Proxy) - if parseErr != nil { - return nil, fmt.Errorf("invalid proxy URL %q: %w", telegramCfg.Proxy, parseErr) - } - opts = append(opts, telego.WithHTTPClient(&http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyURL(proxyURL), - }, - })) - } else if os.Getenv("HTTP_PROXY") != "" || os.Getenv("HTTPS_PROXY") != "" { - // Use environment proxy if configured - opts = append(opts, telego.WithHTTPClient(&http.Client{ - Transport: &http.Transport{ - Proxy: http.ProxyFromEnvironment, - }, - })) - } - - bot, err := telego.NewBot(telegramCfg.Token, opts...) - if err != nil { - return nil, fmt.Errorf("failed to create telegram bot: %w", err) - } - - base := NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) - - return &TelegramChannel{ - BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), - bot: bot, - config: cfg, - chatIDs: make(map[string]int64), - transcriber: nil, - placeholders: sync.Map{}, - stopThinking: sync.Map{}, - }, nil -} - -func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - -func (c *TelegramChannel) Start(ctx context.Context) error { - logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - - updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ - Timeout: 30, - }) - if err != nil { - return fmt.Errorf("failed to start long polling: %w", err) - } - - bh, err := telegohandler.NewBotHandler(c.bot, updates) - if err != nil { - return fmt.Errorf("failed to create bot handler: %w", err) - } - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - c.commands.Help(ctx, message) - return nil - }, th.CommandEqual("help")) - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Start(ctx, message) - }, th.CommandEqual("start")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.Show(ctx, message) - }, th.CommandEqual("show")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.commands.List(ctx, message) - }, th.CommandEqual("list")) - - bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { - return c.handleMessage(ctx, &message) - }, th.AnyMessage()) - - c.setRunning(true) - logger.InfoCF("telegram", "Telegram bot connected", map[string]any{ - "username": c.bot.Username(), - }) - - go bh.Start() - - go func() { - <-ctx.Done() - bh.Stop() - }() - - return nil -} - -func (c *TelegramChannel) Stop(ctx context.Context) error { - logger.InfoC("telegram", "Stopping Telegram bot...") - c.setRunning(false) - return nil -} - -func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("telegram bot not running") - } - - chatID, err := parseChatID(msg.ChatID) - if err != nil { - return fmt.Errorf("invalid chat ID: %w", err) - } - - // Stop thinking animation - if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - if cf, ok := stop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(msg.ChatID) - } - - htmlContent := markdownToTelegramHTML(msg.Content) - - // Try to edit placeholder - if pID, ok := c.placeholders.Load(msg.ChatID); ok { - c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { - return nil - } - // Fallback to new message if edit fails - } - - tgMsg := tu.Message(tu.ID(chatID), htmlContent) - tgMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { - logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{ - "error": err.Error(), - }) - tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) - return err - } - - return nil -} - -func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { - if message == nil { - return fmt.Errorf("message is nil") - } - - user := message.From - if user == nil { - return fmt.Errorf("message sender (user) is nil") - } - - senderID := fmt.Sprintf("%d", user.ID) - if user.Username != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) - } - - // check allowlist to avoid downloading attachments for rejected users - if !c.IsAllowed(senderID) { - logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ - "user_id": senderID, - }) - return nil - } - - chatID := message.Chat.ID - c.chatIDs[senderID] = chatID - - content := "" - mediaPaths := []string{} - localFiles := []string{} // track local files that need cleanup - - // ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) - } - } - }() - - if message.Text != "" { - content += message.Text - } - - if message.Caption != "" { - if content != "" { - content += "\n" - } - content += message.Caption - } - - if len(message.Photo) > 0 { - photo := message.Photo[len(message.Photo)-1] - photoPath := c.downloadPhoto(ctx, photo.FileID) - if photoPath != "" { - localFiles = append(localFiles, photoPath) - mediaPaths = append(mediaPaths, photoPath) - if content != "" { - content += "\n" - } - content += "[image: photo]" - } - } - - if message.Voice != nil { - voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") - if voicePath != "" { - localFiles = append(localFiles, voicePath) - mediaPaths = append(mediaPaths, voicePath) - - transcribedText := "" - if c.transcriber != nil && c.transcriber.IsAvailable() { - transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - result, err := c.transcriber.Transcribe(transcriberCtx, voicePath) - if err != nil { - logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{ - "error": err.Error(), - "path": voicePath, - }) - transcribedText = "[voice (transcription failed)]" - } else { - transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = "[voice]" - } - - if content != "" { - content += "\n" - } - content += transcribedText - } - } - - if message.Audio != nil { - audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") - if audioPath != "" { - localFiles = append(localFiles, audioPath) - mediaPaths = append(mediaPaths, audioPath) - if content != "" { - content += "\n" - } - content += "[audio]" - } - } - - if message.Document != nil { - docPath := c.downloadFile(ctx, message.Document.FileID, "") - if docPath != "" { - localFiles = append(localFiles, docPath) - mediaPaths = append(mediaPaths, docPath) - if content != "" { - content += "\n" - } - content += "[file]" - } - } - - if content == "" { - content = "[empty message]" - } - - logger.DebugCF("telegram", "Received message", map[string]any{ - "sender_id": senderID, - "chat_id": fmt.Sprintf("%d", chatID), - "preview": utils.Truncate(content, 50), - }) - - // Thinking indicator - err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) - if err != nil { - logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{ - "error": err.Error(), - }) - } - - // Stop any previous thinking animation - chatIDStr := fmt.Sprintf("%d", chatID) - if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { - if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - } - - // Create cancel function for thinking state - _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) - c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) - - pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) - if err == nil { - pID := pMsg.MessageID - c.placeholders.Store(chatIDStr, pID) - } - - peerKind := "direct" - peerID := fmt.Sprintf("%d", user.ID) - if message.Chat.Type != "private" { - peerKind = "group" - peerID = fmt.Sprintf("%d", chatID) - } - - metadata := map[string]string{ - "message_id": fmt.Sprintf("%d", message.MessageID), - "user_id": fmt.Sprintf("%d", user.ID), - "username": user.Username, - "first_name": user.FirstName, - "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), - "peer_kind": peerKind, - "peer_id": peerID, - } - - c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) - return nil -} - -func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { - file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) - if err != nil { - logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{ - "error": err.Error(), - }) - return "" - } - - return c.downloadFileWithInfo(file, ".jpg") -} - -func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) string { - if file.FilePath == "" { - return "" - } - - url := c.bot.FileDownloadURL(file.FilePath) - logger.DebugCF("telegram", "File URL", map[string]any{"url": url}) - - // Use FilePath as filename for better identification - filename := file.FilePath + ext - return utils.DownloadFile(url, filename, utils.DownloadOptions{ - LoggerPrefix: "telegram", - }) -} - -func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { - file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) - if err != nil { - logger.ErrorCF("telegram", "Failed to get file", map[string]any{ - "error": err.Error(), - }) - return "" - } - - return c.downloadFileWithInfo(file, ext) -} - -func parseChatID(chatIDStr string) (int64, error) { - var id int64 - _, err := fmt.Sscanf(chatIDStr, "%d", &id) - return id, err -} - -func markdownToTelegramHTML(text string) string { - if text == "" { - return "" - } - - codeBlocks := extractCodeBlocks(text) - text = codeBlocks.text - - inlineCodes := extractInlineCodes(text) - text = inlineCodes.text - - text = regexp.MustCompile(`^#{1,6}\s+(.+)$`).ReplaceAllString(text, "$1") - - text = regexp.MustCompile(`^>\s*(.*)$`).ReplaceAllString(text, "$1") - - text = escapeHTML(text) - - text = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`).ReplaceAllString(text, `$1`) - - text = regexp.MustCompile(`\*\*(.+?)\*\*`).ReplaceAllString(text, "$1") - - text = regexp.MustCompile(`__(.+?)__`).ReplaceAllString(text, "$1") - - reItalic := regexp.MustCompile(`_([^_]+)_`) - text = reItalic.ReplaceAllStringFunc(text, func(s string) string { - match := reItalic.FindStringSubmatch(s) - if len(match) < 2 { - return s - } - return "" + match[1] + "" - }) - - text = regexp.MustCompile(`~~(.+?)~~`).ReplaceAllString(text, "$1") - - text = regexp.MustCompile(`^[-*]\s+`).ReplaceAllString(text, "• ") - - for i, code := range inlineCodes.codes { - escaped := escapeHTML(code) - text = strings.ReplaceAll(text, fmt.Sprintf("\x00IC%d\x00", i), fmt.Sprintf("%s", escaped)) - } - - for i, code := range codeBlocks.codes { - escaped := escapeHTML(code) - text = strings.ReplaceAll( - text, - fmt.Sprintf("\x00CB%d\x00", i), - fmt.Sprintf("
%s
", escaped), - ) - } - - return text -} - -type codeBlockMatch struct { - text string - codes []string -} - -func extractCodeBlocks(text string) codeBlockMatch { - re := regexp.MustCompile("```[\\w]*\\n?([\\s\\S]*?)```") - matches := re.FindAllStringSubmatch(text, -1) - - codes := make([]string, 0, len(matches)) - for _, match := range matches { - codes = append(codes, match[1]) - } - - i := 0 - text = re.ReplaceAllStringFunc(text, func(m string) string { - placeholder := fmt.Sprintf("\x00CB%d\x00", i) - i++ - return placeholder - }) - - return codeBlockMatch{text: text, codes: codes} -} - -type inlineCodeMatch struct { - text string - codes []string -} - -func extractInlineCodes(text string) inlineCodeMatch { - re := regexp.MustCompile("`([^`]+)`") - matches := re.FindAllStringSubmatch(text, -1) - - codes := make([]string, 0, len(matches)) - for _, match := range matches { - codes = append(codes, match[1]) - } - - i := 0 - text = re.ReplaceAllStringFunc(text, func(m string) string { - placeholder := fmt.Sprintf("\x00IC%d\x00", i) - i++ - return placeholder - }) - - return inlineCodeMatch{text: text, codes: codes} -} - -func escapeHTML(text string) string { - text = strings.ReplaceAll(text, "&", "&") - text = strings.ReplaceAll(text, "<", "<") - text = strings.ReplaceAll(text, ">", ">") - return text -} diff --git a/pkg/channels/telegram_commands.go b/pkg/channels/telegram_commands.go deleted file mode 100644 index a084b641b..000000000 --- a/pkg/channels/telegram_commands.go +++ /dev/null @@ -1,156 +0,0 @@ -package channels - -import ( - "context" - "fmt" - "strings" - - "github.com/mymmrac/telego" - - "github.com/sipeed/picoclaw/pkg/config" -) - -type TelegramCommander interface { - Help(ctx context.Context, message telego.Message) error - Start(ctx context.Context, message telego.Message) error - Show(ctx context.Context, message telego.Message) error - List(ctx context.Context, message telego.Message) error -} - -type cmd struct { - bot *telego.Bot - config *config.Config -} - -func NewTelegramCommands(bot *telego.Bot, cfg *config.Config) TelegramCommander { - return &cmd{ - bot: bot, - config: cfg, - } -} - -func commandArgs(text string) string { - parts := strings.SplitN(text, " ", 2) - if len(parts) < 2 { - return "" - } - return strings.TrimSpace(parts[1]) -} - -func (c *cmd) Help(ctx context.Context, message telego.Message) error { - msg := `/start - Start the bot -/help - Show this help message -/show [model|channel] - Show current configuration -/list [models|channels] - List available options - ` - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: msg, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Start(ctx context.Context, message telego.Message) error { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Hello! I am PicoClaw 🦞", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) Show(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /show [model|channel]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "model": - response = fmt.Sprintf("Current Model: %s (Provider: %s)", - c.config.Agents.Defaults.Model, - c.config.Agents.Defaults.Provider) - case "channel": - response = "Current Channel: telegram" - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'model' or 'channel'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} - -func (c *cmd) List(ctx context.Context, message telego.Message) error { - args := commandArgs(message.Text) - if args == "" { - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: "Usage: /list [models|channels]", - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err - } - - var response string - switch args { - case "models": - provider := c.config.Agents.Defaults.Provider - if provider == "" { - provider = "configured default" - } - response = fmt.Sprintf("Configured Model: %s\nProvider: %s\n\nTo change models, update config.yaml", - c.config.Agents.Defaults.Model, provider) - - case "channels": - var enabled []string - if c.config.Channels.Telegram.Enabled { - enabled = append(enabled, "telegram") - } - if c.config.Channels.WhatsApp.Enabled { - enabled = append(enabled, "whatsapp") - } - if c.config.Channels.Feishu.Enabled { - enabled = append(enabled, "feishu") - } - if c.config.Channels.Discord.Enabled { - enabled = append(enabled, "discord") - } - if c.config.Channels.Slack.Enabled { - enabled = append(enabled, "slack") - } - response = fmt.Sprintf("Enabled Channels:\n- %s", strings.Join(enabled, "\n- ")) - - default: - response = fmt.Sprintf("Unknown parameter: %s. Try 'models' or 'channels'.", args) - } - - _, err := c.bot.SendMessage(ctx, &telego.SendMessageParams{ - ChatID: telego.ChatID{ID: message.Chat.ID}, - Text: response, - ReplyParameters: &telego.ReplyParameters{ - MessageID: message.MessageID, - }, - }) - return err -} diff --git a/pkg/channels/wecom.go b/pkg/channels/wecom.go deleted file mode 100644 index f8daf89de..000000000 --- a/pkg/channels/wecom.go +++ /dev/null @@ -1,605 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel implementation -// Uses webhook callback mode for receiving messages and webhook API for sending replies - -package channels - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "net/http" - "sort" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -// WeComBotChannel implements the Channel interface for WeCom Bot (企业微信智能机器人) -// Uses webhook callback mode - simpler than WeCom App but only supports passive replies -type WeComBotChannel struct { - *BaseChannel - config config.WeComConfig - server *http.Server - ctx context.Context - cancel context.CancelFunc - processedMsgs map[string]bool // Message deduplication: msg_id -> processed - msgMu sync.RWMutex -} - -// WeComBotMessage represents the JSON message structure from WeCom Bot (AIBOT) -type WeComBotMessage struct { - MsgID string `json:"msgid"` - AIBotID string `json:"aibotid"` - ChatID string `json:"chatid"` // Session ID, only present for group chats - ChatType string `json:"chattype"` // "single" for DM, "group" for group chat - From struct { - UserID string `json:"userid"` - } `json:"from"` - ResponseURL string `json:"response_url"` - MsgType string `json:"msgtype"` // text, image, voice, file, mixed - Text struct { - Content string `json:"content"` - } `json:"text"` - Image struct { - URL string `json:"url"` - } `json:"image"` - Voice struct { - Content string `json:"content"` // Voice to text content - } `json:"voice"` - File struct { - URL string `json:"url"` - } `json:"file"` - Mixed struct { - MsgItem []struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text"` - Image struct { - URL string `json:"url"` - } `json:"image"` - } `json:"msg_item"` - } `json:"mixed"` - Quote struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text"` - } `json:"quote"` -} - -// WeComBotReplyMessage represents the reply message structure -type WeComBotReplyMessage struct { - MsgType string `json:"msgtype"` - Text struct { - Content string `json:"content"` - } `json:"text,omitempty"` -} - -// NewWeComBotChannel creates a new WeCom Bot channel instance -func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*WeComBotChannel, error) { - if cfg.Token == "" || cfg.WebhookURL == "" { - return nil, fmt.Errorf("wecom token and webhook_url are required") - } - - base := NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) - - return &WeComBotChannel{ - BaseChannel: base, - config: cfg, - processedMsgs: make(map[string]bool), - }, nil -} - -// Name returns the channel name -func (c *WeComBotChannel) Name() string { - return "wecom" -} - -// Start initializes the WeCom Bot channel with HTTP webhook server -func (c *WeComBotChannel) Start(ctx context.Context) error { - logger.InfoC("wecom", "Starting WeCom Bot channel...") - - c.ctx, c.cancel = context.WithCancel(ctx) - - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - - c.setRunning(true) - logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() - - return nil -} - -// Stop gracefully stops the WeCom Bot channel -func (c *WeComBotChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom", "Stopping WeCom Bot channel...") - - if c.cancel != nil { - c.cancel() - } - - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - - c.setRunning(false) - logger.InfoC("wecom", "WeCom Bot channel stopped") - return nil -} - -// Send sends a message to WeCom user via webhook API -// Note: WeCom Bot can only reply within the configured timeout (default 5 seconds) of receiving a message -// For delayed responses, we use the webhook URL -func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("wecom channel not running") - } - - logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) -} - -// handleWebhook handles incoming webhook requests from WeCom -func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - if r.Method == http.MethodGet { - // Handle verification request - c.handleVerification(ctx, w, r) - return - } - - if r.Method == http.MethodPost { - // Handle message callback - c.handleMessageCallback(ctx, w, r) - return - } - - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - echostr := query.Get("echostr") - - if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { - logger.WarnC("wecom", "Signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt echostr - // For AIBOT (智能机器人), receiveid should be empty string "" - // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") - if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Remove BOM and whitespace as per WeCom documentation - // The response must be plain text without quotes, BOM, or newlines - decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) - decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM - w.Write([]byte(decryptedEchoStr)) -} - -// handleMessageCallback handles incoming messages from WeCom -func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - - if msgSignature == "" || timestamp == "" || nonce == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Parse XML to get encrypted message - var encryptedMsg struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - Encrypt string `xml:"Encrypt"` - AgentID string `xml:"AgentID"` - } - - if err = xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.WarnC("wecom", "Message signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt message - // For AIBOT (智能机器人), receiveid should be empty string "" - // Reference: https://developer.work.weixin.qq.com/document/path/101033 - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") - if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted JSON message (AIBOT uses JSON format) - var msg WeComBotMessage - if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid message format", http.StatusBadRequest) - return - } - - // Process the message asynchronously with context - go c.processMessage(ctx, msg) - - // Return success response immediately - // WeCom Bot requires response within configured timeout (default 5 seconds) - w.Write([]byte("success")) -} - -// processMessage processes the received message -func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { - // Skip unsupported message types - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && - msg.MsgType != "mixed" { - logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{ - "msg_type": msg.MsgType, - }) - return - } - - // Message deduplication: Use msg_id to prevent duplicate processing - msgID := msg.MsgID - c.msgMu.Lock() - if c.processedMsgs[msgID] { - c.msgMu.Unlock() - logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ - "msg_id": msgID, - }) - return - } - c.processedMsgs[msgID] = true - c.msgMu.Unlock() - - // Clean up old messages periodically (keep last 1000) - if len(c.processedMsgs) > 1000 { - c.msgMu.Lock() - c.processedMsgs = make(map[string]bool) - c.msgMu.Unlock() - } - - senderID := msg.From.UserID - - // Determine if this is a group chat or direct message - // ChatType: "single" for DM, "group" for group chat - isGroupChat := msg.ChatType == "group" - - var chatID, peerKind, peerID string - if isGroupChat { - // Group chat: use ChatID as chatID and peer_id - chatID = msg.ChatID - peerKind = "group" - peerID = msg.ChatID - } else { - // Direct message: use senderID as chatID and peer_id - chatID = senderID - peerKind = "direct" - peerID = senderID - } - - // Extract content based on message type - var content string - switch msg.MsgType { - case "text": - content = msg.Text.Content - case "voice": - content = msg.Voice.Content // Voice to text content - case "mixed": - // For mixed messages, concatenate text items - for _, item := range msg.Mixed.MsgItem { - if item.MsgType == "text" { - content += item.Text.Content - } - } - case "image", "file": - // For image and file, we don't have text content - content = "" - } - - // Build metadata - metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": msg.MsgID, - "platform": "wecom", - "peer_kind": peerKind, - "peer_id": peerID, - "response_url": msg.ResponseURL, - } - if isGroupChat { - metadata["chat_id"] = msg.ChatID - metadata["sender_id"] = senderID - } - - logger.DebugCF("wecom", "Received message", map[string]any{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "peer_kind": peerKind, - "is_group_chat": isGroupChat, - "preview": utils.Truncate(content, 50), - }) - - // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) -} - -// sendWebhookReply sends a reply using the webhook URL -func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content string) error { - reply := WeComBotReplyMessage{ - MsgType: "text", - } - reply.Text.Content = content - - jsonData, err := json.Marshal(reply) - if err != nil { - return fmt.Errorf("failed to marshal reply: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, c.config.WebhookURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to send webhook reply: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - // Check response - var result struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - } - if err := json.Unmarshal(body, &result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if result.ErrCode != 0 { - return fmt.Errorf("webhook API error: %s (code: %d)", result.ErrMsg, result.ErrCode) - } - - return nil -} - -// handleHealth handles health check requests -func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]any{ - "status": "ok", - "running": c.IsRunning(), - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} - -// WeCom common utilities for both WeCom Bot and WeCom App -// The following functions were moved from wecom_common.go - -// WeComVerifySignature verifies the message signature for WeCom -// This is a common function used by both WeCom Bot and WeCom App -func WeComVerifySignature(token, msgSignature, timestamp, nonce, msgEncrypt string) bool { - if token == "" { - return true // Skip verification if token is not set - } - - // Sort parameters - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - - // Concatenate - str := strings.Join(params, "") - - // SHA1 hash - hash := sha1.Sum([]byte(str)) - expectedSignature := fmt.Sprintf("%x", hash) - - return expectedSignature == msgSignature -} - -// WeComDecryptMessage decrypts the encrypted message using AES -// This is a common function used by both WeCom Bot and WeCom App -// For AIBOT, receiveid should be the aibotid; for other apps, it should be corp_id -func WeComDecryptMessage(encryptedMsg, encodingAESKey string) (string, error) { - return WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, "") -} - -// WeComDecryptMessageWithVerify decrypts the encrypted message and optionally verifies receiveid -// receiveid: for AIBOT use aibotid, for WeCom App use corp_id. If empty, skip verification. -func WeComDecryptMessageWithVerify(encryptedMsg, encodingAESKey, receiveid string) (string, error) { - if encodingAESKey == "" { - // No encryption, return as is (base64 decode) - decoded, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", err - } - return string(decoded), nil - } - - // Decode AES key (base64) - aesKey, err := base64.StdEncoding.DecodeString(encodingAESKey + "=") - if err != nil { - return "", fmt.Errorf("failed to decode AES key: %w", err) - } - - // Decode encrypted message - cipherText, err := base64.StdEncoding.DecodeString(encryptedMsg) - if err != nil { - return "", fmt.Errorf("failed to decode message: %w", err) - } - - // AES decrypt - block, err := aes.NewCipher(aesKey) - if err != nil { - return "", fmt.Errorf("failed to create cipher: %w", err) - } - - if len(cipherText) < aes.BlockSize { - return "", fmt.Errorf("ciphertext too short") - } - - // IV is the first 16 bytes of AESKey - iv := aesKey[:aes.BlockSize] - mode := cipher.NewCBCDecrypter(block, iv) - plainText := make([]byte, len(cipherText)) - mode.CryptBlocks(plainText, cipherText) - - // Remove PKCS7 padding - plainText, err = pkcs7UnpadWeCom(plainText) - if err != nil { - return "", fmt.Errorf("failed to unpad: %w", err) - } - - // Parse message structure - // Format: random(16) + msg_len(4) + msg + receiveid - if len(plainText) < 20 { - return "", fmt.Errorf("decrypted message too short") - } - - msgLen := binary.BigEndian.Uint32(plainText[16:20]) - if int(msgLen) > len(plainText)-20 { - return "", fmt.Errorf("invalid message length") - } - - msg := plainText[20 : 20+msgLen] - - // Verify receiveid if provided - if receiveid != "" && len(plainText) > 20+int(msgLen) { - actualReceiveID := string(plainText[20+msgLen:]) - if actualReceiveID != receiveid { - return "", fmt.Errorf("receiveid mismatch: expected %s, got %s", receiveid, actualReceiveID) - } - } - - return string(msg), nil -} - -// pkcs7UnpadWeCom removes PKCS7 padding with validation -// WeCom uses block size of 32 (not standard AES block size of 16) -const wecomBlockSize = 32 - -func pkcs7UnpadWeCom(data []byte) ([]byte, error) { - if len(data) == 0 { - return data, nil - } - padding := int(data[len(data)-1]) - // WeCom uses 32-byte block size for PKCS7 padding - if padding == 0 || padding > wecomBlockSize { - return nil, fmt.Errorf("invalid padding size: %d", padding) - } - if padding > len(data) { - return nil, fmt.Errorf("padding size larger than data") - } - // Verify all padding bytes - for i := 0; i < padding; i++ { - if data[len(data)-1-i] != byte(padding) { - return nil, fmt.Errorf("invalid padding byte at position %d", i) - } - } - return data[:len(data)-padding], nil -} diff --git a/pkg/channels/wecom_app.go b/pkg/channels/wecom_app.go deleted file mode 100644 index 715c48707..000000000 --- a/pkg/channels/wecom_app.go +++ /dev/null @@ -1,639 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel implementation -// Supports receiving messages via webhook callback and sending messages proactively - -package channels - -import ( - "bytes" - "context" - "encoding/json" - "encoding/xml" - "fmt" - "io" - "net/http" - "net/url" - "strings" - "sync" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/logger" - "github.com/sipeed/picoclaw/pkg/utils" -) - -const ( - wecomAPIBase = "https://qyapi.weixin.qq.com" -) - -// WeComAppChannel implements the Channel interface for WeCom App (企业微信自建应用) -type WeComAppChannel struct { - *BaseChannel - config config.WeComAppConfig - server *http.Server - accessToken string - tokenExpiry time.Time - tokenMu sync.RWMutex - ctx context.Context - cancel context.CancelFunc - processedMsgs map[string]bool // Message deduplication: msg_id -> processed - msgMu sync.RWMutex -} - -// WeComXMLMessage represents the XML message structure from WeCom -type WeComXMLMessage struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - FromUserName string `xml:"FromUserName"` - CreateTime int64 `xml:"CreateTime"` - MsgType string `xml:"MsgType"` - Content string `xml:"Content"` - MsgId int64 `xml:"MsgId"` - AgentID int64 `xml:"AgentID"` - PicUrl string `xml:"PicUrl"` - MediaId string `xml:"MediaId"` - Format string `xml:"Format"` - ThumbMediaId string `xml:"ThumbMediaId"` - LocationX float64 `xml:"Location_X"` - LocationY float64 `xml:"Location_Y"` - Scale int `xml:"Scale"` - Label string `xml:"Label"` - Title string `xml:"Title"` - Description string `xml:"Description"` - Url string `xml:"Url"` - Event string `xml:"Event"` - EventKey string `xml:"EventKey"` -} - -// WeComTextMessage represents text message for sending -type WeComTextMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Text struct { - Content string `json:"content"` - } `json:"text"` - Safe int `json:"safe,omitempty"` -} - -// WeComMarkdownMessage represents markdown message for sending -type WeComMarkdownMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Markdown struct { - Content string `json:"content"` - } `json:"markdown"` -} - -// WeComImageMessage represents image message for sending -type WeComImageMessage struct { - ToUser string `json:"touser"` - MsgType string `json:"msgtype"` - AgentID int64 `json:"agentid"` - Image struct { - MediaID string `json:"media_id"` - } `json:"image"` -} - -// WeComAccessTokenResponse represents the access token API response -type WeComAccessTokenResponse struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - AccessToken string `json:"access_token"` - ExpiresIn int `json:"expires_in"` -} - -// WeComSendMessageResponse represents the send message API response -type WeComSendMessageResponse struct { - ErrCode int `json:"errcode"` - ErrMsg string `json:"errmsg"` - InvalidUser string `json:"invaliduser"` - InvalidParty string `json:"invalidparty"` - InvalidTag string `json:"invalidtag"` -} - -// PKCS7Padding adds PKCS7 padding -type PKCS7Padding struct{} - -// NewWeComAppChannel creates a new WeCom App channel instance -func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) (*WeComAppChannel, error) { - if cfg.CorpID == "" || cfg.CorpSecret == "" || cfg.AgentID == 0 { - return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") - } - - base := NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) - - return &WeComAppChannel{ - BaseChannel: base, - config: cfg, - processedMsgs: make(map[string]bool), - }, nil -} - -// Name returns the channel name -func (c *WeComAppChannel) Name() string { - return "wecom_app" -} - -// Start initializes the WeCom App channel with HTTP webhook server -func (c *WeComAppChannel) Start(ctx context.Context) error { - logger.InfoC("wecom_app", "Starting WeCom App channel...") - - c.ctx, c.cancel = context.WithCancel(ctx) - - // Get initial access token - if err := c.refreshAccessToken(); err != nil { - logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{ - "error": err.Error(), - }) - } - - // Start token refresh goroutine - go c.tokenRefreshLoop() - - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom-app" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom-app", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - - c.setRunning(true) - logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() - - return nil -} - -// Stop gracefully stops the WeCom App channel -func (c *WeComAppChannel) Stop(ctx context.Context) error { - logger.InfoC("wecom_app", "Stopping WeCom App channel...") - - if c.cancel != nil { - c.cancel() - } - - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - - c.setRunning(false) - logger.InfoC("wecom_app", "WeCom App channel stopped") - return nil -} - -// Send sends a message to WeCom user proactively using access token -func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - if !c.IsRunning() { - return fmt.Errorf("wecom_app channel not running") - } - - accessToken := c.getAccessToken() - if accessToken == "" { - return fmt.Errorf("no valid access token available") - } - - logger.DebugCF("wecom_app", "Sending message", map[string]any{ - "chat_id": msg.ChatID, - "preview": utils.Truncate(msg.Content, 100), - }) - - return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) -} - -// handleWebhook handles incoming webhook requests from WeCom -func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - // Log all incoming requests for debugging - logger.DebugCF("wecom_app", "Received webhook request", map[string]any{ - "method": r.Method, - "url": r.URL.String(), - "path": r.URL.Path, - "query": r.URL.RawQuery, - }) - - if r.Method == http.MethodGet { - // Handle verification request - c.handleVerification(ctx, w, r) - return - } - - if r.Method == http.MethodPost { - // Handle message callback - c.handleMessageCallback(ctx, w, r) - return - } - - logger.WarnCF("wecom_app", "Method not allowed", map[string]any{ - "method": r.Method, - }) - http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) -} - -// handleVerification handles the URL verification request from WeCom -func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - echostr := query.Get("echostr") - - logger.DebugCF("wecom_app", "Handling verification request", map[string]any{ - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - "echostr": echostr, - "corp_id": c.config.CorpID, - }) - - if msgSignature == "" || timestamp == "" || nonce == "" || echostr == "" { - logger.ErrorC("wecom_app", "Missing parameters in verification request") - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { - logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ - "token": c.config.Token, - "msg_signature": msgSignature, - "timestamp": timestamp, - "nonce": nonce, - }) - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - logger.DebugC("wecom_app", "Signature verification passed") - - // Decrypt echostr with CorpID verification - // For WeCom App (自建应用), receiveid should be corp_id - logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{ - "encoding_aes_key": c.config.EncodingAESKey, - "corp_id": c.config.CorpID, - }) - decryptedEchoStr, err := WeComDecryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ - "error": err.Error(), - "encoding_aes_key": c.config.EncodingAESKey, - "corp_id": c.config.CorpID, - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{ - "decrypted": decryptedEchoStr, - }) - - // Remove BOM and whitespace as per WeCom documentation - // The response must be plain text without quotes, BOM, or newlines - decryptedEchoStr = strings.TrimSpace(decryptedEchoStr) - decryptedEchoStr = strings.TrimPrefix(decryptedEchoStr, "\xef\xbb\xbf") // Remove UTF-8 BOM - w.Write([]byte(decryptedEchoStr)) -} - -// handleMessageCallback handles incoming messages from WeCom -func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.ResponseWriter, r *http.Request) { - query := r.URL.Query() - msgSignature := query.Get("msg_signature") - timestamp := query.Get("timestamp") - nonce := query.Get("nonce") - - if msgSignature == "" || timestamp == "" || nonce == "" { - http.Error(w, "Missing parameters", http.StatusBadRequest) - return - } - - // Read request body - body, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, "Failed to read body", http.StatusBadRequest) - return - } - defer r.Body.Close() - - // Parse XML to get encrypted message - var encryptedMsg struct { - XMLName xml.Name `xml:"xml"` - ToUserName string `xml:"ToUserName"` - Encrypt string `xml:"Encrypt"` - AgentID string `xml:"AgentID"` - } - - if err = xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid XML", http.StatusBadRequest) - return - } - - // Verify signature - if !WeComVerifySignature(c.config.Token, msgSignature, timestamp, nonce, encryptedMsg.Encrypt) { - logger.WarnC("wecom_app", "Message signature verification failed") - http.Error(w, "Invalid signature", http.StatusForbidden) - return - } - - // Decrypt message with CorpID verification - // For WeCom App (自建应用), receiveid should be corp_id - decryptedMsg, err := WeComDecryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) - if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Decryption failed", http.StatusInternalServerError) - return - } - - // Parse decrypted XML message - var msg WeComXMLMessage - if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{ - "error": err.Error(), - }) - http.Error(w, "Invalid message format", http.StatusBadRequest) - return - } - - // Process the message with context - go c.processMessage(ctx, msg) - - // Return success response immediately - // WeCom App requires response within configured timeout (default 5 seconds) - w.Write([]byte("success")) -} - -// processMessage processes the received message -func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { - // Skip non-text messages for now (can be extended) - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { - logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{ - "msg_type": msg.MsgType, - }) - return - } - - // Message deduplication: Use msg_id to prevent duplicate processing - // As per WeCom documentation, use msg_id for deduplication - msgID := fmt.Sprintf("%d", msg.MsgId) - c.msgMu.Lock() - if c.processedMsgs[msgID] { - c.msgMu.Unlock() - logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ - "msg_id": msgID, - }) - return - } - c.processedMsgs[msgID] = true - c.msgMu.Unlock() - - // Clean up old messages periodically (keep last 1000) - if len(c.processedMsgs) > 1000 { - c.msgMu.Lock() - c.processedMsgs = make(map[string]bool) - c.msgMu.Unlock() - } - - senderID := msg.FromUserName - chatID := senderID // WeCom App uses user ID as chat ID for direct messages - - // Build metadata - // WeCom App only supports direct messages (private chat) - metadata := map[string]string{ - "msg_type": msg.MsgType, - "msg_id": fmt.Sprintf("%d", msg.MsgId), - "agent_id": fmt.Sprintf("%d", msg.AgentID), - "platform": "wecom_app", - "media_id": msg.MediaId, - "create_time": fmt.Sprintf("%d", msg.CreateTime), - "peer_kind": "direct", - "peer_id": senderID, - } - - content := msg.Content - - logger.DebugCF("wecom_app", "Received message", map[string]any{ - "sender_id": senderID, - "msg_type": msg.MsgType, - "preview": utils.Truncate(content, 50), - }) - - // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) -} - -// tokenRefreshLoop periodically refreshes the access token -func (c *WeComAppChannel) tokenRefreshLoop() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - - for { - select { - case <-c.ctx.Done(): - return - case <-ticker.C: - if err := c.refreshAccessToken(); err != nil { - logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{ - "error": err.Error(), - }) - } - } - } -} - -// refreshAccessToken gets a new access token from WeCom API -func (c *WeComAppChannel) refreshAccessToken() error { - apiURL := fmt.Sprintf("%s/cgi-bin/gettoken?corpid=%s&corpsecret=%s", - wecomAPIBase, url.QueryEscape(c.config.CorpID), url.QueryEscape(c.config.CorpSecret)) - - resp, err := http.Get(apiURL) - if err != nil { - return fmt.Errorf("failed to request access token: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var tokenResp WeComAccessTokenResponse - if err := json.Unmarshal(body, &tokenResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if tokenResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", tokenResp.ErrMsg, tokenResp.ErrCode) - } - - c.tokenMu.Lock() - c.accessToken = tokenResp.AccessToken - c.tokenExpiry = time.Now().Add(time.Duration(tokenResp.ExpiresIn-300) * time.Second) // Refresh 5 minutes early - c.tokenMu.Unlock() - - logger.DebugC("wecom_app", "Access token refreshed successfully") - return nil -} - -// getAccessToken returns the current valid access token -func (c *WeComAppChannel) getAccessToken() string { - c.tokenMu.RLock() - defer c.tokenMu.RUnlock() - - if time.Now().After(c.tokenExpiry) { - return "" - } - - return c.accessToken -} - -// sendTextMessage sends a text message to a user -func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, userID, content string) error { - apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - - msg := WeComTextMessage{ - ToUser: userID, - MsgType: "text", - AgentID: c.config.AgentID, - } - msg.Text.Content = content - - jsonData, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var sendResp WeComSendMessageResponse - if err := json.Unmarshal(body, &sendResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if sendResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) - } - - return nil -} - -// sendMarkdownMessage sends a markdown message to a user -func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, userID, content string) error { - apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) - - msg := WeComMarkdownMessage{ - ToUser: userID, - MsgType: "markdown", - AgentID: c.config.AgentID, - } - msg.Markdown.Content = content - - jsonData, err := json.Marshal(msg) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - // Use configurable timeout (default 5 seconds) - timeout := c.config.ReplyTimeout - if timeout <= 0 { - timeout = 5 - } - - reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) - defer cancel() - - req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) - if err != nil { - return fmt.Errorf("failed to create request: %w", err) - } - req.Header.Set("Content-Type", "application/json") - - client := &http.Client{Timeout: time.Duration(timeout) * time.Second} - resp, err := client.Do(req) - if err != nil { - return fmt.Errorf("failed to send message: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return fmt.Errorf("failed to read response: %w", err) - } - - var sendResp WeComSendMessageResponse - if err := json.Unmarshal(body, &sendResp); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - if sendResp.ErrCode != 0 { - return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) - } - - return nil -} - -// handleHealth handles health check requests -func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]any{ - "status": "ok", - "running": c.IsRunning(), - "has_token": c.getAccessToken() != "", - } - - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(status) -} diff --git a/pkg/channels/wecom_app_test.go b/pkg/channels/wecom_app_test.go deleted file mode 100644 index abf15c52b..000000000 --- a/pkg/channels/wecom_app_test.go +++ /dev/null @@ -1,1104 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom App (企业微信自建应用) channel tests - -package channels - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "sort" - "strings" - "testing" - "time" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// generateTestAESKeyApp generates a valid test AES key for WeCom App -func generateTestAESKeyApp() string { - // AES key needs to be 32 bytes (256 bits) for AES-256 - key := make([]byte, 32) - for i := range key { - key[i] = byte(i + 1) - } - // Return base64 encoded key without padding - return base64.StdEncoding.EncodeToString(key)[:43] -} - -// encryptTestMessageApp encrypts a message for testing WeCom App -func encryptTestMessageApp(message, aesKey string) (string, error) { - // Decode AES key - key, err := base64.StdEncoding.DecodeString(aesKey + "=") - if err != nil { - return "", err - } - - // Prepare message: random(16) + msg_len(4) + msg + corp_id - random := make([]byte, 0, 16) - for i := 0; i < 16; i++ { - random = append(random, byte(i+1)) - } - - msgBytes := []byte(message) - corpID := []byte("test_corp_id") - - msgLen := uint32(len(msgBytes)) - lenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lenBytes, msgLen) - - plainText := append(random, lenBytes...) - plainText = append(plainText, msgBytes...) - plainText = append(plainText, corpID...) - - // PKCS7 padding - blockSize := aes.BlockSize - padding := blockSize - len(plainText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - plainText = append(plainText, padText...) - - // Encrypt - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) - cipherText := make([]byte, len(plainText)) - mode.CryptBlocks(cipherText, plainText) - - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// generateSignatureApp generates a signature for testing WeCom App -func generateSignatureApp(token, timestamp, nonce, msgEncrypt string) string { - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -func TestNewWeComAppChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing corp_id", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "", - CorpSecret: "test_secret", - AgentID: 1000002, - } - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing corp_id, got nil") - } - }) - - t.Run("missing corp_secret", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "", - AgentID: 1000002, - } - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing corp_secret, got nil") - } - }) - - t.Run("missing agent_id", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 0, - } - _, err := NewWeComAppChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing agent_id, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - AllowFrom: []string{"user1", "user2"}, - } - ch, err := NewWeComAppChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "wecom_app" { - t.Errorf("Name() = %q, want %q", ch.Name(), "wecom_app") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestWeComAppChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - AllowFrom: []string{}, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - if !ch.IsAllowed("any_user") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - AllowFrom: []string{"allowed_user"}, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - if !ch.IsAllowed("allowed_user") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("blocked_user") { - t.Error("non-allowed user should be blocked") - } - }) -} - -func TestWeComAppVerifySignature(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - Token: "test_token", - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - expectedSig := generateSignatureApp("test_token", timestamp, nonce, msgEncrypt) - - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { - t.Error("valid signature should pass verification") - } - }) - - t.Run("invalid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { - t.Error("invalid signature should fail verification") - } - }) - - t.Run("empty token skips verification", func(t *testing.T) { - cfgEmpty := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - Token: "", - } - chEmpty, _ := NewWeComAppChannel(cfgEmpty, msgBus) - - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should skip verification and return true") - } - }) -} - -func TestWeComAppDecryptMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("decrypt without AES key", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - EncodingAESKey: "", - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - // Without AES key, message should be base64 decoded only - plainText := "hello world" - encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != plainText { - t.Errorf("decryptMessage() = %q, want %q", result, plainText) - } - }) - - t.Run("decrypt with AES key", func(t *testing.T) { - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - EncodingAESKey: aesKey, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - originalMsg := "Hello" - encrypted, err := encryptTestMessageApp(originalMsg, aesKey) - if err != nil { - t.Fatalf("failed to encrypt test message: %v", err) - } - - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != originalMsg { - t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) - } - }) - - t.Run("invalid base64", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - EncodingAESKey: "", - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) - if err == nil { - t.Error("expected error for invalid base64, got nil") - } - }) - - t.Run("invalid AES key", func(t *testing.T) { - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - EncodingAESKey: "invalid_key", - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) - if err == nil { - t.Error("expected error for invalid AES key, got nil") - } - }) - - t.Run("ciphertext too short", func(t *testing.T) { - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - EncodingAESKey: aesKey, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - // Encrypt a very short message that results in ciphertext less than block size - shortData := make([]byte, 8) - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString(shortData), ch.config.EncodingAESKey) - if err == nil { - t.Error("expected error for short ciphertext, got nil") - } - }) -} - -func TestWeComAppPKCS7Unpad(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "empty input", - input: []byte{}, - expected: []byte{}, - }, - { - name: "valid padding 3 bytes", - input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), - expected: []byte("hello"), - }, - { - name: "valid padding 16 bytes (full block)", - input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("123456789012345"), - }, - { - name: "invalid padding larger than data", - input: []byte{20}, - expected: nil, // should return error - }, - { - name: "invalid padding zero", - input: append([]byte("test"), byte(0)), - expected: nil, // should return error - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) - if tt.expected == nil { - // This case should return an error - if err == nil { - t.Errorf("pkcs7Unpad() expected error for invalid padding, got result: %v", result) - } - return - } - if err != nil { - t.Errorf("pkcs7Unpad() unexpected error: %v", err) - return - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7Unpad() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestWeComAppHandleVerification(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - Token: "test_token", - EncodingAESKey: aesKey, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid verification request", func(t *testing.T) { - echostr := "test_echostr_123" - encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != echostr { - t.Errorf("response body = %q, want %q", w.Body.String(), echostr) - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=sig×tamp=ts", nil) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - echostr := "test_echostr" - encryptedEchostr, _ := encryptTestMessageApp(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComAppHandleMessageCallback(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKeyApp() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - Token: "test_token", - EncodingAESKey: aesKey, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("valid message callback", func(t *testing.T) { - // Create XML message - xmlMsg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - AgentID: 1000002, - } - xmlData, _ := xml.Marshal(xmlMsg) - - // Encrypt message - encrypted, _ := encryptTestMessageApp(string(xmlData), aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=sig", nil) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid XML", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, "") - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - strings.NewReader("invalid xml"), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: "encrypted_data", - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComAppProcessMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("process text message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "text", - Content: "Hello World", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process image message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "image", - PicUrl: "https://example.com/image.jpg", - MediaId: "media_123", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process voice message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "voice", - MediaId: "media_123", - Format: "amr", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "video", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process event message", func(t *testing.T) { - msg := WeComXMLMessage{ - ToUserName: "corp_id", - FromUserName: "user123", - CreateTime: 1234567890, - MsgType: "event", - Event: "subscribe", - MsgId: 123456, - AgentID: 1000002, - } - - // Should not panic - ch.processMessage(context.Background(), msg) - }) -} - -func TestWeComAppHandleWebhook(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - Token: "test_token", - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("GET request calls verification", func(t *testing.T) { - echostr := "test_echostr" - encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encoded) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, - nil, - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - }) - - t.Run("POST request calls message callback", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - // Should not be method not allowed - if w.Code == http.StatusMethodNotAllowed { - t.Error("POST request should not return Method Not Allowed") - } - }) - - t.Run("unsupported method", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPut, "/webhook/wecom-app", nil) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - }) -} - -func TestWeComAppHandleHealth(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - req := httptest.NewRequest(http.MethodGet, "/health/wecom-app", nil) - w := httptest.NewRecorder() - - ch.handleHealth(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Content-Type = %q, want %q", contentType, "application/json") - } - - body := w.Body.String() - if !strings.Contains(body, "status") || !strings.Contains(body, "running") || !strings.Contains(body, "has_token") { - t.Errorf("response body should contain status, running, and has_token fields, got: %s", body) - } -} - -func TestWeComAppAccessToken(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComAppConfig{ - CorpID: "test_corp_id", - CorpSecret: "test_secret", - AgentID: 1000002, - } - ch, _ := NewWeComAppChannel(cfg, msgBus) - - t.Run("get empty access token initially", func(t *testing.T) { - token := ch.getAccessToken() - if token != "" { - t.Errorf("getAccessToken() = %q, want empty string", token) - } - }) - - t.Run("set and get access token", func(t *testing.T) { - ch.tokenMu.Lock() - ch.accessToken = "test_token_123" - ch.tokenExpiry = time.Now().Add(1 * time.Hour) - ch.tokenMu.Unlock() - - token := ch.getAccessToken() - if token != "test_token_123" { - t.Errorf("getAccessToken() = %q, want %q", token, "test_token_123") - } - }) - - t.Run("expired token returns empty", func(t *testing.T) { - ch.tokenMu.Lock() - ch.accessToken = "expired_token" - ch.tokenExpiry = time.Now().Add(-1 * time.Hour) - ch.tokenMu.Unlock() - - token := ch.getAccessToken() - if token != "" { - t.Errorf("getAccessToken() = %q, want empty string for expired token", token) - } - }) -} - -func TestWeComAppMessageStructures(t *testing.T) { - t.Run("WeComTextMessage structure", func(t *testing.T) { - msg := WeComTextMessage{ - ToUser: "user123", - MsgType: "text", - AgentID: 1000002, - } - msg.Text.Content = "Hello World" - - if msg.ToUser != "user123" { - t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } - - // Test JSON marshaling - jsonData, err := json.Marshal(msg) - if err != nil { - t.Fatalf("failed to marshal JSON: %v", err) - } - - var unmarshaled WeComTextMessage - err = json.Unmarshal(jsonData, &unmarshaled) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if unmarshaled.ToUser != msg.ToUser { - t.Errorf("JSON round-trip failed for ToUser") - } - }) - - t.Run("WeComMarkdownMessage structure", func(t *testing.T) { - msg := WeComMarkdownMessage{ - ToUser: "user123", - MsgType: "markdown", - AgentID: 1000002, - } - msg.Markdown.Content = "# Hello\nWorld" - - if msg.Markdown.Content != "# Hello\nWorld" { - t.Errorf("Markdown.Content = %q, want %q", msg.Markdown.Content, "# Hello\nWorld") - } - - // Test JSON marshaling - jsonData, err := json.Marshal(msg) - if err != nil { - t.Fatalf("failed to marshal JSON: %v", err) - } - - if !bytes.Contains(jsonData, []byte("markdown")) { - t.Error("JSON should contain 'markdown' field") - } - }) - - t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { - jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "access_token": "test_access_token", - "expires_in": 7200 - }` - - var resp WeComAccessTokenResponse - err := json.Unmarshal([]byte(jsonData), &resp) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if resp.ErrCode != 0 { - t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) - } - if resp.ErrMsg != "ok" { - t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") - } - if resp.AccessToken != "test_access_token" { - t.Errorf("AccessToken = %q, want %q", resp.AccessToken, "test_access_token") - } - if resp.ExpiresIn != 7200 { - t.Errorf("ExpiresIn = %d, want %d", resp.ExpiresIn, 7200) - } - }) - - t.Run("WeComSendMessageResponse structure", func(t *testing.T) { - jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "invaliduser": "", - "invalidparty": "", - "invalidtag": "" - }` - - var resp WeComSendMessageResponse - err := json.Unmarshal([]byte(jsonData), &resp) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if resp.ErrCode != 0 { - t.Errorf("ErrCode = %d, want %d", resp.ErrCode, 0) - } - if resp.ErrMsg != "ok" { - t.Errorf("ErrMsg = %q, want %q", resp.ErrMsg, "ok") - } - }) -} - -func TestWeComAppXMLMessageStructure(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.ToUserName != "corp_id" { - t.Errorf("ToUserName = %q, want %q", msg.ToUserName, "corp_id") - } - if msg.FromUserName != "user123" { - t.Errorf("FromUserName = %q, want %q", msg.FromUserName, "user123") - } - if msg.CreateTime != 1234567890 { - t.Errorf("CreateTime = %d, want %d", msg.CreateTime, 1234567890) - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Content != "Hello World" { - t.Errorf("Content = %q, want %q", msg.Content, "Hello World") - } - if msg.MsgId != 1234567890123456 { - t.Errorf("MsgId = %d, want %d", msg.MsgId, 1234567890123456) - } - if msg.AgentID != 1000002 { - t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) - } -} - -func TestWeComAppXMLMessageImage(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "image" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") - } - if msg.PicUrl != "https://example.com/image.jpg" { - t.Errorf("PicUrl = %q, want %q", msg.PicUrl, "https://example.com/image.jpg") - } - if msg.MediaId != "media_123" { - t.Errorf("MediaId = %q, want %q", msg.MediaId, "media_123") - } -} - -func TestWeComAppXMLMessageVoice(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "voice" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "voice") - } - if msg.Format != "amr" { - t.Errorf("Format = %q, want %q", msg.Format, "amr") - } -} - -func TestWeComAppXMLMessageLocation(t *testing.T) { - xmlData := ` - - - - 1234567890 - - 39.9042 - 116.4074 - 16 - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "location" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "location") - } - if msg.LocationX != 39.9042 { - t.Errorf("LocationX = %f, want %f", msg.LocationX, 39.9042) - } - if msg.LocationY != 116.4074 { - t.Errorf("LocationY = %f, want %f", msg.LocationY, 116.4074) - } - if msg.Scale != 16 { - t.Errorf("Scale = %d, want %d", msg.Scale, 16) - } - if msg.Label != "Beijing" { - t.Errorf("Label = %q, want %q", msg.Label, "Beijing") - } -} - -func TestWeComAppXMLMessageLink(t *testing.T) { - xmlData := ` - - - - 1234567890 - - <![CDATA[Link Title]]> - - - 1234567890123456 - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "link" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "link") - } - if msg.Title != "Link Title" { - t.Errorf("Title = %q, want %q", msg.Title, "Link Title") - } - if msg.Description != "Link Description" { - t.Errorf("Description = %q, want %q", msg.Description, "Link Description") - } - if msg.Url != "https://example.com" { - t.Errorf("Url = %q, want %q", msg.Url, "https://example.com") - } -} - -func TestWeComAppXMLMessageEvent(t *testing.T) { - xmlData := ` - - - - 1234567890 - - - - 1000002 -` - - var msg WeComXMLMessage - err := xml.Unmarshal([]byte(xmlData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal XML: %v", err) - } - - if msg.MsgType != "event" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "event") - } - if msg.Event != "subscribe" { - t.Errorf("Event = %q, want %q", msg.Event, "subscribe") - } - if msg.EventKey != "event_key_123" { - t.Errorf("EventKey = %q, want %q", msg.EventKey, "event_key_123") - } -} diff --git a/pkg/channels/wecom_test.go b/pkg/channels/wecom_test.go deleted file mode 100644 index 8afa7e8c3..000000000 --- a/pkg/channels/wecom_test.go +++ /dev/null @@ -1,785 +0,0 @@ -// PicoClaw - Ultra-lightweight personal AI agent -// WeCom Bot (企业微信智能机器人) channel tests - -package channels - -import ( - "bytes" - "context" - "crypto/aes" - "crypto/cipher" - "crypto/sha1" - "encoding/base64" - "encoding/binary" - "encoding/json" - "encoding/xml" - "fmt" - "net/http" - "net/http/httptest" - "sort" - "strings" - "testing" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" -) - -// generateTestAESKey generates a valid test AES key -func generateTestAESKey() string { - // AES key needs to be 32 bytes (256 bits) for AES-256 - key := make([]byte, 32) - for i := range key { - key[i] = byte(i) - } - // Return base64 encoded key without padding - return base64.StdEncoding.EncodeToString(key)[:43] -} - -// encryptTestMessage encrypts a message for testing (AIBOT JSON format) -func encryptTestMessage(message, aesKey string) (string, error) { - // Decode AES key - key, err := base64.StdEncoding.DecodeString(aesKey + "=") - if err != nil { - return "", err - } - - // Prepare message: random(16) + msg_len(4) + msg + receiveid - random := make([]byte, 0, 16) - for i := 0; i < 16; i++ { - random = append(random, byte(i)) - } - - msgBytes := []byte(message) - receiveID := []byte("test_aibot_id") - - msgLen := uint32(len(msgBytes)) - lenBytes := make([]byte, 4) - binary.BigEndian.PutUint32(lenBytes, msgLen) - - plainText := append(random, lenBytes...) - plainText = append(plainText, msgBytes...) - plainText = append(plainText, receiveID...) - - // PKCS7 padding - blockSize := aes.BlockSize - padding := blockSize - len(plainText)%blockSize - padText := bytes.Repeat([]byte{byte(padding)}, padding) - plainText = append(plainText, padText...) - - // Encrypt - block, err := aes.NewCipher(key) - if err != nil { - return "", err - } - - mode := cipher.NewCBCEncrypter(block, key[:aes.BlockSize]) - cipherText := make([]byte, len(plainText)) - mode.CryptBlocks(cipherText, plainText) - - return base64.StdEncoding.EncodeToString(cipherText), nil -} - -// generateSignature generates a signature for testing -func generateSignature(token, timestamp, nonce, msgEncrypt string) string { - params := []string{token, timestamp, nonce, msgEncrypt} - sort.Strings(params) - str := strings.Join(params, "") - hash := sha1.Sum([]byte(str)) - return fmt.Sprintf("%x", hash) -} - -func TestNewWeComBotChannel(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("missing token", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - _, err := NewWeComBotChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing token, got nil") - } - }) - - t.Run("missing webhook_url", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "", - } - _, err := NewWeComBotChannel(cfg, msgBus) - if err == nil { - t.Error("expected error for missing webhook_url, got nil") - } - }) - - t.Run("valid config", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - AllowFrom: []string{"user1", "user2"}, - } - ch, err := NewWeComBotChannel(cfg, msgBus) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if ch.Name() != "wecom" { - t.Errorf("Name() = %q, want %q", ch.Name(), "wecom") - } - if ch.IsRunning() { - t.Error("new channel should not be running") - } - }) -} - -func TestWeComBotChannelIsAllowed(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("empty allowlist allows all", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - AllowFrom: []string{}, - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - if !ch.IsAllowed("any_user") { - t.Error("empty allowlist should allow all users") - } - }) - - t.Run("allowlist restricts users", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - AllowFrom: []string{"allowed_user"}, - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - if !ch.IsAllowed("allowed_user") { - t.Error("allowed user should pass allowlist check") - } - if ch.IsAllowed("blocked_user") { - t.Error("non-allowed user should be blocked") - } - }) -} - -func TestWeComBotVerifySignature(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - expectedSig := generateSignature("test_token", timestamp, nonce, msgEncrypt) - - if !WeComVerifySignature(ch.config.Token, expectedSig, timestamp, nonce, msgEncrypt) { - t.Error("valid signature should pass verification") - } - }) - - t.Run("invalid signature", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - msgEncrypt := "test_message" - - if WeComVerifySignature(ch.config.Token, "invalid_sig", timestamp, nonce, msgEncrypt) { - t.Error("invalid signature should fail verification") - } - }) - - t.Run("empty token skips verification", func(t *testing.T) { - // Create a channel manually with empty token to test the behavior - cfgEmpty := config.WeComConfig{ - Token: "", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - chEmpty := &WeComBotChannel{ - config: cfgEmpty, - } - - if !WeComVerifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { - t.Error("empty token should skip verification and return true") - } - }) -} - -func TestWeComBotDecryptMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - - t.Run("decrypt without AES key", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - EncodingAESKey: "", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - // Without AES key, message should be base64 decoded only - plainText := "hello world" - encoded := base64.StdEncoding.EncodeToString([]byte(plainText)) - - result, err := WeComDecryptMessage(encoded, ch.config.EncodingAESKey) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != plainText { - t.Errorf("decryptMessage() = %q, want %q", result, plainText) - } - }) - - t.Run("decrypt with AES key", func(t *testing.T) { - aesKey := generateTestAESKey() - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - EncodingAESKey: aesKey, - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - originalMsg := "Hello" - encrypted, err := encryptTestMessage(originalMsg, aesKey) - if err != nil { - t.Fatalf("failed to encrypt test message: %v", err) - } - - result, err := WeComDecryptMessage(encrypted, ch.config.EncodingAESKey) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - if result != originalMsg { - t.Errorf("WeComDecryptMessage() = %q, want %q", result, originalMsg) - } - }) - - t.Run("invalid base64", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - EncodingAESKey: "", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - _, err := WeComDecryptMessage("invalid_base64!!!", ch.config.EncodingAESKey) - if err == nil { - t.Error("expected error for invalid base64, got nil") - } - }) - - t.Run("invalid AES key", func(t *testing.T) { - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - EncodingAESKey: "invalid_key", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - _, err := WeComDecryptMessage(base64.StdEncoding.EncodeToString([]byte("test")), ch.config.EncodingAESKey) - if err == nil { - t.Error("expected error for invalid AES key, got nil") - } - }) -} - -func TestWeComBotPKCS7Unpad(t *testing.T) { - tests := []struct { - name string - input []byte - expected []byte - }{ - { - name: "empty input", - input: []byte{}, - expected: []byte{}, - }, - { - name: "valid padding 3 bytes", - input: append([]byte("hello"), bytes.Repeat([]byte{3}, 3)...), - expected: []byte("hello"), - }, - { - name: "valid padding 16 bytes (full block)", - input: append([]byte("123456789012345"), bytes.Repeat([]byte{16}, 16)...), - expected: []byte("123456789012345"), - }, - { - name: "invalid padding larger than data", - input: []byte{20}, - expected: nil, // should return error - }, - { - name: "invalid padding zero", - input: append([]byte("test"), byte(0)), - expected: nil, // should return error - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result, err := pkcs7UnpadWeCom(tt.input) - if tt.expected == nil { - // This case should return an error - if err == nil { - t.Errorf("pkcs7UnpadWeCom() expected error for invalid padding, got result: %v", result) - } - return - } - if err != nil { - t.Errorf("pkcs7UnpadWeCom() unexpected error: %v", err) - return - } - if !bytes.Equal(result, tt.expected) { - t.Errorf("pkcs7UnpadWeCom() = %v, want %v", result, tt.expected) - } - }) - } -} - -func TestWeComBotHandleVerification(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKey() - cfg := config.WeComConfig{ - Token: "test_token", - EncodingAESKey: aesKey, - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid verification request", func(t *testing.T) { - echostr := "test_echostr_123" - encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != echostr { - t.Errorf("response body = %q, want %q", w.Body.String(), echostr) - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=sig×tamp=ts", nil) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - echostr := "test_echostr" - encryptedEchostr, _ := encryptTestMessage(echostr, aesKey) - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, - nil, - ) - w := httptest.NewRecorder() - - ch.handleVerification(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComBotHandleMessageCallback(t *testing.T) { - msgBus := bus.NewMessageBus() - aesKey := generateTestAESKey() - cfg := config.WeComConfig{ - Token: "test_token", - EncodingAESKey: aesKey, - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("valid direct message callback", func(t *testing.T) { - // Create JSON message for direct chat (single) - jsonMsg := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chattype": "single", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` - - // Encrypt message - encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("valid group message callback", func(t *testing.T) { - // Create JSON message for group chat - jsonMsg := `{ - "msgid": "test_msg_id_456", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user456"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello Group"} - }` - - // Encrypt message - encrypted, _ := encryptTestMessage(jsonMsg, aesKey) - - // Create encrypted XML wrapper - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: encrypted, - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encrypted) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - if w.Body.String() != "success" { - t.Errorf("response body = %q, want %q", w.Body.String(), "success") - } - }) - - t.Run("missing parameters", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=sig", nil) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid XML", func(t *testing.T) { - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, "") - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - strings.NewReader("invalid xml"), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusBadRequest { - t.Errorf("status code = %d, want %d", w.Code, http.StatusBadRequest) - } - }) - - t.Run("invalid signature", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: "encrypted_data", - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleMessageCallback(context.Background(), w, req) - - if w.Code != http.StatusForbidden { - t.Errorf("status code = %d, want %d", w.Code, http.StatusForbidden) - } - }) -} - -func TestWeComBotProcessMessage(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("process direct text message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_123", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "text", - } - msg.From.UserID = "user123" - msg.Text.Content = "Hello World" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process group text message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_456", - AIBotID: "test_aibot_id", - ChatID: "group_chat_id_123", - ChatType: "group", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "text", - } - msg.From.UserID = "user456" - msg.Text.Content = "Hello Group" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("process voice message", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_789", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "voice", - } - msg.From.UserID = "user123" - msg.Voice.Content = "Voice message text" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) - - t.Run("skip unsupported message type", func(t *testing.T) { - msg := WeComBotMessage{ - MsgID: "test_msg_id_000", - AIBotID: "test_aibot_id", - ChatType: "single", - ResponseURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - MsgType: "video", - } - msg.From.UserID = "user123" - - // Should not panic - ch.processMessage(context.Background(), msg) - }) -} - -func TestWeComBotHandleWebhook(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - t.Run("GET request calls verification", func(t *testing.T) { - echostr := "test_echostr" - encoded := base64.StdEncoding.EncodeToString([]byte(echostr)) - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encoded) - - req := httptest.NewRequest( - http.MethodGet, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, - nil, - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - }) - - t.Run("POST request calls message callback", func(t *testing.T) { - encryptedWrapper := struct { - XMLName xml.Name `xml:"xml"` - Encrypt string `xml:"Encrypt"` - }{ - Encrypt: base64.StdEncoding.EncodeToString([]byte("test")), - } - wrapperData, _ := xml.Marshal(encryptedWrapper) - - timestamp := "1234567890" - nonce := "test_nonce" - signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - - req := httptest.NewRequest( - http.MethodPost, - "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, - bytes.NewReader(wrapperData), - ) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - // Should not be method not allowed - if w.Code == http.StatusMethodNotAllowed { - t.Error("POST request should not return Method Not Allowed") - } - }) - - t.Run("unsupported method", func(t *testing.T) { - req := httptest.NewRequest(http.MethodPut, "/webhook/wecom", nil) - w := httptest.NewRecorder() - - ch.handleWebhook(w, req) - - if w.Code != http.StatusMethodNotAllowed { - t.Errorf("status code = %d, want %d", w.Code, http.StatusMethodNotAllowed) - } - }) -} - -func TestWeComBotHandleHealth(t *testing.T) { - msgBus := bus.NewMessageBus() - cfg := config.WeComConfig{ - Token: "test_token", - WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - } - ch, _ := NewWeComBotChannel(cfg, msgBus) - - req := httptest.NewRequest(http.MethodGet, "/health/wecom", nil) - w := httptest.NewRecorder() - - ch.handleHealth(w, req) - - if w.Code != http.StatusOK { - t.Errorf("status code = %d, want %d", w.Code, http.StatusOK) - } - - contentType := w.Header().Get("Content-Type") - if contentType != "application/json" { - t.Errorf("Content-Type = %q, want %q", contentType, "application/json") - } - - body := w.Body.String() - if !strings.Contains(body, "status") || !strings.Contains(body, "running") { - t.Errorf("response body should contain status and running fields, got: %s", body) - } -} - -func TestWeComBotReplyMessage(t *testing.T) { - msg := WeComBotReplyMessage{ - MsgType: "text", - } - msg.Text.Content = "Hello World" - - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } -} - -func TestWeComBotMessageStructure(t *testing.T) { - jsonData := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` - - var msg WeComBotMessage - err := json.Unmarshal([]byte(jsonData), &msg) - if err != nil { - t.Fatalf("failed to unmarshal JSON: %v", err) - } - - if msg.MsgID != "test_msg_id_123" { - t.Errorf("MsgID = %q, want %q", msg.MsgID, "test_msg_id_123") - } - if msg.AIBotID != "test_aibot_id" { - t.Errorf("AIBotID = %q, want %q", msg.AIBotID, "test_aibot_id") - } - if msg.ChatID != "group_chat_id_123" { - t.Errorf("ChatID = %q, want %q", msg.ChatID, "group_chat_id_123") - } - if msg.ChatType != "group" { - t.Errorf("ChatType = %q, want %q", msg.ChatType, "group") - } - if msg.From.UserID != "user123" { - t.Errorf("From.UserID = %q, want %q", msg.From.UserID, "user123") - } - if msg.MsgType != "text" { - t.Errorf("MsgType = %q, want %q", msg.MsgType, "text") - } - if msg.Text.Content != "Hello World" { - t.Errorf("Text.Content = %q, want %q", msg.Text.Content, "Hello World") - } -} diff --git a/pkg/channels/whatsapp.go b/pkg/channels/whatsapp.go deleted file mode 100644 index 958d850bb..000000000 --- a/pkg/channels/whatsapp.go +++ /dev/null @@ -1,192 +0,0 @@ -package channels - -import ( - "context" - "encoding/json" - "fmt" - "log" - "sync" - "time" - - "github.com/gorilla/websocket" - - "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/config" - "github.com/sipeed/picoclaw/pkg/utils" -) - -type WhatsAppChannel struct { - *BaseChannel - conn *websocket.Conn - config config.WhatsAppConfig - url string - mu sync.Mutex - connected bool -} - -func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { - base := NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) - - return &WhatsAppChannel{ - BaseChannel: base, - config: cfg, - url: cfg.BridgeURL, - connected: false, - }, nil -} - -func (c *WhatsAppChannel) Start(ctx context.Context) error { - log.Printf("Starting WhatsApp channel connecting to %s...", c.url) - - dialer := websocket.DefaultDialer - dialer.HandshakeTimeout = 10 * time.Second - - conn, _, err := dialer.Dial(c.url, nil) - if err != nil { - return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) - } - - c.mu.Lock() - c.conn = conn - c.connected = true - c.mu.Unlock() - - c.setRunning(true) - log.Println("WhatsApp channel connected") - - go c.listen(ctx) - - return nil -} - -func (c *WhatsAppChannel) Stop(ctx context.Context) error { - log.Println("Stopping WhatsApp channel...") - - c.mu.Lock() - defer c.mu.Unlock() - - if c.conn != nil { - if err := c.conn.Close(); err != nil { - log.Printf("Error closing WhatsApp connection: %v", err) - } - c.conn = nil - } - - c.connected = false - c.setRunning(false) - - return nil -} - -func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.mu.Lock() - defer c.mu.Unlock() - - if c.conn == nil { - return fmt.Errorf("whatsapp connection not established") - } - - payload := map[string]any{ - "type": "message", - "to": msg.ChatID, - "content": msg.Content, - } - - data, err := json.Marshal(payload) - if err != nil { - return fmt.Errorf("failed to marshal message: %w", err) - } - - if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { - return fmt.Errorf("failed to send message: %w", err) - } - - return nil -} - -func (c *WhatsAppChannel) listen(ctx context.Context) { - for { - select { - case <-ctx.Done(): - return - default: - c.mu.Lock() - conn := c.conn - c.mu.Unlock() - - if conn == nil { - time.Sleep(1 * time.Second) - continue - } - - _, message, err := conn.ReadMessage() - if err != nil { - log.Printf("WhatsApp read error: %v", err) - time.Sleep(2 * time.Second) - continue - } - - var msg map[string]any - if err := json.Unmarshal(message, &msg); err != nil { - log.Printf("Failed to unmarshal WhatsApp message: %v", err) - continue - } - - msgType, ok := msg["type"].(string) - if !ok { - continue - } - - if msgType == "message" { - c.handleIncomingMessage(msg) - } - } - } -} - -func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { - senderID, ok := msg["from"].(string) - if !ok { - return - } - - chatID, ok := msg["chat"].(string) - if !ok { - chatID = senderID - } - - content, ok := msg["content"].(string) - if !ok { - content = "" - } - - var mediaPaths []string - if mediaData, ok := msg["media"].([]any); ok { - mediaPaths = make([]string, 0, len(mediaData)) - for _, m := range mediaData { - if path, ok := m.(string); ok { - mediaPaths = append(mediaPaths, path) - } - } - } - - metadata := make(map[string]string) - if messageID, ok := msg["id"].(string); ok { - metadata["message_id"] = messageID - } - if userName, ok := msg["from_name"].(string); ok { - metadata["user_name"] = userName - } - - if chatID == senderID { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID - } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID - } - - log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) - - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) -} From 420eadc2ba3051e3b00c228fc14d7f12ecfcb24b Mon Sep 17 00:00:00 2001 From: Hoshina Date: Fri, 20 Feb 2026 23:52:41 +0800 Subject: [PATCH 05/28] refactor(channels): remove redundant setRunning method from BaseChannel --- pkg/channels/base.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 3f0a766ea..ff734fdb0 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -98,10 +98,6 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st c.bus.PublishInbound(msg) } -func (c *BaseChannel) setRunning(running bool) { - c.running = running -} - func (c *BaseChannel) SetRunning(running bool) { c.running = running } From b1cbaaba570b1276a4109724a039c8232d2bc437 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sat, 21 Feb 2026 00:00:29 +0800 Subject: [PATCH 06/28] refactor(channels): replace bool with atomic.Bool for running state in BaseChannel --- pkg/channels/base.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/channels/base.go b/pkg/channels/base.go index ff734fdb0..5d77c6c0d 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -3,6 +3,7 @@ package channels import ( "context" "strings" + "sync/atomic" "github.com/sipeed/picoclaw/pkg/bus" ) @@ -19,7 +20,7 @@ type Channel interface { type BaseChannel struct { config any bus *bus.MessageBus - running bool + running atomic.Bool name string allowList []string } @@ -30,7 +31,6 @@ func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []st bus: bus, name: name, allowList: allowList, - running: false, } } @@ -39,7 +39,7 @@ func (c *BaseChannel) Name() string { } func (c *BaseChannel) IsRunning() bool { - return c.running + return c.running.Load() } func (c *BaseChannel) IsAllowed(senderID string) bool { @@ -99,5 +99,5 @@ func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []st } func (c *BaseChannel) SetRunning(running bool) { - c.running = running + c.running.Store(running) } From 00fd70e1aa2e3068b8caa27db13666be76154985 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sat, 21 Feb 2026 16:35:56 +0800 Subject: [PATCH 07/28] fix: golangci-lint run --fix --- cmd/picoclaw/cmd_gateway.go | 22 ++--- pkg/channels/dingtalk/dingtalk.go | 13 ++- pkg/channels/discord/discord.go | 5 +- pkg/channels/feishu/feishu_64.go | 6 +- pkg/channels/line/line.go | 36 +++---- pkg/channels/maixcam/maixcam.go | 26 ++--- pkg/channels/manager.go | 28 +++--- pkg/channels/onebot/onebot.go | 91 +++++++++--------- pkg/channels/qq/qq.go | 10 +- pkg/channels/slack/slack.go | 22 ++--- pkg/channels/telegram/telegram.go | 36 +++---- pkg/channels/telegram/telegram_commands.go | 3 + pkg/channels/wecom/app.go | 40 ++++---- pkg/channels/wecom/app_test.go | 73 ++++++++++---- pkg/channels/wecom/bot.go | 27 +++--- pkg/channels/wecom/bot_test.go | 105 +++++++++++++-------- pkg/channels/whatsapp/whatsapp.go | 8 +- 17 files changed, 315 insertions(+), 236 deletions(-) diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 29b31e071..c62c868e3 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -16,9 +16,17 @@ import ( "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" + _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" dch "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/feishu" + _ "github.com/sipeed/picoclaw/pkg/channels/line" + _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" + _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/qq" slackch "github.com/sipeed/picoclaw/pkg/channels/slack" - tgram "github.com/sipeed/picoclaw/pkg/channels/telegram" + tgramch "github.com/sipeed/picoclaw/pkg/channels/telegram" + _ "github.com/sipeed/picoclaw/pkg/channels/wecom" + _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" @@ -29,16 +37,6 @@ import ( "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" "github.com/sipeed/picoclaw/pkg/voice" - - // Channel factory registrations (blank imports trigger init()) - _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" - _ "github.com/sipeed/picoclaw/pkg/channels/feishu" - _ "github.com/sipeed/picoclaw/pkg/channels/line" - _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" - _ "github.com/sipeed/picoclaw/pkg/channels/onebot" - _ "github.com/sipeed/picoclaw/pkg/channels/qq" - _ "github.com/sipeed/picoclaw/pkg/channels/wecom" - _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" ) func gatewayCmd() { @@ -151,7 +149,7 @@ func gatewayCmd() { if transcriber != nil { if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*tgram.TelegramChannel); ok { + if tc, ok := telegramChannel.(*tgramch.TelegramChannel); ok { tc.SetTranscriber(transcriber) logger.InfoC("voice", "Groq transcription attached to Telegram channel") } diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index 0edb0023c..afc0de47f 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -10,6 +10,7 @@ import ( "github.com/open-dingtalk/dingtalk-stream-sdk-go/chatbot" "github.com/open-dingtalk/dingtalk-stream-sdk-go/client" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -109,7 +110,7 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid session_webhook type for chat %s", msg.ChatID) } - logger.DebugCF("dingtalk", "Sending message", map[string]interface{}{ + logger.DebugCF("dingtalk", "Sending message", map[string]any{ "chat_id": msg.ChatID, "preview": utils.Truncate(msg.Content, 100), }) @@ -121,12 +122,15 @@ func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) err // onChatBotMessageReceived implements the IChatBotMessageHandler function signature // This is called by the Stream SDK when a new message arrives // IChatBotMessageHandler is: func(c context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) -func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *chatbot.BotCallbackDataModel) ([]byte, error) { +func (c *DingTalkChannel) onChatBotMessageReceived( + ctx context.Context, + data *chatbot.BotCallbackDataModel, +) ([]byte, error) { // Extract message content from Text field content := data.Text.Content if content == "" { // Try to extract from Content interface{} if Text is empty - if contentMap, ok := data.Content.(map[string]interface{}); ok { + if contentMap, ok := data.Content.(map[string]any); ok { if textContent, ok := contentMap["content"].(string); ok { content = textContent } @@ -164,7 +168,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived(ctx context.Context, data *ch metadata["peer_id"] = data.ConversationId } - logger.DebugCF("dingtalk", "Received message", map[string]interface{}{ + logger.DebugCF("dingtalk", "Received message", map[string]any{ "sender_nick": senderNick, "sender_id": senderID, "preview": utils.Truncate(content, 50), @@ -193,7 +197,6 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c titleBytes, contentBytes, ) - if err != nil { return fmt.Errorf("failed to send reply: %w", err) } diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 6c4efd87c..b83ac28fd 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -9,6 +9,7 @@ import ( "time" "github.com/bwmarrin/discordgo" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" @@ -322,7 +323,7 @@ func (c *DiscordChannel) startTyping(chatID string) { go func() { if err := c.session.ChannelTyping(chatID); err != nil { - logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err}) } ticker := time.NewTicker(8 * time.Second) defer ticker.Stop() @@ -337,7 +338,7 @@ func (c *DiscordChannel) startTyping(chatID string) { return case <-ticker.C: if err := c.session.ChannelTyping(chatID); err != nil { - logger.DebugCF("discord", "ChannelTyping error", map[string]interface{}{"chatID": chatID, "err": err}) + logger.DebugCF("discord", "ChannelTyping error", map[string]any{"chatID": chatID, "err": err}) } } } diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index a49ee34cb..aa4e141c4 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -66,7 +66,7 @@ func (c *FeishuChannel) Start(ctx context.Context) error { go func() { if err := wsClient.Start(runCtx); err != nil { - logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]interface{}{ + logger.ErrorCF("feishu", "Feishu websocket stopped with error", map[string]any{ "error": err.Error(), }) } @@ -122,7 +122,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) } - logger.DebugCF("feishu", "Feishu message sent", map[string]interface{}{ + logger.DebugCF("feishu", "Feishu message sent", map[string]any{ "chat_id": msg.ChatID, }) @@ -175,7 +175,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 metadata["peer_id"] = chatID } - logger.InfoCF("feishu", "Feishu message received", map[string]interface{}{ + logger.InfoCF("feishu", "Feishu message received", map[string]any{ "sender_id": senderID, "chat_id": chatID, "preview": utils.Truncate(content, 80), diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 7df0491d9..4e1d0dfd3 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -76,11 +76,11 @@ func (c *LINEChannel) Start(ctx context.Context) error { // Fetch bot profile to get bot's userId for mention detection if err := c.fetchBotInfo(); err != nil { - logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]interface{}{ + logger.WarnCF("line", "Failed to fetch bot info (mention detection disabled)", map[string]any{ "error": err.Error(), }) } else { - logger.InfoCF("line", "Bot info fetched", map[string]interface{}{ + logger.InfoCF("line", "Bot info fetched", map[string]any{ "bot_user_id": c.botUserID, "basic_id": c.botBasicID, "display_name": c.botDisplayName, @@ -101,12 +101,12 @@ func (c *LINEChannel) Start(ctx context.Context) error { } go func() { - logger.InfoCF("line", "LINE webhook server listening", map[string]interface{}{ + logger.InfoCF("line", "LINE webhook server listening", map[string]any{ "addr": addr, "path": path, }) if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("line", "Webhook server error", map[string]interface{}{ + logger.ErrorCF("line", "Webhook server error", map[string]any{ "error": err.Error(), }) } @@ -163,7 +163,7 @@ func (c *LINEChannel) Stop(ctx context.Context) error { shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() if err := c.httpServer.Shutdown(shutdownCtx); err != nil { - logger.ErrorCF("line", "Webhook server shutdown error", map[string]interface{}{ + logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{ "error": err.Error(), }) } @@ -183,7 +183,7 @@ func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { body, err := io.ReadAll(r.Body) if err != nil { - logger.ErrorCF("line", "Failed to read request body", map[string]interface{}{ + logger.ErrorCF("line", "Failed to read request body", map[string]any{ "error": err.Error(), }) http.Error(w, "Bad request", http.StatusBadRequest) @@ -201,7 +201,7 @@ func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { Events []lineEvent `json:"events"` } if err := json.Unmarshal(body, &payload); err != nil { - logger.ErrorCF("line", "Failed to parse webhook payload", map[string]interface{}{ + logger.ErrorCF("line", "Failed to parse webhook payload", map[string]any{ "error": err.Error(), }) http.Error(w, "Bad request", http.StatusBadRequest) @@ -267,7 +267,7 @@ type lineMentionee struct { func (c *LINEChannel) processEvent(event lineEvent) { if event.Type != "message" { - logger.DebugCF("line", "Ignoring non-message event", map[string]interface{}{ + logger.DebugCF("line", "Ignoring non-message event", map[string]any{ "type": event.Type, }) return @@ -279,7 +279,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { var msg lineMessage if err := json.Unmarshal(event.Message, &msg); err != nil { - logger.ErrorCF("line", "Failed to parse message", map[string]interface{}{ + logger.ErrorCF("line", "Failed to parse message", map[string]any{ "error": err.Error(), }) return @@ -287,7 +287,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { // In group chats, only respond when the bot is mentioned if isGroup && !c.isBotMentioned(msg) { - logger.DebugCF("line", "Ignoring group message without mention", map[string]interface{}{ + logger.DebugCF("line", "Ignoring group message without mention", map[string]any{ "chat_id": chatID, }) return @@ -313,7 +313,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { - logger.DebugCF("line", "Failed to cleanup temp file", map[string]interface{}{ + logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ "file": file, "error": err.Error(), }) @@ -375,7 +375,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { metadata["peer_id"] = senderID } - logger.DebugCF("line", "Received message", map[string]interface{}{ + logger.DebugCF("line", "Received message", map[string]any{ "sender_id": senderID, "chat_id": chatID, "message_type": msg.Type, @@ -506,7 +506,7 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { tokenEntry := entry.(replyTokenEntry) if time.Since(tokenEntry.timestamp) < lineReplyTokenMaxAge { if err := c.sendReply(ctx, tokenEntry.token, msg.Content, quoteToken); err == nil { - logger.DebugCF("line", "Message sent via Reply API", map[string]interface{}{ + logger.DebugCF("line", "Message sent via Reply API", map[string]any{ "chat_id": msg.ChatID, "quoted": quoteToken != "", }) @@ -534,7 +534,7 @@ func buildTextMessage(content, quoteToken string) map[string]string { // sendReply sends a message using the LINE Reply API. func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteToken string) error { - payload := map[string]interface{}{ + payload := map[string]any{ "replyToken": replyToken, "messages": []map[string]string{buildTextMessage(content, quoteToken)}, } @@ -544,7 +544,7 @@ func (c *LINEChannel) sendReply(ctx context.Context, replyToken, content, quoteT // sendPush sends a message using the LINE Push API. func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken string) error { - payload := map[string]interface{}{ + payload := map[string]any{ "to": to, "messages": []map[string]string{buildTextMessage(content, quoteToken)}, } @@ -554,19 +554,19 @@ func (c *LINEChannel) sendPush(ctx context.Context, to, content, quoteToken stri // sendLoading sends a loading animation indicator to the chat. func (c *LINEChannel) sendLoading(chatID string) { - payload := map[string]interface{}{ + payload := map[string]any{ "chatId": chatID, "loadingSeconds": 60, } if err := c.callAPI(c.ctx, lineLoadingEndpoint, payload); err != nil { - logger.DebugCF("line", "Failed to send loading indicator", map[string]interface{}{ + logger.DebugCF("line", "Failed to send loading indicator", map[string]any{ "error": err.Error(), }) } } // callAPI makes an authenticated POST request to the LINE API. -func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload interface{}) error { +func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) error { body, err := json.Marshal(payload) if err != nil { return fmt.Errorf("failed to marshal payload: %w", err) diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index d3c6662d7..a7bff55e0 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -22,10 +22,10 @@ type MaixCamChannel struct { } type MaixCamMessage struct { - Type string `json:"type"` - Tips string `json:"tips"` - Timestamp float64 `json:"timestamp"` - Data map[string]interface{} `json:"data"` + Type string `json:"type"` + Tips string `json:"tips"` + Timestamp float64 `json:"timestamp"` + Data map[string]any `json:"data"` } func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { @@ -50,7 +50,7 @@ func (c *MaixCamChannel) Start(ctx context.Context) error { c.listener = listener c.SetRunning(true) - logger.InfoCF("maixcam", "MaixCam server listening", map[string]interface{}{ + logger.InfoCF("maixcam", "MaixCam server listening", map[string]any{ "host": c.config.Host, "port": c.config.Port, }) @@ -72,14 +72,14 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { conn, err := c.listener.Accept() if err != nil { if c.IsRunning() { - logger.ErrorCF("maixcam", "Failed to accept connection", map[string]interface{}{ + logger.ErrorCF("maixcam", "Failed to accept connection", map[string]any{ "error": err.Error(), }) } return } - logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]interface{}{ + logger.InfoCF("maixcam", "New connection from MaixCam device", map[string]any{ "remote_addr": conn.RemoteAddr().String(), }) @@ -113,7 +113,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { var msg MaixCamMessage if err := decoder.Decode(&msg); err != nil { if err.Error() != "EOF" { - logger.ErrorCF("maixcam", "Failed to decode message", map[string]interface{}{ + logger.ErrorCF("maixcam", "Failed to decode message", map[string]any{ "error": err.Error(), }) } @@ -134,14 +134,14 @@ func (c *MaixCamChannel) processMessage(msg MaixCamMessage, conn net.Conn) { case "status": c.handleStatusUpdate(msg) default: - logger.WarnCF("maixcam", "Unknown message type", map[string]interface{}{ + logger.WarnCF("maixcam", "Unknown message type", map[string]any{ "type": msg.Type, }) } } func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { - logger.InfoCF("maixcam", "", map[string]interface{}{ + logger.InfoCF("maixcam", "", map[string]any{ "timestamp": msg.Timestamp, "data": msg.Data, }) @@ -179,7 +179,7 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { - logger.InfoCF("maixcam", "Status update from MaixCam", map[string]interface{}{ + logger.InfoCF("maixcam", "Status update from MaixCam", map[string]any{ "status": msg.Data, }) } @@ -217,7 +217,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("no connected MaixCam devices") } - response := map[string]interface{}{ + response := map[string]any{ "type": "command", "timestamp": float64(0), "message": msg.Content, @@ -232,7 +232,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro var sendErr error for conn := range c.clients { if _, err := conn.Write(data); err != nil { - logger.ErrorCF("maixcam", "Failed to send to client", map[string]interface{}{ + logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ "client": conn.RemoteAddr().String(), "error": err.Error(), }) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 091982282..7baef058c 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -47,23 +47,23 @@ func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error func (m *Manager) initChannel(name, displayName string) { f, ok := getFactory(name) if !ok { - logger.WarnCF("channels", "Factory not registered", map[string]interface{}{ + logger.WarnCF("channels", "Factory not registered", map[string]any{ "channel": displayName, }) return } - logger.DebugCF("channels", "Attempting to initialize channel", map[string]interface{}{ + logger.DebugCF("channels", "Attempting to initialize channel", map[string]any{ "channel": displayName, }) ch, err := f(m.config, m.bus) if err != nil { - logger.ErrorCF("channels", "Failed to initialize channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to initialize channel", map[string]any{ "channel": displayName, "error": err.Error(), }) } else { m.channels[name] = ch - logger.InfoCF("channels", "Channel enabled successfully", map[string]interface{}{ + logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ "channel": displayName, }) } @@ -120,7 +120,7 @@ func (m *Manager) initChannels() error { m.initChannel("wecom_app", "WeCom App") } - logger.InfoCF("channels", "Channel initialization completed", map[string]interface{}{ + logger.InfoCF("channels", "Channel initialization completed", map[string]any{ "enabled_channels": len(m.channels), }) @@ -144,11 +144,11 @@ func (m *Manager) StartAll(ctx context.Context) error { go m.dispatchOutbound(dispatchCtx) for name, channel := range m.channels { - logger.InfoCF("channels", "Starting channel", map[string]interface{}{ + logger.InfoCF("channels", "Starting channel", map[string]any{ "channel": name, }) if err := channel.Start(ctx); err != nil { - logger.ErrorCF("channels", "Failed to start channel", map[string]interface{}{ + logger.ErrorCF("channels", "Failed to start channel", map[string]any{ "channel": name, "error": err.Error(), }) @@ -171,11 +171,11 @@ func (m *Manager) StopAll(ctx context.Context) error { } for name, channel := range m.channels { - logger.InfoCF("channels", "Stopping channel", map[string]interface{}{ + logger.InfoCF("channels", "Stopping channel", map[string]any{ "channel": name, }) if err := channel.Stop(ctx); err != nil { - logger.ErrorCF("channels", "Error stopping channel", map[string]interface{}{ + logger.ErrorCF("channels", "Error stopping channel", map[string]any{ "channel": name, "error": err.Error(), }) @@ -210,14 +210,14 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { m.mu.RUnlock() if !exists { - logger.WarnCF("channels", "Unknown channel for outbound message", map[string]interface{}{ + logger.WarnCF("channels", "Unknown channel for outbound message", map[string]any{ "channel": msg.Channel, }) continue } if err := channel.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message to channel", map[string]interface{}{ + logger.ErrorCF("channels", "Error sending message to channel", map[string]any{ "channel": msg.Channel, "error": err.Error(), }) @@ -233,13 +233,13 @@ func (m *Manager) GetChannel(name string) (Channel, bool) { return channel, ok } -func (m *Manager) GetStatus() map[string]interface{} { +func (m *Manager) GetStatus() map[string]any { m.mu.RLock() defer m.mu.RUnlock() - status := make(map[string]interface{}) + status := make(map[string]any) for name, channel := range m.channels { - status[name] = map[string]interface{}{ + status[name] = map[string]any{ "enabled": true, "running": channel.IsRunning(), } diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 209f2dc00..3d2e64e2a 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -88,14 +88,14 @@ type oneBotSender struct { } type oneBotAPIRequest struct { - Action string `json:"action"` - Params interface{} `json:"params"` - Echo string `json:"echo,omitempty"` + Action string `json:"action"` + Params any `json:"params"` + Echo string `json:"echo,omitempty"` } type oneBotMessageSegment struct { - Type string `json:"type"` - Data map[string]interface{} `json:"data"` + Type string `json:"type"` + Data map[string]any `json:"data"` } func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { @@ -118,13 +118,13 @@ func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { go func() { - _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]interface{}{ + _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{ "message_id": messageID, "emoji_id": emojiID, "set": set, }, 5*time.Second) if err != nil { - logger.DebugCF("onebot", "Failed to set emoji like", map[string]interface{}{ + logger.DebugCF("onebot", "Failed to set emoji like", map[string]any{ "message_id": messageID, "error": err.Error(), }) @@ -137,14 +137,14 @@ func (c *OneBotChannel) Start(ctx context.Context) error { return fmt.Errorf("OneBot ws_url not configured") } - logger.InfoCF("onebot", "Starting OneBot channel", map[string]interface{}{ + logger.InfoCF("onebot", "Starting OneBot channel", map[string]any{ "ws_url": c.config.WSUrl, }) c.ctx, c.cancel = context.WithCancel(ctx) if err := c.connect(); err != nil { - logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]interface{}{ + logger.WarnCF("onebot", "Initial connection failed, will retry in background", map[string]any{ "error": err.Error(), }) } else { @@ -209,7 +209,7 @@ func (c *OneBotChannel) pinger(conn *websocket.Conn) { err := conn.WriteMessage(websocket.PingMessage, nil) c.writeMu.Unlock() if err != nil { - logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]interface{}{ + logger.DebugCF("onebot", "Ping write failed, stopping pinger", map[string]any{ "error": err.Error(), }) return @@ -221,7 +221,7 @@ func (c *OneBotChannel) pinger(conn *websocket.Conn) { func (c *OneBotChannel) fetchSelfID() { resp, err := c.sendAPIRequest("get_login_info", nil, 5*time.Second) if err != nil { - logger.WarnCF("onebot", "Failed to get_login_info", map[string]interface{}{ + logger.WarnCF("onebot", "Failed to get_login_info", map[string]any{ "error": err.Error(), }) return @@ -251,7 +251,7 @@ func (c *OneBotChannel) fetchSelfID() { } if uid, err := parseJSONInt64(info.UserID); err == nil && uid > 0 { atomic.StoreInt64(&c.selfID, uid) - logger.InfoCF("onebot", "Bot self ID retrieved", map[string]interface{}{ + logger.InfoCF("onebot", "Bot self ID retrieved", map[string]any{ "self_id": uid, "nickname": info.Nickname, }) @@ -259,12 +259,12 @@ func (c *OneBotChannel) fetchSelfID() { } } - logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]interface{}{ + logger.WarnCF("onebot", "Could not parse self ID from get_login_info response", map[string]any{ "response": string(resp), }) } -func (c *OneBotChannel) sendAPIRequest(action string, params interface{}, timeout time.Duration) (json.RawMessage, error) { +func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.Duration) (json.RawMessage, error) { c.mu.Lock() conn := c.conn c.mu.Unlock() @@ -333,7 +333,7 @@ func (c *OneBotChannel) reconnectLoop() { if conn == nil { logger.InfoC("onebot", "Attempting to reconnect...") if err := c.connect(); err != nil { - logger.ErrorCF("onebot", "Reconnect failed", map[string]interface{}{ + logger.ErrorCF("onebot", "Reconnect failed", map[string]any{ "error": err.Error(), }) } else { @@ -406,7 +406,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error c.writeMu.Unlock() if err != nil { - logger.ErrorCF("onebot", "Failed to send message", map[string]interface{}{ + logger.ErrorCF("onebot", "Failed to send message", map[string]any{ "error": err.Error(), }) return err @@ -428,20 +428,20 @@ func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMes if msgID, ok := lastMsgID.(string); ok && msgID != "" { segments = append(segments, oneBotMessageSegment{ Type: "reply", - Data: map[string]interface{}{"id": msgID}, + Data: map[string]any{"id": msgID}, }) } } segments = append(segments, oneBotMessageSegment{ Type: "text", - Data: map[string]interface{}{"text": content}, + Data: map[string]any{"text": content}, }) return segments } -func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, interface{}, error) { +func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, any, error) { chatID := msg.ChatID segments := c.buildMessageSegments(chatID, msg.Content) @@ -459,7 +459,7 @@ func (c *OneBotChannel) buildSendRequest(msg bus.OutboundMessage) (string, inter if err != nil { return "", nil, fmt.Errorf("invalid %s in chatID: %s", idKey, chatID) } - return action, map[string]interface{}{idKey: id, "message": segments}, nil + return action, map[string]any{idKey: id, "message": segments}, nil } func (c *OneBotChannel) listen() { @@ -479,7 +479,7 @@ func (c *OneBotChannel) listen() { default: _, message, err := conn.ReadMessage() if err != nil { - logger.ErrorCF("onebot", "WebSocket read error", map[string]interface{}{ + logger.ErrorCF("onebot", "WebSocket read error", map[string]any{ "error": err.Error(), }) c.mu.Lock() @@ -495,14 +495,14 @@ func (c *OneBotChannel) listen() { var raw oneBotRawEvent if err := json.Unmarshal(message, &raw); err != nil { - logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]interface{}{ + logger.WarnCF("onebot", "Failed to unmarshal raw event", map[string]any{ "error": err.Error(), "payload": string(message), }) continue } - logger.DebugCF("onebot", "WebSocket event", map[string]interface{}{ + logger.DebugCF("onebot", "WebSocket event", map[string]any{ "length": len(message), "post_type": raw.PostType, "sub_type": raw.SubType, @@ -519,7 +519,7 @@ func (c *OneBotChannel) listen() { default: } } else { - logger.DebugCF("onebot", "Received API response (no waiter)", map[string]interface{}{ + logger.DebugCF("onebot", "Received API response (no waiter)", map[string]any{ "echo": raw.Echo, "status": string(raw.Status), }) @@ -528,7 +528,7 @@ func (c *OneBotChannel) listen() { } if isAPIResponse(raw.Status) { - logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]interface{}{ + logger.DebugCF("onebot", "Received API response without echo, skipping", map[string]any{ "status": string(raw.Status), }) continue @@ -595,7 +595,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) return parseMessageResult{Text: s, IsBotMentioned: mentioned} } - var segments []map[string]interface{} + var segments []map[string]any if err := json.Unmarshal(raw, &segments); err != nil { return parseMessageResult{} } @@ -609,7 +609,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) for _, seg := range segments { segType, _ := seg["type"].(string) - data, _ := seg["data"].(map[string]interface{}) + data, _ := seg["data"].(map[string]any) switch segType { case "text": @@ -663,7 +663,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) result, err := c.transcriber.Transcribe(tctx, localPath) tcancel() if err != nil { - logger.WarnCF("onebot", "Voice transcription failed", map[string]interface{}{ + logger.WarnCF("onebot", "Voice transcription failed", map[string]any{ "error": err.Error(), }) textParts = append(textParts, "[voice (transcription failed)]") @@ -714,7 +714,7 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { case "message": if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { if !c.IsAllowed(strconv.FormatInt(userID, 10)) { - logger.DebugCF("onebot", "Message rejected by allowlist", map[string]interface{}{ + logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{ "user_id": userID, }) return @@ -723,7 +723,7 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { c.handleMessage(raw) case "message_sent": - logger.DebugCF("onebot", "Bot sent message event", map[string]interface{}{ + logger.DebugCF("onebot", "Bot sent message event", map[string]any{ "message_type": raw.MessageType, "message_id": parseJSONString(raw.MessageID), }) @@ -735,18 +735,18 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { c.handleNoticeEvent(raw) case "request": - logger.DebugCF("onebot", "Request event received", map[string]interface{}{ + logger.DebugCF("onebot", "Request event received", map[string]any{ "sub_type": raw.SubType, }) case "": - logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]interface{}{ + logger.DebugCF("onebot", "Event with empty post_type (possibly API response)", map[string]any{ "echo": raw.Echo, "status": raw.Status, }) default: - logger.DebugCF("onebot", "Unknown post_type", map[string]interface{}{ + logger.DebugCF("onebot", "Unknown post_type", map[string]any{ "post_type": raw.PostType, }) } @@ -754,14 +754,14 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { func (c *OneBotChannel) handleMetaEvent(raw *oneBotRawEvent) { if raw.MetaEventType == "lifecycle" { - logger.InfoCF("onebot", "Lifecycle event", map[string]interface{}{"sub_type": raw.SubType}) + logger.InfoCF("onebot", "Lifecycle event", map[string]any{"sub_type": raw.SubType}) } else if raw.MetaEventType != "heartbeat" { logger.DebugCF("onebot", "Meta event: "+raw.MetaEventType, nil) } } func (c *OneBotChannel) handleNoticeEvent(raw *oneBotRawEvent) { - fields := map[string]interface{}{ + fields := map[string]any{ "notice_type": raw.NoticeType, "sub_type": raw.SubType, "group_id": parseJSONString(raw.GroupID), @@ -781,7 +781,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { // Parse fields from raw event userID, err := parseJSONInt64(raw.UserID) if err != nil { - logger.WarnCF("onebot", "Failed to parse user_id", map[string]interface{}{ + logger.WarnCF("onebot", "Failed to parse user_id", map[string]any{ "error": err.Error(), "raw": string(raw.UserID), }) @@ -818,7 +818,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { var sender oneBotSender if len(raw.Sender) > 0 { if err := json.Unmarshal(raw.Sender, &sender); err != nil { - logger.WarnCF("onebot", "Failed to parse sender", map[string]interface{}{ + logger.WarnCF("onebot", "Failed to parse sender", map[string]any{ "error": err.Error(), "sender": string(raw.Sender), }) @@ -830,7 +830,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { defer func() { for _, f := range parsed.LocalFiles { if err := os.Remove(f); err != nil { - logger.DebugCF("onebot", "Failed to remove temp file", map[string]interface{}{ + logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ "path": f, "error": err.Error(), }) @@ -840,14 +840,14 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } if c.isDuplicate(messageID) { - logger.DebugCF("onebot", "Duplicate message, skipping", map[string]interface{}{ + logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ "message_id": messageID, }) return } if content == "" { - logger.DebugCF("onebot", "Received empty message, ignoring", map[string]interface{}{ + logger.DebugCF("onebot", "Received empty message, ignoring", map[string]any{ "message_id": messageID, }) return @@ -890,7 +890,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) if !triggered { - logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]interface{}{ + logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{ "sender": senderID, "group": groupIDStr, "is_mentioned": isBotMentioned, @@ -901,7 +901,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { content = strippedContent default: - logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]interface{}{ + logger.WarnCF("onebot", "Unknown message type, cannot route", map[string]any{ "type": raw.MessageType, "message_id": messageID, "user_id": userID, @@ -909,7 +909,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { return } - logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]interface{}{ + logger.InfoCF("onebot", "Received "+raw.MessageType+" message", map[string]any{ "sender": senderID, "chat_id": chatID, "message_id": messageID, @@ -962,7 +962,10 @@ func truncate(s string, n int) string { return string(runes[:n]) + "..." } -func (c *OneBotChannel) checkGroupTrigger(content string, isBotMentioned bool) (triggered bool, strippedContent string) { +func (c *OneBotChannel) checkGroupTrigger( + content string, + isBotMentioned bool, +) (triggered bool, strippedContent string) { if isBotMentioned { return true, strings.TrimSpace(content) } diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 9b07be0cc..2a95bbd06 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -78,7 +78,7 @@ func (c *QQChannel) Start(ctx context.Context) error { return fmt.Errorf("failed to get websocket info: %w", err) } - logger.InfoCF("qq", "Got WebSocket info", map[string]interface{}{ + logger.InfoCF("qq", "Got WebSocket info", map[string]any{ "shards": wsInfo.Shards, }) @@ -88,7 +88,7 @@ func (c *QQChannel) Start(ctx context.Context) error { // 在 goroutine 中启动 WebSocket 连接,避免阻塞 go func() { if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil { - logger.ErrorCF("qq", "WebSocket session error", map[string]interface{}{ + logger.ErrorCF("qq", "WebSocket session error", map[string]any{ "error": err.Error(), }) c.SetRunning(false) @@ -125,7 +125,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { // C2C 消息发送 _, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate) if err != nil { - logger.ErrorCF("qq", "Failed to send C2C message", map[string]interface{}{ + logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ "error": err.Error(), }) return err @@ -158,7 +158,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { return nil } - logger.InfoCF("qq", "Received C2C message", map[string]interface{}{ + logger.InfoCF("qq", "Received C2C message", map[string]any{ "sender": senderID, "length": len(content), }) @@ -200,7 +200,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return nil } - logger.InfoCF("qq", "Received group AT message", map[string]interface{}{ + logger.InfoCF("qq", "Received group AT message", map[string]any{ "sender": senderID, "group": data.GroupID, "length": len(content), diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index dc5190fc9..cafe53103 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -76,7 +76,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { c.botUserID = authResp.UserID c.teamID = authResp.TeamID - logger.InfoCF("slack", "Slack bot connected", map[string]interface{}{ + logger.InfoCF("slack", "Slack bot connected", map[string]any{ "bot_user_id": c.botUserID, "team": authResp.Team, }) @@ -86,7 +86,7 @@ func (c *SlackChannel) Start(ctx context.Context) error { go func() { if err := c.socketClient.RunContext(c.ctx); err != nil { if c.ctx.Err() == nil { - logger.ErrorCF("slack", "Socket Mode connection error", map[string]interface{}{ + logger.ErrorCF("slack", "Socket Mode connection error", map[string]any{ "error": err.Error(), }) } @@ -141,7 +141,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error }) } - logger.DebugCF("slack", "Message sent", map[string]interface{}{ + logger.DebugCF("slack", "Message sent", map[string]any{ "channel_id": channelID, "thread_ts": threadTS, }) @@ -203,7 +203,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { // 检查白名单,避免为被拒绝的用户下载附件 if !c.IsAllowed(ev.User) { - logger.DebugCF("slack", "Message rejected by allowlist", map[string]interface{}{ + logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{ "user_id": ev.User, }) return @@ -239,7 +239,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]interface{}{ + logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ "file": file, "error": err.Error(), }) @@ -262,7 +262,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { result, err := c.transcriber.Transcribe(ctx, localPath) if err != nil { - logger.ErrorCF("slack", "Voice transcription failed", map[string]interface{}{"error": err.Error()}) + logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()}) content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) } else { content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) @@ -294,7 +294,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "team_id": c.teamID, } - logger.DebugCF("slack", "Received message", map[string]interface{}{ + logger.DebugCF("slack", "Received message", map[string]any{ "sender_id": senderID, "chat_id": chatID, "preview": utils.Truncate(content, 50), @@ -310,7 +310,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { } if !c.IsAllowed(ev.User) { - logger.DebugCF("slack", "Mention rejected by allowlist", map[string]interface{}{ + logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{ "user_id": ev.User, }) return @@ -376,7 +376,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { } if !c.IsAllowed(cmd.UserID) { - logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]interface{}{ + logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{ "user_id": cmd.UserID, }) return @@ -401,7 +401,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "team_id": c.teamID, } - logger.DebugCF("slack", "Slash command received", map[string]interface{}{ + logger.DebugCF("slack", "Slash command received", map[string]any{ "sender_id": senderID, "command": cmd.Command, "text": utils.Truncate(content, 50), @@ -416,7 +416,7 @@ func (c *SlackChannel) downloadSlackFile(file slack.File) string { downloadURL = file.URLPrivate } if downloadURL == "" { - logger.ErrorCF("slack", "No download URL for file", map[string]interface{}{"file_id": file.ID}) + logger.ErrorCF("slack", "No download URL for file", map[string]any{"file_id": file.ID}) return "" } diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index f4c5108df..7619440e2 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -11,10 +11,9 @@ import ( "sync" "time" - th "github.com/mymmrac/telego/telegohandler" - "github.com/mymmrac/telego" "github.com/mymmrac/telego/telegohandler" + th "github.com/mymmrac/telego/telegohandler" tu "github.com/mymmrac/telego/telegoutil" "github.com/sipeed/picoclaw/pkg/bus" @@ -128,7 +127,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { }, th.AnyMessage()) c.SetRunning(true) - logger.InfoCF("telegram", "Telegram bot connected", map[string]interface{}{ + logger.InfoCF("telegram", "Telegram bot connected", map[string]any{ "username": c.bot.Username(), }) @@ -141,6 +140,7 @@ func (c *TelegramChannel) Start(ctx context.Context) error { return nil } + func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.SetRunning(false) @@ -183,7 +183,7 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err tgMsg.ParseMode = telego.ModeHTML if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { - logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]interface{}{ + logger.ErrorCF("telegram", "HTML parse failed, falling back to plain text", map[string]any{ "error": err.Error(), }) tgMsg.ParseMode = "" @@ -211,7 +211,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes // 检查白名单,避免为被拒绝的用户下载附件 if !c.IsAllowed(senderID) { - logger.DebugCF("telegram", "Message rejected by allowlist", map[string]interface{}{ + logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ "user_id": senderID, }) return nil @@ -228,7 +228,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes defer func() { for _, file := range localFiles { if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]interface{}{ + logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ "file": file, "error": err.Error(), }) @@ -268,19 +268,19 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(ctx, 30*time.Second) + transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - result, err := c.transcriber.Transcribe(ctx, voicePath) + result, err := c.transcriber.Transcribe(transcriberCtx, voicePath) if err != nil { - logger.ErrorCF("telegram", "Voice transcription failed", map[string]interface{}{ + logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{ "error": err.Error(), "path": voicePath, }) transcribedText = "[voice (transcription failed)]" } else { transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - logger.InfoCF("telegram", "Voice transcribed successfully", map[string]interface{}{ + logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{ "text": result.Text, }) } @@ -323,7 +323,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes content = "[empty message]" } - logger.DebugCF("telegram", "Received message", map[string]interface{}{ + logger.DebugCF("telegram", "Received message", map[string]any{ "sender_id": senderID, "chat_id": fmt.Sprintf("%d", chatID), "preview": utils.Truncate(content, 50), @@ -332,7 +332,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes // Thinking indicator err := c.bot.SendChatAction(ctx, tu.ChatAction(tu.ID(chatID), telego.ChatActionTyping)) if err != nil { - logger.ErrorCF("telegram", "Failed to send chat action", map[string]interface{}{ + logger.ErrorCF("telegram", "Failed to send chat action", map[string]any{ "error": err.Error(), }) } @@ -379,7 +379,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes func (c *TelegramChannel) downloadPhoto(ctx context.Context, fileID string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - logger.ErrorCF("telegram", "Failed to get photo file", map[string]interface{}{ + logger.ErrorCF("telegram", "Failed to get photo file", map[string]any{ "error": err.Error(), }) return "" @@ -394,7 +394,7 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st } url := c.bot.FileDownloadURL(file.FilePath) - logger.DebugCF("telegram", "File URL", map[string]interface{}{"url": url}) + logger.DebugCF("telegram", "File URL", map[string]any{"url": url}) // Use FilePath as filename for better identification filename := file.FilePath + ext @@ -406,7 +406,7 @@ func (c *TelegramChannel) downloadFileWithInfo(file *telego.File, ext string) st func (c *TelegramChannel) downloadFile(ctx context.Context, fileID, ext string) string { file, err := c.bot.GetFile(ctx, &telego.GetFileParams{FileID: fileID}) if err != nil { - logger.ErrorCF("telegram", "Failed to get file", map[string]interface{}{ + logger.ErrorCF("telegram", "Failed to get file", map[string]any{ "error": err.Error(), }) return "" @@ -464,7 +464,11 @@ func markdownToTelegramHTML(text string) string { for i, code := range codeBlocks.codes { escaped := escapeHTML(code) - text = strings.ReplaceAll(text, fmt.Sprintf("\x00CB%d\x00", i), fmt.Sprintf("
%s
", escaped)) + text = strings.ReplaceAll( + text, + fmt.Sprintf("\x00CB%d\x00", i), + fmt.Sprintf("
%s
", escaped), + ) } return text diff --git a/pkg/channels/telegram/telegram_commands.go b/pkg/channels/telegram/telegram_commands.go index 4bf1b3aff..f17912260 100644 --- a/pkg/channels/telegram/telegram_commands.go +++ b/pkg/channels/telegram/telegram_commands.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/mymmrac/telego" + "github.com/sipeed/picoclaw/pkg/config" ) @@ -35,6 +36,7 @@ func commandArgs(text string) string { } return strings.TrimSpace(parts[1]) } + func (c *cmd) Help(ctx context.Context, message telego.Message) error { msg := `/start - Start the bot /help - Show this help message @@ -96,6 +98,7 @@ func (c *cmd) Show(ctx context.Context, message telego.Message) error { }) return err } + func (c *cmd) List(ctx context.Context, message telego.Message) error { args := commandArgs(message.Text) if args == "" { diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 85c017958..f3557d60f 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -142,7 +142,7 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { // Get initial access token if err := c.refreshAccessToken(); err != nil { - logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]interface{}{ + logger.WarnCF("wecom_app", "Failed to get initial access token", map[string]any{ "error": err.Error(), }) } @@ -168,7 +168,7 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { } c.SetRunning(true) - logger.InfoCF("wecom_app", "WeCom App channel started", map[string]interface{}{ + logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ "address": addr, "path": webhookPath, }) @@ -176,7 +176,7 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { // Start server in goroutine go func() { if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom_app", "HTTP server error", map[string]interface{}{ + logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{ "error": err.Error(), }) } @@ -215,7 +215,7 @@ func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("no valid access token available") } - logger.DebugCF("wecom_app", "Sending message", map[string]interface{}{ + logger.DebugCF("wecom_app", "Sending message", map[string]any{ "chat_id": msg.ChatID, "preview": utils.Truncate(msg.Content, 100), }) @@ -228,7 +228,7 @@ func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) ctx := r.Context() // Log all incoming requests for debugging - logger.DebugCF("wecom_app", "Received webhook request", map[string]interface{}{ + logger.DebugCF("wecom_app", "Received webhook request", map[string]any{ "method": r.Method, "url": r.URL.String(), "path": r.URL.Path, @@ -247,7 +247,7 @@ func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) return } - logger.WarnCF("wecom_app", "Method not allowed", map[string]interface{}{ + logger.WarnCF("wecom_app", "Method not allowed", map[string]any{ "method": r.Method, }) http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) @@ -261,7 +261,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons nonce := query.Get("nonce") echostr := query.Get("echostr") - logger.DebugCF("wecom_app", "Handling verification request", map[string]interface{}{ + logger.DebugCF("wecom_app", "Handling verification request", map[string]any{ "msg_signature": msgSignature, "timestamp": timestamp, "nonce": nonce, @@ -277,7 +277,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons // Verify signature if !verifySignature(c.config.Token, msgSignature, timestamp, nonce, echostr) { - logger.WarnCF("wecom_app", "Signature verification failed", map[string]interface{}{ + logger.WarnCF("wecom_app", "Signature verification failed", map[string]any{ "token": c.config.Token, "msg_signature": msgSignature, "timestamp": timestamp, @@ -291,13 +291,13 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons // Decrypt echostr with CorpID verification // For WeCom App (自建应用), receiveid should be corp_id - logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]interface{}{ + logger.DebugCF("wecom_app", "Attempting to decrypt echostr", map[string]any{ "encoding_aes_key": c.config.EncodingAESKey, "corp_id": c.config.CorpID, }) decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, c.config.CorpID) if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]interface{}{ + logger.ErrorCF("wecom_app", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), "encoding_aes_key": c.config.EncodingAESKey, "corp_id": c.config.CorpID, @@ -306,7 +306,7 @@ func (c *WeComAppChannel) handleVerification(ctx context.Context, w http.Respons return } - logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]interface{}{ + logger.DebugCF("wecom_app", "Successfully decrypted echostr", map[string]any{ "decrypted": decryptedEchoStr, }) @@ -345,8 +345,8 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp AgentID string `xml:"AgentID"` } - if err := xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]interface{}{ + if err = xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom_app", "Failed to parse XML", map[string]any{ "error": err.Error(), }) http.Error(w, "Invalid XML", http.StatusBadRequest) @@ -364,7 +364,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp // For WeCom App (自建应用), receiveid should be corp_id decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, c.config.CorpID) if err != nil { - logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]interface{}{ + logger.ErrorCF("wecom_app", "Failed to decrypt message", map[string]any{ "error": err.Error(), }) http.Error(w, "Decryption failed", http.StatusInternalServerError) @@ -374,7 +374,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp // Parse decrypted XML message var msg WeComXMLMessage if err := xml.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]interface{}{ + logger.ErrorCF("wecom_app", "Failed to parse decrypted message", map[string]any{ "error": err.Error(), }) http.Error(w, "Invalid message format", http.StatusBadRequest) @@ -393,7 +393,7 @@ func (c *WeComAppChannel) handleMessageCallback(ctx context.Context, w http.Resp func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessage) { // Skip non-text messages for now (can be extended) if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" { - logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]interface{}{ + logger.DebugCF("wecom_app", "Skipping non-supported message type", map[string]any{ "msg_type": msg.MsgType, }) return @@ -405,7 +405,7 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag c.msgMu.Lock() if c.processedMsgs[msgID] { c.msgMu.Unlock() - logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]interface{}{ + logger.DebugCF("wecom_app", "Skipping duplicate message", map[string]any{ "msg_id": msgID, }) return @@ -438,7 +438,7 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag content := msg.Content - logger.DebugCF("wecom_app", "Received message", map[string]interface{}{ + logger.DebugCF("wecom_app", "Received message", map[string]any{ "sender_id": senderID, "msg_type": msg.MsgType, "preview": utils.Truncate(content, 50), @@ -459,7 +459,7 @@ func (c *WeComAppChannel) tokenRefreshLoop() { return case <-ticker.C: if err := c.refreshAccessToken(); err != nil { - logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]interface{}{ + logger.ErrorCF("wecom_app", "Failed to refresh access token", map[string]any{ "error": err.Error(), }) } @@ -625,7 +625,7 @@ func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, // handleHealth handles health check requests func (c *WeComAppChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]interface{}{ + status := map[string]any{ "status": "ok", "running": c.IsRunning(), "has_token": c.getAccessToken() != "", diff --git a/pkg/channels/wecom/app_test.go b/pkg/channels/wecom/app_test.go index d9817fd49..5420949de 100644 --- a/pkg/channels/wecom/app_test.go +++ b/pkg/channels/wecom/app_test.go @@ -396,7 +396,11 @@ func TestWeComAppHandleVerification(t *testing.T) { nonce := "test_nonce" signature := generateSignatureApp("test_token", timestamp, nonce, encryptedEchostr) - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) w := httptest.NewRecorder() ch.handleVerification(context.Background(), w, req) @@ -426,7 +430,11 @@ func TestWeComAppHandleVerification(t *testing.T) { timestamp := "1234567890" nonce := "test_nonce" - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) w := httptest.NewRecorder() ch.handleVerification(context.Background(), w, req) @@ -478,7 +486,11 @@ func TestWeComAppHandleMessageCallback(t *testing.T) { nonce := "test_nonce" signature := generateSignatureApp("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -507,7 +519,11 @@ func TestWeComAppHandleMessageCallback(t *testing.T) { nonce := "test_nonce" signature := generateSignatureApp("test_token", timestamp, nonce, "") - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, strings.NewReader("invalid xml")) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + strings.NewReader("invalid xml"), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -529,7 +545,11 @@ func TestWeComAppHandleMessageCallback(t *testing.T) { timestamp := "1234567890" nonce := "test_nonce" - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -643,7 +663,11 @@ func TestWeComAppHandleWebhook(t *testing.T) { nonce := "test_nonce" signature := generateSignatureApp("test_token", timestamp, nonce, encoded) - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, nil) + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, + nil, + ) w := httptest.NewRecorder() ch.handleWebhook(w, req) @@ -666,7 +690,11 @@ func TestWeComAppHandleWebhook(t *testing.T) { nonce := "test_nonce" signature := generateSignatureApp("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom-app?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleWebhook(w, req) @@ -832,15 +860,24 @@ func TestWeComAppMessageStructures(t *testing.T) { if msg.Image.MediaID != "media_123456" { t.Errorf("Image.MediaID = %q, want %q", msg.Image.MediaID, "media_123456") } + if msg.ToUser != "user123" { + t.Errorf("ToUser = %q, want %q", msg.ToUser, "user123") + } + if msg.MsgType != "image" { + t.Errorf("MsgType = %q, want %q", msg.MsgType, "image") + } + if msg.AgentID != 1000002 { + t.Errorf("AgentID = %d, want %d", msg.AgentID, 1000002) + } }) t.Run("WeComAccessTokenResponse structure", func(t *testing.T) { jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "access_token": "test_access_token", - "expires_in": 7200 - }` + "errcode": 0, + "errmsg": "ok", + "access_token": "test_access_token", + "expires_in": 7200 + }` var resp WeComAccessTokenResponse err := json.Unmarshal([]byte(jsonData), &resp) @@ -864,12 +901,12 @@ func TestWeComAppMessageStructures(t *testing.T) { t.Run("WeComSendMessageResponse structure", func(t *testing.T) { jsonData := `{ - "errcode": 0, - "errmsg": "ok", - "invaliduser": "", - "invalidparty": "", - "invalidtag": "" - }` + "errcode": 0, + "errmsg": "ok", + "invaliduser": "", + "invalidparty": "", + "invalidtag": "" + }` var resp WeComSendMessageResponse err := json.Unmarshal([]byte(jsonData), &resp) diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 9683a308f..17ee2107f 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -125,7 +125,7 @@ func (c *WeComBotChannel) Start(ctx context.Context) error { } c.SetRunning(true) - logger.InfoCF("wecom", "WeCom Bot channel started", map[string]interface{}{ + logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ "address": addr, "path": webhookPath, }) @@ -133,7 +133,7 @@ func (c *WeComBotChannel) Start(ctx context.Context) error { // Start server in goroutine go func() { if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom", "HTTP server error", map[string]interface{}{ + logger.ErrorCF("wecom", "HTTP server error", map[string]any{ "error": err.Error(), }) } @@ -169,7 +169,7 @@ func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("wecom channel not running") } - logger.DebugCF("wecom", "Sending message via webhook", map[string]interface{}{ + logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ "chat_id": msg.ChatID, "preview": utils.Truncate(msg.Content, 100), }) @@ -221,7 +221,7 @@ func (c *WeComBotChannel) handleVerification(ctx context.Context, w http.Respons // Reference: https://developer.work.weixin.qq.com/document/path/101033 decryptedEchoStr, err := decryptMessageWithVerify(echostr, c.config.EncodingAESKey, "") if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]interface{}{ + logger.ErrorCF("wecom", "Failed to decrypt echostr", map[string]any{ "error": err.Error(), }) http.Error(w, "Decryption failed", http.StatusInternalServerError) @@ -263,8 +263,8 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp AgentID string `xml:"AgentID"` } - if err := xml.Unmarshal(body, &encryptedMsg); err != nil { - logger.ErrorCF("wecom", "Failed to parse XML", map[string]interface{}{ + if err = xml.Unmarshal(body, &encryptedMsg); err != nil { + logger.ErrorCF("wecom", "Failed to parse XML", map[string]any{ "error": err.Error(), }) http.Error(w, "Invalid XML", http.StatusBadRequest) @@ -283,7 +283,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp // Reference: https://developer.work.weixin.qq.com/document/path/101033 decryptedMsg, err := decryptMessageWithVerify(encryptedMsg.Encrypt, c.config.EncodingAESKey, "") if err != nil { - logger.ErrorCF("wecom", "Failed to decrypt message", map[string]interface{}{ + logger.ErrorCF("wecom", "Failed to decrypt message", map[string]any{ "error": err.Error(), }) http.Error(w, "Decryption failed", http.StatusInternalServerError) @@ -293,7 +293,7 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp // Parse decrypted JSON message (AIBOT uses JSON format) var msg WeComBotMessage if err := json.Unmarshal([]byte(decryptedMsg), &msg); err != nil { - logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]interface{}{ + logger.ErrorCF("wecom", "Failed to parse decrypted message", map[string]any{ "error": err.Error(), }) http.Error(w, "Invalid message format", http.StatusBadRequest) @@ -311,8 +311,9 @@ func (c *WeComBotChannel) handleMessageCallback(ctx context.Context, w http.Resp // processMessage processes the received message func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessage) { // Skip unsupported message types - if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && msg.MsgType != "mixed" { - logger.DebugCF("wecom", "Skipping non-supported message type", map[string]interface{}{ + if msg.MsgType != "text" && msg.MsgType != "image" && msg.MsgType != "voice" && msg.MsgType != "file" && + msg.MsgType != "mixed" { + logger.DebugCF("wecom", "Skipping non-supported message type", map[string]any{ "msg_type": msg.MsgType, }) return @@ -323,7 +324,7 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag c.msgMu.Lock() if c.processedMsgs[msgID] { c.msgMu.Unlock() - logger.DebugCF("wecom", "Skipping duplicate message", map[string]interface{}{ + logger.DebugCF("wecom", "Skipping duplicate message", map[string]any{ "msg_id": msgID, }) return @@ -390,7 +391,7 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag metadata["sender_id"] = senderID } - logger.DebugCF("wecom", "Received message", map[string]interface{}{ + logger.DebugCF("wecom", "Received message", map[string]any{ "sender_id": senderID, "msg_type": msg.MsgType, "peer_kind": peerKind, @@ -459,7 +460,7 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content // handleHealth handles health check requests func (c *WeComBotChannel) handleHealth(w http.ResponseWriter, r *http.Request) { - status := map[string]interface{}{ + status := map[string]any{ "status": "ok", "running": c.IsRunning(), } diff --git a/pkg/channels/wecom/bot_test.go b/pkg/channels/wecom/bot_test.go index 460e0058f..328b145c2 100644 --- a/pkg/channels/wecom/bot_test.go +++ b/pkg/channels/wecom/bot_test.go @@ -18,7 +18,6 @@ import ( "testing" "github.com/sipeed/picoclaw/pkg/bus" - "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" ) @@ -196,10 +195,8 @@ func TestWeComBotVerifySignature(t *testing.T) { Token: "", WebhookURL: "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", } - base := channels.NewBaseChannel("wecom", cfgEmpty, msgBus, cfgEmpty.AllowFrom) chEmpty := &WeComBotChannel{ - BaseChannel: base, - config: cfgEmpty, + config: cfgEmpty, } if !verifySignature(chEmpty.config.Token, "any_sig", "any_ts", "any_nonce", "any_msg") { @@ -356,7 +353,11 @@ func TestWeComBotHandleVerification(t *testing.T) { nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encryptedEchostr) - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) w := httptest.NewRecorder() ch.handleVerification(context.Background(), w, req) @@ -386,7 +387,11 @@ func TestWeComBotHandleVerification(t *testing.T) { timestamp := "1234567890" nonce := "test_nonce" - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, nil) + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encryptedEchostr, + nil, + ) w := httptest.NewRecorder() ch.handleVerification(context.Background(), w, req) @@ -410,14 +415,14 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { t.Run("valid direct message callback", func(t *testing.T) { // Create JSON message for direct chat (single) jsonMsg := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chattype": "single", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chattype": "single", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` // Encrypt message encrypted, _ := encryptTestMessage(jsonMsg, aesKey) @@ -435,7 +440,11 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -451,15 +460,15 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { t.Run("valid group message callback", func(t *testing.T) { // Create JSON message for group chat jsonMsg := `{ - "msgid": "test_msg_id_456", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user456"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello Group"} - }` + "msgid": "test_msg_id_456", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user456"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello Group"} + }` // Encrypt message encrypted, _ := encryptTestMessage(jsonMsg, aesKey) @@ -477,7 +486,11 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encrypted) - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -506,7 +519,11 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, "") - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, strings.NewReader("invalid xml")) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + strings.NewReader("invalid xml"), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -528,7 +545,11 @@ func TestWeComBotHandleMessageCallback(t *testing.T) { timestamp := "1234567890" nonce := "test_nonce" - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature=invalid_sig×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleMessageCallback(context.Background(), w, req) @@ -623,7 +644,11 @@ func TestWeComBotHandleWebhook(t *testing.T) { nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encoded) - req := httptest.NewRequest(http.MethodGet, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, nil) + req := httptest.NewRequest( + http.MethodGet, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce+"&echostr="+encoded, + nil, + ) w := httptest.NewRecorder() ch.handleWebhook(w, req) @@ -646,7 +671,11 @@ func TestWeComBotHandleWebhook(t *testing.T) { nonce := "test_nonce" signature := generateSignature("test_token", timestamp, nonce, encryptedWrapper.Encrypt) - req := httptest.NewRequest(http.MethodPost, "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, bytes.NewReader(wrapperData)) + req := httptest.NewRequest( + http.MethodPost, + "/webhook/wecom?msg_signature="+signature+"×tamp="+timestamp+"&nonce="+nonce, + bytes.NewReader(wrapperData), + ) w := httptest.NewRecorder() ch.handleWebhook(w, req) @@ -713,15 +742,15 @@ func TestWeComBotReplyMessage(t *testing.T) { func TestWeComBotMessageStructure(t *testing.T) { jsonData := `{ - "msgid": "test_msg_id_123", - "aibotid": "test_aibot_id", - "chatid": "group_chat_id_123", - "chattype": "group", - "from": {"userid": "user123"}, - "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", - "msgtype": "text", - "text": {"content": "Hello World"} - }` + "msgid": "test_msg_id_123", + "aibotid": "test_aibot_id", + "chatid": "group_chat_id_123", + "chattype": "group", + "from": {"userid": "user123"}, + "response_url": "https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=test", + "msgtype": "text", + "text": {"content": "Hello World"} + }` var msg WeComBotMessage err := json.Unmarshal([]byte(jsonData), &msg) diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 1ac256766..7e8f13ab6 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -87,7 +87,7 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("whatsapp connection not established") } - payload := map[string]interface{}{ + payload := map[string]any{ "type": "message", "to": msg.ChatID, "content": msg.Content, @@ -127,7 +127,7 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { continue } - var msg map[string]interface{} + var msg map[string]any if err := json.Unmarshal(message, &msg); err != nil { log.Printf("Failed to unmarshal WhatsApp message: %v", err) continue @@ -145,7 +145,7 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { } } -func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { +func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { senderID, ok := msg["from"].(string) if !ok { return @@ -162,7 +162,7 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]interface{}) { } var mediaPaths []string - if mediaData, ok := msg["media"].([]interface{}); ok { + if mediaData, ok := msg["media"].([]any); ok { mediaPaths = make([]string, 0, len(mediaData)) for _, m := range mediaData { if path, ok := m.(string); ok { From 153198e0f35fef47d44c82af0b4d46ad3e411ff0 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sun, 22 Feb 2026 21:57:12 +0800 Subject: [PATCH 08/28] refactor(bus,channels): promote peer and messageID from metadata to structured fields Add bus.Peer struct and explicit Peer/MessageID fields to InboundMessage, replacing the implicit peer_kind/peer_id/message_id metadata convention. - Add Peer{Kind, ID} type to pkg/bus/types.go - Extend InboundMessage with Peer and MessageID fields - Change BaseChannel.HandleMessage signature to accept peer and messageID - Adapt all 12 channel implementations to pass structured peer/messageID - Simplify agent extractPeer() to read msg.Peer directly - extractParentPeer unchanged (parent_peer still via metadata) --- pkg/agent/loop.go | 11 +++++------ pkg/bus/types.go | 8 ++++++++ pkg/channels/base.go | 21 +++++++++++++------- pkg/channels/dingtalk/dingtalk.go | 9 ++++----- pkg/channels/discord/discord.go | 7 +++---- pkg/channels/feishu/feishu_32.go | 4 +++- pkg/channels/feishu/feishu_64.go | 14 ++++++------- pkg/channels/line/line.go | 10 ++++------ pkg/channels/maixcam/maixcam.go | 4 +--- pkg/channels/onebot/onebot.go | 14 ++++++------- pkg/channels/qq/qq.go | 33 ++++++++++++++++++++----------- pkg/channels/slack/slack.go | 16 +++++++-------- pkg/channels/telegram/telegram.go | 16 +++++++++++---- pkg/channels/wecom/app.go | 7 ++++--- pkg/channels/wecom/bot.go | 6 +++--- pkg/channels/whatsapp/whatsapp.go | 14 ++++++------- 16 files changed, 109 insertions(+), 85 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index bf229ad74..d8ea3b091 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -1122,21 +1122,20 @@ func (al *AgentLoop) handleCommand(ctx context.Context, msg bus.InboundMessage) return "", false } -// extractPeer extracts the routing peer from inbound message metadata. +// extractPeer extracts the routing peer from the inbound message's structured Peer field. func extractPeer(msg bus.InboundMessage) *routing.RoutePeer { - peerKind := msg.Metadata["peer_kind"] - if peerKind == "" { + if msg.Peer.Kind == "" { return nil } - peerID := msg.Metadata["peer_id"] + peerID := msg.Peer.ID if peerID == "" { - if peerKind == "direct" { + if msg.Peer.Kind == "direct" { peerID = msg.SenderID } else { peerID = msg.ChatID } } - return &routing.RoutePeer{Kind: peerKind, ID: peerID} + return &routing.RoutePeer{Kind: msg.Peer.Kind, ID: peerID} } // extractParentPeer extracts the parent peer (reply-to) from inbound message metadata. diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 44f9181a5..081f13a0b 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -1,11 +1,19 @@ package bus +// Peer identifies the routing peer for a message (direct, group, channel, etc.) +type Peer struct { + Kind string `json:"kind"` // "direct" | "group" | "channel" | "" + ID string `json:"id"` +} + type InboundMessage struct { Channel string `json:"channel"` SenderID string `json:"sender_id"` ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` + Peer Peer `json:"peer"` // routing peer + MessageID string `json:"message_id,omitempty"` // platform message ID SessionKey string `json:"session_key"` Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 5d77c6c0d..5e603f0d4 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -81,18 +81,25 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { return false } -func (c *BaseChannel) HandleMessage(senderID, chatID, content string, media []string, metadata map[string]string) { +func (c *BaseChannel) HandleMessage( + peer bus.Peer, + messageID, senderID, chatID, content string, + media []string, + metadata map[string]string, +) { if !c.IsAllowed(senderID) { return } msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Peer: peer, + MessageID: messageID, + Metadata: metadata, } c.bus.PublishInbound(msg) diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index afc0de47f..a8aee65d6 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -160,12 +160,11 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "session_webhook": data.SessionWebhook, } + var peer bus.Peer if data.ConversationType == "1" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = data.ConversationId + peer = bus.Peer{Kind: "group", ID: data.ConversationId} } logger.DebugCF("dingtalk", "Received message", map[string]any{ @@ -175,7 +174,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived( }) // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(peer, "", senderID, chatID, content, nil, metadata) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index b83ac28fd..416a94710 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -294,19 +294,18 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag peerID = senderID } + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ - "message_id": m.ID, "user_id": senderID, "username": m.Author.Username, "display_name": senderName, "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), - "peer_kind": peerKind, - "peer_id": peerID, } - c.HandleMessage(senderID, m.ChannelID, content, mediaPaths, metadata) + c.HandleMessage(peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata) } // startTyping starts a continuous typing indicator loop for the given chatID. diff --git a/pkg/channels/feishu/feishu_32.go b/pkg/channels/feishu/feishu_32.go index 14711e49e..d0ec758c6 100644 --- a/pkg/channels/feishu/feishu_32.go +++ b/pkg/channels/feishu/feishu_32.go @@ -18,7 +18,9 @@ type FeishuChannel struct { // NewFeishuChannel returns an error on 32-bit architectures where the Feishu SDK is not supported func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - return nil, errors.New("feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config") + return nil, errors.New( + "feishu channel is not supported on 32-bit architectures (armv7l, 386, etc.). Please use a 64-bit system or disable feishu in your config", + ) } // Start is a stub method to satisfy the Channel interface diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index aa4e141c4..d67823974 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -153,8 +153,9 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 } metadata := map[string]string{} - if messageID := stringValue(message.MessageId); messageID != "" { - metadata["message_id"] = messageID + messageID := "" + if mid := stringValue(message.MessageId); mid != "" { + messageID = mid } if messageType := stringValue(message.MessageType); messageType != "" { metadata["message_type"] = messageType @@ -167,12 +168,11 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 } chatType := stringValue(message.ChatType) + var peer bus.Peer if chatType == "p2p" { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } logger.InfoCF("feishu", "Feishu message received", map[string]any{ @@ -181,7 +181,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 "preview": utils.Truncate(content, 80), }) - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(peer, messageID, senderID, chatID, content, nil, metadata) return nil } diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 4e1d0dfd3..96297e2cd 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -364,15 +364,13 @@ func (c *LINEChannel) processEvent(event lineEvent) { metadata := map[string]string{ "platform": "line", "source_type": event.Source.Type, - "message_id": msg.ID, } + var peer bus.Peer if isGroup { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } else { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } logger.DebugCF("line", "Received message", map[string]any{ @@ -386,7 +384,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { // Show typing/loading indicator (requires user ID, not group ID) c.sendLoading(senderID) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(peer, msg.ID, senderID, chatID, content, mediaPaths, metadata) } // isBotMentioned checks if the bot is mentioned in the message. diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index a7bff55e0..280098dda 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -171,11 +171,9 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { "y": fmt.Sprintf("%.0f", y), "w": fmt.Sprintf("%.0f", w), "h": fmt.Sprintf("%.0f", h), - "peer_kind": "channel", - "peer_id": "default", } - c.HandleMessage(senderID, chatID, content, []string{}, metadata) + c.HandleMessage(bus.Peer{Kind: "channel", ID: "default"}, "", senderID, chatID, content, []string{}, metadata) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 3d2e64e2a..642eebd1d 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -856,9 +856,9 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { senderID := strconv.FormatInt(userID, 10) var chatID string - metadata := map[string]string{ - "message_id": messageID, - } + var peer bus.Peer + + metadata := map[string]string{} if parsed.ReplyTo != "" { metadata["reply_to_message_id"] = parsed.ReplyTo @@ -867,14 +867,12 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { switch raw.MessageType { case "private": chatID = "private:" + senderID - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} case "group": groupIDStr := strconv.FormatInt(groupID, 10) chatID = "group:" + groupIDStr - metadata["peer_kind"] = "group" - metadata["peer_id"] = groupIDStr + peer = bus.Peer{Kind: "group", ID: groupIDStr} metadata["group_id"] = groupIDStr senderUserID, _ := parseJSONInt64(sender.UserID) @@ -929,7 +927,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { c.pendingEmojiMsg.Store(chatID, messageID) } - c.HandleMessage(senderID, chatID, content, parsed.Media, metadata) + c.HandleMessage(peer, messageID, senderID, chatID, content, parsed.Media, metadata) } func (c *OneBotChannel) isDuplicate(messageID string) bool { diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 2a95bbd06..429e23cbf 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -164,13 +164,17 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { }) // 转发到消息总线 - metadata := map[string]string{ - "message_id": data.ID, - "peer_kind": "direct", - "peer_id": senderID, - } - - c.HandleMessage(senderID, senderID, content, []string{}, metadata) + metadata := map[string]string{} + + c.HandleMessage( + bus.Peer{Kind: "direct", ID: senderID}, + data.ID, + senderID, + senderID, + content, + []string{}, + metadata, + ) return nil } @@ -208,13 +212,18 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { // 转发到消息总线(使用 GroupID 作为 ChatID) metadata := map[string]string{ - "message_id": data.ID, - "group_id": data.GroupID, - "peer_kind": "group", - "peer_id": data.GroupID, + "group_id": data.GroupID, } - c.HandleMessage(senderID, data.GroupID, content, []string{}, metadata) + c.HandleMessage( + bus.Peer{Kind: "group", ID: data.GroupID}, + data.ID, + senderID, + data.GroupID, + content, + []string{}, + metadata, + ) return nil } diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index cafe53103..b459a7140 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -284,13 +284,13 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { peerID = senderID } + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", - "peer_kind": peerKind, - "peer_id": peerID, "team_id": c.teamID, } @@ -301,7 +301,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "has_thread": threadTS != "", }) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(peer, messageTS, senderID, chatID, content, mediaPaths, metadata) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -351,18 +351,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { mentionPeerID = senderID } + mentionPeer := bus.Peer{Kind: mentionPeerKind, ID: mentionPeerID} + metadata := map[string]string{ "message_ts": messageTS, "channel_id": channelID, "thread_ts": threadTS, "platform": "slack", "is_mention": "true", - "peer_kind": mentionPeerKind, - "peer_id": mentionPeerID, "team_id": c.teamID, } - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(mentionPeer, messageTS, senderID, chatID, content, nil, metadata) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -396,8 +396,6 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "platform": "slack", "is_command": "true", "trigger_id": cmd.TriggerID, - "peer_kind": "channel", - "peer_id": channelID, "team_id": c.teamID, } @@ -407,7 +405,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "text": utils.Truncate(content, 50), }) - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(bus.Peer{Kind: "channel", ID: channelID}, "", senderID, chatID, content, nil, metadata) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 7619440e2..5703000b4 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -362,17 +362,25 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes peerID = fmt.Sprintf("%d", chatID) } + peer := bus.Peer{Kind: peerKind, ID: peerID} + messageID := fmt.Sprintf("%d", message.MessageID) + metadata := map[string]string{ - "message_id": fmt.Sprintf("%d", message.MessageID), "user_id": fmt.Sprintf("%d", user.ID), "username": user.Username, "first_name": user.FirstName, "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), - "peer_kind": peerKind, - "peer_id": peerID, } - c.HandleMessage(fmt.Sprintf("%d", user.ID), fmt.Sprintf("%d", chatID), content, mediaPaths, metadata) + c.HandleMessage( + peer, + messageID, + fmt.Sprintf("%d", user.ID), + fmt.Sprintf("%d", chatID), + content, + mediaPaths, + metadata, + ) return nil } diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index f3557d60f..873431d3c 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -425,6 +425,9 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag // Build metadata // WeCom App only supports direct messages (private chat) + peer := bus.Peer{Kind: "direct", ID: senderID} + messageID := fmt.Sprintf("%d", msg.MsgId) + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": fmt.Sprintf("%d", msg.MsgId), @@ -432,8 +435,6 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag "platform": "wecom_app", "media_id": msg.MediaId, "create_time": fmt.Sprintf("%d", msg.CreateTime), - "peer_kind": "direct", - "peer_id": senderID, } content := msg.Content @@ -445,7 +446,7 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag }) // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(peer, messageID, senderID, chatID, content, nil, metadata) } // tokenRefreshLoop periodically refreshes the access token diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 17ee2107f..3a8a16c43 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -378,12 +378,12 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag } // Build metadata + peer := bus.Peer{Kind: peerKind, ID: peerID} + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": msg.MsgID, "platform": "wecom", - "peer_kind": peerKind, - "peer_id": peerID, "response_url": msg.ResponseURL, } if isGroupChat { @@ -400,7 +400,7 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag }) // Handle the message through the base channel - c.HandleMessage(senderID, chatID, content, nil, metadata) + c.HandleMessage(peer, msg.MsgID, senderID, chatID, content, nil, metadata) } // sendWebhookReply sends a reply using the webhook URL diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 7e8f13ab6..1a5401172 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -172,22 +172,22 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { } metadata := make(map[string]string) - if messageID, ok := msg["id"].(string); ok { - metadata["message_id"] = messageID + var messageID string + if mid, ok := msg["id"].(string); ok { + messageID = mid } if userName, ok := msg["from_name"].(string); ok { metadata["user_name"] = userName } + var peer bus.Peer if chatID == senderID { - metadata["peer_kind"] = "direct" - metadata["peer_id"] = senderID + peer = bus.Peer{Kind: "direct", ID: senderID} } else { - metadata["peer_kind"] = "group" - metadata["peer_id"] = chatID + peer = bus.Peer{Kind: "group", ID: chatID} } log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) - c.HandleMessage(senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata) } From b6161aec3f49beae2acfc597476e20d0e4b8732c Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sun, 22 Feb 2026 22:25:07 +0800 Subject: [PATCH 09/28] refactor(channels): unify Start/Stop lifecycle and fix goroutine/context leaks - OneBot: remove close(ch) race in Stop() pending cleanup; add WriteDeadline to Send/sendAPIRequest - Telegram: add cancelCtx; Stop() now calls bh.Stop(), cancel(), and cleans up thinking CancelFuncs - Discord: add cancelCtx via WithCancel; Stop() calls cancel(); remove unused getContext() - WhatsApp: add cancelCtx; Send() adds WriteDeadline; replace stdlib log with project logger - MaixCam: add cancelCtx; Send() adds WriteDeadline; Stop() calls cancel() before closing --- pkg/channels/discord/discord.go | 17 ++++++------ pkg/channels/maixcam/maixcam.go | 25 +++++++++++++---- pkg/channels/onebot/onebot.go | 7 +++-- pkg/channels/telegram/telegram.go | 35 +++++++++++++++++++---- pkg/channels/whatsapp/whatsapp.go | 46 +++++++++++++++++++++++-------- 5 files changed, 96 insertions(+), 34 deletions(-) diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 416a94710..faf1e1358 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -29,6 +29,7 @@ type DiscordChannel struct { config config.DiscordConfig transcriber *voice.GroqTranscriber ctx context.Context + cancel context.CancelFunc typingMu sync.Mutex typingStop map[string]chan struct{} // chatID → stop signal botUserID string // stored for mention checking @@ -56,17 +57,10 @@ func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { c.transcriber = transcriber } -func (c *DiscordChannel) getContext() context.Context { - if c.ctx == nil { - return context.Background() - } - return c.ctx -} - func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") - c.ctx = ctx + c.ctx, c.cancel = context.WithCancel(ctx) // Get bot user ID before opening session to avoid race condition botUser, err := c.session.User("@me") @@ -103,6 +97,11 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } c.typingMu.Unlock() + // Cancel our context so typing goroutines using c.ctx.Done() exit + if c.cancel != nil { + c.cancel() + } + if err := c.session.Close(); err != nil { return fmt.Errorf("failed to close discord session: %w", err) } @@ -236,7 +235,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.getContext(), transcriptionTimeout) + ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout) result, err := c.transcriber.Transcribe(ctx, localPath) cancel() // Release context resources immediately to avoid leaks in for loop diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index 280098dda..05213b095 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" @@ -17,6 +18,8 @@ type MaixCamChannel struct { *channels.BaseChannel config config.MaixCamConfig listener net.Listener + ctx context.Context + cancel context.CancelFunc clients map[net.Conn]bool clientsMux sync.RWMutex } @@ -41,9 +44,12 @@ func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamC func (c *MaixCamChannel) Start(ctx context.Context) error { logger.InfoC("maixcam", "Starting MaixCam channel server") + c.ctx, c.cancel = context.WithCancel(ctx) + addr := fmt.Sprintf("%s:%d", c.config.Host, c.config.Port) listener, err := net.Listen("tcp", addr) if err != nil { + c.cancel() return fmt.Errorf("failed to listen on %s: %w", addr, err) } @@ -55,17 +61,17 @@ func (c *MaixCamChannel) Start(ctx context.Context) error { "port": c.config.Port, }) - go c.acceptConnections(ctx) + go c.acceptConnections() return nil } -func (c *MaixCamChannel) acceptConnections(ctx context.Context) { +func (c *MaixCamChannel) acceptConnections() { logger.DebugC("maixcam", "Starting connection acceptor") for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): logger.InfoC("maixcam", "Stopping connection acceptor") return default: @@ -87,12 +93,12 @@ func (c *MaixCamChannel) acceptConnections(ctx context.Context) { c.clients[conn] = true c.clientsMux.Unlock() - go c.handleConnection(conn, ctx) + go c.handleConnection(conn) } } } -func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { +func (c *MaixCamChannel) handleConnection(conn net.Conn) { logger.DebugC("maixcam", "Handling MaixCam connection") defer func() { @@ -107,7 +113,7 @@ func (c *MaixCamChannel) handleConnection(conn net.Conn, ctx context.Context) { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: var msg MaixCamMessage @@ -186,6 +192,11 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error { logger.InfoC("maixcam", "Stopping MaixCam channel") c.SetRunning(false) + // Cancel context first to signal goroutines to exit + if c.cancel != nil { + c.cancel() + } + if c.listener != nil { c.listener.Close() } @@ -229,6 +240,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro var sendErr error for conn := range c.clients { + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if _, err := conn.Write(data); err != nil { logger.ErrorCF("maixcam", "Failed to send to client", map[string]any{ "client": conn.RemoteAddr().String(), @@ -236,6 +248,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro }) sendErr = err } + _ = conn.SetWriteDeadline(time.Time{}) } return sendErr diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 642eebd1d..4f35888ca 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -298,7 +298,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { @@ -354,8 +356,7 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { } c.pendingMu.Lock() - for echo, ch := range c.pending { - close(ch) + for echo := range c.pending { delete(c.pending, echo) } c.pendingMu.Unlock() @@ -402,7 +403,9 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error } c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) c.writeMu.Unlock() if err != nil { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 5703000b4..af825ddc9 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -27,10 +27,13 @@ import ( type TelegramChannel struct { *channels.BaseChannel bot *telego.Bot + bh *telegohandler.BotHandler commands TelegramCommander config *config.Config chatIDs map[string]int64 transcriber *voice.GroqTranscriber + ctx context.Context + cancel context.CancelFunc placeholders sync.Map // chatID -> messageID stopThinking sync.Map // chatID -> thinkingCancel } @@ -94,17 +97,22 @@ func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { func (c *TelegramChannel) Start(ctx context.Context) error { logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") - updates, err := c.bot.UpdatesViaLongPolling(ctx, &telego.GetUpdatesParams{ + c.ctx, c.cancel = context.WithCancel(ctx) + + updates, err := c.bot.UpdatesViaLongPolling(c.ctx, &telego.GetUpdatesParams{ Timeout: 30, }) if err != nil { + c.cancel() return fmt.Errorf("failed to start long polling: %w", err) } bh, err := telegohandler.NewBotHandler(c.bot, updates) if err != nil { + c.cancel() return fmt.Errorf("failed to create bot handler: %w", err) } + c.bh = bh bh.HandleMessage(func(ctx *th.Context, message telego.Message) error { c.commands.Help(ctx, message) @@ -133,17 +141,32 @@ func (c *TelegramChannel) Start(ctx context.Context) error { go bh.Start() - go func() { - <-ctx.Done() - bh.Stop() - }() - return nil } func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.SetRunning(false) + + // Clean up all thinking cancel functions to avoid context leaks + c.stopThinking.Range(func(key, value any) bool { + if cf, ok := value.(*thinkingCancel); ok && cf != nil { + cf.Cancel() + } + c.stopThinking.Delete(key) + return true + }) + + // Stop the bot handler + if c.bh != nil { + c.bh.Stop() + } + + // Cancel our context (stops long polling) + if c.cancel != nil { + c.cancel() + } + return nil } diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 1a5401172..cbc82fd09 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "log" "sync" "time" @@ -13,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -21,6 +21,8 @@ type WhatsAppChannel struct { conn *websocket.Conn config config.WhatsAppConfig url string + ctx context.Context + cancel context.CancelFunc mu sync.Mutex connected bool } @@ -37,13 +39,18 @@ func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsA } func (c *WhatsAppChannel) Start(ctx context.Context) error { - log.Printf("Starting WhatsApp channel connecting to %s...", c.url) + logger.InfoCF("whatsapp", "Starting WhatsApp channel", map[string]any{ + "bridge_url": c.url, + }) + + c.ctx, c.cancel = context.WithCancel(ctx) dialer := websocket.DefaultDialer dialer.HandshakeTimeout = 10 * time.Second conn, _, err := dialer.Dial(c.url, nil) if err != nil { + c.cancel() return fmt.Errorf("failed to connect to WhatsApp bridge: %w", err) } @@ -53,22 +60,29 @@ func (c *WhatsAppChannel) Start(ctx context.Context) error { c.mu.Unlock() c.SetRunning(true) - log.Println("WhatsApp channel connected") + logger.InfoC("whatsapp", "WhatsApp channel connected") - go c.listen(ctx) + go c.listen() return nil } func (c *WhatsAppChannel) Stop(ctx context.Context) error { - log.Println("Stopping WhatsApp channel...") + logger.InfoC("whatsapp", "Stopping WhatsApp channel...") + + // Cancel context first to signal listen goroutine to exit + if c.cancel != nil { + c.cancel() + } c.mu.Lock() defer c.mu.Unlock() if c.conn != nil { if err := c.conn.Close(); err != nil { - log.Printf("Error closing WhatsApp connection: %v", err) + logger.ErrorCF("whatsapp", "Error closing WhatsApp connection", map[string]any{ + "error": err.Error(), + }) } c.conn = nil } @@ -98,17 +112,20 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("failed to marshal message: %w", err) } + _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { + _ = c.conn.SetWriteDeadline(time.Time{}) return fmt.Errorf("failed to send message: %w", err) } + _ = c.conn.SetWriteDeadline(time.Time{}) return nil } -func (c *WhatsAppChannel) listen(ctx context.Context) { +func (c *WhatsAppChannel) listen() { for { select { - case <-ctx.Done(): + case <-c.ctx.Done(): return default: c.mu.Lock() @@ -122,14 +139,18 @@ func (c *WhatsAppChannel) listen(ctx context.Context) { _, message, err := conn.ReadMessage() if err != nil { - log.Printf("WhatsApp read error: %v", err) + logger.ErrorCF("whatsapp", "WhatsApp read error", map[string]any{ + "error": err.Error(), + }) time.Sleep(2 * time.Second) continue } var msg map[string]any if err := json.Unmarshal(message, &msg); err != nil { - log.Printf("Failed to unmarshal WhatsApp message: %v", err) + logger.ErrorCF("whatsapp", "Failed to unmarshal WhatsApp message", map[string]any{ + "error": err.Error(), + }) continue } @@ -187,7 +208,10 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { peer = bus.Peer{Kind: "group", ID: chatID} } - log.Printf("WhatsApp message from %s: %s...", senderID, utils.Truncate(content, 50)) + logger.InfoCF("whatsapp", "WhatsApp message received", map[string]any{ + "sender": senderID, + "preview": utils.Truncate(content, 50), + }) c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata) } From 70019836b5185c4dd269258d0c6135e2b9ed1029 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sun, 22 Feb 2026 22:46:29 +0800 Subject: [PATCH 10/28] refactor(channels): unify message splitting and add per-channel worker queues Move message splitting from individual channels (Discord) to the Manager layer via per-channel worker goroutines. Each channel now declares its max message length through BaseChannelOption/MessageLengthProvider, and the Manager automatically splits oversized outbound messages before dispatch. This prevents one slow channel from blocking all others. - Add WithMaxMessageLength option and MessageLengthProvider interface - Set platform-specific limits (Discord 2000, Telegram 4096, Slack 40000, etc.) - Convert SplitMessage to rune-aware counting for correct Unicode handling - Replace single dispatcher goroutine with per-channel buffered worker queues - Remove Discord's internal SplitMessage call (now handled centrally) --- pkg/channels/base.go | 48 +++++++++++-- pkg/channels/dingtalk/dingtalk.go | 2 +- pkg/channels/discord/discord.go | 15 +--- pkg/channels/line/line.go | 2 +- pkg/channels/manager.go | 112 ++++++++++++++++++++++++++--- pkg/channels/slack/slack.go | 2 +- pkg/channels/telegram/telegram.go | 8 ++- pkg/channels/wecom/app.go | 2 +- pkg/channels/wecom/bot.go | 2 +- pkg/channels/whatsapp/whatsapp.go | 2 +- pkg/utils/message.go | 114 ++++++++++++++++++------------ pkg/utils/message_test.go | 60 +++++++++++----- 12 files changed, 271 insertions(+), 98 deletions(-) diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 5e603f0d4..f70145981 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -17,21 +17,55 @@ type Channel interface { IsAllowed(senderID string) bool } +// BaseChannelOption is a functional option for configuring a BaseChannel. +type BaseChannelOption func(*BaseChannel) + +// WithMaxMessageLength sets the maximum message length (in runes) for a channel. +// Messages exceeding this limit will be automatically split by the Manager. +// A value of 0 means no limit. +func WithMaxMessageLength(n int) BaseChannelOption { + return func(c *BaseChannel) { c.maxMessageLength = n } +} + +// MessageLengthProvider is an opt-in interface that channels implement +// to advertise their maximum message length. The Manager uses this via +// type assertion to decide whether to split outbound messages. +type MessageLengthProvider interface { + MaxMessageLength() int +} + type BaseChannel struct { - config any - bus *bus.MessageBus - running atomic.Bool - name string - allowList []string + config any + bus *bus.MessageBus + running atomic.Bool + name string + allowList []string + maxMessageLength int } -func NewBaseChannel(name string, config any, bus *bus.MessageBus, allowList []string) *BaseChannel { - return &BaseChannel{ +func NewBaseChannel( + name string, + config any, + bus *bus.MessageBus, + allowList []string, + opts ...BaseChannelOption, +) *BaseChannel { + bc := &BaseChannel{ config: config, bus: bus, name: name, allowList: allowList, } + for _, opt := range opts { + opt(bc) + } + return bc +} + +// MaxMessageLength returns the maximum message length (in runes) for this channel. +// A value of 0 means no limit. +func (c *BaseChannel) MaxMessageLength() int { + return c.maxMessageLength } func (c *BaseChannel) Name() string { diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index a8aee65d6..e051add1f 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -38,7 +38,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("dingtalk client_id and client_secret are required") } - base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(20000)) return &DingTalkChannel{ BaseChannel: base, diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index faf1e1358..623bc9f48 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -41,7 +41,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } - base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000)) return &DiscordChannel{ BaseChannel: base, @@ -121,20 +121,11 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return fmt.Errorf("channel ID is empty") } - runes := []rune(msg.Content) - if len(runes) == 0 { + if len([]rune(msg.Content)) == 0 { return nil } - chunks := utils.SplitMessage(msg.Content, 2000) // Split messages into chunks, Discord length limit: 2000 chars - - for _, chunk := range chunks { - if err := c.sendChunk(ctx, channelID, chunk); err != nil { - return err - } - } - - return nil + return c.sendChunk(ctx, channelID, msg.Content) } func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 96297e2cd..9744e1848 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -60,7 +60,7 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return nil, fmt.Errorf("line channel_secret and channel_access_token are required") } - base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(5000)) return &LINEChannel{ BaseChannel: base, diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 7baef058c..081d616da 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -15,10 +15,20 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/utils" ) +const defaultChannelQueueSize = 100 + +type channelWorker struct { + ch Channel + queue chan bus.OutboundMessage + done chan struct{} +} + type Manager struct { channels map[string]Channel + workers map[string]*channelWorker bus *bus.MessageBus config *config.Config dispatchTask *asyncTask @@ -32,6 +42,7 @@ type asyncTask struct { func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error) { m := &Manager{ channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), bus: messageBus, config: cfg, } @@ -63,6 +74,11 @@ func (m *Manager) initChannel(name, displayName string) { }) } else { m.channels[name] = ch + m.workers[name] = &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + done: make(chan struct{}), + } logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ "channel": displayName, }) @@ -141,8 +157,6 @@ func (m *Manager) StartAll(ctx context.Context) error { dispatchCtx, cancel := context.WithCancel(ctx) m.dispatchTask = &asyncTask{cancel: cancel} - go m.dispatchOutbound(dispatchCtx) - for name, channel := range m.channels { logger.InfoCF("channels", "Starting channel", map[string]any{ "channel": name, @@ -155,6 +169,14 @@ func (m *Manager) StartAll(ctx context.Context) error { } } + // Start per-channel workers + for name, w := range m.workers { + go m.runWorker(dispatchCtx, name, w) + } + + // Start the dispatcher that reads from the bus and routes to workers + go m.dispatchOutbound(dispatchCtx) + logger.InfoC("channels", "All channels started") return nil } @@ -165,11 +187,21 @@ func (m *Manager) StopAll(ctx context.Context) error { logger.InfoC("channels", "Stopping all channels") + // Cancel dispatcher first if m.dispatchTask != nil { m.dispatchTask.cancel() m.dispatchTask = nil } + // Close all worker queues and wait for them to drain + for _, w := range m.workers { + close(w.queue) + } + for _, w := range m.workers { + <-w.done + } + + // Stop all channels for name, channel := range m.channels { logger.InfoCF("channels", "Stopping channel", map[string]any{ "channel": name, @@ -186,6 +218,44 @@ func (m *Manager) StopAll(ctx context.Context) error { return nil } +// runWorker processes outbound messages for a single channel, splitting +// messages that exceed the channel's maximum message length. +func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) { + defer close(w.done) + for { + select { + case msg, ok := <-w.queue: + if !ok { + return + } + maxLen := 0 + if mlp, ok := w.ch.(MessageLengthProvider); ok { + maxLen = mlp.MaxMessageLength() + } + if maxLen > 0 && len([]rune(msg.Content)) > maxLen { + chunks := utils.SplitMessage(msg.Content, maxLen) + for _, chunk := range chunks { + chunkMsg := msg + chunkMsg.Content = chunk + if err := w.ch.Send(ctx, chunkMsg); err != nil { + logger.ErrorCF("channels", "Error sending chunk", map[string]any{ + "channel": name, "error": err.Error(), + }) + } + } + } else { + if err := w.ch.Send(ctx, msg); err != nil { + logger.ErrorCF("channels", "Error sending message", map[string]any{ + "channel": name, "error": err.Error(), + }) + } + } + case <-ctx.Done(): + return + } + } +} + func (m *Manager) dispatchOutbound(ctx context.Context) { logger.InfoC("channels", "Outbound dispatcher started") @@ -206,7 +276,8 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { } m.mu.RLock() - channel, exists := m.channels[msg.Channel] + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] m.mu.RUnlock() if !exists { @@ -216,11 +287,12 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { continue } - if err := channel.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message to channel", map[string]any{ - "channel": msg.Channel, - "error": err.Error(), - }) + if wExists { + select { + case w.queue <- msg: + case <-ctx.Done(): + return + } } } } @@ -262,17 +334,28 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() m.channels[name] = channel + m.workers[name] = &channelWorker{ + ch: channel, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + done: make(chan struct{}), + } } func (m *Manager) UnregisterChannel(name string) { m.mu.Lock() defer m.mu.Unlock() + if w, ok := m.workers[name]; ok { + close(w.queue) + <-w.done + } + delete(m.workers, name) delete(m.channels, name) } func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, content string) error { m.mu.RLock() - channel, exists := m.channels[channelName] + _, exists := m.channels[channelName] + w, wExists := m.workers[channelName] m.mu.RUnlock() if !exists { @@ -285,5 +368,16 @@ func (m *Manager) SendToChannel(ctx context.Context, channelName, chatID, conten Content: content, } + if wExists { + select { + case w.queue <- msg: + return nil + case <-ctx.Done(): + return ctx.Err() + } + } + + // Fallback: direct send (should not happen) + channel, _ := m.channels[channelName] return channel.Send(ctx, msg) } diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index b459a7140..fc0bee505 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -50,7 +50,7 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack socketClient := socketmode.New(api) - base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(40000)) return &SlackChannel{ BaseChannel: base, diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index af825ddc9..578e3c51e 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -76,7 +76,13 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann return nil, fmt.Errorf("failed to create telegram bot: %w", err) } - base := channels.NewBaseChannel("telegram", telegramCfg, bus, telegramCfg.AllowFrom) + base := channels.NewBaseChannel( + "telegram", + telegramCfg, + bus, + telegramCfg.AllowFrom, + channels.WithMaxMessageLength(4096), + ) return &TelegramChannel{ BaseChannel: base, diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 873431d3c..eb1711d75 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -120,7 +120,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") } - base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(2048)) return &WeComAppChannel{ BaseChannel: base, diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 3a8a16c43..bbac8611a 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -87,7 +87,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We return nil, fmt.Errorf("wecom token and webhook_url are required") } - base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(2048)) return &WeComBotChannel{ BaseChannel: base, diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index cbc82fd09..b5f3e99d7 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -28,7 +28,7 @@ type WhatsAppChannel struct { } func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { - base := channels.NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(65536)) return &WhatsAppChannel{ BaseChannel: base, diff --git a/pkg/utils/message.go b/pkg/utils/message.go index 1d05950d9..52a967f4c 100644 --- a/pkg/utils/message.go +++ b/pkg/utils/message.go @@ -5,11 +5,20 @@ import ( ) // SplitMessage splits long messages into chunks, preserving code block integrity. +// The maxLen parameter is measured in runes (Unicode characters), not bytes. // The function reserves a buffer (10% of maxLen, min 50) to leave room for closing code blocks, // but may extend to maxLen when needed. // Call SplitMessage with the full text content and the maximum allowed length of a single message; // it returns a slice of message chunks that each respect maxLen and avoid splitting fenced code blocks. func SplitMessage(content string, maxLen int) []string { + if maxLen <= 0 { + if content == "" { + return nil + } + return []string{content} + } + + runes := []rune(content) var messages []string // Dynamic buffer: 10% of maxLen, but at least 50 chars if possible @@ -21,9 +30,9 @@ func SplitMessage(content string, maxLen int) []string { codeBlockBuffer = maxLen / 2 } - for len(content) > 0 { - if len(content) <= maxLen { - messages = append(messages, content) + for len(runes) > 0 { + if len(runes) <= maxLen { + messages = append(messages, string(runes)) break } @@ -34,56 +43,66 @@ func SplitMessage(content string, maxLen int) []string { } // Find natural split point within the effective limit - msgEnd := findLastNewline(content[:effectiveLimit], 200) + msgEnd := findLastNewlineRunes(runes[:effectiveLimit], 200) if msgEnd <= 0 { - msgEnd = findLastSpace(content[:effectiveLimit], 100) + msgEnd = findLastSpaceRunes(runes[:effectiveLimit], 100) } if msgEnd <= 0 { msgEnd = effectiveLimit } // Check if this would end with an incomplete code block - candidate := content[:msgEnd] - unclosedIdx := findLastUnclosedCodeBlock(candidate) + candidate := runes[:msgEnd] + unclosedIdx := findLastUnclosedCodeBlockRunes(candidate) if unclosedIdx >= 0 { // Message would end with incomplete code block // Try to extend up to maxLen to include the closing ``` - if len(content) > msgEnd { - closingIdx := findNextClosingCodeBlock(content, msgEnd) + if len(runes) > msgEnd { + closingIdx := findNextClosingCodeBlockRunes(runes, msgEnd) if closingIdx > 0 && closingIdx <= maxLen { // Extend to include the closing ``` msgEnd = closingIdx } else { // Code block is too long to fit in one chunk or missing closing fence. // Try to split inside by injecting closing and reopening fences. - headerEnd := strings.Index(content[unclosedIdx:], "\n") + candidateStr := string(candidate) + unclosedStr := string(runes[unclosedIdx:]) + headerEnd := strings.Index(unclosedStr, "\n") + var header string if headerEnd == -1 { - headerEnd = unclosedIdx + 3 + header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+3])) } else { - headerEnd += unclosedIdx + header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+headerEnd])) } - header := strings.TrimSpace(content[unclosedIdx:headerEnd]) + headerEndIdx := unclosedIdx + len([]rune(header)) + if headerEnd != -1 { + headerEndIdx = unclosedIdx + headerEnd + } + + _ = candidateStr // used above for context // If we have a reasonable amount of content after the header, split inside - if msgEnd > headerEnd+20 { + if msgEnd > headerEndIdx+20 { // Find a better split point closer to maxLen innerLimit := maxLen - 5 // Leave room for "\n```" - betterEnd := findLastNewline(content[:innerLimit], 200) - if betterEnd > headerEnd { + betterEnd := findLastNewlineRunes(runes[:innerLimit], 200) + if betterEnd > headerEndIdx { msgEnd = betterEnd } else { msgEnd = innerLimit } - messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") - content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + chunk := strings.TrimRight(string(runes[:msgEnd]), " \t\n\r") + "\n```" + messages = append(messages, chunk) + remaining := strings.TrimSpace(header + "\n" + string(runes[msgEnd:])) + runes = []rune(remaining) continue } // Otherwise, try to split before the code block starts - newEnd := findLastNewline(content[:unclosedIdx], 200) + newEnd := findLastNewlineRunes(runes[:unclosedIdx], 200) if newEnd <= 0 { - newEnd = findLastSpace(content[:unclosedIdx], 100) + newEnd = findLastSpaceRunes(runes[:unclosedIdx], 100) } if newEnd > 0 { msgEnd = newEnd @@ -93,8 +112,10 @@ func SplitMessage(content string, maxLen int) []string { msgEnd = unclosedIdx } else { msgEnd = maxLen - 5 - messages = append(messages, strings.TrimRight(content[:msgEnd], " \t\n\r")+"\n```") - content = strings.TrimSpace(header + "\n" + content[msgEnd:]) + chunk := strings.TrimRight(string(runes[:msgEnd]), " \t\n\r") + "\n```" + messages = append(messages, chunk) + remaining := strings.TrimSpace(header + "\n" + string(runes[msgEnd:])) + runes = []rune(remaining) continue } } @@ -106,21 +127,22 @@ func SplitMessage(content string, maxLen int) []string { msgEnd = effectiveLimit } - messages = append(messages, content[:msgEnd]) - content = strings.TrimSpace(content[msgEnd:]) + messages = append(messages, string(runes[:msgEnd])) + remaining := strings.TrimSpace(string(runes[msgEnd:])) + runes = []rune(remaining) } return messages } -// findLastUnclosedCodeBlock finds the last opening ``` that doesn't have a closing ``` -// Returns the position of the opening ``` or -1 if all code blocks are complete -func findLastUnclosedCodeBlock(text string) int { +// findLastUnclosedCodeBlockRunes finds the last opening ``` that doesn't have a closing ``` +// Returns the rune position of the opening ``` or -1 if all code blocks are complete +func findLastUnclosedCodeBlockRunes(runes []rune) int { inCodeBlock := false lastOpenIdx := -1 - for i := 0; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { + for i := 0; i < len(runes); i++ { + if i+2 < len(runes) && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' { // Toggle code block state on each fence if !inCodeBlock { // Entering a code block: record this opening fence @@ -137,41 +159,41 @@ func findLastUnclosedCodeBlock(text string) int { return -1 } -// findNextClosingCodeBlock finds the next closing ``` starting from a position -// Returns the position after the closing ``` or -1 if not found -func findNextClosingCodeBlock(text string, startIdx int) int { - for i := startIdx; i < len(text); i++ { - if i+2 < len(text) && text[i] == '`' && text[i+1] == '`' && text[i+2] == '`' { +// findNextClosingCodeBlockRunes finds the next closing ``` starting from a rune position +// Returns the rune position after the closing ``` or -1 if not found +func findNextClosingCodeBlockRunes(runes []rune, startIdx int) int { + for i := startIdx; i < len(runes); i++ { + if i+2 < len(runes) && runes[i] == '`' && runes[i+1] == '`' && runes[i+2] == '`' { return i + 3 } } return -1 } -// findLastNewline finds the last newline character within the last N characters -// Returns the position of the newline or -1 if not found -func findLastNewline(s string, searchWindow int) int { - searchStart := len(s) - searchWindow +// findLastNewlineRunes finds the last newline character within the last N runes +// Returns the rune position of the newline or -1 if not found +func findLastNewlineRunes(runes []rune, searchWindow int) int { + searchStart := len(runes) - searchWindow if searchStart < 0 { searchStart = 0 } - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == '\n' { + for i := len(runes) - 1; i >= searchStart; i-- { + if runes[i] == '\n' { return i } } return -1 } -// findLastSpace finds the last space character within the last N characters -// Returns the position of the space or -1 if not found -func findLastSpace(s string, searchWindow int) int { - searchStart := len(s) - searchWindow +// findLastSpaceRunes finds the last space character within the last N runes +// Returns the rune position of the space or -1 if not found +func findLastSpaceRunes(runes []rune, searchWindow int) int { + searchStart := len(runes) - searchWindow if searchStart < 0 { searchStart = 0 } - for i := len(s) - 1; i >= searchStart; i-- { - if s[i] == ' ' || s[i] == '\t' { + for i := len(runes) - 1; i >= searchStart; i-- { + if runes[i] == ' ' || runes[i] == '\t' { return i } } diff --git a/pkg/utils/message_test.go b/pkg/utils/message_test.go index 338509437..78e1e2b40 100644 --- a/pkg/utils/message_test.go +++ b/pkg/utils/message_test.go @@ -34,11 +34,15 @@ func TestSplitMessage(t *testing.T) { maxLen: 2000, expectChunks: 2, checkContent: func(t *testing.T, chunks []string) { - if len(chunks[0]) > 2000 { - t.Errorf("Chunk 0 too large: %d", len(chunks[0])) + if len([]rune(chunks[0])) > 2000 { + t.Errorf("Chunk 0 too large: %d runes", len([]rune(chunks[0]))) } - if len(chunks[0])+len(chunks[1]) != len(longText) { - t.Errorf("Total length mismatch. Got %d, want %d", len(chunks[0])+len(chunks[1]), len(longText)) + if len([]rune(chunks[0]))+len([]rune(chunks[1])) != len([]rune(longText)) { + t.Errorf( + "Total rune length mismatch. Got %d, want %d", + len([]rune(chunks[0]))+len([]rune(chunks[1])), + len([]rune(longText)), + ) } }, }, @@ -53,11 +57,11 @@ func TestSplitMessage(t *testing.T) { maxLen: 2000, expectChunks: 2, checkContent: func(t *testing.T, chunks []string) { - if len(chunks[0]) != 1750 { - t.Errorf("Expected chunk 0 to be 1750 length (split at newline), got %d", len(chunks[0])) + if len([]rune(chunks[0])) != 1750 { + t.Errorf("Expected chunk 0 to be 1750 runes (split at newline), got %d", len([]rune(chunks[0]))) } if chunks[1] != strings.Repeat("b", 300) { - t.Errorf("Chunk 1 content mismatch. Len: %d", len(chunks[1])) + t.Errorf("Chunk 1 content mismatch. Len: %d", len([]rune(chunks[1]))) } }, }, @@ -78,17 +82,39 @@ func TestSplitMessage(t *testing.T) { }, }, { - name: "Preserve Unicode characters", - content: strings.Repeat("\u4e16", 1000), // 3000 bytes + name: "Preserve Unicode characters (rune-aware)", + content: strings.Repeat("\u4e16", 2500), // 2500 runes, 7500 bytes maxLen: 2000, expectChunks: 2, checkContent: func(t *testing.T, chunks []string) { - // Just verify we didn't panic and got valid strings. - // Go strings are UTF-8, if we split mid-rune it would be bad, - // but standard slicing might do that. - // Let's assume standard behavior is acceptable or check if it produces invalid rune? - if !strings.Contains(chunks[0], "\u4e16") { - t.Error("Chunk should contain unicode characters") + // Verify chunks contain valid unicode and don't split mid-rune + for i, chunk := range chunks { + runeCount := len([]rune(chunk)) + if runeCount > 2000 { + t.Errorf("Chunk %d has %d runes, exceeds maxLen 2000", i, runeCount) + } + if !strings.Contains(chunk, "\u4e16") { + t.Errorf("Chunk %d should contain unicode characters", i) + } + } + // Verify total rune count is preserved + totalRunes := 0 + for _, chunk := range chunks { + totalRunes += len([]rune(chunk)) + } + if totalRunes != 2500 { + t.Errorf("Total rune count mismatch. Got %d, want 2500", totalRunes) + } + }, + }, + { + name: "Zero maxLen returns single chunk", + content: "Hello world", + maxLen: 0, + expectChunks: 1, + checkContent: func(t *testing.T, chunks []string) { + if chunks[0] != "Hello world" { + t.Errorf("Expected original content, got %q", chunks[0]) } }, }, @@ -145,7 +171,7 @@ func TestSplitMessage_CodeBlockIntegrity(t *testing.T) { } // First chunk should contain meaningful content - if len(chunks[0]) > 40 { - t.Errorf("First chunk exceeded maxLen: length %d", len(chunks[0])) + if len([]rune(chunks[0])) > 40 { + t.Errorf("First chunk exceeded maxLen: length %d runes", len([]rune(chunks[0]))) } } From 8116bcb6bc7ae7cb8215d95708efc59b34505248 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sun, 22 Feb 2026 23:27:55 +0800 Subject: [PATCH 11/28] refactor(media): add MediaStore for unified media file lifecycle management Channels previously deleted downloaded media files via defer os.Remove, racing with the async Agent consumer. Introduce MediaStore to decouple file ownership: channels register files on download, Agent releases them after processing via ReleaseAll(scope). - New pkg/media with MediaStore interface + FileMediaStore implementation - InboundMessage gains MediaScope field for lifecycle tracking - BaseChannel gains SetMediaStore/GetMediaStore + BuildMediaScope helper - Manager injects MediaStore into channels; AgentLoop releases on completion - Telegram, Discord, Slack, OneBot, LINE channels migrated from defer os.Remove to store.Store() with media:// refs --- cmd/picoclaw/cmd_gateway.go | 9 +- pkg/agent/loop.go | 65 +++++++---- pkg/bus/types.go | 5 +- pkg/channels/base.go | 38 +++++-- pkg/channels/discord/discord.go | 30 ++--- pkg/channels/line/line.go | 35 +++--- pkg/channels/manager.go | 19 +++- pkg/channels/onebot/onebot.go | 66 ++++++----- pkg/channels/slack/slack.go | 30 ++--- pkg/channels/telegram/telegram.go | 41 +++---- pkg/media/store.go | 102 +++++++++++++++++ pkg/media/store_test.go | 179 ++++++++++++++++++++++++++++++ 12 files changed, 488 insertions(+), 131 deletions(-) create mode 100644 pkg/media/store.go create mode 100644 pkg/media/store_test.go diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index c62c868e3..3c2cb021d 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -33,6 +33,7 @@ import ( "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" @@ -123,14 +124,18 @@ func gatewayCmd() { return tools.SilentResult(response) }) - channelManager, err := channels.NewManager(cfg, msgBus) + // Create media store for file lifecycle management + mediaStore := media.NewFileMediaStore() + + channelManager, err := channels.NewManager(cfg, msgBus, mediaStore) if err != nil { fmt.Printf("Error creating channel manager: %v\n", err) os.Exit(1) } - // Inject channel manager into agent loop for command handling + // Inject channel manager and media store into agent loop agentLoop.SetChannelManager(channelManager) + agentLoop.SetMediaStore(mediaStore) var transcriber *voice.GroqTranscriber groqAPIKey := cfg.Providers.Groq.APIKey diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index d8ea3b091..97569bef7 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -21,6 +21,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/routing" "github.com/sipeed/picoclaw/pkg/skills" @@ -38,6 +39,7 @@ type AgentLoop struct { summarizing sync.Map fallback *providers.FallbackChain channelManager *channels.Manager + mediaStore media.MediaStore } // processOptions configures how a message is processed @@ -167,33 +169,47 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - response, err := al.processMessage(ctx, msg) - if err != nil { - response = fmt.Sprintf("Error processing message: %v", err) - } - - if response != "" { - // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). - alreadySent := false - defaultAgent := al.registry.GetDefaultAgent() - if defaultAgent != nil { - if tool, ok := defaultAgent.Tools.Get("message"); ok { - if mt, ok := tool.(*tools.MessageTool); ok { - alreadySent = mt.HasSentInRound() + // Process message and ensure media is released afterward + func() { + defer func() { + if al.mediaStore != nil && msg.MediaScope != "" { + if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { + logger.WarnCF("agent", "Failed to release media", map[string]any{ + "scope": msg.MediaScope, + "error": releaseErr.Error(), + }) } } + }() + + response, err := al.processMessage(ctx, msg) + if err != nil { + response = fmt.Sprintf("Error processing message: %v", err) } - if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ - Channel: msg.Channel, - ChatID: msg.ChatID, - Content: response, - }) + if response != "" { + // Check if the message tool already sent a response during this round. + // If so, skip publishing to avoid duplicate messages to the user. + // Use default agent's tools to check (message tool is shared). + alreadySent := false + defaultAgent := al.registry.GetDefaultAgent() + if defaultAgent != nil { + if tool, ok := defaultAgent.Tools.Get("message"); ok { + if mt, ok := tool.(*tools.MessageTool); ok { + alreadySent = mt.HasSentInRound() + } + } + } + + if !alreadySent { + al.bus.PublishOutbound(bus.OutboundMessage{ + Channel: msg.Channel, + ChatID: msg.ChatID, + Content: response, + }) + } } - } + }() } } @@ -216,6 +232,11 @@ func (al *AgentLoop) SetChannelManager(cm *channels.Manager) { al.channelManager = cm } +// SetMediaStore injects a MediaStore for media lifecycle management. +func (al *AgentLoop) SetMediaStore(s media.MediaStore) { + al.mediaStore = s +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 081f13a0b..e49713eb8 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -12,8 +12,9 @@ type InboundMessage struct { ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` - Peer Peer `json:"peer"` // routing peer - MessageID string `json:"message_id,omitempty"` // platform message ID + Peer Peer `json:"peer"` // routing peer + MessageID string `json:"message_id,omitempty"` // platform message ID + MediaScope string `json:"media_scope,omitempty"` // media lifecycle scope SessionKey string `json:"session_key"` Metadata map[string]string `json:"metadata,omitempty"` } diff --git a/pkg/channels/base.go b/pkg/channels/base.go index f70145981..d967d9e91 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -5,7 +5,10 @@ import ( "strings" "sync/atomic" + "github.com/google/uuid" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/media" ) type Channel interface { @@ -41,6 +44,7 @@ type BaseChannel struct { name string allowList []string maxMessageLength int + mediaStore media.MediaStore } func NewBaseChannel( @@ -125,15 +129,18 @@ func (c *BaseChannel) HandleMessage( return } + scope := BuildMediaScope(c.name, chatID, messageID) + msg := bus.InboundMessage{ - Channel: c.name, - SenderID: senderID, - ChatID: chatID, - Content: content, - Media: media, - Peer: peer, - MessageID: messageID, - Metadata: metadata, + Channel: c.name, + SenderID: senderID, + ChatID: chatID, + Content: content, + Media: media, + Peer: peer, + MessageID: messageID, + MediaScope: scope, + Metadata: metadata, } c.bus.PublishInbound(msg) @@ -142,3 +149,18 @@ func (c *BaseChannel) HandleMessage( func (c *BaseChannel) SetRunning(running bool) { c.running.Store(running) } + +// SetMediaStore injects a MediaStore into the channel. +func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } + +// GetMediaStore returns the injected MediaStore (may be nil). +func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } + +// BuildMediaScope constructs a scope key for media lifecycle tracking. +func BuildMediaScope(channel, chatID, messageID string) string { + id := messageID + if id == "" { + id = uuid.New().String() + } + return channel + ":" + chatID + ":" + id +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 623bc9f48..7977d32e1 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,7 +3,6 @@ package discord import ( "context" "fmt" - "os" "strings" "sync" "time" @@ -14,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -202,19 +202,22 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag content := m.Content content = c.stripBotMention(content) mediaPaths := make([]string, 0, len(m.Attachments)) - localFiles := make([]string, 0, len(m.Attachments)) - - // Ensure temp files are cleaned up when function returns - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("discord", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + scope := channels.BuildMediaScope("discord", m.ChannelID, m.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "discord", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } for _, attachment := range m.Attachments { isAudio := utils.IsAudioFile(attachment.Filename, attachment.ContentType) @@ -222,8 +225,6 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - localFiles = append(localFiles, localPath) - transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout) @@ -245,6 +246,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) } + mediaPaths = append(mediaPaths, storeMedia(localPath, attachment.Filename)) content = appendContent(content, transcribedText) } else { logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 9744e1848..272a53c6e 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -10,7 +10,6 @@ import ( "fmt" "io" "net/http" - "os" "strings" "sync" "time" @@ -19,6 +18,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -308,18 +308,22 @@ func (c *LINEChannel) processEvent(event lineEvent) { var content string var mediaPaths []string - localFiles := []string{} - - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("line", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + scope := channels.BuildMediaScope("line", chatID, msg.ID) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "line", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } switch msg.Type { case "text": @@ -331,22 +335,19 @@ func (c *LINEChannel) processEvent(event lineEvent) { case "image": localPath := c.downloadContent(msg.ID, "image.jpg") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "image.jpg")) content = "[image]" } case "audio": localPath := c.downloadContent(msg.ID, "audio.m4a") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "audio.m4a")) content = "[audio]" } case "video": localPath := c.downloadContent(msg.ID, "video.mp4") if localPath != "" { - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, "video.mp4")) content = "[video]" } case "file": diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 081d616da..37af01796 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -31,6 +32,7 @@ type Manager struct { workers map[string]*channelWorker bus *bus.MessageBus config *config.Config + mediaStore media.MediaStore dispatchTask *asyncTask mu sync.RWMutex } @@ -39,12 +41,13 @@ type asyncTask struct { cancel context.CancelFunc } -func NewManager(cfg *config.Config, messageBus *bus.MessageBus) (*Manager, error) { +func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ - channels: make(map[string]Channel), - workers: make(map[string]*channelWorker), - bus: messageBus, - config: cfg, + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + bus: messageBus, + config: cfg, + mediaStore: store, } if err := m.initChannels(); err != nil { @@ -73,6 +76,12 @@ func (m *Manager) initChannel(name, displayName string) { "error": err.Error(), }) } else { + // Inject MediaStore if channel supports it + if m.mediaStore != nil { + if setter, ok := ch.(interface{ SetMediaStore(s media.MediaStore) }); ok { + setter.SetMediaStore(m.mediaStore) + } + } m.channels[name] = ch m.workers[name] = &channelWorker{ ch: ch, diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 4f35888ca..e2fe541f1 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "fmt" - "os" "strconv" "strings" "sync" @@ -17,6 +16,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -575,11 +575,15 @@ type parseMessageResult struct { Text string IsBotMentioned bool Media []string - LocalFiles []string ReplyTo string } -func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) parseMessageResult { +func (c *OneBotChannel) parseMessageSegments( + raw json.RawMessage, + selfID int64, + store media.MediaStore, + scope string, +) parseMessageResult { if len(raw) == 0 { return parseMessageResult{} } @@ -606,10 +610,23 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) var textParts []string mentioned := false selfIDStr := strconv.FormatInt(selfID, 10) - var media []string - var localFiles []string + var mediaRefs []string var replyTo string + // Helper to register a local file with the media store + storeFile := func(localPath, filename string) string { + if store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "onebot", + }, scope) + if err == nil { + return ref + } + } + return localPath // fallback + } + for _, seg := range segments { segType, _ := seg["type"].(string) data, _ := seg["data"].(map[string]any) @@ -645,8 +662,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - media = append(media, localPath) - localFiles = append(localFiles, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, filename)) textParts = append(textParts, fmt.Sprintf("[%s]", segType)) } } @@ -660,7 +676,6 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) LoggerPrefix: "onebot", }) if localPath != "" { - localFiles = append(localFiles, localPath) if c.transcriber != nil && c.transcriber.IsAvailable() { tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) result, err := c.transcriber.Transcribe(tctx, localPath) @@ -670,13 +685,15 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) "error": err.Error(), }) textParts = append(textParts, "[voice (transcription failed)]") - media = append(media, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } else { textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) + // Still store the file so it can be released later + storeFile(localPath, "voice.amr") } } else { textParts = append(textParts, "[voice]") - media = append(media, localPath) + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } } } @@ -706,8 +723,7 @@ func (c *OneBotChannel) parseMessageSegments(raw json.RawMessage, selfID int64) return parseMessageResult{ Text: strings.TrimSpace(strings.Join(textParts, "")), IsBotMentioned: mentioned, - Media: media, - LocalFiles: localFiles, + Media: mediaRefs, ReplyTo: replyTo, } } @@ -799,7 +815,17 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { selfID = atomic.LoadInt64(&c.selfID) } - parsed := c.parseMessageSegments(raw.Message, selfID) + // Compute scope for media store before parsing (parsing may download files) + var chatIDForScope string + switch raw.MessageType { + case "group": + chatIDForScope = "group:" + strconv.FormatInt(groupID, 10) + default: + chatIDForScope = "private:" + strconv.FormatInt(userID, 10) + } + scope := channels.BuildMediaScope("onebot", chatIDForScope, messageID) + + parsed := c.parseMessageSegments(raw.Message, selfID, c.GetMediaStore(), scope) isBotMentioned := parsed.IsBotMentioned content := raw.RawMessage @@ -828,20 +854,6 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } } - // Clean up temp files when done - if len(parsed.LocalFiles) > 0 { - defer func() { - for _, f := range parsed.LocalFiles { - if err := os.Remove(f); err != nil { - logger.DebugCF("onebot", "Failed to remove temp file", map[string]any{ - "path": f, - "error": err.Error(), - }) - } - } - }() - } - if c.isDuplicate(messageID) { logger.DebugCF("onebot", "Duplicate message, skipping", map[string]any{ "message_id": messageID, diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index fc0bee505..53d7c0609 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -3,7 +3,6 @@ package slack import ( "context" "fmt" - "os" "strings" "sync" "time" @@ -16,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -233,19 +233,22 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content = c.stripBotMention(content) var mediaPaths []string - localFiles := []string{} // 跟踪需要清理的本地文件 - - // 确保临时文件在函数返回时被清理 - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("slack", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + scope := channels.BuildMediaScope("slack", chatID, messageTS) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "slack", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback + } if ev.Message != nil && len(ev.Message.Files) > 0 { for _, file := range ev.Message.Files { @@ -253,8 +256,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { if localPath == "" { continue } - localFiles = append(localFiles, localPath) - mediaPaths = append(mediaPaths, localPath) + mediaPaths = append(mediaPaths, storeMedia(localPath, file.Name)) if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 578e3c51e..af7155799 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -20,6 +20,7 @@ import ( "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" "github.com/sipeed/picoclaw/pkg/voice" ) @@ -251,19 +252,24 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes content := "" mediaPaths := []string{} - localFiles := []string{} // 跟踪需要清理的本地文件 - - // 确保临时文件在函数返回时被清理 - defer func() { - for _, file := range localFiles { - if err := os.Remove(file); err != nil { - logger.DebugCF("telegram", "Failed to cleanup temp file", map[string]any{ - "file": file, - "error": err.Error(), - }) + + chatIDStr := fmt.Sprintf("%d", chatID) + messageIDStr := fmt.Sprintf("%d", message.MessageID) + scope := channels.BuildMediaScope("telegram", chatIDStr, messageIDStr) + + // Helper to register a local file with the media store + storeMedia := func(localPath, filename string) string { + if store := c.GetMediaStore(); store != nil { + ref, err := store.Store(localPath, media.MediaMeta{ + Filename: filename, + Source: "telegram", + }, scope) + if err == nil { + return ref } } - }() + return localPath // fallback: use raw path + } if message.Text != "" { content += message.Text @@ -280,8 +286,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes photo := message.Photo[len(message.Photo)-1] photoPath := c.downloadPhoto(ctx, photo.FileID) if photoPath != "" { - localFiles = append(localFiles, photoPath) - mediaPaths = append(mediaPaths, photoPath) + mediaPaths = append(mediaPaths, storeMedia(photoPath, "photo.jpg")) if content != "" { content += "\n" } @@ -292,8 +297,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Voice != nil { voicePath := c.downloadFile(ctx, message.Voice.FileID, ".ogg") if voicePath != "" { - localFiles = append(localFiles, voicePath) - mediaPaths = append(mediaPaths, voicePath) + mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg")) transcribedText := "" if c.transcriber != nil && c.transcriber.IsAvailable() { @@ -327,8 +331,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Audio != nil { audioPath := c.downloadFile(ctx, message.Audio.FileID, ".mp3") if audioPath != "" { - localFiles = append(localFiles, audioPath) - mediaPaths = append(mediaPaths, audioPath) + mediaPaths = append(mediaPaths, storeMedia(audioPath, "audio.mp3")) if content != "" { content += "\n" } @@ -339,8 +342,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if message.Document != nil { docPath := c.downloadFile(ctx, message.Document.FileID, "") if docPath != "" { - localFiles = append(localFiles, docPath) - mediaPaths = append(mediaPaths, docPath) + mediaPaths = append(mediaPaths, storeMedia(docPath, "document")) if content != "" { content += "\n" } @@ -367,7 +369,6 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes } // Stop any previous thinking animation - chatIDStr := fmt.Sprintf("%d", chatID) if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { cf.Cancel() diff --git a/pkg/media/store.go b/pkg/media/store.go new file mode 100644 index 000000000..8d03c03ef --- /dev/null +++ b/pkg/media/store.go @@ -0,0 +1,102 @@ +package media + +import ( + "fmt" + "os" + "sync" + + "github.com/google/uuid" +) + +// MediaMeta holds metadata about a stored media file. +type MediaMeta struct { + Filename string + ContentType string + Source string // "telegram", "discord", "tool:image-gen", etc. +} + +// MediaStore manages the lifecycle of media files associated with processing scopes. +type MediaStore interface { + // Store registers an existing local file under the given scope. + // Returns a ref identifier (e.g. "media://"). + // Store does not move or copy the file; it only records the mapping. + Store(localPath string, meta MediaMeta, scope string) (ref string, err error) + + // Resolve returns the local file path for a given ref. + Resolve(ref string) (localPath string, err error) + + // ReleaseAll deletes all files registered under the given scope + // and removes the mapping entries. File-not-exist errors are ignored. + ReleaseAll(scope string) error +} + +// FileMediaStore is a pure in-memory implementation of MediaStore. +// Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/). +type FileMediaStore struct { + mu sync.RWMutex + refToPath map[string]string + scopeToRefs map[string]map[string]struct{} +} + +// NewFileMediaStore creates a new FileMediaStore. +func NewFileMediaStore() *FileMediaStore { + return &FileMediaStore{ + refToPath: make(map[string]string), + scopeToRefs: make(map[string]map[string]struct{}), + } +} + +// Store registers a local file under the given scope. The file must exist. +func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) { + if _, err := os.Stat(localPath); err != nil { + return "", fmt.Errorf("media store: file does not exist: %s", localPath) + } + + ref := "media://" + uuid.New().String()[:8] + + s.mu.Lock() + defer s.mu.Unlock() + + s.refToPath[ref] = localPath + if s.scopeToRefs[scope] == nil { + s.scopeToRefs[scope] = make(map[string]struct{}) + } + s.scopeToRefs[scope][ref] = struct{}{} + + return ref, nil +} + +// Resolve returns the local path for the given ref. +func (s *FileMediaStore) Resolve(ref string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + path, ok := s.refToPath[ref] + if !ok { + return "", fmt.Errorf("media store: unknown ref: %s", ref) + } + return path, nil +} + +// ReleaseAll removes all files under the given scope and cleans up mappings. +func (s *FileMediaStore) ReleaseAll(scope string) error { + s.mu.Lock() + defer s.mu.Unlock() + + refs, ok := s.scopeToRefs[scope] + if !ok { + return nil + } + + for ref := range refs { + if path, exists := s.refToPath[ref]; exists { + if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + // Log but continue — best effort cleanup + } + delete(s.refToPath, ref) + } + } + + delete(s.scopeToRefs, scope) + return nil +} diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go new file mode 100644 index 000000000..361582307 --- /dev/null +++ b/pkg/media/store_test.go @@ -0,0 +1,179 @@ +package media + +import ( + "os" + "path/filepath" + "strings" + "sync" + "testing" +) + +func createTempFile(t *testing.T, dir, name string) string { + t.Helper() + path := filepath.Join(dir, name) + if err := os.WriteFile(path, []byte("test content"), 0o644); err != nil { + t.Fatalf("failed to create temp file: %v", err) + } + return path +} + +func TestStoreAndResolve(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "photo.jpg") + + ref, err := store.Store(path, MediaMeta{Filename: "photo.jpg", Source: "telegram"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + if !strings.HasPrefix(ref, "media://") { + t.Errorf("ref should start with media://, got %q", ref) + } + + resolved, err := store.Resolve(ref) + if err != nil { + t.Fatalf("Resolve failed: %v", err) + } + if resolved != path { + t.Errorf("Resolve returned %q, want %q", resolved, path) + } +} + +func TestReleaseAll(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + paths := make([]string, 3) + refs := make([]string, 3) + for i := 0; i < 3; i++ { + paths[i] = createTempFile(t, dir, strings.Repeat("a", i+1)+".jpg") + var err error + refs[i], err = store.Store(paths[i], MediaMeta{Source: "test"}, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + } + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("ReleaseAll failed: %v", err) + } + + // Files should be deleted + for _, p := range paths { + if _, err := os.Stat(p); !os.IsNotExist(err) { + t.Errorf("file %q should have been deleted", p) + } + } + + // Refs should be unresolvable + for _, ref := range refs { + if _, err := store.Resolve(ref); err == nil { + t.Errorf("Resolve(%q) should fail after ReleaseAll", ref) + } + } +} + +func TestMultiScopeIsolation(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + pathA := createTempFile(t, dir, "fileA.jpg") + pathB := createTempFile(t, dir, "fileB.jpg") + + refA, _ := store.Store(pathA, MediaMeta{Source: "test"}, "scopeA") + refB, _ := store.Store(pathB, MediaMeta{Source: "test"}, "scopeB") + + // Release only scopeA + if err := store.ReleaseAll("scopeA"); err != nil { + t.Fatalf("ReleaseAll(scopeA) failed: %v", err) + } + + // scopeA file should be gone + if _, err := os.Stat(pathA); !os.IsNotExist(err) { + t.Error("file A should have been deleted") + } + if _, err := store.Resolve(refA); err == nil { + t.Error("refA should be unresolvable after release") + } + + // scopeB file should still exist + if _, err := os.Stat(pathB); err != nil { + t.Error("file B should still exist") + } + resolved, err := store.Resolve(refB) + if err != nil { + t.Fatalf("refB should still resolve: %v", err) + } + if resolved != pathB { + t.Errorf("resolved %q, want %q", resolved, pathB) + } +} + +func TestReleaseAllIdempotent(t *testing.T) { + store := NewFileMediaStore() + + // ReleaseAll on non-existent scope should not error + if err := store.ReleaseAll("nonexistent"); err != nil { + t.Fatalf("ReleaseAll on empty scope should not error: %v", err) + } + + // Create and release, then release again + dir := t.TempDir() + path := createTempFile(t, dir, "file.jpg") + _, _ = store.Store(path, MediaMeta{Source: "test"}, "scope1") + + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("first ReleaseAll failed: %v", err) + } + if err := store.ReleaseAll("scope1"); err != nil { + t.Fatalf("second ReleaseAll should not error: %v", err) + } +} + +func TestStoreNonexistentFile(t *testing.T) { + store := NewFileMediaStore() + + _, err := store.Store("/nonexistent/path/file.jpg", MediaMeta{Source: "test"}, "scope1") + if err == nil { + t.Error("Store should fail for nonexistent file") + } +} + +func TestConcurrentSafety(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + const goroutines = 20 + const filesPerGoroutine = 5 + + var wg sync.WaitGroup + wg.Add(goroutines) + + for g := 0; g < goroutines; g++ { + go func(gIdx int) { + defer wg.Done() + scope := strings.Repeat("s", gIdx+1) + + for i := 0; i < filesPerGoroutine; i++ { + path := createTempFile(t, dir, strings.Repeat("f", gIdx*filesPerGoroutine+i+1)+".tmp") + ref, err := store.Store(path, MediaMeta{Source: "test"}, scope) + if err != nil { + t.Errorf("Store failed: %v", err) + return + } + + if _, err := store.Resolve(ref); err != nil { + t.Errorf("Resolve failed: %v", err) + } + } + + if err := store.ReleaseAll(scope); err != nil { + t.Errorf("ReleaseAll failed: %v", err) + } + }(g) + } + + wg.Wait() +} From a32d98534c0d98a134cb5cb82fe0f1aae3377783 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Sun, 22 Feb 2026 23:51:55 +0800 Subject: [PATCH 12/28] refactor(channels): add per-channel rate limiting and send retry with error classification MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Define sentinel error types (ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed) so the Manager can classify Send failures and choose the right retry strategy: permanent errors bail immediately, rate-limit errors use a fixed 1s delay, and temporary/unknown errors use exponential backoff (500ms→1s→2s, capped at 8s, up to 3 retries). A per-channel token-bucket rate limiter (golang.org/x/time/rate) throttles outbound sends before they hit the platform API. --- go.mod | 1 + go.sum | 2 + pkg/channels/errors.go | 21 ++ pkg/channels/errors_test.go | 56 +++++ pkg/channels/manager.go | 127 +++++++++-- pkg/channels/manager_test.go | 418 +++++++++++++++++++++++++++++++++++ 6 files changed, 601 insertions(+), 24 deletions(-) create mode 100644 pkg/channels/errors.go create mode 100644 pkg/channels/errors_test.go create mode 100644 pkg/channels/manager_test.go diff --git a/go.mod b/go.mod index 1f88639c8..32436ce53 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/time v0.14.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 0e95bf5cd..0e2d37cab 100644 --- a/go.sum +++ b/go.sum @@ -226,6 +226,8 @@ golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= diff --git a/pkg/channels/errors.go b/pkg/channels/errors.go new file mode 100644 index 000000000..09ee88b3f --- /dev/null +++ b/pkg/channels/errors.go @@ -0,0 +1,21 @@ +package channels + +import "errors" + +var ( + // ErrNotRunning indicates the channel is not running. + // Manager will not retry. + ErrNotRunning = errors.New("channel not running") + + // ErrRateLimit indicates the platform returned a rate-limit response (e.g. HTTP 429). + // Manager will wait a fixed delay and retry. + ErrRateLimit = errors.New("rate limited") + + // ErrTemporary indicates a transient failure (e.g. network timeout, 5xx). + // Manager will use exponential backoff and retry. + ErrTemporary = errors.New("temporary failure") + + // ErrSendFailed indicates a permanent failure (e.g. invalid chat ID, 4xx non-429). + // Manager will not retry. + ErrSendFailed = errors.New("send failed") +) diff --git a/pkg/channels/errors_test.go b/pkg/channels/errors_test.go new file mode 100644 index 000000000..e5592345a --- /dev/null +++ b/pkg/channels/errors_test.go @@ -0,0 +1,56 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestErrorsIs(t *testing.T) { + wrapped := fmt.Errorf("telegram API: %w", ErrRateLimit) + if !errors.Is(wrapped, ErrRateLimit) { + t.Error("wrapped ErrRateLimit should match") + } + if errors.Is(wrapped, ErrTemporary) { + t.Error("wrapped ErrRateLimit should not match ErrTemporary") + } +} + +func TestErrorsIsAllTypes(t *testing.T) { + sentinels := []error{ErrNotRunning, ErrRateLimit, ErrTemporary, ErrSendFailed} + + for _, sentinel := range sentinels { + wrapped := fmt.Errorf("context: %w", sentinel) + if !errors.Is(wrapped, sentinel) { + t.Errorf("wrapped %v should match itself", sentinel) + } + + // Verify it doesn't match other sentinel errors + for _, other := range sentinels { + if other == sentinel { + continue + } + if errors.Is(wrapped, other) { + t.Errorf("wrapped %v should not match %v", sentinel, other) + } + } + } +} + +func TestErrorMessages(t *testing.T) { + tests := []struct { + err error + want string + }{ + {ErrNotRunning, "channel not running"}, + {ErrRateLimit, "rate limited"}, + {ErrTemporary, "temporary failure"}, + {ErrSendFailed, "send failed"}, + } + + for _, tt := range tests { + if got := tt.err.Error(); got != tt.want { + t.Errorf("error message = %q, want %q", got, tt.want) + } + } +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 37af01796..1bc321cec 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -8,8 +8,13 @@ package channels import ( "context" + "errors" "fmt" + "math" "sync" + "time" + + "golang.org/x/time/rate" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" @@ -19,12 +24,28 @@ import ( "github.com/sipeed/picoclaw/pkg/utils" ) -const defaultChannelQueueSize = 100 +const ( + defaultChannelQueueSize = 100 + defaultRateLimit = 10 // default 10 msg/s + maxRetries = 3 + rateLimitDelay = 1 * time.Second + baseBackoff = 500 * time.Millisecond + maxBackoff = 8 * time.Second +) + +// channelRateConfig maps channel name to per-second rate limit. +var channelRateConfig = map[string]float64{ + "telegram": 20, + "discord": 1, + "slack": 1, + "line": 10, +} type channelWorker struct { - ch Channel - queue chan bus.OutboundMessage - done chan struct{} + ch Channel + queue chan bus.OutboundMessage + done chan struct{} + limiter *rate.Limiter } type Manager struct { @@ -83,11 +104,7 @@ func (m *Manager) initChannel(name, displayName string) { } } m.channels[name] = ch - m.workers[name] = &channelWorker{ - ch: ch, - queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), - done: make(chan struct{}), - } + m.workers[name] = newChannelWorker(name, ch) logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ "channel": displayName, }) @@ -227,6 +244,23 @@ func (m *Manager) StopAll(ctx context.Context) error { return nil } +// newChannelWorker creates a channelWorker with a rate limiter configured +// for the given channel name. +func newChannelWorker(name string, ch Channel) *channelWorker { + rateVal := float64(defaultRateLimit) + if r, ok := channelRateConfig[name]; ok { + rateVal = r + } + burst := int(math.Max(1, math.Ceil(rateVal/2))) + + return &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Limit(rateVal), burst), + } +} + // runWorker processes outbound messages for a single channel, splitting // messages that exceed the channel's maximum message length. func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) { @@ -246,23 +280,72 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) for _, chunk := range chunks { chunkMsg := msg chunkMsg.Content = chunk - if err := w.ch.Send(ctx, chunkMsg); err != nil { - logger.ErrorCF("channels", "Error sending chunk", map[string]any{ - "channel": name, "error": err.Error(), - }) - } + m.sendWithRetry(ctx, name, w, chunkMsg) } } else { - if err := w.ch.Send(ctx, msg); err != nil { - logger.ErrorCF("channels", "Error sending message", map[string]any{ - "channel": name, "error": err.Error(), - }) - } + m.sendWithRetry(ctx, name, w, msg) + } + case <-ctx.Done(): + return + } + } +} + +// sendWithRetry sends a message through the channel with rate limiting and +// retry logic. It classifies errors to determine the retry strategy: +// - ErrNotRunning / ErrSendFailed: permanent, no retry +// - ErrRateLimit: fixed delay retry +// - ErrTemporary / unknown: exponential backoff retry +func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMessage) { + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + // ctx cancelled, shutting down + return + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = w.ch.Send(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): + continue + case <-ctx.Done(): + return } + } + + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): case <-ctx.Done(): return } } + + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "Send failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) } func (m *Manager) dispatchOutbound(ctx context.Context) { @@ -343,11 +426,7 @@ func (m *Manager) RegisterChannel(name string, channel Channel) { m.mu.Lock() defer m.mu.Unlock() m.channels[name] = channel - m.workers[name] = &channelWorker{ - ch: channel, - queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), - done: make(chan struct{}), - } + m.workers[name] = newChannelWorker(name, channel) } func (m *Manager) UnregisterChannel(name string) { diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go new file mode 100644 index 000000000..162c9f8c9 --- /dev/null +++ b/pkg/channels/manager_test.go @@ -0,0 +1,418 @@ +package channels + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "golang.org/x/time/rate" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// mockChannel is a test double that delegates Send to a configurable function. +type mockChannel struct { + BaseChannel + sendFn func(ctx context.Context, msg bus.OutboundMessage) error +} + +func (m *mockChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + return m.sendFn(ctx, msg) +} + +func (m *mockChannel) Start(ctx context.Context) error { return nil } +func (m *mockChannel) Stop(ctx context.Context) error { return nil } + +// newTestManager creates a minimal Manager suitable for unit tests. +func newTestManager() *Manager { + return &Manager{ + channels: make(map[string]Channel), + workers: make(map[string]*channelWorker), + } +} + +func TestSendWithRetry_Success(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call, got %d", callCount) + } +} + +func TestSendWithRetry_TemporaryThenSuccess(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount <= 2 { + return fmt.Errorf("network error: %w", ErrTemporary) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 3 { + t.Fatalf("expected 3 Send calls (2 failures + 1 success), got %d", callCount) + } +} + +func TestSendWithRetry_PermanentFailure(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("bad chat ID: %w", ErrSendFailed) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for permanent failure), got %d", callCount) + } +} + +func TestSendWithRetry_NotRunning(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return ErrNotRunning + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 1 { + t.Fatalf("expected 1 Send call (no retry for ErrNotRunning), got %d", callCount) + } +} + +func TestSendWithRetry_RateLimitRetry(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return fmt.Errorf("429: %w", ErrRateLimit) + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + elapsed := time.Since(start) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (1 rate limit + 1 success), got %d", callCount) + } + // Should have waited at least rateLimitDelay (1s) but allow some slack + if elapsed < 900*time.Millisecond { + t.Fatalf("expected at least ~1s delay for rate limit retry, got %v", elapsed) + } +} + +func TestSendWithRetry_MaxRetriesExhausted(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + expected := maxRetries + 1 // initial attempt + maxRetries retries + if callCount != expected { + t.Fatalf("expected %d Send calls, got %d", expected, callCount) + } +} + +func TestSendWithRetry_UnknownError(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + if callCount == 1 { + return errors.New("random unexpected error") + } + return nil + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + m.sendWithRetry(ctx, "test", w, msg) + + if callCount != 2 { + t.Fatalf("expected 2 Send calls (unknown error treated as temporary), got %d", callCount) + } +} + +func TestSendWithRetry_ContextCancelled(t *testing.T) { + m := newTestManager() + var callCount int + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + // Cancel context after first Send attempt returns + ch.sendFn = func(_ context.Context, _ bus.OutboundMessage) error { + callCount++ + cancel() + return fmt.Errorf("timeout: %w", ErrTemporary) + } + + m.sendWithRetry(ctx, "test", w, msg) + + // Should have called Send once, then noticed ctx cancelled during backoff + if callCount != 1 { + t.Fatalf("expected 1 Send call before context cancellation, got %d", callCount) + } +} + +func TestWorkerRateLimiter(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var sendTimes []time.Time + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + mu.Lock() + sendTimes = append(sendTimes, time.Now()) + mu.Unlock() + return nil + }, + } + + // Create a worker with a low rate: 2 msg/s, burst 1 + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(2, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go m.runWorker(ctx, "test", w) + + // Enqueue 4 messages + for i := 0; i < 4; i++ { + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: fmt.Sprintf("msg%d", i)} + } + + // Wait enough time for all messages to be sent (4 msgs at 2/s = ~2s, give extra margin) + time.Sleep(3 * time.Second) + + mu.Lock() + times := make([]time.Time, len(sendTimes)) + copy(times, sendTimes) + mu.Unlock() + + if len(times) != 4 { + t.Fatalf("expected 4 sends, got %d", len(times)) + } + + // Verify rate limiting: total duration should be at least 1s + // (first message immediate, then ~500ms between each subsequent one at 2/s) + totalDuration := times[len(times)-1].Sub(times[0]) + if totalDuration < 1*time.Second { + t.Fatalf("expected total duration >= 1s for 4 msgs at 2/s rate, got %v", totalDuration) + } +} + +func TestNewChannelWorker_DefaultRate(t *testing.T) { + ch := &mockChannel{} + w := newChannelWorker("unknown_channel", ch) + + if w.limiter == nil { + t.Fatal("expected limiter to be non-nil") + } + if w.limiter.Limit() != rate.Limit(defaultRateLimit) { + t.Fatalf("expected rate limit %v, got %v", rate.Limit(defaultRateLimit), w.limiter.Limit()) + } +} + +func TestNewChannelWorker_ConfiguredRate(t *testing.T) { + ch := &mockChannel{} + + for name, expectedRate := range channelRateConfig { + w := newChannelWorker(name, ch) + if w.limiter.Limit() != rate.Limit(expectedRate) { + t.Fatalf("channel %s: expected rate %v, got %v", name, expectedRate, w.limiter.Limit()) + } + } +} + +func TestRunWorker_MessageSplitting(t *testing.T) { + m := newTestManager() + + var mu sync.Mutex + var received []string + + ch := &mockChannelWithLength{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, msg bus.OutboundMessage) error { + mu.Lock() + received = append(received, msg.Content) + mu.Unlock() + return nil + }, + }, + maxLen: 5, + } + + w := &channelWorker{ + ch: ch, + queue: make(chan bus.OutboundMessage, 10), + done: make(chan struct{}), + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + go m.runWorker(ctx, "test", w) + + // Send a message that should be split + w.queue <- bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello world"} + + time.Sleep(100 * time.Millisecond) + + mu.Lock() + count := len(received) + mu.Unlock() + + if count < 2 { + t.Fatalf("expected message to be split into at least 2 chunks, got %d", count) + } +} + +// mockChannelWithLength implements MessageLengthProvider. +type mockChannelWithLength struct { + mockChannel + maxLen int +} + +func (m *mockChannelWithLength) MaxMessageLength() int { + return m.maxLen +} + +func TestSendWithRetry_ExponentialBackoff(t *testing.T) { + m := newTestManager() + + var callTimes []time.Time + var callCount atomic.Int32 + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + callTimes = append(callTimes, time.Now()) + callCount.Add(1) + return fmt.Errorf("timeout: %w", ErrTemporary) + }, + } + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + ctx := context.Background() + msg := bus.OutboundMessage{Channel: "test", ChatID: "1", Content: "hello"} + + start := time.Now() + m.sendWithRetry(ctx, "test", w, msg) + totalElapsed := time.Since(start) + + // With maxRetries=3: attempts at 0, ~500ms, ~1.5s, ~3.5s + // Total backoff: 500ms + 1s + 2s = 3.5s + // Allow some margin + if totalElapsed < 3*time.Second { + t.Fatalf("expected total elapsed >= 3s for exponential backoff, got %v", totalElapsed) + } + + if int(callCount.Load()) != maxRetries+1 { + t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) + } +} From 24e2ed79c08e31367e8c3352944fc0d5a5940415 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 00:44:45 +0800 Subject: [PATCH 13/28] refactor(bus): fix deadlock and concurrency issues in MessageBus MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit PublishInbound/PublishOutbound held RLock during blocking channel sends, deadlocking against Close() which needs a write lock when the buffer is full. ConsumeInbound/SubscribeOutbound used bare receives instead of comma-ok, causing zero-value processing or busy loops after close. Replace sync.RWMutex+bool with atomic.Bool+done channel so Publish methods use a lock-free 3-way select (send / done / ctx.Done). Add context.Context parameter to both Publish methods so callers can cancel or timeout blocked sends. Close() now only sets the atomic flag and closes the done channel—never closes the data channels—eliminating send-on-closed-channel panics. - Remove dead code: RegisterHandler, GetHandler, handlers map, MessageHandler type (zero callers across the whole repo) - Add ErrBusClosed sentinel error - Update all 10 caller sites to pass context - Add msgBus.Close() to gateway and agent shutdown flows - Add pkg/bus/bus_test.go with 11 test cases covering basic round-trip, context cancellation, closed-bus behavior, concurrent publish+close, full-buffer timeout, and idempotent Close --- cmd/picoclaw/cmd_agent.go | 1 + cmd/picoclaw/cmd_gateway.go | 1 + pkg/agent/loop.go | 12 +- pkg/bus/bus.go | 81 +++++++------ pkg/bus/bus_test.go | 229 ++++++++++++++++++++++++++++++++++++ pkg/bus/types.go | 2 - pkg/channels/base.go | 2 +- pkg/devices/service.go | 2 +- pkg/heartbeat/service.go | 3 +- pkg/tools/cron.go | 4 +- pkg/tools/subagent.go | 2 +- 11 files changed, 284 insertions(+), 55 deletions(-) create mode 100644 pkg/bus/bus_test.go diff --git a/cmd/picoclaw/cmd_agent.go b/cmd/picoclaw/cmd_agent.go index 8658c9d32..1c92e0b6c 100644 --- a/cmd/picoclaw/cmd_agent.go +++ b/cmd/picoclaw/cmd_agent.go @@ -70,6 +70,7 @@ func agentCmd() { } msgBus := bus.NewMessageBus() + defer msgBus.Close() agentLoop := agent.NewAgentLoop(cfg, msgBus, provider) // Print agent startup info (only for interactive mode) diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 3c2cb021d..3b914f6ae 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -228,6 +228,7 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() + msgBus.Close() healthServer.Stop(context.Background()) deviceService.Stop() heartbeatService.Stop() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 97569bef7..e243a6fdb 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -120,7 +120,7 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: content, @@ -202,7 +202,7 @@ func (al *AgentLoop) Run(ctx context.Context) error { } if !alreadySent { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: msg.Channel, ChatID: msg.ChatID, Content: response, @@ -471,7 +471,7 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt // 8. Optional: send response via bus if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, @@ -586,7 +586,7 @@ func (al *AgentLoop) runLLMIteration( }) if retry == 0 && !constants.IsInternalChannel(opts.Channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: "Context window exceeded. Compressing history and retrying...", @@ -715,7 +715,7 @@ func (al *AgentLoop) runLLMIteration( // Send ForUser content to user immediately if not Silent if !toolResult.Silent && toolResult.ForUser != "" && opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: toolResult.ForUser, @@ -780,7 +780,7 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c go func() { defer al.summarizing.Delete(summarizeKey) if !constants.IsInternalChannel(channel) { - al.bus.PublishOutbound(bus.OutboundMessage{ + al.bus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: "Memory threshold reached. Optimizing conversation history...", diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 58c0a25d5..100ddc456 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -2,81 +2,80 @@ package bus import ( "context" - "sync" + "errors" + "sync/atomic" ) +// ErrBusClosed is returned when publishing to a closed MessageBus. +var ErrBusClosed = errors.New("message bus closed") + type MessageBus struct { inbound chan InboundMessage outbound chan OutboundMessage - handlers map[string]MessageHandler - closed bool - mu sync.RWMutex + done chan struct{} + closed atomic.Bool } func NewMessageBus() *MessageBus { return &MessageBus{ inbound: make(chan InboundMessage, 100), outbound: make(chan OutboundMessage, 100), - handlers: make(map[string]MessageHandler), + done: make(chan struct{}), } } -func (mb *MessageBus) PublishInbound(msg InboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.inbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.inbound <- msg } func (mb *MessageBus) ConsumeInbound(ctx context.Context) (InboundMessage, bool) { select { - case msg := <-mb.inbound: - return msg, true + case msg, ok := <-mb.inbound: + return msg, ok + case <-mb.done: + return InboundMessage{}, false case <-ctx.Done(): return InboundMessage{}, false } } -func (mb *MessageBus) PublishOutbound(msg OutboundMessage) { - mb.mu.RLock() - defer mb.mu.RUnlock() - if mb.closed { - return +func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.outbound <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() } - mb.outbound <- msg } func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, bool) { select { - case msg := <-mb.outbound: - return msg, true + case msg, ok := <-mb.outbound: + return msg, ok + case <-mb.done: + return OutboundMessage{}, false case <-ctx.Done(): return OutboundMessage{}, false } } -func (mb *MessageBus) RegisterHandler(channel string, handler MessageHandler) { - mb.mu.Lock() - defer mb.mu.Unlock() - mb.handlers[channel] = handler -} - -func (mb *MessageBus) GetHandler(channel string) (MessageHandler, bool) { - mb.mu.RLock() - defer mb.mu.RUnlock() - handler, ok := mb.handlers[channel] - return handler, ok -} - func (mb *MessageBus) Close() { - mb.mu.Lock() - defer mb.mu.Unlock() - if mb.closed { - return + if mb.closed.CompareAndSwap(false, true) { + close(mb.done) } - mb.closed = true - close(mb.inbound) - close(mb.outbound) } diff --git a/pkg/bus/bus_test.go b/pkg/bus/bus_test.go new file mode 100644 index 000000000..47826824e --- /dev/null +++ b/pkg/bus/bus_test.go @@ -0,0 +1,229 @@ +package bus + +import ( + "context" + "sync" + "testing" + "time" +) + +func TestPublishConsume(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := InboundMessage{ + Channel: "test", + SenderID: "user1", + ChatID: "chat1", + Content: "hello", + } + + if err := mb.PublishInbound(ctx, msg); err != nil { + t.Fatalf("PublishInbound failed: %v", err) + } + + got, ok := mb.ConsumeInbound(ctx) + if !ok { + t.Fatal("ConsumeInbound returned ok=false") + } + if got.Content != "hello" { + t.Fatalf("expected content 'hello', got %q", got.Content) + } + if got.Channel != "test" { + t.Fatalf("expected channel 'test', got %q", got.Channel) + } +} + +func TestPublishOutboundSubscribe(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + msg := OutboundMessage{ + Channel: "telegram", + ChatID: "123", + Content: "world", + } + + if err := mb.PublishOutbound(ctx, msg); err != nil { + t.Fatalf("PublishOutbound failed: %v", err) + } + + got, ok := mb.SubscribeOutbound(ctx) + if !ok { + t.Fatal("SubscribeOutbound returned ok=false") + } + if got.Content != "world" { + t.Fatalf("expected content 'world', got %q", got.Content) + } +} + +func TestPublishInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + // Fill the buffer + ctx := context.Background() + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Now buffer is full; publish with a cancelled context + cancelCtx, cancel := context.WithCancel(context.Background()) + cancel() + + err := mb.PublishInbound(cancelCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error from cancelled context, got nil") + } + if err != context.Canceled { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestPublishInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestPublishOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + err := mb.PublishOutbound(context.Background(), OutboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed, got %v", err) + } +} + +func TestConsumeInbound_ContextCancel(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when context is cancelled") + } +} + +func TestConsumeInbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.ConsumeInbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestSubscribeOutbound_BusClosed(t *testing.T) { + mb := NewMessageBus() + mb.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + _, ok := mb.SubscribeOutbound(ctx) + if ok { + t.Fatal("expected ok=false when bus is closed") + } +} + +func TestConcurrentPublishClose(t *testing.T) { + mb := NewMessageBus() + ctx := context.Background() + + const numGoroutines = 100 + var wg sync.WaitGroup + wg.Add(numGoroutines + 1) + + // Spawn many goroutines trying to publish + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + // Use a short timeout context so we don't block forever after close + publishCtx, cancel := context.WithTimeout(ctx, 50*time.Millisecond) + defer cancel() + // Errors are expected; we just must not panic or deadlock + _ = mb.PublishInbound(publishCtx, InboundMessage{Content: "concurrent"}) + }() + } + + // Close from another goroutine + go func() { + defer wg.Done() + time.Sleep(5 * time.Millisecond) + mb.Close() + }() + + // Must complete without deadlock + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // success + case <-time.After(5 * time.Second): + t.Fatal("test timed out - possible deadlock") + } +} + +func TestPublishInbound_FullBuffer(t *testing.T) { + mb := NewMessageBus() + defer mb.Close() + + ctx := context.Background() + + // Fill the buffer + for i := 0; i < 100; i++ { + if err := mb.PublishInbound(ctx, InboundMessage{Content: "fill"}); err != nil { + t.Fatalf("fill failed at %d: %v", i, err) + } + } + + // Buffer is full; publish with short timeout + timeoutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := mb.PublishInbound(timeoutCtx, InboundMessage{Content: "overflow"}) + if err == nil { + t.Fatal("expected error when buffer is full and context times out") + } + if err != context.DeadlineExceeded { + t.Fatalf("expected context.DeadlineExceeded, got %v", err) + } +} + +func TestCloseIdempotent(t *testing.T) { + mb := NewMessageBus() + + // Multiple Close calls must not panic + mb.Close() + mb.Close() + mb.Close() + + // After close, publish should return ErrBusClosed + err := mb.PublishInbound(context.Background(), InboundMessage{Content: "test"}) + if err != ErrBusClosed { + t.Fatalf("expected ErrBusClosed after multiple closes, got %v", err) + } +} diff --git a/pkg/bus/types.go b/pkg/bus/types.go index e49713eb8..358829c55 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -24,5 +24,3 @@ type OutboundMessage struct { ChatID string `json:"chat_id"` Content string `json:"content"` } - -type MessageHandler func(InboundMessage) error diff --git a/pkg/channels/base.go b/pkg/channels/base.go index d967d9e91..adacb8c78 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -143,7 +143,7 @@ func (c *BaseChannel) HandleMessage( Metadata: metadata, } - c.bus.PublishInbound(msg) + c.bus.PublishInbound(context.TODO(), msg) } func (c *BaseChannel) SetRunning(running bool) { diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 1541d3c57..408e1c8aa 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -127,7 +127,7 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { } msg := ev.FormatMessage() - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: msg, diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 75d6248b9..62b321955 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -7,6 +7,7 @@ package heartbeat import ( + "context" "fmt" "os" "path/filepath" @@ -307,7 +308,7 @@ func (hs *HeartbeatService) sendResponse(response string) { return } - msgBus.PublishOutbound(bus.OutboundMessage{ + msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: response, diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 562fffc84..3c13f5968 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -294,7 +294,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } - t.msgBus.PublishOutbound(bus.OutboundMessage{ + t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: output, @@ -304,7 +304,7 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // If deliver=true, send message directly without agent processing if job.Payload.Deliver { - t.msgBus.PublishOutbound(bus.OutboundMessage{ + t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: job.Payload.Message, diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 91ebff636..99821daf9 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -218,7 +218,7 @@ After completing the task, provide a clear summary of what was done.` // Send announce message back to main agent if sm.bus != nil { announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - sm.bus.PublishInbound(bus.InboundMessage{ + sm.bus.PublishInbound(context.TODO(), bus.InboundMessage{ Channel: "system", SenderID: fmt.Sprintf("subagent:%s", task.ID), // Format: "original_channel:original_chat_id" for routing back From cc92a6281251c008059c0b2a069cb12631affac0 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 01:45:48 +0800 Subject: [PATCH 14/28] refactor(channels): standardize Send error classification with sentinel types All 12 channel Send methods now return proper sentinel errors (ErrNotRunning, ErrTemporary, ErrRateLimit, ErrSendFailed) instead of plain fmt.Errorf strings, enabling Manager's sendWithRetry classification logic to actually work. - Add ClassifySendError/ClassifyNetError helpers in errutil.go for HTTP-based channels - LINE/WeCom Bot/WeCom App: use ClassifySendError for HTTP status-based classification - SDK channels (Telegram/Discord/Slack/QQ/DingTalk/Feishu): wrap errors as ErrTemporary - WebSocket channels (OneBot/WhatsApp/MaixCam): wrap write errors as ErrTemporary - WhatsApp: add missing IsRunning() check in Send - WhatsApp/OneBot/MaixCam: add ctx.Done() check before entering write path - Telegram Stop: clean up placeholders sync.Map to prevent state leaks --- pkg/channels/dingtalk/dingtalk.go | 4 +- pkg/channels/discord/discord.go | 6 +- pkg/channels/errutil.go | 30 ++++++++++ pkg/channels/errutil_test.go | 97 +++++++++++++++++++++++++++++++ pkg/channels/feishu/feishu_64.go | 6 +- pkg/channels/line/line.go | 6 +- pkg/channels/maixcam/maixcam.go | 11 +++- pkg/channels/onebot/onebot.go | 11 +++- pkg/channels/qq/qq.go | 4 +- pkg/channels/slack/slack.go | 4 +- pkg/channels/telegram/telegram.go | 15 +++-- pkg/channels/wecom/app.go | 16 ++++- pkg/channels/wecom/bot.go | 9 ++- pkg/channels/whatsapp/whatsapp.go | 15 ++++- 14 files changed, 204 insertions(+), 30 deletions(-) create mode 100644 pkg/channels/errutil.go create mode 100644 pkg/channels/errutil_test.go diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index e051add1f..c49769761 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -96,7 +96,7 @@ func (c *DingTalkChannel) Stop(ctx context.Context) error { // Send sends a message to DingTalk via the chatbot reply API func (c *DingTalkChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("dingtalk channel not running") + return channels.ErrNotRunning } // Get session webhook from storage @@ -197,7 +197,7 @@ func (c *DingTalkChannel) SendDirectReply(ctx context.Context, sessionWebhook, c contentBytes, ) if err != nil { - return fmt.Errorf("failed to send reply: %w", err) + return fmt.Errorf("dingtalk send: %w", channels.ErrTemporary) } return nil diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 7977d32e1..d5524f7f9 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -113,7 +113,7 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro c.stopTyping(msg.ChatID) if !c.IsRunning() { - return fmt.Errorf("discord bot not running") + return channels.ErrNotRunning } channelID := msg.ChatID @@ -142,11 +142,11 @@ func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content strin select { case err := <-done: if err != nil { - return fmt.Errorf("failed to send discord message: %w", err) + return fmt.Errorf("discord send: %w", channels.ErrTemporary) } return nil case <-sendCtx.Done(): - return fmt.Errorf("send message timeout: %w", sendCtx.Err()) + return sendCtx.Err() } } diff --git a/pkg/channels/errutil.go b/pkg/channels/errutil.go new file mode 100644 index 000000000..319e3c980 --- /dev/null +++ b/pkg/channels/errutil.go @@ -0,0 +1,30 @@ +package channels + +import ( + "fmt" + "net/http" +) + +// ClassifySendError wraps a raw error with the appropriate sentinel based on +// an HTTP status code. Channels that perform HTTP API calls should use this +// in their Send path. +func ClassifySendError(statusCode int, rawErr error) error { + switch { + case statusCode == http.StatusTooManyRequests: + return fmt.Errorf("%w: %v", ErrRateLimit, rawErr) + case statusCode >= 500: + return fmt.Errorf("%w: %v", ErrTemporary, rawErr) + case statusCode >= 400: + return fmt.Errorf("%w: %v", ErrSendFailed, rawErr) + default: + return rawErr + } +} + +// ClassifyNetError wraps a network/timeout error as ErrTemporary. +func ClassifyNetError(err error) error { + if err == nil { + return nil + } + return fmt.Errorf("%w: %v", ErrTemporary, err) +} diff --git a/pkg/channels/errutil_test.go b/pkg/channels/errutil_test.go new file mode 100644 index 000000000..e3d35f65b --- /dev/null +++ b/pkg/channels/errutil_test.go @@ -0,0 +1,97 @@ +package channels + +import ( + "errors" + "fmt" + "testing" +) + +func TestClassifySendError(t *testing.T) { + raw := fmt.Errorf("some API error") + + tests := []struct { + name string + statusCode int + wantIs error + wantNil bool + }{ + {"429 -> ErrRateLimit", 429, ErrRateLimit, false}, + {"500 -> ErrTemporary", 500, ErrTemporary, false}, + {"502 -> ErrTemporary", 502, ErrTemporary, false}, + {"503 -> ErrTemporary", 503, ErrTemporary, false}, + {"400 -> ErrSendFailed", 400, ErrSendFailed, false}, + {"403 -> ErrSendFailed", 403, ErrSendFailed, false}, + {"404 -> ErrSendFailed", 404, ErrSendFailed, false}, + {"200 -> raw error", 200, nil, false}, + {"201 -> raw error", 201, nil, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ClassifySendError(tt.statusCode, raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if tt.wantIs != nil { + if !errors.Is(err, tt.wantIs) { + t.Errorf("errors.Is(err, %v) = false, want true; err = %v", tt.wantIs, err) + } + } else { + // Should return the raw error unchanged + if err != raw { + t.Errorf("expected raw error to be returned unchanged for status %d, got %v", tt.statusCode, err) + } + } + }) + } +} + +func TestClassifySendErrorNoFalsePositive(t *testing.T) { + raw := fmt.Errorf("some error") + + // 429 should NOT match ErrTemporary or ErrSendFailed + err := ClassifySendError(429, raw) + if errors.Is(err, ErrTemporary) { + t.Error("429 should not match ErrTemporary") + } + if errors.Is(err, ErrSendFailed) { + t.Error("429 should not match ErrSendFailed") + } + + // 500 should NOT match ErrRateLimit or ErrSendFailed + err = ClassifySendError(500, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("500 should not match ErrRateLimit") + } + if errors.Is(err, ErrSendFailed) { + t.Error("500 should not match ErrSendFailed") + } + + // 400 should NOT match ErrRateLimit or ErrTemporary + err = ClassifySendError(400, raw) + if errors.Is(err, ErrRateLimit) { + t.Error("400 should not match ErrRateLimit") + } + if errors.Is(err, ErrTemporary) { + t.Error("400 should not match ErrTemporary") + } +} + +func TestClassifyNetError(t *testing.T) { + t.Run("nil error returns nil", func(t *testing.T) { + if err := ClassifyNetError(nil); err != nil { + t.Errorf("expected nil, got %v", err) + } + }) + + t.Run("non-nil error wraps as ErrTemporary", func(t *testing.T) { + raw := fmt.Errorf("connection refused") + err := ClassifyNetError(raw) + if err == nil { + t.Fatal("expected non-nil error") + } + if !errors.Is(err, ErrTemporary) { + t.Errorf("errors.Is(err, ErrTemporary) = false, want true; err = %v", err) + } + }) +} diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index d67823974..5245cd99d 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -91,7 +91,7 @@ func (c *FeishuChannel) Stop(ctx context.Context) error { func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("feishu channel not running") + return channels.ErrNotRunning } if msg.ChatID == "" { @@ -115,11 +115,11 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error resp, err := c.client.Im.V1.Message.Create(ctx, req) if err != nil { - return fmt.Errorf("failed to send feishu message: %w", err) + return fmt.Errorf("feishu send: %w", channels.ErrTemporary) } if !resp.Success() { - return fmt.Errorf("feishu api error: code=%d msg=%s", resp.Code, resp.Msg) + return fmt.Errorf("feishu api error (code=%d msg=%s): %w", resp.Code, resp.Msg, channels.ErrTemporary) } logger.DebugCF("feishu", "Feishu message sent", map[string]any{ diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 272a53c6e..fd06334d5 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -491,7 +491,7 @@ func (c *LINEChannel) resolveChatID(source lineSource) string { // using a cached reply token, then falls back to the Push API. func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("line channel not running") + return channels.ErrNotRunning } // Load and consume quote token for this chat @@ -582,13 +582,13 @@ func (c *LINEChannel) callAPI(ctx context.Context, endpoint string, payload any) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("API request failed: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) - return fmt.Errorf("LINE API error (status %d): %s", resp.StatusCode, string(respBody)) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("LINE API error: %s", string(respBody))) } return nil diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index 05213b095..b5b7259f9 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -215,7 +215,14 @@ func (c *MaixCamChannel) Stop(ctx context.Context) error { func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("maixcam channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.clientsMux.RLock() @@ -246,7 +253,7 @@ func (c *MaixCamChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro "client": conn.RemoteAddr().String(), "error": err.Error(), }) - sendErr = err + sendErr = fmt.Errorf("maixcam send: %w", channels.ErrTemporary) } _ = conn.SetWriteDeadline(time.Time{}) } diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index e2fe541f1..76950663e 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -373,7 +373,14 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("OneBot channel not running") + return channels.ErrNotRunning + } + + // Check ctx before entering write path + select { + case <-ctx.Done(): + return ctx.Err() + default: } c.mu.Lock() @@ -412,7 +419,7 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error logger.ErrorCF("onebot", "Failed to send message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("onebot send: %w", channels.ErrTemporary) } if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 429e23cbf..69f323e6e 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -114,7 +114,7 @@ func (c *QQChannel) Stop(ctx context.Context) error { func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("QQ bot not running") + return channels.ErrNotRunning } // 构造消息 @@ -128,7 +128,7 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ "error": err.Error(), }) - return err + return fmt.Errorf("qq send: %w", channels.ErrTemporary) } return nil diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 53d7c0609..9e066e00a 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -112,7 +112,7 @@ func (c *SlackChannel) Stop(ctx context.Context) error { func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("slack channel not running") + return channels.ErrNotRunning } channelID, threadTS := parseSlackChatID(msg.ChatID) @@ -130,7 +130,7 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error _, _, err := c.api.PostMessageContext(ctx, channelID, opts...) if err != nil { - return fmt.Errorf("failed to send slack message: %w", err) + return fmt.Errorf("slack send: %w", channels.ErrTemporary) } if ref, ok := c.pendingAcks.LoadAndDelete(msg.ChatID); ok { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index af7155799..a07eb6579 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -164,6 +164,12 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { return true }) + // Clean up placeholder state + c.placeholders.Range(func(key, value any) bool { + c.placeholders.Delete(key) + return true + }) + // Stop the bot handler if c.bh != nil { c.bh.Stop() @@ -179,12 +185,12 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("telegram bot not running") + return channels.ErrNotRunning } chatID, err := parseChatID(msg.ChatID) if err != nil { - return fmt.Errorf("invalid chat ID: %w", err) + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } // Stop thinking animation @@ -217,8 +223,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err "error": err.Error(), }) tgMsg.ParseMode = "" - _, err = c.bot.SendMessage(ctx, tgMsg) - return err + if _, err = c.bot.SendMessage(ctx, tgMsg); err != nil { + return fmt.Errorf("telegram send: %w", channels.ErrTemporary) + } } return nil diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index eb1711d75..41861e8fc 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -207,7 +207,7 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { // Send sends a message to WeCom user proactively using access token func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom_app channel not running") + return channels.ErrNotRunning } accessToken := c.getAccessToken() @@ -548,10 +548,15 @@ func (c *WeComAppChannel) sendTextMessage(ctx context.Context, accessToken, user client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) @@ -603,10 +608,15 @@ func (c *WeComAppChannel) sendMarkdownMessage(ctx context.Context, accessToken, client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send message: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index bbac8611a..7960802fb 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -166,7 +166,7 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { // For delayed responses, we use the webhook URL func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { if !c.IsRunning() { - return fmt.Errorf("wecom channel not running") + return channels.ErrNotRunning } logger.DebugCF("wecom", "Sending message via webhook", map[string]any{ @@ -433,10 +433,15 @@ func (c *WeComBotChannel) sendWebhookReply(ctx context.Context, userID, content client := &http.Client{Timeout: time.Duration(timeout) * time.Second} resp, err := client.Do(req) if err != nil { - return fmt.Errorf("failed to send webhook reply: %w", err) + return channels.ClassifyNetError(err) } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("webhook API error: %s", string(body))) + } + body, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("failed to read response: %w", err) diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index b5f3e99d7..97032334f 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -94,11 +94,22 @@ func (c *WhatsAppChannel) Stop(ctx context.Context) error { } func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + // Check ctx before acquiring lock + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + c.mu.Lock() defer c.mu.Unlock() if c.conn == nil { - return fmt.Errorf("whatsapp connection not established") + return fmt.Errorf("whatsapp connection not established: %w", channels.ErrTemporary) } payload := map[string]any{ @@ -115,7 +126,7 @@ func (c *WhatsAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err _ = c.conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) if err := c.conn.WriteMessage(websocket.TextMessage, data); err != nil { _ = c.conn.SetWriteDeadline(time.Time{}) - return fmt.Errorf("failed to send message: %w", err) + return fmt.Errorf("whatsapp send: %w", channels.ErrTemporary) } _ = c.conn.SetWriteDeadline(time.Time{}) From d1551dc4233ad6171a8b801c28f66979e1c1deb3 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 02:39:09 +0800 Subject: [PATCH 15/28] refactor(channels): consolidate HTTP servers into shared server managed by Manager Merge 3 independent channel HTTP servers (LINE :18791, WeCom Bot :18793, WeCom App :18792) and the health server (:18790) into a single shared HTTP server on the Gateway address. Channels implement WebhookHandler and/or HealthChecker interfaces to register their handlers on the shared mux. Also change Gateway default host from 0.0.0.0 to 127.0.0.1 for security. --- cmd/picoclaw/cmd_gateway.go | 15 ++++---- pkg/channels/line/line.go | 53 ++++++++-------------------- pkg/channels/manager.go | 69 ++++++++++++++++++++++++++++++++++++- pkg/channels/webhook.go | 20 +++++++++++ pkg/channels/wecom/app.go | 63 ++++++++++++++------------------- pkg/channels/wecom/bot.go | 63 ++++++++++++++------------------- pkg/health/server.go | 7 ++++ 7 files changed, 166 insertions(+), 124 deletions(-) create mode 100644 pkg/channels/webhook.go diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 3b914f6ae..4e6ec8bb3 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -6,7 +6,6 @@ package main import ( "context" "fmt" - "net/http" "os" "os/signal" "path/filepath" @@ -208,16 +207,15 @@ func gatewayCmd() { fmt.Println("✓ Device event service started") } + // Setup shared HTTP server with health endpoints and webhook handlers + healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + addr := fmt.Sprintf("%s:%d", cfg.Gateway.Host, cfg.Gateway.Port) + channelManager.SetupHTTPServer(addr, healthServer) + if err := channelManager.StartAll(ctx); err != nil { fmt.Printf("Error starting channels: %v\n", err) } - healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - go func() { - if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) - } - }() fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) go agentLoop.Run(ctx) @@ -229,12 +227,11 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() msgBus.Close() - healthServer.Stop(context.Background()) + channelManager.StopAll(ctx) deviceService.Stop() heartbeatService.Stop() cronService.Stop() agentLoop.Stop() - channelManager.StopAll(ctx) fmt.Println("✓ Gateway stopped") } diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index fd06334d5..6ae048468 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -44,7 +44,6 @@ type replyTokenEntry struct { type LINEChannel struct { *channels.BaseChannel config config.LINEConfig - httpServer *http.Server botUserID string // Bot's user ID botBasicID string // Bot's basic ID (e.g. @216ru...) botDisplayName string // Bot's display name for text-based mention detection @@ -68,7 +67,7 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha }, nil } -// Start launches the HTTP webhook server. +// Start initializes the LINE channel. func (c *LINEChannel) Start(ctx context.Context) error { logger.InfoC("line", "Starting LINE channel (Webhook Mode)") @@ -87,31 +86,6 @@ func (c *LINEChannel) Start(ctx context.Context) error { }) } - mux := http.NewServeMux() - path := c.config.WebhookPath - if path == "" { - path = "/webhook/line" - } - mux.HandleFunc(path, c.webhookHandler) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.httpServer = &http.Server{ - Addr: addr, - Handler: mux, - } - - go func() { - logger.InfoCF("line", "LINE webhook server listening", map[string]any{ - "addr": addr, - "path": path, - }) - if err := c.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("line", "Webhook server error", map[string]any{ - "error": err.Error(), - }) - } - }() - c.SetRunning(true) logger.InfoC("line", "LINE channel started (Webhook Mode)") return nil @@ -151,7 +125,7 @@ func (c *LINEChannel) fetchBotInfo() error { return nil } -// Stop gracefully shuts down the HTTP server. +// Stop gracefully stops the LINE channel. func (c *LINEChannel) Stop(ctx context.Context) error { logger.InfoC("line", "Stopping LINE channel") @@ -159,21 +133,24 @@ func (c *LINEChannel) Stop(ctx context.Context) error { c.cancel() } - if c.httpServer != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if err := c.httpServer.Shutdown(shutdownCtx); err != nil { - logger.ErrorCF("line", "Webhook server shutdown error", map[string]any{ - "error": err.Error(), - }) - } - } - c.SetRunning(false) logger.InfoC("line", "LINE channel stopped") return nil } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *LINEChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/line" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *LINEChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.webhookHandler(w, r) +} + // webhookHandler handles incoming LINE webhook requests. func (c *LINEChannel) webhookHandler(w http.ResponseWriter, r *http.Request) { if r.Method != http.MethodPost { diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 1bc321cec..dadc068e9 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "math" + "net/http" "sync" "time" @@ -19,6 +20,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/constants" + "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" @@ -55,6 +57,8 @@ type Manager struct { config *config.Config mediaStore media.MediaStore dispatchTask *asyncTask + mux *http.ServeMux + httpServer *http.Server mu sync.RWMutex } @@ -169,6 +173,43 @@ func (m *Manager) initChannels() error { return nil } +// SetupHTTPServer creates a shared HTTP server with the given listen address. +// It registers health endpoints from the health server and discovers channels +// that implement WebhookHandler and/or HealthChecker to register their handlers. +func (m *Manager) SetupHTTPServer(addr string, healthServer *health.Server) { + m.mux = http.NewServeMux() + + // Register health endpoints + if healthServer != nil { + healthServer.RegisterOnMux(m.mux) + } + + // Discover and register webhook handlers and health checkers + for name, ch := range m.channels { + if wh, ok := ch.(WebhookHandler); ok { + m.mux.Handle(wh.WebhookPath(), wh) + logger.InfoCF("channels", "Webhook handler registered", map[string]any{ + "channel": name, + "path": wh.WebhookPath(), + }) + } + if hc, ok := ch.(HealthChecker); ok { + m.mux.HandleFunc(hc.HealthPath(), hc.HealthHandler) + logger.InfoCF("channels", "Health endpoint registered", map[string]any{ + "channel": name, + "path": hc.HealthPath(), + }) + } + } + + m.httpServer = &http.Server{ + Addr: addr, + Handler: m.mux, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + } +} + func (m *Manager) StartAll(ctx context.Context) error { m.mu.Lock() defer m.mu.Unlock() @@ -203,6 +244,20 @@ func (m *Manager) StartAll(ctx context.Context) error { // Start the dispatcher that reads from the bus and routes to workers go m.dispatchOutbound(dispatchCtx) + // Start shared HTTP server if configured + if m.httpServer != nil { + go func() { + logger.InfoCF("channels", "Shared HTTP server listening", map[string]any{ + "addr": m.httpServer.Addr, + }) + if err := m.httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("channels", "Shared HTTP server error", map[string]any{ + "error": err.Error(), + }) + } + }() + } + logger.InfoC("channels", "All channels started") return nil } @@ -213,7 +268,19 @@ func (m *Manager) StopAll(ctx context.Context) error { logger.InfoC("channels", "Stopping all channels") - // Cancel dispatcher first + // Shutdown shared HTTP server first + if m.httpServer != nil { + shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + if err := m.httpServer.Shutdown(shutdownCtx); err != nil { + logger.ErrorCF("channels", "Shared HTTP server shutdown error", map[string]any{ + "error": err.Error(), + }) + } + m.httpServer = nil + } + + // Cancel dispatcher if m.dispatchTask != nil { m.dispatchTask.cancel() m.dispatchTask = nil diff --git a/pkg/channels/webhook.go b/pkg/channels/webhook.go new file mode 100644 index 000000000..3cf27baf6 --- /dev/null +++ b/pkg/channels/webhook.go @@ -0,0 +1,20 @@ +package channels + +import "net/http" + +// WebhookHandler is an optional interface for channels that receive messages +// via HTTP webhooks. Manager discovers channels implementing this interface +// and registers them on the shared HTTP server. +type WebhookHandler interface { + // WebhookPath returns the path to mount this handler on the shared server. + // Examples: "/webhook/line", "/webhook/wecom" + WebhookPath() string + http.Handler // ServeHTTP(w http.ResponseWriter, r *http.Request) +} + +// HealthChecker is an optional interface for channels that expose +// a health check endpoint on the shared HTTP server. +type HealthChecker interface { + HealthPath() string + HealthHandler(w http.ResponseWriter, r *http.Request) +} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 41861e8fc..52750505c 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -28,7 +28,6 @@ const ( type WeComAppChannel struct { *channels.BaseChannel config config.WeComAppConfig - server *http.Server accessToken string tokenExpiry time.Time tokenMu sync.RWMutex @@ -134,7 +133,7 @@ func (c *WeComAppChannel) Name() string { return "wecom_app" } -// Start initializes the WeCom App channel with HTTP webhook server +// Start initializes the WeCom App channel func (c *WeComAppChannel) Start(ctx context.Context) error { logger.InfoC("wecom_app", "Starting WeCom App channel...") @@ -150,37 +149,8 @@ func (c *WeComAppChannel) Start(ctx context.Context) error { // Start token refresh goroutine go c.tokenRefreshLoop() - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom-app" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom-app", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - c.SetRunning(true) - logger.InfoCF("wecom_app", "WeCom App channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom_app", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() + logger.InfoC("wecom_app", "WeCom App channel started") return nil } @@ -193,12 +163,6 @@ func (c *WeComAppChannel) Stop(ctx context.Context) error { c.cancel() } - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - c.SetRunning(false) logger.InfoC("wecom_app", "WeCom App channel stopped") return nil @@ -223,6 +187,29 @@ func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *WeComAppChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/wecom-app" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *WeComAppChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path. +func (c *WeComAppChannel) HealthPath() string { + return "/health/wecom-app" +} + +// HealthHandler handles health check requests. +func (c *WeComAppChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + // handleWebhook handles incoming webhook requests from WeCom func (c *WeComAppChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 7960802fb..d5912bddc 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -24,7 +24,6 @@ import ( type WeComBotChannel struct { *channels.BaseChannel config config.WeComConfig - server *http.Server ctx context.Context cancel context.CancelFunc processedMsgs map[string]bool // Message deduplication: msg_id -> processed @@ -101,43 +100,14 @@ func (c *WeComBotChannel) Name() string { return "wecom" } -// Start initializes the WeCom Bot channel with HTTP webhook server +// Start initializes the WeCom Bot channel func (c *WeComBotChannel) Start(ctx context.Context) error { logger.InfoC("wecom", "Starting WeCom Bot channel...") c.ctx, c.cancel = context.WithCancel(ctx) - // Setup HTTP server for webhook - mux := http.NewServeMux() - webhookPath := c.config.WebhookPath - if webhookPath == "" { - webhookPath = "/webhook/wecom" - } - mux.HandleFunc(webhookPath, c.handleWebhook) - - // Health check endpoint - mux.HandleFunc("/health/wecom", c.handleHealth) - - addr := fmt.Sprintf("%s:%d", c.config.WebhookHost, c.config.WebhookPort) - c.server = &http.Server{ - Addr: addr, - Handler: mux, - } - c.SetRunning(true) - logger.InfoCF("wecom", "WeCom Bot channel started", map[string]any{ - "address": addr, - "path": webhookPath, - }) - - // Start server in goroutine - go func() { - if err := c.server.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("wecom", "HTTP server error", map[string]any{ - "error": err.Error(), - }) - } - }() + logger.InfoC("wecom", "WeCom Bot channel started") return nil } @@ -150,12 +120,6 @@ func (c *WeComBotChannel) Stop(ctx context.Context) error { c.cancel() } - if c.server != nil { - shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - c.server.Shutdown(shutdownCtx) - } - c.SetRunning(false) logger.InfoC("wecom", "WeCom Bot channel stopped") return nil @@ -177,6 +141,29 @@ func (c *WeComBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendWebhookReply(ctx, msg.ChatID, msg.Content) } +// WebhookPath returns the path for registering on the shared HTTP server. +func (c *WeComBotChannel) WebhookPath() string { + if c.config.WebhookPath != "" { + return c.config.WebhookPath + } + return "/webhook/wecom" +} + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *WeComBotChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.handleWebhook(w, r) +} + +// HealthPath returns the health check endpoint path. +func (c *WeComBotChannel) HealthPath() string { + return "/health/wecom" +} + +// HealthHandler handles health check requests. +func (c *WeComBotChannel) HealthHandler(w http.ResponseWriter, r *http.Request) { + c.handleHealth(w, r) +} + // handleWebhook handles incoming webhook requests from WeCom func (c *WeComBotChannel) handleWebhook(w http.ResponseWriter, r *http.Request) { ctx := r.Context() diff --git a/pkg/health/server.go b/pkg/health/server.go index 77b36034d..de1ff60fe 100644 --- a/pkg/health/server.go +++ b/pkg/health/server.go @@ -156,6 +156,13 @@ func (s *Server) readyHandler(w http.ResponseWriter, r *http.Request) { }) } +// RegisterOnMux registers /health and /ready handlers onto the given mux. +// This allows the health endpoints to be served by a shared HTTP server. +func (s *Server) RegisterOnMux(mux *http.ServeMux) { + mux.HandleFunc("/health", s.healthHandler) + mux.HandleFunc("/ready", s.readyHandler) +} + func statusString(ok bool) string { if ok { return "ok" From 4c7a5df307627d94e69fa9477ca75844b35f51dc Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 03:10:57 +0800 Subject: [PATCH 16/28] feat(channels): add MediaSender optional interface for outbound media Add outbound media sending capability so the agent can publish media attachments (images, files, audio, video) through channels via the bus. - Add MediaPart and OutboundMediaMessage types to bus - Add PublishOutboundMedia/SubscribeOutboundMedia bus methods - Add MediaSender interface discovered via type assertion by Manager - Add media dispatch/worker in Manager with shared retry logic - Extend ToolResult with Media field and MediaResult constructor - Publish outbound media from agent loop on tool results - Implement SendMedia for Telegram, Discord, Slack, LINE, OneBot, WeCom --- pkg/agent/loop.go | 13 ++ pkg/bus/bus.go | 41 +++++-- pkg/bus/types.go | 16 +++ pkg/channels/discord/discord.go | 98 +++++++++++++++ pkg/channels/line/line.go | 30 +++++ pkg/channels/manager.go | 150 +++++++++++++++++++++-- pkg/channels/media.go | 15 +++ pkg/channels/onebot/onebot.go | 111 +++++++++++++++++ pkg/channels/slack/slack.go | 54 +++++++++ pkg/channels/telegram/telegram.go | 85 +++++++++++++ pkg/channels/wecom/app.go | 194 ++++++++++++++++++++++++++++++ pkg/tools/result.go | 17 +++ 12 files changed, 809 insertions(+), 15 deletions(-) create mode 100644 pkg/channels/media.go diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index e243a6fdb..050303101 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -727,6 +727,19 @@ func (al *AgentLoop) runLLMIteration( }) } + // If tool returned media refs, publish them as outbound media + if len(toolResult.Media) > 0 && opts.SendResponse { + parts := make([]bus.MediaPart, 0, len(toolResult.Media)) + for _, ref := range toolResult.Media { + parts = append(parts, bus.MediaPart{Ref: ref}) + } + al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ + Channel: opts.Channel, + ChatID: opts.ChatID, + Parts: parts, + }) + } + // Determine content for LLM based on tool result contentForLLM := toolResult.ForLLM if contentForLLM == "" && toolResult.Err != nil { diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 100ddc456..6a1c987b7 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -10,17 +10,19 @@ import ( var ErrBusClosed = errors.New("message bus closed") type MessageBus struct { - inbound chan InboundMessage - outbound chan OutboundMessage - done chan struct{} - closed atomic.Bool + inbound chan InboundMessage + outbound chan OutboundMessage + outboundMedia chan OutboundMediaMessage + done chan struct{} + closed atomic.Bool } func NewMessageBus() *MessageBus { return &MessageBus{ - inbound: make(chan InboundMessage, 100), - outbound: make(chan OutboundMessage, 100), - done: make(chan struct{}), + inbound: make(chan InboundMessage, 100), + outbound: make(chan OutboundMessage, 100), + outboundMedia: make(chan OutboundMediaMessage, 100), + done: make(chan struct{}), } } @@ -74,6 +76,31 @@ func (mb *MessageBus) SubscribeOutbound(ctx context.Context) (OutboundMessage, b } } +func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMediaMessage) error { + if mb.closed.Load() { + return ErrBusClosed + } + select { + case mb.outboundMedia <- msg: + return nil + case <-mb.done: + return ErrBusClosed + case <-ctx.Done(): + return ctx.Err() + } +} + +func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMediaMessage, bool) { + select { + case msg, ok := <-mb.outboundMedia: + return msg, ok + case <-mb.done: + return OutboundMediaMessage{}, false + case <-ctx.Done(): + return OutboundMediaMessage{}, false + } +} + func (mb *MessageBus) Close() { if mb.closed.CompareAndSwap(false, true) { close(mb.done) diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 358829c55..1a7a14170 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -24,3 +24,19 @@ type OutboundMessage struct { ChatID string `json:"chat_id"` Content string `json:"content"` } + +// MediaPart describes a single media attachment to send. +type MediaPart struct { + Type string `json:"type"` // "image" | "audio" | "video" | "file" + Ref string `json:"ref"` // media store ref, e.g. "media://abc123" + Caption string `json:"caption,omitempty"` // optional caption text + Filename string `json:"filename,omitempty"` // original filename hint + ContentType string `json:"content_type,omitempty"` // MIME type hint +} + +// OutboundMediaMessage carries media attachments from Agent to channels via the bus. +type OutboundMediaMessage struct { + Channel string `json:"channel"` + ChatID string `json:"chat_id"` + Parts []MediaPart `json:"parts"` +} diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index d5524f7f9..7987f45a9 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -3,6 +3,7 @@ package discord import ( "context" "fmt" + "os" "strings" "sync" "time" @@ -128,6 +129,103 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro return c.sendChunk(ctx, channelID, msg.Content) } +// SendMedia implements the channels.MediaSender interface. +func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + c.stopTyping(msg.ChatID) + + if !c.IsRunning() { + return channels.ErrNotRunning + } + + channelID := msg.ChatID + if channelID == "" { + return fmt.Errorf("channel ID is empty") + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // Collect all files into a single ChannelMessageSendComplex call + files := make([]*discordgo.File, 0, len(msg.Parts)) + var caption string + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("discord", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("discord", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + // Note: discordgo reads from the Reader and we can't close it before send + + filename := part.Filename + if filename == "" { + filename = "file" + } + + files = append(files, &discordgo.File{ + Name: filename, + ContentType: part.ContentType, + Reader: file, + }) + + if part.Caption != "" && caption == "" { + caption = part.Caption + } + } + + if len(files) == 0 { + return nil + } + + sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) + defer cancel() + + done := make(chan error, 1) + go func() { + _, err := c.session.ChannelMessageSendComplex(channelID, &discordgo.MessageSend{ + Content: caption, + Files: files, + }) + done <- err + }() + + select { + case err := <-done: + // Close all file readers + for _, f := range files { + if closer, ok := f.Reader.(*os.File); ok { + closer.Close() + } + } + if err != nil { + return fmt.Errorf("discord send media: %w", channels.ErrTemporary) + } + return nil + case <-sendCtx.Done(): + // Close all file readers + for _, f := range files { + if closer, ok := f.Reader.(*os.File); ok { + closer.Close() + } + } + return sendCtx.Err() + } +} + func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 6ae048468..5b0af4f1d 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -496,6 +496,36 @@ func (c *LINEChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return c.sendPush(ctx, msg.ChatID, msg.Content, quoteToken) } +// SendMedia implements the channels.MediaSender interface. +// LINE requires media to be accessible via public URL; since we only have local files, +// we fall back to sending a text message with the filename/caption. +// For full support, an external file hosting service would be needed. +func (c *LINEChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // LINE Messaging API requires publicly accessible URLs for media messages. + // Since we only have local file paths, send caption text as fallback. + for _, part := range msg.Parts { + caption := part.Caption + if caption == "" { + caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) + } + + if err := c.sendPush(ctx, msg.ChatID, caption, ""); err != nil { + return err + } + } + + return nil +} + // buildTextMessage creates a text message object, optionally with quoteToken. func buildTextMessage(content, quoteToken string) map[string]string { msg := map[string]string{ diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index dadc068e9..92412edeb 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -44,10 +44,12 @@ var channelRateConfig = map[string]float64{ } type channelWorker struct { - ch Channel - queue chan bus.OutboundMessage - done chan struct{} - limiter *rate.Limiter + ch Channel + queue chan bus.OutboundMessage + mediaQueue chan bus.OutboundMediaMessage + done chan struct{} + mediaDone chan struct{} + limiter *rate.Limiter } type Manager struct { @@ -239,10 +241,12 @@ func (m *Manager) StartAll(ctx context.Context) error { // Start per-channel workers for name, w := range m.workers { go m.runWorker(dispatchCtx, name, w) + go m.runMediaWorker(dispatchCtx, name, w) } // Start the dispatcher that reads from the bus and routes to workers go m.dispatchOutbound(dispatchCtx) + go m.dispatchOutboundMedia(dispatchCtx) // Start shared HTTP server if configured if m.httpServer != nil { @@ -293,6 +297,13 @@ func (m *Manager) StopAll(ctx context.Context) error { for _, w := range m.workers { <-w.done } + // Close all media worker queues and wait for them to drain + for _, w := range m.workers { + close(w.mediaQueue) + } + for _, w := range m.workers { + <-w.mediaDone + } // Stop all channels for name, channel := range m.channels { @@ -321,10 +332,12 @@ func newChannelWorker(name string, ch Channel) *channelWorker { burst := int(math.Max(1, math.Ceil(rateVal/2))) return &channelWorker{ - ch: ch, - queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), - done: make(chan struct{}), - limiter: rate.NewLimiter(rate.Limit(rateVal), burst), + ch: ch, + queue: make(chan bus.OutboundMessage, defaultChannelQueueSize), + mediaQueue: make(chan bus.OutboundMediaMessage, defaultChannelQueueSize), + done: make(chan struct{}), + mediaDone: make(chan struct{}), + limiter: rate.NewLimiter(rate.Limit(rateVal), burst), } } @@ -457,6 +470,125 @@ func (m *Manager) dispatchOutbound(ctx context.Context) { } } +func (m *Manager) dispatchOutboundMedia(ctx context.Context) { + logger.InfoC("channels", "Outbound media dispatcher started") + + for { + select { + case <-ctx.Done(): + logger.InfoC("channels", "Outbound media dispatcher stopped") + return + default: + msg, ok := m.bus.SubscribeOutboundMedia(ctx) + if !ok { + continue + } + + // Silently skip internal channels + if constants.IsInternalChannel(msg.Channel) { + continue + } + + m.mu.RLock() + _, exists := m.channels[msg.Channel] + w, wExists := m.workers[msg.Channel] + m.mu.RUnlock() + + if !exists { + logger.WarnCF("channels", "Unknown channel for outbound media message", map[string]any{ + "channel": msg.Channel, + }) + continue + } + + if wExists { + select { + case w.mediaQueue <- msg: + case <-ctx.Done(): + return + } + } + } + } +} + +// runMediaWorker processes outbound media messages for a single channel. +func (m *Manager) runMediaWorker(ctx context.Context, name string, w *channelWorker) { + defer close(w.mediaDone) + for { + select { + case msg, ok := <-w.mediaQueue: + if !ok { + return + } + m.sendMediaWithRetry(ctx, name, w, msg) + case <-ctx.Done(): + return + } + } +} + +// sendMediaWithRetry sends a media message through the channel with rate limiting and +// retry logic. If the channel does not implement MediaSender, it silently skips. +func (m *Manager) sendMediaWithRetry(ctx context.Context, name string, w *channelWorker, msg bus.OutboundMediaMessage) { + ms, ok := w.ch.(MediaSender) + if !ok { + logger.DebugCF("channels", "Channel does not support MediaSender, skipping media", map[string]any{ + "channel": name, + }) + return + } + + // Rate limit: wait for token + if err := w.limiter.Wait(ctx); err != nil { + return + } + + var lastErr error + for attempt := 0; attempt <= maxRetries; attempt++ { + lastErr = ms.SendMedia(ctx, msg) + if lastErr == nil { + return + } + + // Permanent failures — don't retry + if errors.Is(lastErr, ErrNotRunning) || errors.Is(lastErr, ErrSendFailed) { + break + } + + // Last attempt exhausted — don't sleep + if attempt == maxRetries { + break + } + + // Rate limit error — fixed delay + if errors.Is(lastErr, ErrRateLimit) { + select { + case <-time.After(rateLimitDelay): + continue + case <-ctx.Done(): + return + } + } + + // ErrTemporary or unknown error — exponential backoff + backoff := min(time.Duration(float64(baseBackoff)*math.Pow(2, float64(attempt))), maxBackoff) + select { + case <-time.After(backoff): + case <-ctx.Done(): + return + } + } + + // All retries exhausted or permanent failure + logger.ErrorCF("channels", "SendMedia failed", map[string]any{ + "channel": name, + "chat_id": msg.ChatID, + "error": lastErr.Error(), + "retries": maxRetries, + }) +} + func (m *Manager) GetChannel(name string) (Channel, bool) { m.mu.RLock() defer m.mu.RUnlock() @@ -502,6 +634,8 @@ func (m *Manager) UnregisterChannel(name string) { if w, ok := m.workers[name]; ok { close(w.queue) <-w.done + close(w.mediaQueue) + <-w.mediaDone } delete(m.workers, name) delete(m.channels, name) diff --git a/pkg/channels/media.go b/pkg/channels/media.go new file mode 100644 index 000000000..c645a6180 --- /dev/null +++ b/pkg/channels/media.go @@ -0,0 +1,15 @@ +package channels + +import ( + "context" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// MediaSender is an optional interface for channels that can send +// media attachments (images, files, audio, video). +// Manager discovers channels implementing this interface via type +// assertion and routes OutboundMediaMessage to them. +type MediaSender interface { + SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 76950663e..fb357cf27 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -431,6 +431,117 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } +// SendMedia implements the channels.MediaSender interface. +func (c *OneBotChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + c.mu.Lock() + conn := c.conn + c.mu.Unlock() + + if conn == nil { + return fmt.Errorf("OneBot WebSocket not connected") + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + // Build media segments + var segments []oneBotMessageSegment + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("onebot", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + segType := "image" + switch part.Type { + case "image": + segType = "image" + case "video": + segType = "video" + case "audio": + segType = "record" + default: + segType = "file" + } + + segments = append(segments, oneBotMessageSegment{ + Type: segType, + Data: map[string]any{"file": "file://" + localPath}, + }) + + if part.Caption != "" { + segments = append(segments, oneBotMessageSegment{ + Type: "text", + Data: map[string]any{"text": part.Caption}, + }) + } + } + + if len(segments) == 0 { + return nil + } + + chatID := msg.ChatID + var action, idKey string + var rawID string + if rest, ok := strings.CutPrefix(chatID, "group:"); ok { + action, idKey, rawID = "send_group_msg", "group_id", rest + } else if rest, ok := strings.CutPrefix(chatID, "private:"); ok { + action, idKey, rawID = "send_private_msg", "user_id", rest + } else { + action, idKey, rawID = "send_private_msg", "user_id", chatID + } + + id, err := strconv.ParseInt(rawID, 10, 64) + if err != nil { + return fmt.Errorf("invalid %s in chatID: %s: %w", idKey, chatID, channels.ErrSendFailed) + } + + echo := fmt.Sprintf("send_%d", atomic.AddInt64(&c.echoCounter, 1)) + + req := oneBotAPIRequest{ + Action: action, + Params: map[string]any{idKey: id, "message": segments}, + Echo: echo, + } + + data, err := json.Marshal(req) + if err != nil { + return fmt.Errorf("failed to marshal OneBot request: %w", err) + } + + c.writeMu.Lock() + _ = conn.SetWriteDeadline(time.Now().Add(10 * time.Second)) + err = conn.WriteMessage(websocket.TextMessage, data) + _ = conn.SetWriteDeadline(time.Time{}) + c.writeMu.Unlock() + + if err != nil { + logger.ErrorCF("onebot", "Failed to send media message", map[string]any{ + "error": err.Error(), + }) + return fmt.Errorf("onebot send media: %w", channels.ErrTemporary) + } + + return nil +} + func (c *OneBotChannel) buildMessageSegments(chatID, content string) []oneBotMessageSegment { var segments []oneBotMessageSegment diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 9e066e00a..f2dda15ac 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -149,6 +149,60 @@ func (c *SlackChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } +// SendMedia implements the channels.MediaSender interface. +func (c *SlackChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + channelID, _ := parseSlackChatID(msg.ChatID) + if channelID == "" { + return fmt.Errorf("invalid slack chat ID: %s", msg.ChatID) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("slack", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + filename := part.Filename + if filename == "" { + filename = "file" + } + + title := part.Caption + if title == "" { + title = filename + } + + _, err = c.api.UploadFileV2Context(ctx, slack.UploadFileV2Parameters{ + Channel: channelID, + File: localPath, + Filename: filename, + Title: title, + }) + if err != nil { + logger.ErrorCF("slack", "Failed to upload media", map[string]any{ + "filename": filename, + "error": err.Error(), + }) + return fmt.Errorf("slack send media: %w", channels.ErrTemporary) + } + } + + return nil +} + func (c *SlackChannel) eventLoop() { for { select { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index a07eb6579..f9390b8ed 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -231,6 +231,91 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return nil } +// SendMedia implements the channels.MediaSender interface. +func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + chatID, err := parseChatID(msg.ChatID) + if err != nil { + return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("telegram", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + file, err := os.Open(localPath) + if err != nil { + logger.ErrorCF("telegram", "Failed to open media file", map[string]any{ + "path": localPath, + "error": err.Error(), + }) + continue + } + + filename := part.Filename + if filename == "" { + filename = "file" + } + + switch part.Type { + case "image": + params := &telego.SendPhotoParams{ + ChatID: tu.ID(chatID), + Photo: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendPhoto(ctx, params) + case "audio": + params := &telego.SendAudioParams{ + ChatID: tu.ID(chatID), + Audio: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendAudio(ctx, params) + case "video": + params := &telego.SendVideoParams{ + ChatID: tu.ID(chatID), + Video: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendVideo(ctx, params) + default: // "file" or unknown types + params := &telego.SendDocumentParams{ + ChatID: tu.ID(chatID), + Document: telego.InputFile{File: file}, + Caption: part.Caption, + } + _, err = c.bot.SendDocument(ctx, params) + } + + file.Close() + + if err != nil { + logger.ErrorCF("telegram", "Failed to send media", map[string]any{ + "type": part.Type, + "error": err.Error(), + }) + return fmt.Errorf("telegram send media: %w", channels.ErrTemporary) + } + } + + return nil +} + func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Message) error { if message == nil { return fmt.Errorf("message is nil") diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 52750505c..4c2a4d326 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -7,8 +7,11 @@ import ( "encoding/xml" "fmt" "io" + "mime/multipart" "net/http" "net/url" + "os" + "path/filepath" "strings" "sync" "time" @@ -187,6 +190,197 @@ func (c *WeComAppChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return c.sendTextMessage(ctx, accessToken, msg.ChatID, msg.Content) } +// SendMedia implements the channels.MediaSender interface. +func (c *WeComAppChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + accessToken := c.getAccessToken() + if accessToken == "" { + return fmt.Errorf("no valid access token available: %w", channels.ErrTemporary) + } + + store := c.GetMediaStore() + if store == nil { + return fmt.Errorf("no media store available: %w", channels.ErrSendFailed) + } + + for _, part := range msg.Parts { + localPath, err := store.Resolve(part.Ref) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to resolve media ref", map[string]any{ + "ref": part.Ref, + "error": err.Error(), + }) + continue + } + + // Map part type to WeCom media type + mediaType := "file" + switch part.Type { + case "image": + mediaType = "image" + case "audio": + mediaType = "voice" + case "video": + mediaType = "video" + default: + mediaType = "file" + } + + // Upload media to get media_id + mediaID, err := c.uploadMedia(ctx, accessToken, mediaType, localPath) + if err != nil { + logger.ErrorCF("wecom_app", "Failed to upload media", map[string]any{ + "type": mediaType, + "error": err.Error(), + }) + // Fallback: send caption as text + if part.Caption != "" { + _ = c.sendTextMessage(ctx, accessToken, msg.ChatID, part.Caption) + } + continue + } + + // Send media message using the media_id + if mediaType == "image" { + err = c.sendImageMessage(ctx, accessToken, msg.ChatID, mediaID) + } else { + // For non-image types, send as text fallback with caption + caption := part.Caption + if caption == "" { + caption = fmt.Sprintf("[%s: %s]", part.Type, part.Filename) + } + err = c.sendTextMessage(ctx, accessToken, msg.ChatID, caption) + } + + if err != nil { + return err + } + } + + return nil +} + +// uploadMedia uploads a local file to WeCom temporary media storage. +func (c *WeComAppChannel) uploadMedia(ctx context.Context, accessToken, mediaType, localPath string) (string, error) { + apiURL := fmt.Sprintf("%s/cgi-bin/media/upload?access_token=%s&type=%s", + wecomAPIBase, url.QueryEscape(accessToken), url.QueryEscape(mediaType)) + + file, err := os.Open(localPath) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + filename := filepath.Base(localPath) + formFile, err := writer.CreateFormFile("media", filename) + if err != nil { + return "", fmt.Errorf("failed to create form file: %w", err) + } + + if _, err = io.Copy(formFile, file); err != nil { + return "", fmt.Errorf("failed to copy file content: %w", err) + } + writer.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, body) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", writer.FormDataContentType()) + + client := &http.Client{Timeout: 30 * time.Second} + resp, err := client.Do(req) + if err != nil { + return "", channels.ClassifyNetError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return "", channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom upload error: %s", string(respBody))) + } + + var result struct { + ErrCode int `json:"errcode"` + ErrMsg string `json:"errmsg"` + MediaID string `json:"media_id"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return "", fmt.Errorf("failed to parse upload response: %w", err) + } + + if result.ErrCode != 0 { + return "", fmt.Errorf("upload API error: %s (code: %d)", result.ErrMsg, result.ErrCode) + } + + return result.MediaID, nil +} + +// sendImageMessage sends an image message using a media_id. +func (c *WeComAppChannel) sendImageMessage(ctx context.Context, accessToken, userID, mediaID string) error { + apiURL := fmt.Sprintf("%s/cgi-bin/message/send?access_token=%s", wecomAPIBase, accessToken) + + msg := WeComImageMessage{ + ToUser: userID, + MsgType: "image", + AgentID: c.config.AgentID, + } + msg.Image.MediaID = mediaID + + jsonData, err := json.Marshal(msg) + if err != nil { + return fmt.Errorf("failed to marshal message: %w", err) + } + + timeout := c.config.ReplyTimeout + if timeout <= 0 { + timeout = 5 + } + + reqCtx, cancel := context.WithTimeout(ctx, time.Duration(timeout)*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + + client := &http.Client{Timeout: time.Duration(timeout) * time.Second} + resp, err := client.Do(req) + if err != nil { + return channels.ClassifyNetError(err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return channels.ClassifySendError(resp.StatusCode, fmt.Errorf("wecom_app API error: %s", string(respBody))) + } + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("failed to read response: %w", err) + } + + var sendResp WeComSendMessageResponse + if err := json.Unmarshal(respBody, &sendResp); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + if sendResp.ErrCode != 0 { + return fmt.Errorf("API error: %s (code: %d)", sendResp.ErrMsg, sendResp.ErrCode) + } + + return nil +} + // WebhookPath returns the path for registering on the shared HTTP server. func (c *WeComAppChannel) WebhookPath() string { if c.config.WebhookPath != "" { diff --git a/pkg/tools/result.go b/pkg/tools/result.go index b13055b1c..cab833284 100644 --- a/pkg/tools/result.go +++ b/pkg/tools/result.go @@ -30,6 +30,10 @@ type ToolResult struct { // Err is the underlying error (not JSON serialized). // Used for internal error handling and logging. Err error `json:"-"` + + // Media contains media store refs produced by this tool. + // When non-empty, the agent will publish these as OutboundMediaMessage. + Media []string `json:"media,omitempty"` } // NewToolResult creates a basic ToolResult with content for the LLM. @@ -120,6 +124,19 @@ func UserResult(content string) *ToolResult { } } +// MediaResult creates a ToolResult with media refs for the user. +// The agent will publish these refs as OutboundMediaMessage. +// +// Example: +// +// result := MediaResult("Image generated successfully", []string{"media://abc123"}) +func MediaResult(forLLM string, mediaRefs []string) *ToolResult { + return &ToolResult{ + ForLLM: forLLM, + Media: mediaRefs, + } +} + // MarshalJSON implements custom JSON serialization. // The Err field is excluded from JSON output via the json:"-" tag. func (tr *ToolResult) MarshalJSON() ([]byte, error) { From 437657c5d55e1d249ce05b9662b22956e46c9d2b Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 03:47:12 +0800 Subject: [PATCH 17/28] refactor(channels): remove channel-side voice transcription (Phase 12) Remove SetTranscriber and inline transcription logic from 4 channels (Telegram, Discord, Slack, OneBot) and the gateway wiring. Voice/audio files are still downloaded and stored in MediaStore with simple text annotations ([voice], [audio: filename], [file: name]). The pkg/voice package is preserved for future Agent-level transcription middleware. --- cmd/picoclaw/cmd_gateway.go | 44 ++------------------------- pkg/channels/discord/discord.go | 49 +++++++------------------------ pkg/channels/onebot/onebot.go | 27 ++--------------- pkg/channels/slack/slack.go | 23 +-------------- pkg/channels/telegram/telegram.go | 31 +------------------ 5 files changed, 17 insertions(+), 157 deletions(-) diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 4e6ec8bb3..837c55f37 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -9,21 +9,20 @@ import ( "os" "os/signal" "path/filepath" - "strings" "time" "github.com/sipeed/picoclaw/pkg/agent" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" _ "github.com/sipeed/picoclaw/pkg/channels/dingtalk" - dch "github.com/sipeed/picoclaw/pkg/channels/discord" + _ "github.com/sipeed/picoclaw/pkg/channels/discord" _ "github.com/sipeed/picoclaw/pkg/channels/feishu" _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" _ "github.com/sipeed/picoclaw/pkg/channels/qq" - slackch "github.com/sipeed/picoclaw/pkg/channels/slack" - tgramch "github.com/sipeed/picoclaw/pkg/channels/telegram" + _ "github.com/sipeed/picoclaw/pkg/channels/slack" + _ "github.com/sipeed/picoclaw/pkg/channels/telegram" _ "github.com/sipeed/picoclaw/pkg/channels/wecom" _ "github.com/sipeed/picoclaw/pkg/channels/whatsapp" "github.com/sipeed/picoclaw/pkg/config" @@ -36,7 +35,6 @@ import ( "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/state" "github.com/sipeed/picoclaw/pkg/tools" - "github.com/sipeed/picoclaw/pkg/voice" ) func gatewayCmd() { @@ -136,42 +134,6 @@ func gatewayCmd() { agentLoop.SetChannelManager(channelManager) agentLoop.SetMediaStore(mediaStore) - var transcriber *voice.GroqTranscriber - groqAPIKey := cfg.Providers.Groq.APIKey - if groqAPIKey == "" { - for _, mc := range cfg.ModelList { - if strings.HasPrefix(mc.Model, "groq/") && mc.APIKey != "" { - groqAPIKey = mc.APIKey - break - } - } - } - if groqAPIKey != "" { - transcriber = voice.NewGroqTranscriber(groqAPIKey) - logger.InfoC("voice", "Groq voice transcription enabled") - } - - if transcriber != nil { - if telegramChannel, ok := channelManager.GetChannel("telegram"); ok { - if tc, ok := telegramChannel.(*tgramch.TelegramChannel); ok { - tc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Telegram channel") - } - } - if discordChannel, ok := channelManager.GetChannel("discord"); ok { - if dc, ok := discordChannel.(*dch.DiscordChannel); ok { - dc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Discord channel") - } - } - if slackChannel, ok := channelManager.GetChannel("slack"); ok { - if sc, ok := slackChannel.(*slackch.SlackChannel); ok { - sc.SetTranscriber(transcriber) - logger.InfoC("voice", "Groq transcription attached to Slack channel") - } - } - } - enabledChannels := channelManager.GetEnabledChannels() if len(enabledChannels) > 0 { fmt.Printf("✓ Channels enabled: %s\n", enabledChannels) diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 7987f45a9..68725b124 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -16,24 +16,21 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) const ( - transcriptionTimeout = 30 * time.Second - sendTimeout = 10 * time.Second + sendTimeout = 10 * time.Second ) type DiscordChannel struct { *channels.BaseChannel - session *discordgo.Session - config config.DiscordConfig - transcriber *voice.GroqTranscriber - ctx context.Context - cancel context.CancelFunc - typingMu sync.Mutex - typingStop map[string]chan struct{} // chatID → stop signal - botUserID string // stored for mention checking + session *discordgo.Session + config config.DiscordConfig + ctx context.Context + cancel context.CancelFunc + typingMu sync.Mutex + typingStop map[string]chan struct{} // chatID → stop signal + botUserID string // stored for mention checking } func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordChannel, error) { @@ -48,16 +45,11 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC BaseChannel: base, session: session, config: cfg, - transcriber: nil, ctx: context.Background(), typingStop: make(map[string]chan struct{}), }, nil } -func (c *DiscordChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *DiscordChannel) Start(ctx context.Context) error { logger.InfoC("discord", "Starting Discord bot") @@ -265,7 +257,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } - // Check allowlist first to avoid downloading attachments and transcribing for rejected users + // Check allowlist first to avoid downloading attachments for rejected users if !c.IsAllowed(m.Author.ID) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ "user_id": m.Author.ID, @@ -323,29 +315,8 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag if isAudio { localPath := c.downloadAttachment(attachment.URL, attachment.Filename) if localPath != "" { - transcribedText := "" - if c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.ctx, transcriptionTimeout) - result, err := c.transcriber.Transcribe(ctx, localPath) - cancel() // Release context resources immediately to avoid leaks in for loop - - if err != nil { - logger.ErrorCF("discord", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - transcribedText = fmt.Sprintf("[audio: %s (transcription failed)]", attachment.Filename) - } else { - transcribedText = fmt.Sprintf("[audio transcription: %s]", result.Text) - logger.DebugCF("discord", "Audio transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = fmt.Sprintf("[audio: %s]", attachment.Filename) - } - mediaPaths = append(mediaPaths, storeMedia(localPath, attachment.Filename)) - content = appendContent(content, transcribedText) + content = appendContent(content, fmt.Sprintf("[audio: %s]", attachment.Filename)) } else { logger.WarnCF("discord", "Failed to download audio attachment", map[string]any{ "url": attachment.URL, diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index fb357cf27..001965238 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -18,7 +18,6 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type OneBotChannel struct { @@ -36,7 +35,6 @@ type OneBotChannel struct { selfID int64 pending map[string]chan json.RawMessage pendingMu sync.Mutex - transcriber *voice.GroqTranscriber lastMessageID sync.Map pendingEmojiMsg sync.Map } @@ -112,10 +110,6 @@ func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*One }, nil } -func (c *OneBotChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *OneBotChannel) setMsgEmojiLike(messageID string, emojiID int, set bool) { go func() { _, err := c.sendAPIRequest("set_msg_emoji_like", map[string]any{ @@ -794,25 +788,8 @@ func (c *OneBotChannel) parseMessageSegments( LoggerPrefix: "onebot", }) if localPath != "" { - if c.transcriber != nil && c.transcriber.IsAvailable() { - tctx, tcancel := context.WithTimeout(c.ctx, 30*time.Second) - result, err := c.transcriber.Transcribe(tctx, localPath) - tcancel() - if err != nil { - logger.WarnCF("onebot", "Voice transcription failed", map[string]any{ - "error": err.Error(), - }) - textParts = append(textParts, "[voice (transcription failed)]") - mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) - } else { - textParts = append(textParts, fmt.Sprintf("[voice transcription: %s]", result.Text)) - // Still store the file so it can be released later - storeFile(localPath, "voice.amr") - } - } else { - textParts = append(textParts, "[voice]") - mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) - } + textParts = append(textParts, "[voice]") + mediaRefs = append(mediaRefs, storeFile(localPath, "voice.amr")) } } } diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index f2dda15ac..a8d329d65 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" "sync" - "time" "github.com/slack-go/slack" "github.com/slack-go/slack/slackevents" @@ -17,7 +16,6 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type SlackChannel struct { @@ -27,7 +25,6 @@ type SlackChannel struct { socketClient *socketmode.Client botUserID string teamID string - transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc pendingAcks sync.Map @@ -60,10 +57,6 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack }, nil } -func (c *SlackChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *SlackChannel) Start(ctx context.Context) error { logger.InfoC("slack", "Starting Slack channel (Socket Mode)") @@ -311,21 +304,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { continue } mediaPaths = append(mediaPaths, storeMedia(localPath, file.Name)) - - if utils.IsAudioFile(file.Name, file.Mimetype) && c.transcriber != nil && c.transcriber.IsAvailable() { - ctx, cancel := context.WithTimeout(c.ctx, 30*time.Second) - defer cancel() - result, err := c.transcriber.Transcribe(ctx, localPath) - - if err != nil { - logger.ErrorCF("slack", "Voice transcription failed", map[string]any{"error": err.Error()}) - content += fmt.Sprintf("\n[audio: %s (transcription failed)]", file.Name) - } else { - content += fmt.Sprintf("\n[voice transcription: %s]", result.Text) - } - } else { - content += fmt.Sprintf("\n[file: %s]", file.Name) - } + content += fmt.Sprintf("\n[file: %s]", file.Name) } } diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index f9390b8ed..9544987ec 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -22,7 +22,6 @@ import ( "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" - "github.com/sipeed/picoclaw/pkg/voice" ) type TelegramChannel struct { @@ -32,7 +31,6 @@ type TelegramChannel struct { commands TelegramCommander config *config.Config chatIDs map[string]int64 - transcriber *voice.GroqTranscriber ctx context.Context cancel context.CancelFunc placeholders sync.Map // chatID -> messageID @@ -91,16 +89,11 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann bot: bot, config: cfg, chatIDs: make(map[string]int64), - transcriber: nil, placeholders: sync.Map{}, stopThinking: sync.Map{}, }, nil } -func (c *TelegramChannel) SetTranscriber(transcriber *voice.GroqTranscriber) { - c.transcriber = transcriber -} - func (c *TelegramChannel) Start(ctx context.Context) error { logger.InfoC("telegram", "Starting Telegram bot (polling mode)...") @@ -391,32 +384,10 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes if voicePath != "" { mediaPaths = append(mediaPaths, storeMedia(voicePath, "voice.ogg")) - transcribedText := "" - if c.transcriber != nil && c.transcriber.IsAvailable() { - transcriberCtx, cancel := context.WithTimeout(ctx, 30*time.Second) - defer cancel() - - result, err := c.transcriber.Transcribe(transcriberCtx, voicePath) - if err != nil { - logger.ErrorCF("telegram", "Voice transcription failed", map[string]any{ - "error": err.Error(), - "path": voicePath, - }) - transcribedText = "[voice (transcription failed)]" - } else { - transcribedText = fmt.Sprintf("[voice transcription: %s]", result.Text) - logger.InfoCF("telegram", "Voice transcribed successfully", map[string]any{ - "text": result.Text, - }) - } - } else { - transcribedText = "[voice]" - } - if content != "" { content += "\n" } - content += transcribedText + content += "[voice]" } } From 4c653c661db7e686ccd5a2a907ecb70a877fdaa6 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 04:11:11 +0800 Subject: [PATCH 18/28] refactor(channels): standardize group chat trigger filtering (Phase 8) Add unified ShouldRespondInGroup to BaseChannel, replacing scattered per-channel group filtering logic. Introduce GroupTriggerConfig (with mention_only + prefixes), TypingConfig, and PlaceholderConfig types. Migrate Discord MentionOnly, OneBot checkGroupTrigger, and LINE hardcoded mention-only to the shared mechanism. Add group trigger entry points for Slack, Telegram, QQ, Feishu, DingTalk, and WeCom. Legacy config fields are preserved with automatic migration. --- pkg/channels/base.go | 47 +++++++++ pkg/channels/base_test.go | 127 +++++++++++++++++++++- pkg/channels/dingtalk/dingtalk.go | 11 +- pkg/channels/discord/discord.go | 25 +++-- pkg/channels/feishu/feishu_64.go | 10 +- pkg/channels/line/line.go | 26 +++-- pkg/channels/onebot/onebot.go | 28 +---- pkg/channels/qq/qq.go | 11 +- pkg/channels/slack/slack.go | 14 ++- pkg/channels/telegram/telegram.go | 63 +++++++++++ pkg/channels/wecom/app.go | 5 +- pkg/channels/wecom/bot.go | 14 ++- pkg/config/config.go | 170 +++++++++++++++++++----------- pkg/config/defaults.go | 1 + 14 files changed, 446 insertions(+), 106 deletions(-) diff --git a/pkg/channels/base.go b/pkg/channels/base.go index adacb8c78..e345aedf0 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/media" ) @@ -30,6 +31,11 @@ func WithMaxMessageLength(n int) BaseChannelOption { return func(c *BaseChannel) { c.maxMessageLength = n } } +// WithGroupTrigger sets the group trigger configuration for a channel. +func WithGroupTrigger(gt config.GroupTriggerConfig) BaseChannelOption { + return func(c *BaseChannel) { c.groupTrigger = gt } +} + // MessageLengthProvider is an opt-in interface that channels implement // to advertise their maximum message length. The Manager uses this via // type assertion to decide whether to split outbound messages. @@ -44,6 +50,7 @@ type BaseChannel struct { name string allowList []string maxMessageLength int + groupTrigger config.GroupTriggerConfig mediaStore media.MediaStore } @@ -72,6 +79,46 @@ func (c *BaseChannel) MaxMessageLength() int { return c.maxMessageLength } +// ShouldRespondInGroup determines whether the bot should respond in a group chat. +// Each channel is responsible for: +// 1. Detecting isMentioned (platform-specific) +// 2. Stripping bot mention from content (platform-specific) +// 3. Calling this method to get the group response decision +// +// Logic: +// - If isMentioned → always respond +// - If mention_only configured and not mentioned → ignore +// - If prefixes configured → respond if content starts with any prefix (strip it) +// - If prefixes configured but no match and not mentioned → ignore +// - Otherwise (no group_trigger configured) → respond to all (permissive default) +func (c *BaseChannel) ShouldRespondInGroup(isMentioned bool, content string) (bool, string) { + gt := c.groupTrigger + + // Mentioned → always respond + if isMentioned { + return true, strings.TrimSpace(content) + } + + // mention_only → require mention + if gt.MentionOnly { + return false, content + } + + // Prefix matching + if len(gt.Prefixes) > 0 { + for _, prefix := range gt.Prefixes { + if prefix != "" && strings.HasPrefix(content, prefix) { + return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) + } + } + // Prefixes configured but none matched and not mentioned → ignore + return false, content + } + + // No group_trigger configured → permissive (respond to all) + return true, strings.TrimSpace(content) +} + func (c *BaseChannel) Name() string { return c.name } diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go index 78c6d1d66..e56ad3ee9 100644 --- a/pkg/channels/base_test.go +++ b/pkg/channels/base_test.go @@ -1,6 +1,10 @@ package channels -import "testing" +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/config" +) func TestBaseChannelIsAllowed(t *testing.T) { tests := []struct { @@ -50,3 +54,124 @@ func TestBaseChannelIsAllowed(t *testing.T) { }) } } + +func TestShouldRespondInGroup(t *testing.T) { + tests := []struct { + name string + gt config.GroupTriggerConfig + isMentioned bool + content string + wantRespond bool + wantContent string + }{ + { + name: "no config - permissive default", + gt: config.GroupTriggerConfig{}, + isMentioned: false, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "no config - mentioned", + gt: config.GroupTriggerConfig{}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "mention_only - not mentioned", + gt: config.GroupTriggerConfig{MentionOnly: true}, + isMentioned: false, + content: "hello world", + wantRespond: false, + wantContent: "hello world", + }, + { + name: "mention_only - mentioned", + gt: config.GroupTriggerConfig{MentionOnly: true}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "prefix match", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "/ask hello", + wantRespond: true, + wantContent: "hello", + }, + { + name: "prefix no match - not mentioned", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "hello world", + wantRespond: false, + wantContent: "hello world", + }, + { + name: "prefix no match - but mentioned", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask"}}, + isMentioned: true, + content: "hello world", + wantRespond: true, + wantContent: "hello world", + }, + { + name: "multiple prefixes - second matches", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask", "/bot"}}, + isMentioned: false, + content: "/bot help me", + wantRespond: true, + wantContent: "help me", + }, + { + name: "mention_only with prefixes - mentioned overrides", + gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}}, + isMentioned: true, + content: "hello", + wantRespond: true, + wantContent: "hello", + }, + { + name: "mention_only with prefixes - not mentioned, no prefix", + gt: config.GroupTriggerConfig{MentionOnly: true, Prefixes: []string{"/ask"}}, + isMentioned: false, + content: "hello", + wantRespond: false, + wantContent: "hello", + }, + { + name: "empty prefix in list is skipped", + gt: config.GroupTriggerConfig{Prefixes: []string{"", "/ask"}}, + isMentioned: false, + content: "/ask test", + wantRespond: true, + wantContent: "test", + }, + { + name: "prefix strips leading whitespace after prefix", + gt: config.GroupTriggerConfig{Prefixes: []string{"/ask "}}, + isMentioned: false, + content: "/ask hello", + wantRespond: true, + wantContent: "hello", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, nil, WithGroupTrigger(tt.gt)) + gotRespond, gotContent := ch.ShouldRespondInGroup(tt.isMentioned, tt.content) + if gotRespond != tt.wantRespond { + t.Errorf("ShouldRespondInGroup() respond = %v, want %v", gotRespond, tt.wantRespond) + } + if gotContent != tt.wantContent { + t.Errorf("ShouldRespondInGroup() content = %q, want %q", gotContent, tt.wantContent) + } + }) + } +} diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index c49769761..b28bc850f 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -38,7 +38,10 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("dingtalk client_id and client_secret are required") } - base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(20000)) + base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(20000), + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &DingTalkChannel{ BaseChannel: base, @@ -165,6 +168,12 @@ func (c *DingTalkChannel) onChatBotMessageReceived( peer = bus.Peer{Kind: "direct", ID: senderID} } else { peer = bus.Peer{Kind: "group", ID: data.ConversationId} + // In group chats, apply unified group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return nil, nil + } + content = cleaned } logger.DebugCF("dingtalk", "Received message", map[string]any{ diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 68725b124..4ef4906c1 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -39,7 +39,10 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC return nil, fmt.Errorf("failed to create discord session: %w", err) } - base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000)) + base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, + channels.WithMaxMessageLength(2000), + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &DiscordChannel{ BaseChannel: base, @@ -265,9 +268,11 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag return } - // If configured to only respond to mentions, check if bot is mentioned - // Skip this check for DMs (GuildID is empty) - DMs should always be responded to - if c.config.MentionOnly && m.GuildID != "" { + content := m.Content + + // In guild (group) channels, apply unified group trigger filtering + // DMs (GuildID is empty) always get a response + if m.GuildID != "" { isMentioned := false for _, mention := range m.Mentions { if mention.ID == c.botUserID { @@ -275,12 +280,18 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag break } } - if !isMentioned { - logger.DebugCF("discord", "Message ignored - bot not mentioned", map[string]any{ + content = c.stripBotMention(content) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + logger.DebugCF("discord", "Group message ignored by group trigger", map[string]any{ "user_id": m.Author.ID, }) return } + content = cleaned + } else { + // DMs: just strip bot mention without filtering + content = c.stripBotMention(content) } senderID := m.Author.ID @@ -289,8 +300,6 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag senderName += "#" + m.Author.Discriminator } - content := m.Content - content = c.stripBotMention(content) mediaPaths := make([]string, 0, len(m.Attachments)) scope := channels.BuildMediaScope("discord", m.ChannelID, m.ID) diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 5245cd99d..aaaf6cf1b 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -32,7 +32,9 @@ type FeishuChannel struct { } func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { - base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &FeishuChannel{ BaseChannel: base, @@ -173,6 +175,12 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 peer = bus.Peer{Kind: "direct", ID: senderID} } else { peer = bus.Peer{Kind: "group", ID: chatID} + // In group chats, apply unified group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return nil + } + content = cleaned } logger.InfoCF("feishu", "Feishu message received", map[string]any{ diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 5b0af4f1d..a79931bc9 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -59,7 +59,10 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha return nil, fmt.Errorf("line channel_secret and channel_access_token are required") } - base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(5000)) + base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(5000), + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &LINEChannel{ BaseChannel: base, @@ -262,14 +265,6 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } - // In group chats, only respond when the bot is mentioned - if isGroup && !c.isBotMentioned(msg) { - logger.DebugCF("line", "Ignoring group message without mention", map[string]any{ - "chat_id": chatID, - }) - return - } - // Store reply token for later use if event.ReplyToken != "" { c.replyTokens.Store(chatID, replyTokenEntry{ @@ -339,6 +334,19 @@ func (c *LINEChannel) processEvent(event lineEvent) { return } + // In group chats, apply unified group trigger filtering + if isGroup { + isMentioned := c.isBotMentioned(msg) + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + logger.DebugCF("line", "Ignoring group message by group trigger", map[string]any{ + "chat_id": chatID, + }) + return + } + content = cleaned + } + metadata := map[string]string{ "platform": "line", "source_type": event.Source.Type, diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 001965238..f32cb4948 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -97,7 +97,9 @@ type oneBotMessageSegment struct { } func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { - base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + ) const dedupSize = 1024 return &OneBotChannel{ @@ -996,8 +998,8 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { metadata["sender_name"] = sender.Nickname } - triggered, strippedContent := c.checkGroupTrigger(content, isBotMentioned) - if !triggered { + respond, strippedContent := c.ShouldRespondInGroup(isBotMentioned, content) + if !respond { logger.DebugCF("onebot", "Group message ignored (no trigger)", map[string]any{ "sender": senderID, "group": groupIDStr, @@ -1069,23 +1071,3 @@ func truncate(s string, n int) string { } return string(runes[:n]) + "..." } - -func (c *OneBotChannel) checkGroupTrigger( - content string, - isBotMentioned bool, -) (triggered bool, strippedContent string) { - if isBotMentioned { - return true, strings.TrimSpace(content) - } - - for _, prefix := range c.config.GroupTriggerPrefix { - if prefix == "" { - continue - } - if strings.HasPrefix(content, prefix) { - return true, strings.TrimSpace(strings.TrimPrefix(content, prefix)) - } - } - - return false, content -} diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 69f323e6e..011eb6c3c 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -32,7 +32,9 @@ type QQChannel struct { } func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { - base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom) + base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom, + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &QQChannel{ BaseChannel: base, @@ -204,6 +206,13 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return nil } + // GroupAT event means bot is always mentioned; apply group trigger filtering + respond, cleaned := c.ShouldRespondInGroup(true, content) + if !respond { + return nil + } + content = cleaned + logger.InfoCF("qq", "Received group AT message", map[string]any{ "sender": senderID, "group": data.GroupID, diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index a8d329d65..6fba2e0b4 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -47,7 +47,10 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack socketClient := socketmode.New(api) - base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(40000)) + base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(40000), + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &SlackChannel{ BaseChannel: base, @@ -279,6 +282,15 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { content := ev.Text content = c.stripBotMention(content) + // In non-DM channels, apply group trigger filtering + if !strings.HasPrefix(channelID, "D") { + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return + } + content = cleaned + } + var mediaPaths []string scope := channels.BuildMediaScope("slack", chatID, messageTS) diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 9544987ec..c5c055163 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -81,6 +81,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann bus, telegramCfg.AllowFrom, channels.WithMaxMessageLength(4096), + channels.WithGroupTrigger(telegramCfg.GroupTrigger), ) return &TelegramChannel{ @@ -417,6 +418,19 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes content = "[empty message]" } + // In group chats, apply unified group trigger filtering + if message.Chat.Type != "private" { + isMentioned := c.isBotMentioned(message) + if isMentioned { + content = c.stripBotMention(content) + } + respond, cleaned := c.ShouldRespondInGroup(isMentioned, content) + if !respond { + return nil + } + content = cleaned + } + logger.DebugCF("telegram", "Received message", map[string]any{ "sender_id": senderID, "chat_id": fmt.Sprintf("%d", chatID), @@ -629,3 +643,52 @@ func escapeHTML(text string) string { text = strings.ReplaceAll(text, ">", ">") return text } + +// isBotMentioned checks if the bot is mentioned in the message via entities. +func (c *TelegramChannel) isBotMentioned(message *telego.Message) bool { + botUsername := c.bot.Username() + if botUsername == "" { + return false + } + + entities := message.Entities + if entities == nil { + entities = message.CaptionEntities + } + + for _, entity := range entities { + if entity.Type == "mention" { + // Extract the mention text from the message + text := message.Text + if text == "" { + text = message.Caption + } + runes := []rune(text) + end := entity.Offset + entity.Length + if end <= len(runes) { + mention := string(runes[entity.Offset:end]) + if strings.EqualFold(mention, "@"+botUsername) { + return true + } + } + } + if entity.Type == "text_mention" && entity.User != nil { + if entity.User.Username == botUsername { + return true + } + } + } + return false +} + +// stripBotMention removes the @bot mention from the content. +func (c *TelegramChannel) stripBotMention(content string) string { + botUsername := c.bot.Username() + if botUsername == "" { + return content + } + // Case-insensitive replacement + re := regexp.MustCompile(`(?i)@` + regexp.QuoteMeta(botUsername)) + content = re.ReplaceAllString(content, "") + return strings.TrimSpace(content) +} diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 4c2a4d326..53b53ffb8 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -122,7 +122,10 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( return nil, fmt.Errorf("wecom_app corp_id, corp_secret and agent_id are required") } - base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(2048)) + base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &WeComAppChannel{ BaseChannel: base, diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index d5912bddc..7ffe4734b 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -86,7 +86,10 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We return nil, fmt.Errorf("wecom token and webhook_url are required") } - base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(2048)) + base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, + channels.WithMaxMessageLength(2048), + channels.WithGroupTrigger(cfg.GroupTrigger), + ) return &WeComBotChannel{ BaseChannel: base, @@ -367,6 +370,15 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag // Build metadata peer := bus.Peer{Kind: peerKind, ID: peerID} + // In group chats, apply unified group trigger filtering + if isGroupChat { + respond, cleaned := c.ShouldRespondInGroup(false, content) + if !respond { + return + } + content = cleaned + } + metadata := map[string]string{ "msg_type": msg.MsgType, "msg_id": msg.MsgID, diff --git a/pkg/config/config.go b/pkg/config/config.go index 2595398c7..cf768a79e 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -194,6 +194,23 @@ type ChannelsConfig struct { WeComApp WeComAppConfig `json:"wecom_app"` } +// GroupTriggerConfig controls when the bot responds in group chats. +type GroupTriggerConfig struct { + MentionOnly bool `json:"mention_only,omitempty"` + Prefixes []string `json:"prefixes,omitempty"` +} + +// TypingConfig controls typing indicator behavior (Phase 10). +type TypingConfig struct { + Enabled bool `json:"enabled,omitempty"` +} + +// PlaceholderConfig controls placeholder message behavior (Phase 10). +type PlaceholderConfig struct { + Enabled bool `json:"enabled,omitempty"` + Text string `json:"text,omitempty"` +} + type WhatsAppConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` @@ -201,26 +218,33 @@ type WhatsAppConfig struct { } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` - VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` - MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } type MaixCamConfig struct { @@ -231,69 +255,82 @@ type MaixCamConfig struct { } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } type LINEConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` - ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` - ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_LINE_ENABLED"` + ChannelSecret string `json:"channel_secret" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_SECRET"` + ChannelAccessToken string `json:"channel_access_token" env:"PICOCLAW_CHANNELS_LINE_CHANNEL_ACCESS_TOKEN"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_LINE_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_LINE_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } type OneBotConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` - WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` - AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` - ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` - GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_ONEBOT_ENABLED"` + WSUrl string `json:"ws_url" env:"PICOCLAW_CHANNELS_ONEBOT_WS_URL"` + AccessToken string `json:"access_token" env:"PICOCLAW_CHANNELS_ONEBOT_ACCESS_TOKEN"` + ReconnectInterval int `json:"reconnect_interval" env:"PICOCLAW_CHANNELS_ONEBOT_RECONNECT_INTERVAL"` + GroupTriggerPrefix []string `json:"group_trigger_prefix" env:"PICOCLAW_CHANNELS_ONEBOT_GROUP_TRIGGER_PREFIX"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_ONEBOT_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` } type WeComConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` + WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } type WeComAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` + CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` + CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` + AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } type HeartbeatConfig struct { @@ -507,6 +544,9 @@ func LoadConfig(path string) (*Config, error) { return nil, err } + // Migrate legacy channel config fields to new unified structures + cfg.migrateChannelConfigs() + // Auto-migrate: if only legacy providers config exists, convert to model_list if len(cfg.ModelList) == 0 && cfg.HasProvidersConfig() { cfg.ModelList = ConvertProvidersToModelList(cfg) @@ -520,6 +560,18 @@ func LoadConfig(path string) (*Config, error) { return cfg, nil } +func (c *Config) migrateChannelConfigs() { + // Discord: mention_only -> group_trigger.mention_only + if c.Channels.Discord.MentionOnly && !c.Channels.Discord.GroupTrigger.MentionOnly { + c.Channels.Discord.GroupTrigger.MentionOnly = true + } + + // OneBot: group_trigger_prefix -> group_trigger.prefixes + if len(c.Channels.OneBot.GroupTriggerPrefix) > 0 && len(c.Channels.OneBot.GroupTrigger.Prefixes) == 0 { + c.Channels.OneBot.GroupTrigger.Prefixes = c.Channels.OneBot.GroupTriggerPrefix + } +} + func SaveConfig(path string, cfg *Config) error { data, err := json.MarshalIndent(cfg, "", " ") if err != nil { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index b96ee4d89..03ad2ab6b 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -80,6 +80,7 @@ func DefaultConfig() *Config { WebhookPort: 18791, WebhookPath: "/webhook/line", AllowFrom: FlexibleStringSlice{}, + GroupTrigger: GroupTriggerConfig{MentionOnly: true}, }, OneBot: OneBotConfig{ Enabled: false, From 90b4a6468311c4f7bdcb148052924c2bc1d436c6 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 04:55:15 +0800 Subject: [PATCH 19/28] feat(channels): add typing/placeholder automation and Pico Protocol channel (Phase 10 + 7) Phase 10: Define TypingCapable, MessageEditor, PlaceholderRecorder interfaces. Manager orchestrates outbound typing stop and placeholder editing via preSend. Migrate Telegram, Discord, Slack, OneBot to register state with Manager instead of handling locally in Send. Phase 7: Add native WebSocket Pico Protocol channel as reference implementation of all optional capability interfaces. --- cmd/picoclaw/cmd_gateway.go | 1 + pkg/channels/base.go | 27 +- pkg/channels/discord/discord.go | 14 +- pkg/channels/interfaces.go | 24 ++ pkg/channels/manager.go | 56 ++++ pkg/channels/manager_test.go | 216 +++++++++++++++ pkg/channels/onebot/onebot.go | 13 +- pkg/channels/pico/init.go | 13 + pkg/channels/pico/pico.go | 430 ++++++++++++++++++++++++++++++ pkg/channels/pico/protocol.go | 46 ++++ pkg/channels/slack/slack.go | 24 ++ pkg/channels/telegram/telegram.go | 113 +++----- pkg/config/config.go | 12 + pkg/config/defaults.go | 14 + 14 files changed, 913 insertions(+), 90 deletions(-) create mode 100644 pkg/channels/interfaces.go create mode 100644 pkg/channels/pico/init.go create mode 100644 pkg/channels/pico/pico.go create mode 100644 pkg/channels/pico/protocol.go diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 837c55f37..33217492d 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -20,6 +20,7 @@ import ( _ "github.com/sipeed/picoclaw/pkg/channels/line" _ "github.com/sipeed/picoclaw/pkg/channels/maixcam" _ "github.com/sipeed/picoclaw/pkg/channels/onebot" + _ "github.com/sipeed/picoclaw/pkg/channels/pico" _ "github.com/sipeed/picoclaw/pkg/channels/qq" _ "github.com/sipeed/picoclaw/pkg/channels/slack" _ "github.com/sipeed/picoclaw/pkg/channels/telegram" diff --git a/pkg/channels/base.go b/pkg/channels/base.go index e345aedf0..c22a27eb9 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -44,14 +44,15 @@ type MessageLengthProvider interface { } type BaseChannel struct { - config any - bus *bus.MessageBus - running atomic.Bool - name string - allowList []string - maxMessageLength int - groupTrigger config.GroupTriggerConfig - mediaStore media.MediaStore + config any + bus *bus.MessageBus + running atomic.Bool + name string + allowList []string + maxMessageLength int + groupTrigger config.GroupTriggerConfig + mediaStore media.MediaStore + placeholderRecorder PlaceholderRecorder } func NewBaseChannel( @@ -203,6 +204,16 @@ func (c *BaseChannel) SetMediaStore(s media.MediaStore) { c.mediaStore = s } // GetMediaStore returns the injected MediaStore (may be nil). func (c *BaseChannel) GetMediaStore() media.MediaStore { return c.mediaStore } +// SetPlaceholderRecorder injects a PlaceholderRecorder into the channel. +func (c *BaseChannel) SetPlaceholderRecorder(r PlaceholderRecorder) { + c.placeholderRecorder = r +} + +// GetPlaceholderRecorder returns the injected PlaceholderRecorder (may be nil). +func (c *BaseChannel) GetPlaceholderRecorder() PlaceholderRecorder { + return c.placeholderRecorder +} + // BuildMediaScope constructs a scope key for media lifecycle tracking. func BuildMediaScope(channel, chatID, messageID string) string { id := messageID diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 4ef4906c1..ee698da61 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -106,8 +106,6 @@ func (c *DiscordChannel) Stop(ctx context.Context) error { } func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { return channels.ErrNotRunning } @@ -126,8 +124,6 @@ func (c *DiscordChannel) Send(ctx context.Context, msg bus.OutboundMessage) erro // SendMedia implements the channels.MediaSender interface. func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { - c.stopTyping(msg.ChatID) - if !c.IsRunning() { return channels.ErrNotRunning } @@ -221,6 +217,12 @@ func (c *DiscordChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMes } } +// EditMessage implements channels.MessageEditor. +func (c *DiscordChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + _, err := c.session.ChannelMessageEdit(chatID, messageID, content) + return err +} + func (c *DiscordChannel) sendChunk(ctx context.Context, channelID, content string) error { // Use the passed ctx for timeout control sendCtx, cancel := context.WithTimeout(ctx, sendTimeout) @@ -350,6 +352,10 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag // Start typing after all early returns — guaranteed to have a matching Send() c.startTyping(m.ChannelID) + // Register typing stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("discord", m.ChannelID, func() { c.stopTyping(m.ChannelID) }) + } logger.DebugCF("discord", "Received message", map[string]any{ "sender_name": senderName, diff --git a/pkg/channels/interfaces.go b/pkg/channels/interfaces.go new file mode 100644 index 000000000..32bfe95f8 --- /dev/null +++ b/pkg/channels/interfaces.go @@ -0,0 +1,24 @@ +package channels + +import "context" + +// TypingCapable — channels that can show a typing/thinking indicator. +// StartTyping begins the indicator and returns a stop function. +// The stop function MUST be idempotent and safe to call multiple times. +type TypingCapable interface { + StartTyping(ctx context.Context, chatID string) (stop func(), err error) +} + +// MessageEditor — channels that can edit an existing message. +// messageID is always string; channels convert platform-specific types internally. +type MessageEditor interface { + EditMessage(ctx context.Context, chatID string, messageID string, content string) error +} + +// PlaceholderRecorder is injected into channels by Manager. +// Channels call these methods on inbound to register typing/placeholder state. +// Manager uses the registered state on outbound to stop typing and edit placeholders. +type PlaceholderRecorder interface { + RecordPlaceholder(channel, chatID, placeholderID string) + RecordTypingStop(channel, chatID string, stop func()) +} diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 92412edeb..4b1a43b7b 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -62,12 +62,55 @@ type Manager struct { mux *http.ServeMux httpServer *http.Server mu sync.RWMutex + placeholders sync.Map // "channel:chatID" → placeholderID (string) + typingStops sync.Map // "channel:chatID" → func() } type asyncTask struct { cancel context.CancelFunc } +// RecordPlaceholder registers a placeholder message for later editing. +// Implements PlaceholderRecorder. +func (m *Manager) RecordPlaceholder(channel, chatID, placeholderID string) { + key := channel + ":" + chatID + m.placeholders.Store(key, placeholderID) +} + +// RecordTypingStop registers a typing stop function for later invocation. +// Implements PlaceholderRecorder. +func (m *Manager) RecordTypingStop(channel, chatID string, stop func()) { + key := channel + ":" + chatID + m.typingStops.Store(key, stop) +} + +// preSend handles typing stop and placeholder editing before sending a message. +// Returns true if the message was edited into a placeholder (skip Send). +func (m *Manager) preSend(ctx context.Context, name string, msg bus.OutboundMessage, ch Channel) bool { + key := name + ":" + msg.ChatID + + // 1. Stop typing + if v, loaded := m.typingStops.LoadAndDelete(key); loaded { + if stop, ok := v.(func()); ok { + stop() // idempotent, safe + } + } + + // 2. Try editing placeholder + if v, loaded := m.placeholders.LoadAndDelete(key); loaded { + if placeholderID, ok := v.(string); ok && placeholderID != "" { + if editor, ok := ch.(MessageEditor); ok { + if err := editor.EditMessage(ctx, msg.ChatID, placeholderID, msg.Content); err == nil { + return true // edited successfully, skip Send + } + // edit failed → fall through to normal Send + } + } + } + + return false +} + func NewManager(cfg *config.Config, messageBus *bus.MessageBus, store media.MediaStore) (*Manager, error) { m := &Manager{ channels: make(map[string]Channel), @@ -109,6 +152,10 @@ func (m *Manager) initChannel(name, displayName string) { setter.SetMediaStore(m.mediaStore) } } + // Inject PlaceholderRecorder if channel supports it + if setter, ok := ch.(interface{ SetPlaceholderRecorder(PlaceholderRecorder) }); ok { + setter.SetPlaceholderRecorder(m) + } m.channels[name] = ch m.workers[name] = newChannelWorker(name, ch) logger.InfoCF("channels", "Channel enabled successfully", map[string]any{ @@ -168,6 +215,10 @@ func (m *Manager) initChannels() error { m.initChannel("wecom_app", "WeCom App") } + if m.config.Channels.Pico.Enabled && m.config.Channels.Pico.Token != "" { + m.initChannel("pico", "Pico") + } + logger.InfoCF("channels", "Channel initialization completed", map[string]any{ "enabled_channels": len(m.channels), }) @@ -383,6 +434,11 @@ func (m *Manager) sendWithRetry(ctx context.Context, name string, w *channelWork return } + // Pre-send: stop typing and try to edit placeholder + if m.preSend(ctx, name, msg, w.ch) { + return // placeholder was edited successfully, skip Send + } + var lastErr error for attempt := 0; attempt <= maxRetries; attempt++ { lastErr = w.ch.Send(ctx, msg) diff --git a/pkg/channels/manager_test.go b/pkg/channels/manager_test.go index 162c9f8c9..0573c0a8e 100644 --- a/pkg/channels/manager_test.go +++ b/pkg/channels/manager_test.go @@ -416,3 +416,219 @@ func TestSendWithRetry_ExponentialBackoff(t *testing.T) { t.Fatalf("expected %d calls, got %d", maxRetries+1, callCount.Load()) } } + +// --- Phase 10: preSend orchestration tests --- + +// mockMessageEditor is a channel that supports MessageEditor. +type mockMessageEditor struct { + mockChannel + editFn func(ctx context.Context, chatID, messageID, content string) error +} + +func (m *mockMessageEditor) EditMessage(ctx context.Context, chatID, messageID, content string) error { + return m.editFn(ctx, chatID, messageID, content) +} + +func TestPreSend_PlaceholderEditSuccess(t *testing.T) { + m := newTestManager() + var sendCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, chatID, messageID, content string) error { + editCalled = true + if chatID != "123" { + t.Fatalf("expected chatID 123, got %s", chatID) + } + if messageID != "456" { + t.Fatalf("expected messageID 456, got %s", messageID) + } + if content != "hello" { + t.Fatalf("expected content 'hello', got %s", content) + } + return nil + }, + } + + // Register placeholder + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !edited { + t.Fatal("expected preSend to return true (placeholder edited)") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder edited") + } +} + +func TestPreSend_PlaceholderEditFails_FallsThrough(t *testing.T) { + m := newTestManager() + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return fmt.Errorf("edit failed") + }, + } + + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false when edit fails") + } +} + +func TestPreSend_TypingStopCalled(t *testing.T) { + m := newTestManager() + var stopCalled bool + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop func to be called") + } +} + +func TestPreSend_NoRegisteredState(t *testing.T) { + m := newTestManager() + + ch := &mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if edited { + t.Fatal("expected preSend to return false with no registered state") + } +} + +func TestPreSend_TypingAndPlaceholder(t *testing.T) { + m := newTestManager() + var stopCalled bool + var editCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + editCalled = true + return nil + }, + } + + m.RecordTypingStop("test", "123", func() { + stopCalled = true + }) + m.RecordPlaceholder("test", "123", "456") + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + edited := m.preSend(context.Background(), "test", msg, ch) + + if !stopCalled { + t.Fatal("expected typing stop to be called") + } + if !editCalled { + t.Fatal("expected EditMessage to be called") + } + if !edited { + t.Fatal("expected preSend to return true") + } +} + +func TestRecordPlaceholder_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordPlaceholder("test", chatID, fmt.Sprintf("msg_%d", i)) + }(i) + } + wg.Wait() +} + +func TestRecordTypingStop_ConcurrentSafe(t *testing.T) { + m := newTestManager() + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + chatID := fmt.Sprintf("chat_%d", i%10) + m.RecordTypingStop("test", chatID, func() {}) + }(i) + } + wg.Wait() +} + +func TestSendWithRetry_PreSendEditsPlaceholder(t *testing.T) { + m := newTestManager() + var sendCalled bool + + ch := &mockMessageEditor{ + mockChannel: mockChannel{ + sendFn: func(_ context.Context, _ bus.OutboundMessage) error { + sendCalled = true + return nil + }, + }, + editFn: func(_ context.Context, _, _, _ string) error { + return nil // edit succeeds + }, + } + + m.RecordPlaceholder("test", "123", "456") + + w := &channelWorker{ + ch: ch, + limiter: rate.NewLimiter(rate.Inf, 1), + } + + msg := bus.OutboundMessage{Channel: "test", ChatID: "123", Content: "hello"} + m.sendWithRetry(context.Background(), "test", w, msg) + + if sendCalled { + t.Fatal("expected Send to NOT be called when placeholder was edited") + } +} diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index f32cb4948..682025b67 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -418,12 +418,6 @@ func (c *OneBotChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return fmt.Errorf("onebot send: %w", channels.ErrTemporary) } - if msgID, ok := c.pendingEmojiMsg.LoadAndDelete(msg.ChatID); ok { - if mid, ok := msgID.(string); ok && mid != "" { - c.setMsgEmojiLike(mid, 289, false) - } - } - return nil } @@ -1037,6 +1031,13 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { if raw.MessageType == "group" && messageID != "" && messageID != "0" { c.setMsgEmojiLike(messageID, 289, true) c.pendingEmojiMsg.Store(chatID, messageID) + // Register emoji stop with Manager for outbound orchestration + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedMsgID := messageID + rec.RecordTypingStop("onebot", chatID, func() { + c.setMsgEmojiLike(capturedMsgID, 289, false) + }) + } } c.HandleMessage(peer, messageID, senderID, chatID, content, parsed.Media, metadata) diff --git a/pkg/channels/pico/init.go b/pkg/channels/pico/init.go new file mode 100644 index 000000000..96d764418 --- /dev/null +++ b/pkg/channels/pico/init.go @@ -0,0 +1,13 @@ +package pico + +import ( + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" +) + +func init() { + channels.RegisterFactory("pico", func(cfg *config.Config, b *bus.MessageBus) (channels.Channel, error) { + return NewPicoChannel(cfg.Channels.Pico, b) + }) +} diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go new file mode 100644 index 000000000..1c28ca732 --- /dev/null +++ b/pkg/channels/pico/pico.go @@ -0,0 +1,430 @@ +package pico + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" +) + +// picoConn represents a single WebSocket connection. +type picoConn struct { + id string + conn *websocket.Conn + sessionID string + writeMu sync.Mutex + closed atomic.Bool +} + +// writeJSON sends a JSON message to the connection with write locking. +func (pc *picoConn) writeJSON(v any) error { + if pc.closed.Load() { + return fmt.Errorf("connection closed") + } + pc.writeMu.Lock() + defer pc.writeMu.Unlock() + return pc.conn.WriteJSON(v) +} + +// close closes the connection. +func (pc *picoConn) close() { + if pc.closed.CompareAndSwap(false, true) { + pc.conn.Close() + } +} + +// PicoChannel implements the native Pico Protocol WebSocket channel. +// It serves as the reference implementation for all optional capability interfaces. +type PicoChannel struct { + *channels.BaseChannel + config config.PicoConfig + upgrader websocket.Upgrader + connections sync.Map // connID → *picoConn + connCount atomic.Int32 + ctx context.Context + cancel context.CancelFunc +} + +// NewPicoChannel creates a new Pico Protocol channel. +func NewPicoChannel(cfg config.PicoConfig, messageBus *bus.MessageBus) (*PicoChannel, error) { + if cfg.Token == "" { + return nil, fmt.Errorf("pico token is required") + } + + base := channels.NewBaseChannel("pico", cfg, messageBus, cfg.AllowFrom) + + allowOrigins := cfg.AllowOrigins + checkOrigin := func(r *http.Request) bool { + if len(allowOrigins) == 0 { + return true // allow all if not configured + } + origin := r.Header.Get("Origin") + for _, allowed := range allowOrigins { + if allowed == "*" || allowed == origin { + return true + } + } + return false + } + + return &PicoChannel{ + BaseChannel: base, + config: cfg, + upgrader: websocket.Upgrader{ + CheckOrigin: checkOrigin, + ReadBufferSize: 1024, + WriteBufferSize: 1024, + }, + }, nil +} + +// Start implements Channel. +func (c *PicoChannel) Start(ctx context.Context) error { + logger.InfoC("pico", "Starting Pico Protocol channel") + c.ctx, c.cancel = context.WithCancel(ctx) + c.SetRunning(true) + logger.InfoC("pico", "Pico Protocol channel started") + return nil +} + +// Stop implements Channel. +func (c *PicoChannel) Stop(ctx context.Context) error { + logger.InfoC("pico", "Stopping Pico Protocol channel") + c.SetRunning(false) + + // Close all connections + c.connections.Range(func(key, value any) bool { + if pc, ok := value.(*picoConn); ok { + pc.close() + } + c.connections.Delete(key) + return true + }) + + if c.cancel != nil { + c.cancel() + } + + logger.InfoC("pico", "Pico Protocol channel stopped") + return nil +} + +// WebhookPath implements channels.WebhookHandler. +func (c *PicoChannel) WebhookPath() string { return "/pico/" } + +// ServeHTTP implements http.Handler for the shared HTTP server. +func (c *PicoChannel) ServeHTTP(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/pico") + + switch { + case path == "/ws" || path == "/ws/": + c.handleWebSocket(w, r) + default: + http.NotFound(w, r) + } +} + +// Send implements Channel — sends a message to the appropriate WebSocket connection. +func (c *PicoChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { + if !c.IsRunning() { + return channels.ErrNotRunning + } + + outMsg := newMessage(TypeMessageCreate, map[string]any{ + "content": msg.Content, + }) + + return c.broadcastToSession(msg.ChatID, outMsg) +} + +// EditMessage implements channels.MessageEditor. +func (c *PicoChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + outMsg := newMessage(TypeMessageUpdate, map[string]any{ + "message_id": messageID, + "content": content, + }) + return c.broadcastToSession(chatID, outMsg) +} + +// StartTyping implements channels.TypingCapable. +func (c *PicoChannel) StartTyping(ctx context.Context, chatID string) (func(), error) { + startMsg := newMessage(TypeTypingStart, nil) + if err := c.broadcastToSession(chatID, startMsg); err != nil { + return func() {}, err + } + return func() { + stopMsg := newMessage(TypeTypingStop, nil) + c.broadcastToSession(chatID, stopMsg) + }, nil +} + +// broadcastToSession sends a message to all connections with a matching session. +func (c *PicoChannel) broadcastToSession(chatID string, msg PicoMessage) error { + // chatID format: "pico:" + sessionID := strings.TrimPrefix(chatID, "pico:") + msg.SessionID = sessionID + + var sent bool + c.connections.Range(func(key, value any) bool { + pc, ok := value.(*picoConn) + if !ok { + return true + } + if pc.sessionID == sessionID { + if err := pc.writeJSON(msg); err != nil { + logger.DebugCF("pico", "Write to connection failed", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } else { + sent = true + } + } + return true + }) + + if !sent { + return fmt.Errorf("no active connections for session %s: %w", sessionID, channels.ErrSendFailed) + } + return nil +} + +// handleWebSocket upgrades the HTTP connection and manages the WebSocket lifecycle. +func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { + if !c.IsRunning() { + http.Error(w, "channel not running", http.StatusServiceUnavailable) + return + } + + // Authenticate + if !c.authenticate(r) { + http.Error(w, "unauthorized", http.StatusUnauthorized) + return + } + + // Check connection limit + maxConns := c.config.MaxConnections + if maxConns <= 0 { + maxConns = 100 + } + if int(c.connCount.Load()) >= maxConns { + http.Error(w, "too many connections", http.StatusServiceUnavailable) + return + } + + conn, err := c.upgrader.Upgrade(w, r, nil) + if err != nil { + logger.ErrorCF("pico", "WebSocket upgrade failed", map[string]any{ + "error": err.Error(), + }) + return + } + + // Determine session ID from query param or generate one + sessionID := r.URL.Query().Get("session_id") + if sessionID == "" { + sessionID = uuid.New().String() + } + + pc := &picoConn{ + id: uuid.New().String(), + conn: conn, + sessionID: sessionID, + } + + c.connections.Store(pc.id, pc) + c.connCount.Add(1) + + logger.InfoCF("pico", "WebSocket client connected", map[string]any{ + "conn_id": pc.id, + "session_id": sessionID, + }) + + go c.readLoop(pc) +} + +// authenticate checks the Bearer token from header or query parameter. +func (c *PicoChannel) authenticate(r *http.Request) bool { + token := c.config.Token + if token == "" { + return false + } + + // Check Authorization header + auth := r.Header.Get("Authorization") + if strings.HasPrefix(auth, "Bearer ") { + if strings.TrimPrefix(auth, "Bearer ") == token { + return true + } + } + + // Check query parameter + if r.URL.Query().Get("token") == token { + return true + } + + return false +} + +// readLoop reads messages from a WebSocket connection. +func (c *PicoChannel) readLoop(pc *picoConn) { + defer func() { + pc.close() + c.connections.Delete(pc.id) + c.connCount.Add(-1) + logger.InfoCF("pico", "WebSocket client disconnected", map[string]any{ + "conn_id": pc.id, + "session_id": pc.sessionID, + }) + }() + + readTimeout := time.Duration(c.config.ReadTimeout) * time.Second + if readTimeout <= 0 { + readTimeout = 60 * time.Second + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + pc.conn.SetPongHandler(func(appData string) error { + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + return nil + }) + + // Start ping ticker + pingInterval := time.Duration(c.config.PingInterval) * time.Second + if pingInterval <= 0 { + pingInterval = 30 * time.Second + } + go c.pingLoop(pc, pingInterval) + + for { + select { + case <-c.ctx.Done(): + return + default: + } + + _, rawMsg, err := pc.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { + logger.DebugCF("pico", "WebSocket read error", map[string]any{ + "conn_id": pc.id, + "error": err.Error(), + }) + } + return + } + + _ = pc.conn.SetReadDeadline(time.Now().Add(readTimeout)) + + var msg PicoMessage + if err := json.Unmarshal(rawMsg, &msg); err != nil { + errMsg := newError("invalid_message", "failed to parse message") + pc.writeJSON(errMsg) + continue + } + + c.handleMessage(pc, msg) + } +} + +// pingLoop sends periodic ping frames to keep the connection alive. +func (c *PicoChannel) pingLoop(pc *picoConn, interval time.Duration) { + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-c.ctx.Done(): + return + case <-ticker.C: + if pc.closed.Load() { + return + } + pc.writeMu.Lock() + err := pc.conn.WriteMessage(websocket.PingMessage, nil) + pc.writeMu.Unlock() + if err != nil { + return + } + } + } +} + +// handleMessage processes an inbound Pico Protocol message. +func (c *PicoChannel) handleMessage(pc *picoConn, msg PicoMessage) { + switch msg.Type { + case TypePing: + pong := newMessage(TypePong, nil) + pong.ID = msg.ID + pc.writeJSON(pong) + + case TypeMessageSend: + c.handleMessageSend(pc, msg) + + default: + errMsg := newError("unknown_type", fmt.Sprintf("unknown message type: %s", msg.Type)) + pc.writeJSON(errMsg) + } +} + +// handleMessageSend processes an inbound message.send from a client. +func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { + content, _ := msg.Payload["content"].(string) + if strings.TrimSpace(content) == "" { + errMsg := newError("empty_content", "message content is empty") + pc.writeJSON(errMsg) + return + } + + sessionID := msg.SessionID + if sessionID == "" { + sessionID = pc.sessionID + } + + chatID := "pico:" + sessionID + senderID := "pico-user" + + peer := bus.Peer{Kind: "direct", ID: "pico:" + sessionID} + + metadata := map[string]string{ + "platform": "pico", + "session_id": sessionID, + "conn_id": pc.id, + } + + logger.DebugCF("pico", "Received message", map[string]any{ + "session_id": sessionID, + "preview": truncate(content, 50), + }) + + // Register typing with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + stop, err := c.StartTyping(c.ctx, chatID) + if err == nil { + rec.RecordTypingStop("pico", chatID, stop) + } + } + + c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata) +} + +// truncate truncates a string to maxLen runes. +func truncate(s string, maxLen int) string { + runes := []rune(s) + if len(runes) <= maxLen { + return s + } + return string(runes[:maxLen]) + "..." +} diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go new file mode 100644 index 000000000..ca18df1dd --- /dev/null +++ b/pkg/channels/pico/protocol.go @@ -0,0 +1,46 @@ +package pico + +import "time" + +// Protocol message types. +const ( + // Client → Server + TypeMessageSend = "message.send" + TypeMediaSend = "media.send" + TypePing = "ping" + + // Server → Client + TypeMessageCreate = "message.create" + TypeMessageUpdate = "message.update" + TypeMediaCreate = "media.create" + TypeTypingStart = "typing.start" + TypeTypingStop = "typing.stop" + TypeError = "error" + TypePong = "pong" +) + +// PicoMessage is the wire format for all Pico Protocol messages. +type PicoMessage struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + SessionID string `json:"session_id,omitempty"` + Timestamp int64 `json:"timestamp,omitempty"` + Payload map[string]any `json:"payload,omitempty"` +} + +// newMessage creates a PicoMessage with the given type and payload. +func newMessage(msgType string, payload map[string]any) PicoMessage { + return PicoMessage{ + Type: msgType, + Timestamp: time.Now().UnixMilli(), + Payload: payload, + } +} + +// newError creates an error PicoMessage. +func newError(code, message string) PicoMessage { + return newMessage(TypeError, map[string]any{ + "code": code, + "message": message, + }) +} diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 6fba2e0b4..e64525310 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -274,6 +274,18 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, @@ -380,6 +392,18 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { Timestamp: messageTS, }) + // Register typing stop (remove "eyes" reaction) with Manager + if rec := c.GetPlaceholderRecorder(); rec != nil { + capturedChannelID := channelID + capturedMessageTS := messageTS + rec.RecordTypingStop("slack", chatID, func() { + c.api.RemoveReaction("eyes", slack.ItemRef{ + Channel: capturedChannelID, + Timestamp: capturedMessageTS, + }) + }) + } + c.pendingAcks.Store(chatID, slackMessageRef{ ChannelID: channelID, Timestamp: messageTS, diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index c5c055163..98477f3a8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -7,8 +7,8 @@ import ( "net/url" "os" "regexp" + "strconv" "strings" - "sync" "time" "github.com/mymmrac/telego" @@ -26,25 +26,13 @@ import ( type TelegramChannel struct { *channels.BaseChannel - bot *telego.Bot - bh *telegohandler.BotHandler - commands TelegramCommander - config *config.Config - chatIDs map[string]int64 - ctx context.Context - cancel context.CancelFunc - placeholders sync.Map // chatID -> messageID - stopThinking sync.Map // chatID -> thinkingCancel -} - -type thinkingCancel struct { - fn context.CancelFunc -} - -func (c *thinkingCancel) Cancel() { - if c != nil && c.fn != nil { - c.fn() - } + bot *telego.Bot + bh *telegohandler.BotHandler + commands TelegramCommander + config *config.Config + chatIDs map[string]int64 + ctx context.Context + cancel context.CancelFunc } func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChannel, error) { @@ -85,13 +73,11 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann ) return &TelegramChannel{ - BaseChannel: base, - commands: NewTelegramCommands(bot, cfg), - bot: bot, - config: cfg, - chatIDs: make(map[string]int64), - placeholders: sync.Map{}, - stopThinking: sync.Map{}, + BaseChannel: base, + commands: NewTelegramCommands(bot, cfg), + bot: bot, + config: cfg, + chatIDs: make(map[string]int64), }, nil } @@ -149,21 +135,6 @@ func (c *TelegramChannel) Stop(ctx context.Context) error { logger.InfoC("telegram", "Stopping Telegram bot...") c.SetRunning(false) - // Clean up all thinking cancel functions to avoid context leaks - c.stopThinking.Range(func(key, value any) bool { - if cf, ok := value.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(key) - return true - }) - - // Clean up placeholder state - c.placeholders.Range(func(key, value any) bool { - c.placeholders.Delete(key) - return true - }) - // Stop the bot handler if c.bh != nil { c.bh.Stop() @@ -187,28 +158,9 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return fmt.Errorf("invalid chat ID %s: %w", msg.ChatID, channels.ErrSendFailed) } - // Stop thinking animation - if stop, ok := c.stopThinking.Load(msg.ChatID); ok { - if cf, ok := stop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - c.stopThinking.Delete(msg.ChatID) - } - htmlContent := markdownToTelegramHTML(msg.Content) - // Try to edit placeholder - if pID, ok := c.placeholders.Load(msg.ChatID); ok { - c.placeholders.Delete(msg.ChatID) - editMsg := tu.EditMessageText(tu.ID(chatID), pID.(int), htmlContent) - editMsg.ParseMode = telego.ModeHTML - - if _, err = c.bot.EditMessageText(ctx, editMsg); err == nil { - return nil - } - // Fallback to new message if edit fails - } - + // Typing/placeholder handled by Manager.preSend — just send the message tgMsg := tu.Message(tu.ID(chatID), htmlContent) tgMsg.ParseMode = telego.ModeHTML @@ -225,6 +177,23 @@ func (c *TelegramChannel) Send(ctx context.Context, msg bus.OutboundMessage) err return nil } +// EditMessage implements channels.MessageEditor. +func (c *TelegramChannel) EditMessage(ctx context.Context, chatID string, messageID string, content string) error { + cid, err := parseChatID(chatID) + if err != nil { + return err + } + mid, err := strconv.Atoi(messageID) + if err != nil { + return err + } + htmlContent := markdownToTelegramHTML(content) + editMsg := tu.EditMessageText(tu.ID(cid), mid, htmlContent) + editMsg.ParseMode = telego.ModeHTML + _, err = c.bot.EditMessageText(ctx, editMsg) + return err +} + // SendMedia implements the channels.MediaSender interface. func (c *TelegramChannel) SendMedia(ctx context.Context, msg bus.OutboundMediaMessage) error { if !c.IsRunning() { @@ -445,21 +414,21 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes }) } - // Stop any previous thinking animation - if prevStop, ok := c.stopThinking.Load(chatIDStr); ok { - if cf, ok := prevStop.(*thinkingCancel); ok && cf != nil { - cf.Cancel() - } - } - - // Create cancel function for thinking state + // Create cancel function for thinking state and register with Manager _, thinkCancel := context.WithTimeout(ctx, 5*time.Minute) - c.stopThinking.Store(chatIDStr, &thinkingCancel{fn: thinkCancel}) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordTypingStop("telegram", chatIDStr, thinkCancel) + } else { + // No recorder — cancel immediately to avoid context leak + thinkCancel() + } pMsg, err := c.bot.SendMessage(ctx, tu.Message(tu.ID(chatID), "Thinking... 💭")) if err == nil { pID := pMsg.MessageID - c.placeholders.Store(chatIDStr, pID) + if rec := c.GetPlaceholderRecorder(); rec != nil { + rec.RecordPlaceholder("telegram", chatIDStr, fmt.Sprintf("%d", pID)) + } } peerKind := "direct" diff --git a/pkg/config/config.go b/pkg/config/config.go index cf768a79e..35bbefb24 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -192,6 +192,7 @@ type ChannelsConfig struct { OneBot OneBotConfig `json:"onebot"` WeCom WeComConfig `json:"wecom"` WeComApp WeComAppConfig `json:"wecom_app"` + Pico PicoConfig `json:"pico"` } // GroupTriggerConfig controls when the bot responds in group chats. @@ -333,6 +334,17 @@ type WeComAppConfig struct { GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` } +type PicoConfig struct { + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` + AllowOrigins []string `json:"allow_origins,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` // seconds, default 30 + ReadTimeout int `json:"read_timeout,omitempty"` // seconds, default 60 + WriteTimeout int `json:"write_timeout,omitempty"` // seconds, default 10 + MaxConnections int `json:"max_connections,omitempty"` // default 100 + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` +} + type HeartbeatConfig struct { Enabled bool `json:"enabled" env:"PICOCLAW_HEARTBEAT_ENABLED"` Interval int `json:"interval" env:"PICOCLAW_HEARTBEAT_INTERVAL"` // minutes, min 5 diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index 03ad2ab6b..604b53e24 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -33,6 +33,11 @@ func DefaultConfig() *Config { Enabled: false, Token: "", AllowFrom: FlexibleStringSlice{}, + Typing: TypingConfig{Enabled: true}, + Placeholder: PlaceholderConfig{ + Enabled: true, + Text: "Thinking... 💭", + }, }, Feishu: FeishuConfig{ Enabled: false, @@ -114,6 +119,15 @@ func DefaultConfig() *Config { AllowFrom: FlexibleStringSlice{}, ReplyTimeout: 5, }, + Pico: PicoConfig{ + Enabled: false, + Token: "", + PingInterval: 30, + ReadTimeout: 60, + WriteTimeout: 10, + MaxConnections: 100, + AllowFrom: FlexibleStringSlice{}, + }, }, Providers: ProvidersConfig{ OpenAI: OpenAIProviderConfig{WebSearch: true}, From ced55e768ccfb7d14cfee8de691504baff88a9f0 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 05:22:18 +0800 Subject: [PATCH 20/28] fix: resolve golangci-lint issues in channel system --- pkg/channels/manager.go | 2 +- pkg/channels/pico/protocol.go | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 4b1a43b7b..8e72efc5c 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -153,7 +153,7 @@ func (m *Manager) initChannel(name, displayName string) { } } // Inject PlaceholderRecorder if channel supports it - if setter, ok := ch.(interface{ SetPlaceholderRecorder(PlaceholderRecorder) }); ok { + if setter, ok := ch.(interface{ SetPlaceholderRecorder(r PlaceholderRecorder) }); ok { setter.SetPlaceholderRecorder(m) } m.channels[name] = ch diff --git a/pkg/channels/pico/protocol.go b/pkg/channels/pico/protocol.go index ca18df1dd..0a630e193 100644 --- a/pkg/channels/pico/protocol.go +++ b/pkg/channels/pico/protocol.go @@ -4,12 +4,12 @@ import "time" // Protocol message types. const ( - // Client → Server + // TypeMessageSend is sent from client to server. TypeMessageSend = "message.send" TypeMediaSend = "media.send" TypePing = "ping" - // Server → Client + // TypeMessageCreate is sent from server to client. TypeMessageCreate = "message.create" TypeMessageUpdate = "message.update" TypeMediaCreate = "media.create" From f4b0f080e29ae3304ded321b67cbb1697e3a5157 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 05:46:34 +0800 Subject: [PATCH 21/28] refactor(channels): move SplitMessage from pkg/utils to pkg/channels Message splitting is exclusively a Manager responsibility. Moving it into the channels package eliminates the cross-package dependency and aligns with the refactoring plan. --- pkg/channels/manager.go | 3 +-- pkg/{utils/message.go => channels/split.go} | 2 +- pkg/{utils/message_test.go => channels/split_test.go} | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) rename pkg/{utils/message.go => channels/split.go} (99%) rename pkg/{utils/message_test.go => channels/split_test.go} (99%) diff --git a/pkg/channels/manager.go b/pkg/channels/manager.go index 8e72efc5c..07c2ce1e2 100644 --- a/pkg/channels/manager.go +++ b/pkg/channels/manager.go @@ -23,7 +23,6 @@ import ( "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" - "github.com/sipeed/picoclaw/pkg/utils" ) const ( @@ -407,7 +406,7 @@ func (m *Manager) runWorker(ctx context.Context, name string, w *channelWorker) maxLen = mlp.MaxMessageLength() } if maxLen > 0 && len([]rune(msg.Content)) > maxLen { - chunks := utils.SplitMessage(msg.Content, maxLen) + chunks := SplitMessage(msg.Content, maxLen) for _, chunk := range chunks { chunkMsg := msg chunkMsg.Content = chunk diff --git a/pkg/utils/message.go b/pkg/channels/split.go similarity index 99% rename from pkg/utils/message.go rename to pkg/channels/split.go index 52a967f4c..a455c5741 100644 --- a/pkg/utils/message.go +++ b/pkg/channels/split.go @@ -1,4 +1,4 @@ -package utils +package channels import ( "strings" diff --git a/pkg/utils/message_test.go b/pkg/channels/split_test.go similarity index 99% rename from pkg/utils/message_test.go rename to pkg/channels/split_test.go index 78e1e2b40..d6356bdb9 100644 --- a/pkg/utils/message_test.go +++ b/pkg/channels/split_test.go @@ -1,4 +1,4 @@ -package utils +package channels import ( "strings" From db3c1e011ffdc2632377a0a38837a6eba6566b5e Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 06:03:23 +0800 Subject: [PATCH 22/28] fix: address PR review feedback across channel system - MediaStore: use full UUID to prevent ref collisions, preserve and expose metadata via ResolveWithMeta, include underlying OS errors - Agent loop: populate MediaPart Type/Filename/ContentType from MediaStore metadata so channels can dispatch media correctly - SplitMessage: fix byte-vs-rune index mixup in code block header parsing, remove dead candidateStr variable - Pico auth: restrict query-param token behind AllowTokenQuery config flag (default false) to prevent token leakage via logs/referer - HandleMessage: replace context.TODO with caller-propagated ctx, log PublishInbound failures instead of silently discarding - Gateway shutdown: use fresh 15s timeout context for StopAll so graceful shutdown is not short-circuited by the cancelled parent ctx --- cmd/picoclaw/cmd_gateway.go | 8 +++++- pkg/agent/loop.go | 42 ++++++++++++++++++++++++++++- pkg/channels/base.go | 10 ++++++- pkg/channels/dingtalk/dingtalk.go | 2 +- pkg/channels/discord/discord.go | 2 +- pkg/channels/feishu/feishu_64.go | 4 +-- pkg/channels/line/line.go | 2 +- pkg/channels/maixcam/maixcam.go | 11 +++++++- pkg/channels/onebot/onebot.go | 2 +- pkg/channels/pico/pico.go | 13 +++++---- pkg/channels/qq/qq.go | 4 +-- pkg/channels/slack/slack.go | 6 ++--- pkg/channels/split.go | 18 +++++++++---- pkg/channels/telegram/telegram.go | 2 +- pkg/channels/wecom/app.go | 2 +- pkg/channels/wecom/bot.go | 2 +- pkg/channels/whatsapp/whatsapp.go | 2 +- pkg/config/config.go | 17 ++++++------ pkg/media/store.go | 41 +++++++++++++++++++++------- pkg/media/store_test.go | 44 +++++++++++++++++++++++++++++++ 20 files changed, 187 insertions(+), 47 deletions(-) diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index 33217492d..798ad2813 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -190,7 +190,13 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() msgBus.Close() - channelManager.StopAll(ctx) + + // Use a fresh context with timeout for graceful shutdown, + // since the original ctx is already cancelled. + shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 15*time.Second) + defer shutdownCancel() + + channelManager.StopAll(shutdownCtx) deviceService.Stop() heartbeatService.Stop() cronService.Stop() diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 050303101..0e2097488 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -10,6 +10,7 @@ import ( "context" "encoding/json" "fmt" + "path/filepath" "strings" "sync" "sync/atomic" @@ -237,6 +238,36 @@ func (al *AgentLoop) SetMediaStore(s media.MediaStore) { al.mediaStore = s } +// inferMediaType determines the media type ("image", "audio", "video", "file") +// from a filename and MIME content type. +func inferMediaType(filename, contentType string) string { + ct := strings.ToLower(contentType) + fn := strings.ToLower(filename) + + if strings.HasPrefix(ct, "image/") { + return "image" + } + if strings.HasPrefix(ct, "audio/") || ct == "application/ogg" { + return "audio" + } + if strings.HasPrefix(ct, "video/") { + return "video" + } + + // Fallback: infer from extension + ext := filepath.Ext(fn) + switch ext { + case ".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".svg": + return "image" + case ".mp3", ".wav", ".ogg", ".m4a", ".flac", ".aac", ".wma", ".opus": + return "audio" + case ".mp4", ".avi", ".mov", ".webm", ".mkv": + return "video" + } + + return "file" +} + // RecordLastChannel records the last active channel for this workspace. // This uses the atomic state save mechanism to prevent data loss on crash. func (al *AgentLoop) RecordLastChannel(channel string) error { @@ -731,7 +762,16 @@ func (al *AgentLoop) runLLMIteration( if len(toolResult.Media) > 0 && opts.SendResponse { parts := make([]bus.MediaPart, 0, len(toolResult.Media)) for _, ref := range toolResult.Media { - parts = append(parts, bus.MediaPart{Ref: ref}) + part := bus.MediaPart{Ref: ref} + // Populate metadata from MediaStore when available + if al.mediaStore != nil { + if _, meta, err := al.mediaStore.ResolveWithMeta(ref); err == nil { + part.Filename = meta.Filename + part.ContentType = meta.ContentType + part.Type = inferMediaType(meta.Filename, meta.ContentType) + } + } + parts = append(parts, part) } al.bus.PublishOutboundMedia(ctx, bus.OutboundMediaMessage{ Channel: opts.Channel, diff --git a/pkg/channels/base.go b/pkg/channels/base.go index c22a27eb9..c6a5f1cdc 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -9,6 +9,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" ) @@ -168,6 +169,7 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { } func (c *BaseChannel) HandleMessage( + ctx context.Context, peer bus.Peer, messageID, senderID, chatID, content string, media []string, @@ -191,7 +193,13 @@ func (c *BaseChannel) HandleMessage( Metadata: metadata, } - c.bus.PublishInbound(context.TODO(), msg) + if err := c.bus.PublishInbound(ctx, msg); err != nil { + logger.ErrorCF("channels", "Failed to publish inbound message", map[string]any{ + "channel": c.name, + "chat_id": chatID, + "error": err.Error(), + }) + } } func (c *BaseChannel) SetRunning(running bool) { diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index b28bc850f..7ab73b4d3 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -183,7 +183,7 @@ func (c *DingTalkChannel) onChatBotMessageReceived( }) // Handle the message through the base channel - c.HandleMessage(peer, "", senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index ee698da61..464a4db7b 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -381,7 +381,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag "is_dm": fmt.Sprintf("%t", m.GuildID == ""), } - c.HandleMessage(peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata) } // startTyping starts a continuous typing indicator loop for the given chatID. diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index aaaf6cf1b..4b8eddd21 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -131,7 +131,7 @@ func (c *FeishuChannel) Send(ctx context.Context, msg bus.OutboundMessage) error return nil } -func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2MessageReceiveV1) error { +func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim.P2MessageReceiveV1) error { if event == nil || event.Event == nil || event.Event.Message == nil { return nil } @@ -189,7 +189,7 @@ func (c *FeishuChannel) handleMessageReceive(_ context.Context, event *larkim.P2 "preview": utils.Truncate(content, 80), }) - c.HandleMessage(peer, messageID, senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata) return nil } diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index a79931bc9..399617064 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -370,7 +370,7 @@ func (c *LINEChannel) processEvent(event lineEvent) { // Show typing/loading indicator (requires user ID, not group ID) c.sendLoading(senderID) - c.HandleMessage(peer, msg.ID, senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata) } // isBotMentioned checks if the bot is mentioned in the message. diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index b5b7259f9..dceaec4c5 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -179,7 +179,16 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { "h": fmt.Sprintf("%.0f", h), } - c.HandleMessage(bus.Peer{Kind: "channel", ID: "default"}, "", senderID, chatID, content, []string{}, metadata) + c.HandleMessage( + c.ctx, + bus.Peer{Kind: "channel", ID: "default"}, + "", + senderID, + chatID, + content, + []string{}, + metadata, + ) } func (c *MaixCamChannel) handleStatusUpdate(msg MaixCamMessage) { diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index 682025b67..b47685397 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -1040,7 +1040,7 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } } - c.HandleMessage(peer, messageID, senderID, chatID, content, parsed.Media, metadata) + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata) } func (c *OneBotChannel) isDuplicate(messageID string) bool { diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 1c28ca732..9809786e3 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -255,7 +255,8 @@ func (c *PicoChannel) handleWebSocket(w http.ResponseWriter, r *http.Request) { go c.readLoop(pc) } -// authenticate checks the Bearer token from header or query parameter. +// authenticate checks the Bearer token from the Authorization header. +// Query parameter authentication is only allowed when AllowTokenQuery is explicitly enabled. func (c *PicoChannel) authenticate(r *http.Request) bool { token := c.config.Token if token == "" { @@ -270,9 +271,11 @@ func (c *PicoChannel) authenticate(r *http.Request) bool { } } - // Check query parameter - if r.URL.Query().Get("token") == token { - return true + // Check query parameter only when explicitly allowed + if c.config.AllowTokenQuery { + if r.URL.Query().Get("token") == token { + return true + } } return false @@ -417,7 +420,7 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { } } - c.HandleMessage(peer, msg.ID, senderID, chatID, content, nil, metadata) + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata) } // truncate truncates a string to maxLen runes. diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 011eb6c3c..c43c13655 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -168,7 +168,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { // 转发到消息总线 metadata := map[string]string{} - c.HandleMessage( + c.HandleMessage(c.ctx, bus.Peer{Kind: "direct", ID: senderID}, data.ID, senderID, @@ -224,7 +224,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { "group_id": data.GroupID, } - c.HandleMessage( + c.HandleMessage(c.ctx, bus.Peer{Kind: "group", ID: data.GroupID}, data.ID, senderID, diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index e64525310..c6b3c829e 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -360,7 +360,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "has_thread": threadTS != "", }) - c.HandleMessage(peer, messageTS, senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -433,7 +433,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { "team_id": c.teamID, } - c.HandleMessage(mentionPeer, messageTS, senderID, chatID, content, nil, metadata) + c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -476,7 +476,7 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "text": utils.Truncate(content, 50), }) - c.HandleMessage(bus.Peer{Kind: "channel", ID: channelID}, "", senderID, chatID, content, nil, metadata) + c.HandleMessage(c.ctx, bus.Peer{Kind: "channel", ID: channelID}, "", senderID, chatID, content, nil, metadata) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { diff --git a/pkg/channels/split.go b/pkg/channels/split.go index a455c5741..27d76df1b 100644 --- a/pkg/channels/split.go +++ b/pkg/channels/split.go @@ -66,9 +66,8 @@ func SplitMessage(content string, maxLen int) []string { } else { // Code block is too long to fit in one chunk or missing closing fence. // Try to split inside by injecting closing and reopening fences. - candidateStr := string(candidate) - unclosedStr := string(runes[unclosedIdx:]) - headerEnd := strings.Index(unclosedStr, "\n") + fenceRunes := runes[unclosedIdx:] + headerEnd := findNewlineInRunes(fenceRunes) var header string if headerEnd == -1 { header = strings.TrimSpace(string(runes[unclosedIdx : unclosedIdx+3])) @@ -80,8 +79,6 @@ func SplitMessage(content string, maxLen int) []string { headerEndIdx = unclosedIdx + headerEnd } - _ = candidateStr // used above for context - // If we have a reasonable amount of content after the header, split inside if msgEnd > headerEndIdx+20 { // Find a better split point closer to maxLen @@ -170,6 +167,17 @@ func findNextClosingCodeBlockRunes(runes []rune, startIdx int) int { return -1 } +// findNewlineInRunes finds the first newline character in a rune slice. +// Returns the rune index of the newline or -1 if not found. +func findNewlineInRunes(runes []rune) int { + for i, r := range runes { + if r == '\n' { + return i + } + } + return -1 +} + // findLastNewlineRunes finds the last newline character within the last N runes // Returns the rune position of the newline or -1 if not found func findLastNewlineRunes(runes []rune, searchWindow int) int { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 98477f3a8..31be4d489 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -448,7 +448,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes "is_group": fmt.Sprintf("%t", message.Chat.Type != "private"), } - c.HandleMessage( + c.HandleMessage(c.ctx, peer, messageID, fmt.Sprintf("%d", user.ID), diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index 53b53ffb8..e822e67b2 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -630,7 +630,7 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag }) // Handle the message through the base channel - c.HandleMessage(peer, messageID, senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata) } // tokenRefreshLoop periodically refreshes the access token diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 7ffe4734b..401c9c5ec 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -399,7 +399,7 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag }) // Handle the message through the base channel - c.HandleMessage(peer, msg.MsgID, senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata) } // sendWebhookReply sends a reply using the webhook URL diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 97032334f..b4599b5a0 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -224,5 +224,5 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { "preview": utils.Truncate(content, 50), }) - c.HandleMessage(peer, messageID, senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata) } diff --git a/pkg/config/config.go b/pkg/config/config.go index 35bbefb24..fd5def625 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -335,14 +335,15 @@ type WeComAppConfig struct { } type PicoConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` - AllowOrigins []string `json:"allow_origins,omitempty"` - PingInterval int `json:"ping_interval,omitempty"` // seconds, default 30 - ReadTimeout int `json:"read_timeout,omitempty"` // seconds, default 60 - WriteTimeout int `json:"write_timeout,omitempty"` // seconds, default 10 - MaxConnections int `json:"max_connections,omitempty"` // default 100 - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_PICO_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_PICO_TOKEN"` + AllowTokenQuery bool `json:"allow_token_query,omitempty"` + AllowOrigins []string `json:"allow_origins,omitempty"` + PingInterval int `json:"ping_interval,omitempty"` + ReadTimeout int `json:"read_timeout,omitempty"` + WriteTimeout int `json:"write_timeout,omitempty"` + MaxConnections int `json:"max_connections,omitempty"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_PICO_ALLOW_FROM"` } type HeartbeatConfig struct { diff --git a/pkg/media/store.go b/pkg/media/store.go index 8d03c03ef..2df4420e9 100644 --- a/pkg/media/store.go +++ b/pkg/media/store.go @@ -25,23 +25,32 @@ type MediaStore interface { // Resolve returns the local file path for a given ref. Resolve(ref string) (localPath string, err error) + // ResolveWithMeta returns the local file path and metadata for a given ref. + ResolveWithMeta(ref string) (localPath string, meta MediaMeta, err error) + // ReleaseAll deletes all files registered under the given scope // and removes the mapping entries. File-not-exist errors are ignored. ReleaseAll(scope string) error } +// mediaEntry holds the path and metadata for a stored media file. +type mediaEntry struct { + path string + meta MediaMeta +} + // FileMediaStore is a pure in-memory implementation of MediaStore. // Files are expected to already exist on disk (e.g. in /tmp/picoclaw_media/). type FileMediaStore struct { mu sync.RWMutex - refToPath map[string]string + refs map[string]mediaEntry scopeToRefs map[string]map[string]struct{} } // NewFileMediaStore creates a new FileMediaStore. func NewFileMediaStore() *FileMediaStore { return &FileMediaStore{ - refToPath: make(map[string]string), + refs: make(map[string]mediaEntry), scopeToRefs: make(map[string]map[string]struct{}), } } @@ -49,15 +58,15 @@ func NewFileMediaStore() *FileMediaStore { // Store registers a local file under the given scope. The file must exist. func (s *FileMediaStore) Store(localPath string, meta MediaMeta, scope string) (string, error) { if _, err := os.Stat(localPath); err != nil { - return "", fmt.Errorf("media store: file does not exist: %s", localPath) + return "", fmt.Errorf("media store: %s: %w", localPath, err) } - ref := "media://" + uuid.New().String()[:8] + ref := "media://" + uuid.New().String() s.mu.Lock() defer s.mu.Unlock() - s.refToPath[ref] = localPath + s.refs[ref] = mediaEntry{path: localPath, meta: meta} if s.scopeToRefs[scope] == nil { s.scopeToRefs[scope] = make(map[string]struct{}) } @@ -71,11 +80,23 @@ func (s *FileMediaStore) Resolve(ref string) (string, error) { s.mu.RLock() defer s.mu.RUnlock() - path, ok := s.refToPath[ref] + entry, ok := s.refs[ref] if !ok { return "", fmt.Errorf("media store: unknown ref: %s", ref) } - return path, nil + return entry.path, nil +} + +// ResolveWithMeta returns the local path and metadata for the given ref. +func (s *FileMediaStore) ResolveWithMeta(ref string) (string, MediaMeta, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + entry, ok := s.refs[ref] + if !ok { + return "", MediaMeta{}, fmt.Errorf("media store: unknown ref: %s", ref) + } + return entry.path, entry.meta, nil } // ReleaseAll removes all files under the given scope and cleans up mappings. @@ -89,11 +110,11 @@ func (s *FileMediaStore) ReleaseAll(scope string) error { } for ref := range refs { - if path, exists := s.refToPath[ref]; exists { - if err := os.Remove(path); err != nil && !os.IsNotExist(err) { + if entry, exists := s.refs[ref]; exists { + if err := os.Remove(entry.path); err != nil && !os.IsNotExist(err) { // Log but continue — best effort cleanup } - delete(s.refToPath, ref) + delete(s.refs, ref) } } diff --git a/pkg/media/store_test.go b/pkg/media/store_test.go index 361582307..95bd1eb7a 100644 --- a/pkg/media/store_test.go +++ b/pkg/media/store_test.go @@ -139,6 +139,50 @@ func TestStoreNonexistentFile(t *testing.T) { if err == nil { t.Error("Store should fail for nonexistent file") } + // Error message should include the underlying os error, not just "file does not exist" + if !strings.Contains(err.Error(), "no such file or directory") { + t.Errorf("Error should contain OS error detail, got: %v", err) + } +} + +func TestResolveWithMeta(t *testing.T) { + dir := t.TempDir() + store := NewFileMediaStore() + + path := createTempFile(t, dir, "image.png") + meta := MediaMeta{ + Filename: "image.png", + ContentType: "image/png", + Source: "telegram", + } + + ref, err := store.Store(path, meta, "scope1") + if err != nil { + t.Fatalf("Store failed: %v", err) + } + + resolvedPath, resolvedMeta, err := store.ResolveWithMeta(ref) + if err != nil { + t.Fatalf("ResolveWithMeta failed: %v", err) + } + if resolvedPath != path { + t.Errorf("ResolveWithMeta path = %q, want %q", resolvedPath, path) + } + if resolvedMeta.Filename != meta.Filename { + t.Errorf("ResolveWithMeta Filename = %q, want %q", resolvedMeta.Filename, meta.Filename) + } + if resolvedMeta.ContentType != meta.ContentType { + t.Errorf("ResolveWithMeta ContentType = %q, want %q", resolvedMeta.ContentType, meta.ContentType) + } + if resolvedMeta.Source != meta.Source { + t.Errorf("ResolveWithMeta Source = %q, want %q", resolvedMeta.Source, meta.Source) + } + + // Unknown ref should fail + _, _, err = store.ResolveWithMeta("media://nonexistent") + if err == nil { + t.Error("ResolveWithMeta should fail for unknown ref") + } } func TestConcurrentSafety(t *testing.T) { From 3fb4469e47506df5c4539419438a4ce30e5da5a2 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 06:56:48 +0800 Subject: [PATCH 23/28] feat(identity): add unified user identity with canonical platform:id format Introduce SenderInfo struct and pkg/identity package to standardize user identification across all channels. Each channel now constructs structured sender info (platform, platformID, canonicalID, username, displayName) instead of ad-hoc string IDs. Allow-list matching supports all legacy formats (numeric ID, @username, id|username) plus the new canonical "platform:id" format. Session key resolution also handles canonical peerIDs for backward-compatible identity link matching. --- pkg/bus/types.go | 10 ++ pkg/channels/base.go | 44 +++++- pkg/channels/base_test.go | 88 ++++++++++++ pkg/channels/dingtalk/dingtalk.go | 15 +- pkg/channels/discord/discord.go | 26 ++-- pkg/channels/feishu/feishu_64.go | 13 +- pkg/channels/line/line.go | 13 +- pkg/channels/maixcam/maixcam.go | 12 ++ pkg/channels/onebot/onebot.go | 25 +++- pkg/channels/pico/pico.go | 13 +- pkg/channels/qq/qq.go | 23 +++ pkg/channels/slack/slack.go | 42 +++++- pkg/channels/telegram/telegram.go | 22 +-- pkg/channels/wecom/app.go | 10 +- pkg/channels/wecom/bot.go | 14 +- pkg/channels/whatsapp/whatsapp.go | 16 ++- pkg/identity/identity.go | 107 ++++++++++++++ pkg/identity/identity_test.go | 229 ++++++++++++++++++++++++++++++ pkg/routing/session_key.go | 9 ++ pkg/routing/session_key_test.go | 45 ++++++ 20 files changed, 742 insertions(+), 34 deletions(-) create mode 100644 pkg/identity/identity.go create mode 100644 pkg/identity/identity_test.go diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 1a7a14170..7ad8f0417 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -6,9 +6,19 @@ type Peer struct { ID string `json:"id"` } +// SenderInfo provides structured sender identity information. +type SenderInfo struct { + Platform string `json:"platform,omitempty"` // "telegram", "discord", "slack", ... + PlatformID string `json:"platform_id,omitempty"` // raw platform ID, e.g. "123456" + CanonicalID string `json:"canonical_id,omitempty"` // "platform:id" format + Username string `json:"username,omitempty"` // username (e.g. @alice) + DisplayName string `json:"display_name,omitempty"` // display name +} + type InboundMessage struct { Channel string `json:"channel"` SenderID string `json:"sender_id"` + Sender SenderInfo `json:"sender"` ChatID string `json:"chat_id"` Content string `json:"content"` Media []string `json:"media,omitempty"` diff --git a/pkg/channels/base.go b/pkg/channels/base.go index c6a5f1cdc..418933af7 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -9,6 +9,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" ) @@ -20,6 +21,7 @@ type Channel interface { Send(ctx context.Context, msg bus.OutboundMessage) error IsRunning() bool IsAllowed(senderID string) bool + IsAllowedSender(sender bus.SenderInfo) bool } // BaseChannelOption is a functional option for configuring a BaseChannel. @@ -168,22 +170,58 @@ func (c *BaseChannel) IsAllowed(senderID string) bool { return false } +// IsAllowedSender checks whether a structured SenderInfo is permitted by the allow-list. +// It delegates to identity.MatchAllowed for each entry, providing unified matching +// across all legacy formats and the new canonical "platform:id" format. +func (c *BaseChannel) IsAllowedSender(sender bus.SenderInfo) bool { + if len(c.allowList) == 0 { + return true + } + + for _, allowed := range c.allowList { + if identity.MatchAllowed(sender, allowed) { + return true + } + } + + return false +} + func (c *BaseChannel) HandleMessage( ctx context.Context, peer bus.Peer, messageID, senderID, chatID, content string, media []string, metadata map[string]string, + senderOpts ...bus.SenderInfo, ) { - if !c.IsAllowed(senderID) { - return + // Use SenderInfo-based allow check when available, else fall back to string + var sender bus.SenderInfo + if len(senderOpts) > 0 { + sender = senderOpts[0] + } + if sender.CanonicalID != "" || sender.PlatformID != "" { + if !c.IsAllowedSender(sender) { + return + } + } else { + if !c.IsAllowed(senderID) { + return + } + } + + // Set SenderID to canonical if available, otherwise keep the raw senderID + resolvedSenderID := senderID + if sender.CanonicalID != "" { + resolvedSenderID = sender.CanonicalID } scope := BuildMediaScope(c.name, chatID, messageID) msg := bus.InboundMessage{ Channel: c.name, - SenderID: senderID, + SenderID: resolvedSenderID, + Sender: sender, ChatID: chatID, Content: content, Media: media, diff --git a/pkg/channels/base_test.go b/pkg/channels/base_test.go index e56ad3ee9..6132b8bf9 100644 --- a/pkg/channels/base_test.go +++ b/pkg/channels/base_test.go @@ -3,6 +3,7 @@ package channels import ( "testing" + "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/config" ) @@ -175,3 +176,90 @@ func TestShouldRespondInGroup(t *testing.T) { }) } } + +func TestIsAllowedSender(t *testing.T) { + tests := []struct { + name string + allowList []string + sender bus.SenderInfo + want bool + }{ + { + name: "empty allowlist allows all", + allowList: nil, + sender: bus.SenderInfo{PlatformID: "anyone"}, + want: true, + }, + { + name: "numeric ID matches PlatformID", + allowList: []string{"123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: true, + }, + { + name: "canonical format matches", + allowList: []string{"telegram:123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: true, + }, + { + name: "canonical format wrong platform", + allowList: []string{"discord:123456"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: false, + }, + { + name: "@username matches", + allowList: []string{"@alice"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + }, + want: true, + }, + { + name: "compound id|username matches by ID", + allowList: []string{"123456|alice"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + }, + want: true, + }, + { + name: "non matching sender denied", + allowList: []string{"654321"}, + sender: bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ch := NewBaseChannel("test", nil, nil, tt.allowList) + if got := ch.IsAllowedSender(tt.sender); got != tt.want { + t.Fatalf("IsAllowedSender(%+v) = %v, want %v", tt.sender, got, tt.want) + } + }) + } +} diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index 7ab73b4d3..7a3aaca78 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -14,6 +14,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -182,8 +183,20 @@ func (c *DingTalkChannel) onChatBotMessageReceived( "preview": utils.Truncate(content, 50), }) + // Build sender info + sender := bus.SenderInfo{ + Platform: "dingtalk", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("dingtalk", senderID), + DisplayName: senderNick, + } + + if !c.IsAllowedSender(sender) { + return nil, nil + } + // Handle the message through the base channel - c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, "", senderID, chatID, content, nil, metadata, sender) // Return nil to indicate we've handled the message asynchronously // The response will be sent through the message bus diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index 464a4db7b..dc49e7413 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -13,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" @@ -263,7 +264,20 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } // Check allowlist first to avoid downloading attachments for rejected users - if !c.IsAllowed(m.Author.ID) { + sender := bus.SenderInfo{ + Platform: "discord", + PlatformID: m.Author.ID, + CanonicalID: identity.BuildCanonicalID("discord", m.Author.ID), + Username: m.Author.Username, + } + // Build display name + displayName := m.Author.Username + if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { + displayName += "#" + m.Author.Discriminator + } + sender.DisplayName = displayName + + if !c.IsAllowedSender(sender) { logger.DebugCF("discord", "Message rejected by allowlist", map[string]any{ "user_id": m.Author.ID, }) @@ -297,10 +311,6 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } senderID := m.Author.ID - senderName := m.Author.Username - if m.Author.Discriminator != "" && m.Author.Discriminator != "0" { - senderName += "#" + m.Author.Discriminator - } mediaPaths := make([]string, 0, len(m.Attachments)) @@ -358,7 +368,7 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag } logger.DebugCF("discord", "Received message", map[string]any{ - "sender_name": senderName, + "sender_name": sender.DisplayName, "sender_id": senderID, "preview": utils.Truncate(content, 50), }) @@ -375,13 +385,13 @@ func (c *DiscordChannel) handleMessage(s *discordgo.Session, m *discordgo.Messag metadata := map[string]string{ "user_id": senderID, "username": m.Author.Username, - "display_name": senderName, + "display_name": sender.DisplayName, "guild_id": m.GuildID, "channel_id": m.ChannelID, "is_dm": fmt.Sprintf("%t", m.GuildID == ""), } - c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, m.ID, senderID, m.ChannelID, content, mediaPaths, metadata, sender) } // startTyping starts a continuous typing indicator loop for the given chatID. diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 4b8eddd21..62bf69486 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -17,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -189,7 +190,17 @@ func (c *FeishuChannel) handleMessageReceive(ctx context.Context, event *larkim. "preview": utils.Truncate(content, 80), }) - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata) + senderInfo := bus.SenderInfo{ + Platform: "feishu", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("feishu", senderID), + } + + if !c.IsAllowedSender(senderInfo) { + return nil + } + + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, senderInfo) return nil } diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 399617064..28d5ad8f7 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -17,6 +17,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" @@ -370,7 +371,17 @@ func (c *LINEChannel) processEvent(event lineEvent) { // Show typing/loading indicator (requires user ID, not group ID) c.sendLoading(senderID) - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata) + sender := bus.SenderInfo{ + Platform: "line", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("line", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, mediaPaths, metadata, sender) } // isBotMentioned checks if the bot is mentioned in the message. diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index dceaec4c5..142a4b7e7 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -11,6 +11,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -179,6 +180,16 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { "h": fmt.Sprintf("%.0f", h), } + sender := bus.SenderInfo{ + Platform: "maixcam", + PlatformID: "maixcam", + CanonicalID: identity.BuildCanonicalID("maixcam", "maixcam"), + } + + if !c.IsAllowedSender(sender) { + return + } + c.HandleMessage( c.ctx, bus.Peer{Kind: "channel", ID: "default"}, @@ -188,6 +199,7 @@ func (c *MaixCamChannel) handlePersonDetection(msg MaixCamMessage) { content, []string{}, metadata, + sender, ) } diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index b47685397..a748acaa0 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" @@ -823,7 +824,13 @@ func (c *OneBotChannel) handleRawEvent(raw *oneBotRawEvent) { switch raw.PostType { case "message": if userID, err := parseJSONInt64(raw.UserID); err == nil && userID > 0 { - if !c.IsAllowed(strconv.FormatInt(userID, 10)) { + // Build minimal sender for allowlist check + sender := bus.SenderInfo{ + Platform: "onebot", + PlatformID: strconv.FormatInt(userID, 10), + CanonicalID: identity.BuildCanonicalID("onebot", strconv.FormatInt(userID, 10)), + } + if !c.IsAllowedSender(sender) { logger.DebugCF("onebot", "Message rejected by allowlist", map[string]any{ "user_id": userID, }) @@ -1040,7 +1047,21 @@ func (c *OneBotChannel) handleMessage(raw *oneBotRawEvent) { } } - c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata) + senderInfo := bus.SenderInfo{ + Platform: "onebot", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("onebot", senderID), + DisplayName: sender.Nickname, + } + + if !c.IsAllowedSender(senderInfo) { + logger.DebugCF("onebot", "Message rejected by allowlist (senderInfo)", map[string]any{ + "sender": senderID, + }) + return + } + + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, parsed.Media, metadata, senderInfo) } func (c *OneBotChannel) isDuplicate(messageID string) bool { diff --git a/pkg/channels/pico/pico.go b/pkg/channels/pico/pico.go index 9809786e3..c646a3b0b 100644 --- a/pkg/channels/pico/pico.go +++ b/pkg/channels/pico/pico.go @@ -16,6 +16,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -420,7 +421,17 @@ func (c *PicoChannel) handleMessageSend(pc *picoConn, msg PicoMessage) { } } - c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata) + sender := bus.SenderInfo{ + Platform: "pico", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("pico", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, msg.ID, senderID, chatID, content, nil, metadata, sender) } // truncate truncates a string to maxLen runes. diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index c43c13655..85313efe5 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -16,6 +16,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" ) @@ -168,6 +169,16 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { // 转发到消息总线 metadata := map[string]string{} + sender := bus.SenderInfo{ + Platform: "qq", + PlatformID: data.Author.ID, + CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID), + } + + if !c.IsAllowedSender(sender) { + return nil + } + c.HandleMessage(c.ctx, bus.Peer{Kind: "direct", ID: senderID}, data.ID, @@ -176,6 +187,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { content, []string{}, metadata, + sender, ) return nil @@ -224,6 +236,16 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { "group_id": data.GroupID, } + sender := bus.SenderInfo{ + Platform: "qq", + PlatformID: data.Author.ID, + CanonicalID: identity.BuildCanonicalID("qq", data.Author.ID), + } + + if !c.IsAllowedSender(sender) { + return nil + } + c.HandleMessage(c.ctx, bus.Peer{Kind: "group", ID: data.GroupID}, data.ID, @@ -232,6 +254,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { content, []string{}, metadata, + sender, ) return nil diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index c6b3c829e..90c4297ca 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -13,6 +13,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" @@ -252,7 +253,12 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { } // 检查白名单,避免为被拒绝的用户下载附件 - if !c.IsAllowed(ev.User) { + sender := bus.SenderInfo{ + Platform: "slack", + PlatformID: ev.User, + CanonicalID: identity.BuildCanonicalID("slack", ev.User), + } + if !c.IsAllowedSender(sender) { logger.DebugCF("slack", "Message rejected by allowlist", map[string]any{ "user_id": ev.User, }) @@ -360,7 +366,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { "has_thread": threadTS != "", }) - c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata) + c.HandleMessage(c.ctx, peer, messageTS, senderID, chatID, content, mediaPaths, metadata, sender) } func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { @@ -368,7 +374,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { return } - if !c.IsAllowed(ev.User) { + if !c.IsAllowedSender(bus.SenderInfo{ + Platform: "slack", + PlatformID: ev.User, + CanonicalID: identity.BuildCanonicalID("slack", ev.User), + }) { logger.DebugCF("slack", "Mention rejected by allowlist", map[string]any{ "user_id": ev.User, }) @@ -376,6 +386,11 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { } senderID := ev.User + mentionSender := bus.SenderInfo{ + Platform: "slack", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("slack", senderID), + } channelID := ev.Channel threadTS := ev.ThreadTimeStamp messageTS := ev.TimeStamp @@ -433,7 +448,7 @@ func (c *SlackChannel) handleAppMention(ev *slackevents.AppMentionEvent) { "team_id": c.teamID, } - c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata) + c.HandleMessage(c.ctx, mentionPeer, messageTS, senderID, chatID, content, nil, metadata, mentionSender) } func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { @@ -446,7 +461,12 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { c.socketClient.Ack(*event.Request) } - if !c.IsAllowed(cmd.UserID) { + cmdSender := bus.SenderInfo{ + Platform: "slack", + PlatformID: cmd.UserID, + CanonicalID: identity.BuildCanonicalID("slack", cmd.UserID), + } + if !c.IsAllowedSender(cmdSender) { logger.DebugCF("slack", "Slash command rejected by allowlist", map[string]any{ "user_id": cmd.UserID, }) @@ -476,7 +496,17 @@ func (c *SlackChannel) handleSlashCommand(event socketmode.Event) { "text": utils.Truncate(content, 50), }) - c.HandleMessage(c.ctx, bus.Peer{Kind: "channel", ID: channelID}, "", senderID, chatID, content, nil, metadata) + c.HandleMessage( + c.ctx, + bus.Peer{Kind: "channel", ID: channelID}, + "", + senderID, + chatID, + content, + nil, + metadata, + cmdSender, + ) } func (c *SlackChannel) downloadSlackFile(file slack.File) string { diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 31be4d489..6b5a84eda 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -19,6 +19,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/media" "github.com/sipeed/picoclaw/pkg/utils" @@ -289,21 +290,25 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes return fmt.Errorf("message sender (user) is nil") } - senderID := fmt.Sprintf("%d", user.ID) - if user.Username != "" { - senderID = fmt.Sprintf("%d|%s", user.ID, user.Username) + platformID := fmt.Sprintf("%d", user.ID) + sender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: platformID, + CanonicalID: identity.BuildCanonicalID("telegram", platformID), + Username: user.Username, + DisplayName: user.FirstName, } // 检查白名单,避免为被拒绝的用户下载附件 - if !c.IsAllowed(senderID) { + if !c.IsAllowedSender(sender) { logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ - "user_id": senderID, + "user_id": platformID, }) return nil } chatID := message.Chat.ID - c.chatIDs[senderID] = chatID + c.chatIDs[platformID] = chatID content := "" mediaPaths := []string{} @@ -401,7 +406,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes } logger.DebugCF("telegram", "Received message", map[string]any{ - "sender_id": senderID, + "sender_id": sender.CanonicalID, "chat_id": fmt.Sprintf("%d", chatID), "preview": utils.Truncate(content, 50), }) @@ -451,11 +456,12 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes c.HandleMessage(c.ctx, peer, messageID, - fmt.Sprintf("%d", user.ID), + platformID, fmt.Sprintf("%d", chatID), content, mediaPaths, metadata, + sender, ) return nil } diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index e822e67b2..f1e764864 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -19,6 +19,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -629,8 +630,15 @@ func (c *WeComAppChannel) processMessage(ctx context.Context, msg WeComXMLMessag "preview": utils.Truncate(content, 50), }) + // Build sender info + appSender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + } + // Handle the message through the base channel - c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, messageID, senderID, chatID, content, nil, metadata, appSender) } // tokenRefreshLoop periodically refreshes the access token diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 401c9c5ec..460997dab 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -15,6 +15,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -398,8 +399,19 @@ func (c *WeComBotChannel) processMessage(ctx context.Context, msg WeComBotMessag "preview": utils.Truncate(content, 50), }) + // Build sender info + sender := bus.SenderInfo{ + Platform: "wecom", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("wecom", senderID), + } + + if !c.IsAllowedSender(sender) { + return + } + // Handle the message through the base channel - c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata) + c.HandleMessage(ctx, peer, msg.MsgID, senderID, chatID, content, nil, metadata, sender) } // sendWebhookReply sends a reply using the webhook URL diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index b4599b5a0..106114090 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -12,6 +12,7 @@ import ( "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/identity" "github.com/sipeed/picoclaw/pkg/logger" "github.com/sipeed/picoclaw/pkg/utils" ) @@ -224,5 +225,18 @@ func (c *WhatsAppChannel) handleIncomingMessage(msg map[string]any) { "preview": utils.Truncate(content, 50), }) - c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata) + sender := bus.SenderInfo{ + Platform: "whatsapp", + PlatformID: senderID, + CanonicalID: identity.BuildCanonicalID("whatsapp", senderID), + } + if display, ok := metadata["user_name"]; ok { + sender.DisplayName = display + } + + if !c.IsAllowedSender(sender) { + return + } + + c.HandleMessage(c.ctx, peer, messageID, senderID, chatID, content, mediaPaths, metadata, sender) } diff --git a/pkg/identity/identity.go b/pkg/identity/identity.go new file mode 100644 index 000000000..6bc09c210 --- /dev/null +++ b/pkg/identity/identity.go @@ -0,0 +1,107 @@ +// Package identity provides unified user identity utilities for PicoClaw. +// It introduces a canonical "platform:id" format and matching logic +// that is backward-compatible with all legacy allow-list formats. +package identity + +import ( + "strings" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +// BuildCanonicalID constructs a canonical "platform:id" identifier. +// Both platform and platformID are lowercased and trimmed. +func BuildCanonicalID(platform, platformID string) string { + p := strings.ToLower(strings.TrimSpace(platform)) + id := strings.TrimSpace(platformID) + if p == "" || id == "" { + return "" + } + return p + ":" + id +} + +// ParseCanonicalID splits a canonical ID ("platform:id") into its parts. +// Returns ok=false if the input does not contain a colon separator. +func ParseCanonicalID(canonical string) (platform, id string, ok bool) { + canonical = strings.TrimSpace(canonical) + idx := strings.Index(canonical, ":") + if idx <= 0 || idx == len(canonical)-1 { + return "", "", false + } + return canonical[:idx], canonical[idx+1:], true +} + +// MatchAllowed checks whether the given sender matches a single allow-list entry. +// It is backward-compatible with all legacy formats: +// +// - "123456" → matches sender.PlatformID +// - "@alice" → matches sender.Username +// - "123456|alice" → matches PlatformID or Username +// - "telegram:123456" → exact match on sender.CanonicalID +func MatchAllowed(sender bus.SenderInfo, allowed string) bool { + allowed = strings.TrimSpace(allowed) + if allowed == "" { + return false + } + + // Try canonical match first: "platform:id" format + if platform, id, ok := ParseCanonicalID(allowed); ok { + // Only treat as canonical if the platform portion looks like a known platform name + // (not a pure-numeric string, which could be a compound ID) + if !isNumeric(platform) { + candidate := BuildCanonicalID(platform, id) + if candidate != "" && sender.CanonicalID != "" { + return strings.EqualFold(sender.CanonicalID, candidate) + } + // If sender has no canonical ID, try matching platform + platformID + return strings.EqualFold(platform, sender.Platform) && + sender.PlatformID == id + } + } + + // Strip leading "@" for username matching + trimmed := strings.TrimPrefix(allowed, "@") + + // Split compound "id|username" format + allowedID := trimmed + allowedUser := "" + if idx := strings.Index(trimmed, "|"); idx > 0 { + allowedID = trimmed[:idx] + allowedUser = trimmed[idx+1:] + } + + // Match against PlatformID + if sender.PlatformID != "" && sender.PlatformID == allowedID { + return true + } + + // Match against Username + if sender.Username != "" { + if sender.Username == trimmed || sender.Username == allowedUser { + return true + } + } + + // Match compound sender format against allowed parts + if allowedUser != "" && sender.PlatformID != "" && sender.PlatformID == allowedID { + return true + } + if allowedUser != "" && sender.Username != "" && sender.Username == allowedUser { + return true + } + + return false +} + +// isNumeric returns true if s consists entirely of digits. +func isNumeric(s string) bool { + if s == "" { + return false + } + for _, r := range s { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/pkg/identity/identity_test.go b/pkg/identity/identity_test.go new file mode 100644 index 000000000..3d24bd794 --- /dev/null +++ b/pkg/identity/identity_test.go @@ -0,0 +1,229 @@ +package identity + +import ( + "testing" + + "github.com/sipeed/picoclaw/pkg/bus" +) + +func TestBuildCanonicalID(t *testing.T) { + tests := []struct { + platform string + platformID string + want string + }{ + {"telegram", "123456", "telegram:123456"}, + {"Discord", "98765432", "discord:98765432"}, + {"SLACK", "U123ABC", "slack:U123ABC"}, + {"", "123", ""}, + {"telegram", "", ""}, + {" telegram ", " 123 ", "telegram:123"}, + } + + for _, tt := range tests { + got := BuildCanonicalID(tt.platform, tt.platformID) + if got != tt.want { + t.Errorf("BuildCanonicalID(%q, %q) = %q, want %q", + tt.platform, tt.platformID, got, tt.want) + } + } +} + +func TestParseCanonicalID(t *testing.T) { + tests := []struct { + input string + wantPlatform string + wantID string + wantOk bool + }{ + {"telegram:123456", "telegram", "123456", true}, + {"discord:98765432", "discord", "98765432", true}, + {"slack:U123ABC", "slack", "U123ABC", true}, + {"nocolon", "", "", false}, + {"", "", "", false}, + {":missing", "", "", false}, + {"missing:", "", "", false}, + } + + for _, tt := range tests { + platform, id, ok := ParseCanonicalID(tt.input) + if ok != tt.wantOk || platform != tt.wantPlatform || id != tt.wantID { + t.Errorf("ParseCanonicalID(%q) = (%q, %q, %v), want (%q, %q, %v)", + tt.input, platform, id, ok, + tt.wantPlatform, tt.wantID, tt.wantOk) + } + } +} + +func TestMatchAllowed(t *testing.T) { + telegramSender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: "123456", + CanonicalID: "telegram:123456", + Username: "alice", + DisplayName: "Alice Smith", + } + + discordSender := bus.SenderInfo{ + Platform: "discord", + PlatformID: "98765432", + CanonicalID: "discord:98765432", + Username: "bob", + DisplayName: "bob#1234", + } + + noCanonicalSender := bus.SenderInfo{ + Platform: "telegram", + PlatformID: "999", + Username: "carol", + } + + tests := []struct { + name string + sender bus.SenderInfo + allowed string + want bool + }{ + // Pure numeric ID matching + { + name: "numeric ID matches PlatformID", + sender: telegramSender, + allowed: "123456", + want: true, + }, + { + name: "numeric ID does not match", + sender: telegramSender, + allowed: "654321", + want: false, + }, + // Username matching + { + name: "@username matches Username", + sender: telegramSender, + allowed: "@alice", + want: true, + }, + { + name: "@username does not match", + sender: telegramSender, + allowed: "@bob", + want: false, + }, + // Compound format "id|username" + { + name: "compound matches by ID", + sender: telegramSender, + allowed: "123456|alice", + want: true, + }, + { + name: "compound matches by username", + sender: telegramSender, + allowed: "999|alice", + want: true, + }, + { + name: "compound does not match", + sender: telegramSender, + allowed: "654321|bob", + want: false, + }, + // Canonical format "platform:id" + { + name: "canonical matches exactly", + sender: telegramSender, + allowed: "telegram:123456", + want: true, + }, + { + name: "canonical case-insensitive platform", + sender: telegramSender, + allowed: "Telegram:123456", + want: true, + }, + { + name: "canonical wrong platform", + sender: telegramSender, + allowed: "discord:123456", + want: false, + }, + { + name: "canonical wrong ID", + sender: telegramSender, + allowed: "telegram:654321", + want: false, + }, + // Cross-platform canonical + { + name: "discord canonical match", + sender: discordSender, + allowed: "discord:98765432", + want: true, + }, + { + name: "telegram canonical does not match discord sender", + sender: discordSender, + allowed: "telegram:98765432", + want: false, + }, + // Sender without canonical ID + { + name: "canonical match falls back to platform+platformID", + sender: noCanonicalSender, + allowed: "telegram:999", + want: true, + }, + { + name: "platform mismatch on fallback", + sender: noCanonicalSender, + allowed: "discord:999", + want: false, + }, + // Empty allowed string + { + name: "empty allowed never matches", + sender: telegramSender, + allowed: "", + want: false, + }, + // Whitespace handling + { + name: "trimmed allowed matches", + sender: telegramSender, + allowed: " 123456 ", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := MatchAllowed(tt.sender, tt.allowed) + if got != tt.want { + t.Errorf("MatchAllowed(%+v, %q) = %v, want %v", + tt.sender, tt.allowed, got, tt.want) + } + }) + } +} + +func TestIsNumeric(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {"123456", true}, + {"0", true}, + {"", false}, + {"abc", false}, + {"12a34", false}, + {"telegram", false}, + } + + for _, tt := range tests { + got := isNumeric(tt.input) + if got != tt.want { + t.Errorf("isNumeric(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} diff --git a/pkg/routing/session_key.go b/pkg/routing/session_key.go index e12f0d1d8..eab592bec 100644 --- a/pkg/routing/session_key.go +++ b/pkg/routing/session_key.go @@ -163,6 +163,15 @@ func resolveLinkedPeerID(identityLinks map[string][]string, channel, peerID stri scopedCandidate := fmt.Sprintf("%s:%s", channel, strings.ToLower(peerID)) candidates[scopedCandidate] = true } + + // If peerID is already in canonical "platform:id" format, also add the + // bare ID part as a candidate for backward compatibility with identity_links + // that use raw IDs (e.g. "123" instead of "telegram:123"). + if idx := strings.Index(rawCandidate, ":"); idx > 0 && idx < len(rawCandidate)-1 { + bareID := rawCandidate[idx+1:] + candidates[bareID] = true + } + if len(candidates) == 0 { return "" } diff --git a/pkg/routing/session_key_test.go b/pkg/routing/session_key_test.go index 81e4ce018..ad7a1ca02 100644 --- a/pkg/routing/session_key_test.go +++ b/pkg/routing/session_key_test.go @@ -115,6 +115,51 @@ func TestBuildAgentPeerSessionKey_IdentityLink(t *testing.T) { } } +func TestResolveLinkedPeerID_CanonicalPeerID(t *testing.T) { + // When peerID is already in canonical "platform:id" format, + // it should match identity_links that use the bare ID. + links := map[string][]string{ + "john": {"123"}, + } + got := resolveLinkedPeerID(links, "telegram", "telegram:123") + if got != "john" { + t.Errorf("resolveLinkedPeerID with canonical peerID = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_CanonicalInLinks(t *testing.T) { + // When identity_links contain canonical IDs and peerID is canonical too + links := map[string][]string{ + "john": {"telegram:123", "discord:456"}, + } + got := resolveLinkedPeerID(links, "telegram", "telegram:123") + if got != "john" { + t.Errorf("resolveLinkedPeerID canonical in links = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_BarePeerIDMatchesCanonicalLink(t *testing.T) { + // When peerID is bare "123" and links have "telegram:123", + // the scoped candidate "telegram:123" should match. + links := map[string][]string{ + "john": {"telegram:123"}, + } + got := resolveLinkedPeerID(links, "telegram", "123") + if got != "john" { + t.Errorf("resolveLinkedPeerID bare peer matches canonical link = %q, want %q", got, "john") + } +} + +func TestResolveLinkedPeerID_NoMatch(t *testing.T) { + links := map[string][]string{ + "john": {"telegram:123"}, + } + got := resolveLinkedPeerID(links, "discord", "999") + if got != "" { + t.Errorf("resolveLinkedPeerID no match = %q, want empty", got) + } +} + func TestParseAgentSessionKey_Valid(t *testing.T) { parsed := ParseAgentSessionKey("agent:sales:telegram:direct:user123") if parsed == nil { From cea0b95f07a54db6078a4e8f7109c3cb9d02edfb Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 08:20:15 +0800 Subject: [PATCH 24/28] refactor(loop): disable media cleanup to prevent premature file deletion --- pkg/agent/loop.go | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 0e2097488..773e8acd5 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -170,18 +170,20 @@ func (al *AgentLoop) Run(ctx context.Context) error { continue } - // Process message and ensure media is released afterward + // Process message func() { - defer func() { - if al.mediaStore != nil && msg.MediaScope != "" { - if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { - logger.WarnCF("agent", "Failed to release media", map[string]any{ - "scope": msg.MediaScope, - "error": releaseErr.Error(), - }) - } - } - }() + // TODO: Re-enable media cleanup after inbound media is properly consumed by the agent. + // Currently disabled because files are deleted before the LLM can access their content. + // defer func() { + // if al.mediaStore != nil && msg.MediaScope != "" { + // if releaseErr := al.mediaStore.ReleaseAll(msg.MediaScope); releaseErr != nil { + // logger.WarnCF("agent", "Failed to release media", map[string]any{ + // "scope": msg.MediaScope, + // "error": releaseErr.Error(), + // }) + // } + // } + // }() response, err := al.processMessage(ctx, msg) if err != nil { From 94f59fbcab4b0785a76bb2beb987fc2819156da5 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Mon, 23 Feb 2026 21:34:37 +0800 Subject: [PATCH 25/28] fix: address PR #662 review comments (bus drain, context timeouts, onebot leak) - Drain buffered messages in MessageBus.Close() so they aren't silently lost - Replace all context.TODO() with context.WithTimeout(5s) across 7 call sites - Fix OneBot pending channel leak: send nil sentinel in Stop() and handle nil response in sendAPIRequest() to unblock waiting goroutines --- pkg/agent/loop.go | 9 ++++++--- pkg/bus/bus.go | 38 +++++++++++++++++++++++++++++++++++ pkg/channels/onebot/onebot.go | 9 ++++++++- pkg/devices/service.go | 5 ++++- pkg/heartbeat/service.go | 4 +++- pkg/tools/cron.go | 8 ++++++-- pkg/tools/subagent.go | 4 +++- 7 files changed, 68 insertions(+), 9 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 773e8acd5..088e8c4d2 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -121,12 +121,13 @@ func registerSharedTools( // Message tool messageTool := tools.NewMessageTool() messageTool.SetSendCallback(func(channel, chatID, content string) error { - msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + return msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: content, }) - return nil }) agent.Tools.Register(messageTool) @@ -835,7 +836,9 @@ func (al *AgentLoop) maybeSummarize(agent *AgentInstance, sessionKey, channel, c go func() { defer al.summarizing.Delete(summarizeKey) if !constants.IsInternalChannel(channel) { - al.bus.PublishOutbound(context.TODO(), bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + al.bus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: "Memory threshold reached. Optimizing conversation history...", diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index 6a1c987b7..d2b6838c5 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -4,6 +4,8 @@ import ( "context" "errors" "sync/atomic" + + "github.com/sipeed/picoclaw/pkg/logger" ) // ErrBusClosed is returned when publishing to a closed MessageBus. @@ -104,5 +106,41 @@ func (mb *MessageBus) SubscribeOutboundMedia(ctx context.Context) (OutboundMedia func (mb *MessageBus) Close() { if mb.closed.CompareAndSwap(false, true) { close(mb.done) + + // Drain buffered channels so messages aren't silently lost. + // Channels are NOT closed to avoid send-on-closed panics from concurrent publishers. + drained := 0 + for { + select { + case <-mb.inbound: + drained++ + default: + goto doneInbound + } + } + doneInbound: + for { + select { + case <-mb.outbound: + drained++ + default: + goto doneOutbound + } + } + doneOutbound: + for { + select { + case <-mb.outboundMedia: + drained++ + default: + goto doneMedia + } + } + doneMedia: + if drained > 0 { + logger.DebugCF("bus", "Drained buffered messages during close", map[string]any{ + "count": drained, + }) + } } } diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index a748acaa0..feb198d7d 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -306,6 +306,9 @@ func (c *OneBotChannel) sendAPIRequest(action string, params any, timeout time.D select { case resp := <-ch: + if resp == nil { + return nil, fmt.Errorf("API request %s: channel stopped", action) + } return resp, nil case <-time.After(timeout): return nil, fmt.Errorf("API request %s timed out after %v", action, timeout) @@ -353,7 +356,11 @@ func (c *OneBotChannel) Stop(ctx context.Context) error { } c.pendingMu.Lock() - for echo := range c.pending { + for echo, ch := range c.pending { + select { + case ch <- nil: // non-blocking wake for blocked sendAPIRequest goroutines + default: + } delete(c.pending, echo) } c.pendingMu.Unlock() diff --git a/pkg/devices/service.go b/pkg/devices/service.go index 408e1c8aa..1bafe6085 100644 --- a/pkg/devices/service.go +++ b/pkg/devices/service.go @@ -4,6 +4,7 @@ import ( "context" "strings" "sync" + "time" "github.com/sipeed/picoclaw/pkg/bus" "github.com/sipeed/picoclaw/pkg/constants" @@ -127,7 +128,9 @@ func (s *Service) sendNotification(ev *events.DeviceEvent) { } msg := ev.FormatMessage() - msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: msg, diff --git a/pkg/heartbeat/service.go b/pkg/heartbeat/service.go index 62b321955..475f10509 100644 --- a/pkg/heartbeat/service.go +++ b/pkg/heartbeat/service.go @@ -308,7 +308,9 @@ func (hs *HeartbeatService) sendResponse(response string) { return } - msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: platform, ChatID: userID, Content: response, diff --git a/pkg/tools/cron.go b/pkg/tools/cron.go index 3c13f5968..52f914622 100644 --- a/pkg/tools/cron.go +++ b/pkg/tools/cron.go @@ -294,7 +294,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { output = fmt.Sprintf("Scheduled command '%s' executed:\n%s", job.Payload.Command, result.ForLLM) } - t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: output, @@ -304,7 +306,9 @@ func (t *CronTool) ExecuteJob(ctx context.Context, job *cron.CronJob) string { // If deliver=true, send message directly without agent processing if job.Payload.Deliver { - t.msgBus.PublishOutbound(context.TODO(), bus.OutboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + t.msgBus.PublishOutbound(pubCtx, bus.OutboundMessage{ Channel: channel, ChatID: chatID, Content: job.Payload.Message, diff --git a/pkg/tools/subagent.go b/pkg/tools/subagent.go index 99821daf9..fee53fc28 100644 --- a/pkg/tools/subagent.go +++ b/pkg/tools/subagent.go @@ -218,7 +218,9 @@ After completing the task, provide a clear summary of what was done.` // Send announce message back to main agent if sm.bus != nil { announceContent := fmt.Sprintf("Task '%s' completed.\n\nResult:\n%s", task.Label, task.Result) - sm.bus.PublishInbound(context.TODO(), bus.InboundMessage{ + pubCtx, pubCancel := context.WithTimeout(context.Background(), 5*time.Second) + defer pubCancel() + sm.bus.PublishInbound(pubCtx, bus.InboundMessage{ Channel: "system", SenderID: fmt.Sprintf("subagent:%s", task.ID), // Format: "original_channel:original_chat_id" for routing back From 692efb21287cd6fde4bc53e011a5dbde9706ef58 Mon Sep 17 00:00:00 2001 From: Hoshina Date: Tue, 24 Feb 2026 12:17:11 +0800 Subject: [PATCH 26/28] chore: apply PR #697 comment translations to refactored channel subpackages Translate Chinese comments to English in qq, slack, and telegram channel implementations, following the translation work done in PR #697. The original PR modified the old parent package files, but these have been moved to subpackages during the refactor, so translations are applied to the new locations. --- pkg/channels/qq/qq.go | 36 +++++++++++++++---------------- pkg/channels/slack/slack.go | 2 +- pkg/channels/telegram/telegram.go | 2 +- 3 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 85313efe5..1e2cc2354 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -51,31 +51,31 @@ func (c *QQChannel) Start(ctx context.Context) error { logger.InfoC("qq", "Starting QQ bot (WebSocket mode)") - // 创建 token source + // create token source credentials := &token.QQBotCredentials{ AppID: c.config.AppID, AppSecret: c.config.AppSecret, } c.tokenSource = token.NewQQBotTokenSource(credentials) - // 创建子 context + // create child context c.ctx, c.cancel = context.WithCancel(ctx) - // 启动自动刷新 token 协程 + // start auto-refresh token goroutine if err := token.StartRefreshAccessToken(c.ctx, c.tokenSource); err != nil { return fmt.Errorf("failed to start token refresh: %w", err) } - // 初始化 OpenAPI 客户端 + // initialize OpenAPI client c.api = botgo.NewOpenAPI(c.config.AppID, c.tokenSource).WithTimeout(5 * time.Second) - // 注册事件处理器 + // register event handlers intent := event.RegisterHandlers( c.handleC2CMessage(), c.handleGroupATMessage(), ) - // 获取 WebSocket 接入点 + // get WebSocket endpoint wsInfo, err := c.api.WS(c.ctx, nil, "") if err != nil { return fmt.Errorf("failed to get websocket info: %w", err) @@ -85,10 +85,10 @@ func (c *QQChannel) Start(ctx context.Context) error { "shards": wsInfo.Shards, }) - // 创建并保存 sessionManager + // create and save sessionManager c.sessionManager = botgo.NewSessionManager() - // 在 goroutine 中启动 WebSocket 连接,避免阻塞 + // start WebSocket connection in goroutine to avoid blocking go func() { if err := c.sessionManager.Start(wsInfo, c.tokenSource, &intent); err != nil { logger.ErrorCF("qq", "WebSocket session error", map[string]any{ @@ -120,12 +120,12 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return channels.ErrNotRunning } - // 构造消息 + // construct message msgToCreate := &dto.MessageToCreate{ Content: msg.Content, } - // C2C 消息发送 + // send C2C message _, err := c.api.PostC2CMessage(ctx, msg.ChatID, msgToCreate) if err != nil { logger.ErrorCF("qq", "Failed to send C2C message", map[string]any{ @@ -137,15 +137,15 @@ func (c *QQChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil } -// handleC2CMessage 处理 QQ 私聊消息 +// handleC2CMessage handles QQ private messages func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { return func(event *dto.WSPayload, data *dto.WSC2CMessageData) error { - // 去重检查 + // deduplication check if c.isDuplicate(data.ID) { return nil } - // 提取用户信息 + // extract user info var senderID string if data.Author != nil && data.Author.ID != "" { senderID = data.Author.ID @@ -154,7 +154,7 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { return nil } - // 提取消息内容 + // extract message content content := data.Content if content == "" { logger.DebugC("qq", "Received empty message, ignoring") @@ -194,15 +194,15 @@ func (c *QQChannel) handleC2CMessage() event.C2CMessageEventHandler { } } -// handleGroupATMessage 处理群@消息 +// handleGroupATMessage handles QQ group @ messages func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return func(event *dto.WSPayload, data *dto.WSGroupATMessageData) error { - // 去重检查 + // deduplication check if c.isDuplicate(data.ID) { return nil } - // 提取用户信息 + // extract user info var senderID string if data.Author != nil && data.Author.ID != "" { senderID = data.Author.ID @@ -211,7 +211,7 @@ func (c *QQChannel) handleGroupATMessage() event.GroupATMessageEventHandler { return nil } - // 提取消息内容(去掉 @ 机器人部分) + // extract message content (remove @ bot part) content := data.Content if content == "" { logger.DebugC("qq", "Received empty group message, ignoring") diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 90c4297ca..7128980e4 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -252,7 +252,7 @@ func (c *SlackChannel) handleMessageEvent(ev *slackevents.MessageEvent) { return } - // 检查白名单,避免为被拒绝的用户下载附件 + // check allowlist to avoid downloading attachments for rejected users sender := bus.SenderInfo{ Platform: "slack", PlatformID: ev.User, diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 6b5a84eda..005b311a2 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -299,7 +299,7 @@ func (c *TelegramChannel) handleMessage(ctx context.Context, message *telego.Mes DisplayName: user.FirstName, } - // 检查白名单,避免为被拒绝的用户下载附件 + // check allowlist to avoid downloading attachments for rejected users if !c.IsAllowedSender(sender) { logger.DebugCF("telegram", "Message rejected by allowlist", map[string]any{ "user_id": platformID, From b4a14fa199dd59e93c4e593b41eed8a7979153f0 Mon Sep 17 00:00:00 2001 From: Avisek Date: Thu, 26 Feb 2026 10:54:51 +0530 Subject: [PATCH 27/28] feat: Introduce LLM reasoning fields to LLM responses and enable routing reasoning output to dedicated channels. --- pkg/agent/loop.go | 34 +++++ pkg/agent/loop_test.go | 167 ++++++++++++++++++++++++ pkg/channels/base.go | 11 ++ pkg/channels/dingtalk/dingtalk.go | 1 + pkg/channels/discord/discord.go | 1 + pkg/channels/feishu/feishu_64.go | 1 + pkg/channels/line/line.go | 1 + pkg/channels/maixcam/maixcam.go | 8 +- pkg/channels/onebot/onebot.go | 1 + pkg/channels/qq/qq.go | 1 + pkg/channels/slack/slack.go | 1 + pkg/channels/telegram/telegram.go | 1 + pkg/channels/wecom/app.go | 1 + pkg/channels/wecom/bot.go | 1 + pkg/channels/whatsapp/whatsapp.go | 9 +- pkg/config/config.go | 146 +++++++++++---------- pkg/providers/openai_compat/provider.go | 17 ++- pkg/providers/protocoltypes/types.go | 17 ++- 18 files changed, 339 insertions(+), 80 deletions(-) diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 088e8c4d2..e2bd222b9 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -525,6 +525,28 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt return finalContent, nil } +func (al *AgentLoop) targetReasoningChannelID(channelName string) (chatID string) { + if al.channelManager == nil { + return "" + } + if ch, ok := al.channelManager.GetChannel(channelName); ok { + return ch.ReasoningChannelID() + } + return "" +} + +func (al *AgentLoop) handleReasoning(ctx context.Context, reasoningContent, channelName, channelID string) { + if reasoningContent == "" || channelName == "" || channelID == "" { + return + } + + al.bus.PublishOutbound(ctx, bus.OutboundMessage{ + Channel: channelName, + ChatID: channelID, + Content: reasoningContent, + }) +} + // runLLMIteration executes the LLM call loop with tool handling. func (al *AgentLoop) runLLMIteration( ctx context.Context, @@ -649,6 +671,18 @@ func (al *AgentLoop) runLLMIteration( return "", iteration, fmt.Errorf("LLM call failed after retries: %w", err) } + go al.handleReasoning(ctx, response.Reasoning, opts.Channel, al.targetReasoningChannelID(opts.Channel)) + + logger.DebugCF("agent", "LLM response", + map[string]any{ + "agent_id": agent.ID, + "iteration": iteration, + "content_chars": len(response.Content), + "tool_calls": len(response.ToolCalls), + "reasoning": response.Reasoning, + "target_channel": al.targetReasoningChannelID(opts.Channel), + "channel": opts.Channel, + }) // Check if no tool calls - we're done if len(response.ToolCalls) == 0 { finalContent = response.Content diff --git a/pkg/agent/loop_test.go b/pkg/agent/loop_test.go index 4414398b1..6dfc7ef3e 100644 --- a/pkg/agent/loop_test.go +++ b/pkg/agent/loop_test.go @@ -9,11 +9,23 @@ import ( "time" "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/channels" "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/providers" "github.com/sipeed/picoclaw/pkg/tools" ) +type fakeChannel struct{ id string } + +func (f *fakeChannel) Name() string { return "fake" } +func (f *fakeChannel) Start(ctx context.Context) error { return nil } +func (f *fakeChannel) Stop(ctx context.Context) error { return nil } +func (f *fakeChannel) Send(ctx context.Context, msg bus.OutboundMessage) error { return nil } +func (f *fakeChannel) IsRunning() bool { return true } +func (f *fakeChannel) IsAllowed(string) bool { return true } +func (f *fakeChannel) IsAllowedSender(sender bus.SenderInfo) bool { return true } +func (f *fakeChannel) ReasoningChannelID() string { return f.id } + func TestRecordLastChannel(t *testing.T) { // Create temp workspace tmpDir, err := os.MkdirTemp("", "agent-test-*") @@ -631,3 +643,158 @@ func TestAgentLoop_ContextExhaustionRetry(t *testing.T) { t.Errorf("Expected history to be compressed (len < 8), got %d", len(finalHistory)) } } + +func TestTargetReasoningChannelID_AllChannels(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + + al := NewAgentLoop(cfg, bus.NewMessageBus(), &mockProvider{}) + chManager, err := channels.NewManager(&config.Config{}, bus.NewMessageBus(), nil) + if err != nil { + t.Fatalf("Failed to create channel manager: %v", err) + } + for name, id := range map[string]string{ + "whatsapp": "rid-whatsapp", + "telegram": "rid-telegram", + "feishu": "rid-feishu", + "discord": "rid-discord", + "maixcam": "rid-maixcam", + "qq": "rid-qq", + "dingtalk": "rid-dingtalk", + "slack": "rid-slack", + "line": "rid-line", + "onebot": "rid-onebot", + "wecom": "rid-wecom", + "wecom_app": "rid-wecom-app", + } { + chManager.RegisterChannel(name, &fakeChannel{id: id}) + } + al.SetChannelManager(chManager) + tests := []struct { + channel string + wantID string + }{ + {channel: "whatsapp", wantID: "rid-whatsapp"}, + {channel: "telegram", wantID: "rid-telegram"}, + {channel: "feishu", wantID: "rid-feishu"}, + {channel: "discord", wantID: "rid-discord"}, + {channel: "maixcam", wantID: "rid-maixcam"}, + {channel: "qq", wantID: "rid-qq"}, + {channel: "dingtalk", wantID: "rid-dingtalk"}, + {channel: "slack", wantID: "rid-slack"}, + {channel: "line", wantID: "rid-line"}, + {channel: "onebot", wantID: "rid-onebot"}, + {channel: "wecom", wantID: "rid-wecom"}, + {channel: "wecom_app", wantID: "rid-wecom-app"}, + {channel: "unknown", wantID: ""}, + } + + for _, tt := range tests { + t.Run(tt.channel, func(t *testing.T) { + got := al.targetReasoningChannelID(tt.channel) + if got != tt.wantID { + t.Fatalf("targetReasoningChannelID(%q) = %q, want %q", tt.channel, got, tt.wantID) + } + }) + } +} + +func TestHandleReasoning(t *testing.T) { + newLoop := func(t *testing.T) (*AgentLoop, *bus.MessageBus) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "agent-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + t.Cleanup(func() { _ = os.RemoveAll(tmpDir) }) + cfg := &config.Config{ + Agents: config.AgentsConfig{ + Defaults: config.AgentDefaults{ + Workspace: tmpDir, + Model: "test-model", + MaxTokens: 4096, + MaxToolIterations: 10, + }, + }, + } + msgBus := bus.NewMessageBus() + return NewAgentLoop(cfg, msgBus, &mockProvider{}), msgBus + } + + t.Run("skips when any required field is empty", func(t *testing.T) { + al, msgBus := newLoop(t) + al.handleReasoning(context.Background(), "reasoning", "telegram", "") + + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + if msg, ok := msgBus.SubscribeOutbound(ctx); ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + }) + + t.Run("publishes one message for non telegram", func(t *testing.T) { + al, msgBus := newLoop(t) + al.handleReasoning(context.Background(), "hello reasoning", "slack", "channel-1") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatal("expected an outbound message") + } + if msg.Channel != "slack" || msg.ChatID != "channel-1" || msg.Content != "hello reasoning" { + t.Fatalf("unexpected outbound message: %+v", msg) + } + }) + + t.Run("publishes one message for telegram", func(t *testing.T) { + al, msgBus := newLoop(t) + reasoning := "hello telegram reasoning" + al.handleReasoning(context.Background(), reasoning, "telegram", "tg-chat") + + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if !ok { + t.Fatal("expected outbound message") + } + + if msg.Channel != "telegram" { + t.Fatalf("expected telegram channel message, got %+v", msg) + } + if msg.ChatID != "tg-chat" { + t.Fatalf("expected chatID tg-chat, got %+v", msg) + } + if msg.Content != reasoning { + t.Fatalf("content mismatch: got %q want %q", msg.Content, reasoning) + } + }) + t.Run("expired ctx", func(t *testing.T) { + al, msgBus := newLoop(t) + reasoning := "hello telegram reasoning" + ctx, cancel := context.WithCancel(context.Background()) + cancel() + al.handleReasoning(ctx, reasoning, "telegram", "tg-chat") + + ctx, cancel = context.WithTimeout(context.Background(), 200*time.Millisecond) + defer cancel() + msg, ok := msgBus.SubscribeOutbound(ctx) + if ok { + t.Fatalf("expected no outbound message, got %+v", msg) + } + }) +} diff --git a/pkg/channels/base.go b/pkg/channels/base.go index 418933af7..c8c721341 100644 --- a/pkg/channels/base.go +++ b/pkg/channels/base.go @@ -22,6 +22,7 @@ type Channel interface { IsRunning() bool IsAllowed(senderID string) bool IsAllowedSender(sender bus.SenderInfo) bool + ReasoningChannelID() string } // BaseChannelOption is a functional option for configuring a BaseChannel. @@ -39,6 +40,11 @@ func WithGroupTrigger(gt config.GroupTriggerConfig) BaseChannelOption { return func(c *BaseChannel) { c.groupTrigger = gt } } +// WithReasoningChannelID sets the reasoning channel ID where thoughts should be sent. +func WithReasoningChannelID(id string) BaseChannelOption { + return func(c *BaseChannel) { c.reasoningChannelID = id } +} + // MessageLengthProvider is an opt-in interface that channels implement // to advertise their maximum message length. The Manager uses this via // type assertion to decide whether to split outbound messages. @@ -56,6 +62,7 @@ type BaseChannel struct { groupTrigger config.GroupTriggerConfig mediaStore media.MediaStore placeholderRecorder PlaceholderRecorder + reasoningChannelID string } func NewBaseChannel( @@ -127,6 +134,10 @@ func (c *BaseChannel) Name() string { return c.name } +func (c *BaseChannel) ReasoningChannelID() string { + return c.reasoningChannelID +} + func (c *BaseChannel) IsRunning() bool { return c.running.Load() } diff --git a/pkg/channels/dingtalk/dingtalk.go b/pkg/channels/dingtalk/dingtalk.go index 7a3aaca78..8642ad362 100644 --- a/pkg/channels/dingtalk/dingtalk.go +++ b/pkg/channels/dingtalk/dingtalk.go @@ -42,6 +42,7 @@ func NewDingTalkChannel(cfg config.DingTalkConfig, messageBus *bus.MessageBus) ( base := channels.NewBaseChannel("dingtalk", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(20000), channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &DingTalkChannel{ diff --git a/pkg/channels/discord/discord.go b/pkg/channels/discord/discord.go index dc49e7413..fe0f8e82c 100644 --- a/pkg/channels/discord/discord.go +++ b/pkg/channels/discord/discord.go @@ -43,6 +43,7 @@ func NewDiscordChannel(cfg config.DiscordConfig, bus *bus.MessageBus) (*DiscordC base := channels.NewBaseChannel("discord", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(2000), channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &DiscordChannel{ diff --git a/pkg/channels/feishu/feishu_64.go b/pkg/channels/feishu/feishu_64.go index 62bf69486..1db1bf669 100644 --- a/pkg/channels/feishu/feishu_64.go +++ b/pkg/channels/feishu/feishu_64.go @@ -35,6 +35,7 @@ type FeishuChannel struct { func NewFeishuChannel(cfg config.FeishuConfig, bus *bus.MessageBus) (*FeishuChannel, error) { base := channels.NewBaseChannel("feishu", cfg, bus, cfg.AllowFrom, channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &FeishuChannel{ diff --git a/pkg/channels/line/line.go b/pkg/channels/line/line.go index 28d5ad8f7..21eb4cb67 100644 --- a/pkg/channels/line/line.go +++ b/pkg/channels/line/line.go @@ -63,6 +63,7 @@ func NewLINEChannel(cfg config.LINEConfig, messageBus *bus.MessageBus) (*LINECha base := channels.NewBaseChannel("line", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(5000), channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &LINEChannel{ diff --git a/pkg/channels/maixcam/maixcam.go b/pkg/channels/maixcam/maixcam.go index 142a4b7e7..ff9a3ed1a 100644 --- a/pkg/channels/maixcam/maixcam.go +++ b/pkg/channels/maixcam/maixcam.go @@ -33,7 +33,13 @@ type MaixCamMessage struct { } func NewMaixCamChannel(cfg config.MaixCamConfig, bus *bus.MessageBus) (*MaixCamChannel, error) { - base := channels.NewBaseChannel("maixcam", cfg, bus, cfg.AllowFrom) + base := channels.NewBaseChannel( + "maixcam", + cfg, + bus, + cfg.AllowFrom, + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &MaixCamChannel{ BaseChannel: base, diff --git a/pkg/channels/onebot/onebot.go b/pkg/channels/onebot/onebot.go index feb198d7d..e0be58fa0 100644 --- a/pkg/channels/onebot/onebot.go +++ b/pkg/channels/onebot/onebot.go @@ -100,6 +100,7 @@ type oneBotMessageSegment struct { func NewOneBotChannel(cfg config.OneBotConfig, messageBus *bus.MessageBus) (*OneBotChannel, error) { base := channels.NewBaseChannel("onebot", cfg, messageBus, cfg.AllowFrom, channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) const dedupSize = 1024 diff --git a/pkg/channels/qq/qq.go b/pkg/channels/qq/qq.go index 1e2cc2354..112964143 100644 --- a/pkg/channels/qq/qq.go +++ b/pkg/channels/qq/qq.go @@ -35,6 +35,7 @@ type QQChannel struct { func NewQQChannel(cfg config.QQConfig, messageBus *bus.MessageBus) (*QQChannel, error) { base := channels.NewBaseChannel("qq", cfg, messageBus, cfg.AllowFrom, channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &QQChannel{ diff --git a/pkg/channels/slack/slack.go b/pkg/channels/slack/slack.go index 7128980e4..5e2d5dc4b 100644 --- a/pkg/channels/slack/slack.go +++ b/pkg/channels/slack/slack.go @@ -51,6 +51,7 @@ func NewSlackChannel(cfg config.SlackConfig, messageBus *bus.MessageBus) (*Slack base := channels.NewBaseChannel("slack", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(40000), channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &SlackChannel{ diff --git a/pkg/channels/telegram/telegram.go b/pkg/channels/telegram/telegram.go index 005b311a2..86bfc89f8 100644 --- a/pkg/channels/telegram/telegram.go +++ b/pkg/channels/telegram/telegram.go @@ -71,6 +71,7 @@ func NewTelegramChannel(cfg *config.Config, bus *bus.MessageBus) (*TelegramChann telegramCfg.AllowFrom, channels.WithMaxMessageLength(4096), channels.WithGroupTrigger(telegramCfg.GroupTrigger), + channels.WithReasoningChannelID(telegramCfg.ReasoningChannelID), ) return &TelegramChannel{ diff --git a/pkg/channels/wecom/app.go b/pkg/channels/wecom/app.go index f1e764864..409aa8e96 100644 --- a/pkg/channels/wecom/app.go +++ b/pkg/channels/wecom/app.go @@ -126,6 +126,7 @@ func NewWeComAppChannel(cfg config.WeComAppConfig, messageBus *bus.MessageBus) ( base := channels.NewBaseChannel("wecom_app", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(2048), channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &WeComAppChannel{ diff --git a/pkg/channels/wecom/bot.go b/pkg/channels/wecom/bot.go index 460997dab..4c576b84b 100644 --- a/pkg/channels/wecom/bot.go +++ b/pkg/channels/wecom/bot.go @@ -90,6 +90,7 @@ func NewWeComBotChannel(cfg config.WeComConfig, messageBus *bus.MessageBus) (*We base := channels.NewBaseChannel("wecom", cfg, messageBus, cfg.AllowFrom, channels.WithMaxMessageLength(2048), channels.WithGroupTrigger(cfg.GroupTrigger), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), ) return &WeComBotChannel{ diff --git a/pkg/channels/whatsapp/whatsapp.go b/pkg/channels/whatsapp/whatsapp.go index 106114090..5c1b639b3 100644 --- a/pkg/channels/whatsapp/whatsapp.go +++ b/pkg/channels/whatsapp/whatsapp.go @@ -29,7 +29,14 @@ type WhatsAppChannel struct { } func NewWhatsAppChannel(cfg config.WhatsAppConfig, bus *bus.MessageBus) (*WhatsAppChannel, error) { - base := channels.NewBaseChannel("whatsapp", cfg, bus, cfg.AllowFrom, channels.WithMaxMessageLength(65536)) + base := channels.NewBaseChannel( + "whatsapp", + cfg, + bus, + cfg.AllowFrom, + channels.WithMaxMessageLength(65536), + channels.WithReasoningChannelID(cfg.ReasoningChannelID), + ) return &WhatsAppChannel{ BaseChannel: base, diff --git a/pkg/config/config.go b/pkg/config/config.go index fd5def625..b2385ede5 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -213,72 +213,80 @@ type PlaceholderConfig struct { } type WhatsAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` - BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WHATSAPP_ENABLED"` + BridgeURL string `json:"bridge_url" env:"PICOCLAW_CHANNELS_WHATSAPP_BRIDGE_URL"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WHATSAPP_ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WHATSAPP_REASONING_CHANNEL_ID"` } type TelegramConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` - Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - Typing TypingConfig `json:"typing,omitempty"` - Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_TELEGRAM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_TELEGRAM_TOKEN"` + Proxy string `json:"proxy" env:"PICOCLAW_CHANNELS_TELEGRAM_PROXY"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_TELEGRAM_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_TELEGRAM_REASONING_CHANNEL_ID"` } type FeishuConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` - EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` - VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_FEISHU_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_FEISHU_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_FEISHU_APP_SECRET"` + EncryptKey string `json:"encrypt_key" env:"PICOCLAW_CHANNELS_FEISHU_ENCRYPT_KEY"` + VerificationToken string `json:"verification_token" env:"PICOCLAW_CHANNELS_FEISHU_VERIFICATION_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_FEISHU_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_FEISHU_REASONING_CHANNEL_ID"` } type DiscordConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` - MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - Typing TypingConfig `json:"typing,omitempty"` - Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DISCORD_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_DISCORD_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DISCORD_ALLOW_FROM"` + MentionOnly bool `json:"mention_only" env:"PICOCLAW_CHANNELS_DISCORD_MENTION_ONLY"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DISCORD_REASONING_CHANNEL_ID"` } type MaixCamConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` - Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` - Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_MAIXCAM_ENABLED"` + Host string `json:"host" env:"PICOCLAW_CHANNELS_MAIXCAM_HOST"` + Port int `json:"port" env:"PICOCLAW_CHANNELS_MAIXCAM_PORT"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_MAIXCAM_ALLOW_FROM"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_MAIXCAM_REASONING_CHANNEL_ID"` } type QQConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` - AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` - AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_QQ_ENABLED"` + AppID string `json:"app_id" env:"PICOCLAW_CHANNELS_QQ_APP_ID"` + AppSecret string `json:"app_secret" env:"PICOCLAW_CHANNELS_QQ_APP_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_QQ_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_QQ_REASONING_CHANNEL_ID"` } type DingTalkConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` - ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` - ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_DINGTALK_ENABLED"` + ClientID string `json:"client_id" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_ID"` + ClientSecret string `json:"client_secret" env:"PICOCLAW_CHANNELS_DINGTALK_CLIENT_SECRET"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_DINGTALK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_DINGTALK_REASONING_CHANNEL_ID"` } type SlackConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` - BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` - AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` - Typing TypingConfig `json:"typing,omitempty"` - Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_SLACK_ENABLED"` + BotToken string `json:"bot_token" env:"PICOCLAW_CHANNELS_SLACK_BOT_TOKEN"` + AppToken string `json:"app_token" env:"PICOCLAW_CHANNELS_SLACK_APP_TOKEN"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_SLACK_ALLOW_FROM"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Typing TypingConfig `json:"typing,omitempty"` + Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_SLACK_REASONING_CHANNEL_ID"` } type LINEConfig struct { @@ -292,6 +300,7 @@ type LINEConfig struct { GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` Typing TypingConfig `json:"typing,omitempty"` Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_LINE_REASONING_CHANNEL_ID"` } type OneBotConfig struct { @@ -304,34 +313,37 @@ type OneBotConfig struct { GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` Typing TypingConfig `json:"typing,omitempty"` Placeholder PlaceholderConfig `json:"placeholder,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_ONEBOT_REASONING_CHANNEL_ID"` } type WeComConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` - WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_ENABLED"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_ENCODING_AES_KEY"` + WebhookURL string `json:"webhook_url" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_URL"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_REASONING_CHANNEL_ID"` } type WeComAppConfig struct { - Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` - CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` - CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` - AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` - Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` - EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` - WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` - WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` - WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` - AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` - ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` - GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + Enabled bool `json:"enabled" env:"PICOCLAW_CHANNELS_WECOM_APP_ENABLED"` + CorpID string `json:"corp_id" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_ID"` + CorpSecret string `json:"corp_secret" env:"PICOCLAW_CHANNELS_WECOM_APP_CORP_SECRET"` + AgentID int64 `json:"agent_id" env:"PICOCLAW_CHANNELS_WECOM_APP_AGENT_ID"` + Token string `json:"token" env:"PICOCLAW_CHANNELS_WECOM_APP_TOKEN"` + EncodingAESKey string `json:"encoding_aes_key" env:"PICOCLAW_CHANNELS_WECOM_APP_ENCODING_AES_KEY"` + WebhookHost string `json:"webhook_host" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_HOST"` + WebhookPort int `json:"webhook_port" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PORT"` + WebhookPath string `json:"webhook_path" env:"PICOCLAW_CHANNELS_WECOM_APP_WEBHOOK_PATH"` + AllowFrom FlexibleStringSlice `json:"allow_from" env:"PICOCLAW_CHANNELS_WECOM_APP_ALLOW_FROM"` + ReplyTimeout int `json:"reply_timeout" env:"PICOCLAW_CHANNELS_WECOM_APP_REPLY_TIMEOUT"` + GroupTrigger GroupTriggerConfig `json:"group_trigger,omitempty"` + ReasoningChannelID string `json:"reasoning_channel_id" env:"PICOCLAW_CHANNELS_WECOM_APP_REASONING_CHANNEL_ID"` } type PicoConfig struct { diff --git a/pkg/providers/openai_compat/provider.go b/pkg/providers/openai_compat/provider.go index 236a048c4..e8bd21ebd 100644 --- a/pkg/providers/openai_compat/provider.go +++ b/pkg/providers/openai_compat/provider.go @@ -25,6 +25,7 @@ type ( ToolFunctionDefinition = protocoltypes.ToolFunctionDefinition ExtraContent = protocoltypes.ExtraContent GoogleExtra = protocoltypes.GoogleExtra + ReasoningDetail = protocoltypes.ReasoningDetail ) type Provider struct { @@ -148,8 +149,10 @@ func parseResponse(body []byte) (*LLMResponse, error) { var apiResponse struct { Choices []struct { Message struct { - Content string `json:"content"` - ToolCalls []struct { + Content string `json:"content"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` + ToolCalls []struct { ID string `json:"id"` Type string `json:"type"` Function *struct { @@ -221,10 +224,12 @@ func parseResponse(body []byte) (*LLMResponse, error) { } return &LLMResponse{ - Content: choice.Message.Content, - ToolCalls: toolCalls, - FinishReason: choice.FinishReason, - Usage: apiResponse.Usage, + Reasoning: choice.Message.Reasoning, + ReasoningDetails: choice.Message.ReasoningDetails, + Content: choice.Message.Content, + ToolCalls: toolCalls, + FinishReason: choice.FinishReason, + Usage: apiResponse.Usage, }, nil } diff --git a/pkg/providers/protocoltypes/types.go b/pkg/providers/protocoltypes/types.go index 5e1c6d397..3e18fca43 100644 --- a/pkg/providers/protocoltypes/types.go +++ b/pkg/providers/protocoltypes/types.go @@ -25,12 +25,19 @@ type FunctionCall struct { } type LLMResponse struct { - Content string `json:"content"` - ToolCalls []ToolCall `json:"tool_calls,omitempty"` - FinishReason string `json:"finish_reason"` - Usage *UsageInfo `json:"usage,omitempty"` + Content string `json:"content"` + ToolCalls []ToolCall `json:"tool_calls,omitempty"` + FinishReason string `json:"finish_reason"` + Usage *UsageInfo `json:"usage,omitempty"` + Reasoning string `json:"reasoning"` + ReasoningDetails []ReasoningDetail `json:"reasoning_details"` +} +type ReasoningDetail struct { + Format string `json:"format"` + Index int `json:"index"` + Type string `json:"type"` + Text string `json:"text"` } - type UsageInfo struct { PromptTokens int `json:"prompt_tokens"` CompletionTokens int `json:"completion_tokens"` From 5ae3bbceab9c6de161e0e48fa509e567ede0cb3f Mon Sep 17 00:00:00 2001 From: Avisek Date: Thu, 26 Feb 2026 11:33:01 +0530 Subject: [PATCH 28/28] feat: Add `reasoning_channel_id` to communication platform configurations and improve message bus context cancellation handling. --- config/config.example.json | 38 +++++++++++++++++++++++++------------- pkg/bus/bus.go | 9 +++++++++ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/config/config.example.json b/config/config.example.json index 555509732..149247526 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -52,30 +52,35 @@ "proxy": "", "allow_from": [ "YOUR_USER_ID" - ] + ], + "reasoning_channel_id": "" }, "discord": { "enabled": false, "token": "YOUR_DISCORD_BOT_TOKEN", "allow_from": [], - "mention_only": false + "mention_only": false, + "reasoning_channel_id": "" }, "qq": { "enabled": false, "app_id": "YOUR_QQ_APP_ID", "app_secret": "YOUR_QQ_APP_SECRET", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "maixcam": { "enabled": false, "host": "0.0.0.0", "port": 18790, - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "whatsapp": { "enabled": false, "bridge_url": "ws://localhost:3001", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "feishu": { "enabled": false, @@ -83,19 +88,22 @@ "app_secret": "", "encrypt_key": "", "verification_token": "", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "dingtalk": { "enabled": false, "client_id": "YOUR_CLIENT_ID", "client_secret": "YOUR_CLIENT_SECRET", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "slack": { "enabled": false, "bot_token": "xoxb-YOUR-BOT-TOKEN", "app_token": "xapp-YOUR-APP-TOKEN", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "line": { "enabled": false, @@ -104,7 +112,8 @@ "webhook_host": "0.0.0.0", "webhook_port": 18791, "webhook_path": "/webhook/line", - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "onebot": { "enabled": false, @@ -112,7 +121,8 @@ "access_token": "", "reconnect_interval": 5, "group_trigger_prefix": [], - "allow_from": [] + "allow_from": [], + "reasoning_channel_id": "" }, "wecom": { "_comment": "WeCom Bot (智能机器人) - Easier setup, supports group chats", @@ -124,7 +134,8 @@ "webhook_port": 18793, "webhook_path": "/webhook/wecom", "allow_from": [], - "reply_timeout": 5 + "reply_timeout": 5, + "reasoning_channel_id": "" }, "wecom_app": { "_comment": "WeCom App (自建应用) - More features, proactive messaging, private chat only. See docs/wecom-app-configuration.md", @@ -138,7 +149,8 @@ "webhook_port": 18792, "webhook_path": "/webhook/wecom-app", "allow_from": [], - "reply_timeout": 5 + "reply_timeout": 5, + "reasoning_channel_id": "" } }, "providers": { @@ -250,4 +262,4 @@ "host": "127.0.0.1", "port": 18790 } -} +} \ No newline at end of file diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index d2b6838c5..c749b6535 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -32,6 +32,9 @@ func (mb *MessageBus) PublishInbound(ctx context.Context, msg InboundMessage) er if mb.closed.Load() { return ErrBusClosed } + if err := ctx.Err(); err != nil { + return err + } select { case mb.inbound <- msg: return nil @@ -57,6 +60,9 @@ func (mb *MessageBus) PublishOutbound(ctx context.Context, msg OutboundMessage) if mb.closed.Load() { return ErrBusClosed } + if err := ctx.Err(); err != nil { + return err + } select { case mb.outbound <- msg: return nil @@ -82,6 +88,9 @@ func (mb *MessageBus) PublishOutboundMedia(ctx context.Context, msg OutboundMedi if mb.closed.Load() { return ErrBusClosed } + if err := ctx.Err(); err != nil { + return err + } select { case mb.outboundMedia <- msg: return nil