Skip to content

Commit 2796e91

Browse files
committed
merging greptile changes again
2 parents 11d1bfb + 3ea1c2f commit 2796e91

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

docs/examples/quickstart_jax_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def create_train_step_fn(
6262
if forward_kwargs is None:
6363
forward_kwargs = {}
6464

65-
def loss_fn(variables : Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
66-
rngs = {'dropout': dropout_key}
65+
def loss_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
66+
rngs = {"dropout": dropout_key}
6767
with te.fp8_autocast(**fp8_autocast_kwargs):
6868
# Forward Pass: Apply the model using current parameters and variables
6969
call_kwargs = {**forward_kwargs, "rngs": rngs}
@@ -97,21 +97,20 @@ def create_train_step_fn_vjp(
9797

9898
def train_step_fn(variables: Any, inp: jnp.ndarray, grad_target: jnp.ndarray, dropout_key):
9999
"""Compute forward pass and VJP in one step"""
100-
100+
101101
# Define forward function that closes over grad_target and dropout_key
102102
def forward_fn(variables: Any, inp: jnp.ndarray):
103103
"""Pure forward function for VJP computation"""
104-
rngs = {'dropout': dropout_key}
104+
rngs = {"dropout": dropout_key}
105105
with te.fp8_autocast(**fp8_autocast_kwargs):
106-
call_kwargs = {**forward_kwargs, 'rngs': rngs}
106+
call_kwargs = {**forward_kwargs, "rngs": rngs}
107107
return model_apply_fn(variables, inp, **call_kwargs)
108-
108+
109109
# Compute forward pass and get VJP function (w.r.t. variables and inp)
110110
output, vjp_fn = jax.vjp(forward_fn, variables, inp)
111111

112112
# Compute gradients using VJP - returns gradients w.r.t. variables and inp
113113
var_grads, inp_grads = vjp_fn(grad_target)
114-
115114
# Return loss value and gradients
116115
loss_value = jnp.vdot(output, grad_target)
117116
return loss_value, (var_grads, inp_grads)

0 commit comments

Comments
 (0)