Skip to content

Commit

Permalink
Refactor and upgrade test suite
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiritai committed Nov 7, 2024
1 parent c2d1bdd commit 9e0515b
Showing 1 changed file with 174 additions and 130 deletions.
304 changes: 174 additions & 130 deletions src/self-test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import random
import subprocess
import sys
from typing import Optional
from unittest import TextTestRunner, TestSuite, defaultTestLoader
from cfg import CFG, BasicBlock
from bril import Const, ValueOperation, parse_bril, serialize_bril
from instruction.common import ValType
from is_ssa import is_ssa
from instruction.instruction import Instruction
from instruction.value import CoreValType
Expand All @@ -17,23 +17,11 @@
from logger.test import LoggerTest
from instruction.test import InstTest

class BasicBlockTest(LoggedTestCase):
def test_eq(self):
b1 = BasicBlock('b1')
b2 = b1
self.assertEqual(b1, b2)
b3 = BasicBlock('b1')
self.assertNotEqual(b1, b3)

def test_hash(self):
b1 = BasicBlock('b1')
bb_set = { b1 }
bb_set2 = { b1 }
self.assertSetEqual(bb_set, bb_set2)

script_dir = os.path.dirname(os.path.realpath(sys.argv[0]))
example_path = os.path.realpath(f"{script_dir}/../tests/example.bril")

# -------- [Helper functions] --------

def load_program(bril_file: Optional[str] = None):
bril_file = bril_file if bril_file is not None else example_path
with open(bril_file, "r") as f:
Expand All @@ -43,10 +31,32 @@ def load_program(bril_file: Optional[str] = None):
def load_args(bril_file: Optional[str] = None):
bril_file = bril_file if bril_file is not None else example_path
with open(bril_file, "r") as f:
line = f.readline()
lines = list(map(lambda l: l.strip('\n'), f.readlines()))
line = lines[0]
possible_args = []
flags = []
flag_black_list = ("-c", "-f")
if line.startswith("# ARGS: "):
return line.strip("# ARGS: ").split()
return None
args = line.removeprefix("# ARGS: ").split()
possible_args = list(filter(lambda a: '-' not in a, args))
flags = list(filter(lambda a: '-' in a and a not in flag_black_list, args))
for line in lines:
if line.startswith("@main"):
args = line.removeprefix("@main")\
.strip().removeprefix('(').removesuffix('{')\
.strip().removesuffix(')').split(", ")
ret: str = []
for n, arg in enumerate(filter(lambda a: len(a) > 0, args)):
_, tp = arg.split(": ")
tp = ValType.find(tp)
if isinstance(tp, ValType):
if len(possible_args) > n:
ret.append(possible_args[n])
else:
ret.append(tp.random_bril_val)
if len(ret) > 0:
return [*flags, *ret]
return flags if len(flags) != 0 else None

def load_golden_program(bril_file: Optional[str]):
bril_file = bril_file if bril_file is not None else example_path
Expand All @@ -70,6 +80,148 @@ def _traverse(entry: str):
def bb2labels(s: set[BasicBlock]):
return set(bb.label for bb in s)

def compare_ssa(cfg1: CFG, cfg2: CFG):
name_map: dict[str, str] = {}
for bb1, bb2 in zip(cfg1.blocks.values(), cfg2.blocks.values()):
for i1, i2 in zip(sorted(bb1.get_by_op(SsaOpType.PHI), key=lambda i: i.dest),
sorted(bb2.get_by_op(SsaOpType.PHI), key=lambda i: i.dest)):
if isinstance(i1, (ValueOperation, Const)):
if (not isinstance(i2, (ValueOperation, Const)) or
i1.op != i2.op):
err = ValueError(f"Unequal CFGs (inst type not match): {i1} can't map to {i2}")
logger.error(err)
raise err
if i1.dest in name_map:
err = ValueError(f"CFG1 is not in SSA form: duplicated {i1.dest}")
logger.error(err)
raise err
if i2.dest in name_map.values() is not None:
err = ValueError(f"CFG2 is not in SSA form: duplicated {i2.dest}")
logger.error(err)
raise err
name_map[i1.dest] = i2.dest

for bb1, bb2 in zip(cfg1.blocks.values(), cfg2.blocks.values()):
for i1, i2 in zip(list(i for i in bb1.insts if i.op != SsaOpType.PHI),
list(i for i in bb2.insts if i.op != SsaOpType.PHI)):
if isinstance(i1, (ValueOperation, Const)) and isinstance(i2, (ValueOperation, Const)):
name_map[i1.dest] = i2.dest

