@@ -11,8 +11,7 @@ fn test_behavior() {
11
11
ptr:: { null, null_mut} ,
12
12
} ;
13
13
14
- const CODE : & str =
15
- r#"extern "C" __mlu_global__ void kernel() { printf("Hello from MLU!\n"); }"# ;
14
+ const CODE : & str = r#"extern "C" __mlu_entry__ void kernel() { printf("Hello from MLU!\n"); }"# ;
16
15
17
16
crate :: init ( ) ;
18
17
let Some ( dev) = crate :: Device :: fetch ( ) else {
@@ -52,7 +51,6 @@ fn test_behavior() {
52
51
bin
53
52
} ;
54
53
55
- crate :: init ( ) ;
56
54
if let Some ( dev) = crate :: Device :: fetch ( ) {
57
55
dev. context ( ) . apply ( |ctx| {
58
56
use crate :: AsRaw ;
@@ -82,3 +80,70 @@ fn test_behavior() {
82
80
} ) ;
83
81
} ;
84
82
}
83
+
84
+ #[ test]
85
+ fn test_add ( ) {
86
+ use crate :: memcpy_d2h;
87
+ use std:: ffi:: { c_void, CString } ;
88
+
89
+ const N : usize = 64 ;
90
+ let src = format ! (
91
+ r#"
92
+ extern "C" __mlu_entry__ void kernel(
93
+ float * ans_,
94
+ float const* lhs_,
95
+ float const* rhs_
96
+ ) {{
97
+ __nram__ float lhs[{N}];
98
+ __nram__ float rhs[{N}];
99
+ __nram__ float ans[{N}];
100
+ __memcpy(lhs, lhs_, {N} * sizeof(float), GDRAM2NRAM);
101
+ __memcpy(rhs, rhs_, {N} * sizeof(float), GDRAM2NRAM);
102
+ __bang_add(ans, lhs, rhs, {N});
103
+ __memcpy(ans_, ans, {N} * sizeof(float), NRAM2GDRAM);
104
+ }}"#
105
+ ) ;
106
+
107
+ crate :: init ( ) ;
108
+ let Some ( dev) = crate :: Device :: fetch ( ) else {
109
+ return ;
110
+ } ;
111
+
112
+ let ( result, log) = CnrtcBinary :: compile ( src, dev. isa ( ) ) ;
113
+ if !log. is_empty ( ) {
114
+ eprintln ! ( "{log}" ) ;
115
+ }
116
+ let bin = result. unwrap ( ) ;
117
+
118
+ let a = vec ! [ 1.0f32 ; N ] ;
119
+ let b = vec ! [ 2.0f32 ; N ] ;
120
+ let mut c = vec ! [ 0.0f32 ; N ] ;
121
+
122
+ if let Some ( dev) = crate :: Device :: fetch ( ) {
123
+ dev. context ( ) . apply ( |ctx| {
124
+ let mut lhs = ctx. malloc :: < f32 > ( N ) ;
125
+ let mut rhs = ctx. malloc :: < f32 > ( N ) ;
126
+ let mut ans = ctx. malloc :: < f32 > ( N ) ;
127
+
128
+ let queue = ctx. queue ( ) ;
129
+ queue. memcpy_h2d ( & mut lhs, & a) ;
130
+ queue. memcpy_h2d ( & mut rhs, & b) ;
131
+
132
+ let lhs_ptr = lhs. as_ptr ( ) ;
133
+ let rhs_ptr = rhs. as_ptr ( ) ;
134
+ let ans_ptr = ans. as_mut_ptr ( ) ;
135
+ let params: [ * const c_void ; 3 ] = [
136
+ & ans_ptr as * const _ as _ ,
137
+ & lhs_ptr as * const _ as _ ,
138
+ & rhs_ptr as * const _ as _ ,
139
+ ] ;
140
+
141
+ ctx. load ( & bin)
142
+ . get_kernel ( & CString :: new ( "kernel" ) . unwrap ( ) )
143
+ . launch ( 1 , 1 , 1 , params. as_ptr ( ) as _ , & queue) ;
144
+
145
+ memcpy_d2h ( & mut c, & ans) ;
146
+ } ) ;
147
+ assert_eq ! ( c, & [ 3.0f32 ; N ] ) ;
148
+ } ;
149
+ }
0 commit comments