Skip to content

Commit

Permalink
use NewError
Browse files Browse the repository at this point in the history
  • Loading branch information
firefart committed Dec 12, 2023
1 parent 5b6cabe commit f288793
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 26 deletions.
5 changes: 4 additions & 1 deletion internal/cmd/memoryleak.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ func MemoryLeak(opts MemoryleakOpts) error {
}
defer remote.Close()

channelNumber := helper.RandomChannelNumber()
channelNumber, err := helper.RandomChannelNumber()
if err != nil {
return fmt.Errorf("error on getting random channel number: %w", err)
}
channelBindRequest, err := internal.ChannelBindRequest(opts.Username, opts.Password, nonce, realm, opts.TargetHost, opts.TargetPort, channelNumber)
if err != nil {
return fmt.Errorf("error on generating ChannelBind request: %w", err)
Expand Down
10 changes: 8 additions & 2 deletions internal/cmd/udpscanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ func snmpScan(opts UDPScannerOpts, ip netip.Addr, port uint16, community string)
}
defer remote.Close()

channelNumber := helper.RandomChannelNumber()
channelNumber, err := helper.RandomChannelNumber()
if err != nil {
return fmt.Errorf("error on getting random channel number: %w", err)
}
channelBindRequest, err := internal.ChannelBindRequest(opts.Username, opts.Password, nonce, realm, ip, port, channelNumber)
if err != nil {
return fmt.Errorf("error on generating ChannelBindRequest: %w", err)
Expand Down Expand Up @@ -180,7 +183,10 @@ func dnsScan(opts UDPScannerOpts, ip netip.Addr, port uint16, dnsName string) er
}
defer remote.Close()

channelNumber := helper.RandomChannelNumber()
channelNumber, err := helper.RandomChannelNumber()
if err != nil {
return fmt.Errorf("error on getting random channel number: %w", err)
}
channelBindRequest, err := internal.ChannelBindRequest(opts.Username, opts.Password, nonce, realm, ip, port, channelNumber)
if err != nil {
return fmt.Errorf("error on generating ChannelBindRequest: %w", err)
Expand Down
9 changes: 6 additions & 3 deletions internal/helper/helper.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package helper

import (
cryptorand "crypto/rand"
"encoding/binary"
"math/rand"
"net/netip"
Expand Down Expand Up @@ -45,17 +46,19 @@ func IsPrivateIP(ip netip.Addr) bool {
// RandomChannelNumber generates a random valid channel number
// 0x4000 through 0x7FFF: These values are the allowed channel
// numbers (16,383 possible values).
func RandomChannelNumber() []byte {
func RandomChannelNumber() ([]byte, error) {
token := make([]byte, 2)
valid := false
for !valid {
rand.Read(token)
if _, err := cryptorand.Read(token); err != nil {
return nil, err
}
if token[0] >= 0x40 &&
token[0] <= 0x7f {
break
}
}
return token
return token, nil
}

// PutUint16 is a helper function to convert an uint16 to a buffer
Expand Down
5 changes: 4 additions & 1 deletion internal/helper/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import "testing"

func TestRandomChannelNumber(t *testing.T) {
for i := 0; i < 1000; i++ {
channel := RandomChannelNumber()
channel, err := RandomChannelNumber()
if err != nil {
t.Fatal(err)
}
if channel[0] < 0x40 || channel[0] > 0x7F {
t.Fail()
}
Expand Down
12 changes: 6 additions & 6 deletions internal/socksimplementations/socksturntcphandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, *
case socks.RequestAddressTypeIPv4, socks.RequestAddressTypeIPv6:
tmp, ok := netip.AddrFromSlice(request.DestinationAddress)
if !ok {
return nil, &socks.Error{Reason: socks.RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("%02x is no ip address", request.DestinationAddress)}
return nil, socks.NewError(socks.RequestReplyAddressTypeNotSupported, fmt.Errorf("%02x is no ip address", request.DestinationAddress))
}
target = tmp
case socks.RequestAddressTypeDomainname:
Expand All @@ -47,25 +47,25 @@ func (s *SocksTurnTCPHandler) Init(request socks.Request) (io.ReadWriteCloser, *
// input is a hostname
names, err := helper.ResolveName(s.Ctx, string(request.DestinationAddress))
if err != nil {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: err}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, err)
}
if len(names) == 0 {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("%s could not be resolved", string(request.DestinationAddress))}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("%s could not be resolved", string(request.DestinationAddress)))
}
target = names[0]
}
default:
return nil, &socks.Error{Reason: socks.RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("AddressType %#x not implemented", request.AddressType)}
return nil, socks.NewError(socks.RequestReplyAddressTypeNotSupported, fmt.Errorf("AddressType %#x not implemented", request.AddressType))
}

if s.DropNonPrivateRequests && !helper.IsPrivateIP(target) {
s.Log.Debugf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort))
}

