From 74a26981695de17809f473c7bcb2828aa7561554 Mon Sep 17 00:00:00 2001 From: YS Liu Date: Wed, 25 Feb 2026 10:35:46 +0800 Subject: [PATCH] gateway: add --help, default WebSocket on, --no-websocket, remove -e - Add --help/-h: print gatewayHelp() and exit before loading config - Default enableWebSocket=true (WebSocket gateway + /health, /ready) - Add --no-websocket for health-only mode - Remove -e; keep --enable-websocket as no-op for compatibility - Add gatewayHelp() with usage, options, examples Co-authored-by: Cursor --- cmd/picoclaw/cmd_gateway.go | 59 ++++- config/config.example.json | 4 +- pkg/agent/loop.go | 27 +- pkg/bus/types.go | 4 + pkg/config/config.go | 13 +- pkg/config/defaults.go | 6 +- pkg/gateway/message.go | 81 ++++++ pkg/gateway/server.go | 501 ++++++++++++++++++++++++++++++++++++ pkg/gateway/server_test.go | 395 ++++++++++++++++++++++++++++ pkg/gateway/types.go | 81 ++++++ pkg/session/manager.go | 52 ++++ 11 files changed, 1200 insertions(+), 23 deletions(-) create mode 100644 pkg/gateway/message.go create mode 100644 pkg/gateway/server.go create mode 100644 pkg/gateway/server_test.go create mode 100644 pkg/gateway/types.go diff --git a/cmd/picoclaw/cmd_gateway.go b/cmd/picoclaw/cmd_gateway.go index cf7f3563a..03c9a6a6b 100644 --- a/cmd/picoclaw/cmd_gateway.go +++ b/cmd/picoclaw/cmd_gateway.go @@ -19,6 +19,7 @@ import ( "github.com/sipeed/picoclaw/pkg/config" "github.com/sipeed/picoclaw/pkg/cron" "github.com/sipeed/picoclaw/pkg/devices" + "github.com/sipeed/picoclaw/pkg/gateway" "github.com/sipeed/picoclaw/pkg/health" "github.com/sipeed/picoclaw/pkg/heartbeat" "github.com/sipeed/picoclaw/pkg/logger" @@ -29,14 +30,25 @@ import ( ) func gatewayCmd() { - // Check for --debug flag args := os.Args[2:] + for _, arg := range args { + if arg == "--help" || arg == "-h" { + gatewayHelp() + os.Exit(0) + } + } + + // Parse flags: default WebSocket on; --no-websocket for health-only + enableWebSocket := true for _, arg := range args { if arg == "--debug" || arg == "-d" { logger.SetLevel(logger.DEBUG) fmt.Println("🔍 Debug mode enabled") - break } + if arg == "--no-websocket" { + enableWebSocket = false + } + // --enable-websocket is a no-op (WebSocket is on by default) } cfg, err := loadConfig() @@ -196,13 +208,24 @@ func gatewayCmd() { fmt.Printf("Error starting channels: %v\n", err) } - healthServer := health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) - go func() { - if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { - logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) - } - }() - fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + var healthServer *health.Server + if enableWebSocket { + gw := gateway.NewServer(&cfg.Gateway, agentLoop.GetRegistry(), msgBus) + go func() { + if err := gw.Start(ctx); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("gateway", "Gateway server error", map[string]any{"error": err.Error()}) + } + }() + fmt.Printf("✓ WebSocket Gateway and health at http://%s:%d (/, /health, /ready)\n", cfg.Gateway.Host, cfg.Gateway.Port) + } else { + healthServer = health.NewServer(cfg.Gateway.Host, cfg.Gateway.Port) + go func() { + if err := healthServer.Start(); err != nil && err != http.ErrServerClosed { + logger.ErrorCF("health", "Health server error", map[string]any{"error": err.Error()}) + } + }() + fmt.Printf("✓ Health endpoints available at http://%s:%d/health and /ready\n", cfg.Gateway.Host, cfg.Gateway.Port) + } go agentLoop.Run(ctx) @@ -212,7 +235,9 @@ func gatewayCmd() { fmt.Println("\nShutting down...") cancel() - healthServer.Stop(context.Background()) + if healthServer != nil { + healthServer.Stop(context.Background()) + } deviceService.Stop() heartbeatService.Stop() cronService.Stop() @@ -221,6 +246,20 @@ func gatewayCmd() { fmt.Println("✓ Gateway stopped") } +func gatewayHelp() { + fmt.Println("\nStart the PicoClaw gateway (channels, agent, health). WebSocket gateway is enabled by default.") + fmt.Println() + fmt.Println("Usage: picoclaw gateway [options]") + fmt.Println() + fmt.Println("Options:") + fmt.Println(" -d, --debug Enable debug logging") + fmt.Println(" --no-websocket Only serve /health and /ready (no WebSocket gateway)") + fmt.Println() + fmt.Println("Examples:") + fmt.Println(" picoclaw gateway Start with WebSocket gateway and health endpoints") + fmt.Println(" picoclaw gateway --no-websocket Health endpoints only") +} + func setupCronTool( agentLoop *agent.AgentLoop, msgBus *bus.MessageBus, diff --git a/config/config.example.json b/config/config.example.json index e8c6b3d3f..d85ff9b7f 100644 --- a/config/config.example.json +++ b/config/config.example.json @@ -248,6 +248,8 @@ }, "gateway": { "host": "127.0.0.1", - "port": 18790 + "port": 18790, + "token": "", + "password": "" } } diff --git a/pkg/agent/loop.go b/pkg/agent/loop.go index 9a2bb1198..ed510c485 100644 --- a/pkg/agent/loop.go +++ b/pkg/agent/loop.go @@ -172,10 +172,9 @@ func (al *AgentLoop) Run(ctx context.Context) error { response = fmt.Sprintf("Error processing message: %v", err) } - if response != "" { + if response != "" && msg.Channel != "web" { + // Web channel already got final response in processMessage (state "final"); only publish for other channels. // Check if the message tool already sent a response during this round. - // If so, skip publishing to avoid duplicate messages to the user. - // Use default agent's tools to check (message tool is shared). alreadySent := false defaultAgent := al.registry.GetDefaultAgent() if defaultAgent != nil { @@ -185,7 +184,6 @@ func (al *AgentLoop) Run(ctx context.Context) error { } } } - if !alreadySent { al.bus.PublishOutbound(bus.OutboundMessage{ Channel: msg.Channel, @@ -200,6 +198,11 @@ func (al *AgentLoop) Run(ctx context.Context) error { return nil } +// GetRegistry returns the agent registry for gateway/session resolution. +func (al *AgentLoop) GetRegistry() *AgentRegistry { + return al.registry +} + func (al *AgentLoop) Stop() { al.running.Store(false) } @@ -323,6 +326,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) "matched_by": route.MatchedBy, }) + sendResponse := msg.Channel == "web" return al.runAgentLoop(ctx, agent, processOptions{ SessionKey: sessionKey, Channel: msg.Channel, @@ -330,7 +334,7 @@ func (al *AgentLoop) processMessage(ctx context.Context, msg bus.InboundMessage) UserMessage: msg.Content, DefaultResponse: "I've completed processing but have no response to give.", EnableSummary: true, - SendResponse: false, + SendResponse: sendResponse, }) } @@ -448,13 +452,20 @@ func (al *AgentLoop) runAgentLoop(ctx context.Context, agent *AgentInstance, opt al.maybeSummarize(agent, opts.SessionKey, opts.Channel, opts.ChatID) } - // 8. Optional: send response via bus + // 8. Optional: send response via bus (e.g. web gateway for WebClaw) if opts.SendResponse { - al.bus.PublishOutbound(bus.OutboundMessage{ + out := bus.OutboundMessage{ Channel: opts.Channel, ChatID: opts.ChatID, Content: finalContent, - }) + State: "final", + } + if opts.Channel == "web" { + if idx := strings.LastIndex(opts.ChatID, "|"); idx >= 0 && idx < len(opts.ChatID)-1 { + out.RunID = opts.ChatID[idx+1:] + } + } + al.bus.PublishOutbound(out) } // 9. Log response diff --git a/pkg/bus/types.go b/pkg/bus/types.go index 44f9181a5..17cfdec3c 100644 --- a/pkg/bus/types.go +++ b/pkg/bus/types.go @@ -10,10 +10,14 @@ type InboundMessage struct { Metadata map[string]string `json:"metadata,omitempty"` } +// OutboundMessage is sent from the agent loop to channels (e.g. web gateway). +// State and RunID support streaming: "streaming" for chunks, "final" for complete reply. type OutboundMessage struct { Channel string `json:"channel"` ChatID string `json:"chat_id"` Content string `json:"content"` + State string `json:"state,omitempty"` // e.g. "streaming", "final" + RunID string `json:"run_id,omitempty"` // idempotency/run identifier for the turn } type MessageHandler func(InboundMessage) error diff --git a/pkg/config/config.go b/pkg/config/config.go index 85135c820..954d9901b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -419,9 +419,18 @@ func (c *ModelConfig) Validate() error { return nil } +// WebSessionAgentBinding maps a session key prefix to an agent ID for web channel. +type WebSessionAgentBinding struct { + SessionKeyPrefix string `json:"session_key_prefix"` + AgentID string `json:"agent_id"` +} + type GatewayConfig struct { - Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` - Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` + Host string `json:"host" env:"PICOCLAW_GATEWAY_HOST"` + Port int `json:"port" env:"PICOCLAW_GATEWAY_PORT"` + Token string `json:"token" env:"PICOCLAW_GATEWAY_TOKEN"` + Password string `json:"password" env:"PICOCLAW_GATEWAY_PASSWORD"` + WebSessionAgentBindings []WebSessionAgentBinding `json:"web_session_agent_bindings,omitempty"` } type BraveConfig struct { diff --git a/pkg/config/defaults.go b/pkg/config/defaults.go index b96ee4d89..46a7f1880 100644 --- a/pkg/config/defaults.go +++ b/pkg/config/defaults.go @@ -272,8 +272,10 @@ func DefaultConfig() *Config { }, }, Gateway: GatewayConfig{ - Host: "127.0.0.1", - Port: 18790, + Host: "127.0.0.1", + Port: 18790, + Token: "", + Password: "", }, Tools: ToolsConfig{ Web: WebToolsConfig{ diff --git a/pkg/gateway/message.go b/pkg/gateway/message.go new file mode 100644 index 000000000..545c44c42 --- /dev/null +++ b/pkg/gateway/message.go @@ -0,0 +1,81 @@ +package gateway + +import ( + "crypto/sha1" + "fmt" + "time" + + "github.com/sipeed/picoclaw/pkg/providers" +) + +// messageToGateway converts a providers.Message to a WebClaw GatewayMessage (as map for JSON). +// Order of messages is preserved by the caller (session history order). +func messageToGateway(m providers.Message, index int, baseTime int64) map[string]any { + out := map[string]any{ + "role": m.Role, + } + if baseTime > 0 { + out["createdAt"] = baseTime + int64(index)*1000 + } + // Generate a stable id for the message. + id := msgID(m, index) + out["id"] = id + + switch m.Role { + case "user": + out["content"] = []map[string]any{ + {"type": "text", "text": m.Content}, + } + case "assistant": + var content []map[string]any + if m.Content != "" { + content = append(content, map[string]any{"type": "text", "text": m.Content}) + } + for _, tc := range m.ToolCalls { + content = append(content, map[string]any{ + "type": "toolCall", + "id": tc.ID, + "name": tc.Name, + "arguments": tc.Arguments, + }) + } + if len(content) == 0 { + content = []map[string]any{{"type": "text", "text": ""}} + } + out["content"] = content + case "tool": + // WebClaw expects role "toolResult" for tool results. + out["role"] = "toolResult" + out["toolCallId"] = m.ToolCallID + out["content"] = []map[string]any{ + {"type": "text", "text": m.Content}, + } + default: + out["content"] = []map[string]any{ + {"type": "text", "text": m.Content}, + } + } + return out +} + +func msgID(m providers.Message, index int) string { + h := sha1.New() + h.Write([]byte(m.Role + m.Content + m.ToolCallID)) + for _, tc := range m.ToolCalls { + h.Write([]byte(tc.ID + tc.Name)) + } + return fmt.Sprintf("msg-%x-%d", h.Sum(nil)[:8], index) +} + +// historyToGatewayMessages converts session history to WebClaw messages array. +// baseTime is used for createdAt when not stored (e.g. 2026-01-01 00:00:00 UTC ms). +func historyToGatewayMessages(history []providers.Message, baseTime int64) []map[string]any { + if baseTime == 0 { + baseTime = time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC).UnixMilli() + } + out := make([]map[string]any, 0, len(history)) + for i := range history { + out = append(out, messageToGateway(history[i], i, baseTime)) + } + return out +} diff --git a/pkg/gateway/server.go b/pkg/gateway/server.go new file mode 100644 index 000000000..19ec285e2 --- /dev/null +++ b/pkg/gateway/server.go @@ -0,0 +1,501 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "sort" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/logger" + "github.com/sipeed/picoclaw/pkg/routing" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { return true }, +} + +// Server is the WebSocket Gateway server (WebClaw/OpenClaw protocol). +type Server struct { + cfg *config.GatewayConfig + registry *agent.AgentRegistry + bus *bus.MessageBus + server *http.Server + subs map[string]map[*websocket.Conn]struct{} + subsMu sync.RWMutex + connWriteMu sync.Map // *websocket.Conn -> *sync.Mutex, serializes writes per conn + seq int + seqMu sync.Mutex +} + +// NewServer creates a new Gateway server. It serves /health, /ready, and / (WebSocket). +func NewServer(cfg *config.GatewayConfig, registry *agent.AgentRegistry, msgBus *bus.MessageBus) *Server { + s := &Server{ + cfg: cfg, + registry: registry, + bus: msgBus, + subs: make(map[string]map[*websocket.Conn]struct{}), + } + mux := http.NewServeMux() + mux.HandleFunc("/health", s.serveHealth) + mux.HandleFunc("/ready", s.serveReady) + mux.HandleFunc("/", s.serveWebSocket) + s.server = &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), + Handler: mux, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + } + return s +} + +func (s *Server) serveHealth(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ok"}`)) +} + +func (s *Server) serveReady(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"status":"ready"}`)) +} + +func (s *Server) serveWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + logger.ErrorCF("gateway", "WebSocket upgrade failed", map[string]any{"error": err.Error()}) + return + } + defer conn.Close() + + defer func() { + s.subsMu.Lock() + for _, m := range s.subs { + delete(m, conn) + } + s.subsMu.Unlock() + s.connWriteMu.Delete(conn) + }() + + for { + _, data, err := conn.ReadMessage() + if err != nil { + break + } + var frame GatewayFrame + if err := json.Unmarshal(data, &frame); err != nil { + continue + } + if frame.Type != "req" { + continue + } + s.handleRequest(conn, &frame) + } +} + +func (s *Server) handleRequest(conn *websocket.Conn, frame *GatewayFrame) { + var payload interface{} + var errMsg string + var errCode string + + switch frame.Method { + case "connect": + payload, errCode, errMsg = s.handleConnect(frame.Params) + case "sessions.list": + payload, errCode, errMsg = s.handleSessionsList(frame.Params) + case "sessions.patch": + payload, errCode, errMsg = s.handleSessionsPatch(frame.Params) + case "sessions.resolve": + payload, errCode, errMsg = s.handleSessionsResolve(frame.Params) + case "sessions.delete": + payload, errCode, errMsg = s.handleSessionsDelete(frame.Params) + case "chat.send": + payload, errCode, errMsg = s.handleChatSend(frame.Params) + case "chat.history": + payload, errCode, errMsg = s.handleChatHistory(frame.Params) + case "chat.subscribe": + payload, errCode, errMsg = s.handleChatSubscribe(conn, frame.Params) + default: + errCode = "METHOD_NOT_FOUND" + errMsg = "Method not implemented: " + frame.Method + } + + if errCode != "" { + s.sendError(conn, frame.ID, errCode, errMsg) + return + } + s.sendResponse(conn, frame.ID, payload) +} + +func (s *Server) handleConnect(params json.RawMessage) (interface{}, string, string) { + var p ConnectParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, "BAD_REQUEST", "Invalid connect params" + } + token := strings.TrimSpace(p.Auth.Token) + password := strings.TrimSpace(p.Auth.Password) + cfgToken := strings.TrimSpace(s.cfg.Token) + cfgPassword := strings.TrimSpace(s.cfg.Password) + if cfgToken != "" && token != cfgToken { + return nil, "UNAUTHORIZED", "Invalid gateway token" + } + if cfgPassword != "" && password != cfgPassword { + return nil, "UNAUTHORIZED", "Invalid gateway password" + } + if cfgToken == "" && cfgPassword == "" && (token == "" && password == "") { + return nil, "UNAUTHORIZED", "Missing gateway auth" + } + return map[string]any{"protocol": 3, "server": "picoclaw"}, "", "" +} + +// resolveSessionKey returns internal key and agentID (4.3.1). +func (s *Server) resolveSessionKey(key string) (internalKey string, agentID string) { + key = strings.TrimSpace(key) + if key == "" { + return "agent:main:main", "main" + } + if parsed := routing.ParseAgentSessionKey(key); parsed != nil { + return key, parsed.AgentID + } + // Longest prefix match on WebSessionAgentBindings + bindings := make([]config.WebSessionAgentBinding, len(s.cfg.WebSessionAgentBindings)) + copy(bindings, s.cfg.WebSessionAgentBindings) + sort.Slice(bindings, func(i, j int) bool { + return len(bindings[i].SessionKeyPrefix) > len(bindings[j].SessionKeyPrefix) + }) + for _, b := range bindings { + prefix := strings.TrimSpace(b.SessionKeyPrefix) + if prefix != "" && strings.HasPrefix(key, prefix) { + aid := strings.TrimSpace(b.AgentID) + if aid == "" { + aid = "main" + } + return "agent:" + aid + ":" + key, routing.NormalizeAgentID(aid) + } + } + return "agent:main:" + key, "main" +} + +func (s *Server) getAgent(agentID string) *agent.AgentInstance { + ag, _ := s.registry.GetAgent(agentID) + if ag != nil { + return ag + } + return s.registry.GetDefaultAgent() +} + +func (s *Server) handleSessionsList(params json.RawMessage) (interface{}, string, string) { + var p SessionsListParams + _ = json.Unmarshal(params, &p) + if p.Limit <= 0 { + p.Limit = 50 + } + // List from default agent only for now; multi-agent can merge later. + ag := s.registry.GetDefaultAgent() + if ag == nil { + return map[string]any{"sessions": []any{}}, "", "" + } + meta := ag.Sessions.ListSessions() + sessions := make([]map[string]any, 0, len(meta)) + for _, m := range meta { + // Display key: strip "agent:main:" prefix for friendlyId + displayKey := m.Key + if strings.HasPrefix(displayKey, "agent:main:") { + displayKey = strings.TrimPrefix(displayKey, "agent:main:") + } + // Internal heartbeat session is not exposed to WebClaw + if displayKey == "heartbeat" { + continue + } + ent := map[string]any{ + "key": displayKey, + "friendlyId": displayKey, + "updatedAt": m.UpdatedAt.UnixMilli(), + "label": m.Label, + } + if p.IncludeLastMessage { + hist := ag.Sessions.GetHistory(m.Key) + if len(hist) > 0 { + last := hist[len(hist)-1] + ent["lastMessage"] = messageToGateway(last, len(hist)-1, 0) + } + } + if p.IncludeDerivedTitles { + ent["derivedTitle"] = m.Label + } + sessions = append(sessions, ent) + } + return map[string]any{"sessions": sessions}, "", "" +} + +func (s *Server) handleSessionsPatch(params json.RawMessage) (interface{}, string, string) { + var p SessionsPatchParams + if err := json.Unmarshal(params, &p); err != nil || strings.TrimSpace(p.Key) == "" { + return nil, "BAD_REQUEST", "key required" + } + internalKey, agentID := s.resolveSessionKey(p.Key) + ag := s.getAgent(agentID) + if ag == nil { + return nil, "INTERNAL", "no agent" + } + ag.Sessions.GetOrCreate(internalKey) + if p.Label != "" { + ag.Sessions.SetLabel(internalKey, strings.TrimSpace(p.Label)) + } + if err := ag.Sessions.Save(internalKey); err != nil { + return nil, "INTERNAL", err.Error() + } + displayKey := internalKey + if strings.HasPrefix(internalKey, "agent:") { + if idx := strings.Index(internalKey[6:], ":"); idx >= 0 { + displayKey = internalKey[6+idx+1:] + } + } + return map[string]any{"ok": true, "key": displayKey, "entry": map[string]any{"key": displayKey}}, "", "" +} + +func (s *Server) handleSessionsResolve(params json.RawMessage) (interface{}, string, string) { + var p SessionsResolveParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, "BAD_REQUEST", "invalid params" + } + internalKey, _ := s.resolveSessionKey(strings.TrimSpace(p.Key)) + // Return key in display form for WebClaw (friendlyId); chat.send/history accept either and resolve again. + displayKey := internalKey + if strings.HasPrefix(internalKey, "agent:") { + if idx := strings.Index(internalKey[6:], ":"); idx >= 0 { + displayKey = internalKey[6+idx+1:] + } + } + return map[string]any{"ok": true, "key": displayKey}, "", "" +} + +func (s *Server) handleSessionsDelete(params json.RawMessage) (interface{}, string, string) { + var p SessionsDeleteParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, "BAD_REQUEST", "invalid params" + } + internalKey, agentID := s.resolveSessionKey(strings.TrimSpace(p.Key)) + mainKey := routing.BuildAgentMainSessionKey(agentID) + if internalKey == mainKey { + return nil, "INVALID_REQUEST", fmt.Sprintf("Cannot delete the main session (%s).", mainKey) + } + displayKey := internalKey + if strings.HasPrefix(internalKey, "agent:") { + if idx := strings.Index(internalKey[6:], ":"); idx >= 0 { + displayKey = internalKey[6+idx+1:] + } + } + if displayKey == "heartbeat" { + return nil, "INVALID_REQUEST", "Cannot delete the heartbeat session." + } + ag := s.getAgent(agentID) + if ag != nil { + _ = ag.Sessions.Delete(internalKey) + } + return map[string]any{"ok": true}, "", "" +} + +func (s *Server) handleChatSend(params json.RawMessage) (interface{}, string, string) { + var p ChatSendParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, "BAD_REQUEST", "invalid params" + } + sessionKey := strings.TrimSpace(p.SessionKey) + if sessionKey == "" && strings.TrimSpace(p.IdempotencyKey) != "" { + sessionKey = "main" + } + internalKey, agentID := s.resolveSessionKey(sessionKey) + runID := strings.TrimSpace(p.IdempotencyKey) + if runID == "" { + runID = fmt.Sprintf("run-%d", time.Now().UnixNano()) + } + chatID := internalKey + "|" + runID + content := strings.TrimSpace(p.Message) + if content == "" { + return nil, "BAD_REQUEST", "message required" + } + ag := s.getAgent(agentID) + if ag == nil { + return nil, "INTERNAL", "no agent" + } + s.bus.PublishInbound(bus.InboundMessage{ + Channel: "web", + SenderID: "webclaw", + ChatID: chatID, + Content: content, + SessionKey: internalKey, + }) + displayKey := internalKey + if strings.HasPrefix(internalKey, "agent:") { + if idx := strings.Index(internalKey[6:], ":"); idx >= 0 { + displayKey = internalKey[6+idx+1:] + } + } + return map[string]any{"ok": true, "runId": runID, "sessionKey": displayKey}, "", "" +} + +func (s *Server) handleChatHistory(params json.RawMessage) (interface{}, string, string) { + var p ChatHistoryParams + if err := json.Unmarshal(params, &p); err != nil { + return nil, "BAD_REQUEST", "invalid params" + } + internalKey, agentID := s.resolveSessionKey(strings.TrimSpace(p.SessionKey)) + if p.Limit <= 0 { + p.Limit = 200 + } + ag := s.getAgent(agentID) + if ag == nil { + return map[string]any{"sessionKey": p.SessionKey, "messages": []any{}}, "", "" + } + history := ag.Sessions.GetHistory(internalKey) + if len(history) > p.Limit { + history = history[len(history)-p.Limit:] + } + msgs := historyToGatewayMessages(history, 0) + return map[string]any{"sessionKey": p.SessionKey, "messages": msgs}, "", "" +} + +func (s *Server) handleChatSubscribe(conn *websocket.Conn, params json.RawMessage) (interface{}, string, string) { + var p ChatSubscribeParams + _ = json.Unmarshal(params, &p) + key := strings.TrimSpace(p.SessionKey) + if key == "" { + key = strings.TrimSpace(p.FriendlyId) + } + if key == "" { + return map[string]any{"ok": true}, "", "" + } + internalKey, _ := s.resolveSessionKey(key) + s.subsMu.Lock() + if s.subs[internalKey] == nil { + s.subs[internalKey] = make(map[*websocket.Conn]struct{}) + } + s.subs[internalKey][conn] = struct{}{} + s.subsMu.Unlock() + return map[string]any{"ok": true}, "", "" +} + +// writeJSONLocked serializes writes per connection (gorilla/websocket forbids concurrent writes). +func (s *Server) writeJSONLocked(conn *websocket.Conn, v any) error { + muI, _ := s.connWriteMu.LoadOrStore(conn, &sync.Mutex{}) + mu := muI.(*sync.Mutex) + mu.Lock() + defer mu.Unlock() + return conn.WriteJSON(v) +} + +func (s *Server) sendResponse(conn *websocket.Conn, id string, payload interface{}) { + frame := GatewayFrame{Type: "res", ID: id, Ok: true, Payload: payload} + _ = s.writeJSONLocked(conn, frame) +} + +func (s *Server) sendError(conn *websocket.Conn, id, code, message string) { + frame := GatewayFrame{ + Type: "res", + ID: id, + Ok: false, + Error: &GatewayError{Code: code, Message: message}, + } + _ = s.writeJSONLocked(conn, frame) +} + +// Start starts the HTTP server and the outbound consumer goroutine. Blocks until ctx is done or server errors. +func (s *Server) Start(ctx context.Context) error { + go s.consumeOutbound(ctx) + errCh := make(chan error, 1) + go func() { errCh <- s.server.ListenAndServe() }() + select { + case err := <-errCh: + return err + case <-ctx.Done(): + return s.server.Shutdown(context.Background()) + } +} + +// Serve runs the outbound consumer and serves HTTP on the given listener. Used by tests to bind to a random port. +func (s *Server) Serve(ctx context.Context, listener net.Listener) error { + go s.consumeOutbound(ctx) + errCh := make(chan error, 1) + go func() { errCh <- s.server.Serve(listener) }() + select { + case err := <-errCh: + return err + case <-ctx.Done(): + return s.server.Shutdown(context.Background()) + } +} + +// consumeOutbound reads from the bus and pushes events to subscribed WebSocket connections. +func (s *Server) consumeOutbound(ctx context.Context) { + for { + msg, ok := s.bus.SubscribeOutbound(ctx) + if !ok { + return + } + if msg.Channel != "web" { + continue + } + // ChatID format: sessionKey|runId + parts := strings.SplitN(msg.ChatID, "|", 2) + sessionKey := msg.ChatID + runID := "" + if len(parts) == 2 { + sessionKey = parts[0] + runID = parts[1] + } + state := strings.TrimSpace(msg.State) + if state == "" { + state = "final" + } + payload := map[string]any{ + "runId": runID, + "sessionKey": sessionKey, + "state": state, + "message": map[string]any{ + "role": "assistant", + "content": []map[string]any{ + {"type": "text", "text": msg.Content}, + }, + }, + } + s.seqMu.Lock() + s.seq++ + seq := s.seq + s.seqMu.Unlock() + eventFrame := GatewayFrame{ + Type: "event", + Event: "chat", + Payload: payload, + Seq: seq, + } + s.subsMu.RLock() + conns := s.subs[sessionKey] + var snapshot []*websocket.Conn + if len(conns) > 0 { + snapshot = make([]*websocket.Conn, 0, len(conns)) + for c := range conns { + snapshot = append(snapshot, c) + } + } + s.subsMu.RUnlock() + for _, c := range snapshot { + conn := c + go func() { + _ = s.writeJSONLocked(conn, eventFrame) + }() + } + } +} \ No newline at end of file diff --git a/pkg/gateway/server_test.go b/pkg/gateway/server_test.go new file mode 100644 index 000000000..7dbef7d7c --- /dev/null +++ b/pkg/gateway/server_test.go @@ -0,0 +1,395 @@ +package gateway + +import ( + "context" + "encoding/json" + "net" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/sipeed/picoclaw/pkg/agent" + "github.com/sipeed/picoclaw/pkg/bus" + "github.com/sipeed/picoclaw/pkg/config" + "github.com/sipeed/picoclaw/pkg/providers" +) + +// stubLLMProvider implements LLMProvider for gateway tests (no real LLM). +type stubLLMProvider struct{} + +func (stubLLMProvider) Chat( + _ context.Context, + _ []providers.Message, + _ []providers.ToolDefinition, + _ string, + _ map[string]any, +) (*providers.LLMResponse, error) { + return &providers.LLMResponse{Content: "ok", ToolCalls: nil}, nil +} + +func (stubLLMProvider) GetDefaultModel() string { return "test" } + +// TestGatewayVerification runs the full Gateway protocol verification (plan-based). +// See docs/gateway-verification.md for the checklist. +func TestGatewayVerification(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = 18790 // unused when using Serve(listener) + cfg.Gateway.Password = "test" + cfg.Agents.Defaults.Workspace = t.TempDir() + + msgBus := bus.NewMessageBus() + registry := agent.NewAgentRegistry(cfg, stubLLMProvider{}) + if registry.GetDefaultAgent() == nil { + t.Fatal("registry has no default agent") + } + + srv := NewServer(&cfg.Gateway, registry, msgBus) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer listener.Close() + addr := listener.Addr().String() + + go func() { + _ = srv.Serve(ctx, listener) + }() + + // Fake "agent": consume inbound, write to session (so chat.history returns messages), then publish outbound + go func() { + for { + msg, ok := msgBus.ConsumeInbound(ctx) + if !ok { + return + } + ag := registry.GetDefaultAgent() + if ag != nil && msg.SessionKey != "" { + ag.Sessions.AddMessage(msg.SessionKey, "user", msg.Content) + ag.Sessions.AddMessage(msg.SessionKey, "assistant", "echo:"+msg.Content) + _ = ag.Sessions.Save(msg.SessionKey) + } + msgBus.PublishOutbound(bus.OutboundMessage{ + Channel: "web", + ChatID: msg.ChatID, + Content: "echo:" + msg.Content, + State: "final", + }) + } + }() + + // 1. HTTP /health + resp, err := http.Get("http://" + addr + "/health") + if err != nil { + t.Fatalf("GET /health: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Errorf("GET /health status: %d", resp.StatusCode) + } + + // 2. HTTP /ready + resp2, err := http.Get("http://" + addr + "/ready") + if err != nil { + t.Fatalf("GET /ready: %v", err) + } + resp2.Body.Close() + if resp2.StatusCode != http.StatusOK { + t.Errorf("GET /ready status: %d", resp2.StatusCode) + } + + // 3–11. WebSocket RPC + event + wsURL := "ws://" + addr + "/" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("WebSocket dial: %v", err) + } + defer conn.Close() + + readRes := func(id string) (ok bool, payload interface{}, errMsg string) { + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, data, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read: %v", err) + } + var frame GatewayFrame + if err := json.Unmarshal(data, &frame); err != nil { + t.Fatalf("unmarshal res: %v", err) + } + if frame.Type != "res" || frame.ID != id { + t.Fatalf("unexpected frame: type=%s id=%s", frame.Type, frame.ID) + } + if !frame.Ok && frame.Error != nil { + return false, nil, frame.Error.Message + } + return frame.Ok, frame.Payload, "" + } + + sendReq := func(method string, params interface{}) string { + id := "req-" + method + "-1" + body := GatewayFrame{Type: "req", ID: id, Method: method, Params: mustMarshal(params)} + if err := conn.WriteJSON(body); err != nil { + t.Fatalf("write %s: %v", method, err) + } + return id + } + + // 3. connect + connectID := sendReq("connect", map[string]any{"auth": map[string]string{"password": "test"}}) + ok, payload, errStr := readRes(connectID) + if !ok || errStr != "" { + t.Fatalf("connect: ok=%v err=%s", ok, errStr) + } + pm, _ := payload.(map[string]interface{}) + if pm["protocol"] == nil || pm["server"] == nil { + t.Errorf("connect payload: %v", payload) + } + + // 4. sessions.list + id := sendReq("sessions.list", map[string]any{"limit": 10}) + ok, payload, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("sessions.list: ok=%v err=%s", ok, errStr) + } + if _, ok := payload.(map[string]interface{}); !ok { + t.Errorf("sessions.list payload type: %T", payload) + } + + // 5. sessions.patch (create session) + sessionKey := "verify-session-" + time.Now().Format("20060102150405") + id = sendReq("sessions.patch", map[string]any{"key": sessionKey, "label": "Verify"}) + ok, payload, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("sessions.patch: ok=%v err=%s", ok, errStr) + } + + // 6. sessions.resolve + id = sendReq("sessions.resolve", map[string]any{"key": sessionKey}) + ok, payload, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("sessions.resolve: ok=%v err=%s", ok, errStr) + } + resolved, _ := payload.(map[string]interface{}) + resolvedKey, _ := resolved["key"].(string) + if resolvedKey == "" { + t.Errorf("sessions.resolve key empty: %v", payload) + } + + // 8. chat.subscribe (before send so we receive the event) + id = sendReq("chat.subscribe", map[string]any{"sessionKey": resolvedKey}) + ok, _, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("chat.subscribe: ok=%v err=%s", ok, errStr) + } + + // 9. chat.send + userMsg := "hello-verification" + id = sendReq("chat.send", map[string]any{ + "sessionKey": resolvedKey, + "message": userMsg, + "deliver": true, + }) + ok, payload, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("chat.send: ok=%v err=%s", ok, errStr) + } + + // 10. wait for event (chat, state final) + _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + var eventFrame GatewayFrame + for { + _, data, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read event: %v", err) + } + if err := json.Unmarshal(data, &eventFrame); err != nil { + t.Fatalf("unmarshal event: %v", err) + } + if eventFrame.Type == "event" && eventFrame.Event == "chat" { + break + } + } + pl, _ := eventFrame.Payload.(map[string]interface{}) + state, _ := pl["state"].(string) + if state != "final" { + t.Errorf("event state: want final, got %s", state) + } + msgObj, _ := pl["message"].(map[string]interface{}) + contentArr, _ := msgObj["content"].([]interface{}) + if len(contentArr) == 0 { + t.Error("event message.content empty") + } else { + first, _ := contentArr[0].(map[string]interface{}) + text, _ := first["text"].(string) + if !strings.Contains(text, "echo:"+userMsg) { + t.Errorf("event content: want echo:%s, got %s", userMsg, text) + } + } + + // 11. chat.history + id = sendReq("chat.history", map[string]any{"sessionKey": resolvedKey, "limit": 20}) + ok, payload, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("chat.history: ok=%v err=%s", ok, errStr) + } + hist, _ := payload.(map[string]interface{}) + msgs, _ := hist["messages"].([]interface{}) + if len(msgs) < 2 { + t.Errorf("chat.history messages: want at least 2, got %d", len(msgs)) + } + + // 7. sessions.delete (run after history so list/patch/resolve/history are covered) + id = sendReq("sessions.delete", map[string]any{"key": sessionKey}) + ok, _, errStr = readRes(id) + if !ok || errStr != "" { + t.Fatalf("sessions.delete: ok=%v err=%s", ok, errStr) + } +} + +func mustMarshal(v interface{}) json.RawMessage { + b, err := json.Marshal(v) + if err != nil { + panic(err) + } + return b +} + +// TestGatewayConnectAuth checks connect with token/password and invalid auth. +func TestGatewayConnectAuth(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = 18790 + cfg.Gateway.Password = "secret" + cfg.Agents.Defaults.Workspace = t.TempDir() + + msgBus := bus.NewMessageBus() + registry := agent.NewAgentRegistry(cfg, stubLLMProvider{}) + srv := NewServer(&cfg.Gateway, registry, msgBus) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer listener.Close() + + go func() { _ = srv.Serve(ctx, listener) }() + go func() { + for { + _, ok := msgBus.ConsumeInbound(ctx) + if !ok { + return + } + } + }() + + u, _ := url.Parse("ws://" + listener.Addr().String() + "/") + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + sendAndRead := func(method string, params map[string]any) (ok bool, errMsg string) { + id := "auth-" + method + "-1" + _ = conn.WriteJSON(GatewayFrame{Type: "req", ID: id, Method: method, Params: mustMarshal(params)}) + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, data, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read: %v", err) + } + var res GatewayFrame + _ = json.Unmarshal(data, &res) + if res.Error != nil { + errMsg = res.Error.Message + } + return res.Ok, errMsg + } + + ok, _ := sendAndRead("connect", map[string]any{"auth": map[string]string{"password": "wrong"}}) + if ok { + t.Error("connect with wrong password should fail") + } + ok, errStr := sendAndRead("connect", map[string]any{"auth": map[string]string{"password": "secret"}}) + if !ok || errStr != "" { + t.Errorf("connect with correct password: ok=%v err=%s", ok, errStr) + } +} + +// TestGatewaySessionsDeleteRejectsMain ensures sessions.delete returns error for main session. +func TestGatewaySessionsDeleteRejectsMain(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + cfg := config.DefaultConfig() + cfg.Gateway.Host = "127.0.0.1" + cfg.Gateway.Port = 0 + cfg.Gateway.Password = "secret" + cfg.Agents.Defaults.Workspace = t.TempDir() + + msgBus := bus.NewMessageBus() + registry := agent.NewAgentRegistry(cfg, stubLLMProvider{}) + srv := NewServer(&cfg.Gateway, registry, msgBus) + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer listener.Close() + + go func() { _ = srv.Serve(ctx, listener) }() + go func() { + for { + _, ok := msgBus.ConsumeInbound(ctx) + if !ok { + return + } + } + }() + + u, _ := url.Parse("ws://" + listener.Addr().String() + "/") + conn, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + sendReq := func(method string, params map[string]any) string { + id := "del-main-1" + _ = conn.WriteJSON(GatewayFrame{Type: "req", ID: id, Method: method, Params: mustMarshal(params)}) + return id + } + readRes := func(id string) (ok bool, errMsg string) { + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, data, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read: %v", err) + } + var res GatewayFrame + _ = json.Unmarshal(data, &res) + if res.Error != nil { + errMsg = res.Error.Message + } + return res.Ok, errMsg + } + + // connect + sendReq("connect", map[string]any{"auth": map[string]string{"password": "secret"}}) + ok, _ := readRes("del-main-1") + if !ok { + t.Fatal("connect failed") + } + + // sessions.delete with key "main" must fail + sendReq("sessions.delete", map[string]any{"key": "main"}) + ok, errStr := readRes("del-main-1") + if ok || !strings.Contains(errStr, "Cannot delete the main session") { + t.Errorf("sessions.delete main: want error containing 'Cannot delete the main session', got ok=%v err=%q", ok, errStr) + } +} diff --git a/pkg/gateway/types.go b/pkg/gateway/types.go new file mode 100644 index 000000000..a5796df34 --- /dev/null +++ b/pkg/gateway/types.go @@ -0,0 +1,81 @@ +// Package gateway implements the WebSocket Gateway protocol for WebClaw/OpenClaw compatibility. +package gateway + +import "encoding/json" + +// GatewayFrame is the JSON wire format: req (method+params), res (ok+payload/error), or event. +type GatewayFrame struct { + Type string `json:"type"` + ID string `json:"id,omitempty"` + Method string `json:"method,omitempty"` + Params json.RawMessage `json:"params,omitempty"` + Event string `json:"event,omitempty"` + Payload interface{} `json:"payload,omitempty"` + Seq int `json:"seq,omitempty"` + StateVersion int `json:"stateVersion,omitempty"` + Ok bool `json:"ok,omitempty"` + Error *GatewayError `json:"error,omitempty"` +} + +// GatewayError is the error payload in a res frame. +type GatewayError struct { + Code string `json:"code"` + Message string `json:"message"` + Details interface{} `json:"details,omitempty"` +} + +// ConnectParams is the params for the connect method (WebClaw sends auth.token / auth.password). +type ConnectParams struct { + Auth struct { + Token string `json:"token"` + Password string `json:"password"` + } `json:"auth"` +} + +// SessionsListParams for sessions.list. +type SessionsListParams struct { + Limit int `json:"limit"` + IncludeLastMessage bool `json:"includeLastMessage"` + IncludeDerivedTitles bool `json:"includeDerivedTitles"` +} + +// SessionsPatchParams for sessions.patch (create or update). +type SessionsPatchParams struct { + Key string `json:"key"` + Label string `json:"label,omitempty"` +} + +// SessionsResolveParams for sessions.resolve. +type SessionsResolveParams struct { + Key string `json:"key"` + IncludeUnknown bool `json:"includeUnknown"` + IncludeGlobal bool `json:"includeGlobal"` +} + +// SessionsDeleteParams for sessions.delete. +type SessionsDeleteParams struct { + Key string `json:"key"` +} + +// ChatSendParams for chat.send. +type ChatSendParams struct { + SessionKey string `json:"sessionKey"` + Message string `json:"message"` + Thinking string `json:"thinking,omitempty"` + Attachments []any `json:"attachments,omitempty"` + Deliver bool `json:"deliver"` + TimeoutMs int `json:"timeoutMs,omitempty"` + IdempotencyKey string `json:"idempotencyKey,omitempty"` +} + +// ChatHistoryParams for chat.history. +type ChatHistoryParams struct { + SessionKey string `json:"sessionKey"` + Limit int `json:"limit"` +} + +// ChatSubscribeParams for chat.subscribe. +type ChatSubscribeParams struct { + SessionKey string `json:"sessionKey,omitempty"` + FriendlyId string `json:"friendlyId,omitempty"` +} diff --git a/pkg/session/manager.go b/pkg/session/manager.go index 08f0b0ad2..4c560109a 100644 --- a/pkg/session/manager.go +++ b/pkg/session/manager.go @@ -15,10 +15,18 @@ type Session struct { Key string `json:"key"` Messages []providers.Message `json:"messages"` Summary string `json:"summary,omitempty"` + Label string `json:"label,omitempty"` Created time.Time `json:"created"` Updated time.Time `json:"updated"` } +// SessionMeta is a lightweight session entry for listing (e.g. sessions.list). +type SessionMeta struct { + Key string + UpdatedAt time.Time + Label string +} + type SessionManager struct { sessions map[string]*Session mu sync.RWMutex @@ -180,6 +188,7 @@ func (sm *SessionManager) Save(key string) error { snapshot := Session{ Key: stored.Key, Summary: stored.Summary, + Label: stored.Label, Created: stored.Created, Updated: stored.Updated, } @@ -265,6 +274,49 @@ func (sm *SessionManager) loadSessions() error { return nil } +// ListSessions returns metadata for all sessions (key, updated time, label) for listing. +func (sm *SessionManager) ListSessions() []SessionMeta { + sm.mu.RLock() + defer sm.mu.RUnlock() + out := make([]SessionMeta, 0, len(sm.sessions)) + for _, s := range sm.sessions { + out = append(out, SessionMeta{ + Key: s.Key, + UpdatedAt: s.Updated, + Label: s.Label, + }) + } + return out +} + +// Delete removes a session from memory and deletes its storage file. +func (sm *SessionManager) Delete(key string) error { + sm.mu.Lock() + delete(sm.sessions, key) + sm.mu.Unlock() + if sm.storage == "" { + return nil + } + filename := sanitizeFilename(key) + if filename == "." || !filepath.IsLocal(filename) || strings.ContainsAny(filename, `/\`) { + return nil + } + sessionPath := filepath.Join(sm.storage, filename+".json") + _ = os.Remove(sessionPath) + return nil +} + +// SetLabel sets the label for a session and updates its Updated time. +func (sm *SessionManager) SetLabel(key string, label string) { + sm.mu.Lock() + defer sm.mu.Unlock() + session, ok := sm.sessions[key] + if ok { + session.Label = label + session.Updated = time.Now() + } +} + // SetHistory updates the messages of a session. func (sm *SessionManager) SetHistory(key string, history []providers.Message) { sm.mu.Lock()