Skip to content
Merged
2 changes: 1 addition & 1 deletion FSharpActivePatterns/bin/REPL.ml
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ let run_repl dump_parsetree input_file =
(match ic with
| None ->
List.iter
(fun (n, t) -> fprintf std_formatter "%s : %a" n pp_typ t)
(fun (n, t) -> fprintf std_formatter "%s : %a\n" n pp_typ t)
names_and_types;
print_flush ();
run_repl_helper run env new_state
Expand Down
7 changes: 3 additions & 4 deletions FSharpActivePatterns/lib/ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

open KeywordChecker
open TypedTree
open TypesPp

type ident = Ident of string (** identifier *) [@@deriving show { with_path = false }]

Expand Down Expand Up @@ -99,11 +100,9 @@ type pattern =
| PVar of ident (** pattern identifier *)
| POption of pattern option
(*| Variant of (ident list[@gen gen_ident_small_list]) (** | [Blue, Green, Yellow] -> *) *)
| PConstraint of pattern * (typ[@gen gen_typ_sized (n / 4)])
| PConstraint of pattern * (typ[@gen gen_typ_primitive])
[@@deriving show { with_path = false }, qcheck]

let gen_typed_pattern_sized n = QCheck.Gen.(pair (gen_pattern_sized n) (return None))

type is_recursive =
| Nonrec (** let factorial n = ... *)
| Rec (** let rec factorial n = ... *)
Expand Down Expand Up @@ -151,7 +150,7 @@ and expr =
[@gen QCheck.Gen.(list_size (0 -- 2) (gen_let_bind_sized (n / 20)))])
* expr (** [let rec f x = if (x <= 0) then x else g x and g x = f (x-2) in f 3] *)
| Option of expr option (** [int option] *)
| EConstraint of expr * (typ[@gen gen_typ_sized (n / 4)])
| EConstraint of expr * (typ[@gen gen_typ_primitive])
[@@deriving show { with_path = false }, qcheck]

and let_bind =
Expand Down
28 changes: 18 additions & 10 deletions FSharpActivePatterns/lib/astPrinter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

open Format
open Ast
open TypesPp

let print_bin_op indent fmt = function
| Binary_equal -> fprintf fmt "%s| Binary Equal\n" (String.make indent '-')
Expand Down Expand Up @@ -51,13 +52,18 @@ let rec print_pattern indent fmt = function
print_pattern (indent + 2) fmt r
| PVar (Ident name) -> fprintf fmt "%s| PVar(%s)\n" (String.make indent '-') name
| POption p ->
fprintf fmt "%s| POption: " (String.make indent '-');
fprintf fmt "%s| POption " (String.make indent '-');
(match p with
| None -> fprintf fmt "None\n"
| Some p ->
fprintf fmt "Some:\n";
print_pattern (indent + 2) fmt p)
| PConstraint (p, _) -> print_pattern indent fmt p
| PConstraint (p, t) ->
fprintf fmt "%s| PConstraint\n" (String.make indent ' ');
fprintf fmt "%sPattern:\n" (String.make (indent + 2) ' ');
print_pattern (indent + 2) fmt p;
fprintf fmt "%sType:\n" (String.make (indent + 2) ' ');
fprintf fmt "%s| %a\n" (String.make (indent + 2) '-') pp_typ t
;;

let print_unary_op indent fmt = function
Expand All @@ -69,11 +75,9 @@ let rec print_let_bind indent fmt = function
| Let_bind (name, args, body) ->
fprintf fmt "%s| Let_bind:\n" (String.make indent '-');
fprintf fmt "%sNAME:\n" (String.make (indent + 4) ' ');
fprintf fmt "%s| %a\n" (String.make (indent + 4) '-') pp_pattern name;
print_pattern (indent + 4) fmt name;
fprintf fmt "%sARGS:\n" (String.make (indent + 4) ' ');
List.iter
(fun arg -> fprintf fmt "%s| %a\n" (String.make (indent + 2) '-') pp_pattern arg)
args;
List.iter (fun arg -> print_pattern (indent + 2) fmt arg) args;
fprintf fmt "%sBODY:\n" (String.make (indent + 4) ' ');
print_expr (indent + 2) fmt body

