Merge pull request #3164 from yichengq/pin-endpoint

client: pin itself to an endpoint that given
This commit is contained in:
Yicheng Qin 2015-07-27 14:35:51 -07:00
commit 6184e271a4
2 changed files with 78 additions and 12 deletions

View File

@ -18,6 +18,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -131,6 +132,7 @@ type Client interface {
func New(cfg Config) (Client, error) { func New(cfg Config) (Client, error) {
c := &httpClusterClient{ c := &httpClusterClient{
clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()), clientFactory: newHTTPClientFactory(cfg.transport(), cfg.checkRedirect()),
rand: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
} }
if cfg.Username != "" { if cfg.Username != "" {
c.credentials = &credentials{ c.credentials = &credentials{
@ -174,8 +176,10 @@ type httpAction interface {
type httpClusterClient struct { type httpClusterClient struct {
clientFactory httpClientFactory clientFactory httpClientFactory
endpoints []url.URL endpoints []url.URL
pinned int
credentials *credentials credentials *credentials
sync.RWMutex sync.RWMutex
rand *rand.Rand
} }
func (c *httpClusterClient) reset(eps []string) error { func (c *httpClusterClient) reset(eps []string) error {
@ -192,7 +196,9 @@ func (c *httpClusterClient) reset(eps []string) error {
neps[i] = *u neps[i] = *u
} }
c.endpoints = neps c.endpoints = shuffleEndpoints(c.rand, neps)
// TODO: pin old endpoint if possible, and rebalance when new endpoint appears
c.pinned = 0
return nil return nil
} }
@ -203,6 +209,7 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
leps := len(c.endpoints) leps := len(c.endpoints)
eps := make([]url.URL, leps) eps := make([]url.URL, leps)
n := copy(eps, c.endpoints) n := copy(eps, c.endpoints)
pinned := c.pinned
if c.credentials != nil { if c.credentials != nil {
action = &authedAction{ action = &authedAction{
@ -224,8 +231,9 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
var body []byte var body []byte
var err error var err error
for _, ep := range eps { for i := pinned; i < leps+pinned; i++ {
hc := c.clientFactory(ep) k := i % leps
hc := c.clientFactory(eps[k])
resp, body, err = hc.Do(ctx, action) resp, body, err = hc.Do(ctx, action)
if err != nil { if err != nil {
if err == context.DeadlineExceeded || err == context.Canceled { if err == context.DeadlineExceeded || err == context.Canceled {
@ -236,6 +244,11 @@ func (c *httpClusterClient) Do(ctx context.Context, act httpAction) (*http.Respo
if resp.StatusCode/100 == 5 { if resp.StatusCode/100 == 5 {
continue continue
} }
if k != pinned {
c.Lock()
c.pinned = k
c.Unlock()
}
break break
} }
@ -401,3 +414,12 @@ func (r *redirectedHTTPAction) HTTPRequest(ep url.URL) *http.Request {
orig.URL = &r.location orig.URL = &r.location
return orig return orig
} }
func shuffleEndpoints(r *rand.Rand, eps []url.URL) []url.URL {
p := r.Perm(len(eps))
neps := make([]url.URL, len(eps))
for i, k := range p {
neps[i] = eps[k]
}
return neps
}

View File

@ -18,9 +18,11 @@ import (
"errors" "errors"
"io" "io"
"io/ioutil" "io/ioutil"
"math/rand"
"net/http" "net/http"
"net/url" "net/url"
"reflect" "reflect"
"sort"
"strings" "strings"
"testing" "testing"
"time" "time"
@ -299,9 +301,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
fakeErr := errors.New("fake!") fakeErr := errors.New("fake!")
fakeURL := url.URL{} fakeURL := url.URL{}
tests := []struct { tests := []struct {
client *httpClusterClient client *httpClusterClient
wantCode int wantCode int
wantErr error wantErr error
wantPinned int
}{ }{
// first good response short-circuits Do // first good response short-circuits Do
{ {
@ -313,6 +316,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
staticHTTPResponse{err: fakeErr}, staticHTTPResponse{err: fakeErr},
}, },
), ),
rand: rand.New(rand.NewSource(0)),
}, },
wantCode: http.StatusTeapot, wantCode: http.StatusTeapot,
}, },
@ -327,8 +331,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
}, },
), ),
rand: rand.New(rand.NewSource(0)),
}, },
wantCode: http.StatusTeapot, wantCode: http.StatusTeapot,
wantPinned: 1,
}, },
// context.DeadlineExceeded short-circuits Do // context.DeadlineExceeded short-circuits Do
@ -341,6 +347,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
}, },
), ),
rand: rand.New(rand.NewSource(0)),
}, },
wantErr: context.DeadlineExceeded, wantErr: context.DeadlineExceeded,
}, },
@ -355,6 +362,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
}, },
), ),
rand: rand.New(rand.NewSource(0)),
}, },
wantErr: context.Canceled, wantErr: context.Canceled,
}, },
@ -364,6 +372,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
client: &httpClusterClient{ client: &httpClusterClient{
endpoints: []url.URL{}, endpoints: []url.URL{},
clientFactory: newHTTPClientFactory(nil, nil), clientFactory: newHTTPClientFactory(nil, nil),
rand: rand.New(rand.NewSource(0)),
}, },
wantErr: ErrNoEndpoints, wantErr: ErrNoEndpoints,
}, },
@ -378,6 +387,7 @@ func TestHTTPClusterClientDo(t *testing.T) {
staticHTTPResponse{err: fakeErr}, staticHTTPResponse{err: fakeErr},
}, },
), ),
rand: rand.New(rand.NewSource(0)),
}, },
wantErr: fakeErr, wantErr: fakeErr,
}, },
@ -392,8 +402,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}}, staticHTTPResponse{resp: http.Response{StatusCode: http.StatusTeapot}},
}, },
), ),
rand: rand.New(rand.NewSource(0)),
}, },
wantCode: http.StatusTeapot, wantCode: http.StatusTeapot,
wantPinned: 1,
}, },
} }
@ -415,6 +427,10 @@ func TestHTTPClusterClientDo(t *testing.T) {
t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode) t.Errorf("#%d: resp code=%d, want=%d", i, resp.StatusCode, tt.wantCode)
continue continue
} }
if tt.client.pinned != tt.wantPinned {
t.Errorf("#%d: pinned=%d, want=%d", i, tt.client.pinned, tt.wantPinned)
}
} }
} }
@ -671,7 +687,10 @@ func TestHTTPClusterClientSync(t *testing.T) {
}, },
}) })
hc := &httpClusterClient{clientFactory: cf} hc := &httpClusterClient{
clientFactory: cf,
rand: rand.New(rand.NewSource(0)),
}
err := hc.reset([]string{"http://127.0.0.1:2379"}) err := hc.reset([]string{"http://127.0.0.1:2379"})
if err != nil { if err != nil {
t.Fatalf("unexpected error during setup: %#v", err) t.Fatalf("unexpected error during setup: %#v", err)
@ -688,8 +707,9 @@ func TestHTTPClusterClientSync(t *testing.T) {
t.Fatalf("unexpected error during Sync: %#v", err) t.Fatalf("unexpected error during Sync: %#v", err)
} }
want = []string{"http://127.0.0.1:4003", "http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002"} want = []string{"http://127.0.0.1:2379", "http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"}
got = hc.Endpoints() got = hc.Endpoints()
sort.Sort(sort.StringSlice(got))
if !reflect.DeepEqual(want, got) { if !reflect.DeepEqual(want, got) {
t.Fatalf("incorrect endpoints post-Sync: want=%#v got=%#v", want, got) t.Fatalf("incorrect endpoints post-Sync: want=%#v got=%#v", want, got)
} }
@ -711,7 +731,10 @@ func TestHTTPClusterClientSyncFail(t *testing.T) {
staticHTTPResponse{err: errors.New("fail!")}, staticHTTPResponse{err: errors.New("fail!")},
}) })
hc := &httpClusterClient{clientFactory: cf} hc := &httpClusterClient{
clientFactory: cf,
rand: rand.New(rand.NewSource(0)),
}
err := hc.reset([]string{"http://127.0.0.1:2379"}) err := hc.reset([]string{"http://127.0.0.1:2379"})
if err != nil { if err != nil {
t.Fatalf("unexpected error during setup: %#v", err) t.Fatalf("unexpected error during setup: %#v", err)
@ -744,10 +767,31 @@ func TestHTTPClusterClientResetFail(t *testing.T) {
} }
for i, tt := range tests { for i, tt := range tests {
hc := &httpClusterClient{} hc := &httpClusterClient{rand: rand.New(rand.NewSource(0))}
err := hc.reset(tt) err := hc.reset(tt)
if err == nil { if err == nil {
t.Errorf("#%d: expected non-nil error", i) t.Errorf("#%d: expected non-nil error", i)
} }
} }
} }
func TestHTTPClusterClientResetPinRandom(t *testing.T) {
round := 2000
pinNum := 0
for i := 0; i < round; i++ {
hc := &httpClusterClient{rand: rand.New(rand.NewSource(int64(i)))}
err := hc.reset([]string{"http://127.0.0.1:4001", "http://127.0.0.1:4002", "http://127.0.0.1:4003"})
if err != nil {
t.Fatalf("#%d: reset error (%v)", i, err)
}
if hc.endpoints[hc.pinned].String() == "http://127.0.0.1:4001" {
pinNum++
}
}
min := 1.0/3.0 - 0.05
max := 1.0/3.0 + 0.05
if ratio := float64(pinNum) / float64(round); ratio > max || ratio < min {
t.Errorf("pinned ratio = %v, want [%v, %v]", ratio, min, max)
}
}