Skip to content

Commit

Permalink
Add caml_assume_no_effects primitive and tests
Browse files Browse the repository at this point in the history
Passing a function [f] as argument of `caml_assume_no_effects`
guarantees that, when compiling with `--enable doubletranslate`, the
direct-style version of [f] is called, which is faster than the CPS
version. As a consequence, performing an effect in a transitive callee
of [f] will raise `Effect.Unhandled`, regardless of any effect handlers
installed before the call to `caml_assume_no_effects`, unless a new
effect handler was installed in the meantime.

Usage:

```
external assume_no_effects : (unit -> 'a) -> 'a = "caml_assume_no_effects"

... caml_assume_no_effects (fun () -> (* Will be called in direct style... *)) ...
```

When double translation is disabled, `caml_assume_no_effects` simply
acts like `fun f -> f ()`.

This primitive is exposed via `Js_of_ocaml.Js.Effect.assume_no_perform`.
  • Loading branch information
OlivierNicole committed Oct 10, 2024
1 parent 416feb2 commit f473083
Show file tree
Hide file tree
Showing 15 changed files with 484 additions and 8 deletions.
87 changes: 79 additions & 8 deletions compiler/lib/effects.ml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ let dominance_frontier g idom =

(****)

let effect_primitive_or_application = function
| Prim (Extern ("%resume" | "%perform" | "%reperform" | "caml_assume_no_perform"), _)
| Apply _ -> true
| Block (_, _, _, _)
| Field (_, _, _)
| Closure (_, _)
| Constant _
| Prim (_, _)
| Special _ -> false

