1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16
17
18
19 package postgres
20
21 import (
22 "database/sql"
23 "errors"
24 "fmt"
25 "regexp"
26
27 "go4.org/jsonconfig"
28 "perkeep.org/pkg/env"
29 "perkeep.org/pkg/sorted"
30 "perkeep.org/pkg/sorted/sqlkv"
31
32 _ "github.com/lib/pq"
33 )
34
35 func init() {
36 sorted.RegisterKeyValue("postgres", newKeyValueFromJSONConfig)
37 }
38
39 func newKeyValueFromJSONConfig(cfg jsonconfig.Obj) (sorted.KeyValue, error) {
40 var (
41 user = cfg.RequiredString("user")
42 database = cfg.RequiredString("database")
43 host = cfg.OptionalString("host", "localhost")
44 password = cfg.OptionalString("password", "")
45 sslmode = cfg.OptionalString("sslmode", "require")
46 )
47 if err := cfg.Validate(); err != nil {
48 return nil, err
49 }
50
51
52 conninfo := fmt.Sprintf("user=%s host=%s sslmode=%s", user, host, sslmode)
53 if password != "" {
54 conninfo += fmt.Sprintf(" password=%s", password)
55 }
56 db, err := sql.Open("postgres", conninfo)
57 if err != nil {
58 return nil, err
59 }
60 err = createDB(db, database)
61 db.Close()
62 if err != nil {
63 return nil, err
64 }
65
66
67 conninfo += fmt.Sprintf(" dbname=%s", database)
68 db, err = sql.Open("postgres", conninfo)
69 if err != nil {
70 return nil, err
71 }
72
73 for _, tableSql := range SQLCreateTables() {
74 if _, err := db.Exec(tableSql); err != nil {
75 return nil, fmt.Errorf("error creating table with %q: %v", tableSql, err)
76 }
77 }
78 for _, statement := range SQLDefineReplace() {
79 if _, err := db.Exec(statement); err != nil {
80 return nil, fmt.Errorf("error setting up replace statement with %q: %v", statement, err)
81 }
82 }
83 r, err := db.Query(fmt.Sprintf(`SELECT replaceintometa('version', '%d')`, SchemaVersion()))
84 if err != nil {
85 return nil, fmt.Errorf("error setting schema version: %v", err)
86 }
87 r.Close()
88
89 kv := &keyValue{
90 db: db,
91 KeyValue: &sqlkv.KeyValue{
92 DB: db,
93 SetFunc: altSet,
94 BatchSetFunc: altBatchSet,
95 PlaceHolderFunc: replacePlaceHolders,
96 },
97 }
98 if err := kv.ping(); err != nil {
99 return nil, fmt.Errorf("PostgreSQL db unreachable: %v", err)
100 }
101 version, err := kv.SchemaVersion()
102 if err != nil {
103 return nil, fmt.Errorf("error getting schema version (need to init database?): %v", err)
104 }
105 if version != requiredSchemaVersion {
106 if env.IsDev() {
107
108
109 return nil, fmt.Errorf("database schema version is %d; expect %d (run \"devcam server --wipe\" to wipe both your blobs and re-populate the database schema)", version, requiredSchemaVersion)
110 }
111 return nil, fmt.Errorf("database schema version is %d; expect %d (need to re-init/upgrade database?)",
112 version, requiredSchemaVersion)
113 }
114
115 return kv, nil
116 }
117
118 type keyValue struct {
119 *sqlkv.KeyValue
120 db *sql.DB
121 }
122
123
124
125 func altSet(db *sql.DB, key, value string) error {
126 r, err := db.Query("SELECT replaceinto($1, $2)", key, value)
127 if err != nil {
128 return err
129 }
130 return r.Close()
131 }
132
133
134
135 func altBatchSet(tx *sql.Tx, key, value string) error {
136 r, err := tx.Query("SELECT replaceinto($1, $2)", key, value)
137 if err != nil {
138 return err
139 }
140 return r.Close()
141 }
142
143 var qmark = regexp.MustCompile(`\?`)
144
145
146 var replacePlaceHolders = func(query string) string {
147 i := 0
148 dollarInc := func(b []byte) []byte {
149 i++
150 return []byte(fmt.Sprintf("$%d", i))
151 }
152 return string(qmark.ReplaceAllFunc([]byte(query), dollarInc))
153 }
154
155 func (kv *keyValue) ping() error {
156 _, err := kv.SchemaVersion()
157 return err
158 }
159
160 func (kv *keyValue) SchemaVersion() (version int, err error) {
161 err = kv.db.QueryRow("SELECT value FROM meta WHERE metakey='version'").Scan(&version)
162 return
163 }
164
165 var validDatabaseRegex = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
166
167 func validDatabaseName(database string) bool {
168 return validDatabaseRegex.MatchString(database)
169 }
170
171 func createDB(db *sql.DB, database string) error {
172 if database == "" {
173 return errors.New("database name can't be empty")
174 }
175
176 rows, err := db.Query(`SELECT 1 FROM pg_database WHERE datname = $1`, database)
177 if err != nil {
178 return err
179 }
180 defer rows.Close()
181 if rows.Next() {
182 return nil
183 }
184
185
186 if !validDatabaseName(database) {
187 return fmt.Errorf("Invalid postgres database name: %q", database)
188 }
189 _, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
190 if err != nil {
191 return err
192 }
193 return err
194 }