Skip to content

Commit

Permalink
remove flattening and unflattening
Browse files Browse the repository at this point in the history
of tensor subclasses from `nondifferentiable_vjp_symbols`
since the trace transform of tensor subclasses comes after VJP

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Dec 23, 2024
1 parent e5d26fc commit 56c69df
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2509,8 +2509,6 @@ def uniform_backward(primal, minval, maxval, g):
prims.PrimIDs.BITWISE_XOR,
prims.PrimIDs.SIGNBIT,
prims.PrimIDs.FULL,
prims.PrimIDs.FLATTEN_TENSOR_SUBCLASS,
prims.PrimIDs.UNFLATTEN_TENSOR_SUBCLASS,
}


Expand Down

0 comments on commit 56c69df

Please sign in to comment.