for bb1, bb2 in zip(cfg1.blocks.values(), cfg2.blocks.values()):
for i1, i2 in zip(list(i for i in bb1.insts if i.op != SsaOpType.PHI),
list(i for i in bb2.insts if i.op != SsaOpType.PHI)):
if i1.op != i2.op:
err = ValueError(f"Can't map [op] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'type'):
if not hasattr(i2, 'type') or i1.type != i2.type:
err = ValueError(f"Can't map [type] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'value'):
if not hasattr(i2, 'value') or i1.value != i2.value:
err = ValueError(f"Can't map [value] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'dest'):
if not hasattr(i2, 'dest') or i1.dest not in name_map or name_map[i1.dest] != i2.dest:
err = ValueError(f"Can't map [dest] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'args') and i1.args is not None:
if not hasattr(i2, 'args') or i2.args is None:
err = ValueError(f"Can't map [args] {i1} to {i2}")
logger.error(err)
raise err
else:
for arg1, arg2 in zip(i1.args, i2.args):
if name_map[arg1] != arg2:
err = ValueError(f"Can't map [args] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'funcs'):
if not hasattr(i2, 'funcs'):
err = ValueError(f"Can't map [funcs] {i1} to {i2}")
logger.error(err)
raise err
else:
if i1.funcs != i2.funcs:
err = ValueError(f"Can't map [funcs] {i1} to {i2}")
logger.error(err)
raise err

def test_brils_ssa(path: str):
brils = find_all_bril(path)
failed = []
for bril_file in brils:
program = load_program(bril_file)

# skip all ssa bril
for f in program.functions:
to_skip = False
for i in f.instrs:
if i.op == SsaOpType.PHI:
logger.debug(f"Skip test {bril_file} since its in ssa form")
to_skip = True # No need to test this
break
if to_skip:
break
if to_skip:
continue

logger.debug(f"Test {bril_file}")
# logger.flush()
cmd = ["brili"]
args = load_args(bril_file)
if args is not None:
cmd.extend(args)

json_input = serialize_bril(program)

p = subprocess.Popen(cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
golden = p.communicate(input=json_input.encode())

for func1 in program.functions:
construct_ssa(func1)

if not is_ssa(program):
err = ValueError(f"Program is not in ssa form")
logger.error(err)
raise err
json_output = serialize_bril(program)

p = subprocess.Popen(cmd,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE)
attempt = p.communicate(input=json_output.encode())
if golden != attempt:
err = ValueError(f"Computation comparison does not match\n\tGolden: <{golden}>\n\tAttempt: <{attempt}>")
logger.error(err)
for func in program.functions:
CFG(func).view_blocks()
failed.append(err)
if len(failed) > 0:
raise ValueError(f"Errors: {failed}")

# -------- [Test suites] --------

class BasicBlockTest(LoggedTestCase):
def test_eq(self):
b1 = BasicBlock('b1')
b2 = b1
self.assertEqual(b1, b2)
b3 = BasicBlock('b1')
self.assertNotEqual(b1, b3)

def test_hash(self):
b1 = BasicBlock('b1')
bb_set = { b1 }
bb_set2 = { b1 }
self.assertSetEqual(bb_set, bb_set2)

class CfgTest(LoggedTestCase):

def test_make_cfg(self):
Expand Down Expand Up @@ -239,81 +391,6 @@ def test_rename_runnable(self):
insert_phi_functions(dom_tree, global_d2b)
rename_variables(cfg, dom_tree, defs, global_names)

def ssa_checker(cfg1: CFG, cfg2: CFG):
name_map: dict[str, str] = {}
for bb1, bb2 in zip(cfg1.blocks.values(), cfg2.blocks.values()):
for i1, i2 in zip(sorted(bb1.get_by_op(SsaOpType.PHI), key=lambda i: i.dest),
sorted(bb2.get_by_op(SsaOpType.PHI), key=lambda i: i.dest)):
if isinstance(i1, (ValueOperation, Const)):
if (not isinstance(i2, (ValueOperation, Const)) or
i1.op != i2.op):
err = ValueError(f"Unequal CFGs (inst type not match): {i1} can't map to {i2}")
logger.error(err)
raise err
if i1.dest in name_map:
err = ValueError(f"CFG1 is not in SSA form: duplicated {i1.dest}")
logger.error(err)
raise err
if i2.dest in name_map.values() is not None:
err = ValueError(f"CFG2 is not in SSA form: duplicated {i2.dest}")
logger.error(err)
raise err
name_map[i1.dest] = i2.dest

for bb1, bb2 in zip(cfg1.blocks.values(), cfg2.blocks.values()):
for i1, i2 in zip(list(i for i in bb1.insts if i.op != SsaOpType.PHI),
list(i for i in bb2.insts if i.op != SsaOpType.PHI)):
if isinstance(i1, (ValueOperation, Const)) and isinstance(i2, (ValueOperation, Const)):
name_map[i1.dest] = i2.dest

for bb1, bb2 in zip(cfg1.blocks.values(), cfg2.blocks.values()):
for i1, i2 in zip(list(i for i in bb1.insts if i.op != SsaOpType.PHI),
list(i for i in bb2.insts if i.op != SsaOpType.PHI)):
if i1.op != i2.op:
err = ValueError(f"Can't map [op] {i1} to {i2}")
logger.error(err)
cfg1.view_blocks()
cfg2.view_blocks()
raise err
if hasattr(i1, 'type'):
if not hasattr(i2, 'type') or i1.type != i2.type:
err = ValueError(f"Can't map [type] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'value'):
if not hasattr(i2, 'value') or i1.value != i2.value:
err = ValueError(f"Can't map [value] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'dest'):
if not hasattr(i2, 'dest') or i1.dest not in name_map or name_map[i1.dest] != i2.dest:
cfg1.view_blocks()
cfg2.view_blocks()
err = ValueError(f"Can't map [dest] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'args') and i1.args is not None:
if not hasattr(i2, 'args') or i2.args is None:
err = ValueError(f"Can't map [args] {i1} to {i2}")
logger.error(err)
raise err
else:
for arg1, arg2 in zip(i1.args, i2.args):
if name_map[arg1] != arg2:
err = ValueError(f"Can't map [args] {i1} to {i2}")
logger.error(err)
raise err
if hasattr(i1, 'funcs'):
if not hasattr(i2, 'funcs'):
err = ValueError(f"Can't map [funcs] {i1} to {i2}")
logger.error(err)
raise err
else:
if i1.funcs != i2.funcs:
err = ValueError(f"Can't map [funcs] {i1} to {i2}")
logger.error(err)
raise err

class SsaCheckerTest(LoggedTestCase):
def test_example(self):
program = load_program()
Expand All @@ -337,7 +414,7 @@ def test_example(self):
insert_phi_functions(dom_tree, global_d2b)
rename_variables(cfg2, dom_tree, defs, global_names)

ssa_checker(cfg1, cfg2)
compare_ssa(cfg1, cfg2)

def rename_var(i: Instruction, a: str, b: str):
if hasattr(i, 'dest') and i.dest == a:
Expand All @@ -352,7 +429,7 @@ def rename_var(i: Instruction, a: str, b: str):
for i in bb.insts:
rename_var(i, 'a.2', 'a.meow')

ssa_checker(cfg1, cfg2)
compare_ssa(cfg1, cfg2)

def test_execute(self):
program = load_program()
Expand All @@ -372,47 +449,14 @@ def test_execute(self):
p = subprocess.Popen(["brili"], stdin=subprocess.PIPE, stdout=subprocess.PIPE)
_ = p.communicate(input=json_output.encode())

def test_brils(path: str):
brils = find_all_bril(path)
failed = []
for bril_file in brils:
logger.debug(f"Test {bril_file}")
program = load_program(bril_file)
cmd = ["brili"]
args = load_args(bril_file)
if args is not None:
cmd.extend(args)

json_input = serialize_bril(program)
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
golden = p.communicate(input=json_input.encode())
for func1 in program.functions:
logger.debug(f"Test function {func1.name}")
logger.flush()
construct_ssa(func1)
if not is_ssa(program):
err = ValueError(f"Program is not in ssa form")
logger.error(err)
raise err
json_output = serialize_bril(program)
p = subprocess.Popen(cmd, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
attempt = p.communicate(input=json_output.encode())
if golden != attempt:
err = ValueError(f"Computation comparison does not match\n\tGolden: <{ golden[0].decode().strip()}>\n\tAttempt: <{ attempt[0].decode().strip()}>")
failed.append(err)
if len(failed) > 0:
for err in failed:
logger.error(err)
raise ValueError(f"Errors: {failed}")

class IntegrationTest(LoggedTestCase):
def test_advanced_integration(self):
advanced_tests = os.path.realpath(f"{script_dir}/../bril/examples/test")
test_brils(advanced_tests)
test_brils_ssa(advanced_tests)

def test_basic_integration(self):
basic_tests = os.path.realpath(f"{script_dir}/../tests")
test_brils(basic_tests)
test_brils_ssa(basic_tests)

class GradeTest(LoggedTestCase):
def test_grade(self):
Expand Down

0 comments on commit 9e0515b

Please sign in to comment.