Skip to content

Commit

Permalink
Merge pull request #149 from 0Hughman0/sharing-fix
Browse files Browse the repository at this point in the history
Sharing fix
  • Loading branch information
0Hughman0 authored Aug 22, 2024
2 parents 5e9889c + b1aa188 commit 6a180b0
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 27 deletions.
8 changes: 4 additions & 4 deletions cassini/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

if TYPE_CHECKING:
from .core import TierABC, Project
from .sharing import SharedProject
from .sharing import ShareableProject


class _Env:
Expand Down Expand Up @@ -41,7 +41,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> _Env:
def __init__(self) -> None:
self.project: Union[Project, None] = None
self._o: Union[TierABC, None] = None
self.shareable_project: Union[SharedProject, None] = None
self.shareable_project: Union[ShareableProject, None] = None
self._caches: List[Dict[Any, Any]] = []

@staticmethod
Expand Down Expand Up @@ -89,12 +89,12 @@ def _reset(self):


class _SharingInstance(_Env):
shareable_project: SharedProject
shareable_project: ShareableProject
project: Project


class _SharedInstance(_Env):
shareable_project: SharedProject
shareable_project: ShareableProject
project: None


Expand Down
44 changes: 29 additions & 15 deletions cassini/sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ class SharingTier:
"""

def __init__(self, name: str):
self.shared_project: Union[None, SharedProject] = None
self.sharing_project: Union[None, ShareableProject] = None

self._accessed: Dict[str, Any] = {}
self._called: Dict[str, Dict[ArgsKwargsType, Any]] = defaultdict(dict)
Expand All @@ -243,26 +243,26 @@ def __init__(self, name: str):
self.meta: Union[Meta, None] = None

@classmethod
def with_project(cls, name: str, shared_project: SharedProject):
def with_project(cls, name: str, sharing_project: ShareableProject):
"""
Create a `SharingTier` object, and load it from `shared_project`.
Recommended way to create `SharingTier` objects in contexts where the `shared_project` is available.
"""
tier = cls(name)
tier.load(shared_project=shared_project)
tier.load(sharing_project=sharing_project)

shared_project.shared_tiers.append(tier)
sharing_project.shared_tiers.append(tier)

return tier

def load(self, shared_project: SharedProject):
def load(self, sharing_project: ShareableProject):
"""
Sync this `SharingTier` to wrap around the tier with name `self.name` from the `shared_project`.
"""
self.shared_project = shared_project
self.sharing_project = sharing_project

self._tier = shared_project.project[self.name]
self._tier = sharing_project.project[self.name]

self.meta = getattr(self._tier, "meta", None)

Expand All @@ -287,7 +287,10 @@ def handle_attr(self, name: str, val: Any) -> Any:
self._paths_used.append(val)

if isinstance(val, TierABC):
val = self._accessed[name] = SharingTier(val.name)
assert self.sharing_project
val = self._accessed[name] = SharingTier.with_project(
val.name, self.sharing_project
)

return val

Expand All @@ -296,7 +299,8 @@ def handle_call(self, method: str, args_kwargs: ArgsKwargsType, val: Any) -> Any
Handle call to a method to allow caching of the result.
"""
if isinstance(val, TierABC):
val = SharingTier(val.name)
assert self.sharing_project
val = SharingTier.with_project(val.name, self.sharing_project)

self._called[method][args_kwargs] = val

Expand All @@ -307,6 +311,11 @@ def handle_call(self, method: str, args_kwargs: ArgsKwargsType, val: Any) -> Any
return val

def __getattr__(self, name: str) -> Any:
if self.sharing_project is None:
raise RuntimeError(
"SharingTier attributes can be accessed until `SharingTier.load` is called"
)

val = getattr(self._tier, name)

if isinstance(val, MethodType):
Expand Down Expand Up @@ -410,22 +419,22 @@ class SharedTier:

def __init__(self, name: str) -> None:
self.name = name
self.shared_project: Union[None, SharedProject] = None
self.shared_project: Union[None, ShareableProject] = None
self.base_path: Union[Path, None] = None
self.meta: Union[Meta, None] = None
self._accessed: Dict[str, Any] = {}
self._called: Dict[str, Dict[ArgsKwargsType, Any]] = {}

@classmethod
def with_project(cls, name: str, shared_project: SharedProject):
def with_project(cls, name: str, shared_project: ShareableProject):
"""
Create a `SharedTier` instance, and load it from `shared_project`.
"""
tier = cls(name)
tier.load(shared_project)
return tier

def load(self, shared_project: SharedProject):
def load(self, shared_project: ShareableProject):
"""
Load the contents of the shared tier into this object from the `shared_project`.
"""
Expand Down Expand Up @@ -470,6 +479,11 @@ def adjust_path(self, path: Path) -> Path:
return self.shared_project.requires_path / path.relative_to(self.base_path)

