diff --git a/command.go b/command.go index 80e47436..0eb721d4 100644 --- a/command.go +++ b/command.go @@ -191,6 +191,12 @@ func Finish(id MessageID) *Command { return &Command{[]byte("FIN"), params, nil} } +// Stats creates a new command to indiciate that +// query channel statistics data +func Stats() *Command { + return &Command{[]byte("STATS"), nil, nil} +} + // Requeue creates a new Command to indicate that // a given message (by id) should be requeued after the given delay // NOTE: a delay of 0 indicates immediate requeue diff --git a/conn.go b/conn.go index 8ec8a4ab..bf7fa15d 100644 --- a/conn.go +++ b/conn.go @@ -413,6 +413,10 @@ func (c *Conn) identify() (*IdentifyResponse, error) { return resp, nil } +func (c *Conn) stats() error { + return c.WriteCommand(Stats()) +} + func (c *Conn) upgradeTLS(tlsConf *tls.Config) error { host, _, err := net.SplitHostPort(c.addr) if err != nil { @@ -558,6 +562,8 @@ func (c *Conn) readLoop() { atomic.StoreInt64(&c.lastMsgTimestamp, time.Now().UnixNano()) c.delegate.OnMessage(c, msg) + case FrameTypeStats: + c.delegate.OnStats(c, data) case FrameTypeError: c.log(LogLevelError, "protocol error - %s", data) c.delegate.OnError(c, data) diff --git a/consumer.go b/consumer.go index 77d0acdc..d1b1b920 100644 --- a/consumer.go +++ b/consumer.go @@ -2,6 +2,8 @@ package nsq import ( "bytes" + "context" + "encoding/json" "errors" "fmt" "log" @@ -137,6 +139,9 @@ type Consumer struct { stopHandler sync.Once exitHandler sync.Once + channelStatsTimeout time.Duration + channelStatsMapChan chan map[string]*ChannelStats + // read from this channel to block until consumer is cleanly stopped StopChan chan int exitChan chan int @@ -180,6 +185,8 @@ func NewConsumer(topic string, channel string, config *Config) (*Consumer, error rng: rand.New(rand.NewSource(time.Now().UnixNano())), + channelStatsMapChan: make(chan map[string]*ChannelStats), + StopChan: make(chan int), exitChan: make(chan int), } @@ -205,6 +212,45 @@ func (r *Consumer) Stats() *ConsumerStats { } } +// ChannelStats query channel statistical data +func (r *Consumer) ChannelStats(timeout time.Duration) (map[string]*ChannelStats, error) { + if timeout <= 0 { + return nil, errors.New("timeout must be greater than 0") + } + var ( + conns = r.conns() + channelStatsWG sync.WaitGroup + channelStatsMap = make(map[string]*ChannelStats) + ) + if len(conns) == 0 { + return nil, errors.New("no connections") + } + r.channelStatsTimeout = timeout + for _, conn := range conns { + if err := conn.stats(); err != nil { + r.log(LogLevelError, "(%s) error sending STATS - %s", conn.String(), err) + return nil, err + } + channelStatsWG.Add(1) + go func(timeout time.Duration, wg *sync.WaitGroup) { + defer wg.Done() + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + select { + case data := <-r.channelStatsMapChan: + for addr, channelStats := range data { + channelStatsMap[addr] = channelStats + } + return + case <-ctx.Done(): + return + } + }(timeout, &channelStatsWG) + } + channelStatsWG.Wait() + return channelStatsMap, nil +} + func (r *Consumer) conns() []*Conn { r.mtx.RLock() conns := make([]*Conn, 0, len(r.connections)) @@ -712,6 +758,22 @@ func (r *Consumer) onConnResponse(c *Conn, data []byte) { } } +func (r *Consumer) onConnStats(c *Conn, data []byte) { + var channelStats *ChannelStats + if err := json.Unmarshal(data, &channelStats); err != nil { + r.log(LogLevelError, "(%s) failed to unmarshal channel stats response - %s", c.String(), err) + return + } + go func() { + ctx, cancel := context.WithTimeout(context.Background(), r.channelStatsTimeout) + defer cancel() + select { + case <-ctx.Done(): + case r.channelStatsMapChan <- map[string]*ChannelStats{c.String(): channelStats}: + } + }() +} + func (r *Consumer) onConnError(c *Conn, data []byte) {} func (r *Consumer) onConnHeartbeat(c *Conn) {} diff --git a/consumer_test.go b/consumer_test.go index 945f5c0c..7363ef09 100644 --- a/consumer_test.go +++ b/consumer_test.go @@ -70,6 +70,21 @@ func SendMessage(t *testing.T, port int, topic string, method string, body []byt resp.Body.Close() } +func SendDeferMessage(t *testing.T, port int, topic string, ds time.Duration, body []byte) { + httpclient := &http.Client{} + endpoint := fmt.Sprintf("http://127.0.0.1:%d/pub?topic=%s&defer=%d", port, topic, ds/time.Millisecond) + req, _ := http.NewRequest("POST", endpoint, bytes.NewBuffer(body)) + resp, err := httpclient.Do(req) + if err != nil { + t.Fatalf(err.Error()) + return + } + if resp.StatusCode != 200 { + t.Fatalf("status code: %d", resp.StatusCode) + } + resp.Body.Close() +} + func TestConsumer(t *testing.T) { consumerTest(t, nil) } @@ -258,3 +273,39 @@ func consumerTest(t *testing.T, cb func(c *Config)) { t.Fatal("failed message not done") } } + +func TestChannelStats(t *testing.T) { + config := NewConfig() + laddr := "127.0.0.1" + // so that the test can simulate binding consumer to specified address + config.LocalAddr, _ = net.ResolveTCPAddr("tcp", laddr+":0") + // so that the test can simulate reaching max requeues and a call to LogFailedMessage + config.DefaultRequeueDelay = 0 + // so that the test wont timeout from backing off + config.MaxBackoffDuration = time.Millisecond * 50 + topicName := "channel_stats_test" + q, _ := NewConsumer(topicName, "ch", config) + + h := &MyTestHandler{ + t: t, + q: q, + } + q.AddHandler(h) + + // SendMessage(t, 4151, topicName, "mpub", []byte("{\"msg\":\"double\"}\n{\"msg\":\"double\"}")) + // SendDeferMessage(t, 4151, topicName, time.Minute, []byte(`{"msg":"single"}`)) + // time.Sleep(time.Second) + + addr := "127.0.0.1:4150" + err := q.ConnectToNSQD(addr) + if err != nil { + t.Fatal(err) + } + + m, err := q.ChannelStats(time.Second) + if err != nil { + t.Fatal(err) + } + b, _ := json.Marshal(m) + fmt.Println(string(b)) +} diff --git a/delegates.go b/delegates.go index aca72529..736466ef 100644 --- a/delegates.go +++ b/delegates.go @@ -64,6 +64,10 @@ type ConnDelegate interface { // receives a FrameTypeResponse from nsqd OnResponse(*Conn, []byte) + // OnStats is called when the connection + // receives a FrameTypeStats from nsqd + OnStats(*Conn, []byte) + // OnError is called when the connection // receives a FrameTypeError from nsqd OnError(*Conn, []byte) @@ -108,17 +112,22 @@ type consumerConnDelegate struct { r *Consumer } -func (d *consumerConnDelegate) OnResponse(c *Conn, data []byte) { d.r.onConnResponse(c, data) } -func (d *consumerConnDelegate) OnError(c *Conn, data []byte) { d.r.onConnError(c, data) } -func (d *consumerConnDelegate) OnMessage(c *Conn, m *Message) { d.r.onConnMessage(c, m) } -func (d *consumerConnDelegate) OnMessageFinished(c *Conn, m *Message) { d.r.onConnMessageFinished(c, m) } -func (d *consumerConnDelegate) OnMessageRequeued(c *Conn, m *Message) { d.r.onConnMessageRequeued(c, m) } -func (d *consumerConnDelegate) OnBackoff(c *Conn) { d.r.onConnBackoff(c) } -func (d *consumerConnDelegate) OnContinue(c *Conn) { d.r.onConnContinue(c) } -func (d *consumerConnDelegate) OnResume(c *Conn) { d.r.onConnResume(c) } -func (d *consumerConnDelegate) OnIOError(c *Conn, err error) { d.r.onConnIOError(c, err) } -func (d *consumerConnDelegate) OnHeartbeat(c *Conn) { d.r.onConnHeartbeat(c) } -func (d *consumerConnDelegate) OnClose(c *Conn) { d.r.onConnClose(c) } +func (d *consumerConnDelegate) OnResponse(c *Conn, data []byte) { d.r.onConnResponse(c, data) } +func (d *consumerConnDelegate) OnStats(c *Conn, data []byte) { d.r.onConnStats(c, data) } +func (d *consumerConnDelegate) OnError(c *Conn, data []byte) { d.r.onConnError(c, data) } +func (d *consumerConnDelegate) OnMessage(c *Conn, m *Message) { d.r.onConnMessage(c, m) } +func (d *consumerConnDelegate) OnMessageFinished(c *Conn, m *Message) { + d.r.onConnMessageFinished(c, m) +} +func (d *consumerConnDelegate) OnMessageRequeued(c *Conn, m *Message) { + d.r.onConnMessageRequeued(c, m) +} +func (d *consumerConnDelegate) OnBackoff(c *Conn) { d.r.onConnBackoff(c) } +func (d *consumerConnDelegate) OnContinue(c *Conn) { d.r.onConnContinue(c) } +func (d *consumerConnDelegate) OnResume(c *Conn) { d.r.onConnResume(c) } +func (d *consumerConnDelegate) OnIOError(c *Conn, err error) { d.r.onConnIOError(c, err) } +func (d *consumerConnDelegate) OnHeartbeat(c *Conn) { d.r.onConnHeartbeat(c) } +func (d *consumerConnDelegate) OnClose(c *Conn) { d.r.onConnClose(c) } // keeps the exported Producer struct clean of the exported methods // required to implement the ConnDelegate interface @@ -127,6 +136,7 @@ type producerConnDelegate struct { } func (d *producerConnDelegate) OnResponse(c *Conn, data []byte) { d.w.onConnResponse(c, data) } +func (d *producerConnDelegate) OnStats(c *Conn, data []byte) { d.w.onConnStats(c, data) } func (d *producerConnDelegate) OnError(c *Conn, data []byte) { d.w.onConnError(c, data) } func (d *producerConnDelegate) OnMessage(c *Conn, m *Message) {} func (d *producerConnDelegate) OnMessageFinished(c *Conn, m *Message) {} diff --git a/producer.go b/producer.go index 4019fefd..521cb830 100644 --- a/producer.go +++ b/producer.go @@ -427,6 +427,7 @@ func (w *Producer) log(lvl LogLevel, line string, args ...interface{}) { } func (w *Producer) onConnResponse(c *Conn, data []byte) { w.responseChan <- data } +func (w *Producer) onConnStats(c *Conn, data []byte) {} func (w *Producer) onConnError(c *Conn, data []byte) { w.errorChan <- data } func (w *Producer) onConnHeartbeat(c *Conn) {} func (w *Producer) onConnIOError(c *Conn, err error) { w.close() } diff --git a/protocol.go b/protocol.go index 1d0e1a9d..d00e76fa 100644 --- a/protocol.go +++ b/protocol.go @@ -19,6 +19,7 @@ const ( FrameTypeResponse int32 = 0 FrameTypeError int32 = 1 FrameTypeMessage int32 = 2 + FrameTypeStats int32 = 3 ) // Used to detect if an unexpected HTTP response is read diff --git a/stats.go b/stats.go new file mode 100644 index 00000000..c8f751fa --- /dev/null +++ b/stats.go @@ -0,0 +1,13 @@ +package nsq + +type ChannelStats struct { + ChannelName string `json:"channel_name"` + Depth int64 `json:"depth"` + BackendDepth int64 `json:"backend_depth"` + InFlightCount int `json:"in_flight_count"` + DeferredCount int `json:"deferred_count"` + MessageCount uint64 `json:"message_count"` + RequeueCount uint64 `json:"requeue_count"` + TimeoutCount uint64 `json:"timeout_count"` + Paused bool `json:"paused"` +}