mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
@@ -1,90 +0,0 @@
|
|||||||
// Copyright 2015 The etcd Authors
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// Package cors handles cross-origin HTTP requests (CORS).
|
|
||||||
package cors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type CORSInfo map[string]bool
|
|
||||||
|
|
||||||
// Set implements the flag.Value interface to allow users to define a list of CORS origins
|
|
||||||
func (ci *CORSInfo) Set(s string) error {
|
|
||||||
m := make(map[string]bool)
|
|
||||||
for _, v := range strings.Split(s, ",") {
|
|
||||||
v = strings.TrimSpace(v)
|
|
||||||
if v == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if v != "*" {
|
|
||||||
if _, err := url.Parse(v); err != nil {
|
|
||||||
return fmt.Errorf("Invalid CORS origin: %s", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
m[v] = true
|
|
||||||
|
|
||||||
}
|
|
||||||
*ci = CORSInfo(m)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ci *CORSInfo) String() string {
|
|
||||||
o := make([]string, 0)
|
|
||||||
for k := range *ci {
|
|
||||||
o = append(o, k)
|
|
||||||
}
|
|
||||||
sort.StringSlice(o).Sort()
|
|
||||||
return strings.Join(o, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
// OriginAllowed determines whether the server will allow a given CORS origin.
|
|
||||||
func (c CORSInfo) OriginAllowed(origin string) bool {
|
|
||||||
return c["*"] || c[origin]
|
|
||||||
}
|
|
||||||
|
|
||||||
type CORSHandler struct {
|
|
||||||
Handler http.Handler
|
|
||||||
Info *CORSInfo
|
|
||||||
}
|
|
||||||
|
|
||||||
// addHeader adds the correct cors headers given an origin
|
|
||||||
func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
|
|
||||||
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
|
||||||
w.Header().Add("Access-Control-Allow-Origin", origin)
|
|
||||||
w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeHTTP adds the correct CORS headers based on the origin and returns immediately
|
|
||||||
// with a 200 OK if the method is OPTIONS.
|
|
||||||
func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
|
||||||
// Write CORS header.
|
|
||||||
if h.Info.OriginAllowed("*") {
|
|
||||||
h.addHeader(w, "*")
|
|
||||||
} else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
|
|
||||||
h.addHeader(w, origin)
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.Method == "OPTIONS" {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
h.Handler.ServeHTTP(w, req)
|
|
||||||
}
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
// Copyright 2015 The etcd Authors
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package cors
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestCORSInfo(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
s string
|
|
||||||
winfo CORSInfo
|
|
||||||
ws string
|
|
||||||
}{
|
|
||||||
{"", CORSInfo{}, ""},
|
|
||||||
{"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
|
|
||||||
{"*", CORSInfo{"*": true}, "*"},
|
|
||||||
// with space around
|
|
||||||
{" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
|
|
||||||
// multiple addrs
|
|
||||||
{
|
|
||||||
"http://127.0.0.1,http://127.0.0.2",
|
|
||||||
CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true},
|
|
||||||
"http://127.0.0.1,http://127.0.0.2",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for i, tt := range tests {
|
|
||||||
info := CORSInfo{}
|
|
||||||
if err := info.Set(tt.s); err != nil {
|
|
||||||
t.Errorf("#%d: set error = %v, want nil", i, err)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(info, tt.winfo) {
|
|
||||||
t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo)
|
|
||||||
}
|
|
||||||
if g := info.String(); g != tt.ws {
|
|
||||||
t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCORSInfoOriginAllowed(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
set string
|
|
||||||
origin string
|
|
||||||
wallowed bool
|
|
||||||
}{
|
|
||||||
{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true},
|
|
||||||
{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true},
|
|
||||||
{"http://127.0.0.1,http://127.0.0.2", "*", false},
|
|
||||||
{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false},
|
|
||||||
{"*", "*", true},
|
|
||||||
{"*", "http://127.0.0.1", true},
|
|
||||||
}
|
|
||||||
for i, tt := range tests {
|
|
||||||
info := CORSInfo{}
|
|
||||||
if err := info.Set(tt.set); err != nil {
|
|
||||||
t.Errorf("#%d: set error = %v, want nil", i, err)
|
|
||||||
}
|
|
||||||
if g := info.OriginAllowed(tt.origin); g != tt.wallowed {
|
|
||||||
t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCORSHandler(t *testing.T) {
|
|
||||||
info := &CORSInfo{}
|
|
||||||
if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil {
|
|
||||||
t.Fatalf("unexpected set error: %v", err)
|
|
||||||
}
|
|
||||||
h := &CORSHandler{
|
|
||||||
Handler: http.NotFoundHandler(),
|
|
||||||
Info: info,
|
|
||||||
}
|
|
||||||
|
|
||||||
header := func(origin string) http.Header {
|
|
||||||
return http.Header{
|
|
||||||
"Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"},
|
|
||||||
"Access-Control-Allow-Origin": []string{origin},
|
|
||||||
"Access-Control-Allow-Headers": []string{"accept, content-type, authorization"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
method string
|
|
||||||
origin string
|
|
||||||
wcode int
|
|
||||||
wheader http.Header
|
|
||||||
}{
|
|
||||||
{"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")},
|
|
||||||
{"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")},
|
|
||||||
{"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}},
|
|
||||||
{"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")},
|
|
||||||
}
|
|
||||||
for i, tt := range tests {
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
req := &http.Request{
|
|
||||||
Method: tt.method,
|
|
||||||
Header: http.Header{"Origin": []string{tt.origin}},
|
|
||||||
}
|
|
||||||
h.ServeHTTP(rr, req)
|
|
||||||
if rr.Code != tt.wcode {
|
|
||||||
t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode)
|
|
||||||
}
|
|
||||||
// it is set by http package, and there is no need to test it
|
|
||||||
rr.HeaderMap.Del("Content-Type")
|
|
||||||
rr.HeaderMap.Del("X-Content-Type-Options")
|
|
||||||
if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) {
|
|
||||||
t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user