Skip to content

Commit

Permalink
Change Hashtbl config to be part of the transactional data
Browse files Browse the repository at this point in the history
  • Loading branch information
polytypic committed May 26, 2023
1 parent 7fefe99 commit 368ccee
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 89 deletions.
200 changes: 115 additions & 85 deletions src/kcas_data/hashtbl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,16 +72,17 @@ type ('k, 'v) pending =
new_buckets : ('k, 'v) Assoc.t Loc.t array Loc.t;
}

type ('k, 'v) t = {
pending : ('k, 'v) pending Loc.t;
type ('k, 'v) r = {
pending : ('k, 'v) pending;
length : Accumulator.t;
buckets : ('k, 'v) Assoc.t Loc.t array Loc.t;
buckets : ('k, 'v) Assoc.t Loc.t array;
hash : 'k -> int;
equal : 'k -> 'k -> bool;
min_buckets : int;
max_buckets : int;
}

type ('k, 'v) t = ('k, 'v) r Loc.t
type 'k hashed_type = (module Stdlib.Hashtbl.HashedType with type t = 'k)

let lo_buckets = 1 lsl 5
Expand Down Expand Up @@ -112,6 +113,7 @@ let create ?hashed_type ?min_buckets ?max_buckets ?n_way () =
| None -> min_buckets_default
| Some c -> Int.max lo_buckets c |> Int.min hi_buckets |> Bits.ceil_pow_2
in
let t = Loc.make (Obj.magic ()) in
let max_buckets =
match max_buckets with
| None -> Int.max min_buckets max_buckets_default
Expand All @@ -120,16 +122,19 @@ let create ?hashed_type ?min_buckets ?max_buckets ?n_way () =
match hashed_type with
| None -> (Stdlib.Hashtbl.seeded_hash (Random.bits ()), ( = ))
| Some hashed_type -> HashedType.unpack hashed_type
and pending = Loc.make Nothing
and buckets = Loc.make [||]
and pending = Nothing
and buckets = Loc.make_array min_buckets []
and length = Accumulator.make ?n_way 0 in
Loc.set buckets @@ Loc.make_array min_buckets [];
{ pending; length; buckets; hash; equal; min_buckets; max_buckets }
Loc.set t { pending; length; buckets; hash; equal; min_buckets; max_buckets };
t

let n_way_of t = Accumulator.n_way_of (Loc.get t).length
let min_buckets_of t = (Loc.get t).min_buckets
let max_buckets_of t = (Loc.get t).max_buckets

let n_way_of t = Accumulator.n_way_of t.length
let min_buckets_of t = t.min_buckets
let max_buckets_of t = t.max_buckets
let hashed_type_of t = HashedType.pack t.hash t.equal
let hashed_type_of t =
let r = Loc.get t in
HashedType.pack r.hash r.equal

let bucket_of hash key buckets =
Array.unsafe_get buckets (hash key land (Array.length buckets - 1))
Expand All @@ -138,16 +143,16 @@ exception Done

module Xt = struct
let find_opt ~xt t k =
Xt.get ~xt t.buckets |> bucket_of t.hash k |> Xt.get ~xt
|> Assoc.find_opt t.equal k
let r = Xt.get ~xt t in
r.buckets |> bucket_of r.hash k |> Xt.get ~xt |> Assoc.find_opt r.equal k

let find_all ~xt t k =
Xt.get ~xt t.buckets |> bucket_of t.hash k |> Xt.get ~xt
|> Assoc.find_all t.equal k
let r = Xt.get ~xt t in
r.buckets |> bucket_of r.hash k |> Xt.get ~xt |> Assoc.find_all r.equal k

let mem ~xt t k =
Xt.get ~xt t.buckets |> bucket_of t.hash k |> Xt.get ~xt
|> Assoc.mem t.equal k
let r = Xt.get ~xt t in
r.buckets |> bucket_of r.hash k |> Xt.get ~xt |> Assoc.mem r.equal k

