Files
SEEN/backend/internal/services/auth/service.go
T
2026-04-10 12:06:24 +02:00

257 lines
6.5 KiB
Go

package auth
import (
"context"
"errors"
"fmt"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
"github.com/tdvorak/seen/backend/internal/config"
"github.com/tdvorak/seen/backend/internal/domain"
"github.com/tdvorak/seen/backend/internal/repositories/postgres"
"go.uber.org/zap"
"golang.org/x/crypto/bcrypt"
)
var (
ErrInvalidInput = errors.New("invalid input")
ErrInvalidCredentials = errors.New("invalid credentials")
ErrEmailTaken = errors.New("email already exists")
ErrInvalidSession = errors.New("invalid session")
ErrInvalidToken = errors.New("invalid token")
)
type Repository interface {
CreateUser(ctx context.Context, user domain.User) error
FindUserByEmail(ctx context.Context, email string) (*domain.User, error)
FindUserByID(ctx context.Context, userID uuid.UUID) (*domain.User, error)
CreateSession(ctx context.Context, session domain.Session) error
FindSessionByRefreshToken(ctx context.Context, refreshToken string) (*domain.Session, error)
RevokeSession(ctx context.Context, sessionID uuid.UUID) error
}
type Service struct {
repo Repository
cfg config.AuthConfig
log *zap.Logger
}
type RegisterInput struct {
Email string
Password string
DisplayName string
}
type LoginInput struct {
Email string
Password string
UserAgent string
IP string
}
type RefreshInput struct {
RefreshToken string
UserAgent string
IP string
}
type AuthResult struct {
AccessToken string `json:"accessToken"`
RefreshToken string `json:"refreshToken"`
ExpiresAt time.Time `json:"expiresAt"`
User domain.User `json:"user"`
}
func NewService(repo Repository, cfg config.AuthConfig, log *zap.Logger) *Service {
return &Service{repo: repo, cfg: cfg, log: log}
}
func (s *Service) Register(ctx context.Context, input RegisterInput) (*AuthResult, error) {
email := strings.ToLower(strings.TrimSpace(input.Email))
if email == "" || len(input.Password) < 8 {
return nil, ErrInvalidInput
}
displayName := strings.TrimSpace(input.DisplayName)
if displayName == "" {
displayName = strings.Split(email, "@")[0]
}
passwordHash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
if err != nil {
return nil, fmt.Errorf("hash password: %w", err)
}
now := time.Now().UTC()
user := domain.User{
ID: uuid.New(),
Email: email,
DisplayName: displayName,
Role: domain.RoleUser,
PasswordHash: string(passwordHash),
CreatedAt: now,
UpdatedAt: now,
}
if err := s.repo.CreateUser(ctx, user); err != nil {
if errors.Is(err, postgres.ErrUserAlreadyExists) {
return nil, ErrEmailTaken
}
return nil, fmt.Errorf("create user: %w", err)
}
return s.createTokens(ctx, user, "", "")
}
func (s *Service) Login(ctx context.Context, input LoginInput) (*AuthResult, error) {
email := strings.ToLower(strings.TrimSpace(input.Email))
if email == "" || input.Password == "" {
return nil, ErrInvalidInput
}
user, err := s.repo.FindUserByEmail(ctx, email)
if err != nil {
return nil, fmt.Errorf("find user: %w", err)
}
if user == nil {
return nil, ErrInvalidCredentials
}
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(input.Password)); err != nil {
return nil, ErrInvalidCredentials
}
return s.createTokens(ctx, *user, input.UserAgent, input.IP)
}
func (s *Service) Refresh(ctx context.Context, input RefreshInput) (*AuthResult, error) {
if strings.TrimSpace(input.RefreshToken) == "" {
return nil, ErrInvalidInput
}
session, err := s.repo.FindSessionByRefreshToken(ctx, input.RefreshToken)
if err != nil {
return nil, fmt.Errorf("find session: %w", err)
}
if session == nil || session.RevokedAt != nil || session.ExpiresAt.Before(time.Now().UTC()) {
return nil, ErrInvalidSession
}
if err := s.repo.RevokeSession(ctx, session.ID); err != nil {
return nil, fmt.Errorf("revoke session: %w", err)
}
user, err := s.repo.FindUserByID(ctx, session.UserID)
if err != nil {
return nil, fmt.Errorf("find user: %w", err)
}
if user == nil {
return nil, ErrInvalidSession
}
user.PasswordHash = ""
return s.createTokens(ctx, *user, input.UserAgent, input.IP)
}
func (s *Service) UserFromAccessToken(ctx context.Context, accessToken string) (*domain.User, error) {
token := strings.TrimSpace(accessToken)
if token == "" {
return nil, ErrInvalidToken
}
parsed, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
if token.Method != jwt.SigningMethodHS256 {
return nil, ErrInvalidToken
}
return []byte(s.cfg.JWTSecret), nil
})
if err != nil || !parsed.Valid {
return nil, ErrInvalidToken
}
claims, ok := parsed.Claims.(jwt.MapClaims)
if !ok {
return nil, ErrInvalidToken
}
subject, ok := claims["sub"].(string)
if !ok || strings.TrimSpace(subject) == "" {
return nil, ErrInvalidToken
}
userID, err := uuid.Parse(subject)
if err != nil {
return nil, ErrInvalidToken
}
user, err := s.repo.FindUserByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("find user by token: %w", err)
}
if user == nil {
return nil, ErrInvalidToken
}
user.PasswordHash = ""
return user, nil
}
func (s *Service) createTokens(
ctx context.Context,
user domain.User,
userAgent string,
ip string,
) (*AuthResult, error) {
accessToken, expiresAt, err := s.signAccessToken(user)
if err != nil {
return nil, err
}
session := domain.Session{
ID: uuid.New(),
UserID: user.ID,
RefreshToken: uuid.NewString(),
UserAgent: strings.TrimSpace(userAgent),
IP: strings.TrimSpace(ip),
ExpiresAt: time.Now().UTC().Add(time.Duration(s.cfg.RefreshTokenTTLHours) * time.Hour),
CreatedAt: time.Now().UTC(),
}
if err := s.repo.CreateSession(ctx, session); err != nil {
return nil, fmt.Errorf("create session: %w", err)
}
user.PasswordHash = ""
return &AuthResult{
AccessToken: accessToken,
RefreshToken: session.RefreshToken,
ExpiresAt: expiresAt,
User: user,
}, nil
}
func (s *Service) signAccessToken(user domain.User) (string, time.Time, error) {
expiresAt := time.Now().UTC().Add(time.Duration(s.cfg.AccessTokenTTLMinutes) * time.Minute)
claims := jwt.MapClaims{
"sub": user.ID.String(),
"role": string(user.Role),
"exp": expiresAt.Unix(),
"iat": time.Now().UTC().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
signed, err := token.SignedString([]byte(s.cfg.JWTSecret))
if err != nil {
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
}
return signed, expiresAt, nil
}