Files
Containr/app/backend/internal/middleware/middleware_test.go
T
2026-04-10 12:02:36 +02:00

143 lines
4.0 KiB
Go

package middleware
import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
func TestSecurityHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(SecurityHeaders())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if got := rec.Header().Get("X-Content-Type-Options"); got != "nosniff" {
t.Fatalf("expected X-Content-Type-Options nosniff, got %q", got)
}
if got := rec.Header().Get("X-Frame-Options"); got != "DENY" {
t.Fatalf("expected X-Frame-Options DENY, got %q", got)
}
if got := rec.Header().Get("Referrer-Policy"); got != "strict-origin-when-cross-origin" {
t.Fatalf("expected Referrer-Policy strict-origin-when-cross-origin, got %q", got)
}
if got := rec.Header().Get("X-XSS-Protection"); got != "1; mode=block" {
t.Fatalf("expected X-XSS-Protection header, got %q", got)
}
if got := rec.Header().Get("Strict-Transport-Security"); got != "" {
t.Fatalf("expected no HSTS header for plain HTTP request, got %q", got)
}
}
func TestSecurityHeadersAddsHSTSWhenForwardedProtoHTTPS(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(SecurityHeaders())
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("X-Forwarded-Proto", "https")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if got := rec.Header().Get("Strict-Transport-Security"); got == "" {
t.Fatal("expected HSTS header for https-forwarded request")
}
}
func TestRequestBodyLimitRejectsLargeRequest(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(RequestBodyLimit(8))
router.POST("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader("0123456789"))
req.ContentLength = 10
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusRequestEntityTooLarge {
t.Fatalf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code)
}
}
func TestAuthRejectsNonUUIDUserIDClaim(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(Auth("secret"))
router.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
token := issueJWT(t, "secret", jwt.MapClaims{
"user_id": "not-a-uuid",
"email": "test@example.com",
"exp": time.Now().Add(time.Hour).Unix(),
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code)
}
}
func TestAuthStoresStringUserIDForValidClaims(t *testing.T) {
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(Auth("secret"))
router.GET("/test", func(c *gin.Context) {
userID, _ := c.Get("user_id")
c.String(http.StatusOK, fmt.Sprint(userID))
})
expectedUserID := uuid.NewString()
token := issueJWT(t, "secret", jwt.MapClaims{
"user_id": expectedUserID,
"email": "test@example.com",
"exp": time.Now().Add(time.Hour).Unix(),
})
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
}
if got := rec.Body.String(); got != expectedUserID {
t.Fatalf("expected user_id %q, got %q", expectedUserID, got)
}
}
func issueJWT(t *testing.T, secret string, claims jwt.MapClaims) string {
t.Helper()
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString([]byte(secret))
if err != nil {
t.Fatalf("failed to sign token: %v", err)
}
return signed
}