-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5db1ab6
commit 36f0812
Showing
3 changed files
with
326 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
// Package probe provides utilities for checking the internet connectivity of TCP and UDP clients. | ||
package probe |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
package probe | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"fmt" | ||
"net/http" | ||
"time" | ||
|
||
"github.com/database64128/shadowsocks-go/conn" | ||
"github.com/database64128/shadowsocks-go/zerocopy" | ||
) | ||
|
||
// TCPProbeConfig is the configuration for a TCP probe. | ||
type TCPProbeConfig struct { | ||
// Addr is the address of the HTTP test endpoint. | ||
Addr conn.Addr | ||
|
||
// EscapedPath is the escaped URL path of the HTTP test endpoint. | ||
EscapedPath string | ||
|
||
// Host specifies the value of the Host header field in the HTTP request. | ||
Host string | ||
} | ||
|
||
// NewProbe creates a new TCP probe from the configuration. | ||
func (c TCPProbeConfig) NewProbe() TCPProbe { | ||
return TCPProbe{ | ||
addr: c.Addr, | ||
req: fmt.Appendf(nil, "GET %s HTTP/1.1\r\nHost: %s\r\n\r\n", c.EscapedPath, c.Host), | ||
} | ||
} | ||
|
||
// TCPProbe tests the connectivity of a TCP client by sending an HTTP GET request | ||
// to the configured endpoint. The response status code must be 204 No Content. | ||
type TCPProbe struct { | ||
addr conn.Addr | ||
req []byte | ||
} | ||
|
||
// Probe runs the connectivity test. | ||
func (p TCPProbe) Probe(ctx context.Context, client zerocopy.TCPClient) error { | ||
dialer, _ := client.NewDialer() | ||
|
||
_, rw, err := dialer.Dial(ctx, p.addr, p.req) | ||
if err != nil { | ||
return fmt.Errorf("failed to create remote connection: %w", err) | ||
} | ||
defer rw.Close() | ||
|
||
cr := zerocopy.NewCopyReader(rw) | ||
br := bufio.NewReader(cr) | ||
|
||
resp, err := http.ReadResponse(br, nil) | ||
if err != nil { | ||
return fmt.Errorf("failed to read HTTP response: %w", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
if resp.StatusCode != http.StatusNoContent { | ||
return fmt.Errorf("unexpected HTTP status code: %d", resp.StatusCode) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
// ProbeRTT runs the connectivity test and returns the round-trip time. | ||
func (p TCPProbe) ProbeRTT(ctx context.Context, client zerocopy.TCPClient) (rtt time.Duration, err error) { | ||
start := time.Now() | ||
err = p.Probe(ctx, client) | ||
return time.Since(start), err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,252 @@ | ||
package probe | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"math/rand/v2" | ||
"os" | ||
"slices" | ||
"time" | ||
|
||
"github.com/database64128/shadowsocks-go/conn" | ||
"github.com/database64128/shadowsocks-go/zerocopy" | ||
"go.uber.org/zap" | ||
"golang.org/x/net/dns/dnsmessage" | ||
) | ||
|
||
// UDPProbeConfig is the configuration for a UDP probe. | ||
type UDPProbeConfig struct { | ||
// Addr is the address of the UDP DNS server. | ||
Addr conn.Addr | ||
|
||
// Logger is the logger to use for the probe. | ||
Logger *zap.Logger | ||
} | ||
|
||
// NewProbe creates a new UDP probe from the configuration. | ||
func (c UDPProbeConfig) NewProbe() UDPProbe { | ||
return UDPProbe{ | ||
addr: c.Addr, | ||
logger: c.Logger, | ||
} | ||
} | ||
|
||
// UDPProbe tests the connectivity of a UDP client by sending a DNS query to the configured server. | ||
// The DNS server must support the HTTPS RR type and return a response indicating success. | ||
type UDPProbe struct { | ||
addr conn.Addr | ||
logger *zap.Logger | ||
} | ||
|
||
// Probe runs the connectivity test. | ||
func (p UDPProbe) Probe(ctx context.Context, client zerocopy.UDPClient) error { | ||
ctx, cancel := context.WithCancel(ctx) | ||
defer cancel() | ||
|
||
sessionInfo, session, err := client.NewSession(ctx) | ||
if err != nil { | ||
return fmt.Errorf("failed to create client session: %w", err) | ||
} | ||
defer session.Close() | ||
|
||
uc, _, err := sessionInfo.ListenConfig.ListenUDP(ctx, "udp", "") | ||
if err != nil { | ||
return fmt.Errorf("failed to create UDP socket: %w", err) | ||
} | ||
defer uc.Close() | ||
|
||
go func() { | ||
<-ctx.Done() | ||
_ = uc.SetReadDeadline(conn.ALongTimeAgo) | ||
}() | ||
|
||
b := make([]byte, session.MaxPacketSize) | ||
|
||
const domainName = "www.google.com." | ||
name, err := dnsmessage.NewName(domainName) | ||
if err != nil { | ||
return fmt.Errorf("failed to create DNS name: %w", err) | ||
} | ||
|
||
// maxDNSPacketSize is the maximum packet size to advertise in EDNS(0). | ||
// We use the same value as Go itself. | ||
const maxDNSPacketSize = 1232 | ||
var rh dnsmessage.ResourceHeader | ||
if err := rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false); err != nil { | ||
return fmt.Errorf("failed to set EDNS(0) options: %w", err) | ||
} | ||
|
||
const rrTypeHTTPS = 65 | ||
msg := dnsmessage.Message{ | ||
Header: dnsmessage.Header{ | ||
ID: uint16(rand.Uint64()), | ||
RecursionDesired: true, | ||
}, | ||
Questions: []dnsmessage.Question{ | ||
{ | ||
Name: name, | ||
Type: rrTypeHTTPS, | ||
Class: dnsmessage.ClassINET, | ||
}, | ||
}, | ||
Additionals: []dnsmessage.Resource{ | ||
{ | ||
Header: rh, | ||
Body: &dnsmessage.OPTResource{}, | ||
}, | ||
}, | ||
} | ||
sendBuf, err := msg.AppendPack(b[:sessionInfo.PackerHeadroom.Front]) | ||
if err != nil { | ||
return fmt.Errorf("failed to pack DNS message: %w", err) | ||
} | ||
payloadLen := len(sendBuf) - sessionInfo.PackerHeadroom.Front | ||
sendBuf = slices.Grow(sendBuf, sessionInfo.PackerHeadroom.Rear)[:len(sendBuf)+sessionInfo.PackerHeadroom.Rear] | ||
|
||
destAddrPort, packetStart, packetLen, err := session.Packer.PackInPlace(ctx, sendBuf, p.addr, sessionInfo.PackerHeadroom.Front, payloadLen) | ||
if err != nil { | ||
return fmt.Errorf("failed to pack DNS query packet: %w", err) | ||
} | ||
|
||
if _, err = uc.WriteToUDPAddrPort(sendBuf[packetStart:packetStart+packetLen], destAddrPort); err != nil { | ||
return fmt.Errorf("failed to send DNS query packet: %w", err) | ||
} | ||
|
||
for { | ||
n, _, flags, packetSourceAddress, err := uc.ReadMsgUDPAddrPort(b, nil) | ||
if err != nil { | ||
if errors.Is(err, os.ErrDeadlineExceeded) { | ||
return err | ||
} | ||
p.logger.Warn("Failed to read DNS response packet", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("packetSourceAddress", packetSourceAddress), | ||
zap.Int("packetLength", n), | ||
zap.Error(err), | ||
) | ||
continue | ||
} | ||
if err = conn.ParseFlagsForError(flags); err != nil { | ||
p.logger.Warn("Failed to read DNS response packet", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("packetSourceAddress", packetSourceAddress), | ||
zap.Int("packetLength", n), | ||
zap.Error(err), | ||
) | ||
continue | ||
} | ||
|
||
payloadSourceAddrPort, payloadStart, payloadLen, err := session.Unpacker.UnpackInPlace(b, packetSourceAddress, 0, n) | ||
if err != nil { | ||
p.logger.Warn("Failed to unpack DNS response packet", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("packetSourceAddress", packetSourceAddress), | ||
zap.Int("packetLength", n), | ||
zap.Error(err), | ||
) | ||
continue | ||
} | ||
if p.addr.IsIP() { | ||
if !conn.AddrPortMappedEqual(payloadSourceAddrPort, p.addr.IPPort()) { | ||
p.logger.Warn("Ignoring DNS response packet from unexpected source", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
) | ||
continue | ||
} | ||
} | ||
|
||
var parser dnsmessage.Parser | ||
|
||
header, err := parser.Start(b[payloadStart : payloadStart+payloadLen]) | ||
if err != nil { | ||
p.logger.Warn("Failed to parse DNS response header", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Int("payloadLength", payloadLen), | ||
zap.Error(err), | ||
) | ||
continue | ||
} | ||
if header.ID != msg.Header.ID { | ||
p.logger.Warn("Ignoring DNS response packet with unexpected transaction ID", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Uint16("receivedID", header.ID), | ||
zap.Uint16("expectedID", msg.Header.ID), | ||
) | ||
continue | ||
} | ||
if !header.Response { | ||
p.logger.Warn("Ignoring non-response DNS packet", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
) | ||
continue | ||
} | ||
if header.RCode != dnsmessage.RCodeSuccess { | ||
p.logger.Warn("Ignoring non-success DNS response", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Stringer("rcode", header.RCode), | ||
) | ||
continue | ||
} | ||
|
||
question, err := parser.Question() | ||
if err != nil { | ||
p.logger.Warn("Failed to parse question in DNS response packet", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Error(err), | ||
) | ||
continue | ||
} | ||
if question.Name.String() != domainName { | ||
p.logger.Warn("Ignoring DNS response packet with unexpected question name", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Stringer("receivedName", question.Name), | ||
) | ||
continue | ||
} | ||
if question.Type != rrTypeHTTPS { | ||
p.logger.Warn("Ignoring DNS response packet with unexpected question type", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Stringer("receivedType", question.Type), | ||
) | ||
continue | ||
} | ||
if question.Class != dnsmessage.ClassINET { | ||
p.logger.Warn("Ignoring DNS response packet with unexpected question class", | ||
zap.String("client", sessionInfo.Name), | ||
zap.Stringer("targetAddr", p.addr), | ||
zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort), | ||
zap.Stringer("receivedClass", question.Class), | ||
) | ||
continue | ||
} | ||
|
||
return nil | ||
} | ||
} | ||
|
||
// ProbeRTT runs the connectivity test and returns the round-trip time. | ||
func (p UDPProbe) ProbeRTT(ctx context.Context, client zerocopy.UDPClient) (rtt time.Duration, err error) { | ||
start := time.Now() | ||
err = p.Probe(ctx, client) | ||
return time.Since(start), err | ||
} |