205 lines
5.8 KiB
Go
205 lines
5.8 KiB
Go
package stmtcache
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
)
|
|
|
|
// DB is a wrapper for sql.DB, providing an additional layer for caching prepared statements
|
|
// to optimize database interactions and improve performance.
|
|
type DB struct {
|
|
*sql.DB
|
|
|
|
cachingEnabled bool
|
|
|
|
lock sync.RWMutex
|
|
statements map[string]*sql.Stmt
|
|
}
|
|
|
|
// New creates new DB wrapper with statements caching enabled
|
|
func New(db *sql.DB) *DB {
|
|
return &DB{
|
|
DB: db,
|
|
cachingEnabled: true,
|
|
statements: make(map[string]*sql.Stmt),
|
|
}
|
|
}
|
|
|
|
// SetCaching returns *DB wrapper with prepared statements caching enabled or disabled. This method should be
|
|
// called only once. It is not concurrency-safe.
|
|
func (d *DB) SetCaching(enabled bool) *DB {
|
|
d.cachingEnabled = enabled
|
|
return d
|
|
}
|
|
|
|
// CachingEnabled returns true if statements caching is enabled
|
|
func (d *DB) CachingEnabled() bool {
|
|
return d.cachingEnabled
|
|
}
|
|
|
|
// CacheSize returns the current number of prepared statements stored in the cache.
|
|
func (d *DB) CacheSize() int {
|
|
d.lock.RLock()
|
|
ret := len(d.statements)
|
|
d.lock.RUnlock()
|
|
return ret
|
|
}
|
|
|
|
// Begin starts a new SQL transaction and returns a Tx object with statement caching capabilities.
|
|
func (d *DB) Begin() (*Tx, error) {
|
|
tx, err := d.DB.Begin()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Tx{
|
|
Tx: tx,
|
|
db: d,
|
|
statements: make(map[string]*sql.Stmt),
|
|
}, nil
|
|
}
|
|
|
|
// BeginTx starts a new SQL transaction and returns a Tx object with statement caching capabilities.
|
|
func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
|
|
tx, err := d.DB.BeginTx(ctx, opts)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &Tx{
|
|
Tx: tx,
|
|
db: d,
|
|
statements: make(map[string]*sql.Stmt),
|
|
}, nil
|
|
}
|
|
|
|
// Exec executes a query that doesn't return rows. Exec delegates call to ExecContext with contex.Background()
|
|
// as parameter.
|
|
func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
return d.ExecContext(context.Background(), query, args...)
|
|
}
|
|
|
|
// ExecContext executes a query that doesn't return rows. If statement caching is enabled, ExecContext will
|
|
// first call PrepareContext to retrieve a prepared statement, and then execute a query using a prepared statement.
|
|
// If statement caching is disabled, this method delegates the call to the *sql.DB ExecContext method.
|
|
func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
|
if !d.cachingEnabled {
|
|
return d.DB.ExecContext(ctx, query, args...)
|
|
}
|
|
|
|
prepStmt, err := d.PrepareContext(ctx, query)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return prepStmt.ExecContext(ctx, args...)
|
|
}
|
|
|
|
// Query delegates call to QueryContext using context.Background() as parameter.
|
|
func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
|
return d.QueryContext(context.Background(), query, args...)
|
|
}
|
|
|
|
// QueryContext executes a query that returns rows. If statement caching is enabled, QueryContext will
|
|
// first call PrepareContext to retrieve a prepared statement, and then execute a query using a prepared statement.
|
|
// If statement caching is disabled, this method delegates the call to the *sql.DB QueryContext method.
|
|
func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
|
if !d.cachingEnabled {
|
|
return d.DB.QueryContext(ctx, query, args...)
|
|
}
|
|
|
|
prepStmt, err := d.PrepareContext(ctx, query)
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return prepStmt.QueryContext(ctx, args...)
|
|
}
|
|
|
|
// Prepare delegates call to PrepareContext using context.Background as a parameter.
|
|
func (d *DB) Prepare(query string) (*sql.Stmt, error) {
|
|
return d.PrepareContext(context.Background(), query)
|
|
}
|
|
|
|
// PrepareContext returns database prepared statement for a query. When statement caching is enabled, it returns a cached
|
|
// prepared statement if available; otherwise, it creates a new prepared statement and adds it to the cache.
|
|
// Invoking this method directly is unnecessary, as wrapper methods like Exec/ExecContext and Query/QueryContext
|
|
// will call PrepareContext before executing a query on it.
|
|
// If statement caching is disabled, this method delegates the call to the *sql.DB PrepareContext method.
|
|
//
|
|
// There's no need to manually close the returned statement; it operates within the transaction scope and will be closed
|
|
// automatically upon the completion of the transaction, whether it's committed or rolled back.
|
|
func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
|
|
if !d.cachingEnabled {
|
|
return d.DB.PrepareContext(ctx, query)
|
|
}
|
|
|
|
d.lock.RLock()
|
|
prepStmt, ok := d.statements[query]
|
|
d.lock.RUnlock()
|
|
|
|
if ok {
|
|
return prepStmt, nil
|
|
}
|
|
|
|
prepStmt, err := d.DB.PrepareContext(ctx, query)
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to prepare statement %s: %w", query, err)
|
|
}
|
|
|
|
d.lock.Lock()
|
|
|
|
existingPrepStmt, exist := d.statements[query]
|
|
|
|
// if in the meantime, another goroutine created prepared statements for this query, we will close this
|
|
// prepared statement and return the existing one.
|
|
if exist {
|
|
_ = prepStmt.Close()
|
|
d.lock.Unlock()
|
|
return existingPrepStmt, nil
|
|
}
|
|
|
|
d.statements[query] = prepStmt
|
|
d.lock.Unlock()
|
|
return prepStmt, nil
|
|
}
|
|
|
|
// ClearCache will close all cached prepared statements and clear statements cache map
|
|
func (d *DB) ClearCache() error {
|
|
d.lock.Lock()
|
|
defer d.lock.Unlock()
|
|
|
|
var err error
|
|
|
|
for _, statement := range d.statements {
|
|
closeErr := statement.Close()
|
|
|
|
if closeErr != nil {
|
|
err = errors.Join(err, closeErr)
|
|
}
|
|
}
|
|
|
|
d.statements = make(map[string]*sql.Stmt)
|
|
|
|
if err != nil {
|
|
return errors.Join(errors.New("jet: some of the prepared statements failed to close"), err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Close will clear the statements cache and close the underlying db connection
|
|
func (d *DB) Close() error {
|
|
clearErr := d.ClearCache()
|
|
closeErr := d.DB.Close()
|
|
|
|
return errors.Join(clearErr, closeErr)
|
|
}
|