mirror of
https://github.com/Dvorinka/Containr.git
synced 2026-06-03 20:12:58 +00:00
143 lines
4.0 KiB
Go
143 lines
4.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
func TestSecurityHeaders(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(SecurityHeaders())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if got := rec.Header().Get("X-Content-Type-Options"); got != "nosniff" {
|
|
t.Fatalf("expected X-Content-Type-Options nosniff, got %q", got)
|
|
}
|
|
if got := rec.Header().Get("X-Frame-Options"); got != "DENY" {
|
|
t.Fatalf("expected X-Frame-Options DENY, got %q", got)
|
|
}
|
|
if got := rec.Header().Get("Referrer-Policy"); got != "strict-origin-when-cross-origin" {
|
|
t.Fatalf("expected Referrer-Policy strict-origin-when-cross-origin, got %q", got)
|
|
}
|
|
if got := rec.Header().Get("X-XSS-Protection"); got != "1; mode=block" {
|
|
t.Fatalf("expected X-XSS-Protection header, got %q", got)
|
|
}
|
|
if got := rec.Header().Get("Strict-Transport-Security"); got != "" {
|
|
t.Fatalf("expected no HSTS header for plain HTTP request, got %q", got)
|
|
}
|
|
}
|
|
|
|
func TestSecurityHeadersAddsHSTSWhenForwardedProtoHTTPS(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(SecurityHeaders())
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.Header.Set("X-Forwarded-Proto", "https")
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if got := rec.Header().Get("Strict-Transport-Security"); got == "" {
|
|
t.Fatal("expected HSTS header for https-forwarded request")
|
|
}
|
|
}
|
|
|
|
func TestRequestBodyLimitRejectsLargeRequest(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(RequestBodyLimit(8))
|
|
router.POST("/test", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/test", strings.NewReader("0123456789"))
|
|
req.ContentLength = 10
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusRequestEntityTooLarge {
|
|
t.Fatalf("expected status %d, got %d", http.StatusRequestEntityTooLarge, rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthRejectsNonUUIDUserIDClaim(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(Auth("secret"))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
c.String(http.StatusOK, "ok")
|
|
})
|
|
|
|
token := issueJWT(t, "secret", jwt.MapClaims{
|
|
"user_id": "not-a-uuid",
|
|
"email": "test@example.com",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Fatalf("expected status %d, got %d", http.StatusUnauthorized, rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthStoresStringUserIDForValidClaims(t *testing.T) {
|
|
gin.SetMode(gin.TestMode)
|
|
router := gin.New()
|
|
router.Use(Auth("secret"))
|
|
router.GET("/test", func(c *gin.Context) {
|
|
userID, _ := c.Get("user_id")
|
|
c.String(http.StatusOK, fmt.Sprint(userID))
|
|
})
|
|
|
|
expectedUserID := uuid.NewString()
|
|
token := issueJWT(t, "secret", jwt.MapClaims{
|
|
"user_id": expectedUserID,
|
|
"email": "test@example.com",
|
|
"exp": time.Now().Add(time.Hour).Unix(),
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rec := httptest.NewRecorder()
|
|
router.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Fatalf("expected status %d, got %d", http.StatusOK, rec.Code)
|
|
}
|
|
if got := rec.Body.String(); got != expectedUserID {
|
|
t.Fatalf("expected user_id %q, got %q", expectedUserID, got)
|
|
}
|
|
}
|
|
|
|
func issueJWT(t *testing.T, secret string, claims jwt.MapClaims) string {
|
|
t.Helper()
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signed, err := token.SignedString([]byte(secret))
|
|
if err != nil {
|
|
t.Fatalf("failed to sign token: %v", err)
|
|
}
|
|
return signed
|
|
}
|