Skip to content

Commit ab2b1c9

Browse files
committed
fix(cndrv): 增加一个计算类 kernel 的测试,改正 cnrtc 参数错误
Signed-off-by: YdrMaster <ydrml@hotmail.com>
1 parent 9816ec3 commit ab2b1c9

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

cndrv/src/cnrtc/binary.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ impl CnrtcBinary {
2727

2828
let options = [
2929
CString::new("-O3").unwrap(),
30+
CString::new("--bang-fatbin-only").unwrap(),
3031
CString::new(format!("--bang-mlu-arch=mtp_{isa}")).unwrap(),
3132
];
3233
let options = options.iter().map(|s| s.as_ptr()).collect::<Vec<_>>();

cndrv/src/cnrtc/mod.rs

Lines changed: 68 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@ fn test_behavior() {
1111
ptr::{null, null_mut},
1212
};
1313

14-
const CODE: &str =
15-
r#"extern "C" __mlu_global__ void kernel() { printf("Hello from MLU!\n"); }"#;
14+
const CODE: &str = r#"extern "C" __mlu_entry__ void kernel() { printf("Hello from MLU!\n"); }"#;
1615

1716
crate::init();
1817
let Some(dev) = crate::Device::fetch() else {
@@ -52,7 +51,6 @@ fn test_behavior() {
5251
bin
5352
};
5453

55-
crate::init();
5654
if let Some(dev) = crate::Device::fetch() {
5755
dev.context().apply(|ctx| {
5856
use crate::AsRaw;
@@ -82,3 +80,70 @@ fn test_behavior() {
8280
});
8381
};
8482
}
83+
84+
#[test]
85+
fn test_add() {
86+
use crate::memcpy_d2h;
87+
use std::ffi::{c_void, CString};
88+
89+
const N: usize = 64;
90+
let src = format!(
91+
r#"
92+
extern "C" __mlu_entry__ void kernel(
93+
float * ans_,
94+
float const* lhs_,
95+
float const* rhs_
96+
) {{
97+
__nram__ float lhs[{N}];
98+
__nram__ float rhs[{N}];
99+
__nram__ float ans[{N}];
100+
__memcpy(lhs, lhs_, {N} * sizeof(float), GDRAM2NRAM);
101+
__memcpy(rhs, rhs_, {N} * sizeof(float), GDRAM2NRAM);
102+
__bang_add(ans, lhs, rhs, {N});
103+
__memcpy(ans_, ans, {N} * sizeof(float), NRAM2GDRAM);
104+
}}"#
105+
);
106+
107+
crate::init();
108+
let Some(dev) = crate::Device::fetch() else {
109+
return;
110+
};
111+
112+
let (result, log) = CnrtcBinary::compile(src, dev.isa());
113+
if !log.is_empty() {
114+
eprintln!("{log}");
115+
}
116+
let bin = result.unwrap();
117+
118+
let a = vec![1.0f32; N];
119+
let b = vec![2.0f32; N];
120+
let mut c = vec![0.0f32; N];
121+
122+
if let Some(dev) = crate::Device::fetch() {
123+
dev.context().apply(|ctx| {
124+
let mut lhs = ctx.malloc::<f32>(N);
125+
let mut rhs = ctx.malloc::<f32>(N);
126+
let mut ans = ctx.malloc::<f32>(N);
127+
128+
let queue = ctx.queue();
129+
queue.memcpy_h2d(&mut lhs, &a);
130+
queue.memcpy_h2d(&mut rhs, &b);
131+
132+
let lhs_ptr = lhs.as_ptr();
133+
let rhs_ptr = rhs.as_ptr();
134+
let ans_ptr = ans.as_mut_ptr();
135+
let params: [*const c_void; 3] = [
136+
&ans_ptr as *const _ as _,
137+
&lhs_ptr as *const _ as _,
138+
&rhs_ptr as *const _ as _,
139+
];
140+
141+
ctx.load(&bin)
142+
.get_kernel(&CString::new("kernel").unwrap())
143+
.launch(1, 1, 1, params.as_ptr() as _, &queue);
144+
145+
memcpy_d2h(&mut c, &ans);
146+
});
147+
assert_eq!(c, &[3.0f32; N]);
148+
};
149+
}

0 commit comments

Comments
 (0)