Merge pull request #6349 from gyuho/decode-length-limit

rafthttp: check decode size before buffer alloc
This commit is contained in:
Gyu-Ho Lee 2016-09-05 14:25:23 +09:00 committed by GitHub
commit a66b1e7c60
2 changed files with 58 additions and 24 deletions

View File

@ -16,6 +16,7 @@ package rafthttp
import ( import (
"encoding/binary" "encoding/binary"
"errors"
"io" "io"
"github.com/coreos/etcd/pkg/pbutil" "github.com/coreos/etcd/pkg/pbutil"
@ -41,12 +42,20 @@ type messageDecoder struct {
r io.Reader r io.Reader
} }
var (
readBytesLimit uint64 = 512 * 1024 // 512 MB
ErrExceedSizeLimit = errors.New("rafthttp: error limit exceeded")
)
func (dec *messageDecoder) decode() (raftpb.Message, error) { func (dec *messageDecoder) decode() (raftpb.Message, error) {
var m raftpb.Message var m raftpb.Message
var l uint64 var l uint64
if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil { if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil {
return m, err return m, err
} }
if l > readBytesLimit {
return m, ErrExceedSizeLimit
}
buf := make([]byte, int(l)) buf := make([]byte, int(l))
if _, err := io.ReadFull(dec.r, buf); err != nil { if _, err := io.ReadFull(dec.r, buf); err != nil {
return m, err return m, err

View File

@ -23,8 +23,13 @@ import (
) )
func TestMessage(t *testing.T) { func TestMessage(t *testing.T) {
tests := []raftpb.Message{ tests := []struct {
msg raftpb.Message
encodeErr error
decodeErr error
}{
{ {
raftpb.Message{
Type: raftpb.MsgApp, Type: raftpb.MsgApp,
From: 1, From: 1,
To: 2, To: 2,
@ -33,7 +38,11 @@ func TestMessage(t *testing.T) {
Index: 3, Index: 3,
Entries: []raftpb.Entry{{Term: 1, Index: 4}}, Entries: []raftpb.Entry{{Term: 1, Index: 4}},
}, },
nil,
nil,
},
{ {
raftpb.Message{
Type: raftpb.MsgProp, Type: raftpb.MsgProp,
From: 1, From: 1,
To: 2, To: 2,
@ -43,23 +52,39 @@ func TestMessage(t *testing.T) {
{Data: []byte("some data")}, {Data: []byte("some data")},
}, },
}, },
linkHeartbeatMessage, nil,
nil,
},
{
raftpb.Message{
Type: raftpb.MsgProp,
From: 1,
To: 2,
Entries: []raftpb.Entry{
{Data: bytes.Repeat([]byte("a"), int(readBytesLimit+10))},
},
},
nil,
ErrExceedSizeLimit,
},
} }
for i, tt := range tests { for i, tt := range tests {
b := &bytes.Buffer{} b := &bytes.Buffer{}
enc := &messageEncoder{w: b} enc := &messageEncoder{w: b}
if err := enc.encode(&tt); err != nil { if err := enc.encode(&tt.msg); err != tt.encodeErr {
t.Errorf("#%d: unexpected encode message error: %v", i, err) t.Errorf("#%d: encode message error expected %v, got %v", i, tt.encodeErr, err)
continue continue
} }
dec := &messageDecoder{r: b} dec := &messageDecoder{r: b}
m, err := dec.decode() m, err := dec.decode()
if err != nil { if err != tt.decodeErr {
t.Errorf("#%d: unexpected decode message error: %v", i, err) t.Errorf("#%d: decode message error expected %v, got %v", i, tt.decodeErr, err)
continue continue
} }
if !reflect.DeepEqual(m, tt) { if err == nil {
t.Errorf("#%d: message = %+v, want %+v", i, m, tt) if !reflect.DeepEqual(m, tt.msg) {
t.Errorf("#%d: message = %+v, want %+v", i, m, tt.msg)
}
} }
} }
} }