diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index 46f95c5ec0641..ff6481068e550 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -633,6 +633,16 @@ def acos_(self): self.ivy_array = self.acos().ivy_array return self + @to_ivy_arrays_and_back + def to_cpu(self): + if ( + ivy_framework.current_framework_str() == "torch" + and ivy_framework.current_device_str() != "cpu" + ): + return ivy.to_device(self, "cpu") + else: + return self + def new_tensor( self, data, diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py index ab4c97560d58c..d8e0b1c19046e 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_tensor.py @@ -52,8 +52,14 @@ _quantile_helper, ) +import unittest +import torch +from unittest.mock import patch +from ivy_test import helpers +from ivy_test.helpers import CLASS_TREE, handle_frontend_methodtry: + try: - import torch + import torch except ImportError: torch = SimpleNamespace() @@ -7831,6 +7837,55 @@ def test_torch_index_select( ) +#cpu +@handle_frontend_method +class TestTorchInstanceToCPU(unittest.TestCase):( + class_tree=CLASS_TREE, + init_tree="torch.tensor", + method_name="ivy.to_device", + dtype_and_x=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + num_arrays=1, + min_value=-1e04, + max_value=1e04, + allow_inf=False, + ), + ) + def test_torch_instance_to_cpu( + self, + dtype_and_x, + frontend, + backend_fw, + frontend_method_data, + init_flags, + method_flags, + ): + input_dtype, x = dtype_and_x + with patch("ivy_framework.current_framework_str", return_value="torch"), \ + patch("ivy_framework.current_device_str", return_value="cpu"): + instance = frontend.init_all_as_kwargs_np( + input_dtypes=input_dtype, data=x[0] + ) + + result = frontend.frontend_method_data( + instance, method_name="to_cpu", input_dtypes=input_dtype + ) + + self.assertTrue(torch.all(result.data.cpu() == instance.data.cpu())) + + with patch("ivy_framework.current_framework_str", return_value="numpy"): + result = frontend.frontend_method_data( + instance, method_name="to_cpu", input_dtypes=input_dtype + ) + + + self.assertEqual(result, instance) + + +if __name__ == "__main__": + unittest.main() + + # int @handle_frontend_method( class_tree=CLASS_TREE,