From 1c72a57f420a7c95a6fa669dccb01ec219dcbcf9 Mon Sep 17 00:00:00 2001 From: Aotokitsuruya Date: Wed, 11 Sep 2024 20:58:36 +0800 Subject: [PATCH] Refactor to make publish by device --- internal/ctrl/bootstrap.go | 2 +- internal/ctrl/publish.go | 53 +++++++------- internal/ctrl/repository.go | 2 + internal/daemon/daemon.go | 2 +- internal/entity/device.go | 8 ++- internal/entity/peer.go | 8 +-- internal/repo/devices.go | 12 ++++ internal/repo/devices_test.go | 51 ++++++++++++- internal/repo/peers.go | 14 ++++ internal/repo/peers_test.go | 130 +++++++++++++++++++++++++++++++++- 10 files changed, 246 insertions(+), 36 deletions(-) diff --git a/internal/ctrl/bootstrap.go b/internal/ctrl/bootstrap.go index ffa6e52..89e005b 100644 --- a/internal/ctrl/bootstrap.go +++ b/internal/ctrl/bootstrap.go @@ -34,6 +34,7 @@ func (ctrl *BootstrapController) Execute(ctx context.Context) { deviceEntity := entity.NewDevice( entity.DeviceId(device.Name), + device.ListenPort, device.PrivateKey[:], ) @@ -43,7 +44,6 @@ func (ctrl *BootstrapController) Execute(ctx context.Context) { peer := entity.NewPeer( entity.NewPeerId(device.PublicKey[:], p.PublicKey[:]), device.Name, - device.ListenPort, p.PublicKey, ) diff --git a/internal/ctrl/publish.go b/internal/ctrl/publish.go index 56be24e..9978af5 100644 --- a/internal/ctrl/publish.go +++ b/internal/ctrl/publish.go @@ -4,7 +4,6 @@ import ( "context" "log" - "github.com/tjjh89017/stunmesh-go/internal/entity" "github.com/tjjh89017/stunmesh-go/plugin" ) @@ -26,36 +25,42 @@ func NewPublishController(devices DeviceRepository, peers PeerRepository, store } } -func (c *PublishController) Execute(ctx context.Context, peerId entity.PeerId) { - peer, err := c.peers.Find(ctx, peerId) +func (c *PublishController) Execute(ctx context.Context) { + devices, err := c.devices.List(ctx) if err != nil { log.Print(err) return } - device, err := c.devices.Find(ctx, entity.DeviceId(peer.DeviceName())) - if err != nil { - log.Print(err) - return - } + for _, device := range devices { + host, port, err := c.resolver.Resolve(ctx, uint16(device.ListenPort())) + if err != nil { + log.Panic(err) + } - host, port, err := c.resolver.Resolve(ctx, uint16(peer.ListenPort())) - if err != nil { - log.Panic(err) - } + peers, err := c.peers.ListByDevice(ctx, device.Name()) + if err != nil { + log.Print(err) + continue + } - res, err := c.encryptor.Encrypt(ctx, &EndpointEncryptRequest{ - PeerPublicKey: peer.PublicKey(), - PrivateKey: device.PrivateKey(), - Host: host, - Port: port, - }) - if err != nil { - log.Panic(err) - } + for _, peer := range peers { + res, err := c.encryptor.Encrypt(ctx, &EndpointEncryptRequest{ + PeerPublicKey: peer.PublicKey(), + PrivateKey: device.PrivateKey(), + Host: host, + Port: port, + }) + if err != nil { + log.Print(err) + continue + } - err = c.store.Set(ctx, peer.LocalId(), res.Data) - if err != nil { - log.Panic(err) + err = c.store.Set(ctx, peer.LocalId(), res.Data) + if err != nil { + log.Print(err) + continue + } + } } } diff --git a/internal/ctrl/repository.go b/internal/ctrl/repository.go index 38b7ed1..3819a13 100644 --- a/internal/ctrl/repository.go +++ b/internal/ctrl/repository.go @@ -7,12 +7,14 @@ import ( ) type DeviceRepository interface { + List(ctx context.Context) ([]*entity.Device, error) Find(ctx context.Context, name entity.DeviceId) (*entity.Device, error) Save(ctx context.Context, device *entity.Device) } type PeerRepository interface { List(ctx context.Context) ([]*entity.Peer, error) + ListByDevice(ctx context.Context, deviceName entity.DeviceId) ([]*entity.Peer, error) Find(ctx context.Context, id entity.PeerId) (*entity.Peer, error) Save(ctx context.Context, peer *entity.Peer) } diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 7c7e042..cd41008 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -62,11 +62,11 @@ func (d *Daemon) Run(ctx context.Context) { case peerId := <-d.queue.Dequeue(): log.Printf("Processing peer %s", peerId) - go d.publishCtrl.Execute(daemonCtx, peerId) go d.establishCtrl.Execute(daemonCtx, peerId) case <-ticker.C: log.Println("Refreshing peers") + go d.publishCtrl.Execute(daemonCtx) go d.refreshCtrl.Execute(daemonCtx) } } diff --git a/internal/entity/device.go b/internal/entity/device.go index b188a84..944ea25 100644 --- a/internal/entity/device.go +++ b/internal/entity/device.go @@ -10,12 +10,14 @@ type DeviceId string type Device struct { name DeviceId + listenPort int privateKey []byte } -func NewDevice(name DeviceId, privateKey []byte) *Device { +func NewDevice(name DeviceId, listenPort int, privateKey []byte) *Device { return &Device{ name: name, + listenPort: listenPort, privateKey: privateKey, } } @@ -29,3 +31,7 @@ func (d *Device) PrivateKey() [32]byte { copy(key[:], d.privateKey) return key } + +func (d *Device) ListenPort() int { + return d.listenPort +} diff --git a/internal/entity/peer.go b/internal/entity/peer.go index 110bb1b..050eb65 100644 --- a/internal/entity/peer.go +++ b/internal/entity/peer.go @@ -10,14 +10,12 @@ type Peer struct { id PeerId deviceName string publicKey [32]byte - listenPort int } -func NewPeer(id PeerId, deviceName string, listenPort int, publicKey [32]byte) *Peer { +func NewPeer(id PeerId, deviceName string, publicKey [32]byte) *Peer { return &Peer{ id: id, deviceName: deviceName, - listenPort: listenPort, publicKey: publicKey, } } @@ -41,7 +39,3 @@ func (p *Peer) DeviceName() string { func (p *Peer) PublicKey() [32]byte { return p.publicKey } - -func (p *Peer) ListenPort() int { - return p.listenPort -} diff --git a/internal/repo/devices.go b/internal/repo/devices.go index 31840ba..b05c043 100644 --- a/internal/repo/devices.go +++ b/internal/repo/devices.go @@ -18,6 +18,18 @@ func NewDevices() *Devices { } } +func (r *Devices) List(ctx context.Context) ([]*entity.Device, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + devices := make([]*entity.Device, 0, len(r.items)) + for _, device := range r.items { + devices = append(devices, device) + } + + return devices, nil +} + func (r *Devices) Find(ctx context.Context, name entity.DeviceId) (*entity.Device, error) { r.mutex.RLock() defer r.mutex.RUnlock() diff --git a/internal/repo/devices_test.go b/internal/repo/devices_test.go index d0ace0b..c0029eb 100644 --- a/internal/repo/devices_test.go +++ b/internal/repo/devices_test.go @@ -10,7 +10,7 @@ import ( func Test_DeviceFind(t *testing.T) { deviceName := entity.DeviceId("wg0") - device := entity.NewDevice(deviceName, []byte{}) + device := entity.NewDevice(deviceName, 6379, []byte{}) devices := repo.NewDevices() devices.Save(context.TODO(), device) @@ -44,3 +44,52 @@ func Test_DeviceFind(t *testing.T) { }) } } + +func Test_DeviceList(t *testing.T) { + tests := []struct { + name string + devices []*entity.Device + }{ + { + name: "no devices", + devices: []*entity.Device{}, + }, + { + name: "single device", + devices: []*entity.Device{ + entity.NewDevice(entity.DeviceId("wg0"), 6379, []byte{}), + }, + }, + { + name: "multiple devices", + devices: []*entity.Device{ + entity.NewDevice(entity.DeviceId("wg0"), 6379, []byte{}), + entity.NewDevice(entity.DeviceId("wg1"), 6380, []byte{}), + entity.NewDevice(entity.DeviceId("wg2"), 6381, []byte{}), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + devices := repo.NewDevices() + for _, device := range tt.devices { + devices.Save(context.TODO(), device) + } + + entities, err := devices.List(context.TODO()) + if err != nil { + t.Errorf("DeviceRepository.List() error = %v", err) + return + } + + expectedSize := len(tt.devices) + actualSize := len(entities) + if actualSize != expectedSize { + t.Errorf("DeviceRepository.List() size = %v, want %v", actualSize, expectedSize) + } + }) + } +} diff --git a/internal/repo/peers.go b/internal/repo/peers.go index 9cb6def..4f0aeb1 100644 --- a/internal/repo/peers.go +++ b/internal/repo/peers.go @@ -33,6 +33,20 @@ func (r *Peers) List(ctx context.Context) ([]*entity.Peer, error) { return peers, nil } +func (r *Peers) ListByDevice(ctx context.Context, deviceName entity.DeviceId) ([]*entity.Peer, error) { + r.mutex.RLock() + defer r.mutex.RUnlock() + + peers := make([]*entity.Peer, 0) + for _, peer := range r.entities { + if peer.DeviceName() == string(deviceName) { + peers = append(peers, peer) + } + } + + return peers, nil +} + func (r *Peers) Find(ctx context.Context, id entity.PeerId) (*entity.Peer, error) { r.mutex.RLock() defer r.mutex.RUnlock() diff --git a/internal/repo/peers_test.go b/internal/repo/peers_test.go index c618bdb..4b2d995 100644 --- a/internal/repo/peers_test.go +++ b/internal/repo/peers_test.go @@ -14,7 +14,6 @@ func Test_PeerRepository_Find(t *testing.T) { peer := entity.NewPeer( peerId, "wg0", - 8080, [32]byte{}, ) @@ -50,3 +49,132 @@ func Test_PeerRepository_Find(t *testing.T) { }) } } + +func Test_PeerRepository_List(t *testing.T) { + tests := []struct { + name string + peers []*entity.Peer + }{ + { + name: "no peers", + peers: []*entity.Peer{}, + }, + { + name: "one peer", + peers: []*entity.Peer{ + entity.NewPeer( + entity.NewPeerId([]byte{}, []byte{}), + "wg0", + [32]byte{}, + ), + }, + }, + { + name: "two peers", + peers: []*entity.Peer{ + entity.NewPeer( + entity.NewPeerId([]byte{}, []byte{}), + "wg0", + [32]byte{}, + ), + entity.NewPeer( + entity.NewPeerId([]byte{1}, []byte{1}), + "wg1", + [32]byte{}, + ), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + peers := repo.NewPeers() + for _, peer := range tt.peers { + peers.Save(context.TODO(), peer) + } + + entities, err := peers.List(context.TODO()) + if err != nil { + t.Errorf("PeerRepository.List() error = %v", err) + return + } + + expectedSize := len(tt.peers) + actualSize := len(entities) + if actualSize != expectedSize { + t.Errorf("PeerRepository.List() = %v, want %v", actualSize, expectedSize) + return + } + }) + } +} + +func Test_PeerListByDevice(t *testing.T) { + tests := []struct { + name string + deviceName entity.DeviceId + peers []*entity.Peer + expected int + }{ + { + name: "no peers", + deviceName: "wg0", + peers: []*entity.Peer{}, + expected: 0, + }, + { + name: "one peer", + deviceName: "wg0", + peers: []*entity.Peer{ + entity.NewPeer( + entity.NewPeerId([]byte{}, []byte{}), + "wg0", + [32]byte{}, + ), + }, + expected: 1, + }, + { + name: "two peers with one matching device", + deviceName: "wg0", + peers: []*entity.Peer{ + entity.NewPeer( + entity.NewPeerId([]byte{}, []byte{}), + "wg0", + [32]byte{}, + ), + entity.NewPeer( + entity.NewPeerId([]byte{1}, []byte{1}), + "wg1", + [32]byte{}, + ), + }, + expected: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + peers := repo.NewPeers() + for _, peer := range tt.peers { + peers.Save(context.TODO(), peer) + } + + entities, err := peers.ListByDevice(context.TODO(), tt.deviceName) + if err != nil { + t.Errorf("PeerRepository.ListByDevice() error = %v", err) + return + } + + actualSize := len(entities) + if actualSize != tt.expected { + t.Errorf("PeerRepository.ListByDevice() = %v, want %v", actualSize, tt.expected) + return + } + }) + } +}