diff --git a/third_party/code.google.com/p/go.net/html/token.go b/third_party/code.google.com/p/go.net/html/token.go index a1f43896a..3a1ee7ccb 100644 --- a/third_party/code.google.com/p/go.net/html/token.go +++ b/third_party/code.google.com/p/go.net/html/token.go @@ -133,6 +133,11 @@ type Tokenizer struct { // subsequent Next calls would return an ErrorToken. // err is never reset. Once it becomes non-nil, it stays non-nil. err error + // readErr is the error returned by the io.Reader r. It is separate from + // err because it is valid for an io.Reader to return (n int, err1 error) + // such that n > 0 && err1 != nil, and callers should always process the + // n > 0 bytes before considering the error err1. + readErr error // buf[raw.start:raw.end] holds the raw bytes of the current token. // buf[raw.end:] is buffered input that will yield future tokens. raw span @@ -222,7 +227,12 @@ func (z *Tokenizer) Err() error { // Pre-condition: z.err == nil. func (z *Tokenizer) readByte() byte { if z.raw.end >= len(z.buf) { - // Our buffer is exhausted and we have to read from z.r. + // Our buffer is exhausted and we have to read from z.r. Check if the + // previous read resulted in an error. + if z.readErr != nil { + z.err = z.readErr + return 0 + } // We copy z.buf[z.raw.start:z.raw.end] to the beginning of z.buf. If the length // z.raw.end - z.raw.start is more than half the capacity of z.buf, then we // allocate a new buffer before the copy. @@ -253,9 +263,10 @@ func (z *Tokenizer) readByte() byte { z.raw.start, z.raw.end, z.buf = 0, d, buf1[:d] // Now that we have copied the live bytes to the start of the buffer, // we read from z.r into the remainder. - n, err := z.r.Read(buf1[d:cap(buf1)]) - if err != nil { - z.err = err + var n int + n, z.readErr = readAtLeastOneByte(z.r, buf1[d:cap(buf1)]) + if n == 0 { + z.err = z.readErr return 0 } z.buf = buf1[:d+n] @@ -265,6 +276,19 @@ func (z *Tokenizer) readByte() byte { return x } +// readAtLeastOneByte wraps an io.Reader so that reading cannot return (0, nil). +// It returns io.ErrNoProgress if the underlying r.Read method returns (0, nil) +// too many times in succession. +func readAtLeastOneByte(r io.Reader, b []byte) (int, error) { + for i := 0; i < 100; i++ { + n, err := r.Read(b) + if n != 0 || err != nil { + return n, err + } + } + return 0, io.ErrNoProgress +} + // skipWhiteSpace skips past any white space. func (z *Tokenizer) skipWhiteSpace() { if z.err != nil { diff --git a/third_party/code.google.com/p/go.net/html/token_test.go b/third_party/code.google.com/p/go.net/html/token_test.go index fd49ea61b..ca408da40 100644 --- a/third_party/code.google.com/p/go.net/html/token_test.go +++ b/third_party/code.google.com/p/go.net/html/token_test.go @@ -8,6 +8,7 @@ import ( "bytes" "io" "io/ioutil" + "reflect" "runtime" "strings" "testing" @@ -531,6 +532,85 @@ func TestConvertNewlines(t *testing.T) { } } +func TestReaderEdgeCases(t *testing.T) { + const s = "
An io.Reader can return (0, nil) or (n, io.EOF).
" + testCases := []io.Reader{ + &zeroOneByteReader{s: s}, + &eofStringsReader{s: s}, + &stuckReader{}, + } + for i, tc := range testCases { + got := []TokenType{} + z := NewTokenizer(tc) + for { + tt := z.Next() + if tt == ErrorToken { + break + } + got = append(got, tt) + } + if err := z.Err(); err != nil && err != io.EOF { + if err != io.ErrNoProgress { + t.Errorf("i=%d: %v", i, err) + } + continue + } + want := []TokenType{ + StartTagToken, + TextToken, + EndTagToken, + } + if !reflect.DeepEqual(got, want) { + t.Errorf("i=%d: got %v, want %v", i, got, want) + continue + } + } +} + +// zeroOneByteReader is like a strings.Reader that alternates between +// returning 0 bytes and 1 byte at a time. +type zeroOneByteReader struct { + s string + n int +} + +func (r *zeroOneByteReader) Read(p []byte) (int, error) { + if len(p) == 0 { + return 0, nil + } + if len(r.s) == 0 { + return 0, io.EOF + } + r.n++ + if r.n%2 != 0 { + return 0, nil + } + p[0], r.s = r.s[0], r.s[1:] + return 1, nil +} + +// eofStringsReader is like a strings.Reader but can return an (n, err) where +// n > 0 && err != nil. +type eofStringsReader struct { + s string +} + +func (r *eofStringsReader) Read(p []byte) (int, error) { + n := copy(p, r.s) + r.s = r.s[n:] + if r.s != "" { + return n, nil + } + return n, io.EOF +} + +// stuckReader is an io.Reader that always returns no data and no error. +type stuckReader struct{} + +func (*stuckReader) Read(p []byte) (int, error) { + return 0, nil +} + const ( rawLevel = iota lowLevel