diff --git a/cors.go b/cors.go index f23ab6d19..19e92007b 100644 --- a/cors.go +++ b/cors.go @@ -20,24 +20,37 @@ import ( "fmt" "net/http" "net/url" + "strings" ) type CORSInfo map[string]bool -func newCORSInfo(origins []string) (*CORSInfo, error) { - // Construct a lookup of all origins. +// CORSInfo implements the flag.Value interface to allow users to define a list of CORS origins +func (ci *CORSInfo) Set(s string) error { m := make(map[string]bool) - for _, v := range origins { + for _, v := range strings.Split(s, ",") { + v = strings.TrimSpace(v) + if v == "" { + continue + } if v != "*" { if _, err := url.Parse(v); err != nil { - return nil, fmt.Errorf("Invalid CORS origin: %s", err) + return fmt.Errorf("Invalid CORS origin: %s", err) } } m[v] = true - } - info := CORSInfo(m) - return &info, nil + } + *ci = CORSInfo(m) + return nil +} + +func (ci *CORSInfo) String() string { + o := make([]string, 0) + for k, _ := range *ci { + o = append(o, k) + } + return strings.Join(o, ",") } // OriginAllowed determines whether the server will allow a given CORS origin. diff --git a/main.go b/main.go index 874b3a88e..c59b40f7b 100644 --- a/main.go +++ b/main.go @@ -40,6 +40,7 @@ var ( peers = &etcdhttp.Peers{} addrs = &Addrs{} + cors = &CORSInfo{} proxyFlag = new(ProxyFlag) proxyFlagValues = []string{ @@ -52,6 +53,7 @@ var ( func init() { flag.Var(peers, "peers", "your peers") flag.Var(addrs, "bind-addr", "List of HTTP service addresses (e.g., '127.0.0.1:4001,10.0.0.1:8080')") + flag.Var(cors, "cors", "Comma-separated white list of origins for CORS (cross-origin resource sharing).") flag.Var(proxyFlag, "proxy", fmt.Sprintf("Valid values include %s", strings.Join(proxyFlagValues, ", "))) peers.Set("0x1=localhost:8080") addrs.Set("127.0.0.1:4001") @@ -156,8 +158,14 @@ func startEtcd() { } s.Start() - ch := etcdhttp.NewClientHandler(s, *peers, *timeout) - ph := etcdhttp.NewPeerHandler(s) + ch := &CORSHandler{ + Handler: etcdhttp.NewClientHandler(s, *peers, *timeout), + Info: cors, + } + ph := &CORSHandler{ + Handler: etcdhttp.NewPeerHandler(s), + Info: cors, + } // Start the peer server in a goroutine go func() { @@ -181,6 +189,10 @@ func startProxy() { if err != nil { log.Fatal(err) } + ph = &CORSHandler{ + Handler: ph, + Info: cors, + } if string(*proxyFlag) == proxyFlagValueReadonly { ph = proxy.NewReadonlyHandler(ph)