package auth import ( "context" "errors" "fmt" "strings" "time" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/tdvorak/seen/backend/internal/config" "github.com/tdvorak/seen/backend/internal/domain" "github.com/tdvorak/seen/backend/internal/repositories/postgres" "go.uber.org/zap" "golang.org/x/crypto/bcrypt" ) var ( ErrInvalidInput = errors.New("invalid input") ErrInvalidCredentials = errors.New("invalid credentials") ErrEmailTaken = errors.New("email already exists") ErrInvalidSession = errors.New("invalid session") ErrInvalidToken = errors.New("invalid token") ) type Repository interface { CreateUser(ctx context.Context, user domain.User) error FindUserByEmail(ctx context.Context, email string) (*domain.User, error) FindUserByID(ctx context.Context, userID uuid.UUID) (*domain.User, error) CreateSession(ctx context.Context, session domain.Session) error FindSessionByRefreshToken(ctx context.Context, refreshToken string) (*domain.Session, error) RevokeSession(ctx context.Context, sessionID uuid.UUID) error } type Service struct { repo Repository cfg config.AuthConfig log *zap.Logger } type RegisterInput struct { Email string Password string DisplayName string } type LoginInput struct { Email string Password string UserAgent string IP string } type RefreshInput struct { RefreshToken string UserAgent string IP string } type AuthResult struct { AccessToken string `json:"accessToken"` RefreshToken string `json:"refreshToken"` ExpiresAt time.Time `json:"expiresAt"` User domain.User `json:"user"` } func NewService(repo Repository, cfg config.AuthConfig, log *zap.Logger) *Service { return &Service{repo: repo, cfg: cfg, log: log} } func (s *Service) Register(ctx context.Context, input RegisterInput) (*AuthResult, error) { email := strings.ToLower(strings.TrimSpace(input.Email)) if email == "" || len(input.Password) < 8 { return nil, ErrInvalidInput } displayName := strings.TrimSpace(input.DisplayName) if displayName == "" { displayName = strings.Split(email, "@")[0] } passwordHash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost) if err != nil { return nil, fmt.Errorf("hash password: %w", err) } now := time.Now().UTC() user := domain.User{ ID: uuid.New(), Email: email, DisplayName: displayName, Role: domain.RoleUser, PasswordHash: string(passwordHash), CreatedAt: now, UpdatedAt: now, } if err := s.repo.CreateUser(ctx, user); err != nil { if errors.Is(err, postgres.ErrUserAlreadyExists) { return nil, ErrEmailTaken } return nil, fmt.Errorf("create user: %w", err) } return s.createTokens(ctx, user, "", "") } func (s *Service) Login(ctx context.Context, input LoginInput) (*AuthResult, error) { email := strings.ToLower(strings.TrimSpace(input.Email)) if email == "" || input.Password == "" { return nil, ErrInvalidInput } user, err := s.repo.FindUserByEmail(ctx, email) if err != nil { return nil, fmt.Errorf("find user: %w", err) } if user == nil { return nil, ErrInvalidCredentials } if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(input.Password)); err != nil { return nil, ErrInvalidCredentials } return s.createTokens(ctx, *user, input.UserAgent, input.IP) } func (s *Service) Refresh(ctx context.Context, input RefreshInput) (*AuthResult, error) { if strings.TrimSpace(input.RefreshToken) == "" { return nil, ErrInvalidInput } session, err := s.repo.FindSessionByRefreshToken(ctx, input.RefreshToken) if err != nil { return nil, fmt.Errorf("find session: %w", err) } if session == nil || session.RevokedAt != nil || session.ExpiresAt.Before(time.Now().UTC()) { return nil, ErrInvalidSession } if err := s.repo.RevokeSession(ctx, session.ID); err != nil { return nil, fmt.Errorf("revoke session: %w", err) } user, err := s.repo.FindUserByID(ctx, session.UserID) if err != nil { return nil, fmt.Errorf("find user: %w", err) } if user == nil { return nil, ErrInvalidSession } user.PasswordHash = "" return s.createTokens(ctx, *user, input.UserAgent, input.IP) } func (s *Service) UserFromAccessToken(ctx context.Context, accessToken string) (*domain.User, error) { token := strings.TrimSpace(accessToken) if token == "" { return nil, ErrInvalidToken } parsed, err := jwt.Parse(token, func(token *jwt.Token) (any, error) { if token.Method != jwt.SigningMethodHS256 { return nil, ErrInvalidToken } return []byte(s.cfg.JWTSecret), nil }) if err != nil || !parsed.Valid { return nil, ErrInvalidToken } claims, ok := parsed.Claims.(jwt.MapClaims) if !ok { return nil, ErrInvalidToken } subject, ok := claims["sub"].(string) if !ok || strings.TrimSpace(subject) == "" { return nil, ErrInvalidToken } userID, err := uuid.Parse(subject) if err != nil { return nil, ErrInvalidToken } user, err := s.repo.FindUserByID(ctx, userID) if err != nil { return nil, fmt.Errorf("find user by token: %w", err) } if user == nil { return nil, ErrInvalidToken } user.PasswordHash = "" return user, nil } func (s *Service) createTokens( ctx context.Context, user domain.User, userAgent string, ip string, ) (*AuthResult, error) { accessToken, expiresAt, err := s.signAccessToken(user) if err != nil { return nil, err } session := domain.Session{ ID: uuid.New(), UserID: user.ID, RefreshToken: uuid.NewString(), UserAgent: strings.TrimSpace(userAgent), IP: strings.TrimSpace(ip), ExpiresAt: time.Now().UTC().Add(time.Duration(s.cfg.RefreshTokenTTLHours) * time.Hour), CreatedAt: time.Now().UTC(), } if err := s.repo.CreateSession(ctx, session); err != nil { return nil, fmt.Errorf("create session: %w", err) } user.PasswordHash = "" return &AuthResult{ AccessToken: accessToken, RefreshToken: session.RefreshToken, ExpiresAt: expiresAt, User: user, }, nil } func (s *Service) signAccessToken(user domain.User) (string, time.Time, error) { expiresAt := time.Now().UTC().Add(time.Duration(s.cfg.AccessTokenTTLMinutes) * time.Minute) claims := jwt.MapClaims{ "sub": user.ID.String(), "role": string(user.Role), "exp": expiresAt.Unix(), "iat": time.Now().UTC().Unix(), } token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) signed, err := token.SignedString([]byte(s.cfg.JWTSecret)) if err != nil { return "", time.Time{}, fmt.Errorf("sign access token: %w", err) } return signed, expiresAt, nil }