diff --git a/raft/raftpb/raft.proto b/raft/raftpb/raft.proto index f21f3afe5..3b058ba62 100644 --- a/raft/raftpb/raft.proto +++ b/raft/raftpb/raft.proto @@ -60,6 +60,9 @@ enum MessageType { MsgReadIndexResp = 16; MsgPreVote = 17; MsgPreVoteResp = 18; + // NOTE: when adding new message types, remember to update the isLocalMsg and + // isResponseMsg arrays in raft/util.go and update the corresponding tests in + // raft/util_test.go. } message Message { diff --git a/raft/util.go b/raft/util.go index 6e728fb01..0510d3f7e 100644 --- a/raft/util.go +++ b/raft/util.go @@ -40,13 +40,34 @@ func max(a, b uint64) uint64 { return b } +var isLocalMsg = [...]bool{ + pb.MsgHup: true, + pb.MsgBeat: true, + pb.MsgUnreachable: true, + pb.MsgSnapStatus: true, + pb.MsgCheckQuorum: true, +} + +var isResponseMsg = [...]bool{ + pb.MsgAppResp: true, + pb.MsgVoteResp: true, + pb.MsgHeartbeatResp: true, + pb.MsgUnreachable: true, + pb.MsgReadIndexResp: true, + pb.MsgPreVoteResp: true, +} + +func isMsgInArray(msgt pb.MessageType, arr []bool) bool { + i := int(msgt) + return i < len(arr) && arr[i] +} + func IsLocalMsg(msgt pb.MessageType) bool { - return msgt == pb.MsgHup || msgt == pb.MsgBeat || msgt == pb.MsgUnreachable || - msgt == pb.MsgSnapStatus || msgt == pb.MsgCheckQuorum + return isMsgInArray(msgt, isLocalMsg[:]) } func IsResponseMsg(msgt pb.MessageType) bool { - return msgt == pb.MsgAppResp || msgt == pb.MsgVoteResp || msgt == pb.MsgHeartbeatResp || msgt == pb.MsgUnreachable || msgt == pb.MsgPreVoteResp + return isMsgInArray(msgt, isResponseMsg[:]) } // voteResponseType maps vote and prevote message types to their corresponding responses. diff --git a/raft/util_test.go b/raft/util_test.go index eeedebd41..627bdf676 100644 --- a/raft/util_test.go +++ b/raft/util_test.go @@ -96,3 +96,37 @@ func TestIsLocalMsg(t *testing.T) { }) } } + +func TestIsResponseMsg(t *testing.T) { + tests := []struct { + msgt pb.MessageType + isResponse bool + }{ + {pb.MsgHup, false}, + {pb.MsgBeat, false}, + {pb.MsgUnreachable, true}, + {pb.MsgSnapStatus, false}, + {pb.MsgCheckQuorum, false}, + {pb.MsgTransferLeader, false}, + {pb.MsgProp, false}, + {pb.MsgApp, false}, + {pb.MsgAppResp, true}, + {pb.MsgVote, false}, + {pb.MsgVoteResp, true}, + {pb.MsgSnap, false}, + {pb.MsgHeartbeat, false}, + {pb.MsgHeartbeatResp, true}, + {pb.MsgTimeoutNow, false}, + {pb.MsgReadIndex, false}, + {pb.MsgReadIndexResp, true}, + {pb.MsgPreVote, false}, + {pb.MsgPreVoteResp, true}, + } + + for i, tt := range tests { + got := IsResponseMsg(tt.msgt) + if got != tt.isResponse { + t.Errorf("#%d: got %v, want %v", i, got, tt.isResponse) + } + } +}