Home Download Docs Code Community
     1	/*
     2	Copyright 2011 The Perkeep Authors
     3	
     4	Licensed under the Apache License, Version 2.0 (the "License");
     5	you may not use this file except in compliance with the License.
     6	You may obtain a copy of the License at
     7	
     8	     http://www.apache.org/licenses/LICENSE-2.0
     9	
    10	Unless required by applicable law or agreed to in writing, software
    11	distributed under the License is distributed on an "AS IS" BASIS,
    12	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13	See the License for the specific language governing permissions and
    14	limitations under the License.
    15	*/
    16	
    17	package main
    18	
    19	import (
    20		"bytes"
    21		"database/sql"
    22		"errors"
    23		"flag"
    24		"fmt"
    25		"net"
    26		"os"
    27		"strings"
    28	
    29		"perkeep.org/pkg/cmdmain"
    30		"perkeep.org/pkg/sorted/mongo"
    31		"perkeep.org/pkg/sorted/mysql"
    32		"perkeep.org/pkg/sorted/postgres"
    33		"perkeep.org/pkg/sorted/sqlite"
    34	
    35		_ "github.com/go-sql-driver/mysql"
    36		_ "github.com/lib/pq"
    37		"gopkg.in/mgo.v2"
    38	)
    39	
    40	type dbinitCmd struct {
    41		user     string
    42		password string
    43		host     string
    44		dbName   string
    45		dbType   string
    46		sslMode  string // Postgres SSL mode configuration
    47	
    48		wipe bool
    49		keep bool
    50	}
    51	
    52	func init() {
    53		cmdmain.RegisterMode("dbinit", func(flags *flag.FlagSet) cmdmain.CommandRunner {
    54			cmd := new(dbinitCmd)
    55			flags.StringVar(&cmd.user, "user", "root", "Admin user.")
    56			flags.StringVar(&cmd.password, "password", "", "Admin password.")
    57			flags.StringVar(&cmd.host, "host", "localhost", "host[:port]")
    58			flags.StringVar(&cmd.dbName, "dbname", "", "Database to wipe or create. For sqlite, this is the db filename.")
    59			flags.StringVar(&cmd.dbType, "dbtype", "mysql", "Which RDMS to use; possible values: mysql, postgres, sqlite, mongo.")
    60			flags.StringVar(&cmd.sslMode, "sslmode", "require", "Configure SSL mode for postgres. Possible values: require, verify-full, disable.")
    61	
    62			flags.BoolVar(&cmd.wipe, "wipe", false, "Wipe the database and re-create it?")
    63			flags.BoolVar(&cmd.keep, "ignoreexists", false, "Do nothing if database already exists.")
    64			return cmd
    65		})
    66	}
    67	
    68	func (c *dbinitCmd) Demote() bool { return true }
    69	
    70	func (c *dbinitCmd) Describe() string {
    71		return "Set up the database for the indexer."
    72	}
    73	
    74	func (c *dbinitCmd) Usage() {
    75		fmt.Fprintf(os.Stderr, "Usage: pk [globalopts] dbinit [dbinitopts] \n")
    76	}
    77	
    78	func (c *dbinitCmd) Examples() []string {
    79		return []string{
    80			"-user root -password root -host localhost -dbname camliprod -wipe",
    81		}
    82	}
    83	
    84	func (c *dbinitCmd) RunCommand(args []string) error {
    85		if c.dbName == "" {
    86			return cmdmain.UsageError("--dbname flag required")
    87		}
    88	
    89		if c.dbType != "mysql" && c.dbType != "postgres" && c.dbType != "mongo" {
    90			if c.dbType == "sqlite" {
    91				if !WithSQLite {
    92					return ErrNoSQLite
    93				}
    94			} else {
    95				return cmdmain.UsageError(fmt.Sprintf("--dbtype flag: got %v, want %v", c.dbType, `"mysql" or "postgres" or "sqlite", or "mongo"`))
    96			}
    97		}
    98	
    99		var rootdb *sql.DB
   100		var err error
   101		switch c.dbType {
   102		case "postgres":
   103			conninfo := fmt.Sprintf("user=%s dbname=%s host=%s password=%s sslmode=%s", c.user, "postgres", c.host, c.password, c.sslMode)
   104			rootdb, err = sql.Open("postgres", conninfo)
   105		case "mysql":
   106			// need to use an empty dbname to query tables
   107			rootdb, err = sql.Open("mysql", c.mysqlDSN(""))
   108		case "sqlite":
   109			rootdb, err = sql.Open("sqlite", c.dbName)
   110		}
   111		if err != nil {
   112			exitf("Error connecting to the root %s database: %v", c.dbType, err)
   113		}
   114		defer rootdb.Close()
   115	
   116		// Validate the DSN to avoid confusion here
   117		err = rootdb.Ping()
   118		if err != nil {
   119			exitf("Error connecting to the root %s database: %v", c.dbType, err)
   120		}
   121	
   122		dbname := c.dbName
   123		exists := c.dbExists(rootdb)
   124		if exists {
   125			if c.keep {
   126				return nil
   127			}
   128			if !c.wipe {
   129				return cmdmain.UsageError(fmt.Sprintf("Database %q already exists, but --wipe not given. Stopping.", dbname))
   130			}
   131			if c.dbType == "mongo" {
   132				return c.wipeMongo()
   133			}
   134			if c.dbType != "sqlite" {
   135				do(rootdb, "DROP DATABASE "+dbname)
   136			}
   137		}
   138		switch c.dbType {
   139		case "sqlite":
   140			_, err := os.Create(dbname)
   141			if err != nil {
   142				exitf("Error creating file %v for sqlite db: %v", dbname, err)
   143			}
   144		case "mongo":
   145			return nil
   146		case "postgres":
   147			// because we want string comparison to work as on MySQL and SQLite.
   148			// in particular we want: 'foo|bar' < 'foo}' (which is not the case with an utf8 collation apparently).
   149			do(rootdb, "CREATE DATABASE "+dbname+" LC_COLLATE = 'C' TEMPLATE = template0")
   150		default:
   151			do(rootdb, "CREATE DATABASE "+dbname)
   152		}
   153	
   154		var db *sql.DB
   155		switch c.dbType {
   156		case "postgres":
   157			conninfo := fmt.Sprintf("user=%s dbname=%s host=%s password=%s sslmode=%s", c.user, dbname, c.host, c.password, c.sslMode)
   158			db, err = sql.Open("postgres", conninfo)
   159		case "sqlite":
   160			db, err = sql.Open("sqlite", dbname)
   161		default:
   162			db, err = sql.Open("mysql", c.mysqlDSN(dbname))
   163		}
   164		if err != nil {
   165			return fmt.Errorf("Connecting to the %s %s database: %v", dbname, c.dbType, err)
   166		}
   167		defer db.Close()
   168	
   169		switch c.dbType {
   170		case "postgres":
   171			for _, tableSql := range postgres.SQLCreateTables() {
   172				do(db, tableSql)
   173			}
   174			for _, statement := range postgres.SQLDefineReplace() {
   175				do(db, statement)
   176			}
   177			doQuery(db, fmt.Sprintf(`SELECT replaceintometa('version', '%d')`, postgres.SchemaVersion()))
   178		case "mysql":
   179			if err := mysql.CreateDB(db, dbname); err != nil {
   180				exitf("error in CreateDB(%s): %v", dbname, err)
   181			}
   182			for _, tableSQL := range mysql.SQLCreateTables() {
   183				do(db, tableSQL)
   184			}
   185			do(db, fmt.Sprintf(`REPLACE INTO meta VALUES ('version', '%d')`, mysql.SchemaVersion()))
   186		case "sqlite":
   187			if err := sqlite.InitDB(dbname); err != nil {
   188				exitf("error calling InitDB(%s): %v", dbname, err)
   189			}
   190		}
   191		return nil
   192	}
   193	
   194	func do(db *sql.DB, sql string) {
   195		_, err := db.Exec(sql)
   196		if err != nil {
   197			exitf("Error %q running SQL: %q", err, sql)
   198		}
   199	}
   200	
   201	func doQuery(db *sql.DB, sql string) {
   202		r, err := db.Query(sql)
   203		if err == nil {
   204			r.Close()
   205			return
   206		}
   207		exitf("Error %q running SQL: %q", err, sql)
   208	}
   209	
   210	func (c *dbinitCmd) dbExists(db *sql.DB) bool {
   211		query := "SHOW DATABASES"
   212		switch c.dbType {
   213		case "postgres":
   214			query = "SELECT datname FROM pg_database"
   215		case "mysql":
   216			query = "SHOW DATABASES"
   217		case "sqlite":
   218			// There is no point in using sql.Open because it apparently does
   219			// not return an error when the file does not exist.
   220			fi, err := os.Stat(c.dbName)
   221			return err == nil && fi.Size() > 0
   222		case "mongo":
   223			session, err := c.mongoSession()
   224			if err != nil {
   225				exitf("%v", err)
   226			}
   227			defer session.Close()
   228			n, err := session.DB(c.dbName).C(mongo.CollectionName).Find(nil).Limit(1).Count()
   229			if err != nil {
   230				exitf("%v", err)
   231			}
   232			return n != 0
   233		}
   234		rows, err := db.Query(query)
   235		check(err, query)
   236		defer rows.Close()
   237		for rows.Next() {
   238			var db string
   239			check(rows.Scan(&db), query)
   240			if db == c.dbName {
   241				return true
   242			}
   243		}
   244		return false
   245	}
   246	
   247	func check(err error, query string) {
   248		if err == nil {
   249			return
   250		}
   251		exitf("SQL error for query %q: %v", query, err)
   252	}
   253	
   254	func exitf(format string, args ...interface{}) {
   255		if !strings.HasSuffix(format, "\n") {
   256			format = format + "\n"
   257		}
   258		cmdmain.Errorf(format, args...)
   259		cmdmain.Exit(1)
   260	}
   261	
   262	var WithSQLite = false
   263	
   264	var ErrNoSQLite = errors.New("the command was not built with SQLite support. See https://code.google.com/p/camlistore/wiki/SQLite" + compileHint())
   265	
   266	func compileHint() string {
   267		return ""
   268	}
   269	
   270	// mongoSession returns an *mgo.Session or nil if c.dbtype is
   271	// not "mongo" or if there was an error.
   272	func (c *dbinitCmd) mongoSession() (*mgo.Session, error) {
   273		if c.dbType != "mongo" {
   274			return nil, nil
   275		}
   276		url := ""
   277		if c.user == "" || c.password == "" {
   278			url = c.host
   279		} else {
   280			url = c.user + ":" + c.password + "@" + c.host + "/" + c.dbName
   281		}
   282		return mgo.Dial(url)
   283	}
   284	
   285	// wipeMongo erases all documents from the mongo collection
   286	// if c.dbType is "mongo".
   287	func (c *dbinitCmd) wipeMongo() error {
   288		if c.dbType != "mongo" {
   289			return nil
   290		}
   291		session, err := c.mongoSession()
   292		if err != nil {
   293			return err
   294		}
   295		defer session.Close()
   296		if _, err := session.DB(c.dbName).C(mongo.CollectionName).RemoveAll(nil); err != nil {
   297			return err
   298		}
   299		return nil
   300	}
   301	
   302	func (c *dbinitCmd) mysqlDSN(dbname string) string {
   303		var buf bytes.Buffer
   304		fmt.Fprintf(&buf, "%s:%s@", c.user, c.password)
   305		if c.host != "localhost" {
   306			host := c.host
   307			if _, _, err := net.SplitHostPort(host); err != nil {
   308				host = net.JoinHostPort(host, "3306")
   309			}
   310			fmt.Fprintf(&buf, "tcp(%s)", host)
   311		}
   312		fmt.Fprintf(&buf, "/%s", dbname)
   313		return buf.String()
   314	}
Website layout inspired by memcached.
Content by the authors.