missing file
alex wennerberg alex@alexwennerberg.com
Sun, 06 Dec 2020 01:27:13 -0800
1 files changed,
52 insertions(+),
0 deletions(-)
jump to
A
limit.go
@@ -0,0 +1,52 @@
+package main + +import ( + "log" + "net" + "net/http" + "sync" + + "golang.org/x/time/rate" +) + +// Create a map to hold the rate limiters for each visitor and a mutex. +var visitors = make(map[string]*rate.Limiter) +var mu sync.Mutex + +// Retrieve and return the rate limiter for the current visitor if it +// already exists. Otherwise create a new rate limiter and add it to +// the visitors map, using the IP address as the key. +func getVisitor(ip string) *rate.Limiter { + mu.Lock() + defer mu.Unlock() + + limiter, exists := visitors[ip] + if !exists { + limiter = rate.NewLimiter(1, 3) + visitors[ip] = limiter + } + + return limiter +} + +func limit(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Get the IP address for the current user. + ip, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + log.Println(err.Error()) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + return + } + + // Call the getVisitor function to retreive the rate limiter for + // the current user. + limiter := getVisitor(ip) + if limiter.Allow() == false { + http.Error(w, http.StatusText(429), http.StatusTooManyRequests) + return + } + + next.ServeHTTP(w, r) + }) +}