From 7ab080b1c23ee0576b53c6957071fd311f8ebeb8 Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 26 Aug 2025 17:36:30 -0400 Subject: [PATCH 01/14] Adjust path of debs in common-libs artifact (#103) ### why pr checker failed due to common-libs are packaging bookworm debs instead of bullseye. ### what this PR does change bullseye to bookworm --- azure-pipelines.yml | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/azure-pipelines.yml b/azure-pipelines.yml index e7d4c21..00a7b7f 100644 --- a/azure-pipelines.yml +++ b/azure-pipelines.yml @@ -44,15 +44,15 @@ stages: runBranch: 'refs/heads/master' path: $(Build.ArtifactStagingDirectory)/download/common patterns: | - target/debs/bullseye/libnl-3-200_*.deb - target/debs/bullseye/libnl-3-dev_*.deb - target/debs/bullseye/libnl-genl-3-200_*.deb - target/debs/bullseye/libnl-genl-3-dev_*.deb - target/debs/bullseye/libnl-route-3-200_*.deb - target/debs/bullseye/libnl-route-3-dev_*.deb - target/debs/bullseye/libnl-nf-3-200_*.deb - target/debs/bullseye/libnl-nf-3-dev_*.deb - target/debs/bullseye/libyang_*.deb + target/debs/bookworm/libnl-3-200_*.deb + target/debs/bookworm/libnl-3-dev_*.deb + target/debs/bookworm/libnl-genl-3-200_*.deb + target/debs/bookworm/libnl-genl-3-dev_*.deb + target/debs/bookworm/libnl-route-3-200_*.deb + target/debs/bookworm/libnl-route-3-dev_*.deb + target/debs/bookworm/libnl-nf-3-200_*.deb + target/debs/bookworm/libnl-nf-3-dev_*.deb + target/debs/bookworm/libyang_*.deb displayName: "Download common-libs deb packages" - task: DownloadPipelineArtifact@2 From fe8dbe6c5c561fe1263faadece310f60de6c995a Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Thu, 28 Aug 2025 17:27:11 -0400 Subject: [PATCH 02/14] Fix producer bridge CI issue (#105) ### why producer_state_table_bridge_check_dup test failed. Suspect swss-common behaviour changed. ### what this PR does Make the check more strict to make sure no entries are received from consumer. --- crates/swss-common-bridge/src/producer.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/crates/swss-common-bridge/src/producer.rs b/crates/swss-common-bridge/src/producer.rs index 1b04a4b..00c222e 100644 --- a/crates/swss-common-bridge/src/producer.rs +++ b/crates/swss-common-bridge/src/producer.rs @@ -268,10 +268,12 @@ mod test { if result.is_ok() { // If we got here, it means the bridge did not skip the updates let received = consumer_table.pops().await; - for kfv in received { - println!("Received: {}", kfv.key); + if !received.is_empty() { + for kfv in received { + println!("Received: {}", kfv.key); + } + panic!("Expected bridge to skip duplicate updates, but it processed them"); } - panic!("Expected bridge to skip duplicate updates, but it processed them"); } } From c54339a1160e2cec5c47849f1a40ef2a92fbc7cb Mon Sep 17 00:00:00 2001 From: dypet Date: Fri, 29 Aug 2025 09:44:37 -0600 Subject: [PATCH 03/14] Convert Unspecified to Standby in DashHaScopeTable. (#97) Translate "Unspecified" DesiredHaState from dash ha scope config to "standby" in DashHaScopeTable --- crates/hamgrd/src/actors/ha_scope.rs | 14 +++++++++----- crates/hamgrd/src/actors/ha_set.rs | 2 +- crates/hamgrd/src/actors/test.rs | 2 +- crates/hamgrd/src/db_structs.rs | 3 +++ 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/crates/hamgrd/src/actors/ha_scope.rs b/crates/hamgrd/src/actors/ha_scope.rs index 07d3091..9b1b273 100644 --- a/crates/hamgrd/src/actors/ha_scope.rs +++ b/crates/hamgrd/src/actors/ha_scope.rs @@ -241,11 +241,15 @@ impl HaScopeActor { ha_set_id: dash_ha_scope_config.ha_set_id.clone(), vip_v4: haset.ha_set.vip_v4.clone(), vip_v6: haset.ha_set.vip_v6.clone(), - ha_role: format!( - "{}", - DesiredHaState::try_from(dash_ha_scope_config.desired_ha_state).unwrap() - ) - .to_lowercase(), /*todo, how switching_to_active is derived. Is it relevant to dpu driven mode */ + ha_role: if dash_ha_scope_config.desired_ha_state == DesiredHaState::Unspecified as i32 { + "standby".to_string() + } else { + format!( + "{}", + DesiredHaState::try_from(dash_ha_scope_config.desired_ha_state).unwrap() + ) + .to_lowercase() + }, /*todo, how switching_to_active is derived. Is it relevant to dpu driven mode */ flow_reconcile_requested, activate_role_requested, }; diff --git a/crates/hamgrd/src/actors/ha_set.rs b/crates/hamgrd/src/actors/ha_set.rs index 4031b60..f76e88a 100644 --- a/crates/hamgrd/src/actors/ha_set.rs +++ b/crates/hamgrd/src/actors/ha_set.rs @@ -105,7 +105,7 @@ impl HaSetActor { scope: sonic_dash_api_proto::types::HaScope::try_from(dash_ha_set_config.scope) .map(|s| { let name = s.as_str_name(); - name.strip_prefix("SCOPE_").unwrap_or(name).to_lowercase() + name.strip_prefix("HA_SCOPE_").unwrap_or(name).to_lowercase() }) .ok(), local_npu_ip: local_vdpu.dpu.npu_ipv4.clone(), diff --git a/crates/hamgrd/src/actors/test.rs b/crates/hamgrd/src/actors/test.rs index 91de47a..7b0ba54 100644 --- a/crates/hamgrd/src/actors/test.rs +++ b/crates/hamgrd/src/actors/test.rs @@ -483,7 +483,7 @@ pub fn make_dpu_scope_ha_set_obj(switch: u16, dpu: u16) -> (String, DashHaSetTab vip_v4: ip_to_string(&haset_cfg.vip_v4.unwrap()), vip_v6: Some(ip_to_string(&haset_cfg.vip_v6.unwrap())), owner: None, - scope: Some("ha_scope_dpu".to_string()), + scope: Some("dpu".to_string()), local_npu_ip: format!("10.0.{switch}.{dpu}"), local_ip: format!("18.0.{switch}.{dpu}"), peer_ip: format!("18.0.{}.{dpu}", switch_pair_id * 2 + 1), diff --git a/crates/hamgrd/src/db_structs.rs b/crates/hamgrd/src/db_structs.rs index 61bae62..1b275d0 100644 --- a/crates/hamgrd/src/db_structs.rs +++ b/crates/hamgrd/src/db_structs.rs @@ -310,10 +310,13 @@ pub struct DpuDashHaScopeState { // The current term confirmed by ASIC. pub ha_term: String, // DPU is pending on role activation. + #[serde(default)] pub activate_role_pending: bool, // Flow reconcile is requested and pending approval. + #[serde(default)] pub flow_reconcile_pending: bool, // Brainsplit is detected, and DPU is pending on recovery. + #[serde(default)] pub brainsplit_recover_pending: bool, } From fe71e6f893e0617e515befe6522bd698eafc1563 Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Fri, 29 Aug 2025 11:47:27 -0400 Subject: [PATCH 04/14] Route exchange (#55) Implements route exchange feature specified in the wiki. For the detail behaviour, see https://github.com/sonic-net/sonic-dash-ha/wiki/SWBus-(Switch-Bus)#route-exchange-working-theory --- crates/hamgrd/src/actors.rs | 2 +- crates/hamgrd/src/actors/test.rs | 3 +- crates/hamgrd/src/main.rs | 12 +- crates/swbus-actor/src/state/incoming.rs | 3 +- crates/swbus-actor/tests/echo.rs | 4 +- crates/swbus-actor/tests/kvstore.rs | 4 +- crates/swbus-actor/tests/test_executor.rs | 3 +- crates/swbus-cli/src/main.rs | 8 +- crates/swbus-cli/src/show/hamgrd/actor.rs | 2 +- crates/swbus-cli/src/show/hamgrd/mod.rs | 4 +- crates/swbus-cli/src/show/mod.rs | 4 +- crates/swbus-cli/src/show/swbusd/mod.rs | 4 +- crates/swbus-cli/src/show/swbusd/route.rs | 13 +- crates/swbus-config/src/lib.rs | 46 +- crates/swbus-core/src/mux/conn.rs | 5 +- crates/swbus-core/src/mux/conn_info.rs | 27 +- crates/swbus-core/src/mux/conn_store.rs | 59 +- crates/swbus-core/src/mux/conn_worker.rs | 94 +- crates/swbus-core/src/mux/mod.rs | 1 + crates/swbus-core/src/mux/multiplexer.rs | 1389 ++++++++++++++--- crates/swbus-core/src/mux/nexthop.rs | 171 +- crates/swbus-core/src/mux/route_annoucer.rs | 252 +++ crates/swbus-core/src/mux/service.rs | 55 +- crates/swbus-core/tests/README.md | 95 ++ crates/swbus-core/tests/basic_tests.rs | 33 +- .../swbus-core/tests/common/test_executor.rs | 103 +- .../tests/data/{ => b2b}/test_ping.json | 6 - .../tests/data/{ => b2b}/test_show_route.json | 12 +- .../data/{ => b2b}/test_trace_route.json | 0 crates/swbus-core/tests/data/b2b/topo.json | 43 + .../tests/data/inter-cluster/test_ping.json | 49 + .../data/inter-cluster/test_show_route.json | 93 ++ .../tests/data/inter-cluster/topo.json | 93 ++ crates/swbus-core/tests/data/topos.json | 45 - .../swbus-core/tests/inter_cluster_tests.rs | 22 + crates/swbus-edge/src/core_client.rs | 16 +- crates/swbus-edge/src/edge_runtime.rs | 16 +- crates/swbus-proto/build.rs | 8 +- crates/swbus-proto/proto/swbus.proto | 36 +- crates/swbus-proto/src/swbus.rs | 138 +- .../swbusd_cluster-a.10.0.0.1.cfg | 14 + .../swbusd_cluster-a.10.0.0.2.cfg | 8 + .../swbusd_cluster-b.11.0.0.1.cfg | 10 + crates/swbusd/sample/swbusd1.cfg | 6 +- crates/swbusd/sample/swbusd2.cfg | 4 +- crates/swss-common-bridge/src/consumer.rs | 6 +- crates/swss-common-bridge/src/producer.rs | 6 +- 47 files changed, 2404 insertions(+), 623 deletions(-) create mode 100644 crates/swbus-core/src/mux/route_annoucer.rs create mode 100644 crates/swbus-core/tests/README.md rename crates/swbus-core/tests/data/{ => b2b}/test_ping.json (96%) rename crates/swbus-core/tests/data/{ => b2b}/test_show_route.json (82%) rename crates/swbus-core/tests/data/{ => b2b}/test_trace_route.json (100%) create mode 100644 crates/swbus-core/tests/data/b2b/topo.json create mode 100644 crates/swbus-core/tests/data/inter-cluster/test_ping.json create mode 100644 crates/swbus-core/tests/data/inter-cluster/test_show_route.json create mode 100644 crates/swbus-core/tests/data/inter-cluster/topo.json delete mode 100644 crates/swbus-core/tests/data/topos.json create mode 100644 crates/swbus-core/tests/inter_cluster_tests.rs create mode 100644 crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.1.cfg create mode 100644 crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.2.cfg create mode 100644 crates/swbusd/sample/inter-cluster/swbusd_cluster-b.11.0.0.1.cfg diff --git a/crates/hamgrd/src/actors.rs b/crates/hamgrd/src/actors.rs index adc8599..b6690e1 100644 --- a/crates/hamgrd/src/actors.rs +++ b/crates/hamgrd/src/actors.rs @@ -108,7 +108,7 @@ where SwbusError::RouteError { code, detail } => (code, detail), }; let response = SwbusMessage::new_response( - &msg, + msg.header.as_ref().unwrap(), Some(&self.sp), code, &err_msg, diff --git a/crates/hamgrd/src/actors/test.rs b/crates/hamgrd/src/actors/test.rs index 7b0ba54..6d501c5 100644 --- a/crates/hamgrd/src/actors/test.rs +++ b/crates/hamgrd/src/actors/test.rs @@ -10,7 +10,7 @@ use std::{net::Ipv4Addr, net::Ipv6Addr, sync::Arc}; use swbus_actor::{ActorMessage, ActorRuntime}; use swbus_edge::{ simple_client::{IncomingMessage, MessageBody, OutgoingMessage, SimpleSwbusEdgeClient}, - swbus_proto::swbus::{ServicePath, SwbusErrorCode}, + swbus_proto::swbus::{ConnectionType, ServicePath, SwbusErrorCode}, SwbusEdgeRuntime, }; use swss_common::{FieldValues, Table}; @@ -283,6 +283,7 @@ pub async fn create_edge_runtime() -> SwbusEdgeRuntime { let mut swbus_edge: SwbusEdgeRuntime = SwbusEdgeRuntime::new( "none".to_string(), ServicePath::from_string("unknown.unknown.unknown/hamgrd/0").unwrap(), + ConnectionType::InNode, ); swbus_edge.start().await.unwrap(); swbus_edge diff --git a/crates/hamgrd/src/main.rs b/crates/hamgrd/src/main.rs index 1991dbb..419d6aa 100644 --- a/crates/hamgrd/src/main.rs +++ b/crates/hamgrd/src/main.rs @@ -9,7 +9,11 @@ use std::{ }; use swbus_actor::{set_global_runtime, ActorRuntime}; use swbus_config::swbus_config_from_db; -use swbus_edge::{simple_client::SimpleSwbusEdgeClient, swbus_proto::swbus::ServicePath, RuntimeEnv, SwbusEdgeRuntime}; +use swbus_edge::{ + simple_client::SimpleSwbusEdgeClient, + swbus_proto::swbus::{ConnectionType, ServicePath}, + RuntimeEnv, SwbusEdgeRuntime, +}; use swss_common::{sonic_db_config_initialize_global, DbConnector}; use swss_common_bridge::consumer::ConsumerBridge; use tokio::{signal, task::JoinHandle, time::timeout}; @@ -63,7 +67,11 @@ async fn main() { let runtime_data = RuntimeData::new(args.slot_id, swbus_config.npu_ipv4, swbus_config.npu_ipv6); // Setup swbus and actor runtime - let mut swbus_edge = SwbusEdgeRuntime::new(format!("http://{}", swbus_config.endpoint), swbus_sp.clone()); + let mut swbus_edge = SwbusEdgeRuntime::new( + format!("http://{}", swbus_config.endpoint), + swbus_sp.clone(), + ConnectionType::InNode, + ); swbus_edge.set_runtime_env(Box::new(runtime_data)); swbus_edge.start().await.unwrap(); diff --git a/crates/swbus-actor/src/state/incoming.rs b/crates/swbus-actor/src/state/incoming.rs index c2cd387..dc11b70 100644 --- a/crates/swbus-actor/src/state/incoming.rs +++ b/crates/swbus-actor/src/state/incoming.rs @@ -176,7 +176,7 @@ impl PartialEq for IncomingTableEntry { mod test { use super::*; use crate::actor_message::ActorMessage; - use swbus_edge::swbus_proto::swbus::ServicePath; + use swbus_edge::swbus_proto::swbus::{ConnectionType, ServicePath}; use swbus_edge::SwbusEdgeRuntime; #[test] @@ -184,6 +184,7 @@ mod test { let swbus_edge = Arc::new(SwbusEdgeRuntime::new( "none".to_string(), ServicePath::from_string("unknown.unknown.unknown/hamgrd/0").unwrap(), + ConnectionType::InNode, )); let swbus_edge = Arc::new(SimpleSwbusEdgeClient::new( diff --git a/crates/swbus-actor/tests/echo.rs b/crates/swbus-actor/tests/echo.rs index 5661f7f..e6b4192 100644 --- a/crates/swbus-actor/tests/echo.rs +++ b/crates/swbus-actor/tests/echo.rs @@ -1,6 +1,6 @@ use std::{mem, time::Duration}; use swbus_actor::{Actor, ActorMessage, ActorRuntime, Context, Result, State}; -use swbus_edge::{swbus_proto::swbus::ServicePath, SwbusEdgeRuntime}; +use swbus_edge::{swbus_proto::swbus::ConnectionType, swbus_proto::swbus::ServicePath, SwbusEdgeRuntime}; use tokio::{ sync::oneshot::{channel, Sender}, time::timeout, @@ -12,7 +12,7 @@ fn sp(name: &str) -> ServicePath { #[tokio::test] async fn echo() { - let mut swbus_edge = SwbusEdgeRuntime::new("none".to_string(), sp("none")); + let mut swbus_edge = SwbusEdgeRuntime::new("none".to_string(), sp("none"), ConnectionType::InNode); swbus_edge.start().await.unwrap(); let actor_runtime = ActorRuntime::new(swbus_edge.into()); swbus_actor::set_global_runtime(actor_runtime); diff --git a/crates/swbus-actor/tests/kvstore.rs b/crates/swbus-actor/tests/kvstore.rs index 3b90533..78a2631 100644 --- a/crates/swbus-actor/tests/kvstore.rs +++ b/crates/swbus-actor/tests/kvstore.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use std::{mem, time::Duration}; use swbus_actor::{Actor, ActorMessage, ActorRuntime, Context, Result, State}; -use swbus_edge::{swbus_proto::swbus::ServicePath, SwbusEdgeRuntime}; +use swbus_edge::{swbus_proto::swbus::ConnectionType, swbus_proto::swbus::ServicePath, SwbusEdgeRuntime}; use swss_common::Table; use swss_common_testing::Redis; use tokio::{ @@ -26,7 +26,7 @@ async fn echo() { // Add a handler to the runtime to receive management response let (mgmt_resp_queue_tx, mut mgmt_resp_queue_rx) = mpsc::channel::(1); - let mut swbus_edge = SwbusEdgeRuntime::new("none".to_string(), sp("none")); + let mut swbus_edge = SwbusEdgeRuntime::new("none".to_string(), sp("none"), ConnectionType::InNode); swbus_edge.add_handler(sp("mgmt_resp"), mgmt_resp_queue_tx); swbus_edge.start().await.unwrap(); diff --git a/crates/swbus-actor/tests/test_executor.rs b/crates/swbus-actor/tests/test_executor.rs index 40256fe..08cf619 100644 --- a/crates/swbus-actor/tests/test_executor.rs +++ b/crates/swbus-actor/tests/test_executor.rs @@ -4,7 +4,7 @@ use std::{ collections::HashMap, fs::File, future::pending, io::BufReader, mem, path::PathBuf, sync::Arc, time::Duration, }; use swbus_actor::{Actor, ActorMessage, ActorRuntime, Context, State}; -use swbus_edge::{swbus_proto::swbus::ServicePath, SwbusEdgeRuntime}; +use swbus_edge::{swbus_proto::swbus::ConnectionType, swbus_proto::swbus::ServicePath, SwbusEdgeRuntime}; use swss_common::{DbConnector, Table}; use swss_common_testing::{random_string, Redis}; use tokio::{ @@ -123,6 +123,7 @@ async fn run_test(mut t: TestSpec) -> Option<&'static str> { let mut swbus_edge = SwbusEdgeRuntime::new( "".into(), ServicePath::from_string("test.test.test/test/test").unwrap(), + ConnectionType::InNode, ); swbus_edge.start().await.unwrap(); let actor_rt = ActorRuntime::new(Arc::new(swbus_edge)); diff --git a/crates/swbus-cli/src/main.rs b/crates/swbus-cli/src/main.rs index ecb78a3..9b33fb8 100644 --- a/crates/swbus-cli/src/main.rs +++ b/crates/swbus-cli/src/main.rs @@ -158,7 +158,11 @@ async fn main() { sp.service_type = "swbus-cli".to_string(); sp.service_id = Uuid::new_v4().to_string(); - let mut runtime = SwbusEdgeRuntime::new(format!("http://{}", swbus_config.endpoint), sp.clone()); + let mut runtime = SwbusEdgeRuntime::new( + format!("http://{}", swbus_config.endpoint), + sp.clone(), + ConnectionType::Client, + ); runtime.start().await.unwrap(); let runtime = Arc::new(runtime); @@ -279,7 +283,7 @@ mod tests { assert!(config .routes .iter() - .any(|r| r.scope == RouteScope::Cluster && r.key == expected_sp)); + .any(|r| r.scope == RouteScope::InCluster && r.key == expected_sp)); cleanup_configdb_for_test(); } diff --git a/crates/swbus-cli/src/show/hamgrd/actor.rs b/crates/swbus-cli/src/show/hamgrd/actor.rs index b310171..5769887 100644 --- a/crates/swbus-cli/src/show/hamgrd/actor.rs +++ b/crates/swbus-cli/src/show/hamgrd/actor.rs @@ -260,7 +260,7 @@ impl ShowCmdHandler for ShowActorCmd { } } - fn process_response(&self, response: &RequestResponse) { + fn process_response(&self, _: &CommandContext, response: &RequestResponse) { let result = match &response.response_body { Some(request_response::ResponseBody::ManagementQueryResult(ref result)) => &result.value, _ => { diff --git a/crates/swbus-cli/src/show/hamgrd/mod.rs b/crates/swbus-cli/src/show/hamgrd/mod.rs index 7179758..2c6d1ba 100644 --- a/crates/swbus-cli/src/show/hamgrd/mod.rs +++ b/crates/swbus-cli/src/show/hamgrd/mod.rs @@ -22,8 +22,8 @@ impl ShowCmdHandler for ShowHamgrdCmd { sub_cmd.create_request(ctx, src_sp) } - fn process_response(&self, response: &RequestResponse) { + fn process_response(&self, ctx: &CommandContext, response: &RequestResponse) { let HamgrdCmd::Actor(sub_cmd) = &self.subcommand; - sub_cmd.process_response(response); + sub_cmd.process_response(ctx, response); } } diff --git a/crates/swbus-cli/src/show/mod.rs b/crates/swbus-cli/src/show/mod.rs index a9f079d..1ffc546 100644 --- a/crates/swbus-cli/src/show/mod.rs +++ b/crates/swbus-cli/src/show/mod.rs @@ -22,7 +22,7 @@ enum ShowSubCmd { trait ShowCmdHandler { fn create_request(&self, ctx: &super::CommandContext, src_sp: &ServicePath) -> SwbusMessage; - fn process_response(&self, response: &RequestResponse); + fn process_response(&self, ctx: &super::CommandContext, response: &RequestResponse); } impl super::CmdHandler for ShowCmd { @@ -55,7 +55,7 @@ impl super::CmdHandler for ShowCmd { let body = result.msg.unwrap().body.unwrap(); match body { swbus_message::Body::Response(response) => { - sub_cmd.process_response(&response); + sub_cmd.process_response(ctx, &response); } _ => { info!("Invalid response"); diff --git a/crates/swbus-cli/src/show/swbusd/mod.rs b/crates/swbus-cli/src/show/swbusd/mod.rs index 9c57fe8..86e57c7 100644 --- a/crates/swbus-cli/src/show/swbusd/mod.rs +++ b/crates/swbus-cli/src/show/swbusd/mod.rs @@ -23,8 +23,8 @@ impl ShowCmdHandler for ShowSwbusdCmd { sub_cmd.create_request(ctx, src_sp) } - fn process_response(&self, response: &RequestResponse) { + fn process_response(&self, ctx: &CommandContext, response: &RequestResponse) { let SwbusdCmd::Route(sub_cmd) = &self.subcommand; - sub_cmd.process_response(response); + sub_cmd.process_response(ctx, response); } } diff --git a/crates/swbus-cli/src/show/swbusd/route.rs b/crates/swbus-cli/src/show/swbusd/route.rs index 362ad99..f743f5c 100644 --- a/crates/swbus-cli/src/show/swbusd/route.rs +++ b/crates/swbus-cli/src/show/swbusd/route.rs @@ -13,7 +13,7 @@ struct RouteDisplay { service_path: String, hop_count: u32, nh_id: String, - nh_scope: String, + route_scope: String, nh_service_path: String, } @@ -29,18 +29,21 @@ impl ShowCmdHandler for ShowRouteCmd { } } - fn process_response(&self, response: &RequestResponse) { + fn process_response(&self, ctx: &CommandContext, response: &RequestResponse) { let routes = match &response.response_body { - Some(request_response::ResponseBody::RouteQueryResult(route_result)) => route_result, + Some(request_response::ResponseBody::RouteEntries(route_result)) => route_result, _ => { - info!("Expecting RouteQueryResult but got something else: {:?}", response); + info!("Expecting RouteEntries but got something else: {:?}", response); return; } }; + let my_sp = Some(ctx.sp.clone()); let routes: Vec = routes .entries .iter() + // Filter out the entry for the show_route request itself + .filter(|entry| entry.service_path != my_sp) .map(|entry| RouteDisplay { service_path: entry .service_path @@ -49,7 +52,7 @@ impl ShowCmdHandler for ShowRouteCmd { .to_longest_path(), hop_count: entry.hop_count, nh_id: entry.nh_id.clone(), - nh_scope: RouteScope::try_from(entry.nh_scope).unwrap().as_str_name().to_string(), + route_scope: RouteScope::try_from(entry.route_scope).unwrap().as_str_name().to_string(), nh_service_path: entry .nh_service_path .as_ref() diff --git a/crates/swbus-config/src/lib.rs b/crates/swbus-config/src/lib.rs index dcd593b..c2f7925 100644 --- a/crates/swbus-config/src/lib.rs +++ b/crates/swbus-config/src/lib.rs @@ -40,7 +40,7 @@ pub struct PeerConfig { impl SwbusConfig { pub fn get_swbusd_service_path(&self) -> Option { for route in &self.routes { - if route.scope == RouteScope::Cluster { + if route.scope == RouteScope::InCluster { return Some(route.key.clone()); } } @@ -139,7 +139,7 @@ fn route_config_from_dpu_entry(dpu_entry: &ConfigDBDPUEntry, region: &str, clust let sp = ServicePath::with_node(region, cluster, &format!("{npu_ipv4}-dpu{dpu_id}"), "", "", "", ""); routes.push(RouteConfig { key: sp, - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }); } @@ -147,7 +147,7 @@ fn route_config_from_dpu_entry(dpu_entry: &ConfigDBDPUEntry, region: &str, clust let sp = ServicePath::with_node(region, cluster, &format!("{npu_ipv6}-dpu{dpu_id}"), "", "", "", ""); routes.push(RouteConfig { key: sp, - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }); } @@ -195,7 +195,7 @@ fn peer_config_from_dpu_entry( peers.push(PeerConfig { id: sp, endpoint: SocketAddr::new(IpAddr::V4(npu_ipv4), swbusd_port), - conn_type: ConnectionType::Cluster, + conn_type: ConnectionType::InCluster, }); } @@ -208,7 +208,7 @@ fn peer_config_from_dpu_entry( peers.push(PeerConfig { id: sp, endpoint: SocketAddr::new(IpAddr::V6(npu_ipv6), swbusd_port), - conn_type: ConnectionType::Cluster, + conn_type: ConnectionType::InCluster, }); } @@ -497,40 +497,40 @@ mod tests { endpoint: "10.0.1.0:23606" routes: - key: "region-a.cluster-a.10.0.1.0-dpu0" - scope: "Cluster" + scope: "InCluster" - key: "region-a.cluster-a.2001:db8:1::-dpu0" - scope: "Cluster" + scope: "InCluster" peers: - id: "region-a.cluster-a.10.0.1.0-dpu1" endpoint: "10.0.1.0:23607" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.2001:db8:1::-dpu1" endpoint: "[2001:db8:1::]:23607" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.10.0.1.1-dpu0" endpoint: "10.0.1.1:23606" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.2001:db8:1::1-dpu0" endpoint: "[2001:db8:1::1]:23606" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.10.0.1.1-dpu1" endpoint: "10.0.1.1:23607" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.2001:db8:1::1-dpu1" endpoint: "[2001:db8:1::1]:23607" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.10.0.1.2-dpu0" endpoint: "10.0.1.2:23606" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.2001:db8:1::2-dpu0" endpoint: "[2001:db8:1::2]:23606" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.10.0.1.2-dpu1" endpoint: "10.0.1.2:23607" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.2001:db8:1::2-dpu1" endpoint: "[2001:db8:1::2]:23607" - conn_type: "Cluster" + conn_type: "InCluster" "#; let dir = tempdir().unwrap(); @@ -557,14 +557,14 @@ mod tests { endpoint: 10.0.0.1:8000 routes: - key: "region-a.cluster-a.10.0.0.1-dpu0" - scope: "Cluster" + scope: "InCluster" peers: - id: "region-a.cluster-a.10.0.0.2-dpu0" endpoint: "10.0.0.2:8000" - conn_type: "Cluster" + conn_type: "InCluster" - id: "region-a.cluster-a.10.0.0.3-dpu0" endpoint: "10.0.0.3:8000" - conn_type: "Cluster" + conn_type: "InCluster" "#; let dir = tempdir().unwrap(); @@ -588,7 +588,7 @@ mod tests { config.routes[0].key, ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap() ); - assert_eq!(config.routes[0].scope, RouteScope::Cluster); + assert_eq!(config.routes[0].scope, RouteScope::InCluster); assert_eq!( config.peers[0].id, @@ -598,7 +598,7 @@ mod tests { config.peers[0].endpoint, "10.0.0.2:8000".parse().expect("not expecting error") ); - assert_eq!(config.peers[0].conn_type, ConnectionType::Cluster); + assert_eq!(config.peers[0].conn_type, ConnectionType::InCluster); assert_eq!( config.peers[1].id, ServicePath::from_string("region-a.cluster-a.10.0.0.3-dpu0").unwrap() @@ -607,6 +607,6 @@ mod tests { config.peers[1].endpoint, "10.0.0.3:8000".parse().expect("not expecting error") ); - assert_eq!(config.peers[1].conn_type, ConnectionType::Cluster); + assert_eq!(config.peers[1].conn_type, ConnectionType::InCluster); } } diff --git a/crates/swbus-core/src/mux/conn.rs b/crates/swbus-core/src/mux/conn.rs index 82f13cf..8a996bf 100644 --- a/crates/swbus-core/src/mux/conn.rs +++ b/crates/swbus-core/src/mux/conn.rs @@ -98,10 +98,7 @@ impl SwbusConn { let mut stream_message_request = Request::new(request_stream); - let sp_str = conn_info - .local_service_path() - .expect("missing local service path") - .to_string(); + let sp_str = mux.get_my_service_path().to_string(); let meta = stream_message_request.metadata_mut(); diff --git a/crates/swbus-core/src/mux/conn_info.rs b/crates/swbus-core/src/mux/conn_info.rs index 8cd84c5..8e93f53 100644 --- a/crates/swbus-core/src/mux/conn_info.rs +++ b/crates/swbus-core/src/mux/conn_info.rs @@ -23,10 +23,6 @@ pub struct SwbusConnInfo { #[getset(get_copy = "pub")] connection_type: ConnectionType, - // Local service path is only used for client mode to send my route to the server - // this will be removed when we implement route update - local_service_path: Option, - #[getset(get = "pub")] remote_service_path: ServicePath, } @@ -36,14 +32,12 @@ impl SwbusConnInfo { conn_type: ConnectionType, remote_addr: SocketAddr, remote_service_path: ServicePath, - local_service_path: ServicePath, ) -> SwbusConnInfo { SwbusConnInfo { id: format!("swbs-to://{}:{}", remote_addr.ip(), remote_addr.port()), mode: SwbusConnMode::Client, remote_addr, connection_type: conn_type, - local_service_path: Some(local_service_path), remote_service_path, } } @@ -58,14 +52,9 @@ impl SwbusConnInfo { mode: SwbusConnMode::Server, remote_addr, connection_type: conn_type, - local_service_path: None, remote_service_path, } } - - pub fn local_service_path(&self) -> Option<&ServicePath> { - self.local_service_path.as_ref() - } } #[cfg(test)] @@ -77,33 +66,25 @@ mod tests { fn new_client_conn_info_should_succeed() { let remote_addr = "127.0.0.1:8080".parse().unwrap(); let remote_service_path = ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(); - let local_service_path = ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(); - let conn_info = SwbusConnInfo::new_client( - ConnectionType::Cluster, - remote_addr, - remote_service_path.clone(), - local_service_path.clone(), - ); + let conn_info = SwbusConnInfo::new_client(ConnectionType::InCluster, remote_addr, remote_service_path.clone()); assert_eq!(conn_info.id(), "swbs-to://127.0.0.1:8080"); assert_eq!(conn_info.mode(), SwbusConnMode::Client); assert_eq!(conn_info.remote_addr(), remote_addr); - assert_eq!(conn_info.connection_type(), ConnectionType::Cluster); + assert_eq!(conn_info.connection_type(), ConnectionType::InCluster); assert_eq!(conn_info.remote_service_path(), &remote_service_path); - assert_eq!(conn_info.local_service_path(), Some(&local_service_path)); } #[test] fn new_server_conn_info_should_succeed() { let remote_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); let remote_service_path = ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(); - let conn_info = SwbusConnInfo::new_server(ConnectionType::Region, remote_addr, remote_service_path.clone()); + let conn_info = SwbusConnInfo::new_server(ConnectionType::InRegion, remote_addr, remote_service_path.clone()); assert_eq!(conn_info.id(), "swbs-from://127.0.0.1:8080"); assert_eq!(conn_info.mode(), SwbusConnMode::Server); assert_eq!(conn_info.remote_addr(), remote_addr); - assert_eq!(conn_info.connection_type(), ConnectionType::Region); + assert_eq!(conn_info.connection_type(), ConnectionType::InRegion); assert_eq!(conn_info.remote_service_path(), &remote_service_path); - assert_eq!(conn_info.local_service_path(), None); } } diff --git a/crates/swbus-core/src/mux/conn_store.rs b/crates/swbus-core/src/mux/conn_store.rs index 2851163..8604050 100644 --- a/crates/swbus-core/src/mux/conn_store.rs +++ b/crates/swbus-core/src/mux/conn_store.rs @@ -2,9 +2,9 @@ use crate::mux::conn::SwbusConn; use crate::mux::SwbusConnInfo; use crate::mux::SwbusConnMode; use crate::mux::SwbusMultiplexer; -use dashmap::{DashMap, DashSet}; +use dashmap::DashMap; use std::sync::Arc; -use swbus_config::{PeerConfig, RouteConfig}; +use swbus_config::PeerConfig; use tokio::time::Duration; use tokio_util::sync::CancellationToken; use tracing::*; @@ -18,7 +18,6 @@ enum ConnTracker { pub struct SwbusConnStore { mux: Arc, connections: DashMap, ConnTracker>, - my_routes: DashSet, } impl SwbusConnStore { @@ -26,7 +25,6 @@ impl SwbusConnStore { SwbusConnStore { mux, connections: DashMap::new(), - my_routes: DashSet::new(), } } @@ -68,18 +66,11 @@ impl SwbusConnStore { self.connections.insert(conn_info_clone, ConnTracker::Task(token)); } - pub fn add_my_route(&self, my_route: RouteConfig) { - self.my_routes.insert(my_route); - } - pub fn add_peer(self: &Arc, peer: PeerConfig) { - // todo: assuming only one route for now. Will be improved to send routes in route update message and remove this - let my_route = self.my_routes.iter().next().expect("My service path is not set"); let conn_info = Arc::new(SwbusConnInfo::new_client( peer.conn_type, peer.endpoint, peer.id.clone(), - my_route.key.clone(), )); self.start_connect_task(conn_info, false); } @@ -123,24 +114,25 @@ impl SwbusConnStore { #[cfg(test)] mod tests { use super::*; + use swbus_config::RouteConfig; use swbus_proto::swbus::ConnectionType; use swbus_proto::swbus::RouteScope; use swbus_proto::swbus::ServicePath; use tokio::sync::mpsc; #[tokio::test] async fn test_add_peer() { - let mux = Arc::new(SwbusMultiplexer::new()); - let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let peer_config = PeerConfig { - conn_type: ConnectionType::Local, + conn_type: ConnectionType::InNode, endpoint: "127.0.0.1:8080".to_string().parse().unwrap(), id: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), }; let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - conn_store.add_my_route(route_config); + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config])); + let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); conn_store.add_peer(peer_config); @@ -149,35 +141,19 @@ mod tests { })); } - #[tokio::test] - async fn test_add_my_route() { - let mux = Arc::new(SwbusMultiplexer::new()); - let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); - let route_config = RouteConfig { - key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), - scope: RouteScope::Cluster, - }; - - conn_store.add_my_route(route_config.clone()); - - assert!(conn_store.my_routes.contains(&route_config)); - } - #[tokio::test] async fn test_conn_lost() { - let mux = Arc::new(SwbusMultiplexer::new()); - let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - conn_store.add_my_route(route_config); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config])); + let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); conn_store.conn_lost(conn_info.clone()); @@ -188,19 +164,18 @@ mod tests { #[tokio::test] async fn test_conn_established() { - let mux = Arc::new(SwbusMultiplexer::new()); - let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - conn_store.add_my_route(route_config); + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config])); + let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); let (send_queue_tx, _) = mpsc::channel(16); let conn = SwbusConn::new(&conn_info, send_queue_tx); diff --git a/crates/swbus-core/src/mux/conn_worker.rs b/crates/swbus-core/src/mux/conn_worker.rs index d11a561..26dde13 100644 --- a/crates/swbus-core/src/mux/conn_worker.rs +++ b/crates/swbus-core/src/mux/conn_worker.rs @@ -75,7 +75,7 @@ where } fn unregister_from_mux(&self) -> Result<()> { - self.mux.unregister(self.info.clone()); + self.mux.unregister(&self.info); Ok(()) } @@ -98,7 +98,10 @@ where } } Some(Err(err)) => { - error!("Failed to receive message: {}.", err); + if self.info.connection_type() != ConnectionType::Client { + // we don't care CLI client disconnected + error!("Failed to receive message: {}.", err); + } return Err(SwbusError::connection( SwbusErrorCode::ConnectionError, io::Error::new(io::ErrorKind::ConnectionReset, err.to_string()), @@ -129,14 +132,30 @@ where let id = self.mux.generate_message_id(); let my_sp = self.mux.get_my_service_path(); - let response = SwbusMessage::new_response(&message, Some(&my_sp), SwbusErrorCode::Ok, "", id, None); + let response = SwbusMessage::new_response( + message.header.as_ref().unwrap(), + Some(my_sp), + SwbusErrorCode::Ok, + "", + id, + None, + ); self.mux.route_message(response).await?; - if message.header.as_ref().unwrap().destination.as_ref().unwrap() != &my_sp { + if message.header.as_ref().unwrap().destination.as_ref().unwrap() != my_sp { self.mux.route_message(message).await?; } } + Some(swbus_message::Body::RouteAnnouncement(route_entries)) => { + // drop route announcement message + debug!("Received route announcement"); + self.mux.process_route_announcement(route_entries, &self.info)?; + } + Some(swbus_message::Body::ManagementRequest(mgmt_request)) => { + let response = self.process_mgmt_request(message.header.as_ref().unwrap(), mgmt_request)?; + self.mux.route_message(response).await?; + } _ => { self.mux.route_message(message).await?; } @@ -144,6 +163,40 @@ where Ok(()) } + fn process_mgmt_request( + &self, + request_header: &SwbusMessageHeader, + mgmt_request: ManagementRequest, + ) -> Result { + let request_type = ManagementRequestType::try_from(mgmt_request.request).map_err(|_| { + SwbusError::input( + SwbusErrorCode::InvalidArgs, + format!("Invalid management request: {:?}", mgmt_request.request), + ) + })?; + + match request_type { + ManagementRequestType::SwbusdGetRoutes => { + debug!("Received show_route request"); + let routes = self.mux.dump_route_table(); + debug!("show_route response: {:?}", routes); + let response_msg = SwbusMessage::new_response( + request_header, + Some(self.mux.get_my_service_path()), + SwbusErrorCode::Ok, + "", + self.mux.generate_message_id(), + Some(request_response::ResponseBody::RouteEntries(routes)), + ); + Ok(response_msg) + } + _ => Err(SwbusError::input( + SwbusErrorCode::InvalidArgs, + format!("Invalid management request: {:?}", mgmt_request), + )), + } + } + fn validate_message_common(&mut self, message: &SwbusMessage) -> Result<()> { if message.header.is_none() { return Err(SwbusError::input( @@ -193,16 +246,21 @@ mod tests { #[tokio::test] async fn conn_worker_can_be_shutdown() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), + scope: RouteScope::InCluster, + }; + let shutdown_ct = CancellationToken::new(); let message_stream = stream::iter(vec![]); - let mux = Arc::new(SwbusMultiplexer::new()); + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config])); let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); let mut worker = SwbusConnWorker::new(conn_info, shutdown_ct.clone(), message_stream, mux, conn_store); @@ -228,19 +286,18 @@ mod tests { }; let message_stream = stream::iter(vec![Ok(ping_msg)]); - let mux = Arc::new(SwbusMultiplexer::new()); - let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config])); + let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); let mut worker = SwbusConnWorker::new(conn_info, shutdown_ct.clone(), message_stream, mux, conn_store); @@ -253,16 +310,21 @@ mod tests { #[tokio::test] async fn test_worker_invalid_message() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), + scope: RouteScope::InCluster, + }; + let shutdown_ct = CancellationToken::new(); let message_stream = stream::iter(vec![]); - let mux = Arc::new(SwbusMultiplexer::new()); + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config])); let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); let mut worker = SwbusConnWorker::new(conn_info, shutdown_ct.clone(), message_stream, mux, conn_store); diff --git a/crates/swbus-core/src/mux/mod.rs b/crates/swbus-core/src/mux/mod.rs index 78fb5d9..b23b347 100644 --- a/crates/swbus-core/src/mux/mod.rs +++ b/crates/swbus-core/src/mux/mod.rs @@ -6,6 +6,7 @@ mod conn_worker; mod message_handler; mod multiplexer; pub mod nexthop; +mod route_annoucer; pub mod service; pub use conn::*; diff --git a/crates/swbus-core/src/mux/multiplexer.rs b/crates/swbus-core/src/mux/multiplexer.rs index 5b9438f..2fd463e 100644 --- a/crates/swbus-core/src/mux/multiplexer.rs +++ b/crates/swbus-core/src/mux/multiplexer.rs @@ -1,11 +1,16 @@ +use super::route_annoucer::{RouteAnnounceTask, TriggerType}; use super::{NextHopType, SwbusConnInfo, SwbusConnProxy, SwbusNextHop}; -use dashmap::mapref::entry::*; use dashmap::{DashMap, DashSet}; +use std::collections::BTreeSet; +use std::collections::HashMap; use std::sync::Arc; use swbus_config::RouteConfig; use swbus_proto::message_id_generator::MessageIdGenerator; use swbus_proto::result::*; use swbus_proto::swbus::*; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; +use tokio_util::sync::CancellationToken; use tracing::*; enum RouteStage { @@ -21,102 +26,409 @@ const ROUTE_STAGES: [RouteStage; 4] = [ RouteStage::Global, ]; -#[derive(Default)] pub struct SwbusMultiplexer { /// Route table. Each entry is a registered prefix to a next hop, which points to a connection. - routes: DashMap, + routes: DashMap>, id_generator: MessageIdGenerator, - my_routes: DashSet, + my_routes: HashMap, + my_service_path: ServicePath, + connections: DashSet>, + route_annouce_task_tx: Option>, + route_announer_ct: Option, + routes_by_conn: DashMap, BTreeSet>, } +/// `SwbusMultiplexer` is responsible for managing routes and connections within the Swbus system. +/// It maintains a route table and handles route announcements, registrations, and unregistrations of connections. +/// +/// # Methods +/// +/// - `new(routes: Vec) -> Self` +/// - Creates a new `SwbusMultiplexer` instance with the given routes. +/// - Panics if no `InCluster` route is provided. +/// +/// - `set_route_announcer(&mut self, tx: mpsc::Sender, ct: CancellationToken)` +/// - Sets the route announcer with the provided sender and cancellation token. +/// +/// - `generate_message_id(&self) -> u64` +/// - Generates a new message ID using the internal ID generator. +/// +/// - `route_from_conn(&self, conn_info: &Arc) -> String` +/// - Determines the route key based on the connection type of the provided connection info. +/// +/// - `register(&self, conn_info: &Arc, proxy: SwbusConnProxy)` +/// - Registers a new connection and updates the route table accordingly. +/// - Announces the routes to peers if the connection is not of type `Client`. +/// +/// - `unregister(&self, conn_info: &Arc)` +/// - Unregisters a connection and updates the route table accordingly. +/// - Announces the route removal to peers if necessary. +/// +/// - `update_route(&self, route_key: String, nexthop: SwbusNextHop) -> bool` +/// - Updates the route for the given service path with the new next hop. +/// - Returns `true` if the route was updated, otherwise `false`. +/// +/// - `remove_route_to_nh(&self, route_key: &str, nexthop: &SwbusNextHop) -> bool` +/// - Removes the specified route to the given next hop. +/// - Returns `true` if the route was removed, otherwise `false`. +/// +/// - `remove_route(&self, route_key: &str) -> bool` +/// - Removes the specified route. +/// - Returns `true` if the route was removed, otherwise `false`. +/// +/// - `get_nexthop_to_conn(&self, conn_info: &Arc) -> Result` +/// - Retrieves the next hop to the given connection. +/// - Returns an error if the route to the connection is not found. +/// +/// - `key_from_route_entry(entry: &RouteEntry) -> Result` +/// - Generates a route key from the given route entry. +/// +/// - `process_route_announcement(&self, routes: RouteEntries, conn_info: Arc) -> Result<()>` +/// - Processes a route announcement from a connection. +/// - Updates the route table and announces the routes to peers if necessary. +/// +/// - `announce_routes(&self, conn_info: &Arc, trigger: TriggerType) -> Result<()>` +/// - Announces routes to peers based on the provided trigger type. +/// +/// - `get_my_service_path(&self) -> &ServicePath` +/// - Returns the service path of the current instance. +/// +/// - `route_message(&self, message: SwbusMessage) -> Result<()>` +/// - Routes a message to the appropriate next hop. +/// - Returns an error if no route is found. +/// +/// - `dump_route_table(&self) -> RouteEntries` +/// - Dumps the current route table for display purposes. +/// +/// - `export_routes(&self, target_conn: &Arc) -> Option` +/// - Exports routes for route exchange between Swbus instances. +/// - Skips routes that go through the target connection. +/// +/// - `export_connections(&self) -> Vec>` +/// - Exports the list of current connections. +/// +/// - `shutdown(&mut self)` +/// - Shuts down the multiplexer by canceling the route announcer task. impl SwbusMultiplexer { - pub fn new() -> Self { + pub fn new(routes: Vec) -> Self { + let mut my_service_path = None; + let initial_routes: DashMap> = DashMap::new(); + let mut my_routes: HashMap = HashMap::new(); + + for route in routes { + if route.scope == RouteScope::InCluster { + // Create local service route + let route_key = route.key.to_incluster_prefix(); + let local_nh = SwbusNextHop::new_local(); + initial_routes.entry(route_key).or_default().insert(local_nh); + my_service_path = Some(route.key.clone()); + } + // my_routes are to be announced to the peers + my_routes.insert(route.key, route.scope); + } + if my_service_path.is_none() { + panic!("my routes must include a InCluster route"); + } + SwbusMultiplexer { - routes: DashMap::new(), + routes: initial_routes, id_generator: MessageIdGenerator::new(), - my_routes: DashSet::new(), + my_routes, + my_service_path: my_service_path.unwrap(), + connections: DashSet::new(), + route_annouce_task_tx: None, + route_announer_ct: None, + routes_by_conn: DashMap::new(), } } + pub(crate) fn set_route_announcer(&mut self, tx: mpsc::Sender, ct: CancellationToken) { + self.route_annouce_task_tx = Some(tx); + self.route_announer_ct = Some(ct); + } + pub fn generate_message_id(&self) -> u64 { self.id_generator.generate() } + /// get route from conn_info based on connection type. This is a direct route which means it is to a immediate nexthop (1 hop away). + fn route_from_conn(&self, conn_info: &Arc) -> String { + match conn_info.connection_type() { + // direct route Global, InRegion, InCluster connection type is always to node level + ConnectionType::Global | ConnectionType::InRegion | ConnectionType::InCluster => { + conn_info.remote_service_path().to_incluster_prefix() + } + // both Client (for CLI) and InNode route by service prefix. IOW, service prefix uniquely + // identify the endpoint to the client (CLI or hamgrd) + ConnectionType::InNode | ConnectionType::Client => conn_info.remote_service_path().to_service_prefix(), + } + } + pub(crate) fn register(&self, conn_info: &Arc, proxy: SwbusConnProxy) { // Update the route table. - let path = conn_info.remote_service_path(); - let route_key = match conn_info.connection_type() { - ConnectionType::Global => path.to_regional_prefix(), - ConnectionType::Region => path.to_cluster_prefix(), - ConnectionType::Cluster => path.to_node_prefix(), - ConnectionType::Local => path.to_service_prefix(), - ConnectionType::Client => path.to_string(), - }; + let direct_route = self.route_from_conn(conn_info); + let nexthop = SwbusNextHop::new_remote(conn_info.clone(), proxy, 1); - self.update_route(route_key, nexthop); + self.update_route(direct_route, nexthop); + + // no need to add client connection or anounce client connection to the peers + if conn_info.connection_type() == ConnectionType::Client { + return; + } + + self.connections.insert(conn_info.clone()); + + self.announce_routes(conn_info, TriggerType::ConnectionUp) + .unwrap_or_else(|e| { + error!("Failed to send route announcement: {:?}", e); + }); } - pub(crate) fn unregister(&self, conn_info: Arc) { + pub(crate) fn unregister(&self, conn_info: &Arc) { // remove the route entry from the route table. - let path = conn_info.remote_service_path(); - let route_key = match conn_info.connection_type() { - ConnectionType::Global => path.to_regional_prefix(), - ConnectionType::Region => path.to_cluster_prefix(), - ConnectionType::Cluster => path.to_node_prefix(), - ConnectionType::Local => path.to_service_prefix(), - ConnectionType::Client => path.to_string(), - }; - self.routes.remove(&route_key); + let direct_route = self.route_from_conn(conn_info); + + // no need to add client connection or anounce client connection to the peers + if conn_info.connection_type() == ConnectionType::Client { + self.remove_route(&direct_route); + return; + } + + // get old routes, or create an empty one, and hold lock on the entry + let old_routes = self.routes_by_conn.entry(conn_info.clone()).or_default(); + + // create a dummy nexthop from conn_info to compare with the nexthop in the old_routes + let dummy_nh = SwbusNextHop::new_dummy_remote(conn_info.clone(), 1); + let mut need_announce = false; + + for entry in old_routes.iter() { + match Self::key_from_route_entry(entry) { + Ok(route_key) => { + let old_nh = dummy_nh.clone_with_hop_count(entry.hop_count + 1); + let route_removed = self.remove_route_to_nh(&route_key, &old_nh); + if route_removed { + need_announce = true; + } + } + Err(e) => { + error!("Failed to get route key from the route entry: {:?}, {:?}", entry, e); + } + } + } + + // remove the direct route + let route_removed = self.remove_route_to_nh(&direct_route, &dummy_nh); + if route_removed { + need_announce = true; + } + + self.connections.remove(conn_info); + + // release the lock on the entry + drop(old_routes); + self.routes_by_conn.remove(conn_info); + + if need_announce { + self.announce_routes(conn_info, TriggerType::ConnectionDown) + .unwrap_or_else(|e| { + error!("Failed to send route announcement: {:?}", e); + }); + } } + // Update route for the give service path with the new nexthop. If the route is updated, return true. + // Otherwise, return false. #[instrument(name = "update_route", level = "info", skip(self, nexthop), fields(nh_type=?nexthop.nh_type(), hop_count=nexthop.hop_count(), conn_info=nexthop.conn_info().as_ref().map(|x| x.id()).unwrap_or(&"None".to_string())))] - pub(crate) fn update_route(&self, route_key: String, nexthop: SwbusNextHop) { + fn update_route(&self, route_key: String, nexthop: SwbusNextHop) -> bool { // If route entry doesn't exist, we insert the next hop as a new one. - info!("Update route entry"); - match self.routes.entry(route_key) { - Entry::Occupied(mut existing) => { - let route_entry = existing.get(); - if route_entry.hop_count() > nexthop.hop_count() { - existing.insert(nexthop); - } else { - info!("Route entry already exists with smaller hop count"); - } - } - Entry::Vacant(entry) => { - entry.insert(nexthop); + let inserted = self.routes.entry(route_key).or_default().insert(nexthop.clone()); + info!("Update route entry: inserted={}", inserted); + inserted + } + + /// remove a specified route and to the specified nexthop + #[instrument(name = "remove_route_to_nh", level = "info", skip(self, nexthop), fields(nh_type=?nexthop.nh_type(), hop_count=nexthop.hop_count(), conn_info=nexthop.conn_info().as_ref().map(|x| x.id()).unwrap_or(&"None".to_string())))] + fn remove_route_to_nh(&self, route_key: &str, nexthop: &SwbusNextHop) -> bool { + let mut removed = false; + let mut remove_all = false; + if let Some(mut entry) = self.routes.get_mut(route_key) { + removed = entry.remove(nexthop); + if entry.is_empty() { + remove_all = true; + // can't remove route here because we are holding the lock on the entry. } } + if remove_all { + self.routes.remove(route_key); + } - // // If we already have one, then we update the entry only when we have a smaller hop count. - // // The dashmap RefMut reference will hold a lock to the entry, which makes this function atomic. - // if route_entry.hop_count > nexthop.hop_count { - // *route_entry.value_mut() = nexthop; - // } + info!("Remove route entry: removed={}", removed); + removed } - // Riff: The my route part is very confusing. Looks to be made for local service, but not really sure how it works. - pub fn set_my_routes(&self, routes: Vec) { - for route in routes { - if route.scope == RouteScope::Cluster { - // Create local service route - let route_key = route.key.to_node_prefix(); - let local_nh = SwbusNextHop::new_local(); - self.update_route(route_key, local_nh); + /// remove a specified route + #[instrument(name = "remove_route", level = "info", skip(self))] + fn remove_route(&self, route_key: &str) -> bool { + let removed = self.routes.remove(route_key); + info!("Remove route entry: removed={}", removed.is_some()); + removed.is_some() + } + + /// get the next hop to the connection. This is a direct route which means it is to a immediate nexthop (1 hop away). + fn get_nexthop_to_conn(&self, conn_info: &Arc) -> Result { + // get nexthop from conn_info + let direct_route = self.route_from_conn(conn_info); + + let nhs = self.routes.get(&direct_route).ok_or_else(|| { + // This could happen due to race condition if the connection is just removed. + SwbusError::internal( + SwbusErrorCode::Fail, + format!( + "Receive route announcement from {} but route to the connection {} is not found ", + conn_info.remote_service_path(), + direct_route + ), + ) + })?; + + let nh = nhs + .iter() + .find(|nh| nh.conn_info().as_ref().unwrap().id() == conn_info.id()); + + if let Some(nh) = nh { + debug!("Found direct route to the connection: {}", direct_route); + Ok(nh.clone()) + } else { + Err(SwbusError::internal( + SwbusErrorCode::Fail, + format!( + "Receive route announcement from {} but nh of the route is not from same connection: {}", + conn_info.remote_service_path(), + direct_route + ), + )) + } + } + + fn key_from_route_entry(entry: &RouteEntry) -> Result { + let sp: &ServicePath = match entry.service_path { + None => Err(SwbusError::input( + SwbusErrorCode::InvalidRoute, + "Route entry missing service path".to_string(), + ))?, + Some(ref sp) => sp, + }; + Ok(match RouteScope::try_from(entry.route_scope) { + Ok(RouteScope::Global) => sp.to_global_prefix(), + Ok(RouteScope::InRegion) => sp.to_inregion_prefix(), + Ok(RouteScope::InCluster) => sp.to_incluster_prefix(), + _ => Err(SwbusError::input( + SwbusErrorCode::InvalidRoute, + "Invalid route scope".to_string(), + ))?, + }) + } + + /// processing route announcement from a connection. RouteEntry in the announcement is a set of routes + /// that the connection can reach. RouteEntry only includes the service path and hop count from the connection. + pub(crate) fn process_route_announcement( + &self, + routes: RouteEntries, + conn_info: &Arc, + ) -> Result<()> { + debug!( + "Begin processing route announcement from {}", + conn_info.remote_service_path() + ); + + // get nh to the connection, which will be used as the next hop to the routes in the announcement + let nh = self.get_nexthop_to_conn(conn_info)?; + + let mut new_routes: BTreeSet = BTreeSet::new(); + + // FILTER OUT MY ROUTES. MY ROUTES already HAVE LOWEST HOP COUNT. + for entry in routes.entries.into_iter() { + if entry.service_path.is_none() { + error!( + "Received route announcement with missing service path from {}", + conn_info.remote_service_path() + ); + continue; + } + if self.my_routes.contains_key(entry.service_path.as_ref().unwrap()) { + debug!("Skip adding my route: {}", entry.service_path.as_ref().unwrap()); + continue; } + new_routes.insert(entry.clone()); + } + let mut need_announce = false; - self.my_routes.insert(route); + // get old routes, or create an empty one, and hold lock on the entry + let mut old_routes = self.routes_by_conn.entry(conn_info.clone()).or_default(); + let routes_to_remove: BTreeSet = old_routes.difference(&new_routes).cloned().collect(); + let routes_to_add: BTreeSet = new_routes.difference(&old_routes).cloned().collect(); + + for entry in routes_to_remove { + let route_key = match Self::key_from_route_entry(&entry) { + Ok(route_key) => route_key, + Err(e) => { + error!("Failed to get route key from the route entry: {:?}, {:?}", entry, e); + continue; + } + }; + let old_nh = nh.clone_with_hop_count(entry.hop_count + 1); + let route_removed = self.remove_route_to_nh(&route_key, &old_nh); + if route_removed { + need_announce = true; + } + } + + for entry in routes_to_add { + // clone the nexthop but increment hop count and use it in the routes from the announcement + let route_key = match Self::key_from_route_entry(&entry) { + Ok(route_key) => route_key, + Err(e) => { + error!("Failed to get route key from the route entry: {:?}, {:?}", entry, e); + continue; + } + }; + let new_nh = nh.clone_with_hop_count(entry.hop_count + 1); + + let route_changed = self.update_route(route_key, new_nh); + if route_changed { + need_announce = true; + } + } + + // notify route_announcer to send the routes to the other peers only if routes are updated. + if need_announce { + self.announce_routes(conn_info, TriggerType::RouteUpdated)?; } + *old_routes = new_routes; + debug!( + "Finished processing route announcement from {}", + conn_info.remote_service_path() + ); + Ok(()) } - pub fn get_my_service_path(&self) -> ServicePath { - self.my_routes - .iter() - .by_ref() - .next() - .expect("My route is not set") - .key - .clone() + fn announce_routes(&self, conn_info: &Arc, trigger: TriggerType) -> Result<()> { + if let Some(tx) = &self.route_annouce_task_tx { + return match tx.try_send(RouteAnnounceTask::new(trigger, conn_info.clone())) { + Ok(_) => Ok(()), + Err(e) => match e { + TrySendError::Full(_) => Err(SwbusError::route(SwbusErrorCode::QueueFull, e.to_string())), + _ => Err(SwbusError::route(SwbusErrorCode::NoRoute, e.to_string())), + }, + }; + } + Ok(()) + } + + pub fn get_my_service_path(&self) -> &ServicePath { + &self.my_service_path } + #[instrument(name="route_message", parent=None, level="debug", skip_all, fields(message_id=?message.header.as_ref().unwrap().id))] pub async fn route_message(&self, message: SwbusMessage) -> Result<()> { debug!( @@ -130,6 +442,7 @@ impl SwbusMultiplexer { .to_longest_path(), "Routing message" ); + let header = match message.header { Some(ref header) => header, None => { @@ -139,6 +452,7 @@ impl SwbusMultiplexer { )) } }; + let destination = match header.destination { Some(ref destination) => destination, None => { @@ -152,32 +466,40 @@ impl SwbusMultiplexer { for stage in &ROUTE_STAGES { let route_key = match stage { RouteStage::Local => destination.to_service_prefix(), - RouteStage::Cluster => destination.to_node_prefix(), - RouteStage::Region => destination.to_cluster_prefix(), - RouteStage::Global => destination.to_regional_prefix(), + RouteStage::Cluster => destination.to_incluster_prefix(), + RouteStage::Region => destination.to_inregion_prefix(), + RouteStage::Global => destination.to_global_prefix(), }; // If the route entry doesn't exist, we drop the message. - let nexthop = match self.routes.get(&route_key) { + let nhs = match self.routes.get(&route_key) { Some(entry) => entry, None => { continue; } }; - - // If the route entry is resolved, we forward the message to the next hop. - let response = nexthop.queue_message(self, message).await.unwrap(); - if let Some(response) = response { - Box::pin(self.route_message(response)).await.unwrap(); - } else { - // todo: try another nexthop if there is one + for nh in nhs.iter() { + // If the route entry is resolved, we forward the message to the next hop. + match nh.queue_message(self, message.clone()).await { + Ok(response) => { + if let Some(response) = response { + Box::pin(self.route_message(response)).await?; + } + return Ok(()); + } + Err(e) => { + error!( + "Failed to queue message to the next hop: {}. Will try with the next nexthop", + e + ); + } + } } - return Ok(()); } info!("No route found for destination: {}", destination.to_longest_path()); let response = SwbusMessage::new_response( - &message, - Some(&self.get_my_service_path()), + message.header.as_ref().unwrap(), + Some(self.get_my_service_path()), SwbusErrorCode::NoRoute, "Route not found", self.id_generator.generate(), @@ -189,105 +511,209 @@ impl SwbusMultiplexer { // Here it will send 'no-route' response[2] for response[1]. response[1] has source SP of A // because the response is originated from A. So response[2]'s dest is to A (itself). // Response[2] will be sent to a drop nexhop, which should drop the unexpected response packet. - Box::pin(self.route_message(response)).await.unwrap(); + Box::pin(self.route_message(response)).await?; Ok(()) } - pub fn export_routes(&self, scope: Option) -> RouteQueryResult { - let entries: Vec = self + /// dump_route_table exports routes for swbus-cli show routes, which needs all the details about routes and nexthops. + pub(crate) fn dump_route_table(&self) -> RouteEntries { + debug!("Dumping route table"); + let entries: Vec = self .routes .iter() .filter(|entry| { - if !matches!(entry.value().nh_type(), NextHopType::Remote) { + if !matches!(entry.value().first(), Some(s) if s.nh_type() == NextHopType::Remote) { return false; } - let route_scope = ServicePath::from_string(entry.key()).unwrap().route_scope(); - match scope { - Some(s) => route_scope >= s && route_scope >= RouteScope::Cluster, - None => true, + true + }) + .flat_map(|entry| { + // flatten the routes with one entry per nexthop + debug!("Dumping route {}: {:?}", entry.key(), entry.value()); + + entry + .value() + .iter() + .filter(|nh| { + // skip Client connection from CLI + nh.conn_info().is_some() + && nh.conn_info().as_ref().unwrap().connection_type() != ConnectionType::Client + }) + .map(|nh| RouteEntry { + service_path: Some( + ServicePath::from_string(entry.key()) + .expect("Not expecting service_path in route table to be invalid"), + ), + hop_count: nh.hop_count(), + nh_id: nh.conn_info().as_ref().unwrap().id().to_string(), + nh_service_path: Some(nh.conn_info().as_ref().unwrap().remote_service_path().clone()), + route_scope: ServicePath::from_string(entry.key()).unwrap().route_scope() as i32, + }) + .collect::>() + }) + .collect(); + RouteEntries { entries } + } + + /// Exports routes based on the connection type and target scope. + /// + /// This function filters and processes the routes to be exported to a target connection. + /// It considers the connection type and target scope to determine which routes are eligible + /// for export. The function performs the following steps: + /// + /// 1. Determines the target scope based on the connection type of the target connection. + /// 2. Filters the routes to include only those with a `NextHopType::Remote` and a route scope + /// greater than or equal to the target scope. + /// 3. For each eligible route, it finds the lowest hop count and creates new `RouteEntry` objects + /// for routes with the lowest hop count and active status, excluding routes that go through + /// the peer to which the routes are being exported. + /// 4. Additionally, it exports the routes from `my_routes` with a hop count of 0 and a route scope + /// greater than or equal to the target scope. + /// 5. Combines the filtered and processed routes into a `RouteEntries` object and returns it. + /// + /// # Arguments + /// + /// * `target_conn` - A reference to the target connection information. + /// + /// # Returns + /// + /// An `Option` containing the exported routes. If no routes are eligible for export, + /// an empty `RouteEntries` object is returned to allow the caller to send an empty route announcement + /// to the peer to remove all routes.usd to reach the destination. + pub fn export_routes(&self, target_conn: &Arc) -> Option { + let conn_type = target_conn.connection_type(); + let target_scope = match conn_type { + ConnectionType::InCluster => RouteScope::InCluster, + ConnectionType::InRegion => RouteScope::InRegion, + ConnectionType::Global => RouteScope::Global, + _ => return None, + }; + + let mut entries: Vec = self + .routes + .iter() + .filter(|entry| { + if !matches!(entry.value().first(), Some(s) if s.nh_type() == NextHopType::Remote) { + return false; } + let route_scope = ServicePath::from_string(entry.key()).unwrap().route_scope(); + route_scope >= target_scope }) - .map(|entry| RouteQueryResultEntry { - service_path: Some( - ServicePath::from_string(entry.key()) - .expect("Not expecting service_path in route table to be invalid"), - ), - hop_count: entry.value().hop_count(), - nh_id: entry.value().conn_info().as_ref().unwrap().id().to_string(), - nh_service_path: Some( - entry - .value() - .conn_info() - .as_ref() - .unwrap() - .remote_service_path() - .clone(), - ), - nh_scope: entry.value().conn_info().as_ref().unwrap().connection_type() as i32, + .flat_map(|entry| { + // Get the lowest hop_count (if set is non-empty) + let lowest_hop = match entry.value().iter().next() { + Some(first) => first.hop_count(), + None => return vec![], // Empty set, return empty vec + }; + + // Iterate, filter for lowest hop_count and active, create new objects + entry.value().iter() + .take_while(|nh| nh.hop_count() <= lowest_hop + 1) // Stop after lowest group + // do not export routes that go through the peer to which the routes are being exported + .find(|nh| nh.conn_info().as_ref().unwrap().remote_service_path() != target_conn.remote_service_path()) + .map(|nh| vec![RouteEntry { + service_path: Some( + ServicePath::from_string(entry.key()) + .expect("Not expecting service_path in route table to be invalid"), + ), + hop_count: nh.hop_count(), + nh_id: "".to_string(), + nh_service_path: None, + route_scope: ServicePath::from_string(entry.key()).unwrap().route_scope() as i32, + }]) + .unwrap_or(vec![]) + }) + .collect(); + + // export my_routes with hop count 0 + let my_route_entries: Vec = self + .my_routes + .iter() + .filter(|(_, &value)| value >= target_scope) + .map(|(key, value)| RouteEntry { + service_path: Some(key.clone()), + hop_count: 0, + nh_id: "".to_string(), + nh_service_path: None, + route_scope: *value as i32, }) .collect(); - RouteQueryResult { entries } + entries.extend(my_route_entries); + // even if there is no route, return an empty RouteEntries so the caller can send an empty route announcement to the peer to remove all routes. + + Some(RouteEntries { entries }) + } + + pub fn export_connections(&self) -> Vec> { + self.connections.iter().map(|conn| conn.clone()).collect() + } + + pub fn shutdown(&mut self) { + if let Some(tx) = &self.route_announer_ct { + tx.cancel(); + self.route_announer_ct = None; + self.route_annouce_task_tx = None; + } } } #[cfg(test)] -mod tests { - use pretty_assertions::assert_eq; - use tokio::{sync::mpsc, time}; - use tonic::Status; - +pub mod test_utils { use super::*; use crate::mux::SwbusConn; use tokio::time::Duration; + use tokio::{sync::mpsc, time}; + use tonic::Status; - #[test] - fn test_set_my_routes() { - let mux = Arc::new(SwbusMultiplexer::new()); - let route_config = RouteConfig { - key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), - scope: RouteScope::Cluster, - }; - mux.set_my_routes(vec![route_config.clone()]); - assert!(mux.my_routes.contains(&route_config)); - - let nh = mux.routes.get(&route_config.key.to_node_prefix()).unwrap(); - assert_eq!(nh.nh_type(), NextHopType::Local); + // Helper function to create a new SwbusConn for testing. + pub(crate) fn new_conn_for_test( + conn_type: ConnectionType, + sp: &str, + ) -> (SwbusConn, mpsc::Receiver>) { + new_conn_for_test_with_endpoint(conn_type, sp, "127.0.0.1:8080") } - fn add_route( - mux: &SwbusMultiplexer, - route_key: &str, - hop_count: u32, - nh_sp: &str, - nh_conn_type: ConnectionType, - ) -> mpsc::Receiver> { + pub(crate) fn new_conn_for_test_with_endpoint( + conn_type: ConnectionType, + sp: &str, + nh_endpoint: &str, + ) -> (SwbusConn, mpsc::Receiver>) { let conn_info = Arc::new(SwbusConnInfo::new_client( - nh_conn_type, - "127.0.0.1:8080".parse().unwrap(), - ServicePath::from_string(nh_sp).unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), + conn_type, + nh_endpoint.parse().unwrap(), + ServicePath::from_string(sp).unwrap(), )); let (send_queue_tx, send_queue_rx) = mpsc::channel(16); - let conn = SwbusConn::new(&conn_info, send_queue_tx); - - let nexthop_nh1 = SwbusNextHop::new_remote(conn_info.clone(), conn.new_proxy(), hop_count); - mux.update_route(route_key.to_string(), nexthop_nh1); - send_queue_rx + (SwbusConn::new(&conn_info, send_queue_tx), send_queue_rx) } - - async fn route_message_and_compare( + // Helper function to add a route to the multiplexer. + pub(crate) fn add_route( mux: &SwbusMultiplexer, + route_key: &str, + hop_count: u32, + conn: &SwbusConn, + ) -> Option { + let nexthop_nh1 = SwbusNextHop::new_remote(conn.info().clone(), conn.new_proxy(), hop_count); + if mux.update_route(route_key.to_string(), nexthop_nh1) { + Some(RouteEntry { + service_path: Some(ServicePath::from_string(route_key).unwrap()), + hop_count, + nh_id: conn.info().id().to_string(), + nh_service_path: Some(conn.info().remote_service_path().clone()), + route_scope: ServicePath::from_string(route_key).unwrap().route_scope() as i32, + }) + } else { + None + } + } + + pub(crate) async fn receive_and_compare( send_queue_rx: &mut mpsc::Receiver>, - request: &str, expected: &str, ) { - let request_msg: SwbusMessage = serde_json::from_str(request).unwrap(); let expected_msg: SwbusMessage = serde_json::from_str(expected).unwrap(); - - let result = mux.route_message(request_msg).await; - assert!(result.is_ok()); match time::timeout(Duration::from_secs(1), send_queue_rx.recv()).await { Ok(Some(msg)) => { let normalized_msg = swbus_proto::swbus::normalize_msg(&msg.ok().unwrap()); @@ -300,31 +726,100 @@ mod tests { } } + pub(crate) async fn route_message_and_compare( + mux: &SwbusMultiplexer, + send_queue_rx: &mut mpsc::Receiver>, + request: &str, + expected: &str, + ) { + let request_msg: SwbusMessage = serde_json::from_str(request).unwrap(); + + let result = mux.route_message(request_msg).await; + assert!(result.is_ok()); + receive_and_compare(send_queue_rx, expected).await; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mux::test_utils::*; + use crate::mux::SwbusConn; + use pretty_assertions::assert_eq; + use tokio::sync::mpsc; + + // Helper function to create a new RouteEntry for route announcement. + fn new_route_entry_for_ra(sp: &str, hop_count: u32) -> RouteEntry { + let sp = ServicePath::from_string(sp).unwrap(); + let route_scope = sp.route_scope(); + RouteEntry { + service_path: Some(sp), + hop_count, + nh_id: "".to_string(), + nh_service_path: None, + route_scope: route_scope as i32, + } + } + + fn route_config_to_route_entry(route_config: &RouteConfig) -> RouteEntry { + RouteEntry { + service_path: Some(route_config.key.clone()), + hop_count: 0, + nh_id: "".to_string(), + nh_service_path: None, + route_scope: route_config.scope as i32, + } + } + + fn clone_route_entry_without_nh(route_entry: &RouteEntry) -> RouteEntry { + RouteEntry { + service_path: route_entry.service_path.clone(), + hop_count: route_entry.hop_count, + nh_id: "".to_string(), + nh_service_path: None, + route_scope: route_entry.route_scope, + } + } + + #[test] + fn test_set_my_routes() { + let incluster_route = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), + scope: RouteScope::InCluster, + }; + let inregion_route = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a").unwrap(), + scope: RouteScope::InRegion, + }; + + let mux = Arc::new(SwbusMultiplexer::new(vec![ + incluster_route.clone(), + inregion_route.clone(), + ])); + + assert_eq!(*mux.my_routes.get(&incluster_route.key).unwrap(), RouteScope::InCluster); + assert_eq!(*mux.my_routes.get(&inregion_route.key).unwrap(), RouteScope::InRegion); + + let nhs = mux.routes.get(&incluster_route.key.to_incluster_prefix()).unwrap(); + assert_eq!(nhs.first().unwrap().nh_type(), NextHopType::Local); + assert_eq!(&incluster_route.key, mux.get_my_service_path()); + } + #[tokio::test] async fn test_route_message() { - let mux = Arc::new(SwbusMultiplexer::new()); - let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); - let _ = add_route( - &mux, - "region-a.cluster-a.10.0.0.1-dpu0", - 1, - "region-a.cluster-a.10.0.0.1-dpu0", - ConnectionType::Cluster, - ); - let mut send_queue_rx3 = add_route( - &mux, - "region-a.cluster-a.10.0.0.3-dpu0", - 1, - "region-a.cluster-a.10.0.0.3-dpu0", - ConnectionType::Cluster, - ); + let (conn1, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.1-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.1-dpu0", 1, &conn1).unwrap(); + + let (conn3, mut send_queue_rx3) = + new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.3-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.3-dpu0", 1, &conn3).unwrap(); let request = r#" { @@ -361,29 +856,19 @@ mod tests { #[tokio::test] async fn test_route_message_unreachable() { - let mux = Arc::new(SwbusMultiplexer::new()); - let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); - let mut send_queue_rx1 = add_route( - &mux, - "region-a.cluster-a.10.0.0.1-dpu0", - 1, - "region-a.cluster-a.10.0.0.1-dpu0", - ConnectionType::Cluster, - ); - let _ = add_route( - &mux, - "region-a.cluster-a.10.0.0.3-dpu0", - 1, - "region-a.cluster-a.10.0.0.3-dpu0", - ConnectionType::Cluster, - ); + let (conn1, mut send_queue_rx1) = + new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.1-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.1-dpu0", 1, &conn1).unwrap(); + + let (conn3, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.3-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.3-dpu0", 1, &conn3).unwrap(); let request = r#" { @@ -425,22 +910,16 @@ mod tests { #[tokio::test] async fn test_route_message_noroute() { - let mux = Arc::new(SwbusMultiplexer::new()); - let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); - let mut send_queue_rx1 = add_route( - &mux, - "region-a.cluster-a.10.0.0.1-dpu0", - 1, - "region-a.cluster-a.10.0.0.1-dpu0", - ConnectionType::Cluster, - ); + let (conn1, mut send_queue_rx1) = + new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.1-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.1-dpu0", 1, &conn1).unwrap(); let request = r#" { @@ -482,14 +961,12 @@ mod tests { #[tokio::test] async fn test_route_message_isolated() { - let mux = Arc::new(SwbusMultiplexer::new()); - let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); let request = r#" { @@ -514,60 +991,514 @@ mod tests { assert!(result.is_ok()); } + #[tokio::test] + /// Create alternative paths and verify the message is routed to the new path when the primary path is down + async fn test_route_message_failover() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), + scope: RouteScope::InCluster, + }; + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); + + let (conn1, mut send_queue_rx1) = + new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.1-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.3-dpu0", 1, &conn1).unwrap(); + + // add an alternative route + let (conn3, mut send_queue_rx3) = + new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.3-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.3-dpu0", 3, &conn3).unwrap(); + + let request = r#" + { + "header": { + "version": 1, + "id": 0, + "flag": 0, + "ttl": 63, + "source": "region-a.cluster-a.10.0.0.1-dpu0/testsvc/0/ping/0", + "destination": "region-a.cluster-a.10.0.0.3-dpu0/local-mgmt/0" + }, + "body": { + "PingRequest": {} + } + } + "#; + let expected = r#" + { + "header": { + "version": 1, + "id": 0, + "flag": 0, + "ttl": 62, + "source": "region-a.cluster-a.10.0.0.1-dpu0/testsvc/0/ping/0", + "destination": "region-a.cluster-a.10.0.0.3-dpu0/local-mgmt/0" + }, + "body": { + "PingRequest": {} + } + } + "#; + route_message_and_compare(&mux, &mut send_queue_rx1, request, expected).await; + + // close the primary path + send_queue_rx1.close(); + route_message_and_compare(&mux, &mut send_queue_rx3, request, expected).await; + } + + #[test] + fn test_dump_route_table() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.node0").unwrap(), + scope: RouteScope::InCluster, + }; + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); + + // build route table + let (conn1, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node1"); + let node1_re1 = add_route(&mux, "region-a.cluster-a.node1", 1, &conn1).unwrap(); + let node1_re2 = add_route(&mux, "region-a.cluster-b", 1, &conn1).unwrap(); + + let (conn2, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node2"); + let node2_re1 = add_route(&mux, "region-a.cluster-a.node2", 1, &conn2).unwrap(); + let node2_re2 = add_route(&mux, "region-a.cluster-b", 2, &conn2).unwrap(); + + // dump route table. Expect: all routes are dumped but not my routes + let routes = mux.dump_route_table(); + + let mut actual: Vec = routes.entries; + actual.sort(); + + let mut expected = vec![node1_re1, node1_re2, node2_re1, node2_re2]; + expected.sort(); + assert_eq!(actual, expected); + } + #[test] fn test_export_routes() { - let mux = Arc::new(SwbusMultiplexer::new()); + // create a mux with 2 my-routes + let my_route1 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.node0").unwrap(), + scope: RouteScope::InCluster, + }; + let my_route2 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a").unwrap(), + scope: RouteScope::InRegion, + }; + let mux = Arc::new(SwbusMultiplexer::new(vec![my_route1.clone(), my_route2.clone()])); + + let my_re1 = route_config_to_route_entry(&my_route1); + let my_re2 = route_config_to_route_entry(&my_route2); + + // build route table with routes over 2 peers + let (conn1, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node1"); + let _ = add_route(&mux, "region-a.cluster-a.node1", 1, &conn1).unwrap(); + // add a regional route + let node1_re2 = add_route(&mux, "region-a.cluster-b", 2, &conn1).unwrap(); + let (conn2, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node2"); + let node2_re1 = add_route(&mux, "region-a.cluster-a.node2", 1, &conn2).unwrap(); + // add a global route + let node2_re2 = add_route(&mux, "region-a", 2, &conn2).unwrap(); + // add a route with hop-count+1 to the same destination as node1_re2 + let node2_re3 = add_route(&mux, "region-a.cluster-b", 3, &conn2).unwrap(); + // add a route with hop-count+2 to the same destination as node1_re1 + let _ = add_route(&mux, "region-a.cluster-a.node1", 3, &conn2).unwrap(); + + let (conn3, _) = new_conn_for_test(ConnectionType::InRegion, "region-a.cluster-b.node1"); + + // export routes to one of the connection. + // Expect: + // 1. routes via the target peer are suppressed. + // 2. alternative routes with the same hop count or hop-count+1 are included. + // 3. alternative routes with hop-count+2 are suppressed. + let routes = mux.export_routes(conn1.info()); + let mut actual: Vec = routes.unwrap().entries; + actual.sort(); + let mut expected = vec![ + clone_route_entry_without_nh(&node2_re1), + clone_route_entry_without_nh(&node2_re2), + clone_route_entry_without_nh(&node2_re3), + my_re1.clone(), + my_re2.clone(), + ]; + expected.sort(); + assert_eq!(actual, expected); + // export routes to a InRegion connection. Expect: routes with InCluster scope are suppressed. + let routes = mux.export_routes(conn3.info()); + let mut actual: Vec = routes.unwrap().entries; + actual.sort(); + let mut expected = vec![ + clone_route_entry_without_nh(&node1_re2), + clone_route_entry_without_nh(&node2_re2), + my_re2.clone(), + ]; + expected.sort(); + assert_eq!(actual, expected); + } + + #[tokio::test] + async fn test_process_route_announcement() { + // create a mux with my routes and add a dummy route announcer + // create a mux with 2 my-routes + let my_route1 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.node0").unwrap(), + scope: RouteScope::InCluster, + }; + let my_route2 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a").unwrap(), + scope: RouteScope::InRegion, + }; + let mut mux = SwbusMultiplexer::new(vec![my_route1.clone(), my_route2.clone()]); + let (route_announce_task_tx, mut route_announce_task_rx) = mpsc::channel(16); + mux.set_route_announcer(route_announce_task_tx, CancellationToken::new()); + let mux = Arc::new(mux); + // add a direct route to the node1 + let (conn1, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node1"); + add_route(&mux, "region-a.cluster-a.node1", 1, &conn1).unwrap(); + let conn1_nh = SwbusNextHop::new_remote(conn1.info().clone(), conn1.new_proxy(), 1); + + // create some route entries + // route entry matches my routes. Should be skipped + let node1_re1 = new_route_entry_for_ra("region-a.cluster-a.node1", 0); + // negative test: route announcement with scope lower than InCluster + let node1_re2 = new_route_entry_for_ra("region-a.cluster-a.node1/ha/0", 1); + let node2_re1 = new_route_entry_for_ra("region-a.cluster-a.node2", 1); + let clusterb_re1 = new_route_entry_for_ra("region-a.cluster-b", 2); + + // Step1: process a route announcement + let routes = RouteEntries { + entries: vec![ + node1_re1.clone(), + node1_re2.clone(), + node2_re1.clone(), + clusterb_re1.clone(), + ], + }; + let expected_route_entries = BTreeSet::from_iter(vec![ + node1_re1.clone(), + node1_re2.clone(), + node2_re1.clone(), + clusterb_re1.clone(), + ]); + mux.process_route_announcement(routes.clone(), conn1.info()).unwrap(); + // Expect: + // 1. new routes are updated + // 2. my routes are skipped + // 3. route anouncement task is created + // 4. route entries are updated in routes_by_conn + let mut expected = HashMap::new(); + // add my incluster route + expected.insert( + "region-a.cluster-a.node0".to_string(), + BTreeSet::from([SwbusNextHop::new_local()]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node1_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(1)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node2_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(2)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&clusterb_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(3)]), + ); + + let actual: HashMap> = mux + .routes + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + assert_eq!(actual, expected); + assert_eq!( + route_announce_task_rx.try_recv(), + Ok(RouteAnnounceTask::new(TriggerType::RouteUpdated, conn1.info().clone())) + ); + { + let routes_by_conn = mux.routes_by_conn.get(conn1.info()).unwrap(); + assert_eq!(*routes_by_conn, expected_route_entries); + } + + // Step 2: process a route announcement with the same routes + mux.process_route_announcement(routes.clone(), conn1.info()).unwrap(); + // Expect: + // 1. no route is updated + // 2. no route anouncement task is created + assert_eq!( + route_announce_task_rx.try_recv(), + Err(tokio::sync::mpsc::error::TryRecvError::Empty) + ); + + // Step 3: process a route announcement with changed and removed route + let node2_re1_changed = new_route_entry_for_ra("region-a.cluster-a.node2", 2); + let clusterb_re1 = new_route_entry_for_ra("region-a.cluster-b", 2); + let node3_re1 = new_route_entry_for_ra("region-a.cluster-a.node3", 2); + // process a route announcement + let routes = RouteEntries { + entries: vec![ + node1_re1.clone(), + node1_re2.clone(), + node2_re1_changed.clone(), + clusterb_re1.clone(), + node3_re1.clone(), + ], + }; + let expected_route_entries = BTreeSet::from_iter(vec![ + node1_re1.clone(), + node1_re2.clone(), + node2_re1_changed.clone(), + clusterb_re1.clone(), + node3_re1.clone(), + ]); + let mut expected = HashMap::new(); + // add my incluster route + expected.insert( + "region-a.cluster-a.node0".to_string(), + BTreeSet::from([SwbusNextHop::new_local()]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node1_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(1)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node2_re1_changed).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(3)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&clusterb_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(3)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node3_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(3)]), + ); + // Expect: + // 1. updated routes are updated + // 2. removed routes are removed + // 3. route anouncement task is created + // 4. route entries are updated in routes_by_conn + mux.process_route_announcement(routes.clone(), conn1.info()).unwrap(); + + let actual: HashMap> = mux + .routes + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + assert_eq!(actual, expected); + assert_eq!( + route_announce_task_rx.try_recv(), + Ok(RouteAnnounceTask::new(TriggerType::RouteUpdated, conn1.info().clone())) + ); + { + let routes_by_conn = mux.routes_by_conn.get(conn1.info()).unwrap(); + assert_eq!(*routes_by_conn, expected_route_entries); + } + } + + #[tokio::test] + async fn test_get_nexthop_to_conn() { let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); - let mut _send_queue_rx1 = add_route( - &mux, - "region-a.cluster-a.10.0.0.1-dpu0", - 1, + let (conn1, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.1-dpu0"); + add_route(&mux, "region-a.cluster-a.10.0.0.1-dpu0", 1, &conn1).unwrap(); + + // add the same route with a different connection + let conn_info = Arc::new(SwbusConnInfo::new_client( + ConnectionType::InCluster, + "127.0.0.1:8081".parse().unwrap(), + ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), + )); + let (send_queue_tx, _) = mpsc::channel(16); + let conn2 = SwbusConn::new(&conn_info, send_queue_tx); + add_route(&mux, "region-a.cluster-a.10.0.0.1-dpu0", 1, &conn2).unwrap(); + + // Test case: verify we get nexthop for both connections + let result = mux.get_nexthop_to_conn(conn1.info()); + assert!(result.is_ok()); + let nexthop = result.unwrap(); + assert_eq!(nexthop.conn_info().as_ref().unwrap(), conn1.info()); + + let result = mux.get_nexthop_to_conn(conn2.info()); + assert!(result.is_ok()); + let nexthop = result.unwrap(); + assert_eq!(nexthop.conn_info().as_ref().unwrap(), conn2.info()); + + // Test case: connection not found + let (conn2, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.10.0.0.3-dpu0"); + let result = mux.get_nexthop_to_conn(conn2.info()); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(matches!(error, SwbusError::InternalError { code: _, detail: _ })); + } + + #[test] + fn test_remove_existing_nexthop() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), + scope: RouteScope::InCluster, + }; + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); + + let route_key = "region-a.cluster-a.node1"; + let (conn1, _) = new_conn_for_test_with_endpoint( + ConnectionType::InCluster, "region-a.cluster-a.10.0.0.1-dpu0", - ConnectionType::Cluster, + "127.0.0.1:8080", ); - let mut _send_queue_rx3 = add_route( - &mux, - "region-a.cluster-b", - 1, - "region-a.cluster-b.10.0.0.1-dpu0", - ConnectionType::Region, + add_route(&mux, route_key, 1, &conn1).unwrap(); + let nexthop1 = SwbusNextHop::new_remote(conn1.info().clone(), conn1.new_proxy(), 1); + + let (conn2, _) = new_conn_for_test_with_endpoint( + ConnectionType::InCluster, + "region-a.cluster-a.10.0.0.1-dpu1", + "127.0.0.1:8081", ); + add_route(&mux, route_key, 1, &conn2).unwrap(); + let nexthop2 = SwbusNextHop::new_remote(conn2.info().clone(), conn2.new_proxy(), 1); - let routes = mux.export_routes(Some(RouteScope::Cluster)); - let json_string = serde_json::to_string(&routes).unwrap(); - let normalized_routes: RouteQueryResult = serde_json::from_str(&json_string).unwrap(); + assert!(mux.remove_route_to_nh(route_key, &nexthop1)); + assert!(mux.routes.contains_key(route_key)); + assert!(mux.remove_route_to_nh(route_key, &nexthop2)); - let entry1 = RouteQueryResultEntry { - service_path: Some(ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap()), - hop_count: 1, - nh_id: "".to_string(), - nh_service_path: Some(ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap()), - nh_scope: RouteScope::Cluster as i32, + assert!(!mux.routes.contains_key(route_key)); // Ensure route is removed when empty + } + + #[test] + fn test_remove_non_existing_nexthop() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), + scope: RouteScope::InCluster, }; - let entry2 = RouteQueryResultEntry { - service_path: Some(ServicePath::from_string("region-a.cluster-b").unwrap()), - hop_count: 1, - nh_id: "".to_string(), - nh_service_path: Some(ServicePath::from_string("region-a.cluster-b.10.0.0.1-dpu0").unwrap()), - nh_scope: RouteScope::Region as i32, + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); + + let route_key = "region-a.cluster-a.node1"; + let (conn1, _) = new_conn_for_test_with_endpoint( + ConnectionType::InCluster, + "region-a.cluster-a.10.0.0.1-dpu0", + "127.0.0.1:8080", + ); + add_route(&mux, route_key, 1, &conn1).unwrap(); + + let (conn2, _) = new_conn_for_test_with_endpoint( + ConnectionType::InCluster, + "region-a.cluster-a.10.0.0.1-dpu1", + "127.0.0.1:8081", + ); + let nexthop2 = SwbusNextHop::new_remote(conn2.info().clone(), conn2.new_proxy(), 1); + + assert!(!mux.remove_route_to_nh(route_key, &nexthop2)); + + assert!(mux.routes.contains_key(route_key)); // Ensure route is not removed + } + + #[test] + fn test_register_unregister() { + // create mux + // create a mux with my routes and add a dummy route announcer + // create a mux with 2 my-routes + let my_route1 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.node0").unwrap(), + scope: RouteScope::InCluster, }; + let mut mux = SwbusMultiplexer::new(vec![my_route1.clone()]); + let (route_announce_task_tx, mut route_announce_task_rx) = mpsc::channel(16); + mux.set_route_announcer(route_announce_task_tx, CancellationToken::new()); + let mux = Arc::new(mux); + + // register a connection + let (conn1, _) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node1"); + mux.register(conn1.info(), conn1.new_proxy()); + let conn1_nh = SwbusNextHop::new_remote(conn1.info().clone(), conn1.new_proxy(), 1); + assert_eq!( + route_announce_task_rx.try_recv(), + Ok(RouteAnnounceTask::new(TriggerType::ConnectionUp, conn1.info().clone())) + ); + assert!(mux.connections.contains(conn1.info())); + + // create some route entries + let node1_re1 = new_route_entry_for_ra("region-a.cluster-a.node1", 0); + let node2_re1 = new_route_entry_for_ra("region-a.cluster-a.node2", 1); + let clusterb_re1 = new_route_entry_for_ra("region-a.cluster-b", 2); - let expected = RouteQueryResult { - entries: vec![entry1.clone(), entry2.clone()], + // Step1: process a route announcement + let routes = RouteEntries { + entries: vec![node1_re1.clone(), node2_re1.clone(), clusterb_re1.clone()], }; - assert_eq!(normalized_routes, expected); + let expected_route_entries = + BTreeSet::from_iter(vec![node1_re1.clone(), node2_re1.clone(), clusterb_re1.clone()]); + mux.process_route_announcement(routes.clone(), conn1.info()).unwrap(); + // Expect: + // 1. new routes are updated + // 2. my routes are skipped + // 3. route anouncement task is created + // 4. route entries are updated in routes_by_conn + let mut expected = HashMap::new(); + // add my incluster route + expected.insert( + "region-a.cluster-a.node0".to_string(), + BTreeSet::from([SwbusNextHop::new_local()]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node1_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(1)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&node2_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(2)]), + ); + expected.insert( + SwbusMultiplexer::key_from_route_entry(&clusterb_re1).unwrap(), + BTreeSet::from([conn1_nh.clone_with_hop_count(3)]), + ); + + let actual: HashMap> = mux + .routes + .iter() + .map(|entry| (entry.key().clone(), entry.value().clone())) + .collect(); + assert_eq!(actual, expected); + + assert_eq!( + route_announce_task_rx.try_recv(), + Ok(RouteAnnounceTask::new(TriggerType::RouteUpdated, conn1.info().clone())) + ); + { + let routes_by_conn = mux.routes_by_conn.get(conn1.info()).unwrap(); + assert_eq!(*routes_by_conn, expected_route_entries); + } - let routes = mux.export_routes(Some(RouteScope::Region)); - let json_string = serde_json::to_string(&routes).unwrap(); - let normalized_routes: RouteQueryResult = serde_json::from_str(&json_string).unwrap(); - let expected = RouteQueryResult { entries: vec![entry2] }; - assert_eq!(normalized_routes, expected); + // unregister the connection + mux.unregister(conn1.info()); + // Expect: + // the connection is removed from routes_by_conn + assert!(!mux.routes_by_conn.contains_key(conn1.info())); + // the routes in route announcement task are removed + assert!(!mux + .routes + .contains_key(&SwbusMultiplexer::key_from_route_entry(&node1_re1).unwrap())); + assert!(!mux + .routes + .contains_key(&SwbusMultiplexer::key_from_route_entry(&node2_re1).unwrap())); + assert!(!mux + .routes + .contains_key(&SwbusMultiplexer::key_from_route_entry(&clusterb_re1).unwrap())); + + assert_eq!( + route_announce_task_rx.try_recv(), + Ok(RouteAnnounceTask::new( + TriggerType::ConnectionDown, + conn1.info().clone() + )) + ); + assert!(!mux.connections.contains(conn1.info())); } } diff --git a/crates/swbus-core/src/mux/nexthop.rs b/crates/swbus-core/src/mux/nexthop.rs index d8e451d..3efe52c 100644 --- a/crates/swbus-core/src/mux/nexthop.rs +++ b/crates/swbus-core/src/mux/nexthop.rs @@ -3,19 +3,20 @@ use super::SwbusConnProxy; use super::SwbusMultiplexer; use getset::CopyGetters; use getset::Getters; +use std::cmp::Ordering; use std::sync::Arc; use swbus_proto::result::*; use swbus_proto::swbus::*; -use swbus_proto::swbus::{swbus_message, ManagementRequestType, SwbusMessage}; +use swbus_proto::swbus::{swbus_message, SwbusMessage}; use tracing::*; -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub(crate) enum NextHopType { Local, Remote, } -#[derive(Clone, Getters, CopyGetters)] +#[derive(Clone, Getters, CopyGetters, Debug)] pub(crate) struct SwbusNextHop { #[getset(get_copy = "pub")] nh_type: NextHopType, @@ -30,6 +31,34 @@ pub(crate) struct SwbusNextHop { hop_count: u32, } +impl PartialEq for SwbusNextHop { + fn eq(&self, other: &Self) -> bool { + self.hop_count == other.hop_count && self.nh_type == other.nh_type && self.conn_info == other.conn_info + } +} + +impl Eq for SwbusNextHop {} + +impl PartialOrd for SwbusNextHop { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SwbusNextHop { + fn cmp(&self, other: &Self) -> Ordering { + match self.hop_count.cmp(&other.hop_count) { + Ordering::Equal => match (&self.conn_info, &other.conn_info) { + (Some(ref a), Some(ref b)) => a.id().cmp(b.id()), + (Some(_), None) => Ordering::Greater, + (None, Some(_)) => Ordering::Less, + (None, None) => Ordering::Equal, + }, + other => other, + } + } +} + impl SwbusNextHop { pub fn new_remote(conn_info: Arc, conn_proxy: SwbusConnProxy, hop_count: u32) -> Self { SwbusNextHop { @@ -49,6 +78,25 @@ impl SwbusNextHop { } } + pub fn clone_with_hop_count(&self, hop_count: u32) -> Self { + SwbusNextHop { + nh_type: self.nh_type, + conn_info: self.conn_info.clone(), + conn_proxy: self.conn_proxy.clone(), + hop_count, + } + } + + /// used in finding existing routes to remove + pub fn new_dummy_remote(conn_info: Arc, hop_count: u32) -> Self { + SwbusNextHop { + nh_type: NextHopType::Remote, + conn_info: Some(conn_info), + conn_proxy: None, + hop_count, + } + } + #[instrument(name="queue_message", parent=None, level="debug", skip_all, fields(nh_type=?self.nh_type, conn_info=self.conn_info.as_ref().map(|x| x.id()).unwrap_or(&"None".to_string()), message.id=?message.header.as_ref().unwrap().id))] pub async fn queue_message( &self, @@ -69,8 +117,8 @@ impl SwbusNextHop { if header.ttl == 0 { debug!("TTL expired"); let response = SwbusMessage::new_response( - &message, - Some(&mux.get_my_service_path()), + message.header.as_ref().unwrap(), + Some(mux.get_my_service_path()), SwbusErrorCode::Unreachable, "TTL expired", mux.generate_message_id(), @@ -101,7 +149,7 @@ impl SwbusNextHop { // there is no route to the service, the packet will be routed to here. We need to // return no route error in this case. let response = SwbusMessage::new_response( - &message, + message.header.as_ref().unwrap(), None, SwbusErrorCode::NoRoute, "Route not found", @@ -111,10 +159,7 @@ impl SwbusNextHop { return Ok(Some(response)); } let response = match message.body.as_ref() { - Some(swbus_message::Body::PingRequest(_)) => self.process_ping_request(mux, message).unwrap(), - Some(swbus_message::Body::ManagementRequest(mgmt_request)) => { - self.process_mgmt_request(mux, &message, mgmt_request).unwrap() - } + Some(swbus_message::Body::PingRequest(_)) => self.process_ping_request(mux, message)?, _ => { // drop all other messages. This could happen due to message loop or other invaid messages to swbusd. debug!("Drop unknown message to a local endpoint"); @@ -128,7 +173,7 @@ impl SwbusNextHop { debug!("Received ping request"); let id = mux.generate_message_id(); Ok(SwbusMessage::new_response( - &message, + message.header.as_ref().unwrap(), None, SwbusErrorCode::Ok, "", @@ -136,45 +181,12 @@ impl SwbusNextHop { None, )) } - - fn process_mgmt_request( - &self, - mux: &SwbusMultiplexer, - message: &SwbusMessage, - mgmt_request: &ManagementRequest, - ) -> Result { - let request_type = ManagementRequestType::try_from(mgmt_request.request).map_err(|_| { - SwbusError::input( - SwbusErrorCode::InvalidArgs, - format!("Invalid management request: {:?}", mgmt_request.request), - ) - })?; - - match request_type { - ManagementRequestType::SwbusdGetRoutes => { - debug!("Received show_route request"); - let routes = mux.export_routes(None); - let response_msg = SwbusMessage::new_response( - message, - None, - SwbusErrorCode::Ok, - "", - mux.generate_message_id(), - Some(request_response::ResponseBody::RouteQueryResult(routes)), - ); - Ok(response_msg) - } - _ => Err(SwbusError::input( - SwbusErrorCode::InvalidArgs, - format!("Invalid management request: {mgmt_request:?}"), - )), - } - } } #[cfg(test)] mod tests { use super::*; + use crate::mux::test_utils::*; use crate::mux::SwbusConn; use std::sync::Arc; use swbus_config::RouteConfig; @@ -184,10 +196,9 @@ mod tests { #[tokio::test] async fn test_new_remote() { let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); let (send_queue_tx, _) = mpsc::channel(16); let conn = SwbusConn::new(&conn_info, send_queue_tx); @@ -211,8 +222,14 @@ mod tests { #[tokio::test] async fn test_queue_message_drop() { + let route_config = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), + scope: RouteScope::InCluster, + }; + + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); + let nexthop = SwbusNextHop::new_local(); - let mux = SwbusMultiplexer::default(); let message = SwbusMessage { header: Some(SwbusMessageHeader::new( ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap(), @@ -227,14 +244,13 @@ mod tests { #[tokio::test] async fn test_queue_message_local_ping() { - let nexthop = SwbusNextHop::new_local(); - let mux = Arc::new(SwbusMultiplexer::default()); let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); - mux.set_my_routes(vec![route_config.clone()]); + let nexthop = SwbusNextHop::new_local(); let request = r#" { @@ -265,22 +281,21 @@ mod tests { #[tokio::test] async fn test_queue_message_remote_ttl_expired() { let conn_info = Arc::new(SwbusConnInfo::new_client( - ConnectionType::Cluster, + ConnectionType::InCluster, "127.0.0.1:8080".parse().unwrap(), ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), - ServicePath::from_string("regiona.clustera.10.0.0.1-dpu0").unwrap(), )); let (send_queue_tx, _) = mpsc::channel(16); let conn = SwbusConn::new(&conn_info, send_queue_tx); let hop_count = 5; let nexthop = SwbusNextHop::new_remote(conn_info.clone(), conn.new_proxy(), hop_count); - let mux = Arc::new(SwbusMultiplexer::default()); + let route_config = RouteConfig { key: ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap(), - scope: RouteScope::Cluster, + scope: RouteScope::InCluster, }; - mux.set_my_routes(vec![route_config.clone()]); + let mux = Arc::new(SwbusMultiplexer::new(vec![route_config.clone()])); let request = r#" { @@ -310,4 +325,46 @@ mod tests { _ => panic!("Expected response message"), } } + + #[test] + fn test_clone_with_hop_count() { + let conn_info = Arc::new(SwbusConnInfo::new_client( + ConnectionType::InCluster, + "127.0.0.1:8080".parse().unwrap(), + ServicePath::from_string("regiona.clustera.10.0.0.2-dpu0").unwrap(), + )); + let (send_queue_tx, _) = mpsc::channel(16); + let conn = SwbusConn::new(&conn_info, send_queue_tx); + let hop_count = 5; + let nexthop = SwbusNextHop::new_remote(conn_info.clone(), conn.new_proxy(), hop_count); + + let new_hop_count = 10; + let cloned_nexthop = nexthop.clone_with_hop_count(new_hop_count); + + assert_eq!(cloned_nexthop.nh_type, NextHopType::Remote); + assert_eq!(cloned_nexthop.conn_info, Some(conn_info)); + assert_eq!(cloned_nexthop.hop_count, new_hop_count); + } + + #[test] + fn test_nexthop_ord_and_eq() { + let (conn1, _) = + new_conn_for_test_with_endpoint(ConnectionType::InCluster, "region-a.cluster-a.node0", "127.0.0.1:61000"); + let (conn2, _) = + new_conn_for_test_with_endpoint(ConnectionType::InCluster, "region-a.cluster-a.node0", "127.0.0.1:61001"); + + let nh1 = SwbusNextHop::new_remote(conn1.info().clone(), conn1.new_proxy(), 1); + let nh2 = SwbusNextHop::new_remote(conn2.info().clone(), conn2.new_proxy(), 1); + let nh3 = SwbusNextHop::new_remote(conn2.info().clone(), conn2.new_proxy(), 2); + + assert!(nh1 != nh2); + assert!(nh1 < nh2 && nh2 < nh3); + + #[allow(clippy::mutable_key_type)] + let mut nhs = std::collections::BTreeSet::new(); + assert!(nhs.insert(nh1.clone())); + assert!(nhs.contains(&nh1)); + assert!(nhs.insert(nh2.clone())); + assert!(nhs.contains(&nh2)); + } } diff --git a/crates/swbus-core/src/mux/route_annoucer.rs b/crates/swbus-core/src/mux/route_annoucer.rs new file mode 100644 index 0000000..7ad9097 --- /dev/null +++ b/crates/swbus-core/src/mux/route_annoucer.rs @@ -0,0 +1,252 @@ +use super::SwbusConnInfo; +use super::SwbusMultiplexer; +use std::sync::Arc; +use swbus_proto::result::*; +use swbus_proto::swbus::*; +use tokio::sync::mpsc; +use tokio_util::sync::CancellationToken; +use tracing::*; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) enum TriggerType { + ConnectionUp, + ConnectionDown, + RouteUpdated, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct RouteAnnounceTask { + trigger: TriggerType, + conn_info: Arc, +} + +impl RouteAnnounceTask { + pub fn new(trigger: TriggerType, conn_info: Arc) -> Self { + Self { trigger, conn_info } + } +} +impl std::fmt::Display for RouteAnnounceTask { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "trigger: {:?}, conn_info: {}", + self.trigger, + self.conn_info.remote_service_path() + ) + } +} +pub struct RouteAnnouncer { + task_rx: mpsc::Receiver, + ct: CancellationToken, + mux: Arc, +} +impl RouteAnnouncer { + pub fn new(task_rx: mpsc::Receiver, ct: CancellationToken, mux: Arc) -> Self { + Self { task_rx, ct, mux } + } + pub async fn run(&mut self) { + self.run_loop().await; + } + + async fn run_loop(&mut self) { + let mut tasks: Vec = Vec::new(); + loop { + tokio::select! { + _ = self.ct.cancelled() => { + info!("Route announcer stopped. Shutting down immediately."); + break; + } + // currently, we send routes to all connections on any trigger so we don't need to differentiate + // between the tasks. If there are multiple tasks in the queue, we only process the last one. + // This is to avoid sending multiple route updates because only the last one is the latest. + num_received = self.task_rx.recv_many(&mut tasks, 100) => { + if num_received == 0 { + info!("Route announcer task channel closed. Shutting down."); + break; + } + let last_task = tasks.pop().expect("Last task is always present since num_received > 0"); + for task in tasks.drain(..) { + info!("skipping route announce task: {}", task); + } + self.process_task(last_task).await; + } + } + } + } + + async fn process_task(&self, task: RouteAnnounceTask) { + info!("Processing route announce task: {}", task); + let connections = self.mux.export_connections(); + + for conn_info in connections { + let routes = self.mux.export_routes(&conn_info); + if routes.is_none() { + continue; + } + let _ = self + .send_route_announcement(&conn_info, &routes.unwrap()) + .await + .map_err(|e| { + error!( + "Failed to send route announcement to {}: {}", + conn_info.remote_service_path(), + e + ); + }); + } + } + + #[instrument(name = "send_route_announcement", level = "debug", skip_all)] + async fn send_route_announcement(&self, conn_info: &SwbusConnInfo, routes: &RouteEntries) -> Result<()> { + let dest_sp = conn_info.remote_service_path().clone(); + let msg = SwbusMessage { + header: Some(SwbusMessageHeader::new( + self.mux.get_my_service_path().clone(), + dest_sp, + self.mux.generate_message_id(), + )), + body: Some(swbus_message::Body::RouteAnnouncement(routes.clone())), + }; + debug!( + "Sending route announcement to {}, conn_info {:?}, message {:?}", + conn_info.remote_service_path(), + conn_info, + &msg + ); + self.mux.route_message(msg).await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::mux::test_utils::*; + use std::sync::Arc; + use swbus_config::RouteConfig; + use tokio::sync::mpsc; + + #[tokio::test] + async fn test_process_task() { + // create a mux with 2 my-routes + let my_route1 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a.node0").unwrap(), + scope: RouteScope::InCluster, + }; + let my_route2 = RouteConfig { + key: ServicePath::from_string("region-a.cluster-a").unwrap(), + scope: RouteScope::InRegion, + }; + let mux = SwbusMultiplexer::new(vec![my_route1.clone(), my_route2.clone()]); + // Register 2 connections + let (conn1, mut conn1_sendq_rx) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node1"); + mux.register(conn1.info(), conn1.new_proxy()); + let (conn2, mut conn2_sendq_rx) = new_conn_for_test(ConnectionType::InCluster, "region-a.cluster-a.node2"); + mux.register(conn2.info(), conn2.new_proxy()); + + // Add some routes + // add a regional route + add_route(&mux, "region-a.cluster-b", 2, &conn1).unwrap(); + + // add a global route + add_route(&mux, "region-a", 2, &conn2).unwrap(); + + // Create a route announce task + let (tx, rx) = mpsc::channel(1); + + let mut announcer = RouteAnnouncer::new(rx, CancellationToken::new(), Arc::new(mux)); + tokio::spawn(async move { + announcer.run().await; + }); + + let task = RouteAnnounceTask::new(TriggerType::RouteUpdated, conn1.info().clone()); + tx.send(task).await.unwrap(); + // verify that the routes are sent to both connections + let conn1_ra = r#" + { + "header": { + "version": 1, + "flag": 0, + "ttl": 63, + "source": "region-a.cluster-a.node0", + "destination": "region-a.cluster-a.node1" + }, + "body": { + "RouteAnnouncement": { + "entries": [ + { + "service_path": "region-a", + "nh_service_path": null, + "route_scope": 4, + "hop_count": 2 + }, + { + "service_path": "region-a.cluster-a", + "nh_service_path": null, + "route_scope": 3, + "hop_count": 0 + }, + { + "service_path": "region-a.cluster-a.node0", + "nh_service_path": null, + "route_scope": 2, + "hop_count": 0 + }, + { + "service_path": "region-a.cluster-a.node2", + "nh_service_path": null, + "route_scope": 2, + "hop_count": 1 + } + ] + } + } + } + "#; + let _: SwbusMessage = serde_json::from_str(conn1_ra).unwrap(); + receive_and_compare(&mut conn1_sendq_rx, conn1_ra).await; + + let conn2_ra = r#" + { + "header": { + "version": 1, + "flag": 0, + "ttl": 63, + "source": "region-a.cluster-a.node0", + "destination": "region-a.cluster-a.node2" + }, + "body": { + "RouteAnnouncement": { + "entries": [ + { + "service_path": "region-a.cluster-a", + "nh_service_path": null, + "route_scope": 3, + "hop_count": 0 + }, + { + "service_path": "region-a.cluster-a.node0", + "nh_service_path": null, + "route_scope": 2, + "hop_count": 0 + }, + { + "service_path": "region-a.cluster-a.node1", + "nh_service_path": null, + "route_scope": 2, + "hop_count": 1 + }, + { + "service_path": "region-a.cluster-b", + "nh_service_path": null, + "route_scope": 3, + "hop_count": 2 + } + ] + } + } + } + "#; + receive_and_compare(&mut conn2_sendq_rx, conn2_ra).await; + } +} diff --git a/crates/swbus-core/src/mux/service.rs b/crates/swbus-core/src/mux/service.rs index e1f1cbf..22715f9 100644 --- a/crates/swbus-core/src/mux/service.rs +++ b/crates/swbus-core/src/mux/service.rs @@ -1,3 +1,4 @@ +use super::route_annoucer::{RouteAnnounceTask, RouteAnnouncer}; use super::SwbusConn; use super::SwbusMultiplexer; use crate::mux::conn_store::SwbusConnStore; @@ -16,13 +17,14 @@ use tokio::sync::{ }; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::Stream; +use tokio_util::sync::CancellationToken; use tonic::{transport::Server, Request, Response, Status, Streaming}; use tracing::*; pub struct SwbusServiceHost { swbus_server_addr: SocketAddr, - mux: Arc, - conn_store: Arc, + mux: Option>, + conn_store: Option>, shutdown_tx: Option>, shutdown_rx: Option>, } @@ -32,13 +34,11 @@ type SwbusMessageStream = Pin impl SwbusServiceHost { pub fn new(swbus_server_addr: &SocketAddr) -> Self { - let mux = Arc::new(SwbusMultiplexer::new()); - let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); let (shutdown_tx, shutdown_rx) = oneshot::channel::<()>(); Self { swbus_server_addr: *swbus_server_addr, - mux, - conn_store, + mux: None, + conn_store: None, shutdown_tx: Some(shutdown_tx), shutdown_rx: Some(shutdown_rx), } @@ -65,25 +65,39 @@ impl SwbusServiceHost { )); } - // register local nexthops for local services - self.mux.set_my_routes(config.routes.clone()); - for route in config.routes { - self.conn_store.add_my_route(route); - } + // create mux and set route announce task queue + let mut mux = SwbusMultiplexer::new(config.routes); + let (route_annouce_task_tx, route_annouce_task_rx) = mpsc::channel::(100); + let route_announcer_ct = CancellationToken::new(); + mux.set_route_announcer(route_annouce_task_tx, route_announcer_ct.clone()); + + let mux = Arc::new(mux); + let mux_clone = mux.clone(); + + // start the route announcer + tokio::spawn(async move { + let mut route_announcer = RouteAnnouncer::new(route_annouce_task_rx, route_announcer_ct, mux_clone); + route_announcer.run().await; + }); + + let conn_store = Arc::new(SwbusConnStore::new(mux.clone())); // add peers to the connection store for peer in config.peers { - self.conn_store.add_peer(peer); + conn_store.add_peer(peer); } - let conn_store = self.conn_store.clone(); + self.mux = Some(mux); + let conn_store_clone = conn_store.clone(); + self.conn_store = Some(conn_store); let shutdown_rx = self.shutdown_rx.take().unwrap(); + Server::builder() .add_service(SwbusServiceServer::new(self)) .serve_with_shutdown(addr, async { shutdown_rx.await.ok(); info!("SwbusServiceServer received shutdown signal"); - conn_store.shutdown().await; + conn_store_clone.shutdown().await; }) .await .map_err(|e| { @@ -149,10 +163,15 @@ impl SwbusService for SwbusServiceHost { let (out_tx, out_rx) = mpsc::channel(16); let conn_info = Arc::new(SwbusConnInfo::new_server(conn_type, client_addr, service_path)); - let conn = - SwbusConn::from_incoming_stream(conn_info, in_stream, out_tx, self.mux.clone(), self.conn_store.clone()) - .await; - self.conn_store.conn_established(conn); + let conn = SwbusConn::from_incoming_stream( + conn_info, + in_stream, + out_tx, + self.mux.as_ref().unwrap().clone(), + self.conn_store.as_ref().unwrap().clone(), + ) + .await; + self.conn_store.as_ref().unwrap().conn_established(conn); let out_stream = ReceiverStream::new(out_rx); Ok(Response::new(Box::pin(out_stream) as Self::StreamMessagesStream)) } diff --git a/crates/swbus-core/tests/README.md b/crates/swbus-core/tests/README.md new file mode 100644 index 0000000..8d006c5 --- /dev/null +++ b/crates/swbus-core/tests/README.md @@ -0,0 +1,95 @@ +# Introduction +A test infrastructure is implemented by common/test_executor.rs. It automates test topology bringup, sending swbus message and comparing received swbus messages to assert test pass or failure. Here are the high level functionalities of the infra. +- Test topology bringup. It takes topology definition in a specified JSON file and brings up the topology, which includes multiple swbusd instances and swbus clients. +- Execute tests in the topology. Test is defined in JSON file, which specifies the requests (swbus message) sent from clients and expected responses received by clients, which can be same or different from the clients sending the requests. +- Record received messages to help creating test JSON file to be used in future runs. + +# Topology Definition +Topology is defined in a JSON file. The topology has a few servers (swbusd e.g.) and clients. + +1. Add a Description +Provide a short summary of the topology under the "description" field. + +"description": "Brief explanation of the topology structure and its connections" + +2. Define the Servers +Each server represents a node in the network and should have: + +- A unique name (e.g., "swbusd1", "swbusd2") +- An endpoint specifying the IP address and port +- A routes array that defines "my-route" for the swbusd +- A peers array listing neighboring swbusd it connects to +``` +Example: + +"servers": { + "server-name": { + "endpoint": "IP:PORT", + "routes": [ + { + "key": "region.cluster.node-id", + "scope": "InCluster" + } + ], + "peers": [ + { + "id": "region.cluster.peer-id", + "endpoint": "IP:PORT", + "conn_type": "InCluster" + } + ] + } +} +``` + +3. Define the Clients +Each client connects to a specific server and interacts with a server directly +client-sp defines a service path for the client +``` +"clients": { + "client-name": { + "swbusd": "server-name", + "client_sp": "region.cluster.node-id/service-name/instance" + } +} +``` + +# Test Data Definition +Test data is also defined in JSON file. It contains multiple test cases, which will be run in sequence. +- Each test case has a name, description. +- Each test case can include multiple test steps. In Each test step, multiple requests (swbus message) will be sent from the corresponding clients in parallel. Expected responses will be compared to received messages collected from the corresponding clients to assert test pass or fail. + +Below guide explains how to create a structured test data entry in JSON format. + +1. Define the Test Case Name +Each test case should have a unique identifier under the name field, representing its purpose. +"name": "test_case_name" + +2. Provide a Description +Describe what the test case is verifying. +"description": "Brief explanation of the test scenario" + +3. Define Test Steps +Each test consists of one or more steps, where a client sends a request, and an expected response is verified. +"steps": [ { ... } ] + +4. Define Requests +Each step contains a list of requests that are sent by a client. A request consists of: +- client: The client initiating the request. +- message: The swbus message to be sent + +5. Define Expected Responses +Each step also contains expected responses that validate the outcome. The response structure mirrors the request, with: +- client: The client receiving the response. +- message: The swbus response message + +# Record Response +Instead of hand-crafting a test data with complete info, we can define a test data without response and let test run to collect received messages and fill in response in test data. This can be used to create initial test data or when system changes behavior, it can update the test data with new responses. To do that, +1. Create a test data with empty response, "responses": [], in JSON. +2. Run test (make or cargo test) with env GENERATE_TEST_DATA=1. + +# Run tests with trace enabled +In order to troubleshoot a test run with full traces from swbusd, use env ENABLE_TRACE=1 to run make or cargo test. trace output will be printed to stdout or stderr. + +# Limitations +We can't have multiple tokio tests in the same test rust file. rust executes tests in the same file in parallel. If different test brings up different topology, there might be conflict between them. For example, multiple swbusd uses the same GRPC port. It also increases difficulties to trouble shoot because traces from different test cases will mix together. \ No newline at end of file diff --git a/crates/swbus-core/tests/basic_tests.rs b/crates/swbus-core/tests/basic_tests.rs index 5407efe..7dfb2c4 100644 --- a/crates/swbus-core/tests/basic_tests.rs +++ b/crates/swbus-core/tests/basic_tests.rs @@ -1,15 +1,28 @@ mod common; -use common::test_executor::{run_tests, TopoRuntime}; +use common::test_executor::{init_logger, run_tests, TopoRuntime}; #[tokio::test] -async fn test_all() { - let mut topo = TopoRuntime::new("2-swbusd"); +async fn test_b2b() { + let trace_enabled: bool = std::env::var("ENABLE_TRACE") + .map(|val| val == "1" || val.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + if trace_enabled { + init_logger(); + } + + let mut topo = TopoRuntime::new("tests/data/b2b/topo.json"); topo.bring_up().await; - // can't split run_tests into multiple test cases. Each tokio::test creates a new runtime, from which bring_up_topo runs. - // when a test case is done, the runtime is dropped and the topology is torn down. It can't be reused to run another test case. - // If move bring_up_topo outside of test cases to a setup function and create a single shared runtime, test case cannot - // use the shared runtime. It will panic with "fatal runtime error: thread::set_current should only be called once per thread". - run_tests(&mut topo, "tests/data/test_ping.json", None).await; - run_tests(&mut topo, "tests/data/test_show_route.json", None).await; - run_tests(&mut topo, "tests/data/test_trace_route.json", None).await; + // Can't split run_tests into multiple test cases. Each tokio::test creates a new runtime, + // from which bring_up_topo runs. When a test case is done, the runtime is dropped and the + // topology is torn down. It can't be reused to run another test case. + // If move bring_up_topo outside of test cases to a setup function and create a single shared + // runtime, test case cannot use the shared runtime. It will panic with "fatal runtime + // error: thread::set_current should only be called once per thread". + run_tests(&mut topo, "tests/data/b2b/test_ping.json", None).await; + run_tests(&mut topo, "tests/data/b2b/test_show_route.json", None).await; + run_tests(&mut topo, "tests/data/b2b/test_trace_route.json", None).await; } +//Can't add another test that brings up a topology and runs tests unless the new topo has +// different ports than the existing ones. Otherwise, the tests will fail because the ports +// are already in use. This is because rust will run all tests in parallel. diff --git a/crates/swbus-core/tests/common/test_executor.rs b/crates/swbus-core/tests/common/test_executor.rs index d7fbc75..cb0ff8b 100644 --- a/crates/swbus-core/tests/common/test_executor.rs +++ b/crates/swbus-core/tests/common/test_executor.rs @@ -12,14 +12,15 @@ use swbus_proto::swbus::*; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tokio::time::{self, Duration, Instant}; -use tracing::{error, info}; +use tracing_subscriber::{fmt, prelude::*, Layer}; // 3 seconds receive timeout pub const RECEIVE_TIMEOUT: u32 = 3; /// The Topo struct contains the server jobs and clients' TX and RX of its message queues. pub struct TopoRuntime { - pub name: String, + /// test folder name + pub topo_file: String, /// The server jobs are the tokio tasks that run the swbusd servers. pub server_jobs: Vec>, /// The client_receivers are the message queues of the clients to receive messages. @@ -28,13 +29,11 @@ pub struct TopoRuntime { pub client_senders: HashMap>, } -/// The test case data including the name, topo, description, and test steps. -/// topo is optional and if it is not provided, the test will be skipped if it doesn't match current topo. +/// The test case data including the name, description, and test steps. /// The test steps contain the requests to be sent from the client and the expected responses from specified clients. #[derive(Serialize, Deserialize, Debug)] struct TestCaseData { pub name: String, - pub topo: Option, pub description: Option, pub steps: Vec, } @@ -67,9 +66,9 @@ struct SwbusClientConfig { } impl TopoRuntime { - pub fn new(name: &str) -> Self { + pub fn new(topo_file: &str) -> Self { TopoRuntime { - name: name.to_string(), + topo_file: topo_file.to_string(), server_jobs: Vec::new(), client_receivers: HashMap::new(), client_senders: HashMap::new(), @@ -85,15 +84,11 @@ impl TopoRuntime { pub async fn bring_up(&mut self) { init_logger_for_test(); - let file = File::open("tests/data/topos.json").unwrap(); + let file = File::open(&self.topo_file).unwrap(); let reader = BufReader::new(file); // Parse the topo data - let topo_cfgs: HashMap = serde_json::from_reader(reader).expect("failed to parse topos.json"); - - let topo_cfg = topo_cfgs - .get(&self.name) - .unwrap_or_else(|| panic!("Failed to find topo {}", self.name)); + let topo_cfg: TopoData = serde_json::from_reader(reader).expect("failed to parse topo.json"); for (name, server) in topo_cfg.servers.clone() { self.start_server(&name, &server).await; @@ -112,7 +107,7 @@ impl TopoRuntime { .await; } - info!("Topo {} is up", self.name); + println!("Topo {} is up", self.topo_file); } async fn start_server(&mut self, name: &str, route_config: &SwbusConfig) { @@ -124,7 +119,7 @@ impl TopoRuntime { self.server_jobs.push(server_task); - info!("Server {} started at {}", name, &route_config.endpoint); + println!("Server {} started at {}", name, &route_config.endpoint); } async fn start_client(&mut self, name: &str, node_addr: &SocketAddr, client_sp: ServicePath) { @@ -133,16 +128,26 @@ impl TopoRuntime { let addr = format!("http://{node_addr}"); while start.elapsed() < Duration::from_secs(10) { - match SwbusCoreClient::connect(addr.clone(), client_sp.clone(), receive_queue_tx.clone()).await { + println!("Trying to connect to the server at {}", node_addr); + match SwbusCoreClient::connect( + addr.clone(), + client_sp.clone(), + ConnectionType::InNode, + receive_queue_tx.clone(), + ) + .await + { Ok((_, send_queue_tx)) => { self.client_receivers.insert(name.to_string(), receive_queue_rx); self.client_senders.insert(name.to_string(), send_queue_tx); - info!("Client {} connected to {}", name, node_addr); + println!("Client {} connected to {}", name, node_addr); return; } Err(e) => { - error!("Failed to connect to the server: {:?}", e); + println!("Failed to connect to the server: {:?}", e); + //std::thread::sleep(std::time::Duration::from_secs(1)); time::sleep(Duration::from_secs(1)).await; + println!("awake from sleep"); } } } @@ -163,26 +168,22 @@ pub async fn run_tests(topo: &mut TopoRuntime, test_json_file: &str, test_case_n continue; } } - if test.topo.is_some() && test.topo.as_ref().unwrap() != &topo.name { - info!( - "Skipping test {} due to mismatched topo: test.topo={}, running-topo={}", - test.name, - test.topo.as_ref().unwrap(), - topo.name - ); - continue; - } - info!("Running test: {}", test.name); + + println!("Running test: {}", test.name); for (i, step) in test.steps.iter_mut().enumerate() { - info!(" --- Step {} ---", i); + println!(" --- Step {} ---", i); for req in &step.requests { - let sender = topo.client_senders.get(&req.client).unwrap(); + let sender = topo + .client_senders + .get(&req.client) + .unwrap_or_else(|| panic!("Failed to find client sender {} in topo file", &req.client)); + match sender.send(req.message.clone()).await { Ok(_) => { - info!("Sent message from client {}", req.client); + println!("Sent message from client {}", req.client); } Err(e) => { - error!("Failed to send message from client {}: {:?}", req.client, e); + println!("Failed to send message from client {}: {:?}", req.client, e); } } } @@ -190,7 +191,7 @@ pub async fn run_tests(topo: &mut TopoRuntime, test_json_file: &str, test_case_n if to_generate { let responses = record_received_messages(topo, RECEIVE_TIMEOUT).await; step.responses = responses; - info!(" --- Recorded {} messages ---", &step.responses.len()); + println!(" --- Recorded {} messages ---", &step.responses.len()); } else { receive_and_compare(topo, &step.responses, RECEIVE_TIMEOUT).await; } @@ -206,12 +207,19 @@ pub async fn run_tests(topo: &mut TopoRuntime, test_json_file: &str, test_case_n /// if the responses are not in order. async fn receive_and_compare(topo: &mut TopoRuntime, expected_responses: &[MessageClientPair], timeout: u32) { for resp in expected_responses.iter() { - let receiver = topo.client_receivers.get_mut(&resp.client).unwrap(); + let receiver = topo + .client_receivers + .get_mut(&resp.client) + .unwrap_or_else(|| panic!("Failed to find client receiver {} in topo file", &resp.client)); + match time::timeout(Duration::from_secs(timeout as u64), receiver.recv()).await { Ok(Some(msg)) => { let normalized_msg = swbus_proto::swbus::normalize_msg(&msg); - assert_eq!(normalized_msg, resp.message); + assert_eq!( + normalized_msg, resp.message, + "Received(left) message does not match the expected(right)" + ); } Ok(None) => { panic!("channel broken"); @@ -242,3 +250,28 @@ async fn record_received_messages(topo: &mut TopoRuntime, timeout: u32) -> Vec swbusd1 <-> swbusd2", + "servers": { + "swbusd1": { + "endpoint": "127.0.0.1:60001", + "routes": [ + { + "key": "region-a.cluster-a.10.0.0.1-dpu0", + "scope": "InCluster" + } + ], + "peers": [ + { + "id": "region-a.cluster-a.10.0.0.2-dpu0", + "endpoint": "127.0.0.1:60002", + "conn_type": "InCluster" + } + ] + }, + "swbusd2": { + "endpoint": "127.0.0.1:60002", + "routes": [ + { + "key": "region-a.cluster-a.10.0.0.2-dpu0", + "scope": "InCluster" + } + ], + "peers": [ + { + "id": "region-a.cluster-a.10.0.0.1-dpu0", + "endpoint": "127.0.0.1:60001", + "conn_type": "InCluster" + } + ] + } + }, + "clients": { + "swbusd1-client": { + "swbusd": "swbusd1", + "client_sp": "region-a.cluster-a.10.0.0.1-dpu0/testsvc/0" + } + } +} \ No newline at end of file diff --git a/crates/swbus-core/tests/data/inter-cluster/test_ping.json b/crates/swbus-core/tests/data/inter-cluster/test_ping.json new file mode 100644 index 0000000..1ec0d69 --- /dev/null +++ b/crates/swbus-core/tests/data/inter-cluster/test_ping.json @@ -0,0 +1,49 @@ +[ + { + "name": "inter_cluster_ping", + "description": "verify ping and response", + "steps": [ + { + "requests": [ + { + "client": "swbusd_cluster-a.node1_client", + "message": { + "header": { + "version": 1, + "flag": 0, + "ttl": 64, + "source": "region-a.cluster-a.node1/testsvc/0", + "destination": "region-a.cluster-b.node1" + }, + "body": { + "PingRequest": {} + } + } + } + ], + "responses": [ + { + "client": "swbusd_cluster-a.node1_client", + "message": { + "header": { + "version": 1, + "flag": 0, + "ttl": 60, + "source": "region-a.cluster-b.node1", + "destination": "region-a.cluster-a.node1/testsvc/0" + }, + "body": { + "Response": { + "request_id": 0, + "error_code": 1, + "error_message": "", + "response_body": null + } + } + } + } + ] + } + ] + } +] \ No newline at end of file diff --git a/crates/swbus-core/tests/data/inter-cluster/test_show_route.json b/crates/swbus-core/tests/data/inter-cluster/test_show_route.json new file mode 100644 index 0000000..378f5b2 --- /dev/null +++ b/crates/swbus-core/tests/data/inter-cluster/test_show_route.json @@ -0,0 +1,93 @@ +[ + { + "name": "show_route_inter-cluster", + "description": "verify routes are exchanged between gateways", + "steps": [ + { + "requests": [ + { + "client": "swbusd_cluster-a.node1_client", + "message": { + "header": { + "version": 1, + "flag": 0, + "ttl": 1, + "source": "region-a.cluster-a.node1/testsvc/0", + "destination": "region-a.cluster-b.node1" + }, + "body": { + "ManagementRequest": { + "request": 0, + "arguments": [] + } + } + } + } + ], + "responses": [ + { + "client": "swbusd_cluster-a.node1_client", + "message": { + "header": { + "version": 1, + "flag": 0, + "ttl": 63, + "source": "region-a.cluster-a.node1", + "destination": "region-a.cluster-a.node1/testsvc/0" + }, + "body": { + "Response": { + "request_id": 0, + "error_code": 1, + "error_message": "", + "response_body": { + "RouteEntries": { + "entries": [ + { + "service_path": "region-a.cluster-a", + "nh_service_path": "region-a.cluster-a.gw", + "route_scope": 3, + "hop_count": 1 + }, + { + "service_path": "region-a.cluster-a.gw", + "nh_service_path": "region-a.cluster-a.gw", + "route_scope": 2, + "hop_count": 1 + }, + { + "service_path": "region-a.cluster-a.gw", + "nh_service_path": "region-a.cluster-a.gw", + "route_scope": 2, + "hop_count": 1 + }, + { + "service_path": "region-a.cluster-a.node1/testsvc/0", + "nh_service_path": "region-a.cluster-a.node1/testsvc/0", + "route_scope": 0, + "hop_count": 1 + }, + { + "service_path": "region-a.cluster-b", + "nh_service_path": "region-a.cluster-a.gw", + "route_scope": 3, + "hop_count": 2 + }, + { + "service_path": "region-a.cluster-b.gw", + "nh_service_path": "region-a.cluster-a.gw", + "route_scope": 2, + "hop_count": 2 + } + ] + } + } + } + } + } + } + ] + } + ] + } +] \ No newline at end of file diff --git a/crates/swbus-core/tests/data/inter-cluster/topo.json b/crates/swbus-core/tests/data/inter-cluster/topo.json new file mode 100644 index 0000000..044fd09 --- /dev/null +++ b/crates/swbus-core/tests/data/inter-cluster/topo.json @@ -0,0 +1,93 @@ +{ + "description": "Simple topo with 4 swbusd including 2 gateways and 1 client: client <-> swbusd_cluster-a.node1 <-> swbusd_cluster-a.gw <-> swbusd_cluster-b.gw <-> swbusd_cluster-b.node1", + "servers": { + "swbusd_cluster-a.node1": { + "endpoint": "127.0.0.1:60001", + "routes": [ + { + "key": "region-a.cluster-a.node1", + "scope": "InCluster" + } + ], + "peers": [ + { + "id": "region-a.cluster-a.gw", + "endpoint": "127.0.0.1:60002", + "conn_type": "InCluster" + } + ] + }, + "swbusd_cluster-a.gw": { + "endpoint": "127.0.0.1:60002", + "routes": [ + { + "key": "region-a.cluster-a.gw", + "scope": "InCluster" + }, + { + "key": "region-a.cluster-a", + "scope": "InRegion" + } + ], + "peers": [ + { + "id": "region-a.cluster-a.node1", + "endpoint": "127.0.0.1:60001", + "conn_type": "InCluster" + }, + { + "id": "region-a.cluster-b.gw", + "endpoint": "127.0.0.1:60003", + "conn_type": "InRegion" + } + ] + }, + "swbusd_cluster-b.gw": { + "endpoint": "127.0.0.1:60003", + "routes": [ + { + "key": "region-a.cluster-b.gw", + "scope": "InCluster" + }, + { + "key": "region-a.cluster-b", + "scope": "InRegion" + } + ], + "peers": [ + { + "id": "region-a.cluster-a.gw", + "endpoint": "127.0.0.1:60002", + "conn_type": "InRegion" + }, + { + "id": "region-a.cluster-b.node1", + "endpoint": "127.0.0.1:60004", + "conn_type": "InCluster" + } + ] + }, + "swbusd_cluster-b.node1": { + "endpoint": "127.0.0.1:60004", + "routes": [ + { + "key": "region-a.cluster-b.node1", + "scope": "InCluster" + } + ], + "peers": [ + { + "id": "region-a.cluster-b.gw", + "endpoint": "127.0.0.1:60003", + "conn_type": "InCluster" + } + ] + } + }, + "clients": { + "swbusd_cluster-a.node1_client": { + "swbusd": "swbusd_cluster-a.node1", + "client_sp": "region-a.cluster-a.node1/testsvc/0" + } + } +} \ No newline at end of file diff --git a/crates/swbus-core/tests/data/topos.json b/crates/swbus-core/tests/data/topos.json deleted file mode 100644 index 76b6f57..0000000 --- a/crates/swbus-core/tests/data/topos.json +++ /dev/null @@ -1,45 +0,0 @@ -{ - "2-swbusd": { - "description": "Simple topo with 2 swbusd and 1 client: client <-> swbusd1 <-> swbusd2", - "servers": { - "swbusd1": { - "endpoint": "127.0.0.1:60001", - "routes": [ - { - "key": "region-a.cluster-a.10.0.0.1-dpu0", - "scope": "Cluster" - } - ], - "peers": [ - { - "id": "region-a.cluster-a.10.0.0.2-dpu0", - "endpoint": "127.0.0.1:60002", - "conn_type": "Cluster" - } - ] - }, - "swbusd2": { - "endpoint": "127.0.0.1:60002", - "routes": [ - { - "key": "region-a.cluster-a.10.0.0.2-dpu0", - "scope": "Cluster" - } - ], - "peers": [ - { - "id": "region-a.cluster-a.10.0.0.1-dpu0", - "endpoint": "127.0.0.1:60001", - "conn_type": "Cluster" - } - ] - } - }, - "clients": { - "swbusd1-client": { - "swbusd": "swbusd1", - "client_sp": "region-a.cluster-a.10.0.0.1-dpu0/testsvc/0" - } - } - } -} \ No newline at end of file diff --git a/crates/swbus-core/tests/inter_cluster_tests.rs b/crates/swbus-core/tests/inter_cluster_tests.rs new file mode 100644 index 0000000..c18a34d --- /dev/null +++ b/crates/swbus-core/tests/inter_cluster_tests.rs @@ -0,0 +1,22 @@ +mod common; +use common::test_executor::{init_logger, run_tests, TopoRuntime}; + +#[tokio::test] +async fn test_inter_cluster() { + let trace_enabled: bool = std::env::var("ENABLE_TRACE") + .map(|val| val == "1" || val.eq_ignore_ascii_case("true")) + .unwrap_or(false); + + if trace_enabled { + init_logger(); + } + + let mut topo = TopoRuntime::new("tests/data/inter-cluster/topo.json"); + topo.bring_up().await; + // can't split run_tests into multiple test cases. Each tokio::test creates a new runtime, from which bring_up_topo runs. + // when a test case is done, the runtime is dropped and the topology is torn down. It can't be reused to run another test case. + // If move bring_up_topo outside of test cases to a setup function and create a single shared runtime, test case cannot + // use the shared runtime. It will panic with "fatal runtime error: thread::set_current should only be called once per thread". + run_tests(&mut topo, "tests/data/inter-cluster/test_ping.json", None).await; + run_tests(&mut topo, "tests/data/inter-cluster/test_show_route.json", None).await; +} diff --git a/crates/swbus-edge/src/core_client.rs b/crates/swbus-edge/src/core_client.rs index c41839d..66ce52c 100644 --- a/crates/swbus-edge/src/core_client.rs +++ b/crates/swbus-edge/src/core_client.rs @@ -18,6 +18,7 @@ use tracing::{debug, error, info}; pub struct SwbusCoreClient { uri: String, sp: ServicePath, + conn_type: ConnectionType, // tx queue to send messages to swbusd pub(crate) send_queue_tx: Arc>>>, @@ -29,10 +30,16 @@ pub struct SwbusCoreClient { // Factory functions impl SwbusCoreClient { - pub fn new(uri: String, sp: ServicePath, message_processor_tx: mpsc::Sender) -> Self { + pub fn new( + uri: String, + sp: ServicePath, + message_processor_tx: mpsc::Sender, + conn_type: ConnectionType, + ) -> Self { Self { uri, sp, + conn_type, send_queue_tx: Arc::new(RwLock::new(None)), message_processor_tx, swbusd_connect_task: None, @@ -49,6 +56,7 @@ impl SwbusCoreClient { pub async fn connect( uri: String, sp: ServicePath, + conn_type: ConnectionType, receive_queue_tx: mpsc::Sender, ) -> Result<(tokio::task::JoinHandle>, mpsc::Sender)> { let (send_queue_tx, send_queue_rx) = mpsc::channel::(100); @@ -80,7 +88,7 @@ impl SwbusCoreClient { meta.insert( SWBUS_CONNECTION_TYPE, - MetadataValue::from_str(ConnectionType::Local.as_str_name()).unwrap(), + MetadataValue::from_str(conn_type.as_str_name()).unwrap(), ); let recv_stream = match client.stream_messages(send_stream_request).await { @@ -105,10 +113,10 @@ impl SwbusCoreClient { let sp = self.sp.clone(); let message_processor_tx = self.message_processor_tx.clone(); let send_queue_tx_arc = self.send_queue_tx.clone(); - + let conn_type = self.conn_type; let handle = tokio::spawn(async move { loop { - match Self::connect(uri.clone(), sp.clone(), message_processor_tx.clone()).await { + match Self::connect(uri.clone(), sp.clone(), conn_type, message_processor_tx.clone()).await { Ok((recv_stream_task, send_queue_tx)) => { info!("Successfully connected to swbusd at {}", uri); send_queue_tx_arc.write().await.replace(send_queue_tx); diff --git a/crates/swbus-edge/src/edge_runtime.rs b/crates/swbus-edge/src/edge_runtime.rs index 0bf8b6c..e15ac6a 100644 --- a/crates/swbus-edge/src/edge_runtime.rs +++ b/crates/swbus-edge/src/edge_runtime.rs @@ -25,11 +25,12 @@ pub struct SwbusEdgeRuntime { } impl SwbusEdgeRuntime { - pub fn new(swbus_uri: String, sp: ServicePath) -> Self { + pub fn new(swbus_uri: String, sp: ServicePath, conn_type: ConnectionType) -> Self { + assert!(conn_type == ConnectionType::Client || conn_type == ConnectionType::InNode); let (local_msg_tx, local_msg_rx) = channel(SWBUS_RECV_QUEUE_SIZE); let (remote_msg_tx, remote_msg_rx) = channel(SWBUS_RECV_QUEUE_SIZE); let base_sp = sp.clone(); - let swbus_client = SwbusCoreClient::new(swbus_uri.clone(), sp, remote_msg_tx); + let swbus_client = SwbusCoreClient::new(swbus_uri.clone(), sp, remote_msg_tx, conn_type); let tx_to_swbusd = swbus_client.send_queue_tx.clone(); let message_router = SwbusMessageRouter::new(swbus_client, local_msg_rx, remote_msg_rx); @@ -128,7 +129,7 @@ mod tests { endpoint: "127.0.0.1:{port}" routes: - key: "region-a.cluster-a.10.0.1.0-dpu0" - scope: "Cluster" + scope: "InCluster" peers: "# ); @@ -187,7 +188,11 @@ mod tests { sp.service_type = "swbus-edge".to_string(); sp.service_id = "test".to_string(); - let mut runtime = SwbusEdgeRuntime::new(format!("http://{}", swbus_config.endpoint), sp.clone()); + let mut runtime = SwbusEdgeRuntime::new( + format!("http://{}", swbus_config.endpoint), + sp.clone(), + ConnectionType::InNode, + ); runtime.start().await.unwrap(); let runtime = Arc::new(runtime); @@ -222,7 +227,8 @@ mod tests { sp.service_type = "swbus-edge".to_string(); sp.service_id = "test".to_string(); - let mut runtime = SwbusEdgeRuntime::new(format!("http://{}", swbus_config.endpoint), sp); + let mut runtime = + SwbusEdgeRuntime::new(format!("http://{}", swbus_config.endpoint), sp, ConnectionType::InNode); runtime.start().await.unwrap(); let base_sp = swbus_config.routes[0].key.to_swbusd_service_path().to_longest_path(); diff --git a/crates/swbus-proto/build.rs b/crates/swbus-proto/build.rs index b4ce856..58596c7 100644 --- a/crates/swbus-proto/build.rs +++ b/crates/swbus-proto/build.rs @@ -9,11 +9,11 @@ fn main() -> Result<(), Box> { .message_attribute("swbus.ServicePath", "#[derive(Eq, Hash, Ord, PartialOrd)]") .field_attribute("swbus.SwbusMessageHeader.id", "#[serde(default, skip_serializing)]") .field_attribute( - "swbus.RouteQueryResultEntry.nh_id", + "swbus.RouteEntry.nh_id", "#[serde(default, skip_serializing)]", ) .field_attribute( - "swbus.RouteQueryResult.entries", + "swbus.RouteEntries.entries", "#[serde(serialize_with = \"sorted_vec_serializer\")]", ) .field_attribute( @@ -25,11 +25,11 @@ fn main() -> Result<(), Box> { "#[serde(serialize_with = \"serialize_service_path_opt\",deserialize_with = \"deserialize_service_path_opt\")]", ) .field_attribute( - "swbus.RouteQueryResultEntry.service_path", + "swbus.RouteEntry.service_path", "#[serde(serialize_with = \"serialize_service_path_opt\",deserialize_with = \"deserialize_service_path_opt\")]", ) .field_attribute( - "swbus.RouteQueryResultEntry.nh_service_path", + "swbus.RouteEntry.nh_service_path", "#[serde(serialize_with = \"serialize_service_path_opt\",deserialize_with = \"deserialize_service_path_opt\")]", ); diff --git a/crates/swbus-proto/proto/swbus.proto b/crates/swbus-proto/proto/swbus.proto index 9ad810c..674b599 100644 --- a/crates/swbus-proto/proto/swbus.proto +++ b/crates/swbus-proto/proto/swbus.proto @@ -104,6 +104,9 @@ enum SwbusErrorCode { // Invalid message payload. SWBUS_ERROR_CODE_INVALID_PAYLOAD = 211; + // Invalid route announcement + SWBUS_ERROR_CODE_INVALID_ROUTE = 212; + // Input error ends. SWBUS_ERROR_CODE_INPUT_ERROR_MAX = 299; @@ -146,7 +149,7 @@ message RequestResponse { SwbusErrorCode error_code = 20; string error_message = 30; oneof ResponseBody { - RouteQueryResult route_query_result = 100; + RouteEntries route_entries = 100; ManagementQueryResult management_query_result = 110; } } @@ -156,9 +159,9 @@ message RequestResponse { // enum ConnectionType { CONNECTION_TYPE_CLIENT = 0; - CONNECTION_TYPE_LOCAL = 1; - CONNECTION_TYPE_CLUSTER = 2; - CONNECTION_TYPE_REGION = 3; + CONNECTION_TYPE_IN_NODE = 1; + CONNECTION_TYPE_IN_CLUSTER = 2; + CONNECTION_TYPE_IN_REGION = 3; CONNECTION_TYPE_GLOBAL = 4; } @@ -167,27 +170,21 @@ enum ConnectionType { // enum RouteScope { ROUTE_SCOPE_CLIENT = 0; - ROUTE_SCOPE_LOCAL = 1; - ROUTE_SCOPE_CLUSTER = 2; - ROUTE_SCOPE_REGION = 3; + ROUTE_SCOPE_IN_NODE = 1; + ROUTE_SCOPE_IN_CLUSTER = 2; + ROUTE_SCOPE_IN_REGION = 3; ROUTE_SCOPE_GLOBAL = 4; } -message RegistrationQueryRequest { -} - -message RegistrationQueryResponse { -} - -message RouteQueryResult { - repeated RouteQueryResultEntry entries = 10; +message RouteEntries { + repeated RouteEntry entries = 10; } -message RouteQueryResultEntry { +message RouteEntry { ServicePath service_path = 10; string nh_id = 20; ServicePath nh_service_path = 30; - RouteScope nh_scope = 40; + RouteScope route_scope = 40; uint32 hop_count = 50; } @@ -237,9 +234,8 @@ message SwbusMessage { oneof Body { RequestResponse response = 20; - // Registration - RegistrationQueryRequest registration_query_request = 101; - RegistrationQueryResponse registration_query_response = 102; + // Route Update + RouteEntries route_announcement = 100; // Ping PingRequest ping_request = 310; diff --git a/crates/swbus-proto/src/swbus.rs b/crates/swbus-proto/src/swbus.rs index 1418a5f..3e4e939 100644 --- a/crates/swbus-proto/src/swbus.rs +++ b/crates/swbus-proto/src/swbus.rs @@ -90,15 +90,15 @@ impl ServicePath { }) } - pub fn to_regional_prefix(&self) -> String { + pub fn to_global_prefix(&self) -> String { self.region_id.clone() } - pub fn to_cluster_prefix(&self) -> String { + pub fn to_inregion_prefix(&self) -> String { format!("{}.{}", self.region_id, self.cluster_id) } - pub fn to_node_prefix(&self) -> String { + pub fn to_incluster_prefix(&self) -> String { format!("{}.{}.{}", self.region_id, self.cluster_id, self.node_id) } @@ -149,10 +149,10 @@ impl ServicePath { return RouteScope::Global; } if self.node_id.is_empty() { - return RouteScope::Region; + return RouteScope::InRegion; } if self.service_id.is_empty() { - return RouteScope::Cluster; + return RouteScope::InCluster; } RouteScope::Client } @@ -238,13 +238,16 @@ pub fn deserialize_service_path_opt<'de, D>(deserializer: D) -> Result, { - let s = String::deserialize(deserializer)?; + let opt = Option::::deserialize(deserializer)?; - match ServicePath::from_string(&s) { - Ok(sp) => Ok(Some(sp)), - Err(_) => Err(serde::de::Error::custom(format!( - "Failed to parse service path from string: {s}" - ))), + match opt { + Some(s) => match ServicePath::from_string(&s) { + Ok(sp) => Ok(Some(sp)), + Err(_) => Err(serde::de::Error::custom(format!( + "Failed to parse service path from string: {s}" + ))), + }, + None => Ok(None), } } @@ -263,10 +266,18 @@ where } } -impl PartialOrd for RouteQueryResultEntry { +impl PartialOrd for RouteEntry { fn partial_cmp(&self, other: &Self) -> Option { - match self.service_path.partial_cmp(&other.service_path) { - Some(std::cmp::Ordering::Equal) => self.hop_count.partial_cmp(&other.hop_count), + Some(self.cmp(other)) + } +} + +impl Eq for RouteEntry {} + +impl Ord for RouteEntry { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + match self.service_path.cmp(&other.service_path) { + std::cmp::Ordering::Equal => self.hop_count.cmp(&other.hop_count), x => x, } } @@ -350,16 +361,16 @@ impl SwbusMessage { /// send response to the sender of the request pub fn new_response( - request: &SwbusMessage, - source: Option<&ServicePath>, + request_header: &SwbusMessageHeader, + dest: Option<&ServicePath>, error_code: SwbusErrorCode, error_message: &str, request_id: u64, response_body: Option, ) -> Self { let mut request_response = match error_code { - SwbusErrorCode::Ok => RequestResponse::ok(request.header.as_ref().unwrap().id), - _ => RequestResponse::infra_error(request.header.as_ref().unwrap().id, error_code, error_message), + SwbusErrorCode::Ok => RequestResponse::ok(request_header.id), + _ => RequestResponse::infra_error(request_header.id, error_code, error_message), }; if response_body.is_some() { @@ -367,27 +378,15 @@ impl SwbusMessage { }; // if dest is not provided, use the source of the request - let src_sp = match source { + let dest_sp = match dest { Some(sp) => sp.clone(), - None => request - .header - .as_ref() - .unwrap() - .destination - .clone() - .expect("missing dest service_path"), + None => request_header.destination.clone().expect("missing dest service_path"), }; SwbusMessage { header: Some(SwbusMessageHeader::new( - src_sp, - request - .header - .as_ref() - .unwrap() - .source - .clone() - .expect("missing source service_path"), + dest_sp, + request_header.source.clone().expect("missing source service_path"), request_id, )), body: Some(swbus_message::Body::Response(request_response)), @@ -503,15 +502,16 @@ mod tests { } #[test] - fn registration_query_request_can_be_created() { - let request = RegistrationQueryRequest {}; - test_packing_with_swbus_message(swbus_message::Body::RegistrationQueryRequest(request)); - } - - #[test] - fn registration_query_response_can_be_created() { - let response = RegistrationQueryResponse {}; - test_packing_with_swbus_message(swbus_message::Body::RegistrationQueryResponse(response)); + fn route_announcement_can_be_created() { + let entry = RouteEntry { + service_path: Some(create_mock_service_path()), + nh_id: "test".to_string(), + nh_service_path: Some(create_mock_service_path()), + route_scope: RouteScope::InRegion.into(), + hop_count: 1, + }; + let request = RouteEntries { entries: vec![entry] }; + test_packing_with_swbus_message(swbus_message::Body::RouteAnnouncement(request)); } #[test] @@ -628,8 +628,14 @@ mod tests { let request_id = request.header.as_ref().unwrap().id; let src = request.header.as_ref().unwrap().source.as_ref().unwrap().clone(); let dest = request.header.as_ref().unwrap().destination.as_ref().unwrap().clone(); - let response = - SwbusMessage::new_response(&request, None, SwbusErrorCode::Ok, "", create_mock_message_id(), None); + let response = SwbusMessage::new_response( + request.header.as_ref().unwrap(), + None, + SwbusErrorCode::Ok, + "", + create_mock_message_id(), + None, + ); assert_eq!(response.header.as_ref().unwrap().version, 1); assert_eq!(response.header.as_ref().unwrap().flag, 0); assert_eq!(response.header.as_ref().unwrap().ttl, 64); @@ -646,4 +652,46 @@ mod tests { // assert_eq!(response.body.as_ref().unwrap().request_, true); } + + #[test] + fn test_sp_compare() { + let sp1 = ServicePath::from_string("region-a.cluster-a.1.0.0.1-dpu0").unwrap(); + let sp2 = ServicePath::from_string("region-a.cluster-a.1.0.0.2-dpu0").unwrap(); + let sp3 = ServicePath::from_string("region-a.cluster-b").unwrap(); + + assert_eq!(sp1 < sp2, true); + assert_eq!(sp2 < sp3, true); + } + + #[test] + fn test_route_entry_compare() { + let entry1 = RouteEntry { + service_path: Some(ServicePath::from_string("region-a.cluster-a.10.0.0.1-dpu0").unwrap()), + hop_count: 1, + nh_id: "".to_string(), + nh_service_path: None, + route_scope: RouteScope::InCluster as i32, + }; + let entry2 = RouteEntry { + service_path: Some(ServicePath::from_string("region-a.cluster-a.10.0.0.2-dpu0").unwrap()), + hop_count: 0, //my route has hop count 0 + nh_id: "".to_string(), + nh_service_path: None, + route_scope: RouteScope::InCluster as i32, + }; + let entry3 = RouteEntry { + service_path: Some(ServicePath::from_string("region-a.cluster-b").unwrap()), + hop_count: 1, + nh_id: "".to_string(), + nh_service_path: None, + route_scope: RouteScope::InRegion as i32, + }; + + assert_eq!(entry1 < entry2, true); + assert_eq!(entry2 < entry3, true); + + let mut entries = vec![entry1.clone(), entry3.clone(), entry2.clone()]; + entries.sort(); + assert_eq!(entries, vec![entry1, entry2, entry3]); + } } diff --git a/crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.1.cfg b/crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.1.cfg new file mode 100644 index 0000000..4b489a8 --- /dev/null +++ b/crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.1.cfg @@ -0,0 +1,14 @@ + endpoint: "127.0.0.1:50001" + routes: + - key: "region-a.cluster-a.10.0.0.1-dpu0" + scope: "InCluster" + - key: "region-a.cluster-a" + scope: "InRegion" + peers: + - id: "region-a.cluster-a.10.0.0.2-dpu0" + endpoint: "127.0.0.1:50002" + conn_type: "InCluster" + - id: "region-a.cluster-b.11.0.0.1-dpu0" + endpoint: "127.0.0.1:50000" + conn_type: "InRegion" + diff --git a/crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.2.cfg b/crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.2.cfg new file mode 100644 index 0000000..68378fc --- /dev/null +++ b/crates/swbusd/sample/inter-cluster/swbusd_cluster-a.10.0.0.2.cfg @@ -0,0 +1,8 @@ + endpoint: 127.0.0.1:50002 + routes: + - key: "region-a.cluster-a.10.0.0.2-dpu0" + scope: "InCluster" + peers: + - id: "region-a.cluster-a.10.0.0.1-dpu0" + endpoint: "127.0.0.1:50001" + conn_type: "InCluster" diff --git a/crates/swbusd/sample/inter-cluster/swbusd_cluster-b.11.0.0.1.cfg b/crates/swbusd/sample/inter-cluster/swbusd_cluster-b.11.0.0.1.cfg new file mode 100644 index 0000000..724315a --- /dev/null +++ b/crates/swbusd/sample/inter-cluster/swbusd_cluster-b.11.0.0.1.cfg @@ -0,0 +1,10 @@ + endpoint: "127.0.0.1:50000" + routes: + - key: "region-a.cluster-b.11.0.0.1-dpu0" + scope: "InCluster" + - key: "region-a.cluster-b" + scope: "InRegion" + peers: + - id: "region-a.cluster-a.10.0.0.1-dpu0" + endpoint: "127.0.0.1:50001" + conn_type: "InRegion" diff --git a/crates/swbusd/sample/swbusd1.cfg b/crates/swbusd/sample/swbusd1.cfg index 12f1de8..e97b8d7 100644 --- a/crates/swbusd/sample/swbusd1.cfg +++ b/crates/swbusd/sample/swbusd1.cfg @@ -1,8 +1,10 @@ endpoint: "127.0.0.1:50001" routes: - key: "region-a.cluster-a.10.0.0.1-dpu0" - scope: "Cluster" + scope: "InCluster" + - key: "region-a.cluster-a" + scope: "InRegion" peers: - id: "region-a.cluster-a.10.0.0.2-dpu0" endpoint: "127.0.0.1:50002" - conn_type: "Cluster" + conn_type: "InCluster" diff --git a/crates/swbusd/sample/swbusd2.cfg b/crates/swbusd/sample/swbusd2.cfg index f67e1db..68378fc 100644 --- a/crates/swbusd/sample/swbusd2.cfg +++ b/crates/swbusd/sample/swbusd2.cfg @@ -1,8 +1,8 @@ endpoint: 127.0.0.1:50002 routes: - key: "region-a.cluster-a.10.0.0.2-dpu0" - scope: "Cluster" + scope: "InCluster" peers: - id: "region-a.cluster-a.10.0.0.1-dpu0" endpoint: "127.0.0.1:50001" - conn_type: "Cluster" + conn_type: "InCluster" diff --git a/crates/swss-common-bridge/src/consumer.rs b/crates/swss-common-bridge/src/consumer.rs index 8bdb5c3..9bdea5f 100644 --- a/crates/swss-common-bridge/src/consumer.rs +++ b/crates/swss-common-bridge/src/consumer.rs @@ -188,7 +188,7 @@ mod test { use swbus_actor::ActorMessage; use swbus_edge::{ simple_client::{IncomingMessage, MessageBody, SimpleSwbusEdgeClient}, - swbus_proto::swbus::ServicePath, + swbus_proto::swbus::{ConnectionType, ServicePath}, SwbusEdgeRuntime, }; use swss_common::{ @@ -244,7 +244,7 @@ mod test { mut producer_table: P, ) { // Setup swbus - let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge")); + let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge"), ConnectionType::InNode); swbus_edge.start().await.unwrap(); let rt = Arc::new(swbus_edge); @@ -319,7 +319,7 @@ mod test { async fn run_proto_test(consumer_table: C, mut producer_table: P) { // Setup swbus - let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge")); + let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge"), ConnectionType::InNode); swbus_edge.start().await.unwrap(); let rt = Arc::new(swbus_edge); diff --git a/crates/swss-common-bridge/src/producer.rs b/crates/swss-common-bridge/src/producer.rs index 00c222e..22f1c39 100644 --- a/crates/swss-common-bridge/src/producer.rs +++ b/crates/swss-common-bridge/src/producer.rs @@ -121,7 +121,7 @@ mod test { use swbus_actor::ActorMessage; use swbus_edge::{ simple_client::{MessageBody, OutgoingMessage, SimpleSwbusEdgeClient}, - swbus_proto::swbus::ServicePath, + swbus_proto::swbus::{ConnectionType, ServicePath}, SwbusEdgeRuntime, }; use swss_common::{ @@ -172,7 +172,7 @@ mod test { async fn run_test(mut consumer_table: C, producer_table: P) { // Setup swbus - let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge")); + let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge"), ConnectionType::InNode); swbus_edge.start().await.unwrap(); let rt = Arc::new(swbus_edge); @@ -209,7 +209,7 @@ mod test { async fn run_dup_test(mut consumer_table: C, producer_table: P) { // Setup swbus - let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge")); + let mut swbus_edge = SwbusEdgeRuntime::new("".to_string(), sp("edge"), ConnectionType::InNode); swbus_edge.start().await.unwrap(); let rt = Arc::new(swbus_edge); From 0b2aedaab74e9e15ff829c506a1c6aab1e5b1930 Mon Sep 17 00:00:00 2001 From: dypet Date: Tue, 2 Sep 2025 11:56:58 -0600 Subject: [PATCH 05/14] Update dpu scope state table name. (#106) Update dpu scope state table name to match https://github.com/sonic-net/sonic-swss-common/blob/e7ee75dfcd44de934d49aea43c991a6aa20db63b/common/schema.h#L557 --- crates/hamgrd/src/db_structs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/hamgrd/src/db_structs.rs b/crates/hamgrd/src/db_structs.rs index 1b275d0..3c679d1 100644 --- a/crates/hamgrd/src/db_structs.rs +++ b/crates/hamgrd/src/db_structs.rs @@ -295,7 +295,7 @@ pub struct DashHaScopeTable { /// #[derive(Debug, Deserialize, Serialize, PartialEq, Default, Clone, SonicDb)] #[sonicdb( - table_name = "DASH_HA_SCOPE_STATE", + table_name = "DASH_HA_SCOPE_STATE_TABLE", key_separator = "|", db_name = "DPU_STATE_DB", is_dpu = "true" From 34f1387d4be395080392dfbd25f7a88d0254a87e Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 2 Sep 2025 13:57:58 -0400 Subject: [PATCH 06/14] Implement mark-delete in actor framework (#104) ### why actor has the retry logic in outgoing state. If a message is not acked, it will resend the message to make sure receiver has received it successfully. When an actor is terminated, the retry will be terminated as well so it can't guarantee the receiver getting the message. ### what this PR does Introduce mark-delete concept. 1. when an actor is going to terminate, add "mark_deleted" flag to the driver of the actor. 2. In the run loop of the actor, which is triggered each time it receives a message, it will check if the actor is ready_for_delete. 3. ready_for_delete checks if there is unacked message in outgoing state. Only exits from the run loop when there is non 4. When an actor is in mark_deleted state, stop processing incoming requests but always replies OK. So 2 mark_deleted actors won't form a dead loop. 5. response is processed normally so unacked messages can be ACKed. 6. management_request is processed normally so we can still dump actor state using swbus-cli --- crates/hamgrd/src/actors/ha_scope.rs | 41 +++++++++++++++++++-- crates/hamgrd/src/actors/ha_set.rs | 8 +++++ crates/hamgrd/src/actors/test.rs | 11 +++--- crates/hamgrd/src/actors/vdpu.rs | 3 +- crates/swbus-actor/src/driver.rs | 45 +++++++++++++++++++++--- crates/swbus-actor/src/state/outgoing.rs | 10 ++++-- 6 files changed, 103 insertions(+), 15 deletions(-) diff --git a/crates/hamgrd/src/actors/ha_scope.rs b/crates/hamgrd/src/actors/ha_scope.rs index 9b1b273..66cf6cd 100644 --- a/crates/hamgrd/src/actors/ha_scope.rs +++ b/crates/hamgrd/src/actors/ha_scope.rs @@ -821,7 +821,8 @@ mod test { "flow_reconcile_requested": "false" }, }, - addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) + }, // Write to NPU DASH_HA_SCOPE_STATE through internal state with no pending activation chkdb! { type: NpuDashHaScopeState, @@ -840,7 +841,23 @@ mod test { // Send vdpu state update after bfd session up send! { key: VDpuActorState::msg_key(&vdpu0_id), data: vdpu0_state_obj, addr: runtime.sp("vdpu", &vdpu0_id) }, - + // Recv update to DPU DASH_HA_SCOPE_TABLE, triggered by vdpu state update + recv! { key: &ha_set_id, data: { + "key": &ha_set_id, + "operation": "Set", + "field_values": { + "version": "3", + "ha_role": "active", + "disabled": "false", + "ha_set_id": &ha_set_id, + "vip_v4": ha_set_obj.vip_v4.clone(), + "vip_v6": ha_set_obj.vip_v6.clone(), + "activate_role_requested": "false", + "flow_reconcile_requested": "false" + }, + }, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) + }, // Write to NPU DASH_HA_SCOPE_STATE through internal state with bfd session up chkdb! { type: NpuDashHaScopeState, key: &scope_id_in_state, data: npu_ha_scope_state_fvs6, @@ -857,10 +874,26 @@ mod test { let commands = [ // Send DASH_HA_SCOPE_CONFIG_TABLE with desired_ha_state = dead send! { key: HaScopeConfig::table_name(), data: { "key": &scope_id, "operation": "Set", - "field_values": {"json": format!(r#"{{"version":"2","disabled":false,"desired_ha_state":{},"owner":{},"ha_set_id":"{ha_set_id}","approved_pending_operation_ids":[]}}"#, DesiredHaState::Dead as i32, HaOwner::Dpu as i32)}, + "field_values": {"json": format!(r#"{{"version":"4","disabled":false,"desired_ha_state":{},"owner":{},"ha_set_id":"{ha_set_id}","approved_pending_operation_ids":[]}}"#, DesiredHaState::Dead as i32, HaOwner::Dpu as i32)}, }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: &ha_set_id, data: { + "key": &ha_set_id, + "operation": "Set", + "field_values": { + "version": "4", + "ha_role": "dead", + "disabled": "false", + "ha_set_id": &ha_set_id, + "vip_v4": ha_set_obj.vip_v4.clone(), + "vip_v6": ha_set_obj.vip_v6.clone(), + "activate_role_requested": "false", + "flow_reconcile_requested": "false" + }, + }, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) + }, // Check NPU DASH_HA_SCOPE_STATE is updated with desired_ha_state = dead chkdb! { type: NpuDashHaScopeState, key: &scope_id_in_state, data: npu_ha_scope_state_fvs7, @@ -871,6 +904,8 @@ mod test { "field_values": {"json": format!(r#"{{"version":"2","disabled":false,"desired_ha_state":{},"owner":{},"ha_set_id":"{ha_set_id}","approved_pending_operation_ids":[]}}"#, DesiredHaState::Dead as i32, HaOwner::Dpu as i32)}, }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &scope_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::HaSetState, &scope_id), data: { "active": false }, addr: runtime.sp(HaSetActor::name(), &ha_set_id) }, ]; test::run_commands(&runtime, runtime.sp(HaScopeActor::name(), &scope_id), &commands).await; diff --git a/crates/hamgrd/src/actors/ha_set.rs b/crates/hamgrd/src/actors/ha_set.rs index f76e88a..fa8db72 100644 --- a/crates/hamgrd/src/actors/ha_set.rs +++ b/crates/hamgrd/src/actors/ha_set.rs @@ -516,6 +516,10 @@ mod test { // simulate delete of ha-set entry send! { key: HaSetActor::table_name(), data: { "key": HaSetActor::table_name(), "operation": "Del", "field_values": ha_set_cfg_fvs }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, + addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, + addr: runtime.sp(VDpuActor::name(), &vdpu1_id) }, ]; test::run_commands(&runtime, runtime.sp(HaSetActor::name(), &ha_set_id), &commands).await; @@ -591,6 +595,10 @@ mod test { // simulate delete of ha-set entry send! { key: HaSetActor::table_name(), data: { "key": HaSetActor::table_name(), "operation": "Del", "field_values": ha_set_cfg_fvs }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, + addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, + addr: runtime.sp(VDpuActor::name(), &vdpu1_id) }, ]; test::run_commands(&runtime, runtime.sp(HaSetActor::name(), &ha_set_id), &commands).await; diff --git a/crates/hamgrd/src/actors/test.rs b/crates/hamgrd/src/actors/test.rs index 6d501c5..74aa188 100644 --- a/crates/hamgrd/src/actors/test.rs +++ b/crates/hamgrd/src/actors/test.rs @@ -136,7 +136,8 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ } // Execute commands - for cmd in commands { + for (index, cmd) in commands.iter().enumerate() { + let step = index + 1; match cmd { Send { key, data, addr, fail } => { let client = &clients[addr]; @@ -149,7 +150,7 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ }; let sent_id = client.send(msg).await.unwrap(); - print!("Sent {key}, "); + print!("Step {step} - Sent {key}, "); if *fail { print!("expecting Fail, "); @@ -187,7 +188,7 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ Recv { key, data, addr } => { let client = &clients[addr]; - print!("Receiving {key}, "); + print!("Step {step} - Receiving {key}, "); let (am, request_id) = match timeout(client.recv()).await { Ok(Some(IncomingMessage { body: MessageBody::Request { payload }, @@ -225,6 +226,7 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ data, exclude, } => { + print!("Step {step} - Checking {table_name}/{db_name} for key {key}, "); let db = crate::db_named(db_name, *is_dpu).await.unwrap(); let mut table = Table::new(db, table_name).unwrap(); @@ -249,6 +251,7 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ if actual_data == fvs { // Success, break out of retry loop + println!("found"); break; } else { last_error = Some(format!( @@ -272,7 +275,7 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ } // If we got here and there's still an error, panic on the last attempt if let Some(error) = last_error { - panic!("{error}"); + panic!("Step {step} - {error}"); } } } diff --git a/crates/hamgrd/src/actors/vdpu.rs b/crates/hamgrd/src/actors/vdpu.rs index 33e0f12..04033cb 100644 --- a/crates/hamgrd/src/actors/vdpu.rs +++ b/crates/hamgrd/src/actors/vdpu.rs @@ -192,7 +192,8 @@ mod test { send! { key: VDpuActor::table_name(), data: { "key": VDpuActor::table_name(), "operation": "Del", "field_values": {"main_dpu_ids": "switch1_dpu0"}}, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, - + recv! { key: ActorRegistration::msg_key(RegistrationType::DPUState, "test-vdpu"), data: { "active": false }, + addr: runtime.sp(DpuActor::name(), "switch1_dpu0") }, ]; test::run_commands(&runtime, runtime.sp(VDpuActor::name(), "test-vdpu"), &commands).await; diff --git a/crates/swbus-actor/src/driver.rs b/crates/swbus-actor/src/driver.rs index 5254675..a34f93d 100644 --- a/crates/swbus-actor/src/driver.rs +++ b/crates/swbus-actor/src/driver.rs @@ -13,6 +13,7 @@ pub(crate) struct ActorDriver { state: State, swbus_edge: Arc, context: Context, + mark_deleted: bool, } impl ActorDriver { @@ -24,9 +25,15 @@ impl ActorDriver { state: State::new(swbus_edge.clone()), swbus_edge, context: Context::new(edge_runtime), + mark_deleted: false, } } + /// Check if the actor is ready for deletion (no unacked messages) + fn ready_for_delete(&self) -> bool { + self.state.outgoing.ready_for_delete() + } + /// Run the actor's main loop pub(crate) async fn run(mut self) { self.actor.init(&mut self.state).await.unwrap(); @@ -46,11 +53,21 @@ impl ActorDriver { } } if self.context.stopped { - info!( - "actor {} terminated", - self.swbus_edge.get_service_path().to_longest_path() - ); - break; + self.mark_deleted = true; + } + if self.mark_deleted { + if self.ready_for_delete() { + info!( + "actor {} terminated", + self.swbus_edge.get_service_path().to_longest_path() + ); + break; + } else { + debug!( + "actor {} marked for deletion, waiting for unacked messages to complete", + self.swbus_edge.get_service_path().to_longest_path() + ); + } } } } @@ -60,6 +77,24 @@ impl ActorDriver { let IncomingMessage { id, source, body, .. } = msg; match body { MessageBody::Request { payload } => { + if self.mark_deleted { + // Actor is marked for deletion, send OK response but don't process + debug!("Actor marked for deletion, skipping request processing"); + self.swbus_edge + .send(OutgoingMessage { + destination: source.clone(), + body: MessageBody::Response { + request_id: id, + error_code: SwbusErrorCode::Ok, + error_message: "Request ignored: actor has been marked for deletion".to_string(), + response_body: None, + }, + }) + .await + .expect("failed to send swbus message"); + return; + } + let Ok(actor_msg) = ActorMessage::deserialize(&payload) else { eprintln!("Received invalid actor message from {source}"); return; diff --git a/crates/swbus-actor/src/state/outgoing.rs b/crates/swbus-actor/src/state/outgoing.rs index 24215b3..2d30a09 100644 --- a/crates/swbus-actor/src/state/outgoing.rs +++ b/crates/swbus-actor/src/state/outgoing.rs @@ -1,3 +1,4 @@ +use super::get_unix_time; use crate::actor_message::{actor_msg_to_swbus_msg, ActorMessage}; use serde::{Deserialize, Serialize}; use std::{ @@ -10,8 +11,7 @@ use swbus_edge::{ swbus_proto::swbus::{ServicePath, SwbusErrorCode, SwbusMessage}, }; use tokio::time::{interval, Interval}; - -use super::get_unix_time; +use tracing::debug; const RESEND_TIME: Duration = Duration::from_secs(60); @@ -59,6 +59,7 @@ impl Outgoing { /// Actor logic succeeded, so send out messages. pub(crate) async fn send_queued_messages(&mut self) { for msg in self.queued_messages.drain(..) { + debug!("Sending message: {msg:?}"); self.swbus_client .send_raw(msg.swbus_message.clone()) .await @@ -150,6 +151,11 @@ impl Outgoing { self.from_my_sp("swss-common-bridge", &resource_id) } + /// Check if there are no unacked messages + pub fn ready_for_delete(&self) -> bool { + self.unacked_messages.is_empty() + } + pub(crate) fn dump_state(&self) -> OutgoingStateData { let state_data = OutgoingStateData { outgoing_queued: self.queued_messages.clone(), From 489637c2d2716dcacda65e91f486e40928494d1a Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 9 Sep 2025 03:00:12 -0400 Subject: [PATCH 07/14] fix show hamgrd actor command (#108) ### why show hamgrd actor command is broken after route_exchange PR is merged. In log we can see this below error Sep 5 12:11:25 ott-ss-010 swbusd: 2025-09-05T16:11:25.820184Z ERROR ConnWorker{conn_id="swbs-from://127.0.0.1:39642"}: 96: Failed to process the incoming message: Input:InvalidArgs - Invalid management request: ManagementRequest { request: HamgrdGetActorState, arguments: [] } This is because swbusd incorrectly intercepting all ManagementRequest. ### What this PR does 1. check if the ManagementRequest has swbusd's service path as destination. If not, route the message 2. fix some misc issues exposed after above code change. 3. use init_logger_for_test from Logger and remove the proprietary implementation. --- crates/swbus-core/src/mux/conn_worker.rs | 13 +++++++-- crates/swbus-core/tests/basic_tests.rs | 12 ++------ .../swbus-core/tests/common/test_executor.rs | 29 ------------------- .../data/inter-cluster/test_show_route.json | 2 +- .../swbus-core/tests/inter_cluster_tests.rs | 12 ++------ 5 files changed, 17 insertions(+), 51 deletions(-) diff --git a/crates/swbus-core/src/mux/conn_worker.rs b/crates/swbus-core/src/mux/conn_worker.rs index 26dde13..7a38ed2 100644 --- a/crates/swbus-core/src/mux/conn_worker.rs +++ b/crates/swbus-core/src/mux/conn_worker.rs @@ -152,9 +152,16 @@ where debug!("Received route announcement"); self.mux.process_route_announcement(route_entries, &self.info)?; } - Some(swbus_message::Body::ManagementRequest(mgmt_request)) => { - let response = self.process_mgmt_request(message.header.as_ref().unwrap(), mgmt_request)?; - self.mux.route_message(response).await?; + Some(swbus_message::Body::ManagementRequest(ref mgmt_request)) => { + let my_sp = self.mux.get_my_service_path(); + if message.header.as_ref().unwrap().destination.as_ref().unwrap() == my_sp { + // Message is destined for this service, process it locally + let response = self.process_mgmt_request(message.header.as_ref().unwrap(), mgmt_request.clone())?; + self.mux.route_message(response).await?; + } else { + // Message is not for us, route it to the intended destination + self.mux.route_message(message).await?; + } } _ => { self.mux.route_message(message).await?; diff --git a/crates/swbus-core/tests/basic_tests.rs b/crates/swbus-core/tests/basic_tests.rs index 7dfb2c4..074d653 100644 --- a/crates/swbus-core/tests/basic_tests.rs +++ b/crates/swbus-core/tests/basic_tests.rs @@ -1,15 +1,9 @@ mod common; -use common::test_executor::{init_logger, run_tests, TopoRuntime}; - +use common::test_executor::{run_tests, TopoRuntime}; +use sonic_common::log::init_logger_for_test; #[tokio::test] async fn test_b2b() { - let trace_enabled: bool = std::env::var("ENABLE_TRACE") - .map(|val| val == "1" || val.eq_ignore_ascii_case("true")) - .unwrap_or(false); - - if trace_enabled { - init_logger(); - } + init_logger_for_test(); let mut topo = TopoRuntime::new("tests/data/b2b/topo.json"); topo.bring_up().await; diff --git a/crates/swbus-core/tests/common/test_executor.rs b/crates/swbus-core/tests/common/test_executor.rs index cb0ff8b..67b19a8 100644 --- a/crates/swbus-core/tests/common/test_executor.rs +++ b/crates/swbus-core/tests/common/test_executor.rs @@ -1,5 +1,4 @@ use serde::{Deserialize, Serialize}; -use sonic_common::log::init_logger_for_test; use std::collections::HashMap; use std::env; use std::fs::{self, File}; @@ -12,7 +11,6 @@ use swbus_proto::swbus::*; use tokio::sync::mpsc; use tokio::task::JoinHandle; use tokio::time::{self, Duration, Instant}; -use tracing_subscriber::{fmt, prelude::*, Layer}; // 3 seconds receive timeout pub const RECEIVE_TIMEOUT: u32 = 3; @@ -82,8 +80,6 @@ impl TopoRuntime { /// The client configurations are a map of client names to the client configuration. /// The client configuration contains the server (swbusd) name where the client is connected and the service path of the client. pub async fn bring_up(&mut self) { - init_logger_for_test(); - let file = File::open(&self.topo_file).unwrap(); let reader = BufReader::new(file); @@ -250,28 +246,3 @@ async fn record_received_messages(topo: &mut TopoRuntime, timeout: u32) -> Vec Date: Tue, 9 Sep 2025 01:00:44 -0600 Subject: [PATCH 08/14] Fix ha state. (#107) local_ha_state was being set to ha_role, update so it is correctly set to the DPU's ha_state. This fixes issue #91 --- crates/hamgrd/src/actors/ha_scope.rs | 4 ++-- crates/hamgrd/src/actors/test.rs | 2 ++ crates/hamgrd/src/db_structs.rs | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/hamgrd/src/actors/ha_scope.rs b/crates/hamgrd/src/actors/ha_scope.rs index 66cf6cd..b3f0610 100644 --- a/crates/hamgrd/src/actors/ha_scope.rs +++ b/crates/hamgrd/src/actors/ha_scope.rs @@ -404,7 +404,7 @@ impl HaScopeActor { }; // in dpu driven mode, local_ha_state is same as dpu acked ha state - npu_ha_scope_state.local_ha_state = Some(dpu_ha_scope_state.ha_role.clone()); + npu_ha_scope_state.local_ha_state = Some(dpu_ha_scope_state.ha_state.clone()); npu_ha_scope_state.local_ha_state_last_updated_time_in_ms = Some(dpu_ha_scope_state.ha_role_start_time); // The reason of the last HA state change. npu_ha_scope_state.local_ha_state_last_updated_reason = Some("dpu initiated".to_string()); @@ -418,7 +418,7 @@ impl HaScopeActor { .to_lowercase(), ); // The HA state that ASIC acked. - npu_ha_scope_state.local_acked_asic_ha_state = Some(dpu_ha_scope_state.ha_role.clone()); + npu_ha_scope_state.local_acked_asic_ha_state = Some(dpu_ha_scope_state.ha_state.clone()); // The current target term of the HA state machine. in dpu-driven mode, use the term acked by asic npu_ha_scope_state.local_target_term = Some(dpu_ha_scope_state.ha_term.clone()); diff --git a/crates/hamgrd/src/actors/test.rs b/crates/hamgrd/src/actors/test.rs index 74aa188..77ae80f 100644 --- a/crates/hamgrd/src/actors/test.rs +++ b/crates/hamgrd/src/actors/test.rs @@ -594,6 +594,8 @@ pub fn make_dpu_ha_scope_state(role: &str) -> DpuDashHaScopeState { ha_role_start_time: now_in_millis(), // The current term confirmed by ASIC. ha_term: "1".to_string(), + // The DPU HA state. + ha_state: role.to_string(), activate_role_pending: false, flow_reconcile_pending: false, brainsplit_recover_pending: false, diff --git a/crates/hamgrd/src/db_structs.rs b/crates/hamgrd/src/db_structs.rs index 3c679d1..230c209 100644 --- a/crates/hamgrd/src/db_structs.rs +++ b/crates/hamgrd/src/db_structs.rs @@ -309,6 +309,8 @@ pub struct DpuDashHaScopeState { pub ha_role_start_time: i64, // The current term confirmed by ASIC. pub ha_term: String, + // The DPU ha state. + pub ha_state: String, // DPU is pending on role activation. #[serde(default)] pub activate_role_pending: bool, From b0e267245aa89b5cdfd0b0f0dcba78e1add4099a Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 9 Sep 2025 13:07:38 -0400 Subject: [PATCH 09/14] Implement cleanup logic for all the actors (#102) ### why This addresses issue #100. When upstream deletes the DB entry that is the originator of the actor, the actor should cleanup all the db entries it has created before terminating itself. For example, deleting DashHaSetConfig entry should triggers the cleanup actor in the corresponding HaSetActor, which includes removing DASH_HA_SET_TABLE it creates in DPU_APPL_DB and VNET_ROUTE_TUNNEL_TABLE in APPL_DB. ### what this PR does 1. Implements cleanup for all the actors. - DpuActor: remove entries in DPU_APPL_DB/BFD_SESSION_TABLE - VDpuActor: unregister from DpuActor - HASetActor: remove entry from DPU_APPL_DB/DASH_HA_SET_TABLE, remove entry from APPL_DB/VNET_ROUTE_TUNNEL_TABLE, unregistered from VDpuActor - HAScopeActor: remove entry from DPU_APPL_DB/DASH_HA_SCOPE_TABLE, remove entry from STATE_DB/DASH_HA_SCOPE_STATE and unregister from VDpuActor and HaSetActor 3. Extend ChkDb macro to check a db entry doesn't exist 4. Extend Internal state with deleting an entry from db --- crates/hamgrd/src/actors/dpu.rs | 68 +++++++--- crates/hamgrd/src/actors/ha_scope.rs | 47 ++++++- crates/hamgrd/src/actors/ha_set.rs | 65 ++++++++- crates/hamgrd/src/actors/test.rs | 83 +++++++++--- crates/hamgrd/src/actors/vdpu.rs | 14 +- crates/swbus-actor/src/state/internal.rs | 44 ++++++- crates/swbus-actor/tests/kvstore.rs | 160 ++++++++++++++++++++++- 7 files changed, 414 insertions(+), 67 deletions(-) diff --git a/crates/hamgrd/src/actors/dpu.rs b/crates/hamgrd/src/actors/dpu.rs index 1ee801a..f149b67 100644 --- a/crates/hamgrd/src/actors/dpu.rs +++ b/crates/hamgrd/src/actors/dpu.rs @@ -6,7 +6,7 @@ use crate::ha_actor_messages::{ActorRegistration, DpuActorState, RegistrationTyp use crate::ServicePath; use anyhow::{anyhow, Result}; use sonic_common::SonicDbTable; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::sync::Arc; use swbus_actor::{state::incoming::Incoming, state::outgoing::Outgoing, Actor, ActorMessage, Context, State}; use swbus_edge::SwbusEdgeRuntime; @@ -125,10 +125,17 @@ impl DpuActor { Ok(bridges) } + fn do_cleanup(&mut self, _context: &mut Context, state: &mut State) { + if let Err(e) = self.update_bfd_sessions(state, true) { + error!("Failed to cleanup BFD sessions: {}", e); + } + } + async fn handle_dpu_message(&mut self, state: &mut State, key: &str, context: &mut Context) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { + self.do_cleanup(context, state); context.stop(); return Ok(()); } @@ -316,28 +323,40 @@ impl DpuActor { peer_ip: &str, global_cfg: &DashHaGlobalConfig, outgoing: &mut Outgoing, + remove: bool, ) -> Result<()> { // todo: this needs to wait until HaScope has been activated. let Some(DpuData::LocalDpu { ref dpu, .. }) = self.dpu else { debug!("DPU is not managed by this HA instance. Ignore BFD session creation"); return Ok(()); }; - let bfd_session = BfdSessionTable { - tx_interval: global_cfg.dpu_bfd_probe_interval_in_ms, - rx_interval: global_cfg.dpu_bfd_probe_interval_in_ms, - multiplier: global_cfg.dpu_bfd_probe_multiplier, - multihop: true, - local_addr: dpu.pa_ipv4.clone(), - session_type: Some("passive".to_string()), - shutdown: false, - }; - let fv = swss_serde::to_field_values(&bfd_session)?; let sep = BfdSessionTable::key_separator(); - let kfv = KeyOpFieldValues { - key: format!("default{sep}default{sep}{peer_ip}"), - operation: KeyOperation::Set, - field_values: fv, + let key = format!("default{sep}default{sep}{peer_ip}"); + + let kfv = if remove { + KeyOpFieldValues { + key, + operation: KeyOperation::Del, + field_values: HashMap::new(), + } + } else { + let bfd_session = BfdSessionTable { + tx_interval: global_cfg.dpu_bfd_probe_interval_in_ms, + rx_interval: global_cfg.dpu_bfd_probe_interval_in_ms, + multiplier: global_cfg.dpu_bfd_probe_multiplier, + multihop: true, + local_addr: dpu.pa_ipv4.clone(), + session_type: Some("passive".to_string()), + shutdown: false, + }; + + let fv = swss_serde::to_field_values(&bfd_session)?; + KeyOpFieldValues { + key, + operation: KeyOperation::Set, + field_values: fv, + } }; let msg = ActorMessage::new(self.id.clone(), &kfv)?; @@ -345,9 +364,9 @@ impl DpuActor { Ok(()) } - fn update_bfd_sessions(&self, state: &mut State) -> Result<()> { + fn update_bfd_sessions(&self, state: &mut State, remove: bool) -> Result<()> { if !self.is_local_managed() { - debug!("DPU is not managed by this HA instance. Ignore BFD session creation"); + debug!("DPU is not managed by this HA instance. Ignore BFD session creation or deletion"); return Ok(()); } let (_internal, incoming, outgoing) = state.get_all(); @@ -375,7 +394,7 @@ impl DpuActor { remote_npus.push(npu_ipv4.clone()); remote_npus.sort(); for npu in remote_npus { - self.update_bfd_session(&npu, &global_cfg, outgoing)?; + self.update_bfd_session(&npu, &global_cfg, outgoing, remove)?; } Ok(()) } @@ -408,7 +427,7 @@ impl DpuActor { // create bfd session let global_cfg = Self::get_dash_ha_global_config(incoming)?; - self.update_bfd_session(&remote_dpu.npu_ipv4, &global_cfg, outgoing)?; + self.update_bfd_session(&remote_dpu.npu_ipv4, &global_cfg, outgoing, false)?; Ok(()) } @@ -421,7 +440,7 @@ impl DpuActor { } fn handle_dash_ha_global_config(&mut self, state: &mut State) -> Result<()> { - self.update_bfd_sessions(state)?; + self.update_bfd_sessions(state, false)?; Ok(()) } @@ -566,6 +585,15 @@ mod test { // simulate delete of Dpu entry send! { key: Dpu::table_name(), data: { "key": DpuActor::dpu_table_name(), "operation": "Del", "field_values": dpu_fvs}, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + + recv! { key: "switch0_dpu0", data: {"key": "default:default:10.0.0.0", "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: "switch0_dpu0", data: {"key": "default:default:10.0.1.0", "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: "switch0_dpu0", data: {"key": "default:default:10.0.2.0", "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: "switch0_dpu0", data: {"key": "default:default:10.0.3.0", "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, ]; test::run_commands(&runtime, runtime.sp("dpu", "switch0_dpu0"), &commands).await; if tokio::time::timeout(Duration::from_secs(1), handle).await.is_err() { diff --git a/crates/hamgrd/src/actors/ha_scope.rs b/crates/hamgrd/src/actors/ha_scope.rs index b3f0610..fa03246 100644 --- a/crates/hamgrd/src/actors/ha_scope.rs +++ b/crates/hamgrd/src/actors/ha_scope.rs @@ -189,6 +189,38 @@ impl HaScopeActor { Ok(()) } + fn delete_dash_ha_scope_table(&self, outgoing: &mut Outgoing) -> Result<()> { + let kfv = KeyOpFieldValues { + key: self.ha_scope_id.clone(), + operation: KeyOperation::Del, + field_values: HashMap::new(), + }; + + let msg = ActorMessage::new(self.ha_scope_id.clone(), &kfv)?; + outgoing.send(outgoing.common_bridge_sp::(), msg); + + Ok(()) + } + + fn delete_npu_ha_scope_state(&self, internal: &mut Internal) -> Result<()> { + if self.dash_ha_scope_config.is_none() { + return Ok(()); + }; + + internal.delete(NpuDashHaScopeState::table_name()); + + Ok(()) + } + + fn do_cleanup(&mut self, state: &mut State) -> Result<()> { + let (internal, _incoming, outgoing) = state.get_all(); + self.delete_dash_ha_scope_table(outgoing)?; + self.delete_npu_ha_scope_state(internal)?; + self.register_to_vdpu_actor(outgoing, false)?; + self.register_to_haset_actor(outgoing, false)?; + Ok(()) + } + fn update_dpu_ha_scope_table(&self, state: &mut State) -> Result<()> { let Some(dash_ha_scope_config) = self.dash_ha_scope_config.as_ref() else { return Ok(()); @@ -449,9 +481,10 @@ impl HaScopeActor { let kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; if kfv.operation == KeyOperation::Del { - // unregister from the vDPU Actor and ha-set actor - self.register_to_vdpu_actor(outgoing, false)?; - self.register_to_haset_actor(outgoing, false)?; + // cleanup resources before stopping + if let Err(e) = self.do_cleanup(state) { + error!("Failed to cleanup HaScopeActor resources: {}", e); + } context.stop(); return Ok(()); } @@ -904,8 +937,16 @@ mod test { "field_values": {"json": format!(r#"{{"version":"2","disabled":false,"desired_ha_state":{},"owner":{},"ha_set_id":"{ha_set_id}","approved_pending_operation_ids":[]}}"#, DesiredHaState::Dead as i32, HaOwner::Dpu as i32)}, }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + + // Verify that cleanup removed the NPU DASH_HA_SCOPE_STATE table entry + chkdb! { type: NpuDashHaScopeState, key: &scope_id_in_state, nonexist }, + + // Recv delete of DPU DASH_HA_SCOPE_TABLE + recv! { key: &ha_set_id, data: { "key": &ha_set_id, "operation": "Del", "field_values": {} }, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &scope_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, recv! { key: ActorRegistration::msg_key(RegistrationType::HaSetState, &scope_id), data: { "active": false }, addr: runtime.sp(HaSetActor::name(), &ha_set_id) }, + ]; test::run_commands(&runtime, runtime.sp(HaScopeActor::name(), &scope_id), &commands).await; diff --git a/crates/hamgrd/src/actors/ha_set.rs b/crates/hamgrd/src/actors/ha_set.rs index fa8db72..7a742aa 100644 --- a/crates/hamgrd/src/actors/ha_set.rs +++ b/crates/hamgrd/src/actors/ha_set.rs @@ -7,6 +7,7 @@ use sonic_common::SonicDbTable; use sonic_dash_api_proto::decode_from_field_values; use sonic_dash_api_proto::ha_set_config::HaSetConfig; use sonic_dash_api_proto::ip_to_string; +use std::collections::HashMap; use swbus_actor::{ state::{incoming::Incoming, internal::Internal, outgoing::Outgoing}, Actor, ActorMessage, Context, State, @@ -149,6 +150,24 @@ impl HaSetActor { Ok(()) } + fn delete_dash_ha_set_table(&self, vdpus: &[VDpuStateExt], outgoing: &mut Outgoing) -> Result<()> { + if !vdpus.iter().any(|vdpu_ext| vdpu_ext.vdpu.dpu.is_managed) { + debug!("None of DPUs is managed by local HAMGRD. Skip dash_ha_set deletion"); + return Ok(()); + } + + let kfv = KeyOpFieldValues { + key: self.id.clone(), + operation: KeyOperation::Del, + field_values: HashMap::new(), + }; + + let msg = ActorMessage::new(self.id.clone(), &kfv)?; + outgoing.send(outgoing.common_bridge_sp::(), msg); + + Ok(()) + } + async fn update_vnet_route_tunnel_table( &self, vdpus: &Vec, @@ -213,7 +232,12 @@ impl HaSetActor { Ok(()) } - async fn register_to_vdpu_actor(&self, outgoing: &mut Outgoing, active: bool) -> Result<()> { + fn delete_vnet_route_tunnel_table(&self, internal: &mut Internal) -> Result<()> { + internal.delete(VnetRouteTunnelTable::table_name()); + Ok(()) + } + + fn register_to_vdpu_actor(&self, outgoing: &mut Outgoing, active: bool) -> Result<()> { let Some(ref dash_ha_set_config) = self.dash_ha_set_config else { return Ok(()); }; @@ -304,9 +328,10 @@ impl HaSetActor { let (_internal, incoming, outgoing) = state.get_all(); let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { - // unregister from the DPU Actor - self.register_to_vdpu_actor(outgoing, false).await?; - + // cleanup resources before stopping + if let Err(e) = self.do_cleanup(state) { + error!("Failed to cleanup HaSetActor resources: {}", e); + } context.stop(); return Ok(()); } @@ -315,7 +340,7 @@ impl HaSetActor { self.dash_ha_set_config = Some(decode_from_field_values(&dpu_kfv.field_values).unwrap()); // Subscribe to the DPU Actor for state updates. - self.register_to_vdpu_actor(outgoing, true).await?; + self.register_to_vdpu_actor(outgoing, true)?; if first_time { self.bridges.push( @@ -379,6 +404,26 @@ impl HaSetActor { } Ok(()) } + + fn do_cleanup(&mut self, state: &mut State) -> Result<()> { + let (internal, incoming, outgoing) = state.get_all(); + + let Some(vdpus) = self.get_vdpus_if_ready(incoming) else { + debug!("Not all DPU info is ready for cleanup"); + return Ok(()); + }; + + if let Err(e) = self.delete_dash_ha_set_table(&vdpus, outgoing) { + error!("Failed to delete dash_ha_set_table: {}", e); + } + + if let Err(e) = self.delete_vnet_route_tunnel_table(internal) { + error!("Failed to delete vnet_route_tunnel_table: {}", e); + } + + self.register_to_vdpu_actor(outgoing, false)?; + Ok(()) + } } impl Actor for HaSetActor { @@ -512,10 +557,13 @@ mod test { // Verify that haset actor state is sent to ha-scope actor recv! { key: HaSetActorState::msg_key(&ha_set_id), data: { "up": true, "ha_set": &ha_set_obj }, addr: runtime.sp("ha-scope", &format!("vdpu0:{ha_set_id}")) }, - chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.unwrap(), ip_to_string(&ha_set_cfg.vip_v4.unwrap())), data: expected_vnet_route }, + chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), data: expected_vnet_route }, // simulate delete of ha-set entry send! { key: HaSetActor::table_name(), data: { "key": HaSetActor::table_name(), "operation": "Del", "field_values": ha_set_cfg_fvs }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + recv! { key: &ha_set_id, data: {"key": &ha_set_id, "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), nonexist }, recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, @@ -590,13 +638,16 @@ mod test { send! { key: VDpuActorState::msg_key(&vdpu1_id), data: vdpu1_state, addr: runtime.sp("vdpu", &vdpu1_id) }, // Verify that the DASH_HA_SET_TABLE was updated - chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.unwrap(), ip_to_string(&ha_set_cfg.vip_v4.unwrap())), + chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), data: expected_vnet_route }, // simulate delete of ha-set entry send! { key: HaSetActor::table_name(), data: { "key": HaSetActor::table_name(), "operation": "Del", "field_values": ha_set_cfg_fvs }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), nonexist }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, + recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu1_id) }, ]; diff --git a/crates/hamgrd/src/actors/test.rs b/crates/hamgrd/src/actors/test.rs index 77ae80f..e254f6a 100644 --- a/crates/hamgrd/src/actors/test.rs +++ b/crates/hamgrd/src/actors/test.rs @@ -68,6 +68,7 @@ macro_rules! chkdb { key: String::from($key), data: serde_json::json!($data), exclude: "".to_string(), + nonexist: false, } }; @@ -79,6 +80,19 @@ macro_rules! chkdb { key: String::from($key), data: serde_json::json!($data), exclude: String::from($exclude), + nonexist: false, + } + }; + + (type: $type:ty, key: $key:expr, nonexist) => { + $crate::actors::test::Command::ChkDb { + db: String::from(<$type>::db_name()), + is_dpu: <$type>::is_dpu(), + table: String::from(<$type>::table_name()), + key: String::from($key), + data: serde_json::json!({}), + exclude: "".to_string(), + nonexist: true, } }; @@ -90,6 +104,19 @@ macro_rules! chkdb { key: String::from($key), data: serde_json::json!($data), exclude: String::from($exclude), + nonexist: false, + } + }; + + (db: $db:expr, is_dpu: $is_dpu:expr, table: $table:expr, key: $key:expr, nonexist) => { + $crate::actors::test::Command::ChkDb { + db: String::from($db), + is_dpu: $is_dpu, + table: String::from($table), + key: String::from($key), + data: serde_json::json!({}), + exclude: "".to_string(), + nonexist: true, } }; } @@ -114,6 +141,7 @@ pub enum Command { key: String, data: Value, exclude: String, + nonexist: bool, }, } @@ -225,44 +253,56 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ key, data, exclude, + nonexist, } => { print!("Step {step} - Checking {table_name}/{db_name} for key {key}, "); let db = crate::db_named(db_name, *is_dpu).await.unwrap(); let mut table = Table::new(db, table_name).unwrap(); let mut last_error = None; - + print!("Checking DB {db_name}/{table_name} for key {key}, "); // Retry loop: 5 attempts with 100ms sleep between retries. This is needed because the previous send operation // may not have been fully committed yet. send is asynchronous and may not complete immediately. for attempt in 1..=5 { last_error = None; match table.get_async(key).await { Ok(Some(mut actual_data)) => { - let mut fvs: FieldValues = serde_json::from_value(data.clone()).unwrap(); - - exclude - .split(',') - .map(str::trim) - .filter(|s| !s.is_empty()) - .for_each(|id| { - fvs.remove(id); - actual_data.remove(id); - }); - - if actual_data == fvs { - // Success, break out of retry loop - println!("found"); - break; - } else { + if *nonexist { last_error = Some(format!( - "Data mismatch on attempt {attempt}: expected {fvs:?}, got {actual_data:?}" + "Key {key} unexpectedly found in {table_name}/{db_name} on attempt {attempt} (expected nonexist)" )); + } else { + let mut fvs: FieldValues = serde_json::from_value(data.clone()).unwrap(); + + exclude + .split(',') + .map(str::trim) + .filter(|s| !s.is_empty()) + .for_each(|id| { + fvs.remove(id); + actual_data.remove(id); + }); + + if actual_data == fvs { + // Success, break out of retry loop + println!("found"); + break; + } else { + last_error = Some(format!( + "Data mismatch on attempt {attempt}: expected {fvs:?}, got {actual_data:?}" + )); + } } } Ok(None) => { - last_error = Some(format!( - "Key {key} not found in {table_name}/{db_name} on attempt {attempt}" - )); + if *nonexist { + // Success, key doesn't exist as expected + break; + } else { + last_error = Some(format!( + "Key {key} not found in {table_name}/{db_name} on attempt {attempt}" + )); + } } Err(e) => { last_error = Some(format!("Database error on attempt {attempt}: {e}")); @@ -277,6 +317,7 @@ pub async fn run_commands(runtime: &ActorRuntime, aut: ServicePath, commands: &[ if let Some(error) = last_error { panic!("Step {step} - {error}"); } + println!("check passed"); } } } diff --git a/crates/hamgrd/src/actors/vdpu.rs b/crates/hamgrd/src/actors/vdpu.rs index 04033cb..f4cfae6 100644 --- a/crates/hamgrd/src/actors/vdpu.rs +++ b/crates/hamgrd/src/actors/vdpu.rs @@ -31,7 +31,7 @@ impl DbBasedActor for VDpuActor { } impl VDpuActor { - async fn register_to_dpu_actor(&self, outgoing: &mut Outgoing, active: bool) -> Result<()> { + fn register_to_dpu_actor(&self, outgoing: &mut Outgoing, active: bool) -> Result<()> { if self.vdpu.is_none() { return Ok(()); } @@ -42,12 +42,20 @@ impl VDpuActor { Ok(()) } + fn do_cleanup(&mut self, _context: &mut Context, state: &mut State) { + // unregister from the DPU Actor + let result = self.register_to_dpu_actor(state.outgoing(), false); + if result.is_err() { + error!("Failed to unregister from DPU Actor: {:?}", result.err()); + } + } + async fn handle_vdpu_message(&mut self, state: &mut State, key: &str, context: &mut Context) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { // unregister from the DPU Actor - self.register_to_dpu_actor(outgoing, false).await?; + self.do_cleanup(context, state); context.stop(); return Ok(()); } @@ -55,7 +63,7 @@ impl VDpuActor { self.vdpu = Some(swss_serde::from_field_values(&dpu_kfv.field_values)?); // Subscribe to the DPU Actor for state updates - self.register_to_dpu_actor(outgoing, true).await?; + self.register_to_dpu_actor(outgoing, true)?; Ok(()) } diff --git a/crates/swbus-actor/src/state/internal.rs b/crates/swbus-actor/src/state/internal.rs index b7ed071..cd60042 100644 --- a/crates/swbus-actor/src/state/internal.rs +++ b/crates/swbus-actor/src/state/internal.rs @@ -29,6 +29,12 @@ impl Internal { self.table.insert(key.into(), entry); } + pub fn delete(&mut self, key: &str) { + if let Some(entry) = self.table.get_mut(key) { + entry.delete(); + } + } + pub fn has_entry(&self, key: &str, swss_key: &str) -> bool { let entry = self.table.get(key); match entry { @@ -48,8 +54,17 @@ impl Internal { } pub(crate) async fn commit_changes(&mut self) { - for entry in self.table.values_mut() { + let mut keys_to_remove = Vec::new(); + + for (key, entry) in self.table.iter_mut() { entry.commit_changes().await; + if entry.data.to_delete { + keys_to_remove.push(key.clone()); + } + } + + for key in keys_to_remove { + self.table.remove(&key); } } @@ -74,6 +89,7 @@ pub struct InternalTableData { // Local cache/copy of the table's FVs pub fvs: FieldValues, pub mutated: bool, + pub to_delete: bool, // FVs that will be restored if an actor callback fails pub backup_fvs: FieldValues, @@ -91,6 +107,7 @@ impl PartialEq for InternalTableData { && self.fvs == other.fvs && self.backup_fvs == other.backup_fvs && self.mutated == other.mutated + && self.to_delete == other.to_delete } } @@ -110,6 +127,7 @@ impl InternalTableEntry { swss_key, fvs, mutated: false, + to_delete: false, backup_fvs, last_updated_time: None, }, @@ -129,17 +147,29 @@ impl InternalTableEntry { &mut self.data.fvs } + fn delete(&mut self) { + self.data.to_delete = true; + } + async fn commit_changes(&mut self) { - self.data.mutated = false; - self.swss_table - .set_async(&self.data.swss_key, self.data.fvs.clone()) - .await - .expect("Table::set threw an exception"); - self.data.last_updated_time = Some(get_unix_time()); + if self.data.to_delete { + self.swss_table + .del_async(&self.data.swss_key) + .await + .expect("Table::del threw an exception"); + } else { + self.data.mutated = false; + self.swss_table + .set_async(&self.data.swss_key, self.data.fvs.clone()) + .await + .expect("Table::set threw an exception"); + self.data.last_updated_time = Some(get_unix_time()); + } } fn drop_changes(&mut self) { self.data.mutated = false; + self.data.to_delete = false; self.data.fvs.clone_from(&self.data.backup_fvs); } } diff --git a/crates/swbus-actor/tests/kvstore.rs b/crates/swbus-actor/tests/kvstore.rs index 78a2631..5463cc0 100644 --- a/crates/swbus-actor/tests/kvstore.rs +++ b/crates/swbus-actor/tests/kvstore.rs @@ -38,13 +38,23 @@ async fn echo() { let (notify_done, is_done) = channel(); swbus_actor::spawn(KVStore(Redis::start()), "test", "kv"); - swbus_actor::spawn(KVClient(notify_done), "test", "client"); + swbus_actor::spawn(KVClient(notify_done, false), "test", "client"); timeout(Duration::from_secs(3), is_done) .await .expect("timeout") .unwrap(); - verify_actor_state(swbus_edge, &mut mgmt_resp_queue_rx).await; + verify_actor_state(swbus_edge.clone(), &mut mgmt_resp_queue_rx).await; + + // Test delete operation + let (notify_done_delete, is_done_delete) = channel(); + swbus_actor::spawn(KVClient(notify_done_delete, true), "test", "client"); + + timeout(Duration::from_secs(3), is_done_delete) + .await + .expect("timeout") + .unwrap(); + verify_actor_state_after_delete(swbus_edge, &mut mgmt_resp_queue_rx).await; } async fn verify_actor_state(swbus_edge: Arc, mgmt_resp_queue_rx: &mut mpsc::Receiver) { @@ -95,6 +105,7 @@ async fn verify_actor_state(swbus_edge: Arc, mgmt_resp_queue_r "count": "1000" }, "mutated": false, + "to_delete": false, "backup_fvs": { "count": "999" }, @@ -166,10 +177,127 @@ async fn verify_actor_state(swbus_edge: Arc, mgmt_resp_queue_r } } +async fn verify_actor_state_after_delete( + swbus_edge: Arc, + mgmt_resp_queue_rx: &mut mpsc::Receiver, +) { + // send a request to get actor state + let mgmt_request = ManagementRequest::new(ManagementRequestType::HamgrdGetActorState); + + let header = SwbusMessageHeader::new(sp("mgmt_resp"), sp("kv"), 2); + + let request_msg = SwbusMessage { + header: Some(header), + body: Some(Body::ManagementRequest(mgmt_request)), + }; + swbus_edge.send(request_msg).await.unwrap(); + let expected_json = r#" + { + "incoming": { + "": { + "msg": { + "key": "", + "data": { + "Del": { + "key": "count" + } + } + }, + "source": { + "region_id": "test", + "cluster_id": "test", + "node_id": "test", + "service_type": "test", + "service_id": "test", + "resource_type": "test", + "resource_id": "client" + }, + "request_id": 0, + "version": 2002, + "created_time": 0, + "last_updated_time": 0, + "response": "Ok", + "acked": true + } + }, + "internal": {}, + "outgoing":{ + "outgoing_queued":[], + "outgoing_sent":{ + "kv-get":{ + "msg":{ + "key":"kv-get", + "data":{ + "key":"count", + "val":"1000" + } + }, + "id":0, + "created_time":0, + "last_updated_time":0, + "last_sent_time":0, + "version":1001, + "acked":true, + "response":"Ok", + "response_source":{ + "region_id":"test", + "cluster_id":"test", + "node_id":"test", + "service_type":"test", + "service_id":"test", + "resource_type":"test", + "resource_id":"client" + } + } + } + } + } + "#; + + let expected: ActorStateDump = serde_json::from_str(expected_json).unwrap(); + match timeout(Duration::from_secs(3), mgmt_resp_queue_rx.recv()).await { + Ok(Some(msg)) => match msg.body { + Some(Body::Response(ref response)) => { + assert_eq!(response.request_id, 2); + assert_eq!(response.error_code, SwbusErrorCode::Ok as i32); + match response.response_body { + Some(ResponseBody::ManagementQueryResult(ref result)) => { + println!("{}", &result.value); + let mut state: ActorStateDump = serde_json::from_str(&result.value).unwrap(); + if let Some(inner_fields) = state.outgoing.outgoing_sent.get_mut("kv-get") { + inner_fields.id = 0; + inner_fields.created_time = 0; + inner_fields.last_updated_time = 0; + inner_fields.last_sent_time = 0; + } + // Reset timestamps for incoming message + if let Some(incoming_entry) = state.incoming.get_mut("") { + incoming_entry.request_id = 0; + incoming_entry.created_time = 0; + incoming_entry.last_updated_time = 0; + } + + assert_eq!(state, expected); + } + _ => panic!("message body is not a ManagementQueryResult"), + } + } + _ => panic!("message body is not a Response"), + }, + Ok(None) => { + panic!("channel broken"); + } + Err(_) => { + panic!("request timeout: didn't receive response"); + } + } +} + #[derive(Serialize, Deserialize)] enum KVMessage { Get { key: String }, Set { key: String, val: String }, + Del { key: String }, } impl KVMessage { @@ -183,6 +311,10 @@ impl KVMessage { val: v.into(), } } + + fn del(k: impl Into) -> Self { + KVMessage::Del { key: k.into() } + } } #[derive(Serialize, Deserialize)] @@ -191,7 +323,7 @@ struct KVGetResult { val: String, } -struct KVClient(Sender<()>); +struct KVClient(Sender<()>, bool); impl KVClient { fn notify_done(&mut self) { @@ -201,13 +333,26 @@ impl KVClient { impl Actor for KVClient { async fn init(&mut self, state: &mut State) -> Result<()> { - state - .outgoing() - .send(sp("kv"), ActorMessage::new("", &KVMessage::get("count"))?); + if self.1 { + // to_delete is true, send delete message and immediately notify done + state + .outgoing() + .send(sp("kv"), ActorMessage::new("", &KVMessage::del("count"))?); + self.notify_done(); + } else { + state + .outgoing() + .send(sp("kv"), ActorMessage::new("", &KVMessage::get("count"))?); + } Ok(()) } async fn handle_message(&mut self, state: &mut State, key: &str, _context: &mut Context) -> Result<()> { + if self.1 { + // For delete operations, we shouldn't receive any messages + return Ok(()); + } + assert_eq!(key, "kv-get"); let KVGetResult { key, val } = state.incoming().get(key)?.deserialize_data::()?; @@ -268,6 +413,9 @@ impl Actor for KVStore { KVMessage::Set { key, val } => { state.internal().get_mut("data").insert(key, val.into()); } + KVMessage::Del { key: _ } => { + state.internal().delete("data"); + } } Ok(()) } From 82b1a3f8efe6d622429aacf8e72f3c55c6516a97 Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 16 Sep 2025 13:31:50 -0400 Subject: [PATCH 10/14] Move vnet_tunnel_route_table to producer bridge (#110) ### why vnet_tunnel_route_table needs to be updated via ProducerStateTable to properly trigger orchagent handlers ### what this PR does move the table from internal state to outgoing state via producer bridge --- crates/hamgrd/src/actors.rs | 20 +++++++- crates/hamgrd/src/actors/ha_set.rs | 82 +++++++++++++++++++++--------- crates/hamgrd/src/main.rs | 6 ++- 3 files changed, 81 insertions(+), 27 deletions(-) diff --git a/crates/hamgrd/src/actors.rs b/crates/hamgrd/src/actors.rs index b6690e1..b61f242 100644 --- a/crates/hamgrd/src/actors.rs +++ b/crates/hamgrd/src/actors.rs @@ -16,7 +16,9 @@ use swbus_edge::swbus_proto::message_id_generator::MessageIdGenerator; use swbus_edge::swbus_proto::result::*; use swbus_edge::swbus_proto::swbus::{swbus_message::Body, DataRequest, ServicePath, SwbusErrorCode, SwbusMessage}; use swbus_edge::SwbusEdgeRuntime; -use swss_common::{KeyOpFieldValues, KeyOperation, SubscriberStateTable, ZmqClient, ZmqProducerStateTable}; +use swss_common::{ + KeyOpFieldValues, KeyOperation, ProducerStateTable, SubscriberStateTable, ZmqClient, ZmqProducerStateTable, +}; use swss_common_bridge::{consumer::ConsumerBridge, producer::spawn_producer_bridge}; use tokio::sync::mpsc::{channel, Receiver}; use tokio::task::JoinHandle; @@ -277,3 +279,19 @@ where anyhow::bail!("Failed to connect to ZMQ server at {}", zmq_endpoint); } } + +pub async fn spawn_vanilla_producer_bridge(edge_runtime: Arc) -> AnyhowResult> +where + T: SonicDbTable + 'static, +{ + let db = crate::db_for_table::().await?; + let pst = ProducerStateTable::new(db, T::table_name()).unwrap(); + + let sp = crate::common_bridge_sp::(&edge_runtime); + info!( + "spawned ZMQ producer bridge for {} at {}", + T::table_name(), + sp.to_longest_path() + ); + Ok(spawn_producer_bridge(edge_runtime.clone(), sp, pst)) +} diff --git a/crates/hamgrd/src/actors/ha_set.rs b/crates/hamgrd/src/actors/ha_set.rs index 7a742aa..a813563 100644 --- a/crates/hamgrd/src/actors/ha_set.rs +++ b/crates/hamgrd/src/actors/ha_set.rs @@ -9,10 +9,9 @@ use sonic_dash_api_proto::ha_set_config::HaSetConfig; use sonic_dash_api_proto::ip_to_string; use std::collections::HashMap; use swbus_actor::{ - state::{incoming::Incoming, internal::Internal, outgoing::Outgoing}, + state::{incoming::Incoming, outgoing::Outgoing}, Actor, ActorMessage, Context, State, }; -use swss_common::Table; use swss_common::{KeyOpFieldValues, KeyOperation}; use swss_common_bridge::consumer::ConsumerBridge; use tracing::{debug, error, info, instrument}; @@ -172,7 +171,7 @@ impl HaSetActor { &self, vdpus: &Vec, incoming: &Incoming, - internal: &mut Internal, + outgoing: &mut Outgoing, ) -> Result<()> { let Some(global_cfg) = Self::get_dash_global_config(incoming) else { return Ok(()); @@ -192,12 +191,6 @@ impl HaSetActor { .unwrap_or_default() ); - if !internal.has_entry(VnetRouteTunnelTable::table_name(), &swss_key) { - let db = crate::db_for_table::().await?; - let table = Table::new_async(db, VnetRouteTunnelTable::table_name()).await?; - internal.add(VnetRouteTunnelTable::table_name(), table, swss_key).await; - } - let mut endpoint = Vec::new(); let mut endpoint_monitor = Vec::new(); let mut primary = Vec::new(); @@ -228,12 +221,44 @@ impl HaSetActor { }; let fvs = swss_serde::to_field_values(&vnet_route)?; - internal.get_mut(VnetRouteTunnelTable::table_name()).clone_from(&fvs); + let kfv = KeyOpFieldValues { + key: swss_key, + operation: KeyOperation::Set, + field_values: fvs, + }; + + let msg = ActorMessage::new(self.id.clone(), &kfv)?; + outgoing.send(outgoing.common_bridge_sp::(), msg); + Ok(()) } - fn delete_vnet_route_tunnel_table(&self, internal: &mut Internal) -> Result<()> { - internal.delete(VnetRouteTunnelTable::table_name()); + fn delete_vnet_route_tunnel_table(&self, incoming: &Incoming, outgoing: &mut Outgoing) -> Result<()> { + let Some(global_cfg) = Self::get_dash_global_config(incoming) else { + return Ok(()); + }; + let swss_key = format!( + "{}:{}", + global_cfg + .vnet_name + .ok_or(anyhow!("Missing vnet_name in global config"))?, + self.dash_ha_set_config + .as_ref() + .unwrap() + .vip_v4 + .as_ref() + .map(ip_to_string) + .unwrap_or_default() + ); + + let kfv = KeyOpFieldValues { + key: swss_key, + operation: KeyOperation::Del, + field_values: HashMap::new(), + }; + + let msg = ActorMessage::new(self.id.clone(), &kfv)?; + outgoing.send(outgoing.common_bridge_sp::(), msg); Ok(()) } @@ -364,24 +389,24 @@ impl HaSetActor { } async fn handle_dash_ha_global_config(&mut self, state: &mut State) -> Result<()> { - let (internal, incoming, outgoing) = state.get_all(); + let (_internal, incoming, outgoing) = state.get_all(); let Some(vdpus) = self.get_vdpus_if_ready(incoming) else { return Ok(()); }; // global config update affects Vxlan tunnel and dash-ha-set in DPU self.update_dash_ha_set_table(&vdpus, incoming, outgoing)?; - self.update_vnet_route_tunnel_table(&vdpus, incoming, internal).await?; + self.update_vnet_route_tunnel_table(&vdpus, incoming, outgoing).await?; Ok(()) } async fn handle_vdpu_state_update(&mut self, state: &mut State) -> Result<()> { - let (internal, incoming, outgoing) = state.get_all(); + let (_internal, incoming, outgoing) = state.get_all(); // vdpu update affects dash-ha-set in DPU and vxlan tunnel let Some(vdpus) = self.get_vdpus_if_ready(incoming) else { return Ok(()); }; self.update_dash_ha_set_table(&vdpus, incoming, outgoing)?; - self.update_vnet_route_tunnel_table(&vdpus, incoming, internal).await?; + self.update_vnet_route_tunnel_table(&vdpus, incoming, outgoing).await?; Ok(()) } @@ -406,7 +431,7 @@ impl HaSetActor { } fn do_cleanup(&mut self, state: &mut State) -> Result<()> { - let (internal, incoming, outgoing) = state.get_all(); + let (_internal, incoming, outgoing) = state.get_all(); let Some(vdpus) = self.get_vdpus_if_ready(incoming) else { debug!("Not all DPU info is ready for cleanup"); @@ -417,7 +442,7 @@ impl HaSetActor { error!("Failed to delete dash_ha_set_table: {}", e); } - if let Err(e) = self.delete_vnet_route_tunnel_table(internal) { + if let Err(e) = self.delete_vnet_route_tunnel_table(incoming, outgoing) { error!("Failed to delete vnet_route_tunnel_table: {}", e); } @@ -557,13 +582,17 @@ mod test { // Verify that haset actor state is sent to ha-scope actor recv! { key: HaSetActorState::msg_key(&ha_set_id), data: { "up": true, "ha_set": &ha_set_obj }, addr: runtime.sp("ha-scope", &format!("vdpu0:{ha_set_id}")) }, - chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), data: expected_vnet_route }, + recv! { key: &ha_set_id, data: {"key": format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), + "operation": "Set", "field_values": expected_vnet_route}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, // simulate delete of ha-set entry send! { key: HaSetActor::table_name(), data: { "key": HaSetActor::table_name(), "operation": "Del", "field_values": ha_set_cfg_fvs }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, recv! { key: &ha_set_id, data: {"key": &ha_set_id, "operation": "Del", "field_values": {}}, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, - chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), nonexist }, + recv! { key: &ha_set_id, data: {"key": format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), + "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, @@ -636,14 +665,19 @@ mod test { send! { key: VDpuActorState::msg_key(&vdpu0_id), data: vdpu0_state, addr: runtime.sp("vdpu", &vdpu0_id) }, // Simulate VDPU state update for vdpu1 (backup) send! { key: VDpuActorState::msg_key(&vdpu1_id), data: vdpu1_state, addr: runtime.sp("vdpu", &vdpu1_id) }, - // Verify that the DASH_HA_SET_TABLE was updated - chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), - data: expected_vnet_route }, + // Verify that the VnetRouteTunnelTable was updated + recv! { key: &ha_set_id, data: {"key": format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), + "operation": "Set", "field_values": expected_vnet_route}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, + // simulate delete of ha-set entry send! { key: HaSetActor::table_name(), data: { "key": HaSetActor::table_name(), "operation": "Del", "field_values": ha_set_cfg_fvs }, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, - chkdb! { type: VnetRouteTunnelTable, key: &format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), nonexist }, + + recv! { key: &ha_set_id, data: {"key": format!("{}:{}", global_cfg.vnet_name.as_ref().unwrap(), ip_to_string(ha_set_cfg.vip_v4.as_ref().unwrap())), + "operation": "Del", "field_values": {}}, + addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, recv! { key: ActorRegistration::msg_key(RegistrationType::VDPUState, &ha_set_id), data: { "active": false }, addr: runtime.sp(VDpuActor::name(), &vdpu0_id) }, diff --git a/crates/hamgrd/src/main.rs b/crates/hamgrd/src/main.rs index 419d6aa..4fc064b 100644 --- a/crates/hamgrd/src/main.rs +++ b/crates/hamgrd/src/main.rs @@ -21,10 +21,10 @@ use tracing::error; mod actors; mod db_structs; mod ha_actor_messages; -use actors::spawn_zmq_producer_bridge; use actors::{dpu::DpuActor, ha_scope::HaScopeActor, ha_set::HaSetActor, vdpu::VDpuActor, DbBasedActor}; +use actors::{spawn_vanilla_producer_bridge, spawn_zmq_producer_bridge}; use anyhow::Result; -use db_structs::{BfdSessionTable, DashHaScopeTable, DashHaSetTable, Dpu, VDpu}; +use db_structs::{BfdSessionTable, DashHaScopeTable, DashHaSetTable, Dpu, VDpu, VnetRouteTunnelTable}; use lazy_static::lazy_static; use sonic_dash_api_proto::{ha_scope_config::HaScopeConfig, ha_set_config::HaSetConfig}; use std::any::Any; @@ -152,6 +152,8 @@ async fn spawn_producer_bridges(edge_runtime: Arc, dpu: &Dpu) let handle = spawn_zmq_producer_bridge::(edge_runtime.clone(), &zmq_endpoint).await?; handles.push(handle); + let handle = spawn_vanilla_producer_bridge::(edge_runtime.clone()).await?; + handles.push(handle); Ok(handles) } From bb38a0164cf5a8945419de8cc8c7360240bacc71 Mon Sep 17 00:00:00 2001 From: dypet Date: Tue, 23 Sep 2025 14:40:20 -0600 Subject: [PATCH 11/14] Add local nexthop ip (#116) Adding local_nexthop_ip so correct nexthop IP is used as endpoint in VNET_ROUTE_TUNNEL_TABLE when DPU is local. --- crates/hamgrd/src/actors/ha_set.rs | 29 +++++++++++++++++++------- crates/hamgrd/src/actors/test.rs | 2 ++ crates/hamgrd/src/db_structs.rs | 4 ++++ crates/hamgrd/src/ha_actor_messages.rs | 3 +++ 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/crates/hamgrd/src/actors/ha_set.rs b/crates/hamgrd/src/actors/ha_set.rs index a813563..11bc843 100644 --- a/crates/hamgrd/src/actors/ha_set.rs +++ b/crates/hamgrd/src/actors/ha_set.rs @@ -195,18 +195,33 @@ impl HaSetActor { let mut endpoint_monitor = Vec::new(); let mut primary = Vec::new(); let mut check_directly_connected = false; + let mut any_managed = false; for vdpu_ext in vdpus { - if vdpu_ext.vdpu.dpu.is_managed { - // if it is locally managed dpu, use dpu pa_ipv4 as endpoint - endpoint.push(vdpu_ext.vdpu.dpu.pa_ipv4.clone()); + if !vdpu_ext.vdpu.dpu.remote_dpu { + // if it is locally managed dpu, use local nexthop as endpoint + endpoint.push(vdpu_ext.vdpu.dpu.local_nexthop_ip.clone()); } else { endpoint.push(vdpu_ext.vdpu.dpu.npu_ipv4.clone()); } endpoint_monitor.push(vdpu_ext.vdpu.dpu.pa_ipv4.clone()); - primary.push(vdpu_ext.is_primary.to_string()); - check_directly_connected |= vdpu_ext.vdpu.dpu.is_managed; + if vdpu_ext.is_primary { + if !vdpu_ext.vdpu.dpu.remote_dpu { + primary.push(vdpu_ext.vdpu.dpu.local_nexthop_ip.clone()); + } else { + primary.push(vdpu_ext.vdpu.dpu.npu_ipv4.clone()); + } + } + check_directly_connected |= !vdpu_ext.vdpu.dpu.remote_dpu; + any_managed |= vdpu_ext.vdpu.dpu.is_managed; + } + + if check_directly_connected && !any_managed { + debug!( + "Skipping VnetRouteTunnelTable update as directly connected DPU and no locally managed DPU are present." + ); + return Ok(()); } // update vnet route tunnel table @@ -544,7 +559,7 @@ mod test { vdpu1_state_obj.dpu.pa_ipv4.clone(), ]), monitoring: None, - primary: Some(vec!["true".to_string(), "false".to_string()]), + primary: Some(vec![vdpu0_state_obj.dpu.pa_ipv4.clone()]), rx_monitor_timer: global_cfg.dpu_bfd_probe_interval_in_ms, tx_monitor_timer: global_cfg.dpu_bfd_probe_interval_in_ms, check_directly_connected: Some(true), @@ -637,7 +652,7 @@ mod test { vdpu1_state_obj.dpu.pa_ipv4.clone(), ]), monitoring: None, - primary: Some(vec!["true".to_string(), "false".to_string()]), + primary: Some(vec![vdpu0_state_obj.dpu.npu_ipv4.clone()]), rx_monitor_timer: global_cfg.dpu_bfd_probe_interval_in_ms, tx_monitor_timer: global_cfg.dpu_bfd_probe_interval_in_ms, check_directly_connected: Some(false), diff --git a/crates/hamgrd/src/actors/test.rs b/crates/hamgrd/src/actors/test.rs index e254f6a..567277f 100644 --- a/crates/hamgrd/src/actors/test.rs +++ b/crates/hamgrd/src/actors/test.rs @@ -373,6 +373,7 @@ pub fn make_dpu_object(switch: u16, dpu: u32) -> Dpu { vip_ipv6: Some(normalize_ipv6(&format!("3:2:{switch_pair_id}::{dpu}"))), pa_ipv4: format!("18.0.{switch}.{dpu}"), pa_ipv6: Some(normalize_ipv6(&format!("18:0:{switch}::{dpu}"))), + local_nexthop_ip: format!("18.0.{switch}.{dpu}"), dpu_id: dpu, vdpu_id: Some(format!("vdpu{}", switch * 8 + dpu as u16)), orchagent_zmq_port: 8100, @@ -452,6 +453,7 @@ pub fn to_local_dpu(dpu_actor_state: &DpuActorState) -> Dpu { vip_ipv6: dpu_actor_state.vip_ipv6.clone(), pa_ipv4: dpu_actor_state.pa_ipv4.clone(), pa_ipv6: dpu_actor_state.pa_ipv6.clone(), + local_nexthop_ip: dpu_actor_state.local_nexthop_ip.clone(), dpu_id: dpu_actor_state.dpu_id, vdpu_id: dpu_actor_state.vdpu_id.clone(), orchagent_zmq_port: dpu_actor_state.orchagent_zmq_port, diff --git a/crates/hamgrd/src/db_structs.rs b/crates/hamgrd/src/db_structs.rs index 230c209..635e540 100644 --- a/crates/hamgrd/src/db_structs.rs +++ b/crates/hamgrd/src/db_structs.rs @@ -43,6 +43,7 @@ pub struct Dpu { pub vip_ipv6: Option, pub pa_ipv4: String, pub pa_ipv6: Option, + pub local_nexthop_ip: String, pub dpu_id: u32, pub vdpu_id: Option, pub orchagent_zmq_port: u16, @@ -446,6 +447,7 @@ mod test { "operation": "Set", "field_values": { "pa_ipv4": "1.2.3.4", + "local_nexthop_ip": "2.2.2.5", "dpu_id": "1", "orchagent_zmq_port": "8100", "swbus_port": "23606", @@ -497,6 +499,7 @@ mod test { pa_ipv4: "1.2.3.6".to_string(), vip_ipv4: Some("4.5.6.6".to_string()), pa_ipv6: None, + local_nexthop_ip: "2.2.2.5".to_string(), dpu_id: 6, orchagent_zmq_port: 8100, swbus_port: 23612, @@ -515,6 +518,7 @@ mod test { for d in 6..8 { let dpu_fvs = vec![ ("pa_ipv4".to_string(), Ipv4Addr::new(1, 2, 3, d).to_string()), + ("local_nexthop_ip".to_string(), Ipv4Addr::new(2, 2, 2, 5).to_string()), ("vip_ipv4".to_string(), Ipv4Addr::new(4, 5, 6, d).to_string()), ("dpu_id".to_string(), d.to_string()), ("orchagent_zmq_port".to_string(), "8100".to_string()), diff --git a/crates/hamgrd/src/ha_actor_messages.rs b/crates/hamgrd/src/ha_actor_messages.rs index 934da16..f6e5582 100644 --- a/crates/hamgrd/src/ha_actor_messages.rs +++ b/crates/hamgrd/src/ha_actor_messages.rs @@ -25,6 +25,7 @@ pub struct DpuActorState { pub vip_ipv6: Option, pub pa_ipv4: String, pub pa_ipv6: Option, + pub local_nexthop_ip: String, pub dpu_id: u32, pub vdpu_id: Option, pub orchagent_zmq_port: u16, @@ -56,6 +57,7 @@ impl DpuActorState { vip_ipv6: dpu.vip_ipv6.clone(), pa_ipv4: dpu.pa_ipv4.clone(), pa_ipv6: dpu.pa_ipv6.clone(), + local_nexthop_ip: dpu.local_nexthop_ip.clone(), dpu_id: dpu.dpu_id, vdpu_id: dpu.vdpu_id.clone(), orchagent_zmq_port: dpu.orchagent_zmq_port, @@ -79,6 +81,7 @@ impl DpuActorState { vip_ipv6: None, pa_ipv4: rdpu.pa_ipv4.clone(), pa_ipv6: rdpu.pa_ipv6.clone(), + local_nexthop_ip: "".to_string(), dpu_id: rdpu.dpu_id, vdpu_id: None, orchagent_zmq_port: 0, From f282d41ee0bf04e6fa3050799c409e1e6008e320 Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:41:02 -0400 Subject: [PATCH 12/14] Unregister handlers after actor terminates (#115) ### why when actor terminates itself, the handler in SwbusEdgeRuntime is not removed. When a new actor is spawned with the same service path, it will be ignored because the ActorCreator replies on "NoRoute" to discover new actor. ### what this PR does when ActorDriver exits from run loop, the SimpleSwbusEdgeClient it owns will be destructed. From the destructor, handler will be removed. This addresses issue #111 --- crates/hamgrd/src/actors.rs | 2 +- crates/swbus-actor/tests/echo.rs | 25 +++++++++++++++++-- crates/swbus-edge/src/edge_runtime.rs | 22 ++++++++++++++++ crates/swbus-edge/src/message_router.rs | 8 ++++++ .../src/message_router/route_map.rs | 8 ++++++ crates/swbus-edge/src/simple_client.rs | 6 +++++ 6 files changed, 68 insertions(+), 3 deletions(-) diff --git a/crates/hamgrd/src/actors.rs b/crates/hamgrd/src/actors.rs index b61f242..245c567 100644 --- a/crates/hamgrd/src/actors.rs +++ b/crates/hamgrd/src/actors.rs @@ -289,7 +289,7 @@ where let sp = crate::common_bridge_sp::(&edge_runtime); info!( - "spawned ZMQ producer bridge for {} at {}", + "spawned producer bridge for {} at {}", T::table_name(), sp.to_longest_path() ); diff --git a/crates/swbus-actor/tests/echo.rs b/crates/swbus-actor/tests/echo.rs index e6b4192..13a2adc 100644 --- a/crates/swbus-actor/tests/echo.rs +++ b/crates/swbus-actor/tests/echo.rs @@ -14,7 +14,7 @@ fn sp(name: &str) -> ServicePath { async fn echo() { let mut swbus_edge = SwbusEdgeRuntime::new("none".to_string(), sp("none"), ConnectionType::InNode); swbus_edge.start().await.unwrap(); - let actor_runtime = ActorRuntime::new(swbus_edge.into()); + let actor_runtime: ActorRuntime = ActorRuntime::new(swbus_edge.into()); swbus_actor::set_global_runtime(actor_runtime); let (notify_done, is_done) = channel(); @@ -22,6 +22,25 @@ async fn echo() { swbus_actor::spawn(EchoServer, "test", "echo"); swbus_actor::spawn(EchoClient(notify_done), "test", "client"); + timeout(Duration::from_secs(3), is_done) + .await + .expect("timeout") + .unwrap(); + + { + // Access the actor runtime to get the service path and check if handler exists + let runtime_guard = swbus_actor::get_global_runtime(); + let runtime = runtime_guard.as_ref().unwrap(); + let sp = runtime.sp("test", "client"); + let has_handler = runtime.get_swbus_edge().has_handler(&sp); + assert!(!has_handler); + } + + let (notify_done, is_done) = channel(); + + // spawn the same actor to make sure the previous actor cleans up properly + swbus_actor::spawn(EchoClient(notify_done), "test", "client"); + timeout(Duration::from_secs(3), is_done) .await .expect("timeout") @@ -43,7 +62,7 @@ impl Actor for EchoClient { Ok(()) } - async fn handle_message(&mut self, state: &mut State, key: &str, _context: &mut Context) -> Result<()> { + async fn handle_message(&mut self, state: &mut State, key: &str, context: &mut Context) -> Result<()> { let count = key.parse::().unwrap(); // Assert that the incoming table has messages 0..=count still cached @@ -54,6 +73,8 @@ impl Actor for EchoClient { if count == 1000 { self.notify_done(); + // terminate this actor + context.stop(); } else { state .outgoing() diff --git a/crates/swbus-edge/src/edge_runtime.rs b/crates/swbus-edge/src/edge_runtime.rs index e15ac6a..e6f3f27 100644 --- a/crates/swbus-edge/src/edge_runtime.rs +++ b/crates/swbus-edge/src/edge_runtime.rs @@ -74,6 +74,28 @@ impl SwbusEdgeRuntime { self.message_router.add_private_route(svc_path, proxy); } + /// Remove handler by ServicePath. + pub fn remove_handler(&self, svc_path: &ServicePath) -> bool { + match self.message_router.remove_route(svc_path) { + Some(_) => { + info!("Removed handler for service path: {}", svc_path.to_longest_path()); + true + } + None => { + info!( + "No handler found to remove for service path: {}", + svc_path.to_longest_path() + ); + false + } + } + } + + /// Check if a handler exists for the given ServicePath. + pub fn has_handler(&self, svc_path: &ServicePath) -> bool { + self.message_router.has_route(svc_path) + } + pub async fn send(&self, message: SwbusMessage) -> Result<()> { // Send message to the message router match self.sender_to_message_router.send(message).await { diff --git a/crates/swbus-edge/src/message_router.rs b/crates/swbus-edge/src/message_router.rs index bb04b6f..bc583d2 100644 --- a/crates/swbus-edge/src/message_router.rs +++ b/crates/swbus-edge/src/message_router.rs @@ -77,6 +77,14 @@ impl SwbusMessageRouter { self.routes.insert(svc_path, handler, Privacy::Private); } + pub fn remove_route(&self, svc_path: &ServicePath) -> Option { + self.routes.remove(svc_path).map(|(handler, _)| handler) + } + + pub fn has_route(&self, svc_path: &ServicePath) -> bool { + self.routes.contains(svc_path) + } + async fn route_message( swbus_client: &mut SwbusCoreClient, routes: &RouteMap, diff --git a/crates/swbus-edge/src/message_router/route_map.rs b/crates/swbus-edge/src/message_router/route_map.rs index 25cca6f..bd19bc0 100644 --- a/crates/swbus-edge/src/message_router/route_map.rs +++ b/crates/swbus-edge/src/message_router/route_map.rs @@ -22,4 +22,12 @@ impl RouteMap { } }) } + + pub(super) fn remove(&self, svc_path: &ServicePath) -> Option<(SwbusMessageHandlerProxy, Privacy)> { + self.0.remove(svc_path).map(|(_, value)| value) + } + + pub(super) fn contains(&self, svc_path: &ServicePath) -> bool { + self.0.contains_key(svc_path) + } } diff --git a/crates/swbus-edge/src/simple_client.rs b/crates/swbus-edge/src/simple_client.rs index 4582acf..9c99454 100644 --- a/crates/swbus-edge/src/simple_client.rs +++ b/crates/swbus-edge/src/simple_client.rs @@ -238,3 +238,9 @@ pub struct OutgoingMessage { pub destination: ServicePath, pub body: MessageBody, } + +impl Drop for SimpleSwbusEdgeClient { + fn drop(&mut self) { + self.rt.remove_handler(&self.source); + } +} From a35415e0492134eb04779941f323b5d0826a04f7 Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:41:32 -0400 Subject: [PATCH 13/14] Proper handling of entry not found in incoming vs decoding error (#114) ### why currently Incoming::get returns error if the entry is not found and caller typically propagates the error further. Sometimes it is normal that an entry doesn't exist. It needs to be treated differently from message decode error. ### what this PR does Incoming::get returns Option. Caller needs to handle the None return accordingly, which means the entry is not found. --- crates/hamgrd/src/actors/dpu.rs | 47 +++++++++++++++--------- crates/hamgrd/src/actors/ha_scope.rs | 30 +++++++++------ crates/hamgrd/src/actors/ha_set.rs | 20 ++++++---- crates/hamgrd/src/actors/vdpu.rs | 41 +++++++++++++-------- crates/swbus-actor/src/state.rs | 2 +- crates/swbus-actor/src/state/incoming.rs | 16 ++++++-- crates/swbus-actor/tests/echo.rs | 5 ++- crates/swbus-actor/tests/kvstore.rs | 2 +- 8 files changed, 105 insertions(+), 58 deletions(-) diff --git a/crates/hamgrd/src/actors/dpu.rs b/crates/hamgrd/src/actors/dpu.rs index f149b67..66e1f56 100644 --- a/crates/hamgrd/src/actors/dpu.rs +++ b/crates/hamgrd/src/actors/dpu.rs @@ -12,7 +12,7 @@ use swbus_actor::{state::incoming::Incoming, state::outgoing::Outgoing, Actor, A use swbus_edge::SwbusEdgeRuntime; use swss_common::{KeyOpFieldValues, KeyOperation, SubscriberStateTable}; use swss_common_bridge::consumer::ConsumerBridge; -use tracing::{debug, error, info, instrument}; +use tracing::{debug, error, instrument}; use super::spawn_consumer_bridge_for_actor_with_selector; @@ -60,19 +60,30 @@ impl DpuActor { RemoteDpu::table_name() } - fn get_dpu_state(incoming: &Incoming) -> Result { - let dpu_state_kfv: KeyOpFieldValues = incoming.get(DpuState::table_name())?.deserialize_data()?; - Ok(swss_serde::from_field_values(&dpu_state_kfv.field_values)?) + fn get_dpu_state(incoming: &Incoming) -> Result> { + match incoming.get(DpuState::table_name()) { + Some(msg) => { + let dpu_state_kfv: KeyOpFieldValues = msg.deserialize_data()?; + Ok(Some(swss_serde::from_field_values(&dpu_state_kfv.field_values)?)) + } + None => Ok(None), + } } - fn get_bfd_probe_state(incoming: &Incoming) -> Result { - let bfd_probe_kfv: KeyOpFieldValues = incoming.get(DashBfdProbeState::table_name())?.deserialize_data()?; - Ok(swss_serde::from_field_values(&bfd_probe_kfv.field_values)?) + fn get_bfd_probe_state(incoming: &Incoming) -> Result> { + match incoming.get(DashBfdProbeState::table_name()) { + Some(msg) => { + let bfd_probe_kfv: KeyOpFieldValues = msg.deserialize_data()?; + Ok(Some(swss_serde::from_field_values(&bfd_probe_kfv.field_values)?)) + } + None => Ok(None), + } } fn get_dash_ha_global_config(incoming: &Incoming) -> Result { - let ha_global_config_kfv: KeyOpFieldValues = - incoming.get(DashHaGlobalConfig::table_name())?.deserialize_data()?; + let ha_global_config_kfv: KeyOpFieldValues = incoming + .get_or_fail(DashHaGlobalConfig::table_name())? + .deserialize_data()?; Ok(swss_serde::from_field_values(&ha_global_config_kfv.field_values)?) } @@ -133,7 +144,7 @@ impl DpuActor { async fn handle_dpu_message(&mut self, state: &mut State, key: &str, context: &mut Context) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); - let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; + let dpu_kfv: KeyOpFieldValues = incoming.get_or_fail(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { self.do_cleanup(context, state); context.stop(); @@ -233,16 +244,16 @@ impl DpuActor { } // Check pmon state from DPU_STATE table let dpu_state = match Self::get_dpu_state(incoming) { - Ok(dpu_state) => Some(dpu_state), + Ok(dpu_state) => dpu_state, Err(e) => { - info!("Not able to get DPU state. Assume DPU is down. Error: {}", e); + error!("Failed to decode DPU_STATE. Error: {}", e); None } }; let bfd_probe_state = match Self::get_bfd_probe_state(incoming) { - Ok(bfd_probe_state) => Some(bfd_probe_state), + Ok(bfd_probe_state) => bfd_probe_state, Err(e) => { - debug!("Not able to get BFD probe state. Error: {}", e); + error!("Failed to decode DASH_BFD_PROBE_STATE. Error: {}", e); None } }; @@ -310,7 +321,9 @@ impl DpuActor { // handle DPU state registration request. In response to the request, this actor will send its current state. fn handle_dpu_state_registration(&mut self, key: &str, incoming: &Incoming, outgoing: &mut Outgoing) -> Result<()> { - let entry = incoming.get_entry(key)?; + let entry = incoming + .get_entry(key) + .ok_or_else(|| anyhow!("Entry not found for key: {}", key))?; let ActorRegistration { active, .. } = entry.msg.deserialize_data()?; if active { self.update_dpu_state(incoming, outgoing, Some(entry.source.clone()))?; @@ -406,7 +419,7 @@ impl DpuActor { context: &mut Context, ) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); - let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; + let dpu_kfv: KeyOpFieldValues = incoming.get_or_fail(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { context.stop(); return Ok(()); @@ -421,7 +434,7 @@ impl DpuActor { fn handle_remote_dpu_message_to_local_dpu(&mut self, state: &mut State, key: &str) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); - let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; + let dpu_kfv: KeyOpFieldValues = incoming.get_or_fail(key)?.deserialize_data()?; let remote_dpu: RemoteDpu = swss_serde::from_field_values(&dpu_kfv.field_values)?; diff --git a/crates/hamgrd/src/actors/ha_scope.rs b/crates/hamgrd/src/actors/ha_scope.rs index fa03246..3830781 100644 --- a/crates/hamgrd/src/actors/ha_scope.rs +++ b/crates/hamgrd/src/actors/ha_scope.rs @@ -57,20 +57,28 @@ impl HaScopeActor { // get vdpu data received via vdpu udpate fn get_vdpu(&self, incoming: &Incoming) -> Option { let key = VDpuActorState::msg_key(&self.vdpu_id); - let Ok(msg) = incoming.get(&key) else { - return None; - }; - msg.deserialize_data().ok() + let msg = incoming.get(&key)?; + match msg.deserialize_data() { + Ok(data) => Some(data), + Err(e) => { + error!("Failed to deserialize VDpuActorState from message: {}", e); + None + } + } } fn get_haset(&self, incoming: &Incoming) -> Option { let ha_set_id = self.get_haset_id()?; let key = HaSetActorState::msg_key(&ha_set_id); - let Ok(msg) = incoming.get(&key) else { - return None; - }; - msg.deserialize_data().ok() + let msg = incoming.get(&key)?; + match msg.deserialize_data() { + Ok(data) => Some(data), + Err(e) => { + error!("Failed to deserialize HaSetActorState from message: {}", e); + None + } + } } fn get_haset_id(&self) -> Option { @@ -79,9 +87,7 @@ impl HaScopeActor { } fn get_dpu_ha_scope_state(&self, incoming: &Incoming) -> Option { - let Ok(msg) = incoming.get(DpuDashHaScopeState::table_name()) else { - return None; - }; + let msg = incoming.get(DpuDashHaScopeState::table_name())?; let kfv = match msg.deserialize_data::() { Ok(data) => data, Err(e) => { @@ -478,7 +484,7 @@ impl HaScopeActor { let (_internal, incoming, outgoing) = state.get_all(); // Retrieve the config update from the incoming message - let kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; + let kfv: KeyOpFieldValues = incoming.get_or_fail(key)?.deserialize_data()?; if kfv.operation == KeyOperation::Del { // cleanup resources before stopping diff --git a/crates/hamgrd/src/actors/ha_set.rs b/crates/hamgrd/src/actors/ha_set.rs index 11bc843..949a4fa 100644 --- a/crates/hamgrd/src/actors/ha_set.rs +++ b/crates/hamgrd/src/actors/ha_set.rs @@ -48,7 +48,7 @@ struct VDpuStateExt { impl HaSetActor { fn get_dash_global_config(incoming: &Incoming) -> Option { - let Ok(msg) = incoming.get(DashHaGlobalConfig::table_name()) else { + let Some(msg) = incoming.get(DashHaGlobalConfig::table_name()) else { debug!("DASH_HA_GLOBAL_CONFIG table is not available"); return None; }; @@ -298,10 +298,14 @@ impl HaSetActor { // get vdpu data received via vdpu udpate fn get_vdpu(&self, incoming: &Incoming, vdpu_id: &str) -> Option { let key = VDpuActorState::msg_key(vdpu_id); - let Ok(msg) = incoming.get(&key) else { - return None; - }; - msg.deserialize_data().ok() + let msg = incoming.get(&key)?; + match msg.deserialize_data() { + Ok(vdpu) => Some(vdpu), + Err(e) => { + error!("Failed to deserialize VDpuActorState from the message: {}", e); + None + } + } } /// Get vdpu data received via vdpu update and return them in a list with primary DPUs first. @@ -366,7 +370,7 @@ impl HaSetActor { context: &mut Context, ) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); - let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; + let dpu_kfv: KeyOpFieldValues = incoming.get_or_fail(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { // cleanup resources before stopping if let Err(e) = self.do_cleanup(state) { @@ -428,7 +432,9 @@ impl HaSetActor { async fn handle_haset_state_registration(&mut self, state: &mut State, key: &str) -> Result<()> { let (_, incoming, outgoing) = state.get_all(); - let entry = incoming.get_entry(key)?; + let entry = incoming + .get_entry(key) + .ok_or_else(|| anyhow!("Entry not found for key: {}", key))?; let ActorRegistration { active, .. } = entry.msg.deserialize_data()?; if active { let Some(vdpus) = self.get_vdpus_if_ready(incoming) else { diff --git a/crates/hamgrd/src/actors/vdpu.rs b/crates/hamgrd/src/actors/vdpu.rs index f4cfae6..1740823 100644 --- a/crates/hamgrd/src/actors/vdpu.rs +++ b/crates/hamgrd/src/actors/vdpu.rs @@ -2,7 +2,7 @@ use crate::actors::dpu::DpuActor; use crate::actors::DbBasedActor; use crate::db_structs::VDpu; use crate::ha_actor_messages::{ActorRegistration, DpuActorState, RegistrationType, VDpuActorState}; -use anyhow::Result; +use anyhow::{anyhow, Result}; use sonic_common::SonicDbTable; use swbus_actor::Context; use swbus_actor::{state::incoming::Incoming, state::outgoing::Outgoing, Actor, State}; @@ -31,6 +31,21 @@ impl DbBasedActor for VDpuActor { } impl VDpuActor { + fn get_dpu_actor_state(incoming: &Incoming, dpu_id: &str) -> Result> { + let msg = incoming.get(&format!("{}{}", DpuActorState::msg_key_prefix(), dpu_id)); + let Some(msg) = msg else { + // dpu data is not available yet + return Ok(None); + }; + match msg.deserialize_data() { + Ok(dpu) => Ok(Some(dpu)), + Err(e) => { + error!("Failed to deserialize DpuActorState from the message: {}", e); + Err(e) + } + } + } + fn register_to_dpu_actor(&self, outgoing: &mut Outgoing, active: bool) -> Result<()> { if self.vdpu.is_none() { return Ok(()); @@ -52,7 +67,7 @@ impl VDpuActor { async fn handle_vdpu_message(&mut self, state: &mut State, key: &str, context: &mut Context) -> Result<()> { let (_internal, incoming, outgoing) = state.get_all(); - let dpu_kfv: KeyOpFieldValues = incoming.get(key)?.deserialize_data()?; + let dpu_kfv: KeyOpFieldValues = incoming.get_or_fail(key)?.deserialize_data()?; if dpu_kfv.operation == KeyOperation::Del { // unregister from the DPU Actor self.do_cleanup(context, state); @@ -86,20 +101,14 @@ impl VDpuActor { } // only one dpu is supported for now let dpu_id = &self.vdpu.as_ref().unwrap().main_dpu_ids[0]; - let msg = incoming.get(&format!("{}{}", DpuActorState::msg_key_prefix(), dpu_id)); - - let Ok(msg) = msg else { - // dpu data is not available yet - return None; + let dpu = match Self::get_dpu_actor_state(incoming, dpu_id) { + Ok(None) => return None, + Ok(Some(dpu_actor_state)) => dpu_actor_state, + Err(_) => return None, }; - if let Ok(dpu) = msg.deserialize_data::() { - let vdpu = VDpuActorState { up: dpu.up, dpu }; - Some(vdpu) - } else { - error!("Failed to deserialize DpuActorState from the message"); - None - } + let vdpu = VDpuActorState { up: dpu.up, dpu }; + Some(vdpu) } async fn handle_vdpu_state_registration( @@ -108,7 +117,9 @@ impl VDpuActor { incoming: &Incoming, outgoing: &mut Outgoing, ) -> Result<()> { - let entry = incoming.get_entry(key)?; + let entry = incoming + .get_entry(key) + .ok_or_else(|| anyhow!("Entry not found for key: {}", key))?; let ActorRegistration { active, .. } = entry.msg.deserialize_data()?; if active { let Some(vdpu_state) = self.calculate_vdpu_state(incoming) else { diff --git a/crates/swbus-actor/src/state.rs b/crates/swbus-actor/src/state.rs index 9a3a6f2..48d04f1 100644 --- a/crates/swbus-actor/src/state.rs +++ b/crates/swbus-actor/src/state.rs @@ -40,7 +40,7 @@ impl State { /// # fn foo() -> anyhow::Result<()> { /// # let state: swbus_actor::State = todo!(); /// let (internal, incoming, outgoing) = state.get_all(); - /// let x = incoming.get("x")?; + /// let x = incoming.get_or_fail("x")?; /// let tbl = internal.get_mut("tbl"); /// tbl["x"] = x.deserialize_data::()?.into(); /// // Ok diff --git a/crates/swbus-actor/src/state/incoming.rs b/crates/swbus-actor/src/state/incoming.rs index dc11b70..0b55b6c 100644 --- a/crates/swbus-actor/src/state/incoming.rs +++ b/crates/swbus-actor/src/state/incoming.rs @@ -15,13 +15,17 @@ pub struct Incoming { } impl Incoming { - pub fn get(&self, key: &str) -> Result<&ActorMessage> { + pub fn get(&self, key: &str) -> Option<&ActorMessage> { self.get_entry(key).map(|entry| &entry.msg) } - pub fn get_entry(&self, key: &str) -> Result<&IncomingTableEntry> { - self.table - .get(key) + pub fn get_entry(&self, key: &str) -> Option<&IncomingTableEntry> { + self.table.get(key) + } + + pub fn get_or_fail(&self, key: &str) -> Result<&ActorMessage> { + self.get_entry(key) + .map(|entry| &entry.msg) .ok_or_else(|| anyhow!("Incoming state table has no key '{key}'")) } @@ -210,6 +214,10 @@ mod test { assert_eq!(incoming.get_entry("actor_registration-source/0").unwrap().msg, msg1); assert_eq!(incoming.get_entry("actor_registration-source/1").unwrap().msg, msg2); + // Test get_or_fail + assert_eq!(incoming.get_or_fail("actor_registration-source/0").unwrap(), &msg1); + assert!(incoming.get_or_fail("nonexistent").is_err()); + let regs = incoming.get_by_prefix("actor_registration-"); assert_eq!(regs.len(), 2); } diff --git a/crates/swbus-actor/tests/echo.rs b/crates/swbus-actor/tests/echo.rs index 13a2adc..45f1198 100644 --- a/crates/swbus-actor/tests/echo.rs +++ b/crates/swbus-actor/tests/echo.rs @@ -67,7 +67,10 @@ impl Actor for EchoClient { // Assert that the incoming table has messages 0..=count still cached for i in 0..=count { - let n = state.incoming().get(&format!("{i}"))?.deserialize_data::()?; + let n = state + .incoming() + .get_or_fail(&format!("{i}"))? + .deserialize_data::()?; assert_eq!(n, i); } diff --git a/crates/swbus-actor/tests/kvstore.rs b/crates/swbus-actor/tests/kvstore.rs index 5463cc0..3a174ea 100644 --- a/crates/swbus-actor/tests/kvstore.rs +++ b/crates/swbus-actor/tests/kvstore.rs @@ -354,7 +354,7 @@ impl Actor for KVClient { } assert_eq!(key, "kv-get"); - let KVGetResult { key, val } = state.incoming().get(key)?.deserialize_data::()?; + let KVGetResult { key, val } = state.incoming().get_or_fail(key)?.deserialize_data::()?; match &*key { "count" => { From aad309c4f9e34ee913ec173a8e3a520e41032df5 Mon Sep 17 00:00:00 2001 From: yue-fred-gao <132678244+yue-fred-gao@users.noreply.github.com> Date: Tue, 23 Sep 2025 16:42:17 -0400 Subject: [PATCH 14/14] Make deserializer more forgiven to parse dash_bfd_probe_state (#113) ### why currently the deserializer for dash_bfd_probe_state has strict requirements on the format of the fields. If it doesn't follow the format, it will reject it. Specifically, the timestamp field is enclosed in double-quotes, which caused parsing error. ### what this PR does make the deserializer more forgiven with format. If the value has double or single quotes or whitespaces, remove them first. If the value of v4_bfd_up_sessions or v6_bfd_up_sessions has quotes or space between comma, remove them. --- crates/hamgrd/src/actors/dpu.rs | 4 +- crates/hamgrd/src/db_structs.rs | 115 ++++++++++++++++++++++++++++++-- 2 files changed, 111 insertions(+), 8 deletions(-) diff --git a/crates/hamgrd/src/actors/dpu.rs b/crates/hamgrd/src/actors/dpu.rs index 66e1f56..0c994e4 100644 --- a/crates/hamgrd/src/actors/dpu.rs +++ b/crates/hamgrd/src/actors/dpu.rs @@ -582,7 +582,7 @@ mod test { recv! { key: "switch0_dpu0", data: {"key": "default:default:10.0.3.0", "operation": "Set", "field_values": bfd_fvs}, addr: crate::common_bridge_sp::(&runtime.get_swbus_edge()) }, - send! { key: DashBfdProbeState::table_name(), data: { "key": "", "operation": "Set", "field_values":serde_json::to_value(to_field_values(&dpu_bfd_up_state).unwrap()).unwrap()} }, + send! { key: DashBfdProbeState::table_name(), data: { "key": "dash_ha", "operation": "Set", "field_values":serde_json::to_value(to_field_values(&dpu_bfd_up_state).unwrap()).unwrap()} }, recv! { key: "DPUStateUpdate|switch0_dpu0", data: dpu_actor_up_state, addr: runtime.sp("vdpu", "test-vdpu") }, // Simulate DPU_STATE planes going down then up @@ -592,7 +592,7 @@ mod test { recv! { key: "DPUStateUpdate|switch0_dpu0", data: dpu_actor_up_state, addr: runtime.sp("vdpu", "test-vdpu") }, // Simulate BFD probe going down - send! { key: DashBfdProbeState::table_name(), data: { "key": "", "operation": "Set", "field_values": serde_json::to_value(to_field_values(&dpu_bfd_down_state).unwrap()).unwrap()} }, + send! { key: DashBfdProbeState::table_name(), data: { "key": "dash_ha", "operation": "Set", "field_values": serde_json::to_value(to_field_values(&dpu_bfd_down_state).unwrap()).unwrap()} }, recv! { key: "DPUStateUpdate|switch0_dpu0", data: dpu_actor_bfd_down_state, addr: runtime.sp("vdpu", "test-vdpu") }, // simulate delete of Dpu entry diff --git a/crates/hamgrd/src/db_structs.rs b/crates/hamgrd/src/db_structs.rs index 635e540..5fb3c0a 100644 --- a/crates/hamgrd/src/db_structs.rs +++ b/crates/hamgrd/src/db_structs.rs @@ -145,7 +145,6 @@ impl Default for DpuState { } /// -#[serde_as] #[derive(Serialize, Deserialize, Clone, PartialEq, Eq, Debug, SonicDb)] #[sonicdb( table_name = "DASH_BFD_PROBE_STATE", @@ -154,8 +153,11 @@ impl Default for DpuState { is_dpu = "true" )] pub struct DashBfdProbeState { - #[serde(default)] - #[serde_as(as = "StringWithSeparator::")] + #[serde( + default, + deserialize_with = "string_array_deserialize", + serialize_with = "string_array_serialize" + )] pub v4_bfd_up_sessions: Vec, #[serde( default = "now_in_millis", @@ -163,8 +165,11 @@ pub struct DashBfdProbeState { serialize_with = "timestamp_serialize" )] pub v4_bfd_up_sessions_timestamp: i64, - #[serde(default)] - #[serde_as(as = "StringWithSeparator::")] + #[serde( + default, + deserialize_with = "string_array_deserialize", + serialize_with = "string_array_serialize" + )] pub v6_bfd_up_sessions: Vec, #[serde( default = "now_in_millis", @@ -195,6 +200,47 @@ where serializer.serialize_str(&formatted) } +fn string_array_deserialize<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + // Handle both missing fields and present fields + let opt: Option = Option::deserialize(deserializer)?; + match opt { + Some(s) => { + if s.trim().is_empty() { + return Ok(Vec::new()); + } + + let sessions: Vec = s + .split(',') + .map(|item| { + // Trim whitespace around each item + let trimmed = item.trim(); + // Remove enclosing single or double quotes if present + let trimmed = trimmed.strip_prefix('"').unwrap_or(trimmed); + let trimmed = trimmed.strip_suffix('"').unwrap_or(trimmed); + let trimmed = trimmed.strip_prefix('\'').unwrap_or(trimmed); + let trimmed = trimmed.strip_suffix('\'').unwrap_or(trimmed); + trimmed.to_string() + }) + .filter(|item| !item.is_empty()) // Filter out empty strings + .collect(); + + Ok(sessions) + } + None => Ok(Vec::new()), // Use empty vector when field is missing + } +} + +fn string_array_serialize(sessions: &[String], serializer: S) -> Result +where + S: Serializer, +{ + let joined = sessions.join(","); + serializer.serialize_str(&joined) +} + fn timestamp_deserialize<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -203,7 +249,14 @@ where let opt: Option = Option::deserialize(deserializer)?; match opt { Some(s) => { - let naive = chrono::NaiveDateTime::parse_from_str(&s, TIMESTAMP_FORMAT).map_err(de::Error::custom)?; + // Trim whitespace and remove enclosing single or double quotes if present + let s = s.trim(); + let s = s.strip_prefix('"').unwrap_or(s); + let s = s.strip_suffix('"').unwrap_or(s); + let s = s.strip_prefix('\'').unwrap_or(s); + let s = s.strip_suffix('\'').unwrap_or(s); + + let naive = chrono::NaiveDateTime::parse_from_str(s, TIMESTAMP_FORMAT).map_err(de::Error::custom)?; Ok(naive.and_utc().timestamp_millis()) } None => Ok(now_in_millis()), // Use default when field is missing @@ -705,4 +758,54 @@ mod test { let now = chrono::Utc::now().timestamp_millis(); assert!((bfd_state.v6_bfd_up_sessions_timestamp - now).abs() < 1000); } + + #[test] + fn test_bfd_sessions_quote_and_whitespace_handling() { + let json = r#" + { + "v4_bfd_up_sessions": " \"10.0.1.1\" , '10.0.1.2', 10.0.1.3 , \"10.0.1.4\" ", + "v6_bfd_up_sessions": "'2001:db8::1', \"2001:db8::2\" , 2001:db8::3 " + }"#; + + let fvs: FieldValues = serde_json::from_str(json).unwrap(); + let bfd_state: DashBfdProbeState = swss_serde::from_field_values(&fvs).unwrap(); + + // Test v4 sessions - should have quotes removed and whitespace trimmed + assert_eq!(bfd_state.v4_bfd_up_sessions.len(), 4); + assert_eq!(bfd_state.v4_bfd_up_sessions[0], "10.0.1.1"); + assert_eq!(bfd_state.v4_bfd_up_sessions[1], "10.0.1.2"); + assert_eq!(bfd_state.v4_bfd_up_sessions[2], "10.0.1.3"); + assert_eq!(bfd_state.v4_bfd_up_sessions[3], "10.0.1.4"); + + // Test v6 sessions - should have quotes removed and whitespace trimmed + assert_eq!(bfd_state.v6_bfd_up_sessions.len(), 3); + assert_eq!(bfd_state.v6_bfd_up_sessions[0], "2001:db8::1"); + assert_eq!(bfd_state.v6_bfd_up_sessions[1], "2001:db8::2"); + assert_eq!(bfd_state.v6_bfd_up_sessions[2], "2001:db8::3"); + } + + #[test] + fn test_bfd_sessions_serialization_roundtrip() { + // Test data with quotes and whitespace + let json = r#" + { + "v4_bfd_up_sessions": " \"10.0.1.1\" , '10.0.1.2', 10.0.1.3 ", + "v6_bfd_up_sessions": "'2001:db8::1', \"2001:db8::2\"" + }"#; + + let fvs: FieldValues = serde_json::from_str(json).unwrap(); + let bfd_state: DashBfdProbeState = swss_serde::from_field_values(&fvs).unwrap(); + + // Serialize back to field values + let serialized_fvs = swss_serde::to_field_values(&bfd_state).unwrap(); + + // Check that serialized format is clean comma-separated without quotes + assert_eq!(serialized_fvs["v4_bfd_up_sessions"], "10.0.1.1,10.0.1.2,10.0.1.3"); + assert_eq!(serialized_fvs["v6_bfd_up_sessions"], "2001:db8::1,2001:db8::2"); + + // Deserialize again to ensure consistency + let bfd_state2: DashBfdProbeState = swss_serde::from_field_values(&serialized_fvs).unwrap(); + assert_eq!(bfd_state.v4_bfd_up_sessions, bfd_state2.v4_bfd_up_sessions); + assert_eq!(bfd_state.v6_bfd_up_sessions, bfd_state2.v6_bfd_up_sessions); + } }