Skip to content

Commit

Permalink
Bugfix: Initialize grad_input, grad_weight to None (#68)
Browse files Browse the repository at this point in the history
* initialize grad_input to None

* minor
  • Loading branch information
adityaranjan authored and Sathwik Yanamaddi committed Apr 12, 2024
1 parent 5087268 commit 5faec5b
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ def backward(ctx, grad_output):
if dist.get_world_size(ctx.backward_all_reduce_group) > 1 or (
not overlap_reduce_scatter
):
grad_input, grad_weight = None, None

if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
Expand Down Expand Up @@ -136,6 +138,8 @@ def backward(ctx, grad_output):
grad_weight = None # weight gradients are not ready yet
return grad_input, grad_weight, None, None, None, None, None, None, None
else:
grad_input, grad_weight = None, None

if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
Expand Down

0 comments on commit 5faec5b

Please sign in to comment.