(*
We establish the list of blocks that needs to be CPS-transformed. We
also mark blocks that correspond to function continuations or
Expand Down Expand Up @@ -182,11 +192,8 @@ let compute_needed_transformations ~cfg ~idom ~cps_needed ~blocks ~start =
(match fst block.branch with
| Branch (dst, _) -> (
match List.last block.body with
| Some
( Let
(x, (Apply _ | Prim (Extern ("%resume" | "%perform" | "%reperform"), _)))
, _ )
when Var.Set.mem x cps_needed ->
| Some (Let (x, e), _)
when effect_primitive_or_application e && Var.Set.mem x cps_needed ->
(* The block after a function application that needs to
be turned to CPS or an effect primitive needs to be
transformed. *)
Expand Down Expand Up @@ -735,7 +742,39 @@ let cps_instr ~st (instr : instr) : instr list =
(* Nothing to do for single-version functions. *)
[ instr ]
| Let (_, (Apply _ | Prim (Extern ("%resume" | "%perform" | "%reperform"), _))) ->
(* Applications of CPS functions and effect primitives require more work
(allocating a continuation and/or modifying end-of-block branches) and
are handled in a specialized function below. *)
assert false
| Let (x, Prim (Extern "caml_assume_no_perform", [ Pv f ])) ->
if double_translate ()
then
(* We just need to call [f] in direct style. *)
let unit = Var.fresh_n "unit" in
let exact = Global_flow.exact_call st.flow_info f 1 in
[ Let (unit, Constant (Int Targetint.zero))
; Let (x, Apply { exact; f; args = [ unit ] })
]
else (
(* The "needs CPS" case should have been taken care of by another, specialized
function below. *)
assert (not (Var.Set.mem x st.cps_needed));
(* Translated like the [Apply] case, with a unit argument *)
assert (
(* If this function is unknown to the global flow analysis, then it was
introduced by the lambda lifting and does not require CPS *)
Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
|| Global_flow.exact_call st.flow_info f 1);
let unit = Var.fresh_n "unit" in
[ Let (unit, Constant (Int Targetint.zero))
; Let (x, Apply { f; args = [ unit ]; exact = true })
])
| Let (_, Prim (Extern "caml_assume_no_perform", args)) ->
invalid_arg
@@ Format.sprintf
"Internal primitive `caml_assume_no_perform` takes exactly 1 argument (%d \
given)"
(List.length args)
| _ -> [ instr ]

let cps_block ~st ~k ~lifter_functions ~orig_pc block =
Expand Down Expand Up @@ -769,6 +808,27 @@ let cps_block ~st ~k ~lifter_functions ~orig_pc block =
|| Global_flow.exact_call st.flow_info f (List.length args)
in
tail_call ~st ~exact ~in_cps:true ~check:true ~f (args @ [ k ]) loc)
| Prim (Extern "caml_assume_no_perform", [ Pv f ])
when (not (double_translate ())) && Var.Set.mem x st.cps_needed ->
(* Translated like the [Apply] case, with a unit argument *)
Some
(fun ~k ->
let exact =
(* If this function is unknown to the global flow analysis, then it was
introduced by the lambda lifting and is exact *)
Var.idx f >= Var.Tbl.length st.flow_info.info_approximation
|| Global_flow.exact_call st.flow_info f 1
in
let unit = Var.fresh_n "unit" in
tail_call
~st
~instrs:[ Let (unit, Constant (Int Targetint.zero)), noloc ]
~exact
~in_cps:false
~check:true
~f
[ unit; k ]
loc)
| Prim (Extern "%resume", [ Pv stack; Pv f; Pv arg ]) ->
Some
(fun ~k ->
Expand Down Expand Up @@ -881,8 +941,7 @@ let rewrite_direct_instr ~st (instr, loc) =
the right number of parameter *)
assert (Global_flow.exact_call st.flow_info f (List.length args));
Let (x, Apply { f; args; exact = true }), loc
| Let (_, (Apply _ | Prim (Extern ("%resume" | "%perform" | "%reperform"), _))) ->
assert false
| Let (_, e) when effect_primitive_or_application e -> assert false
| _ -> instr, loc

(* If double-translating, modify all function applications and closure
Expand Down Expand Up @@ -940,6 +999,18 @@ let rewrite_direct_block
, Prim (Extern "caml_perform_effect", [ Pv effect; Pv continuation; Pc k ])
)
]
| Let (x, Prim (Extern "caml_assume_no_perform", [ Pv f ])) ->
(* We just need to call [f] in direct style. *)
let unit = Var.fresh_n "unit" in
let unit_val = Int Targetint.zero in
let exact = Global_flow.exact_call st.flow_info f 1 in
[ Let (unit, Constant unit_val); Let (x, Apply { exact; f; args = [ unit ] }) ]
| Let (_, Prim (Extern "caml_assume_no_perform", args)) ->
invalid_arg
@@ Format.sprintf
"Internal primitive `caml_assume_no_perform` takes exactly 1 argument (%d \
given)"
(List.length args)
| (Let _ | Assign _ | Set_field _ | Offset_ref _ | Array_set _) as instr ->
[ instr ]
in
Expand Down Expand Up @@ -1384,7 +1455,7 @@ let split_blocks ~cps_needed ~lifter_functions (p : Code.program) =
let split_block pc block p =
let is_split_point i r branch =
match i with
| Let (x, (Apply _ | Prim (Extern ("%resume" | "%perform" | "%reperform"), _))) -> (
| Let (x, e) when effect_primitive_or_application e -> (
((not (List.is_empty r))
||
match fst branch with
Expand Down
13 changes: 13 additions & 0 deletions compiler/lib/partial_cps_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,15 @@ let block_deps ~info ~vars ~tail_deps ~deps ~blocks ~fun_name pc =
(* If a function contains effect primitives, it must be
in CPS *)
add_dep deps f x)
| Let (x, Prim (Extern "caml_assume_no_perform", _)) -> (
add_var vars x;
match fun_name with
| None -> ()
| Some f ->
add_var vars f;
(* If a function contains effect primitives, it must be
in CPS *)
add_dep deps f x)
| Let (x, Closure _) -> add_var vars x
| Let (_, (Prim _ | Block _ | Constant _ | Field _ | Special _))
| Assign _ | Set_field _ | Offset_ref _ | Array_set _ -> ())
Expand Down Expand Up @@ -141,6 +150,10 @@ let cps_needed ~info ~in_mutual_recursion ~rev_deps st x =
| Expr (Prim (Extern ("%perform" | "%reperform" | "%resume"), _)) ->
(* Effects primitives are in CPS *)
true
| Expr (Prim (Extern "caml_assume_no_perform", _)) ->
(* This primitive calls its function argument in direct style when double translation
is enabled. Otherwise, it simply applies its argument to unit. *)
not (Config.Flag.double_translation ())
| Expr (Prim _ | Block _ | Constant _ | Field _ | Special _) | Phi _ -> false

module SCC = Strongly_connected_components.Make (struct
Expand Down
1 change: 1 addition & 0 deletions compiler/tests-check-prim/main.output
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Missing

From main.bc:
caml_alloc_dummy_function
caml_assume_no_perform
caml_dynlink_add_primitive
caml_dynlink_close_lib
caml_dynlink_get_current_libs
Expand Down
1 change: 1 addition & 0 deletions compiler/tests-check-prim/main.output5
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Missing

From main.bc:
caml_alloc_dummy_function
caml_assume_no_perform
caml_continuation_use
caml_drop_continuation
caml_dynlink_add_primitive
Expand Down
1 change: 1 addition & 0 deletions compiler/tests-check-prim/unix-unix.output
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Missing

From unix.bc:
caml_alloc_dummy_function
caml_assume_no_perform
caml_dynlink_add_primitive
caml_dynlink_close_lib
caml_dynlink_get_current_libs
Expand Down
1 change: 1 addition & 0 deletions compiler/tests-check-prim/unix-unix.output5
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Missing

From unix.bc:
caml_alloc_dummy_function
caml_assume_no_perform
caml_continuation_use
caml_drop_continuation
caml_dynlink_add_primitive
Expand Down
1 change: 1 addition & 0 deletions compiler/tests-check-prim/unix-win32.output
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Missing

From unix.bc:
caml_alloc_dummy_function
caml_assume_no_perform
caml_dynlink_add_primitive
caml_dynlink_close_lib
caml_dynlink_get_current_libs
Expand Down
1 change: 1 addition & 0 deletions compiler/tests-check-prim/unix-win32.output5
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Missing

From unix.bc:
caml_alloc_dummy_function
caml_assume_no_perform
caml_continuation_use
caml_drop_continuation
caml_dynlink_add_primitive
Expand Down
164 changes: 164 additions & 0 deletions compiler/tests-ocaml/lib-effects/assume_no_perform.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
open Printf
open Effect
open Effect.Deep

module type TREE = sig
type 'a t
(** The type of tree. *)

val leaf : 'a t
(** A tree with only a leaf. *)

val node : 'a t -> 'a -> 'a t -> 'a t
(** [node l x r] constructs a new tree with a new node [x] as the value, with
[l] and [r] being the left and right sub-trees. *)

val deep : int -> int t
(** [deep n] constructs a tree of depth n, in linear time, where every node at
level [l] has value [l]. *)

val to_iter : 'a t -> ('a -> unit) -> unit
(** Iterator function. *)

val to_gen : 'a t -> unit -> 'a option
(** Generator function. [to_gen t] returns a generator function [g] for the
tree that traverses the tree in depth-first fashion, returning [Some x]
for each node when [g] is invoked. [g] returns [None] once the traversal
is complete. *)

val to_gen_cps : 'a t -> unit -> 'a option
(** CPS version of the generator function. *)
end

module Tree : TREE = struct
type 'a t =
| Leaf
| Node of 'a t * 'a * 'a t

let leaf = Leaf

let node l x r = Node (l, x, r)

let rec deep = function
| 0 -> Leaf
| n ->
let t = deep (n - 1) in
Node (t, n, t)

let rec iter f = function
| Leaf -> ()
| Node (l, x, r) ->
iter f l;
f x;
iter f r

(* val to_iter : 'a t -> ('a -> unit) -> unit *)
let to_iter t f = iter f t

(* val to_gen : 'a t -> (unit -> 'a option) *)
let to_gen (type a) (t : a t) =
let module M = struct
type _ Effect.t += Next : a -> unit Effect.t
end in
let open M in
let rec step =
ref (fun () ->
try_with
(fun t ->
iter (fun x -> perform (Next x)) t;
None)
t
{ effc =
(fun (type a) (e : a Effect.t) ->
match e with
| Next v ->
Some
(fun (k : (a, _) continuation) ->
(step := fun () -> continue k ());
Some v)
| _ -> None)
})
in
fun () -> !step ()

