Skip to content

Commit

Permalink
Merge pull request #3 from BP-WG/psbt
Browse files Browse the repository at this point in the history
Psbt bugfixing and test covering
  • Loading branch information
dr-orlovsky authored Oct 5, 2023
2 parents fc52055 + 31e330b commit e1fd99c
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 97 deletions.
66 changes: 56 additions & 10 deletions psbt/src/coders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use bpstd::{
CompressedPk, ConsensusDataError, ConsensusDecode, ConsensusDecodeError, ConsensusEncode,
DerivationIndex, DerivationPath, HardenedIndex, Idx, InternalPk, KeyOrigin, LegacyPk, LockTime,
RedeemScript, Sats, ScriptBytes, ScriptPubkey, SeqNo, SigScript, TaprootPk, Tx, TxOut, TxVer,
Txid, UncompressedPk, VarInt, Vout, Witness, WitnessScript, Xpub, XpubDecodeError, XpubFp,
XpubOrigin,
Txid, UncompressedPk, VarInt, VarIntArray, Vout, Witness, WitnessScript, Xpub, XpubDecodeError,
XpubFp, XpubOrigin,
};

use crate::keys::KeyValue;
Expand Down Expand Up @@ -223,7 +223,7 @@ impl Psbt {
.map(PsbtVer::deserialize)
.transpose()?
.unwrap_or(PsbtVer::V0);
let mut psbt = Psbt::create();
let mut psbt = Psbt::create(PsbtVer::V0);
psbt.parse_map(version, map)?;

for input in &mut psbt.inputs {
Expand Down Expand Up @@ -442,10 +442,14 @@ impl Encode for LegacyPk {

impl Decode for LegacyPk {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
let flag = u8::decode(reader)?;
match flag {
02 | 03 => CompressedPk::decode(reader).map(Self::from),
04 => UncompressedPk::decode(reader).map(Self::from),
let mut buf = [0u8; 65];
reader.read_exact(&mut buf[..33])?;
match buf[0] {
02 | 03 => CompressedPk::decode(&mut Cursor::new(&buf[..33])).map(Self::from),
04 => {
reader.read_exact(&mut buf[33..])?;
UncompressedPk::decode(&mut Cursor::new(buf)).map(Self::from)
}
other => Err(PsbtError::UnrecognizedKeyFormat(other).into()),
}
}
Expand Down Expand Up @@ -622,11 +626,33 @@ impl Decode for LockHeight {
}
}

psbt_code_using_consensus!(ScriptBytes);
psbt_code_using_consensus!(SigScript);
psbt_code_using_consensus!(ScriptPubkey);
psbt_code_using_consensus!(Witness);

impl Encode for ScriptBytes {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
RawBytes(self.as_inner()).encode(writer)
}
}

impl Decode for ScriptBytes {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
let bytes = RawBytes::<VarIntArray<u8>>::decode(reader)?;
Ok(ScriptBytes::from_inner(bytes.0))
}
}

impl Encode for ScriptPubkey {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
self.as_script_bytes().encode(writer)
}
}

impl Decode for ScriptPubkey {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
ScriptBytes::decode(reader).map(Self::from_inner)
}
}

impl Encode for WitnessScript {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
self.as_script_bytes().encode(writer)
Expand All @@ -651,6 +677,18 @@ impl Decode for RedeemScript {
}
}

impl Encode for SigScript {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
self.as_script_bytes().encode(writer)
}
}

impl Decode for SigScript {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
ScriptBytes::decode(reader).map(Self::from_inner)
}
}

psbt_code_using_consensus!(Sats);
psbt_code_using_consensus!(u8);
psbt_code_using_consensus!(u32);
Expand All @@ -674,6 +712,14 @@ impl Decode for RawBytes<Vec<u8>> {
}
}

impl Decode for RawBytes<VarIntArray<u8>> {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
let mut buf = Vec::new();
reader.read_to_end(&mut buf)?;
VarIntArray::try_from(buf).map(Self).map_err(DecodeError::from)
}
}