Expand Down Expand Up @@ -134,7 +138,6 @@ and print_expr indent fmt expr =
| Lambda (arg1, args, body) ->
fprintf fmt "%s| Lambda:\n" (String.make indent '-');
fprintf fmt "%sARGS\n" (String.make (indent + 2) ' ');
print_pattern (indent + 4) fmt arg1;
List.iter (fun pat -> print_pattern (indent + 4) fmt pat) (arg1 :: args);
fprintf fmt "%sBODY\n" (String.make (indent + 2) ' ');
print_expr (indent + 4) fmt body
Expand All @@ -147,11 +150,11 @@ and print_expr indent fmt expr =
| LetIn (rec_flag, let_bind, let_bind_list, inner_e) ->
fprintf
fmt
"%s | %s LetIn=\n"
"%s| %sLetIn=\n"
(String.make indent '-')
(match rec_flag with
| Nonrec -> ""
| Rec -> "Rec");
| Rec -> "Rec ");
fprintf fmt "%sLet_binds\n" (String.make (indent + 2) ' ');
List.iter (print_let_bind (indent + 2) fmt) (let_bind :: let_bind_list);
fprintf fmt "%sINNER_EXPRESSION\n" (String.make (indent + 2) ' ');
Expand All @@ -162,7 +165,12 @@ and print_expr indent fmt expr =
| Some e ->
fprintf fmt "%s| Option: Some\n" (String.make indent '-');
print_expr (indent + 2) fmt e)
| EConstraint (e, _) -> print_expr indent fmt e
| EConstraint (e, t) ->
fprintf fmt "%s| EConstraint\n" (String.make indent ' ');
fprintf fmt "%sExpr:\n" (String.make (indent + 2) ' ');
print_expr (indent + 2) fmt e;
fprintf fmt "%sType:\n" (String.make (indent + 2) ' ');
fprintf fmt "%s| %a\n" (String.make (indent + 2) '-') pp_typ t
;;

let print_statement indent fmt = function
Expand Down
144 changes: 85 additions & 59 deletions FSharpActivePatterns/lib/inferencer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,21 @@ type error =
| `Not_allowed_right_hand_side_let_rec
| `Not_allowed_left_hand_side_let_rec
| `Args_after_not_variable_let
| `Bound_several_times
]

let pp_error fmt : error -> _ = function
| `Occurs_check -> fprintf fmt "Occurs check failed"
| `Undef_var s -> fprintf fmt "Undefined variable '%s'" s
| `Unification_failed (fst, snd) ->
fprintf fmt "unification failed on %a and %a" pp_typ fst pp_typ snd
fprintf fmt "unification failed on %a and %a\n" pp_typ fst pp_typ snd
| `Not_allowed_right_hand_side_let_rec ->
fprintf fmt "This kind of expression is not allowed as right-hand side of `let rec'"
| `Not_allowed_left_hand_side_let_rec ->
fprintf fmt "Only variables are allowed as left-hand side of `let rec'"
| `Args_after_not_variable_let ->
fprintf fmt "Arguments in let allowed only after variable"
| `Bound_several_times -> fprintf fmt "Variable is bound several times"
;;

(* for treating result of type inference *)
Expand Down Expand Up @@ -357,7 +359,8 @@ end = struct
;; *)

let pp_without_freevars fmt t =
Map.iteri t ~f:(fun ~key ~data -> fprintf fmt "%s : %a" key pp_typ (Scheme.typ data))
Map.iteri t ~f:(fun ~key ~data ->
fprintf fmt "%s : %a\n" key pp_typ (Scheme.typ data))
;;

(* collect all free vars from environment *)
Expand Down Expand Up @@ -468,27 +471,40 @@ let infer_patterns env ~shadow patterns =
return (new_env, typ :: typs))
;;

let extract_names_from_pattern pat =
let rec helper = function
| PVar (Ident name) -> [ name ]
| PList l -> List.concat (List.map l ~f:helper)
| PCons (hd, tl) -> List.concat [ helper hd; helper tl ]
| PTuple (fst, snd, rest) ->
List.concat [ helper fst; helper snd; List.concat (List.map rest ~f:helper) ]
| POption (Some p) -> helper p
| PConstraint (p, _) -> helper p
| POption None -> []
| Wild -> []
| PConst _ -> []
in
helper pat
module StringSet = struct
include Stdlib.Set.Make (String)

