diff --git a/server/embed/config.go b/server/embed/config.go index b10c5dc52..7ffc85266 100644 --- a/server/embed/config.go +++ b/server/embed/config.go @@ -330,6 +330,23 @@ type Config struct { // - https://bugs.chromium.org/p/project-zero/issues/detail?id=1447#c2 // - https://github.com/transmission/transmission/pull/468 // - https://github.com/etcd-io/etcd/issues/9353 + // + // 1. If client connection is secure via HTTPS, allow any hostnames. + // 2. If client connection is not secure and "HostWhitelist" is not empty, + // only allow HTTP requests whose Host field is listed in whitelist. + // + // Note that the client origin policy is enforced whether authentication + // is enabled or not, for tighter controls. + // + // By default, "HostWhitelist" is "*", which allows any hostnames. + // Note that when specifying hostnames, loopback addresses are not added + // automatically. To allow loopback interfaces, leave it empty or set it "*", + // or add them to whitelist manually (e.g. "localhost", "127.0.0.1", etc.). + // + // CVE-2018-5702 reference: + // - https://bugs.chromium.org/p/project-zero/issues/detail?id=1447#c2 + // - https://github.com/transmission/transmission/pull/468 + // - https://github.com/etcd-io/etcd/issues/9353 HostWhitelist map[string]struct{} // UserHandlers is for registering users handlers and only used for @@ -462,6 +479,11 @@ type Config struct { // ServerFeatureGate is a server level feature gate ServerFeatureGate featuregate.FeatureGate + + // BackendType specifies the type of backend storage to use + BackendType string `json:"backend-type"` + // MySQLDSN is the Data Source Name for the MySQL backend + MySQLDSN string `json:"mysql-dsn"` } // configYAML holds the config suitable for yaml parsing @@ -586,6 +608,9 @@ func NewConfig() *Config { AutoCompactionMode: DefaultAutoCompactionMode, ServerFeatureGate: features.NewDefaultServerFeatureGate(DefaultName, nil), + + BackendType: "mysql", + MySQLDSN: "root:password@tcp(localhost:3306)/etcd", } cfg.InitialCluster = cfg.InitialClusterFromName(cfg.Name) return cfg diff --git a/server/etcdmain/config.go b/server/etcdmain/config.go index 60233445c..2650c3000 100644 --- a/server/etcdmain/config.go +++ b/server/etcdmain/config.go @@ -65,6 +65,8 @@ type config struct { configFile string printVersion bool ignored []string + BackendType string `json:"backend-type"` + MySQLDSN string `json:"mysql-dsn"` } // configFlags has the set of flags used for command line parsing a Config diff --git a/server/etcdmain/etcd.go b/server/etcdmain/etcd.go index ebb2964de..7ce9563df 100644 --- a/server/etcdmain/etcd.go +++ b/server/etcdmain/etcd.go @@ -30,6 +30,7 @@ import ( "go.etcd.io/etcd/server/v3/embed" "go.etcd.io/etcd/server/v3/etcdserver/api/v2discovery" "go.etcd.io/etcd/server/v3/etcdserver/errors" + _ "github.com/go-sql-driver/mysql" ) type dirType string diff --git a/server/go.mod b/server/go.mod index 8e3dea92f..fbb4b50dd 100644 --- a/server/go.mod +++ b/server/go.mod @@ -48,12 +48,14 @@ require ( ) require ( + filippo.io/edwards25519 v1.1.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.4.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect diff --git a/server/go.sum b/server/go.sum index ddd0a084c..771ac6b38 100644 --- a/server/go.sum +++ b/server/go.sum @@ -1,4 +1,6 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= @@ -32,6 +34,8 @@ github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= diff --git a/server/storage/backend/backend.go b/server/storage/backend/backend.go index 95f5cf96f..fe3888b2e 100644 --- a/server/storage/backend/backend.go +++ b/server/storage/backend/backend.go @@ -129,6 +129,13 @@ type backend struct { lg *zap.Logger } +type BackendType string + +const ( + BackendTypeBoltDB BackendType = "boltdb" + BackendTypeMySQL BackendType = "mysql" +) + type BackendConfig struct { // Path is the file path to the backend file. Path string @@ -149,6 +156,9 @@ type BackendConfig struct { // Hooks are getting executed during lifecycle of Backend's transactions. Hooks Hooks + + BackendType BackendType + MySQLDSN string // MySQL Data Source Name } type BackendConfigOption func(*BackendConfig) @@ -159,11 +169,25 @@ func DefaultBackendConfig(lg *zap.Logger) BackendConfig { BatchLimit: defaultBatchLimit, MmapSize: InitialMmapSize, Logger: lg, + BackendType: BackendTypeMySQL, // Change this to MySQL + MySQLDSN: "root:password@tcp(localhost:3306)/etcd", // Default MySQL DSN } } func New(bcfg BackendConfig) Backend { - return newBackend(bcfg) + switch bcfg.BackendType { + case BackendTypeBoltDB: + return newBackend(bcfg) + case BackendTypeMySQL: + be, err := newMySQLBackend(bcfg) + if err != nil { + bcfg.Logger.Panic("failed to create MySQL backend", zap.Error(err)) + } + return be + default: + bcfg.Logger.Panic("unknown backend type", zap.String("type", string(bcfg.BackendType))) + return nil + } } func WithMmapSize(size uint64) BackendConfigOption { diff --git a/server/storage/backend/backend_mysql.go b/server/storage/backend/backend_mysql.go new file mode 100644 index 000000000..96f82dc20 --- /dev/null +++ b/server/storage/backend/backend_mysql.go @@ -0,0 +1,325 @@ +package backend + +import ( + "database/sql" + "fmt" + "io" + "sync" + "time" + + _ "github.com/go-sql-driver/mysql" + "go.uber.org/zap" +) + +type mysqlBackend struct { + db *sql.DB + lg *zap.Logger + mu sync.RWMutex + stopc chan struct{} + donec chan struct{} + hooks Hooks + postLockInsideApplyHook func() +} + +func newMySQLBackend(bcfg BackendConfig) (*mysqlBackend, error) { + db, err := sql.Open("mysql", bcfg.MySQLDSN) + if err != nil { + return nil, err + } + + // Set connection pool settings + db.SetMaxOpenConns(100) + db.SetMaxIdleConns(10) + db.SetConnMaxLifetime(time.Hour) + + be := &mysqlBackend{ + db: db, + lg: bcfg.Logger, + stopc: make(chan struct{}), + donec: make(chan struct{}), + hooks: bcfg.Hooks, + } + + // Initialize tables + if err := be.initTables(); err != nil { + return nil, err + } + + return be, nil +} + +func (m *mysqlBackend) initTables() error { + _, err := m.db.Exec(` + CREATE TABLE IF NOT EXISTS kv_store ( + key VARBINARY(512) PRIMARY KEY, + value LONGBLOB, + create_revision BIGINT, + mod_revision BIGINT, + version BIGINT + ) + `) + return err +} + +func (m *mysqlBackend) BatchTx() BatchTx { + return &mysqlBatchTx{be: m} +} + +func (m *mysqlBackend) ReadTx() ReadTx { + return &mysqlReadTx{be: m} +} + +func (m *mysqlBackend) ConcurrentReadTx() ReadTx { + return &mysqlReadTx{be: m} +} + +func (m *mysqlBackend) Snapshot() Snapshot { + return &mysqlSnapshot{be: m} +} + +func (m *mysqlBackend) Hash(ignores func([]byte, []byte) bool) (uint32, error) { + // Implement hash calculation for MySQL + // This is a placeholder implementation; you should replace it with your actual logic. + return 0, fmt.Errorf("Hash not implemented for MySQL backend") +} + +func (m *mysqlBackend) Size() int64 { + var size int64 + row := m.db.QueryRow("SELECT SUM(DATA_LENGTH + INDEX_LENGTH) FROM information_schema.TABLES WHERE TABLE_SCHEMA = DATABASE()") + err := row.Scan(&size) + if err != nil { + m.lg.Error("failed to get database size", zap.Error(err)) + return 0 + } + return size +} + +func (m *mysqlBackend) SizeInUse() int64 { + return m.Size() // For MySQL, Size and SizeInUse are the same +} + +func (m *mysqlBackend) OpenReadTxN() int64 { + // MySQL doesn't have a concept of read transactions, so return 0 + return 0 +} + +func (m *mysqlBackend) Defrag() error { + // MySQL handles fragmentation internally, so this is a no-op + return nil +} + +func (m *mysqlBackend) ForceCommit() { + // MySQL commits automatically, so this is a no-op +} + +func (m *mysqlBackend) Close() error { + close(m.stopc) + <-m.donec + return m.db.Close() +} + +func (m *mysqlBackend) SetTxPostLockInsideApplyHook(hook func()) { + m.postLockInsideApplyHook = hook +} + +// mysqlBatchTx implements BatchTx interface +type mysqlBatchTx struct { + be *mysqlBackend + tx *sql.Tx +} + +func (t *mysqlBatchTx) Lock() { + t.be.mu.Lock() +} + +func (t *mysqlBatchTx) Unlock() { + t.be.mu.Unlock() +} + +func (t *mysqlBatchTx) UnsafeCreateBucket(bucket Bucket) { + // MySQL doesn't use buckets, so this is a no-op + t.be.lg.Warn("UnsafeCreateBucket called on MySQL backend", zap.String("bucket", "n/a")) +} + +func (t *mysqlBatchTx) UnsafePut(bucket Bucket, key []byte, value []byte) { + if t.tx == nil { + var err error + t.tx, err = t.be.db.Begin() + if err != nil { + t.be.lg.Error("failed to begin transaction", zap.Error(err)) + return + } + } + _, err := t.tx.Exec("INSERT INTO kv_store (key, value) VALUES (?, ?) ON DUPLICATE KEY UPDATE value = ?", key, value, value) + if err != nil { + t.be.lg.Error("failed to put key-value pair", zap.Error(err)) + } +} + +func (t *mysqlBatchTx) UnsafeSeqPut(bucket Bucket, key []byte, value []byte) { + t.UnsafePut(bucket, key, value) +} + +func (t *mysqlBatchTx) UnsafeDelete(bucket Bucket, key []byte) { + if t.tx == nil { + var err error + t.tx, err = t.be.db.Begin() + if err != nil { + t.be.lg.Error("failed to begin transaction", zap.Error(err)) + return + } + } + _, err := t.tx.Exec("DELETE FROM kv_store WHERE key = ?", key) + if err != nil { + t.be.lg.Error("failed to delete key", zap.Error(err)) + } +} + +func (t *mysqlBatchTx) UnsafeDeleteBucket(bucket Bucket) { + t.be.lg.Warn("UnsafeDeleteBucket called on MySQL backend", zap.String("bucket", "n/a")) + // No-op for MySQL as it doesn't use buckets +} + +func (t *mysqlBatchTx) UnsafeForEach(bucket Bucket, visitor func(k, v []byte) error) error { + rows, err := t.be.db.Query("SELECT key, value FROM kv_store") + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var k, v []byte + if err := rows.Scan(&k, &v); err != nil { + return err + } + if err := visitor(k, v); err != nil { + return err + } + } + return rows.Err() +} + +func (t *mysqlBatchTx) UnsafeRange(bucket Bucket, key, endKey []byte, limit int64) ([][]byte, [][]byte) { + var keys, values [][]byte + query := "SELECT key, value FROM kv_store WHERE key >= ? AND key < ? ORDER BY key LIMIT ?" + rows, err := t.be.db.Query(query, key, endKey, limit) + if err != nil { + t.be.lg.Error("failed to query range", zap.Error(err)) + return nil, nil + } + defer rows.Close() + for rows.Next() { + var k, v []byte + if err := rows.Scan(&k, &v); err != nil { + t.be.lg.Error("failed to scan row", zap.Error(err)) + continue + } + keys = append(keys, k) + values = append(values, v) + } + return keys, values +} + +func (t *mysqlBatchTx) Commit() { + if t.tx != nil { + err := t.tx.Commit() + if err != nil { + t.be.lg.Error("failed to commit transaction", zap.Error(err)) + } + t.tx = nil + } +} + +func (t *mysqlBatchTx) CommitAndStop() { + t.Commit() + // Additional cleanup if needed +} + +func (t *mysqlBatchTx) LockInsideApply() { + t.be.mu.Lock() + if t.be.postLockInsideApplyHook != nil { + t.be.postLockInsideApplyHook() + } +} + +func (t *mysqlBatchTx) LockOutsideApply() { + t.be.mu.Lock() +} + +// mysqlReadTx implements ReadTx interface +type mysqlReadTx struct { + be *mysqlBackend +} + +func (t *mysqlReadTx) Lock() {} +func (t *mysqlReadTx) Unlock() {} +func (t *mysqlReadTx) Reset() {} +func (t *mysqlReadTx) RLock() {} +func (t *mysqlReadTx) RUnlock() {} + +func (t *mysqlReadTx) UnsafeRange(bucket Bucket, key, endKey []byte, limit int64) ([][]byte, [][]byte) { + var keys, values [][]byte + query := "SELECT key, value FROM kv_store WHERE key >= ? AND key < ? ORDER BY key LIMIT ?" + rows, err := t.be.db.Query(query, key, endKey, limit) + if err != nil { + t.be.lg.Error("failed to query range", zap.Error(err)) + return nil, nil + } + defer rows.Close() + for rows.Next() { + var k, v []byte + if err := rows.Scan(&k, &v); err != nil { + t.be.lg.Error("failed to scan row", zap.Error(err)) + continue + } + keys = append(keys, k) + values = append(values, v) + } + return keys, values +} + +func (t *mysqlReadTx) UnsafeGet(bucket Bucket, key []byte) (value []byte, err error) { + err = t.be.db.QueryRow("SELECT value FROM kv_store WHERE key = ?", key).Scan(&value) + if err == sql.ErrNoRows { + return nil, nil + } + return +} + +func (t *mysqlReadTx) UnsafeForEach(bucket Bucket, visitor func(k, v []byte) error) error { + rows, err := t.be.db.Query("SELECT key, value FROM kv_store") + if err != nil { + return err + } + defer rows.Close() + + for rows.Next() { + var k, v []byte + if err := rows.Scan(&k, &v); err != nil { + return err + } + if err := visitor(k, v); err != nil { + return err + } + } + return rows.Err() +} + +// mysqlSnapshot implements Snapshot interface +type mysqlSnapshot struct { + be *mysqlBackend +} + +func (s *mysqlSnapshot) Close() error { + // MySQL doesn't require explicit snapshot closing + return nil +} + +func (s *mysqlSnapshot) Size() int64 { + return s.be.Size() +} + +func (s *mysqlSnapshot) WriteTo(w io.Writer) (int64, error) { + // Implement snapshot writing logic + return 0, fmt.Errorf("WriteTo not implemented for MySQL snapshot") +} diff --git a/server/storage/backend/backend_mysql_test.go b/server/storage/backend/backend_mysql_test.go new file mode 100644 index 000000000..220c41c28 --- /dev/null +++ b/server/storage/backend/backend_mysql_test.go @@ -0,0 +1,188 @@ +package backend + +import ( + "database/sql" + "os" + "testing" + "time" + + "github.com/go-sql-driver/mysql" + "go.uber.org/zap" +) + +var testMySQLDSN string + +// Define TestBucket +var TestBucket Bucket = testBucket("test") + +type testBucket string + +func (b testBucket) ID() BucketID { + return BucketID(0) // You might want to implement a proper ID system +} + +func (b testBucket) Name() []byte { + return []byte(b) +} + +func (b testBucket) String() string { + return string(b) +} + +func (b testBucket) IsSafeRangeBucket() bool { + // Implement this method based on your requirements + // For testing purposes, we'll return true + return true +} + +func init() { + // Set up the test MySQL DSN + // You might want to make this configurable via environment variables + cfg := mysql.NewConfig() + cfg.User = "root" + cfg.Passwd = "password" + cfg.DBName = "etcd_test" + cfg.ParseTime = true + cfg.Loc = time.UTC + + testMySQLDSN = cfg.FormatDSN() +} + +func setupTestMySQL(t *testing.T) *mysqlBackend { + db, err := sql.Open("mysql", testMySQLDSN) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + + // Create test database + _, err = db.Exec("CREATE DATABASE IF NOT EXISTS etcd_test") + if err != nil { + t.Fatalf("Failed to create test database: %v", err) + } + + db.Close() + + // Create backend + lg, _ := zap.NewDevelopment() + bcfg := BackendConfig{ + Logger: lg, + MySQLDSN: testMySQLDSN, + } + + be, err := newMySQLBackend(bcfg) + if err != nil { + t.Fatalf("Failed to create MySQL backend: %v", err) + } + + return be +} + +func teardownTestMySQL(t *testing.T, be *mysqlBackend) { + be.Close() + + // Drop test database + db, err := sql.Open("mysql", testMySQLDSN) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + _, err = db.Exec("DROP DATABASE IF EXISTS etcd_test") + if err != nil { + t.Fatalf("Failed to drop test database: %v", err) + } +} + +func TestMySQLBackend_BatchTx(t *testing.T) { + be := setupTestMySQL(t) + defer teardownTestMySQL(t, be) + + tx := be.BatchTx() + + // Test Put and Get + bucket := TestBucket + key := []byte("testkey") + value := []byte("testvalue") + + tx.Lock() + tx.UnsafePut(bucket, key, value) + tx.Unlock() + + tx.Commit() + + rtx := be.ReadTx() + rtx.RLock() + gotValues, _ := rtx.UnsafeRange(bucket, key, nil, 0) + rtx.RUnlock() + + if len(gotValues) == 0 || string(gotValues[0]) != string(value) { + t.Errorf("Got %s, want %s", string(gotValues[0]), string(value)) + } + + // Test Delete + tx.Lock() + tx.UnsafeDelete(bucket, key) + tx.Unlock() + + tx.Commit() + + rtx.RLock() + gotValues, _ = rtx.UnsafeRange(bucket, key, nil, 0) + rtx.RUnlock() + + if len(gotValues) != 0 { + t.Errorf("Got %s, want nil", string(gotValues[0])) + } +} + +func TestMySQLBackend_ReadTx(t *testing.T) { + be := setupTestMySQL(t) + defer teardownTestMySQL(t, be) + + tx := be.BatchTx() + + // Insert test data + bucket := TestBucket + testData := map[string]string{ + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + tx.Lock() + for k, v := range testData { + tx.UnsafePut(bucket, []byte(k), []byte(v)) + } + tx.Unlock() + + tx.Commit() + + // Test Range + rtx := be.ReadTx() + rtx.RLock() + keys, values := rtx.UnsafeRange(bucket, []byte("key"), []byte("key4"), 0) + rtx.RUnlock() + + if len(keys) != len(testData) || len(values) != len(testData) { + t.Errorf("Got %d keys and %d values, want %d each", len(keys), len(values), len(testData)) + } + + for i, key := range keys { + value := values[i] + if testData[string(key)] != string(value) { + t.Errorf("For key %s, got value %s, want %s", string(key), string(value), testData[string(key)]) + } + } +} + +func TestMain(m *testing.M) { + // Set up any global test environment here if needed + + // Run the tests + code := m.Run() + + // Tear down any global test environment here if needed + + os.Exit(code) +} +