diff --git a/etcdmain/etcd.go b/etcdmain/etcd.go index aa5ee572d..907f7540c 100644 --- a/etcdmain/etcd.go +++ b/etcdmain/etcd.go @@ -31,6 +31,7 @@ import ( "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/etcdhttp" "github.com/coreos/etcd/pkg/cors" + "github.com/coreos/etcd/pkg/osutil" "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/pkg/types" "github.com/coreos/etcd/proxy" @@ -73,7 +74,10 @@ func Main() { } } + osutil.HandleInterrupts() + <-stopped + osutil.Exit(0) } // startEtcd launches the etcd server and HTTP handlers for client/server communication. @@ -160,6 +164,7 @@ func startEtcd(cfg *config) (<-chan struct{}, error) { return nil, err } s.Start() + osutil.RegisterInterruptHandler(s.Stop) if cfg.corsInfo.String() != "" { log.Printf("etcd: cors = %s", cfg.corsInfo) diff --git a/pkg/osutil/osutil.go b/pkg/osutil/osutil.go index 37b38329a..aa9b601ff 100644 --- a/pkg/osutil/osutil.go +++ b/pkg/osutil/osutil.go @@ -15,8 +15,12 @@ package osutil import ( + "log" "os" + "os/signal" "strings" + "sync" + "syscall" ) func Unsetenv(key string) error { @@ -33,3 +37,53 @@ func Unsetenv(key string) error { } return nil } + +// InterruptHandler is a function that is called on receiving a +// SIGTERM or SIGINT signal. +type InterruptHandler func() + +var ( + interruptRegisterMu, interruptExitMu sync.Mutex + // interruptHandlers holds all registered InterruptHandlers in order + // they will be executed. + interruptHandlers = []InterruptHandler{} +) + +// RegisterInterruptHandler registers a new InterruptHandler. Handlers registered +// after interrupt handing was initiated will not be executed. +func RegisterInterruptHandler(h InterruptHandler) { + interruptRegisterMu.Lock() + defer interruptRegisterMu.Unlock() + interruptHandlers = append(interruptHandlers, h) +} + +// HandleInterrupts calls the handler functions on receiving a SIGINT or SIGTERM. +func HandleInterrupts() { + notifier := make(chan os.Signal, 1) + signal.Notify(notifier, syscall.SIGINT, syscall.SIGTERM) + + go func() { + sig := <-notifier + + interruptRegisterMu.Lock() + ihs := make([]InterruptHandler, len(interruptHandlers)) + copy(ihs, interruptHandlers) + interruptRegisterMu.Unlock() + + interruptExitMu.Lock() + + log.Printf("received %v signal, shutting down", sig) + + for _, h := range ihs { + h() + } + signal.Stop(notifier) + syscall.Kill(syscall.Getpid(), sig.(syscall.Signal)) + }() +} + +// Exit relays to os.Exit if no interrupt handlers are running, blocks otherwise. +func Exit(code int) { + interruptExitMu.Lock() + os.Exit(code) +} diff --git a/pkg/osutil/osutil_test.go b/pkg/osutil/osutil_test.go index f2d88d141..bb8681925 100644 --- a/pkg/osutil/osutil_test.go +++ b/pkg/osutil/osutil_test.go @@ -16,8 +16,11 @@ package osutil import ( "os" + "os/signal" "reflect" + "syscall" "testing" + "time" ) func TestUnsetenv(t *testing.T) { @@ -43,3 +46,43 @@ func TestUnsetenv(t *testing.T) { } } } + +func waitSig(t *testing.T, c <-chan os.Signal, sig os.Signal) { + select { + case s := <-c: + if s != sig { + t.Fatalf("signal was %v, want %v", s, sig) + } + case <-time.After(1 * time.Second): + t.Fatalf("timeout waiting for %v", sig) + } +} + +func TestHandleInterrupts(t *testing.T) { + for _, sig := range []syscall.Signal{syscall.SIGINT, syscall.SIGTERM} { + n := 1 + RegisterInterruptHandler(func() { n++ }) + RegisterInterruptHandler(func() { n *= 2 }) + + c := make(chan os.Signal, 2) + signal.Notify(c, sig) + + HandleInterrupts() + syscall.Kill(syscall.Getpid(), sig) + + // we should receive the signal once from our own kill and + // a second time from HandleInterrupts + waitSig(t, c, sig) + waitSig(t, c, sig) + + if n == 3 { + t.Fatalf("interrupt handlers were called in wrong order") + } + if n != 4 { + t.Fatalf("interrupt handlers were not called properly") + } + // reset interrupt handlers + interruptHandlers = interruptHandlers[:0] + interruptExitMu.Unlock() + } +}