diff --git a/cmd/auth-rpc-client.go b/cmd/auth-rpc-client.go index c397de20b..905eb2083 100644 --- a/cmd/auth-rpc-client.go +++ b/cmd/auth-rpc-client.go @@ -111,7 +111,7 @@ func newAuthClient(cfg *authConfig) *AuthRPCClient { // Save the config. config: cfg, // Initialize a new reconnectable rpc client. - rpc: newClient(cfg.address, cfg.path, cfg.secureConn), + rpc: newRPCClient(cfg.address, cfg.path, cfg.secureConn), // Allocated auth client not logged in yet. isLoggedIn: false, } diff --git a/cmd/browser-peer-rpc_test.go b/cmd/browser-peer-rpc_test.go index 0f6cc74f8..2f34da764 100644 --- a/cmd/browser-peer-rpc_test.go +++ b/cmd/browser-peer-rpc_test.go @@ -70,7 +70,7 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) { // Validate for invalid token. args := SetAuthPeerArgs{Creds: creds} args.Token = "garbage" - rclient := newClient(s.testAuthConf.address, s.testAuthConf.path, false) + rclient := newRPCClient(s.testAuthConf.address, s.testAuthConf.path, false) defer rclient.Close() err := rclient.Call("BrowserPeer.SetAuthPeer", &args, &GenericReply{}) if err != nil { @@ -89,7 +89,7 @@ func (s *TestRPCBrowserPeerSuite) testBrowserPeerRPC(t *testing.T) { } // Validate for failure in login handler with previous credentials. - rclient = newClient(s.testAuthConf.address, s.testAuthConf.path, false) + rclient = newRPCClient(s.testAuthConf.address, s.testAuthConf.path, false) defer rclient.Close() rargs := &RPCLoginArgs{ Username: s.testAuthConf.accessKey, diff --git a/cmd/lock-rpc-server.go b/cmd/lock-rpc-server.go index 42cba5303..6e362557c 100644 --- a/cmd/lock-rpc-server.go +++ b/cmd/lock-rpc-server.go @@ -279,7 +279,7 @@ func (l *lockServer) lockMaintenance(interval time.Duration) { // Validate if long lived locks are indeed clean. for _, nlrip := range nlripLongLived { // Initialize client based on the long live locks. - c := newClient(nlrip.lri.node, nlrip.lri.rpcPath, isSSL()) + c := newRPCClient(nlrip.lri.node, nlrip.lri.rpcPath, isSSL()) var expired bool diff --git a/cmd/net-rpc-client.go b/cmd/net-rpc-client.go index b74b531c2..1ba1bb8e5 100644 --- a/cmd/net-rpc-client.go +++ b/cmd/net-rpc-client.go @@ -30,19 +30,21 @@ import ( "time" ) +// defaultDialTimeout is used for non-secure connection. +const defaultDialTimeout = 3 * time.Second + // RPCClient is a wrapper type for rpc.Client which provides reconnect on first failure. type RPCClient struct { - mu sync.Mutex - rpcPrivate *rpc.Client - node string - rpcPath string - secureConn bool + mu sync.Mutex + netRPCClient *rpc.Client + node string + rpcPath string + secureConn bool } // newClient constructs a RPCClient object with node and rpcPath initialized. -// It _doesn't_ connect to the remote endpoint. See Call method to see when the -// connect happens. -func newClient(node, rpcPath string, secureConn bool) *RPCClient { +// It does lazy connect to the remote endpoint on Call(). +func newRPCClient(node, rpcPath string, secureConn bool) *RPCClient { return &RPCClient{ node: node, rpcPath: rpcPath, @@ -50,34 +52,19 @@ func newClient(node, rpcPath string, secureConn bool) *RPCClient { } } -// clearRPCClient clears the pointer to the rpc.Client object in a safe manner -func (rpcClient *RPCClient) clearRPCClient() { +// dial tries to establish a connection to the server in a safe manner. +// If there is a valid rpc.Cliemt, it returns that else creates a new one. +func (rpcClient *RPCClient) dial() (*rpc.Client, error) { rpcClient.mu.Lock() - rpcClient.rpcPrivate = nil - rpcClient.mu.Unlock() -} + defer rpcClient.mu.Unlock() -// getRPCClient gets the pointer to the rpc.Client object in a safe manner -func (rpcClient *RPCClient) getRPCClient() *rpc.Client { - rpcClient.mu.Lock() - rpcLocalStack := rpcClient.rpcPrivate - rpcClient.mu.Unlock() - return rpcLocalStack -} - -// dialRPCClient tries to establish a connection to the server in a safe manner -func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) { - rpcClient.mu.Lock() - // After acquiring lock, check whether another thread may not have already dialed and established connection - if rpcClient.rpcPrivate != nil { - rpcClient.mu.Unlock() - return rpcClient.rpcPrivate, nil + // Nothing to do as we already have valid connection. + if rpcClient.netRPCClient != nil { + return rpcClient.netRPCClient, nil } - rpcClient.mu.Unlock() var err error var conn net.Conn - if rpcClient.secureConn { hostname, _, splitErr := net.SplitHostPort(rpcClient.node) if splitErr != nil { @@ -92,14 +79,14 @@ func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) { // ServerName in tls.Config needs to be specified to support SNI certificates conn, err = tls.Dial("tcp", rpcClient.node, &tls.Config{ServerName: hostname, RootCAs: globalRootCAs}) } else { - // Have a dial timeout with 3 secs. - conn, err = net.DialTimeout("tcp", rpcClient.node, 3*time.Second) + // Dial with 3 seconds timeout. + conn, err = net.DialTimeout("tcp", rpcClient.node, defaultDialTimeout) } if err != nil { // Print RPC connection errors that are worthy to display in log switch err.(type) { case x509.HostnameError: - errorIf(err, "Unable to establish RPC to %s", rpcClient.node) + errorIf(err, "Unable to establish secure connection to %s", rpcClient.node) } return nil, &net.OpError{ Op: "dial-http", @@ -108,25 +95,27 @@ func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) { Err: err, } } + io.WriteString(conn, "CONNECT "+rpcClient.rpcPath+" HTTP/1.0\n\n") // Require successful HTTP response before switching to RPC protocol. resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) if err == nil && resp.Status == "200 Connected to Go RPC" { - rpc := rpc.NewClient(conn) - if rpc == nil { + netRPCClient := rpc.NewClient(conn) + if netRPCClient == nil { return nil, &net.OpError{ Op: "dial-http", Net: rpcClient.node + " " + rpcClient.rpcPath, Addr: nil, - Err: fmt.Errorf("Unable to initialize new rpcClient, %s", errUnexpected), + Err: fmt.Errorf("Unable to initialize new rpc.Client, %s", errUnexpected), } } - rpcClient.mu.Lock() - rpcClient.rpcPrivate = rpc - rpcClient.mu.Unlock() - return rpc, nil + + rpcClient.netRPCClient = netRPCClient + + return netRPCClient, nil } + if err == nil { err = errors.New("unexpected HTTP response: " + resp.Status) } @@ -141,38 +130,31 @@ func (rpcClient *RPCClient) dialRPCClient() (*rpc.Client, error) { // Call makes a RPC call to the remote endpoint using the default codec, namely encoding/gob. func (rpcClient *RPCClient) Call(serviceMethod string, args interface{}, reply interface{}) error { - // Make a copy below so that we can safely (continue to) work with the rpc.Client. - // Even in the case the two threads would simultaneously find that the connection is not initialised, - // they would both attempt to dial and only one of them would succeed in doing so. - rpcLocalStack := rpcClient.getRPCClient() - - // If the rpc.Client is nil, we attempt to (re)connect with the remote endpoint. - if rpcLocalStack == nil { - var err error - rpcLocalStack, err = rpcClient.dialRPCClient() - if err != nil { - return err - } + // Get a new or existing rpc.Client. + netRPCClient, err := rpcClient.dial() + if err != nil { + return err } - // If the RPC fails due to a network-related error - return rpcLocalStack.Call(serviceMethod, args, reply) + return netRPCClient.Call(serviceMethod, args, reply) } -// Close closes the underlying socket file descriptor. +// Close closes underlying rpc.Client. func (rpcClient *RPCClient) Close() error { - // See comment above for making a copy on local stack - rpcLocalStack := rpcClient.getRPCClient() + rpcClient.mu.Lock() - // If rpc client has not connected yet there is nothing to close. - if rpcLocalStack == nil { - return nil + if rpcClient.netRPCClient != nil { + // We make a copy of rpc.Client and unlock it immediately so that another + // goroutine could try to dial or close in parallel. + netRPCClient := rpcClient.netRPCClient + rpcClient.netRPCClient = nil + rpcClient.mu.Unlock() + + return netRPCClient.Close() } - // Reset rpcClient.rpc to allow for subsequent calls to use a new - // (socket) connection. - rpcClient.clearRPCClient() - return rpcLocalStack.Close() + rpcClient.mu.Unlock() + return nil } // Node returns the node (network address) of the connection diff --git a/cmd/s3-peer-rpc-handlers_test.go b/cmd/s3-peer-rpc-handlers_test.go index c310cfbfe..ff46414bc 100644 --- a/cmd/s3-peer-rpc-handlers_test.go +++ b/cmd/s3-peer-rpc-handlers_test.go @@ -63,7 +63,7 @@ func TestS3PeerRPC(t *testing.T) { func (s *TestRPCS3PeerSuite) testS3PeerRPC(t *testing.T) { // Validate for invalid token. args := GenericArgs{Token: "garbage", Timestamp: time.Now().UTC()} - rclient := newClient(s.testAuthConf.address, s.testAuthConf.path, false) + rclient := newRPCClient(s.testAuthConf.address, s.testAuthConf.path, false) defer rclient.Close() err := rclient.Call("S3.SetBucketNotificationPeer", &args, &GenericReply{}) if err != nil {