1+ // Copyright (c) 2020-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+ //
3+ // Licensed under the Apache License, Version 2.0 (the "License");
4+ // you may not use this file except in compliance with the License.
5+ // You may obtain a copy of the License at
6+ //
7+ // http://www.apache.org/licenses/LICENSE-2.0
8+ //
9+ // Unless required by applicable law or agreed to in writing, software
10+ // distributed under the License is distributed on an "AS IS" BASIS,
11+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ // See the License for the specific language governing permissions and
13+ // limitations under the License.
14+
15+
16+ #ifndef DALI_OPERATORS_GENERIC_REDUCE_AXIS_HELPER_H__
17+ #define DALI_OPERATORS_GENERIC_REDUCE_AXIS_HELPER_H__
18+
19+ #include < vector>
20+
21+ #include " dali/pipeline/operator/operator.h"
22+
23+ namespace dali {
24+ namespace detail {
25+
26+ class AxesHelper {
27+ public:
28+ explicit inline AxesHelper (const OpSpec &spec) {
29+ has_axes_arg_ = spec.TryGetRepeatedArgument (axes_, " axes" );
30+ has_axis_names_arg_ = spec.TryGetArgument (axis_names_, " axis_names" );
31+ has_empty_axes_arg_ =
32+ (has_axes_arg_ && axes_.empty ()) || (has_axis_names_arg_ && axis_names_.empty ());
33+
34+ DALI_ENFORCE (!has_axes_arg_ || !has_axis_names_arg_,
35+ " Arguments `axes` and `axis_names` are mutually exclusive" );
36+ }
37+
38+ void PrepareAxes (const TensorLayout &layout, int sample_dim) {
39+ if (has_axis_names_arg_) {
40+ axes_ = GetDimIndices (layout, axis_names_).to_vector ();
41+ return ;
42+ }
43+
44+ if (!has_axes_arg_) {
45+ axes_.resize (sample_dim);
46+ std::iota (axes_.begin (), axes_.end (), 0 );
47+ }
48+ }
49+
50+ bool has_axes_arg_;
51+ bool has_axis_names_arg_;
52+ bool has_empty_axes_arg_;
53+ std::vector<int > axes_;
54+ TensorLayout axis_names_;
55+ };
56+
57+ } // namespace detail
58+ } // namespace dali
59+
60+ #endif // DALI_OPERATORS_GENERIC_REDUCE_AXIS_HELPER_H__
0 commit comments