Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache prover keys #14396

Merged
merged 9 commits into from
Oct 31, 2023
105 changes: 66 additions & 39 deletions src/lib/pickles/cache.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ module Step = struct
[@@warning "-4"]
end

type storable =
(Key.Proving.t, Backend.Tick.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
( Key.Verification.t
, Kimchi_bindings.Protocol.VerifierIndex.Fp.t )
Key_cache.Sync.Disk_storable.t

let storable =
Key_cache.Sync.Disk_storable.simple Key.Proving.to_string
(fun (_, header, _, cs) ~path ->
Expand Down Expand Up @@ -83,9 +91,8 @@ module Step = struct
(Kimchi_bindings.Protocol.VerifierIndex.Fp.write (Some true) x)
header path ) )

let read_or_generate ~prev_challenges cache k_p k_v typ return_typ main =
let s_p = storable in
let s_v = vk_storable in
let read_or_generate ~prev_challenges cache ?(s_p = storable) k_p
?(s_v = vk_storable) k_v typ return_typ main =
let open Impls.Step in
let pk =
lazy
Expand Down Expand Up @@ -155,6 +162,12 @@ module Wrap = struct
end
end

type storable =
(Key.Proving.t, Backend.Tock.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
(Key.Verification.t, Verification_key.t) Key_cache.Sync.Disk_storable.t

let storable =
Key_cache.Sync.Disk_storable.simple Key.Proving.to_string
(fun (_, header, cs) ~path ->
Expand Down Expand Up @@ -182,10 +195,42 @@ module Wrap = struct
(Kimchi_bindings.Protocol.Index.Fq.write (Some true) t.index)
header path ) )

let read_or_generate ~prev_challenges cache k_p k_v typ return_typ main =
let vk_storable =
Key_cache.Sync.Disk_storable.simple Key.Verification.to_string
(fun (_, header, _cs) ~path ->
Or_error.try_with_join (fun () ->
let open Or_error.Let_syntax in
let%map header_read, index =
Snark_keys_header.read_with_header
~read_data:(fun ~offset:_ path ->
Binable.of_string
(module Verification_key.Stable.Latest)
(In_channel.read_all path) )
path
in
[%test_eq: int] header.header_version header_read.header_version ;
[%test_eq: Snark_keys_header.Kind.t] header.kind header_read.kind ;
[%test_eq: Snark_keys_header.Constraint_constants.t]
header.constraint_constants header_read.constraint_constants ;
[%test_eq: string] header.constraint_system_hash
header_read.constraint_system_hash ;
index ) )
(fun (_, header, _) t path ->
Or_error.try_with (fun () ->
Snark_keys_header.write_with_header
~expected_max_size_log2:33 (* 8 GB should be enough *)
~append_data:(fun path ->
Out_channel.with_file ~append:true path ~f:(fun file ->
Out_channel.output_string file
(Binable.to_string
(module Verification_key.Stable.Latest)
t ) ) )
header path ) )

