Skip to content

Commit

Permalink
Add transform_unary_fusion (#312)
Browse files Browse the repository at this point in the history
* add transform_unary_fusion

* add compose operator for node_t, support printing functor_composition_t on graphviz

* add predecessors for adjacency_list

* update tests

* temporary testing workaround on gcc
  • Loading branch information
alifahrri authored Dec 26, 2024
1 parent 03ad8ac commit 51cfdfa
Show file tree
Hide file tree
Showing 6 changed files with 1,060 additions and 3 deletions.
52 changes: 51 additions & 1 deletion include/nmtools/array/functional/compute_graph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,17 @@ namespace nmtools::functional
operands_type operands;
output_shape_type output_shape = {};
output_element_type output_element = {};

template <typename other_functor_t, typename other_operands_t, typename other_output_shape_t, typename other_output_element_t>
constexpr auto operator*(const node_t<other_functor_t,other_operands_t,other_output_shape_t,other_output_element_t>& other) const
{
auto composition = functor * other.functor;
// todo handle different arity
auto operands = other.operands;
using result_t = node_t<decltype(composition),other_operands_t,output_shape_type,output_element_type>;
// TODO: do not discard intermediate shape/type
return result_t{composition,operands,output_shape,output_element};
}
};

template <typename functor_t, typename operands_t>
Expand Down Expand Up @@ -414,14 +425,53 @@ namespace nmtools::utils::impl
auto str = nmtools_string("");
str += "[graphviz_record_layout_open]";
str += fmap_str;
str += " | ";
str += " ";
str += attr_str;
str += "[graphviz_record_layout_close]";

return str;
}
};

template <template<typename...>typename tuple, typename...functors_t, typename operands_t>
struct to_string_t<functional::functor_composition_t<tuple<functors_t...>,operands_t>, graphviz_t>
{
using composition_type = functional::functor_composition_t<tuple<functors_t...>,operands_t>;
using formatter_type = graphviz_t;
using result_type = nmtools_string;

auto operator()(const composition_type& composition) const noexcept
{
auto str = nmtools_string("");
str += "[graphviz_record_layout_open]";

constexpr auto N = sizeof...(functors_t);
meta::template_for<N>([&](auto index){
constexpr auto I = decltype(index)::value;
const auto& functor = nmtools::get<I>(composition.functors);
auto fmap_str = to_string(functor.fmap,utils::Compact);
auto attr_str = nmtools_string("");
using attributes_t = decltype(functor.attributes);
constexpr auto M = meta::len_v<attributes_t>;
meta::template_for<M>([&](auto index){
const auto& attribute = nmtools::at(functor.attributes,index);
attr_str += to_string(attribute,utils::Compact);
if (index < (M-1)) {
attr_str += ",";
}
});
str += fmap_str;
str += attr_str;
if constexpr (index < N-1) {
str += " | ";
}
});

str += "[graphviz_record_layout_close]";
return str;
}
};

template <typename functor_t, typename operands_t, typename output_shape_t, typename output_element_t>
struct to_string_t<
functional::node_t<functor_t,operands_t,output_shape_t,output_element_t>, graphviz_t, void
Expand Down
58 changes: 58 additions & 0 deletions include/nmtools/array/functional/transform/unary_fusion.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef NMTOOLS_ARRAY_FUNCTIONAL_TRANSFORM_UNARY_FUSION_HPP
#define NMTOOLS_ARRAY_FUNCTIONAL_TRANSFORM_UNARY_FUSION_HPP

#include "nmtools/meta.hpp"
#include "nmtools/utility/ct_digraph.hpp"
#include "nmtools/array/functional/functor.hpp"

