diff --git a/common/network/conn.go b/common/network/conn.go index 01fe13517..c795a19de 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -124,7 +124,7 @@ type UDPHandler interface { } type UDPHandlerEx interface { - NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error + NewPacketEx(buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) } // Deprecated: Use UDPConnectionHandlerEx instead. diff --git a/common/network/direct.go b/common/network/direct.go index 24f38d7f7..822536cde 100644 --- a/common/network/direct.go +++ b/common/network/direct.go @@ -19,15 +19,27 @@ func (o ReadWaitOptions) NeedHeadroom() bool { return o.FrontHeadroom > 0 || o.RearHeadroom > 0 } +func (o ReadWaitOptions) Copy(buffer *buf.Buffer) *buf.Buffer { + if o.FrontHeadroom > buffer.Start() || + o.RearHeadroom > buffer.FreeLen() { + newBuffer := o.newBuffer(buf.UDPBufferSize, false) + newBuffer.Write(buffer.Bytes()) + buffer.Release() + return newBuffer + } else { + return buffer + } +} + func (o ReadWaitOptions) NewBuffer() *buf.Buffer { - return o.newBuffer(buf.BufferSize) + return o.newBuffer(buf.BufferSize, true) } func (o ReadWaitOptions) NewPacketBuffer() *buf.Buffer { - return o.newBuffer(buf.UDPBufferSize) + return o.newBuffer(buf.UDPBufferSize, true) } -func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { +func (o ReadWaitOptions) newBuffer(defaultBufferSize int, reserve bool) *buf.Buffer { var bufferSize int if o.MTU > 0 { bufferSize = o.MTU + o.FrontHeadroom + o.RearHeadroom @@ -36,9 +48,9 @@ func (o ReadWaitOptions) newBuffer(defaultBufferSize int) *buf.Buffer { } buffer := buf.NewSize(bufferSize) if o.FrontHeadroom > 0 { - buffer.Resize(o.FrontHeadroom, 0) + buffer.Advance(o.FrontHeadroom) } - if o.RearHeadroom > 0 { + if o.RearHeadroom > 0 && reserve { buffer.Reserve(o.RearHeadroom) } return buffer diff --git a/common/udpnat/service.go b/common/udpnat/service.go index a5b37dbf7..6f95dbb7f 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -131,8 +131,6 @@ func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf. s.nat.Delete(key) } }() - } else { - c.localAddr = source } if common.Done(c.ctx) { s.nat.Delete(key) @@ -215,10 +213,6 @@ func (c *conn) SetWriteDeadline(t time.Time) error { return os.ErrInvalid } -func (c *conn) NeedAdditionalReadDeadline() bool { - return true -} - func (c *conn) Upstream() any { return c.source } diff --git a/common/udpnat2/conn.go b/common/udpnat2/conn.go new file mode 100644 index 000000000..a5ca8ac22 --- /dev/null +++ b/common/udpnat2/conn.go @@ -0,0 +1,90 @@ +package udpnat + +import ( + "io" + "net" + "os" + "time" + + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" +) + +type natConn struct { + writer N.PacketWriter + localAddr M.Socksaddr + packetChan chan *Packet + doneChan chan struct{} + readDeadline pipe.Deadline + readWaitOptions N.ReadWaitOptions +} + +func (c *natConn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { + select { + case p := <-c.packetChan: + _, err = buffer.ReadOnceFrom(p.Buffer) + destination := p.Destination + p.Buffer.Release() + PutPacket(p) + return destination, err + case <-c.doneChan: + return M.Socksaddr{}, io.ErrClosedPipe + case <-c.readDeadline.Wait(): + return M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *natConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + return c.writer.WritePacket(buffer, destination) +} + +func (c *natConn) InitializeReadWaiter(options N.ReadWaitOptions) (needCopy bool) { + c.readWaitOptions = options + return false +} + +func (c *natConn) WaitReadPacket() (buffer *buf.Buffer, destination M.Socksaddr, err error) { + select { + case packet := <-c.packetChan: + buffer = c.readWaitOptions.Copy(packet.Buffer) + destination = packet.Destination + PutPacket(packet) + return + case <-c.doneChan: + return nil, M.Socksaddr{}, io.ErrClosedPipe + case <-c.readDeadline.Wait(): + return nil, M.Socksaddr{}, os.ErrDeadlineExceeded + } +} + +func (c *natConn) Close() error { + select { + case <-c.doneChan: + default: + close(c.doneChan) + } + return nil +} + +func (c *natConn) LocalAddr() net.Addr { + return c.localAddr +} + +func (c *natConn) RemoteAddr() net.Addr { + return M.Socksaddr{} +} + +func (c *natConn) SetDeadline(t time.Time) error { + return os.ErrInvalid +} + +func (c *natConn) SetReadDeadline(t time.Time) error { + c.readDeadline.Set(t) + return nil +} + +func (c *natConn) SetWriteDeadline(t time.Time) error { + return os.ErrInvalid +} diff --git a/common/udpnat2/packet.go b/common/udpnat2/packet.go new file mode 100644 index 000000000..c3761e84e --- /dev/null +++ b/common/udpnat2/packet.go @@ -0,0 +1,28 @@ +package udpnat + +import ( + "net/netip" + "sync" + + "github.com/sagernet/sing/common/buf" +) + +var packetPool = sync.Pool{ + New: func() any { + return new(Packet) + }, +} + +type Packet struct { + Buffer *buf.Buffer + Destination netip.AddrPort +} + +func NewPacket() *Packet { + return packetPool.Get().(*Packet) +} + +func PutPacket(packet *Packet) { + *packet = Packet{} + packetPool.Put(packet) +} diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go new file mode 100644 index 000000000..8f720dfe2 --- /dev/null +++ b/common/udpnat2/service.go @@ -0,0 +1,100 @@ +package udpnat + +import ( + "context" + "net/netip" + "time" + + "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/buf" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/pipe" + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" +) + +type Service struct { + nat *freelru.LRU[netip.AddrPort, *natConn] + handler Handler + metrics Metrics +} + +type Handler interface { + PreparePacketConnection(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort, userData any) (bool, context.Context, N.PacketWriter, N.CloseHandlerFunc) + N.UDPConnectionHandlerEx +} + +type Metrics struct { + Creates uint64 + Rejects uint64 + Inputs uint64 + Drops uint64 +} + +func New(handler Handler, timeout time.Duration) *Service { + nat := common.Must1(freelru.New[netip.AddrPort, *natConn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + nat.SetLifetime(timeout) + nat.SetHealthCheck(func(port netip.AddrPort, conn *natConn) bool { + select { + case <-conn.doneChan: + return false + default: + return true + } + }) + nat.SetOnEvict(func(_ netip.AddrPort, conn *natConn) { + conn.Close() + }) + return &Service{ + nat: nat, + handler: handler, + } +} + +func (s *Service) NewPacket(buffer *buf.Buffer, source netip.AddrPort, destination netip.AddrPort, userData any) { + conn, loaded := s.nat.Get(source) + if !loaded { + ok, ctx, writer, onClose := s.handler.PreparePacketConnection(buffer, source, destination, userData) + if !ok { + buffer.Release() + s.metrics.Rejects++ + return + } + conn = &natConn{ + writer: writer, + localAddr: M.SocksaddrFromNetIP(source), + packetChan: make(chan *Packet, 64), + doneChan: make(chan struct{}), + readDeadline: pipe.MakeDeadline(), + } + packet := NewPacket() + *packet = Packet{ + Buffer: buffer, + Destination: destination, + } + conn.packetChan <- packet + s.nat.Add(source, conn) + s.handler.NewPacketConnectionEx(ctx, conn, M.SocksaddrFromNetIP(source), M.SocksaddrFromNetIP(destination), onClose) + s.metrics.Creates++ + s.metrics.Inputs++ + return + } + packet := NewPacket() + *packet = Packet{ + Buffer: conn.readWaitOptions.Copy(buffer), + Destination: destination, + } + select { + case conn.packetChan <- packet: + s.metrics.Inputs++ + default: + packet.Buffer.Release() + PutPacket(packet) + s.metrics.Drops++ + } +} + +func (s *Service) Metrics() Metrics { + return s.metrics +} diff --git a/contrab/freelru/lru.go b/contrab/freelru/lru.go index af8b8e919..045cc3ee0 100644 --- a/contrab/freelru/lru.go +++ b/contrab/freelru/lru.go @@ -31,6 +31,8 @@ type OnEvictCallback[K comparable, V any] func(K, V) // HashKeyCallback is the function that creates a hash from the passed key. type HashKeyCallback[K comparable] func(K) uint32 +type HealthCheckCallback[K comparable, V any] func(K, V) bool + type element[K comparable, V any] struct { key K value V @@ -61,12 +63,13 @@ const emptyBucket = math.MaxUint32 // LRU implements a non-thread safe fixed size LRU cache. type LRU[K comparable, V any] struct { - buckets []uint32 // contains positions of bucket lists or 'emptyBucket' - elements []element[K, V] - onEvict OnEvictCallback[K, V] - hash HashKeyCallback[K] - lifetime time.Duration - metrics Metrics + buckets []uint32 // contains positions of bucket lists or 'emptyBucket' + elements []element[K, V] + onEvict OnEvictCallback[K, V] + hash HashKeyCallback[K] + healthCheck HealthCheckCallback[K, V] + lifetime time.Duration + metrics Metrics // used for element clearing after removal or expiration emptyKey K @@ -108,6 +111,10 @@ func (lru *LRU[K, V]) SetOnEvict(onEvict OnEvictCallback[K, V]) { lru.onEvict = onEvict } +func (lru *LRU[K, V]) SetHealthCheck(healthCheck HealthCheckCallback[K, V]) { + lru.healthCheck = healthCheck +} + // New constructs an LRU with the given capacity of elements. // The hash function calculates a hash value from the keys. func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, V], error) { @@ -120,7 +127,8 @@ func New[K comparable, V any](capacity uint32, hash HashKeyCallback[K]) (*LRU[K, // by reducing the chance of collisions. // Size must not be lower than the capacity. func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallback[K]) ( - *LRU[K, V], error) { + *LRU[K, V], error, +) { if capacity == 0 { return nil, errors.New("capacity must be positive") } @@ -144,7 +152,8 @@ func NewWithSize[K comparable, V any](capacity, size uint32, hash HashKeyCallbac } func initLRU[K comparable, V any](lru *LRU[K, V], capacity, size uint32, hash HashKeyCallback[K], - buckets []uint32, elements []element[K, V]) { + buckets []uint32, elements []element[K, V], +) { lru.cap = capacity lru.size = size lru.hash = hash @@ -294,7 +303,7 @@ func (lru *LRU[K, V]) clearKeyAndValue(pos uint32) { lru.elements[pos].value = lru.emptyValue } -func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { +func (lru *LRU[K, V]) findKey(hash uint32, key K, updateLifetimeOnGet bool) (uint32, bool) { _, startPos := lru.hashToPos(hash) if startPos == emptyBucket { return emptyBucket, false @@ -303,10 +312,14 @@ func (lru *LRU[K, V]) findKey(hash uint32, key K) (uint32, bool) { pos := startPos for { if key == lru.elements[pos].key { - if lru.elements[pos].expire != 0 && lru.elements[pos].expire <= now() { + elem := lru.elements[pos] + if (elem.expire != 0 && elem.expire <= now()) || (lru.healthCheck != nil && !lru.healthCheck(key, elem.value)) { lru.removeAt(pos) return emptyBucket, false } + if updateLifetimeOnGet { + lru.elements[pos].expire = expire(lru.lifetime) + } return pos, true } @@ -330,7 +343,8 @@ func (lru *LRU[K, V]) AddWithLifetime(key K, value V, lifetime time.Duration) (e } func (lru *LRU[K, V]) addWithLifetime(hash uint32, key K, value V, - lifetime time.Duration) (evicted bool) { + lifetime time.Duration, +) (evicted bool) { bucketPos, startPos := lru.hashToPos(hash) if startPos == emptyBucket { pos := lru.len @@ -425,11 +439,11 @@ func (lru *LRU[K, V]) add(hash uint32, key K, value V) (evicted bool) { // If the found cache item is already expired, the evict function is called // and the return value indicates that the key was not found. func (lru *LRU[K, V]) Get(key K) (value V, ok bool) { - return lru.get(lru.hash(key), key) + return lru.get(lru.hash(key), key, true) } -func (lru *LRU[K, V]) get(hash uint32, key K) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key); ok { +func (lru *LRU[K, V]) get(hash uint32, key K, updateLifetime bool) (value V, ok bool) { + if pos, ok := lru.findKey(hash, key, updateLifetime); ok { if pos != lru.head { lru.unlinkElement(pos) lru.setHead(pos) @@ -449,7 +463,7 @@ func (lru *LRU[K, V]) Peek(key K) (value V, ok bool) { } func (lru *LRU[K, V]) peek(hash uint32, key K) (value V, ok bool) { - if pos, ok := lru.findKey(hash, key); ok { + if pos, ok := lru.findKey(hash, key, false); ok { return lru.elements[pos].value, ok } @@ -476,7 +490,7 @@ func (lru *LRU[K, V]) Remove(key K) (removed bool) { } func (lru *LRU[K, V]) remove(hash uint32, key K) (removed bool) { - if pos, ok := lru.findKey(hash, key); ok { + if pos, ok := lru.findKey(hash, key, false); ok { lru.removeAt(pos) return ok } diff --git a/contrab/freelru/lru_test.go b/contrab/freelru/lru_test.go new file mode 100644 index 000000000..4c4ba5919 --- /dev/null +++ b/contrab/freelru/lru_test.go @@ -0,0 +1,35 @@ +package freelru_test + +import ( + "testing" + "time" + + "github.com/sagernet/sing/contrab/freelru" + "github.com/sagernet/sing/contrab/maphash" + + "github.com/stretchr/testify/require" +) + +func TestMyChange0(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.AddWithLifetime("hello", "world", 2*time.Second) + time.Sleep(time.Second) + lru.Get("hello") + time.Sleep(time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.True(t, ok) +} + +func TestMyChange1(t *testing.T) { + t.Parallel() + lru, err := freelru.New[string, string](1024, maphash.NewHasher[string]().Hash32) + require.NoError(t, err) + lru.AddWithLifetime("hello", "world", 2*time.Second) + time.Sleep(time.Second) + lru.Peek("hello") + time.Sleep(time.Second + time.Millisecond*100) + _, ok := lru.Get("hello") + require.False(t, ok) +} diff --git a/contrab/maphash/hasher.go b/contrab/maphash/hasher.go index ef53596a2..cc60b2e4c 100644 --- a/contrab/maphash/hasher.go +++ b/contrab/maphash/hasher.go @@ -46,3 +46,8 @@ func (h Hasher[K]) Hash(key K) uint64 { p := noescape(unsafe.Pointer(&key)) return uint64(h.hash(p, h.seed)) } + +func (h Hasher[K]) Hash32(key K) uint32 { + p := noescape(unsafe.Pointer(&key)) + return uint32(h.hash(p, h.seed)) +} diff --git a/contrab/maphash/runtime.go b/contrab/maphash/runtime.go index 29cd6a8ed..f2aa2e06f 100644 --- a/contrab/maphash/runtime.go +++ b/contrab/maphash/runtime.go @@ -52,6 +52,7 @@ func newHashSeed() uintptr { //go:nocheckptr func noescape(p unsafe.Pointer) unsafe.Pointer { x := uintptr(p) + //nolint:staticcheck return unsafe.Pointer(x ^ 0) } @@ -91,9 +92,11 @@ type hmap struct { } // go/src/runtime/type.go -type tflag uint8 -type nameOff int32 -type typeOff int32 +type ( + tflag uint8 + nameOff int32 + typeOff int32 +) // go/src/runtime/type.go type _type struct {