Skip to content

Commit

Permalink
Merge pull request #75 from 0xffea/wip-pool
Browse files Browse the repository at this point in the history
implement a connection pool
  • Loading branch information
c-cube authored Oct 6, 2023
2 parents a21b734 + a14afcc commit 716b0f1
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 59 deletions.
149 changes: 96 additions & 53 deletions examples/bench_merge_sort.ml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
module C = Redis_sync.Client
open Lwt.Infix
module C = Redis_lwt.Client
module P = Redis_lwt.Pool

type t = {
c: C.connection;
pool: P.t;
l: int list;
n: int;
}
Expand All @@ -12,95 +14,136 @@ let mk_list n : int list =
CCList.init n (fun _ -> Random.State.int st 5_000)

(* make a fresh index *)
let mk_id (self:t) (pre:string) : string =
let i = C.incr self.c "bms:cur_id" in
let mk_id (self:t) (pre:string) : string Lwt.t =
P.with_connection self.pool (fun c -> C.incr c "bms:cur_id") >|= fun i ->
Printf.sprintf "bms:id:%s:%d" pre i

let ignore_int (_:int) = ()
let ignore_int (_x:int Lwt.t) = _x >|= fun _ -> ()

let str_of_list (self:t) (id:string) : string =
Printf.sprintf "[%s]" (String.concat ","@@ C.lrange self.c id 0 self.n)
let str_of_list (self:t) (id:string) : (int * string) Lwt.t =
P.with_connection self.pool (fun c -> C.lrange c id 0 self.n) >|= fun l ->
List.length l, Printf.sprintf "[%s]" (String.concat "," l)

let unwrap_opt_ msg = function
| Some x -> x
| None -> failwith msg

