diff --git a/etcd_test.go b/etcd_test.go index cb43a2a4f..786c6a745 100644 --- a/etcd_test.go +++ b/etcd_test.go @@ -6,6 +6,8 @@ import ( "github.com/coreos/go-etcd/etcd" "math/rand" "net/http" + "net/http/httptest" + "net/url" "os" "strconv" "strings" @@ -54,6 +56,53 @@ func TestSingleNode(t *testing.T) { } } +// TestInternalVersionFail will ensure that etcd does not come up if the internal raft +// versions do not match. +func TestInternalVersionFail(t *testing.T) { + checkedVersion := false + testMux := http.NewServeMux() + + testMux.HandleFunc("/version", func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "This is not a version number") + checkedVersion = true + }) + + testMux.HandleFunc("/join", func(w http.ResponseWriter, r *http.Request) { + t.Fatal("should not attempt to join!") + }) + + ts := httptest.NewServer(testMux) + defer ts.Close() + + fakeURL, _ := url.Parse(ts.URL) + + procAttr := new(os.ProcAttr) + procAttr.Files = []*os.File{nil, os.Stdout, os.Stderr} + args := []string{"etcd", "-n=node1", "-f", "-d=/tmp/node1", "-vv", "-C="+fakeURL.Host} + + process, err := os.StartProcess("etcd", args, procAttr) + if err != nil { + t.Fatal("start process failed:" + err.Error()) + return + } + defer process.Kill() + + time.Sleep(time.Second) + + _, err = http.Get("http://127.0.0.1:4001") + + if err == nil { + t.Fatal("etcd node should not be up") + return + } + + if checkedVersion == false { + t.Fatal("etcd did not check the version") + return + } +} + + // This test creates a single node and then set a value to it. // Then this test kills the node and restart it and tries to get the value again. func TestSingleNodeRecovery(t *testing.T) { diff --git a/raft_server.go b/raft_server.go index ec4a3feff..800da0152 100644 --- a/raft_server.go +++ b/raft_server.go @@ -4,6 +4,7 @@ import ( "bytes" "crypto/tls" "encoding/json" + "io/ioutil" "fmt" etcdErr "github.com/coreos/etcd/error" "github.com/coreos/go-raft" @@ -163,15 +164,41 @@ func (r *raftServer) startTransport(scheme string, tlsConf tls.Config) { } +func getLeaderVersion(t transporter, versionURL url.URL) (string, error) { + resp, err := t.Get(versionURL.String()) + + if err != nil { + return "", err + } + + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + + return string(body), nil +} + // Send join requests to the leader. func joinCluster(s *raft.Server, raftURL string, scheme string) error { var b bytes.Buffer - json.NewEncoder(&b).Encode(newJoinCommand()) - // t must be ok t, _ := r.Transporter().(transporter) + // Our version must match the leaders version + versionURL := url.URL{Host: raftURL, Scheme: scheme, Path: "/version"} + version, err := getLeaderVersion(t, versionURL) + if err != nil { + return fmt.Errorf("Unable to join: %v", err) + } + + // TODO: versioning of the internal protocol. See: + // Documentation/internatl-protocol-versioning.md + if version != r.version { + return fmt.Errorf("Unable to join: internal version mismatch, entire cluster must be running identical versions of etcd") + } + + json.NewEncoder(&b).Encode(newJoinCommand()) + joinURL := url.URL{Host: raftURL, Scheme: scheme, Path: "/join"} debugf("Send Join Request to %s", raftURL)