From 599c4abaf3b772304035996f3e5bdeb2c3fbc03e Mon Sep 17 00:00:00 2001 From: George Berry Date: Tue, 26 Aug 2025 18:27:09 -0400 Subject: [PATCH 1/2] update --- blayers/layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/blayers/layers.py b/blayers/layers.py index ed79969..906c444 100644 --- a/blayers/layers.py +++ b/blayers/layers.py @@ -843,7 +843,7 @@ def __init__( coef_dist: distribution for coefficients coef_kwargs: kwargs for coef distribution lmbda_kwargs: kwargs for scale prior - units: number of output dimensions + units: number of output dimensions """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist From 6a55debdd4712399c35d7f74e0cca20d9db4bb6c Mon Sep 17 00:00:00 2001 From: George Berry Date: Fri, 29 Aug 2025 11:21:55 -0400 Subject: [PATCH 2/2] update --- blayers/layers.py | 426 +++++++++++++++++++++++++++------------------- 1 file changed, 249 insertions(+), 177 deletions(-) diff --git a/blayers/layers.py b/blayers/layers.py index 906c444..c8d7540 100644 --- a/blayers/layers.py +++ b/blayers/layers.py @@ -36,7 +36,7 @@ # ---- Matmul functions ------------------------------------------------------ # -def pairwise_interactions(x: jax.Array, z: jax.Array) -> jax.Array: +def _pairwise_interactions(x: jax.Array, z: jax.Array) -> jax.Array: """ Compute all pairwise interactions between features in X and Y. @@ -186,7 +186,7 @@ def _matmul_interaction( """ # thanks chat GPT - interactions = pairwise_interactions(x, z) + interactions = _pairwise_interactions(x, z) return jnp.einsum("nd,du->nu", interactions, beta) @@ -199,7 +199,7 @@ class BLayer(ABC): @abstractmethod def __init__(self, *args: Any) -> None: - """Initialize layer parameters.""" + """Initialize layer parameters. This is the Bayesian model.""" @abstractmethod def __call__(self, *args: Any) -> Any: @@ -207,8 +207,6 @@ def __call__(self, *args: Any) -> Any: Run the layer's forward pass. Args: - name: Name scope for sampled variables. Note due to mypy stuff we - only write the `name` arg explicitly in subclass. *args: Inputs to the layer. Returns: @@ -242,8 +240,6 @@ def __init__( coef_dist: NumPyro distribution class for the coefficient prior. coef_kwargs: Parameters for the prior distribution. lmbda_kwargs: Parameters for the scale distribution. - units: The number of outputs - dependent_outputs: For multi-output models whether to treat the outputs as dependent. By deafult they are independent. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist @@ -261,11 +257,13 @@ def __call__( Forward pass with adaptive prior on coefficients. Args: - name: Variable name scope. - x: Input data array of shape (n, d, u). + name: Variable name. + x: Input data array of shape ``(n, d)``. + units: Number of outputs. + activation: Activation function to apply to output. Returns: - jax.Array: Output array of shape (n, u). + jax.Array: Output array of shape ``(n, u)``. """ x = add_trailing_dim(x) @@ -321,11 +319,13 @@ def __call__( Forward pass with fixed prior. Args: - name: Variable name prefix. - x: Input data array of shape (n, d). + name: Variable name. + x: Input data array of shape ``(n, d)``. + units: Number of outputs. + activation: Activation function to apply to output. Returns: - jax.Array: Output array of shape (n, u). + jax.Array: Output array of shape ``(n, u)``. """ x = add_trailing_dim(x) @@ -373,10 +373,12 @@ def __call__( Forward pass with fixed prior. Args: - name: Variable name prefix. + name: Variable name. + units: Number of outputs. + activation: Activation function to apply to output. Returns: - jax.Array: Output array of shape (1, u). + jax.Array: Output array of shape ``(1, u)``. """ # sampling block @@ -387,117 +389,6 @@ def __call__( return activation(beta) -class EmbeddingLayer(BLayer): - """Bayesian embedding layer for sparse categorical features.""" - - def __init__( - self, - lmbda_dist: distributions.Distribution = distributions.HalfNormal, - coef_dist: distributions.Distribution = distributions.Normal, - coef_kwargs: dict[str, float] = {"loc": 0.0}, - lmbda_kwargs: dict[str, float] = {"scale": 1.0}, - ): - """ - Args: - num_embeddings: Total number of discrete embedding entries. - embedding_dim: Dimensionality of each embedding vector. - coef_dist: Prior distribution for embedding weights. - coef_kwargs: Parameters for the prior distribution. - """ - self.lmbda_dist = lmbda_dist - self.coef_dist = coef_dist - self.coef_kwargs = coef_kwargs - self.lmbda_kwargs = lmbda_kwargs - - def __call__( - self, - name: str, - x: jax.Array, - num_categories: int, - embedding_dim: int, - ) -> jax.Array: - """ - Forward pass through embedding lookup. - - Args: - name: Variable name scope. - x: Integer indices of shape (n,) indicating embeddings to use. - num_categories: The number of distinct things getting an embedding - embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc. - - Returns: - jax.Array: Embedding vectors of shape (n, m). - """ - - # sampling block - lmbda = sample( - name=f"{self.__class__.__name__}_{name}_lmbda", - fn=self.lmbda_dist(**self.lmbda_kwargs), - ) - beta = sample( - name=f"{self.__class__.__name__}_{name}_beta", - fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( - [num_categories, embedding_dim] - ), - ) - # matmul and return - return beta[x.squeeze()] - - -class RandomEffectsLayer(BLayer): - """Exactly like the EmbeddingLayer but with ``embedding_dim=1``.""" - - def __init__( - self, - lmbda_dist: distributions.Distribution = distributions.HalfNormal, - coef_dist: distributions.Distribution = distributions.Normal, - coef_kwargs: dict[str, float] = {"loc": 0.0}, - lmbda_kwargs: dict[str, float] = {"scale": 1.0}, - ): - """ - Args: - ``num_embeddings``: Total number of discrete embedding entries. - ``embedding_dim``: Dimensionality of each embedding vector. - ``coef_dist``: Prior distribution for embedding weights. - ``coef_kwargs``: Parameters for the prior distribution. - """ - self.lmbda_dist = lmbda_dist - self.coef_dist = coef_dist - self.coef_kwargs = coef_kwargs - self.lmbda_kwargs = lmbda_kwargs - - def __call__( - self, - name: str, - x: jax.Array, - num_categories: int, - ) -> jax.Array: - """ - Forward pass through embedding lookup. - - Args: - name: Variable name scope. - x: Integer indices of shape (n,) indicating embeddings to use. - num_categories: The number of distinct things getting an embedding - - Returns: - jax.Array: Embedding vectors of shape (n, embedding_dim). - """ - - # sampling block - lmbda = sample( - name=f"{self.__class__.__name__}_{name}_lmbda", - fn=self.lmbda_dist(**self.lmbda_kwargs), - ) - beta = sample( - name=f"{self.__class__.__name__}_{name}_beta", - fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( - [num_categories, 1] - ), - ) - return beta[x.squeeze()] - - class FMLayer(BLayer): """Bayesian factorization machine layer with adaptive priors. @@ -530,7 +421,6 @@ def __init__( coef_dist: Prior for beta parameters. coef_kwargs: Arguments for prior distribution. lmbda_kwargs: Arguments for λ distribution. - low_rank_dim: Dimensionality of low-rank approximation. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist @@ -550,10 +440,13 @@ def __call__( Args: name: Variable name scope. - x: Input matrix of shape (n, d). + x: Input matrix of shape ``(n, d)``. + low_rank_dim: Dimensionality of low-rank approximation. + units: Number of outputs. + activation: Activation function to apply to output. Returns: - jax.Array: Output array of shape (n,). + jax.Array: Output array of shape ``(n, u)``. """ # get shapes and reshape if necessary x = add_trailing_dim(x) @@ -590,7 +483,6 @@ def __init__( coef_dist: Prior for beta parameters. coef_kwargs: Arguments for prior distribution. lmbda_kwargs: Arguments for λ distribution. - low_rank_dim: Dimensionality of low-rank approximation. """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist @@ -610,10 +502,13 @@ def __call__( Args: name: Variable name scope. - x: Input matrix of shape (n, d). + x: Input matrix of shape ``(n, d)``. + low_rank_dim: Dimensionality of low-rank approximation. + units: Number of outputs. + activation: Activation function to apply to output. Returns: - jax.Array: Output array of shape (n,). + jax.Array: Output array of shape ``(n,)``. """ # get shapes and reshape if necessary x = add_trailing_dim(x) @@ -643,7 +538,6 @@ def __init__( coef_dist: distributions.Distribution = distributions.Normal, coef_kwargs: dict[str, float] = {"loc": 0.0}, lmbda_kwargs: dict[str, float] = {"scale": 1.0}, - units: int = 1, ): self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist @@ -659,6 +553,20 @@ def __call__( units: int = 1, activation: Callable[[jax.Array], jax.Array] = jnn.identity, ) -> jax.Array: + """ + Interaction between feature matrices X and Z in a low rank way. UV decomp. + + Args: + name: Variable name scope. + x: Input matrix of shape ``(n, d1)``. + z: Input matrix of shape ``(n, d2)``. + low_rank_dim: Dimensionality of low-rank approximation. + units: Number of outputs. + activation: Activation function to apply to output. + + Returns: + jax.Array: Output array of shape ``(n, u)``. + """ # get shapes and reshape if necessary x = add_trailing_dim(x) z = add_trailing_dim(z) @@ -689,48 +597,6 @@ def __call__( return activation(_matmul_uv_decomp(theta1, theta2, x, z)) -class RandomWalkLayer(BLayer): - """Random walk of embedding dim ``m``, defaults to Gaussian walk.""" - - def __init__( - self, - lmbda_dist: distributions.Distribution = distributions.HalfNormal, - coef_dist: distributions.Distribution = distributions.Normal, - coef_kwargs: dict[str, float] = {"loc": 0.0}, - lmbda_kwargs: dict[str, float] = {"scale": 1.0}, - ): - self.lmbda_dist = lmbda_dist - self.coef_dist = coef_dist - self.coef_kwargs = coef_kwargs - self.lmbda_kwargs = lmbda_kwargs - - def __call__( - self, - name: str, - x: jax.Array, - num_categories: int, - embedding_dim: int, - ) -> jax.Array: - """ """ - - # sampling block - lmbda = sample( - name=f"{self.__class__.__name__}_{name}_lmbda", - fn=self.lmbda_dist(**self.lmbda_kwargs), - ) - theta = sample( - name=f"{self.__class__.__name__}_{name}_theta", - fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( - [ - num_categories, - embedding_dim, - ] - ), - ) - # matmul and return - return _matmul_randomwalk(theta, x) - - class InteractionLayer(BLayer): """Computes every interaction coefficient between two sets of inputs.""" @@ -754,6 +620,19 @@ def __call__( units: int = 1, activation: Callable[[jax.Array], jax.Array] = jnn.identity, ) -> jax.Array: + """ + Interaction between feature matrices X and Z in a low rank way. UV decomp. + + Args: + name: Variable name scope. + x: Input matrix of shape ``(n, d1)``. + z: Input matrix of shape ``(n, d2)``. + units: Number of outputs. + activation: Activation function to apply to output. + + Returns: + jax.Array: Output array of shape ``(n, u)``. + """ # get shapes and reshape if necessary x = add_trailing_dim(x) z = add_trailing_dim(z) @@ -791,7 +670,6 @@ def __init__( coef_dist: distribution for coefficients coef_kwargs: kwargs for coef distribution lmbda_kwargs: kwargs for scale prior - units: number of output dimensions """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist @@ -806,6 +684,19 @@ def __call__( units: int = 1, activation: Callable[[jax.Array], jax.Array] = jnn.identity, ) -> jax.Array: + """ + Interaction between feature matrices X and Z in a low rank way. UV decomp. + + Args: + name: Variable name scope. + x: Input matrix of shape ``(n, d1)``. + z: Input matrix of shape ``(n, d2)``. + units: Number of outputs. + activation: Activation function to apply to output. + + Returns: + jax.Array: Output array of shape ``(n, u)``. + """ # ensure inputs are [batch, dim] x = add_trailing_dim(x) z = add_trailing_dim(z) @@ -843,7 +734,6 @@ def __init__( coef_dist: distribution for coefficients coef_kwargs: kwargs for coef distribution lmbda_kwargs: kwargs for scale prior - units: number of output dimensions """ self.lmbda_dist = lmbda_dist self.coef_dist = coef_dist @@ -859,6 +749,20 @@ def __call__( units: int = 1, activation: Callable[[jax.Array], jax.Array] = jnn.identity, ) -> jax.Array: + """ + Interaction between feature matrices X and Z in a low rank way. UV decomp. + + Args: + name: Variable name scope. + x: Input matrix of shape ``(n, d1)``. + z: Input matrix of shape ``(n, d2)``. + low_rank_dim: Dimensionality of low-rank approximation. + units: Number of outputs. + activation: Activation function to apply to output. + + Returns: + jax.Array: Output array of shape ``(n, u)``. + """ # ensure inputs are [batch, dim] x = add_trailing_dim(x) z = add_trailing_dim(z) @@ -888,3 +792,171 @@ def __call__( out = jnp.sum(x_proj * z_proj, axis=1) # [batch, units] return activation(out) + + +# ---- Embeddings ------------------------------------------------------------ # + + +class EmbeddingLayer(BLayer): + """Bayesian embedding layer for sparse categorical features.""" + + def __init__( + self, + lmbda_dist: distributions.Distribution = distributions.HalfNormal, + coef_dist: distributions.Distribution = distributions.Normal, + coef_kwargs: dict[str, float] = {"loc": 0.0}, + lmbda_kwargs: dict[str, float] = {"scale": 1.0}, + ): + """ + Args: + lmbda_dist: NumPyro distribution class for the scale (λ) of the + prior. + coef_dist: NumPyro distribution class for the coefficient prior. + coef_kwargs: Parameters for the prior distribution. + lmbda_kwargs: Parameters for the scale distribution. + """ + self.lmbda_dist = lmbda_dist + self.coef_dist = coef_dist + self.coef_kwargs = coef_kwargs + self.lmbda_kwargs = lmbda_kwargs + + def __call__( + self, + name: str, + x: jax.Array, + num_categories: int, + embedding_dim: int, + ) -> jax.Array: + """ + Forward pass through embedding lookup. + + Args: + name: Variable name scope. + x: Integer indices indicating embeddings to use. + num_categories: The number of distinct things getting an embedding + embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc. + + Returns: + jax.Array: Embedding vectors of shape ``(n, m)``. + """ + + # sampling block + lmbda = sample( + name=f"{self.__class__.__name__}_{name}_lmbda", + fn=self.lmbda_dist(**self.lmbda_kwargs), + ) + beta = sample( + name=f"{self.__class__.__name__}_{name}_beta", + fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( + [num_categories, embedding_dim] + ), + ) + # matmul and return + return beta[x.squeeze()] + + +class RandomEffectsLayer(BLayer): + """Exactly like the EmbeddingLayer but with ``embedding_dim=1``.""" + + def __init__( + self, + lmbda_dist: distributions.Distribution = distributions.HalfNormal, + coef_dist: distributions.Distribution = distributions.Normal, + coef_kwargs: dict[str, float] = {"loc": 0.0}, + lmbda_kwargs: dict[str, float] = {"scale": 1.0}, + ): + """ + Args: + num_embeddings: Total number of discrete embedding entries. + embedding_dim: Dimensionality of each embedding vector. + coef_dist: Prior distribution for embedding weights. + coef_kwargs: Parameters for the prior distribution. + """ + self.lmbda_dist = lmbda_dist + self.coef_dist = coef_dist + self.coef_kwargs = coef_kwargs + self.lmbda_kwargs = lmbda_kwargs + + def __call__( + self, + name: str, + x: jax.Array, + num_categories: int, + ) -> jax.Array: + """ + Forward pass through embedding lookup. + + Args: + name: Variable name scope. + x: Integer indicating embeddings to use. + num_categories: The number of distinct things getting an embedding + + Returns: + jax.Array: Embedding vectors of shape (n, embedding_dim). + """ + + # sampling block + lmbda = sample( + name=f"{self.__class__.__name__}_{name}_lmbda", + fn=self.lmbda_dist(**self.lmbda_kwargs), + ) + beta = sample( + name=f"{self.__class__.__name__}_{name}_beta", + fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( + [num_categories, 1] + ), + ) + return beta[x.squeeze()] + + +class RandomWalkLayer(BLayer): + """Random walk of embedding dim ``m``, defaults to Gaussian walk.""" + + def __init__( + self, + lmbda_dist: distributions.Distribution = distributions.HalfNormal, + coef_dist: distributions.Distribution = distributions.Normal, + coef_kwargs: dict[str, float] = {"loc": 0.0}, + lmbda_kwargs: dict[str, float] = {"scale": 1.0}, + ): + self.lmbda_dist = lmbda_dist + self.coef_dist = coef_dist + self.coef_kwargs = coef_kwargs + self.lmbda_kwargs = lmbda_kwargs + + def __call__( + self, + name: str, + x: jax.Array, + num_categories: int, + embedding_dim: int, + ) -> jax.Array: + """ + Forward pass through embedding lookup. + + Args: + name: Variable name scope. + x: Integer indices indicating embeddings to use. + num_categories: The number of distinct things getting an embedding + embedding_dim: The size of each embedding, e.g. 2, 4, 8, etc. + + Returns: + jax.Array: Embedding vectors of shape ``(n, m)``. + """ + + # sampling block + lmbda = sample( + name=f"{self.__class__.__name__}_{name}_lmbda", + fn=self.lmbda_dist(**self.lmbda_kwargs), + ) + theta = sample( + name=f"{self.__class__.__name__}_{name}_theta", + fn=self.coef_dist(scale=lmbda, **self.coef_kwargs).expand( + [ + num_categories, + embedding_dim, + ] + ), + ) + # matmul and return + return _matmul_randomwalk(theta, x)