Skip to content

Commit

Permalink
feat: add deployment logs to network mode (#312)
Browse files Browse the repository at this point in the history
* feat(draft): add deployment logs to network mode

if the deployments db is initialized, insert all deployed contracts into the deployments db.
include information like tx data, ts, source bundle (for verification), session id
  • Loading branch information
charles-cooper authored Oct 1, 2024
1 parent 48c03ab commit f00e12b
Show file tree
Hide file tree
Showing 9 changed files with 271 additions and 21 deletions.
4 changes: 2 additions & 2 deletions boa/contracts/vyper/vyper_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ def at(self, address: Any) -> "VyperContract":
return ret

@cached_property
def standard_json(self):
def solc_json(self):
"""
Generates a standard JSON representation of the Vyper contract.
Generates a solc "standard json" representation of the Vyper contract.
"""
return build_solc_json(self.compiler_data)

Expand Down
150 changes: 150 additions & 0 deletions boa/deployments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import json
import sqlite3
import uuid
from dataclasses import asdict, dataclass, field, fields
from pathlib import Path
from typing import Any, Optional

from boa.util.abi import Address
from boa.util.open_ctx import Open

"""
Module to handle deployment objects. When a contract is deployed in network
mode, we enter it into the deployments database so that it can be
queried/verified later.
This module could potentially be handled as plugin functionality / or left
as functionality for higher-level frameworks.
"""

_session_id: str = None # type: ignore


# generate a unique session id, so that deployments can be queried by session
def get_session_id():
global _session_id
if _session_id is None:
_session_id = str(uuid.uuid4())
return _session_id


@dataclass(frozen=True)
class Deployment:
contract_address: Address # receipt_dict["createAddress"]
contract_name: str
rpc: str
deployer: Address # ostensibly equal to tx_dict["from"]
tx_hash: str
broadcast_ts: float # time the tx was broadcast
tx_dict: dict # raw tx fields
receipt_dict: dict # raw receipt fields
source_code: Optional[Any] # optional source code or bundle
session_id: str = field(default_factory=get_session_id)
deployment_id: Optional[int] = None # the db-assigned id - primary key

def sql_values(self):
ret = asdict(self)
# sqlite doesn't have json, just dump to string
ret["tx_dict"] = json.dumps(ret["tx_dict"])
ret["receipt_dict"] = json.dumps(ret["receipt_dict"])
if ret["source_code"] is not None:
ret["source_code"] = json.dumps(ret["source_code"])
return ret

def to_dict(self):
"""
Convert Deployment object to a dict, which is prepared to be
dumped to json.
"""
return asdict(self)

def to_json(self, *args, **kwargs):
"""
Convert a Deployment object to a json object. *args and **kwargs
are forwarded to the `json.dumps()` call.
"""
return json.dumps(self.to_dict(), *args, **kwargs)

@classmethod
def from_sql_tuple(cls, values):
assert len(values) == len(fields(cls))
ret = dict(zip([field.name for field in fields(cls)], values))
ret["contract_address"] = Address(ret["contract_address"])
ret["deployer"] = Address(ret["deployer"])
ret["tx_dict"] = json.loads(ret["tx_dict"])
ret["receipt_dict"] = json.loads(ret["receipt_dict"])
if ret["source_code"] is not None:
ret["source_code"] = json.loads(ret["source_code"])
return cls(**ret)


_CREATE_CMD = """
CREATE TABLE IF NOT EXISTS
deployments(
deployment_id integer primary key autoincrement,
session_id text,
contract_address text,
contract_name text,
rpc text,
deployer text,
tx_hash text,
broadcast_ts real,
tx_dict text,
receipt_dict text,
source_code text
);
"""


class DeploymentsDB:
def __init__(self, path=":memory:"):
if path != ":memory:": # sqlite magic path
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)

# once 3.12 is min version, use autocommit=True
self.db = sqlite3.connect(path)

self.db.execute(_CREATE_CMD)

def __del__(self):
self.db.close()

def insert_deployment(self, deployment: Deployment):
values = deployment.sql_values()

values_placeholder = ",".join(["?"] * len(values))
colnames = ",".join(values.keys())

insert_cmd = f"INSERT INTO deployments({colnames}) VALUES({values_placeholder})"

