diff --git a/wal/block.go b/wal/block.go new file mode 100644 index 000000000..1ef76f3ec --- /dev/null +++ b/wal/block.go @@ -0,0 +1,43 @@ +package wal + +import ( + "fmt" + "io" +) + +type block struct { + t int64 + l int64 + d []byte +} + +func writeBlock(w io.Writer, t int64, d []byte) error { + if err := writeInt64(w, t); err != nil { + return err + } + if err := writeInt64(w, int64(len(d))); err != nil { + return err + } + _, err := w.Write(d) + return err +} + +func readBlock(r io.Reader) (*block, error) { + t, err := readInt64(r) + if err != nil { + return nil, err + } + l, err := readInt64(r) + if err != nil { + return nil, unexpectedEOF(err) + } + d := make([]byte, l) + n, err := r.Read(d) + if err != nil { + return nil, unexpectedEOF(err) + } + if n != int(l) { + return nil, fmt.Errorf("len(data) = %d, want %d", n, l) + } + return &block{t, l, d}, nil +} diff --git a/wal/block_test.go b/wal/block_test.go new file mode 100644 index 000000000..01a297c67 --- /dev/null +++ b/wal/block_test.go @@ -0,0 +1,50 @@ +package wal + +import ( + "bytes" + "io" + "reflect" + "testing" +) + +func TestReadBlock(t *testing.T) { + tests := []struct { + data []byte + wb *block + we error + }{ + {infoBlock, &block{1, 8, infoData}, nil}, + {[]byte(""), nil, io.EOF}, + {infoBlock[:len(infoBlock)-len(infoData)-8], nil, io.ErrUnexpectedEOF}, + {infoBlock[:len(infoBlock)-len(infoData)], nil, io.ErrUnexpectedEOF}, + {infoBlock[:len(infoBlock)-8], nil, io.ErrUnexpectedEOF}, + } + + for i, tt := range tests { + buf := bytes.NewBuffer(tt.data) + b, e := readBlock(buf) + if !reflect.DeepEqual(b, tt.wb) { + t.Errorf("#%d: block = %v, want %v", i, b, tt.wb) + } + if !reflect.DeepEqual(e, tt.we) { + t.Errorf("#%d: err = %v, want %v", i, e, tt.we) + } + } +} + +func TestWriteBlock(t *testing.T) { + typ := int64(0xABCD) + d := []byte("Hello world!") + buf := new(bytes.Buffer) + writeBlock(buf, typ, d) + b, err := readBlock(buf) + if err != nil { + t.Errorf("err = %v, want nil", err) + } + if b.t != typ { + t.Errorf("type = %d, want %d", b.t, typ) + } + if !reflect.DeepEqual(b.d, d) { + t.Errorf("data = %v, want %v", b.d, d) + } +} diff --git a/wal/wal.go b/wal/wal.go index 04bc3795a..74bec75ef 100644 --- a/wal/wal.go +++ b/wal/wal.go @@ -49,63 +49,44 @@ func Open(path string) (*WAL, error) { func (w *WAL) Close() { if w.f != nil { - w.flush() + w.Flush() w.f.Close() } } -func (w *WAL) writeInfo(id int64) error { - // | 8 bytes | 8 bytes | 8 bytes | - // | type | len | nodeid | +func (w *WAL) SaveInfo(id int64) error { if err := w.checkAtHead(); err != nil { return err } - if err := w.writeInt64(infoType); err != nil { - return err + // cache the buffer? + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.LittleEndian, id) + if err != nil { + panic(err) } - if err := w.writeInt64(8); err != nil { - return err - } - return w.writeInt64(id) + return writeBlock(w.bw, infoType, buf.Bytes()) } -func (w *WAL) writeEntry(e *raft.Entry) error { - // | 8 bytes | 8 bytes | variable length | - // | type | len | entry data | - if err := w.writeInt64(entryType); err != nil { - return err - } +func (w *WAL) SaveEntry(e *raft.Entry) error { + // protobuf? b, err := json.Marshal(e) if err != nil { - return err + panic(err) } - n := len(b) - if err := w.writeInt64(int64(n)); err != nil { - return err - } - if _, err := w.bw.Write(b); err != nil { - return err - } - return nil + return writeBlock(w.bw, entryType, b) } -func (w *WAL) writeState(s *raft.State) error { - // | 8 bytes | 8 bytes | 24 bytes | - // | type | len | state | - if err := w.writeInt64(stateType); err != nil { - return err +func (w *WAL) SaveState(s *raft.State) error { + // cache the buffer? + buf := new(bytes.Buffer) + err := binary.Write(buf, binary.LittleEndian, s) + if err != nil { + panic(err) } - if err := w.writeInt64(24); err != nil { - return err - } - return binary.Write(w.bw, binary.LittleEndian, s) + return writeBlock(w.bw, stateType, buf.Bytes()) } -func (w *WAL) writeInt64(n int64) error { - return binary.Write(w.bw, binary.LittleEndian, n) -} - -func (w *WAL) flush() error { +func (w *WAL) Flush() error { return w.bw.Flush() } @@ -126,61 +107,51 @@ type Node struct { State raft.State } -func (w *WAL) ReadNode() (*Node, error) { +func (w *WAL) LoadNode() (*Node, error) { if err := w.checkAtHead(); err != nil { return nil, err } br := bufio.NewReader(w.f) - n := new(Node) b, err := readBlock(br) if err != nil { return nil, err } - switch b.t { - case infoType: - id, err := parseInfo(b.d) - if err != nil { - return nil, err - } - n.Id = id - default: - return nil, fmt.Errorf("type = %d, want %d", b.t, infoType) + if b.t != infoType { + return nil, fmt.Errorf("the first block of wal is not infoType but %d", b.t) + } + id, err := loadInfo(b.d) + if err != nil { + return nil, err } ents := make([]raft.Entry, 0) var state raft.State - for { - b, err := readBlock(br) - if err == io.EOF { - break - } - if err != nil { - return nil, err - } + for b, err = readBlock(br); err == nil; b, err = readBlock(br) { switch b.t { case entryType: - e, err := parseEntry(b.d) + e, err := loadEntry(b.d) if err != nil { return nil, err } ents = append(ents, e) case stateType: - s, err := parseState(b.d) + s, err := loadState(b.d) if err != nil { return nil, err } state = s default: - return nil, fmt.Errorf("cannot handle type %d", b.t) + return nil, fmt.Errorf("unexpected block type %d", b.t) } } - n.Ents = ents - n.State = state - return n, nil + if err != io.EOF { + return nil, err + } + return &Node{id, ents, state}, nil } -func parseInfo(d []byte) (int64, error) { +func loadInfo(d []byte) (int64, error) { if len(d) != 8 { return 0, fmt.Errorf("len = %d, want 8", len(d)) } @@ -188,49 +159,21 @@ func parseInfo(d []byte) (int64, error) { return readInt64(buf) } -func parseEntry(d []byte) (raft.Entry, error) { +func loadEntry(d []byte) (raft.Entry, error) { var e raft.Entry err := json.Unmarshal(d, &e) return e, err } -func parseState(d []byte) (raft.State, error) { +func loadState(d []byte) (raft.State, error) { var s raft.State buf := bytes.NewBuffer(d) err := binary.Read(buf, binary.LittleEndian, &s) return s, err } -type block struct { - t int64 - l int64 - d []byte -} - -func readBlock(r io.Reader) (*block, error) { - typ, err := readInt64(r) - if err != nil { - return nil, err - } - l, err := readInt64(r) - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - data := make([]byte, l) - n, err := r.Read(data) - if err != nil { - if err == io.EOF { - err = io.ErrUnexpectedEOF - } - return nil, err - } - if n != int(l) { - return nil, fmt.Errorf("len(data) = %d, want %d", n, l) - } - return &block{typ, l, data}, nil +func writeInt64(w io.Writer, n int64) error { + return binary.Write(w, binary.LittleEndian, n) } func readInt64(r io.Reader) (int64, error) { @@ -239,6 +182,13 @@ func readInt64(r io.Reader) (int64, error) { return n, err } +func unexpectedEOF(err error) error { + if err == io.EOF { + return io.ErrUnexpectedEOF + } + return err +} + func max(a, b int64) int64 { if a > b { return a diff --git a/wal/wal_test.go b/wal/wal_test.go index 860de171f..f13762099 100644 --- a/wal/wal_test.go +++ b/wal/wal_test.go @@ -1,8 +1,6 @@ package wal import ( - "bytes" - "io" "io/ioutil" "os" "path" @@ -12,6 +10,17 @@ import ( "github.com/coreos/etcd/raft" ) +var ( + infoData = []byte("\xef\xbe\x00\x00\x00\x00\x00\x00") + infoBlock = append([]byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00"), infoData...) + + stateData = []byte("\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00") + stateBlock = append([]byte("\x03\x00\x00\x00\x00\x00\x00\x00\x18\x00\x00\x00\x00\x00\x00\x00"), stateData...) + + entryJsonData = []byte("{\"Type\":1,\"Term\":1,\"Data\":\"AQ==\"}") + entryBlock = append([]byte("\x02\x00\x00\x00\x00\x00\x00\x00\x21\x00\x00\x00\x00\x00\x00\x00"), entryJsonData...) +) + func TestNew(t *testing.T) { f, err := ioutil.TempFile(os.TempDir(), "waltest") if err != nil { @@ -37,14 +46,14 @@ func TestNew(t *testing.T) { } } -func TestWriteEntry(t *testing.T) { +func TestSaveEntry(t *testing.T) { p := path.Join(os.TempDir(), "waltest") w, err := New(p) if err != nil { t.Fatal(err) } e := &raft.Entry{1, 1, []byte{1}} - err = w.writeEntry(e) + err = w.SaveEntry(e) if err != nil { t.Fatal(err) } @@ -54,9 +63,8 @@ func TestWriteEntry(t *testing.T) { if err != nil { t.Fatal(err) } - wb := []byte("\x02\x00\x00\x00\x00\x00\x00\x00!\x00\x00\x00\x00\x00\x00\x00{\"Type\":1,\"Term\":1,\"Data\":\"AQ==\"}") - if !reflect.DeepEqual(b, wb) { - t.Errorf("ent = %q, want %q", b, wb) + if !reflect.DeepEqual(b, entryBlock) { + t.Errorf("ent = %q, want %q", b, entryBlock) } err = os.Remove(p) @@ -65,28 +73,28 @@ func TestWriteEntry(t *testing.T) { } } -func TestWriteInfo(t *testing.T) { +func TestSaveInfo(t *testing.T) { p := path.Join(os.TempDir(), "waltest") w, err := New(p) if err != nil { t.Fatal(err) } id := int64(0xBEEF) - err = w.writeInfo(id) + err = w.SaveInfo(id) if err != nil { t.Fatal(err) } // make sure we can only write info at the head of the wal file // still in buffer - err = w.writeInfo(id) + err = w.SaveInfo(id) if err == nil || err.Error() != "cannot write info at 24, expect 0" { t.Errorf("err = %v, want cannot write info at 8, expect 0", err) } // flush to disk - w.flush() - err = w.writeInfo(id) + w.Flush() + err = w.SaveInfo(id) if err == nil || err.Error() != "cannot write info at 24, expect 0" { t.Errorf("err = %v, want cannot write info at 8, expect 0", err) } @@ -96,9 +104,8 @@ func TestWriteInfo(t *testing.T) { if err != nil { t.Fatal(err) } - wb := []byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00\xef\xbe\x00\x00\x00\x00\x00\x00") - if !reflect.DeepEqual(b, wb) { - t.Errorf("ent = %q, want %q", b, wb) + if !reflect.DeepEqual(b, infoBlock) { + t.Errorf("ent = %q, want %q", b, infoBlock) } err = os.Remove(p) @@ -107,14 +114,14 @@ func TestWriteInfo(t *testing.T) { } } -func TestWriteState(t *testing.T) { +func TestSaveState(t *testing.T) { p := path.Join(os.TempDir(), "waltest") w, err := New(p) if err != nil { t.Fatal(err) } st := &raft.State{1, 1, 1} - err = w.writeState(st) + err = w.SaveState(st) if err != nil { t.Fatal(err) } @@ -124,9 +131,8 @@ func TestWriteState(t *testing.T) { if err != nil { t.Fatal(err) } - wb := []byte("\x03\x00\x00\x00\x00\x00\x00\x00\x18\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00") - if !reflect.DeepEqual(b, wb) { - t.Errorf("ent = %q, want %q", b, wb) + if !reflect.DeepEqual(b, stateBlock) { + t.Errorf("ent = %q, want %q", b, stateBlock) } err = os.Remove(p) @@ -135,9 +141,8 @@ func TestWriteState(t *testing.T) { } } -func TestParseInfo(t *testing.T) { - data := []byte("\xef\xbe\x00\x00\x00\x00\x00\x00") - id, err := parseInfo(data) +func TestLoadInfo(t *testing.T) { + id, err := loadInfo(infoData) if err != nil { t.Fatal(err) } @@ -146,9 +151,8 @@ func TestParseInfo(t *testing.T) { } } -func TestParseEntry(t *testing.T) { - data := []byte("{\"Type\":1,\"Term\":1,\"Data\":\"AQ==\"}") - e, err := parseEntry(data) +func TestLoadEntry(t *testing.T) { + e, err := loadEntry(entryJsonData) if err != nil { t.Fatal(err) } @@ -158,9 +162,8 @@ func TestParseEntry(t *testing.T) { } } -func TestParseState(t *testing.T) { - data := []byte("\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00") - s, err := parseState(data) +func TestLoadState(t *testing.T) { + s, err := loadState(stateData) if err != nil { t.Fatal(err) } @@ -170,70 +173,25 @@ func TestParseState(t *testing.T) { } } -func TestReadBlock(t *testing.T) { - tests := []struct { - data []byte - wb *block - we error - }{ - { - []byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00\xef\xbe\x00\x00\x00\x00\x00\x00"), - &block{1, 8, []byte("\xef\xbe\x00\x00\x00\x00\x00\x00")}, - nil, - }, - { - []byte(""), - nil, - io.EOF, - }, - { - []byte("\x01\x00\x00\x00"), - nil, - io.ErrUnexpectedEOF, - }, - { - []byte("\x01\x00\x00\x00\x00\x00\x00\x00"), - nil, - io.ErrUnexpectedEOF, - }, - { - []byte("\x01\x00\x00\x00\x00\x00\x00\x00\b\x00\x00\x00\x00\x00\x00\x00"), - nil, - io.ErrUnexpectedEOF, - }, - } - - for i, tt := range tests { - buf := bytes.NewBuffer(tt.data) - b, e := readBlock(buf) - if !reflect.DeepEqual(b, tt.wb) { - t.Errorf("#%d: block = %v, want %v", i, b, tt.wb) - } - if !reflect.DeepEqual(e, tt.we) { - t.Errorf("#%d: err = %v, want %v", i, e, tt.we) - } - } -} - -func TestReadNode(t *testing.T) { +func TestLoadNode(t *testing.T) { p := path.Join(os.TempDir(), "waltest") w, err := New(p) if err != nil { t.Fatal(err) } id := int64(0xBEEF) - if err = w.writeInfo(id); err != nil { + if err = w.SaveInfo(id); err != nil { t.Fatal(err) } ents := []raft.Entry{{1, 1, []byte{1}}, {2, 2, []byte{2}}} for _, e := range ents { - if err = w.writeEntry(&e); err != nil { + if err = w.SaveEntry(&e); err != nil { t.Fatal(err) } } sts := []raft.State{{1, 1, 1}, {2, 2, 2}} for _, s := range sts { - if err = w.writeState(&s); err != nil { + if err = w.SaveState(&s); err != nil { t.Fatal(err) } } @@ -243,7 +201,7 @@ func TestReadNode(t *testing.T) { if err != nil { t.Fatal(err) } - n, err := w.ReadNode() + n, err := w.LoadNode() if err != nil { t.Fatal(err) }