diff --git a/HARK/core.py b/HARK/core.py index adf141556..be72e83a6 100644 --- a/HARK/core.py +++ b/HARK/core.py @@ -15,11 +15,13 @@ from copy import copy, deepcopy from dataclasses import dataclass, field from time import time -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Union from warnings import warn import numpy as np import pandas as pd +from xarray import DataArray + from HARK.distribution import ( Distribution, IndexDistribution, @@ -28,7 +30,6 @@ ) from HARK.parallel import multi_thread_commands, multi_thread_commands_fake from HARK.utilities import NullFunc, get_arg_names -from xarray import DataArray logging.basicConfig(format="%(message)s") _log = logging.getLogger("HARK") @@ -61,205 +62,387 @@ def set_verbosity_level(level): class Parameters: """ - This class defines an object that stores all of the parameters for a model - as an internal dictionary. It is designed to also handle the age-varying - dynamics of parameters. + A smart container for model parameters that handles age-varying dynamics. + + This class stores parameters as an internal dictionary and manages their + age-varying properties, providing both attribute-style and dictionary-style + access. It is designed to handle the time-varying dynamics of parameters + in economic models. Attributes ---------- - _length : int The terminal age of the agents in the model. - _invariant_params : list - A list of the names of the parameters that are invariant over time. - _varying_params : list - A list of the names of the parameters that vary over time. + _invariant_params : Set[str] + A set of parameter names that are invariant over time. + _varying_params : Set[str] + A set of parameter names that vary over time. + _parameters : Dict[str, Any] + The internal dictionary storing all parameters. """ - def __init__(self, **parameters: Any): + __slots__ = ("_length", "_invariant_params", "_varying_params", "_parameters") + + def __init__(self, **parameters: Any) -> None: """ - Initializes a Parameters object and parses the age-varying - dynamics of the parameters. + Initialize a Parameters object and parse the age-varying dynamics of parameters. Parameters ---------- - - parameters : keyword arguments - Any number of keyword arguments of the form key=value. - To parse a dictionary of parameters, use the ** operator. + **parameters : Any + Any number of parameters in the form key=value. """ - params = parameters.copy() - self._length = params.pop("T_cycle", None) - self._invariant_params = set() - self._varying_params = set() - self._parameters: Dict[str, Union[int, float, np.ndarray, list, tuple]] = {} + self._length: int = parameters.pop("T_cycle", 1) + self._invariant_params: Set[str] = set() + self._varying_params: Set[str] = set() + self._parameters: Dict[str, Any] = {"T_cycle": self._length} - for key, value in params.items(): - self._parameters[key] = self.__infer_dims__(key, value) + for key, value in parameters.items(): + self[key] = value - def __infer_dims__( - self, key: str, value: Union[int, float, np.ndarray, list, tuple, None] - ) -> Union[int, float, np.ndarray, list, tuple]: + def __getitem__(self, item_or_key: Union[int, str]) -> Union["Parameters", Any]: """ - Infers the age-varying dimensions of a parameter. + Access parameters by age index or parameter name. - If the parameter is a scalar, numpy array, boolean, distribution, callable or None, - it is assumed to be invariant over time. If the parameter is a list or - tuple, it is assumed to be varying over time. If the parameter is a list - or tuple of length greater than 1, the length of the list or tuple must match - the `_term_age` attribute of the Parameters object. + If item_or_key is an integer, returns a Parameters object with the parameters + that apply to that age. This includes all invariant parameters and the + `item_or_key`th element of all age-varying parameters. If item_or_key is a + string, it returns the value of the parameter with that name. Parameters ---------- - key : str - name of parameter - value : Any - value of parameter - - """ - if isinstance( - value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) - ): - self.__add_to_invariant__(key) - return value - if isinstance(value, (list, tuple)): - if len(value) == 1: - self.__add_to_invariant__(key) - return value[0] - if self._length is None or self._length == 1: - self._length = len(value) - if len(value) == self._length: - self.__add_to_varying__(key) - return value - raise ValueError( - f"Parameter {key} must be of length 1 or {self._length}, not {len(value)}" - ) - raise ValueError(f"Parameter {key} has unsupported type {type(value)}") - - def __add_to_invariant__(self, key: str): - """ - Adds parameter name to invariant set and removes from varying set. - """ - self._varying_params.discard(key) - self._invariant_params.add(key) - - def __add_to_varying__(self, key: str): - """ - Adds parameter name to varying set and removes from invariant set. - """ - self._invariant_params.discard(key) - self._varying_params.add(key) + item_or_key : Union[int, str] + Age index or parameter name. - def __getitem__(self, item_or_key: Union[int, str]): - """ - If item_or_key is an integer, returns a Parameters object with the parameters - that apply to that age. This includes all invariant parameters and the - `item_or_key`th element of all age-varying parameters. If item_or_key is a string, - it returns the value of the parameter with that name. + Returns + ------- + Union[Parameters, Any] + A new Parameters object for the specified age, or the value of the + specified parameter. + + Raises + ------ + ValueError: + If the age index is out of bounds. + KeyError: + If the parameter name is not found. + TypeError: + If the key is neither an integer nor a string. """ if isinstance(item_or_key, int): if item_or_key >= self._length: raise ValueError( - f"Age {item_or_key} is greater than or equal to terminal age {self._length}." + f"Age {item_or_key} is out of bounds (max: {self._length - 1})." ) params = {key: self._parameters[key] for key in self._invariant_params} params.update( { key: self._parameters[key][item_or_key] + if isinstance(self._parameters[key], (list, tuple, np.ndarray)) + else self._parameters[key] for key in self._varying_params } ) return Parameters(**params) elif isinstance(item_or_key, str): return self._parameters[item_or_key] + else: + raise TypeError("Key must be an integer (age) or string (parameter name).") - def __setitem__(self, key: str, value: Any): + def __setitem__(self, key: str, value: Any) -> None: """ - Sets the value of a parameter. + Set parameter values, automatically inferring time variance. + + If the parameter is a scalar, numpy array, boolean, distribution, callable + or None, it is assumed to be invariant over time. If the parameter is a + list or tuple, it is assumed to be varying over time. If the parameter + is a list or tuple of length greater than 1, the length of the list or + tuple must match the `_length` attribute of the Parameters object. Parameters ---------- key : str - name of parameter + Name of the parameter. value : Any - value of parameter + Value of the parameter. + Raises + ------ + ValueError: + If the parameter name is not a string or if the value type is unsupported. + If the parameter value is inconsistent with the current model length. """ if not isinstance(key, str): - raise ValueError("Parameters must be set with a string key") - self._parameters[key] = self.__infer_dims__(key, value) + raise ValueError(f"Parameter name must be a string, got {type(key)}") + + if isinstance( + value, (int, float, np.ndarray, type(None), Distribution, bool, Callable) + ): + self._invariant_params.add(key) + self._varying_params.discard(key) + elif isinstance(value, (list, tuple)): + if len(value) == 1: + value = value[0] + self._invariant_params.add(key) + self._varying_params.discard(key) + elif self._length is None or self._length == 1: + self._length = len(value) + self._varying_params.add(key) + self._invariant_params.discard(key) + elif len(value) == self._length: + self._varying_params.add(key) + self._invariant_params.discard(key) + else: + raise ValueError( + f"Parameter {key} must have length 1 or {self._length}, not {len(value)}" + ) + else: + raise ValueError(f"Unsupported type for parameter {key}: {type(value)}") - def keys(self): + self._parameters[key] = value + + def __iter__(self) -> Iterator[str]: + """Allow iteration over parameter names.""" + return iter(self._parameters) + + def __len__(self) -> int: + """Return the number of parameters.""" + return len(self._parameters) + + def keys(self) -> Iterator[str]: + """Return a view of parameter names.""" + return self._parameters.keys() + + def values(self) -> Iterator[Any]: + """Return a view of parameter values.""" + return self._parameters.values() + + def items(self) -> Iterator[Tuple[str, Any]]: + """Return a view of parameter (name, value) pairs.""" + return self._parameters.items() + + def to_dict(self) -> Dict[str, Any]: """ - Returns a list of the names of the parameters. + Convert parameters to a plain dictionary. + + Returns + ------- + Dict[str, Any] + A dictionary containing all parameters. """ - return self._invariant_params | self._varying_params + return dict(self._parameters) - def values(self): + def to_namedtuple(self) -> namedtuple: """ - Returns a list of the values of the parameters. + Convert parameters to a namedtuple. + + Returns + ------- + namedtuple + A namedtuple containing all parameters. """ - return list(self._parameters.values()) + return namedtuple("Parameters", self.keys())(**self.to_dict()) - def items(self): + def update(self, other: Union["Parameters", Dict[str, Any]]) -> None: """ - Returns a list of tuples of the form (name, value) for each parameter. + Update parameters from another Parameters object or dictionary. + + Parameters + ---------- + other : Union[Parameters, Dict[str, Any]] + The source of parameters to update from. + + Raises + ------ + TypeError + If the input is neither a Parameters object nor a dictionary. + """ + if isinstance(other, Parameters): + for key, value in other._parameters.items(): + self[key] = value + elif isinstance(other, dict): + for key, value in other.items(): + self[key] = value + else: + raise TypeError( + "Update source must be a Parameters object or a dictionary." + ) + + def __repr__(self) -> str: + """Return a detailed string representation of the Parameters object.""" + return ( + f"Parameters(_length={self._length}, " + f"_invariant_params={self._invariant_params}, " + f"_varying_params={self._varying_params}, " + f"_parameters={self._parameters})" + ) + + def __str__(self) -> str: + """Return a simple string representation of the Parameters object.""" + return f"Parameters({str(self._parameters)})" + + def __getattr__(self, name: str) -> Any: """ - return list(self._parameters.items()) + Allow attribute-style access to parameters. - def __iter__(self): + Parameters + ---------- + name : str + Name of the parameter to access. + + Returns + ------- + Any + The value of the specified parameter. + + Raises + ------ + AttributeError: + If the parameter name is not found. """ - Allows for iterating over the parameter names. + if name.startswith("_"): + return super().__getattribute__(name) + try: + return self._parameters[name] + except KeyError: + raise AttributeError(f"'Parameters' object has no attribute '{name}'") + + def __setattr__(self, name: str, value: Any) -> None: """ - return iter(self.keys()) + Allow attribute-style setting of parameters. - def __deepcopy__(self, memo): + Parameters + ---------- + name : str + Name of the parameter to set. + value : Any + Value to set for the parameter. """ - Returns a deep copy of the Parameters object. + if name.startswith("_"): + super().__setattr__(name, value) + else: + self[name] = value + + def __contains__(self, item: str) -> bool: + """Check if a parameter exists in the Parameters object.""" + return item in self._parameters + + def copy(self) -> "Parameters": + """ + Create a deep copy of the Parameters object. + + Returns + ------- + Parameters + A new Parameters object with the same contents. """ - return Parameters(**deepcopy(self.to_dict(), memo)) + return deepcopy(self) - def to_dict(self): + def add_to_time_vary(self, *params: str) -> None: """ - Returns a dictionary of the parameters. + Adds any number of parameters to the time-varying set. + + Parameters + ---------- + *params : str + Any number of strings naming parameters to be added to time_vary. """ - return {key: self._parameters[key] for key in self.keys()} + for param in params: + if param in self._parameters: + self._varying_params.add(param) + self._invariant_params.discard(param) + else: + warn( + f"Parameter '{param}' does not exist and cannot be added to time_vary." + ) - def to_namedtuple(self): + def add_to_time_inv(self, *params: str) -> None: """ - Returns a namedtuple of the parameters. + Adds any number of parameters to the time-invariant set. + + Parameters + ---------- + *params : str + Any number of strings naming parameters to be added to time_inv. """ - return namedtuple("Parameters", self.keys())(**self.to_dict()) + for param in params: + if param in self._parameters: + self._invariant_params.add(param) + self._varying_params.discard(param) + else: + warn( + f"Parameter '{param}' does not exist and cannot be added to time_inv." + ) - def update(self, other_params): + def del_from_time_vary(self, *params: str) -> None: """ - Updates the parameters with the values from another - Parameters object or a dictionary. + Removes any number of parameters from the time-varying set. Parameters ---------- - other_params : Parameters or dict - Parameters object or dictionary of parameters to update with. + *params : str + Any number of strings naming parameters to be removed from time_vary. """ - if isinstance(other_params, Parameters): - self._parameters.update(other_params.to_dict()) - elif isinstance(other_params, dict): - self._parameters.update(other_params) - else: - raise ValueError("Parameters must be a dict or a Parameters object") + for param in params: + self._varying_params.discard(param) - def __str__(self): + def del_from_time_inv(self, *params: str) -> None: """ - Returns a simple string representation of the Parameters object. + Removes any number of parameters from the time-invariant set. + + Parameters + ---------- + *params : str + Any number of strings naming parameters to be removed from time_inv. """ - return f"Parameters({str(self.to_dict())})" + for param in params: + self._invariant_params.discard(param) - def __repr__(self): + def get(self, key: str, default: Any = None) -> Any: """ - Returns a detailed string representation of the Parameters object. + Get a parameter value, returning a default if not found. + + Parameters + ---------- + key : str + The parameter name. + default : Any, optional + The default value to return if the key is not found. + + Returns + ------- + Any + The parameter value or the default. + """ + return self._parameters.get(key, default) + + def set_many(self, **kwargs: Any) -> None: """ - return f"Parameters( _age_inv = {self._invariant_params}, _age_var = {self._varying_params}, | {self.to_dict()})" + Set multiple parameters at once. + + Parameters + ---------- + **kwargs : Keyword arguments representing parameter names and values. + """ + for key, value in kwargs.items(): + self[key] = value + + def is_time_varying(self, key: str) -> bool: + """ + Check if a parameter is time-varying. + + Parameters + ---------- + key : str + The parameter name. + + Returns + ------- + bool + True if the parameter is time-varying, False otherwise. + """ + return key in self._varying_params class Model: @@ -280,12 +463,12 @@ def assign_parameters(self, **kwds): Parameters ---------- **kwds : keyword arguments - Any number of keyword arguments of the form key=value. Each value - will be assigned to the attribute named in self. + Any number of keyword arguments of the form key=value. + Each value will be assigned to the attribute named in self. Returns ------- - none + None """ self.parameters.update(kwds) for key in kwds: @@ -297,13 +480,12 @@ def get_parameter(self, name): Parameters ---------- - name : string + name : str The name of the parameter to get Returns ------- - value : - The value of the parameter + value : The value of the parameter """ return self.parameters[name] @@ -343,7 +525,7 @@ def del_param(self, param_name): Returns ------- - None. + None """ if param_name in self.parameters: del self.parameters[param_name] @@ -366,8 +548,8 @@ def construct(self, *args, force=False): Parameters ---------- *args : str, optional - Keys of self.constructors that are requested to be constructed. If - no arguments are passed, *all* elements of the dictionary are implied. + Keys of self.constructors that are requested to be constructed. + If no arguments are passed, *all* elements of the dictionary are implied. force : bool, optional When True, the method will force its way past any errors, including missing constructors, missing arguments for constructors, and errors @@ -492,13 +674,13 @@ def describe_constructors(self, *args): Parameters ---------- - *args : str + *args : str, optional Optional list of strings naming constructed inputs to be described. If none are passed, all constructors are described. Returns ------- - None. + None """ if len(args) > 0: keys = args @@ -701,8 +883,7 @@ def unpack(self, parameter): """ Unpacks a parameter from a solution object for easier access. After the model has been solved, the parameters (like consumption function) - reside in the attributes of each element of `ConsumerType.solution` - (e.g. `cFunc`). This method creates a (time varying) attribute of the given + reside in the attributes of each element of `ConsumerType.solution` (e.g. `cFunc`). This method creates a (time varying) attribute of the given parameter name that contains a list of functions accessible by `ConsumerType.parameter`. Parameters @@ -1576,7 +1757,7 @@ class Market(Model): A list of all the AgentTypes in this market. sow_vars : [string] Names of variables generated by the "aggregate market process" that should - be "sown" to the agents in the market. Aggregate state, etc. + "sown" to the agents in the market. Aggregate state, etc. reap_vars : [string] Names of variables to be collected ("reaped") from agents in the market to be used in the "aggregate market process". diff --git a/HARK/tests/test_core.py b/HARK/tests/test_core.py index 102d23deb..f08d05af5 100644 --- a/HARK/tests/test_core.py +++ b/HARK/tests/test_core.py @@ -182,103 +182,69 @@ def test_create_agents(self): self.assertEqual(len(self.agent_pop.agents), 12) -class test_parameters(unittest.TestCase): - def setUp(self): - self.params = Parameters( - T_cycle=3, - a=1, - b=[2, 3, 4], - c=np.array([5, 6, 7]), - d=[lambda x: x, lambda x: x**2, lambda x: x**3], - e=Uniform(), - f=[True, False, True], - ) - - def test_init(self): - self.assertEqual(self.params._length, 3) - self.assertEqual(self.params._invariant_params, {"a", "c", "e"}) - self.assertEqual(self.params._varying_params, {"b", "d", "f"}) - - def test_getitem(self): - self.assertEqual(self.params["a"], 1) - self.assertEqual(self.params[0]["b"], 2) - self.assertEqual(self.params["c"][1], 6) - - def test_setitem(self): - self.params["d"] = 8 - self.assertEqual(self.params["d"], 8) - - def test_update(self): - self.params.update({"a": 9, "b": [10, 11, 12]}) - self.assertEqual(self.params["a"], 9) - self.assertEqual(self.params[0]["b"], 10) - - def test_initialization(self): - params = Parameters(a=1, b=[1, 2], T_cycle=2) - assert params._length == 2 - assert params._invariant_params == {"a"} - assert params._varying_params == {"b"} - - def test_infer_dims_scalar(self): - params = Parameters(a=1) - assert params["a"] == 1 - - def test_infer_dims_array(self): - params = Parameters(b=np.array([1, 2])) - assert all(params["b"] == np.array([1, 2])) - - def test_infer_dims_list_varying(self): - params = Parameters(b=[1, 2], T_cycle=2) - assert params["b"] == [1, 2] - - def test_infer_dims_list_invariant(self): - params = Parameters(b=[1]) - assert params["b"] == 1 - - def test_setitem(self): - params = Parameters(a=1) - params["b"] = 2 - assert params["b"] == 2 - - def test_keys_values_items(self): - params = Parameters(a=1, b=2) - assert set(params.keys()) == {"a", "b"} - assert set(params.values()) == {1, 2} - assert set(params.items()) == {("a", 1), ("b", 2)} - - def test_to_dict(self): - params = Parameters(a=1, b=2) - assert params.to_dict() == {"a": 1, "b": 2} - - def test_to_namedtuple(self): - params = Parameters(a=1, b=2) - named_tuple = params.to_namedtuple() - assert named_tuple.a == 1 - assert named_tuple.b == 2 - - def test_update_params(self): - params1 = Parameters(a=1, b=2) - params2 = Parameters(a=3, c=4) - params1.update(params2) - assert params1["a"] == 3 - assert params1["c"] == 4 - - def test_unsupported_type_error(self): +import pytest +import numpy as np +from HARK.distribution import Uniform +from HARK.core import Parameters + + +@pytest.fixture +def sample_params(): + return Parameters(a=1, b=[2, 3, 4], c=5.0, d=[6.0, 7.0, 8.0], T_cycle=3) + + +class TestParameters: + def test_initialization(self, sample_params): + assert sample_params._length == 3 + assert sample_params._invariant_params == {"a", "c"} + assert sample_params._varying_params == {"b", "d"} + assert sample_params._parameters["T_cycle"] == 3 + + def test_getitem(self, sample_params): + assert sample_params["a"] == 1 + assert sample_params["b"] == [2, 3, 4] + assert sample_params[0]["b"] == 2 + assert sample_params[1]["d"] == 7.0 + + def test_setitem(self, sample_params): + sample_params["e"] = 9 + assert sample_params["e"] == 9 + assert "e" in sample_params._invariant_params + + sample_params["f"] = [10, 11, 12] + assert sample_params["f"] == [10, 11, 12] + assert "f" in sample_params._varying_params + + def test_get(self, sample_params): + assert sample_params.get("a") == 1 + assert sample_params.get("z", 100) == 100 + + def test_set_many(self, sample_params): + sample_params.set_many(g=13, h=[14, 15, 16]) + assert sample_params["g"] == 13 + assert sample_params["h"] == [14, 15, 16] + + def test_is_time_varying(self, sample_params): + assert sample_params.is_time_varying("b") is True + assert sample_params.is_time_varying("a") is False + + def test_to_dict(self, sample_params): + params_dict = sample_params.to_dict() + assert isinstance(params_dict, dict) + assert params_dict["a"] == 1 + assert params_dict["b"] == [2, 3, 4] + + def test_update(self, sample_params): + new_params = Parameters(a=100, e=200) + sample_params.update(new_params) + assert sample_params["a"] == 100 + assert sample_params["e"] == 200 + + @pytest.mark.parametrize("invalid_key", [1, 2.0, None, []]) + def test_setitem_invalid_key(self, sample_params, invalid_key): with pytest.raises(ValueError): - Parameters(b={1, 2}) + sample_params[invalid_key] = 42 - def test_get_item_dimension_error(self): - params = Parameters(b=[1, 2], T_cycle=2) + def test_setitem_invalid_value_length(self, sample_params): with pytest.raises(ValueError): - params[2] - - def test_getitem_with_key(self): - params = Parameters(a=1, b=[2, 3], T_cycle=2) - assert params["a"] == 1 - assert params["b"] == [2, 3] - - def test_getitem_with_item(self): - params = Parameters(a=1, b=[2, 3], T_cycle=2) - age_params = params[1] - assert age_params["a"] == 1 - assert age_params["b"] == 3 + sample_params["invalid"] = [1, 2] # Should be length 1 or 3