From 3319f716d9146cd283bac61e35be94c0f6020c42 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Sat, 3 Jan 2015 19:39:33 -0800 Subject: [PATCH] rafthttp: a stopped stream does not accept any methods --- rafthttp/streamer.go | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/rafthttp/streamer.go b/rafthttp/streamer.go index c7a49c366..3f1654e3a 100644 --- a/rafthttp/streamer.go +++ b/rafthttp/streamer.go @@ -17,6 +17,7 @@ package rafthttp import ( + "errors" "fmt" "io" "log" @@ -41,24 +42,27 @@ const ( // TODO: a stream might hava one stream server or one stream client, but not both. type stream struct { - // the server might be attached asynchronously with the owner of the stream - // use a mutex to protect it sync.Mutex - w *streamWriter - - r *streamReader + w *streamWriter + r *streamReader + stopped bool } func (s *stream) open(from, to, cid types.ID, term uint64, tr http.RoundTripper, u string, r Raft) error { - if s.r != nil { - panic("open: stream is open") - } - c, err := newStreamReader(from, to, cid, term, tr, u, r) if err != nil { log.Printf("stream: error opening stream: %v", err) return err } + + s.Lock() + defer s.Unlock() + if s.stopped { + return errors.New("stream: stopped") + } + if s.r != nil { + panic("open: stream is open") + } s.r = c return nil } @@ -66,6 +70,9 @@ func (s *stream) open(from, to, cid types.ID, term uint64, tr http.RoundTripper, func (s *stream) attach(sw *streamWriter) error { s.Lock() defer s.Unlock() + if s.stopped { + return errors.New("stream: stopped") + } if s.w != nil { // ignore lower-term streaming request if sw.term < s.w.term { @@ -80,6 +87,9 @@ func (s *stream) attach(sw *streamWriter) error { func (s *stream) write(m raftpb.Message) bool { s.Lock() defer s.Unlock() + if s.stopped { + return false + } if s.w == nil { return false } @@ -105,7 +115,6 @@ func (s *stream) write(m raftpb.Message) bool { func (s *stream) invalidate(term uint64) { s.Lock() defer s.Unlock() - if s.w != nil { if s.w.term < term { s.w.stop() @@ -118,6 +127,9 @@ func (s *stream) invalidate(term uint64) { s.r = nil } } + if term == math.MaxUint64 { + s.stopped = true + } } func (s *stream) stop() { @@ -125,6 +137,8 @@ func (s *stream) stop() { } func (s *stream) isOpen() bool { + s.Lock() + defer s.Unlock() if s.r != nil && s.r.isStopped() { s.r = nil }