diff --git a/server-mux.go b/server-mux.go index 822c0810b..69588db95 100644 --- a/server-mux.go +++ b/server-mux.go @@ -181,7 +181,7 @@ type MuxServer struct { listener *MuxListener WaitGroup *sync.WaitGroup GracefulTimeout time.Duration - mu sync.Mutex // guards closed and conns + mu sync.Mutex // guards closed, conns, and listener closed bool conns map[net.Conn]http.ConnState // except terminal states } @@ -221,7 +221,9 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { return err } + m.mu.Lock() m.listener = mux + m.mu.Unlock() err = http.Serve(mux, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -257,7 +259,9 @@ func (m *MuxServer) ListenAndServe() error { return err } + m.mu.Lock() m.listener = mux + m.mu.Unlock() return m.Server.Serve(mux) } @@ -280,42 +284,39 @@ func longestWord(strings []string) int { // Close initiates the graceful shutdown func (m *MuxServer) Close() error { + m.mu.Lock() if m.closed { return errors.New("Server has been closed") } - - m.mu.Lock() - m.Server.SetKeepAlivesEnabled(false) m.closed = true - m.mu.Unlock() + + // Make sure a listener was set if err := m.listener.Close(); err != nil { return err } - // force connections to close after timeout - wait := make(chan struct{}) - go func() { - defer close(wait) - m.mu.Lock() - for c, st := range m.conns { - // Force close any idle and new connections. - if st == http.StateIdle || st == http.StateNew { - c.Close() - } + m.SetKeepAlivesEnabled(false) + for c, st := range m.conns { + // Force close any idle and new connections. Waiting for other connections + // to close on their own (within the timeout period) + if st == http.StateIdle || st == http.StateNew { + c.Close() } - m.mu.Unlock() + } - // Wait for all connections to be gracefully closed - m.WaitGroup.Wait() - }() + // If the GracefulTimeout happens then forcefully close all connections + t := time.AfterFunc(m.GracefulTimeout, func() { + for c := range m.conns { + c.Close() + } + }) + defer t.Stop() - // We block until all active connections are closed or the GracefulTimeout happens - select { - case <-time.After(m.GracefulTimeout): - return nil - case <-wait: - return nil - } + m.mu.Unlock() + + // Block until all connections are closed + m.WaitGroup.Wait() + return nil } // connState setups the ConnState tracking hook to know which connections are idle diff --git a/server-mux_test.go b/server-mux_test.go index 2cbf64cb4..96a1232ae 100644 --- a/server-mux_test.go +++ b/server-mux_test.go @@ -17,8 +17,10 @@ package main import ( + "bufio" "fmt" "io/ioutil" + "net" "net/http" "net/http/httptest" "testing" @@ -82,9 +84,79 @@ func TestMuxServer(t *testing.T) { // Make sure there are zero connections m.mu.Lock() - if len(m.conns) < 0 { + if len(m.conns) > 0 { t.Fatal("Should have 0 connections") } m.mu.Unlock() +} + +func TestServerCloseBlocking(t *testing.T) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + // Create ServerMux + m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "hello") + })) + + // Set the test server config to the mux + ts.Config = &m.Server + ts.Start() + + // Create a MuxListener + // var err error + ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "") + if err != nil { + t.Fatal(err) + } + m.listener = ml + + dial := func() net.Conn { + c, cerr := net.Dial("tcp", ts.Listener.Addr().String()) + if cerr != nil { + t.Fatal(err) + } + return c + } + + // Dial to open a StateNew but don't send anything + cnew := dial() + defer cnew.Close() + + // Dial another connection but idle after a request to have StateIdle + cidle := dial() + defer cidle.Close() + cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) + _, err = http.ReadResponse(bufio.NewReader(cidle), nil) + if err != nil { + t.Fatal(err) + } + + // Make sure we don't block forever. + m.Close() + + // Make sure there are zero connections + m.mu.Lock() + if len(m.conns) > 0 { + t.Fatal("Should have 0 connections") + } + m.mu.Unlock() +} +func TestListenAndServe(t *testing.T) { + m := NewMuxServer("", nil) + stopc := make(chan struct{}) + errc := make(chan error) + go func() { errc <- m.ListenAndServe() }() + go func() { errc <- m.Close(); close(stopc) }() + select { + case err := <-errc: + if err != nil { + t.Fatal(err) + } + case <-stopc: + return + } }