Skip to content

Commit

Permalink
added network connection rotations
Browse files Browse the repository at this point in the history
  • Loading branch information
caffix committed Feb 2, 2024
1 parent 89eae36 commit 9361e3e
Showing 1 changed file with 91 additions and 46 deletions.
137 changes: 91 additions & 46 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package resolve

import (
"context"
"errors"
"fmt"
"net"
"runtime"
Expand All @@ -25,82 +26,123 @@ type resp struct {
Addr net.Addr
}

type connection struct {
conn net.PacketConn
done chan struct{}
}

type connections struct {
sync.Mutex
done chan struct{}
conns []net.PacketConn
conns []*connection
resps queue.Queue
nextWrite int
cpus int
}

func newConnections(cpus int, resps queue.Queue) *connections {
conns := &connections{
done: make(chan struct{}, 1),
resps: resps,
done: make(chan struct{}),
cpus: cpus,
}

conns.Lock()
defer conns.Unlock()

for i := 0; i < cpus; i++ {
if err := conns.Add(); err != nil {
conns.Close()
return nil
}
}
go conns.rotations()
return conns
}

func (c *connections) Close() {
select {
case <-c.done:
return
default:
func (r *connections) Close() {
r.Lock()
defer r.Unlock()

if r.conns != nil {
close(r.done)
for _, c := range r.conns {
close(c.done)
}
r.conns = nil
}
close(c.done)
for _, conn := range c.conns {
conn.Close()
}

func (r *connections) rotations() {
t := time.NewTicker(time.Minute)
defer t.Stop()

for {
select {
case <-r.done:
return
case <-t.C:
r.rotate()
}
}
}

func (c *connections) Next() net.PacketConn {
c.Lock()
defer c.Unlock()
func (r *connections) rotate() {
r.Lock()
defer r.Unlock()

cur := c.nextWrite
c.nextWrite = (c.nextWrite + 1) % len(c.conns)
return c.conns[cur]
for _, c := range r.conns {
go func(c *connection) {
t := time.NewTimer(10 * time.Second)
defer t.Stop()

<-t.C
close(c.done)
}(c)
}

r.conns = []*connection{}
for i := 0; i < r.cpus; i++ {
_ = r.Add()
}
}

func (c *connections) Add() error {
func (r *connections) Next() net.PacketConn {
r.Lock()
defer r.Unlock()

if r.conns == nil || len(r.conns) == 0 {
return nil
}

cur := r.nextWrite
r.nextWrite = (r.nextWrite + 1) % len(r.conns)
return r.conns[cur].conn
}

func (r *connections) Add() error {
var err error
var conn net.PacketConn

switch runtime.GOOS {
case "android":
fallthrough
case "linux":
fallthrough
case "darwin":
fallthrough
case "freebsd":
fallthrough
case "netbsd":
fallthrough
case "openbsd":
fallthrough
case "solaris":
conn, err = c.unixListenPacket()
default:
if runtime.GOOS == "linux" {
conn, err = r.linuxListenPacket()
} else {
conn, err = net.ListenPacket("udp", ":0")
}

if err == nil {
_ = conn.SetDeadline(time.Time{})
c.conns = append(c.conns, conn)
go c.responses(conn)
c := &connection{
conn: conn,
done: make(chan struct{}),
}
r.conns = append(r.conns, c)
go r.responses(c)
}
return err
}

func (c *connections) unixListenPacket() (net.PacketConn, error) {
func (r *connections) linuxListenPacket() (net.PacketConn, error) {
lc := net.ListenConfig{
Control: func(network, address string, c syscall.RawConn) error {
var operr error
Expand All @@ -116,43 +158,46 @@ func (c *connections) unixListenPacket() (net.PacketConn, error) {
}

laddr := ":0"
if len(c.conns) > 0 {
laddr = c.conns[0].LocalAddr().String()
if len(r.conns) > 0 {
laddr = r.conns[0].conn.LocalAddr().String()
}

return lc.ListenPacket(context.Background(), "udp", laddr)
}

func (c *connections) WriteMsg(msg *dns.Msg, addr net.Addr) error {
func (r *connections) WriteMsg(msg *dns.Msg, addr net.Addr) error {
var n int
var err error
var out []byte

if out, err = msg.Pack(); err == nil {
conn := c.Next()
err = errors.New("failed to obtain a connection")

_ = conn.SetWriteDeadline(time.Now().Add(500 * time.Millisecond))
if n, err = conn.WriteTo(out, addr); err == nil && n < len(out) {
err = fmt.Errorf("only wrote %d bytes of the %d byte message", n, len(out))
if conn := r.Next(); conn != nil {
_ = conn.SetWriteDeadline(time.Now().Add(500 * time.Millisecond))
if n, err = conn.WriteTo(out, addr); err == nil && n < len(out) {
err = fmt.Errorf("only wrote %d bytes of the %d byte message", n, len(out))
}
}
}
return err
}

func (c *connections) responses(conn net.PacketConn) {
func (r *connections) responses(c *connection) {
b := make([]byte, dns.DefaultMsgSize)

for {
select {
case <-c.done:
_ = c.conn.Close()
return
default:
}
if n, addr, err := conn.ReadFrom(b); err == nil && n >= headerSize {
if n, addr, err := c.conn.ReadFrom(b); err == nil && n >= headerSize {
m := new(dns.Msg)

if err := m.Unpack(b[:n]); err == nil && len(m.Question) > 0 {
c.resps.Append(&resp{
r.resps.Append(&resp{
Msg: m,
Addr: addr,
})
Expand Down

0 comments on commit 9361e3e

Please sign in to comment.