rafthttp: stop etcd if it is found removed when stream dial

The original process is stopping etcd only when pipeline message finds itself
has been removed. After this PR, stream dial has this functionality too.
It helps fast etcd stop, which doesn't need to wait for stream break to
fall back to pipeline, and wait for election timeout to send out message
to detect self removal.
This commit is contained in:
Yicheng Qin 2015-04-25 23:24:59 -07:00
parent be6f49ba32
commit 1c1cccd236
9 changed files with 64 additions and 20 deletions

View File

@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
return s.r.Step(ctx, m)
}
func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {

View File

@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
}
type fakeRaft struct {
recvc chan<- raftpb.Message
err error
recvc chan<- raftpb.Message
err error
removedID uint64
}
func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
return p.err
}
func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
func (p *fakeRaft) ReportUnreachable(id uint64) {}
func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}

View File

@ -46,9 +46,10 @@ type peerGetter interface {
Get(id types.ID) Peer
}
func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
return &streamHandler{
peerGetter: peerGetter,
r: r,
id: id,
cid: cid,
}
@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
type streamHandler struct {
peerGetter peerGetter
r Raft
id types.ID
cid types.ID
}
@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Error(w, "invalid from", http.StatusNotFound)
return
}
if h.r.IsIDRemoved(uint64(from)) {
log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
http.Error(w, "removed member", http.StatusGone)
return
}
p := h.peerGetter.Get(from)
if p == nil {
log.Printf("rafthttp: fail to find sender %s", from)

View File

@ -17,6 +17,7 @@ package rafthttp
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"net/http/httptest"
@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
peer := newFakePeer()
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
rw := httptest.NewRecorder()
go h.ServeHTTP(rw, req)
@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
}
func TestServeRaftStreamPrefixBad(t *testing.T) {
removedID := uint64(5)
tests := []struct {
method string
path string
@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
"1",
http.StatusNotFound,
},
// removed peer
{
"GET",
RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
"1",
"1",
http.StatusGone,
},
// wrong cluster ID
{
"GET",
@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
req.Header.Set("X-Raft-To", tt.remote)
rw := httptest.NewRecorder()
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
r := &fakeRaft{removedID: removedID}
h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
h.ServeHTTP(rw, req)
if rw.Code != tt.wcode {

View File

@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
go func() {
var paused bool
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
for {
select {
case m := <-p.sendc:

View File

@ -226,6 +226,7 @@ type streamReader struct {
cid types.ID
recvc chan<- raftpb.Message
propc chan<- raftpb.Message
errorc chan<- error
mu sync.Mutex
msgAppTerm uint64
@ -235,7 +236,7 @@ type streamReader struct {
done chan struct{}
}
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
r := &streamReader{
tr: tr,
picker: picker,
@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
cid: cid,
recvc: recvc,
propc: propc,
errorc: errorc,
stopc: make(chan struct{}),
done: make(chan struct{}),
}
@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
cr.picker.unreachable(u)
return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
}
if resp.StatusCode != http.StatusOK {
switch resp.StatusCode {
case http.StatusGone:
resp.Body.Close()
err := fmt.Errorf("the member has been permanently removed from the cluster")
select {
case cr.errorc <- err:
default:
}
return nil, err
case http.StatusOK:
return resp.Body, nil
default:
resp.Body.Close()
return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
}
return resp.Body, nil
}
func (cr *streamReader) cancelRequest() {

View File

@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
// HTTP response received.
func TestStreamReaderDialResult(t *testing.T) {
tests := []struct {
code int
err error
wok bool
code int
err error
wok bool
whalt bool
}{
{0, errors.New("blah"), false},
{http.StatusOK, nil, true},
{http.StatusMethodNotAllowed, nil, false},
{http.StatusNotFound, nil, false},
{http.StatusPreconditionFailed, nil, false},
{0, errors.New("blah"), false, false},
{http.StatusOK, nil, true, false},
{http.StatusMethodNotAllowed, nil, false, false},
{http.StatusNotFound, nil, false, false},
{http.StatusPreconditionFailed, nil, false, false},
{http.StatusGone, nil, false, true},
}
for i, tt := range tests {
tr := newRespRoundTripper(tt.code, tt.err)
@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
from: types.ID(1),
to: types.ID(2),
cid: types.ID(1),
errorc: make(chan error, 1),
}
_, err := sr.dial()
if ok := err == nil; ok != tt.wok {
t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
}
if halt := len(sr.errorc) > 0; halt != tt.whalt {
t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
}
}
}
@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
h.sw = sw
picker := mustNewURLPicker(t, []string{srv.URL})
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
defer sr.stop()
if tt.t == streamTypeMsgApp {
sr.updateMsgAppTerm(tt.term)

View File

@ -28,6 +28,7 @@ import (
type Raft interface {
Process(ctx context.Context, m raftpb.Message) error
IsIDRemoved(id uint64) bool
ReportUnreachable(id uint64)
ReportSnapshot(id uint64, status raft.SnapshotStatus)
}
@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
func (t *transport) Handler() http.Handler {
pipelineHandler := NewHandler(t.raft, t.clusterID)
streamHandler := newStreamHandler(t, t.id, t.clusterID)
streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
mux := http.NewServeMux()
mux.Handle(RaftPrefix, pipelineHandler)
mux.Handle(RaftStreamPrefix+"/", streamHandler)

View File

@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
return nil
}
func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
func (r *countRaft) ReportUnreachable(id uint64) {}
func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}