Skip to content

Commit 7681c3e

Browse files
LukeBoyertensorflower-gardener
authored andcommitted
Change the compile function type to take a model rather than a list of subgraphs
PiperOrigin-RevId: 723715188
1 parent a7814f8 commit 7681c3e

16 files changed

+88
-72
lines changed

tensorflow/lite/experimental/litert/compiler/plugin/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ cc_library(
3737
"//tensorflow/lite/experimental/litert/core/model",
3838
"//tensorflow/lite/experimental/litert/core/model:ir_allocator",
3939
"//tensorflow/lite/experimental/litert/core/model:model_serialize",
40+
"//tensorflow/lite/experimental/litert/core/util:flatbuffer_tools",
4041
"//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin",
4142
"//tensorflow/lite/experimental/litert/vendors/c:litert_compiler_plugin_api",
4243
"@com_google_absl//absl/log:absl_check",

tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.cc

+21-5
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <array>
1919
#include <cstddef>
2020
#include <cstdint>
21+
#include <memory>
2122
#include <optional>
2223
#include <string>
2324
#include <utility>
@@ -46,6 +47,7 @@
4647
#include "tensorflow/lite/experimental/litert/core/model/ir_allocator.h"
4748
#include "tensorflow/lite/experimental/litert/core/model/model.h"
4849
#include "tensorflow/lite/experimental/litert/core/model/model_serialize.h"
50+
#include "tensorflow/lite/experimental/litert/core/util/flatbuffer_tools.h"
4951
#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h"
5052
#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h"
5153

@@ -339,16 +341,16 @@ Expected<std::vector<LiteRtOp>> CompilerPlugin::Partition(
339341
return ops.Vec();
340342
}
341343

342-
Expected<CompiledResult> CompilerPlugin::Compile(
343-
absl::Span<LiteRtSubgraph> partitions, absl::string_view soc_model) {
344+
Expected<CompiledResult> CompilerPlugin::Compile(LiteRtModel partitions,
345+
absl::string_view soc_model) {
344346
CompiledResult result = MakeResult();
345347
// If the user has passed an soc_model, then we use it; otherwise we let the
346348
// backend pick the appropriate one by passing nullptr as soc_model. This is
347349
// important for on-device compilation, where the backend must determine the
348350
// SoC model based on the user device.
349351
const char* soc_model_str = !soc_model.empty() ? soc_model.data() : nullptr;
350352
LITERT_RETURN_IF_ERROR(plugin_api_.compiler_plugin_compile(
351-
plugin_handle_, soc_model_str, partitions.data(), partitions.size(),
353+
plugin_handle_, soc_model_str, partitions,
352354
&result.compiled_result_handle_));
353355
return result;
354356
}
@@ -408,9 +410,23 @@ Expected<void> ApplyPlugin(CompilerPlugin& compiler_plugin, LiteRtModelT& model,
408410
auto& dispatch_ops = partitions->first;
409411
auto& subgraphs = partitions->second;
410412

413+
// Wrap the partitioned subgraphs in a LiteRtModel.
414+
LiteRtModelT sliced_model;
415+
sliced_model.TransferSubgraphs(std::move(subgraphs));
416+
417+
// Copy op codes.
418+
const auto& op_codes = detail::GetTflOpCodes(model);
419+
420+
LiteRtModelT::TflOpCodes codes;
421+
codes.reserve(op_codes.size());
422+
for (const auto& op_code : op_codes) {
423+
codes.emplace_back(std::make_unique<TflOpCode>(*op_code));
424+
}
425+
426+
detail::SetTflOpCodes(sliced_model, std::move(codes));
427+
411428
// Pass sliced subgraphs to plugin for compilation.
412-
auto compiled_result =
413-
compiler_plugin.Compile(subgraphs.Elements(), soc_model);
429+
auto compiled_result = compiler_plugin.Compile(&sliced_model, soc_model);
414430
if (!compiled_result) {
415431
return compiled_result.Error();
416432
}

tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class CompilerPlugin {
103103

104104
// Compile given LiteRtSubgraphs. Result object must be outlived by
105105
// this CompilerPlugin.
106-
Expected<CompiledResult> Compile(absl::Span<LiteRtSubgraph> partitions,
106+
Expected<CompiledResult> Compile(LiteRtModel partitions,
107107
absl::string_view soc_model = "");
108108

109109
// Search for shared library files with prefix "libLiteRtCompilerPlugin" in

tensorflow/lite/experimental/litert/compiler/plugin/compiler_plugin_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ TEST(CompilerPluginTest, Compile) {
122122
auto model_wrap = testing::LoadTestFileModel("mul_simple.tflite");
123123
auto& model = *model_wrap.Get();
124124

125-
auto result = plugins->front().Compile(model.Subgraphs());
125+
auto result = plugins->front().Compile(&model);
126126
ASSERT_TRUE(result);
127127

128128
auto byte_code = result->ByteCode();

tensorflow/lite/experimental/litert/tools/apply_plugin.cc

+1-2
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ LiteRtStatus Compile(Context& ctx) {
361361
ctx.Dump().Start("Compiling");
362362
DumpCompilationRequest(ctx.Dump(), ctx.SocModelTarget(),
363363
model.NumSubgraphs());
364-
auto compilation_result =
365-
plugin->Compile(model.Subgraphs(), ctx.SocModelTarget());
364+
auto compilation_result = plugin->Compile(&model, ctx.SocModelTarget());
366365
if (!compilation_result) {
367366
ctx.Dump().Fail();
368367
return compilation_result.Error().Status();

tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h

+4-6
Original file line numberDiff line numberDiff line change
@@ -65,14 +65,12 @@ LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin,
6565
LiteRtSubgraph subgraph,
6666
LiteRtOpList selected_ops);
6767

68-
// Prepare result to pass to the runtime for given partition and, optionally,
69-
// for a given SoC model (parameter `soc_model` can be NULL to specify a default
70-
// SoC model). The given subgraphs are valid sub-DAG within the ops selected in
71-
// partition step.
68+
// Prepare result to pass to the runtime for given model containing partitioned
69+
// subgraphs. Optionally, handles a SoC model (parameter `soc_model` can be NULL
70+
// to specify a default SoC model).
7271
LiteRtStatus LiteRtCompilerPluginCompile(LiteRtCompilerPlugin compiler_plugin,
7372
const char* soc_model,
74-
LiteRtSubgraph* partitions,
75-
LiteRtParamIndex num_partitions,
73+
LiteRtModel partitions,
7674
LiteRtCompiledResult* compiled_result);
7775

7876
//

tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin_api.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ typedef LiteRtStatus (*LiteRtCompilerPluginPartitionT)(
5454
LiteRtCompilerPlugin, LiteRtSubgraph subgraph, LiteRtOpList selected_ops);
5555

5656
typedef LiteRtStatus (*LiteRtCompilerPluginCompileT)(
57-
LiteRtCompilerPlugin, const char* soc_model, LiteRtSubgraph* partitions,
58-
LiteRtParamIndex num_partitions, LiteRtCompiledResult* compiled_result);
57+
LiteRtCompilerPlugin, const char* soc_model, LiteRtModel partitions,
58+
LiteRtCompiledResult* compiled_result);
5959

6060
typedef void (*LiteRtDestroyCompiledResultT)(LiteRtCompiledResult);
6161

tensorflow/lite/experimental/litert/vendors/examples/example_plugin.cc

+5-3
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,14 @@ LiteRtStatus CompileSinglePartition(LiteRtParamIndex partition_index,
8888

8989
LiteRtStatus LiteRtCompilerPluginCompile(
9090
LiteRtCompilerPlugin compiler_plugin, const char* soc_model,
91-
LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions,
92-
LiteRtCompiledResult* compiled_result) {
91+
LiteRtModel partitions, LiteRtCompiledResult* compiled_result) {
9392
LiteRtCompiledResult result = new LiteRtCompiledResultT;
9493

94+
auto model = litert::Model::CreateFromNonOwnedHandle(partitions);
95+
const auto num_partitions = model.NumSubgraphs();
9596
for (auto i = 0; i < num_partitions; ++i) {
96-
LITERT_RETURN_IF_ERROR(CompileSinglePartition(i, partitions[i], *result));
97+
LITERT_RETURN_IF_ERROR(
98+
CompileSinglePartition(i, model.Subgraph(i)->Get(), *result));
9799
}
98100

99101
*compiled_result = result;

tensorflow/lite/experimental/litert/vendors/examples/example_plugin_test.cc

+1-4
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,9 @@ TEST(TestCallDummyPlugin, CompileMulSubgraph) {
6565
auto plugin = CreatePlugin();
6666
auto model = testing::LoadTestFileModel("mul_simple.tflite");
6767

68-
auto main_subgraph = model.MainSubgraph();
69-
LiteRtSubgraph litert_subgraph = main_subgraph->Get();
70-
7168
LiteRtCompiledResult compiled;
7269
LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(
73-
plugin.get(), /*soc_model=*/nullptr, &litert_subgraph, 1, &compiled));
70+
plugin.get(), /*soc_model=*/nullptr, model.Get(), &compiled));
7471

7572
const void* byte_code;
7673
size_t byte_code_size;

tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,16 @@ LiteRtStatus CompileSinglePartition(
118118
// infrastructure.
119119
LiteRtStatus LiteRtCompilerPluginCompile(
120120
LiteRtCompilerPlugin compiler_plugin, const char* soc_model,
121-
LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions,
122-
LiteRtCompiledResult* compiled_result) {
121+
LiteRtModel partitions, LiteRtCompiledResult* compiled_result) {
123122
auto* result = new LiteRtCompiledResultT;
124123

124+
auto model = litert::Model::CreateFromNonOwnedHandle(partitions);
125+
const auto num_partitions = model.NumSubgraphs();
125126
for (auto i = 0; i < num_partitions; ++i) {
126127
auto name = absl::StrFormat("partition_%lu", i);
127128
LITERT_RETURN_IF_ERROR(
128129
CompileSinglePartition(compiler_plugin->legalizations, std::move(name),
129-
partitions[i], *result));
130+
model.Subgraph(i)->Get(), *result));
130131
}
131132

132133
*compiled_result = result;

tensorflow/lite/experimental/litert/vendors/examples/example_plugin_with_conversions_test.cc

+1-5
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,9 @@ TEST(ExamplePluginWithConvertTypesTest, CompileMulSubgraph) {
6969
auto plugin = CreatePlugin();
7070
auto model = litert::testing::LoadTestFileModel("mul_simple.tflite");
7171

72-
auto main_subgraph = model.MainSubgraph();
73-
LiteRtSubgraph litert_subgraph = main_subgraph->Get();
74-
7572
LiteRtCompiledResult compiled;
7673
LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(
77-
plugin.get(), /*soc_model=*/nullptr, &litert_subgraph,
78-
/*num_partitions*/ 1, &compiled));
74+
plugin.get(), /*soc_model=*/nullptr, model.Get(), &compiled));
7975

8076
const void* byte_code;
8177
size_t byte_code_size;

tensorflow/lite/experimental/litert/vendors/google_tensor/compiler/google_tensor_compiler_plugin_test.cc

+2-5
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,9 @@ TEST(TestCallGoogleTensorPlugin, CompileMulSubgraph) {
6565
auto plugin = CreatePlugin();
6666
auto model = testing::LoadTestFileModel("mul_simple.tflite");
6767

68-
auto main_subgraph = model.MainSubgraph();
69-
LiteRtSubgraph litert_subgraph = main_subgraph->Get();
70-
7168
LiteRtCompiledResult compiled;
72-
LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(plugin.get(), "P25",
73-
&litert_subgraph, 1, &compiled));
69+
LITERT_ASSERT_OK(
70+
LiteRtCompilerPluginCompile(plugin.get(), "P25", model.Get(), &compiled));
7471

7572
LiteRtDestroyCompiledResult(compiled);
7673
} // Todo(abhirs): activate this test once the compiler wrapper is updated

tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin.cc

+6-4
Original file line numberDiff line numberDiff line change
@@ -283,8 +283,10 @@ Expected<std::vector<uint8_t>> CompilePartition(
283283

284284
LiteRtStatus LiteRtCompilerPluginCompile(
285285
LiteRtCompilerPlugin compiler_plugin, const char* soc_model,
286-
LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions,
287-
LiteRtCompiledResult* compiled_result) {
286+
LiteRtModel partitions, LiteRtCompiledResult* compiled_result) {
287+
auto model = litert::Model::CreateFromNonOwnedHandle(partitions);
288+
const auto num_partitions = model.NumSubgraphs();
289+
288290
LITERT_LOG(LITERT_INFO,
289291
"Starting MediaTek Compilation for %d subgraphs, soc_model=%s",
290292
num_partitions, soc_model);
@@ -306,11 +308,11 @@ LiteRtStatus LiteRtCompilerPluginCompile(
306308
}
307309

308310
auto result = std::make_unique<LiteRtCompiledResultT>();
311+
309312
for (auto i = 0; i < num_partitions; ++i) {
310-
auto partition = litert::Subgraph(partitions[i]);
311313
auto graph_name = absl::StrFormat("Partition_%d", i);
312314
auto bytecode =
313-
CompilePartition(**api, partition, graph_name, opt_soc_model);
315+
CompilePartition(**api, *model.Subgraph(i), graph_name, opt_soc_model);
314316
if (!bytecode) {
315317
LITERT_LOG(LITERT_INFO, "%s", bytecode.Error().Message().c_str());
316318
return bytecode.Error().Status();

tensorflow/lite/experimental/litert/vendors/mediatek/compiler/compiler_plugin_test.cc

+28-18
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
#include "tensorflow/lite/experimental/litert/cc/litert_macros.h"
2626
#include "tensorflow/lite/experimental/litert/core/model/model.h"
2727
#include "tensorflow/lite/experimental/litert/test/common.h"
28-
#include "tensorflow/lite/experimental/litert/test/matchers.h"
2928
#include "tensorflow/lite/experimental/litert/test/test_models.h"
3029
#include "tensorflow/lite/experimental/litert/vendors/c/litert_compiler_plugin.h"
3130
#include "tensorflow/lite/experimental/litert/vendors/cc/litert_compiler_plugin.h"
@@ -43,18 +42,24 @@ const auto kSupportedOps = Values(
4342
// clang-format on
4443

4544
TEST(TestQnnPlugin, GetConfigInfo) {
45+
#ifndef __ANDROID__
46+
GTEST_SKIP() << "Loading shared lib not currently supported on linux.";
47+
#endif // __ANDROID__
48+
4649
EXPECT_STREQ(LiteRtGetCompilerPluginSocManufacturer(), "MediaTek");
4750

4851
auto plugin = CreatePlugin();
4952

5053
LiteRtParamIndex num_supported_soc_models;
51-
LITERT_ASSERT_OK(LiteRtGetNumCompilerPluginSupportedSocModels(
52-
plugin.get(), &num_supported_soc_models));
54+
ASSERT_EQ(LiteRtGetNumCompilerPluginSupportedSocModels(
55+
plugin.get(), &num_supported_soc_models),
56+
kLiteRtStatusOk);
5357
ASSERT_EQ(num_supported_soc_models, 12);
5458

5559
const char* config_id;
56-
LITERT_ASSERT_OK(
57-
LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id));
60+
ASSERT_EQ(
61+
LiteRtGetCompilerPluginSupportedSocModel(plugin.get(), 0, &config_id),
62+
kLiteRtStatusOk);
5863
EXPECT_STREQ(config_id, "mt6853");
5964
}
6065

@@ -63,8 +68,9 @@ TEST(TestQnnPlugin, PartitionAdd) {
6368
auto model = testing::LoadTestFileModel("add_simple.tflite");
6469

6570
LiteRtOpListT selected_op_list;
66-
LITERT_ASSERT_OK(LiteRtCompilerPluginPartition(
67-
plugin.get(), model.Subgraph(0)->Get(), &selected_op_list));
71+
ASSERT_EQ(LiteRtCompilerPluginPartition(
72+
plugin.get(), model.Subgraph(0)->Get(), &selected_op_list),
73+
kLiteRtStatusOk);
6874
const auto selected_ops = selected_op_list.Vec();
6975

7076
ASSERT_EQ(selected_ops.size(), 1);
@@ -77,27 +83,30 @@ class MtkPluginOpCompatibilityTest
7783
: public ::testing::TestWithParam<std::string> {};
7884

7985
TEST_P(MtkPluginOpCompatibilityTest, SupportedOpsTest) {
86+
#ifndef __ANDROID__
87+
GTEST_SKIP() << "Loading shared lib not currently supported on linux.";
88+
#endif // __ANDROID__
89+
8090
LITERT_LOG(LITERT_INFO, "Testing TFLite model: %s", GetParam().c_str());
8191
auto plugin = CreatePlugin();
8292
auto model = testing::LoadTestFileModel(GetParam());
8393

84-
const auto subgraph = model.MainSubgraph();
85-
LiteRtSubgraph litert_subgraph = subgraph->Get();
86-
8794
LiteRtCompiledResult compiled;
88-
LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(
89-
plugin.get(), /*soc_model=*/nullptr, &litert_subgraph, 1, &compiled));
95+
ASSERT_EQ(LiteRtCompilerPluginCompile(plugin.get(), /*soc_model=*/nullptr,
96+
model.Get(), &compiled),
97+
kLiteRtStatusOk);
9098

9199
LiteRtParamIndex num_byte_code;
92-
LITERT_ASSERT_OK(
93-
LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code));
100+
ASSERT_EQ(LiteRtCompiledResultNumByteCodeModules(compiled, &num_byte_code),
101+
kLiteRtStatusOk);
94102
ASSERT_EQ(num_byte_code, 1);
95103

96104
const void* byte_code;
97105
size_t byte_code_size;
98106

99-
LITERT_ASSERT_OK(LiteRtGetCompiledResultByteCode(
100-
compiled, /*byte_code_idx=*/0, &byte_code, &byte_code_size));
107+
ASSERT_EQ(LiteRtGetCompiledResultByteCode(compiled, /*byte_code_idx=*/0,
108+
&byte_code, &byte_code_size),
109+
kLiteRtStatusOk);
101110

102111
absl::string_view byte_code_string(reinterpret_cast<const char*>(byte_code),
103112
byte_code_size);
@@ -107,8 +116,9 @@ TEST_P(MtkPluginOpCompatibilityTest, SupportedOpsTest) {
107116
size_t op_data_size;
108117
LiteRtParamIndex byte_code_idx;
109118

110-
LITERT_ASSERT_OK(LiteRtGetCompiledResultCallInfo(
111-
compiled, /*call_idx=*/0, &op_data, &op_data_size, &byte_code_idx));
119+
ASSERT_EQ(LiteRtGetCompiledResultCallInfo(compiled, /*call_idx=*/0, &op_data,
120+
&op_data_size, &byte_code_idx),
121+
kLiteRtStatusOk);
112122

113123
EXPECT_EQ(byte_code_idx, 0);
114124

tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin.cc

+6-3
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,10 @@ LiteRtStatus LiteRtCompilerPluginPartition(LiteRtCompilerPlugin compiler_plugin,
257257

258258
LiteRtStatus LiteRtCompilerPluginCompile(
259259
LiteRtCompilerPlugin compiler_plugin, const char* soc_model,
260-
LiteRtSubgraph* partitions, LiteRtParamIndex num_partitions,
261-
LiteRtCompiledResult* compiled_result) {
260+
LiteRtModel partitions, LiteRtCompiledResult* compiled_result) {
261+
auto model = litert::Model::CreateFromNonOwnedHandle(partitions);
262+
const auto num_partitions = model.NumSubgraphs();
263+
262264
LITERT_LOG(LITERT_INFO,
263265
"Starting QNN Compilation for %d subgraphs, soc_model=%s",
264266
num_partitions, soc_model);
@@ -304,8 +306,9 @@ LiteRtStatus LiteRtCompilerPluginCompile(
304306
{
305307
std::string& entry_point_name = result->graph_names.emplace_back();
306308
entry_point_name = "qnn_partition_0";
309+
LiteRtSubgraph partition = model.Subgraph(0)->Get();
307310
LITERT_RETURN_IF_ERROR(litert::qnn::ComposeGraph(
308-
**qnn_manager, context_handle->get(), partitions[0], entry_point_name));
311+
**qnn_manager, context_handle->get(), partition, entry_point_name));
309312
}
310313
LITERT_LOG(LITERT_INFO, "%s", "Graph composed");
311314

tensorflow/lite/experimental/litert/vendors/qualcomm/compiler/qnn_compiler_plugin_test.cc

+4-10
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,9 @@ TEST(TestQnnPlugin, CompileMulSubgraph) {
123123
auto plugin = CreatePlugin();
124124
auto model = testing::LoadTestFileModel("one_mul.tflite");
125125

126-
const auto subgraph = model.MainSubgraph();
127-
LiteRtSubgraph litert_subgraph = subgraph->Get();
128-
129126
LiteRtCompiledResult compiled;
130-
LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(plugin.get(), "V75",
131-
&litert_subgraph, 1, &compiled));
127+
LITERT_ASSERT_OK(
128+
LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled));
132129

133130
const void* byte_code;
134131
size_t byte_code_size;
@@ -254,12 +251,9 @@ TEST_P(QnnPluginOpCompatibilityTest, SupportedOpsTest) {
254251
auto plugin = CreatePlugin();
255252
auto model = testing::LoadTestFileModel(GetParam());
256253

257-
const auto subgraph = model.MainSubgraph();
258-
LiteRtSubgraph litert_subgraph = subgraph->Get();
259-
260254
LiteRtCompiledResult compiled;
261-
LITERT_ASSERT_OK(LiteRtCompilerPluginCompile(plugin.get(), "V75",
262-
&litert_subgraph, 1, &compiled));
255+
LITERT_ASSERT_OK(
256+
LiteRtCompilerPluginCompile(plugin.get(), "V75", model.Get(), &compiled));
263257

264258
const void* byte_code;
265259
size_t byte_code_size;

0 commit comments

Comments
 (0)