namespace nmtools::functional
{
// find a node that has single output and single input, in which the input only used by that node
template <typename adjacency_list_t>
constexpr auto find_unary_fusion(const adjacency_list_t& adjacency_list)
{
auto sorted = utility::topological_sort(adjacency_list);
auto predecessors = utility::predecessors(adjacency_list);

auto from = -1;
auto to = -1;
for (nm_size_t i=0; i<(nm_size_t)sorted.size(); i++) {
auto node = sorted[i];
if (predecessors[node].size() != 1) {
continue;
}
auto predecessor = predecessors[node][0];
if ((predecessors[predecessor].size() == 0) || (adjacency_list[predecessor].size() > 1)) {
continue;
}
from = predecessor;
to = node;
break;
}
return nmtools_tuple{from,to};
} // find_unary_fusion

template <typename graph_t>
constexpr auto transform_unary_fusion(const graph_t& graph)
{
// TODO: check if graph is fn::compute_graph_t or utility::ct_digraph
constexpr auto adjacency_result = utility::adjacency_list(decltype(graph.digraph){});
constexpr auto adjacency_list = nmtools::get<0>(adjacency_result);
constexpr auto src_id_map = nmtools::get<1>(adjacency_result);

constexpr auto unary_fusion = find_unary_fusion(adjacency_list);
constexpr auto from = nmtools::get<0>(unary_fusion);
constexpr auto to = nmtools::get<1>(unary_fusion);

if constexpr ((from < 0) || (to < 0)) {
return graph;
} else {
auto from_ct = meta::ct_v<src_id_map[from]>;
auto to_ct = meta::ct_v<src_id_map[to]>;
auto fused = graph.nodes(to_ct) * graph.nodes(from_ct);
return utility::contracted_edge(graph,nmtools_tuple{from_ct,to_ct},to_ct,fused);
}
}
} // namespace nmtools::functional

#endif // NMTOOLS_ARRAY_FUNCTIONAL_TRANSFORM_UNARY_FUSION_HPP
27 changes: 25 additions & 2 deletions include/nmtools/utility/ct_digraph.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,14 @@ namespace nmtools::utility

constexpr auto nodes = meta::to_value_v<nodes_t>;

for (nm_size_t i=0; i<N; i++) {
for (nm_size_t i=0; i<(nm_size_t)N; i++) {
id_map[i] = nodes[i];
}

// given id, return node
auto reverse_id = [&](auto id)->nm_index_t{
for (nm_size_t i=0; i<(nm_size_t)len(id_map); i++) {
if (id_map[i] == id) {
if (id_map[i] == (nm_size_t)id) {
return i;
}
}
Expand Down Expand Up @@ -124,6 +124,29 @@ namespace nmtools::utility
result[list[i][j]]++;
}
}
return result;
} // in_degree

template <typename adjacency_list_t>
constexpr auto predecessors(const adjacency_list_t& adj_list)
{
// assume static_vector or array
// TODO: check for another type
constexpr auto N = meta::bounded_size_v<typename adjacency_list_t::value_type>;

using result_t = utl::static_vector<utl::static_vector<nm_size_t,N>,N>;
auto result = result_t {};

result.resize(adj_list.size());

for (nm_size_t i=0; i<(nm_size_t)adj_list.size(); i++) {
auto edges = adj_list[i];
for (nm_size_t j=0; j<(nm_size_t)edges.size(); j++) {
auto edge = edges.at(j);
result[edge].push_back(i);
}
}

return result;
}

Expand Down
2 changes: 2 additions & 0 deletions tests/functional/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ set(FUNCTIONAL_TEST_GRAPH_SOURCES
src/combinator/dig.cpp
src/combinator/swap.cpp
src/combinator/dup.cpp

src/transform/unary_fusion.cpp
)

set(FUNCTIONAL_TEST_MISC_SOURCES
Expand Down
159 changes: 159 additions & 0 deletions tests/functional/src/misc/ct_digraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
namespace nm = nmtools;
namespace fn = nmtools::functional;
namespace meta = nmtools::meta;
namespace view = nmtools::view;
namespace na = nmtools::array;
namespace ix = nmtools::index;
namespace utils = nmtools::utils;
Expand Down Expand Up @@ -944,6 +945,7 @@ TEST_CASE("contracted_edge(case3)" * doctest::test_suite("ct_digraph"))
}
}

