Skip to content

Commit

Permalink
Merge pull request #11 from OlofHarrysson/develop
Browse files Browse the repository at this point in the history
Release 0.2.0
  • Loading branch information
OlofHarrysson authored Dec 23, 2020
2 parents f1f90af + 17b31f3 commit 8c64069
Show file tree
Hide file tree
Showing 21 changed files with 870 additions and 339 deletions.
30 changes: 29 additions & 1 deletion anyfig/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
__author__ = """Olof Harrysson"""
__email__ = 'harrysson.olof@gmail.com'
__version__ = '0.1.0'
__version__ = '0.2.0'

import sys
from functools import wraps
from anyfig.figutils import *
from anyfig.anyfig_setup import *

Expand All @@ -12,3 +13,30 @@
from anyfig.fields import *
else:
from anyfig.dummyfields import *


def get_global_cfg(func):
''' Decorator for GlobalConfig methods. Saves the config if it's not already saved '''
@wraps(func)
def wrapper(*args, **kwargs):
self = args[0]
if self.global_cfg is None:
self.global_cfg = get_config()
return func(*args, **kwargs)

return wrapper


class GlobalConfig:
global_cfg = None

@get_global_cfg
def __getattr__(self, name):
return getattr(self.global_cfg, name)

@get_global_cfg
def __str__(self):
return str(self.global_cfg)


global_cfg = GlobalConfig()
43 changes: 23 additions & 20 deletions anyfig/anyfig_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,20 @@ def init_config(default_config, cli_args=None):
# Create config
config_str = cli_args.pop('config', default_config.__name__)
config = create_config(config_str)
fields.validate_fields(config)

# Print config help
if 'help' in cli_args:
config_classes = list(figutils.get_config_classes())
print(
f"Available config classes {config_classes}. Set config with --config=OtherConfigClass\n",
f"\nCurrent config is '{config_str}'. The available input arguments are")

help_string = config.comments_string()
print(help_string)
if 'help' in cli_args or 'h' in cli_args:
print(config.cli_help())
sys.exit(0)

# Overwrite parameters via optional input flags
config = overwrite(config, cli_args)

# Resolve required values
# Perform deep post init after input flags
figutils.post_init(config)

# Unwrap the field values
fields.resolve_fields(config)

# Freezes config
Expand Down Expand Up @@ -110,20 +108,20 @@ def overwrite(main_config_obj, args):
config_obj = main_config_obj
config_class = type(config_obj).__name__

for key_part in outer_keys:
for key_idx, key_part in enumerate(argument_key.split('.')):
err_msg = f"{base_err_msg}. '{key_part}' isn't an attribute in '{config_class}'"
assert hasattr(config_obj, key_part), err_msg

config_obj = getattr(config_obj, key_part)
config_class = type(config_obj).__name__
err_msg = f"{base_err_msg}. '{'.'.join(outer_keys)}' isn't a registered Anyfig config class"
assert figutils.is_config_class(config_obj), err_msg
# Check if the config allows the argument
figutils.check_allowed_input_argument(config_obj, key_part, argument_key)

# Error if trying to set unknown attribute key
err_msg = f"{base_err_msg}. '{inner_key}' isn't an attribute in '{config_class}'"
assert inner_key in vars(config_obj), err_msg
# Check if the outer attributes are config classes
if key_idx < len(outer_keys):
config_obj = getattr(config_obj, key_part)
config_class = type(config_obj).__name__
err_msg = f"{base_err_msg}. '{'.'.join(outer_keys)}' isn't a registered Anyfig config class"
assert figutils.is_config_class(config_obj), err_msg

# Class definition
value_class = type(getattr(config_obj, inner_key))
base_err_msg = f"Input argument '{argument_key}' with value {val} can't create an object of the expected type"

Expand All @@ -134,9 +132,14 @@ def overwrite(main_config_obj, args):
# Create new object that follows the InterfaceField's rules
elif issubclass(value_class, fields.InterfaceField):
field = getattr(config_obj, inner_key)

