server-mux: Rewrite graceful shutdown mechanism (#3771)

Old code uses waitgroup Add() and Wait() in different threads,
which eventually can lead to a race.
master
Anis Elleuch 8 years ago committed by Harshavardhana
parent d12f3e06b1
commit 7e84c7427d
  1. 64
      cmd/server-mux.go
  2. 16
      cmd/server-mux_test.go

@ -26,9 +26,14 @@ import (
"net/url" "net/url"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
const (
serverShutdownPoll = 500 * time.Millisecond
)
// The value chosen below is longest word chosen // The value chosen below is longest word chosen
// from all the http verbs comprising of // from all the http verbs comprising of
// "PRI", "OPTIONS", "GET", "HEAD", "POST", // "PRI", "OPTIONS", "GET", "HEAD", "POST",
@ -324,11 +329,13 @@ type ServerMux struct {
handler http.Handler handler http.Handler
listeners []*ListenerMux listeners []*ListenerMux
gracefulWait *sync.WaitGroup // Current number of concurrent http requests
currentReqs int32
// Time to wait before forcing server shutdown
gracefulTimeout time.Duration gracefulTimeout time.Duration
mu sync.Mutex // guards closed, and listener mu sync.Mutex // guards closing, and listeners
closed bool closing bool
} }
// NewServerMux constructor to create a ServerMux // NewServerMux constructor to create a ServerMux
@ -339,7 +346,6 @@ func NewServerMux(addr string, handler http.Handler) *ServerMux {
// Wait for 5 seconds for new incoming connnections, otherwise // Wait for 5 seconds for new incoming connnections, otherwise
// forcibly close them during graceful stop or restart. // forcibly close them during graceful stop or restart.
gracefulTimeout: 5 * time.Second, gracefulTimeout: 5 * time.Second,
gracefulWait: &sync.WaitGroup{},
} }
// Returns configured HTTP server. // Returns configured HTTP server.
@ -452,11 +458,22 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
} }
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
} else { } else {
// Execute registered handlers, protect with a waitgroup
// to accomplish a graceful shutdown when the user asks to quit // Return ServiceUnavailable for clients which are sending requests
m.gracefulWait.Add(1) // in shutdown phase
m.mu.Lock()
closing := m.closing
m.mu.Unlock()
if closing {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
// Execute registered handlers, update currentReqs to keep
// tracks of current requests currently processed by the server
atomic.AddInt32(&m.currentReqs, 1)
m.handler.ServeHTTP(w, r) m.handler.ServeHTTP(w, r)
m.gracefulWait.Done() atomic.AddInt32(&m.currentReqs, -1)
} }
}) })
@ -481,12 +498,12 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
func (m *ServerMux) Close() error { func (m *ServerMux) Close() error {
m.mu.Lock() m.mu.Lock()
if m.closed { if m.closing {
m.mu.Unlock() m.mu.Unlock()
return errors.New("Server has been closed") return errors.New("Server has been closed")
} }
// Closed completely. // Closed completely.
m.closed = true m.closing = true
// Close the listeners. // Close the listeners.
for _, listener := range m.listeners { for _, listener := range m.listeners {
@ -497,19 +514,18 @@ func (m *ServerMux) Close() error {
} }
m.mu.Unlock() m.mu.Unlock()
// Prepare for a graceful shutdown // Starting graceful shutdown. Check if all requests are finished
waitSignal := make(chan struct{}) // in regular interval or force the shutdown
go func() { ticker := time.NewTicker(serverShutdownPoll)
defer close(waitSignal) defer ticker.Stop()
m.gracefulWait.Wait() for {
}() select {
case <-time.After(m.gracefulTimeout):
select { return nil
// Wait for everything to be properly closed case <-ticker.C:
case <-waitSignal: if atomic.LoadInt32(&m.currentReqs) <= 0 {
// Forced shutdown return nil
case <-time.After(m.gracefulTimeout): }
}
} }
return nil
} }

@ -198,21 +198,27 @@ func TestServerMux(t *testing.T) {
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
// Check if one listener is ready // Check if one listener is ready
m.mu.Lock() m.mu.Lock()
if len(m.listeners) == 0 { listenersCount := len(m.listeners)
m.mu.Unlock() m.mu.Unlock()
if listenersCount == 0 {
continue continue
} }
m.mu.Lock()
listenerAddr := m.listeners[0].Addr().String()
m.mu.Unlock() m.mu.Unlock()
// Issue the GET request // Issue the GET request
client := http.Client{} client := http.Client{}
m.mu.Lock() res, err = client.Get("http://" + listenerAddr)
res, err = client.Get("http://" + m.listeners[0].Addr().String())
m.mu.Unlock()
if err != nil { if err != nil {
continue continue
} }
// Read the request response // Read the request response
got, err = ioutil.ReadAll(res.Body) got, err = ioutil.ReadAll(res.Body)
if err != nil {
continue
}
// We've got a response, quit the loop
break
} }
// Check for error persisted after 5 times // Check for error persisted after 5 times

Loading…
Cancel
Save