aboutsummaryrefslogtreecommitdiffstats
path: root/src/letsqlite.ml
blob: 83bf0c6058d23adb78f973ea347c8a29cc867a6c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
module S = Sqlite3
module Rc = S.Rc

type stmt_wrap = bool * S.stmt
type 'a stmt_m =
| Failed of Rc.t
| Norm of ('a * stmt_wrap)

type ('a,'b) monad_fun = ('a * stmt_wrap) -> 'b stmt_m
type rowdata = (string, Sqlite3.Data.t) Hashtbl.t
(* Monad definitions *)
let (>>$) s f = match s with
| Failed r -> Failed r
| Norm (x,s) -> f (x,s)

let (let$) = (>>$)

let (>-$) s f = match s with
| Failed r -> Failed r
| Norm (x,s) -> (match f (x,s) with
	| Failed r -> Failed r
	| Norm (_,s) -> Norm (x,s)
	)
let (@>$) s f = fun x -> s x >>$ f

let gs (_,s) = s

(* Helper definitions *)
let transform f (v,s) = Norm ((f v),s)
let inject v = transform (fun _ -> v)
let fail er (_,(b,s)) =
	if b then Failed er else (
		let er2 = S.finalize s
		in if Rc.is_success er then Failed er2
		else Failed er
	)
let stmtfail er s = fail er ((),s)
let lift_err e tp =
	if Rc.is_success e then Norm tp
	else fail e tp

(* Helper functions *)
let arr_to_tbl tbl keys vals =
	assert (Array.length keys = Array.length vals);
	Array.iteri (fun i x -> Hashtbl.add tbl x vals.(i)) keys

let create_rowtbl s rowdata =
	let keys = S.row_names (gs s)
	in let tbl = Hashtbl.create (Array.length keys)
	in arr_to_tbl tbl keys rowdata;
	tbl

(* Monadic functions *)
let prepare_keep keep db stmt v =
	try Norm (v,(keep, S.prepare db stmt)) with
	_ -> Failed Rc.ERROR
let prepare db s v = prepare_keep false db s v
let reprepare db stmt (v,(b,s)) =
	let x = if not b then S.finalize s
		else Rc.OK
	in if Rc.is_success x then prepare db stmt v
	else Failed x

let reset (v,s) = lift_err (S.reset (gs s)) (v,s)
let step (v,s) = lift_err (S.step (gs s)) (v,s)
let bind_values l (v,s) = lift_err (S.bind_values (gs s) l) (v,s)
let clear_bindings (v,s) = lift_err (S.clear_bindings (gs s)) (v,s)

let execif r v = if Rc.is_success r then Norm v else fail r v

let rowfold f (init,s) =
	let rec fold acc ((),s) =
		let r = S.step (gs s)
		in if r = Rc.ROW then
		match f acc (create_rowtbl s (S.row_data (gs s))) with
			| Error r -> stmtfail r s
			| Ok x -> fold x ((),s)
		else execif r (acc,s)
	in fold init ((),s) >>$ reset

let iter f (_,s) =
	let wrap () x = let r = f x in
		if r = Rc.OK then Ok ()
		else Error r
	in rowfold wrap ((),s)

let map f l (ival,s) =
	let rec map l (b,s) = match l with
	| [] -> Norm (b,s)
	| h::t ->
		let$ v,s = f h (ival,s)
		in map t (v::b, s)
	in let$ (v,s) = map l ([],s)
	in Norm (List.rev v, s)

let rec fold f l (v,s) = match l with
| [] -> Norm (v,s)
| h::t -> f h (v,s) >>$ fold f t

let get_exactly_one_row (_,s) =
	let r = S.step (gs s)
	in match r with
	| Rc.ROW -> if S.data_count (gs s) = 0 then stmtfail Rc.NOTFOUND s
		else let tbl = create_rowtbl s (S.row_data (gs s))
		in let r = S.step (gs s)
		in if Rc.is_success r then
			reset (((),tbl),s)
		else stmtfail r s
	| x -> stmtfail x s

let exec_extract extrfun convfun ((_,arg),s) =
	match extrfun arg with
	| None -> stmtfail Rc.NOTFOUND s
	| Some x -> (match convfun x with
		| None -> stmtfail Rc.NOTFOUND s
		| Some x -> Norm ((x,arg), s)
	)

let extract name conv tu =
	exec_extract (fun tbl -> Hashtbl.find_opt tbl name) conv tu

let finalize = function
| Failed e -> raise (S.SqliteError (Rc.to_string e))
| Norm (v,(b,s)) -> if b then v else (
                    Rc.check (S.finalize s); v
                    )

let exec db s = prepare db s () >>$ step
let reexec db s = reprepare db s @>$ step
let bsrc l = bind_values l @>$ step @>$ reset @>$ clear_bindings

let%test "simple usage" = 
	let db = S.db_open ":memory:"
	in let s1 = ref []
	in let fin = (
		let m = exec db "CREATE TABLE tbl (x,y);"
		>>$ reprepare db "INSERT INTO tbl VALUES(?,?);"
		>>$ bsrc [S.Data.INT 1L; S.Data.INT 2L]
		>>$ bsrc [S.Data.INT 2L; S.Data.INT 3L]
		>>$ reprepare db "SELECT y FROM tbl WHERE x = ?;"
		>>$ bind_values [S.Data.INT 1L]
		>>$ get_exactly_one_row
		>>$ extract "y" S.Data.to_int64
		in let$ ((v,r),s) = m
		in s1 := [v];
		let$ ((v,_),s) = bind_values [S.Data.INT 2L] ((v,r),s)
		>>$ get_exactly_one_row
		>>$ extract "y" S.Data.to_int64
		in s1 := v::!s1;
		Norm ((),s)
	)
	in finalize fin; S.db_close db && !s1 = [3L; 2L]