diff --git a/proxy/director.go b/proxy/director.go new file mode 100644 index 000000000..2e124fa05 --- /dev/null +++ b/proxy/director.go @@ -0,0 +1,106 @@ +package proxy + +import ( + "errors" + "fmt" + "log" + "net/url" + "sync" + "time" +) + +const ( + // amount of time an endpoint will be held in a failed + // state before being reconsidered for proxied requests + endpointFailureWait = 5 * time.Second +) + +func newDirector(urls []string) (*director, error) { + if len(urls) == 0 { + return nil, errors.New("one or more endpoints required") + } + + endpoints := make([]*endpoint, len(urls)) + for i, v := range urls { + u, err := url.Parse(v) + if err != nil { + return nil, fmt.Errorf("invalid endpoint %q: %v", v, err) + } + + if u.Scheme == "" { + return nil, fmt.Errorf("invalid endpoint %q: scheme required", v) + } + + if u.Host == "" { + return nil, fmt.Errorf("invalid endpoint %q: host empty", v) + } + + endpoints[i] = newEndpoint(*u) + } + + d := director{ep: endpoints} + return &d, nil +} + +type director struct { + ep []*endpoint +} + +func (d *director) endpoints() []*endpoint { + filtered := make([]*endpoint, 0) + for _, ep := range d.ep { + if ep.Available { + filtered = append(filtered, ep) + } + } + + return filtered +} + +func newEndpoint(u url.URL) *endpoint { + ep := endpoint{ + URL: u, + Available: true, + failFunc: timedUnavailabilityFunc(endpointFailureWait), + } + + return &ep +} + +type endpoint struct { + sync.Mutex + + URL url.URL + Available bool + + failFunc func(ep *endpoint) +} + +func (ep *endpoint) Failed() { + ep.Lock() + if !ep.Available { + ep.Unlock() + return + } + + ep.Available = false + ep.Unlock() + + log.Printf("proxy: marked endpoint %s unavailable", ep.URL.String()) + + if ep.failFunc == nil { + log.Printf("proxy: no failFunc defined, endpoint %s will be unavailable forever.", ep.URL.String()) + return + } + + ep.failFunc(ep) +} + +func timedUnavailabilityFunc(wait time.Duration) func(*endpoint) { + return func(ep *endpoint) { + time.AfterFunc(wait, func() { + ep.Available = true + log.Printf("proxy: marked endpoint %s available", ep.URL.String()) + }) + } +} diff --git a/proxy/director_test.go b/proxy/director_test.go new file mode 100644 index 000000000..bd6ce675f --- /dev/null +++ b/proxy/director_test.go @@ -0,0 +1,61 @@ +package proxy + +import ( + "net/url" + "reflect" + "testing" +) + +func TestNewDirectorEndpointValidation(t *testing.T) { + tests := []struct { + good bool + endpoints []string + }{ + {true, []string{"http://192.0.2.8"}}, + {true, []string{"http://192.0.2.8:8001"}}, + {true, []string{"http://example.com"}}, + {true, []string{"http://example.com:8001"}}, + {true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}}, + + {false, []string{"://"}}, + {false, []string{"http://"}}, + {false, []string{"192.0.2.8"}}, + {false, []string{"192.0.2.8:8001"}}, + {false, []string{""}}, + {false, []string{}}, + } + + for i, tt := range tests { + _, err := newDirector(tt.endpoints) + if tt.good != (err == nil) { + t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err) + } + } +} + +func TestDirectorEndpointsFiltering(t *testing.T) { + d := director{ + ep: []*endpoint{ + &endpoint{ + URL: url.URL{Scheme: "http", Host: "192.0.2.5:5050"}, + Available: false, + }, + &endpoint{ + URL: url.URL{Scheme: "http", Host: "192.0.2.4:4000"}, + Available: true, + }, + }, + } + + got := d.endpoints() + want := []*endpoint{ + &endpoint{ + URL: url.URL{Scheme: "http", Host: "192.0.2.4:4000"}, + Available: true, + }, + } + + if !reflect.DeepEqual(want, got) { + t.Fatalf("directed to incorrect endpoint: want = %#v, got = %#v", want, got) + } +} diff --git a/proxy/proxy.go b/proxy/proxy.go index 167b46ad5..1f9c0d7a4 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -1,64 +1,33 @@ package proxy import ( - "errors" - "fmt" + "net" "net/http" - "net/http/httputil" - "net/url" + "time" ) -func NewHandler(endpoints []string) (*httputil.ReverseProxy, error) { +const ( + dialTimeout = 30 * time.Second + responseHeaderTimeout = 30 * time.Second +) + +func NewHandler(endpoints []string) (http.Handler, error) { d, err := newDirector(endpoints) if err != nil { return nil, err } - proxy := httputil.ReverseProxy{ - Director: d.direct, - Transport: &http.Transport{}, - FlushInterval: 0, + tr := http.Transport{ + Dial: func(network, address string) (net.Conn, error) { + return net.DialTimeout(network, address, dialTimeout) + }, + ResponseHeaderTimeout: responseHeaderTimeout, } - return &proxy, nil -} - -func newDirector(endpoints []string) (*director, error) { - if len(endpoints) == 0 { - return nil, errors.New("one or more endpoints required") + rp := reverseProxy{ + director: d, + transport: &tr, } - urls := make([]url.URL, len(endpoints)) - for i, e := range endpoints { - u, err := url.Parse(e) - if err != nil { - return nil, fmt.Errorf("invalid endpoint %q: %v", e, err) - } - - if u.Scheme == "" { - return nil, fmt.Errorf("invalid endpoint %q: scheme required", e) - } - - if u.Host == "" { - return nil, fmt.Errorf("invalid endpoint %q: host empty", e) - } - - urls[i] = *u - } - - d := director{ - endpoints: urls, - } - - return &d, nil -} - -type director struct { - endpoints []url.URL -} - -func (d *director) direct(req *http.Request) { - choice := d.endpoints[0] - req.URL.Scheme = choice.Scheme - req.URL.Host = choice.Host + return &rp, nil } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go deleted file mode 100644 index 707660cc2..000000000 --- a/proxy/proxy_test.go +++ /dev/null @@ -1,71 +0,0 @@ -package proxy - -import ( - "net/http" - "net/url" - "reflect" - "testing" -) - -func TestNewDirector(t *testing.T) { - tests := []struct { - good bool - endpoints []string - }{ - {true, []string{"http://192.0.2.8"}}, - {true, []string{"http://192.0.2.8:8001"}}, - {true, []string{"http://example.com"}}, - {true, []string{"http://example.com:8001"}}, - {true, []string{"http://192.0.2.8:8001", "http://example.com:8002"}}, - - {false, []string{"192.0.2.8"}}, - {false, []string{"192.0.2.8:8001"}}, - {false, []string{""}}, - } - - for i, tt := range tests { - _, err := newDirector(tt.endpoints) - if tt.good != (err == nil) { - t.Errorf("#%d: expected success = %t, got err = %v", i, tt.good, err) - } - } -} - -func TestDirectorDirect(t *testing.T) { - d := &director{ - endpoints: []url.URL{ - url.URL{ - Scheme: "http", - Host: "bar.example.com", - }, - }, - } - - req := &http.Request{ - Method: "GET", - Host: "foo.example.com", - URL: &url.URL{ - Host: "foo.example.com", - Path: "/v2/keys/baz", - }, - } - - d.direct(req) - - want := &http.Request{ - Method: "GET", - // this field must not change - Host: "foo.example.com", - URL: &url.URL{ - // the Scheme field is updated per the director's first endpoint - Scheme: "http", - // the Host field is updated per the director's first endpoint - Host: "bar.example.com", - Path: "/v2/keys/baz", - }, - } - - if !reflect.DeepEqual(want, req) { - t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req) - } -} diff --git a/proxy/reverse.go b/proxy/reverse.go new file mode 100644 index 000000000..5ab4a2514 --- /dev/null +++ b/proxy/reverse.go @@ -0,0 +1,120 @@ +package proxy + +import ( + "io" + "log" + "net" + "net/http" + "net/url" + "strings" +) + +// Hop-by-hop headers. These are removed when sent to the backend. +// http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html +// This list of headers borrowed from stdlib httputil.ReverseProxy +var singleHopHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailers", + "Transfer-Encoding", + "Upgrade", +} + +func removeSingleHopHeaders(hdrs *http.Header) { + for _, h := range singleHopHeaders { + hdrs.Del(h) + } +} + +type reverseProxy struct { + director *director + transport http.RoundTripper +} + +func (p *reverseProxy) ServeHTTP(rw http.ResponseWriter, clientreq *http.Request) { + proxyreq := new(http.Request) + *proxyreq = *clientreq + + // deep-copy the headers, as these will be modified below + proxyreq.Header = make(http.Header) + copyHeader(proxyreq.Header, clientreq.Header) + + normalizeRequest(proxyreq) + removeSingleHopHeaders(&proxyreq.Header) + maybeSetForwardedFor(proxyreq) + + endpoints := p.director.endpoints() + if len(endpoints) == 0 { + log.Printf("proxy: zero endpoints currently available") + rw.WriteHeader(http.StatusServiceUnavailable) + return + } + + var res *http.Response + var err error + + for _, ep := range endpoints { + redirectRequest(proxyreq, ep.URL) + + res, err = p.transport.RoundTrip(proxyreq) + if err != nil { + log.Printf("proxy: failed to direct request to %s: %v", ep.URL.String(), err) + ep.Failed() + continue + } + + break + } + + if res == nil { + log.Printf("proxy: unable to get response from %d endpoint(s)", len(endpoints)) + rw.WriteHeader(http.StatusBadGateway) + return + } + + defer res.Body.Close() + + removeSingleHopHeaders(&res.Header) + copyHeader(rw.Header(), res.Header) + + rw.WriteHeader(res.StatusCode) + io.Copy(rw, res.Body) +} + +func copyHeader(dst, src http.Header) { + for k, vv := range src { + for _, v := range vv { + dst.Add(k, v) + } + } +} + +func redirectRequest(req *http.Request, loc url.URL) { + req.URL.Scheme = loc.Scheme + req.URL.Host = loc.Host +} + +func normalizeRequest(req *http.Request) { + req.Proto = "HTTP/1.1" + req.ProtoMajor = 1 + req.ProtoMinor = 1 + req.Close = false +} + +func maybeSetForwardedFor(req *http.Request) { + clientIP, _, err := net.SplitHostPort(req.RemoteAddr) + if err != nil { + return + } + + // If we aren't the first proxy retain prior + // X-Forwarded-For information as a comma+space + // separated list and fold multiple headers into one. + if prior, ok := req.Header["X-Forwarded-For"]; ok { + clientIP = strings.Join(prior, ", ") + ", " + clientIP + } + req.Header.Set("X-Forwarded-For", clientIP) +} diff --git a/proxy/reverse_test.go b/proxy/reverse_test.go new file mode 100644 index 000000000..fc6b8ad9e --- /dev/null +++ b/proxy/reverse_test.go @@ -0,0 +1,227 @@ +package proxy + +import ( + "bytes" + "errors" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "reflect" + "testing" +) + +type staticRoundTripper struct { + res *http.Response + err error +} + +func (srt *staticRoundTripper) RoundTrip(*http.Request) (*http.Response, error) { + return srt.res, srt.err +} + +func TestReverseProxyServe(t *testing.T) { + u := url.URL{Scheme: "http", Host: "192.0.2.3:4040"} + + tests := []struct { + eps []*endpoint + rt http.RoundTripper + want int + }{ + // no endpoints available so no requests are even made + { + eps: []*endpoint{}, + rt: &staticRoundTripper{ + res: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(&bytes.Reader{}), + }, + }, + want: http.StatusServiceUnavailable, + }, + + // error is returned from one endpoint that should be available + { + eps: []*endpoint{&endpoint{URL: u, Available: true}}, + rt: &staticRoundTripper{err: errors.New("what a bad trip")}, + want: http.StatusBadGateway, + }, + + // endpoint is available and returns success + { + eps: []*endpoint{&endpoint{URL: u, Available: true}}, + rt: &staticRoundTripper{ + res: &http.Response{ + StatusCode: http.StatusCreated, + Body: ioutil.NopCloser(&bytes.Reader{}), + }, + }, + want: http.StatusCreated, + }, + } + + for i, tt := range tests { + rp := reverseProxy{ + director: &director{tt.eps}, + transport: tt.rt, + } + + req, _ := http.NewRequest("GET", "http://192.0.2.2:4001", nil) + rr := httptest.NewRecorder() + rp.ServeHTTP(rr, req) + + if rr.Code != tt.want { + t.Errorf("#%d: unexpected HTTP status code: want = %d, got = %d", i, tt.want, rr.Code) + } + } +} + +func TestRedirectRequest(t *testing.T) { + loc := url.URL{ + Scheme: "http", + Host: "bar.example.com", + } + + req := &http.Request{ + Method: "GET", + Host: "foo.example.com", + URL: &url.URL{ + Host: "foo.example.com", + Path: "/v2/keys/baz", + }, + } + + redirectRequest(req, loc) + + want := &http.Request{ + Method: "GET", + // this field must not change + Host: "foo.example.com", + URL: &url.URL{ + // the Scheme field is updated to that of the provided URL + Scheme: "http", + // the Host field is updated to that of the provided URL + Host: "bar.example.com", + Path: "/v2/keys/baz", + }, + } + + if !reflect.DeepEqual(want, req) { + t.Fatalf("HTTP request does not match expected criteria: want=%#v got=%#v", want, req) + } +} + +func TestMaybeSetForwardedFor(t *testing.T) { + tests := []struct { + raddr string + fwdFor string + want string + }{ + {"192.0.2.3:8002", "", "192.0.2.3"}, + {"192.0.2.3:8002", "192.0.2.2", "192.0.2.2, 192.0.2.3"}, + {"192.0.2.3:8002", "192.0.2.1, 192.0.2.2", "192.0.2.1, 192.0.2.2, 192.0.2.3"}, + {"example.com:8002", "", "example.com"}, + + // While these cases look valid, golang net/http will not let it happen + // The RemoteAddr field will always be a valid host:port + {":8002", "", ""}, + {"192.0.2.3", "", ""}, + + // blatantly invalid host w/o a port + {"12", "", ""}, + {"12", "192.0.2.3", "192.0.2.3"}, + } + + for i, tt := range tests { + req := &http.Request{ + RemoteAddr: tt.raddr, + Header: make(http.Header), + } + + if tt.fwdFor != "" { + req.Header.Set("X-Forwarded-For", tt.fwdFor) + } + + maybeSetForwardedFor(req) + got := req.Header.Get("X-Forwarded-For") + if tt.want != got { + t.Errorf("#%d: incorrect header: want = %q, got = %q", i, tt.want, got) + } + } +} + +func TestRemoveSingleHopHeaders(t *testing.T) { + hdr := http.Header(map[string][]string{ + // single-hop headers that should be removed + "Connection": []string{"close"}, + "Keep-Alive": []string{"foo"}, + "Proxy-Authenticate": []string{"Basic realm=example.com"}, + "Proxy-Authorization": []string{"foo"}, + "Te": []string{"deflate,gzip"}, + "Trailers": []string{"ETag"}, + "Transfer-Encoding": []string{"chunked"}, + "Upgrade": []string{"WebSocket"}, + + // headers that should persist + "Accept": []string{"application/json"}, + "X-Foo": []string{"Bar"}, + }) + + removeSingleHopHeaders(&hdr) + + want := http.Header(map[string][]string{ + "Accept": []string{"application/json"}, + "X-Foo": []string{"Bar"}, + }) + + if !reflect.DeepEqual(want, hdr) { + t.Fatalf("unexpected result: want = %#v, got = %#v", want, hdr) + } +} + +func TestCopyHeader(t *testing.T) { + tests := []struct { + src http.Header + dst http.Header + want http.Header + }{ + { + src: http.Header(map[string][]string{ + "Foo": []string{"bar", "baz"}, + }), + dst: http.Header(map[string][]string{}), + want: http.Header(map[string][]string{ + "Foo": []string{"bar", "baz"}, + }), + }, + { + src: http.Header(map[string][]string{ + "Foo": []string{"bar"}, + "Ping": []string{"pong"}, + }), + dst: http.Header(map[string][]string{}), + want: http.Header(map[string][]string{ + "Foo": []string{"bar"}, + "Ping": []string{"pong"}, + }), + }, + { + src: http.Header(map[string][]string{ + "Foo": []string{"bar", "baz"}, + }), + dst: http.Header(map[string][]string{ + "Foo": []string{"qux"}, + }), + want: http.Header(map[string][]string{ + "Foo": []string{"qux", "bar", "baz"}, + }), + }, + } + + for i, tt := range tests { + copyHeader(tt.dst, tt.src) + if !reflect.DeepEqual(tt.dst, tt.want) { + t.Errorf("#%d: unexpected headers: want = %v, got = %v", i, tt.want, tt.dst) + } + } +}