From fd2203b1b7f981a804b4bd8981d690f0d0e58174 Mon Sep 17 00:00:00 2001 From: Harshavardhana Date: Wed, 29 Apr 2015 18:01:49 -0700 Subject: [PATCH] Some more improvements to connection limit --- pkg/api/quota/conn_limit.go | 22 ++++++++++++++-------- pkg/api/quota/errors.go | 2 +- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/pkg/api/quota/conn_limit.go b/pkg/api/quota/conn_limit.go index 913e2142a..0089013d6 100644 --- a/pkg/api/quota/conn_limit.go +++ b/pkg/api/quota/conn_limit.go @@ -32,20 +32,24 @@ type connLimit struct { limit int } +func (c *connLimit) IsLimitExceeded(ip uint32) bool { + if c.connections[ip] >= c.limit { + return true + } + return false +} + func (c *connLimit) GetUsed(ip uint32) int { return c.connections[ip] } -func (c *connLimit) TestAndAdd(ip uint32) bool { +func (c *connLimit) Add(ip uint32) { c.Lock() defer c.Unlock() - count, _ := c.connections[ip] - if count >= c.limit { - return false - } + count := c.connections[ip] count = count + 1 c.connections[ip] = count - return true + return } func (c *connLimit) Remove(ip uint32) { @@ -64,11 +68,13 @@ func (c *connLimit) Remove(ip uint32) { func (c *connLimit) ServeHTTP(w http.ResponseWriter, req *http.Request) { host, _, _ := net.SplitHostPort(req.RemoteAddr) longIP := longIP{net.ParseIP(host)}.IptoUint32() - if !c.TestAndAdd(longIP) { + if c.IsLimitExceeded(longIP) { hosts, _ := net.LookupAddr(uint32ToIP(longIP).String()) - log.Debug.Printf("Offending Host: %s, ConnectionsUSED: %d\n", hosts, c.GetUsed(longIP)) + log.Debug.Printf("Connection limit reached - Host: %s, Total Connections: %d\n", hosts, c.GetUsed(longIP)) writeErrorResponse(w, req, ConnectionLimitExceeded, req.URL.Path) + return } + c.Add(longIP) defer c.Remove(longIP) c.handler.ServeHTTP(w, req) } diff --git a/pkg/api/quota/errors.go b/pkg/api/quota/errors.go index 8fe91a965..c4ad019e5 100644 --- a/pkg/api/quota/errors.go +++ b/pkg/api/quota/errors.go @@ -58,11 +58,11 @@ const ( func writeErrorResponse(w http.ResponseWriter, req *http.Request, errorType int, resource string) { error := getErrorCode(errorType) errorResponse := getErrorResponse(error, resource) + encodedErrorResponse := encodeErrorResponse(errorResponse) // set headers writeErrorHeaders(w) w.WriteHeader(error.HTTPStatusCode) // write body - encodedErrorResponse := encodeErrorResponse(errorResponse) w.Write(encodedErrorResponse) }