mirror of
https://github.com/etcd-io/etcd.git
synced 2024-09-27 06:25:44 +00:00
pkg/cors: remove
Signed-off-by: Gyuho Lee <gyuhox@gmail.com>
This commit is contained in:
parent
df6cd22d59
commit
7195bb7ced
@ -1,90 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Package cors handles cross-origin HTTP requests (CORS).
|
||||
package cors
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type CORSInfo map[string]bool
|
||||
|
||||
// Set implements the flag.Value interface to allow users to define a list of CORS origins
|
||||
func (ci *CORSInfo) Set(s string) error {
|
||||
m := make(map[string]bool)
|
||||
for _, v := range strings.Split(s, ",") {
|
||||
v = strings.TrimSpace(v)
|
||||
if v == "" {
|
||||
continue
|
||||
}
|
||||
if v != "*" {
|
||||
if _, err := url.Parse(v); err != nil {
|
||||
return fmt.Errorf("Invalid CORS origin: %s", err)
|
||||
}
|
||||
}
|
||||
m[v] = true
|
||||
|
||||
}
|
||||
*ci = CORSInfo(m)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ci *CORSInfo) String() string {
|
||||
o := make([]string, 0)
|
||||
for k := range *ci {
|
||||
o = append(o, k)
|
||||
}
|
||||
sort.StringSlice(o).Sort()
|
||||
return strings.Join(o, ",")
|
||||
}
|
||||
|
||||
// OriginAllowed determines whether the server will allow a given CORS origin.
|
||||
func (c CORSInfo) OriginAllowed(origin string) bool {
|
||||
return c["*"] || c[origin]
|
||||
}
|
||||
|
||||
type CORSHandler struct {
|
||||
Handler http.Handler
|
||||
Info *CORSInfo
|
||||
}
|
||||
|
||||
// addHeader adds the correct cors headers given an origin
|
||||
func (h *CORSHandler) addHeader(w http.ResponseWriter, origin string) {
|
||||
w.Header().Add("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
|
||||
w.Header().Add("Access-Control-Allow-Origin", origin)
|
||||
w.Header().Add("Access-Control-Allow-Headers", "accept, content-type, authorization")
|
||||
}
|
||||
|
||||
// ServeHTTP adds the correct CORS headers based on the origin and returns immediately
|
||||
// with a 200 OK if the method is OPTIONS.
|
||||
func (h *CORSHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
|
||||
// Write CORS header.
|
||||
if h.Info.OriginAllowed("*") {
|
||||
h.addHeader(w, "*")
|
||||
} else if origin := req.Header.Get("Origin"); h.Info.OriginAllowed(origin) {
|
||||
h.addHeader(w, origin)
|
||||
}
|
||||
|
||||
if req.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
h.Handler.ServeHTTP(w, req)
|
||||
}
|
@ -1,125 +0,0 @@
|
||||
// Copyright 2015 The etcd Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package cors
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCORSInfo(t *testing.T) {
|
||||
tests := []struct {
|
||||
s string
|
||||
winfo CORSInfo
|
||||
ws string
|
||||
}{
|
||||
{"", CORSInfo{}, ""},
|
||||
{"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
|
||||
{"*", CORSInfo{"*": true}, "*"},
|
||||
// with space around
|
||||
{" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"},
|
||||
// multiple addrs
|
||||
{
|
||||
"http://127.0.0.1,http://127.0.0.2",
|
||||
CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true},
|
||||
"http://127.0.0.1,http://127.0.0.2",
|
||||
},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
info := CORSInfo{}
|
||||
if err := info.Set(tt.s); err != nil {
|
||||
t.Errorf("#%d: set error = %v, want nil", i, err)
|
||||
}
|
||||
if !reflect.DeepEqual(info, tt.winfo) {
|
||||
t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo)
|
||||
}
|
||||
if g := info.String(); g != tt.ws {
|
||||
t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSInfoOriginAllowed(t *testing.T) {
|
||||
tests := []struct {
|
||||
set string
|
||||
origin string
|
||||
wallowed bool
|
||||
}{
|
||||
{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true},
|
||||
{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true},
|
||||
{"http://127.0.0.1,http://127.0.0.2", "*", false},
|
||||
{"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false},
|
||||
{"*", "*", true},
|
||||
{"*", "http://127.0.0.1", true},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
info := CORSInfo{}
|
||||
if err := info.Set(tt.set); err != nil {
|
||||
t.Errorf("#%d: set error = %v, want nil", i, err)
|
||||
}
|
||||
if g := info.OriginAllowed(tt.origin); g != tt.wallowed {
|
||||
t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCORSHandler(t *testing.T) {
|
||||
info := &CORSInfo{}
|
||||
if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil {
|
||||
t.Fatalf("unexpected set error: %v", err)
|
||||
}
|
||||
h := &CORSHandler{
|
||||
Handler: http.NotFoundHandler(),
|
||||
Info: info,
|
||||
}
|
||||
|
||||
header := func(origin string) http.Header {
|
||||
return http.Header{
|
||||
"Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"},
|
||||
"Access-Control-Allow-Origin": []string{origin},
|
||||
"Access-Control-Allow-Headers": []string{"accept, content-type, authorization"},
|
||||
}
|
||||
}
|
||||
tests := []struct {
|
||||
method string
|
||||
origin string
|
||||
wcode int
|
||||
wheader http.Header
|
||||
}{
|
||||
{"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")},
|
||||
{"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")},
|
||||
{"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}},
|
||||
{"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")},
|
||||
}
|
||||
for i, tt := range tests {
|
||||
rr := httptest.NewRecorder()
|
||||
req := &http.Request{
|
||||
Method: tt.method,
|
||||
Header: http.Header{"Origin": []string{tt.origin}},
|
||||
}
|
||||
h.ServeHTTP(rr, req)
|
||||
if rr.Code != tt.wcode {
|
||||
t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode)
|
||||
}
|
||||
// it is set by http package, and there is no need to test it
|
||||
rr.HeaderMap.Del("Content-Type")
|
||||
rr.HeaderMap.Del("X-Content-Type-Options")
|
||||
if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) {
|
||||
t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader)
|
||||
}
|
||||
}
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user