|
|
@ -50,28 +50,44 @@ type ConnBuf struct { |
|
|
|
offset int |
|
|
|
offset int |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// MuxConn - implements a Read() which streams twice the firs bytes from
|
|
|
|
// ConnMux - implements a Read() which streams twice the firs bytes from
|
|
|
|
// the incoming connection, to help peeking protocol
|
|
|
|
// the incoming connection, to help peeking protocol
|
|
|
|
type MuxConn struct { |
|
|
|
type ConnMux struct { |
|
|
|
net.Conn |
|
|
|
net.Conn |
|
|
|
lastError error |
|
|
|
lastError error |
|
|
|
dataBuf ConnBuf |
|
|
|
dataBuf ConnBuf |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// NewMuxConn - creates a new MuxConn instance
|
|
|
|
func longestWord(strings []string) int { |
|
|
|
func NewMuxConn(c net.Conn) *MuxConn { |
|
|
|
maxLen := 0 |
|
|
|
|
|
|
|
for _, m := range defaultHTTP1Methods { |
|
|
|
|
|
|
|
if maxLen < len(m) { |
|
|
|
|
|
|
|
maxLen = len(m) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
for _, m := range defaultHTTP2Methods { |
|
|
|
|
|
|
|
if maxLen < len(m) { |
|
|
|
|
|
|
|
maxLen = len(m) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return maxLen |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// NewConnMux - creates a new ConnMux instance
|
|
|
|
|
|
|
|
func NewConnMux(c net.Conn) *ConnMux { |
|
|
|
h1 := longestWord(defaultHTTP1Methods) |
|
|
|
h1 := longestWord(defaultHTTP1Methods) |
|
|
|
h2 := longestWord(defaultHTTP2Methods) |
|
|
|
h2 := longestWord(defaultHTTP2Methods) |
|
|
|
max := h1 |
|
|
|
max := h1 |
|
|
|
if h2 > max { |
|
|
|
if h2 > max { |
|
|
|
max = h2 |
|
|
|
max = h2 |
|
|
|
} |
|
|
|
} |
|
|
|
return &MuxConn{Conn: c, dataBuf: ConnBuf{buffer: make([]byte, max+1)}} |
|
|
|
return &ConnMux{Conn: c, dataBuf: ConnBuf{buffer: make([]byte, max+1)}} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// PeekProtocol - reads the first bytes, then checks if it is similar
|
|
|
|
// PeekProtocol - reads the first bytes, then checks if it is similar
|
|
|
|
// to one of the default http methods
|
|
|
|
// to one of the default http methods
|
|
|
|
func (c *MuxConn) PeekProtocol() string { |
|
|
|
func (c *ConnMux) PeekProtocol() string { |
|
|
|
var n int |
|
|
|
var n int |
|
|
|
n, c.lastError = c.Conn.Read(c.dataBuf.buffer) |
|
|
|
n, c.lastError = c.Conn.Read(c.dataBuf.buffer) |
|
|
|
if n == 0 || (c.lastError != nil && c.lastError != io.EOF) { |
|
|
|
if n == 0 || (c.lastError != nil && c.lastError != io.EOF) { |
|
|
@ -91,9 +107,9 @@ func (c *MuxConn) PeekProtocol() string { |
|
|
|
return "tls" |
|
|
|
return "tls" |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Read - streams the MuxConn buffer when reset flag is activated, otherwise
|
|
|
|
// Read - streams the ConnMux buffer when reset flag is activated, otherwise
|
|
|
|
// streams from the incoming network connection
|
|
|
|
// streams from the incoming network connection
|
|
|
|
func (c *MuxConn) Read(b []byte) (int, error) { |
|
|
|
func (c *ConnMux) Read(b []byte) (int, error) { |
|
|
|
if c.dataBuf.unRead { |
|
|
|
if c.dataBuf.unRead { |
|
|
|
n := copy(b, c.dataBuf.buffer[c.dataBuf.offset:]) |
|
|
|
n := copy(b, c.dataBuf.buffer[c.dataBuf.offset:]) |
|
|
|
c.dataBuf.offset += n |
|
|
|
c.dataBuf.offset += n |
|
|
@ -118,64 +134,40 @@ func (c *MuxConn) Read(b []byte) (int, error) { |
|
|
|
return c.Conn.Read(b) |
|
|
|
return c.Conn.Read(b) |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// MuxListener - encapuslates the standard net.Listener to inspect
|
|
|
|
// ListenerMux - encapuslates the standard net.Listener to inspect
|
|
|
|
// the communication protocol upon network connection
|
|
|
|
// the communication protocol upon network connection
|
|
|
|
type MuxListener struct { |
|
|
|
type ListenerMux struct { |
|
|
|
net.Listener |
|
|
|
net.Listener |
|
|
|
config *tls.Config |
|
|
|
config *tls.Config |
|
|
|
wg *sync.WaitGroup |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// NewMuxListener - creates new MuxListener, returns error when cert/key files are not found
|
|
|
|
|
|
|
|
// or invalid
|
|
|
|
|
|
|
|
func NewMuxListener(listener net.Listener, wg *sync.WaitGroup, certPath, keyPath string) (*MuxListener, error) { |
|
|
|
|
|
|
|
var err error |
|
|
|
|
|
|
|
config := &tls.Config{} // Always instantiate.
|
|
|
|
|
|
|
|
if certPath != "" && keyPath != "" { |
|
|
|
|
|
|
|
if config.NextProtos == nil { |
|
|
|
|
|
|
|
config.NextProtos = []string{"http/1.1", "h2"} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
config.Certificates = make([]tls.Certificate, 1) |
|
|
|
|
|
|
|
config.Certificates[0], err = tls.LoadX509KeyPair(certPath, keyPath) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return &MuxListener{}, err |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
l := &MuxListener{Listener: listener, config: config, wg: wg} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return l, nil |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Accept - peek the protocol to decide if we should wrap the
|
|
|
|
// Accept - peek the protocol to decide if we should wrap the
|
|
|
|
// network stream with the TLS server
|
|
|
|
// network stream with the TLS server
|
|
|
|
func (l *MuxListener) Accept() (net.Conn, error) { |
|
|
|
func (l *ListenerMux) Accept() (net.Conn, error) { |
|
|
|
c, err := l.Listener.Accept() |
|
|
|
conn, err := l.Listener.Accept() |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return c, err |
|
|
|
return conn, err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
connMux := NewConnMux(conn) |
|
|
|
cmux := NewMuxConn(c) |
|
|
|
protocol := connMux.PeekProtocol() |
|
|
|
protocol := cmux.PeekProtocol() |
|
|
|
|
|
|
|
if protocol == "tls" { |
|
|
|
if protocol == "tls" { |
|
|
|
return tls.Server(cmux, l.config), nil |
|
|
|
return tls.Server(connMux, l.config), nil |
|
|
|
} |
|
|
|
} |
|
|
|
return cmux, nil |
|
|
|
return connMux, nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Close Listener
|
|
|
|
// Close Listener
|
|
|
|
func (l *MuxListener) Close() error { |
|
|
|
func (l *ListenerMux) Close() error { |
|
|
|
if l == nil { |
|
|
|
if l == nil { |
|
|
|
return nil |
|
|
|
return nil |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return l.Listener.Close() |
|
|
|
return l.Listener.Close() |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// MuxServer - the main mux server
|
|
|
|
// ServerMux - the main mux server
|
|
|
|
type MuxServer struct { |
|
|
|
type ServerMux struct { |
|
|
|
http.Server |
|
|
|
http.Server |
|
|
|
listener *MuxListener |
|
|
|
listener *ListenerMux |
|
|
|
WaitGroup *sync.WaitGroup |
|
|
|
WaitGroup *sync.WaitGroup |
|
|
|
GracefulTimeout time.Duration |
|
|
|
GracefulTimeout time.Duration |
|
|
|
mu sync.Mutex // guards closed, conns, and listener
|
|
|
|
mu sync.Mutex // guards closed, conns, and listener
|
|
|
@ -183,9 +175,9 @@ type MuxServer struct { |
|
|
|
conns map[net.Conn]http.ConnState // except terminal states
|
|
|
|
conns map[net.Conn]http.ConnState // except terminal states
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// NewMuxServer constructor to create a MuxServer
|
|
|
|
// NewServerMux constructor to create a ServerMux
|
|
|
|
func NewMuxServer(addr string, handler http.Handler) *MuxServer { |
|
|
|
func NewServerMux(addr string, handler http.Handler) *ServerMux { |
|
|
|
m := &MuxServer{ |
|
|
|
m := &ServerMux{ |
|
|
|
Server: http.Server{ |
|
|
|
Server: http.Server{ |
|
|
|
Addr: addr, |
|
|
|
Addr: addr, |
|
|
|
// Do not add any timeouts Golang net.Conn
|
|
|
|
// Do not add any timeouts Golang net.Conn
|
|
|
@ -208,23 +200,31 @@ func NewMuxServer(addr string, handler http.Handler) *MuxServer { |
|
|
|
// ListenAndServeTLS - similar to the http.Server version. However, it has the
|
|
|
|
// ListenAndServeTLS - similar to the http.Server version. However, it has the
|
|
|
|
// ability to redirect http requests to the correct HTTPS url if the client
|
|
|
|
// ability to redirect http requests to the correct HTTPS url if the client
|
|
|
|
// mistakenly initiates a http connection over the https port
|
|
|
|
// mistakenly initiates a http connection over the https port
|
|
|
|
func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { |
|
|
|
func (m *ServerMux) ListenAndServeTLS(certFile, keyFile string) error { |
|
|
|
listener, err := net.Listen("tcp", m.Server.Addr) |
|
|
|
listener, err := net.Listen("tcp", m.Server.Addr) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
|
} |
|
|
|
} |
|
|
|
mux, err := NewMuxListener(listener, m.WaitGroup, mustGetCertFile(), mustGetKeyFile()) |
|
|
|
|
|
|
|
|
|
|
|
config := &tls.Config{} // Always instantiate.
|
|
|
|
|
|
|
|
if config.NextProtos == nil { |
|
|
|
|
|
|
|
config.NextProtos = []string{"http/1.1", "h2"} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
config.Certificates = make([]tls.Certificate, 1) |
|
|
|
|
|
|
|
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
listenerMux := &ListenerMux{Listener: listener, config: config} |
|
|
|
|
|
|
|
|
|
|
|
m.mu.Lock() |
|
|
|
m.mu.Lock() |
|
|
|
m.listener = mux |
|
|
|
m.listener = listenerMux |
|
|
|
m.mu.Unlock() |
|
|
|
m.mu.Unlock() |
|
|
|
|
|
|
|
|
|
|
|
err = http.Serve(mux, |
|
|
|
err = http.Serve(listenerMux, |
|
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
// We reach here when MuxListener.MuxConn is not wrapped with tls.Server
|
|
|
|
// We reach here when ListenerMux.ConnMux is not wrapped with tls.Server
|
|
|
|
if r.TLS == nil { |
|
|
|
if r.TLS == nil { |
|
|
|
u := url.URL{ |
|
|
|
u := url.URL{ |
|
|
|
Scheme: "https", |
|
|
|
Scheme: "https", |
|
|
@ -245,42 +245,23 @@ func (m *MuxServer) ListenAndServeTLS(certFile, keyFile string) error { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// ListenAndServe - Same as the http.Server version
|
|
|
|
// ListenAndServe - Same as the http.Server version
|
|
|
|
func (m *MuxServer) ListenAndServe() error { |
|
|
|
func (m *ServerMux) ListenAndServe() error { |
|
|
|
listener, err := net.Listen("tcp", m.Server.Addr) |
|
|
|
listener, err := net.Listen("tcp", m.Server.Addr) |
|
|
|
if err != nil { |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
return err |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
mux, err := NewMuxListener(listener, m.WaitGroup, "", "") |
|
|
|
listenerMux := &ListenerMux{Listener: listener, config: &tls.Config{}} |
|
|
|
if err != nil { |
|
|
|
|
|
|
|
return err |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m.mu.Lock() |
|
|
|
m.mu.Lock() |
|
|
|
m.listener = mux |
|
|
|
m.listener = listenerMux |
|
|
|
m.mu.Unlock() |
|
|
|
m.mu.Unlock() |
|
|
|
|
|
|
|
|
|
|
|
return m.Server.Serve(mux) |
|
|
|
return m.Server.Serve(listenerMux) |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
func longestWord(strings []string) int { |
|
|
|
|
|
|
|
maxLen := 0 |
|
|
|
|
|
|
|
for _, m := range defaultHTTP1Methods { |
|
|
|
|
|
|
|
if maxLen < len(m) { |
|
|
|
|
|
|
|
maxLen = len(m) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
for _, m := range defaultHTTP2Methods { |
|
|
|
|
|
|
|
if maxLen < len(m) { |
|
|
|
|
|
|
|
maxLen = len(m) |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return maxLen |
|
|
|
|
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// Close initiates the graceful shutdown
|
|
|
|
// Close initiates the graceful shutdown
|
|
|
|
func (m *MuxServer) Close() error { |
|
|
|
func (m *ServerMux) Close() error { |
|
|
|
m.mu.Lock() |
|
|
|
m.mu.Lock() |
|
|
|
if m.closed { |
|
|
|
if m.closed { |
|
|
|
return errors.New("Server has been closed") |
|
|
|
return errors.New("Server has been closed") |
|
|
@ -317,7 +298,7 @@ func (m *MuxServer) Close() error { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// connState setups the ConnState tracking hook to know which connections are idle
|
|
|
|
// connState setups the ConnState tracking hook to know which connections are idle
|
|
|
|
func (m *MuxServer) connState() { |
|
|
|
func (m *ServerMux) connState() { |
|
|
|
// Set our ConnState to track idle connections
|
|
|
|
// Set our ConnState to track idle connections
|
|
|
|
m.Server.ConnState = func(c net.Conn, cs http.ConnState) { |
|
|
|
m.Server.ConnState = func(c net.Conn, cs http.ConnState) { |
|
|
|
m.mu.Lock() |
|
|
|
m.mu.Lock() |
|
|
@ -355,7 +336,7 @@ func (m *MuxServer) connState() { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// forgetConn removes c from conns and decrements WaitGroup
|
|
|
|
// forgetConn removes c from conns and decrements WaitGroup
|
|
|
|
func (m *MuxServer) forgetConn(c net.Conn) { |
|
|
|
func (m *ServerMux) forgetConn(c net.Conn) { |
|
|
|
if _, ok := m.conns[c]; ok { |
|
|
|
if _, ok := m.conns[c]; ok { |
|
|
|
delete(m.conns, c) |
|
|
|
delete(m.conns, c) |
|
|
|
m.WaitGroup.Done() |
|
|
|
m.WaitGroup.Done() |
|
|
|