Skip to content

Commit

Permalink
add: ast with inferred types
Browse files Browse the repository at this point in the history
  • Loading branch information
momeemt committed Aug 15, 2024
1 parent 85eb490 commit 5f27c45
Show file tree
Hide file tree
Showing 9 changed files with 296 additions and 210 deletions.
2 changes: 1 addition & 1 deletion app/app.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ let () =
close_in in_channel;
let tokens = tokenize source_code in
let ast = parse tokens in
let (te, _, _, _) = tinf ast in
let te, _, _, _, _ = tinf ast in
let wat = codegen ast te in
let out_channel = open_out wat_file in
output_string out_channel wat;
Expand Down
119 changes: 70 additions & 49 deletions compiler/ast.ml
Original file line number Diff line number Diff line change
@@ -1,60 +1,81 @@
open Types

type ast =
| Let of string * string list * ast * ast
| LetRec of string * string list * ast * ast
| Fun of string * ast
| App of string * ast list
| Sequence of ast list
| IntLit of int
| FloatLit of float
| StringLit of string
| BoolLit of bool
| List of ast list
| If of ast * ast * ast
| Eq of ast * ast
| Less of ast * ast
| Greater of ast * ast
| Plus of ast * ast
| Minus of ast * ast
| Times of ast * ast
| Div of ast * ast
| Cons of ast * ast
| Append of ast * ast
| Let of ty * string * string list * ast * ast
| LetRec of ty * string * string list * ast * ast
| App of ty * string * ast list
| Fun of ty * string * ast
| Sequence of ty * ast list
| IntLit of ty * int
| FloatLit of ty * float
| StringLit of ty * string
| BoolLit of ty * bool
| List of ty * ast list
| If of ty * ast * ast * ast
| Eq of ty * ast * ast
| Less of ty * ast * ast
| Greater of ty * ast * ast
| Plus of ty * ast * ast
| Minus of ty * ast * ast
| Times of ty * ast * ast
| Div of ty * ast * ast
| Cons of ty * ast * ast
| Append of ty * ast * ast

