Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions dali/operators/generic/select.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <algorithm>
#include "dali/operators/generic/select.h"

namespace dali {

DALI_SCHEMA(Select)
.DocStr(R"(Builds an output by selecting each sample from one of the inputs.

This operator is useful for conditionally selecting results of different operations.
The shapes of the corresponding samples in the inputs may differ, but the number of dimensions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about layouts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. See below.

and data type of all the inputs must be the same.

The output layout can be specified via ``layout`` argument or will be taken from the first
input with a non-empty layout.)")
.NumInput(1, 99)
.NumOutput(1)
.AllowSequences()
.SupportVolumetric()
.AddArg("input_idx", R"(Index of the input to take the sample from.

This argument contains (per-sample) indices that point to the input from which each
sample is taken.

Providing a negative index will produce an empty tensor with the same number of dimensions as
the inputs. Negative indices cannot be used with batches of scalars (0D tensors) since they
can never be empty.
)", DALI_INT32, true)
.AddOptionalArg<TensorLayout>("layout", R"(Layot string for the output.

If not specified, the input layouts are checked and the first non-empty is used.)", nullptr);

template <>
void Select<CPUBackend>::RunImpl(HostWorkspace &ws) {
SetOutputLayout(ws);
TensorVector<CPUBackend> &out = ws.OutputRef<CPUBackend>(0);
const auto &out_shape = out.shape();
int N = out_shape.num_samples();
int64_t element_size = out.type().size();
int64_t total_size = out_shape.num_elements() * element_size;
const int64_t min_size = 16 << 10;

ThreadPool &tp = ws.GetThreadPool();
int min_blocks = tp.NumThreads() * 10;
int64_t block_size = std::max(total_size / min_size, min_size);
for (int i = 0; i < N; i++) {
auto sample_size = out_shape.tensor_size(i) * element_size;
if (sample_size == 0)
continue;
int idx = *input_idx_[i].data;
auto &inp = ws.InputRef<CPUBackend>(idx);
assert(idx >= 0 && idx < spec_.NumRegularInput());
int blocks = div_ceil(sample_size, block_size);
char *out_ptr = static_cast<char*>(out.raw_mutable_tensor(i));
const char *in_ptr = static_cast<const char*>(inp.raw_tensor(i));
ptrdiff_t start = 0;
for (int block = 0; block < blocks; block++) {
ptrdiff_t end = sample_size * (block + 1) / blocks;
tp.AddWork([in_ptr, out_ptr, start, end](int) {
memcpy(out_ptr + start, in_ptr + start, end - start);
}, end - start);
start = end;
}
Comment on lines +67 to +76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a similar pattern used numpy reader. Maybe we can have a fucntion for this (even, make contiguous can use that in future).

Copy link
Contributor Author

@mzient mzient Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can factor it out when there are more usages of this kind of copy. NumpyReader is a bit more involved, though.
Regarding MakeContiguous - we won't have cpu2cpu (mixed) MakeContiguous when we unify workspaces and buffer objects.

}
tp.RunAll();
}

template <>
void Select<GPUBackend>::RunImpl(DeviceWorkspace &ws) {
SetOutputLayout(ws);
TensorList<GPUBackend> &out = ws.OutputRef<GPUBackend>(0);
const auto &out_shape = out.shape();
int N = out_shape.num_samples();
int64_t element_size = out.type().size();
for (int i = 0; i < N; i++) {
auto sample_size = out_shape.tensor_size(i) * element_size;
if (sample_size == 0)
continue;
int idx = *input_idx_[i].data;
auto &inp = ws.InputRef<GPUBackend>(idx);
assert(idx >= 0 && idx < spec_.NumRegularInput());
sg_.AddCopy(out.raw_mutable_tensor(i), inp.raw_tensor(i), sample_size);
}
sg_.Run(ws.stream());
}

DALI_REGISTER_OPERATOR(Select, Select<CPUBackend>, CPU);
DALI_REGISTER_OPERATOR(Select, Select<GPUBackend>, GPU);

} // namespace namespace dali
107 changes: 107 additions & 0 deletions dali/operators/generic/select.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_OPERATORS_GENERIC_SELECT_H_
#define DALI_OPERATORS_GENERIC_SELECT_H_

#include <vector>
#include "dali/pipeline/operator/operator.h"
#include "dali/pipeline/operator/arg_helper.h"
#include "dali/kernels/common/scatter_gather.h"

