Skip to content

Commit

Permalink
fix: isolation fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Sep 13, 2024
1 parent fe0415a commit 06600c2
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 151 deletions.
166 changes: 68 additions & 98 deletions src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import inspect
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, field
from fnmatch import fnmatch
Expand Down Expand Up @@ -98,50 +97,64 @@ def _function_isolation(self) -> Iterator[None]:

@dataclass
class Snapshot:
"""
All the data necessary for accurately supporting isolation.
"""

scope: "Scope" # Assuming 'Scope' is defined elsewhere
identifier: Optional[str] = None
"""Corresponds to fixture scope."""

identifier: Optional[SnapshotID] = None
"""Snapshot ID taken before the peer-fixtures in the same scope."""

fixtures: list = field(default_factory=list)
"""All peer fixtures, tracked so we know when new ones are added."""

def append_fixtures(self, fixtures: Iterable[str]):
for fixture in fixtures:
if fixture in self.fixtures:
continue

def _get_lower_scopes(scope: Scope) -> tuple[Scope, ...]:
if scope is Scope.SESSION:
return (Scope.FUNCTION, Scope.CLASS, Scope.MODULE, Scope.PACKAGE)
elif scope is Scope.PACKAGE:
return (Scope.FUNCTION, Scope.CLASS, Scope.MODULE)
elif scope is Scope.MODULE:
return (Scope.FUNCTION, Scope.CLASS)
elif scope is Scope.CLASS:
return (Scope.FUNCTION,)
self.fixtures.append(fixture)

return ()

class SnapshotRegistry(dict[Scope, Snapshot]):
def __init__(self):
super().__init__(
{
Scope.SESSION: Snapshot(Scope.SESSION),
Scope.PACKAGE: Snapshot(Scope.PACKAGE),
Scope.MODULE: Snapshot(Scope.MODULE),
Scope.CLASS: Snapshot(Scope.CLASS),
Scope.FUNCTION: Snapshot(Scope.FUNCTION),
}
)

class IsolationManager(ManagerAccessMixin):
INVALID_KEY = "__invalid_snapshot__"
def get_snapshot_id(self, scope: Scope) -> Optional[SnapshotID]:
return self[scope].identifier

def set_snapshot_id(self, scope: Scope, snapshot_id: SnapshotID):
self[scope].identifier = snapshot_id

_supported: bool = True
_snapshot_registry: dict[Scope, Snapshot] = {
Scope.SESSION: Snapshot(Scope.SESSION),
Scope.PACKAGE: Snapshot(Scope.PACKAGE),
Scope.MODULE: Snapshot(Scope.MODULE),
Scope.CLASS: Snapshot(Scope.CLASS),
Scope.FUNCTION: Snapshot(Scope.FUNCTION),
}
def clear_snapshot_id(self, scope: Scope):
self[scope].identifier = None

def next_snapshots(self, scope: Scope) -> Iterator[Snapshot]:
for lower_scope in scope.lower_scopes:
yield self[lower_scope]

def extend_fixtures(self, scope: Scope, fixtures: Iterable[str]):
self[scope].fixtures.extend(fixtures)


class IsolationManager(ManagerAccessMixin):
supported: bool = True
snapshots: SnapshotRegistry = SnapshotRegistry()

def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCapture"):
self.config_wrapper = config_wrapper
self.receipt_capture = receipt_capture

@cached_property
def builtin_ape_fixtures(self) -> tuple[str, ...]:
return tuple(
[
n
for n, itm in inspect.getmembers(PytestApeFixtures)
if callable(itm) and not n.startswith("_")
]
)

@cached_property
def _track_transactions(self) -> bool:
return (
Expand All @@ -150,33 +163,18 @@ def _track_transactions(self) -> bool:
and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage)
)

def update_fixtures(self, scope: Scope, fixtures: Iterable[str]):
snapshot = self._snapshot_registry[scope]
if not (
new_fixtures := [
p
for p in fixtures
if p not in snapshot.fixtures and p not in self.builtin_ape_fixtures
]
):
return

