Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Wallet: deprecate old hash to curve" #458

Merged
merged 1 commit into from
Feb 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 36 additions & 38 deletions cashu/core/crypto/b_dhke.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,26 @@

from secp256k1 import PrivateKey, PublicKey


def hash_to_curve(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point


DOMAIN_SEPARATOR = b"Secp256k1_HashToCurve_Cashu_"


def hash_to_curve(message: bytes) -> PublicKey:
def hash_to_curve_domain_separated(message: bytes) -> PublicKey:
"""Generates a secp256k1 point from a message.

The point is generated by hashing the message with a domain separator and then
Expand Down Expand Up @@ -94,6 +110,15 @@ def step1_alice(
return B_, r


def step1_alice_domain_separated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r


def step2_bob(B_: PublicKey, a: PrivateKey) -> Tuple[PublicKey, PrivateKey, PrivateKey]:
C_: PublicKey = B_.mult(a) # type: ignore
# produce dleq proof
Expand All @@ -111,11 +136,17 @@ def verify(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
valid = C == Y.mult(a) # type: ignore
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
valid = verify_deprecated(a, C, secret_msg)
valid = verify_domain_separated(a, C, secret_msg)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid


def verify_domain_separated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid


def hash_e(*publickeys: PublicKey) -> bytes:
e_ = ""
for p in publickeys:
Expand Down Expand Up @@ -171,53 +202,20 @@ def carol_verify_dleq(
valid = alice_verify_dleq(B_, C_, e, s, A)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not valid:
return carol_verify_dleq_deprecated(secret_msg, r, C, e, s, A)
return carol_verify_dleq_domain_separated(secret_msg, r, C, e, s, A)
# END: BACKWARDS COMPATIBILITY < 0.15.1
return valid


# -------- Deprecated hash_to_curve before 0.15.0 --------


def hash_to_curve_deprecated(message: bytes) -> PublicKey:
"""Generates a point from the message hash and checks if the point lies on the curve.
If it does not, iteratively tries to compute a new point from the hash."""
point = None
msg_to_hash = message
while point is None:
_hash = hashlib.sha256(msg_to_hash).digest()
try:
# will error if point does not lie on curve
point = PublicKey(b"\x02" + _hash, raw=True)
except Exception:
msg_to_hash = _hash
return point


def step1_alice_deprecated(
secret_msg: str, blinding_factor: Optional[PrivateKey] = None
) -> tuple[PublicKey, PrivateKey]:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
r = blinding_factor or PrivateKey()
B_: PublicKey = Y + r.pubkey # type: ignore
return B_, r


def verify_deprecated(a: PrivateKey, C: PublicKey, secret_msg: str) -> bool:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
valid = C == Y.mult(a) # type: ignore
return valid


def carol_verify_dleq_deprecated(
def carol_verify_dleq_domain_separated(
secret_msg: str,
r: PrivateKey,
C: PublicKey,
e: PrivateKey,
s: PrivateKey,
A: PublicKey,
) -> bool:
Y: PublicKey = hash_to_curve_deprecated(secret_msg.encode("utf-8"))
Y: PublicKey = hash_to_curve_domain_separated(secret_msg.encode("utf-8"))
C_: PublicKey = C + A.mult(r) # type: ignore
B_: PublicKey = Y + r.pubkey # type: ignore
valid = alice_verify_dleq(B_, C_, e, s, A)
Expand Down
2 changes: 1 addition & 1 deletion cashu/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class WalletSettings(CashuSettings):
mint_port: int = Field(default=3338)
wallet_name: str = Field(default="wallet")
wallet_unit: str = Field(default="sat")
wallet_use_deprecated_h2c: bool = Field(default=False)
wallet_domain_separation: bool = Field(default=False)
api_port: int = Field(default=4448)
api_host: str = Field(default="127.0.0.1")

Expand Down
26 changes: 11 additions & 15 deletions cashu/wallet/wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,11 +998,9 @@ async def pay_lightning(
# NUT-08, the mint will imprint these outputs with a value depending on the
# amount of fees we overpaid.
n_change_outputs = calculate_number_of_blank_outputs(fee_reserve_sat)
(
change_secrets,
change_rs,
change_derivation_paths,
) = await self.generate_n_secrets(n_change_outputs)
change_secrets, change_rs, change_derivation_paths = (
await self.generate_n_secrets(n_change_outputs)
)
change_outputs, change_rs = self._construct_outputs(
n_change_outputs * [1], change_secrets, change_rs
)
Expand Down Expand Up @@ -1128,15 +1126,14 @@ async def _construct_proofs(
C = b_dhke.step3_alice(
C_, r, self.keysets[promise.id].public_keys[promise.amount]
)

if not settings.wallet_use_deprecated_h2c:
B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
if not settings.wallet_domain_separation:
B_, r = b_dhke.step1_alice(secret, r) # recompute B_ for dleq proofs
# END: BACKWARDS COMPATIBILITY < 0.15.1
else:
B_, r = b_dhke.step1_alice_deprecated(
B_, r = b_dhke.step1_alice_domain_separated(
secret, r
) # recompute B_ for dleq proofs
# END: BACKWARDS COMPATIBILITY < 0.15.1

proof = Proof(
id=promise.id,
Expand Down Expand Up @@ -1199,13 +1196,12 @@ def _construct_outputs(
rs_ = [None] * len(amounts) if not rs else rs
rs_return: List[PrivateKey] = []
for secret, amount, r in zip(secrets, amounts, rs_):
if not settings.wallet_use_deprecated_h2c:
B_, r = b_dhke.step1_alice(secret, r or None)
# BEGIN: BACKWARDS COMPATIBILITY < 0.15.1
else:
B_, r = b_dhke.step1_alice_deprecated(secret, r or None)
if not settings.wallet_domain_separation:
B_, r = b_dhke.step1_alice(secret, r or None)
# END: BACKWARDS COMPATIBILITY < 0.15.1

else:
B_, r = b_dhke.step1_alice_domain_separated(secret, r or None)
rs_return.append(r)
output = BlindedMessage(
amount=amount, B_=B_.serialize().hex(), id=self.keyset_id
Expand Down
Loading
Loading