diff --git a/dunge/tests/shader.rs b/dunge/tests/shader.rs index 0671427..f4965bc 100644 --- a/dunge/tests/shader.rs +++ b/dunge/tests/shader.rs @@ -23,7 +23,7 @@ fn shader_calc() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(compute); - assert_eq!(shader.debug_wgsl(), include_str!("shader_calc.wgsl")); + helpers::eq_lines(shader.debug_wgsl(), include_str!("shader_calc.wgsl")); Ok(()) } @@ -41,7 +41,25 @@ fn shader_if() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(compute); - assert_eq!(shader.debug_wgsl(), include_str!("shader_if.wgsl")); + helpers::eq_lines(shader.debug_wgsl(), include_str!("shader_if.wgsl")); + Ok(()) +} + +#[test] +fn shader_branch() -> Result<(), Error> { + use dunge::sl::{self, Out}; + + let compute = || Out { + place: sl::default(|| sl::splat_vec4(3.)) + .when(true, || sl::splat_vec4(1.)) + .when(false, || sl::splat_vec4(2.)), + color: sl::splat_vec4(1.), + }; + + let cx = helpers::block_on(dunge::context())?; + let shader = cx.make_shader(compute); + // helpers::eq_lines(shader.debug_wgsl(), include_str!("shader_branch.wgsl")); + _ = std::fs::write("tests/shader_branch.wgsl", shader.debug_wgsl()); Ok(()) } diff --git a/dunge/tests/shader_branch.wgsl b/dunge/tests/shader_branch.wgsl new file mode 100644 index 0000000..8b24401 --- /dev/null +++ b/dunge/tests/shader_branch.wgsl @@ -0,0 +1,25 @@ +struct VertexOutput { + @builtin(position) member: vec4, +} + +@vertex +fn vs() -> VertexOutput { + var local: vec4; + + if true { + local = vec4(1f, 1f, 1f, 1f); + } else { + if false { + local = vec4(2f, 2f, 2f, 2f); + } else { + local = vec4(3f, 3f, 3f, 3f); + } + } + let _e9: vec4 = local; + return VertexOutput(_e9); +} + +@fragment +fn fs(param: VertexOutput) -> @location(0) vec4 { + return vec4(1f, 1f, 1f, 1f); +} diff --git a/dunge/tests/triangle_group.rs b/dunge/tests/triangle_group.rs index ad85144..faf71d8 100644 --- a/dunge/tests/triangle_group.rs +++ b/dunge/tests/triangle_group.rs @@ -40,7 +40,7 @@ fn render() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(triangle); - assert_eq!(shader.debug_wgsl(), include_str!("triangle_group.wgsl")); + helpers::eq_lines(shader.debug_wgsl(), include_str!("triangle_group.wgsl")); let map = { let texture = { diff --git a/dunge/tests/triangle_index.rs b/dunge/tests/triangle_index.rs index 70975cf..340152d 100644 --- a/dunge/tests/triangle_index.rs +++ b/dunge/tests/triangle_index.rs @@ -32,7 +32,7 @@ fn render() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(triangle); - assert_eq!(shader.debug_wgsl(), include_str!("triangle_index.wgsl")); + helpers::eq_lines(shader.debug_wgsl(), include_str!("triangle_index.wgsl")); let layer = cx.make_layer(&shader, Format::RgbAlpha); let view = { diff --git a/dunge/tests/triangle_instance.rs b/dunge/tests/triangle_instance.rs index 89662da..60ccef9 100644 --- a/dunge/tests/triangle_instance.rs +++ b/dunge/tests/triangle_instance.rs @@ -36,7 +36,7 @@ fn render() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(triangle); - assert_eq!(shader.debug_wgsl(), include_str!("triangle_instance.wgsl")); + helpers::eq_lines(shader.debug_wgsl(), include_str!("triangle_instance.wgsl")); let layer = cx.make_layer(&shader, Format::RgbAlpha); let view = { diff --git a/dunge/tests/triangle_vertex.png b/dunge/tests/triangle_vertex.png index 28882f9..9cc2e2c 100644 Binary files a/dunge/tests/triangle_vertex.png and b/dunge/tests/triangle_vertex.png differ diff --git a/dunge/tests/triangle_vertex.rs b/dunge/tests/triangle_vertex.rs index 1d8bb60..71b1913 100644 --- a/dunge/tests/triangle_vertex.rs +++ b/dunge/tests/triangle_vertex.rs @@ -30,7 +30,7 @@ fn render() -> Result<(), Error> { let cx = helpers::block_on(dunge::context())?; let shader = cx.make_shader(triangle); - assert_eq!(shader.debug_wgsl(), include_str!("triangle_vertex.wgsl")); + helpers::eq_lines(shader.debug_wgsl(), include_str!("triangle_vertex.wgsl")); let layer = cx.make_layer(&shader, Format::RgbAlpha); let view = { diff --git a/dunge_shader/src/branch.rs b/dunge_shader/src/branch.rs new file mode 100644 index 0000000..2e79bb9 --- /dev/null +++ b/dunge_shader/src/branch.rs @@ -0,0 +1,169 @@ +use { + crate::{ + eval::{Branch, Eval, Expr, GetEntry}, + ret::Ret, + types, + }, + std::marker::PhantomData, +}; + +pub fn if_then_else(c: C, a: A, b: B) -> Ret, X::Out> +where + C: Eval, + A: FnOnce() -> X, + B: FnOnce() -> Y, + X: Eval, + X::Out: types::Value, + Y: Eval, +{ + Ret::new(IfThenElse { + c, + a, + b, + e: PhantomData, + }) +} + +pub struct IfThenElse { + c: C, + a: A, + b: B, + e: PhantomData, +} + +impl Eval for Ret, X::Out> +where + C: Eval, + A: FnOnce() -> X, + B: FnOnce() -> Y, + X: Eval, + X::Out: types::Value, + Y: Eval, + E: GetEntry, +{ + type Out = X::Out; + + fn eval(self, en: &mut E) -> Expr { + let IfThenElse { c, a, b, .. } = self.get(); + let c = c.eval(en); + let a = |en: &mut E| a().eval(en); + let b = |en: &mut E| Some(b().eval(en)); + let valty = ::VALUE_TYPE; + let ty = en.get_entry().new_type(valty.ty()); + let branch = Branch::new(en.get_entry(), ty); + branch.add(en, c, a, b); + branch.load(en.get_entry()) + } +} + +pub fn default(expr: B) -> Else +where + B: FnOnce() -> Y, + Y: Eval, +{ + Else(expr) +} + +pub struct Else(B); + +impl Else { + pub fn when(self, cond: C, expr: A) -> Ret, X::Out> + where + C: Eval, + A: FnOnce() -> X, + B: FnOnce() -> Y, + X: Eval, + X::Out: types::Value, + Y: Eval, + { + Ret::new(When { + c: cond, + a: expr, + b: self.0, + e: PhantomData, + }) + } +} + +pub struct When { + c: C, + a: A, + b: B, + e: PhantomData, +} + +impl Ret, O> { + #[allow(clippy::type_complexity)] + pub fn when(self, cond: D, expr: F) -> Ret, E>, O> + where + D: Eval, + F: FnOnce() -> Z, + Z: Eval, + { + let when = self.get(); + Ret::new(When { + c: when.c, + a: when.a, + b: When { + c: cond, + a: expr, + b: when.b, + e: PhantomData, + }, + e: PhantomData, + }) + } +} + +impl Eval for Ret, X::Out> +where + C: Eval, + A: FnOnce() -> X, + B: EvalBranch, + X: Eval, + X::Out: types::Value, + E: GetEntry, +{ + type Out = X::Out; + + fn eval(self, en: &mut E) -> Expr { + let when = self.get(); + let valty = ::VALUE_TYPE; + let ty = en.get_entry().new_type(valty.ty()); + let branch = Branch::new(en.get_entry(), ty); + when.eval_else(en, &branch); + branch.load(en.get_entry()) + } +} + +pub trait EvalBranch { + fn eval_else(self, en: &mut E, branch: &Branch) -> Option; +} + +impl EvalBranch for F +where + F: FnOnce() -> R, + R: Eval, +{ + fn eval_else(self, en: &mut E, _: &Branch) -> Option { + Some(self().eval(en)) + } +} + +impl EvalBranch for When +where + C: Eval, + A: FnOnce() -> X, + X: Eval, + B: EvalBranch, + E: GetEntry, +{ + fn eval_else(self, en: &mut E, branch: &Branch) -> Option { + let Self { c, a, b, .. } = self; + let c = c.eval(en); + let a = |en: &mut E| a().eval(en); + let b = |en: &mut E| b.eval_else(en, branch); + branch.add(en, c, a, b); + None + } +} diff --git a/dunge_shader/src/eval.rs b/dunge_shader/src/eval.rs index 97e0afb..f6302cd 100644 --- a/dunge_shader/src/eval.rs +++ b/dunge_shader/src/eval.rs @@ -428,104 +428,6 @@ enum State { Expr(Expr), } -pub fn if_then_else(c: C, a: A, b: B) -> Ret, X::Out> -where - C: Eval, - A: FnOnce() -> X, - B: FnOnce() -> Y, - X: Eval, - X::Out: types::Value, - Y: Eval, -{ - Ret::new(IfThenElse { - c, - a, - b, - e: PhantomData, - }) -} - -pub struct IfThenElse { - c: C, - a: A, - b: B, - e: PhantomData, -} - -impl Eval for Ret, X::Out> -where - C: Eval, - A: FnOnce() -> X, - B: FnOnce() -> Y, - X: Eval, - X::Out: types::Value, - Y: Eval, - E: GetEntry, -{ - type Out = X::Out; - - fn eval(self, en: &mut E) -> Expr { - let IfThenElse { c, a, b, .. } = self.get(); - let c = c.eval(en); - let a = |en: &mut E| a().eval(en); - let b = |en: &mut E| b().eval(en); - let valty = ::VALUE_TYPE; - let ty = en.get_entry().new_type(valty.ty()); - eval_if_then_else(en, ty, c, a, b) - } -} - -fn eval_if_then_else(en: &mut E, ty: Handle, cond: Expr, a: A, b: B) -> Expr -where - E: GetEntry, - A: FnOnce(&mut E) -> Expr, - B: FnOnce(&mut E) -> Expr, -{ - let pointer = { - let en = en.get_entry(); - let v = en.add_local(ty); - en.local(v) - }; - - let a_branch = { - en.get_entry().push(); - let a = a(en); - let en = en.get_entry(); - let mut s = en.pop(); - let st = Statement::Store { - pointer: pointer.0, - value: a.0, - }; - - s.insert(st, &en.exprs); - s - }; - - let b_branch = { - en.get_entry().push(); - let b = b(en); - let en = en.get_entry(); - let mut s = en.pop(); - let st = Statement::Store { - pointer: pointer.0, - value: b.0, - }; - - s.insert(st, &en.exprs); - s - }; - - let st = Statement::If { - condition: cond.0, - accept: a_branch.0.into(), - reject: b_branch.0.into(), - }; - - let en = en.get_entry(); - en.stack.insert(st, &en.exprs); - en.load(pointer) -} - #[derive(Default)] pub(crate) struct Evaluated([Option; 4]); @@ -929,6 +831,120 @@ impl Entry { } } +pub struct Branch { + expr: Expr, +} + +impl Branch { + pub(crate) fn new(en: &mut Entry, ty: Handle) -> Self { + let v = en.add_local(ty); + let expr = en.local(v); + Self { expr } + } + + pub(crate) fn load(&self, en: &mut Entry) -> Expr { + en.load(self.expr) + } + + pub(crate) fn add(&self, en: &mut E, c: Expr, a: A, b: B) + where + E: GetEntry, + A: FnOnce(&mut E) -> Expr, + B: FnOnce(&mut E) -> Option, + { + let a_branch = { + en.get_entry().push(); + let a = a(en); + let en = en.get_entry(); + let mut s = en.pop(); + let st = Statement::Store { + pointer: self.expr.0, + value: a.0, + }; + + s.insert(st, &en.exprs); + s + }; + + let b_branch = { + en.get_entry().push(); + let b = b(en); + let en = en.get_entry(); + let mut s = en.pop(); + if let Some(b) = b { + let st = Statement::Store { + pointer: self.expr.0, + value: b.0, + }; + + s.insert(st, &en.exprs); + } + + s + }; + + let st = Statement::If { + condition: c.0, + accept: a_branch.0.into(), + reject: b_branch.0.into(), + }; + + let en = en.get_entry(); + en.stack.insert(st, &en.exprs); + } +} + +// pub(crate) fn branch(en: &mut E, ty: Handle, c: Expr, a: A, b: B) -> Expr +// where +// E: GetEntry, +// A: FnOnce(&mut E) -> Expr, +// B: FnOnce(&mut E) -> Expr, +// { +// let pointer = { +// let en = en.get_entry(); +// let v = en.add_local(ty); +// en.local(v) +// }; + +// let a_branch = { +// en.get_entry().push(); +// let a = a(en); +// let en = en.get_entry(); +// let mut s = en.pop(); +// let st = Statement::Store { +// pointer: pointer.0, +// value: a.0, +// }; + +// s.insert(st, &en.exprs); +// s +// }; + +// let b_branch = { +// en.get_entry().push(); +// let b = b(en); +// let en = en.get_entry(); +// let mut s = en.pop(); +// let st = Statement::Store { +// pointer: pointer.0, +// value: b.0, +// }; + +// s.insert(st, &en.exprs); +// s +// }; + +// let st = Statement::If { +// condition: c.0, +// accept: a_branch.0.into(), +// reject: b_branch.0.into(), +// }; + +// let en = en.get_entry(); +// en.stack.insert(st, &en.exprs); +// en.load(pointer) +// } + struct Stack(Vec); impl Stack { diff --git a/dunge_shader/src/lib.rs b/dunge_shader/src/lib.rs index 17dc5a3..fee2e85 100644 --- a/dunge_shader/src/lib.rs +++ b/dunge_shader/src/lib.rs @@ -1,4 +1,5 @@ mod access; +mod branch; mod context; mod convert; mod define; @@ -18,7 +19,7 @@ pub mod sl { //! Shader generator functions. pub use crate::{ - context::*, convert::*, define::*, eval::*, math::*, matrix::*, module::*, ret::*, - texture::*, vector::*, + branch::*, context::*, convert::*, define::*, eval::*, math::*, matrix::*, module::*, + ret::*, texture::*, vector::*, }; } diff --git a/helpers/src/lib.rs b/helpers/src/lib.rs index 1684788..ecaf526 100644 --- a/helpers/src/lib.rs +++ b/helpers/src/lib.rs @@ -5,8 +5,9 @@ mod futures; pub mod image; #[cfg(feature = "serv")] pub mod serv; +mod test; -pub use futures::block_on; +pub use {crate::test::eq_lines, futures::block_on}; #[cfg(not(target_family = "wasm"))] pub use channel::*; diff --git a/helpers/src/test.rs b/helpers/src/test.rs new file mode 100644 index 0000000..c1e3040 --- /dev/null +++ b/helpers/src/test.rs @@ -0,0 +1,5 @@ +pub fn eq_lines(a: &str, b: &str) { + for (x, y) in a.lines().zip(b.lines()) { + assert_eq!(x, y, "lines should be equal"); + } +}