Skip to content

Commit

Permalink
Fix reconnect bug and alot of races
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonas Falck committed Dec 9, 2021
1 parent b3419c6 commit 5d5322f
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 182 deletions.
49 changes: 43 additions & 6 deletions device.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gocast

import (
"context"
"net"
"strconv"
"sync"
Expand All @@ -22,6 +23,8 @@ type Device struct {
wrapper *packetStream
reconnect chan struct{}

stop context.CancelFunc

eventListners []func(event events.Event)
subscriptions map[string]*Subscription

Expand All @@ -36,49 +39,83 @@ func NewDevice() *Device {
reconnect: make(chan struct{}),
subscriptions: make(map[string]*Subscription),
connectionHandler: &handlers.Connection{},
heartbeatHandler: &handlers.Heartbeat{},
heartbeatHandler: handlers.NewHeartbeat(),
ReceiverHandler: &handlers.Receiver{},
}

d.heartbeatHandler.OnFailure = func() {
d.Disconnect()
}

return d
}

func (d *Device) SetName(name string) {
d.Lock()
d.name = name
d.Unlock()
}

func (d *Device) SetUuid(uuid string) {
d.Lock()
d.uuid = uuid
d.Unlock()
}

func (d *Device) SetIp(ip net.IP) {
d.Lock()
d.ip = ip
d.Unlock()
}

func (d *Device) SetPort(port int) {
d.Lock()
d.port = port
d.Unlock()
}

func (d *Device) Name() string {
d.RLock()
defer d.RUnlock()
return d.name
}

func (d *Device) Uuid() string {
d.RLock()
defer d.RUnlock()
return d.uuid
}

func (d *Device) Ip() net.IP {
d.RLock()
defer d.RUnlock()
return d.ip
}

func (d *Device) Port() int {
d.RLock()
defer d.RUnlock()
return d.port
}

func (d *Device) Connected() bool {
d.RLock()
defer d.RUnlock()
return d.connected
}
func (d *Device) getConn() net.Conn {
d.RLock()
defer d.RUnlock()
return d.conn
}
func (d *Device) getSubscriptionsAsSlice() []*Subscription {
d.RLock()
subs := make([]*Subscription, len(d.subscriptions))
i := 0
for _, v := range d.subscriptions {
subs[i] = v
i++
}
defer d.RUnlock()
return subs
}

func (d *Device) String() string {
return d.name + " - " + d.ip.String() + ":" + strconv.Itoa(d.port)
}
Expand Down Expand Up @@ -117,10 +154,10 @@ func (d *Device) UnsubscribeByUrn(urn string) {
}
d.RUnlock()
d.Lock()
defer d.Unlock()
for _, sub := range subs {
delete(d.subscriptions, sub)
}
d.Unlock()
}

