Merge pull request #4831 from xiang90/tlx

*: http and https on the same port
This commit is contained in:
Xiang Li 2016-03-23 15:59:58 -07:00
commit 333ac5789a
14 changed files with 278 additions and 133 deletions

View File

@ -35,11 +35,18 @@ const (
caPath = "../integration/fixtures/ca.crt"
)
type clientConnType int
const (
clientNonTLS clientConnType = iota
clientTLS
clientTLSAndNonTLS
)
var (
configNoTLS = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 0,
isClientTLS: false,
isPeerTLS: false,
initialToken: "new",
}
@ -52,42 +59,46 @@ var (
configTLS = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 0,
isClientTLS: true,
clientTLS: clientTLS,
isPeerTLS: true,
initialToken: "new",
}
configClientTLS = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 0,
isClientTLS: true,
clientTLS: clientTLS,
isPeerTLS: false,
initialToken: "new",
}
configClientBoth = etcdProcessClusterConfig{
clusterSize: 1,
proxySize: 0,
clientTLS: clientTLSAndNonTLS,
isPeerTLS: false,
initialToken: "new",
}
configPeerTLS = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 0,
isClientTLS: false,
isPeerTLS: true,
initialToken: "new",
}
configWithProxy = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 1,
isClientTLS: false,
isPeerTLS: false,
initialToken: "new",
}
configWithProxyTLS = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 1,
isClientTLS: true,
clientTLS: clientTLS,
isPeerTLS: true,
initialToken: "new",
}
configWithProxyPeerTLS = etcdProcessClusterConfig{
clusterSize: 3,
proxySize: 1,
isClientTLS: false,
isPeerTLS: true,
initialToken: "new",
}
@ -107,6 +118,7 @@ func TestBasicOpsClientTLS(t *testing.T) { testBasicOpsPutGet(t, &configClien
func TestBasicOpsProxyNoTLS(t *testing.T) { testBasicOpsPutGet(t, &configWithProxy) }
func TestBasicOpsProxyTLS(t *testing.T) { testBasicOpsPutGet(t, &configWithProxyTLS) }
func TestBasicOpsProxyPeerTLS(t *testing.T) { testBasicOpsPutGet(t, &configWithProxyPeerTLS) }
func TestBasicOpsClientBoth(t *testing.T) { testBasicOpsPutGet(t, &configClientBoth) }
func testBasicOpsPutGet(t *testing.T, cfg *etcdProcessClusterConfig) {
defer testutil.AfterTest(t)
@ -126,13 +138,27 @@ func testBasicOpsPutGet(t *testing.T, cfg *etcdProcessClusterConfig) {
}()
expectPut := `{"action":"set","node":{"key":"/testKey","value":"foo","`
if err := cURLPut(epc, "testKey", "foo", expectPut); err != nil {
t.Fatalf("failed put with curl (%v)", err)
}
expectGet := `{"action":"get","node":{"key":"/testKey","value":"foo","`
if err := cURLGet(epc, "testKey", expectGet); err != nil {
t.Fatalf("failed get with curl (%v)", err)
if cfg.clientTLS == clientTLSAndNonTLS {
if err := cURLPut(epc, "testKey", "foo", expectPut); err != nil {
t.Fatalf("failed put with curl (%v)", err)
}
if err := cURLGet(epc, "testKey", expectGet); err != nil {
t.Fatalf("failed get with curl (%v)", err)
}
if err := cURLGetUseTLS(epc, "testKey", expectGet); err != nil {
t.Fatalf("failed get with curl (%v)", err)
}
} else {
if err := cURLPut(epc, "testKey", "foo", expectPut); err != nil {
t.Fatalf("failed put with curl (%v)", err)
}
if err := cURLGet(epc, "testKey", expectGet); err != nil {
t.Fatalf("failed get with curl (%v)", err)
}
}
}
@ -140,11 +166,24 @@ func testBasicOpsPutGet(t *testing.T, cfg *etcdProcessClusterConfig) {
// addressed to a random URL in the given cluster.
func cURLPrefixArgs(clus *etcdProcessCluster, key string) []string {
cmdArgs := []string{"curl"}
if clus.cfg.isClientTLS {
acurl := clus.procs[rand.Intn(clus.cfg.clusterSize)].cfg.acurl
if clus.cfg.clientTLS == clientTLS {
cmdArgs = append(cmdArgs, "--cacert", caPath, "--cert", certPath, "--key", privateKeyPath)
}
acurl := clus.procs[rand.Intn(clus.cfg.clusterSize)].cfg.acurl
keyURL := acurl.String() + "/v2/keys/testKey"
keyURL := acurl + "/v2/keys/testKey"
cmdArgs = append(cmdArgs, "-L", keyURL)
return cmdArgs
}
func cURLPrefixArgsUseTLS(clus *etcdProcessCluster, key string) []string {
cmdArgs := []string{"curl"}
if clus.cfg.clientTLS != clientTLSAndNonTLS {
panic("should not use cURLPrefixArgsUseTLS when serving only TLS or non-TLS")
}
cmdArgs = append(cmdArgs, "--cacert", caPath, "--cert", certPath, "--key", privateKeyPath)
acurl := clus.procs[rand.Intn(clus.cfg.clusterSize)].cfg.acurltls
keyURL := acurl + "/v2/keys/testKey"
cmdArgs = append(cmdArgs, "-L", keyURL)
return cmdArgs
}
@ -158,6 +197,10 @@ func cURLGet(clus *etcdProcessCluster, key, expected string) error {
return spawnWithExpectedString(cURLPrefixArgs(clus, key), expected)
}
func cURLGetUseTLS(clus *etcdProcessCluster, key, expected string) error {
return spawnWithExpectedString(cURLPrefixArgsUseTLS(clus, key), expected)
}
type etcdProcessCluster struct {
cfg *etcdProcessClusterConfig
procs []*etcdProcess
@ -172,14 +215,17 @@ type etcdProcess struct {
type etcdProcessConfig struct {
args []string
dataDirPath string
acurl url.URL
isProxy bool
acurl string
// additional url for tls connection when the etcd process
// serves both http and https
acurltls string
isProxy bool
}
type etcdProcessClusterConfig struct {
clusterSize int
proxySize int
isClientTLS bool
clientTLS clientConnType
isPeerTLS bool
isPeerAutoTLS bool
initialToken string
@ -254,7 +300,7 @@ func newEtcdProcess(cfg *etcdProcessConfig) (*etcdProcess, error) {
func (cfg *etcdProcessClusterConfig) etcdProcessConfigs() []*etcdProcessConfig {
clientScheme := "http"
if cfg.isClientTLS {
if cfg.clientTLS == clientTLS {
clientScheme = "https"
}
peerScheme := "http"
@ -265,8 +311,20 @@ func (cfg *etcdProcessClusterConfig) etcdProcessConfigs() []*etcdProcessConfig {
etcdCfgs := make([]*etcdProcessConfig, cfg.clusterSize+cfg.proxySize)
initialCluster := make([]string, cfg.clusterSize)
for i := 0; i < cfg.clusterSize; i++ {
var curls []string
var curl, curltls string
port := etcdProcessBasePort + 2*i
curl := url.URL{Scheme: clientScheme, Host: fmt.Sprintf("localhost:%d", port)}
switch cfg.clientTLS {
case clientNonTLS, clientTLS:
curl = (&url.URL{Scheme: clientScheme, Host: fmt.Sprintf("localhost:%d", port)}).String()
curls = []string{curl}
case clientTLSAndNonTLS:
curl = (&url.URL{Scheme: "http", Host: fmt.Sprintf("localhost:%d", port)}).String()
curltls = (&url.URL{Scheme: "https", Host: fmt.Sprintf("localhost:%d", port)}).String()
curls = []string{curl, curltls}
}
purl := url.URL{Scheme: peerScheme, Host: fmt.Sprintf("localhost:%d", port+1)}
name := fmt.Sprintf("testname%d", i)
dataDirPath, derr := ioutil.TempDir("", name+".etcd")
@ -277,8 +335,8 @@ func (cfg *etcdProcessClusterConfig) etcdProcessConfigs() []*etcdProcessConfig {
args := []string{
"--name", name,
"--listen-client-urls", curl.String(),
"--advertise-client-urls", curl.String(),
"--listen-client-urls", strings.Join(curls, ","),
"--advertise-client-urls", strings.Join(curls, ","),
"--listen-peer-urls", purl.String(),
"--initial-advertise-peer-urls", purl.String(),
"--initial-cluster-token", cfg.initialToken,
@ -294,6 +352,7 @@ func (cfg *etcdProcessClusterConfig) etcdProcessConfigs() []*etcdProcessConfig {
args: args,
dataDirPath: dataDirPath,
acurl: curl,
acurltls: curltls,
}
}
for i := 0; i < cfg.proxySize; i++ {
@ -314,7 +373,7 @@ func (cfg *etcdProcessClusterConfig) etcdProcessConfigs() []*etcdProcessConfig {
etcdCfgs[cfg.clusterSize+i] = &etcdProcessConfig{
args: args,
dataDirPath: dataDirPath,
acurl: curl,
acurl: curl.String(),
isProxy: true,
}
}
@ -328,7 +387,7 @@ func (cfg *etcdProcessClusterConfig) etcdProcessConfigs() []*etcdProcessConfig {
}
func (cfg *etcdProcessClusterConfig) tlsArgs() (args []string) {
if cfg.isClientTLS {
if cfg.clientTLS != clientNonTLS {
tlsClientArgs := []string{
"--cert-file", certPath,
"--key-file", privateKeyPath,

View File

@ -226,16 +226,16 @@ func TestCtlV2RoleList(t *testing.T) {
func etcdctlPrefixArgs(clus *etcdProcessCluster) []string {
endpoints := ""
if proxies := clus.proxies(); len(proxies) != 0 {
endpoints = proxies[0].cfg.acurl.String()
endpoints = proxies[0].cfg.acurl
} else if backends := clus.backends(); len(backends) != 0 {
es := []string{}
for _, b := range backends {
es = append(es, b.cfg.acurl.String())
es = append(es, b.cfg.acurl)
}
endpoints = strings.Join(es, ",")
}
cmdArgs := []string{"../bin/etcdctl", "--endpoints", endpoints}
if clus.cfg.isClientTLS {
if clus.cfg.clientTLS == clientTLS {
cmdArgs = append(cmdArgs, "--ca-file", caPath, "--cert-file", certPath, "--key-file", privateKeyPath)
}
return cmdArgs

View File

@ -104,12 +104,12 @@ func ctlV3PrefixArgs(clus *etcdProcessCluster, dialTimeout time.Duration) []stri
if backends := clus.backends(); len(backends) != 0 {
es := []string{}
for _, b := range backends {
es = append(es, stripSchema(b.cfg.acurl.String()))
es = append(es, stripSchema(b.cfg.acurl))
}
endpoints = strings.Join(es, ",")
}
cmdArgs := []string{"../bin/etcdctlv3", "--endpoints", endpoints, "--dial-timeout", dialTimeout.String()}
if clus.cfg.isClientTLS {
if clus.cfg.clientTLS == clientTLS {
cmdArgs = append(cmdArgs, "--cacert", caPath, "--cert", certPath, "--key", privateKeyPath)
}
return cmdArgs

View File

@ -18,6 +18,7 @@
package etcdmain
import (
"crypto/tls"
"encoding/json"
"fmt"
"io/ioutil"
@ -33,7 +34,6 @@ import (
"github.com/coreos/etcd/discovery"
"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/api/v3rpc"
"github.com/coreos/etcd/etcdserver/etcdhttp"
"github.com/coreos/etcd/pkg/cors"
"github.com/coreos/etcd/pkg/fileutil"
@ -49,7 +49,6 @@ import (
systemdutil "github.com/coreos/go-systemd/util"
"github.com/coreos/pkg/capnslog"
"github.com/prometheus/client_golang/prometheus"
"google.golang.org/grpc"
)
type dirType string
@ -220,14 +219,24 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
if !cfg.peerTLSInfo.Empty() {
plog.Infof("peerTLS: %s", cfg.peerTLSInfo)
}
plns := make([]net.Listener, 0)
for _, u := range cfg.lpurls {
if u.Scheme == "http" && !cfg.peerTLSInfo.Empty() {
plog.Warningf("The scheme of peer url %s is http while peer key/cert files are presented. Ignored peer key/cert files.", u.String())
}
var l net.Listener
l, err = rafthttp.NewListener(u, cfg.peerTLSInfo)
var (
l net.Listener
tlscfg *tls.Config
)
if !cfg.peerTLSInfo.Empty() {
tlscfg, err = cfg.peerTLSInfo.ServerConfig()
if err != nil {
return nil, err
}
}
l, err = rafthttp.NewListener(u, tlscfg)
if err != nil {
return nil, err
}
@ -243,15 +252,40 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
plns = append(plns, l)
}
var ctlscfg *tls.Config
if !cfg.clientTLSInfo.Empty() {
plog.Infof("clientTLS: %s", cfg.clientTLSInfo)
ctlscfg, err = cfg.clientTLSInfo.ServerConfig()
if err != nil {
return nil, err
}
}
clns := make([]net.Listener, 0)
sctxs := make(map[string]*serveCtx)
for _, u := range cfg.lcurls {
if u.Scheme == "http" && !cfg.clientTLSInfo.Empty() {
plog.Warningf("The scheme of client url %s is http while client key/cert files are presented. Ignored client key/cert files.", u.String())
}
ctx := &serveCtx{host: u.Host}
if u.Scheme == "https" {
ctx.secure = true
} else {
ctx.insecure = true
}
if sctxs[u.Host] != nil {
if ctx.secure {
sctxs[u.Host].secure = true
}
if ctx.insecure {
sctxs[u.Host].insecure = true
}
continue
}
var l net.Listener
l, err = net.Listen("tcp", u.Host)
if err != nil {
return nil, err
@ -265,22 +299,20 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
l = transport.LimitListener(l, int(fdLimit-reservedInternalFDNum))
}
// Do not wrap around this listener if TLS Info is set.
// HTTPS server expects TLS Conn created by TLSListener.
l, err = transport.NewKeepAliveListener(l, u.Scheme, cfg.clientTLSInfo)
l, err = transport.NewKeepAliveListener(l, "tcp", nil)
ctx.l = l
if err != nil {
return nil, err
}
urlStr := u.String()
plog.Info("listening for client requests on ", urlStr)
plog.Info("listening for client requests on ", u.Host)
defer func() {
if err != nil {
l.Close()
plog.Info("stopping listening for client requests on ", urlStr)
plog.Info("stopping listening for client requests on ", u.Host)
}
}()
clns = append(clns, l)
sctxs[u.Host] = ctx
}
srvcfg := &etcdserver.ServerConfig{
@ -317,40 +349,25 @@ func startEtcd(cfg *config) (<-chan struct{}, error) {
if cfg.corsInfo.String() != "" {
plog.Infof("cors = %s", cfg.corsInfo)
}
ch := &cors.CORSHandler{
ch := http.Handler(&cors.CORSHandler{
Handler: etcdhttp.NewClientHandler(s, srvcfg.ReqTimeout()),
Info: cfg.corsInfo,
}
})
ph := etcdhttp.NewPeerHandler(s)
var grpcS *grpc.Server
if cfg.v3demo {
// set up v3 demo rpc
tls := &cfg.clientTLSInfo
if cfg.clientTLSInfo.Empty() {
tls = nil
}
grpcS, err = v3rpc.Server(s, tls)
if err != nil {
s.Stop()
<-s.StopNotify()
return nil, err
}
}
// Start the peer server in a goroutine
for _, l := range plns {
go func(l net.Listener) {
plog.Fatal(serve(l, nil, ph, 5*time.Minute))
plog.Fatal(servePeerHTTP(l, ph))
}(l)
}
// Start a client server goroutine for each listen address
for _, l := range clns {
go func(l net.Listener) {
for _, sctx := range sctxs {
go func(sctx *serveCtx) {
// read timeout does not work with http close notify
// TODO: https://github.com/golang/go/issues/9524
plog.Fatal(serve(l, grpcS, ch, 0))
}(l)
plog.Fatal(serve(sctx, s, ctlscfg, ch))
}(sctx)
}
return s.StopNotify(), nil
@ -419,11 +436,11 @@ func startProxy(cfg *config) error {
clientURLs := []string{}
uf := func() []string {
gcls, err := etcdserver.GetClusterFromRemotePeers(peerURLs, tr)
gcls, gerr := etcdserver.GetClusterFromRemotePeers(peerURLs, tr)
// TODO: remove the 2nd check when we fix GetClusterFromRemotePeers
// GetClusterFromRemotePeers should not return nil error with an invalid empty cluster
if err != nil {
plog.Warningf("proxy: %v", err)
if gerr != nil {
plog.Warningf("proxy: %v", gerr)
return []string{}
}
if len(gcls.Members()) == 0 {
@ -432,9 +449,9 @@ func startProxy(cfg *config) error {
clientURLs = gcls.ClientURLs()
urls := struct{ PeerURLs []string }{gcls.PeerURLs()}
b, err := json.Marshal(urls)
if err != nil {
plog.Warningf("proxy: error on marshal peer urls %s", err)
b, jerr := json.Marshal(urls)
if jerr != nil {
plog.Warningf("proxy: error on marshal peer urls %s", jerr)
return clientURLs
}
@ -466,7 +483,18 @@ func startProxy(cfg *config) error {
}
// Start a proxy server goroutine for each listen address
for _, u := range cfg.lcurls {
l, err := transport.NewListener(u.Host, u.Scheme, cfg.clientTLSInfo)
var (
l net.Listener
tlscfg *tls.Config
)
if !cfg.clientTLSInfo.Empty() {
tlscfg, err = cfg.clientTLSInfo.ServerConfig()
if err != nil {
return err
}
}
l, err := transport.NewListener(u.Host, u.Scheme, tlscfg)
if err != nil {
return err
}

View File

@ -15,37 +15,87 @@
package etcdmain
import (
"crypto/tls"
"io/ioutil"
defaultLog "log"
"net"
"net/http"
"strings"
"time"
"github.com/cockroachdb/cmux"
"github.com/coreos/etcd/etcdserver"
"github.com/coreos/etcd/etcdserver/api/v3rpc"
"google.golang.org/grpc"
)
type serveCtx struct {
l net.Listener
host string
secure bool
insecure bool
}
// serve accepts incoming connections on the listener l,
// creating a new service goroutine for each. The service goroutines
// read requests and then call handler to reply to them.
func serve(l net.Listener, grpcS *grpc.Server, handler http.Handler, readTimeout time.Duration) error {
// TODO: assert net.Listener type? Arbitrary listener might break HTTPS server which
// expect a TLS Conn type.
httpl := l
if grpcS != nil {
m := cmux.New(l)
grpcl := m.Match(cmux.HTTP2HeaderField("content-type", "application/grpc"))
httpl = m.Match(cmux.Any())
go func() { plog.Fatal(m.Serve()) }()
go plog.Fatal(grpcS.Serve(grpcl))
func serve(sctx *serveCtx, s *etcdserver.EtcdServer, tlscfg *tls.Config, handler http.Handler) error {
logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
m := cmux.New(sctx.l)
if sctx.insecure {
gs := v3rpc.Server(s, nil)
grpcl := m.Match(cmux.HTTP2())
go func() { plog.Fatal(gs.Serve(grpcl)) }()
srvhttp := &http.Server{
Handler: handler,
ErrorLog: logger, // do not log user error
}
httpl := m.Match(cmux.HTTP1())
go func() { plog.Fatal(srvhttp.Serve(httpl)) }()
plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.host)
}
if sctx.secure {
gs := v3rpc.Server(s, tlscfg)
handler = grpcHandlerFunc(gs, handler)
tlsl := tls.NewListener(m.Match(cmux.Any()), tlscfg)
// TODO: add debug flag; enable logging when debug flag is set
srv := &http.Server{
Handler: handler,
TLSConfig: tlscfg,
ErrorLog: logger, // do not log user error
}
go func() { plog.Fatal(srv.Serve(tlsl)) }()
plog.Infof("serving client requests on %s", sctx.host)
}
return m.Serve()
}
// grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
// connections or otherHandler otherwise. Copied from cockroachdb.
func grpcHandlerFunc(grpcServer *grpc.Server, otherHandler http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
grpcServer.ServeHTTP(w, r)
} else {
otherHandler.ServeHTTP(w, r)
}
})
}
func servePeerHTTP(l net.Listener, handler http.Handler) error {
logger := defaultLog.New(ioutil.Discard, "etcdhttp", 0)
// TODO: add debug flag; enable logging when debug flag is set
srv := &http.Server{
Handler: handler,
ReadTimeout: readTimeout,
ReadTimeout: 5 * time.Minute,
ErrorLog: logger, // do not log user error
}
return srv.Serve(httpl)
return srv.Serve(l)
}

View File

@ -15,21 +15,18 @@
package v3rpc
import (
"crypto/tls"
"github.com/coreos/etcd/etcdserver"
pb "github.com/coreos/etcd/etcdserver/etcdserverpb"
"github.com/coreos/etcd/pkg/transport"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
)
func Server(s *etcdserver.EtcdServer, tls *transport.TLSInfo) (*grpc.Server, error) {
func Server(s *etcdserver.EtcdServer, tls *tls.Config) *grpc.Server {
var opts []grpc.ServerOption
if tls != nil {
creds, err := credentials.NewServerTLSFromFile(tls.CertFile, tls.KeyFile)
if err != nil {
return nil, err
}
opts = append(opts, grpc.Creds(creds))
opts = append(opts, grpc.Creds(credentials.NewTLS(tls)))
}
grpcServer := grpc.NewServer(opts...)
@ -39,5 +36,5 @@ func Server(s *etcdserver.EtcdServer, tls *transport.TLSInfo) (*grpc.Server, err
pb.RegisterClusterServer(grpcServer, NewClusterServer(s))
pb.RegisterAuthServer(grpcServer, NewAuthServer(s))
pb.RegisterMaintenanceServer(grpcServer, NewMaintenanceServer(s))
return grpcServer, nil
return grpcServer
}

View File

@ -15,6 +15,7 @@
package integration
import (
"crypto/tls"
"fmt"
"io/ioutil"
"math/rand"
@ -585,7 +586,16 @@ func (m *member) Launch() error {
m.hss = append(m.hss, hs)
}
if m.grpcListener != nil {
m.grpcServer, err = v3rpc.Server(m.s, m.ClientTLSInfo)
var (
tlscfg *tls.Config
)
if m.ClientTLSInfo != nil && !m.ClientTLSInfo.Empty() {
tlscfg, err = m.ClientTLSInfo.ServerConfig()
if err != nil {
return err
}
}
m.grpcServer = v3rpc.Server(m.s, tlscfg)
go m.grpcServer.Serve(m.grpcListener)
}
return nil

View File

@ -30,17 +30,12 @@ type keepAliveConn interface {
// Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
// Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
// http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
func NewKeepAliveListener(l net.Listener, scheme string, info TLSInfo) (net.Listener, error) {
func NewKeepAliveListener(l net.Listener, scheme string, tlscfg *tls.Config) (net.Listener, error) {
if scheme == "https" {
if info.Empty() {
if tlscfg == nil {
return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
}
cfg, err := info.ServerConfig()
if err != nil {
return nil, err
}
return newTLSKeepaliveListener(l, cfg), nil
return newTLSKeepaliveListener(l, tlscfg), nil
}
return &keepaliveListener{

View File

@ -31,7 +31,7 @@ func TestNewKeepAliveListener(t *testing.T) {
t.Fatalf("unexpected listen error: %v", err)
}
ln, err = NewKeepAliveListener(ln, "http", TLSInfo{})
ln, err = NewKeepAliveListener(ln, "http", nil)
if err != nil {
t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
}
@ -53,7 +53,11 @@ func TestNewKeepAliveListener(t *testing.T) {
defer os.Remove(tmp)
tlsInfo := TLSInfo{CertFile: tmp, KeyFile: tmp}
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
tlsln, err := NewKeepAliveListener(ln, "https", tlsInfo)
tlscfg, err := tlsInfo.ServerConfig()
if err != nil {
t.Fatalf("unexpected serverConfig error: %v", err)
}
tlsln, err := NewKeepAliveListener(ln, "https", tlscfg)
if err != nil {
t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
}
@ -70,13 +74,13 @@ func TestNewKeepAliveListener(t *testing.T) {
tlsln.Close()
}
func TestNewKeepAliveListenerTLSEmptyInfo(t *testing.T) {
func TestNewKeepAliveListenerTLSEmptyConfig(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("unexpected listen error: %v", err)
}
_, err = NewKeepAliveListener(ln, "https", TLSInfo{})
_, err = NewKeepAliveListener(ln, "https", nil)
if err == nil {
t.Errorf("err = nil, want not presented error")
}

View File

@ -33,7 +33,7 @@ import (
"time"
)
func NewListener(addr string, scheme string, info TLSInfo) (net.Listener, error) {
func NewListener(addr string, scheme string, tlscfg *tls.Config) (net.Listener, error) {
nettype := "tcp"
if scheme == "unix" {
// unix sockets via unix://laddr
@ -46,15 +46,11 @@ func NewListener(addr string, scheme string, info TLSInfo) (net.Listener, error)
}
if scheme == "https" {
if info.Empty() {
if tlscfg == nil {
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", scheme+"://"+addr)
}
cfg, err := info.ServerConfig()
if err != nil {
return nil, err
}
l = tls.NewListener(l, cfg)
l = tls.NewListener(l, tlscfg)
}
return l, nil

View File

@ -58,7 +58,11 @@ func TestNewListenerTLSInfo(t *testing.T) {
}
func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
tlscfg, err := tlsInfo.ServerConfig()
if err != nil {
t.Fatalf("unexpected serverConfig error: %v", err)
}
ln, err := NewListener("127.0.0.1:0", "https", tlscfg)
if err != nil {
t.Fatalf("unexpected NewListener error: %v", err)
}
@ -76,25 +80,12 @@ func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
}
func TestNewListenerTLSEmptyInfo(t *testing.T) {
_, err := NewListener("127.0.0.1:0", "https", TLSInfo{})
_, err := NewListener("127.0.0.1:0", "https", nil)
if err == nil {
t.Errorf("err = nil, want not presented error")
}
}
func TestNewListenerTLSInfoNonexist(t *testing.T) {
tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"}
_, err := NewListener("127.0.0.1:0", "https", tlsInfo)
werr := &os.PathError{
Op: "open",
Path: "@badname",
Err: errors.New("no such file or directory"),
}
if err.Error() != werr.Error() {
t.Errorf("err = %v, want %v", err, werr)
}
}
func TestNewTransportTLSInfo(t *testing.T) {
tmp, err := createTempFile([]byte("XXX"))
if err != nil {
@ -131,6 +122,19 @@ func TestNewTransportTLSInfo(t *testing.T) {
}
}
func TestTLSInfoNonexist(t *testing.T) {
tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"}
_, err := tlsInfo.ServerConfig()
werr := &os.PathError{
Op: "open",
Path: "@badname",
Err: errors.New("no such file or directory"),
}
if err.Error() != werr.Error() {
t.Errorf("err = %v, want %v", err, werr)
}
}
func TestTLSInfoEmpty(t *testing.T) {
tests := []struct {
info TLSInfo
@ -247,7 +251,7 @@ func TestTLSInfoConfigFuncs(t *testing.T) {
}
func TestNewListenerUnixSocket(t *testing.T) {
l, err := NewListener("testsocket", "unix", TLSInfo{})
l, err := NewListener("testsocket", "unix", nil)
if err != nil {
t.Errorf("error listening on unix socket (%v)", err)
}

View File

@ -15,6 +15,7 @@
package transport
import (
"crypto/tls"
"net"
"time"
)
@ -22,8 +23,8 @@ import (
// NewTimeoutListener returns a listener that listens on the given address.
// If read/write on the accepted connection blocks longer than its time limit,
// it will return timeout error.
func NewTimeoutListener(addr string, scheme string, info TLSInfo, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
ln, err := NewListener(addr, scheme, info)
func NewTimeoutListener(addr string, scheme string, tlscfg *tls.Config, rdtimeoutd, wtimeoutd time.Duration) (net.Listener, error) {
ln, err := NewListener(addr, scheme, tlscfg)
if err != nil {
return nil, err
}

View File

@ -23,7 +23,7 @@ import (
// TestNewTimeoutListener tests that NewTimeoutListener returns a
// rwTimeoutListener struct with timeouts set.
func TestNewTimeoutListener(t *testing.T) {
l, err := NewTimeoutListener("127.0.0.1:0", "http", TLSInfo{}, time.Hour, time.Hour)
l, err := NewTimeoutListener("127.0.0.1:0", "http", nil, time.Hour, time.Hour)
if err != nil {
t.Fatalf("unexpected NewTimeoutListener error: %v", err)
}

View File

@ -15,6 +15,7 @@
package rafthttp
import (
"crypto/tls"
"encoding/binary"
"fmt"
"io"
@ -38,8 +39,8 @@ var (
// NewListener returns a listener for raft message transfer between peers.
// It uses timeout listener to identify broken streams promptly.
func NewListener(u url.URL, tlsInfo transport.TLSInfo) (net.Listener, error) {
return transport.NewTimeoutListener(u.Host, u.Scheme, tlsInfo, ConnReadTimeout, ConnWriteTimeout)
func NewListener(u url.URL, tlscfg *tls.Config) (net.Listener, error) {
return transport.NewTimeoutListener(u.Host, u.Scheme, tlscfg, ConnReadTimeout, ConnWriteTimeout)
}
// NewRoundTripper returns a roundTripper used to send requests