let get_or_alloc array_loc alloc =
let tx ~xt =
Expand All @@ -167,15 +172,18 @@ module Xt = struct
(* TODO: Implement pending operations such that multiple domains may be
working to complete them in parallel by extending the [state] to an array
of multiple partition [states]. *)
let must_be_done_in_this_tx = Xt.is_in_log ~xt t.pending in
match Xt.exchange ~xt t.pending Nothing with
| Nothing -> ()
let must_be_done_in_this_tx = Xt.is_in_log ~xt t in
let r = Xt.get ~xt t in
match r.pending with
| Nothing -> r
| Rehash { state; new_capacity; new_buckets } -> (
let new_buckets =
get_or_alloc new_buckets @@ fun () -> Loc.make_array new_capacity []
in
let old_buckets = Xt.exchange ~xt t.buckets new_buckets in
let hash = t.hash and mask = new_capacity - 1 in
let old_buckets = r.buckets in
let r = { r with pending = Nothing; buckets = new_buckets } in
Xt.set ~xt t r;
let hash = r.hash and mask = new_capacity - 1 in
let rehash_a_few_buckets ~xt =
(* We process buckets in descending order as that is slightly faster
with the transaction log. It also makes sure that we know when the
Expand Down Expand Up @@ -211,11 +219,14 @@ module Xt = struct
at a time. This gives expected linear time, O(n). *)
while true do
Xt.commit { tx = rehash_a_few_buckets }
done
with Done -> ())
done;
r
with Done -> r)
| Snapshot { state; snapshot } -> (
assert (not must_be_done_in_this_tx);
let buckets = Xt.get ~xt t.buckets in
let buckets = r.buckets in
let r = { r with pending = Nothing } in
Xt.set ~xt t r;
(* Check state to ensure that buckets have not been updated. *)
if Loc.fenceless_get state < 0 then Retry.invalid ();
let snapshot =
Expand All @@ -233,11 +244,12 @@ module Xt = struct
try
while true do
Xt.commit { tx = snapshot_a_few_buckets }
done
with Done -> ())
done;
r
with Done -> r)
| Filter_map { state; fn; raised; new_buckets } -> (
assert (not must_be_done_in_this_tx);
let old_buckets = Xt.get ~xt t.buckets in
let old_buckets = r.buckets in
(* Check state to ensure that buckets have not been updated. *)
if Loc.fenceless_get state < 0 then Retry.invalid ();
let new_capacity = Array.length old_buckets in
Expand All @@ -260,108 +272,122 @@ module Xt = struct
while true do
total_delta :=
!total_delta + Xt.commit { tx = filter_map_a_few_buckets }
done
done;
r
with
| Done ->
Accumulator.Xt.add ~xt t.length !total_delta;
Xt.set ~xt t.buckets new_buckets
| exn -> Loc.compare_and_set raised Done exn |> ignore)
Accumulator.Xt.add ~xt r.length !total_delta;
let r = { r with pending = Nothing; buckets = new_buckets } in
Xt.set ~xt t r;
r
| exn ->
Loc.compare_and_set raised Done exn |> ignore;
let r = { r with pending = Nothing } in
Xt.set ~xt t r;
r)

let make_rehash old_capacity new_capacity =
let state = Loc.make old_capacity and new_buckets = Loc.make [||] in
Rehash { state; new_capacity; new_buckets }
[@@inline]

let reset ~xt t =
perform_pending ~xt t;
Xt.set ~xt t.buckets [||];
Accumulator.Xt.set ~xt t.length 0;
Xt.set ~xt t.pending @@ make_rehash 0 t.min_buckets
let r = perform_pending ~xt t in
Accumulator.Xt.set ~xt r.length 0;
Xt.set ~xt t
{ r with pending = make_rehash 0 r.min_buckets; buckets = [||] }

let clear ~xt t = reset ~xt t

