From 099899991126ef393fd8ba7e972fe9975181420b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sat, 9 Nov 2024 11:39:59 +0800 Subject: [PATCH] udpnat2: Fix missing shared impl --- common/udpnat2/service.go | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/common/udpnat2/service.go b/common/udpnat2/service.go index e2a1482a..68df6669 100644 --- a/common/udpnat2/service.go +++ b/common/udpnat2/service.go @@ -14,7 +14,7 @@ import ( ) type Service struct { - nat *freelru.LRU[netip.AddrPort, *Conn] + cache freelru.Cache[netip.AddrPort, *Conn] handler N.UDPConnectionHandlerEx prepare PrepareFunc metrics Metrics @@ -29,10 +29,15 @@ type Metrics struct { Drops uint64 } -func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration) *Service { - nat := common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) - nat.SetLifetime(timeout) - nat.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { +func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Duration, shared bool) *Service { + var cache freelru.Cache[netip.AddrPort, *Conn] + if !shared { + cache = common.Must1(freelru.New[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + } else { + cache = common.Must1(freelru.NewSharded[netip.AddrPort, *Conn](1024, maphash.NewHasher[netip.AddrPort]().Hash32)) + } + cache.SetLifetime(timeout) + cache.SetHealthCheck(func(port netip.AddrPort, conn *Conn) bool { select { case <-conn.doneChan: return false @@ -40,18 +45,18 @@ func New(handler N.UDPConnectionHandlerEx, prepare PrepareFunc, timeout time.Dur return true } }) - nat.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { + cache.SetOnEvict(func(_ netip.AddrPort, conn *Conn) { conn.Close() }) return &Service{ - nat: nat, + cache: cache, handler: handler, prepare: prepare, } } func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destination M.Socksaddr, userData any) { - conn, loaded := s.nat.Get(source.AddrPort()) + conn, loaded := s.cache.Get(source.AddrPort()) if !loaded { ok, ctx, writer, onClose := s.prepare(source, destination, userData) if !ok { @@ -65,7 +70,7 @@ func (s *Service) NewPacket(bufferSlices [][]byte, source M.Socksaddr, destinati doneChan: make(chan struct{}), readDeadline: pipe.MakeDeadline(), } - s.nat.Add(source.AddrPort(), conn) + s.cache.Add(source.AddrPort(), conn) go s.handler.NewPacketConnectionEx(ctx, conn, source, destination, onClose) s.metrics.Creates++ }