if isinstance(value_class, fields.InputField):
value_class = field.type_pattern
else:
value_class = type(field.value)

try:
# TODO: Naive solution. Doesn't handle e.g. typing.Union[Path, str]
val = field.type_pattern(val)
val = value_class(val)
except Exception as e:
err_msg = f"{base_err_msg} {field.type_pattern}. {e}"
raise RuntimeError(err_msg) from None
Expand Down
18 changes: 14 additions & 4 deletions anyfig/config_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,23 @@ class MasterConfig(ABC):
def __init__(self):
pass # Add empty init if config doesn't have one

def comments_string(self):
''' Returns string for config class's attributes and comments '''
return print_utils.comments_string(self)
def allowed_cli_args(self):
''' Returns the attribute names that can be be overwritten from command line input '''
return self.get_parameters()

def post_init(self):
''' A function that is called after overwriting from command line input '''
pass

def cli_help(self):
return print_utils.cli_help(self)

def frozen(self, freeze=True):
''' Freeze/unfreeze config '''
self.__class__._frozen = freeze # TODO: .__class__ needed? type(self) instead?
self._frozen = freeze
for _, val in self.get_parameters(copy=False).items():
if figutils.is_config_class(val):
val.frozen(freeze)
return self

def get_parameters(self, copy=True):
Expand Down
6 changes: 5 additions & 1 deletion anyfig/dummyfields.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def raise_error():
err_msg = f"This feature isn't supported in Python {sys.version}. See our website '{figutils.get_website()}' for more information"
err_msg = f"This feature isn't supported in Python {sys.version_info.major}.{sys.version_info.minor}. See our website '{figutils.get_website()}' for more information"
raise RuntimeError(err_msg)


Expand All @@ -14,3 +14,7 @@ def field(*args, **kwargs):

def constant(value, strict=False):
raise_error()


def cli_input(type_pattern):
raise_error()
54 changes: 45 additions & 9 deletions anyfig/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,52 @@
from .figutils import is_config_class


def field(*args, **kwargs):
def field(type_pattern=typing.Any, tests=None):
''' Returns an InterfaceField '''
return InterfaceField(*args, **kwargs)
return InterfaceField(type_pattern, tests)


def constant(value, strict=False):
''' Returns a ConstantField '''
return ConstantField(value, strict)


def resolve_fields(config):
def cli_input(type_pattern):
''' Returns an InputField '''
assert type_pattern in [str, int, tuple, list, dict]
return InputField(type_pattern)


def validate_fields(config):
''' Validates that fields has a value '''
for key, val in vars(config).items():
if type(val) is InterfaceField: # Don't check InputField or ConstantField
err_msg = (
f"Missing value for '{key}' in config '{type(config).__name__}'. "
"Set a value or change the type to 'anyfig.cli_input' to allow input arguments without default values"
)
assert hasattr(val, 'value'), err_msg

# Resolve nested configs
if is_config_class(val):
validate_fields(val)


def resolve_fields(config, cli_name=''):
''' Removes wrapping for InterfaceFields '''
for key, val in vars(config).items():
if isinstance(val, InterfaceField):
cli_name = '.'.join([cli_name, key]).lstrip('.')
config_class = type(config).__name__
value = val.finish_wrapping_phase(key, config_class)
value = val.finish_wrapping_phase(cli_name, config_class)
setattr(config, key, value)

# Resolve nested configs
if is_config_class(val):
resolve_fields(val)
resolve_fields(val, cli_name=key)


class InterfaceField():
class InterfaceField:
''' Used to define allowed values for a config-attribute '''
def __init__(self, type_pattern=typing.Any, tests=None):
err_msg = f"Expected 'type_pattern' to be a type or a typing pattern but got {type(type_pattern)}"
Expand All @@ -44,8 +66,7 @@ def __init__(self, type_pattern=typing.Any, tests=None):

