Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ endif()
# Library with C and Fortran bindings
add_library(${LIB_NAME} src/ctorch.cpp src/ftorch.f90 src/ftorch_devices.F90
src/ftorch_types.f90 src/ftorch_tensor.f90
src/ftorch_model.f90 src/ftorch_test_utils.f90)
src/ftorch_model.f90 src/ftorch_method.f90
src/ftorch_test_utils.f90)

# Define compile definitions, including GPU devices
target_compile_definitions(
Expand Down Expand Up @@ -187,6 +188,8 @@ install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_tensor.mod"
DESTINATION "${CMAKE_INSTALL_MODULEDIR}")
install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_model.mod"
DESTINATION "${CMAKE_INSTALL_MODULEDIR}")
install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_method.mod"
DESTINATION "${CMAKE_INSTALL_MODULEDIR}")
install(FILES "${CMAKE_Fortran_MODULE_DIRECTORY}/ftorch_test_utils.mod"
DESTINATION "${CMAKE_INSTALL_MODULEDIR}")

Expand Down
85 changes: 72 additions & 13 deletions src/ctorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,25 @@ void set_is_training(torch_jit_script_module_t module, const bool is_training =
}
}

std::vector<torch::jit::IValue> tensors_to_ivalue_vec(const torch_tensor_t *tensors, const int ntensors) {
// Local IValue for checking we are passed types
torch::jit::IValue LocalTensor;
// Generate a vector of IValues (placeholders for various Torch types)
std::vector<torch::jit::IValue> ivalue_vec;
ivalue_vec.reserve(ntensors);
// Populate with Tensors pointed at by pointers
// For each IValue check it is of Tensor type
for (int i = 0; i < ntensors; ++i) {
LocalTensor = *(tensors[i]);
if (LocalTensor.isTensor()) {
ivalue_vec.push_back(LocalTensor);
} else {
ctorch_error("One of the inputs to tensors_to_ivalue_vec is not a Tensor");
}
}
return ivalue_vec;
}

