Skip to content

Commit

Permalink
Crazy sekai overturns the small pond
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Oct 20, 2024
1 parent d59ac57 commit 863160f
Show file tree
Hide file tree
Showing 15 changed files with 580 additions and 130 deletions.
11 changes: 6 additions & 5 deletions common/bufio/addr_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@ import (

type AddrConn struct {
net.Conn
M.Metadata
Source M.Socksaddr
Destination M.Socksaddr
}

func (c *AddrConn) LocalAddr() net.Addr {
if c.Metadata.Destination.IsValid() {
return c.Metadata.Destination.TCPAddr()
if c.Destination.IsValid() {
return c.Destination.TCPAddr()
}
return c.Conn.LocalAddr()
}

func (c *AddrConn) RemoteAddr() net.Addr {
if c.Metadata.Source.IsValid() {
return c.Metadata.Source.TCPAddr()
if c.Source.IsValid() {
return c.Source.TCPAddr()
}
return c.Conn.RemoteAddr()
}
Expand Down
1 change: 0 additions & 1 deletion common/bufio/vectorised_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ func (w *SyscallVectorisedWriter) WriteVectorised(buffers []*buf.Buffer) error {
var innerErr unix.Errno
err := w.rawConn.Write(func(fd uintptr) (done bool) {
//nolint:staticcheck
//goland:noinspection GoDeprecation
_, _, innerErr = unix.Syscall(unix.SYS_WRITEV, fd, uintptr(unsafe.Pointer(&iovecList[0])), uintptr(len(iovecList)))
return innerErr != unix.EAGAIN && innerErr != unix.EWOULDBLOCK
})
Expand Down
1 change: 1 addition & 0 deletions common/exceptions/error.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
F "github.com/sagernet/sing/common/format"
)

// Deprecated: wtf is this?
type Handler interface {
NewError(ctx context.Context, err error)
}
Expand Down
1 change: 0 additions & 1 deletion common/exceptions/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ type TimeoutError interface {
func IsTimeout(err error) bool {
var netErr net.Error
if errors.As(err, &netErr) {
//goland:noinspection GoDeprecation
//nolint:staticcheck
return netErr.Temporary() && netErr.Timeout()
}
Expand Down
1 change: 1 addition & 0 deletions common/metadata/metadata.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package metadata

// Deprecated: wtf is this?
type Metadata struct {
Protocol string
Source Socksaddr
Expand Down
51 changes: 48 additions & 3 deletions common/network/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"sync"
"time"

"github.com/sagernet/sing/common"
Expand Down Expand Up @@ -70,8 +71,38 @@ type ExtendedConn interface {
net.Conn
}

type CloseHandlerFunc = func(it error)

func AppendClose(parent CloseHandlerFunc, onClose CloseHandlerFunc) CloseHandlerFunc {
if parent == nil {
return parent
} else if onClose == nil {
return onClose
}
return func(it error) {
onClose(it)
parent(it)
}
}

func OnceClose(onClose CloseHandlerFunc) CloseHandlerFunc {
var once sync.Once
return func(it error) {
once.Do(func() {
onClose(it)
})
}
}

// Deprecated: Use TCPConnectionHandlerEx instead.
type TCPConnectionHandler interface {
NewConnection(ctx context.Context, conn net.Conn, metadata M.Metadata) error
NewConnection(ctx context.Context, conn net.Conn,
//nolint:staticcheck
metadata M.Metadata) error
}

type TCPConnectionHandlerEx interface {
NewConnectionEx(ctx context.Context, conn net.Conn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
}

type NetPacketConn interface {
Expand All @@ -85,12 +116,26 @@ type BindPacketConn interface {
net.Conn
}

// Deprecated: Use UDPHandlerEx instead.
type UDPHandler interface {
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, metadata M.Metadata) error
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer,
//nolint:staticcheck
metadata M.Metadata) error
}

type UDPHandlerEx interface {
NewPacket(ctx context.Context, conn PacketConn, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr) error
}

// Deprecated: Use UDPConnectionHandlerEx instead.
type UDPConnectionHandler interface {
NewPacketConnection(ctx context.Context, conn PacketConn, metadata M.Metadata) error
NewPacketConnection(ctx context.Context, conn PacketConn,
//nolint:staticcheck
metadata M.Metadata) error
}

type UDPConnectionHandlerEx interface {
NewPacketConnectionEx(ctx context.Context, conn PacketConn, source M.Socksaddr, destination M.Socksaddr, onClose CloseHandlerFunc)
}

type CachedReader interface {
Expand Down
65 changes: 61 additions & 4 deletions common/network/handshake.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package network

import (
"io"
"net"

"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
Expand All @@ -13,17 +16,71 @@ type HandshakeSuccess interface {
HandshakeSuccess() error
}

func ReportHandshakeFailure(conn any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](conn); isHandshakeConn {
type ConnHandshakeSuccess interface {
ConnHandshakeSuccess(conn net.Conn) error
}

type PacketConnHandshakeSuccess interface {
PacketConnHandshakeSuccess(conn net.PacketConn) error
}

func ReportHandshakeFailure(reporter any, err error) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
return E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
return E.Cause(err, "write handshake failure")
})
}
return nil
}

func CloseOnHandshakeFailure(reporter any, onClose CloseHandlerFunc, err error) error {
if err != nil {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeFailure](reporter); isHandshakeConn {
err = E.Append(err, handshakeConn.HandshakeFailure(err), func(err error) error {
return E.Cause(err, "write handshake failure")
})
} else {
if tcpConn, isTCPConn := common.Cast[interface {
SetLinger(sec int) error
}](reporter); isTCPConn {
tcpConn.SetLinger(0)
}
if closer, isCloser := reporter.(io.Closer); isCloser {
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close")
})
}
}
}
if onClose != nil {
onClose(err)
}
return err
}

func ReportHandshakeSuccess(conn any) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](conn); isHandshakeConn {
// Deprecated: use ReportConnHandshakeSuccess/ReportPacketConnHandshakeSuccess instead
func ReportHandshakeSuccess(reporter any) error {
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}

func ReportConnHandshakeSuccess(reporter any, conn net.Conn) error {
if handshakeConn, isHandshakeConn := common.Cast[ConnHandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.ConnHandshakeSuccess(conn)
}
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
}

func ReportPacketConnHandshakeSuccess(reporter any, conn net.PacketConn) error {
if handshakeConn, isHandshakeConn := common.Cast[PacketConnHandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.PacketConnHandshakeSuccess(conn)
}
if handshakeConn, isHandshakeConn := common.Cast[HandshakeSuccess](reporter); isHandshakeConn {
return handshakeConn.HandshakeSuccess()
}
return nil
Expand Down
2 changes: 1 addition & 1 deletion common/network/thread.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@ type ThreadUnsafeWriter interface {
}

// Deprecated: Use ReadWaiter interface instead.

type ThreadSafeReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadBufferThreadSafe() (buffer *buf.Buffer, err error)
}

// Deprecated: Use ReadWaiter interface instead.
type ThreadSafePacketReader interface {
// Deprecated: Use ReadWaiter interface instead.
ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error)
}

Expand Down
1 change: 0 additions & 1 deletion common/random/seed_go119.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,5 @@ func InitializeSeed() {
func initializeSeed() {
var seed int64
common.Must(binary.Read(rand.Reader, binary.LittleEndian, &seed))
//goland:noinspection GoDeprecation
mRand.Seed(seed)
}
1 change: 0 additions & 1 deletion common/rw/varint.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func ToByteReader(reader io.Reader) io.ByteReader {

// Deprecated: Use binary.ReadUvarint instead.
func ReadUVariant(reader io.Reader) (uint64, error) {
//goland:noinspection GoDeprecation
return binary.ReadUvarint(ToByteReader(reader))
}

Expand Down
69 changes: 56 additions & 13 deletions common/udpnat/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,23 @@ import (
"github.com/sagernet/sing/common/pipe"
)

// Deprecated: Use N.UDPConnectionHandler instead.
//
//nolint:staticcheck
type Handler interface {
N.UDPConnectionHandler
E.Handler
}

type Service[K comparable] struct {
nat *cache.LruCache[K, *conn]
handler Handler
nat *cache.LruCache[K, *conn]
handler Handler
handlerEx N.UDPConnectionHandlerEx
}

// Deprecated: Use NewEx instead.
func New[K comparable](maxAge int64, handler Handler) *Service[K] {
return &Service[K]{
service := &Service[K]{
nat: cache.New(
cache.WithAge[K, *conn](maxAge),
cache.WithUpdateAgeOnGet[K, *conn](),
Expand All @@ -37,11 +42,27 @@ func New[K comparable](maxAge int64, handler Handler) *Service[K] {
),
handler: handler,
}
return service
}

func NewEx[K comparable](maxAge int64, handler N.UDPConnectionHandlerEx) *Service[K] {
service := &Service[K]{
nat: cache.New(
cache.WithAge[K, *conn](maxAge),
cache.WithUpdateAgeOnGet[K, *conn](),
cache.WithEvict[K, *conn](func(key K, conn *conn) {
conn.Close()
}),
),
handlerEx: handler,
}
return service
}

func (s *Service[T]) WriteIsThreadUnsafe() {
}

// Deprecated: don't use
func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) {
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, &DirectBackWriter{conn, natConn}
Expand All @@ -61,18 +82,30 @@ func (w *DirectBackWriter) Upstream() any {
return w.Source
}

// Deprecated: use NewPacketEx instead.
func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) {
s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, init(natConn)
})
}

func (s *Service[T]) NewPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) N.PacketWriter) {
s.NewContextPacketEx(ctx, key, buffer, source, destination, func(natConn N.PacketConn) (context.Context, N.PacketWriter) {
return ctx, init(natConn)
})
}