namespace dali {

template <typename Backend>
class Select : public Operator<Backend> {
public:
explicit Select(const OpSpec &spec) : Operator<Backend>(spec), input_idx_("input_idx", spec) {
has_layout_arg_ = spec.TryGetArgument(layout_, "layout");
}

bool CanInferOutputs() const override { return true; }

USE_OPERATOR_MEMBERS();

bool SetupImpl(vector<OutputDesc> &outputs, const workspace_t<Backend> &ws) override {
auto &inp0 = ws.template InputRef<Backend>(0);
int num_inp = spec_.NumRegularInput();
int sample_dim = inp0.sample_dim();
for (int i = 1; i < spec_.NumRegularInput(); i++) {
auto &inp = ws.template InputRef<Backend>(i);
DALI_ENFORCE(inp.type() == inp0.type(), make_string(
"All inputs must have the same type. "
"Got: ", inp0.type().id(), " and ", inp.type().id()));

DALI_ENFORCE(inp.sample_dim() == sample_dim, make_string(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we check layouts as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No; the semantics are that we pick the first non-empty layout - this sort-of implies that they can differ - and there are checks in the executor that require the non-empty input layouts to be of correct length.

"All inputs must have the same number of dimensions. "
"Got: ", sample_dim, " and ", inp.sample_dim()));
}

if (has_layout_arg_) {
DALI_ENFORCE(layout_.size() == sample_dim, make_string("The layout '", layout_, "' is not "
"a valid layout for ", sample_dim, "-D tensors."));
}

int num_samples = inp0.ntensor();

outputs.resize(1);
outputs[0].shape.resize(num_samples, sample_dim);
outputs[0].type = inp0.type();

input_idx_.Acquire(spec_, ws, num_samples);

TensorShape<> empty_shape;
empty_shape.resize(sample_dim);

for (int i = 0; i < num_samples; i++) {
int idx = *input_idx_[i].data;
bool is_valid_idx = (sample_dim > 0 && idx < 0) || (idx >= 0 && idx < num_inp);
DALI_ENFORCE(is_valid_idx, make_string("Invalid input index for sample ", i, ": ", idx));
if (idx < 0) {
outputs[0].shape.set_tensor_shape(i, empty_shape);
} else {
auto &inp = ws.template InputRef<Backend>(idx);
outputs[0].shape.set_tensor_shape(i, inp.tensor_shape(i));
}
}
return true;
}

void RunImpl(workspace_t<Backend> &ws) override;

private:
void SetOutputLayout(workspace_t<Backend> &ws) {
if (has_layout_arg_) {
ws.template OutputRef<Backend>(0).SetLayout(layout_);
} else {
for (int i = 0; i < spec_.NumRegularInput(); i++) {
auto &inp = ws.template InputRef<Backend>(i);
auto layout = inp.GetLayout();
if (!layout.empty()) {
ws.template OutputRef<Backend>(0).SetLayout(layout);
break;
}
}
}
}

ArgValue<int> input_idx_;
TensorLayout layout_;
bool has_layout_arg_;
kernels::ScatterGatherGPU sg_;
};

} // namespace dali

#endif // DALI_OPERATORS_GENERIC_SELECT_H_
14 changes: 13 additions & 1 deletion dali/test/python/test_dali_cpu_only.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -891,6 +891,17 @@ def get_data():
return out
check_single_input(fn.squeeze, axis_names="YZ", get_data=get_data, input_layout="HWCYZ")

def test_select_cpu():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to update test_dali_variable_batch_size.py as well.

Copy link
Contributor Author

@mzient mzient Jul 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The regular test uses variable batch size. I don't think we need to double that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need as the test_dali_variable_batch_size.py checks if all operators are tested. There is no other way to enforce this kind of test.

pipe = Pipeline(batch_size=batch_size, num_threads=3, device_id=None)
data = fn.external_source(source=get_data, layout="HWC")
data2 = fn.external_source(source=get_data, layout="HWC")
data3 = fn.external_source(source=get_data, layout="HWC")
idx = fn.random.uniform(range=[0, 2], dtype=types.INT32)
pipe.set_outputs(fn.select(data, data2, data3, input_idx=idx))
pipe.build()
for _ in range(3):
pipe.run()

def test_peek_image_shape_cpu():
pipe = Pipeline(batch_size=batch_size, num_threads=4, device_id=None)
input, _ = fn.readers.file(file_root=images_dir, shard_id=0, num_shards=1)
Expand Down Expand Up @@ -1025,6 +1036,7 @@ def test_separated_exec_setup():
"resize_crop_mirror",
"fast_resize_crop_mirror",
"segmentation.select_masks",
"select",
"slice",
"segmentation.random_mask_pixel",
"transpose",
Expand Down
Loading