rafthttp: version enforcement on rafthttp messages

This PR sets etcd version and min cluster version in request header,
and let server check version compatibility. rafthttp server
will reject any message from peer with incompatible version(too low
version or too high version), and print out warning logs.
This commit is contained in:
Yicheng Qin 2015-06-02 11:46:36 -07:00
parent 8825af47a0
commit c371d8c65c
9 changed files with 290 additions and 122 deletions

View File

@ -15,6 +15,7 @@
package rafthttp
import (
"errors"
"io/ioutil"
"log"
"net/http"
@ -34,6 +35,9 @@ const (
var (
RaftPrefix = "/raft"
RaftStreamPrefix = path.Join(RaftPrefix, "stream")
errIncompatibleVersion = errors.New("incompatible version")
errClusterIDMismatch = errors.New("cluster ID mismatch")
)
func NewHandler(r Raft, cid types.ID) http.Handler {
@ -72,13 +76,19 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
if err := checkVersionCompability(r.Header.Get("X-Server-From"), serverVersion(r.Header), minClusterVersion(r.Header)); err != nil {
log.Printf("rafthttp: request received was ignored (%v)", err)
http.Error(w, errIncompatibleVersion.Error(), http.StatusPreconditionFailed)
return
}
wcid := h.cid.String()
w.Header().Set("X-Etcd-Cluster-ID", wcid)
gcid := r.Header.Get("X-Etcd-Cluster-ID")
if gcid != wcid {
log.Printf("rafthttp: request ignored due to cluster ID mismatch got %s want %s", gcid, wcid)
http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed)
http.Error(w, errClusterIDMismatch.Error(), http.StatusPreconditionFailed)
return
}
@ -126,17 +136,23 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
w.Header().Set("X-Server-Version", version.Version)
if err := checkVersionCompability(r.Header.Get("X-Server-From"), serverVersion(r.Header), minClusterVersion(r.Header)); err != nil {
log.Printf("rafthttp: request received was ignored (%v)", err)
http.Error(w, errIncompatibleVersion.Error(), http.StatusPreconditionFailed)
return
}
wcid := h.cid.String()
w.Header().Set("X-Etcd-Cluster-ID", wcid)
if gcid := r.Header.Get("X-Etcd-Cluster-ID"); gcid != wcid {
log.Printf("rafthttp: streaming request ignored due to cluster ID mismatch got %s want %s", gcid, wcid)
http.Error(w, "clusterID mismatch", http.StatusPreconditionFailed)
http.Error(w, errClusterIDMismatch.Error(), http.StatusPreconditionFailed)
return
}
w.Header().Add("X-Server-Version", version.Version)
var t streamType
switch path.Dir(r.URL.Path) {
// backward compatibility

View File

@ -124,7 +124,7 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
r: r,
msgAppWriter: startStreamWriter(to, fs, r),
writer: startStreamWriter(to, fs, r),
pipeline: newPipeline(tr, picker, to, cid, fs, r, errorc),
pipeline: newPipeline(tr, picker, local, to, cid, fs, r, errorc),
sendc: make(chan raftpb.Message),
recvc: make(chan raftpb.Message, recvBufSize),
propc: make(chan raftpb.Message, maxPendingProposals),

View File

@ -17,8 +17,10 @@ package rafthttp
import (
"bytes"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"sync"
"time"
@ -27,6 +29,7 @@ import (
"github.com/coreos/etcd/pkg/types"
"github.com/coreos/etcd/raft"
"github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
)
const (
@ -39,8 +42,8 @@ const (
)
type pipeline struct {
id types.ID
cid types.ID
from, to types.ID
cid types.ID
tr http.RoundTripper
picker *urlPicker
@ -58,9 +61,10 @@ type pipeline struct {
errored error
}
func newPipeline(tr http.RoundTripper, picker *urlPicker, id, cid types.ID, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline {
func newPipeline(tr http.RoundTripper, picker *urlPicker, from, to, cid types.ID, fs *stats.FollowerStats, r Raft, errorc chan error) *pipeline {
p := &pipeline{
id: id,
from: from,
to: to,
cid: cid,
tr: tr,
picker: picker,
@ -94,11 +98,11 @@ func (p *pipeline) handle() {
reportSentFailure(pipelineMsg, m)
if p.errored == nil || p.errored.Error() != err.Error() {
log.Printf("pipeline: error posting to %s: %v", p.id, err)
log.Printf("pipeline: error posting to %s: %v", p.to, err)
p.errored = err
}
if p.active {
log.Printf("pipeline: the connection with %s became inactive", p.id)
log.Printf("pipeline: the connection with %s became inactive", p.to)
p.active = false
}
if m.Type == raftpb.MsgApp && p.fs != nil {
@ -110,7 +114,7 @@ func (p *pipeline) handle() {
}
} else {
if !p.active {
log.Printf("pipeline: the connection with %s became active", p.id)
log.Printf("pipeline: the connection with %s became active", p.to)
p.active = true
p.errored = nil
}
@ -138,19 +142,35 @@ func (p *pipeline) post(data []byte) error {
return err
}
req.Header.Set("Content-Type", "application/protobuf")
req.Header.Set("X-Server-From", p.from.String())
req.Header.Set("X-Server-Version", version.Version)
req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
req.Header.Set("X-Etcd-Cluster-ID", p.cid.String())
resp, err := p.tr.RoundTrip(req)
if err != nil {
p.picker.unreachable(u)
return err
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
p.picker.unreachable(u)
return err
}
resp.Body.Close()
switch resp.StatusCode {
case http.StatusPreconditionFailed:
log.Printf("rafthttp: request sent was ignored due to cluster ID mismatch (remote[%s]:%s, local:%s)",
uu.Host, resp.Header.Get("X-Etcd-Cluster-ID"), p.cid)
return fmt.Errorf("cluster ID mismatch")
switch strings.TrimSuffix(string(b), "\n") {
case errIncompatibleVersion.Error():
log.Printf("rafthttp: request sent was ignored by peer %s (server version incompatible)", p.to)
return errIncompatibleVersion
case errClusterIDMismatch.Error():
log.Printf("rafthttp: request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)",
p.to, resp.Header.Get("X-Etcd-Cluster-ID"), p.cid)
return errClusterIDMismatch
default:
return fmt.Errorf("unhandled error %q when precondition failed", string(b))
}
case http.StatusForbidden:
err := fmt.Errorf("the member has been permanently removed from the cluster")
select {

View File

@ -16,6 +16,7 @@ package rafthttp
import (
"errors"
"io"
"io/ioutil"
"net/http"
"sync"
@ -25,6 +26,7 @@ import (
"github.com/coreos/etcd/pkg/testutil"
"github.com/coreos/etcd/pkg/types"
"github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
)
// TestPipelineSend tests that pipeline could send data using roundtripper
@ -33,7 +35,7 @@ func TestPipelineSend(t *testing.T) {
tr := &roundTripperRecorder{}
picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
fs := &stats.FollowerStats{}
p := newPipeline(tr, picker, types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
p.stop()
@ -52,7 +54,7 @@ func TestPipelineExceedMaximalServing(t *testing.T) {
tr := newRoundTripperBlocker()
picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
fs := &stats.FollowerStats{}
p := newPipeline(tr, picker, types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
// keep the sender busy and make the buffer full
// nothing can go out as we block the sender
@ -92,7 +94,7 @@ func TestPipelineExceedMaximalServing(t *testing.T) {
func TestPipelineSendFailed(t *testing.T) {
picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
fs := &stats.FollowerStats{}
p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p := newPipeline(newRespRoundTripper(0, errors.New("blah")), picker, types.ID(2), types.ID(1), types.ID(1), fs, &fakeRaft{}, nil)
p.msgc <- raftpb.Message{Type: raftpb.MsgApp}
p.stop()
@ -107,7 +109,7 @@ func TestPipelineSendFailed(t *testing.T) {
func TestPipelinePost(t *testing.T) {
tr := &roundTripperRecorder{}
picker := mustNewURLPicker(t, []string{"http://localhost:2380"})
p := newPipeline(tr, picker, types.ID(1), types.ID(1), nil, &fakeRaft{}, nil)
p := newPipeline(tr, picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, nil)
if err := p.post([]byte("some data")); err != nil {
t.Fatalf("unexpect post error: %v", err)
}
@ -122,6 +124,12 @@ func TestPipelinePost(t *testing.T) {
if g := tr.Request().Header.Get("Content-Type"); g != "application/protobuf" {
t.Errorf("content type = %s, want %s", g, "application/protobuf")
}
if g := tr.Request().Header.Get("X-Server-Version"); g != version.Version {
t.Errorf("version = %s, want %s", g, version.Version)
}
if g := tr.Request().Header.Get("X-Min-Cluster-Version"); g != version.MinClusterVersion {
t.Errorf("min version = %s, want %s", g, version.MinClusterVersion)
}
if g := tr.Request().Header.Get("X-Etcd-Cluster-ID"); g != "1" {
t.Errorf("cluster id = %s, want %s", g, "1")
}
@ -148,7 +156,7 @@ func TestPipelinePostBad(t *testing.T) {
}
for i, tt := range tests {
picker := mustNewURLPicker(t, []string{tt.u})
p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(1), types.ID(1), nil, &fakeRaft{}, make(chan error))
p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, make(chan error))
err := p.post([]byte("some data"))
p.stop()
@ -169,7 +177,7 @@ func TestPipelinePostErrorc(t *testing.T) {
for i, tt := range tests {
picker := mustNewURLPicker(t, []string{tt.u})
errorc := make(chan error, 1)
p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(1), types.ID(1), nil, &fakeRaft{}, errorc)
p := newPipeline(newRespRoundTripper(tt.code, tt.err), picker, types.ID(2), types.ID(1), types.ID(1), nil, &fakeRaft{}, errorc)
p.post([]byte("some data"))
p.stop()
select {
@ -227,5 +235,5 @@ func (t *roundTripperRecorder) Request() *http.Request {
type nopReadCloser struct{}
func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, nil }
func (n *nopReadCloser) Read(p []byte) (int, error) { return 0, io.EOF }
func (n *nopReadCloser) Close() error { return nil }

View File

@ -31,7 +31,7 @@ func startRemote(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID,
picker := newURLPicker(urls)
return &remote{
id: to,
pipeline: newPipeline(tr, picker, to, cid, nil, r, errorc),
pipeline: newPipeline(tr, picker, local, to, cid, nil, r, errorc),
}
}

View File

@ -17,11 +17,13 @@ package rafthttp
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
"path"
"strconv"
"strings"
"sync"
"time"
@ -389,6 +391,9 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
cr.picker.unreachable(u)
return nil, fmt.Errorf("new request to %s error: %v", u, err)
}
req.Header.Set("X-Server-From", cr.from.String())
req.Header.Set("X-Server-Version", version.Version)
req.Header.Set("X-Min-Cluster-Version", version.MinClusterVersion)
req.Header.Set("X-Etcd-Cluster-ID", cr.cid.String())
req.Header.Set("X-Raft-To", cr.to.String())
if t == streamTypeMsgApp {
@ -425,10 +430,24 @@ func (cr *streamReader) dial(t streamType) (io.ReadCloser, error) {
resp.Body.Close()
return nil, fmt.Errorf("local member has not been added to the peer list of member %s", cr.to)
case http.StatusPreconditionFailed:
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
cr.picker.unreachable(u)
return nil, err
}
resp.Body.Close()
log.Printf("rafthttp: request sent was ignored due to cluster ID mismatch (remote[%s]:%s, local:%s)",
uu.Host, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid)
return nil, fmt.Errorf("cluster ID mismatch")
switch strings.TrimSuffix(string(b), "\n") {
case errIncompatibleVersion.Error():
log.Printf("rafthttp: request sent was ignored by peer %s (server version incompatible)", cr.to)
return nil, errIncompatibleVersion
case errClusterIDMismatch.Error():
log.Printf("rafthttp: request sent was ignored (cluster ID mismatch: remote[%s]=%s, local=%s)",
cr.to, resp.Header.Get("X-Etcd-Cluster-ID"), cr.cid)
return nil, errClusterIDMismatch
default:
return nil, fmt.Errorf("unhandled error %q when precondition failed", string(b))
}
default:
resp.Body.Close()
return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
@ -457,32 +476,6 @@ func isClosedConnectionError(err error) bool {
return ok && operr.Err.Error() == "use of closed network connection"
}
// serverVersion returns the version from the given header.
func serverVersion(h http.Header) *semver.Version {
verStr := h.Get("X-Server-Version")
// backward compatibility with etcd 2.0
if verStr == "" {
verStr = "2.0.0"
}
return semver.Must(semver.NewVersion(verStr))
}
// compareMajorMinorVersion returns an integer comparing two versions based on
// their major and minor version. The result will be 0 if a==b, -1 if a < b,
// and 1 if a > b.
func compareMajorMinorVersion(a, b *semver.Version) int {
na := &semver.Version{Major: a.Major, Minor: a.Minor}
nb := &semver.Version{Major: b.Major, Minor: b.Minor}
switch {
case na.LessThan(*nb):
return -1
case nb.LessThan(*na):
return 1
default:
return 0
}
}
// checkStreamSupport checks whether the stream type is supported in the
// given version.
func checkStreamSupport(v *semver.Version, t streamType) bool {

View File

@ -302,76 +302,6 @@ func TestStream(t *testing.T) {
}
}
func TestServerVersion(t *testing.T) {
tests := []struct {
h http.Header
wv *semver.Version
}{
// backward compatibility with etcd 2.0
{
http.Header{},
semver.Must(semver.NewVersion("2.0.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0"}},
semver.Must(semver.NewVersion("2.1.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0-alpha.0+git"}},
semver.Must(semver.NewVersion("2.1.0-alpha.0+git")),
},
}
for i, tt := range tests {
v := serverVersion(tt.h)
if v.String() != tt.wv.String() {
t.Errorf("#%d: version = %s, want %s", i, v, tt.wv)
}
}
}
func TestCompareMajorMinorVersion(t *testing.T) {
tests := []struct {
va, vb *semver.Version
w int
}{
// equal to
{
semver.Must(semver.NewVersion("2.1.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// smaller than
{
semver.Must(semver.NewVersion("2.0.0")),
semver.Must(semver.NewVersion("2.1.0")),
-1,
},
// bigger than
{
semver.Must(semver.NewVersion("2.2.0")),
semver.Must(semver.NewVersion("2.1.0")),
1,
},
// ignore patch
{
semver.Must(semver.NewVersion("2.1.1")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// ignore prerelease
{
semver.Must(semver.NewVersion("2.1.0-alpha.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
}
for i, tt := range tests {
if g := compareMajorMinorVersion(tt.va, tt.vb); g != tt.w {
t.Errorf("#%d: compare = %d, want %d", i, g, tt.w)
}
}
}
func TestCheckStreamSupport(t *testing.T) {
tests := []struct {
v *semver.Version

View File

@ -16,9 +16,13 @@ package rafthttp
import (
"encoding/binary"
"fmt"
"io"
"net/http"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/go-semver/semver"
"github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
)
func writeEntryTo(w io.Writer, ent *raftpb.Entry) error {
@ -45,3 +49,53 @@ func readEntryFrom(r io.Reader, ent *raftpb.Entry) error {
}
return ent.Unmarshal(buf)
}
// compareMajorMinorVersion returns an integer comparing two versions based on
// their major and minor version. The result will be 0 if a==b, -1 if a < b,
// and 1 if a > b.
func compareMajorMinorVersion(a, b *semver.Version) int {
na := &semver.Version{Major: a.Major, Minor: a.Minor}
nb := &semver.Version{Major: b.Major, Minor: b.Minor}
switch {
case na.LessThan(*nb):
return -1
case nb.LessThan(*na):
return 1
default:
return 0
}
}
// serverVersion returns the server version from the given header.
func serverVersion(h http.Header) *semver.Version {
verStr := h.Get("X-Server-Version")
// backward compatibility with etcd 2.0
if verStr == "" {
verStr = "2.0.0"
}
return semver.Must(semver.NewVersion(verStr))
}
// serverVersion returns the min cluster version from the given header.
func minClusterVersion(h http.Header) *semver.Version {
verStr := h.Get("X-Min-Cluster-Version")
// backward compatibility with etcd 2.0
if verStr == "" {
verStr = "2.0.0"
}
return semver.Must(semver.NewVersion(verStr))
}
// checkVersionCompability checks whether the given version is compatible
// with the local version.
func checkVersionCompability(name string, server, minCluster *semver.Version) error {
localServer := semver.Must(semver.NewVersion(version.Version))
localMinCluster := semver.Must(semver.NewVersion(version.MinClusterVersion))
if compareMajorMinorVersion(server, localMinCluster) == -1 {
return fmt.Errorf("remote version is too low: remote[%s]=%s, local=%s", name, server, localServer)
}
if compareMajorMinorVersion(minCluster, localServer) == 1 {
return fmt.Errorf("local version is too low: remote[%s]=%s, local=%s", name, server, localServer)
}
return nil
}

View File

@ -16,10 +16,13 @@ package rafthttp
import (
"bytes"
"net/http"
"reflect"
"testing"
"github.com/coreos/etcd/Godeps/_workspace/src/github.com/coreos/go-semver/semver"
"github.com/coreos/etcd/raft/raftpb"
"github.com/coreos/etcd/version"
)
func TestEntry(t *testing.T) {
@ -44,3 +47,147 @@ func TestEntry(t *testing.T) {
}
}
}
func TestCompareMajorMinorVersion(t *testing.T) {
tests := []struct {
va, vb *semver.Version
w int
}{
// equal to
{
semver.Must(semver.NewVersion("2.1.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// smaller than
{
semver.Must(semver.NewVersion("2.0.0")),
semver.Must(semver.NewVersion("2.1.0")),
-1,
},
// bigger than
{
semver.Must(semver.NewVersion("2.2.0")),
semver.Must(semver.NewVersion("2.1.0")),
1,
},
// ignore patch
{
semver.Must(semver.NewVersion("2.1.1")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
// ignore prerelease
{
semver.Must(semver.NewVersion("2.1.0-alpha.0")),
semver.Must(semver.NewVersion("2.1.0")),
0,
},
}
for i, tt := range tests {
if g := compareMajorMinorVersion(tt.va, tt.vb); g != tt.w {
t.Errorf("#%d: compare = %d, want %d", i, g, tt.w)
}
}
}
func TestServerVersion(t *testing.T) {
tests := []struct {
h http.Header
wv *semver.Version
}{
// backward compatibility with etcd 2.0
{
http.Header{},
semver.Must(semver.NewVersion("2.0.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0"}},
semver.Must(semver.NewVersion("2.1.0")),
},
{
http.Header{"X-Server-Version": []string{"2.1.0-alpha.0+git"}},
semver.Must(semver.NewVersion("2.1.0-alpha.0+git")),
},
}
for i, tt := range tests {
v := serverVersion(tt.h)
if v.String() != tt.wv.String() {
t.Errorf("#%d: version = %s, want %s", i, v, tt.wv)
}
}
}
func TestMinClusterVersion(t *testing.T) {
tests := []struct {
h http.Header
wv *semver.Version
}{
// backward compatibility with etcd 2.0
{
http.Header{},
semver.Must(semver.NewVersion("2.0.0")),
},
{
http.Header{"X-Min-Cluster-Version": []string{"2.1.0"}},
semver.Must(semver.NewVersion("2.1.0")),
},
{
http.Header{"X-Min-Cluster-Version": []string{"2.1.0-alpha.0+git"}},
semver.Must(semver.NewVersion("2.1.0-alpha.0+git")),
},
}
for i, tt := range tests {
v := minClusterVersion(tt.h)
if v.String() != tt.wv.String() {
t.Errorf("#%d: version = %s, want %s", i, v, tt.wv)
}
}
}
func TestCheckVersionCompatibility(t *testing.T) {
ls := semver.Must(semver.NewVersion(version.Version))
lmc := semver.Must(semver.NewVersion(version.MinClusterVersion))
tests := []struct {
server *semver.Version
minCluster *semver.Version
wok bool
}{
// the same version as local
{
ls,
lmc,
true,
},
// one version lower
{
lmc,
&semver.Version{},
true,
},
// one version higher
{
&semver.Version{Major: ls.Major + 1},
ls,
true,
},
// too low version
{
&semver.Version{Major: lmc.Major - 1},
&semver.Version{},
false,
},
// too high version
{
&semver.Version{Major: ls.Major + 1, Minor: 1},
&semver.Version{Major: ls.Major + 1},
false,
},
}
for i, tt := range tests {
err := checkVersionCompability("", tt.server, tt.minCluster)
if ok := err == nil; ok != tt.wok {
t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
}
}
}