let to_gen_cps t =
let next = ref t in
let cont = ref Leaf in
let rec iter t k =
match t with
| Leaf -> run k
| Node (left, x, right) -> iter left (Node (k, x, right))
and run = function
| Leaf -> None
| Node (k, x, right) ->
next := right;
cont := k;
Some x
in
fun () -> iter !next !cont
end

let get_mean_sd l =
let get_mean l =
List.fold_right (fun a v -> a +. v) l 0. /. (float_of_int @@ List.length l)
in
let mean = get_mean l in
let sd = get_mean @@ List.map (fun v -> abs_float (v -. mean) ** 2.) l in
mean, sd

let benchmark f n =
let rec run acc = function
| 0 -> acc
| n ->
let t1 = Sys.time () in
let () = f () in
let d = Sys.time () -. t1 in
run (d :: acc) (n - 1)
in
let r = run [] n in
get_mean_sd r

(* Main follows *)

type _ Effect.t += Dummy : unit t

let () =
try_with
(fun () ->
let n = try int_of_string Sys.argv.(1) with _ -> 21 in
let t = Tree.deep n in
let iter_fun () = Tree.to_iter t (fun _ -> ()) in
let rec consume_all f =
match f () with
| None -> ()
| Some _ -> consume_all f
in

(* The code below should be called in direct style despite the installed
effect handler *)
Js_of_ocaml.Js.Effect.assume_no_perform (fun () ->
let m, sd = benchmark iter_fun 5 in
let () = printf "Iter: mean = %f, sd = %f\n%!" m sd in

let gen_cps_fun () =
let f = Tree.to_gen_cps t in
consume_all f
in

let m, sd = benchmark gen_cps_fun 5 in
printf "Gen_cps: mean = %f, sd = %f\n%!" m sd);

let gen_fun () =
let f = Tree.to_gen t in
consume_all f
in

let m, sd = benchmark gen_fun 5 in
printf "Gen_eff: mean = %f, sd = %f\n%!" m sd)
()
{ effc =
(fun (type a) (e : a Effect.t) ->
match e with
| Dummy -> Some (fun (k : (a, _) continuation) -> continue k ())
| _ -> None)
}
Loading

0 comments on commit f473083

Please sign in to comment.