controlConnection, dataConnection, err := internal.SetupTurnTCPConnection(s.Log, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword)
if err != nil {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: err}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, err)
}

// we need to keep this connection open
Expand Down
23 changes: 13 additions & 10 deletions internal/socksimplementations/socksturnudphandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,46 +37,49 @@ func (s *SocksTurnUDPHandler) PreHandler(request socks.Request) (io.ReadWriteClo
case socks.RequestAddressTypeIPv4, socks.RequestAddressTypeIPv6:
tmp, ok := netip.AddrFromSlice(request.DestinationAddress)
if !ok {
return nil, &socks.Error{Reason: socks.RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("%02x is no ip address", request.DestinationAddress)}
return nil, socks.NewError(socks.RequestReplyAddressTypeNotSupported, fmt.Errorf("%02x is no ip address", request.DestinationAddress))
}
target = tmp
case socks.RequestAddressTypeDomainname:
names, err := helper.ResolveName(s.Ctx, string(request.DestinationAddress))
if err != nil {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: err}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, err)
}
if len(names) == 0 {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("%s could not be resolved", string(request.DestinationAddress))}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("%s could not be resolved", string(request.DestinationAddress)))
}
target = names[0]
default:
return nil, &socks.Error{Reason: socks.RequestReplyAddressTypeNotSupported, Err: fmt.Errorf("AddressType %#x not implemented", request.AddressType)}
return nil, socks.NewError(socks.RequestReplyAddressTypeNotSupported, fmt.Errorf("AddressType %#x not implemented", request.AddressType))
}

if s.DropNonPrivateRequests && !helper.IsPrivateIP(target) {
s.Log.Debugf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort)}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("dropping non private connection to %s:%d", target.String(), request.DestinationPort))
}

remote, realm, nonce, err := internal.SetupTurnConnection(s.Log, s.ConnectProtocol, s.Server, s.UseTLS, s.Timeout, target, request.DestinationPort, s.TURNUsername, s.TURNPassword)
if err != nil {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: err}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, err)
}
defer remote.Close()

s.channelNumber = helper.RandomChannelNumber()
s.channelNumber, err = helper.RandomChannelNumber()
if err != nil {
return nil, socks.NewError(socks.RequestReplyGeneralFailure, fmt.Errorf("error on getting random channel number: %w", err))
}
channelBindRequest, err := internal.ChannelBindRequest(s.TURNUsername, s.TURNPassword, nonce, realm, target, request.DestinationPort, s.channelNumber)
if err != nil {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("error on generating ChannelBindRequest: %w", err)}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("error on generating ChannelBindRequest: %w", err))
}
s.Log.Debugf("ChannelBind Request:\n%s", channelBindRequest.String())
channelBindResponse, err := channelBindRequest.SendAndReceive(s.Log, remote, s.Timeout)
if err != nil {
return nil, &socks.Error{Reason: socks.RequestReplyHostUnreachable, Err: fmt.Errorf("error on sending ChannelBindRequest: %w", err)}
return nil, socks.NewError(socks.RequestReplyHostUnreachable, fmt.Errorf("error on sending ChannelBindRequest: %w", err))
}
s.Log.Debugf("ChannelBind Response:\n%s", channelBindResponse.String())
if channelBindResponse.Header.MessageType.Class == internal.MsgTypeClassError {
return nil, &socks.Error{Reason: socks.RequestReplyGeneralFailure, Err: fmt.Errorf("error on ChannelBind: %s", channelBindResponse.GetErrorString())}
return nil, socks.NewError(socks.RequestReplyGeneralFailure, fmt.Errorf("error on ChannelBind: %s", channelBindResponse.GetErrorString()))
}
return remote, nil
}
Expand Down
3 changes: 0 additions & 3 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ package main

import (
"fmt"
"math/rand"
"net"
"net/netip"
"os"
Expand All @@ -33,8 +32,6 @@ func main() {
log.SetOutput(os.Stdout)
log.SetLevel(logrus.InfoLevel)

rand.Seed(time.Now().UnixNano())

app := &cli.App{
Name: "stunner",
Usage: "test turn servers for misconfigurations",
Expand Down

0 comments on commit f288793

Please sign in to comment.