mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
client: Add grpc authority header integration tests
This commit is contained in:
parent
6e04e8ae42
commit
58d2b12a50
69
pkg/grpc_testing/recorder.go
Normal file
69
pkg/grpc_testing/recorder.go
Normal file
@ -0,0 +1,69 @@
|
||||
// Copyright 2021 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package grpc_testing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
)
|
||||
|
||||
type GrpcRecorder struct {
|
||||
mux sync.RWMutex
|
||||
requests []RequestInfo
|
||||
}
|
||||
|
||||
type RequestInfo struct {
|
||||
FullMethod string
|
||||
Authority string
|
||||
}
|
||||
|
||||
func (ri *GrpcRecorder) UnaryInterceptor() grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
ri.record(toRequestInfo(ctx, info))
|
||||
resp, err := handler(ctx, req)
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
|
||||
func (ri *GrpcRecorder) RecordedRequests() []RequestInfo {
|
||||
ri.mux.RLock()
|
||||
defer ri.mux.RUnlock()
|
||||
reqs := make([]RequestInfo, len(ri.requests))
|
||||
copy(reqs, ri.requests)
|
||||
return reqs
|
||||
}
|
||||
|
||||
func toRequestInfo(ctx context.Context, info *grpc.UnaryServerInfo) RequestInfo {
|
||||
req := RequestInfo{
|
||||
FullMethod: info.FullMethod,
|
||||
}
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if ok {
|
||||
as := md.Get(":authority")
|
||||
if len(as) != 0 {
|
||||
req.Authority = as[0]
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
func (ri *GrpcRecorder) record(r RequestInfo) {
|
||||
ri.mux.Lock()
|
||||
defer ri.mux.Unlock()
|
||||
ri.requests = append(ri.requests, r)
|
||||
}
|
@ -539,7 +539,7 @@ func (e *Etcd) servePeers() (err error) {
|
||||
|
||||
for _, p := range e.Peers {
|
||||
u := p.Listener.Addr().String()
|
||||
gs := v3rpc.Server(e.Server, peerTLScfg)
|
||||
gs := v3rpc.Server(e.Server, peerTLScfg, nil)
|
||||
m := cmux.New(p.Listener)
|
||||
go gs.Serve(m.Match(cmux.HTTP2()))
|
||||
srv := &http.Server{
|
||||
|
@ -110,7 +110,7 @@ func (sctx *serveCtx) serve(
|
||||
}()
|
||||
|
||||
if sctx.insecure {
|
||||
gs = v3rpc.Server(s, nil, gopts...)
|
||||
gs = v3rpc.Server(s, nil, nil, gopts...)
|
||||
v3electionpb.RegisterElectionServer(gs, servElection)
|
||||
v3lockpb.RegisterLockServer(gs, servLock)
|
||||
if sctx.serviceRegister != nil {
|
||||
@ -148,7 +148,7 @@ func (sctx *serveCtx) serve(
|
||||
if tlsErr != nil {
|
||||
return tlsErr
|
||||
}
|
||||
gs = v3rpc.Server(s, tlscfg, gopts...)
|
||||
gs = v3rpc.Server(s, tlscfg, nil, gopts...)
|
||||
v3electionpb.RegisterElectionServer(gs, servElection)
|
||||
v3lockpb.RegisterLockServer(gs, servLock)
|
||||
if sctx.serviceRegister != nil {
|
||||
|
@ -36,19 +36,21 @@ const (
|
||||
maxSendBytes = math.MaxInt32
|
||||
)
|
||||
|
||||
func Server(s *etcdserver.EtcdServer, tls *tls.Config, gopts ...grpc.ServerOption) *grpc.Server {
|
||||
func Server(s *etcdserver.EtcdServer, tls *tls.Config, interceptor grpc.UnaryServerInterceptor, gopts ...grpc.ServerOption) *grpc.Server {
|
||||
var opts []grpc.ServerOption
|
||||
opts = append(opts, grpc.CustomCodec(&codec{}))
|
||||
if tls != nil {
|
||||
bundle := credentials.NewBundle(credentials.Config{TLSConfig: tls})
|
||||
opts = append(opts, grpc.Creds(bundle.TransportCredentials()))
|
||||
}
|
||||
|
||||
chainUnaryInterceptors := []grpc.UnaryServerInterceptor{
|
||||
newLogUnaryInterceptor(s),
|
||||
newUnaryInterceptor(s),
|
||||
grpc_prometheus.UnaryServerInterceptor,
|
||||
}
|
||||
if interceptor != nil {
|
||||
chainUnaryInterceptors = append(chainUnaryInterceptors, interceptor)
|
||||
}
|
||||
|
||||
chainStreamInterceptors := []grpc.StreamServerInterceptor{
|
||||
newStreamInterceptor(s),
|
||||
|
@ -39,6 +39,7 @@ import (
|
||||
"go.etcd.io/etcd/client/pkg/v3/types"
|
||||
"go.etcd.io/etcd/client/v2"
|
||||
"go.etcd.io/etcd/client/v3"
|
||||
"go.etcd.io/etcd/pkg/v3/grpc_testing"
|
||||
"go.etcd.io/etcd/raft/v3"
|
||||
"go.etcd.io/etcd/server/v3/config"
|
||||
"go.etcd.io/etcd/server/v3/embed"
|
||||
@ -602,6 +603,8 @@ type member struct {
|
||||
|
||||
isLearner bool
|
||||
closed bool
|
||||
|
||||
grpcServerRecorder *grpc_testing.GrpcRecorder
|
||||
}
|
||||
|
||||
func (m *member) GRPCURL() string { return m.grpcURL }
|
||||
@ -733,7 +736,7 @@ func mustNewMember(t testutil.TB, mcfg memberConfig) *member {
|
||||
m.WarningApplyDuration = embed.DefaultWarningApplyDuration
|
||||
|
||||
m.V2Deprecation = config.V2_DEPR_DEFAULT
|
||||
|
||||
m.grpcServerRecorder = &grpc_testing.GrpcRecorder{}
|
||||
m.Logger = memberLogger(t, mcfg.name)
|
||||
t.Cleanup(func() {
|
||||
// if we didn't cleanup the logger, the consecutive test
|
||||
@ -945,8 +948,8 @@ func (m *member) Launch() error {
|
||||
return err
|
||||
}
|
||||
}
|
||||
m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerOpts...)
|
||||
m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg)
|
||||
m.grpcServer = v3rpc.Server(m.s, tlscfg, m.grpcServerRecorder.UnaryInterceptor(), m.grpcServerOpts...)
|
||||
m.grpcServerPeer = v3rpc.Server(m.s, peerTLScfg, m.grpcServerRecorder.UnaryInterceptor())
|
||||
m.serverClient = v3client.New(m.s)
|
||||
lockpb.RegisterLockServer(m.grpcServer, v3lock.NewLockServer(m.serverClient))
|
||||
epb.RegisterElectionServer(m.grpcServer, v3election.NewElectionServer(m.serverClient))
|
||||
@ -1081,6 +1084,10 @@ func (m *member) Launch() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *member) RecordedRequests() []grpc_testing.RequestInfo {
|
||||
return m.grpcServerRecorder.RecordedRequests()
|
||||
}
|
||||
|
||||
func (m *member) WaitOK(t testutil.TB) {
|
||||
m.WaitStarted(t)
|
||||
for m.s.Leader() == 0 {
|
||||
@ -1370,8 +1377,9 @@ func (p SortableMemberSliceByPeerURLs) Swap(i, j int) { p[i], p[j] = p[j], p[i]
|
||||
type ClusterV3 struct {
|
||||
*cluster
|
||||
|
||||
mu sync.Mutex
|
||||
clients []*clientv3.Client
|
||||
mu sync.Mutex
|
||||
clients []*clientv3.Client
|
||||
clusterClient *clientv3.Client
|
||||
}
|
||||
|
||||
// NewClusterV3 returns a launched cluster with a grpc client connection
|
||||
@ -1417,6 +1425,11 @@ func (c *ClusterV3) Terminate(t testutil.TB) {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
if c.clusterClient != nil {
|
||||
if err := c.clusterClient.Close(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
c.mu.Unlock()
|
||||
c.cluster.Terminate(t)
|
||||
}
|
||||
@ -1429,6 +1442,25 @@ func (c *ClusterV3) Client(i int) *clientv3.Client {
|
||||
return c.clients[i]
|
||||
}
|
||||
|
||||
func (c *ClusterV3) ClusterClient() (client *clientv3.Client, err error) {
|
||||
if c.clusterClient == nil {
|
||||
endpoints := []string{}
|
||||
for _, m := range c.Members {
|
||||
endpoints = append(endpoints, m.grpcURL)
|
||||
}
|
||||
cfg := clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
DialTimeout: 5 * time.Second,
|
||||
DialOptions: []grpc.DialOption{grpc.WithBlock()},
|
||||
}
|
||||
c.clusterClient, err = newClientV3(cfg, cfg.Logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return c.clusterClient, nil
|
||||
}
|
||||
|
||||
// NewClientV3 creates a new grpc client connection to the member
|
||||
func (c *ClusterV3) NewClientV3(memberIndex int) (*clientv3.Client, error) {
|
||||
return NewClientV3(c.Members[memberIndex])
|
||||
|
182
tests/integration/grpc_test.go
Normal file
182
tests/integration/grpc_test.go
Normal file
@ -0,0 +1,182 @@
|
||||
// Copyright 2021 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
tls "crypto/tls"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
clientv3 "go.etcd.io/etcd/client/v3"
|
||||
"google.golang.org/grpc"
|
||||
)
|
||||
|
||||
func TestAuthority(t *testing.T) {
|
||||
tcs := []struct {
|
||||
name string
|
||||
useTCP bool
|
||||
useTLS bool
|
||||
// Pattern used to generate endpoints for client. Fields filled
|
||||
// %d - will be filled with member grpc port
|
||||
// %s - will be filled with member name
|
||||
clientURLPattern string
|
||||
|
||||
// Pattern used to validate authority received by server. Fields filled:
|
||||
// %s - list of endpoints concatenated with ";"
|
||||
expectAuthorityPattern string
|
||||
}{
|
||||
{
|
||||
name: "unix:path",
|
||||
clientURLPattern: "unix:localhost:%s",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
{
|
||||
name: "unix://absolute_path",
|
||||
clientURLPattern: "unix://localhost:%s",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
// "unixs" is not standard schema supported by etcd
|
||||
{
|
||||
name: "unixs:absolute_path",
|
||||
useTLS: true,
|
||||
clientURLPattern: "unixs:localhost:%s",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
{
|
||||
name: "unixs://absolute_path",
|
||||
useTLS: true,
|
||||
clientURLPattern: "unixs://localhost:%s",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
{
|
||||
name: "http://domain[:port]",
|
||||
useTCP: true,
|
||||
clientURLPattern: "http://localhost:%d",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
{
|
||||
name: "https://domain[:port]",
|
||||
useTLS: true,
|
||||
useTCP: true,
|
||||
clientURLPattern: "https://localhost:%d",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
{
|
||||
name: "http://address[:port]",
|
||||
useTCP: true,
|
||||
clientURLPattern: "http://127.0.0.1:%d",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
{
|
||||
name: "https://address[:port]",
|
||||
useTCP: true,
|
||||
useTLS: true,
|
||||
clientURLPattern: "https://127.0.0.1:%d",
|
||||
expectAuthorityPattern: "#initially=[%s]",
|
||||
},
|
||||
}
|
||||
for _, tc := range tcs {
|
||||
for _, clusterSize := range []int{1, 3} {
|
||||
t.Run(fmt.Sprintf("Size: %d, Scenario: %q", clusterSize, tc.name), func(t *testing.T) {
|
||||
BeforeTest(t)
|
||||
cfg := ClusterConfig{
|
||||
Size: clusterSize,
|
||||
UseTCP: tc.useTCP,
|
||||
UseIP: tc.useTCP,
|
||||
}
|
||||
cfg, tlsConfig := setupTLS(t, tc.useTLS, cfg)
|
||||
clus := NewClusterV3(t, &cfg)
|
||||
defer clus.Terminate(t)
|
||||
endpoints := templateEndpoints(t, tc.clientURLPattern, clus)
|
||||
|
||||
kv := setupClient(t, tc.clientURLPattern, clus, tlsConfig)
|
||||
defer kv.Close()
|
||||
|
||||
_, err := kv.Put(context.TODO(), "foo", "bar")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
assertAuthority(t, fmt.Sprintf(tc.expectAuthorityPattern, strings.Join(endpoints, ";")), clus)
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupTLS(t *testing.T, useTLS bool, cfg ClusterConfig) (ClusterConfig, *tls.Config) {
|
||||
t.Helper()
|
||||
if useTLS {
|
||||
cfg.ClientTLS = &testTLSInfo
|
||||
tlsConfig, err := testTLSInfo.ClientConfig()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return cfg, tlsConfig
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func setupClient(t *testing.T, endpointPattern string, clus *ClusterV3, tlsConfig *tls.Config) *clientv3.Client {
|
||||
t.Helper()
|
||||
endpoints := templateEndpoints(t, endpointPattern, clus)
|
||||
kv, err := clientv3.New(clientv3.Config{
|
||||
Endpoints: endpoints,
|
||||
DialTimeout: 5 * time.Second,
|
||||
DialOptions: []grpc.DialOption{grpc.WithBlock()},
|
||||
TLS: tlsConfig,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return kv
|
||||
}
|
||||
|
||||
func templateEndpoints(t *testing.T, pattern string, clus *ClusterV3) []string {
|
||||
t.Helper()
|
||||
endpoints := []string{}
|
||||
for _, m := range clus.Members {
|
||||
ent := pattern
|
||||
if strings.Contains(ent, "%d") {
|
||||
ent = fmt.Sprintf(ent, GrpcPortNumber(m.UniqNumber, m.MemberNumber))
|
||||
}
|
||||
if strings.Contains(ent, "%s") {
|
||||
ent = fmt.Sprintf(ent, m.Name)
|
||||
}
|
||||
if strings.Contains(ent, "%") {
|
||||
t.Fatalf("Failed to template pattern, %% symbol left %q", ent)
|
||||
}
|
||||
endpoints = append(endpoints, ent)
|
||||
}
|
||||
return endpoints
|
||||
}
|
||||
|
||||
func assertAuthority(t *testing.T, expectedAuthority string, clus *ClusterV3) {
|
||||
t.Helper()
|
||||
requestsFound := 0
|
||||
for _, m := range clus.Members {
|
||||
for _, r := range m.RecordedRequests() {
|
||||
requestsFound++
|
||||
if r.Authority != expectedAuthority {
|
||||
t.Errorf("Got unexpected authority header, member: %q, request: %q, got authority: %q, expected %q", m.Name, r.FullMethod, r.Authority, expectedAuthority)
|
||||
}
|
||||
}
|
||||
}
|
||||
if requestsFound == 0 {
|
||||
t.Errorf("Expected at least one request")
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user