forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test_gpu_validator.h
421 lines (379 loc) · 14.3 KB
/
test_gpu_validator.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
#include <torch/csrc/jit/codegen/cuda/executor_utils.h>
#include <torch/csrc/jit/codegen/cuda/expr_evaluator.h>
#include <torch/csrc/jit/codegen/cuda/fusion.h>
#include <torch/csrc/jit/codegen/cuda/ir_iostream.h>
#include <torch/csrc/jit/codegen/cuda/lower_utils.h>
#include <unordered_map>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
struct ValidationConstants {
// Tolerances generated from randn + add + sum fusion
// compared against double precision
std::array<std::array<double, 2>, 20> sum_tolerances_float = {
{{4, 1.51992e-06}, {8, 2.23704e-06}, {16, 2.95788e-06},
{32, 4.4778e-06}, {64, 6.75395e-06}, {128, 8.57934e-06},
{256, 1.30594e-05}, {512, 2.19122e-05}, {1024, 3.3451e-05},
{2048, 5.78476e-05}, {4096, 0.000108292}, {8192, 0.00012207},
{16384, 0.000136882}, {32768, 0.000248561}, {65536, 0.000407594},
{131072, 0.000500901}, {262144, 0.000923019}, {524288, 0.00156909},
{1048576, 0.00223107}, {2097152, 0.00343043}}};
// Tolerances generated from randn + add + sum fusion
// compared against double precision
std::array<std::array<double, 2>, 20> sum_tolerances_half = {
{{4, 0.00390625}, {8, 0.0078125}, {16, 0.0078125},
{32, 0.0155334}, {64, 0.0156269}, {128, 0.0312042},
{256, 0.0312548}, {512, 0.0619979}, {1024, 0.0625103},
{2048, 0.124686}, {4096, 0.12501}, {8192, 0.24945},
{16384, 0.250049}, {32768, 0.498946}, {65536, 0.500071},
{131072, 0.985087}, {262144, 1.00006}, {524288, 1.99234},
{1048576, 2.00032}, {2097152, 3.99073}}};
double base_half_abs_tol = -1;
double base_half_rel_tol = -1;
double base_float_abs_tol = -1;
double base_float_rel_tol = -1;
};
namespace {
// Returns abs and relative values to use for validation
std::pair<double, double> getTolerance(
DataType dtype,
int64_t reduction_size,
const ValidationConstants& tolerances) {
switch (dtype) {
case DataType::Float:
// TODO: Pull new tolerances for Double, for now we will just use float
// tolerances as it should be no worse.
case DataType::Double: {
const auto& sum_tolerance_entry = tolerances.sum_tolerances_float;
const auto& base_abs = tolerances.base_float_abs_tol;
const auto& base_rel = tolerances.base_float_rel_tol;
if (reduction_size <= 1) {
// No reduction case
if (base_abs == -1 || base_rel == -1) {
return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
} else {
return {base_abs, base_rel};
}
} else {
// Reduction case
size_t entry = 0;
while (sum_tolerance_entry[entry][0] < reduction_size &&
entry < sum_tolerance_entry.size()) {
entry++;
}
double abs_tol = 0.0;
if (entry + 1 < sum_tolerance_entry.size()) {
// Grab the next entry up so we have some margin
abs_tol = sum_tolerance_entry[entry + 1][1];
} else {
// If we hit the end of the list, return twice the max error we
// measured
abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
}
// Relative tol we're going to set to 1% of abs tol just for
// a small margin of rel error.
return {abs_tol, abs_tol * 0.01};
}
}
case DataType::Half: {
// Copied from float case
const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
const auto& base_abs = tolerances.base_half_abs_tol;
const auto& base_rel = tolerances.base_half_rel_tol;
if (reduction_size <= 1) {
// No reduction case
if (base_abs == -1 || base_rel == -1) {
return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
} else {
return {base_abs, base_rel};
}
} else {
// Reduction case
size_t entry = 0;
while (sum_tolerance_entry[entry][0] < reduction_size &&
entry < sum_tolerance_entry.size()) {
entry++;
}
double abs_tol = 0.0;
if (entry + 1 < sum_tolerance_entry.size()) {
// Grab the next entry up so we have some margin
abs_tol = sum_tolerance_entry[entry + 1][1];
} else {
// If we hit the end of the list, return twice the max error we
// measured
abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
}
// Relative tol we're going to set to 1% of abs tol just for
// a small margin of rel error.
return {abs_tol, abs_tol * 0.01};
}
}
case DataType::BFloat16: {
// Copied from float case
const auto& sum_tolerance_entry = tolerances.sum_tolerances_half;
const auto& base_abs = tolerances.base_half_abs_tol;
const auto& base_rel = tolerances.base_half_rel_tol;
if (reduction_size <= 1) {
// No reduction case
if (base_abs == -1 || base_rel == -1) {
return {sum_tolerance_entry[0][1], sum_tolerance_entry[1][1]};
} else {
return {base_abs * 10.0, base_rel * 10.0};
}
} else {
// Reduction case
size_t entry = 0;
while (sum_tolerance_entry[entry][0] < reduction_size &&
entry < sum_tolerance_entry.size()) {
entry++;
}
double abs_tol = 0.0;
if (entry + 1 < sum_tolerance_entry.size()) {
// Grab the next entry up so we have some margin
abs_tol = sum_tolerance_entry[entry + 1][1];
} else {
// If we hit the end of the list, return twice the max error we
// measured
abs_tol = sum_tolerance_entry[sum_tolerance_entry.size() - 1][1] * 2.;
}
// Relative tol we're going to set to 1% of abs tol just for
// a small margin of rel error.
return {abs_tol * 10.0, abs_tol * 0.01 * 10.0};
}
}
case DataType::Int:
return {0.0, 0.0};
case DataType::Int32:
return {0.0, 0.0};
case DataType::Bool:
return {0.0, 0.0};
default:
TORCH_INTERNAL_ASSERT(
false, "Do not have tolerance computation for type ", dtype, ".");
}
}
class ReductionSizeMapper : private IterVisitor {
public:
//! Runs through the fusion and determines how many reductions were performed
//! to compute each tensorview.
static std::unordered_map<TensorView*, int64_t> computeReductionSizes(
Fusion* fusion,
ExpressionEvaluator& expr_eval) {
ReductionSizeMapper mapper(fusion, expr_eval);
return mapper.reduction_map;
}
private:
ReductionSizeMapper(Fusion* fusion, ExpressionEvaluator& expr_eval)
: expr_eval_(expr_eval) {
// Initialize input values
for (auto inp : fusion->inputs()) {
if (inp->isA<TensorView>()) {
auto tv = inp->as<TensorView>();
// Shouldn't have any reductions, but run it through analysis anyways.
reduction_map[tv] = getReductionSize(tv);
}
}
IterVisitor::traverse(fusion);
// catch up with dangling outputs;
for (auto out : fusion->outputs()) {
if (out->isA<TensorView>()) {
auto tv = out->as<TensorView>();
// possible that we have a dangling output that's not generated by any
// expression. e.g. 0 workspace or null tensor
if (reduction_map.count(tv) == 0) {
// Shouldn't have any reductions, but run it through analysis anyways.
reduction_map[tv] = getReductionSize(tv);
}
}
}
}
int64_t getReductionSize(const TensorView* tv) {
int64_t reduction_elements = 1;
for (auto id : tv->getMaybeRFactorDomain()) {
if (id->isReduction()) {
auto inferred_extent = expr_eval_.evaluate(id->extent());
TORCH_INTERNAL_ASSERT(
inferred_extent.has_value(),
"Couldn't figure out what the dimensions of a tensorview is in evaluation for validation. ",
id,
" in ",
tv);
reduction_elements = reduction_elements * inferred_extent.value();
}
}
return reduction_elements;
}
void handle(Expr* expr) override {
if (!ir_utils::isTVOp(expr)) {
return;
}
int64_t inp_reduction_elements = 1;
for (auto inp : expr->inputs()) {
if (inp->isA<TensorView>()) {
if (auto tv = inp->as<TensorView>()) {
inp_reduction_elements =
std::max(inp_reduction_elements, reduction_map.at(tv));
}
}
}
for (auto out : expr->outputs()) {
if (out->isA<TensorView>()) {
auto tv = out->as<TensorView>();
reduction_map[tv] = getReductionSize(tv) * inp_reduction_elements;
}
}
}
private:
using IterVisitor::handle;
std::unordered_map<TensorView*, int64_t> reduction_map;
ExpressionEvaluator& expr_eval_;
};
ExpressionEvaluator bindInputsAndLaunchParams(
Fusion* fusion,
const at::ArrayRef<IValue>& aten_inputs,
const LaunchParams& launch_constraints) {
auto expr_eval = executor_utils::bindFusionInputs(aten_inputs, fusion);
for (auto val : fusion->vals()) {
if (!val->isA<TensorView>()) {
continue;
}
// Roughly taken from executor.cpp/computeLaunchParams
auto tv = val->as<TensorView>();
for (auto id : tv->domain()->domain()) {
if (!(id->isThread() && id->extent()->definition() == nullptr)) {
continue;
}
if (id->isBroadcast()) {
continue;
}
auto extent = id->extent();
auto inferred_extent = expr_eval.evaluate(extent);
auto p_type = id->getParallelType();
if (inferred_extent.has_value()) {
// This value could have been inferred, make sure it was set right.
TORCH_CHECK(
inferred_extent.value() == launch_constraints.getDim(p_type) ||
launch_constraints.getRawVal(p_type) == -1,
"inferred that ",
p_type,
" should be set to ",
inferred_extent.value(),
" but launch constraints specified ",
launch_constraints.getRawVal(p_type));
} else {
// Bind the launch constraint into our evaluation context
if (launch_constraints.hasDim(id->getParallelType())) {
expr_eval.bind(extent, launch_constraints.getDim(p_type));
}
}
}
}
return expr_eval;
}
} // namespace
// Validation will look through the fusion and figure out how many elements were
// reduced to create each output. It will then compute a tolernace to use for
// allclose based on experimental results. The experimental results were based
// on adding two tensors then summing them. This of course has an assumption
// that we're always summing values between -2 and 2. If we start summing values
// larger than that this approach might not hold.
inline void testValidate(
Fusion* fusion,
const std::vector<at::Tensor>& fusion_outputs,
const at::ArrayRef<IValue>& aten_inputs,
const std::vector<at::Tensor>& aten_outputs,
int line_number,
const char* file_name,
std::string err_msg = "",
const LaunchParams& lparams = LaunchParams(),
const ValidationConstants& tolerances = ValidationConstants()) {
FusionGuard fg(fusion);
auto expr_eval = bindInputsAndLaunchParams(fusion, aten_inputs, lparams);
auto reduction_sizes =
ReductionSizeMapper::computeReductionSizes(fusion, expr_eval);
TORCH_INTERNAL_ASSERT(
fusion_outputs.size() == aten_outputs.size() &&
aten_outputs.size() == fusion->outputs().size(),
"Number of outputs don't match.");
TORCH_INTERNAL_ASSERT(
fusion->inputs().size() == aten_inputs.size(),
"Number of inputs don't match.");
for (size_t i = 0; i < fusion->inputs().size(); i++) {
if (fusion->inputs()[i]->isA<TensorView>()) {
TORCH_INTERNAL_ASSERT(
aten_inputs[i].isTensor(), "Mismatch of tensor inputs.");
auto fusion_input_tv = fusion->inputs()[i]->as<TensorView>();
auto at_tensor = aten_inputs[i].toTensor();
TORCH_INTERNAL_ASSERT(
at_tensor.dim() ==
TensorDomain::noReductions(
fusion_input_tv->getMaybeRFactorDomain())
.size(),
"Dimensionality mismatch in inputs.");
}
}
for (size_t i = 0; i < fusion->outputs().size(); i++) {
TORCH_INTERNAL_ASSERT(
fusion->outputs()[i]->isA<TensorView>(), "Mismatch of tensor outputs.");
auto fusion_output_tensor = fusion_outputs[i];
auto fusion_output_tv = fusion->outputs()[i]->as<TensorView>();
auto aten_output_tensor = aten_outputs[i];
TORCH_INTERNAL_ASSERT(
reduction_sizes.count(fusion_output_tv),
"Missed reduction size count on fusion output at index: ",
i);
int64_t reduction_size = reduction_sizes.at(fusion_output_tv);
TORCH_INTERNAL_ASSERT(
aten_output_tensor.dim() == fusion_output_tensor.dim() &&
fusion_outputs[i].dim() ==
TensorDomain::noReductions(
fusion_output_tv->getMaybeRFactorDomain())
.size(),
"Dimensionality mismatch in inputs.");
auto tolerance_values = getTolerance(
fusion_output_tv->getDataType().value(), reduction_size, tolerances);
if (aten_output_tensor.is_floating_point()) {
TORCH_INTERNAL_ASSERT(
aten_output_tensor.allclose(
fusion_output_tensor.to(aten_output_tensor.dtype()),
tolerance_values.second,
tolerance_values.first),
"\n",
err_msg,
"\nValidation error in output ",
i,
" on line ",
line_number,
" in file ",
file_name,
".\n Detected abs error of: ",
aten_output_tensor.sub(fusion_output_tensor)
.abs()
.max()
.item()
.to<double>(),
"\n absolute tolerance was set to ",
tolerance_values.first,
"\n and relative tolerance set to ",
tolerance_values.second);
} else {
TORCH_INTERNAL_ASSERT(
aten_output_tensor.equal(
fusion_output_tensor.to(aten_output_tensor.dtype())),
"\n",
err_msg,
".\n Validation error in output ",
i,
" on line ",
line_number,
" in file ",
file_name,
".\n Values are not equal and are not a floating type.");
}
}
}
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch