diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..98d76c1 --- /dev/null +++ b/.clang-format @@ -0,0 +1,29 @@ +AllowShortFunctionsOnASingleLine: Empty +AllowShortIfStatementsOnASingleLine: false +AllowShortLoopsOnASingleLine: false +BasedOnStyle: Google +BinPackArguments: false +BinPackParameters: false +ColumnLimit: 100 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +IncludeBlocks: Preserve +IncludeCategories: + - Regex: "<.*>" + Priority: 1 + SortPriority: 0 + - Regex: '^(<|"(boost|google|grpc)/)' + Priority: 3 + SortPriority: 0 + - Regex: ".*" + Priority: 1 + SortPriority: 0 +IncludeIsMainRegex: "(Test)?$" +IncludeIsMainSourceRegex: "" +IndentWidth: 4 +MaxEmptyLinesToKeep: 1 +NamespaceIndentation: None +SortIncludes: true +SpaceBeforeAssignmentOperators: true +Standard: Cpp11 +UseTab: Never diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4453a13..6017b21 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -2,16 +2,16 @@ name: Run linter and build on: push: - branches: [ "main" ] + branches: ["main"] pull_request: - branches: [ "main" ] + branches: ["main"] jobs: build: strategy: matrix: # TODO(RSDK-10636): run the build on windows-2019, too, when it's is tolerably fast - runs_on: [ubuntu-22.04, ubuntu-22.04-arm, macos-14, macos-13] + runs_on: [ubuntu-22.04, ubuntu-22.04-arm, macos-14] name: "Lint and build on each platform" runs-on: ${{ matrix.runs_on }} diff --git a/.gitignore b/.gitignore index f267040..1e9dbed 100644 --- a/.gitignore +++ b/.gitignore @@ -28,4 +28,5 @@ *.app CMakeUserPresets.json -build-conan +build-conan/ +build/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3bbdf48..9f6b592 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,16 +38,20 @@ find_package(Threads REQUIRED) find_package(viam-cpp-sdk REQUIRED) find_package(tensorflowlite REQUIRED) -add_executable(tflite_cpu - src/tflite_cpu.cpp -) +add_library(tflite_cpu_service src/tflite_cpu.cpp) -target_link_libraries(tflite_cpu - PRIVATE Threads::Threads - PRIVATE viam-cpp-sdk::viamsdk +target_link_libraries(tflite_cpu_service + PUBLIC Threads::Threads + PUBLIC viam-cpp-sdk::viamsdk PRIVATE tensorflow::tensorflowlite ) +add_executable(tflite_cpu + src/main.cpp +) + +target_link_libraries(tflite_cpu PRIVATE tflite_cpu_service) + install( TARGETS tflite_cpu -) +) \ No newline at end of file diff --git a/conanfile.py b/conanfile.py index 9813e25..bd117b6 100644 --- a/conanfile.py +++ b/conanfile.py @@ -44,6 +44,7 @@ def requirements(self): # NOTE: If you update the `viam-cpp-sdk` dependency here, it # should also be updated in `bin/setup.{sh,ps1}`. self.requires("viam-cpp-sdk/0.20.1") + self.requires("boost/[>=1.74.0]") self.requires("tensorflow-lite/2.15.0") # NOTE: This should match what the viam-cpp-sdk pulls (indirectly, via grpc/protobuf) # TODO: Is there a way to express that better than hardcoding it? diff --git a/src/main.cpp b/src/main.cpp new file mode 100644 index 0000000..e4418b9 --- /dev/null +++ b/src/main.cpp @@ -0,0 +1,64 @@ +#include + +#include +#include +#include + +#include "tflite_cpu.hpp" + +namespace { + +int serve(const std::string& socket_path) try { + // Every Viam C++ SDK program must have one and only one Instance object which is created before + // any other C++ SDK objects and stays alive until all Viam C++ SDK objects are destroyed. + viam::sdk::Instance inst; + + // Create a new model registration for the service. + auto module_registration = std::make_shared( + // Identify that this resource offers the MLModelService API + viam::sdk::API::get(), + + // Declare a model triple for this service. + viam::sdk::Model{"viam", "mlmodel-tflite", "tflite_cpu"}, + + // Define the factory for instances of the resource. + [](viam::sdk::Dependencies deps, viam::sdk::ResourceConfig config) { + return std::make_shared(std::move(deps), + std::move(config)); + }); + + // Register the newly created registration with the Registry. + viam::sdk::Registry::get().register_model(module_registration); + + // Construct the module service and tell it where to place the socket path. + auto module_service = std::make_shared(socket_path); + + // Add the server as providing the API and model declared in the + // registration. + module_service->add_model_from_registry(module_registration->api(), + module_registration->model()); + + // Start the module service. + module_service->serve(); + + return EXIT_SUCCESS; +} catch (const std::exception& ex) { + std::cout << "ERROR: A std::exception was thrown from `serve`: " << ex.what() << std::endl; + return EXIT_FAILURE; +} catch (...) { + std::cout << "ERROR: An unknown exception was thrown from `serve`" << std::endl; + return EXIT_FAILURE; +} +} // namespace + +int main(int argc, char* argv[]) { + const auto usage = std::string("usage: ") + argv[0] + " /path/to/unix/socket"; + + if (argc < 2) { + std::cout << "ERROR: insufficient arguments\n"; + std::cout << usage << "\n"; + return EXIT_FAILURE; + } + + return serve(argv[1]); +} diff --git a/src/tflite_cpu.cpp b/src/tflite_cpu.cpp index 364b4d1..a864581 100644 --- a/src/tflite_cpu.cpp +++ b/src/tflite_cpu.cpp @@ -34,632 +34,552 @@ #include #include -namespace { +#include "tflite_cpu.hpp" -namespace vsdk = ::viam::sdk; -constexpr char service_name[] = "viam_tflite_cpu"; +namespace viam { +namespace mlmodel_tflite { -// An MLModelService instance which runs TensorFlow Lite models. -// -// Configuration requires the following parameters: -// -- `model_path`: An absolute filesystem path to a TensorFlow Lite model file. -// -// The following optional parameters are honored: -// -- `num_threads`: Sets the number of threads to be used, where applicable. -// -// -- `label_path`: An absolute filesystem path to a .txt file of the model's category labels. -// -// Any additional configuration fields are ignored. -class MLModelServiceTFLite : public vsdk::MLModelService, - public vsdk::Stoppable, - public vsdk::Reconfigurable { - class write_to_tflite_tensor_visitor_; +namespace { +namespace vsdk = ::viam::sdk; +using namespace vsdk; - public: - explicit MLModelServiceTFLite(vsdk::Dependencies dependencies, - vsdk::ResourceConfig configuration) - : MLModelService(configuration.name()), - state_(configure_(std::move(dependencies), std::move(configuration))) {} - - ~MLModelServiceTFLite() final { - // All invocations arrive via gRPC, so we know we are idle - // here. It should be safe to tear down all state - // automatically without needing to wait for anything more to - // drain. - } +constexpr char k_service_name[] = "viam_tflite_cpu"; - void stop(const vsdk::ProtoStruct& extra) noexcept final { - return stop(); +// Converts from tflites type enumeration into the model service +// type enumeration or throws if there is no such conversion. +MLModelService::tensor_info::data_types service_data_type_from_tflite_data_type(TfLiteType type) { + switch (type) { + case kTfLiteInt8: { + return MLModelService::tensor_info::data_types::k_int8; + } + case kTfLiteUInt8: { + return MLModelService::tensor_info::data_types::k_uint8; + } + case kTfLiteInt16: { + return MLModelService::tensor_info::data_types::k_int16; + } + case kTfLiteUInt16: { + return MLModelService::tensor_info::data_types::k_uint16; + } + case kTfLiteInt32: { + return MLModelService::tensor_info::data_types::k_int32; + } + case kTfLiteUInt32: { + return MLModelService::tensor_info::data_types::k_uint32; + } + case kTfLiteInt64: { + return MLModelService::tensor_info::data_types::k_int64; + } + case kTfLiteUInt64: { + return MLModelService::tensor_info::data_types::k_uint64; + } + case kTfLiteFloat32: { + return MLModelService::tensor_info::data_types::k_float32; + } + case kTfLiteFloat64: { + return MLModelService::tensor_info::data_types::k_float64; + } + default: { + std::ostringstream buffer; + buffer << k_service_name << ": Model contains unsupported tflite data type" << type; + throw std::invalid_argument(buffer.str()); + } } +} - /// @brief Stops the MLModelServiceTFLite from running. - void stop() noexcept { - const std::unique_lock state_wlock(state_rwmutex_); - state_.reset(); +// The type specific version of the above function, it just +// reinterpret_casts the tensor buffer into an MLModelService +// tensor view and applies the necessary shape info. +template +MLModelService::tensor_views tensor_views_from_tflite_tensor_t( + const MLModelService::tensor_info& info, const TfLiteTensor* const tflite_tensor) { + const auto* const tensor_data = reinterpret_cast(TfLiteTensorData(tflite_tensor)); + const auto tensor_size_bytes = TfLiteTensorByteSize(tflite_tensor); + const auto tensor_size_t = tensor_size_bytes / sizeof(T); + // TODO: We are just feeding back out what we cached in the + // metadata for shape. Should this instead be re-querying the + // output tensor NumDims / DimN after each invocation in case + // the shape is dynamic? The possibility of a dynamically + // sized extent is why we represent the dimensions as signed + // quantities in the tensor metadata. But an actual tensor has + // a real extent. How would tflite ever communicate that to us + // differently given that we use the same API to obtain + // metadata as we would here? + std::vector shape; + shape.reserve(info.shape.size()); + for (const auto s : info.shape) { + shape.push_back(static_cast(s)); } + return mlmodel_tflite::MLModelServiceTFLite::MLModelService::make_tensor_view( + tensor_data, tensor_size_t, std::move(shape)); +} - void reconfigure(const vsdk::Dependencies& dependencies, - const vsdk::ResourceConfig& configuration) final { - const std::unique_lock state_wlock(state_rwmutex_); - check_stopped_inlock_(); - state_.reset(); - state_ = configure_(dependencies, configuration); +// Creates a tensor_view which views a tflite tensor buffer. It dispatches on the +// type and delegates to the templated version below. +MLModelService::tensor_views tensor_views_from_tflite_tensor( + const MLModelService::tensor_info& info, const TfLiteTensor* const tflite_tensor) { + const auto tflite_tensor_type = TfLiteTensorType(tflite_tensor); + switch (tflite_tensor_type) { + case kTfLiteInt8: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteUInt8: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteInt16: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteUInt16: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteInt32: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteUInt32: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteInt64: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteUInt64: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteFloat32: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + case kTfLiteFloat64: { + return tensor_views_from_tflite_tensor_t(info, tflite_tensor); + } + default: { + std::ostringstream buffer; + buffer << k_service_name + << ": Model returned unsupported tflite data type: " << tflite_tensor_type; + throw std::invalid_argument(buffer.str()); + } } +} - std::shared_ptr infer(const named_tensor_views& inputs, - const vsdk::ProtoStruct& extra) final { - - // We need to lock state so we are protected against reconfiguration, but - // we don't want to block access to `metadata`. We use a shared lock here, - // and an exclusive lock to protect the interpreter itself, below. - std::shared_lock state_rlock(state_rwmutex_); - check_stopped_inlock_(); +// A visitor that can populate a TFLiteTensor given a MLModelService::tensor_view. +class write_to_tflite_tensor_visitor : public boost::static_visitor { + public: + write_to_tflite_tensor_visitor(const std::string* name, TfLiteTensor* tflite_tensor) + : name_(name), tflite_tensor_(tflite_tensor) {}; - // Ensure that enough inputs were provided. - if (inputs.size() < state_->input_tensor_indices_by_name.size()) { + template + TfLiteStatus operator()(const T& mlmodel_tensor) const { + const auto expected_size = TfLiteTensorByteSize(tflite_tensor_); + const auto* const mlmodel_data_begin = + reinterpret_cast(mlmodel_tensor.data()); + const auto* const mlmodel_data_end = + reinterpret_cast(mlmodel_tensor.data() + mlmodel_tensor.size()); + const auto mlmodel_data_size = static_cast(mlmodel_data_end - mlmodel_data_begin); + if (expected_size != mlmodel_data_size) { std::ostringstream buffer; - buffer << service_name - << ": Too few inputs provided for inference: " << state_->input_tensor_indices_by_name.size() - << " expected, but got " << inputs.size() << " instead"; + buffer << k_service_name << ": tensor `" << *name_ + << "` was expected to have byte size " << expected_size << " but " + << mlmodel_data_size << " bytes were provided"; throw std::invalid_argument(buffer.str()); } + return TfLiteTensorCopyFromBuffer(tflite_tensor_, mlmodel_data_begin, expected_size); + } - // Only one thread can actually interact with `state->interpreter` at the same time. - std::unique_lock interpreter_lock(state_->interpreter_mutex); - - // Walk the inputs, and copy the data from each of the input - // tensor views we were given into the associated tflite input - // tensor buffer. - for (const auto& kv : inputs) { - const auto where = state_->input_tensor_indices_by_name.find(kv.first); - if (where == state_->input_tensor_indices_by_name.end()) { - std::ostringstream buffer; - buffer << service_name << ": Tensor name `" << kv.first << "`" - << " is not a known input tensor name for the model"; - throw std::invalid_argument(buffer.str()); - } - auto* const tensor = state_->interpreter->tensor(where->second); - if (!tensor) { - std::ostringstream buffer; - buffer << service_name << ": Failed to obtain tflite input tensor for `" << kv.first - << "` (index " << where->second << ")"; - throw std::invalid_argument(buffer.str()); - } - - const auto tflite_status = - boost::apply_visitor(write_to_tflite_tensor_visitor_(&kv.first, tensor), kv.second); - - if (tflite_status != TfLiteStatus::kTfLiteOk) { - std::ostringstream buffer; - buffer << service_name - << ": input tensor `" << kv.first - << "` failed population: " << state_->interpreter_error_data; - throw std::invalid_argument(buffer.str()); - } - } - - // Invoke the interpreter and return any failure information. - if (state_->interpreter->Invoke() != TfLiteStatus::kTfLiteOk) { - std::ostringstream buffer; - buffer << service_name - << ": interpreter invocation failed: " << state_->interpreter_error_data; - throw std::runtime_error(buffer.str()); - } + private: + const std::string* name_; + TfLiteTensor* tflite_tensor_; +}; - // A local type that we will keep on the heap to hold - // inference results until the caller is done with them. In - // our case, the caller is MLModelServiceServer, which will - // copy the data into the reply gRPC proto and then unwind. So - // we can avoid copying the data by letting the views alias - // the tensorflow tensor buffers and keep the interpreter lock - // held until the gRPC work is done. Note that this means the - // state and interpreter locks will remain held until the - // inference_result_type object tracked by the shared pointer - // we return is destroyed. Callers that want to make use of - // the inference results without keeping the interpreter - // locked would need to copy the data out of the views and - // then release the return value. - struct inference_result_type { - std::shared_lock state_rlock; - std::unique_lock interpreter_lock; - named_tensor_views views; - }; - auto inference_result = std::make_shared(); - - // Walk the outputs per our metadata and emplace an - // appropriately typed tensor_view aliasing the interpreter - // output tensor buffer into the inference results. - for (const auto& output : state_->metadata.outputs) { - const auto where = state_->output_tensor_indices_by_name.find(output.name); - if (where == state_->output_tensor_indices_by_name.end()) { - continue; // Should be impossible - } - const auto* const tflite_tensor = state_->interpreter->tensor(where->second); - inference_result->views.emplace(output.name, - std::move(make_tensor_view_(output, tflite_tensor))); - } - - // The views created in the loop above are only valid until - // the interpreter lock is released, so we keep the lock held - // by moving the unique_lock into the inference_result - // object. We also need the state lock to protect our configuration. - inference_result->state_rlock = std::move(state_rlock); - inference_result->interpreter_lock = std::move(interpreter_lock); - - // Finally, construct an aliasing shared_ptr which appears to - // the caller as a shared_ptr to views, but in fact manages - // the lifetime of the inference_result. When the - // inference_result object is destroyed, the lock will be - // released and the next caller can invoke the interpreter. - auto* const views = &inference_result->views; - // NOLINTNEXTLINE(performance-move-const-arg): C++20 - return {std::move(inference_result), views}; - } +} // namespace - struct metadata metadata(const vsdk::ProtoStruct& extra) final { - // Just return a copy of our metadata from leased state. - const std::shared_lock state_rlock(state_rwmutex_); - check_stopped_inlock_(); - return state_->metadata; +// All of the meaningful internal state of the service is held in +// a separate state object to help ensure clean replacement of our +// internals during reconfiguration. +struct MLModelServiceTFLite::state_ final : public tflite::ErrorReporter { + explicit state_(vsdk::Dependencies dependencies, vsdk::ResourceConfig configuration) + : dependencies(std::move(dependencies)), configuration(std::move(configuration)) {} + + int Report(const char* format, va_list args) override { + char buffer[4096]; + static_cast(vsnprintf(buffer, sizeof(buffer), format, args)); + interpreter_error_data = buffer; + return 0; } - private: - struct state_; + // The dependencies and configuration we were given at + // construction / reconfiguration. + vsdk::Dependencies dependencies; + vsdk::ResourceConfig configuration; - void check_stopped_inlock_() const { - if (!state_) { - std::ostringstream buffer; - buffer << service_name << ": service is stopped: "; - throw std::runtime_error(buffer.str()); - } - } + // This data must outlive any interpreters created from the + // model we build against model data. + std::string model_data; + std::unique_ptr model; + + // Metadata about input and output tensors that was extracted + // during configuration. Callers need this in order to know + // how to interact with the service. + struct MLModelService::metadata metadata; + + // The label path is a file that relates the outputs of the label tensor ints to strings + std::string label_path; + + // Maps from string names of tensors to the numeric + // value. Note that the keys here are the renamed tensors, if + // applicable. + std::unordered_map input_tensor_indices_by_name; + std::unordered_map output_tensor_indices_by_name; + + // Protects interpreter_error_data and interpreter + std::mutex interpreter_mutex; + + // The `Report` method will overwrite this string. + std::string interpreter_error_data; + + // The interpreter itself. + std::unique_ptr interpreter; +}; + +MLModelServiceTFLite::MLModelServiceTFLite(vsdk::Dependencies dependencies, + vsdk::ResourceConfig configuration) + : MLModelService(configuration.name()), + state_(configure_(std::move(dependencies), std::move(configuration))) {} - static std::unique_ptr configure_(vsdk::Dependencies dependencies, - vsdk::ResourceConfig configuration) { +MLModelServiceTFLite::~MLModelServiceTFLite() { + // All invocations arrive via gRPC, so we know we are idle + // here. It should be safe to tear down all state + // automatically without needing to wait for anything more to + // drain. +} + +void MLModelServiceTFLite::stop(const vsdk::ProtoStruct& extra) noexcept { + return stop(); +} + +/// @brief Stops the MLModelServiceTFLite from running. +void MLModelServiceTFLite::stop() noexcept { + const std::unique_lock state_wlock(state_rwmutex_); + state_.reset(); +} + +void MLModelServiceTFLite::reconfigure(const vsdk::Dependencies& dependencies, + const vsdk::ResourceConfig& configuration) { + const std::unique_lock state_wlock(state_rwmutex_); + check_stopped_inlock_(); + state_.reset(); + state_ = configure_(dependencies, configuration); +} + +std::shared_ptr MLModelServiceTFLite::infer( + const named_tensor_views& inputs, const vsdk::ProtoStruct& extra) { + // We need to lock state so we are protected against reconfiguration, but + // we don't want to block access to `metadata`. We use a shared lock here, + // and an exclusive lock to protect the interpreter itself, below. + std::shared_lock state_rlock(state_rwmutex_); + check_stopped_inlock_(); + + // Ensure that enough inputs were provided. + if (inputs.size() < state_->input_tensor_indices_by_name.size()) { + std::ostringstream buffer; + buffer << k_service_name << ": Too few inputs provided for inference: " + << state_->input_tensor_indices_by_name.size() << " expected, but got " + << inputs.size() << " instead"; + throw std::invalid_argument(buffer.str()); + } - auto state = - std::make_unique(std::move(dependencies), std::move(configuration)); + // Only one thread can actually interact with `state->interpreter` at the same time. + std::unique_lock interpreter_lock(state_->interpreter_mutex); - // Now we can begin parsing and validating the provided `configuration`. - // Pull the model path out of the configuration. - const auto& attributes = state->configuration.attributes(); - auto model_path = attributes.find("model_path"); - if (model_path == attributes.end()) { + // Walk the inputs, and copy the data from each of the input + // tensor views we were given into the associated tflite input + // tensor buffer. + for (const auto& kv : inputs) { + const auto where = state_->input_tensor_indices_by_name.find(kv.first); + if (where == state_->input_tensor_indices_by_name.end()) { std::ostringstream buffer; - buffer << service_name - << ": Required parameter `model_path` not found in configuration"; + buffer << k_service_name << ": Tensor name `" << kv.first << "`" + << " is not a known input tensor name for the model"; throw std::invalid_argument(buffer.str()); } - const auto* const model_path_string = model_path->second.get(); - if (!model_path_string || model_path_string->empty()) { + auto* const tensor = state_->interpreter->tensor(where->second); + if (!tensor) { std::ostringstream buffer; - buffer << service_name - << ": Required non-empty string parameter `model_path` is either not a string " - "or is an empty string"; + buffer << k_service_name << ": Failed to obtain tflite input tensor for `" << kv.first + << "` (index " << where->second << ")"; throw std::invalid_argument(buffer.str()); } - std::string label_path_string = ""; // default value for label_path - auto label_path = attributes.find("label_path"); - if (label_path != attributes.end()) { - const auto* const lp_string = label_path->second.get(); - if (!lp_string) { - std::ostringstream buffer; - buffer << service_name - << ": string parameter `label_path` is not a string "; - throw std::invalid_argument(buffer.str()); - } - label_path_string = *lp_string; - } - state->label_path = std::move(label_path_string); - - // Configuration parsing / extraction is complete. Move on to - // building the actual model with the provided information. - - // Try to load the provided `model_path`. The TFLite API - // declares that if you use `TfLiteModelCreateFromFile` that - // the file must remain unaltered during execution, but - // reconfiguration might cause it to change on disk while - // inference is in progress. Instead we read the file into a - // buffer which we can use with `TfLiteModelCreate`. That - // still requires that the buffer be kept valid, but that's - // more easily done. - const std::ifstream in(*model_path_string, std::ios::in | std::ios::binary); - if (!in) { + const auto tflite_status = + boost::apply_visitor(write_to_tflite_tensor_visitor(&kv.first, tensor), kv.second); + + if (tflite_status != TfLiteStatus::kTfLiteOk) { std::ostringstream buffer; - buffer << service_name << ": Failed to open file for `model_path` " - << *model_path_string; + buffer << k_service_name << ": input tensor `" << kv.first + << "` failed population: " << state_->interpreter_error_data; throw std::invalid_argument(buffer.str()); } - std::ostringstream model_path_contents_stream; - model_path_contents_stream << in.rdbuf(); - state->model_data = std::move(model_path_contents_stream.str()); + } + + // Invoke the interpreter and return any failure information. + if (state_->interpreter->Invoke() != TfLiteStatus::kTfLiteOk) { + std::ostringstream buffer; + buffer << k_service_name + << ": interpreter invocation failed: " << state_->interpreter_error_data; + throw std::runtime_error(buffer.str()); + } + + // A local type that we will keep on the heap to hold + // inference results until the caller is done with them. In + // our case, the caller is MLModelServiceServer, which will + // copy the data into the reply gRPC proto and then unwind. So + // we can avoid copying the data by letting the views alias + // the tensorflow tensor buffers and keep the interpreter lock + // held until the gRPC work is done. Note that this means the + // state and interpreter locks will remain held until the + // inference_result_type object tracked by the shared pointer + // we return is destroyed. Callers that want to make use of + // the inference results without keeping the interpreter + // locked would need to copy the data out of the views and + // then release the return value. + struct inference_result_type { + std::shared_lock state_rlock; + std::unique_lock interpreter_lock; + named_tensor_views views; + }; + auto inference_result = std::make_shared(); + + // Walk the outputs per our metadata and emplace an + // appropriately typed tensor_view aliasing the interpreter + // output tensor buffer into the inference results. + for (const auto& output : state_->metadata.outputs) { + const auto where = state_->output_tensor_indices_by_name.find(output.name); + if (where == state_->output_tensor_indices_by_name.end()) { + continue; // Should be impossible + } + const auto* const tflite_tensor = state_->interpreter->tensor(where->second); + inference_result->views.emplace( + output.name, std::move(tensor_views_from_tflite_tensor(output, tflite_tensor))); + } + + // The views created in the loop above are only valid until + // the interpreter lock is released, so we keep the lock held + // by moving the unique_lock into the inference_result + // object. We also need the state lock to protect our configuration. + inference_result->state_rlock = std::move(state_rlock); + inference_result->interpreter_lock = std::move(interpreter_lock); + + // Finally, construct an aliasing shared_ptr which appears to + // the caller as a shared_ptr to views, but in fact manages + // the lifetime of the inference_result. When the + // inference_result object is destroyed, the lock will be + // released and the next caller can invoke the interpreter. + auto* const views = &inference_result->views; + // NOLINTNEXTLINE(performance-move-const-arg): C++20 + return {std::move(inference_result), views}; +} + +struct MLModelServiceTFLite::metadata MLModelServiceTFLite::metadata( + const vsdk::ProtoStruct& extra) { + // Just return a copy of our metadata from leased state. + const std::shared_lock state_rlock(state_rwmutex_); + check_stopped_inlock_(); + return state_->metadata; +} + +void MLModelServiceTFLite::check_stopped_inlock_() const { + if (!state_) { + std::ostringstream buffer; + buffer << k_service_name << ": service is stopped: "; + throw std::runtime_error(buffer.str()); + } +} - state->model = tflite::impl::FlatBufferModel::BuildFromBuffer( - &state->model_data[0], - std::distance(cbegin(state->model_data), cend(state->model_data)), state.get()); +std::unique_ptr MLModelServiceTFLite::configure_( + vsdk::Dependencies dependencies, vsdk::ResourceConfig configuration) { + auto state = std::make_unique(std::move(dependencies), std::move(configuration)); + + // Now we can begin parsing and validating the provided `configuration`. + // Pull the model path out of the configuration. + const auto& attributes = state->configuration.attributes(); + auto model_path = attributes.find("model_path"); + if (model_path == attributes.end()) { + std::ostringstream buffer; + buffer << k_service_name << ": Required parameter `model_path` not found in configuration"; + throw std::invalid_argument(buffer.str()); + } + const auto* const model_path_string = model_path->second.get(); + if (!model_path_string || model_path_string->empty()) { + std::ostringstream buffer; + buffer << k_service_name + << ": Required non-empty string parameter `model_path` is either not a string " + "or is an empty string"; + throw std::invalid_argument(buffer.str()); + } - if (!state->model) { + std::string label_path_string = ""; // default value for label_path + auto label_path = attributes.find("label_path"); + if (label_path != attributes.end()) { + const auto* const lp_string = label_path->second.get(); + if (!lp_string) { std::ostringstream buffer; - buffer << service_name << ": Failed to load model from file `" << model_path_string - << "`: " << state->interpreter_error_data; + buffer << k_service_name << ": string parameter `label_path` is not a string "; throw std::invalid_argument(buffer.str()); } + label_path_string = *lp_string; + } + state->label_path = std::move(label_path_string); + + // Configuration parsing / extraction is complete. Move on to + // building the actual model with the provided information. + + // Try to load the provided `model_path`. The TFLite API + // declares that if you use `TfLiteModelCreateFromFile` that + // the file must remain unaltered during execution, but + // reconfiguration might cause it to change on disk while + // inference is in progress. Instead we read the file into a + // buffer which we can use with `TfLiteModelCreate`. That + // still requires that the buffer be kept valid, but that's + // more easily done. + const std::ifstream in(*model_path_string, std::ios::in | std::ios::binary); + if (!in) { + std::ostringstream buffer; + buffer << k_service_name << ": Failed to open file for `model_path` " << *model_path_string; + throw std::invalid_argument(buffer.str()); + } + std::ostringstream model_path_contents_stream; + model_path_contents_stream << in.rdbuf(); + state->model_data = std::move(model_path_contents_stream.str()); + + state->model = tflite::impl::FlatBufferModel::BuildFromBuffer( + &state->model_data[0], + std::distance(cbegin(state->model_data), cend(state->model_data)), + state.get()); + + if (!state->model) { + std::ostringstream buffer; + buffer << k_service_name << ": Failed to load model from file `" << model_path_string + << "`: " << state->interpreter_error_data; + throw std::invalid_argument(buffer.str()); + } - // Create an InterpreterBuilder so we can set the number of threads. - tflite::ops::builtin::BuiltinOpResolver resolver; - tflite::impl::InterpreterBuilder builder(*state->model, resolver); - - // If present, extract and validate the number of threads to - // use in the interpreter and create an interpreter options - // object to carry that information. - auto num_threads = attributes.find("num_threads"); - if (num_threads != attributes.end()) { - const auto* num_threads_double = num_threads->second.get(); - if (!num_threads_double || !std::isnormal(*num_threads_double) || - (*num_threads_double < 0) || - (*num_threads_double >= std::numeric_limits::max()) || - (std::trunc(*num_threads_double) != *num_threads_double)) { - std::ostringstream buffer; - buffer << service_name - << ": Value for field `num_threads` is not a positive integer: " - << *num_threads_double; - throw std::invalid_argument(buffer.str()); - } - if (builder.SetNumThreads(static_cast(*num_threads_double)) != kTfLiteOk) { - std::ostringstream buffer; - buffer << service_name - << ": Failed to set number of threads in interpreter builder: " << state->interpreter_error_data; - throw std::invalid_argument(buffer.str()); - } - } - - if (builder(&state->interpreter) != kTfLiteOk) { + // Create an InterpreterBuilder so we can set the number of threads. + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::impl::InterpreterBuilder builder(*state->model, resolver); + + // If present, extract and validate the number of threads to + // use in the interpreter and create an interpreter options + // object to carry that information. + auto num_threads = attributes.find("num_threads"); + if (num_threads != attributes.end()) { + const auto* num_threads_double = num_threads->second.get(); + if (!num_threads_double || !std::isnormal(*num_threads_double) || + (*num_threads_double < 0) || (*num_threads_double >= std::numeric_limits::max()) || + (std::trunc(*num_threads_double) != *num_threads_double)) { std::ostringstream buffer; - buffer << service_name - << ": Failed to create tflite interpreter: " << state->interpreter_error_data; - throw std::runtime_error(buffer.str()); + buffer << k_service_name + << ": Value for field `num_threads` is not a positive integer: " + << *num_threads_double; + throw std::invalid_argument(buffer.str()); } - - // Have the interpreter allocate tensors for the model - if (state->interpreter->AllocateTensors() != kTfLiteOk) { + if (builder.SetNumThreads(static_cast(*num_threads_double)) != kTfLiteOk) { std::ostringstream buffer; - buffer << service_name << ": Failed to allocate tensors for tflite interpreter: " + buffer << k_service_name << ": Failed to set number of threads in interpreter builder: " << state->interpreter_error_data; - throw std::runtime_error(buffer.str()); + throw std::invalid_argument(buffer.str()); } - - // Walk the input tensors now that they have been allocated - // and extract information about tensor names, types, and - // dimensions. Apply any tensor renamings per our - // configuration. Stash the relevant data in our `metadata` - // fields. - const auto input_tensor_indices = state->interpreter->inputs(); - for (auto input_tensor_index : input_tensor_indices) { - const auto* const tensor = state->interpreter->tensor(input_tensor_index); - - auto ndims = TfLiteTensorNumDims(tensor); - if (ndims == -1) { - std::ostringstream buffer; - buffer << service_name - << ": Unable to determine input tensor shape at configuration time, " - "inference not possible"; - throw std::runtime_error(buffer.str()); - } - - MLModelService::tensor_info input_info; - const auto* name = TfLiteTensorName(tensor); - input_info.name = name; - input_info.data_type = - service_data_type_from_tflite_data_type_(TfLiteTensorType(tensor)); - for (decltype(ndims) j = 0; j != ndims; ++j) { - input_info.shape.push_back(TfLiteTensorDim(tensor, j)); - } - state->input_tensor_indices_by_name[input_info.name] = input_tensor_index; - state->metadata.inputs.emplace_back(std::move(input_info)); - } - - // NOTE: The tflite C API docs state that information about - // output tensors may not be available until after one round - // of inference. We do a best effort inference on all zero - // inputs to try to account for this. - for (auto input_tensor_index : input_tensor_indices) { - auto* const tensor = state->interpreter->tensor(input_tensor_index); - const auto tensor_size = TfLiteTensorByteSize(tensor); - const std::vector zero_buffer(tensor_size, 0); - TfLiteTensorCopyFromBuffer(tensor, &zero_buffer[0], tensor_size); - } - - if (state->interpreter->Invoke() != TfLiteStatus::kTfLiteOk) { - // TODO: After C++ SDK 0.11.0 is released, use the new logging API. - std::cout << "WARNING: Inference with all zero input tensors failed: returned output tensor metadata may be unreliable" << std::endl; - } - - // Now that we have hopefully done one round of inference, dig out the actual - // metadata that we will return to clients. - const auto output_tensor_indices = state->interpreter->outputs(); - for (auto output_tensor_index : output_tensor_indices) { - const auto* const tensor = state->interpreter->tensor(output_tensor_index); - - auto ndims = TfLiteTensorNumDims(tensor); - if (ndims == -1) { - std::ostringstream buffer; - buffer << service_name - << ": Unable to determine output tensor shape at configuration time, " - "inference not possible"; - throw std::runtime_error(buffer.str()); - } - - MLModelService::tensor_info output_info; - const auto* name = TfLiteTensorName(tensor); - output_info.name = name; - output_info.data_type = - service_data_type_from_tflite_data_type_(TfLiteTensorType(tensor)); - for (decltype(ndims) j = 0; j != ndims; ++j) { - output_info.shape.push_back(TfLiteTensorDim(tensor, j)); - } - if (state->label_path != "") { - output_info.extra.insert({"labels", state->label_path}); - } - state->output_tensor_indices_by_name[output_info.name] = output_tensor_index; - state->metadata.outputs.emplace_back(std::move(output_info)); - } - - return state; } - // Converts from tflites type enumeration into the model service - // type enumeration or throws if there is no such conversion. - static MLModelService::tensor_info::data_types service_data_type_from_tflite_data_type_( - TfLiteType type) { - switch (type) { - case kTfLiteInt8: { - return MLModelService::tensor_info::data_types::k_int8; - } - case kTfLiteUInt8: { - return MLModelService::tensor_info::data_types::k_uint8; - } - case kTfLiteInt16: { - return MLModelService::tensor_info::data_types::k_int16; - } - case kTfLiteUInt16: { - return MLModelService::tensor_info::data_types::k_uint16; - } - case kTfLiteInt32: { - return MLModelService::tensor_info::data_types::k_int32; - } - case kTfLiteUInt32: { - return MLModelService::tensor_info::data_types::k_uint32; - } - case kTfLiteInt64: { - return MLModelService::tensor_info::data_types::k_int64; - } - case kTfLiteUInt64: { - return MLModelService::tensor_info::data_types::k_uint64; - } - case kTfLiteFloat32: { - return MLModelService::tensor_info::data_types::k_float32; - } - case kTfLiteFloat64: { - return MLModelService::tensor_info::data_types::k_float64; - } - default: { - std::ostringstream buffer; - buffer << service_name << ": Model contains unsupported tflite data type" << type; - throw std::invalid_argument(buffer.str()); - } - } + if (builder(&state->interpreter) != kTfLiteOk) { + std::ostringstream buffer; + buffer << k_service_name + << ": Failed to create tflite interpreter: " << state->interpreter_error_data; + throw std::runtime_error(buffer.str()); } - // All of the meaningful internal state of the service is held in - // a separate state object to help ensure clean replacement of our - // internals during reconfiguration. - struct state_ final : public tflite::ErrorReporter { - explicit state_(vsdk::Dependencies dependencies, vsdk::ResourceConfig configuration) - : dependencies(std::move(dependencies)), configuration(std::move(configuration)) {} + // Have the interpreter allocate tensors for the model + if (state->interpreter->AllocateTensors() != kTfLiteOk) { + std::ostringstream buffer; + buffer << k_service_name << ": Failed to allocate tensors for tflite interpreter: " + << state->interpreter_error_data; + throw std::runtime_error(buffer.str()); + } - int Report(const char *format, va_list args) override { - char buffer[4096]; - static_cast(vsnprintf(buffer, sizeof(buffer), format, args)); - interpreter_error_data = buffer; - return 0; + // Walk the input tensors now that they have been allocated + // and extract information about tensor names, types, and + // dimensions. Apply any tensor renamings per our + // configuration. Stash the relevant data in our `metadata` + // fields. + const auto input_tensor_indices = state->interpreter->inputs(); + for (auto input_tensor_index : input_tensor_indices) { + const auto* const tensor = state->interpreter->tensor(input_tensor_index); + + auto ndims = TfLiteTensorNumDims(tensor); + if (ndims == -1) { + std::ostringstream buffer; + buffer << k_service_name + << ": Unable to determine input tensor shape at configuration time, " + "inference not possible"; + throw std::runtime_error(buffer.str()); } - // The dependencies and configuration we were given at - // construction / reconfiguration. - vsdk::Dependencies dependencies; - vsdk::ResourceConfig configuration; - - // This data must outlive any interpreters created from the - // model we build against model data. - std::string model_data; - std::unique_ptr model; - - // Metadata about input and output tensors that was extracted - // during configuration. Callers need this in order to know - // how to interact with the service. - struct MLModelService::metadata metadata; - - // The label path is a file that relates the outputs of the label tensor ints to strings - std::string label_path; - - // Maps from string names of tensors to the numeric - // value. Note that the keys here are the renamed tensors, if - // applicable. - std::unordered_map input_tensor_indices_by_name; - std::unordered_map output_tensor_indices_by_name; - - // Protects interpreter_error_data and interpreter - std::mutex interpreter_mutex; - - // The `Report` method will overwrite this string. - std::string interpreter_error_data; - - // The interpreter itself. - std::unique_ptr interpreter; - }; - - // A visitor that can populate a TFLiteTensor given a MLModelService::tensor_view. - class write_to_tflite_tensor_visitor_ : public boost::static_visitor { - public: - write_to_tflite_tensor_visitor_(const std::string* name, TfLiteTensor* tflite_tensor) - : name_(name), tflite_tensor_(tflite_tensor) {} - - template - TfLiteStatus operator()(const T& mlmodel_tensor) const { - const auto expected_size = TfLiteTensorByteSize(tflite_tensor_); - const auto* const mlmodel_data_begin = - reinterpret_cast(mlmodel_tensor.data()); - const auto* const mlmodel_data_end = reinterpret_cast( - mlmodel_tensor.data() + mlmodel_tensor.size()); - const auto mlmodel_data_size = - static_cast(mlmodel_data_end - mlmodel_data_begin); - if (expected_size != mlmodel_data_size) { - std::ostringstream buffer; - buffer << service_name << ": tensor `" << *name_ - << "` was expected to have byte size " << expected_size << " but " - << mlmodel_data_size << " bytes were provided"; - throw std::invalid_argument(buffer.str()); - } - return TfLiteTensorCopyFromBuffer(tflite_tensor_, mlmodel_data_begin, expected_size); - } - - private: - const std::string* name_; - TfLiteTensor* tflite_tensor_; - }; - - // Creates a tensor_view which views a tflite tensor buffer. It dispatches on the - // type and delegates to the templated version below. - MLModelService::tensor_views make_tensor_view_(const MLModelService::tensor_info& info, - const TfLiteTensor* const tflite_tensor) { - const auto tflite_tensor_type = TfLiteTensorType(tflite_tensor); - switch (tflite_tensor_type) { - case kTfLiteInt8: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteUInt8: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteInt16: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteUInt16: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteInt32: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteUInt32: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteInt64: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteUInt64: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteFloat32: { - return make_tensor_view_t_(info, tflite_tensor); - } - case kTfLiteFloat64: { - return make_tensor_view_t_(info, tflite_tensor); - } - default: { - std::ostringstream buffer; - buffer << service_name - << ": Model returned unsupported tflite data type: " << tflite_tensor_type; - throw std::invalid_argument(buffer.str()); - } + MLModelService::tensor_info input_info; + const auto* name = TfLiteTensorName(tensor); + input_info.name = name; + input_info.data_type = service_data_type_from_tflite_data_type(TfLiteTensorType(tensor)); + for (decltype(ndims) j = 0; j != ndims; ++j) { + input_info.shape.push_back(TfLiteTensorDim(tensor, j)); } + state->input_tensor_indices_by_name[input_info.name] = input_tensor_index; + state->metadata.inputs.emplace_back(std::move(input_info)); } - // The type specific version of the above function, it just - // reinterpret_casts the tensor buffer into an MLModelService - // tensor view and applies the necessary shape info. - template - MLModelService::tensor_views make_tensor_view_t_(const MLModelService::tensor_info& info, - const TfLiteTensor* const tflite_tensor) { - const auto* const tensor_data = reinterpret_cast(TfLiteTensorData(tflite_tensor)); - const auto tensor_size_bytes = TfLiteTensorByteSize(tflite_tensor); - const auto tensor_size_t = tensor_size_bytes / sizeof(T); - // TODO: We are just feeding back out what we cached in the - // metadata for shape. Should this instead be re-querying the - // output tensor NumDims / DimN after each invocation in case - // the shape is dynamic? The possibility of a dynamically - // sized extent is why we represent the dimensions as signed - // quantities in the tensor metadata. But an actual tensor has - // a real extent. How would tflite ever communicate that to us - // differently given that we use the same API to obtain - // metadata as we would here? - std::vector shape; - shape.reserve(info.shape.size()); - for (const auto s : info.shape) { - shape.push_back(static_cast(s)); - } - return MLModelService::make_tensor_view(tensor_data, tensor_size_t, std::move(shape)); + // NOTE: The tflite C API docs state that information about + // output tensors may not be available until after one round + // of inference. We do a best effort inference on all zero + // inputs to try to account for this. + for (auto input_tensor_index : input_tensor_indices) { + auto* const tensor = state->interpreter->tensor(input_tensor_index); + const auto tensor_size = TfLiteTensorByteSize(tensor); + const std::vector zero_buffer(tensor_size, 0); + TfLiteTensorCopyFromBuffer(tensor, &zero_buffer[0], tensor_size); } - // Accesss to the module state is serialized. All configuration - // state is held in the `state` type to make it easier to destroy - // the current state and replace it with a new one. - std::shared_mutex state_rwmutex_; - - // In C++17, this could be `std::optional`. - std::unique_ptr state_; -}; + if (state->interpreter->Invoke() != TfLiteStatus::kTfLiteOk) { + // TODO: After C++ SDK 0.11.0 is released, use the new logging API. + std::cout << "WARNING: Inference with all zero input tensors failed: returned output " + "tensor metadata may be unreliable" + << std::endl; + } -int serve(const std::string& socket_path) try { - // Every Viam C++ SDK program must have one and only one Instance object which is created before - // any other C++ SDK objects and stays alive until all Viam C++ SDK objects are destroyed. - vsdk::Instance inst; - - // Create a new model registration for the service. - auto module_registration = std::make_shared( - // Identify that this resource offers the MLModelService API - vsdk::API::get(), - - // Declare a model triple for this service. - vsdk::Model{"viam", "mlmodel-tflite", "tflite_cpu"}, - - // Define the factory for instances of the resource. - [](vsdk::Dependencies deps, vsdk::ResourceConfig config) { - return std::make_shared(std::move(deps), std::move(config)); - }); - - // Register the newly created registration with the Registry. - vsdk::Registry::get().register_model(module_registration); - - // Construct the module service and tell it where to place the socket path. - auto module_service = std::make_shared(socket_path); - - // Add the server as providing the API and model declared in the - // registration. - module_service->add_model_from_registry(module_registration->api(), - module_registration->model()); - - // Start the module service. - module_service->serve(); - - return EXIT_SUCCESS; -} catch (const std::exception& ex) { - std::cout << "ERROR: A std::exception was thrown from `serve`: " << ex.what() << std::endl; - return EXIT_FAILURE; -} catch (...) { - std::cout << "ERROR: An unknown exception was thrown from `serve`" << std::endl; - return EXIT_FAILURE; -} + // Now that we have hopefully done one round of inference, dig out the actual + // metadata that we will return to clients. + const auto output_tensor_indices = state->interpreter->outputs(); + for (auto output_tensor_index : output_tensor_indices) { + const auto* const tensor = state->interpreter->tensor(output_tensor_index); -} // namespace - -int main(int argc, char* argv[]) { - const std::string usage = "usage: mlmodelservice_tflite /path/to/unix/socket"; + auto ndims = TfLiteTensorNumDims(tensor); + if (ndims == -1) { + std::ostringstream buffer; + buffer << k_service_name + << ": Unable to determine output tensor shape at configuration time, " + "inference not possible"; + throw std::runtime_error(buffer.str()); + } - if (argc < 2) { - std::cout << "ERROR: insufficient arguments\n"; - std::cout << usage << "\n"; - return EXIT_FAILURE; + MLModelService::tensor_info output_info; + const auto* name = TfLiteTensorName(tensor); + output_info.name = name; + output_info.data_type = service_data_type_from_tflite_data_type(TfLiteTensorType(tensor)); + for (decltype(ndims) j = 0; j != ndims; ++j) { + output_info.shape.push_back(TfLiteTensorDim(tensor, j)); + } + if (state->label_path != "") { + output_info.extra.insert({"labels", state->label_path}); + } + state->output_tensor_indices_by_name[output_info.name] = output_tensor_index; + state->metadata.outputs.emplace_back(std::move(output_info)); } - return serve(argv[1]); + return state; } + +} // namespace mlmodel_tflite +} // namespace viam \ No newline at end of file diff --git a/src/tflite_cpu.hpp b/src/tflite_cpu.hpp new file mode 100644 index 0000000..d22ca49 --- /dev/null +++ b/src/tflite_cpu.hpp @@ -0,0 +1,62 @@ +#include + +#include +#include +#include + +namespace viam { +namespace mlmodel_tflite { + +namespace vsdk = ::viam::sdk; + +// An MLModelService instance which runs TensorFlow Lite models. +// +// Configuration requires the following parameters: +// -- `model_path`: An absolute filesystem path to a TensorFlow Lite model file. +// +// The following optional parameters are honored: +// -- `num_threads`: Sets the number of threads to be used, where applicable. +// +// -- `label_path`: An absolute filesystem path to a .txt file of the model's category labels. +// +// Any additional configuration fields are ignored. +class MLModelServiceTFLite final : public vsdk::MLModelService, + public vsdk::Stoppable, + public vsdk::Reconfigurable { + public: + MLModelServiceTFLite(vsdk::Dependencies dependencies, vsdk::ResourceConfig configuration); + + ~MLModelServiceTFLite() final; + + void stop(const vsdk::ProtoStruct& extra) noexcept final; + + /// @brief Stops the MLModelServiceTFLite from running. + void stop() noexcept; + + void reconfigure(const vsdk::Dependencies& dependencies, + const vsdk::ResourceConfig& configuration) final; + + std::shared_ptr infer(const named_tensor_views& inputs, + const vsdk::ProtoStruct& extra) final; + + struct metadata metadata(const vsdk::ProtoStruct& extra) final; + + private: + struct state_; + + void check_stopped_inlock_() const; + + static std::unique_ptr configure_(vsdk::Dependencies dependencies, + vsdk::ResourceConfig configuration); + + // Accesss to the module state is serialized. All configuration + // state is held in the `state` type to make it easier to destroy + // the current state and replace it with a new one. + std::shared_mutex state_rwmutex_; + + // In C++17, this could be `std::optional`. + std::unique_ptr state_; +}; + +} // namespace mlmodel_tflite +} \ No newline at end of file