diff --git a/rafthttp/msg_codec.go b/rafthttp/msg_codec.go index 281ff4aa9..8bd4d0119 100644 --- a/rafthttp/msg_codec.go +++ b/rafthttp/msg_codec.go @@ -16,6 +16,7 @@ package rafthttp import ( "encoding/binary" + "errors" "io" "github.com/coreos/etcd/pkg/pbutil" @@ -41,12 +42,20 @@ type messageDecoder struct { r io.Reader } +var ( + readBytesLimit uint64 = 512 * 1024 // 512 MB + ErrExceedSizeLimit = errors.New("rafthttp: error limit exceeded") +) + func (dec *messageDecoder) decode() (raftpb.Message, error) { var m raftpb.Message var l uint64 if err := binary.Read(dec.r, binary.BigEndian, &l); err != nil { return m, err } + if l > readBytesLimit { + return m, ErrExceedSizeLimit + } buf := make([]byte, int(l)) if _, err := io.ReadFull(dec.r, buf); err != nil { return m, err diff --git a/rafthttp/msg_codec_test.go b/rafthttp/msg_codec_test.go index 39e296b3f..56043a41f 100644 --- a/rafthttp/msg_codec_test.go +++ b/rafthttp/msg_codec_test.go @@ -23,43 +23,68 @@ import ( ) func TestMessage(t *testing.T) { - tests := []raftpb.Message{ + tests := []struct { + msg raftpb.Message + encodeErr error + decodeErr error + }{ { - Type: raftpb.MsgApp, - From: 1, - To: 2, - Term: 1, - LogTerm: 1, - Index: 3, - Entries: []raftpb.Entry{{Term: 1, Index: 4}}, - }, - { - Type: raftpb.MsgProp, - From: 1, - To: 2, - Entries: []raftpb.Entry{ - {Data: []byte("some data")}, - {Data: []byte("some data")}, - {Data: []byte("some data")}, + raftpb.Message{ + Type: raftpb.MsgApp, + From: 1, + To: 2, + Term: 1, + LogTerm: 1, + Index: 3, + Entries: []raftpb.Entry{{Term: 1, Index: 4}}, }, + nil, + nil, + }, + { + raftpb.Message{ + Type: raftpb.MsgProp, + From: 1, + To: 2, + Entries: []raftpb.Entry{ + {Data: []byte("some data")}, + {Data: []byte("some data")}, + {Data: []byte("some data")}, + }, + }, + nil, + nil, + }, + { + raftpb.Message{ + Type: raftpb.MsgProp, + From: 1, + To: 2, + Entries: []raftpb.Entry{ + {Data: bytes.Repeat([]byte("a"), int(readBytesLimit+10))}, + }, + }, + nil, + ErrExceedSizeLimit, }, - linkHeartbeatMessage, } for i, tt := range tests { b := &bytes.Buffer{} enc := &messageEncoder{w: b} - if err := enc.encode(&tt); err != nil { - t.Errorf("#%d: unexpected encode message error: %v", i, err) + if err := enc.encode(&tt.msg); err != tt.encodeErr { + t.Errorf("#%d: encode message error expected %v, got %v", i, tt.encodeErr, err) continue } dec := &messageDecoder{r: b} m, err := dec.decode() - if err != nil { - t.Errorf("#%d: unexpected decode message error: %v", i, err) + if err != tt.decodeErr { + t.Errorf("#%d: decode message error expected %v, got %v", i, tt.decodeErr, err) continue } - if !reflect.DeepEqual(m, tt) { - t.Errorf("#%d: message = %+v, want %+v", i, m, tt) + if err == nil { + if !reflect.DeepEqual(m, tt.msg) { + t.Errorf("#%d: message = %+v, want %+v", i, m, tt.msg) + } } } }