// Deprecated: Use NewPacketConnectionEx instead.
func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
s.NewContextPacketEx(ctx, key, buffer, metadata.Source, metadata.Destination, init)
}

func (s *Service[T]) NewContextPacketEx(ctx context.Context, key T, buffer *buf.Buffer, source M.Socksaddr, destination M.Socksaddr, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) {
c, loaded := s.nat.LoadOrStore(key, func() *conn {
c := &conn{
data: make(chan packet, 64),
localAddr: metadata.Source,
remoteAddr: metadata.Destination,
localAddr: source,
remoteAddr: destination,
readDeadline: pipe.MakeDeadline(),
}
c.ctx, c.cancel = common.ContextWithCancelCause(ctx)
Expand All @@ -81,26 +114,36 @@ func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Bu
if !loaded {
ctx, c.source = init(c)
go func() {
err := s.handler.NewPacketConnection(ctx, c, metadata)
if err != nil {
s.handler.NewError(ctx, err)
if s.handlerEx != nil {
s.handlerEx.NewPacketConnectionEx(ctx, c, source, destination, func(err error) {
s.nat.Delete(key)
})
} else {
//nolint:staticcheck
err := s.handler.NewPacketConnection(ctx, c, M.Metadata{
Source: source,
Destination: destination,
})
if err != nil {
s.handler.NewError(ctx, err)
}
c.Close()
s.nat.Delete(key)
}
c.Close()
s.nat.Delete(key)
}()
} else {
c.localAddr = metadata.Source
c.localAddr = source
}
if common.Done(c.ctx) {
s.nat.Delete(key)
if !common.Done(ctx) {
s.NewContextPacket(ctx, key, buffer, metadata, init)
s.NewContextPacketEx(ctx, key, buffer, source, destination, init)
}
return
}
c.data <- packet{
data: buffer,
destination: metadata.Destination,
destination: destination,
}
}

Expand Down
Loading

0 comments on commit 863160f

Please sign in to comment.