diff --git a/src/commands/import_genesis_keys.py b/src/commands/import_genesis_keys.py new file mode 100644 index 00000000..cb370983 --- /dev/null +++ b/src/commands/import_genesis_keys.py @@ -0,0 +1,126 @@ +import glob +import os +from pathlib import Path +from typing import Dict + +import click +from Cryptodome.Cipher import AES, PKCS1_OAEP +from Cryptodome.PublicKey import RSA +from eth_typing import BLSPrivateKey, HexAddress, HexStr +from py_ecc.bls import G2ProofOfPossession +from web3 import Web3 + +from src.common.credentials import Credential +from src.common.password import get_or_create_password_file +from src.common.utils import greenify +from src.common.validators import validate_eth_address +from src.common.vault_config import VaultConfig +from src.config.settings import settings + + +@click.option( + '--data-dir', + default=str(Path.home() / '.stakewise'), + envvar='DATA_DIR', + help='Path where the vault data will be placed. Default is ~/.stakewise.', + type=click.Path(exists=True, file_okay=False, dir_okay=True), +) +@click.option( + '--rsa-key', + help='The RSA private key to decrypt keystores.', + type=click.Path(exists=True, file_okay=True, dir_okay=False), +) +@click.option( + '--exported-keys-dir', + help='Path where the encrypted keys are placed.', + type=click.Path(exists=True, file_okay=False, dir_okay=True), +) +@click.option( + '--vault', + help='The address of the vault.', + prompt='Enter the vault address', + type=str, + callback=validate_eth_address, +) +@click.command(help='Import encrypted keystores. Only for genesis vault') +# pylint: disable-next=too-many-arguments +def import_genesis_keys( + rsa_key: str, + exported_keys_dir: str, + vault: HexAddress, + data_dir: str, +) -> None: + vault_config = VaultConfig(vault, Path(data_dir)) + vault_config.load() + network = vault_config.network + + settings.set( + vault=vault, + network=network, + vault_dir=vault_config.vault_dir, + ) + if settings.network_config.GENESIS_VAULT_CONTRACT_ADDRESS != vault: + raise click.ClickException('Only genesis vault support keys import.') + + keystores_dir = vault_config.vault_dir / 'keystores' + password_file = keystores_dir / 'password.txt' + password = get_or_create_password_file(password_file) + + click.secho('Decrypting keystores...', bold=True) + + transferred_keypairs = _decrypt_transferred_keys( + keys_dir=exported_keys_dir, decrypt_key=rsa_key + ) + + click.secho(f'Saving keystores to {greenify(keystores_dir)}...', bold=True) + + index = 0 + for private_key in transferred_keypairs.values(): + credential = Credential( + private_key=BLSPrivateKey(private_key), + vault=vault, + network=network, + path=f'imported_{index}', + ) + credential.save_signing_keystore(password=password, folder=str(keystores_dir)) + index += 1 + + click.echo( + f'Done. Imported {greenify(len(transferred_keypairs))} keys for {greenify(vault)} vault.\n' + f'Keystores saved to {greenify(keystores_dir)} file\n' + ) + + +# pylint: disable-next=too-many-locals +def _decrypt_transferred_keys(keys_dir: str, decrypt_key: str) -> Dict[HexStr, int]: + keypairs: Dict[HexStr, int] = {} + + with open(decrypt_key, 'r', encoding='utf-8') as f: + rsa_key = RSA.import_key(f.read()) + for filename in glob.glob(os.path.join(keys_dir, '*.enc')): + with open(os.path.join(os.getcwd(), filename), 'rb') as f: + try: + enc_session_key, nonce, tag, ciphertext = [ + f.read(x) for x in (rsa_key.size_in_bytes(), 16, 16, -1) + ] + except Exception as e: + raise click.ClickException(f'Invalid encrypted private key file: {filename}') from e + + try: + cipher_rsa = PKCS1_OAEP.new(rsa_key) + session_key = cipher_rsa.decrypt(enc_session_key) + except Exception as e: + raise click.ClickException('Failed to decrypt the private key.') from e + + # Decrypt the data with the AES session key + cipher_aes = AES.new(session_key, AES.MODE_EAX, nonce) + try: + private_key = int(cipher_aes.decrypt_and_verify(ciphertext, tag)) + public_key = Web3.to_hex(G2ProofOfPossession.SkToPk(private_key)) + keypairs[public_key] = private_key + except Exception as e: + raise click.ClickException( + 'Failed to decrypt the private key file. Is it corrupted?' + ) from e + + return keypairs diff --git a/src/main.py b/src/main.py index c9e2ea68..74f5fc32 100644 --- a/src/main.py +++ b/src/main.py @@ -10,6 +10,7 @@ from src.commands.create_keys import create_keys from src.commands.create_wallet import create_wallet from src.commands.get_validators_root import get_validators_root +from src.commands.import_genesis_keys import import_genesis_keys from src.commands.init import init from src.commands.merge_deposit_data import merge_deposit_data from src.commands.recover import recover @@ -40,6 +41,7 @@ def cli() -> None: cli.add_command(start) cli.add_command(recover) cli.add_command(get_validators_root) +cli.add_command(import_genesis_keys) cli.add_command(remote_db_group) if __name__ == '__main__':