diff --git a/README.md b/README.md index 1f4cdf4..29c81ff 100644 --- a/README.md +++ b/README.md @@ -49,7 +49,7 @@ class Model(eqx.Module): def __call__(self, hists: dict[str, Array]) -> Array: mu_modifier = self.mu.unconstrained() - syst_modifier = self.syst.lnN(width=jnp.array([0.9, 1.1])) + syst_modifier = self.syst.lnN(up=jnp.array([1.1]), down=jnp.array([0.9])) return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"]) diff --git a/src/evermore/modifier.py b/src/evermore/modifier.py index 87c9b75..1d652e9 100644 --- a/src/evermore/modifier.py +++ b/src/evermore/modifier.py @@ -361,8 +361,9 @@ def scale_factor(self, hist: Array) -> SF: groups[jtu.tree_structure(mod)].append(mod) # then do the `jax.lax.scan` loops for _, group_mods in groups.items(): - # Essentially we are turning an array of modifiers into a single modifier with a stack of scale factors and effect leaves (e.g. `width`). - # Then we can use XLA's loop constructs (e.g.: `jax.lax.scan`) to calculate the scale factors without having to compile the fully unrolled loop. + # Essentially we are turning an array of modifiers into a single modifier of stacked leaves. + # Then we can use XLA's loop constructs (e.g.: `jax.lax.scan`) to calculate the scale factors + # without having to compile the fully unrolled loop. stack = tree_stack(group_mods, broadcast_leaves=True) # type: ignore[arg-type] # scan over first axis of stack dynamic_stack, static_stack = eqx.partition(stack, eqx.is_array)