diff --git a/go.mod b/go.mod index 6292663..a1676d9 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,16 @@ module github.com/projectdiscovery/pd-agent go 1.24.1 require ( + github.com/google/gopacket v1.1.19 github.com/projectdiscovery/gcache v0.0.0-20241015120333-12546c6e3f4c github.com/projectdiscovery/goflags v0.1.74 github.com/projectdiscovery/gologger v1.1.61 + github.com/projectdiscovery/mapcidr v1.1.97 github.com/projectdiscovery/utils v0.7.3 github.com/rs/xid v1.6.0 github.com/shirou/gopsutil/v3 v3.23.7 github.com/tidwall/gjson v1.18.0 + golang.org/x/net v0.47.0 golang.org/x/sys v0.38.0 k8s.io/apimachinery v0.34.2 k8s.io/client-go v0.34.2 @@ -77,7 +80,6 @@ require ( go4.org v0.0.0-20230225012048-214862532bf5 // indirect golang.org/x/exp v0.0.0-20250813145105-42675adae3e6 // indirect golang.org/x/mod v0.29.0 // indirect - golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.27.0 // indirect golang.org/x/sync v0.18.0 // indirect golang.org/x/term v0.37.0 // indirect diff --git a/go.sum b/go.sum index 52b875b..07c613a 100644 --- a/go.sum +++ b/go.sum @@ -95,6 +95,8 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/pprof v0.0.0-20190515194954-54271f7e092f/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -185,6 +187,8 @@ github.com/projectdiscovery/goflags v0.1.74 h1:n85uTRj5qMosm0PFBfsvOL24I7TdWRcWq github.com/projectdiscovery/goflags v0.1.74/go.mod h1:UMc9/7dFz2oln+10tv6cy+7WZKTHf9UGhaNkF95emh4= github.com/projectdiscovery/gologger v1.1.61 h1:+jJ0Z0x6X9s69IRjbtsnOfMD8YTFTVADHMKFNu6dUGg= github.com/projectdiscovery/gologger v1.1.61/go.mod h1:EfuwZ1lQX7kH4rgNo0nzk5XPh2j2gpYEQUi9tkoJDJw= +github.com/projectdiscovery/mapcidr v1.1.97 h1:7FkxNNVXp+m1rIu5Nv/2SrF9k4+LwP8QuWs2puwy+2w= +github.com/projectdiscovery/mapcidr v1.1.97/go.mod h1:9dgTJh1SP02gYZdpzMjm6vtYFkEHQHoTyaVNvaeJ7lA= github.com/projectdiscovery/utils v0.7.3 h1:kX+77AA58yK6EZgkTRJEnK9V/7AZYzlXdcu/o/kJhFs= github.com/projectdiscovery/utils v0.7.3/go.mod h1:uDdQ3/VWomai98l+a3Ye/srDXdJ4xUIar/mSXlQ9gBM= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= @@ -284,6 +288,7 @@ golang.org/x/lint v0.0.0-20190909230951-414d861bb4ac/go.mod h1:6SW0HCj/g11FgYtHl golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= golang.org/x/lint v0.0.0-20191125180803-fdd1cda4f05f/go.mod h1:5qLYkcX4OjUUV8bRuDixDT3tpyyb+LUpUlRWLxfhWrs= golang.org/x/lint v0.0.0-20200130185559-910be7a94367/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mobile v0.0.0-20190312151609-d3739f865fa6/go.mod h1:z+o9i4GpDbdi3rU15maQ/Ox0txvL9dWGYEHz965HBQE= golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCcRqshq8CkpyQDoeVncDDYHnLhea+o= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= diff --git a/pkg/peerdiscovery/arp/arp.go b/pkg/peerdiscovery/arp/arp.go new file mode 100644 index 0000000..f97577f --- /dev/null +++ b/pkg/peerdiscovery/arp/arp.go @@ -0,0 +1,183 @@ +package arp + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/projectdiscovery/mapcidr" + "github.com/projectdiscovery/pd-agent/pkg/peerdiscovery/common" + mapsutil "github.com/projectdiscovery/utils/maps" + syncutil "github.com/projectdiscovery/utils/sync" +) + +// Peer represents a discovered ARP peer +type Peer struct { + IP net.IP + MAC net.HardwareAddr +} + +// DiscoverPeers retrieves all ARP peers by first reading the local ARP table, +// then scanning the network in parallel to discover additional peers. +func DiscoverPeers(ctx context.Context) ([]Peer, error) { + peers := mapsutil.NewSyncLockMap[string, *Peer]() + + // Read local ARP table + localPeers, err := readLocalARPTable() + if err != nil { + return nil, fmt.Errorf("failed to read local ARP table: %w", err) + } + + for _, peer := range localPeers { + key := peer.IP.String() + peerCopy := peer + _ = peers.Set(key, &peerCopy) + } + + // Get /24 network ranges from local interfaces + networks, err := common.GetLocalNetworks24() + if err != nil { + return nil, fmt.Errorf("failed to get local networks: %w", err) + } + + // Scan networks sequentially (no hurry) + for _, network := range networks { + select { + case <-ctx.Done(): + goto done + default: + } + + discovered, err := scanNetwork24(ctx, network) + if err != nil { + continue + } + + for _, peer := range discovered { + key := peer.IP.String() + if _, exists := peers.Get(key); !exists { + peerCopy := peer + _ = peers.Set(key, &peerCopy) + } + } + } + +done: + // Convert map to slice + var result []Peer + _ = peers.Iterate(func(key string, peer *Peer) error { + if peer != nil { + result = append(result, *peer) + } + return nil + }) + + return result, nil +} + +// scanNetwork24 scans a /24 network range to discover ARP peers +// Uses UDP connections to trigger OS ARP requests and monitors the ARP table +func scanNetwork24(ctx context.Context, network *net.IPNet) ([]Peer, error) { + // Verify it's a /24 network + ones, bits := network.Mask.Size() + if ones != 24 || bits != 32 { + return nil, fmt.Errorf("network %s is not a /24 network", network.String()) + } + + // Get initial ARP table state + initialPeers, err := readLocalARPTable() + if err != nil { + return nil, fmt.Errorf("failed to read initial ARP table: %w", err) + } + + initialSet := make(map[string]struct{}) + for _, peer := range initialPeers { + if network.Contains(peer.IP) { + initialSet[peer.IP.String()] = struct{}{} + } + } + + // Expand CIDR to get all IPs in /24 range + cidrStr := network.String() + ips, err := mapcidr.IPAddresses(cidrStr) + if err != nil { + return nil, fmt.Errorf("failed to expand CIDR %s: %w", cidrStr, err) + } + + if len(ips) == 0 { + return []Peer{}, nil + } + + // Use adaptive waitgroup with low parallelism (no hurry) + awg, err := syncutil.New(syncutil.WithSize(5)) + if err != nil { + return nil, fmt.Errorf("failed to create adaptive waitgroup: %w", err) + } + + // Trigger ARP resolution for each IP using UDP connections + for _, ipStr := range ips { + select { + case <-ctx.Done(): + goto done + default: + } + + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + // Skip network and broadcast addresses + if common.IsNetworkOrBroadcast(ip, network) { + continue + } + + awg.Add() + go func(targetIP net.IP) { + defer awg.Done() + + // Send UDP packet to trigger ARP resolution + // The OS will handle the ARP request for us + conn, err := net.DialTimeout("udp", net.JoinHostPort(targetIP.String(), "12345"), 50*time.Millisecond) + if err != nil { + // Connection will fail, but ARP resolution may occur + return + } + if conn != nil { + _ = conn.Close() + } + }(ip) + + // Small delay between requests to avoid overwhelming + time.Sleep(10 * time.Millisecond) + } + +done: + awg.Wait() + + // Wait for OS ARP requests to complete and ARP table to update + // Give it time since we're not in a hurry + time.Sleep(2 * time.Second) + + // Read ARP table again to find new entries + finalPeers, err := readLocalARPTable() + if err != nil { + return nil, fmt.Errorf("failed to read final ARP table: %w", err) + } + + // Find newly discovered peers + var discovered []Peer + for _, peer := range finalPeers { + if !network.Contains(peer.IP) { + continue + } + + // Check if this is a new peer + if _, exists := initialSet[peer.IP.String()]; !exists { + discovered = append(discovered, peer) + } + } + + return discovered, nil +} diff --git a/pkg/peerdiscovery/arp/arp_unix.go b/pkg/peerdiscovery/arp/arp_unix.go new file mode 100644 index 0000000..ae63a1f --- /dev/null +++ b/pkg/peerdiscovery/arp/arp_unix.go @@ -0,0 +1,145 @@ +//go:build !windows + +package arp + +import ( + "bufio" + "fmt" + "net" + "os" + "os/exec" + "strings" + + osutils "github.com/projectdiscovery/utils/os" +) + +// readLocalARPTable reads the local ARP table (Linux and macOS) +func readLocalARPTable() ([]Peer, error) { + if osutils.IsLinux() { + return readLinuxARPTable() + } else if osutils.IsOSX() { + return readDarwinARPTable() + } + return nil, fmt.Errorf("unsupported OS") +} + +// readLinuxARPTable reads ARP table from /proc/net/arp +func readLinuxARPTable() ([]Peer, error) { + data, err := os.ReadFile("/proc/net/arp") + if err != nil { + return nil, err + } + + var peers []Peer + scanner := bufio.NewScanner(strings.NewReader(string(data))) + + // Skip header line + if !scanner.Scan() { + return peers, nil + } + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 6 { + continue + } + + // Format: IP address HW type Flags HW address Mask Device + ipStr := fields[0] + macStr := fields[3] + + // Skip incomplete entries + if macStr == "00:00:00:00:00:00" || macStr == "" { + continue + } + + ip := net.ParseIP(ipStr) + if ip == nil || ip.To4() == nil { + continue + } + + mac, err := net.ParseMAC(macStr) + if err != nil { + continue + } + + peers = append(peers, Peer{ + IP: ip, + MAC: mac, + }) + } + + return peers, scanner.Err() +} + +// readDarwinARPTable reads ARP table using 'arp -a' command on macOS +func readDarwinARPTable() ([]Peer, error) { + cmd := exec.Command("arp", "-a") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to execute arp -a: %w", err) + } + + var peers []Peer + scanner := bufio.NewScanner(strings.NewReader(string(output))) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // macOS arp -a format: "hostname (192.168.1.1) at aa:bb:cc:dd:ee:ff [ethernet] on en0" + // or: "? (192.168.1.1) at aa:bb:cc:dd:ee:ff [ethernet] on en0" + + // Extract IP address (between parentheses) + ipStart := strings.Index(line, "(") + ipEnd := strings.Index(line, ")") + if ipStart == -1 || ipEnd == -1 || ipStart >= ipEnd { + continue + } + ipStr := line[ipStart+1 : ipEnd] + + // Extract MAC address (after "at ") + atIndex := strings.Index(line, " at ") + if atIndex == -1 { + continue + } + macStart := atIndex + 4 + macEnd := strings.Index(line[macStart:], " ") + if macEnd == -1 { + macEnd = strings.Index(line[macStart:], "[") + } + if macEnd == -1 { + macEnd = len(line) - macStart + } + macStr := strings.TrimSpace(line[macStart : macStart+macEnd]) + + // Skip incomplete entries + if macStr == "(incomplete)" || macStr == "" { + continue + } + + ip := net.ParseIP(ipStr) + if ip == nil || ip.To4() == nil { + continue + } + + mac, err := net.ParseMAC(macStr) + if err != nil { + continue + } + + peers = append(peers, Peer{ + IP: ip, + MAC: mac, + }) + } + + return peers, scanner.Err() +} diff --git a/pkg/peerdiscovery/arp/arp_windows.go b/pkg/peerdiscovery/arp/arp_windows.go new file mode 100644 index 0000000..b36328b --- /dev/null +++ b/pkg/peerdiscovery/arp/arp_windows.go @@ -0,0 +1,87 @@ +//go:build windows + +package arp + +import ( + "bufio" + "fmt" + "net" + "os/exec" + "strings" +) + +// readLocalARPTable reads the local ARP table on Windows using 'arp -a' command +func readLocalARPTable() ([]Peer, error) { + cmd := exec.Command("arp", "-a") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to execute arp -a: %w", err) + } + + var peers []Peer + scanner := bufio.NewScanner(strings.NewReader(string(output))) + + // Windows arp -a output has two sections: Interface and ARP entries + // Format example: + // Interface: 192.168.1.100 --- 0xa + // Internet Address Physical Address Type + // 192.168.1.1 aa-bb-cc-dd-ee-ff dynamic + // 192.168.1.255 ff-ff-ff-ff-ff-ff static + + inARPTable := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Check if we're entering the ARP table section + if strings.Contains(line, "Internet Address") && strings.Contains(line, "Physical Address") { + inARPTable = true + continue + } + + // Check if we're entering a new interface section + if strings.HasPrefix(line, "Interface:") { + inARPTable = false + continue + } + + if !inARPTable { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + ipStr := fields[0] + macStr := fields[1] + + // Skip incomplete entries + if macStr == "incomplete" || strings.HasPrefix(macStr, "ff-ff-ff-ff-ff-ff") { + continue + } + + // Convert Windows MAC format (aa-bb-cc-dd-ee-ff) to standard format (aa:bb:cc:dd:ee:ff) + macStr = strings.ReplaceAll(macStr, "-", ":") + + ip := net.ParseIP(ipStr) + if ip == nil || ip.To4() == nil { + continue + } + + mac, err := net.ParseMAC(macStr) + if err != nil { + continue + } + + peers = append(peers, Peer{ + IP: ip, + MAC: mac, + }) + } + + return peers, scanner.Err() +} diff --git a/pkg/peerdiscovery/arp/doc.go b/pkg/peerdiscovery/arp/doc.go new file mode 100644 index 0000000..db03f51 --- /dev/null +++ b/pkg/peerdiscovery/arp/doc.go @@ -0,0 +1,9 @@ +// Package arp provides a library for discovering ARP peers on a network. +// discovery is performed in a slow fashion +// - The agent monitors the ARP table for changes +// - The agent sends UDP packets to the target IP to trigger ARP resolution +// - The agent waits for the ARP resolution to complete before continuing with the next target IP +// - The agent reads the ARP table again to find new entries +// +// The agent is designed to be used in a slow fashion to avoid overwhelming the network. +package arp diff --git a/pkg/peerdiscovery/common/ip.go b/pkg/peerdiscovery/common/ip.go new file mode 100644 index 0000000..7b71100 --- /dev/null +++ b/pkg/peerdiscovery/common/ip.go @@ -0,0 +1,35 @@ +package common + +import "net" + +// IsNetworkOrBroadcast checks if an IP is the network or broadcast address. +// For IPv4, it checks both network and broadcast addresses. +// For IPv6, it checks network address and multicast addresses. +func IsNetworkOrBroadcast(ip net.IP, network *net.IPNet) bool { + if network == nil { + return false + } + + // Check if IP equals network address + if ip.Equal(network.IP) { + return true + } + + // For IPv4, check broadcast address + if ip4 := ip.To4(); ip4 != nil { + broadcast := make(net.IP, len(network.IP)) + copy(broadcast, network.IP) + for i := range broadcast { + broadcast[i] |= ^network.Mask[i] + } + return ip.Equal(broadcast) + } + + // For IPv6, check multicast addresses + if ip.IsMulticast() { + return true + } + + return false +} + diff --git a/pkg/peerdiscovery/common/networks.go b/pkg/peerdiscovery/common/networks.go new file mode 100644 index 0000000..23f9c5f --- /dev/null +++ b/pkg/peerdiscovery/common/networks.go @@ -0,0 +1,252 @@ +package common + +import ( + "net" +) + +// GetLocalNetworks24 returns all local network interfaces as /24 IPNet ranges (IPv4 only) +func GetLocalNetworks24() ([]*net.IPNet, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var networks []*net.IPNet + seen := make(map[string]struct{}) + + for _, iface := range interfaces { + // Skip loopback and down interfaces + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + // Only process IPv4 addresses + ip := ipNet.IP.To4() + if ip == nil { + continue + } + + // Only process private networks + if !ip.IsPrivate() { + continue + } + + // Convert to /24 network + mask24 := net.CIDRMask(24, 32) + network24 := &net.IPNet{ + IP: ip.Mask(mask24), + Mask: mask24, + } + + // Avoid duplicates + key := network24.String() + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + + networks = append(networks, network24) + } + } + + return networks, nil +} + +// GetLocalNetworks64 returns all local network interfaces as /64 IPNet ranges (IPv6 only) +func GetLocalNetworks64() ([]*net.IPNet, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var networks []*net.IPNet + seen := make(map[string]struct{}) + + for _, iface := range interfaces { + // Skip loopback and down interfaces + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + ip := ipNet.IP + + // Only process IPv6 addresses + if ip.To4() != nil { + continue + } + + // Must be 16 bytes for IPv6 + if len(ip) != net.IPv6len { + continue + } + + // Skip loopback + if ip.IsLoopback() { + continue + } + + // Skip multicast + if ip.IsMulticast() { + continue + } + + // Only process link-local and ULA (private) addresses + // Link-local: fe80::/10 + // ULA: fc00::/7 (actually fd00::/8 is used for ULA) + isLinkLocal := ip.IsLinkLocalUnicast() + // ULA addresses start with fd (fd00::/8) + isULA := len(ip) == net.IPv6len && ip[0] == 0xfd + + if !isLinkLocal && !isULA { + continue + } + + // Convert to /64 network (IPv6 standard) + mask64 := net.CIDRMask(64, 128) + network64 := &net.IPNet{ + IP: ip.Mask(mask64), + Mask: mask64, + } + + // Avoid duplicates + key := network64.String() + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + + networks = append(networks, network64) + } + } + + return networks, nil +} + +// GetLocalNetworks returns all local network interfaces as IPNet ranges +// Supports both IPv4 (/24) and IPv6 (/64) networks +func GetLocalNetworks() ([]*net.IPNet, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var networks []*net.IPNet + seen := make(map[string]struct{}) + + for _, iface := range interfaces { + // Skip loopback and down interfaces + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagUp == 0 { + continue + } + + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, addr := range addrs { + ipNet, ok := addr.(*net.IPNet) + if !ok { + continue + } + + ip := ipNet.IP + + // Handle IPv4 addresses + if ip4 := ip.To4(); ip4 != nil { + // Only process private networks + if !ip4.IsPrivate() { + continue + } + + // Convert to /24 network + mask24 := net.CIDRMask(24, 32) + network24 := &net.IPNet{ + IP: ip4.Mask(mask24), + Mask: mask24, + } + + // Avoid duplicates + key := network24.String() + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + + networks = append(networks, network24) + continue + } + + // Handle IPv6 addresses + if len(ip) == net.IPv6len { + // Skip loopback + if ip.IsLoopback() { + continue + } + + // Skip multicast + if ip.IsMulticast() { + continue + } + + // Only process link-local and ULA (private) addresses + isLinkLocal := ip.IsLinkLocalUnicast() + // ULA addresses start with fd (fd00::/8) + isULA := ip[0] == 0xfd + + if !isLinkLocal && !isULA { + continue + } + + // Convert to /64 network (IPv6 standard) + mask64 := net.CIDRMask(64, 128) + network64 := &net.IPNet{ + IP: ip.Mask(mask64), + Mask: mask64, + } + + // Avoid duplicates + key := network64.String() + if _, exists := seen[key]; exists { + continue + } + seen[key] = struct{}{} + + networks = append(networks, network64) + } + } + } + + return networks, nil +} diff --git a/pkg/peerdiscovery/igmp/capture.go b/pkg/peerdiscovery/igmp/capture.go new file mode 100644 index 0000000..42cab90 --- /dev/null +++ b/pkg/peerdiscovery/igmp/capture.go @@ -0,0 +1,43 @@ +package igmp + +import ( + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/pcap" +) + +const ( + // DefaultSnapLen is the default snapshot length for packet capture + DefaultSnapLen = 1600 + // DefaultPromisc enables promiscuous mode by default + DefaultPromisc = true + // DefaultTimeout is the default timeout for packet reads (100ms for responsiveness) + DefaultTimeout = 100 * time.Millisecond +) + +// createCaptureHandle creates a pcap handle for packet capture on the given interface +func createCaptureHandle(iface *net.Interface, config *Config) (*pcap.Handle, error) { + // Open live capture with timeout for responsive context cancellation + handle, err := pcap.OpenLive(iface.Name, DefaultSnapLen, DefaultPromisc, DefaultTimeout) + if err != nil { + return nil, err + } + + return handle, nil +} + +// packetSource wraps pcap packet source +type packetSource struct { + source *gopacket.PacketSource + handle *pcap.Handle +} + +// newPacketSource creates a new packet source from a pcap handle +func newPacketSource(handle *pcap.Handle) *packetSource { + return &packetSource{ + source: gopacket.NewPacketSource(handle, handle.LinkType()), + handle: handle, + } +} diff --git a/pkg/peerdiscovery/igmp/doc.go b/pkg/peerdiscovery/igmp/doc.go new file mode 100644 index 0000000..dc45919 --- /dev/null +++ b/pkg/peerdiscovery/igmp/doc.go @@ -0,0 +1,49 @@ +// Package igmp provides a library for discovering alive hosts on a network +// using IGMP (Internet Group Management Protocol) monitoring. +// +// The package provides continuous monitoring of network interfaces for IGMP +// membership reports, which can reveal active hosts participating in multicast +// groups without requiring any open ports on target hosts. +// +// The main function is: +// - DiscoverPeers: Continuously monitors network interfaces and returns a channel +// that receives Peer structs as hosts are discovered +// +// Discovery is performed by: +// - Monitoring network interfaces for IGMP membership reports +// - Optionally sending periodic membership queries to trigger immediate reports +// - Extracting source IP addresses from IGMP packets to identify active hosts +// - Using BPF filtering for performance (kernel-level packet filtering) +// +// Example usage: +// +// ctx, cancel := context.WithCancel(context.Background()) +// defer cancel() +// +// config := &igmp.Config{ +// Version: 2, +// EnableQueries: true, +// ChannelBuffer: 100, +// } +// +// // Start continuous monitoring +// peerChan, err := igmp.DiscoverPeers(ctx, config) +// if err != nil { +// log.Fatal(err) +// } +// +// // Process discovered peers +// for peer := range peerChan { +// log.Printf("Discovered host: %s (Groups: %v)", peer.IP, peer.MulticastGroups) +// } +// +// Privilege Requirements: +// - Raw sockets and packet capture require root/admin privileges on most systems +// - libpcap/WinPcap must be installed for packet capture +// +// Limitations: +// - Only discovers hosts that are members of multicast groups +// - Works only on the same network segment (Layer 2) +// - May miss hosts not participating in multicast +// - Requires elevated privileges for packet capture +package igmp diff --git a/pkg/peerdiscovery/igmp/igmp.go b/pkg/peerdiscovery/igmp/igmp.go new file mode 100644 index 0000000..e493bc6 --- /dev/null +++ b/pkg/peerdiscovery/igmp/igmp.go @@ -0,0 +1,259 @@ +package igmp + +import ( + "context" + "fmt" + "net" + "sync" + "time" +) + +// Peer represents a discovered IGMP peer +type Peer struct { + IP net.IP + MAC net.HardwareAddr // Optional, may be nil + MulticastGroups []string // List of multicast groups the host belongs to + LastSeen time.Time // When the host was last detected + ResponseTime time.Duration // Time to receive IGMP response + IGMPVersion int // IGMP version detected (1, 2, or 3) +} + +// Config holds configuration for IGMP discovery +type Config struct { + // IGMP version to use (1, 2, or 3) + Version int // Default: 2 + + // Multicast groups to query + MulticastGroups []net.IP // Default: common multicast groups + + // Network interface + Interface *net.Interface // Specific interface to use (nil = all) + + // Discovery settings + QueryInterval time.Duration // Interval between membership queries (default: 125s for IGMPv2) + EnableQueries bool // Enable periodic membership queries (default: true) + + // Channel buffer size for discovered peers + ChannelBuffer int // Default: 100 + + // Enable BPF filtering + EnableBPF bool // Default: true +} + +// DefaultConfig returns a Config with sensible defaults +func DefaultConfig() *Config { + return &Config{ + Version: 2, + MulticastGroups: CommonMulticastGroups, + Interface: nil, // All interfaces + QueryInterval: 125 * time.Second, + EnableQueries: true, + ChannelBuffer: 100, + EnableBPF: true, + } +} + +// DiscoverPeers continuously monitors network interfaces for IGMP messages +// and sends discovered peers to the returned channel. +// The function runs until the context is cancelled. +// Returns a channel that receives Peer structs as hosts are discovered. +func DiscoverPeers(ctx context.Context, config *Config) (<-chan Peer, error) { + if config == nil { + config = DefaultConfig() + } + + // Create buffered channel for discovered peers + peerChan := make(chan Peer, config.ChannelBuffer) + + // Get network interfaces to monitor + var interfaces []*net.Interface + + if config.Interface != nil { + // Monitor specific interface + interfaces = []*net.Interface{config.Interface} + } else { + // Get all local network interfaces + allInterfaces, err := net.Interfaces() + if err != nil { + close(peerChan) + return nil, fmt.Errorf("failed to get network interfaces: %w", err) + } + + // Filter to only active, non-loopback interfaces + for i := range allInterfaces { + iface := &allInterfaces[i] + if iface.Flags&net.FlagLoopback != 0 { + continue + } + if iface.Flags&net.FlagUp == 0 { + continue + } + interfaces = append(interfaces, iface) + } + } + + if len(interfaces) == 0 { + close(peerChan) + return nil, fmt.Errorf("no suitable network interfaces found") + } + + // Use sync.WaitGroup to track monitor goroutines + // Channel will be closed when all monitors finish or context is cancelled + var wg sync.WaitGroup + + // Start monitoring goroutines for each interface + for _, iface := range interfaces { + select { + case <-ctx.Done(): + close(peerChan) + return nil, ctx.Err() + default: + } + + wg.Add(1) + // Start monitoring goroutine for this interface + go func(interfaceToMonitor *net.Interface) { + defer wg.Done() + if err := monitorInterface(ctx, interfaceToMonitor, config, peerChan); err != nil { + // Log error but don't fail entire discovery + // Errors are expected if interface becomes unavailable + _ = err + } + }(iface) + } + + // Start periodic query goroutine if enabled + if config.EnableQueries { + go sendPeriodicQueries(ctx, interfaces, config) + } + + // Start goroutine to close channel when all monitors finish or context is done + go func() { + // Wait for context cancellation or all monitors to finish + done := make(chan struct{}) + go func() { + <-ctx.Done() + close(done) + }() + go func() { + wg.Wait() + close(done) + }() + + <-done + // Give a small grace period for final packets to be sent + time.Sleep(200 * time.Millisecond) + close(peerChan) + }() + + return peerChan, nil +} + +// monitorInterface monitors a specific network interface for IGMP messages +// and sends discovered peers to the provided channel. +// Runs until context is cancelled. +func monitorInterface(ctx context.Context, iface *net.Interface, config *Config, peerChan chan<- Peer) error { + // Create packet capture handle + handle, err := createCaptureHandle(iface, config) + if err != nil { + return fmt.Errorf("failed to create capture handle for %s: %w", iface.Name, err) + } + defer handle.Close() + + // Set BPF filter if enabled + if config.EnableBPF { + filter := "ip proto 2 and (ip[20] == 0x12 or ip[20] == 0x16 or ip[20] == 0x22)" + if err := handle.SetBPFFilter(filter); err != nil { + // Log warning but continue without filter + // BPF filter is optional for functionality + _ = err + } + } + + // Create packet source + packetSource := newPacketSource(handle) + + // Process packets from packet source channel + packetChan := packetSource.source.Packets() + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case packet, ok := <-packetChan: + if !ok { + // Channel closed (handle closed) + return nil + } + if packet == nil { + continue + } + + // Parse and process IGMP packet + peer, err := parseIGMPPacket(packet, iface) + if err != nil { + // Skip invalid packets (not IGMP, wrong type, etc.) + continue + } + + if peer == nil { + continue + } + + // Send peer to channel (non-blocking with timeout to avoid blocking) + select { + case peerChan <- *peer: + // Successfully sent + case <-ctx.Done(): + return ctx.Err() + case <-time.After(50 * time.Millisecond): + // Channel full, skip this peer to avoid blocking monitor + // This prevents one slow consumer from blocking all monitors + // The timeout ensures we don't block indefinitely + continue + } + } + } +} + +// sendPeriodicQueries sends periodic IGMP membership queries to trigger reports +func sendPeriodicQueries(ctx context.Context, interfaces []*net.Interface, config *Config) { + ticker := time.NewTicker(config.QueryInterval) + defer ticker.Stop() + + // Send initial query immediately + sendQueriesToInterfaces(ctx, interfaces, config) + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + sendQueriesToInterfaces(ctx, interfaces, config) + } + } +} + +// sendQueriesToInterfaces sends membership queries on all interfaces +func sendQueriesToInterfaces(ctx context.Context, interfaces []*net.Interface, config *Config) { + for _, iface := range interfaces { + select { + case <-ctx.Done(): + return + default: + } + + // Send query to common multicast groups + for _, group := range config.MulticastGroups { + select { + case <-ctx.Done(): + return + default: + } + + // Send query (non-blocking, errors are ignored) + _ = sendMembershipQuery(iface, group, config.Version) + } + } +} + diff --git a/pkg/peerdiscovery/igmp/multicast.go b/pkg/peerdiscovery/igmp/multicast.go new file mode 100644 index 0000000..73253f0 --- /dev/null +++ b/pkg/peerdiscovery/igmp/multicast.go @@ -0,0 +1,43 @@ +package igmp + +import "net" + +// CommonMulticastGroups contains commonly used multicast group addresses +var CommonMulticastGroups = []net.IP{ + net.ParseIP("224.0.0.1"), // All Systems (all-hosts) + net.ParseIP("224.0.0.2"), // All Routers + net.ParseIP("224.0.0.22"), // IGMP + net.ParseIP("224.0.0.251"), // mDNS + net.ParseIP("224.0.0.252"), // LLMNR + net.ParseIP("239.255.255.250"), // SSDP +} + +// IGMP message type constants +const ( + IGMPMembershipQuery = 0x11 // Membership query + IGMPV1MembershipReport = 0x12 // IGMPv1 membership report + IGMPV2MembershipReport = 0x16 // IGMPv2 membership report + IGMPV2LeaveGroup = 0x17 // IGMPv2 leave group + IGMPV3MembershipReport = 0x22 // IGMPv3 membership report +) + +// IsMembershipReport checks if the IGMP type is a membership report +func IsMembershipReport(msgType uint8) bool { + return msgType == IGMPV1MembershipReport || + msgType == IGMPV2MembershipReport || + msgType == IGMPV3MembershipReport +} + +// GetIGMPVersion returns the IGMP version based on message type +func GetIGMPVersion(msgType uint8) int { + switch msgType { + case IGMPV1MembershipReport: + return 1 + case IGMPV2MembershipReport, IGMPV2LeaveGroup: + return 2 + case IGMPV3MembershipReport: + return 3 + default: + return 0 + } +} diff --git a/pkg/peerdiscovery/igmp/parser.go b/pkg/peerdiscovery/igmp/parser.go new file mode 100644 index 0000000..885dfcc --- /dev/null +++ b/pkg/peerdiscovery/igmp/parser.go @@ -0,0 +1,120 @@ +package igmp + +import ( + "fmt" + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +// parseIGMPPacket parses an IGMP packet and extracts peer information +func parseIGMPPacket(packet gopacket.Packet, iface *net.Interface) (*Peer, error) { + // Get IP layer + ipLayer := packet.Layer(layers.LayerTypeIPv4) + if ipLayer == nil { + return nil, fmt.Errorf("packet does not contain IPv4 layer") + } + + ip, ok := ipLayer.(*layers.IPv4) + if !ok { + return nil, fmt.Errorf("failed to cast to IPv4 layer") + } + + // Check if protocol is IGMP + if ip.Protocol != layers.IPProtocolIGMP { + return nil, fmt.Errorf("packet is not IGMP (protocol: %d)", ip.Protocol) + } + + // Get IGMP layer + igmpLayer := packet.Layer(layers.LayerTypeIGMP) + if igmpLayer == nil { + // Try to parse IGMP manually if layer is not available + return parseIGMPManually(packet, ip) + } + + igmp, ok := igmpLayer.(*layers.IGMP) + if !ok { + return nil, fmt.Errorf("failed to cast to IGMP layer") + } + + // Only process membership reports + igmpType := uint8(igmp.Type) + if !IsMembershipReport(igmpType) { + return nil, fmt.Errorf("not a membership report (type: 0x%02x)", igmpType) + } + + // Extract peer information + peer := &Peer{ + IP: ip.SrcIP, + LastSeen: time.Now(), + IGMPVersion: GetIGMPVersion(igmpType), + MulticastGroups: []string{}, + } + + // Extract multicast group + if igmp.GroupAddress != nil && !igmp.GroupAddress.IsUnspecified() { + peer.MulticastGroups = append(peer.MulticastGroups, igmp.GroupAddress.String()) + } + + // Try to extract MAC address from Ethernet layer + if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { + if eth, ok := ethLayer.(*layers.Ethernet); ok { + peer.MAC = eth.SrcMAC + } + } + + return peer, nil +} + +// parseIGMPManually parses IGMP packet manually if gopacket layers don't support it +func parseIGMPManually(packet gopacket.Packet, ip *layers.IPv4) (*Peer, error) { + // Get the payload (IGMP data) + payload := ip.Payload + if len(payload) < 8 { + return nil, fmt.Errorf("IGMP packet too short") + } + + // Parse IGMP header manually + msgType := payload[0] + maxRespTime := payload[1] + checksum := uint16(payload[2])<<8 | uint16(payload[3]) + + // Validate checksum (simplified - full validation would require recomputing) + _ = checksum + _ = maxRespTime + + // Only process membership reports + if !IsMembershipReport(msgType) { + return nil, fmt.Errorf("not a membership report (type: 0x%02x)", msgType) + } + + // Extract group address (bytes 4-7) + var groupAddr net.IP + if len(payload) >= 8 { + groupAddr = net.IP(payload[4:8]) + } + + // Create peer + peer := &Peer{ + IP: ip.SrcIP, + LastSeen: time.Now(), + IGMPVersion: GetIGMPVersion(msgType), + MulticastGroups: []string{}, + } + + // Add group address if valid + if groupAddr != nil && !groupAddr.IsUnspecified() { + peer.MulticastGroups = append(peer.MulticastGroups, groupAddr.String()) + } + + // Try to extract MAC address from Ethernet layer + if ethLayer := packet.Layer(layers.LayerTypeEthernet); ethLayer != nil { + if eth, ok := ethLayer.(*layers.Ethernet); ok { + peer.MAC = eth.SrcMAC + } + } + + return peer, nil +} diff --git a/pkg/peerdiscovery/igmp/queries.go b/pkg/peerdiscovery/igmp/queries.go new file mode 100644 index 0000000..d316dd4 --- /dev/null +++ b/pkg/peerdiscovery/igmp/queries.go @@ -0,0 +1,189 @@ +package igmp + +import ( + "fmt" + "net" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "golang.org/x/net/ipv4" +) + +// sendMembershipQuery sends an IGMP membership query to the specified multicast group +func sendMembershipQuery(iface *net.Interface, group net.IP, version int) error { + // Get interface IP + srcIP := getInterfaceIP(iface) + if srcIP == nil || srcIP.IsUnspecified() { + return fmt.Errorf("interface %s has no IPv4 address", iface.Name) + } + + // Build IGMP payload + igmpPayload := buildIGMPPayload(group, version) + + // Create raw socket for sending + conn, err := net.ListenPacket("ip4:2", "0.0.0.0") // Protocol 2 = IGMP + if err != nil { + return fmt.Errorf("failed to create raw socket: %w", err) + } + defer func() { + _ = conn.Close() + }() + + // Wrap with ipv4.RawConn for control + rawConn, err := ipv4.NewRawConn(conn) + if err != nil { + return fmt.Errorf("failed to create raw connection: %w", err) + } + defer func() { + _ = rawConn.Close() + }() + + // Set socket options + if err := rawConn.SetMulticastInterface(iface); err != nil { + return fmt.Errorf("failed to set multicast interface: %w", err) + } + + // Build IP header using gopacket for proper construction + ipLayer := &layers.IPv4{ + Version: 4, + IHL: 5, // 20 bytes header (5 * 4) + TTL: 1, // Don't route beyond local network + Protocol: layers.IPProtocolIGMP, + SrcIP: srcIP, + DstIP: group, + } + + // Serialize IP header to get proper values + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + if err := gopacket.SerializeLayers(buffer, opts, ipLayer, gopacket.Payload(igmpPayload)); err != nil { + return fmt.Errorf("failed to serialize IP layer: %w", err) + } + + packetBytes := buffer.Bytes() + + // Calculate and set IGMP checksum + igmpStart := 20 // IP header length + if len(packetBytes) >= igmpStart+8 { + checksum := calculateIGMPChecksum(packetBytes[igmpStart : igmpStart+8]) + packetBytes[igmpStart+2] = byte(checksum >> 8) + packetBytes[igmpStart+3] = byte(checksum & 0xff) + } + + // Extract IP header and IGMP payload for raw socket + ipHeader := &ipv4.Header{ + Version: int(packetBytes[0] >> 4), + Len: int((packetBytes[0] & 0x0f) * 4), + TotalLen: int(uint16(packetBytes[2])<<8 | uint16(packetBytes[3])), + TTL: int(packetBytes[8]), + Protocol: int(packetBytes[9]), + Src: net.IP(packetBytes[12:16]), + Dst: net.IP(packetBytes[16:20]), + } + + igmpData := packetBytes[20:] + + // Set control flags + cm := &ipv4.ControlMessage{ + IfIndex: iface.Index, + } + + // Send packet + if err := rawConn.WriteTo(ipHeader, igmpData, cm); err != nil { + return fmt.Errorf("failed to send IGMP query: %w", err) + } + + return nil +} + +// buildIGMPPayload constructs the IGMP payload (without IP header) +func buildIGMPPayload(group net.IP, version int) []byte { + // IGMP query packet structure: + // Type (1 byte) + Max Response Time (1 byte) + Checksum (2 bytes) + Group Address (4 bytes) + data := make([]byte, 8) + + // Type: Membership Query + data[0] = IGMPMembershipQuery + + // Max Response Time (in 1/10 second units for IGMPv2) + // Default: 10 seconds = 100 (0x64) + if version == 2 { + data[1] = 100 // 10 seconds + } else { + data[1] = 0 // IGMPv1 uses 0 + } + + // Checksum (will be calculated after IP header is known) + data[2] = 0 + data[3] = 0 + + // Group Address + if group != nil && !group.IsUnspecified() { + copy(data[4:8], group.To4()) + } else { + // General query uses 0.0.0.0 + copy(data[4:8], net.IPv4zero.To4()) + } + + return data +} + +// calculateIGMPChecksum calculates the IGMP checksum (RFC 1071) +// The checksum is the 16-bit one's complement of the one's complement sum +// of all 16-bit words in the IGMP message. +func calculateIGMPChecksum(data []byte) uint16 { + var sum uint32 + + // Sum all 16-bit words + for i := 0; i < len(data); i += 2 { + var word uint16 + if i+1 < len(data) { + word = uint16(data[i])<<8 | uint16(data[i+1]) + } else { + word = uint16(data[i]) << 8 + } + sum += uint32(word) + } + + // Add carry bits (fold 32-bit sum to 16 bits) + for sum>>16 != 0 { + sum = (sum & 0xffff) + (sum >> 16) + } + + // One's complement + return ^uint16(sum) +} + +// getInterfaceIP gets the first IPv4 address of an interface +func getInterfaceIP(iface *net.Interface) net.IP { + addrs, err := iface.Addrs() + if err != nil { + return nil + } + + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok { + if ip := ipNet.IP.To4(); ip != nil { + // Skip link-local addresses (169.254.0.0/16) + if !ip.IsLinkLocalUnicast() { + return ip + } + } + } + } + + // Fallback to any IPv4 address if no non-link-local found + for _, addr := range addrs { + if ipNet, ok := addr.(*net.IPNet); ok { + if ip := ipNet.IP.To4(); ip != nil { + return ip + } + } + } + + return nil +} diff --git a/pkg/peerdiscovery/ndp/ndp.go b/pkg/peerdiscovery/ndp/ndp.go new file mode 100644 index 0000000..97bace9 --- /dev/null +++ b/pkg/peerdiscovery/ndp/ndp.go @@ -0,0 +1,204 @@ +package ndp + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/projectdiscovery/mapcidr" + "github.com/projectdiscovery/pd-agent/pkg/peerdiscovery/common" + mapsutil "github.com/projectdiscovery/utils/maps" + syncutil "github.com/projectdiscovery/utils/sync" +) + +// Peer represents a discovered NDP peer +type Peer struct { + IP net.IP + MAC net.HardwareAddr +} + +// DiscoverPeers retrieves all NDP peers by first reading the local NDP table, +// then scanning the network in parallel to discover additional peers. +func DiscoverPeers(ctx context.Context) ([]Peer, error) { + peers := mapsutil.NewSyncLockMap[string, *Peer]() + + // Read local NDP table + localPeers, err := readLocalNDPTable() + if err != nil { + return nil, fmt.Errorf("failed to read local NDP table: %w", err) + } + + for _, peer := range localPeers { + key := peer.IP.String() + peerCopy := peer + _ = peers.Set(key, &peerCopy) + } + + // Get /64 network ranges from local interfaces + networks, err := common.GetLocalNetworks64() + if err != nil { + return nil, fmt.Errorf("failed to get local networks: %w", err) + } + + // Scan networks sequentially (no hurry) + for _, network := range networks { + select { + case <-ctx.Done(): + goto done + default: + } + + discovered, err := scanNetwork64(ctx, network) + if err != nil { + continue + } + + for _, peer := range discovered { + key := peer.IP.String() + if _, exists := peers.Get(key); !exists { + peerCopy := peer + _ = peers.Set(key, &peerCopy) + } + } + } + +done: + // Convert map to slice + var result []Peer + _ = peers.Iterate(func(key string, peer *Peer) error { + if peer != nil { + result = append(result, *peer) + } + return nil + }) + + return result, nil +} + +// scanNetwork64 scans a /64 network range to discover NDP peers +// Uses UDP6 connections to trigger OS NDP requests and monitors the NDP table +func scanNetwork64(ctx context.Context, network *net.IPNet) ([]Peer, error) { + // Verify it's a /64 network + ones, bits := network.Mask.Size() + if ones != 64 || bits != 128 { + return nil, fmt.Errorf("network %s is not a /64 network", network.String()) + } + + // Get initial NDP table state + initialPeers, err := readLocalNDPTable() + if err != nil { + return nil, fmt.Errorf("failed to read initial NDP table: %w", err) + } + + initialSet := make(map[string]struct{}) + for _, peer := range initialPeers { + if network.Contains(peer.IP) { + initialSet[peer.IP.String()] = struct{}{} + } + } + + // Expand CIDR to get all IPs in /64 range + cidrStr := network.String() + ips, err := mapcidr.IPAddresses(cidrStr) + if err != nil { + return nil, fmt.Errorf("failed to expand CIDR %s: %w", cidrStr, err) + } + + if len(ips) == 0 { + return []Peer{}, nil + } + + // Use adaptive waitgroup with low parallelism (no hurry) + awg, err := syncutil.New(syncutil.WithSize(5)) + if err != nil { + return nil, fmt.Errorf("failed to create adaptive waitgroup: %w", err) + } + + // Trigger NDP resolution for each IP using UDP6 connections + for _, ipStr := range ips { + select { + case <-ctx.Done(): + goto done + default: + } + + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + // Skip network and multicast addresses + if isNetworkOrMulticast(ip, network) { + continue + } + + awg.Add() + go func(targetIP net.IP) { + defer awg.Done() + + // Send UDP6 packet to trigger NDP resolution + // The OS will handle the NDP Neighbor Solicitation for us + conn, err := net.DialTimeout("udp6", net.JoinHostPort(targetIP.String(), "12345"), 50*time.Millisecond) + if err != nil { + // Connection will fail, but NDP resolution may occur + return + } + if conn != nil { + _ = conn.Close() + } + }(ip) + + // Small delay between requests to avoid overwhelming + time.Sleep(10 * time.Millisecond) + } + +done: + awg.Wait() + + // Wait for OS NDP requests to complete and NDP table to update + // Give it time since we're not in a hurry + time.Sleep(2 * time.Second) + + // Read NDP table again to find new entries + finalPeers, err := readLocalNDPTable() + if err != nil { + return nil, fmt.Errorf("failed to read final NDP table: %w", err) + } + + // Find newly discovered peers + var discovered []Peer + for _, peer := range finalPeers { + if !network.Contains(peer.IP) { + continue + } + + // Check if this is a new peer + if _, exists := initialSet[peer.IP.String()]; !exists { + discovered = append(discovered, peer) + } + } + + return discovered, nil +} + +// isNetworkOrMulticast checks if an IPv6 address is the network or multicast address +// IPv6 doesn't have broadcast, uses multicast instead +func isNetworkOrMulticast(ip net.IP, network *net.IPNet) bool { + // Network address + if ip.Equal(network.IP) { + return true + } + + // Check if it's a multicast address + if ip.IsMulticast() { + return true + } + + // Skip all-nodes multicast (ff02::1) + if ip.Equal(net.ParseIP("ff02::1")) { + return true + } + + return false +} diff --git a/pkg/peerdiscovery/ndp/ndp_unix.go b/pkg/peerdiscovery/ndp/ndp_unix.go new file mode 100644 index 0000000..6c3465c --- /dev/null +++ b/pkg/peerdiscovery/ndp/ndp_unix.go @@ -0,0 +1,171 @@ +//go:build !windows + +package ndp + +import ( + "bufio" + "fmt" + "net" + "os/exec" + "runtime" + "strings" + + osutils "github.com/projectdiscovery/utils/os" +) + +// readLocalNDPTable reads the local NDP table (Linux and macOS) +func readLocalNDPTable() ([]Peer, error) { + if osutils.IsLinux() { + return readLinuxNDPTable() + } else if osutils.IsOSX() { + return readDarwinNDPTable() + } + return nil, fmt.Errorf("unsupported OS: %s", runtime.GOOS) +} + +// readLinuxNDPTable reads NDP table using 'ip -6 neigh show' command on Linux +func readLinuxNDPTable() ([]Peer, error) { + cmd := exec.Command("ip", "-6", "neigh", "show") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to execute ip -6 neigh show: %w", err) + } + + var peers []Peer + scanner := bufio.NewScanner(strings.NewReader(string(output))) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + fields := strings.Fields(line) + if len(fields) < 5 { + continue + } + + // Format: fe80::1 dev eth0 lladdr aa:bb:cc:dd:ee:ff REACHABLE + // or: fe80::1 dev eth0 lladdr aa:bb:cc:dd:ee:ff STALE + ipStr := fields[0] + + // Find MAC address (field with lladdr prefix or after lladdr) + var macStr string + for i, field := range fields { + if field == "lladdr" && i+1 < len(fields) { + macStr = fields[i+1] + break + } + } + + if macStr == "" { + continue + } + + // Skip incomplete entries + if macStr == "00:00:00:00:00:00" || strings.Contains(line, "FAILED") { + continue + } + + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + // Only process IPv6 addresses + if ip.To4() != nil { + continue + } + + mac, err := net.ParseMAC(macStr) + if err != nil { + continue + } + + peers = append(peers, Peer{ + IP: ip, + MAC: mac, + }) + } + + return peers, scanner.Err() +} + +// readDarwinNDPTable reads NDP table using 'ndp -an' command on macOS +func readDarwinNDPTable() ([]Peer, error) { + cmd := exec.Command("ndp", "-an") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to execute ndp -an: %w", err) + } + + var peers []Peer + scanner := bufio.NewScanner(strings.NewReader(string(output))) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // macOS ndp -an format: "? (fe80::1%en0) at aa:bb:cc:dd:ee:ff on en0 ifscope [ethernet]" + // or: "fe80::1%en0 (fe80::1%en0) at aa:bb:cc:dd:ee:ff on en0 ifscope [ethernet]" + + // Extract IP address (between parentheses, before % if present) + ipStart := strings.Index(line, "(") + ipEnd := strings.Index(line, ")") + if ipStart == -1 || ipEnd == -1 || ipStart >= ipEnd { + continue + } + ipWithScope := line[ipStart+1 : ipEnd] + + // Remove scope identifier (%interface) + ipStr := ipWithScope + if scopeIdx := strings.Index(ipWithScope, "%"); scopeIdx != -1 { + ipStr = ipWithScope[:scopeIdx] + } + + // Extract MAC address (after "at ") + atIndex := strings.Index(line, " at ") + if atIndex == -1 { + continue + } + macStart := atIndex + 4 + macEnd := strings.Index(line[macStart:], " ") + if macEnd == -1 { + macEnd = strings.Index(line[macStart:], " on") + } + if macEnd == -1 { + macEnd = len(line) - macStart + } + macStr := strings.TrimSpace(line[macStart : macStart+macEnd]) + + // Skip incomplete entries + if macStr == "(incomplete)" || macStr == "" { + continue + } + + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + // Only process IPv6 addresses + if ip.To4() != nil { + continue + } + + mac, err := net.ParseMAC(macStr) + if err != nil { + continue + } + + peers = append(peers, Peer{ + IP: ip, + MAC: mac, + }) + } + + return peers, scanner.Err() +} + diff --git a/pkg/peerdiscovery/ndp/ndp_windows.go b/pkg/peerdiscovery/ndp/ndp_windows.go new file mode 100644 index 0000000..c2f3301 --- /dev/null +++ b/pkg/peerdiscovery/ndp/ndp_windows.go @@ -0,0 +1,90 @@ +//go:build windows + +package ndp + +import ( + "bufio" + "fmt" + "net" + "os/exec" + "strings" +) + +// readLocalNDPTable reads the local NDP table on Windows using 'netsh interface ipv6 show neighbors' command +func readLocalNDPTable() ([]Peer, error) { + cmd := exec.Command("netsh", "interface", "ipv6", "show", "neighbors") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to execute netsh interface ipv6 show neighbors: %w", err) + } + + var peers []Peer + scanner := bufio.NewScanner(strings.NewReader(string(output))) + + // Windows netsh output format: + // Interface 12: Ethernet + // fe80::1 aa-bb-cc-dd-ee-ff Permanent + // 2001:db8::1 aa-bb-cc-dd-ee-ff Reachable + + inTable := false + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + // Check if we're entering the neighbors table + if strings.Contains(line, "Interface") { + inTable = true + continue + } + + if !inTable { + continue + } + + // Skip header lines + if strings.Contains(line, "IPv6 Address") || strings.Contains(line, "---") { + continue + } + + fields := strings.Fields(line) + if len(fields) < 2 { + continue + } + + ipStr := fields[0] + macStr := fields[1] + + // Skip incomplete entries + if macStr == "incomplete" || strings.HasPrefix(macStr, "ff-ff-ff-ff-ff-ff") { + continue + } + + // Convert Windows MAC format (aa-bb-cc-dd-ee-ff) to standard format (aa:bb:cc:dd:ee:ff) + macStr = strings.ReplaceAll(macStr, "-", ":") + + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + // Only process IPv6 addresses + if ip.To4() != nil { + continue + } + + mac, err := net.ParseMAC(macStr) + if err != nil { + continue + } + + peers = append(peers, Peer{ + IP: ip, + MAC: mac, + }) + } + + return peers, scanner.Err() +} + diff --git a/pkg/peerdiscovery/pingsweep/doc.go b/pkg/peerdiscovery/pingsweep/doc.go new file mode 100644 index 0000000..1300a1f --- /dev/null +++ b/pkg/peerdiscovery/pingsweep/doc.go @@ -0,0 +1,31 @@ +// Package pingsweep provides a library for discovering active hosts on a network +// using ICMP ping sweep. +// +// The package provides two main functions: +// - DiscoverPeers: Scans a provided list of CIDRs or IPs +// - Autodiscover: Automatically discovers and scans local network interfaces +// +// Discovery is performed by: +// - Expanding network ranges to individual IPs +// - Sending ICMP echo requests to each IP in parallel using an adaptive waitgroup (10 workers) +// - Collecting responses to identify active hosts +// +// Example usage: +// +// // Manual scan of specific targets +// targets := []string{"192.168.1.0/24", "10.0.0.1"} +// peers, err := pingsweep.DiscoverPeers(ctx, targets) +// +// // Automatic discovery of local networks +// peers, err := pingsweep.Autodiscover(ctx) +// +// Privilege Requirements: +// - Raw ICMP sockets require root/admin privileges on most systems +// - Consider using alternative methods if privileges are not available +// +// Limitations: +// - Hosts with ICMP disabled or firewalled will not respond +// - Some networks may rate-limit ICMP traffic +// - Large network scans may take significant time +package pingsweep + diff --git a/pkg/peerdiscovery/pingsweep/pingsweep.go b/pkg/peerdiscovery/pingsweep/pingsweep.go new file mode 100644 index 0000000..78e2ada --- /dev/null +++ b/pkg/peerdiscovery/pingsweep/pingsweep.go @@ -0,0 +1,318 @@ +package pingsweep + +import ( + "context" + "fmt" + "net" + "time" + + "github.com/projectdiscovery/mapcidr" + "github.com/projectdiscovery/pd-agent/pkg/peerdiscovery/common" + mapsutil "github.com/projectdiscovery/utils/maps" + "golang.org/x/net/icmp" +) + +// Peer represents a discovered ping peer +type Peer struct { + IP net.IP + MAC net.HardwareAddr // Optional, may be nil + RTT time.Duration // Round-trip time for the ping +} + +// DiscoverPeers scans the provided CIDRs or IPs and returns discovered active peers +// targets can be CIDR notation (e.g., "192.168.1.0/24") or individual IPs (e.g., "192.168.1.1") +func DiscoverPeers(ctx context.Context, targets []string) ([]Peer, error) { + peers := mapsutil.NewSyncLockMap[string, *Peer]() + + // Parse targets into networks and individual IPs + networks, individualIPs, err := parseTargets(targets) + if err != nil { + return nil, fmt.Errorf("failed to parse targets: %w", err) + } + + // Scan networks + for _, network := range networks { + select { + case <-ctx.Done(): + goto done + default: + } + + discovered, err := scanNetwork(ctx, network) + if err != nil { + continue + } + + for _, peer := range discovered { + key := peer.IP.String() + if _, exists := peers.Get(key); !exists { + peerCopy := peer + _ = peers.Set(key, &peerCopy) + } + } + } + + // Scan individual IPs + if len(individualIPs) > 0 { + discovered, err := scanIPs(ctx, individualIPs) + if err == nil { + for _, peer := range discovered { + key := peer.IP.String() + if _, exists := peers.Get(key); !exists { + peerCopy := peer + _ = peers.Set(key, &peerCopy) + } + } + } + } + +done: + // Convert map to slice + var result []Peer + _ = peers.Iterate(func(key string, peer *Peer) error { + if peer != nil { + result = append(result, *peer) + } + return nil + }) + + return result, nil +} + +// Autodiscover retrieves all active peers by automatically discovering and scanning local networks using ICMP pings +func Autodiscover(ctx context.Context) ([]Peer, error) { + // Get local network ranges from local interfaces + networks, err := common.GetLocalNetworks() + if err != nil { + return nil, fmt.Errorf("failed to get local networks: %w", err) + } + + // Convert networks to string targets + targets := make([]string, 0, len(networks)) + for _, network := range networks { + targets = append(targets, network.String()) + } + + return DiscoverPeers(ctx, targets) +} + +// parseTargets parses a list of target strings into networks and individual IPs +func parseTargets(targets []string) ([]*net.IPNet, []net.IP, error) { + var networks []*net.IPNet + var individualIPs []net.IP + seenNetworks := make(map[string]struct{}) + seenIPs := make(map[string]struct{}) + + for _, target := range targets { + // Try to parse as CIDR first + _, ipNet, err := net.ParseCIDR(target) + if err == nil { + // It's a CIDR + key := ipNet.String() + if _, exists := seenNetworks[key]; !exists { + seenNetworks[key] = struct{}{} + networks = append(networks, ipNet) + } + continue + } + + // Try to parse as individual IP + ip := net.ParseIP(target) + if ip != nil { + key := ip.String() + if _, exists := seenIPs[key]; !exists { + seenIPs[key] = struct{}{} + individualIPs = append(individualIPs, ip) + } + continue + } + + // Neither CIDR nor IP + return nil, nil, fmt.Errorf("invalid target format: %s (must be CIDR or IP)", target) + } + + return networks, individualIPs, nil +} + +// scanNetwork scans a network range to discover active peers using ICMP pings +func scanNetwork(ctx context.Context, network *net.IPNet) ([]Peer, error) { + // Expand CIDR to get all IPs in range + cidrStr := network.String() + ips, err := mapcidr.IPAddresses(cidrStr) + if err != nil { + return nil, fmt.Errorf("failed to expand CIDR %s: %w", cidrStr, err) + } + + if len(ips) == 0 { + return []Peer{}, nil + } + + // Filter IPs and determine if IPv4 or IPv6 + var targetIPs []net.IP + isIPv6 := false + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + // Skip network and broadcast/multicast addresses + if common.IsNetworkOrBroadcast(ip, network) { + continue + } + + if ip.To4() == nil { + isIPv6 = true + } + targetIPs = append(targetIPs, ip) + } + + if len(targetIPs) == 0 { + return []Peer{}, nil + } + + // Use shared connection approach for better reply matching + return scanIPsWithSharedConnection(ctx, targetIPs, isIPv6) +} + +// scanIPs scans a list of individual IPs to discover active peers using ICMP pings +func scanIPs(ctx context.Context, ips []net.IP) ([]Peer, error) { + if len(ips) == 0 { + return []Peer{}, nil + } + + // Determine if IPv4 or IPv6 + isIPv6 := false + for _, ip := range ips { + if ip.To4() == nil { + isIPv6 = true + break + } + } + + // Use shared connection approach for better reply matching + return scanIPsWithSharedConnection(ctx, ips, isIPv6) +} + +// scanIPsWithSharedConnection uses a shared ICMP connection to send pings and match replies +func scanIPsWithSharedConnection(ctx context.Context, ips []net.IP, isIPv6 bool) ([]Peer, error) { + const maxRetries = 4 + // Match ping command timeout: default is typically 1-2 seconds, use 2 seconds to account for network latency + const pingTimeout = 2 * time.Second + + // Create shared ICMP connection + conn, err := createSharedICMPConnection(isIPv6) + if err != nil { + return nil, fmt.Errorf("failed to create shared ICMP connection: %w", err) + } + defer func() { + _ = conn.Close() + }() + + // Map to track pending pings: sequence number -> pending ping info + pendingPings := mapsutil.NewSyncLockMap[int, *pendingPing]() + + // Map to store successful peers + peers := mapsutil.NewSyncLockMap[string, *Peer]() + + // Channel for receiver to signal completion + receiverDone := make(chan struct{}) + + // Start receiver goroutine to match replies + go func() { + defer close(receiverDone) + receiveReplies(ctx, conn, pendingPings, peers, isIPv6, pingTimeout) + }() + + // Generate unique sequence numbers starting from 1 + seqCounter := 0 + getNextSeq := func() int { + seqCounter++ + return seqCounter + } + + // Iterate over input for retries times (initial attempt + maxRetries) + for attempt := 0; attempt <= maxRetries; attempt++ { + select { + case <-ctx.Done(): + goto done + default: + } + + // Send pings for all IPs in this attempt + for _, ip := range ips { + select { + case <-ctx.Done(): + goto done + default: + } + + // Skip if we already have a successful peer for this IP + if _, exists := peers.Get(ip.String()); exists { + continue + } + + seq := getNextSeq() + start := time.Now() + + // Track pending ping + pending := &pendingPing{ + IP: ip, + Start: start, + Seq: seq, + Retries: attempt, + } + _ = pendingPings.Set(seq, pending) + + // Send ping + if err := sendPing(conn, ip, seq, isIPv6); err != nil { + pendingPings.Delete(seq) + continue + } + } + + // Wait for replies before next attempt (except on last attempt) + // Match ping command behavior: wait for timeout before retrying + if attempt < maxRetries { + time.Sleep(pingTimeout) + } + } + +done: + // Wait for final replies - match ping command final wait time + // Give extra time for any delayed replies + finalTimeout := pingTimeout + select { + case <-receiverDone: + case <-time.After(finalTimeout): + // Receiver should finish on its own, but we have a timeout + case <-ctx.Done(): + } + + // Convert map to slice + var result []Peer + _ = peers.Iterate(func(key string, peer *Peer) error { + if peer != nil { + result = append(result, *peer) + } + return nil + }) + + return result, nil +} + +// pendingPing tracks a sent ping waiting for reply +type pendingPing struct { + IP net.IP + Start time.Time + Seq int + Retries int +} + +// createSharedICMPConnection creates a shared ICMP connection +func createSharedICMPConnection(isIPv6 bool) (net.PacketConn, error) { + if isIPv6 { + return icmp.ListenPacket("ip6:ipv6-icmp", "::") + } + return icmp.ListenPacket("ip4:icmp", "0.0.0.0") +} diff --git a/pkg/peerdiscovery/pingsweep/pingsweep_unix.go b/pkg/peerdiscovery/pingsweep/pingsweep_unix.go new file mode 100644 index 0000000..877c7ca --- /dev/null +++ b/pkg/peerdiscovery/pingsweep/pingsweep_unix.go @@ -0,0 +1,149 @@ +//go:build !windows + +package pingsweep + +import ( + "context" + "fmt" + "net" + "os" + "time" + + mapsutil "github.com/projectdiscovery/utils/maps" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// sendPing sends an ICMP echo request through the shared connection +func sendPing(conn net.PacketConn, ip net.IP, seq int, isIPv6 bool) error { + var msgType icmp.Type + + if isIPv6 { + msgType = ipv6.ICMPTypeEchoRequest + } else { + msgType = ipv4.ICMPTypeEcho + } + + msg := &icmp.Message{ + Type: msgType, + Code: 0, + Body: &icmp.Echo{ + ID: os.Getpid() & 0xffff, + Seq: seq, + Data: []byte("HELLO-R-U-THERE"), + }, + } + + msgBytes, err := msg.Marshal(nil) + if err != nil { + return fmt.Errorf("failed to marshal ICMP message: %w", err) + } + + dst := &net.IPAddr{IP: ip} + _, err = conn.WriteTo(msgBytes, dst) + return err +} + +// receiveReplies receives and matches ICMP echo replies +func receiveReplies(ctx context.Context, conn net.PacketConn, pendingPings *mapsutil.SyncLockMap[int, *pendingPing], peers *mapsutil.SyncLockMap[string, *Peer], isIPv6 bool, timeout time.Duration) { + var echoReplyType icmp.Type + var protocol int + + if isIPv6 { + echoReplyType = ipv6.ICMPTypeEchoReply + protocol = ipv6.ICMPTypeEchoReply.Protocol() + } else { + echoReplyType = ipv4.ICMPTypeEchoReply + protocol = ipv4.ICMPTypeEchoReply.Protocol() + } + + expectedID := os.Getpid() & 0xffff + // Use a longer initial deadline to match ping command behavior + // Ping typically waits 1-2 seconds per ping, we'll use the timeout parameter + deadline := time.Now().Add(timeout * 3) // Give more time for all pings to complete + + for { + // Check if we should stop + select { + case <-ctx.Done(): + return + default: + } + + // Check if deadline passed + if time.Now().After(deadline) { + // Check if there are still pending pings + hasPending := false + _ = pendingPings.Iterate(func(key int, ping *pendingPing) error { + hasPending = true + return nil + }) + if !hasPending { + return + } + // Extend deadline if there are still pending pings - give more time + deadline = time.Now().Add(timeout) + } + + // Set read deadline - use longer timeout to match ping command + // Ping command typically uses 1-2 second timeouts, we'll check more frequently but with longer socket timeout + if err := conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil { + continue + } + + // Read reply + reply := make([]byte, 1500) + n, peer, err := conn.ReadFrom(reply) + if err != nil { + // Timeout or error, continue + continue + } + + // Parse reply + rm, err := icmp.ParseMessage(protocol, reply[:n]) + if err != nil { + continue + } + + // Verify it's an echo reply + if rm.Type != echoReplyType { + continue + } + + // Check if this is our reply + echo, ok := rm.Body.(*icmp.Echo) + if !ok { + continue + } + + // Match by ID + if echo.ID != expectedID { + continue + } + + // Find matching pending ping + pending, exists := pendingPings.Get(echo.Seq) + if !exists { + continue + } + + // Verify peer IP matches + if peerAddr, ok := peer.(*net.IPAddr); !ok || !peerAddr.IP.Equal(pending.IP) { + continue + } + + // Calculate RTT + rtt := time.Since(pending.Start) + + // Store successful peer + discoveredPeer := &Peer{ + IP: pending.IP, + RTT: rtt, + } + _ = peers.Set(pending.IP.String(), discoveredPeer) + + // Remove from pending + pendingPings.Delete(echo.Seq) + } +} diff --git a/pkg/peerdiscovery/pingsweep/pingsweep_windows.go b/pkg/peerdiscovery/pingsweep/pingsweep_windows.go new file mode 100644 index 0000000..4de0610 --- /dev/null +++ b/pkg/peerdiscovery/pingsweep/pingsweep_windows.go @@ -0,0 +1,149 @@ +//go:build windows + +package pingsweep + +import ( + "context" + "fmt" + "net" + "os" + "time" + + mapsutil "github.com/projectdiscovery/utils/maps" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "golang.org/x/net/ipv6" +) + +// sendPing sends an ICMP echo request through the shared connection +func sendPing(conn net.PacketConn, ip net.IP, seq int, isIPv6 bool) error { + var msgType icmp.Type + + if isIPv6 { + msgType = ipv6.ICMPTypeEchoRequest + } else { + msgType = ipv4.ICMPTypeEcho + } + + msg := &icmp.Message{ + Type: msgType, + Code: 0, + Body: &icmp.Echo{ + ID: os.Getpid() & 0xffff, + Seq: seq, + Data: []byte("HELLO-R-U-THERE"), + }, + } + + msgBytes, err := msg.Marshal(nil) + if err != nil { + return fmt.Errorf("failed to marshal ICMP message: %w", err) + } + + dst := &net.IPAddr{IP: ip} + _, err = conn.WriteTo(msgBytes, dst) + return err +} + +// receiveReplies receives and matches ICMP echo replies +func receiveReplies(ctx context.Context, conn net.PacketConn, pendingPings *mapsutil.SyncLockMap[int, *pendingPing], peers *mapsutil.SyncLockMap[string, *Peer], isIPv6 bool, timeout time.Duration) { + var echoReplyType icmp.Type + var protocol int + + if isIPv6 { + echoReplyType = ipv6.ICMPTypeEchoReply + protocol = ipv6.ICMPTypeEchoReply.Protocol() + } else { + echoReplyType = ipv4.ICMPTypeEchoReply + protocol = ipv4.ICMPTypeEchoReply.Protocol() + } + + expectedID := os.Getpid() & 0xffff + // Use a longer initial deadline to match ping command behavior + // Ping typically waits 1-2 seconds per ping, we'll use the timeout parameter + deadline := time.Now().Add(timeout * 3) // Give more time for all pings to complete + + for { + // Check if we should stop + select { + case <-ctx.Done(): + return + default: + } + + // Check if deadline passed + if time.Now().After(deadline) { + // Check if there are still pending pings + hasPending := false + _ = pendingPings.Iterate(func(key int, ping *pendingPing) error { + hasPending = true + return nil + }) + if !hasPending { + return + } + // Extend deadline if there are still pending pings - give more time + deadline = time.Now().Add(timeout) + } + + // Set read deadline - use longer timeout to match ping command + // Ping command typically uses 1-2 second timeouts, we'll check more frequently but with longer socket timeout + if err := conn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)); err != nil { + continue + } + + // Read reply + reply := make([]byte, 1500) + n, peer, err := conn.ReadFrom(reply) + if err != nil { + // Timeout or error, continue + continue + } + + // Parse reply + rm, err := icmp.ParseMessage(protocol, reply[:n]) + if err != nil { + continue + } + + // Verify it's an echo reply + if rm.Type != echoReplyType { + continue + } + + // Check if this is our reply + echo, ok := rm.Body.(*icmp.Echo) + if !ok { + continue + } + + // Match by ID + if echo.ID != expectedID { + continue + } + + // Find matching pending ping + pending, exists := pendingPings.Get(echo.Seq) + if !exists { + continue + } + + // Verify peer IP matches + if peerAddr, ok := peer.(*net.IPAddr); !ok || !peerAddr.IP.Equal(pending.IP) { + continue + } + + // Calculate RTT + rtt := time.Since(pending.Start) + + // Store successful peer + discoveredPeer := &Peer{ + IP: pending.IP, + RTT: rtt, + } + _ = peers.Set(pending.IP.String(), discoveredPeer) + + // Remove from pending + pendingPings.Delete(echo.Seq) + } +} diff --git a/pkg/peerdiscovery/prescan/distribution.go b/pkg/peerdiscovery/prescan/distribution.go new file mode 100644 index 0000000..cd2bbf0 --- /dev/null +++ b/pkg/peerdiscovery/prescan/distribution.go @@ -0,0 +1,110 @@ +package prescan + +import ( + "net" + + "github.com/projectdiscovery/pd-agent/pkg/peerdiscovery/common" +) + +// DistributionPattern maps an IP range to a priority tier +type DistributionPattern struct { + RangeStart int + RangeEnd int + Priority int + Description string +} + +// Priority tiers based on real-world network patterns +const ( + PriorityTier1 = 100 // .1, .254 (routers/gateways) + PriorityTier2 = 90 // .2-.5, .250-.253 (reserved) + PriorityTier3 = 80 // .6-.10 (early DHCP) + PriorityTier4 = 70 // .50, .100, .150 (DHCP peaks) + PriorityTier5 = 50 // .51-.99, .101-.149, .151-.200 (DHCP pool) + PriorityTier6 = 20 // .11-.49, .201-.249 (long-tail) + PriorityTier7 = 0 // .0, .255 (excluded) +) + +// getDistributionPatterns returns the priority patterns for /24 networks +func getDistributionPatterns() []DistributionPattern { + return []DistributionPattern{ + // Routers/gateways - check these first + {RangeStart: 1, RangeEnd: 1, Priority: PriorityTier1, Description: "Router/switch management"}, + {RangeStart: 254, RangeEnd: 254, Priority: PriorityTier1, Description: "Gateway/router"}, + + // Reserved infrastructure + {RangeStart: 2, RangeEnd: 5, Priority: PriorityTier2, Description: "Infrastructure reserved"}, + {RangeStart: 250, RangeEnd: 253, Priority: PriorityTier2, Description: "High-end reserved"}, + + // Early DHCP - devices that connect first + {RangeStart: 6, RangeEnd: 10, Priority: PriorityTier3, Description: "Early DHCP allocation"}, + + // DHCP allocation peaks + {RangeStart: 50, RangeEnd: 50, Priority: PriorityTier4, Description: "DHCP peak 1"}, + {RangeStart: 100, RangeEnd: 100, Priority: PriorityTier4, Description: "DHCP peak 2"}, + {RangeStart: 150, RangeEnd: 150, Priority: PriorityTier4, Description: "DHCP peak 3"}, + + // Main DHCP pool + {RangeStart: 51, RangeEnd: 99, Priority: PriorityTier5, Description: "DHCP range 1"}, + {RangeStart: 101, RangeEnd: 149, Priority: PriorityTier5, Description: "DHCP range 2"}, + {RangeStart: 151, RangeEnd: 200, Priority: PriorityTier5, Description: "DHCP range 3"}, + + // Long-tail - lower probability + {RangeStart: 11, RangeEnd: 49, Priority: PriorityTier6, Description: "Long-tail 1"}, + {RangeStart: 201, RangeEnd: 249, Priority: PriorityTier6, Description: "Long-tail 2"}, + + // Excluded addresses + {RangeStart: 0, RangeEnd: 0, Priority: PriorityTier7, Description: "Network address"}, + {RangeStart: 255, RangeEnd: 255, Priority: PriorityTier7, Description: "Broadcast address"}, + } +} + +// calculateIPv4Priority scores an IPv4 address based on last octet patterns. +func calculateIPv4Priority(ip net.IP, network *net.IPNet) int { + ip4 := ip.To4() + if ip4 == nil { + return PriorityTier6 + } + + lastOctet := int(ip4[3]) + + // Skip network/broadcast + if common.IsNetworkOrBroadcast(ip, network) { + return PriorityTier7 + } + + // Match against known patterns + patterns := getDistributionPatterns() + for _, pattern := range patterns { + if lastOctet >= pattern.RangeStart && lastOctet <= pattern.RangeEnd { + return pattern.Priority + } + } + + return PriorityTier6 +} + +// adaptPriorityForSubnet handles non-/24 subnets using /24 logic for now. +func adaptPriorityForSubnet(ip net.IP, network *net.IPNet) int { + ones, bits := network.Mask.Size() + if bits != 32 { + return PriorityTier6 + } + + // /24 gets the full treatment + if ones == 24 { + return calculateIPv4Priority(ip, network) + } + + ip4 := ip.To4() + if ip4 == nil { + return PriorityTier6 + } + + if common.IsNetworkOrBroadcast(ip, network) { + return PriorityTier7 + } + + // For other sizes, just use /24 logic for now + return calculateIPv4Priority(ip, network) +} diff --git a/pkg/peerdiscovery/prescan/doc.go b/pkg/peerdiscovery/prescan/doc.go new file mode 100644 index 0000000..c753892 --- /dev/null +++ b/pkg/peerdiscovery/prescan/doc.go @@ -0,0 +1,24 @@ +// Package prescan selects the most likely-to-be-online IPs from a CIDR based on +// real-world network patterns. Uses Pareto principle - most active hosts are in +// the top 20% of IPs (routers, gateways, early DHCP allocations). +// +// Priority tiers (0-100): +// - 100: .1, .254 (routers/gateways - always check these first) +// - 90: .2-.5, .250-.253 (reserved infrastructure) +// - 80: .6-.10 (early DHCP - devices that connect first) +// - 70: .50, .100, .150 (DHCP peaks - common allocation points) +// - 50: .51-.99, .101-.149, .151-.200 (main DHCP pool) +// - 20: .11-.49, .201-.249 (long-tail, lower probability) +// - 0: .0, .255 (excluded - network/broadcast) +// +// Example: +// +// // Get top 25% most likely IPs +// ips, err := prescan.SelectIPs("192.168.1.0/24", 0.25) +// +// // Or get exactly 50 IPs +// ips, err := prescan.SelectIPsWithCount("192.168.1.0/24", 50) +// +// O(n log n) complexity. For huge networks, use SelectIPsWithCount. +package prescan + diff --git a/pkg/peerdiscovery/prescan/prescan.go b/pkg/peerdiscovery/prescan/prescan.go new file mode 100644 index 0000000..abef19e --- /dev/null +++ b/pkg/peerdiscovery/prescan/prescan.go @@ -0,0 +1,175 @@ +package prescan + +import ( + "fmt" + "math" + "net" + "sort" + + "github.com/projectdiscovery/mapcidr" +) + +// SelectIPs returns the top N% of IPs from a CIDR, sorted by priority. +// ratio is 0.0-1.0 (e.g., 0.25 = 25%). +func SelectIPs(cidr string, ratio float64) ([]net.IP, error) { + // Clamp ratio to valid range + if ratio < 0 { + ratio = 0 + } + if ratio > 1 { + ratio = 1 + } + + _, network, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("invalid CIDR: %w", err) + } + + ips, err := mapcidr.IPAddresses(cidr) + if err != nil { + return nil, fmt.Errorf("failed to expand CIDR: %w", err) + } + + if len(ips) == 0 { + return []net.IP{}, nil + } + + // Drop network/broadcast addresses + usableIPs := filterUsableIPs(ips, network) + if len(usableIPs) == 0 { + return []net.IP{}, nil + } + + // Score each IP + prioritized := make([]PrioritizedIP, 0, len(usableIPs)) + for _, ip := range usableIPs { + prioritized = append(prioritized, PrioritizedIP{ + IP: ip, + Priority: CalculatePriority(ip, network), + }) + } + + // Sort by priority (high to low), then by IP for stable ordering + sort.Slice(prioritized, func(i, j int) bool { + if prioritized[i].Priority != prioritized[j].Priority { + return prioritized[i].Priority > prioritized[j].Priority + } + return compareIP(prioritized[i].IP, prioritized[j].IP) < 0 + }) + + targetCount := int(math.Ceil(float64(len(usableIPs)) * ratio)) + // If ratio > 0 but math gives us 0, at least return 1 IP + if ratio > 0 && targetCount == 0 && len(usableIPs) > 0 { + targetCount = 1 + } + if targetCount > len(prioritized) { + targetCount = len(prioritized) + } + + result := make([]net.IP, 0, targetCount) + for i := 0; i < targetCount; i++ { + result = append(result, prioritized[i].IP) + } + + return result, nil +} + +// SelectIPsWithCount returns exactly N highest-priority IPs from a CIDR. +func SelectIPsWithCount(cidr string, count int) ([]net.IP, error) { + if count <= 0 { + return []net.IP{}, nil + } + + _, network, err := net.ParseCIDR(cidr) + if err != nil { + return nil, fmt.Errorf("invalid CIDR: %w", err) + } + + ips, err := mapcidr.IPAddresses(cidr) + if err != nil { + return nil, fmt.Errorf("failed to expand CIDR: %w", err) + } + + if len(ips) == 0 { + return []net.IP{}, nil + } + + usableIPs := filterUsableIPs(ips, network) + if len(usableIPs) == 0 { + return []net.IP{}, nil + } + + // Score and sort + prioritized := make([]PrioritizedIP, 0, len(usableIPs)) + for _, ip := range usableIPs { + prioritized = append(prioritized, PrioritizedIP{ + IP: ip, + Priority: CalculatePriority(ip, network), + }) + } + + sort.Slice(prioritized, func(i, j int) bool { + if prioritized[i].Priority != prioritized[j].Priority { + return prioritized[i].Priority > prioritized[j].Priority + } + return compareIP(prioritized[i].IP, prioritized[j].IP) < 0 + }) + + if count > len(prioritized) { + count = len(prioritized) + } + + result := make([]net.IP, 0, count) + for i := 0; i < count; i++ { + result = append(result, prioritized[i].IP) + } + + return result, nil +} + +// compareIP compares two IPs. Returns -1 if ip1 < ip2, 0 if equal, 1 if ip1 > ip2. +// IPv4 always comes before IPv6. +func compareIP(ip1, ip2 net.IP) int { + ip1v4 := ip1.To4() + ip2v4 := ip2.To4() + + // IPv4 < IPv6 + if ip1v4 != nil && ip2v4 == nil { + return -1 + } + if ip1v4 == nil && ip2v4 != nil { + return 1 + } + + // Both IPv4 + if ip1v4 != nil && ip2v4 != nil { + for i := 0; i < len(ip1v4); i++ { + if ip1v4[i] < ip2v4[i] { + return -1 + } + if ip1v4[i] > ip2v4[i] { + return 1 + } + } + return 0 + } + + // Both IPv6 + for i := 0; i < len(ip1) && i < len(ip2); i++ { + if ip1[i] < ip2[i] { + return -1 + } + if ip1[i] > ip2[i] { + return 1 + } + } + + if len(ip1) < len(ip2) { + return -1 + } + if len(ip1) > len(ip2) { + return 1 + } + + return 0 +} diff --git a/pkg/peerdiscovery/prescan/prescan_test.go b/pkg/peerdiscovery/prescan/prescan_test.go new file mode 100644 index 0000000..3232ddb --- /dev/null +++ b/pkg/peerdiscovery/prescan/prescan_test.go @@ -0,0 +1,857 @@ +package prescan + +import ( + "net" + "testing" + + "github.com/projectdiscovery/pd-agent/pkg/peerdiscovery/common" +) + +func TestSelectIPs(t *testing.T) { + tests := []struct { + name string + cidr string + ratio float64 + wantCount int + wantErr bool + validate func(t *testing.T, ips []net.IP) + }{ + { + name: "25% of /24 network", + cidr: "192.168.1.0/24", + ratio: 0.25, + wantCount: 64, // 254 usable IPs * 0.25 = 63.5, rounded up to 64 + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + // Should include high-priority IPs + has1 := false + has254 := false + for _, ip := range ips { + ip4 := ip.To4() + if ip4 != nil && ip4[3] == 1 { + has1 = true + } + if ip4 != nil && ip4[3] == 254 { + has254 = true + } + } + if !has1 { + t.Error("Expected to include .1 (router)") + } + if !has254 { + t.Error("Expected to include .254 (gateway)") + } + }, + }, + { + name: "50% of /24 network", + cidr: "192.168.1.0/24", + ratio: 0.5, + wantCount: 127, // 254 usable IPs * 0.5 = 127 + wantErr: false, + }, + { + name: "100% of /24 network", + cidr: "192.168.1.0/24", + ratio: 1.0, + wantCount: 254, // All usable IPs + wantErr: false, + }, + { + name: "0% ratio", + cidr: "192.168.1.0/24", + ratio: 0.0, + wantCount: 0, + wantErr: false, + }, + { + name: "10% of /24 network", + cidr: "192.168.1.0/24", + ratio: 0.1, + wantCount: 26, // 254 * 0.1 = 25.4, rounded up + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + // Should prioritize high-priority IPs + // Check that .1 and .254 are included + has1 := false + has254 := false + for _, ip := range ips { + ip4 := ip.To4() + if ip4 != nil && ip4[3] == 1 { + has1 = true + } + if ip4 != nil && ip4[3] == 254 { + has254 = true + } + } + if !has1 { + t.Error("Expected to include .1 in top 10%") + } + if !has254 { + t.Error("Expected to include .254 in top 10%") + } + }, + }, + { + name: "Invalid CIDR", + cidr: "invalid", + ratio: 0.25, + wantCount: 0, + wantErr: true, + }, + { + name: "Single host /32", + cidr: "192.168.1.1/32", + ratio: 0.5, + wantCount: 0, // /32 has no usable IPs (network = broadcast) + wantErr: false, + }, + { + name: "Negative ratio clamped to 0", + cidr: "192.168.1.0/24", + ratio: -0.1, + wantCount: 0, + wantErr: false, + }, + { + name: "Ratio > 1 clamped to 1", + cidr: "192.168.1.0/24", + ratio: 1.5, + wantCount: 254, + wantErr: false, + }, + { + name: "Small /30 network", + cidr: "192.168.1.0/30", + ratio: 0.5, + wantCount: 1, // 2 usable IPs * 0.5 = 1 + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := SelectIPs(tt.cidr, tt.ratio) + if (err != nil) != tt.wantErr { + t.Errorf("SelectIPs() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && len(ips) != tt.wantCount { + t.Errorf("SelectIPs() count = %d, want %d", len(ips), tt.wantCount) + } + if tt.validate != nil { + tt.validate(t, ips) + } + }) + } +} + +func TestSelectIPsWithCount(t *testing.T) { + tests := []struct { + name string + cidr string + count int + wantErr bool + validate func(t *testing.T, ips []net.IP) + }{ + { + name: "Select 50 IPs from /24", + cidr: "192.168.1.0/24", + count: 50, + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + if len(ips) != 50 { + t.Errorf("Expected 50 IPs, got %d", len(ips)) + } + // Should include high-priority IPs + has1 := false + has254 := false + for _, ip := range ips { + ip4 := ip.To4() + if ip4 != nil && ip4[3] == 1 { + has1 = true + } + if ip4 != nil && ip4[3] == 254 { + has254 = true + } + } + if !has1 { + t.Error("Expected to include .1") + } + if !has254 { + t.Error("Expected to include .254") + } + }, + }, + { + name: "Select 0 IPs", + cidr: "192.168.1.0/24", + count: 0, + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + if len(ips) != 0 { + t.Errorf("Expected 0 IPs, got %d", len(ips)) + } + }, + }, + { + name: "Select more than available", + cidr: "192.168.1.0/24", + count: 1000, + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + // Should return all usable IPs (254) + if len(ips) != 254 { + t.Errorf("Expected 254 IPs, got %d", len(ips)) + } + }, + }, + { + name: "Invalid CIDR", + cidr: "invalid", + count: 10, + wantErr: true, + }, + { + name: "Single host /32", + cidr: "192.168.1.1/32", + count: 10, + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + // /32 has no usable IPs (network = broadcast) + if len(ips) != 0 { + t.Errorf("Expected 0 IPs for /32, got %d", len(ips)) + } + }, + }, + { + name: "Negative count", + cidr: "192.168.1.0/24", + count: -5, + wantErr: false, + validate: func(t *testing.T, ips []net.IP) { + if len(ips) != 0 { + t.Errorf("Expected 0 IPs for negative count, got %d", len(ips)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := SelectIPsWithCount(tt.cidr, tt.count) + if (err != nil) != tt.wantErr { + t.Errorf("SelectIPsWithCount() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.validate != nil { + tt.validate(t, ips) + } + }) + } +} + +func TestCalculatePriority(t *testing.T) { + _, network, _ := net.ParseCIDR("192.168.1.0/24") + + tests := []struct { + name string + ip string + network *net.IPNet + want int + wantErr bool + }{ + { + name: "Infrastructure .1", + ip: "192.168.1.1", + network: network, + want: PriorityTier1, + }, + { + name: "Infrastructure .254", + ip: "192.168.1.254", + network: network, + want: PriorityTier1, + }, + { + name: "Reserved .2", + ip: "192.168.1.2", + network: network, + want: PriorityTier2, + }, + { + name: "Reserved .5", + ip: "192.168.1.5", + network: network, + want: PriorityTier2, + }, + { + name: "Reserved .250", + ip: "192.168.1.250", + network: network, + want: PriorityTier2, + }, + { + name: "Early DHCP .6", + ip: "192.168.1.6", + network: network, + want: PriorityTier3, + }, + { + name: "Early DHCP .10", + ip: "192.168.1.10", + network: network, + want: PriorityTier3, + }, + { + name: "DHCP peak .50", + ip: "192.168.1.50", + network: network, + want: PriorityTier4, + }, + { + name: "DHCP peak .100", + ip: "192.168.1.100", + network: network, + want: PriorityTier4, + }, + { + name: "DHCP peak .150", + ip: "192.168.1.150", + network: network, + want: PriorityTier4, + }, + { + name: "DHCP range .51", + ip: "192.168.1.51", + network: network, + want: PriorityTier5, + }, + { + name: "DHCP range .200", + ip: "192.168.1.200", + network: network, + want: PriorityTier5, + }, + { + name: "Long-tail .25", + ip: "192.168.1.25", + network: network, + want: PriorityTier6, + }, + { + name: "Long-tail .240", + ip: "192.168.1.240", + network: network, + want: PriorityTier6, + }, + { + name: "Network address .0", + ip: "192.168.1.0", + network: network, + want: PriorityTier7, + }, + { + name: "Broadcast address .255", + ip: "192.168.1.255", + network: network, + want: PriorityTier7, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + if ip == nil { + t.Fatalf("Failed to parse IP: %s", tt.ip) + } + got := CalculatePriority(ip, tt.network) + if got != tt.want { + t.Errorf("CalculatePriority() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestPriorityOrdering(t *testing.T) { + // Test that SelectIPs returns IPs in priority order + ips, err := SelectIPs("192.168.1.0/24", 0.5) + if err != nil { + t.Fatalf("SelectIPs() error = %v", err) + } + + if len(ips) == 0 { + t.Fatal("Expected at least some IPs") + } + + // Calculate priorities for all returned IPs + _, network, _ := net.ParseCIDR("192.168.1.0/24") + priorities := make([]int, len(ips)) + for i, ip := range ips { + priorities[i] = CalculatePriority(ip, network) + } + + // Check that priorities are in descending order + for i := 1; i < len(priorities); i++ { + if priorities[i] > priorities[i-1] { + t.Errorf("IPs not in priority order: priority[%d]=%d > priority[%d]=%d", + i, priorities[i], i-1, priorities[i-1]) + } + } +} + +func TestHighPriorityIPsIncluded(t *testing.T) { + // Test that high-priority IPs are included even in small ratios + ips, err := SelectIPs("192.168.1.0/24", 0.05) // 5% = ~13 IPs + if err != nil { + t.Fatalf("SelectIPs() error = %v", err) + } + + // Check for high-priority IPs + has1 := false + has254 := false + has2to5 := false + has250to253 := false + + for _, ip := range ips { + ip4 := ip.To4() + if ip4 == nil { + continue + } + lastOctet := ip4[3] + + if lastOctet == 1 { + has1 = true + } + if lastOctet == 254 { + has254 = true + } + if lastOctet >= 2 && lastOctet <= 5 { + has2to5 = true + } + if lastOctet >= 250 && lastOctet <= 253 { + has250to253 = true + } + } + + if !has1 { + t.Error("Expected .1 to be included in top 5%") + } + if !has254 { + t.Error("Expected .254 to be included in top 5%") + } + if !has2to5 { + t.Error("Expected some .2-.5 IPs to be included in top 5%") + } + if !has250to253 { + t.Error("Expected some .250-.253 IPs to be included in top 5%") + } +} + +func TestDHCPPeaksIncluded(t *testing.T) { + // Test that DHCP peaks are included in reasonable ratios + ips, err := SelectIPs("192.168.1.0/24", 0.15) // 15% = ~38 IPs + if err != nil { + t.Fatalf("SelectIPs() error = %v", err) + } + + has50 := false + has100 := false + has150 := false + + for _, ip := range ips { + ip4 := ip.To4() + if ip4 == nil { + continue + } + lastOctet := ip4[3] + + if lastOctet == 50 { + has50 = true + } + if lastOctet == 100 { + has100 = true + } + if lastOctet == 150 { + has150 = true + } + } + + if !has50 { + t.Error("Expected .50 to be included in top 15%") + } + if !has100 { + t.Error("Expected .100 to be included in top 15%") + } + if !has150 { + t.Error("Expected .150 to be included in top 15%") + } +} + +func TestNetworkAndBroadcastExcluded(t *testing.T) { + ips, err := SelectIPs("192.168.1.0/24", 1.0) // 100% should include all but .0 and .255 + if err != nil { + t.Fatalf("SelectIPs() error = %v", err) + } + + has0 := false + has255 := false + + for _, ip := range ips { + ip4 := ip.To4() + if ip4 == nil { + continue + } + lastOctet := ip4[3] + + if lastOctet == 0 { + has0 = true + } + if lastOctet == 255 { + has255 = true + } + } + + if has0 { + t.Error("Network address .0 should be excluded") + } + if has255 { + t.Error("Broadcast address .255 should be excluded") + } +} + +func TestDifferentSubnetSizes(t *testing.T) { + tests := []struct { + name string + cidr string + ratio float64 + wantMin int + wantMax int + }{ + { + name: "/30 network", + cidr: "192.168.1.0/30", + ratio: 0.5, + wantMin: 1, + wantMax: 2, + }, + { + name: "/28 network", + cidr: "192.168.1.0/28", + ratio: 0.5, + wantMin: 7, + wantMax: 14, + }, + { + name: "/25 network", + cidr: "192.168.1.0/25", + ratio: 0.5, + wantMin: 63, + wantMax: 126, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := SelectIPs(tt.cidr, tt.ratio) + if err != nil { + t.Fatalf("SelectIPs() error = %v", err) + } + if len(ips) < tt.wantMin || len(ips) > tt.wantMax { + t.Errorf("SelectIPs() count = %d, want between %d and %d", len(ips), tt.wantMin, tt.wantMax) + } + }) + } +} + +func TestCompareIP(t *testing.T) { + tests := []struct { + name string + ip1 string + ip2 string + want int + }{ + { + name: "ip1 < ip2", + ip1: "192.168.1.1", + ip2: "192.168.1.2", + want: -1, + }, + { + name: "ip1 > ip2", + ip1: "192.168.1.2", + ip2: "192.168.1.1", + want: 1, + }, + { + name: "ip1 == ip2", + ip1: "192.168.1.1", + ip2: "192.168.1.1", + want: 0, + }, + { + name: "Different first octet", + ip1: "192.168.1.1", + ip2: "193.168.1.1", + want: -1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip1 := net.ParseIP(tt.ip1) + ip2 := net.ParseIP(tt.ip2) + got := compareIP(ip1, ip2) + if got != tt.want { + t.Errorf("compareIP() = %d, want %d", got, tt.want) + } + }) + } +} + +func TestDeterministicOrdering(t *testing.T) { + // Test that the same input produces the same output + ips1, err1 := SelectIPs("192.168.1.0/24", 0.25) + ips2, err2 := SelectIPs("192.168.1.0/24", 0.25) + + if err1 != nil || err2 != nil { + t.Fatalf("SelectIPs() errors: %v, %v", err1, err2) + } + + if len(ips1) != len(ips2) { + t.Fatalf("Different lengths: %d vs %d", len(ips1), len(ips2)) + } + + for i := range ips1 { + if !ips1[i].Equal(ips2[i]) { + t.Errorf("Different IPs at index %d: %s vs %s", i, ips1[i], ips2[i]) + } + } +} + +func BenchmarkSelectIPs(b *testing.B) { + cidr := "192.168.1.0/24" + ratio := 0.25 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := SelectIPs(cidr, ratio) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkSelectIPsWithCount(b *testing.B) { + cidr := "192.168.1.0/24" + count := 50 + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := SelectIPsWithCount(cidr, count) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkCalculatePriority(b *testing.B) { + _, network, _ := net.ParseCIDR("192.168.1.0/24") + ip := net.ParseIP("192.168.1.100") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CalculatePriority(ip, network) + } +} + +func TestCalculatePriorityNilNetwork(t *testing.T) { + ip := net.ParseIP("192.168.1.1") + priority := CalculatePriority(ip, nil) + if priority != PriorityTier6 { + t.Errorf("Expected default priority %d for nil network, got %d", PriorityTier6, priority) + } +} + +func TestCalculatePriorityIPv6(t *testing.T) { + _, network, _ := net.ParseCIDR("2001:db8::/32") + ip := net.ParseIP("2001:db8::1") + priority := CalculatePriority(ip, network) + // IPv6 should get default priority + if priority != PriorityTier6 { + t.Errorf("Expected default priority %d for IPv6, got %d", PriorityTier6, priority) + } +} + +func TestCompareIPIPv6(t *testing.T) { + ip1 := net.ParseIP("2001:db8::1") + ip2 := net.ParseIP("2001:db8::2") + + result := compareIP(ip1, ip2) + if result >= 0 { + t.Errorf("Expected ip1 < ip2 for IPv6, got %d", result) + } +} + +func TestCompareIPMixed(t *testing.T) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("2001:db8::1") + + // IPv4 should come before IPv6 + result := compareIP(ip1, ip2) + if result >= 0 { + t.Errorf("Expected IPv4 < IPv6, got %d", result) + } + + // Reverse order + result = compareIP(ip2, ip1) + if result <= 0 { + t.Errorf("Expected IPv6 > IPv4, got %d", result) + } +} + +func TestAdaptPriorityForSubnet(t *testing.T) { + tests := []struct { + name string + cidr string + ip string + wantNot int // Should not be this priority (excluded) + }{ + { + name: "/25 network", + cidr: "192.168.1.0/25", + ip: "192.168.1.1", + wantNot: PriorityTier7, // Should not be excluded + }, + { + name: "/16 network", + cidr: "192.168.0.0/16", + ip: "192.168.1.1", + wantNot: PriorityTier7, // Should not be excluded + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, network, _ := net.ParseCIDR(tt.cidr) + ip := net.ParseIP(tt.ip) + priority := adaptPriorityForSubnet(ip, network) + if priority == tt.wantNot { + t.Errorf("Priority should not be %d for %s in %s", tt.wantNot, tt.ip, tt.cidr) + } + }) + } +} + +func TestFilterUsableIPs(t *testing.T) { + _, network, _ := net.ParseCIDR("192.168.1.0/24") + ips := []string{ + "192.168.1.0", // Network - should be filtered + "192.168.1.1", // Usable + "192.168.1.255", // Broadcast - should be filtered + "192.168.1.100", // Usable + "invalid", // Invalid - should be filtered + } + + usable := filterUsableIPs(ips, network) + + if len(usable) != 2 { + t.Errorf("Expected 2 usable IPs, got %d", len(usable)) + } + + // Check that .0 and .255 are not included + for _, ip := range usable { + ip4 := ip.To4() + if ip4 != nil { + if ip4[3] == 0 || ip4[3] == 255 { + t.Errorf("Network/broadcast address should be filtered: %s", ip) + } + } + } +} + +func TestIsNetworkOrBroadcast(t *testing.T) { + _, network, _ := net.ParseCIDR("192.168.1.0/24") + + tests := []struct { + name string + ip string + want bool + }{ + { + name: "Network address", + ip: "192.168.1.0", + want: true, + }, + { + name: "Broadcast address", + ip: "192.168.1.255", + want: true, + }, + { + name: "Regular IP", + ip: "192.168.1.1", + want: false, + }, + { + name: "Another regular IP", + ip: "192.168.1.100", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ip := net.ParseIP(tt.ip) + got := common.IsNetworkOrBroadcast(ip, network) + if got != tt.want { + t.Errorf("IsNetworkOrBroadcast() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSelectIPsEdgeCases(t *testing.T) { + tests := []struct { + name string + cidr string + ratio float64 + validate func(t *testing.T, ips []net.IP, err error) + }{ + { + name: "Very small ratio", + cidr: "192.168.1.0/24", + ratio: 0.001, // 0.1% + validate: func(t *testing.T, ips []net.IP, err error) { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // Should get at least 1 IP + if len(ips) < 1 { + t.Error("Expected at least 1 IP for very small ratio") + } + }, + }, + { + name: "Exact 50%", + cidr: "192.168.1.0/24", + ratio: 0.5, + validate: func(t *testing.T, ips []net.IP, err error) { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + // Should get approximately 127 IPs (254 * 0.5) + if len(ips) < 120 || len(ips) > 130 { + t.Errorf("Expected approximately 127 IPs for 50%% ratio, got %d", len(ips)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := SelectIPs(tt.cidr, tt.ratio) + tt.validate(t, ips, err) + }) + } +} diff --git a/pkg/peerdiscovery/prescan/priority.go b/pkg/peerdiscovery/prescan/priority.go new file mode 100644 index 0000000..5f7be4c --- /dev/null +++ b/pkg/peerdiscovery/prescan/priority.go @@ -0,0 +1,48 @@ +package prescan + +import ( + "net" + + "github.com/projectdiscovery/pd-agent/pkg/peerdiscovery/common" +) + +// PrioritizedIP holds an IP and its priority score (0-100) +type PrioritizedIP struct { + IP net.IP + Priority int +} + +// CalculatePriority returns priority score (0-100) for an IP in a network. +// Higher scores mean more likely to be online. IPv6 uses default priority. +func CalculatePriority(ip net.IP, network *net.IPNet) int { + if network == nil { + return PriorityTier6 + } + + if ip.To4() != nil { + return adaptPriorityForSubnet(ip, network) + } + + // IPv6 support TODO + return PriorityTier6 +} + +// filterUsableIPs drops network/broadcast addresses from the list +func filterUsableIPs(ips []string, network *net.IPNet) []net.IP { + var usableIPs []net.IP + + for _, ipStr := range ips { + ip := net.ParseIP(ipStr) + if ip == nil { + continue + } + + if common.IsNetworkOrBroadcast(ip, network) { + continue + } + + usableIPs = append(usableIPs, ip) + } + + return usableIPs +}