Skip to content

Commit

Permalink
Add root saturation function (issue #702) (#858)
Browse files Browse the repository at this point in the history
* feat: adding root_saturation to transformers.py

* feat: adding RootSaturation class to saturation.py

* chore: adding missing RootSaturation to SATURATION_TRANSFORMATIONS

* feat: adding root_saturation to transformers.py

* feat: adding RootSaturation class to saturation.py

* chore: adding missing RootSaturation to SATURATION_TRANSFORMATIONS

* chore: linting edits

* chore: adding coefficient to function

* chore: linting corrections

* chore: removed empty References section of docstring

* chore: produce visual examples of root saturation

* chore: adding root to test_saturation.py

* chore: adding RootSaturation to init file

---------

Co-authored-by: ruari.walker <ruari.walker@qonto.com>
Co-authored-by: Will Dean <57733339+wd60622@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 25, 2024
1 parent 19aea61 commit 9755d3b
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
RootSaturation,
SaturationTransformation,
TanhSaturation,
TanhSaturationBaselined,
Expand All @@ -51,6 +52,7 @@
"MMMModelBuilder",
"MichaelisMentenSaturation",
"MonthlyFourier",
"RootSaturation",
"SaturationTransformation",
"TanhSaturation",
"TanhSaturationBaselined",
Expand Down
35 changes: 35 additions & 0 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def function(self, x, b):
inverse_scaled_logistic_saturation,
logistic_saturation,
michaelis_menten,
root_saturation,
tanh_saturation,
tanh_saturation_baselined,
)
Expand Down Expand Up @@ -369,6 +370,39 @@ class HillSaturation(SaturationTransformation):
}


class RootSaturation(SaturationTransformation):
"""Wrapper around Root saturation function.
For more information, see :func:`pymc_marketing.mmm.transformers.root_saturation`.
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
from pymc_marketing.mmm import RootSaturation
rng = np.random.default_rng(0)
saturation = RootSaturation()
prior = saturation.sample_prior(random_seed=rng)
curve = saturation.sample_curve(prior)
saturation.plot_curve(curve, sample_kwargs={"rng": rng})
plt.show()
"""

lookup_name = "root"

def function(self, x, alpha, beta):
return beta * root_saturation(x, alpha)

default_priors = {
"alpha": Prior("Beta", alpha=1, beta=2),
"beta": Prior("Gamma", mu=1, sigma=1),
}


SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {
cls.lookup_name: cls
for cls in [
Expand All @@ -378,6 +412,7 @@ class HillSaturation(SaturationTransformation):
TanhSaturationBaselined,
MichaelisMentenSaturation,
HillSaturation,
RootSaturation,
]
}

Expand Down
46 changes: 46 additions & 0 deletions pymc_marketing/mmm/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,49 @@ def hill_saturation(
The value of the Hill function for each input value of x.
"""
return sigma / (1 + pt.exp(-beta * (x - lam)))


def root_saturation(
x: pt.TensorLike,
alpha: pt.TensorLike,
) -> pt.TensorVariable:
r"""Root saturation transformation.
.. math::
f(x) = x^{\alpha}
.. plot::
:context: close-figs
import matplotlib.pyplot as plt
import numpy as np
import arviz as az
from pymc_marketing.mmm.transformers import root_saturation
plt.style.use('arviz-darkgrid')
alpha = np.array([0.1, 0.3, 0.5, 0.7])
x = np.linspace(0, 5, 100)
ax = plt.subplot(111)
for a in alpha:
y = root_saturation(x, alpha=a)
plt.plot(x, y, label=f'alpha = {a}')
plt.xlabel('spend', fontsize=12)
plt.ylabel('f(spend)', fontsize=12)
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()
Parameters
----------
x : tensor
Input tensor.
alpha : float
Exponent for the root transformation. Must be non-negative.
Returns
-------
tensor
Transformed tensor.
"""
return x**alpha
3 changes: 3 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
InverseScaledLogisticSaturation,
LogisticSaturation,
MichaelisMentenSaturation,
RootSaturation,
TanhSaturation,
TanhSaturationBaselined,
_get_saturation_function,
Expand All @@ -46,6 +47,7 @@ def saturation_functions():
TanhSaturationBaselined(),
MichaelisMentenSaturation(),
HillSaturation(),
RootSaturation(),
]


Expand Down Expand Up @@ -101,6 +103,7 @@ def test_support_for_lift_test_integrations(saturation) -> None:
("tanh_baselined", TanhSaturationBaselined),
("michaelis_menten", MichaelisMentenSaturation),
("hill", HillSaturation),
("root", RootSaturation),
],
)
def test_get_saturation_function(name, saturation_cls) -> None:
Expand Down

0 comments on commit 9755d3b

Please sign in to comment.