mirror of
https://github.com/Dvorinka/Bookra.git
synced 2026-06-04 20:43:01 +00:00
cleanup
This commit is contained in:
@@ -0,0 +1,333 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"bookra/apps/auth-service/internal/db"
|
||||
"bookra/apps/auth-service/internal/email"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
accessTokenTTL = 24 * time.Hour
|
||||
refreshTokenTTL = 30 * 24 * time.Hour
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
db *db.DB
|
||||
email *email.Service
|
||||
jwtSecret []byte
|
||||
frontendURL string
|
||||
}
|
||||
|
||||
type TokenPair struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
type Claims struct {
|
||||
UserID string `json:"sub"`
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name,omitempty"`
|
||||
Role string `json:"role,omitempty"`
|
||||
Type string `json:"type"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
func NewService(database *db.DB, emailSvc *email.Service, jwtSecret string, frontendURL string) *Service {
|
||||
return &Service{
|
||||
db: database,
|
||||
email: emailSvc,
|
||||
jwtSecret: []byte(jwtSecret),
|
||||
frontendURL: frontendURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) GenerateMagicLink(ctx context.Context, emailAddr string, locale string) error {
|
||||
user, err := s.db.GetUserByEmail(ctx, emailAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
user = &db.User{
|
||||
Email: emailAddr,
|
||||
Provider: "email",
|
||||
}
|
||||
user, err = s.db.CreateUser(ctx, user)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
token := generateRandomToken(32)
|
||||
expiresAt := time.Now().Add(15 * time.Minute)
|
||||
|
||||
if err := s.db.CreateMagicLink(ctx, token, emailAddr, user.ID, expiresAt); err != nil {
|
||||
return fmt.Errorf("create magic link: %w", err)
|
||||
}
|
||||
|
||||
magicURL := fmt.Sprintf("%s/auth/callback?token=%s", s.frontendURL, token)
|
||||
|
||||
var name string
|
||||
if user.Name != nil {
|
||||
name = *user.Name
|
||||
}
|
||||
|
||||
if err := s.email.SendMagicLink(emailAddr, name, magicURL, locale); err != nil {
|
||||
return fmt.Errorf("send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) VerifyMagicLink(ctx context.Context, token string) (*TokenPair, error) {
|
||||
ml, err := s.db.GetMagicLink(ctx, token)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get magic link: %w", err)
|
||||
}
|
||||
|
||||
if ml == nil || ml.Used {
|
||||
return nil, fmt.Errorf("invalid or used token")
|
||||
}
|
||||
|
||||
if time.Now().After(ml.ExpiresAt) {
|
||||
return nil, fmt.Errorf("token expired")
|
||||
}
|
||||
|
||||
if err := s.db.MarkMagicLinkUsed(ctx, token); err != nil {
|
||||
return nil, fmt.Errorf("mark used: %w", err)
|
||||
}
|
||||
|
||||
user, err := s.db.GetUserByID(ctx, ml.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
|
||||
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
|
||||
return nil, fmt.Errorf("update login: %w", err)
|
||||
}
|
||||
|
||||
return s.generateTokens(user)
|
||||
}
|
||||
|
||||
func (s *Service) OAuthLoginOrCreate(ctx context.Context, provider, providerID, email, name string) (*TokenPair, error) {
|
||||
user, err := s.db.GetUserByProviderID(ctx, provider, providerID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user by provider: %w", err)
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
existing, err := s.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check existing email: %w", err)
|
||||
}
|
||||
|
||||
if existing != nil {
|
||||
existing.Provider = provider
|
||||
existing.ProviderID = &providerID
|
||||
existing.Name = &name
|
||||
existing.EmailVerified = true
|
||||
if err := s.db.UpdateUser(ctx, existing); err != nil {
|
||||
return nil, fmt.Errorf("link provider: %w", err)
|
||||
}
|
||||
user = existing
|
||||
} else {
|
||||
user = &db.User{
|
||||
Email: email,
|
||||
Name: &name,
|
||||
Provider: provider,
|
||||
ProviderID: &providerID,
|
||||
EmailVerified: true,
|
||||
}
|
||||
user, err = s.db.CreateUser(ctx, user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create oauth user: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
|
||||
return nil, fmt.Errorf("update login: %w", err)
|
||||
}
|
||||
|
||||
return s.generateTokens(user)
|
||||
}
|
||||
|
||||
func (s *Service) RegisterWithPassword(ctx context.Context, email, password, name string) (*TokenPair, error) {
|
||||
existing, err := s.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check existing: %w", err)
|
||||
}
|
||||
if existing != nil {
|
||||
return nil, fmt.Errorf("email already registered")
|
||||
}
|
||||
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
hashStr := string(hash)
|
||||
user := &db.User{
|
||||
Email: email,
|
||||
Name: &name,
|
||||
PasswordHash: &hashStr,
|
||||
Provider: "email",
|
||||
EmailVerified: false,
|
||||
}
|
||||
|
||||
user, err = s.db.CreateUser(ctx, user)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
return s.generateTokens(user)
|
||||
}
|
||||
|
||||
func (s *Service) LoginWithPassword(ctx context.Context, email, password string) (*TokenPair, error) {
|
||||
user, err := s.db.GetUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
if user == nil || user.PasswordHash == nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(*user.PasswordHash), []byte(password)); err != nil {
|
||||
return nil, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
if err := s.db.UpdateLastLogin(ctx, user.ID); err != nil {
|
||||
return nil, fmt.Errorf("update login: %w", err)
|
||||
}
|
||||
|
||||
return s.generateTokens(user)
|
||||
}
|
||||
|
||||
func (s *Service) generateTokens(user *db.User) (*TokenPair, error) {
|
||||
now := time.Now()
|
||||
return s.generateTokensAt(user, now)
|
||||
}
|
||||
|
||||
func (s *Service) generateTokensAt(user *db.User, now time.Time) (*TokenPair, error) {
|
||||
name := ""
|
||||
if user.Name != nil {
|
||||
name = *user.Name
|
||||
}
|
||||
|
||||
accessTokenString, err := s.signToken(user, name, "access", now, accessTokenTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign access token: %w", err)
|
||||
}
|
||||
|
||||
refreshTokenString, err := s.signToken(user, name, "refresh", now, refreshTokenTTL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign refresh token: %w", err)
|
||||
}
|
||||
|
||||
return &TokenPair{
|
||||
AccessToken: accessTokenString,
|
||||
RefreshToken: refreshTokenString,
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: int(accessTokenTTL.Seconds()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) VerifyToken(tokenString string) (*Claims, error) {
|
||||
return s.verifyTokenOfType(tokenString, "access")
|
||||
}
|
||||
|
||||
func (s *Service) VerifyRefreshToken(tokenString string) (*Claims, error) {
|
||||
return s.verifyTokenOfType(tokenString, "refresh")
|
||||
}
|
||||
|
||||
func (s *Service) RefreshTokens(ctx context.Context, refreshToken string) (*TokenPair, error) {
|
||||
claims, err := s.VerifyRefreshToken(refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := &db.User{
|
||||
ID: uuid.MustParse(claims.UserID),
|
||||
Email: claims.Email,
|
||||
}
|
||||
if claims.Name != "" {
|
||||
user.Name = &claims.Name
|
||||
}
|
||||
|
||||
if s.db != nil {
|
||||
storedUser, err := s.db.GetUserByID(ctx, user.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
if storedUser == nil {
|
||||
return nil, fmt.Errorf("user not found")
|
||||
}
|
||||
user = storedUser
|
||||
}
|
||||
|
||||
return s.generateTokens(user)
|
||||
}
|
||||
|
||||
func (s *Service) verifyTokenOfType(tokenString string, expectedType string) (*Claims, error) {
|
||||
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
|
||||
}
|
||||
return s.jwtSecret, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
|
||||
if claims.Type != expectedType {
|
||||
return nil, fmt.Errorf("invalid token type")
|
||||
}
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("invalid token")
|
||||
}
|
||||
|
||||
func (s *Service) signToken(user *db.User, name string, tokenType string, now time.Time, ttl time.Duration) (string, error) {
|
||||
claims := Claims{
|
||||
UserID: user.ID.String(),
|
||||
Email: user.Email,
|
||||
Name: name,
|
||||
Role: "authenticated",
|
||||
Type: tokenType,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Issuer: "bookra-auth",
|
||||
Subject: user.ID.String(),
|
||||
Audience: jwt.ClaimStrings{"bookra"},
|
||||
ID: generateRandomToken(12),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(ttl)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
return token.SignedString(s.jwtSecret)
|
||||
}
|
||||
|
||||
func generateRandomToken(length int) string {
|
||||
b := make([]byte, length)
|
||||
rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
Reference in New Issue
Block a user