@@ -62,8 +62,8 @@ def create_train_step_fn(
6262 if forward_kwargs is None :
6363 forward_kwargs = {}
6464
65- def loss_fn (variables : Any , inp : jnp .ndarray , grad_target : jnp .ndarray , dropout_key ):
66- rngs = {' dropout' : dropout_key }
65+ def loss_fn (variables : Any , inp : jnp .ndarray , grad_target : jnp .ndarray , dropout_key ):
66+ rngs = {" dropout" : dropout_key }
6767 with te .fp8_autocast (** fp8_autocast_kwargs ):
6868 # Forward Pass: Apply the model using current parameters and variables
6969 call_kwargs = {** forward_kwargs , "rngs" : rngs }
@@ -97,21 +97,20 @@ def create_train_step_fn_vjp(
9797
9898 def train_step_fn (variables : Any , inp : jnp .ndarray , grad_target : jnp .ndarray , dropout_key ):
9999 """Compute forward pass and VJP in one step"""
100-
100+
101101 # Define forward function that closes over grad_target and dropout_key
102102 def forward_fn (variables : Any , inp : jnp .ndarray ):
103103 """Pure forward function for VJP computation"""
104- rngs = {' dropout' : dropout_key }
104+ rngs = {" dropout" : dropout_key }
105105 with te .fp8_autocast (** fp8_autocast_kwargs ):
106- call_kwargs = {** forward_kwargs , ' rngs' : rngs }
106+ call_kwargs = {** forward_kwargs , " rngs" : rngs }
107107 return model_apply_fn (variables , inp , ** call_kwargs )
108-
108+
109109 # Compute forward pass and get VJP function (w.r.t. variables and inp)
110110 output , vjp_fn = jax .vjp (forward_fn , variables , inp )
111111
112112 # Compute gradients using VJP - returns gradients w.r.t. variables and inp
113113 var_grads , inp_grads = vjp_fn (grad_target )
114-
115114 # Return loss value and gradients
116115 loss_value = jnp .vdot (output , grad_target )
117116 return loss_value , (var_grads , inp_grads )
0 commit comments