server: Refactor wal version to use visitor pattern

This commit is contained in:
Marek Siarkowicz 2021-11-19 17:04:02 +01:00
parent 6d808e5d7d
commit d865bb96f1
3 changed files with 143 additions and 122 deletions

View File

@ -53,32 +53,63 @@ func (w *walVersion) MinimalEtcdVersion() *semver.Version {
func MinimalEtcdVersion(ents []raftpb.Entry) *semver.Version { func MinimalEtcdVersion(ents []raftpb.Entry) *semver.Version {
var maxVer *semver.Version var maxVer *semver.Version
for _, ent := range ents { for _, ent := range ents {
maxVer = maxVersion(maxVer, etcdVersionFromEntry(ent)) err := visitEntry(ent, func(path protoreflect.FullName, ver *semver.Version) error {
maxVer = maxVersion(maxVer, ver)
return nil
})
if err != nil {
panic(err)
}
} }
return maxVer return maxVer
} }
func etcdVersionFromEntry(ent raftpb.Entry) *semver.Version { type Visitor func(path protoreflect.FullName, ver *semver.Version) error
msgVer := etcdVersionFromMessage(proto.MessageReflect(&ent))
dataVer := etcdVersionFromData(ent.Type, ent.Data) func VisitFileDescriptor(file protoreflect.FileDescriptor, visitor Visitor) error {
return maxVersion(msgVer, dataVer) msgs := file.Messages()
for i := 0; i < msgs.Len(); i++ {
err := visitMessageDescriptor(msgs.Get(i), visitor)
if err != nil {
return err
}
}
enums := file.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return nil
} }
func etcdVersionFromData(entryType raftpb.EntryType, data []byte) *semver.Version { func visitEntry(ent raftpb.Entry, visitor Visitor) error {
err := visitMessage(proto.MessageReflect(&ent), visitor)
if err != nil {
return err
}
return visitEntryData(ent.Type, ent.Data, visitor)
}
func visitEntryData(entryType raftpb.EntryType, data []byte, visitor Visitor) error {
var msg protoreflect.Message var msg protoreflect.Message
var ver *semver.Version
switch entryType { switch entryType {
case raftpb.EntryNormal: case raftpb.EntryNormal:
var raftReq etcdserverpb.InternalRaftRequest var raftReq etcdserverpb.InternalRaftRequest
err := pbutil.Unmarshaler(&raftReq).Unmarshal(data) err := pbutil.Unmarshaler(&raftReq).Unmarshal(data)
if err != nil { if err != nil {
return nil return err
} }
msg = proto.MessageReflect(&raftReq) msg = proto.MessageReflect(&raftReq)
if raftReq.ClusterVersionSet != nil { if raftReq.ClusterVersionSet != nil {
ver, err = semver.NewVersion(raftReq.ClusterVersionSet.Ver) ver, err := semver.NewVersion(raftReq.ClusterVersionSet.Ver)
if err != nil { if err != nil {
panic(err) return err
}
err = visitor(msg.Descriptor().FullName(), ver)
if err != nil {
return err
} }
} }
case raftpb.EntryConfChange: case raftpb.EntryConfChange:
@ -98,46 +129,106 @@ func etcdVersionFromData(entryType raftpb.EntryType, data []byte) *semver.Versio
default: default:
panic("unhandled") panic("unhandled")
} }
return maxVersion(etcdVersionFromMessage(msg), ver) return visitMessage(msg, visitor)
} }
func etcdVersionFromMessage(m protoreflect.Message) *semver.Version { func visitMessageDescriptor(md protoreflect.MessageDescriptor, visitor Visitor) error {
var maxVer *semver.Version err := visitDescriptor(md, visitor)
md := m.Descriptor() if err != nil {
opts := md.Options().(*descriptorpb.MessageOptions) return err
if opts != nil { }
ver, _ := EtcdVersionFromOptionsString(opts.String()) fields := md.Fields()
maxVer = maxVersion(maxVer, ver) for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = visitDescriptor(fd, visitor)
if err != nil {
return err
}
} }
enums := md.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return err
}
func visitMessage(m protoreflect.Message, visitor Visitor) error {
md := m.Descriptor()
err := visitDescriptor(md, visitor)
if err != nil {
return err
}
m.Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool { m.Range(func(field protoreflect.FieldDescriptor, value protoreflect.Value) bool {
fd := md.Fields().Get(field.Index()) fd := md.Fields().Get(field.Index())
maxVer = maxVersion(maxVer, etcdVersionFromField(fd)) err = visitDescriptor(fd, visitor)
if err != nil {
return false
}
switch m := value.Interface().(type) { switch m := value.Interface().(type) {
case protoreflect.Message: case protoreflect.Message:
maxVer = maxVersion(maxVer, etcdVersionFromMessage(m)) err = visitMessage(m, visitor)
case protoreflect.EnumNumber: case protoreflect.EnumNumber:
maxVer = maxVersion(maxVer, etcdVersionFromEnum(field.Enum(), m)) err = visitEnumNumber(fd.Enum(), m, visitor)
}
if err != nil {
return false
} }
return true return true
}) })
return maxVer return err
} }
func etcdVersionFromEnum(enum protoreflect.EnumDescriptor, value protoreflect.EnumNumber) *semver.Version { func visitEnumDescriptor(enum protoreflect.EnumDescriptor, visitor Visitor) error {
var maxVer *semver.Version err := visitDescriptor(enum, visitor)
enumOpts := enum.Options().(*descriptorpb.EnumOptions) if err != nil {
if enumOpts != nil { return err
ver, _ := EtcdVersionFromOptionsString(enumOpts.String())
maxVer = maxVersion(maxVer, ver)
} }
valueDesc := enum.Values().Get(int(value)) fields := enum.Values()
valueOpts := valueDesc.Options().(*descriptorpb.EnumValueOptions) for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = visitDescriptor(fd, visitor)
if err != nil {
return err
}
}
return err
}
func visitEnumNumber(enum protoreflect.EnumDescriptor, number protoreflect.EnumNumber, visitor Visitor) error {
err := visitDescriptor(enum, visitor)
if err != nil {
return err
}
return visitEnumValue(enum.Values().Get(int(number)), visitor)
}
func visitEnumValue(enum protoreflect.EnumValueDescriptor, visitor Visitor) error {
valueOpts := enum.Options().(*descriptorpb.EnumValueOptions)
if valueOpts != nil { if valueOpts != nil {
ver, _ := EtcdVersionFromOptionsString(valueOpts.String()) ver, _ := etcdVersionFromOptionsString(valueOpts.String())
maxVer = maxVersion(maxVer, ver) err := visitor(enum.FullName(), ver)
if err != nil {
return err
}
} }
return maxVer return nil
}
func visitDescriptor(md protoreflect.Descriptor, visitor Visitor) error {
opts, ok := md.Options().(fmt.Stringer)
if !ok {
return nil
}
ver, err := etcdVersionFromOptionsString(opts.String())
if err != nil {
return fmt.Errorf("%s: %s", md.FullName(), err)
}
return visitor(md.FullName(), ver)
} }
func maxVersion(a *semver.Version, b *semver.Version) *semver.Version { func maxVersion(a *semver.Version, b *semver.Version) *semver.Version {
@ -147,16 +238,7 @@ func maxVersion(a *semver.Version, b *semver.Version) *semver.Version {
return b return b
} }
func etcdVersionFromField(fd protoreflect.FieldDescriptor) *semver.Version { func etcdVersionFromOptionsString(opts string) (*semver.Version, error) {
opts := fd.Options().(*descriptorpb.FieldOptions)
if opts == nil {
return nil
}
ver, _ := EtcdVersionFromOptionsString(opts.String())
return ver
}
func EtcdVersionFromOptionsString(opts string) (*semver.Version, error) {
// TODO: Use proto.GetExtention when gogo/protobuf is usable with protoreflect // TODO: Use proto.GetExtention when gogo/protobuf is usable with protoreflect
msgs := []string{"[versionpb.etcd_version_msg]:", "[versionpb.etcd_version_field]:", "[versionpb.etcd_version_enum]:", "[versionpb.etcd_version_enum_value]:"} msgs := []string{"[versionpb.etcd_version_msg]:", "[versionpb.etcd_version_field]:", "[versionpb.etcd_version_enum]:", "[versionpb.etcd_version_enum_value]:"}
var end, index int var end, index int

View File

@ -25,6 +25,7 @@ import (
"go.etcd.io/etcd/api/v3/membershippb" "go.etcd.io/etcd/api/v3/membershippb"
"go.etcd.io/etcd/pkg/v3/pbutil" "go.etcd.io/etcd/pkg/v3/pbutil"
"go.etcd.io/etcd/raft/v3/raftpb" "go.etcd.io/etcd/raft/v3/raftpb"
"google.golang.org/protobuf/reflect/protoreflect"
) )
var ( var (
@ -97,8 +98,13 @@ func TestEtcdVersionFromEntry(t *testing.T) {
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ver := etcdVersionFromEntry(tc.input) var maxVer *semver.Version
assert.Equal(t, tc.expect, ver) err := visitEntry(tc.input, func(path protoreflect.FullName, ver *semver.Version) error {
maxVer = maxVersion(maxVer, ver)
return nil
})
assert.NoError(t, err)
assert.Equal(t, tc.expect, maxVer)
}) })
} }
} }
@ -162,8 +168,13 @@ func TestEtcdVersionFromMessage(t *testing.T) {
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
ver := etcdVersionFromMessage(proto.MessageReflect(tc.input)) var maxVer *semver.Version
assert.Equal(t, tc.expect, ver) err := visitMessage(proto.MessageReflect(tc.input), func(path protoreflect.FullName, ver *semver.Version) error {
maxVer = maxVersion(maxVer, ver)
return nil
})
assert.NoError(t, err)
assert.Equal(t, tc.expect, maxVer)
}) })
} }
} }
@ -237,7 +248,7 @@ func TestEtcdVersionFromFieldOptionsString(t *testing.T) {
} }
for _, tc := range tcs { for _, tc := range tcs {
t.Run(tc.input, func(t *testing.T) { t.Run(tc.input, func(t *testing.T) {
ver, err := EtcdVersionFromOptionsString(tc.input) ver, err := etcdVersionFromOptionsString(tc.input)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, ver, tc.expect) assert.Equal(t, ver, tc.expect)
}) })

View File

@ -101,7 +101,7 @@ func allEtcdVersionAnnotations() (annotations []etcdVersionAnnotation, err error
} }
func fileEtcdVersionAnnotations(file protoreflect.FileDescriptor) (annotations []etcdVersionAnnotation, err error) { func fileEtcdVersionAnnotations(file protoreflect.FileDescriptor) (annotations []etcdVersionAnnotation, err error) {
err = visitFileDescriptor(file, func(path string, ver *semver.Version) error { err = wal.VisitFileDescriptor(file, func(path protoreflect.FullName, ver *semver.Version) error {
a := etcdVersionAnnotation{fullName: path, version: ver} a := etcdVersionAnnotation{fullName: path, version: ver}
annotations = append(annotations, a) annotations = append(annotations, a)
return nil return nil
@ -109,80 +109,8 @@ func fileEtcdVersionAnnotations(file protoreflect.FileDescriptor) (annotations [
return annotations, err return annotations, err
} }
type Visitor func(path string, ver *semver.Version) error
func visitFileDescriptor(file protoreflect.FileDescriptor, visitor Visitor) error {
msgs := file.Messages()
for i := 0; i < msgs.Len(); i++ {
err := visitMessageDescriptor(msgs.Get(i), visitor)
if err != nil {
return err
}
}
enums := file.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return nil
}
func visitMessageDescriptor(md protoreflect.MessageDescriptor, visitor Visitor) error {
err := VisitDescriptor(md, visitor)
if err != nil {
return err
}
fields := md.Fields()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = VisitDescriptor(fd, visitor)
if err != nil {
return err
}
}
enums := md.Enums()
for i := 0; i < enums.Len(); i++ {
err := visitEnumDescriptor(enums.Get(i), visitor)
if err != nil {
return err
}
}
return err
}
func visitEnumDescriptor(enum protoreflect.EnumDescriptor, visitor Visitor) error {
err := VisitDescriptor(enum, visitor)
if err != nil {
return err
}
fields := enum.Values()
for i := 0; i < fields.Len(); i++ {
fd := fields.Get(i)
err = VisitDescriptor(fd, visitor)
if err != nil {
return err
}
}
return err
}
func VisitDescriptor(md protoreflect.Descriptor, visitor Visitor) error {
s, ok := md.Options().(fmt.Stringer)
if !ok {
return nil
}
ver, err := wal.EtcdVersionFromOptionsString(s.String())
if err != nil {
return fmt.Errorf("%s: %s", md.FullName(), err)
}
return visitor(string(md.FullName()), ver)
}
type etcdVersionAnnotation struct { type etcdVersionAnnotation struct {
fullName string fullName protoreflect.FullName
version *semver.Version version *semver.Version
} }