package auth import ( "context" "errors" "fmt" "net/url" "strings" "time" "github.com/MicahParks/keyfunc/v3" "github.com/golang-jwt/jwt/v5" ) type Verifier struct { jwks keyfunc.Keyfunc expectedIssuer string enabled bool localSecret []byte cancel context.CancelFunc } func NewVerifier(neonAuthURL string, localJWTSecret string) (*Verifier, error) { trimmed := strings.TrimRight(strings.TrimSpace(neonAuthURL), "/") if trimmed == "" { secret := strings.TrimSpace(localJWTSecret) return &Verifier{ enabled: secret != "", localSecret: []byte(secret), }, nil } parsed, err := url.Parse(trimmed) if err != nil { return nil, fmt.Errorf("parse neon auth url: %w", err) } expectedIssuer := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host) jwksURL := fmt.Sprintf("%s/.well-known/jwks.json", trimmed) ctx, cancel := context.WithCancel(context.Background()) jwks, err := keyfunc.NewDefaultCtx(ctx, []string{jwksURL}) if err != nil { cancel() return nil, fmt.Errorf("create jwks: %w", err) } return &Verifier{ jwks: jwks, expectedIssuer: expectedIssuer, enabled: true, localSecret: []byte(strings.TrimSpace(localJWTSecret)), cancel: cancel, }, nil } func (v *Verifier) Enabled() bool { return v.enabled } func (v *Verifier) Close() { if v.cancel != nil { v.cancel() } } func (v *Verifier) Verify(tokenString string) (jwt.MapClaims, error) { if !v.enabled { return nil, errors.New("neon auth verifier is disabled") } if len(v.localSecret) > 0 && v.jwks == nil { token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } return v.localSecret, nil }, jwt.WithIssuer("bookra-auth"), jwt.WithAudience("bookra"), jwt.WithLeeway(15*time.Second)) if err != nil { return nil, err } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { return nil, errors.New("invalid token claims") } if tokenType, _ := claims["type"].(string); tokenType != "access" { return nil, errors.New("invalid token type") } return claims, nil } token, err := jwt.Parse(tokenString, v.jwks.Keyfunc, jwt.WithIssuer(v.expectedIssuer), jwt.WithValidMethods([]string{"EdDSA"}), jwt.WithAudience(v.expectedIssuer), jwt.WithLeeway(15*time.Second), ) if err != nil { return nil, err } claims, ok := token.Claims.(jwt.MapClaims) if !ok || !token.Valid { return nil, errors.New("invalid token claims") } return claims, nil }