diff --git a/cmd/bulletin-board/main.go b/cmd/bulletin-board/main.go index c1ca7de..34a9bb4 100755 --- a/cmd/bulletin-board/main.go +++ b/cmd/bulletin-board/main.go @@ -40,12 +40,13 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - cfg, err := config.NewConfig() - if err != nil { - slog.Error("failed get config", err) + if err := config.InitConfig(); err != nil { + slog.Error("failed to init config", err) os.Exit(1) } + cfg := config.GlobalConfig + host := cfg.BulletinBoard.Host port := cfg.BulletinBoard.Port diff --git a/cmd/clients/main.go b/cmd/clients/main.go index c01aab9..7609d41 100644 --- a/cmd/clients/main.go +++ b/cmd/clients/main.go @@ -44,12 +44,13 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - cfg, err := config.NewConfig() - if err != nil { - slog.Error("failed get config", err) + if err = config.InitConfig(); err != nil { + slog.Error("failed to init config", err) os.Exit(1) } + cfg := config.GlobalConfig + slog.Info("⚡ init client", "heartbeat_interval", cfg.HeartbeatInterval) // set up logrus diff --git a/cmd/config/config.go b/cmd/config/config.go index c70cd75..8a241fc 100755 --- a/cmd/config/config.go +++ b/cmd/config/config.go @@ -18,7 +18,7 @@ type Node struct { Port int `yaml:"port"` } -type GlobalConfig struct { +type Config struct { ServerLoad int `yaml:"server_load"` HeartbeatInterval int `yaml:"heartbeat_interval"` MinNodes int `yaml:"min_nodes"` @@ -30,16 +30,18 @@ type GlobalConfig struct { Nodes []Node `yaml:"nodes"` } -func NewConfig() (*GlobalConfig, error) { - cfg := &GlobalConfig{} +var GlobalConfig *Config + +func InitConfig() error { + GlobalConfig = &Config{} if dir, err := os.Getwd(); err != nil { - return nil, fmt.Errorf("config.NewConfig(): global config error: %w", err) - } else if err2 := cleanenv.ReadConfig(dir+"/cmd/config/config.yml", cfg); err2 != nil { - return nil, fmt.Errorf("config.NewConfig(): global config error: %w", err2) - } else if err3 := cleanenv.ReadEnv(cfg); err3 != nil { - return nil, fmt.Errorf("config.NewConfig(): global config error: %w", err3) + return fmt.Errorf("config.NewConfig(): global config error: %w", err) + } else if err2 := cleanenv.ReadConfig(dir+"/cmd/config/config.yml", GlobalConfig); err2 != nil { + return fmt.Errorf("config.NewConfig(): global config error: %w", err2) + } else if err3 := cleanenv.ReadEnv(GlobalConfig); err3 != nil { + return fmt.Errorf("config.NewConfig(): global config error: %w", err3) } else { - return cfg, nil + return nil } } diff --git a/cmd/config/config.yml b/cmd/config/config.yml index de83d4f..8d74fad 100755 --- a/cmd/config/config.yml +++ b/cmd/config/config.yml @@ -14,4 +14,13 @@ nodes: port: 8081 - id: 2 host: 'localhost' - port: 8082 \ No newline at end of file + port: 8082 + - id: 3 + host: 'localhost' + port: 8083 + - id: 4 + host: 'localhost' + port: 8084 + - id: 5 + host: 'localhost' + port: 8085 \ No newline at end of file diff --git a/cmd/node/main.go b/cmd/node/main.go index b10cca8..972c591 100755 --- a/cmd/node/main.go +++ b/cmd/node/main.go @@ -52,12 +52,13 @@ func main() { ctx, cancel := context.WithCancel(context.Background()) - cfg, err := config.NewConfig() - if err != nil { - slog.Error("failed get config", err) + if err = config.InitConfig(); err != nil { + slog.Error("failed to init config", err) os.Exit(1) } + cfg := config.GlobalConfig + var nodeConfig *config.Node for _, n := range cfg.Nodes { if n.ID == *id { diff --git a/go.mod b/go.mod index 9c9a49d..4e49e93 100755 --- a/go.mod +++ b/go.mod @@ -1,20 +1,21 @@ module github.com/HannahMarsh/pi_t-experiment -go 1.19 +go 1.21 + +toolchain go1.21.2 require ( - github.com/enriquebris/goconcurrentqueue v0.7.0 + github.com/emirpasic/gods v1.18.1 github.com/google/uuid v1.3.0 github.com/google/wire v0.5.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 github.com/ilyakaznacheev/cleanenv v1.3.0 + github.com/jfcg/sorty/v2 v2.1.1 github.com/kyleconroy/sqlc v1.16.0 github.com/lib/pq v1.10.7 - github.com/orcaman/concurrent-map/v2 v2.0.1 github.com/pkg/errors v0.9.1 github.com/rabbitmq/amqp091-go v1.5.0 github.com/sirupsen/logrus v1.9.0 - github.com/thoas/go-funk v0.9.3 go.uber.org/automaxprocs v1.5.1 golang.org/x/exp v0.0.0-20221026153819-32f3d567a233 google.golang.org/genproto v0.0.0-20221014213838-99cd37c6964a @@ -26,6 +27,7 @@ require ( github.com/BurntSushi/toml v1.1.0 // indirect github.com/golang/glog v1.0.0 // indirect github.com/golang/protobuf v1.5.2 // indirect + github.com/jfcg/sixb v1.4.1 // indirect github.com/joho/godotenv v1.4.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/stretchr/testify v1.8.1 // indirect diff --git a/go.sum b/go.sum index d5b1753..43cd0a3 100755 --- a/go.sum +++ b/go.sum @@ -4,8 +4,8 @@ github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ3 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/enriquebris/goconcurrentqueue v0.7.0 h1:JYrDa45N3xo3Sr9mjvlRaWiBHvBEJIhAdLXO3VGVghA= -github.com/enriquebris/goconcurrentqueue v0.7.0/go.mod h1:OZ+KC2BcRYzjg0vgoUs1GFqdAjkD9mz2Ots7Jbm1yS4= +github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= +github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/golang/glog v1.0.0 h1:nfP3RFugxnNRyKgeWd4oI1nYvXpxrx8ck8ZrcizshdQ= github.com/golang/glog v1.0.0/go.mod h1:EWib/APOK0SL3dFbYqvxE3UYd8E6s1ouQ7iEp/0LWV4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -14,6 +14,7 @@ github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiu github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/subcommands v1.0.1/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -23,6 +24,14 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0 h1:kr3j8iIMR4ywO/O0rvksXaJvauG github.com/grpc-ecosystem/grpc-gateway/v2 v2.12.0/go.mod h1:ummNFgdgLhhX7aIiy35vVmQNS0rWXknfPE0qe6fmFXg= github.com/ilyakaznacheev/cleanenv v1.3.0 h1:RapuLclPPUbmdd5Bi5UXScwMEZA6+ZNLU5OW9itPjj0= github.com/ilyakaznacheev/cleanenv v1.3.0/go.mod h1:i0owW+HDxeGKE0/JPREJOdSCPIyOnmh6C0xhWAkF/xA= +github.com/jfcg/opt v0.3.1 h1:6zgKvv3fR5OlX2nxUYJC4wtosY30N4vypILgXmRNr34= +github.com/jfcg/opt v0.3.1/go.mod h1:3ZUYQhiqKM6vVjMRYV1fVZ9a91EQ47b5kg7KsnfRClk= +github.com/jfcg/rng v1.0.6 h1:JCYvI/GaSSd3lL0zl15J7FNHqMZYNosPmNKDrysKhH0= +github.com/jfcg/rng v1.0.6/go.mod h1:UqYFfcn9XCugyejaC+8cXQgrWtHQtvlWGiLQXu/7Cnw= +github.com/jfcg/sixb v1.4.1 h1:lb/fWXTn7G+Om2K2/NhZByppgPyQ7mWyQ5XOHu9PAfU= +github.com/jfcg/sixb v1.4.1/go.mod h1:hofNeC6Ua8uwQ7X14L2/byXF1xJyUyX4lk53SmeAopI= +github.com/jfcg/sorty/v2 v2.1.1 h1:jMgkME/JZ4dVFxOVtAeQUXbSzLhbPGxjgfhtnCMXXVM= +github.com/jfcg/sorty/v2 v2.1.1/go.mod h1:wFv8kNl8smeqwsx62BPpgxyjxMY+4ylIug1ARe4nLnI= github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -36,13 +45,12 @@ github.com/kyleconroy/sqlc v1.16.0 h1:PE5xrrnUiV5T2b97sLWKHgpBPQoPo/N1K/gWU/GFwa github.com/kyleconroy/sqlc v1.16.0/go.mod h1:m+cX/UyBRnKP58lFfUsq+0gw87UUw9AmxwqU/AaQeDA= github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= -github.com/orcaman/concurrent-map/v2 v2.0.1 h1:jOJ5Pg2w1oeB6PeDurIYf6k9PQ+aTITr/6lP/L/zp6c= -github.com/orcaman/concurrent-map/v2 v2.0.1/go.mod h1:9Eq3TG2oBe5FirmYWQfYO5iH1q0Jv47PLaNK++uCdOM= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= +github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= github.com/rabbitmq/amqp091-go v1.5.0 h1:VouyHPBu1CrKyJVfteGknGOGCzmOz0zcv/tONLkb7rg= github.com/rabbitmq/amqp091-go v1.5.0/go.mod h1:JsV0ofX5f1nwOGafb8L5rBItt9GyhfQfcJj+oyz0dGg= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= @@ -50,14 +58,11 @@ github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVs github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/thoas/go-funk v0.9.3 h1:7+nAEx3kn5ZJcnDm2Bh23N2yOtweO14bi//dvRtgLpw= -github.com/thoas/go-funk v0.9.3/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xzUPe4Q= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= go.uber.org/automaxprocs v1.5.1 h1:e1YG66Lrk73dn4qhg8WFSvhF0JuFQF0ERIp4rpuV8Qk= go.uber.org/automaxprocs v1.5.1/go.mod h1:BF4eumQw0P9GtnuxxovUd06vwm1o18oMzFtK66vU6XU= @@ -111,7 +116,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/api/nodeApi.go b/internal/api/nodeApi.go index 30cc74d..11e02bd 100644 --- a/internal/api/nodeApi.go +++ b/internal/api/nodeApi.go @@ -5,13 +5,13 @@ import "time" type PublicNodeApi struct { ID int Address string - PublicKey []byte + PublicKey string } type PrivateNodeApi struct { TimeOfRequest time.Time ID int Address string - PublicKey []byte + PublicKey string MessageQueue []int } diff --git a/internal/bulletin_board/NodeView.go b/internal/bulletin_board/NodeView.go index 99705a7..3dbe12d 100644 --- a/internal/bulletin_board/NodeView.go +++ b/internal/bulletin_board/NodeView.go @@ -9,7 +9,7 @@ import ( type NodeView struct { ID int Address string - PublicKey []byte + PublicKey string MessageQueue []int mu sync.RWMutex LastHeartbeat time.Time diff --git a/internal/bulletin_board/bulletin_board.go b/internal/bulletin_board/bulletin_board.go index 8660474..64699f0 100644 --- a/internal/bulletin_board/bulletin_board.go +++ b/internal/bulletin_board/bulletin_board.go @@ -14,11 +14,11 @@ import ( type BulletinBoard struct { Network map[int]*NodeView // Maps node IDs to their queue sizes mu sync.RWMutex - config *config.GlobalConfig + config *config.Config } // NewBulletinBoard creates a new bulletin board -func NewBulletinBoard(config *config.GlobalConfig) *BulletinBoard { +func NewBulletinBoard(config *config.Config) *BulletinBoard { return &BulletinBoard{ Network: make(map[int]*NodeView), config: config, diff --git a/internal/node/node.go b/internal/node/node.go index 62238a4..227ac94 100644 --- a/internal/node/node.go +++ b/internal/node/node.go @@ -2,10 +2,11 @@ package node import ( "bytes" - "context" "encoding/json" + "errors" "fmt" "github.com/HannahMarsh/pi_t-experiment/internal/api" + "github.com/HannahMarsh/pi_t-experiment/internal/pi_t" "github.com/HannahMarsh/pi_t-experiment/pkg/utils" "golang.org/x/exp/slog" "io" @@ -13,45 +14,51 @@ import ( "sync" "time" - "math/rand" + rng "math/rand" ) +var rand = rng.New(rng.NewSource(time.Now().UnixNano())) + // Node represents a node in the onion routing network type Node struct { ID int Host string Port int - PublicKey []byte - PrivateKey []byte + PublicKey string + PrivateKey string ActiveNodes []api.PublicNodeApi - MessageQueue []*api.Message - OnionQueue []*Onion - mu sync.Mutex - NodeInfo api.PublicNodeApi + OnionQueue *utils.SafeHeap[QueuedOnion] + mu sync.RWMutex BulletinBoardUrl string wg sync.WaitGroup requests sync.Map + lastUpdate time.Time +} + +type QueuedOnion struct { + ConstructedOnion string + DestinationAddress string + OriginalMessage api.Message + TimeReceived time.Time +} + +func qoLess(a, b QueuedOnion) bool { + return a.TimeReceived.Before(b.TimeReceived) } // NewNode creates a new node func NewNode(id int, host string, port int, bulletinBoardUrl string) (*Node, error) { - if publicKey, privateKey, err := utils.GenerateKeyPair(); err != nil { + if publicKey, privateKey, err := pi_t.KeyGen(); err != nil { return nil, fmt.Errorf("node.NewNode(): failed to generate key pair: %w", err) } else { n := &Node{ - ID: id, - Host: host, - Port: port, - PublicKey: publicKey, - PrivateKey: privateKey, - ActiveNodes: make([]api.PublicNodeApi, 0), - MessageQueue: make([]*api.Message, 0), - OnionQueue: make([]*Onion, 0), - NodeInfo: api.PublicNodeApi{ - ID: id, - Address: fmt.Sprintf("%s:%d", host, port), - PublicKey: publicKey, - }, + ID: id, + Host: host, + Port: port, + PublicKey: publicKey, + PrivateKey: privateKey, + ActiveNodes: make([]api.PublicNodeApi, 0), + OnionQueue: utils.NewSafeHeap(qoLess), BulletinBoardUrl: bulletinBoardUrl, wg: sync.WaitGroup{}, } @@ -59,38 +66,52 @@ func NewNode(id int, host string, port int, bulletinBoardUrl string) (*Node, err return nil, fmt.Errorf("node.NewNode(): failed to register with bulletin board: %w", err2) } - go n.StartPeriodicUpdates(time.Second * 1) + go n.StartPeriodicUpdates(time.Second * 3) return n, nil } } +func (n *Node) getPublicNodeInfo() api.PublicNodeApi { + return api.PublicNodeApi{ + ID: n.ID, + Address: fmt.Sprintf("http://%s:%d", n.Host, n.Port), + PublicKey: n.PublicKey, + } +} + +func (n *Node) getPrivateNodeInfo(timeOfRequest time.Time) api.PrivateNodeApi { + mq := n.OnionQueue.MapToInt(func(qo QueuedOnion) int { + return qo.OriginalMessage.To + }) + return api.PrivateNodeApi{ + TimeOfRequest: timeOfRequest, + ID: n.ID, + Address: fmt.Sprintf("http://%s:%d", n.Host, n.Port), + PublicKey: n.PublicKey, + MessageQueue: mq, + } +} + func (n *Node) StartPeriodicUpdates(interval time.Duration) { ticker := time.NewTicker(interval) go func() { for range ticker.C { - if err := n.updateBulletinBoard(); err != nil { + if err := n.updateBulletinBoard("/update", http.StatusOK); err != nil { fmt.Printf("Error updating bulletin board: %v\n", err) return + } else if activeNodes, err2 := n.GetActiveNodes(); err2 != nil { + fmt.Printf("Error getting active nodes: %v\n", err2) + return + } else { + n.mu.Lock() + n.ActiveNodes = utils.Copy(activeNodes) + n.mu.Unlock() } - n.ProcessMessageQueue() } }() } -func (n *Node) ProcessMessageQueue() { - n.mu.Lock() - defer n.mu.Unlock() - for _, msg := range n.MessageQueue { - // Create an onion from the message - if onion, err := n.NewOnion(msg, 1); err != nil { - fmt.Printf("Error creating onion: %v\n", err) - } else { - n.OnionQueue = append(n.OnionQueue, onion) - } - } -} - func (n *Node) getNode(id int) *api.PublicNodeApi { for _, node := range n.ActiveNodes { if node.ID == id { @@ -105,52 +126,74 @@ func (n *Node) getRandomNode() *api.PublicNodeApi { return &n.ActiveNodes[r] } -func (n *Node) NewOnion(msg *api.Message, pathLength int) (*Onion, error) { - if msg_string, err := json.Marshal(msg); err != nil { - return nil, fmt.Errorf("NewOnion(): failed to marshal message: %w", err) +func (n *Node) QueueOnion(msg api.Message, pathLength int) error { + timeReceived := time.Now() + if msgString, err := json.Marshal(msg); err != nil { + return fmt.Errorf("NewOnion(): failed to marshal message: %w", err) + } else if to := n.getNode(msg.To); to == nil { + return fmt.Errorf("NewOnion(): failed to get node with id %d", msg.To) + } else if routingPath, err2 := n.DetermineRoutingPath(pathLength); err2 != nil { + return fmt.Errorf("NewOnion(): failed to determine routing path: %w", err2) } else { - if to := n.getNode(msg.To); to != nil { - if o, err2 := NewOnion(fmt.Sprintf("%s/receive", to.Address), msg_string, to.PublicKey); err2 != nil { - return nil, fmt.Errorf("NewOnion(): failed to create onion: %w", err2) - } else { - for i := 0; i < pathLength; i++ { - var intermediary *api.PublicNodeApi - for intermediary.ID == n.ID { - intermediary = n.getRandomNode() - } - if err3 := o.AddLayer(intermediary.Address, intermediary.PublicKey); err3 != nil { - return nil, fmt.Errorf("NewOnion(): failed to add layer: %w", err3) - } - } - return o, nil - } + publicKeys := utils.Map(routingPath, func(node api.PublicNodeApi) string { + return node.PublicKey + }) + addresses := utils.Map(routingPath, func(node api.PublicNodeApi) string { + return node.Address + }) + if addr, onion, err3 := pi_t.FormOnion(msgString, publicKeys, addresses); err3 != nil { + return fmt.Errorf("NewOnion(): failed to create onion: %w", err3) } else { - return nil, fmt.Errorf("NewOnion(): failed to get node with id %d", msg.To) + qo := QueuedOnion{ + ConstructedOnion: onion, + DestinationAddress: addr, + OriginalMessage: msg, + TimeReceived: timeReceived, + } + n.OnionQueue.Push(qo) + return nil } } } +// DetermineRoutingPath determines a random routing path of a given length +func (n *Node) DetermineRoutingPath(pathLength int) ([]api.PublicNodeApi, error) { + if len(n.ActiveNodes) < pathLength { + return nil, errors.New("not enough nodes to form a path") + } + + selectedNodes := make([]api.PublicNodeApi, pathLength) + perm := rand.Perm(len(n.ActiveNodes)) + + for i := 0; i < pathLength; i++ { + selectedNodes[i] = n.ActiveNodes[perm[i]] + } + + return selectedNodes, nil +} + +func (n *Node) IDsMatch(nodeApi api.PublicNodeApi) bool { + return n.ID == nodeApi.ID +} + func (n *Node) startRun(activeNodes []api.PublicNodeApi) (didParticipate bool, e error) { + n.wg.Wait() + n.wg.Add(1) + defer n.wg.Done() + n.mu.Lock() if len(activeNodes) == 0 { n.mu.Unlock() return false, fmt.Errorf("startRun(): no active nodes") } - n.ActiveNodes = activeNodes - onionsToSend := n.OnionQueue - n.OnionQueue = make([]*Onion, 0) + n.ActiveNodes = utils.Copy(activeNodes) + onionsToSend := n.OnionQueue.Drain() n.mu.Unlock() + slog.Info("Starting run with", "num_onions", len(onionsToSend)) - n.wg.Wait() - n.wg.Add(1) - defer n.wg.Done() - var participate bool = false - for _, node := range activeNodes { - if node.ID == n.ID { - participate = true - } - } + participate := utils.Contains(activeNodes, n.IDsMatch) + if participate { for _, onion := range onionsToSend { if err2 := sendToNode(onion); err2 != nil { @@ -162,29 +205,27 @@ func (n *Node) startRun(activeNodes []api.PublicNodeApi) (didParticipate bool, e return false, nil } -func (n *Node) Receive(o *Onion) error { - if err := o.RemoveLayer(n.PrivateKey); err != nil { +func (n *Node) Receive(o string) error { + if destination, payload, err := pi_t.PeelOnion(o, n.PrivateKey); err != nil { return fmt.Errorf("node.Receive(): failed to remove layer: %w", err) - } else if o.HasNextLayer() { - if err2 := sendToNode(o); err2 != nil { - return fmt.Errorf("node.Receive(): failed to send to next node: %w", err2) - } - } else if o.HasMessage() { - // Process the final message here - slog.Info("Received onion with message", "message", o.Message) - context.TODO() } else { - slog.Info("Received dummy onion") - context.TODO() + bruised, err2 := pi_t.BruiseOnion(payload) + if err2 != nil { + return fmt.Errorf("node.Receive(): failed to bruise onion: %w", err2) + } + if err3 := sendToNode(QueuedOnion{ + ConstructedOnion: bruised, + DestinationAddress: destination, + }); err != nil { + return fmt.Errorf("node.Receive(): failed to send to next node: %w", err3) + } } return nil } -func sendToNode(o *Onion) error { - url := fmt.Sprintf("http://%s/receive", o.Address) - if data, err := json.Marshal(o); err != nil { - return fmt.Errorf("sendToNode(): failed to marshal onion: %w", err) - } else if resp, err2 := http.Post(url, "application/json", bytes.NewBuffer(data)); err2 != nil { +func sendToNode(onion QueuedOnion) error { + url := fmt.Sprintf("http://%s/receive", onion.DestinationAddress) + if resp, err2 := http.Post(url, "application/json", bytes.NewBuffer([]byte(onion.ConstructedOnion))); err2 != nil { return fmt.Errorf("sendToNode(): failed to send POST request with onion to next node: %w", err2) } else { defer func(Body io.ReadCloser) { diff --git a/internal/node/node_handler.go b/internal/node/node_handler.go index 6aeaa42..31957c1 100644 --- a/internal/node/node_handler.go +++ b/internal/node/node_handler.go @@ -9,17 +9,18 @@ import ( "golang.org/x/exp/slog" "io" "net/http" + "time" ) func (n *Node) HandleReceive(w http.ResponseWriter, r *http.Request) { slog.Info("Received onion") - var o Onion + var o string if err := json.NewDecoder(r.Body).Decode(&o); err != nil { slog.Error("Error decoding onion", err) http.Error(w, err.Error(), http.StatusBadRequest) return } - if err := n.Receive(&o); err != nil { + if err := n.Receive(o); err != nil { slog.Error("Error receiving onion", err) http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -35,6 +36,7 @@ func (n *Node) HandleStartRun(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusBadRequest) return } + slog.Info("Active nodes", "activeNodes", activeNodes) go func() { if didParticipate, err := n.startRun(activeNodes); err != nil { slog.Error("Error starting run", err) @@ -46,41 +48,27 @@ func (n *Node) HandleStartRun(w http.ResponseWriter, r *http.Request) { } func (n *Node) HandleClientRequest(w http.ResponseWriter, r *http.Request) { - slog.Info("Received client request") + var msgs []api.Message if err := json.NewDecoder(r.Body).Decode(&msgs); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } - slog.Info("Enqueuing messages", "num_messages", len(msgs)) + slog.Info("Received client request", "num_messages", len(msgs), "destinations", utils.Map(msgs, func(m api.Message) int { return m.To })) + //slog.Info("Enqueuing messages", "num_messages", len(msgs)) for _, msg := range msgs { - n.mu.Lock() - n.MessageQueue = append(n.MessageQueue, &msg) - n.mu.Unlock() + if err := n.QueueOnion(msg, 2); err != nil { + slog.Error("Error queuing message", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } w.WriteHeader(http.StatusOK) } func (n *Node) RegisterWithBulletinBoard() error { - if data, err := json.Marshal(n.NodeInfo); err != nil { - return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to marshal node info: %w", err) - } else { - url := n.BulletinBoardUrl + "/register" - slog.Info("Sending node registration request.", "url", url, "id", n.NodeInfo.ID) - if resp, err2 := http.Post(url, "application/json", bytes.NewBuffer(data)); err2 != nil { - return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to send POST request to bulletin board: %w", err2) - } else { - defer func(Body io.ReadCloser) { - if err3 := Body.Close(); err3 != nil { - fmt.Printf("node.RegisterWithBulletinBoard(): error closing response body: %v\n", err2) - } - }(resp.Body) - if resp.StatusCode != http.StatusCreated { - return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to register node, status code: %d, %s", resp.StatusCode, resp.Status) - } - return nil - } - } + slog.Info("Sending node registration request.", "id", n.ID) + return n.updateBulletinBoard("/register", http.StatusCreated) } func (n *Node) GetActiveNodes() ([]api.PublicNodeApi, error) { @@ -90,8 +78,7 @@ func (n *Node) GetActiveNodes() ([]api.PublicNodeApi, error) { return nil, fmt.Errorf("error making GET request to %s: %v", url, err) } defer func(Body io.ReadCloser) { - err := Body.Close() - if err != nil { + if err2 := Body.Close(); err2 != nil { fmt.Printf("error closing response body: %v\n", err) } }(resp.Body) @@ -108,37 +95,27 @@ func (n *Node) GetActiveNodes() ([]api.PublicNodeApi, error) { return activeNodes, nil } -func (n *Node) updateBulletinBoard() error { +func (n *Node) updateBulletinBoard(endpoint string, expectedStatusCode int) error { n.mu.Lock() - a, _ := n.GetActiveNodes() - if a != nil && len(a) > 0 { - n.ActiveNodes = a - } - m := utils.NewStream(n.MessageQueue).MapToInt(func(msg *api.Message) int { - return msg.To - }).Array - n.mu.Unlock() - nodeInfo := api.PrivateNodeApi{ - ID: n.ID, - Address: n.NodeInfo.Address, - PublicKey: n.PublicKey, - MessageQueue: m, - } - if data, err := json.Marshal(nodeInfo); err != nil { - return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to marshal node info: %w", err) + defer n.mu.Unlock() + t := time.Now() + if data, err := json.Marshal(n.getPrivateNodeInfo(t)); err != nil { + return fmt.Errorf("node.UpdateBulletinBoard(): failed to marshal node info: %w", err) } else { - url := n.BulletinBoardUrl + "/update" - slog.Info("Sending node registration request.", "url", url, "id", n.NodeInfo.ID) + url := n.BulletinBoardUrl + endpoint + //slog.Info("Sending node registration request.", "url", url, "id", n.ID) if resp, err2 := http.Post(url, "application/json", bytes.NewBuffer(data)); err2 != nil { - return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to send POST request to bulletin board: %w", err2) + return fmt.Errorf("node.UpdateBulletinBoard(): failed to send POST request to bulletin board: %w", err2) } else { defer func(Body io.ReadCloser) { if err3 := Body.Close(); err3 != nil { - fmt.Printf("node.RegisterWithBulletinBoard(): error closing response body: %v\n", err2) + fmt.Printf("node.UpdateBulletinBoard(): error closing response body: %v\n", err2) } }(resp.Body) - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to register node, status code: %d, %s", resp.StatusCode, resp.Status) + if resp.StatusCode != expectedStatusCode { + return fmt.Errorf("node.RegisterWithBulletinBoard(): failed to %s node, status code: %d, %s", endpoint, resp.StatusCode, resp.Status) + } else { + n.lastUpdate = t } return nil } diff --git a/internal/node/onion.go b/internal/node/onion.go deleted file mode 100644 index ccae067..0000000 --- a/internal/node/onion.go +++ /dev/null @@ -1,124 +0,0 @@ -package node - -import ( - "bytes" - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/gob" - "encoding/pem" - "errors" - "fmt" -) - -type Onion struct { - Address string - Data []byte - Message []byte -} - -func (o *Onion) IsDummy() bool { - return o.Data == nil && o.Message == nil -} - -func (o *Onion) HasNextLayer() bool { - return o.Data != nil -} - -func (o *Onion) HasMessage() bool { - return o.Message != nil -} - -func (o *Onion) RemoveLayer(privateKey []byte) error { - // Parse private key - if o.HasNextLayer() { - if inner, err := decrypt(o.Data, privateKey); err != nil { - return fmt.Errorf("onion.RemoveLayer(): failed to decrypt data: %w", err) - } else if on, err2 := fromBytes(inner); err2 != nil { - return fmt.Errorf("onion.RemoveLayer(): failed to decrypt data: %w", err2) - } else { - o.Address = on.Address - o.Data = on.Data - return nil - } - } else if o.HasMessage() { - if message, err := decrypt(o.Message, privateKey); err != nil { - return fmt.Errorf("onion.RemoveLayer(): failed to decrypt message: %w", err) - } else { - o.Message = message - return nil - } - } else { - return nil - } -} - -func NewOnion(addr string, msg []byte, publicKey []byte) (*Onion, error) { - if encryptedData, err := encrypt(msg, publicKey); err != nil { - return nil, fmt.Errorf("newOnion(): failed to encrypt message: %w", err) - } else { - return &Onion{ - Address: addr, - Data: nil, - Message: encryptedData, - }, nil - } -} - -func (o *Onion) AddLayer(addr string, publicKey []byte) error { - if b, err := toBytes(o); err != nil { - return fmt.Errorf("onion.AddLayer(): failed to add layer: %w", err) - } else if encryptedData, err2 := encrypt(b, publicKey); err2 != nil { - return fmt.Errorf("onion.AddLayer(): failed to add layer: %w", err2) - } else { - o.Address = addr - o.Data = encryptedData - o.Message = nil - return nil - } -} - -func toBytes(o *Onion) ([]byte, error) { - var buf bytes.Buffer // Stand-in for a buf connection - enc := gob.NewEncoder(&buf) // Will write to buf. - // Encode (send) the value. - if err := enc.Encode(o); err != nil { - return nil, fmt.Errorf("toBytes(): failed to encode onion: %w", err) - } - return buf.Bytes(), nil -} - -func fromBytes(data []byte) (*Onion, error) { - dec := gob.NewDecoder(bytes.NewReader(data)) - var o Onion - if err := dec.Decode(&o); err != nil { - return nil, fmt.Errorf("fromBytes(): failed to decode onion: %w", err) - } - return &o, nil -} - -func decrypt(data []byte, privateKey []byte) ([]byte, error) { - if block, _ := pem.Decode(privateKey); block == nil { - return nil, fmt.Errorf("decrypt(): failed to parse private key: %s", string(privateKey)) - } else if privKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { - return nil, fmt.Errorf("decrypt(): failed to parse private key: %w", err) - } else if result, err2 := rsa.DecryptPKCS1v15(rand.Reader, privKey, data); err2 != nil { // Decrypt address and data - return nil, fmt.Errorf("decrypt(): failed to decrypt address: %w", err2) - } else { - return result, nil - } -} - -func encrypt(data []byte, publicKey []byte) ([]byte, error) { - if block, _ := pem.Decode(publicKey); block == nil { - return nil, errors.New("encrypt(): failed to parse public key") - } else if pubKey, err := x509.ParsePKIXPublicKey(block.Bytes); err != nil { - return nil, fmt.Errorf("encrypt(): failed to parse public key: %w", err) - } else if rsaPubKey, ok := pubKey.(*rsa.PublicKey); !ok { - return nil, errors.New("encrypt(): failed to parse RSA public key") - } else if encryptedData, err2 := rsa.EncryptPKCS1v15(rand.Reader, rsaPubKey, data); err2 != nil { - return nil, fmt.Errorf("encrypt(): failed to encrypt address: %w", err2) - } else { - return encryptedData, nil - } -} diff --git a/internal/pi_t/pi_t_functions.go b/internal/pi_t/pi_t_functions.go new file mode 100644 index 0000000..1adcc81 --- /dev/null +++ b/internal/pi_t/pi_t_functions.go @@ -0,0 +1,128 @@ +package pi_t + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "errors" + "fmt" +) + +// KeyGen generates an RSA key pair and returns the public and private keys in PEM format +func KeyGen() (privateKeyPEM, publicKeyPEM string, err error) { + // Generate RSA key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", "", fmt.Errorf("failed to generate private key: %w", err) + } + + // Encode private key to PEM format + privateKeyPEMBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(privateKey), + }) + privateKeyPEM = string(privateKeyPEMBytes) + + // Generate public key + publicKey := &privateKey.PublicKey + + // Encode public key to PEM format + publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) + if err != nil { + return "", "", fmt.Errorf("failed to marshal public key: %w", err) + } + publicKeyPEMBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PUBLIC KEY", + Bytes: publicKeyBytes, + }) + publicKeyPEM = string(publicKeyPEMBytes) + + return privateKeyPEM, publicKeyPEM, nil +} + +type OnionLayer struct { + NextHop string + Payload string +} + +// FormOnion creates an onion by encapsulating a message in multiple encryption layers +func FormOnion(payload []byte, publicKeys []string, routingPath []string) (string, string, error) { + + for i := len(publicKeys) - 1; i >= 0; i-- { + layer := OnionLayer{ + NextHop: routingPath[i], + Payload: base64.StdEncoding.EncodeToString(payload), + } + + layerBytes, err := json.Marshal(layer) + if err != nil { + return "", "", err + } + + pubKeyBlock, _ := pem.Decode([]byte(publicKeys[i])) + if pubKeyBlock == nil || pubKeyBlock.Type != "RSA PUBLIC KEY" { + return "", "", errors.New("invalid public key PEM block") + } + + pubKey, err := x509.ParsePKIXPublicKey(pubKeyBlock.Bytes) + if err != nil { + return "", "", err + } + + payload, err = rsa.EncryptPKCS1v15(rand.Reader, pubKey.(*rsa.PublicKey), layerBytes) + if err != nil { + return "", "", err + } + } + + return routingPath[0], base64.StdEncoding.EncodeToString(payload), nil +} + +// PeelOnion removes the outermost layer of the onion +func PeelOnion(onion string, privateKeyPEM string) (string, string, error) { + privateKeyBlock, _ := pem.Decode([]byte(privateKeyPEM)) + if privateKeyBlock == nil || privateKeyBlock.Type != "RSA PRIVATE KEY" { + return "", "", errors.New("invalid private key PEM block") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes) + if err != nil { + return "", "", err + } + + onionBytes, err := base64.StdEncoding.DecodeString(onion) + if err != nil { + return "", "", err + } + + decryptedBytes, err := rsa.DecryptPKCS1v15(rand.Reader, privateKey, onionBytes) + if err != nil { + return "", "", err + } + + var layer OnionLayer + err = json.Unmarshal(decryptedBytes, &layer) + if err != nil { + return "", "", err + } + + return layer.NextHop, layer.Payload, nil +} + +// BruiseOnion modifies the onion payload to introduce bruising +func BruiseOnion(onion string) (string, error) { + onionBytes, err := base64.StdEncoding.DecodeString(onion) + if err != nil { + return "", err + } + + // Introduce bruising by modifying a small portion of the payload + if len(onionBytes) > 0 { + onionBytes[0] ^= 0xFF + } + + return base64.StdEncoding.EncodeToString(onionBytes), nil +} diff --git a/internal/pi_t/pi_t_functions_test.go b/internal/pi_t/pi_t_functions_test.go new file mode 100644 index 0000000..2afba42 --- /dev/null +++ b/internal/pi_t/pi_t_functions_test.go @@ -0,0 +1,104 @@ +package pi_t + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestKeyGen(t *testing.T) { + privateKeyPEM, publicKeyPEM, err := KeyGen() + if err != nil { + t.Fatalf("KeyGen() error: %v", err) + } + if privateKeyPEM == "" || publicKeyPEM == "" { + t.Fatal("KeyGen() returned empty keys") + } +} + +func TestFormOnion(t *testing.T) { + _, publicKeyPEM, err := KeyGen() + if err != nil { + t.Fatalf("KeyGen() error: %v", err) + } + + payload := []byte("secret message") + publicKeys := []string{publicKeyPEM, publicKeyPEM} + routingPath := []string{"node1", "node2"} + + addr, onion, err := FormOnion(payload, publicKeys, routingPath) + if err != nil { + t.Fatalf("FormOnion() error: %v", err) + } + + if addr != "node1" { + t.Fatalf("FormOnion() expected address 'node1', got %s", addr) + } + + if onion == "" { + t.Fatal("FormOnion() returned empty onion") + } +} + +func TestPeelOnion(t *testing.T) { + privateKeyPEM, publicKeyPEM, err := KeyGen() + if err != nil { + t.Fatalf("KeyGen() error: %v", err) + } + + payload := []byte("secret message") + publicKeys := []string{publicKeyPEM, publicKeyPEM} + routingPath := []string{"node1", "node2"} + + _, onion, err := FormOnion(payload, publicKeys, routingPath) + if err != nil { + t.Fatalf("FormOnion() error: %v", err) + } + + nextHop, peeledPayload, err := PeelOnion(onion, privateKeyPEM) + if err != nil { + t.Fatalf("PeelOnion() error: %v", err) + } + + if nextHop != "node2" { + t.Fatalf("PeelOnion() expected next hop 'node2', got %s", nextHop) + } + + decodedPayload, err := base64.StdEncoding.DecodeString(peeledPayload) + if err != nil { + t.Fatalf("PeelOnion() error decoding payload: %v", err) + } + + var layer OnionLayer + err = json.Unmarshal(decodedPayload, &layer) + if err != nil { + t.Fatalf("PeelOnion() error unmarshaling layer: %v", err) + } + + if layer.Payload != base64.StdEncoding.EncodeToString(payload) { + t.Fatalf("PeelOnion() expected payload %s, got %s", base64.StdEncoding.EncodeToString(payload), layer.Payload) + } +} + +func TestBruiseOnion(t *testing.T) { + payload := []byte("secret message") + onion := base64.StdEncoding.EncodeToString(payload) + + bruisedOnion, err := BruiseOnion(onion) + if err != nil { + t.Fatalf("BruiseOnion() error: %v", err) + } + + if bruisedOnion == onion { + t.Fatal("BruiseOnion() did not modify the onion") + } + + decodedBruisedOnion, err := base64.StdEncoding.DecodeString(bruisedOnion) + if err != nil { + t.Fatalf("BruiseOnion() error decoding bruised onion: %v", err) + } + + if decodedBruisedOnion[0] != payload[0]^0xFF { + t.Fatalf("BruiseOnion() did not correctly modify the onion, got %x", decodedBruisedOnion[0]) + } +} diff --git a/pkg/utils/laplace.go b/pkg/utils/laplace.go new file mode 100644 index 0000000..3c50479 --- /dev/null +++ b/pkg/utils/laplace.go @@ -0,0 +1,30 @@ +package utils + +import ( + "math" + "math/rand" + "time" +) + +// Initialize the random number generator +func init() { + rand.Seed(time.Now().UnixNano()) +} + +// Function to generate Laplace noise +func laplaceMechanism(value, sensitivity, epsilon float64) float64 { + scale := sensitivity / epsilon + u := rand.Float64() - 0.5 + return value - scale*math.Copysign(math.Log(1-2*math.Abs(u)), u) +} + +// LaplaceNoise adds Laplace noise to a given value +func LaplaceNoise(value, epsilon, delta float64) float64 { + b := epsilon / math.Sqrt(2*math.Log(1.25/delta)) + u := rand.Float64() - 0.5 + sign := 1.0 + if u < 0 { + sign = -1.0 + } + return value - b*sign*math.Log(1-2*math.Abs(u)) +} diff --git a/pkg/utils/safeHeap.go b/pkg/utils/safeHeap.go new file mode 100644 index 0000000..a87d5f2 --- /dev/null +++ b/pkg/utils/safeHeap.go @@ -0,0 +1,117 @@ +package utils + +import ( + pq "github.com/emirpasic/gods/queues/priorityqueue" + "sync" +) + +type SafeHeap[T any] struct { + p *pq.Queue + less func(a, b T) bool + mu sync.RWMutex +} + +func NewSafeHeap[T any](less func(a, b T) bool) *SafeHeap[T] { + return &SafeHeap[T]{ + p: pq.NewWith(Comparator(less)), + less: less, + } +} + +func (sh *SafeHeap[T]) Push(value T) { + sh.mu.Lock() + sh.p.Enqueue(value) + sh.mu.Unlock() +} + +func (sh *SafeHeap[T]) Pop() (*T, bool) { + sh.mu.Lock() + defer sh.mu.Unlock() + if sh.p.Empty() { + return nil, false + } + if v, b := sh.p.Dequeue(); b || v == nil { + return nil, false + } else { + return v.(*T), true + } +} + +func (sh *SafeHeap[T]) Size() int { + sh.mu.RLock() + defer sh.mu.RUnlock() + return sh.p.Size() +} + +func (sh *SafeHeap[T]) Drain() []T { + sh.mu.Lock() + defer sh.mu.Unlock() + values := make([]T, sh.p.Size()) + for i := 0; i < len(values); i++ { + if v, b := sh.Pop(); b { + values[i] = *v + } else { + return values[:i] + } + } + return values +} + +func (sh *SafeHeap[T]) Clear() { + sh.mu.Lock() + defer sh.mu.Unlock() + sh.p.Clear() +} + +func (sh *SafeHeap[T]) Values() []T { + sh.mu.RLock() + defer sh.mu.RUnlock() + values := sh.p.Values() + ret := make([]T, len(values)) + for i, v := range values { + ret[i] = v.(T) + } + return ret +} + +func (sh *SafeHeap[T]) Empty() bool { + sh.mu.RLock() + defer sh.mu.RUnlock() + return sh.p.Empty() +} + +func (sh *SafeHeap[T]) String() string { + sh.mu.RLock() + defer sh.mu.RUnlock() + return sh.p.String() +} + +func (sh *SafeHeap[T]) Peek() (*T, bool) { + sh.mu.RLock() + defer sh.mu.RUnlock() + if v, b := sh.p.Peek(); b { + return v.(*T), true + } else { + return nil, false + } +} + +func (sh *SafeHeap[T]) MapToInt(f func(T) int) []int { + values := sh.Values() + ret := make([]int, len(values)) + for i, v := range values { + ret[i] = f(v) + } + return ret +} + +func Comparator[T any](less func(T, T) bool) func(interface{}, interface{}) int { + return func(a, b interface{}) int { + if less(a.(T), b.(T)) { + return -1 + } else if less(b.(T), a.(T)) { + return 1 + } + return 0 + } +} diff --git a/pkg/utils/stream.go b/pkg/utils/stream.go index 3cc2c64..91ba994 100644 --- a/pkg/utils/stream.go +++ b/pkg/utils/stream.go @@ -1,5 +1,14 @@ package utils +import ( + "cmp" + "context" + "github.com/jfcg/sorty/v2" + "runtime" + "sync" + "sync/atomic" +) + type Stream[T any] struct { Array []T } @@ -189,3 +198,206 @@ func Map[T any, O any](items []T, f func(T) O) []O { } return result } + +func Contains[T any](items []T, f func(T) bool) bool { + for _, item := range items { + if f(item) { + return true + } + } + return false +} + +func DoesNotContain[T any](items []T, f func(T) bool) bool { + return !Contains(items, f) +} + +func Find[T any](items []T, f func(T) bool) *T { + for _, item := range items { + if f(item) { + return &item + } + } + return nil +} + +func FindInMap[K comparable, V any](m map[K]V, f func(K, V) bool, defaultKey K, defaultValue V) (K, V, bool) { + for k, v := range m { + if f(k, v) { + return k, v, true + } + } + return defaultKey, defaultValue, false +} + +func FindKey[K comparable, V any](m map[K]V, f func(K, V) bool, defaultValue K) (K, bool) { + for k, v := range m { + if f(k, v) { + return k, true + } + } + return defaultValue, false +} + +func FindValue[K comparable, V any](m map[K]V, f func(K, V) bool, defaultValue V) (V, bool) { + for k, v := range m { + if f(k, v) { + return v, true + } + } + return defaultValue, false +} + +func DoesMapContain[K comparable, V any](m map[K]V, f func(K, V) bool) bool { + for k, v := range m { + if f(k, v) { + return true + } + } + return false +} + +func DoesMapNotContain[K comparable, V any](m map[K]V, f func(K, V) bool) bool { + return !DoesMapContain(m, f) +} + +func FindIndex[T any](items []T, f func(T) bool) int { + for i, item := range items { + if f(item) { + return i + } + } + return -1 +} + +func Copy[T any](items []T) []T { + result := make([]T, len(items)) + copy(result, items) + return result +} + +func CopyMap[K comparable, V any](m map[K]V) map[K]V { + result := make(map[K]V) + for k, v := range m { + result[k] = v + } + return result +} + +func Swap[T any](items []*T, i, j int) { + items[i], items[j] = items[j], items[i] +} + +func Flatten[T any](items [][]T) []T { + var result []T + for _, item := range items { + result = append(result, item...) + } + return result +} + +func FlatMap[T any, O any](items []T, f func(T) []O) []O { + var result []O + for _, item := range items { + result = append(result, f(item)...) + } + return result +} + +func Fold[T any, O any](items []T, initial O, f func(O, T) O) O { + result := initial + for _, item := range items { + result = f(result, item) + } + return result +} + +func Apply[T any](items []T, f func(T)) { + for _, item := range items { + f(item) + } +} + +func Unless[T any](items []T, f func(T) bool) bool { + for _, item := range items { + if !f(item) { + return false + } + } + return true +} + +func ParallelFind[T any](items []T, f func(T) bool) *T { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() // Ensure all paths cancel the context to avoid a context leak + + var found atomic.Value + found.Store((*T)(nil)) // Initialize with nil + + var wg sync.WaitGroup + numProcs := runtime.NumCPU() // Number of logical CPUs + segmentSize := (len(items) + numProcs - 1) / numProcs + + for i := 0; i < len(items); i += segmentSize { + end := i + segmentSize + if end > len(items) { + end = len(items) + } + + wg.Add(1) + go func(segment []T) { + defer wg.Done() + for _, item := range segment { + select { + case <-ctx.Done(): + return // Exit if context is cancelled + default: + if f(item) { + found.Store(&item) + cancel() // Cancel other goroutines + return + } + } + } + }(items[i:end]) + } + + wg.Wait() + result, _ := found.Load().(*T) + return result +} + +func Sort[T any](items []T, less func(T, T) bool) { + // Define the Lesswap function required by sorty + lesswap := func(i, k, r, s int) bool { + if less(items[i], items[k]) { + if r != s { + items[r], items[s] = items[s], items[r] + } + return true + } + return false + } + + // Call sorty.Sort with the length of the items and the lesswap function + sorty.Sort(len(items), lesswap) +} + +func SortOrdered[T cmp.Ordered](items []T) { + Sort(items, func(a, b T) bool { + return a < b + }) +} + +func ParallelContains[T any](items []T, f func(T) bool) bool { + return ParallelFind(items, f) != nil +} + +func FindLast[T any](items []T, f func(T) bool) *T { + for i := len(items) - 1; i >= 0; i-- { + if f(items[i]) { + return &items[i] + } + } + return nil +} diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index d76e59f..3307a2b 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -1,11 +1,6 @@ package utils import ( - "crypto/rand" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "fmt" "os" ) @@ -17,33 +12,3 @@ func IsRunningInContainer() bool { return true } - -// GenerateKeyPair generates an RSA key pair and returns the public and private keys in PEM format -func GenerateKeyPair() (privateKeyPEM, publicKeyPEM []byte, err error) { - // Generate RSA key - privateKey, err := rsa.GenerateKey(rand.Reader, 2048) - if err != nil { - return nil, nil, fmt.Errorf("failed to generate private key: %w", err) - } - - // Encode private key to PEM format - privateKeyPEM = pem.EncodeToMemory(&pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(privateKey), - }) - - // Generate public key - publicKey := &privateKey.PublicKey - - // Encode public key to PEM format - publicKeyBytes, err := x509.MarshalPKIXPublicKey(publicKey) - if err != nil { - return nil, nil, fmt.Errorf("failed to marshal public key: %w", err) - } - publicKeyPEM = pem.EncodeToMemory(&pem.Block{ - Type: "RSA PUBLIC KEY", - Bytes: publicKeyBytes, - }) - - return privateKeyPEM, publicKeyPEM, nil -}