diff --git a/client.go b/client.go index 6e3ca33..3b43cce 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,6 @@ package xrpl import ( - "encoding/json" "errors" "fmt" "log" @@ -18,35 +17,38 @@ type ClientConfig struct { URL string Authorization string Certificate string - ConnectionTimeout time.Duration FeeCushion uint32 Key string MaxFeeXRP uint64 Passphrase byte Proxy byte ProxyAuthorization byte - Timeout time.Duration + ReadTimeout time.Duration + WriteTimeout time.Duration + HeartbeatInterval time.Duration QueueCapacity int } type Client struct { - config ClientConfig - connection *websocket.Conn - closed bool - mutex sync.Mutex - response *http.Response - StreamLedger chan []byte - StreamTransaction chan []byte - StreamValidation chan []byte - StreamManifest chan []byte - StreamPeerStatus chan []byte - StreamConsensus chan []byte - StreamPathFind chan []byte - StreamServer chan []byte - StreamDefault chan []byte - requestQueue map[string](chan<- BaseResponse) - nextId int - err error + config ClientConfig + connection *websocket.Conn + heartbeatDone chan bool + closed bool + mutex sync.Mutex + response *http.Response + StreamLedger chan []byte + StreamTransaction chan []byte + StreamValidation chan []byte + StreamManifest chan []byte + StreamPeerStatus chan []byte + StreamConsensus chan []byte + StreamPathFind chan []byte + StreamServer chan []byte + StreamDefault chan []byte + StreamSubscriptions map[string]bool + requestQueue map[string](chan<- BaseResponse) + nextId int + err error } func (config *ClientConfig) Validate() error { @@ -54,143 +56,149 @@ func (config *ClientConfig) Validate() error { return errors.New("cannot create a new connection with an empty URL") } - if config.ConnectionTimeout < 0 || config.ConnectionTimeout >= math.MaxInt32 { - return fmt.Errorf("connection timeout out of bounds: %d", config.ConnectionTimeout) + if config.ReadTimeout < 0 || + config.ReadTimeout <= config.HeartbeatInterval || + config.ReadTimeout >= math.MaxInt32 { + return fmt.Errorf("connection read timeout out of bounds: %d", config.ReadTimeout) } - - if config.Timeout < 0 || config.Timeout >= math.MaxInt32 { - return fmt.Errorf("timeout out of bounds: %d", config.Timeout) + if config.WriteTimeout < 0 || + config.WriteTimeout <= config.HeartbeatInterval || + config.WriteTimeout >= math.MaxInt32 { + return fmt.Errorf("connection write timeout out of bounds: %d", config.WriteTimeout) + } + if config.HeartbeatInterval < 0 || + config.HeartbeatInterval >= math.MaxInt32 { + return fmt.Errorf("connection heartbeat interval out of bounds: %d", config.HeartbeatInterval) } return nil } func NewClient(config ClientConfig) *Client { - if err := config.Validate(); err != nil { - panic(err) + if config.ReadTimeout == 0 { + config.ReadTimeout = 20 } - - if config.ConnectionTimeout == 0 { - config.ConnectionTimeout = 60 * time.Second + if config.WriteTimeout == 0 { + config.WriteTimeout = 20 + } + if config.HeartbeatInterval == 0 { + config.HeartbeatInterval = 5 } if config.QueueCapacity == 0 { config.QueueCapacity = 128 } + if err := config.Validate(); err != nil { + panic(err) + } + client := &Client{ - config: config, - StreamLedger: make(chan []byte, config.QueueCapacity), - StreamTransaction: make(chan []byte, config.QueueCapacity), - StreamValidation: make(chan []byte, config.QueueCapacity), - StreamManifest: make(chan []byte, config.QueueCapacity), - StreamPeerStatus: make(chan []byte, config.QueueCapacity), - StreamConsensus: make(chan []byte, config.QueueCapacity), - StreamPathFind: make(chan []byte, config.QueueCapacity), - StreamServer: make(chan []byte, config.QueueCapacity), - StreamDefault: make(chan []byte, config.QueueCapacity), - requestQueue: make(map[string](chan<- BaseResponse)), - nextId: 0, - } - c, r, err := websocket.DefaultDialer.Dial(config.URL, nil) + config: config, + heartbeatDone: make(chan bool), + StreamLedger: make(chan []byte, config.QueueCapacity), + StreamTransaction: make(chan []byte, config.QueueCapacity), + StreamValidation: make(chan []byte, config.QueueCapacity), + StreamManifest: make(chan []byte, config.QueueCapacity), + StreamPeerStatus: make(chan []byte, config.QueueCapacity), + StreamConsensus: make(chan []byte, config.QueueCapacity), + StreamPathFind: make(chan []byte, config.QueueCapacity), + StreamServer: make(chan []byte, config.QueueCapacity), + StreamDefault: make(chan []byte, config.QueueCapacity), + StreamSubscriptions: make(map[string]bool), + requestQueue: make(map[string](chan<- BaseResponse)), + nextId: 0, + } + + _, err := client.NewConnection() if err != nil { - client.err = err - return nil + log.Println("WS connection error:", client.config.URL, err) } - defer r.Body.Close() - client.connection = c - client.response = r - client.connection.SetPongHandler(client.handlePong) - go client.handleResponse() return client } -func (c *Client) Ping(message []byte) error { - if err := c.connection.WriteMessage(websocket.PingMessage, message); err != nil { - return err - } - return nil -} - -// Returns incremental ID that may be used as request ID for websocket requests -func (c *Client) NextID() string { +func (c *Client) NewConnection() (*websocket.Conn, error) { c.mutex.Lock() - c.nextId++ - c.mutex.Unlock() - return strconv.Itoa(c.nextId) -} + defer c.mutex.Unlock() -func (c *Client) Subscribe(streams []string) (BaseResponse, error) { - req := BaseRequest{ - "command": "subscribe", - "streams": streams, - } - res, err := c.Request(req) + conn, r, err := websocket.DefaultDialer.Dial(c.config.URL, nil) if err != nil { + c.err = err return nil, err } - return res, nil + defer r.Body.Close() + c.connection = conn + c.response = r + c.closed = false + + // Set connection handlers and heartbeat + c.connection.SetPongHandler(c.handlePong) + go c.handleResponse() + go c.heartbeat() + return c.connection, nil } -func (c *Client) Unsubscribe(streams []string) (BaseResponse, error) { - req := BaseRequest{ - "command": "unsubscribe", - "streams": streams, - } - res, err := c.Request(req) +func (c *Client) Reconnect() error { + // Close old websocket connection + c.Close() + + // Create a new websocket connection + _, err := c.NewConnection() if err != nil { - return nil, err + log.Println("WS reconnection error:", c.config.URL, err) + return err } - return res, nil -} -// Send a websocket request. This method takes a BaseRequest object and automatically adds -// incremental request ID to it. -// -// Example usage: -// -// req := BaseRequest{ -// "command": "account_info", -// "account": "rG1QQv2nh2gr7RCZ1P8YYcBUKCCN633jCn", -// "ledger_index": "current", -// } -// -// err := client.Request(req, func(){}) -func (c *Client) Request(req BaseRequest) (BaseResponse, error) { - requestId := c.NextID() - req["id"] = requestId - data, err := json.Marshal(req) + // Re-subscribe xrpl streams + _, err = c.Subscribe(c.Subscriptions()) if err != nil { - return nil, err + log.Println("WS stream subscription error:", err) } + return nil +} - ch := make(chan BaseResponse, 1) - +func (c *Client) Ping(message []byte) error { c.mutex.Lock() - c.requestQueue[requestId] = ch - err = c.connection.WriteMessage(websocket.TextMessage, data) - if err != nil { - return nil, err + defer c.mutex.Unlock() + // log.Println("PING:", string(message)) + if err := c.connection.WriteMessage(websocket.PingMessage, message); err != nil { + return err } + return nil +} + +// Returns incremental ID that may be used as request ID for websocket requests +func (c *Client) NextID() string { + c.mutex.Lock() + c.nextId++ c.mutex.Unlock() + return strconv.Itoa(c.nextId) +} - res := <-ch - return res, nil +func (c *Client) Subscriptions() []string { + c.mutex.Lock() + defer c.mutex.Unlock() + subs := make([]string, 0, len(c.StreamSubscriptions)) + for k := range c.StreamSubscriptions { + subs = append(subs, k) + } + return subs } func (c *Client) Close() error { c.mutex.Lock() defer c.mutex.Unlock() c.closed = true + c.heartbeatDone <- true err := c.connection.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) if err != nil { - log.Println("Write close error: ", err) + log.Println("WS write error: ", err) return err } err = c.connection.Close() if err != nil { - log.Println("Write close error: ", err) + log.Println("WS close error: ", err) return err } return nil diff --git a/command.go b/command.go new file mode 100644 index 0000000..3f1b165 --- /dev/null +++ b/command.go @@ -0,0 +1,79 @@ +package xrpl + +import ( + "encoding/json" + + "github.com/gorilla/websocket" +) + +func (c *Client) Subscribe(streams []string) (BaseResponse, error) { + req := BaseRequest{ + "command": "subscribe", + "streams": streams, + } + res, err := c.Request(req) + if err != nil { + return nil, err + } + + c.mutex.Lock() + for _, stream := range streams { + c.StreamSubscriptions[stream] = true + } + c.mutex.Unlock() + + return res, nil +} + +func (c *Client) Unsubscribe(streams []string) (BaseResponse, error) { + req := BaseRequest{ + "command": "unsubscribe", + "streams": streams, + } + res, err := c.Request(req) + if err != nil { + return nil, err + } + + c.mutex.Lock() + for _, stream := range streams { + delete(c.StreamSubscriptions, stream) + } + c.mutex.Unlock() + + return res, nil +} + +// Send a websocket request. This method takes a BaseRequest object and automatically adds +// incremental request ID to it. +// +// Example usage: +// +// req := BaseRequest{ +// "command": "account_info", +// "account": "rG1QQv2nh2gr7RCZ1P8YYcBUKCCN633jCn", +// "ledger_index": "current", +// } +// +// err := client.Request(req, func(){}) +func (c *Client) Request(req BaseRequest) (BaseResponse, error) { + requestId := c.NextID() + req["id"] = requestId + data, err := json.Marshal(req) + if err != nil { + return nil, err + } + + ch := make(chan BaseResponse, 1) + + c.mutex.Lock() + c.requestQueue[requestId] = ch + err = c.connection.WriteMessage(websocket.TextMessage, data) + if err != nil { + return nil, err + } + c.mutex.Unlock() + + res := <-ch + return res, nil +} diff --git a/handlers.go b/handlers.go index 32dd10e..e2b78f0 100644 --- a/handlers.go +++ b/handlers.go @@ -4,12 +4,15 @@ import ( "encoding/json" "fmt" "log" + "time" "github.com/gorilla/websocket" ) func (c *Client) handlePong(message string) error { - fmt.Println("PONG response:", message) + // log.Println("PONG:", message) + c.connection.SetReadDeadline(time.Now().Add(c.config.ReadTimeout * time.Second)) + c.connection.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout * time.Second)) return nil } @@ -19,8 +22,10 @@ func (c *Client) handleResponse() error { break } messageType, message, err := c.connection.ReadMessage() - if err != nil && websocket.IsCloseError(err) { - log.Println("XRPL read error: ", err) + if err != nil { + log.Println("WS read error:", err) + c.Reconnect() + break } switch messageType { @@ -38,7 +43,7 @@ func (c *Client) handleResponse() error { func (c *Client) resolveStream(message []byte) { var m BaseResponse if err := json.Unmarshal(message, &m); err != nil { - fmt.Println("json.Unmarshal error: ", err) + log.Println("json.Unmarshal error: ", err) } switch m["type"] { diff --git a/heartbeat.go b/heartbeat.go new file mode 100644 index 0000000..8f435c5 --- /dev/null +++ b/heartbeat.go @@ -0,0 +1,23 @@ +package xrpl + +import ( + "time" +) + +// Heartbeat runner to send Pings periodically. If a Pong is received, it is +// handled by handlePong handler which further extends websocket connection's +// read and write deadline into the future. +func (c *Client) heartbeat() { + // log.Println("INF: Heartbeat started") + ticker := time.NewTicker(c.config.HeartbeatInterval * time.Second) + for { + select { + case <-c.heartbeatDone: + ticker.Stop() + // log.Println("ERR: Heartbeat stopped") + return + case t := <-ticker.C: + c.Ping([]byte(t.String())) + } + } +}