Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Psbt bugfixing and test covering #3

Merged
merged 8 commits into from
Oct 5, 2023
66 changes: 56 additions & 10 deletions psbt/src/coders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ use bpstd::{
CompressedPk, ConsensusDataError, ConsensusDecode, ConsensusDecodeError, ConsensusEncode,
DerivationIndex, DerivationPath, HardenedIndex, Idx, InternalPk, KeyOrigin, LegacyPk, LockTime,
RedeemScript, Sats, ScriptBytes, ScriptPubkey, SeqNo, SigScript, TaprootPk, Tx, TxOut, TxVer,
Txid, UncompressedPk, VarInt, Vout, Witness, WitnessScript, Xpub, XpubDecodeError, XpubFp,
XpubOrigin,
Txid, UncompressedPk, VarInt, VarIntArray, Vout, Witness, WitnessScript, Xpub, XpubDecodeError,
XpubFp, XpubOrigin,
};

use crate::keys::KeyValue;
Expand Down Expand Up @@ -223,7 +223,7 @@ impl Psbt {
.map(PsbtVer::deserialize)
.transpose()?
.unwrap_or(PsbtVer::V0);
let mut psbt = Psbt::create();
let mut psbt = Psbt::create(PsbtVer::V0);
psbt.parse_map(version, map)?;

for input in &mut psbt.inputs {
Expand Down Expand Up @@ -442,10 +442,14 @@ impl Encode for LegacyPk {

impl Decode for LegacyPk {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
let flag = u8::decode(reader)?;
match flag {
02 | 03 => CompressedPk::decode(reader).map(Self::from),
04 => UncompressedPk::decode(reader).map(Self::from),
let mut buf = [0u8; 65];
reader.read_exact(&mut buf[..33])?;
match buf[0] {
02 | 03 => CompressedPk::decode(&mut Cursor::new(&buf[..33])).map(Self::from),
04 => {
reader.read_exact(&mut buf[33..])?;
UncompressedPk::decode(&mut Cursor::new(buf)).map(Self::from)
}
other => Err(PsbtError::UnrecognizedKeyFormat(other).into()),
}
}
Expand Down Expand Up @@ -622,11 +626,33 @@ impl Decode for LockHeight {
}
}

psbt_code_using_consensus!(ScriptBytes);
psbt_code_using_consensus!(SigScript);
psbt_code_using_consensus!(ScriptPubkey);
psbt_code_using_consensus!(Witness);

impl Encode for ScriptBytes {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
RawBytes(self.as_inner()).encode(writer)
}
}

impl Decode for ScriptBytes {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
let bytes = RawBytes::<VarIntArray<u8>>::decode(reader)?;
Ok(ScriptBytes::from_inner(bytes.0))
}
}

impl Encode for ScriptPubkey {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
self.as_script_bytes().encode(writer)
}
}

impl Decode for ScriptPubkey {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
ScriptBytes::decode(reader).map(Self::from_inner)
}
}

impl Encode for WitnessScript {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
self.as_script_bytes().encode(writer)
Expand All @@ -651,6 +677,18 @@ impl Decode for RedeemScript {
}
}

impl Encode for SigScript {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
self.as_script_bytes().encode(writer)
}
}

impl Decode for SigScript {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
ScriptBytes::decode(reader).map(Self::from_inner)
}
}

psbt_code_using_consensus!(Sats);
psbt_code_using_consensus!(u8);
psbt_code_using_consensus!(u32);
Expand All @@ -674,6 +712,14 @@ impl Decode for RawBytes<Vec<u8>> {
}
}

impl Decode for RawBytes<VarIntArray<u8>> {
fn decode(reader: &mut impl Read) -> Result<Self, DecodeError> {
let mut buf = Vec::new();
reader.read_to_end(&mut buf)?;
VarIntArray::try_from(buf).map(Self).map_err(DecodeError::from)
}
}

