proxy: rewrite stdlib ReverseProxy

The ReverseProxy code from the standard library doesn't actually
give us the control that we want. Pull it down and rip out what
we don't need, adding tests in the process.

All available endpoints are attempted when proxying a request. If a
proxied request fails, the upstream will be considered unavailable
for 5s and no more requests will be proxied to it. After the 5s is
up, the endpoint will be put back to rotation.
This commit is contained in:
Brian Waldon 2014-09-10 16:38:52 -07:00
parent df253a2b14
commit a155f0bda6
6 changed files with 531 additions and 119 deletions

106
proxy/director.go Normal file
View File

@ -0,0 +1,106 @@
package proxy
import (
"errors"
"fmt"
"log"
"net/url"
"sync"
"time"
)
const (
// amount of time an endpoint will be held in a failed
// state before being reconsidered for proxied requests
endpointFailureWait = 5 * time.Second
)
func newDirector(urls []string) (*director, error) {
if len(urls) == 0 {
return nil, errors.New("one or more endpoints required")
}
endpoints := make([]*endpoint, len(urls))
for i, v := range urls {
u, err := url.Parse(v)
if err != nil {
return nil, fmt.Errorf("invalid endpoint %q: %v", v, err)
}
if u.Scheme == "" {
return nil, fmt.Errorf("invalid endpoint %q: scheme required", v)
}
if u.Host == "" {
return nil, fmt.Errorf("invalid endpoint %q: host empty", v)
}
endpoints[i] = newEndpoint(*u)
}
d := director{ep: endpoints}
return &d, nil
}
type director struct {
ep []*endpoint
}
func (d *director) endpoints() []*endpoint {
filtered := make([]*endpoint, 0)
for _, ep := range d.ep {
if ep.Available {
filtered = append(filtered, ep)
}
}
return filtered
}
func newEndpoint(u url.URL) *endpoint {
ep := endpoint{
URL: u,
Available: true,
failFunc: timedUnavailabilityFunc(endpointFailureWait),
}
return &ep
}
type endpoint struct {
sync.Mutex
URL url.URL
Available bool
failFunc func(ep *endpoint)
}
func (ep *endpoint) Failed() {
ep.Lock()
if !ep.Available {
ep.Unlock()
return
}
ep.Available = false
ep.Unlock()
log.Printf("proxy: marked endpoint %s unavailable", ep.URL.String())
if ep.failFunc == nil {
log.Printf("proxy: no failFunc defined, endpoint %s will be unavailable forever.", ep.URL.String())
return
}
ep.failFunc(ep)
}
func timedUnavailabilityFunc(wait time.Duration) func(*endpoint) {
return func(ep *endpoint) {
time.AfterFunc(wait, func() {
ep.Available = true
log.Printf("proxy: marked endpoint %s available", ep.URL.String())
})
}
}

61
proxy/director_test.go Normal file
View File

@ -0,0 +1,61 @@
package proxy
import (
"net/url"
"reflect"
"testing"
)
func TestNewDirectorEndpointValidation(t *testing.T) {
tests := []struct {
good bool
endpoints []string
}{
{true, []string{"http://192.0.2.8"}},
{true, []string{"http://192.0.2.8:8001"}},
{true, []string{"http://example.com"}},
{true, []string{"http://example.com:8001"}},
{true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}},
{false, []string{"://"}},
{false, []string{"http://"}},
{false, []string{"192.0.2.8"}},
{false, []string{"192.0.2.8:8001"}},
{false, []string{""}},
{false, []string{}},
}
for i, tt := range tests {
_, err := newDirector(tt.endpoints)
if tt.good != (err == nil) {
t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err)
}
}
}
func TestDirectorEndpointsFiltering(t *testing.T) {
d := director{
ep: []*endpoint{
&endpoint{
URL: url.URL{Scheme: "http", Host: "192.0.2.5:5050"},
Available: false,
},
&endpoint{
URL: url.URL{Scheme: "http", Host: "192.0.2.4:4000"},
Available: true,
},
},
}
got := d.endpoints()
want := []*endpoint{
&endpoint{
URL: url.URL{Scheme: "http", Host: "192.0.2.4:4000"},
Available: true,
},
}
if !reflect.DeepEqual(want, got) {
t.Fatalf("directed to incorrect endpoint: want = %#v, got = %#v", want, got)
}
}

View File