let rec string_of_ast ast =
match ast with
| Let (id, args, e1, e2) ->
| Let (ty, id, args, e1, e2) ->
"Let (" ^ id ^ ", [" ^ String.concat "; " args ^ "], " ^ string_of_ast e1
^ ", " ^ string_of_ast e2 ^ ")"
| LetRec (f, args, e1, e2) ->
^ ", " ^ string_of_ast e2 ^ ") : " ^ string_of_ty ty
| LetRec (ty, f, args, e1, e2) ->
"LetRec (" ^ f ^ ", [" ^ String.concat "; " args ^ "], "
^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Fun (id, body) -> "Fun (" ^ id ^ ", " ^ string_of_ast body ^ ")"
| App (name, exprs) ->
^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : " ^ string_of_ty ty
| App (ty, name, exprs) ->
"App (" ^ name ^ ": "
^ (List.map (fun expr -> string_of_ast expr) exprs |> String.concat " ")
^ ")"
| Sequence exprs ->
^ ") : " ^ string_of_ty ty
| Fun (ty, arg, body) ->
"Fun (" ^ arg ^ ", " ^ string_of_ast body ^ ") : " ^ string_of_ty ty
| Sequence (ty, exprs) ->
"Sequence ("
^ (List.map (fun expr -> string_of_ast expr) exprs |> String.concat ";")
^ ")"
| IntLit n -> "IntLit (" ^ string_of_int n ^ ")"
| FloatLit f -> "FloatLit (" ^ string_of_float f ^ ")"
| StringLit s -> "StringLit(" ^ s ^ ")"
| BoolLit b -> "BoolLit (" ^ string_of_bool b ^ ")"
| List l -> "List (" ^ (List.map string_of_ast l |> String.concat "; ") ^ ")"
| If (e1, e2, e3) ->
^ ") : " ^ string_of_ty ty
| IntLit (ty, n) -> "IntLit (" ^ string_of_int n ^ ") : " ^ string_of_ty ty
| FloatLit (ty, f) ->
"FloatLit (" ^ string_of_float f ^ ") : " ^ string_of_ty ty
| StringLit (ty, s) -> "StringLit(" ^ s ^ ") : " ^ string_of_ty ty
| BoolLit (ty, b) -> "BoolLit (" ^ string_of_bool b ^ ") : " ^ string_of_ty ty
| List (ty, l) ->
"List ("
^ (List.map string_of_ast l |> String.concat "; ")
^ ") : " ^ string_of_ty ty
| If (ty, e1, e2, e3) ->
"If (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ", "
^ string_of_ast e3 ^ ")"
| Eq (e1, e2) -> "Eq (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Less (e1, e2) -> "Less (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Greater (e1, e2) ->
"Greater (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Plus (e1, e2) -> "Plus (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Minus (e1, e2) ->
"Minus (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Times (e1, e2) ->
"Times (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Div (e1, e2) -> "Div (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Cons (e1, e2) -> "Cons (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
| Append (e1, e2) ->
"Append (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ")"
^ string_of_ast e3 ^ ") : " ^ string_of_ty ty
| Eq (ty, e1, e2) ->
"Eq (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Less (ty, e1, e2) ->
"Less (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Greater (ty, e1, e2) ->
"Greater (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Plus (ty, e1, e2) ->
"Plus (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Minus (ty, e1, e2) ->
"Minus (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Times (ty, e1, e2) ->
"Times (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Div (ty, e1, e2) ->
"Div (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Cons (ty, e1, e2) ->
"Cons (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
| Append (ty, e1, e2) ->
"Append (" ^ string_of_ast e1 ^ ", " ^ string_of_ast e2 ^ ") : "
^ string_of_ty ty
43 changes: 26 additions & 17 deletions compiler/codegen.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
open Ast
open Builtin
open Inferer
open Types
open Runtime.Instructions
open Runtime.Modules
open Runtime.Wasi
Expand Down Expand Up @@ -64,10 +65,10 @@ let codegen ast te =
aux func_name funcs (Env.add name (rand_name, Func) env) body addr
in
match expr with
| IntLit n ->
| IntLit (_, n) ->
let func = Funcs.find func_name funcs in
(Funcs.add func_name { func with body = [ I32Const n ] } funcs, addr)
| StringLit s ->
| StringLit (_, s) ->
let func = Funcs.find func_name funcs in
let start_addr = addr in
let new_func_body, addr =
Expand All @@ -80,7 +81,7 @@ let codegen ast te =
in
let new_func_body = new_func_body @ [ I32Const start_addr ] in
(Funcs.add func_name { func with body = new_func_body } funcs, addr)
| List lst ->
| List (_, lst) ->
let func = Funcs.find func_name funcs in
let funcs, lst_instrs, end_addr =
List.fold_left
Expand All @@ -104,7 +105,7 @@ let codegen ast te =
let head_addr = if List.length lst = 0 then -1 else addr in
let new_func_body = lst_instrs @ [ I32Const head_addr ] in
(Funcs.add func_name { func with body = new_func_body } funcs, end_addr)
| Cons (cons, lst) ->
| Cons (_, cons, lst) ->
let func = Funcs.find func_name funcs in
let lst_funcs, addr = aux func_name funcs env lst addr in
let lst_expr_instr = (Funcs.find func_name lst_funcs).body in
Expand All @@ -124,7 +125,7 @@ let codegen ast te =
in
( Funcs.add func_name { func with body = new_func_body } funcs,
next_addr + 4 )
| Append (lst1, lst2) ->
| Append (_, lst1, lst2) ->
let func = Funcs.find func_name funcs in
let lst1_funcs, lst2_addr = aux func_name funcs env lst1 addr in
let lst1_expr_instr = (Funcs.find func_name lst1_funcs).body in
Expand Down Expand Up @@ -155,7 +156,7 @@ let codegen ast te =
in
( Funcs.add func_name { func with body = new_func_body } lst2_funcs,
lst_result_addr )
| App (name, args) ->
| App (_, name, args) ->
let func = Funcs.find func_name funcs in
let funcs, args_instrs, end_addr =
List.fold_left
Expand All @@ -176,7 +177,7 @@ let codegen ast te =
| Arg -> args_instrs @ [ LocalGet wat_name ]
in
(Funcs.add func_name { func with body = new_func_body } funcs, end_addr)
| Sequence exprs ->
| Sequence (_, exprs) ->
let func = Funcs.find func_name funcs in
let funcs, exprs_instrs, end_addr =
List.fold_left
Expand All @@ -187,18 +188,26 @@ let codegen ast te =
(funcs, [], addr) exprs
in
(Funcs.add func_name { func with body = exprs_instrs } funcs, end_addr)
| If (cond, then_, else_) -> aux_if cond then_ else_ addr
| Let (name, params, value, body) ->
| If (_, cond, then_, else_) -> aux_if cond then_ else_ addr
| Let (_, name, params, value, body) ->
aux_let name params value body false addr
| LetRec (name, params, value, body) ->
| LetRec (_, name, params, value, body) ->
aux_let name params value body true addr
| Plus (left, right) -> aux_binops left right I32Add addr
| Minus (left, right) -> aux_binops left right I32Sub addr
| Times (left, right) -> aux_binops left right I32Mul addr
| Div (left, right) -> aux_binops left right I32DivS addr
| Eq (left, right) -> aux_binops left right I32Eq addr
| Greater (left, right) -> aux_binops left right I32GtU addr
| Less (left, right) -> aux_binops left right I32LtU addr
| Plus (_, left, right) -> aux_binops left right I32Add addr
| Minus (_, left, right) -> aux_binops left right I32Sub addr
| Times (_, left, right) -> aux_binops left right I32Mul addr
| Div (_, left, right) -> aux_binops left right I32DivS addr
| Eq (_, left, right) ->
let left_funcs, addr = aux func_name funcs env left addr in
let left = (Funcs.find func_name left_funcs).body in
let right_funcs, addr = aux func_name left_funcs env right addr in
let func = Funcs.find func_name right_funcs in
let right = func.body in
let new_func_body = left @ right @ [ I32Eq ] in
let new_func = { func with body = new_func_body } in
(Funcs.add func_name new_func right_funcs, addr)
| Greater (_, left, right) -> aux_binops left right I32GtU addr
| Less (_, left, right) -> aux_binops left right I32LtU addr
| rest ->
raise
(CodegenError
Expand Down
2 changes: 1 addition & 1 deletion compiler/dune
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
(library
(name compiler)
(public_name wascaml.compiler)
(modules ast builtin codegen inferer parser tokenizer tokens)
(modules ast builtin codegen inferer parser tokenizer tokens types)
(libraries wascaml.runtime wascaml.sanitizer))
Loading

0 comments on commit 5f27c45

Please sign in to comment.