From bd2131ba349ec06e832d467e2a74526305637b19 Mon Sep 17 00:00:00 2001 From: Harshavardhana Date: Fri, 16 Oct 2020 14:49:05 -0700 Subject: [PATCH] add DNS cache support to avoid DNS flooding (#10693) Go stdlib resolver doesn't support caching DNS resolutions, since we compile with CGO disabled we are more probe to DNS flooding for all network calls to resolve for DNS from the DNS server. Under various containerized environments such as VMWare this becomes a problem because there are no DNS caches available and we may end up overloading the kube-dns resolver under concurrent I/O. To circumvent this issue implement a DNSCache resolver which resolves DNS and caches them for around 10secs with every 3sec invalidation attempted. --- cmd/gateway-main.go | 7 ++ cmd/globals.go | 2 + cmd/http/dial_dnscache.go | 197 +++++++++++++++++++++++++++++++ cmd/http/dial_dnscache_test.go | 204 +++++++++++++++++++++++++++++++++ cmd/server-main.go | 6 + cmd/test-utils_test.go | 3 + cmd/utils.go | 6 +- 7 files changed, 422 insertions(+), 3 deletions(-) create mode 100644 cmd/http/dial_dnscache.go create mode 100644 cmd/http/dial_dnscache_test.go diff --git a/cmd/gateway-main.go b/cmd/gateway-main.go index 0db4cf369..7e7187d4c 100644 --- a/cmd/gateway-main.go +++ b/cmd/gateway-main.go @@ -19,12 +19,14 @@ package cmd import ( "context" "fmt" + "math/rand" "net" "net/url" "os" "os/signal" "strings" "syscall" + "time" "github.com/gorilla/mux" "github.com/minio/cli" @@ -153,6 +155,11 @@ func ValidateGatewayArguments(serverAddr, endpointAddr string) error { // StartGateway - handler for 'minio gateway '. func StartGateway(ctx *cli.Context, gw Gateway) { + rand.Seed(time.Now().UTC().UnixNano()) + + globalDNSCache = xhttp.NewDNSCache(3*time.Second, 10*time.Second) + defer globalDNSCache.Stop() + // This is only to uniquely identify each gateway deployments. globalDeploymentID = env.Get("MINIO_GATEWAY_DEPLOYMENT_ID", mustGetUUID()) logger.SetDeploymentID(globalDeploymentID) diff --git a/cmd/globals.go b/cmd/globals.go index ceac88e65..b96c31071 100644 --- a/cmd/globals.go +++ b/cmd/globals.go @@ -274,6 +274,8 @@ var ( globalFSOSync bool globalProxyEndpoints []ProxyEndpoint + + globalDNSCache *xhttp.DNSCache // Add new variable global values here. ) diff --git a/cmd/http/dial_dnscache.go b/cmd/http/dial_dnscache.go new file mode 100644 index 000000000..8c139f9a1 --- /dev/null +++ b/cmd/http/dial_dnscache.go @@ -0,0 +1,197 @@ +/* + * MinIO Cloud Storage, (C) 2020 MinIO, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "context" + "log" + "math/rand" + "net" + "sync" + "time" +) + +var randPerm = func(n int) []int { + return rand.Perm(n) +} + +// DialContextWithDNSCache is a helper function which returns `net.DialContext` function. +// It randomly fetches an IP from the DNS cache and dials it by the given dial +// function. It dials one by one and returns first connected `net.Conn`. +// If it fails to dial all IPs from cache it returns first error. If no baseDialFunc +// is given, it sets default dial function. +// +// You can use returned dial function for `http.Transport.DialContext`. +// +// In this function, it uses functions from `rand` package. To make it really random, +// you MUST call `rand.Seed` and change the value from the default in your application +func DialContextWithDNSCache(cache *DNSCache, baseDialCtx DialContext) DialContext { + if baseDialCtx == nil { + // This is same as which `http.DefaultTransport` uses. + baseDialCtx = (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext + } + return func(ctx context.Context, network, host string) (net.Conn, error) { + h, p, err := net.SplitHostPort(host) + if err != nil { + return nil, err + } + + // Fetch DNS result from cache. + // + // ctxLookup is only used for canceling DNS Lookup. + ctxLookup, cancelF := context.WithTimeout(ctx, cache.lookupTimeout) + defer cancelF() + addrs, err := cache.Fetch(ctxLookup, h) + if err != nil { + return nil, err + } + + var firstErr error + for _, randomIndex := range randPerm(len(addrs)) { + conn, err := baseDialCtx(ctx, "tcp", net.JoinHostPort(addrs[randomIndex], p)) + if err == nil { + return conn, nil + } + if firstErr == nil { + firstErr = err + } + } + + return nil, firstErr + } +} + +const ( + // cacheSize is initial size of addr and IP list cache map. + cacheSize = 64 +) + +// defaultFreq is default frequency a resolver refreshes DNS cache. +var ( + defaultFreq = 3 * time.Second + defaultLookupTimeout = 10 * time.Second +) + +// DNSCache is DNS cache resolver which cache DNS resolve results in memory. +type DNSCache struct { + sync.RWMutex + + lookupHostFn func(ctx context.Context, host string) ([]string, error) + lookupTimeout time.Duration + + cache map[string][]string + closer func() +} + +// NewDNSCache initializes DNS cache resolver and starts auto refreshing +// in a new goroutine. To stop auto refreshing, call `Stop()` function. +// Once `Stop()` is called auto refreshing cannot be resumed. +func NewDNSCache(freq time.Duration, lookupTimeout time.Duration) *DNSCache { + if freq <= 0 { + freq = defaultFreq + } + + if lookupTimeout <= 0 { + lookupTimeout = defaultLookupTimeout + } + + ticker := time.NewTicker(freq) + ch := make(chan struct{}) + closer := func() { + ticker.Stop() + close(ch) + } + + r := &DNSCache{ + lookupHostFn: net.DefaultResolver.LookupHost, + lookupTimeout: lookupTimeout, + cache: make(map[string][]string, cacheSize), + closer: closer, + } + + go func() { + for { + select { + case <-ticker.C: + r.Refresh() + case <-ch: + return + } + } + }() + + return r +} + +// LookupHost lookups address list from DNS server, persist the results +// in-memory cache. `Fetch` is used to obtain the values for a given host. +func (r *DNSCache) LookupHost(ctx context.Context, host string) ([]string, error) { + addrs, err := r.lookupHostFn(ctx, host) + if err != nil { + return nil, err + } + + r.Lock() + r.cache[host] = addrs + r.Unlock() + + return addrs, nil +} + +// Fetch fetches IP list from the cache. If IP list of the given addr is not in the cache, +// then it lookups from DNS server by `Lookup` function. +func (r *DNSCache) Fetch(ctx context.Context, host string) ([]string, error) { + r.RLock() + addrs, ok := r.cache[host] + r.RUnlock() + if ok { + return addrs, nil + } + return r.LookupHost(ctx, host) +} + +// Refresh refreshes IP list cache, automatically. +func (r *DNSCache) Refresh() { + r.RLock() + hosts := make([]string, 0, len(r.cache)) + for host := range r.cache { + hosts = append(hosts, host) + } + r.RUnlock() + + for _, host := range hosts { + ctx, cancelF := context.WithTimeout(context.Background(), r.lookupTimeout) + if _, err := r.LookupHost(ctx, host); err != nil { + log.Println("failed to refresh DNS cache, resolver is unavailable", err) + } + cancelF() + } +} + +// Stop stops auto refreshing. +func (r *DNSCache) Stop() { + r.Lock() + defer r.Unlock() + if r.closer != nil { + r.closer() + r.closer = nil + } +} diff --git a/cmd/http/dial_dnscache_test.go b/cmd/http/dial_dnscache_test.go new file mode 100644 index 000000000..2833a1a84 --- /dev/null +++ b/cmd/http/dial_dnscache_test.go @@ -0,0 +1,204 @@ +/* + * MinIO Cloud Storage, (C) 2020 MinIO, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package http + +import ( + "context" + "errors" + "fmt" + "math/rand" + "net" + "testing" + "time" +) + +var ( + testFreq = 1 * time.Second + testDefaultLookupTimeout = 1 * time.Second +) + +func testDNSCache(t *testing.T) *DNSCache { + t.Helper() // skip printing file and line information from this function + return NewDNSCache(testFreq, testDefaultLookupTimeout) +} + +func TestDialContextWithDNSCache(t *testing.T) { + resolver := &DNSCache{ + cache: map[string][]string{ + "play.min.io": { + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + }, + }, + } + + cases := []struct { + permF func(n int) []int + dialF DialContext + }{ + { + permF: func(n int) []int { + return []int{0} + }, + dialF: func(ctx context.Context, network, addr string) (net.Conn, error) { + if got, want := addr, net.JoinHostPort("127.0.0.1", "443"); got != want { + t.Fatalf("got addr %q, want %q", got, want) + } + return nil, nil + }, + }, + { + permF: func(n int) []int { + return []int{1} + }, + dialF: func(ctx context.Context, network, addr string) (net.Conn, error) { + if got, want := addr, net.JoinHostPort("127.0.0.2", "443"); got != want { + t.Fatalf("got addr %q, want %q", got, want) + } + return nil, nil + }, + }, + { + permF: func(n int) []int { + return []int{2} + }, + dialF: func(ctx context.Context, network, addr string) (net.Conn, error) { + if got, want := addr, net.JoinHostPort("127.0.0.3", "443"); got != want { + t.Fatalf("got addr %q, want %q", got, want) + } + return nil, nil + }, + }, + } + + origFunc := randPerm + defer func() { + randPerm = origFunc + }() + + for _, tc := range cases { + t.Run("", func(t *testing.T) { + randPerm = tc.permF + if _, err := DialContextWithDNSCache(resolver, tc.dialF)(context.Background(), "tcp", "play.min.io:443"); err != nil { + t.Fatalf("err: %s", err) + } + }) + } + +} + +func TestDialContextWithDNSCacheRand(t *testing.T) { + rand.Seed(time.Now().UTC().UnixNano()) + defer func() { + rand.Seed(1) + }() + + resolver := &DNSCache{ + cache: map[string][]string{ + "play.min.io": { + "127.0.0.1", + "127.0.0.2", + "127.0.0.3", + }, + }, + } + + count := make(map[string]int) + dialF := func(ctx context.Context, network, addr string) (net.Conn, error) { + count[addr]++ + return nil, nil + } + + for i := 0; i < 100; i++ { + if _, err := DialContextWithDNSCache(resolver, dialF)(context.Background(), "tcp", "play.min.io:443"); err != nil { + t.Fatalf("err: %s", err) + } + } + + for _, c := range count { + got := float32(c) / float32(100) + if got < float32(0.2) { + t.Fatalf("expected 0.2 rate got %f", got) + } + } +} + +// Verify without port Dial fails, Go stdlib net.Dial expects port +func TestDialContextWithDNSCacheScenario1(t *testing.T) { + resolver := testDNSCache(t) + if _, err := DialContextWithDNSCache(resolver, nil)(context.Background(), "tcp", "play.min.io"); err == nil { + t.Fatalf("expect to fail") // expected port + } +} + +// Verify if the host lookup function failed to return addresses +func TestDialContextWithDNSCacheScenario2(t *testing.T) { + res := testDNSCache(t) + originalFunc := res.lookupHostFn + defer func() { + res.lookupHostFn = originalFunc + }() + + res.lookupHostFn = func(ctx context.Context, host string) ([]string, error) { + return nil, fmt.Errorf("err") + } + + if _, err := DialContextWithDNSCache(res, nil)(context.Background(), "tcp", "min.io:443"); err == nil { + t.Fatalf("exect to fail") + } +} + +// Verify we always return the first error from net.Dial failure +func TestDialContextWithDNSCacheScenario3(t *testing.T) { + resolver := &DNSCache{ + cache: map[string][]string{ + "min.io": { + "1.1.1.1", + "2.2.2.2", + "3.3.3.3", + }, + }, + } + + origFunc := randPerm + randPerm = func(n int) []int { + return []int{0, 1, 2} + } + defer func() { + randPerm = origFunc + }() + + want := errors.New("error1") + dialF := func(ctx context.Context, network, addr string) (net.Conn, error) { + if addr == net.JoinHostPort("1.1.1.1", "443") { + return nil, want // first error should be returned + } + if addr == net.JoinHostPort("2.2.2.2", "443") { + return nil, fmt.Errorf("error2") + } + if addr == net.JoinHostPort("3.3.3.3", "443") { + return nil, fmt.Errorf("error3") + } + return nil, nil + } + + _, got := DialContextWithDNSCache(resolver, dialF)(context.Background(), "tcp", "min.io:443") + if got != want { + t.Fatalf("got error %v, want %v", got, want) + } +} diff --git a/cmd/server-main.go b/cmd/server-main.go index a7af0e596..f2b47e515 100644 --- a/cmd/server-main.go +++ b/cmd/server-main.go @@ -20,6 +20,7 @@ import ( "context" "errors" "fmt" + "math/rand" "net" "os" "os/signal" @@ -361,6 +362,11 @@ func initAllSubsystems(ctx context.Context, newObject ObjectLayer) (err error) { // serverMain handler called for 'minio server' command. func serverMain(ctx *cli.Context) { + rand.Seed(time.Now().UTC().UnixNano()) + + globalDNSCache = xhttp.NewDNSCache(3*time.Second, 10*time.Second) + defer globalDNSCache.Stop() + signal.Notify(globalOSSignalCh, os.Interrupt, syscall.SIGTERM, syscall.SIGQUIT) go handleSignals() diff --git a/cmd/test-utils_test.go b/cmd/test-utils_test.go index 1f7df4426..d207c9674 100644 --- a/cmd/test-utils_test.go +++ b/cmd/test-utils_test.go @@ -58,6 +58,7 @@ import ( "github.com/minio/minio-go/v7/pkg/signer" "github.com/minio/minio/cmd/config" "github.com/minio/minio/cmd/crypto" + xhttp "github.com/minio/minio/cmd/http" "github.com/minio/minio/cmd/logger" "github.com/minio/minio/pkg/auth" "github.com/minio/minio/pkg/bucket/policy" @@ -99,6 +100,8 @@ func init() { logger.Disable = true + globalDNSCache = xhttp.NewDNSCache(3*time.Second, 10*time.Second) + initHelp() resetTestGlobals() diff --git a/cmd/utils.go b/cmd/utils.go index 19d1f7f58..deaed5e8f 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -468,7 +468,7 @@ func newInternodeHTTPTransport(tlsConfig *tls.Config, dialTimeout time.Duration) // https://golang.org/pkg/net/http/#Transport documentation tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, - DialContext: xhttp.NewInternodeDialContext(dialTimeout), + DialContext: xhttp.DialContextWithDNSCache(globalDNSCache, xhttp.NewInternodeDialContext(dialTimeout)), MaxIdleConnsPerHost: 1024, IdleConnTimeout: 15 * time.Second, ResponseHeaderTimeout: 3 * time.Minute, // Set conservative timeouts for MinIO internode. @@ -496,7 +496,7 @@ func newCustomHTTPProxyTransport(tlsConfig *tls.Config, dialTimeout time.Duratio // https://golang.org/pkg/net/http/#Transport documentation tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, - DialContext: xhttp.NewCustomDialContext(dialTimeout), + DialContext: xhttp.DialContextWithDNSCache(globalDNSCache, xhttp.NewInternodeDialContext(dialTimeout)), MaxIdleConnsPerHost: 1024, IdleConnTimeout: 15 * time.Second, ResponseHeaderTimeout: 30 * time.Minute, // Set larger timeouts for proxied requests. @@ -519,7 +519,7 @@ func newCustomHTTPTransport(tlsConfig *tls.Config, dialTimeout time.Duration) fu // https://golang.org/pkg/net/http/#Transport documentation tr := &http.Transport{ Proxy: http.ProxyFromEnvironment, - DialContext: xhttp.NewCustomDialContext(dialTimeout), + DialContext: xhttp.DialContextWithDNSCache(globalDNSCache, xhttp.NewInternodeDialContext(dialTimeout)), MaxIdleConnsPerHost: 1024, IdleConnTimeout: 15 * time.Second, ResponseHeaderTimeout: 3 * time.Minute, // Set conservative timeouts for MinIO internode.