package middleware import ( "net" "net/http" "sync" "time" "github.com/gin-gonic/gin" ) // limiterKey identifies a client + path pair type limiterKey struct { IP string Path string } type counter struct { Count int ExpiresAt time.Time } // in-memory store (process local) var ( limitStore = struct { sync.Mutex m map[limiterKey]*counter }{m: make(map[limiterKey]*counter)} ) // RateLimit returns a middleware that limits requests to `max` per given `window` per IP and path. func RateLimit(max int, window time.Duration) gin.HandlerFunc { if max <= 0 { max = 10 } if window <= 0 { window = time.Minute } return func(c *gin.Context) { ip := clientIP(c.Request) key := limiterKey{IP: ip, Path: c.FullPath()} limitStore.Lock() ct, ok := limitStore.m[key] now := time.Now() if !ok || now.After(ct.ExpiresAt) { ct = &counter{Count: 0, ExpiresAt: now.Add(window)} limitStore.m[key] = ct } if ct.Count >= max { retryAfter := int(ct.ExpiresAt.Sub(now).Seconds()) limitStore.Unlock() c.Header("Retry-After", strconvItoaSafe(retryAfter)) c.JSON(http.StatusTooManyRequests, gin.H{"error": "Příliš mnoho požadavků, zkuste to prosím později."}) c.Abort() return } ct.Count++ limitStore.Unlock() c.Next() } } func clientIP(r *http.Request) string { // Prefer X-Forwarded-For if present (first IP) if xff := r.Header.Get("X-Forwarded-For"); xff != "" { if p := parseFirstIP(xff); p != "" { return p } } // Fallback to RemoteAddr host, _, err := net.SplitHostPort(r.RemoteAddr) if err == nil && host != "" { return host } return r.RemoteAddr } func parseFirstIP(s string) string { for _, part := range splitAndTrim(s, ',') { ip := net.ParseIP(part) if ip != nil { return ip.String() } } return "" } func splitAndTrim(s string, sep rune) []string { var out []string cur := make([]rune, 0, len(s)) for _, ch := range s { if ch == sep { part := string(cur) cur = cur[:0] if t := trimSpace(part); t != "" { out = append(out, t) } continue } cur = append(cur, ch) } if t := trimSpace(string(cur)); t != "" { out = append(out, t) } return out } func trimSpace(s string) string { start, end := 0, len(s) for start < end && (s[start] == ' ' || s[start] == '\t') { start++ } for end > start && (s[end-1] == ' ' || s[end-1] == '\t') { end-- } return s[start:end] } func strconvItoaSafe(i int) string { // Avoid importing strconv just for small header value if i == 0 { return "0" } neg := false if i < 0 { neg = true i = -i } buf := make([]byte, 0, 12) for i > 0 { d := byte(i % 10) buf = append([]byte{'0' + d}, buf...) i /= 10 } if neg { buf = append([]byte{'-'}, buf...) } return string(buf) }