Home Download Docs Code Community
     1	/*
     2	Copyright 2012 The Perkeep Authors.
     3	
     4	Licensed under the Apache License, Version 2.0 (the "License");
     5	you may not use this file except in compliance with the License.
     6	You may obtain a copy of the License at
     7	
     8	     http://www.apache.org/licenses/LICENSE-2.0
     9	
    10	Unless required by applicable law or agreed to in writing, software
    11	distributed under the License is distributed on an "AS IS" BASIS,
    12	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13	See the License for the specific language governing permissions and
    14	limitations under the License.
    15	*/
    16	
    17	// Package postgres provides an implementation of sorted.KeyValue
    18	// on top of PostgreSQL.
    19	package postgres // import "perkeep.org/pkg/sorted/postgres"
    20	
    21	import (
    22		"database/sql"
    23		"errors"
    24		"fmt"
    25		"regexp"
    26	
    27		"go4.org/jsonconfig"
    28		"perkeep.org/pkg/env"
    29		"perkeep.org/pkg/sorted"
    30		"perkeep.org/pkg/sorted/sqlkv"
    31	
    32		_ "github.com/lib/pq"
    33	)
    34	
    35	func init() {
    36		sorted.RegisterKeyValue("postgres", newKeyValueFromJSONConfig)
    37	}
    38	
    39	func newKeyValueFromJSONConfig(cfg jsonconfig.Obj) (sorted.KeyValue, error) {
    40		var (
    41			user     = cfg.RequiredString("user")
    42			database = cfg.RequiredString("database")
    43			host     = cfg.OptionalString("host", "localhost")
    44			password = cfg.OptionalString("password", "")
    45			sslmode  = cfg.OptionalString("sslmode", "require")
    46		)
    47		if err := cfg.Validate(); err != nil {
    48			return nil, err
    49		}
    50	
    51		// connect without a database, it may not exist yet
    52		conninfo := fmt.Sprintf("user=%s host=%s sslmode=%s", user, host, sslmode)
    53		if password != "" {
    54			conninfo += fmt.Sprintf(" password=%s", password)
    55		}
    56		db, err := sql.Open("postgres", conninfo)
    57		if err != nil {
    58			return nil, err
    59		}
    60		err = createDB(db, database)
    61		db.Close() // ignoring error, if createDB failed db.Close() will likely also fail
    62		if err != nil {
    63			return nil, err
    64		}
    65	
    66		// reconnect after database is created
    67		conninfo += fmt.Sprintf(" dbname=%s", database)
    68		db, err = sql.Open("postgres", conninfo)
    69		if err != nil {
    70			return nil, err
    71		}
    72	
    73		for _, tableSql := range SQLCreateTables() {
    74			if _, err := db.Exec(tableSql); err != nil {
    75				return nil, fmt.Errorf("error creating table with %q: %v", tableSql, err)
    76			}
    77		}
    78		for _, statement := range SQLDefineReplace() {
    79			if _, err := db.Exec(statement); err != nil {
    80				return nil, fmt.Errorf("error setting up replace statement with %q: %v", statement, err)
    81			}
    82		}
    83		r, err := db.Query(fmt.Sprintf(`SELECT replaceintometa('version', '%d')`, SchemaVersion()))
    84		if err != nil {
    85			return nil, fmt.Errorf("error setting schema version: %v", err)
    86		}
    87		r.Close()
    88	
    89		kv := &keyValue{
    90			db: db,
    91			KeyValue: &sqlkv.KeyValue{
    92				DB:              db,
    93				SetFunc:         altSet,
    94				BatchSetFunc:    altBatchSet,
    95				PlaceHolderFunc: replacePlaceHolders,
    96			},
    97		}
    98		if err := kv.ping(); err != nil {
    99			return nil, fmt.Errorf("PostgreSQL db unreachable: %v", err)
   100		}
   101		version, err := kv.SchemaVersion()
   102		if err != nil {
   103			return nil, fmt.Errorf("error getting schema version (need to init database?): %v", err)
   104		}
   105		if version != requiredSchemaVersion {
   106			if env.IsDev() {
   107				// Good signal that we're using the devcam server, so help out
   108				// the user with a more useful tip:
   109				return nil, fmt.Errorf("database schema version is %d; expect %d (run \"devcam server --wipe\" to wipe both your blobs and re-populate the database schema)", version, requiredSchemaVersion)
   110			}
   111			return nil, fmt.Errorf("database schema version is %d; expect %d (need to re-init/upgrade database?)",
   112				version, requiredSchemaVersion)
   113		}
   114	
   115		return kv, nil
   116	}
   117	
   118	type keyValue struct {
   119		*sqlkv.KeyValue
   120		db *sql.DB
   121	}
   122	
   123	// postgres does not have REPLACE INTO (upsert), so we use that custom
   124	// one for Set operations instead
   125	func altSet(db *sql.DB, key, value string) error {
   126		r, err := db.Query("SELECT replaceinto($1, $2)", key, value)
   127		if err != nil {
   128			return err
   129		}
   130		return r.Close()
   131	}
   132	
   133	// postgres does not have REPLACE INTO (upsert), so we use that custom
   134	// one for Set operations in batch instead
   135	func altBatchSet(tx *sql.Tx, key, value string) error {
   136		r, err := tx.Query("SELECT replaceinto($1, $2)", key, value)
   137		if err != nil {
   138			return err
   139		}
   140		return r.Close()
   141	}
   142	
   143	var qmark = regexp.MustCompile(`\?`)
   144	
   145	// replace all ? placeholders into the corresponding $n in queries
   146	var replacePlaceHolders = func(query string) string {
   147		i := 0
   148		dollarInc := func(b []byte) []byte {
   149			i++
   150			return []byte(fmt.Sprintf("$%d", i))
   151		}
   152		return string(qmark.ReplaceAllFunc([]byte(query), dollarInc))
   153	}
   154	
   155	func (kv *keyValue) ping() error {
   156		_, err := kv.SchemaVersion()
   157		return err
   158	}
   159	
   160	func (kv *keyValue) SchemaVersion() (version int, err error) {
   161		err = kv.db.QueryRow("SELECT value FROM meta WHERE metakey='version'").Scan(&version)
   162		return
   163	}
   164	
   165	var validDatabaseRegex = regexp.MustCompile(`^[a-zA-Z0-9\-_]+$`)
   166	
   167	func validDatabaseName(database string) bool {
   168		return validDatabaseRegex.MatchString(database)
   169	}
   170	
   171	func createDB(db *sql.DB, database string) error {
   172		if database == "" {
   173			return errors.New("database name can't be empty")
   174		}
   175	
   176		rows, err := db.Query(`SELECT 1 FROM pg_database WHERE datname = $1`, database)
   177		if err != nil {
   178			return err
   179		}
   180		defer rows.Close()
   181		if rows.Next() {
   182			return nil // database is already created
   183		}
   184	
   185		// Verify database only has runes we expect
   186		if !validDatabaseName(database) {
   187			return fmt.Errorf("Invalid postgres database name: %q", database)
   188		}
   189		_, err = db.Exec(fmt.Sprintf("CREATE DATABASE %s", database))
   190		if err != nil {
   191			return err
   192		}
   193		return err
   194	}
Website layout inspired by memcached.
Content by the authors.