diff --git a/server/wal/decoder.go b/server/wal/decoder.go index 0251a7213..2656d286a 100644 --- a/server/wal/decoder.go +++ b/server/wal/decoder.go @@ -15,12 +15,13 @@ package wal import ( - "bufio" "encoding/binary" + "fmt" "hash" "io" "sync" + "go.etcd.io/etcd/client/pkg/v3/fileutil" "go.etcd.io/etcd/pkg/v3/crc" "go.etcd.io/etcd/pkg/v3/pbutil" "go.etcd.io/etcd/raft/v3/raftpb" @@ -34,17 +35,17 @@ const frameSizeBytes = 8 type decoder struct { mu sync.Mutex - brs []*bufio.Reader + brs []*fileutil.FileBufReader // lastValidOff file offset following the last valid decoded record lastValidOff int64 crc hash.Hash32 } -func newDecoder(r ...io.Reader) *decoder { - readers := make([]*bufio.Reader, len(r)) +func newDecoder(r ...fileutil.FileReader) *decoder { + readers := make([]*fileutil.FileBufReader, len(r)) for i := range r { - readers[i] = bufio.NewReader(r[i]) + readers[i] = fileutil.NewFileBufReader(r[i]) } return &decoder{ brs: readers, @@ -59,17 +60,13 @@ func (d *decoder) decode(rec *walpb.Record) error { return d.decodeRecord(rec) } -// raft max message size is set to 1 MB in etcd server -// assume projects set reasonable message size limit, -// thus entry size should never exceed 10 MB -const maxWALEntrySizeLimit = int64(10 * 1024 * 1024) - func (d *decoder) decodeRecord(rec *walpb.Record) error { if len(d.brs) == 0 { return io.EOF } - l, err := readInt64(d.brs[0]) + fileBufReader := d.brs[0] + l, err := readInt64(fileBufReader) if err == io.EOF || (err == nil && l == 0) { // hit end of file or preallocated space d.brs = d.brs[1:] @@ -84,12 +81,15 @@ func (d *decoder) decodeRecord(rec *walpb.Record) error { } recBytes, padBytes := decodeFrameSize(l) - if recBytes >= maxWALEntrySizeLimit-padBytes { - return ErrMaxWALEntrySizeLimitExceeded + // The length of current WAL entry must be less than the remaining file size. + maxEntryLimit := fileBufReader.FileInfo().Size() - d.lastValidOff - padBytes + if recBytes > maxEntryLimit { + return fmt.Errorf("wal: max entry size limit exceeded, recBytes: %d, fileSize(%d) - offset(%d) - padBytes(%d) = entryLimit(%d)", + recBytes, fileBufReader.FileInfo().Size(), d.lastValidOff, padBytes, maxEntryLimit) } data := make([]byte, recBytes+padBytes) - if _, err = io.ReadFull(d.brs[0], data); err != nil { + if _, err = io.ReadFull(fileBufReader, data); err != nil { // ReadFull returns io.EOF only if no bytes were read // the decoder should treat this as an ErrUnexpectedEOF instead. if err == io.EOF { diff --git a/server/wal/record_test.go b/server/wal/record_test.go index d28807ebb..49a2c8eac 100644 --- a/server/wal/record_test.go +++ b/server/wal/record_test.go @@ -19,10 +19,11 @@ import ( "errors" "hash/crc32" "io" - "io/ioutil" + "os" "reflect" "testing" + "go.etcd.io/etcd/client/pkg/v3/fileutil" "go.etcd.io/etcd/server/v3/wal/walpb" ) @@ -43,8 +44,7 @@ func TestReadRecord(t *testing.T) { }{ {infoRecord, &walpb.Record{Type: 1, Crc: crc32.Checksum(infoData, crcTable), Data: infoData}, nil}, {[]byte(""), &walpb.Record{}, io.EOF}, - {infoRecord[:8], &walpb.Record{}, io.ErrUnexpectedEOF}, - {infoRecord[:len(infoRecord)-len(infoData)-8], &walpb.Record{}, io.ErrUnexpectedEOF}, + {infoRecord[:14], &walpb.Record{}, io.ErrUnexpectedEOF}, {infoRecord[:len(infoRecord)-len(infoData)], &walpb.Record{}, io.ErrUnexpectedEOF}, {infoRecord[:len(infoRecord)-8], &walpb.Record{}, io.ErrUnexpectedEOF}, {badInfoRecord, &walpb.Record{}, walpb.ErrCRCMismatch}, @@ -53,7 +53,11 @@ func TestReadRecord(t *testing.T) { rec := &walpb.Record{} for i, tt := range tests { buf := bytes.NewBuffer(tt.data) - decoder := newDecoder(ioutil.NopCloser(buf)) + f, err := createFileWithData(t, buf) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + decoder := newDecoder(fileutil.NewFileReader(f)) e := decoder.decode(rec) if !reflect.DeepEqual(rec, tt.wr) { t.Errorf("#%d: block = %v, want %v", i, rec, tt.wr) @@ -73,8 +77,12 @@ func TestWriteRecord(t *testing.T) { e := newEncoder(buf, 0, 0) e.encode(&walpb.Record{Type: typ, Data: d}) e.flush() - decoder := newDecoder(ioutil.NopCloser(buf)) - err := decoder.decode(b) + f, err := createFileWithData(t, buf) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + decoder := newDecoder(fileutil.NewFileReader(f)) + err = decoder.decode(b) if err != nil { t.Errorf("err = %v, want nil", err) } @@ -85,3 +93,15 @@ func TestWriteRecord(t *testing.T) { t.Errorf("data = %v, want %v", b.Data, d) } } + +func createFileWithData(t *testing.T, bf *bytes.Buffer) (*os.File, error) { + f, err := os.CreateTemp(t.TempDir(), "wal") + if err != nil { + return nil, err + } + if _, err := f.Write(bf.Bytes()); err != nil { + return nil, err + } + f.Seek(0, 0) + return f, nil +} diff --git a/server/wal/repair.go b/server/wal/repair.go index 122ee49a6..0ed842546 100644 --- a/server/wal/repair.go +++ b/server/wal/repair.go @@ -40,7 +40,7 @@ func Repair(lg *zap.Logger, dirpath string) bool { lg.Info("repairing", zap.String("path", f.Name())) rec := &walpb.Record{} - decoder := newDecoder(f) + decoder := newDecoder(fileutil.NewFileReader(f.File)) for { lastOffset := decoder.lastOffset() err := decoder.decode(rec) diff --git a/server/wal/wal.go b/server/wal/wal.go index 3c940e0cd..0d652220e 100644 --- a/server/wal/wal.go +++ b/server/wal/wal.go @@ -54,15 +54,14 @@ var ( // so that tests can set a different segment size. SegmentSizeBytes int64 = 64 * 1000 * 1000 // 64MB - ErrMetadataConflict = errors.New("wal: conflicting metadata found") - ErrFileNotFound = errors.New("wal: file not found") - ErrCRCMismatch = errors.New("wal: crc mismatch") - ErrSnapshotMismatch = errors.New("wal: snapshot mismatch") - ErrSnapshotNotFound = errors.New("wal: snapshot not found") - ErrSliceOutOfRange = errors.New("wal: slice bounds out of range") - ErrMaxWALEntrySizeLimitExceeded = errors.New("wal: max entry size limit exceeded") - ErrDecoderNotFound = errors.New("wal: decoder not found") - crcTable = crc32.MakeTable(crc32.Castagnoli) + ErrMetadataConflict = errors.New("wal: conflicting metadata found") + ErrFileNotFound = errors.New("wal: file not found") + ErrCRCMismatch = errors.New("wal: crc mismatch") + ErrSnapshotMismatch = errors.New("wal: snapshot mismatch") + ErrSnapshotNotFound = errors.New("wal: snapshot not found") + ErrSliceOutOfRange = errors.New("wal: slice bounds out of range") + ErrDecoderNotFound = errors.New("wal: decoder not found") + crcTable = crc32.MakeTable(crc32.Castagnoli) ) // WAL is a logical representation of the stable storage. @@ -378,12 +377,13 @@ func selectWALFiles(lg *zap.Logger, dirpath string, snap walpb.Snapshot) ([]stri return names, nameIndex, nil } -func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, write bool) ([]io.Reader, []*fileutil.LockedFile, func() error, error) { +func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, write bool) ([]fileutil.FileReader, []*fileutil.LockedFile, func() error, error) { rcs := make([]io.ReadCloser, 0) - rs := make([]io.Reader, 0) + rs := make([]fileutil.FileReader, 0) ls := make([]*fileutil.LockedFile, 0) for _, name := range names[nameIndex:] { p := filepath.Join(dirpath, name) + var f *os.File if write { l, err := fileutil.TryLockFile(p, os.O_RDWR, fileutil.PrivateFileMode) if err != nil { @@ -392,6 +392,7 @@ func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, } ls = append(ls, l) rcs = append(rcs, l) + f = l.File } else { rf, err := os.OpenFile(p, os.O_RDONLY, fileutil.PrivateFileMode) if err != nil { @@ -400,8 +401,10 @@ func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, } ls = append(ls, nil) rcs = append(rcs, rf) + f = rf } - rs = append(rs, rcs[len(rcs)-1]) + fileReader := fileutil.NewFileReader(f) + rs = append(rs, fileReader) } closer := func() error { return closeAll(lg, rcs...) } diff --git a/server/wal/wal_test.go b/server/wal/wal_test.go index 05014086c..e3b829cff 100644 --- a/server/wal/wal_test.go +++ b/server/wal/wal_test.go @@ -20,6 +20,7 @@ import ( "io" "io/ioutil" "math" + "math/rand" "os" "path" "path/filepath" @@ -335,7 +336,7 @@ func TestCut(t *testing.T) { } defer f.Close() nw := &WAL{ - decoder: newDecoder(f), + decoder: newDecoder(fileutil.NewFileReader(f)), start: snap, } _, gst, _, err := nw.ReadAll() @@ -411,51 +412,76 @@ func TestSaveWithCut(t *testing.T) { } func TestRecover(t *testing.T) { - p, err := ioutil.TempDir(t.TempDir(), "waltest") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(p) - - w, err := Create(zap.NewExample(), p, []byte("metadata")) - if err != nil { - t.Fatal(err) - } - if err = w.SaveSnapshot(walpb.Snapshot{}); err != nil { - t.Fatal(err) - } - ents := []raftpb.Entry{{Index: 1, Term: 1, Data: []byte{1}}, {Index: 2, Term: 2, Data: []byte{2}}} - if err = w.Save(raftpb.HardState{}, ents); err != nil { - t.Fatal(err) - } - sts := []raftpb.HardState{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}} - for _, s := range sts { - if err = w.Save(s, nil); err != nil { - t.Fatal(err) - } - } - w.Close() - - if w, err = Open(zap.NewExample(), p, walpb.Snapshot{}); err != nil { - t.Fatal(err) - } - metadata, state, entries, err := w.ReadAll() - if err != nil { - t.Fatal(err) + cases := []struct { + name string + size int + }{ + { + name: "10MB", + size: 10 * 1024 * 1024, + }, + { + name: "20MB", + size: 20 * 1024 * 1024, + }, + { + name: "40MB", + size: 40 * 1024 * 1024, + }, } - if !bytes.Equal(metadata, []byte("metadata")) { - t.Errorf("metadata = %s, want %s", metadata, "metadata") + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p := t.TempDir() + + w, err := Create(zap.NewExample(), p, []byte("metadata")) + if err != nil { + t.Fatal(err) + } + if err = w.SaveSnapshot(walpb.Snapshot{}); err != nil { + t.Fatal(err) + } + + data := make([]byte, tc.size) + n, err := rand.Read(data) + assert.Equal(t, tc.size, n) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + ents := []raftpb.Entry{{Index: 1, Term: 1, Data: data}, {Index: 2, Term: 2, Data: data}} + if err = w.Save(raftpb.HardState{}, ents); err != nil { + t.Fatal(err) + } + sts := []raftpb.HardState{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}} + for _, s := range sts { + if err = w.Save(s, nil); err != nil { + t.Fatal(err) + } + } + w.Close() + + if w, err = Open(zap.NewExample(), p, walpb.Snapshot{}); err != nil { + t.Fatal(err) + } + metadata, state, entries, err := w.ReadAll() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(metadata, []byte("metadata")) { + t.Errorf("metadata = %s, want %s", metadata, "metadata") + } + if !reflect.DeepEqual(entries, ents) { + t.Errorf("ents = %+v, want %+v", entries, ents) + } + // only the latest state is recorded + s := sts[len(sts)-1] + if !reflect.DeepEqual(state, s) { + t.Errorf("state = %+v, want %+v", state, s) + } + w.Close() + }) } - if !reflect.DeepEqual(entries, ents) { - t.Errorf("ents = %+v, want %+v", entries, ents) - } - // only the latest state is recorded - s := sts[len(sts)-1] - if !reflect.DeepEqual(state, s) { - t.Errorf("state = %+v, want %+v", state, s) - } - w.Close() } func TestSearchIndex(t *testing.T) {