def update_value(self, name, value, config_class):
# Updates value and return wrapped value or value if setup is finished
if self.type_pattern:
check_type(name, value, self.type_pattern)
check_type(name, value, self.type_pattern)

for test in self.tests:
self._check_test(test, name, value, config_class)
Expand All @@ -55,7 +76,8 @@ def update_value(self, name, value, config_class):

def finish_wrapping_phase(self, name, config_class):
# Verifies that attribute is overridden and finishes setup
err_msg = f"Attribute '{name}' in '{config_class}' is required to be overridden"
inner_key = name.split('.')[-1]
err_msg = f"The field '{inner_key}' in '{config_class}' is required to be overridden"
assert hasattr(self, 'value'), err_msg

self.wrapping_phase = False
Expand Down Expand Up @@ -95,3 +117,17 @@ def _check_test(self, test, name, value, config_class):
''' Calls the test with the new attribute value. Raises error if test doesn't pass '''
err_msg = f"Can't override constant '{name}' with value '{value}' in config '{config_class}'"
assert test(value), err_msg


class InputField(InterfaceField):
''' Used to define required config-attribute from command line input '''
def __init__(self, type_pattern):
super().__init__(type_pattern=type_pattern)

def finish_wrapping_phase(self, name, config_class):
# Verifies that attribute is overridden and finishes setup
err_msg = f"Missing required input argument --{name}. See --help for more info"
assert hasattr(self, 'value'), err_msg

self.wrapping_phase = False
return self.value
50 changes: 45 additions & 5 deletions anyfig/figutils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import inspect
import dill

from pathlib import Path
from collections.abc import Iterable

import dill

registered_config_classes = {}
global_configs = {}
Expand Down Expand Up @@ -46,16 +47,16 @@ def is_config_class(obj):
return inspect.isclass(obj) and obj.__name__ in registered_config_classes


def cfg():
''' Returns the config object that is registed with anyfig '''
def get_config():
''' Returns the config object that is registered with anyfig '''

# Normal case
if len(global_configs) == 1:
return next(iter(global_configs.values()))

# init_config function adds one so this should never happen
elif len(global_configs) == 0:
raise RuntimeError("No global config has been registered")
raise RuntimeError("No config object has been registered")

# If multiple config objects has been marked as global
raise RuntimeError(
Expand Down Expand Up @@ -117,3 +118,42 @@ def find_arguments(callable_):
if param.default == inspect.Parameter.empty
]
return list(parameters), required_args


def check_allowed_input_argument(config_obj, name, deep_name):
''' Raises error if the input argument isn't marked as "allowed" '''
allowed_args = get_allowed_cli_args(config_obj)
if name not in allowed_args:
err_msg = f"Input argument '{deep_name}' is not allowed to be overwritten. See --help for more info"
raise ValueError(err_msg)


def get_allowed_cli_args(config_obj):
''' Returns the attribute names that can be be overwritten from command line input.
Raises AttributeError if an attribute doesn't exist '''
allowed_items = config_obj.allowed_cli_args()
if allowed_items is None:
allowed_items = []
if isinstance(allowed_items, str):
allowed_items = [allowed_items]
err_msg = (
f"Expected return type 'String, None or Iterable' for {type(config_obj).__name__}'s allowed_cli_args method, "
f"was {allowed_items} with type {type(allowed_items)}")
assert isinstance(allowed_items, Iterable), err_msg

attributes = config_obj.get_parameters()
for item in allowed_items:
if item not in attributes:
err_msg = (
f"'{type(config_obj).__name__}' has no attribute '{item}' and should not be marked as an allowed command line "
"input argument")
raise AttributeError(err_msg)
return allowed_items


def post_init(config_obj):
''' Recursively calls the post_init method on a config and it's attributes '''
config_obj.post_init()
for _, val in config_obj.get_parameters(copy=False).items():
if is_config_class(val):
post_init(val)
Loading

1 comment on commit 8c64069

@vercel
Copy link

@vercel vercel bot commented on 8c64069 Dec 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.