diff --git a/.gitignore b/.gitignore index 6fe064d29..c72ce2a4a 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,5 @@ bin/ # vendor/ tetrapod.yaml tetrapod.db + +tmp/ diff --git a/Makefile b/Makefile index d341a9c52..d04960f82 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,10 @@ hostvrf: bin route-pods: bin CGO_ENABLED=0 go build -o ./bin ./tetracni/cmd/route-pods -cni-plugins: tetra-extra-routes tetra-pod-ipam hostvrf route-pods +nsexec: bin + CGO_ENABLED=0 go build -o ./bin ./tetracni/cmd/nsexec + +cni-plugins: tetra-extra-routes tetra-pod-ipam hostvrf route-pods nsexec .PHONY: test test: envtest ## Run tests. @@ -44,3 +47,16 @@ test: envtest ## Run tests. envtest: $(ENVTEST) ## Download envtest-setup locally if necessary. $(ENVTEST): $(LOCALBIN) test -s $(LOCALBIN)/setup-envtest || GOBIN=$(LOCALBIN) go install sigs.k8s.io/controller-runtime/tools/setup-envtest@latest + +.PHONY: install-cnitools +install-cnitools: + GOBIN=$(PWD)/bin go install github.com/containernetworking/cni/cnitool@latest + ( \ + rm -rf plugins; \ + git clone https://github.com/containernetworking/plugins.git && \ + cd plugins && \ + ./build_linux.sh && \ + mv ./bin/* ../bin && \ + cd .. && \ + rm -rf plugins \ + ) diff --git a/controlplane/config/default/kustomization.yaml b/controlplane/config/default/kustomization.yaml index c46e9a470..a7a398e6c 100644 --- a/controlplane/config/default/kustomization.yaml +++ b/controlplane/config/default/kustomization.yaml @@ -12,7 +12,7 @@ namePrefix: controlplane- #commonLabels: # someName: someValue -bases: +resources: - ../crd - ../rbac - ../manager diff --git a/pkg/nsutil/nsutil.go b/pkg/nsutil/nsutil.go index 6a8ce4a2d..17c4d80a1 100644 --- a/pkg/nsutil/nsutil.go +++ b/pkg/nsutil/nsutil.go @@ -5,6 +5,7 @@ import ( "os" "os/exec" "runtime" + "sync" "github.com/vishvananda/netns" ) @@ -38,31 +39,44 @@ func CreateNamespace(name string) (netns.NsHandle, error) { } func RunInNamespace(handle netns.NsHandle, fn func() error) (err error) { - runtime.LockOSThread() - defer func() { - if err == nil { - runtime.UnlockOSThread() + impl := func() error { + runtime.LockOSThread() + defer func() { + if err == nil { + runtime.UnlockOSThread() + } + }() + + cur, err := netns.Get() + if err != nil { + return err } - }() + defer func() { + if e := netns.Set(cur); e != nil { + err = fmt.Errorf("failed to recover netns: %w", err) + } + cur.Close() + }() - cur, err := netns.Get() - if err != nil { - return err - } - defer func() { - if e := netns.Set(cur); e != nil { - err = fmt.Errorf("failed to recover netns: %w", err) + if err := netns.Set(handle); err != nil { + return fmt.Errorf("failed to set netns: %w", err) } - cur.Close() - }() - if err := netns.Set(handle); err != nil { - return fmt.Errorf("failed to set netns: %w", err) - } + if err := fn(); err != nil { + return fmt.Errorf("fn failed: %w", err) + } - if err := fn(); err != nil { - return fmt.Errorf("fn failed: %w", err) + return nil } - return nil + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + err = impl() + }() + + wg.Wait() + + return err } diff --git a/tetracni/cmd/nsexec/main.go b/tetracni/cmd/nsexec/main.go index 26c103354..b5fafa13c 100644 --- a/tetracni/cmd/nsexec/main.go +++ b/tetracni/cmd/nsexec/main.go @@ -4,8 +4,12 @@ import ( "bytes" "encoding/json" "fmt" + "io" + "os" "os/exec" + "path/filepath" + "github.com/containernetworking/cni/pkg/invoke" "github.com/containernetworking/cni/pkg/skel" "github.com/containernetworking/cni/pkg/types" "github.com/containernetworking/cni/pkg/version" @@ -40,17 +44,35 @@ func cmd(args *skel.CmdArgs) error { return err } + _, _ = nsutil.CreateNamespace(netConf.Sandbox) + handle, err := netns.GetFromName(netConf.Sandbox) if err != nil { return err } + paths := filepath.SplitList(os.Getenv("CNI_PATH")) + pluginPath, err := invoke.FindInPath(netConf.Plugin, paths) + if err != nil { + return err + } + return nsutil.RunInNamespace(handle, func() error { - cmd := exec.Command(netConf.Plugin) + cmd := exec.Command(pluginPath) cmd.Stdin = bytes.NewReader(args.StdinData) - return cmd.Run() + var buf bytes.Buffer + cmd.Stdout = io.MultiWriter(os.Stdout, &buf) + cmd.Stderr = io.MultiWriter(os.Stderr, &buf) + + err := cmd.Run() + + if err != nil { + return fmt.Errorf("executing plugin failed %s: %w", buf.String(), err) + } + + return nil }) } diff --git a/tetracni/cmd/route-pods/conf.go b/tetracni/cmd/route-pods/conf.go index ec61d5cf3..55951e558 100644 --- a/tetracni/cmd/route-pods/conf.go +++ b/tetracni/cmd/route-pods/conf.go @@ -12,6 +12,7 @@ import ( type Conf struct { types.NetConf + Sandbox string `json:"sandbox"` HostVeth string `json:"hostVeth"` PeerVeth string `json:"peerVeth"` Firewall string `json:"firewall"` diff --git a/tetracni/cmd/route-pods/iptables.go b/tetracni/cmd/route-pods/iptables.go index 1374e20a2..21f88aa48 100644 --- a/tetracni/cmd/route-pods/iptables.go +++ b/tetracni/cmd/route-pods/iptables.go @@ -4,7 +4,9 @@ import ( "fmt" "github.com/coreos/go-iptables/iptables" + "github.com/miscord-dev/tetrapod/pkg/nsutil" "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" ) const ( @@ -19,7 +21,7 @@ var ( } ) -func setUpIPTables(ipt *iptables.IPTables, hostVeth *netlink.Veth, redirectedChain string) error { +func setUpIPTables(ipt *iptables.IPTables, peerVeth *netlink.Veth, redirectedChain string) error { exists, err := ipt.ChainExists(iptablesFilterTable, redirectedChain) if err != nil { @@ -41,9 +43,9 @@ func setUpIPTables(ipt *iptables.IPTables, hostVeth *netlink.Veth, redirectedCha } rules := [][]string{ - {"-o", hostVeth.Name, "-j", "ACCEPT"}, - {"-i", hostVeth.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, - {"-i", hostVeth.Name, "-j", "DROP"}, + {"-i", peerVeth.Name, "-j", "ACCEPT"}, + {"-o", peerVeth.Name, "-m", "state", "--state", "RELATED,ESTABLISHED", "-j", "ACCEPT"}, + {"-o", peerVeth.Name, "-j", "DROP"}, } for _, rule := range rules { @@ -55,29 +57,37 @@ func setUpIPTables(ipt *iptables.IPTables, hostVeth *netlink.Veth, redirectedCha return nil } -func setUpFirewall(hostVeth *netlink.Veth, conf *Conf) error { +func setUpFirewall(peerNetns netns.NsHandle, peerVeth *netlink.Veth, conf *Conf) error { switch conf.Firewall { case FirewallNever: return nil case FirewallIPTables: - ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) + err := nsutil.RunInNamespace(peerNetns, func() error { + ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4) - if err != nil { - return fmt.Errorf("failed to set up iptables: %w", err) - } + if err != nil { + return fmt.Errorf("failed to set up iptables: %w", err) + } - if err := setUpIPTables(ipt, hostVeth, conf.IPTablesChain); err != nil { - return fmt.Errorf("failed to set up iptables for v4: %w", err) - } + if err := setUpIPTables(ipt, peerVeth, conf.IPTablesChain); err != nil { + return fmt.Errorf("failed to set up iptables for v4: %w", err) + } - ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) + ipt, err = iptables.NewWithProtocol(iptables.ProtocolIPv6) - if err != nil { - return fmt.Errorf("failed to set up iptables: %w", err) - } + if err != nil { + return fmt.Errorf("failed to set up iptables: %w", err) + } + + if err := setUpIPTables(ipt, peerVeth, conf.IPTablesChain); err != nil { + return fmt.Errorf("failed to set up iptables for v4: %w", err) + } - if err := setUpIPTables(ipt, hostVeth, conf.IPTablesChain); err != nil { - return fmt.Errorf("failed to set up iptables for v4: %w", err) + return nil + }) + + if err != nil { + return fmt.Errorf("manipulating iptables in %s netns failed: %w", peerNetns, err) } } diff --git a/tetracni/cmd/route-pods/main.go b/tetracni/cmd/route-pods/main.go index 5264ab2ac..75b9c94c0 100644 --- a/tetracni/cmd/route-pods/main.go +++ b/tetracni/cmd/route-pods/main.go @@ -11,13 +11,14 @@ import ( bv "github.com/containernetworking/plugins/pkg/utils/buildversion" "github.com/seancfoley/ipaddress-go/ipaddr" "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" ) func main() { skel.PluginMain(cmdAdd, cmdCheck, cmdDel, version.All, bv.BuildString("hostvrf")) } -func findVeth(hostVethName, peerVethName string) (hostVeth, peerVeth *netlink.Veth, err error) { +func findVeth(hostVethName, peerVethName string, peerNetnsFd int, peerNetnsNetlink *netlink.Handle) (hostVeth, peerVeth *netlink.Veth, err error) { hostVethLink, err := netlink.LinkByName(hostVethName) if err != nil { @@ -30,23 +31,52 @@ func findVeth(hostVethName, peerVethName string) (hostVeth, peerVeth *netlink.Ve return nil, nil, fmt.Errorf("link %s is not veth: %s", hostVethName, hostVethLink.Type()) } - peerVethLink, err := netlink.LinkByName(peerVethName) + peerVethLink, err := peerNetnsNetlink.LinkByName(peerVethName) if err != nil { - return nil, nil, fmt.Errorf("failed to find veth %s: %w", peerVethName, err) + return hostVeth, nil, fmt.Errorf("failed to find veth %s: %w", hostVethName, err) } peerVeth, ok = peerVethLink.(*netlink.Veth) + if !ok { + return hostVeth, nil, fmt.Errorf("link %s is not veth: %s", peerVethName, peerVethLink.Type()) + } + + return +} + +func updateVeth(hostVethName, peerVethName string, peerNetnsFd int, peerNetnsNetlink *netlink.Handle) (hostVeth, peerVeth *netlink.Veth, err error) { + hostVeth, peerVeth, err = findVeth(hostVethName, peerVethName, peerNetnsFd, peerNetnsNetlink) + + if err == nil { + return hostVeth, peerVeth, nil + } + if err != nil && hostVeth == nil { + return nil, nil, fmt.Errorf("failed to find veth pairs: %w", err) + } + + peerVethLink, err := netlink.LinkByName(peerVethName) + + if err != nil { + return hostVeth, nil, fmt.Errorf("failed to find veth %s: %w", hostVethName, err) + } + + peerVeth, ok := peerVethLink.(*netlink.Veth) + if !ok { return nil, nil, fmt.Errorf("link %s is not veth: %s", peerVethName, peerVethLink.Type()) } + if err := netlink.LinkSetNsFd(peerVethLink, peerNetnsFd); err != nil { + return nil, nil, fmt.Errorf("failed to set netns for peer veth %s: %w", peerVethName, err) + } + return } -func setupVeth(hostVethName, peerVethName string) (hostVeth, peerVeth *netlink.Veth, err error) { - hostVeth, peerVeth, err = findVeth(hostVethName, peerVethName) +func setupVeth(hostVethName, peerVethName string, peerNetnsFd int, peerNetnsNetlink *netlink.Handle) (hostVeth, peerVeth *netlink.Veth, err error) { + hostVeth, peerVeth, err = updateVeth(hostVethName, peerVethName, peerNetnsFd, peerNetnsNetlink) switch { case err == nil: @@ -63,7 +93,7 @@ func setupVeth(hostVethName, peerVethName string) (hostVeth, peerVeth *netlink.V return nil, nil, fmt.Errorf("faile create a veth %s: %w", hostVethName, err) } - return findVeth(hostVethName, peerVethName) + return updateVeth(hostVethName, peerVethName, peerNetnsFd, peerNetnsNetlink) default: return nil, nil, fmt.Errorf("failed to find %s, %s: %w", hostVethName, peerVethName, err) } @@ -76,7 +106,19 @@ func cmdAdd(args *skel.CmdArgs) error { return fmt.Errorf("failed to load config: %w", err) } - hostVeth, peerVeth, err := setupVeth(conf.HostVeth, conf.PeerVeth) + ns, err := netns.GetFromName(conf.Sandbox) + + if err != nil { + return fmt.Errorf("failed to find netns %s: %w", conf.Sandbox, err) + } + + peerNetnsNetlink, err := netlink.NewHandleAt(ns) + + if err != nil { + return fmt.Errorf("failed to find netns %s: %w", conf.Sandbox, err) + } + + hostVeth, peerVeth, err := setupVeth(conf.HostVeth, conf.PeerVeth, int(ns), peerNetnsNetlink) if err != nil { return fmt.Errorf("failed to set up veth: %w", err) @@ -88,7 +130,7 @@ func cmdAdd(args *skel.CmdArgs) error { continue } - link, err := netlink.LinkByName(iface.Name) + link, err := peerNetnsNetlink.LinkByName(iface.Name) if err != nil { return fmt.Errorf("failed to get link %s: %w", iface.Name, err) @@ -106,7 +148,7 @@ func cmdAdd(args *skel.CmdArgs) error { return fmt.Errorf("bridge not found") } - if err := netlink.LinkSetMaster(peerVeth, bridge); err != nil { + if err := peerNetnsNetlink.LinkSetMaster(peerVeth, bridge); err != nil { return fmt.Errorf("failed to set master of %s %s: %w", peerVeth.Name, bridge.Name, err) } @@ -126,14 +168,14 @@ func cmdAdd(args *skel.CmdArgs) error { } } - if err := setUpFirewall(hostVeth, conf); err != nil { + if err := setUpFirewall(ns, peerVeth, conf); err != nil { return fmt.Errorf("failed to set up firewall: %w", err) } if err := netlink.LinkSetUp(hostVeth); err != nil { return fmt.Errorf("failed to set %s up: %w", hostVeth.Name, err) } - if err := netlink.LinkSetUp(peerVeth); err != nil { + if err := peerNetnsNetlink.LinkSetUp(peerVeth); err != nil { return fmt.Errorf("failed to set %s up: %w", peerVeth.Name, err) } @@ -147,7 +189,19 @@ func cmdCheck(args *skel.CmdArgs) error { return fmt.Errorf("failed to load config: %w", err) } - _, _, err = findVeth(conf.HostVeth, conf.PeerVeth) + ns, err := netns.GetFromName(conf.Sandbox) + + if err != nil { + return fmt.Errorf("failed to find netns %s: %w", conf.Sandbox, err) + } + + peerNetnsNetlink, err := netlink.NewHandleAt(ns) + + if err != nil { + return fmt.Errorf("failed to find netns %s: %w", conf.Sandbox, err) + } + + _, _, err = findVeth(conf.HostVeth, conf.PeerVeth, int(ns), peerNetnsNetlink) if err != nil { return fmt.Errorf("failed to find veth: %w", err) diff --git a/tetracni/cmd/tetra-extra-routes/main.go b/tetracni/cmd/tetra-extra-routes/main.go index 085fa0bb6..f2425c921 100644 --- a/tetracni/cmd/tetra-extra-routes/main.go +++ b/tetracni/cmd/tetra-extra-routes/main.go @@ -15,6 +15,26 @@ func main() { skel.PluginMain(cmdAdd, cmdCheck, cmdDel, version.All, bv.BuildString("tetra-extra-routes")) } +func vrfTable(conf *Conf) (uint32, error) { + if conf.VRF == "" { + return 0, nil + } + + link, err := netlink.LinkByName(conf.VRF) + + if err != nil { + return 0, fmt.Errorf("failed to find a VRF: %w", err) + } + + vrf, ok := link.(*netlink.Vrf) + + if !ok { + return 0, fmt.Errorf("%s is not VRF", conf.VRF) + } + + return vrf.Table, nil +} + func cmdAdd(args *skel.CmdArgs) error { conf, result, err := loadConfig(args.StdinData, args.Args) @@ -53,20 +73,14 @@ func cmdAdd(args *skel.CmdArgs) error { return fmt.Errorf("no veth to route") } - link, err := netlink.LinkByName(conf.VRF) + table, err := vrfTable(conf) if err != nil { - return fmt.Errorf("failed to find a VRF: %w", err) - } - - vrf, ok := link.(*netlink.Vrf) - - if !ok { - return fmt.Errorf("%s is not VRF", conf.VRF) + return fmt.Errorf("failed to get table index of vrf: %w", err) } routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ - Table: int(vrf.Table), + Table: int(table), }, netlink.RT_FILTER_TABLE) if err != nil { @@ -94,11 +108,11 @@ func cmdAdd(args *skel.CmdArgs) error { err = netlink.RouteAdd(&netlink.Route{ Dst: cidr, LinkIndex: veth.Index, - Table: int(vrf.Table), + Table: int(table), }) if err != nil { - return fmt.Errorf("failed to add a route for %s to %s: %w", cidr.String(), vrf.Name, err) + return fmt.Errorf("failed to add a route for %s to %s: %w", cidr.String(), conf.VRF, err) } } @@ -143,20 +157,14 @@ func cmdCheck(args *skel.CmdArgs) error { return fmt.Errorf("no veth to route") } - link, err := netlink.LinkByName(conf.VRF) + table, err := vrfTable(conf) if err != nil { - return fmt.Errorf("failed to find a VRF: %w", err) - } - - vrf, ok := link.(*netlink.Vrf) - - if !ok { - return fmt.Errorf("%s is not VRF", conf.VRF) + return fmt.Errorf("failed to get table index of vrf: %w", err) } routes, err := netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ - Table: int(vrf.Table), + Table: int(table), }, netlink.RT_FILTER_TABLE) if err != nil { diff --git a/tetracni/cni/10-tetra.conflist b/tetracni/cni/10-tetra.conflist index 8ee2d9709..099bcfbd0 100644 --- a/tetracni/cni/10-tetra.conflist +++ b/tetracni/cni/10-tetra.conflist @@ -3,7 +3,9 @@ "name": "tetrapod", "plugins": [ { - "type": "bridge", + "type": "nsexec", + "sandbox": "tetrapod", + "plugin": "bridge", "bridge": "cni0", "isDefaultGateway": true, "mtu": 1280, @@ -13,12 +15,10 @@ } }, { - "type": "hostvrf", - "vrf": "tetrapod-vrf" - }, - { - "type": "tetra-extra-routes", - "vrf": "tetrapod-vrf" + "type": "nsexec", + "sandbox": "tetrapod", + + "plugin": "tetra-extra-routes" }, { "type": "route-pods" diff --git a/tetrad/api/v1alpha1/cniconfig_types.go b/tetrad/api/v1alpha1/cniconfig_types.go index 61a651c38..be58ac220 100644 --- a/tetrad/api/v1alpha1/cniconfig_types.go +++ b/tetrad/api/v1alpha1/cniconfig_types.go @@ -109,8 +109,7 @@ type Wireguard struct { ListenPort int `json:"listenPort"` STUNEndpoint string `json:"stunEndpoint"` Name string `json:"name"` - VRF string `json:"vrf"` - Table int `json:"table"` + Netns string `json:"netns"` } func (wg *Wireguard) Load() { @@ -142,13 +141,8 @@ func (wg *Wireguard) Load() { if wg.Name == "" { wg.Name = "tetrapod0" } - if wg.VRF == "" { - wg.VRF = "tetrapod-vrf" - } else if wg.VRF == "-" { - wg.VRF = "" - } - if wg.Table == 0 { - wg.Table = 1351 + if wg.Netns == "" { + wg.Netns = "tetrapod" } } diff --git a/tetrad/config/default/kustomization.yaml b/tetrad/config/default/kustomization.yaml index b99543886..09aefcd13 100644 --- a/tetrad/config/default/kustomization.yaml +++ b/tetrad/config/default/kustomization.yaml @@ -12,7 +12,7 @@ namePrefix: tetrad- #commonLabels: # someName: someValue -bases: +resources: - ../rbac - ../manager # [WEBHOOK] To enable webhook, uncomment all the sections with [WEBHOOK] prefix including the one in diff --git a/tetrad/main.go b/tetrad/main.go index 4f0ccb17d..f20194bc2 100644 --- a/tetrad/main.go +++ b/tetrad/main.go @@ -144,7 +144,7 @@ func main() { config.Wireguard.PrivateKey = privKey.String() } - engine, err := tetraengine.New(config.Wireguard.Name, config.Wireguard.VRF, uint32(config.Wireguard.Table), &tetraengine.Config{ + engine, err := tetraengine.New(config.Wireguard.Name, config.Wireguard.Netns, &tetraengine.Config{ PrivateKey: config.Wireguard.PrivateKey, ListenPort: config.Wireguard.ListenPort, STUNEndpoint: config.Wireguard.STUNEndpoint, diff --git a/tetraengine/tetraengine.go b/tetraengine/tetraengine.go index fa1a8845c..f8c6481b0 100644 --- a/tetraengine/tetraengine.go +++ b/tetraengine/tetraengine.go @@ -42,23 +42,23 @@ type tetraEngine struct { logger *zap.Logger } -func New(ifaceName, vrf string, table uint32, config *Config, logger *zap.Logger) (res TetraEngine, err error) { +func New(ifaceName, netns string, config *Config, logger *zap.Logger) (res TetraEngine, err error) { engine := &tetraEngine{ logger: logger, reconfigTriggerCh: make(chan struct{}, 1), } - if err := engine.init(ifaceName, vrf, table, config); err != nil { + if err := engine.init(ifaceName, netns, config); err != nil { return nil, fmt.Errorf("failed to init engine: %w", err) } return engine, nil } -func (e *tetraEngine) init(ifaceName, vrf string, table uint32, config *Config) error { +func (e *tetraEngine) init(ifaceName, netns string, config *Config) error { var err error - e.wgEngine, err = wgengine.NewVRF(ifaceName, vrf, table, e.logger.With(zap.String("component", "wgengine"))) + e.wgEngine, err = wgengine.NewNetns(ifaceName, netns, e.logger.With(zap.String("component", "wgengine"))) if err != nil { return fmt.Errorf("failed to set up wgengine: %w", err) } diff --git a/tetraengine/wgengine/wgengine.go b/tetraengine/wgengine/wgengine.go index 8c77166fb..59fdd8b17 100644 --- a/tetraengine/wgengine/wgengine.go +++ b/tetraengine/wgengine/wgengine.go @@ -4,8 +4,11 @@ import ( "errors" "fmt" "io" + "os" + "github.com/miscord-dev/tetrapod/pkg/nsutil" "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" "go.uber.org/zap" "golang.zx2c4.com/wireguard/wgctrl" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" @@ -18,12 +21,11 @@ type Engine interface { var _ Engine = &wgEngine{} -func NewVRF(ifaceName, vrf string, table uint32, logger *zap.Logger) (Engine, error) { +func NewNetns(ifaceName, netns string, logger *zap.Logger) (Engine, error) { e := wgEngine{ - ifaceName: ifaceName, - vrf: vrf, - table: table, - logger: logger, + ifaceName: ifaceName, + wgNetnsName: netns, + logger: logger, } wgctrl, err := wgctrl.New() @@ -41,23 +43,33 @@ func NewVRF(ifaceName, vrf string, table uint32, logger *zap.Logger) (Engine, er } type wgEngine struct { - ifaceName string - vrf string - table uint32 + ifaceName string + wgNetnsName string netlink *netlink.Handle + wgNetns netns.NsHandle + wgNetlink *netlink.Handle wireguard *netlink.Wireguard - vrfLink *netlink.Vrf prevConfig wgtypes.Config wgctrl *wgctrl.Client logger *zap.Logger } +func (e *wgEngine) garbageCollect() { + link, err := e.netlink.LinkByName(e.ifaceName) + + if err == nil { + e.netlink.LinkDel(link) + } +} + func (e *wgEngine) initWireguard() (*netlink.Wireguard, error) { + e.garbageCollect() + var wg *netlink.Wireguard - link, err := e.netlink.LinkByName(e.ifaceName) + link, err := e.wgNetlink.LinkByName(e.ifaceName) switch { case err == nil: @@ -74,57 +86,32 @@ func (e *wgEngine) initWireguard() (*netlink.Wireguard, error) { }, } - if err := netlink.LinkAdd(wg); err != nil { + if err := e.netlink.LinkAdd(wg); err != nil { return nil, fmt.Errorf("failed to create wireguard interface: %w", err) } + + if err := e.netlink.LinkSetNsFd(wg, int(e.wgNetns)); err != nil { + return nil, fmt.Errorf("failed to set netns %s for wireguard interface: %w", e.wgNetnsName, err) + } default: return nil, fmt.Errorf("failed to find the link %s: %w", e.ifaceName, err) } - if err := netlink.LinkSetMTU(wg, 1280); err != nil { - return nil, fmt.Errorf("failed to set MTU: %w", err) - } - return wg, nil } -func (e *wgEngine) initVRF() (*netlink.Vrf, error) { - var vrf *netlink.Vrf - - link, err := e.netlink.LinkByName(e.vrf) - - switch { - case err == nil: - var ok bool - vrf, ok = link.(*netlink.Vrf) - - if !ok { - return nil, fmt.Errorf("the link %s is not vrf", e.vrf) - } - case errors.As(err, &netlink.LinkNotFoundError{}): - vrf = &netlink.Vrf{ - LinkAttrs: netlink.LinkAttrs{ - Name: e.vrf, - }, - Table: e.table, - } - - if err := netlink.LinkAdd(vrf); err != nil { - return nil, fmt.Errorf("failed to create vrf interface: %w", err) - } - default: - return nil, fmt.Errorf("failed to find the link %s: %w", e.vrf, err) - } +func (e *wgEngine) initNetns() (netns.NsHandle, error) { + nsHandle, err := netns.GetFromName(e.wgNetnsName) - if err := netlink.LinkSetMaster(e.wireguard, vrf); err != nil { - return nil, fmt.Errorf("failed to find the link %s: %w", e.vrf, err) + if os.IsNotExist(err) { + return nsutil.CreateNamespace(e.wgNetnsName) } - if err := netlink.LinkSetUp(vrf); err != nil { - return nil, fmt.Errorf("ip link set %s up failed: %w", vrf.LinkAttrs.Name, err) + if err != nil { + return 0, fmt.Errorf("failed to find netns with name %s: %w", e.wgNetnsName, err) } - return vrf, nil + return nsHandle, nil } func (e *wgEngine) init() error { @@ -132,7 +119,20 @@ func (e *wgEngine) init() error { e.netlink, err = netlink.NewHandle() if err != nil { - return fmt.Errorf("failed to initialize handle for main ns: %w", err) + return fmt.Errorf("failed to initialize handle for main netns: %w", err) + } + + nsHandle, err := e.initNetns() + + if err != nil { + return fmt.Errorf("failed to init netns: %w", err) + } + e.wgNetns = nsHandle + + e.wgNetlink, err = netlink.NewHandleAt(nsHandle) + + if err != nil { + return fmt.Errorf("failed to init handle for wg netns: %w", err) } wg, err := e.initWireguard() @@ -142,17 +142,28 @@ func (e *wgEngine) init() error { } e.wireguard = wg - if e.vrf != "" { - vrf, err := e.initVRF() + if err := e.wgNetlink.LinkSetMTU(wg, 1280); err != nil { + return fmt.Errorf("failed to set MTU: %w", err) + } + + if err := e.wgNetlink.LinkSetUp(wg); err != nil { + return fmt.Errorf("ip link set %s up failed: %w", wg.LinkAttrs.Name, err) + } + + err = nsutil.RunInNamespace(nsHandle, func() error { + client, err := wgctrl.New() if err != nil { - return fmt.Errorf("failed to init vrf: %w", err) + return fmt.Errorf("failed to init wgctrl: %w", err) } - e.vrfLink = vrf - } - if err := netlink.LinkSetUp(wg); err != nil { - return fmt.Errorf("ip link set %s up failed: %w", wg.LinkAttrs.Name, err) + e.wgctrl = client + + return nil + }) + + if err != nil { + return fmt.Errorf("running in netns %s failed: %w", e.wgNetnsName, err) } return nil @@ -165,8 +176,16 @@ func (e *wgEngine) reconfigWireguard(config wgtypes.Config) error { return nil } - if err := e.wgctrl.ConfigureDevice(e.ifaceName, diff); err != nil { - return fmt.Errorf("failed to configure device: %w", err) + err := nsutil.RunInNamespace(e.wgNetns, func() error { + if err := e.wgctrl.ConfigureDevice(e.ifaceName, diff); err != nil { + return fmt.Errorf("failed to configure device: %w", err) + } + + return nil + }) + + if err != nil { + return fmt.Errorf("running in netns %s failed: %w", e.wgNetnsName, err) } e.prevConfig = config @@ -175,7 +194,7 @@ func (e *wgEngine) reconfigWireguard(config wgtypes.Config) error { } func (e *wgEngine) reconfigAddresses(addrs []netlink.Addr) error { - current, err := e.netlink.AddrList(e.wireguard, netlink.FAMILY_ALL) + current, err := e.wgNetlink.AddrList(e.wireguard, netlink.FAMILY_ALL) if err != nil { return fmt.Errorf("failed to list addresses for %s: %w", e.ifaceName, err) @@ -185,13 +204,13 @@ func (e *wgEngine) reconfigAddresses(addrs []netlink.Addr) error { var lastErr error for _, d := range deleted { - if err := e.netlink.AddrDel(e.wireguard, &d); err != nil { + if err := e.wgNetlink.AddrDel(e.wireguard, &d); err != nil { lastErr = fmt.Errorf("failed to delete %s: %w", d, err) e.logger.Error("failed to delete an address", zap.Error(err), zap.String("addr", d.String())) } } for _, a := range added { - if err := e.netlink.AddrAdd(e.wireguard, &a); err != nil { + if err := e.wgNetlink.AddrAdd(e.wireguard, &a); err != nil { lastErr = fmt.Errorf("failed to add %s: %w", a, err) e.logger.Error("failed to add an address", zap.Error(err), zap.String("addr", a.String())) } @@ -209,16 +228,18 @@ func printRoutes(msg string, routes []netlink.Route) { } func (e *wgEngine) reconfigRoutes(config wgtypes.Config) error { - current, err := e.netlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ + table := 0 + + current, err := e.wgNetlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ LinkIndex: e.wireguard.Attrs().Index, - Table: int(e.table), + Table: table, }, netlink.RT_FILTER_OIF|netlink.RT_FILTER_TABLE) if err != nil { return fmt.Errorf("failed to list addresses for %s: %w", e.ifaceName, err) } - desired := generateRoutesFromWGConfig(config, e.wireguard, int(e.table)) + desired := generateRoutesFromWGConfig(config, e.wireguard, table) added, deleted := diffRoutes(desired, current) var lastErr error @@ -230,13 +251,13 @@ func (e *wgEngine) reconfigRoutes(config wgtypes.Config) error { continue } - if err := e.netlink.RouteDel(&d); err != nil { + if err := e.wgNetlink.RouteDel(&d); err != nil { lastErr = fmt.Errorf("failed to delete %s: %w", d.Dst, err) e.logger.Error("failed to delete a route", zap.Error(err), zap.String("dst", d.Dst.String())) } } for _, a := range added { - if err := e.netlink.RouteAdd(&a); err != nil { + if err := e.wgNetlink.RouteAdd(&a); err != nil { lastErr = fmt.Errorf("failed to add %s: %w", a.Dst, err) e.logger.Error("failed to add a route", zap.Error(err), zap.String("dst", a.Dst.String())) } @@ -265,11 +286,10 @@ func (e *wgEngine) Reconfig(config wgtypes.Config, addrs []netlink.Addr) error { } func (e *wgEngine) Close() error { - netlink.LinkDel(e.wireguard) - if e.vrfLink != nil { - netlink.LinkDel(e.vrfLink) - } - e.wgctrl.Close() + nsutil.RunInNamespace(e.wgNetns, func() error { + return e.wgctrl.Close() + }) + netns.DeleteNamed(e.wgNetnsName) return nil } diff --git a/tetraengine/wgengine/wgengine_test.go b/tetraengine/wgengine/wgengine_test.go new file mode 100644 index 000000000..2dd511e90 --- /dev/null +++ b/tetraengine/wgengine/wgengine_test.go @@ -0,0 +1,161 @@ +package wgengine_test + +import ( + "net" + "testing" + + "github.com/miscord-dev/tetrapod/tetraengine/wgengine" + "github.com/vishvananda/netlink" + "github.com/vishvananda/netns" + "go.uber.org/zap" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" +) + +func TestInit(t *testing.T) { + wgName := "tetra0" + netnsName := "netns" + + engine, err := wgengine.NewNetns(wgName, netnsName, zap.NewNop()) + + if err != nil { + t.Fatal(err) + } + // defer engine.Close() + + handle, err := netns.GetFromName(netnsName) + + if err != nil { + t.Fatal(err) + } + + wgNetlink, err := netlink.NewHandleAt(handle) + + if err != nil { + t.Fatal(err) + } + + wg, err := wgNetlink.LinkByName(wgName) + + if err != nil { + t.Fatal(err) + } + + key, _ := wgtypes.GeneratePrivateKey() + + route := net.IPNet{ + IP: net.IPv4(10, 0, 1, 0), + Mask: net.CIDRMask(24, 32), + } + ip, _ := netlink.ParseAddr("10.0.2.1/24") + + config, addrs := wgtypes.Config{ + PrivateKey: &key, + Peers: []wgtypes.PeerConfig{ + { + PublicKey: key.PublicKey(), + AllowedIPs: []net.IPNet{ + route, + }, + }, + }, + }, []netlink.Addr{ + *ip, + } + + t.Run("configure first", func(t *testing.T) { + err := engine.Reconfig(config, addrs) + + if err != nil { + t.Fatal(err) + } + + routes, err := wgNetlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ + LinkIndex: wg.Attrs().Index, + }, netlink.RT_FILTER_OIF) + + if err != nil { + t.Fatal(err) + } + + if expected := "10.0.1.0/24"; routes[0].Dst.String() != expected { + t.Error("expected", expected, "got", routes[0].Dst.String()) + } + + if wg.Attrs().MTU != 1280 { + t.Error("MTU is", wg.Attrs().MTU) + } + }) + + t.Run("add routes/addrs", func(t *testing.T) { + route2 := net.IPNet{ + IP: net.IPv4(10, 0, 2, 0), + Mask: net.CIDRMask(24, 32), + } + ip2, _ := netlink.ParseAddr("10.1.2.1/24") + + config.Peers = append(config.Peers, wgtypes.PeerConfig{ + PublicKey: key.PublicKey(), + AllowedIPs: []net.IPNet{ + route2, + }, + }) + addrs = append(addrs, *ip2) + + if err := engine.Reconfig(config, addrs); err != nil { + t.Fatal(err) + } + + routes, err := wgNetlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ + LinkIndex: wg.Attrs().Index, + }, netlink.RT_FILTER_OIF) + + if err != nil { + t.Fatal(err) + } + + if len(routes) != 2 { + t.Error("mismatched len of routes", routes) + } + + addrs, err := wgNetlink.AddrList(wg, netlink.FAMILY_ALL) + + if err != nil { + t.Fatal(err) + } + + if len(addrs) != 2 { + t.Error("mismatched len of addrs", addrs) + } + }) + + t.Run("remove routes/addrs", func(t *testing.T) { + config.Peers = config.Peers[:1] + addrs = addrs[:1] + + if err := engine.Reconfig(config, addrs); err != nil { + t.Fatal(err) + } + + routes, err := wgNetlink.RouteListFiltered(netlink.FAMILY_ALL, &netlink.Route{ + LinkIndex: wg.Attrs().Index, + }, netlink.RT_FILTER_OIF) + + if err != nil { + t.Fatal(err) + } + + if len(routes) != 1 { + t.Error("mismatched len of routes", routes) + } + + addrs, err := wgNetlink.AddrList(wg, netlink.FAMILY_ALL) + + if err != nil { + t.Fatal(err) + } + + if len(addrs) != 1 { + t.Error("mismatched len of addrs", addrs) + } + }) +}