diff --git a/client/pkg/fileutil/filereader.go b/client/pkg/fileutil/filereader.go new file mode 100644 index 000000000..55248888c --- /dev/null +++ b/client/pkg/fileutil/filereader.go @@ -0,0 +1,60 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileutil + +import ( + "bufio" + "io" + "io/fs" + "os" +) + +// FileReader is a wrapper of io.Reader. It also provides file info. +type FileReader interface { + io.Reader + FileInfo() (fs.FileInfo, error) +} + +type fileReader struct { + *os.File +} + +func NewFileReader(f *os.File) FileReader { + return &fileReader{f} +} + +func (fr *fileReader) FileInfo() (fs.FileInfo, error) { + return fr.Stat() +} + +// FileBufReader is a wrapper of bufio.Reader. It also provides file info. +type FileBufReader struct { + *bufio.Reader + fi fs.FileInfo +} + +func NewFileBufReader(fr FileReader) *FileBufReader { + bufReader := bufio.NewReader(fr) + fi, err := fr.FileInfo() + if err != nil { + // This should never happen. + panic(err) + } + return &FileBufReader{bufReader, fi} +} + +func (fbr *FileBufReader) FileInfo() fs.FileInfo { + return fbr.fi +} diff --git a/client/pkg/fileutil/filereader_test.go b/client/pkg/fileutil/filereader_test.go new file mode 100644 index 000000000..2f863cdce --- /dev/null +++ b/client/pkg/fileutil/filereader_test.go @@ -0,0 +1,44 @@ +// Copyright 2022 The etcd Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package fileutil + +import ( + "os" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFileBufReader(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "wal") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + fi, err := f.Stat() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + fbr := NewFileBufReader(NewFileReader(f)) + + if !strings.HasPrefix(fbr.FileInfo().Name(), "wal") { + t.Errorf("Unexpected file name: %s", fbr.FileInfo().Name()) + } + assert.Equal(t, fi.Size(), fbr.FileInfo().Size()) + assert.Equal(t, fi.IsDir(), fbr.FileInfo().IsDir()) + assert.Equal(t, fi.Mode(), fbr.FileInfo().Mode()) + assert.Equal(t, fi.ModTime(), fbr.FileInfo().ModTime()) +} diff --git a/server/storage/wal/decoder.go b/server/storage/wal/decoder.go index 7cc634a2e..99ca9fdc4 100644 --- a/server/storage/wal/decoder.go +++ b/server/storage/wal/decoder.go @@ -15,12 +15,13 @@ package wal import ( - "bufio" "encoding/binary" + "fmt" "hash" "io" "sync" + "go.etcd.io/etcd/client/pkg/v3/fileutil" "go.etcd.io/etcd/pkg/v3/crc" "go.etcd.io/etcd/pkg/v3/pbutil" "go.etcd.io/etcd/raft/v3/raftpb" @@ -34,17 +35,17 @@ const frameSizeBytes = 8 type decoder struct { mu sync.Mutex - brs []*bufio.Reader + brs []*fileutil.FileBufReader // lastValidOff file offset following the last valid decoded record lastValidOff int64 crc hash.Hash32 } -func newDecoder(r ...io.Reader) *decoder { - readers := make([]*bufio.Reader, len(r)) +func newDecoder(r ...fileutil.FileReader) *decoder { + readers := make([]*fileutil.FileBufReader, len(r)) for i := range r { - readers[i] = bufio.NewReader(r[i]) + readers[i] = fileutil.NewFileBufReader(r[i]) } return &decoder{ brs: readers, @@ -59,17 +60,13 @@ func (d *decoder) decode(rec *walpb.Record) error { return d.decodeRecord(rec) } -// raft max message size is set to 1 MB in etcd server -// assume projects set reasonable message size limit, -// thus entry size should never exceed 10 MB -const maxWALEntrySizeLimit = int64(10 * 1024 * 1024) - func (d *decoder) decodeRecord(rec *walpb.Record) error { if len(d.brs) == 0 { return io.EOF } - l, err := readInt64(d.brs[0]) + fileBufReader := d.brs[0] + l, err := readInt64(fileBufReader) if err == io.EOF || (err == nil && l == 0) { // hit end of file or preallocated space d.brs = d.brs[1:] @@ -84,12 +81,15 @@ func (d *decoder) decodeRecord(rec *walpb.Record) error { } recBytes, padBytes := decodeFrameSize(l) - if recBytes >= maxWALEntrySizeLimit-padBytes { - return ErrMaxWALEntrySizeLimitExceeded + // The length of current WAL entry must be less than the remaining file size. + maxEntryLimit := fileBufReader.FileInfo().Size() - d.lastValidOff - padBytes + if recBytes > maxEntryLimit { + return fmt.Errorf("wal: max entry size limit exceeded, recBytes: %d, fileSize(%d) - offset(%d) - padBytes(%d) = entryLimit(%d)", + recBytes, fileBufReader.FileInfo().Size(), d.lastValidOff, padBytes, maxEntryLimit) } data := make([]byte, recBytes+padBytes) - if _, err = io.ReadFull(d.brs[0], data); err != nil { + if _, err = io.ReadFull(fileBufReader, data); err != nil { // ReadFull returns io.EOF only if no bytes were read // the decoder should treat this as an ErrUnexpectedEOF instead. if err == io.EOF { diff --git a/server/storage/wal/record_test.go b/server/storage/wal/record_test.go index 68bcb0bda..0a01d6e6f 100644 --- a/server/storage/wal/record_test.go +++ b/server/storage/wal/record_test.go @@ -19,9 +19,11 @@ import ( "errors" "hash/crc32" "io" + "os" "reflect" "testing" + "go.etcd.io/etcd/client/pkg/v3/fileutil" "go.etcd.io/etcd/server/v3/storage/wal/walpb" ) @@ -42,8 +44,7 @@ func TestReadRecord(t *testing.T) { }{ {infoRecord, &walpb.Record{Type: 1, Crc: crc32.Checksum(infoData, crcTable), Data: infoData}, nil}, {[]byte(""), &walpb.Record{}, io.EOF}, - {infoRecord[:8], &walpb.Record{}, io.ErrUnexpectedEOF}, - {infoRecord[:len(infoRecord)-len(infoData)-8], &walpb.Record{}, io.ErrUnexpectedEOF}, + {infoRecord[:14], &walpb.Record{}, io.ErrUnexpectedEOF}, {infoRecord[:len(infoRecord)-len(infoData)], &walpb.Record{}, io.ErrUnexpectedEOF}, {infoRecord[:len(infoRecord)-8], &walpb.Record{}, io.ErrUnexpectedEOF}, {badInfoRecord, &walpb.Record{}, walpb.ErrCRCMismatch}, @@ -52,7 +53,11 @@ func TestReadRecord(t *testing.T) { rec := &walpb.Record{} for i, tt := range tests { buf := bytes.NewBuffer(tt.data) - decoder := newDecoder(io.NopCloser(buf)) + f, err := createFileWithData(t, buf) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + decoder := newDecoder(fileutil.NewFileReader(f)) e := decoder.decode(rec) if !reflect.DeepEqual(rec, tt.wr) { t.Errorf("#%d: block = %v, want %v", i, rec, tt.wr) @@ -72,8 +77,12 @@ func TestWriteRecord(t *testing.T) { e := newEncoder(buf, 0, 0) e.encode(&walpb.Record{Type: typ, Data: d}) e.flush() - decoder := newDecoder(io.NopCloser(buf)) - err := decoder.decode(b) + f, err := createFileWithData(t, buf) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + decoder := newDecoder(fileutil.NewFileReader(f)) + err = decoder.decode(b) if err != nil { t.Errorf("err = %v, want nil", err) } @@ -84,3 +93,15 @@ func TestWriteRecord(t *testing.T) { t.Errorf("data = %v, want %v", b.Data, d) } } + +func createFileWithData(t *testing.T, bf *bytes.Buffer) (*os.File, error) { + f, err := os.CreateTemp(t.TempDir(), "wal") + if err != nil { + return nil, err + } + if _, err := f.Write(bf.Bytes()); err != nil { + return nil, err + } + f.Seek(0, 0) + return f, nil +} diff --git a/server/storage/wal/repair.go b/server/storage/wal/repair.go index c007763de..78083d45b 100644 --- a/server/storage/wal/repair.go +++ b/server/storage/wal/repair.go @@ -40,7 +40,7 @@ func Repair(lg *zap.Logger, dirpath string) bool { lg.Info("repairing", zap.String("path", f.Name())) rec := &walpb.Record{} - decoder := newDecoder(f) + decoder := newDecoder(fileutil.NewFileReader(f.File)) for { lastOffset := decoder.lastOffset() err := decoder.decode(rec) diff --git a/server/storage/wal/wal.go b/server/storage/wal/wal.go index 187cfe397..aa68a1a73 100644 --- a/server/storage/wal/wal.go +++ b/server/storage/wal/wal.go @@ -54,15 +54,14 @@ var ( // so that tests can set a different segment size. SegmentSizeBytes int64 = 64 * 1000 * 1000 // 64MB - ErrMetadataConflict = errors.New("wal: conflicting metadata found") - ErrFileNotFound = errors.New("wal: file not found") - ErrCRCMismatch = errors.New("wal: crc mismatch") - ErrSnapshotMismatch = errors.New("wal: snapshot mismatch") - ErrSnapshotNotFound = errors.New("wal: snapshot not found") - ErrSliceOutOfRange = errors.New("wal: slice bounds out of range") - ErrMaxWALEntrySizeLimitExceeded = errors.New("wal: max entry size limit exceeded") - ErrDecoderNotFound = errors.New("wal: decoder not found") - crcTable = crc32.MakeTable(crc32.Castagnoli) + ErrMetadataConflict = errors.New("wal: conflicting metadata found") + ErrFileNotFound = errors.New("wal: file not found") + ErrCRCMismatch = errors.New("wal: crc mismatch") + ErrSnapshotMismatch = errors.New("wal: snapshot mismatch") + ErrSnapshotNotFound = errors.New("wal: snapshot not found") + ErrSliceOutOfRange = errors.New("wal: slice bounds out of range") + ErrDecoderNotFound = errors.New("wal: decoder not found") + crcTable = crc32.MakeTable(crc32.Castagnoli) ) // WAL is a logical representation of the stable storage. @@ -386,12 +385,13 @@ func selectWALFiles(lg *zap.Logger, dirpath string, snap walpb.Snapshot) ([]stri return names, nameIndex, nil } -func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, write bool) ([]io.Reader, []*fileutil.LockedFile, func() error, error) { +func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, write bool) ([]fileutil.FileReader, []*fileutil.LockedFile, func() error, error) { rcs := make([]io.ReadCloser, 0) - rs := make([]io.Reader, 0) + rs := make([]fileutil.FileReader, 0) ls := make([]*fileutil.LockedFile, 0) for _, name := range names[nameIndex:] { p := filepath.Join(dirpath, name) + var f *os.File if write { l, err := fileutil.TryLockFile(p, os.O_RDWR, fileutil.PrivateFileMode) if err != nil { @@ -400,6 +400,7 @@ func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, } ls = append(ls, l) rcs = append(rcs, l) + f = l.File } else { rf, err := os.OpenFile(p, os.O_RDONLY, fileutil.PrivateFileMode) if err != nil { @@ -408,8 +409,10 @@ func openWALFiles(lg *zap.Logger, dirpath string, names []string, nameIndex int, } ls = append(ls, nil) rcs = append(rcs, rf) + f = rf } - rs = append(rs, rcs[len(rcs)-1]) + fileReader := fileutil.NewFileReader(f) + rs = append(rs, fileReader) } closer := func() error { return closeAll(lg, rcs...) } diff --git a/server/storage/wal/wal_test.go b/server/storage/wal/wal_test.go index 16e552f9a..e988bb4e5 100644 --- a/server/storage/wal/wal_test.go +++ b/server/storage/wal/wal_test.go @@ -19,6 +19,7 @@ import ( "fmt" "io" "math" + "math/rand" "os" "path" "path/filepath" @@ -297,7 +298,7 @@ func TestCut(t *testing.T) { } defer f.Close() nw := &WAL{ - decoder: newDecoder(f), + decoder: newDecoder(fileutil.NewFileReader(f)), start: snap, } _, gst, _, err := nw.ReadAll() @@ -369,47 +370,77 @@ func TestSaveWithCut(t *testing.T) { } func TestRecover(t *testing.T) { - p := t.TempDir() - - w, err := Create(zaptest.NewLogger(t), p, []byte("metadata")) - if err != nil { - t.Fatal(err) - } - if err = w.SaveSnapshot(walpb.Snapshot{}); err != nil { - t.Fatal(err) - } - ents := []raftpb.Entry{{Index: 1, Term: 1, Data: []byte{1}}, {Index: 2, Term: 2, Data: []byte{2}}} - if err = w.Save(raftpb.HardState{}, ents); err != nil { - t.Fatal(err) - } - sts := []raftpb.HardState{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}} - for _, s := range sts { - if err = w.Save(s, nil); err != nil { - t.Fatal(err) - } - } - w.Close() - - if w, err = Open(zaptest.NewLogger(t), p, walpb.Snapshot{}); err != nil { - t.Fatal(err) - } - metadata, state, entries, err := w.ReadAll() - if err != nil { - t.Fatal(err) + cases := []struct { + name string + size int + }{ + { + name: "10MB", + size: 10 * 1024 * 1024, + }, + { + name: "20MB", + size: 20 * 1024 * 1024, + }, + { + name: "40MB", + size: 40 * 1024 * 1024, + }, } - if !bytes.Equal(metadata, []byte("metadata")) { - t.Errorf("metadata = %s, want %s", metadata, "metadata") + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + p := t.TempDir() + + w, err := Create(zaptest.NewLogger(t), p, []byte("metadata")) + if err != nil { + t.Fatal(err) + } + if err = w.SaveSnapshot(walpb.Snapshot{}); err != nil { + t.Fatal(err) + } + + data := make([]byte, tc.size) + n, err := rand.Read(data) + assert.Equal(t, tc.size, n) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + ents := []raftpb.Entry{{Index: 1, Term: 1, Data: data}, {Index: 2, Term: 2, Data: data}} + if err = w.Save(raftpb.HardState{}, ents); err != nil { + t.Fatal(err) + } + sts := []raftpb.HardState{{Term: 1, Vote: 1, Commit: 1}, {Term: 2, Vote: 2, Commit: 2}} + for _, s := range sts { + if err = w.Save(s, nil); err != nil { + t.Fatal(err) + } + } + w.Close() + + if w, err = Open(zaptest.NewLogger(t), p, walpb.Snapshot{}); err != nil { + t.Fatal(err) + } + metadata, state, entries, err := w.ReadAll() + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(metadata, []byte("metadata")) { + t.Errorf("metadata = %s, want %s", metadata, "metadata") + } + if !reflect.DeepEqual(entries, ents) { + t.Errorf("ents = %+v, want %+v", entries, ents) + } + // only the latest state is recorded + s := sts[len(sts)-1] + if !reflect.DeepEqual(state, s) { + t.Errorf("state = %+v, want %+v", state, s) + } + w.Close() + }) } - if !reflect.DeepEqual(entries, ents) { - t.Errorf("ents = %+v, want %+v", entries, ents) - } - // only the latest state is recorded - s := sts[len(sts)-1] - if !reflect.DeepEqual(state, s) { - t.Errorf("state = %+v, want %+v", state, s) - } - w.Close() } func TestSearchIndex(t *testing.T) {