func (d *Device) UnsubscribeByUrnAndDestinationId(urn, destinationId string) {
Expand Down
154 changes: 83 additions & 71 deletions device_connection.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gocast

import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
Expand All @@ -14,88 +15,86 @@ import (
"github.com/stampzilla/gocast/responses"
)

func (d *Device) reader() {
func (d *Device) reader(ctx context.Context) {
for {
packet, err := d.wrapper.Read()
if err != nil {
logrus.Errorf("Error reading from chromecast error: %s Packet: %#v", err, packet)
d.Disconnect()
select {
case <-ctx.Done():
if ctx.Err() != nil {
logrus.Errorf("closing reader %s: %s", d.Name(), ctx.Err())
}
return
}
case p := <-d.wrapper.packets:
if p.err != nil {
logrus.Errorf("Error reading from chromecast error: %s Packet: %#v", p.err, p)
return
}
packet := p.payload

message := &api.CastMessage{}
err = proto.Unmarshal(*packet, message)
if err != nil {
logrus.Errorf("Failed to unmarshal CastMessage: %s", err)
continue
}
message := &api.CastMessage{}
err := proto.Unmarshal(packet, message)
if err != nil {
logrus.Errorf("Failed to unmarshal CastMessage: %s", err)
continue
}

headers := &responses.Headers{}
headers := &responses.Headers{}

err = json.Unmarshal([]byte(*message.PayloadUtf8), headers)
err = json.Unmarshal([]byte(*message.PayloadUtf8), headers)

if err != nil {
logrus.Errorf("Failed to unmarshal message: %s", err)
continue
}
if err != nil {
logrus.Errorf("Failed to unmarshal message: %s", err)
continue
}

catched := false
d.RLock()
for _, subscription := range d.subscriptions {
if subscription.Receive(message, headers) {
catched = true
catched := false
for _, subscription := range d.getSubscriptionsAsSlice() {
if subscription.Receive(message, headers) {
catched = true
}
}
}
d.RUnlock()

if !catched {
logrus.Debug("LOST MESSAGE:")
logrus.Debug(spew.Sdump(message))
if !catched {
logrus.Debug("LOST MESSAGE:")
logrus.Debug(spew.Sdump(message))
}
}
}
}

func (d *Device) Connected() bool {
d.RLock()
defer d.RUnlock()
return d.connected
}

func (d *Device) Connect() error {
go d.reconnector()
return d.connect()
}
func (d *Device) Connect(ctx context.Context) error {
d.heartbeatHandler.OnFailure = func() { // make sure we reconnect if we loose heartbeat
logrus.Errorf("heartbeat timeout for: %s trying to reconnect", d.Name())

func (d *Device) Reconnect() {
select {
case d.reconnect <- struct{}{}:
default:
}
}

func (d *Device) reconnector() {
for {
select {
case <-d.reconnect:
logrus.Info("Reconnect signal received")
time.Sleep(time.Second * 2)
err := d.connect()
if err != nil {
logrus.Error(err)
d.Disconnect()
for { // try to connect until no error
err := d.connect(ctx)
if err == nil {
break
}
logrus.Error("error reconnect: ", err)
time.Sleep(2 * time.Second)
}
}
return d.connect(ctx)
}

func (d *Device) connect() error {
logrus.Infof("connecting to %s:%d ...", d.ip, d.port)
func (d *Device) connect(pCtx context.Context) error {
ctx, cancel := context.WithCancel(pCtx)
d.stop = cancel

ip := d.Ip()
port := d.Port()

if d.conn != nil {
return fmt.Errorf("already connected to: %s (%s:%d)", d.Name(), d.Ip().String(), d.Port())
logrus.Infof("connecting to %s:%d ...", ip, port)

if d.getConn() != nil {
err := d.conn.Close()
if err != nil {
logrus.Error("trying to connect with existing connection. error closing: ", err)
}
}

var err error
d.conn, err = tls.Dial("tcp", fmt.Sprintf("%s:%d", d.ip, d.port), &tls.Config{
conn, err := tls.Dial("tcp", fmt.Sprintf("%s:%d", ip, port), &tls.Config{
InsecureSkipVerify: true,
})

Expand All @@ -104,37 +103,50 @@ func (d *Device) connect() error {
}

d.Lock()
d.conn = conn
d.connected = true
d.Unlock()

d.Dispatch(events.Connected{})

d.wrapper = NewPacketStream(d.conn)
go d.reader()
go d.wrapper.readPackets(ctx)
go d.reader(ctx)

d.Subscribe("urn:x-cast:com.google.cast.tp.connection", "receiver-0", d.connectionHandler)
d.Subscribe("urn:x-cast:com.google.cast.tp.heartbeat", "receiver-0", d.heartbeatHandler)
d.Subscribe("urn:x-cast:com.google.cast.receiver", "receiver-0", d.ReceiverHandler)

d.Dispatch(events.Connected{})

return nil
}

func (d *Device) Disconnect() {
logrus.Debug("disconnecting: ", d.Name())

for _, subscription := range d.getSubscriptionsAsSlice() {
logrus.Debugf("disconnect subscription %s: %s ", d.Name(), subscription.Urn)
subscription.Handler.Disconnect()
}
d.Lock()
if d.conn != nil {
for _, subscription := range d.subscriptions {
subscription.Handler.Disconnect()
}
d.subscriptions = make(map[string]*Subscription)
d.Unlock()

d.subscriptions = make(map[string]*Subscription, 0)
d.Dispatch(events.Disconnected{})
if d.stop != nil { // make sure any old goroutines are stopped
d.stop()
}

d.conn.Close()
if c := d.getConn(); d != nil {
c.Close()
d.Lock()
d.conn = nil
d.Unlock()
}

d.Lock()
d.connected = false
d.Unlock()

d.Dispatch(events.Disconnected{})
}

func (d *Device) Send(urn, sourceId, destinationId string, payload responses.Payload) error {
Expand All @@ -161,7 +173,7 @@ func (d *Device) Send(urn, sourceId, destinationId string, payload responses.Pay
}

if *message.Namespace != "urn:x-cast:com.google.cast.tp.heartbeat" {
logrus.Debug("Writing:", spew.Sdump(message))
logrus.Debugf("Writing to %s: %s", d.Name(), spew.Sdump(message))
}

if d.conn == nil {
Expand Down
Loading

0 comments on commit 5d5322f

Please sign in to comment.