let read_or_generate ~prev_challenges cache ?(s_p = storable) k_p
?(s_v = vk_storable) k_v typ return_typ main =
let module Vk = Verification_key in
let open Impls.Wrap in
let s_p = storable in
let pk =
lazy
(let k = Lazy.force k_p in
Expand All @@ -209,40 +254,6 @@ module Wrap = struct
let vk =
lazy
(let k_v = Lazy.force k_v in
let s_v =
Key_cache.Sync.Disk_storable.simple Key.Verification.to_string
(fun (_, header, _cs) ~path ->
Or_error.try_with_join (fun () ->
let open Or_error.Let_syntax in
let%map header_read, index =
Snark_keys_header.read_with_header
~read_data:(fun ~offset:_ path ->
Binable.of_string
(module Vk.Stable.Latest)
(In_channel.read_all path) )
path
in
[%test_eq: int] header.header_version
header_read.header_version ;
[%test_eq: Snark_keys_header.Kind.t] header.kind
header_read.kind ;
[%test_eq: Snark_keys_header.Constraint_constants.t]
header.constraint_constants
header_read.constraint_constants ;
[%test_eq: string] header.constraint_system_hash
header_read.constraint_system_hash ;
index ) )
(fun (_, header, _) t path ->
Or_error.try_with (fun () ->
Snark_keys_header.write_with_header
~expected_max_size_log2:33 (* 8 GB should be enough *)
~append_data:(fun path ->
Out_channel.with_file ~append:true path ~f:(fun file ->
Out_channel.output_string file
(Binable.to_string (module Vk.Stable.Latest) t) )
)
header path ) )
in
match Key_cache.Sync.read cache s_v k_v with
| Ok (vk, d) ->
(vk, d)
Expand All @@ -265,3 +276,19 @@ module Wrap = struct
in
(pk, vk)
end

module Storables = struct
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
type t =
{ step_storable : Step.storable
; step_vk_storable : Step.vk_storable
; wrap_storable : Wrap.storable
; wrap_vk_storable : Wrap.vk_storable
}

let default =
{ step_storable = Step.storable
; step_vk_storable = Step.vk_storable
; wrap_storable = Wrap.storable
; wrap_vk_storable = Wrap.vk_storable
}
end
47 changes: 45 additions & 2 deletions src/lib/pickles/cache.mli
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ module Step : sig
* Snark_keys_header.t
* int
* Backend.Tick.R1CS_constraint_system.t

val to_string : t -> string
end

module Verification : sig
Expand All @@ -17,13 +19,29 @@ module Step : sig
* Snark_keys_header.t
* int
* Core_kernel.Md5.t

val to_string : t -> string
end
end

type storable =
(Key.Proving.t, Backend.Tick.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
( Key.Verification.t
, Kimchi_bindings.Protocol.VerifierIndex.Fp.t )
Key_cache.Sync.Disk_storable.t

val storable : storable

val vk_storable : vk_storable

val read_or_generate :
prev_challenges:int
-> Key_cache.Spec.t list
-> ?s_p:storable
-> Key.Proving.t lazy_t
-> ?s_v:vk_storable
-> Key.Verification.t lazy_t
-> ('a, 'b) Impls.Step.Typ.t
-> ('c, 'd) Impls.Step.Typ.t
Expand All @@ -43,6 +61,8 @@ module Wrap : sig
Core_kernel.Type_equal.Id.Uid.t
* Snark_keys_header.t
* Backend.Tock.R1CS_constraint_system.t

val to_string : t -> string
end

module Verification : sig
Expand All @@ -59,11 +79,23 @@ module Wrap : sig
end
end

type storable =
(Key.Proving.t, Backend.Tock.Keypair.t) Key_cache.Sync.Disk_storable.t

type vk_storable =
(Key.Verification.t, Verification_key.t) Key_cache.Sync.Disk_storable.t

val storable : storable

val vk_storable : vk_storable

val read_or_generate :
prev_challenges:Core_kernel.Int.t
-> Key_cache.Spec.t list
-> Key.Proving.t Core_kernel.Lazy.t
-> Key.Verification.t Core_kernel.Lazy.t
-> ?s_p:storable
-> Key.Proving.t lazy_t
-> ?s_v:vk_storable
-> Key.Verification.t lazy_t
-> ('a, 'b) Impls.Wrap.Typ.t
-> ('c, 'd) Impls.Wrap.Typ.t
-> ('a -> unit -> 'c)
Expand All @@ -74,3 +106,14 @@ module Wrap : sig
* [> `Cache_hit | `Generated_something | `Locally_generated ] )
lazy_t
end

module Storables : sig
mitschabaude marked this conversation as resolved.
Show resolved Hide resolved
type t =
{ step_storable : Step.storable
; step_vk_storable : Step.vk_storable
; wrap_storable : Wrap.storable
; wrap_vk_storable : Wrap.vk_storable
}

val default : t
end
24 changes: 15 additions & 9 deletions src/lib/pickles/compile.ml
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ struct
type var value prev_varss prev_valuess widthss heightss max_proofs_verified branches.
self:(var, value, max_proofs_verified, branches) Tag.t
-> cache:Key_cache.Spec.t list
-> storables:Cache.Storables.t
-> proof_cache:Proof_cache.t option
-> ?disk_keys:
(Cache.Step.Key.Verification.t, branches) Vector.t
Expand Down Expand Up @@ -378,10 +379,13 @@ struct
* _
* _
* _ =
fun ~self ~cache ~proof_cache ?disk_keys
?(return_early_digest_exception = false) ?override_wrap_domain
?override_wrap_main ~branches:(module Branches) ~max_proofs_verified
~name ~constraint_constants ~public_input ~auxiliary_typ ~choices () ->
fun ~self ~cache
~storables:
{ step_storable; step_vk_storable; wrap_storable; wrap_vk_storable }
~proof_cache ?disk_keys ?(return_early_digest_exception = false)
?override_wrap_domain ?override_wrap_main ~branches:(module Branches)
~max_proofs_verified ~name ~constraint_constants ~public_input
~auxiliary_typ ~choices () ->
let snark_keys_header kind constraint_system_hash =
{ Snark_keys_header.header_version = Snark_keys_header.header_version
; kind
Expand Down Expand Up @@ -595,7 +599,7 @@ struct
Common.time "step read or generate" (fun () ->
Cache.Step.read_or_generate
~prev_challenges:(Nat.to_int (fst b.proofs_verified))
cache k_p k_v
cache ~s_p:step_storable k_p ~s_v:step_vk_storable k_v
(Snarky_backendless.Typ.unit ())
typ main )
in
Expand Down Expand Up @@ -671,7 +675,8 @@ struct
let r =
Common.time "wrap read or generate " (fun () ->
Cache.Wrap.read_or_generate (* Due to Wrap_hack *)
~prev_challenges:2 cache disk_key_prover disk_key_verifier typ
~prev_challenges:2 cache ~s_p:wrap_storable disk_key_prover
~s_v:wrap_vk_storable disk_key_verifier typ
(Snarky_backendless.Typ.unit ())
main )
in
Expand Down Expand Up @@ -938,6 +943,7 @@ let compile_with_wrap_main_override_promise :
type var value a_var a_value ret_var ret_value auxiliary_var auxiliary_value prev_varss prev_valuess widthss heightss max_proofs_verified branches.
?self:(var, value, max_proofs_verified, branches) Tag.t
-> ?cache:Key_cache.Spec.t list
-> ?storables:Cache.Storables.t
-> ?proof_cache:Proof_cache.t
-> ?disk_keys:
(Cache.Step.Key.Verification.t, branches) Vector.t
Expand Down Expand Up @@ -991,8 +997,8 @@ let compile_with_wrap_main_override_promise :
(* This function is an adapter between the user-facing Pickles.compile API
and the underlying Make(_).compile function which builds the circuits.
*)
fun ?self ?(cache = []) ?proof_cache ?disk_keys
?(return_early_digest_exception = false) ?override_wrap_domain
fun ?self ?(cache = []) ?(storables = Cache.Storables.default) ?proof_cache
?disk_keys ?(return_early_digest_exception = false) ?override_wrap_domain
?override_wrap_main ~public_input ~auxiliary_typ ~branches
~max_proofs_verified ~name ~constraint_constants ~choices () ->
let self =
Expand Down Expand Up @@ -1061,7 +1067,7 @@ let compile_with_wrap_main_override_promise :
in
let provers, wrap_vk, wrap_disk_key, cache_handle =
M.compile ~return_early_digest_exception ~self ~proof_cache ~cache
?disk_keys ?override_wrap_domain ?override_wrap_main ~branches
~storables ?disk_keys ?override_wrap_domain ?override_wrap_main ~branches
~max_proofs_verified ~name ~public_input ~auxiliary_typ
~constraint_constants
~choices:(fun ~self -> conv_irs (choices ~self))
Expand Down
1 change: 1 addition & 0 deletions src/lib/pickles/compile.mli
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ type ('max_proofs_verified, 'branches, 'prev_varss) wrap_main_generic =
val compile_with_wrap_main_override_promise :
?self:('var, 'value, 'max_proofs_verified, 'branches) Tag.t
-> ?cache:Key_cache.Spec.t list
-> ?storables:Cache.Storables.t
-> ?proof_cache:Proof_cache.t
-> ?disk_keys:
(Cache.Step.Key.Verification.t, 'branches) Vector.t
Expand Down
23 changes: 12 additions & 11 deletions src/lib/pickles/pickles.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ module Make_str (_ : Wire_types.Concrete) = struct
module Step_main_inputs = Step_main_inputs
module Step_verifier = Step_verifier
module Proof_cache = Proof_cache
module Cache = Cache

exception Return_digest = Compile.Return_digest

Expand Down Expand Up @@ -306,22 +307,22 @@ module Make_str (_ : Wire_types.Concrete) = struct
let compile_with_wrap_main_override_promise =
Compile.compile_with_wrap_main_override_promise

let compile_promise ?self ?cache ?proof_cache ?disk_keys
let compile_promise ?self ?cache ?storables ?proof_cache ?disk_keys
?return_early_digest_exception ?override_wrap_domain ~public_input
~auxiliary_typ ~branches ~max_proofs_verified ~name ~constraint_constants
~choices () =
compile_with_wrap_main_override_promise ?self ?cache ?proof_cache ?disk_keys
?return_early_digest_exception ?override_wrap_domain ~public_input
~auxiliary_typ ~branches ~max_proofs_verified ~name ~constraint_constants
~choices ()

let compile ?self ?cache ?proof_cache ?disk_keys ?override_wrap_domain
compile_with_wrap_main_override_promise ?self ?cache ?storables ?proof_cache
?disk_keys ?return_early_digest_exception ?override_wrap_domain
~public_input ~auxiliary_typ ~branches ~max_proofs_verified ~name
~constraint_constants ~choices () =
~constraint_constants ~choices ()

let compile ?self ?cache ?storables ?proof_cache ?disk_keys
?override_wrap_domain ~public_input ~auxiliary_typ ~branches
~max_proofs_verified ~name ~constraint_constants ~choices () =
let self, cache_handle, proof_module, provers =
compile_promise ?self ?cache ?proof_cache ?disk_keys ?override_wrap_domain
~public_input ~auxiliary_typ ~branches ~max_proofs_verified ~name
~constraint_constants ~choices ()
compile_promise ?self ?cache ?storables ?proof_cache ?disk_keys
?override_wrap_domain ~public_input ~auxiliary_typ ~branches
~max_proofs_verified ~name ~constraint_constants ~choices ()
in
let rec adjust_provers :
type a1 a2 a3 s1 s2_inner.
Expand Down
5 changes: 4 additions & 1 deletion src/lib/pickles/pickles_intf.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ module type S = sig
module Step_verifier = Step_verifier
module Common = Common
module Proof_cache = Proof_cache
module Cache = Cache

exception Return_digest of Md5.t

Expand All @@ -37,7 +38,7 @@ module type S = sig
[%%versioned:
module Stable : sig
module V2 : sig
type t [@@deriving to_yojson]
type t [@@deriving to_yojson, of_yojson]
end
end]

Expand Down Expand Up @@ -366,6 +367,7 @@ module type S = sig
val compile_promise :
?self:('var, 'value, 'max_proofs_verified, 'branches) Tag.t
-> ?cache:Key_cache.Spec.t list
-> ?storables:Cache.Storables.t
-> ?proof_cache:Proof_cache.t
-> ?disk_keys:
(Cache.Step.Key.Verification.t, 'branches) Vector.t
Expand Down Expand Up @@ -421,6 +423,7 @@ module type S = sig
val compile :
?self:('var, 'value, 'max_proofs_verified, 'branches) Tag.t
-> ?cache:Key_cache.Spec.t list
-> ?storables:Cache.Storables.t
-> ?proof_cache:Proof_cache.t
-> ?disk_keys:
(Cache.Step.Key.Verification.t, 'branches) Vector.t
Expand Down
Loading
Loading