Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add access to last_row_id in cursor #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 30 additions & 11 deletions src/sqlite3_utils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,11 @@ end

module Cursor = struct
type 'a t = {
db: db;
stmt: Sqlite3.stmt;
read: Sqlite3.stmt -> 'a;
mutable cur: 'a option;
mutable last_row_id: int64;
}

let ignore _ = ()
Expand All @@ -173,11 +175,12 @@ module Cursor = struct
self.cur <- None;
| Sqlite3.Rc.ROW ->
let x = self.read self.stmt in
self.last_row_id <- Sqlite3.last_insert_rowid self.db;
self.cur <- Some x
| rc -> raise (RcError rc)

let make_ stmt read =
let self = { stmt; cur=None; read; } in
let make_ db stmt read =
let self = { db; stmt; cur=None; read; last_row_id=0L; } in
next_ self;
self

Expand All @@ -192,17 +195,17 @@ module Cursor = struct
| Error rc -> raise (RcError rc)

let map ~f c : _ t =
{stmt=c.stmt;
{c with
read=(fun stmt -> f (c.read stmt));
cur=opt_map_ f c.cur;
}

let make stmt ty f =
let make db stmt ty f =
let read stmt = Ty.tr_row (Sqlite3.column stmt) 0 ty f in
make_ stmt read
make_ db stmt read

let make_raw stmt : Data.t array t =
make_ stmt Sqlite3.row_data
let make_raw db stmt : Data.t array t =
make_ db stmt Sqlite3.row_data

(* next value in the cursor *)
let next self : _ option =
Expand All @@ -212,13 +215,29 @@ module Cursor = struct
next_ self;
x

(* next value in the cursor *)
let next_with_last_row_id self : _ option =
match self.cur with
| None -> None
| Some x ->
let last = self.last_row_id in
next_ self;
Some (x, last)

let rec iter ~f self = match self.cur with
| None -> ()
| Some res ->
f res;
next_ self;
iter ~f self

let rec iter_with_last_row_id ~f self = match self.cur with
| None -> ()
| Some res ->
f self.last_row_id res;
next_ self;
iter_with_last_row_id ~f self

let to_seq self =
let rec get_next () =
let n = lazy (
Expand Down Expand Up @@ -307,7 +326,7 @@ let exec_raw_exn db str ~f =
with_stmt db str
~f:(fun stmt ->
check_arity_params_ stmt 0;
f (Cursor.make_raw stmt))
f (Cursor.make_raw db stmt))

let exec_raw db str ~f =
try Ok (exec_raw_exn db str ~f)
Expand All @@ -319,7 +338,7 @@ let exec_raw_args_exn db str a ~f =
~f:(fun stmt ->
check_arity_params_ stmt (Array.length a);
Array.iteri (fun i x -> check_ret_exn (Sqlite3.bind stmt (i+1) x)) a;
f (Cursor.make_raw stmt))
f (Cursor.make_raw db stmt))

let exec_raw_args db str a ~f =
try Ok (exec_raw_args_exn db str a ~f)
Expand Down Expand Up @@ -353,7 +372,7 @@ let exec_ db str ~ty ~f =
| Ok () ->
try
check_arity_res_ stmt (Ty.count ty_r);
let c = Cursor.make stmt ty_r f_r in
let c = Cursor.make db stmt ty_r f_r in
let r = f (Ok c) in
finalize_check_ stmt;
r
Expand Down Expand Up @@ -383,7 +402,7 @@ let exec_no_params_exn db str ~ty ~f =
~f:(fun stmt ->
check_arity_params_ stmt 0;
let ty_r, f_r = ty in
f (Cursor.make stmt ty_r f_r))
f (Cursor.make db stmt ty_r f_r))

let exec_no_params db str ~ty ~f =
exec_no_params_exn db str ~ty
Expand Down
16 changes: 13 additions & 3 deletions src/sqlite3_utils.mli
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ module Ty : sig
val mkp6: 'a -> 'b -> 'c -> 'd -> 'e -> 'f -> 'a * 'b * 'c * 'd * 'e * 'f
end

(** {2 Cursor API}
(** {2 Cursor API}

A Cursor is a special iterator over Sqlite rows of results.
It should be consumed quickly as it will not survive the call to
Expand All @@ -138,6 +138,10 @@ module Cursor : sig
val next : 'a t -> 'a option
(** Get next value, or [None] if all values have been enumerated *)

val next_with_last_row_id : 'a t -> ('a * int64) option
(** Same as {!next}, but also return last inserted row ID
@since NEXT_RELEASE *)

val get_one : 'a t -> ('a, Rc.t) result
(** Get the first element (useful when querying a scalar, like "count( * )").
returns [Error Rc.NOTFOUND] if it's empty.
Expand All @@ -151,6 +155,10 @@ module Cursor : sig
val iter : f:('a -> unit) -> 'a t -> unit
(** Iterate over the values *)

val iter_with_last_row_id : f:(int64 -> 'a -> unit) -> 'a t -> unit
(** Iterate over the values with last row id
@since NEXT_RELEASE *)

val map : f:('a -> 'b) -> 'a t -> 'b t
(** Map over values of the cursor. Once [map ~f c] is built, [c] should
not be used. *)
Expand Down Expand Up @@ -290,8 +298,10 @@ val exec_get_column_names : t -> string -> string list
val transact : t -> (t -> 'a) -> 'a
(** [transact db f] runs [f db] within a transaction (begin/commit/rollback).
Useful to perform a batch of insertions or updates, as Sqlite doesn't
write very fast. *)
write very fast. Cannot be nested.
See https://www.sqlite.org/lang_transaction.html . *)

val atomically : t -> (t -> 'a) -> 'a
(** Same as {!transact} but uses Sqlite's savepoint/release/rollback mechanism.
instead of begin/commit/rolllback *)
instead of begin/commit/rollback.
This can be nested, unlike {!transact}. *)