rafthttp: version enforcement on rafthttp messages

This PR sets etcd version and min cluster version in request header,
and let server check version compatibility. rafthttp server
will reject any message from peer with incompatible version(too low
version or too high version), and print out warning logs.
This commit is contained in:
Yicheng Qin 2015-06-02 11:46:36 -07:00
parent 8825af47a0
commit c371d8c65c
9 changed files with 290 additions and 122 deletions

View File

@ -15,6 +15,7 @@
package rafthttp package rafthttp
import ( import (
"errors"
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
@ -34,6 +35,9 @@ const (
var ( var (
RaftPrefix = "/raft" RaftPrefix = "/raft"
RaftStreamPrefix = path.Join(RaftPrefix, "stream") RaftStreamPrefix = path.Join(RaftPrefix, "stream")
errIncompatibleVersion = errors.New("incompatible version")
errClusterIDMismatch = errors.New("cluster ID mismatch")
) )
func NewHandler(r Raft, cid types.ID) http.Handler { func NewHandler(r Raft, cid types.ID) http.Handler {
@ -72,13 +76,19 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := checkVersionCompability(r.Header.Get("X-Server-From"), serverVersion(r.Header), minClusterVersion(r.Header)); err != nil {
log.Printf("rafthttp: request received was ignored (%v)", err)
http.Error(w, errIncompatibleVersion.Error(), http.StatusPreconditionFailed)
return
}
wcid := h.cid.String() wcid := h.cid.String()
w.Header().Set("X-Etcd-Cluster-ID", wcid) w.Header().Set("X-Etcd-Cluster-ID", wcid)
gcid := r.Header.Get("X-Etcd-Cluster-ID") gcid := r.Header.Get("X-Etcd-Cluster-ID")
if gcid != wcid { if gcid != wcid {
log.Printf("rafthttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid) log.Printf("rafthttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid)
http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed) http.Error(w, errClusterIDMismatch.Error(), http.StatusPreconditionFailed)
return return
} }
@ -126,17 +136,23 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
w.Header().Set("X-Server-Version", version.Version)
if err := checkVersionCompability(r.Header.Get("X-Server-From"), serverVersion(r.Header), minClusterVersion(r.Header)); err != nil {
log.Printf("rafthttp: request received was ignored (%v)", err)
http.Error(w, errIncompatibleVersion.Error(), http.StatusPreconditionFailed)
return
}
wcid := h.cid.String() wcid := h.cid.String()
w.Header().Set("X-Etcd-Cluster-ID", wcid) w.Header().Set("X-Etcd-Cluster-ID", wcid)
if gcid := r.Header.Get("X-Etcd-Cluster-ID"); gcid != wcid { if gcid := r.Header.Get("X-Etcd-Cluster-ID"); gcid != wcid {
log.Printf("rafthttp: streaming request ignored due to cluster ID mismatch got %s want %s", gcid, wcid) log.Printf("rafthttp: streaming request ignored due to cluster ID mismatch got %s want %s", gcid, wcid)
http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed) http.Error(w, errClusterIDMismatch.Error(), http.StatusPreconditionFailed)
return return
} }
w.Header().Add("X-Server-Version", version.Version)
var t streamType var t streamType
switch path.Dir(r.URL.Path) { switch path.Dir(r.URL.Path) {
// backward compatibility // backward compatibility

View File

@ -124,7 +124,7 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
r: r, r: r,
msgAppWriter: startStreamWriter(to, fs, r), msgAppWriter: startStreamWriter(to, fs, r),
writer: startStreamWriter(to, fs, r), writer: startStreamWriter(to, fs, r),
pipeline: newPipeline(tr, picker, to, cid, fs, r, errorc), pipeline: newPipeline(tr, picker, local, to, cid, fs, r, errorc),
sendc: make(chan raftpb.Message), sendc: make(chan raftpb.Message),
recvc: make(chan raftpb.Message, recvBufSize), recvc: make(chan raftpb.Message, recvBufSize),
propc: make(chan raftpb.Message, maxPendingProposals), propc: make(chan raftpb.Message, maxPendingProposals),

View File

@ -17,8 +17,10 @@ package rafthttp
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@ -27,6 +29,7 @@ import (
"github.com/coreos/etcd/pkg/types" "github.com/coreos/etcd/pkg/types"
"github.com/coreos/etcd/raft" "github.com/coreos/etcd/raft"
"github.com/coreos/etcd/raft/raftpb" "github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
) )
const ( const (
@ -39,7 +42,7 @@ const (
) )
type pipeline struct { type pipeline struct {
id types.ID from, to types.ID
cid types.ID cid types.ID
tr http.RoundTripper tr http.RoundTripper
@ -58,9 +61,10 @@ type pipeline struct {
errored error errored error
} }
func newPipeline(tr http.RoundTripper, picker *urlPicker, id, cid types.ID, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline { func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline {
p := &pipeline{ p := &pipeline{
id: id, from: from,
to: to,
cid: cid, cid: cid,
tr: tr, tr: tr,
picker: picker, picker: picker,
@ -94,11 +98,11 @@ func (p *pipeline) handle() {
reportSentFailure(pipelineMsg, m) reportSentFailure(pipelineMsg, m)
if p.errored == nil || p.errored.Error() != err.Error() { if p.errored == nil || p.errored.Error() != err.Error() {
log.Printf("pipeline: error posting to %s: %v", p.id, err) log.Printf("pipeline: error posting to %s: %v", p.to, err)
p.errored = err p.errored = err
} }
if p.active { if p.active {
log.Printf("pipeline: the connection with %s became inactive", p.id) log.Printf("pipeline: the connection with %s became inactive", p.to)
p.active = false p.active = false
} }
if m.Type == raftpb.MsgApp && p.fs != nil { if m.Type == raftpb.MsgApp && p.fs != nil {
@ -110,7 +114,7 @@ func (p *pipeline) handle() {
} }
} else { } else {
if !p.active { if !p.active {
log.Printf("pipeline: the connection with %s became active", p.id) log.Printf("pipeline: the connection with %s became active", p.to)
p.active = true p.active = true
p.errored = nil p.errored = nil
} }
@ -138,19 +142,35 @@ func (p *pipeline) post(data []byte) error {
return err return err
} }
req.Header.Set("Content-Type", "application/protobuf") req.Header.Set("Content-Type", "application/protobuf")
req.Header.Set("X-Server-From", p.from.String())
req.Header.Set("X-Server-Version", version.Version)
req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
req.Header.Set("X-Etcd-Cluster-ID", p.cid.String()) req.Header.Set("X-Etcd-Cluster-ID", p.cid.String())
resp, err := p.tr.RoundTrip(req) resp, err := p.tr.RoundTrip(req)
if err != nil { if err != nil {
p.picker.unreachable(u) p.picker.unreachable(u)
return err return err
} }
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
p.picker.unreachable(u)
return err
}
resp.Body.Close() resp.Body.Close()
switch resp.StatusCode { switch resp.StatusCode {
case http.StatusPreconditionFailed: case http.StatusPreconditionFailed:
log.Printf("rafthttp: request sent was ignored due to cluster ID mismatch (remote[%s]:%s, local:%s)", switch strings.TrimSuffix(string(b), "\n") {
uu.Host, resp.Header.Get("X-Etcd-Cluster-ID"), p.cid) case errIncompatibleVersion.Error():
return fmt.Errorf("cluster ID mismatch") log.Printf("rafthttp: request sent was ignored by peer %s (server version incompatible)", p.to)
return errIncompatibleVersion
case errClusterIDMismatch.Error():
log.Printf("rafthttp: request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)",
p.to, resp.Header.Get("X-Etcd-Cluster-ID"), p.cid)
return errClusterIDMismatch
default:
return fmt.Errorf("unhandled error %q when precondition failed", string(b))
}
case http.StatusForbidden: case http.StatusForbidden:
err := fmt.Errorf("the member has been permanently removed from the cluster") err := fmt.Errorf("the member has been permanently removed from the cluster")
select { select {

View File

@ -16,6 +16,7 @@ package rafthttp
import ( import (
"errors" "errors"
"io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"sync" "sync"
@ -25,6 +26,7 @@ import (
"github.com/coreos/etcd/pkg/testutil" "github.com/coreos/etcd/pkg/testutil"
"github.com/coreos/etcd/pkg/types" "github.com/coreos/etcd/pkg/types"
"github.com/coreos/etcd/raft/raftpb" "github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
) )
// TestPipelineSend tests that pipeline could send data using roundtripper // TestPipelineSend tests that pipeline could send data using roundtripper
@ -33,7 +35,7 @@ func TestPipelineSend(t *testing.T) {
tr := &roundTripperRecorder{} tr := &roundTripperRecorder{}
picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
fs := &stats.FollowerStats{} fs := &stats.FollowerStats{}
p := newPipeline(tr, picker, types.ID(1), types.ID(1), fs, &fakeRaft{}, nil) p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p.msgc <- raftpb.Message{Type: raftpb.MsgApp} p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
p.stop() p.stop()
@ -52,7 +54,7 @@ func TestPipelineExceedMaximalServing(t *testing.T) {
tr := newRoundTripperBlocker() tr := newRoundTripperBlocker()
picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
fs := &stats.FollowerStats{} fs := &stats.FollowerStats{}
p := newPipeline(tr, picker, types.ID(1), types.ID(1), fs, &fakeRaft{}, nil) p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
// keep the sender busy and make the buffer full // keep the sender busy and make the buffer full
// nothing can go out as we block the sender // nothing can go out as we block the sender
@ -92,7 +94,7 @@ func TestPipelineExceedMaximalServing(t *testing.T) {
func TestPipelineSendFailed(t *testing.T) { func TestPipelineSendFailed(t *testing.T) {
picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
fs := &stats.FollowerStats{} fs := &stats.FollowerStats{}
p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(1), types.ID(1), fs, &fakeRaft{}, nil) p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p.msgc <- raftpb.Message{Type: raftpb.MsgApp} p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
p.stop() p.stop()
@ -107,7 +109,7 @@ func TestPipelineSendFailed(t *testing.T) {
func TestPipelinePost(t *testing.T) { func TestPipelinePost(t *testing.T) {
tr := &roundTripperRecorder{} tr := &roundTripperRecorder{}
picker := mustNewURLPicker(t, []string{"http://localhost:2380"}) picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
p := newPipeline(tr, picker, types.ID(1), types.ID(1), nil, &fakeRaft{}, nil) p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, nil)
if err := p.post([]byte("some data")); err != nil { if err := p.post([]byte("some data")); err != nil {
t.Fatalf("unexpect post error: %v", err) t.Fatalf("unexpect post error: %v", err)
} }
@ -122,6 +124,12 @@ func TestPipelinePost(t *testing.T) {
if g := tr.Request().Header.Get("Content-Type"); g != "application/protobuf" { if g := tr.Request().Header.Get("Content-Type"); g != "application/protobuf" {
t.Errorf("content type = %s, want %s", g, "application/protobuf") t.Errorf("content type = %s, want %s", g, "application/protobuf")
} }
if g := tr.Request().Header.Get("X-Server-Version"); g != version.Version {
t.Errorf("version = %s, want %s", g, version.Version)
}
if g := tr.Request().Header.Get("X-Min-Cluster-Version"); g != version.MinClusterVersion {
t.Errorf("min version = %s, want %s", g, version.MinClusterVersion)
}
if g := tr.Request().Header.Get("X-Etcd-Cluster-ID"); g != "1" { if g := tr.Request().Header.Get("X-Etcd-Cluster-ID"); g != "1" {
t.Errorf("cluster id = %s, want %s", g, "1") t.Errorf("cluster id = %s, want %s", g, "1")
} }
@ -148,7 +156,7 @@ func TestPipelinePostBad(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
picker := mustNewURLPicker(t, []string{tt.u}) picker := mustNewURLPicker(t, []string{tt.u})
p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(1), types.ID(1), nil, &fakeRaft{}, make(chan error)) p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, make(chan error))
err := p.post([]byte("some data")) err := p.post([]byte("some data"))
p.stop() p.stop()
@ -169,7 +177,7 @@ func TestPipelinePostErrorc(t *testing.T) {
for i, tt := range tests { for i, tt := range tests {
picker := mustNewURLPicker(t, []string{tt.u}) picker := mustNewURLPicker(t, []string{tt.u})
errorc := make(chan error, 1) errorc := make(chan error, 1)
p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(1), types.ID(1), nil, &fakeRaft{}, errorc) p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, errorc)
p.post([]byte("some data")) p.post([]byte("some data"))
p.stop() p.stop()
select { select {
@ -227,5 +235,5 @@ func (t *roundTripperRecorder) Request() *http.Request {
type nopReadCloser struct{} type nopReadCloser struct{}
func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, nil } func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, io.EOF }
func (n *nopReadCloser) Close() error { return nil } func (n *nopReadCloser) Close() error { return nil }

View File

@ -31,7 +31,7 @@ func startRemote(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID,
picker := newURLPicker(urls) picker := newURLPicker(urls)
return &remote{ return &remote{
id: to, id: to,
pipeline: newPipeline(tr, picker, to, cid, nil, r, errorc), pipeline: newPipeline(tr, picker, local, to, cid, nil, r, errorc),
} }
} }

View File

@ -17,11 +17,13 @@ package rafthttp
import ( import (
"fmt" "fmt"
"io" "io"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
"path" "path"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -389,6 +391,9 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
cr.picker.unreachable(u) cr.picker.unreachable(u)
return nil, fmt.Errorf("new request to %s error: %v", u, err) return nil, fmt.Errorf("new request to %s error: %v", u, err)
} }
req.Header.Set("X-Server-From", cr.from.String())
req.Header.Set("X-Server-Version", version.Version)
req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String()) req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String())
req.Header.Set("X-Raft-To", cr.to.String()) req.Header.Set("X-Raft-To", cr.to.String())
if t == streamTypeMsgApp { if t == streamTypeMsgApp {
@ -425,10 +430,24 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
resp.Body.Close() resp.Body.Close()
return nil, fmt.Errorf("local member has not been added to the peer list of member %s", cr.to) return nil, fmt.Errorf("local member has not been added to the peer list of member %s", cr.to)
case http.StatusPreconditionFailed: case http.StatusPreconditionFailed:
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
cr.picker.unreachable(u)
return nil, err
}
resp.Body.Close() resp.Body.Close()
log.Printf("rafthttp: request sent was ignored due to cluster ID mismatch (remote[%s]:%s, local:%s)",
uu.Host, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid) switch strings.TrimSuffix(string(b), "\n") {
return nil, fmt.Errorf("cluster ID mismatch") case errIncompatibleVersion.Error():
log.Printf("rafthttp: request sent was ignored by peer %s (server version incompatible)", cr.to)
return nil, errIncompatibleVersion
case errClusterIDMismatch.Error():
log.Printf("rafthttp: request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)",
cr.to, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid)
return nil, errClusterIDMismatch
default:
return nil, fmt.Errorf("unhandled error %q when precondition failed", string(b))
}
default: default:
resp.Body.Close() resp.Body.Close()
return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode) return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
@ -457,32 +476,6 @@ func isClosedConnectionError(err error) bool {
return ok && operr.Err.Error() == "use of closed network connection" return ok && operr.Err.Error() == "use of closed network connection"
} }
// serverVersion returns the version from the given header.
func serverVersion(h http.Header) *semver.Version {
verStr := h.Get("X-Server-Version")
// backward compatibility with etcd 2.0
if verStr == "" {
verStr = "2.0.0"
}
return semver.Must(semver.NewVersion(verStr))
}
// compareMajorMinorVersion returns an integer comparing two versions based on
// their major and minor version. The result will be 0 if a==b, -1 if a < b,
// and 1 if a > b.
func compareMajorMinorVersion(a, b *semver.Version) int {
na := &semver.Version{Major: a.Major, Minor: a.Minor}
nb := &semver.Version{Major: b.Major, Minor: b.Minor}
switch {
case na.LessThan(*nb):
return -1
case nb.LessThan(*na):
return 1
default:
return 0
}
}
// checkStreamSupport checks whether the stream type is supported in the // checkStreamSupport checks whether the stream type is supported in the
// given version. // given version.
func checkStreamSupport(v *semver.Version, t streamType) bool { func checkStreamSupport(v *semver.Version, t streamType) bool {

View File

@ -302,76 +302,6 @@ func TestStream(t *testing.T) {
} }
} }
func TestServerVersion(t *testing.T) {
tests := []struct {
h http.Header
wv *semver.Version
}{
// backward compatibility with etcd 2.0
{
http.Header{},
semver.Must(semver.NewVersion("2.0.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0"}},
semver.Must(semver.NewVersion("2.1.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0-alpha.0+git"}},
semver.Must(semver.NewVersion("2.1.0-alpha.0+git")),
},
}
for i, tt := range tests {
v := serverVersion(tt.h)
if v.String() != tt.wv.String() {
t.Errorf("#%d: version = %s, want %s", i, v, tt.wv)
}
}
}
func TestCompareMajorMinorVersion(t *testing.T) {
tests := []struct {
va, vb *semver.Version
w int
}{
// equal to
{
semver.Must(semver.NewVersion("2.1.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// smaller than
{
semver.Must(semver.NewVersion("2.0.0")),
semver.Must(semver.NewVersion("2.1.0")),
-1,
},
// bigger than
{
semver.Must(semver.NewVersion("2.2.0")),
semver.Must(semver.NewVersion("2.1.0")),
1,
},
// ignore patch
{
semver.Must(semver.NewVersion("2.1.1")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// ignore prerelease
{
semver.Must(semver.NewVersion("2.1.0-alpha.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
}
for i, tt := range tests {
if g := compareMajorMinorVersion(tt.va, tt.vb); g != tt.w {
t.Errorf("#%d: compare = %d, want %d", i, g, tt.w)
}
}
}
func TestCheckStreamSupport(t *testing.T) { func TestCheckStreamSupport(t *testing.T) {
tests := []struct { tests := []struct {
v *semver.Version v *semver.Version

View File

@ -16,9 +16,13 @@ package rafthttp
import ( import (
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"net/http"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/go-semver/semver"
"github.com/coreos/etcd/raft/raftpb" "github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
) )
func writeEntryTo(w io.Writer, ent *raftpb.Entry) error { func writeEntryTo(w io.Writer, ent *raftpb.Entry) error {
@ -45,3 +49,53 @@ func readEntryFrom(r io.Reader, ent *raftpb.Entry) error {
} }
return ent.Unmarshal(buf) return ent.Unmarshal(buf)
} }
// compareMajorMinorVersion returns an integer comparing two versions based on
// their major and minor version. The result will be 0 if a==b, -1 if a < b,
// and 1 if a > b.
func compareMajorMinorVersion(a, b *semver.Version) int {
na := &semver.Version{Major: a.Major, Minor: a.Minor}
nb := &semver.Version{Major: b.Major, Minor: b.Minor}
switch {
case na.LessThan(*nb):
return -1
case nb.LessThan(*na):
return 1
default:
return 0
}
}
// serverVersion returns the server version from the given header.
func serverVersion(h http.Header) *semver.Version {
verStr := h.Get("X-Server-Version")
// backward compatibility with etcd 2.0
if verStr == "" {
verStr = "2.0.0"
}
return semver.Must(semver.NewVersion(verStr))
}
// serverVersion returns the min cluster version from the given header.
func minClusterVersion(h http.Header) *semver.Version {
verStr := h.Get("X-Min-Cluster-Version")
// backward compatibility with etcd 2.0
if verStr == "" {
verStr = "2.0.0"
}
return semver.Must(semver.NewVersion(verStr))
}
// checkVersionCompability checks whether the given version is compatible
// with the local version.
func checkVersionCompability(name string, server, minCluster *semver.Version) error {
localServer := semver.Must(semver.NewVersion(version.Version))
localMinCluster := semver.Must(semver.NewVersion(version.MinClusterVersion))
if compareMajorMinorVersion(server, localMinCluster) == -1 {
return fmt.Errorf("remote version is too low: remote[%s]=%s, local=%s", name, server, localServer)
}
if compareMajorMinorVersion(minCluster, localServer) == 1 {
return fmt.Errorf("local version is too low: remote[%s]=%s, local=%s", name, server, localServer)
}
return nil
}

View File

@ -16,10 +16,13 @@ package rafthttp
import ( import (
"bytes" "bytes"
"net/http"
"reflect" "reflect"
"testing" "testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/go-semver/semver"
"github.com/coreos/etcd/raft/raftpb" "github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
) )
func TestEntry(t *testing.T) { func TestEntry(t *testing.T) {
@ -44,3 +47,147 @@ func TestEntry(t *testing.T) {
} }
} }
} }
func TestCompareMajorMinorVersion(t *testing.T) {
tests := []struct {
va, vb *semver.Version
w int
}{
// equal to
{
semver.Must(semver.NewVersion("2.1.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// smaller than
{
semver.Must(semver.NewVersion("2.0.0")),
semver.Must(semver.NewVersion("2.1.0")),
-1,
},
// bigger than
{
semver.Must(semver.NewVersion("2.2.0")),
semver.Must(semver.NewVersion("2.1.0")),
1,
},
// ignore patch
{
semver.Must(semver.NewVersion("2.1.1")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// ignore prerelease
{
semver.Must(semver.NewVersion("2.1.0-alpha.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
}
for i, tt := range tests {
if g := compareMajorMinorVersion(tt.va, tt.vb); g != tt.w {
t.Errorf("#%d: compare = %d, want %d", i, g, tt.w)
}
}
}
func TestServerVersion(t *testing.T) {
tests := []struct {
h http.Header
wv *semver.Version
}{
// backward compatibility with etcd 2.0
{
http.Header{},
semver.Must(semver.NewVersion("2.0.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0"}},
semver.Must(semver.NewVersion("2.1.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0-alpha.0+git"}},
semver.Must(semver.NewVersion("2.1.0-alpha.0+git")),
},
}
for i, tt := range tests {
v := serverVersion(tt.h)
if v.String() != tt.wv.String() {
t.Errorf("#%d: version = %s, want %s", i, v, tt.wv)
}
}
}
func TestMinClusterVersion(t *testing.T) {
tests := []struct {
h http.Header
wv *semver.Version
}{
// backward compatibility with etcd 2.0
{
http.Header{},
semver.Must(semver.NewVersion("2.0.0")),
},
{
http.Header{"X-Min-Cluster-Version": []string{"2.1.0"}},
semver.Must(semver.NewVersion("2.1.0")),
},
{
http.Header{"X-Min-Cluster-Version": []string{"2.1.0-alpha.0+git"}},
semver.Must(semver.NewVersion("2.1.0-alpha.0+git")),
},
}
for i, tt := range tests {
v := minClusterVersion(tt.h)
if v.String() != tt.wv.String() {
t.Errorf("#%d: version = %s, want %s", i, v, tt.wv)
}
}
}
func TestCheckVersionCompatibility(t *testing.T) {
ls := semver.Must(semver.NewVersion(version.Version))
lmc := semver.Must(semver.NewVersion(version.MinClusterVersion))
tests := []struct {
server *semver.Version
minCluster *semver.Version
wok bool
}{
// the same version as local
{
ls,
lmc,
true,
},
// one version lower
{
lmc,
&semver.Version{},
true,
},
// one version higher
{
&semver.Version{Major: ls.Major + 1},
ls,
true,
},
// too low version
{
&semver.Version{Major: lmc.Major - 1},
&semver.Version{},
false,
},
// too high version
{
&semver.Version{Major: ls.Major + 1, Minor: 1},
&semver.Version{Major: ls.Major + 1},
false,
},
}
for i, tt := range tests {
err := checkVersionCompability("", tt.server, tt.minCluster)
if ok := err == nil; ok != tt.wok {
t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
}
}
}