raft: MultiNode.Status returns nil for non-existent groups.

Previously it would panic if the group did not exist.
This commit is contained in:
Ben Darnell 2015-05-20 15:45:38 -04:00
parent 260aad5468
commit d58fac453d
2 changed files with 30 additions and 6 deletions

View File

@ -37,8 +37,9 @@ type MultiNode interface {
// last Ready results. It must be called with the last value returned from the Ready()
// channel.
Advance(map[uint64]Ready)
// Status returns the current status of the given group.
Status(group uint64) Status
// Status returns the current status of the given group. Returns nil if no such group
// exists.
Status(group uint64) *Status
// Report reports the given node is not reachable for the last send.
ReportUnreachable(id, groupID uint64)
// ReportSnapshot reports the stutus of the sent snapshot.
@ -70,7 +71,7 @@ type multiConfChange struct {
type multiStatus struct {
group uint64
ch chan Status
ch chan *Status
}
type groupCreation struct {
@ -299,7 +300,12 @@ func (mn *multiNode) run() {
advancec = nil
case ms := <-mn.status:
ms.ch <- getStatus(groups[ms.group].raft)
if group, ok := groups[ms.group]; ok {
s := getStatus(group.raft)
ms.ch <- &s
} else {
ms.ch <- nil
}
case <-mn.stop:
close(mn.done)
@ -443,10 +449,10 @@ func (mn *multiNode) Advance(rds map[uint64]Ready) {
}
}
func (mn *multiNode) Status(group uint64) Status {
func (mn *multiNode) Status(group uint64) *Status {
ms := multiStatus{
group: group,
ch: make(chan Status),
ch: make(chan *Status),
}
mn.status <- ms
return <-ms.ch

View File

@ -392,3 +392,21 @@ func TestMultiNodeAdvance(t *testing.T) {
t.Errorf("expect Ready after Advance, but there is no Ready available")
}
}
func TestMultiNodeStatus(t *testing.T) {
storage := NewMemoryStorage()
mn := StartMultiNode(1)
err := mn.CreateGroup(1, newTestConfig(1, nil, 10, 1, storage), []Peer{{ID: 1}})
if err != nil {
t.Fatal(err)
}
status := mn.Status(1)
if status == nil {
t.Errorf("expected status struct, got nil")
}
status = mn.Status(2)
if status != nil {
t.Errorf("expected nil status, got %+v", status)
}
}