1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16
17
18 package sqlkv
19
20 import (
21 "context"
22 "database/sql"
23 "errors"
24 "fmt"
25 "log"
26 "strings"
27 "sync"
28
29 "go4.org/syncutil"
30 "perkeep.org/internal/leak"
31 "perkeep.org/pkg/sorted"
32 )
33
34
35 type KeyValue struct {
36 DB *sql.DB
37
38
39 SetFunc func(*sql.DB, string, string) error
40 BatchSetFunc func(*sql.Tx, string, string) error
41
42
43
44 PlaceHolderFunc func(string) string
45
46
47
48
49
50
51
52
53
54
55 Gate *syncutil.Gate
56
57
58
59 TablePrefix string
60
61 queriesInitOnce sync.Once
62 replacer *strings.Replacer
63
64 queriesMu sync.RWMutex
65 queries map[string]string
66 }
67
68
69
70 func (kv *KeyValue) sql(sqlStmt string) string {
71
72 kv.queriesInitOnce.Do(func() {
73 kv.queries = make(map[string]string, 8)
74 kv.replacer = strings.NewReplacer("/*TPRE*/", kv.TablePrefix)
75 })
76 kv.queriesMu.RLock()
77 sqlQuery, ok := kv.queries[sqlStmt]
78 kv.queriesMu.RUnlock()
79 if ok {
80 return sqlQuery
81 }
82 kv.queriesMu.Lock()
83
84 if sqlQuery, ok = kv.queries[sqlStmt]; ok {
85 kv.queriesMu.Unlock()
86 return sqlQuery
87 }
88 sqlQuery = sqlStmt
89 if f := kv.PlaceHolderFunc; f != nil {
90 sqlQuery = f(sqlQuery)
91 }
92 sqlQuery = kv.replacer.Replace(sqlQuery)
93 kv.queries[sqlStmt] = sqlQuery
94 kv.queriesMu.Unlock()
95 return sqlQuery
96 }
97
98 type batchTx struct {
99 tx *sql.Tx
100 err error
101 kv *KeyValue
102 }
103
104 func (b *batchTx) Set(key, value string) {
105 if b.err != nil {
106 return
107 }
108 if err := sorted.CheckSizes(key, value); err != nil {
109 log.Printf("Skipping storing (%q:%q): %v", key, value, err)
110 return
111 }
112 if b.kv.BatchSetFunc != nil {
113 b.err = b.kv.BatchSetFunc(b.tx, key, value)
114 return
115 }
116 _, b.err = b.tx.Exec(b.kv.sql("REPLACE INTO /*TPRE*/rows (k, v) VALUES (?, ?)"), key, value)
117 }
118
119 func (b *batchTx) Delete(key string) {
120 if b.err != nil {
121 return
122 }
123 _, b.err = b.tx.Exec(b.kv.sql("DELETE FROM /*TPRE*/rows WHERE k=?"), key)
124 }
125
126 func (b *batchTx) Find(start, end string) sorted.Iterator {
127 if b.err != nil {
128 return &iter{
129 kv: b.kv,
130 closeCheck: leak.NewChecker(),
131 err: b.err,
132 }
133 }
134 return find(b.kv, b.tx, start, end)
135 }
136
137 func (b *batchTx) Get(key string) (value string, err error) {
138 if b.err != nil {
139 return "", b.err
140 }
141 return get(b.kv, b.tx, key)
142 }
143
144 func (b *batchTx) Close() error {
145 if b.err != nil {
146 return b.err
147 }
148 if b.kv.Gate != nil {
149 defer b.kv.Gate.Done()
150 }
151 return b.tx.Commit()
152 }
153
154 func (kv *KeyValue) beginTx(txOpts *sql.TxOptions) *batchTx {
155 if kv.Gate != nil {
156 kv.Gate.Start()
157 }
158 tx, err := kv.DB.BeginTx(context.TODO(), txOpts)
159 if err != nil {
160 log.Printf("SQL BEGIN BATCH: %v", err)
161 }
162 return &batchTx{
163 tx: tx,
164 err: err,
165 kv: kv,
166 }
167 }
168
169 func (kv *KeyValue) BeginBatch() sorted.BatchMutation {
170 return kv.beginTx(nil)
171 }
172
173 func (kv *KeyValue) CommitBatch(b sorted.BatchMutation) error {
174 if kv.Gate != nil {
175 defer kv.Gate.Done()
176 }
177 bt, ok := b.(*batchTx)
178 if !ok {
179 return fmt.Errorf("wrong BatchMutation type %T", b)
180 }
181 if bt.err != nil {
182 if err := bt.tx.Rollback(); err != nil {
183 log.Printf("Transaction rollback error: %v", err)
184 }
185 return bt.err
186 }
187 return bt.tx.Commit()
188 }
189
190 func (kv *KeyValue) BeginReadTx() sorted.ReadTransaction {
191 return kv.beginTx(&sql.TxOptions{
192 ReadOnly: true,
193
194
195 Isolation: sql.LevelSerializable,
196 })
197
198 }
199
200 func (kv *KeyValue) Get(key string) (value string, err error) {
201 if kv.Gate != nil {
202 kv.Gate.Start()
203 defer kv.Gate.Done()
204 }
205 return get(kv, kv.DB, key)
206 }
207
208 func (kv *KeyValue) Set(key, value string) error {
209 if err := sorted.CheckSizes(key, value); err != nil {
210 log.Printf("Skipping storing (%q:%q): %v", key, value, err)
211 return nil
212 }
213 if kv.Gate != nil {
214 kv.Gate.Start()
215 defer kv.Gate.Done()
216 }
217 if kv.SetFunc != nil {
218 return kv.SetFunc(kv.DB, key, value)
219 }
220 _, err := kv.DB.Exec(kv.sql("REPLACE INTO /*TPRE*/rows (k, v) VALUES (?, ?)"), key, value)
221 return err
222 }
223
224 func (kv *KeyValue) Delete(key string) error {
225 if kv.Gate != nil {
226 kv.Gate.Start()
227 defer kv.Gate.Done()
228 }
229 _, err := kv.DB.Exec(kv.sql("DELETE FROM /*TPRE*/rows WHERE k=?"), key)
230 return err
231 }
232
233
234
235
236 func (kv *KeyValue) Wipe() error {
237 if kv.Gate != nil {
238 kv.Gate.Start()
239 defer kv.Gate.Done()
240 }
241 _, err := kv.DB.Exec(kv.sql("DELETE FROM /*TPRE*/rows"))
242 return err
243 }
244
245 func (kv *KeyValue) Close() error { return kv.DB.Close() }
246
247
248 type queryObject interface {
249 QueryRow(query string, args ...interface{}) *sql.Row
250 Query(query string, args ...interface{}) (*sql.Rows, error)
251 }
252
253
254 func find(kv *KeyValue, qobj queryObject, start, end string) *iter {
255 var rows *sql.Rows
256 var err error
257 if end == "" {
258 rows, err = qobj.Query(kv.sql("SELECT k, v FROM /*TPRE*/rows WHERE k >= ? ORDER BY k "), start)
259 } else {
260 rows, err = qobj.Query(kv.sql("SELECT k, v FROM /*TPRE*/rows WHERE k >= ? AND k < ? ORDER BY k "), start, end)
261 }
262 if err != nil {
263 log.Printf("unexpected query error: %v", err)
264 return &iter{err: err}
265 }
266
267 return &iter{
268 kv: kv,
269 rows: rows,
270 closeCheck: leak.NewChecker(),
271 }
272 }
273
274
275 func get(kv *KeyValue, qobj queryObject, key string) (value string, err error) {
276 err = qobj.QueryRow(kv.sql("SELECT v FROM /*TPRE*/rows WHERE k=?"), key).Scan(&value)
277 if err == sql.ErrNoRows {
278 err = sorted.ErrNotFound
279 }
280 return
281 }
282
283 func (kv *KeyValue) Find(start, end string) sorted.Iterator {
284 var releaseGate func()
285 if kv.Gate != nil {
286 var once sync.Once
287 kv.Gate.Start()
288 releaseGate = func() {
289 once.Do(kv.Gate.Done)
290 }
291 }
292 it := find(kv, kv.DB, start, end)
293 it.releaseGate = releaseGate
294 return it
295 }
296
297
298 type iter struct {
299 kv *KeyValue
300 err error
301
302 closeCheck *leak.Checker
303 releaseGate func()
304
305 rows *sql.Rows
306
307 key sql.RawBytes
308 val sql.RawBytes
309 skey, sval *string
310 }
311
312 var errClosed = errors.New("sqlkv: Iterator already closed")
313
314 func (t *iter) KeyBytes() []byte { return t.key }
315 func (t *iter) Key() string {
316 if t.skey != nil {
317 return *t.skey
318 }
319 str := string(t.key)
320 t.skey = &str
321 return str
322 }
323
324 func (t *iter) ValueBytes() []byte { return t.val }
325 func (t *iter) Value() string {
326 if t.sval != nil {
327 return *t.sval
328 }
329 str := string(t.val)
330 t.sval = &str
331 return str
332 }
333
334 func (t *iter) Close() error {
335 t.closeCheck.Close()
336 if t.rows != nil {
337 t.rows.Close()
338 t.rows = nil
339 }
340 if t.releaseGate != nil {
341 t.releaseGate()
342 }
343 err := t.err
344 t.err = errClosed
345 return err
346 }
347
348 func (t *iter) Next() bool {
349 if t.err != nil {
350 return false
351 }
352 t.skey, t.sval = nil, nil
353 if !t.rows.Next() {
354 return false
355 }
356 t.err = t.rows.Scan(&t.key, &t.val)
357 if t.err != nil {
358 log.Printf("unexpected Scan error: %v", t.err)
359 return false
360 }
361 return true
362 }