package main import ( "encoding/base64" "fmt" "github.com/gorilla/mux" "io/ioutil" "net" "net/http" "os" "os/signal" "strings" "sync/atomic" "syscall" "time" ) func check(err error) { if err != nil { panic(err) } } func config(k string) string { v := os.Getenv(k) if v == "" { panic("missing " + k) } return v } func runLogging(logs chan string) { for log := range logs { fmt.Println(log) } } func wrapLogging(f http.HandlerFunc, logs chan string) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { start := time.Now() f(res, req) method := req.Method path := req.URL.Path elapsed := float64(time.Since(start)) / 1000000.0 logs <- fmt.Sprintf("request at=finish method=%s path=%s elapsed=%f", method, path, elapsed) } } func wrapCanonicalHost(f http.HandlerFunc, canonicalHost string, forceHttps bool) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { scheme := "http" if h, ok := req.Header["X-Forwarded-Proto"]; ok { if h[0] == "https" { scheme = "https" } } hostPort := strings.Split(req.Host, ":") host := hostPort[0] if (forceHttps && (scheme != "https")) || host != canonicalHost { if forceHttps { scheme = "https" } hostPort[0] = canonicalHost url := scheme + "://" + strings.Join(hostPort, ":") + req.URL.String() http.Redirect(res, req, url, 301) return } f(res, req) } } type Authenticator func(string, string) bool func testAuth(r *http.Request, auth Authenticator) bool { s := strings.SplitN(r.Header.Get("Authorization"), " ", 2) if len(s) != 2 || s[0] != "Basic" { return false } b, err := base64.StdEncoding.DecodeString(s[1]) if err != nil { return false } pair := strings.SplitN(string(b), ":", 2) if len(pair) != 2 { return false } return auth(pair[0], pair[1]) } func requireAuth(w http.ResponseWriter, r *http.Request) { w.Header().Set("WWW-Authenticate", `Basic realm="private"`) w.WriteHeader(401) w.Write([]byte("401 Unauthorized\n")) } func wrapAuth(h http.HandlerFunc, a Authenticator) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { if testAuth(r, a) { h(w, r) } else { requireAuth(w, r) } } } var reqCount int64 = 0 func wrapReqCount(h http.HandlerFunc, reqCountPtr *int64) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { atomic.AddInt64(reqCountPtr, 1) h(w, r) atomic.AddInt64(reqCountPtr, -1) } } func static(res http.ResponseWriter, req *http.Request) { http.ServeFile(res, req, "public"+req.URL.Path) } func notFound(res http.ResponseWriter, req *http.Request) { http.ServeFile(res, req, "public/404.html") } func checkAuth(user, pass string) bool { auth := os.Getenv("AUTH") if auth == "" { return true } return auth == strings.Join([]string{user, pass}, ":") } func routerHandlerFunc(router *mux.Router) http.HandlerFunc { return func(res http.ResponseWriter, req *http.Request) { router.ServeHTTP(res, req) } } func router() *mux.Router { router := mux.NewRouter() router.HandleFunc("/", static).Methods("GET") router.HandleFunc("/favicon.ico", static).Methods("GET") router.HandleFunc("/play.png", static).Methods("GET") router.HandleFunc("/site.css", static).Methods("GET") entries, err := ioutil.ReadDir("public") check(err) for _, f := range entries { if !strings.Contains(f.Name(), ".") { router.HandleFunc("/" + f.Name(), static).Methods("GET") } } router.NotFoundHandler = http.HandlerFunc(notFound) return router } func main() { logs := make(chan string, 10000) go runLogging(logs) handler := routerHandlerFunc(router()) if os.Getenv("AUTH") != "" { handler = wrapAuth(handler, checkAuth) } handler = wrapCanonicalHost(handler, config("CANONICAL_HOST"), config("FORCE_HTTPS") == "1") handler = wrapLogging(handler, logs) handler = wrapReqCount(handler, &reqCount) server := &http.Server{Handler: handler} listener, listenErr := net.Listen("tcp", ":"+config("PORT")) if listenErr != nil { panic(listenErr) } stop := make(chan bool, 1) sig := make(chan os.Signal, 1) go func() { signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) logs <- "trap at=start" <-sig for { reqCountCurrent := atomic.LoadInt64(&reqCount) if reqCountCurrent > 0 { logs <- fmt.Sprintf("trap at=draining remaining=%d", reqCountCurrent) time.Sleep(time.Second) } else { logs <- fmt.Sprintf("trap at=finish") stop <- true return } } }() go func() { logs <- "serve at=start" server.Serve(listener) logs <- "serve at=finish" }() <-stop logs <- "close at=start" closeErr := listener.Close() if closeErr != nil { panic(closeErr) } logs <- "close at=finish" }