// matmul
TEST_CASE("contracted_edge(case1)" * doctest::test_suite("ct_digraph"))
{
auto lhs_shape = array{3,4};
Expand Down Expand Up @@ -1532,4 +1534,161 @@ TEST_CASE("topological_sort(case3a)" * doctest::test_suite("ct_digraph"))
auto expected = array{0,472,435,470,51,428,391,433,1022};

NMTOOLS_ASSERT_EQUAL( result, expected );
}

// matmul
TEST_CASE("predecessors(case1)" * doctest::test_suite("ct_digraph"))
{
using vector = utl::static_vector<int,8>;

auto make_vector = [](auto...v){
auto vec = vector();
(vec.push_back(v),...);
return vec;
};

auto adj_list = array<vector,8>{
make_vector(2),
make_vector(3),
make_vector(4),
make_vector(5),
make_vector(6),
make_vector(6),
make_vector(7),
make_vector(),
};

auto result = utility::predecessors(adj_list);
auto expected = array<vector,8>{
make_vector(),
make_vector(),
make_vector(0),
make_vector(1),
make_vector(2),
make_vector(3),
make_vector(4,5),
make_vector(6),
};

NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() );
for (nm_size_t i=0; i<(nm_size_t)result.size(); i++) {
NMTOOLS_ASSERT_EQUAL( result.at(i), expected.at(i) );
}
}

// softmax
TEST_CASE("predecessors(case2)" * doctest::test_suite("ct_digraph"))
{
using vector = utl::static_vector<int,8>;

auto make_vector = [](auto...v){
auto vec = vector();
(vec.push_back(v),...);
return vec;
};

auto adj_list = array<vector,6>{
make_vector(1,2),
make_vector(2),
make_vector(3),
make_vector(4,5),
make_vector(5),
make_vector()
};

auto result = utility::predecessors(adj_list);
auto expected = array<vector,6>{
make_vector(),
make_vector(0),
make_vector(0,1),
make_vector(2),
make_vector(3),
make_vector(3,4),
};

NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() );
for (nm_size_t i=0; i<(nm_size_t)result.size(); i++) {
NMTOOLS_ASSERT_EQUAL( result.at(i), expected.at(i) );
}
}

// var
TEST_CASE("predecessors(case3)" * doctest::test_suite("ct_digraph"))
{
using vector = utl::static_vector<int,10>;

auto make_vector = [](auto...v){
auto vec = vector();
(vec.push_back(v),...);
return vec;
};

auto adj_list = array<vector,10>{
make_vector(2,4),
make_vector(3),
make_vector(3),
make_vector(4),
make_vector(5),
make_vector(6),
make_vector(7),
make_vector(9),
make_vector(9),
make_vector()
};

auto result = utility::predecessors(adj_list);
auto expected = array<vector,10>{
make_vector(),
make_vector(),
make_vector(0),
make_vector(1,2),
make_vector(0,3),
make_vector(4),
make_vector(5),
make_vector(6),
make_vector(),
make_vector(7,8)
};

NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() );
for (nm_size_t i=0; i<(nm_size_t)result.size(); i++) {
NMTOOLS_ASSERT_EQUAL( result.at(i), expected.at(i) );
}
}

TEST_CASE("predecessors(case1b)" * doctest::test_suite("functional"))
{
using vector = utl::static_vector<int,8>;

auto make_vector = [](auto...v){
auto vec = vector();
(vec.push_back(v),...);
return vec;
};

auto adj_list = array<vector,7>{
make_vector(2),
make_vector(3),
make_vector(5),
make_vector(4),
make_vector(5),
make_vector(6),
make_vector()
};

auto result = utility::predecessors(adj_list);
auto expected = array<vector,7>{
make_vector(),
make_vector(),
make_vector(0),
make_vector(1),
make_vector(3),
make_vector(2,4),
make_vector(5),
};

NMTOOLS_ASSERT_EQUAL( result.size(), expected.size() );
for (nm_size_t i=0; i<(nm_size_t)result.size(); i++) {
NMTOOLS_ASSERT_EQUAL( result.at(i), expected.at(i) );
}
}
Loading

0 comments on commit 51cfdfa

Please sign in to comment.