Skip to content

Commit

Permalink
Add a shorter timeout and close connections earlier
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Ellis (OpenFaaS Ltd) <alexellis2@gmail.com>
  • Loading branch information
alexellis committed Sep 11, 2022
1 parent 7ff9d48 commit 3bbd128
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 32 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ To make the upstream address listen on all interfaces, use `0.0.0.0` instead of

The port for the from and to addresses do not need to match.

See also:
* `-t` - specify the dial timeout for an upstream host in the "to" field of the config file.
* `-v` - verbose logging - set to false to turn off logs of connections established and closed.

## License

This software is licensed MIT.
Expand Down
131 changes: 99 additions & 32 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,45 +28,62 @@ type Rule struct {

func main() {
var (
file string
file string
verbose bool
dialTimeout time.Duration
)

flag.StringVar(&file, "f", "", "Job to run or leave blank for job.yaml in current directory")

flag.BoolVar(&verbose, "v", true, "Verbose output for opened and closed connections")
flag.DurationVar(&dialTimeout, "t", time.Millisecond*1500, "Dial timeout")
flag.Parse()

if len(file) == 0 {
fmt.Fprintf(os.Stderr, "usage: mixctl -f rules.yaml\n")
os.Exit(1)
}

set := ForwardingSet{}
data, err := os.ReadFile(file)
if err != nil {
log.Fatalf("error reading file %s %s", file, err.Error())
fmt.Fprintf(os.Stderr, "error reading file %s %s", file, err.Error())
os.Exit(1)
}
if err = yaml.Unmarshal(data, &set); err != nil {
log.Fatalf("error parsing file %s %s", file, err.Error())
fmt.Fprintf(os.Stderr, "error parsing file %s %s", file, err.Error())
os.Exit(1)
}

if len(set.Rules) == 0 {
fmt.Fprintf(os.Stderr, "no rules found in file %s", file)
os.Exit(1)
}

fmt.Printf("mixctl by inlets..\n")
fmt.Printf("Starting mixctl by https://inlets.dev/\n\n")

wg := sync.WaitGroup{}
wg.Add(len(set.Rules))
for _, f := range set.Rules {

r := f
go func(rule *Rule) {
fmt.Printf("Forward (%s) from: %s to: %s\n", rule.Name, rule.From, rule.To)
for _, rule := range set.Rules {
fmt.Printf("Forward (%s) from: %s to: %s\n", rule.Name, rule.From, rule.To)
}
fmt.Println()

if err := forward(rule.Name, rule.From, rule.To); err != nil {
for _, rule := range set.Rules {
// Copy the value to avoid the loop variable being reused
r := rule
go func() {
if err := forward(r.Name, r.From, r.To, verbose, dialTimeout); err != nil {
log.Printf("error forwarding %s", err.Error())
os.Exit(1)
}

defer wg.Done()
}(&r)
}()
}
wg.Wait()

wg.Wait()
}

func forward(name, from string, to []string) error {
func forward(name, from string, to []string, verbose bool, dialTimeout time.Duration) error {
seed := time.Now().UnixNano()
rand.Seed(seed)

Expand All @@ -76,42 +93,92 @@ func forward(name, from string, to []string) error {
return fmt.Errorf("error listening on %s %s", from, err.Error())
}

defer l.Close()

for {
conn, err := l.Accept()
// accept a connection on the local port of the load balancer
local, err := l.Accept()
if err != nil {
return fmt.Errorf("error accepting connection %s", err.Error())
}

// pick randomly from the list of upstream servers
// available
index := rand.Intn(len(to))
upstream := to[index]

remote, err := net.Dial("tcp", to[index])
if err != nil {
return fmt.Errorf("error dialing %s %s", to[index], err.Error())
}
// A separate Goroutine means the loop can accept another
// incoming connection on the local address
go connect(local, upstream, from, verbose, dialTimeout)
}
}

go func() {
log.Printf("[%s] %s => %s",
from,
conn.RemoteAddr().String(),
remote.RemoteAddr().String())
if err := forwardConnection(conn, remote); err != nil && err.Error() != "done" {
log.Printf("error forwarding connection %s", err.Error())
}
}()
// connect dials the upstream address, then copies data
// between it and connection accepted on a local port
func connect(local net.Conn, upstreamAddr, from string, verbose bool, dialTimeout time.Duration) {
defer local.Close()

// If Dial is used on its own, then the timeout can be as long
// as 2 minutes on MacOS for an unreachable host
upstream, err := net.DialTimeout("tcp", upstreamAddr, dialTimeout)
if err != nil {
log.Printf("error dialing %s %s", upstreamAddr, err.Error())
return
}
defer upstream.Close()

if verbose {
log.Printf("Connected %s => %s (%s)",
from,
upstream.RemoteAddr().String(),
local.RemoteAddr().String())
}

ctx := context.Background()
if err := copy(ctx, local, upstream); err != nil && err.Error() != "done" {
log.Printf("error forwarding connection %s", err.Error())
}

if verbose {
log.Printf("Closed %s => %s (%s)",
from,
upstream.RemoteAddr().String(),
local.RemoteAddr().String())
}
}

func forwardConnection(from, to net.Conn) error {
errgrp, _ := errgroup.WithContext(context.Background())
// copy copies data between two connections using io.Copy
// and will exit when either connection is closed or runs
// into an error
func copy(ctx context.Context, from, to net.Conn) error {

ctx, cancel := context.WithCancel(ctx)
errgrp, _ := errgroup.WithContext(ctx)
errgrp.Go(func() error {
io.Copy(from, to)
cancel()

return fmt.Errorf("done")
})
errgrp.Go(func() error {
io.Copy(to, from)
cancel()

return fmt.Errorf("done")
})
errgrp.Go(func() error {
<-ctx.Done()

// This closes both ends of the connection as
// soon as possible.
from.Close()
to.Close()
return fmt.Errorf("done")
})

if err := errgrp.Wait(); err != nil {
return err
}

return errgrp.Wait()
return nil
}
5 changes: 5 additions & 0 deletions rules.example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@ rules:
- 192.168.1.19:22
- 192.168.1.21:22
- 192.168.1.20:22

- name: remap-local-ssh-port
from: 127.0.0.1:2222
to:
- 127.0.0.1:22

0 comments on commit 3bbd128

Please sign in to comment.