Skip to content

Commit 41ece2f

Browse files
authored
Add more sycl tests (#262)
* add cuda & sycl dockerfile & devcontainer * fix for sycl werror compilation * add where test with integer condition * add more sycl tests
1 parent 61d19b1 commit 41ece2f

22 files changed

+3313
-4
lines changed

.devcontainer/devcontainer.json

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
// "dockerFile": "../docker/eoan.dockerfile",
1717

1818
// DOCKERFILE: ubuntu jammy (22.04)
19-
"name": "ubuntu jammy - gcc11, clang14",
20-
"dockerFile": "../docker/jammy.dockerfile",
19+
// "name": "ubuntu jammy - gcc11, clang14",
20+
// "dockerFile": "../docker/jammy.dockerfile",
2121

2222
// DOCKERFILE: ubuntu jammy (22.04)
2323
// "name": "ubuntu lunar",
@@ -32,13 +32,17 @@
3232
// "dockerFile": "../docker/circle.dockerfile",
3333

3434
// DOCKERFILE: cuda
35-
// "name": "cuda-12 ubuntu20.04",
35+
// "name": "nvidia/cuda:11.8.0-devel-ubuntu22.04",
3636
// "dockerFile": "../docker/cuda.dockerfile",
3737

3838
// DOCKERFILE: sycl
3939
// "name": "sycl-clang14 ubuntu22.04",
4040
// "dockerFile": "../docker/sycl.dockerfile",
4141

42+
// DOCKERFILE: cuda-sycl
43+
"name": "sycl-clang14 ubuntu22.04 with cuda toolchain",
44+
"dockerFile": "../docker/cuda-sycl.dockerfile",
45+
4246
"build": {
4347
"args": { "USERNAME": "${localEnv:USER}" },
4448
"target": "dev"

docker/cuda-sycl.dockerfile

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
ARG CUDA_BASE=nvidia/cuda:11.8.0-devel-ubuntu22.04
2+
from ${CUDA_BASE} as dev
3+
4+
ARG DEBIAN_FRONTEND=noninteractive
5+
ENV TZ=Asia
6+
7+
run apt update && apt install -y \
8+
build-essential cmake clang \
9+
curl git-core gnupg locales \
10+
zsh wget fonts-powerline \
11+
&& locale-gen en_US.UTF-8
12+
13+
ENV DEBIAN_FRONTEND=dialog
14+
15+
# generate locale for agnoster
16+
RUN echo "en_US.UTF-8 UTF-8" > /etc/locale.gen && /usr/sbin/locale-gen
17+
18+
ENV TERM xterm
19+
20+
# Set the default shell to bash rather than sh
21+
ENV SHELL /bin/zsh
22+
23+
# run the installation script
24+
RUN wget https://github.com/robbyrussell/oh-my-zsh/raw/master/tools/install.sh -O - | zsh || true
25+
26+
# install powerlevel10k
27+
RUN git clone https://github.com/romkatv/powerlevel10k.git ~/.oh-my-zsh/custom/themes/powerlevel10k
28+
29+
RUN cd $HOME && curl -fsSLO https://raw.githubusercontent.com/romkatv/dotfiles-public/master/.purepower
30+
31+
ADD .devcontainer/.zshrc $HOME
32+
33+
from dev as build
34+
35+
WORKDIR /workspace/nmtools
36+
37+
COPY cmake cmake
38+
COPY scripts scripts
39+
COPY include include
40+
COPY tests tests
41+
COPY CMakeLists.txt CMakeLists.txt
42+
COPY nmtools.pc.in nmtools.pc.in
43+
COPY nmtoolsConfig.cmake.in nmtoolsConfig.cmake.in
44+
45+
## install doctest
46+
COPY scripts/install_doctest.sh scripts/install_doctest.sh
47+
RUN bash scripts/install_doctest.sh
48+
49+
RUN apt install -y libclang-dev clang-tools libomp-dev llvm-dev lld libboost-dev libboost-fiber-dev libboost-context-dev
50+
RUN bash scripts/install_opensycl.sh
51+
52+
ARG toolchain=sycl-clang14-omp
53+
RUN mkdir -p build && cd build \
54+
&& cmake -DCMAKE_TOOLCHAIN_FILE=cmake/toolchains/${toolchain}.cmake \
55+
-DNMTOOLS_BUILD_META_TESTS=OFF -DNMTOOLS_BUILD_UTL_TESTS=OFF -DNMTOOLS_TEST_ALL=OFF \
56+
-DNMTOOLS_BUILD_SYCL_TESTS=ON \
57+
../ \
58+
&& make -j2 VERBOSE=1
59+
60+
CMD ["/workspace/nmtools/build/tests/sycl/numeric-tests-sycl-doctest"]

include/nmtools/array/eval.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ namespace nmtools::array
305305
* @return constexpr auto
306306
*/
307307
template <typename output_t, typename context_t, typename resolver_t, typename views_t, template<auto...>typename index_sequence, auto...Is>
308-
constexpr auto apply_eval(const views_t& views, context_t&& context, output_t&& output, resolver_t resolver, index_sequence<Is...>)
308+
constexpr auto apply_eval(const views_t& views, context_t&& context, output_t&& output, [[maybe_unused]] resolver_t resolver, index_sequence<Is...>)
309309
{
310310
return nmtools_tuple{array::eval(nmtools::get<Is>(views)
311311
, nmtools::forward<context_t>(context)

include/nmtools/testing/data/array/where.hpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,83 @@ NMTOOLS_TESTING_DECLARE_CASE(view, where)
8686
}
8787
};
8888
}
89+
90+
NMTOOLS_TESTING_DECLARE_ARGS(case1b)
91+
{
92+
inline int8_t condition[10] = {1, 1, 1, 1, 1, 0, 0, 0, 0, 0};
93+
inline int8_t x[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
94+
inline int8_t y[10] = {0,10,20,30,40,50,60,70,80,90};
95+
NMTOOLS_CAST_ARRAYS(condition)
96+
NMTOOLS_CAST_ARRAYS(x)
97+
NMTOOLS_CAST_ARRAYS(y)
98+
}
99+
NMTOOLS_TESTING_DECLARE_EXPECT(case1b)
100+
{
101+
inline int8_t dim = 1;
102+
inline int8_t shape[1] = {10};
103+
inline int8_t result[10] = {0, 1, 2, 3, 4, 50, 60, 70, 80, 90};
104+
}
105+
106+
NMTOOLS_TESTING_DECLARE_ARGS(case2b)
107+
{
108+
inline int8_t condition[3][4] = {
109+
{0, 1, 1, 1},
110+
{0, 0, 1, 1},
111+
{0, 0, 0, 1},
112+
};
113+
inline int8_t x[3][1] = {
114+
{0},
115+
{1},
116+
{2},
117+
};
118+
inline int8_t y[1][4] = {{10,11,12,13}};
119+
NMTOOLS_CAST_ARRAYS(condition)
120+
NMTOOLS_CAST_ARRAYS(x)
121+
NMTOOLS_CAST_ARRAYS(y)
122+
}
123+
NMTOOLS_TESTING_DECLARE_EXPECT(case2b)
124+
{
125+
inline int8_t dim = 2;
126+
inline int8_t shape[2] = {3,4};
127+
inline int8_t result[3][4] = {
128+
{10, 0, 0, 0},
129+
{10, 11, 1, 1},
130+
{10, 11, 12, 2},
131+
};
132+
}
133+
134+
135+
NMTOOLS_TESTING_DECLARE_ARGS(case3b)
136+
{
137+
inline int8_t condition[3][4] = {
138+
{0, 1, 1, 1},
139+
{0, 0, 1, 1},
140+
{0, 0, 0, 1},
141+
};
142+
inline int8_t x[1][3][1] = {
143+
{
144+
{0},
145+
{1},
146+
{2},
147+
}
148+
};
149+
inline int8_t y[1][4] = {{10,11,12,13}};
150+
NMTOOLS_CAST_ARRAYS(condition)
151+
NMTOOLS_CAST_ARRAYS(x)
152+
NMTOOLS_CAST_ARRAYS(y)
153+
}
154+
NMTOOLS_TESTING_DECLARE_EXPECT(case3b)
155+
{
156+
inline int8_t dim = 3;
157+
inline int8_t shape[3] = {1,3,4};
158+
inline int8_t result[1][3][4] = {
159+
{
160+
{10, 0, 0, 0},
161+
{10, 11, 1, 1},
162+
{10, 11, 12, 2},
163+
}
164+
};
165+
}
89166
#endif
90167
}
91168

tests/sycl/CMakeLists.txt

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,23 @@ set(NMTOOLS_SYCL_TEST_SOURCES ${NMTOOLS_SYCL_TEST_SOURCES}
5353
array/outer/multiply.cpp
5454

5555
array/resize.cpp
56+
array/atleast_1d.cpp
57+
array/atleast_2d.cpp
58+
array/atleast_3d.cpp
59+
array/broadcast_to.cpp
60+
array/concatenate.cpp
61+
array/expand_dims.cpp
62+
array/flatten.cpp
63+
array/flip.cpp
64+
array/pad.cpp
65+
array/pooling.cpp
66+
array/repeat.cpp
67+
array/reshape.cpp
68+
array/slice.cpp
69+
array/squeeze.cpp
70+
array/tile.cpp
71+
array/transpose.cpp
72+
array/where.cpp
5673
)
5774

5875
## TODO: support nvcc compilation

tests/sycl/array/atleast_1d.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#define NMTOOLS_CAST_ARRAYS_EXTRA(name) \
2+
inline auto name##_cs_fb = nmtools::cast(name, nmtools::array::kind::ndarray_cs_fb); \
3+
inline auto name##_cs_hb = nmtools::cast(name, nmtools::array::kind::ndarray_cs_hb); \
4+
inline auto name##_cs_db = nmtools::cast(name, nmtools::array::kind::ndarray_cs_db); \
5+
inline auto name##_fs_fb = nmtools::cast(name, nmtools::array::kind::ndarray_fs_fb); \
6+
inline auto name##_fs_hb = nmtools::cast(name, nmtools::array::kind::ndarray_fs_hb); \
7+
inline auto name##_fs_db = nmtools::cast(name, nmtools::array::kind::ndarray_fs_db); \
8+
inline auto name##_hs_fb = nmtools::cast(name, nmtools::array::kind::ndarray_hs_fb); \
9+
inline auto name##_hs_hb = nmtools::cast(name, nmtools::array::kind::ndarray_hs_hb); \
10+
inline auto name##_hs_db = nmtools::cast(name, nmtools::array::kind::ndarray_hs_db); \
11+
inline auto name##_ds_fb = nmtools::cast(name, nmtools::array::kind::ndarray_ds_fb); \
12+
inline auto name##_ds_hb = nmtools::cast(name, nmtools::array::kind::ndarray_ds_hb); \
13+
inline auto name##_ds_db = nmtools::cast(name, nmtools::array::kind::ndarray_ds_db); \
14+
inline auto name##_ls_fb = nmtools::cast(name, nmtools::array::kind::ndarray_ls_fb); \
15+
inline auto name##_ls_hb = nmtools::cast(name, nmtools::array::kind::ndarray_ls_hb); \
16+
inline auto name##_ls_db = nmtools::cast(name, nmtools::array::kind::ndarray_ls_db);
17+
18+
#include "nmtools/array/array/atleast_1d.hpp"
19+
#include "nmtools/testing/doctest.hpp"
20+
#include "nmtools/array/eval/sycl.hpp"
21+
#include "nmtools/testing/data/array/atleast_1d.hpp"
22+
23+
namespace nm = nmtools;
24+
namespace na = nm::array;
25+
26+
#define ATLEAST_1D_SUBCASE(case_name, ...) \
27+
SUBCASE(#case_name) \
28+
{ \
29+
NMTOOLS_TESTING_DECLARE_NS(array, atleast_1d, case_name); \
30+
using namespace args; \
31+
auto result = na::atleast_1d(__VA_ARGS__, na::sycl::default_context()); \
32+
auto expect = na::atleast_1d(__VA_ARGS__); \
33+
NMTOOLS_ASSERT_EQUAL( nm::shape(result), nm::shape(expect) ); \
34+
NMTOOLS_ASSERT_CLOSE( result, expect ); \
35+
}
36+
37+
// TEST_CASE("atleast_1d(case1)" * doctest::test_suite("array::atleast_1d"))
38+
// {
39+
// ATLEAST_1D_SUBCASE(case1, a);
40+
// }
41+
42+
TEST_CASE("atleast_1d(case2)" * doctest::test_suite("array::atleast_1d"))
43+
{
44+
// ATLEAST_1D_SUBCASE(case2, a);
45+
// ATLEAST_1D_SUBCASE(case2, a_a);
46+
// ATLEAST_1D_SUBCASE(case2, a_f);
47+
// ATLEAST_1D_SUBCASE(case2, a_h);
48+
// ATLEAST_1D_SUBCASE(case2, a_d);
49+
50+
// ATLEAST_1D_SUBCASE(case2, a_cs_fb);
51+
// ATLEAST_1D_SUBCASE(case2, a_cs_hb);
52+
// ATLEAST_1D_SUBCASE(case2, a_cs_db);
53+
54+
ATLEAST_1D_SUBCASE(case2, a_fs_fb);
55+
ATLEAST_1D_SUBCASE(case2, a_fs_hb);
56+
ATLEAST_1D_SUBCASE(case2, a_fs_db);
57+
58+
ATLEAST_1D_SUBCASE(case2, a_hs_fb);
59+
ATLEAST_1D_SUBCASE(case2, a_hs_hb);
60+
ATLEAST_1D_SUBCASE(case2, a_hs_db);
61+
62+
ATLEAST_1D_SUBCASE(case2, a_ds_fb);
63+
ATLEAST_1D_SUBCASE(case2, a_ds_hb);
64+
ATLEAST_1D_SUBCASE(case2, a_ds_db);
65+
66+
// ATLEAST_1D_SUBCASE(case2, a_ls_fb);
67+
// ATLEAST_1D_SUBCASE(case2, a_ls_hb);
68+
// ATLEAST_1D_SUBCASE(case2, a_ls_db);
69+
}
70+
71+
TEST_CASE("atleast_1d(case3)" * doctest::test_suite("array::atleast_1d"))
72+
{
73+
// ATLEAST_1D_SUBCASE(case3, a);
74+
// ATLEAST_1D_SUBCASE(case3, a_a);
75+
// ATLEAST_1D_SUBCASE(case3, a_f);
76+
// ATLEAST_1D_SUBCASE(case3, a_h);
77+
// ATLEAST_1D_SUBCASE(case3, a_d);
78+
79+
// ATLEAST_1D_SUBCASE(case3, a_cs_fb);
80+
// ATLEAST_1D_SUBCASE(case3, a_cs_hb);
81+
// ATLEAST_1D_SUBCASE(case3, a_cs_db);
82+
83+
ATLEAST_1D_SUBCASE(case3, a_fs_fb);
84+
ATLEAST_1D_SUBCASE(case3, a_fs_hb);
85+
ATLEAST_1D_SUBCASE(case3, a_fs_db);
86+
87+
ATLEAST_1D_SUBCASE(case3, a_hs_fb);
88+
ATLEAST_1D_SUBCASE(case3, a_hs_hb);
89+
ATLEAST_1D_SUBCASE(case3, a_hs_db);
90+
91+
ATLEAST_1D_SUBCASE(case3, a_ds_fb);
92+
ATLEAST_1D_SUBCASE(case3, a_ds_hb);
93+
ATLEAST_1D_SUBCASE(case3, a_ds_db);
94+
95+
// ATLEAST_1D_SUBCASE(case3, a_ls_fb);
96+
// ATLEAST_1D_SUBCASE(case3, a_ls_hb);
97+
// ATLEAST_1D_SUBCASE(case3, a_ls_db);
98+
}
99+
100+
TEST_CASE("atleast_1d(case4)" * doctest::test_suite("array::atleast_1d"))
101+
{
102+
// ATLEAST_1D_SUBCASE(case4, a);
103+
// ATLEAST_1D_SUBCASE(case4, a_a);
104+
// ATLEAST_1D_SUBCASE(case4, a_f);
105+
// ATLEAST_1D_SUBCASE(case4, a_h);
106+
// ATLEAST_1D_SUBCASE(case4, a_d);
107+
108+
// ATLEAST_1D_SUBCASE(case4, a_cs_fb);
109+
// ATLEAST_1D_SUBCASE(case4, a_cs_hb);
110+
// ATLEAST_1D_SUBCASE(case4, a_cs_db);
111+
112+
ATLEAST_1D_SUBCASE(case4, a_fs_fb);
113+
ATLEAST_1D_SUBCASE(case4, a_fs_hb);
114+
ATLEAST_1D_SUBCASE(case4, a_fs_db);
115+
116+
ATLEAST_1D_SUBCASE(case4, a_hs_fb);
117+
ATLEAST_1D_SUBCASE(case4, a_hs_hb);
118+
ATLEAST_1D_SUBCASE(case4, a_hs_db);
119+
120+
ATLEAST_1D_SUBCASE(case4, a_ds_fb);
121+
ATLEAST_1D_SUBCASE(case4, a_ds_hb);
122+
ATLEAST_1D_SUBCASE(case4, a_ds_db);
123+
124+
// ATLEAST_1D_SUBCASE(case4, a_ls_fb);
125+
// ATLEAST_1D_SUBCASE(case4, a_ls_hb);
126+
// ATLEAST_1D_SUBCASE(case4, a_ls_db);
127+
}

0 commit comments

Comments
 (0)