let run (self:t) : unit =
let c = self.c in
let id_list = mk_id self "list" in
let run (self:t) : unit Lwt.t =
mk_id self "list" >>= fun id_list ->
(* insert the whole list *)
let _n = C.rpush c id_list (List.rev_map string_of_int self.l) in
P.with_connection self.pool
(fun c -> C.rpush c id_list (List.rev_map string_of_int self.l))
>>= fun _n ->
assert (_n = self.n);
Printf.printf "initial: %s\n%!" (str_of_list self id_list);
str_of_list self id_list >>= fun (len,s_list) ->
Printf.printf "initial (len %d): %s\n%!" len s_list;
(* merge [id1] and [id2] into [into] *)
let merge (id1:string) (id2:string) ~into : unit =
(*Printf.printf "merge %s=%s and %s=%s into %s=%s\n%!"
id1 (str_of_list self id1)
id2 (str_of_list self id2) into (str_of_list self into); *)
let rec loop () : unit =
let len1 = C.llen c id1 in
let len2 = C.llen c id2 in
let merge (id1:string) (id2:string) ~into : unit Lwt.t =
(*Lwt.async (fun () ->
str_of_list self id1 >>= fun (_,s1) ->
str_of_list self id2 >>= fun (_,s2) ->
str_of_list self into >|= fun (_,sinto) ->
Printf.printf "merge %s=%s and %s=%s into %s=%s\n%!"
id1 s1 id2 s2 into sinto);*)
assert (id1 <> id2);
let rec loop () : unit Lwt.t =
let len1 = P.with_connection self.pool (fun c -> C.llen c id1) in
let len2 = P.with_connection self.pool (fun c -> C.llen c id2) in
len1 >>= fun len1 ->
len2 >>= fun len2 ->
(* Printf.printf " len1=%d, len2=%d\n%!" len1 len2; *)
if len1=0 && len2=0 then ()
if len1=0 && len2=0 then Lwt.return ()
else if len1=0 then (
C.rpush c into (C.lrange c id2 0 len2) |> ignore_int;
P.with_connection self.pool
(fun c -> C.lrange c id2 0 len2 >>= C.rpush c into) |> ignore_int
) else if len2=0 then (
C.rpush c into (C.lrange c id1 0 len1) |> ignore_int;
P.with_connection self.pool
(fun c -> C.lrange c id1 0 len1 >>= C.rpush c into) |> ignore_int
) else (
let x = C.lpop c id1 |> unwrap_opt_ "lpop id1" |> int_of_string in
let y = C.lpop c id2 |> unwrap_opt_ "lpop id2" |> int_of_string in
let x =
P.with_connection self.pool
(fun c -> C.lpop c id1 >|= unwrap_opt_ "lpop id1" >|= int_of_string)
and y =
P.with_connection self.pool
(fun c -> C.lpop c id2 >|= unwrap_opt_ "lpop id2" >|= int_of_string)
in
x >>= fun x ->
y >>= fun y ->
(* Printf.printf " x=%d, y=%d\n%!" x y; *)
if x<y then (
C.lpush c id2 [string_of_int y] |> ignore_int;
C.rpush c into [string_of_int x] |> ignore_int;
loop ();
P.with_connection self.pool (fun c ->
C.lpush c id2 [string_of_int y] >>= fun _ ->
C.rpush c into [string_of_int x] |> ignore_int)
>>= loop
) else (
C.lpush c id1 [string_of_int x] |> ignore_int;
C.rpush c into [string_of_int y] |> ignore_int;
loop ();
P.with_connection self.pool (fun c ->
C.lpush c id1 [string_of_int x] >>= fun _ ->
C.rpush c into [string_of_int y] |> ignore_int)
>>= loop
)
)
in
loop ();
(* Printf.printf " -> %s\n%!" (str_of_list self into); *)
(* str_of_list self into >>= fun (_,s) -> Printf.printf " -> [%s]=%s\n%!" into s; *)
loop ()
in
(* now recursively do merge sort *)
let rec sort (id_list:string) : unit =
let len = C.llen c id_list in
let rec sort (id_list:string) : unit Lwt.t =
P.with_connection self.pool (fun c -> C.llen c id_list)
>>= fun len ->
if len >= 2 then (
let mid = len/2 in
let l1 = mk_id self "list_tmp" in
let l2 = mk_id self "list_tmp" in
C.rpush c l1 (C.lrange c id_list 0 (mid-1)) |> ignore_int;
C.rpush c l2 (C.lrange c id_list mid len) |> ignore_int;
assert (C.llen c l1 + C.llen c l2 = len);
C.del c [id_list] |> ignore_int;
sort l1;
sort l2;
merge l1 l2 ~into:id_list;
C.del self.c [l1; l2] |> ignore_int; (* collect tmp clauses *)
)
l1 >>= fun l1 ->
l2 >>= fun l2 ->
let fut1 =
P.with_connection self.pool
(fun c -> C.lrange c id_list 0 (mid-1) >>= C.rpush c l1)
and fut2 =
P.with_connection self.pool
(fun c -> C.lrange c id_list mid len >>= C.rpush c l2)
in
fut1 >>= fun len1 ->
fut2 >>= fun len2 ->
assert (len1 + len2 = len);
P.with_connection self.pool
(fun c -> C.del c [id_list] |> ignore_int) >>= fun () ->
(* sort sublists in parallel *)
let fut1 = sort l1 in
let fut2 = sort l2 in
fut1 >>= fun () ->
fut2 >>= fun () ->
merge l1 l2 ~into:id_list >>= fun () ->
(* cleanup tmp clauses *)
P.with_connection self.pool (fun c -> C.del c [l1; l2]) >|= fun _ -> ()
) else Lwt.return ()
in
sort id_list;
Printf.printf "result: %s\n%!" (str_of_list self id_list);
let l = C.lrange c id_list 0 self.n |> List.map int_of_string in
C.del self.c [id_list] |> ignore_int;
C.del self.c ["bms:cur_id"] |> ignore_int;
sort id_list >>= fun () ->
str_of_list self id_list >>= fun (len,s_res) ->
Printf.printf "result (len %d): %s\n%!" len s_res;
P.with_connection self.pool (fun c ->
(C.lrange c id_list 0 self.n >|= List.map int_of_string) >>= fun l ->
C.del c [id_list] >>= fun _ ->
C.del c ["bms:cur_id"] >|= fun _ -> l)
>>= fun l ->
(* must be sorted *)
assert ( CCList.is_sorted ~cmp:CCInt.compare l);
assert (CCList.is_sorted ~cmp:CCInt.compare l);
(* same length *)
assert (List.length l = List.length self.l);
(* same elements *)
assert (
let module IS = CCSet.Make(CCInt) in
IS.equal (IS.of_list l) (IS.of_list self.l));
()
Lwt.return ()

let run ?(n=100_000) host port : unit =
let c = C.connect {C.host; port} in
let st = {n; c; l=mk_list n} in
let spec = {C.host; port} in
let start = Unix.gettimeofday () in
run st;
Lwt_main.run
(P.with_pool ~size:32 spec
(fun pool ->
let st = {n; pool; l=mk_list n} in
run st));
let stop = Unix.gettimeofday () in
Printf.printf "time: %.3fs\n%!" (stop -. start);
()
Expand Down
1 change: 0 additions & 1 deletion examples/dune
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
(modes native)
(name examples))


