From 56568516b963fd2c929516526d5434d045b4517a Mon Sep 17 00:00:00 2001 From: Reza Khosravivala Date: Mon, 2 Feb 2026 00:52:59 +0100 Subject: [PATCH] fix: ip filtering added --- cli/cmd/start.go | 3 + cli/internal/conduit/service.go | 80 +++++++++++++---- cli/internal/config/config.go | 4 +- cli/internal/filter/country_filter.go | 118 ++++++++++++++++++++++++++ 4 files changed, 188 insertions(+), 17 deletions(-) create mode 100644 cli/internal/filter/country_filter.go diff --git a/cli/cmd/start.go b/cli/cmd/start.go index c45d8346..99cc9722 100644 --- a/cli/cmd/start.go +++ b/cli/cmd/start.go @@ -42,6 +42,7 @@ var ( geoEnabled bool metricsAddr string idleRestart string + allowedCountries []string ) var startCmd = &cobra.Command{ @@ -72,6 +73,7 @@ func init() { startCmd.Flags().StringVar(&metricsAddr, "metrics-addr", "", "address for Prometheus metrics endpoint (e.g., :9090 or 127.0.0.1:9090)") startCmd.Flags().StringVarP(&psiphonConfigPath, "psiphon-config", "c", "", "path to Psiphon network config file (JSON)") startCmd.Flags().StringVar(&idleRestart, "idle-restart", "", "restart service after idle duration (e.g., 30m, 1h, 2h)") + startCmd.Flags().StringSliceVar(&allowedCountries, "allowed-countries", nil, "only allow connections from these countries (e.g., IR)") } func runStart(cmd *cobra.Command, args []string) error { @@ -142,6 +144,7 @@ func runStart(cmd *cobra.Command, args []string) error { GeoEnabled: geoEnabled, MetricsAddr: metricsAddr, IdleRestart: idleRestartDuration, + AllowedCountries: allowedCountries, }) if err != nil { return fmt.Errorf("failed to load configuration: %w", err) diff --git a/cli/internal/conduit/service.go b/cli/internal/conduit/service.go index fea57f5c..228287bd 100644 --- a/cli/internal/conduit/service.go +++ b/cli/internal/conduit/service.go @@ -31,6 +31,7 @@ import ( "time" "github.com/Psiphon-Inc/conduit/cli/internal/config" + "github.com/Psiphon-Inc/conduit/cli/internal/filter" "github.com/Psiphon-Inc/conduit/cli/internal/geo" "github.com/Psiphon-Inc/conduit/cli/internal/metrics" "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon" @@ -42,12 +43,13 @@ var ErrIdleRestart = errors.New("idle restart triggered") // Service represents the Conduit inproxy service type Service struct { - config *config.Config - controller *psiphon.Controller - stats *Stats - geoCollector *geo.Collector - metrics *metrics.Metrics - mu sync.RWMutex + config *config.Config + controller *psiphon.Controller + stats *Stats + geoCollector *geo.Collector + countryFilter *filter.CountryFilter + metrics *metrics.Metrics + mu sync.RWMutex } // Stats tracks proxy activity statistics @@ -108,6 +110,20 @@ func (s *Service) Run(ctx context.Context) error { } } + // Initialize country filter if configured + if len(s.config.AllowedCountries) > 0 { + dbPath := s.config.DataDir + "/GeoLite2-Country.mmdb" + if err := geo.EnsureDatabase(dbPath); err != nil { + return fmt.Errorf("failed to ensure GeoIP database for filter: %w", err) + } + cf, err := filter.NewCountryFilter(dbPath, s.config.AllowedCountries) + if err != nil { + return fmt.Errorf("failed to create country filter: %w", err) + } + s.countryFilter = cf + fmt.Printf("[FILTER] Only allowing connections from: %v\n", s.config.AllowedCountries) + } + if s.metrics != nil && s.config.MetricsAddr != "" { if err := s.metrics.StartServer(s.config.MetricsAddr); err != nil { return fmt.Errorf("failed to start metrics server: %w", err) @@ -243,26 +259,58 @@ func (s *Service) createPsiphonConfig() (*psiphon.Config, error) { return nil, fmt.Errorf("failed to commit config: %w", err) } - // Set up geo tracking callback if enabled - if s.geoCollector != nil { + // Set up connection callbacks for filtering and/or geo tracking + if s.countryFilter != nil || s.geoCollector != nil { psiphonConfig.OnInproxyConnectionEstablished = func(local, remote inproxy.ConnectionStats) { if remote.IP == "" { return } - if remote.CandidateType == "relay" { - s.geoCollector.ConnectRelay(remote.IP) - } else { - s.geoCollector.ConnectIP(remote.IP) + + // Check country filter first (if enabled) + if s.countryFilter != nil { + allowed, countryCode, isRelay := s.countryFilter.IsAllowed(remote.IP) + if !allowed { + if s.config.Verbosity >= 1 { + fmt.Printf("[BLOCKED] Connection from %s (%s)\n", remote.IP, countryCode) + } + return + } + if s.config.Verbosity >= 2 { + if isRelay { + fmt.Printf("[ALLOWED] Relay connection from %s\n", remote.IP) + } else { + fmt.Printf("[ALLOWED] Connection from %s (%s)\n", remote.IP, countryCode) + } + } + } + + // Geo tracking (if enabled) + if s.geoCollector != nil { + if remote.CandidateType == "relay" { + s.geoCollector.ConnectRelay(remote.IP) + } else { + s.geoCollector.ConnectIP(remote.IP) + } } } psiphonConfig.OnInproxyConnectionClosed = func(remote *inproxy.ConnectionStats, bw *inproxy.BandwidthStats) { if remote == nil || remote.IP == "" || bw == nil { return } - if remote.CandidateType == "relay" { - s.geoCollector.DisconnectRelay(remote.IP, bw.BytesUp, bw.BytesDown) - } else { - s.geoCollector.DisconnectIP(remote.IP, bw.BytesUp, bw.BytesDown) + // Only track geo for connections that passed the filter + if s.geoCollector != nil { + if s.countryFilter != nil { + // Re-check filter to ensure we only track allowed connections + allowed, _, _ := s.countryFilter.IsAllowed(remote.IP) + if !allowed { + return + } + } + if remote.CandidateType == "relay" { + s.geoCollector.DisconnectRelay(remote.IP, bw.BytesUp, bw.BytesDown) + } else { + s.geoCollector.DisconnectIP(remote.IP, bw.BytesUp, bw.BytesDown) + } } } } diff --git a/cli/internal/config/config.go b/cli/internal/config/config.go index ee140120..64c56d96 100644 --- a/cli/internal/config/config.go +++ b/cli/internal/config/config.go @@ -55,6 +55,7 @@ type Options struct { GeoEnabled bool // Enable geo tracking via tcpdump MetricsAddr string // Address for Prometheus metrics endpoint (empty = disabled) IdleRestart time.Duration + AllowedCountries []string } // Config represents the validated configuration for the Conduit service @@ -71,6 +72,7 @@ type Config struct { GeoEnabled bool // Enable geo tracking via tcpdump MetricsAddr string // Address for Prometheus metrics endpoint (empty = disabled) IdleRestart time.Duration + AllowedCountries []string } // persistedKey represents the key data saved to disk @@ -185,6 +187,7 @@ func LoadOrCreate(opts Options) (*Config, error) { GeoEnabled: opts.GeoEnabled, MetricsAddr: opts.MetricsAddr, IdleRestart: opts.IdleRestart, + AllowedCountries: opts.AllowedCountries, }, nil } @@ -277,5 +280,4 @@ func LoadKey(dataDir string) (*crypto.KeyPair, string, error) { keyPair, err := crypto.ParsePrivateKey(privateKeyBytes) return keyPair, pk.PrivateKeyBase64, err - } diff --git a/cli/internal/filter/country_filter.go b/cli/internal/filter/country_filter.go new file mode 100644 index 00000000..6e00ad15 --- /dev/null +++ b/cli/internal/filter/country_filter.go @@ -0,0 +1,118 @@ +/* + * Copyright (c) 2026, Psiphon Inc. + * All rights reserved. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +// Package filter provides IP filtering based on country +package filter + +import ( + "net" + "sync" + + "github.com/oschwald/geoip2-golang" +) + +// CountryFilter filters connections based on country +type CountryFilter struct { + db *geoip2.Reader + allowedCountries map[string]bool + mu sync.RWMutex + + // Stats + allowedCount int64 + blockedCount int64 + relayCount int64 +} + +// NewCountryFilter creates a new country filter +func NewCountryFilter(dbPath string, allowedCountries []string) (*CountryFilter, error) { + db, err := geoip2.Open(dbPath) + if err != nil { + return nil, err + } + + allowed := make(map[string]bool) + for _, cc := range allowedCountries { + allowed[cc] = true + } + + return &CountryFilter{ + db: db, + allowedCountries: allowed, + }, nil +} + +// IsAllowed checks if an IP is allowed based on country +// Returns: allowed (bool), countryCode (string), isRelay (bool for private IPs) +func (f *CountryFilter) IsAllowed(ipStr string) (bool, string, bool) { + ip := net.ParseIP(ipStr) + if ip == nil { + // Invalid IP, block it + f.mu.Lock() + f.blockedCount++ + f.mu.Unlock() + return false, "", false + } + + // Allow private/loopback IPs (TURN relay connections) + if isPrivateIP(ip) { + f.mu.Lock() + f.relayCount++ + f.mu.Unlock() + return true, "RELAY", true + } + + f.mu.Lock() + defer f.mu.Unlock() + + record, err := f.db.Country(ip) + if err != nil || record.Country.IsoCode == "" { + // Can't determine country, block it + f.blockedCount++ + return false, "UNKNOWN", false + } + + countryCode := record.Country.IsoCode + if f.allowedCountries[countryCode] { + f.allowedCount++ + return true, countryCode, false + } + + f.blockedCount++ + return false, countryCode, false +} + +// GetStats returns the current filter statistics +func (f *CountryFilter) GetStats() (allowed, blocked, relay int64) { + f.mu.RLock() + defer f.mu.RUnlock() + return f.allowedCount, f.blockedCount, f.relayCount +} + +// Close closes the GeoIP database +func (f *CountryFilter) Close() error { + if f.db != nil { + return f.db.Close() + } + return nil +} + +// isPrivateIP checks if an IP is private/internal +func isPrivateIP(ip net.IP) bool { + return ip.IsLoopback() || ip.IsPrivate() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() +}