// Copyright 2023 The etcd Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package testutils import ( "context" "fmt" "strings" "sync" "time" "go.uber.org/zap" "go.uber.org/zap/zapcore" zapobserver "go.uber.org/zap/zaptest/observer" ) type LogObserver struct { ob *zapobserver.ObservedLogs enc zapcore.Encoder mu sync.Mutex // entries stores all the logged entries after syncLogs. entries []zapobserver.LoggedEntry } func NewLogObserver(level zapcore.LevelEnabler) (zapcore.Core, *LogObserver) { // align with zaptest enc := zapcore.NewConsoleEncoder(zap.NewDevelopmentEncoderConfig()) co, ob := zapobserver.New(level) return co, &LogObserver{ ob: ob, enc: enc, } } // ExpectAtLeastNTimes returns the first N lines containing the given string. func (logOb *LogObserver) ExpectAtLeastNTimes(ctx context.Context, s string, count int) ([]string, error) { return logOb.ExpectFunc(ctx, func(log string) bool { return strings.Contains(log, s) }, count) } // ExpectFunc returns the first N line satisfying the function f. func (logOb *LogObserver) ExpectFunc(ctx context.Context, filter func(string) bool, count int) ([]string, error) { i := 0 res := make([]string, 0, count) for { select { case <-ctx.Done(): return nil, ctx.Err() default: } entries := logOb.syncLogs() // The order of entries won't be changed because of append-only. // It's safe to skip scanned entries by reusing `i`. for ; i < len(entries); i++ { buf, err := logOb.enc.EncodeEntry(entries[i].Entry, entries[i].Context) if err != nil { return nil, fmt.Errorf("failed to encode entry: %w", err) } logInStr := buf.String() if filter(logInStr) { res = append(res, logInStr) } if len(res) >= count { return res, nil } } time.Sleep(10 * time.Millisecond) } } // ExpectExactNTimes returns all lines that satisfy the filter if there are // exactly `count` of them when the duration ends. Otherwise, it returns an error. // Make sure ctx has a timeout longer than duration for ExpectExactNTimes to work properly. func (logOb *LogObserver) ExpectExactNTimes(ctx context.Context, s string, count int, duration time.Duration) ([]string, error) { return logOb.ExpectExactNTimesFunc(ctx, func(log string) bool { return strings.Contains(log, s) }, count, duration) } // ExpectExactNTimesFunc returns all lines that satisfy the filter if there are // exactly `count` of them when the duration ends. Otherwise, it returns an error. // Make sure ctx has a timeout longer than duration for ExpectExactNTimesFunc to work properly. func (logOb *LogObserver) ExpectExactNTimesFunc(ctx context.Context, filter func(string) bool, count int, duration time.Duration) ([]string, error) { timer := time.NewTimer(duration) defer timer.Stop() i := 0 res := make([]string, 0, count) for { select { case <-ctx.Done(): return nil, ctx.Err() case <-timer.C: if len(res) == count { return res, nil } else { return nil, fmt.Errorf("failed to expect, expected: %d, got: %d", count, len(res)) } default: } entries := logOb.syncLogs() // The order of entries won't be changed because of append-only. // It's safe to skip scanned entries by reusing `i`. for ; i < len(entries); i++ { buf, err := logOb.enc.EncodeEntry(entries[i].Entry, entries[i].Context) if err != nil { return nil, fmt.Errorf("failed to encode entry: %w", err) } logInStr := buf.String() if filter(logInStr) { res = append(res, logInStr) } if len(res) > count { return nil, fmt.Errorf("failed to expect; too many occurrences; expected: %d, got:%d", count, len(res)) } } time.Sleep(10 * time.Millisecond) } } // syncLogs is to take all the existing logged entries from zapobserver and // truncate zapobserver's entries slice. func (logOb *LogObserver) syncLogs() []zapobserver.LoggedEntry { logOb.mu.Lock() defer logOb.mu.Unlock() logOb.entries = append(logOb.entries, logOb.ob.TakeAll()...) return logOb.entries }