diff --git a/server-main.go b/server-main.go index 900b20f1d..76d9f7596 100644 --- a/server-main.go +++ b/server-main.go @@ -90,24 +90,6 @@ type serverCmdConfig struct { ignoredDisks []string } -// configureServer configure a new server instance -func configureServer(srvCmdConfig serverCmdConfig) *MuxServer { - // Minio server config - apiServer := &MuxServer{ - Server: http.Server{ - Addr: srvCmdConfig.serverAddr, - // Adding timeout of 10 minutes for unresponsive client connections. - ReadTimeout: 10 * time.Minute, - WriteTimeout: 10 * time.Minute, - Handler: configureServerHandler(srvCmdConfig), - MaxHeaderBytes: 1 << 20, - }, - } - - // Returns configured HTTP server. - return apiServer -} - // getListenIPs - gets all the ips to listen on. func getListenIPs(httpServerConf *http.Server) (hosts []string, port string) { host, port, err := net.SplitHostPort(httpServerConf.Addr) @@ -263,12 +245,14 @@ func serverMain(c *cli.Context) { disks := c.Args() // Configure server. - apiServer := configureServer(serverCmdConfig{ + handler := configureServerHandler(serverCmdConfig{ serverAddr: serverAddress, disks: disks, ignoredDisks: ignoredDisks, }) + apiServer := NewMuxServer(serverAddress, handler) + // Fetch endpoints which we are going to serve from. endPoints := finalizeEndpoints(tls, &apiServer.Server) diff --git a/server-mux.go b/server-mux.go index b28159409..822c0810b 100644 --- a/server-mux.go +++ b/server-mux.go @@ -18,11 +18,14 @@ package main import ( "crypto/tls" + "errors" "io" "net" "net/http" "net/url" "strings" + "sync" + "time" ) var defaultHTTP2Methods = []string{ @@ -120,31 +123,41 @@ func (c *MuxConn) Read(b []byte) (int, error) { type MuxListener struct { net.Listener config *tls.Config + wg *sync.WaitGroup } // NewMuxListener - creates new MuxListener, returns error when cert/key files are not found // or invalid -func NewMuxListener(listener net.Listener, certPath, keyPath string) (MuxListener, error) { +func NewMuxListener(listener net.Listener, wg *sync.WaitGroup, certPath, keyPath string) (*MuxListener, error) { var err error - config := &tls.Config{} - if config.NextProtos == nil { - config.NextProtos = []string{"http/1.1", "h2"} - } - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return MuxListener{}, err + var config *tls.Config + config = nil + + if certPath != "" { + config = &tls.Config{} + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1", "h2"} + } + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return &MuxListener{}, err + } } - return MuxListener{Listener: listener, config: config}, nil + + l := &MuxListener{Listener: listener, config: config, wg: wg} + + return l, nil } // Accept - peek the protocol to decide if we should wrap the // network stream with the TLS server -func (l MuxListener) Accept() (net.Conn, error) { +func (l *MuxListener) Accept() (net.Conn, error) { c, err := l.Listener.Accept() if err != nil { return c, err } + cmux := NewMuxConn(c) protocol := cmux.PeekProtocol() if protocol == "tls" { @@ -153,9 +166,46 @@ func (l MuxListener) Accept() (net.Conn, error) { return cmux, nil } +// Close Listener +func (l *MuxListener) Close() error { + if l == nil { + return nil + } + + return l.Listener.Close() +} + // MuxServer - the main mux server type MuxServer struct { http.Server + listener *MuxListener + WaitGroup *sync.WaitGroup + GracefulTimeout time.Duration + mu sync.Mutex // guards closed and conns + closed bool + conns map[net.Conn]http.ConnState // except terminal states +} + +// NewMuxServer constructor to create a MuxServer +func NewMuxServer(addr string, handler http.Handler) *MuxServer { + m := &MuxServer{ + Server: http.Server{ + Addr: addr, + // Adding timeout of 10 minutes for unresponsive client connections. + ReadTimeout: 10 * time.Minute, + WriteTimeout: 10 * time.Minute, + Handler: handler, + MaxHeaderBytes: 1 << 20, + }, + WaitGroup: &sync.WaitGroup{}, + GracefulTimeout: 5 * time.Second, + } + + // Track connection state + m.connState() + + // Returns configured HTTP server. + return m } // ListenAndServeTLS - similar to the http.Server version. However, it has the @@ -166,10 +216,13 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { if err != nil { return err } - mux, err := NewMuxListener(listener, mustGetCertFile(), mustGetKeyFile()) + mux, err := NewMuxListener(listener, m.WaitGroup, mustGetCertFile(), mustGetKeyFile()) if err != nil { return err } + + m.listener = mux + err = http.Serve(mux, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // We reach here when MuxListener.MuxConn is not wrapped with tls.Server @@ -194,7 +247,19 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { // ListenAndServe - Same as the http.Server version func (m *MuxServer) ListenAndServe() error { - return m.Server.ListenAndServe() + listener, err := net.Listen("tcp", m.Server.Addr) + if err != nil { + return err + } + + mux, err := NewMuxListener(listener, m.WaitGroup, "", "") + if err != nil { + return err + } + + m.listener = mux + + return m.Server.Serve(mux) } func longestWord(strings []string) int { @@ -212,3 +277,89 @@ func longestWord(strings []string) int { return maxLen } + +// Close initiates the graceful shutdown +func (m *MuxServer) Close() error { + if m.closed { + return errors.New("Server has been closed") + } + + m.mu.Lock() + m.Server.SetKeepAlivesEnabled(false) + m.closed = true + m.mu.Unlock() + if err := m.listener.Close(); err != nil { + return err + } + + // force connections to close after timeout + wait := make(chan struct{}) + go func() { + defer close(wait) + m.mu.Lock() + for c, st := range m.conns { + // Force close any idle and new connections. + if st == http.StateIdle || st == http.StateNew { + c.Close() + } + } + m.mu.Unlock() + + // Wait for all connections to be gracefully closed + m.WaitGroup.Wait() + }() + + // We block until all active connections are closed or the GracefulTimeout happens + select { + case <-time.After(m.GracefulTimeout): + return nil + case <-wait: + return nil + } +} + +// connState setups the ConnState tracking hook to know which connections are idle +func (m *MuxServer) connState() { + // Set our ConnState to track idle connections + m.Server.ConnState = func(c net.Conn, cs http.ConnState) { + m.mu.Lock() + defer m.mu.Unlock() + + switch cs { + case http.StateNew: + // New connections increment the WaitGroup and are added the the conns dictionary + m.WaitGroup.Add(1) + if m.conns == nil { + m.conns = make(map[net.Conn]http.ConnState) + } + m.conns[c] = cs + case http.StateActive: + // Only update status to StateActive if it's in the conns dictionary + if _, ok := m.conns[c]; ok { + m.conns[c] = cs + } + case http.StateIdle: + // Only update status to StateIdle if it's in the conns dictionary + if _, ok := m.conns[c]; ok { + m.conns[c] = cs + } + + // If we've already closed then we need to close this connection. + // We don't allow connections to become idle after server is closed + if m.closed { + c.Close() + } + case http.StateHijacked, http.StateClosed: + // If the connection is hijacked or closed we forget it + m.forgetConn(c) + } + } +} + +// forgetConn removes c from conns and decrements WaitGroup +func (m *MuxServer) forgetConn(c net.Conn) { + if _, ok := m.conns[c]; ok { + delete(m.conns, c) + m.WaitGroup.Done() + } +} diff --git a/server-mux_test.go b/server-mux_test.go new file mode 100644 index 000000000..2cbf64cb4 --- /dev/null +++ b/server-mux_test.go @@ -0,0 +1,90 @@ +/* + * Minio Cloud Storage, (C) 2015, 2016 Minio, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package main + +import ( + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +func TestClose(t *testing.T) { + // Create ServerMux + m := NewMuxServer("", nil) + + if err := m.Close(); err != nil { + t.Error("Server errored while trying to Close", err) + } +} + +func TestMuxServer(t *testing.T) { + ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, "Hello, client") + })) + defer ts.Close() + + // Create ServerMux + m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "hello") + })) + + // Set the test server config to the mux + ts.Config = &m.Server + ts.Start() + + // Create a MuxListener + ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "") + if err != nil { + t.Fatal(err) + } + m.listener = ml + + client := http.Client{} + res, err := client.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + + got, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + + if string(got) != "hello" { + t.Errorf("got %q, want hello", string(got)) + } + + // Make sure there is only 1 connection + m.mu.Lock() + if len(m.conns) < 1 { + t.Fatal("Should have 1 connections") + } + m.mu.Unlock() + + // Close the server + m.Close() + + // Make sure there are zero connections + m.mu.Lock() + if len(m.conns) < 0 { + t.Fatal("Should have 0 connections") + } + m.mu.Unlock() + +}