Skip to content

Commit

Permalink
add: type inference
Browse files Browse the repository at this point in the history
  • Loading branch information
momeemt committed Aug 13, 2024
1 parent 9696825 commit 7343d83
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 22 deletions.
4 changes: 2 additions & 2 deletions app/app.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
open Compiler.Tokenizer
open Compiler.Parser
open Compiler.Codegen
(* open Compiler.Inferer *)
open Compiler.Inferer

let () =
if Array.length Sys.argv <> 2 then
Expand All @@ -17,7 +17,7 @@ let () =
close_in in_channel;
let tokens = tokenize source_code in
let ast = parse tokens in
(* let _ = tinf ast in *)
let _ = tinf ast in
let wat = codegen ast in
let out_channel = open_out wat_file in
output_string out_channel wat;
Expand Down
68 changes: 51 additions & 17 deletions compiler/inferer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ let rec string_of_tyenv tenv =
let rec string_of_tysubst (tsubst : (tyvar * ty) list) =
match tsubst with
| (tyv, t) :: rest ->
(Printf.sprintf "%s :: %s, " tyv (string_of_ty t)) ^ string_of_tysubst rest
Printf.sprintf "%s :: %s, " tyv (string_of_ty t) ^ string_of_tysubst rest
| [] -> ""

let theta0 = ([] : tysubst)
Expand Down Expand Up @@ -125,12 +125,31 @@ let tinf e =
(te, TList t, theta0, new_n)
| List (h :: tl) ->
let te1, t1, theta1, n1 = aux te h n in
let te2, t2, theta2, n2 = aux te1 (List tl) n1 in
let te2, types, theta2, n2 =
List.fold_left
(fun (te_acc, t_acc, theta_acc, n_acc) e ->
let te_next, t_next, theta_next, n_next = aux te_acc e n_acc in
let t_next' = subst_ty theta_next t_next in
let theta_acc' = compose_subst theta_next theta_acc in
( subst_tyenv theta_acc' te_next,
t_acc @ [ t_next' ],
theta_acc',
n_next ))
(te1, [], theta0, n1) tl in
let _ = List.iter (fun t -> ignore (unify [ (t1, t) ])) types in
let t11 = subst_ty theta2 t1 in
let theta3 = unify [ (t11, t2) ] in
let theta3 = unify [ (t11, t1) ] in
let te3 = subst_tyenv theta3 te2 in
let theta4 = compose_subst theta3 (compose_subst theta2 theta1) in
(te3, TList t2, theta4, n2)
(te3, TList t11, theta4, n2)
| Cons (h, tl) ->
let te1, t1, theta1, n1 = aux te h n in
let te2, t2, theta2, n2 = aux te1 tl n1 in
let t11 = subst_ty theta2 t1 in
let theta3 = unify [ (t2, TList t11) ] in
let te3 = subst_tyenv theta3 te2 in
let theta4 = compose_subst theta3 (compose_subst theta2 theta1) in
(te3, TList t11, theta4, n2)
| Plus (e1, e2) | Minus (e1, e2) | Times (e1, e2) | Div (e1, e2) ->
let te1, t1, theta1, n1 = aux te e1 n in
let te2, t2, theta2, n2 = aux te1 e2 n1 in
Expand Down Expand Up @@ -173,13 +192,15 @@ let tinf e =
te params param_types
in
let te2, t_value, theta1, n2 = aux te1 value n1 in
let te3 = subst_tyenv theta1 te2 in
let te4, t_body, theta2, n3 = aux ((name, t_value) :: te3) body n2 in
let theta3 = unify [ (t_body, TUnit) ] in
let te5 = subst_tyenv theta3 te4 in
let theta4 = compose_subst theta3 (compose_subst theta2 theta1) in
print_string (string_of_tyenv te5);
(te5, TUnit, theta4, n3)
let t_func =
List.fold_right
(fun t_arg t_acc -> TArrow (t_arg, t_acc))
param_types t_value
in
let te3 = (name, t_func) :: subst_tyenv theta1 te2 in
let te4, t_body, theta2, n3 = aux te3 body n2 in
let theta3 = compose_subst theta2 theta1 in
(te4, t_body, theta3, n3)
| LetRec (name, params, value, body) ->
let param_types, n1 =
List.fold_right
Expand All @@ -201,16 +222,13 @@ let tinf e =
te1 params param_types
in
let te3, t_value, theta1, n3 = aux te2 value n2 in
let theta_func = unify [ (t_func, t_value) ] in
let theta_func = unify [ (t_ret, t_value) ] in
let te4 = subst_tyenv theta_func te3 in
let theta2 = compose_subst theta_func theta1 in
let te5, t_body, theta3, n4 = aux te4 body n3 in
let theta4 = unify [ (t_body, TUnit) ] in
let te6 = subst_tyenv theta4 te5 in
let theta_final = compose_subst theta4 (compose_subst theta3 theta2) in
(te6, TUnit, theta_final, n4)
let theta4 = compose_subst theta3 theta2 in
(te5, t_body, theta4, n4)
| App (func_name, args) ->
(* broken - App should have (Fun, params), not (String, params) *)
let t_func =
try List.assoc func_name te
with Not_found -> failwith ("Unknown function: " ^ func_name)
Expand All @@ -237,13 +255,29 @@ let tinf e =
let te_final = subst_tyenv theta_func te' in
let theta_final = compose_subst theta_func theta' in
(te_final, subst_ty theta_final t_ret, theta_final, n2)
| Sequence exprs ->
let rec aux_sequence te exprs n =
match exprs with
| [] -> (te, TUnit, theta0, n)
| [ e ] -> aux te e n
| e :: rest ->
let te1, t1, theta1, n1 = aux te e n in
let te2 = unify [ (t1, TUnit) ] in
let te3 = subst_tyenv te2 te1 in
let te4, t2, theta2, n2 = aux_sequence te3 rest n1 in
(te4, t2, compose_subst theta2 theta1, n2)
in
aux_sequence te exprs n
| rest -> failwith ("not implemented: " ^ string_of_ast rest)
in
let t1, t2, t3, t4 =
aux
[
("print_int32", TArrow (TInt, TUnit));
("print_string", TArrow (TString, TUnit));
("print_list", TArrow (TList (TVar "'list"), TUnit));
("list_length", TArrow (TList (TVar "'list"), TInt));
("discard", TArrow (TVar "'discard", TUnit));
]
e 1
in
Expand Down
7 changes: 4 additions & 3 deletions test/compiler/e2e.ml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
open Compiler.Codegen
open Compiler.Parser
open Compiler.Tokenizer
open Compiler.Inferer

let rec find_project_root current_dir =
let dune_project_path = Filename.concat current_dir "dune-project" in
Expand All @@ -15,7 +16,7 @@ let exec_code code test_name =
let tokens = tokenize code in
let ast = parse tokens in
let wat = codegen ast in
(* let _ = tinf ast in *)
let _ = tinf ast in
let filename =
find_project_root (Sys.getcwd ())
^ "/test/compiler/tmp/" ^ test_name ^ ".wat"
Expand Down Expand Up @@ -116,14 +117,14 @@ let () =
print_int32 (3 * 8)"
"4, 30, 24\n";
] );
("list_1", [ test_case_str "list_1" "print_list [1 2 3]" "[1, 2, 3]\n" ]);
("list_1", [ test_case_str "list_1" "print_list [1 2 3 4 5]" "[1, 2, 3, 4, 5]\n" ]);
( "list_length_1",
[ test_case "list_length_1" "print_int32 (list_length [10 20 30])" 3 ]
);
( "list_length_2",
[ test_case "list_length_2" "print_int32 (list_length [])" 0 ] );
( "list_cons_1",
[ test_case_str "list_cons_1" "print_list (1 :: [2 3])" "[1, 2, 3]\n" ] );
[ test_case_str "list_cons_1" "print_list (1 :: [2 3 4 5])" "[1, 2, 3, 4, 5]\n" ] );
( "list_cons_2",
[ test_case_str "list_cons_2" "print_list (1 :: 2 :: [3])" "[1, 2, 3]\n" ]
);
Expand Down

0 comments on commit 7343d83

Please sign in to comment.