From afa86358892fc0a497f2c22959e6cb3c4789f6a3 Mon Sep 17 00:00:00 2001 From: Simon Cruanes Date: Thu, 28 Apr 2022 10:42:35 -0400 Subject: [PATCH] feat: add access to last_row_id in cursor --- src/sqlite3_utils.ml | 41 ++++++++++++++++++++++++++++++----------- src/sqlite3_utils.mli | 16 +++++++++++++--- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/sqlite3_utils.ml b/src/sqlite3_utils.ml index d5323f2..fab9017 100644 --- a/src/sqlite3_utils.ml +++ b/src/sqlite3_utils.ml @@ -160,9 +160,11 @@ end module Cursor = struct type 'a t = { + db: db; stmt: Sqlite3.stmt; read: Sqlite3.stmt -> 'a; mutable cur: 'a option; + mutable last_row_id: int64; } let ignore _ = () @@ -173,11 +175,12 @@ module Cursor = struct self.cur <- None; | Sqlite3.Rc.ROW -> let x = self.read self.stmt in + self.last_row_id <- Sqlite3.last_insert_rowid self.db; self.cur <- Some x | rc -> raise (RcError rc) - let make_ stmt read = - let self = { stmt; cur=None; read; } in + let make_ db stmt read = + let self = { db; stmt; cur=None; read; last_row_id=0L; } in next_ self; self @@ -192,17 +195,17 @@ module Cursor = struct | Error rc -> raise (RcError rc) let map ~f c : _ t = - {stmt=c.stmt; + {c with read=(fun stmt -> f (c.read stmt)); cur=opt_map_ f c.cur; } - let make stmt ty f = + let make db stmt ty f = let read stmt = Ty.tr_row (Sqlite3.column stmt) 0 ty f in - make_ stmt read + make_ db stmt read - let make_raw stmt : Data.t array t = - make_ stmt Sqlite3.row_data + let make_raw db stmt : Data.t array t = + make_ db stmt Sqlite3.row_data (* next value in the cursor *) let next self : _ option = @@ -212,6 +215,15 @@ module Cursor = struct next_ self; x + (* next value in the cursor *) + let next_with_last_row_id self : _ option = + match self.cur with + | None -> None + | Some x -> + let last = self.last_row_id in + next_ self; + Some (x, last) + let rec iter ~f self = match self.cur with | None -> () | Some res -> @@ -219,6 +231,13 @@ module Cursor = struct next_ self; iter ~f self + let rec iter_with_last_row_id ~f self = match self.cur with + | None -> () + | Some res -> + f self.last_row_id res; + next_ self; + iter_with_last_row_id ~f self + let to_seq self = let rec get_next () = let n = lazy ( @@ -307,7 +326,7 @@ let exec_raw_exn db str ~f = with_stmt db str ~f:(fun stmt -> check_arity_params_ stmt 0; - f (Cursor.make_raw stmt)) + f (Cursor.make_raw db stmt)) let exec_raw db str ~f = try Ok (exec_raw_exn db str ~f) @@ -319,7 +338,7 @@ let exec_raw_args_exn db str a ~f = ~f:(fun stmt -> check_arity_params_ stmt (Array.length a); Array.iteri (fun i x -> check_ret_exn (Sqlite3.bind stmt (i+1) x)) a; - f (Cursor.make_raw stmt)) + f (Cursor.make_raw db stmt)) let exec_raw_args db str a ~f = try Ok (exec_raw_args_exn db str a ~f) @@ -353,7 +372,7 @@ let exec_ db str ~ty ~f = | Ok () -> try check_arity_res_ stmt (Ty.count ty_r); - let c = Cursor.make stmt ty_r f_r in + let c = Cursor.make db stmt ty_r f_r in let r = f (Ok c) in finalize_check_ stmt; r @@ -383,7 +402,7 @@ let exec_no_params_exn db str ~ty ~f = ~f:(fun stmt -> check_arity_params_ stmt 0; let ty_r, f_r = ty in - f (Cursor.make stmt ty_r f_r)) + f (Cursor.make db stmt ty_r f_r)) let exec_no_params db str ~ty ~f = exec_no_params_exn db str ~ty diff --git a/src/sqlite3_utils.mli b/src/sqlite3_utils.mli index d354d90..59b2bb0 100644 --- a/src/sqlite3_utils.mli +++ b/src/sqlite3_utils.mli @@ -121,7 +121,7 @@ module Ty : sig val mkp6: 'a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'a * 'b * 'c * 'd * 'e * 'f end -(** {2 Cursor API} +(** {2 Cursor API} A Cursor is a special iterator over Sqlite rows of results. It should be consumed quickly as it will not survive the call to @@ -138,6 +138,10 @@ module Cursor : sig val next : 'a t -> 'a option (** Get next value, or [None] if all values have been enumerated *) + val next_with_last_row_id : 'a t -> ('a * int64) option + (** Same as {!next}, but also return last inserted row ID + @since NEXT_RELEASE *) + val get_one : 'a t -> ('a, Rc.t) result (** Get the first element (useful when querying a scalar, like "count( * )"). returns [Error Rc.NOTFOUND] if it's empty. @@ -151,6 +155,10 @@ module Cursor : sig val iter : f:('a -> unit) -> 'a t -> unit (** Iterate over the values *) + val iter_with_last_row_id : f:(int64 -> 'a -> unit) -> 'a t -> unit + (** Iterate over the values with last row id + @since NEXT_RELEASE *) + val map : f:('a -> 'b) -> 'a t -> 'b t (** Map over values of the cursor. Once [map ~f c] is built, [c] should not be used. *) @@ -290,8 +298,10 @@ val exec_get_column_names : t -> string -> string list val transact : t -> (t -> 'a) -> 'a (** [transact db f] runs [f db] within a transaction (begin/commit/rollback). Useful to perform a batch of insertions or updates, as Sqlite doesn't - write very fast. *) + write very fast. Cannot be nested. + See https://www.sqlite.org/lang_transaction.html . *) val atomically : t -> (t -> 'a) -> 'a (** Same as {!transact} but uses Sqlite's savepoint/release/rollback mechanism. - instead of begin/commit/rolllback *) + instead of begin/commit/rollback. + This can be nested, unlike {!transact}. *)