From 5f220569dda43c5f75b8cf54e1d049abd8505b8b Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 19 Oct 2024 14:06:12 +0200 Subject: [PATCH] Add support for prepared statements caching. --- .circleci/config.yml | 9 +- go.mod | 2 +- internal/testutils/test_utils.go | 19 +++- mysql/statement.go | 10 -- postgres/statement.go | 10 -- sqlite/statement.go | 10 -- {internal/jet/db => stmtcache}/db.go | 63 ++++++++---- {internal/jet/db => stmtcache}/tx.go | 11 ++- tests/mysql/main_test.go | 53 +++++++--- tests/mysql/stmtcache_test.go | 128 ++++++++++++++++++++++++ tests/postgres/alltypes_test.go | 5 +- tests/postgres/main_test.go | 54 +++++++---- tests/postgres/sample_test.go | 2 +- tests/postgres/stmtcache_test.go | 139 +++++++++++++++++++++++++++ tests/postgres/values_test.go | 4 +- tests/sqlite/delete_test.go | 2 +- tests/sqlite/insert_test.go | 3 +- tests/sqlite/main_test.go | 66 ++++++++----- tests/sqlite/stmtcache_test.go | 131 +++++++++++++++++++++++++ tests/sqlite/values_test.go | 4 +- 20 files changed, 591 insertions(+), 134 deletions(-) rename {internal/jet/db => stmtcache}/db.go (72%) rename {internal/jet/db => stmtcache}/tx.go (93%) create mode 100644 tests/mysql/stmtcache_test.go create mode 100644 tests/postgres/stmtcache_test.go create mode 100644 tests/sqlite/stmtcache_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 74f2bd8..5db44b9 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -126,15 +126,20 @@ jobs: cd tests go run ./init/init.go -testsuite all - # to create test results report - run: name: Install gotestsum command: go install gotest.tools/gotestsum@latest + + # to create test results report - run: mkdir -p $TEST_RESULTS - run: name: Running tests - command: gotestsum --junitfile $TEST_RESULTS/report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... ./... + command: gotestsum --junitfile $TEST_RESULTS/report.xml --format testname -- -coverprofile=cover.out -covermode=atomic -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/...,github.com/go-jet/jet/v2/stmtcache/... ./... + + - run: + name: Running tests with statement caching enabled + command: JET_TESTS_WITH_STMT_CACHE=true go test -v ./tests/... # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ diff --git a/go.mod b/go.mod index 890dc0d..4b48c66 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/go-jet/jet/v2 -go 1.18 +go 1.20 require ( github.com/go-sql-driver/mysql v1.8.1 diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index a4a06ee..817de32 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -6,9 +6,9 @@ import ( "encoding/json" "fmt" "github.com/go-jet/jet/v2/internal/jet" - jet2 "github.com/go-jet/jet/v2/internal/jet/db" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/qrm" + "github.com/go-jet/jet/v2/stmtcache" "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/stretchr/testify/assert" @@ -26,7 +26,7 @@ var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool { }) // AssertExecAndRollback will execute and rollback statement in sql transaction -func AssertExecAndRollback(t *testing.T, stmt jet.Statement, db *jet2.DB, rowsAffected ...int64) { +func AssertExecAndRollback(t *testing.T, stmt jet.Statement, db *stmtcache.DB, rowsAffected ...int64) { tx, err := db.Begin() require.NoError(t, err) defer func() { @@ -50,8 +50,21 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int } } +// AssertExecContext assert statement execution for successful execution and number of rows affected +func AssertExecContext(t *testing.T, stmt jet.Statement, ctx context.Context, db qrm.DB, rowsAffected ...int64) { + res, err := stmt.ExecContext(ctx, db) + + require.NoError(t, err) + rows, err := res.RowsAffected() + require.NoError(t, err) + + if len(rowsAffected) > 0 { + require.Equal(t, rowsAffected[0], rows) + } +} + // ExecuteInTxAndRollback will execute function in sql transaction and then rollback transaction -func ExecuteInTxAndRollback(t *testing.T, db *jet2.DB, f func(tx qrm.DB)) { +func ExecuteInTxAndRollback(t *testing.T, db *stmtcache.DB, f func(tx qrm.DB)) { tx, err := db.Begin() require.NoError(t, err) defer func() { diff --git a/mysql/statement.go b/mysql/statement.go index 008ace5..5219ffb 100644 --- a/mysql/statement.go +++ b/mysql/statement.go @@ -2,19 +2,9 @@ package mysql import ( "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/jet/db" ) // RawStatement creates new sql statements from raw query and optional map of named arguments func RawStatement(rawQuery string, namedArguments ...RawArgs) jet.SerializerStatement { return jet.RawStatement(Dialect, rawQuery, namedArguments...) } - -// DB is a wrapper around sql.DB, adding prepared statement caching capability. -type DB = db.DB - -// NewDB creates new DB wrapper with statements caching disabled -var NewDB = db.NewDB - -// Tx is a wrapper around *sql.Tx, adding prepared statement caching capability. -type Tx = db.Tx diff --git a/postgres/statement.go b/postgres/statement.go index 4199fa9..be645c9 100644 --- a/postgres/statement.go +++ b/postgres/statement.go @@ -2,19 +2,9 @@ package postgres import ( "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/jet/db" ) // RawStatement creates new sql statements from raw query and optional map of named arguments func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { return jet.RawStatement(Dialect, rawQuery, namedArguments...) } - -// DB is a wrapper around sql.DB, adding prepared statement caching capability. -type DB = db.DB - -// NewDB creates new DB wrapper with statements caching disabled -var NewDB = db.NewDB - -// Tx is a wrapper around *sql.Tx, adding prepared statement caching capability. -type Tx = db.Tx diff --git a/sqlite/statement.go b/sqlite/statement.go index eb701e7..3e837cf 100644 --- a/sqlite/statement.go +++ b/sqlite/statement.go @@ -2,19 +2,9 @@ package sqlite import ( "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/jet/db" ) // RawStatement creates new sql statements from raw query and optional map of named arguments func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { return jet.RawStatement(Dialect, rawQuery, namedArguments...) } - -// DB is a wrapper around sql.DB, adding prepared statement caching capability. -type DB = db.DB - -// NewDB creates new DB wrapper with statements caching disabled -var NewDB = db.NewDB - -// Tx is a wrapper around *sql.Tx, adding prepared statement caching capability. -type Tx = db.Tx diff --git a/internal/jet/db/db.go b/stmtcache/db.go similarity index 72% rename from internal/jet/db/db.go rename to stmtcache/db.go index 94dfeca..a558993 100644 --- a/internal/jet/db/db.go +++ b/stmtcache/db.go @@ -1,39 +1,54 @@ -package db +package stmtcache import ( "context" "database/sql" + "errors" "fmt" "sync" ) -// DB is a wrapper around sql.DB, adding prepared statement caching capability. +// DB is a wrapper for sql.DB, providing an additional layer for caching prepared statements +// to optimize database interactions and improve performance. type DB struct { *sql.DB - statementsCaching bool + cachingEnabled bool lock sync.RWMutex statements map[string]*sql.Stmt } -// NewDB creates new DB wrapper with statements caching disabled -func NewDB(db *sql.DB) *DB { +// New creates new DB wrapper with statements caching enabled +func New(db *sql.DB) *DB { return &DB{ - DB: db, - statementsCaching: false, - statements: make(map[string]*sql.Stmt), + DB: db, + cachingEnabled: true, + statements: make(map[string]*sql.Stmt), } } -// WithStatementsCaching returns *DB wrapper with prepared statements caching enabled or disabled. This method should be +// SetCaching returns *DB wrapper with prepared statements caching enabled or disabled. This method should be // called only once. It is not concurrency-safe. -func (d *DB) WithStatementsCaching(enabled bool) *DB { - d.statementsCaching = enabled +func (d *DB) SetCaching(enabled bool) *DB { + d.cachingEnabled = enabled return d } -// Begin starts sql transaction and returns wrapped Tx object. +// CachingEnabled returns true if statements caching is enabled +func (d *DB) CachingEnabled() bool { + return d.cachingEnabled +} + +// CacheSize returns the current number of prepared statements stored in the cache. +func (d *DB) CacheSize() int { + d.lock.RLock() + ret := len(d.statements) + d.lock.RUnlock() + return ret +} + +// Begin starts a new SQL transaction and returns a Tx object with statement caching capabilities. func (d *DB) Begin() (*Tx, error) { tx, err := d.DB.Begin() @@ -48,7 +63,7 @@ func (d *DB) Begin() (*Tx, error) { }, nil } -// BeginTx starts sql transaction and returns wrapped Tx object. +// BeginTx starts a new SQL transaction and returns a Tx object with statement caching capabilities. func (d *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { tx, err := d.DB.BeginTx(ctx, opts) @@ -73,7 +88,7 @@ func (d *DB) Exec(query string, args ...interface{}) (sql.Result, error) { // first call PrepareContext to retrieve a prepared statement, and then execute a query using a prepared statement. // If statement caching is disabled, this method delegates the call to the *sql.DB ExecContext method. func (d *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - if !d.statementsCaching { + if !d.cachingEnabled { return d.DB.ExecContext(ctx, query, args...) } @@ -95,7 +110,7 @@ func (d *DB) Query(query string, args ...interface{}) (*sql.Rows, error) { // first call PrepareContext to retrieve a prepared statement, and then execute a query using a prepared statement. // If statement caching is disabled, this method delegates the call to the *sql.DB QueryContext method. func (d *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - if !d.statementsCaching { + if !d.cachingEnabled { return d.DB.QueryContext(ctx, query, args...) } @@ -122,7 +137,7 @@ func (d *DB) Prepare(query string) (*sql.Stmt, error) { // There's no need to manually close the returned statement; it operates within the transaction scope and will be closed // automatically upon the completion of the transaction, whether it's committed or rolled back. func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - if !d.statementsCaching { + if !d.cachingEnabled { return d.DB.PrepareContext(ctx, query) } @@ -157,8 +172,8 @@ func (d *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error return prepStmt, nil } -// Clear will close all cached prepared statements -func (d *DB) Clear() error { +// ClearCache will close all cached prepared statements and clear statements cache map +func (d *DB) ClearCache() error { d.lock.Lock() defer d.lock.Unlock() @@ -168,15 +183,23 @@ func (d *DB) Clear() error { closeErr := statement.Close() if closeErr != nil { - err = closeErr + err = errors.Join(err, closeErr) } } d.statements = make(map[string]*sql.Stmt) if err != nil { - return fmt.Errorf("some of the prepared statements failed to close, last err: %w", err) + return errors.Join(errors.New("jet: some of the prepared statements failed to close"), err) } return nil } + +// Close will clear the statements cache and close the underlying db connection +func (d *DB) Close() error { + clearErr := d.ClearCache() + closeErr := d.DB.Close() + + return errors.Join(clearErr, closeErr) +} diff --git a/internal/jet/db/tx.go b/stmtcache/tx.go similarity index 93% rename from internal/jet/db/tx.go rename to stmtcache/tx.go index b64b231..c02fb6b 100644 --- a/internal/jet/db/tx.go +++ b/stmtcache/tx.go @@ -1,4 +1,4 @@ -package db +package stmtcache import ( "context" @@ -7,6 +7,7 @@ import ( ) // Tx is a wrapper around *sql.Tx, adding prepared statement caching capability. +// Tx is not thread safe and should not be shared between goroutines. type Tx struct { *sql.Tx @@ -24,7 +25,7 @@ func (t *Tx) Exec(query string, args ...interface{}) (sql.Result, error) { // first call PrepareContext to retrieve a prepared statement, and then execute a query using a prepared statement. // If statement caching is disabled, this method delegates the call to the *sql.Tx ExecContext method. func (t *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { - if !t.db.statementsCaching { + if !t.db.cachingEnabled { return t.Tx.ExecContext(ctx, query, args...) } @@ -46,7 +47,7 @@ func (t *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) { // first call PrepareContext to retrieve a prepared statement, and then execute a query using a prepared statement. // If statement caching is disabled, this method delegates the call to the *sql.Tx QueryContext method. func (t *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { - if !t.db.statementsCaching { + if !t.db.cachingEnabled { return t.Tx.QueryContext(ctx, query, args...) } @@ -73,8 +74,8 @@ func (t *Tx) Prepare(query string) (*sql.Stmt, error) { // There's no need to manually close the returned statement; it operates within the transaction scope and will be closed // automatically upon the completion of the transaction, whether it's committed or rolled back. func (t *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { - if !t.db.statementsCaching { - return t.PrepareContext(ctx, query) + if !t.db.cachingEnabled { + return t.Tx.PrepareContext(ctx, query) } prepStmt, ok := t.statements[query] diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index ee02a77..e4f4322 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -3,8 +3,10 @@ package mysql import ( "context" "database/sql" + "fmt" + "github.com/go-jet/jet/v2/mysql" jetmysql "github.com/go-jet/jet/v2/mysql" - "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/require" @@ -15,14 +17,16 @@ import ( "testing" ) -var db *jetmysql.DB +var db *stmtcache.DB var source string +var withStatementCaching bool const MariaDB = "MariaDB" func init() { source = os.Getenv("MY_SQL_SOURCE") + withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" } func sourceIsMariaDB() bool { @@ -32,21 +36,38 @@ func sourceIsMariaDB() bool { func TestMain(m *testing.M) { defer profile.Start().Stop() - var err error - sqlDB, err := sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), "")) - if err != nil { - panic("Failed to connect to test db" + err.Error()) - } + func() { + fmt.Printf("\nRunning mysql tests caching enabled: %t \n", withStatementCaching) - db = jetmysql.NewDB(sqlDB).WithStatementsCaching(true) - defer db.Close() - - for i := 0; i < 2; i++ { - ret := m.Run() - if ret != 0 { - os.Exit(ret) + sqlDB, err := sql.Open("mysql", dbconfig.MySQLConnectionString(sourceIsMariaDB(), "")) + if err != nil { + panic("Failed to connect to test db" + err.Error()) } + + db = stmtcache.New(sqlDB).SetCaching(withStatementCaching) + defer db.Close() + + for i := 0; i < runCount(withStatementCaching); i++ { + ret := m.Run() + if ret != 0 { + fmt.Printf("\nFAIL: Running mysql tests failed, caching enabled: %t \n", withStatementCaching) + os.Exit(ret) + } + } + }() + +} + +func getConnectionString() string { + return dbconfig.MySQLConnectionString(sourceIsMariaDB(), "") +} + +func runCount(stmtCaching bool) int { + if stmtCaching { + return 3 } + + return 1 } var loggedSQL string @@ -70,14 +91,14 @@ func init() { }) } -func requireLogged(t *testing.T, statement postgres.Statement) { +func requireLogged(t *testing.T, statement mysql.Statement) { query, args := statement.Sql() require.Equal(t, loggedSQL, query) require.Equal(t, loggedSQLArgs, args) require.Equal(t, loggedDebugSQL, statement.DebugSql()) } -func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { +func requireQueryLogged(t *testing.T, statement mysql.Statement, rowsProcessed int64) { query, args := statement.Sql() queryLogged, argsLogged := queryInfo.Statement.Sql() diff --git a/tests/mysql/stmtcache_test.go b/tests/mysql/stmtcache_test.go new file mode 100644 index 0000000..87774c1 --- /dev/null +++ b/tests/mysql/stmtcache_test.go @@ -0,0 +1,128 @@ +package mysql + +import ( + "context" + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/mysql" + "github.com/go-jet/jet/v2/stmtcache" + "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" + "github.com/stretchr/testify/require" + "testing" +) + +func TestPreparedStatementCache(t *testing.T) { + sqlDB, err := sql.Open("mysql", getConnectionString()) + require.NoError(t, err) + stmtCachedDB := stmtcache.New(sqlDB) + defer func(db *stmtcache.DB) { + err := db.Close() + require.NoError(t, err) + require.Equal(t, db.CacheSize(), 0) + }(stmtCachedDB) + + require.True(t, stmtCachedDB.CachingEnabled()) + require.Equal(t, stmtCachedDB.CacheSize(), 0) + + testStatementCaching := func(cachingEnabled bool) { + + stmtCachedDB.SetCaching(cachingEnabled) + require.Equal(t, stmtCachedDB.CachingEnabled(), cachingEnabled) + + ctx := context.TODO() + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.BETWEEN(Int(1), Int(10))) + + query, args := stmt.Sql() + + preStmt, err := stmtCachedDB.Prepare(query) + require.NoError(t, err) + + preStmt2, err := stmtCachedDB.PrepareContext(ctx, query) + require.NoError(t, err) + require.Equal(t, preStmt == preStmt2, cachingEnabled) + + t.Run("Exec", func(t *testing.T) { + testutils.AssertExec(t, stmt, stmtCachedDB) + testutils.AssertExecContext(t, stmt, ctx, stmtCachedDB) + _, err := stmtCachedDB.Exec(query, args...) + require.NoError(t, err) + }) + + t.Run("Query", func(t *testing.T) { + var dest []model.Actor + + err := stmt.Query(stmtCachedDB, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + rows, err := stmtCachedDB.Query(query, args...) + rows.Close() + require.NoError(t, err) + + t.Run("ctx", func(t *testing.T) { + var dest []model.Actor + err := stmt.QueryContext(ctx, stmtCachedDB, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + }) + + }) + + t.Run("tx", func(t *testing.T) { + tx, err := stmtCachedDB.Begin() + require.NoError(t, err) + preStmtTx, err := tx.Prepare(query) + require.NoError(t, err) + _, err = preStmtTx.Exec(args...) + require.NoError(t, err) + preStmtTx2, err := tx.PrepareContext(ctx, query) + require.NoError(t, err) + require.Equal(t, preStmtTx == preStmtTx2, cachingEnabled) + _, err = preStmtTx2.ExecContext(ctx, args...) + require.NoError(t, err) + + t.Run("Exec", func(t *testing.T) { + testutils.AssertExec(t, stmt, tx) + testutils.AssertExecContext(t, stmt, ctx, tx) + + _, err := tx.Exec(query, args...) + require.NoError(t, err) + }) + + t.Run("Query", func(t *testing.T) { + var dest []model.Actor + err = stmt.QueryContext(ctx, tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + + rows, err := tx.Query(query, args...) + require.NoError(t, err) + require.NoError(t, rows.Close()) + }) + + t.Run("new tx", func(t *testing.T) { + txCtx, err := stmtCachedDB.BeginTx(ctx, nil) + require.NoError(t, err) + + preStmtTxCtx, err := txCtx.PrepareContext(ctx, query) + require.NoError(t, err) + require.NotEqual(t, preStmtTx, preStmtTxCtx) + + require.NoError(t, txCtx.Rollback()) + }) + + require.NoError(t, tx.Commit()) + }) + } + + testStatementCaching(true) + require.Equal(t, stmtCachedDB.CacheSize(), 1) + testStatementCaching(false) + require.Equal(t, stmtCachedDB.CacheSize(), 1) + + require.NoError(t, stmtCachedDB.ClearCache()) + require.Equal(t, stmtCachedDB.CacheSize(), 0) +} diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 6d9c0b7..d6cdfce 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,7 +1,6 @@ package postgres import ( - "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" @@ -944,7 +943,7 @@ RETURNING employee.employee_id AS "employee.employee_id", employee.manager_id AS "employee.manager_id", employee.pto_accrual AS "employee.pto_accrual"; ` - testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { var windy model.Employee windy.PtoAccrual = ptr.Of("3h") stmt := Employee.UPDATE(Employee.PtoAccrual).SET( @@ -972,7 +971,7 @@ RETURNING employee.employee_id AS "employee.employee_id", employee.manager_id AS "employee.manager_id", employee.pto_accrual AS "employee.pto_accrual"; ` - testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { var employee model.Employee employee.PtoAccrual = ptr.Of("5h") stmt := Employee.INSERT(Employee.AllColumns). diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 1f02e69..46dfd94 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/tests/internal/utils/repo" "github.com/jackc/pgx/v4/stdlib" "os" @@ -19,15 +20,17 @@ import ( _ "github.com/jackc/pgx/v4/stdlib" ) -var db *postgres.DB +var db *stmtcache.DB var testRoot string var source string +var withStatementCaching bool const CockroachDB = "COCKROACH_DB" func init() { source = os.Getenv("PG_SOURCE") + withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" } func sourceIsCockroachDB() bool { @@ -45,39 +48,50 @@ func TestMain(m *testing.M) { setTestRoot() - for _, driverName := range []string{"pgx", "postgres"} { - fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) + for _, driverName := range []string{"postgres", "pgx"} { + + fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, withStatementCaching) func() { - - connectionString := dbconfig.PostgresConnectString - - if sourceIsCockroachDB() { - connectionString = dbconfig.CockroachConnectString - } - - sqlDB, err := sql.Open(driverName, connectionString) + sqlDB, err := sql.Open(driverName, getConnectionString()) if err != nil { fmt.Println(err.Error()) panic("Failed to connect to test db") } - db = postgres.NewDB(sqlDB).WithStatementsCaching(true) - defer db.Close() + db = stmtcache.New(sqlDB).SetCaching(withStatementCaching) + defer func(db *stmtcache.DB) { + err := db.Close() + if err != nil { + fmt.Printf("ERROR: Failed to close db connection, %v", err) + } + }(db) - for i := 0; i < 2; i++ { + for i := 0; i < runCount(withStatementCaching); i++ { ret := m.Run() if ret != 0 { + fmt.Printf("\nFAIL: Running postgres tests failed for driver: %s, caching enabled: %t \n", driverName, withStatementCaching) os.Exit(ret) } } - - err = db.Clear() - - if err != nil { - os.Exit(-2) - } }() } + +} + +func runCount(stmtCaching bool) int { + if stmtCaching { + return 2 + } + + return 1 +} + +func getConnectionString() string { + if sourceIsCockroachDB() { + return dbconfig.CockroachConnectString + } + + return dbconfig.PostgresConnectString } func setTestRoot() { diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index 73b652c..7ac0506 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -1,8 +1,8 @@ package postgres import ( - "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" "testing" diff --git a/tests/postgres/stmtcache_test.go b/tests/postgres/stmtcache_test.go new file mode 100644 index 0000000..99f12a2 --- /dev/null +++ b/tests/postgres/stmtcache_test.go @@ -0,0 +1,139 @@ +package postgres + +import ( + "context" + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/stmtcache" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" + "github.com/stretchr/testify/require" + "testing" +) + +func TestPreparedStatementCache(t *testing.T) { + sqlDB, err := sql.Open("postgres", getConnectionString()) + require.NoError(t, err) + stmtCachedDB := stmtcache.New(sqlDB) + defer func(db *stmtcache.DB) { + err := db.Close() + require.NoError(t, err) + require.Equal(t, db.CacheSize(), 0) + }(stmtCachedDB) + ctx := context.TODO() + + require.True(t, stmtCachedDB.CachingEnabled()) + require.Equal(t, stmtCachedDB.CacheSize(), 0) + + testStatementCaching := func(cachingEnabled bool) { + + stmtCachedDB.SetCaching(cachingEnabled) + require.Equal(t, stmtCachedDB.CachingEnabled(), cachingEnabled) + + stmt := Actor.UPDATE(). + SET(Actor.LastName.SET(Actor.LastName)). + WHERE(Actor.ActorID.BETWEEN(Int(1), Int(10))). + RETURNING(Actor.AllColumns) + + query, args := stmt.Sql() + + preStmt, err := stmtCachedDB.Prepare(query) + require.NoError(t, err) + + preStmt2, err := stmtCachedDB.PrepareContext(ctx, query) + require.NoError(t, err) + require.Equal(t, preStmt == preStmt2, cachingEnabled) + + t.Run("Exec", func(t *testing.T) { + testutils.AssertExec(t, stmt, stmtCachedDB, 10) + testutils.AssertExecContext(t, stmt, ctx, stmtCachedDB, 10) + _, err := stmtCachedDB.Exec(query, args...) + require.NoError(t, err) + }) + + t.Run("Query", func(t *testing.T) { + var dest []model.Actor + + err := stmt.Query(stmtCachedDB, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + rows, err := stmtCachedDB.Query(query, args...) + rows.Close() + require.NoError(t, err) + + t.Run("ctx", func(t *testing.T) { + var dest []model.Actor + err := stmt.QueryContext(ctx, stmtCachedDB, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + }) + + }) + + t.Run("tx", func(t *testing.T) { + tx, err := stmtCachedDB.Begin() + require.NoError(t, err) + preStmtTx, err := tx.Prepare(query) + require.NoError(t, err) + _, err = preStmtTx.Exec(args...) + require.NoError(t, err) + preStmtTx2, err := tx.PrepareContext(ctx, query) + require.NoError(t, err) + require.Equal(t, preStmtTx == preStmtTx2, cachingEnabled) + _, err = preStmtTx2.ExecContext(ctx, args...) + require.NoError(t, err) + + t.Run("Exec", func(t *testing.T) { + testutils.AssertExec(t, stmt, tx, 10) + testutils.AssertExecContext(t, stmt, ctx, tx, 10) + + _, err := tx.Exec(query, args...) + require.NoError(t, err) + }) + + t.Run("Query", func(t *testing.T) { + var dest []model.Actor + err = stmt.QueryContext(ctx, tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + + rows, err := tx.Query(query, args...) + require.NoError(t, err) + require.NoError(t, rows.Close()) + }) + + t.Run("new tx", func(t *testing.T) { + txCtx, err := stmtCachedDB.BeginTx(ctx, nil) + require.NoError(t, err) + + preStmtTxCtx, err := txCtx.PrepareContext(ctx, query) + require.NoError(t, err) + require.NotEqual(t, preStmtTx, preStmtTxCtx) + + require.NoError(t, txCtx.Rollback()) + }) + + require.NoError(t, tx.Commit()) + }) + + // second prepared statement + stmt2 := SELECT(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.EQ(Int(11))) + + var actor model.Actor + + err = stmt2.Query(stmtCachedDB, &actor) + require.NoError(t, err) + } + + testStatementCaching(true) + require.Equal(t, stmtCachedDB.CacheSize(), 2) + testStatementCaching(false) + require.Equal(t, stmtCachedDB.CacheSize(), 2) + + // clear all + require.NoError(t, stmtCachedDB.ClearCache()) + require.Equal(t, stmtCachedDB.CacheSize(), 0) +} diff --git a/tests/postgres/values_test.go b/tests/postgres/values_test.go index 9e89f7e..ce26276 100644 --- a/tests/postgres/values_test.go +++ b/tests/postgres/values_test.go @@ -1,9 +1,9 @@ package postgres import ( - "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" + "github.com/go-jet/jet/v2/qrm" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" "github.com/stretchr/testify/assert" @@ -251,7 +251,7 @@ RETURNING payment.payment_id AS "payment.payment_id", payment.payment_date AS "payment.payment_date"; `) - testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { var payments []model.Payment diff --git a/tests/sqlite/delete_test.go b/tests/sqlite/delete_test.go index 3c8c0c4..f700902 100644 --- a/tests/sqlite/delete_test.go +++ b/tests/sqlite/delete_test.go @@ -69,7 +69,7 @@ func TestDeleteContextDeadlineExceeded(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) defer cancel() - time.Sleep(10 * time.Millisecond) + time.Sleep(20 * time.Millisecond) testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx qrm.DB) { var dest []model.Link diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go index c504d90..0cd7804 100644 --- a/tests/sqlite/insert_test.go +++ b/tests/sqlite/insert_test.go @@ -2,9 +2,8 @@ package sqlite import ( "context" - "github.com/go-jet/jet/v2/qrm" - "database/sql" "github.com/go-jet/jet/v2/internal/utils/ptr" + "github.com/go-jet/jet/v2/qrm" "math/rand" "testing" diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 593d630..a113c14 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -5,8 +5,8 @@ import ( "database/sql" "fmt" "github.com/go-jet/jet/v2/internal/utils/throw" - "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/tests/dbconfig" "github.com/pkg/profile" "github.com/stretchr/testify/require" @@ -17,38 +17,52 @@ import ( _ "github.com/mattn/go-sqlite3" ) -var db *sqlite.DB -var sampleDB *sqlite.DB -var testRoot string +var db *stmtcache.DB +var sampleDB *stmtcache.DB + +var withStatementCaching bool + +func init() { + withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" +} func TestMain(m *testing.M) { defer profile.Start().Stop() - sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) - throw.OnError(err) - db = sqlite.NewDB(sqlDB).WithStatementsCaching(true) - defer db.Close() + func() { + fmt.Printf("\nRunning sqlite tests caching enabled: %t \n", withStatementCaching) - _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) - throw.OnError(err) + sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) + throw.OnError(err) + db = stmtcache.New(sqlDB).SetCaching(withStatementCaching) + defer db.Close() - sqlSampleDB, err := sql.Open("sqlite3", dbconfig.TestSampleDBPath) - throw.OnError(err) - sampleDB = sqlite.NewDB(sqlSampleDB).WithStatementsCaching(true) - defer sampleDB.Close() + _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) + throw.OnError(err) - for i := 0; i < 2; i++ { - ret := m.Run() - if ret != 0 { - os.Exit(ret) + sqlSampleDB, err := sql.Open("sqlite3", dbconfig.TestSampleDBPath) + throw.OnError(err) + sampleDB = stmtcache.New(sqlSampleDB).SetCaching(withStatementCaching) + defer sampleDB.Close() + + for i := 0; i < runCount(withStatementCaching); i++ { + ret := m.Run() + if ret != 0 { + fmt.Printf("\nFAIL: Running sqlite tests failed, caching enabled: %t \n", withStatementCaching) + os.Exit(ret) + } } + + }() + +} + +func runCount(stmtCaching bool) int { + if stmtCaching { + return 4 } - err = sampleDB.Clear() - - if err != nil { - panic(err) - } + return 1 } var loggedSQL string @@ -72,7 +86,7 @@ func init() { }) } -func requireQueryLogged(t *testing.T, statement postgres.Statement, rowsProcessed int64) { +func requireQueryLogged(t *testing.T, statement sqlite.Statement, rowsProcessed int64) { query, args := statement.Sql() queryLogged, argsLogged := queryInfo.Statement.Sql() @@ -94,13 +108,13 @@ func requireLogged(t *testing.T, statement sqlite.Statement) { require.Equal(t, loggedDebugSQL, statement.DebugSql()) } -func beginSampleDBTx(t *testing.T) *sqlite.Tx { +func beginSampleDBTx(t *testing.T) *stmtcache.Tx { tx, err := sampleDB.BeginTx(context.Background(), nil) require.NoError(t, err) return tx } -func beginDBTx(t *testing.T) *sqlite.Tx { +func beginDBTx(t *testing.T) *stmtcache.Tx { tx, err := db.Begin() require.NoError(t, err) return tx diff --git a/tests/sqlite/stmtcache_test.go b/tests/sqlite/stmtcache_test.go new file mode 100644 index 0000000..be8cb91 --- /dev/null +++ b/tests/sqlite/stmtcache_test.go @@ -0,0 +1,131 @@ +package sqlite + +import ( + "context" + "database/sql" + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/stmtcache" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" + "testing" +) + +func TestPreparedStatementCache(t *testing.T) { + sqlDB, err := sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) + require.NoError(t, err) + stmtCachedDB := stmtcache.New(sqlDB) + defer func(db *stmtcache.DB) { + err := db.Close() + require.NoError(t, err) + require.Equal(t, db.CacheSize(), 0) + }(stmtCachedDB) + + require.True(t, stmtCachedDB.CachingEnabled()) + require.Equal(t, stmtCachedDB.CacheSize(), 0) + + testStatementCaching := func(cachingEnabled bool) { + + stmtCachedDB.SetCaching(cachingEnabled) + require.Equal(t, stmtCachedDB.CachingEnabled(), cachingEnabled) + + ctx := context.TODO() + + stmt := SELECT(Actor.AllColumns). + FROM(Actor). + WHERE(Actor.ActorID.BETWEEN(Int(1), Int(10))) + + query, args := stmt.Sql() + + preStmt, err := stmtCachedDB.Prepare(query) + require.NoError(t, err) + + preStmt2, err := stmtCachedDB.PrepareContext(ctx, query) + require.NoError(t, err) + require.Equal(t, preStmt == preStmt2, cachingEnabled) + + t.Run("Exec", func(t *testing.T) { + testutils.AssertExec(t, stmt, stmtCachedDB) + testutils.AssertExecContext(t, stmt, ctx, stmtCachedDB) + _, err := stmtCachedDB.Exec(query, args...) + require.NoError(t, err) + }) + + t.Run("Query", func(t *testing.T) { + var dest []model.Actor + + err := stmt.Query(stmtCachedDB, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + rows, err := stmtCachedDB.Query(query, args...) + rows.Close() + require.NoError(t, err) + + t.Run("ctx", func(t *testing.T) { + var dest []model.Actor + err := stmt.QueryContext(ctx, stmtCachedDB, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + }) + + }) + + t.Run("tx", func(t *testing.T) { + tx, err := stmtCachedDB.Begin() + require.NoError(t, err) + preStmtTx, err := tx.Prepare(query) + require.NoError(t, err) + _, err = preStmtTx.Exec(args...) + require.NoError(t, err) + preStmtTx2, err := tx.PrepareContext(ctx, query) + require.NoError(t, err) + require.Equal(t, preStmtTx == preStmtTx2, cachingEnabled) + _, err = preStmtTx2.ExecContext(ctx, args...) + require.NoError(t, err) + + t.Run("Exec", func(t *testing.T) { + testutils.AssertExec(t, stmt, tx) + testutils.AssertExecContext(t, stmt, ctx, tx) + + _, err := tx.Exec(query, args...) + require.NoError(t, err) + }) + + t.Run("Query", func(t *testing.T) { + var dest []model.Actor + err = stmt.QueryContext(ctx, tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 10) + + rows, err := tx.Query(query, args...) + require.NoError(t, err) + require.NoError(t, rows.Close()) + }) + + t.Run("new tx", func(t *testing.T) { + txCtx, err := stmtCachedDB.BeginTx(ctx, nil) + require.NoError(t, err) + + preStmtTxCtx, err := txCtx.PrepareContext(ctx, query) + require.NoError(t, err) + require.NotEqual(t, preStmtTx, preStmtTxCtx) + + require.NoError(t, txCtx.Rollback()) + }) + + require.NoError(t, preStmtTx.Close()) + require.NoError(t, preStmtTx2.Close()) + require.NoError(t, tx.Commit()) + }) + } + + testStatementCaching(true) + require.Equal(t, stmtCachedDB.CacheSize(), 1) + testStatementCaching(false) + require.Equal(t, stmtCachedDB.CacheSize(), 1) + + require.NoError(t, stmtCachedDB.ClearCache()) + require.Equal(t, stmtCachedDB.CacheSize(), 0) +} diff --git a/tests/sqlite/values_test.go b/tests/sqlite/values_test.go index 0793397..948d211 100644 --- a/tests/sqlite/values_test.go +++ b/tests/sqlite/values_test.go @@ -1,8 +1,8 @@ package sqlite import ( - "database/sql" "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/qrm" "github.com/stretchr/testify/require" "strings" "testing" @@ -293,7 +293,7 @@ RETURNING payment.payment_id AS "payment.payment_id", payment.last_update AS "payment.last_update"; `, "''", "`")) - testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { var payments []model.Payment err := stmt.Query(tx, &payments)