This commit is contained in:
Tomas Dvorak
2026-05-05 09:48:07 +02:00
parent d854614a87
commit 48c3e15a38
295 changed files with 178381 additions and 1039 deletions
+79
View File
@@ -0,0 +1,79 @@
package auth
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
"time"
"github.com/MicahParks/keyfunc/v3"
"github.com/golang-jwt/jwt/v5"
)
type NeonVerifier struct {
jwks keyfunc.Keyfunc
expectedIssuer string
enabled bool
cancel context.CancelFunc
}
func NewNeonVerifier(neonAuthURL string) (*NeonVerifier, error) {
trimmed := strings.TrimRight(strings.TrimSpace(neonAuthURL), "/")
if trimmed == "" {
return &NeonVerifier{enabled: false}, nil
}
parsed, err := url.Parse(trimmed)
if err != nil {
return nil, fmt.Errorf("parse neon auth url: %w", err)
}
expectedIssuer := fmt.Sprintf("%s://%s", parsed.Scheme, parsed.Host)
jwksURL := fmt.Sprintf("%s/.well-known/jwks.json", trimmed)
ctx, cancel := context.WithCancel(context.Background())
jwks, err := keyfunc.NewDefaultCtx(ctx, []string{jwksURL})
if err != nil {
cancel()
return nil, fmt.Errorf("create neon jwks: %w", err)
}
return &NeonVerifier{jwks: jwks, expectedIssuer: expectedIssuer, enabled: true, cancel: cancel}, nil
}
func (v *NeonVerifier) Enabled() bool {
return v != nil && v.enabled
}
func (v *NeonVerifier) Close() {
if v != nil && v.cancel != nil {
v.cancel()
}
}
func (v *NeonVerifier) Verify(tokenString string) (*Claims, error) {
if !v.Enabled() {
return nil, errors.New("neon auth verifier is disabled")
}
token, err := jwt.Parse(tokenString, v.jwks.Keyfunc,
jwt.WithIssuer(v.expectedIssuer),
jwt.WithValidMethods([]string{"EdDSA"}),
jwt.WithAudience(v.expectedIssuer),
jwt.WithLeeway(15*time.Second),
)
if err != nil {
return nil, err
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok || !token.Valid {
return nil, errors.New("invalid neon claims")
}
subject, _ := claims["sub"].(string)
email, _ := claims["email"].(string)
name, _ := claims["name"].(string)
if name == "" {
name, _ = claims["display_name"].(string)
}
if strings.TrimSpace(subject) == "" {
return nil, errors.New("missing neon subject")
}
return &Claims{UserID: subject, Email: email, Name: name, Role: "authenticated", Type: "access"}, nil
}
+333
View File
@@ -0,0 +1,333 @@
package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"bookra/apps/auth-service/internal/db"
"bookra/apps/auth-service/internal/email"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
const (
accessTokenTTL = 24 * time.Hour
refreshTokenTTL = 30 * 24 * time.Hour
)
type Service struct {
db *db.DB
email *email.Service
jwtSecret []byte
frontendURL string
}
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
type Claims struct {
UserID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Type string `json:"type"`
jwt.RegisteredClaims
}
func NewService(database *db.DB, emailSvc *email.Service, jwtSecret string, frontendURL string) *Service {
return &Service{
db: database,
email: emailSvc,
jwtSecret: []byte(jwtSecret),
frontendURL: frontendURL,
}
}
func (s *Service) GenerateMagicLink(ctx context.Context, emailAddr string, locale string) error {
user, err := s.db.GetUserByEmail(ctx, emailAddr)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if user == nil {
user = &db.User{
Email: emailAddr,
Provider: "email",
}
user, err = s.db.CreateUser(ctx, user)
if err != nil {
return fmt.Errorf("create user: %w", err)
}
}
token := generateRandomToken(32)
expiresAt := time.Now().Add(15 * time.Minute)
if err := s.db.CreateMagicLink(ctx, token, emailAddr, user.ID, expiresAt); err != nil {
return fmt.Errorf("create magic link: %w", err)
}
magicURL := fmt.Sprintf("%s/auth/callback?token=%s", s.frontendURL, token)
var name string
if user.Name != nil {
name = *user.Name
}
if err := s.email.SendMagicLink(emailAddr, name, magicURL, locale); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
func (s *Service) VerifyMagicLink(ctx context.Context, token string) (*TokenPair, error) {
ml, err := s.db.GetMagicLink(ctx, token)
if err != nil {
return nil, fmt.Errorf("get magic link: %w", err)
}
if ml == nil || ml.Used {
return nil, fmt.Errorf("invalid or used token")
}
if time.Now().After(ml.ExpiresAt) {
return nil, fmt.Errorf("token expired")
}
if err := s.db.MarkMagicLinkUsed(ctx, token); err != nil {
return nil, fmt.Errorf("mark used: %w", err)
}
user, err := s.db.GetUserByID(ctx, ml.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user == nil {
return nil, fmt.Errorf("user not found")
}
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
return nil, fmt.Errorf("update login: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) OAuthLoginOrCreate(ctx context.Context, provider, providerID, email, name string) (*TokenPair, error) {
user, err := s.db.GetUserByProviderID(ctx, provider, providerID)
if err != nil {
return nil, fmt.Errorf("get user by provider: %w", err)
}
if user == nil {
existing, err := s.db.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("check existing email: %w", err)
}
if existing != nil {
existing.Provider = provider
existing.ProviderID = &providerID
existing.Name = &name
existing.EmailVerified = true
if err := s.db.UpdateUser(ctx, existing); err != nil {
return nil, fmt.Errorf("link provider: %w", err)
}
user = existing
} else {
user = &db.User{
Email: email,
Name: &name,
Provider: provider,
ProviderID: &providerID,
EmailVerified: true,
}
user, err = s.db.CreateUser(ctx, user)
if err != nil {
return nil, fmt.Errorf("create oauth user: %w", err)
}
}
}
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
return nil, fmt.Errorf("update login: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) RegisterWithPassword(ctx context.Context, email, password, name string) (*TokenPair, error) {
existing, err := s.db.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("check existing: %w", err)
}
if existing != nil {
return nil, fmt.Errorf("email already registered")
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
hashStr := string(hash)
user := &db.User{
Email: email,
Name: &name,
PasswordHash: &hashStr,
Provider: "email",
EmailVerified: false,
}
user, err = s.db.CreateUser(ctx, user)
if err != nil {
return nil, fmt.Errorf("create user: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) LoginWithPassword(ctx context.Context, email, password string) (*TokenPair, error) {
user, err := s.db.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user == nil || user.PasswordHash == nil {
return nil, fmt.Errorf("invalid credentials")
}
if err := bcrypt.CompareHashAndPassword([]byte(*user.PasswordHash), []byte(password)); err != nil {
return nil, fmt.Errorf("invalid credentials")
}
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
return nil, fmt.Errorf("update login: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) generateTokens(user *db.User) (*TokenPair, error) {
now := time.Now()
return s.generateTokensAt(user, now)
}
func (s *Service) generateTokensAt(user *db.User, now time.Time) (*TokenPair, error) {
name := ""
if user.Name != nil {
name = *user.Name
}
accessTokenString, err := s.signToken(user, name, "access", now, accessTokenTTL)
if err != nil {
return nil, fmt.Errorf("sign access token: %w", err)
}
refreshTokenString, err := s.signToken(user, name, "refresh", now, refreshTokenTTL)
if err != nil {
return nil, fmt.Errorf("sign refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessTokenString,
RefreshToken: refreshTokenString,
TokenType: "Bearer",
ExpiresIn: int(accessTokenTTL.Seconds()),
}, nil
}
func (s *Service) VerifyToken(tokenString string) (*Claims, error) {
return s.verifyTokenOfType(tokenString, "access")
}
func (s *Service) VerifyRefreshToken(tokenString string) (*Claims, error) {
return s.verifyTokenOfType(tokenString, "refresh")
}
func (s *Service) RefreshTokens(ctx context.Context, refreshToken string) (*TokenPair, error) {
claims, err := s.VerifyRefreshToken(refreshToken)
if err != nil {
return nil, err
}
user := &db.User{
ID: uuid.MustParse(claims.UserID),
Email: claims.Email,
}
if claims.Name != "" {
user.Name = &claims.Name
}
if s.db != nil {
storedUser, err := s.db.GetUserByID(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if storedUser == nil {
return nil, fmt.Errorf("user not found")
}
user = storedUser
}
return s.generateTokens(user)
}
func (s *Service) verifyTokenOfType(tokenString string, expectedType string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.jwtSecret, nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
if claims.Type != expectedType {
return nil, fmt.Errorf("invalid token type")
}
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
func (s *Service) signToken(user *db.User, name string, tokenType string, now time.Time, ttl time.Duration) (string, error) {
claims := Claims{
UserID: user.ID.String(),
Email: user.Email,
Name: name,
Role: "authenticated",
Type: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "bookra-auth",
Subject: user.ID.String(),
Audience: jwt.ClaimStrings{"bookra"},
ID: generateRandomToken(12),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.jwtSecret)
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}
@@ -0,0 +1,88 @@
package auth
import (
"context"
"testing"
"time"
"bookra/apps/auth-service/internal/db"
"github.com/google/uuid"
)
func TestGenerateTokensProducesVerifiableAccessAndRefreshTokens(t *testing.T) {
service := NewService(nil, nil, "test-secret", "http://localhost:3000")
name := "Token Tester"
user := &db.User{
ID: uuid.MustParse("019daeaa-bc14-7712-9224-e347a96bd5c3"),
Email: "tester@bookra.dev",
Name: &name,
}
tokens, err := service.generateTokensAt(user, time.Now().UTC())
if err != nil {
t.Fatalf("generate tokens: %v", err)
}
accessClaims, err := service.VerifyToken(tokens.AccessToken)
if err != nil {
t.Fatalf("verify access token: %v", err)
}
if accessClaims.Type != "access" {
t.Fatalf("expected access type, got %s", accessClaims.Type)
}
refreshClaims, err := service.VerifyRefreshToken(tokens.RefreshToken)
if err != nil {
t.Fatalf("verify refresh token: %v", err)
}
if refreshClaims.Type != "refresh" {
t.Fatalf("expected refresh type, got %s", refreshClaims.Type)
}
if _, err := service.VerifyToken(tokens.RefreshToken); err == nil {
t.Fatal("expected refresh token to fail access verification")
}
if _, err := service.VerifyRefreshToken(tokens.AccessToken); err == nil {
t.Fatal("expected access token to fail refresh verification")
}
}
func TestRefreshTokensReturnsRotatedPair(t *testing.T) {
service := NewService(nil, nil, "test-secret", "http://localhost:3000")
user := &db.User{
ID: uuid.MustParse("019daeaa-bc14-7712-9224-e347a96bd5c3"),
Email: "tester@bookra.dev",
}
original, err := service.generateTokens(user)
if err != nil {
t.Fatalf("generate tokens: %v", err)
}
refreshed, err := service.RefreshTokens(context.Background(), original.RefreshToken)
if err != nil {
t.Fatalf("refresh tokens: %v", err)
}
if refreshed.AccessToken == original.AccessToken {
t.Fatal("expected rotated access token")
}
if refreshed.RefreshToken == original.RefreshToken {
t.Fatal("expected rotated refresh token")
}
if _, err := service.VerifyToken(refreshed.AccessToken); err != nil {
t.Fatalf("verify refreshed access token: %v", err)
}
if _, err := service.VerifyRefreshToken(refreshed.RefreshToken); err != nil {
t.Fatalf("verify refreshed refresh token: %v", err)
}
}
func TestRefreshTokensRejectsInvalidToken(t *testing.T) {
service := NewService(nil, nil, "test-secret", "http://localhost:3000")
if _, err := service.RefreshTokens(context.Background(), "bad-token"); err == nil {
t.Fatal("expected invalid refresh token error")
}
}