impl<const LEN: usize> Encode for Array<u8, LEN> {
fn encode(&self, writer: &mut dyn Write) -> Result<usize, IoError> {
writer.write_all(self.as_inner())?;
Expand Down
152 changes: 81 additions & 71 deletions psbt/src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::fmt;
use std::fmt::{Display, Formatter, LowerHex};
use std::str::FromStr;

use amplify::confinement::Confined;
use amplify::hex::{FromHex, ToHex};
use amplify::num::u5;
use amplify::{hex, Bytes20, Bytes32};
use base64::Engine;
use amplify::{Bytes20, Bytes32};
use bpstd::{
CompressedPk, Descriptor, InternalPk, KeyOrigin, LegacyPk, LockTime, NormalIndex, Outpoint,
RedeemScript, Sats, ScriptPubkey, SeqNo, SigScript, TaprootPk, Terminal, Tx, TxIn, TxOut,
TxVer, Txid, Vout, Witness, WitnessScript, Xpub, XpubOrigin,
};
use indexmap::IndexMap;

pub use self::display_from_str::PsbtParseError;
use crate::{
Bip340Sig, KeyData, LegacySig, LockHeight, LockTimestamp, PropKey, PsbtError, PsbtVer,
SighashType, ValueData,
Expand Down Expand Up @@ -103,13 +98,13 @@ pub struct Psbt {
}

impl Default for Psbt {
fn default() -> Self { Psbt::create() }
fn default() -> Self { Psbt::create(PsbtVer::V2) }
}

impl Psbt {
pub fn create() -> Psbt {
pub fn create(version: PsbtVer) -> Psbt {
Psbt {
version: PsbtVer::V2,
version,
tx_version: TxVer::V2,
fallback_locktime: None,
inputs: vec![],
Expand All @@ -122,12 +117,13 @@ impl Psbt {
}

pub fn from_unsigned_tx(unsigned_tx: Tx) -> Self {
let mut psbt = Psbt::create();
let mut psbt = Psbt::create(PsbtVer::V0);
psbt.reset_from_unsigned_tx(unsigned_tx);
psbt
}

pub(crate) fn reset_from_unsigned_tx(&mut self, tx: Tx) {
self.version = PsbtVer::V0;
self.tx_version = tx.version;
self.fallback_locktime = Some(tx.lock_time);
self.inputs = tx.inputs.into_iter().enumerate().map(Input::from_unsigned_txin).collect();
Expand Down Expand Up @@ -317,84 +313,92 @@ impl Psbt {
}
}

#[derive(Clone, Debug, Display, Error, From)]
#[display(inner)]
pub enum PsbtParseError {
#[from]
Hex(hex::Error),
mod display_from_str {
use std::fmt::{self, Display, Formatter, LowerHex};
use std::str::FromStr;

#[from]
Base64(base64::DecodeError),
use amplify::hex::{self, FromHex, ToHex};
use base64::display::Base64Display;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;

#[from]
Psbt(PsbtError),
}
use super::*;

impl Psbt {
fn base64_engine() -> base64::engine::GeneralPurpose {
base64::engine::GeneralPurpose::new(
&base64::alphabet::STANDARD,
base64::engine::GeneralPurposeConfig::new(),
)
}
#[derive(Clone, Debug, Display, Error, From)]
#[display(inner)]
pub enum PsbtParseError {
#[from]
Hex(hex::Error),

pub fn from_base64(s: &str) -> Result<Psbt, PsbtParseError> {
let engine = Self::base64_engine();
let data = engine.decode(s)?;
Psbt::deserialize(data).map_err(PsbtParseError::from)
}
#[from]
Base64(base64::DecodeError),

pub fn from_base16(s: &str) -> Result<Psbt, PsbtParseError> {
let data = Vec::<u8>::from_hex(s)?;
Psbt::deserialize(data).map_err(PsbtParseError::from)
#[from]
Psbt(PsbtError),
}

pub fn to_base64(&self) -> String { self.to_base64_ver(self.version) }
impl Psbt {
pub fn from_base64(s: &str) -> Result<Psbt, PsbtParseError> {
Psbt::deserialize(BASE64_STANDARD.decode(s)?).map_err(PsbtParseError::from)
}

pub fn to_base64_ver(&self, version: PsbtVer) -> String {
let engine = Self::base64_engine();
engine.encode(self.serialize(version))
}
pub fn from_base16(s: &str) -> Result<Psbt, PsbtParseError> {
let data = Vec::<u8>::from_hex(s)?;
Psbt::deserialize(data).map_err(PsbtParseError::from)
}

pub fn to_base16(&self) -> String { self.to_base16_ver(self.version) }
pub fn to_base64(&self) -> String { self.to_base64_ver(self.version) }

pub fn to_base16_ver(&self, version: PsbtVer) -> String { self.serialize(version).to_hex() }
}
pub fn to_base64_ver(&self, version: PsbtVer) -> String {
BASE64_STANDARD.encode(self.serialize(version))
}

impl FromStr for Psbt {
type Err = PsbtParseError;
pub fn to_base16(&self) -> String { self.to_base16_ver(self.version) }

#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_base16(s).or_else(|_| Self::from_base64(s))
pub fn to_base16_ver(&self, version: PsbtVer) -> String { self.serialize(version).to_hex() }
}
}

impl Display for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut ver = match f.width().unwrap_or(0) {
0 => PsbtVer::V0,
2 => PsbtVer::V2,
_ => return Err(fmt::Error),
};
if f.alternate() {
ver = PsbtVer::V2;
/// FromStr implementation parses both Base64 and Hex (Base16) encodings.
impl FromStr for Psbt {
type Err = PsbtParseError;

#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_base16(s).or_else(|_| Self::from_base64(s))
}
f.write_str(&self.to_base64_ver(ver))
}
}

impl LowerHex for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut ver = match f.width().unwrap_or(0) {
0 => PsbtVer::V0,
2 => PsbtVer::V2,
_ => return Err(fmt::Error),
};
if f.alternate() {
ver = PsbtVer::V2;
/// PSBT displays Base64-encoded string. The selection of the version if the following:
/// - by default, it uses version specified in the PSBT itself;
/// - if zero `{:0}` is given and no width (`{:0}`) or a zero width (`{:00}`) is provided, than
/// the PSBT is encoded as V0 even if the structure itself uses V2;
/// - if a width equal to two is given like in `{:2}`, than zero flag is ignored (so `{:02}`
/// also works that way) and PSBT is encoded as V2 even if the structure itself uses V1;
/// - all other flags has no effect on the display.
impl Display for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let ver = match (f.width(), f.sign_aware_zero_pad()) {
(None, true) => PsbtVer::V0,
(Some(0), _) => PsbtVer::V0,
(Some(2), _) => PsbtVer::V2,
_ => self.version,
};
write!(f, "{}", Base64Display::new(&self.serialize(ver), &BASE64_STANDARD))
}
}

impl LowerHex for Psbt {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let mut ver = match f.width().unwrap_or(0) {
0 => PsbtVer::V0,
2 => PsbtVer::V2,
_ => return Err(fmt::Error),
};
if f.alternate() {
ver = PsbtVer::V2;
}
f.write_str(&self.to_base16_ver(ver))
}
f.write_str(&self.to_base16_ver(ver))
}
}

Expand Down Expand Up @@ -616,6 +620,9 @@ impl Input {

#[inline]
pub fn value(&self) -> Sats { self.prev_txout().value }

#[inline]
pub fn index(&self) -> usize { self.index }
}

#[derive(Clone, Eq, PartialEq, Debug)]
Expand Down Expand Up @@ -708,6 +715,9 @@ impl Output {

#[inline]
pub fn value(&self) -> Sats { self.amount }

#[inline]
pub fn index(&self) -> usize { self.index }
}

#[derive(Clone, Eq, PartialEq, Debug)]
Expand Down
1 change: 1 addition & 0 deletions psbt/src/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ pub enum KeyValue<T: KeyType> {
Separator,
}

#[derive(Debug)]
pub struct KeyPair<T: KeyType, K, V> {
pub key_type: T,
pub key_data: K,
Expand Down
Loading
Loading