diff --git a/go.sum b/go.sum index d1d0e38..93927e1 100644 --- a/go.sum +++ b/go.sum @@ -18,6 +18,7 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI= github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA= @@ -113,6 +114,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= +golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -168,6 +171,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= +golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo= diff --git a/internal/config/config.go b/internal/config/config.go index 25b811d..7d159e3 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -4,7 +4,15 @@ import ( "errors" "time" + "github.com/google/wire" "github.com/spf13/viper" + "github.com/tjjh89017/stunmesh-go/internal/entity" +) + +var DefaultSet = wire.NewSet( + Load, + NewDeviceConfig, + wire.Bind(new(entity.PeerAllower), new(*DeviceConfig)), ) const Name = "config" @@ -30,16 +38,6 @@ var envs = map[string][]string{ "refresh_interval": {"REFRESH_INTERVAL"}, } -type Peer struct { - Description string `mapstructure:"description"` - PublicKey string `mapstructure:"public_key"` -} - -type Interface struct { - Peers map[string]Peer `mapstructure:"peers"` -} -type Interfaces map[string]Interface - type Logger struct { Level string `mapstructure:"level"` } diff --git a/internal/config/device.go b/internal/config/device.go new file mode 100644 index 0000000..c5467f0 --- /dev/null +++ b/internal/config/device.go @@ -0,0 +1,55 @@ +package config + +import ( + "context" + "encoding/base64" + + "github.com/rs/zerolog" + "github.com/tjjh89017/stunmesh-go/internal/entity" +) + +type Peer struct { + Description string `mapstructure:"description"` + PublicKey string `mapstructure:"public_key"` +} + +type Interface struct { + Peers map[string]Peer `mapstructure:"peers"` +} +type Interfaces map[string]Interface + +var _ entity.PeerAllower = &DeviceConfig{} + +type DeviceConfig struct { + interfaces Interfaces +} + +func NewDeviceConfig(config *Config) *DeviceConfig { + return &DeviceConfig{ + interfaces: config.Interfaces, + } +} + +func (c *DeviceConfig) Allow(ctx context.Context, deviceName string, publicKey []byte, peerId entity.PeerId) bool { + logger := zerolog.Ctx(ctx) + + device, ok := c.interfaces[deviceName] + if !ok { + return false + } + + for _, peer := range device.Peers { + peerPublicKey, err := base64.StdEncoding.DecodeString(peer.PublicKey) + if err != nil { + logger.Error().Err(err).Str("device", deviceName).Str("public_key", peer.PublicKey).Msg("failed to decode public key") + continue + } + + currentPeerId := entity.NewPeerId(publicKey, peerPublicKey) + if peerId == currentPeerId { + return true + } + } + + return false +} diff --git a/internal/ctrl/bootstrap.go b/internal/ctrl/bootstrap.go index 0588499..63f8471 100644 --- a/internal/ctrl/bootstrap.go +++ b/internal/ctrl/bootstrap.go @@ -3,85 +3,68 @@ package ctrl import ( "context" - "encoding/base64" "github.com/rs/zerolog" "github.com/tjjh89017/stunmesh-go/internal/config" "github.com/tjjh89017/stunmesh-go/internal/entity" ) type BootstrapController struct { - wg WireGuardClient - config *config.Config - devices DeviceRepository - peers PeerRepository - logger zerolog.Logger + wg WireGuardClient + config *config.Config + devices DeviceRepository + peers PeerRepository + logger zerolog.Logger + filterService *entity.FilterPeerService } -func NewBootstrapController(wg WireGuardClient, config *config.Config, devices DeviceRepository, peers PeerRepository, logger *zerolog.Logger) *BootstrapController { +func NewBootstrapController(wg WireGuardClient, config *config.Config, devices DeviceRepository, peers PeerRepository, logger *zerolog.Logger, filterService *entity.FilterPeerService) *BootstrapController { return &BootstrapController{ - wg: wg, - config: config, - devices: devices, - peers: peers, - logger: logger.With().Str("controller", "bootstrap").Logger(), + wg: wg, + config: config, + devices: devices, + peers: peers, + logger: logger.With().Str("controller", "bootstrap").Logger(), + filterService: filterService, } } func (ctrl *BootstrapController) Execute(ctx context.Context) { - for deviceName, device := range ctrl.config.Interfaces { - if err := ctrl.registerDevice(ctx, deviceName, device.Peers); err != nil { + for deviceName := range ctrl.config.Interfaces { + if err := ctrl.registerDevice(ctx, deviceName); err != nil { ctrl.logger.Error().Err(err).Str("device", deviceName).Msg("failed to register device") continue } } } -func (ctrl *BootstrapController) registerDevice(ctx context.Context, deviceName string, peers map[string]config.Peer) error { - if len(peers) == 0 { - ctrl.logger.Warn().Str("device", deviceName).Msg("Peers list is empty.") - return nil - } - +func (ctrl *BootstrapController) registerDevice(ctx context.Context, deviceName string) error { device, err := ctrl.wg.Device(deviceName) if err != nil { return err } - peerCount := 0 - for _, p := range device.Peers { - base64PublicKey := base64.StdEncoding.EncodeToString(p.PublicKey[:]) - if name, ok := containsPeer(peers, base64PublicKey); ok { - peerCount += 1 - ctrl.logger.Info().Str("device", deviceName).Str("peer", name).Str("publicKey", base64PublicKey).Msg("Register Peer") - peer := entity.NewPeer( - entity.NewPeerId(device.PublicKey[:], p.PublicKey[:]), - device.Name, - p.PublicKey, - ) + deviceEntity := entity.NewDevice( + entity.DeviceId(device.Name), + device.ListenPort, + device.PrivateKey[:], + ) - ctrl.peers.Save(ctx, peer) - } + allowPeers, err := ctrl.filterService.Execute(ctx, deviceEntity.Name(), device.PublicKey[:]) + if err != nil { + ctrl.logger.Error().Err(err).Str("device", deviceName).Msg("failed to filter allowed peers") + return err } - if peerCount > 0 { - ctrl.logger.Info().Str("device", deviceName).Msg("Register Device") - deviceEntity := entity.NewDevice( - entity.DeviceId(device.Name), - device.ListenPort, - device.PrivateKey[:], - ) + isAnyPeerAllowed := len(allowPeers) > 0 + if !isAnyPeerAllowed { + ctrl.logger.Warn().Str("device", deviceName).Msg("no peer is allowed") + return nil + } - ctrl.devices.Save(ctx, deviceEntity) + ctrl.devices.Save(ctx, deviceEntity) + for _, peer := range allowPeers { + ctrl.peers.Save(ctx, peer) } return nil } - -func containsPeer(m map[string]config.Peer, publicKey string) (string, bool) { - for k, v := range m { - if v.PublicKey == publicKey { - return k, true - } - } - return "", false -} diff --git a/internal/ctrl/bootstrap_test.go b/internal/ctrl/bootstrap_test.go index 2114c69..77506f2 100644 --- a/internal/ctrl/bootstrap_test.go +++ b/internal/ctrl/bootstrap_test.go @@ -2,16 +2,56 @@ package ctrl_test import ( "context" + "errors" "testing" "github.com/rs/zerolog" "github.com/tjjh89017/stunmesh-go/internal/config" "github.com/tjjh89017/stunmesh-go/internal/ctrl" mock "github.com/tjjh89017/stunmesh-go/internal/ctrl/mock" + "github.com/tjjh89017/stunmesh-go/internal/entity" + mockEntity "github.com/tjjh89017/stunmesh-go/internal/entity/mock" gomock "go.uber.org/mock/gomock" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) +func TestBootstrap_WithError(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockWgClient := mock.NewMockWireGuardClient(mockCtrl) + mockDevices := mock.NewMockDeviceRepository(mockCtrl) + mockPeers := mock.NewMockPeerRepository(mockCtrl) + logger := zerolog.Nop() + cfg := &config.Config{ + Interfaces: map[string]config.Interface{ + "wg0": { + Peers: map[string]config.Peer{ + "test_peer1": { + PublicKey: "XgPRso34lnrSAx8nJtdj1/zlF7CoNj7B64LPElYdOGs=", + }, + }, + }, + }, + } + deviceConfig := config.NewDeviceConfig(cfg) + mockPeerSearcher := mockEntity.NewMockPeerSearcher(mockCtrl) + peerFilterService := entity.NewFilterPeerService(mockPeerSearcher, deviceConfig) + + mockWgClient.EXPECT().Device("wg0").Return(nil, errors.New("device not found")) + + bootstrap := ctrl.NewBootstrapController( + mockWgClient, + cfg, + mockDevices, + mockPeers, + &logger, + peerFilterService, + ) + + bootstrap.Execute(context.TODO()) +} + func TestBootstrap_WithMultipleInterfaces(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() @@ -20,27 +60,30 @@ func TestBootstrap_WithMultipleInterfaces(t *testing.T) { mockDevices := mock.NewMockDeviceRepository(mockCtrl) mockPeers := mock.NewMockPeerRepository(mockCtrl) logger := zerolog.Nop() - config := &config.Config{ + cfg := &config.Config{ Interfaces: map[string]config.Interface{ - "wg0": config.Interface{ + "wg0": { Peers: map[string]config.Peer{ - "test_peer1": config.Peer{ + "test_peer1": { PublicKey: "XgPRso34lnrSAx8nJtdj1/zlF7CoNj7B64LPElYdOGs=", }, }, }, - "wg1": config.Interface{ + "wg1": { Peers: map[string]config.Peer{ - "test_peer2": config.Peer{ + "test_peer2": { PublicKey: "FQ9/2l8t4xmQQbs6SB03+Lh2VijJX74rxRUOv7YT03k=", }, - "test_peer3": config.Peer{ + "test_peer3": { PublicKey: "Cud5HogJJLCppoUuHnWrSvEJuI49D01sQcfiD3Y9RRU=", }, }, }, }, } + mockPeerSearcher := mockEntity.NewMockPeerSearcher(mockCtrl) + deviceConfig := config.NewDeviceConfig(cfg) + peerFilterService := entity.NewFilterPeerService(mockPeerSearcher, deviceConfig) mockDevice0 := &wgtypes.Device{ Name: "wg0", @@ -72,12 +115,37 @@ func TestBootstrap_WithMultipleInterfaces(t *testing.T) { mockDevices.EXPECT().Save(gomock.Any(), gomock.Any()).Times(2) mockPeers.EXPECT().Save(gomock.Any(), gomock.Any()).Times(3) + mockDevice0Peers := []*entity.Peer{ + entity.NewPeer( + entity.NewPeerId(mockDevice0.PublicKey[:], mockDevice0.Peers[0].PublicKey[:]), + mockDevice0.Name, + mockDevice0.Peers[0].PublicKey, + ), + } + + mockDevice1Peers := []*entity.Peer{ + entity.NewPeer( + entity.NewPeerId(mockDevice1.PublicKey[:], mockDevice1.Peers[0].PublicKey[:]), + mockDevice1.Name, + mockDevice1.Peers[0].PublicKey, + ), + entity.NewPeer( + entity.NewPeerId(mockDevice1.PublicKey[:], mockDevice1.Peers[1].PublicKey[:]), + mockDevice1.Name, + mockDevice1.Peers[1].PublicKey, + ), + } + + mockPeerSearcher.EXPECT().SearchByDevice(gomock.Any(), entity.DeviceId("wg0")).Return(mockDevice0Peers, nil) + mockPeerSearcher.EXPECT().SearchByDevice(gomock.Any(), entity.DeviceId("wg1")).Return(mockDevice1Peers, nil) + bootstrap := ctrl.NewBootstrapController( mockWgClient, - config, + cfg, mockDevices, mockPeers, &logger, + peerFilterService, ) bootstrap.Execute(context.TODO()) diff --git a/internal/entity/entity.go b/internal/entity/entity.go new file mode 100644 index 0000000..12723d6 --- /dev/null +++ b/internal/entity/entity.go @@ -0,0 +1,7 @@ +package entity + +import "github.com/google/wire" + +var DefaultSet = wire.NewSet( + NewFilterPeerService, +) diff --git a/internal/entity/filter_peer.go b/internal/entity/filter_peer.go new file mode 100644 index 0000000..332219c --- /dev/null +++ b/internal/entity/filter_peer.go @@ -0,0 +1,40 @@ +//go:generate mockgen -destination=./mock/mock_peer.go -package=mock_entity . PeerSearcher,PeerAllower +package entity + +import "context" + +type PeerSearcher interface { + SearchByDevice(context.Context, DeviceId) ([]*Peer, error) +} + +type PeerAllower interface { + Allow(ctx context.Context, deviceName string, publicKey []byte, peerId PeerId) bool +} + +type FilterPeerService struct { + searcher PeerSearcher + allower PeerAllower +} + +func NewFilterPeerService(searcher PeerSearcher, allower PeerAllower) *FilterPeerService { + return &FilterPeerService{ + searcher: searcher, + allower: allower, + } +} + +func (svc *FilterPeerService) Execute(ctx context.Context, deviceName DeviceId, publicKey []byte) ([]*Peer, error) { + peers, err := svc.searcher.SearchByDevice(ctx, deviceName) + if err != nil { + return nil, err + } + + allowedPeers := make([]*Peer, 0, len(peers)) + for _, peer := range peers { + if svc.allower.Allow(ctx, string(deviceName), publicKey, peer.Id()) { + allowedPeers = append(allowedPeers, peer) + } + } + + return allowedPeers, nil +} diff --git a/internal/entity/mock/mock_peer.go b/internal/entity/mock/mock_peer.go new file mode 100644 index 0000000..91990c3 --- /dev/null +++ b/internal/entity/mock/mock_peer.go @@ -0,0 +1,93 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/tjjh89017/stunmesh-go/internal/entity (interfaces: PeerSearcher,PeerAllower) +// +// Generated by this command: +// +// mockgen -destination=./mock/mock_peer.go -package=mock_entity . PeerSearcher,PeerAllower +// + +// Package mock_entity is a generated GoMock package. +package mock_entity + +import ( + context "context" + reflect "reflect" + + entity "github.com/tjjh89017/stunmesh-go/internal/entity" + gomock "go.uber.org/mock/gomock" +) + +// MockPeerSearcher is a mock of PeerSearcher interface. +type MockPeerSearcher struct { + ctrl *gomock.Controller + recorder *MockPeerSearcherMockRecorder +} + +// MockPeerSearcherMockRecorder is the mock recorder for MockPeerSearcher. +type MockPeerSearcherMockRecorder struct { + mock *MockPeerSearcher +} + +// NewMockPeerSearcher creates a new mock instance. +func NewMockPeerSearcher(ctrl *gomock.Controller) *MockPeerSearcher { + mock := &MockPeerSearcher{ctrl: ctrl} + mock.recorder = &MockPeerSearcherMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPeerSearcher) EXPECT() *MockPeerSearcherMockRecorder { + return m.recorder +} + +// SearchByDevice mocks base method. +func (m *MockPeerSearcher) SearchByDevice(arg0 context.Context, arg1 entity.DeviceId) ([]*entity.Peer, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SearchByDevice", arg0, arg1) + ret0, _ := ret[0].([]*entity.Peer) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// SearchByDevice indicates an expected call of SearchByDevice. +func (mr *MockPeerSearcherMockRecorder) SearchByDevice(arg0, arg1 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchByDevice", reflect.TypeOf((*MockPeerSearcher)(nil).SearchByDevice), arg0, arg1) +} + +// MockPeerAllower is a mock of PeerAllower interface. +type MockPeerAllower struct { + ctrl *gomock.Controller + recorder *MockPeerAllowerMockRecorder +} + +// MockPeerAllowerMockRecorder is the mock recorder for MockPeerAllower. +type MockPeerAllowerMockRecorder struct { + mock *MockPeerAllower +} + +// NewMockPeerAllower creates a new mock instance. +func NewMockPeerAllower(ctrl *gomock.Controller) *MockPeerAllower { + mock := &MockPeerAllower{ctrl: ctrl} + mock.recorder = &MockPeerAllowerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPeerAllower) EXPECT() *MockPeerAllowerMockRecorder { + return m.recorder +} + +// Allow mocks base method. +func (m *MockPeerAllower) Allow(arg0 context.Context, arg1 string, arg2 []byte, arg3 entity.PeerId) bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Allow", arg0, arg1, arg2, arg3) + ret0, _ := ret[0].(bool) + return ret0 +} + +// Allow indicates an expected call of Allow. +func (mr *MockPeerAllowerMockRecorder) Allow(arg0, arg1, arg2, arg3 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Allow", reflect.TypeOf((*MockPeerAllower)(nil).Allow), arg0, arg1, arg2, arg3) +} diff --git a/internal/repo/api.go b/internal/repo/api.go new file mode 100644 index 0000000..20b9fde --- /dev/null +++ b/internal/repo/api.go @@ -0,0 +1,11 @@ +//go:generate mockgen -destination=./mock/mock_api.go -package=mock_repo . WireGuardClient + +package repo + +import ( + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +type WireGuardClient interface { + Device(deviceName string) (*wgtypes.Device, error) +} diff --git a/internal/repo/mock/mock_api.go b/internal/repo/mock/mock_api.go new file mode 100644 index 0000000..46db916 --- /dev/null +++ b/internal/repo/mock/mock_api.go @@ -0,0 +1,55 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/tjjh89017/stunmesh-go/internal/repo (interfaces: WireGuardClient) +// +// Generated by this command: +// +// mockgen -destination=./mock/mock_api.go -package=mock_repo . WireGuardClient +// + +// Package mock_repo is a generated GoMock package. +package mock_repo + +import ( + reflect "reflect" + + gomock "go.uber.org/mock/gomock" + wgtypes "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +// MockWireGuardClient is a mock of WireGuardClient interface. +type MockWireGuardClient struct { + ctrl *gomock.Controller + recorder *MockWireGuardClientMockRecorder +} + +// MockWireGuardClientMockRecorder is the mock recorder for MockWireGuardClient. +type MockWireGuardClientMockRecorder struct { + mock *MockWireGuardClient +} + +// NewMockWireGuardClient creates a new mock instance. +func NewMockWireGuardClient(ctrl *gomock.Controller) *MockWireGuardClient { + mock := &MockWireGuardClient{ctrl: ctrl} + mock.recorder = &MockWireGuardClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockWireGuardClient) EXPECT() *MockWireGuardClientMockRecorder { + return m.recorder +} + +// Device mocks base method. +func (m *MockWireGuardClient) Device(arg0 string) (*wgtypes.Device, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Device", arg0) + ret0, _ := ret[0].(*wgtypes.Device) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Device indicates an expected call of Device. +func (mr *MockWireGuardClientMockRecorder) Device(arg0 any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Device", reflect.TypeOf((*MockWireGuardClient)(nil).Device), arg0) +} diff --git a/internal/repo/peers.go b/internal/repo/peers.go index 4f0aeb1..a8a0bce 100644 --- a/internal/repo/peers.go +++ b/internal/repo/peers.go @@ -9,14 +9,17 @@ import ( ) var _ ctrl.PeerRepository = &Peers{} +var _ entity.PeerSearcher = &Peers{} type Peers struct { + wgCtrl WireGuardClient mutex sync.RWMutex entities map[entity.PeerId]*entity.Peer } -func NewPeers() *Peers { +func NewPeers(wgCtrl WireGuardClient) *Peers { return &Peers{ + wgCtrl: wgCtrl, entities: make(map[entity.PeerId]*entity.Peer), } } @@ -47,6 +50,25 @@ func (r *Peers) ListByDevice(ctx context.Context, deviceName entity.DeviceId) ([ return peers, nil } +// NOTE: will replace the above ListByDevice +func (r *Peers) SearchByDevice(ctx context.Context, deviceName entity.DeviceId) ([]*entity.Peer, error) { + device, err := r.wgCtrl.Device(string(deviceName)) + if err != nil { + return nil, err + } + + peers := make([]*entity.Peer, len(device.Peers)) + for i, peer := range device.Peers { + peers[i] = entity.NewPeer( + entity.NewPeerId(device.PublicKey[:], peer.PublicKey[:]), + device.Name, + peer.PublicKey, + ) + } + + return peers, nil +} + func (r *Peers) Find(ctx context.Context, id entity.PeerId) (*entity.Peer, error) { r.mutex.RLock() defer r.mutex.RUnlock() diff --git a/internal/repo/peers_test.go b/internal/repo/peers_test.go index 4b2d995..8af9af6 100644 --- a/internal/repo/peers_test.go +++ b/internal/repo/peers_test.go @@ -6,9 +6,16 @@ import ( "github.com/tjjh89017/stunmesh-go/internal/entity" "github.com/tjjh89017/stunmesh-go/internal/repo" + mock "github.com/tjjh89017/stunmesh-go/internal/repo/mock" + "go.uber.org/mock/gomock" ) func Test_PeerRepository_Find(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockWgClient := mock.NewMockWireGuardClient(mockCtrl) + peerId := entity.NewPeerId([]byte{}, []byte{}) peer := entity.NewPeer( @@ -17,7 +24,7 @@ func Test_PeerRepository_Find(t *testing.T) { [32]byte{}, ) - peers := repo.NewPeers() + peers := repo.NewPeers(mockWgClient) peers.Save(context.TODO(), peer) tests := []struct { @@ -51,6 +58,11 @@ func Test_PeerRepository_Find(t *testing.T) { } func Test_PeerRepository_List(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockWgClient := mock.NewMockWireGuardClient(mockCtrl) + tests := []struct { name string peers []*entity.Peer @@ -90,7 +102,7 @@ func Test_PeerRepository_List(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - peers := repo.NewPeers() + peers := repo.NewPeers(mockWgClient) for _, peer := range tt.peers { peers.Save(context.TODO(), peer) } @@ -112,6 +124,11 @@ func Test_PeerRepository_List(t *testing.T) { } func Test_PeerListByDevice(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockWgClient := mock.NewMockWireGuardClient(mockCtrl) + tests := []struct { name string deviceName entity.DeviceId @@ -159,7 +176,7 @@ func Test_PeerListByDevice(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Parallel() - peers := repo.NewPeers() + peers := repo.NewPeers(mockWgClient) for _, peer := range tt.peers { peers.Save(context.TODO(), peer) } diff --git a/internal/repo/repo.go b/internal/repo/repo.go index 2203a59..27d42a2 100644 --- a/internal/repo/repo.go +++ b/internal/repo/repo.go @@ -3,11 +3,13 @@ package repo import ( "github.com/google/wire" "github.com/tjjh89017/stunmesh-go/internal/ctrl" + "github.com/tjjh89017/stunmesh-go/internal/entity" ) var DefaultSet = wire.NewSet( NewPeers, wire.Bind(new(ctrl.PeerRepository), new(*Peers)), + wire.Bind(new(entity.PeerSearcher), new(*Peers)), NewDevices, wire.Bind(new(ctrl.DeviceRepository), new(*Devices)), ) diff --git a/wire.go b/wire.go index ddd857b..4e76675 100644 --- a/wire.go +++ b/wire.go @@ -22,7 +22,6 @@ import ( func setup() (*daemon.Daemon, error) { wire.Build( - config.Load, wgctrl.New, wire.Bind(new(ctrl.WireGuardClient), new(*wgctrl.Client)), provideCloudflareApi, @@ -30,11 +29,13 @@ func setup() (*daemon.Daemon, error) { wire.Bind(new(plugin.Store), new(*store.CloudflareStore)), provideRefreshQueue, wire.Bind(new(ctrl.RefreshQueue), new(*queue.Queue[entity.PeerId])), + config.DefaultSet, logger.DefaultSet, repo.DefaultSet, stun.DefaultSet, crypto.DefaultSet, ctrl.DefaultSet, + entity.DefaultSet, daemon.New, ) diff --git a/wire_gen.go b/wire_gen.go index c6c971f..63b1f8e 100644 --- a/wire_gen.go +++ b/wire_gen.go @@ -34,9 +34,11 @@ func setup() (*daemon.Daemon, error) { return nil, err } devices := repo.NewDevices() - peers := repo.NewPeers() + peers := repo.NewPeers(client) zerologLogger := logger.NewLogger(configConfig) - bootstrapController := ctrl.NewBootstrapController(client, configConfig, devices, peers, zerologLogger) + deviceConfig := config.NewDeviceConfig(configConfig) + filterPeerService := entity.NewFilterPeerService(peers, deviceConfig) + bootstrapController := ctrl.NewBootstrapController(client, configConfig, devices, peers, zerologLogger, filterPeerService) api, err := provideCloudflareApi(configConfig) if err != nil { return nil, err