From c8994cff37f072bbe73feb388f9f0559f2d03517 Mon Sep 17 00:00:00 2001 From: Yicheng Qin Date: Mon, 12 Jan 2015 17:45:17 -0800 Subject: [PATCH] pkg/cors: add tests --- pkg/cors/cors_test.go | 126 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 pkg/cors/cors_test.go diff --git a/pkg/cors/cors_test.go b/pkg/cors/cors_test.go new file mode 100644 index 000000000..894d42cb8 --- /dev/null +++ b/pkg/cors/cors_test.go @@ -0,0 +1,126 @@ +/* + Copyright 2014 CoreOS, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +package cors + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +func TestCORSInfo(t *testing.T) { + tests := []struct { + s string + winfo CORSInfo + ws string + }{ + {"", CORSInfo{}, ""}, + {"http://127.0.0.1", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"}, + {"*", CORSInfo{"*": true}, "*"}, + // with space around + {" http://127.0.0.1 ", CORSInfo{"http://127.0.0.1": true}, "http://127.0.0.1"}, + // multiple addrs + { + "http://127.0.0.1,http://127.0.0.2", + CORSInfo{"http://127.0.0.1": true, "http://127.0.0.2": true}, + "http://127.0.0.1,http://127.0.0.2", + }, + } + for i, tt := range tests { + info := CORSInfo{} + if err := info.Set(tt.s); err != nil { + t.Errorf("#%d: set error = %v, want nil", i, err) + } + if !reflect.DeepEqual(info, tt.winfo) { + t.Errorf("#%d: info = %v, want %v", i, info, tt.winfo) + } + if g := info.String(); g != tt.ws { + t.Errorf("#%d: info string = %s, want %s", i, g, tt.ws) + } + } +} + +func TestCORSInfoOriginAllowed(t *testing.T) { + tests := []struct { + set string + origin string + wallowed bool + }{ + {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.1", true}, + {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.2", true}, + {"http://127.0.0.1,http://127.0.0.2", "*", false}, + {"http://127.0.0.1,http://127.0.0.2", "http://127.0.0.3", false}, + {"*", "*", true}, + {"*", "http://127.0.0.1", true}, + } + for i, tt := range tests { + info := CORSInfo{} + if err := info.Set(tt.set); err != nil { + t.Errorf("#%d: set error = %v, want nil", i, err) + } + if g := info.OriginAllowed(tt.origin); g != tt.wallowed { + t.Errorf("#%d: allowed = %v, want %v", i, g, tt.wallowed) + } + } +} + +func TestCORSHandler(t *testing.T) { + info := &CORSInfo{} + if err := info.Set("http://127.0.0.1,http://127.0.0.2"); err != nil { + t.Fatalf("unexpected set error: %v", err) + } + h := &CORSHandler{ + Handler: http.NotFoundHandler(), + Info: info, + } + + header := func(origin string) http.Header { + return http.Header{ + "Access-Control-Allow-Methods": []string{"POST, GET, OPTIONS, PUT, DELETE"}, + "Access-Control-Allow-Origin": []string{origin}, + "Access-Control-Allow-Headers": []string{"accept, content-type"}, + } + } + tests := []struct { + method string + origin string + wcode int + wheader http.Header + }{ + {"GET", "http://127.0.0.1", http.StatusNotFound, header("http://127.0.0.1")}, + {"GET", "http://127.0.0.2", http.StatusNotFound, header("http://127.0.0.2")}, + {"GET", "http://127.0.0.3", http.StatusNotFound, http.Header{}}, + {"OPTIONS", "http://127.0.0.1", http.StatusOK, header("http://127.0.0.1")}, + } + for i, tt := range tests { + rr := httptest.NewRecorder() + req := &http.Request{ + Method: tt.method, + Header: http.Header{"Origin": []string{tt.origin}}, + } + h.ServeHTTP(rr, req) + if rr.Code != tt.wcode { + t.Errorf("#%d: code = %v, want %v", i, rr.Code, tt.wcode) + } + // it is set by http package, and there is no need to test it + rr.HeaderMap.Del("Content-Type") + if !reflect.DeepEqual(rr.HeaderMap, tt.wheader) { + t.Errorf("#%d: header = %+v, want %+v", i, rr.HeaderMap, tt.wheader) + } + } +}