diff --git a/wal/wal.go b/wal/wal.go index 676723525..593d4e939 100644 --- a/wal/wal.go +++ b/wal/wal.go @@ -22,7 +22,13 @@ type WAL struct { } func New(path string) (*WAL, error) { - f, err := os.Create(path) + f, err := os.Open(path) + if err == nil { + f.Close() + return nil, os.ErrExist + } + + f, err = os.Create(path) if err != nil { return nil, err } @@ -30,6 +36,13 @@ func New(path string) (*WAL, error) { return &WAL{f, bw}, nil } +func (w *WAL) Close() { + if w.f != nil { + w.flush() + w.f.Close() + } +} + func (w *WAL) writeInfo(id int64) error { // | 8 bytes | 8 bytes | 8 bytes | // | type | len | nodeid | diff --git a/wal/wal_test.go b/wal/wal_test.go new file mode 100644 index 000000000..c79dfb420 --- /dev/null +++ b/wal/wal_test.go @@ -0,0 +1,32 @@ +package wal + +import ( + "io/ioutil" + "os" + "testing" +) + +func TestNew(t *testing.T) { + f, err := ioutil.TempFile(os.TempDir(), "waltest") + if err != nil { + t.Fatal(err) + } + p := f.Name() + _, err = New(p) + if err == nil || err != os.ErrExist { + t.Errorf("err = %v, want %v", err, os.ErrExist) + } + err = os.Remove(p) + if err != nil { + t.Fatal(err) + } + w, err := New(p) + if err != nil { + t.Errorf("err = %v, want nil", err) + } + w.Close() + err = os.Remove(p) + if err != nil { + t.Fatal(err) + } +}