diff --git a/formulaic/formula.py b/formulaic/formula.py index a4690f5..107dd89 100644 --- a/formulaic/formula.py +++ b/formulaic/formula.py @@ -1,8 +1,8 @@ from __future__ import annotations -import sys from abc import ABCMeta, abstractmethod -from collections.abc import Generator, Iterable, Mapping, MutableSequence +from collections import Counter +from collections.abc import Generator, Iterable, Mapping from enum import Enum from typing import ( Any, @@ -11,7 +11,6 @@ TypeVar, Union, cast, - overload, ) from typing_extensions import Self, TypeAlias @@ -330,12 +329,12 @@ def differentiate( class SimpleFormula( - MutableSequence[Term] if sys.version_info >= (3, 9) else MutableSequence, # type: ignore + OrderedSet[Term], Formula, ): """ The atomic component of all formulae represented by Formulaic, which in turn - is a mutable sequence of `Term` instances. `StructuredFormula` uses + is a mutable ordered set of `Term` instances. `StructuredFormula` uses `SimpleFormula` as its nodes. Instances of this class can be used directly as a mutable sequence of @@ -355,7 +354,9 @@ class SimpleFormula( def __init__( self, - root: Union[Iterable[Term], MissingType] = MISSING, + root: Union[ + Iterable[Term], Mapping[Term, int], OrderedSet[Term], MissingType + ] = MISSING, *, _ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE, _parser: Optional[FormulaParser] = None, @@ -376,10 +377,10 @@ def __init__( "`SimpleFormula` does not support nested structure. To create a " "structured formula, use `StructuredFormula` instead." ) - self.__terms = list(root) self.ordering = OrderingMethod(_ordering) - self.__validate_terms(self.__terms) + self.__validate_terms(root) + super().__init__(root) self._reorder() @@ -413,56 +414,36 @@ def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None: ) if orderer is not None: - self.__terms = orderer(self.__terms) - - # MutableSequence implementation - - @overload - def __getitem__(self, key: int) -> Term: ... - - @overload - def __getitem__(self, key: slice) -> SimpleFormula: ... - - def __getitem__(self, key: Union[int, slice]) -> Union[Term, SimpleFormula]: - if isinstance(key, slice): - return self.__class__(self.__terms[key], _ordering=self.ordering) - else: - return self.__terms[key] - - @overload - def __setitem__(self, key: int, value: Term) -> None: ... + self._values = Counter( + {item: self._values[item] for item in orderer(self._values)} + ) - @overload - def __setitem__(self, key: slice, value: Iterable[Term]) -> None: ... + # Overrides to ensure that ordering is maintained - def __setitem__(self, key, value): # type: ignore - self.__validate_terms([value]) - self.__terms[key] = value + def add(self, item: Term) -> None: + self.__validate_terms([item]) + super().add(item) self._reorder() - @overload - def __delitem__(self, key: int) -> None: ... - - @overload - def __delitem__(self, key: slice) -> None: ... - - def __delitem__(self, key): # type: ignore - del self.__terms[key] - - def __len__(self) -> int: - return len(self.__terms) + def update(self, items: Iterable[Term]) -> None: + """ + Update this formula with the terms from another iterable, or a mapping + from terms to observed counts. - def insert(self, index: int, value: Term) -> None: - self.__validate_terms([value]) - self.__terms.insert(index, value) + Args: + items: The terms to add to this formula. If an iterable is + is provided, the terms will be added with a count of 1. + Otherwise the counts will be aggregated from the mapping and/or + ordered set instances. + """ + self.__validate_terms(items) + super().update(items) self._reorder() def __eq__(self, other: Any) -> bool: - if isinstance(other, SimpleFormula): - other = list(other) - if isinstance(other, list): - return self.__terms == other - return NotImplemented + if isinstance(other, (SimpleFormula, list, tuple)): + return tuple(other) == tuple(other) + return super().__eq__(other) # Transforms @@ -487,7 +468,7 @@ def differentiate( # pylint: disable=redefined-builtin return SimpleFormula( [ differentiate_term(term, wrt, use_sympy=use_sympy) - for term in self.__terms + for term in self._values ], # Preserve term ordering even if differentiation modifies degrees/etc. _ordering=OrderingMethod.NONE, @@ -537,7 +518,7 @@ def required_variables(self) -> set[Variable]: variables: list[Variable] = [ variable - for term in self.__terms + for term in self._values for factor in term.factors for variable in get_expression_variables(factor.expr, {}) if "value" in variable.roles @@ -555,7 +536,7 @@ def required_variables(self) -> set[Variable]: ) def __repr__(self) -> str: - return " + ".join([str(t) for t in self.__terms]) + return " + ".join([str(t) for t in self]) # Deprecated shims for legacy `Structured`-like behaviour (previously there # was no distinction between `SimpleFormula` and `StructuredFormula`, and diff --git a/formulaic/utils/ordered_set.py b/formulaic/utils/ordered_set.py index f2a8279..9820ca8 100644 --- a/formulaic/utils/ordered_set.py +++ b/formulaic/utils/ordered_set.py @@ -1,21 +1,25 @@ from __future__ import annotations from collections import Counter -from collections.abc import Iterable, Iterator, Mapping, MutableSet, Sequence -from itertools import islice +from collections.abc import Iterable, Iterator, Mapping, MutableSequence, MutableSet +from itertools import chain, islice from typing import Any, Generic, TypeVar, Union, overload _ItemType = TypeVar("_ItemType") _SelfType = TypeVar("_SelfType", bound="OrderedSet") -class OrderedSet(MutableSet, Sequence, Generic[_ItemType]): +class OrderedSet(MutableSet, MutableSequence, Generic[_ItemType]): """ - A mutable set-like sequenced container that retains the order in which item + A mutable set-like sequence container that retains the order in which item were added to the set, keeps track of multiplicities (how many times an item was added), and provides both set and list-like indexing and mutations. This container keeps track of how many times an item was added to the set, which can be checked using the `.get_multiplicity()` method. + + This class is optimised for set-like operations, but also provides O(n) + lookups by index, insertions, deletions, and updates. We may optimise this + in the future based on need by maintaining index tables. """ def __init__( @@ -58,7 +62,21 @@ def discard(self, item: _ItemType) -> None: if item in self._values: del self._values[item] - # Additional methods for Sequence interface (O(n) lookups by index) + # MutableSet order preservation + + def __ror__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) | self + + def __rxor__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) ^ self + + def __rand__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) & self + + def __rsub__(self, other: Any) -> OrderedSet[_ItemType]: + return OrderedSet(other) - self + + # Additional methods for MutableSequence interface (O(n) lookups by index) @overload def __getitem__(self, index: int) -> _ItemType: ... @@ -74,12 +92,76 @@ def __getitem__( { item: self._values[item] for item in islice( - self._values, index.start, index.stop, index.step + self._values, index.start % len(self) if index.start is not None else None, index.stop % len(self) if index.stop is not None else None, index.step ) } ) else: - return next(islice(self._values, index, None)) + return next(islice(self._values, index % len(self), None)) + + @overload + def __setitem__(self, key: int, value: _ItemType) -> None: ... + + @overload + def __setitem__(self, key: slice, value: Iterable[_ItemType]) -> None: ... + + def __setitem__(self, key, value): # type: ignore + + if isinstance(key, slice): + items_to_replace = self[key] + if len(items_to_replace) != len(value): + raise ValueError( + "It does not make sense to replace a slice with a different number of items." + ) + for item, value in zip(items_to_replace, value): + self[item] = value + return + + item_to_replace = self[key] + if item_to_replace == value: + return + if value in self._values: + raise ValueError( + "It does not make sense to replace an item with another item that is already in the set." + ) + self._values = Counter( + { + value if item == item_to_replace else item: self._values.get( + value if item == item_to_replace else item, 1 + ) + for item in self._values + } + ) + + @overload + def __delitem__(self, key: int) -> None: ... + + @overload + def __delitem__(self, key: slice) -> None: ... + + def __delitem__(self, key): # type: ignore + if isinstance(key, slice): + for item in self[key]: + del self._values[item] + else: + del self._values[self[key]] + + def insert(self, index: int, value: _ItemType) -> None: + if value in self._values: + raise ValueError( + "It does not makes sense to insert an item that is already in the set." + ) + + self._values = Counter( + { + item: self._values.get(item, 1) + for item in chain( + islice(self._values, 0, index), + (value,), + islice(self._values, index, None), + ) + } + ) # Other data model methods