Fix net.Listener to fully close the underlying socket. (#3171)

Leads to races and accepting connections. This patch implements
a way to reject accepting new connections.
master
Harshavardhana 8 years ago committed by GitHub
parent bf3c93a8cc
commit 1ba497950c
  1. 94
      cmd/server-mux.go
  2. 75
      cmd/server-mux_test.go

@ -134,16 +134,88 @@ func (c *ConnMux) Read(b []byte) (int, error) {
return c.Conn.Read(b)
}
// ListenerMux - encapuslates the standard net.Listener to inspect
// ListenerMux wraps the standard net.Listener to inspect
// the communication protocol upon network connection
// ListenerMux also wraps net.Listener to ensure that once
// Listener.Close returns, the underlying socket has been closed.
//
// - https://github.com/golang/go/issues/10527
//
// The default Listener returns from Close before the underlying
// socket has been closed if another goroutine has an active
// reference (e.g. is in Accept).
//
// The following sequence of events can happen:
//
// Goroutine 1 is running Accept, and is blocked, waiting for epoll
//
// Goroutine 2 calls Close. It sees an extra reference, and so cannot
// destroy the socket, but instead decrements a reference, marks the
// connection as closed and unblocks epoll.
//
// Goroutine 2 returns to the caller, makes a new connection.
// The new connection is sent to the socket (since it hasn't been destroyed)
//
// Goroutine 1 returns from epoll, and accepts the new connection.
//
// To avoid accepting connections after Close, we block Goroutine 2
// from returning from Close till Accept returns an error to the user.
type ListenerMux struct {
net.Listener
config *tls.Config
// Cond is used to signal Close when there are no references to the listener.
cond *sync.Cond
refs int
}
// IsClosed - Returns if the underlying listener is closed fully.
func (l *ListenerMux) IsClosed() bool {
l.cond.L.Lock()
defer l.cond.L.Unlock()
return l.refs == 0
}
func (l *ListenerMux) incRef() {
l.cond.L.Lock()
l.refs++
l.cond.L.Unlock()
}
func (l *ListenerMux) decRef() {
l.cond.L.Lock()
l.refs--
newRefs := l.refs
l.cond.L.Unlock()
if newRefs == 0 {
l.cond.Broadcast()
}
}
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
func (l *ListenerMux) Close() error {
if l == nil {
return nil
}
if err := l.Listener.Close(); err != nil {
return err
}
l.cond.L.Lock()
for l.refs > 0 {
l.cond.Wait()
}
l.cond.L.Unlock()
return nil
}
// Accept - peek the protocol to decide if we should wrap the
// network stream with the TLS server
func (l *ListenerMux) Accept() (net.Conn, error) {
l.incRef()
defer l.decRef()
conn, err := l.Listener.Accept()
if err != nil {
return conn, err
@ -156,14 +228,6 @@ func (l *ListenerMux) Accept() (net.Conn, error) {
return connMux, nil
}
// Close Listener
func (l *ListenerMux) Close() error {
if l == nil {
return nil
}
return l.Listener.Close()
}
// ServerMux - the main mux server
type ServerMux struct {
*http.Server
@ -215,6 +279,7 @@ func initListeners(serverAddr string, tls *tls.Config) ([]*ListenerMux, error) {
listeners = append(listeners, &ListenerMux{
Listener: listener,
config: tls,
cond: sync.NewCond(&sync.Mutex{}),
})
return listeners, nil
}
@ -239,6 +304,7 @@ func initListeners(serverAddr string, tls *tls.Config) ([]*ListenerMux, error) {
listeners = append(listeners, &ListenerMux{
Listener: listener,
config: tls,
cond: sync.NewCond(&sync.Mutex{}),
})
}
return listeners, nil
@ -294,7 +360,10 @@ func (m *ServerMux) ListenAndServeTLS(certFile, keyFile string) (err error) {
}
}),
)
errorIf(serr, "Unable to serve incoming requests.")
// Do not print the error if the listener is closed.
if !listener.IsClosed() {
errorIf(serr, "Unable to serve incoming requests.")
}
}(listener)
}
// Waits for all http.Serve's to return.
@ -321,7 +390,10 @@ func (m *ServerMux) ListenAndServe() error {
go func(listener *ListenerMux) {
defer wg.Done()
serr := m.Server.Serve(listener)
errorIf(serr, "Unable to serve incoming requests.")
// Do not print the error if the listener is closed.
if !listener.IsClosed() {
errorIf(serr, "Unable to serve incoming requests.")
}
}(listener)
}
// Wait for all the http.Serve to finish.

@ -37,6 +37,79 @@ import (
"time"
)
func TestListenerAcceptAfterClose(t *testing.T) {
var wg sync.WaitGroup
for i := 0; i < 16; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 10; i++ {
runTest(t)
}
}()
}
wg.Wait()
}
func runTest(t *testing.T) {
const connectionsBeforeClose = 1
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
ln = &ListenerMux{
Listener: ln,
config: &tls.Config{},
cond: sync.NewCond(&sync.Mutex{}),
}
addr := ln.Addr().String()
waitForListener := make(chan error)
go func() {
defer close(waitForListener)
var connCount int
for {
conn, aerr := ln.Accept()
if aerr != nil {
return
}
connCount++
if connCount > connectionsBeforeClose {
waitForListener <- errUnexpected
return
}
conn.Close()
}
}()
for i := 0; i < connectionsBeforeClose; i++ {
err = dial(addr)
if err != nil {
t.Fatal(err)
}
}
ln.Close()
dial(addr)
err = <-waitForListener
if err != nil {
t.Fatal(err)
}
}
func dial(addr string) error {
conn, err := net.Dial("tcp", addr)
if err == nil {
conn.Close()
}
return err
}
// Tests initalizing listeners.
func TestInitListeners(t *testing.T) {
testCases := []struct {
@ -125,6 +198,7 @@ func TestServerMux(t *testing.T) {
lm := &ListenerMux{
Listener: ts.Listener,
config: &tls.Config{},
cond: sync.NewCond(&sync.Mutex{}),
}
m.listeners = []*ListenerMux{lm}
@ -178,6 +252,7 @@ func TestServerCloseBlocking(t *testing.T) {
lm := &ListenerMux{
Listener: ts.Listener,
config: &tls.Config{},
cond: sync.NewCond(&sync.Mutex{}),
}
m.listeners = []*ListenerMux{lm}

Loading…
Cancel
Save