From cc7e6309230fe6909280476318ebcf47616efa04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 11 Nov 2024 00:06:25 +0800 Subject: [PATCH] control: Refactor interface finder --- common/cond.go | 41 +++++++++--------- common/control/bind_darwin.go | 4 +- common/control/bind_finder.go | 44 +++++++++++++++++-- common/control/bind_finder_default.go | 62 +++++++++++---------------- common/control/bind_linux.go | 4 +- common/control/bind_windows.go | 6 +-- 6 files changed, 91 insertions(+), 70 deletions(-) diff --git a/common/cond.go b/common/cond.go index 6fe11bc2..5558715c 100644 --- a/common/cond.go +++ b/common/cond.go @@ -157,6 +157,18 @@ func IndexIndexed[T any](arr []T, block func(index int, it T) bool) int { return -1 } +func Equal[S ~[]E, E comparable](s1, s2 S) bool { + if len(s1) != len(s2) { + return false + } + for i := range s1 { + if s1[i] != s2[i] { + return false + } + } + return true +} + //go:norace func Dup[T any](obj T) T { pointer := uintptr(unsafe.Pointer(&obj)) @@ -268,6 +280,14 @@ func Reverse[T any](arr []T) []T { return arr } +func ReverseMap[K comparable, V comparable](m map[K]V) map[V]K { + ret := make(map[V]K, len(m)) + for k, v := range m { + ret[v] = k + } + return ret +} + func Done(ctx context.Context) bool { select { case <-ctx.Done(): @@ -362,24 +382,3 @@ func Close(closers ...any) error { } return retErr } - -// Deprecated: wtf is this? -type Starter interface { - Start() error -} - -// Deprecated: wtf is this? -func Start(starters ...any) error { - for _, rawStarter := range starters { - if rawStarter == nil { - continue - } - if starter, isStarter := rawStarter.(Starter); isStarter { - err := starter.Start() - if err != nil { - return err - } - } - } - return nil -} diff --git a/common/control/bind_darwin.go b/common/control/bind_darwin.go index bff6c293..2fb3db96 100644 --- a/common/control/bind_darwin.go +++ b/common/control/bind_darwin.go @@ -9,15 +9,15 @@ import ( func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { - var err error if interfaceIndex == -1 { if finder == nil { return os.ErrInvalid } - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + iif, err := finder.ByName(interfaceName) if err != nil { return err } + interfaceIndex = iif.Index } switch network { case "tcp6", "udp6": diff --git a/common/control/bind_finder.go b/common/control/bind_finder.go index 9b013d34..f956f190 100644 --- a/common/control/bind_finder.go +++ b/common/control/bind_finder.go @@ -3,21 +3,57 @@ package control import ( "net" "net/netip" + "unsafe" + + "github.com/sagernet/sing/common" + M "github.com/sagernet/sing/common/metadata" ) type InterfaceFinder interface { Update() error Interfaces() []Interface - InterfaceIndexByName(name string) (int, error) - InterfaceNameByIndex(index int) (string, error) - InterfaceByAddr(addr netip.Addr) (*Interface, error) + ByName(name string) (*Interface, error) + ByIndex(index int) (*Interface, error) + ByAddr(addr netip.Addr) (*Interface, error) } type Interface struct { Index int MTU int Name string - Addresses []netip.Prefix HardwareAddr net.HardwareAddr Flags net.Flags + Addresses []netip.Prefix +} + +func (i Interface) Equals(other Interface) bool { + return i.Index == other.Index && + i.MTU == other.MTU && + i.Name == other.Name && + common.Equal(i.HardwareAddr, other.HardwareAddr) && + i.Flags == other.Flags && + common.Equal(i.Addresses, other.Addresses) +} + +func (i Interface) NetInterface() net.Interface { + return *(*net.Interface)(unsafe.Pointer(&i)) +} + +func InterfaceFromNet(iif net.Interface) (Interface, error) { + ifAddrs, err := iif.Addrs() + if err != nil { + return Interface{}, err + } + return InterfaceFromNetAddrs(iif, common.Map(ifAddrs, M.PrefixFromNet)), nil +} + +func InterfaceFromNetAddrs(iif net.Interface, addresses []netip.Prefix) Interface { + return Interface{ + Index: iif.Index, + MTU: iif.MTU, + Name: iif.Name, + HardwareAddr: iif.HardwareAddr, + Flags: iif.Flags, + Addresses: addresses, + } } diff --git a/common/control/bind_finder_default.go b/common/control/bind_finder_default.go index 804497b6..cfc481e9 100644 --- a/common/control/bind_finder_default.go +++ b/common/control/bind_finder_default.go @@ -3,11 +3,8 @@ package control import ( "net" "net/netip" - _ "unsafe" - "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" - M "github.com/sagernet/sing/common/metadata" ) var _ InterfaceFinder = (*DefaultInterfaceFinder)(nil) @@ -27,18 +24,12 @@ func (f *DefaultInterfaceFinder) Update() error { } interfaces := make([]Interface, 0, len(netIfs)) for _, netIf := range netIfs { - ifAddrs, err := netIf.Addrs() + var iif Interface + iif, err = InterfaceFromNet(netIf) if err != nil { return err } - interfaces = append(interfaces, Interface{ - Index: netIf.Index, - MTU: netIf.MTU, - Name: netIf.Name, - Addresses: common.Map(ifAddrs, M.PrefixFromNet), - HardwareAddr: netIf.HardwareAddr, - Flags: netIf.Flags, - }) + interfaces = append(interfaces, iif) } f.interfaces = interfaces return nil @@ -52,46 +43,41 @@ func (f *DefaultInterfaceFinder) Interfaces() []Interface { return f.interfaces } -func (f *DefaultInterfaceFinder) InterfaceIndexByName(name string) (int, error) { +func (f *DefaultInterfaceFinder) ByName(name string) (*Interface, error) { for _, netInterface := range f.interfaces { if netInterface.Name == name { - return netInterface.Index, nil + return &netInterface, nil } } - netInterface, err := net.InterfaceByName(name) - if err != nil { - return 0, err + _, err := net.InterfaceByName(name) + if err == nil { + err = f.Update() + if err != nil { + return nil, err + } + return f.ByName(name) } - f.Update() - return netInterface.Index, nil + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")} } -func (f *DefaultInterfaceFinder) InterfaceNameByIndex(index int) (string, error) { +func (f *DefaultInterfaceFinder) ByIndex(index int) (*Interface, error) { for _, netInterface := range f.interfaces { if netInterface.Index == index { - return netInterface.Name, nil + return &netInterface, nil } } - netInterface, err := net.InterfaceByIndex(index) - if err != nil { - return "", err + _, err := net.InterfaceByIndex(index) + if err == nil { + err = f.Update() + if err != nil { + return nil, err + } + return f.ByIndex(index) } - f.Update() - return netInterface.Name, nil + return nil, &net.OpError{Op: "route", Net: "ip+net", Source: nil, Addr: &net.IPAddr{IP: nil}, Err: E.New("no such network interface")} } -func (f *DefaultInterfaceFinder) InterfaceByAddr(addr netip.Addr) (*Interface, error) { - for _, netInterface := range f.interfaces { - for _, prefix := range netInterface.Addresses { - if prefix.Contains(addr) { - return &netInterface, nil - } - } - } - err := f.Update() - if err != nil { - return nil, err - } +func (f *DefaultInterfaceFinder) ByAddr(addr netip.Addr) (*Interface, error) { for _, netInterface := range f.interfaces { for _, prefix := range netInterface.Addresses { if prefix.Contains(addr) { diff --git a/common/control/bind_linux.go b/common/control/bind_linux.go index c92bf6b0..c5e668d7 100644 --- a/common/control/bind_linux.go +++ b/common/control/bind_linux.go @@ -19,11 +19,11 @@ func bindToInterface(conn syscall.RawConn, network string, address string, finde if interfaceName == "" { return os.ErrInvalid } - var err error - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + iif, err := finder.ByName(interfaceName) if err != nil { return err } + interfaceIndex = iif.Index } err := unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_BINDTOIFINDEX, interfaceIndex) if err == nil { diff --git a/common/control/bind_windows.go b/common/control/bind_windows.go index a499556f..cf833868 100644 --- a/common/control/bind_windows.go +++ b/common/control/bind_windows.go @@ -11,19 +11,19 @@ import ( func bindToInterface(conn syscall.RawConn, network string, address string, finder InterfaceFinder, interfaceName string, interfaceIndex int, preferInterfaceName bool) error { return Raw(conn, func(fd uintptr) error { - var err error if interfaceIndex == -1 { if finder == nil { return os.ErrInvalid } - interfaceIndex, err = finder.InterfaceIndexByName(interfaceName) + iif, err := finder.ByName(interfaceName) if err != nil { return err } + interfaceIndex = iif.Index } handle := syscall.Handle(fd) if M.ParseSocksaddr(address).AddrString() == "" { - err = bind4(handle, interfaceIndex) + err := bind4(handle, interfaceIndex) if err != nil { return err }