diff --git a/cmd/server-main.go b/cmd/server-main.go index 1971ef442..e54143d73 100644 --- a/cmd/server-main.go +++ b/cmd/server-main.go @@ -449,8 +449,8 @@ func serverMain(c *cli.Context) { initGlobalAdminPeers(endpoints) // Determine API endpoints where we are going to serve the S3 API from. - apiEndPoints, err := finalizeAPIEndpoints(apiServer.Server) - fatalIf(err, "Unable to finalize API endpoints for %s", apiServer.Server.Addr) + apiEndPoints, err := finalizeAPIEndpoints(apiServer.Addr) + fatalIf(err, "Unable to finalize API endpoints for %s", apiServer.Addr) // Set the global API endpoints value. globalAPIEndpoints = apiEndPoints diff --git a/cmd/server-main_test.go b/cmd/server-main_test.go index df152b3ae..5fbc17029 100644 --- a/cmd/server-main_test.go +++ b/cmd/server-main_test.go @@ -19,7 +19,6 @@ package cmd import ( "errors" "flag" - "net/http" "os" "reflect" "runtime" @@ -120,9 +119,7 @@ func TestFinalizeAPIEndpoints(t *testing.T) { } for i, test := range testCases { - endPoints, err := finalizeAPIEndpoints(&http.Server{ - Addr: test.addr, - }) + endPoints, err := finalizeAPIEndpoints(test.addr) if err != nil && len(endPoints) <= 0 { t.Errorf("Test case %d returned with no API end points for %s", i+1, test.addr) diff --git a/cmd/server-mux.go b/cmd/server-mux.go index 16d525d28..ae2eb0b73 100644 --- a/cmd/server-mux.go +++ b/cmd/server-mux.go @@ -309,32 +309,28 @@ func (l *ListenerMux) Accept() (net.Conn, error) { // ServerMux - the main mux server type ServerMux struct { - *http.Server - listeners []*ListenerMux - WaitGroup *sync.WaitGroup - GracefulTimeout time.Duration - mu sync.Mutex // guards closed, conns, and listener - closed bool - conns map[net.Conn]http.ConnState // except terminal states + Addr string + handler http.Handler + listeners []*ListenerMux + + gracefulWait *sync.WaitGroup + gracefulTimeout time.Duration + + mu sync.Mutex // guards closed, and listener + closed bool } // NewServerMux constructor to create a ServerMux func NewServerMux(addr string, handler http.Handler) *ServerMux { m := &ServerMux{ - Server: &http.Server{ - Addr: addr, - Handler: handler, - MaxHeaderBytes: 1 << 20, - }, - WaitGroup: &sync.WaitGroup{}, + Addr: addr, + handler: handler, // Wait for 5 seconds for new incoming connnections, otherwise // forcibly close them during graceful stop or restart. - GracefulTimeout: 5 * time.Second, + gracefulTimeout: 5 * time.Second, + gracefulWait: &sync.WaitGroup{}, } - // Track connection state - m.connState() - // Returns configured HTTP server. return m } @@ -421,7 +417,7 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) { go m.handleServiceSignals() - listeners, err := initListeners(m.Server.Addr, config) + listeners, err := initListeners(m.Addr, config) if err != nil { return err } @@ -445,8 +441,11 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) { } http.Redirect(w, r, u.String(), http.StatusTemporaryRedirect) } else { - // Execute registered handlers - m.Server.Handler.ServeHTTP(w, r) + // Execute registered handlers, protect with a waitgroup + // to accomplish a graceful shutdown when the user asks to quit + m.gracefulWait.Add(1) + m.handler.ServeHTTP(w, r) + m.gracefulWait.Done() } }) @@ -470,6 +469,7 @@ func (m *ServerMux) ListenAndServe(certFile, keyFile string) (err error) { // Close initiates the graceful shutdown func (m *ServerMux) Close() error { m.mu.Lock() + if m.closed { m.mu.Unlock() return errors.New("Server has been closed") @@ -484,76 +484,21 @@ func (m *ServerMux) Close() error { return err } } - - m.SetKeepAlivesEnabled(false) - // Force close any idle and new connections. Waiting for other connections - // to close on their own (within the timeout period) - for c, st := range m.conns { - if st == http.StateIdle || st == http.StateNew { - c.Close() - } - } - - // If the GracefulTimeout happens then forcefully close all connections - t := time.AfterFunc(m.GracefulTimeout, func() { - for c := range m.conns { - c.Close() - } - }) - - // Wait for graceful timeout of connections. - defer t.Stop() - m.mu.Unlock() - // Block until all connections are closed - m.WaitGroup.Wait() - - return nil -} - -// connState setups the ConnState tracking hook to know which connections are idle -func (m *ServerMux) connState() { - // Set our ConnState to track idle connections - m.Server.ConnState = func(c net.Conn, cs http.ConnState) { - m.mu.Lock() - defer m.mu.Unlock() - - switch cs { - case http.StateNew: - // New connections increment the WaitGroup and are added the the conns dictionary - m.WaitGroup.Add(1) - if m.conns == nil { - m.conns = make(map[net.Conn]http.ConnState) - } - m.conns[c] = cs - case http.StateActive: - // Only update status to StateActive if it's in the conns dictionary - if _, ok := m.conns[c]; ok { - m.conns[c] = cs - } - case http.StateIdle: - // Only update status to StateIdle if it's in the conns dictionary - if _, ok := m.conns[c]; ok { - m.conns[c] = cs - } + // Prepare for a graceful shutdown + waitSignal := make(chan struct{}) + go func() { + defer close(waitSignal) + m.gracefulWait.Wait() + }() - // If we've already closed then we need to close this connection. - // We don't allow connections to become idle after server is closed - if m.closed { - c.Close() - } - case http.StateHijacked, http.StateClosed: - // If the connection is hijacked or closed we forget it - m.forgetConn(c) - } + select { + // Wait for everything to be properly closed + case <-waitSignal: + // Forced shutdown + case <-time.After(m.gracefulTimeout): } -} -// forgetConn removes c from conns and decrements WaitGroup -func (m *ServerMux) forgetConn(c net.Conn) { - if _, ok := m.conns[c]; ok { - delete(m.conns, c) - m.WaitGroup.Done() - } + return nil } diff --git a/cmd/server-mux_test.go b/cmd/server-mux_test.go index 661ad05a4..1a24473a6 100644 --- a/cmd/server-mux_test.go +++ b/cmd/server-mux_test.go @@ -29,7 +29,6 @@ import ( "math/big" "net" "net/http" - "net/http/httptest" "os" "runtime" "sync" @@ -181,113 +180,111 @@ func TestClose(t *testing.T) { } func TestServerMux(t *testing.T) { - ts := httptest.NewUnstartedServer(nil) - defer ts.Close() + var err error + var got []byte + var res *http.Response // Create ServerMux - m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := NewServerMux("127.0.0.1:0", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "hello") })) - - // Set the test server config to the mux - ts.Config = m.Server - ts.Start() - - // Create a ListenerMux - lm := &ListenerMux{ - Listener: ts.Listener, - config: &tls.Config{}, - cond: sync.NewCond(&sync.Mutex{}), - } - m.listeners = []*ListenerMux{lm} - - client := http.Client{} - res, err := client.Get(ts.URL) - if err != nil { - t.Fatal(err) + // Start serving requests + go m.ListenAndServe("", "") + + // Issue a GET request. Since we started server in a goroutine, it could be not ready + // at this point. So we allow until 5 failed retries before declare there is an error + for i := 0; i < 5; i++ { + // Sleep one second + time.Sleep(1 * time.Second) + // Check if one listener is ready + m.mu.Lock() + if len(m.listeners) == 0 { + m.mu.Unlock() + continue + } + m.mu.Unlock() + // Issue the GET request + client := http.Client{} + m.mu.Lock() + res, err = client.Get("http://" + m.listeners[0].Addr().String()) + m.mu.Unlock() + if err != nil { + continue + } + // Read the request response + got, err = ioutil.ReadAll(res.Body) } - got, err := ioutil.ReadAll(res.Body) + // Check for error persisted after 5 times if err != nil { t.Fatal(err) } + // Check the web service response if string(got) != "hello" { t.Errorf("got %q, want hello", string(got)) } - - // Make sure there is only 1 connection - m.mu.Lock() - if len(m.conns) < 1 { - t.Fatal("Should have 1 connections") - } - m.mu.Unlock() - - // Close the server - m.Close() - - // Make sure there are zero connections - m.mu.Lock() - if len(m.conns) > 0 { - t.Fatal("Should have 0 connections") - } - m.mu.Unlock() } func TestServerCloseBlocking(t *testing.T) { - ts := httptest.NewUnstartedServer(nil) - defer ts.Close() - // Create ServerMux - m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := NewServerMux("127.0.0.1:0", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "hello") })) - // Set the test server config to the mux - ts.Config = m.Server - ts.Start() - - // Create a ListenerMux. - lm := &ListenerMux{ - Listener: ts.Listener, - config: &tls.Config{}, - cond: sync.NewCond(&sync.Mutex{}), - } - m.listeners = []*ListenerMux{lm} - - dial := func() net.Conn { - c, cerr := net.Dial("tcp", ts.Listener.Addr().String()) - if cerr != nil { - t.Fatal(cerr) + // Start serving requests in a goroutine + go m.ListenAndServe("", "") + + // Dial, try until 5 times before declaring a failure + dial := func() (net.Conn, error) { + var c net.Conn + var err error + for i := 0; i < 5; i++ { + // Sleep one second in case of the server is not ready yet + time.Sleep(1 * time.Second) + // Check if there is at least one listener configured + m.mu.Lock() + if len(m.listeners) == 0 { + m.mu.Unlock() + continue + } + m.mu.Unlock() + // Run the actual Dial + m.mu.Lock() + c, err = net.Dial("tcp", m.listeners[0].Addr().String()) + m.mu.Unlock() + if err != nil { + continue + } } - return c + return c, err } // Dial to open a StateNew but don't send anything - cnew := dial() + cnew, err := dial() + if err != nil { + t.Fatal(err) + } defer cnew.Close() // Dial another connection but idle after a request to have StateIdle - cidle := dial() + cidle, err := dial() + if err != nil { + t.Fatal(err) + } defer cidle.Close() + cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) - _, err := http.ReadResponse(bufio.NewReader(cidle), nil) + _, err = http.ReadResponse(bufio.NewReader(cidle), nil) if err != nil { t.Fatal(err) } // Make sure we don't block forever. m.Close() - - // Make sure there are zero connections - m.mu.Lock() - if len(m.conns) > 0 { - t.Fatal("Should have 0 connections") - } - m.mu.Unlock() } -func TestListenAndServePlain(t *testing.T) { +func TestServerListenAndServePlain(t *testing.T) { wait := make(chan struct{}) addr := net.JoinHostPort("127.0.0.1", getFreePort()) errc := make(chan error) @@ -295,8 +292,6 @@ func TestListenAndServePlain(t *testing.T) { // Initialize done channel specifically for each tests. globalServiceDoneCh = make(chan struct{}, 1) - // Initialize signal channel specifically for each tests. - globalServiceSignalCh = make(chan serviceSignal, 1) // Create ServerMux and when we receive a request we stop waiting m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -337,7 +332,7 @@ func TestListenAndServePlain(t *testing.T) { } } -func TestListenAndServeTLS(t *testing.T) { +func TestServerListenAndServeTLS(t *testing.T) { wait := make(chan struct{}) addr := net.JoinHostPort("127.0.0.1", getFreePort()) errc := make(chan error) diff --git a/cmd/server-startup-utils.go b/cmd/server-startup-utils.go index aef08e2b6..f19164595 100644 --- a/cmd/server-startup-utils.go +++ b/cmd/server-startup-utils.go @@ -19,7 +19,6 @@ package cmd import ( "fmt" "net" - "net/http" ) // getListenIPs - gets all the ips to listen on. @@ -49,7 +48,7 @@ func getListenIPs(serverAddr string) (hosts []string, port string, err error) { } // Finalizes the API endpoints based on the host list and port. -func finalizeAPIEndpoints(apiServer *http.Server) (endPoints []string, err error) { +func finalizeAPIEndpoints(addr string) (endPoints []string, err error) { // Verify current scheme. scheme := httpScheme if globalIsSSL { @@ -57,7 +56,7 @@ func finalizeAPIEndpoints(apiServer *http.Server) (endPoints []string, err error } // Get list of listen ips and port. - hosts, port, err1 := getListenIPs(apiServer.Addr) + hosts, port, err1 := getListenIPs(addr) if err1 != nil { return nil, err1 }