diff --git a/device.go b/device.go index 929a525..7f470bc 100644 --- a/device.go +++ b/device.go @@ -1,6 +1,7 @@ package gocast import ( + "context" "net" "strconv" "sync" @@ -22,6 +23,8 @@ type Device struct { wrapper *packetStream reconnect chan struct{} + stop context.CancelFunc + eventListners []func(event events.Event) subscriptions map[string]*Subscription @@ -36,49 +39,83 @@ func NewDevice() *Device { reconnect: make(chan struct{}), subscriptions: make(map[string]*Subscription), connectionHandler: &handlers.Connection{}, - heartbeatHandler: &handlers.Heartbeat{}, + heartbeatHandler: handlers.NewHeartbeat(), ReceiverHandler: &handlers.Receiver{}, } - d.heartbeatHandler.OnFailure = func() { - d.Disconnect() - } - return d } func (d *Device) SetName(name string) { + d.Lock() d.name = name + d.Unlock() } func (d *Device) SetUuid(uuid string) { + d.Lock() d.uuid = uuid + d.Unlock() } func (d *Device) SetIp(ip net.IP) { + d.Lock() d.ip = ip + d.Unlock() } func (d *Device) SetPort(port int) { + d.Lock() d.port = port + d.Unlock() } func (d *Device) Name() string { + d.RLock() + defer d.RUnlock() return d.name } func (d *Device) Uuid() string { + d.RLock() + defer d.RUnlock() return d.uuid } func (d *Device) Ip() net.IP { + d.RLock() + defer d.RUnlock() return d.ip } func (d *Device) Port() int { + d.RLock() + defer d.RUnlock() return d.port } +func (d *Device) Connected() bool { + d.RLock() + defer d.RUnlock() + return d.connected +} +func (d *Device) getConn() net.Conn { + d.RLock() + defer d.RUnlock() + return d.conn +} +func (d *Device) getSubscriptionsAsSlice() []*Subscription { + d.RLock() + subs := make([]*Subscription, len(d.subscriptions)) + i := 0 + for _, v := range d.subscriptions { + subs[i] = v + i++ + } + defer d.RUnlock() + return subs +} + func (d *Device) String() string { return d.name + " - " + d.ip.String() + ":" + strconv.Itoa(d.port) } @@ -117,10 +154,10 @@ func (d *Device) UnsubscribeByUrn(urn string) { } d.RUnlock() d.Lock() - defer d.Unlock() for _, sub := range subs { delete(d.subscriptions, sub) } + d.Unlock() } func (d *Device) UnsubscribeByUrnAndDestinationId(urn, destinationId string) { diff --git a/device_connection.go b/device_connection.go index 5273a13..c2e86e8 100644 --- a/device_connection.go +++ b/device_connection.go @@ -1,6 +1,7 @@ package gocast import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -14,88 +15,86 @@ import ( "github.com/stampzilla/gocast/responses" ) -func (d *Device) reader() { +func (d *Device) reader(ctx context.Context) { for { - packet, err := d.wrapper.Read() - if err != nil { - logrus.Errorf("Error reading from chromecast error: %s Packet: %#v", err, packet) - d.Disconnect() + select { + case <-ctx.Done(): + if ctx.Err() != nil { + logrus.Errorf("closing reader %s: %s", d.Name(), ctx.Err()) + } return - } + case p := <-d.wrapper.packets: + if p.err != nil { + logrus.Errorf("Error reading from chromecast error: %s Packet: %#v", p.err, p) + return + } + packet := p.payload - message := &api.CastMessage{} - err = proto.Unmarshal(*packet, message) - if err != nil { - logrus.Errorf("Failed to unmarshal CastMessage: %s", err) - continue - } + message := &api.CastMessage{} + err := proto.Unmarshal(packet, message) + if err != nil { + logrus.Errorf("Failed to unmarshal CastMessage: %s", err) + continue + } - headers := &responses.Headers{} + headers := &responses.Headers{} - err = json.Unmarshal([]byte(*message.PayloadUtf8), headers) + err = json.Unmarshal([]byte(*message.PayloadUtf8), headers) - if err != nil { - logrus.Errorf("Failed to unmarshal message: %s", err) - continue - } + if err != nil { + logrus.Errorf("Failed to unmarshal message: %s", err) + continue + } - catched := false - d.RLock() - for _, subscription := range d.subscriptions { - if subscription.Receive(message, headers) { - catched = true + catched := false + for _, subscription := range d.getSubscriptionsAsSlice() { + if subscription.Receive(message, headers) { + catched = true + } } - } - d.RUnlock() - if !catched { - logrus.Debug("LOST MESSAGE:") - logrus.Debug(spew.Sdump(message)) + if !catched { + logrus.Debug("LOST MESSAGE:") + logrus.Debug(spew.Sdump(message)) + } } } } -func (d *Device) Connected() bool { - d.RLock() - defer d.RUnlock() - return d.connected -} - -func (d *Device) Connect() error { - go d.reconnector() - return d.connect() -} +func (d *Device) Connect(ctx context.Context) error { + d.heartbeatHandler.OnFailure = func() { // make sure we reconnect if we loose heartbeat + logrus.Errorf("heartbeat timeout for: %s trying to reconnect", d.Name()) -func (d *Device) Reconnect() { - select { - case d.reconnect <- struct{}{}: - default: - } -} - -func (d *Device) reconnector() { - for { - select { - case <-d.reconnect: - logrus.Info("Reconnect signal received") - time.Sleep(time.Second * 2) - err := d.connect() - if err != nil { - logrus.Error(err) + d.Disconnect() + for { // try to connect until no error + err := d.connect(ctx) + if err == nil { + break } + logrus.Error("error reconnect: ", err) + time.Sleep(2 * time.Second) } } + return d.connect(ctx) } -func (d *Device) connect() error { - logrus.Infof("connecting to %s:%d ...", d.ip, d.port) +func (d *Device) connect(pCtx context.Context) error { + ctx, cancel := context.WithCancel(pCtx) + d.stop = cancel + + ip := d.Ip() + port := d.Port() - if d.conn != nil { - return fmt.Errorf("already connected to: %s (%s:%d)", d.Name(), d.Ip().String(), d.Port()) + logrus.Infof("connecting to %s:%d ...", ip, port) + + if d.getConn() != nil { + err := d.conn.Close() + if err != nil { + logrus.Error("trying to connect with existing connection. error closing: ", err) + } } - var err error - d.conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", d.ip, d.port), &tls.Config{ + conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &tls.Config{ InsecureSkipVerify: true, }) @@ -104,37 +103,50 @@ func (d *Device) connect() error { } d.Lock() + d.conn = conn d.connected = true d.Unlock() - d.Dispatch(events.Connected{}) - d.wrapper = NewPacketStream(d.conn) - go d.reader() + go d.wrapper.readPackets(ctx) + go d.reader(ctx) d.Subscribe("urn:x-cast:com.google.cast.tp.connection", "receiver-0", d.connectionHandler) d.Subscribe("urn:x-cast:com.google.cast.tp.heartbeat", "receiver-0", d.heartbeatHandler) d.Subscribe("urn:x-cast:com.google.cast.receiver", "receiver-0", d.ReceiverHandler) + d.Dispatch(events.Connected{}) + return nil } func (d *Device) Disconnect() { + logrus.Debug("disconnecting: ", d.Name()) + + for _, subscription := range d.getSubscriptionsAsSlice() { + logrus.Debugf("disconnect subscription %s: %s ", d.Name(), subscription.Urn) + subscription.Handler.Disconnect() + } d.Lock() - if d.conn != nil { - for _, subscription := range d.subscriptions { - subscription.Handler.Disconnect() - } + d.subscriptions = make(map[string]*Subscription) + d.Unlock() - d.subscriptions = make(map[string]*Subscription, 0) - d.Dispatch(events.Disconnected{}) + if d.stop != nil { // make sure any old goroutines are stopped + d.stop() + } - d.conn.Close() + if c := d.getConn(); d != nil { + c.Close() + d.Lock() d.conn = nil + d.Unlock() } + d.Lock() d.connected = false d.Unlock() + + d.Dispatch(events.Disconnected{}) } func (d *Device) Send(urn, sourceId, destinationId string, payload responses.Payload) error { @@ -161,7 +173,7 @@ func (d *Device) Send(urn, sourceId, destinationId string, payload responses.Pay } if *message.Namespace != "urn:x-cast:com.google.cast.tp.heartbeat" { - logrus.Debug("Writing:", spew.Sdump(message)) + logrus.Debugf("Writing to %s: %s", d.Name(), spew.Sdump(message)) } if d.conn == nil { diff --git a/discovery/service.go b/discovery/service.go index 5a8e284..9c7cfd3 100644 --- a/discovery/service.go +++ b/discovery/service.go @@ -2,13 +2,16 @@ package discovery import ( + "context" "fmt" "regexp" "strconv" "strings" + "sync/atomic" "time" "github.com/micro/mdns" + "github.com/sirupsen/logrus" "github.com/stampzilla/gocast" ) @@ -16,50 +19,57 @@ type Service struct { found chan *gocast.Device entriesCh chan *mdns.ServiceEntry - foundDevices map[string]*gocast.Device - stopPeriodic chan struct{} + foundDevices map[string]*gocast.Device + periodicRunning uint32 + stop context.CancelFunc } func NewService() *Service { - s := &Service{ + return &Service{ found: make(chan *gocast.Device), entriesCh: make(chan *mdns.ServiceEntry), - foundDevices: make(map[string]*gocast.Device, 0), + foundDevices: make(map[string]*gocast.Device), } - - go s.listner() - - return s } -func (d *Service) Periodic(interval time.Duration) error { - if d.stopPeriodic != nil { +func (d *Service) periodic(ctx context.Context, interval time.Duration) error { + if i := atomic.LoadUint32(&d.periodicRunning); i != 0 { return fmt.Errorf("Periodic discovery is already running") } - mdns.Query(&mdns.QueryParam{ + c, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := mdns.Query(&mdns.QueryParam{ Service: "_googlecast._tcp", Domain: "local", Timeout: time.Second * 1, Entries: d.entriesCh, + Context: c, }) + if err != nil { + logrus.Error("error doing mdns query: ", err) + } + ticker := time.NewTicker(interval) - d.stopPeriodic = make(chan struct{}) + atomic.AddUint32(&d.periodicRunning, 1) go func() { for { - mdns.Query(&mdns.QueryParam{ + err := mdns.Query(&mdns.QueryParam{ Service: "_googlecast._tcp", Domain: "local", Timeout: time.Second * 1, Entries: d.entriesCh, + Context: c, }) + if err != nil { + logrus.Error("error doing mdns query: ", err) + } select { case <-ticker.C: - case <-d.stopPeriodic: + case <-ctx.Done(): ticker.Stop() - d.foundDevices = make(map[string]*gocast.Device, 0) - + logrus.Debug("stopping periodic goroutine") return } } @@ -68,10 +78,20 @@ func (d *Service) Periodic(interval time.Duration) error { return nil } +func (d *Service) Start(pCtx context.Context, interval time.Duration) { + ctx, cancel := context.WithCancel(pCtx) + d.stop = cancel + + go d.listner(ctx) + err := d.periodic(ctx, interval) + if err != nil { + logrus.Error("error starting periodic mdns query: ", err) + } +} + func (d *Service) Stop() { - if d.stopPeriodic != nil { - close(d.stopPeriodic) - d.stopPeriodic = nil + if d.stop != nil { + d.stop() } } @@ -79,50 +99,56 @@ func (d *Service) Found() chan *gocast.Device { return d.found } -func (d *Service) listner() { - for entry := range d.entriesCh { - // fmt.Printf("Got new entry: %#v\n", entry) - - name := strings.Split(entry.Name, "._googlecast") - - // Skip everything that dont have googlecast in the fdqn - if len(name) < 2 { - continue - } - - info := decodeTxtRecord(entry.Info) - key := info["id"] // Use device ID as key, allowes the device to change IP +func (d *Service) listner(ctx context.Context) { + for { + select { + case <-ctx.Done(): + logrus.Debug("stopping listner goroutine") + d.foundDevices = make(map[string]*gocast.Device) + return + case entry := <-d.entriesCh: + // fmt.Printf("Got new entry: %#v\n", entry) + + name := strings.Split(entry.Name, "._googlecast") + + // Skip everything that dont have googlecast in the fdqn + if len(name) < 2 { + continue + } - if dev, ok := d.foundDevices[key]; ok { - // If not connected, update address and reconnect - if !dev.Connected() { - dev.SetIp(entry.AddrV4) - dev.SetPort(entry.Port) - dev.Reconnect() + info := decodeTxtRecord(entry.Info) + key := info["id"] // Use device ID as key, allowes the device to change IP + + if dev, ok := d.foundDevices[key]; ok { + // If not connected, update address so we can reconnect to it + if !dev.Connected() { + dev.SetIp(entry.AddrV4) + dev.SetPort(entry.Port) + } + // Skip already connected devices + continue } - // Skip already connected devices - continue - } - device := gocast.NewDevice() - device.SetIp(entry.AddrV4) - device.SetPort(entry.Port) + device := gocast.NewDevice() + device.SetIp(entry.AddrV4) + device.SetPort(entry.Port) - device.SetUuid(key) - device.SetName(info["fn"]) + device.SetUuid(key) + device.SetName(info["fn"]) - d.foundDevices[key] = device + d.foundDevices[key] = device - select { - case d.found <- device: - case <-time.After(time.Second): + select { + case d.found <- device: + case <-time.After(time.Second): + } } } } func decodeDnsEntry(text string) string { - text = strings.Replace(text, `\.`, ".", -1) - text = strings.Replace(text, `\ `, " ", -1) + text = strings.ReplaceAll(text, `\.`, ".") + text = strings.ReplaceAll(text, `\ `, " ") re := regexp.MustCompile(`([\\][0-9][0-9][0-9])`) text = re.ReplaceAllStringFunc(text, func(source string) string { diff --git a/discovery/service_test.go b/discovery/service_test.go index 05d6804..1156b9a 100644 --- a/discovery/service_test.go +++ b/discovery/service_test.go @@ -13,7 +13,6 @@ func TestDecodeDnsEntry(t *testing.T) { assert.Equal(t, result, "Stamp.. Är En Liten Fisk") } -} func TestDecodeTxtRecord(t *testing.T) { source := `id=87cf98a003f1f1dbd2efe6d19055a617|ve=04|md=Chromecast|ic=/setup/icon.png|fn=Chromecast PO|ca=5|st=0|bs=FA8FCA7EE8A9|rs=` diff --git a/example/main.go b/example/main.go index c72d489..95163e5 100644 --- a/example/main.go +++ b/example/main.go @@ -2,22 +2,25 @@ package main import ( + "context" "fmt" "time" + "github.com/sirupsen/logrus" "github.com/stampzilla/gocast/discovery" "github.com/stampzilla/gocast/events" ) func main() { + logrus.SetLevel(logrus.DebugLevel) discovery := discovery.NewService() go discoveryListner(discovery) // Start a periodic discovery fmt.Println("Start discovery") - discovery.Periodic(time.Second * 10) - <-time.After(time.Second * 30) + discovery.Start(context.Background(), time.Second*10) + <-time.After(time.Second * 15) fmt.Println("Stop discovery") discovery.Stop() @@ -33,19 +36,17 @@ func discoveryListner(discovery *discovery.Service) { // device.Subscribe("urn:x-cast:plex", plexHandler) // device.Subscribe("urn:x-cast:com.google.cast.media", mediaHandler) + d := device device.OnEvent(func(event events.Event) { switch data := event.(type) { case events.Connected: - fmt.Println(device.Name(), "- Connected, weeihoo") + fmt.Println(d.Name(), "- Connected, weeihoo") case events.Disconnected: - fmt.Println(device.Name(), "- Disconnected, bah :/") - - // Try to reconnect again - device.Connect() + fmt.Println(d.Name(), "- Disconnected, bah :/") case events.AppStarted: - fmt.Println(device.Name(), "- App started:", data.DisplayName, "(", data.AppID, ")") + fmt.Println(d.Name(), "- App started:", data.DisplayName, "(", data.AppID, ")") case events.AppStopped: - fmt.Println(device.Name(), "- App stopped:", data.DisplayName, "(", data.AppID, ")") + fmt.Println(d.Name(), "- App stopped:", data.DisplayName, "(", data.AppID, ")") // gocast.MediaEvent: // plexEvent: default: @@ -53,7 +54,7 @@ func discoveryListner(discovery *discovery.Service) { } }) - device.Connect() + device.Connect(context.Background()) //go func() { //<-time.After(time.Second * 10) diff --git a/handlers/connection.go b/handlers/connection.go index d3d33f5..30db057 100644 --- a/handlers/connection.go +++ b/handlers/connection.go @@ -14,9 +14,13 @@ func (c *Connection) Connect() { } func (c *Connection) Disconnect() { - c.Send(&responses.Headers{Type: "CLOSE"}) + logrus.Debug("sending disconnect from connection handler") + err := c.Send(&responses.Headers{Type: "CLOSE"}) + if err != nil { + logrus.Error("error sending disconnect: ", err) + } } func (c *Connection) Unmarshal(message string) { - logrus.Info("Connection received: ", message) + logrus.Debug("Connection received: ", message) } diff --git a/handlers/heartbeat.go b/handlers/heartbeat.go index f3e6fc1..7e4aeaa 100644 --- a/handlers/heartbeat.go +++ b/handlers/heartbeat.go @@ -1,6 +1,7 @@ package handlers import ( + "context" "time" "github.com/stampzilla/gocast/responses" @@ -9,30 +10,34 @@ import ( type Heartbeat struct { OnFailure func() baseHandler - ticker *time.Ticker - shutdown chan struct{} receivedAnswer chan struct{} + stop context.CancelFunc +} + +func NewHeartbeat() *Heartbeat { + return &Heartbeat{ + receivedAnswer: make(chan struct{}), + } } func (h *Heartbeat) Connect() { - if h.ticker != nil { - h.ticker.Stop() - if h.shutdown != nil { - close(h.shutdown) - h.shutdown = nil - } + if h.stop != nil { + h.stop() } - h.ticker = time.NewTicker(time.Second * 5) - h.shutdown = make(chan struct{}) - h.receivedAnswer = make(chan struct{}) + //TODO take context from parent + ctx, s := context.WithCancel(context.Background()) + h.stop = s + go func() { + ticker := time.NewTicker(time.Second * 5) + defer ticker.Stop() for { // Send out a ping select { - case <-h.ticker.C: + case <-ticker.C: h.Ping() - case <-h.shutdown: + case <-ctx.Done(): return } @@ -40,7 +45,7 @@ func (h *Heartbeat) Connect() { select { case <-time.After(time.Second * 10): h.OnFailure() - case <-h.shutdown: + case <-ctx.Done(): return case <-h.receivedAnswer: // everything great, carry on @@ -50,19 +55,13 @@ func (h *Heartbeat) Connect() { } func (h *Heartbeat) Disconnect() { - if h.ticker != nil { - h.ticker.Stop() - if h.shutdown != nil { - close(h.shutdown) - h.shutdown = nil - } + if h.stop != nil { + h.stop() } } +// Unmarshal takes the message and notifies our timeout goroutine to check if we get pong or not. func (h *Heartbeat) Unmarshal(message string) { - // fmt.Println("Heartbeat received: ", message) - - // Try to notify our timeout montor select { case h.receivedAnswer <- struct{}{}: case <-time.After(time.Second): diff --git a/packetstream.go b/packetstream.go index 15fbd77..09b882d 100644 --- a/packetstream.go +++ b/packetstream.go @@ -1,6 +1,7 @@ package gocast import ( + "context" "encoding/binary" "fmt" "io" @@ -14,25 +15,25 @@ type packetStream struct { } type packetContainer struct { - payload *[]byte + payload []byte err error } func NewPacketStream(stream io.ReadWriteCloser) *packetStream { - wrapper := packetStream{ + return &packetStream{ stream: stream, packets: make(chan packetContainer), } - wrapper.readPackets() - - return &wrapper } -func (w *packetStream) readPackets() { +func (w *packetStream) readPackets(ctx context.Context) { var length uint32 go func() { for { + if ctx.Err() != nil { + logrus.Errorf("closing packetStream reader %s", ctx.Err()) + } err := binary.Read(w.stream, binary.BigEndian, &length) if err != nil { logrus.Errorf("Failed binary.Read packet: %s", err) @@ -40,8 +41,6 @@ func (w *packetStream) readPackets() { return } - // TODO make sure this goroutine is killed on disconnect - if length > 0 { packet := make([]byte, length) @@ -57,7 +56,7 @@ func (w *packetStream) readPackets() { } w.packets <- packetContainer{ - payload: &packet, + payload: packet, err: nil, } } @@ -65,14 +64,6 @@ func (w *packetStream) readPackets() { }() } -func (w *packetStream) Read() (*[]byte, error) { - pkt := <-w.packets - if pkt.err != nil { - close(w.packets) - } - return pkt.payload, pkt.err -} - func (w *packetStream) Write(data []byte) (int, error) { err := binary.Write(w.stream, binary.BigEndian, uint32(len(data))) if err != nil {