mirror of
https://github.com/Dvorinka/Containr.git
synced 2026-06-03 20:12:58 +00:00
small fix, don't worry about it
This commit is contained in:
@@ -0,0 +1,367 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"containr/internal/database"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Logger middleware
|
||||
func Logger() gin.HandlerFunc {
|
||||
return gin.LoggerWithFormatter(func(param gin.LogFormatterParams) string {
|
||||
return fmt.Sprintf("%s - [%s] \"%s %s %s %d %s \"%s\" %s\"\n",
|
||||
param.ClientIP,
|
||||
param.TimeStamp.Format(time.RFC1123),
|
||||
param.Method,
|
||||
param.Path,
|
||||
param.Request.Proto,
|
||||
param.StatusCode,
|
||||
param.Latency,
|
||||
param.Request.UserAgent(),
|
||||
param.ErrorMessage,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// Recovery middleware
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return gin.Recovery()
|
||||
}
|
||||
|
||||
// SecurityHeaders adds secure default HTTP response headers.
|
||||
func SecurityHeaders() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-Content-Type-Options", "nosniff")
|
||||
c.Header("X-Frame-Options", "DENY")
|
||||
c.Header("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
c.Header("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
if strings.EqualFold(c.Request.Header.Get("X-Forwarded-Proto"), "https") || c.Request.TLS != nil {
|
||||
c.Header("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestID middleware adds a unique request ID to each request
|
||||
func RequestID() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestID := c.GetHeader("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = uuid.New().String()
|
||||
}
|
||||
c.Set("request_id", requestID)
|
||||
c.Header("X-Request-ID", requestID)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// Auth middleware for JWT authentication
|
||||
func Auth(jwtSecret string) gin.HandlerFunc {
|
||||
sessionVerifier := newBetterAuthSessionVerifier()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
tokenString, tokenErr, hasToken := extractJWTToken(c)
|
||||
if tokenString != "" {
|
||||
if claims, valid := validateJWTClaims(tokenString, jwtSecret); valid {
|
||||
userIDClaim, exists := claims["user_id"]
|
||||
if exists {
|
||||
userID := strings.TrimSpace(fmt.Sprint(userIDClaim))
|
||||
if _, err := uuid.Parse(userID); err == nil {
|
||||
email := ""
|
||||
if emailClaim, ok := claims["email"]; ok && emailClaim != nil {
|
||||
email = strings.TrimSpace(fmt.Sprint(emailClaim))
|
||||
}
|
||||
c.Set("user_id", userID)
|
||||
c.Set("email", email)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
tokenErr = "Invalid token claims"
|
||||
} else if tokenErr == "" {
|
||||
tokenErr = "Invalid token"
|
||||
}
|
||||
}
|
||||
|
||||
if sessionVerifier != nil {
|
||||
if userID, email, ok := sessionVerifier.resolveUser(c); ok {
|
||||
c.Set("user_id", userID)
|
||||
c.Set("email", email)
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if tokenErr != "" {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": tokenErr})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if hasToken {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token"})
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "Authentication required"})
|
||||
}
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
type betterAuthSessionVerifier struct {
|
||||
internalURL string
|
||||
internalToken string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
type betterAuthSessionUser struct {
|
||||
ID string `json:"id"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
Image string `json:"image"`
|
||||
}
|
||||
|
||||
type betterAuthSessionResponse struct {
|
||||
Authenticated bool `json:"authenticated"`
|
||||
User betterAuthSessionUser `json:"user"`
|
||||
}
|
||||
|
||||
func newBetterAuthSessionVerifier() *betterAuthSessionVerifier {
|
||||
internalURL := strings.TrimSpace(os.Getenv("BETTER_AUTH_INTERNAL_URL"))
|
||||
if internalURL == "" {
|
||||
internalURL = "http://127.0.0.1:3001/internal/session"
|
||||
}
|
||||
internalToken := strings.TrimSpace(os.Getenv("BETTER_AUTH_INTERNAL_TOKEN"))
|
||||
|
||||
if internalToken == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &betterAuthSessionVerifier{
|
||||
internalURL: internalURL,
|
||||
internalToken: internalToken,
|
||||
client: &http.Client{
|
||||
Timeout: 3 * time.Second,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (v *betterAuthSessionVerifier) resolveUser(c *gin.Context) (string, string, bool) {
|
||||
request, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, v.internalURL, nil)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
request.Header.Set("X-Containr-Auth-Internal", v.internalToken)
|
||||
copyHeaderIfPresent(c.Request, request, "Cookie")
|
||||
copyHeaderIfPresent(c.Request, request, "User-Agent")
|
||||
copyHeaderIfPresent(c.Request, request, "X-Forwarded-For")
|
||||
copyHeaderIfPresent(c.Request, request, "X-Forwarded-Proto")
|
||||
|
||||
response, err := v.client.Do(request)
|
||||
if err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
defer response.Body.Close()
|
||||
|
||||
if response.StatusCode != http.StatusOK {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
var payload betterAuthSessionResponse
|
||||
if err := json.NewDecoder(response.Body).Decode(&payload); err != nil {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
if !payload.Authenticated {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(payload.User.Email))
|
||||
if email == "" {
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
localUserID, err := ensureLocalUserRecord(c, payload.User)
|
||||
if err != nil {
|
||||
log.Printf("Failed to map Better Auth user to local user: %v", err)
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
return localUserID, email, true
|
||||
}
|
||||
|
||||
func ensureLocalUserRecord(c *gin.Context, user betterAuthSessionUser) (string, error) {
|
||||
dbValue, ok := c.Get("db")
|
||||
if !ok || dbValue == nil {
|
||||
return "", fmt.Errorf("database context missing")
|
||||
}
|
||||
|
||||
db, ok := dbValue.(*database.DB)
|
||||
if !ok || db == nil {
|
||||
return "", fmt.Errorf("invalid database context")
|
||||
}
|
||||
|
||||
email := strings.ToLower(strings.TrimSpace(user.Email))
|
||||
if email == "" {
|
||||
return "", fmt.Errorf("email is required")
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(user.Name)
|
||||
if name == "" {
|
||||
name = "Containr User"
|
||||
}
|
||||
|
||||
avatarURL := strings.TrimSpace(user.Image)
|
||||
|
||||
var localUserID string
|
||||
err := db.QueryRow(`SELECT id FROM users WHERE email = $1`, email).Scan(&localUserID)
|
||||
switch {
|
||||
case err == nil:
|
||||
_, _ = db.Exec(`
|
||||
UPDATE users
|
||||
SET name = $1,
|
||||
avatar_url = CASE WHEN $2 = '' THEN avatar_url ELSE $2 END,
|
||||
updated_at = NOW()
|
||||
WHERE id = $3
|
||||
`, name, avatarURL, localUserID)
|
||||
return localUserID, nil
|
||||
case !errors.Is(err, sql.ErrNoRows):
|
||||
return "", err
|
||||
}
|
||||
|
||||
hashedPassword, hashErr := bcrypt.GenerateFromPassword([]byte(uuid.NewString()), bcrypt.DefaultCost)
|
||||
if hashErr != nil {
|
||||
return "", hashErr
|
||||
}
|
||||
|
||||
if err := db.QueryRow(`
|
||||
INSERT INTO users (email, password_hash, name, avatar_url)
|
||||
VALUES ($1, $2, $3, NULLIF($4, ''))
|
||||
RETURNING id
|
||||
`, email, string(hashedPassword), name, avatarURL).Scan(&localUserID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return localUserID, nil
|
||||
}
|
||||
|
||||
func validateJWTClaims(tokenString, jwtSecret string) (jwt.MapClaims, bool) {
|
||||
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return []byte(jwtSecret), nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok || !token.Valid {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return claims, true
|
||||
}
|
||||
|
||||
func extractJWTToken(c *gin.Context) (string, string, bool) {
|
||||
authHeader := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if authHeader != "" {
|
||||
parts := strings.Fields(authHeader)
|
||||
if len(parts) != 2 || !strings.EqualFold(parts[0], "Bearer") {
|
||||
return "", "Invalid authorization header format", true
|
||||
}
|
||||
return strings.TrimSpace(parts[1]), "", true
|
||||
}
|
||||
|
||||
if strings.EqualFold(c.GetHeader("Upgrade"), "websocket") {
|
||||
token := strings.TrimSpace(c.Query("token"))
|
||||
if token == "" {
|
||||
return "", "", false
|
||||
}
|
||||
return token, "", true
|
||||
}
|
||||
|
||||
return "", "", false
|
||||
}
|
||||
|
||||
func copyHeaderIfPresent(src *http.Request, dst *http.Request, key string) {
|
||||
value := strings.TrimSpace(src.Header.Get(key))
|
||||
if value != "" {
|
||||
dst.Header.Set(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
// ErrorHandler middleware for consistent error handling
|
||||
func ErrorHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// Check if there are any errors
|
||||
if len(c.Errors) > 0 {
|
||||
err := c.Errors.Last()
|
||||
log.Printf("Request error: %v", err)
|
||||
|
||||
// Return JSON error response
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"error": "Internal server error",
|
||||
"code": "INTERNAL_ERROR",
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// CORSMiddleware for CORS handling
|
||||
func CORSMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Credentials", "true")
|
||||
c.Header("Access-Control-Allow-Headers", "Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization, accept, origin, Cache-Control, X-Requested-With")
|
||||
c.Header("Access-Control-Allow-Methods", "POST, OPTIONS, GET, PUT, DELETE")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(204)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequestBodyLimit enforces a maximum HTTP request body size.
|
||||
func RequestBodyLimit(maxBytes int64) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if maxBytes <= 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.ContentLength > maxBytes {
|
||||
c.AbortWithStatusJSON(http.StatusRequestEntityTooLarge, gin.H{
|
||||
"error": "Request body too large",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Body != nil {
|
||||
c.Request.Body = http.MaxBytesReader(c.Writer, c.Request.Body, maxBytes)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,142 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func TestSecurityHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(SecurityHeaders())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if got := rec.Header().Get("X-Content-Type-Options"); got != "nosniff" {
|
||||
t.Fatalf("expected X-Content-Type-Options nosniff, got %q", got)
|
||||
}
|
||||
if got := rec.Header().Get("X-Frame-Options"); got != "DENY" {
|
||||
t.Fatalf("expected X-Frame-Options DENY, got %q", got)
|
||||
}
|
||||
if got := rec.Header().Get("Referrer-Policy"); got != "strict-origin-when-cross-origin" {
|
||||
t.Fatalf("expected Referrer-Policy strict-origin-when-cross-origin, got %q", got)
|
||||
}
|
||||
if got := rec.Header().Get("X-XSS-Protection"); got != "1; mode=block" {
|
||||
t.Fatalf("expected X-XSS-Protection header, got %q", got)
|
||||
}
|
||||
if got := rec.Header().Get("Strict-Transport-Security"); got != "" {
|
||||
t.Fatalf("expected no HSTS header for plain HTTP request, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeadersAddsHSTSWhenForwardedProtoHTTPS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(SecurityHeaders())
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if got := rec.Header().Get("Strict-Transport-Security"); got == "" {
|
||||
t.Fatal("expected HSTS header for https-forwarded request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestBodyLimitRejectsLargeRequest(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(RequestBodyLimit(8))
|
||||
router.POST("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader("0123456789"))
|
||||
req.ContentLength = 10
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusRequestEntityTooLarge {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthRejectsNonUUIDUserIDClaim(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(Auth("secret"))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.String(http.StatusOK, "ok")
|
||||
})
|
||||
|
||||
token := issueJWT(t, "secret", jwt.MapClaims{
|
||||
"user_id": "not-a-uuid",
|
||||
"email": "test@example.com",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthStoresStringUserIDForValidClaims(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(Auth("secret"))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
userID, _ := c.Get("user_id")
|
||||
c.String(http.StatusOK, fmt.Sprint(userID))
|
||||
})
|
||||
|
||||
expectedUserID := uuid.NewString()
|
||||
token := issueJWT(t, "secret", jwt.MapClaims{
|
||||
"user_id": expectedUserID,
|
||||
"email": "test@example.com",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rec := httptest.NewRecorder()
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
|
||||
}
|
||||
if got := rec.Body.String(); got != expectedUserID {
|
||||
t.Fatalf("expected user_id %q, got %q", expectedUserID, got)
|
||||
}
|
||||
}
|
||||
|
||||
func issueJWT(t *testing.T, secret string, claims jwt.MapClaims) string {
|
||||
t.Helper()
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString([]byte(secret))
|
||||
if err != nil {
|
||||
t.Fatalf("failed to sign token: %v", err)
|
||||
}
|
||||
return signed
|
||||
}
|
||||
@@ -0,0 +1,167 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// RateLimiter implements a token bucket rate limiter
|
||||
type RateLimiter struct {
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*bucket
|
||||
rate int // requests per window
|
||||
window time.Duration // time window
|
||||
cleanupInterval time.Duration
|
||||
}
|
||||
|
||||
type bucket struct {
|
||||
tokens int
|
||||
lastReset time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewRateLimiter creates a new rate limiter
|
||||
func NewRateLimiter(rate int, window time.Duration) *RateLimiter {
|
||||
rl := &RateLimiter{
|
||||
buckets: make(map[string]*bucket),
|
||||
rate: rate,
|
||||
window: window,
|
||||
cleanupInterval: window * 2,
|
||||
}
|
||||
|
||||
// Start cleanup goroutine
|
||||
go rl.cleanup()
|
||||
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow checks if a request should be allowed
|
||||
func (rl *RateLimiter) Allow(key string) bool {
|
||||
rl.mu.RLock()
|
||||
b, exists := rl.buckets[key]
|
||||
rl.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
rl.mu.Lock()
|
||||
b = &bucket{
|
||||
tokens: rl.rate,
|
||||
lastReset: time.Now(),
|
||||
}
|
||||
rl.buckets[key] = b
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Reset bucket if window has passed
|
||||
if time.Since(b.lastReset) > rl.window {
|
||||
b.tokens = rl.rate
|
||||
b.lastReset = time.Now()
|
||||
}
|
||||
|
||||
if b.tokens > 0 {
|
||||
b.tokens--
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// cleanup removes old buckets
|
||||
func (rl *RateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(rl.cleanupInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
rl.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, b := range rl.buckets {
|
||||
b.mu.Lock()
|
||||
if now.Sub(b.lastReset) > rl.cleanupInterval {
|
||||
delete(rl.buckets, key)
|
||||
}
|
||||
b.mu.Unlock()
|
||||
}
|
||||
rl.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimit returns a middleware that limits requests
|
||||
func RateLimit(rate int, window time.Duration) gin.HandlerFunc {
|
||||
limiter := NewRateLimiter(rate, window)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Use IP address as key
|
||||
key := c.ClientIP()
|
||||
|
||||
// For authenticated requests, use user ID if available
|
||||
if userID, exists := c.Get("user_id"); exists {
|
||||
if uid, ok := userID.(string); ok && uid != "" {
|
||||
key = "user:" + uid
|
||||
}
|
||||
}
|
||||
|
||||
if !limiter.Allow(key) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please try again later.",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitByIP limits requests by IP address
|
||||
func RateLimitByIP(rate int, window time.Duration) gin.HandlerFunc {
|
||||
limiter := NewRateLimiter(rate, window)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
if !limiter.Allow(c.ClientIP()) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests from your IP. Please try again later.",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RateLimitByUser limits requests by authenticated user
|
||||
func RateLimitByUser(rate int, window time.Duration) gin.HandlerFunc {
|
||||
limiter := NewRateLimiter(rate, window)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
uid, ok := userID.(string)
|
||||
if !ok || uid == "" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
if !limiter.Allow("user:" + uid) {
|
||||
c.JSON(http.StatusTooManyRequests, gin.H{
|
||||
"error": "rate_limit_exceeded",
|
||||
"message": "Too many requests. Please slow down.",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRateLimiter(t *testing.T) {
|
||||
limiter := NewRateLimiter(5, time.Second)
|
||||
|
||||
// Should allow first 5 requests
|
||||
for i := 0; i < 5; i++ {
|
||||
assert.True(t, limiter.Allow("test-key"), "Request %d should be allowed", i+1)
|
||||
}
|
||||
|
||||
// Should block 6th request
|
||||
assert.False(t, limiter.Allow("test-key"), "6th request should be blocked")
|
||||
|
||||
// Wait for window to reset
|
||||
time.Sleep(time.Second + 100*time.Millisecond)
|
||||
|
||||
// Should allow requests again
|
||||
assert.True(t, limiter.Allow("test-key"), "Request after reset should be allowed")
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(RateLimit(2, time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
// First request should succeed
|
||||
req1 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w1 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w1, req1)
|
||||
assert.Equal(t, http.StatusOK, w1.Code)
|
||||
|
||||
// Second request should succeed
|
||||
req2 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w2 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w2, req2)
|
||||
assert.Equal(t, http.StatusOK, w2.Code)
|
||||
|
||||
// Third request should be rate limited
|
||||
req3 := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w3 := httptest.NewRecorder()
|
||||
router.ServeHTTP(w3, req3)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w3.Code)
|
||||
}
|
||||
|
||||
func TestRateLimitByUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("user_id", "user123")
|
||||
c.Next()
|
||||
})
|
||||
router.Use(RateLimitByUser(2, time.Second))
|
||||
router.GET("/test", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"message": "success"})
|
||||
})
|
||||
|
||||
// First two requests should succeed
|
||||
for i := 0; i < 2; i++ {
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusOK, w.Code)
|
||||
}
|
||||
|
||||
// Third request should be rate limited
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
w := httptest.NewRecorder()
|
||||
router.ServeHTTP(w, req)
|
||||
assert.Equal(t, http.StatusTooManyRequests, w.Code)
|
||||
}
|
||||
@@ -0,0 +1,161 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// ValidateContentType ensures the request has the correct content type
|
||||
func ValidateContentType(allowedTypes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip validation for GET, DELETE, HEAD requests
|
||||
if c.Request.Method == http.MethodGet ||
|
||||
c.Request.Method == http.MethodDelete ||
|
||||
c.Request.Method == http.MethodHead {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
contentType := c.GetHeader("Content-Type")
|
||||
if contentType == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_content_type",
|
||||
"message": "Content-Type header is required",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Check if content type matches any allowed type
|
||||
valid := false
|
||||
for _, allowed := range allowedTypes {
|
||||
if strings.HasPrefix(contentType, allowed) {
|
||||
valid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !valid {
|
||||
c.JSON(http.StatusUnsupportedMediaType, gin.H{
|
||||
"error": "unsupported_media_type",
|
||||
"message": "Content-Type must be one of: " + strings.Join(allowedTypes, ", "),
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateJSONBody ensures the request body is valid JSON
|
||||
func ValidateJSONBody() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Skip validation for GET, DELETE, HEAD requests
|
||||
if c.Request.Method == http.MethodGet ||
|
||||
c.Request.Method == http.MethodDelete ||
|
||||
c.Request.Method == http.MethodHead {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Read body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request_body",
|
||||
"message": "Failed to read request body",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Restore body for handlers
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
|
||||
|
||||
// Empty body is allowed for some requests
|
||||
if len(body) == 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Validate JSON structure (basic check)
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) > 0 {
|
||||
if trimmed[0] != '{' && trimmed[0] != '[' {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_json",
|
||||
"message": "Request body must be valid JSON",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireFields validates that required fields are present in the request
|
||||
func RequireFields(fields ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var body map[string]interface{}
|
||||
if err := c.ShouldBindJSON(&body); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "invalid_request_body",
|
||||
"message": "Failed to parse request body",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
missing := []string{}
|
||||
for _, field := range fields {
|
||||
if _, exists := body[field]; !exists {
|
||||
missing = append(missing, field)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "missing_required_fields",
|
||||
"message": "Missing required fields: " + strings.Join(missing, ", "),
|
||||
"fields": missing,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
// Store parsed body for handlers
|
||||
c.Set("parsed_body", body)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateQueryParams validates required query parameters
|
||||
func ValidateQueryParams(params ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
missing := []string{}
|
||||
for _, param := range params {
|
||||
if c.Query(param) == "" {
|
||||
missing = append(missing, param)
|
||||
}
|
||||
}
|
||||
|
||||
if len(missing) > 0 {
|
||||
c.JSON(http.StatusBadRequest, gin.H{
|
||||
"error": "missing_query_parameters",
|
||||
"message": "Missing required query parameters: " + strings.Join(missing, ", "),
|
||||
"params": missing,
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user