self.db.execute(insert_cmd, tuple(values.values()))
self.db.commit()

def _get_deployments_from_sql(self, sql_query: str, parameters=(), /):
cur = self.db.execute(sql_query, parameters)
ret = [Deployment.from_sql_tuple(item) for item in cur.fetchall()]
return ret

def _get_fieldnames_str(self) -> str:
return ",".join(field.name for field in fields(Deployment))

def get_deployments(self) -> list[Deployment]:
fieldnames = self._get_fieldnames_str()
return self._get_deployments_from_sql(f"SELECT {fieldnames} FROM deployments")


_db: Optional[DeploymentsDB] = None


def set_deployments_db(db: Optional[DeploymentsDB]):
def set_(db):
global _db
_db = db

return Open(get_deployments_db, set_, db)


def get_deployments_db():
global _db
return _db
39 changes: 35 additions & 4 deletions boa/network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# an Environment which interacts with a real (prod or test) chain
import contextlib
import time
import warnings
from dataclasses import dataclass
from functools import cached_property
Expand All @@ -8,6 +9,7 @@
from eth_account import Account
from requests.exceptions import HTTPError

from boa.deployments import Deployment, get_deployments_db
from boa.environment import Env, _AddressType
from boa.rpc import (
RPC,
Expand All @@ -20,6 +22,7 @@
trim_dict,
)
from boa.util.abi import Address
from boa.verifiers import get_verification_bundle


