diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go index 7c254d99164..1fe1f61bfb1 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/wsclient/client.go @@ -471,6 +471,31 @@ func (cs *ClientServerImpl) WriteCloseMessage() error { return cs.conn.WriteControl(websocket.CloseMessage, send, time.Now().Add(cs.RWTimeout)) } +func (cs *ClientServerImpl) readMessage() (int, []byte, error) { + return cs.conn.ReadMessage() +} + +func (cs *ClientServerImpl) readMessageWithTimeout(timeout time.Duration) (int, []byte, error) { + type result struct { + messageType int + message []byte + err error + } + + resultChan := make(chan result, 1) + go func() { + messageType, message, err := cs.readMessage() + resultChan <- result{messageType, message, err} + }() + + select { + case res := <-resultChan: + return res.messageType, res.message, res.err + case <-time.After(timeout): + return 0, nil, fmt.Errorf("read message timeout after %v", timeout) + } +} + // ConsumeMessages reads messages from the websocket connection and handles read // messages from an active connection. func (cs *ClientServerImpl) ConsumeMessages(ctx context.Context) error { @@ -482,7 +507,11 @@ func (cs *ClientServerImpl) ConsumeMessages(ctx context.Context) error { errChan <- err return } - messageType, message, err := cs.conn.ReadMessage() + logger.Debug("Will now read message", logger.Fields{"url": cs.URL}) + messageType, message, err := cs.readMessageWithTimeout(cs.RWTimeout) + logger.Debug("Returned from ReadMessage", logger.Fields{ + "url": cs.URL, "messageType": messageType, "message": string(message), "error": err, + }) switch { case err == nil: diff --git a/ecs-agent/wsclient/client.go b/ecs-agent/wsclient/client.go index 7c254d99164..1fe1f61bfb1 100644 --- a/ecs-agent/wsclient/client.go +++ b/ecs-agent/wsclient/client.go @@ -471,6 +471,31 @@ func (cs *ClientServerImpl) WriteCloseMessage() error { return cs.conn.WriteControl(websocket.CloseMessage, send, time.Now().Add(cs.RWTimeout)) } +func (cs *ClientServerImpl) readMessage() (int, []byte, error) { + return cs.conn.ReadMessage() +} + +func (cs *ClientServerImpl) readMessageWithTimeout(timeout time.Duration) (int, []byte, error) { + type result struct { + messageType int + message []byte + err error + } + + resultChan := make(chan result, 1) + go func() { + messageType, message, err := cs.readMessage() + resultChan <- result{messageType, message, err} + }() + + select { + case res := <-resultChan: + return res.messageType, res.message, res.err + case <-time.After(timeout): + return 0, nil, fmt.Errorf("read message timeout after %v", timeout) + } +} + // ConsumeMessages reads messages from the websocket connection and handles read // messages from an active connection. func (cs *ClientServerImpl) ConsumeMessages(ctx context.Context) error { @@ -482,7 +507,11 @@ func (cs *ClientServerImpl) ConsumeMessages(ctx context.Context) error { errChan <- err return } - messageType, message, err := cs.conn.ReadMessage() + logger.Debug("Will now read message", logger.Fields{"url": cs.URL}) + messageType, message, err := cs.readMessageWithTimeout(cs.RWTimeout) + logger.Debug("Returned from ReadMessage", logger.Fields{ + "url": cs.URL, "messageType": messageType, "message": string(message), "error": err, + }) switch { case err == nil: