From 7b8de8194b65a0b605db3b0deb8b9d449cfade6e Mon Sep 17 00:00:00 2001 From: new-dream <111836360+new-dream@users.noreply.github.com> Date: Fri, 25 Aug 2023 22:02:45 +0800 Subject: [PATCH] pkg: add a verification on the pagebytes which must be > 0 Signed-off-by: n00607095 --- pkg/ioutil/pagewriter.go | 4 ++++ pkg/ioutil/pagewriter_test.go | 41 +++++++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+) diff --git a/pkg/ioutil/pagewriter.go b/pkg/ioutil/pagewriter.go index cf9a8dc66..62eb5cd43 100644 --- a/pkg/ioutil/pagewriter.go +++ b/pkg/ioutil/pagewriter.go @@ -15,6 +15,7 @@ package ioutil import ( + "fmt" "io" ) @@ -41,6 +42,9 @@ type PageWriter struct { // NewPageWriter creates a new PageWriter. pageBytes is the number of bytes // to write per page. pageOffset is the starting offset of io.Writer. func NewPageWriter(w io.Writer, pageBytes, pageOffset int) *PageWriter { + if pageBytes <= 0 { + panic(fmt.Sprintf("assertion failed: invalid pageBytes (%d) value, it must be greater than 0", pageBytes)) + } return &PageWriter{ w: w, pageOffset: pageOffset, diff --git a/pkg/ioutil/pagewriter_test.go b/pkg/ioutil/pagewriter_test.go index 305bc6320..e05c71f7c 100644 --- a/pkg/ioutil/pagewriter_test.go +++ b/pkg/ioutil/pagewriter_test.go @@ -17,6 +17,8 @@ package ioutil import ( "math/rand" "testing" + + "github.com/stretchr/testify/assert" ) func TestPageWriterRandom(t *testing.T) { @@ -111,6 +113,45 @@ func TestPageWriterOffset(t *testing.T) { } } +func TestPageWriterPageBytes(t *testing.T) { + cases := []struct { + name string + pageBytes int + expectPanic bool + }{ + { + name: "normal page bytes", + pageBytes: 4096, + expectPanic: false, + }, + { + name: "negative page bytes", + pageBytes: -1, + expectPanic: true, + }, + { + name: "zero page bytes", + pageBytes: 0, + expectPanic: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + defaultBufferBytes = 1024 + cw := &checkPageWriter{pageBytes: tc.pageBytes, t: t} + if tc.expectPanic { + assert.Panicsf(t, func() { + NewPageWriter(cw, tc.pageBytes, 0) + }, "expected panic when pageBytes is %d", tc.pageBytes) + } else { + pw := NewPageWriter(cw, tc.pageBytes, 0) + assert.NotEqual(t, pw, nil) + } + }) + } +} + // checkPageWriter implements an io.Writer that fails a test on unaligned writes. type checkPageWriter struct { pageBytes int