Connection Limiter

作为服务提供商,我们肯定是希望越多客户使用我们的服务越好,为此我们不惜花大价钱雇程序员来搞高并发服务器编程,砸很多的钱买最好的服务器,但是总有些无良用户想要和我们作对,搞很多 HTTP 连接请求恶意占用服务器的资源,导致其他用户的服务体验下降,从而最终导致客户的流失,这种情况我们肯定是希望极力避免的。

一般来说,正常的客户端(人为操作)不会在短时间内对同一服务发送过多的请求,只有想要实施恶意攻击行为的客户端(例如爬虫)才会同时发送很多请求来占用服务器的资源。为了避免这种情况的发生,我们需要限制同一个 IP 地址的请求数量。

实现连接限制的思路比较简单,我们通过一个 map 来记录来自同一个 IP 地址的 HTTP 请求的个数,如果在正常范围内,则给予该客户端正常的服务,如果超过上限,此时该客户端被怀疑正在进行爬虫之类的非善意行为,对此我们返回一个错误并拒绝服务该请求。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
type ConnLimiter struct {
sync.Mutex
connections map[string]int64
maxConnections int64
totalConnections int64
next http.Handler
errHandler ErrorHandler
}

func New(next http.Handler, maxConnections int64) (*ConnLimiter, error) {
cl := &ConnLimiter {
maxConnections: maxConnections,
connections: make(map[string]int64),
next: next,
}
if cl.errHandler == nil {
cl.errHandler = defaultErrHandler
}
return cl, nil
}

因为我们是对同一个IP 地址设置连接限制,故对于每个请求,通过读取请求 http.Request 中的 RemoteAddr 来获取 IP 地址。但是上有政策,下游对策,很多写爬虫的人通过代理 IP 池来规避这种审查,这里我们先不讨论如何处理这种情况。

1
2
3
4
5
6
7
func extractClientIP(req *http.Request) (string, int64, error) {
vals := strings.SplitN(req.RemoteAddr, ":", 2)
if len(vals[0]) == 0 {
return "", 0, fmt.Errorf("Failed to parse client IP: %v", req.RemoteAddr)
}
return vals[0], 1, nil
}

我们还需要定义当同一个 IP 地址的连接数量超过最高连接数量的错误 MaxConnError 以及处理该错误的方法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
type MaxConnError struct {
max int64
}

func (e *MaxConnError) Error() string {
return fmt.Sprintf("max connections reached: %d", m.max)
}

type ConnErrHandler struct {}

func (h *ConnErrHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
statusCode := http.StatusInternalServerError
w.WriteHeader(statusCode)
w.Write([]byte(http.StatusText(statusCode)))
}

var defaultErrHandler = &ConnErrHandler{}

acquirerelease 方法在加锁的条件下对 map 进行操作,这是因为在并发情况下对同一个数据进行读写操作时,会发生数据竞争的情况,所以需要使用 sync.Mutex 来对数据读写进行保护。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
func (cl *ConnLimiter) acquire(token string, amount int64) error {
cl.Lock()
defer cl.Unlock()

connections := cl.connections[token]
if connections >= cl.maxConnections {
return &MaxConnError{max: cl.maxConnections}
}

cl.connections[token] += amount
cl.totalConnections += amount
return nil
}

func (cl *ConnLimiter) release(token string, amount int64) {
cl.Lock()
defer cl.Unlock()

cl.connections[token] -= amount
cl.totalConnections -= amount

if cl.connections[token] == 0 {
delete(cl.connections, token)
}
}

在提供服务( ServeHTTP )之前,我们通过 acquire 判断该客户端是否有资格获得服务,如果有资格则使用正常的 handler 来处理,否则则使用 errHandler 来处理,最后服务完了需要将相关资源释放( release )。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
func (cl *ConnLimiter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
token, amount, err := extractClientIP(r)
if err != nil {
log.Errorf("failed to extract source of the connection: %v", err)
cl.errHandler.ServeHTTP(w, r)
return
}
if err = cl.acquire(token, amount); err != nil {
log.Debugf("limiting request source %s: %v", token, err)
cl.errHandler.ServeHTTP(w, r)
return
}
defer cl.release(token, amount)

cl.next.ServeHTTP(w, r)
}

func (cl *ConnLimiter) Wrap(h http.Handler) {
cl.next = h
}
Pieces of Valuable Programming Knowledges