Skip to content

Commit

Permalink
Refactor ws package to make it more decoupled
Browse files Browse the repository at this point in the history
  • Loading branch information
ksysoev committed Nov 13, 2023
1 parent 224931a commit f76fcaa
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
8 changes: 4 additions & 4 deletions pkg/cli/cli.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ const (

type CLI struct {
formater *formater.Formater
wsConn *ws.Connection
wsConn ws.ConnectionHandler
editor *Editor
cmdEditor *Editor
input Inputer
Expand All @@ -53,7 +53,7 @@ type Inputer interface {
// NewCLI creates a new CLI instance with the given wsConn, input, and output.
// It returns an error if it fails to get the current user, create the necessary directories,
// load the macro for the domain, or initialize the CLI instance.
func NewCLI(wsConn *ws.Connection, input Inputer, output io.Writer) (*CLI, error) {
func NewCLI(wsConn ws.ConnectionHandler, input Inputer, output io.Writer) (*CLI, error) {
currentUser, err := user.Current()
if err != nil {
return nil, fmt.Errorf("fail to get current user: %s", err)
Expand All @@ -67,7 +67,7 @@ func NewCLI(wsConn *ws.Connection, input Inputer, output io.Writer) (*CLI, error
history := NewHistory(homeDir+"/"+HistoryFilename, HistoryLimit)
cmdHistory := NewHistory(homeDir+"/"+HistoryCmdFilename, HistoryLimit)

macro, err := LoadMacroForDomain(homeDir+"/"+ConfigDir+"/"+MacroDir, wsConn.Hostname)
macro, err := LoadMacroForDomain(homeDir+"/"+ConfigDir+"/"+MacroDir, wsConn.Hostname())
if err != nil {
return nil, fmt.Errorf("fail to load macro: %s", err)
}
Expand Down Expand Up @@ -147,7 +147,7 @@ func (c *CLI) Run(opts RunOptions) error {
}
}

case msg, ok := <-c.wsConn.Messages:
case msg, ok := <-c.wsConn.Messages():
if !ok {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/cli/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ func NewCommandWaitForResp(timeout time.Duration) *CommandWaitForResp {
// If the WebSocket connection is closed, it will return an error.
func (c *CommandWaitForResp) Execute(exCtx *ExecutionContext) (Executer, error) {
if c.timeout.Seconds() == 0 {
msg, ok := <-exCtx.cli.wsConn.Messages
msg, ok := <-exCtx.cli.wsConn.Messages()
if !ok {
return nil, fmt.Errorf("connection closed")
}
Expand All @@ -198,7 +198,7 @@ func (c *CommandWaitForResp) Execute(exCtx *ExecutionContext) (Executer, error)
select {
case <-time.After(c.timeout):
return nil, fmt.Errorf("timeout")
case msg, ok := <-exCtx.cli.wsConn.Messages:
case msg, ok := <-exCtx.cli.wsConn.Messages():
if !ok {
return nil, fmt.Errorf("connection closed")
}
Expand Down
27 changes: 22 additions & 5 deletions pkg/ws/ws.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,19 @@ type Message struct {

type Connection struct {
ws *websocket.Conn
Messages chan Message
messages chan Message
waitGroup *sync.WaitGroup
Hostname string
hostname string
isClosed atomic.Bool
}

type ConnectionHandler interface {
Messages() <-chan Message
Hostname() string
Send(msg string) (*Message, error)
Close()
}

type Options struct {
Headers []string
SkipSSLVerification bool
Expand Down Expand Up @@ -103,19 +110,29 @@ func NewWS(wsURL string, opts Options) (*Connection, error) {

messages := make(chan Message, WSMessageBufferSize)

wsInsp := &Connection{ws: ws, Messages: messages, waitGroup: &waitGroup, Hostname: parsedURL.Hostname()}
wsInsp := &Connection{ws: ws, messages: messages, waitGroup: &waitGroup, hostname: parsedURL.Hostname()}

go wsInsp.handleResponses()

return wsInsp, nil
}

// Messages returns a channel that receives messages from the WebSocket connection.
func (wsInsp *Connection) Messages() <-chan Message {
return wsInsp.messages
}

// Hostname returns the hostname of the WebSocket server.
func (wsInsp *Connection) Hostname() string {
return wsInsp.hostname
}

// handleResponses reads messages from the websocket connection and sends them to the Messages channel.
// It runs in a loop until the connection is closed or an error occurs.
func (wsInsp *Connection) handleResponses() {
defer func() {
wsInsp.waitGroup.Wait()
close(wsInsp.Messages)
close(wsInsp.messages)
}()

for {
Expand All @@ -127,7 +144,7 @@ func (wsInsp *Connection) handleResponses() {
return
}

wsInsp.Messages <- Message{Type: Response, Data: msg}
wsInsp.messages <- Message{Type: Response, Data: msg}
}
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/ws/ws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func TestNewWSDisconnect(t *testing.T) {
ws.Close()

select {
case _, ok := <-ws.Messages:
case _, ok := <-ws.Messages():
if ok {
t.Errorf("Expected channel to be closed")
}
Expand Down

0 comments on commit f76fcaa

Please sign in to comment.