@ -1,64 +1,33 @@
package proxy
import (
"errors"
"fmt"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
)
func NewHandler(endpoints []string) (*httputil.ReverseProxy, error) {
const (
dialTimeout = 30 * time.Second
responseHeaderTimeout = 30 * time.Second
)
func NewHandler(endpoints []string) (http.Handler, error) {
d, err := newDirector(endpoints)
if err != nil {
return nil, err
}
proxy := httputil.ReverseProxy{
Director: d.direct,
Transport: &http.Transport{},
FlushInterval: 0,
tr := http.Transport{
Dial: func(network, address string) (net.Conn, error) {
return net.DialTimeout(network, address, dialTimeout)
},
ResponseHeaderTimeout: responseHeaderTimeout,
}
return &proxy, nil
}
func newDirector(endpoints []string) (*director, error) {
if len(endpoints) == 0 {
return nil, errors.New("one or more endpoints required")
rp := reverseProxy{
director: d,
transport: &tr,
}
urls := make([]url.URL, len(endpoints))
for i, e := range endpoints {
u, err := url.Parse(e)
if err != nil {
return nil, fmt.Errorf("invalid endpoint %q: %v", e, err)
}
if u.Scheme == "" {
return nil, fmt.Errorf("invalid endpoint %q: scheme required", e)
}
if u.Host == "" {
return nil, fmt.Errorf("invalid endpoint %q: host empty", e)
}
urls[i] = *u
}
d := director{
endpoints: urls,
}
return &d, nil
}
type director struct {
endpoints []url.URL
}
func (d *director) direct(req *http.Request) {
choice := d.endpoints[0]
req.URL.Scheme = choice.Scheme
req.URL.Host = choice.Host
return &rp, nil
}

View File

@ -1,71 +0,0 @@
package proxy
import (
"net/http"
"net/url"
"reflect"
"testing"
)
func TestNewDirector(t *testing.T) {
tests := []struct {
good bool
endpoints []string
}{
{true, []string{"http://192.0.2.8"}},
{true, []string{"http://192.0.2.8:8001"}},
{true, []string{"http://example.com"}},
{true, []string{"http://example.com:8001"}},
{true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}},
{false, []string{"192.0.2.8"}},
{false, []string{"192.0.2.8:8001"}},
{false, []string{""}},
}
for i, tt := range tests {
_, err := newDirector(tt.endpoints)
if tt.good != (err == nil) {
t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err)
}
}
}
func TestDirectorDirect(t *testing.T) {
d := &director{
endpoints: []url.URL{
url.URL{
Scheme: "http",
Host: "bar.example.com",
},
},
}
req := &http.Request{
Method: "GET",
Host: "foo.example.com",
URL: &url.URL{
Host: "foo.example.com",
Path: "/v2/keys/baz",
},
}
d.direct(req)
want := &http.Request{
Method: "GET",
// this field must not change
Host: "foo.example.com",
URL: &url.URL{
// the Scheme field is updated per the director's first endpoint
Scheme: "http",
// the Host field is updated per the director's first endpoint
Host: "bar.example.com",
Path: "/v2/keys/baz",
},
}
if !reflect.DeepEqual(want, req) {
t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
}
}

120
proxy/reverse.go Normal file
View File

@ -0,0 +1,120 @@
package proxy
import (
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
)
// Hop-by-hop headers. These are removed when sent to the backend.
// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
// This list of headers borrowed from stdlib httputil.ReverseProxy
var singleHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te", // canonicalized version of "TE"
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
func removeSingleHopHeaders(hdrs *http.Header) {
for _, h := range singleHopHeaders {
hdrs.Del(h)
}
}
type reverseProxy struct {
director *director
transport http.RoundTripper
}
func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request) {
proxyreq := new(http.Request)
*proxyreq = *clientreq
// deep-copy the headers, as these will be modified below
proxyreq.Header = make(http.Header)
copyHeader(proxyreq.Header, clientreq.Header)
normalizeRequest(proxyreq)
removeSingleHopHeaders(&proxyreq.Header)
maybeSetForwardedFor(proxyreq)
endpoints := p.director.endpoints()
if len(endpoints) == 0 {
log.Printf("proxy: zero endpoints currently available")
rw.WriteHeader(http.StatusServiceUnavailable)
return
}
var res *http.Response
var err error
for _, ep := range endpoints {
redirectRequest(proxyreq, ep.URL)
res, err = p.transport.RoundTrip(proxyreq)
if err != nil {
log.Printf("proxy: failed to direct request to %s: %v", ep.URL.String(), err)
ep.Failed()
continue
}
break
}
if res == nil {
log.Printf("proxy: unable to get response from %d endpoint(s)", len(endpoints))
rw.WriteHeader(http.StatusBadGateway)
return
}
defer res.Body.Close()
removeSingleHopHeaders(&res.Header)
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
io.Copy(rw, res.Body)
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func redirectRequest(req *http.Request, loc url.URL) {
req.URL.Scheme = loc.Scheme
req.URL.Host = loc.Host
}
func normalizeRequest(req *http.Request) {
req.Proto = "HTTP/1.1"
req.ProtoMajor = 1
req.ProtoMinor = 1
req.Close = false
}
func maybeSetForwardedFor(req *http.Request) {
clientIP, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return
}
// If we aren't the first proxy retain prior
// X-Forwarded-For information as a comma+space
// separated list and fold multiple headers into one.
if prior, ok := req.Header["X-Forwarded-For"]; ok {
clientIP = strings.Join(prior, ", ") + ", " + clientIP
}
req.Header.Set("X-Forwarded-For", clientIP)
}

227
proxy/reverse_test.go Normal file
View File

