Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better operators overloading inference #665

Merged
merged 1 commit into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ecHiInductive.ml
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ let trans_matchfix
let filter = fun _ op -> EcDecl.is_ctor op in
let PPApp ((cname, tvi), cargs) = pb.pop_pattern in
let tvi = tvi |> omap (TT.transtvi env ue) in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in

match cts with
| [] ->
Expand Down
6 changes: 3 additions & 3 deletions src/ecPrinting.ml
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ let pp_opapp
(es : 'a list))
=
let (nm, opname) =
PPEnv.op_symb ppe op (Some (pred, tvi, List.map t_ty es)) in
PPEnv.op_symb ppe op (Some (pred, tvi, (List.map t_ty es, None))) in

let inm = if nm = [] then fst outer else nm in

Expand Down Expand Up @@ -1250,7 +1250,7 @@ let pp_chained_orderings (ppe : PPEnv.t) t_ty pp_sub outer fmt (f, fs) =
ignore (List.fold_left
(fun fe (op, tvi, f) ->
let (nm, opname) =
PPEnv.op_symb ppe op (Some (`Form, tvi, [t_ty fe; t_ty f]))
PPEnv.op_symb ppe op (Some (`Form, tvi, ([t_ty fe; t_ty f], None)))
in
Format.fprintf fmt " %t@ %a"
(fun fmt ->
Expand Down Expand Up @@ -1343,7 +1343,7 @@ let lower_left (ppe : PPEnv.t) (t_ty : form -> EcTypes.ty) (f : form)
else l_l f2 onm e_bin_prio_rop4
| Fapp ({f_node = Fop (op, tys)}, [f1; f2]) ->
(let (inm, opname) =
PPEnv.op_symb ppe op (Some (`Form, tys, List.map t_ty [f1; f2])) in
PPEnv.op_symb ppe op (Some (`Form, tys, (List.map t_ty [f1; f2], None))) in
if inm <> [] && inm <> onm
then None
else match priority_of_binop opname with
Expand Down
2 changes: 1 addition & 1 deletion src/ecScope.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ module Ty = struct
let tvi = List.map (TT.transty tp_tydecl env ue) tvi in
let selected =
EcUnify.select_op ~filter:(fun _ -> EcDecl.is_oper)
(Some (EcUnify.TVIunamed tvi)) env (unloc op) ue []
(Some (EcUnify.TVIunamed tvi)) env (unloc op) ue ([], None)
in
let op =
match selected with
Expand Down
115 changes: 84 additions & 31 deletions src/ecTyping.ml
Original file line number Diff line number Diff line change
Expand Up @@ -300,15 +300,15 @@ let select_local env (qs,s) =
else None

(* -------------------------------------------------------------------- *)
let select_pv env side name ue tvi psig =
let select_pv env side name ue tvi (psig, retty) =
if tvi <> None
then []
else
try
let pvs = EcEnv.Var.lookup_progvar ?side name env in
let select (pv,ty) =
let subue = UE.copy ue in
let texpected = EcUnify.tfun_expected subue psig in
let texpected = EcUnify.tfun_expected subue ?retty psig in
try
EcUnify.unify env subue ty texpected;
[(pv, ty, subue)]
Expand Down Expand Up @@ -346,7 +346,7 @@ let gen_select_op
(env : EcEnv.env)
(name : EcSymbols.qsymbol)
(ue : EcUnify.unienv)
(psig : EcTypes.dom)
(psig : EcTypes.dom * EcTypes.ty option)

: OpSelect.gopsel list
=
Expand Down Expand Up @@ -432,7 +432,7 @@ let select_form_op env mode ~forcepv opsc name ue tvi psig =
(* -------------------------------------------------------------------- *)
let select_proj env opsc name ue tvi recty =
let filter = (fun _ op -> EcDecl.is_proj op) in
let ops = EcUnify.select_op ~filter tvi env name ue [recty] in
let ops = EcUnify.select_op ~filter tvi env name ue ([recty], None) in
let ops = List.map (fun (p, ty, ue, _) -> (p, ty, ue)) ops in

match ops, opsc with
Expand Down Expand Up @@ -1060,7 +1060,7 @@ let transpattern1 env ue (p : EcParsetree.plpattern) =
let fields =
let for1 (name, v) =
let filter = fun _ op -> EcDecl.is_proj op in
let fds = EcUnify.select_op ~filter None env (unloc name) ue [] in
let fds = EcUnify.select_op ~filter None env (unloc name) ue ([], None) in
match List.ohead fds with
| None ->
let exn = UnknownRecFieldName (unloc name) in
Expand Down Expand Up @@ -1200,7 +1200,7 @@ let trans_record env ue (subtt, proj) (loc, b, fields) =
let for1 rf =
let filter = fun _ op -> EcDecl.is_proj op in
let tvi = rf.rf_tvi |> omap (transtvi env ue) in
let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue [] in
let fds = EcUnify.select_op ~filter tvi env (unloc rf.rf_name) ue ([], None) in
match List.ohead fds with
| None ->
let exn = UnknownRecFieldName (unloc rf.rf_name) in
Expand Down Expand Up @@ -1289,7 +1289,7 @@ let trans_branch ~loc env ue gindty ((pb, body) : ppattern * _) =
let filter = fun _ op -> EcDecl.is_ctor op in
let PPApp ((cname, tvi), cargs) = pb in
let tvi = tvi |> omap (transtvi env ue) in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue [] in
let cts = EcUnify.select_op ~filter tvi env (unloc cname) ue ([], None) in

match cts with
| [] ->
Expand Down Expand Up @@ -2512,7 +2512,7 @@ and translvalue ue (env : EcEnv.env) lvalue =
let e, ety = e_tuple e, ttuple ety in
let name = ([], EcCoreLib.s_set) in
let esig = [xty; ety; codomty] in
let ops = select_exp_op env `InProc None name ue tvi esig in
let ops = select_exp_op env `InProc None name ue tvi (esig, None) in

match ops with
| [] ->
Expand Down Expand Up @@ -2581,8 +2581,9 @@ and trans_gbinding env ue decl =
and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
let state = PFS.create () in

let rec transf_r opsc env f =
let transf = transf_r opsc in
let rec transf_r_tyinfo opsc env ?tt f =
let transf env ?tt f =
transf_r opsc env ?tt f in

match f.pl_desc with
| PFhole -> begin
Expand Down Expand Up @@ -2814,20 +2815,18 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
| PFdecimal (n, f) ->
f_decimal (n, f)

| PFtuple args -> begin
let args = List.map (transf env) args in
match args with
| [] -> f_tt
| [f] -> f
| fs -> f_tuple fs
end
| PFtuple pes ->
let esig = List.map (fun _ -> EcUnify.UniEnv.fresh ue) pes in
tt |> oiter (fun tt -> unify_or_fail env ue f.pl_loc ~expct:tt (ttuple esig));
let es = List.map2 (fun tt pe -> transf env ~tt pe) esig pes in
f_tuple es

| PFident ({ pl_desc = name; pl_loc = loc }, tvi) ->
let tvi = tvi |> omap (transtvi env ue) in
let ops =
select_form_op
~forcepv:(PFS.isforced state)
env mode opsc name ue tvi [] in
env mode opsc name ue tvi ([], tt) in
begin match ops with
| [] ->
tyerror loc env (UnknownVarOrOp (name, []))
Expand Down Expand Up @@ -2962,13 +2961,43 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
check_mem f.pl_loc EcFol.mright;
EcFol.f_ands (List.map (do1 (EcFol.mleft, EcFol.mright)) fs)

| PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) ->
| PFapp ({pl_desc = PFident ({ pl_desc = name; pl_loc = loc }, tvi)}, pes) -> begin
let try_trans ?tt pe =
let ue' = EcUnify.UniEnv.copy ue in
let ps' = Option.map (fun ps -> ref !ps) ps in
match transf env ?tt pe with
| e -> Some e
| exception TyError (_, _, MultipleOpMatch _) ->
Option.iter (fun ps -> ps := !(Option.get ps')) ps;
EcUnify.UniEnv.restore ~dst:ue ~src:ue';
None
in

match
let ue' = EcUnify.UniEnv.copy ue in
let ps' = Option.map (fun ps -> ref !ps) ps in
let es = List.map (fun pe -> try_trans pe) pes in
let tvi = tvi |> omap (transtvi env ue) in
let esig = List.map (fun e ->
match e with Some e -> e.f_ty | None -> EcUnify.UniEnv.fresh ue
) es in
match
select_form_op ~forcepv:(PFS.isforced state)
env mode opsc name ue tvi (esig, tt)
with
| [sel] -> Some (sel, (es, esig, tvi))
| _ ->
Option.iter (fun ps -> ps := !(Option.get ps')) ps;
EcUnify.UniEnv.restore ~dst:ue ~src:ue';
None
with
| None -> begin
let tvi = tvi |> omap (transtvi env ue) in
let es = List.map (transf env) pes in
let esig = List.map EcFol.f_ty es in
let ops =
select_form_op ~forcepv:(PFS.isforced state)
env mode opsc name ue tvi esig in
env mode opsc name ue tvi (esig, tt) in

begin match ops with
| [] ->
Expand All @@ -2986,6 +3015,24 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
let matches = List.map (fun (_, _, subue, m) -> (m, subue)) ops in
tyerror loc env (MultipleOpMatch (name, esig, matches))
end
end

| Some ((_, _, subue, _) as sel, (es, esig, _tvi)) ->
EcUnify.UniEnv.restore ~dst:ue ~src:subue;
let es =
List.map2 (
fun (e, ty) pe ->
match e with None -> try_trans ~tt:ty pe | Some e -> Some e
) (List.combine es esig) pes in
let es =
List.map2 (
fun (e, ty) pe ->
match e with None -> transf env ~tt:ty pe | Some e -> e
) (List.combine es esig) pes in
let es = List.map2 (fun e l -> mk_loc l.pl_loc e) es pes in
EcUnify.UniEnv.restore ~src:ue ~dst:subue;
form_of_opselect (env, ue) loc sel es
end

| PFapp (e, pes) ->
let es = List.map (transf env) pes in
Expand Down Expand Up @@ -3041,25 +3088,30 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
let f1 = transf env pf1 in
unify_or_fail env ue pf1.pl_loc ~expct:pty f1.f_ty;
aty |> oiter (fun aty-> unify_or_fail env ue pf1.pl_loc ~expct:pty aty);
let f2 = transf penv f2 in
let f2 = transf penv ?tt f2 in
f_let p f1 f2

| PFforall (xs, pf) ->
let env, xs = trans_gbinding env ue xs in
let f = transf env pf in
unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty;
f_forall xs f
unify_or_fail env ue pf.pl_loc ~expct:tbool f.f_ty;
f_forall xs f

| PFexists (xs, f1) ->
let env, xs = trans_gbinding env ue xs in
let f = transf env f1 in
unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty;
f_exists xs f
unify_or_fail env ue f1.pl_loc ~expct:tbool f.f_ty;
f_exists xs f

| PFlambda (xs, f1) ->
let env, xs = trans_binding env ue xs in
let f = transf env f1 in
f_lambda (List.map (fun (x,ty) -> (x,GTty ty)) xs) f
let subtt = tt |> Option.map (fun tt ->
let codom = EcUnify.UniEnv.fresh ue in
unify_or_fail env ue (loc f) ~expct:(toarrow (List.snd xs) codom) tt;
codom
) in
let f = transf env ?tt:subtt f1 in
f_lambda (List.map (fun (x, ty) -> (x, GTty ty)) xs) f

| PFrecord (b, fields) ->
let (ctor, fields, (rtvi, reccty)) =
Expand Down Expand Up @@ -3190,11 +3242,12 @@ and trans_form_or_pattern env mode ?mv ?ps ue pf tt =
unify_or_fail qenv ue post.pl_loc ~expct:tbool post'.f_ty;
f_eagerF pre' s1 fpath1 fpath2 s2 post'

in
and transf_r opsc env ?tt pf =
let f = transf_r_tyinfo opsc env ?tt pf in
let () = oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty) tt in
f

let f = transf_r None env pf in
tt |> oiter (fun tt -> unify_or_fail env ue pf.pl_loc ~expct:tt f.f_ty);
f
in transf_r None env ?tt pf

(* Type-check a memtype. *)
and trans_memtype env ue (pmemtype : pmemtype) : memtype =
Expand Down
10 changes: 5 additions & 5 deletions src/ecUnify.ml
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,15 @@ let hastc env ue ty tc =
ue := { !ue with ue_uf = uf; }

(* -------------------------------------------------------------------- *)
let tfun_expected ue psig =
let tres = UniEnv.fresh ue in
EcTypes.toarrow psig tres
let tfun_expected ue ?retty psig =
let retty = ofdfl (fun () -> UniEnv.fresh ue) retty in
EcTypes.toarrow psig retty

(* -------------------------------------------------------------------- *)
type sbody = ((EcIdent.t * ty) list * expr) Lazy.t

(* -------------------------------------------------------------------- *)
let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig =
let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue (psig, retty) =
ignore hidden; (* FIXME *)

let module D = EcDecl in
Expand Down Expand Up @@ -457,7 +457,7 @@ let select_op ?(hidden = false) ?(filter = fun _ _ -> true) tvi env name ue psig

let (tip, tvs) = UniEnv.openty_r subue op.D.op_tparams tvi in
let top = ty_subst tip op.D.op_ty in
let texpected = tfun_expected subue psig in
let texpected = tfun_expected subue ?retty psig in

(try unify env subue top texpected
with UnificationFailure _ -> raise E.Failure);
Expand Down
4 changes: 2 additions & 2 deletions src/ecUnify.mli
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ end
val unify : EcEnv.env -> unienv -> ty -> ty -> unit
val hastc : EcEnv.env -> unienv -> ty -> Sp.t -> unit

val tfun_expected : unienv -> EcTypes.ty list -> EcTypes.ty
val tfun_expected : unienv -> ?retty:ty -> EcTypes.ty list -> EcTypes.ty

type sbody = ((EcIdent.t * ty) list * expr) Lazy.t

Expand All @@ -48,5 +48,5 @@ val select_op :
-> EcEnv.env
-> qsymbol
-> unienv
-> dom
-> dom * ty option
-> ((EcPath.path * ty list) * ty * unienv * sbody option) list
24 changes: 24 additions & 0 deletions tests/overloading.ec
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
require import AllCore List.

theory T.
op o : int.
op a : int -> int -> int.
end T.

theory U.
op o : bool.
op a : bool -> bool -> bool.
end U.

import T U.

op foo : int -> unit.

op bar = foo o.

op plop1 = foldr a false [].

op plop2 = foldr (fun x => a x) false [].

op plop3 = foldr (fun x y => a x y) false [].