Skip to content

Commit

Permalink
[TensorV2] Add functions to split and concatenate tensors
Browse files Browse the repository at this point in the history
This pull request adds two new methods to the Tensor class: split() and cat().
The split() method allows users to divide a tensor into multiple smaller tensors along the specified axis.
The cat() method enables them to combine multiple tensors into a single larger tensor along a given axis.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <dhyeon.jeong@samsung.com>
  • Loading branch information
djeong20 authored and jijoongmoon committed Mar 11, 2024
1 parent a4eac32 commit 847c552
Show file tree
Hide file tree
Showing 7 changed files with 543 additions and 12 deletions.
198 changes: 195 additions & 3 deletions nntrainer/tensor/float_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,19 @@ const void *FloatTensor::getAddress(unsigned int i) const {
return &((float *)getData())[i];
}

const float FloatTensor::getValue(unsigned int i) const {
const float &FloatTensor::getValue(unsigned int i) const {
return ((float *)getData())[i];
}

const float FloatTensor::getValue(unsigned int b, unsigned int c,
unsigned int h, unsigned int w) const {
float &FloatTensor::getValue(unsigned int i) { return ((float *)getData())[i]; }

const float &FloatTensor::getValue(unsigned int b, unsigned int c,
unsigned int h, unsigned int w) const {
return getValue(getIndex(b, c, h, w));
}

float &FloatTensor::getValue(unsigned int b, unsigned int c, unsigned int h,
unsigned int w) {
return getValue(getIndex(b, c, h, w));
}

Expand Down Expand Up @@ -896,6 +903,191 @@ void FloatTensor::zoneout_mask(TensorV2 &opposite, float zoneout) {
}
}

std::vector<TensorV2> FloatTensor::split(std::vector<size_t> sizes, int axis) {
size_t num_size = sizes.size();

if (axis == -1) {
axis = 3;
}

size_t total_size = std::accumulate(sizes.begin(), sizes.end(), 0);
NNTR_THROW_IF(dim.getTensorDim(axis) != total_size, std::invalid_argument)
<< "given sum of sizes did not match with origin tensor dim, tensor dim: "
<< dim.getTensorDim(axis) << " total size: " << total_size;

std::vector<TensorDim> ret_dims;
ret_dims.reserve(num_size);
for (unsigned int i = 0; i < num_size; ++i) {
ret_dims[i] = dim;
ret_dims[i].setTensorDim(axis, sizes[i]);
}

bool is_format_nchw = (dim.getFormat() == Tformat::NCHW) ? true : false;
std::vector<TensorV2> ret;

auto iter_value = [this, is_format_nchw](
std::array<size_t, 4> &loc,
const std::array<size_t, 4> &end_loc,
const std::array<size_t, 4> &reset_dim_arr) -> float & {
auto &value = (is_format_nchw) ? getValue(loc[0], loc[1], loc[2], loc[3])
: getValue(loc[0], loc[3], loc[1], loc[2]);
for (int i = 3; i >= 0; --i) {
loc[i]++;
if (loc[i] == end_loc[i]) {
loc[i] -= reset_dim_arr[i];
continue;
}
break;
}
return value;
};

ret.reserve(num_size);

unsigned int accumulated_size = 0;
for (unsigned int i = 0; i < num_size; ++i) {
std::array<size_t, 4> loc = {0, 0, 0, 0};

if (is_format_nchw) {
loc[axis] += accumulated_size;
} else {
if (axis == 0) {
loc[0] += accumulated_size;
} else if (axis == 1) {
loc[3] += accumulated_size;
} else if (axis == 2 || axis == 3) {
loc[axis - 1] += accumulated_size;
}
}

ret.emplace_back(ret_dims[i]);
auto &ret_t = ret.back();

std::array<size_t, 4> end_loc;

if (is_format_nchw) {
end_loc = {ret_dims[i].batch(), ret_dims[i].channel(),
ret_dims[i].height(), ret_dims[i].width()};
} else {
end_loc = {ret_dims[i].batch(), ret_dims[i].height(), ret_dims[i].width(),
ret_dims[i].channel()};
}

accumulated_size += sizes[i];

if (is_format_nchw) {
end_loc[axis] = accumulated_size;
} else {
if (axis == 0) {
end_loc[0] = accumulated_size;
} else if (axis == 1) {
end_loc[3] = accumulated_size;
} else if (axis == 2 || axis == 3) {
end_loc[axis - 1] = accumulated_size;
}
}

std::array<size_t, 4> reset_dim_arr;
if (is_format_nchw) {
reset_dim_arr = {ret_dims[i].batch(), ret_dims[i].channel(),
ret_dims[i].height(), ret_dims[i].width()};
} else {
reset_dim_arr = {ret_dims[i].batch(), ret_dims[i].height(),
ret_dims[i].width(), ret_dims[i].channel()};
}

ret_t.apply_i<float>(
[&iter_value, &loc, &end_loc, &reset_dim_arr](float _) {
return iter_value(loc, end_loc, reset_dim_arr);
});
}

return ret;
}

TensorV2 FloatTensor::cat(const std::vector<TensorV2> &tensors, int axis) {
if (axis == -1) {
axis = 3;
}

TensorV2 ret;
auto ref_dim = tensors.front().getDim();
bool is_format_nchw = (ref_dim.getFormat() == Tformat::NCHW);
ref_dim.setTensorDim(axis, 1);
NNTR_THROW_IF(!std::all_of(tensors.begin(), tensors.end(),
[&ref_dim, axis](const TensorV2 &t) {
auto cur_dim = t.getDim();
cur_dim.setTensorDim(axis, 1);
return ref_dim == cur_dim;
}),
std::invalid_argument)
<< " all tensor must have the same dimension except for the axis, ref_dim: "
<< ref_dim << " axis : " << axis;

auto axis_dim = std::accumulate(tensors.begin(), tensors.end(), 0u,
[axis](unsigned cur, const TensorV2 &t) {
return cur += t.getDim().getTensorDim(axis);
});
auto iter_value =
[is_format_nchw](std::array<unsigned, 4> &loc,
const std::array<unsigned, 4> &start_loc, TensorV2 &t,
const std::array<unsigned, 4> &ref_dim_arr) -> float & {
auto &value = is_format_nchw
? t.getValue<float>(loc[0], loc[1], loc[2], loc[3])
: t.getValue<float>(loc[0], loc[3], loc[1], loc[2]);

for (int i = 3; i >= 0; --i) {
loc[i]++;
if (loc[i] - start_loc[i] == ref_dim_arr[i]) {
loc[i] = start_loc[i];
continue;
}
break;
}
return value;
};

auto ret_dim = ref_dim;
ret_dim.setTensorDim(axis, axis_dim);

ret = TensorV2(ret_dim);

std::array<unsigned, 4> loc = {0, 0, 0, 0};
for (auto &t : tensors) {
std::array<unsigned, 4> start_loc = loc;
std::array<unsigned, 4> tensor_dim_arr;
if (is_format_nchw) {
tensor_dim_arr[0] = t.getDim().getTensorDim(0);
tensor_dim_arr[1] = t.getDim().getTensorDim(1);
tensor_dim_arr[2] = t.getDim().getTensorDim(2);
tensor_dim_arr[3] = t.getDim().getTensorDim(3);
} else {
tensor_dim_arr[0] = t.getDim().getTensorDim(0);
tensor_dim_arr[1] = t.getDim().getTensorDim(2);
tensor_dim_arr[2] = t.getDim().getTensorDim(3);
tensor_dim_arr[3] = t.getDim().getTensorDim(1);
}

for (size_t i = 0u, sz = t.size(); i < sz; ++i) {
iter_value(loc, start_loc, ret, tensor_dim_arr) = t.getValue<float>(i);
}

if (is_format_nchw) {
loc[axis] += t.getDim().getTensorDim(axis);
} else {
if (axis == 0) {
loc[0] += t.getDim().getTensorDim(axis);
} else if (axis == 1) {
loc[3] += t.getDim().getTensorDim(axis);
} else if (axis == 2 || axis == 3) {
loc[axis - 1] += t.getDim().getTensorDim(axis);
}
}
}

return ret;
}

void FloatTensor::print(std::ostream &out) const {
printInstance(out, this);
const float *data = (float *)getData();
Expand Down
32 changes: 29 additions & 3 deletions nntrainer/tensor/float_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,13 @@ class FloatTensor : public TensorBase {
* @brief return value at specific location
* @param[in] i index
*/
const float getValue(unsigned int i) const;
const float &getValue(unsigned int i) const;

/**
* @brief return value at specific location
* @param[in] i index
*/
float &getValue(unsigned int i);

/**
* @brief return value at specific location
Expand All @@ -128,8 +134,18 @@ class FloatTensor : public TensorBase {
* @param[in] h height location
* @param[in] w width location
*/
const float getValue(unsigned int b, unsigned int c, unsigned int h,
unsigned int w) const;
const float &getValue(unsigned int b, unsigned int c, unsigned int h,
unsigned int w) const;

/**
* @brief return value at specific location
* @param[in] b batch location
* @param[in] c channel location
* @param[in] h height location
* @param[in] w width location
*/
float &getValue(unsigned int b, unsigned int c, unsigned int h,
unsigned int w);

/**
* @copydoc TensorV2::setValue(float value)
Expand Down Expand Up @@ -302,6 +318,16 @@ class FloatTensor : public TensorBase {
*/
void zoneout_mask(TensorV2 &opposite, float zoneout) override;

/**
* @copydoc TensorV2::split(std::vector<size_t> sizes, int axis)
*/
std::vector<TensorV2> split(std::vector<size_t> sizes, int axis) override;

/**
* @copydoc TensorV2::cat(const std::vector<TensorV2> &tensors, int axis)
*/
static TensorV2 cat(const std::vector<TensorV2> &tensors, int axis);

/**
* @copydoc TensorV2::copy(const TensorV2 &from)
*/
Expand Down
Loading

0 comments on commit 847c552

Please sign in to comment.