diff --git a/grpc/src/client/load_balancing/child_manager.rs b/grpc/src/client/load_balancing/child_manager.rs index 9501bccb1..7f303543f 100644 --- a/grpc/src/client/load_balancing/child_manager.rs +++ b/grpc/src/client/load_balancing/child_manager.rs @@ -37,7 +37,7 @@ use crate::client::load_balancing::{ ChannelController, LbConfig, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, WeakSubchannel, WorkScheduler, }; -use crate::client::name_resolution::{Address, ResolverUpdate}; +use crate::client::name_resolution::{Address, Endpoint, ResolverUpdate}; use crate::client::ConnectivityState; use crate::rt::Runtime; @@ -50,6 +50,7 @@ pub struct ChildManager { update_sharder: Box>, pending_work: Arc>>, runtime: Arc, + updated: bool, } struct Child { @@ -81,6 +82,47 @@ pub trait ResolverUpdateSharder: Send { ) -> Result>>, Box>; } +/// EndpointSharder shards a resolver update into individual endpoints, +/// with each endpoint serving as the unique identifier for a child. +/// +/// The EndpointSharder implements the ResolverUpdateSharder trait, +/// allowing any load-balancing (LB) policy that uses the ChildManager +/// to split a resolver update into individual endpoints, with one endpoint for each child. +pub struct EndpointSharder { + pub builder: Arc, +} + +// Creates a ChildUpdate for each endpoint received. +impl ResolverUpdateSharder for EndpointSharder { + fn shard_update( + &self, + resolver_update: ResolverUpdate, + ) -> Result>>, Box> { + let update: Vec<_> = resolver_update + .endpoints + .unwrap() + .into_iter() + .map(|e| ChildUpdate { + child_identifier: e.clone(), + child_policy_builder: self.builder.clone(), + child_update: ResolverUpdate { + attributes: resolver_update.attributes.clone(), + endpoints: Ok(vec![e.clone()]), + service_config: resolver_update.service_config.clone(), + resolution_note: resolver_update.resolution_note.clone(), + }, + }) + .collect(); + Ok(Box::new(update.into_iter())) + } +} + +impl EndpointSharder { + pub fn new(builder: Arc) -> Self { + Self { builder } + } +} + impl ChildManager { /// Creates a new ChildManager LB policy. shard_update is called whenever a /// resolver_update operation occurs. @@ -94,6 +136,7 @@ impl ChildManager { children: Default::default(), pending_work: Default::default(), runtime, + updated: false, } } @@ -158,8 +201,34 @@ impl ChildManager { // Update the tracked state if the child produced an update. if let Some(state) = channel_controller.picker_update { self.children[child_idx].state = state; + self.updated = true; }; } + + // Forwards ResolverUpdate to all children. This function avoids resharding + // in case you would like to pass resolver errors down to existing children. + pub(crate) fn forward_update_to_children( + &mut self, + channel_controller: &mut dyn ChannelController, + resolver_update: ResolverUpdate, + config: Option<&LbConfig>, + ) { + for child_idx in 0..self.children.len() { + let child = &mut self.children[child_idx]; + let mut channel_controller = WrappedController::new(channel_controller); + let _ = child.policy.resolver_update( + resolver_update.clone(), + config, + &mut channel_controller, + ); + self.resolve_child_controller(channel_controller, child_idx); + } + } + + /// Checks whether a child has produced an update. + pub fn has_updated(&mut self) -> bool { + mem::take(&mut self.updated) + } } impl LbPolicy for ChildManager { @@ -306,8 +375,21 @@ impl LbPolicy for ChildManager } } - fn exit_idle(&mut self, _channel_controller: &mut dyn ChannelController) { - todo!("implement exit_idle") + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + let has_idle = self + .children + .iter() + .any(|child| child.state.connectivity_state == ConnectivityState::Idle); + + if !has_idle { + return; + } + for child_idx in 0..self.children.len() { + let child = &mut self.children[child_idx]; + let mut channel_controller = WrappedController::new(channel_controller); + child.policy.exit_idle(&mut channel_controller); + self.resolve_child_controller(channel_controller, child_idx); + } } } @@ -359,61 +441,20 @@ impl WorkScheduler for ChildWorkScheduler { #[cfg(test)] mod test { - use crate::client::load_balancing::child_manager::{ - Child, ChildManager, ChildUpdate, ChildWorkScheduler, ResolverUpdateSharder, - }; + use crate::client::load_balancing::child_manager::{ChildManager, EndpointSharder}; use crate::client::load_balancing::test_utils::{ - self, StubPolicy, StubPolicyFuncs, TestChannelController, TestEvent, TestSubchannel, - TestWorkScheduler, + self, StubPolicyData, StubPolicyFuncs, TestChannelController, TestEvent, }; use crate::client::load_balancing::{ - ChannelController, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, ParsedJsonLbConfig, - Pick, PickResult, Picker, QueuingPicker, Subchannel, SubchannelState, GLOBAL_LB_REGISTRY, + ChannelController, LbPolicy, LbPolicyBuilder, LbState, QueuingPicker, Subchannel, + SubchannelState, GLOBAL_LB_REGISTRY, }; use crate::client::name_resolution::{Address, Endpoint, ResolverUpdate}; - use crate::client::service_config::{LbConfig, ServiceConfig}; use crate::client::ConnectivityState; - use crate::rt::{default_runtime, Runtime}; - use crate::service::Request; - use serde::{Deserialize, Serialize}; - use std::collections::{HashMap, HashSet}; - use std::error::Error; + use crate::rt::default_runtime; use std::panic; use std::sync::Arc; - use std::sync::Mutex; use tokio::sync::mpsc; - use tonic::metadata::MetadataMap; - - // TODO: This needs to be moved to a common place that can be shared between - // round_robin and this test. This EndpointSharder maps endpoints to - // children policies. - struct EndpointSharder { - builder: Arc, - } - - impl ResolverUpdateSharder for EndpointSharder { - fn shard_update( - &self, - resolver_update: ResolverUpdate, - ) -> Result>>, Box> - { - let mut sharded_endpoints = Vec::new(); - for endpoint in resolver_update.endpoints.unwrap().iter() { - let child_update = ChildUpdate { - child_identifier: endpoint.clone(), - child_policy_builder: self.builder.clone(), - child_update: ResolverUpdate { - attributes: resolver_update.attributes.clone(), - endpoints: Ok(vec![endpoint.clone()]), - service_config: resolver_update.service_config.clone(), - resolution_note: resolver_update.resolution_note.clone(), - }, - }; - sharded_endpoints.push(child_update); - } - Ok(Box::new(sharded_endpoints.into_iter())) - } - } // Sets up the test environment. // @@ -444,7 +485,7 @@ mod test { let (tx_events, rx_events) = mpsc::unbounded_channel::(); let tcc = Box::new(TestChannelController { tx_events }); let builder: Arc = GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap(); - let endpoint_sharder = EndpointSharder { builder: builder }; + let endpoint_sharder = EndpointSharder::new(builder); let child_manager = ChildManager::new(Box::new(endpoint_sharder), default_runtime()); (rx_events, Box::new(child_manager), tcc) } @@ -517,19 +558,20 @@ mod test { // Defines the functions resolver_update and subchannel_update to test // aggregate_states. fn create_verifying_funcs_for_aggregate_tests() -> StubPolicyFuncs { + let _data = StubPolicyData::default(); StubPolicyFuncs { // Closure for resolver_update. resolver_update should only receive // one endpoint and create one subchannel for the endpoint it // receives. - resolver_update: Some(move |update: ResolverUpdate, _, controller| { + resolver_update: Some(move |_data, update: ResolverUpdate, _, controller| { assert_eq!(update.endpoints.iter().len(), 1); let endpoint = update.endpoints.unwrap().pop().unwrap(); - let subchannel = controller.new_subchannel(&endpoint.addresses[0]); + let _ = controller.new_subchannel(&endpoint.addresses[0]); Ok(()) }), // Closure for subchannel_update. Sends a picker of the same state // that was passed to it. - subchannel_update: Some(move |updated_subchannel, state, controller| { + subchannel_update: Some(move |_data, _updated_subchannel, state, controller| { controller.update_picker(LbState { connectivity_state: state.connectivity_state, picker: Arc::new(QueuingPicker {}), diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index 168da55e3..4bfebf22f 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -54,6 +54,7 @@ use crate::client::{ pub mod child_manager; pub mod pick_first; +pub mod round_robin; #[cfg(test)] pub mod test_utils; diff --git a/grpc/src/client/load_balancing/round_robin.rs b/grpc/src/client/load_balancing/round_robin.rs new file mode 100644 index 000000000..256ab188a --- /dev/null +++ b/grpc/src/client/load_balancing/round_robin.rs @@ -0,0 +1,1352 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + * + */ + +use crate::client::load_balancing::child_manager::{ChildManager, EndpointSharder}; +use crate::client::load_balancing::pick_first::{self}; +use crate::client::load_balancing::{ + ChannelController, Failing, LbConfig, LbPolicy, LbPolicyBuilder, LbPolicyOptions, LbState, + PickResult, Picker, Subchannel, SubchannelState, WorkScheduler, GLOBAL_LB_REGISTRY, +}; +use crate::client::name_resolution::{Endpoint, ResolverUpdate}; +use crate::client::ConnectivityState; +use crate::service::Request; +use std::error::Error; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::{Arc, Once}; + +pub static POLICY_NAME: &str = "round_robin"; +static START: Once = Once::new(); + +struct RoundRobinBuilder {} + +impl LbPolicyBuilder for RoundRobinBuilder { + fn build(&self, options: LbPolicyOptions) -> Box { + let resolver_update_sharder = EndpointSharder::new( + GLOBAL_LB_REGISTRY + .get_policy(pick_first::POLICY_NAME) + .unwrap(), + ); + let child_manager = Box::new(ChildManager::new( + Box::new(resolver_update_sharder), + options.runtime, + )); + Box::new(RoundRobinPolicy::new(child_manager, options.work_scheduler)) + } + + fn name(&self) -> &'static str { + POLICY_NAME + } +} + +struct RoundRobinPolicy { + child_manager: Box>, + work_scheduler: Arc, + endpoints_available: bool, // Whether endpoints were available from the previous resolver updates. +} + +impl RoundRobinPolicy { + fn new( + child_manager: Box>, + work_scheduler: Arc, + ) -> Self { + Self { + child_manager, + work_scheduler, + endpoints_available: false, + } + } + + // Sends aggregate picker based on states of children. + // + // If the aggregate state is Idle or Connecting, send pickers of IDLE/CONNECTING + // children. If the aggregate state is Ready, send the pickers of all Ready + // children. If the aggregate state is Transient Failure, send pickers of + // TRANSIENT FAILURE children. + fn send_aggregate_picker(&mut self, channel_controller: &mut dyn ChannelController) { + let state = self.child_manager.aggregate_states(); + let pickers = self + .child_manager + .child_states() + .filter(|&(_, cs)| (cs.connectivity_state == state)) + .map(|(_, cs)| cs.picker.clone()) + .collect(); + let picker_update = LbState { + connectivity_state: state, + picker: Arc::new(RoundRobinPicker::new(pickers)), + }; + channel_controller.update_picker(picker_update); + } + + fn move_to_transient_failure( + &mut self, + channel_controller: &mut dyn ChannelController, + error: String, + ) { + channel_controller.update_picker(LbState { + connectivity_state: ConnectivityState::TransientFailure, + picker: Arc::new(Failing { error }), + }); + channel_controller.request_resolution(); + } + + // Moves children from Idle and then sends a picker based on aggregate + // state. + fn resolve_child_updates(&mut self, channel_controller: &mut dyn ChannelController) { + if !self.child_manager.has_updated() { + return; + } + self.child_manager.exit_idle(channel_controller); + self.send_aggregate_picker(channel_controller); + } +} + +impl LbPolicy for RoundRobinPolicy { + fn resolver_update( + &mut self, + update: ResolverUpdate, + config: Option<&LbConfig>, + channel_controller: &mut dyn ChannelController, + ) -> Result<(), Box> { + match update.clone().endpoints { + Ok(endpoints) => { + if endpoints.is_empty() { + // Call resolver_update to clear children. + let _ = self + .child_manager + .resolver_update(update, config, channel_controller); + self.resolve_child_updates(channel_controller); + self.move_to_transient_failure( + channel_controller, + "received empty address list from the name resolver".into(), + ); + self.endpoints_available = false; + return Err("received empty address list from the name resolver".into()); + } + self.endpoints_available = true; + let _ = self + .child_manager + .resolver_update(update, config, channel_controller); + self.resolve_child_updates(channel_controller); + } + Err(resolver_error) => { + if !self.endpoints_available { + self.move_to_transient_failure(channel_controller, resolver_error.clone()); + return Err(resolver_error.into()); + } else { + // If there are children, forward the error to the children. + self.child_manager.forward_update_to_children( + channel_controller, + update, + config, + ); + self.resolve_child_updates(channel_controller); + } + } + } + Ok(()) + } + + fn subchannel_update( + &mut self, + subchannel: Arc, + state: &SubchannelState, + channel_controller: &mut dyn ChannelController, + ) { + self.child_manager + .subchannel_update(subchannel, state, channel_controller); + self.resolve_child_updates(channel_controller); + } + + fn work(&mut self, channel_controller: &mut dyn ChannelController) { + self.child_manager.work(channel_controller); + self.resolve_child_updates(channel_controller); + } + + fn exit_idle(&mut self, channel_controller: &mut dyn ChannelController) { + self.child_manager.exit_idle(channel_controller); + self.resolve_child_updates(channel_controller); + } +} + +/// Register round robin as a LbPolicy. +pub fn reg() { + START.call_once(|| { + GLOBAL_LB_REGISTRY.add_builder(RoundRobinBuilder {}); + }); +} + +struct RoundRobinPicker { + pickers: Vec>, + next: AtomicUsize, +} + +impl RoundRobinPicker { + fn new(pickers: Vec>) -> Self { + let random_index: usize = rand::random_range(..pickers.len()); + Self { + pickers, + next: AtomicUsize::new(random_index), + } + } +} + +impl Picker for RoundRobinPicker { + fn pick(&self, request: &Request) -> PickResult { + let len = self.pickers.len(); + let idx = self.next.fetch_add(1, Ordering::Relaxed) % len; + self.pickers[idx].pick(request) + } +} + +#[cfg(test)] +mod test { + use crate::client::load_balancing::child_manager::{ChildManager, EndpointSharder}; + use crate::client::load_balancing::round_robin::{self, RoundRobinPolicy}; + use crate::client::load_balancing::test_utils::{ + self, StubPolicyData, StubPolicyFuncs, TestChannelController, TestEvent, TestWorkScheduler, + }; + use crate::client::load_balancing::{ + ChannelController, Failing, LbPolicy, LbPolicyBuilder, LbState, Pick, PickResult, Picker, + QueuingPicker, Subchannel, SubchannelState, GLOBAL_LB_REGISTRY, + }; + use crate::client::name_resolution::{Address, Endpoint, ResolverUpdate}; + use crate::client::ConnectivityState; + use crate::rt::default_runtime; + use crate::service::Request; + use std::collections::HashSet; + use std::panic; + use std::sync::Arc; + use tokio::sync::mpsc; + use tonic::metadata::MetadataMap; + + // Sets up the test environment. + // + // Performs the following: + // 1. Creates a work scheduler. + // 2. Creates a fake channel that acts as a channel controller. + // 3. Creates an StubPolicyBuilder with StubFuncs and the name of the test + // passed in. + // 4. Creates an EndpointSharder with StubPolicyBuilder passed in as the + // child policy. + // 5. Creates a ChildManager with the EndpointSharder. + // 6. Create a Round Robin policy with the ChildManager passed in. + // + // Returns the following: + // 1. A receiver for events initiated by the LB policy (like creating a new + // subchannel, sending a new picker etc). + // 2. The Round Robin to send resolver and subchannel updates from the test. + // 3. The controller to pass to the LB policy as part of the updates. + fn setup( + funcs: StubPolicyFuncs, + test_name: &'static str, + ) -> ( + mpsc::UnboundedReceiver, + Box, + Box, + ) { + round_robin::reg(); + test_utils::reg_stub_policy(test_name, funcs); + let resolver_update_sharder = + EndpointSharder::new(GLOBAL_LB_REGISTRY.get_policy(test_name).unwrap()); + + let child_manager = Box::new(ChildManager::new( + Box::new(resolver_update_sharder), + default_runtime(), + )); + let (tx_events, rx_events) = mpsc::unbounded_channel(); + let work_scheduler = Arc::new(TestWorkScheduler { + tx_events: tx_events.clone(), + }); + let tcc = Box::new(TestChannelController { tx_events }); + + let lb_policy = Box::new(RoundRobinPolicy::new(child_manager, work_scheduler)); + (rx_events, lb_policy, tcc) + } + + struct TestSubchannelList { + subchannels: Vec>, + } + + impl TestSubchannelList { + fn new(addresses: &Vec
, channel_controller: &mut dyn ChannelController) -> Self { + let mut scl = TestSubchannelList { + subchannels: Vec::new(), + }; + for address in addresses { + let sc = channel_controller.new_subchannel(address); + scl.subchannels.push(sc); + } + scl + } + + fn contains(&self, sc: &Arc) -> bool { + self.subchannels.contains(sc) + } + } + + fn create_n_endpoints_with_k_addresses(n: usize, k: usize) -> Vec { + let mut endpoints = Vec::with_capacity(n); + for i in 0..n { + let mut addresses: Vec
= Vec::with_capacity(k); + for j in 0..k { + addresses.push(Address { + address: format!("{}.{}.{}.{}:{}", i + 1, i + 1, i + 1, i + 1, j).into(), + ..Default::default() + }); + } + endpoints.push(Endpoint { + addresses, + ..Default::default() + }) + } + endpoints + } + + // Sends a resolver update to the LB policy with the specified endpoint. + fn send_resolver_update_to_policy( + lb_policy: &mut dyn LbPolicy, + endpoints: Vec, + tcc: &mut dyn ChannelController, + ) { + let update = ResolverUpdate { + endpoints: Ok(endpoints), + ..Default::default() + }; + let _ = lb_policy.resolver_update(update, None, tcc); + } + + fn send_resolver_error_to_policy( + lb_policy: &mut dyn LbPolicy, + err: String, + tcc: &mut dyn ChannelController, + ) { + let update = ResolverUpdate { + endpoints: Err(err), + ..Default::default() + }; + let _ = lb_policy.resolver_update(update, None, tcc); + } + + fn move_subchannel_to_state( + lb_policy: &mut dyn LbPolicy, + subchannel: Arc, + tcc: &mut dyn ChannelController, + state: ConnectivityState, + ) { + lb_policy.subchannel_update( + subchannel, + &SubchannelState { + connectivity_state: state, + ..Default::default() + }, + tcc, + ); + } + + fn move_subchannel_to_transient_failure( + lb_policy: &mut dyn LbPolicy, + subchannel: Arc, + err: &str, + tcc: &mut dyn ChannelController, + ) { + lb_policy.subchannel_update( + subchannel, + &SubchannelState { + connectivity_state: ConnectivityState::TransientFailure, + last_connection_error: Some(Arc::from(Box::from(err.to_owned()))), + }, + tcc, + ); + } + + struct TestOneSubchannelPicker { + sc: Arc, + } + + impl Picker for TestOneSubchannelPicker { + fn pick(&self, request: &Request) -> PickResult { + PickResult::Pick(Pick { + subchannel: self.sc.clone(), + on_complete: None, + metadata: MetadataMap::new(), + }) + } + } + + fn address_list_from_endpoints(endpoints: &[Endpoint]) -> Vec
{ + let mut addresses: Vec
= endpoints + .iter() + .flat_map(|ep| ep.addresses.clone()) + .collect(); + let mut uniques = HashSet::new(); + addresses.retain(|e| uniques.insert(e.clone())); + addresses + } + + struct PickFirstState { + subchannel_list: Option, + selected_subchannel: Option>, + addresses: Vec
, + connectivity_state: ConnectivityState, + } + + // TODO: Replace with Pick First child once merged. + // Defines the functions resolver_update and subchannel_update to test round + // robin. Simple version of PickFirst. It just creates a subchannel and then + // sends the appropriate picker update. + fn create_funcs_for_roundrobin_tests() -> StubPolicyFuncs { + StubPolicyFuncs { + // Closure for resolver_update. It creates a subchannel for the + // endpoint it receives and stores which endpoint it received and + // which subchannel this child created in the data field. + resolver_update: Some( + move |data: &mut StubPolicyData, update: ResolverUpdate, _, channel_controller| { + let state = data + .test_data + .get_or_insert_with(|| { + Box::new(PickFirstState { + subchannel_list: None, + selected_subchannel: None, + addresses: vec![], + connectivity_state: ConnectivityState::Connecting, + }) + }) + .downcast_mut::() + .unwrap(); + + match update.endpoints { + Ok(endpoints) => { + let new_addresses = address_list_from_endpoints(&endpoints); + if new_addresses.is_empty() { + channel_controller.update_picker(LbState { + connectivity_state: ConnectivityState::TransientFailure, + picker: Arc::new(Failing { + error: "received empty address list from the name resolver" + .to_string(), + }), + }); + state.connectivity_state = ConnectivityState::TransientFailure; + channel_controller.request_resolution(); + return Err( + "received empty address list from the name resolver".into() + ); + } + + if state.connectivity_state != ConnectivityState::Idle { + state.subchannel_list = Some(TestSubchannelList::new( + &new_addresses, + channel_controller, + )); + } + state.addresses = new_addresses; + } + Err(error) => { + if state.addresses.is_empty() + || state.connectivity_state == ConnectivityState::TransientFailure + { + channel_controller.update_picker(LbState { + connectivity_state: ConnectivityState::TransientFailure, + picker: Arc::new(Failing { + error: error.to_string(), + }), + }); + state.connectivity_state = ConnectivityState::TransientFailure; + channel_controller.request_resolution(); + } + } + } + Ok(()) + }, + ), + // Closure for subchannel_update. Verify that the subchannel that + // being updated now is the same one that this child policy created + // in resolver_update. It then sends a picker of the same state that + // was passed to it. + subchannel_update: Some( + move |data: &mut StubPolicyData, subchannel, state, channel_controller| { + // Retrieve the specific TestState from the generic test_data field. + // This downcasts the `Any` trait object + if let Some(test_data) = data.test_data.as_mut() { + if let Some(test_state) = test_data.downcast_mut::() { + if let Some(scl) = &mut test_state.subchannel_list { + assert!( + scl.contains(&subchannel), + "subchannel_update received an update for a subchannel it does not own." + ); + if scl.contains(&subchannel) { + match state.connectivity_state { + ConnectivityState::Ready => { + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(TestOneSubchannelPicker { + sc: subchannel, + }), + }); + test_state.connectivity_state = + ConnectivityState::Ready; + } + ConnectivityState::Idle => {} + ConnectivityState::Connecting => { + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(QueuingPicker {}), + }); + test_state.connectivity_state = + ConnectivityState::Connecting; + } + ConnectivityState::TransientFailure => { + channel_controller.update_picker(LbState { + connectivity_state: state.connectivity_state, + picker: Arc::new(Failing { + error: state + .last_connection_error + .as_ref() + .unwrap() + .to_string(), + }), + }); + test_state.connectivity_state = + ConnectivityState::TransientFailure; + } + } + } + } + } + } + }, + ), + } + } + + // Creates a new endpoint with the specified number of addresses. + fn create_endpoint_with_n_addresses(n: usize) -> Endpoint { + let mut addresses = Vec::new(); + for i in 0..n { + addresses.push(Address { + address: format!("{}.{}.{}.{}:{}", i, i, i, i, i).into(), + ..Default::default() + }); + } + Endpoint { + addresses, + ..Default::default() + } + } + + // Verifies that the expected number of subchannels is created. Returns the + // subchannels created. + async fn verify_subchannel_creation_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + number_of_subchannels: usize, + ) -> Vec> { + let mut subchannels = Vec::new(); + for _ in 0..number_of_subchannels { + match rx_events.recv().await.unwrap() { + TestEvent::NewSubchannel(sc) => { + subchannels.push(sc); + } + other => panic!("unexpected event {:?}", other), + }; + } + subchannels + } + + // Verifies that the channel moves to CONNECTING state with a queuing picker. + // + // Returns the picker for tests to make more picks, if required. + async fn verify_connecting_picker_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + ) -> Arc { + println!("verify connecting picker"); + match rx_events.recv().await.unwrap() { + TestEvent::UpdatePicker(update) => { + println!("connectivity state is {}", update.connectivity_state); + assert!(update.connectivity_state == ConnectivityState::Connecting); + let req = test_utils::new_request(); + assert!(update.picker.pick(&req) == PickResult::Queue); + return update.picker.clone(); + } + other => panic!("unexpected event {:?}", other), + } + } + + // Verifies that the channel moves to READY state with a picker that returns + // the given subchannel. + // + // Returns the picker for tests to make more picks, if required. + async fn verify_ready_picker_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + subchannel: Arc, + ) -> Arc { + println!("verify ready picker"); + match rx_events.recv().await.unwrap() { + TestEvent::UpdatePicker(update) => { + println!( + "connectivity state for ready picker is {}", + update.connectivity_state + ); + assert!(update.connectivity_state == ConnectivityState::Ready); + let req = test_utils::new_request(); + match update.picker.pick(&req) { + PickResult::Pick(pick) => { + println!("selected subchannel is {}", pick.subchannel); + println!("should've been selected subchannel is {}", subchannel); + assert!(pick.subchannel == subchannel.clone()); + update.picker.clone() + } + other => panic!("unexpected pick result {}", other), + } + } + other => panic!("unexpected event {:?}", other), + } + } + + // Returns the picker for when there are multiple pickers in the ready + // picker. + async fn verify_roundrobin_ready_picker_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + ) -> Arc { + println!("verify ready picker"); + match rx_events.recv().await.unwrap() { + TestEvent::UpdatePicker(update) => { + println!( + "connectivity state for ready picker is {}", + update.connectivity_state + ); + assert!(update.connectivity_state == ConnectivityState::Ready); + let req = test_utils::new_request(); + match update.picker.pick(&req) { + PickResult::Pick(pick) => update.picker.clone(), + other => panic!("unexpected pick result {}", other), + } + } + other => panic!("unexpected event {:?}", other), + } + } + + // Verifies that the channel moves to TRANSIENT_FAILURE state with a picker + // that returns an error with the given message. The error code should be + // UNAVAILABLE.. + // + // Returns the picker for tests to make more picks, if required. + async fn verify_transient_failure_picker_from_policy( + rx_events: &mut mpsc::UnboundedReceiver, + want_error: String, + ) -> Arc { + let picker = match rx_events.recv().await.unwrap() { + TestEvent::UpdatePicker(update) => { + assert!(update.connectivity_state == ConnectivityState::TransientFailure); + let req = test_utils::new_request(); + match update.picker.pick(&req) { + PickResult::Fail(status) => { + assert!(status.code() == tonic::Code::Unavailable); + assert!(status.message().contains(&want_error)); + update.picker.clone() + } + other => panic!("unexpected pick result {}", other), + } + } + other => panic!("unexpected event {:?}", other), + }; + picker + } + + // Verifies that the LB policy requests re-resolution. + async fn verify_resolution_request(rx_events: &mut mpsc::UnboundedReceiver) { + println!("verifying resolution request"); + match rx_events.recv().await.unwrap() { + TestEvent::RequestResolution => {} + other => panic!("unexpected event {:?}", other), + }; + } + + const DEFAULT_TEST_SHORT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(100); + + async fn verify_no_activity_from_policy(rx_events: &mut mpsc::UnboundedReceiver) { + tokio::select! { + _ = tokio::time::sleep(DEFAULT_TEST_SHORT_TIMEOUT) => {} + event = rx_events.recv() => { + panic!("unexpected event {:?}", event.unwrap()); + } + } + } + + #[test] + fn roundrobin_builder_name() -> Result<(), String> { + round_robin::reg(); + + let builder: Arc = match GLOBAL_LB_REGISTRY.get_policy("round_robin") { + Some(b) => b, + None => { + return Err(String::from("round_robin LB policy not registered")); + } + }; + assert_eq!(builder.name(), "round_robin"); + Ok(()) + } + + // Tests the scenario where the resolver returns an error before a valid + // update. The LB policy should move to TRANSIENT_FAILURE state with a + // failing picker. + #[tokio::test] + async fn roundrobin_resolver_error_before_a_valid_update() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_resolver_error_before_a_valid_update", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + let resolver_error = String::from("resolver error"); + send_resolver_error_to_policy(lb_policy, resolver_error.clone(), tcc); + verify_transient_failure_picker_from_policy(&mut rx_events, resolver_error).await; + } + + // Tests the scenario where the resolver returns an error after a valid update + // and the LB policy has moved to READY. The LB policy should ignore the error + // and continue using the previously received update. + #[tokio::test] + async fn roundrobin_resolver_error_after_a_valid_update_in_ready() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_resolver_error_after_a_valid_update_in_ready", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + let endpoint = create_endpoint_with_n_addresses(1); + send_resolver_update_to_policy(lb_policy, vec![endpoint], tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 1).await; + + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + let picker = verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + let resolver_error = String::from("resolver error"); + send_resolver_error_to_policy(lb_policy, resolver_error.clone(), tcc); + verify_no_activity_from_policy(&mut rx_events).await; + + let req = test_utils::new_request(); + match picker.pick(&req) { + PickResult::Pick(pick) => { + assert!(pick.subchannel == subchannels[0].clone()); + } + other => panic!("unexpected pick result {}", other), + } + } + + // Tests the scenario where the resolver returns an error after a valid update + // and the LB policy is still trying to connect. The LB policy should ignore the + // error and continue using the previously received update. + #[tokio::test] + async fn roundrobin_resolver_error_after_a_valid_update_in_connecting() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_resolver_error_after_a_valid_update_in_connecting", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + + let endpoint = create_endpoint_with_n_addresses(1); + send_resolver_update_to_policy(lb_policy, vec![endpoint], tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 1).await; + + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + let picker = verify_connecting_picker_from_policy(&mut rx_events).await; + + let resolver_error = String::from("resolver error"); + + send_resolver_error_to_policy(lb_policy, resolver_error.clone(), tcc); + + verify_no_activity_from_policy(&mut rx_events).await; + + let req = test_utils::new_request(); + match picker.pick(&req) { + PickResult::Queue => {} + other => panic!("unexpected pick result {}", other), + } + } + + // Tests the scenario where the resolver returns an error after a valid update + // and the LB policy has moved to TRANSIENT_FAILURE after attempting to connect + // to all addresses. The LB policy should send a new picker that returns the + // error from the resolver. + #[tokio::test] + async fn roundrobin_resolver_error_after_a_valid_update_in_tf() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_resolver_error_after_a_valid_update_in_tf", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + let endpoint = create_endpoint_with_n_addresses(1); + send_resolver_update_to_policy(lb_policy, vec![endpoint], tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 1).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + let connection_error = String::from("test connection error"); + move_subchannel_to_transient_failure( + lb_policy, + subchannels[0].clone(), + &connection_error, + tcc, + ); + verify_transient_failure_picker_from_policy(&mut rx_events, connection_error).await; + let resolver_error = String::from("resolver error"); + send_resolver_error_to_policy(lb_policy, resolver_error.clone(), tcc); + verify_resolution_request(&mut rx_events).await; + verify_transient_failure_picker_from_policy(&mut rx_events, resolver_error).await; + } + + // Round Robin should round robin across endpoints. + #[tokio::test] + async fn roundrobin_picks_are_round_robin() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_picks_are_round_robin", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + let endpoints = create_n_endpoints_with_k_addresses(2, 1); + send_resolver_update_to_policy(lb_policy, endpoints, tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + move_subchannel_to_state( + lb_policy, + subchannels[1].clone(), + tcc, + ConnectivityState::Ready, + ); + let picker = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + let req = test_utils::new_request(); + let mut picked = Vec::new(); + for _ in 0..4 { + match picker.pick(&req) { + PickResult::Pick(pick) => { + println!("picked subchannel is {}", pick.subchannel); + picked.push(pick.subchannel.clone()) + } + other => panic!("unexpected pick result {}", other), + } + } + assert!( + picked[0] != picked[1].clone(), + "Should alternate between subchannels" + ); + assert_eq!(&picked[0], &picked[2]); + assert_eq!(&picked[1], &picked[3]); + assert!(picked.contains(&subchannels[0])); + assert!(picked.contains(&subchannels[1])); + } + + // If round robin receives no endpoints in a resolver update, + // it should go into transient failure. + #[tokio::test] + async fn roundrobin_endpoints_removed() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_addresses_removed", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + + let endpoints = create_n_endpoints_with_k_addresses(2, 1); + send_resolver_update_to_policy(lb_policy, endpoints, tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + let update = ResolverUpdate { + endpoints: Ok(vec![]), + ..Default::default() + }; + let _ = lb_policy.resolver_update(update, None, tcc); + let want_error = "received empty address list from the name resolver"; + verify_transient_failure_picker_from_policy(&mut rx_events, want_error.to_string()).await; + verify_resolution_request(&mut rx_events).await; + } + + // Round robin should only round robin across children that are ready. + // If a child leaves the ready state, Round Robin should only + // pick from the children that are still Ready. + #[tokio::test] + async fn roundrobin_one_endpoint_down() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_one_endpoint_down", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + let endpoints = create_n_endpoints_with_k_addresses(2, 1); + send_resolver_update_to_policy(lb_policy, endpoints, tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + let picker = verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + move_subchannel_to_state( + lb_policy, + subchannels[1].clone(), + tcc, + ConnectivityState::Ready, + ); + let picker = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + let req = test_utils::new_request(); + let mut picked = Vec::new(); + for _ in 0..4 { + match picker.pick(&req) { + PickResult::Pick(pick) => { + println!("picked subchannel is {}", pick.subchannel); + picked.push(pick.subchannel.clone()) + } + other => panic!("unexpected pick result {}", other), + } + } + assert!( + picked[0] != picked[1].clone(), + "Should alternate between subchannels" + ); + assert_eq!(&picked[0], &picked[2]); + assert_eq!(&picked[1], &picked[3]); + + assert!(picked.contains(&subchannels[0])); + assert!(picked.contains(&subchannels[1])); + let subchannel_being_removed = subchannels[1].clone(); + let error = "endpoint down"; + move_subchannel_to_transient_failure(lb_policy, subchannels[1].clone(), error, tcc); + + let new_picker = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + + let req = test_utils::new_request(); + let mut picked = Vec::new(); + for _ in 0..4 { + match new_picker.pick(&req) { + PickResult::Pick(pick) => { + println!("picked subchannel is {}", pick.subchannel); + picked.push(pick.subchannel.clone()) + } + other => panic!("unexpected pick result {}", other), + } + } + + assert_eq!(&picked[0], &picked[2]); + assert_eq!(&picked[1], &picked[3]); + assert!(picked.contains(&subchannels[0])); + assert!(!picked.contains(&subchannel_being_removed)); + } + + // If Round Robin receives a resolver update that removes an endpoint and + // adds a new endpoint from a previous update, that endpoint's subchannels + // should not be apart of its picks anymore and should be removed. It should + // then roundrobin across the endpoints it still has and the new one. + #[tokio::test] + async fn roundrobin_pick_after_resolved_updated_hosts() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_pick_after_resolved_updated_hosts", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + + // Two initial endpoints: subchannel_one, subchannel_two + let addr_one = Address { + address: "subchannel_one".to_string().into(), // <-- fixed spelling + ..Default::default() + }; + let addr_two = Address { + address: "subchannel_two".to_string().into(), + ..Default::default() + }; + let endpoint_one = Endpoint { + addresses: vec![addr_one], + ..Default::default() + }; + let endpoint_two = Endpoint { + addresses: vec![addr_two], + ..Default::default() + }; + + send_resolver_update_to_policy(lb_policy, vec![endpoint_one, endpoint_two.clone()], tcc); + + // Start with two subchannels created + let all_subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + let subchannel_one = all_subchannels + .iter() + .find(|sc| sc.address().address == "subchannel_one".to_string().into()) + .unwrap(); + let subchannel_two = all_subchannels + .iter() + .find(|sc| sc.address().address == "subchannel_two".to_string().into()) + .unwrap(); + + move_subchannel_to_state( + lb_policy, + subchannel_one.clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannel_two.clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + + move_subchannel_to_state( + lb_policy, + subchannel_one.clone(), + tcc, + ConnectivityState::Ready, + ); + verify_ready_picker_from_policy(&mut rx_events, subchannel_one.clone()).await; + move_subchannel_to_state( + lb_policy, + subchannel_two.clone(), + tcc, + ConnectivityState::Ready, + ); + let picker = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + + let req = test_utils::new_request(); + let mut picked = Vec::new(); + for _ in 0..4 { + match picker.pick(&req) { + PickResult::Pick(pick) => picked.push(pick.subchannel.clone()), + other => panic!("unexpected pick result {}", other), + } + } + assert!(picked.contains(&subchannel_one)); + assert!(picked.contains(&subchannel_two)); + + // Resolver update removes subchannel_one and adds "new" + let new_addr = Address { + address: "new".to_string().into(), + ..Default::default() + }; + let new_endpoint = Endpoint { + addresses: vec![new_addr], + ..Default::default() + }; + + send_resolver_update_to_policy(lb_policy, vec![endpoint_two, new_endpoint], tcc); + + let new_subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + let new_sc = new_subchannels + .iter() + .find(|sc| sc.address().address == "new".to_string().into()) + .unwrap(); + let old_sc = new_subchannels + .iter() + .find(|sc| sc.address().address == "subchannel_two".to_string().into()) // <-- fixed lookup + .unwrap(); + + move_subchannel_to_state(lb_policy, old_sc.clone(), tcc, ConnectivityState::Ready); + let _ = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + + move_subchannel_to_state( + lb_policy, + new_sc.clone(), + tcc, + ConnectivityState::Connecting, + ); + let _ = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state(lb_policy, new_sc.clone(), tcc, ConnectivityState::Ready); + let new_picker = verify_roundrobin_ready_picker_from_policy(&mut rx_events).await; + + let req = test_utils::new_request(); + let mut picked = Vec::new(); + for _ in 0..4 { + match new_picker.pick(&req) { + PickResult::Pick(pick) => picked.push(pick.subchannel.clone()), + other => panic!("unexpected pick result {}", other), + } + } + assert!(picked.contains(&old_sc)); + assert!(picked.contains(&new_sc)); + assert!(!picked.contains(&subchannel_one)); + } + + // Round robin should stay in transient failure until a child reports ready + #[tokio::test] + async fn roundrobin_stay_transient_failure_until_ready() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_stay_transient_failure_until_ready", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + let endpoints = create_n_endpoints_with_k_addresses(2, 1); + send_resolver_update_to_policy(lb_policy, endpoints, tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[1].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + let first_error = String::from("test connection error 1"); + move_subchannel_to_transient_failure(lb_policy, subchannels[0].clone(), &first_error, tcc); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_transient_failure(lb_policy, subchannels[1].clone(), &first_error, tcc); + verify_transient_failure_picker_from_policy(&mut rx_events, first_error).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + } + + // Tests the scenario where the resolver returns an update with no endpoints + // (before sending any valid update). The LB policy should move to + // TRANSIENT_FAILURE state with a failing picker. + #[tokio::test] + async fn roundrobin_zero_endpoints_from_resolver_before_valid_update() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_zero_endpoints_from_resolver_before_valid_update", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + send_resolver_update_to_policy(lb_policy, vec![], tcc); + verify_transient_failure_picker_from_policy( + &mut rx_events, + "received empty address list from the name resolver".to_string(), + ) + .await; + } + + // Tests the scenario where the resolver returns an update with no endpoints + // after sending a valid update (and the LB policy has moved to READY). The LB + // policy should move to TRANSIENT_FAILURE state with a failing picker. + #[tokio::test] + async fn roundrobin_zero_endpoints_from_resolver_after_valid_update() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_zero_endpoints_from_resolver_after_valid_update", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + + let endpoint = create_endpoint_with_n_addresses(1); + send_resolver_update_to_policy(lb_policy, vec![endpoint], tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 1).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + let update = ResolverUpdate { + endpoints: Ok(vec![]), + ..Default::default() + }; + assert!(lb_policy.resolver_update(update, None, tcc).is_err()); + verify_transient_failure_picker_from_policy( + &mut rx_events, + "received empty address list from the name resolver".to_string(), + ) + .await; + verify_resolution_request(&mut rx_events).await; + } + + // Tests the scenario where the resolver returns an update with multiple + // address. The LB policy should create subchannels for all address, and attempt + // to connect to them in order, until a connection succeeds, at which point it + // should move to READY state with a picker that returns that subchannel. + #[tokio::test] + async fn roundrobin_with_multiple_backends_first_backend_is_ready() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_with_multiple_backends_first_backend_is_ready", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + + let endpoint = create_n_endpoints_with_k_addresses(2, 1); + send_resolver_update_to_policy(lb_policy, endpoint, tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + + let picker = verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + + let req = test_utils::new_request(); + // First pick determines the only subchannel the picker should yield + let first_sc = match picker.pick(&req) { + PickResult::Pick(p) => p.subchannel.clone(), + other => panic!("unexpected pick result {}", other), + }; + + for _ in 0..7 { + match picker.pick(&req) { + PickResult::Pick(p) => { + assert!( + Arc::ptr_eq(&first_sc, &p.subchannel), + "READY picker should contain exactly one subchannel" + ); + } + other => panic!("unexpected pick result {}", other), + } + } + } + + // Tests the scenario where the resolver returns an update with multiple + // addresses and the LB policy successfully connects to first one and moves to + // READY. The resolver then returns an update with a new address list that + // contains the address of the currently connected subchannel. The LB policy + // should create subchannels for the new addresses, and then see that the + // currently connected subchannel is in the new address list. It should then + // send a new READY picker that returns the currently connected subchannel. + #[tokio::test] + async fn roundrobin_resolver_update_contains_currently_ready_subchannel() { + let (mut rx_events, mut lb_policy, mut tcc) = setup( + create_funcs_for_roundrobin_tests(), + "stub-roundrobin_resolver_update_contains_currently_ready_subchannel", + ); + let lb_policy = lb_policy.as_mut(); + let tcc = tcc.as_mut(); + + let endpoints = create_endpoint_with_n_addresses(2); + send_resolver_update_to_policy(lb_policy, vec![endpoints], tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 2).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Connecting, + ); + verify_connecting_picker_from_policy(&mut rx_events).await; + move_subchannel_to_state( + lb_policy, + subchannels[0].clone(), + tcc, + ConnectivityState::Ready, + ); + verify_ready_picker_from_policy(&mut rx_events, subchannels[0].clone()).await; + + let mut endpoints = create_endpoint_with_n_addresses(4); + endpoints.addresses.reverse(); + send_resolver_update_to_policy(lb_policy, vec![endpoints], tcc); + let subchannels = verify_subchannel_creation_from_policy(&mut rx_events, 4).await; + // TODO(easwars): Once Pick First gets merged, this won't send a + // CONNECTING picker. + lb_policy.subchannel_update(subchannels[0].clone(), &SubchannelState::default(), tcc); + lb_policy.subchannel_update(subchannels[1].clone(), &SubchannelState::default(), tcc); + lb_policy.subchannel_update(subchannels[2].clone(), &SubchannelState::default(), tcc); + lb_policy.subchannel_update( + subchannels[3].clone(), + &SubchannelState { + connectivity_state: ConnectivityState::Ready, + ..Default::default() + }, + tcc, + ); + verify_ready_picker_from_policy(&mut rx_events, subchannels[3].clone()).await; + } +} diff --git a/grpc/src/client/load_balancing/test_utils.rs b/grpc/src/client/load_balancing/test_utils.rs index a7fb623fb..2aacfe4f4 100644 --- a/grpc/src/client/load_balancing/test_utils.rs +++ b/grpc/src/client/load_balancing/test_utils.rs @@ -159,13 +159,15 @@ impl WorkScheduler for TestWorkScheduler { // The callback to invoke when resolver_update is invoked on the stub policy. type ResolverUpdateFn = fn( + &mut StubPolicyData, ResolverUpdate, Option<&LbConfig>, &mut dyn ChannelController, ) -> Result<(), Box>; // The callback to invoke when subchannel_update is invoked on the stub policy. -type SubchannelUpdateFn = fn(Arc, &SubchannelState, &mut dyn ChannelController); +type SubchannelUpdateFn = + fn(&mut StubPolicyData, Arc, &SubchannelState, &mut dyn ChannelController); /// This struct holds `LbPolicy` trait stub functions that tests are expected to /// implement. @@ -175,9 +177,16 @@ pub struct StubPolicyFuncs { pub subchannel_update: Option, } +#[derive(Default)] +/// Data holds test data that will be passed all to functions in PolicyFuncs +pub struct StubPolicyData { + pub test_data: Option>, +} + /// The stub `LbPolicy` that calls the provided functions. pub struct StubPolicy { funcs: StubPolicyFuncs, + data: StubPolicyData, } impl LbPolicy for StubPolicy { @@ -188,7 +197,7 @@ impl LbPolicy for StubPolicy { channel_controller: &mut dyn ChannelController, ) -> Result<(), Box> { if let Some(f) = &self.funcs.resolver_update { - return f(update, config, channel_controller); + return f(&mut self.data, update, config, channel_controller); } Ok(()) } @@ -200,7 +209,7 @@ impl LbPolicy for StubPolicy { channel_controller: &mut dyn ChannelController, ) { if let Some(f) = &self.funcs.subchannel_update { - f(subchannel, state, channel_controller); + f(&mut self.data, subchannel, state, channel_controller); } } @@ -223,6 +232,7 @@ impl LbPolicyBuilder for StubPolicyBuilder { fn build(&self, options: LbPolicyOptions) -> Box { Box::new(StubPolicy { funcs: self.funcs.clone(), + data: StubPolicyData::default(), }) } @@ -234,7 +244,7 @@ impl LbPolicyBuilder for StubPolicyBuilder { &self, _config: &ParsedJsonLbConfig, ) -> Result, Box> { - todo!("Implement parse_config in StubPolicyBuilder") + todo!("Implement parse_config in StubPolicy"); } }