Connection Limiter
Posted on
作为服务提供商,我们肯定是希望越多客户使用我们的服务越好,为此我们不惜花大价钱雇程序员来搞高并发服务器编程,砸很多的钱买最好的服务器,但是总有些无良用户想要和我们作对,搞很多 HTTP
连接请求恶意占用服务器的资源,导致其他用户的服务体验下降,从而最终导致客户的流失,这种情况我们肯定是希望极力避免的。
一般来说,正常的客户端(人为操作)不会在短时间内对同一服务发送过多的请求,只有想要实施恶意攻击行为的客户端(例如爬虫)才会同时发送很多请求来占用服务器的资源。为了避免这种情况的发生,我们需要限制同一个 IP
地址的请求数量。
实现连接限制的思路比较简单,我们通过一个 map
来记录来自同一个 IP
地址的 HTTP
请求的个数,如果在正常范围内,则给予该客户端正常的服务,如果超过上限,此时该客户端被怀疑正在进行爬虫之类的非善意行为,对此我们返回一个错误并拒绝服务该请求。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| type ConnLimiter struct { sync.Mutex connections map[string]int64 maxConnections int64 totalConnections int64 next http.Handler errHandler ErrorHandler }
func New(next http.Handler, maxConnections int64) (*ConnLimiter, error) { cl := &ConnLimiter { maxConnections: maxConnections, connections: make(map[string]int64), next: next, } if cl.errHandler == nil { cl.errHandler = defaultErrHandler } return cl, nil }
|
因为我们是对同一个IP
地址设置连接限制,故对于每个请求,通过读取请求 http.Request
中的 RemoteAddr
来获取 IP
地址。但是上有政策,下游对策,很多写爬虫的人通过代理 IP
池来规避这种审查,这里我们先不讨论如何处理这种情况。
1 2 3 4 5 6 7
| func extractClientIP(req *http.Request) (string, int64, error) { vals := strings.SplitN(req.RemoteAddr, ":", 2) if len(vals[0]) == 0 { return "", 0, fmt.Errorf("Failed to parse client IP: %v", req.RemoteAddr) } return vals[0], 1, nil }
|
我们还需要定义当同一个 IP
地址的连接数量超过最高连接数量的错误 MaxConnError
以及处理该错误的方法。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| type MaxConnError struct { max int64 }
func (e *MaxConnError) Error() string { return fmt.Sprintf("max connections reached: %d", m.max) }
type ConnErrHandler struct {}
func (h *ConnErrHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { statusCode := http.StatusInternalServerError w.WriteHeader(statusCode) w.Write([]byte(http.StatusText(statusCode))) }
var defaultErrHandler = &ConnErrHandler{}
|
acquire
和 release
方法在加锁的条件下对 map
进行操作,这是因为在并发情况下对同一个数据进行读写操作时,会发生数据竞争的情况,所以需要使用 sync.Mutex
来对数据读写进行保护。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
| func (cl *ConnLimiter) acquire(token string, amount int64) error { cl.Lock() defer cl.Unlock()
connections := cl.connections[token] if connections >= cl.maxConnections { return &MaxConnError{max: cl.maxConnections} }
cl.connections[token] += amount cl.totalConnections += amount return nil }
func (cl *ConnLimiter) release(token string, amount int64) { cl.Lock() defer cl.Unlock()
cl.connections[token] -= amount cl.totalConnections -= amount if cl.connections[token] == 0 { delete(cl.connections, token) } }
|
在提供服务( ServeHTTP
)之前,我们通过 acquire
判断该客户端是否有资格获得服务,如果有资格则使用正常的 handler
来处理,否则则使用 errHandler
来处理,最后服务完了需要将相关资源释放( release
)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) { token, amount, err := extractClientIP(r) if err != nil { log.Errorf("failed to extract source of the connection: %v", err) cl.errHandler.ServeHTTP(w, r) return } if err = cl.acquire(token, amount); err != nil { log.Debugf("limiting request source %s: %v", token, err) cl.errHandler.ServeHTTP(w, r) return } defer cl.release(token, amount)
cl.next.ServeHTTP(w, r) }
func (cl *ConnLimiter) Wrap(h http.Handler) { cl.next = h }
|