diff --git a/raft/multinode.go b/raft/multinode.go index fe5d45eb4..42d2a6959 100644 --- a/raft/multinode.go +++ b/raft/multinode.go @@ -27,7 +27,7 @@ type MultiNode interface { // CreateGroup adds a new group to the MultiNode. The application must call CreateGroup // on each particpating node with the same group ID; it may create groups on demand as it // receives messages. If the given storage contains existing log entries the list of peers - // may be empty. The Config.ID field will be ignored and replaced by the ID passed + // may be empty. If Config.ID field is zero it will be replaced by the ID passed // to StartMultiNode. CreateGroup(group uint64, c *Config, peers []Peer) error // RemoveGroup removes a group from the MultiNode. @@ -62,9 +62,10 @@ type MultiNode interface { Stop() } -// StartMultiNode creates a MultiNode and starts its background goroutine. -// The id identifies this node and will be used as its node ID in all groups. -// The election and heartbeat timers are in units of ticks. +// StartMultiNode creates a MultiNode and starts its background +// goroutine. If id is non-zero it identifies this node and will be +// used as its node ID in all groups. The election and heartbeat +// timers are in units of ticks. func StartMultiNode(id uint64) MultiNode { mn := newMultiNode(id) go mn.run() @@ -193,7 +194,12 @@ func (mn *multiNode) run() { var group *groupState select { case gc := <-mn.groupc: - gc.config.ID = mn.id + if (gc.config.ID != mn.id) && (gc.config.ID != 0 && mn.id != 0) { + panic("if gc.config.ID and mn.id differ, one of them must be zero") + } + if gc.config.ID == 0 { + gc.config.ID = mn.id + } r := newRaft(gc.config) group = &groupState{ id: gc.id, @@ -240,9 +246,9 @@ func (mn *multiNode) run() { // has a leader; we can't do that since we have one propc for many groups. // We'll have to buffer somewhere on a group-by-group basis, or just let // raft.Step drop any such proposals on the floor. - mm.msg.From = mn.id var ok bool if group, ok = groups[mm.group]; ok { + mm.msg.From = group.raft.id group.raft.Step(mm.msg) } diff --git a/raft/multinode_test.go b/raft/multinode_test.go index dd343d170..060e2a091 100644 --- a/raft/multinode_test.go +++ b/raft/multinode_test.go @@ -476,3 +476,98 @@ func TestMultiNodeStatus(t *testing.T) { t.Errorf("expected nil status, got %+v", status) } } + +// TestMultiNodePerGroupID tests that MultiNode may have a different +// node ID for each group, if and only if the Config.ID field is +// filled in when calling CreateGroup. +func TestMultiNodePerGroupID(t *testing.T) { + storage := NewMemoryStorage() + mn := StartMultiNode(0) + + // Maps group ID to node ID. + groups := map[uint64]uint64{ + 1: 10, + 2: 20, + } + + // Create two groups. + for g, nodeID := range groups { + err := mn.CreateGroup(g, newTestConfig(nodeID, nil, 10, 1, storage), + []Peer{{ID: nodeID}, {ID: nodeID + 1}, {ID: nodeID + 2}}) + if err != nil { + t.Fatal(err) + } + } + + // Campaign on both groups. + for g := range groups { + err := mn.Campaign(context.Background(), g) + if err != nil { + t.Fatal(err) + } + } + + // All outgoing messages (two MsgVotes for each group) should have + // the correct From IDs. + var rd map[uint64]Ready + select { + case rd = <-mn.Ready(): + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for ready") + } + for g, nodeID := range groups { + if len(rd[g].Messages) != 2 { + t.Errorf("expected 2 messages in group %d; got %d", g, len(rd[g].Messages)) + } + + for _, m := range rd[g].Messages { + if m.From != nodeID { + t.Errorf("expected %s message in group %d to have From: %d; got %d", + m.Type, g, nodeID, m.From) + } + } + } + mn.Advance(rd) + + // Become a follower in both groups. + for g, nodeID := range groups { + err := mn.Step(context.Background(), g, raftpb.Message{ + Type: raftpb.MsgHeartbeat, + To: nodeID, + From: nodeID + 1, + }) + if err != nil { + t.Fatal(err) + } + } + + // Propose a command on each group (Propose is tested separately + // because proposals in follower mode go through a different code path). + for g := range groups { + err := mn.Propose(context.Background(), g, []byte("foo")) + if err != nil { + t.Fatal(err) + } + } + + // Validate that all outgoing messages (heartbeat response and + // proposal) have the correct From IDs. + select { + case rd = <-mn.Ready(): + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for ready") + } + for g, nodeID := range groups { + if len(rd[g].Messages) != 2 { + t.Errorf("expected 2 messages in group %d; got %d", g, len(rd[g].Messages)) + } + + for _, m := range rd[g].Messages { + if m.From != nodeID { + t.Errorf("expected %s message in group %d to have From: %d; got %d", + m.Type, g, nodeID, m.From) + } + } + } + mn.Advance(rd) +}