Skip to content

Commit

Permalink
Add MatX, update ORT version, pybind11 version (#189)
Browse files Browse the repository at this point in the history
* Add MatX

* matx

* better

* skip whatever does not work

* use 1.19.2

* e

* improve error message

* listdir

* fix cmake

* fix windows build
  • Loading branch information
xadupre authored Sep 17, 2024
1 parent 5004ce6 commit ff2c70e
Show file tree
Hide file tree
Showing 20 changed files with 314 additions and 29 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,4 @@ onnx_extended/validation/cython/*.c
onnx_extended/validation/cython/*.cpp
onnx_extended/validation/cython/vector_function_cy.c*
onnx_extended/ortcy/wrap/ortinf.c*
onnx_extended/ortcy/wrap/*.lib
2 changes: 1 addition & 1 deletion CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Change Logs
0.3.0
+++++

* :pr:`186`: support numpy 2.0
* :pr:`189`: use onnxruntime==1.19.2 as default, pybind11 2.13.5, MatX 0.8.0
* :pr:`187`: Fix compilation with GCC>=13 #187
* :pr:`185`: adds custom operator MulMulSigmoid on CUDA
* :pr:`184`: use onnxruntime==1.18.0 as default
Expand Down
10 changes: 10 additions & 0 deletions _cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ message(STATUS "--------------------------------------------")

include("targets/common.cmake")

#
# display all variables
#

get_cmake_property(_variableNames VARIABLES)
list (SORT _variableNames)
foreach (_variableName ${_variableNames})
message(STATUS "---- ${_variableName}=${${_variableName}}")
endforeach()

#
# standalone modules
#
Expand Down
32 changes: 32 additions & 0 deletions _cmake/externals/FindLocalMatX.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#
# initialization
#
# defines matx matx_SOURCE_DIR matx_BINARY_DIR

#
# matx
#

set(matx_TAG "v0.8.0")

include(FetchContent)
FetchContent_Declare(
matx
GIT_REPOSITORY https://github.com/NVIDIA/matx
GIT_TAG ${matx_TAG})

FetchContent_MakeAvailable(matx)
FetchContent_GetProperties(matx)

set(matx_VERSION ${matx_TAG})
set(MATX_INCLUDE_DIR "${matx_SOURCE_DIR}/include")
message(STATUS "matx_BINARY_DIR=${matx_BINARY_DIR}")
message(STATUS "matx_SOURCE_DIR=${matx_SOURCE_DIR}")
message(STATUS "MATX_INCLUDE_DIR=${MATX_INCLUDE_DIR}")
message(STATUS "matx_VERSION=${matx_VERSION}")

include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(
LocalMatX
VERSION_VAR matx_VERSION
REQUIRED_VARS matx_SOURCE_DIR matx_BINARY_DIR)
4 changes: 3 additions & 1 deletion _cmake/externals/FindLocalPyBind11.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# pybind11
#

set(pybind11_TAG "v2.10.4")
set(pybind11_TAG "v2.13.5")

include(FetchContent)
FetchContent_Declare(
Expand All @@ -19,6 +19,8 @@ FetchContent_Declare(
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
FetchContent_Populate(pybind11)
message(STATUS "pybind11_SOURCE_DIR=${pybind11_SOURCE_DIR}")
message(STATUS "pybind11_BINARY_DIR=${pybind11_BINARY_DIR}")
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR})
else()
message(FATAL_ERROR "Pybind11 was not found.")
Expand Down
18 changes: 14 additions & 4 deletions _cmake/externals/FindOrt.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@ if(MSVC)
"with extension 'pdb'.")
endif()
list(APPEND ORT_LIB_FILES ${ORT_LIB_FILES_PDB})

file(GLOB ORT_LIB_FILES_LIB ${ONNXRUNTIME_LIB_DIR}/*.lib)
list(LENGTH ORT_LIB_FILES_LIB ORT_LIB_FILES_LIB_LENGTH)
if (ORT_LIB_FILES_LIB_LENGTH LESS_EQUAL 1)
message(FATAL_ERROR "No pdb file found in '${ONNXRUNTIME_LIB_DIR}' "
"from path or url '${ORT_URL}', "
"found files [${ORT_LIB_FILES}] "
"with extension 'lib'.")
endif()
list(APPEND ORT_LIB_FILES ${ORT_LIB_FILES_LIB})
endif()

#
Expand All @@ -133,19 +143,19 @@ function(ort_add_dependency name folder_copy)
foreach(file_i ${ORT_LIB_FILES})
message(STATUS "ort: copy ${file_i} to '${destination_dir}'")
add_custom_command(
TARGET ${name} POST_BUILD
TARGET ${name} PRE_BUILD
COMMAND ${CMAKE_COMMAND} ARGS -E copy ${file_i} ${destination_dir})
if(folder_copy)
message(STATUS "ort: copy '${file_i}' to '${ROOT_PROJECT_PATH}/${folder_copy}'")
add_custom_command(
TARGET ${name} POST_BUILD
TARGET ${name} PRE_BUILD
COMMAND ${CMAKE_COMMAND} ARGS -E copy "${file_i}" "${ROOT_PROJECT_PATH}/${folder_copy}")
message(STATUS "ort: copy '${file_i}' to '${SETUP_BUILD_LIB}/${folder_copy}'")
add_custom_command(
TARGET ${name} POST_BUILD
TARGET ${name} PRE_BUILD
COMMAND ${CMAKE_COMMAND} ARGS -E make_directory "${SETUP_BUILD_LIB}/${folder_copy}")
add_custom_command(
TARGET ${name} POST_BUILD
TARGET ${name} PRE_BUILD
COMMAND ${CMAKE_COMMAND} ARGS -E copy "${file_i}" "${SETUP_BUILD_LIB}/${folder_copy}")
endif()
endforeach()
Expand Down
10 changes: 10 additions & 0 deletions _cmake/load_externals.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,14 @@ else()
message(FATAL_ERROR "Module eigen is not installed.")
endif()

if(USE_CUDA)
message(STATUS "-------------------")
find_package(LocalMatX REQUIRED)
if(LocalMatX_FOUND)
message(STATUS "Found MatX ${LocalMatX_VERSION}")
else()
message(FATAL_ERROR "Module MatX is not installed.")
endif()
endif()

message(STATUS "-------------------")
20 changes: 13 additions & 7 deletions _cmake/targets/ortinf.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,30 @@ target_include_directories(
${ROOT_INCLUDE_PATH})
target_link_libraries(lib_ortapi PRIVATE common)

set(ORTAPI_INCLUDE_DIR "${ROOT_PROJECT_PATH}/onnx_extended/ortcy/wrap")

cython_add_module(
ortinf
../onnx_extended/ortcy/wrap/ortinf.pyx
OpenMP::OpenMP_CXX)
target_link_directories(ortinf PRIVATE ${ONNXRUNTIME_LIB_DIR})
message(STATUS " LINK ortinf <- lib_ortapi onnxruntime")

message(STATUS " LINK ortinf <- lib_ortapi onnxruntime ${ORTAPI_INCLUDE_DIR}")

ort_add_dependency(
ortinf
onnx_extended/ortcy/wrap)

# If ONNXRUNTIME_LIB_DIR is used, then it seems a local installation does
# does not the binaries anymore if they are removed.
target_link_directories(ortinf PRIVATE ${ORTAPI_INCLUDE_DIR})

target_link_libraries(
ortinf
PRIVATE
lib_ortapi
onnxruntime
common_kernels)
target_include_directories(ortinf PRIVATE ${ROOT_INCLUDE_PATH})
ort_add_dependency(
ortinf
onnx_extended/ortcy/wrap)

set(ORTAPI_INCLUDE_DIR "${ROOT_INCLUDE_PATH}/onnx_extended/ortcy/wrap")

add_executable(test_ortcy_inference_cpp ../_unittests/ut_ortcy/test_inference.cpp)
target_compile_definitions(test_ortcy_inference_cpp PRIVATE PYTHON_MANYLINUX=${PYTHON_MANYLINUX})
Expand Down
8 changes: 6 additions & 2 deletions _cmake/targets/ortops_tutorial_cuda.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# module: onnx_extended.reference.c_ops.cpu.c_op_conv_
# custom ops: onnx_extended.ortops.tutorial.cuda
#

if(CUDA_AVAILABLE)
Expand All @@ -12,6 +12,7 @@ if(CUDA_AVAILABLE)
onnx_extended/ortops/tutorial/cuda
../onnx_extended/cpp/onnx_extended_helpers.cpp
../onnx_extended/ortops/tutorial/cuda/custom_gemm.cu
../onnx_extended/ortops/tutorial/cuda/matx_matmul.cu
../onnx_extended/ortops/tutorial/cuda/ort_tutorial_cuda_lib.cc)

# needed to include onnx_extended_helpers.h
Expand All @@ -20,6 +21,9 @@ if(CUDA_AVAILABLE)
PRIVATE
"${ROOT_INCLUDE_PATH}"
"${ORTAPI_INCLUDE_DIR}"
"${ORTOPS_INCLUDE_DIR}")
"${ORTOPS_INCLUDE_DIR}"
"${matx_INCLUDE_DIR}")

target_link_libraries(ortops_tutorial_cuda PRIVATE matx::matx)

endif()
1 change: 1 addition & 0 deletions _doc/tutorial/old_version.rst
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ It calls function :func:`bench_virtual <onnx_extended.tools.run_onnx.bench_virtu
runtimes = ["onnxruntime"]
modules = [
{"onnx-extended": "0.3.0", "onnx": "1.15.0", "onnxruntime": "1.19.2"},
{"onnx-extended": "0.3.0", "onnx": "1.15.0", "onnxruntime": "1.18.0"},
{"onnx-extended": "0.2.3", "onnx": "1.15.0", "onnxruntime": "1.17.3"},
{"onnx-extended": "0.2.3", "onnx": "1.15.0", "onnxruntime": "1.16.3"},
Expand Down
42 changes: 34 additions & 8 deletions _unittests/ut_ortcy/test_ortcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
path = onnx_extended.ortcy.wrap.ortinf.__file__
except ImportError as ee:
path = str(ee)
msg = "libonnxruntime.so.1.18.0: cannot open shared object file"
msg = "libonnxruntime.so.1.19.2: cannot open shared object file"
if msg in str(e):
from onnx_extended.ortcy.wrap import __file__ as loc

Expand All @@ -41,22 +41,39 @@
get_ort_c_api_supported_version = None
here = os.path.dirname(__file__)
else:
OrtSession = "OrtSession is not initialized"
import onnx_extended.ortcy.wrap as wp

OrtSession = (
f"OrtSession is not initialized:\n{e}"
f"\n--found--\n{os.listdir(os.path.dirname(wp.__file__))}"
)
get_ort_c_api_supported_version = (
"get_ort_c_api_supported_version is not initialized"
f"get_ort_c_api_supported_version is not initialized:\n{e}"
f"\n--found--\n{os.listdir(os.path.dirname(wp.__file__))}"
)


class TestOrtCy(ExtTestCase):

def test_get_ort_c_api_supported_version_str(self):
assert not isinstance(get_ort_c_api_supported_version, str), (
f"Unexpected value for get_ort_c_api_supported_version="
f"{get_ort_c_api_supported_version!r}"
)

@unittest.skipIf(
get_ort_c_api_supported_version is None,
get_ort_c_api_supported_version is None
or isinstance(get_ort_c_api_supported_version, str),
reason="libonnxruntime installation failed",
)
def test_get_ort_c_api_supported_version(self):
v = get_ort_c_api_supported_version()
self.assertGreaterEqual(v, 16)

@unittest.skipIf(OrtSession is None, reason="libonnxruntime installation failed")
@unittest.skipIf(
OrtSession is None or isinstance(OrtSession, str),
reason="libonnxruntime installation failed",
)
def test_ort_get_available_providers(self):
from onnx_extended.ortcy.wrap.ortinf import ort_get_available_providers

Expand All @@ -65,7 +82,10 @@ def test_ort_get_available_providers(self):
self.assertGreater(len(res), 0)
self.assertIn("CPUExecutionProvider", res)

@unittest.skipIf(OrtSession is None, reason="libonnxruntime installation failed")
@unittest.skipIf(
OrtSession is None or isinstance(OrtSession, str),
reason="libonnxruntime installation failed",
)
@unittest.skipIf(
sys.platform == "darwin",
reason="The test is unstable and leads to a crash `Illegal instruction`. "
Expand Down Expand Up @@ -115,7 +135,10 @@ def test_session(self):
self.assertEqual(len(got), 1)
self.assertEqualArray(got[0], x + y)

@unittest.skipIf(OrtSession is None, reason="libonnxruntime installation failed")
@unittest.skipIf(
OrtSession is None or isinstance(OrtSession, str),
reason="libonnxruntime installation failed",
)
@unittest.skipIf(
sys.platform == "darwin", reason="Compilation settings fails on darwin"
)
Expand Down Expand Up @@ -148,7 +171,10 @@ def test_my_custom_ops_cy(self):
got = session.run_2(x, y)[0]
self.assertEqualArray(x + y, got)

@unittest.skipIf(OrtSession is None, reason="libonnxruntime installation failed")
@unittest.skipIf(
OrtSession is None or isinstance(OrtSession, str),
reason="libonnxruntime installation failed",
)
@unittest.skipIf(
sys.platform == "darwin", reason="Compilation settings fails on darwin"
)
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_xrun_doc/test_documentation_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
try:
from onnx_extended.ortcy.wrap.ortinf import OrtSession
except ImportError as e:
msg = "libonnxruntime.so.1.18.0: cannot open shared object file"
msg = "libonnxruntime.so.1.19.2: cannot open shared object file"
if msg in str(e):
from onnx_extended.ortcy.wrap import __file__ as loc

Expand Down
2 changes: 1 addition & 1 deletion onnx_extended/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def local_print(msg):

this = os.path.dirname(cyfile)
files = os.listdir(this)
if "libonnxruntime.so.1.18.0" in files:
if "libonnxruntime.so.1.19.2" in files:
if verbose:
local_print(
"[check_installation_ortcy] weird issue as the "
Expand Down
4 changes: 2 additions & 2 deletions onnx_extended/ortops/tutorial/cuda/custom_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ void CustomGemmKernel::SetParams(const std::vector<int64_t> &shape_A,
}
}

void check_device(const Ort::ConstValue &input, const char *name) {
static void check_device(const Ort::ConstValue &input, const char *name) {
EXT_ENFORCE(input.HasValue(), "Input '", name, "' is not empty.");
auto mem = input.GetTensorMemoryInfo();
EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Input '", name, "' is not on CUDA");
}

void check_device(const Ort::UnownedValue &output, const char *name) {
static void check_device(const Ort::UnownedValue &output, const char *name) {
auto mem = output.GetTensorMemoryInfo();
EXT_ENFORCE(mem.GetDeviceType() == OrtMemoryInfoDeviceType::OrtMemoryInfoDeviceType_GPU,
"Output '", name, "' is not on CUDA");
Expand Down
Loading

0 comments on commit ff2c70e

Please sign in to comment.