From 2429b623fc1c9c4398d60fc16818381bddda80f5 Mon Sep 17 00:00:00 2001 From: Ori Newman Date: Wed, 18 Sep 2019 13:51:20 +0300 Subject: [PATCH] [NOD-327] Add --migrate cli flag to API server (#407) * [NOD-327] Add --migrate cli flag to API server * [NOD-327] Change log messages * [NOD-327] Remove `required` flag from API server RPC CLI arguments * [NOD-327] Add database version in migrations logs --- apiserver/config/config.go | 19 +++++++-- apiserver/database/database.go | 75 +++++++++++++++++++++++++--------- apiserver/main.go | 8 ++++ 3 files changed, 80 insertions(+), 22 deletions(-) diff --git a/apiserver/config/config.go b/apiserver/config/config.go index 1c17b6f3b..3710b7765 100644 --- a/apiserver/config/config.go +++ b/apiserver/config/config.go @@ -23,9 +23,9 @@ var ( // Config defines the configuration options for the API server. type Config struct { LogDir string `long:"logdir" description:"Directory to log output."` - RPCUser string `short:"u" long:"rpcuser" description:"RPC username" required:"true"` - RPCPassword string `short:"P" long:"rpcpass" default-mask:"-" description:"RPC password" required:"true"` - RPCServer string `short:"s" long:"rpcserver" description:"RPC server to connect to" required:"true"` + RPCUser string `short:"u" long:"rpcuser" description:"RPC username"` + RPCPassword string `short:"P" long:"rpcpass" default-mask:"-" description:"RPC password"` + RPCServer string `short:"s" long:"rpcserver" description:"RPC server to connect to"` RPCCert string `short:"c" long:"rpccert" description:"RPC server certificate chain for validation"` DisableTLS bool `long:"notls" description:"Disable TLS"` DBAddress string `long:"dbaddress" description:"Database address"` @@ -33,6 +33,7 @@ type Config struct { DBPassword string `long:"dbpass" description:"Database password" required:"true"` DBName string `long:"dbname" description:"Database name" required:"true"` HTTPListen string `long:"listen" description:"HTTP address to listen on (default: 0.0.0.0:8080)"` + Migrate bool `long:"migrate" description:"Migrate the database to the latest version. The server will not start when using this flag."` } // Parse parses the CLI arguments and returns a config struct. @@ -49,6 +50,18 @@ func Parse() (*Config, error) { return nil, err } + if !cfg.Migrate { + if cfg.RPCUser == "" { + return nil, errors.New("--rpcuser is required if --migrate flag is not used") + } + if cfg.RPCPassword == "" { + return nil, errors.New("--rpcpass is required if --migrate flag is not used") + } + if cfg.RPCServer == "" { + return nil, errors.New("--rpcserver is required if --migrate flag is not used") + } + } + if cfg.RPCCert == "" && !cfg.DisableTLS { return nil, errors.New("--notls has to be disabled if --cert is used") } diff --git a/apiserver/database/database.go b/apiserver/database/database.go index 55e3d2be9..1b7debebc 100644 --- a/apiserver/database/database.go +++ b/apiserver/database/database.go @@ -34,13 +34,17 @@ func (l gormLogger) Print(v ...interface{}) { // config variable. func Connect(cfg *config.Config) error { connectionString := buildConnectionString(cfg) - isCurrent, err := isCurrent(connectionString) + migrator, driver, err := openMigrator(connectionString) + if err != nil { + return err + } + isCurrent, version, err := isCurrent(migrator, driver) if err != nil { return fmt.Errorf("Error checking whether the database is current: %s", err) } if !isCurrent { - return fmt.Errorf("Database is not current. Please migrate" + - " the database and start again.") + return fmt.Errorf("Database is not current (version %d). Please migrate"+ + " the database by running the server with --migrate flag and then run it again.", version) } db, err = gorm.Open("mysql", connectionString) @@ -68,35 +72,68 @@ func buildConnectionString(cfg *config.Config) string { // isCurrent resolves whether the database is on the latest // version of the schema. -func isCurrent(connectionString string) (bool, error) { - driver, err := source.Open("file://migrations") - if err != nil { - return false, err - } - migrator, err := migrate.NewWithSourceInstance( - "migrations", driver, "mysql://"+connectionString) - if err != nil { - return false, err - } - +func isCurrent(migrator *migrate.Migrate, driver source.Driver) (bool, uint, error) { // Get the current version version, isDirty, err := migrator.Version() if err == migrate.ErrNilVersion { - return false, nil + return false, 0, nil } if err != nil { - return false, err + return false, 0, err } if isDirty { - return false, fmt.Errorf("Database is dirty") + return false, 0, fmt.Errorf("Database is dirty") } // The database is current if Next returns ErrNotExist _, err = driver.Next(version) if pathErr, ok := err.(*os.PathError); ok { if pathErr.Err == os.ErrNotExist { - return true, nil + return true, version, nil } } - return false, err + return false, version, err +} + +func openMigrator(connectionString string) (*migrate.Migrate, source.Driver, error) { + driver, err := source.Open("file://migrations") + if err != nil { + return nil, nil, err + } + migrator, err := migrate.NewWithSourceInstance( + "migrations", driver, "mysql://"+connectionString) + if err != nil { + return nil, nil, err + } + return migrator, driver, nil +} + +// Migrate database to the latest version. +func Migrate(cfg *config.Config) error { + connectionString := buildConnectionString(cfg) + migrator, driver, err := openMigrator(connectionString) + if err != nil { + return err + } + isCurrent, version, err := isCurrent(migrator, driver) + if err != nil { + return fmt.Errorf("Error checking whether the database is current: %s", err) + } + if isCurrent { + log.Infof("Database is already up-to-date (version %d)", version) + return nil + } + err = migrator.Up() + if err != nil { + return err + } + version, isDirty, err := migrator.Version() + if err != nil { + return err + } + if isDirty { + return fmt.Errorf("error migrating database: database is dirty") + } + log.Infof("Migrated database to the latest version (version %d)", version) + return nil } diff --git a/apiserver/main.go b/apiserver/main.go index 9711143d9..a6353dd5e 100644 --- a/apiserver/main.go +++ b/apiserver/main.go @@ -23,6 +23,14 @@ func main() { panic(fmt.Errorf("Error parsing command-line arguments: %s", err)) } + if cfg.Migrate { + err := database.Migrate(cfg) + if err != nil { + panic(fmt.Errorf("Error migrating database: %s", err)) + } + return + } + err = database.Connect(cfg) if err != nil { panic(fmt.Errorf("Error connecting to database: %s", err))