diff --git a/cmd/host-to-ip.go b/cmd/host-to-ip.go index a20de6c7f..36686b9c3 100644 --- a/cmd/host-to-ip.go +++ b/cmd/host-to-ip.go @@ -17,6 +17,7 @@ package cmd import ( + "fmt" "net" "sort" ) @@ -31,12 +32,20 @@ func (n byLastOctet) Less(i, j int) bool { return []byte(n[i].To4())[3] < []byte(n[j].To4())[3] } -// getIPsFromHosts - returns a reverse sorted list of ips based on the last octet value. -func getIPsFromHosts(hosts []string) (ips []net.IP) { - for _, host := range hosts { - ips = append(ips, net.ParseIP(host)) +// sortIPsByOctet - returns a reverse sorted list of hosts based on the last octet value. +func sortIPsByOctet(ips []string) error { + var nips []net.IP + for _, ip := range ips { + nip := net.ParseIP(ip) + if nip == nil { + return fmt.Errorf("Unable to parse invalid ip %s", ip) + } + nips = append(nips, nip) } // Reverse sort ips by their last octet. - sort.Sort(sort.Reverse(byLastOctet(ips))) - return ips + sort.Sort(sort.Reverse(byLastOctet(nips))) + for i, nip := range nips { + ips[i] = nip.String() + } + return nil } diff --git a/cmd/host-to-ip_test.go b/cmd/host-to-ip_test.go index e29c22de6..85e75b4f2 100644 --- a/cmd/host-to-ip_test.go +++ b/cmd/host-to-ip_test.go @@ -16,18 +16,23 @@ package cmd -import "testing" +import ( + "fmt" + "testing" +) // Tests sorted list generated from hosts to ip. func TestHostToIP(t *testing.T) { // Collection of test cases to validate last octet sorting. testCases := []struct { - hosts []string - sortedHosts []string + ips []string + sortedIPs []string + err error + shouldPass bool }{ { // List of ip addresses that need to be sorted. - []string{ + ips: []string{ "129.95.30.40", "5.24.69.2", "19.20.203.5", @@ -37,7 +42,7 @@ func TestHostToIP(t *testing.T) { "5.220.100.50", }, // Numerical sorting result based on the last octet. - []string{ + sortedIPs: []string{ "5.220.100.50", "129.95.30.40", "19.20.21.22", @@ -46,15 +51,34 @@ func TestHostToIP(t *testing.T) { "5.24.69.2", "127.0.0.1", }, + err: nil, + shouldPass: true, + }, + { + ips: []string{ + "localhost", + }, + sortedIPs: []string{}, + err: fmt.Errorf("Unable to parse invalid ip localhost"), + shouldPass: false, }, } // Tests the correct sorting behavior of getIPsFromHosts. for j, testCase := range testCases { - ips := getIPsFromHosts(testCase.hosts) - for i, ip := range ips { - if ip.String() != testCase.sortedHosts[i] { - t.Fatalf("Test %d expected to pass but failed. Wanted ip %s, but got %s", j+1, testCase.sortedHosts[i], ip.String()) + err := sortIPsByOctet(testCase.ips) + if !testCase.shouldPass && testCase.err.Error() != err.Error() { + t.Fatalf("Test %d: Expected error %s, got %s", j+1, testCase.err, err) + } + if testCase.shouldPass && err != nil { + t.Fatalf("Test %d: Expected error %s", j+1, err) + } + if testCase.shouldPass { + for i, ip := range testCase.ips { + if ip == testCase.sortedIPs[i] { + continue + } + t.Errorf("Test %d expected to pass but failed. Wanted ip %s, but got %s", j+1, testCase.sortedIPs[i], ip) } } } diff --git a/cmd/server-main.go b/cmd/server-main.go index 97028fc93..ba51ab19b 100644 --- a/cmd/server-main.go +++ b/cmd/server-main.go @@ -215,43 +215,52 @@ func parseStorageEndPoints(eps []string, defaultPort int) (endpoints []storageEn } // getListenIPs - gets all the ips to listen on. -func getListenIPs(httpServerConf *http.Server) (hosts []string, port string) { - host, port, err := net.SplitHostPort(httpServerConf.Addr) - fatalIf(err, "Unable to parse host address.", httpServerConf.Addr) - - if host != "" { - hosts = append(hosts, host) - return hosts, port +func getListenIPs(httpServerConf *http.Server) (hosts []string, port string, err error) { + var host string + host, port, err = net.SplitHostPort(httpServerConf.Addr) + if err != nil { + return nil, port, fmt.Errorf("Unable to parse host address %s", err) } - addrs, err := net.InterfaceAddrs() - fatalIf(err, "Unable to determine network interface address.") - for _, addr := range addrs { - if addr.Network() == "ip+net" { - host := strings.Split(addr.String(), "/")[0] - if ip := net.ParseIP(host); ip.To4() != nil { - hosts = append(hosts, host) + if host == "" { + var addrs []net.Addr + addrs, err = net.InterfaceAddrs() + if err != nil { + return nil, port, fmt.Errorf("Unable to determine network interface address. %s", err) + } + for _, addr := range addrs { + if addr.Network() == "ip+net" { + hostname := strings.Split(addr.String(), "/")[0] + if ip := net.ParseIP(hostname); ip.To4() != nil { + hosts = append(hosts, hostname) + } } } - } - return hosts, port + err = sortIPsByOctet(hosts) + if err != nil { + return nil, port, fmt.Errorf("Unable reverse sorted ips from hosts %s", err) + } + return hosts, port, nil + } // if host != "" { + // Proceed to append itself, since user requested a specific endpoint. + hosts = append(hosts, host) + return hosts, port, nil } // Finalizes the endpoints based on the host list and port. func finalizeEndpoints(tls bool, apiServer *http.Server) (endPoints []string) { - // Get list of listen ips and port. - hosts, port := getListenIPs(apiServer) - // Verify current scheme. scheme := "http" if tls { scheme = "https" } - ips := getIPsFromHosts(hosts) + // Get list of listen ips and port. + hosts, port, err := getListenIPs(apiServer) + fatalIf(err, "Unable to get list of ips to listen on") // Construct proper endpoints. - for _, ip := range ips { - endPoints = append(endPoints, fmt.Sprintf("%s://%s:%s", scheme, ip.String(), port)) + for _, host := range hosts { + endPoints = append(endPoints, fmt.Sprintf("%s://%s:%s", scheme, host, port)) } // Success. diff --git a/cmd/server-main_test.go b/cmd/server-main_test.go index b754ea795..89e64575b 100644 --- a/cmd/server-main_test.go +++ b/cmd/server-main_test.go @@ -28,15 +28,39 @@ import ( func TestGetListenIPs(t *testing.T) { testCases := []struct { - addr string - port string + addr string + port string + shouldPass bool }{ - {"localhost", ":80"}, - {"", ":80"}, + {"localhost", "9000", true}, + {"", "9000", true}, + {"", "", false}, } for _, test := range testCases { - ts := &http.Server{Addr: test.addr + test.port} - getListenIPs(ts) + var addr string + // Please keep this we need to do this because + // of odd https://play.golang.org/p/4dMPtM6Wdd + // implementation issue. + if test.port != "" { + addr = test.addr + ":" + test.port + } + hosts, port, err := getListenIPs(&http.Server{ + Addr: addr, + }) + if !test.shouldPass && err == nil { + t.Fatalf("Test should fail but succeeded %s", err) + } + if test.shouldPass && err != nil { + t.Fatalf("Test should succeed but failed %s", err) + } + if test.shouldPass { + if port != test.port { + t.Errorf("Test expected %s, got %s", test.port, port) + } + if len(hosts) == 0 { + t.Errorf("Test unexpected value hosts cannot be empty %#v", test) + } + } } }