mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
rafthttp: stop etcd if it is found removed when stream dial
The original process is stopping etcd only when pipeline message finds itself has been removed. After this PR, stream dial has this functionality too. It helps fast etcd stop, which doesn't need to wait for stream break to fall back to pipeline, and wait for election timeout to send out message to detect self removal.
This commit is contained in:
parent
be6f49ba32
commit
1c1cccd236
@ -348,6 +348,8 @@ func (s *EtcdServer) Process(ctx context.Context, m raftpb.Message) error {
|
||||
return s.r.Step(ctx, m)
|
||||
}
|
||||
|
||||
func (s *EtcdServer) IsIDRemoved(id uint64) bool { return s.Cluster.IsIDRemoved(types.ID(id)) }
|
||||
|
||||
func (s *EtcdServer) ReportUnreachable(id uint64) { s.r.ReportUnreachable(id) }
|
||||
|
||||
func (s *EtcdServer) ReportSnapshot(id uint64, status raft.SnapshotStatus) {
|
||||
|
@ -134,8 +134,9 @@ func waitStreamWorking(p *peer) bool {
|
||||
}
|
||||
|
||||
type fakeRaft struct {
|
||||
recvc chan<- raftpb.Message
|
||||
err error
|
||||
recvc chan<- raftpb.Message
|
||||
err error
|
||||
removedID uint64
|
||||
}
|
||||
|
||||
func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
|
||||
@ -146,6 +147,8 @@ func (p *fakeRaft) Process(ctx context.Context, m raftpb.Message) error {
|
||||
return p.err
|
||||
}
|
||||
|
||||
func (p *fakeRaft) IsIDRemoved(id uint64) bool { return id == p.removedID }
|
||||
|
||||
func (p *fakeRaft) ReportUnreachable(id uint64) {}
|
||||
|
||||
func (p *fakeRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
|
||||
|
@ -46,9 +46,10 @@ type peerGetter interface {
|
||||
Get(id types.ID) Peer
|
||||
}
|
||||
|
||||
func newStreamHandler(peerGetter peerGetter, id, cid types.ID) http.Handler {
|
||||
func newStreamHandler(peerGetter peerGetter, r Raft, id, cid types.ID) http.Handler {
|
||||
return &streamHandler{
|
||||
peerGetter: peerGetter,
|
||||
r: r,
|
||||
id: id,
|
||||
cid: cid,
|
||||
}
|
||||
@ -112,6 +113,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
type streamHandler struct {
|
||||
peerGetter peerGetter
|
||||
r Raft
|
||||
id types.ID
|
||||
cid types.ID
|
||||
}
|
||||
@ -145,6 +147,11 @@ func (h *streamHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "invalid from", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
if h.r.IsIDRemoved(uint64(from)) {
|
||||
log.Printf("rafthttp: reject the stream from peer %s since it was removed", from)
|
||||
http.Error(w, "removed member", http.StatusGone)
|
||||
return
|
||||
}
|
||||
p := h.peerGetter.Get(from)
|
||||
if p == nil {
|
||||
log.Printf("rafthttp: fail to find sender %s", from)
|
||||
|
@ -17,6 +17,7 @@ package rafthttp
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@ -185,7 +186,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
|
||||
|
||||
peer := newFakePeer()
|
||||
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): peer}}
|
||||
h := newStreamHandler(peerGetter, types.ID(2), types.ID(1))
|
||||
h := newStreamHandler(peerGetter, &fakeRaft{}, types.ID(2), types.ID(1))
|
||||
|
||||
rw := httptest.NewRecorder()
|
||||
go h.ServeHTTP(rw, req)
|
||||
@ -207,6 +208,7 @@ func TestServeRaftStreamPrefix(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestServeRaftStreamPrefixBad(t *testing.T) {
|
||||
removedID := uint64(5)
|
||||
tests := []struct {
|
||||
method string
|
||||
path string
|
||||
@ -263,6 +265,14 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
|
||||
"1",
|
||||
http.StatusNotFound,
|
||||
},
|
||||
// removed peer
|
||||
{
|
||||
"GET",
|
||||
RaftStreamPrefix + "/message/" + fmt.Sprint(removedID),
|
||||
"1",
|
||||
"1",
|
||||
http.StatusGone,
|
||||
},
|
||||
// wrong cluster ID
|
||||
{
|
||||
"GET",
|
||||
@ -289,7 +299,8 @@ func TestServeRaftStreamPrefixBad(t *testing.T) {
|
||||
req.Header.Set("X-Raft-To", tt.remote)
|
||||
rw := httptest.NewRecorder()
|
||||
peerGetter := &fakePeerGetter{peers: map[types.ID]Peer{types.ID(1): newFakePeer()}}
|
||||
h := newStreamHandler(peerGetter, types.ID(1), types.ID(1))
|
||||
r := &fakeRaft{removedID: removedID}
|
||||
h := newStreamHandler(peerGetter, r, types.ID(1), types.ID(1))
|
||||
h.ServeHTTP(rw, req)
|
||||
|
||||
if rw.Code != tt.wcode {
|
||||
|
@ -149,8 +149,8 @@ func startPeer(tr http.RoundTripper, urls types.URLs, local, to, cid types.ID, r
|
||||
|
||||
go func() {
|
||||
var paused bool
|
||||
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc)
|
||||
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc)
|
||||
msgAppReader := startStreamReader(tr, picker, streamTypeMsgAppV2, local, to, cid, p.recvc, p.propc, errorc)
|
||||
reader := startStreamReader(tr, picker, streamTypeMessage, local, to, cid, p.recvc, p.propc, errorc)
|
||||
for {
|
||||
select {
|
||||
case m := <-p.sendc:
|
||||
|
@ -226,6 +226,7 @@ type streamReader struct {
|
||||
cid types.ID
|
||||
recvc chan<- raftpb.Message
|
||||
propc chan<- raftpb.Message
|
||||
errorc chan<- error
|
||||
|
||||
mu sync.Mutex
|
||||
msgAppTerm uint64
|
||||
@ -235,7 +236,7 @@ type streamReader struct {
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message) *streamReader {
|
||||
func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, from, to, cid types.ID, recvc chan<- raftpb.Message, propc chan<- raftpb.Message, errorc chan<- error) *streamReader {
|
||||
r := &streamReader{
|
||||
tr: tr,
|
||||
picker: picker,
|
||||
@ -245,6 +246,7 @@ func startStreamReader(tr http.RoundTripper, picker *urlPicker, t streamType, fr
|
||||
cid: cid,
|
||||
recvc: recvc,
|
||||
propc: propc,
|
||||
errorc: errorc,
|
||||
stopc: make(chan struct{}),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
@ -367,11 +369,21 @@ func (cr *streamReader) dial() (io.ReadCloser, error) {
|
||||
cr.picker.unreachable(u)
|
||||
return nil, fmt.Errorf("error roundtripping to %s: %v", req.URL, err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
switch resp.StatusCode {
|
||||
case http.StatusGone:
|
||||
resp.Body.Close()
|
||||
err := fmt.Errorf("the member has been permanently removed from the cluster")
|
||||
select {
|
||||
case cr.errorc <- err:
|
||||
default:
|
||||
}
|
||||
return nil, err
|
||||
case http.StatusOK:
|
||||
return resp.Body, nil
|
||||
default:
|
||||
resp.Body.Close()
|
||||
return nil, fmt.Errorf("unhandled http status %d", resp.StatusCode)
|
||||
}
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (cr *streamReader) cancelRequest() {
|
||||
|
@ -119,15 +119,17 @@ func TestStreamReaderDialRequest(t *testing.T) {
|
||||
// HTTP response received.
|
||||
func TestStreamReaderDialResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
code int
|
||||
err error
|
||||
wok bool
|
||||
code int
|
||||
err error
|
||||
wok bool
|
||||
whalt bool
|
||||
}{
|
||||
{0, errors.New("blah"), false},
|
||||
{http.StatusOK, nil, true},
|
||||
{http.StatusMethodNotAllowed, nil, false},
|
||||
{http.StatusNotFound, nil, false},
|
||||
{http.StatusPreconditionFailed, nil, false},
|
||||
{0, errors.New("blah"), false, false},
|
||||
{http.StatusOK, nil, true, false},
|
||||
{http.StatusMethodNotAllowed, nil, false, false},
|
||||
{http.StatusNotFound, nil, false, false},
|
||||
{http.StatusPreconditionFailed, nil, false, false},
|
||||
{http.StatusGone, nil, false, true},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
tr := newRespRoundTripper(tt.code, tt.err)
|
||||
@ -138,12 +140,16 @@ func TestStreamReaderDialResult(t *testing.T) {
|
||||
from: types.ID(1),
|
||||
to: types.ID(2),
|
||||
cid: types.ID(1),
|
||||
errorc: make(chan error, 1),
|
||||
}
|
||||
|
||||
_, err := sr.dial()
|
||||
if ok := err == nil; ok != tt.wok {
|
||||
t.Errorf("#%d: ok = %v, want %v", i, ok, tt.wok)
|
||||
}
|
||||
if halt := len(sr.errorc) > 0; halt != tt.whalt {
|
||||
t.Errorf("#%d: halt = %v, want %v", i, halt, tt.whalt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -203,7 +209,7 @@ func TestStream(t *testing.T) {
|
||||
h.sw = sw
|
||||
|
||||
picker := mustNewURLPicker(t, []string{srv.URL})
|
||||
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc)
|
||||
sr := startStreamReader(&http.Transport{}, picker, tt.t, types.ID(1), types.ID(2), types.ID(1), recvc, propc, nil)
|
||||
defer sr.stop()
|
||||
if tt.t == streamTypeMsgApp {
|
||||
sr.updateMsgAppTerm(tt.term)
|
||||
|
@ -28,6 +28,7 @@ import (
|
||||
|
||||
type Raft interface {
|
||||
Process(ctx context.Context, m raftpb.Message) error
|
||||
IsIDRemoved(id uint64) bool
|
||||
ReportUnreachable(id uint64)
|
||||
ReportSnapshot(id uint64, status raft.SnapshotStatus)
|
||||
}
|
||||
@ -98,7 +99,7 @@ func NewTransporter(rt http.RoundTripper, id, cid types.ID, r Raft, errorc chan
|
||||
|
||||
func (t *transport) Handler() http.Handler {
|
||||
pipelineHandler := NewHandler(t.raft, t.clusterID)
|
||||
streamHandler := newStreamHandler(t, t.id, t.clusterID)
|
||||
streamHandler := newStreamHandler(t, t.raft, t.id, t.clusterID)
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle(RaftPrefix, pipelineHandler)
|
||||
mux.Handle(RaftStreamPrefix+"/", streamHandler)
|
||||
|
@ -88,6 +88,8 @@ func (r *countRaft) Process(ctx context.Context, m raftpb.Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *countRaft) IsIDRemoved(id uint64) bool { return false }
|
||||
|
||||
func (r *countRaft) ReportUnreachable(id uint64) {}
|
||||
|
||||
func (r *countRaft) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user