let union_disjoint s1 s2 =
let* s1 = s1 in
let* s2 = s2 in
if is_empty (inter s1 s2) then return (union s1 s2) else fail `Bound_several_times
;;

let union_disjoint_many sets =
List.fold ~init:(return empty) ~f:(fun acc set -> union_disjoint acc set) sets
;;
end

let rec extract_names_from_pattern =
let extr = extract_names_from_pattern in
function
| PVar (Ident name) -> return (StringSet.singleton name)
| PList l -> StringSet.union_disjoint_many (List.map l ~f:extr)
| PCons (hd, tl) -> StringSet.union_disjoint (extr hd) (extr tl)
| PTuple (fst, snd, rest) ->
StringSet.union_disjoint_many (List.map ~f:extr (fst :: snd :: rest))
| POption (Some p) -> extr p
| PConstraint (p, _) -> extr p
| POption None -> return StringSet.empty
| Wild -> return StringSet.empty
| PConst _ -> return StringSet.empty
;;

let infer_match_pattern env ~shadow pattern match_type =
let* env, pat_typ = infer_pattern env ~shadow pattern in
let* subst = unify pat_typ match_type in
let env = TypeEnvironment.apply subst env in
let pat_names = extract_names_from_pattern pattern in
let* pat_names = extract_names_from_pattern pattern >>| StringSet.elements in
let generalized_schemes =
List.map pat_names ~f:(fun name ->
let typ = TypeEnvironment.find_typ_exn env name in
Expand All @@ -501,12 +517,11 @@ let infer_match_pattern env ~shadow pattern match_type =
;;

let extract_names_from_patterns pats =
List.fold pats ~init:[] ~f:(fun acc p ->
List.concat [ acc; extract_names_from_pattern p ])
StringSet.union_disjoint_many (List.map ~f:extract_names_from_pattern pats)
;;

let extract_bind_names_from_let_binds let_binds =
List.concat
StringSet.union_disjoint_many
(List.map let_binds ~f:(function Let_bind (pat, _, _) ->
extract_names_from_pattern pat))
;;
Expand All @@ -518,7 +533,7 @@ let extract_bind_patterns_from_let_binds let_binds =
let extend_env_with_bind_names env let_binds =
(* to prevent binds like let rec x = x + 1*)
let let_binds =
List.filter let_binds ~f:(function Let_bind (_, args, _) -> List.length args <> 0)
List.filter let_binds ~f:(function Let_bind (_, args, _) -> not (List.is_empty args))
in
let bind_names = extract_bind_patterns_from_let_binds let_binds in
let* env, _ = infer_patterns env ~shadow:true bind_names in
Expand Down Expand Up @@ -661,48 +676,55 @@ let rec infer_expr env = function
let* subst_final = Substitution.compose subst1 subst2 in
return (subst_final, typ)
| Function ((p1, e1), rest) ->
let* arg_type = make_fresh_var in
let* return_type = make_fresh_var in
let* subst, return_type =
List.fold
((p1, e1) :: rest)
~init:(return (Substitution.empty, return_type))
~f:(fun acc (pat, expr) ->
let* subst1, return_type = acc in
let* env, pat = infer_pattern env ~shadow:true pat in
let* subst2 = unify arg_type pat in
let env = TypeEnvironment.apply subst2 env in
let* subst3, expr_typ = infer_expr env expr in
let* subst4 = unify return_type expr_typ in
let* subst = Substitution.compose_all [ subst1; subst2; subst3; subst4 ] in
return (subst, Substitution.apply subst return_type))
in
return (subst, Arrow (Substitution.apply subst arg_type, return_type))
let* match_t = make_fresh_var in
let* return_t = make_fresh_var in
infer_matching_expr
env
((p1, e1) :: rest)
Substitution.empty
match_t
return_t
~with_arg:true
| Match (e, (p1, e1), rest) ->
let* subst_init, match_type = infer_expr env e in
let* subst_init, match_t = infer_expr env e in
let env = TypeEnvironment.apply subst_init env in
let* return_type = make_fresh_var in
let* subst, return_type =
List.fold
((p1, e1) :: rest)
~init:(return (subst_init, return_type))
~f:(fun acc (pat, expr) ->
let* subst1, return_type = acc in
let* env, subst2 = infer_match_pattern env ~shadow:true pat match_type in
let* subst12 = Substitution.compose subst1 subst2 in
let env = TypeEnvironment.apply subst12 env in
let* subst3, expr_typ = infer_expr env expr in
let* subst4 = unify return_type expr_typ in
let* subst = Substitution.compose_all [ subst12; subst3; subst4 ] in
return (subst, Substitution.apply subst return_type))
in
return (subst, return_type)
let* return_t = make_fresh_var in
infer_matching_expr env ((p1, e1) :: rest) subst_init match_t return_t ~with_arg:false
| EConstraint (e, t) ->
let* subst1, e_type = infer_expr env e in
let* subst2 = unify e_type (Substitution.apply subst1 t) in
let* subst_result = Substitution.compose subst1 subst2 in
return (subst_result, Substitution.apply subst2 e_type)

and infer_matching_expr env cases subst_init match_t return_t ~with_arg =
let* subst, return_t =
List.fold
cases
~init:(return (subst_init, return_t))
~f:(fun acc (pat, expr) ->
let* subst1, return_type = acc in
let* env, subst2 =
match with_arg with
| true ->
let* env, pat = infer_pattern env ~shadow:true pat in
let* subst2 = unify match_t pat in
return (env, subst2)
| false -> infer_match_pattern env ~shadow:true pat match_t
in
let* subst12 = Substitution.compose subst1 subst2 in
let env = TypeEnvironment.apply subst12 env in
let* subst3, expr_typ = infer_expr env expr in
let* subst4 = unify return_type expr_typ in
let* subst = Substitution.compose_all [ subst12; subst3; subst4 ] in
return (subst, Substitution.apply subst return_type))
in
let final_typ =
match with_arg with
| true -> Arrow (Substitution.apply subst match_t, return_t)
| false -> return_t
in
return (subst, final_typ)

and extend_env_with_let_binds env is_rec let_binds =
List.fold
let_binds
Expand All @@ -729,8 +751,8 @@ and infer_let_bind env is_rec let_bind =
let* subst2 = unify (Substitution.apply subst1 name_type) bind_type in
let* subst = Substitution.compose subst1 subst2 in
let env = TypeEnvironment.apply subst env in
let names = extract_names_from_pattern name in
let arg_names = extract_names_from_patterns args in
let* names = extract_names_from_pattern name >>| StringSet.elements in
let* arg_names = extract_names_from_patterns args >>| StringSet.elements in
let names_types = List.map names ~f:(fun n -> n, TypeEnvironment.find_typ_exn env n) in
let env = TypeEnvironment.remove_many env (List.concat [ names; arg_names ]) in
let names_schemes_list =
Expand All @@ -744,7 +766,9 @@ let infer_statement env = function
let let_binds = let_bind :: let_binds in
let* env = extend_env_with_bind_names env let_binds in
let* env, _ = extend_env_with_let_binds env Rec let_binds in
let bind_names = extract_bind_names_from_let_binds let_binds in
let* bind_names =
extract_bind_names_from_let_binds let_binds >>| StringSet.elements
in
let bind_names_with_types =
List.map bind_names ~f:(fun name ->
match TypeEnvironment.find_exn env name with
Expand All @@ -754,7 +778,9 @@ let infer_statement env = function
| Let (Nonrec, let_bind, let_binds) ->
let let_binds = let_bind :: let_binds in
let* env, _ = extend_env_with_let_binds env Nonrec let_binds in
let bind_names = extract_bind_names_from_let_binds let_binds in
let* bind_names =
extract_bind_names_from_let_binds let_binds >>| StringSet.elements
in
let bind_names_with_types =
List.map bind_names ~f:(fun name ->
match TypeEnvironment.find_exn env name with
Expand Down
1 change: 1 addition & 0 deletions FSharpActivePatterns/lib/inferencer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type error =
| `Not_allowed_right_hand_side_let_rec
| `Not_allowed_left_hand_side_let_rec
| `Args_after_not_variable_let
| `Bound_several_times
]

val pp_error : formatter -> error -> unit
Expand Down
Loading
Loading