Skip to content

Commit

Permalink
add types whitelist (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
aguschin authored Mar 24, 2022
1 parent cd47e83 commit 5619e99
Show file tree
Hide file tree
Showing 5 changed files with 110 additions and 48 deletions.
69 changes: 38 additions & 31 deletions gto/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

from pydantic import BaseSettings, validator
from pydantic import BaseSettings
from pydantic.env_settings import InitSettingsSource
from ruamel.yaml import YAML

from gto.constants import BRANCH, COMMIT, TAG
from gto.exceptions import UnknownEnvironment, UnknownType
from gto.versions import AbstractVersion

from .constants import BRANCH, COMMIT, TAG
from .exceptions import UnknownEnvironment

yaml = YAML(typ="safe", pure=True)
yaml.default_flow_style = False

Expand Down Expand Up @@ -44,16 +43,24 @@ def config_settings_source(settings: "RegistryConfig") -> Dict[str, Any]:

class RegistryConfig(BaseSettings):
INDEX: str = "artifacts.yaml"
TYPE_ALLOWED: List[str] = []
VERSION_BASE: str = TAG
VERSION_CONVENTION: str = "NumberedVersion"
VERSION_REQUIRED_FOR_ENV: bool = True
ENV_BASE: str = TAG
ENV_WHITELIST: List[str] = []
ENV_ALLOWED: List[str] = []
ENV_BRANCH_MAPPING: Dict[str, str] = {}
LOG_LEVEL: str = "INFO"
DEBUG: bool = False
CONFIG_FILE: Optional[str] = CONFIG_FILE

def assert_type(self, name):
if not self.check_type(name):
raise UnknownType(name, self.TYPE_ALLOWED)

def check_type(self, name):
return name in self.TYPE_ALLOWED or not self.TYPE_ALLOWED

@property
def VERSION_SYSTEM_MAPPING(self):
from .versions import NumberedVersion, SemVer
Expand All @@ -76,15 +83,15 @@ def ENV_MANAGERS_MAPPING(self):

def assert_env(self, name):
if not self.check_env(name):
raise UnknownEnvironment(name)
raise UnknownEnvironment(name, self.envs)

def check_env(self, name):
return name in self.envs or not self.envs

@property
def envs(self) -> List[str]:
if self.ENV_BASE == TAG:
return self.ENV_WHITELIST
return self.ENV_ALLOWED
if self.ENV_BASE == BRANCH:
return list(self.ENV_BRANCH_MAPPING)
raise NotImplementedError("Unknown ENV_BASE")
Expand Down Expand Up @@ -132,30 +139,30 @@ def customise_sources(
# raise ValueError(f"ENV_BASE must be one of: {cls.ENV_MANAGERS_MAPPING.keys()}")
# return value

@validator("ENV_WHITELIST", always=True)
def validate_env_whitelist(cls, value, values):
if values["ENV_BASE"] == BRANCH:
# logging.warning("ENV_WHITELIST is ignored when ENV_BASE is BRANCH")
pass
return value

@validator("ENV_BRANCH_MAPPING", always=True)
def validate_env_branch_mapping(
cls, value: Dict[str, str], values
) -> Dict[str, str]:
if values["ENV_BASE"] != BRANCH:
# logging.warning("ENV_BRANCH_MAPPING is ignored when ENV_BASE is not BRANCH")
return value
if not isinstance(value, dict):
raise ValueError(
f"ENV_BRANCH_MAPPING must be a dict, got {type(value)}",
"ENV_BRANCH_MAPPING",
)
if not all(isinstance(k, str) and isinstance(v, str) for k, v in value.items()):
raise ValueError(
"ENV_BRANCH_MAPPING must be a dict of str:str", "ENV_BRANCH_MAPPING"
)
return value
# @validator("ENV_WHITELIST", always=True)
# def validate_env_whitelist(cls, value, values):
# if values["ENV_BASE"] == BRANCH:
# # logging.warning("ENV_WHITELIST is ignored when ENV_BASE is BRANCH")
# pass
# return value

# @validator("ENV_BRANCH_MAPPING", always=True)
# def validate_env_branch_mapping(
# cls, value: Dict[str, str], values
# ) -> Dict[str, str]:
# if values["ENV_BASE"] != BRANCH:
# # logging.warning("ENV_BRANCH_MAPPING is ignored when ENV_BASE is not BRANCH")
# return value
# if not isinstance(value, dict):
# raise ValueError(
# f"ENV_BRANCH_MAPPING must be a dict, got {type(value)}",
# "ENV_BRANCH_MAPPING",
# )
# if not all(isinstance(k, str) and isinstance(v, str) for k, v in value.items()):
# raise ValueError(
# "ENV_BRANCH_MAPPING must be a dict of str:str", "ENV_BRANCH_MAPPING"
# )
# return value

@property
def versions_class(self) -> AbstractVersion:
Expand Down
17 changes: 12 additions & 5 deletions gto/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ def __init__(self, path) -> None:
super().__init__(self.message)


class UnknownType(GTOException):
_message = (
"Type '{type}' is not present in your config file. Allowed values are: {types}."
)

def __init__(self, type, types) -> None:
self.message = self._message.format(type=type, types=types)
super().__init__(self.message)


class ArtifactExists(GTOException):
_message = "Artifact '{name}' is already exists in Index"

Expand Down Expand Up @@ -97,11 +107,8 @@ def __init__(self, latest, suggested) -> None:
class UnknownEnvironment(GTOException):
_message = "Environment '{env}' is not present in your config file. Allowed envs are: {envs}."

def __init__(self, env) -> None:
# to avoid circular import
from .config import CONFIG # pylint: disable=import-outside-toplevel

self.message = self._message.format(env=env, envs=CONFIG.ENV_WHITELIST)
def __init__(self, env, envs) -> None:
self.message = self._message.format(env=env, envs=envs)
super().__init__(self.message)


Expand Down
30 changes: 21 additions & 9 deletions gto/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import git
from pydantic import BaseModel, parse_obj_as

from .config import CONFIG, yaml
from .config import CONFIG_FILE, RegistryConfig, yaml
from .exceptions import ArtifactExists, ArtifactNotFound, NoFile, NoRepo, PathIsUsed


Expand Down Expand Up @@ -107,6 +107,7 @@ def remove(self, name):

class BaseIndexManager(BaseModel, ABC):
current: Optional[Index]
config: RegistryConfig

@abstractmethod
def get_index(self) -> Index:
Expand All @@ -121,11 +122,12 @@ def get_history(self) -> Dict[str, Index]:
raise NotImplementedError

def add(self, type, name, path, virtual=False):
index = self.get_index()
self.config.assert_type(type)
if not virtual and not check_if_path_exists(
path, self.repo if hasattr(self, "repo") else None
):
raise NoFile(path)
index = self.get_index()
index.add(type, name, path, virtual)
self.update()

Expand All @@ -138,8 +140,14 @@ def remove(self, name):
class FileIndexManager(BaseIndexManager):
path: str = ""

@classmethod
def from_path(cls, path: str, config: RegistryConfig = None):
if config is None:
config = RegistryConfig(CONFIG_FILE=os.path.join(path, CONFIG_FILE))
return cls(path=path, config=config)

def index_path(self):
return str(Path(self.path) / CONFIG.INDEX)
return str(Path(self.path) / self.config.INDEX)

def get_index(self) -> Index:
if os.path.exists(self.index_path()):
Expand All @@ -163,24 +171,28 @@ class RepoIndexManager(FileIndexManager):
repo: git.Repo

@classmethod
def from_repo(cls, repo: Union[str, git.Repo]):
def from_repo(cls, repo: Union[str, git.Repo], config: RegistryConfig = None):
if isinstance(repo, str):
try:
repo = git.Repo(repo, search_parent_directories=True)
except git.InvalidGitRepositoryError as e:
raise NoRepo(repo) from e
return cls(repo=repo)
if config is None:
config = RegistryConfig(
CONFIG_FILE=os.path.join(repo.working_dir, CONFIG_FILE)
)
return cls(repo=repo, config=config)

def index_path(self):
# TODO: config should be loaded from repo too
return os.path.join(os.path.dirname(self.repo.git_dir), CONFIG.INDEX)
return os.path.join(os.path.dirname(self.repo.git_dir), self.config.INDEX)

class Config:
arbitrary_types_allowed = True

def get_commit_index(self, ref: str) -> Index:
return Index.read(
(self.repo.commit(ref).tree / CONFIG.INDEX).data_stream, frozen=True
(self.repo.commit(ref).tree / self.config.INDEX).data_stream, frozen=True
)

def get_history(self) -> Dict[str, Index]:
Expand All @@ -192,7 +204,7 @@ def get_history(self) -> Dict[str, Index]:
return {
commit.hexsha: self.get_commit_index(commit.hexsha)
for commit in commits
if CONFIG.INDEX in commit.tree
if self.config.INDEX in commit.tree
}

def artifact_centric_representation(self) -> ArtifactCommits:
Expand All @@ -214,4 +226,4 @@ def init_index_manager(path):
try:
return RepoIndexManager.from_repo(path)
except NoRepo:
return FileIndexManager(path)
return FileIndexManager.from_path(path)
5 changes: 2 additions & 3 deletions gto/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Config:
arbitrary_types_allowed = True

@classmethod
def from_repo(cls, repo=Union[str, Repo], config=None):
def from_repo(cls, repo=Union[str, Repo], config: RegistryConfig = None):
if isinstance(repo, str):
try:
repo = git.Repo(repo, search_parent_directories=True)
Expand All @@ -49,7 +49,7 @@ def from_repo(cls, repo=Union[str, Repo], config=None):

@property
def index(self):
return RepoIndexManager(repo=self.repo)
return RepoIndexManager.from_repo(self.repo)

@property
def state(self):
Expand Down Expand Up @@ -84,7 +84,6 @@ def register(self, name, ref, version=None, bump=None):
is not None
):
raise VersionAlreadyRegistered(version)
print(found_artifact.versions)
if found_artifact.versions:
latest_ver = found_artifact.get_latest_version(
include_deprecated=True
Expand Down
37 changes: 37 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pytest

from gto.api import add
from gto.config import CONFIG_FILE
from gto.exceptions import UnknownType
from gto.index import init_index_manager
from gto.registry import GitRegistry


@pytest.fixture
def init_repo(empty_git_repo):
repo, write_file = empty_git_repo

write_file(
CONFIG_FILE,
"type_allowed: [model, dataset]",
)
return repo


def test_config_load_index(init_repo):
index = init_index_manager(init_repo)
assert index.config.TYPE_ALLOWED == ["model", "dataset"]


def test_config_load_registry(init_repo):
registry = GitRegistry.from_repo(init_repo)
assert registry.config.TYPE_ALLOWED == ["model", "dataset"]


def test_adding_allowed_type(init_repo):
add(init_repo, "model", "name", "path", virtual=True)


def test_adding_not_allowed_type(init_repo):
with pytest.raises(UnknownType):
add(init_repo, "unknown", "name", "path", virtual=True)

0 comments on commit 5619e99

Please sign in to comment.