raft: Propose in raft node wait the proposal result so we can fail fast while dropping proposal.

This commit is contained in:
Vincent Lee 2018-01-12 15:41:59 +08:00
parent bf052ef491
commit f0dffb4163
2 changed files with 101 additions and 16 deletions

View File

@ -224,9 +224,14 @@ func RestartNode(c *Config) Node {
return &n return &n
} }
type msgWithResult struct {
m pb.Message
result chan error
}
// node is the canonical implementation of the Node interface // node is the canonical implementation of the Node interface
type node struct { type node struct {
propc chan pb.Message propc chan msgWithResult
recvc chan pb.Message recvc chan pb.Message
confc chan pb.ConfChange confc chan pb.ConfChange
confstatec chan pb.ConfState confstatec chan pb.ConfState
@ -242,7 +247,7 @@ type node struct {
func newNode() node { func newNode() node {
return node{ return node{
propc: make(chan pb.Message), propc: make(chan msgWithResult),
recvc: make(chan pb.Message), recvc: make(chan pb.Message),
confc: make(chan pb.ConfChange), confc: make(chan pb.ConfChange),
confstatec: make(chan pb.ConfState), confstatec: make(chan pb.ConfState),
@ -271,7 +276,7 @@ func (n *node) Stop() {
} }
func (n *node) run(r *raft) { func (n *node) run(r *raft) {
var propc chan pb.Message var propc chan msgWithResult
var readyc chan Ready var readyc chan Ready
var advancec chan struct{} var advancec chan struct{}
var prevLastUnstablei, prevLastUnstablet uint64 var prevLastUnstablei, prevLastUnstablet uint64
@ -314,13 +319,18 @@ func (n *node) run(r *raft) {
// TODO: maybe buffer the config propose if there exists one (the way // TODO: maybe buffer the config propose if there exists one (the way
// described in raft dissertation) // described in raft dissertation)
// Currently it is dropped in Step silently. // Currently it is dropped in Step silently.
case m := <-propc: case pm := <-propc:
m := pm.m
m.From = r.id m.From = r.id
r.Step(m) err := r.Step(m)
if pm.result != nil {
pm.result <- err
close(pm.result)
}
case m := <-n.recvc: case m := <-n.recvc:
// filter out response message from unknown From. // filter out response message from unknown From.
if pr := r.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) { if pr := r.getProgress(m.From); pr != nil || !IsResponseMsg(m.Type) {
r.Step(m) // raft never returns an error r.Step(m)
} }
case cc := <-n.confc: case cc := <-n.confc:
if cc.NodeID == None { if cc.NodeID == None {
@ -408,7 +418,7 @@ func (n *node) Tick() {
func (n *node) Campaign(ctx context.Context) error { return n.step(ctx, pb.Message{Type: pb.MsgHup}) } func (n *node) Campaign(ctx context.Context) error { return n.step(ctx, pb.Message{Type: pb.MsgHup}) }
func (n *node) Propose(ctx context.Context, data []byte) error { func (n *node) Propose(ctx context.Context, data []byte) error {
return n.step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}}) return n.stepWait(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Data: data}}})
} }
func (n *node) Step(ctx context.Context, m pb.Message) error { func (n *node) Step(ctx context.Context, m pb.Message) error {
@ -428,16 +438,20 @@ func (n *node) ProposeConfChange(ctx context.Context, cc pb.ConfChange) error {
return n.Step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Type: pb.EntryConfChange, Data: data}}}) return n.Step(ctx, pb.Message{Type: pb.MsgProp, Entries: []pb.Entry{{Type: pb.EntryConfChange, Data: data}}})
} }
// Step advances the state machine using msgs. The ctx.Err() will be returned,
// if any.
func (n *node) step(ctx context.Context, m pb.Message) error { func (n *node) step(ctx context.Context, m pb.Message) error {
ch := n.recvc return n.stepWithWaitOption(ctx, m, false)
if m.Type == pb.MsgProp {
ch = n.propc
} }
func (n *node) stepWait(ctx context.Context, m pb.Message) error {
return n.stepWithWaitOption(ctx, m, true)
}
// Step advances the state machine using msgs. The ctx.Err() will be returned,
// if any.
func (n *node) stepWithWaitOption(ctx context.Context, m pb.Message, wait bool) error {
if m.Type != pb.MsgProp {
select { select {
case ch <- m: case n.recvc <- m:
return nil return nil
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
@ -445,6 +459,33 @@ func (n *node) step(ctx context.Context, m pb.Message) error {
return ErrStopped return ErrStopped
} }
} }
ch := n.propc
pm := msgWithResult{m: m}
if wait {
pm.result = make(chan error, 1)
}
select {
case ch <- pm:
if !wait {
return nil
}
case <-ctx.Done():
return ctx.Err()
case <-n.done:
return ErrStopped
}
select {
case rsp := <-pm.result:
if rsp != nil {
return rsp
}
case <-ctx.Done():
return ctx.Err()
case <-n.done:
return ErrStopped
}
return nil
}
func (n *node) Ready() <-chan Ready { return n.readyc } func (n *node) Ready() <-chan Ready { return n.readyc }

View File

@ -18,6 +18,7 @@ import (
"bytes" "bytes"
"context" "context"
"reflect" "reflect"
"strings"
"testing" "testing"
"time" "time"
@ -30,7 +31,7 @@ import (
func TestNodeStep(t *testing.T) { func TestNodeStep(t *testing.T) {
for i, msgn := range raftpb.MessageType_name { for i, msgn := range raftpb.MessageType_name {
n := &node{ n := &node{
propc: make(chan raftpb.Message, 1), propc: make(chan msgWithResult, 1),
recvc: make(chan raftpb.Message, 1), recvc: make(chan raftpb.Message, 1),
} }
msgt := raftpb.MessageType(i) msgt := raftpb.MessageType(i)
@ -64,7 +65,7 @@ func TestNodeStep(t *testing.T) {
func TestNodeStepUnblock(t *testing.T) { func TestNodeStepUnblock(t *testing.T) {
// a node without buffer to block step // a node without buffer to block step
n := &node{ n := &node{
propc: make(chan raftpb.Message), propc: make(chan msgWithResult),
done: make(chan struct{}), done: make(chan struct{}),
} }
@ -433,6 +434,49 @@ func TestBlockProposal(t *testing.T) {
} }
} }
func TestNodeProposeWaitDropped(t *testing.T) {
msgs := []raftpb.Message{}
droppingMsg := []byte("test_dropping")
dropStep := func(r *raft, m raftpb.Message) error {
if m.Type == raftpb.MsgProp && strings.Contains(m.String(), string(droppingMsg)) {
t.Logf("dropping message: %v", m.String())
return ErrProposalDropped
}
msgs = append(msgs, m)
return nil
}
n := newNode()
s := NewMemoryStorage()
r := newTestRaft(1, []uint64{1}, 10, 1, s)
go n.run(r)
n.Campaign(context.TODO())
for {
rd := <-n.Ready()
s.Append(rd.Entries)
// change the step function to dropStep until this raft becomes leader
if rd.SoftState.Lead == r.id {
r.step = dropStep
n.Advance()
break
}
n.Advance()
}
proposalTimeout := time.Millisecond * 100
ctx, cancel := context.WithTimeout(context.Background(), proposalTimeout)
// propose with cancel should be cancelled earyly if dropped
err := n.Propose(ctx, droppingMsg)
if err != ErrProposalDropped {
t.Errorf("should drop proposal : %v", err)
}
cancel()
n.Stop()
if len(msgs) != 0 {
t.Fatalf("len(msgs) = %d, want %d", len(msgs), 1)
}
}
// TestNodeTick ensures that node.Tick() will increase the // TestNodeTick ensures that node.Tick() will increase the
// elapsed of the underlying raft state machine. // elapsed of the underlying raft state machine.
func TestNodeTick(t *testing.T) { func TestNodeTick(t *testing.T) {