Skip to content

Commit

Permalink
Refactor bootstrap controller to move filter behavior to entity
Browse files Browse the repository at this point in the history
  • Loading branch information
elct9620 committed Sep 30, 2024
1 parent 38cf66d commit 8b709be
Show file tree
Hide file tree
Showing 15 changed files with 433 additions and 74 deletions.
5 changes: 5 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE=
github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk=
github.com/google/wire v0.6.0 h1:HBkoIh4BdSxoyo9PveV8giw7ZsaBOvzWKfcg/6MrVwI=
github.com/google/wire v0.6.0/go.mod h1:F4QhpQ9EDIdJ1Mbop/NZBRB+5yrR6qg3BnctaoUk6NA=
Expand Down Expand Up @@ -113,6 +114,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
Expand Down Expand Up @@ -168,6 +171,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/wireguard v0.0.0-20230325221338-052af4a8072b h1:J1CaxgLerRR5lgx3wnr6L04cJFbWoceSK9JWBdglINo=
Expand Down
18 changes: 8 additions & 10 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@ import (
"errors"
"time"

"github.com/google/wire"
"github.com/spf13/viper"
"github.com/tjjh89017/stunmesh-go/internal/entity"
)

var DefaultSet = wire.NewSet(
Load,
NewDeviceConfig,
wire.Bind(new(entity.PeerAllower), new(*DeviceConfig)),
)

const Name = "config"
Expand All @@ -30,16 +38,6 @@ var envs = map[string][]string{
"refresh_interval": {"REFRESH_INTERVAL"},
}

type Peer struct {
Description string `mapstructure:"description"`
PublicKey string `mapstructure:"public_key"`
}

type Interface struct {
Peers map[string]Peer `mapstructure:"peers"`
}
type Interfaces map[string]Interface

type Logger struct {
Level string `mapstructure:"level"`
}
Expand Down
55 changes: 55 additions & 0 deletions internal/config/device.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package config

import (
"context"
"encoding/base64"

"github.com/rs/zerolog"
"github.com/tjjh89017/stunmesh-go/internal/entity"
)

type Peer struct {
Description string `mapstructure:"description"`
PublicKey string `mapstructure:"public_key"`
}

type Interface struct {
Peers map[string]Peer `mapstructure:"peers"`
}
type Interfaces map[string]Interface

var _ entity.PeerAllower = &DeviceConfig{}

type DeviceConfig struct {
interfaces Interfaces
}

func NewDeviceConfig(config *Config) *DeviceConfig {
return &DeviceConfig{
interfaces: config.Interfaces,
}
}

func (c *DeviceConfig) Allow(ctx context.Context, deviceName string, publicKey []byte, peerId entity.PeerId) bool {
logger := zerolog.Ctx(ctx)

device, ok := c.interfaces[deviceName]
if !ok {
return false
}

for _, peer := range device.Peers {
peerPublicKey, err := base64.StdEncoding.DecodeString(peer.PublicKey)
if err != nil {
logger.Error().Err(err).Str("device", deviceName).Str("public_key", peer.PublicKey).Msg("failed to decode public key")
continue
}

currentPeerId := entity.NewPeerId(publicKey, peerPublicKey)
if peerId == currentPeerId {
return true
}
}

return false
}
83 changes: 33 additions & 50 deletions internal/ctrl/bootstrap.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,85 +3,68 @@ package ctrl
import (
"context"

"encoding/base64"
"github.com/rs/zerolog"
"github.com/tjjh89017/stunmesh-go/internal/config"
"github.com/tjjh89017/stunmesh-go/internal/entity"
)

type BootstrapController struct {
wg WireGuardClient
config *config.Config
devices DeviceRepository
peers PeerRepository
logger zerolog.Logger
wg WireGuardClient
config *config.Config
devices DeviceRepository
peers PeerRepository
logger zerolog.Logger
filterService *entity.FilterPeerService
}

