diff --git a/include/nmtools/utility/ct_digraph.hpp b/include/nmtools/utility/ct_digraph.hpp index a5893f60a..b5f70740d 100644 --- a/include/nmtools/utility/ct_digraph.hpp +++ b/include/nmtools/utility/ct_digraph.hpp @@ -5,6 +5,7 @@ #include "nmtools/utility/at.hpp" #include "nmtools/utility/ct_map.hpp" #include "nmtools/utl/static_vector.hpp" +#include "nmtools/utl/static_queue.hpp" #include "nmtools/assert.hpp" namespace nmtools::utility @@ -103,6 +104,27 @@ namespace nmtools::utility }); return nmtools_tuple{list,id_map}; + } // adjacency_list + + template + constexpr auto in_degree(const adjacency_list_t& list) + { + // assume static_vector or array + // TODO: check for another type + constexpr auto N = meta::bounded_size_v; + + using result_t = utl::static_vector; + auto result = result_t {}; + + auto n = len(list); + result.resize(n); + + for (nm_size_t i=0; i<(nm_size_t)n; i++) { + for (nm_size_t j=0; j<(nm_size_t)list[i].size(); j++) { + result[list[i][j]]++; + } + } + return result; } template , typename edges_t=nmtools_tuple<>, typename node_data_t=nmtools_tuple<>> @@ -354,6 +376,55 @@ namespace nmtools::utility return result; } -} + + template + constexpr auto topological_sort(const adjacency_list_t& adjacency_list) + { + // assume static_vector or array + // TODO: check for another type + constexpr auto N = meta::bounded_size_v; + + auto in_degree = utility::in_degree(adjacency_list); + + auto queue = utl::static_queue(); + auto result = utl::static_vector(); + + // initialize queue with input (in_degree 0) + for (nm_size_t i=0; i + constexpr auto topological_sort(const ct_digraph& graph) + { + constexpr auto adjacency_result = adjacency_list(decltype(graph.digraph){}); + constexpr auto adjacency_list = nmtools::get<0>(adjacency_result); + constexpr auto src_id_map = nmtools::get<1>(adjacency_result); + + constexpr auto sorted = topological_sort(adjacency_list); + constexpr auto N = sorted.size(); + + return meta::template_reduce([&](auto init, auto index){ + auto node_id = meta::ct_v; + return tuple_append(init,node_id); + },nmtools_tuple{}); + } +} // nmtools::utility #endif // NMTOOLS_UTILITY_CT_DIGRAPH_HPP \ No newline at end of file diff --git a/include/nmtools/utl/static_queue.hpp b/include/nmtools/utl/static_queue.hpp new file mode 100644 index 000000000..cf2b0e09e --- /dev/null +++ b/include/nmtools/utl/static_queue.hpp @@ -0,0 +1,69 @@ +#ifndef NMTOOLS_UTL_STATIC_QUEUE_HPP +#define NMTOOLS_UTL_STATIC_QUEUE_HPP + +#include "nmtools/def.hpp" +#include "nmtools/platform.hpp" +#include "nmtools/assert.hpp" +#include "nmtools/utl/common.hpp" +#include "nmtools/utl/array.hpp" +#include "nmtools/meta/bits/array/resize_bounded_size.hpp" + +namespace nmtools::utl +{ + template + struct static_queue + { + using buffer_type = utl::array; + using size_type = nm_size_t; + + buffer_type buffer_; + size_type size_; + + constexpr static_queue() + : buffer_{} + , size_{0} + {} + + constexpr auto front() const noexcept + { + return buffer_[0]; + } + + constexpr auto back() const noexcept + { + return buffer_[size_-1]; + } + + constexpr auto size() const noexcept + { + return size_; + } + + constexpr auto empty() const noexcept + { + return size_ == 0; + } + + constexpr auto push(const T& t) + { + nmtools_assert( size_ < N + , "static_queue already full"); + + size_ += 1; + buffer_[size_-1] = t; + } + + constexpr auto pop() + { + size_ -= 1; + auto res = buffer_[0]; + // simply move all elements + for (nm_size_t i=0; i<(nm_size_t)size_; i++) { + buffer_[i] = buffer_[i+1]; + } + return res; + } + }; +} + +#endif // NMTOOLS_UTL_STATIC_QUEUE_HPP \ No newline at end of file diff --git a/tests/functional/src/misc/ct_digraph.cpp b/tests/functional/src/misc/ct_digraph.cpp index 2d016a0b5..4c15f5a23 100644 --- a/tests/functional/src/misc/ct_digraph.cpp +++ b/tests/functional/src/misc/ct_digraph.cpp @@ -836,7 +836,6 @@ TEST_CASE("contracted_edge(case1)" * doctest::test_suite("ct_digraph")) make_vector(5), make_vector(), make_vector(4), - }; auto expected_id = array{0,1,769,447,722,635,765}; @@ -892,6 +891,8 @@ TEST_CASE("contracted_edge(case2)" * doctest::test_suite("ct_digraph")) } } + +// var TEST_CASE("contracted_edge(case3)" * doctest::test_suite("ct_digraph")) { using vector = utl::static_vector; @@ -1266,4 +1267,269 @@ TEST_CASE("adjacency_to_graph(case1)" * doctest::test_suite("ct_digraph")) meta::remove_cvref_t , meta::remove_cvref_t ); +} + +// matmul +TEST_CASE("in_degree(case1)" * doctest::test_suite("ct_digraph")) +{ + using vector = utl::static_vector; + + auto make_vector = [](auto...v){ + auto vec = vector(); + (vec.push_back(v),...); + return vec; + }; + auto list = array{ + make_vector(6), + make_vector(2), + make_vector(3), + make_vector(4), + make_vector(5), + make_vector(), + make_vector(4), + }; + + auto result = utility::in_degree(list); + + auto expected = make_vector(0,0,1,1,2,1,1); + + NMTOOLS_ASSERT_EQUAL(result,expected); +} + +// softmax +TEST_CASE("in_degree(case2)" * doctest::test_suite("ct_digraph")) +{ + using vector = utl::static_vector; + + auto make_vector = [](auto...v){ + auto vec = vector(); + (vec.push_back(v),...); + return vec; + }; + + auto list = array{ + make_vector(1,2), + make_vector(2), + make_vector(3), + make_vector(4,5), + make_vector(5), + make_vector() + }; + + auto result = utility::in_degree(list); + + auto expected = make_vector(0,1,2,1,1,2); + + NMTOOLS_ASSERT_EQUAL(result,expected); +} + +// var +TEST_CASE("in_degree(case3)" * doctest::test_suite("ct_digraph")) +{ + using vector = utl::static_vector; + + auto make_vector = [](auto...v){ + auto vec = vector(); + (vec.push_back(v),...); + return vec; + }; + + auto list = array{ + make_vector(2,4), + make_vector(3), + make_vector(3), + make_vector(4), + make_vector(5), + make_vector(6), + make_vector(7), + make_vector(9), + make_vector(9), + make_vector() + }; + + auto result = utility::in_degree(list); + + auto expected = make_vector(0,0,1,2,2,1,1,1,0,2); + + NMTOOLS_ASSERT_EQUAL( result, expected ); +} + +// matmul +TEST_CASE("topological_sort(case1)" * doctest::test_suite("ct_digraph")) +{ + using vector = utl::static_vector; + + auto make_vector = [](auto...v){ + auto vec = vector(); + (vec.push_back(v),...); + return vec; + }; + auto list = array{ + make_vector(6), + make_vector(2), + make_vector(3), + make_vector(4), + make_vector(5), + make_vector(), + make_vector(4), + }; + // auto id_map = array{0,1,769,447,722,635,765}; + + auto result = utility::topological_sort(list); + + auto expected = vector{0,1,6,2,3,4,5}; + + NMTOOLS_ASSERT_EQUAL( result, expected ); +} + +// softmax +TEST_CASE("topological_sort(case2)" * doctest::test_suite("ct_digraph")) +{ + using vector = utl::static_vector; + + auto make_vector = [](auto...v){ + auto vec = vector(); + (vec.push_back(v),...); + return vec; + }; + + auto adj_list = array{ + make_vector(1,4), + make_vector(4), + make_vector(3), + make_vector(), + make_vector(2,3) + }; + + auto result = utility::topological_sort(adj_list); + + auto expected = vector{0,1,4,2,3}; + + NMTOOLS_ASSERT_EQUAL( result, expected ); +} + +// var +TEST_CASE("topological_sort(case3)" * doctest::test_suite("ct_digraph")) +{ + using vector = utl::static_vector; + + auto make_vector = [](auto...v){ + auto vec = vector(); + (vec.push_back(v),...); + return vec; + }; + + auto adj_list = array{ + make_vector(2,4), + make_vector(3), + make_vector(3), + make_vector(4), + make_vector(8), + make_vector(7), + make_vector(7), + make_vector(), + make_vector(5) + }; + + auto result = utility::topological_sort(adj_list); + + auto expected = vector{0,1,6,2,3,4,8,5,7}; + + NMTOOLS_ASSERT_EQUAL( result, expected ); +} + +TEST_CASE("topological_sort(case1a)" * doctest::test_suite("ct_digraph")) +{ + auto lhs_shape = array{3,4}; + auto rhs_shape = array{4,3}; + + auto lhs = na::reshape(na::arange(ix::product(lhs_shape)),lhs_shape); + auto rhs = na::reshape(na::arange(ix::product(rhs_shape)),rhs_shape); + + auto graph = utility::ct_digraph() + .add_node(0_ct,&lhs) + .add_node(1_ct,&rhs) + .add_node(769_ct,fn::transpose[/*axes=*/array{1,0}]) + .add_node(447_ct,fn::reshape[/*dst_shape=*/array{1,3,4}]) + .add_node(722_ct,fn::multiply) + .add_node(635_ct,fn::sum[/*axis=*/-1]) + .add_node(765_ct,fn::reshape[/*dst_shape=*/array{3,3,4}] * fn::tile[/*reps=*/array{1,3}]) + .add_edge(0_ct,765_ct) + .add_edge(1_ct,769_ct) + .add_edge(769_ct,447_ct) + .add_edge(447_ct,722_ct) + .add_edge(722_ct,635_ct) + .add_edge(765_ct,722_ct) + ; + + auto result = utility::topological_sort(graph); + + // auto expected = array{0,1,6,2,3,4,5}; + auto expected = array{0,1,765,769,447,722,635}; + + NMTOOLS_ASSERT_EQUAL( result, expected ); +} + +// softmax +TEST_CASE("topological_sort(case2a)" * doctest::test_suite("ct_digraph")) +{ + auto input_shape = array{3,4}; + auto input = na::reshape(na::arange(ix::product(input_shape)),input_shape); + + auto fused = fn::subtract * fn::exp; + + auto graph = utility::ct_digraph() + .add_node(0_ct,&input) + .add_node(263_ct,fn::reduce_maximum[/*axis=*/0]) + .add_node(407_ct,fn::sum[/*axis=*/0]) + .add_node(850_ct,fn::divide) + .add_node(111_ct,fused) + .add_edges(0_ct,tuple{263_ct,111_ct}) + .add_edge(263_ct,111_ct) + .add_edge(407_ct,850_ct) + .add_edges(111_ct,tuple{407_ct,850_ct}) + ; + + auto result = utility::topological_sort(graph); + + // auto expected = array{0,1,4,2,3}; + auto expected = array{0,263,111,407,850}; + + NMTOOLS_ASSERT_EQUAL( result, expected ); +} + +// var +TEST_CASE("topological_sort(case3a)" * doctest::test_suite("ct_digraph")) +{ + auto input_shape = array{3,4}; + auto input = na::reshape(na::arange(ix::product(input_shape)),input_shape); + + auto fused = fn::square * fn::fabs; + + auto graph = utility::ct_digraph() + .add_node(0_ct,&input) + .add_node(472_ct,3) + .add_node(470_ct,fn::reduce_add[/*axis=*/0]) + .add_node(51_ct,fn::divide) + .add_node(428_ct,fn::subtract) + .add_node(433_ct,fn::reduce_add[/*axis=*/0]) + .add_node(435_ct,3) + .add_node(1022_ct,fn::divide) + .add_node(391_ct,fused) + .add_edges(0_ct,tuple{470_ct,428_ct}) + .add_edge(472_ct,51_ct) + .add_edge(470_ct,51_ct) + .add_edge(51_ct,428_ct) + .add_edge(428_ct,391_ct) + .add_edge(433_ct,1022_ct) + .add_edge(435_ct,1022_ct) + .add_edge(391_ct,433_ct) + ; + + auto result = utility::topological_sort(graph); + + // auto expected = vector{0,1,6,2,3,4,8,5,7}; + auto expected = array{0,472,435,470,51,428,391,433,1022}; + + NMTOOLS_ASSERT_EQUAL( result, expected ); } \ No newline at end of file diff --git a/tests/utl/utl/CMakeLists.txt b/tests/utl/utl/CMakeLists.txt index a7bce4669..b0669c3c2 100644 --- a/tests/utl/utl/CMakeLists.txt +++ b/tests/utl/utl/CMakeLists.txt @@ -10,6 +10,7 @@ add_executable(${PROJECT_NAME} tests.cpp src/either.cpp src/vector.cpp src/clipped_integer.cpp + src/static_queue.cpp src/static_vector.cpp src/tuplev2.cpp ) diff --git a/tests/utl/utl/src/static_queue.cpp b/tests/utl/utl/src/static_queue.cpp new file mode 100644 index 000000000..a981078ae --- /dev/null +++ b/tests/utl/utl/src/static_queue.cpp @@ -0,0 +1,38 @@ +#include "nmtools/utl/static_queue.hpp" +#include "nmtools/testing/doctest.hpp" + +namespace utl = nmtools::utl; + +TEST_CASE("static_queue(case1)" * doctest::test_suite("utl")) +{ + auto q = utl::static_queue(); + + NMTOOLS_ASSERT_EQUAL( q.empty(), true ); +} + +TEST_CASE("static_queue(case2)" * doctest::test_suite("utl")) +{ + auto q = utl::static_queue(); + + q.push(1); + + NMTOOLS_ASSERT_EQUAL( q.empty(), false ); + NMTOOLS_ASSERT_EQUAL( q.size(), 1 ); + + auto v = q.pop(); + + NMTOOLS_ASSERT_EQUAL( q.empty(), true ); + NMTOOLS_ASSERT_EQUAL( v, 1 ); + + q.push(2); + q.push(3); + q.push(5); + + v = q.pop(); + + NMTOOLS_ASSERT_EQUAL( q.empty(), false ); + NMTOOLS_ASSERT_EQUAL( q.size(), 2 ); + NMTOOLS_ASSERT_EQUAL( v, 2 ); + NMTOOLS_ASSERT_EQUAL( q.front(), 3 ); + NMTOOLS_ASSERT_EQUAL( q.back(), 5 ); +} \ No newline at end of file