Skip to content

Commit

Permalink
fix doc strings & README.md example
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Mar 15, 2024
1 parent d7c4583 commit 37910ed
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class Model(eqx.Module):

def __call__(self, hists: dict[str, Array]) -> Array:
mu_modifier = self.mu.unconstrained()
syst_modifier = self.syst.lnN(width=jnp.array([0.9, 1.1]))
syst_modifier = self.syst.lnN(up=jnp.array([1.1]), down=jnp.array([0.9]))
return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])


Expand Down
5 changes: 3 additions & 2 deletions src/evermore/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,9 @@ def scale_factor(self, hist: Array) -> SF:
groups[jtu.tree_structure(mod)].append(mod)
# then do the `jax.lax.scan` loops
for _, group_mods in groups.items():
# Essentially we are turning an array of modifiers into a single modifier with a stack of scale factors and effect leaves (e.g. `width`).
# Then we can use XLA's loop constructs (e.g.: `jax.lax.scan`) to calculate the scale factors without having to compile the fully unrolled loop.
# Essentially we are turning an array of modifiers into a single modifier of stacked leaves.
# Then we can use XLA's loop constructs (e.g.: `jax.lax.scan`) to calculate the scale factors
# without having to compile the fully unrolled loop.
stack = tree_stack(group_mods, broadcast_leaves=True) # type: ignore[arg-type]
# scan over first axis of stack
dynamic_stack, static_stack = eqx.partition(stack, eqx.is_array)
Expand Down

0 comments on commit 37910ed

Please sign in to comment.