From 911c8442b7fa6668ba40c82679ac2f8380a8120c Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Sun, 14 Aug 2016 18:55:08 -0700 Subject: [PATCH] rafthttp: fix race between streamReader.stop() and connection closer --- rafthttp/stream.go | 11 ++++++++- rafthttp/stream_test.go | 55 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 1 deletion(-) diff --git a/rafthttp/stream.go b/rafthttp/stream.go index 8f17bd97b..d07bad89a 100644 --- a/rafthttp/stream.go +++ b/rafthttp/stream.go @@ -332,7 +332,16 @@ func (cr *streamReader) decodeLoop(rc io.ReadCloser, t streamType) error { default: plog.Panicf("unhandled stream type %s", t) } - cr.closer = rc + select { + case <-cr.stopc: + cr.mu.Unlock() + if err := rc.Close(); err != nil { + return err + } + return io.EOF + default: + cr.closer = rc + } cr.mu.Unlock() for { diff --git a/rafthttp/stream_test.go b/rafthttp/stream_test.go index d6e711b3e..98b220e73 100644 --- a/rafthttp/stream_test.go +++ b/rafthttp/stream_test.go @@ -17,6 +17,7 @@ package rafthttp import ( "errors" "fmt" + "io" "net/http" "net/http/httptest" "reflect" @@ -180,6 +181,60 @@ func TestStreamReaderDialResult(t *testing.T) { } } +// TestStreamReaderStopOnDial tests a stream reader closes the connection on stop. +func TestStreamReaderStopOnDial(t *testing.T) { + defer testutil.AfterTest(t) + h := http.Header{} + h.Add("X-Server-Version", version.Version) + tr := &respWaitRoundTripper{rrt: &respRoundTripper{code: http.StatusOK, header: h}} + sr := &streamReader{ + peerID: types.ID(2), + tr: &Transport{streamRt: tr, ClusterID: types.ID(1)}, + picker: mustNewURLPicker(t, []string{"http://localhost:2380"}), + errorc: make(chan error, 1), + typ: streamTypeMessage, + status: newPeerStatus(types.ID(2)), + } + tr.onResp = func() { + // stop() waits for the run() goroutine to exit, but that exit + // needs a response from RoundTrip() first; use goroutine + go sr.stop() + // wait so that stop() is blocked on run() exiting + time.Sleep(10 * time.Millisecond) + // sr.run() completes dialing then begins decoding while stopped + } + sr.start() + select { + case <-sr.done: + case <-time.After(time.Second): + t.Fatal("streamReader did not stop in time") + } +} + +type respWaitRoundTripper struct { + rrt *respRoundTripper + onResp func() +} + +func (t *respWaitRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := t.rrt.RoundTrip(req) + resp.Body = newWaitReadCloser() + t.onResp() + return resp, err +} + +type waitReadCloser struct{ closec chan struct{} } + +func newWaitReadCloser() *waitReadCloser { return &waitReadCloser{make(chan struct{})} } +func (wrc *waitReadCloser) Read(p []byte) (int, error) { + <-wrc.closec + return 0, io.EOF +} +func (wrc *waitReadCloser) Close() error { + close(wrc.closec) + return nil +} + // TestStreamReaderDialDetectUnsupport tests that dial func could find // out that the stream type is not supported by the remote. func TestStreamReaderDialDetectUnsupport(t *testing.T) {