Skip to content

Commit e1e7bf5

Browse files
authored
Add kron, vecdot, and tensordot (#304)
* add kron * add max_value & min_value metafunctions * add tests * add vecdot * add tensordot * move out index::contains form expand_dims, add range index function * add is_clipped_index metafunction * update tests * skip kron compile-time shape inference when on gcc * update compiler-notes * fix for gcc werror * fix for gcc werror
1 parent 64d1bb5 commit e1e7bf5

31 files changed

+7546
-21
lines changed

docs/compiler-notes.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,7 @@ Documenting various note on behaviour difference between clang & gcc (or with so
101101

102102
1. clang vs gcc disagree on capturing constexpr value in lambda expression
103103
clang ok, gcc not ok
104-
https://godbolt.org/z/a1o8P9957
104+
https://godbolt.org/z/a1o8P9957
105+
106+
1. gcc `for` loop becomes goto in constexpr context (and breaks), works fine on clang
107+
https://github.com/alifahrri/nmtools/issues/303

include/nmtools/array/array/kron.hpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef NMTOOLS_ARRAY_ARRAY_KRON_HPP
2+
#define NMTOOLS_ARRAY_ARRAY_KRON_HPP
3+
4+
#include "nmtools/array/view/kron.hpp"
5+
#include "nmtools/array/eval.hpp"
6+
7+
namespace nmtools::array
8+
{
9+
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
10+
, typename lhs_t, typename rhs_t>
11+
constexpr auto kron(const lhs_t& lhs, const rhs_t& rhs
12+
, context_t&& context=context_t{}, output_t&& output=output_t{}, meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
13+
{
14+
auto a = view::kron(lhs,rhs);
15+
return eval(
16+
a
17+
, nmtools::forward<context_t>(context)
18+
, nmtools::forward<output_t>(output)
19+
, resolver
20+
);
21+
} // kron
22+
}
23+
24+
#endif // NMTOOLS_ARRAY_ARRAY_KRON_HPP
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP
2+
#define NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP
3+
4+
#include "nmtools/array/view/tensordot.hpp"
5+
#include "nmtools/array/eval.hpp"
6+
7+
namespace nmtools::array
8+
{
9+
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
10+
, typename lhs_t, typename rhs_t, typename axes_t=meta::ct<2>>
11+
constexpr auto tensordot(const lhs_t& lhs, const rhs_t& rhs, axes_t axes=axes_t{}
12+
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
13+
{
14+
auto a = view::tensordot(lhs,rhs,axes);
15+
return eval(
16+
a
17+
, nmtools::forward<context_t>(context)
18+
, nmtools::forward<output_t>(output)
19+
, resolver
20+
);
21+
} // tensordot
22+
} // nmtools::array
23+
24+
#endif // NMTOOLS_ARRAY_ARRAY_TENSORDOT_HPP
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#ifndef NMTOOLS_ARRAY_ARRAY_VECDOT_HPP
2+
#define NMTOOLS_ARRAY_ARRAY_VECDOT_HPP
3+
4+
#include "nmtools/array/view/vecdot.hpp"
5+
#include "nmtools/array/eval.hpp"
6+
7+
namespace nmtools::array
8+
{
9+
template <typename output_t=none_t, typename context_t=none_t, typename resolver_t=eval_result_t<>
10+
, typename lhs_t, typename rhs_t, typename dtype_t=none_t, typename keepdims_t=meta::false_type>
11+
constexpr auto vecdot(const lhs_t& lhs, const rhs_t& rhs, dtype_t dtype=dtype_t{}, keepdims_t keepdims=keepdims_t{}
12+
, context_t&& context=context_t{}, output_t&& output=output_t{},meta::as_value<resolver_t> resolver=meta::as_value_v<resolver_t>)
13+
{
14+
auto a = view::vecdot(lhs,rhs,dtype,keepdims);
15+
return eval(
16+
a
17+
, nmtools::forward<context_t>(context)
18+
, nmtools::forward<output_t>(output)
19+
, resolver
20+
);
21+
} // vecdot
22+
} // nmtools::array
23+
24+
#endif // NMTOOLS_ARRAY_ARRAY_VECDOT_HPP
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef NMTOOLS_ARRAY_INDEX_CONTAINS_HPP
2+
#define NMTOOLS_ARRAY_INDEX_CONTAINS_HPP
3+
4+
#include "nmtools/meta.hpp"
5+
#include "nmtools/array/shape.hpp"
6+
#include "nmtools/utils/isequal.hpp"
7+
8+
namespace nmtools::index
9+
{
10+
template <typename array_t, typename value_t>
11+
constexpr auto contains(const array_t& array, const value_t& value)
12+
{
13+
for (nm_size_t i=0; i<(nm_size_t)len(array); i++) {
14+
if (utils::isequal(at(array,i),value)) {
15+
return true;
16+
}
17+
}
18+
return false;
19+
} // contains
20+
} // nmtools::index
21+
22+
#endif // NMTOOLS_ARRAY_INDEX_CONTAINS_HPP

include/nmtools/array/index/expand_dims.hpp

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "nmtools/array/utility/at.hpp"
77
#include "nmtools/utils/isequal.hpp"
88
#include "nmtools/array/ndarray/hybrid.hpp"
9+
#include "nmtools/array/index/contains.hpp"
910
#include "nmtools/array/index/normalize_axis.hpp"
1011
#include "nmtools/utility/unwrap.hpp"
1112

@@ -22,26 +23,6 @@ namespace nmtools::index
2223
*/
2324
struct shape_expand_dims_t {};
2425

25-
// TODO: remove
26-
template <typename array_t, typename value_t>
27-
constexpr auto contains(const array_t& array, const value_t& value)
28-
{
29-
if constexpr (meta::is_fixed_index_array_v<array_t>) {
30-
bool contain = false;
31-
meta::template_for<meta::len_v<array_t>>([&](auto i){
32-
if (utils::isequal(at(array,i),value))
33-
contain = true;
34-
});
35-
return contain;
36-
}
37-
else {
38-
for (size_t i=0; i<len(array); i++)
39-
if (utils::isequal(at(array,i),value))
40-
return true;
41-
return false;
42-
}
43-
} // contains
44-
4526
/**
4627
* @brief extend the shape with value 1 for each given axis
4728
*

include/nmtools/array/index/range.hpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
#ifndef NMTOOLS_ARRAY_INDEX_RANGE_HPP
2+
#define NMTOOLS_ARRAY_INDEX_RANGE_HPP
3+
4+
#include "nmtools/meta.hpp"
5+
#include "nmtools/array/shape.hpp"
6+
7+
namespace nmtools::index
8+
{
9+
struct range_t {};
10+
11+
template <typename start_t, typename stop_t, typename step_t=meta::ct<1>>
12+
constexpr auto range([[maybe_unused]] start_t start
13+
, [[maybe_unused]] stop_t stop
14+
, [[maybe_unused]] step_t step=step_t{}
15+
) {
16+
using result_t = meta::resolve_optype_t<range_t,start_t,stop_t,step_t>;
17+
18+
auto result = result_t {};
19+
20+
if constexpr (!meta::is_fail_v<result_t>
21+
&& !meta::is_constant_index_array_v<result_t>
22+
) {
23+
auto n = (stop - start) / step;
24+
if constexpr (meta::is_resizable_v<result_t>) {
25+
result.resize(n);
26+
}
27+
28+
for (nm_size_t i=0; i<(nm_size_t)n; i++) {
29+
at(result,i) = i * step;
30+
}
31+
}
32+
33+
return result;
34+
} // range
35+
36+
template <typename stop_t>
37+
constexpr auto range(stop_t stop)
38+
{
39+
return range(meta::ct_v<0>,stop,meta::ct_v<1>);
40+
}
41+
} // nmtools::index
42+
43+
namespace nmtools::meta
44+
{
45+
namespace error
46+
{
47+
template <typename...>
48+
struct RANGE_UNSUPPORTED : detail::fail_t {};
49+
}
50+
51+
template <typename start_t, typename stop_t, typename step_t>
52+
struct resolve_optype<
53+
void, index::range_t, start_t, stop_t, step_t
54+
> {
55+
static constexpr auto vtype = [](){
56+
if constexpr (
57+
!is_index_v<start_t>
58+
|| !is_index_v<stop_t>
59+
|| !is_index_v<step_t>
60+
) {
61+
using type = error::RANGE_UNSUPPORTED<start_t,stop_t,step_t>;
62+
return as_value_v<type>;
63+
} else if constexpr (is_constant_index_v<start_t>
64+
&& is_constant_index_v<stop_t>
65+
&& is_constant_index_v<step_t>
66+
) {
67+
constexpr auto start = to_value_v<start_t>;
68+
constexpr auto stop = to_value_v<stop_t>;
69+
constexpr auto step = to_value_v<step_t>;
70+
constexpr auto start_cl = clipped_int64_t<int64_t(start > 0 ? start : 1)>(start);
71+
constexpr auto stop_cl = clipped_int64_t<(int64_t)stop>(stop);
72+
constexpr auto step_cl = clipped_int64_t<(int64_t)step>(step);
73+
constexpr auto result = index::range(start_cl,stop_cl,step_cl);
74+
using nmtools::at, nmtools::len;
75+
return template_reduce<len(result)>([&](auto init, auto I){
76+
using init_t = type_t<decltype(init)>;
77+
using type = append_type_t<init_t,ct<at(result,I)>>;
78+
return as_value_v<type>;
79+
}, as_value_v<nmtools_tuple<>>);
80+
} else {
81+
constexpr auto max_dim = max_value_v<stop_t>;
82+
if constexpr (!is_fail_v<decltype(max_dim)>) {
83+
using type = nmtools_static_vector<nm_size_t,max_dim>;
84+
return as_value_v<type>;
85+
} else {
86+
// TODO: small vector optimization
87+
using type = nmtools_list<nm_size_t>;
88+
return as_value_v<type>;
89+
}
90+
}
91+
}();
92+
using type = type_t<decltype(vtype)>;
93+
}; // index::range_t
94+
} // nmtools::meta
95+
96+
#endif // NMTOOLS_ARRAY_INDEX_RANGE_HPP

0 commit comments

Comments
 (0)