diff --git a/src/cry.ml b/src/cry.ml index 749b1fa..fe19cac 100644 --- a/src/cry.ml +++ b/src/cry.ml @@ -20,27 +20,32 @@ (** OCaml low level implementation of the shout source protocol. *) -let poll r w timeout = - let timeout = - match timeout with - | x when x < 0. -> Poll.Timeout.never - | 0. -> Poll.Timeout.immediate - | x -> - let frac, int = modf x in - let int = Int64.mul (Int64.of_float int) 1_000_000_000L in - let frac = Int64.of_float (frac *. 1_000_000_000.) in - let timeout = Int64.add int frac in - Poll.Timeout.after timeout - in +let poll = let poll = Poll.create () in - List.iter (fun fd -> Poll.set poll fd Poll.Event.read) r; - List.iter (fun fd -> Poll.set poll fd Poll.Event.write) w; - ignore (Poll.wait poll timeout); - let r = ref [] in - let w = ref [] in - Poll.iter_ready poll ~f:(fun fd -> function - | { Poll.Event.readable = true; _ } -> r := fd :: !r | _ -> w := fd :: !w); - (!r, !w) + fun r w timeout -> + let timeout = + match timeout with + | x when x < 0. -> Poll.Timeout.never + | 0. -> Poll.Timeout.immediate + | x -> + let frac, int = modf x in + let int = Int64.mul (Int64.of_float int) 1_000_000_000L in + let frac = Int64.of_float (frac *. 1_000_000_000.) in + let timeout = Int64.add int frac in + Poll.Timeout.after timeout + in + List.iter (fun fd -> Poll.set poll fd Poll.Event.read) r; + List.iter (fun fd -> Poll.set poll fd Poll.Event.write) w; + Fun.protect + (fun () -> + ignore (Poll.wait poll timeout); + let r = ref [] in + let w = ref [] in + Poll.iter_ready poll ~f:(fun fd -> function + | { Poll.Event.readable = true; _ } -> r := fd :: !r + | _ -> w := fd :: !w); + (!r, !w)) + ~finally:(fun () -> Poll.clear poll) type error = | Create of exn