class TraceObject:
Expand Down Expand Up @@ -300,7 +303,7 @@ def execute_code(

if is_modifying:
try:
receipt, trace = self._send_txn(
txdata, receipt, trace = self._send_txn(
from_=sender, to=to_address, value=value, gas=gas, data=hexdata
)
except _EstimateGasFailed:
Expand Down Expand Up @@ -375,7 +378,9 @@ def deploy(
bytecode = to_hex(bytecode)
sender = self._check_sender(self._get_sender(sender))

receipt, trace = self._send_txn(
broadcast_ts = time.time()

txdata, receipt, trace = self._send_txn(
from_=sender, value=value, gas=gas, data=bytecode
)

Expand All @@ -394,9 +399,35 @@ def deploy(
if local_address != create_address:
raise RuntimeError(f"uh oh! {local_address} != {create_address}")

# TODO get contract info in here
print(f"contract deployed at {create_address}")

if (deployments_db := get_deployments_db()) is not None:
contract_name = getattr(contract, "contract_name", None)
try:
source_bundle = get_verification_bundle(contract)
except Exception as e:
# there was a problem constructing the verification bundle.
# assume the user cares more about continuing, than getting
# the bundle into the db
msg = "While saving deployment data, couldn't construct"
msg += f" verification bundle for {contract_name}! Full stack"
msg += f" trace:\n```\n{e}\n```\nContinuing.\n"
warnings.warn(msg, stacklevel=2)
source_bundle = None

deployment_data = Deployment(
create_address,
contract_name,
self._rpc.name,
sender,
receipt["transactionHash"],
broadcast_ts,
txdata,
receipt,
source_bundle,
)
deployments_db.insert_deployment(deployment_data)

return create_address, computation

@cached_property
Expand Down Expand Up @@ -538,7 +569,7 @@ def _send_txn(self, from_, to=None, gas=None, value=None, data=None):
self._reset_fork(block_identifier=receipt["blockNumber"])

t_obj = TraceObject(trace) if trace is not None else None
return receipt, t_obj
return tx_data, receipt, t_obj

def get_chain_id(self) -> int:
"""Get the current chain ID of the network as an integer."""
Expand Down
4 changes: 2 additions & 2 deletions boa/util/open_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ def __init__(self, get, set_, item):
self.anchor = get()
self._set = set_
self._set(item)
self._item = item

def __enter__(self):
# dummy implementation, no-op
pass
return self._item

def __exit__(self, *args):
self._set(self.anchor)
31 changes: 20 additions & 11 deletions boa/verifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ def verify(
self,
address: Address,
contract_name: str,
standard_json: dict,
solc_json: dict,
license_type: str = None,
wait: bool = False,
) -> Optional["VerificationResult"]:
"""
Verify the Vyper contract on Blockscout.
:param address: The address of the contract.
:param contract_name: The name of the contract.
:param standard_json: The standard JSON output of the Vyper compiler.
:param solc_json: The solc_json output of the Vyper compiler.
:param license_type: The license to use for the contract. Defaults to "none".
:param wait: Whether to return a VerificationResult immediately
or wait for verification to complete. Defaults to False
Expand All @@ -57,13 +57,13 @@ def verify(
url = f"{self.uri}/api/v2/smart-contracts/{address}/"
url += f"verification/via/vyper-standard-input?apikey={api_key}"
data = {
"compiler_version": standard_json["compiler_version"],
"compiler_version": solc_json["compiler_version"],
"license_type": license_type,
}
files = {
"files[0]": (
contract_name,
json.dumps(standard_json).encode("utf-8"),
json.dumps(solc_json).encode("utf-8"),
"application/json",
)
}
Expand Down Expand Up @@ -137,7 +137,18 @@ def set_verifier(verifier):
return Open(get_verifier, _set_verifier, verifier)


def verify(contract, verifier=None, license_type: str = None) -> VerificationResult:
def get_verification_bundle(contract_like):
if not hasattr(contract_like, "deployer"):
return None
if not hasattr(contract_like.deployer, "solc_json"):
return None
return contract_like.deployer.solc_json


# should we also add a `verify_deployment` function?
def verify(
contract, verifier=None, license_type: str = None, wait=False
) -> VerificationResult:
"""
Verifies the contract on a block explorer.
:param contract: The contract to verify.
Expand All @@ -148,15 +159,13 @@ def verify(contract, verifier=None, license_type: str = None) -> VerificationRes
if verifier is None:
verifier = get_verifier()

if not hasattr(contract, "deployer") or not hasattr(
contract.deployer, "standard_json"
):
if (bundle := get_verification_bundle(contract)) is None:
raise ValueError(f"Not a contract! {contract}")

address = contract.address
return verifier.verify(
address=address,
standard_json=contract.deployer.standard_json,
address=contract.address,
solc_json=bundle,
contract_name=contract.contract_name,
license_type=license_type,
wait=wait,
)
3 changes: 2 additions & 1 deletion tests/integration/network/anvil/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from eth_account import Account

import boa
from boa.deployments import DeploymentsDB, set_deployments_db
from boa.network import NetworkEnv

ANVIL_FORK_PKEYS = [
Expand Down Expand Up @@ -76,7 +77,7 @@ def anvil_env(free_port):
# max coverage across VM implementations?
@pytest.fixture(scope="module", autouse=True)
def networked_env(accounts, anvil_env):
with boa.swap_env(anvil_env):
with boa.swap_env(anvil_env), set_deployments_db(DeploymentsDB(":memory:")):
for account in accounts:
boa.env.add_account(account)
yield
28 changes: 28 additions & 0 deletions tests/integration/network/anvil/test_network_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@

import boa
import boa.test.strategies as vy
from boa.deployments import DeploymentsDB, set_deployments_db
from boa.network import NetworkEnv
from boa.rpc import to_bytes
from boa.util.abi import Address

code = """
totalSupply: public(uint256)
Expand Down Expand Up @@ -68,3 +71,28 @@ def test_failed_transaction():


# XXX: probably want to test deployment revert behavior


def test_deployment_db():
with set_deployments_db(DeploymentsDB(":memory:")) as db:
arg = 5

# contract is written to deployments db
contract = boa.loads(code, arg)

# test get_deployments()
deployment = db.get_deployments()[-1]

initcode = contract.compiler_data.bytecode + arg.to_bytes(32, "big")

# sanity check all the fields
assert deployment.contract_address == contract.address
assert deployment.contract_name == contract.contract_name
assert deployment.deployer == boa.env.eoa
assert deployment.rpc == boa.env._rpc.name
assert deployment.source_code == contract.deployer.solc_json

# some sanity checks on tx_dict and rx_dict fields
assert to_bytes(deployment.tx_dict["data"]) == initcode
assert deployment.tx_dict["chainId"] == hex(boa.env.get_chain_id())
assert Address(deployment.receipt_dict["contractAddress"]) == contract.address
Loading

0 comments on commit f00e12b

Please sign in to comment.