Skip to content

Commit

Permalink
Rewrite base transports
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Jun 24, 2024
1 parent af7b448 commit dabf25f
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 326 deletions.
237 changes: 0 additions & 237 deletions transport_base.go

This file was deleted.

68 changes: 52 additions & 16 deletions transport_tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package dns
import (
"context"
"encoding/binary"
"net"
"io"
"net/netip"
"net/url"
"os"

"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
Expand All @@ -24,7 +26,9 @@ func init() {
}

type TCPTransport struct {
myTransportAdapter
name string
dialer N.Dialer
serverAddr M.Socksaddr
}

func NewTCPTransport(options TransportOptions) (*TCPTransport, error) {
Expand All @@ -43,29 +47,61 @@ func NewTCPTransport(options TransportOptions) (*TCPTransport, error) {
}

func newTCPTransport(options TransportOptions, serverAddr M.Socksaddr) *TCPTransport {
transport := &TCPTransport{
newAdapter(options, serverAddr, false),
return &TCPTransport{
name: options.Name,
dialer: options.Dialer,
serverAddr: serverAddr,
}
transport.handler = transport
return transport
}

func (t *TCPTransport) DialContext(ctx context.Context) (net.Conn, error) {
return t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
func (t *TCPTransport) Name() string {
return t.name
}

func (t *TCPTransport) ReadMessage(conn net.Conn) (*dns.Msg, error) {
var length uint16
err := binary.Read(conn, binary.BigEndian, &length)
func (t *TCPTransport) Start() error {
return nil
}

func (t *TCPTransport) Reset() {
}

func (t *TCPTransport) Close() error {
return nil
}

func (t *TCPTransport) Raw() bool {
return true
}

func (t *TCPTransport) Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) {
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
if err != nil {
return nil, err
}
defer conn.Close()
err = writeMessage(conn, message)
if err != nil {
return nil, err
}
return readMessage(conn)
}

func (t *TCPTransport) Lookup(ctx context.Context, domain string, strategy DomainStrategy) ([]netip.Addr, error) {
return nil, os.ErrInvalid
}

func readMessage(reader io.Reader) (*dns.Msg, error) {
var responseLen uint16
err := binary.Read(reader, binary.BigEndian, &responseLen)
if err != nil {
return nil, err
}
if length < 10 {
if responseLen < 10 {
return nil, dns.ErrShortRead
}
buffer := buf.NewSize(int(length))
buffer := buf.NewSize(int(responseLen))
defer buffer.Release()
_, err = buffer.ReadFullFrom(conn, int(length))
_, err = buffer.ReadFullFrom(reader, int(responseLen))
if err != nil {
return nil, err
}
Expand All @@ -74,7 +110,7 @@ func (t *TCPTransport) ReadMessage(conn net.Conn) (*dns.Msg, error) {
return &message, err
}

func (t *TCPTransport) WriteMessage(conn net.Conn, message *dns.Msg) error {
func writeMessage(writer io.Writer, message *dns.Msg) error {
requestLen := message.Len()
buffer := buf.NewSize(3 + requestLen)
defer buffer.Release()
Expand All @@ -86,5 +122,5 @@ func (t *TCPTransport) WriteMessage(conn net.Conn, message *dns.Msg) error {
return err
}
buffer.Truncate(2 + len(rawMessage))
return common.Error(conn.Write(buffer.Bytes()))
return common.Error(writer.Write(buffer.Bytes()))
}
Loading

0 comments on commit dabf25f

Please sign in to comment.