forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_backend_compiler_lib.cpp
173 lines (161 loc) · 6.36 KB
/
test_backend_compiler_lib.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
#include <ATen/Utils.h>
#include <c10/core/TensorImpl.h>
#include <torch/csrc/jit/backends/backend.h>
#include <torch/csrc/jit/backends/backend_exception.h>
#include <torch/csrc/jit/mobile/profiler_edge.h>
namespace torch {
namespace jit {
// Implementation of a PyTorch Backend that can process, compile and execute
// TorchScript Modules composed of 'add' and 'sub' operators. It just supports
// for modules that implement a sum or subtraction of 2 inputs (i.e. in1 + in2
// or in1 - in2). Hence the methods of the models expect exactly 2 inputs of
// type Tensor. This backend is used to demonstrate the flow of compilation and
// execution with minimum amount of work. It's not intended to a practical
// backend that can be used for actual inference.
// Implementation details:
//
// Compilation
// 1. A backend with minimum compilation features, "backend_with_compiler_demo"
// is added.
// 2. The compilation happens AOT in the preprocess function registered to this
// backend.
// 3. Compiled results are stored in a string blob for each method. They are
// serialized to the lowered module with __getstate__ function.
// 4. Error message with model source code is thrown, for features not handled
// by the backend compiler.
//
// Runtime
// 1. The compiled blob is loaded in __setstate__ method.
// 2. The compile function of the backend: parse the preprocessed blob to the
// format (a list of tokens) that the backend can understand.
// 3. The execute function of the backend executes the specified method
// (handle).
namespace {
std::vector<std::tuple<std::string, int64_t>> parseMethodHandle(
const std::string& blob) {
std::vector<std::tuple<std::string, int64_t>> result;
std::stringstream s_stream(blob);
constexpr char debug_handle_token[] = "<debug_handle>";
while (s_stream.good()) {
std::string substr;
getline(s_stream, substr, ',');
auto debug_handle_pos = substr.find(debug_handle_token);
int64_t debug_handle{-1};
auto instruction = substr.substr(0);
if (debug_handle_pos != std::string::npos) {
instruction = substr.substr(0, debug_handle_pos);
debug_handle = stoi(substr.substr(debug_handle_pos + 14));
}
result.push_back(std::make_tuple(instruction, debug_handle));
}
return result;
}
float* float_data_ptr(const at::Tensor& t) {
return t.unsafeGetTensorImpl()->data_ptr_impl<float>();
}
} // namespace
class BackendWithCompiler : public PyTorchBackendInterface {
public:
// Constructor.
// NOLINTNEXTLINE(modernize-use-equals-default)
explicit BackendWithCompiler() {}
// NOLINTNEXTLINE(modernize-use-override)
virtual ~BackendWithCompiler() = default;
bool is_available() override {
return true;
}
// Since the actual compilation is done AOT,
c10::impl::GenericDict compile(
c10::IValue processed,
c10::impl::GenericDict method_compile_spec) override {
auto dict = processed.toGenericDict();
auto handles =
c10::Dict<std::string, std::vector<std::tuple<std::string, int64_t>>>();
for (const auto& kv : dict) {
auto tokens = parseMethodHandle(kv.value().toStringRef());
handles.insert(kv.key().toStringRef(), tokens);
}
return c10::impl::toGenericDict(handles);
}
c10::impl::GenericList execute(
c10::IValue handle,
c10::impl::GenericList inputs) override {
TORCH_INTERNAL_ASSERT(inputs.size() == 2);
c10::IValue val0 = inputs[0];
at::Tensor x = val0.toTensor();
c10::IValue val1 = inputs[1];
at::Tensor h = val1.toTensor();
std::vector<std::tuple<int64_t, int64_t, std::string>> op_runtimes_us;
op_runtimes_us.reserve(handle.toList().size());
c10::List<at::Tensor> output_list;
auto start_us = autograd::profiler::getTime() / 1000;
for (const auto& token : handle.toList()) {
IValue val = token;
auto instruction = val.toTupleRef().elements()[0].toStringRef();
auto debug_handle = val.toTupleRef().elements()[1].toInt();
double const_val = 1.0;
auto start_time_us = autograd::profiler::getTime() / 1000;
try {
if (instruction.rfind("prim::Constant", 0) == 0) {
TORCH_CHECK(
instruction.size() > 15,
"Constant value is expected in ",
instruction);
// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
auto sub = instruction.substr(15);
// NOLINTNEXTLINE(clang-analyzer-deadcode.DeadStores)
const_val = stod(sub);
} else if (instruction == "aten::add" || instruction == "aten::sub") {
TORCH_CHECK(x.sizes() == h.sizes());
if (x.dim() > 1 || (x.dim() == 1 && x.size(0) > 1)) {
TORCH_WARN(
"Only the first elements of the tensors are added or subbed.");
}
TORCH_CHECK(
(x.scalar_type() == c10::ScalarType::Float &&
h.scalar_type() == c10::ScalarType::Float),
"Only float tensors are compatible for add and sub.");
auto y = at::detail::empty_cpu(
x.sizes(), c10::ScalarType::Float, {}, {}, {}, c10::nullopt);
auto x_ptr = float_data_ptr(x);
auto h_ptr = float_data_ptr(h);
auto y_ptr = float_data_ptr(y);
if (instruction == "aten::add") {
y_ptr[0] = x_ptr[0] + h_ptr[0];
} else {
y_ptr[0] = x_ptr[0] - h_ptr[0];
}
output_list.emplace_back(y);
} else {
TORCH_CHECK(
false,
"Instruction, ",
instruction,
" is not supported. ",
"Contact the backend POC for details. ");
}
} catch (c10::Error& e) {
TORCH_DELEGATED_BACKEND_THROW(false, e.what(), debug_handle);
}
auto end_time_us = autograd::profiler::getTime() / 1000;
auto duration = end_time_us - start_time_us;
op_runtimes_us.emplace_back(duration, debug_handle, instruction);
}
for (const auto& tup : op_runtimes_us) {
RECORD_BACKEND_EVENT_TO_EDGE_PROFILER(
start_us,
start_us + std::get<0>(tup),
std::get<1>(tup),
std::get<2>(tup),
"test_backend");
start_us = start_us + std::get<0>(tup);
}
return c10::impl::toList(output_list);
}
};
namespace {
constexpr auto backend_name = "backend_with_compiler_demo";
static auto cls = torch::jit::backend<BackendWithCompiler>(backend_name);
} // namespace
} // namespace jit
} // namespace torch