def __getattr__(self, name: str) -> Any:
if self.shared_project is None:
raise RuntimeError(
"SharedTier attributes can be accessed until `SharedTier.load` is called"
)

if name in self._accessed:
val = self._accessed[name]

Expand Down Expand Up @@ -507,7 +521,7 @@ def __hash__(self) -> int:
return hash(self.name)


class SharedProject:
class ShareableProject:
"""
Shareable version of `Project`. Allows sharing of notebooks that use Cassini with users who don't have
Cassini set up.
Expand Down Expand Up @@ -559,7 +573,7 @@ def env(self, name: str) -> Union[SharedTier, SharingTier]:
Name of the tier to get.
"""
if self.project:
tier = SharingTier.with_project(name=name, shared_project=self)
tier = SharingTier.with_project(name=name, sharing_project=self)
env.update(tier)
return tier
else:
Expand All @@ -576,7 +590,7 @@ def __getitem__(self, name: str) -> Union[SharedTier, SharingTier]:
Name of the tier to get.
"""
if self.project:
return SharingTier.with_project(name=name, shared_project=self)
return SharingTier.with_project(name=name, sharing_project=self)
else:
return SharedTier.with_project(name=name, shared_project=self)

Expand Down
36 changes: 28 additions & 8 deletions tests/test_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
from typing import List, Any

from cassini.sharing import SharedProject, NoseyPath, SharedTierData, SharedTier, SharedTierCalls, GetChildCall, GetItemCall, TrueDivCall, SharedTierCall
from cassini.sharing import ShareableProject, NoseyPath, SharedTierData, SharedTier, SharedTierCalls, GetChildCall, GetItemCall, TrueDivCall, SharedTierCall
from cassini.testing_utils import get_Project, patch_project
from cassini.magics import hlt
from cassini import DEFAULT_TIERS, env
Expand Down Expand Up @@ -35,7 +35,7 @@ def mk_shared_project(tmp_path):
env.project = None
env.shareable_project = None

shared_project = SharedProject(location=shared)
shared_project = ShareableProject(location=shared)
stier = shared_project['WP1.1']
return stier, shared_project

Expand Down Expand Up @@ -99,7 +99,7 @@ def test_nosey_path_compressing():
def test_attribute_caching(get_Project, tmp_path):
Project = get_Project
project = Project(DEFAULT_TIERS, tmp_path)
shared_project = SharedProject()
shared_project = ShareableProject()
project.setup_files()

tier = project['WP1']
Expand Down Expand Up @@ -139,7 +139,7 @@ def test_attribute_caching(get_Project, tmp_path):
def test_stier_path_finding(get_Project, tmp_path):
Project = get_Project
project = Project(DEFAULT_TIERS, tmp_path)
shared_project = SharedProject()
shared_project = ShareableProject()
project.setup_files()

tier = project['WP1.1']
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_meta_wrapping(get_Project, tmp_path):
tier.parent.setup_files()
tier.setup_files()

shared_project = SharedProject()
shared_project = ShareableProject()
stier = shared_project.env('WP1.1')

assert stier.meta is tier.meta
Expand All @@ -258,7 +258,7 @@ def test_meta_wrapping(get_Project, tmp_path):
def test_making_share(get_Project, tmp_path):
Project = get_Project
project = Project(DEFAULT_TIERS, tmp_path)
shared_project = SharedProject(location=tmp_path / 'shared')
shared_project = ShareableProject(location=tmp_path / 'shared')
project.setup_files()

tier = project['WP1.1']
Expand All @@ -275,7 +275,7 @@ def test_making_share(get_Project, tmp_path):
shared_project.make_shared()

env.shareable_project = None
shared_project = SharedProject(location=tmp_path / 'shared')
shared_project = ShareableProject(location=tmp_path / 'shared')
shared_project.project = None

shared_tier = shared_project.env('WP1.1')
Expand All @@ -295,7 +295,7 @@ def test_making_share(get_Project, tmp_path):
def test_no_meta(get_Project, tmp_path):
Project = get_Project
project = Project(DEFAULT_TIERS, tmp_path)
shared_project = SharedProject(location=tmp_path / 'shared')
shared_project = ShareableProject(location=tmp_path / 'shared')
project.setup_files()

project['WP1'].setup_files()
Expand Down Expand Up @@ -332,3 +332,23 @@ def test_no_magics(mk_shared_project):
out = hlt('hlt', 'print("cell")')

assert out == 'print("cell")'


def test_getting_tier_children(get_Project, tmp_path):
Project = get_Project
project = Project(DEFAULT_TIERS, tmp_path)
shared_project = ShareableProject(location=tmp_path / 'shared')
project.setup_files()
project['WP1'].setup_files()
project['WP1.1'].setup_files()

(project['WP1.1'] / 'file').write_text('data')

stier = shared_project['WP1']
stier_child = stier['1']

assert stier_child.exists()
assert (stier_child / 'file').read_text() == 'data'



0 comments on commit 6a180b0

Please sign in to comment.