Skip to content

Commit

Permalink
Improve test case for test_vmap_fix
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Jan 6, 2025
1 parent 8d415de commit 3e1d31e
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions unit_test/core/test_vmap_fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ def test_distance_fn_with_mask(self):
self.assertIsNotNone(distances(self.costs[:-1], self.mask[:-1]))

def test_distance_fn_without_mask(self):
distances = jit(
distance_fn, trace=True, lazy=False, example_inputs=(self.costs,)
)
self.assertIsNotNone(distances(self.costs))

def test_distance_fn_with_none(self):
distances = jit(
distance_fn, trace=True, lazy=False, example_inputs=(self.costs, None)
)
Expand Down

0 comments on commit 3e1d31e

Please sign in to comment.