impl<const LEN: usize> Encode for Array<u8, LEN> {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
writer.write_all(self.as_inner())?;
Expand Down
152 changes: 81 additions & 71 deletions psbt/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;
use std::fmt::{Display, Formatter, LowerHex};
use std::str::FromStr;

use amplify::confinement::Confined;
use amplify::hex::{FromHex, ToHex};
use amplify::num::u5;
use amplify::{hex, Bytes20, Bytes32};
use base64::Engine;
use amplify::{Bytes20, Bytes32};
use bpstd::{
CompressedPk, Descriptor, InternalPk, KeyOrigin, LegacyPk, LockTime, NormalIndex, Outpoint,
RedeemScript, Sats, ScriptPubkey, SeqNo, SigScript, TaprootPk, Terminal, Tx, TxIn, TxOut,
TxVer, Txid, Vout, Witness, WitnessScript, Xpub, XpubOrigin,
};
use indexmap::IndexMap;

pub use self::display_from_str::PsbtParseError;
use crate::{
Bip340Sig, KeyData, LegacySig, LockHeight, LockTimestamp, PropKey, PsbtError, PsbtVer,
SighashType, ValueData,
Expand Down Expand Up @@ -103,13 +98,13 @@ pub struct Psbt {
}

impl Default for Psbt {
fn default() -> Self { Psbt::create() }
fn default() -> Self { Psbt::create(PsbtVer::V2) }
}

impl Psbt {
pub fn create() -> Psbt {
pub fn create(version: PsbtVer) -> Psbt {
Psbt {
version: PsbtVer::V2,
version,
tx_version: TxVer::V2,
fallback_locktime: None,
inputs: vec![],
Expand All @@ -122,12 +117,13 @@ impl Psbt {
}

pub fn from_unsigned_tx(unsigned_tx: Tx) -> Self {
let mut psbt = Psbt::create();
let mut psbt = Psbt::create(PsbtVer::V0);
psbt.reset_from_unsigned_tx(unsigned_tx);
psbt
}

pub(crate) fn reset_from_unsigned_tx(&mut self, tx: Tx) {
self.version = PsbtVer::V0;
self.tx_version = tx.version;
self.fallback_locktime = Some(tx.lock_time);
self.inputs = tx.inputs.into_iter().enumerate().map(Input::from_unsigned_txin).collect();
Expand Down Expand Up @@ -317,84 +313,92 @@ impl Psbt {
}
}

#[derive(Clone, Debug, Display, Error, From)]
#[display(inner)]
pub enum PsbtParseError {
#[from]
Hex(hex::Error),
mod display_from_str {
use std::fmt::{self, Display, Formatter, LowerHex};
use std::str::FromStr;

#[from]
Base64(base64::DecodeError),
use amplify::hex::{self, FromHex, ToHex};
use base64::display::Base64Display;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;

#[from]
Psbt(PsbtError),
}
use super::*;

impl Psbt {
fn base64_engine() -> base64::engine::GeneralPurpose {
base64::engine::GeneralPurpose::new(
&base64::alphabet::STANDARD,
base64::engine::GeneralPurposeConfig::new(),
)
}
#[derive(Clone, Debug, Display, Error, From)]
#[display(inner)]
pub enum PsbtParseError {
#[from]
Hex(hex::Error),

pub fn from_base64(s: &str) -> Result<Psbt, PsbtParseError> {
let engine = Self::base64_engine();
let data = engine.decode(s)?;
Psbt::deserialize(data).map_err(PsbtParseError::from)
}
#[from]
Base64(base64::DecodeError),

pub fn from_base16(s: &str) -> Result<Psbt, PsbtParseError> {
let data = Vec::<u8>::from_hex(s)?;
Psbt::deserialize(data).map_err(PsbtParseError::from)
#[from]
Psbt(PsbtError),
}

pub fn to_base64(&self) -> String { self.to_base64_ver(self.version) }
impl Psbt {
pub fn from_base64(s: &str) -> Result<Psbt, PsbtParseError> {
Psbt::deserialize(BASE64_STANDARD.decode(s)?).map_err(PsbtParseError::from)
}

pub fn to_base64_ver(&self, version: PsbtVer) -> String {
let engine = Self::base64_engine();
engine.encode(self.serialize(version))
}
pub fn from_base16(s: &str) -> Result<Psbt, PsbtParseError> {
let data = Vec::<u8>::from_hex(s)?;
Psbt::deserialize(data).map_err(PsbtParseError::from)
}

pub fn to_base16(&self) -> String { self.to_base16_ver(self.version) }
pub fn to_base64(&self) -> String { self.to_base64_ver(self.version) }

pub fn to_base16_ver(&self, version: PsbtVer) -> String { self.serialize(version).to_hex() }
}
pub fn to_base64_ver(&self, version: PsbtVer) -> String {
BASE64_STANDARD.encode(self.serialize(version))
}

impl FromStr for Psbt {
type Err = PsbtParseError;
pub fn to_base16(&self) -> String { self.to_base16_ver(self.version) }

#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_base16(s).or_else(|_| Self::from_base64(s))
pub fn to_base16_ver(&self, version: PsbtVer) -> String { self.serialize(version).to_hex() }
}
}

impl Display for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut ver = match f.width().unwrap_or(0) {
0 => PsbtVer::V0,
2 => PsbtVer::V2,
_ => return Err(fmt::Error),
};
if f.alternate() {
ver = PsbtVer::V2;
/// FromStr implementation parses both Base64 and Hex (Base16) encodings.
impl FromStr for Psbt {
type Err = PsbtParseError;

#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_base16(s).or_else(|_| Self::from_base64(s))
}
f.write_str(&self.to_base64_ver(ver))
}
}

impl LowerHex for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut ver = match f.width().unwrap_or(0) {
0 => PsbtVer::V0,
2 => PsbtVer::V2,
_ => return Err(fmt::Error),
};
if f.alternate() {
ver = PsbtVer::V2;
/// PSBT displays Base64-encoded string. The selection of the version if the following:
/// - by default, it uses version specified in the PSBT itself;
/// - if zero `{:0}` is given and no width (`{:0}`) or a zero width (`{:00}`) is provided, than
/// the PSBT is encoded as V0 even if the structure itself uses V2;
/// - if a width equal to two is given like in `{:2}`, than zero flag is ignored (so `{:02}`
/// also works that way) and PSBT is encoded as V2 even if the structure itself uses V1;
/// - all other flags has no effect on the display.
impl Display for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let ver = match (f.width(), f.sign_aware_zero_pad()) {
(None, true) => PsbtVer::V0,
(Some(0), _) => PsbtVer::V0,
(Some(2), _) => PsbtVer::V2,
_ => self.version,
};
write!(f, "{}", Base64Display::new(&self.serialize(ver), &BASE64_STANDARD))
}
}

impl LowerHex for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut ver = match f.width().unwrap_or(0) {
0 => PsbtVer::V0,
2 => PsbtVer::V2,
_ => return Err(fmt::Error),
};
if f.alternate() {
ver = PsbtVer::V2;
}
f.write_str(&self.to_base16_ver(ver))
}
f.write_str(&self.to_base16_ver(ver))
}
}

Expand Down Expand Up @@ -616,6 +620,9 @@ impl Input {

#[inline]
pub fn value(&self) -> Sats { self.prev_txout().value }

#[inline]
pub fn index(&self) -> usize { self.index }
}

#[derive(Clone, Eq, PartialEq, Debug)]
Expand Down Expand Up @@ -708,6 +715,9 @@ impl Output {

#[inline]
pub fn value(&self) -> Sats { self.amount }

#[inline]
pub fn index(&self) -> usize { self.index }
}

#[derive(Clone, Eq, PartialEq, Debug)]
Expand Down
1 change: 1 addition & 0 deletions psbt/src/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ pub enum KeyValue<T: KeyType> {
Separator,
}

#[derive(Debug)]
pub struct KeyPair<T: KeyType, K, V> {
pub key_type: T,
pub key_data: K,
Expand Down
Loading

0 comments on commit e1fd99c

Please sign in to comment.