mirror of
https://github.com/Dvorinka/SEEN.git
synced 2026-06-04 12:33:02 +00:00
257 lines
6.5 KiB
Go
257 lines
6.5 KiB
Go
package auth
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/golang-jwt/jwt/v5"
|
|
"github.com/google/uuid"
|
|
"github.com/tdvorak/seen/backend/internal/config"
|
|
"github.com/tdvorak/seen/backend/internal/domain"
|
|
"github.com/tdvorak/seen/backend/internal/repositories/postgres"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var (
|
|
ErrInvalidInput = errors.New("invalid input")
|
|
ErrInvalidCredentials = errors.New("invalid credentials")
|
|
ErrEmailTaken = errors.New("email already exists")
|
|
ErrInvalidSession = errors.New("invalid session")
|
|
ErrInvalidToken = errors.New("invalid token")
|
|
)
|
|
|
|
type Repository interface {
|
|
CreateUser(ctx context.Context, user domain.User) error
|
|
FindUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
|
FindUserByID(ctx context.Context, userID uuid.UUID) (*domain.User, error)
|
|
CreateSession(ctx context.Context, session domain.Session) error
|
|
FindSessionByRefreshToken(ctx context.Context, refreshToken string) (*domain.Session, error)
|
|
RevokeSession(ctx context.Context, sessionID uuid.UUID) error
|
|
}
|
|
|
|
type Service struct {
|
|
repo Repository
|
|
cfg config.AuthConfig
|
|
log *zap.Logger
|
|
}
|
|
|
|
type RegisterInput struct {
|
|
Email string
|
|
Password string
|
|
DisplayName string
|
|
}
|
|
|
|
type LoginInput struct {
|
|
Email string
|
|
Password string
|
|
UserAgent string
|
|
IP string
|
|
}
|
|
|
|
type RefreshInput struct {
|
|
RefreshToken string
|
|
UserAgent string
|
|
IP string
|
|
}
|
|
|
|
type AuthResult struct {
|
|
AccessToken string `json:"accessToken"`
|
|
RefreshToken string `json:"refreshToken"`
|
|
ExpiresAt time.Time `json:"expiresAt"`
|
|
User domain.User `json:"user"`
|
|
}
|
|
|
|
func NewService(repo Repository, cfg config.AuthConfig, log *zap.Logger) *Service {
|
|
return &Service{repo: repo, cfg: cfg, log: log}
|
|
}
|
|
|
|
func (s *Service) Register(ctx context.Context, input RegisterInput) (*AuthResult, error) {
|
|
email := strings.ToLower(strings.TrimSpace(input.Email))
|
|
if email == "" || len(input.Password) < 8 {
|
|
return nil, ErrInvalidInput
|
|
}
|
|
|
|
displayName := strings.TrimSpace(input.DisplayName)
|
|
if displayName == "" {
|
|
displayName = strings.Split(email, "@")[0]
|
|
}
|
|
|
|
passwordHash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("hash password: %w", err)
|
|
}
|
|
|
|
now := time.Now().UTC()
|
|
user := domain.User{
|
|
ID: uuid.New(),
|
|
Email: email,
|
|
DisplayName: displayName,
|
|
Role: domain.RoleUser,
|
|
PasswordHash: string(passwordHash),
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
|
|
if err := s.repo.CreateUser(ctx, user); err != nil {
|
|
if errors.Is(err, postgres.ErrUserAlreadyExists) {
|
|
return nil, ErrEmailTaken
|
|
}
|
|
return nil, fmt.Errorf("create user: %w", err)
|
|
}
|
|
|
|
return s.createTokens(ctx, user, "", "")
|
|
}
|
|
|
|
func (s *Service) Login(ctx context.Context, input LoginInput) (*AuthResult, error) {
|
|
email := strings.ToLower(strings.TrimSpace(input.Email))
|
|
if email == "" || input.Password == "" {
|
|
return nil, ErrInvalidInput
|
|
}
|
|
|
|
user, err := s.repo.FindUserByEmail(ctx, email)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("find user: %w", err)
|
|
}
|
|
if user == nil {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(input.Password)); err != nil {
|
|
return nil, ErrInvalidCredentials
|
|
}
|
|
|
|
return s.createTokens(ctx, *user, input.UserAgent, input.IP)
|
|
}
|
|
|
|
func (s *Service) Refresh(ctx context.Context, input RefreshInput) (*AuthResult, error) {
|
|
if strings.TrimSpace(input.RefreshToken) == "" {
|
|
return nil, ErrInvalidInput
|
|
}
|
|
|
|
session, err := s.repo.FindSessionByRefreshToken(ctx, input.RefreshToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("find session: %w", err)
|
|
}
|
|
if session == nil || session.RevokedAt != nil || session.ExpiresAt.Before(time.Now().UTC()) {
|
|
return nil, ErrInvalidSession
|
|
}
|
|
|
|
if err := s.repo.RevokeSession(ctx, session.ID); err != nil {
|
|
return nil, fmt.Errorf("revoke session: %w", err)
|
|
}
|
|
|
|
user, err := s.repo.FindUserByID(ctx, session.UserID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("find user: %w", err)
|
|
}
|
|
if user == nil {
|
|
return nil, ErrInvalidSession
|
|
}
|
|
|
|
user.PasswordHash = ""
|
|
|
|
return s.createTokens(ctx, *user, input.UserAgent, input.IP)
|
|
}
|
|
|
|
func (s *Service) UserFromAccessToken(ctx context.Context, accessToken string) (*domain.User, error) {
|
|
token := strings.TrimSpace(accessToken)
|
|
if token == "" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
parsed, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
|
if token.Method != jwt.SigningMethodHS256 {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
return []byte(s.cfg.JWTSecret), nil
|
|
})
|
|
if err != nil || !parsed.Valid {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
claims, ok := parsed.Claims.(jwt.MapClaims)
|
|
if !ok {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
subject, ok := claims["sub"].(string)
|
|
if !ok || strings.TrimSpace(subject) == "" {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
userID, err := uuid.Parse(subject)
|
|
if err != nil {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
user, err := s.repo.FindUserByID(ctx, userID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("find user by token: %w", err)
|
|
}
|
|
if user == nil {
|
|
return nil, ErrInvalidToken
|
|
}
|
|
|
|
user.PasswordHash = ""
|
|
return user, nil
|
|
}
|
|
|
|
func (s *Service) createTokens(
|
|
ctx context.Context,
|
|
user domain.User,
|
|
userAgent string,
|
|
ip string,
|
|
) (*AuthResult, error) {
|
|
accessToken, expiresAt, err := s.signAccessToken(user)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
session := domain.Session{
|
|
ID: uuid.New(),
|
|
UserID: user.ID,
|
|
RefreshToken: uuid.NewString(),
|
|
UserAgent: strings.TrimSpace(userAgent),
|
|
IP: strings.TrimSpace(ip),
|
|
ExpiresAt: time.Now().UTC().Add(time.Duration(s.cfg.RefreshTokenTTLHours) * time.Hour),
|
|
CreatedAt: time.Now().UTC(),
|
|
}
|
|
|
|
if err := s.repo.CreateSession(ctx, session); err != nil {
|
|
return nil, fmt.Errorf("create session: %w", err)
|
|
}
|
|
|
|
user.PasswordHash = ""
|
|
|
|
return &AuthResult{
|
|
AccessToken: accessToken,
|
|
RefreshToken: session.RefreshToken,
|
|
ExpiresAt: expiresAt,
|
|
User: user,
|
|
}, nil
|
|
}
|
|
|
|
func (s *Service) signAccessToken(user domain.User) (string, time.Time, error) {
|
|
expiresAt := time.Now().UTC().Add(time.Duration(s.cfg.AccessTokenTTLMinutes) * time.Minute)
|
|
|
|
claims := jwt.MapClaims{
|
|
"sub": user.ID.String(),
|
|
"role": string(user.Role),
|
|
"exp": expiresAt.Unix(),
|
|
"iat": time.Now().UTC().Unix(),
|
|
}
|
|
|
|
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
|
signed, err := token.SignedString([]byte(s.cfg.JWTSecret))
|
|
if err != nil {
|
|
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
|
|
}
|
|
|
|
return signed, expiresAt, nil
|
|
}
|