From bc776f947b70e0a5037c24be59ebf3865b3d4de4 Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 5 May 2022 13:01:42 +0200 Subject: [PATCH] Add support for CockorachDB. --- .circleci/config.yml | 23 ++- generator/postgres/query_set.go | 4 +- internal/testutils/test_utils.go | 33 ++++ postgres/cast_test.go | 2 +- postgres/dialect_test.go | 12 +- postgres/expressions_test.go | 2 +- postgres/insert_statement_test.go | 4 +- postgres/literal.go | 4 +- postgres/literal_test.go | 2 +- tests/Makefile | 7 + tests/dbconfig/dbconfig.go | 19 +- tests/docker-compose.yaml | 13 ++ tests/init/init.go | 97 ++++++---- tests/mysql/alltypes_test.go | 4 +- tests/mysql/delete_test.go | 38 +--- tests/mysql/insert_test.go | 257 ++++++++++++-------------- tests/mysql/main_test.go | 6 - tests/mysql/update_test.go | 116 +++++------- tests/postgres/alltypes_test.go | 183 ++++++++++++++---- tests/postgres/chinook_db_test.go | 82 ++++---- tests/postgres/delete_test.go | 106 +++++------ tests/postgres/generator_test.go | 2 + tests/postgres/insert_test.go | 244 +++++++++++------------- tests/postgres/lock_test.go | 4 + tests/postgres/main_test.go | 33 +++- tests/postgres/raw_statements_test.go | 45 ++--- tests/postgres/sample_test.go | 98 ++-------- tests/postgres/select_test.go | 110 ++++++++--- tests/postgres/update_test.go | 221 +++++++++------------- tests/postgres/util_test.go | 57 ------ tests/postgres/with_test.go | 26 ++- tests/sqlite/insert_test.go | 220 ++++++++++------------ tests/sqlite/main_test.go | 3 +- 33 files changed, 1040 insertions(+), 1037 deletions(-) delete mode 100644 tests/postgres/util_test.go diff --git a/.circleci/config.yml b/.circleci/config.yml index 1571843..f83283c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -33,6 +33,13 @@ jobs: MYSQL_USER: jet MYSQL_PASSWORD: jet + - image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 + command: ['start-single-node', '--insecure'] + environment: + COCKROACH_USER: jet + COCKROACH_PASSWORD: jet + COCKROACH_DATABASE: jetdb + environment: # environment variables for the build itself TEST_RESULTS: /tmp/test-results # path to where test results will be saved @@ -82,7 +89,18 @@ jobs: echo -n . sleep 1 done - echo Failed waiting for MySQL && exit 1 + echo Failed waiting for MySQL && exit 1 + + - run: + name: Waiting for Cockroach to be ready + command: | + for i in `seq 1 10`; + do + nc -z localhost 26257 && echo Success && exit 0 + echo -n . + sleep 1 + done + echo Failed waiting for Cockroach && exit 1 - run: name: Install MySQL CLI; @@ -122,8 +140,9 @@ jobs: -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \ -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml - # run mariaDB tests. No need to collect coverage, because coverage is already included with mysql tests + # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ + - run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/ - save_cache: key: go-mod-v4-{{ checksum "go.sum" }} diff --git a/generator/postgres/query_set.go b/generator/postgres/query_set.go index 93e6ffb..da48505 100644 --- a/generator/postgres/query_set.go +++ b/generator/postgres/query_set.go @@ -35,7 +35,9 @@ WITH primaryKeys AS ( SELECT column_name FROM information_schema.key_column_usage AS c LEFT JOIN information_schema.table_constraints AS t - ON t.constraint_name = c.constraint_name + ON t.constraint_name = c.constraint_name AND + c.table_schema = t.table_schema AND + c.table_name = t.table_name WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' ) SELECT column_name as "column.Name", diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 37c1665..7e6e21a 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -2,6 +2,8 @@ package testutils import ( "bytes" + "context" + "database/sql" "encoding/json" "fmt" "github.com/go-jet/jet/v2/internal/jet" @@ -25,6 +27,18 @@ var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool { return t1.Unix() == t2.Unix() }) +// AssertExecAndRollback will execute and rollback statement in sql transaction +func AssertExecAndRollback(t *testing.T, stmt jet.Statement, db *sql.DB, rowsAffected ...int64) { + tx, err := db.Begin() + require.NoError(t, err) + defer func() { + err := tx.Rollback() + require.NoError(t, err) + }() + + AssertExec(t, stmt, tx, rowsAffected...) +} + // AssertExec assert statement execution for successful execution and number of rows affected func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) @@ -38,6 +52,18 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int } } +// ExecuteInTxAndRollback will execute function in sql transaction and then rollback transaction +func ExecuteInTxAndRollback(t *testing.T, db *sql.DB, f func(tx *sql.Tx)) { + tx, err := db.Begin() + require.NoError(t, err) + defer func() { + err := tx.Rollback() + require.NoError(t, err) + }() + + f(tx) +} + // AssertExecErr assert statement execution for failed execution with error string errorStr func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) { _, err := stmt.Exec(db) @@ -45,6 +71,13 @@ func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) require.Error(t, err, errorStr) } +// AssertExecContextErr assert statement execution for failed execution with error string errorStr +func AssertExecContextErr(t *testing.T, stmt jet.Statement, ctx context.Context, db qrm.DB, errorStr string) { + _, err := stmt.ExecContext(ctx, db) + + require.Error(t, err, errorStr) +} + func getFullPath(relativePath string) string { path, _ := os.Getwd() return filepath.Join(path, "../", relativePath) diff --git a/postgres/cast_test.go b/postgres/cast_test.go index e02336a..c0586d3 100644 --- a/postgres/cast_test.go +++ b/postgres/cast_test.go @@ -5,7 +5,7 @@ import ( ) func TestExpressionCAST_AS(t *testing.T) { - assertSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") + assertSerialize(t, CAST(Int(11)).AS("text"), `$1::text`, int64(11)) } func TestExpressionCAST_AS_BOOL(t *testing.T) { diff --git a/postgres/dialect_test.go b/postgres/dialect_test.go index 45ed739..9aadbc9 100644 --- a/postgres/dialect_test.go +++ b/postgres/dialect_test.go @@ -4,16 +4,16 @@ import "testing" func TestString_REGEXP_LIKE_operator(t *testing.T) { assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") - assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") - assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") - assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1::text)", "JOHN") } func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") - assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") - assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") - assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1::text)", "JOHN") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1::text)", "JOHN") } func TestExists(t *testing.T) { diff --git a/postgres/expressions_test.go b/postgres/expressions_test.go index 77c3dee..76403fb 100644 --- a/postgres/expressions_test.go +++ b/postgres/expressions_test.go @@ -60,7 +60,7 @@ func TestRawHelperMethods(t *testing.T) { assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), "((table.colInt + $1) = $2)", 11.22, 3.14) assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), - "((table.colStr || $1) = $2)", "doe", "john doe") + "((table.colStr || $1) = $2::text)", "doe", "john doe") now := time.Now() assertSerialize(t, RawTime("table.colTime").EQ(TimeT(now)), diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 3ec333e..25300c2 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -167,7 +167,7 @@ VALUES ('one', 'two'), ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE SET col_bool = TRUE::boolean, col_int = 1, - (col1, col_bool) = ROW(2, 'two') + (col1, col_bool) = ROW(2, 'two'::text) WHERE table1.col1 > 2 RETURNING table1.col1 AS "table1.col1", table1.col_bool AS "table1.col_bool"; @@ -193,7 +193,7 @@ VALUES ('one', 'two'), ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE SET col_bool = FALSE::boolean, col_int = 1, - (col1, col_bool) = ROW(2, 'two') + (col1, col_bool) = ROW(2, 'two'::text) WHERE table1.col1 > 2 RETURNING table1.col1 AS "table1.col1", table1.col_bool AS "table1.col_bool"; diff --git a/postgres/literal.go b/postgres/literal.go index e46b874..7fffd9d 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -61,7 +61,9 @@ var Float = jet.Float var Decimal = jet.Decimal // String creates new string literal expression -var String = jet.String +func String(value string) StringExpression { + return CAST(jet.String(value)).AS_TEXT() +} // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method diff --git a/postgres/literal_test.go b/postgres/literal_test.go index f95e486..5c5160e 100644 --- a/postgres/literal_test.go +++ b/postgres/literal_test.go @@ -59,7 +59,7 @@ func TestFloat(t *testing.T) { } func TestString(t *testing.T) { - assertSerialize(t, String("Some text"), `$1`, "Some text") + assertSerialize(t, String("Some text"), `$1::text`, "Some text") } func TestBytea(t *testing.T) { diff --git a/tests/Makefile b/tests/Makefile index 632c3d2..26d63cd 100644 --- a/tests/Makefile +++ b/tests/Makefile @@ -39,6 +39,7 @@ jet-gen-postgres: jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=northwind -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/ jet-gen-mysql: @@ -56,6 +57,12 @@ jet-gen-sqlite: jet -source=sqlite -dsn="./testdata/init/sqlite/sakila.db" -schema=dvds -path=./.gentestdata/sqlite/sakila jet -source=sqlite -dsn="./testdata/init/sqlite/test_sample.db" -schema=dvds -path=./.gentestdata/sqlite/test_sample +jet-gen-cockroach: + jet -dsn=postgres://jet:jet@127.0.0.1:26257/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=northwind -path=./.gentestdata/ + jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/ # docker-compose-cleanup will stop and remove test containers, volumes, and images. cleanup: diff --git a/tests/dbconfig/dbconfig.go b/tests/dbconfig/dbconfig.go index bbf73f9..59ff402 100644 --- a/tests/dbconfig/dbconfig.go +++ b/tests/dbconfig/dbconfig.go @@ -15,7 +15,24 @@ const ( ) // PostgresConnectString is PostgreSQL test database connection string -var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName) +var PostgresConnectString = pgConnectionString(PgHost, PgPort, PgUser, PgPassword, PgDBName) + +// Postgres test database connection parameters +const ( + CockroachHost = "localhost" + CockroachPort = 26257 + CockroachUser = "jet" + CockroachPassword = "jet" + CockroachDBName = "jetdb" +) + +// CockroachConnectString is Cockroach test database connection string +var CockroachConnectString = pgConnectionString(CockroachHost, CockroachPort, CockroachUser, CockroachPassword, CockroachDBName) + +func pgConnectionString(host string, port int, user, password, dbName string) string { + return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + host, port, user, password, dbName) +} // MySQL test database connection parameters const ( diff --git a/tests/docker-compose.yaml b/tests/docker-compose.yaml index 2e913f1..9c562fb 100644 --- a/tests/docker-compose.yaml +++ b/tests/docker-compose.yaml @@ -37,3 +37,16 @@ services: - '50903:3306' volumes: - ./testdata/init/mysql:/docker-entrypoint-initdb.d + + cockroach: + image: cockroachdb/cockroach-unstable:v22.1.0-beta.4 + environment: + - COCKROACH_USER=jet + - COCKROACH_PASSWORD=jet + - COCKROACH_DATABASE=jetdb + ports: + - "26257:26257" + command: start-single-node --insecure +# volumes: +# - ./testdata/init/cockroach:/docker-entrypoint-initdb.d + diff --git a/tests/init/init.go b/tests/init/init.go index c1c842a..5b7273b 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -1,10 +1,12 @@ package main import ( + "context" "database/sql" "flag" "fmt" "github.com/go-jet/jet/v2/generator/mysql" + "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/tests/internal/utils/repo" "io/ioutil" @@ -12,46 +14,52 @@ import ( "os/exec" "strings" - "github.com/go-jet/jet/v2/generator/postgres" "github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" + _ "github.com/jackc/pgx/v4/stdlib" + _ "github.com/lib/pq" _ "github.com/mattn/go-sqlite3" ) var testSuite string func init() { - flag.StringVar(&testSuite, "testsuite", "all", "Test suite name (postgres or mysql)") - + flag.StringVar(&testSuite, "testsuite", "all", "Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)") flag.Parse() } +const ( + Postgres = "postgres" + MySql = "mysql" + MariaDB = "mariadb" + Sqlite = "sqlite" + Cockroach = "cockroach" +) + func main() { - testSuite = strings.ToLower(testSuite) - - if testSuite == "postgres" { - initPostgresDB() - return - } - - if testSuite == "mysql" || testSuite == "mariadb" { - initMySQLDB(testSuite == "mariadb") - return - } - - if testSuite == "sqlite" { + switch strings.ToLower(testSuite) { + case Postgres: + initPostgresDB(Postgres, dbconfig.PostgresConnectString) + case Cockroach: + initPostgresDB(Cockroach, dbconfig.CockroachConnectString) + case MySql: + initMySQLDB(false) + case MariaDB: + initMySQLDB(true) + case Sqlite: initSQLiteDB() - return + case "all": + initPostgresDB(Cockroach, dbconfig.CockroachConnectString) + initPostgresDB(Postgres, dbconfig.PostgresConnectString) + initMySQLDB(false) + initMySQLDB(true) + initSQLiteDB() + default: + panic("invalid testsuite flag. Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)") } - - initPostgresDB() - initMySQLDB(false) - initMySQLDB(true) - initSQLiteDB() } func initSQLiteDB() { @@ -109,8 +117,8 @@ func initMySQLDB(isMariaDB bool) { } } -func initPostgresDB() { - db, err := sql.Open("postgres", dbconfig.PostgresConnectString) +func initPostgresDB(dbType string, connectionString string) { + db, err := sql.Open("postgres", connectionString) if err != nil { panic("Failed to connect to test db: " + err.Error()) } @@ -120,26 +128,19 @@ func initPostgresDB() { }() schemaNames := []string{ + "northwind", "dvds", "test_sample", "chinook", "chinook2", - "northwind", } for _, schemaName := range schemaNames { + fmt.Println("\nInitializing", schemaName, "schema...") - execFile(db, "./testdata/init/postgres/"+schemaName+".sql") + execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName)) - err = postgres.Generate("./.gentestdata", postgres.DBConnection{ - Host: dbconfig.PgHost, - Port: dbconfig.PgPort, - User: dbconfig.PgUser, - Password: dbconfig.PgPassword, - DBName: dbconfig.PgDBName, - SchemaName: schemaName, - SslMode: "disable", - }) + err = postgres.GenerateDSN(connectionString, schemaName, "./.gentestdata") throw.OnError(err) } } @@ -148,10 +149,32 @@ func execFile(db *sql.DB, sqlFilePath string) { testSampleSql, err := ioutil.ReadFile(sqlFilePath) throw.OnError(err) - _, err = db.Exec(string(testSampleSql)) + err = execInTx(db, func(tx *sql.Tx) error { + _, err := tx.Exec(string(testSampleSql)) + return err + }) throw.OnError(err) } +func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error { + tx, err := db.BeginTx(context.Background(), &sql.TxOptions{ + Isolation: sql.LevelReadUncommitted, // to speed up initialization of test database + }) + + if err != nil { + return err + } + + err = f(tx) + + if err != nil { + tx.Rollback() + return err + } + + return tx.Commit() +} + func printOnError(err error) { if err != nil { fmt.Println(err.Error()) diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 428a0e6..85268c7 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -20,7 +20,7 @@ import ( func TestAllTypes(t *testing.T) { - dest := []model.AllTypes{} + var dest []model.AllTypes err := AllTypes. SELECT(AllTypes.AllColumns). @@ -39,7 +39,7 @@ func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes - dest := []AllTypesView{} + var dest []AllTypesView err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) require.NoError(t, err) diff --git a/tests/mysql/delete_test.go b/tests/mysql/delete_test.go index 709ce1a..2c92367 100644 --- a/tests/mysql/delete_test.go +++ b/tests/mysql/delete_test.go @@ -13,24 +13,20 @@ import ( ) func TestDeleteWithWhere(t *testing.T) { - initForDeleteTest(t) - - var expectedSQL = ` -DELETE FROM test_sample.link -WHERE link.name IN ('Gmail', 'Outlook'); -` deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) - testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") - testutils.AssertExec(t, deleteStmt, db, 2) + testutils.AssertDebugStatementSql(t, deleteStmt, ` +DELETE FROM test_sample.link +WHERE link.name IN ('Gmail', 'Outlook'); +`, "Gmail", "Outlook") + + testutils.AssertExecAndRollback(t, deleteStmt, db, 2) requireLogged(t, deleteStmt) } func TestDeleteWithWhereOrderByLimit(t *testing.T) { - initForDeleteTest(t) - var expectedSQL = ` DELETE FROM test_sample.link WHERE link.name IN ('Gmail', 'Outlook') @@ -44,13 +40,11 @@ LIMIT 1; LIMIT(1) testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1)) - testutils.AssertExec(t, deleteStmt, db, 1) + testutils.AssertExecAndRollback(t, deleteStmt, db, 1) requireLogged(t, deleteStmt) } func TestDeleteQueryContext(t *testing.T) { - initForDeleteTest(t) - deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) @@ -60,7 +54,7 @@ func TestDeleteQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := deleteStmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") @@ -68,8 +62,6 @@ func TestDeleteQueryContext(t *testing.T) { } func TestDeleteExecContext(t *testing.T) { - initForDeleteTest(t) - deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) @@ -84,19 +76,7 @@ func TestDeleteExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") } -func initForDeleteTest(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(Link.URL, Link.Name, Link.Description). - VALUES("www.gmail.com", "Gmail", "Email service developed by Google"). - VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft") - - testutils.AssertExec(t, stmt, db, 2) -} - func TestDeleteWithUsing(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() - stmt := table.Rental.DELETE(). USING( table.Rental. @@ -116,5 +96,5 @@ USING dvds.rental WHERE (staff.staff_id != ?) AND (rental.rental_id < ?); `) - testutils.AssertExec(t, stmt, tx) + testutils.AssertExecAndRollback(t, stmt, db) } diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index 55fc706..10887f5 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -2,6 +2,7 @@ package mysql import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" @@ -13,51 +14,47 @@ import ( ) func TestInsertValues(t *testing.T) { - cleanUpLinkTable(t) - - var expectedSQL = ` -INSERT INTO test_sample.link (id, url, name, description) -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL); -` - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(101, "http://www.google.com", "Google", DEFAULT). VALUES(102, "http://www.yahoo.com", "Yahoo", nil) - testutils.AssertDebugStatementSql(t, insertQuery, expectedSQL, + testutils.AssertDebugStatementSql(t, insertQuery, ` +INSERT INTO test_sample.link (id, url, name, description) +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 101, "http://www.google.com", "Google", 102, "http://www.yahoo.com", "Yahoo", nil) - _, err := insertQuery.Exec(db) - require.NoError(t, err) - requireLogged(t, insertQuery) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := insertQuery.Exec(tx) + require.NoError(t, err) + requireLogged(t, insertQuery) - insertedLinks := []model.Link{} + var insertedLinks []model.Link - err = Link.SELECT(Link.AllColumns). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(db, &insertedLinks) + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.BETWEEN(Int(100), Int(199))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 3) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) - - testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ - ID: 101, - URL: "http://www.google.com", - Name: "Google", - }) - - testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ - ID: 102, - URL: "http://www.yahoo.com", - Name: "Yahoo", + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) }) } @@ -68,42 +65,34 @@ var postgreTutorial = model.Link{ } func TestInsertEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - - expectedSQL := ` -INSERT INTO test_sample.link -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); -` - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") - _, err := stmt.Exec(db) - require.NoError(t, err) - requireLogged(t, stmt) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + requireLogged(t, stmt) - insertedLinks := []model.Link{} + var insertedLinks []model.Link - err = Link.SELECT(Link.AllColumns). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(db, &insertedLinks) + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.BETWEEN(Int(100), Int(199))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 1) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + }) } func TestInsertModelObject(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); -` - linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -113,19 +102,19 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); INSERT(Link.URL, Link.Name). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); +`, + "http://www.duckduckgo.com", "Duck Duck go") - _, err := query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link -VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); -` - linkData := model.Link{ ID: 1000, URL: "http://www.duckduckgo.com", @@ -136,20 +125,18 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); INSERT(). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +`, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) - _, err := query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelsObject(t *testing.T) { - expectedSQL := ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); -` - tutorial := model.Link{ URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", @@ -169,24 +156,23 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), INSERT(Link.URL, Link.Name). MODELS([]model.Link{tutorial, google, yahoo}) - testutils.AssertDebugStatementSql(t, query, expectedSQL, + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", "http://www.yahoo.com", "Yahoo") - _, err := query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertUsingMutableColumns(t *testing.T) { - var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); -` - google := model.Link{ URL: "http://www.google.com", Name: "Google", @@ -203,31 +189,25 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), MODEL(google). MODELS([]model.Link{google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil, "http://www.yahoo.com", "Yahoo", nil) - _, err := stmt.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + }) } func TestInsertQuery(t *testing.T) { - _, err := Link.DELETE(). - WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))). - Exec(db) - require.NoError(t, err) - - var expectedSQL = ` -INSERT INTO test_sample.link (url, name) ( - SELECT link.url AS "link.url", - link.name AS "link.name" - FROM test_sample.link - WHERE link.id = 1 -); -` - query := Link. INSERT(Link.URL, Link.Name). QUERY( @@ -236,19 +216,28 @@ INSERT INTO test_sample.link (url, name) ( WHERE(Link.ID.EQ(Int(1))), ) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) ( + SELECT link.url AS "link.url", + link.name AS "link.name" + FROM test_sample.link + WHERE link.id = 1 +); +`, int64(1)) - _, err = query.Exec(db) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) - youtubeLinks := []model.Link{} - err = Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Youtube"))). - Query(db, &youtubeLinks) + var youtubeLinks []model.Link + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Youtube"))). + Query(tx, &youtubeLinks) - require.NoError(t, err) - require.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) + }) } func TestInsertOnDuplicateKey(t *testing.T) { @@ -272,28 +261,29 @@ ON DUPLICATE KEY UPDATE id = (id + ?), randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", int64(11), "PostgreSQL Tutorial 2") - testutils.AssertExec(t, stmt, db, 3) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) - newLinks := []model.Link{} + var newLinks []model.Link - err := SELECT(Link.AllColumns). - FROM(Link). - WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). - Query(db, &newLinks) + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). + Query(tx, &newLinks) - require.NoError(t, err) - require.Len(t, newLinks, 1) - require.Equal(t, newLinks[0], model.Link{ - ID: randId + 11, - URL: "http://www.postgresqltutorial.com", - Name: "PostgreSQL Tutorial 2", - Description: nil, + require.NoError(t, err) + require.Len(t, newLinks, 1) + require.Equal(t, newLinks[0], model.Link{ + ID: randId + 11, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial 2", + Description: nil, + }) }) } func TestInsertWithQueryContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) @@ -302,15 +292,13 @@ func TestInsertWithQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := stmt.QueryContext(ctx, db, &dest) require.Error(t, err, "context deadline exceeded") } func TestInsertWithExecContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) @@ -323,8 +311,3 @@ func TestInsertWithExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") } - -func cleanUpLinkTable(t *testing.T) { - _, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db) - require.NoError(t, err) -} diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index e04580d..f6ce57d 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -96,9 +96,3 @@ func skipForMariaDB(t *testing.T) { t.SkipNow() } } - -func beginTx(t *testing.T) *sql.Tx { - tx, err := db.Begin() - require.NoError(t, err) - return tx -} diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index ba628a1..c03d424 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -2,6 +2,7 @@ package mysql import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" @@ -13,8 +14,6 @@ import ( ) func TestUpdateValues(t *testing.T) { - setupLinkTableForUpdateTest(t) - var expectedSQL = ` UPDATE test_sample.link SET name = 'Bong', @@ -28,8 +27,26 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") - testutils.AssertExec(t, query, db) - requireLogged(t, query) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + + var links []model.Link + + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bong"))). + Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + }) + }) t.Run("new version", func(t *testing.T) { @@ -41,29 +58,29 @@ WHERE link.name = 'Bing'; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") - testutils.AssertExec(t, stmt, db) - requireLogged(t, stmt) - }) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) - links := []model.Link{} + var links []model.Link - err := Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Bong"))). - Query(db, &links) + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bong"))). + Query(tx, &links) - require.NoError(t, err) - require.Equal(t, len(links), 1) - testutils.AssertDeepEqual(t, links[0], model.Link{ - ID: 204, - URL: "http://bong.com", - Name: "Bong", + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + }) }) } func TestUpdateWithSubQueries(t *testing.T) { - setupLinkTableForUpdateTest(t) - expectedSQL := ` UPDATE test_sample.link SET name = ?, @@ -86,7 +103,7 @@ WHERE link.name = ?; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") - testutils.AssertExec(t, query, db) + testutils.AssertExecAndRollback(t, query, db) requireLogged(t, query) }) @@ -104,14 +121,12 @@ WHERE link.name = ?; WHERE(Link.Name.EQ(String("Bing"))) testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") - testutils.AssertExec(t, query, db) + testutils.AssertExecAndRollback(t, query, db) requireLogged(t, query) }) } func TestUpdateWithModelData(t *testing.T) { - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -123,24 +138,20 @@ func TestUpdateWithModelData(t *testing.T) { MODEL(link). WHERE(Link.ID.EQ(Int32(link.ID))) - expectedSQL := ` + testutils.AssertStatementSql(t, stmt, ` UPDATE test_sample.link SET id = ?, url = ?, name = ?, description = ? WHERE link.id = ?; -` - testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) +`, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) - testutils.AssertExec(t, stmt, db) + testutils.AssertExecAndRollback(t, stmt, db) requireLogged(t, stmt) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { - - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -154,23 +165,19 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { MODEL(link). WHERE(Link.ID.EQ(Int32(link.ID))) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET description = NULL, name = 'DuckDuckGo', url = 'http://www.duckduckgo.com' WHERE link.id = 201; -` +`, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) - - testutils.AssertExec(t, stmt, db) + testutils.AssertExecAndRollback(t, stmt, db) requireLogged(t, stmt) } func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -192,7 +199,7 @@ WHERE link.id = 201; //fmt.Println(stmt.DebugSql()) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) - testutils.AssertExec(t, stmt, db) + testutils.AssertExecAndRollback(t, stmt, db) } func TestUpdateWithInvalidModelData(t *testing.T) { @@ -201,8 +208,6 @@ func TestUpdateWithInvalidModelData(t *testing.T) { require.Equal(t, r, "missing struct field for column : id") }() - setupLinkTableForUpdateTest(t) - link := struct { Ident int URL string @@ -215,17 +220,13 @@ func TestUpdateWithInvalidModelData(t *testing.T) { Name: "DuckDuckGo", } - stmt := Link. + _ = Link. UPDATE(Link.AllColumns). MODEL(link). WHERE(Link.ID.EQ(Int(int64(link.Ident)))) - - stmt.Sql() } func TestUpdateQueryContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -243,8 +244,6 @@ func TestUpdateQueryContext(t *testing.T) { } func TestUpdateExecContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -261,9 +260,6 @@ func TestUpdateExecContext(t *testing.T) { } func TestUpdateWithJoin(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() - statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). UPDATE(table.Staff.LastName). SET(String("New staff name")). @@ -276,21 +272,5 @@ SET last_name = ? WHERE staff.staff_id = ?; `, "New staff name", int64(1)) - _, err := statement.Exec(tx) - require.NoError(t, err) -} - -func setupLinkTableForUpdateTest(t *testing.T) { - - cleanUpLinkTable(t) - - _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). - VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES(201, "http://www.ask.com", "Ask", DEFAULT). - VALUES(202, "http://www.ask.com", "Ask", DEFAULT). - VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES(204, "http://www.bing.com", "Bing", DEFAULT). - Exec(db) - - require.NoError(t, err) + testutils.AssertExecAndRollback(t, statement, db) } diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 405ec9e..f1c82af 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,6 +1,7 @@ package postgres import ( + "database/sql" "testing" "time" @@ -17,10 +18,10 @@ import ( ) func TestAllTypesSelect(t *testing.T) { - dest := []model.AllTypes{} + var dest []model.AllTypes err := AllTypes.SELECT( - AllTypes.AllColumns, + AllTypesAllColumns, ).LIMIT(2). Query(db, &dest) require.NoError(t, err) @@ -32,7 +33,7 @@ func TestAllTypesSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) { type AllTypesView model.AllTypes - dest := []AllTypesView{} + var dest []AllTypesView err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) require.NoError(t, err) @@ -44,40 +45,123 @@ func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) { skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) - query := AllTypes.INSERT(AllTypes.AllColumns). + query := AllTypes.INSERT(AllTypesAllColumns). MODEL(allTypesRow0). MODEL(&allTypesRow1). RETURNING(AllTypes.AllColumns) - dest := []model.AllTypes{} - err := query.Query(db, &dest) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.AllTypes + err := query.Query(tx, &dest) + require.NoError(t, err) - require.Equal(t, len(dest), 2) - testutils.AssertDeepEqual(t, dest[0], allTypesRow0) - testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + if sourceIsCockroachDB() { + return + } + require.Equal(t, len(dest), 2) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + }) } +var AllTypesAllColumns = AllTypes.AllColumns.Except(IntegerColumn("rowid")) + func TestAllTypesInsertQuery(t *testing.T) { - query := AllTypes.INSERT(AllTypes.AllColumns). + query := AllTypes.INSERT(AllTypesAllColumns). QUERY( AllTypes. - SELECT(AllTypes.AllColumns). + SELECT(AllTypesAllColumns). LIMIT(2), ). - RETURNING(AllTypes.AllColumns) + RETURNING(AllTypesAllColumns) - dest := []model.AllTypes{} - err := query.Query(db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.AllTypes + err := query.Query(tx, &dest) + require.NoError(t, err) + require.Equal(t, len(dest), 2) + testutils.AssertDeepEqual(t, dest[0], allTypesRow0) + testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + }) +} + +func TestUUIDType(t *testing.T) { + id := uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + + query := AllTypes. + SELECT(AllTypes.UUID, AllTypes.UUIDPtr). + WHERE(AllTypes.UUID.EQ(UUID(id))) + + testutils.AssertDebugStatementSql(t, query, ` +SELECT all_types.uuid AS "all_types.uuid", + all_types.uuid_ptr AS "all_types.uuid_ptr" +FROM test_sample.all_types +WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; +`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + + result := model.AllTypes{} + + err := query.Query(db, &result) require.NoError(t, err) - require.Equal(t, len(dest), 2) - testutils.AssertDeepEqual(t, dest[0], allTypesRow0) - testutils.AssertDeepEqual(t, dest[1], allTypesRow1) + require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + requireLogged(t, query) +} + +func TestBytea(t *testing.T) { + byteArrHex := "\\x48656c6c6f20476f7068657221" + byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21") + + insertStmt := AllTypes.INSERT(AllTypes.Bytea, AllTypes.ByteaPtr). + VALUES(byteArrHex, byteArrBin). + RETURNING(AllTypes.Bytea, AllTypes.ByteaPtr) + + testutils.AssertStatementSql(t, insertStmt, ` +INSERT INTO test_sample.all_types (bytea, bytea_ptr) +VALUES ($1, $2) +RETURNING all_types.bytea AS "all_types.bytea", + all_types.bytea_ptr AS "all_types.bytea_ptr"; +`, byteArrHex, byteArrBin) + + var inserted model.AllTypes + err := insertStmt.Query(db, &inserted) + require.NoError(t, err) + + require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") + // It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver. + // pq driver always encodes parameter string if destination column is of type bytea. + // Probably pq driver error. + // require.Equal(t, string(inserted.Bytea), "Hello Gopher!") + + stmt := SELECT( + AllTypes.Bytea, + AllTypes.ByteaPtr, + ).FROM( + AllTypes, + ).WHERE( + AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)), + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT all_types.bytea AS "all_types.bytea", + all_types.bytea_ptr AS "all_types.bytea_ptr" +FROM test_sample.all_types +WHERE all_types.bytea_ptr = $1::bytea; +`, byteArrBin) + + var dest model.AllTypes + + err = stmt.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") + // Probably pq driver error. + // require.Equal(t, string(dest.Bytea), "Hello Gopher!") } func TestAllTypesFromSubQuery(t *testing.T) { - subQuery := SELECT(AllTypes.AllColumns). + subQuery := SELECT(AllTypesAllColumns). FROM(AllTypes). AsTable("allTypesSubQuery") @@ -214,7 +298,7 @@ FROM ( LIMIT 2; `) - dest := []model.AllTypes{} + var dest []model.AllTypes err := mainQuery.Query(db, &dest) require.NoError(t, err) @@ -298,7 +382,6 @@ LIMIT $11; } func TestExpressionCast(t *testing.T) { - skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text' query := AllTypes.SELECT( @@ -315,19 +398,28 @@ func TestExpressionCast(t *testing.T) { CAST(Int(234)).AS_TEXT(), CAST(String("1/8/1999")).AS_DATE(), CAST(String("04:05:06.789")).AS_TIME(), - CAST(String("04:05:06 PST")).AS_TIMEZ(), + CAST(String("04:05:06+01:00")).AS_TIMEZ(), CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), - CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), + CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(), CAST(String("04:05:06")).AS_INTERVAL(), - TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), - TO_CHAR(AllTypes.Integer, String("999")), - TO_CHAR(AllTypes.DoublePrecision, String("999D9")), - TO_CHAR(AllTypes.Numeric, String("999D99S")), + func() ProjectionList { + if sourceIsCockroachDB() { + return ProjectionList{NULL} + } - TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")), - TO_NUMBER(String("12,454"), String("99G999D9S")), - TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")), + // cockroach doesn't support currently + return ProjectionList{ + TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), + TO_CHAR(AllTypes.Integer, String("999")), + TO_CHAR(AllTypes.DoublePrecision, String("999D9")), + TO_CHAR(AllTypes.Numeric, String("999D99S")), + + TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")), + TO_NUMBER(String("12,454"), String("99G999D9S")), + TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")), + } + }(), COALESCE(AllTypes.IntegerPtr, AllTypes.SmallIntPtr, NULL, Int(11)), NULLIF(AllTypes.Text, String("(none)")), @@ -337,16 +429,15 @@ func TestExpressionCast(t *testing.T) { Raw("current_database()"), ) - //fmt.Println(query.DebugSql()) - - dest := []struct{}{} + var dest []struct{} err := query.Query(db, &dest) require.NoError(t, err) } func TestStringOperators(t *testing.T) { - skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' + skipForCockroachDB(t) // some string functions are still unimplemented + skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' query := AllTypes.SELECT( AllTypes.Text.EQ(AllTypes.Char), @@ -395,18 +486,18 @@ func TestStringOperators(t *testing.T) { CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), - CONVERT(String("bytea"), String("UTF8"), String("LATIN1")), + CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")), CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")), - CONVERT_FROM(String("text_in_utf8"), String("UTF8")), + CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")), CONVERT_TO(String("text_in_utf8"), String("UTF8")), - ENCODE(String("123\000\001"), String("base64")), + ENCODE(Bytea("123\000\001"), String("base64")), DECODE(String("MTIzAAE="), String("base64")), FORMAT(String("Hello %s, %1$s"), String("World")), INITCAP(String("hi THOMAS")), LEFT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)), - LENGTH(String("jose")), - LENGTH(String("jose"), String("UTF8")), + LENGTH(Bytea("jose")), + LENGTH(Bytea("jose"), String("UTF8")), LPAD(String("Hi"), Int(5)), LPAD(String("Hi"), Int(5), String("xy")), RPAD(String("Hi"), Int(5)), @@ -421,8 +512,6 @@ func TestStringOperators(t *testing.T) { TO_HEX(AllTypes.IntegerPtr), ) - //fmt.Println(query.DebugSql()) - dest := []struct{}{} err := query.Query(db, &dest) @@ -501,6 +590,8 @@ LIMIT $5; } func TestFloatOperators(t *testing.T) { + skipForCockroachDB(t) // some functions are still unimplemented + query := AllTypes.SELECT( AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"), AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"), @@ -604,6 +695,8 @@ LIMIT $38; } func TestIntegerOperators(t *testing.T) { + skipForCockroachDB(t) // some functions are still unimplemented + query := AllTypes.SELECT( AllTypes.BigInt, AllTypes.BigIntPtr, @@ -733,6 +826,8 @@ LIMIT $27; } func TestTimeExpression(t *testing.T) { + skipForCockroachDB(t) + query := AllTypes.SELECT( AllTypes.Time.EQ(AllTypes.Time), AllTypes.Time.EQ(Time(23, 6, 6, 1)), @@ -813,6 +908,8 @@ func TestTimeExpression(t *testing.T) { } func TestInterval(t *testing.T) { + skipForCockroachDB(t) + stmt := SELECT( INTERVAL(1, YEAR), INTERVAL(1, MONTH), @@ -1084,6 +1181,10 @@ LIMIT $6; dest.Timez = dest.Timez.UTC() dest.Timestampz = dest.Timestampz.UTC() + if sourceIsCockroachDB() { + return // rounding differences + } + testutils.AssertJSON(t, dest, ` { "Date": "2009-11-17T00:00:00Z", diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 2d9821c..65eae7f 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -188,34 +188,36 @@ func TestJoinEverything(t *testing.T) { manager := Employee.AS("Manager") - stmt := Artist. - LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)). - LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)). - LEFT_JOIN(Genre, Genre.GenreId.EQ(Track.GenreId)). - LEFT_JOIN(MediaType, MediaType.MediaTypeId.EQ(Track.MediaTypeId)). - LEFT_JOIN(PlaylistTrack, PlaylistTrack.TrackId.EQ(Track.TrackId)). - LEFT_JOIN(Playlist, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId)). - LEFT_JOIN(InvoiceLine, InvoiceLine.TrackId.EQ(Track.TrackId)). - LEFT_JOIN(Invoice, Invoice.InvoiceId.EQ(InvoiceLine.InvoiceId)). - LEFT_JOIN(Customer, Customer.CustomerId.EQ(Invoice.CustomerId)). - LEFT_JOIN(Employee, Employee.EmployeeId.EQ(Customer.SupportRepId)). - LEFT_JOIN(manager, manager.EmployeeId.EQ(Employee.ReportsTo)). - SELECT( - Artist.AllColumns, - Album.AllColumns, - Track.AllColumns, - Genre.AllColumns, - MediaType.AllColumns, - PlaylistTrack.AllColumns, - Playlist.AllColumns, - Invoice.AllColumns, - Customer.AllColumns, - Employee.AllColumns, - manager.AllColumns, - ). - ORDER_BY(Artist.ArtistId, Album.AlbumId, Track.TrackId, - Genre.GenreId, MediaType.MediaTypeId, Playlist.PlaylistId, - Invoice.InvoiceId, Customer.CustomerId) + stmt := SELECT( + Artist.AllColumns, + Album.AllColumns, + Track.AllColumns, + Genre.AllColumns, + MediaType.AllColumns, + PlaylistTrack.AllColumns, + Playlist.AllColumns, + Invoice.AllColumns, + Customer.AllColumns, + Employee.AllColumns, + manager.AllColumns, + ).FROM( + Artist. + LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)). + LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)). + LEFT_JOIN(Genre, Genre.GenreId.EQ(Track.GenreId)). + LEFT_JOIN(MediaType, MediaType.MediaTypeId.EQ(Track.MediaTypeId)). + LEFT_JOIN(PlaylistTrack, PlaylistTrack.TrackId.EQ(Track.TrackId)). + LEFT_JOIN(Playlist, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId)). + LEFT_JOIN(InvoiceLine, InvoiceLine.TrackId.EQ(Track.TrackId)). + LEFT_JOIN(Invoice, Invoice.InvoiceId.EQ(InvoiceLine.InvoiceId)). + LEFT_JOIN(Customer, Customer.CustomerId.EQ(Invoice.CustomerId)). + LEFT_JOIN(Employee, Employee.EmployeeId.EQ(Customer.SupportRepId)). + LEFT_JOIN(manager, manager.EmployeeId.EQ(Employee.ReportsTo)), + ).ORDER_BY( + Artist.ArtistId, Album.AlbumId, Track.TrackId, + Genre.GenreId, MediaType.MediaTypeId, Playlist.PlaylistId, + Invoice.InvoiceId, Customer.CustomerId, + ) var dest []struct { //list of all artist model.Artist @@ -398,11 +400,11 @@ FROM ( SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId", "subQuery1"."Artist.Name" AS "Artist.Name", "subQuery1".custom_column_1 AS "custom_column_1", - $1 AS "custom_column_2" + $1::text AS "custom_column_2" FROM ( SELECT "Artist"."ArtistId" AS "Artist.ArtistId", "Artist"."Name" AS "Artist.Name", - $2 AS "custom_column_1" + $2::text AS "custom_column_1" FROM chinook."Artist" ORDER BY "Artist"."ArtistId" ASC ) AS "subQuery1" @@ -721,11 +723,14 @@ ORDER BY "Album.AlbumId"; } func TestQueryWithContext(t *testing.T) { + if sourceIsCockroachDB() && !isPgxDriver() { + return // context cancellation doesn't work for pq driver + } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() - dest := []model.Album{} + var dest []model.Album err := Album. CROSS_JOIN(Track). @@ -737,6 +742,9 @@ func TestQueryWithContext(t *testing.T) { } func TestExecWithContext(t *testing.T) { + if sourceIsCockroachDB() && !isPgxDriver() { + return // context cancellation doesn't work for pq driver + } ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() @@ -828,10 +836,12 @@ func Test_SchemaRename(t *testing.T) { albumArtistID := Album2.ArtistId.From(first10Albums) - stmt := SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()). - FROM(first10Artist. - INNER_JOIN(first10Albums, artistID.EQ(albumArtistID))). - ORDER_BY(artistID) + stmt := SELECT( + first10Artist.AllColumns(), + first10Albums.AllColumns(), + ).FROM(first10Artist. + INNER_JOIN(first10Albums, artistID.EQ(albumArtistID)), + ).ORDER_BY(artistID) testutils.AssertDebugStatementSql(t, stmt, ` SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId", @@ -891,6 +901,8 @@ var album347 = model.Album{ } func TestAggregateFunc(t *testing.T) { + skipForCockroachDB(t) + stmt := SELECT( PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), diff --git a/tests/postgres/delete_test.go b/tests/postgres/delete_test.go index abbb344..ee8d320 100644 --- a/tests/postgres/delete_test.go +++ b/tests/postgres/delete_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" @@ -14,69 +15,49 @@ import ( ) func TestDeleteWithWhere(t *testing.T) { - initForDeleteTest(t) - - var expectedSQL = ` -DELETE FROM test_sample.link -WHERE link.name IN ('Gmail', 'Outlook'); -` deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) - testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") + testutils.AssertDebugStatementSql(t, deleteStmt, ` +DELETE FROM test_sample.link +WHERE link.name IN ('Gmail'::text, 'Outlook'::text); +`, "Gmail", "Outlook") - res, err := deleteStmt.ExecContext(context.Background(), db) - - require.NoError(t, err) - rows, err := res.RowsAffected() - require.NoError(t, err) - require.Equal(t, rows, int64(2)) + testutils.AssertExecAndRollback(t, deleteStmt, db, 2) requireQueryLogged(t, deleteStmt, int64(2)) } func TestDeleteWithWhereAndReturning(t *testing.T) { - initForDeleteTest(t) - - var expectedSQL = ` -DELETE FROM test_sample.link -WHERE link.name IN ('Gmail', 'Outlook') -RETURNING link.id AS "link.id", - link.url AS "link.url", - link.name AS "link.name", - link.description AS "link.description"; -` deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))). RETURNING(Link.AllColumns) - testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") + testutils.AssertDebugStatementSql(t, deleteStmt, ` +DELETE FROM test_sample.link +WHERE link.name IN ('Gmail'::text, 'Outlook'::text) +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`, "Gmail", "Outlook") - dest := []model.Link{} + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.Link - err := deleteStmt.Query(db, &dest) + err := deleteStmt.Query(tx, &dest) - require.NoError(t, err) + require.NoError(t, err) - require.Equal(t, len(dest), 2) - testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") - testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") - requireLogged(t, deleteStmt) -} - -func initForDeleteTest(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(Link.URL, Link.Name, Link.Description). - VALUES("www.gmail.com", "Gmail", "Email service developed by Google"). - VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft") - - AssertExec(t, stmt, 2) + require.Equal(t, len(dest), 2) + testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") + testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") + requireLogged(t, deleteStmt) + }) } func TestDeleteQueryContext(t *testing.T) { - initForDeleteTest(t) - deleteStmt := Link. DELETE(). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) @@ -86,16 +67,16 @@ func TestDeleteQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} - err := deleteStmt.QueryContext(ctx, db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + dest := []model.Link{} + err := deleteStmt.QueryContext(ctx, tx, &dest) - require.Error(t, err, "context deadline exceeded") - requireLogged(t, deleteStmt) + require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) + }) } func TestDeleteExecContext(t *testing.T) { - initForDeleteTest(t) - list := []Expression{String("Gmail"), String("Outlook")} deleteStmt := Link. @@ -107,15 +88,16 @@ func TestDeleteExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - _, err := deleteStmt.ExecContext(ctx, db) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + _, err := deleteStmt.ExecContext(ctx, tx) - require.Error(t, err, "context deadline exceeded") - requireLogged(t, deleteStmt) + require.Error(t, err, "context deadline exceeded") + requireLogged(t, deleteStmt) + }) } func TestDeleteFrom(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() + skipForCockroachDB(t) // USING is not supported stmt := table.Rental.DELETE(). USING( @@ -158,16 +140,17 @@ RETURNING rental.rental_id AS "rental.rental_id", store.last_update AS "store.last_update"; `) - var dest []struct { - Rental model2.Rental - Store model2.Store - } + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + Rental model2.Rental + Store model2.Store + } - err := stmt.Query(tx, &dest) + err := stmt.Query(tx, &dest) - require.NoError(t, err) - require.Len(t, dest, 3) - testutils.AssertJSON(t, dest[0], ` + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` { "Rental": { "RentalID": 4, @@ -186,4 +169,5 @@ RETURNING rental.rental_id AS "rental.rental_id", } } `) + }) } diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 0e420bb..328de17 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -497,6 +497,8 @@ func newActorInfoTableImpl(schemaName, tableName, alias string) actorInfoTable { ` func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { + skipForCockroachDB(t) + enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index 8a50e02..1274869 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" @@ -13,9 +14,13 @@ import ( ) func TestInsertValues(t *testing.T) { - cleanUpLinkTable(t) + insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(101, "http://www.google.com", "Google", DEFAULT). + VALUES(102, "http://www.yahoo.com", "Yahoo", nil). + RETURNING(Link.AllColumns) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, insertQuery, ` INSERT INTO test_sample.link (id, url, name, description) VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), (101, 'http://www.google.com', 'Google', DEFAULT), @@ -24,76 +29,61 @@ RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -` - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). - VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES(101, "http://www.google.com", "Google", DEFAULT). - VALUES(102, "http://www.yahoo.com", "Yahoo", nil). - RETURNING(Link.AllColumns) - - testutils.AssertDebugStatementSql(t, insertQuery, expectedSQL, +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 101, "http://www.google.com", "Google", 102, "http://www.yahoo.com", "Yahoo", nil) - insertedLinks := []model.Link{} + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var insertedLinks []model.Link - err := insertQuery.Query(db, &insertedLinks) + err := insertQuery.Query(tx, &insertedLinks) - require.NoError(t, err) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ + ID: 100, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", + }) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) - require.Equal(t, len(insertedLinks), 3) + var allLinks []model.Link - testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ - ID: 100, - URL: "http://www.postgresqltutorial.com", - Name: "PostgreSQL Tutorial", + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.BETWEEN(Int(100), Int(199))). + ORDER_BY(Link.ID). + Query(tx, &allLinks) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, insertedLinks, allLinks) }) - - testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ - ID: 101, - URL: "http://www.google.com", - Name: "Google", - }) - - testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ - ID: 102, - URL: "http://www.yahoo.com", - Name: "Yahoo", - }) - - allLinks := []model.Link{} - - err = Link.SELECT(Link.AllColumns). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(db, &allLinks) - - require.NoError(t, err) - - testutils.AssertDeepEqual(t, insertedLinks, allLinks) } func TestInsertEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - - expectedSQL := ` -INSERT INTO test_sample.link -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); -` - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) } func TestInsertOnConflict(t *testing.T) { - t.Run("do nothing", func(t *testing.T) { employee := model.Employee{EmployeeID: rand.Int31()} @@ -108,11 +98,12 @@ VALUES ($1, $2, $3, $4, $5), ($6, $7, $8, $9, $10) ON CONFLICT (employee_id) DO NOTHING; `) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) }) t.Run("on constraint do nothing", func(t *testing.T) { + skipForCockroachDB(t) // does not support employee := model.Employee{EmployeeID: rand.Int31()} stmt := Employee.INSERT(Employee.AllColumns). @@ -126,12 +117,11 @@ VALUES ($1, $2, $3, $4, $5), ($6, $7, $8, $9, $10) ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; `) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) }) t.Run("do update", func(t *testing.T) { - cleanUpLinkTable(t) stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). @@ -148,18 +138,19 @@ VALUES ($1, $2, $3, DEFAULT), ($4, $5, $6, DEFAULT) ON CONFLICT (id) DO UPDATE SET id = excluded.id, - url = $7 + url = $7::text RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; `) - AssertExec(t, stmt, 2) + testutils.AssertExecAndRollback(t, stmt, db, 2) }) t.Run("on constraint do update", func(t *testing.T) { - cleanUpLinkTable(t) + skipForCockroachDB(t) // does not support + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). @@ -177,18 +168,18 @@ VALUES ($1, $2, $3, DEFAULT), ($4, $5, $6, DEFAULT) ON CONFLICT ON CONSTRAINT link_pkey DO UPDATE SET id = excluded.id, - url = $7 + url = $7::text RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; `) - AssertExec(t, stmt, 2) + testutils.AssertExecAndRollback(t, stmt, db, 2) }) t.Run("do update complex", func(t *testing.T) { - cleanUpLinkTable(t) + skipForCockroachDB(t) // does not support ROW stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). @@ -210,21 +201,15 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE SELECT MAX(link.id) + 1 FROM test_sample.link ), - (name, description) = ROW(excluded.name, 'new description') + (name, description) = ROW(excluded.name, 'new description'::text) WHERE link.description IS NOT NULL; `) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) }) } func TestInsertModelObject(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); -` - linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -234,18 +219,15 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); INSERT(Link.URL, Link.Name). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); +`, "http://www.duckduckgo.com", "Duck Duck go") - AssertExec(t, query, 1) + testutils.AssertExecAndRollback(t, query, db, 1) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { - cleanUpLinkTable(t) - var expectedSQL = ` -INSERT INTO test_sample.link -VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); -` - linkData := model.Link{ ID: 1000, URL: "http://www.duckduckgo.com", @@ -256,19 +238,16 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); INSERT(). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO test_sample.link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +`, + int64(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) - AssertExec(t, query, 1) + testutils.AssertExecAndRollback(t, query, db, 1) } func TestInsertModelsObject(t *testing.T) { - expectedSQL := ` -INSERT INTO test_sample.link (url, name) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); -` - tutorial := model.Link{ URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", @@ -288,23 +267,20 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), INSERT(Link.URL, Link.Name). MODELS([]model.Link{tutorial, google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", "http://www.yahoo.com", "Yahoo") - AssertExec(t, stmt, 3) + testutils.AssertExecAndRollback(t, stmt, db, 3) } func TestInsertUsingMutableColumns(t *testing.T) { - var expectedSQL = ` -INSERT INTO test_sample.link (url, name, description) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); -` - google := model.Link{ URL: "http://www.google.com", Name: "Google", @@ -321,22 +297,32 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), MODEL(google). MODELS([]model.Link{google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO test_sample.link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil, "http://www.yahoo.com", "Yahoo", nil) - AssertExec(t, stmt, 4) + testutils.AssertExecAndRollback(t, stmt, db, 4) } func TestInsertQuery(t *testing.T) { - _, err := Link.DELETE(). - WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). - Exec(db) - require.NoError(t, err) + query := Link. + INSERT(Link.URL, Link.Name). + QUERY( + SELECT(Link.URL, Link.Name). + FROM(Link). + WHERE(Link.ID.EQ(Int(0))), + ). + RETURNING(Link.AllColumns) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, query, ` INSERT INTO test_sample.link (url, name) ( SELECT link.url AS "link.url", link.name AS "link.name" @@ -347,38 +333,26 @@ RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -` +`, int64(0)) - query := Link. - INSERT(Link.URL, Link.Name). - QUERY( - SELECT(Link.URL, Link.Name). - FROM(Link). - WHERE(Link.ID.EQ(Int(0))), - ). - RETURNING(Link.AllColumns) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []model.Link - testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(0)) + err := query.Query(tx, &dest) + require.NoError(t, err) - dest := []model.Link{} + var youtubeLinks []model.Link + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Youtube"))). + Query(tx, &youtubeLinks) - err = query.Query(db, &dest) - - require.NoError(t, err) - - youtubeLinks := []model.Link{} - err = Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Youtube"))). - Query(db, &youtubeLinks) - - require.NoError(t, err) - require.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) + }) } func TestInsertWithQueryContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). RETURNING(Link.AllColumns) @@ -388,15 +362,15 @@ func TestInsertWithQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} - err := stmt.QueryContext(ctx, db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + dest := []model.Link{} + err := stmt.QueryContext(ctx, tx, &dest) - require.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") + }) } func TestInsertWithExecContext(t *testing.T) { - cleanUpLinkTable(t) - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) @@ -405,7 +379,7 @@ func TestInsertWithExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - _, err := stmt.ExecContext(ctx, db) - - require.Error(t, err, "context deadline exceeded") + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + testutils.AssertExecContextErr(t, stmt, ctx, tx, "context deadline exceeded") + }) } diff --git a/tests/postgres/lock_test.go b/tests/postgres/lock_test.go index c028629..4caed10 100644 --- a/tests/postgres/lock_test.go +++ b/tests/postgres/lock_test.go @@ -12,6 +12,8 @@ import ( ) func TestLockTable(t *testing.T) { + skipForCockroachDB(t) // doesn't support + expectedSQL := ` LOCK TABLE dvds.address IN` @@ -62,6 +64,8 @@ LOCK TABLE dvds.address IN` } func TestLockExecContext(t *testing.T) { + skipForCockroachDB(t) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) defer cancel() diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index aa05e0f..08af67a 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -25,6 +25,24 @@ import ( var db *sql.DB var testRoot string +var source string + +const CockroachDB = "COCKROACH_DB" + +func init() { + source = os.Getenv("PG_SOURCE") +} + +func sourceIsCockroachDB() bool { + return source == CockroachDB +} + +func skipForCockroachDB(t *testing.T) { + if sourceIsCockroachDB() { + t.SkipNow() + } +} + func TestMain(m *testing.M) { rand.Seed(time.Now().Unix()) defer profile.Start().Stop() @@ -35,8 +53,15 @@ func TestMain(m *testing.M) { fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) func() { + + connectionString := dbconfig.PostgresConnectString + + if sourceIsCockroachDB() { + connectionString = dbconfig.CockroachConnectString + } + var err error - db, err = sql.Open(driverName, dbconfig.PostgresConnectString) + db, err = sql.Open(driverName, connectionString) if err != nil { fmt.Println(err.Error()) panic("Failed to connect to test db") @@ -113,9 +138,3 @@ func isPgxDriver() bool { return false } - -func beginTx(t *testing.T) *sql.Tx { - tx, err := db.Begin() - require.NoError(t, err) - return tx -} diff --git a/tests/postgres/raw_statements_test.go b/tests/postgres/raw_statements_test.go index 4bbf90c..e201c75 100644 --- a/tests/postgres/raw_statements_test.go +++ b/tests/postgres/raw_statements_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "testing" "time" @@ -85,12 +86,10 @@ func TestRawStatementSelectWithArguments(t *testing.T) { } func TestRawInsert(t *testing.T) { - cleanUpLinkTable(t) - stmt := RawStatement(` INSERT INTO test_sample.link (id, url, name, description) VALUES (@id1, @url1, @name1, DEFAULT), - (200, @url1, @name1, NULL), + (2000, @url1, @name1, NULL), (@id2, @url2, @name2, DEFAULT), (@id3, @url3, @name3, NULL) RETURNING link.id AS "link.id", @@ -98,45 +97,47 @@ RETURNING link.id AS "link.id", link.name AS "link.name", link.description AS "link.description"`, RawArgs{ - "@id1": 100, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial", - "@id2": 101, "@url2": "http://www.google.com", "@name2": "Google", - "@id3": 102, "@url3": "http://www.yahoo.com", "@name3": "Yahoo", + "@id1": 1000, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial", + "@id2": 1010, "@url2": "http://www.google.com", "@name2": "Google", + "@id3": 1020, "@url3": "http://www.yahoo.com", "@name3": "Yahoo", }) testutils.AssertStatementSql(t, stmt, ` INSERT INTO test_sample.link (id, url, name, description) VALUES ($1, $2, $3, DEFAULT), - (200, $2, $3, NULL), + (2000, $2, $3, NULL), ($4, $5, $6, DEFAULT), ($7, $8, $9, NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; -`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", - 101, "http://www.google.com", "Google", - 102, "http://www.yahoo.com", "Yahoo") +`, 1000, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + 1010, "http://www.google.com", "Google", + 1020, "http://www.yahoo.com", "Yahoo") testutils.AssertDebugStatementSql(t, stmt, ` INSERT INTO test_sample.link (id, url, name, description) -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), - (200, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), - (101, 'http://www.google.com', 'Google', DEFAULT), - (102, 'http://www.yahoo.com', 'Yahoo', NULL) +VALUES (1000, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (2000, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + (1010, 'http://www.google.com', 'Google', DEFAULT), + (1020, 'http://www.yahoo.com', 'Yahoo', NULL) RETURNING link.id AS "link.id", link.url AS "link.url", link.name AS "link.name", link.description AS "link.description"; `) - var links []model2.Link - err := stmt.Query(db, &links) - require.NoError(t, err) - require.Len(t, links, 4) - require.Equal(t, links[0].ID, int32(100)) - require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com") - require.Equal(t, links[2].Name, "Google") - require.Nil(t, links[2].Description) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var links []model2.Link + err := stmt.Query(tx, &links) + require.NoError(t, err) + require.Len(t, links, 4) + require.Equal(t, links[0].ID, int64(1000)) + require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com") + require.Equal(t, links[2].Name, "Google") + require.Nil(t, links[2].Description) + }) } func TestRawStatementRows(t *testing.T) { diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index f1d9999..a13a30b 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -1,9 +1,9 @@ package postgres import ( + "github.com/google/uuid" "testing" - "github.com/google/uuid" "github.com/stretchr/testify/require" "github.com/go-jet/jet/v2/internal/testutils" @@ -14,30 +14,6 @@ import ( "github.com/shopspring/decimal" ) -func TestUUIDType(t *testing.T) { - - id := uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") - - query := AllTypes. - SELECT(AllTypes.UUID, AllTypes.UUIDPtr). - WHERE(AllTypes.UUID.EQ(UUID(id))) - - testutils.AssertDebugStatementSql(t, query, ` -SELECT all_types.uuid AS "all_types.uuid", - all_types.uuid_ptr AS "all_types.uuid_ptr" -FROM test_sample.all_types -WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; -`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") - - result := model.AllTypes{} - - err := query.Query(db, &result) - require.NoError(t, err) - require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - requireLogged(t, query) -} - func TestExactDecimals(t *testing.T) { type floats struct { @@ -80,7 +56,7 @@ func TestExactDecimals(t *testing.T) { t.Run("should insert decimal", func(t *testing.T) { insertQuery := Floats.INSERT( - Floats.AllColumns, + Floats.MutableColumns, ).MODEL( floats{ Floats: model.Floats{ @@ -102,7 +78,7 @@ func TestExactDecimals(t *testing.T) { DecimalPtr: decimal.RequireFromString("3.3333333333333333333"), }, ).RETURNING( - Floats.AllColumns, + Floats.MutableColumns, ) testutils.AssertDebugStatementSql(t, insertQuery, ` @@ -199,7 +175,9 @@ func TestUUIDComplex(t *testing.T) { }) t.Run("single struct", func(t *testing.T) { - singleQuery := query.WHERE(Person.PersonID.EQ(String("b68dbff6-a87d-11e9-a7f2-98ded00c39c8"))) + uuid, err := uuid.Parse("b68dbff6-a87d-11e9-a7f2-98ded00c39c8") + require.NoError(t, err) + singleQuery := query.WHERE(Person.PersonID.EQ(UUID(uuid))) var dest struct { model.Person @@ -207,7 +185,7 @@ func TestUUIDComplex(t *testing.T) { model.PersonPhone } } - err := singleQuery.Query(db, &dest) + err = singleQuery.Query(db, &dest) require.NoError(t, err) testutils.AssertJSON(t, dest, ` @@ -304,7 +282,7 @@ SELECT person.person_id AS "person.person_id", FROM test_sample.person; `) - result := []model.Person{} + var result []model.Person err := query.Query(db, &result) @@ -333,7 +311,7 @@ FROM test_sample.person; `) } -func TestSelecSelfJoin1(t *testing.T) { +func TestSelectSelfJoin1(t *testing.T) { // clean up _, err := Employee.DELETE().WHERE(Employee.EmployeeID.GT(Int(100))).Exec(db) @@ -398,7 +376,7 @@ ORDER BY employee.employee_id; } func TestWierdNamesTable(t *testing.T) { - stmt := WeirdNamesTable.SELECT(WeirdNamesTable.AllColumns) + stmt := WeirdNamesTable.SELECT(WeirdNamesTable.MutableColumns) testutils.AssertDebugStatementSql(t, stmt, ` SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column_name1", @@ -420,7 +398,7 @@ SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column FROM test_sample."WEIRD NAMES TABLE"; `) - dest := []model.WeirdNamesTable{} + var dest []model.WeirdNamesTable err := stmt.Query(db, &dest) @@ -448,7 +426,7 @@ FROM test_sample."WEIRD NAMES TABLE"; } func TestReserwedWordEscape(t *testing.T) { - stmt := SELECT(User.AllColumns). + stmt := SELECT(User.MutableColumns). FROM(User) //fmt.Println(stmt.DebugSql()) @@ -480,6 +458,7 @@ FROM test_sample."User"; testutils.AssertJSON(t, dest, ` [ { + "ID": 0, "Column": "Column", "Check": "CHECK", "Ceil": "CEIL", @@ -497,54 +476,3 @@ FROM test_sample."User"; ] `) } - -func TestBytea(t *testing.T) { - byteArrHex := "\\x48656c6c6f20476f7068657221" - byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21") - - insertStmt := AllTypes.INSERT(AllTypes.Bytea, AllTypes.ByteaPtr). - VALUES(byteArrHex, byteArrBin). - RETURNING(AllTypes.Bytea, AllTypes.ByteaPtr) - - testutils.AssertStatementSql(t, insertStmt, ` -INSERT INTO test_sample.all_types (bytea, bytea_ptr) -VALUES ($1, $2) -RETURNING all_types.bytea AS "all_types.bytea", - all_types.bytea_ptr AS "all_types.bytea_ptr"; -`, byteArrHex, byteArrBin) - - var inserted model.AllTypes - err := insertStmt.Query(db, &inserted) - require.NoError(t, err) - - require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!") - // It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver. - // pq driver always encodes parameter string if destination column is of type bytea. - // Probably pq driver error. - // require.Equal(t, string(inserted.Bytea), "Hello Gopher!") - - stmt := SELECT( - AllTypes.Bytea, - AllTypes.ByteaPtr, - ).FROM( - AllTypes, - ).WHERE( - AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)), - ) - - testutils.AssertStatementSql(t, stmt, ` -SELECT all_types.bytea AS "all_types.bytea", - all_types.bytea_ptr AS "all_types.bytea_ptr" -FROM test_sample.all_types -WHERE all_types.bytea_ptr = $1::bytea; -`, byteArrBin) - - var dest model.AllTypes - - err = stmt.Query(db, &dest) - require.NoError(t, err) - - require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!") - // Probably pq driver error. - // require.Equal(t, string(dest.Bytea), "Hello Gopher!") -} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 304acb8..b52508f 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -416,8 +416,8 @@ FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) WHERE ( - (city.city = 'London') - OR (city.city = 'York') + (city.city = 'London'::text) + OR (city.city = 'York'::text) ) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -492,7 +492,7 @@ SELECT city.city_id AS "my_city.id", FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) -WHERE (city.city = 'London') OR (city.city = 'York') +WHERE (city.city = 'London'::text) OR (city.city = 'York'::text) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -550,7 +550,7 @@ SELECT city.city_id AS "city_id", FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) -WHERE (city.city = 'London') OR (city.city = 'York') +WHERE (city.city = 'London'::text) OR (city.city = 'York'::text) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -607,7 +607,7 @@ SELECT city.city_id AS "city.city_id", FROM dvds.city INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id) -WHERE (city.city = 'London') OR (city.city = 'York') +WHERE (city.city = 'London'::text) OR (city.city = 'York'::text) ORDER BY city.city_id, address.address_id, customer.customer_id; `, "London", "York") @@ -685,9 +685,6 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) { err := query.Query(db, &customers) require.NoError(t, err) - - //spew.Dump(customers) - require.Equal(t, len(customers), 599) } @@ -770,27 +767,35 @@ ORDER BY customer.customer_id ASC; testutils.AssertDebugStatementSql(t, query, expectedSQL) - allCustomersAndAddress := []struct { + var allCustomersAndAddress []struct { Address *model.Address Customer *model.Customer - }{} + } err := query.Query(db, &allCustomersAndAddress) require.NoError(t, err) require.Equal(t, len(allCustomersAndAddress), 603) - testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) - require.True(t, allCustomersAndAddress[0].Address != nil) + if sourceIsCockroachDB() { + nullsFirst := allCustomersAndAddress[0] + require.True(t, nullsFirst.Customer == nil) + require.True(t, nullsFirst.Address != nil) - lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1] + testutils.AssertDeepEqual(t, allCustomersAndAddress[4].Customer, &customer0) + require.True(t, allCustomersAndAddress[0].Address != nil) + } else { // postgres + testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) + require.True(t, allCustomersAndAddress[0].Address != nil) - require.True(t, lastCustomerAddress.Customer == nil) - require.True(t, lastCustomerAddress.Address != nil) + nullsLast := allCustomersAndAddress[len(allCustomersAndAddress)-1] + require.True(t, nullsLast.Customer == nil) + require.True(t, nullsLast.Address != nil) + } } -func TestSelectFullCrossJoin(t *testing.T) { +func TestSelectCrossJoin(t *testing.T) { expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", @@ -1128,6 +1133,7 @@ ORDER BY film.film_id ASC; } func TestSelectGroupByHaving(t *testing.T) { + expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", @@ -1197,6 +1203,9 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC; require.Equal(t, len(dest), 104) + if sourceIsCockroachDB() { + return // small precision difference in result + } //testutils.SaveJsonFile(dest, "postgres/testdata/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") } @@ -1395,9 +1404,6 @@ ORDER BY payment.payment_date ASC; err := query.Query(db, &payments) require.NoError(t, err) - - //spew.Dump(payments) - require.Equal(t, len(payments), 9) testutils.AssertDeepEqual(t, payments[0], model.Payment{ PaymentID: 17793, @@ -1531,7 +1537,7 @@ func TestAllSetOperators(t *testing.T) { func TestSelectWithCase(t *testing.T) { expectedQuery := ` -SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE' WHEN 2 THEN 'TWO' WHEN 3 THEN 'THREE' ELSE 'OTHER' END) AS "staff_id_num" +SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE'::text WHEN 2 THEN 'TWO'::text WHEN 3 THEN 'THREE'::text ELSE 'OTHER'::text END) AS "staff_id_num" FROM dvds.payment ORDER BY payment.payment_id ASC LIMIT 20; @@ -1611,6 +1617,10 @@ FOR` require.NoError(t, err) } + if sourceIsCockroachDB() { + return // SKIP LOCKED lock wait policy is not supported + } + for lockType, lockTypeStr := range getRowLockTestData() { query.FOR(lockType.SKIP_LOCKED()) @@ -1660,7 +1670,7 @@ FROM dvds.actor INNER JOIN dvds.language ON (language.language_id = film.language_id) INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) INNER JOIN dvds.category ON (category.category_id = film_category.category_id) -WHERE ((language.name = 'English') AND (category.name != 'Action')) AND (film.length > 180) +WHERE ((language.name = 'English'::text) AND (category.name != 'Action'::text)) AND (film.length > 180) ORDER BY actor.actor_id ASC, film.film_id ASC; ` @@ -1927,10 +1937,11 @@ func TestSimpleView(t *testing.T) { query := SELECT( view.ActorInfo.AllColumns, - ). - FROM(view.ActorInfo). - ORDER_BY(view.ActorInfo.ActorID). - LIMIT(10) + ).FROM( + view.ActorInfo, + ).ORDER_BY( + view.ActorInfo.ActorID, + ).LIMIT(10) type ActorInfo struct { ActorID int @@ -1944,6 +1955,10 @@ func TestSimpleView(t *testing.T) { err := query.Query(db, &dest) require.NoError(t, err) + if sourceIsCockroachDB() { + return // skip for cockroach db, FilmInfo is set to '' in ddl + } + testutils.AssertJSON(t, dest[1:2], ` [ { @@ -2117,7 +2132,7 @@ FROM dvds.film language.name AS "language.name", language.last_update AS "language.last_update" FROM dvds.language - WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + WHERE (language.name NOT IN ('spanish'::text)) AND (film.language_id = language.language_id) ) AS films WHERE film.film_id = 1 ORDER BY film.film_id @@ -2162,7 +2177,7 @@ FROM dvds.film, language.name AS "language.name", language.last_update AS "language.last_update" FROM dvds.language - WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + WHERE (language.name NOT IN ('spanish'::text)) AND (film.language_id = language.language_id) ) AS films WHERE film.film_id = 1 ORDER BY film.film_id @@ -2630,6 +2645,8 @@ func GET_FILM_COUNT(lenFrom, lenTo IntegerExpression) IntegerExpression { } func TestCustomFunctionCall(t *testing.T) { + skipForCockroachDB(t) + stmt := SELECT( GET_FILM_COUNT(Int(100), Int(120)).AS("film_count"), ) @@ -2662,3 +2679,42 @@ SELECT dvds.get_film_count(100, 120) AS "film_count"; require.NoError(t, err) require.Equal(t, dest.FilmCount, 165) } + +var customer0 = model.Customer{ + CustomerID: 1, + StoreID: 1, + FirstName: "Mary", + LastName: "Smith", + Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), + AddressID: 5, + Activebool: true, + CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: testutils.Int32Ptr(1), +} + +var customer1 = model.Customer{ + CustomerID: 2, + StoreID: 1, + FirstName: "Patricia", + LastName: "Johnson", + Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), + AddressID: 6, + Activebool: true, + CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: testutils.Int32Ptr(1), +} + +var lastCustomer = model.Customer{ + CustomerID: 599, + StoreID: 2, + FirstName: "Austin", + LastName: "Cintron", + Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), + AddressID: 605, + Activebool: true, + CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: testutils.Int32Ptr(1), +} diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 6cde276..975b684 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -2,6 +2,7 @@ package postgres import ( "context" + "database/sql" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" @@ -14,9 +15,7 @@ import ( ) func TestUpdateValues(t *testing.T) { - setupLinkTableForUpdateTest(t) - - t.Run("deprecated version", func(t *testing.T) { + t.Run("deprecated update", func(t *testing.T) { query := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -25,31 +24,34 @@ func TestUpdateValues(t *testing.T) { testutils.AssertDebugStatementSql(t, query, ` UPDATE test_sample.link SET (name, url) = ('Bong', 'http://bong.com') -WHERE link.name = 'Bing'; +WHERE link.name = 'Bing'::text; `, "Bong", "http://bong.com", "Bing") - testutils.AssertExec(t, query, db, 1) - requireLogged(t, query) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { - links := []model.Link{} + testutils.AssertExec(t, query, tx, 1) + requireLogged(t, query) - selQuery := Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.IN(String("Bong"))) + var links []model.Link - err := selQuery.Query(db, &links) + selQuery := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.IN(String("Bong"))) - require.NoError(t, err) - require.Equal(t, len(links), 1) - testutils.AssertDeepEqual(t, links[0], model.Link{ - ID: 204, - URL: "http://bong.com", - Name: "Bong", + err := selQuery.Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + requireLogged(t, selQuery) }) - requireLogged(t, selQuery) }) - t.Run("new version", func(t *testing.T) { + t.Run("new type safe update", func(t *testing.T) { stmt := Link.UPDATE(). SET( Link.Name.SET(String("DuckDuckGo")), @@ -59,18 +61,16 @@ WHERE link.name = 'Bing'; testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link -SET name = 'DuckDuckGo', - url = 'www.duckduckgo.com' -WHERE link.name = 'Yahoo'; +SET name = 'DuckDuckGo'::text, + url = 'www.duckduckgo.com'::text +WHERE link.name = 'Yahoo'::text; `) - testutils.AssertExec(t, stmt, db, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireLogged(t, stmt) }) } func TestUpdateWithSubQueries(t *testing.T) { - setupLinkTableForUpdateTest(t) - t.Run("deprecated version", func(t *testing.T) { query := Link. UPDATE(Link.Name, Link.URL). @@ -82,20 +82,19 @@ func TestUpdateWithSubQueries(t *testing.T) { ). WHERE(Link.Name.EQ(String("Bing"))) - expectedSQL := ` + testutils.AssertDebugStatementSql(t, query, ` UPDATE test_sample.link SET (name, url) = (( - SELECT 'Bong' + SELECT 'Bong'::text ), ( SELECT link.url AS "link.url" FROM test_sample.link - WHERE link.name = 'Bing' + WHERE link.name = 'Bing'::text )) -WHERE link.name = 'Bing'; -` - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") +WHERE link.name = 'Bing'::text; +`, "Bong", "Bing", "Bing") - AssertExec(t, query, 1) + testutils.AssertExecAndRollback(t, query, db, 1) requireLogged(t, query) }) @@ -113,50 +112,48 @@ WHERE link.name = 'Bing'; testutils.AssertStatementSql(t, query, ` UPDATE test_sample.link -SET name = $1, +SET name = $1::text, url = ( SELECT link.url AS "link.url" FROM test_sample.link - WHERE link.name = $2 + WHERE link.name = $2::text ) -WHERE link.name = $3; +WHERE link.name = $3::text; `, "Bong", "Bing", "Bing") - _, err := query.Exec(db) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, query, db) requireLogged(t, query) }) } func TestUpdateAndReturning(t *testing.T) { - setupLinkTableForUpdateTest(t) - - expectedSQL := ` -UPDATE test_sample.link -SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.name = 'Ask' -RETURNING link.id AS "link.id", - link.url AS "link.url", - link.name AS "link.name", - link.description AS "link.description"; -` - stmt := Link. UPDATE(Link.Name, Link.URL). SET("DuckDuckGo", "http://www.duckduckgo.com"). WHERE(Link.Name.EQ(String("Ask"))). RETURNING(Link.AllColumns) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.link +SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com') +WHERE link.name = 'Ask'::text +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") - links := []model.Link{} + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + links := []model.Link{} - err := stmt.Query(db, &links) + err := stmt.Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 2) + require.Equal(t, links[0].Name, "DuckDuckGo") + require.Equal(t, links[1].Name, "DuckDuckGo") + requireLogged(t, stmt) + }) - require.NoError(t, err) - require.Equal(t, len(links), 2) - require.Equal(t, links[0].Name, "DuckDuckGo") - require.Equal(t, links[1].Name, "DuckDuckGo") - requireLogged(t, stmt) } func TestUpdateWithSelect(t *testing.T) { @@ -170,7 +167,7 @@ func TestUpdateWithSelect(t *testing.T) { ). WHERE(Link.ID.EQ(Int(0))) - expectedSQL := ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET (id, url, name, description) = ( SELECT link.id AS "link.id", @@ -181,10 +178,9 @@ SET (id, url, name, description) = ( WHERE link.id = 0 ) WHERE link.id = 0; -` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) +`, int64(0), int64(0)) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) }) t.Run("new version", func(t *testing.T) { @@ -210,12 +206,11 @@ SET (url, name, description) = ( WHERE link.id = 0; `, int64(0), int64(0)) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) }) } func TestUpdateWithInvalidSelect(t *testing.T) { - t.Run("deprecated version", func(t *testing.T) { stmt := Link.UPDATE(Link.AllColumns). SET( @@ -236,7 +231,6 @@ SET (id, url, name, description) = ( WHERE link.id = 0; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") }) @@ -250,8 +244,6 @@ WHERE link.id = 0; } func TestUpdateWithModelData(t *testing.T) { - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -261,24 +253,20 @@ func TestUpdateWithModelData(t *testing.T) { stmt := Link. UPDATE(Link.AllColumns). MODEL(link). - WHERE(Link.ID.EQ(Int32(link.ID))) + WHERE(Link.ID.EQ(Int64(link.ID))) expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) -WHERE link.id = 201::integer; +WHERE link.id = 201::bigint; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) - _, err := stmt.Exec(db) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, stmt, db, 1) requireQueryLogged(t, stmt, 1) } func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { - - setupLinkTableForUpdateTest(t) - link := model.Link{ ID: 201, URL: "http://www.duckduckgo.com", @@ -290,27 +278,24 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { stmt := Link. UPDATE(updateColumnList). MODEL(link). - WHERE(Link.ID.EQ(Int32(link.ID))) + WHERE(Link.ID.EQ(Int64(link.ID))) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, stmt, ` UPDATE test_sample.link SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') -WHERE link.id = 201::integer; -` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) +WHERE link.id = 201::bigint; +`, + nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) - AssertExec(t, stmt, 1) + testutils.AssertExecAndRollback(t, stmt, db, 1) } func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() - require.Equal(t, r, "missing struct field for column : id") }() - setupLinkTableForUpdateTest(t) - link := struct { Ident int URL string @@ -323,24 +308,13 @@ func TestUpdateWithInvalidModelData(t *testing.T) { Name: "DuckDuckGo", } - stmt := Link. + _ = Link. UPDATE(Link.AllColumns). - MODEL(link). + MODEL(link). // panics WHERE(Link.ID.EQ(Int(int64(link.Ident)))) - - var expectedSQL = ` -UPDATE test_sample.link -SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL) -WHERE link.id = 201; -` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) - - testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") } func TestUpdateQueryContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -351,15 +325,15 @@ func TestUpdateQueryContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} - err := updateStmt.QueryContext(ctx, db, &dest) + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + dest := []model.Link{} + err := updateStmt.QueryContext(ctx, tx, &dest) - require.Error(t, err, "context deadline exceeded") + require.Error(t, err, "context deadline exceeded") + }) } func TestUpdateExecContext(t *testing.T) { - setupLinkTableForUpdateTest(t) - updateStmt := Link. UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). @@ -370,15 +344,10 @@ func TestUpdateExecContext(t *testing.T) { time.Sleep(10 * time.Millisecond) - _, err := updateStmt.ExecContext(ctx, db) - - require.Error(t, err, "context deadline exceeded") + testutils.AssertExecContextErr(t, updateStmt, ctx, db, "context deadline exceeded") } func TestUpdateFrom(t *testing.T) { - tx := beginTx(t) - defer tx.Rollback() - stmt := table.Rental.UPDATE(). SET( table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)), @@ -416,16 +385,17 @@ RETURNING rental.rental_id AS "rental.rental_id", store.address_id AS "store.address_id"; `) - var dest []struct { - Rental model2.Rental - Store model2.Store - } + testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) { + var dest []struct { + Rental model2.Rental + Store model2.Store + } - err := stmt.Query(tx, &dest) + err := stmt.Query(tx, &dest) - require.NoError(t, err) - require.Len(t, dest, 3) - testutils.AssertJSON(t, dest[0], ` + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` { "Rental": { "RentalID": 4, @@ -444,24 +414,5 @@ RETURNING rental.rental_id AS "rental.rental_id", } } `) -} - -func setupLinkTableForUpdateTest(t *testing.T) { - - cleanUpLinkTable(t) - - _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). - VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES(201, "http://www.ask.com", "Ask", DEFAULT). - VALUES(202, "http://www.ask.com", "Ask", DEFAULT). - VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES(204, "http://www.bing.com", "Bing", DEFAULT). - Exec(db) - - require.NoError(t, err) -} - -func cleanUpLinkTable(t *testing.T) { - _, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db) - require.NoError(t, err) + }) } diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go deleted file mode 100644 index 847056f..0000000 --- a/tests/postgres/util_test.go +++ /dev/null @@ -1,57 +0,0 @@ -package postgres - -import ( - "github.com/go-jet/jet/v2/internal/jet" - "github.com/go-jet/jet/v2/internal/testutils" - "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" - "github.com/stretchr/testify/require" - "testing" -) - -func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { - res, err := stmt.Exec(db) - - require.NoError(t, err) - rows, err := res.RowsAffected() - require.NoError(t, err) - require.Equal(t, rows, rowsAffected) -} - -var customer0 = model.Customer{ - CustomerID: 1, - StoreID: 1, - FirstName: "Mary", - LastName: "Smith", - Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), - AddressID: 5, - Activebool: true, - CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), -} - -var customer1 = model.Customer{ - CustomerID: 2, - StoreID: 1, - FirstName: "Patricia", - LastName: "Johnson", - Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), - AddressID: 6, - Activebool: true, - CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), -} - -var lastCustomer = model.Customer{ - CustomerID: 599, - StoreID: 2, - FirstName: "Austin", - LastName: "Cintron", - Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), - AddressID: 605, - Activebool: true, - CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: testutils.Int32Ptr(1), -} diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index c78ca8a4..21fca32 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -106,9 +106,11 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { removeDiscontinuedOrders.AS( OrderDetails.DELETE(). WHERE(OrderDetails.ProductID.IN( - SELECT(Products.ProductID). - FROM(Products). - WHERE(Products.Discontinued.EQ(Int(1)))), + SELECT( + Products.ProductID, + ).FROM( + Products, + ).WHERE(Products.Discontinued.EQ(Int(1)))), ).RETURNING(OrderDetails.ProductID), ), updateDiscontinuedPrice.AS( @@ -121,7 +123,13 @@ func TestWithStatementDeleteAndInsert(t *testing.T) { ), logDiscontinuedProducts.AS( ProductLogs.INSERT(ProductLogs.AllColumns). - QUERY(SELECT(updateDiscontinuedPrice.AllColumns()).FROM(updateDiscontinuedPrice)). + QUERY( + SELECT( + updateDiscontinuedPrice.AllColumns(), + ).FROM( + updateDiscontinuedPrice, + ), + ). RETURNING( ProductLogs.ProductID, ProductLogs.ProductName, @@ -384,7 +392,7 @@ WITH cte1 AS ( SELECT territories.territory_id AS "territories.territory_id", territories.territory_description AS "territories.territory_description", territories.region_id AS "territories.region_id", - $1 AS "custom_column_1" + $1::text AS "custom_column_1" FROM northwind.territories ORDER BY territories.territory_id ASC ),cte2 AS ( @@ -392,7 +400,7 @@ WITH cte1 AS ( cte1."territories.territory_description" AS "territories.territory_description", cte1."territories.region_id" AS "territories.region_id", cte1.custom_column_1 AS "custom_column_1", - $2 AS "custom_column_2" + $2::text AS "custom_column_2" FROM cte1 ) SELECT cte2."territories.territory_id" AS "territories.territory_id", @@ -485,7 +493,7 @@ func TestRecursiveWithStatement(t *testing.T) { Employees, ).WHERE( Employees.EmployeeID.EQ(Int(2)), - ).UNION( + ).UNION_ALL( SELECT( Employees.AllColumns, ).FROM( @@ -790,13 +798,13 @@ WITH suppliers_fax AS ( suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name", suppliers_fax."suppliers.country" AS "suppliers.country" FROM suppliers_fax - WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia') + WHERE suppliers_fax."suppliers.country" NOT IN ('US'::text, 'Australia'::text) ) SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id", not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name", not_from_us_or_aus."suppliers.country" AS "suppliers.country" FROM not_from_us_or_aus -WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'; +WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'::text; `) var dest []model.Suppliers diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go index f5939bb..e1b3a54 100644 --- a/tests/sqlite/insert_test.go +++ b/tests/sqlite/insert_test.go @@ -2,6 +2,7 @@ package sqlite import ( "context" + "database/sql" "math/rand" "testing" @@ -15,9 +16,6 @@ import ( ) func TestInsertValues(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(101, "http://www.google.com", "Google", "Search engine"). @@ -32,31 +30,32 @@ VALUES (?, ?, ?, ?), 101, "http://www.google.com", "Google", "Search engine", 102, "http://www.yahoo.com", "Yahoo", nil) - _, err := insertQuery.Exec(tx) - require.NoError(t, err) - requireLogged(t, insertQuery) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + testutils.AssertExec(t, insertQuery, tx) + requireLogged(t, insertQuery) - insertedLinks := []model.Link{} + var insertedLinks []model.Link - err = SELECT(Link.AllColumns). - FROM(Link). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(tx, &insertedLinks) + err := SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 3) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) - testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ - ID: 101, - URL: "http://www.google.com", - Name: "Google", - Description: testutils.StringPtr("Search engine"), - }) - testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ - ID: 102, - URL: "http://www.yahoo.com", - Name: "Yahoo", + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + Description: testutils.StringPtr("Search engine"), + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) }) } @@ -67,41 +66,35 @@ var postgreTutorial = model.Link{ } func TestInsertEmptyColumnList(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - expectedSQL := ` -INSERT INTO link -VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL); -` - stmt := Link.INSERT(). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) - _, err := stmt.Exec(tx) - require.NoError(t, err) - requireLogged(t, stmt) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + requireLogged(t, stmt) - insertedLinks := []model.Link{} + var insertedLinks []model.Link - err = SELECT(Link.AllColumns). - FROM(Link). - WHERE(Link.ID.GT_EQ(Int(100))). - ORDER_BY(Link.ID). - Query(tx, &insertedLinks) + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) - require.NoError(t, err) - require.Equal(t, len(insertedLinks), 1) - testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + }) } func TestInsertModelObject(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -115,19 +108,13 @@ INSERT INTO link (url, name) VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); `, "http://www.duckduckgo.com", "Duck Duck go") - _, err := query.Exec(tx) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelObjectEmptyColumnList(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - var expectedSQL = ` -INSERT INTO link -VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); -` - linkData := model.Link{ ID: 1000, URL: "http://www.duckduckgo.com", @@ -138,23 +125,18 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); INSERT(). MODEL(linkData) - testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +`, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) - _, err := query.Exec(tx) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) + }) } func TestInsertModelsObject(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - expectedSQL := ` -INSERT INTO link (url, name) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), - ('http://www.google.com', 'Google'), - ('http://www.yahoo.com', 'Yahoo'); -` - tutorial := model.Link{ URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", @@ -176,27 +158,20 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), yahoo, }) - testutils.AssertDebugStatementSql(t, query, expectedSQL, + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.google.com", "Google", "http://www.yahoo.com", "Yahoo") - _, err := query.Exec(tx) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, query, sampleDB) } func TestInsertUsingMutableColumns(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - - var expectedSQL = ` -INSERT INTO link (url, name, description) -VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.google.com', 'Google', NULL), - ('http://www.yahoo.com', 'Yahoo', NULL); -` - google := model.Link{ URL: "http://www.google.com", Name: "Google", @@ -213,20 +188,22 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), MODEL(google). MODELS([]model.Link{google, yahoo}) - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +`, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, "http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil, "http://www.yahoo.com", "Yahoo", nil) - _, err := stmt.Exec(tx) - require.NoError(t, err) + testutils.AssertExecAndRollback(t, stmt, sampleDB) } func TestInsertQuery(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - var expectedSQL = ` INSERT INTO link (url, name) SELECT link.url AS "link.url", @@ -242,24 +219,22 @@ WHERE link.id = 24; ) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24)) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + _, err := query.Exec(tx) + require.NoError(t, err) - _, err := query.Exec(tx) - require.NoError(t, err) + var youtubeLinks []model.Link + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bing"))). + Query(tx, &youtubeLinks) - youtubeLinks := []model.Link{} - err = Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Bing"))). - Query(tx, &youtubeLinks) - - require.NoError(t, err) - require.Equal(t, len(youtubeLinks), 2) + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) + }) } func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - stmt := Link.INSERT(). DEFAULT_VALUES(). RETURNING(Link.AllColumns) @@ -273,24 +248,23 @@ RETURNING link.id AS "link.id", link.description AS "link.description"; `) - var link model.Link - err := stmt.Query(tx, &link) - require.NoError(t, err) + testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) { + var link model.Link + err := stmt.Query(tx, &link) + require.NoError(t, err) - require.EqualValues(t, link, model.Link{ - ID: 25, - URL: "www.", - Name: "_", - Description: nil, + require.EqualValues(t, link, model.Link{ + ID: 25, + URL: "www.", + Name: "_", + Description: nil, + }) }) } func TestInsertOnConflict(t *testing.T) { t.Run("do nothing", func(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - link := model.Link{ID: rand.Int31()} stmt := Link.INSERT(Link.AllColumns). @@ -304,14 +278,11 @@ VALUES (?, ?, ?, ?), (?, ?, ?, ?) ON CONFLICT (id) DO NOTHING; `) - testutils.AssertExec(t, stmt, tx, 1) + testutils.AssertExecAndRollback(t, stmt, sampleDB, 1) requireLogged(t, stmt) }) t.Run("do update", func(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). @@ -336,14 +307,11 @@ RETURNING link.id AS "link.id", link.description AS "link.description"; `) - testutils.AssertExec(t, stmt, tx) + testutils.AssertExecAndRollback(t, stmt, sampleDB) requireLogged(t, stmt) }) t.Run("do update complex", func(t *testing.T) { - tx := beginSampleDBTx(t) - defer tx.Rollback() - stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). ON_CONFLICT(Link.ID). @@ -370,7 +338,7 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE WHERE link.description IS NOT NULL; `) - testutils.AssertExec(t, stmt, tx) + testutils.AssertExecAndRollback(t, stmt, sampleDB) requireLogged(t, stmt) }) } @@ -384,7 +352,7 @@ func TestInsertContextDeadlineExceeded(t *testing.T) { time.Sleep(10 * time.Millisecond) - dest := []model.Link{} + var dest []model.Link err := stmt.QueryContext(ctx, sampleDB, &dest) require.Error(t, err, "context deadline exceeded") diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go index 4eb274e..4975845 100644 --- a/tests/sqlite/main_test.go +++ b/tests/sqlite/main_test.go @@ -35,6 +35,7 @@ func TestMain(m *testing.M) { var err error db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) throw.OnError(err) + defer db.Close() _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) throw.OnError(err) @@ -42,8 +43,6 @@ func TestMain(m *testing.M) { sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath) throw.OnError(err) - defer db.Close() - ret := m.Run() if ret != 0 {