-
I passed a jnp array containing only one integer into a scan loop, in order to calibrate that this is the first few loops, just like for_i. But I had to write it in the form of scan. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Hi - it's hard to figure out what code you're executing from the description you gave. Could you edit your question and add a minimal reproducible example that shows the issue you're having, so that we can better answer your question? Your question would also be clearer with properly formatted code; refer to Creating and highlighting code blocks for information on that. Thanks! |
Beta Was this translation helpful? Give feedback.
-
Sure! Thank you!!! There are an minimal example: import jax
from jax.config import config
config.update("jax_enable_x64", True)
import flax.linen as nn
import jax.numpy as jnp
import flax
from typing import Union
from functools import partial
def cplx_variance_scaling(rng, shape, dtype):
rng1, rng2 = jax.random.split(rng)
unif = jax.nn.initializers.uniform(scale=1.)
elems = 1
for k in shape[:-2]:
elems *= k
w = jax.numpy.sqrt((shape[-1] + shape[-2]) * elems)
return (1. / w) * unif(rng1, shape, dtype=jnp.float64) * jax.numpy.exp(1.j * 3.141593 * unif(rng2, shape, dtype=jnp.float64))
def init_fn_args(dtype=None, **kwargs):
init_args = {}
if dtype is not None:
init_args["dtype"] = dtype
if flax.__version__ >= "0.4.0":
init_args["param_dtype"] = dtype
for k in kwargs.keys():
if k[-4:] == "init":
init_args[k] = kwargs[k]
else:
for k in kwargs.keys():
if k[-4:] == "init":
init_args[k] = partial(kwargs[k], dtype=dtype)
return init_args
class RNN1DGeneral(nn.Module):
L: int = 10
depth: int = 1
inputDim: int = 2
hiddenSize: int = 10
actFun: callable = nn.elu
initScale: float = 1.0
logProbFactor: float = 1
realValuedOutput: bool = False
realValuedParams: bool = True
cell: Union[str, list] = "NADE"
initDist:str = None
initialized:bool = True
MConstrains: jnp.ndarray = jnp.array([2, 2, 21, 21])
def setup(self):
if isinstance(self.cell, str) and self.cell != "RNN":
ValueError("Complex parameters for LSTM/GRU not yet implemented.")
if self.realValuedParams:
self.dtype = jnp.float64
self.initFunction = jax.nn.initializers.normal(stddev=1)
else:
self.dtype = jnp.complex128
self.initFunction = partial(cplx_variance_scaling, scale=self.initScale)
if isinstance(self.cell, str):
self.zero_carry = jnp.zeros((self.depth, 1, self.hiddenSize), dtype=self.dtype)
if self.cell == "NADE":
self.cells = [NADECell() for _ in range(self.depth)]
self.zero_carry = jnp.zeros((self.depth, 1), dtype=jnp.int32)
else:
ValueError("Cell name not recognized.")
else:
self.cells = self.cell[0]
self.zero_carry = self.cell[1]
self.rnnCell = RNNCellStack(self.cells, actFun=self.actFun)
def __call__(self, x):
_, probs = self.rnn_cell((self.zero_carry, jnp.zeros(self.inputDim)[None,:],self.MConstrains), (jax.nn.one_hot(x, self.inputDim)))
return jnp.sum(probs)
@partial(nn.transforms.scan,
variable_broadcast='params',
split_rngs={'params': False})
def rnn_cell(self, carry, x):
# x1, MConstrain = x
if self.cell =="NADE":
newCarry, logProb = self.rnnCell(carry[0], carry[1])
else:
newCarry, out = self.rnnCell(carry[0], carry[1])
logProb = self.log_coeffs_to_log_probs(self.outputDense(out),carry[2])
logProb = jnp.sum(logProb * x, axis=-1)
return (newCarry, x), jnp.nan_to_num(logProb, nan=-35)
class RNNCellStack(nn.Module):
"""
Implementation of a stack of RNN-cells which is scanned over an input sequence.
This is achieved by stacking multiple 'vanilla' RNN-cells to obtain a deep RNN.
Arguments:
* ``hiddenSize``: size of the hidden state vector
* ``actFun``: non-linear activation function
* ``initScale``: factor by which the initial parameters are scaled
Returns:
New set of hidden states (one for each layer), as well as the last hidden state, that serves as input to the output layer
"""
cells: list
dtype: type = jnp.float64
actFun: callable = nn.elu
initFun: callable = jax.nn.initializers.variance_scaling(scale=0.1, mode="fan_avg", distribution="normal")
# initFun: callable = jax.nn.initializers.normal(stddev=0.01)
@ nn.compact
def __call__(self, carry, newR):
newCarry = jnp.zeros_like(carry)
newR = nn.Dense(features=carry.shape[-1], use_bias=False,
**init_fn_args(kernel_init=self.initFun, dtype=self.dtype),
name="data_in_dense")(newR)
newR = self.actFun(newR)
for j, (c, cell) in enumerate(zip(carry, self.cells)):
current_carry, newR = cell(c, newR)
newCarry = newCarry.at[j].set(current_carry)
return newCarry, newR
class NADECell(nn.Module):
# actFun: callable = nn.elu
actFun: callable = nn.sigmoid
dtype: type = jnp.float32
L: int = 10
depth: int = 1
inputDim: int = 2
hiddenSize: int = 10
def setup(self):
self.W = self.param('W', jax.nn.initializers.normal(), (self.hiddenSize, self.L), self.dtype)
self.c = self.param('c', jax.nn.initializers.zeros, (self.hiddenSize,), self.dtype)
self.V = self.param('V', jax.nn.initializers.normal(), (self.L, self.inputDim, self.hiddenSize), self.dtype)
self.b = self.param('b', jax.nn.initializers.zeros, (self.L, self.inputDim,), self.dtype)
def select_logprobs(self,i,logprobs):
cases = [lambda: logprobs[j, :] for j in range(self.L)]
return jax.lax.switch(i, cases)
@nn.compact
def __call__(self, i, logprobs):
probs = jnp.exp(jax.jit(self.select_logprobs(i,logprobs),static_argnames=0))
states = jax.random.categorical(self.make_rng('params'), probs, axis=-1).reshape()
h_i = self.actFun(self.c + jnp.einsum("hi,bi->bh", self.W[:, :i], states))
logits = jax.nn.log_softmax(self.b[i,:] + jnp.einsum("oh,bh->bo", self.V[i,:], h_i))
ids = logits[:, i].astype(jnp.int64)
log_prob = jnp.take_along_axis(logits, ids[:, None], axis=1).squeeze(1)
logprobs = jnp.concatenate([logprobs, log_prob[None, :]], axis=0)
newCarry = (i + 1, logprobs)
return newCarry, log_prob The example code includes several functions, among which |
Beta Was this translation helpful? Give feedback.
Thanks - the error I get when running your code is this:
This indicates that you're passing a length-1 array in a place where a scalar is expected. I fixed this by changing this:
to this:
Running again, I get this error:
This comes because you're passing an array to JIT rather than a function. To fix this I changed this: