mirror of
https://github.com/Dvorinka/MyClubServer.git
synced 2026-06-03 18:22:57 +00:00
142 lines
2.7 KiB
Go
142 lines
2.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// limiterKey identifies a client + path pair
|
|
type limiterKey struct {
|
|
IP string
|
|
Path string
|
|
}
|
|
|
|
type counter struct {
|
|
Count int
|
|
ExpiresAt time.Time
|
|
}
|
|
|
|
// in-memory store (process local)
|
|
var (
|
|
limitStore = struct {
|
|
sync.Mutex
|
|
m map[limiterKey]*counter
|
|
}{m: make(map[limiterKey]*counter)}
|
|
)
|
|
|
|
// RateLimit returns a middleware that limits requests to `max` per given `window` per IP and path.
|
|
func RateLimit(max int, window time.Duration) gin.HandlerFunc {
|
|
if max <= 0 {
|
|
max = 10
|
|
}
|
|
if window <= 0 {
|
|
window = time.Minute
|
|
}
|
|
return func(c *gin.Context) {
|
|
ip := clientIP(c.Request)
|
|
key := limiterKey{IP: ip, Path: c.FullPath()}
|
|
|
|
limitStore.Lock()
|
|
ct, ok := limitStore.m[key]
|
|
now := time.Now()
|
|
if !ok || now.After(ct.ExpiresAt) {
|
|
ct = &counter{Count: 0, ExpiresAt: now.Add(window)}
|
|
limitStore.m[key] = ct
|
|
}
|
|
if ct.Count >= max {
|
|
retryAfter := int(ct.ExpiresAt.Sub(now).Seconds())
|
|
limitStore.Unlock()
|
|
c.Header("Retry-After", strconvItoaSafe(retryAfter))
|
|
c.JSON(http.StatusTooManyRequests, gin.H{"error": "Příliš mnoho požadavků, zkuste to prosím později."})
|
|
c.Abort()
|
|
return
|
|
}
|
|
ct.Count++
|
|
limitStore.Unlock()
|
|
|
|
c.Next()
|
|
}
|
|
}
|
|
|
|
func clientIP(r *http.Request) string {
|
|
// Prefer X-Forwarded-For if present (first IP)
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
if p := parseFirstIP(xff); p != "" {
|
|
return p
|
|
}
|
|
}
|
|
// Fallback to RemoteAddr
|
|
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err == nil && host != "" {
|
|
return host
|
|
}
|
|
return r.RemoteAddr
|
|
}
|
|
|
|
func parseFirstIP(s string) string {
|
|
for _, part := range splitAndTrim(s, ',') {
|
|
ip := net.ParseIP(part)
|
|
if ip != nil {
|
|
return ip.String()
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func splitAndTrim(s string, sep rune) []string {
|
|
var out []string
|
|
cur := make([]rune, 0, len(s))
|
|
for _, ch := range s {
|
|
if ch == sep {
|
|
part := string(cur)
|
|
cur = cur[:0]
|
|
if t := trimSpace(part); t != "" {
|
|
out = append(out, t)
|
|
}
|
|
continue
|
|
}
|
|
cur = append(cur, ch)
|
|
}
|
|
if t := trimSpace(string(cur)); t != "" {
|
|
out = append(out, t)
|
|
}
|
|
return out
|
|
}
|
|
|
|
func trimSpace(s string) string {
|
|
start, end := 0, len(s)
|
|
for start < end && (s[start] == ' ' || s[start] == '\t') {
|
|
start++
|
|
}
|
|
for end > start && (s[end-1] == ' ' || s[end-1] == '\t') {
|
|
end--
|
|
}
|
|
return s[start:end]
|
|
}
|
|
|
|
func strconvItoaSafe(i int) string {
|
|
// Avoid importing strconv just for small header value
|
|
if i == 0 {
|
|
return "0"
|
|
}
|
|
neg := false
|
|
if i < 0 {
|
|
neg = true
|
|
i = -i
|
|
}
|
|
buf := make([]byte, 0, 12)
|
|
for i > 0 {
|
|
d := byte(i % 10)
|
|
buf = append([]byte{'0' + d}, buf...)
|
|
i /= 10
|
|
}
|
|
if neg {
|
|
buf = append([]byte{'-'}, buf...)
|
|
}
|
|
return string(buf)
|
|
}
|