Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add items iterator #38

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 84 additions & 36 deletions envier/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import deque
from collections import namedtuple
import os
import typing as t
Expand Down Expand Up @@ -68,6 +69,12 @@ def __init__(
self.help_type = help_type
self.help_default = help_default

self._full_name = _normalized(name) # Will be set by the EnvMeta metaclass

@property
def full_name(self) -> str:
return f"_{self._full_name}" if self.private else self._full_name

def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any:
if _type is bool:
return t.cast(T, raw.lower() in env.__truthy__)
Expand Down Expand Up @@ -100,9 +107,7 @@ def _cast(self, _type: t.Any, raw: str, env: "Env") -> t.Any:
def _retrieve(self, env: "Env", prefix: str) -> T:
source = env.source

full_name = prefix + _normalized(self.name)
if self.private:
full_name = f"_{full_name}"
full_name = self.full_name
raw = source.get(full_name.format(**env.dynamic))
if raw is None and self.deprecations:
for name, deprecated_when, removed_when in self.deprecations:
Expand Down Expand Up @@ -167,10 +172,8 @@ def __call__(self, env: "Env", prefix: str) -> T:
try:
self.validator(value)
except ValueError as e:
full_name = prefix + _normalized(self.name)
raise ValueError(
"Invalid value for environment variable %s: %s" % (full_name, e)
)
msg = f"Invalid value for environment variable {self.full_name}: {e}"
raise ValueError(msg)

return value

Expand All @@ -191,7 +194,22 @@ def __call__(self, env: "Env") -> T:
return value


class Env(object):
class EnvMeta(type):
def __new__(
cls, name: str, bases: t.Tuple[t.Type], ns: t.Dict[str, t.Any]
) -> t.Any:
env = t.cast("Env", super().__new__(cls, name, bases, ns))

prefix = ns.get("__prefix__")
if prefix:
for v in env.values(recursive=True):
if isinstance(v, EnvVariable):
v._full_name = f"{_normalized(prefix)}_{v._full_name}".upper()

return env


class Env(metaclass=EnvMeta):
"""Env base class.

This class is meant to be subclassed. The configuration is declared by using
Expand Down Expand Up @@ -336,26 +354,42 @@ def d(
return DerivedVariable(type, derivation)

@classmethod
def keys(cls) -> t.Iterator[str]:
"""Return the names of all the items."""
return (
k
for k, v in cls.__dict__.items()
if isinstance(v, (EnvVariable, DerivedVariable))
or isinstance(v, type)
and issubclass(v, Env)
)
def items(
cls, recursive: bool = False, include_derived: bool = False
) -> t.Iterator[t.Tuple[str, t.Union[EnvVariable, DerivedVariable]]]:
classes = (EnvVariable, DerivedVariable) if include_derived else (EnvVariable,)
q: t.Deque[t.Tuple[t.Tuple[str], t.Type["Env"]]] = deque()
path: t.Tuple[str] = tuple() # type: ignore[assignment]
q.append((path, cls))
while q:
path, env = q.popleft()
for k, v in env.__dict__.items():
if isinstance(v, classes):
yield (
".".join((*path, k)),
t.cast(t.Union[EnvVariable, DerivedVariable], v),
)
elif isinstance(v, type) and issubclass(v, Env) and recursive:
item_name = getattr(v, "__item__", k)
if item_name is None:
item_name = k
q.append(((*path, item_name), v)) # type: ignore[arg-type]

@classmethod
def values(cls) -> t.Iterator[t.Union[EnvVariable, DerivedVariable, t.Type["Env"]]]:
"""Return the names of all the items."""
return (
v
for v in cls.__dict__.values()
if isinstance(v, (EnvVariable, DerivedVariable))
or isinstance(v, type)
and issubclass(v, Env)
)
def keys(
cls, recursive: bool = False, include_derived: bool = False
) -> t.Iterator[str]:
"""Return the name of all the configuration items."""
for k, _ in cls.items(recursive, include_derived):
yield k

@classmethod
def values(
cls, recursive: bool = False, include_derived: bool = False
) -> t.Iterator[t.Union[EnvVariable, DerivedVariable, t.Type["Env"]]]:
"""Return the value of all the configuration items."""
for _, v in cls.items(recursive, include_derived):
yield v

@classmethod
def include(
Expand All @@ -371,14 +405,6 @@ def include(
operation would result in some variables being overwritten. This can
be disabled by setting the ``overwrite`` argument to ``True``.
"""
if namespace is not None:
if not overwrite and hasattr(cls, namespace):
raise ValueError("Namespace already in use: {}".format(namespace))

setattr(cls, namespace, env_spec)

return None

# Pick only the attributes that define variables.
to_include = {
k: v
Expand All @@ -387,14 +413,36 @@ def include(
or isinstance(v, type)
and issubclass(v, Env)
}

if not overwrite:
overlap = set(cls.__dict__.keys()) & set(to_include.keys())
if overlap:
raise ValueError("Configuration clashes detected: {}".format(overlap))

own_prefix = _normalized(getattr(cls, "__prefix__", ""))

if namespace is not None:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟠 Code Quality Violation

Found too many nested ifs within this condition (...read more)

Too many nested loops make the code hard to read and understand. Simplify your code by removing nesting levels and separate code in small units.

View in Datadog  Leave us feedback  Documentation

if not overwrite and hasattr(cls, namespace):
raise ValueError("Namespace already in use: {}".format(namespace))

if getattr(cls, namespace, None) is not env_spec:
setattr(cls, namespace, env_spec)

if own_prefix:
for _, v in to_include.items():
if isinstance(v, EnvVariable):
v._full_name = f"{own_prefix}_{v._full_name}"
Comment on lines +432 to +433

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Code Quality Violation

too many nesting levels (...read more)

Avoid to nest too many loops together. Having too many loops make your code harder to understand.
Prefer to organize your code in functions and unit of code you can clearly understand.

Learn More

View in Datadog  Leave us feedback  Documentation


return None

other_prefix = getattr(env_spec, "__prefix__", "")
for k, v in to_include.items():
setattr(cls, k, v)
if getattr(cls, k, None) is not v:
setattr(cls, k, v)
if isinstance(v, EnvVariable):
if other_prefix:
v._full_name = v._full_name[len(other_prefix) + 1 :] # noqa
if own_prefix:
v._full_name = f"{own_prefix}_{v._full_name}"
Comment on lines +444 to +445

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Code Quality Violation

too many nesting levels (...read more)

Avoid to nest too many loops together. Having too many loops make your code harder to understand.
Prefer to organize your code in functions and unit of code you can clearly understand.

Learn More

View in Datadog  Leave us feedback  Documentation


@classmethod
def help_info(
Expand Down
50 changes: 48 additions & 2 deletions tests/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,12 @@ class GlobalConfig(Env):
service = ServiceConfig

config = GlobalConfig()
assert set(config.keys()) == {"debug_mode", "service"}
assert set(config.keys()) == {"debug_mode"}
assert set(config.keys(recursive=True)) == {
"debug_mode",
"service.host",
"service.port",
}
assert config.service.port == 8080


Expand All @@ -178,11 +183,23 @@ class ServiceConfig(Env):

host = Env.var(str, "host", default="localhost")
port = Env.var(int, "port", default=3000)
_private = Env.var(int, "private", default=42, private=True)

config = GlobalConfig()
assert set(config.keys()) == {"debug_mode", "service"}
assert set(config.keys()) == {"debug_mode"}
assert set(config.keys(recursive=True)) == {
"debug_mode",
"service.host",
"service.port",
"service._private",
}
assert config.service.port == 8080

assert GlobalConfig.debug_mode.full_name == "MYAPP_DEBUG"
assert GlobalConfig.service.host.full_name == "MYAPP_SERVICE_HOST"
assert GlobalConfig.service.port.full_name == "MYAPP_SERVICE_PORT"
assert GlobalConfig.service._private.full_name == "_MYAPP_SERVICE_PRIVATE"


def test_env_include():
class GlobalConfig(Env):
Expand Down Expand Up @@ -383,3 +400,32 @@ class Config(Env):
("_PRIVATE_FOO", "int", "42", ""),
("PUBLIC_FOO", "int", "42", ""),
}

assert Config.private.full_name == "_PRIVATE_FOO"


def test_env_items(monkeypatch):
monkeypatch.setenv("MYAPP_SERVICE_PORT", "8080")

class GlobalConfig(Env):
__prefix__ = "myapp"

debug_mode = Env.var(bool, "debug", default=False)

class ServiceConfig(Env):
__item__ = __prefix__ = "service"

host = Env.var(str, "host", default="localhost")
port = Env.var(int, "port", default=3000)
_private = Env.var(int, "private", default=42, private=True)

items = list(GlobalConfig.items())
assert items == [("debug_mode", GlobalConfig.debug_mode)]

items = list(GlobalConfig.items(recursive=True))
assert items == [
("debug_mode", GlobalConfig.debug_mode),
("service.host", GlobalConfig.ServiceConfig.host),
("service.port", GlobalConfig.ServiceConfig.port),
("service._private", GlobalConfig.ServiceConfig._private),
]
Loading