diff --git a/cmd/storage-rest-client.go b/cmd/storage-rest-client.go index abe2ba6e0..82c3e632b 100644 --- a/cmd/storage-rest-client.go +++ b/cmd/storage-rest-client.go @@ -43,11 +43,13 @@ import ( // data for all internode storage REST requests. const storageRESTTimeout = 5 * time.Minute -func isNetworkDisconnectError(err error) bool { +func isNetworkError(err error) bool { if err == nil { return false } - + if err.Error() == errConnectionStale.Error() { + return true + } if uerr, isURLError := err.(*url.Error); isURLError { if uerr.Timeout() { return true @@ -68,7 +70,7 @@ func toStorageErr(err error) error { return nil } - if isNetworkDisconnectError(err) { + if isNetworkError(err) { return errDiskNotFound } @@ -128,6 +130,7 @@ type storageRESTClient struct { restClient *rest.Client connected bool lastError error + instanceID string // REST server's instanceID which is sent with every request for validation. } // Wrapper to restClient.Call to handle network errors, in case of network error the connection is makred disconnected @@ -137,12 +140,16 @@ func (client *storageRESTClient) call(method string, values url.Values, body io. if !client.connected { return nil, errDiskNotFound } + if values == nil { + values = make(url.Values) + } + values.Set(storageRESTInstanceID, client.instanceID) respBody, err = client.restClient.Call(method, values, body, length) if err == nil { return respBody, nil } client.lastError = err - if isNetworkDisconnectError(err) { + if isNetworkError(err) { client.connected = false } @@ -350,6 +357,22 @@ func (client *storageRESTClient) RenameFile(srcVolume, srcPath, dstVolume, dstPa return err } +// Gets peer storage server's instanceID - to be used with every REST call for validation. +func (client *storageRESTClient) getInstanceID() (err error) { + respBody, err := client.restClient.Call(storageRESTMethodGetInstanceID, nil, nil, -1) + if err != nil { + return err + } + defer http.DrainBody(respBody) + instanceIDBuf := make([]byte, 64) + n, err := io.ReadFull(respBody, instanceIDBuf) + if err != io.EOF && err != io.ErrUnexpectedEOF { + return err + } + client.instanceID = string(instanceIDBuf[:n]) + return nil +} + // Close - marks the client as closed. func (client *storageRESTClient) Close() error { client.connected = false @@ -382,5 +405,7 @@ func newStorageRESTClient(endpoint Endpoint) *storageRESTClient { } restClient := rest.NewClient(serverURL, tlsConfig, storageRESTTimeout, newAuthToken) - return &storageRESTClient{endpoint: endpoint, restClient: restClient, connected: true} + client := &storageRESTClient{endpoint: endpoint, restClient: restClient, connected: true} + client.connected = client.getInstanceID() == nil + return client } diff --git a/cmd/storage-rest-common.go b/cmd/storage-rest-common.go index 737fd48c8..aab67a242 100644 --- a/cmd/storage-rest-common.go +++ b/cmd/storage-rest-common.go @@ -36,6 +36,7 @@ const ( storageRESTMethodListDir = "listdir" storageRESTMethodDeleteFile = "deletefile" storageRESTMethodRenameFile = "renamefile" + storageRESTMethodGetInstanceID = "getinstanceid" ) const ( @@ -51,4 +52,5 @@ const ( storageRESTCount = "count" storageRESTBitrotAlgo = "bitrot-algo" storageRESTBitrotHash = "bitrot-hash" + storageRESTInstanceID = "instance-id" ) diff --git a/cmd/storage-rest-server.go b/cmd/storage-rest-server.go index 4bb8a67af..00c31d97f 100644 --- a/cmd/storage-rest-server.go +++ b/cmd/storage-rest-server.go @@ -17,6 +17,7 @@ package cmd import ( + "errors" "fmt" "io" "path" @@ -33,9 +34,14 @@ import ( "github.com/minio/minio/cmd/logger" ) +var errConnectionStale = errors.New("connection stale, REST client/server instance-id mismatch") + // To abstract a disk over network. type storageRESTServer struct { storage *posix + // Used to detect reboot of servers so that peers revalidate format.json as + // different disk might be available on the same mount point after reboot. + instanceID string } func (s *storageRESTServer) writeErrorResponse(w http.ResponseWriter, err error) { @@ -43,23 +49,15 @@ func (s *storageRESTServer) writeErrorResponse(w http.ResponseWriter, err error) w.Write([]byte(err.Error())) } -// Authenticates storage client's requests. -func storageServerRequestAuthenticate(r *http.Request) error { - _, _, err := webRequestAuthenticate(r) - return err -} - -// IsValid - To authenticate and verify the time difference. -func (s *storageRESTServer) IsValid(w http.ResponseWriter, r *http.Request) bool { - if err := storageServerRequestAuthenticate(r); err != nil { - w.WriteHeader(http.StatusForbidden) - return false +// Authenticates storage client's requests and validates for skewed time. +func storageServerRequestValidate(r *http.Request) error { + if _, _, err := webRequestAuthenticate(r); err != nil { + return err } requestTimeStr := r.Header.Get("X-Minio-Time") requestTime, err := time.Parse(time.RFC3339, requestTimeStr) if err != nil { - s.writeErrorResponse(w, err) - return false + return err } utcNow := UTCNow() delta := requestTime.Sub(utcNow) @@ -67,12 +65,36 @@ func (s *storageRESTServer) IsValid(w http.ResponseWriter, r *http.Request) bool delta = delta * -1 } if delta > DefaultSkewTime { - s.writeErrorResponse(w, fmt.Errorf("client time %v is too apart with server time %v", requestTime, utcNow)) + return fmt.Errorf("client time %v is too apart with server time %v", requestTime, utcNow) + } + return nil +} + +// IsValid - To authenticate and verify the time difference. +func (s *storageRESTServer) IsValid(w http.ResponseWriter, r *http.Request) bool { + if err := storageServerRequestValidate(r); err != nil { + s.writeErrorResponse(w, err) + return false + } + instanceID := r.URL.Query().Get(storageRESTInstanceID) + if instanceID != s.instanceID { + // This will cause the peer to revalidate format.json using a new storage-rest-client instance. + s.writeErrorResponse(w, errConnectionStale) return false } return true } +// GetInstanceID - returns the instance ID of the server. +func (s *storageRESTServer) GetInstanceID(w http.ResponseWriter, r *http.Request) { + if err := storageServerRequestValidate(r); err != nil { + s.writeErrorResponse(w, err) + return + } + w.Header().Set("Content-Length", strconv.Itoa(len(s.instanceID))) + w.Write([]byte(s.instanceID)) +} + // DiskInfoHandler - returns disk info. func (s *storageRESTServer) DiskInfoHandler(w http.ResponseWriter, r *http.Request) { if !s.IsValid(w, r) { @@ -383,7 +405,7 @@ func registerStorageRESTHandlers(router *mux.Router, endpoints EndpointList) { logger.Fatal(uiErrUnableToWriteInBackend(err), "Unable to initialize posix backend") } - server := &storageRESTServer{storage} + server := &storageRESTServer{storage, mustGetUUID()} subrouter := router.PathPrefix(path.Join(storageRESTPath, endpoint.Path)).Subrouter() @@ -414,6 +436,7 @@ func registerStorageRESTHandlers(router *mux.Router, endpoints EndpointList) { Queries(restQueries(storageRESTVolume, storageRESTFilePath)...) subrouter.Methods(http.MethodPost).Path("/" + storageRESTMethodRenameFile).HandlerFunc(httpTraceHdrs(server.RenameFileHandler)). Queries(restQueries(storageRESTSrcVolume, storageRESTSrcPath, storageRESTDstVolume, storageRESTDstPath)...) + subrouter.Methods(http.MethodPost).Path("/" + storageRESTMethodGetInstanceID).HandlerFunc(httpTraceAll(server.GetInstanceID)) } router.NotFoundHandler = http.HandlerFunc(httpTraceAll(notFoundHandler))