mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Merge pull request #2762 from yichengq/343
rafthttp: stop etcd if it is found removed when stream dial
This commit is contained in:
commit
d080c33c07
@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
|
||||
return s.r.Step(ctx, m)
|
||||
}
|
||||
|
||||
func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
|
||||
|
||||
func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
|
||||
|
||||
func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {
|
||||
|
@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
|
||||
}
|
||||
|
||||
type fakeRaft struct {
|
||||
recvc chan<- raftpb.Message
|
||||
err error
|
||||
recvc chan<- raftpb.Message
|
||||
err error
|
||||
removedID uint64
|
||||
}
|
||||
|
||||
func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
|
||||
@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
|
||||
return p.err
|
||||
}
|
||||
|
||||
func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
|
||||
|
||||
func (p *fakeRaft) ReportUnreachable(id uint64) {}
|
||||
|
||||
func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
|
||||
|
@ -46,9 +46,10 @@ type peerGetter interface {
|
||||
Get(id types.ID) Peer
|
||||
}
|
||||
|
||||
func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
|
||||
func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
|
||||
return &streamHandler{
|
||||
peerGetter: peerGetter,
|
||||
r: r,
|
||||
id: id,
|
||||
cid: cid,
|
||||
}
|
||||
@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type streamHandler struct {
|
||||
peerGetter peerGetter
|
||||
r Raft
|
||||
id types.ID
|
||||
cid types.ID
|
||||
}
|
||||
@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "invalid from", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if h.r.IsIDRemoved(uint64(from)) {
|
||||
log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
|
||||
http.Error(w, "removed member", http.StatusGone)
|
||||
return
|
||||
}
|
||||
p := h.peerGetter.Get(from)
|
||||
if p == nil {
|
||||
log.Printf("rafthttp: fail to find sender %s", from)
|
||||
|
@ -17,6 +17,7 @@ package rafthttp
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
|
||||
|
||||
peer := newFakePeer()
|
||||
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
|
||||
h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
|
||||
h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
go h.ServeHTTP(rw, req)
|
||||
@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServeRaftStreamPrefixBad(t *testing.T) {
|
||||
removedID := uint64(5)
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
|
||||
"1",
|
||||
http.StatusNotFound,
|
||||
},
|
||||
// removed peer
|
||||
{
|
||||
"GET",
|
||||
RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
|
||||
"1",
|
||||
"1",
|
||||
http.StatusGone,
|
||||
},
|
||||
// wrong cluster ID
|
||||
{
|
||||
"GET",
|
||||
@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
|
||||
req.Header.Set("X-Raft-To", tt.remote)
|
||||
rw := httptest.NewRecorder()
|
||||
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
|
||||
h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
|
||||
r := &fakeRaft{removedID: removedID}
|
||||
h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != tt.wcode {
|
||||
|
@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
|
||||
|
||||
go func() {
|
||||
var paused bool
|
||||
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
|
||||
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
|
||||
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
|
||||
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
|
||||
for {
|
||||
select {
|
||||
case m := <-p.sendc:
|
||||
|
@ -226,6 +226,7 @@ type streamReader struct {
|
||||
cid types.ID
|
||||
recvc chan<- raftpb.Message
|
||||
propc chan<- raftpb.Message
|
||||
errorc chan<- error
|
||||
|
||||
mu sync.Mutex
|
||||
msgAppTerm uint64
|
||||
@ -235,7 +236,7 @@ type streamReader struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
|
||||
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
|
||||
r := &streamReader{
|
||||
tr: tr,
|
||||
picker: picker,
|
||||
@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
|
||||
cid: cid,
|
||||
recvc: recvc,
|
||||
propc: propc,
|
||||
errorc: errorc,
|
||||
stopc: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
|
||||
cr.picker.unreachable(u)
|
||||
return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusGone:
|
||||
resp.Body.Close()
|
||||
err := fmt.Errorf("the member has been permanently removed from the cluster")
|
||||
select {
|
||||
case cr.errorc <- err:
|
||||
default:
|
||||
}
|
||||
return nil, err
|
||||
case http.StatusOK:
|
||||
return resp.Body, nil
|
||||
default:
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
|
||||
}
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (cr *streamReader) cancelRequest() {
|
||||
|
@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
|
||||
// HTTP response received.
|
||||
func TestStreamReaderDialResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
err error
|
||||
wok bool
|
||||
code int
|
||||
err error
|
||||
wok bool
|
||||
whalt bool
|
||||
}{
|
||||
{0, errors.New("blah"), false},
|
||||
{http.StatusOK, nil, true},
|
||||
{http.StatusMethodNotAllowed, nil, false},
|
||||
{http.StatusNotFound, nil, false},
|
||||
{http.StatusPreconditionFailed, nil, false},
|
||||
{0, errors.New("blah"), false, false},
|
||||
{http.StatusOK, nil, true, false},
|
||||
{http.StatusMethodNotAllowed, nil, false, false},
|
||||
{http.StatusNotFound, nil, false, false},
|
||||
{http.StatusPreconditionFailed, nil, false, false},
|
||||
{http.StatusGone, nil, false, true},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
tr := newRespRoundTripper(tt.code, tt.err)
|
||||
@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
|
||||
from: types.ID(1),
|
||||
to: types.ID(2),
|
||||
cid: types.ID(1),
|
||||
errorc: make(chan error, 1),
|
||||
}
|
||||
|
||||
_, err := sr.dial()
|
||||
if ok := err == nil; ok != tt.wok {
|
||||
t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
|
||||
}
|
||||
if halt := len(sr.errorc) > 0; halt != tt.whalt {
|
||||
t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
|
||||
h.sw = sw
|
||||
|
||||
picker := mustNewURLPicker(t, []string{srv.URL})
|
||||
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
|
||||
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
|
||||
defer sr.stop()
|
||||
if tt.t == streamTypeMsgApp {
|
||||
sr.updateMsgAppTerm(tt.term)
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
|
||||
type Raft interface {
|
||||
Process(ctx context.Context, m raftpb.Message) error
|
||||
IsIDRemoved(id uint64) bool
|
||||
ReportUnreachable(id uint64)
|
||||
ReportSnapshot(id uint64, status raft.SnapshotStatus)
|
||||
}
|
||||
@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
|
||||
|
||||
func (t *transport) Handler() http.Handler {
|
||||
pipelineHandler := NewHandler(t.raft, t.clusterID)
|
||||
streamHandler := newStreamHandler(t, t.id, t.clusterID)
|
||||
streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(RaftPrefix, pipelineHandler)
|
||||
mux.Handle(RaftStreamPrefix+"/", streamHandler)
|
||||
|
@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
|
||||
|
||||
func (r *countRaft) ReportUnreachable(id uint64) {}
|
||||
|
||||
func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user