diff --git a/dan_layer/consensus_tests/src/support/messaging_impls.rs b/dan_layer/consensus_tests/src/support/messaging_impls.rs index 275b36982..be0cec58b 100644 --- a/dan_layer/consensus_tests/src/support/messaging_impls.rs +++ b/dan_layer/consensus_tests/src/support/messaging_impls.rs @@ -7,27 +7,32 @@ use tari_consensus::{ traits::{InboundMessaging, InboundMessagingError, OutboundMessaging, OutboundMessagingError}, }; use tari_dan_common_types::ShardGroup; +use tari_epoch_manager::EpochManagerReader; use tokio::sync::mpsc; +use super::epoch_manager::TestEpochManager; use crate::support::TestAddress; #[derive(Debug, Clone)] pub struct TestOutboundMessaging { + epoch_manager: TestEpochManager, tx_leader: mpsc::Sender<(TestAddress, HotstuffMessage)>, - _tx_broadcast: mpsc::Sender<(Vec, HotstuffMessage)>, + tx_broadcast: mpsc::Sender<(Vec, HotstuffMessage)>, loopback_sender: mpsc::Sender, } impl TestOutboundMessaging { pub fn create( + epoch_manager: TestEpochManager, tx_leader: mpsc::Sender<(TestAddress, HotstuffMessage)>, tx_broadcast: mpsc::Sender<(Vec, HotstuffMessage)>, ) -> (Self, mpsc::Receiver) { let (loopback_sender, loopback_receiver) = mpsc::channel(100); ( Self { + epoch_manager, tx_leader, - _tx_broadcast: tx_broadcast, + tx_broadcast, loopback_sender, }, loopback_receiver, @@ -61,12 +66,30 @@ impl OutboundMessaging for TestOutboundMessaging { }) } - async fn multicast<'a, T>(&mut self, _shard_group: ShardGroup, _message: T) -> Result<(), OutboundMessagingError> + async fn multicast<'a, T>(&mut self, shard_group: ShardGroup, message: T) -> Result<(), OutboundMessagingError> where Self::Addr: 'a, T: Into + Send, { - Ok(()) + let epoch = self + .epoch_manager + .current_epoch() + .await + .map_err(|e| OutboundMessagingError::UpstreamError(e.into()))?; + let peers: Vec = self + .epoch_manager + .get_committees_by_shard_group(epoch, shard_group) + .await + .map_err(|e| OutboundMessagingError::UpstreamError(e.into()))? + .values() + .flat_map(|c| c.addresses().cloned()) + .collect(); + + self.tx_broadcast.send((peers, message.into())).await.map_err(|_| { + OutboundMessagingError::FailedToEnqueueMessage { + reason: "broadcast channel closed".to_string(), + } + }) } } diff --git a/dan_layer/consensus_tests/src/support/validator/builder.rs b/dan_layer/consensus_tests/src/support/validator/builder.rs index edb72e708..bf807f46f 100644 --- a/dan_layer/consensus_tests/src/support/validator/builder.rs +++ b/dan_layer/consensus_tests/src/support/validator/builder.rs @@ -109,7 +109,14 @@ impl ValidatorBuilder { let (tx_hs_message, rx_hs_message) = mpsc::channel(100); let (tx_leader, rx_leader) = mpsc::channel(100); - let (outbound_messaging, rx_loopback) = TestOutboundMessaging::create(tx_leader, tx_broadcast); + let epoch_manager = self.epoch_manager.as_ref().unwrap().clone_for( + self.address.clone(), + self.public_key.clone(), + self.shard_address, + ); + + let (outbound_messaging, rx_loopback) = + TestOutboundMessaging::create(epoch_manager.clone(), tx_leader, tx_broadcast); let inbound_messaging = TestInboundMessaging::new(self.address.clone(), rx_hs_message, rx_loopback); let store = SqliteStateStore::connect(&self.sql_url).unwrap(); @@ -117,12 +124,6 @@ impl ValidatorBuilder { let transaction_pool = TransactionPool::new(); let (tx_events, _) = broadcast::channel(100); - let epoch_manager = self.epoch_manager.as_ref().unwrap().clone_for( - self.address.clone(), - self.public_key.clone(), - self.shard_address, - ); - let transaction_executor = TestBlockTransactionProcessor::new(self.transaction_executions.clone()); let worker = HotstuffWorker::::new(