server/mux: Remove unused waitgroup from listenerMux. (#2614)

Change struct names to be more meaningful.
master
Harshavardhana 8 years ago committed by GitHub
parent 2dc7ecc59b
commit bc8f34bfe7
  1. 2
      cmd/server-main.go
  2. 139
      cmd/server-mux.go
  3. 37
      cmd/server-mux_test.go

@ -251,7 +251,7 @@ func serverMain(c *cli.Context) {
ignoredDisks: ignoredDisks, ignoredDisks: ignoredDisks,
}) })
apiServer := NewMuxServer(serverAddress, handler) apiServer := NewServerMux(serverAddress, handler)
// Fetch endpoints which we are going to serve from. // Fetch endpoints which we are going to serve from.
endPoints := finalizeEndpoints(tls, &apiServer.Server) endPoints := finalizeEndpoints(tls, &apiServer.Server)

@ -50,28 +50,44 @@ type ConnBuf struct {
offset int offset int
} }
// MuxConn - implements a Read() which streams twice the firs bytes from // ConnMux - implements a Read() which streams twice the firs bytes from
// the incoming connection, to help peeking protocol // the incoming connection, to help peeking protocol
type MuxConn struct { type ConnMux struct {
net.Conn net.Conn
lastError error lastError error
dataBuf ConnBuf dataBuf ConnBuf
} }
// NewMuxConn - creates a new MuxConn instance func longestWord(strings []string) int {
func NewMuxConn(c net.Conn) *MuxConn { maxLen := 0
for _, m := range defaultHTTP1Methods {
if maxLen < len(m) {
maxLen = len(m)
}
}
for _, m := range defaultHTTP2Methods {
if maxLen < len(m) {
maxLen = len(m)
}
}
return maxLen
}
// NewConnMux - creates a new ConnMux instance
func NewConnMux(c net.Conn) *ConnMux {
h1 := longestWord(defaultHTTP1Methods) h1 := longestWord(defaultHTTP1Methods)
h2 := longestWord(defaultHTTP2Methods) h2 := longestWord(defaultHTTP2Methods)
max := h1 max := h1
if h2 > max { if h2 > max {
max = h2 max = h2
} }
return &MuxConn{Conn: c, dataBuf: ConnBuf{buffer: make([]byte, max+1)}} return &ConnMux{Conn: c, dataBuf: ConnBuf{buffer: make([]byte, max+1)}}
} }
// PeekProtocol - reads the first bytes, then checks if it is similar // PeekProtocol - reads the first bytes, then checks if it is similar
// to one of the default http methods // to one of the default http methods
func (c *MuxConn) PeekProtocol() string { func (c *ConnMux) PeekProtocol() string {
var n int var n int
n, c.lastError = c.Conn.Read(c.dataBuf.buffer) n, c.lastError = c.Conn.Read(c.dataBuf.buffer)
if n == 0 || (c.lastError != nil && c.lastError != io.EOF) { if n == 0 || (c.lastError != nil && c.lastError != io.EOF) {
@ -91,9 +107,9 @@ func (c *MuxConn) PeekProtocol() string {
return "tls" return "tls"
} }
// Read - streams the MuxConn buffer when reset flag is activated, otherwise // Read - streams the ConnMux buffer when reset flag is activated, otherwise
// streams from the incoming network connection // streams from the incoming network connection
func (c *MuxConn) Read(b []byte) (int, error) { func (c *ConnMux) Read(b []byte) (int, error) {
if c.dataBuf.unRead { if c.dataBuf.unRead {
n := copy(b, c.dataBuf.buffer[c.dataBuf.offset:]) n := copy(b, c.dataBuf.buffer[c.dataBuf.offset:])
c.dataBuf.offset += n c.dataBuf.offset += n
@ -118,64 +134,40 @@ func (c *MuxConn) Read(b []byte) (int, error) {
return c.Conn.Read(b) return c.Conn.Read(b)
} }
// MuxListener - encapuslates the standard net.Listener to inspect // ListenerMux - encapuslates the standard net.Listener to inspect
// the communication protocol upon network connection // the communication protocol upon network connection
type MuxListener struct { type ListenerMux struct {
net.Listener net.Listener
config *tls.Config config *tls.Config
wg *sync.WaitGroup
}
// NewMuxListener - creates new MuxListener, returns error when cert/key files are not found
// or invalid
func NewMuxListener(listener net.Listener, wg *sync.WaitGroup, certPath, keyPath string) (*MuxListener, error) {
var err error
config := &tls.Config{} // Always instantiate.
if certPath != "" && keyPath != "" {
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1", "h2"}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return &MuxListener{}, err
}
}
l := &MuxListener{Listener: listener, config: config, wg: wg}
return l, nil
} }
// Accept - peek the protocol to decide if we should wrap the // Accept - peek the protocol to decide if we should wrap the
// network stream with the TLS server // network stream with the TLS server
func (l *MuxListener) Accept() (net.Conn, error) { func (l *ListenerMux) Accept() (net.Conn, error) {
c, err := l.Listener.Accept() conn, err := l.Listener.Accept()
if err != nil { if err != nil {
return c, err return conn, err
} }
connMux := NewConnMux(conn)
cmux := NewMuxConn(c) protocol := connMux.PeekProtocol()
protocol := cmux.PeekProtocol()
if protocol == "tls" { if protocol == "tls" {
return tls.Server(cmux, l.config), nil return tls.Server(connMux, l.config), nil
} }
return cmux, nil return connMux, nil
} }
// Close Listener // Close Listener
func (l *MuxListener) Close() error { func (l *ListenerMux) Close() error {
if l == nil { if l == nil {
return nil return nil
} }
return l.Listener.Close() return l.Listener.Close()
} }
// MuxServer - the main mux server // ServerMux - the main mux server
type MuxServer struct { type ServerMux struct {
http.Server http.Server
listener *MuxListener listener *ListenerMux
WaitGroup *sync.WaitGroup WaitGroup *sync.WaitGroup
GracefulTimeout time.Duration GracefulTimeout time.Duration
mu sync.Mutex // guards closed, conns, and listener mu sync.Mutex // guards closed, conns, and listener
@ -183,9 +175,9 @@ type MuxServer struct {
conns map[net.Conn]http.ConnState // except terminal states conns map[net.Conn]http.ConnState // except terminal states
} }
// NewMuxServer constructor to create a MuxServer // NewServerMux constructor to create a ServerMux
func NewMuxServer(addr string, handler http.Handler) *MuxServer { func NewServerMux(addr string, handler http.Handler) *ServerMux {
m := &MuxServer{ m := &ServerMux{
Server: http.Server{ Server: http.Server{
Addr: addr, Addr: addr,
// Do not add any timeouts Golang net.Conn // Do not add any timeouts Golang net.Conn
@ -208,23 +200,31 @@ func NewMuxServer(addr string, handler http.Handler) *MuxServer {
// ListenAndServeTLS - similar to the http.Server version. However, it has the // ListenAndServeTLS - similar to the http.Server version. However, it has the
// ability to redirect http requests to the correct HTTPS url if the client // ability to redirect http requests to the correct HTTPS url if the client
// mistakenly initiates a http connection over the https port // mistakenly initiates a http connection over the https port
func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { func (m *ServerMux) ListenAndServeTLS(certFile, keyFile string) error {
listener, err := net.Listen("tcp", m.Server.Addr) listener, err := net.Listen("tcp", m.Server.Addr)
if err != nil { if err != nil {
return err return err
} }
mux, err := NewMuxListener(listener, m.WaitGroup, mustGetCertFile(), mustGetKeyFile())
config := &tls.Config{} // Always instantiate.
if config.NextProtos == nil {
config.NextProtos = []string{"http/1.1", "h2"}
}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil { if err != nil {
return err return err
} }
listenerMux := &ListenerMux{Listener: listener, config: config}
m.mu.Lock() m.mu.Lock()
m.listener = mux m.listener = listenerMux
m.mu.Unlock() m.mu.Unlock()
err = http.Serve(mux, err = http.Serve(listenerMux,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// We reach here when MuxListener.MuxConn is not wrapped with tls.Server // We reach here when ListenerMux.ConnMux is not wrapped with tls.Server
if r.TLS == nil { if r.TLS == nil {
u := url.URL{ u := url.URL{
Scheme: "https", Scheme: "https",
@ -245,42 +245,23 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error {
} }
// ListenAndServe - Same as the http.Server version // ListenAndServe - Same as the http.Server version
func (m *MuxServer) ListenAndServe() error { func (m *ServerMux) ListenAndServe() error {
listener, err := net.Listen("tcp", m.Server.Addr) listener, err := net.Listen("tcp", m.Server.Addr)
if err != nil { if err != nil {
return err return err
} }
mux, err := NewMuxListener(listener, m.WaitGroup, "", "") listenerMux := &ListenerMux{Listener: listener, config: &tls.Config{}}
if err != nil {
return err
}
m.mu.Lock() m.mu.Lock()
m.listener = mux m.listener = listenerMux
m.mu.Unlock() m.mu.Unlock()
return m.Server.Serve(mux) return m.Server.Serve(listenerMux)
}
func longestWord(strings []string) int {
maxLen := 0
for _, m := range defaultHTTP1Methods {
if maxLen < len(m) {
maxLen = len(m)
}
}
for _, m := range defaultHTTP2Methods {
if maxLen < len(m) {
maxLen = len(m)
}
}
return maxLen
} }
// Close initiates the graceful shutdown // Close initiates the graceful shutdown
func (m *MuxServer) Close() error { func (m *ServerMux) Close() error {
m.mu.Lock() m.mu.Lock()
if m.closed { if m.closed {
return errors.New("Server has been closed") return errors.New("Server has been closed")
@ -317,7 +298,7 @@ func (m *MuxServer) Close() error {
} }
// connState setups the ConnState tracking hook to know which connections are idle // connState setups the ConnState tracking hook to know which connections are idle
func (m *MuxServer) connState() { func (m *ServerMux) connState() {
// Set our ConnState to track idle connections // Set our ConnState to track idle connections
m.Server.ConnState = func(c net.Conn, cs http.ConnState) { m.Server.ConnState = func(c net.Conn, cs http.ConnState) {
m.mu.Lock() m.mu.Lock()
@ -355,7 +336,7 @@ func (m *MuxServer) connState() {
} }
// forgetConn removes c from conns and decrements WaitGroup // forgetConn removes c from conns and decrements WaitGroup
func (m *MuxServer) forgetConn(c net.Conn) { func (m *ServerMux) forgetConn(c net.Conn) {
if _, ok := m.conns[c]; ok { if _, ok := m.conns[c]; ok {
delete(m.conns, c) delete(m.conns, c)
m.WaitGroup.Done() m.WaitGroup.Done()

@ -40,19 +40,19 @@ import (
func TestClose(t *testing.T) { func TestClose(t *testing.T) {
// Create ServerMux // Create ServerMux
m := NewMuxServer("", nil) m := NewServerMux("", nil)
if err := m.Close(); err != nil { if err := m.Close(); err != nil {
t.Error("Server errored while trying to Close", err) t.Error("Server errored while trying to Close", err)
} }
} }
func TestMuxServer(t *testing.T) { func TestServerMux(t *testing.T) {
ts := httptest.NewUnstartedServer(nil) ts := httptest.NewUnstartedServer(nil)
defer ts.Close() defer ts.Close()
// Create ServerMux // Create ServerMux
m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello") fmt.Fprint(w, "hello")
})) }))
@ -60,12 +60,12 @@ func TestMuxServer(t *testing.T) {
ts.Config = &m.Server ts.Config = &m.Server
ts.Start() ts.Start()
// Create a MuxListener // Create a ListenerMux
ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "") lm := &ListenerMux{
if err != nil { Listener: ts.Listener,
t.Fatal(err) config: &tls.Config{},
} }
m.listener = ml m.listener = lm
client := http.Client{} client := http.Client{}
res, err := client.Get(ts.URL) res, err := client.Get(ts.URL)
@ -105,7 +105,7 @@ func TestServerCloseBlocking(t *testing.T) {
defer ts.Close() defer ts.Close()
// Create ServerMux // Create ServerMux
m := NewMuxServer("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux("", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello") fmt.Fprint(w, "hello")
})) }))
@ -113,18 +113,17 @@ func TestServerCloseBlocking(t *testing.T) {
ts.Config = &m.Server ts.Config = &m.Server
ts.Start() ts.Start()
// Create a MuxListener // Create a ListenerMux.
// var err error lm := &ListenerMux{
ml, err := NewMuxListener(ts.Listener, m.WaitGroup, "", "") Listener: ts.Listener,
if err != nil { config: &tls.Config{},
t.Fatal(err)
} }
m.listener = ml m.listener = lm
dial := func() net.Conn { dial := func() net.Conn {
c, cerr := net.Dial("tcp", ts.Listener.Addr().String()) c, cerr := net.Dial("tcp", ts.Listener.Addr().String())
if cerr != nil { if cerr != nil {
t.Fatal(err) t.Fatal(cerr)
} }
return c return c
} }
@ -137,7 +136,7 @@ func TestServerCloseBlocking(t *testing.T) {
cidle := dial() cidle := dial()
defer cidle.Close() defer cidle.Close()
cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n")) cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
_, err = http.ReadResponse(bufio.NewReader(cidle), nil) _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -160,7 +159,7 @@ func TestListenAndServePlain(t *testing.T) {
once := &sync.Once{} once := &sync.Once{}
// Create ServerMux and when we receive a request we stop waiting // Create ServerMux and when we receive a request we stop waiting
m := NewMuxServer(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello") fmt.Fprint(w, "hello")
once.Do(func() { close(wait) }) once.Do(func() { close(wait) })
})) }))
@ -209,7 +208,7 @@ func TestListenAndServeTLS(t *testing.T) {
once := &sync.Once{} once := &sync.Once{}
// Create ServerMux and when we receive a request we stop waiting // Create ServerMux and when we receive a request we stop waiting
m := NewMuxServer(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { m := NewServerMux(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "hello") fmt.Fprint(w, "hello")
once.Do(func() { close(wait) }) once.Do(func() { close(wait) })
})) }))

Loading…
Cancel
Save