torch_jit_script_module_t torch_jit_load(const char *filename,
const torch_device_t device_type = torch_kCPU,
const int device_index = -1,
Expand All @@ -571,6 +590,20 @@ torch_jit_script_module_t torch_jit_load(const char *filename,
return module;
}

torch_jit_method_t torch_jit_get_method(const torch_jit_script_module_t module, const char *method_name) {
torch::jit::Method *method = nullptr;
auto m = static_cast<torch::jit::script::Module *>(module);
try {
method = new torch::jit::Method(m->get_method(method_name));
} catch (const torch::Error &e) {
ctorch_error(e.msg(), [&]() { delete method; });
} catch (const std::exception &e) {
ctorch_error(e.what(), [&]() { delete method; });
}

return method;
}

void torch_jit_module_forward(const torch_jit_script_module_t module,
const torch_tensor_t *inputs, const int nin,
torch_tensor_t *outputs, const int nout,
Expand All @@ -580,20 +613,8 @@ void torch_jit_module_forward(const torch_jit_script_module_t module,
auto model = static_cast<torch::jit::script::Module *>(module);
auto in = reinterpret_cast<torch::Tensor *const *>(inputs);
auto out = reinterpret_cast<torch::Tensor **>(outputs);
// Local IValue for checking we are passed types
torch::jit::IValue LocalTensor;
// Generate a vector of IValues (placeholders for various Torch types)
std::vector<torch::jit::IValue> inputs_vec;
// Populate with Tensors pointed at by pointers
// For each IValue check it is of Tensor type
for (int i = 0; i < nin; ++i) {
LocalTensor = *(in[i]);
if (LocalTensor.isTensor()) {
inputs_vec.push_back(LocalTensor);
} else {
ctorch_error("One of the inputs to torch_jit_module_forward is not a Tensor");
}
}
auto inputs_vec = tensors_to_ivalue_vec(inputs, nin);
try {
auto model_out = model->forward(inputs_vec);
if (model_out.isTensor()) {
Expand All @@ -616,6 +637,39 @@ void torch_jit_module_forward(const torch_jit_script_module_t module,
}
}

void torch_jit_method_call(const torch_jit_method_t method,
const torch_tensor_t *inputs, const int nin,
torch_tensor_t *outputs, const int nout,
const bool requires_grad = false) {
torch::AutoGradMode enable_grad(requires_grad);
// Here we cast the pointers we recieved in to Tensor objects
auto method = static_cast<torch::jit::Method *>(method);
auto in = reinterpret_cast<torch::Tensor *const *>(inputs);
auto out = reinterpret_cast<torch::Tensor **>(outputs);
// Generate a vector of IValues (placeholders for various Torch types)
auto inputs_vec = tensors_to_ivalue_vec(inputs, nin);
try {
auto method_out = method(inputs_vec);
if (method_out.isTensor()) {
// Single output models will return a tensor directly.
std::move(*out[0]) = method_out.toTensor();
} else if (method_out.isTuple()) {
// Multiple output models will return a tuple => cast to tensors.
for (int i = 0; i < nout; ++i) {
std::move(*out[i]) = method_out.toTuple()->elements()[i].toTensor();
}
} else {
// If for some reason the forward method does not return a Tensor it
// should raise an error when trying to cast to a Tensor type
ctorch_error("Method Output is neither Tensor nor Tuple");
}
} catch (const torch::Error &e) {
ctorch_error(e.msg());
} catch (const std::exception &e) {
ctorch_error(e.what());
}
}

void torch_jit_module_print_parameters(const torch_jit_script_module_t module) {
auto m = reinterpret_cast<torch::jit::script::Module *>(module);
for (const auto &[key, value] : m->named_parameters()) {
Expand All @@ -632,3 +686,8 @@ void torch_jit_module_delete(torch_jit_script_module_t module) {
auto m = reinterpret_cast<torch::jit::script::Module *>(module);
delete m;
}

void torch_jit_method_delete(torch_jit_method_t method) {
auto m = reinterpret_cast<torch::jit::Method *>(method);
delete m;
}
33 changes: 33 additions & 0 deletions src/ctorch.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
// Opaque pointer type alias for torch::jit::script::Module class
typedef void *torch_jit_script_module_t;

// Opaque pointer type alias for torch::jit::Method class
typedef void *torch_jit_method_t;

// Opaque pointer type alias for at::Tensor
typedef void *torch_tensor_t;

Expand Down Expand Up @@ -347,6 +350,16 @@ EXPORT_C torch_jit_script_module_t torch_jit_load(const char *filename,
const bool requires_grad,
const bool is_training);

/**
* Function to load in a Torch model from a TorchScript file and store in a
* Torch Module
* @param module containing the model to get the method from
* @param method_name name of the method to get from the module
* @return Torch Method loaded from module
*/
EXPORT_C torch_jit_method_t torch_jit_get_method(const torch_jit_script_module_t module,
const char *method_name);

/**
* Function to run the `forward` method of a Torch Module
* @param Torch Module containing the model
Expand All @@ -361,6 +374,20 @@ EXPORT_C void torch_jit_module_forward(const torch_jit_script_module_t module,
torch_tensor_t *outputs, const int nout,
const bool requires_grad);

/**
* Function to call a Torch Method
* @param method containing the model method to be called
* @param vector of Torch Tensors as inputs to the method
* @param number of input Tensors in the input vector
* @param vector of Torch Tensors as outputs from running the method
* @param number of output Tensors in the output vector
* @param whether gradient is required
*/
EXPORT_C void torch_jit_method_call(const torch_jit_script_module_t module,
const torch_tensor_t *inputs, const int nin,
torch_tensor_t *outputs, const int nout,
const bool requires_grad);

/**
* Function to print out the parameters of a Torch Module
*
Expand All @@ -381,4 +408,10 @@ EXPORT_C bool torch_jit_module_is_training(const torch_jit_script_module_t modul
*/
EXPORT_C void torch_jit_module_delete(torch_jit_script_module_t module);

/**
* Function to delete a Torch Method to clean up
* @param Torch Method to delete
*/
EXPORT_C void torch_jit_method_delete(torch_jit_method_t method);

#endif /* C_TORCH_H*/
137 changes: 137 additions & 0 deletions src/ftorch_method.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
!| Module for the FTorch `torch_method` type and associated procedures.
!
! * License
! FTorch is released under an MIT license.
! See the [LICENSE](https://github.com/Cambridge-ICCS/FTorch/blob/main/LICENSE)
! file for details.
module ftorch_method
use, intrinsic :: iso_c_binding, only : c_null_ptr, c_ptr
use ftorch_devices, only: torch_kCPU, torch_kCUDA, torch_kHIP, torch_kXPU, torch_kMPS
use ftorch_model, only: torch_model
use ftorch_types, only: ftorch_int
use ftorch_tensor, only: torch_tensor

implicit none

public

!> Type for holding a Torch Method from a TorchScript model.
type torch_method
type(c_ptr) :: p = c_null_ptr !! pointer to the method in memory
contains
final :: torch_method_delete
end type torch_method

contains

! ============================================================================
! --- Procedures for getting methods from a model
! ============================================================================

!> Loads a Torch Method from a TorchScript nn.module (pre-trained PyTorch model saved with TorchScript)
subroutine torch_get_method(method, model, methodname)
use, intrinsic :: iso_c_binding, only : c_bool, c_int, c_null_char
type(torch_method), intent(out) :: method !! Returned deserialized method
type(torch_model), intent(in) :: model !! Model to associate with the method
character(*), intent(in) :: methodname !! Name of the method
integer(c_int) :: device_index_value

interface
function torch_jit_get_method_c(model_c, methodname_c) result(method_c) &
bind(c, name = 'torch_jit_get_method')
use, intrinsic :: iso_c_binding, only : c_bool, c_char, c_int, c_ptr
implicit none
type(c_ptr), value, intent(in) :: model_c
character(c_char), intent(in) :: methodname_c(*)
type(c_ptr) :: method_c
end function torch_jit_get_method_c
end interface

! Need to append c_null_char at end of methodname
method%p = torch_jit_get_method_c(model%p, trim(adjustl(methodname))//c_null_char)
end subroutine torch_get_method

! ============================================================================
! --- Procedures for performing inference
! ============================================================================

!> Performs a forward pass of the method with the input tensors
subroutine torch_method_call(method, input_tensors, output_tensors, requires_grad)
use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int, c_loc
type(torch_method), intent(in) :: method !! Method
type(torch_tensor), intent(in), dimension(:) :: input_tensors !! Array of Input tensors
type(torch_tensor), intent(in), dimension(:) :: output_tensors !! Returned output tensors
logical, optional, intent(in) :: requires_grad !! Whether gradients need to be computed for the created tensor
logical :: requires_grad_value !! Whether gradients need to be computed for the created tensor

integer(ftorch_int) :: i
integer(c_int) :: n_inputs
integer(c_int) :: n_outputs
type(c_ptr), dimension(size(input_tensors)), target :: input_ptrs
type(c_ptr), dimension(size(output_tensors)), target :: output_ptrs

interface
subroutine torch_jit_method_call_c(method_c, input_tensors_c, n_inputs_c, &
output_tensors_c, n_outputs_c, requires_grad_c) &
bind(c, name = 'torch_jit_method_call')
use, intrinsic :: iso_c_binding, only : c_bool, c_ptr, c_int
implicit none
type(c_ptr), value, intent(in) :: method_c
type(c_ptr), value, intent(in) :: input_tensors_c
integer(c_int), value, intent(in) :: n_inputs_c
type(c_ptr), value, intent(in) :: output_tensors_c
integer(c_int), value, intent(in) :: n_outputs_c
logical(c_bool), value, intent(in) :: requires_grad_c
end subroutine torch_jit_method_call_c
end interface

n_inputs = size(input_tensors)
n_outputs = size(output_tensors)

if (.not. present(requires_grad)) then
requires_grad_value = .false.
else
requires_grad_value = requires_grad
end if

! Assign array of pointers to the input tensors
do i = 1, n_inputs
input_ptrs(i) = input_tensors(i)%p
end do

! Assign array of pointers to the output tensors
do i = 1, n_outputs
output_ptrs(i) = output_tensors(i)%p
end do

call torch_jit_method_call_c(method%p, c_loc(input_ptrs), n_inputs, &
c_loc(output_ptrs), n_outputs, &
logical(requires_grad_value, c_bool))
end subroutine torch_method_call

! ============================================================================
! --- Procedures for deallocating methods
! ============================================================================

!> Deallocates a TorchScript method
subroutine torch_method_delete(method)
use, intrinsic :: iso_c_binding, only : c_associated, c_null_ptr
type(torch_method), intent(inout) :: method !! Torch Method to deallocate

interface
subroutine torch_jit_method_delete_c(method_c) &
bind(c, name = 'torch_jit_method_delete')
use, intrinsic :: iso_c_binding, only : c_ptr
implicit none
type(c_ptr), value, intent(in) :: method_c
end subroutine torch_jit_method_delete_c
end interface

! Call the destructor, if it hasn't already been called
if (c_associated(method%p)) then
call torch_jit_method_delete_c(method%p)
method%p = c_null_ptr
end if
end subroutine torch_method_delete

end module ftorch_method