Skip to content

Commit

Permalink
test grad for normalize: increase eps to make test less flaky (#823)
Browse files Browse the repository at this point in the history
  • Loading branch information
t-vi authored Jul 22, 2024
1 parent 85df4f6 commit 09846ed
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6144,23 +6144,25 @@ def tensor_constructor_error_generator(op, device, **kwargs):


def normalize_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
def make(shape):
input_tensor = make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad)
# avoid very small norm tensors, which can be unstable to normalize
input_tensor = input_tensor + 0.2 * torch.sign(input_tensor)
return input_tensor

# input shape
cases = (
shapes = (
(4, 4),
(32, 32),
(16, 16, 16),
(4, 2, 4, 5),
)
for case in cases:
input_tensor = make(case)
# avoid very small norm tensors, which can be unstable to normalize
input_tensor = input_tensor + 0.2 * torch.sign(input_tensor)
yield SampleInput(input_tensor, eps=1e-8)
yield SampleInput(input_tensor, p=0, eps=1e-8)
yield SampleInput(input_tensor, p=1, eps=1e-8)
yield SampleInput(input_tensor, p=4, eps=1e-8)
yield SampleInput(input_tensor, p=math.inf, eps=1e-8)
for shape in shapes:
yield SampleInput(make(shape), eps=1e-6)
yield SampleInput(make(shape), p=0, eps=1e-6)
yield SampleInput(make(shape), p=1, eps=1e-6)
yield SampleInput(make(shape), p=4, eps=1e-6)
yield SampleInput(make(shape), p=math.inf, eps=1e-6)


normalize_opinfo = OpInfo(
Expand Down

0 comments on commit 09846ed

Please sign in to comment.