Skip to content

Commit

Permalink
Merge pull request #47 from idohuber/main
Browse files Browse the repository at this point in the history
add timeouts to mysql and Kafka application discovery
  • Loading branch information
idohuber authored Jun 26, 2024
2 parents d539b65 + 2f03a6e commit db0e5e7
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
5 changes: 3 additions & 2 deletions cmd/scan.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -79,7 +80,7 @@ func scan(cmd *cobra.Command, args []string) error {
for _, target := range scanResults {
// Perform service discovery for open TCP ports
for _, port := range target.TCPPorts {
discoveryResult, err := ScanTargets(target.Host, port)
discoveryResult, err := ScanTargets(context.Background(), target.Host, port)
if err != nil {
fmt.Fprintf(os.Stderr, "Error while discovering services on %s:%d: %s\n", target.Host, port, err)
continue
Expand Down Expand Up @@ -111,7 +112,7 @@ func scan(cmd *cobra.Command, args []string) error {
}
// Perform service discovery for open UDP ports
for _, port := range target.UDPPorts {
discoveryResult, err := ScanTargets(target.Host, port)
discoveryResult, err := ScanTargets(context.Background(), target.Host, port)
if err != nil {
fmt.Fprintf(os.Stderr, "Error while discovering services on %s:%d: %s\n", target.Host, port, err)
continue
Expand Down
7 changes: 5 additions & 2 deletions cmd/servicediscovery.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"
"io"
"sync"
Expand All @@ -21,7 +22,7 @@ type DiscoveryResult struct {
Properties map[string]interface{}
}

func ScanTargets(host string, port int) (result DiscoveryResult, err error) {
func ScanTargets(ctx context.Context, host string, port int) (result DiscoveryResult, err error) {
var sessionWg sync.WaitGroup
var presentationWg sync.WaitGroup
var applicationWg sync.WaitGroup
Expand Down Expand Up @@ -105,6 +106,7 @@ func ScanTargets(host string, port int) (result DiscoveryResult, err error) {
if err != nil {
return
}

applicationLayerChan <- applicationDiscoveryResult
}(applicationDiscoveryItem)
}
Expand All @@ -125,10 +127,11 @@ func ScanTargets(host string, port int) (result DiscoveryResult, err error) {
break // Stop checking application layer protocol
}
}

break // Stop checking presentation layer protocols
}

}

if presentationDiscoveryResult == nil || !presentationDiscoveryResult.GetIsDetected() {
// Continue to discover application layer protocols
applicationLayerChan := make(chan applicationLayerDiscoveryResult)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,24 @@ func (k *KafkaDiscovery) Discover(sessionHandler servicediscovery.ISessionHandle
// Configure the producer
config := sarama.NewConfig()
config.Producer.RequiredAcks = sarama.WaitForAll
// add a more strict connection timeout
config.Net.ReadTimeout = 10 * time.Second
config.Net.WriteTimeout = 10 * time.Second
config.Producer.Retry.Max = 1
config.Producer.Timeout = 500 * time.Millisecond
config.Producer.Return.Successes = true

// Create a new SyncProducer
producer, err := sarama.NewSyncProducer(brokerList, config)

if err != nil {
return &KafkaDiscoveryResult{
isDetected: false,
isAuthenticated: true,
properties: nil, // Set properties to nil as it's not used in this case
}, err
}

defer func() {
if err := producer.Close(); err != nil {
log.Debugf("Failed to close Kafka producer: %s", err)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package applicationlayerdiscovery

import (
"context"
"fmt"
"io"
"log"
"strings"
"time"

"database/sql"

Expand Down Expand Up @@ -55,8 +57,10 @@ func (d *MysqlDiscovery) Discover(sessionHandler servicediscovery.ISessionHandle
}, err
}

// Ping the server
err = db.Ping()
// Ping the server with passed context()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
err = db.PingContext(ctx)
if err != nil {
if strings.Contains(err.Error(), "Access denied") {
return &MysqlDiscoveryResult{
Expand Down

0 comments on commit db0e5e7

Please sign in to comment.