gobyexample/server.go
2013-09-11 08:40:37 -07:00

209 lines
4.6 KiB
Go

package main
import (
"encoding/base64"
"fmt"
"github.com/gorilla/mux"
"io/ioutil"
"net"
"net/http"
"os"
"os/signal"
"strings"
"sync/atomic"
"syscall"
"time"
)
func check(err error) {
if err != nil {
panic(err)
}
}
func config(k string) string {
v := os.Getenv(k)
if v == "" {
panic("missing " + k)
}
return v
}
func runLogging(logs chan string) {
for log := range logs {
fmt.Println(log)
}
}
func wrapLogging(f http.HandlerFunc, logs chan string) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
start := time.Now()
f(res, req)
method := req.Method
path := req.URL.Path
elapsed := float64(time.Since(start)) / 1000000.0
logs <- fmt.Sprintf("request at=finish method=%s path=%s elapsed=%f", method, path, elapsed)
}
}
func wrapCanonicalHost(f http.HandlerFunc, canonicalHost string, forceHttps bool) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
scheme := "http"
if h, ok := req.Header["X-Forwarded-Proto"]; ok {
if h[0] == "https" {
scheme = "https"
}
}
hostPort := strings.Split(req.Host, ":")
host := hostPort[0]
if (forceHttps && (scheme != "https")) || host != canonicalHost {
if forceHttps {
scheme = "https"
}
hostPort[0] = canonicalHost
url := scheme + "://" + strings.Join(hostPort, ":") + req.URL.String()
http.Redirect(res, req, url, 301)
return
}
f(res, req)
}
}
type Authenticator func(string, string) bool
func testAuth(r *http.Request, auth Authenticator) bool {
s := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
if len(s) != 2 || s[0] != "Basic" {
return false
}
b, err := base64.StdEncoding.DecodeString(s[1])
if err != nil {
return false
}
pair := strings.SplitN(string(b), ":", 2)
if len(pair) != 2 {
return false
}
return auth(pair[0], pair[1])
}
func requireAuth(w http.ResponseWriter, r *http.Request) {
w.Header().Set("WWW-Authenticate", `Basic realm="private"`)
w.WriteHeader(401)
w.Write([]byte("401 Unauthorized\n"))
}
func wrapAuth(h http.HandlerFunc, a Authenticator) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if testAuth(r, a) {
h(w, r)
} else {
requireAuth(w, r)
}
}
}
var reqCount int64 = 0
func wrapReqCount(h http.HandlerFunc, reqCountPtr *int64) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(reqCountPtr, 1)
h(w, r)
atomic.AddInt64(reqCountPtr, -1)
}
}
func static(res http.ResponseWriter, req *http.Request) {
http.ServeFile(res, req, "public"+req.URL.Path)
}
func notFound(res http.ResponseWriter, req *http.Request) {
http.ServeFile(res, req, "public/404.html")
}
func checkAuth(user, pass string) bool {
auth := os.Getenv("AUTH")
if auth == "" {
return true
}
return auth == strings.Join([]string{user, pass}, ":")
}
func routerHandlerFunc(router *mux.Router) http.HandlerFunc {
return func(res http.ResponseWriter, req *http.Request) {
router.ServeHTTP(res, req)
}
}
func router() *mux.Router {
router := mux.NewRouter()
router.HandleFunc("/", static).Methods("GET")
router.HandleFunc("/favicon.ico", static).Methods("GET")
router.HandleFunc("/play.png", static).Methods("GET")
router.HandleFunc("/site.css", static).Methods("GET")
entries, err := ioutil.ReadDir("public")
check(err)
for _, f := range entries {
if !strings.Contains(f.Name(), ".") {
router.HandleFunc("/" + f.Name(), static).Methods("GET")
}
}
router.NotFoundHandler = http.HandlerFunc(notFound)
return router
}
func main() {
logs := make(chan string, 10000)
go runLogging(logs)
handler := routerHandlerFunc(router())
if os.Getenv("AUTH") != "" {
handler = wrapAuth(handler, checkAuth)
}
handler = wrapCanonicalHost(handler, config("CANONICAL_HOST"), config("FORCE_HTTPS") == "1")
handler = wrapLogging(handler, logs)
handler = wrapReqCount(handler, &reqCount)
server := &http.Server{Handler: handler}
listener, listenErr := net.Listen("tcp", ":"+config("PORT"))
if listenErr != nil {
panic(listenErr)
}
stop := make(chan bool, 1)
sig := make(chan os.Signal, 1)
go func() {
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
logs <- "trap at=start"
<-sig
for {
reqCountCurrent := atomic.LoadInt64(&reqCount)
if reqCountCurrent > 0 {
logs <- fmt.Sprintf("trap at=draining remaining=%d", reqCountCurrent)
time.Sleep(time.Second)
} else {
logs <- fmt.Sprintf("trap at=finish")
stop <- true
return
}
}
}()
go func() {
logs <- "serve at=start"
server.Serve(listener)
logs <- "serve at=finish"
}()
<-stop
logs <- "close at=start"
closeErr := listener.Close()
if closeErr != nil {
panic(closeErr)
}
logs <- "close at=finish"
}