Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Dec 1, 2024
1 parent 89a4164 commit d6c3acc
Show file tree
Hide file tree
Showing 15 changed files with 808 additions and 47 deletions.
36 changes: 30 additions & 6 deletions formulaic/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys
from abc import ABCMeta, abstractmethod
from collections.abc import MutableSequence
import warnings
from enum import Enum
from typing import (
Any,
Expand Down Expand Up @@ -82,7 +83,7 @@ def __call__(
`SimpleFormula` instance will be returned; otherwise, a
`StructuredFormula`.
Some arguments a prefixed with underscores to prevent collision with
Some arguments are prefixed with underscores to prevent collision with
formula structure.
Args:
Expand Down Expand Up @@ -136,6 +137,7 @@ def from_spec(
ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
parser: Optional[FormulaParser] = None,
nested_parser: Optional[FormulaParser] = None,
context: Optional[Mapping[str, Any]] = None,
) -> Union[SimpleFormula, StructuredFormula]:
"""
Construct a `SimpleFormula` or `StructuredFormula` instance from a
Expand Down Expand Up @@ -164,7 +166,9 @@ def from_spec(
if isinstance(spec, str):
spec = cast(
FormulaSpec,
(parser or DefaultFormulaParser()).get_terms(spec)._simplify(),
(parser or DefaultFormulaParser())
.get_terms(spec, context=context)
._simplify(),
)

if isinstance(spec, dict):
Expand All @@ -190,7 +194,7 @@ def from_spec(
term
for value in spec
for term in (
nested_parser.get_terms(value) # type: ignore[attr-defined]
nested_parser.get_terms(value, context=context) # type: ignore[attr-defined]
if isinstance(value, str)
else [value]
)
Expand Down Expand Up @@ -248,9 +252,11 @@ class Formula(metaclass=_FormulaMeta):
def __init__(
self,
root: Union[FormulaSpec, _MissingType] = MISSING,
*,
_parser: Optional[FormulaParser] = None,
_nested_parser: Optional[FormulaParser] = None,
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_context: Optional[Mapping[str, Any]] = None,
**structure: FormulaSpec,
):
"""
Expand Down Expand Up @@ -354,6 +360,7 @@ def __init__(
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_parser: Optional[FormulaParser] = None,
_nested_parser: Optional[FormulaParser] = None,
_context: Optional[Mapping[str, Any]] = None,
**structure: FormulaSpec,
):
if root is MISSING:
Expand Down Expand Up @@ -667,19 +674,22 @@ class StructuredFormula(Structured[SimpleFormula], Formula):
formula specifications. Can be: "none", "degree" (default), or "sort".
"""

__slots__ = ("_parser", "_nested_parser", "_ordering")
__slots__ = ("_parser", "_nested_parser", "_ordering", "_context")

def __init__(
self,
root: Union[FormulaSpec, _MissingType] = MISSING,
*,
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_parser: Optional[FormulaParser] = None,
_nested_parser: Optional[FormulaParser] = None,
_ordering: Union[OrderingMethod, str] = OrderingMethod.DEGREE,
_context: Optional[Mapping[str, Any]] = None,
**structure: FormulaSpec,
):
self._ordering = OrderingMethod(_ordering)
self._parser = _parser or DEFAULT_PARSER
self._nested_parser = _nested_parser or _parser or DEFAULT_NESTED_PARSER
self._ordering = OrderingMethod(_ordering)
self._context = _context
super().__init__(root, **structure) # type: ignore
self._simplify(unwrap=False, inplace=True)

Expand All @@ -704,6 +714,7 @@ def _prepare_item( # type: ignore[override]
ordering=self._ordering,
parser=(self._parser if key == "root" else self._nested_parser),
nested_parser=self._nested_parser,
context=self._context,
)

def get_model_matrix(
Expand Down Expand Up @@ -782,3 +793,16 @@ def differentiate( # pylint: disable=redefined-builtin
SimpleFormula,
self._map(lambda formula: formula.differentiate(*wrt, use_sympy=use_sympy)),
)

# Ensure pickling never includes context
def __getstate__(self):
if self._context is not None:
warnings.warn(
"Dropping context from Formula instance during pickling.",
RuntimeWarning,
stacklevel=2,
)

state = super().__getstate__()
state[1]["_context"] = None
return state
262 changes: 262 additions & 0 deletions formulaic/materializers/Untitled-1.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
".lhs:\n",
" y\n",
".rhs:\n",
" root:\n",
" 1 + x_hat + `a:b_hat` + d_hat\n",
" .deps:\n",
" [0]:\n",
" .lhs:\n",
" x + a:b\n",
" .rhs:\n",
" 1 + z\n",
" [1]:\n",
" .lhs:\n",
" d\n",
" .rhs:\n",
" 1 + z2"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from formulaic import Formula\n",
"\n",
"Formula(\" y ~ [x + a:b ~ z] + [d ~ z2]\")"
]
},
{
"cell_type": "code",
"execution_count": 65,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'factor', 'prior_apps'}"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from formulaic import Formula\n",
"from formulaic.utils.variables import get_expression_variables\n",
"f = Formula(\"apps ~ prior_apps + I(prior_apps**2) + factor + prior_apps:factor\")\n",
"set(\n",
" variable\n",
" for term in f.rhs\n",
" for factor in term.factors\n",
" for variable in get_expression_variables(factor.expr, {})\n",
" if \"value\" in variable.roles\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [],
"source": [
"from formulaic.parser.types import Term, Factor"
]
},
{
"cell_type": "code",
"execution_count": 56,
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "unsupported operand type(s) for +: 'Term' and 'Term'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[56], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m (\u001b[43mTerm\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mFactor\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mx\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m2\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mTerm\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43mFactor\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mz\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m) \u001b[38;5;241m*\u001b[39m Term([Factor(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124my\u001b[39m\u001b[38;5;124m\"\u001b[39m)])\n",
"\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for +: 'Term' and 'Term'"
]
}
],
"source": [
"(Term([Factor(\"x\")]*2) + Term([Factor(\"z\")])) * Term([Factor(\"y\")])"
]
},
{
"cell_type": "code",
"execution_count": 58,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
".lhs:\n",
" y\n",
".rhs:\n",
" [0]:\n",
" 1 + w\n",
" [1]:\n",
" 1 + x\n",
" [2]:\n",
" 1 + z"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from formulaic import Formula\n",
"\n",
"Formula(\"y ~ w | x | z\")"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Intercept</th>\n",
" <th>varlist('X*')[X1]</th>\n",
" <th>varlist('X*')[X2]</th>\n",
" <th>varlist2('X*')[X1]</th>\n",
" <th>varlist2('X*')[X2]</th>\n",
" <th>varlist('X*')[X1]:varlist2('X*')[X1]</th>\n",
" <th>varlist('X*')[X2]:varlist2('X*')[X1]</th>\n",
" <th>varlist('X*')[X1]:varlist2('X*')[X2]</th>\n",
" <th>varlist('X*')[X2]:varlist2('X*')[X2]</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.0</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>2</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" <td>4</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.0</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>3</td>\n",
" <td>9</td>\n",
" <td>9</td>\n",
" <td>9</td>\n",
" <td>9</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Intercept ... varlist('X*')[X2]:varlist2('X*')[X2]\n",
"0 1.0 ... 1\n",
"1 1.0 ... 4\n",
"2 1.0 ... 9\n",
"\n",
"[3 rows x 9 columns]"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas\n",
"import re\n",
"\n",
"from formulaic import Formula\n",
"from formulaic.utils.stateful_transforms import stateful_transform\n",
"\n",
"@stateful_transform\n",
"def varlist(pattern, _context=None):\n",
" pattern = re.compile(pattern)\n",
" return {\n",
" variable: values\n",
" for variable, values in _context.named_layers.get(\"data\", {}).items()\n",
" if pattern.match(variable)\n",
" }\n",
"\n",
"Formula(\"varlist('X*') * varlist2('X*')\").get_model_matrix(pandas.DataFrame({\"X1\": [1,2,3], \"X2\": [1,2,3]}), context={\"varlist\": varlist, \"varlist2\": varlist})\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit d6c3acc

Please sign in to comment.