mirror of
https://github.com/Dvorinka/Bookra.git
synced 2026-06-03 20:13:00 +00:00
334 lines
8.2 KiB
Go
334 lines
8.2 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"encoding/base64"
|
|
"fmt"
|
|
"time"
|
|
|
|
"bookra/apps/auth-service/internal/db"
|
|
"bookra/apps/auth-service/internal/email"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
const (
|
|
accessTokenTTL = 24 * time.Hour
|
|
refreshTokenTTL = 30 * 24 * time.Hour
|
|
)
|
|
|
|
type Service struct {
|
|
db *db.DB
|
|
email *email.Service
|
|
jwtSecret []byte
|
|
frontendURL string
|
|
}
|
|
|
|
type TokenPair struct {
|
|
AccessToken string `json:"access_token"`
|
|
RefreshToken string `json:"refresh_token,omitempty"`
|
|
TokenType string `json:"token_type"`
|
|
ExpiresIn int `json:"expires_in"`
|
|
}
|
|
|
|
type Claims struct {
|
|
UserID string `json:"sub"`
|
|
Email string `json:"email"`
|
|
Name string `json:"name,omitempty"`
|
|
Role string `json:"role,omitempty"`
|
|
Type string `json:"type"`
|
|
jwt.RegisteredClaims
|
|
}
|
|
|
|
func NewService(database *db.DB, emailSvc *email.Service, jwtSecret string, frontendURL string) *Service {
|
|
return &Service{
|
|
db: database,
|
|
email: emailSvc,
|
|
jwtSecret: []byte(jwtSecret),
|
|
frontendURL: frontendURL,
|
|
}
|
|
}
|
|
|
|
func (s *Service) GenerateMagicLink(ctx context.Context, emailAddr string, locale string) error {
|
|
user, err := s.db.GetUserByEmail(ctx, emailAddr)
|
|
if err != nil {
|
|
return fmt.Errorf("get user: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
user = &db.User{
|
|
Email: emailAddr,
|
|
Provider: "email",
|
|
}
|
|
user, err = s.db.CreateUser(ctx, user)
|
|
if err != nil {
|
|
return fmt.Errorf("create user: %w", err)
|
|
}
|
|
}
|
|
|
|
token := generateRandomToken(32)
|
|
expiresAt := time.Now().Add(15 * time.Minute)
|
|
|
|
if err := s.db.CreateMagicLink(ctx, token, emailAddr, user.ID, expiresAt); err != nil {
|
|
return fmt.Errorf("create magic link: %w", err)
|
|
}
|
|
|
|
magicURL := fmt.Sprintf("%s/auth/callback?token=%s", s.frontendURL, token)
|
|
|
|
var name string
|
|
if user.Name != nil {
|
|
name = *user.Name
|
|
}
|
|
|
|
if err := s.email.SendMagicLink(emailAddr, name, magicURL, locale); err != nil {
|
|
return fmt.Errorf("send email: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Service) VerifyMagicLink(ctx context.Context, token string) (*TokenPair, error) {
|
|
ml, err := s.db.GetMagicLink(ctx, token)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get magic link: %w", err)
|
|
}
|
|
|
|
if ml == nil || ml.Used {
|
|
return nil, fmt.Errorf("invalid or used token")
|
|
}
|
|
|
|
if time.Now().After(ml.ExpiresAt) {
|
|
return nil, fmt.Errorf("token expired")
|
|
}
|
|
|
|
if err := s.db.MarkMagicLinkUsed(ctx, token); err != nil {
|
|
return nil, fmt.Errorf("mark used: %w", err)
|
|
}
|
|
|
|
user, err := s.db.GetUserByID(ctx, ml.UserID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
|
|
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
|
|
return nil, fmt.Errorf("update login: %w", err)
|
|
}
|
|
|
|
return s.generateTokens(user)
|
|
}
|
|
|
|
func (s *Service) OAuthLoginOrCreate(ctx context.Context, provider, providerID, email, name string) (*TokenPair, error) {
|
|
user, err := s.db.GetUserByProviderID(ctx, provider, providerID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user by provider: %w", err)
|
|
}
|
|
|
|
if user == nil {
|
|
existing, err := s.db.GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check existing email: %w", err)
|
|
}
|
|
|
|
if existing != nil {
|
|
existing.Provider = provider
|
|
existing.ProviderID = &providerID
|
|
existing.Name = &name
|
|
existing.EmailVerified = true
|
|
if err := s.db.UpdateUser(ctx, existing); err != nil {
|
|
return nil, fmt.Errorf("link provider: %w", err)
|
|
}
|
|
user = existing
|
|
} else {
|
|
user = &db.User{
|
|
Email: email,
|
|
Name: &name,
|
|
Provider: provider,
|
|
ProviderID: &providerID,
|
|
EmailVerified: true,
|
|
}
|
|
user, err = s.db.CreateUser(ctx, user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create oauth user: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
|
|
return nil, fmt.Errorf("update login: %w", err)
|
|
}
|
|
|
|
return s.generateTokens(user)
|
|
}
|
|
|
|
func (s *Service) RegisterWithPassword(ctx context.Context, email, password, name string) (*TokenPair, error) {
|
|
existing, err := s.db.GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check existing: %w", err)
|
|
}
|
|
if existing != nil {
|
|
return nil, fmt.Errorf("email already registered")
|
|
}
|
|
|
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("hash password: %w", err)
|
|
}
|
|
|
|
hashStr := string(hash)
|
|
user := &db.User{
|
|
Email: email,
|
|
Name: &name,
|
|
PasswordHash: &hashStr,
|
|
Provider: "email",
|
|
EmailVerified: false,
|
|
}
|
|
|
|
user, err = s.db.CreateUser(ctx, user)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create user: %w", err)
|
|
}
|
|
|
|
return s.generateTokens(user)
|
|
}
|
|
|
|
func (s *Service) LoginWithPassword(ctx context.Context, email, password string) (*TokenPair, error) {
|
|
user, err := s.db.GetUserByEmail(ctx, email)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user: %w", err)
|
|
}
|
|
if user == nil || user.PasswordHash == nil {
|
|
return nil, fmt.Errorf("invalid credentials")
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(*user.PasswordHash), []byte(password)); err != nil {
|
|
return nil, fmt.Errorf("invalid credentials")
|
|
}
|
|
|
|
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
|
|
return nil, fmt.Errorf("update login: %w", err)
|
|
}
|
|
|
|
return s.generateTokens(user)
|
|
}
|
|
|
|
func (s *Service) generateTokens(user *db.User) (*TokenPair, error) {
|
|
now := time.Now()
|
|
return s.generateTokensAt(user, now)
|
|
}
|
|
|
|
func (s *Service) generateTokensAt(user *db.User, now time.Time) (*TokenPair, error) {
|
|
name := ""
|
|
if user.Name != nil {
|
|
name = *user.Name
|
|
}
|
|
|
|
accessTokenString, err := s.signToken(user, name, "access", now, accessTokenTTL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sign access token: %w", err)
|
|
}
|
|
|
|
refreshTokenString, err := s.signToken(user, name, "refresh", now, refreshTokenTTL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("sign refresh token: %w", err)
|
|
}
|
|
|
|
return &TokenPair{
|
|
AccessToken: accessTokenString,
|
|
RefreshToken: refreshTokenString,
|
|
TokenType: "Bearer",
|
|
ExpiresIn: int(accessTokenTTL.Seconds()),
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) VerifyToken(tokenString string) (*Claims, error) {
|
|
return s.verifyTokenOfType(tokenString, "access")
|
|
}
|
|
|
|
func (s *Service) VerifyRefreshToken(tokenString string) (*Claims, error) {
|
|
return s.verifyTokenOfType(tokenString, "refresh")
|
|
}
|
|
|
|
func (s *Service) RefreshTokens(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
|
claims, err := s.VerifyRefreshToken(refreshToken)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
user := &db.User{
|
|
ID: uuid.MustParse(claims.UserID),
|
|
Email: claims.Email,
|
|
}
|
|
if claims.Name != "" {
|
|
user.Name = &claims.Name
|
|
}
|
|
|
|
if s.db != nil {
|
|
storedUser, err := s.db.GetUserByID(ctx, user.ID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get user: %w", err)
|
|
}
|
|
if storedUser == nil {
|
|
return nil, fmt.Errorf("user not found")
|
|
}
|
|
user = storedUser
|
|
}
|
|
|
|
return s.generateTokens(user)
|
|
}
|
|
|
|
func (s *Service) verifyTokenOfType(tokenString string, expectedType string) (*Claims, error) {
|
|
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
|
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
|
}
|
|
return s.jwtSecret, nil
|
|
})
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
|
if claims.Type != expectedType {
|
|
return nil, fmt.Errorf("invalid token type")
|
|
}
|
|
return claims, nil
|
|
}
|
|
|
|
return nil, fmt.Errorf("invalid token")
|
|
}
|
|
|
|
func (s *Service) signToken(user *db.User, name string, tokenType string, now time.Time, ttl time.Duration) (string, error) {
|
|
claims := Claims{
|
|
UserID: user.ID.String(),
|
|
Email: user.Email,
|
|
Name: name,
|
|
Role: "authenticated",
|
|
Type: tokenType,
|
|
RegisteredClaims: jwt.RegisteredClaims{
|
|
Issuer: "bookra-auth",
|
|
Subject: user.ID.String(),
|
|
Audience: jwt.ClaimStrings{"bookra"},
|
|
ID: generateRandomToken(12),
|
|
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
|
IssuedAt: jwt.NewNumericDate(now),
|
|
},
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
return token.SignedString(s.jwtSecret)
|
|
}
|
|
|
|
func generateRandomToken(length int) string {
|
|
b := make([]byte, length)
|
|
rand.Read(b)
|
|
return base64.URLEncoding.EncodeToString(b)
|
|
}
|