Skip to content

Commit cecd206

Browse files
committed
Better operator overloading inference
This commit introduces a weak form a bi-directional typing, and does a two-pass typing of overloading operators arguments.
1 parent 4f84b7c commit cecd206

File tree

7 files changed

+120
-43
lines changed

7 files changed

+120
-43
lines changed

src/ecHiInductive.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ let trans_matchfix
284284
let filter = fun _ op -> EcDecl.is_ctor op in
285285
let PPApp ((cname, tvi), cargs) = pb.pop_pattern in
286286
let tvi = tvi |> omap (TT.transtvi env ue) in
287-
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in
287+
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in
288288

289289
match cts with
290290
| [] ->

src/ecPrinting.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,7 @@ let pp_opapp
941941
(es : 'a list))
942942
=
943943
let (nm, opname) =
944-
PPEnv.op_symb ppe op (Some (pred, tvi, List.map t_ty es)) in
944+
PPEnv.op_symb ppe op (Some (pred, tvi, (List.map t_ty es, None))) in
945945

946946
let inm = if nm = [] then fst outer else nm in
947947

@@ -1250,7 +1250,7 @@ let pp_chained_orderings (ppe : PPEnv.t) t_ty pp_sub outer fmt (f, fs) =
12501250
ignore (List.fold_left
12511251
(fun fe (op, tvi, f) ->
12521252
let (nm, opname) =
1253-
PPEnv.op_symb ppe op (Some (`Form, tvi, [t_ty fe; t_ty f]))
1253+
PPEnv.op_symb ppe op (Some (`Form, tvi, ([t_ty fe; t_ty f], None)))
12541254
in
12551255
Format.fprintf fmt " %t@ %a"
12561256
(fun fmt ->
@@ -1343,7 +1343,7 @@ let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form)
13431343
else l_l f2 onm e_bin_prio_rop4
13441344
| Fapp ({f_node = Fop (op, tys)}, [f1; f2]) ->
13451345
(let (inm, opname) =
1346-
PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in
1346+
PPEnv.op_symb ppe op (Some (`Form, tys, (List.map t_ty [f1; f2], None))) in
13471347
if inm <> [] && inm <> onm
13481348
then None
13491349
else match priority_of_binop opname with

src/ecScope.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1689,7 +1689,7 @@ module Ty = struct
16891689
let tvi = List.map (TT.transty tp_tydecl env ue) tvi in
16901690
let selected =
16911691
EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper)
1692-
(Some (EcUnify.TVIunamed tvi)) env (unloc op) ue []
1692+
(Some (EcUnify.TVIunamed tvi)) env (unloc op) ue ([], None)
16931693
in
16941694
let op =
16951695
match selected with

src/ecTyping.ml

Lines changed: 84 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -300,15 +300,15 @@ let select_local env (qs,s) =
300300
else None
301301

302302
(* -------------------------------------------------------------------- *)
303-
let select_pv env side name ue tvi psig =
303+
let select_pv env side name ue tvi (psig, retty) =
304304
if tvi <> None
305305
then []
306306
else
307307
try
308308
let pvs = EcEnv.Var.lookup_progvar ?side name env in
309309
let select (pv,ty) =
310310
let subue = UE.copy ue in
311-
let texpected = EcUnify.tfun_expected subue psig in
311+
let texpected = EcUnify.tfun_expected subue ?retty psig in
312312
try
313313
EcUnify.unify env subue ty texpected;
314314
[(pv, ty, subue)]
@@ -346,7 +346,7 @@ let gen_select_op
346346
(env : EcEnv.env)
347347
(name : EcSymbols.qsymbol)
348348
(ue : EcUnify.unienv)
349-
(psig : EcTypes.dom)
349+
(psig : EcTypes.dom * EcTypes.ty option)
350350

351351
: OpSelect.gopsel list
352352
=
@@ -432,7 +432,7 @@ let select_form_op env mode ~forcepv opsc name ue tvi psig =
432432
(* -------------------------------------------------------------------- *)
433433
let select_proj env opsc name ue tvi recty =
434434
let filter = (fun _ op -> EcDecl.is_proj op) in
435-
let ops = EcUnify.select_op ~filter tvi env name ue [recty] in
435+
let ops = EcUnify.select_op ~filter tvi env name ue ([recty], None) in
436436
let ops = List.map (fun (p, ty, ue, _) -> (p, ty, ue)) ops in
437437

438438
match ops, opsc with
@@ -1060,7 +1060,7 @@ let transpattern1 env ue (p : EcParsetree.plpattern) =
10601060
let fields =
10611061
let for1 (name, v) =
10621062
let filter = fun _ op -> EcDecl.is_proj op in
1063-
let fds = EcUnify.select_op ~filter None env (unloc name) ue [] in
1063+
let fds = EcUnify.select_op ~filter None env (unloc name) ue ([], None) in
10641064
match List.ohead fds with
10651065
| None ->
10661066
let exn = UnknownRecFieldName (unloc name) in
@@ -1200,7 +1200,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) =
12001200
let for1 rf =
12011201
let filter = fun _ op -> EcDecl.is_proj op in
12021202
let tvi = rf.rf_tvi |> omap (transtvi env ue) in
1203-
let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue [] in
1203+
let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue ([], None) in
12041204
match List.ohead fds with
12051205
| None ->
12061206
let exn = UnknownRecFieldName (unloc rf.rf_name) in
@@ -1289,7 +1289,7 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) =
12891289
let filter = fun _ op -> EcDecl.is_ctor op in
12901290
let PPApp ((cname, tvi), cargs) = pb in
12911291
let tvi = tvi |> omap (transtvi env ue) in
1292-
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in
1292+
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in
12931293

12941294
match cts with
12951295
| [] ->
@@ -2512,7 +2512,7 @@ and translvalue ue (env : EcEnv.env) lvalue =
25122512
let e, ety = e_tuple e, ttuple ety in
25132513
let name = ([], EcCoreLib.s_set) in
25142514
let esig = [xty; ety; codomty] in
2515-
let ops = select_exp_op env `InProc None name ue tvi esig in
2515+
let ops = select_exp_op env `InProc None name ue tvi (esig, None) in
25162516

25172517
match ops with
25182518
| [] ->
@@ -2581,8 +2581,9 @@ and trans_gbinding env ue decl =
25812581
and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
25822582
let state = PFS.create () in
25832583

2584-
let rec transf_r opsc env f =
2585-
let transf = transf_r opsc in
2584+
let rec transf_r_tyinfo opsc env ?tt f =
2585+
let transf env ?tt f =
2586+
transf_r opsc env ?tt f in
25862587

25872588
match f.pl_desc with
25882589
| PFhole -> begin
@@ -2814,20 +2815,18 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
28142815
| PFdecimal (n, f) ->
28152816
f_decimal (n, f)
28162817

2817-
| PFtuple args -> begin
2818-
let args = List.map (transf env) args in
2819-
match args with
2820-
| [] -> f_tt
2821-
| [f] -> f
2822-
| fs -> f_tuple fs
2823-
end
2818+
| PFtuple pes ->
2819+
let esig = List.map (fun _ -> EcUnify.UniEnv.fresh ue) pes in
2820+
tt |> oiter (fun tt -> unify_or_fail env ue f.pl_loc ~expct:tt (ttuple esig));
2821+
let es = List.map2 (fun tt pe -> transf env ~tt pe) esig pes in
2822+
f_tuple es
28242823

28252824
| PFident ({ pl_desc = name; pl_loc = loc }, tvi) ->
28262825
let tvi = tvi |> omap (transtvi env ue) in
28272826
let ops =
28282827
select_form_op
28292828
~forcepv:(PFS.isforced state)
2830-
env mode opsc name ue tvi [] in
2829+
env mode opsc name ue tvi ([], tt) in
28312830
begin match ops with
28322831
| [] ->
28332832
tyerror loc env (UnknownVarOrOp (name, []))
@@ -2962,13 +2961,43 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
29622961
check_mem f.pl_loc EcFol.mright;
29632962
EcFol.f_ands (List.map (do1 (EcFol.mleft, EcFol.mright)) fs)
29642963

2965-
| PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) ->
2964+
| PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) -> begin
2965+
let try_trans ?tt pe =
2966+
let ue' = EcUnify.UniEnv.copy ue in
2967+
let ps' = Option.map (fun ps -> ref !ps) ps in
2968+
match transf env ?tt pe with
2969+
| e -> Some e
2970+
| exception TyError (_, _, MultipleOpMatch _) ->
2971+
Option.iter (fun ps -> ps := !(Option.get ps')) ps;
2972+
EcUnify.UniEnv.restore ~dst:ue ~src:ue';
2973+
None
2974+
in
2975+
2976+
match
2977+
let ue' = EcUnify.UniEnv.copy ue in
2978+
let ps' = Option.map (fun ps -> ref !ps) ps in
2979+
let es = List.map (fun pe -> try_trans pe) pes in
2980+
let tvi = tvi |> omap (transtvi env ue) in
2981+
let esig = List.map (fun e ->
2982+
match e with Some e -> e.f_ty | None -> EcUnify.UniEnv.fresh ue
2983+
) es in
2984+
match
2985+
select_form_op ~forcepv:(PFS.isforced state)
2986+
env mode opsc name ue tvi (esig, tt)
2987+
with
2988+
| [sel] -> Some (sel, (es, esig, tvi))
2989+
| _ ->
2990+
Option.iter (fun ps -> ps := !(Option.get ps')) ps;
2991+
EcUnify.UniEnv.restore ~dst:ue ~src:ue';
2992+
None
2993+
with
2994+
| None -> begin
29662995
let tvi = tvi |> omap (transtvi env ue) in
29672996
let es = List.map (transf env) pes in
29682997
let esig = List.map EcFol.f_ty es in
29692998
let ops =
29702999
select_form_op ~forcepv:(PFS.isforced state)
2971-
env mode opsc name ue tvi esig in
3000+
env mode opsc name ue tvi (esig, tt) in
29723001

29733002
begin match ops with
29743003
| [] ->
@@ -2986,6 +3015,24 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
29863015
let matches = List.map (fun (_, _, subue, m) -> (m, subue)) ops in
29873016
tyerror loc env (MultipleOpMatch (name, esig, matches))
29883017
end
3018+
end
3019+
3020+
| Some ((_, _, subue, _) as sel, (es, esig, _tvi)) ->
3021+
EcUnify.UniEnv.restore ~dst:ue ~src:subue;
3022+
let es =
3023+
List.map2 (
3024+
fun (e, ty) pe ->
3025+
match e with None -> try_trans ~tt:ty pe | Some e -> Some e
3026+
) (List.combine es esig) pes in
3027+
let es =
3028+
List.map2 (
3029+
fun (e, ty) pe ->
3030+
match e with None -> transf env ~tt:ty pe | Some e -> e
3031+
) (List.combine es esig) pes in
3032+
let es = List.map2 (fun e l -> mk_loc l.pl_loc e) es pes in
3033+
EcUnify.UniEnv.restore ~src:ue ~dst:subue;
3034+
form_of_opselect (env, ue) loc sel es
3035+
end
29893036

29903037
| PFapp (e, pes) ->
29913038
let es = List.map (transf env) pes in
@@ -3041,25 +3088,30 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
30413088
let f1 = transf env pf1 in
30423089
unify_or_fail env ue pf1.pl_loc ~expct:pty f1.f_ty;
30433090
aty |> oiter (fun aty-> unify_or_fail env ue pf1.pl_loc ~expct:pty aty);
3044-
let f2 = transf penv f2 in
3091+
let f2 = transf penv ?tt f2 in
30453092
f_let p f1 f2
30463093

30473094
| PFforall (xs, pf) ->
30483095
let env, xs = trans_gbinding env ue xs in
30493096
let f = transf env pf in
3050-
unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty;
3051-
f_forall xs f
3097+
unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty;
3098+
f_forall xs f
30523099

30533100
| PFexists (xs, f1) ->
30543101
let env, xs = trans_gbinding env ue xs in
30553102
let f = transf env f1 in
3056-
unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty;
3057-
f_exists xs f
3103+
unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty;
3104+
f_exists xs f
30583105

30593106
| PFlambda (xs, f1) ->
30603107
let env, xs = trans_binding env ue xs in
3061-
let f = transf env f1 in
3062-
f_lambda (List.map (fun (x,ty) -> (x,GTty ty)) xs) f
3108+
let subtt = tt |> Option.map (fun tt ->
3109+
let codom = EcUnify.UniEnv.fresh ue in
3110+
unify_or_fail env ue (loc f) ~expct:(toarrow (List.snd xs) codom) tt;
3111+
codom
3112+
) in
3113+
let f = transf env ?tt:subtt f1 in
3114+
f_lambda (List.map (fun (x, ty) -> (x, GTty ty)) xs) f
30633115

30643116
| PFrecord (b, fields) ->
30653117
let (ctor, fields, (rtvi, reccty)) =
@@ -3190,11 +3242,12 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
31903242
unify_or_fail qenv ue post.pl_loc ~expct:tbool post'.f_ty;
31913243
f_eagerF pre' s1 fpath1 fpath2 s2 post'
31923244

3193-
in
3245+
and transf_r opsc env ?tt pf =
3246+
let f = transf_r_tyinfo opsc env ?tt pf in
3247+
let () = oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty) tt in
3248+
f
31943249

3195-
let f = transf_r None env pf in
3196-
tt |> oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty);
3197-
f
3250+
in transf_r None env ?tt pf
31983251

31993252
(* Type-check a memtype. *)
32003253
and trans_memtype env ue (pmemtype : pmemtype) : memtype =

src/ecUnify.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,15 +396,15 @@ let hastc env ue ty tc =
396396
ue := { !ue with ue_uf = uf; }
397397

398398
(* -------------------------------------------------------------------- *)
399-
let tfun_expected ue psig =
400-
let tres = UniEnv.fresh ue in
401-
EcTypes.toarrow psig tres
399+
let tfun_expected ue ?retty psig =
400+
let retty = ofdfl (fun () -> UniEnv.fresh ue) retty in
401+
EcTypes.toarrow psig retty
402402

403403
(* -------------------------------------------------------------------- *)
404404
type sbody = ((EcIdent.t * ty) list * expr) Lazy.t
405405

406406
(* -------------------------------------------------------------------- *)
407-
let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig =
407+
let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue (psig, retty) =
408408
ignore hidden; (* FIXME *)
409409
410410
let module D = EcDecl in
@@ -457,7 +457,7 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig
457457
458458
let (tip, tvs) = UniEnv.openty_r subue op.D.op_tparams tvi in
459459
let top = ty_subst tip op.D.op_ty in
460-
let texpected = tfun_expected subue psig in
460+
let texpected = tfun_expected subue ?retty psig in
461461
462462
(try unify env subue top texpected
463463
with UnificationFailure _ -> raise E.Failure);

src/ecUnify.mli

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
val unify : EcEnv.env -> unienv -> ty -> ty -> unit
3838
val hastc : EcEnv.env -> unienv -> ty -> Sp.t -> unit
3939

40-
val tfun_expected : unienv -> EcTypes.ty list -> EcTypes.ty
40+
val tfun_expected : unienv -> ?retty:ty -> EcTypes.ty list -> EcTypes.ty
4141

4242
type sbody = ((EcIdent.t * ty) list * expr) Lazy.t
4343

@@ -48,5 +48,5 @@ val select_op :
4848
-> EcEnv.env
4949
-> qsymbol
5050
-> unienv
51-
-> dom
51+
-> dom * ty option
5252
-> ((EcPath.path * ty list) * ty * unienv * sbody option) list

tests/overloading.ec

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
require import AllCore List.
2+
3+
theory T.
4+
op o : int.
5+
op a : int -> int -> int.
6+
end T.
7+
8+
theory U.
9+
op o : bool.
10+
op a : bool -> bool -> bool.
11+
end U.
12+
13+
import T U.
14+
15+
op foo : int -> unit.
16+
17+
op bar = foo o.
18+
19+
op plop1 = foldr a false [].
20+
21+
op plop2 = foldr (fun x => a x) false [].
22+
23+
op plop3 = foldr (fun x y => a x y) false [].
24+

0 commit comments

Comments
 (0)