diff --git a/dali/operators/generic/select.cc b/dali/operators/generic/select.cc new file mode 100644 index 00000000000..9629bddadbc --- /dev/null +++ b/dali/operators/generic/select.cc @@ -0,0 +1,103 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "dali/operators/generic/select.h" + +namespace dali { + +DALI_SCHEMA(Select) + .DocStr(R"(Builds an output by selecting each sample from one of the inputs. + +This operator is useful for conditionally selecting results of different operations. +The shapes of the corresponding samples in the inputs may differ, but the number of dimensions +and data type of all the inputs must be the same. + +The output layout can be specified via ``layout`` argument or will be taken from the first +input with a non-empty layout.)") + .NumInput(1, 99) + .NumOutput(1) + .AllowSequences() + .SupportVolumetric() + .AddArg("input_idx", R"(Index of the input to take the sample from. + +This argument contains (per-sample) indices that point to the input from which each +sample is taken. + +Providing a negative index will produce an empty tensor with the same number of dimensions as +the inputs. Negative indices cannot be used with batches of scalars (0D tensors) since they +can never be empty. +)", DALI_INT32, true) + .AddOptionalArg("layout", R"(Layot string for the output. + +If not specified, the input layouts are checked and the first non-empty is used.)", nullptr); + +template <> +void Select::RunImpl(HostWorkspace &ws) { + SetOutputLayout(ws); + TensorVector &out = ws.OutputRef(0); + const auto &out_shape = out.shape(); + int N = out_shape.num_samples(); + int64_t element_size = out.type().size(); + int64_t total_size = out_shape.num_elements() * element_size; + const int64_t min_size = 16 << 10; + + ThreadPool &tp = ws.GetThreadPool(); + int min_blocks = tp.NumThreads() * 10; + int64_t block_size = std::max(total_size / min_size, min_size); + for (int i = 0; i < N; i++) { + auto sample_size = out_shape.tensor_size(i) * element_size; + if (sample_size == 0) + continue; + int idx = *input_idx_[i].data; + auto &inp = ws.InputRef(idx); + assert(idx >= 0 && idx < spec_.NumRegularInput()); + int blocks = div_ceil(sample_size, block_size); + char *out_ptr = static_cast(out.raw_mutable_tensor(i)); + const char *in_ptr = static_cast(inp.raw_tensor(i)); + ptrdiff_t start = 0; + for (int block = 0; block < blocks; block++) { + ptrdiff_t end = sample_size * (block + 1) / blocks; + tp.AddWork([in_ptr, out_ptr, start, end](int) { + memcpy(out_ptr + start, in_ptr + start, end - start); + }, end - start); + start = end; + } + } + tp.RunAll(); +} + +template <> +void Select::RunImpl(DeviceWorkspace &ws) { + SetOutputLayout(ws); + TensorList &out = ws.OutputRef(0); + const auto &out_shape = out.shape(); + int N = out_shape.num_samples(); + int64_t element_size = out.type().size(); + for (int i = 0; i < N; i++) { + auto sample_size = out_shape.tensor_size(i) * element_size; + if (sample_size == 0) + continue; + int idx = *input_idx_[i].data; + auto &inp = ws.InputRef(idx); + assert(idx >= 0 && idx < spec_.NumRegularInput()); + sg_.AddCopy(out.raw_mutable_tensor(i), inp.raw_tensor(i), sample_size); + } + sg_.Run(ws.stream()); +} + +DALI_REGISTER_OPERATOR(Select, Select, CPU); +DALI_REGISTER_OPERATOR(Select, Select, GPU); + +} // namespace namespace dali diff --git a/dali/operators/generic/select.h b/dali/operators/generic/select.h new file mode 100644 index 00000000000..e34c94794b0 --- /dev/null +++ b/dali/operators/generic/select.h @@ -0,0 +1,107 @@ +// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_GENERIC_SELECT_H_ +#define DALI_OPERATORS_GENERIC_SELECT_H_ + +#include +#include "dali/pipeline/operator/operator.h" +#include "dali/pipeline/operator/arg_helper.h" +#include "dali/kernels/common/scatter_gather.h" + +namespace dali { + +template +class Select : public Operator { + public: + explicit Select(const OpSpec &spec) : Operator(spec), input_idx_("input_idx", spec) { + has_layout_arg_ = spec.TryGetArgument(layout_, "layout"); + } + + bool CanInferOutputs() const override { return true; } + + USE_OPERATOR_MEMBERS(); + + bool SetupImpl(vector &outputs, const workspace_t &ws) override { + auto &inp0 = ws.template InputRef(0); + int num_inp = spec_.NumRegularInput(); + int sample_dim = inp0.sample_dim(); + for (int i = 1; i < spec_.NumRegularInput(); i++) { + auto &inp = ws.template InputRef(i); + DALI_ENFORCE(inp.type() == inp0.type(), make_string( + "All inputs must have the same type. " + "Got: ", inp0.type().id(), " and ", inp.type().id())); + + DALI_ENFORCE(inp.sample_dim() == sample_dim, make_string( + "All inputs must have the same number of dimensions. " + "Got: ", sample_dim, " and ", inp.sample_dim())); + } + + if (has_layout_arg_) { + DALI_ENFORCE(layout_.size() == sample_dim, make_string("The layout '", layout_, "' is not " + "a valid layout for ", sample_dim, "-D tensors.")); + } + + int num_samples = inp0.ntensor(); + + outputs.resize(1); + outputs[0].shape.resize(num_samples, sample_dim); + outputs[0].type = inp0.type(); + + input_idx_.Acquire(spec_, ws, num_samples); + + TensorShape<> empty_shape; + empty_shape.resize(sample_dim); + + for (int i = 0; i < num_samples; i++) { + int idx = *input_idx_[i].data; + bool is_valid_idx = (sample_dim > 0 && idx < 0) || (idx >= 0 && idx < num_inp); + DALI_ENFORCE(is_valid_idx, make_string("Invalid input index for sample ", i, ": ", idx)); + if (idx < 0) { + outputs[0].shape.set_tensor_shape(i, empty_shape); + } else { + auto &inp = ws.template InputRef(idx); + outputs[0].shape.set_tensor_shape(i, inp.tensor_shape(i)); + } + } + return true; + } + + void RunImpl(workspace_t &ws) override; + + private: + void SetOutputLayout(workspace_t &ws) { + if (has_layout_arg_) { + ws.template OutputRef(0).SetLayout(layout_); + } else { + for (int i = 0; i < spec_.NumRegularInput(); i++) { + auto &inp = ws.template InputRef(i); + auto layout = inp.GetLayout(); + if (!layout.empty()) { + ws.template OutputRef(0).SetLayout(layout); + break; + } + } + } + } + + ArgValue input_idx_; + TensorLayout layout_; + bool has_layout_arg_; + kernels::ScatterGatherGPU sg_; +}; + +} // namespace dali + +#endif // DALI_OPERATORS_GENERIC_SELECT_H_ diff --git a/dali/test/python/test_dali_cpu_only.py b/dali/test/python/test_dali_cpu_only.py index 82b6ac68aa2..bd93cae56fd 100644 --- a/dali/test/python/test_dali_cpu_only.py +++ b/dali/test/python/test_dali_cpu_only.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -891,6 +891,17 @@ def get_data(): return out check_single_input(fn.squeeze, axis_names="YZ", get_data=get_data, input_layout="HWCYZ") +def test_select_cpu(): + pipe = Pipeline(batch_size=batch_size, num_threads=3, device_id=None) + data = fn.external_source(source=get_data, layout="HWC") + data2 = fn.external_source(source=get_data, layout="HWC") + data3 = fn.external_source(source=get_data, layout="HWC") + idx = fn.random.uniform(range=[0, 2], dtype=types.INT32) + pipe.set_outputs(fn.select(data, data2, data3, input_idx=idx)) + pipe.build() + for _ in range(3): + pipe.run() + def test_peek_image_shape_cpu(): pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=None) input, _ = fn.readers.file(file_root=images_dir, shard_id=0, num_shards=1) @@ -1025,6 +1036,7 @@ def test_separated_exec_setup(): "resize_crop_mirror", "fast_resize_crop_mirror", "segmentation.select_masks", + "select", "slice", "segmentation.random_mask_pixel", "transpose", diff --git a/dali/test/python/test_operator_select.py b/dali/test/python/test_operator_select.py new file mode 100644 index 00000000000..581d4f4385e --- /dev/null +++ b/dali/test/python/test_operator_select.py @@ -0,0 +1,157 @@ +# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from numpy import random +from numpy.core.fromnumeric import shape +from nvidia.dali import Pipeline +import nvidia.dali as dali +from nvidia.dali.external_source import external_source +import nvidia.dali.fn as fn +import numpy as np +import random +import test_utils + +random.seed(1234) +np.random.seed(1234) + +def generate_data(ndim, ninp, type, max_batch_size): + batch_size = np.random.randint(1, max_batch_size+1) + inp_sel = np.random.randint(0 if ndim == 0 else -1, ninp, (batch_size,), dtype=np.int32) + dtype = test_utils.dali_type_to_np(type) + if ndim > 0: + max_extent = max(10, int(100000 ** (1/ndim))) + random_shape = lambda: np.random.randint(1, max_extent, (ndim,)) + else: + random_shape = lambda: () + if type in (dali.types.FLOAT, dali.types.FLOAT64, dali.types.FLOAT16): + rnd = lambda: np.random.random(random_shape()).astype(dtype=dtype) + else: + rnd = lambda: np.random.randint(0, 100, size=random_shape(), dtype=dtype) + data = [[rnd() for _ in range(batch_size)] for _ in range(ninp)] + return data + [inp_sel] + +def make_layout(ndim, start='a'): + out = '' + for i in range(ndim): + out += chr(ord(start) + i) + return out + +def _test_select(ndim, ninp, type, max_batch_size, device, layout): + pipe = Pipeline(max_batch_size, 3, None if device == "cpu" else 0) + with pipe: + *inputs_cpu, idx = fn.external_source( + source=lambda: generate_data(ndim, ninp, type, max_batch_size), + num_outputs=ninp+1) + inputs_cpu = list(inputs_cpu) + input_layout = make_layout(ndim) + inputs_cpu[ninp - 1] = fn.reshape(inputs_cpu[ninp-1], layout=input_layout) + inputs = [x.gpu() for x in inputs_cpu] if device == "gpu" else inputs_cpu + out = fn.select(*inputs, input_idx=idx, device=device, layout=layout) + pipe.set_outputs(out, *inputs_cpu, idx) + pipe.build() + expected_layout = layout if layout is not None else input_layout + for iter in range(5): + out_tl, *input_tls, idx_tl = pipe.run() + assert out_tl.layout() == expected_layout + if device == "gpu": + out_tl = out_tl.as_cpu() + batch_size = len(out_tl) + for inp in input_tls: + assert len(inp) == batch_size, "Inconsistent batch size" + for i in range(batch_size): + idx = int(idx_tl.at(i)) + if idx < 0: + assert out_tl.at(i).size == 0 + else: + assert np.array_equal(out_tl.at(i), input_tls[idx].at(i)) + +def test_select(): + for device in ["cpu", "gpu"]: + for ndim in (0, 1, 2, 3): + for ninp in (1, 2, 3, 4): + type = random.choice([dali.types.FLOAT, dali.types.UINT8, dali.types.INT32]) + layout = make_layout(ndim, 'p') if np.random.randint(0, 2) else None + yield _test_select, ndim, ninp, type, 10, device, layout + + +def test_error_inconsistent_ndim(): + def check(device): + pipe = Pipeline(1, 3, 0) + pipe.set_outputs(fn.select(np.float32([0,1]), np.float32([[2,3,4]]), input_idx=1)) + pipe.build() + try: + pipe.run() + assert False, "Expected an exception" + except RuntimeError as e: + assert "same number of dimensions" in str(e), "Unexpected exception" + + for device in ["gpu", "cpu"]: + yield check, device + +def test_error_inconsistent_type(): + def check(device): + pipe = Pipeline(1, 3, 0) + pipe.set_outputs(fn.select(np.float32([0,1]), np.int32([2,3,4]), input_idx=1)) + pipe.build() + try: + pipe.run() + assert False, "Expected an exception" + except RuntimeError as e: + assert "same type" in str(e), "Unexpected exception" + + for device in ["gpu", "cpu"]: + yield check, device + +def test_error_input_out_of_range(): + def check(device): + pipe = Pipeline(1, 3, 0) + pipe.set_outputs(fn.select(np.float32([0,1]), np.float32([2,3,4]), input_idx=2)) + pipe.build() + try: + pipe.run() + assert False, "Expected an exception" + except RuntimeError as e: + assert "Invalid input index" in str(e), "Unexpected exception" + + for device in ["gpu", "cpu"]: + yield check, device + + +def test_error_empty_scalar(): + def check(device): + pipe = Pipeline(1, 3, 0) + pipe.set_outputs(fn.select(np.array(0, dtype=np.float32), input_idx=-1)) + pipe.build() + try: + pipe.run() + assert False, "Expected an exception" + except RuntimeError as e: + assert "Invalid input index" in str(e), "Unexpected exception" + + for device in ["gpu", "cpu"]: + yield check, device + +def test_error_bad_layout(): + def check(device): + pipe = Pipeline(1, 3, 0) + pipe.set_outputs(fn.select(np.float32([[[0,1]]]), np.float32([[[2,3,4]]]), input_idx=0, layout="ab")) + pipe.build() + try: + pipe.run() + assert False, "Expected an exception" + except RuntimeError as e: + assert "valid layout" in str(e), "Unexpected exception" + + for device in ["gpu", "cpu"]: + yield check, device