mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
Move pkg/* used by client to client/pkg.
This commit is contained in:
@@ -1,28 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package fileutil
|
||||
|
||||
import "os"
|
||||
|
||||
const (
|
||||
// PrivateDirMode grants owner to make/remove files inside the directory.
|
||||
PrivateDirMode = 0700
|
||||
)
|
||||
|
||||
// OpenDir opens a directory for syncing.
|
||||
func OpenDir(path string) (*os.File, error) { return os.Open(path) }
|
||||
@@ -1,52 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateDirMode grants owner to make/remove files inside the directory.
|
||||
PrivateDirMode = 0777
|
||||
)
|
||||
|
||||
// OpenDir opens a directory in windows with write access for syncing.
|
||||
func OpenDir(path string) (*os.File, error) {
|
||||
fd, err := openDir(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return os.NewFile(uintptr(fd), path), nil
|
||||
}
|
||||
|
||||
func openDir(path string) (fd syscall.Handle, err error) {
|
||||
if len(path) == 0 {
|
||||
return syscall.InvalidHandle, syscall.ERROR_FILE_NOT_FOUND
|
||||
}
|
||||
pathp, err := syscall.UTF16PtrFromString(path)
|
||||
if err != nil {
|
||||
return syscall.InvalidHandle, err
|
||||
}
|
||||
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE)
|
||||
sharemode := uint32(syscall.FILE_SHARE_READ | syscall.FILE_SHARE_WRITE)
|
||||
createmode := uint32(syscall.OPEN_EXISTING)
|
||||
fl := uint32(syscall.FILE_FLAG_BACKUP_SEMANTICS)
|
||||
return syscall.CreateFile(pathp, access, sharemode, nil, createmode, fl, 0)
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package fileutil implements utility functions related to files and paths.
|
||||
package fileutil
|
||||
@@ -1,137 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// PrivateFileMode grants owner to read/write a file.
|
||||
PrivateFileMode = 0600
|
||||
)
|
||||
|
||||
// IsDirWriteable checks if dir is writable by writing and removing a file
|
||||
// to dir. It returns nil if dir is writable.
|
||||
func IsDirWriteable(dir string) error {
|
||||
f := filepath.Join(dir, ".touch")
|
||||
if err := ioutil.WriteFile(f, []byte(""), PrivateFileMode); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(f)
|
||||
}
|
||||
|
||||
// TouchDirAll is similar to os.MkdirAll. It creates directories with 0700 permission if any directory
|
||||
// does not exists. TouchDirAll also ensures the given directory is writable.
|
||||
func TouchDirAll(dir string) error {
|
||||
// If path is already a directory, MkdirAll does nothing and returns nil, so,
|
||||
// first check if dir exist with an expected permission mode.
|
||||
if Exist(dir) {
|
||||
err := CheckDirPermission(dir, PrivateDirMode)
|
||||
if err != nil {
|
||||
lg, _ := zap.NewProduction()
|
||||
if lg == nil {
|
||||
lg = zap.NewExample()
|
||||
}
|
||||
lg.Warn("check file permission", zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
err := os.MkdirAll(dir, PrivateDirMode)
|
||||
if err != nil {
|
||||
// if mkdirAll("a/text") and "text" is not
|
||||
// a directory, this will return syscall.ENOTDIR
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return IsDirWriteable(dir)
|
||||
}
|
||||
|
||||
// CreateDirAll is similar to TouchDirAll but returns error
|
||||
// if the deepest directory was not empty.
|
||||
func CreateDirAll(dir string) error {
|
||||
err := TouchDirAll(dir)
|
||||
if err == nil {
|
||||
var ns []string
|
||||
ns, err = ReadDir(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(ns) != 0 {
|
||||
err = fmt.Errorf("expected %q to be empty, got %q", dir, ns)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Exist returns true if a file or directory exists.
|
||||
func Exist(name string) bool {
|
||||
_, err := os.Stat(name)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// DirEmpty returns true if a directory empty and can access.
|
||||
func DirEmpty(name string) bool {
|
||||
ns, err := ReadDir(name)
|
||||
return len(ns) == 0 && err == nil
|
||||
}
|
||||
|
||||
// ZeroToEnd zeros a file starting from SEEK_CUR to its SEEK_END. May temporarily
|
||||
// shorten the length of the file.
|
||||
func ZeroToEnd(f *os.File) error {
|
||||
// TODO: support FALLOC_FL_ZERO_RANGE
|
||||
off, err := f.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
lenf, lerr := f.Seek(0, io.SeekEnd)
|
||||
if lerr != nil {
|
||||
return lerr
|
||||
}
|
||||
if err = f.Truncate(off); err != nil {
|
||||
return err
|
||||
}
|
||||
// make sure blocks remain allocated
|
||||
if err = Preallocate(f, lenf, true); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = f.Seek(off, io.SeekStart)
|
||||
return err
|
||||
}
|
||||
|
||||
// CheckDirPermission checks permission on an existing dir.
|
||||
// Returns error if dir is empty or exist with a different permission than specified.
|
||||
func CheckDirPermission(dir string, perm os.FileMode) error {
|
||||
if !Exist(dir) {
|
||||
return fmt.Errorf("directory %q empty, cannot check permission", dir)
|
||||
}
|
||||
//check the existing permission on the directory
|
||||
dirInfo, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
dirMode := dirInfo.Mode().Perm()
|
||||
if dirMode != perm {
|
||||
err = fmt.Errorf("directory %q exist, but the permission is %q. The recommended permission is %q to prevent possible unprivileged access to the data", dir, dirInfo.Mode(), os.FileMode(PrivateDirMode))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,194 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsDirWriteable(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", "")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected ioutil.TempDir error: %v", err)
|
||||
}
|
||||
defer os.RemoveAll(tmpdir)
|
||||
if err = IsDirWriteable(tmpdir); err != nil {
|
||||
t.Fatalf("unexpected IsDirWriteable error: %v", err)
|
||||
}
|
||||
if err = os.Chmod(tmpdir, 0444); err != nil {
|
||||
t.Fatalf("unexpected os.Chmod error: %v", err)
|
||||
}
|
||||
me, err := user.Current()
|
||||
if err != nil {
|
||||
// err can be non-nil when cross compiled
|
||||
// http://stackoverflow.com/questions/20609415/cross-compiling-user-current-not-implemented-on-linux-amd64
|
||||
t.Skipf("failed to get current user: %v", err)
|
||||
}
|
||||
if me.Name == "root" || runtime.GOOS == "windows" {
|
||||
// ideally we should check CAP_DAC_OVERRIDE.
|
||||
// but it does not matter for tests.
|
||||
// Chmod is not supported under windows.
|
||||
t.Skipf("running as a superuser or in windows")
|
||||
}
|
||||
if err := IsDirWriteable(tmpdir); err == nil {
|
||||
t.Fatalf("expected IsDirWriteable to error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDirAll(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir(os.TempDir(), "foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
tmpdir2 := filepath.Join(tmpdir, "testdir")
|
||||
if err = CreateDirAll(tmpdir2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = ioutil.WriteFile(filepath.Join(tmpdir2, "text.txt"), []byte("test text"), PrivateFileMode); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err = CreateDirAll(tmpdir2); err == nil || !strings.Contains(err.Error(), "to be empty, got") {
|
||||
t.Fatalf("unexpected error %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExist(t *testing.T) {
|
||||
fdir := filepath.Join(os.TempDir(), fmt.Sprint(time.Now().UnixNano()+rand.Int63n(1000)))
|
||||
os.RemoveAll(fdir)
|
||||
if err := os.Mkdir(fdir, 0666); err != nil {
|
||||
t.Skip(err)
|
||||
}
|
||||
defer os.RemoveAll(fdir)
|
||||
if !Exist(fdir) {
|
||||
t.Fatalf("expected Exist true, got %v", Exist(fdir))
|
||||
}
|
||||
|
||||
f, err := ioutil.TempFile(os.TempDir(), "fileutil")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
|
||||
if g := Exist(f.Name()); !g {
|
||||
t.Errorf("exist = %v, want true", g)
|
||||
}
|
||||
|
||||
os.Remove(f.Name())
|
||||
if g := Exist(f.Name()); g {
|
||||
t.Errorf("exist = %v, want false", g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirEmpty(t *testing.T) {
|
||||
dir, err := ioutil.TempDir(os.TempDir(), "empty_dir")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
if !DirEmpty(dir) {
|
||||
t.Fatalf("expected DirEmpty true, got %v", DirEmpty(dir))
|
||||
}
|
||||
|
||||
file, err := ioutil.TempFile(dir, "new_file")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
file.Close()
|
||||
|
||||
if DirEmpty(dir) {
|
||||
t.Fatalf("expected DirEmpty false, got %v", DirEmpty(dir))
|
||||
}
|
||||
if DirEmpty(file.Name()) {
|
||||
t.Fatalf("expected DirEmpty false, got %v", DirEmpty(file.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
func TestZeroToEnd(t *testing.T) {
|
||||
f, err := ioutil.TempFile(os.TempDir(), "fileutil")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(f.Name())
|
||||
defer f.Close()
|
||||
|
||||
// Ensure 0 size is a nop so zero-to-end on an empty file won't give EINVAL.
|
||||
if err = ZeroToEnd(f); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b := make([]byte, 1024)
|
||||
for i := range b {
|
||||
b[i] = 12
|
||||
}
|
||||
if _, err = f.Write(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err = f.Seek(512, io.SeekStart); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err = ZeroToEnd(f); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
off, serr := f.Seek(0, io.SeekCurrent)
|
||||
if serr != nil {
|
||||
t.Fatal(serr)
|
||||
}
|
||||
if off != 512 {
|
||||
t.Fatalf("expected offset 512, got %d", off)
|
||||
}
|
||||
|
||||
b = make([]byte, 512)
|
||||
if _, err = f.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
for i := range b {
|
||||
if b[i] != 0 {
|
||||
t.Errorf("expected b[%d] = 0, got %d", i, b[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDirPermission(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir(os.TempDir(), "foo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
tmpdir2 := filepath.Join(tmpdir, "testpermission")
|
||||
// create a new dir with 0700
|
||||
if err = CreateDirAll(tmpdir2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// check dir permission with mode different than created dir
|
||||
if err = CheckDirPermission(tmpdir2, 0600); err == nil {
|
||||
t.Errorf("expected error, got nil")
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLocked = errors.New("fileutil: file already locked")
|
||||
)
|
||||
|
||||
type LockedFile struct{ *os.File }
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows && !plan9 && !solaris
|
||||
// +build !windows,!plan9,!solaris
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func flockTryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB); err != nil {
|
||||
f.Close()
|
||||
if err == syscall.EWOULDBLOCK {
|
||||
err = ErrLocked
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
|
||||
func flockLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = syscall.Flock(int(f.Fd()), syscall.LOCK_EX); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, err
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// This used to call syscall.Flock() but that call fails with EBADF on NFS.
|
||||
// An alternative is lockf() which works on NFS but that call lets a process lock
|
||||
// the same file twice. Instead, use Linux's non-standard open file descriptor
|
||||
// locks which will block if the process already holds the file lock.
|
||||
|
||||
var (
|
||||
wrlck = syscall.Flock_t{
|
||||
Type: syscall.F_WRLCK,
|
||||
Whence: int16(io.SeekStart),
|
||||
Start: 0,
|
||||
Len: 0,
|
||||
}
|
||||
|
||||
linuxTryLockFile = flockTryLockFile
|
||||
linuxLockFile = flockLockFile
|
||||
)
|
||||
|
||||
func init() {
|
||||
// use open file descriptor locks if the system supports it
|
||||
getlk := syscall.Flock_t{Type: syscall.F_RDLCK}
|
||||
if err := syscall.FcntlFlock(0, unix.F_OFD_GETLK, &getlk); err == nil {
|
||||
linuxTryLockFile = ofdTryLockFile
|
||||
linuxLockFile = ofdLockFile
|
||||
}
|
||||
}
|
||||
|
||||
func TryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
return linuxTryLockFile(path, flag, perm)
|
||||
}
|
||||
|
||||
func ofdTryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ofdTryLockFile failed to open %q (%v)", path, err)
|
||||
}
|
||||
|
||||
flock := wrlck
|
||||
if err = syscall.FcntlFlock(f.Fd(), unix.F_OFD_SETLK, &flock); err != nil {
|
||||
f.Close()
|
||||
if err == syscall.EWOULDBLOCK {
|
||||
err = ErrLocked
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
|
||||
func LockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
return linuxLockFile(path, flag, perm)
|
||||
}
|
||||
|
||||
func ofdLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ofdLockFile failed to open %q (%v)", path, err)
|
||||
}
|
||||
|
||||
flock := wrlck
|
||||
err = syscall.FcntlFlock(f.Fd(), unix.F_OFD_SETLKW, &flock)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright 2017 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package fileutil
|
||||
|
||||
import "testing"
|
||||
|
||||
// TestLockAndUnlockSyscallFlock tests the fallback flock using the flock syscall.
|
||||
func TestLockAndUnlockSyscallFlock(t *testing.T) {
|
||||
oldTryLock, oldLock := linuxTryLockFile, linuxLockFile
|
||||
defer func() {
|
||||
linuxTryLockFile, linuxLockFile = oldTryLock, oldLock
|
||||
}()
|
||||
linuxTryLockFile, linuxLockFile = flockTryLockFile, flockLockFile
|
||||
TestLockAndUnlock(t)
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
if err := os.Chmod(path, syscall.DMEXCL|PrivateFileMode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
f, err := os.Open(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, ErrLocked
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
|
||||
func LockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
if err := os.Chmod(path, syscall.DMEXCL|PrivateFileMode); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err == nil {
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build solaris
|
||||
// +build solaris
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func TryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
var lock syscall.Flock_t
|
||||
lock.Start = 0
|
||||
lock.Len = 0
|
||||
lock.Pid = 0
|
||||
lock.Type = syscall.F_WRLCK
|
||||
lock.Whence = 0
|
||||
lock.Pid = 0
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := syscall.FcntlFlock(f.Fd(), syscall.F_SETLK, &lock); err != nil {
|
||||
f.Close()
|
||||
if err == syscall.EAGAIN {
|
||||
err = ErrLocked
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
|
||||
func LockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
var lock syscall.Flock_t
|
||||
lock.Start = 0
|
||||
lock.Len = 0
|
||||
lock.Pid = 0
|
||||
lock.Type = syscall.F_WRLCK
|
||||
lock.Whence = 0
|
||||
f, err := os.OpenFile(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = syscall.FcntlFlock(f.Fd(), syscall.F_SETLKW, &lock); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestLockAndUnlock(t *testing.T) {
|
||||
f, err := ioutil.TempFile("", "lock")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
defer func() {
|
||||
err = os.Remove(f.Name())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// lock the file
|
||||
l, err := LockFile(f.Name(), os.O_WRONLY, PrivateFileMode)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// try lock a locked file
|
||||
if _, err = TryLockFile(f.Name(), os.O_WRONLY, PrivateFileMode); err != ErrLocked {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// unlock the file
|
||||
if err = l.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// try lock the unlocked file
|
||||
dupl, err := TryLockFile(f.Name(), os.O_WRONLY, PrivateFileMode)
|
||||
if err != nil {
|
||||
t.Errorf("err = %v, want %v", err, nil)
|
||||
}
|
||||
|
||||
// blocking on locked file
|
||||
locked := make(chan struct{}, 1)
|
||||
go func() {
|
||||
bl, blerr := LockFile(f.Name(), os.O_WRONLY, PrivateFileMode)
|
||||
if blerr != nil {
|
||||
t.Error(blerr)
|
||||
}
|
||||
locked <- struct{}{}
|
||||
if blerr = bl.Close(); blerr != nil {
|
||||
t.Error(blerr)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-locked:
|
||||
t.Error("unexpected unblocking")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
// unlock
|
||||
if err = dupl.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// the previously blocked routine should be unblocked
|
||||
select {
|
||||
case <-locked:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Error("unexpected blocking")
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows && !plan9 && !solaris && !linux
|
||||
// +build !windows,!plan9,!solaris,!linux
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
func TryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
return flockTryLockFile(path, flag, perm)
|
||||
}
|
||||
|
||||
func LockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
return flockLockFile(path, flag, perm)
|
||||
}
|
||||
@@ -1,126 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
var (
|
||||
modkernel32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
procLockFileEx = modkernel32.NewProc("LockFileEx")
|
||||
|
||||
errLocked = errors.New("the process cannot access the file because another process has locked a portion of the file")
|
||||
)
|
||||
|
||||
const (
|
||||
// https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx
|
||||
LOCKFILE_EXCLUSIVE_LOCK = 2
|
||||
LOCKFILE_FAIL_IMMEDIATELY = 1
|
||||
|
||||
// see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx
|
||||
errLockViolation syscall.Errno = 0x21
|
||||
)
|
||||
|
||||
func TryLockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
f, err := open(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := lockFile(syscall.Handle(f.Fd()), LOCKFILE_FAIL_IMMEDIATELY); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
|
||||
func LockFile(path string, flag int, perm os.FileMode) (*LockedFile, error) {
|
||||
f, err := open(path, flag, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := lockFile(syscall.Handle(f.Fd()), 0); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &LockedFile{f}, nil
|
||||
}
|
||||
|
||||
func open(path string, flag int, perm os.FileMode) (*os.File, error) {
|
||||
if path == "" {
|
||||
return nil, fmt.Errorf("cannot open empty filename")
|
||||
}
|
||||
var access uint32
|
||||
switch flag {
|
||||
case syscall.O_RDONLY:
|
||||
access = syscall.GENERIC_READ
|
||||
case syscall.O_WRONLY:
|
||||
access = syscall.GENERIC_WRITE
|
||||
case syscall.O_RDWR:
|
||||
access = syscall.GENERIC_READ | syscall.GENERIC_WRITE
|
||||
case syscall.O_WRONLY | syscall.O_CREAT:
|
||||
access = syscall.GENERIC_ALL
|
||||
default:
|
||||
panic(fmt.Errorf("flag %v is not supported", flag))
|
||||
}
|
||||
fd, err := syscall.CreateFile(&(syscall.StringToUTF16(path)[0]),
|
||||
access,
|
||||
syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE|syscall.FILE_SHARE_DELETE,
|
||||
nil,
|
||||
syscall.OPEN_ALWAYS,
|
||||
syscall.FILE_ATTRIBUTE_NORMAL,
|
||||
0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return os.NewFile(uintptr(fd), path), nil
|
||||
}
|
||||
|
||||
func lockFile(fd syscall.Handle, flags uint32) error {
|
||||
var flag uint32 = LOCKFILE_EXCLUSIVE_LOCK
|
||||
flag |= flags
|
||||
if fd == syscall.InvalidHandle {
|
||||
return nil
|
||||
}
|
||||
err := lockFileEx(fd, flag, 1, 0, &syscall.Overlapped{})
|
||||
if err == nil {
|
||||
return nil
|
||||
} else if err.Error() == errLocked.Error() {
|
||||
return ErrLocked
|
||||
} else if err != errLockViolation {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lockFileEx(h syscall.Handle, flags, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) {
|
||||
var reserved uint32 = 0
|
||||
r1, _, e1 := syscall.Syscall6(procLockFileEx.Addr(), 6, uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)))
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = error(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
// Preallocate tries to allocate the space for given
|
||||
// file. This operation is only supported on linux by a
|
||||
// few filesystems (btrfs, ext4, etc.).
|
||||
// If the operation is unsupported, no error will be returned.
|
||||
// Otherwise, the error encountered will be returned.
|
||||
func Preallocate(f *os.File, sizeInBytes int64, extendFile bool) error {
|
||||
if sizeInBytes == 0 {
|
||||
// fallocate will return EINVAL if length is 0; skip
|
||||
return nil
|
||||
}
|
||||
if extendFile {
|
||||
return preallocExtend(f, sizeInBytes)
|
||||
}
|
||||
return preallocFixed(f, sizeInBytes)
|
||||
}
|
||||
|
||||
func preallocExtendTrunc(f *os.File, sizeInBytes int64) error {
|
||||
curOff, err := f.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
size, err := f.Seek(sizeInBytes, io.SeekEnd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err = f.Seek(curOff, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
if sizeInBytes > size {
|
||||
return nil
|
||||
}
|
||||
return f.Truncate(sizeInBytes)
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
func preallocExtend(f *os.File, sizeInBytes int64) error {
|
||||
if err := preallocFixed(f, sizeInBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
return preallocExtendTrunc(f, sizeInBytes)
|
||||
}
|
||||
|
||||
func preallocFixed(f *os.File, sizeInBytes int64) error {
|
||||
// allocate all requested space or no space at all
|
||||
// TODO: allocate contiguous space on disk with F_ALLOCATECONTIG flag
|
||||
fstore := &unix.Fstore_t{
|
||||
Flags: unix.F_ALLOCATEALL,
|
||||
Posmode: unix.F_PEOFPOSMODE,
|
||||
Length: sizeInBytes,
|
||||
}
|
||||
err := unix.FcntlFstore(f.Fd(), unix.F_PREALLOCATE, fstore)
|
||||
if err == nil || err == unix.ENOTSUP {
|
||||
return nil
|
||||
}
|
||||
|
||||
// wrong argument to fallocate syscall
|
||||
if err == unix.EINVAL {
|
||||
// filesystem "st_blocks" are allocated in the units of
|
||||
// "Allocation Block Size" (run "diskutil info /" command)
|
||||
var stat syscall.Stat_t
|
||||
syscall.Fstat(int(f.Fd()), &stat)
|
||||
|
||||
// syscall.Statfs_t.Bsize is "optimal transfer block size"
|
||||
// and contains matching 4096 value when latest OS X kernel
|
||||
// supports 4,096 KB filesystem block size
|
||||
var statfs syscall.Statfs_t
|
||||
syscall.Fstatfs(int(f.Fd()), &statfs)
|
||||
blockSize := int64(statfs.Bsize)
|
||||
|
||||
if stat.Blocks*blockSize >= sizeInBytes {
|
||||
// enough blocks are already allocated
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,77 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPreallocateExtend(t *testing.T) {
|
||||
pf := func(f *os.File, sz int64) error { return Preallocate(f, sz, true) }
|
||||
tf := func(t *testing.T, f *os.File) { testPreallocateExtend(t, f, pf) }
|
||||
runPreallocTest(t, tf)
|
||||
}
|
||||
|
||||
func TestPreallocateExtendTrunc(t *testing.T) {
|
||||
tf := func(t *testing.T, f *os.File) { testPreallocateExtend(t, f, preallocExtendTrunc) }
|
||||
runPreallocTest(t, tf)
|
||||
}
|
||||
|
||||
func testPreallocateExtend(t *testing.T, f *os.File, pf func(*os.File, int64) error) {
|
||||
size := int64(64 * 1000)
|
||||
if err := pf(f, size); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stat.Size() != size {
|
||||
t.Errorf("size = %d, want %d", stat.Size(), size)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreallocateFixed(t *testing.T) { runPreallocTest(t, testPreallocateFixed) }
|
||||
func testPreallocateFixed(t *testing.T, f *os.File) {
|
||||
size := int64(64 * 1000)
|
||||
if err := Preallocate(f, size, false); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stat, err := f.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if stat.Size() != 0 {
|
||||
t.Errorf("size = %d, want %d", stat.Size(), 0)
|
||||
}
|
||||
}
|
||||
|
||||
func runPreallocTest(t *testing.T, test func(*testing.T, *os.File)) {
|
||||
p, err := ioutil.TempDir(os.TempDir(), "preallocateTest")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(p)
|
||||
|
||||
f, err := ioutil.TempFile(p, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
test(t, f)
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func preallocExtend(f *os.File, sizeInBytes int64) error {
|
||||
// use mode = 0 to change size
|
||||
err := syscall.Fallocate(int(f.Fd()), 0, 0, sizeInBytes)
|
||||
if err != nil {
|
||||
errno, ok := err.(syscall.Errno)
|
||||
// not supported; fallback
|
||||
// fallocate EINTRs frequently in some environments; fallback
|
||||
if ok && (errno == syscall.ENOTSUP || errno == syscall.EINTR) {
|
||||
return preallocExtendTrunc(f, sizeInBytes)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func preallocFixed(f *os.File, sizeInBytes int64) error {
|
||||
// use mode = 1 to keep size; see FALLOC_FL_KEEP_SIZE
|
||||
err := syscall.Fallocate(int(f.Fd()), 1, 0, sizeInBytes)
|
||||
if err != nil {
|
||||
errno, ok := err.(syscall.Errno)
|
||||
// treat not supported as nil error
|
||||
if ok && errno == syscall.ENOTSUP {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !linux && !darwin
|
||||
// +build !linux,!darwin
|
||||
|
||||
package fileutil
|
||||
|
||||
import "os"
|
||||
|
||||
func preallocExtend(f *os.File, sizeInBytes int64) error {
|
||||
return preallocExtendTrunc(f, sizeInBytes)
|
||||
}
|
||||
|
||||
func preallocFixed(f *os.File, sizeInBytes int64) error { return nil }
|
||||
@@ -1,93 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func PurgeFile(lg *zap.Logger, dirname string, suffix string, max uint, interval time.Duration, stop <-chan struct{}) <-chan error {
|
||||
return purgeFile(lg, dirname, suffix, max, interval, stop, nil, nil)
|
||||
}
|
||||
|
||||
func PurgeFileWithDoneNotify(lg *zap.Logger, dirname string, suffix string, max uint, interval time.Duration, stop <-chan struct{}) (<-chan struct{}, <-chan error) {
|
||||
doneC := make(chan struct{})
|
||||
errC := purgeFile(lg, dirname, suffix, max, interval, stop, nil, doneC)
|
||||
return doneC, errC
|
||||
}
|
||||
|
||||
// purgeFile is the internal implementation for PurgeFile which can post purged files to purgec if non-nil.
|
||||
// if donec is non-nil, the function closes it to notify its exit.
|
||||
func purgeFile(lg *zap.Logger, dirname string, suffix string, max uint, interval time.Duration, stop <-chan struct{}, purgec chan<- string, donec chan<- struct{}) <-chan error {
|
||||
if lg == nil {
|
||||
lg = zap.NewNop()
|
||||
}
|
||||
errC := make(chan error, 1)
|
||||
go func() {
|
||||
if donec != nil {
|
||||
defer close(donec)
|
||||
}
|
||||
for {
|
||||
fnames, err := ReadDir(dirname)
|
||||
if err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
newfnames := make([]string, 0)
|
||||
for _, fname := range fnames {
|
||||
if strings.HasSuffix(fname, suffix) {
|
||||
newfnames = append(newfnames, fname)
|
||||
}
|
||||
}
|
||||
sort.Strings(newfnames)
|
||||
fnames = newfnames
|
||||
for len(newfnames) > int(max) {
|
||||
f := filepath.Join(dirname, newfnames[0])
|
||||
l, err := TryLockFile(f, os.O_WRONLY, PrivateFileMode)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
if err = os.Remove(f); err != nil {
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
if err = l.Close(); err != nil {
|
||||
lg.Warn("failed to unlock/close", zap.String("path", l.Name()), zap.Error(err))
|
||||
errC <- err
|
||||
return
|
||||
}
|
||||
lg.Info("purged", zap.String("path", f))
|
||||
newfnames = newfnames[1:]
|
||||
}
|
||||
if purgec != nil {
|
||||
for i := 0; i < len(fnames)-len(newfnames); i++ {
|
||||
purgec <- fnames[i]
|
||||
}
|
||||
}
|
||||
select {
|
||||
case <-time.After(interval):
|
||||
case <-stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return errC
|
||||
}
|
||||
@@ -1,179 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestPurgeFile(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "purgefile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
// minimal file set
|
||||
for i := 0; i < 3; i++ {
|
||||
f, ferr := os.Create(filepath.Join(dir, fmt.Sprintf("%d.test", i)))
|
||||
if ferr != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
stop, purgec := make(chan struct{}), make(chan string, 10)
|
||||
|
||||
// keep 3 most recent files
|
||||
errch := purgeFile(zap.NewExample(), dir, "test", 3, time.Millisecond, stop, purgec, nil)
|
||||
select {
|
||||
case f := <-purgec:
|
||||
t.Errorf("unexpected purge on %q", f)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
|
||||
// rest of the files
|
||||
for i := 4; i < 10; i++ {
|
||||
go func(n int) {
|
||||
f, ferr := os.Create(filepath.Join(dir, fmt.Sprintf("%d.test", n)))
|
||||
if ferr != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
f.Close()
|
||||
}(i)
|
||||
}
|
||||
|
||||
// watch files purge away
|
||||
for i := 4; i < 10; i++ {
|
||||
select {
|
||||
case <-purgec:
|
||||
case <-time.After(time.Second):
|
||||
t.Errorf("purge took too long")
|
||||
}
|
||||
}
|
||||
|
||||
fnames, rerr := ReadDir(dir)
|
||||
if rerr != nil {
|
||||
t.Fatal(rerr)
|
||||
}
|
||||
wnames := []string{"7.test", "8.test", "9.test"}
|
||||
if !reflect.DeepEqual(fnames, wnames) {
|
||||
t.Errorf("filenames = %v, want %v", fnames, wnames)
|
||||
}
|
||||
|
||||
// no error should be reported from purge routine
|
||||
select {
|
||||
case f := <-purgec:
|
||||
t.Errorf("unexpected purge on %q", f)
|
||||
case err := <-errch:
|
||||
t.Errorf("unexpected purge error %v", err)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
close(stop)
|
||||
}
|
||||
|
||||
func TestPurgeFileHoldingLockFile(t *testing.T) {
|
||||
dir, err := ioutil.TempDir("", "purgefile")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
var f *os.File
|
||||
f, err = os.Create(filepath.Join(dir, fmt.Sprintf("%d.test", i)))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
|
||||
// create a purge barrier at 5
|
||||
p := filepath.Join(dir, fmt.Sprintf("%d.test", 5))
|
||||
l, err := LockFile(p, os.O_WRONLY, PrivateFileMode)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
stop, purgec := make(chan struct{}), make(chan string, 10)
|
||||
errch := purgeFile(zap.NewExample(), dir, "test", 3, time.Millisecond, stop, purgec, nil)
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
select {
|
||||
case <-purgec:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("purge took too long")
|
||||
}
|
||||
}
|
||||
|
||||
fnames, rerr := ReadDir(dir)
|
||||
if rerr != nil {
|
||||
t.Fatal(rerr)
|
||||
}
|
||||
|
||||
wnames := []string{"5.test", "6.test", "7.test", "8.test", "9.test"}
|
||||
if !reflect.DeepEqual(fnames, wnames) {
|
||||
t.Errorf("filenames = %v, want %v", fnames, wnames)
|
||||
}
|
||||
|
||||
select {
|
||||
case s := <-purgec:
|
||||
t.Errorf("unexpected purge %q", s)
|
||||
case err = <-errch:
|
||||
t.Errorf("unexpected purge error %v", err)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
|
||||
// remove the purge barrier
|
||||
if err = l.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// wait for rest of purges (5, 6)
|
||||
for i := 0; i < 2; i++ {
|
||||
select {
|
||||
case <-purgec:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("purge took too long")
|
||||
}
|
||||
}
|
||||
|
||||
fnames, rerr = ReadDir(dir)
|
||||
if rerr != nil {
|
||||
t.Fatal(rerr)
|
||||
}
|
||||
wnames = []string{"7.test", "8.test", "9.test"}
|
||||
if !reflect.DeepEqual(fnames, wnames) {
|
||||
t.Errorf("filenames = %v, want %v", fnames, wnames)
|
||||
}
|
||||
|
||||
select {
|
||||
case f := <-purgec:
|
||||
t.Errorf("unexpected purge on %q", f)
|
||||
case err := <-errch:
|
||||
t.Errorf("unexpected purge error %v", err)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
|
||||
close(stop)
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
)
|
||||
|
||||
// ReadDirOp represents an read-directory operation.
|
||||
type ReadDirOp struct {
|
||||
ext string
|
||||
}
|
||||
|
||||
// ReadDirOption configures archiver operations.
|
||||
type ReadDirOption func(*ReadDirOp)
|
||||
|
||||
// WithExt filters file names by their extensions.
|
||||
// (e.g. WithExt(".wal") to list only WAL files)
|
||||
func WithExt(ext string) ReadDirOption {
|
||||
return func(op *ReadDirOp) { op.ext = ext }
|
||||
}
|
||||
|
||||
func (op *ReadDirOp) applyOpts(opts []ReadDirOption) {
|
||||
for _, opt := range opts {
|
||||
opt(op)
|
||||
}
|
||||
}
|
||||
|
||||
// ReadDir returns the filenames in the given directory in sorted order.
|
||||
func ReadDir(d string, opts ...ReadDirOption) ([]string, error) {
|
||||
op := &ReadDirOp{}
|
||||
op.applyOpts(opts)
|
||||
|
||||
dir, err := os.Open(d)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer dir.Close()
|
||||
|
||||
names, err := dir.Readdirnames(-1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(names)
|
||||
|
||||
if op.ext != "" {
|
||||
tss := make([]string, 0)
|
||||
for _, v := range names {
|
||||
if filepath.Ext(v) == op.ext {
|
||||
tss = append(tss, v)
|
||||
}
|
||||
}
|
||||
names = tss
|
||||
}
|
||||
return names, nil
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestReadDir(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", "")
|
||||
defer os.RemoveAll(tmpdir)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected ioutil.TempDir error: %v", err)
|
||||
}
|
||||
|
||||
files := []string{"def", "abc", "xyz", "ghi"}
|
||||
for _, f := range files {
|
||||
writeFunc(t, filepath.Join(tmpdir, f))
|
||||
}
|
||||
fs, err := ReadDir(tmpdir)
|
||||
if err != nil {
|
||||
t.Fatalf("error calling ReadDir: %v", err)
|
||||
}
|
||||
wfs := []string{"abc", "def", "ghi", "xyz"}
|
||||
if !reflect.DeepEqual(fs, wfs) {
|
||||
t.Fatalf("ReadDir: got %v, want %v", fs, wfs)
|
||||
}
|
||||
|
||||
files = []string{"def.wal", "abc.wal", "xyz.wal", "ghi.wal"}
|
||||
for _, f := range files {
|
||||
writeFunc(t, filepath.Join(tmpdir, f))
|
||||
}
|
||||
fs, err = ReadDir(tmpdir, WithExt(".wal"))
|
||||
if err != nil {
|
||||
t.Fatalf("error calling ReadDir: %v", err)
|
||||
}
|
||||
wfs = []string{"abc.wal", "def.wal", "ghi.wal", "xyz.wal"}
|
||||
if !reflect.DeepEqual(fs, wfs) {
|
||||
t.Fatalf("ReadDir: got %v, want %v", fs, wfs)
|
||||
}
|
||||
}
|
||||
|
||||
func writeFunc(t *testing.T, path string) {
|
||||
fh, err := os.Create(path)
|
||||
if err != nil {
|
||||
t.Fatalf("error creating file: %v", err)
|
||||
}
|
||||
if err = fh.Close(); err != nil {
|
||||
t.Fatalf("error closing file: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !linux && !darwin
|
||||
// +build !linux,!darwin
|
||||
|
||||
package fileutil
|
||||
|
||||
import "os"
|
||||
|
||||
// Fsync is a wrapper around file.Sync(). Special handling is needed on darwin platform.
|
||||
func Fsync(f *os.File) error {
|
||||
return f.Sync()
|
||||
}
|
||||
|
||||
// Fdatasync is a wrapper around file.Sync(). Special handling is needed on linux platform.
|
||||
func Fdatasync(f *os.File) error {
|
||||
return f.Sync()
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build darwin
|
||||
// +build darwin
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Fsync on HFS/OSX flushes the data on to the physical drive but the drive
|
||||
// may not write it to the persistent media for quite sometime and it may be
|
||||
// written in out-of-order sequence. Using F_FULLFSYNC ensures that the
|
||||
// physical drive's buffer will also get flushed to the media.
|
||||
func Fsync(f *os.File) error {
|
||||
_, err := unix.FcntlInt(f.Fd(), unix.F_FULLFSYNC, 0)
|
||||
return err
|
||||
}
|
||||
|
||||
// Fdatasync on darwin platform invokes fcntl(F_FULLFSYNC) for actual persistence
|
||||
// on physical drive media.
|
||||
func Fdatasync(f *os.File) error {
|
||||
return Fsync(f)
|
||||
}
|
||||
@@ -1,35 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build linux
|
||||
// +build linux
|
||||
|
||||
package fileutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Fsync is a wrapper around file.Sync(). Special handling is needed on darwin platform.
|
||||
func Fsync(f *os.File) error {
|
||||
return f.Sync()
|
||||
}
|
||||
|
||||
// Fdatasync is similar to fsync(), but does not flush modified metadata
|
||||
// unless that metadata is needed in order to allow a subsequent data retrieval
|
||||
// to be correctly handled.
|
||||
func Fdatasync(f *os.File) error {
|
||||
return syscall.Fdatasync(int(f.Fd()))
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package logutil includes utilities to facilitate logging.
|
||||
package logutil
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright 2019 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
var DefaultLogLevel = "info"
|
||||
|
||||
// ConvertToZapLevel converts log level string to zapcore.Level.
|
||||
func ConvertToZapLevel(lvl string) zapcore.Level {
|
||||
var level zapcore.Level
|
||||
if err := level.Set(lvl); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return level
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
// Copyright 2019 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// DefaultZapLoggerConfig defines default zap logger configuration.
|
||||
var DefaultZapLoggerConfig = zap.Config{
|
||||
Level: zap.NewAtomicLevelAt(ConvertToZapLevel(DefaultLogLevel)),
|
||||
|
||||
Development: false,
|
||||
Sampling: &zap.SamplingConfig{
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
|
||||
Encoding: "json",
|
||||
|
||||
// copied from "zap.NewProductionEncoderConfig" with some updates
|
||||
EncoderConfig: zapcore.EncoderConfig{
|
||||
TimeKey: "ts",
|
||||
LevelKey: "level",
|
||||
NameKey: "logger",
|
||||
CallerKey: "caller",
|
||||
MessageKey: "msg",
|
||||
StacktraceKey: "stacktrace",
|
||||
LineEnding: zapcore.DefaultLineEnding,
|
||||
EncodeLevel: zapcore.LowercaseLevelEncoder,
|
||||
EncodeTime: zapcore.ISO8601TimeEncoder,
|
||||
EncodeDuration: zapcore.StringDurationEncoder,
|
||||
EncodeCaller: zapcore.ShortCallerEncoder,
|
||||
},
|
||||
|
||||
// Use "/dev/null" to discard all
|
||||
OutputPaths: []string{"stderr"},
|
||||
ErrorOutputPaths: []string{"stderr"},
|
||||
}
|
||||
|
||||
// MergeOutputPaths merges logging output paths, resolving conflicts.
|
||||
func MergeOutputPaths(cfg zap.Config) zap.Config {
|
||||
outputs := make(map[string]struct{})
|
||||
for _, v := range cfg.OutputPaths {
|
||||
outputs[v] = struct{}{}
|
||||
}
|
||||
outputSlice := make([]string, 0)
|
||||
if _, ok := outputs["/dev/null"]; ok {
|
||||
// "/dev/null" to discard all
|
||||
outputSlice = []string{"/dev/null"}
|
||||
} else {
|
||||
for k := range outputs {
|
||||
outputSlice = append(outputSlice, k)
|
||||
}
|
||||
}
|
||||
cfg.OutputPaths = outputSlice
|
||||
sort.Strings(cfg.OutputPaths)
|
||||
|
||||
errOutputs := make(map[string]struct{})
|
||||
for _, v := range cfg.ErrorOutputPaths {
|
||||
errOutputs[v] = struct{}{}
|
||||
}
|
||||
errOutputSlice := make([]string, 0)
|
||||
if _, ok := errOutputs["/dev/null"]; ok {
|
||||
// "/dev/null" to discard all
|
||||
errOutputSlice = []string{"/dev/null"}
|
||||
} else {
|
||||
for k := range errOutputs {
|
||||
errOutputSlice = append(errOutputSlice, k)
|
||||
}
|
||||
}
|
||||
cfg.ErrorOutputPaths = errOutputSlice
|
||||
sort.Strings(cfg.ErrorOutputPaths)
|
||||
|
||||
return cfg
|
||||
}
|
||||
@@ -1,111 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
"google.golang.org/grpc/grpclog"
|
||||
)
|
||||
|
||||
// NewGRPCLoggerV2 converts "*zap.Logger" to "grpclog.LoggerV2".
|
||||
// It discards all INFO level logging in gRPC, if debug level
|
||||
// is not enabled in "*zap.Logger".
|
||||
func NewGRPCLoggerV2(lcfg zap.Config) (grpclog.LoggerV2, error) {
|
||||
lg, err := lcfg.Build(zap.AddCallerSkip(1)) // to annotate caller outside of "logutil"
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &zapGRPCLogger{lg: lg, sugar: lg.Sugar()}, nil
|
||||
}
|
||||
|
||||
// NewGRPCLoggerV2FromZapCore creates "grpclog.LoggerV2" from "zap.Core"
|
||||
// and "zapcore.WriteSyncer". It discards all INFO level logging in gRPC,
|
||||
// if debug level is not enabled in "*zap.Logger".
|
||||
func NewGRPCLoggerV2FromZapCore(cr zapcore.Core, syncer zapcore.WriteSyncer) grpclog.LoggerV2 {
|
||||
// "AddCallerSkip" to annotate caller outside of "logutil"
|
||||
lg := zap.New(cr, zap.AddCaller(), zap.AddCallerSkip(1), zap.ErrorOutput(syncer)).Named("grpc")
|
||||
return &zapGRPCLogger{lg: lg, sugar: lg.Sugar()}
|
||||
}
|
||||
|
||||
type zapGRPCLogger struct {
|
||||
lg *zap.Logger
|
||||
sugar *zap.SugaredLogger
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Info(args ...interface{}) {
|
||||
if !zl.lg.Core().Enabled(zapcore.DebugLevel) {
|
||||
return
|
||||
}
|
||||
zl.sugar.Info(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Infoln(args ...interface{}) {
|
||||
if !zl.lg.Core().Enabled(zapcore.DebugLevel) {
|
||||
return
|
||||
}
|
||||
zl.sugar.Info(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Infof(format string, args ...interface{}) {
|
||||
if !zl.lg.Core().Enabled(zapcore.DebugLevel) {
|
||||
return
|
||||
}
|
||||
zl.sugar.Infof(format, args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Warning(args ...interface{}) {
|
||||
zl.sugar.Warn(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Warningln(args ...interface{}) {
|
||||
zl.sugar.Warn(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Warningf(format string, args ...interface{}) {
|
||||
zl.sugar.Warnf(format, args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Error(args ...interface{}) {
|
||||
zl.sugar.Error(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Errorln(args ...interface{}) {
|
||||
zl.sugar.Error(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Errorf(format string, args ...interface{}) {
|
||||
zl.sugar.Errorf(format, args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Fatal(args ...interface{}) {
|
||||
zl.sugar.Fatal(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Fatalln(args ...interface{}) {
|
||||
zl.sugar.Fatal(args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) Fatalf(format string, args ...interface{}) {
|
||||
zl.sugar.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
func (zl *zapGRPCLogger) V(l int) bool {
|
||||
// infoLog == 0
|
||||
if l <= 0 { // debug level, then we ignore info level in gRPC
|
||||
return !zl.lg.Core().Enabled(zapcore.DebugLevel)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestNewGRPCLoggerV2(t *testing.T) {
|
||||
logPath := filepath.Join(os.TempDir(), fmt.Sprintf("test-log-%d", time.Now().UnixNano()))
|
||||
defer os.RemoveAll(logPath)
|
||||
|
||||
lcfg := zap.Config{
|
||||
Level: zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
Development: false,
|
||||
Sampling: &zap.SamplingConfig{
|
||||
Initial: 100,
|
||||
Thereafter: 100,
|
||||
},
|
||||
Encoding: "json",
|
||||
EncoderConfig: DefaultZapLoggerConfig.EncoderConfig,
|
||||
OutputPaths: []string{logPath},
|
||||
ErrorOutputPaths: []string{logPath},
|
||||
}
|
||||
gl, err := NewGRPCLoggerV2(lcfg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// debug level is not enabled,
|
||||
// so info level gRPC-side logging is discarded
|
||||
gl.Info("etcd-logutil-1")
|
||||
data, err := ioutil.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Contains(data, []byte("etcd-logutil-1")) {
|
||||
t.Fatalf("unexpected line %q", string(data))
|
||||
}
|
||||
|
||||
gl.Warning("etcd-logutil-2")
|
||||
data, err = ioutil.ReadFile(logPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Contains(data, []byte("etcd-logutil-2")) {
|
||||
t.Fatalf("can't find data in log %q", string(data))
|
||||
}
|
||||
if !bytes.Contains(data, []byte("logutil/zap_grpc_test.go:")) {
|
||||
t.Fatalf("unexpected caller; %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewGRPCLoggerV2FromZapCore(t *testing.T) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
syncer := zapcore.AddSync(buf)
|
||||
cr := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(zap.NewProductionEncoderConfig()),
|
||||
syncer,
|
||||
zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
)
|
||||
|
||||
lg := NewGRPCLoggerV2FromZapCore(cr, syncer)
|
||||
lg.Warning("TestNewGRPCLoggerV2FromZapCore")
|
||||
txt := buf.String()
|
||||
if !strings.Contains(txt, "TestNewGRPCLoggerV2FromZapCore") {
|
||||
t.Fatalf("unexpected log %q", txt)
|
||||
}
|
||||
}
|
||||
@@ -1,93 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"go.etcd.io/etcd/pkg/v3/systemd"
|
||||
|
||||
"github.com/coreos/go-systemd/v22/journal"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
// NewJournalWriter wraps "io.Writer" to redirect log output
|
||||
// to the local systemd journal. If journald send fails, it fails
|
||||
// back to writing to the original writer.
|
||||
// The decode overhead is only <30µs per write.
|
||||
// Reference: https://github.com/coreos/pkg/blob/master/capnslog/journald_formatter.go
|
||||
func NewJournalWriter(wr io.Writer) (io.Writer, error) {
|
||||
return &journalWriter{Writer: wr}, systemd.DialJournal()
|
||||
}
|
||||
|
||||
type journalWriter struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
// WARN: assume that etcd uses default field names in zap encoder config
|
||||
// make sure to keep this up-to-date!
|
||||
type logLine struct {
|
||||
Level string `json:"level"`
|
||||
Caller string `json:"caller"`
|
||||
}
|
||||
|
||||
func (w *journalWriter) Write(p []byte) (int, error) {
|
||||
line := &logLine{}
|
||||
if err := json.NewDecoder(bytes.NewReader(p)).Decode(line); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var pri journal.Priority
|
||||
switch line.Level {
|
||||
case zapcore.DebugLevel.String():
|
||||
pri = journal.PriDebug
|
||||
case zapcore.InfoLevel.String():
|
||||
pri = journal.PriInfo
|
||||
|
||||
case zapcore.WarnLevel.String():
|
||||
pri = journal.PriWarning
|
||||
case zapcore.ErrorLevel.String():
|
||||
pri = journal.PriErr
|
||||
|
||||
case zapcore.DPanicLevel.String():
|
||||
pri = journal.PriCrit
|
||||
case zapcore.PanicLevel.String():
|
||||
pri = journal.PriCrit
|
||||
case zapcore.FatalLevel.String():
|
||||
pri = journal.PriCrit
|
||||
|
||||
default:
|
||||
panic(fmt.Errorf("unknown log level: %q", line.Level))
|
||||
}
|
||||
|
||||
err := journal.Send(string(p), pri, map[string]string{
|
||||
"PACKAGE": filepath.Dir(line.Caller),
|
||||
"SYSLOG_IDENTIFIER": filepath.Base(os.Args[0]),
|
||||
})
|
||||
if err != nil {
|
||||
// "journal" also falls back to stderr
|
||||
// "fmt.Fprintln(os.Stderr, s)"
|
||||
return w.Writer.Write(p)
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package logutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
func TestNewJournalWriter(t *testing.T) {
|
||||
buf := bytes.NewBuffer(nil)
|
||||
jw, err := NewJournalWriter(buf)
|
||||
if err != nil {
|
||||
t.Skip(err)
|
||||
}
|
||||
|
||||
syncer := zapcore.AddSync(jw)
|
||||
|
||||
cr := zapcore.NewCore(
|
||||
zapcore.NewJSONEncoder(DefaultZapLoggerConfig.EncoderConfig),
|
||||
syncer,
|
||||
zap.NewAtomicLevelAt(zap.InfoLevel),
|
||||
)
|
||||
|
||||
lg := zap.New(cr, zap.AddCaller(), zap.ErrorOutput(syncer))
|
||||
defer lg.Sync()
|
||||
|
||||
lg.Info("TestNewJournalWriter")
|
||||
if buf.String() == "" {
|
||||
// check with "journalctl -f"
|
||||
t.Log("sent logs successfully to journald")
|
||||
}
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright 2009 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package pathutil implements utility functions for handling slash-separated
|
||||
// paths.
|
||||
package pathutil
|
||||
|
||||
import "path"
|
||||
|
||||
// CanonicalURLPath returns the canonical url path for p, which follows the rules:
|
||||
// 1. the path always starts with "/"
|
||||
// 2. replace multiple slashes with a single slash
|
||||
// 3. replace each '.' '..' path name element with equivalent one
|
||||
// 4. keep the trailing slash
|
||||
// The function is borrowed from stdlib http.cleanPath in server.go.
|
||||
func CanonicalURLPath(p string) string {
|
||||
if p == "" {
|
||||
return "/"
|
||||
}
|
||||
if p[0] != '/' {
|
||||
p = "/" + p
|
||||
}
|
||||
np := path.Clean(p)
|
||||
// path.Clean removes trailing slash except for root,
|
||||
// put the trailing slash back if necessary.
|
||||
if p[len(p)-1] == '/' && np != "/" {
|
||||
np += "/"
|
||||
}
|
||||
return np
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pathutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestCanonicalURLPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
p string
|
||||
wp string
|
||||
}{
|
||||
{"/a", "/a"},
|
||||
{"", "/"},
|
||||
{"a", "/a"},
|
||||
{"//a", "/a"},
|
||||
{"/a/.", "/a"},
|
||||
{"/a/..", "/"},
|
||||
{"/a/", "/a/"},
|
||||
{"/a//", "/a/"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
if g := CanonicalURLPath(tt.p); g != tt.wp {
|
||||
t.Errorf("#%d: canonical path = %s, want %s", i, g, tt.wp)
|
||||
}
|
||||
}
|
||||
}
|
||||
142
pkg/srv/srv.go
142
pkg/srv/srv.go
@@ -1,142 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package srv looks up DNS SRV records.
|
||||
package srv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"go.etcd.io/etcd/pkg/v3/types"
|
||||
)
|
||||
|
||||
var (
|
||||
// indirection for testing
|
||||
lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict
|
||||
resolveTCPAddr = net.ResolveTCPAddr
|
||||
)
|
||||
|
||||
// GetCluster gets the cluster information via DNS discovery.
|
||||
// Also sees each entry as a separate instance.
|
||||
func GetCluster(serviceScheme, service, name, dns string, apurls types.URLs) ([]string, error) {
|
||||
tempName := int(0)
|
||||
tcp2ap := make(map[string]url.URL)
|
||||
|
||||
// First, resolve the apurls
|
||||
for _, url := range apurls {
|
||||
tcpAddr, err := resolveTCPAddr("tcp", url.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcp2ap[tcpAddr.String()] = url
|
||||
}
|
||||
|
||||
stringParts := []string{}
|
||||
updateNodeMap := func(service, scheme string) error {
|
||||
_, addrs, err := lookupSRV(service, "tcp", dns)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, srv := range addrs {
|
||||
port := fmt.Sprintf("%d", srv.Port)
|
||||
host := net.JoinHostPort(srv.Target, port)
|
||||
tcpAddr, terr := resolveTCPAddr("tcp", host)
|
||||
if terr != nil {
|
||||
err = terr
|
||||
continue
|
||||
}
|
||||
n := ""
|
||||
url, ok := tcp2ap[tcpAddr.String()]
|
||||
if ok {
|
||||
n = name
|
||||
}
|
||||
if n == "" {
|
||||
n = fmt.Sprintf("%d", tempName)
|
||||
tempName++
|
||||
}
|
||||
// SRV records have a trailing dot but URL shouldn't.
|
||||
shortHost := strings.TrimSuffix(srv.Target, ".")
|
||||
urlHost := net.JoinHostPort(shortHost, port)
|
||||
if ok && url.Scheme != scheme {
|
||||
err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String())
|
||||
} else {
|
||||
stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost))
|
||||
}
|
||||
}
|
||||
if len(stringParts) == 0 {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
err := updateNodeMap(service, serviceScheme)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error querying DNS SRV records for _%s %s", service, err)
|
||||
}
|
||||
return stringParts, nil
|
||||
}
|
||||
|
||||
type SRVClients struct {
|
||||
Endpoints []string
|
||||
SRVs []*net.SRV
|
||||
}
|
||||
|
||||
// GetClient looks up the client endpoints for a service and domain.
|
||||
func GetClient(service, domain string, serviceName string) (*SRVClients, error) {
|
||||
var urls []*url.URL
|
||||
var srvs []*net.SRV
|
||||
|
||||
updateURLs := func(service, scheme string) error {
|
||||
_, addrs, err := lookupSRV(service, "tcp", domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, srv := range addrs {
|
||||
urls = append(urls, &url.URL{
|
||||
Scheme: scheme,
|
||||
Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)),
|
||||
})
|
||||
}
|
||||
srvs = append(srvs, addrs...)
|
||||
return nil
|
||||
}
|
||||
|
||||
errHTTPS := updateURLs(GetSRVService(service, serviceName, "https"), "https")
|
||||
errHTTP := updateURLs(GetSRVService(service, serviceName, "http"), "http")
|
||||
|
||||
if errHTTPS != nil && errHTTP != nil {
|
||||
return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP)
|
||||
}
|
||||
|
||||
endpoints := make([]string, len(urls))
|
||||
for i := range urls {
|
||||
endpoints[i] = urls[i].String()
|
||||
}
|
||||
return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil
|
||||
}
|
||||
|
||||
// GetSRVService generates a SRV service including an optional suffix.
|
||||
func GetSRVService(service, serviceName string, scheme string) (SRVService string) {
|
||||
if scheme == "https" {
|
||||
service = fmt.Sprintf("%s-ssl", service)
|
||||
}
|
||||
|
||||
if serviceName != "" {
|
||||
return fmt.Sprintf("%s-%s", service, serviceName)
|
||||
}
|
||||
return service
|
||||
}
|
||||
@@ -1,308 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package srv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"go.etcd.io/etcd/pkg/v3/testutil"
|
||||
)
|
||||
|
||||
func notFoundErr(service, proto, domain string) error {
|
||||
name := fmt.Sprintf("_%s._%s.%s", service, proto, domain)
|
||||
return &net.DNSError{Err: "no such host", Name: name, Server: "10.0.0.53:53", IsTimeout: false, IsTemporary: false, IsNotFound: true}
|
||||
}
|
||||
|
||||
func TestSRVGetCluster(t *testing.T) {
|
||||
defer func() {
|
||||
lookupSRV = net.LookupSRV
|
||||
resolveTCPAddr = net.ResolveTCPAddr
|
||||
}()
|
||||
|
||||
hasErr := func(err error) bool {
|
||||
return err != nil
|
||||
}
|
||||
|
||||
name := "dnsClusterTest"
|
||||
dns := map[string]string{
|
||||
"1.example.com.:2480": "10.0.0.1:2480",
|
||||
"2.example.com.:2480": "10.0.0.2:2480",
|
||||
"3.example.com.:2480": "10.0.0.3:2480",
|
||||
"4.example.com.:2380": "10.0.0.3:2380",
|
||||
}
|
||||
srvAll := []*net.SRV{
|
||||
{Target: "1.example.com.", Port: 2480},
|
||||
{Target: "2.example.com.", Port: 2480},
|
||||
{Target: "3.example.com.", Port: 2480},
|
||||
}
|
||||
srvNone := []*net.SRV{}
|
||||
|
||||
tests := []struct {
|
||||
service string
|
||||
scheme string
|
||||
withSSL []*net.SRV
|
||||
withoutSSL []*net.SRV
|
||||
urls []string
|
||||
expected string
|
||||
werr bool
|
||||
}{
|
||||
{
|
||||
"etcd-server-ssl",
|
||||
"https",
|
||||
srvNone,
|
||||
srvNone,
|
||||
nil,
|
||||
"",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"etcd-server-ssl",
|
||||
"https",
|
||||
srvAll,
|
||||
srvNone,
|
||||
nil,
|
||||
"0=https://1.example.com:2480,1=https://2.example.com:2480,2=https://3.example.com:2480",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"etcd-server",
|
||||
"http",
|
||||
srvNone,
|
||||
srvAll,
|
||||
nil,
|
||||
"0=http://1.example.com:2480,1=http://2.example.com:2480,2=http://3.example.com:2480",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"etcd-server-ssl",
|
||||
"https",
|
||||
srvAll,
|
||||
srvNone,
|
||||
[]string{"https://10.0.0.1:2480"},
|
||||
"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
|
||||
false,
|
||||
},
|
||||
// matching local member with resolved addr and return unresolved hostnames
|
||||
{
|
||||
"etcd-server-ssl",
|
||||
"https",
|
||||
srvAll,
|
||||
srvNone,
|
||||
[]string{"https://10.0.0.1:2480"},
|
||||
"dnsClusterTest=https://1.example.com:2480,0=https://2.example.com:2480,1=https://3.example.com:2480",
|
||||
false,
|
||||
},
|
||||
// reject if apurls are TLS but SRV is only http
|
||||
{
|
||||
"etcd-server",
|
||||
"http",
|
||||
srvNone,
|
||||
srvAll,
|
||||
[]string{"https://10.0.0.1:2480"},
|
||||
"0=http://2.example.com:2480,1=http://3.example.com:2480",
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
resolveTCPAddr = func(network, addr string) (*net.TCPAddr, error) {
|
||||
if strings.Contains(addr, "10.0.0.") {
|
||||
// accept IP addresses when resolving apurls
|
||||
return net.ResolveTCPAddr(network, addr)
|
||||
}
|
||||
if dns[addr] == "" {
|
||||
return nil, errors.New("missing dns record")
|
||||
}
|
||||
return net.ResolveTCPAddr(network, dns[addr])
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) {
|
||||
if service == "etcd-server-ssl" {
|
||||
if len(tt.withSSL) > 0 {
|
||||
return "", tt.withSSL, nil
|
||||
}
|
||||
return "", nil, notFoundErr(service, proto, domain)
|
||||
}
|
||||
if service == "etcd-server" {
|
||||
if len(tt.withoutSSL) > 0 {
|
||||
return "", tt.withoutSSL, nil
|
||||
}
|
||||
return "", nil, notFoundErr(service, proto, domain)
|
||||
}
|
||||
return "", nil, errors.New("unknown service in mock")
|
||||
}
|
||||
|
||||
urls := testutil.MustNewURLs(t, tt.urls)
|
||||
str, err := GetCluster(tt.scheme, tt.service, name, "example.com", urls)
|
||||
|
||||
if hasErr(err) != tt.werr {
|
||||
t.Fatalf("%d: err = %#v, want = %#v", i, err, tt.werr)
|
||||
}
|
||||
if strings.Join(str, ",") != tt.expected {
|
||||
t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSRVDiscover(t *testing.T) {
|
||||
defer func() { lookupSRV = net.LookupSRV }()
|
||||
|
||||
hasErr := func(err error) bool {
|
||||
return err != nil
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
withSSL []*net.SRV
|
||||
withoutSSL []*net.SRV
|
||||
expected []string
|
||||
werr bool
|
||||
}{
|
||||
{
|
||||
[]*net.SRV{},
|
||||
[]*net.SRV{},
|
||||
[]string{},
|
||||
true,
|
||||
},
|
||||
{
|
||||
[]*net.SRV{},
|
||||
[]*net.SRV{
|
||||
{Target: "10.0.0.1", Port: 2480},
|
||||
{Target: "10.0.0.2", Port: 2480},
|
||||
{Target: "10.0.0.3", Port: 2480},
|
||||
},
|
||||
[]string{"http://10.0.0.1:2480", "http://10.0.0.2:2480", "http://10.0.0.3:2480"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
[]*net.SRV{
|
||||
{Target: "10.0.0.1", Port: 2480},
|
||||
{Target: "10.0.0.2", Port: 2480},
|
||||
{Target: "10.0.0.3", Port: 2480},
|
||||
},
|
||||
[]*net.SRV{},
|
||||
[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
[]*net.SRV{
|
||||
{Target: "10.0.0.1", Port: 2480},
|
||||
{Target: "10.0.0.2", Port: 2480},
|
||||
{Target: "10.0.0.3", Port: 2480},
|
||||
},
|
||||
[]*net.SRV{
|
||||
{Target: "10.0.0.1", Port: 7001},
|
||||
},
|
||||
[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
[]*net.SRV{
|
||||
{Target: "10.0.0.1", Port: 2480},
|
||||
{Target: "10.0.0.2", Port: 2480},
|
||||
{Target: "10.0.0.3", Port: 2480},
|
||||
},
|
||||
[]*net.SRV{
|
||||
{Target: "10.0.0.1", Port: 7001},
|
||||
},
|
||||
[]string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"},
|
||||
false,
|
||||
},
|
||||
{
|
||||
[]*net.SRV{
|
||||
{Target: "a.example.com", Port: 2480},
|
||||
{Target: "b.example.com", Port: 2480},
|
||||
{Target: "c.example.com", Port: 2480},
|
||||
},
|
||||
[]*net.SRV{},
|
||||
[]string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) {
|
||||
if service == "etcd-client-ssl" {
|
||||
if len(tt.withSSL) > 0 {
|
||||
return "", tt.withSSL, nil
|
||||
}
|
||||
return "", nil, notFoundErr(service, proto, domain)
|
||||
}
|
||||
if service == "etcd-client" {
|
||||
if len(tt.withoutSSL) > 0 {
|
||||
return "", tt.withoutSSL, nil
|
||||
}
|
||||
return "", nil, notFoundErr(service, proto, domain)
|
||||
}
|
||||
return "", nil, errors.New("unknown service in mock")
|
||||
}
|
||||
|
||||
srvs, err := GetClient("etcd-client", "example.com", "")
|
||||
|
||||
if hasErr(err) != tt.werr {
|
||||
t.Fatalf("%d: err = %#v, want = %#v", i, err, tt.werr)
|
||||
}
|
||||
if srvs == nil {
|
||||
if len(tt.expected) > 0 {
|
||||
t.Errorf("#%d: srvs = nil, want non-nil", i)
|
||||
}
|
||||
} else {
|
||||
if !reflect.DeepEqual(srvs.Endpoints, tt.expected) {
|
||||
t.Errorf("#%d: endpoints = %v, want = %v", i, srvs.Endpoints, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSRVService(t *testing.T) {
|
||||
tests := []struct {
|
||||
scheme string
|
||||
serviceName string
|
||||
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
"https",
|
||||
"",
|
||||
"etcd-client-ssl",
|
||||
},
|
||||
{
|
||||
"http",
|
||||
"",
|
||||
"etcd-client",
|
||||
},
|
||||
{
|
||||
"https",
|
||||
"foo",
|
||||
"etcd-client-ssl-foo",
|
||||
},
|
||||
{
|
||||
"http",
|
||||
"bar",
|
||||
"etcd-client-bar",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
service := GetSRVService("etcd-client", tt.serviceName, tt.scheme)
|
||||
if strings.Compare(service, tt.expected) != 0 {
|
||||
t.Errorf("#%d: service = %s, want %s", i, service, tt.expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package systemd provides utility functions for systemd.
|
||||
package systemd
|
||||
@@ -1,29 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package systemd
|
||||
|
||||
import "net"
|
||||
|
||||
// DialJournal returns no error if the process can dial journal socket.
|
||||
// Returns an error if dial failed, which indicates journald is not available
|
||||
// (e.g. run embedded etcd as docker daemon).
|
||||
// Reference: https://github.com/coreos/go-systemd/blob/master/journal/journal.go.
|
||||
func DialJournal() error {
|
||||
conn, err := net.Dial("unixgram", "/run/systemd/journal/socket")
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -1,67 +0,0 @@
|
||||
// Copyright 2017 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func AssertEqual(t *testing.T, e, a interface{}, msg ...string) {
|
||||
t.Helper()
|
||||
if (e == nil || a == nil) && (isNil(e) && isNil(a)) {
|
||||
return
|
||||
}
|
||||
if reflect.DeepEqual(e, a) {
|
||||
return
|
||||
}
|
||||
s := ""
|
||||
if len(msg) > 1 {
|
||||
s = msg[0] + ": "
|
||||
}
|
||||
s = fmt.Sprintf("%sexpected %+v, got %+v", s, e, a)
|
||||
FatalStack(t, s)
|
||||
}
|
||||
|
||||
func AssertNil(t *testing.T, v interface{}) {
|
||||
t.Helper()
|
||||
AssertEqual(t, nil, v)
|
||||
}
|
||||
|
||||
func AssertNotNil(t *testing.T, v interface{}) {
|
||||
t.Helper()
|
||||
if v == nil {
|
||||
t.Fatalf("expected non-nil, got %+v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func AssertTrue(t *testing.T, v bool, msg ...string) {
|
||||
t.Helper()
|
||||
AssertEqual(t, true, v, msg...)
|
||||
}
|
||||
|
||||
func AssertFalse(t *testing.T, v bool, msg ...string) {
|
||||
t.Helper()
|
||||
AssertEqual(t, false, v, msg...)
|
||||
}
|
||||
|
||||
func isNil(v interface{}) bool {
|
||||
if v == nil {
|
||||
return true
|
||||
}
|
||||
rv := reflect.ValueOf(v)
|
||||
return rv.Kind() != reflect.Struct && rv.IsNil()
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
// Copyright 2013 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
/*
|
||||
CheckLeakedGoroutine verifies tests do not leave any leaky
|
||||
goroutines. It returns true when there are goroutines still
|
||||
running(leaking) after all tests.
|
||||
|
||||
import "go.etcd.io/etcd/pkg/v3/testutil"
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
testutil.MustTestMainWithLeakDetection(m)
|
||||
}
|
||||
|
||||
func TestSample(t *testing.T) {
|
||||
BeforeTest(t)
|
||||
...
|
||||
}
|
||||
|
||||
*/
|
||||
func CheckLeakedGoroutine() bool {
|
||||
gs := interestingGoroutines()
|
||||
if len(gs) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
stackCount := make(map[string]int)
|
||||
re := regexp.MustCompile(`\(0[0-9a-fx, ]*\)`)
|
||||
for _, g := range gs {
|
||||
// strip out pointer arguments in first function of stack dump
|
||||
normalized := string(re.ReplaceAll([]byte(g), []byte("(...)")))
|
||||
stackCount[normalized]++
|
||||
}
|
||||
|
||||
fmt.Fprintf(os.Stderr, "Unexpected goroutines running after all test(s).\n")
|
||||
for stack, count := range stackCount {
|
||||
fmt.Fprintf(os.Stderr, "%d instances of:\n%s\n", count, stack)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckAfterTest returns an error if AfterTest would fail with an error.
|
||||
// Waits for go-routines shutdown for 'd'.
|
||||
func CheckAfterTest(d time.Duration) error {
|
||||
http.DefaultTransport.(*http.Transport).CloseIdleConnections()
|
||||
var bad string
|
||||
// Presence of these goroutines causes immediate test failure.
|
||||
badSubstring := map[string]string{
|
||||
").writeLoop(": "a Transport",
|
||||
"created by net/http/httptest.(*Server).Start": "an httptest.Server",
|
||||
"timeoutHandler": "a TimeoutHandler",
|
||||
"net.(*netFD).connect(": "a timing out dial",
|
||||
").noteClientGone(": "a closenotifier sender",
|
||||
").readLoop(": "a Transport",
|
||||
".grpc": "a gRPC resource",
|
||||
").sendCloseSubstream(": "a stream closing routine",
|
||||
}
|
||||
|
||||
var stacks string
|
||||
begin := time.Now()
|
||||
for time.Since(begin) < d {
|
||||
bad = ""
|
||||
goroutines := interestingGoroutines()
|
||||
if len(goroutines) == 0 {
|
||||
return nil
|
||||
}
|
||||
stacks = strings.Join(goroutines, "\n\n")
|
||||
|
||||
for substr, what := range badSubstring {
|
||||
if strings.Contains(stacks, substr) {
|
||||
bad = what
|
||||
}
|
||||
}
|
||||
// Undesired goroutines found, but goroutines might just still be
|
||||
// shutting down, so give it some time.
|
||||
runtime.Gosched()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
}
|
||||
return fmt.Errorf("appears to have leaked %s:\n%s", bad, stacks)
|
||||
}
|
||||
|
||||
// BeforeTest is a convenient way to register before-and-after code to a test.
|
||||
// If you execute BeforeTest, you don't need to explicitly register AfterTest.
|
||||
func BeforeTest(t TB) {
|
||||
if err := CheckAfterTest(10 * time.Millisecond); err != nil {
|
||||
t.Skip("Found leaked goroutined BEFORE test", err)
|
||||
return
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
AfterTest(t)
|
||||
})
|
||||
}
|
||||
|
||||
// AfterTest is meant to run in a defer that executes after a test completes.
|
||||
// It will detect common goroutine leaks, retrying in case there are goroutines
|
||||
// not synchronously torn down, and fail the test if any goroutines are stuck.
|
||||
func AfterTest(t TB) {
|
||||
if err := CheckAfterTest(1 * time.Second); err != nil {
|
||||
t.Errorf("Test %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func interestingGoroutines() (gs []string) {
|
||||
buf := make([]byte, 2<<20)
|
||||
buf = buf[:runtime.Stack(buf, true)]
|
||||
for _, g := range strings.Split(string(buf), "\n\n") {
|
||||
sl := strings.SplitN(g, "\n", 2)
|
||||
if len(sl) != 2 {
|
||||
continue
|
||||
}
|
||||
stack := strings.TrimSpace(sl[1])
|
||||
if stack == "" ||
|
||||
strings.Contains(stack, "sync.(*WaitGroup).Done") ||
|
||||
strings.Contains(stack, "os.(*file).close") ||
|
||||
strings.Contains(stack, "os.(*Process).Release") ||
|
||||
strings.Contains(stack, "created by os/signal.init") ||
|
||||
strings.Contains(stack, "runtime/panic.go") ||
|
||||
strings.Contains(stack, "created by testing.RunTests") ||
|
||||
strings.Contains(stack, "created by testing.runTests") ||
|
||||
strings.Contains(stack, "created by testing.(*T).Run") ||
|
||||
strings.Contains(stack, "testing.Main(") ||
|
||||
strings.Contains(stack, "runtime.goexit") ||
|
||||
strings.Contains(stack, "go.etcd.io/etcd/pkg/v3/testutil.interestingGoroutines") ||
|
||||
strings.Contains(stack, "go.etcd.io/etcd/pkg/v3/logutil.(*MergeLogger).outputLoop") ||
|
||||
strings.Contains(stack, "github.com/golang/glog.(*loggingT).flushDaemon") ||
|
||||
strings.Contains(stack, "created by runtime.gc") ||
|
||||
strings.Contains(stack, "created by text/template/parse.lex") ||
|
||||
strings.Contains(stack, "runtime.MHeap_Scavenger") ||
|
||||
strings.Contains(stack, "rcrypto/internal/boring.(*PublicKeyRSA).finalize") ||
|
||||
strings.Contains(stack, "net.(*netFD).Close(") ||
|
||||
strings.Contains(stack, "testing.(*T).Run") {
|
||||
continue
|
||||
}
|
||||
gs = append(gs, stack)
|
||||
}
|
||||
sort.Strings(gs)
|
||||
return gs
|
||||
}
|
||||
|
||||
func MustCheckLeakedGoroutine() {
|
||||
http.DefaultTransport.(*http.Transport).CloseIdleConnections()
|
||||
|
||||
CheckAfterTest(5 * time.Second)
|
||||
|
||||
// Let the other goroutines finalize.
|
||||
runtime.Gosched()
|
||||
|
||||
if CheckLeakedGoroutine() {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// MustTestMainWithLeakDetection expands standard m.Run with leaked
|
||||
// goroutines detection.
|
||||
func MustTestMainWithLeakDetection(m *testing.M) {
|
||||
v := m.Run()
|
||||
if v == 0 {
|
||||
MustCheckLeakedGoroutine()
|
||||
}
|
||||
os.Exit(v)
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// so tests pass if given a -run that doesn't include TestSample
|
||||
var ranSample = false
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
m.Run()
|
||||
isLeaked := CheckLeakedGoroutine()
|
||||
if ranSample && !isLeaked {
|
||||
fmt.Fprintln(os.Stderr, "expected leaky goroutines but none is detected")
|
||||
os.Exit(1)
|
||||
}
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func TestSample(t *testing.T) {
|
||||
SkipTestIfShortMode(t, "Counting leaked routines is disabled in --short tests")
|
||||
defer AfterTest(t)
|
||||
ranSample = true
|
||||
for range make([]struct{}, 100) {
|
||||
go func() {
|
||||
select {}
|
||||
}()
|
||||
}
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type PauseableHandler struct {
|
||||
Next http.Handler
|
||||
mu sync.Mutex
|
||||
paused bool
|
||||
}
|
||||
|
||||
func (ph *PauseableHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
ph.mu.Lock()
|
||||
paused := ph.paused
|
||||
ph.mu.Unlock()
|
||||
if !paused {
|
||||
ph.Next.ServeHTTP(w, r)
|
||||
} else {
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
panic("webserver doesn't support hijacking")
|
||||
}
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *PauseableHandler) Pause() {
|
||||
ph.mu.Lock()
|
||||
defer ph.mu.Unlock()
|
||||
ph.paused = true
|
||||
}
|
||||
|
||||
func (ph *PauseableHandler) Resume() {
|
||||
ph.mu.Lock()
|
||||
defer ph.mu.Unlock()
|
||||
ph.paused = false
|
||||
}
|
||||
@@ -1,139 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Action struct {
|
||||
Name string
|
||||
Params []interface{}
|
||||
}
|
||||
|
||||
type Recorder interface {
|
||||
// Record publishes an Action (e.g., function call) which will
|
||||
// be reflected by Wait() or Chan()
|
||||
Record(a Action)
|
||||
// Wait waits until at least n Actions are available or returns with error
|
||||
Wait(n int) ([]Action, error)
|
||||
// Action returns immediately available Actions
|
||||
Action() []Action
|
||||
// Chan returns the channel for actions published by Record
|
||||
Chan() <-chan Action
|
||||
}
|
||||
|
||||
// RecorderBuffered appends all Actions to a slice
|
||||
type RecorderBuffered struct {
|
||||
sync.Mutex
|
||||
actions []Action
|
||||
}
|
||||
|
||||
func (r *RecorderBuffered) Record(a Action) {
|
||||
r.Lock()
|
||||
r.actions = append(r.actions, a)
|
||||
r.Unlock()
|
||||
}
|
||||
|
||||
func (r *RecorderBuffered) Action() []Action {
|
||||
r.Lock()
|
||||
cpy := make([]Action, len(r.actions))
|
||||
copy(cpy, r.actions)
|
||||
r.Unlock()
|
||||
return cpy
|
||||
}
|
||||
|
||||
func (r *RecorderBuffered) Wait(n int) (acts []Action, err error) {
|
||||
// legacy racey behavior
|
||||
WaitSchedule()
|
||||
acts = r.Action()
|
||||
if len(acts) < n {
|
||||
err = newLenErr(n, len(acts))
|
||||
}
|
||||
return acts, err
|
||||
}
|
||||
|
||||
func (r *RecorderBuffered) Chan() <-chan Action {
|
||||
ch := make(chan Action)
|
||||
go func() {
|
||||
acts := r.Action()
|
||||
for i := range acts {
|
||||
ch <- acts[i]
|
||||
}
|
||||
close(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
// RecorderStream writes all Actions to an unbuffered channel
|
||||
type recorderStream struct {
|
||||
ch chan Action
|
||||
waitTimeout time.Duration
|
||||
}
|
||||
|
||||
func NewRecorderStream() Recorder {
|
||||
return NewRecorderStreamWithWaitTimout(time.Duration(5 * time.Second))
|
||||
}
|
||||
|
||||
func NewRecorderStreamWithWaitTimout(waitTimeout time.Duration) Recorder {
|
||||
return &recorderStream{ch: make(chan Action), waitTimeout: waitTimeout}
|
||||
}
|
||||
|
||||
func (r *recorderStream) Record(a Action) {
|
||||
r.ch <- a
|
||||
}
|
||||
|
||||
func (r *recorderStream) Action() (acts []Action) {
|
||||
for {
|
||||
select {
|
||||
case act := <-r.ch:
|
||||
acts = append(acts, act)
|
||||
default:
|
||||
return acts
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *recorderStream) Chan() <-chan Action {
|
||||
return r.ch
|
||||
}
|
||||
|
||||
func (r *recorderStream) Wait(n int) ([]Action, error) {
|
||||
acts := make([]Action, n)
|
||||
timeoutC := time.After(r.waitTimeout)
|
||||
for i := 0; i < n; i++ {
|
||||
select {
|
||||
case acts[i] = <-r.ch:
|
||||
case <-timeoutC:
|
||||
acts = acts[:i]
|
||||
return acts, newLenErr(n, i)
|
||||
}
|
||||
}
|
||||
// extra wait to catch any Action spew
|
||||
select {
|
||||
case act := <-r.ch:
|
||||
acts = append(acts, act)
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
}
|
||||
return acts, nil
|
||||
}
|
||||
|
||||
func newLenErr(expected int, actual int) error {
|
||||
s := fmt.Sprintf("len(actions) = %d, expected >= %d", actual, expected)
|
||||
return errors.New(s)
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
// Copyright 2021 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
||||
// TB is a subset of methods of testing.TB interface.
|
||||
// We cannot implement testing.TB due to protection, so we expose this simplified interface.
|
||||
type TB interface {
|
||||
Cleanup(func())
|
||||
Error(args ...interface{})
|
||||
Errorf(format string, args ...interface{})
|
||||
Fail()
|
||||
FailNow()
|
||||
Failed() bool
|
||||
Fatal(args ...interface{})
|
||||
Fatalf(format string, args ...interface{})
|
||||
Logf(format string, args ...interface{})
|
||||
Name() string
|
||||
TempDir() string
|
||||
Helper()
|
||||
Skip(args ...interface{})
|
||||
}
|
||||
|
||||
// NewTestingTBProthesis creates a fake variant of testing.TB implementation.
|
||||
// It's supposed to be used in contexts were real testing.T is not provided,
|
||||
// e.g. in 'examples'.
|
||||
//
|
||||
// The `closef` goroutine should get executed when tb will not be needed any longer.
|
||||
//
|
||||
// The provided implementation is NOT thread safe (Cleanup() method).
|
||||
func NewTestingTBProthesis(name string) (tb TB, closef func()) {
|
||||
testtb := &testingTBProthesis{name: name}
|
||||
return testtb, testtb.close
|
||||
}
|
||||
|
||||
type testingTBProthesis struct {
|
||||
name string
|
||||
failed bool
|
||||
cleanups []func()
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Helper() {
|
||||
// Ignored
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Skip(args ...interface{}) {
|
||||
t.Log(append([]interface{}{"Skipping due to: "}, args...))
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Cleanup(f func()) {
|
||||
t.cleanups = append(t.cleanups, f)
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Error(args ...interface{}) {
|
||||
log.Println(args...)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Errorf(format string, args ...interface{}) {
|
||||
log.Printf(format, args...)
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Fail() {
|
||||
t.failed = true
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) FailNow() {
|
||||
t.failed = true
|
||||
panic("FailNow() called")
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Failed() bool {
|
||||
return t.failed
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Fatal(args ...interface{}) {
|
||||
log.Fatalln(args...)
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Fatalf(format string, args ...interface{}) {
|
||||
log.Fatalf(format, args...)
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Logf(format string, args ...interface{}) {
|
||||
log.Printf(format, args...)
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Log(args ...interface{}) {
|
||||
log.Println(args...)
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) TempDir() string {
|
||||
dir, err := ioutil.TempDir("", t.name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.cleanups = append([]func(){func() {
|
||||
t.Logf("Cleaning UP: %v", dir)
|
||||
os.RemoveAll(dir)
|
||||
}}, t.cleanups...)
|
||||
return dir
|
||||
}
|
||||
|
||||
func (t *testingTBProthesis) close() {
|
||||
for i := len(t.cleanups) - 1; i >= 0; i-- {
|
||||
t.cleanups[i]()
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package testutil provides test utility functions.
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// WaitSchedule briefly sleeps in order to invoke the go scheduler.
|
||||
// TODO: improve this when we are able to know the schedule or status of target go-routine.
|
||||
func WaitSchedule() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
|
||||
func MustNewURLs(t *testing.T, urls []string) []url.URL {
|
||||
if urls == nil {
|
||||
return nil
|
||||
}
|
||||
var us []url.URL
|
||||
for _, url := range urls {
|
||||
u := MustNewURL(t, url)
|
||||
us = append(us, *u)
|
||||
}
|
||||
return us
|
||||
}
|
||||
|
||||
func MustNewURL(t *testing.T, s string) *url.URL {
|
||||
u, err := url.Parse(s)
|
||||
if err != nil {
|
||||
t.Fatalf("parse %v error: %v", s, err)
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// FatalStack helps to fatal the test and print out the stacks of all running goroutines.
|
||||
func FatalStack(t *testing.T, s string) {
|
||||
stackTrace := make([]byte, 1024*1024)
|
||||
n := runtime.Stack(stackTrace, true)
|
||||
t.Error(string(stackTrace[:n]))
|
||||
t.Fatalf(s)
|
||||
}
|
||||
|
||||
// ConditionFunc returns true when a condition is met.
|
||||
type ConditionFunc func() (bool, error)
|
||||
|
||||
// Poll calls a condition function repeatedly on a polling interval until it returns true, returns an error
|
||||
// or the timeout is reached. If the condition function returns true or an error before the timeout, Poll
|
||||
// immediately returns with the true value or the error. If the timeout is exceeded, Poll returns false.
|
||||
func Poll(interval time.Duration, timeout time.Duration, condition ConditionFunc) (bool, error) {
|
||||
timeoutCh := time.After(timeout)
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeoutCh:
|
||||
return false, nil
|
||||
case <-ticker.C:
|
||||
success, err := condition()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if success {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func SkipTestIfShortMode(t TB, reason string) {
|
||||
if t != nil {
|
||||
t.Helper()
|
||||
if testing.Short() {
|
||||
t.Skip(reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ExitInShortMode closes the current process (with 0) if the short test mode detected.
|
||||
//
|
||||
// To be used in Test-main, where test context (testing.TB) is not available.
|
||||
//
|
||||
// Requires custom env-variable (GOLANG_TEST_SHORT) apart of `go test --short flag`.
|
||||
func ExitInShortMode(reason string) {
|
||||
if os.Getenv("GOLANG_TEST_SHORT") == "true" {
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package testutil
|
||||
|
||||
import "time"
|
||||
|
||||
var (
|
||||
ApplyTimeout = time.Second
|
||||
RequestTimeout = 3 * time.Second
|
||||
)
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tlsutil
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
// GetCipherSuite returns the corresponding cipher suite,
|
||||
// and boolean value if it is supported.
|
||||
func GetCipherSuite(s string) (uint16, bool) {
|
||||
for _, c := range tls.CipherSuites() {
|
||||
if s == c.Name {
|
||||
return c.ID, true
|
||||
}
|
||||
}
|
||||
for _, c := range tls.InsecureCipherSuites() {
|
||||
if s == c.Name {
|
||||
return c.ID, true
|
||||
}
|
||||
}
|
||||
switch s {
|
||||
case "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305":
|
||||
return tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, true
|
||||
case "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305":
|
||||
return tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, true
|
||||
}
|
||||
return 0, false
|
||||
}
|
||||
@@ -1,50 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tlsutil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetCipherSuite_not_existing(t *testing.T) {
|
||||
_, ok := GetCipherSuite("not_existing")
|
||||
if ok {
|
||||
t.Fatal("Expected not ok")
|
||||
}
|
||||
}
|
||||
|
||||
func CipherSuiteExpectedToExist(tb testing.TB, cipher string, expectedId uint16) {
|
||||
vid, ok := GetCipherSuite(cipher)
|
||||
if !ok {
|
||||
tb.Errorf("Expected %v cipher to exist", cipher)
|
||||
}
|
||||
if vid != expectedId {
|
||||
tb.Errorf("For %v expected=%v found=%v", cipher, expectedId, vid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetCipherSuite_success(t *testing.T) {
|
||||
CipherSuiteExpectedToExist(t, "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA)
|
||||
CipherSuiteExpectedToExist(t, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256)
|
||||
|
||||
// Explicit test for legacy names
|
||||
CipherSuiteExpectedToExist(t, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305", tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256)
|
||||
CipherSuiteExpectedToExist(t, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305", tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256)
|
||||
}
|
||||
|
||||
func TestGetCipherSuite_insecure(t *testing.T) {
|
||||
CipherSuiteExpectedToExist(t, "TLS_ECDHE_RSA_WITH_RC4_128_SHA", tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA)
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package tlsutil provides utility functions for handling TLS.
|
||||
package tlsutil
|
||||
@@ -1,73 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tlsutil
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// NewCertPool creates x509 certPool with provided CA files.
|
||||
func NewCertPool(CAFiles []string) (*x509.CertPool, error) {
|
||||
certPool := x509.NewCertPool()
|
||||
|
||||
for _, CAFile := range CAFiles {
|
||||
pemByte, err := ioutil.ReadFile(CAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for {
|
||||
var block *pem.Block
|
||||
block, pemByte = pem.Decode(pemByte)
|
||||
if block == nil {
|
||||
break
|
||||
}
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
certPool.AddCert(cert)
|
||||
}
|
||||
}
|
||||
|
||||
return certPool, nil
|
||||
}
|
||||
|
||||
// NewCert generates TLS cert by using the given cert,key and parse function.
|
||||
func NewCert(certfile, keyfile string, parseFunc func([]byte, []byte) (tls.Certificate, error)) (*tls.Certificate, error) {
|
||||
cert, err := ioutil.ReadFile(certfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := ioutil.ReadFile(keyfile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if parseFunc == nil {
|
||||
parseFunc = tls.X509KeyPair
|
||||
}
|
||||
|
||||
tlsCert, err := parseFunc(cert, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &tlsCert, nil
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package transport implements various HTTP transport utilities based on Go
|
||||
// net package.
|
||||
package transport
|
||||
@@ -1,94 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type keepAliveConn interface {
|
||||
SetKeepAlive(bool) error
|
||||
SetKeepAlivePeriod(d time.Duration) error
|
||||
}
|
||||
|
||||
// NewKeepAliveListener returns a listener that listens on the given address.
|
||||
// Be careful when wrap around KeepAliveListener with another Listener if TLSInfo is not nil.
|
||||
// Some pkgs (like go/http) might expect Listener to return TLSConn type to start TLS handshake.
|
||||
// http://tldp.org/HOWTO/TCP-Keepalive-HOWTO/overview.html
|
||||
func NewKeepAliveListener(l net.Listener, scheme string, tlscfg *tls.Config) (net.Listener, error) {
|
||||
if scheme == "https" {
|
||||
if tlscfg == nil {
|
||||
return nil, fmt.Errorf("cannot listen on TLS for given listener: KeyFile and CertFile are not presented")
|
||||
}
|
||||
return newTLSKeepaliveListener(l, tlscfg), nil
|
||||
}
|
||||
|
||||
return &keepaliveListener{
|
||||
Listener: l,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type keepaliveListener struct{ net.Listener }
|
||||
|
||||
func (kln *keepaliveListener) Accept() (net.Conn, error) {
|
||||
c, err := kln.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kac := c.(keepAliveConn)
|
||||
// detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl
|
||||
// default on linux: 30 + 8 * 30
|
||||
// default on osx: 30 + 8 * 75
|
||||
kac.SetKeepAlive(true)
|
||||
kac.SetKeepAlivePeriod(30 * time.Second)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// A tlsKeepaliveListener implements a network listener (net.Listener) for TLS connections.
|
||||
type tlsKeepaliveListener struct {
|
||||
net.Listener
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next incoming TLS connection.
|
||||
// The returned connection c is a *tls.Conn.
|
||||
func (l *tlsKeepaliveListener) Accept() (c net.Conn, err error) {
|
||||
c, err = l.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
kac := c.(keepAliveConn)
|
||||
// detection time: tcp_keepalive_time + tcp_keepalive_probes + tcp_keepalive_intvl
|
||||
// default on linux: 30 + 8 * 30
|
||||
// default on osx: 30 + 8 * 75
|
||||
kac.SetKeepAlive(true)
|
||||
kac.SetKeepAlivePeriod(30 * time.Second)
|
||||
c = tls.Server(c, l.config)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// NewListener creates a Listener which accepts connections from an inner
|
||||
// Listener and wraps each connection with Server.
|
||||
// The configuration config must be non-nil and must have
|
||||
// at least one certificate.
|
||||
func newTLSKeepaliveListener(inner net.Listener, config *tls.Config) net.Listener {
|
||||
l := &tlsKeepaliveListener{}
|
||||
l.Listener = inner
|
||||
l.config = config
|
||||
return l
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestNewKeepAliveListener tests NewKeepAliveListener returns a listener
|
||||
// that accepts connections.
|
||||
// TODO: verify the keepalive option is set correctly
|
||||
func TestNewKeepAliveListener(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected listen error: %v", err)
|
||||
}
|
||||
|
||||
ln, err = NewKeepAliveListener(ln, "http", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
|
||||
}
|
||||
|
||||
go http.Get("http://" + ln.Addr().String())
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Accept error: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
ln.Close()
|
||||
|
||||
ln, err = net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Listen error: %v", err)
|
||||
}
|
||||
|
||||
// tls
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create tmpfile: %v", err)
|
||||
}
|
||||
defer del()
|
||||
tlsInfo := TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile}
|
||||
tlsInfo.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
||||
tlscfg, err := tlsInfo.ServerConfig()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected serverConfig error: %v", err)
|
||||
}
|
||||
tlsln, err := NewKeepAliveListener(ln, "https", tlscfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewKeepAliveListener error: %v", err)
|
||||
}
|
||||
|
||||
go http.Get("https://" + tlsln.Addr().String())
|
||||
conn, err = tlsln.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Accept error: %v", err)
|
||||
}
|
||||
if _, ok := conn.(*tls.Conn); !ok {
|
||||
t.Errorf("failed to accept *tls.Conn")
|
||||
}
|
||||
conn.Close()
|
||||
tlsln.Close()
|
||||
}
|
||||
|
||||
func TestNewKeepAliveListenerTLSEmptyConfig(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected listen error: %v", err)
|
||||
}
|
||||
|
||||
_, err = NewKeepAliveListener(ln, "https", nil)
|
||||
if err == nil {
|
||||
t.Errorf("err = nil, want not presented error")
|
||||
}
|
||||
}
|
||||
@@ -1,80 +0,0 @@
|
||||
// Copyright 2013 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package transport provides network utility functions, complementing the more
|
||||
// common ones in the net package.
|
||||
package transport
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrNotTCP = errors.New("only tcp connections have keepalive")
|
||||
)
|
||||
|
||||
// LimitListener returns a Listener that accepts at most n simultaneous
|
||||
// connections from the provided Listener.
|
||||
func LimitListener(l net.Listener, n int) net.Listener {
|
||||
return &limitListener{l, make(chan struct{}, n)}
|
||||
}
|
||||
|
||||
type limitListener struct {
|
||||
net.Listener
|
||||
sem chan struct{}
|
||||
}
|
||||
|
||||
func (l *limitListener) acquire() { l.sem <- struct{}{} }
|
||||
func (l *limitListener) release() { <-l.sem }
|
||||
|
||||
func (l *limitListener) Accept() (net.Conn, error) {
|
||||
l.acquire()
|
||||
c, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.release()
|
||||
return nil, err
|
||||
}
|
||||
return &limitListenerConn{Conn: c, release: l.release}, nil
|
||||
}
|
||||
|
||||
type limitListenerConn struct {
|
||||
net.Conn
|
||||
releaseOnce sync.Once
|
||||
release func()
|
||||
}
|
||||
|
||||
func (l *limitListenerConn) Close() error {
|
||||
err := l.Conn.Close()
|
||||
l.releaseOnce.Do(l.release)
|
||||
return err
|
||||
}
|
||||
|
||||
func (l *limitListenerConn) SetKeepAlive(doKeepAlive bool) error {
|
||||
tcpc, ok := l.Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return ErrNotTCP
|
||||
}
|
||||
return tcpc.SetKeepAlive(doKeepAlive)
|
||||
}
|
||||
|
||||
func (l *limitListenerConn) SetKeepAlivePeriod(d time.Duration) error {
|
||||
tcpc, ok := l.Conn.(*net.TCPConn)
|
||||
if !ok {
|
||||
return ErrNotTCP
|
||||
}
|
||||
return tcpc.SetKeepAlivePeriod(d)
|
||||
}
|
||||
@@ -1,562 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"go.etcd.io/etcd/pkg/v3/fileutil"
|
||||
"go.etcd.io/etcd/pkg/v3/tlsutil"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// NewListener creates a new listner.
|
||||
func NewListener(addr, scheme string, tlsinfo *TLSInfo) (l net.Listener, err error) {
|
||||
return newListener(addr, scheme, WithTLSInfo(tlsinfo))
|
||||
}
|
||||
|
||||
// NewListenerWithOpts creates a new listener which accpets listener options.
|
||||
func NewListenerWithOpts(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
|
||||
return newListener(addr, scheme, opts...)
|
||||
}
|
||||
|
||||
func newListener(addr, scheme string, opts ...ListenerOption) (net.Listener, error) {
|
||||
if scheme == "unix" || scheme == "unixs" {
|
||||
// unix sockets via unix://laddr
|
||||
return NewUnixListener(addr)
|
||||
}
|
||||
|
||||
lnOpts := newListenOpts(opts...)
|
||||
|
||||
switch {
|
||||
case lnOpts.IsSocketOpts():
|
||||
// new ListenConfig with socket options.
|
||||
config, err := newListenConfig(lnOpts.socketOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lnOpts.ListenConfig = config
|
||||
// check for timeout
|
||||
fallthrough
|
||||
case lnOpts.IsTimeout(), lnOpts.IsSocketOpts():
|
||||
// timeout listener with socket options.
|
||||
ln, err := lnOpts.ListenConfig.Listen(context.TODO(), "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lnOpts.Listener = &rwTimeoutListener{
|
||||
Listener: ln,
|
||||
readTimeout: lnOpts.readTimeout,
|
||||
writeTimeout: lnOpts.writeTimeout,
|
||||
}
|
||||
case lnOpts.IsTimeout():
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lnOpts.Listener = &rwTimeoutListener{
|
||||
Listener: ln,
|
||||
readTimeout: lnOpts.readTimeout,
|
||||
writeTimeout: lnOpts.writeTimeout,
|
||||
}
|
||||
default:
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lnOpts.Listener = ln
|
||||
}
|
||||
|
||||
// only skip if not passing TLSInfo
|
||||
if lnOpts.skipTLSInfoCheck && !lnOpts.IsTLS() {
|
||||
return lnOpts.Listener, nil
|
||||
}
|
||||
return wrapTLS(scheme, lnOpts.tlsInfo, lnOpts.Listener)
|
||||
}
|
||||
|
||||
func wrapTLS(scheme string, tlsinfo *TLSInfo, l net.Listener) (net.Listener, error) {
|
||||
if scheme != "https" && scheme != "unixs" {
|
||||
return l, nil
|
||||
}
|
||||
if tlsinfo != nil && tlsinfo.SkipClientSANVerify {
|
||||
return NewTLSListener(l, tlsinfo)
|
||||
}
|
||||
return newTLSListener(l, tlsinfo, checkSAN)
|
||||
}
|
||||
|
||||
func newListenConfig(sopts *SocketOpts) (net.ListenConfig, error) {
|
||||
lc := net.ListenConfig{}
|
||||
if sopts != nil {
|
||||
ctls := getControls(sopts)
|
||||
if len(ctls) > 0 {
|
||||
lc.Control = ctls.Control
|
||||
}
|
||||
}
|
||||
return lc, nil
|
||||
}
|
||||
|
||||
type TLSInfo struct {
|
||||
// CertFile is the _server_ cert, it will also be used as a _client_ certificate if ClientCertFile is empty
|
||||
CertFile string
|
||||
// KeyFile is the key for the CertFile
|
||||
KeyFile string
|
||||
// ClientCertFile is a _client_ cert for initiating connections when ClientCertAuth is defined. If ClientCertAuth
|
||||
// is true but this value is empty, the CertFile will be used instead.
|
||||
ClientCertFile string
|
||||
// ClientKeyFile is the key for the ClientCertFile
|
||||
ClientKeyFile string
|
||||
|
||||
TrustedCAFile string
|
||||
ClientCertAuth bool
|
||||
CRLFile string
|
||||
InsecureSkipVerify bool
|
||||
SkipClientSANVerify bool
|
||||
|
||||
// ServerName ensures the cert matches the given host in case of discovery / virtual hosting
|
||||
ServerName string
|
||||
|
||||
// HandshakeFailure is optionally called when a connection fails to handshake. The
|
||||
// connection will be closed immediately afterwards.
|
||||
HandshakeFailure func(*tls.Conn, error)
|
||||
|
||||
// CipherSuites is a list of supported cipher suites.
|
||||
// If empty, Go auto-populates it by default.
|
||||
// Note that cipher suites are prioritized in the given order.
|
||||
CipherSuites []uint16
|
||||
|
||||
selfCert bool
|
||||
|
||||
// parseFunc exists to simplify testing. Typically, parseFunc
|
||||
// should be left nil. In that case, tls.X509KeyPair will be used.
|
||||
parseFunc func([]byte, []byte) (tls.Certificate, error)
|
||||
|
||||
// AllowedCN is a CN which must be provided by a client.
|
||||
AllowedCN string
|
||||
|
||||
// AllowedHostname is an IP address or hostname that must match the TLS
|
||||
// certificate provided by a client.
|
||||
AllowedHostname string
|
||||
|
||||
// Logger logs TLS errors.
|
||||
// If nil, all logs are discarded.
|
||||
Logger *zap.Logger
|
||||
|
||||
// EmptyCN indicates that the cert must have empty CN.
|
||||
// If true, ClientConfig() will return an error for a cert with non empty CN.
|
||||
EmptyCN bool
|
||||
}
|
||||
|
||||
func (info TLSInfo) String() string {
|
||||
return fmt.Sprintf("cert = %s, key = %s, client-cert=%s, client-key=%s, trusted-ca = %s, client-cert-auth = %v, crl-file = %s", info.CertFile, info.KeyFile, info.ClientCertFile, info.ClientKeyFile, info.TrustedCAFile, info.ClientCertAuth, info.CRLFile)
|
||||
}
|
||||
|
||||
func (info TLSInfo) Empty() bool {
|
||||
return info.CertFile == "" && info.KeyFile == ""
|
||||
}
|
||||
|
||||
func SelfCert(lg *zap.Logger, dirpath string, hosts []string, selfSignedCertValidity uint, additionalUsages ...x509.ExtKeyUsage) (info TLSInfo, err error) {
|
||||
info.Logger = lg
|
||||
if selfSignedCertValidity == 0 {
|
||||
err = fmt.Errorf("selfSignedCertValidity is invalid,it should be greater than 0")
|
||||
info.Logger.Warn(
|
||||
"cannot generate cert",
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
err = fileutil.TouchDirAll(dirpath)
|
||||
if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"cannot create cert directory",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
certPath := filepath.Join(dirpath, "cert.pem")
|
||||
keyPath := filepath.Join(dirpath, "key.pem")
|
||||
_, errcert := os.Stat(certPath)
|
||||
_, errkey := os.Stat(keyPath)
|
||||
if errcert == nil && errkey == nil {
|
||||
info.CertFile = certPath
|
||||
info.KeyFile = keyPath
|
||||
info.ClientCertFile = certPath
|
||||
info.ClientKeyFile = keyPath
|
||||
info.selfCert = true
|
||||
return
|
||||
}
|
||||
|
||||
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
|
||||
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
|
||||
if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"cannot generate random number",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
tmpl := x509.Certificate{
|
||||
SerialNumber: serialNumber,
|
||||
Subject: pkix.Name{Organization: []string{"etcd"}},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Duration(selfSignedCertValidity) * 365 * (24 * time.Hour)),
|
||||
|
||||
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: append([]x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, additionalUsages...),
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"automatically generate certificates",
|
||||
zap.Time("certificate-validity-bound-not-after", tmpl.NotAfter),
|
||||
)
|
||||
}
|
||||
|
||||
for _, host := range hosts {
|
||||
h, _, _ := net.SplitHostPort(host)
|
||||
if ip := net.ParseIP(h); ip != nil {
|
||||
tmpl.IPAddresses = append(tmpl.IPAddresses, ip)
|
||||
} else {
|
||||
tmpl.DNSNames = append(tmpl.DNSNames, h)
|
||||
}
|
||||
}
|
||||
|
||||
priv, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
|
||||
if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"cannot generate ECDSA key",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
derBytes, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, &priv.PublicKey, priv)
|
||||
if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"cannot generate x509 certificate",
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
certOut, err := os.Create(certPath)
|
||||
if err != nil {
|
||||
info.Logger.Warn(
|
||||
"cannot cert file",
|
||||
zap.String("path", certPath),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes})
|
||||
certOut.Close()
|
||||
if info.Logger != nil {
|
||||
info.Logger.Info("created cert file", zap.String("path", certPath))
|
||||
}
|
||||
|
||||
b, err := x509.MarshalECPrivateKey(priv)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
|
||||
if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"cannot key file",
|
||||
zap.String("path", keyPath),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
return
|
||||
}
|
||||
pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: b})
|
||||
keyOut.Close()
|
||||
if info.Logger != nil {
|
||||
info.Logger.Info("created key file", zap.String("path", keyPath))
|
||||
}
|
||||
return SelfCert(lg, dirpath, hosts, selfSignedCertValidity)
|
||||
}
|
||||
|
||||
// baseConfig is called on initial TLS handshake start.
|
||||
//
|
||||
// Previously,
|
||||
// 1. Server has non-empty (*tls.Config).Certificates on client hello
|
||||
// 2. Server calls (*tls.Config).GetCertificate iff:
|
||||
// - Server's (*tls.Config).Certificates is not empty, or
|
||||
// - Client supplies SNI; non-empty (*tls.ClientHelloInfo).ServerName
|
||||
//
|
||||
// When (*tls.Config).Certificates is always populated on initial handshake,
|
||||
// client is expected to provide a valid matching SNI to pass the TLS
|
||||
// verification, thus trigger server (*tls.Config).GetCertificate to reload
|
||||
// TLS assets. However, a cert whose SAN field does not include domain names
|
||||
// but only IP addresses, has empty (*tls.ClientHelloInfo).ServerName, thus
|
||||
// it was never able to trigger TLS reload on initial handshake; first
|
||||
// ceritifcate object was being used, never being updated.
|
||||
//
|
||||
// Now, (*tls.Config).Certificates is created empty on initial TLS client
|
||||
// handshake, in order to trigger (*tls.Config).GetCertificate and populate
|
||||
// rest of the certificates on every new TLS connection, even when client
|
||||
// SNI is empty (e.g. cert only includes IPs).
|
||||
func (info TLSInfo) baseConfig() (*tls.Config, error) {
|
||||
if info.KeyFile == "" || info.CertFile == "" {
|
||||
return nil, fmt.Errorf("KeyFile and CertFile must both be present[key: %v, cert: %v]", info.KeyFile, info.CertFile)
|
||||
}
|
||||
if info.Logger == nil {
|
||||
info.Logger = zap.NewNop()
|
||||
}
|
||||
|
||||
_, err := tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Perform prevalidation of client cert and key if either are provided. This makes sure we crash before accepting any connections.
|
||||
if (info.ClientKeyFile == "") != (info.ClientCertFile == "") {
|
||||
return nil, fmt.Errorf("ClientKeyFile and ClientCertFile must both be present or both absent: key: %v, cert: %v]", info.ClientKeyFile, info.ClientCertFile)
|
||||
}
|
||||
if info.ClientCertFile != "" {
|
||||
_, err := tlsutil.NewCert(info.ClientCertFile, info.ClientKeyFile, info.parseFunc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
cfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
ServerName: info.ServerName,
|
||||
}
|
||||
|
||||
if len(info.CipherSuites) > 0 {
|
||||
cfg.CipherSuites = info.CipherSuites
|
||||
}
|
||||
|
||||
// Client certificates may be verified by either an exact match on the CN,
|
||||
// or a more general check of the CN and SANs.
|
||||
var verifyCertificate func(*x509.Certificate) bool
|
||||
if info.AllowedCN != "" {
|
||||
if info.AllowedHostname != "" {
|
||||
return nil, fmt.Errorf("AllowedCN and AllowedHostname are mutually exclusive (cn=%q, hostname=%q)", info.AllowedCN, info.AllowedHostname)
|
||||
}
|
||||
verifyCertificate = func(cert *x509.Certificate) bool {
|
||||
return info.AllowedCN == cert.Subject.CommonName
|
||||
}
|
||||
}
|
||||
if info.AllowedHostname != "" {
|
||||
verifyCertificate = func(cert *x509.Certificate) bool {
|
||||
return cert.VerifyHostname(info.AllowedHostname) == nil
|
||||
}
|
||||
}
|
||||
if verifyCertificate != nil {
|
||||
cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
|
||||
for _, chains := range verifiedChains {
|
||||
if len(chains) != 0 {
|
||||
if verifyCertificate(chains[0]) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return errors.New("client certificate authentication failed")
|
||||
}
|
||||
}
|
||||
|
||||
// this only reloads certs when there's a client request
|
||||
// TODO: support server-side refresh (e.g. inotify, SIGHUP), caching
|
||||
cfg.GetCertificate = func(clientHello *tls.ClientHelloInfo) (cert *tls.Certificate, err error) {
|
||||
cert, err = tlsutil.NewCert(info.CertFile, info.KeyFile, info.parseFunc)
|
||||
if os.IsNotExist(err) {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"failed to find peer cert files",
|
||||
zap.String("cert-file", info.CertFile),
|
||||
zap.String("key-file", info.KeyFile),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
} else if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"failed to create peer certificate",
|
||||
zap.String("cert-file", info.CertFile),
|
||||
zap.String("key-file", info.KeyFile),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
return cert, err
|
||||
}
|
||||
cfg.GetClientCertificate = func(unused *tls.CertificateRequestInfo) (cert *tls.Certificate, err error) {
|
||||
certfile, keyfile := info.CertFile, info.KeyFile
|
||||
if info.ClientCertFile != "" {
|
||||
certfile, keyfile = info.ClientCertFile, info.ClientKeyFile
|
||||
}
|
||||
cert, err = tlsutil.NewCert(certfile, keyfile, info.parseFunc)
|
||||
if os.IsNotExist(err) {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"failed to find client cert files",
|
||||
zap.String("cert-file", certfile),
|
||||
zap.String("key-file", keyfile),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
} else if err != nil {
|
||||
if info.Logger != nil {
|
||||
info.Logger.Warn(
|
||||
"failed to create client certificate",
|
||||
zap.String("cert-file", certfile),
|
||||
zap.String("key-file", keyfile),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
return cert, err
|
||||
}
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// cafiles returns a list of CA file paths.
|
||||
func (info TLSInfo) cafiles() []string {
|
||||
cs := make([]string, 0)
|
||||
if info.TrustedCAFile != "" {
|
||||
cs = append(cs, info.TrustedCAFile)
|
||||
}
|
||||
return cs
|
||||
}
|
||||
|
||||
// ServerConfig generates a tls.Config object for use by an HTTP server.
|
||||
func (info TLSInfo) ServerConfig() (*tls.Config, error) {
|
||||
cfg, err := info.baseConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cfg.ClientAuth = tls.NoClientCert
|
||||
if info.TrustedCAFile != "" || info.ClientCertAuth {
|
||||
cfg.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
}
|
||||
|
||||
cs := info.cafiles()
|
||||
if len(cs) > 0 {
|
||||
cp, err := tlsutil.NewCertPool(cs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cfg.ClientCAs = cp
|
||||
}
|
||||
|
||||
// "h2" NextProtos is necessary for enabling HTTP2 for go's HTTP server
|
||||
cfg.NextProtos = []string{"h2"}
|
||||
|
||||
// go1.13 enables TLS 1.3 by default
|
||||
// and in TLS 1.3, cipher suites are not configurable
|
||||
// setting Max TLS version to TLS 1.2 for go 1.13
|
||||
cfg.MaxVersion = tls.VersionTLS12
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// ClientConfig generates a tls.Config object for use by an HTTP client.
|
||||
func (info TLSInfo) ClientConfig() (*tls.Config, error) {
|
||||
var cfg *tls.Config
|
||||
var err error
|
||||
|
||||
if !info.Empty() {
|
||||
cfg, err = info.baseConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
cfg = &tls.Config{ServerName: info.ServerName}
|
||||
}
|
||||
cfg.InsecureSkipVerify = info.InsecureSkipVerify
|
||||
|
||||
cs := info.cafiles()
|
||||
if len(cs) > 0 {
|
||||
cfg.RootCAs, err = tlsutil.NewCertPool(cs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if info.selfCert {
|
||||
cfg.InsecureSkipVerify = true
|
||||
}
|
||||
|
||||
if info.EmptyCN {
|
||||
hasNonEmptyCN := false
|
||||
cn := ""
|
||||
_, err := tlsutil.NewCert(info.CertFile, info.KeyFile, func(certPEMBlock []byte, keyPEMBlock []byte) (tls.Certificate, error) {
|
||||
var block *pem.Block
|
||||
block, _ = pem.Decode(certPEMBlock)
|
||||
cert, err := x509.ParseCertificate(block.Bytes)
|
||||
if err != nil {
|
||||
return tls.Certificate{}, err
|
||||
}
|
||||
if len(cert.Subject.CommonName) != 0 {
|
||||
hasNonEmptyCN = true
|
||||
cn = cert.Subject.CommonName
|
||||
}
|
||||
return tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if hasNonEmptyCN {
|
||||
return nil, fmt.Errorf("cert has non empty Common Name (%s): %s", cn, info.CertFile)
|
||||
}
|
||||
}
|
||||
|
||||
// go1.13 enables TLS 1.3 by default
|
||||
// and in TLS 1.3, cipher suites are not configurable
|
||||
// setting Max TLS version to TLS 1.2 for go 1.13
|
||||
cfg.MaxVersion = tls.VersionTLS12
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// IsClosedConnError returns true if the error is from closing listener, cmux.
|
||||
// copied from golang.org/x/net/http2/http2.go
|
||||
func IsClosedConnError(err error) bool {
|
||||
// 'use of closed network connection' (Go <=1.8)
|
||||
// 'use of closed file or network connection' (Go >1.8, internal/poll.ErrClosing)
|
||||
// 'mux: listener closed' (cmux.ErrListenerClosed)
|
||||
return err != nil && strings.Contains(err.Error(), "closed")
|
||||
}
|
||||
@@ -1,76 +0,0 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ListenerOptions struct {
|
||||
Listener net.Listener
|
||||
ListenConfig net.ListenConfig
|
||||
|
||||
socketOpts *SocketOpts
|
||||
tlsInfo *TLSInfo
|
||||
skipTLSInfoCheck bool
|
||||
writeTimeout time.Duration
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func newListenOpts(opts ...ListenerOption) *ListenerOptions {
|
||||
lnOpts := &ListenerOptions{}
|
||||
lnOpts.applyOpts(opts)
|
||||
return lnOpts
|
||||
}
|
||||
|
||||
func (lo *ListenerOptions) applyOpts(opts []ListenerOption) {
|
||||
for _, opt := range opts {
|
||||
opt(lo)
|
||||
}
|
||||
}
|
||||
|
||||
// IsTimeout returns true if the listener has a read/write timeout defined.
|
||||
func (lo *ListenerOptions) IsTimeout() bool { return lo.readTimeout != 0 || lo.writeTimeout != 0 }
|
||||
|
||||
// IsSocketOpts returns true if the listener options includes socket options.
|
||||
func (lo *ListenerOptions) IsSocketOpts() bool {
|
||||
if lo.socketOpts == nil {
|
||||
return false
|
||||
}
|
||||
return lo.socketOpts.ReusePort || lo.socketOpts.ReuseAddress
|
||||
}
|
||||
|
||||
// IsTLS returns true if listner options includes TLSInfo.
|
||||
func (lo *ListenerOptions) IsTLS() bool {
|
||||
if lo.tlsInfo == nil {
|
||||
return false
|
||||
}
|
||||
return !lo.tlsInfo.Empty()
|
||||
}
|
||||
|
||||
// ListenerOption are options which can be applied to the listener.
|
||||
type ListenerOption func(*ListenerOptions)
|
||||
|
||||
// WithTimeout allows for a read or write timeout to be applied to the listener.
|
||||
func WithTimeout(read, write time.Duration) ListenerOption {
|
||||
return func(lo *ListenerOptions) {
|
||||
lo.writeTimeout = write
|
||||
lo.readTimeout = read
|
||||
}
|
||||
}
|
||||
|
||||
// WithSocketOpts defines socket options that will be applied to the listener.
|
||||
func WithSocketOpts(s *SocketOpts) ListenerOption {
|
||||
return func(lo *ListenerOptions) { lo.socketOpts = s }
|
||||
}
|
||||
|
||||
// WithTLSInfo adds TLS credentials to the listener.
|
||||
func WithTLSInfo(t *TLSInfo) ListenerOption {
|
||||
return func(lo *ListenerOptions) { lo.tlsInfo = t }
|
||||
}
|
||||
|
||||
// WithSkipTLSInfoCheck when true a transport can be created with an https scheme
|
||||
// without passing TLSInfo, circumventing not presented error. Skipping this check
|
||||
// also requires that TLSInfo is not passed.
|
||||
func WithSkipTLSInfoCheck(skip bool) ListenerOption {
|
||||
return func(lo *ListenerOptions) { lo.skipTLSInfoCheck = skip }
|
||||
}
|
||||
@@ -1,576 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func createSelfCert(hosts ...string) (*TLSInfo, func(), error) {
|
||||
return createSelfCertEx("127.0.0.1")
|
||||
}
|
||||
|
||||
func createSelfCertEx(host string, additionalUsages ...x509.ExtKeyUsage) (*TLSInfo, func(), error) {
|
||||
d, terr := ioutil.TempDir("", "etcd-test-tls-")
|
||||
if terr != nil {
|
||||
return nil, nil, terr
|
||||
}
|
||||
info, err := SelfCert(zap.NewExample(), d, []string{host + ":0"}, 1, additionalUsages...)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return &info, func() { os.RemoveAll(d) }, nil
|
||||
}
|
||||
|
||||
func fakeCertificateParserFunc(cert tls.Certificate, err error) func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
|
||||
return func(certPEMBlock, keyPEMBlock []byte) (tls.Certificate, error) {
|
||||
return cert, err
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewListenerTLSInfo tests that NewListener with valid TLSInfo returns
|
||||
// a TLS listener that accepts TLS connections.
|
||||
func TestNewListenerTLSInfo(t *testing.T) {
|
||||
tlsInfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
testNewListenerTLSInfoAccept(t, *tlsInfo)
|
||||
}
|
||||
|
||||
func TestNewListenerWithOpts(t *testing.T) {
|
||||
tlsInfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
tests := map[string]struct {
|
||||
opts []ListenerOption
|
||||
scheme string
|
||||
expectedErr bool
|
||||
}{
|
||||
"https scheme no TLSInfo": {
|
||||
opts: []ListenerOption{},
|
||||
expectedErr: true,
|
||||
scheme: "https",
|
||||
},
|
||||
"https scheme no TLSInfo with skip check": {
|
||||
opts: []ListenerOption{WithSkipTLSInfoCheck(true)},
|
||||
expectedErr: false,
|
||||
scheme: "https",
|
||||
},
|
||||
"https scheme empty TLSInfo with skip check": {
|
||||
opts: []ListenerOption{
|
||||
WithSkipTLSInfoCheck(true),
|
||||
WithTLSInfo(&TLSInfo{}),
|
||||
},
|
||||
expectedErr: false,
|
||||
scheme: "https",
|
||||
},
|
||||
"https scheme empty TLSInfo no skip check": {
|
||||
opts: []ListenerOption{
|
||||
WithTLSInfo(&TLSInfo{}),
|
||||
},
|
||||
expectedErr: true,
|
||||
scheme: "https",
|
||||
},
|
||||
"https scheme with TLSInfo and skip check": {
|
||||
opts: []ListenerOption{
|
||||
WithSkipTLSInfoCheck(true),
|
||||
WithTLSInfo(tlsInfo),
|
||||
},
|
||||
expectedErr: false,
|
||||
scheme: "https",
|
||||
},
|
||||
}
|
||||
for testName, test := range tests {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
if test.expectedErr && err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !test.expectedErr && err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewListenerWithSocketOpts(t *testing.T) {
|
||||
tlsInfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
tests := map[string]struct {
|
||||
opts []ListenerOption
|
||||
scheme string
|
||||
expectedErr bool
|
||||
}{
|
||||
"nil socketopts": {
|
||||
opts: []ListenerOption{WithSocketOpts(nil)},
|
||||
expectedErr: true,
|
||||
scheme: "http",
|
||||
},
|
||||
"empty socketopts": {
|
||||
opts: []ListenerOption{WithSocketOpts(&SocketOpts{})},
|
||||
expectedErr: true,
|
||||
scheme: "http",
|
||||
},
|
||||
|
||||
"reuse address": {
|
||||
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true})},
|
||||
scheme: "http",
|
||||
expectedErr: true,
|
||||
},
|
||||
"reuse address with TLS": {
|
||||
opts: []ListenerOption{
|
||||
WithSocketOpts(&SocketOpts{ReuseAddress: true}),
|
||||
WithTLSInfo(tlsInfo),
|
||||
},
|
||||
scheme: "https",
|
||||
expectedErr: true,
|
||||
},
|
||||
"reuse address and port": {
|
||||
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true})},
|
||||
scheme: "http",
|
||||
expectedErr: false,
|
||||
},
|
||||
"reuse address and port with TLS": {
|
||||
opts: []ListenerOption{
|
||||
WithSocketOpts(&SocketOpts{ReuseAddress: true, ReusePort: true}),
|
||||
WithTLSInfo(tlsInfo),
|
||||
},
|
||||
scheme: "https",
|
||||
expectedErr: false,
|
||||
},
|
||||
"reuse port with TLS and timeout": {
|
||||
opts: []ListenerOption{
|
||||
WithSocketOpts(&SocketOpts{ReusePort: true}),
|
||||
WithTLSInfo(tlsInfo),
|
||||
WithTimeout(5*time.Second, 5*time.Second),
|
||||
},
|
||||
scheme: "https",
|
||||
expectedErr: false,
|
||||
},
|
||||
"reuse port with https scheme and no TLSInfo skip check": {
|
||||
opts: []ListenerOption{
|
||||
WithSocketOpts(&SocketOpts{ReusePort: true}),
|
||||
WithSkipTLSInfoCheck(true),
|
||||
},
|
||||
scheme: "https",
|
||||
expectedErr: false,
|
||||
},
|
||||
"reuse port": {
|
||||
opts: []ListenerOption{WithSocketOpts(&SocketOpts{ReusePort: true})},
|
||||
scheme: "http",
|
||||
expectedErr: false,
|
||||
},
|
||||
}
|
||||
for testName, test := range tests {
|
||||
t.Run(testName, func(t *testing.T) {
|
||||
ln, err := NewListenerWithOpts("127.0.0.1:0", test.scheme, test.opts...)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewListenerWithSocketOpts error: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
ln2, err := NewListenerWithOpts(ln.Addr().String(), test.scheme, test.opts...)
|
||||
if ln2 != nil {
|
||||
ln2.Close()
|
||||
}
|
||||
if test.expectedErr && err == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
if !test.expectedErr && err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testNewListenerTLSInfoAccept(t *testing.T, tlsInfo TLSInfo) {
|
||||
ln, err := NewListener("127.0.0.1:0", "https", &tlsInfo)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewListener error: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
|
||||
cli := &http.Client{Transport: tr}
|
||||
go cli.Get("https://" + ln.Addr().String())
|
||||
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected Accept error: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
if _, ok := conn.(*tls.Conn); !ok {
|
||||
t.Error("failed to accept *tls.Conn")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewListenerTLSInfoSkipClientSANVerify tests that if client IP address mismatches
|
||||
// with specified address in its certificate the connection is still accepted
|
||||
// if the flag SkipClientSANVerify is set (i.e. checkSAN() is disabled for the client side)
|
||||
func TestNewListenerTLSInfoSkipClientSANVerify(t *testing.T) {
|
||||
tests := []struct {
|
||||
skipClientSANVerify bool
|
||||
goodClientHost bool
|
||||
acceptExpected bool
|
||||
}{
|
||||
{false, true, true},
|
||||
{false, false, false},
|
||||
{true, true, true},
|
||||
{true, false, true},
|
||||
}
|
||||
for _, test := range tests {
|
||||
testNewListenerTLSInfoClientCheck(t, test.skipClientSANVerify, test.goodClientHost, test.acceptExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func testNewListenerTLSInfoClientCheck(t *testing.T, skipClientSANVerify, goodClientHost, acceptExpected bool) {
|
||||
tlsInfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
host := "127.0.0.222"
|
||||
if goodClientHost {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
clientTLSInfo, del2, err := createSelfCertEx(host, x509.ExtKeyUsageClientAuth)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del2()
|
||||
|
||||
tlsInfo.SkipClientSANVerify = skipClientSANVerify
|
||||
tlsInfo.TrustedCAFile = clientTLSInfo.CertFile
|
||||
|
||||
rootCAs := x509.NewCertPool()
|
||||
loaded, err := ioutil.ReadFile(tlsInfo.CertFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected missing certfile: %v", err)
|
||||
}
|
||||
rootCAs.AppendCertsFromPEM(loaded)
|
||||
|
||||
clientCert, err := tls.LoadX509KeyPair(clientTLSInfo.CertFile, clientTLSInfo.KeyFile)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create peer cert: %v", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{}
|
||||
tlsConfig.InsecureSkipVerify = false
|
||||
tlsConfig.Certificates = []tls.Certificate{clientCert}
|
||||
tlsConfig.RootCAs = rootCAs
|
||||
|
||||
ln, err := NewListener("127.0.0.1:0", "https", tlsInfo)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewListener error: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
tr := &http.Transport{TLSClientConfig: tlsConfig}
|
||||
cli := &http.Client{Transport: tr}
|
||||
chClientErr := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := cli.Get("https://" + ln.Addr().String())
|
||||
chClientErr <- err
|
||||
}()
|
||||
|
||||
chAcceptErr := make(chan error, 1)
|
||||
chAcceptConn := make(chan net.Conn, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
chAcceptErr <- err
|
||||
} else {
|
||||
chAcceptConn <- conn
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-chClientErr:
|
||||
if acceptExpected {
|
||||
t.Errorf("accepted for good client address: skipClientSANVerify=%t, goodClientHost=%t", skipClientSANVerify, goodClientHost)
|
||||
}
|
||||
case acceptErr := <-chAcceptErr:
|
||||
t.Fatalf("unexpected Accept error: %v", acceptErr)
|
||||
case conn := <-chAcceptConn:
|
||||
defer conn.Close()
|
||||
if _, ok := conn.(*tls.Conn); !ok {
|
||||
t.Errorf("failed to accept *tls.Conn")
|
||||
}
|
||||
if !acceptExpected {
|
||||
t.Errorf("accepted for bad client address: skipClientSANVerify=%t, goodClientHost=%t", skipClientSANVerify, goodClientHost)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewListenerTLSEmptyInfo(t *testing.T) {
|
||||
_, err := NewListener("127.0.0.1:0", "https", nil)
|
||||
if err == nil {
|
||||
t.Errorf("err = nil, want not presented error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewTransportTLSInfo(t *testing.T) {
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
tests := []TLSInfo{
|
||||
{},
|
||||
{
|
||||
CertFile: tlsinfo.CertFile,
|
||||
KeyFile: tlsinfo.KeyFile,
|
||||
},
|
||||
{
|
||||
CertFile: tlsinfo.CertFile,
|
||||
KeyFile: tlsinfo.KeyFile,
|
||||
TrustedCAFile: tlsinfo.TrustedCAFile,
|
||||
},
|
||||
{
|
||||
TrustedCAFile: tlsinfo.TrustedCAFile,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
tt.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
||||
trans, err := NewTransport(tt, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("Received unexpected error from NewTransport: %v", err)
|
||||
}
|
||||
|
||||
if trans.TLSClientConfig == nil {
|
||||
t.Fatalf("#%d: want non-nil TLSClientConfig", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSInfoNonexist(t *testing.T) {
|
||||
tlsInfo := TLSInfo{CertFile: "@badname", KeyFile: "@badname"}
|
||||
_, err := tlsInfo.ServerConfig()
|
||||
werr := &os.PathError{
|
||||
Op: "open",
|
||||
Path: "@badname",
|
||||
Err: errors.New("no such file or directory"),
|
||||
}
|
||||
if err.Error() != werr.Error() {
|
||||
t.Errorf("err = %v, want %v", err, werr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSInfoEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
info TLSInfo
|
||||
want bool
|
||||
}{
|
||||
{TLSInfo{}, true},
|
||||
{TLSInfo{TrustedCAFile: "baz"}, true},
|
||||
{TLSInfo{CertFile: "foo"}, false},
|
||||
{TLSInfo{KeyFile: "bar"}, false},
|
||||
{TLSInfo{CertFile: "foo", KeyFile: "bar"}, false},
|
||||
{TLSInfo{CertFile: "foo", TrustedCAFile: "baz"}, false},
|
||||
{TLSInfo{KeyFile: "bar", TrustedCAFile: "baz"}, false},
|
||||
{TLSInfo{CertFile: "foo", KeyFile: "bar", TrustedCAFile: "baz"}, false},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := tt.info.Empty()
|
||||
if tt.want != got {
|
||||
t.Errorf("#%d: result of Empty() incorrect: want=%t got=%t", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSInfoMissingFields(t *testing.T) {
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
tests := []TLSInfo{
|
||||
{CertFile: tlsinfo.CertFile},
|
||||
{KeyFile: tlsinfo.KeyFile},
|
||||
{CertFile: tlsinfo.CertFile, TrustedCAFile: tlsinfo.TrustedCAFile},
|
||||
{KeyFile: tlsinfo.KeyFile, TrustedCAFile: tlsinfo.TrustedCAFile},
|
||||
}
|
||||
|
||||
for i, info := range tests {
|
||||
if _, err = info.ServerConfig(); err == nil {
|
||||
t.Errorf("#%d: expected non-nil error from ServerConfig()", i)
|
||||
}
|
||||
|
||||
if _, err = info.ClientConfig(); err == nil {
|
||||
t.Errorf("#%d: expected non-nil error from ClientConfig()", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSInfoParseFuncError(t *testing.T) {
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
tests := []struct {
|
||||
info TLSInfo
|
||||
}{
|
||||
{
|
||||
info: *tlsinfo,
|
||||
},
|
||||
|
||||
{
|
||||
info: TLSInfo{CertFile: "", KeyFile: "", TrustedCAFile: tlsinfo.CertFile, EmptyCN: true},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, errors.New("fake"))
|
||||
|
||||
if _, err = tt.info.ServerConfig(); err == nil {
|
||||
t.Errorf("#%d: expected non-nil error from ServerConfig()", i)
|
||||
}
|
||||
|
||||
if _, err = tt.info.ClientConfig(); err == nil {
|
||||
t.Errorf("#%d: expected non-nil error from ClientConfig()", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSInfoConfigFuncs(t *testing.T) {
|
||||
tlsinfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
tests := []struct {
|
||||
info TLSInfo
|
||||
clientAuth tls.ClientAuthType
|
||||
wantCAs bool
|
||||
}{
|
||||
{
|
||||
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile},
|
||||
clientAuth: tls.NoClientCert,
|
||||
wantCAs: false,
|
||||
},
|
||||
|
||||
{
|
||||
info: TLSInfo{CertFile: tlsinfo.CertFile, KeyFile: tlsinfo.KeyFile, TrustedCAFile: tlsinfo.CertFile},
|
||||
clientAuth: tls.RequireAndVerifyClientCert,
|
||||
wantCAs: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
tt.info.parseFunc = fakeCertificateParserFunc(tls.Certificate{}, nil)
|
||||
|
||||
sCfg, err := tt.info.ServerConfig()
|
||||
if err != nil {
|
||||
t.Errorf("#%d: expected nil error from ServerConfig(), got non-nil: %v", i, err)
|
||||
}
|
||||
|
||||
if tt.wantCAs != (sCfg.ClientCAs != nil) {
|
||||
t.Errorf("#%d: wantCAs=%t but ClientCAs=%v", i, tt.wantCAs, sCfg.ClientCAs)
|
||||
}
|
||||
|
||||
cCfg, err := tt.info.ClientConfig()
|
||||
if err != nil {
|
||||
t.Errorf("#%d: expected nil error from ClientConfig(), got non-nil: %v", i, err)
|
||||
}
|
||||
|
||||
if tt.wantCAs != (cCfg.RootCAs != nil) {
|
||||
t.Errorf("#%d: wantCAs=%t but RootCAs=%v", i, tt.wantCAs, sCfg.RootCAs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewListenerUnixSocket(t *testing.T) {
|
||||
l, err := NewListener("testsocket", "unix", nil)
|
||||
if err != nil {
|
||||
t.Errorf("error listening on unix socket (%v)", err)
|
||||
}
|
||||
l.Close()
|
||||
}
|
||||
|
||||
// TestNewListenerTLSInfoSelfCert tests that a new certificate accepts connections.
|
||||
func TestNewListenerTLSInfoSelfCert(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir(os.TempDir(), "tlsdir")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tmpdir)
|
||||
tlsinfo, err := SelfCert(zap.NewExample(), tmpdir, []string{"127.0.0.1"}, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tlsinfo.Empty() {
|
||||
t.Fatalf("tlsinfo should have certs (%+v)", tlsinfo)
|
||||
}
|
||||
testNewListenerTLSInfoAccept(t, tlsinfo)
|
||||
}
|
||||
|
||||
func TestIsClosedConnError(t *testing.T) {
|
||||
l, err := NewListener("testsocket", "unix", nil)
|
||||
if err != nil {
|
||||
t.Errorf("error listening on unix socket (%v)", err)
|
||||
}
|
||||
l.Close()
|
||||
_, err = l.Accept()
|
||||
if !IsClosedConnError(err) {
|
||||
t.Fatalf("expect true, got false (%v)", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSocktOptsEmpty(t *testing.T) {
|
||||
tests := []struct {
|
||||
sopts SocketOpts
|
||||
want bool
|
||||
}{
|
||||
{SocketOpts{}, true},
|
||||
{SocketOpts{ReuseAddress: true, ReusePort: false}, false},
|
||||
{SocketOpts{ReusePort: true}, false},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := tt.sopts.Empty()
|
||||
if tt.want != got {
|
||||
t.Errorf("#%d: result of Empty() incorrect: want=%t got=%t", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,272 +0,0 @@
|
||||
// Copyright 2017 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// tlsListener overrides a TLS listener so it will reject client
|
||||
// certificates with insufficient SAN credentials or CRL revoked
|
||||
// certificates.
|
||||
type tlsListener struct {
|
||||
net.Listener
|
||||
connc chan net.Conn
|
||||
donec chan struct{}
|
||||
err error
|
||||
handshakeFailure func(*tls.Conn, error)
|
||||
check tlsCheckFunc
|
||||
}
|
||||
|
||||
type tlsCheckFunc func(context.Context, *tls.Conn) error
|
||||
|
||||
// NewTLSListener handshakes TLS connections and performs optional CRL checking.
|
||||
func NewTLSListener(l net.Listener, tlsinfo *TLSInfo) (net.Listener, error) {
|
||||
check := func(context.Context, *tls.Conn) error { return nil }
|
||||
return newTLSListener(l, tlsinfo, check)
|
||||
}
|
||||
|
||||
func newTLSListener(l net.Listener, tlsinfo *TLSInfo, check tlsCheckFunc) (net.Listener, error) {
|
||||
if tlsinfo == nil || tlsinfo.Empty() {
|
||||
l.Close()
|
||||
return nil, fmt.Errorf("cannot listen on TLS for %s: KeyFile and CertFile are not presented", l.Addr().String())
|
||||
}
|
||||
tlscfg, err := tlsinfo.ServerConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hf := tlsinfo.HandshakeFailure
|
||||
if hf == nil {
|
||||
hf = func(*tls.Conn, error) {}
|
||||
}
|
||||
|
||||
if len(tlsinfo.CRLFile) > 0 {
|
||||
prevCheck := check
|
||||
check = func(ctx context.Context, tlsConn *tls.Conn) error {
|
||||
if err := prevCheck(ctx, tlsConn); err != nil {
|
||||
return err
|
||||
}
|
||||
st := tlsConn.ConnectionState()
|
||||
if certs := st.PeerCertificates; len(certs) > 0 {
|
||||
return checkCRL(tlsinfo.CRLFile, certs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
tlsl := &tlsListener{
|
||||
Listener: tls.NewListener(l, tlscfg),
|
||||
connc: make(chan net.Conn),
|
||||
donec: make(chan struct{}),
|
||||
handshakeFailure: hf,
|
||||
check: check,
|
||||
}
|
||||
go tlsl.acceptLoop()
|
||||
return tlsl, nil
|
||||
}
|
||||
|
||||
func (l *tlsListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case conn := <-l.connc:
|
||||
return conn, nil
|
||||
case <-l.donec:
|
||||
return nil, l.err
|
||||
}
|
||||
}
|
||||
|
||||
func checkSAN(ctx context.Context, tlsConn *tls.Conn) error {
|
||||
st := tlsConn.ConnectionState()
|
||||
if certs := st.PeerCertificates; len(certs) > 0 {
|
||||
addr := tlsConn.RemoteAddr().String()
|
||||
return checkCertSAN(ctx, certs[0], addr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// acceptLoop launches each TLS handshake in a separate goroutine
|
||||
// to prevent a hanging TLS connection from blocking other connections.
|
||||
func (l *tlsListener) acceptLoop() {
|
||||
var wg sync.WaitGroup
|
||||
var pendingMu sync.Mutex
|
||||
|
||||
pending := make(map[net.Conn]struct{})
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
cancel()
|
||||
pendingMu.Lock()
|
||||
for c := range pending {
|
||||
c.Close()
|
||||
}
|
||||
pendingMu.Unlock()
|
||||
wg.Wait()
|
||||
close(l.donec)
|
||||
}()
|
||||
|
||||
for {
|
||||
conn, err := l.Listener.Accept()
|
||||
if err != nil {
|
||||
l.err = err
|
||||
return
|
||||
}
|
||||
|
||||
pendingMu.Lock()
|
||||
pending[conn] = struct{}{}
|
||||
pendingMu.Unlock()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer func() {
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
wg.Done()
|
||||
}()
|
||||
|
||||
tlsConn := conn.(*tls.Conn)
|
||||
herr := tlsConn.Handshake()
|
||||
pendingMu.Lock()
|
||||
delete(pending, conn)
|
||||
pendingMu.Unlock()
|
||||
|
||||
if herr != nil {
|
||||
l.handshakeFailure(tlsConn, herr)
|
||||
return
|
||||
}
|
||||
if err := l.check(ctx, tlsConn); err != nil {
|
||||
l.handshakeFailure(tlsConn, err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case l.connc <- tlsConn:
|
||||
conn = nil
|
||||
case <-ctx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
func checkCRL(crlPath string, cert []*x509.Certificate) error {
|
||||
// TODO: cache
|
||||
crlBytes, err := ioutil.ReadFile(crlPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
certList, err := x509.ParseCRL(crlBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
revokedSerials := make(map[string]struct{})
|
||||
for _, rc := range certList.TBSCertList.RevokedCertificates {
|
||||
revokedSerials[string(rc.SerialNumber.Bytes())] = struct{}{}
|
||||
}
|
||||
for _, c := range cert {
|
||||
serial := string(c.SerialNumber.Bytes())
|
||||
if _, ok := revokedSerials[serial]; ok {
|
||||
return fmt.Errorf("transport: certificate serial %x revoked", serial)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkCertSAN(ctx context.Context, cert *x509.Certificate, remoteAddr string) error {
|
||||
if len(cert.IPAddresses) == 0 && len(cert.DNSNames) == 0 {
|
||||
return nil
|
||||
}
|
||||
h, _, herr := net.SplitHostPort(remoteAddr)
|
||||
if herr != nil {
|
||||
return herr
|
||||
}
|
||||
if len(cert.IPAddresses) > 0 {
|
||||
cerr := cert.VerifyHostname(h)
|
||||
if cerr == nil {
|
||||
return nil
|
||||
}
|
||||
if len(cert.DNSNames) == 0 {
|
||||
return cerr
|
||||
}
|
||||
}
|
||||
if len(cert.DNSNames) > 0 {
|
||||
ok, err := isHostInDNS(ctx, h, cert.DNSNames)
|
||||
if ok {
|
||||
return nil
|
||||
}
|
||||
errStr := ""
|
||||
if err != nil {
|
||||
errStr = " (" + err.Error() + ")"
|
||||
}
|
||||
return fmt.Errorf("tls: %q does not match any of DNSNames %q"+errStr, h, cert.DNSNames)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isHostInDNS(ctx context.Context, host string, dnsNames []string) (ok bool, err error) {
|
||||
// reverse lookup
|
||||
wildcards, names := []string{}, []string{}
|
||||
for _, dns := range dnsNames {
|
||||
if strings.HasPrefix(dns, "*.") {
|
||||
wildcards = append(wildcards, dns[1:])
|
||||
} else {
|
||||
names = append(names, dns)
|
||||
}
|
||||
}
|
||||
lnames, lerr := net.DefaultResolver.LookupAddr(ctx, host)
|
||||
for _, name := range lnames {
|
||||
// strip trailing '.' from PTR record
|
||||
if name[len(name)-1] == '.' {
|
||||
name = name[:len(name)-1]
|
||||
}
|
||||
for _, wc := range wildcards {
|
||||
if strings.HasSuffix(name, wc) {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
for _, n := range names {
|
||||
if n == name {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
err = lerr
|
||||
|
||||
// forward lookup
|
||||
for _, dns := range names {
|
||||
addrs, lerr := net.DefaultResolver.LookupHost(ctx, dns)
|
||||
if lerr != nil {
|
||||
err = lerr
|
||||
continue
|
||||
}
|
||||
for _, addr := range addrs {
|
||||
if addr == host {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
func (l *tlsListener) Close() error {
|
||||
err := l.Listener.Close()
|
||||
<-l.donec
|
||||
return err
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Controls []func(network, addr string, conn syscall.RawConn) error
|
||||
|
||||
func (ctls Controls) Control(network, addr string, conn syscall.RawConn) error {
|
||||
for _, s := range ctls {
|
||||
if err := s(network, addr, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SocketOpts struct {
|
||||
// ReusePort enables socket option SO_REUSEPORT [1] which allows rebind of
|
||||
// a port already in use. User should keep in mind that flock can fail
|
||||
// in which case lock on data file could result in unexpected
|
||||
// condition. User should take caution to protect against lock race.
|
||||
// [1] https://man7.org/linux/man-pages/man7/socket.7.html
|
||||
ReusePort bool
|
||||
// ReuseAddress enables a socket option SO_REUSEADDR which allows
|
||||
// binding to an address in `TIME_WAIT` state. Useful to improve MTTR
|
||||
// in cases where etcd slow to restart due to excessive `TIME_WAIT`.
|
||||
// [1] https://man7.org/linux/man-pages/man7/socket.7.html
|
||||
ReuseAddress bool
|
||||
}
|
||||
|
||||
func getControls(sopts *SocketOpts) Controls {
|
||||
ctls := Controls{}
|
||||
if sopts.ReuseAddress {
|
||||
ctls = append(ctls, setReuseAddress)
|
||||
}
|
||||
if sopts.ReusePort {
|
||||
ctls = append(ctls, setReusePort)
|
||||
}
|
||||
return ctls
|
||||
}
|
||||
|
||||
func (sopts *SocketOpts) Empty() bool {
|
||||
return !sopts.ReuseAddress && !sopts.ReusePort
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"golang.org/x/sys/unix"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setReusePort(network, address string, conn syscall.RawConn) error {
|
||||
return conn.Control(func(fd uintptr) {
|
||||
syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEPORT, 1)
|
||||
})
|
||||
}
|
||||
|
||||
func setReuseAddress(network, address string, conn syscall.RawConn) error {
|
||||
return conn.Control(func(fd uintptr) {
|
||||
syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, unix.SO_REUSEADDR, 1)
|
||||
})
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
//go:build windows
|
||||
// +build windows
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func setReusePort(network, address string, c syscall.RawConn) error {
|
||||
return fmt.Errorf("port reuse is not supported on Windows")
|
||||
}
|
||||
|
||||
// Windows supports SO_REUSEADDR, but it may cause undefined behavior, as
|
||||
// there is no protection against port hijacking.
|
||||
func setReuseAddress(network, addr string, conn syscall.RawConn) error {
|
||||
return fmt.Errorf("address reuse is not supported on Windows")
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type timeoutConn struct {
|
||||
net.Conn
|
||||
writeTimeout time.Duration
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func (c timeoutConn) Write(b []byte) (n int, err error) {
|
||||
if c.writeTimeout > 0 {
|
||||
if err := c.SetWriteDeadline(time.Now().Add(c.writeTimeout)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.Conn.Write(b)
|
||||
}
|
||||
|
||||
func (c timeoutConn) Read(b []byte) (n int, err error) {
|
||||
if c.readTimeout > 0 {
|
||||
if err := c.SetReadDeadline(time.Now().Add(c.readTimeout)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return c.Conn.Read(b)
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rwTimeoutDialer struct {
|
||||
wtimeoutd time.Duration
|
||||
rdtimeoutd time.Duration
|
||||
net.Dialer
|
||||
}
|
||||
|
||||
func (d *rwTimeoutDialer) Dial(network, address string) (net.Conn, error) {
|
||||
conn, err := d.Dialer.Dial(network, address)
|
||||
tconn := &timeoutConn{
|
||||
readTimeout: d.rdtimeoutd,
|
||||
writeTimeout: d.wtimeoutd,
|
||||
Conn: conn,
|
||||
}
|
||||
return tconn, err
|
||||
}
|
||||
@@ -1,104 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestReadWriteTimeoutDialer(t *testing.T) {
|
||||
stop := make(chan struct{})
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected listen error: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
stop <- struct{}{}
|
||||
}()
|
||||
ts := testBlockingServer{ln, 2, stop}
|
||||
go ts.Start(t)
|
||||
|
||||
d := rwTimeoutDialer{
|
||||
wtimeoutd: 10 * time.Millisecond,
|
||||
rdtimeoutd: 10 * time.Millisecond,
|
||||
}
|
||||
conn, err := d.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected dial error: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// fill the socket buffer
|
||||
data := make([]byte, 5*1024*1024)
|
||||
done := make(chan struct{}, 1)
|
||||
go func() {
|
||||
_, err = conn.Write(data)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Wait 5s more than timeout to avoid delay in low-end systems;
|
||||
// the slack was 1s extra, but that wasn't enough for CI.
|
||||
case <-time.After(d.wtimeoutd*10 + 5*time.Second):
|
||||
t.Fatal("wait timeout")
|
||||
}
|
||||
|
||||
if operr, ok := err.(*net.OpError); !ok || operr.Op != "write" || !operr.Timeout() {
|
||||
t.Errorf("err = %v, want write i/o timeout error", err)
|
||||
}
|
||||
|
||||
conn, err = d.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected dial error: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
buf := make([]byte, 10)
|
||||
go func() {
|
||||
_, err = conn.Read(buf)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(d.rdtimeoutd * 10):
|
||||
t.Fatal("wait timeout")
|
||||
}
|
||||
|
||||
if operr, ok := err.(*net.OpError); !ok || operr.Op != "read" || !operr.Timeout() {
|
||||
t.Errorf("err = %v, want write i/o timeout error", err)
|
||||
}
|
||||
}
|
||||
|
||||
type testBlockingServer struct {
|
||||
ln net.Listener
|
||||
n int
|
||||
stop chan struct{}
|
||||
}
|
||||
|
||||
func (ts *testBlockingServer) Start(t *testing.T) {
|
||||
for i := 0; i < ts.n; i++ {
|
||||
conn, err := ts.ln.Accept()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
}
|
||||
<-ts.stop
|
||||
}
|
||||
@@ -1,45 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewTimeoutListener returns a listener that listens on the given address.
|
||||
// If read/write on the accepted connection blocks longer than its time limit,
|
||||
// it will return timeout error.
|
||||
func NewTimeoutListener(addr string, scheme string, tlsinfo *TLSInfo, readTimeout, writeTimeout time.Duration) (net.Listener, error) {
|
||||
return newListener(addr, scheme, WithTimeout(readTimeout, writeTimeout), WithTLSInfo(tlsinfo))
|
||||
}
|
||||
|
||||
type rwTimeoutListener struct {
|
||||
net.Listener
|
||||
writeTimeout time.Duration
|
||||
readTimeout time.Duration
|
||||
}
|
||||
|
||||
func (rwln *rwTimeoutListener) Accept() (net.Conn, error) {
|
||||
c, err := rwln.Listener.Accept()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return timeoutConn{
|
||||
Conn: c,
|
||||
writeTimeout: rwln.writeTimeout,
|
||||
readTimeout: rwln.readTimeout,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewTimeoutListener tests that NewTimeoutListener returns a
|
||||
// rwTimeoutListener struct with timeouts set.
|
||||
func TestNewTimeoutListener(t *testing.T) {
|
||||
l, err := NewTimeoutListener("127.0.0.1:0", "http", nil, time.Hour, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewTimeoutListener error: %v", err)
|
||||
}
|
||||
defer l.Close()
|
||||
tln := l.(*rwTimeoutListener)
|
||||
if tln.readTimeout != time.Hour {
|
||||
t.Errorf("read timeout = %s, want %s", tln.readTimeout, time.Hour)
|
||||
}
|
||||
if tln.writeTimeout != time.Hour {
|
||||
t.Errorf("write timeout = %s, want %s", tln.writeTimeout, time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteReadTimeoutListener(t *testing.T) {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected listen error: %v", err)
|
||||
}
|
||||
wln := rwTimeoutListener{
|
||||
Listener: ln,
|
||||
writeTimeout: 10 * time.Millisecond,
|
||||
readTimeout: 10 * time.Millisecond,
|
||||
}
|
||||
stop := make(chan struct{}, 1)
|
||||
|
||||
blocker := func() {
|
||||
conn, derr := net.Dial("tcp", ln.Addr().String())
|
||||
if derr != nil {
|
||||
t.Errorf("unexpected dail error: %v", derr)
|
||||
}
|
||||
defer conn.Close()
|
||||
// block the receiver until the writer timeout
|
||||
<-stop
|
||||
}
|
||||
go blocker()
|
||||
|
||||
conn, err := wln.Accept()
|
||||
if err != nil {
|
||||
stop <- struct{}{}
|
||||
t.Fatalf("unexpected accept error: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// fill the socket buffer
|
||||
data := make([]byte, 5*1024*1024)
|
||||
done := make(chan struct{}, 1)
|
||||
go func() {
|
||||
_, err = conn.Write(data)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// It waits 1s more to avoid delay in low-end system.
|
||||
case <-time.After(wln.writeTimeout*10 + time.Second):
|
||||
stop <- struct{}{}
|
||||
t.Fatal("wait timeout")
|
||||
}
|
||||
|
||||
if operr, ok := err.(*net.OpError); !ok || operr.Op != "write" || !operr.Timeout() {
|
||||
t.Errorf("err = %v, want write i/o timeout error", err)
|
||||
}
|
||||
stop <- struct{}{}
|
||||
|
||||
go blocker()
|
||||
|
||||
conn, err = wln.Accept()
|
||||
if err != nil {
|
||||
stop <- struct{}{}
|
||||
t.Fatalf("unexpected accept error: %v", err)
|
||||
}
|
||||
buf := make([]byte, 10)
|
||||
|
||||
go func() {
|
||||
_, err = conn.Read(buf)
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(wln.readTimeout * 10):
|
||||
stop <- struct{}{}
|
||||
t.Fatal("wait timeout")
|
||||
}
|
||||
|
||||
if operr, ok := err.(*net.OpError); !ok || operr.Op != "read" || !operr.Timeout() {
|
||||
t.Errorf("err = %v, want write i/o timeout error", err)
|
||||
}
|
||||
stop <- struct{}{}
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewTimeoutTransport returns a transport created using the given TLS info.
|
||||
// If read/write on the created connection blocks longer than its time limit,
|
||||
// it will return timeout error.
|
||||
// If read/write timeout is set, transport will not be able to reuse connection.
|
||||
func NewTimeoutTransport(info TLSInfo, dialtimeoutd, rdtimeoutd, wtimeoutd time.Duration) (*http.Transport, error) {
|
||||
tr, err := NewTransport(info, dialtimeoutd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if rdtimeoutd != 0 || wtimeoutd != 0 {
|
||||
// the timed out connection will timeout soon after it is idle.
|
||||
// it should not be put back to http transport as an idle connection for future usage.
|
||||
tr.MaxIdleConnsPerHost = -1
|
||||
} else {
|
||||
// allow more idle connections between peers to avoid unnecessary port allocation.
|
||||
tr.MaxIdleConnsPerHost = 1024
|
||||
}
|
||||
|
||||
tr.Dial = (&rwTimeoutDialer{
|
||||
Dialer: net.Dialer{
|
||||
Timeout: dialtimeoutd,
|
||||
KeepAlive: 30 * time.Second,
|
||||
},
|
||||
rdtimeoutd: rdtimeoutd,
|
||||
wtimeoutd: wtimeoutd,
|
||||
}).Dial
|
||||
return tr, nil
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewTimeoutTransport tests that NewTimeoutTransport returns a transport
|
||||
// that can dial out timeout connections.
|
||||
func TestNewTimeoutTransport(t *testing.T) {
|
||||
tr, err := NewTimeoutTransport(TLSInfo{}, time.Hour, time.Hour, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewTimeoutTransport error: %v", err)
|
||||
}
|
||||
|
||||
remoteAddr := func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write([]byte(r.RemoteAddr))
|
||||
}
|
||||
srv := httptest.NewServer(http.HandlerFunc(remoteAddr))
|
||||
|
||||
defer srv.Close()
|
||||
conn, err := tr.Dial("tcp", srv.Listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected dial error: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
tconn, ok := conn.(*timeoutConn)
|
||||
if !ok {
|
||||
t.Fatalf("failed to dial out *timeoutConn")
|
||||
}
|
||||
if tconn.readTimeout != time.Hour {
|
||||
t.Errorf("read timeout = %s, want %s", tconn.readTimeout, time.Hour)
|
||||
}
|
||||
if tconn.writeTimeout != time.Hour {
|
||||
t.Errorf("write timeout = %s, want %s", tconn.writeTimeout, time.Hour)
|
||||
}
|
||||
|
||||
// ensure not reuse timeout connection
|
||||
req, err := http.NewRequest("GET", srv.URL, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err %v", err)
|
||||
}
|
||||
resp, err := tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err %v", err)
|
||||
}
|
||||
addr0, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err %v", err)
|
||||
}
|
||||
|
||||
resp, err = tr.RoundTrip(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err %v", err)
|
||||
}
|
||||
addr1, err := ioutil.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected err %v", err)
|
||||
}
|
||||
|
||||
if bytes.Equal(addr0, addr1) {
|
||||
t.Errorf("addr0 = %s addr1= %s, want not equal", string(addr0), string(addr1))
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ValidateSecureEndpoints scans the given endpoints against tls info, returning only those
|
||||
// endpoints that could be validated as secure.
|
||||
func ValidateSecureEndpoints(tlsInfo TLSInfo, eps []string) ([]string, error) {
|
||||
t, err := NewTransport(tlsInfo, 5*time.Second)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var errs []string
|
||||
var endpoints []string
|
||||
for _, ep := range eps {
|
||||
if !strings.HasPrefix(ep, "https://") {
|
||||
errs = append(errs, fmt.Sprintf("%q is insecure", ep))
|
||||
continue
|
||||
}
|
||||
conn, cerr := t.Dial("tcp", ep[len("https://"):])
|
||||
if cerr != nil {
|
||||
errs = append(errs, fmt.Sprintf("%q failed to dial (%v)", ep, cerr))
|
||||
continue
|
||||
}
|
||||
conn.Close()
|
||||
endpoints = append(endpoints, ep)
|
||||
}
|
||||
if len(errs) != 0 {
|
||||
err = fmt.Errorf("%s", strings.Join(errs, ","))
|
||||
}
|
||||
return endpoints, err
|
||||
}
|
||||
@@ -1,71 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type unixTransport struct{ *http.Transport }
|
||||
|
||||
func NewTransport(info TLSInfo, dialtimeoutd time.Duration) (*http.Transport, error) {
|
||||
cfg, err := info.ClientConfig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: (&net.Dialer{
|
||||
Timeout: dialtimeoutd,
|
||||
// value taken from http.DefaultTransport
|
||||
KeepAlive: 30 * time.Second,
|
||||
}).Dial,
|
||||
// value taken from http.DefaultTransport
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: cfg,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: dialtimeoutd,
|
||||
KeepAlive: 30 * time.Second,
|
||||
}
|
||||
dial := func(net, addr string) (net.Conn, error) {
|
||||
return dialer.Dial("unix", addr)
|
||||
}
|
||||
|
||||
tu := &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
Dial: dial,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: cfg,
|
||||
}
|
||||
ut := &unixTransport{tu}
|
||||
|
||||
t.RegisterProtocol("unix", ut)
|
||||
t.RegisterProtocol("unixs", ut)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (urt *unixTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
url := *req.URL
|
||||
req.URL = &url
|
||||
req.URL.Scheme = strings.Replace(req.URL.Scheme, "unix", "http", 1)
|
||||
return urt.Transport.RoundTrip(req)
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
// Copyright 2018 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestNewTransportTLSInvalidCipherSuitesTLS12 expects a client with invalid
|
||||
// cipher suites fail to handshake with the server.
|
||||
func TestNewTransportTLSInvalidCipherSuitesTLS12(t *testing.T) {
|
||||
tlsInfo, del, err := createSelfCert()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create cert: %v", err)
|
||||
}
|
||||
defer del()
|
||||
|
||||
cipherSuites := []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
|
||||
}
|
||||
|
||||
// make server and client have unmatched cipher suites
|
||||
srvTLS, cliTLS := *tlsInfo, *tlsInfo
|
||||
srvTLS.CipherSuites, cliTLS.CipherSuites = cipherSuites[:2], cipherSuites[2:]
|
||||
|
||||
ln, err := NewListener("127.0.0.1:0", "https", &srvTLS)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected NewListener error: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
donec := make(chan struct{})
|
||||
go func() {
|
||||
ln.Accept()
|
||||
donec <- struct{}{}
|
||||
}()
|
||||
go func() {
|
||||
tr, err := NewTransport(cliTLS, 3*time.Second)
|
||||
tr.TLSClientConfig.MaxVersion = tls.VersionTLS12
|
||||
if err != nil {
|
||||
t.Errorf("unexpected NewTransport error: %v", err)
|
||||
}
|
||||
cli := &http.Client{Transport: tr}
|
||||
_, gerr := cli.Get("https://" + ln.Addr().String())
|
||||
if gerr == nil || !strings.Contains(gerr.Error(), "tls: handshake failure") {
|
||||
t.Error("expected client TLS handshake error")
|
||||
}
|
||||
ln.Close()
|
||||
donec <- struct{}{}
|
||||
}()
|
||||
<-donec
|
||||
<-donec
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
// Copyright 2016 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package transport
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
)
|
||||
|
||||
type unixListener struct{ net.Listener }
|
||||
|
||||
func NewUnixListener(addr string) (net.Listener, error) {
|
||||
if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
l, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &unixListener{l}, nil
|
||||
}
|
||||
|
||||
func (ul *unixListener) Close() error {
|
||||
if err := os.Remove(ul.Addr().String()); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
return ul.Listener.Close()
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package types declares various data types and implements type-checking
|
||||
// functions.
|
||||
package types
|
||||
@@ -1,39 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import "strconv"
|
||||
|
||||
// ID represents a generic identifier which is canonically
|
||||
// stored as a uint64 but is typically represented as a
|
||||
// base-16 string for input/output
|
||||
type ID uint64
|
||||
|
||||
func (i ID) String() string {
|
||||
return strconv.FormatUint(uint64(i), 16)
|
||||
}
|
||||
|
||||
// IDFromString attempts to create an ID from a base-16 string.
|
||||
func IDFromString(s string) (ID, error) {
|
||||
i, err := strconv.ParseUint(s, 16, 64)
|
||||
return ID(i), err
|
||||
}
|
||||
|
||||
// IDSlice implements the sort interface
|
||||
type IDSlice []ID
|
||||
|
||||
func (p IDSlice) Len() int { return len(p) }
|
||||
func (p IDSlice) Less(i, j int) bool { return uint64(p[i]) < uint64(p[j]) }
|
||||
func (p IDSlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
@@ -1,95 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIDString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input ID
|
||||
want string
|
||||
}{
|
||||
{
|
||||
input: 12,
|
||||
want: "c",
|
||||
},
|
||||
{
|
||||
input: 4918257920282737594,
|
||||
want: "444129853c343bba",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got := tt.input.String()
|
||||
if tt.want != got {
|
||||
t.Errorf("#%d: ID.String failure: want=%v, got=%v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDFromString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want ID
|
||||
}{
|
||||
{
|
||||
input: "17",
|
||||
want: 23,
|
||||
},
|
||||
{
|
||||
input: "612840dae127353",
|
||||
want: 437557308098245459,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
got, err := IDFromString(tt.input)
|
||||
if err != nil {
|
||||
t.Errorf("#%d: IDFromString failure: err=%v", i, err)
|
||||
continue
|
||||
}
|
||||
if tt.want != got {
|
||||
t.Errorf("#%d: IDFromString failure: want=%v, got=%v", i, tt.want, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDFromStringFail(t *testing.T) {
|
||||
tests := []string{
|
||||
"",
|
||||
"XXX",
|
||||
"612840dae127353612840dae127353",
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
_, err := IDFromString(tt)
|
||||
if err == nil {
|
||||
t.Fatalf("#%d: IDFromString expected error, but err=nil", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIDSlice(t *testing.T) {
|
||||
g := []ID{10, 500, 5, 1, 100, 25}
|
||||
w := []ID{1, 5, 10, 25, 100, 500}
|
||||
sort.Sort(IDSlice(g))
|
||||
if !reflect.DeepEqual(g, w) {
|
||||
t.Errorf("slice after sort = %#v, want %#v", g, w)
|
||||
}
|
||||
}
|
||||
195
pkg/types/set.go
195
pkg/types/set.go
@@ -1,195 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type Set interface {
|
||||
Add(string)
|
||||
Remove(string)
|
||||
Contains(string) bool
|
||||
Equals(Set) bool
|
||||
Length() int
|
||||
Values() []string
|
||||
Copy() Set
|
||||
Sub(Set) Set
|
||||
}
|
||||
|
||||
func NewUnsafeSet(values ...string) *unsafeSet {
|
||||
set := &unsafeSet{make(map[string]struct{})}
|
||||
for _, v := range values {
|
||||
set.Add(v)
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func NewThreadsafeSet(values ...string) *tsafeSet {
|
||||
us := NewUnsafeSet(values...)
|
||||
return &tsafeSet{us, sync.RWMutex{}}
|
||||
}
|
||||
|
||||
type unsafeSet struct {
|
||||
d map[string]struct{}
|
||||
}
|
||||
|
||||
// Add adds a new value to the set (no-op if the value is already present)
|
||||
func (us *unsafeSet) Add(value string) {
|
||||
us.d[value] = struct{}{}
|
||||
}
|
||||
|
||||
// Remove removes the given value from the set
|
||||
func (us *unsafeSet) Remove(value string) {
|
||||
delete(us.d, value)
|
||||
}
|
||||
|
||||
// Contains returns whether the set contains the given value
|
||||
func (us *unsafeSet) Contains(value string) (exists bool) {
|
||||
_, exists = us.d[value]
|
||||
return exists
|
||||
}
|
||||
|
||||
// ContainsAll returns whether the set contains all given values
|
||||
func (us *unsafeSet) ContainsAll(values []string) bool {
|
||||
for _, s := range values {
|
||||
if !us.Contains(s) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Equals returns whether the contents of two sets are identical
|
||||
func (us *unsafeSet) Equals(other Set) bool {
|
||||
v1 := sort.StringSlice(us.Values())
|
||||
v2 := sort.StringSlice(other.Values())
|
||||
v1.Sort()
|
||||
v2.Sort()
|
||||
return reflect.DeepEqual(v1, v2)
|
||||
}
|
||||
|
||||
// Length returns the number of elements in the set
|
||||
func (us *unsafeSet) Length() int {
|
||||
return len(us.d)
|
||||
}
|
||||
|
||||
// Values returns the values of the Set in an unspecified order.
|
||||
func (us *unsafeSet) Values() (values []string) {
|
||||
values = make([]string, 0)
|
||||
for val := range us.d {
|
||||
values = append(values, val)
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
// Copy creates a new Set containing the values of the first
|
||||
func (us *unsafeSet) Copy() Set {
|
||||
cp := NewUnsafeSet()
|
||||
for val := range us.d {
|
||||
cp.Add(val)
|
||||
}
|
||||
|
||||
return cp
|
||||
}
|
||||
|
||||
// Sub removes all elements in other from the set
|
||||
func (us *unsafeSet) Sub(other Set) Set {
|
||||
oValues := other.Values()
|
||||
result := us.Copy().(*unsafeSet)
|
||||
|
||||
for _, val := range oValues {
|
||||
if _, ok := result.d[val]; !ok {
|
||||
continue
|
||||
}
|
||||
delete(result.d, val)
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
type tsafeSet struct {
|
||||
us *unsafeSet
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Add(value string) {
|
||||
ts.m.Lock()
|
||||
defer ts.m.Unlock()
|
||||
ts.us.Add(value)
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Remove(value string) {
|
||||
ts.m.Lock()
|
||||
defer ts.m.Unlock()
|
||||
ts.us.Remove(value)
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Contains(value string) (exists bool) {
|
||||
ts.m.RLock()
|
||||
defer ts.m.RUnlock()
|
||||
return ts.us.Contains(value)
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Equals(other Set) bool {
|
||||
ts.m.RLock()
|
||||
defer ts.m.RUnlock()
|
||||
|
||||
// If ts and other represent the same variable, avoid calling
|
||||
// ts.us.Equals(other), to avoid double RLock bug
|
||||
if _other, ok := other.(*tsafeSet); ok {
|
||||
if _other == ts {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return ts.us.Equals(other)
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Length() int {
|
||||
ts.m.RLock()
|
||||
defer ts.m.RUnlock()
|
||||
return ts.us.Length()
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Values() (values []string) {
|
||||
ts.m.RLock()
|
||||
defer ts.m.RUnlock()
|
||||
return ts.us.Values()
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Copy() Set {
|
||||
ts.m.RLock()
|
||||
defer ts.m.RUnlock()
|
||||
usResult := ts.us.Copy().(*unsafeSet)
|
||||
return &tsafeSet{usResult, sync.RWMutex{}}
|
||||
}
|
||||
|
||||
func (ts *tsafeSet) Sub(other Set) Set {
|
||||
ts.m.RLock()
|
||||
defer ts.m.RUnlock()
|
||||
|
||||
// If ts and other represent the same variable, avoid calling
|
||||
// ts.us.Sub(other), to avoid double RLock bug
|
||||
if _other, ok := other.(*tsafeSet); ok {
|
||||
if _other == ts {
|
||||
usResult := NewUnsafeSet()
|
||||
return &tsafeSet{usResult, sync.RWMutex{}}
|
||||
}
|
||||
}
|
||||
usResult := ts.us.Sub(other).(*unsafeSet)
|
||||
return &tsafeSet{usResult, sync.RWMutex{}}
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUnsafeSet(t *testing.T) {
|
||||
driveSetTests(t, NewUnsafeSet())
|
||||
}
|
||||
|
||||
func TestThreadsafeSet(t *testing.T) {
|
||||
driveSetTests(t, NewThreadsafeSet())
|
||||
}
|
||||
|
||||
// Check that two slices contents are equal; order is irrelevant
|
||||
func equal(a, b []string) bool {
|
||||
as := sort.StringSlice(a)
|
||||
bs := sort.StringSlice(b)
|
||||
as.Sort()
|
||||
bs.Sort()
|
||||
return reflect.DeepEqual(as, bs)
|
||||
}
|
||||
|
||||
func driveSetTests(t *testing.T, s Set) {
|
||||
// Verify operations on an empty set
|
||||
eValues := []string{}
|
||||
values := s.Values()
|
||||
if !reflect.DeepEqual(values, eValues) {
|
||||
t.Fatalf("Expect values=%v got %v", eValues, values)
|
||||
}
|
||||
if l := s.Length(); l != 0 {
|
||||
t.Fatalf("Expected length=0, got %d", l)
|
||||
}
|
||||
for _, v := range []string{"foo", "bar", "baz"} {
|
||||
if s.Contains(v) {
|
||||
t.Fatalf("Expect s.Contains(%q) to be fale, got true", v)
|
||||
}
|
||||
}
|
||||
|
||||
// Add three items, ensure they show up
|
||||
s.Add("foo")
|
||||
s.Add("bar")
|
||||
s.Add("baz")
|
||||
|
||||
eValues = []string{"foo", "bar", "baz"}
|
||||
values = s.Values()
|
||||
if !equal(values, eValues) {
|
||||
t.Fatalf("Expect values=%v got %v", eValues, values)
|
||||
}
|
||||
|
||||
for _, v := range eValues {
|
||||
if !s.Contains(v) {
|
||||
t.Fatalf("Expect s.Contains(%q) to be true, got false", v)
|
||||
}
|
||||
}
|
||||
|
||||
if l := s.Length(); l != 3 {
|
||||
t.Fatalf("Expected length=3, got %d", l)
|
||||
}
|
||||
|
||||
// Add the same item a second time, ensuring it is not duplicated
|
||||
s.Add("foo")
|
||||
|
||||
values = s.Values()
|
||||
if !equal(values, eValues) {
|
||||
t.Fatalf("Expect values=%v got %v", eValues, values)
|
||||
}
|
||||
if l := s.Length(); l != 3 {
|
||||
t.Fatalf("Expected length=3, got %d", l)
|
||||
}
|
||||
|
||||
// Remove all items, ensure they are gone
|
||||
s.Remove("foo")
|
||||
s.Remove("bar")
|
||||
s.Remove("baz")
|
||||
|
||||
eValues = []string{}
|
||||
values = s.Values()
|
||||
if !equal(values, eValues) {
|
||||
t.Fatalf("Expect values=%v got %v", eValues, values)
|
||||
}
|
||||
|
||||
if l := s.Length(); l != 0 {
|
||||
t.Fatalf("Expected length=0, got %d", l)
|
||||
}
|
||||
|
||||
// Create new copies of the set, and ensure they are unlinked to the
|
||||
// original Set by making modifications
|
||||
s.Add("foo")
|
||||
s.Add("bar")
|
||||
cp1 := s.Copy()
|
||||
cp2 := s.Copy()
|
||||
s.Remove("foo")
|
||||
cp3 := s.Copy()
|
||||
cp1.Add("baz")
|
||||
|
||||
for i, tt := range []struct {
|
||||
want []string
|
||||
got []string
|
||||
}{
|
||||
{[]string{"bar"}, s.Values()},
|
||||
{[]string{"foo", "bar", "baz"}, cp1.Values()},
|
||||
{[]string{"foo", "bar"}, cp2.Values()},
|
||||
{[]string{"bar"}, cp3.Values()},
|
||||
} {
|
||||
if !equal(tt.want, tt.got) {
|
||||
t.Fatalf("case %d: expect values=%v got %v", i, tt.want, tt.got)
|
||||
}
|
||||
}
|
||||
|
||||
for i, tt := range []struct {
|
||||
want bool
|
||||
got bool
|
||||
}{
|
||||
{true, s.Equals(cp3)},
|
||||
{true, cp3.Equals(s)},
|
||||
{false, s.Equals(cp2)},
|
||||
{false, s.Equals(cp1)},
|
||||
{false, cp1.Equals(s)},
|
||||
{false, cp2.Equals(s)},
|
||||
{false, cp2.Equals(cp1)},
|
||||
} {
|
||||
if tt.got != tt.want {
|
||||
t.Fatalf("case %d: want %t, got %t", i, tt.want, tt.got)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
// Subtract values from a Set, ensuring a new Set is created and
|
||||
// the original Sets are unmodified
|
||||
sub1 := cp1.Sub(s)
|
||||
sub2 := cp2.Sub(cp1)
|
||||
|
||||
for i, tt := range []struct {
|
||||
want []string
|
||||
got []string
|
||||
}{
|
||||
{[]string{"foo", "bar", "baz"}, cp1.Values()},
|
||||
{[]string{"foo", "bar"}, cp2.Values()},
|
||||
{[]string{"bar"}, s.Values()},
|
||||
{[]string{"foo", "baz"}, sub1.Values()},
|
||||
{[]string{}, sub2.Values()},
|
||||
} {
|
||||
if !equal(tt.want, tt.got) {
|
||||
t.Fatalf("case %d: expect values=%v got %v", i, tt.want, tt.got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnsafeSetContainsAll(t *testing.T) {
|
||||
vals := []string{"foo", "bar", "baz"}
|
||||
s := NewUnsafeSet(vals...)
|
||||
|
||||
tests := []struct {
|
||||
strs []string
|
||||
wcontain bool
|
||||
}{
|
||||
{[]string{}, true},
|
||||
{vals[:1], true},
|
||||
{vals[:2], true},
|
||||
{vals, true},
|
||||
{[]string{"cuz"}, false},
|
||||
{[]string{vals[0], "cuz"}, false},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
if g := s.ContainsAll(tt.strs); g != tt.wcontain {
|
||||
t.Errorf("#%d: ok = %v, want %v", i, g, tt.wcontain)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
// Uint64Slice implements sort interface
|
||||
type Uint64Slice []uint64
|
||||
|
||||
func (p Uint64Slice) Len() int { return len(p) }
|
||||
func (p Uint64Slice) Less(i, j int) bool { return p[i] < p[j] }
|
||||
func (p Uint64Slice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
|
||||
@@ -1,30 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestUint64Slice(t *testing.T) {
|
||||
g := Uint64Slice{10, 500, 5, 1, 100, 25}
|
||||
w := Uint64Slice{1, 5, 10, 25, 100, 500}
|
||||
sort.Sort(g)
|
||||
if !reflect.DeepEqual(g, w) {
|
||||
t.Errorf("slice after sort = %#v, want %#v", g, w)
|
||||
}
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type URLs []url.URL
|
||||
|
||||
func NewURLs(strs []string) (URLs, error) {
|
||||
all := make([]url.URL, len(strs))
|
||||
if len(all) == 0 {
|
||||
return nil, errors.New("no valid URLs given")
|
||||
}
|
||||
for i, in := range strs {
|
||||
in = strings.TrimSpace(in)
|
||||
u, err := url.Parse(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" && u.Scheme != "unix" && u.Scheme != "unixs" {
|
||||
return nil, fmt.Errorf("URL scheme must be http, https, unix, or unixs: %s", in)
|
||||
}
|
||||
if _, _, err := net.SplitHostPort(u.Host); err != nil {
|
||||
return nil, fmt.Errorf(`URL address does not have the form "host:port": %s`, in)
|
||||
}
|
||||
if u.Path != "" {
|
||||
return nil, fmt.Errorf("URL must not contain a path: %s", in)
|
||||
}
|
||||
all[i] = *u
|
||||
}
|
||||
us := URLs(all)
|
||||
us.Sort()
|
||||
|
||||
return us, nil
|
||||
}
|
||||
|
||||
func MustNewURLs(strs []string) URLs {
|
||||
urls, err := NewURLs(strs)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return urls
|
||||
}
|
||||
|
||||
func (us URLs) String() string {
|
||||
return strings.Join(us.StringSlice(), ",")
|
||||
}
|
||||
|
||||
func (us *URLs) Sort() {
|
||||
sort.Sort(us)
|
||||
}
|
||||
func (us URLs) Len() int { return len(us) }
|
||||
func (us URLs) Less(i, j int) bool { return us[i].String() < us[j].String() }
|
||||
func (us URLs) Swap(i, j int) { us[i], us[j] = us[j], us[i] }
|
||||
|
||||
func (us URLs) StringSlice() []string {
|
||||
out := make([]string, len(us))
|
||||
for i := range us {
|
||||
out[i] = us[i].String()
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
@@ -1,169 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"go.etcd.io/etcd/pkg/v3/testutil"
|
||||
)
|
||||
|
||||
func TestNewURLs(t *testing.T) {
|
||||
tests := []struct {
|
||||
strs []string
|
||||
wurls URLs
|
||||
}{
|
||||
{
|
||||
[]string{"http://127.0.0.1:2379"},
|
||||
testutil.MustNewURLs(t, []string{"http://127.0.0.1:2379"}),
|
||||
},
|
||||
// it can trim space
|
||||
{
|
||||
[]string{" http://127.0.0.1:2379 "},
|
||||
testutil.MustNewURLs(t, []string{"http://127.0.0.1:2379"}),
|
||||
},
|
||||
// it does sort
|
||||
{
|
||||
[]string{
|
||||
"http://127.0.0.2:2379",
|
||||
"http://127.0.0.1:2379",
|
||||
},
|
||||
testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.1:2379",
|
||||
"http://127.0.0.2:2379",
|
||||
}),
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
urls, _ := NewURLs(tt.strs)
|
||||
if !reflect.DeepEqual(urls, tt.wurls) {
|
||||
t.Errorf("#%d: urls = %+v, want %+v", i, urls, tt.wurls)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLsString(t *testing.T) {
|
||||
tests := []struct {
|
||||
us URLs
|
||||
wstr string
|
||||
}{
|
||||
{
|
||||
URLs{},
|
||||
"",
|
||||
},
|
||||
{
|
||||
testutil.MustNewURLs(t, []string{"http://127.0.0.1:2379"}),
|
||||
"http://127.0.0.1:2379",
|
||||
},
|
||||
{
|
||||
testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.1:2379",
|
||||
"http://127.0.0.2:2379",
|
||||
}),
|
||||
"http://127.0.0.1:2379,http://127.0.0.2:2379",
|
||||
},
|
||||
{
|
||||
testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.2:2379",
|
||||
"http://127.0.0.1:2379",
|
||||
}),
|
||||
"http://127.0.0.2:2379,http://127.0.0.1:2379",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
g := tt.us.String()
|
||||
if g != tt.wstr {
|
||||
t.Errorf("#%d: string = %s, want %s", i, g, tt.wstr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLsSort(t *testing.T) {
|
||||
g := testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.4:2379",
|
||||
"http://127.0.0.2:2379",
|
||||
"http://127.0.0.1:2379",
|
||||
"http://127.0.0.3:2379",
|
||||
})
|
||||
w := testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.1:2379",
|
||||
"http://127.0.0.2:2379",
|
||||
"http://127.0.0.3:2379",
|
||||
"http://127.0.0.4:2379",
|
||||
})
|
||||
gurls := URLs(g)
|
||||
gurls.Sort()
|
||||
if !reflect.DeepEqual(g, w) {
|
||||
t.Errorf("URLs after sort = %#v, want %#v", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestURLsStringSlice(t *testing.T) {
|
||||
tests := []struct {
|
||||
us URLs
|
||||
wstr []string
|
||||
}{
|
||||
{
|
||||
URLs{},
|
||||
[]string{},
|
||||
},
|
||||
{
|
||||
testutil.MustNewURLs(t, []string{"http://127.0.0.1:2379"}),
|
||||
[]string{"http://127.0.0.1:2379"},
|
||||
},
|
||||
{
|
||||
testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.1:2379",
|
||||
"http://127.0.0.2:2379",
|
||||
}),
|
||||
[]string{"http://127.0.0.1:2379", "http://127.0.0.2:2379"},
|
||||
},
|
||||
{
|
||||
testutil.MustNewURLs(t, []string{
|
||||
"http://127.0.0.2:2379",
|
||||
"http://127.0.0.1:2379",
|
||||
}),
|
||||
[]string{"http://127.0.0.2:2379", "http://127.0.0.1:2379"},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
g := tt.us.StringSlice()
|
||||
if !reflect.DeepEqual(g, tt.wstr) {
|
||||
t.Errorf("#%d: string slice = %+v, want %+v", i, g, tt.wstr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewURLsFail(t *testing.T) {
|
||||
tests := [][]string{
|
||||
// no urls given
|
||||
{},
|
||||
// missing protocol scheme
|
||||
{"://127.0.0.1:2379"},
|
||||
// unsupported scheme
|
||||
{"mailto://127.0.0.1:2379"},
|
||||
// not conform to host:port
|
||||
{"http://127.0.0.1"},
|
||||
// contain a path
|
||||
{"http://127.0.0.1:2379/path"},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
_, err := NewURLs(tt)
|
||||
if err == nil {
|
||||
t.Errorf("#%d: err = nil, but error", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,107 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// URLsMap is a map from a name to its URLs.
|
||||
type URLsMap map[string]URLs
|
||||
|
||||
// NewURLsMap returns a URLsMap instantiated from the given string,
|
||||
// which consists of discovery-formatted names-to-URLs, like:
|
||||
// mach0=http://1.1.1.1:2380,mach0=http://2.2.2.2::2380,mach1=http://3.3.3.3:2380,mach2=http://4.4.4.4:2380
|
||||
func NewURLsMap(s string) (URLsMap, error) {
|
||||
m := parse(s)
|
||||
|
||||
cl := URLsMap{}
|
||||
for name, urls := range m {
|
||||
us, err := NewURLs(urls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cl[name] = us
|
||||
}
|
||||
return cl, nil
|
||||
}
|
||||
|
||||
// NewURLsMapFromStringMap takes a map of strings and returns a URLsMap. The
|
||||
// string values in the map can be multiple values separated by the sep string.
|
||||
func NewURLsMapFromStringMap(m map[string]string, sep string) (URLsMap, error) {
|
||||
var err error
|
||||
um := URLsMap{}
|
||||
for k, v := range m {
|
||||
um[k], err = NewURLs(strings.Split(v, sep))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return um, nil
|
||||
}
|
||||
|
||||
// String turns URLsMap into discovery-formatted name-to-URLs sorted by name.
|
||||
func (c URLsMap) String() string {
|
||||
var pairs []string
|
||||
for name, urls := range c {
|
||||
for _, url := range urls {
|
||||
pairs = append(pairs, fmt.Sprintf("%s=%s", name, url.String()))
|
||||
}
|
||||
}
|
||||
sort.Strings(pairs)
|
||||
return strings.Join(pairs, ",")
|
||||
}
|
||||
|
||||
// URLs returns a list of all URLs.
|
||||
// The returned list is sorted in ascending lexicographical order.
|
||||
func (c URLsMap) URLs() []string {
|
||||
var urls []string
|
||||
for _, us := range c {
|
||||
for _, u := range us {
|
||||
urls = append(urls, u.String())
|
||||
}
|
||||
}
|
||||
sort.Strings(urls)
|
||||
return urls
|
||||
}
|
||||
|
||||
// Len returns the size of URLsMap.
|
||||
func (c URLsMap) Len() int {
|
||||
return len(c)
|
||||
}
|
||||
|
||||
// parse parses the given string and returns a map listing the values specified for each key.
|
||||
func parse(s string) map[string][]string {
|
||||
m := make(map[string][]string)
|
||||
for s != "" {
|
||||
key := s
|
||||
if i := strings.IndexAny(key, ","); i >= 0 {
|
||||
key, s = key[:i], key[i+1:]
|
||||
} else {
|
||||
s = ""
|
||||
}
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
value := ""
|
||||
if i := strings.Index(key, "="); i >= 0 {
|
||||
key, value = key[:i], key[i+1:]
|
||||
}
|
||||
m[key] = append(m[key], value)
|
||||
}
|
||||
return m
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"go.etcd.io/etcd/pkg/v3/testutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseInitialCluster(t *testing.T) {
|
||||
c, err := NewURLsMap("mem1=http://10.0.0.1:2379,mem1=http://128.193.4.20:2379,mem2=http://10.0.0.2:2379,default=http://127.0.0.1:2379")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
wc := URLsMap(map[string]URLs{
|
||||
"mem1": testutil.MustNewURLs(t, []string{"http://10.0.0.1:2379", "http://128.193.4.20:2379"}),
|
||||
"mem2": testutil.MustNewURLs(t, []string{"http://10.0.0.2:2379"}),
|
||||
"default": testutil.MustNewURLs(t, []string{"http://127.0.0.1:2379"}),
|
||||
})
|
||||
if !reflect.DeepEqual(c, wc) {
|
||||
t.Errorf("cluster = %+v, want %+v", c, wc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseInitialClusterBad(t *testing.T) {
|
||||
tests := []string{
|
||||
// invalid URL
|
||||
"%^",
|
||||
// no URL defined for member
|
||||
"mem1=,mem2=http://128.193.4.20:2379,mem3=http://10.0.0.2:2379",
|
||||
"mem1,mem2=http://128.193.4.20:2379,mem3=http://10.0.0.2:2379",
|
||||
// bad URL for member
|
||||
"default=http://localhost/",
|
||||
}
|
||||
for i, tt := range tests {
|
||||
if _, err := NewURLsMap(tt); err == nil {
|
||||
t.Errorf("#%d: unexpected successful parse, want err", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameURLPairsString(t *testing.T) {
|
||||
cls := URLsMap(map[string]URLs{
|
||||
"abc": testutil.MustNewURLs(t, []string{"http://1.1.1.1:1111", "http://0.0.0.0:0000"}),
|
||||
"def": testutil.MustNewURLs(t, []string{"http://2.2.2.2:2222"}),
|
||||
"ghi": testutil.MustNewURLs(t, []string{"http://3.3.3.3:1234", "http://127.0.0.1:2380"}),
|
||||
// no PeerURLs = not included
|
||||
"four": testutil.MustNewURLs(t, []string{}),
|
||||
"five": testutil.MustNewURLs(t, nil),
|
||||
})
|
||||
w := "abc=http://0.0.0.0:0000,abc=http://1.1.1.1:1111,def=http://2.2.2.2:2222,ghi=http://127.0.0.1:2380,ghi=http://3.3.3.3:1234"
|
||||
if g := cls.String(); g != w {
|
||||
t.Fatalf("NameURLPairs.String():\ngot %#v\nwant %#v", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
wm map[string][]string
|
||||
}{
|
||||
{
|
||||
"",
|
||||
map[string][]string{},
|
||||
},
|
||||
{
|
||||
"a=b",
|
||||
map[string][]string{"a": {"b"}},
|
||||
},
|
||||
{
|
||||
"a=b,a=c",
|
||||
map[string][]string{"a": {"b", "c"}},
|
||||
},
|
||||
{
|
||||
"a=b,a1=c",
|
||||
map[string][]string{"a": {"b"}, "a1": {"c"}},
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
m := parse(tt.s)
|
||||
if !reflect.DeepEqual(m, tt.wm) {
|
||||
t.Errorf("#%d: m = %+v, want %+v", i, m, tt.wm)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewURLsMapIPV6 is only tested in Go1.5+ because Go1.4 doesn't support literal IPv6 address with zone in
|
||||
// URI (https://github.com/golang/go/issues/6530).
|
||||
func TestNewURLsMapIPV6(t *testing.T) {
|
||||
c, err := NewURLsMap("mem1=http://[2001:db8::1]:2380,mem1=http://[fe80::6e40:8ff:feb1:58e4%25en0]:2380,mem2=http://[fe80::92e2:baff:fe7c:3224%25ext0]:2380")
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected parse error: %v", err)
|
||||
}
|
||||
wc := URLsMap(map[string]URLs{
|
||||
"mem1": testutil.MustNewURLs(t, []string{"http://[2001:db8::1]:2380", "http://[fe80::6e40:8ff:feb1:58e4%25en0]:2380"}),
|
||||
"mem2": testutil.MustNewURLs(t, []string{"http://[fe80::92e2:baff:fe7c:3224%25ext0]:2380"}),
|
||||
})
|
||||
if !reflect.DeepEqual(c, wc) {
|
||||
t.Errorf("cluster = %#v, want %#v", c, wc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewURLsMapFromStringMapEmpty(t *testing.T) {
|
||||
mss := make(map[string]string)
|
||||
urlsMap, err := NewURLsMapFromStringMap(mss, ",")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
s := ""
|
||||
um, err := NewURLsMap(s)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if um.String() != urlsMap.String() {
|
||||
t.Errorf("Expected:\n%+v\ngot:\n%+v", um, urlsMap)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewURLsMapFromStringMapNormal(t *testing.T) {
|
||||
mss := make(map[string]string)
|
||||
mss["host0"] = "http://127.0.0.1:2379,http://127.0.0.1:2380"
|
||||
mss["host1"] = "http://127.0.0.1:2381,http://127.0.0.1:2382"
|
||||
mss["host2"] = "http://127.0.0.1:2383,http://127.0.0.1:2384"
|
||||
mss["host3"] = "http://127.0.0.1:2385,http://127.0.0.1:2386"
|
||||
urlsMap, err := NewURLsMapFromStringMap(mss, ",")
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
s := "host0=http://127.0.0.1:2379,host0=http://127.0.0.1:2380," +
|
||||
"host1=http://127.0.0.1:2381,host1=http://127.0.0.1:2382," +
|
||||
"host2=http://127.0.0.1:2383,host2=http://127.0.0.1:2384," +
|
||||
"host3=http://127.0.0.1:2385,host3=http://127.0.0.1:2386"
|
||||
um, err := NewURLsMap(s)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if um.String() != urlsMap.String() {
|
||||
t.Errorf("Expected:\n%+v\ngot:\n%+v", um, urlsMap)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user