Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions on Using nnx.value_and_grad for Loss Calculation and Model Decoupling in Flax NNX #4476

Open
Tomato-toast opened this issue Jan 10, 2025 · 4 comments

Comments

@Tomato-toast
Copy link

Hello everyone.

While implementing the A2C reinforcement learning algorithm using Flax NNX, I encountered some challenges and would appreciate your guidance. Below is a simplified code:

  def run_epoch(policy_network , value_network , observation):
      
      def data(policy_network, observation):
          data = policy_network(observation)
          return data
  
      def compute_value(value_network, data):
          value = value_network(data.observation)
          return value
  
      def compute_advantage(value):
          advantage  # Derived from a series of computations based on value
          return advantage
  
      def compute_policy_loss(policy_network, data, advantage):
          policy_loss = -jnp.mean(jax.lax.stop_gradient(advantage) * data.log_prob)
          return policy_loss
  
      def compute_critic_loss(value_network, advantage):
          critic_loss = jnp.mean(advantage**2)
          return critic_loss
  
      data = data(self.policy_network, observation)
      value = compute_value(self.value_network, data)
      advantage = compute_advantage(value)
  
      policy_loss, policy_grad = nnx.value_and_grad(
          compute_policy_loss, has_aux=False
      )(self.policy_network, data, advantage)
  
      value_loss, value_grad = nnx.value_and_grad(
          compute_critic_loss, has_aux=False
      )(self.value_network, advantage)
  
      return policy_grad, value_grad

Background

  • Both policy_network and value_network are complex models based on Transformer modules, and their forward pass involves multi-layer computational logic. These networks are implemented as nnx.Module.
  • Since both loss functions require advantage, which depends on the forward pass of both policy_network and value_network, I attempted to extract the forward pass as a separate step. I then passed only the results of the forward pass to compute_policy_loss and compute_critic_loss.
  • However, during execution, if the models are not passed directly to the loss functions, the code raises an error. In all examples in the documentation, nnx.value_and_grad seems to require passing the models directly to the loss function. As a workaround, I passed policy_network and value_network into compute_policy_loss and compute_critic_loss, and the code worked. However, the twonnx.Module instances are not explicitly used within these functions.

Questions

  • Is it mandatory to pass the model to the loss function for proper gradient computation? Is it possible to use only the results of the forward pass in the loss calculation, without passing the entire model?
  • Can the forward pass of the model be decoupled from the loss calculation? Specifically, can the forward pass be extracted as a separate step without affecting the proper functioning of nnx.value_and_grad?
@cgarciae
Copy link
Collaborator

Hi @Tomato-toast.

Is it mandatory to pass the model to the loss function for proper gradient computation? Is it possible to use only the results of the forward pass in the loss calculation, without passing the entire model?

Modules can be passed as captures as long as as no updates/mutations are performed inside. If any Module or Variable is updated inside a transform when they are passed as a capture you will get an error to avoid tracer leakage (and overall correctness).

Can the forward pass of the model be decoupled from the loss calculation? Specifically, can the forward pass be extracted as a separate step without affecting the proper functioning of nnx.value_and_grad?

Not sure I fully understand the question here. You mean you want to cache the value forward and pass it as an input to the losses, or you want to return the value of the forward from the losses?

BTW: In the real code, compute_policy_loss and compute_critic_loss actually use the input networks?

@Tomato-toast
Copy link
Author

Hello, @cgarciae thank you so much for your response,but I still have some questions.

Question 1:

When using nnx.value_and_grad , I noticed that the first argument of the loss function must be an nnx.Module instance; otherwise, an error occurs. For example:

  policy_loss, policy_grad = nnx.value_and_grad(
      compute_policy_loss, has_aux=False
  )(data, advantage)

This throws the following error:

  ValueError: Expected named tuple, got State({
      ‘policy_heard':{...
      }
      'torso':{...
      }
  })

However, if I write it as:

  policy_loss, policy_grad = nnx.value_and_grad(
      compute_policy_loss, has_aux=False
  )(self.policy_network, data, advantage)

It works fine, even though self.policy_network are not actually used in the compute_policy_loss functions. I’m unsure why this happens.

Question 2:

I want to know if caching intermediate values from the forward pass and using them as inputs to the loss function would interfere with the gradient computation (nnx.value_and_grad).

