server-mux: Simplify graceful shutdown behavior (#3681)

`*http.Server` is no more used, doing some cleanup.
master
Anis Elleuch 8 years ago committed by Harshavardhana
parent ed4fcb63f7
commit b6ebf2aba8
  1. 4
      cmd/server-main.go
  2. 5
      cmd/server-main_test.go
  3. 119
      cmd/server-mux.go
  4. 143
      cmd/server-mux_test.go
  5. 5
      cmd/server-startup-utils.go

@ -449,8 +449,8 @@ func serverMain(c *cli.Context) {
initGlobalAdminPeers(endpoints) initGlobalAdminPeers(endpoints)
// Determine API endpoints where we are going to serve the S3 API from. // Determine API endpoints where we are going to serve the S3 API from.
apiEndPoints, err := finalizeAPIEndpoints(apiServer.Server) apiEndPoints, err := finalizeAPIEndpoints(apiServer.Addr)
fatalIf(err, "Unable to finalize API endpoints for %s", apiServer.Server.Addr) fatalIf(err, "Unable to finalize API endpoints for %s", apiServer.Addr)
// Set the global API endpoints value. // Set the global API endpoints value.
globalAPIEndpoints = apiEndPoints globalAPIEndpoints = apiEndPoints

@ -19,7 +19,6 @@ package cmd
import ( import (
"errors" "errors"
"flag" "flag"
"net/http"
"os" "os"
"reflect" "reflect"
"runtime" "runtime"
@ -120,9 +119,7 @@ func TestFinalizeAPIEndpoints(t *testing.T) {
} }
for i, test := range testCases { for i, test := range testCases {
endPoints, err := finalizeAPIEndpoints(&http.Server{ endPoints, err := finalizeAPIEndpoints(test.addr)
Addr: test.addr,
})
if err != nil && len(endPoints) <= 0 { if err != nil && len(endPoints) <= 0 {
t.Errorf("Test case %d returned with no API end points for %s", t.Errorf("Test case %d returned with no API end points for %s",
i+1, test.addr) i+1, test.addr)

@ -309,32 +309,28 @@ func (l *ListenerMux) Accept() (net.Conn, error) {
// ServerMux - the main mux server // ServerMux - the main mux server
type ServerMux struct { type ServerMux struct {
*http.Server Addr string
listeners []*ListenerMux handler http.Handler
WaitGroup *sync.WaitGroup listeners []*ListenerMux
GracefulTimeout time.Duration
mu sync.Mutex // guards closed, conns, and listener gracefulWait *sync.WaitGroup
closed bool gracefulTimeout time.Duration
conns map[net.Conn]http.ConnState // except terminal states
mu sync.Mutex // guards closed, and listener
closed bool
} }
// NewServerMux constructor to create a ServerMux // NewServerMux constructor to create a ServerMux
func NewServerMux(addr string, handler http.Handler) *ServerMux { func NewServerMux(addr string, handler http.Handler) *ServerMux {
m := &ServerMux{ m := &ServerMux{
Server: &http.Server{ Addr: addr,
Addr: addr, handler: handler,
Handler: handler,
MaxHeaderBytes: 1 << 20,
},
WaitGroup: &sync.WaitGroup{},
// Wait for 5 seconds for new incoming connnections, otherwise // Wait for 5 seconds for new incoming connnections, otherwise
// forcibly close them during graceful stop or restart. // forcibly close them during graceful stop or restart.
GracefulTimeout: 5 * time.Second, gracefulTimeout: 5 * time.Second,
gracefulWait: &sync.WaitGroup{},
} }
// Track connection state
m.connState()
// Returns configured HTTP server. // Returns configured HTTP server.
return m return m
} }
@ -421,7 +417,7 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
go m.handleServiceSignals() go m.handleServiceSignals()
listeners, err := initListeners(m.Server.Addr, config) listeners, err := initListeners(m.Addr, config)
if err != nil { if err != nil {
return err return err
} }
@ -445,8 +441,11 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
} }
http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect)
} else { } else {
// Execute registered handlers // Execute registered handlers, protect with a waitgroup
m.Server.Handler.ServeHTTP(w, r) // to accomplish a graceful shutdown when the user asks to quit
m.gracefulWait.Add(1)
m.handler.ServeHTTP(w, r)
m.gracefulWait.Done()
} }
}) })
@ -470,6 +469,7 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) {
// Close initiates the graceful shutdown // Close initiates the graceful shutdown
func (m *ServerMux) Close() error { func (m *ServerMux) Close() error {
m.mu.Lock() m.mu.Lock()
if m.closed { if m.closed {
m.mu.Unlock() m.mu.Unlock()
return errors.New("Server has been closed") return errors.New("Server has been closed")
@ -484,76 +484,21 @@ func (m *ServerMux) Close() error {
return err return err
} }
} }
m.SetKeepAlivesEnabled(false)
// Force close any idle and new connections. Waiting for other connections
// to close on their own (within the timeout period)
for c, st := range m.conns {
if st == http.StateIdle || st == http.StateNew {
c.Close()
}
}
// If the GracefulTimeout happens then forcefully close all connections
t := time.AfterFunc(m.GracefulTimeout, func() {
for c := range m.conns {
c.Close()
}
})
// Wait for graceful timeout of connections.
defer t.Stop()
m.mu.Unlock() m.mu.Unlock()
// Block until all connections are closed // Prepare for a graceful shutdown
m.WaitGroup.Wait() waitSignal := make(chan struct{})
go func() {
return nil defer close(waitSignal)
} m.gracefulWait.Wait()
}()
// connState setups the ConnState tracking hook to know which connections are idle
func (m *ServerMux) connState() {
// Set our ConnState to track idle connections
m.Server.ConnState = func(c net.Conn, cs http.ConnState) {
m.mu.Lock()
defer m.mu.Unlock()
switch cs {
case http.StateNew:
// New connections increment the WaitGroup and are added the the conns dictionary
m.WaitGroup.Add(1)
if m.conns == nil {
m.conns = make(map[net.Conn]http.ConnState)
}
m.conns[c] = cs
case http.StateActive:
// Only update status to StateActive if it's in the conns dictionary
if _, ok := m.conns[c]; ok {
m.conns[c] = cs
}
case http.StateIdle:
// Only update status to StateIdle if it's in the conns dictionary
if _, ok := m.conns[c]; ok {
m.conns[c] = cs
}
// If we've already closed then we need to close this connection. select {
// We don't allow connections to become idle after server is closed // Wait for everything to be properly closed
if m.closed { case <-waitSignal:
c.Close() // Forced shutdown
} case <-time.After(m.gracefulTimeout):
case http.StateHijacked, http.StateClosed:
// If the connection is hijacked or closed we forget it
m.forgetConn(c)
}
} }
}
// forgetConn removes c from conns and decrements WaitGroup return nil
func (m *ServerMux) forgetConn(c net.Conn) {
if _, ok := m.conns[c]; ok {
delete(m.conns, c)
m.WaitGroup.Done()
}
} }

