Skip to content

Commit

Permalink
Fix test_validate_tt_tensor test (#24167)
Browse files Browse the repository at this point in the history
  • Loading branch information
mobley-trent authored Sep 19, 2023
1 parent 7b42da8 commit 0e0816f
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions ivy_tests/test_ivy/test_misc/test_tt_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,9 @@ def test_validate_tt_rank(coef):
tensor_shape = tuple(ivy.random.randint(5, 10, shape=(4,)))
n_param_tensor = ivy.prod(tensor_shape)

# TODO: This test fails even for the native implementation
# https://github.com/tensorly/tensorly/issues/529
# rank = ivy.TTTensor.validate_tt_rank(tensor_shape, coef, rounding="floor")
# n_param = ivy.TTTensor._tt_n_param(tensor_shape, rank)
# np.testing.assert_(n_param >= n_param_tensor * coef)
rank = ivy.TTTensor.validate_tt_rank(tensor_shape, coef, rounding="floor")
n_param = ivy.TTTensor._tt_n_param(tensor_shape, rank)
np.testing.assert_(n_param <= n_param_tensor * coef)

rank = ivy.TTTensor.validate_tt_rank(tensor_shape, coef, rounding="ceil")
n_param = ivy.TTTensor._tt_n_param(tensor_shape, rank)
Expand Down

0 comments on commit 0e0816f

Please sign in to comment.