From 7e84c7427d08f9112241d820bfdca2825d0f895a Mon Sep 17 00:00:00 2001 From: Anis Elleuch Date: Sat, 18 Feb 2017 22:28:54 +0100 Subject: [PATCH] server-mux: Rewrite graceful shutdown mechanism (#3771) Old code uses waitgroup Add() and Wait() in different threads, which eventually can lead to a race. --- cmd/server-mux.go | 64 ++++++++++++++++++++++++++---------------- cmd/server-mux_test.go | 16 +++++++---- 2 files changed, 51 insertions(+), 29 deletions(-) diff --git a/cmd/server-mux.go b/cmd/server-mux.go index 8ea33a3d5..4208a5807 100644 --- a/cmd/server-mux.go +++ b/cmd/server-mux.go @@ -26,9 +26,14 @@ import ( "net/url" "strings" "sync" + "sync/atomic" "time" ) +const ( + serverShutdownPoll = 500 * time.Millisecond +) + // The value chosen below is longest word chosen // from all the http verbs comprising of // "PRI", "OPTIONS", "GET", "HEAD", "POST", @@ -324,11 +329,13 @@ type ServerMux struct { handler http.Handler listeners []*ListenerMux - gracefulWait *sync.WaitGroup + // Current number of concurrent http requests + currentReqs int32 + // Time to wait before forcing server shutdown gracefulTimeout time.Duration - mu sync.Mutex // guards closed, and listener - closed bool + mu sync.Mutex // guards closing, and listeners + closing bool } // NewServerMux constructor to create a ServerMux @@ -339,7 +346,6 @@ func NewServerMux(addr string, handler http.Handler) *ServerMux { // Wait for 5 seconds for new incoming connnections, otherwise // forcibly close them during graceful stop or restart. gracefulTimeout: 5 * time.Second, - gracefulWait: &sync.WaitGroup{}, } // Returns configured HTTP server. @@ -452,11 +458,22 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) { } http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } else { - // Execute registered handlers, protect with a waitgroup - // to accomplish a graceful shutdown when the user asks to quit - m.gracefulWait.Add(1) + + // Return ServiceUnavailable for clients which are sending requests + // in shutdown phase + m.mu.Lock() + closing := m.closing + m.mu.Unlock() + if closing { + w.WriteHeader(http.StatusServiceUnavailable) + return + } + + // Execute registered handlers, update currentReqs to keep + // tracks of current requests currently processed by the server + atomic.AddInt32(&m.currentReqs, 1) m.handler.ServeHTTP(w, r) - m.gracefulWait.Done() + atomic.AddInt32(&m.currentReqs, -1) } }) @@ -481,12 +498,12 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) { func (m *ServerMux) Close() error { m.mu.Lock() - if m.closed { + if m.closing { m.mu.Unlock() return errors.New("Server has been closed") } // Closed completely. - m.closed = true + m.closing = true // Close the listeners. for _, listener := range m.listeners { @@ -497,19 +514,18 @@ func (m *ServerMux) Close() error { } m.mu.Unlock() - // Prepare for a graceful shutdown - waitSignal := make(chan struct{}) - go func() { - defer close(waitSignal) - m.gracefulWait.Wait() - }() - - select { - // Wait for everything to be properly closed - case <-waitSignal: - // Forced shutdown - case <-time.After(m.gracefulTimeout): + // Starting graceful shutdown. Check if all requests are finished + // in regular interval or force the shutdown + ticker := time.NewTicker(serverShutdownPoll) + defer ticker.Stop() + for { + select { + case <-time.After(m.gracefulTimeout): + return nil + case <-ticker.C: + if atomic.LoadInt32(&m.currentReqs) <= 0 { + return nil + } + } } - - return nil } diff --git a/cmd/server-mux_test.go b/cmd/server-mux_test.go index 1a24473a6..24802abed 100644 --- a/cmd/server-mux_test.go +++ b/cmd/server-mux_test.go @@ -198,21 +198,27 @@ func TestServerMux(t *testing.T) { time.Sleep(1 * time.Second) // Check if one listener is ready m.mu.Lock() - if len(m.listeners) == 0 { - m.mu.Unlock() + listenersCount := len(m.listeners) + m.mu.Unlock() + if listenersCount == 0 { continue } + m.mu.Lock() + listenerAddr := m.listeners[0].Addr().String() m.mu.Unlock() // Issue the GET request client := http.Client{} - m.mu.Lock() - res, err = client.Get("http://" + m.listeners[0].Addr().String()) - m.mu.Unlock() + res, err = client.Get("http://" + listenerAddr) if err != nil { continue } // Read the request response got, err = ioutil.ReadAll(res.Body) + if err != nil { + continue + } + // We've got a response, quit the loop + break } // Check for error persisted after 5 times