(alias
(name runtest)
(locks ../test)
Expand Down
2 changes: 2 additions & 0 deletions redis.opam
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ depends: [
"re" {>= "1.7.2"}
"ocaml" { >= "4.03.0" }
"odoc" {with-doc}
"containers" {with-test}
"ounit2" {with-test}
]
116 changes: 116 additions & 0 deletions src/pool.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@


module Make(IO : S.IO)(Client : S.Client with module IO=IO)
: S.POOL with module IO = IO and module Client = Client
= struct
module IO = IO
module Client = Client

open IO

type t = {
mutex: IO.mutex;
condition: IO.condition; (* for threads waiting for a connection *)
pool: Client.connection Queue.t; (* connections available *)
spec: Client.connection_spec;
size: int;
mutable closed: bool; (* once true, no query accepted *)
}

let size self = self.size

(* initialize [i] connections *)
let rec init_conns (self:t) i : unit IO.t =
if i<=0 then IO.return ()
else (
Client.connect self.spec >>= fun c ->
Queue.push c self.pool;
init_conns self (i-1)
)

let create ~size spec : t IO.t =
if size < 1 then invalid_arg "pool.create: size >= 1 required";
let self = {
mutex=IO.mutex_create ();
condition=IO.condition_create();
pool=Queue.create ();
spec;
size;
closed = false;
} in
init_conns self size >>= fun () ->
Format.printf "queue: %d@." (Queue.length self.pool);
IO.return self

let close (self:t) : unit IO.t =
self.closed <- true; (* should always be atomic *)
(* wake up waiters eagerly, to have them die earlier *)
IO.condition_broadcast self.condition;
(* close remaining connections *)
let rec close_conns_in_pool_ () =
if Queue.is_empty self.pool then IO.return ()
else (
let c = Queue.pop self.pool in
Client.disconnect c >>= close_conns_in_pool_
)
in
close_conns_in_pool_ ()

let with_pool ~size spec f : _ IO.t =
create ~size spec >>= fun pool ->
IO.try_bind
(fun () -> f pool)
(fun x -> close pool >|= fun () -> x)
(fun e -> close pool >>= fun () -> IO.fail e)

(* release a connection back into the pool, or close it if the
pool is closed. *)
let release_conn_ (self:t) (c:Client.connection) : unit IO.t =
IO.mutex_with self.mutex
(fun () ->
if self.closed then (
(* close connection *)
Client.disconnect c
) else (
(* release connection, and potentially wake up a waiter to grab it *)
Queue.push c self.pool;
IO.condition_signal self.condition;
IO.return ()
)
)

(* open a new connection and put it into the pool *)
let reopen_conn_ (self:t) : unit IO.t =
Client.connect self.spec >>= release_conn_ self

let rec with_connection (self:t) (f: _ -> 'a IO.t) : 'a IO.t =
if self.closed then IO.fail (Failure "pool closed")
else (
(* try to acquire a connection *)
IO.mutex_with self.mutex
(fun () ->
if Queue.is_empty self.pool then (
IO.condition_wait self.condition self.mutex >|= fun () ->
None
) else (
let c = Queue.pop self.pool in
IO.return (Some c)
))
>>= function
| None -> with_connection self f (* try again *)
| Some c ->
(* run [f c], and be sure to cleanup afterwards *)
IO.try_bind
(fun () -> f c)
(fun x -> release_conn_ self c >|= fun () -> x)
(fun e ->
(* close [c] and reopen a new one instead;
could have been interrupted during a transfer! *)
let fut1 = reopen_conn_ self in
let fut2 = Client.disconnect c in
fut1 >>= fun () ->
fut2 >>= fun () ->
IO.fail e)
)

end
3 changes: 3 additions & 0 deletions src/pool.mli
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

module Make(IO : S.IO)(Client : S.Client with module IO=IO)
: S.POOL with module IO = IO and module Client = Client
38 changes: 38 additions & 0 deletions src/s.ml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ module type IO = sig
val stream_from : (stream_count -> 'b option t) -> 'b stream
val stream_next: 'a stream -> 'a t

type mutex
val mutex_create : unit -> mutex
val mutex_with : mutex -> (unit -> 'a t) -> 'a t

type condition
val condition_create : unit -> condition
val condition_wait : condition -> mutex -> unit t
val condition_signal : condition -> unit
val condition_broadcast: condition -> unit
end

module type Client = sig
Expand Down Expand Up @@ -715,6 +724,11 @@ module type Client = sig
val unwatch : connection -> unit IO.t

val queue : (unit -> 'a IO.t) -> unit IO.t
(** Within a transaction (see {!multi}, {!exec}, and {!discard}),
commands will not return their normal value. It is necessary to
wrap each of them in their individual [Client.queue (fun () -> the_command)]
to avoid getting an exception [Unexpected (Status "QUEUED")].
*)

(** {2 Scripting commands} *)

Expand Down Expand Up @@ -821,3 +835,27 @@ module type Mutex = sig
val release : Client.connection -> string -> string -> unit IO.t
val with_mutex : Client.connection -> ?atime:float -> ?ltime:int -> string -> (unit -> 'a IO.t) -> 'a IO.t
end

(** {2 Connection pool} *)
module type POOL = sig
module IO : IO
module Client : Client

type t

val size : t -> int

val create : size:int -> Client.connection_spec -> t IO.t
(** Create a pool of [size] connections, using the given spec. *)

val close : t -> unit IO.t
(** Close all connections *)

val with_pool : size:int -> Client.connection_spec -> (t -> 'a IO.t) -> 'a IO.t
(** Create a pool of [size] connections, using the given spec,
pass it to the callback, and then destroy it. *)

val with_connection : t -> (Client.connection -> 'a IO.t) -> 'a IO.t
(** Temporarily require a connection to perform some operation.
The connection must not escape the scope of the callback *)
end
Loading

0 comments on commit 716b0f1

Please sign in to comment.