Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions src/frontends/pytorch/src/op/neg.cpp

This file was deleted.

199 changes: 95 additions & 104 deletions src/frontends/pytorch/src/op/pixel_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,21 @@
// SPDX-License-Identifier: Apache-2.0
//

#include <utility>

#include "openvino/frontend/pytorch/node_context.hpp"
#include "openvino/op/add.hpp"
#include "openvino/op/concat.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/depth_to_space.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/gather.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/range.hpp"
#include "openvino/op/reshape.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/shuffle_channels.hpp"
#include "openvino/op/slice.hpp"
#include "openvino/op/space_to_depth.hpp"
#include "openvino/op/split.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

namespace ov {
Expand All @@ -26,116 +26,107 @@ namespace op {

using namespace ov::op;

namespace {

// Holds the split shape information; dims_before captures all pre-spatial dims
struct PixelSpatialInfo {
Output<Node> dims_before;
Output<Node> channels;
Output<Node> height;
Output<Node> width;
};

PixelSpatialInfo get_pixel_spatial_info(const NodeContext& context, const Output<Node>& x) {
auto zero_vec = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto zero_scalar = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto neg_three = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
auto one_vec = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
auto dims_before = context.mark_node(std::make_shared<v8::Slice>(shape, zero_vec, neg_three, one_vec));
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {-3, -2, -1}));
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero_scalar));
auto split = context.mark_node(std::make_shared<v1::Split>(dims, zero_scalar, 3));
return {dims_before, split->output(0), split->output(1), split->output(2)};
}

Output<Node> make_flatten_shape(const NodeContext& context,
const Output<Node>& channels,
const Output<Node>& height,
const Output<Node>& width) {
auto neg_one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto chw = context.mark_node(std::make_shared<v0::Concat>(OutputVector{channels, height, width}, 0));
return context.mark_node(std::make_shared<v0::Concat>(OutputVector{neg_one, chw}, 0));
}

Output<Node> make_final_shape(const NodeContext& context,
const Output<Node>& dims_before,
const Output<Node>& new_c,
const Output<Node>& new_h,
const Output<Node>& new_w) {
auto tail = context.mark_node(std::make_shared<v0::Concat>(OutputVector{new_c, new_h, new_w}, 0));
return context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, tail}, 0));
}

OutputVector translate_pixel_transform(const NodeContext& context, bool is_shuffle) {
num_inputs_check(context, 2, 2);
const auto x = context.get_input(0);
const auto block = context.const_input<int64_t>(1);
PYTORCH_OP_CONVERSION_CHECK(block > 0, "Upscale factor for pixel shuffle ops must be positive");

const auto block_size = static_cast<size_t>(block);
const auto block_scalar =
context.mark_node(v0::Constant::create(element::i32, Shape{}, {static_cast<int32_t>(block)}));
const auto block_sq_scalar =
context.mark_node(v0::Constant::create(element::i32, Shape{}, {static_cast<int32_t>(block * block)}));

const auto [dims_before, channels, height, width] = get_pixel_spatial_info(context, x);
const auto flat_shape = make_flatten_shape(context, channels, height, width);
const auto flattened = context.mark_node(std::make_shared<v1::Reshape>(x, flat_shape, false));

Output<Node> transformed;
Output<Node> new_c;
Output<Node> new_h;
Output<Node> new_w;

if (is_shuffle) {
transformed = context.mark_node(
std::make_shared<v0::DepthToSpace>(flattened, v0::DepthToSpace::DepthToSpaceMode::DEPTH_FIRST, block_size));
new_c = context.mark_node(std::make_shared<v1::Divide>(channels, block_sq_scalar));
new_h = context.mark_node(std::make_shared<v1::Multiply>(height, block_scalar));
new_w = context.mark_node(std::make_shared<v1::Multiply>(width, block_scalar));
} else {
transformed = context.mark_node(
std::make_shared<v0::SpaceToDepth>(flattened, v0::SpaceToDepth::SpaceToDepthMode::DEPTH_FIRST, block_size));
new_c = context.mark_node(std::make_shared<v1::Multiply>(channels, block_sq_scalar));
new_h = context.mark_node(std::make_shared<v1::Divide>(height, block_scalar, true));
new_w = context.mark_node(std::make_shared<v1::Divide>(width, block_scalar, true));
}

const auto final_shape = make_final_shape(context, dims_before, new_c, new_h, new_w);
auto reshaped = context.mark_node(std::make_shared<v1::Reshape>(transformed, final_shape, false));
return {std::move(reshaped)};
}

} // namespace

