From 43be50abad39024fa2d4fccd229922d41341eb8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D0=9A=D0=B8=D1=80=D0=B8=D0=BB=D0=BB=20=D0=A1=D1=8B=D1=81?= =?UTF-8?q?=D0=BE=D0=B5=D0=B2?= Date: Wed, 8 Nov 2023 22:25:52 +0800 Subject: [PATCH] conver wait for response into a command --- cmd/wsget/main.go | 11 ++++++++++- pkg/cli/cli_test.go | 2 +- pkg/cli/commands.go | 36 ++++++++++++++++++++++++++++++++++-- pkg/ws/ws.go | 28 +++------------------------- pkg/ws/ws_test.go | 12 ++++-------- 5 files changed, 52 insertions(+), 37 deletions(-) diff --git a/cmd/wsget/main.go b/cmd/wsget/main.go index 0d20bd1..5bb33d3 100644 --- a/cmd/wsget/main.go +++ b/cmd/wsget/main.go @@ -3,6 +3,7 @@ package main import ( "fmt" "os" + "time" "github.com/fatih/color" "github.com/ksysoev/wsget/pkg/cli" @@ -75,7 +76,7 @@ func run(cmd *cobra.Command, args []string) { return } - wsConn, err := ws.NewWS(wsURL, ws.Options{SkipSSLVerification: insecure, Headers: headers, WaitForResp: waitResponse}) + wsConn, err := ws.NewWS(wsURL, ws.Options{SkipSSLVerification: insecure, Headers: headers}) if err != nil { color.New(color.FgRed).Println("Unable to connect to the server: ", err) return @@ -101,6 +102,14 @@ func run(cmd *cobra.Command, args []string) { if request != "" { opts.Commands = []cli.Executer{cli.NewCommandSend(request)} + + if waitResponse >= 0 { + opts.Commands = append( + opts.Commands, + cli.NewCommandWaitForResp(time.Duration(waitResponse)*time.Second), + cli.NewCommandExit(), + ) + } } else { opts.Commands = []cli.Executer{cli.NewCommandEdit("")} } diff --git a/pkg/cli/cli_test.go b/pkg/cli/cli_test.go index 6c6cfba..b41156e 100644 --- a/pkg/cli/cli_test.go +++ b/pkg/cli/cli_test.go @@ -54,7 +54,7 @@ func TestNewCLI(t *testing.T) { t.Error("Expected non-nil editor") } - if err = wsConn.Send("Hello, world!"); err != nil { + if _, err = wsConn.Send("Hello, world!"); err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/pkg/cli/commands.go b/pkg/cli/commands.go index 5292781..187cb62 100644 --- a/pkg/cli/commands.go +++ b/pkg/cli/commands.go @@ -3,6 +3,7 @@ package cli import ( "fmt" "io" + "time" "github.com/eiannone/keyboard" "github.com/fatih/color" @@ -55,11 +56,12 @@ func NewCommandSend(request string) *CommandSend { } func (c *CommandSend) Execute(exCtx *ExecutionContext) (Executer, error) { - if err := exCtx.wsConn.Send(c.request); err != nil { + msg, err := exCtx.wsConn.Send(c.request) + if err != nil { return nil, fmt.Errorf("fail to send request: %s", err) } - return nil, nil + return NewCommandPrintMsg(*msg), nil } type CommandPrintMsg struct { @@ -110,3 +112,33 @@ func NewCommandExit() *CommandExit { func (c *CommandExit) Execute(_ *ExecutionContext) (Executer, error) { return nil, fmt.Errorf("interrupted") } + +type CommandWaitForResp struct { + timeout time.Duration +} + +func NewCommandWaitForResp(timeout time.Duration) *CommandWaitForResp { + return &CommandWaitForResp{timeout} +} + +func (c *CommandWaitForResp) Execute(exCtx *ExecutionContext) (Executer, error) { + if c.timeout.Seconds() == 0 { + msg, ok := <-exCtx.wsConn.Messages + if !ok { + return nil, fmt.Errorf("connection closed") + } + + return NewCommandPrintMsg(msg), nil + } + + select { + case <-time.After(c.timeout): + return nil, fmt.Errorf("timeout") + case msg, ok := <-exCtx.wsConn.Messages: + if !ok { + return nil, fmt.Errorf("connection closed") + } + + return NewCommandPrintMsg(msg), nil + } +} diff --git a/pkg/ws/ws.go b/pkg/ws/ws.go index 505cf6b..2fbd712 100644 --- a/pkg/ws/ws.go +++ b/pkg/ws/ws.go @@ -6,7 +6,6 @@ import ( "net/http" "strings" "sync" - "time" "github.com/fatih/color" "golang.org/x/net/websocket" @@ -50,7 +49,6 @@ type Connection struct { type Options struct { Headers []string SkipSSLVerification bool - WaitForResp int } func NewWS(url string, opts Options) (*Connection, error) { @@ -86,14 +84,6 @@ func NewWS(url string, opts Options) (*Connection, error) { ws, err := websocket.DialConfig(cfg) - if opts.WaitForResp > 0 { - go func() { - time.Sleep(time.Duration(opts.WaitForResp) * time.Second) - color.New(color.FgRed).Println("Timeout reached. Closing connection") - ws.Close() - }() - } - if err != nil { return nil, err } @@ -113,12 +103,6 @@ func NewWS(url string, opts Options) (*Connection, error) { err = websocket.Message.Receive(ws, &msg) if err != nil { - if opts.WaitForResp >= 0 { - // If we are waiting for single response and connection is closed - // we just return from the function - return - } - if err.Error() == "EOF" { color.New(color.FgRed).Println("Connection closed by the server") } else { @@ -129,29 +113,23 @@ func NewWS(url string, opts Options) (*Connection, error) { } messages <- Message{Type: Response, Data: msg} - - if opts.WaitForResp >= 0 { - return - } } }() return &Connection{ws: ws, Messages: messages, waitGroup: &waitGroup}, nil } -func (wsInsp *Connection) Send(msg string) error { +func (wsInsp *Connection) Send(msg string) (*Message, error) { wsInsp.waitGroup.Add(1) defer wsInsp.waitGroup.Done() err := websocket.Message.Send(wsInsp.ws, msg) if err != nil { - return err + return nil, err } - wsInsp.Messages <- Message{Type: Request, Data: msg} - - return nil + return &Message{Type: Request, Data: msg}, nil } func (wsInsp *Connection) Close() { diff --git a/pkg/ws/ws_test.go b/pkg/ws/ws_test.go index 6612336..323eec1 100644 --- a/pkg/ws/ws_test.go +++ b/pkg/ws/ws_test.go @@ -28,17 +28,13 @@ func TestNewWS(t *testing.T) { t.Fatalf("Expected ws connection, but got nil") } - if err = ws.Send("Hello, world!"); err != nil { + msg, err := ws.Send("Hello, world!") + if err != nil { t.Fatalf("Unexpected error: %v", err) } - select { - case msg := <-ws.Messages: - if msg.Data != "Hello, world!" { - t.Errorf("Expected message data to be 'Hello, world!', but got %v", msg.Data) - } - default: - t.Errorf("Expected message, but got none") + if msg.Data != "Hello, world!" { + t.Errorf("Expected message data to be 'Hello, world!', but got %v", msg.Data) } }