# If the snapshot is already set, we have to invalidate it.
# We need to replace the snapshot with one that happens after
# the new fixtures.
# if snapshot is not None:
# breakpoint()
# self._snapshot_registry[scope].identifier = self.INVALID_KEY
def get_snapshot(self, scope: Scope) -> Snapshot:
return self.snapshots[scope]

# Add or update peer-fixtures.
self._snapshot_registry[scope].fixtures.extend(new_fixtures)
def extend_fixtures(self, scope: Scope, fixtures: Iterable[str]):
self.snapshots.extend_fixtures(scope, fixtures)

def isolation(self, scope: Scope) -> Iterator[None]:
"""
Isolation logic used to implement isolation fixtures for each pytest scope.
When tracing support is available, will also assist in capturing receipts.
"""
self._set_snapshot(scope)
self.set_snapshot(scope)
if self._track_transactions:
did_yield = False
try:
Expand All @@ -193,47 +191,26 @@ def isolation(self, scope: Scope) -> Iterator[None]:

# NOTE: self._supported may have gotten set to False
# someplace else _after_ snapshotting succeeded.
if not self._supported:
if not self.supported:
return

self._restore(scope)
self.restore(scope)

def _set_snapshot(self, scope: Scope):
def set_snapshot(self, scope: Scope):
# Also can be used to re-set snapshot.
if not self._supported:
if not self.supported:
return

# Here is something tricky: If a snapshot exists
# already at a lower-level, we must use that one.
# Like if a session comes in _after_ a module, have
# the session just use the module.
# Else, it falls apart.
snapshot_id = None
if scope is not Scope.FUNCTION:
lower_scopes = _get_lower_scopes(scope)
for lower_scope in lower_scopes:
snapshot = self._snapshot_registry[lower_scope]
if snapshot.identifier is not None:
snapshot_id = snapshot.identifier
break

try:
snapshot_id = self.take_snapshot()
except Exception:
self.supported = False
else:
if snapshot_id is not None:
# Clear out others
for lower_scope in lower_scopes:
snapshot = self._snapshot_registry[lower_scope]
snapshot.identifier = None

if snapshot_id is None:
try:
snapshot_id = self._take_snapshot()
except Exception:
self._supported = False

if snapshot_id is not None:
self._snapshot_registry[scope].identifier = snapshot_id
self.snapshots.set_snapshot_id(scope, snapshot_id)

@allow_disconnected
def _take_snapshot(self) -> Optional[SnapshotID]:
def take_snapshot(self) -> Optional[SnapshotID]:
try:
return self.chain_manager.snapshot()
except NotImplementedError:
Expand All @@ -242,19 +219,19 @@ def _take_snapshot(self) -> Optional[SnapshotID]:
"Tests will not be completely isolated."
)
# To avoid trying again
self._supported = False
self.supported = False

return None

@allow_disconnected
def _restore(self, scope: Scope):
snapshot_id = self._snapshot_registry[scope].identifier
def restore(self, scope: Scope):
snapshot_id = self.snapshots.get_snapshot_id(scope)
if snapshot_id is None:
return

elif snapshot_id not in self.chain_manager._snapshots or snapshot_id == self.INVALID_KEY:
elif snapshot_id not in self.chain_manager._snapshots:
# Still clear out.
self._snapshot_registry[scope].identifier = None
self.snapshots.clear_snapshot_id(scope)
return

try:
Expand All @@ -265,16 +242,9 @@ def _restore(self, scope: Scope):
"Tests will not be completely isolated."
)
# To avoid trying again
self._supported = False

self._snapshot_registry[scope].identifier = None
self.supported = False

# If we are reverting to a session-state, there is no
# reason to revert back to a function state (if one exists).
# and so forth.
lower_scopes = _get_lower_scopes(scope)
for lower_scope in lower_scopes:
self._snapshot_registry[lower_scope].identifier = None
self.snapshots.clear_snapshot_id(scope)


class ReceiptCapture(ManagerAccessMixin):
Expand Down
Loading

0 comments on commit 06600c2

Please sign in to comment.