From abdfb744723f5931add7c248ffc242a11d52b976 Mon Sep 17 00:00:00 2001
From: Henry Gressmann <mail@henrygressmann.de>
Date: Mon, 11 Dec 2023 00:08:19 +0100
Subject: [PATCH] test(tinywasm): add AssertMalformed to spec test harness

Signed-off-by: Henry Gressmann <mail@henrygressmann.de>
---
 crates/parser/src/conversion.rs |   4 +-
 crates/tinywasm/tests/mvp.rs    | 125 +++++++++++++++++++++++---------
 2 files changed, 92 insertions(+), 37 deletions(-)

diff --git a/crates/parser/src/conversion.rs b/crates/parser/src/conversion.rs
index abc401c..f15c9ae 100644
--- a/crates/parser/src/conversion.rs
+++ b/crates/parser/src/conversion.rs
@@ -31,14 +31,16 @@ pub(crate) fn convert_module_code(
 ) -> Result<CodeSection> {
     let locals_reader = func.get_locals_reader()?;
     let count = locals_reader.get_count();
+
     let mut locals = Vec::with_capacity(count as usize);
 
     for (i, local) in locals_reader.into_iter().enumerate() {
         let local = local?;
+        validator.define_locals(i, local.0, local.1)?;
+
         for _ in 0..local.0 {
             locals.push(convert_valtype(&local.1));
         }
-        validator.define_locals(i, local.0, local.1)?;
     }
 
     if locals.len() != count as usize {
diff --git a/crates/tinywasm/tests/mvp.rs b/crates/tinywasm/tests/mvp.rs
index 96bdc37..dacffe3 100644
--- a/crates/tinywasm/tests/mvp.rs
+++ b/crates/tinywasm/tests/mvp.rs
@@ -1,5 +1,5 @@
 use std::{
-    collections::HashMap,
+    collections::BTreeMap,
     fmt::{Debug, Formatter},
 };
 
@@ -21,29 +21,64 @@ fn parse_module(mut module: wast::core::Module) -> Result<TinyWasmModule, Error>
 fn test_mvp() {
     let mut test_suite = TestSuite::new();
 
-    wasm_testsuite::MVP_TESTS.iter().for_each(|name| {
-        println!("test: {}", name);
+    wasm_testsuite::MVP_TESTS.iter().for_each(|group| {
+        println!("test: {}", group);
 
-        let test_group = test_suite.test_group("mvp");
+        let test_group = test_suite.test_group(group);
 
-        let wast = wasm_testsuite::get_test_wast(name).expect("failed to get test wast");
+        let wast = wasm_testsuite::get_test_wast(group).expect("failed to get test wast");
         let wast = std::str::from_utf8(&wast).expect("failed to convert wast to utf8");
 
         let mut lexer = Lexer::new(&wast);
+        // we need to allow confusing unicode characters since they are technically valid wasm
         lexer.allow_confusing_unicode(true);
 
         let buf = ParseBuffer::new_with_lexer(lexer).expect("failed to create parse buffer");
         let wast_data = parser::parse::<Wast>(&buf).expect("failed to parse wat");
 
-        for directive in wast_data.directives {
+        for (i, directive) in wast_data.directives.into_iter().enumerate() {
             let span = directive.span();
 
             use wast::WastDirective::*;
+            let name = format!("{}-{}", group, i);
             match directive {
+                // TODO: needs to support more binary sections
                 Wat(QuoteWat::Wat(wast::Wat::Module(module))) => {
-                    let module = parse_module(module).map(|_| ());
-                    test_group.module_compiles(name, span, module);
+                    let module = std::panic::catch_unwind(|| parse_module(module));
+                    test_group.add_result(
+                        &format!("{}-parse", name),
+                        span,
+                        match module {
+                            Ok(Ok(_)) => Ok(()),
+                            Ok(Err(e)) => Err(e),
+                            Err(e) => Err(Error::Other(format!("failed to parse module: {:?}", e))),
+                        },
+                    );
                 }
+                // these all pass already :)
+                AssertMalformed {
+                    span,
+                    module: QuoteWat::Wat(wast::Wat::Module(module)),
+                    message,
+                } => {
+                    println!("  assert_malformed: {}", message);
+                    let res = std::panic::catch_unwind(|| parse_module(module).map(|_| ()));
+
+                    test_group.add_result(
+                        &format!("{}-malformed", name),
+                        span,
+                        match res {
+                            Ok(Ok(_)) => Err(Error::Other("expected module to be malformed".to_string())),
+                            Err(_) | Ok(Err(_)) => Ok(()),
+                        },
+                    );
+                }
+                // _ => test_group.add_result(
+                //     &format!("{}-unknown", name),
+                //     span,
+                //     Err(Error::Other("test not implemented".to_string())),
+                // ),
+                // TODO: implement more test directives
                 _ => {}
             }
         }
@@ -51,18 +86,20 @@ fn test_mvp() {
 
     if test_suite.failed() {
         panic!("failed one or more tests: {:#?}", test_suite);
+    } else {
+        println!("passed all tests: {:#?}", test_suite);
     }
 }
 
-struct TestSuite(HashMap<String, TestGroup>);
+struct TestSuite(BTreeMap<String, TestGroup>);
 
 impl TestSuite {
     fn new() -> Self {
-        Self(HashMap::new())
+        Self(BTreeMap::new())
     }
 
     fn failed(&self) -> bool {
-        self.0.values().any(|group| group.failed())
+        self.0.values().any(|group| group.stats().1 > 0)
     }
 
     fn test_group(&mut self, name: &str) -> &mut TestGroup {
@@ -73,49 +110,65 @@ impl TestSuite {
 impl Debug for TestSuite {
     fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
         use owo_colors::OwoColorize;
-        let mut passed_count = 0;
-        let mut failed_count = 0;
+        let mut total_passed = 0;
+        let mut total_failed = 0;
 
         for (group_name, group) in &self.0 {
+            let (group_passed, group_failed) = group.stats();
+            total_passed += group_passed;
+            total_failed += group_failed;
+
             writeln!(f, "{}", group_name.bold().underline())?;
-            for (test_name, test) in &group.tests {
-                writeln!(f, "  {}", test_name.bold())?;
-                match test.result {
-                    Ok(()) => {
-                        writeln!(f, "    Result: {}", "Passed".green())?;
-                        passed_count += 1;
-                    }
-                    Err(_) => {
-                        writeln!(f, "    Result: {}", "Failed".red())?;
-                        failed_count += 1;
-                    }
-                }
-                writeln!(f, "    Span: {:?}", test.span)?;
-            }
+            writeln!(f, "  Tests Passed: {}", group_passed.to_string().green())?;
+            writeln!(f, "  Tests Failed: {}", group_failed.to_string().red())?;
+
+            // for (test_name, test) in &group.tests {
+            //     write!(f, "    {}: ", test_name.bold())?;
+            //     match &test.result {
+            //         Ok(()) => {
+            //             writeln!(f, "{}", "Passed".green())?;
+            //         }
+            //         Err(e) => {
+            //             writeln!(f, "{}", "Failed".red())?;
+            //             // writeln!(f, "Error: {:?}", e)?;
+            //         }
+            //     }
+            //     writeln!(f, "      Span: {:?}", test.span)?;
+            // }
         }
 
-        writeln!(f, "\n{}", "Test Summary:".bold().underline())?;
-        writeln!(f, "  Total Tests: {}", (passed_count + failed_count))?;
-        writeln!(f, "  Passed: {}", passed_count.to_string().green())?;
-        writeln!(f, "  Failed: {}", failed_count.to_string().red())?;
+        writeln!(f, "\n{}", "Total Test Summary:".bold().underline())?;
+        writeln!(f, "  Total Tests: {}", (total_passed + total_failed))?;
+        writeln!(f, "  Total Passed: {}", total_passed.to_string().green())?;
+        writeln!(f, "  Total Failed: {}", total_failed.to_string().red())?;
         Ok(())
     }
 }
 
 struct TestGroup {
-    tests: HashMap<String, TestCase>,
+    tests: BTreeMap<String, TestCase>,
 }
 
 impl TestGroup {
     fn new() -> Self {
-        Self { tests: HashMap::new() }
+        Self { tests: BTreeMap::new() }
     }
 
-    fn failed(&self) -> bool {
-        self.tests.values().any(|test| test.result.is_err())
+    fn stats(&self) -> (usize, usize) {
+        let mut passed_count = 0;
+        let mut failed_count = 0;
+
+        for test in self.tests.values() {
+            match test.result {
+                Ok(()) => passed_count += 1,
+                Err(_) => failed_count += 1,
+            }
+        }
+
+        (passed_count, failed_count)
     }
 
-    fn module_compiles(&mut self, name: &str, span: wast::token::Span, result: Result<()>) {
+    fn add_result(&mut self, name: &str, span: wast::token::Span, result: Result<()>) {
         self.tests.insert(name.to_string(), TestCase { result, span });
     }
 }