@ -0,0 +1,227 @@
package proxy
import (
"bytes"
"errors"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
"testing"
)
type staticRoundTripper struct {
res *http.Response
err error
}
func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return srt.res, srt.err
}
func TestReverseProxyServe(t *testing.T) {
u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"}
tests := []struct {
eps []*endpoint
rt http.RoundTripper
want int
}{
// no endpoints available so no requests are even made
{
eps: []*endpoint{},
rt: &staticRoundTripper{
res: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(&bytes.Reader{}),
},
},
want: http.StatusServiceUnavailable,
},
// error is returned from one endpoint that should be available
{
eps: []*endpoint{&endpoint{URL: u, Available: true}},
rt: &staticRoundTripper{err: errors.New("what a bad trip")},
want: http.StatusBadGateway,
},
// endpoint is available and returns success
{
eps: []*endpoint{&endpoint{URL: u, Available: true}},
rt: &staticRoundTripper{
res: &http.Response{
StatusCode: http.StatusCreated,
Body: ioutil.NopCloser(&bytes.Reader{}),
},
},
want: http.StatusCreated,
},
}
for i, tt := range tests {
rp := reverseProxy{
director: &director{tt.eps},
transport: tt.rt,
}
req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil)
rr := httptest.NewRecorder()
rp.ServeHTTP(rr, req)
if rr.Code != tt.want {
t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code)
}
}
}
func TestRedirectRequest(t *testing.T) {
loc := url.URL{
Scheme: "http",
Host: "bar.example.com",
}
req := &http.Request{
Method: "GET",
Host: "foo.example.com",
URL: &url.URL{
Host: "foo.example.com",
Path: "/v2/keys/baz",
},
}
redirectRequest(req, loc)
want := &http.Request{
Method: "GET",
// this field must not change
Host: "foo.example.com",
URL: &url.URL{
// the Scheme field is updated to that of the provided URL
Scheme: "http",
// the Host field is updated to that of the provided URL
Host: "bar.example.com",
Path: "/v2/keys/baz",
},
}
if !reflect.DeepEqual(want, req) {
t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req)
}
}
func TestMaybeSetForwardedFor(t *testing.T) {
tests := []struct {
raddr string
fwdFor string
want string
}{
{"192.0.2.3:8002", "", "192.0.2.3"},
{"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"},
{"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"},
{"example.com:8002", "", "example.com"},
// While these cases look valid, golang net/http will not let it happen
// The RemoteAddr field will always be a valid host:port
{":8002", "", ""},
{"192.0.2.3", "", ""},
// blatantly invalid host w/o a port
{"12", "", ""},
{"12", "192.0.2.3", "192.0.2.3"},
}
for i, tt := range tests {
req := &http.Request{
RemoteAddr: tt.raddr,
Header: make(http.Header),
}
if tt.fwdFor != "" {
req.Header.Set("X-Forwarded-For", tt.fwdFor)
}
maybeSetForwardedFor(req)
got := req.Header.Get("X-Forwarded-For")
if tt.want != got {
t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got)
}
}
}
func TestRemoveSingleHopHeaders(t *testing.T) {
hdr := http.Header(map[string][]string{
// single-hop headers that should be removed
"Connection": []string{"close"},
"Keep-Alive": []string{"foo"},
"Proxy-Authenticate": []string{"Basic realm=example.com"},
"Proxy-Authorization": []string{"foo"},
"Te": []string{"deflate,gzip"},
"Trailers": []string{"ETag"},
"Transfer-Encoding": []string{"chunked"},
"Upgrade": []string{"WebSocket"},
// headers that should persist
"Accept": []string{"application/json"},
"X-Foo": []string{"Bar"},
})
removeSingleHopHeaders(&hdr)
want := http.Header(map[string][]string{
"Accept": []string{"application/json"},
"X-Foo": []string{"Bar"},
})
if !reflect.DeepEqual(want, hdr) {
t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr)
}
}
func TestCopyHeader(t *testing.T) {
tests := []struct {
src http.Header
dst http.Header
want http.Header
}{
{
src: http.Header(map[string][]string{
"Foo": []string{"bar", "baz"},
}),
dst: http.Header(map[string][]string{}),
want: http.Header(map[string][]string{
"Foo": []string{"bar", "baz"},
}),
},
{
src: http.Header(map[string][]string{
"Foo": []string{"bar"},
"Ping": []string{"pong"},
}),
dst: http.Header(map[string][]string{}),
want: http.Header(map[string][]string{
"Foo": []string{"bar"},
"Ping": []string{"pong"},
}),
},
{
src: http.Header(map[string][]string{
"Foo": []string{"bar", "baz"},
}),
dst: http.Header(map[string][]string{
"Foo": []string{"qux"},
}),
want: http.Header(map[string][]string{
"Foo": []string{"qux", "bar", "baz"},
}),
},
}
for i, tt := range tests {
copyHeader(tt.dst, tt.src)
if !reflect.DeepEqual(tt.dst, tt.want) {
t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst)
}
}
}