Skip to content

Commit

Permalink
fix: tf backend get_item for tensor queries
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Sep 15, 2024
1 parent 8183a22 commit 25ff001
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,15 @@ def get_item(
) -> Union[tf.Tensor, tf.Variable]:
if ivy.is_array(query) and ivy.is_bool_dtype(query) and not len(query.shape):
return tf.expand_dims(x, 0)
return x[query]
if isinstance(query, tf.Tensor):
if query.dtype == tf.bool:
return tf.boolean_mask(x, query, axis=0)
else:
query = tf.cast(query, tf.int64)
return tf.gather(x, query, axis=0)
else:
# for slices and other basic indexing, use __getitem__
return x[query]


def to_numpy(x: Union[tf.Tensor, tf.Variable], /, *, copy: bool = True) -> np.ndarray:
Expand Down

0 comments on commit 25ff001

Please sign in to comment.