diff --git a/cmd/server-main.go b/cmd/server-main.go index 69883404c..608043bec 100644 --- a/cmd/server-main.go +++ b/cmd/server-main.go @@ -251,7 +251,7 @@ func serverMain(c *cli.Context) { ignoredDisks: ignoredDisks, }) - apiServer := NewMuxServer(serverAddress, handler) + apiServer := NewServerMux(serverAddress, handler) // Fetch endpoints which we are going to serve from. endPoints := finalizeEndpoints(tls, &apiServer.Server) diff --git a/cmd/server-mux.go b/cmd/server-mux.go index 4f9bfce43..e4a4c91a7 100644 --- a/cmd/server-mux.go +++ b/cmd/server-mux.go @@ -50,28 +50,44 @@ type ConnBuf struct { offset int } -// MuxConn - implements a Read() which streams twice the firs bytes from +// ConnMux - implements a Read() which streams twice the firs bytes from // the incoming connection, to help peeking protocol -type MuxConn struct { +type ConnMux struct { net.Conn lastError error dataBuf ConnBuf } -// NewMuxConn - creates a new MuxConn instance -func NewMuxConn(c net.Conn) *MuxConn { +func longestWord(strings []string) int { + maxLen := 0 + for _, m := range defaultHTTP1Methods { + if maxLen < len(m) { + maxLen = len(m) + } + } + for _, m := range defaultHTTP2Methods { + if maxLen < len(m) { + maxLen = len(m) + } + } + + return maxLen +} + +// NewConnMux - creates a new ConnMux instance +func NewConnMux(c net.Conn) *ConnMux { h1 := longestWord(defaultHTTP1Methods) h2 := longestWord(defaultHTTP2Methods) max := h1 if h2 > max { max = h2 } - return &MuxConn{Conn: c, dataBuf: ConnBuf{buffer: make([]byte, max+1)}} + return &ConnMux{Conn: c, dataBuf: ConnBuf{buffer: make([]byte, max+1)}} } // PeekProtocol - reads the first bytes, then checks if it is similar // to one of the default http methods -func (c *MuxConn) PeekProtocol() string { +func (c *ConnMux) PeekProtocol() string { var n int n, c.lastError = c.Conn.Read(c.dataBuf.buffer) if n == 0 || (c.lastError != nil && c.lastError != io.EOF) { @@ -91,9 +107,9 @@ func (c *MuxConn) PeekProtocol() string { return "tls" } -// Read - streams the MuxConn buffer when reset flag is activated, otherwise +// Read - streams the ConnMux buffer when reset flag is activated, otherwise // streams from the incoming network connection -func (c *MuxConn) Read(b []byte) (int, error) { +func (c *ConnMux) Read(b []byte) (int, error) { if c.dataBuf.unRead { n := copy(b, c.dataBuf.buffer[c.dataBuf.offset:]) c.dataBuf.offset += n @@ -118,64 +134,40 @@ func (c *MuxConn) Read(b []byte) (int, error) { return c.Conn.Read(b) } -// MuxListener - encapuslates the standard net.Listener to inspect +// ListenerMux - encapuslates the standard net.Listener to inspect // the communication protocol upon network connection -type MuxListener struct { +type ListenerMux struct { net.Listener config *tls.Config - wg *sync.WaitGroup -} - -// NewMuxListener - creates new MuxListener, returns error when cert/key files are not found -// or invalid -func NewMuxListener(listener net.Listener, wg *sync.WaitGroup, certPath, keyPath string) (*MuxListener, error) { - var err error - config := &tls.Config{} // Always instantiate. - if certPath != "" && keyPath != "" { - if config.NextProtos == nil { - config.NextProtos = []string{"http/1.1", "h2"} - } - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath) - if err != nil { - return &MuxListener{}, err - } - } - - l := &MuxListener{Listener: listener, config: config, wg: wg} - - return l, nil } // Accept - peek the protocol to decide if we should wrap the // network stream with the TLS server -func (l *MuxListener) Accept() (net.Conn, error) { - c, err := l.Listener.Accept() +func (l *ListenerMux) Accept() (net.Conn, error) { + conn, err := l.Listener.Accept() if err != nil { - return c, err + return conn, err } - - cmux := NewMuxConn(c) - protocol := cmux.PeekProtocol() + connMux := NewConnMux(conn) + protocol := connMux.PeekProtocol() if protocol == "tls" { - return tls.Server(cmux, l.config), nil + return tls.Server(connMux, l.config), nil } - return cmux, nil + return connMux, nil } // Close Listener -func (l *MuxListener) Close() error { +func (l *ListenerMux) Close() error { if l == nil { return nil } - return l.Listener.Close() } -// MuxServer - the main mux server -type MuxServer struct { +// ServerMux - the main mux server +type ServerMux struct { http.Server - listener *MuxListener + listener *ListenerMux WaitGroup *sync.WaitGroup GracefulTimeout time.Duration mu sync.Mutex // guards closed, conns, and listener @@ -183,9 +175,9 @@ type MuxServer struct { conns map[net.Conn]http.ConnState // except terminal states } -// NewMuxServer constructor to create a MuxServer -func NewMuxServer(addr string, handler http.Handler) *MuxServer { - m := &MuxServer{ +// NewServerMux constructor to create a ServerMux +func NewServerMux(addr string, handler http.Handler) *ServerMux { + m := &ServerMux{ Server: http.Server{ Addr: addr, // Do not add any timeouts Golang net.Conn @@ -208,23 +200,31 @@ func NewMuxServer(addr string, handler http.Handler) *MuxServer { // ListenAndServeTLS - similar to the http.Server version. However, it has the // ability to redirect http requests to the correct HTTPS url if the client // mistakenly initiates a http connection over the https port -func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { +func (m *ServerMux) ListenAndServeTLS(certFile, keyFile string) error { listener, err := net.Listen("tcp", m.Server.Addr) if err != nil { return err } - mux, err := NewMuxListener(listener, m.WaitGroup, mustGetCertFile(), mustGetKeyFile()) + + config := &tls.Config{} // Always instantiate. + if config.NextProtos == nil { + config.NextProtos = []string{"http/1.1", "h2"} + } + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return err } + listenerMux := &ListenerMux{Listener: listener, config: config} + m.mu.Lock() - m.listener = mux + m.listener = listenerMux m.mu.Unlock() - err = http.Serve(mux, + err = http.Serve(listenerMux, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // We reach here when MuxListener.MuxConn is not wrapped with tls.Server + // We reach here when ListenerMux.ConnMux is not wrapped with tls.Server if r.TLS == nil { u := url.URL{ Scheme: "https", @@ -245,42 +245,23 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { } // ListenAndServe - Same as the http.Server version -func (m *MuxServer) ListenAndServe() error { +func (m *ServerMux) ListenAndServe() error { listener, err := net.Listen("tcp", m.Server.Addr) if err != nil { return err } - mux, err := NewMuxListener(listener, m.WaitGroup, "", "") - if err != nil { - return err - } + listenerMux := &ListenerMux{Listener: listener, config: &tls.Config{}} m.mu.Lock() - m.listener = mux + m.listener = listenerMux m.mu.Unlock() - return m.Server.Serve(mux) -} - -func longestWord(strings []string) int { - maxLen := 0 - for _, m := range defaultHTTP1Methods { - if maxLen < len(m) { - maxLen = len(m) - } - } - for _, m := range defaultHTTP2Methods { - if maxLen < len(m) { - maxLen = len(m) - } - } - - return maxLen + return m.Server.Serve(listenerMux) } // Close initiates the graceful shutdown -func (m *MuxServer) Close() error { +func (m *ServerMux) Close() error { m.mu.Lock() if m.closed { return errors.New("Server has been closed") @@ -317,7 +298,7 @@ func (m *MuxServer) Close() error { } // connState setups the ConnState tracking hook to know which connections are idle -func (m *MuxServer) connState() { +func (m *ServerMux) connState() { // Set our ConnState to track idle connections m.Server.ConnState = func(c net.Conn, cs http.ConnState) { m.mu.Lock() @@ -355,7 +336,7 @@ func (m *MuxServer) connState() { } // forgetConn removes c from conns and decrements WaitGroup -func (m *MuxServer) forgetConn(c net.Conn) { +func (m *ServerMux) forgetConn(c net.Conn) { if _, ok := m.conns[c]; ok { delete(m.conns, c) m.WaitGroup.Done() diff --git a/cmd/server-mux_test.go b/cmd/server-mux_test.go index 221b6052f..bdf556025 100644 --- a/cmd/server-mux_test.go +++ b/cmd/server-mux_test.go @@ -40,19 +40,19 @@ import ( func TestClose(t *testing.T) { // Create ServerMux - m := NewMuxServer("", nil) + m := NewServerMux("", nil) if err := m.Close(); err != nil { t.Error("Server errored while trying to Close", err) } } -func TestMuxServer(t *testing.T) { +func TestServerMux(t *testing.T) { ts := httptest.NewUnstartedServer(nil) defer ts.Close() // Create ServerMux - m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "hello") })) @@ -60,12 +60,12 @@ func TestMuxServer(t *testing.T) { ts.Config = &m.Server ts.Start() - // Create a MuxListener - ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "") - if err != nil { - t.Fatal(err) + // Create a ListenerMux + lm := &ListenerMux{ + Listener: ts.Listener, + config: &tls.Config{}, } - m.listener = ml + m.listener = lm client := http.Client{} res, err := client.Get(ts.URL) @@ -105,7 +105,7 @@ func TestServerCloseBlocking(t *testing.T) { defer ts.Close() // Create ServerMux - m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "hello") })) @@ -113,18 +113,17 @@ func TestServerCloseBlocking(t *testing.T) { ts.Config = &m.Server ts.Start() - // Create a MuxListener - // var err error - ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "") - if err != nil { - t.Fatal(err) + // Create a ListenerMux. + lm := &ListenerMux{ + Listener: ts.Listener, + config: &tls.Config{}, } - m.listener = ml + m.listener = lm dial := func() net.Conn { c, cerr := net.Dial("tcp", ts.Listener.Addr().String()) if cerr != nil { - t.Fatal(err) + t.Fatal(cerr) } return c } @@ -137,7 +136,7 @@ func TestServerCloseBlocking(t *testing.T) { cidle := dial() defer cidle.Close() cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) - _, err = http.ReadResponse(bufio.NewReader(cidle), nil) + _, err := http.ReadResponse(bufio.NewReader(cidle), nil) if err != nil { t.Fatal(err) } @@ -160,7 +159,7 @@ func TestListenAndServePlain(t *testing.T) { once := &sync.Once{} // Create ServerMux and when we receive a request we stop waiting - m := NewMuxServer(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "hello") once.Do(func() { close(wait) }) })) @@ -209,7 +208,7 @@ func TestListenAndServeTLS(t *testing.T) { once := &sync.Once{} // Create ServerMux and when we receive a request we stop waiting - m := NewMuxServer(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "hello") once.Do(func() { close(wait) }) }))