Files
Bookra/apps/auth-service/internal/auth/service.go
T
Tomas Dvorak 48c3e15a38 cleanup
2026-05-05 09:48:07 +02:00

334 lines
8.2 KiB
Go

package auth
import (
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"time"
"bookra/apps/auth-service/internal/db"
"bookra/apps/auth-service/internal/email"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"golang.org/x/crypto/bcrypt"
)
const (
accessTokenTTL = 24 * time.Hour
refreshTokenTTL = 30 * 24 * time.Hour
)
type Service struct {
db *db.DB
email *email.Service
jwtSecret []byte
frontendURL string
}
type TokenPair struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token,omitempty"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
type Claims struct {
UserID string `json:"sub"`
Email string `json:"email"`
Name string `json:"name,omitempty"`
Role string `json:"role,omitempty"`
Type string `json:"type"`
jwt.RegisteredClaims
}
func NewService(database *db.DB, emailSvc *email.Service, jwtSecret string, frontendURL string) *Service {
return &Service{
db: database,
email: emailSvc,
jwtSecret: []byte(jwtSecret),
frontendURL: frontendURL,
}
}
func (s *Service) GenerateMagicLink(ctx context.Context, emailAddr string, locale string) error {
user, err := s.db.GetUserByEmail(ctx, emailAddr)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
if user == nil {
user = &db.User{
Email: emailAddr,
Provider: "email",
}
user, err = s.db.CreateUser(ctx, user)
if err != nil {
return fmt.Errorf("create user: %w", err)
}
}
token := generateRandomToken(32)
expiresAt := time.Now().Add(15 * time.Minute)
if err := s.db.CreateMagicLink(ctx, token, emailAddr, user.ID, expiresAt); err != nil {
return fmt.Errorf("create magic link: %w", err)
}
magicURL := fmt.Sprintf("%s/auth/callback?token=%s", s.frontendURL, token)
var name string
if user.Name != nil {
name = *user.Name
}
if err := s.email.SendMagicLink(emailAddr, name, magicURL, locale); err != nil {
return fmt.Errorf("send email: %w", err)
}
return nil
}
func (s *Service) VerifyMagicLink(ctx context.Context, token string) (*TokenPair, error) {
ml, err := s.db.GetMagicLink(ctx, token)
if err != nil {
return nil, fmt.Errorf("get magic link: %w", err)
}
if ml == nil || ml.Used {
return nil, fmt.Errorf("invalid or used token")
}
if time.Now().After(ml.ExpiresAt) {
return nil, fmt.Errorf("token expired")
}
if err := s.db.MarkMagicLinkUsed(ctx, token); err != nil {
return nil, fmt.Errorf("mark used: %w", err)
}
user, err := s.db.GetUserByID(ctx, ml.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user == nil {
return nil, fmt.Errorf("user not found")
}
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
return nil, fmt.Errorf("update login: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) OAuthLoginOrCreate(ctx context.Context, provider, providerID, email, name string) (*TokenPair, error) {
user, err := s.db.GetUserByProviderID(ctx, provider, providerID)
if err != nil {
return nil, fmt.Errorf("get user by provider: %w", err)
}
if user == nil {
existing, err := s.db.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("check existing email: %w", err)
}
if existing != nil {
existing.Provider = provider
existing.ProviderID = &providerID
existing.Name = &name
existing.EmailVerified = true
if err := s.db.UpdateUser(ctx, existing); err != nil {
return nil, fmt.Errorf("link provider: %w", err)
}
user = existing
} else {
user = &db.User{
Email: email,
Name: &name,
Provider: provider,
ProviderID: &providerID,
EmailVerified: true,
}
user, err = s.db.CreateUser(ctx, user)
if err != nil {
return nil, fmt.Errorf("create oauth user: %w", err)
}
}
}
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
return nil, fmt.Errorf("update login: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) RegisterWithPassword(ctx context.Context, email, password, name string) (*TokenPair, error) {
existing, err := s.db.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("check existing: %w", err)
}
if existing != nil {
return nil, fmt.Errorf("email already registered")
}
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
hashStr := string(hash)
user := &db.User{
Email: email,
Name: &name,
PasswordHash: &hashStr,
Provider: "email",
EmailVerified: false,
}
user, err = s.db.CreateUser(ctx, user)
if err != nil {
return nil, fmt.Errorf("create user: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) LoginWithPassword(ctx context.Context, email, password string) (*TokenPair, error) {
user, err := s.db.GetUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if user == nil || user.PasswordHash == nil {
return nil, fmt.Errorf("invalid credentials")
}
if err := bcrypt.CompareHashAndPassword([]byte(*user.PasswordHash), []byte(password)); err != nil {
return nil, fmt.Errorf("invalid credentials")
}
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
return nil, fmt.Errorf("update login: %w", err)
}
return s.generateTokens(user)
}
func (s *Service) generateTokens(user *db.User) (*TokenPair, error) {
now := time.Now()
return s.generateTokensAt(user, now)
}
func (s *Service) generateTokensAt(user *db.User, now time.Time) (*TokenPair, error) {
name := ""
if user.Name != nil {
name = *user.Name
}
accessTokenString, err := s.signToken(user, name, "access", now, accessTokenTTL)
if err != nil {
return nil, fmt.Errorf("sign access token: %w", err)
}
refreshTokenString, err := s.signToken(user, name, "refresh", now, refreshTokenTTL)
if err != nil {
return nil, fmt.Errorf("sign refresh token: %w", err)
}
return &TokenPair{
AccessToken: accessTokenString,
RefreshToken: refreshTokenString,
TokenType: "Bearer",
ExpiresIn: int(accessTokenTTL.Seconds()),
}, nil
}
func (s *Service) VerifyToken(tokenString string) (*Claims, error) {
return s.verifyTokenOfType(tokenString, "access")
}
func (s *Service) VerifyRefreshToken(tokenString string) (*Claims, error) {
return s.verifyTokenOfType(tokenString, "refresh")
}
func (s *Service) RefreshTokens(ctx context.Context, refreshToken string) (*TokenPair, error) {
claims, err := s.VerifyRefreshToken(refreshToken)
if err != nil {
return nil, err
}
user := &db.User{
ID: uuid.MustParse(claims.UserID),
Email: claims.Email,
}
if claims.Name != "" {
user.Name = &claims.Name
}
if s.db != nil {
storedUser, err := s.db.GetUserByID(ctx, user.ID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
if storedUser == nil {
return nil, fmt.Errorf("user not found")
}
user = storedUser
}
return s.generateTokens(user)
}
func (s *Service) verifyTokenOfType(tokenString string, expectedType string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return s.jwtSecret, nil
})
if err != nil {
return nil, err
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
if claims.Type != expectedType {
return nil, fmt.Errorf("invalid token type")
}
return claims, nil
}
return nil, fmt.Errorf("invalid token")
}
func (s *Service) signToken(user *db.User, name string, tokenType string, now time.Time, ttl time.Duration) (string, error) {
claims := Claims{
UserID: user.ID.String(),
Email: user.Email,
Name: name,
Role: "authenticated",
Type: tokenType,
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "bookra-auth",
Subject: user.ID.String(),
Audience: jwt.ClaimStrings{"bookra"},
ID: generateRandomToken(12),
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
IssuedAt: jwt.NewNumericDate(now),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(s.jwtSecret)
}
func generateRandomToken(length int) string {
b := make([]byte, length)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)
}