Skip to content

Commit

Permalink
Merge pull request #923 from im-kulikov/socketmode-pass-context
Browse files Browse the repository at this point in the history
[socketmode] Add methods with passing context
  • Loading branch information
kanata2 authored Apr 18, 2021
2 parents c5aa3bb + 153f70a commit 24ed201
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 11 deletions.
2 changes: 1 addition & 1 deletion socketmode/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ type ConnectionInfo struct {
}

type SocketModeMessagePayload struct {
Event json.RawMessage `json:"´event"`
Event json.RawMessage `json:"event"`
}

// Client is a Socket Mode client that allows programs to use [Events API](https://api.slack.com/events-api)
Expand Down
36 changes: 26 additions & 10 deletions socketmode/socket_mode_managed_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,20 @@ import (
// If you want to retry even on reconnection failure, you'd need to write your own wrapper for this function
// to do so.
func (smc *Client) Run() error {
ctx := context.TODO()
return smc.RunContext(context.TODO())
}

// RunContext is a blocking function that connects the Slack Socket Mode API and handles all incoming
// requests and outgoing responses.
//
// The consumer of the Client and this function should read the Client.Events channel to receive
// `socketmode.Event`s that includes the client-specific events that may or may not wrap Socket Mode requests.
//
// Note that this function automatically reconnect on requested by Slack through a `disconnect` message.
// This function exists with an error only when a reconnection is failued due to some reason.
// If you want to retry even on reconnection failure, you'd need to write your own wrapper for this function
// to do so.
func (smc *Client) RunContext(ctx context.Context) error {
for connectionCount := 0; ; connectionCount++ {
if err := smc.run(ctx, connectionCount); err != nil {
return err
Expand All @@ -44,6 +56,7 @@ func (smc *Client) Run() error {

func (smc *Client) run(ctx context.Context, connectionCount int) error {
messages := make(chan json.RawMessage)
defer close(messages)

deadmanTimer := newDeadmanTimer(smc.maxPingInterval)

Expand Down Expand Up @@ -131,7 +144,10 @@ func (smc *Client) run(ctx context.Context, connectionCount int) error {

select {
case <-ctx.Done():
// Detect when the connection is dead.
// Detect when the connection is dead and try close connection.
if err = conn.Close(); err != nil {
smc.Debugf("Failed to close connection: %v", err)
}
case <-deadmanTimer.Elapsed():
firstErrOnce.Do(func() {
firstErr = errors.New("ping timeout: Slack did not send us WebSocket PING for more than Client.maxInterval")
Expand All @@ -143,14 +159,14 @@ func (smc *Client) run(ctx context.Context, connectionCount int) error {

wg.Wait()

if firstErr == context.Canceled {
return firstErr
}

// wg.Wait() finishes only after any of the above go routines finishes.
// Also, we can expect firstErr to be not nil, as goroutines can finish only on error.
smc.Debugf("Reconnecting due to %v", firstErr)

if err = conn.Close(); err != nil {
smc.Debugf("Failed to close connection: %v", err)
}

return nil
}

Expand Down Expand Up @@ -183,7 +199,7 @@ func (smc *Client) connect(ctx context.Context, connectionCount int, additionalP
})

// attempt to start the connection
info, conn, err := smc.openAndDial(additionalPingHandler)
info, conn, err := smc.openAndDial(ctx, additionalPingHandler)
if err == nil {
return info, conn, nil
}
Expand Down Expand Up @@ -234,13 +250,13 @@ func (smc *Client) connect(ctx context.Context, connectionCount int, additionalP
// openAndDial attempts to open a Socket Mode connection and dial to the connection endpoint using WebSocket.
// It returns the full information returned by the "apps.connections.open" method on the
// Slack API.
func (smc *Client) openAndDial(additionalPingHandler func(string) error) (info *slack.SocketModeConnection, _ *websocket.Conn, err error) {
func (smc *Client) openAndDial(ctx context.Context, additionalPingHandler func(string) error) (info *slack.SocketModeConnection, _ *websocket.Conn, err error) {
var (
url string
)

smc.Debugf("Starting SocketMode")
info, url, err = smc.Open()
info, url, err = smc.OpenContext(ctx)

if err != nil {
smc.Debugf("Failed to start or connect with SocketMode: %s", err)
Expand All @@ -255,7 +271,7 @@ func (smc *Client) openAndDial(additionalPingHandler func(string) error) (info *
if smc.dialer != nil {
dialer = smc.dialer
}
conn, _, err := dialer.Dial(url, upgradeHeader)
conn, _, err := dialer.DialContext(ctx, url, upgradeHeader)
if err != nil {
smc.Debugf("Failed to dial to the websocket: %s", err)
return nil, nil, err
Expand Down
45 changes: 45 additions & 0 deletions socketmode/socket_mode_managed_conn_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// +build go1.13

package socketmode

import (
"context"
"errors"
"testing"
"time"

"github.com/slack-go/slack"
"github.com/slack-go/slack/slacktest"

"github.com/stretchr/testify/assert"
)

func Test_passContext(t *testing.T) {
s := slacktest.NewTestServer()
go s.Start()

api := slack.New("ABCDEFG", slack.OptionAPIURL(s.GetAPIURL()))
cli := New(api)

ctx, cancel := context.WithTimeout(context.TODO(), time.Nanosecond)
defer cancel()

t.Run("RunWithContext", func(t *testing.T) {
// should fail imidiatly.
assert.EqualError(t, cli.RunContext(ctx), context.DeadlineExceeded.Error())
})

t.Run("openAndDial", func(t *testing.T) {
_, _, err := cli.openAndDial(ctx, func(_ string) error { return nil })

// should fail imidiatly.
assert.EqualError(t, errors.Unwrap(err), context.DeadlineExceeded.Error())
})

t.Run("OpenWithContext", func(t *testing.T) {
_, _, err := cli.OpenContext(ctx)

// should fail imidiatly.
assert.EqualError(t, errors.Unwrap(err), context.DeadlineExceeded.Error())
})
}
7 changes: 7 additions & 0 deletions socketmode/socketmode.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ func (smc *Client) Open() (info *slack.SocketModeConnection, websocketURL string
return smc.apiClient.StartSocketModeContext(ctx)
}

// OpenContext calls the "apps.connections.open" endpoint and returns the provided URL and the full Info block.
//
// To have a fully managed Websocket connection, use `New`, and call `Run()` on it.
func (smc *Client) OpenContext(ctx context.Context) (info *slack.SocketModeConnection, websocketURL string, err error) {
return smc.apiClient.StartSocketModeContext(ctx)
}

// Option options for the managed Client.
type Option func(client *Client)

Expand Down

0 comments on commit 24ed201

Please sign in to comment.