OutputVector translate_pixel_shuffle(const NodeContext& context) {
// aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto upscale_factor = get_input_as_i32(context, 1);
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto zero_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto one_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
Output<Node> shape;
Output<Node> rank;
std::tie(shape, rank) = get_shape_rank(context, x, true);
// 1. Reshape input to [*, -1, r, r, H, W], where r is upscale factor
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {-3, -2, -1}));
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero_s));
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero_s, 3));
auto c = dims_splitted->output(0);
auto h = dims_splitted->output(1);
auto w = dims_splitted->output(2);
auto dims_before = context.mark_node(std::make_shared<v8::Slice>(shape, zero, neg_3, one));
auto upscale_factor_1d = context.mark_node(std::make_shared<v1::Reshape>(upscale_factor, neg_1, false));
auto intermediate_shape = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{dims_before, neg_1, upscale_factor_1d, upscale_factor_1d, h, w}, 0));
auto reshape = context.mark_node(std::make_shared<v1::Reshape>(x, intermediate_shape, false));
// 2. Transpose tensor to [*, C, r, H, r, W]
auto dims_before_len = context.mark_node(std::make_shared<v3::ShapeOf>(dims_before, element::i32));
auto dims_before_len_s = context.mark_node(std::make_shared<v0::Squeeze>(dims_before_len, zero));
auto order_begin = context.mark_node(std::make_shared<v4::Range>(zero_s, dims_before_len_s, one_s, element::i32));
auto order_end_neg = context.mark_node(
v0::Constant::create(element::i32, Shape{5}, {-3, 0, -2, 1, -1})); // +2 because rank is expanded
auto order_end = context.mark_node(std::make_shared<v1::Add>(order_end_neg, rank));
auto order = context.mark_node(std::make_shared<v0::Concat>(OutputVector{order_begin, order_end}, 0));
auto transpose = context.mark_node(std::make_shared<v1::Transpose>(reshape, order));
// 3. Reshape to [*, -1, r * H, r * W]
auto new_h = context.mark_node(std::make_shared<v1::Multiply>(h, upscale_factor));
auto new_w = context.mark_node(std::make_shared<v1::Multiply>(w, upscale_factor));
auto shape_after =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, neg_1, new_h, new_w}, 0));
return {context.mark_node(std::make_shared<v1::Reshape>(transpose, shape_after, false))};
return translate_pixel_transform(context, true);
};

OutputVector translate_pixel_unshuffle(const NodeContext& context) {
// aten::pixel_unshuffle(Tensor self, int upscale_factor) -> Tensor
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto upscale_factor = get_input_as_i32(context, 1);
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto neg_3 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-3}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {0}));
auto zero_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {1}));
auto one_s = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
Output<Node> shape;
Output<Node> rank;
std::tie(shape, rank) = get_shape_rank(context, x, true);
// 1. Reshape input to [-1, C, H / r, r, W / r, r], where r is upscale factor
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{3}, {-3, -2, -1}));
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero_s));
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero_s, 3));
auto c = dims_splitted->output(0);
auto h = dims_splitted->output(1);
auto w = dims_splitted->output(2);
auto dims_before = context.mark_node(std::make_shared<v8::Slice>(shape, zero, neg_3, one));
auto r = context.mark_node(std::make_shared<v0::Unsqueeze>(upscale_factor, zero));
auto new_h = context.mark_node(std::make_shared<v1::Divide>(h, upscale_factor, true));
auto new_w = context.mark_node(std::make_shared<v1::Divide>(w, upscale_factor, true));
auto intermediate_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{neg_1, c, new_h, r, new_w, r}, 0));
auto x_reshaped = context.mark_node(std::make_shared<v1::Reshape>(x, intermediate_shape, false));
// 2. Transpose to [-1, C, r, r, H / r, W / r]
auto transpose_order = context.mark_node(v0::Constant::create(element::i32, Shape{6}, {0, 1, 3, 5, 2, 4}));
auto x_transposed = context.mark_node(std::make_shared<v1::Transpose>(x_reshaped, transpose_order));
// 3. Reshape to [*, C*r*r, H / r, W / r]
auto r_sqr = context.mark_node(std::make_shared<v1::Multiply>(r, r));
auto new_c = context.mark_node(std::make_shared<v1::Multiply>(c, r_sqr));
auto final_shape =
context.mark_node(std::make_shared<v0::Concat>(OutputVector{dims_before, new_c, new_h, new_w}, 0));
return {context.mark_node(std::make_shared<v1::Reshape>(x_transposed, final_shape, false))};
return translate_pixel_transform(context, false);
};

