diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 4489d5314..0d31666eb 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -76,6 +76,7 @@ PYBIND11_MODULE(_ext, m) m.def("set_random_seed", &ctranslate2::set_random_seed, py::arg("seed"), "Sets the seed of random generators."); + ctranslate2::python::register_profiling(m); ctranslate2::python::register_logging(m); ctranslate2::python::register_storage_view(m); ctranslate2::python::register_translation_stats(m); diff --git a/python/cpp/module.h b/python/cpp/module.h index 9c9a9a2ff..207e0abf6 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -1,12 +1,14 @@ #pragma once #include +#include namespace py = pybind11; namespace ctranslate2 { namespace python { + void register_profiling(py::module& m); void register_encoder(py::module& m); void register_generation_result(py::module& m); void register_generator(py::module& m); diff --git a/python/cpp/profiling.cc b/python/cpp/profiling.cc new file mode 100644 index 000000000..4b8dc4192 --- /dev/null +++ b/python/cpp/profiling.cc @@ -0,0 +1,19 @@ +#include "module.h" +#include +#include + +namespace ctranslate2 { + namespace python { + + void register_profiling(py::module& m) { + + m.def("init_profiling", &ctranslate2::init_profiling); + m.def("dump_profiling", []() { + std::ostringstream oss; + ctranslate2::dump_profiling(oss); + return oss.str(); + }); + } + + } +} diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py index 88da68aec..46ada00d5 100644 --- a/python/ctranslate2/__init__.py +++ b/python/ctranslate2/__init__.py @@ -42,6 +42,7 @@ ) from ctranslate2.extensions import register_extensions from ctranslate2.logging import get_log_level, set_log_level + from ctranslate2.profiling import dump_profiler, init_profiler register_extensions() del register_extensions diff --git a/python/ctranslate2/profiling.py b/python/ctranslate2/profiling.py new file mode 100644 index 000000000..6663a0ea5 --- /dev/null +++ b/python/ctranslate2/profiling.py @@ -0,0 +1,12 @@ +import sys + +from ctranslate2 import Device, _ext + + +def init_profiler(device=Device.cpu, num_threads=1): + _ext.init_profiling(device, num_threads) + + +def dump_profiler(): + profiling_data = _ext.dump_profiling() + sys.stdout.write(profiling_data)