Lock: Use REST API instead of RPC (#7469)
In distributed mode, use REST API to acquire and manage locks instead of RPC. RPC has been completely removed from MinIO source. Since we are moving from RPC to REST, we cannot use rolling upgrades as the nodes that have not yet been upgraded cannot talk to the ones that have been upgraded. We expect all minio processes on all nodes to be stopped and then the upgrade process to be completed. Also force http1.1 for inter-node communicationmaster
parent
7686340621
commit
d2f42d830f
@ -0,0 +1,223 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2019 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"crypto/tls" |
||||
"encoding/gob" |
||||
"errors" |
||||
"io" |
||||
"sync" |
||||
"time" |
||||
|
||||
"net/url" |
||||
|
||||
"github.com/minio/dsync" |
||||
"github.com/minio/minio/cmd/http" |
||||
"github.com/minio/minio/cmd/logger" |
||||
"github.com/minio/minio/cmd/rest" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
// lockRESTClient is authenticable lock REST client
|
||||
type lockRESTClient struct { |
||||
lockSync sync.RWMutex |
||||
host *xnet.Host |
||||
restClient *rest.Client |
||||
serverURL *url.URL |
||||
connected bool |
||||
timer *time.Timer |
||||
} |
||||
|
||||
// ServerAddr - dsync.NetLocker interface compatible method.
|
||||
func (client *lockRESTClient) ServerAddr() string { |
||||
return client.serverURL.Host |
||||
} |
||||
|
||||
// ServiceEndpoint - dsync.NetLocker interface compatible method.
|
||||
func (client *lockRESTClient) ServiceEndpoint() string { |
||||
return client.serverURL.Path |
||||
} |
||||
|
||||
// check if the host is up or if it is fine
|
||||
// to make a call to the lock rest server.
|
||||
func (client *lockRESTClient) isHostUp() bool { |
||||
client.lockSync.Lock() |
||||
defer client.lockSync.Unlock() |
||||
|
||||
if client.connected { |
||||
return true |
||||
} |
||||
select { |
||||
case <-client.timer.C: |
||||
client.connected = true |
||||
client.timer = nil |
||||
return true |
||||
default: |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// Mark the host as down if there is a Network error.
|
||||
func (client *lockRESTClient) markHostDown() { |
||||
client.lockSync.Lock() |
||||
defer client.lockSync.Unlock() |
||||
|
||||
if !client.connected { |
||||
return |
||||
} |
||||
client.connected = false |
||||
client.timer = time.NewTimer(defaultRetryUnit * 5) |
||||
} |
||||
|
||||
// Wrapper to restClient.Call to handle network errors, in case of network error the connection is marked disconnected
|
||||
// permanently. The only way to restore the connection is at the xl-sets layer by xlsets.monitorAndConnectEndpoints()
|
||||
// after verifying format.json
|
||||
func (client *lockRESTClient) call(method string, values url.Values, body io.Reader, length int64) (respBody io.ReadCloser, err error) { |
||||
|
||||
if !client.isHostUp() { |
||||
return nil, errors.New("Lock rest server node is down") |
||||
} |
||||
|
||||
if values == nil { |
||||
values = make(url.Values) |
||||
} |
||||
|
||||
respBody, err = client.restClient.Call(method, values, body, length) |
||||
|
||||
if err == nil { |
||||
return respBody, nil |
||||
} |
||||
|
||||
if isNetworkError(err) { |
||||
client.markHostDown() |
||||
} |
||||
|
||||
return nil, err |
||||
} |
||||
|
||||
// Stringer provides a canonicalized representation of node.
|
||||
func (client *lockRESTClient) String() string { |
||||
return client.host.String() |
||||
} |
||||
|
||||
// IsOnline - returns whether REST client failed to connect or not.
|
||||
func (client *lockRESTClient) IsOnline() bool { |
||||
return client.connected |
||||
} |
||||
|
||||
// Close - marks the client as closed.
|
||||
func (client *lockRESTClient) Close() error { |
||||
client.connected = false |
||||
client.restClient.Close() |
||||
return nil |
||||
} |
||||
|
||||
// restCall makes a call to the lock REST server.
|
||||
func (client *lockRESTClient) restCall(call string, args dsync.LockArgs) (reply bool, err error) { |
||||
|
||||
reader := bytes.NewBuffer(make([]byte, 0, 2048)) |
||||
err = gob.NewEncoder(reader).Encode(args) |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
respBody, err := client.call(call, nil, reader, -1) |
||||
if err != nil { |
||||
return false, err |
||||
} |
||||
|
||||
var resp lockResponse |
||||
defer http.DrainBody(respBody) |
||||
err = gob.NewDecoder(respBody).Decode(&resp) |
||||
|
||||
if err != nil || !resp.Success { |
||||
reqInfo := &logger.ReqInfo{} |
||||
reqInfo.AppendTags("resource", args.Resource) |
||||
reqInfo.AppendTags("serveraddress", args.ServerAddr) |
||||
reqInfo.AppendTags("serviceendpoint", args.ServiceEndpoint) |
||||
reqInfo.AppendTags("source", args.Source) |
||||
reqInfo.AppendTags("uid", args.UID) |
||||
ctx := logger.SetReqInfo(context.Background(), reqInfo) |
||||
logger.LogIf(ctx, err) |
||||
} |
||||
return resp.Success, err |
||||
} |
||||
|
||||
// RLock calls read lock REST API.
|
||||
func (client *lockRESTClient) RLock(args dsync.LockArgs) (reply bool, err error) { |
||||
return client.restCall(lockRESTMethodRLock, args) |
||||
} |
||||
|
||||
// Lock calls lock REST API.
|
||||
func (client *lockRESTClient) Lock(args dsync.LockArgs) (reply bool, err error) { |
||||
return client.restCall(lockRESTMethodLock, args) |
||||
} |
||||
|
||||
// RUnlock calls read unlock REST API.
|
||||
func (client *lockRESTClient) RUnlock(args dsync.LockArgs) (reply bool, err error) { |
||||
return client.restCall(lockRESTMethodRUnlock, args) |
||||
} |
||||
|
||||
// Unlock calls write unlock RPC.
|
||||
func (client *lockRESTClient) Unlock(args dsync.LockArgs) (reply bool, err error) { |
||||
return client.restCall(lockRESTMethodUnlock, args) |
||||
} |
||||
|
||||
// ForceUnlock calls force unlock RPC.
|
||||
func (client *lockRESTClient) ForceUnlock(args dsync.LockArgs) (reply bool, err error) { |
||||
return client.restCall(lockRESTMethodForceUnlock, args) |
||||
} |
||||
|
||||
// Expired calls expired RPC.
|
||||
func (client *lockRESTClient) Expired(args dsync.LockArgs) (reply bool, err error) { |
||||
return client.restCall(lockRESTMethodExpired, args) |
||||
} |
||||
|
||||
// Returns a lock rest client.
|
||||
func newlockRESTClient(peer *xnet.Host) *lockRESTClient { |
||||
|
||||
scheme := "http" |
||||
if globalIsSSL { |
||||
scheme = "https" |
||||
} |
||||
|
||||
serverURL := &url.URL{ |
||||
Scheme: scheme, |
||||
Host: peer.String(), |
||||
Path: lockRESTPath, |
||||
} |
||||
|
||||
var tlsConfig *tls.Config |
||||
if globalIsSSL { |
||||
tlsConfig = &tls.Config{ |
||||
ServerName: peer.Name, |
||||
RootCAs: globalRootCAs, |
||||
NextProtos: []string{"http/1.1"}, // Force http1.1
|
||||
} |
||||
} |
||||
|
||||
restClient, err := rest.NewClient(serverURL, tlsConfig, rest.DefaultRESTTimeout, newAuthToken) |
||||
|
||||
if err != nil { |
||||
logger.LogIf(context.Background(), err) |
||||
return &lockRESTClient{serverURL: serverURL, host: peer, restClient: restClient, connected: false, timer: time.NewTimer(defaultRetryUnit * 5)} |
||||
} |
||||
|
||||
return &lockRESTClient{serverURL: serverURL, host: peer, restClient: restClient, connected: true} |
||||
} |
@ -0,0 +1,336 @@ |
||||
/* |
||||
* Minio Cloud Storage, (C) 2019 Minio, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"context" |
||||
"encoding/gob" |
||||
"errors" |
||||
"math/rand" |
||||
"net/http" |
||||
"time" |
||||
|
||||
"github.com/gorilla/mux" |
||||
"github.com/minio/dsync" |
||||
"github.com/minio/minio/cmd/logger" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
const ( |
||||
// Lock rpc server endpoint.
|
||||
lockServiceSubPath = "/lock" |
||||
|
||||
// Lock maintenance interval.
|
||||
lockMaintenanceInterval = 1 * time.Minute |
||||
|
||||
// Lock validity check interval.
|
||||
lockValidityCheckInterval = 2 * time.Minute |
||||
) |
||||
|
||||
// To abstract a node over network.
|
||||
type lockRESTServer struct { |
||||
ll localLocker |
||||
} |
||||
|
||||
func (l *lockRESTServer) writeErrorResponse(w http.ResponseWriter, err error) { |
||||
w.WriteHeader(http.StatusForbidden) |
||||
w.Write([]byte(err.Error())) |
||||
} |
||||
|
||||
// IsValid - To authenticate and verify the time difference.
|
||||
func (l *lockRESTServer) IsValid(w http.ResponseWriter, r *http.Request) bool { |
||||
if err := storageServerRequestValidate(r); err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return false |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// LockHandler - Acquires a lock.
|
||||
func (l *lockRESTServer) LockHandler(w http.ResponseWriter, r *http.Request) { |
||||
if !l.IsValid(w, r) { |
||||
l.writeErrorResponse(w, errors.New("Invalid request")) |
||||
return |
||||
} |
||||
|
||||
ctx := newContext(r, w, "Lock") |
||||
|
||||
var lockArgs dsync.LockArgs |
||||
if r.ContentLength < 0 { |
||||
l.writeErrorResponse(w, errInvalidArgument) |
||||
return |
||||
} |
||||
|
||||
err := gob.NewDecoder(r.Body).Decode(&lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
|
||||
success, err := l.ll.Lock(lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
resp := lockResponse{Success: success} |
||||
logger.LogIf(ctx, gob.NewEncoder(w).Encode(resp)) |
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// UnlockHandler - releases the acquired lock.
|
||||
func (l *lockRESTServer) UnlockHandler(w http.ResponseWriter, r *http.Request) { |
||||
if !l.IsValid(w, r) { |
||||
l.writeErrorResponse(w, errors.New("Invalid request")) |
||||
return |
||||
} |
||||
|
||||
ctx := newContext(r, w, "Unlock") |
||||
|
||||
var lockArgs dsync.LockArgs |
||||
if r.ContentLength < 0 { |
||||
l.writeErrorResponse(w, errInvalidArgument) |
||||
return |
||||
} |
||||
|
||||
err := gob.NewDecoder(r.Body).Decode(&lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
|
||||
success, err := l.ll.Unlock(lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
resp := lockResponse{Success: success} |
||||
logger.LogIf(ctx, gob.NewEncoder(w).Encode(resp)) |
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// LockHandler - Acquires an RLock.
|
||||
func (l *lockRESTServer) RLockHandler(w http.ResponseWriter, r *http.Request) { |
||||
if !l.IsValid(w, r) { |
||||
l.writeErrorResponse(w, errors.New("Invalid request")) |
||||
return |
||||
} |
||||
|
||||
ctx := newContext(r, w, "RLock") |
||||
var lockArgs dsync.LockArgs |
||||
if r.ContentLength < 0 { |
||||
l.writeErrorResponse(w, errInvalidArgument) |
||||
return |
||||
} |
||||
|
||||
err := gob.NewDecoder(r.Body).Decode(&lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
|
||||
success, err := l.ll.RLock(lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
resp := lockResponse{Success: success} |
||||
logger.LogIf(ctx, gob.NewEncoder(w).Encode(resp)) |
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// RUnlockHandler - releases the acquired read lock.
|
||||
func (l *lockRESTServer) RUnlockHandler(w http.ResponseWriter, r *http.Request) { |
||||
if !l.IsValid(w, r) { |
||||
l.writeErrorResponse(w, errors.New("Invalid request")) |
||||
return |
||||
} |
||||
|
||||
ctx := newContext(r, w, "RUnlock") |
||||
var lockArgs dsync.LockArgs |
||||
if r.ContentLength < 0 { |
||||
l.writeErrorResponse(w, errInvalidArgument) |
||||
return |
||||
} |
||||
|
||||
err := gob.NewDecoder(r.Body).Decode(&lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
|
||||
success, err := l.ll.RUnlock(lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
resp := lockResponse{Success: success} |
||||
logger.LogIf(ctx, gob.NewEncoder(w).Encode(resp)) |
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// ForceUnlockHandler - force releases the acquired lock.
|
||||
func (l *lockRESTServer) ForceUnlockHandler(w http.ResponseWriter, r *http.Request) { |
||||
if !l.IsValid(w, r) { |
||||
l.writeErrorResponse(w, errors.New("Invalid request")) |
||||
return |
||||
} |
||||
|
||||
ctx := newContext(r, w, "ForceUnlock") |
||||
|
||||
var lockArgs dsync.LockArgs |
||||
if r.ContentLength < 0 { |
||||
l.writeErrorResponse(w, errInvalidArgument) |
||||
return |
||||
} |
||||
|
||||
err := gob.NewDecoder(r.Body).Decode(&lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
|
||||
success, err := l.ll.ForceUnlock(lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
resp := lockResponse{Success: success} |
||||
logger.LogIf(ctx, gob.NewEncoder(w).Encode(resp)) |
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// ExpiredHandler - query expired lock status.
|
||||
func (l *lockRESTServer) ExpiredHandler(w http.ResponseWriter, r *http.Request) { |
||||
if !l.IsValid(w, r) { |
||||
l.writeErrorResponse(w, errors.New("Invalid request")) |
||||
return |
||||
} |
||||
|
||||
ctx := newContext(r, w, "Expired") |
||||
|
||||
var lockArgs dsync.LockArgs |
||||
if r.ContentLength < 0 { |
||||
l.writeErrorResponse(w, errInvalidArgument) |
||||
return |
||||
} |
||||
|
||||
err := gob.NewDecoder(r.Body).Decode(&lockArgs) |
||||
if err != nil { |
||||
l.writeErrorResponse(w, err) |
||||
return |
||||
} |
||||
success := true |
||||
l.ll.mutex.Lock() |
||||
defer l.ll.mutex.Unlock() |
||||
// Lock found, proceed to verify if belongs to given uid.
|
||||
if lri, ok := l.ll.lockMap[lockArgs.Resource]; ok { |
||||
// Check whether uid is still active
|
||||
for _, entry := range lri { |
||||
if entry.UID == lockArgs.UID { |
||||
success = false // When uid found, lock is still active so return not expired.
|
||||
break |
||||
} |
||||
} |
||||
} |
||||
// When we get here lock is no longer active due to either dsync.LockArgs.Resource
|
||||
// being absent from map or uid not found for given dsync.LockArgs.Resource
|
||||
resp := lockResponse{Success: success} |
||||
logger.LogIf(ctx, gob.NewEncoder(w).Encode(resp)) |
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// lockMaintenance loops over locks that have been active for some time and checks back
|
||||
// with the original server whether it is still alive or not
|
||||
//
|
||||
// Following logic inside ignores the errors generated for Dsync.Active operation.
|
||||
// - server at client down
|
||||
// - some network error (and server is up normally)
|
||||
//
|
||||
// We will ignore the error, and we will retry later to get a resolve on this lock
|
||||
func (l *lockRESTServer) lockMaintenance(interval time.Duration) { |
||||
l.ll.mutex.Lock() |
||||
// Get list of long lived locks to check for staleness.
|
||||
nlripLongLived := getLongLivedLocks(l.ll.lockMap, interval) |
||||
l.ll.mutex.Unlock() |
||||
|
||||
// Validate if long lived locks are indeed clean.
|
||||
for _, nlrip := range nlripLongLived { |
||||
// Initialize client based on the long live locks.
|
||||
host, err := xnet.ParseHost(nlrip.lri.Node) |
||||
if err != nil { |
||||
logger.LogIf(context.Background(), err) |
||||
continue |
||||
} |
||||
c := newlockRESTClient(host) |
||||
if !c.connected { |
||||
continue |
||||
} |
||||
|
||||
// Call back to original server verify whether the lock is still active (based on name & uid)
|
||||
expired, _ := c.Expired(dsync.LockArgs{ |
||||
UID: nlrip.lri.UID, |
||||
Resource: nlrip.name, |
||||
}) |
||||
|
||||
// Close the connection regardless of the call response.
|
||||
c.Close() |
||||
|
||||
// For successful response, verify if lock is indeed active or stale.
|
||||
if expired { |
||||
// The lock is no longer active at server that originated the lock
|
||||
// So remove the lock from the map.
|
||||
l.ll.mutex.Lock() |
||||
l.ll.removeEntryIfExists(nlrip) // Purge the stale entry if it exists.
|
||||
l.ll.mutex.Unlock() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Start lock maintenance from all lock servers.
|
||||
func startLockMaintenance(lkSrv *lockRESTServer) { |
||||
// Initialize a new ticker with a minute between each ticks.
|
||||
ticker := time.NewTicker(lockMaintenanceInterval) |
||||
// Stop the timer upon service closure and cleanup the go-routine.
|
||||
defer ticker.Stop() |
||||
|
||||
// Start with random sleep time, so as to avoid "synchronous checks" between servers
|
||||
time.Sleep(time.Duration(rand.Float64() * float64(lockMaintenanceInterval))) |
||||
for { |
||||
// Verifies every minute for locks held more than 2 minutes.
|
||||
select { |
||||
case <-GlobalServiceDoneCh: |
||||
return |
||||
case <-ticker.C: |
||||
lkSrv.lockMaintenance(lockValidityCheckInterval) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// registerLockRESTHandlers - register lock rest router.
|
||||
func registerLockRESTHandlers(router *mux.Router) { |
||||
subrouter := router.PathPrefix(lockRESTPath).Subrouter() |
||||
subrouter.Methods(http.MethodPost).Path("/" + lockRESTMethodLock).HandlerFunc(httpTraceHdrs(globalLockServer.LockHandler)) |
||||
subrouter.Methods(http.MethodPost).Path("/" + lockRESTMethodRLock).HandlerFunc(httpTraceHdrs(globalLockServer.RLockHandler)) |
||||
subrouter.Methods(http.MethodPost).Path("/" + lockRESTMethodUnlock).HandlerFunc(httpTraceHdrs(globalLockServer.UnlockHandler)) |
||||
subrouter.Methods(http.MethodPost).Path("/" + lockRESTMethodRUnlock).HandlerFunc(httpTraceHdrs(globalLockServer.RUnlockHandler)) |
||||
subrouter.Methods(http.MethodPost).Path("/" + lockRESTMethodForceUnlock).HandlerFunc(httpTraceHdrs(globalLockServer.ForceUnlockHandler)) |
||||
subrouter.Methods(http.MethodPost).Path("/" + lockRESTMethodExpired).HandlerFunc(httpTraceAll(globalLockServer.ExpiredHandler)) |
||||
router.NotFoundHandler = http.HandlerFunc(httpTraceAll(notFoundHandler)) |
||||
|
||||
// Start lock maintenance from all lock servers.
|
||||
go startLockMaintenance(globalLockServer) |
||||
} |
@ -1,114 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2016 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
|
||||
"github.com/minio/dsync" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
// LockRPCClient is authenticable lock RPC client compatible to dsync.NetLocker
|
||||
type LockRPCClient struct { |
||||
*RPCClient |
||||
} |
||||
|
||||
// ServerAddr - dsync.NetLocker interface compatible method.
|
||||
func (lockRPC *LockRPCClient) ServerAddr() string { |
||||
url := lockRPC.ServiceURL() |
||||
return url.Host |
||||
} |
||||
|
||||
// ServiceEndpoint - dsync.NetLocker interface compatible method.
|
||||
func (lockRPC *LockRPCClient) ServiceEndpoint() string { |
||||
url := lockRPC.ServiceURL() |
||||
return url.Path |
||||
} |
||||
|
||||
// RLock calls read lock RPC.
|
||||
func (lockRPC *LockRPCClient) RLock(args dsync.LockArgs) (reply bool, err error) { |
||||
err = lockRPC.Call(lockServiceName+".RLock", &LockArgs{LockArgs: args}, &reply) |
||||
return reply, err |
||||
} |
||||
|
||||
// Lock calls write lock RPC.
|
||||
func (lockRPC *LockRPCClient) Lock(args dsync.LockArgs) (reply bool, err error) { |
||||
err = lockRPC.Call(lockServiceName+".Lock", &LockArgs{LockArgs: args}, &reply) |
||||
return reply, err |
||||
} |
||||
|
||||
// RUnlock calls read unlock RPC.
|
||||
func (lockRPC *LockRPCClient) RUnlock(args dsync.LockArgs) (reply bool, err error) { |
||||
err = lockRPC.Call(lockServiceName+".RUnlock", &LockArgs{LockArgs: args}, &reply) |
||||
return reply, err |
||||
} |
||||
|
||||
// Unlock calls write unlock RPC.
|
||||
func (lockRPC *LockRPCClient) Unlock(args dsync.LockArgs) (reply bool, err error) { |
||||
err = lockRPC.Call(lockServiceName+".Unlock", &LockArgs{LockArgs: args}, &reply) |
||||
return reply, err |
||||
} |
||||
|
||||
// ForceUnlock calls force unlock RPC.
|
||||
func (lockRPC *LockRPCClient) ForceUnlock(args dsync.LockArgs) (reply bool, err error) { |
||||
err = lockRPC.Call(lockServiceName+".ForceUnlock", &LockArgs{LockArgs: args}, &reply) |
||||
return reply, err |
||||
} |
||||
|
||||
// Expired calls expired RPC.
|
||||
func (lockRPC *LockRPCClient) Expired(args dsync.LockArgs) (reply bool, err error) { |
||||
err = lockRPC.Call(lockServiceName+".Expired", &LockArgs{LockArgs: args}, &reply) |
||||
return reply, err |
||||
} |
||||
|
||||
// NewLockRPCClient - returns new lock RPC client.
|
||||
func NewLockRPCClient(host *xnet.Host) (*LockRPCClient, error) { |
||||
scheme := "http" |
||||
if globalIsSSL { |
||||
scheme = "https" |
||||
} |
||||
|
||||
serviceURL := &xnet.URL{ |
||||
Scheme: scheme, |
||||
Host: host.String(), |
||||
Path: lockServicePath, |
||||
} |
||||
|
||||
var tlsConfig *tls.Config |
||||
if globalIsSSL { |
||||
tlsConfig = &tls.Config{ |
||||
ServerName: host.Name, |
||||
RootCAs: globalRootCAs, |
||||
} |
||||
} |
||||
|
||||
rpcClient, err := NewRPCClient( |
||||
RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: lockServiceName, |
||||
ServiceURL: serviceURL, |
||||
TLSConfig: tlsConfig, |
||||
}, |
||||
) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &LockRPCClient{rpcClient}, nil |
||||
} |
@ -1,196 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2016, 2017, 2018, 2019 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"context" |
||||
"math/rand" |
||||
"path" |
||||
"time" |
||||
|
||||
"github.com/gorilla/mux" |
||||
"github.com/minio/dsync" |
||||
"github.com/minio/minio/cmd/logger" |
||||
xrpc "github.com/minio/minio/cmd/rpc" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
const ( |
||||
// Lock rpc server endpoint.
|
||||
lockServiceSubPath = "/lock" |
||||
|
||||
// Lock rpc service name.
|
||||
lockServiceName = "Dsync" |
||||
|
||||
// Lock maintenance interval.
|
||||
lockMaintenanceInterval = 1 * time.Minute |
||||
|
||||
// Lock validity check interval.
|
||||
lockValidityCheckInterval = 2 * time.Minute |
||||
) |
||||
|
||||
var lockServicePath = path.Join(minioReservedBucketPath, lockServiceSubPath) |
||||
|
||||
// LockArgs represents arguments for any authenticated lock RPC call.
|
||||
type LockArgs struct { |
||||
AuthArgs |
||||
LockArgs dsync.LockArgs |
||||
} |
||||
|
||||
// lockRPCReceiver is type for RPC handlers
|
||||
type lockRPCReceiver struct { |
||||
ll localLocker |
||||
} |
||||
|
||||
// Lock - rpc handler for (single) write lock operation.
|
||||
func (l *lockRPCReceiver) Lock(args *LockArgs, reply *bool) (err error) { |
||||
*reply, err = l.ll.Lock(args.LockArgs) |
||||
return err |
||||
} |
||||
|
||||
// Unlock - rpc handler for (single) write unlock operation.
|
||||
func (l *lockRPCReceiver) Unlock(args *LockArgs, reply *bool) (err error) { |
||||
*reply, err = l.ll.Unlock(args.LockArgs) |
||||
return err |
||||
} |
||||
|
||||
// RLock - rpc handler for read lock operation.
|
||||
func (l *lockRPCReceiver) RLock(args *LockArgs, reply *bool) (err error) { |
||||
*reply, err = l.ll.RLock(args.LockArgs) |
||||
return err |
||||
} |
||||
|
||||
// RUnlock - rpc handler for read unlock operation.
|
||||
func (l *lockRPCReceiver) RUnlock(args *LockArgs, reply *bool) (err error) { |
||||
*reply, err = l.ll.RUnlock(args.LockArgs) |
||||
return err |
||||
} |
||||
|
||||
// ForceUnlock - rpc handler for force unlock operation.
|
||||
func (l *lockRPCReceiver) ForceUnlock(args *LockArgs, reply *bool) (err error) { |
||||
*reply, err = l.ll.ForceUnlock(args.LockArgs) |
||||
return err |
||||
} |
||||
|
||||
// Expired - rpc handler for expired lock status.
|
||||
func (l *lockRPCReceiver) Expired(args *LockArgs, reply *bool) error { |
||||
l.ll.mutex.Lock() |
||||
defer l.ll.mutex.Unlock() |
||||
// Lock found, proceed to verify if belongs to given uid.
|
||||
if lri, ok := l.ll.lockMap[args.LockArgs.Resource]; ok { |
||||
// Check whether uid is still active
|
||||
for _, entry := range lri { |
||||
if entry.UID == args.LockArgs.UID { |
||||
*reply = false // When uid found, lock is still active so return not expired.
|
||||
return nil // When uid found *reply is set to true.
|
||||
} |
||||
} |
||||
} |
||||
// When we get here lock is no longer active due to either args.LockArgs.Resource
|
||||
// being absent from map or uid not found for given args.LockArgs.Resource
|
||||
*reply = true |
||||
return nil |
||||
} |
||||
|
||||
// lockMaintenance loops over locks that have been active for some time and checks back
|
||||
// with the original server whether it is still alive or not
|
||||
//
|
||||
// Following logic inside ignores the errors generated for Dsync.Active operation.
|
||||
// - server at client down
|
||||
// - some network error (and server is up normally)
|
||||
//
|
||||
// We will ignore the error, and we will retry later to get a resolve on this lock
|
||||
func (l *lockRPCReceiver) lockMaintenance(interval time.Duration) { |
||||
l.ll.mutex.Lock() |
||||
// Get list of long lived locks to check for staleness.
|
||||
nlripLongLived := getLongLivedLocks(l.ll.lockMap, interval) |
||||
l.ll.mutex.Unlock() |
||||
|
||||
// Validate if long lived locks are indeed clean.
|
||||
for _, nlrip := range nlripLongLived { |
||||
// Initialize client based on the long live locks.
|
||||
host, err := xnet.ParseHost(nlrip.lri.Node) |
||||
if err != nil { |
||||
logger.LogIf(context.Background(), err) |
||||
continue |
||||
} |
||||
c, err := NewLockRPCClient(host) |
||||
if err != nil { |
||||
logger.LogIf(context.Background(), err) |
||||
continue |
||||
} |
||||
|
||||
// Call back to original server verify whether the lock is still active (based on name & uid)
|
||||
expired, _ := c.Expired(dsync.LockArgs{ |
||||
UID: nlrip.lri.UID, |
||||
Resource: nlrip.name, |
||||
}) |
||||
|
||||
// Close the connection regardless of the call response.
|
||||
c.Close() |
||||
|
||||
// For successful response, verify if lock is indeed active or stale.
|
||||
if expired { |
||||
// The lock is no longer active at server that originated the lock
|
||||
// So remove the lock from the map.
|
||||
l.ll.mutex.Lock() |
||||
l.ll.removeEntryIfExists(nlrip) // Purge the stale entry if it exists.
|
||||
l.ll.mutex.Unlock() |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Start lock maintenance from all lock servers.
|
||||
func startLockMaintenance(lkSrv *lockRPCReceiver) { |
||||
// Initialize a new ticker with a minute between each ticks.
|
||||
ticker := time.NewTicker(lockMaintenanceInterval) |
||||
// Stop the timer upon service closure and cleanup the go-routine.
|
||||
defer ticker.Stop() |
||||
|
||||
// Start with random sleep time, so as to avoid "synchronous checks" between servers
|
||||
time.Sleep(time.Duration(rand.Float64() * float64(lockMaintenanceInterval))) |
||||
for { |
||||
// Verifies every minute for locks held more than 2minutes.
|
||||
select { |
||||
case <-GlobalServiceDoneCh: |
||||
return |
||||
case <-ticker.C: |
||||
lkSrv.lockMaintenance(lockValidityCheckInterval) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// NewLockRPCServer - returns new lock RPC server.
|
||||
func NewLockRPCServer() (*xrpc.Server, error) { |
||||
rpcServer := xrpc.NewServer() |
||||
if err := rpcServer.RegisterName(lockServiceName, globalLockServer); err != nil { |
||||
return nil, err |
||||
} |
||||
return rpcServer, nil |
||||
} |
||||
|
||||
// Register distributed NS lock handlers.
|
||||
func registerDistNSLockRouter(router *mux.Router) { |
||||
rpcServer, err := NewLockRPCServer() |
||||
logger.FatalIf(err, "Unable to initialize Lock RPC Server") |
||||
|
||||
// Start lock maintenance from all lock servers.
|
||||
go startLockMaintenance(globalLockServer) |
||||
|
||||
subrouter := router.PathPrefix(minioReservedBucketPath).Subrouter() |
||||
subrouter.Path(lockServiceSubPath).HandlerFunc(httpTraceHdrs(rpcServer.ServeHTTP)) |
||||
} |
@ -1,538 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2016, 2017, 2018, 2019 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"os" |
||||
"runtime" |
||||
"sync" |
||||
"testing" |
||||
|
||||
"github.com/minio/dsync" |
||||
) |
||||
|
||||
// Helper function to test equality of locks (without taking timing info into account)
|
||||
func testLockEquality(lriLeft, lriRight []lockRequesterInfo) bool { |
||||
if len(lriLeft) != len(lriRight) { |
||||
return false |
||||
} |
||||
|
||||
for i := 0; i < len(lriLeft); i++ { |
||||
if lriLeft[i].Writer != lriRight[i].Writer || |
||||
lriLeft[i].Node != lriRight[i].Node || |
||||
lriLeft[i].ServiceEndpoint != lriRight[i].ServiceEndpoint || |
||||
lriLeft[i].UID != lriRight[i].UID { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
|
||||
// Helper function to create a lock server for testing
|
||||
func createLockTestServer(t *testing.T) (string, *lockRPCReceiver, string) { |
||||
obj, fsDir, err := prepareFS() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if err = newTestConfig(globalMinioDefaultRegion, obj); err != nil { |
||||
t.Fatalf("unable initialize config file, %s", err) |
||||
} |
||||
|
||||
locker := &lockRPCReceiver{ |
||||
ll: localLocker{ |
||||
mutex: sync.Mutex{}, |
||||
serviceEndpoint: "rpc-path", |
||||
lockMap: make(map[string][]lockRequesterInfo), |
||||
}, |
||||
} |
||||
creds := globalServerConfig.GetCredential() |
||||
token, err := authenticateNode(creds.AccessKey, creds.SecretKey) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
return fsDir, locker, token |
||||
} |
||||
|
||||
// Test Lock functionality
|
||||
func TestLockRpcServerLock(t *testing.T) { |
||||
testPath, locker, token := createLockTestServer(t) |
||||
defer os.RemoveAll(testPath) |
||||
|
||||
la := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "0123-4567", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// Claim a lock
|
||||
var result bool |
||||
err := locker.Lock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} else { |
||||
gotLri := locker.ll.lockMap["name"] |
||||
expectedLri := []lockRequesterInfo{ |
||||
{ |
||||
Writer: true, |
||||
Node: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
UID: "0123-4567", |
||||
}, |
||||
} |
||||
if !testLockEquality(expectedLri, gotLri) { |
||||
t.Errorf("Expected %#v, got %#v", expectedLri, gotLri) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Try to claim same lock again (will fail)
|
||||
la2 := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "89ab-cdef", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
err = locker.Lock(&la2, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if result { |
||||
t.Errorf("Expected %#v, got %#v", false, result) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Test Unlock functionality
|
||||
func TestLockRpcServerUnlock(t *testing.T) { |
||||
testPath, locker, token := createLockTestServer(t) |
||||
defer os.RemoveAll(testPath) |
||||
|
||||
la := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "0123-4567", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// First test return of error when attempting to unlock a lock that does not exist
|
||||
var result bool |
||||
err := locker.Unlock(&la, &result) |
||||
if err == nil { |
||||
t.Errorf("Expected error, got %#v", nil) |
||||
} |
||||
|
||||
// Create lock (so that we can release)
|
||||
err = locker.Lock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
|
||||
// Finally test successful release of lock
|
||||
err = locker.Unlock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} else { |
||||
gotLri := locker.ll.lockMap["name"] |
||||
expectedLri := []lockRequesterInfo(nil) |
||||
if !testLockEquality(expectedLri, gotLri) { |
||||
t.Errorf("Expected %#v, got %#v", expectedLri, gotLri) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Test RLock functionality
|
||||
func TestLockRpcServerRLock(t *testing.T) { |
||||
testPath, locker, token := createLockTestServer(t) |
||||
defer os.RemoveAll(testPath) |
||||
|
||||
la := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "0123-4567", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// Claim a lock
|
||||
var result bool |
||||
err := locker.RLock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} else { |
||||
gotLri := locker.ll.lockMap["name"] |
||||
expectedLri := []lockRequesterInfo{ |
||||
{ |
||||
Writer: false, |
||||
Node: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
UID: "0123-4567", |
||||
}, |
||||
} |
||||
if !testLockEquality(expectedLri, gotLri) { |
||||
t.Errorf("Expected %#v, got %#v", expectedLri, gotLri) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Try to claim same again (will succeed)
|
||||
la2 := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "89ab-cdef", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
err = locker.RLock(&la2, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Test RUnlock functionality
|
||||
func TestLockRpcServerRUnlock(t *testing.T) { |
||||
testPath, locker, token := createLockTestServer(t) |
||||
defer os.RemoveAll(testPath) |
||||
|
||||
la := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "0123-4567", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// First test return of error when attempting to unlock a read-lock that does not exist
|
||||
var result bool |
||||
err := locker.Unlock(&la, &result) |
||||
if err == nil { |
||||
t.Errorf("Expected error, got %#v", nil) |
||||
} |
||||
|
||||
// Create first lock ... (so that we can release)
|
||||
err = locker.RLock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
|
||||
// Try to claim same again (will succeed)
|
||||
la2 := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "89ab-cdef", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// ... and create a second lock on same resource
|
||||
err = locker.RLock(&la2, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
|
||||
// Test successful release of first read lock
|
||||
err = locker.RUnlock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} else { |
||||
gotLri := locker.ll.lockMap["name"] |
||||
expectedLri := []lockRequesterInfo{ |
||||
{ |
||||
Writer: false, |
||||
Node: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
UID: "89ab-cdef", |
||||
}, |
||||
} |
||||
if !testLockEquality(expectedLri, gotLri) { |
||||
t.Errorf("Expected %#v, got %#v", expectedLri, gotLri) |
||||
} |
||||
|
||||
} |
||||
} |
||||
|
||||
// Finally test successful release of second (and last) read lock
|
||||
err = locker.RUnlock(&la2, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else { |
||||
if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} else { |
||||
gotLri := locker.ll.lockMap["name"] |
||||
expectedLri := []lockRequesterInfo(nil) |
||||
if !testLockEquality(expectedLri, gotLri) { |
||||
t.Errorf("Expected %#v, got %#v", expectedLri, gotLri) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Test ForceUnlock functionality
|
||||
func TestLockRpcServerForceUnlock(t *testing.T) { |
||||
testPath, locker, token := createLockTestServer(t) |
||||
defer os.RemoveAll(testPath) |
||||
|
||||
laForce := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "1234-5678", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// First test that UID should be empty
|
||||
var result bool |
||||
err := locker.ForceUnlock(&laForce, &result) |
||||
if err == nil { |
||||
t.Errorf("Expected error, got %#v", nil) |
||||
} |
||||
|
||||
// Then test force unlock of a lock that does not exist (not returning an error)
|
||||
laForce.LockArgs.UID = "" |
||||
err = locker.ForceUnlock(&laForce, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected no error, got %#v", err) |
||||
} |
||||
|
||||
la := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "0123-4567", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// Create lock ... (so that we can force unlock)
|
||||
err = locker.Lock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
|
||||
// Forcefully unlock the lock (not returning an error)
|
||||
err = locker.ForceUnlock(&laForce, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected no error, got %#v", err) |
||||
} |
||||
|
||||
// Try to get lock again (should be granted)
|
||||
err = locker.Lock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
|
||||
// Finally forcefully unlock the lock once again
|
||||
err = locker.ForceUnlock(&laForce, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected no error, got %#v", err) |
||||
} |
||||
} |
||||
|
||||
// Test Expired functionality
|
||||
func TestLockRpcServerExpired(t *testing.T) { |
||||
testPath, locker, token := createLockTestServer(t) |
||||
defer os.RemoveAll(testPath) |
||||
|
||||
la := LockArgs{ |
||||
AuthArgs: AuthArgs{ |
||||
Token: token, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
}, |
||||
LockArgs: dsync.LockArgs{ |
||||
UID: "0123-4567", |
||||
Resource: "name", |
||||
ServerAddr: "node", |
||||
ServiceEndpoint: "rpc-path", |
||||
}} |
||||
|
||||
// Unknown lock at server will return expired = true
|
||||
var expired bool |
||||
err := locker.Expired(&la, &expired) |
||||
if err != nil { |
||||
t.Errorf("Expected no error, got %#v", err) |
||||
} else { |
||||
if !expired { |
||||
t.Errorf("Expected %#v, got %#v", true, expired) |
||||
} |
||||
} |
||||
|
||||
// Create lock (so that we can test that it is not expired)
|
||||
var result bool |
||||
err = locker.Lock(&la, &result) |
||||
if err != nil { |
||||
t.Errorf("Expected %#v, got %#v", nil, err) |
||||
} else if !result { |
||||
t.Errorf("Expected %#v, got %#v", true, result) |
||||
} |
||||
|
||||
err = locker.Expired(&la, &expired) |
||||
if err != nil { |
||||
t.Errorf("Expected no error, got %#v", err) |
||||
} else { |
||||
if expired { |
||||
t.Errorf("Expected %#v, got %#v", false, expired) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Test initialization of lock server.
|
||||
func TestLockServerInit(t *testing.T) { |
||||
if runtime.GOOS == globalWindowsOSName { |
||||
return |
||||
} |
||||
|
||||
obj, fsDir, err := prepareFS() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
defer os.RemoveAll(fsDir) |
||||
if err = newTestConfig(globalMinioDefaultRegion, obj); err != nil { |
||||
t.Fatalf("unable initialize config file, %s", err) |
||||
} |
||||
|
||||
currentIsDistXL := globalIsDistXL |
||||
currentLockServer := globalLockServer |
||||
defer func() { |
||||
globalIsDistXL = currentIsDistXL |
||||
globalLockServer = currentLockServer |
||||
}() |
||||
|
||||
case1Endpoints := mustGetNewEndpointList( |
||||
"http://localhost:9000/mnt/disk1", |
||||
"http://1.1.1.2:9000/mnt/disk2", |
||||
"http://1.1.2.1:9000/mnt/disk3", |
||||
"http://1.1.2.2:9000/mnt/disk4", |
||||
) |
||||
for i := range case1Endpoints { |
||||
if case1Endpoints[i].Host == "localhost:9000" { |
||||
case1Endpoints[i].IsLocal = true |
||||
} |
||||
} |
||||
|
||||
case2Endpoints := mustGetNewEndpointList( |
||||
"http://localhost:9000/mnt/disk1", |
||||
"http://localhost:9000/mnt/disk2", |
||||
"http://1.1.2.1:9000/mnt/disk3", |
||||
"http://1.1.2.2:9000/mnt/disk4", |
||||
) |
||||
for i := range case2Endpoints { |
||||
if case2Endpoints[i].Host == "localhost:9000" { |
||||
case2Endpoints[i].IsLocal = true |
||||
} |
||||
} |
||||
|
||||
globalMinioHost = "" |
||||
testCases := []struct { |
||||
isDistXL bool |
||||
endpoints EndpointList |
||||
}{ |
||||
// Test - 1 one lock server initialized.
|
||||
{true, case1Endpoints}, |
||||
// Test - similar endpoint hosts should
|
||||
// converge to single lock server
|
||||
// initialized.
|
||||
{true, case2Endpoints}, |
||||
} |
||||
|
||||
// Validates lock server initialization.
|
||||
for i, testCase := range testCases { |
||||
globalIsDistXL = testCase.isDistXL |
||||
globalLockServer = nil |
||||
_, _ = newDsyncNodes(testCase.endpoints) |
||||
if globalLockServer == nil && testCase.isDistXL { |
||||
t.Errorf("Test %d: Expected initialized lock RPC receiver, but got uninitialized", i+1) |
||||
} |
||||
} |
||||
} |
@ -1,281 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"fmt" |
||||
"net" |
||||
"net/url" |
||||
"sync" |
||||
"time" |
||||
|
||||
xrpc "github.com/minio/minio/cmd/rpc" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
// DefaultSkewTime - skew time is 15 minutes between minio peers.
|
||||
const DefaultSkewTime = 15 * time.Minute |
||||
|
||||
// defaultRPCTimeout - default RPC timeout is one minute.
|
||||
const defaultRPCTimeout = 5 * time.Minute |
||||
|
||||
// defaultRPCRetryTime - default RPC time to wait before retry after a network error
|
||||
const defaultRPCRetryTime = 1 * time.Minute |
||||
|
||||
var errRPCRetry = fmt.Errorf("rpc: retry error") |
||||
|
||||
func isNetError(err error) bool { |
||||
if err == nil { |
||||
return false |
||||
} |
||||
|
||||
if uerr, isURLError := err.(*url.Error); isURLError { |
||||
if uerr.Timeout() { |
||||
return true |
||||
} |
||||
|
||||
err = uerr.Err |
||||
} |
||||
|
||||
_, isNetOpError := err.(*net.OpError) |
||||
return isNetOpError |
||||
} |
||||
|
||||
// RPCVersion - RPC semantic version based on semver 2.0.0 https://semver.org/.
|
||||
type RPCVersion struct { |
||||
Major uint64 |
||||
Minor uint64 |
||||
Patch uint64 |
||||
} |
||||
|
||||
// Compare - compares given version with this version.
|
||||
func (v RPCVersion) Compare(o RPCVersion) int { |
||||
compare := func(v1, v2 uint64) int { |
||||
if v1 == v2 { |
||||
return 0 |
||||
} |
||||
|
||||
if v1 > v2 { |
||||
return 1 |
||||
} |
||||
|
||||
return -1 |
||||
} |
||||
|
||||
if r := compare(v.Major, o.Major); r != 0 { |
||||
return r |
||||
} |
||||
|
||||
if r := compare(v.Minor, o.Minor); r != 0 { |
||||
return r |
||||
} |
||||
|
||||
return compare(v.Patch, o.Patch) |
||||
} |
||||
|
||||
func (v RPCVersion) String() string { |
||||
return fmt.Sprintf("%v.%v.%v", v.Major, v.Minor, v.Patch) |
||||
} |
||||
|
||||
// AuthArgs - base argument for any RPC call for authentication.
|
||||
type AuthArgs struct { |
||||
Token string |
||||
RPCVersion RPCVersion |
||||
RequestTime time.Time |
||||
} |
||||
|
||||
// Authenticate - checks if given arguments are valid to allow RPC call.
|
||||
// This is xrpc.Authenticator and is called in RPC server.
|
||||
func (args AuthArgs) Authenticate() error { |
||||
// Check whether request time is within acceptable skew time.
|
||||
utcNow := time.Now().UTC() |
||||
if args.RequestTime.Sub(utcNow) > DefaultSkewTime || utcNow.Sub(args.RequestTime) > DefaultSkewTime { |
||||
return fmt.Errorf("client time %v is too apart with server time %v", args.RequestTime, utcNow) |
||||
} |
||||
|
||||
if globalRPCAPIVersion.Compare(args.RPCVersion) != 0 { |
||||
return fmt.Errorf("version mismatch. expected: %v, received: %v", globalRPCAPIVersion, args.RPCVersion) |
||||
} |
||||
|
||||
if !isAuthTokenValid(args.Token) { |
||||
return errAuthentication |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// SetAuthArgs - sets given authentication arguments to this args. This is called in RPC client.
|
||||
func (args *AuthArgs) SetAuthArgs(authArgs AuthArgs) { |
||||
*args = authArgs |
||||
} |
||||
|
||||
// VoidReply - void (empty) RPC reply.
|
||||
type VoidReply struct{} |
||||
|
||||
// RPCClientArgs - RPC client arguments.
|
||||
type RPCClientArgs struct { |
||||
NewAuthTokenFunc func() string |
||||
RPCVersion RPCVersion |
||||
ServiceName string |
||||
ServiceURL *xnet.URL |
||||
TLSConfig *tls.Config |
||||
} |
||||
|
||||
// validate - checks whether given args are valid or not.
|
||||
func (args RPCClientArgs) validate() error { |
||||
if args.NewAuthTokenFunc == nil { |
||||
return fmt.Errorf("NewAuthTokenFunc must not be empty") |
||||
} |
||||
|
||||
if args.ServiceName == "" { |
||||
return fmt.Errorf("ServiceName must not be empty") |
||||
} |
||||
|
||||
if args.ServiceURL.Scheme != "http" && args.ServiceURL.Scheme != "https" { |
||||
return fmt.Errorf("unknown RPC URL %v", args.ServiceURL) |
||||
} |
||||
|
||||
if args.ServiceURL.User != nil || args.ServiceURL.ForceQuery || args.ServiceURL.RawQuery != "" || args.ServiceURL.Fragment != "" { |
||||
return fmt.Errorf("unknown RPC URL %v", args.ServiceURL) |
||||
} |
||||
|
||||
if args.ServiceURL.Scheme == "https" && args.TLSConfig == nil { |
||||
return fmt.Errorf("tls configuration must not be empty for https url %v", args.ServiceURL) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// RPCClient - base RPC client.
|
||||
type RPCClient struct { |
||||
sync.RWMutex |
||||
args RPCClientArgs |
||||
authToken string |
||||
rpcClient *xrpc.Client |
||||
retryTicker *time.Ticker |
||||
} |
||||
|
||||
func (client *RPCClient) setRetryTicker(ticker *time.Ticker) { |
||||
if ticker == nil { |
||||
client.RLock() |
||||
isNil := client.retryTicker == nil |
||||
client.RUnlock() |
||||
if isNil { |
||||
return |
||||
} |
||||
} |
||||
|
||||
client.Lock() |
||||
defer client.Unlock() |
||||
|
||||
if client.retryTicker != nil { |
||||
client.retryTicker.Stop() |
||||
} |
||||
|
||||
client.retryTicker = ticker |
||||
} |
||||
|
||||
// Call - calls servicemethod on remote server.
|
||||
func (client *RPCClient) Call(serviceMethod string, args interface { |
||||
SetAuthArgs(args AuthArgs) |
||||
}, reply interface{}) (err error) { |
||||
lockedCall := func() error { |
||||
client.RLock() |
||||
retryTicker := client.retryTicker |
||||
client.RUnlock() |
||||
if retryTicker != nil { |
||||
select { |
||||
case <-retryTicker.C: |
||||
default: |
||||
return errRPCRetry |
||||
} |
||||
} |
||||
|
||||
client.RLock() |
||||
authToken := client.authToken |
||||
client.RUnlock() |
||||
|
||||
// Make RPC call.
|
||||
args.SetAuthArgs(AuthArgs{authToken, client.args.RPCVersion, time.Now().UTC()}) |
||||
return client.rpcClient.Call(serviceMethod, args, reply) |
||||
} |
||||
|
||||
call := func() error { |
||||
err = lockedCall() |
||||
|
||||
if err == errRPCRetry { |
||||
return err |
||||
} |
||||
|
||||
if isNetError(err) { |
||||
client.setRetryTicker(time.NewTicker(defaultRPCRetryTime)) |
||||
} else { |
||||
client.setRetryTicker(nil) |
||||
} |
||||
|
||||
return err |
||||
} |
||||
|
||||
// If authentication error is received, retry the same call only once
|
||||
// with new authentication token.
|
||||
if err = call(); err == nil { |
||||
return nil |
||||
} |
||||
if err.Error() != errAuthentication.Error() { |
||||
return err |
||||
} |
||||
|
||||
client.Lock() |
||||
client.authToken = client.args.NewAuthTokenFunc() |
||||
client.Unlock() |
||||
return call() |
||||
} |
||||
|
||||
// Close - closes underneath RPC client.
|
||||
func (client *RPCClient) Close() error { |
||||
client.Lock() |
||||
defer client.Unlock() |
||||
|
||||
client.authToken = "" |
||||
return client.rpcClient.Close() |
||||
} |
||||
|
||||
// ServiceURL - returns service URL used for RPC call.
|
||||
func (client *RPCClient) ServiceURL() *xnet.URL { |
||||
// Take copy of ServiceURL
|
||||
u := *(client.args.ServiceURL) |
||||
return &u |
||||
} |
||||
|
||||
// NewRPCClient - returns new RPC client.
|
||||
func NewRPCClient(args RPCClientArgs) (*RPCClient, error) { |
||||
if err := args.validate(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
rpcClient, err := xrpc.NewClient(args.ServiceURL, args.TLSConfig, defaultRPCTimeout) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return &RPCClient{ |
||||
args: args, |
||||
authToken: args.NewAuthTokenFunc(), |
||||
rpcClient: rpcClient, |
||||
}, nil |
||||
} |
@ -1,142 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package rpc |
||||
|
||||
import ( |
||||
"bytes" |
||||
"context" |
||||
"crypto/tls" |
||||
"encoding/gob" |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"reflect" |
||||
"time" |
||||
|
||||
xhttp "github.com/minio/minio/cmd/http" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
"golang.org/x/net/http2" |
||||
) |
||||
|
||||
// DefaultRPCTimeout - default RPC timeout is one minute.
|
||||
const DefaultRPCTimeout = 1 * time.Minute |
||||
|
||||
// Client - http based RPC client.
|
||||
type Client struct { |
||||
httpClient *http.Client |
||||
httpIdleConnsCloser func() |
||||
serviceURL *xnet.URL |
||||
} |
||||
|
||||
// Call - calls service method on RPC server.
|
||||
func (client *Client) Call(serviceMethod string, args, reply interface{}) error { |
||||
replyKind := reflect.TypeOf(reply).Kind() |
||||
if replyKind != reflect.Ptr { |
||||
return fmt.Errorf("rpc reply must be a pointer type, but found %v", replyKind) |
||||
} |
||||
|
||||
argBuf := bytes.NewBuffer(make([]byte, 0, 1024)) |
||||
|
||||
if err := gobEncodeBuf(args, argBuf); err != nil { |
||||
return err |
||||
} |
||||
|
||||
callRequest := CallRequest{ |
||||
Method: serviceMethod, |
||||
ArgBytes: argBuf.Bytes(), |
||||
} |
||||
|
||||
reqBuf := bytes.NewBuffer(make([]byte, 0, 1024)) |
||||
if err := gob.NewEncoder(reqBuf).Encode(callRequest); err != nil { |
||||
return err |
||||
} |
||||
|
||||
response, err := client.httpClient.Post(client.serviceURL.String(), "", reqBuf) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
defer xhttp.DrainBody(response.Body) |
||||
|
||||
if response.StatusCode != http.StatusOK { |
||||
return fmt.Errorf("%v rpc call failed with error code %v", serviceMethod, response.StatusCode) |
||||
} |
||||
|
||||
var callResponse CallResponse |
||||
if err := gob.NewDecoder(response.Body).Decode(&callResponse); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if callResponse.Error != "" { |
||||
return errors.New(callResponse.Error) |
||||
} |
||||
|
||||
return gobDecode(callResponse.ReplyBytes, reply) |
||||
} |
||||
|
||||
// Close closes all idle connections of the underlying http client
|
||||
func (client *Client) Close() error { |
||||
if client.httpIdleConnsCloser != nil { |
||||
client.httpIdleConnsCloser() |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func newCustomDialContext(timeout time.Duration) func(ctx context.Context, network, addr string) (net.Conn, error) { |
||||
return func(ctx context.Context, network, addr string) (net.Conn, error) { |
||||
dialer := &net.Dialer{ |
||||
Timeout: timeout, |
||||
KeepAlive: timeout, |
||||
DualStack: true, |
||||
} |
||||
|
||||
conn, err := dialer.DialContext(ctx, network, addr) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return xhttp.NewTimeoutConn(conn, timeout, timeout), nil |
||||
} |
||||
} |
||||
|
||||
// NewClient - returns new RPC client.
|
||||
func NewClient(serviceURL *xnet.URL, tlsConfig *tls.Config, timeout time.Duration) (*Client, error) { |
||||
// Transport is exactly same as Go default in https://golang.org/pkg/net/http/#RoundTripper
|
||||
// except custom DialContext and TLSClientConfig.
|
||||
tr := &http.Transport{ |
||||
Proxy: http.ProxyFromEnvironment, |
||||
DialContext: newCustomDialContext(timeout), |
||||
MaxIdleConnsPerHost: 4096, |
||||
MaxIdleConns: 4096, |
||||
IdleConnTimeout: 120 * time.Second, |
||||
TLSHandshakeTimeout: 30 * time.Second, |
||||
ExpectContinueTimeout: 10 * time.Second, |
||||
TLSClientConfig: tlsConfig, |
||||
DisableCompression: true, |
||||
} |
||||
if tlsConfig != nil { |
||||
// If TLS is enabled configure http2
|
||||
if err := http2.ConfigureTransport(tr); err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
return &Client{ |
||||
httpClient: &http.Client{Transport: tr}, |
||||
httpIdleConnsCloser: tr.CloseIdleConnections, |
||||
serviceURL: serviceURL, |
||||
}, nil |
||||
} |
@ -1,75 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package rpc |
||||
|
||||
import ( |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"testing" |
||||
|
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
func TestClientCall(t *testing.T) { |
||||
rpcServer := NewServer() |
||||
if err := rpcServer.RegisterName("Arith", &Arith{}); err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
rpcServer.ServeHTTP(w, r) |
||||
})) |
||||
defer httpServer.Close() |
||||
|
||||
url, err := xnet.ParseURL(httpServer.URL) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
rpcClient, err := NewClient(url, nil, DefaultRPCTimeout) |
||||
if err != nil { |
||||
t.Fatalf("NewClient initialization error %v", err) |
||||
} |
||||
|
||||
var reply int |
||||
var boolReply bool |
||||
var intArg int |
||||
|
||||
testCases := []struct { |
||||
serviceMethod string |
||||
args interface{} |
||||
reply interface{} |
||||
expectErr bool |
||||
}{ |
||||
{"Arith.Multiply", Args{7, 8}, &reply, false}, |
||||
{"Arith.Multiply", &Args{7, 8}, &reply, false}, |
||||
// rpc reply must be a pointer type but found int error.
|
||||
{"Arith.Multiply", &Args{7, 8}, reply, true}, |
||||
// gob: type mismatch in decoder: want struct type rpc.Args; got non-struct error.
|
||||
{"Arith.Multiply", intArg, &reply, true}, |
||||
// gob: decoding into local type *bool, received remote type int error.
|
||||
{"Arith.Multiply", &Args{7, 8}, &boolReply, true}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
err := rpcClient.Call(testCase.serviceMethod, testCase.args, testCase.reply) |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
} |
||||
} |
@ -1,48 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package rpc |
||||
|
||||
import ( |
||||
"bytes" |
||||
"sync" |
||||
) |
||||
|
||||
// A Pool is a type-safe wrapper around a sync.Pool.
|
||||
type Pool struct { |
||||
p *sync.Pool |
||||
} |
||||
|
||||
// NewPool constructs a new Pool.
|
||||
func NewPool() Pool { |
||||
return Pool{p: &sync.Pool{ |
||||
New: func() interface{} { |
||||
return &bytes.Buffer{} |
||||
}, |
||||
}} |
||||
} |
||||
|
||||
// Get retrieves a bytes.Buffer from the pool, creating one if necessary.
|
||||
func (p Pool) Get() *bytes.Buffer { |
||||
buf := p.p.Get().(*bytes.Buffer) |
||||
return buf |
||||
} |
||||
|
||||
// Put - returns a bytes.Buffer to the pool.
|
||||
func (p Pool) Put(buf *bytes.Buffer) { |
||||
buf.Reset() |
||||
p.p.Put(buf) |
||||
} |
@ -1,256 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package rpc |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/gob" |
||||
"errors" |
||||
"fmt" |
||||
"net/http" |
||||
"reflect" |
||||
"strings" |
||||
"unicode" |
||||
"unicode/utf8" |
||||
) |
||||
|
||||
// Authenticator - validator of first argument of any RPC call.
|
||||
type Authenticator interface { |
||||
// Method to validate first argument of any RPC call.
|
||||
Authenticate() error |
||||
} |
||||
|
||||
// reflect.Type of error interface.
|
||||
var errorType = reflect.TypeOf((*error)(nil)).Elem() |
||||
|
||||
// reflect.Type of Authenticator interface.
|
||||
var authenticatorType = reflect.TypeOf((*Authenticator)(nil)).Elem() |
||||
|
||||
func gobEncodeBuf(e interface{}, buf *bytes.Buffer) error { |
||||
return gob.NewEncoder(buf).Encode(e) |
||||
} |
||||
|
||||
func gobDecode(data []byte, e interface{}) error { |
||||
return gob.NewDecoder(bytes.NewReader(data)).Decode(e) |
||||
} |
||||
|
||||
// Returns whether given type is exported or builin type or not.
|
||||
func isExportedOrBuiltinType(t reflect.Type) bool { |
||||
for t.Kind() == reflect.Ptr { |
||||
t = t.Elem() |
||||
} |
||||
|
||||
rune, _ := utf8.DecodeRuneInString(t.Name()) |
||||
return unicode.IsUpper(rune) || t.PkgPath() == "" |
||||
} |
||||
|
||||
// Makes method name map from given type.
|
||||
func getMethodMap(receiverType reflect.Type) map[string]reflect.Method { |
||||
methodMap := make(map[string]reflect.Method) |
||||
for i := 0; i < receiverType.NumMethod(); i++ { |
||||
// Method.PkgPath is empty for this package.
|
||||
method := receiverType.Method(i) |
||||
|
||||
// Methods must have three arguments (receiver, args, reply)
|
||||
if method.Type.NumIn() != 3 { |
||||
continue |
||||
} |
||||
|
||||
// First argument must be exported.
|
||||
if !isExportedOrBuiltinType(method.Type.In(1)) { |
||||
continue |
||||
} |
||||
|
||||
// First argument must be Authenticator.
|
||||
if !method.Type.In(1).Implements(authenticatorType) { |
||||
continue |
||||
} |
||||
|
||||
// Second argument must be exported or builtin type.
|
||||
if !isExportedOrBuiltinType(method.Type.In(2)) { |
||||
continue |
||||
} |
||||
|
||||
// Second argument must be a pointer.
|
||||
if method.Type.In(2).Kind() != reflect.Ptr { |
||||
continue |
||||
} |
||||
|
||||
// Method must return one value.
|
||||
if method.Type.NumOut() != 1 { |
||||
continue |
||||
} |
||||
|
||||
// The return type of the method must be error.
|
||||
if method.Type.Out(0) != errorType { |
||||
continue |
||||
} |
||||
|
||||
methodMap[method.Name] = method |
||||
} |
||||
|
||||
return methodMap |
||||
} |
||||
|
||||
// Server - HTTP based RPC server.
|
||||
type Server struct { |
||||
serviceName string |
||||
receiverValue reflect.Value |
||||
methodMap map[string]reflect.Method |
||||
} |
||||
|
||||
// RegisterName - registers receiver with given name to handle RPC requests.
|
||||
func (server *Server) RegisterName(name string, receiver interface{}) error { |
||||
server.serviceName = name |
||||
|
||||
server.receiverValue = reflect.ValueOf(receiver) |
||||
if !reflect.Indirect(server.receiverValue).IsValid() { |
||||
return fmt.Errorf("nil receiver") |
||||
} |
||||
|
||||
receiverName := reflect.Indirect(server.receiverValue).Type().Name() |
||||
receiverType := reflect.TypeOf(receiver) |
||||
server.methodMap = getMethodMap(receiverType) |
||||
if len(server.methodMap) == 0 { |
||||
str := "rpc.Register: type " + receiverName + " has no exported methods of suitable type" |
||||
|
||||
// To help the user, see if a pointer receiver would work.
|
||||
if len(getMethodMap(reflect.PtrTo(receiverType))) != 0 { |
||||
str += " (hint: pass a pointer to value of that type)" |
||||
} |
||||
|
||||
return errors.New(str) |
||||
} |
||||
|
||||
return nil |
||||
} |
||||
|
||||
// call - call service method in receiver.
|
||||
func (server *Server) call(serviceMethod string, argBytes []byte, replyBytes *bytes.Buffer) (err error) { |
||||
tokens := strings.SplitN(serviceMethod, ".", 2) |
||||
if len(tokens) != 2 { |
||||
return fmt.Errorf("invalid service/method request ill-formed %v", serviceMethod) |
||||
} |
||||
|
||||
serviceName := tokens[0] |
||||
if serviceName != server.serviceName { |
||||
return fmt.Errorf("can't find service %v", serviceName) |
||||
} |
||||
|
||||
methodName := tokens[1] |
||||
method, found := server.methodMap[methodName] |
||||
if !found { |
||||
return fmt.Errorf("can't find method %v", methodName) |
||||
} |
||||
|
||||
var argv reflect.Value |
||||
|
||||
// Decode the argument value.
|
||||
argIsValue := false // if true, need to indirect before calling.
|
||||
if method.Type.In(1).Kind() == reflect.Ptr { |
||||
argv = reflect.New(method.Type.In(1).Elem()) |
||||
} else { |
||||
argv = reflect.New(method.Type.In(1)) |
||||
argIsValue = true |
||||
} |
||||
|
||||
if err = gobDecode(argBytes, argv.Interface()); err != nil { |
||||
return err |
||||
} |
||||
|
||||
if argIsValue { |
||||
argv = argv.Elem() |
||||
} |
||||
|
||||
// call Authenticate() method.
|
||||
authMethod, ok := method.Type.In(1).MethodByName("Authenticate") |
||||
if !ok { |
||||
panic("Authenticate() method not found. This should not happen.") |
||||
} |
||||
returnValues := authMethod.Func.Call([]reflect.Value{argv}) |
||||
errInter := returnValues[0].Interface() |
||||
if errInter != nil { |
||||
err = errInter.(error) |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
replyv := reflect.New(method.Type.In(2).Elem()) |
||||
|
||||
switch method.Type.In(2).Elem().Kind() { |
||||
case reflect.Map: |
||||
replyv.Elem().Set(reflect.MakeMap(method.Type.In(2).Elem())) |
||||
case reflect.Slice: |
||||
replyv.Elem().Set(reflect.MakeSlice(method.Type.In(2).Elem(), 0, 0)) |
||||
} |
||||
|
||||
returnValues = method.Func.Call([]reflect.Value{server.receiverValue, argv, replyv}) |
||||
errInter = returnValues[0].Interface() |
||||
if errInter != nil { |
||||
err = errInter.(error) |
||||
} |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
return gobEncodeBuf(replyv.Interface(), replyBytes) |
||||
} |
||||
|
||||
// CallRequest - RPC call request parameters.
|
||||
type CallRequest struct { |
||||
Method string |
||||
ArgBytes []byte |
||||
} |
||||
|
||||
// CallResponse - RPC call response parameters.
|
||||
type CallResponse struct { |
||||
Error string |
||||
ReplyBytes []byte |
||||
} |
||||
|
||||
// ServeHTTP - handles RPC on HTTP request.
|
||||
func (server *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) { |
||||
if req.Method != http.MethodPost { |
||||
w.WriteHeader(http.StatusMethodNotAllowed) |
||||
return |
||||
} |
||||
|
||||
var callRequest CallRequest |
||||
if err := gob.NewDecoder(req.Body).Decode(&callRequest); err != nil { |
||||
w.WriteHeader(http.StatusBadRequest) |
||||
return |
||||
} |
||||
|
||||
callResponse := CallResponse{} |
||||
|
||||
buf := bytes.NewBuffer(make([]byte, 0, 1024)) |
||||
|
||||
if err := server.call(callRequest.Method, callRequest.ArgBytes, buf); err != nil { |
||||
callResponse.Error = err.Error() |
||||
} |
||||
callResponse.ReplyBytes = buf.Bytes() |
||||
|
||||
gob.NewEncoder(w).Encode(callResponse) |
||||
|
||||
w.(http.Flusher).Flush() |
||||
} |
||||
|
||||
// NewServer - returns new RPC server.
|
||||
func NewServer() *Server { |
||||
return &Server{} |
||||
} |
@ -1,349 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package rpc |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/gob" |
||||
"errors" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"reflect" |
||||
"testing" |
||||
) |
||||
|
||||
func gobEncode(e interface{}) ([]byte, error) { |
||||
var buf bytes.Buffer |
||||
err := gob.NewEncoder(&buf).Encode(e) |
||||
return buf.Bytes(), err |
||||
} |
||||
|
||||
type Args struct { |
||||
A, B int |
||||
} |
||||
|
||||
func (a *Args) Authenticate() (err error) { |
||||
if a.A == 0 && a.B == 0 { |
||||
err = errors.New("authenticated failed") |
||||
} |
||||
|
||||
return |
||||
} |
||||
|
||||
type Quotient struct { |
||||
Quo, Rem int |
||||
} |
||||
|
||||
type Arith struct{} |
||||
|
||||
func (t *Arith) Multiply(args *Args, reply *int) error { |
||||
*reply = args.A * args.B |
||||
return nil |
||||
} |
||||
|
||||
func (t *Arith) Divide(args *Args, quo *Quotient) error { |
||||
if args.B == 0 { |
||||
return errors.New("divide by zero") |
||||
} |
||||
quo.Quo = args.A / args.B |
||||
quo.Rem = args.A % args.B |
||||
return nil |
||||
} |
||||
|
||||
type mytype int |
||||
|
||||
type Auth struct{} |
||||
|
||||
func (a Auth) Authenticate() error { |
||||
return nil |
||||
} |
||||
|
||||
// exported method.
|
||||
func (t mytype) Foo(a *Auth, b *int) error { |
||||
return nil |
||||
} |
||||
|
||||
// incompatible method because of first argument is not Authenticator.
|
||||
func (t *mytype) Bar(a, b *int) error { |
||||
return nil |
||||
} |
||||
|
||||
// incompatible method because of error is not returned.
|
||||
func (t mytype) IncompatFoo(a, b *int) { |
||||
} |
||||
|
||||
// incompatible method because of second argument is not a pointer.
|
||||
func (t *mytype) IncompatBar(a *int, b int) error { |
||||
return nil |
||||
} |
||||
|
||||
func TestIsExportedOrBuiltinType(t *testing.T) { |
||||
var i int |
||||
case1Type := reflect.TypeOf(i) |
||||
|
||||
var iptr *int |
||||
case2Type := reflect.TypeOf(iptr) |
||||
|
||||
var a Arith |
||||
case3Type := reflect.TypeOf(a) |
||||
|
||||
var aptr *Arith |
||||
case4Type := reflect.TypeOf(aptr) |
||||
|
||||
var m mytype |
||||
case5Type := reflect.TypeOf(m) |
||||
|
||||
var mptr *mytype |
||||
case6Type := reflect.TypeOf(mptr) |
||||
|
||||
testCases := []struct { |
||||
t reflect.Type |
||||
expectedResult bool |
||||
}{ |
||||
{case1Type, true}, |
||||
{case2Type, true}, |
||||
{case3Type, true}, |
||||
{case4Type, true}, |
||||
// Type.Name() starts with lower case and Type.PkgPath() is not empty.
|
||||
{case5Type, false}, |
||||
// Type.Name() starts with lower case and Type.PkgPath() is not empty.
|
||||
{case6Type, false}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
result := isExportedOrBuiltinType(testCase.t) |
||||
|
||||
if result != testCase.expectedResult { |
||||
t.Fatalf("case %v: expected: %v, got: %v\n", i+1, testCase.expectedResult, result) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestGetMethodMap(t *testing.T) { |
||||
var a Arith |
||||
case1Type := reflect.TypeOf(a) |
||||
|
||||
var aptr *Arith |
||||
case2Type := reflect.TypeOf(aptr) |
||||
|
||||
var m mytype |
||||
case3Type := reflect.TypeOf(m) |
||||
|
||||
var mptr *mytype |
||||
case4Type := reflect.TypeOf(mptr) |
||||
|
||||
testCases := []struct { |
||||
t reflect.Type |
||||
expectedResult int |
||||
}{ |
||||
// No methods exported.
|
||||
{case1Type, 0}, |
||||
// Multiply and Divide methods are exported.
|
||||
{case2Type, 2}, |
||||
// Foo method is exported.
|
||||
{case3Type, 1}, |
||||
// Foo method is exported.
|
||||
{case4Type, 1}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
m := getMethodMap(testCase.t) |
||||
result := len(m) |
||||
|
||||
if result != testCase.expectedResult { |
||||
t.Fatalf("case %v: expected: %v, got: %v\n", i+1, testCase.expectedResult, result) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestServerRegisterName(t *testing.T) { |
||||
case1Receiver := &Arith{} |
||||
var case2Receiver mytype |
||||
var case3Receiver *Arith |
||||
i := 0 |
||||
var case4Receiver = &i |
||||
var case5Receiver Arith |
||||
|
||||
testCases := []struct { |
||||
name string |
||||
receiver interface{} |
||||
expectErr bool |
||||
}{ |
||||
{"Arith", case1Receiver, false}, |
||||
{"arith", case1Receiver, false}, |
||||
{"Arith", case2Receiver, false}, |
||||
// nil receiver error.
|
||||
{"Arith", nil, true}, |
||||
// nil receiver error.
|
||||
{"Arith", case3Receiver, true}, |
||||
// rpc.Register: type Arith has no exported methods of suitable type error.
|
||||
{"Arith", case4Receiver, true}, |
||||
// rpc.Register: type Arith has no exported methods of suitable type (hint: pass a pointer to value of that type) error.
|
||||
{"Arith", case5Receiver, true}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
err := NewServer().RegisterName(testCase.name, testCase.receiver) |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: expected: %v, got: %v\n", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestServerCall(t *testing.T) { |
||||
server1 := NewServer() |
||||
if err := server1.RegisterName("Arith", &Arith{}); err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
server2 := NewServer() |
||||
if err := server2.RegisterName("arith", &Arith{}); err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
case1ArgBytes, err := gobEncode(&Args{7, 8}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
reply := 7 * 8 |
||||
case1ExpectedResult, err := gobEncode(&reply) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
case2ArgBytes, err := gobEncode(&Args{}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
testCases := []struct { |
||||
server *Server |
||||
serviceMethod string |
||||
argBytes []byte |
||||
expectedResult []byte |
||||
expectErr bool |
||||
}{ |
||||
{server1, "Arith.Multiply", case1ArgBytes, case1ExpectedResult, false}, |
||||
{server2, "arith.Multiply", case1ArgBytes, case1ExpectedResult, false}, |
||||
// invalid service/method request ill-formed error.
|
||||
{server1, "Multiply", nil, nil, true}, |
||||
// can't find service error.
|
||||
{server1, "arith.Multiply", nil, nil, true}, |
||||
// can't find method error.
|
||||
{server1, "Arith.Add", nil, nil, true}, |
||||
// gob decode error.
|
||||
{server1, "Arith.Multiply", []byte{10}, nil, true}, |
||||
// authentication error.
|
||||
{server1, "Arith.Multiply", case2ArgBytes, nil, true}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
buf := bytes.NewBuffer([]byte{}) |
||||
|
||||
err := testCase.server.call(testCase.serviceMethod, testCase.argBytes, buf) |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: error: expected: %v, got: %v\n", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
|
||||
if !testCase.expectErr { |
||||
if !reflect.DeepEqual(buf.Bytes(), testCase.expectedResult) { |
||||
t.Fatalf("case %v: result: expected: %v, got: %v\n", i+1, testCase.expectedResult, buf.Bytes()) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestServerServeHTTP(t *testing.T) { |
||||
server1 := NewServer() |
||||
if err := server1.RegisterName("Arith", &Arith{}); err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
argBytes, err := gobEncode(&Args{7, 8}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
requestBodyData, err := gobEncode(CallRequest{Method: "Arith.Multiply", ArgBytes: argBytes}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case1Request, err := http.NewRequest("POST", "http://localhost:12345/", bytes.NewReader(requestBodyData)) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
reply := 7 * 8 |
||||
replyBytes, err := gobEncode(&reply) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case1Result, err := gobEncode(CallResponse{ReplyBytes: replyBytes}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
case2Request, err := http.NewRequest("GET", "http://localhost:12345/", bytes.NewReader([]byte{})) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
case3Request, err := http.NewRequest("POST", "http://localhost:12345/", bytes.NewReader([]byte{10, 20})) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
requestBodyData, err = gobEncode(CallRequest{Method: "Arith.Add", ArgBytes: argBytes}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case4Request, err := http.NewRequest("POST", "http://localhost:12345/", bytes.NewReader(requestBodyData)) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case4Result, err := gobEncode(CallResponse{Error: "can't find method Add"}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
testCases := []struct { |
||||
server *Server |
||||
httpRequest *http.Request |
||||
expectedCode int |
||||
expectedResult []byte |
||||
}{ |
||||
{server1, case1Request, http.StatusOK, case1Result}, |
||||
{server1, case2Request, http.StatusMethodNotAllowed, nil}, |
||||
{server1, case3Request, http.StatusBadRequest, nil}, |
||||
{server1, case4Request, http.StatusOK, case4Result}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
writer := httptest.NewRecorder() |
||||
testCase.server.ServeHTTP(writer, testCase.httpRequest) |
||||
if writer.Code != testCase.expectedCode { |
||||
t.Fatalf("case %v: code: expected: %v, got: %v\n", i+1, testCase.expectedCode, writer.Code) |
||||
} |
||||
|
||||
if testCase.expectedCode == http.StatusOK { |
||||
result := writer.Body.Bytes() |
||||
if !reflect.DeepEqual(result, testCase.expectedResult) { |
||||
t.Fatalf("case %v: result: expected: %v, got: %v\n", i+1, testCase.expectedResult, result) |
||||
} |
||||
} |
||||
} |
||||
} |
@ -1,402 +0,0 @@ |
||||
/* |
||||
* MinIO Cloud Storage, (C) 2018 MinIO, Inc. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
|
||||
package cmd |
||||
|
||||
import ( |
||||
"crypto/tls" |
||||
"errors" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
xrpc "github.com/minio/minio/cmd/rpc" |
||||
xnet "github.com/minio/minio/pkg/net" |
||||
) |
||||
|
||||
func TestAuthArgsAuthenticate(t *testing.T) { |
||||
tmpGlobalServerConfig := globalServerConfig |
||||
defer func() { |
||||
globalServerConfig = tmpGlobalServerConfig |
||||
}() |
||||
globalServerConfig = newServerConfig() |
||||
|
||||
case1Args := AuthArgs{ |
||||
Token: newAuthToken(), |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
} |
||||
|
||||
case2Args := AuthArgs{ |
||||
Token: newAuthToken(), |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow().Add(15 * time.Minute), |
||||
} |
||||
|
||||
case3Args := AuthArgs{ |
||||
Token: newAuthToken(), |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow().Add(-16 * time.Minute), |
||||
} |
||||
|
||||
case4Args := AuthArgs{ |
||||
Token: newAuthToken(), |
||||
RPCVersion: RPCVersion{99, 99, 99}, |
||||
RequestTime: UTCNow(), |
||||
} |
||||
|
||||
case5Args := AuthArgs{ |
||||
Token: "invalid-token", |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
} |
||||
|
||||
testCases := []struct { |
||||
args AuthArgs |
||||
expectErr bool |
||||
}{ |
||||
{case1Args, false}, |
||||
{case2Args, false}, |
||||
{case3Args, true}, |
||||
{case4Args, true}, |
||||
{case5Args, true}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
err := testCase.args.Authenticate() |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestAuthArgsSetAuthArgs(t *testing.T) { |
||||
tmpGlobalServerConfig := globalServerConfig |
||||
defer func() { |
||||
globalServerConfig = tmpGlobalServerConfig |
||||
}() |
||||
globalServerConfig = newServerConfig() |
||||
|
||||
case1Args := AuthArgs{ |
||||
Token: newAuthToken(), |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow(), |
||||
} |
||||
|
||||
case2Args := AuthArgs{ |
||||
Token: newAuthToken(), |
||||
RPCVersion: globalRPCAPIVersion, |
||||
RequestTime: UTCNow().Add(15 * time.Minute), |
||||
} |
||||
|
||||
testCases := []struct { |
||||
args *AuthArgs |
||||
authArgs AuthArgs |
||||
expectedResult *AuthArgs |
||||
}{ |
||||
{&AuthArgs{}, case1Args, &case1Args}, |
||||
{&case2Args, case1Args, &case1Args}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
testCase.args.SetAuthArgs(testCase.authArgs) |
||||
result := testCase.args |
||||
|
||||
if !reflect.DeepEqual(result, testCase.expectedResult) { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectedResult, result) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestRPCClientArgsValidate(t *testing.T) { |
||||
case1URL, err := xnet.ParseURL("http://localhost:12345/rpc") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case1Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: case1URL, |
||||
TLSConfig: nil, |
||||
} |
||||
|
||||
case2URL, err := xnet.ParseURL("https://localhost:12345/rpc") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case2Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: case1URL, |
||||
TLSConfig: &tls.Config{}, |
||||
} |
||||
|
||||
case3Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: nil, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: case1URL, |
||||
TLSConfig: &tls.Config{}, |
||||
} |
||||
|
||||
case4Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceURL: case1URL, |
||||
TLSConfig: &tls.Config{}, |
||||
} |
||||
|
||||
case5URL, err := xnet.ParseURL("ftp://localhost:12345/rpc") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case5Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: case5URL, |
||||
TLSConfig: &tls.Config{}, |
||||
} |
||||
|
||||
case6URL, err := xnet.ParseURL("http://localhost:12345/rpc?location") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
case6Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: case6URL, |
||||
TLSConfig: &tls.Config{}, |
||||
} |
||||
|
||||
case7Args := RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: case2URL, |
||||
TLSConfig: nil, |
||||
} |
||||
|
||||
testCases := []struct { |
||||
args RPCClientArgs |
||||
expectErr bool |
||||
}{ |
||||
{case1Args, false}, |
||||
{case2Args, false}, |
||||
// NewAuthTokenFunc must not be empty error.
|
||||
{case3Args, true}, |
||||
// ServiceName must not be empty.
|
||||
{case4Args, true}, |
||||
// unknown RPC URL error.
|
||||
{case5Args, true}, |
||||
// unknown RPC URL error.
|
||||
{case6Args, true}, |
||||
// tls configuration must not be empty for https url error.
|
||||
{case7Args, true}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
err := testCase.args.validate() |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type Args struct { |
||||
AuthArgs |
||||
A, B int |
||||
} |
||||
|
||||
type Quotient struct { |
||||
Quo, Rem int |
||||
} |
||||
|
||||
type Arith struct{} |
||||
|
||||
func (t *Arith) Multiply(args *Args, reply *int) error { |
||||
*reply = args.A * args.B |
||||
return nil |
||||
} |
||||
|
||||
func (t *Arith) Divide(args *Args, quo *Quotient) error { |
||||
if args.B == 0 { |
||||
return errors.New("divide by zero") |
||||
} |
||||
quo.Quo = args.A / args.B |
||||
quo.Rem = args.A % args.B |
||||
return nil |
||||
} |
||||
|
||||
func TestRPCClientCall(t *testing.T) { |
||||
tmpGlobalServerConfig := globalServerConfig |
||||
defer func() { |
||||
globalServerConfig = tmpGlobalServerConfig |
||||
}() |
||||
globalServerConfig = newServerConfig() |
||||
|
||||
rpcServer := xrpc.NewServer() |
||||
if err := rpcServer.RegisterName("Arith", &Arith{}); err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
rpcServer.ServeHTTP(w, r) |
||||
})) |
||||
defer httpServer.Close() |
||||
|
||||
url, err := xnet.ParseURL(httpServer.URL) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
rpcClient, err := NewRPCClient(RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: url, |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
var case1Result int |
||||
case1ExpectedResult := 19 * 8 |
||||
|
||||
testCases := []struct { |
||||
serviceMethod string |
||||
args *Args |
||||
result interface{} |
||||
changeConfig bool |
||||
expectedResult interface{} |
||||
expectErr bool |
||||
}{ |
||||
{"Arith.Multiply", &Args{A: 19, B: 8}, &case1Result, false, &case1ExpectedResult, false}, |
||||
{"Arith.Divide", &Args{A: 19, B: 8}, &Quotient{}, false, &Quotient{2, 3}, false}, |
||||
{"Arith.Multiply", &Args{A: 19, B: 8}, &case1Result, true, &case1ExpectedResult, false}, |
||||
{"Arith.Divide", &Args{A: 19, B: 8}, &Quotient{}, true, &Quotient{2, 3}, false}, |
||||
{"Arith.Divide", &Args{A: 19, B: 0}, &Quotient{}, false, nil, true}, |
||||
{"Arith.Divide", &Args{A: 19, B: 8}, &case1Result, false, nil, true}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
if testCase.changeConfig { |
||||
globalServerConfig = newServerConfig() |
||||
} |
||||
|
||||
err := rpcClient.Call(testCase.serviceMethod, testCase.args, testCase.result) |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
|
||||
if !testCase.expectErr { |
||||
if !reflect.DeepEqual(testCase.result, testCase.expectedResult) { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectedResult, testCase.result) |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestRPCClientClose(t *testing.T) { |
||||
tmpGlobalServerConfig := globalServerConfig |
||||
defer func() { |
||||
globalServerConfig = tmpGlobalServerConfig |
||||
}() |
||||
globalServerConfig = newServerConfig() |
||||
|
||||
url, err := xnet.ParseURL("http://localhost:12345") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
rpcClient, err := NewRPCClient(RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: url, |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
testCases := []struct { |
||||
rpcClient *RPCClient |
||||
expectErr bool |
||||
}{ |
||||
{rpcClient, false}, |
||||
// Double close.
|
||||
{rpcClient, false}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
err := testCase.rpcClient.Close() |
||||
expectErr := (err != nil) |
||||
|
||||
if expectErr != testCase.expectErr { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectErr, expectErr) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestRPCClientServiceURL(t *testing.T) { |
||||
tmpGlobalServerConfig := globalServerConfig |
||||
defer func() { |
||||
globalServerConfig = tmpGlobalServerConfig |
||||
}() |
||||
globalServerConfig = newServerConfig() |
||||
|
||||
url, err := xnet.ParseURL("http://localhost:12345") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
rpcClient, err := NewRPCClient(RPCClientArgs{ |
||||
NewAuthTokenFunc: newAuthToken, |
||||
RPCVersion: globalRPCAPIVersion, |
||||
ServiceName: "Arith", |
||||
ServiceURL: url, |
||||
}) |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
|
||||
case1Result, err := xnet.ParseURL("http://localhost:12345") |
||||
if err != nil { |
||||
t.Fatalf("unexpected error %v", err) |
||||
} |
||||
testCases := []struct { |
||||
rpcClient *RPCClient |
||||
expectedResult *xnet.URL |
||||
}{ |
||||
{rpcClient, case1Result}, |
||||
} |
||||
|
||||
for i, testCase := range testCases { |
||||
result := testCase.rpcClient.ServiceURL() |
||||
|
||||
if !reflect.DeepEqual(result, testCase.expectedResult) { |
||||
t.Fatalf("case %v: expected: %v, got: %v", i+1, testCase.expectedResult, result) |
||||
} |
||||
} |
||||
} |
Loading…
Reference in new issue