@ -29,7 +29,6 @@ import (
"math/big" "math/big"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"os" "os"
"runtime" "runtime"
"sync" "sync"
@ -181,113 +180,111 @@ func TestClose(t *testing.T) {
} }
func TestServerMux(t *testing.T) { func TestServerMux(t *testing.T) {
ts := httptest.NewUnstartedServer(nil) var err error
defer ts.Close() var got []byte
var res *http.Response
// Create ServerMux // Create ServerMux
m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux("127.0.0.1:0", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello") fmt.Fprint(w, "hello")
})) }))
// Start serving requests
// Set the test server config to the mux go m.ListenAndServe("", "")
ts.Config = m.Server
ts.Start() // Issue a GET request. Since we started server in a goroutine, it could be not ready
// at this point. So we allow until 5 failed retries before declare there is an error
// Create a ListenerMux for i := 0; i < 5; i++ {
lm := &ListenerMux{ // Sleep one second
Listener: ts.Listener, time.Sleep(1 * time.Second)
config: &tls.Config{}, // Check if one listener is ready
cond: sync.NewCond(&sync.Mutex{}), m.mu.Lock()
} if len(m.listeners) == 0 {
m.listeners = []*ListenerMux{lm} m.mu.Unlock()
continue
client := http.Client{} }
res, err := client.Get(ts.URL) m.mu.Unlock()
if err != nil { // Issue the GET request
t.Fatal(err) client := http.Client{}
m.mu.Lock()
res, err = client.Get("http://" + m.listeners[0].Addr().String())
m.mu.Unlock()
if err != nil {
continue
}
// Read the request response
got, err = ioutil.ReadAll(res.Body)
} }
got, err := ioutil.ReadAll(res.Body) // Check for error persisted after 5 times
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Check the web service response
if string(got) != "hello" { if string(got) != "hello" {
t.Errorf("got %q, want hello", string(got)) t.Errorf("got %q, want hello", string(got))
} }
// Make sure there is only 1 connection
m.mu.Lock()
if len(m.conns) < 1 {
t.Fatal("Should have 1 connections")
}
m.mu.Unlock()
// Close the server
m.Close()
// Make sure there are zero connections
m.mu.Lock()
if len(m.conns) > 0 {
t.Fatal("Should have 0 connections")
}
m.mu.Unlock()
} }
func TestServerCloseBlocking(t *testing.T) { func TestServerCloseBlocking(t *testing.T) {
ts := httptest.NewUnstartedServer(nil)
defer ts.Close()
// Create ServerMux // Create ServerMux
m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux("127.0.0.1:0", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello") fmt.Fprint(w, "hello")
})) }))
// Set the test server config to the mux // Start serving requests in a goroutine
ts.Config = m.Server go m.ListenAndServe("", "")
ts.Start()
// Dial, try until 5 times before declaring a failure
// Create a ListenerMux. dial := func() (net.Conn, error) {
lm := &ListenerMux{ var c net.Conn
Listener: ts.Listener, var err error
config: &tls.Config{}, for i := 0; i < 5; i++ {
cond: sync.NewCond(&sync.Mutex{}), // Sleep one second in case of the server is not ready yet
} time.Sleep(1 * time.Second)
m.listeners = []*ListenerMux{lm} // Check if there is at least one listener configured
m.mu.Lock()
dial := func() net.Conn { if len(m.listeners) == 0 {
c, cerr := net.Dial("tcp", ts.Listener.Addr().String()) m.mu.Unlock()
if cerr != nil { continue
t.Fatal(cerr) }
m.mu.Unlock()
// Run the actual Dial
m.mu.Lock()
c, err = net.Dial("tcp", m.listeners[0].Addr().String())
m.mu.Unlock()
if err != nil {
continue
}
} }
return c return c, err
} }
// Dial to open a StateNew but don't send anything // Dial to open a StateNew but don't send anything
cnew := dial() cnew, err := dial()
if err != nil {
t.Fatal(err)
}
defer cnew.Close() defer cnew.Close()
// Dial another connection but idle after a request to have StateIdle // Dial another connection but idle after a request to have StateIdle
cidle := dial() cidle, err := dial()
if err != nil {
t.Fatal(err)
}
defer cidle.Close() defer cidle.Close()
cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
_, err := http.ReadResponse(bufio.NewReader(cidle), nil) _, err = http.ReadResponse(bufio.NewReader(cidle), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
// Make sure we don't block forever. // Make sure we don't block forever.
m.Close() m.Close()
// Make sure there are zero connections
m.mu.Lock()
if len(m.conns) > 0 {
t.Fatal("Should have 0 connections")
}
m.mu.Unlock()
} }
func TestListenAndServePlain(t *testing.T) { func TestServerListenAndServePlain(t *testing.T) {
wait := make(chan struct{}) wait := make(chan struct{})
addr := net.JoinHostPort("127.0.0.1", getFreePort()) addr := net.JoinHostPort("127.0.0.1", getFreePort())
errc := make(chan error) errc := make(chan error)
@ -295,8 +292,6 @@ func TestListenAndServePlain(t *testing.T) {
// Initialize done channel specifically for each tests. // Initialize done channel specifically for each tests.
globalServiceDoneCh = make(chan struct{}, 1) globalServiceDoneCh = make(chan struct{}, 1)
// Initialize signal channel specifically for each tests.
globalServiceSignalCh = make(chan serviceSignal, 1)
// Create ServerMux and when we receive a request we stop waiting // Create ServerMux and when we receive a request we stop waiting
m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -337,7 +332,7 @@ func TestListenAndServePlain(t *testing.T) {
} }
} }
func TestListenAndServeTLS(t *testing.T) { func TestServerListenAndServeTLS(t *testing.T) {
wait := make(chan struct{}) wait := make(chan struct{})
addr := net.JoinHostPort("127.0.0.1", getFreePort()) addr := net.JoinHostPort("127.0.0.1", getFreePort())
errc := make(chan error) errc := make(chan error)

@ -19,7 +19,6 @@ package cmd
import ( import (
"fmt" "fmt"
"net" "net"
"net/http"
) )
// getListenIPs - gets all the ips to listen on. // getListenIPs - gets all the ips to listen on.
@ -49,7 +48,7 @@ func getListenIPs(serverAddr string) (hosts []string, port string, err error) {
} }
// Finalizes the API endpoints based on the host list and port. // Finalizes the API endpoints based on the host list and port.
func finalizeAPIEndpoints(apiServer *http.Server) (endPoints []string, err error) { func finalizeAPIEndpoints(addr string) (endPoints []string, err error) {
// Verify current scheme. // Verify current scheme.
scheme := httpScheme scheme := httpScheme
if globalIsSSL { if globalIsSSL {
@ -57,7 +56,7 @@ func finalizeAPIEndpoints(apiServer *http.Server) (endPoints []string, err error
} }
// Get list of listen ips and port. // Get list of listen ips and port.
hosts, port, err1 := getListenIPs(apiServer.Addr) hosts, port, err1 := getListenIPs(addr)
if err1 != nil { if err1 != nil {
return nil, err1 return nil, err1
} }

Loading…
Cancel
Save