diff --git a/common/bufio/addr_conn.go b/common/bufio/addr_conn.go index d74ce9f6..4d095b51 100644 --- a/common/bufio/addr_conn.go +++ b/common/bufio/addr_conn.go @@ -9,19 +9,20 @@ import ( type AddrConn struct { net.Conn - M.Metadata + Source M.Socksaddr + Destination M.Socksaddr } func (c *AddrConn) LocalAddr() net.Addr { - if c.Metadata.Destination.IsValid() { - return c.Metadata.Destination.TCPAddr() + if c.Destination.IsValid() { + return c.Destination.TCPAddr() } return c.Conn.LocalAddr() } func (c *AddrConn) RemoteAddr() net.Addr { - if c.Metadata.Source.IsValid() { - return c.Metadata.Source.TCPAddr() + if c.Source.IsValid() { + return c.Source.TCPAddr() } return c.Conn.RemoteAddr() } diff --git a/common/bufio/vectorised_unix.go b/common/bufio/vectorised_unix.go index 6bb5d7d8..b0697f45 100644 --- a/common/bufio/vectorised_unix.go +++ b/common/bufio/vectorised_unix.go @@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error { var innerErr unix.Errno err := w.rawConn.Write(func(fd uintptr) (done bool) { //nolint:staticcheck - //goland:noinspection GoDeprecation _, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList))) return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK }) diff --git a/common/exceptions/error.go b/common/exceptions/error.go index 5d056e6f..16b075a2 100644 --- a/common/exceptions/error.go +++ b/common/exceptions/error.go @@ -12,6 +12,7 @@ import ( F "github.com/sagernet/sing/common/format" ) +// Deprecated: wtf is this? type Handler interface { NewError(ctx context.Context, err error) } diff --git a/common/exceptions/timeout.go b/common/exceptions/timeout.go index f2ae6c33..222123a1 100644 --- a/common/exceptions/timeout.go +++ b/common/exceptions/timeout.go @@ -12,7 +12,6 @@ type TimeoutError interface { func IsTimeout(err error) bool { var netErr net.Error if errors.As(err, &netErr) { - //goland:noinspection GoDeprecation //nolint:staticcheck return netErr.Temporary() && netErr.Timeout() } diff --git a/common/metadata/metadata.go b/common/metadata/metadata.go index db2d7d0c..a67cb6a6 100644 --- a/common/metadata/metadata.go +++ b/common/metadata/metadata.go @@ -1,5 +1,6 @@ package metadata +// Deprecated: wtf is this? type Metadata struct { Protocol string Source Socksaddr diff --git a/common/network/conn.go b/common/network/conn.go index a920ab6e..01fe1351 100644 --- a/common/network/conn.go +++ b/common/network/conn.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "sync" "time" "github.com/sagernet/sing/common" @@ -70,8 +71,38 @@ type ExtendedConn interface { net.Conn } +type CloseHandlerFunc = func(it error) + +func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc { + if parent == nil { + return parent + } else if onClose == nil { + return onClose + } + return func(it error) { + onClose(it) + parent(it) + } +} + +func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc { + var once sync.Once + return func(it error) { + once.Do(func() { + onClose(it) + }) + } +} + +// Deprecated: Use TCPConnectionHandlerEx instead. type TCPConnectionHandler interface { - NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error + NewConnection(ctx context.Context, conn net.Conn, + //nolint:staticcheck + metadata M.Metadata) error +} + +type TCPConnectionHandlerEx interface { + NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc) } type NetPacketConn interface { @@ -85,12 +116,26 @@ type BindPacketConn interface { net.Conn } +// Deprecated: Use UDPHandlerEx instead. type UDPHandler interface { - NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error + NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, + //nolint:staticcheck + metadata M.Metadata) error } +type UDPHandlerEx interface { + NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error +} + +// Deprecated: Use UDPConnectionHandlerEx instead. type UDPConnectionHandler interface { - NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error + NewPacketConnection(ctx context.Context, conn PacketConn, + //nolint:staticcheck + metadata M.Metadata) error +} + +type UDPConnectionHandlerEx interface { + NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc) } type CachedReader interface { diff --git a/common/network/handshake.go b/common/network/handshake.go index 674211d8..5f13492b 100644 --- a/common/network/handshake.go +++ b/common/network/handshake.go @@ -1,6 +1,9 @@ package network import ( + "io" + "net" + "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" ) @@ -13,17 +16,71 @@ type HandshakeSuccess interface { HandshakeSuccess() error } -func ReportHandshakeFailure(conn any, err error) error { - if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn { +type ConnHandshakeSuccess interface { + ConnHandshakeSuccess(conn net.Conn) error +} + +type PacketConnHandshakeSuccess interface { + PacketConnHandshakeSuccess(conn net.PacketConn) error +} + +func ReportHandshakeFailure(reporter any, err error) error { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { return E.Cause(err, "write handshake failure") }) } + return nil +} + +func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error { + if err != nil { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn { + err = E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error { + return E.Cause(err, "write handshake failure") + }) + } else { + if tcpConn, isTCPConn := common.Cast[interface { + SetLinger(sec int) error + }](reporter); isTCPConn { + tcpConn.SetLinger(0) + } + if closer, isCloser := reporter.(io.Closer); isCloser { + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close") + }) + } + } + } + if onClose != nil { + onClose(err) + } return err } -func ReportHandshakeSuccess(conn any) error { - if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn { +// Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead +func ReportHandshakeSuccess(reporter any) error { + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.HandshakeSuccess() + } + return nil +} + +func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error { + if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.ConnHandshakeSuccess(conn) + } + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.HandshakeSuccess() + } + return nil +} + +func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error { + if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn { + return handshakeConn.PacketConnHandshakeSuccess(conn) + } + if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn { return handshakeConn.HandshakeSuccess() } return nil diff --git a/common/network/thread.go b/common/network/thread.go index 22063af4..4e47da7f 100644 --- a/common/network/thread.go +++ b/common/network/thread.go @@ -11,6 +11,7 @@ type ThreadUnsafeWriter interface { } // Deprecated: Use ReadWaiter interface instead. + type ThreadSafeReader interface { // Deprecated: Use ReadWaiter interface instead. ReadBufferThreadSafe() (buffer *buf.Buffer, err error) @@ -18,7 +19,6 @@ type ThreadSafeReader interface { // Deprecated: Use ReadWaiter interface instead. type ThreadSafePacketReader interface { - // Deprecated: Use ReadWaiter interface instead. ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) } diff --git a/common/random/seed_go119.go b/common/random/seed_go119.go index c0da2efd..d339fcaf 100644 --- a/common/random/seed_go119.go +++ b/common/random/seed_go119.go @@ -20,6 +20,5 @@ func InitializeSeed() { func initializeSeed() { var seed int64 common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed)) - //goland:noinspection GoDeprecation mRand.Seed(seed) } diff --git a/common/rw/varint.go b/common/rw/varint.go index f9f5ca97..d19f1628 100644 --- a/common/rw/varint.go +++ b/common/rw/varint.go @@ -27,7 +27,6 @@ func ToByteReader(reader io.Reader) io.ByteReader { // Deprecated: Use binary.ReadUvarint instead. func ReadUVariant(reader io.Reader) (uint64, error) { - //goland:noinspection GoDeprecation return binary.ReadUvarint(ToByteReader(reader)) } diff --git a/common/udpnat/service.go b/common/udpnat/service.go index bdd917de..a5b37dbf 100644 --- a/common/udpnat/service.go +++ b/common/udpnat/service.go @@ -16,18 +16,23 @@ import ( "github.com/sagernet/sing/common/pipe" ) +// Deprecated: Use N.UDPConnectionHandler instead. +// +//nolint:staticcheck type Handler interface { N.UDPConnectionHandler E.Handler } type Service[K comparable] struct { - nat *cache.LruCache[K, *conn] - handler Handler + nat *cache.LruCache[K, *conn] + handler Handler + handlerEx N.UDPConnectionHandlerEx } +// Deprecated: Use NewEx instead. func New[K comparable](maxAge int64, handler Handler) *Service[K] { - return &Service[K]{ + service := &Service[K]{ nat: cache.New( cache.WithAge[K, *conn](maxAge), cache.WithUpdateAgeOnGet[K, *conn](), @@ -37,11 +42,27 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] { ), handler: handler, } + return service +} + +func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] { + service := &Service[K]{ + nat: cache.New( + cache.WithAge[K, *conn](maxAge), + cache.WithUpdateAgeOnGet[K, *conn](), + cache.WithEvict[K, *conn](func(key K, conn *conn) { + conn.Close() + }), + ), + handlerEx: handler, + } + return service } func (s *Service[T]) WriteIsThreadUnsafe() { } +// Deprecated: don't use func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) { s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { return ctx, &DirectBackWriter{conn, natConn} @@ -61,18 +82,30 @@ func (w *DirectBackWriter) Upstream() any { return w.Source } +// Deprecated: use NewPacketEx instead. func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) { s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { return ctx, init(natConn) }) } +func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) { + s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { + return ctx, init(natConn) + }) +} + +// Deprecated: Use NewPacketConnectionEx instead. func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { + s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init) +} + +func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { c, loaded := s.nat.LoadOrStore(key, func() *conn { c := &conn{ data: make(chan packet, 64), - localAddr: metadata.Source, - remoteAddr: metadata.Destination, + localAddr: source, + remoteAddr: destination, readDeadline: pipe.MakeDeadline(), } c.ctx, c.cancel = common.ContextWithCancelCause(ctx) @@ -81,26 +114,36 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu if !loaded { ctx, c.source = init(c) go func() { - err := s.handler.NewPacketConnection(ctx, c, metadata) - if err != nil { - s.handler.NewError(ctx, err) + if s.handlerEx != nil { + s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) { + s.nat.Delete(key) + }) + } else { + //nolint:staticcheck + err := s.handler.NewPacketConnection(ctx, c, M.Metadata{ + Source: source, + Destination: destination, + }) + if err != nil { + s.handler.NewError(ctx, err) + } + c.Close() + s.nat.Delete(key) } - c.Close() - s.nat.Delete(key) }() } else { - c.localAddr = metadata.Source + c.localAddr = source } if common.Done(c.ctx) { s.nat.Delete(key) if !common.Done(ctx) { - s.NewContextPacket(ctx, key, buffer, metadata, init) + s.NewContextPacketEx(ctx, key, buffer, source, destination, init) } return } c.data <- packet{ data: buffer, - destination: metadata.Destination, + destination: destination, } } diff --git a/protocol/http/handshake.go b/protocol/http/handshake.go index 8a156ad4..955a7225 100644 --- a/protocol/http/handshake.go +++ b/protocol/http/handshake.go @@ -20,9 +20,18 @@ import ( "github.com/sagernet/sing/common/pipe" ) -type Handler = N.TCPConnectionHandler +// Deprecated: Use HandleConnectionEx instead. +func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, + //nolint:staticcheck + handler N.TCPConnectionHandler, metadata M.Metadata, +) error { + return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, nil) +} -func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { +func HandleConnectionEx(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, + //nolint:staticcheck + handler N.TCPConnectionHandler, handlerEx N.TCPConnectionHandlerEx, source M.Socksaddr, onClose N.CloseHandlerFunc, +) error { for { request, err := ReadRequest(reader) if err != nil { @@ -68,7 +77,7 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } if sourceAddress := SourceAddress(request); sourceAddress.IsValid() { - metadata.Source = sourceAddress + source = sourceAddress } if request.Method == "CONNECT" { @@ -81,9 +90,6 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read if err != nil { return E.Cause(err, "write http response") } - metadata.Protocol = "http" - metadata.Destination = destination - var requestConn net.Conn if reader.Buffered() > 0 { buffer := buf.NewSize(reader.Buffered()) @@ -95,75 +101,105 @@ func HandleConnection(ctx context.Context, conn net.Conn, reader *std_bufio.Read } else { requestConn = conn } - return handler.NewConnection(ctx, requestConn, metadata) + if handler != nil { + //nolint:staticcheck + return handler.NewConnection(ctx, requestConn, M.Metadata{Protocol: "http", Source: source, Destination: destination}) + } else { + handlerEx.NewConnectionEx(ctx, requestConn, source, destination, onClose) + return nil + } } - keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" - request.RequestURI = "" + err = handleHTTPConnection(ctx, handler, handlerEx, conn, request, source) + if err != nil { + return err + } + } +} - removeHopByHopHeaders(request.Header) - removeExtraHTTPHostPort(request) +func handleHTTPConnection( + ctx context.Context, + //nolint:staticcheck + handler N.TCPConnectionHandler, + handlerEx N.TCPConnectionHandlerEx, + conn net.Conn, + request *http.Request, source M.Socksaddr, +) error { + keepAlive := !(request.ProtoMajor == 1 && request.ProtoMinor == 0) && strings.TrimSpace(strings.ToLower(request.Header.Get("Proxy-Connection"))) == "keep-alive" + request.RequestURI = "" - if hostStr := request.Header.Get("Host"); hostStr != "" { - if hostStr != request.URL.Host { - request.Host = hostStr - } - } + removeHopByHopHeaders(request.Header) + removeExtraHTTPHostPort(request) - if request.URL.Scheme == "" || request.URL.Host == "" { - return responseWith(request, http.StatusBadRequest).Write(conn) + if hostStr := request.Header.Get("Host"); hostStr != "" { + if hostStr != request.URL.Host { + request.Host = hostStr } + } - var innerErr atomic.TypedValue[error] - httpClient := &http.Client{ - Transport: &http.Transport{ - DisableCompression: true, - DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { - metadata.Destination = M.ParseSocksaddr(address) - metadata.Protocol = "http" - input, output := pipe.Pipe() + if request.URL.Scheme == "" || request.URL.Host == "" { + return responseWith(request, http.StatusBadRequest).Write(conn) + } + + var innerErr atomic.TypedValue[error] + httpClient := &http.Client{ + Transport: &http.Transport{ + DisableCompression: true, + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + input, output := pipe.Pipe() + if handler != nil { go func() { - hErr := handler.NewConnection(ctx, output, metadata) + //nolint:staticcheck + hErr := handler.NewConnection(ctx, output, M.Metadata{Protocol: "http", Source: source, Destination: M.ParseSocksaddr(address)}) if hErr != nil { innerErr.Store(hErr) common.Close(input, output) } }() - return input, nil - }, - }, - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse + } else { + go handlerEx.NewConnectionEx(ctx, output, source, M.ParseSocksaddr(address), func(it error) { + innerErr.Store(it) + common.Close(input, output) + }) + } + return input, nil }, - } - requestCtx, cancel := context.WithCancel(ctx) - response, err := httpClient.Do(request.WithContext(requestCtx)) - if err != nil { - cancel() - return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) - } + }, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + } + defer httpClient.CloseIdleConnections() - removeHopByHopHeaders(response.Header) + requestCtx, cancel := context.WithCancel(ctx) + response, err := httpClient.Do(request.WithContext(requestCtx)) + if err != nil { + cancel() + return E.Errors(innerErr.Load(), err, responseWith(request, http.StatusBadGateway).Write(conn)) + } - if keepAlive { - response.Header.Set("Proxy-Connection", "keep-alive") - response.Header.Set("Connection", "keep-alive") - response.Header.Set("Keep-Alive", "timeout=4") - } + removeHopByHopHeaders(response.Header) - response.Close = !keepAlive + if keepAlive { + response.Header.Set("Proxy-Connection", "keep-alive") + response.Header.Set("Connection", "keep-alive") + response.Header.Set("Keep-Alive", "timeout=4") + } - err = response.Write(conn) - if err != nil { - cancel() - return E.Errors(innerErr.Load(), err) - } + response.Close = !keepAlive + err = response.Write(conn) + if err != nil { cancel() - if !keepAlive { - return conn.Close() - } + return E.Errors(innerErr.Load(), err) } + + cancel() + if !keepAlive { + return conn.Close() + } + + return nil } func removeHopByHopHeaders(header http.Header) { diff --git a/protocol/socks/handshake.go b/protocol/socks/handshake.go index 8ee2542b..7232eeac 100644 --- a/protocol/socks/handshake.go +++ b/protocol/socks/handshake.go @@ -19,11 +19,19 @@ import ( "github.com/sagernet/sing/protocol/socks/socks5" ) +// Deprecated: Use HandlerEx instead. +// +//nolint:staticcheck type Handler interface { N.TCPConnectionHandler N.UDPConnectionHandler } +type HandlerEx interface { + N.TCPConnectionHandlerEx + N.UDPConnectionHandlerEx +} + func ClientHandshake4(conn io.ReadWriter, command byte, destination M.Socksaddr, username string) (socks4.Response, error) { err := socks4.WriteRequest(conn, socks4.Request{ Command: command, @@ -96,18 +104,33 @@ func ClientHandshake5(conn io.ReadWriter, command byte, destination M.Socksaddr, return response, err } +// Deprecated: use HandleConnectionEx instead. func HandleConnection(ctx context.Context, conn net.Conn, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { return HandleConnection0(ctx, conn, std_bufio.NewReader(conn), authenticator, handler, metadata) } +// Deprecated: Use HandleConnectionEx instead. func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Reader, authenticator *auth.Authenticator, handler Handler, metadata M.Metadata) error { + return HandleConnectionEx(ctx, conn, reader, authenticator, handler, nil, metadata.Source, metadata.Destination, nil) +} + +func HandleConnectionEx( + ctx context.Context, conn net.Conn, reader *std_bufio.Reader, + authenticator *auth.Authenticator, + //nolint:staticcheck + handler Handler, + handlerEx HandlerEx, + source M.Socksaddr, destination M.Socksaddr, + onClose N.CloseHandlerFunc, +) error { version, err := reader.ReadByte() if err != nil { return err } switch version { case socks4.Version: - request, err := socks4.ReadRequest0(reader) + var request socks4.Request + request, err = socks4.ReadRequest0(reader) if err != nil { return err } @@ -115,28 +138,31 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea case socks4.CommandConnect: if authenticator != nil && !authenticator.Verify(request.Username, "") { err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeRejectedOrFailed, - Destination: request.Destination, + ReplyCode: socks4.ReplyCodeRejectedOrFailed, }) if err != nil { return err } return E.New("socks4: authentication failed, username=", request.Username) } - err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeGranted, - Destination: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err + destination = request.Destination + if handlerEx != nil { + handlerEx.NewConnectionEx(auth.ContextWithUser(ctx, request.Username), NewLazyConn(conn, version), source, destination, onClose) + } else { + err = socks4.WriteResponse(conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFromNet(conn.LocalAddr()), + }) + if err != nil { + return err + } + //nolint:staticcheck + return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, M.Metadata{Protocol: "socks4", Source: source, Destination: destination}) } - metadata.Protocol = "socks4" - metadata.Destination = request.Destination - return handler.NewConnection(auth.ContextWithUser(ctx, request.Username), conn, metadata) + return nil default: err = socks4.WriteResponse(conn, socks4.Response{ - ReplyCode: socks4.ReplyCodeRejectedOrFailed, - Destination: request.Destination, + ReplyCode: socks4.ReplyCodeRejectedOrFailed, }) if err != nil { return err @@ -144,7 +170,8 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea return E.New("socks4: unsupported command ", request.Command) } case socks5.Version: - authRequest, err := socks5.ReadAuthRequest0(reader) + var authRequest socks5.AuthRequest + authRequest, err = socks5.ReadAuthRequest0(reader) if err != nil { return err } @@ -169,7 +196,8 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea return err } if authMethod == socks5.AuthTypeUsernamePassword { - usernamePasswordAuthRequest, err := socks5.ReadUsernamePasswordAuthRequest(reader) + var usernamePasswordAuthRequest socks5.UsernamePasswordAuthRequest + usernamePasswordAuthRequest, err = socks5.ReadUsernamePasswordAuthRequest(reader) if err != nil { return err } @@ -188,49 +216,60 @@ func HandleConnection0(ctx context.Context, conn net.Conn, reader *std_bufio.Rea return E.New("socks5: authentication failed, username=", usernamePasswordAuthRequest.Username, ", password=", usernamePasswordAuthRequest.Password) } } - request, err := socks5.ReadRequest(reader) + var request socks5.Request + request, err = socks5.ReadRequest(reader) if err != nil { return err } switch request.Command { case socks5.CommandConnect: - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(conn.LocalAddr()), - }) - if err != nil { - return err + destination = request.Destination + if handlerEx != nil { + handlerEx.NewConnectionEx(ctx, NewLazyConn(conn, version), source, destination, onClose) + return nil + } else { + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(conn.LocalAddr()), + }) + if err != nil { + return err + } + //nolint:staticcheck + return handler.NewConnection(ctx, conn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) } - metadata.Protocol = "socks5" - metadata.Destination = request.Destination - return handler.NewConnection(ctx, conn, metadata) case socks5.CommandUDPAssociate: var udpConn *net.UDPConn udpConn, err = net.ListenUDP(M.NetworkFromNetAddr("udp", M.AddrFromNet(conn.LocalAddr())), net.UDPAddrFromAddrPort(netip.AddrPortFrom(M.AddrFromNet(conn.LocalAddr()), 0))) if err != nil { return err } - defer udpConn.Close() - err = socks5.WriteResponse(conn, socks5.Response{ - ReplyCode: socks5.ReplyCodeSuccess, - Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), - }) - if err != nil { - return err + if handlerEx == nil { + defer udpConn.Close() + err = socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(udpConn.LocalAddr()), + }) + if err != nil { + return err + } + destination = request.Destination + associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn) + var innerError error + done := make(chan struct{}) + go func() { + //nolint:staticcheck + innerError = handler.NewPacketConnection(ctx, associatePacketConn, M.Metadata{Protocol: "socks5", Source: source, Destination: destination}) + close(done) + }() + err = common.Error(io.Copy(io.Discard, conn)) + associatePacketConn.Close() + <-done + return E.Errors(innerError, err) + } else { + handlerEx.NewPacketConnectionEx(ctx, NewLazyAssociatePacketConn(bufio.NewServerPacketConn(udpConn), destination, conn), source, destination, onClose) + return nil } - metadata.Protocol = "socks5" - metadata.Destination = request.Destination - var innerError error - done := make(chan struct{}) - associatePacketConn := NewAssociatePacketConn(bufio.NewServerPacketConn(udpConn), request.Destination, conn) - go func() { - innerError = handler.NewPacketConnection(ctx, associatePacketConn, metadata) - close(done) - }() - err = common.Error(io.Copy(io.Discard, conn)) - associatePacketConn.Close() - <-done - return E.Errors(innerError, err) default: err = socks5.WriteResponse(conn, socks5.Response{ ReplyCode: socks5.ReplyCodeUnsupported, diff --git a/protocol/socks/lazy.go b/protocol/socks/lazy.go new file mode 100644 index 00000000..e6874752 --- /dev/null +++ b/protocol/socks/lazy.go @@ -0,0 +1,215 @@ +package socks + +import ( + "net" + + "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/bufio" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/protocol/socks/socks4" + "github.com/sagernet/sing/protocol/socks/socks5" +) + +type LazyConn struct { + net.Conn + socksVersion byte + responseWritten bool +} + +func NewLazyConn(conn net.Conn, socksVersion byte) *LazyConn { + return &LazyConn{ + Conn: conn, + socksVersion: socksVersion, + } +} + +func (c *LazyConn) ConnHandshakeSuccess(conn net.Conn) error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + }() + switch c.socksVersion { + case socks4.Version: + return socks4.WriteResponse(c.Conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeGranted, + Destination: M.SocksaddrFromNet(conn.LocalAddr()), + }) + case socks5.Version: + return socks5.WriteResponse(conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(conn.LocalAddr()), + }) + default: + panic("unknown socks version") + } +} + +func (c *LazyConn) HandshakeFailure(err error) error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + }() + switch c.socksVersion { + case socks4.Version: + return socks4.WriteResponse(c.Conn, socks4.Response{ + ReplyCode: socks4.ReplyCodeRejectedOrFailed, + }) + case socks5.Version: + return socks5.WriteResponse(c.Conn, socks5.Response{ + ReplyCode: socks5.ReplyCodeForError(err), + }) + default: + panic("unknown socks version") + } +} + +func (c *LazyConn) Read(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.ConnHandshakeSuccess(c.Conn) + if err != nil { + return + } + } + return c.Conn.Read(p) +} + +func (c *LazyConn) Write(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.ConnHandshakeSuccess(c.Conn) + if err != nil { + return + } + } + return c.Conn.Write(p) +} + +func (c *LazyConn) ReaderReplaceable() bool { + return c.responseWritten +} + +func (c *LazyConn) WriterReplaceable() bool { + return c.responseWritten +} + +func (c *LazyConn) Upstream() any { + return c.Conn +} + +type LazyAssociatePacketConn struct { + AssociatePacketConn + responseWritten bool +} + +func NewLazyAssociatePacketConn(conn net.Conn, remoteAddr M.Socksaddr, underlying net.Conn) *LazyAssociatePacketConn { + return &LazyAssociatePacketConn{ + AssociatePacketConn: AssociatePacketConn{ + AbstractConn: conn, + conn: bufio.NewExtendedConn(conn), + remoteAddr: remoteAddr, + underlying: underlying, + }, + } +} + +func (c *LazyAssociatePacketConn) HandshakeSuccess() error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + }() + return socks5.WriteResponse(c.underlying, socks5.Response{ + ReplyCode: socks5.ReplyCodeSuccess, + Bind: M.SocksaddrFromNet(c.conn.LocalAddr()), + }) +} + +func (c *LazyAssociatePacketConn) HandshakeFailure(err error) error { + if c.responseWritten { + return nil + } + defer func() { + c.responseWritten = true + c.conn.Close() + c.underlying.Close() + }() + return socks5.WriteResponse(c.underlying, socks5.Response{ + ReplyCode: socks5.ReplyCodeForError(err), + }) +} + +func (c *LazyAssociatePacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.ReadFrom(p) +} + +func (c *LazyAssociatePacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.WriteTo(p, addr) +} + +func (c *LazyAssociatePacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.ReadPacket(buffer) +} + +func (c *LazyAssociatePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { + if !c.responseWritten { + err := c.HandshakeSuccess() + if err != nil { + return err + } + } + return c.AssociatePacketConn.WritePacket(buffer, destination) +} + +func (c *LazyAssociatePacketConn) Read(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.Read(p) +} + +func (c *LazyAssociatePacketConn) Write(p []byte) (n int, err error) { + if !c.responseWritten { + err = c.HandshakeSuccess() + if err != nil { + return + } + } + return c.AssociatePacketConn.Write(p) +} + +func (c *LazyAssociatePacketConn) ReaderReplaceable() bool { + return c.responseWritten +} + +func (c *LazyAssociatePacketConn) WriterReplaceable() bool { + return c.responseWritten +} + +func (c *LazyAssociatePacketConn) Upstream() any { + return c.underlying +} diff --git a/protocol/socks/socks5/protocol.go b/protocol/socks/socks5/protocol.go index 29ff3db5..cb042270 100644 --- a/protocol/socks/socks5/protocol.go +++ b/protocol/socks/socks5/protocol.go @@ -1,8 +1,10 @@ package socks5 import ( + "errors" "io" "net/netip" + "syscall" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" @@ -37,6 +39,20 @@ const ( ReplyCodeAddressTypeUnsupported byte = 8 ) +func ReplyCodeForError(err error) byte { + if errors.Is(err, syscall.ENETUNREACH) { + return ReplyCodeNetworkUnreachable + } else if errors.Is(err, syscall.EHOSTUNREACH) { + return ReplyCodeHostUnreachable + } else if errors.Is(err, syscall.ECONNREFUSED) { + return ReplyCodeConnectionRefused + } else if errors.Is(err, syscall.EPERM) { + return ReplyCodeNotAllowed + } else { + return ReplyCodeFailure + } +} + // +----+----------+----------+ // |VER | NMETHODS | METHODS | // +----+----------+----------+