From 346ee3efacc147e3ad0faa56caf1de1635dac266 Mon Sep 17 00:00:00 2001 From: Sebastian Ehlert Date: Mon, 16 Feb 2026 12:48:18 +0100 Subject: [PATCH] Add support for torch::jit::Method --- CMakeLists.txt | 5 +- src/ctorch.cpp | 85 ++++++++++++++++++++++---- src/ctorch.h | 33 ++++++++++ src/ftorch_method.f90 | 137 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 246 insertions(+), 14 deletions(-) create mode 100644 src/ftorch_method.f90 diff --git a/CMakeLists.txt b/CMakeLists.txt index dfa6a9e19..9a48d5860 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,8 @@ endif() # Library with C and Fortran bindings add_library(${LIB_NAME} src/ctorch.cpp src/ftorch.f90 src/ftorch_devices.F90 src/ftorch_types.f90 src/ftorch_tensor.f90 - src/ftorch_model.f90 src/ftorch_test_utils.f90) + src/ftorch_model.f90 src/ftorch_method.f90 + src/ftorch_test_utils.f90) # Define compile definitions, including GPU devices target_compile_definitions( @@ -187,6 +188,8 @@ install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_tensor.mod" DESTINATION "${CMAKE_INSTALL_MODULEDIR}") install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_model.mod" DESTINATION "${CMAKE_INSTALL_MODULEDIR}") +install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_method.mod" + DESTINATION "${CMAKE_INSTALL_MODULEDIR}") install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_test_utils.mod" DESTINATION "${CMAKE_INSTALL_MODULEDIR}") diff --git a/src/ctorch.cpp b/src/ctorch.cpp index 723cf0e01..5e18c2b64 100644 --- a/src/ctorch.cpp +++ b/src/ctorch.cpp @@ -550,6 +550,25 @@ void set_is_training(torch_jit_script_module_t module, const bool is_training = } } +std::vector tensors_to_ivalue_vec(const torch_tensor_t *tensors, const int ntensors) { + // Local IValue for checking we are passed types + torch::jit::IValue LocalTensor; + // Generate a vector of IValues (placeholders for various Torch types) + std::vector ivalue_vec; + ivalue_vec.reserve(ntensors); + // Populate with Tensors pointed at by pointers + // For each IValue check it is of Tensor type + for (int i = 0; i < ntensors; ++i) { + LocalTensor = *(tensors[i]); + if (LocalTensor.isTensor()) { + ivalue_vec.push_back(LocalTensor); + } else { + ctorch_error("One of the inputs to tensors_to_ivalue_vec is not a Tensor"); + } + } + return ivalue_vec; +} + torch_jit_script_module_t torch_jit_load(const char *filename, const torch_device_t device_type = torch_kCPU, const int device_index = -1, @@ -571,6 +590,20 @@ torch_jit_script_module_t torch_jit_load(const char *filename, return module; } +torch_jit_method_t torch_jit_get_method(const torch_jit_script_module_t module, const char *method_name) { + torch::jit::Method *method = nullptr; + auto m = static_cast(module); + try { + method = new torch::jit::Method(m->get_method(method_name)); + } catch (const torch::Error &e) { + ctorch_error(e.msg(), [&]() { delete method; }); + } catch (const std::exception &e) { + ctorch_error(e.what(), [&]() { delete method; }); + } + + return method; +} + void torch_jit_module_forward(const torch_jit_script_module_t module, const torch_tensor_t *inputs, const int nin, torch_tensor_t *outputs, const int nout, @@ -580,20 +613,8 @@ void torch_jit_module_forward(const torch_jit_script_module_t module, auto model = static_cast(module); auto in = reinterpret_cast(inputs); auto out = reinterpret_cast(outputs); - // Local IValue for checking we are passed types - torch::jit::IValue LocalTensor; // Generate a vector of IValues (placeholders for various Torch types) - std::vector inputs_vec; - // Populate with Tensors pointed at by pointers - // For each IValue check it is of Tensor type - for (int i = 0; i < nin; ++i) { - LocalTensor = *(in[i]); - if (LocalTensor.isTensor()) { - inputs_vec.push_back(LocalTensor); - } else { - ctorch_error("One of the inputs to torch_jit_module_forward is not a Tensor"); - } - } + auto inputs_vec = tensors_to_ivalue_vec(inputs, nin); try { auto model_out = model->forward(inputs_vec); if (model_out.isTensor()) { @@ -616,6 +637,39 @@ void torch_jit_module_forward(const torch_jit_script_module_t module, } } +void torch_jit_method_call(const torch_jit_method_t method, + const torch_tensor_t *inputs, const int nin, + torch_tensor_t *outputs, const int nout, + const bool requires_grad = false) { + torch::AutoGradMode enable_grad(requires_grad); + // Here we cast the pointers we recieved in to Tensor objects + auto method = static_cast(method); + auto in = reinterpret_cast(inputs); + auto out = reinterpret_cast(outputs); + // Generate a vector of IValues (placeholders for various Torch types) + auto inputs_vec = tensors_to_ivalue_vec(inputs, nin); + try { + auto method_out = method(inputs_vec); + if (method_out.isTensor()) { + // Single output models will return a tensor directly. + std::move(*out[0]) = method_out.toTensor(); + } else if (method_out.isTuple()) { + // Multiple output models will return a tuple => cast to tensors. + for (int i = 0; i < nout; ++i) { + std::move(*out[i]) = method_out.toTuple()->elements()[i].toTensor(); + } + } else { + // If for some reason the forward method does not return a Tensor it + // should raise an error when trying to cast to a Tensor type + ctorch_error("Method Output is neither Tensor nor Tuple"); + } + } catch (const torch::Error &e) { + ctorch_error(e.msg()); + } catch (const std::exception &e) { + ctorch_error(e.what()); + } +} + void torch_jit_module_print_parameters(const torch_jit_script_module_t module) { auto m = reinterpret_cast(module); for (const auto &[key, value] : m->named_parameters()) { @@ -632,3 +686,8 @@ void torch_jit_module_delete(torch_jit_script_module_t module) { auto m = reinterpret_cast(module); delete m; } + +void torch_jit_method_delete(torch_jit_method_t method) { + auto m = reinterpret_cast(method); + delete m; +} \ No newline at end of file diff --git a/src/ctorch.h b/src/ctorch.h index c60e5b2ea..8ecf62b64 100644 --- a/src/ctorch.h +++ b/src/ctorch.h @@ -16,6 +16,9 @@ // Opaque pointer type alias for torch::jit::script::Module class typedef void *torch_jit_script_module_t; +// Opaque pointer type alias for torch::jit::Method class +typedef void *torch_jit_method_t; + // Opaque pointer type alias for at::Tensor typedef void *torch_tensor_t; @@ -347,6 +350,16 @@ EXPORT_C torch_jit_script_module_t torch_jit_load(const char *filename, const bool requires_grad, const bool is_training); +/** + * Function to load in a Torch model from a TorchScript file and store in a + * Torch Module + * @param module containing the model to get the method from + * @param method_name name of the method to get from the module + * @return Torch Method loaded from module + */ +EXPORT_C torch_jit_method_t torch_jit_get_method(const torch_jit_script_module_t module, + const char *method_name); + /** * Function to run the `forward` method of a Torch Module * @param Torch Module containing the model @@ -361,6 +374,20 @@ EXPORT_C void torch_jit_module_forward(const torch_jit_script_module_t module, torch_tensor_t *outputs, const int nout, const bool requires_grad); +/** + * Function to call a Torch Method + * @param method containing the model method to be called + * @param vector of Torch Tensors as inputs to the method + * @param number of input Tensors in the input vector + * @param vector of Torch Tensors as outputs from running the method + * @param number of output Tensors in the output vector + * @param whether gradient is required + */ +EXPORT_C void torch_jit_method_call(const torch_jit_script_module_t module, + const torch_tensor_t *inputs, const int nin, + torch_tensor_t *outputs, const int nout, + const bool requires_grad); + /** * Function to print out the parameters of a Torch Module * @@ -381,4 +408,10 @@ EXPORT_C bool torch_jit_module_is_training(const torch_jit_script_module_t modul */ EXPORT_C void torch_jit_module_delete(torch_jit_script_module_t module); +/** + * Function to delete a Torch Method to clean up + * @param Torch Method to delete + */ +EXPORT_C void torch_jit_method_delete(torch_jit_method_t method); + #endif /* C_TORCH_H*/ diff --git a/src/ftorch_method.f90 b/src/ftorch_method.f90 new file mode 100644 index 000000000..eba580eca --- /dev/null +++ b/src/ftorch_method.f90 @@ -0,0 +1,137 @@ +!| Module for the FTorch `torch_method` type and associated procedures. +! +! * License +! FTorch is released under an MIT license. +! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE) +! file for details. +module ftorch_method + use, intrinsic :: iso_c_binding, only : c_null_ptr, c_ptr + use ftorch_devices, only: torch_kCPU, torch_kCUDA, torch_kHIP, torch_kXPU, torch_kMPS + use ftorch_model, only: torch_model + use ftorch_types, only: ftorch_int + use ftorch_tensor, only: torch_tensor + + implicit none + + public + + !> Type for holding a Torch Method from a TorchScript model. + type torch_method + type(c_ptr) :: p = c_null_ptr !! pointer to the method in memory + contains + final :: torch_method_delete + end type torch_method + +contains + + ! ============================================================================ + ! --- Procedures for getting methods from a model + ! ============================================================================ + + !> Loads a Torch Method from a TorchScript nn.module (pre-trained PyTorch model saved with TorchScript) + subroutine torch_get_method(method, model, methodname) + use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_null_char + type(torch_method), intent(out) :: method !! Returned deserialized method + type(torch_model), intent(in) :: model !! Model to associate with the method + character(*), intent(in) :: methodname !! Name of the method + integer(c_int) :: device_index_value + + interface + function torch_jit_get_method_c(model_c, methodname_c) result(method_c) & + bind(c, name = 'torch_jit_get_method') + use, intrinsic :: iso_c_binding, only : c_bool, c_char, c_int, c_ptr + implicit none + type(c_ptr), value, intent(in) :: model_c + character(c_char), intent(in) :: methodname_c(*) + type(c_ptr) :: method_c + end function torch_jit_get_method_c + end interface + + ! Need to append c_null_char at end of methodname + method%p = torch_jit_get_method_c(model%p, trim(adjustl(methodname))//c_null_char) + end subroutine torch_get_method + + ! ============================================================================ + ! --- Procedures for performing inference + ! ============================================================================ + + !> Performs a forward pass of the method with the input tensors + subroutine torch_method_call(method, input_tensors, output_tensors, requires_grad) + use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int, c_loc + type(torch_method), intent(in) :: method !! Method + type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors + type(torch_tensor), intent(in), dimension(:) :: output_tensors !! Returned output tensors + logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor + logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor + + integer(ftorch_int) :: i + integer(c_int) :: n_inputs + integer(c_int) :: n_outputs + type(c_ptr), dimension(size(input_tensors)), target :: input_ptrs + type(c_ptr), dimension(size(output_tensors)), target :: output_ptrs + + interface + subroutine torch_jit_method_call_c(method_c, input_tensors_c, n_inputs_c, & + output_tensors_c, n_outputs_c, requires_grad_c) & + bind(c, name = 'torch_jit_method_call') + use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int + implicit none + type(c_ptr), value, intent(in) :: method_c + type(c_ptr), value, intent(in) :: input_tensors_c + integer(c_int), value, intent(in) :: n_inputs_c + type(c_ptr), value, intent(in) :: output_tensors_c + integer(c_int), value, intent(in) :: n_outputs_c + logical(c_bool), value, intent(in) :: requires_grad_c + end subroutine torch_jit_method_call_c + end interface + + n_inputs = size(input_tensors) + n_outputs = size(output_tensors) + + if (.not. present(requires_grad)) then + requires_grad_value = .false. + else + requires_grad_value = requires_grad + end if + + ! Assign array of pointers to the input tensors + do i = 1, n_inputs + input_ptrs(i) = input_tensors(i)%p + end do + + ! Assign array of pointers to the output tensors + do i = 1, n_outputs + output_ptrs(i) = output_tensors(i)%p + end do + + call torch_jit_method_call_c(method%p, c_loc(input_ptrs), n_inputs, & + c_loc(output_ptrs), n_outputs, & + logical(requires_grad_value, c_bool)) + end subroutine torch_method_call + + ! ============================================================================ + ! --- Procedures for deallocating methods + ! ============================================================================ + + !> Deallocates a TorchScript method + subroutine torch_method_delete(method) + use, intrinsic :: iso_c_binding, only : c_associated, c_null_ptr + type(torch_method), intent(inout) :: method !! Torch Method to deallocate + + interface + subroutine torch_jit_method_delete_c(method_c) & + bind(c, name = 'torch_jit_method_delete') + use, intrinsic :: iso_c_binding, only : c_ptr + implicit none + type(c_ptr), value, intent(in) :: method_c + end subroutine torch_jit_method_delete_c + end interface + + ! Call the destructor, if it hasn't already been called + if (c_associated(method%p)) then + call torch_jit_method_delete_c(method%p) + method%p = c_null_ptr + end if + end subroutine torch_method_delete + +end module ftorch_method