let remove ~xt t k =
perform_pending ~xt t;
let buckets = Xt.get ~xt t.buckets in
let r = perform_pending ~xt t in
let buckets = r.buckets in
let mask = Array.length buckets - 1 in
let bucket = Array.unsafe_get buckets (t.hash k land mask) in
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
let change = ref `None in
Xt.unsafe_modify ~xt bucket (fun kvs ->
let kvs' = Assoc.remove t.equal change k kvs in
let kvs' = Assoc.remove r.equal change k kvs in
if !change != `None then kvs' else kvs);
if !change == `Removed then (
Accumulator.Xt.decr ~xt t.length;
if t.min_buckets <= mask && Random.bits () land mask = 0 then
Accumulator.Xt.decr ~xt r.length;
if r.min_buckets <= mask && Random.bits () land mask = 0 then
let capacity = mask + 1 in
let length = Accumulator.Xt.get ~xt t.length in
let length = Accumulator.Xt.get ~xt r.length in
if length * 4 < capacity then
Xt.set ~xt t.pending @@ make_rehash capacity (capacity asr 1))
Xt.set ~xt t
{ r with pending = make_rehash capacity (capacity asr 1) })

let add ~xt t k v =
perform_pending ~xt t;
let buckets = Xt.get ~xt t.buckets in
let r = perform_pending ~xt t in
let buckets = r.buckets in
let mask = Array.length buckets - 1 in
let bucket = Array.unsafe_get buckets (t.hash k land mask) in
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
Xt.unsafe_modify ~xt bucket (List.cons (k, v));
Accumulator.Xt.incr ~xt t.length;
if mask + 1 < t.max_buckets && Random.bits () land mask = 0 then
Accumulator.Xt.incr ~xt r.length;
if mask + 1 < r.max_buckets && Random.bits () land mask = 0 then
let capacity = mask + 1 in
let length = Accumulator.Xt.get ~xt t.length in
let length = Accumulator.Xt.get ~xt r.length in
if capacity < length then
Xt.set ~xt t.pending @@ make_rehash capacity (capacity * 2)
Xt.set ~xt t { r with pending = make_rehash capacity (capacity * 2) }

let replace ~xt t k v =
perform_pending ~xt t;
let buckets = Xt.get ~xt t.buckets in
let r = perform_pending ~xt t in
let buckets = r.buckets in
let mask = Array.length buckets - 1 in
let bucket = Array.unsafe_get buckets (t.hash k land mask) in
let bucket = Array.unsafe_get buckets (r.hash k land mask) in
let change = ref `None in
Xt.unsafe_modify ~xt bucket (fun kvs ->
let kvs' = Assoc.replace t.equal change k v kvs in
let kvs' = Assoc.replace r.equal change k v kvs in
if !change != `None then kvs' else kvs);
if !change == `Added then (
Accumulator.Xt.incr ~xt t.length;
if mask + 1 < t.max_buckets && Random.bits () land mask = 0 then
Accumulator.Xt.incr ~xt r.length;
if mask + 1 < r.max_buckets && Random.bits () land mask = 0 then
let capacity = mask + 1 in
let length = Accumulator.Xt.get ~xt t.length in
let length = Accumulator.Xt.get ~xt r.length in
if capacity < length then
Xt.set ~xt t.pending @@ make_rehash capacity (capacity * 2))
Xt.set ~xt t { r with pending = make_rehash capacity (capacity * 2) })

let length ~xt t = Accumulator.Xt.get ~xt t.length
let length ~xt t = Accumulator.Xt.get ~xt (Xt.get ~xt t).length
let swap = Xt.swap
end

let find_opt t k =
Loc.get t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
let t = Loc.get t in
t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
|> Assoc.find_opt t.equal k

let find_all t k =
Loc.get t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
let t = Loc.get t in
t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
|> Assoc.find_all t.equal k

let find t k = match find_opt t k with None -> raise Not_found | Some v -> v

let mem t k =
Loc.get t.buckets |> bucket_of t.hash k |> Loc.fenceless_get
|> Assoc.mem t.equal k
let t = Loc.get t in
t.buckets |> bucket_of t.hash k |> Loc.fenceless_get |> Assoc.mem t.equal k

let clear t = Kcas.Xt.commit { tx = Xt.clear t }
let reset t = Kcas.Xt.commit { tx = Xt.reset t }
let remove t k = Kcas.Xt.commit { tx = Xt.remove t k }
let add t k v = Kcas.Xt.commit { tx = Xt.add t k v }
let replace t k v = Kcas.Xt.commit { tx = Xt.replace t k v }
let length t = Accumulator.get t.length
let length t = Accumulator.get (Loc.get t).length
let swap t1 t2 = Kcas.Xt.commit { tx = Xt.swap t1 t2 }

let snapshot ?length t =
let snapshot ?length ?record t =
let state = Loc.make 0 and snapshot = Loc.make [||] in
let pending = Snapshot { state; snapshot } in
let tx ~xt =
Xt.perform_pending ~xt t;
let r = Xt.perform_pending ~xt t in
length
|> Option.iter (fun length -> length := Accumulator.Xt.get ~xt t.length);
Loc.set state (Array.length (Kcas.Xt.get ~xt t.buckets));
Kcas.Xt.set ~xt t.pending pending
|> Option.iter (fun length -> length := Accumulator.Xt.get ~xt r.length);
record |> Option.iter (fun record -> record := r);
Loc.set state (Array.length r.buckets);
Kcas.Xt.set ~xt t { r with pending }
in
Kcas.Xt.commit { tx };
Kcas.Xt.commit { tx = Xt.perform_pending t };
Kcas.Xt.commit { tx = Xt.perform_pending t } |> ignore;
Loc.fenceless_get snapshot