func NewBootstrapController(wg WireGuardClient, config *config.Config, devices DeviceRepository, peers PeerRepository, logger *zerolog.Logger) *BootstrapController {
func NewBootstrapController(wg WireGuardClient, config *config.Config, devices DeviceRepository, peers PeerRepository, logger *zerolog.Logger, filterService *entity.FilterPeerService) *BootstrapController {
return &BootstrapController{
wg: wg,
config: config,
devices: devices,
peers: peers,
logger: logger.With().Str("controller", "bootstrap").Logger(),
wg: wg,
config: config,
devices: devices,
peers: peers,
logger: logger.With().Str("controller", "bootstrap").Logger(),
filterService: filterService,
}
}

func (ctrl *BootstrapController) Execute(ctx context.Context) {
for deviceName, device := range ctrl.config.Interfaces {
if err := ctrl.registerDevice(ctx, deviceName, device.Peers); err != nil {
for deviceName := range ctrl.config.Interfaces {
if err := ctrl.registerDevice(ctx, deviceName); err != nil {
ctrl.logger.Error().Err(err).Str("device", deviceName).Msg("failed to register device")
continue
}
}
}

func (ctrl *BootstrapController) registerDevice(ctx context.Context, deviceName string, peers map[string]config.Peer) error {
if len(peers) == 0 {
ctrl.logger.Warn().Str("device", deviceName).Msg("Peers list is empty.")
return nil
}

func (ctrl *BootstrapController) registerDevice(ctx context.Context, deviceName string) error {
device, err := ctrl.wg.Device(deviceName)
if err != nil {
return err
}

peerCount := 0
for _, p := range device.Peers {
base64PublicKey := base64.StdEncoding.EncodeToString(p.PublicKey[:])
if name, ok := containsPeer(peers, base64PublicKey); ok {
peerCount += 1
ctrl.logger.Info().Str("device", deviceName).Str("peer", name).Str("publicKey", base64PublicKey).Msg("Register Peer")
peer := entity.NewPeer(
entity.NewPeerId(device.PublicKey[:], p.PublicKey[:]),
device.Name,
p.PublicKey,
)
deviceEntity := entity.NewDevice(
entity.DeviceId(device.Name),
device.ListenPort,
device.PrivateKey[:],
)

ctrl.peers.Save(ctx, peer)
}
allowPeers, err := ctrl.filterService.Execute(ctx, deviceEntity.Name(), device.PublicKey[:])
if err != nil {
ctrl.logger.Error().Err(err).Str("device", deviceName).Msg("failed to filter allowed peers")
return err
}

if peerCount > 0 {
ctrl.logger.Info().Str("device", deviceName).Msg("Register Device")
deviceEntity := entity.NewDevice(
entity.DeviceId(device.Name),
device.ListenPort,
device.PrivateKey[:],
)
isAnyPeerAllowed := len(allowPeers) > 0
if !isAnyPeerAllowed {
ctrl.logger.Warn().Str("device", deviceName).Msg("no peer is allowed")
return nil
}

ctrl.devices.Save(ctx, deviceEntity)
ctrl.devices.Save(ctx, deviceEntity)
for _, peer := range allowPeers {
ctrl.peers.Save(ctx, peer)
}

return nil
}

func containsPeer(m map[string]config.Peer, publicKey string) (string, bool) {
for k, v := range m {
if v.PublicKey == publicKey {
return k, true
}
}
return "", false
}
82 changes: 75 additions & 7 deletions internal/ctrl/bootstrap_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,56 @@ package ctrl_test

import (
"context"
"errors"
"testing"

"github.com/rs/zerolog"
"github.com/tjjh89017/stunmesh-go/internal/config"
"github.com/tjjh89017/stunmesh-go/internal/ctrl"
mock "github.com/tjjh89017/stunmesh-go/internal/ctrl/mock"
"github.com/tjjh89017/stunmesh-go/internal/entity"
mockEntity "github.com/tjjh89017/stunmesh-go/internal/entity/mock"
gomock "go.uber.org/mock/gomock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

func TestBootstrap_WithError(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()

mockWgClient := mock.NewMockWireGuardClient(mockCtrl)
mockDevices := mock.NewMockDeviceRepository(mockCtrl)
mockPeers := mock.NewMockPeerRepository(mockCtrl)
logger := zerolog.Nop()
cfg := &config.Config{
Interfaces: map[string]config.Interface{
"wg0": {
Peers: map[string]config.Peer{
"test_peer1": {
PublicKey: "XgPRso34lnrSAx8nJtdj1/zlF7CoNj7B64LPElYdOGs=",
},
},
},
},
}
deviceConfig := config.NewDeviceConfig(cfg)
mockPeerSearcher := mockEntity.NewMockPeerSearcher(mockCtrl)
peerFilterService := entity.NewFilterPeerService(mockPeerSearcher, deviceConfig)

mockWgClient.EXPECT().Device("wg0").Return(nil, errors.New("device not found"))

bootstrap := ctrl.NewBootstrapController(
mockWgClient,
cfg,
mockDevices,
mockPeers,
&logger,
peerFilterService,
)

bootstrap.Execute(context.TODO())
}

func TestBootstrap_WithMultipleInterfaces(t *testing.T) {
mockCtrl := gomock.NewController(t)
defer mockCtrl.Finish()
Expand All @@ -20,27 +60,30 @@ func TestBootstrap_WithMultipleInterfaces(t *testing.T) {
mockDevices := mock.NewMockDeviceRepository(mockCtrl)
mockPeers := mock.NewMockPeerRepository(mockCtrl)
logger := zerolog.Nop()
config := &config.Config{
cfg := &config.Config{
Interfaces: map[string]config.Interface{
"wg0": config.Interface{
"wg0": {
Peers: map[string]config.Peer{
"test_peer1": config.Peer{
"test_peer1": {
PublicKey: "XgPRso34lnrSAx8nJtdj1/zlF7CoNj7B64LPElYdOGs=",
},
},
},
"wg1": config.Interface{
"wg1": {
Peers: map[string]config.Peer{
"test_peer2": config.Peer{
"test_peer2": {
PublicKey: "FQ9/2l8t4xmQQbs6SB03+Lh2VijJX74rxRUOv7YT03k=",
},
"test_peer3": config.Peer{
"test_peer3": {
PublicKey: "Cud5HogJJLCppoUuHnWrSvEJuI49D01sQcfiD3Y9RRU=",
},
},
},
},
}
mockPeerSearcher := mockEntity.NewMockPeerSearcher(mockCtrl)
deviceConfig := config.NewDeviceConfig(cfg)
peerFilterService := entity.NewFilterPeerService(mockPeerSearcher, deviceConfig)

mockDevice0 := &wgtypes.Device{
Name: "wg0",
Expand Down Expand Up @@ -72,12 +115,37 @@ func TestBootstrap_WithMultipleInterfaces(t *testing.T) {
mockDevices.EXPECT().Save(gomock.Any(), gomock.Any()).Times(2)
mockPeers.EXPECT().Save(gomock.Any(), gomock.Any()).Times(3)

mockDevice0Peers := []*entity.Peer{
entity.NewPeer(
entity.NewPeerId(mockDevice0.PublicKey[:], mockDevice0.Peers[0].PublicKey[:]),
mockDevice0.Name,
mockDevice0.Peers[0].PublicKey,
),
}

mockDevice1Peers := []*entity.Peer{
entity.NewPeer(
entity.NewPeerId(mockDevice1.PublicKey[:], mockDevice1.Peers[0].PublicKey[:]),
mockDevice1.Name,
mockDevice1.Peers[0].PublicKey,
),
entity.NewPeer(
entity.NewPeerId(mockDevice1.PublicKey[:], mockDevice1.Peers[1].PublicKey[:]),
mockDevice1.Name,
mockDevice1.Peers[1].PublicKey,
),
}

mockPeerSearcher.EXPECT().SearchByDevice(gomock.Any(), entity.DeviceId("wg0")).Return(mockDevice0Peers, nil)
mockPeerSearcher.EXPECT().SearchByDevice(gomock.Any(), entity.DeviceId("wg1")).Return(mockDevice1Peers, nil)

bootstrap := ctrl.NewBootstrapController(
mockWgClient,
config,
cfg,
mockDevices,
mockPeers,
&logger,
peerFilterService,
)

bootstrap.Execute(context.TODO())
Expand Down
7 changes: 7 additions & 0 deletions internal/entity/entity.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package entity

import "github.com/google/wire"

var DefaultSet = wire.NewSet(
NewFilterPeerService,
)
Loading

0 comments on commit 8b709be

Please sign in to comment.