rafthttp: check decode size before buffer alloc

Fix https://github.com/coreos/etcd/issues/5386.
This commit is contained in:
Gyu-Ho Lee 2016-09-05 11:57:17 +09:00
parent 2e0dc8467d
commit 5c8ba23767
2 changed files with 58 additions and 24 deletions

View File

@ -16,6 +16,7 @@ package rafthttp
import (
"encoding/binary"
"errors"
"io"
"github.com/coreos/etcd/pkg/pbutil"
@ -41,12 +42,20 @@ type messageDecoder struct {
r io.Reader
}
var (
readBytesLimit uint64 = 512 * 1024 // 512 MB
ErrExceedSizeLimit = errors.New("rafthttp: error limit exceeded")
)
func (dec *messageDecoder) decode() (raftpb.Message, error) {
var m raftpb.Message
var l uint64
if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil {
return m, err
}
if l > readBytesLimit {
return m, ErrExceedSizeLimit
}
buf := make([]byte, int(l))
if _, err := io.ReadFull(dec.r, buf); err != nil {
return m, err

View File

@ -23,43 +23,68 @@ import (
)
func TestMessage(t *testing.T) {
tests := []raftpb.Message{
tests := []struct {
msg raftpb.Message
encodeErr error
decodeErr error
}{
{
Type: raftpb.MsgApp,
From: 1,
To: 2,
Term: 1,
LogTerm: 1,
Index: 3,
Entries: []raftpb.Entry{{Term: 1, Index: 4}},
},
{
Type: raftpb.MsgProp,
From: 1,
To: 2,
Entries: []raftpb.Entry{
{Data: []byte("some data")},
{Data: []byte("some data")},
{Data: []byte("some data")},
raftpb.Message{
Type: raftpb.MsgApp,
From: 1,
To: 2,
Term: 1,
LogTerm: 1,
Index: 3,
Entries: []raftpb.Entry{{Term: 1, Index: 4}},
},
nil,
nil,
},
{
raftpb.Message{
Type: raftpb.MsgProp,
From: 1,
To: 2,
Entries: []raftpb.Entry{
{Data: []byte("some data")},
{Data: []byte("some data")},
{Data: []byte("some data")},
},
},
nil,
nil,
},
{
raftpb.Message{
Type: raftpb.MsgProp,
From: 1,
To: 2,
Entries: []raftpb.Entry{
{Data: bytes.Repeat([]byte("a"), int(readBytesLimit+10))},
},
},
nil,
ErrExceedSizeLimit,
},
linkHeartbeatMessage,
}
for i, tt := range tests {
b := &bytes.Buffer{}
enc := &messageEncoder{w: b}
if err := enc.encode(&tt); err != nil {
t.Errorf("#%d: unexpected encode message error: %v", i, err)
if err := enc.encode(&tt.msg); err != tt.encodeErr {
t.Errorf("#%d: encode message error expected %v, got %v", i, tt.encodeErr, err)
continue
}
dec := &messageDecoder{r: b}
m, err := dec.decode()
if err != nil {
t.Errorf("#%d: unexpected decode message error: %v", i, err)
if err != tt.decodeErr {
t.Errorf("#%d: decode message error expected %v, got %v", i, tt.decodeErr, err)
continue
}
if !reflect.DeepEqual(m, tt) {
t.Errorf("#%d: message = %+v, want %+v", i, m, tt)
if err == nil {
if !reflect.DeepEqual(m, tt.msg) {
t.Errorf("#%d: message = %+v, want %+v", i, m, tt.msg)
}
}
}
}