OutputVector translate_channel_shuffle(const NodeContext& context) {
// aten::channel_shuffle(Tensor self, int groups) -> Tensor
num_inputs_check(context, 2, 2);
auto x = context.get_input(0);
auto groups = context.get_input(1);
auto neg_1 = context.mark_node(v0::Constant::create(element::i32, Shape{1}, {-1}));
auto zero = context.mark_node(v0::Constant::create(element::i32, Shape{}, {0}));
auto one = context.mark_node(v0::Constant::create(element::i32, Shape{}, {1}));
auto shape = context.mark_node(std::make_shared<v3::ShapeOf>(x, element::i32));
// PyTorch realization uses assumption that channels dim is always 1
auto indices = context.mark_node(v0::Constant::create(element::i32, Shape{2}, {0, 1}));
auto dims = context.mark_node(std::make_shared<v8::Gather>(shape, indices, zero));
auto dims_splitted = context.mark_node(std::make_shared<v1::Split>(dims, zero, 2));
auto c = dims_splitted->output(1);
auto n = dims_splitted->output(0);
groups = context.mark_node(std::make_shared<v0::Convert>(groups, element::i32));
auto k = context.mark_node(std::make_shared<v1::Divide>(c, groups, true));
auto g = context.mark_node(std::make_shared<v0::Unsqueeze>(groups, zero));
// 1. Reshape input [N, G, K=C/G, -1]
auto reshape_indices = context.mark_node(
std::make_shared<v0::Concat>(OutputVector{std::move(n), std::move(g), std::move(k), std::move(neg_1)}, 0));
x = context.mark_node(std::make_shared<v1::Reshape>(x, reshape_indices, false));
// 2. Transpose to [N, K, G, -1]
auto permute_indices = context.mark_node(v0::Constant::create(element::i32, Shape{4}, {0, 2, 1, 3}));
auto y = context.mark_node(std::make_shared<v1::Transpose>(x, permute_indices));
// 3. Reshape back to original shape
auto result = context.mark_node(std::make_shared<v1::Reshape>(y, shape, false));
return {result};
const auto x = context.get_input(0);
const auto groups = context.const_input<int64_t>(1);
PYTORCH_OP_CONVERSION_CHECK(groups > 0, "groups argument for channel_shuffle must be positive");
auto shuffled = context.mark_node(std::make_shared<v0::ShuffleChannels>(x, 1, static_cast<size_t>(groups)));
return {std::move(shuffled)};
};

} // namespace op
Expand Down
7 changes: 3 additions & 4 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ OP_CONVERTER(translate_movedim);
OP_CONVERTER(translate_multinomial);
OP_CONVERTER(translate_narrow);
OP_CONVERTER(translate_native_multi_head_attention);
OP_CONVERTER(translate_neg);
OP_CONVERTER(translate_new_full);
OP_CONVERTER(translate_new_ones);
OP_CONVERTER(translate_new_zeros);
Expand Down Expand Up @@ -638,7 +637,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::multinomial", op::translate_multinomial},
{"aten::narrow", op::translate_narrow},
{"aten::ne", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten::neg", op::translate_neg},
{"aten::neg", op::translate_1to1_match_1_inputs<opset10::Negative>},
{"aten::new_empty", op::translate_new_zeros},
{"aten::new_full", op::translate_new_full},
{"aten::new_ones", op::translate_new_ones},
Expand Down Expand Up @@ -827,7 +826,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"<built-in function floordiv>", op::translate_floor_divide},
{"<built-in function getitem>", op::translate_getitem}, // TODO: Check if there is any other way to handle this
{"<built-in function mul>", op::translate_mul},
{"<built-in function neg>", op::translate_neg},
{"<built-in function neg>", op::translate_1to1_match_1_inputs<opset10::Negative>},
{"<built-in function sub>", op::translate_sub},
{"aten._adaptive_avg_pool1d.default", op::translate_adaptive_avg_pool1d},
{"aten._adaptive_avg_pool2d.default", op::translate_adaptive_avg_pool2d},
Expand Down Expand Up @@ -1014,7 +1013,7 @@ const std::unordered_map<std::string, CreatorFunction> get_supported_ops_fx() {
{"aten.native_layer_norm.default", op::translate_layer_norm_fx},
{"aten.ne.Scalar", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten.ne.Tensor", op::translate_1to1_match_2_inputs_align_types<opset10::NotEqual>},
{"aten.neg.default", op::translate_neg},
{"aten.neg.default", op::translate_1to1_match_1_inputs<opset10::Negative>},
{"aten.new_full.default", op::translate_new_full_fx},
{"aten.new_ones.default", op::translate_new_ones_fx},
{"aten.new_zeros.default", op::translate_new_zeros_fx},
Expand Down
Loading