mirror of
https://github.com/Dvorinka/SEEN.git
synced 2026-06-04 20:43:03 +00:00
small fix, don't worry about it
This commit is contained in:
@@ -0,0 +1,256 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/tdvorak/seen/backend/internal/config"
|
||||
"github.com/tdvorak/seen/backend/internal/domain"
|
||||
"github.com/tdvorak/seen/backend/internal/repositories/postgres"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidInput = errors.New("invalid input")
|
||||
ErrInvalidCredentials = errors.New("invalid credentials")
|
||||
ErrEmailTaken = errors.New("email already exists")
|
||||
ErrInvalidSession = errors.New("invalid session")
|
||||
ErrInvalidToken = errors.New("invalid token")
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
CreateUser(ctx context.Context, user domain.User) error
|
||||
FindUserByEmail(ctx context.Context, email string) (*domain.User, error)
|
||||
FindUserByID(ctx context.Context, userID uuid.UUID) (*domain.User, error)
|
||||
CreateSession(ctx context.Context, session domain.Session) error
|
||||
FindSessionByRefreshToken(ctx context.Context, refreshToken string) (*domain.Session, error)
|
||||
RevokeSession(ctx context.Context, sessionID uuid.UUID) error
|
||||
}
|
||||
|
||||
type Service struct {
|
||||
repo Repository
|
||||
cfg config.AuthConfig
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
type RegisterInput struct {
|
||||
Email string
|
||||
Password string
|
||||
DisplayName string
|
||||
}
|
||||
|
||||
type LoginInput struct {
|
||||
Email string
|
||||
Password string
|
||||
UserAgent string
|
||||
IP string
|
||||
}
|
||||
|
||||
type RefreshInput struct {
|
||||
RefreshToken string
|
||||
UserAgent string
|
||||
IP string
|
||||
}
|
||||
|
||||
type AuthResult struct {
|
||||
AccessToken string `json:"accessToken"`
|
||||
RefreshToken string `json:"refreshToken"`
|
||||
ExpiresAt time.Time `json:"expiresAt"`
|
||||
User domain.User `json:"user"`
|
||||
}
|
||||
|
||||
func NewService(repo Repository, cfg config.AuthConfig, log *zap.Logger) *Service {
|
||||
return &Service{repo: repo, cfg: cfg, log: log}
|
||||
}
|
||||
|
||||
func (s *Service) Register(ctx context.Context, input RegisterInput) (*AuthResult, error) {
|
||||
email := strings.ToLower(strings.TrimSpace(input.Email))
|
||||
if email == "" || len(input.Password) < 8 {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
displayName := strings.TrimSpace(input.DisplayName)
|
||||
if displayName == "" {
|
||||
displayName = strings.Split(email, "@")[0]
|
||||
}
|
||||
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(input.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hash password: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
user := domain.User{
|
||||
ID: uuid.New(),
|
||||
Email: email,
|
||||
DisplayName: displayName,
|
||||
Role: domain.RoleUser,
|
||||
PasswordHash: string(passwordHash),
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if err := s.repo.CreateUser(ctx, user); err != nil {
|
||||
if errors.Is(err, postgres.ErrUserAlreadyExists) {
|
||||
return nil, ErrEmailTaken
|
||||
}
|
||||
return nil, fmt.Errorf("create user: %w", err)
|
||||
}
|
||||
|
||||
return s.createTokens(ctx, user, "", "")
|
||||
}
|
||||
|
||||
func (s *Service) Login(ctx context.Context, input LoginInput) (*AuthResult, error) {
|
||||
email := strings.ToLower(strings.TrimSpace(input.Email))
|
||||
if email == "" || input.Password == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
user, err := s.repo.FindUserByEmail(ctx, email)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find user: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.PasswordHash), []byte(input.Password)); err != nil {
|
||||
return nil, ErrInvalidCredentials
|
||||
}
|
||||
|
||||
return s.createTokens(ctx, *user, input.UserAgent, input.IP)
|
||||
}
|
||||
|
||||
func (s *Service) Refresh(ctx context.Context, input RefreshInput) (*AuthResult, error) {
|
||||
if strings.TrimSpace(input.RefreshToken) == "" {
|
||||
return nil, ErrInvalidInput
|
||||
}
|
||||
|
||||
session, err := s.repo.FindSessionByRefreshToken(ctx, input.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find session: %w", err)
|
||||
}
|
||||
if session == nil || session.RevokedAt != nil || session.ExpiresAt.Before(time.Now().UTC()) {
|
||||
return nil, ErrInvalidSession
|
||||
}
|
||||
|
||||
if err := s.repo.RevokeSession(ctx, session.ID); err != nil {
|
||||
return nil, fmt.Errorf("revoke session: %w", err)
|
||||
}
|
||||
|
||||
user, err := s.repo.FindUserByID(ctx, session.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find user: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, ErrInvalidSession
|
||||
}
|
||||
|
||||
user.PasswordHash = ""
|
||||
|
||||
return s.createTokens(ctx, *user, input.UserAgent, input.IP)
|
||||
}
|
||||
|
||||
func (s *Service) UserFromAccessToken(ctx context.Context, accessToken string) (*domain.User, error) {
|
||||
token := strings.TrimSpace(accessToken)
|
||||
if token == "" {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
parsed, err := jwt.Parse(token, func(token *jwt.Token) (any, error) {
|
||||
if token.Method != jwt.SigningMethodHS256 {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
return []byte(s.cfg.JWTSecret), nil
|
||||
})
|
||||
if err != nil || !parsed.Valid {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
claims, ok := parsed.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
subject, ok := claims["sub"].(string)
|
||||
if !ok || strings.TrimSpace(subject) == "" {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
userID, err := uuid.Parse(subject)
|
||||
if err != nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
user, err := s.repo.FindUserByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("find user by token: %w", err)
|
||||
}
|
||||
if user == nil {
|
||||
return nil, ErrInvalidToken
|
||||
}
|
||||
|
||||
user.PasswordHash = ""
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Service) createTokens(
|
||||
ctx context.Context,
|
||||
user domain.User,
|
||||
userAgent string,
|
||||
ip string,
|
||||
) (*AuthResult, error) {
|
||||
accessToken, expiresAt, err := s.signAccessToken(user)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session := domain.Session{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
RefreshToken: uuid.NewString(),
|
||||
UserAgent: strings.TrimSpace(userAgent),
|
||||
IP: strings.TrimSpace(ip),
|
||||
ExpiresAt: time.Now().UTC().Add(time.Duration(s.cfg.RefreshTokenTTLHours) * time.Hour),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
if err := s.repo.CreateSession(ctx, session); err != nil {
|
||||
return nil, fmt.Errorf("create session: %w", err)
|
||||
}
|
||||
|
||||
user.PasswordHash = ""
|
||||
|
||||
return &AuthResult{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: session.RefreshToken,
|
||||
ExpiresAt: expiresAt,
|
||||
User: user,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Service) signAccessToken(user domain.User) (string, time.Time, error) {
|
||||
expiresAt := time.Now().UTC().Add(time.Duration(s.cfg.AccessTokenTTLMinutes) * time.Minute)
|
||||
|
||||
claims := jwt.MapClaims{
|
||||
"sub": user.ID.String(),
|
||||
"role": string(user.Role),
|
||||
"exp": expiresAt.Unix(),
|
||||
"iat": time.Now().UTC().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
signed, err := token.SignedString([]byte(s.cfg.JWTSecret))
|
||||
if err != nil {
|
||||
return "", time.Time{}, fmt.Errorf("sign access token: %w", err)
|
||||
}
|
||||
|
||||
return signed, expiresAt, nil
|
||||
}
|
||||
@@ -0,0 +1,165 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/tdvorak/seen/backend/internal/config"
|
||||
"github.com/tdvorak/seen/backend/internal/domain"
|
||||
"github.com/tdvorak/seen/backend/internal/repositories/postgres"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type inMemoryRepo struct {
|
||||
usersByEmail map[string]domain.User
|
||||
sessions map[string]domain.Session
|
||||
}
|
||||
|
||||
func newInMemoryRepo() *inMemoryRepo {
|
||||
return &inMemoryRepo{
|
||||
usersByEmail: make(map[string]domain.User),
|
||||
sessions: make(map[string]domain.Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *inMemoryRepo) CreateUser(_ context.Context, user domain.User) error {
|
||||
if _, exists := r.usersByEmail[user.Email]; exists {
|
||||
return postgres.ErrUserAlreadyExists
|
||||
}
|
||||
r.usersByEmail[user.Email] = user
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *inMemoryRepo) FindUserByEmail(_ context.Context, email string) (*domain.User, error) {
|
||||
user, exists := r.usersByEmail[email]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
copy := user
|
||||
return ©, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryRepo) FindUserByID(_ context.Context, userID uuid.UUID) (*domain.User, error) {
|
||||
for _, user := range r.usersByEmail {
|
||||
if user.ID == userID {
|
||||
copy := user
|
||||
return ©, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryRepo) CreateSession(_ context.Context, session domain.Session) error {
|
||||
r.sessions[session.RefreshToken] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *inMemoryRepo) FindSessionByRefreshToken(_ context.Context, refreshToken string) (*domain.Session, error) {
|
||||
session, exists := r.sessions[refreshToken]
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
copy := session
|
||||
return ©, nil
|
||||
}
|
||||
|
||||
func (r *inMemoryRepo) RevokeSession(_ context.Context, sessionID uuid.UUID) error {
|
||||
now := time.Now().UTC()
|
||||
for token, session := range r.sessions {
|
||||
if session.ID == sessionID {
|
||||
session.RevokedAt = &now
|
||||
r.sessions[token] = session
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("session not found")
|
||||
}
|
||||
|
||||
func TestRegisterValidation(t *testing.T) {
|
||||
svc := NewService(newInMemoryRepo(), config.AuthConfig{AccessTokenTTLMinutes: 10, RefreshTokenTTLHours: 1, JWTSecret: "test"}, zap.NewNop())
|
||||
|
||||
_, err := svc.Register(context.Background(), RegisterInput{
|
||||
Email: "",
|
||||
Password: "short",
|
||||
})
|
||||
if !errors.Is(err, ErrInvalidInput) {
|
||||
t.Fatalf("expected invalid input error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegisterAndLoginFlow(t *testing.T) {
|
||||
repo := newInMemoryRepo()
|
||||
svc := NewService(repo, config.AuthConfig{AccessTokenTTLMinutes: 10, RefreshTokenTTLHours: 1, JWTSecret: "test"}, zap.NewNop())
|
||||
|
||||
registered, err := svc.Register(context.Background(), RegisterInput{
|
||||
Email: "user@example.com",
|
||||
Password: "password123",
|
||||
DisplayName: "Seen User",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
if registered.AccessToken == "" || registered.RefreshToken == "" {
|
||||
t.Fatalf("expected issued tokens")
|
||||
}
|
||||
|
||||
loggedIn, err := svc.Login(context.Background(), LoginInput{
|
||||
Email: "user@example.com",
|
||||
Password: "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("login failed: %v", err)
|
||||
}
|
||||
|
||||
if loggedIn.AccessToken == "" {
|
||||
t.Fatalf("expected login access token")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginWrongPassword(t *testing.T) {
|
||||
repo := newInMemoryRepo()
|
||||
svc := NewService(repo, config.AuthConfig{AccessTokenTTLMinutes: 10, RefreshTokenTTLHours: 1, JWTSecret: "test"}, zap.NewNop())
|
||||
|
||||
_, err := svc.Register(context.Background(), RegisterInput{
|
||||
Email: "user@example.com",
|
||||
Password: "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
_, err = svc.Login(context.Background(), LoginInput{
|
||||
Email: "user@example.com",
|
||||
Password: "wrongpass",
|
||||
})
|
||||
if !errors.Is(err, ErrInvalidCredentials) {
|
||||
t.Fatalf("expected invalid credentials error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFromAccessToken(t *testing.T) {
|
||||
repo := newInMemoryRepo()
|
||||
svc := NewService(repo, config.AuthConfig{AccessTokenTTLMinutes: 10, RefreshTokenTTLHours: 1, JWTSecret: "test"}, zap.NewNop())
|
||||
|
||||
authResult, err := svc.Register(context.Background(), RegisterInput{
|
||||
Email: "user@example.com",
|
||||
Password: "password123",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
|
||||
user, err := svc.UserFromAccessToken(context.Background(), authResult.AccessToken)
|
||||
if err != nil {
|
||||
t.Fatalf("user from access token failed: %v", err)
|
||||
}
|
||||
|
||||
if user.Email != "user@example.com" {
|
||||
t.Fatalf("expected user email, got %s", user.Email)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user