Files

117 lines
2.5 KiB
Go

package postgres
import (
"context"
"database/sql"
"embed"
"fmt"
"strconv"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/pressly/goose/v3"
)
//go:embed migrations/*.sql
var migrationFS embed.FS
type DB struct {
*sql.DB
}
type Tx struct {
*sql.Tx
}
func Open(databaseURL string) (*DB, error) {
if databaseURL == "" {
return nil, fmt.Errorf("DATABASE_URL is required")
}
db, err := sql.Open("pgx", databaseURL)
if err != nil {
return nil, err
}
db.SetMaxOpenConns(25)
db.SetMaxIdleConns(5)
db.SetConnMaxLifetime(30 * time.Minute)
return &DB{DB: db}, nil
}
func Migrate(ctx context.Context, db *sql.DB) error {
goose.SetBaseFS(migrationFS)
if err := goose.SetDialect("postgres"); err != nil {
return err
}
return goose.UpContext(ctx, db, "migrations")
}
func (db *DB) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return db.DB.ExecContext(ctx, Rebind(query), args...)
}
func (db *DB) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return db.DB.QueryContext(ctx, Rebind(query), args...)
}
func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return db.DB.QueryRowContext(ctx, Rebind(query), args...)
}
func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
tx, err := db.DB.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
return &Tx{Tx: tx}, nil
}
func (tx *Tx) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
return tx.Tx.ExecContext(ctx, Rebind(query), args...)
}
func (tx *Tx) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return tx.Tx.QueryContext(ctx, Rebind(query), args...)
}
func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
return tx.Tx.QueryRowContext(ctx, Rebind(query), args...)
}
func Rebind(query string) string {
out := make([]byte, 0, len(query)+8)
arg := 1
inSingle := false
inDouble := false
for i := 0; i < len(query); i++ {
ch := query[i]
switch ch {
case '\'':
out = append(out, ch)
if !inDouble {
if inSingle && i+1 < len(query) && query[i+1] == '\'' {
i++
out = append(out, query[i])
continue
}
inSingle = !inSingle
}
case '"':
out = append(out, ch)
if !inSingle {
inDouble = !inDouble
}
case '?':
if inSingle || inDouble {
out = append(out, ch)
continue
}
out = append(out, '$')
out = strconv.AppendInt(out, int64(arg), 10)
arg++
default:
out = append(out, ch)
}
}
return string(out)
}