package cache import ( "context" "encoding/json" "errors" "fmt" "time" "github.com/redis/go-redis/v9" ) var ( ErrCacheMiss = errors.New("cache miss") ErrCacheSet = errors.New("cache set failed") ) // Service provides high-level caching operations type Service struct { client *redis.Client } // NewService creates a new cache service func NewService(client *redis.Client) *Service { return &Service{client: client} } // Get retrieves a value from cache and unmarshals it into the target func (s *Service) Get(ctx context.Context, key string, target interface{}) error { data, err := s.client.Get(ctx, key).Bytes() if err != nil { if errors.Is(err, redis.Nil) { return ErrCacheMiss } return fmt.Errorf("cache get: %w", err) } if err := json.Unmarshal(data, target); err != nil { return fmt.Errorf("cache unmarshal: %w", err) } return nil } // Set stores a value in cache with the given TTL func (s *Service) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error { data, err := json.Marshal(value) if err != nil { return fmt.Errorf("cache marshal: %w", err) } if err := s.client.Set(ctx, key, data, ttl).Err(); err != nil { return fmt.Errorf("%w: %v", ErrCacheSet, err) } return nil } // Delete removes a key from cache func (s *Service) Delete(ctx context.Context, keys ...string) error { if len(keys) == 0 { return nil } if err := s.client.Del(ctx, keys...).Err(); err != nil { return fmt.Errorf("cache delete: %w", err) } return nil } // Exists checks if a key exists in cache func (s *Service) Exists(ctx context.Context, key string) (bool, error) { count, err := s.client.Exists(ctx, key).Result() if err != nil { return false, fmt.Errorf("cache exists: %w", err) } return count > 0, nil } // Expire sets a TTL on an existing key func (s *Service) Expire(ctx context.Context, key string, ttl time.Duration) error { if err := s.client.Expire(ctx, key, ttl).Err(); err != nil { return fmt.Errorf("cache expire: %w", err) } return nil } // TTL returns the remaining time to live for a key func (s *Service) TTL(ctx context.Context, key string) (time.Duration, error) { ttl, err := s.client.TTL(ctx, key).Result() if err != nil { return 0, fmt.Errorf("cache ttl: %w", err) } return ttl, nil } // Increment atomically increments a counter func (s *Service) Increment(ctx context.Context, key string) (int64, error) { val, err := s.client.Incr(ctx, key).Result() if err != nil { return 0, fmt.Errorf("cache increment: %w", err) } return val, nil } // IncrementBy atomically increments a counter by a specific amount func (s *Service) IncrementBy(ctx context.Context, key string, amount int64) (int64, error) { val, err := s.client.IncrBy(ctx, key, amount).Result() if err != nil { return 0, fmt.Errorf("cache increment by: %w", err) } return val, nil } // SetNX sets a key only if it doesn't exist (useful for locks) func (s *Service) SetNX(ctx context.Context, key string, value interface{}, ttl time.Duration) (bool, error) { data, err := json.Marshal(value) if err != nil { return false, fmt.Errorf("cache marshal: %w", err) } ok, err := s.client.SetNX(ctx, key, data, ttl).Result() if err != nil { return false, fmt.Errorf("cache setnx: %w", err) } return ok, nil } // GetSet atomically sets a new value and returns the old value func (s *Service) GetSet(ctx context.Context, key string, newValue interface{}, target interface{}) error { data, err := json.Marshal(newValue) if err != nil { return fmt.Errorf("cache marshal: %w", err) } oldData, err := s.client.GetSet(ctx, key, data).Bytes() if err != nil { if errors.Is(err, redis.Nil) { return ErrCacheMiss } return fmt.Errorf("cache getset: %w", err) } if err := json.Unmarshal(oldData, target); err != nil { return fmt.Errorf("cache unmarshal: %w", err) } return nil } // MGet retrieves multiple keys at once func (s *Service) MGet(ctx context.Context, keys ...string) ([]interface{}, error) { if len(keys) == 0 { return []interface{}{}, nil } values, err := s.client.MGet(ctx, keys...).Result() if err != nil { return nil, fmt.Errorf("cache mget: %w", err) } return values, nil } // MSet sets multiple key-value pairs at once func (s *Service) MSet(ctx context.Context, pairs map[string]interface{}) error { if len(pairs) == 0 { return nil } // Convert map to slice of interface{} for Redis args := make([]interface{}, 0, len(pairs)*2) for key, value := range pairs { data, err := json.Marshal(value) if err != nil { return fmt.Errorf("cache marshal %s: %w", key, err) } args = append(args, key, data) } if err := s.client.MSet(ctx, args...).Err(); err != nil { return fmt.Errorf("cache mset: %w", err) } return nil } // FlushDB clears all keys in the current database (use with caution!) func (s *Service) FlushDB(ctx context.Context) error { if err := s.client.FlushDB(ctx).Err(); err != nil { return fmt.Errorf("cache flush: %w", err) } return nil } // Keys returns all keys matching a pattern func (s *Service) Keys(ctx context.Context, pattern string) ([]string, error) { keys, err := s.client.Keys(ctx, pattern).Result() if err != nil { return nil, fmt.Errorf("cache keys: %w", err) } return keys, nil } // Scan iterates over keys matching a pattern (better than Keys for large datasets) func (s *Service) Scan(ctx context.Context, pattern string, count int64) ([]string, error) { var keys []string var cursor uint64 for { var batch []string var err error batch, cursor, err = s.client.Scan(ctx, cursor, pattern, count).Result() if err != nil { return nil, fmt.Errorf("cache scan: %w", err) } keys = append(keys, batch...) if cursor == 0 { break } } return keys, nil } // Ping checks if the cache is responsive func (s *Service) Ping(ctx context.Context) error { return s.client.Ping(ctx).Err() } // Close closes the cache connection func (s *Service) Close() error { return s.client.Close() } // Client returns the underlying Redis client for advanced operations func (s *Service) Client() *redis.Client { return s.client }