BTW: in my real code:

policy_network is used to compute data:

  def data(policy_network, observation):
      data = policy_network(observation)
      return data

value_network is used to compute value:

  def compute_value(value_network, data):
      value = value_network(data.observation)
      return value

During the computation of policy_loss and critic_loss, self.policy_network and self.value_network are passed in but are not directly utilized in the function bodies:

  def compute_policy_loss(policy_network, data, advantage):
      policy_loss = -jnp.mean(jax.lax.stop_gradient(advantage) * data.log_prob)
      return policy_loss
  
  def compute_critic_loss(value_network, advantage):
      critic_loss = jnp.mean(advantage**2)
      return critic_loss

I’d greatly appreciate your clarification. Thank you again for your help!

@cgarciae
Copy link
Collaborator

cgarciae commented Jan 20, 2025

Hi @Tomato-toast .

Question 1: This might be a bug, you should be able to get a gradient for non-Modules. I'll look into it. Can you post a minimal repro here?

Question 2: yeah if you pass an intermediate from one loss function to avoid computation on a second loss function you will get a different gradient (e.g. lots of zeros)

@Tomato-toast
Copy link
Author

Hi, @cgarciae thank you very much for your answer, it helped me a lot.
Here is the minimal repro:

  from flax import nnx
  import jax
  import optax
  import jax.numpy as jnp
  import matplotlib.pyplot as plt
  
  class Block(nnx.Module):
    def __init__(self, in_features: int , out_features: int, rngs: nnx.Rngs):
      self.linear = nnx.Linear(in_features, out_features, rngs=rngs)
      self.dropout = nnx.Dropout(0.5, rngs=rngs)
  
    def __call__(self, x):
      x = self.linear(x)
      x = self.dropout(x)
      x = jax.nn.relu(x)
      return x
  
  class Model(nnx.Module):
    def __init__(self, din: int, dmid: int, dout: int, rngs: nnx.Rngs):
      self.block = Block(din, dmid, rngs=rngs)
      self.linear = nnx.Linear(dmid, dout, rngs=rngs)
  
    def __call__(self, x):
      x = self.block(x)
      x = self.linear(x)
      return x
  
  # test
  inputs = jnp.ones((32, 784))  
  labels = jnp.zeros(32, dtype=jnp.int32)  
  
  model = Model(784, 256, 10, rngs=nnx.Rngs(0))
  
  logits_history = []
  loss_history = []
  grads_history = []
  
  # train
  for step in range(5):
      def train_step(model, inputs, labels):
          def loss(model):
              logits = model(inputs)
              return logits
          logits = loss(model)
             
          def loss_fn(model, logits):
              # logits = model(inputs)
              loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
              return loss, logits
          grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
          (loss, logits), grads = grad_fn(model, logits)
          optimizer = nnx.Optimizer(model, optax.adam(2e-4))
          optimizer.update(grads)
      
          return loss, logits, grads
  
  
      loss, logits, grads = train_step(model, inputs, labels)
  
      logits_history.append(logits.mean().item())
      loss_history.append(loss.item())
      grads_norm = jax.tree_util.tree_map(lambda g: jnp.linalg.norm(g), grads)
      grads_mean_norm = jnp.mean(jnp.array([g.item() for g in jax.tree_util.tree_flatten(grads_norm)[0]]))
      grads_history.append(grads_mean_norm)
  
  
  
  
  plt.figure(figsize=(15, 5))
  
  # logits 
  plt.subplot(1, 3, 1)
  plt.plot(logits_history, label="Logits Mean")
  plt.xlabel("Steps")
  plt.ylabel("Logits Mean")
  plt.title("Logits")
  plt.legend()
  
  # loss 
  plt.subplot(1, 3, 2)
  plt.plot(loss_history, label="Loss")
  plt.xlabel("Steps")
  plt.ylabel("Loss")
  plt.title("Loss")
  plt.legend()
  
  # grads 
  plt.subplot(1, 3, 3)
  plt.plot(grads_history, label="Grad Norm Mean")
  plt.xlabel("Steps")
  plt.ylabel("Gradient Norm Mean")
  plt.title("Grads")
  plt.legend()
  
  plt.tight_layout()
  plt.savefig("output.png")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants