limit.go (view raw)
1package main
2
3import (
4 "log"
5 "net"
6 "net/http"
7 "sync"
8
9 "golang.org/x/time/rate"
10)
11
12// Create a map to hold the rate limiters for each visitor and a mutex.
13var visitors = make(map[string]*rate.Limiter)
14var mu sync.Mutex
15
16// Retrieve and return the rate limiter for the current visitor if it
17// already exists. Otherwise create a new rate limiter and add it to
18// the visitors map, using the IP address as the key.
19func getVisitor(ip string) *rate.Limiter {
20 mu.Lock()
21 defer mu.Unlock()
22
23 limiter, exists := visitors[ip]
24 if !exists {
25 limiter = rate.NewLimiter(1, 3)
26 visitors[ip] = limiter
27 }
28
29 return limiter
30}
31
32func limit(next http.Handler) http.Handler {
33 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
34 // Get the IP address for the current user.
35 ip, _, err := net.SplitHostPort(r.RemoteAddr)
36 if err != nil {
37 log.Println(err.Error())
38 http.Error(w, "Internal Server Error", http.StatusInternalServerError)
39 return
40 }
41
42 // Call the getVisitor function to retreive the rate limiter for
43 // the current user.
44 limiter := getVisitor(ip)
45 if limiter.Allow() == false {
46 http.Error(w, http.StatusText(429), http.StatusTooManyRequests)
47 return
48 }
49
50 next.ServeHTTP(w, r)
51 })
52}