From bf67519768d91b049bf3a47c234ff71a81d0ede9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BF=97=E5=AE=87?= Date: Mon, 15 Jan 2024 17:48:36 +0800 Subject: [PATCH] feat(chain): add `LocalChain::disconnect_from` method --- crates/chain/src/local_chain.rs | 22 ++++++++ crates/chain/tests/test_local_chain.rs | 75 +++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/crates/chain/src/local_chain.rs b/crates/chain/src/local_chain.rs index bdd25d8e0..32fd72852 100644 --- a/crates/chain/src/local_chain.rs +++ b/crates/chain/src/local_chain.rs @@ -420,6 +420,28 @@ impl LocalChain { Ok(changeset) } + /// Removes blocks from (and inclusive of) the given `block_id`. + /// + /// This will remove blocks with a height equal or greater than `block_id`, but only if + /// `block_id` exists in the chain. + /// + /// # Errors + /// + /// This will fail with [`MissingGenesisError`] if the caller attempts to disconnect from the + /// genesis block. + pub fn disconnect_from(&mut self, block_id: BlockId) -> Result { + if self.index.get(&block_id.height) != Some(&block_id.hash) { + return Ok(ChangeSet::default()); + } + + let changeset = self + .index + .range(block_id.height..) + .map(|(&height, _)| (height, None)) + .collect::(); + self.apply_changeset(&changeset).map(|_| changeset) + } + /// Reindex the heights in the chain from (and including) `from` height fn reindex(&mut self, from: u32) { let _ = self.index.split_off(&from); diff --git a/crates/chain/tests/test_local_chain.rs b/crates/chain/tests/test_local_chain.rs index d09325bd9..25cbbb08e 100644 --- a/crates/chain/tests/test_local_chain.rs +++ b/crates/chain/tests/test_local_chain.rs @@ -1,5 +1,5 @@ use bdk_chain::local_chain::{ - AlterCheckPointError, CannotConnectError, ChangeSet, LocalChain, Update, + AlterCheckPointError, CannotConnectError, ChangeSet, LocalChain, MissingGenesisError, Update, }; use bitcoin::BlockHash; @@ -350,3 +350,76 @@ fn local_chain_insert_block() { assert_eq!(chain, t.expected_final, "[{}] unexpected final chain", i,); } } + +#[test] +fn local_chain_disconnect_from() { + struct TestCase { + name: &'static str, + original: LocalChain, + disconnect_from: (u32, BlockHash), + exp_result: Result, + exp_final: LocalChain, + } + + let test_cases = [ + TestCase { + name: "try_replace_genesis_should_fail", + original: local_chain![(0, h!("_"))], + disconnect_from: (0, h!("_")), + exp_result: Err(MissingGenesisError), + exp_final: local_chain![(0, h!("_"))], + }, + TestCase { + name: "try_replace_genesis_should_fail_2", + original: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C"))], + disconnect_from: (0, h!("_")), + exp_result: Err(MissingGenesisError), + exp_final: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C"))], + }, + TestCase { + name: "from_does_not_exist", + original: local_chain![(0, h!("_")), (3, h!("C"))], + disconnect_from: (2, h!("B")), + exp_result: Ok(ChangeSet::default()), + exp_final: local_chain![(0, h!("_")), (3, h!("C"))], + }, + TestCase { + name: "from_has_different_blockhash", + original: local_chain![(0, h!("_")), (2, h!("B"))], + disconnect_from: (2, h!("not_B")), + exp_result: Ok(ChangeSet::default()), + exp_final: local_chain![(0, h!("_")), (2, h!("B"))], + }, + TestCase { + name: "disconnect_one", + original: local_chain![(0, h!("_")), (2, h!("B"))], + disconnect_from: (2, h!("B")), + exp_result: Ok(ChangeSet::from_iter([(2, None)])), + exp_final: local_chain![(0, h!("_"))], + }, + TestCase { + name: "disconnect_three", + original: local_chain![(0, h!("_")), (2, h!("B")), (3, h!("C")), (4, h!("D"))], + disconnect_from: (2, h!("B")), + exp_result: Ok(ChangeSet::from_iter([(2, None), (3, None), (4, None)])), + exp_final: local_chain![(0, h!("_"))], + }, + ]; + + for (i, t) in test_cases.into_iter().enumerate() { + println!("Case {}: {}", i, t.name); + + let mut chain = t.original; + let result = chain.disconnect_from(t.disconnect_from.into()); + assert_eq!( + result, t.exp_result, + "[{}:{}] unexpected changeset result", + i, t.name + ); + assert_eq!( + chain, t.exp_final, + "[{}:{}] unexpected final chain", + i, t.name + ); + } +}