Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Jan 9, 2025
1 parent 3bb90dc commit 1278e86
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 60 deletions.
87 changes: 34 additions & 53 deletions formulaic/formula.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from __future__ import annotations

import sys
from abc import ABCMeta, abstractmethod
from collections.abc import Generator, Iterable, Mapping, MutableSequence
from collections import Counter
from collections.abc import Generator, Iterable, Mapping
from enum import Enum
from typing import (
Any,
Expand All @@ -11,7 +11,6 @@
TypeVar,
Union,
cast,
overload,
)

from typing_extensions import Self, TypeAlias
Expand Down Expand Up @@ -330,12 +329,12 @@ def differentiate(


class SimpleFormula(
MutableSequence[Term] if sys.version_info >= (3, 9) else MutableSequence, # type: ignore
OrderedSet[Term],
Formula,
):
"""
The atomic component of all formulae represented by Formulaic, which in turn
is a mutable sequence of `Term` instances. `StructuredFormula` uses
is a mutable ordered set of `Term` instances. `StructuredFormula` uses
`SimpleFormula` as its nodes.
Instances of this class can be used directly as a mutable sequence of
Expand All @@ -355,7 +354,9 @@ class SimpleFormula(

def __init__(
self,
root: Union[Iterable[Term], MissingType] = MISSING,
root: Union[
Iterable[Term], Mapping[Term, int], OrderedSet[Term], MissingType
] = MISSING,
*,
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_parser: Optional[FormulaParser] = None,
Expand All @@ -376,10 +377,10 @@ def __init__(
"`SimpleFormula` does not support nested structure. To create a "
"structured formula, use `StructuredFormula` instead."
)
self.__terms = list(root)
self.ordering = OrderingMethod(_ordering)

self.__validate_terms(self.__terms)
self.__validate_terms(root)
super().__init__(root)

self._reorder()

Expand Down Expand Up @@ -413,56 +414,36 @@ def _reorder(self, ordering: Optional[OrderingMethod] = None) -> None:
)

if orderer is not None:
self.__terms = orderer(self.__terms)

# MutableSequence implementation

@overload
def __getitem__(self, key: int) -> Term: ...

@overload
def __getitem__(self, key: slice) -> SimpleFormula: ...

def __getitem__(self, key: Union[int, slice]) -> Union[Term, SimpleFormula]:
if isinstance(key, slice):
return self.__class__(self.__terms[key], _ordering=self.ordering)
else:
return self.__terms[key]

@overload
def __setitem__(self, key: int, value: Term) -> None: ...
self._values = Counter(
{item: self._values[item] for item in orderer(self._values)}
)

@overload
def __setitem__(self, key: slice, value: Iterable[Term]) -> None: ...
# Overrides to ensure that ordering is maintained

def __setitem__(self, key, value): # type: ignore
self.__validate_terms([value])
self.__terms[key] = value
def add(self, item: Term) -> None:
self.__validate_terms([item])
super().add(item)
self._reorder()

@overload
def __delitem__(self, key: int) -> None: ...

@overload
def __delitem__(self, key: slice) -> None: ...

def __delitem__(self, key): # type: ignore
del self.__terms[key]

def __len__(self) -> int:
return len(self.__terms)
def update(self, items: Iterable[Term]) -> None:
"""
Update this formula with the terms from another iterable, or a mapping
from terms to observed counts.
def insert(self, index: int, value: Term) -> None:
self.__validate_terms([value])
self.__terms.insert(index, value)
Args:
items: The terms to add to this formula. If an iterable is
is provided, the terms will be added with a count of 1.
Otherwise the counts will be aggregated from the mapping and/or
ordered set instances.
"""
self.__validate_terms(items)
super().update(items)
self._reorder()

def __eq__(self, other: Any) -> bool:
if isinstance(other, SimpleFormula):
other = list(other)
if isinstance(other, list):
return self.__terms == other
return NotImplemented
if isinstance(other, (SimpleFormula, list, tuple)):
return tuple(other) == tuple(other)
return super().__eq__(other)

# Transforms

Expand All @@ -487,7 +468,7 @@ def differentiate( # pylint: disable=redefined-builtin
return SimpleFormula(
[
differentiate_term(term, wrt, use_sympy=use_sympy)
for term in self.__terms
for term in self._values
],
# Preserve term ordering even if differentiation modifies degrees/etc.
_ordering=OrderingMethod.NONE,
Expand Down Expand Up @@ -537,7 +518,7 @@ def required_variables(self) -> set[Variable]:

variables: list[Variable] = [
variable
for term in self.__terms
for term in self._values
for factor in term.factors
for variable in get_expression_variables(factor.expr, {})
if "value" in variable.roles
Expand All @@ -555,7 +536,7 @@ def required_variables(self) -> set[Variable]:
)

def __repr__(self) -> str:
return " + ".join([str(t) for t in self.__terms])
return " + ".join([str(t) for t in self])

# Deprecated shims for legacy `Structured`-like behaviour (previously there
# was no distinction between `SimpleFormula` and `StructuredFormula`, and
Expand Down
96 changes: 89 additions & 7 deletions formulaic/utils/ordered_set.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,25 @@
from __future__ import annotations

from collections import Counter
from collections.abc import Iterable, Iterator, Mapping, MutableSet, Sequence
from itertools import islice
from collections.abc import Iterable, Iterator, Mapping, MutableSequence, MutableSet
from itertools import chain, islice
from typing import Any, Generic, TypeVar, Union, overload

_ItemType = TypeVar("_ItemType")
_SelfType = TypeVar("_SelfType", bound="OrderedSet")


class OrderedSet(MutableSet, Sequence, Generic[_ItemType]):
class OrderedSet(MutableSet, MutableSequence, Generic[_ItemType]):
"""
A mutable set-like sequenced container that retains the order in which item
A mutable set-like sequence container that retains the order in which item
were added to the set, keeps track of multiplicities (how many times an item
was added), and provides both set and list-like indexing and mutations. This
container keeps track of how many times an item was added to the set, which
can be checked using the `.get_multiplicity()` method.
This class is optimised for set-like operations, but also provides O(n)
lookups by index, insertions, deletions, and updates. We may optimise this
in the future based on need by maintaining index tables.
"""

def __init__(
Expand Down Expand Up @@ -58,7 +62,21 @@ def discard(self, item: _ItemType) -> None:
if item in self._values:
del self._values[item]

# Additional methods for Sequence interface (O(n) lookups by index)
# MutableSet order preservation

def __ror__(self, other: Any) -> OrderedSet[_ItemType]:
return OrderedSet(other) | self

def __rxor__(self, other: Any) -> OrderedSet[_ItemType]:
return OrderedSet(other) ^ self

def __rand__(self, other: Any) -> OrderedSet[_ItemType]:
return OrderedSet(other) & self

def __rsub__(self, other: Any) -> OrderedSet[_ItemType]:
return OrderedSet(other) - self

# Additional methods for MutableSequence interface (O(n) lookups by index)

@overload
def __getitem__(self, index: int) -> _ItemType: ...
Expand All @@ -74,12 +92,76 @@ def __getitem__(
{
item: self._values[item]
for item in islice(
self._values, index.start, index.stop, index.step
self._values, index.start % len(self) if index.start is not None else None, index.stop % len(self) if index.stop is not None else None, index.step
)
}
)
else:
return next(islice(self._values, index, None))
return next(islice(self._values, index % len(self), None))

@overload
def __setitem__(self, key: int, value: _ItemType) -> None: ...

@overload
def __setitem__(self, key: slice, value: Iterable[_ItemType]) -> None: ...

def __setitem__(self, key, value): # type: ignore

if isinstance(key, slice):
items_to_replace = self[key]
if len(items_to_replace) != len(value):
raise ValueError(
"It does not make sense to replace a slice with a different number of items."
)
for item, value in zip(items_to_replace, value):
self[item] = value
return

item_to_replace = self[key]
if item_to_replace == value:
return
if value in self._values:
raise ValueError(
"It does not make sense to replace an item with another item that is already in the set."
)
self._values = Counter(
{
value if item == item_to_replace else item: self._values.get(
value if item == item_to_replace else item, 1
)
for item in self._values
}
)

@overload
def __delitem__(self, key: int) -> None: ...

@overload
def __delitem__(self, key: slice) -> None: ...

def __delitem__(self, key): # type: ignore
if isinstance(key, slice):
for item in self[key]:
del self._values[item]
else:
del self._values[self[key]]

def insert(self, index: int, value: _ItemType) -> None:
if value in self._values:
raise ValueError(
"It does not makes sense to insert an item that is already in the set."
)

self._values = Counter(
{
item: self._values.get(item, 1)
for item in chain(
islice(self._values, 0, index),
(value,),
islice(self._values, index, None),
)
}
)

# Other data model methods

Expand Down

0 comments on commit 1278e86

Please sign in to comment.