diff --git a/cmd/endpoint.go b/cmd/endpoint.go index 7b21c0cbc..30af75657 100644 --- a/cmd/endpoint.go +++ b/cmd/endpoint.go @@ -127,14 +127,10 @@ func NewEndpoint(arg string) (Endpoint, error) { return Endpoint{}, fmt.Errorf("empty or root path is not supported in URL endpoint") } - // Get IPv4 address of the host. - hostIPs, err := getHostIP4(host) + isLocal, err = isLocalHost(host) if err != nil { return Endpoint{}, err } - - // If intersection of two IP sets is not empty, then the host is local host. - isLocal = !localIP4.Intersection(hostIPs).IsEmpty() } else { u = &url.URL{Path: path.Clean(arg)} isLocal = true diff --git a/cmd/gateway-main.go b/cmd/gateway-main.go index 8b72e768f..9ca9d94e1 100644 --- a/cmd/gateway-main.go +++ b/cmd/gateway-main.go @@ -185,11 +185,24 @@ func gatewayMain(ctx *cli.Context) { } // First argument is selected backend type. - backendType := ctx.Args().First() + backendType := ctx.Args().Get(0) + // Second argument is the endpoint address (optional) + endpointAddr := ctx.Args().Get(1) + // Third argument is the address flag + serverAddr := ctx.String("address") + + if endpointAddr != "" { + // Reject the endpoint if it points to the gateway handler itself. + sameTarget, err := sameLocalAddrs(endpointAddr, serverAddr) + fatalIf(err, "Unable to compare server and endpoint addresses.") + if sameTarget { + fatalIf(errors.New("endpoint points to the local gateway"), "Endpoint url is not allowed") + } + } // Second argument is endpoint. If no endpoint is specified then the // gateway implementation should use a default setting. - endPoint, secure, err := parseGatewayEndpoint(ctx.Args().Get(1)) + endPoint, secure, err := parseGatewayEndpoint(endpointAddr) fatalIf(err, "Unable to parse endpoint") // Create certs path for SSL configuration. @@ -223,7 +236,7 @@ func gatewayMain(ctx *cli.Context) { setAuthHandler, } - apiServer := NewServerMux(ctx.String("address"), registerHandlers(router, handlerFns...)) + apiServer := NewServerMux(serverAddr, registerHandlers(router, handlerFns...)) _, _, globalIsSSL, err = getSSLConfig() fatalIf(err, "Invalid SSL key file") diff --git a/cmd/net.go b/cmd/net.go index 9f8d6021c..ba62c0d77 100644 --- a/cmd/net.go +++ b/cmd/net.go @@ -17,11 +17,14 @@ package cmd import ( + "errors" "fmt" "net" + "net/url" "os" "sort" "strconv" + "strings" "syscall" "github.com/minio/minio-go/pkg/set" @@ -186,6 +189,121 @@ func checkPortAvailability(port string) (err error) { return nil } +// extractHostPort - extracts host/port from many address formats +// such as, ":9000", "localhost:9000", "http://localhost:9000/" +func extractHostPort(hostAddr string) (string, string, error) { + var addr, scheme string + + if hostAddr == "" { + return "", "", errors.New("unable to process empty address") + } + + // Parse address to extract host and scheme field + u, err := url.Parse(hostAddr) + if err != nil { + // Ignore scheme not present error + if !strings.Contains(err.Error(), "missing protocol scheme") { + return "", "", err + } + } else { + addr = u.Host + scheme = u.Scheme + } + + // Use the given parameter again if url.Parse() + // didn't return any useful result. + if addr == "" { + addr = hostAddr + scheme = "http" + } + + // At this point, addr can be one of the following form: + // ":9000" + // "localhost:9000" + // "localhost" <- in this case, we check for scheme + + host, port, err := net.SplitHostPort(addr) + if err != nil { + if !strings.Contains(err.Error(), "missing port in address") { + return "", "", err + } + + host = addr + + switch scheme { + case "https": + port = "443" + case "http": + port = "80" + default: + return "", "", errors.New("unable to guess port from scheme") + } + } + + return host, port, nil +} + +// isLocalHost - checks if the given parameter +// correspond to one of the local IP of the +// current machine +func isLocalHost(host string) (bool, error) { + hostIPs, err := getHostIP4(host) + if err != nil { + return false, err + } + + // If intersection of two IP sets is not empty, then the host is local host. + isLocal := !localIP4.Intersection(hostIPs).IsEmpty() + return isLocal, nil +} + +// sameLocalAddrs - returns true if two addresses, even with different +// formats, point to the same machine, e.g: +// ':9000' and 'http://localhost:9000/' will return true +func sameLocalAddrs(addr1, addr2 string) (bool, error) { + + // Extract host & port from given parameters + host1, port1, err := extractHostPort(addr1) + if err != nil { + return false, err + } + host2, port2, err := extractHostPort(addr2) + if err != nil { + return false, err + } + + var addr1Local, addr2Local bool + + if host1 == "" { + // If empty host means it is localhost + addr1Local = true + } else { + // Host not empty, check if it is local + if addr1Local, err = isLocalHost(host1); err != nil { + return false, err + } + } + + if host2 == "" { + // If empty host means it is localhost + addr2Local = true + } else { + // Host not empty, check if it is local + if addr2Local, err = isLocalHost(host2); err != nil { + return false, err + } + } + + // If both of addresses point to the same machine, check if + // have the same port + if addr1Local && addr2Local { + if port1 == port2 { + return true, nil + } + } + return false, nil +} + // CheckLocalServerAddr - checks if serverAddr is valid and local host. func CheckLocalServerAddr(serverAddr string) error { host, port, err := net.SplitHostPort(serverAddr) @@ -202,12 +320,11 @@ func CheckLocalServerAddr(serverAddr string) error { } if host != "" { - hostIPs, err := getHostIP4(host) + isLocalHost, err := isLocalHost(host) if err != nil { return err } - - if localIP4.Intersection(hostIPs).IsEmpty() { + if !isLocalHost { return fmt.Errorf("host in server address should be this server") } } diff --git a/cmd/net_test.go b/cmd/net_test.go index e95e81b9d..9f2e9c9a0 100644 --- a/cmd/net_test.go +++ b/cmd/net_test.go @@ -17,6 +17,7 @@ package cmd import ( + "errors" "fmt" "net" "reflect" @@ -240,3 +241,77 @@ func TestCheckLocalServerAddr(t *testing.T) { } } } + +func TestExtractHostPort(t *testing.T) { + testCases := []struct { + addr string + host string + port string + expectedErr error + }{ + {"", "", "", errors.New("unable to process empty address")}, + {"localhost", "localhost", "80", nil}, + {"localhost:9000", "localhost", "9000", nil}, + {"http://:9000/", "", "9000", nil}, + {"http://8.8.8.8:9000/", "8.8.8.8", "9000", nil}, + {"https://facebook.com:9000/", "facebook.com", "9000", nil}, + } + + for i, testCase := range testCases { + host, port, err := extractHostPort(testCase.addr) + if testCase.expectedErr == nil { + if err != nil { + t.Fatalf("Test %d: should succeed but failed with err: %v", i+1, err) + } + if host != testCase.host { + t.Fatalf("Test %d: expected: %v, found: %v", i+1, testCase.host, host) + } + if port != testCase.port { + t.Fatalf("Test %d: expected: %v, found: %v", i+1, testCase.port, port) + } + + } + if testCase.expectedErr != nil { + if err == nil { + t.Fatalf("Test %d:, should fail but succeeded.", i+1) + } + if testCase.expectedErr.Error() != err.Error() { + t.Fatalf("Test %d: failed with different error, expected: '%v', found:'%v'.", i+1, testCase.expectedErr, err) + } + } + } +} + +func TestSameLocalAddrs(t *testing.T) { + testCases := []struct { + addr1 string + addr2 string + sameAddr bool + expectedErr error + }{ + {"", "", false, errors.New("unable to process empty address")}, + {":9000", ":9000", true, nil}, + {"localhost:9000", ":9000", true, nil}, + {"localhost:9000", "http://localhost:9000", true, nil}, + {"8.8.8.8:9000", "http://localhost:9000", false, nil}, + } + + for i, testCase := range testCases { + sameAddr, err := sameLocalAddrs(testCase.addr1, testCase.addr2) + if testCase.expectedErr != nil && err == nil { + t.Fatalf("Test %d: should fail but succeeded", i+1) + } + if testCase.expectedErr == nil && err != nil { + t.Fatalf("Test %d: should succeed but failed with %v", i+1, err) + } + if err == nil { + if sameAddr != testCase.sameAddr { + t.Fatalf("Test %d: expected: %v, found: %v", i+1, testCase.sameAddr, sameAddr) + } + } else { + if err.Error() != testCase.expectedErr.Error() { + t.Fatalf("Test %d: failed with different error, expected: '%v', found:'%v'.", i+1, testCase.expectedErr, err) + } + } + } +}