let to_seq t =
Expand All @@ -384,29 +410,33 @@ let of_seq ?hashed_type ?min_buckets ?max_buckets ?n_way xs =
t

let rebuild ?hashed_type ?min_buckets ?max_buckets ?n_way t =
let record = ref (Obj.magic ()) and length = ref 0 in
let snapshot = snapshot ~length ~record t in
let r = !record in
let min_buckets =
match min_buckets with
| None -> min_buckets_of t
| None -> r.min_buckets
| Some c -> Int.max lo_buckets c |> Int.min hi_buckets |> Bits.ceil_pow_2
in
let max_buckets =
match max_buckets with
| None -> Int.max min_buckets (max_buckets_of t)
| None -> Int.max min_buckets r.max_buckets
| Some c -> Int.max min_buckets c |> Int.min hi_buckets |> Bits.ceil_pow_2
and n_way = match n_way with None -> n_way_of t | Some n -> n
and length = ref 0 in
let snapshot = snapshot ~length t in
and n_way =
match n_way with None -> Accumulator.n_way_of r.length | Some n -> n
in
let is_same_hashed_type =
match hashed_type with
| None -> true
| Some hashed_type -> HashedType.is_same_as t.hash t.equal hashed_type
| Some hashed_type -> HashedType.is_same_as r.hash r.equal hashed_type
and length = !length in
if is_same_hashed_type && min_buckets <= length && length <= max_buckets then (
let pending = Loc.make Nothing
and buckets = Loc.make [||]
let t = Loc.make (Obj.magic ()) in
let pending = Nothing
and buckets = Array.map Loc.make snapshot
and length = Accumulator.make ~n_way length in
Loc.set buckets @@ Array.map Loc.make snapshot;
{ t with pending; length; buckets; min_buckets; max_buckets })
Loc.set t { r with pending; length; buckets; min_buckets; max_buckets };
t)
else
let t = create ?hashed_type ~min_buckets ~max_buckets ~n_way () in
snapshot
Expand All @@ -427,12 +457,12 @@ let filter_map_inplace fn t =
and new_buckets = Loc.make [||] in
let pending = Filter_map { state; fn; raised; new_buckets } in
let tx ~xt =
Xt.perform_pending ~xt t;
Loc.set state (Array.length (Kcas.Xt.get ~xt t.buckets));
Kcas.Xt.set ~xt t.pending pending
let r = Xt.perform_pending ~xt t in
Loc.set state (Array.length r.buckets);
Kcas.Xt.set ~xt t { r with pending }
in
Kcas.Xt.commit { tx };
Kcas.Xt.commit { tx = Xt.perform_pending t };
Kcas.Xt.commit { tx = Xt.perform_pending t } |> ignore;
match Loc.fenceless_get raised with Done -> () | exn -> raise exn

let stats t =
Expand Down
3 changes: 3 additions & 0 deletions src/kcas_data/hashtbl_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ module type Ops = sig
val clear : ('x, ('k, 'v) t -> unit) fn
(** [clear] is a synonym for {!reset}. *)

val swap : ('x, ('k, 'v) t -> ('k, 'v) t -> unit) fn
(** [swap t1 t2] exchanges the contents of the hash tables [t1] and [t2]. *)

val remove : ('x, ('k, 'v) t -> 'k -> unit) fn
(** [remove t k] removes the most recent existing binding of key [k], if any,
from the hash table [t] thereby revealing the earlier binding of [k], if
Expand Down
10 changes: 6 additions & 4 deletions test/kcas_data/hashtbl_test.ml
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@ let () =
assert (
Hashtbl.to_seq t |> List.of_seq = [ ("key", 3); ("key", 2); ("key", 1) ]);
let u = Hashtbl.to_seq t |> Hashtbl.of_seq in
assert (Hashtbl.find u "key" = 1);
assert (Hashtbl.find t "key" = 3);
Hashtbl.filter_map_inplace (fun _ v -> if v = 1 then None else Some (-v)) t;
assert (Hashtbl.find_all t "key" = [ -3; -2 ]);
Hashtbl.swap t u;
assert (Hashtbl.find t "key" = 1);
assert (Hashtbl.find u "key" = 3);
Hashtbl.filter_map_inplace (fun _ v -> if v = 1 then None else Some (-v)) u;
assert (Hashtbl.find_all u "key" = [ -3; -2 ]);
Hashtbl.swap u t;
assert (Hashtbl.length t = 2);
(match
Hashtbl.filter_map_inplace
Expand Down

0 comments on commit 368ccee

Please sign in to comment.