Add support for CockorachDB.

This commit is contained in:
go-jet 2022-05-05 13:01:42 +02:00
parent 3ff9241eea
commit bc776f947b
33 changed files with 1040 additions and 1037 deletions

View file

@ -33,6 +33,13 @@ jobs:
MYSQL_USER: jet MYSQL_USER: jet
MYSQL_PASSWORD: jet MYSQL_PASSWORD: jet
- image: cockroachdb/cockroach-unstable:v22.1.0-beta.4
command: ['start-single-node', '--insecure']
environment:
COCKROACH_USER: jet
COCKROACH_PASSWORD: jet
COCKROACH_DATABASE: jetdb
environment: # environment variables for the build itself environment: # environment variables for the build itself
TEST_RESULTS: /tmp/test-results # path to where test results will be saved TEST_RESULTS: /tmp/test-results # path to where test results will be saved
@ -84,6 +91,17 @@ jobs:
done done
echo Failed waiting for MySQL && exit 1 echo Failed waiting for MySQL && exit 1
- run:
name: Waiting for Cockroach to be ready
command: |
for i in `seq 1 10`;
do
nc -z localhost 26257 && echo Success && exit 0
echo -n .
sleep 1
done
echo Failed waiting for Cockroach && exit 1
- run: - run:
name: Install MySQL CLI; name: Install MySQL CLI;
command: | command: |
@ -122,8 +140,9 @@ jobs:
-coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \ -coverpkg=github.com/go-jet/jet/v2/postgres/...,github.com/go-jet/jet/v2/mysql/...,github.com/go-jet/jet/v2/sqlite/...,github.com/go-jet/jet/v2/qrm/...,github.com/go-jet/jet/v2/generator/...,github.com/go-jet/jet/v2/internal/... \
-coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml -coverprofile=cover.out 2>&1 | go-junit-report > $TEST_RESULTS/results.xml
# run mariaDB tests. No need to collect coverage, because coverage is already included with mysql tests # run mariaDB and cockroachdb tests. No need to collect coverage, because coverage is already included with mysql and postgres tests
- run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/ - run: MY_SQL_SOURCE=MariaDB go test -v ./tests/mysql/
- run: PG_SOURCE=COCKROACH_DB go test -v ./tests/postgres/
- save_cache: - save_cache:
key: go-mod-v4-{{ checksum "go.sum" }} key: go-mod-v4-{{ checksum "go.sum" }}

View file

@ -35,7 +35,9 @@ WITH primaryKeys AS (
SELECT column_name SELECT column_name
FROM information_schema.key_column_usage AS c FROM information_schema.key_column_usage AS c
LEFT JOIN information_schema.table_constraints AS t LEFT JOIN information_schema.table_constraints AS t
ON t.constraint_name = c.constraint_name ON t.constraint_name = c.constraint_name AND
c.table_schema = t.table_schema AND
c.table_name = t.table_name
WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY' WHERE t.table_schema = $1 AND t.table_name = $2 AND t.constraint_type = 'PRIMARY KEY'
) )
SELECT column_name as "column.Name", SELECT column_name as "column.Name",

View file

@ -2,6 +2,8 @@ package testutils
import ( import (
"bytes" "bytes"
"context"
"database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/go-jet/jet/v2/internal/jet" "github.com/go-jet/jet/v2/internal/jet"
@ -25,6 +27,18 @@ var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool {
return t1.Unix() == t2.Unix() return t1.Unix() == t2.Unix()
}) })
// AssertExecAndRollback will execute and rollback statement in sql transaction
func AssertExecAndRollback(t *testing.T, stmt jet.Statement, db *sql.DB, rowsAffected ...int64) {
tx, err := db.Begin()
require.NoError(t, err)
defer func() {
err := tx.Rollback()
require.NoError(t, err)
}()
AssertExec(t, stmt, tx, rowsAffected...)
}
// AssertExec assert statement execution for successful execution and number of rows affected // AssertExec assert statement execution for successful execution and number of rows affected
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
@ -38,6 +52,18 @@ func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int
} }
} }
// ExecuteInTxAndRollback will execute function in sql transaction and then rollback transaction
func ExecuteInTxAndRollback(t *testing.T, db *sql.DB, f func(tx *sql.Tx)) {
tx, err := db.Begin()
require.NoError(t, err)
defer func() {
err := tx.Rollback()
require.NoError(t, err)
}()
f(tx)
}
// AssertExecErr assert statement execution for failed execution with error string errorStr // AssertExecErr assert statement execution for failed execution with error string errorStr
func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) { func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string) {
_, err := stmt.Exec(db) _, err := stmt.Exec(db)
@ -45,6 +71,13 @@ func AssertExecErr(t *testing.T, stmt jet.Statement, db qrm.DB, errorStr string)
require.Error(t, err, errorStr) require.Error(t, err, errorStr)
} }
// AssertExecContextErr assert statement execution for failed execution with error string errorStr
func AssertExecContextErr(t *testing.T, stmt jet.Statement, ctx context.Context, db qrm.DB, errorStr string) {
_, err := stmt.ExecContext(ctx, db)
require.Error(t, err, errorStr)
}
func getFullPath(relativePath string) string { func getFullPath(relativePath string) string {
path, _ := os.Getwd() path, _ := os.Getwd()
return filepath.Join(path, "../", relativePath) return filepath.Join(path, "../", relativePath)

View file

@ -5,7 +5,7 @@ import (
) )
func TestExpressionCAST_AS(t *testing.T) { func TestExpressionCAST_AS(t *testing.T) {
assertSerialize(t, CAST(String("test")).AS("text"), `$1::text`, "test") assertSerialize(t, CAST(Int(11)).AS("text"), `$1::text`, int64(11))
} }
func TestExpressionCAST_AS_BOOL(t *testing.T) { func TestExpressionCAST_AS_BOOL(t *testing.T) {

View file

@ -4,16 +4,16 @@ import "testing"
func TestString_REGEXP_LIKE_operator(t *testing.T) { func TestString_REGEXP_LIKE_operator(t *testing.T) {
assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)") assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 ~* table2.col_str)")
assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 ~* $1::text)", "JOHN")
assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), false), "(table3.col2 ~* $1::text)", "JOHN")
assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1)", "JOHN") assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN"), true), "(table3.col2 ~ $1::text)", "JOHN")
} }
func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) {
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 !~* table2.col_str)")
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 !~* $1::text)", "JOHN")
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), false), "(table3.col2 !~* $1::text)", "JOHN")
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1)", "JOHN") assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN"), true), "(table3.col2 !~ $1::text)", "JOHN")
} }
func TestExists(t *testing.T) { func TestExists(t *testing.T) {

View file

@ -60,7 +60,7 @@ func TestRawHelperMethods(t *testing.T) {
assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)), assertSerialize(t, RawFloat("table.colInt + :float", RawArgs{":float": 11.22}).EQ(Float(3.14)),
"((table.colInt + $1) = $2)", 11.22, 3.14) "((table.colInt + $1) = $2)", 11.22, 3.14)
assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")),
"((table.colStr || $1) = $2)", "doe", "john doe") "((table.colStr || $1) = $2::text)", "doe", "john doe")
now := time.Now() now := time.Now()
assertSerialize(t, RawTime("table.colTime").EQ(TimeT(now)), assertSerialize(t, RawTime("table.colTime").EQ(TimeT(now)),

View file

@ -167,7 +167,7 @@ VALUES ('one', 'two'),
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = TRUE::boolean, SET col_bool = TRUE::boolean,
col_int = 1, col_int = 1,
(col1, col_bool) = ROW(2, 'two') (col1, col_bool) = ROW(2, 'two'::text)
WHERE table1.col1 > 2 WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1", RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool"; table1.col_bool AS "table1.col_bool";
@ -193,7 +193,7 @@ VALUES ('one', 'two'),
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = FALSE::boolean, SET col_bool = FALSE::boolean,
col_int = 1, col_int = 1,
(col1, col_bool) = ROW(2, 'two') (col1, col_bool) = ROW(2, 'two'::text)
WHERE table1.col1 > 2 WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1", RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool"; table1.col_bool AS "table1.col_bool";

View file

@ -61,7 +61,9 @@ var Float = jet.Float
var Decimal = jet.Decimal var Decimal = jet.Decimal
// String creates new string literal expression // String creates new string literal expression
var String = jet.String func String(value string) StringExpression {
return CAST(jet.String(value)).AS_TEXT()
}
// UUID is a helper function to create string literal expression from uuid object // UUID is a helper function to create string literal expression from uuid object
// value can be any uuid type with a String method // value can be any uuid type with a String method

View file

@ -59,7 +59,7 @@ func TestFloat(t *testing.T) {
} }
func TestString(t *testing.T) { func TestString(t *testing.T) {
assertSerialize(t, String("Some text"), `$1`, "Some text") assertSerialize(t, String("Some text"), `$1::text`, "Some text")
} }
func TestBytea(t *testing.T) { func TestBytea(t *testing.T) {

View file

@ -39,6 +39,7 @@ jet-gen-postgres:
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=northwind -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/ jet -dsn=postgres://jet:jet@localhost:50901/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/
jet-gen-mysql: jet-gen-mysql:
@ -56,6 +57,12 @@ jet-gen-sqlite:
jet -source=sqlite -dsn="./testdata/init/sqlite/sakila.db" -schema=dvds -path=./.gentestdata/sqlite/sakila jet -source=sqlite -dsn="./testdata/init/sqlite/sakila.db" -schema=dvds -path=./.gentestdata/sqlite/sakila
jet -source=sqlite -dsn="./testdata/init/sqlite/test_sample.db" -schema=dvds -path=./.gentestdata/sqlite/test_sample jet -source=sqlite -dsn="./testdata/init/sqlite/test_sample.db" -schema=dvds -path=./.gentestdata/sqlite/test_sample
jet-gen-cockroach:
jet -dsn=postgres://jet:jet@127.0.0.1:26257/jetdb?sslmode=disable -schema=dvds -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=chinook -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=chinook2 -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=northwind -path=./.gentestdata/
jet -dsn=postgres://jet:jet@localhost:26257/jetdb?sslmode=disable -schema=test_sample -path=./.gentestdata/
# docker-compose-cleanup will stop and remove test containers, volumes, and images. # docker-compose-cleanup will stop and remove test containers, volumes, and images.
cleanup: cleanup:

View file

@ -15,7 +15,24 @@ const (
) )
// PostgresConnectString is PostgreSQL test database connection string // PostgresConnectString is PostgreSQL test database connection string
var PostgresConnectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", PgHost, PgPort, PgUser, PgPassword, PgDBName) var PostgresConnectString = pgConnectionString(PgHost, PgPort, PgUser, PgPassword, PgDBName)
// Postgres test database connection parameters
const (
CockroachHost = "localhost"
CockroachPort = 26257
CockroachUser = "jet"
CockroachPassword = "jet"
CockroachDBName = "jetdb"
)
// CockroachConnectString is Cockroach test database connection string
var CockroachConnectString = pgConnectionString(CockroachHost, CockroachPort, CockroachUser, CockroachPassword, CockroachDBName)
func pgConnectionString(host string, port int, user, password, dbName string) string {
return fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
host, port, user, password, dbName)
}
// MySQL test database connection parameters // MySQL test database connection parameters
const ( const (

View file

@ -37,3 +37,16 @@ services:
- '50903:3306' - '50903:3306'
volumes: volumes:
- ./testdata/init/mysql:/docker-entrypoint-initdb.d - ./testdata/init/mysql:/docker-entrypoint-initdb.d
cockroach:
image: cockroachdb/cockroach-unstable:v22.1.0-beta.4
environment:
- COCKROACH_USER=jet
- COCKROACH_PASSWORD=jet
- COCKROACH_DATABASE=jetdb
ports:
- "26257:26257"
command: start-single-node --insecure
# volumes:
# - ./testdata/init/cockroach:/docker-entrypoint-initdb.d

View file

@ -1,10 +1,12 @@
package main package main
import ( import (
"context"
"database/sql" "database/sql"
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/mysql" "github.com/go-jet/jet/v2/generator/mysql"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/generator/sqlite" "github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/tests/internal/utils/repo" "github.com/go-jet/jet/v2/tests/internal/utils/repo"
"io/ioutil" "io/ioutil"
@ -12,46 +14,52 @@ import (
"os/exec" "os/exec"
"strings" "strings"
"github.com/go-jet/jet/v2/generator/postgres"
"github.com/go-jet/jet/v2/internal/utils/throw" "github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/jackc/pgx/v4/stdlib"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
var testSuite string var testSuite string
func init() { func init() {
flag.StringVar(&testSuite, "testsuite", "all", "Test suite name (postgres or mysql)") flag.StringVar(&testSuite, "testsuite", "all", "Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)")
flag.Parse() flag.Parse()
} }
const (
Postgres = "postgres"
MySql = "mysql"
MariaDB = "mariadb"
Sqlite = "sqlite"
Cockroach = "cockroach"
)
func main() { func main() {
testSuite = strings.ToLower(testSuite) switch strings.ToLower(testSuite) {
case Postgres:
if testSuite == "postgres" { initPostgresDB(Postgres, dbconfig.PostgresConnectString)
initPostgresDB() case Cockroach:
return initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
} case MySql:
initMySQLDB(false)
if testSuite == "mysql" || testSuite == "mariadb" { case MariaDB:
initMySQLDB(testSuite == "mariadb") initMySQLDB(true)
return case Sqlite:
}
if testSuite == "sqlite" {
initSQLiteDB() initSQLiteDB()
return case "all":
} initPostgresDB(Cockroach, dbconfig.CockroachConnectString)
initPostgresDB(Postgres, dbconfig.PostgresConnectString)
initPostgresDB()
initMySQLDB(false) initMySQLDB(false)
initMySQLDB(true) initMySQLDB(true)
initSQLiteDB() initSQLiteDB()
default:
panic("invalid testsuite flag. Test suite name (postgres, mysql, mariadb, cockroach, sqlite or all)")
}
} }
func initSQLiteDB() { func initSQLiteDB() {
@ -109,8 +117,8 @@ func initMySQLDB(isMariaDB bool) {
} }
} }
func initPostgresDB() { func initPostgresDB(dbType string, connectionString string) {
db, err := sql.Open("postgres", dbconfig.PostgresConnectString) db, err := sql.Open("postgres", connectionString)
if err != nil { if err != nil {
panic("Failed to connect to test db: " + err.Error()) panic("Failed to connect to test db: " + err.Error())
} }
@ -120,26 +128,19 @@ func initPostgresDB() {
}() }()
schemaNames := []string{ schemaNames := []string{
"northwind",
"dvds", "dvds",
"test_sample", "test_sample",
"chinook", "chinook",
"chinook2", "chinook2",
"northwind",
} }
for _, schemaName := range schemaNames { for _, schemaName := range schemaNames {
fmt.Println("\nInitializing", schemaName, "schema...")
execFile(db, "./testdata/init/postgres/"+schemaName+".sql") execFile(db, fmt.Sprintf("./testdata/init/%s/%s.sql", dbType, schemaName))
err = postgres.Generate("./.gentestdata", postgres.DBConnection{ err = postgres.GenerateDSN(connectionString, schemaName, "./.gentestdata")
Host: dbconfig.PgHost,
Port: dbconfig.PgPort,
User: dbconfig.PgUser,
Password: dbconfig.PgPassword,
DBName: dbconfig.PgDBName,
SchemaName: schemaName,
SslMode: "disable",
})
throw.OnError(err) throw.OnError(err)
} }
} }
@ -148,10 +149,32 @@ func execFile(db *sql.DB, sqlFilePath string) {
testSampleSql, err := ioutil.ReadFile(sqlFilePath) testSampleSql, err := ioutil.ReadFile(sqlFilePath)
throw.OnError(err) throw.OnError(err)
_, err = db.Exec(string(testSampleSql)) err = execInTx(db, func(tx *sql.Tx) error {
_, err := tx.Exec(string(testSampleSql))
return err
})
throw.OnError(err) throw.OnError(err)
} }
func execInTx(db *sql.DB, f func(tx *sql.Tx) error) error {
tx, err := db.BeginTx(context.Background(), &sql.TxOptions{
Isolation: sql.LevelReadUncommitted, // to speed up initialization of test database
})
if err != nil {
return err
}
err = f(tx)
if err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
func printOnError(err error) { func printOnError(err error) {
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())

View file

@ -20,7 +20,7 @@ import (
func TestAllTypes(t *testing.T) { func TestAllTypes(t *testing.T) {
dest := []model.AllTypes{} var dest []model.AllTypes
err := AllTypes. err := AllTypes.
SELECT(AllTypes.AllColumns). SELECT(AllTypes.AllColumns).
@ -39,7 +39,7 @@ func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes type AllTypesView model.AllTypes
dest := []AllTypesView{} var dest []AllTypesView
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)

View file

@ -13,24 +13,20 @@ import (
) )
func TestDeleteWithWhere(t *testing.T) { func TestDeleteWithWhere(t *testing.T) {
initForDeleteTest(t)
var expectedSQL = `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook');
`
deleteStmt := Link. deleteStmt := Link.
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertDebugStatementSql(t, deleteStmt, `
testutils.AssertExec(t, deleteStmt, db, 2) DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook');
`, "Gmail", "Outlook")
testutils.AssertExecAndRollback(t, deleteStmt, db, 2)
requireLogged(t, deleteStmt) requireLogged(t, deleteStmt)
} }
func TestDeleteWithWhereOrderByLimit(t *testing.T) { func TestDeleteWithWhereOrderByLimit(t *testing.T) {
initForDeleteTest(t)
var expectedSQL = ` var expectedSQL = `
DELETE FROM test_sample.link DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook') WHERE link.name IN ('Gmail', 'Outlook')
@ -44,13 +40,11 @@ LIMIT 1;
LIMIT(1) LIMIT(1)
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1)) testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook", int64(1))
testutils.AssertExec(t, deleteStmt, db, 1) testutils.AssertExecAndRollback(t, deleteStmt, db, 1)
requireLogged(t, deleteStmt) requireLogged(t, deleteStmt)
} }
func TestDeleteQueryContext(t *testing.T) { func TestDeleteQueryContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link. deleteStmt := Link.
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
@ -60,7 +54,7 @@ func TestDeleteQueryContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
dest := []model.Link{} var dest []model.Link
err := deleteStmt.QueryContext(ctx, db, &dest) err := deleteStmt.QueryContext(ctx, db, &dest)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
@ -68,8 +62,6 @@ func TestDeleteQueryContext(t *testing.T) {
} }
func TestDeleteExecContext(t *testing.T) { func TestDeleteExecContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link. deleteStmt := Link.
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
@ -84,19 +76,7 @@ func TestDeleteExecContext(t *testing.T) {
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func initForDeleteTest(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(Link.URL, Link.Name, Link.Description).
VALUES("www.gmail.com", "Gmail", "Email service developed by Google").
VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft")
testutils.AssertExec(t, stmt, db, 2)
}
func TestDeleteWithUsing(t *testing.T) { func TestDeleteWithUsing(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
stmt := table.Rental.DELETE(). stmt := table.Rental.DELETE().
USING( USING(
table.Rental. table.Rental.
@ -116,5 +96,5 @@ USING dvds.rental
WHERE (staff.staff_id != ?) AND (rental.rental_id < ?); WHERE (staff.staff_id != ?) AND (rental.rental_id < ?);
`) `)
testutils.AssertExec(t, stmt, tx) testutils.AssertExecAndRollback(t, stmt, db)
} }

View file

@ -2,6 +2,7 @@ package mysql
import ( import (
"context" "context"
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql" . "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
@ -13,52 +14,48 @@ import (
) )
func TestInsertValues(t *testing.T) { func TestInsertValues(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link (id, url, name, description)
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT),
(102, 'http://www.yahoo.com', 'Yahoo', NULL);
`
insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(101, "http://www.google.com", "Google", DEFAULT). VALUES(101, "http://www.google.com", "Google", DEFAULT).
VALUES(102, "http://www.yahoo.com", "Yahoo", nil) VALUES(102, "http://www.yahoo.com", "Yahoo", nil)
testutils.AssertDebugStatementSql(t, insertQuery, expectedSQL, testutils.AssertDebugStatementSql(t, insertQuery, `
INSERT INTO test_sample.link (id, url, name, description)
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT),
(102, 'http://www.yahoo.com', 'Yahoo', NULL);
`,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
101, "http://www.google.com", "Google", 101, "http://www.google.com", "Google",
102, "http://www.yahoo.com", "Yahoo", nil) 102, "http://www.yahoo.com", "Yahoo", nil)
_, err := insertQuery.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := insertQuery.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
requireLogged(t, insertQuery) requireLogged(t, insertQuery)
insertedLinks := []model.Link{} var insertedLinks []model.Link
err = Link.SELECT(Link.AllColumns). err = Link.SELECT(Link.AllColumns).
WHERE(Link.ID.GT_EQ(Int(100))). WHERE(Link.ID.BETWEEN(Int(100), Int(199))).
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &insertedLinks) Query(tx, &insertedLinks)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(insertedLinks), 3) require.Equal(t, len(insertedLinks), 3)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{
ID: 101, ID: 101,
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
}) })
testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102, ID: 102,
URL: "http://www.yahoo.com", URL: "http://www.yahoo.com",
Name: "Yahoo", Name: "Yahoo",
}) })
})
} }
var postgreTutorial = model.Link{ var postgreTutorial = model.Link{
@ -68,42 +65,34 @@ var postgreTutorial = model.Link{
} }
func TestInsertEmptyColumnList(t *testing.T) { func TestInsertEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
expectedSQL := `
INSERT INTO test_sample.link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
`
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
`,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
_, err := stmt.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := stmt.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
requireLogged(t, stmt) requireLogged(t, stmt)
insertedLinks := []model.Link{} var insertedLinks []model.Link
err = Link.SELECT(Link.AllColumns). err = Link.SELECT(Link.AllColumns).
WHERE(Link.ID.GT_EQ(Int(100))). WHERE(Link.ID.BETWEEN(Int(100), Int(199))).
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &insertedLinks) Query(tx, &insertedLinks)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(insertedLinks), 1) require.Equal(t, len(insertedLinks), 1)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
})
} }
func TestInsertModelObject(t *testing.T) { func TestInsertModelObject(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`
linkData := model.Link{ linkData := model.Link{
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
Name: "Duck Duck go", Name: "Duck Duck go",
@ -113,19 +102,19 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
MODEL(linkData) MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`,
"http://www.duckduckgo.com", "Duck Duck go")
_, err := query.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
})
} }
func TestInsertModelObjectEmptyColumnList(t *testing.T) { func TestInsertModelObjectEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
linkData := model.Link{ linkData := model.Link{
ID: 1000, ID: 1000,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -136,20 +125,18 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
INSERT(). INSERT().
MODEL(linkData) MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
_, err := query.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
})
} }
func TestInsertModelsObject(t *testing.T) { func TestInsertModelsObject(t *testing.T) {
expectedSQL := `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`
tutorial := model.Link{ tutorial := model.Link{
URL: "http://www.postgresqltutorial.com", URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial", Name: "PostgreSQL Tutorial",
@ -169,24 +156,23 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
MODELS([]model.Link{tutorial, google, yahoo}) MODELS([]model.Link{tutorial, google, yahoo})
testutils.AssertDebugStatementSql(t, query, expectedSQL, testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", "http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo") "http://www.yahoo.com", "Yahoo")
_, err := query.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
})
} }
func TestInsertUsingMutableColumns(t *testing.T) { func TestInsertUsingMutableColumns(t *testing.T) {
var expectedSQL = `
INSERT INTO test_sample.link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`
google := model.Link{ google := model.Link{
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
@ -203,31 +189,25 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
MODEL(google). MODEL(google).
MODELS([]model.Link{google, yahoo}) MODELS([]model.Link{google, yahoo})
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.yahoo.com", "Yahoo", nil) "http://www.yahoo.com", "Yahoo", nil)
_, err := stmt.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := stmt.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
})
} }
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
_, err := Link.DELETE().
WHERE(Link.ID.NOT_EQ(Int(1)).AND(Link.Name.EQ(String("Youtube")))).
Exec(db)
require.NoError(t, err)
var expectedSQL = `
INSERT INTO test_sample.link (url, name) (
SELECT link.url AS "link.url",
link.name AS "link.name"
FROM test_sample.link
WHERE link.id = 1
);
`
query := Link. query := Link.
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
QUERY( QUERY(
@ -236,19 +216,28 @@ INSERT INTO test_sample.link (url, name) (
WHERE(Link.ID.EQ(Int(1))), WHERE(Link.ID.EQ(Int(1))),
) )
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(1)) testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link (url, name) (
SELECT link.url AS "link.url",
link.name AS "link.name"
FROM test_sample.link
WHERE link.id = 1
);
`, int64(1))
_, err = query.Exec(db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
youtubeLinks := []model.Link{} var youtubeLinks []model.Link
err = Link. err = Link.
SELECT(Link.AllColumns). SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Youtube"))). WHERE(Link.Name.EQ(String("Youtube"))).
Query(db, &youtubeLinks) Query(tx, &youtubeLinks)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(youtubeLinks), 2) require.Equal(t, len(youtubeLinks), 2)
})
} }
func TestInsertOnDuplicateKey(t *testing.T) { func TestInsertOnDuplicateKey(t *testing.T) {
@ -272,14 +261,16 @@ ON DUPLICATE KEY UPDATE id = (id + ?),
randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
int64(11), "PostgreSQL Tutorial 2") int64(11), "PostgreSQL Tutorial 2")
testutils.AssertExec(t, stmt, db, 3) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := stmt.Exec(tx)
require.NoError(t, err)
newLinks := []model.Link{} var newLinks []model.Link
err := SELECT(Link.AllColumns). err = SELECT(Link.AllColumns).
FROM(Link). FROM(Link).
WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))). WHERE(Link.ID.EQ(Int32(randId).ADD(Int(11)))).
Query(db, &newLinks) Query(tx, &newLinks)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, newLinks, 1) require.Len(t, newLinks, 1)
@ -289,11 +280,10 @@ ON DUPLICATE KEY UPDATE id = (id + ?),
Name: "PostgreSQL Tutorial 2", Name: "PostgreSQL Tutorial 2",
Description: nil, Description: nil,
}) })
})
} }
func TestInsertWithQueryContext(t *testing.T) { func TestInsertWithQueryContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
@ -302,15 +292,13 @@ func TestInsertWithQueryContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
dest := []model.Link{} var dest []model.Link
err := stmt.QueryContext(ctx, db, &dest) err := stmt.QueryContext(ctx, db, &dest)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func TestInsertWithExecContext(t *testing.T) { func TestInsertWithExecContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
@ -323,8 +311,3 @@ func TestInsertWithExecContext(t *testing.T) {
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
} }
func cleanUpLinkTable(t *testing.T) {
_, err := Link.DELETE().WHERE(Link.ID.GT(Int(1))).Exec(db)
require.NoError(t, err)
}

View file

@ -96,9 +96,3 @@ func skipForMariaDB(t *testing.T) {
t.SkipNow() t.SkipNow()
} }
} }
func beginTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -2,6 +2,7 @@ package mysql
import ( import (
"context" "context"
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/mysql" . "github.com/go-jet/jet/v2/mysql"
"github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table"
@ -13,8 +14,6 @@ import (
) )
func TestUpdateValues(t *testing.T) { func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t)
var expectedSQL = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET name = 'Bong', SET name = 'Bong',
@ -28,8 +27,26 @@ WHERE link.name = 'Bing';
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
testutils.AssertExec(t, query, tx)
requireLogged(t, query) requireLogged(t, query)
var links []model.Link
err := Link.
SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Bong"))).
Query(tx, &links)
require.NoError(t, err)
require.Equal(t, len(links), 1)
testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 204,
URL: "http://bong.com",
Name: "Bong",
})
})
}) })
t.Run("new version", func(t *testing.T) { t.Run("new version", func(t *testing.T) {
@ -41,16 +58,16 @@ WHERE link.name = 'Bing';
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, stmt, db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt) requireLogged(t, stmt)
})
links := []model.Link{} var links []model.Link
err := Link. err := Link.
SELECT(Link.AllColumns). SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Bong"))). WHERE(Link.Name.EQ(String("Bong"))).
Query(db, &links) Query(tx, &links)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(links), 1) require.Equal(t, len(links), 1)
@ -59,11 +76,11 @@ WHERE link.name = 'Bing';
URL: "http://bong.com", URL: "http://bong.com",
Name: "Bong", Name: "Bong",
}) })
})
})
} }
func TestUpdateWithSubQueries(t *testing.T) { func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t)
expectedSQL := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET name = ?, SET name = ?,
@ -86,7 +103,7 @@ WHERE link.name = ?;
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db) testutils.AssertExecAndRollback(t, query, db)
requireLogged(t, query) requireLogged(t, query)
}) })
@ -104,14 +121,12 @@ WHERE link.name = ?;
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db) testutils.AssertExecAndRollback(t, query, db)
requireLogged(t, query) requireLogged(t, query)
}) })
} }
func TestUpdateWithModelData(t *testing.T) { func TestUpdateWithModelData(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{ link := model.Link{
ID: 201, ID: 201,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -123,24 +138,20 @@ func TestUpdateWithModelData(t *testing.T) {
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID))) WHERE(Link.ID.EQ(Int32(link.ID)))
expectedSQL := ` testutils.AssertStatementSql(t, stmt, `
UPDATE test_sample.link UPDATE test_sample.link
SET id = ?, SET id = ?,
url = ?, url = ?,
name = ?, name = ?,
description = ? description = ?
WHERE link.id = ?; WHERE link.id = ?;
` `, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
testutils.AssertExec(t, stmt, db) testutils.AssertExecAndRollback(t, stmt, db)
requireLogged(t, stmt) requireLogged(t, stmt)
} }
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{ link := model.Link{
ID: 201, ID: 201,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -154,23 +165,19 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID))) WHERE(Link.ID.EQ(Int32(link.ID)))
var expectedSQL = ` testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link UPDATE test_sample.link
SET description = NULL, SET description = NULL,
name = 'DuckDuckGo', name = 'DuckDuckGo',
url = 'http://www.duckduckgo.com' url = 'http://www.duckduckgo.com'
WHERE link.id = 201; WHERE link.id = 201;
` `, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) testutils.AssertExecAndRollback(t, stmt, db)
testutils.AssertExec(t, stmt, db)
requireLogged(t, stmt) requireLogged(t, stmt)
} }
func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { func TestUpdateWithModelDataAndMutableColumns(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{ link := model.Link{
ID: 201, ID: 201,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -192,7 +199,7 @@ WHERE link.id = 201;
//fmt.Println(stmt.DebugSql()) //fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
testutils.AssertExec(t, stmt, db) testutils.AssertExecAndRollback(t, stmt, db)
} }
func TestUpdateWithInvalidModelData(t *testing.T) { func TestUpdateWithInvalidModelData(t *testing.T) {
@ -201,8 +208,6 @@ func TestUpdateWithInvalidModelData(t *testing.T) {
require.Equal(t, r, "missing struct field for column : id") require.Equal(t, r, "missing struct field for column : id")
}() }()
setupLinkTableForUpdateTest(t)
link := struct { link := struct {
Ident int Ident int
URL string URL string
@ -215,17 +220,13 @@ func TestUpdateWithInvalidModelData(t *testing.T) {
Name: "DuckDuckGo", Name: "DuckDuckGo",
} }
stmt := Link. _ = Link.
UPDATE(Link.AllColumns). UPDATE(Link.AllColumns).
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.Ident)))) WHERE(Link.ID.EQ(Int(int64(link.Ident))))
stmt.Sql()
} }
func TestUpdateQueryContext(t *testing.T) { func TestUpdateQueryContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link. updateStmt := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
@ -243,8 +244,6 @@ func TestUpdateQueryContext(t *testing.T) {
} }
func TestUpdateExecContext(t *testing.T) { func TestUpdateExecContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link. updateStmt := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
@ -261,9 +260,6 @@ func TestUpdateExecContext(t *testing.T) {
} }
func TestUpdateWithJoin(t *testing.T) { func TestUpdateWithJoin(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
UPDATE(table.Staff.LastName). UPDATE(table.Staff.LastName).
SET(String("New staff name")). SET(String("New staff name")).
@ -276,21 +272,5 @@ SET last_name = ?
WHERE staff.staff_id = ?; WHERE staff.staff_id = ?;
`, "New staff name", int64(1)) `, "New staff name", int64(1))
_, err := statement.Exec(tx) testutils.AssertExecAndRollback(t, statement, db)
require.NoError(t, err)
}
func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t)
_, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(201, "http://www.ask.com", "Ask", DEFAULT).
VALUES(202, "http://www.ask.com", "Ask", DEFAULT).
VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT).
VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db)
require.NoError(t, err)
} }

View file

@ -1,6 +1,7 @@
package postgres package postgres
import ( import (
"database/sql"
"testing" "testing"
"time" "time"
@ -17,10 +18,10 @@ import (
) )
func TestAllTypesSelect(t *testing.T) { func TestAllTypesSelect(t *testing.T) {
dest := []model.AllTypes{} var dest []model.AllTypes
err := AllTypes.SELECT( err := AllTypes.SELECT(
AllTypes.AllColumns, AllTypesAllColumns,
).LIMIT(2). ).LIMIT(2).
Query(db, &dest) Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
@ -32,7 +33,7 @@ func TestAllTypesSelect(t *testing.T) {
func TestAllTypesViewSelect(t *testing.T) { func TestAllTypesViewSelect(t *testing.T) {
type AllTypesView model.AllTypes type AllTypesView model.AllTypes
dest := []AllTypesView{} var dest []AllTypesView
err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest) err := view.AllTypesView.SELECT(view.AllTypesView.AllColumns).Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
@ -44,40 +45,123 @@ func TestAllTypesViewSelect(t *testing.T) {
func TestAllTypesInsertModel(t *testing.T) { func TestAllTypesInsertModel(t *testing.T) {
skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008) skipForPgxDriver(t) // pgx driver bug ERROR: date/time field value out of range: "0000-01-01 12:05:06Z" (SQLSTATE 22008)
query := AllTypes.INSERT(AllTypes.AllColumns). query := AllTypes.INSERT(AllTypesAllColumns).
MODEL(allTypesRow0). MODEL(allTypesRow0).
MODEL(&allTypesRow1). MODEL(&allTypesRow1).
RETURNING(AllTypes.AllColumns) RETURNING(AllTypes.AllColumns)
dest := []model.AllTypes{} testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
err := query.Query(db, &dest) var dest []model.AllTypes
err := query.Query(tx, &dest)
require.NoError(t, err) require.NoError(t, err)
if sourceIsCockroachDB() {
return
}
require.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
testutils.AssertDeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
})
} }
var AllTypesAllColumns = AllTypes.AllColumns.Except(IntegerColumn("rowid"))
func TestAllTypesInsertQuery(t *testing.T) { func TestAllTypesInsertQuery(t *testing.T) {
query := AllTypes.INSERT(AllTypes.AllColumns). query := AllTypes.INSERT(AllTypesAllColumns).
QUERY( QUERY(
AllTypes. AllTypes.
SELECT(AllTypes.AllColumns). SELECT(AllTypesAllColumns).
LIMIT(2), LIMIT(2),
). ).
RETURNING(AllTypes.AllColumns) RETURNING(AllTypesAllColumns)
dest := []model.AllTypes{} testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
err := query.Query(db, &dest) var dest []model.AllTypes
err := query.Query(tx, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(dest), 2) require.Equal(t, len(dest), 2)
testutils.AssertDeepEqual(t, dest[0], allTypesRow0) testutils.AssertDeepEqual(t, dest[0], allTypesRow0)
testutils.AssertDeepEqual(t, dest[1], allTypesRow1) testutils.AssertDeepEqual(t, dest[1], allTypesRow1)
})
}
func TestUUIDType(t *testing.T) {
id := uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
query := AllTypes.
SELECT(AllTypes.UUID, AllTypes.UUIDPtr).
WHERE(AllTypes.UUID.EQ(UUID(id)))
testutils.AssertDebugStatementSql(t, query, `
SELECT all_types.uuid AS "all_types.uuid",
all_types.uuid_ptr AS "all_types.uuid_ptr"
FROM test_sample.all_types
WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11';
`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
result := model.AllTypes{}
err := query.Query(db, &result)
require.NoError(t, err)
require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
requireLogged(t, query)
}
func TestBytea(t *testing.T) {
byteArrHex := "\\x48656c6c6f20476f7068657221"
byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21")
insertStmt := AllTypes.INSERT(AllTypes.Bytea, AllTypes.ByteaPtr).
VALUES(byteArrHex, byteArrBin).
RETURNING(AllTypes.Bytea, AllTypes.ByteaPtr)
testutils.AssertStatementSql(t, insertStmt, `
INSERT INTO test_sample.all_types (bytea, bytea_ptr)
VALUES ($1, $2)
RETURNING all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr";
`, byteArrHex, byteArrBin)
var inserted model.AllTypes
err := insertStmt.Query(db, &inserted)
require.NoError(t, err)
require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!")
// It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver.
// pq driver always encodes parameter string if destination column is of type bytea.
// Probably pq driver error.
// require.Equal(t, string(inserted.Bytea), "Hello Gopher!")
stmt := SELECT(
AllTypes.Bytea,
AllTypes.ByteaPtr,
).FROM(
AllTypes,
).WHERE(
AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)),
)
testutils.AssertStatementSql(t, stmt, `
SELECT all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr"
FROM test_sample.all_types
WHERE all_types.bytea_ptr = $1::bytea;
`, byteArrBin)
var dest model.AllTypes
err = stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!")
// Probably pq driver error.
// require.Equal(t, string(dest.Bytea), "Hello Gopher!")
} }
func TestAllTypesFromSubQuery(t *testing.T) { func TestAllTypesFromSubQuery(t *testing.T) {
subQuery := SELECT(AllTypes.AllColumns). subQuery := SELECT(AllTypesAllColumns).
FROM(AllTypes). FROM(AllTypes).
AsTable("allTypesSubQuery") AsTable("allTypesSubQuery")
@ -214,7 +298,7 @@ FROM (
LIMIT 2; LIMIT 2;
`) `)
dest := []model.AllTypes{} var dest []model.AllTypes
err := mainQuery.Query(db, &dest) err := mainQuery.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
@ -298,7 +382,6 @@ LIMIT $11;
} }
func TestExpressionCast(t *testing.T) { func TestExpressionCast(t *testing.T) {
skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text' skipForPgxDriver(t) // pgx driver bug 'cannot convert 151 to Text'
query := AllTypes.SELECT( query := AllTypes.SELECT(
@ -315,11 +398,18 @@ func TestExpressionCast(t *testing.T) {
CAST(Int(234)).AS_TEXT(), CAST(Int(234)).AS_TEXT(),
CAST(String("1/8/1999")).AS_DATE(), CAST(String("1/8/1999")).AS_DATE(),
CAST(String("04:05:06.789")).AS_TIME(), CAST(String("04:05:06.789")).AS_TIME(),
CAST(String("04:05:06 PST")).AS_TIMEZ(), CAST(String("04:05:06+01:00")).AS_TIMEZ(),
CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(),
CAST(String("January 8 04:05:06 1999 PST")).AS_TIMESTAMPZ(), CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(),
CAST(String("04:05:06")).AS_INTERVAL(), CAST(String("04:05:06")).AS_INTERVAL(),
func() ProjectionList {
if sourceIsCockroachDB() {
return ProjectionList{NULL}
}
// cockroach doesn't support currently
return ProjectionList{
TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")), TO_CHAR(AllTypes.Timestamp, String("HH12:MI:SS")),
TO_CHAR(AllTypes.Integer, String("999")), TO_CHAR(AllTypes.Integer, String("999")),
TO_CHAR(AllTypes.DoublePrecision, String("999D9")), TO_CHAR(AllTypes.DoublePrecision, String("999D9")),
@ -328,6 +418,8 @@ func TestExpressionCast(t *testing.T) {
TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")), TO_DATE(String("05 Dec 2000"), String("DD Mon YYYY")),
TO_NUMBER(String("12,454"), String("99G999D9S")), TO_NUMBER(String("12,454"), String("99G999D9S")),
TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")), TO_TIMESTAMP(String("05 Dec 2000"), String("DD Mon YYYY")),
}
}(),
COALESCE(AllTypes.IntegerPtr, AllTypes.SmallIntPtr, NULL, Int(11)), COALESCE(AllTypes.IntegerPtr, AllTypes.SmallIntPtr, NULL, Int(11)),
NULLIF(AllTypes.Text, String("(none)")), NULLIF(AllTypes.Text, String("(none)")),
@ -337,15 +429,14 @@ func TestExpressionCast(t *testing.T) {
Raw("current_database()"), Raw("current_database()"),
) )
//fmt.Println(query.DebugSql()) var dest []struct{}
dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
} }
func TestStringOperators(t *testing.T) { func TestStringOperators(t *testing.T) {
skipForCockroachDB(t) // some string functions are still unimplemented
skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text' skipForPgxDriver(t) // pgx driver bug 'cannot convert 11 to Text'
query := AllTypes.SELECT( query := AllTypes.SELECT(
@ -395,18 +486,18 @@ func TestStringOperators(t *testing.T) {
CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)),
CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT(Bool(false), Int(1), Float(22.2), String("test test")),
CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)),
CONVERT(String("bytea"), String("UTF8"), String("LATIN1")), CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")),
CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")), CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")),
CONVERT_FROM(String("text_in_utf8"), String("UTF8")), CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")),
CONVERT_TO(String("text_in_utf8"), String("UTF8")), CONVERT_TO(String("text_in_utf8"), String("UTF8")),
ENCODE(String("123\000\001"), String("base64")), ENCODE(Bytea("123\000\001"), String("base64")),
DECODE(String("MTIzAAE="), String("base64")), DECODE(String("MTIzAAE="), String("base64")),
FORMAT(String("Hello %s, %1$s"), String("World")), FORMAT(String("Hello %s, %1$s"), String("World")),
INITCAP(String("hi THOMAS")), INITCAP(String("hi THOMAS")),
LEFT(String("abcde"), Int(2)), LEFT(String("abcde"), Int(2)),
RIGHT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)),
LENGTH(String("jose")), LENGTH(Bytea("jose")),
LENGTH(String("jose"), String("UTF8")), LENGTH(Bytea("jose"), String("UTF8")),
LPAD(String("Hi"), Int(5)), LPAD(String("Hi"), Int(5)),
LPAD(String("Hi"), Int(5), String("xy")), LPAD(String("Hi"), Int(5), String("xy")),
RPAD(String("Hi"), Int(5)), RPAD(String("Hi"), Int(5)),
@ -421,8 +512,6 @@ func TestStringOperators(t *testing.T) {
TO_HEX(AllTypes.IntegerPtr), TO_HEX(AllTypes.IntegerPtr),
) )
//fmt.Println(query.DebugSql())
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := query.Query(db, &dest)
@ -501,6 +590,8 @@ LIMIT $5;
} }
func TestFloatOperators(t *testing.T) { func TestFloatOperators(t *testing.T) {
skipForCockroachDB(t) // some functions are still unimplemented
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"), AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"),
AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"), AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"),
@ -604,6 +695,8 @@ LIMIT $38;
} }
func TestIntegerOperators(t *testing.T) { func TestIntegerOperators(t *testing.T) {
skipForCockroachDB(t) // some functions are still unimplemented
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.BigInt, AllTypes.BigInt,
AllTypes.BigIntPtr, AllTypes.BigIntPtr,
@ -733,6 +826,8 @@ LIMIT $27;
} }
func TestTimeExpression(t *testing.T) { func TestTimeExpression(t *testing.T) {
skipForCockroachDB(t)
query := AllTypes.SELECT( query := AllTypes.SELECT(
AllTypes.Time.EQ(AllTypes.Time), AllTypes.Time.EQ(AllTypes.Time),
AllTypes.Time.EQ(Time(23, 6, 6, 1)), AllTypes.Time.EQ(Time(23, 6, 6, 1)),
@ -813,6 +908,8 @@ func TestTimeExpression(t *testing.T) {
} }
func TestInterval(t *testing.T) { func TestInterval(t *testing.T) {
skipForCockroachDB(t)
stmt := SELECT( stmt := SELECT(
INTERVAL(1, YEAR), INTERVAL(1, YEAR),
INTERVAL(1, MONTH), INTERVAL(1, MONTH),
@ -1084,6 +1181,10 @@ LIMIT $6;
dest.Timez = dest.Timez.UTC() dest.Timez = dest.Timez.UTC()
dest.Timestampz = dest.Timestampz.UTC() dest.Timestampz = dest.Timestampz.UTC()
if sourceIsCockroachDB() {
return // rounding differences
}
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
{ {
"Date": "2009-11-17T00:00:00Z", "Date": "2009-11-17T00:00:00Z",

View file

@ -188,19 +188,7 @@ func TestJoinEverything(t *testing.T) {
manager := Employee.AS("Manager") manager := Employee.AS("Manager")
stmt := Artist. stmt := SELECT(
LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)).
LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)).
LEFT_JOIN(Genre, Genre.GenreId.EQ(Track.GenreId)).
LEFT_JOIN(MediaType, MediaType.MediaTypeId.EQ(Track.MediaTypeId)).
LEFT_JOIN(PlaylistTrack, PlaylistTrack.TrackId.EQ(Track.TrackId)).
LEFT_JOIN(Playlist, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId)).
LEFT_JOIN(InvoiceLine, InvoiceLine.TrackId.EQ(Track.TrackId)).
LEFT_JOIN(Invoice, Invoice.InvoiceId.EQ(InvoiceLine.InvoiceId)).
LEFT_JOIN(Customer, Customer.CustomerId.EQ(Invoice.CustomerId)).
LEFT_JOIN(Employee, Employee.EmployeeId.EQ(Customer.SupportRepId)).
LEFT_JOIN(manager, manager.EmployeeId.EQ(Employee.ReportsTo)).
SELECT(
Artist.AllColumns, Artist.AllColumns,
Album.AllColumns, Album.AllColumns,
Track.AllColumns, Track.AllColumns,
@ -212,10 +200,24 @@ func TestJoinEverything(t *testing.T) {
Customer.AllColumns, Customer.AllColumns,
Employee.AllColumns, Employee.AllColumns,
manager.AllColumns, manager.AllColumns,
). ).FROM(
ORDER_BY(Artist.ArtistId, Album.AlbumId, Track.TrackId, Artist.
LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)).
LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)).
LEFT_JOIN(Genre, Genre.GenreId.EQ(Track.GenreId)).
LEFT_JOIN(MediaType, MediaType.MediaTypeId.EQ(Track.MediaTypeId)).
LEFT_JOIN(PlaylistTrack, PlaylistTrack.TrackId.EQ(Track.TrackId)).
LEFT_JOIN(Playlist, Playlist.PlaylistId.EQ(PlaylistTrack.PlaylistId)).
LEFT_JOIN(InvoiceLine, InvoiceLine.TrackId.EQ(Track.TrackId)).
LEFT_JOIN(Invoice, Invoice.InvoiceId.EQ(InvoiceLine.InvoiceId)).
LEFT_JOIN(Customer, Customer.CustomerId.EQ(Invoice.CustomerId)).
LEFT_JOIN(Employee, Employee.EmployeeId.EQ(Customer.SupportRepId)).
LEFT_JOIN(manager, manager.EmployeeId.EQ(Employee.ReportsTo)),
).ORDER_BY(
Artist.ArtistId, Album.AlbumId, Track.TrackId,
Genre.GenreId, MediaType.MediaTypeId, Playlist.PlaylistId, Genre.GenreId, MediaType.MediaTypeId, Playlist.PlaylistId,
Invoice.InvoiceId, Customer.CustomerId) Invoice.InvoiceId, Customer.CustomerId,
)
var dest []struct { //list of all artist var dest []struct { //list of all artist
model.Artist model.Artist
@ -398,11 +400,11 @@ FROM (
SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId", SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId",
"subQuery1"."Artist.Name" AS "Artist.Name", "subQuery1"."Artist.Name" AS "Artist.Name",
"subQuery1".custom_column_1 AS "custom_column_1", "subQuery1".custom_column_1 AS "custom_column_1",
$1 AS "custom_column_2" $1::text AS "custom_column_2"
FROM ( FROM (
SELECT "Artist"."ArtistId" AS "Artist.ArtistId", SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name", "Artist"."Name" AS "Artist.Name",
$2 AS "custom_column_1" $2::text AS "custom_column_1"
FROM chinook."Artist" FROM chinook."Artist"
ORDER BY "Artist"."ArtistId" ASC ORDER BY "Artist"."ArtistId" ASC
) AS "subQuery1" ) AS "subQuery1"
@ -721,11 +723,14 @@ ORDER BY "Album.AlbumId";
} }
func TestQueryWithContext(t *testing.T) { func TestQueryWithContext(t *testing.T) {
if sourceIsCockroachDB() && !isPgxDriver() {
return // context cancellation doesn't work for pq driver
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
dest := []model.Album{} var dest []model.Album
err := Album. err := Album.
CROSS_JOIN(Track). CROSS_JOIN(Track).
@ -737,6 +742,9 @@ func TestQueryWithContext(t *testing.T) {
} }
func TestExecWithContext(t *testing.T) { func TestExecWithContext(t *testing.T) {
if sourceIsCockroachDB() && !isPgxDriver() {
return // context cancellation doesn't work for pq driver
}
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel() defer cancel()
@ -828,10 +836,12 @@ func Test_SchemaRename(t *testing.T) {
albumArtistID := Album2.ArtistId.From(first10Albums) albumArtistID := Album2.ArtistId.From(first10Albums)
stmt := SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()). stmt := SELECT(
FROM(first10Artist. first10Artist.AllColumns(),
INNER_JOIN(first10Albums, artistID.EQ(albumArtistID))). first10Albums.AllColumns(),
ORDER_BY(artistID) ).FROM(first10Artist.
INNER_JOIN(first10Albums, artistID.EQ(albumArtistID)),
).ORDER_BY(artistID)
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `
SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId", SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId",
@ -891,6 +901,8 @@ var album347 = model.Album{
} }
func TestAggregateFunc(t *testing.T) { func TestAggregateFunc(t *testing.T) {
skipForCockroachDB(t)
stmt := SELECT( stmt := SELECT(
PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"), PERCENTILE_DISC(Float(0.1)).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceId).AS("percentile_disc_1"),
PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"), PERCENTILE_DISC(Invoice.Total.DIV(Float(100))).WITHIN_GROUP_ORDER_BY(Invoice.InvoiceDate.ASC()).AS("percentile_disc_2"),

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres" . "github.com/go-jet/jet/v2/postgres"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
@ -14,48 +15,38 @@ import (
) )
func TestDeleteWithWhere(t *testing.T) { func TestDeleteWithWhere(t *testing.T) {
initForDeleteTest(t)
var expectedSQL = `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook');
`
deleteStmt := Link. deleteStmt := Link.
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertDebugStatementSql(t, deleteStmt, `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail'::text, 'Outlook'::text);
`, "Gmail", "Outlook")
res, err := deleteStmt.ExecContext(context.Background(), db) testutils.AssertExecAndRollback(t, deleteStmt, db, 2)
require.NoError(t, err)
rows, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, rows, int64(2))
requireQueryLogged(t, deleteStmt, int64(2)) requireQueryLogged(t, deleteStmt, int64(2))
} }
func TestDeleteWithWhereAndReturning(t *testing.T) { func TestDeleteWithWhereAndReturning(t *testing.T) {
initForDeleteTest(t)
var expectedSQL = `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail', 'Outlook')
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`
deleteStmt := Link. deleteStmt := Link.
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))). WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook") testutils.AssertDebugStatementSql(t, deleteStmt, `
DELETE FROM test_sample.link
WHERE link.name IN ('Gmail'::text, 'Outlook'::text)
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`, "Gmail", "Outlook")
dest := []model.Link{} testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var dest []model.Link
err := deleteStmt.Query(db, &dest) err := deleteStmt.Query(tx, &dest)
require.NoError(t, err) require.NoError(t, err)
@ -63,20 +54,10 @@ RETURNING link.id AS "link.id",
testutils.AssertDeepEqual(t, dest[0].Name, "Gmail") testutils.AssertDeepEqual(t, dest[0].Name, "Gmail")
testutils.AssertDeepEqual(t, dest[1].Name, "Outlook") testutils.AssertDeepEqual(t, dest[1].Name, "Outlook")
requireLogged(t, deleteStmt) requireLogged(t, deleteStmt)
} })
func initForDeleteTest(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(Link.URL, Link.Name, Link.Description).
VALUES("www.gmail.com", "Gmail", "Email service developed by Google").
VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft")
AssertExec(t, stmt, 2)
} }
func TestDeleteQueryContext(t *testing.T) { func TestDeleteQueryContext(t *testing.T) {
initForDeleteTest(t)
deleteStmt := Link. deleteStmt := Link.
DELETE(). DELETE().
WHERE(Link.Name.IN(String("Gmail"), String("Outlook"))) WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
@ -86,16 +67,16 @@ func TestDeleteQueryContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
dest := []model.Link{} dest := []model.Link{}
err := deleteStmt.QueryContext(ctx, db, &dest) err := deleteStmt.QueryContext(ctx, tx, &dest)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
requireLogged(t, deleteStmt) requireLogged(t, deleteStmt)
})
} }
func TestDeleteExecContext(t *testing.T) { func TestDeleteExecContext(t *testing.T) {
initForDeleteTest(t)
list := []Expression{String("Gmail"), String("Outlook")} list := []Expression{String("Gmail"), String("Outlook")}
deleteStmt := Link. deleteStmt := Link.
@ -107,15 +88,16 @@ func TestDeleteExecContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
_, err := deleteStmt.ExecContext(ctx, db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
_, err := deleteStmt.ExecContext(ctx, tx)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
requireLogged(t, deleteStmt) requireLogged(t, deleteStmt)
})
} }
func TestDeleteFrom(t *testing.T) { func TestDeleteFrom(t *testing.T) {
tx := beginTx(t) skipForCockroachDB(t) // USING is not supported
defer tx.Rollback()
stmt := table.Rental.DELETE(). stmt := table.Rental.DELETE().
USING( USING(
@ -158,6 +140,7 @@ RETURNING rental.rental_id AS "rental.rental_id",
store.last_update AS "store.last_update"; store.last_update AS "store.last_update";
`) `)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var dest []struct { var dest []struct {
Rental model2.Rental Rental model2.Rental
Store model2.Store Store model2.Store
@ -186,4 +169,5 @@ RETURNING rental.rental_id AS "rental.rental_id",
} }
} }
`) `)
})
} }

View file

@ -497,6 +497,8 @@ func newActorInfoTableImpl(schemaName, tableName, alias string) actorInfoTable {
` `
func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) { func TestGeneratedAllTypesSQLBuilderFiles(t *testing.T) {
skipForCockroachDB(t)
enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/") enumDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/enum/")
modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/") modelDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/model/")
tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/") tableDir := filepath.Join(testRoot, "/.gentestdata/jetdb/test_sample/table/")

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres" . "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
@ -13,9 +14,13 @@ import (
) )
func TestInsertValues(t *testing.T) { func TestInsertValues(t *testing.T) {
cleanUpLinkTable(t) insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(101, "http://www.google.com", "Google", DEFAULT).
VALUES(102, "http://www.yahoo.com", "Yahoo", nil).
RETURNING(Link.AllColumns)
var expectedSQL = ` testutils.AssertDebugStatementSql(t, insertQuery, `
INSERT INTO test_sample.link (id, url, name, description) INSERT INTO test_sample.link (id, url, name, description)
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(101, 'http://www.google.com', 'Google', DEFAULT), (101, 'http://www.google.com', 'Google', DEFAULT),
@ -24,76 +29,61 @@ RETURNING link.id AS "link.id",
link.url AS "link.url", link.url AS "link.url",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
` `,
insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(101, "http://www.google.com", "Google", DEFAULT).
VALUES(102, "http://www.yahoo.com", "Yahoo", nil).
RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, insertQuery, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
101, "http://www.google.com", "Google", 101, "http://www.google.com", "Google",
102, "http://www.yahoo.com", "Yahoo", nil) 102, "http://www.yahoo.com", "Yahoo", nil)
insertedLinks := []model.Link{} testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var insertedLinks []model.Link
err := insertQuery.Query(db, &insertedLinks) err := insertQuery.Query(tx, &insertedLinks)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(insertedLinks), 3) require.Equal(t, len(insertedLinks), 3)
testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[0], model.Link{
ID: 100, ID: 100,
URL: "http://www.postgresqltutorial.com", URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial", Name: "PostgreSQL Tutorial",
}) })
testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{
ID: 101, ID: 101,
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
}) })
testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102, ID: 102,
URL: "http://www.yahoo.com", URL: "http://www.yahoo.com",
Name: "Yahoo", Name: "Yahoo",
}) })
allLinks := []model.Link{} var allLinks []model.Link
err = Link.SELECT(Link.AllColumns). err = Link.SELECT(Link.AllColumns).
WHERE(Link.ID.GT_EQ(Int(100))). WHERE(Link.ID.BETWEEN(Int(100), Int(199))).
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
Query(db, &allLinks) Query(tx, &allLinks)
require.NoError(t, err) require.NoError(t, err)
testutils.AssertDeepEqual(t, insertedLinks, allLinks) testutils.AssertDeepEqual(t, insertedLinks, allLinks)
})
} }
func TestInsertEmptyColumnList(t *testing.T) { func TestInsertEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
expectedSQL := `
INSERT INTO test_sample.link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
`
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT);
`,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial") 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
requireLogged(t, stmt) requireLogged(t, stmt)
} }
func TestInsertOnConflict(t *testing.T) { func TestInsertOnConflict(t *testing.T) {
t.Run("do nothing", func(t *testing.T) { t.Run("do nothing", func(t *testing.T) {
employee := model.Employee{EmployeeID: rand.Int31()} employee := model.Employee{EmployeeID: rand.Int31()}
@ -108,11 +98,12 @@ VALUES ($1, $2, $3, $4, $5),
($6, $7, $8, $9, $10) ($6, $7, $8, $9, $10)
ON CONFLICT (employee_id) DO NOTHING; ON CONFLICT (employee_id) DO NOTHING;
`) `)
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
requireLogged(t, stmt) requireLogged(t, stmt)
}) })
t.Run("on constraint do nothing", func(t *testing.T) { t.Run("on constraint do nothing", func(t *testing.T) {
skipForCockroachDB(t) // does not support
employee := model.Employee{EmployeeID: rand.Int31()} employee := model.Employee{EmployeeID: rand.Int31()}
stmt := Employee.INSERT(Employee.AllColumns). stmt := Employee.INSERT(Employee.AllColumns).
@ -126,12 +117,11 @@ VALUES ($1, $2, $3, $4, $5),
($6, $7, $8, $9, $10) ($6, $7, $8, $9, $10)
ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING;
`) `)
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
requireLogged(t, stmt) requireLogged(t, stmt)
}) })
t.Run("do update", func(t *testing.T) { t.Run("do update", func(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
@ -148,18 +138,19 @@ VALUES ($1, $2, $3, DEFAULT),
($4, $5, $6, DEFAULT) ($4, $5, $6, DEFAULT)
ON CONFLICT (id) DO UPDATE ON CONFLICT (id) DO UPDATE
SET id = excluded.id, SET id = excluded.id,
url = $7 url = $7::text
RETURNING link.id AS "link.id", RETURNING link.id AS "link.id",
link.url AS "link.url", link.url AS "link.url",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
`) `)
AssertExec(t, stmt, 2) testutils.AssertExecAndRollback(t, stmt, db, 2)
}) })
t.Run("on constraint do update", func(t *testing.T) { t.Run("on constraint do update", func(t *testing.T) {
cleanUpLinkTable(t) skipForCockroachDB(t) // does not support
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
@ -177,18 +168,18 @@ VALUES ($1, $2, $3, DEFAULT),
($4, $5, $6, DEFAULT) ($4, $5, $6, DEFAULT)
ON CONFLICT ON CONSTRAINT link_pkey DO UPDATE ON CONFLICT ON CONSTRAINT link_pkey DO UPDATE
SET id = excluded.id, SET id = excluded.id,
url = $7 url = $7::text
RETURNING link.id AS "link.id", RETURNING link.id AS "link.id",
link.url AS "link.url", link.url AS "link.url",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
`) `)
AssertExec(t, stmt, 2) testutils.AssertExecAndRollback(t, stmt, db, 2)
}) })
t.Run("do update complex", func(t *testing.T) { t.Run("do update complex", func(t *testing.T) {
cleanUpLinkTable(t) skipForCockroachDB(t) // does not support ROW
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
@ -210,21 +201,15 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE
SELECT MAX(link.id) + 1 SELECT MAX(link.id) + 1
FROM test_sample.link FROM test_sample.link
), ),
(name, description) = ROW(excluded.name, 'new description') (name, description) = ROW(excluded.name, 'new description'::text)
WHERE link.description IS NOT NULL; WHERE link.description IS NOT NULL;
`) `)
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
}) })
} }
func TestInsertModelObject(t *testing.T) { func TestInsertModelObject(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`
linkData := model.Link{ linkData := model.Link{
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
Name: "Duck Duck go", Name: "Duck Duck go",
@ -234,18 +219,15 @@ VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
MODEL(linkData) MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go") testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`, "http://www.duckduckgo.com", "Duck Duck go")
AssertExec(t, query, 1) testutils.AssertExecAndRollback(t, query, db, 1)
} }
func TestInsertModelObjectEmptyColumnList(t *testing.T) { func TestInsertModelObjectEmptyColumnList(t *testing.T) {
cleanUpLinkTable(t)
var expectedSQL = `
INSERT INTO test_sample.link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
linkData := model.Link{ linkData := model.Link{
ID: 1000, ID: 1000,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -256,19 +238,16 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
INSERT(). INSERT().
MODEL(linkData) MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`,
int64(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
AssertExec(t, query, 1) testutils.AssertExecAndRollback(t, query, db, 1)
} }
func TestInsertModelsObject(t *testing.T) { func TestInsertModelsObject(t *testing.T) {
expectedSQL := `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`
tutorial := model.Link{ tutorial := model.Link{
URL: "http://www.postgresqltutorial.com", URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial", Name: "PostgreSQL Tutorial",
@ -288,23 +267,20 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
INSERT(Link.URL, Link.Name). INSERT(Link.URL, Link.Name).
MODELS([]model.Link{tutorial, google, yahoo}) MODELS([]model.Link{tutorial, google, yahoo})
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", "http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo") "http://www.yahoo.com", "Yahoo")
AssertExec(t, stmt, 3) testutils.AssertExecAndRollback(t, stmt, db, 3)
} }
func TestInsertUsingMutableColumns(t *testing.T) { func TestInsertUsingMutableColumns(t *testing.T) {
var expectedSQL = `
INSERT INTO test_sample.link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`
google := model.Link{ google := model.Link{
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
@ -321,22 +297,32 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
MODEL(google). MODEL(google).
MODELS([]model.Link{google, yahoo}) MODELS([]model.Link{google, yahoo})
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.yahoo.com", "Yahoo", nil) "http://www.yahoo.com", "Yahoo", nil)
AssertExec(t, stmt, 4) testutils.AssertExecAndRollback(t, stmt, db, 4)
} }
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
_, err := Link.DELETE(). query := Link.
WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). INSERT(Link.URL, Link.Name).
Exec(db) QUERY(
require.NoError(t, err) SELECT(Link.URL, Link.Name).
FROM(Link).
WHERE(Link.ID.EQ(Int(0))),
).
RETURNING(Link.AllColumns)
var expectedSQL = ` testutils.AssertDebugStatementSql(t, query, `
INSERT INTO test_sample.link (url, name) ( INSERT INTO test_sample.link (url, name) (
SELECT link.url AS "link.url", SELECT link.url AS "link.url",
link.name AS "link.name" link.name AS "link.name"
@ -347,38 +333,26 @@ RETURNING link.id AS "link.id",
link.url AS "link.url", link.url AS "link.url",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
` `, int64(0))
query := Link. testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
INSERT(Link.URL, Link.Name). var dest []model.Link
QUERY(
SELECT(Link.URL, Link.Name).
FROM(Link).
WHERE(Link.ID.EQ(Int(0))),
).
RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(0))
dest := []model.Link{}
err = query.Query(db, &dest)
err := query.Query(tx, &dest)
require.NoError(t, err) require.NoError(t, err)
youtubeLinks := []model.Link{} var youtubeLinks []model.Link
err = Link. err = Link.
SELECT(Link.AllColumns). SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Youtube"))). WHERE(Link.Name.EQ(String("Youtube"))).
Query(db, &youtubeLinks) Query(tx, &youtubeLinks)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(youtubeLinks), 2) require.Equal(t, len(youtubeLinks), 2)
})
} }
func TestInsertWithQueryContext(t *testing.T) { func TestInsertWithQueryContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
@ -388,15 +362,15 @@ func TestInsertWithQueryContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
dest := []model.Link{} dest := []model.Link{}
err := stmt.QueryContext(ctx, db, &dest) err := stmt.QueryContext(ctx, tx, &dest)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
})
} }
func TestInsertWithExecContext(t *testing.T) { func TestInsertWithExecContext(t *testing.T) {
cleanUpLinkTable(t)
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT) VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT)
@ -405,7 +379,7 @@ func TestInsertWithExecContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
_, err := stmt.ExecContext(ctx, db) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
testutils.AssertExecContextErr(t, stmt, ctx, tx, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded") })
} }

View file

@ -12,6 +12,8 @@ import (
) )
func TestLockTable(t *testing.T) { func TestLockTable(t *testing.T) {
skipForCockroachDB(t) // doesn't support
expectedSQL := ` expectedSQL := `
LOCK TABLE dvds.address IN` LOCK TABLE dvds.address IN`
@ -62,6 +64,8 @@ LOCK TABLE dvds.address IN`
} }
func TestLockExecContext(t *testing.T) { func TestLockExecContext(t *testing.T) {
skipForCockroachDB(t)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel() defer cancel()

View file

@ -25,6 +25,24 @@ import (
var db *sql.DB var db *sql.DB
var testRoot string var testRoot string
var source string
const CockroachDB = "COCKROACH_DB"
func init() {
source = os.Getenv("PG_SOURCE")
}
func sourceIsCockroachDB() bool {
return source == CockroachDB
}
func skipForCockroachDB(t *testing.T) {
if sourceIsCockroachDB() {
t.SkipNow()
}
}
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix()) rand.Seed(time.Now().Unix())
defer profile.Start().Stop() defer profile.Start().Stop()
@ -35,8 +53,15 @@ func TestMain(m *testing.M) {
fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName) fmt.Printf("\nRunning postgres tests for '%s' driver\n", driverName)
func() { func() {
connectionString := dbconfig.PostgresConnectString
if sourceIsCockroachDB() {
connectionString = dbconfig.CockroachConnectString
}
var err error var err error
db, err = sql.Open(driverName, dbconfig.PostgresConnectString) db, err = sql.Open(driverName, connectionString)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
panic("Failed to connect to test db") panic("Failed to connect to test db")
@ -113,9 +138,3 @@ func isPgxDriver() bool {
return false return false
} }
func beginTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"testing" "testing"
"time" "time"
@ -85,12 +86,10 @@ func TestRawStatementSelectWithArguments(t *testing.T) {
} }
func TestRawInsert(t *testing.T) { func TestRawInsert(t *testing.T) {
cleanUpLinkTable(t)
stmt := RawStatement(` stmt := RawStatement(`
INSERT INTO test_sample.link (id, url, name, description) INSERT INTO test_sample.link (id, url, name, description)
VALUES (@id1, @url1, @name1, DEFAULT), VALUES (@id1, @url1, @name1, DEFAULT),
(200, @url1, @name1, NULL), (2000, @url1, @name1, NULL),
(@id2, @url2, @name2, DEFAULT), (@id2, @url2, @name2, DEFAULT),
(@id3, @url3, @name3, NULL) (@id3, @url3, @name3, NULL)
RETURNING link.id AS "link.id", RETURNING link.id AS "link.id",
@ -98,45 +97,47 @@ RETURNING link.id AS "link.id",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"`, link.description AS "link.description"`,
RawArgs{ RawArgs{
"@id1": 100, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial", "@id1": 1000, "@url1": "http://www.postgresqltutorial.com", "@name1": "PostgreSQL Tutorial",
"@id2": 101, "@url2": "http://www.google.com", "@name2": "Google", "@id2": 1010, "@url2": "http://www.google.com", "@name2": "Google",
"@id3": 102, "@url3": "http://www.yahoo.com", "@name3": "Yahoo", "@id3": 1020, "@url3": "http://www.yahoo.com", "@name3": "Yahoo",
}) })
testutils.AssertStatementSql(t, stmt, ` testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.link (id, url, name, description) INSERT INTO test_sample.link (id, url, name, description)
VALUES ($1, $2, $3, DEFAULT), VALUES ($1, $2, $3, DEFAULT),
(200, $2, $3, NULL), (2000, $2, $3, NULL),
($4, $5, $6, DEFAULT), ($4, $5, $6, DEFAULT),
($7, $8, $9, NULL) ($7, $8, $9, NULL)
RETURNING link.id AS "link.id", RETURNING link.id AS "link.id",
link.url AS "link.url", link.url AS "link.url",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", `, 1000, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
101, "http://www.google.com", "Google", 1010, "http://www.google.com", "Google",
102, "http://www.yahoo.com", "Yahoo") 1020, "http://www.yahoo.com", "Yahoo")
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link (id, url, name, description) INSERT INTO test_sample.link (id, url, name, description)
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), VALUES (1000, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT),
(200, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), (2000, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL),
(101, 'http://www.google.com', 'Google', DEFAULT), (1010, 'http://www.google.com', 'Google', DEFAULT),
(102, 'http://www.yahoo.com', 'Yahoo', NULL) (1020, 'http://www.yahoo.com', 'Yahoo', NULL)
RETURNING link.id AS "link.id", RETURNING link.id AS "link.id",
link.url AS "link.url", link.url AS "link.url",
link.name AS "link.name", link.name AS "link.name",
link.description AS "link.description"; link.description AS "link.description";
`) `)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var links []model2.Link var links []model2.Link
err := stmt.Query(db, &links) err := stmt.Query(tx, &links)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, links, 4) require.Len(t, links, 4)
require.Equal(t, links[0].ID, int32(100)) require.Equal(t, links[0].ID, int64(1000))
require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com") require.Equal(t, links[1].URL, "http://www.postgresqltutorial.com")
require.Equal(t, links[2].Name, "Google") require.Equal(t, links[2].Name, "Google")
require.Nil(t, links[2].Description) require.Nil(t, links[2].Description)
})
} }
func TestRawStatementRows(t *testing.T) { func TestRawStatementRows(t *testing.T) {

View file

@ -1,9 +1,9 @@
package postgres package postgres
import ( import (
"github.com/google/uuid"
"testing" "testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
@ -14,30 +14,6 @@ import (
"github.com/shopspring/decimal" "github.com/shopspring/decimal"
) )
func TestUUIDType(t *testing.T) {
id := uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
query := AllTypes.
SELECT(AllTypes.UUID, AllTypes.UUIDPtr).
WHERE(AllTypes.UUID.EQ(UUID(id)))
testutils.AssertDebugStatementSql(t, query, `
SELECT all_types.uuid AS "all_types.uuid",
all_types.uuid_ptr AS "all_types.uuid_ptr"
FROM test_sample.all_types
WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11';
`, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
result := model.AllTypes{}
err := query.Query(db, &result)
require.NoError(t, err)
require.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
requireLogged(t, query)
}
func TestExactDecimals(t *testing.T) { func TestExactDecimals(t *testing.T) {
type floats struct { type floats struct {
@ -80,7 +56,7 @@ func TestExactDecimals(t *testing.T) {
t.Run("should insert decimal", func(t *testing.T) { t.Run("should insert decimal", func(t *testing.T) {
insertQuery := Floats.INSERT( insertQuery := Floats.INSERT(
Floats.AllColumns, Floats.MutableColumns,
).MODEL( ).MODEL(
floats{ floats{
Floats: model.Floats{ Floats: model.Floats{
@ -102,7 +78,7 @@ func TestExactDecimals(t *testing.T) {
DecimalPtr: decimal.RequireFromString("3.3333333333333333333"), DecimalPtr: decimal.RequireFromString("3.3333333333333333333"),
}, },
).RETURNING( ).RETURNING(
Floats.AllColumns, Floats.MutableColumns,
) )
testutils.AssertDebugStatementSql(t, insertQuery, ` testutils.AssertDebugStatementSql(t, insertQuery, `
@ -199,7 +175,9 @@ func TestUUIDComplex(t *testing.T) {
}) })
t.Run("single struct", func(t *testing.T) { t.Run("single struct", func(t *testing.T) {
singleQuery := query.WHERE(Person.PersonID.EQ(String("b68dbff6-a87d-11e9-a7f2-98ded00c39c8"))) uuid, err := uuid.Parse("b68dbff6-a87d-11e9-a7f2-98ded00c39c8")
require.NoError(t, err)
singleQuery := query.WHERE(Person.PersonID.EQ(UUID(uuid)))
var dest struct { var dest struct {
model.Person model.Person
@ -207,7 +185,7 @@ func TestUUIDComplex(t *testing.T) {
model.PersonPhone model.PersonPhone
} }
} }
err := singleQuery.Query(db, &dest) err = singleQuery.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
@ -304,7 +282,7 @@ SELECT person.person_id AS "person.person_id",
FROM test_sample.person; FROM test_sample.person;
`) `)
result := []model.Person{} var result []model.Person
err := query.Query(db, &result) err := query.Query(db, &result)
@ -333,7 +311,7 @@ FROM test_sample.person;
`) `)
} }
func TestSelecSelfJoin1(t *testing.T) { func TestSelectSelfJoin1(t *testing.T) {
// clean up // clean up
_, err := Employee.DELETE().WHERE(Employee.EmployeeID.GT(Int(100))).Exec(db) _, err := Employee.DELETE().WHERE(Employee.EmployeeID.GT(Int(100))).Exec(db)
@ -398,7 +376,7 @@ ORDER BY employee.employee_id;
} }
func TestWierdNamesTable(t *testing.T) { func TestWierdNamesTable(t *testing.T) {
stmt := WeirdNamesTable.SELECT(WeirdNamesTable.AllColumns) stmt := WeirdNamesTable.SELECT(WeirdNamesTable.MutableColumns)
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `
SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column_name1", SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column_name1",
@ -420,7 +398,7 @@ SELECT "WEIRD NAMES TABLE".weird_column_name1 AS "WEIRD NAMES TABLE.weird_column
FROM test_sample."WEIRD NAMES TABLE"; FROM test_sample."WEIRD NAMES TABLE";
`) `)
dest := []model.WeirdNamesTable{} var dest []model.WeirdNamesTable
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
@ -448,7 +426,7 @@ FROM test_sample."WEIRD NAMES TABLE";
} }
func TestReserwedWordEscape(t *testing.T) { func TestReserwedWordEscape(t *testing.T) {
stmt := SELECT(User.AllColumns). stmt := SELECT(User.MutableColumns).
FROM(User) FROM(User)
//fmt.Println(stmt.DebugSql()) //fmt.Println(stmt.DebugSql())
@ -480,6 +458,7 @@ FROM test_sample."User";
testutils.AssertJSON(t, dest, ` testutils.AssertJSON(t, dest, `
[ [
{ {
"ID": 0,
"Column": "Column", "Column": "Column",
"Check": "CHECK", "Check": "CHECK",
"Ceil": "CEIL", "Ceil": "CEIL",
@ -497,54 +476,3 @@ FROM test_sample."User";
] ]
`) `)
} }
func TestBytea(t *testing.T) {
byteArrHex := "\\x48656c6c6f20476f7068657221"
byteArrBin := []byte("\x48\x65\x6c\x6c\x6f\x20\x47\x6f\x70\x68\x65\x72\x21")
insertStmt := AllTypes.INSERT(AllTypes.Bytea, AllTypes.ByteaPtr).
VALUES(byteArrHex, byteArrBin).
RETURNING(AllTypes.Bytea, AllTypes.ByteaPtr)
testutils.AssertStatementSql(t, insertStmt, `
INSERT INTO test_sample.all_types (bytea, bytea_ptr)
VALUES ($1, $2)
RETURNING all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr";
`, byteArrHex, byteArrBin)
var inserted model.AllTypes
err := insertStmt.Query(db, &inserted)
require.NoError(t, err)
require.Equal(t, string(*inserted.ByteaPtr), "Hello Gopher!")
// It is not possible to initiate bytea column using hex format '\xDEADBEEF' with pq driver.
// pq driver always encodes parameter string if destination column is of type bytea.
// Probably pq driver error.
// require.Equal(t, string(inserted.Bytea), "Hello Gopher!")
stmt := SELECT(
AllTypes.Bytea,
AllTypes.ByteaPtr,
).FROM(
AllTypes,
).WHERE(
AllTypes.ByteaPtr.EQ(Bytea(byteArrBin)),
)
testutils.AssertStatementSql(t, stmt, `
SELECT all_types.bytea AS "all_types.bytea",
all_types.bytea_ptr AS "all_types.bytea_ptr"
FROM test_sample.all_types
WHERE all_types.bytea_ptr = $1::bytea;
`, byteArrBin)
var dest model.AllTypes
err = stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, string(*dest.ByteaPtr), "Hello Gopher!")
// Probably pq driver error.
// require.Equal(t, string(dest.Bytea), "Hello Gopher!")
}

View file

@ -416,8 +416,8 @@ FROM dvds.city
INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.address ON (address.city_id = city.city_id)
INNER JOIN dvds.customer ON (customer.address_id = address.address_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id)
WHERE ( WHERE (
(city.city = 'London') (city.city = 'London'::text)
OR (city.city = 'York') OR (city.city = 'York'::text)
) )
ORDER BY city.city_id, address.address_id, customer.customer_id; ORDER BY city.city_id, address.address_id, customer.customer_id;
`, "London", "York") `, "London", "York")
@ -492,7 +492,7 @@ SELECT city.city_id AS "my_city.id",
FROM dvds.city FROM dvds.city
INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.address ON (address.city_id = city.city_id)
INNER JOIN dvds.customer ON (customer.address_id = address.address_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id)
WHERE (city.city = 'London') OR (city.city = 'York') WHERE (city.city = 'London'::text) OR (city.city = 'York'::text)
ORDER BY city.city_id, address.address_id, customer.customer_id; ORDER BY city.city_id, address.address_id, customer.customer_id;
`, "London", "York") `, "London", "York")
@ -550,7 +550,7 @@ SELECT city.city_id AS "city_id",
FROM dvds.city FROM dvds.city
INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.address ON (address.city_id = city.city_id)
INNER JOIN dvds.customer ON (customer.address_id = address.address_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id)
WHERE (city.city = 'London') OR (city.city = 'York') WHERE (city.city = 'London'::text) OR (city.city = 'York'::text)
ORDER BY city.city_id, address.address_id, customer.customer_id; ORDER BY city.city_id, address.address_id, customer.customer_id;
`, "London", "York") `, "London", "York")
@ -607,7 +607,7 @@ SELECT city.city_id AS "city.city_id",
FROM dvds.city FROM dvds.city
INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.address ON (address.city_id = city.city_id)
INNER JOIN dvds.customer ON (customer.address_id = address.address_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id)
WHERE (city.city = 'London') OR (city.city = 'York') WHERE (city.city = 'London'::text) OR (city.city = 'York'::text)
ORDER BY city.city_id, address.address_id, customer.customer_id; ORDER BY city.city_id, address.address_id, customer.customer_id;
`, "London", "York") `, "London", "York")
@ -685,9 +685,6 @@ func TestSelect_WithoutUniqueColumnSelected(t *testing.T) {
err := query.Query(db, &customers) err := query.Query(db, &customers)
require.NoError(t, err) require.NoError(t, err)
//spew.Dump(customers)
require.Equal(t, len(customers), 599) require.Equal(t, len(customers), 599)
} }
@ -770,27 +767,35 @@ ORDER BY customer.customer_id ASC;
testutils.AssertDebugStatementSql(t, query, expectedSQL) testutils.AssertDebugStatementSql(t, query, expectedSQL)
allCustomersAndAddress := []struct { var allCustomersAndAddress []struct {
Address *model.Address Address *model.Address
Customer *model.Customer Customer *model.Customer
}{} }
err := query.Query(db, &allCustomersAndAddress) err := query.Query(db, &allCustomersAndAddress)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(allCustomersAndAddress), 603) require.Equal(t, len(allCustomersAndAddress), 603)
if sourceIsCockroachDB() {
nullsFirst := allCustomersAndAddress[0]
require.True(t, nullsFirst.Customer == nil)
require.True(t, nullsFirst.Address != nil)
testutils.AssertDeepEqual(t, allCustomersAndAddress[4].Customer, &customer0)
require.True(t, allCustomersAndAddress[0].Address != nil)
} else { // postgres
testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0) testutils.AssertDeepEqual(t, allCustomersAndAddress[0].Customer, &customer0)
require.True(t, allCustomersAndAddress[0].Address != nil) require.True(t, allCustomersAndAddress[0].Address != nil)
lastCustomerAddress := allCustomersAndAddress[len(allCustomersAndAddress)-1] nullsLast := allCustomersAndAddress[len(allCustomersAndAddress)-1]
require.True(t, nullsLast.Customer == nil)
require.True(t, lastCustomerAddress.Customer == nil) require.True(t, nullsLast.Address != nil)
require.True(t, lastCustomerAddress.Address != nil) }
} }
func TestSelectFullCrossJoin(t *testing.T) { func TestSelectCrossJoin(t *testing.T) {
expectedSQL := ` expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id", SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id", customer.store_id AS "customer.store_id",
@ -1128,6 +1133,7 @@ ORDER BY film.film_id ASC;
} }
func TestSelectGroupByHaving(t *testing.T) { func TestSelectGroupByHaving(t *testing.T) {
expectedSQL := ` expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id", SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id", customer.store_id AS "customer.store_id",
@ -1197,6 +1203,9 @@ ORDER BY customer.customer_id, SUM(payment.amount) ASC;
require.Equal(t, len(dest), 104) require.Equal(t, len(dest), 104)
if sourceIsCockroachDB() {
return // small precision difference in result
}
//testutils.SaveJsonFile(dest, "postgres/testdata/customer_payment_sum.json") //testutils.SaveJsonFile(dest, "postgres/testdata/customer_payment_sum.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json") testutils.AssertJSONFile(t, dest, "./testdata/results/postgres/customer_payment_sum.json")
} }
@ -1395,9 +1404,6 @@ ORDER BY payment.payment_date ASC;
err := query.Query(db, &payments) err := query.Query(db, &payments)
require.NoError(t, err) require.NoError(t, err)
//spew.Dump(payments)
require.Equal(t, len(payments), 9) require.Equal(t, len(payments), 9)
testutils.AssertDeepEqual(t, payments[0], model.Payment{ testutils.AssertDeepEqual(t, payments[0], model.Payment{
PaymentID: 17793, PaymentID: 17793,
@ -1531,7 +1537,7 @@ func TestAllSetOperators(t *testing.T) {
func TestSelectWithCase(t *testing.T) { func TestSelectWithCase(t *testing.T) {
expectedQuery := ` expectedQuery := `
SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE' WHEN 2 THEN 'TWO' WHEN 3 THEN 'THREE' ELSE 'OTHER' END) AS "staff_id_num" SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE'::text WHEN 2 THEN 'TWO'::text WHEN 3 THEN 'THREE'::text ELSE 'OTHER'::text END) AS "staff_id_num"
FROM dvds.payment FROM dvds.payment
ORDER BY payment.payment_id ASC ORDER BY payment.payment_id ASC
LIMIT 20; LIMIT 20;
@ -1611,6 +1617,10 @@ FOR`
require.NoError(t, err) require.NoError(t, err)
} }
if sourceIsCockroachDB() {
return // SKIP LOCKED lock wait policy is not supported
}
for lockType, lockTypeStr := range getRowLockTestData() { for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.SKIP_LOCKED()) query.FOR(lockType.SKIP_LOCKED())
@ -1660,7 +1670,7 @@ FROM dvds.actor
INNER JOIN dvds.language ON (language.language_id = film.language_id) INNER JOIN dvds.language ON (language.language_id = film.language_id)
INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id) INNER JOIN dvds.film_category ON (film_category.film_id = film.film_id)
INNER JOIN dvds.category ON (category.category_id = film_category.category_id) INNER JOIN dvds.category ON (category.category_id = film_category.category_id)
WHERE ((language.name = 'English') AND (category.name != 'Action')) AND (film.length > 180) WHERE ((language.name = 'English'::text) AND (category.name != 'Action'::text)) AND (film.length > 180)
ORDER BY actor.actor_id ASC, film.film_id ASC; ORDER BY actor.actor_id ASC, film.film_id ASC;
` `
@ -1927,10 +1937,11 @@ func TestSimpleView(t *testing.T) {
query := SELECT( query := SELECT(
view.ActorInfo.AllColumns, view.ActorInfo.AllColumns,
). ).FROM(
FROM(view.ActorInfo). view.ActorInfo,
ORDER_BY(view.ActorInfo.ActorID). ).ORDER_BY(
LIMIT(10) view.ActorInfo.ActorID,
).LIMIT(10)
type ActorInfo struct { type ActorInfo struct {
ActorID int ActorID int
@ -1944,6 +1955,10 @@ func TestSimpleView(t *testing.T) {
err := query.Query(db, &dest) err := query.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
if sourceIsCockroachDB() {
return // skip for cockroach db, FilmInfo is set to '' in ddl
}
testutils.AssertJSON(t, dest[1:2], ` testutils.AssertJSON(t, dest[1:2], `
[ [
{ {
@ -2117,7 +2132,7 @@ FROM dvds.film
language.name AS "language.name", language.name AS "language.name",
language.last_update AS "language.last_update" language.last_update AS "language.last_update"
FROM dvds.language FROM dvds.language
WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) WHERE (language.name NOT IN ('spanish'::text)) AND (film.language_id = language.language_id)
) AS films ) AS films
WHERE film.film_id = 1 WHERE film.film_id = 1
ORDER BY film.film_id ORDER BY film.film_id
@ -2162,7 +2177,7 @@ FROM dvds.film,
language.name AS "language.name", language.name AS "language.name",
language.last_update AS "language.last_update" language.last_update AS "language.last_update"
FROM dvds.language FROM dvds.language
WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) WHERE (language.name NOT IN ('spanish'::text)) AND (film.language_id = language.language_id)
) AS films ) AS films
WHERE film.film_id = 1 WHERE film.film_id = 1
ORDER BY film.film_id ORDER BY film.film_id
@ -2630,6 +2645,8 @@ func GET_FILM_COUNT(lenFrom, lenTo IntegerExpression) IntegerExpression {
} }
func TestCustomFunctionCall(t *testing.T) { func TestCustomFunctionCall(t *testing.T) {
skipForCockroachDB(t)
stmt := SELECT( stmt := SELECT(
GET_FILM_COUNT(Int(100), Int(120)).AS("film_count"), GET_FILM_COUNT(Int(100), Int(120)).AS("film_count"),
) )
@ -2662,3 +2679,42 @@ SELECT dvds.get_film_count(100, 120) AS "film_count";
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, dest.FilmCount, 165) require.Equal(t, dest.FilmCount, 165)
} }
var customer0 = model.Customer{
CustomerID: 1,
StoreID: 1,
FirstName: "Mary",
LastName: "Smith",
Email: testutils.StringPtr("mary.smith@sakilacustomer.org"),
AddressID: 5,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
}
var customer1 = model.Customer{
CustomerID: 2,
StoreID: 1,
FirstName: "Patricia",
LastName: "Johnson",
Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"),
AddressID: 6,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
}
var lastCustomer = model.Customer{
CustomerID: 599,
StoreID: 2,
FirstName: "Austin",
LastName: "Cintron",
Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"),
AddressID: 605,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
}

View file

@ -2,6 +2,7 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres" . "github.com/go-jet/jet/v2/postgres"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
@ -14,9 +15,7 @@ import (
) )
func TestUpdateValues(t *testing.T) { func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t) t.Run("deprecated update", func(t *testing.T) {
t.Run("deprecated version", func(t *testing.T) {
query := Link. query := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
@ -25,19 +24,21 @@ func TestUpdateValues(t *testing.T) {
testutils.AssertDebugStatementSql(t, query, ` testutils.AssertDebugStatementSql(t, query, `
UPDATE test_sample.link UPDATE test_sample.link
SET (name, url) = ('Bong', 'http://bong.com') SET (name, url) = ('Bong', 'http://bong.com')
WHERE link.name = 'Bing'; WHERE link.name = 'Bing'::text;
`, "Bong", "http://bong.com", "Bing") `, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, db, 1) testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
testutils.AssertExec(t, query, tx, 1)
requireLogged(t, query) requireLogged(t, query)
links := []model.Link{} var links []model.Link
selQuery := Link. selQuery := Link.
SELECT(Link.AllColumns). SELECT(Link.AllColumns).
WHERE(Link.Name.IN(String("Bong"))) WHERE(Link.Name.IN(String("Bong")))
err := selQuery.Query(db, &links) err := selQuery.Query(tx, &links)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(links), 1) require.Equal(t, len(links), 1)
@ -48,8 +49,9 @@ WHERE link.name = 'Bing';
}) })
requireLogged(t, selQuery) requireLogged(t, selQuery)
}) })
})
t.Run("new version", func(t *testing.T) { t.Run("new type safe update", func(t *testing.T) {
stmt := Link.UPDATE(). stmt := Link.UPDATE().
SET( SET(
Link.Name.SET(String("DuckDuckGo")), Link.Name.SET(String("DuckDuckGo")),
@ -59,18 +61,16 @@ WHERE link.name = 'Bing';
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link UPDATE test_sample.link
SET name = 'DuckDuckGo', SET name = 'DuckDuckGo'::text,
url = 'www.duckduckgo.com' url = 'www.duckduckgo.com'::text
WHERE link.name = 'Yahoo'; WHERE link.name = 'Yahoo'::text;
`) `)
testutils.AssertExec(t, stmt, db, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
requireLogged(t, stmt) requireLogged(t, stmt)
}) })
} }
func TestUpdateWithSubQueries(t *testing.T) { func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t)
t.Run("deprecated version", func(t *testing.T) { t.Run("deprecated version", func(t *testing.T) {
query := Link. query := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
@ -82,20 +82,19 @@ func TestUpdateWithSubQueries(t *testing.T) {
). ).
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
expectedSQL := ` testutils.AssertDebugStatementSql(t, query, `
UPDATE test_sample.link UPDATE test_sample.link
SET (name, url) = (( SET (name, url) = ((
SELECT 'Bong' SELECT 'Bong'::text
), ( ), (
SELECT link.url AS "link.url" SELECT link.url AS "link.url"
FROM test_sample.link FROM test_sample.link
WHERE link.name = 'Bing' WHERE link.name = 'Bing'::text
)) ))
WHERE link.name = 'Bing'; WHERE link.name = 'Bing'::text;
` `, "Bong", "Bing", "Bing")
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing")
AssertExec(t, query, 1) testutils.AssertExecAndRollback(t, query, db, 1)
requireLogged(t, query) requireLogged(t, query)
}) })
@ -113,50 +112,48 @@ WHERE link.name = 'Bing';
testutils.AssertStatementSql(t, query, ` testutils.AssertStatementSql(t, query, `
UPDATE test_sample.link UPDATE test_sample.link
SET name = $1, SET name = $1::text,
url = ( url = (
SELECT link.url AS "link.url" SELECT link.url AS "link.url"
FROM test_sample.link FROM test_sample.link
WHERE link.name = $2 WHERE link.name = $2::text
) )
WHERE link.name = $3; WHERE link.name = $3::text;
`, "Bong", "Bing", "Bing") `, "Bong", "Bing", "Bing")
_, err := query.Exec(db) testutils.AssertExecAndRollback(t, query, db)
require.NoError(t, err)
requireLogged(t, query) requireLogged(t, query)
}) })
} }
func TestUpdateAndReturning(t *testing.T) { func TestUpdateAndReturning(t *testing.T) {
setupLinkTableForUpdateTest(t)
expectedSQL := `
UPDATE test_sample.link
SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.name = 'Ask'
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`
stmt := Link. stmt := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("DuckDuckGo", "http://www.duckduckgo.com"). SET("DuckDuckGo", "http://www.duckduckgo.com").
WHERE(Link.Name.EQ(String("Ask"))). WHERE(Link.Name.EQ(String("Ask"))).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link
SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.name = 'Ask'::text
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`, "DuckDuckGo", "http://www.duckduckgo.com", "Ask")
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
links := []model.Link{} links := []model.Link{}
err := stmt.Query(db, &links) err := stmt.Query(tx, &links)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(links), 2) require.Equal(t, len(links), 2)
require.Equal(t, links[0].Name, "DuckDuckGo") require.Equal(t, links[0].Name, "DuckDuckGo")
require.Equal(t, links[1].Name, "DuckDuckGo") require.Equal(t, links[1].Name, "DuckDuckGo")
requireLogged(t, stmt) requireLogged(t, stmt)
})
} }
func TestUpdateWithSelect(t *testing.T) { func TestUpdateWithSelect(t *testing.T) {
@ -170,7 +167,7 @@ func TestUpdateWithSelect(t *testing.T) {
). ).
WHERE(Link.ID.EQ(Int(0))) WHERE(Link.ID.EQ(Int(0)))
expectedSQL := ` testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description) = ( SET (id, url, name, description) = (
SELECT link.id AS "link.id", SELECT link.id AS "link.id",
@ -181,10 +178,9 @@ SET (id, url, name, description) = (
WHERE link.id = 0 WHERE link.id = 0
) )
WHERE link.id = 0; WHERE link.id = 0;
` `, int64(0), int64(0))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
}) })
t.Run("new version", func(t *testing.T) { t.Run("new version", func(t *testing.T) {
@ -210,12 +206,11 @@ SET (url, name, description) = (
WHERE link.id = 0; WHERE link.id = 0;
`, int64(0), int64(0)) `, int64(0), int64(0))
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
}) })
} }
func TestUpdateWithInvalidSelect(t *testing.T) { func TestUpdateWithInvalidSelect(t *testing.T) {
t.Run("deprecated version", func(t *testing.T) { t.Run("deprecated version", func(t *testing.T) {
stmt := Link.UPDATE(Link.AllColumns). stmt := Link.UPDATE(Link.AllColumns).
SET( SET(
@ -236,7 +231,6 @@ SET (id, url, name, description) = (
WHERE link.id = 0; WHERE link.id = 0;
` `
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values")
}) })
@ -250,8 +244,6 @@ WHERE link.id = 0;
} }
func TestUpdateWithModelData(t *testing.T) { func TestUpdateWithModelData(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{ link := model.Link{
ID: 201, ID: 201,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -261,24 +253,20 @@ func TestUpdateWithModelData(t *testing.T) {
stmt := Link. stmt := Link.
UPDATE(Link.AllColumns). UPDATE(Link.AllColumns).
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID))) WHERE(Link.ID.EQ(Int64(link.ID)))
expectedSQL := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL) SET (id, url, name, description) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL)
WHERE link.id = 201::integer; WHERE link.id = 201::bigint;
` `
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
_, err := stmt.Exec(db) testutils.AssertExecAndRollback(t, stmt, db, 1)
require.NoError(t, err)
requireQueryLogged(t, stmt, 1) requireQueryLogged(t, stmt, 1)
} }
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{ link := model.Link{
ID: 201, ID: 201,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -290,27 +278,24 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
stmt := Link. stmt := Link.
UPDATE(updateColumnList). UPDATE(updateColumnList).
MODEL(link). MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID))) WHERE(Link.ID.EQ(Int64(link.ID)))
var expectedSQL = ` testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link UPDATE test_sample.link
SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') SET (description, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com')
WHERE link.id = 201::integer; WHERE link.id = 201::bigint;
` `,
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(201)) nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
AssertExec(t, stmt, 1) testutils.AssertExecAndRollback(t, stmt, db, 1)
} }
func TestUpdateWithInvalidModelData(t *testing.T) { func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
require.Equal(t, r, "missing struct field for column : id") require.Equal(t, r, "missing struct field for column : id")
}() }()
setupLinkTableForUpdateTest(t)
link := struct { link := struct {
Ident int Ident int
URL string URL string
@ -323,24 +308,13 @@ func TestUpdateWithInvalidModelData(t *testing.T) {
Name: "DuckDuckGo", Name: "DuckDuckGo",
} }
stmt := Link. _ = Link.
UPDATE(Link.AllColumns). UPDATE(Link.AllColumns).
MODEL(link). MODEL(link). // panics
WHERE(Link.ID.EQ(Int(int64(link.Ident)))) WHERE(Link.ID.EQ(Int(int64(link.Ident))))
var expectedSQL = `
UPDATE test_sample.link
SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL)
WHERE link.id = 201;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201))
testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values")
} }
func TestUpdateQueryContext(t *testing.T) { func TestUpdateQueryContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link. updateStmt := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
@ -351,15 +325,15 @@ func TestUpdateQueryContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
dest := []model.Link{} dest := []model.Link{}
err := updateStmt.QueryContext(ctx, db, &dest) err := updateStmt.QueryContext(ctx, tx, &dest)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")
})
} }
func TestUpdateExecContext(t *testing.T) { func TestUpdateExecContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link. updateStmt := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
@ -370,15 +344,10 @@ func TestUpdateExecContext(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
_, err := updateStmt.ExecContext(ctx, db) testutils.AssertExecContextErr(t, updateStmt, ctx, db, "context deadline exceeded")
require.Error(t, err, "context deadline exceeded")
} }
func TestUpdateFrom(t *testing.T) { func TestUpdateFrom(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
stmt := table.Rental.UPDATE(). stmt := table.Rental.UPDATE().
SET( SET(
table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)), table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)),
@ -416,6 +385,7 @@ RETURNING rental.rental_id AS "rental.rental_id",
store.address_id AS "store.address_id"; store.address_id AS "store.address_id";
`) `)
testutils.ExecuteInTxAndRollback(t, db, func(tx *sql.Tx) {
var dest []struct { var dest []struct {
Rental model2.Rental Rental model2.Rental
Store model2.Store Store model2.Store
@ -444,24 +414,5 @@ RETURNING rental.rental_id AS "rental.rental_id",
} }
} }
`) `)
} })
func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t)
_, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(201, "http://www.ask.com", "Ask", DEFAULT).
VALUES(202, "http://www.ask.com", "Ask", DEFAULT).
VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT).
VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db)
require.NoError(t, err)
}
func cleanUpLinkTable(t *testing.T) {
_, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db)
require.NoError(t, err)
} }

View file

@ -1,57 +0,0 @@
package postgres
import (
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
"github.com/stretchr/testify/require"
"testing"
)
func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) {
res, err := stmt.Exec(db)
require.NoError(t, err)
rows, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, rows, rowsAffected)
}
var customer0 = model.Customer{
CustomerID: 1,
StoreID: 1,
FirstName: "Mary",
LastName: "Smith",
Email: testutils.StringPtr("mary.smith@sakilacustomer.org"),
AddressID: 5,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
}
var customer1 = model.Customer{
CustomerID: 2,
StoreID: 1,
FirstName: "Patricia",
LastName: "Johnson",
Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"),
AddressID: 6,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
}
var lastCustomer = model.Customer{
CustomerID: 599,
StoreID: 2,
FirstName: "Austin",
LastName: "Cintron",
Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"),
AddressID: 605,
Activebool: true,
CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0),
LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3),
Active: testutils.Int32Ptr(1),
}

View file

@ -106,9 +106,11 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
removeDiscontinuedOrders.AS( removeDiscontinuedOrders.AS(
OrderDetails.DELETE(). OrderDetails.DELETE().
WHERE(OrderDetails.ProductID.IN( WHERE(OrderDetails.ProductID.IN(
SELECT(Products.ProductID). SELECT(
FROM(Products). Products.ProductID,
WHERE(Products.Discontinued.EQ(Int(1)))), ).FROM(
Products,
).WHERE(Products.Discontinued.EQ(Int(1)))),
).RETURNING(OrderDetails.ProductID), ).RETURNING(OrderDetails.ProductID),
), ),
updateDiscontinuedPrice.AS( updateDiscontinuedPrice.AS(
@ -121,7 +123,13 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
), ),
logDiscontinuedProducts.AS( logDiscontinuedProducts.AS(
ProductLogs.INSERT(ProductLogs.AllColumns). ProductLogs.INSERT(ProductLogs.AllColumns).
QUERY(SELECT(updateDiscontinuedPrice.AllColumns()).FROM(updateDiscontinuedPrice)). QUERY(
SELECT(
updateDiscontinuedPrice.AllColumns(),
).FROM(
updateDiscontinuedPrice,
),
).
RETURNING( RETURNING(
ProductLogs.ProductID, ProductLogs.ProductID,
ProductLogs.ProductName, ProductLogs.ProductName,
@ -384,7 +392,7 @@ WITH cte1 AS (
SELECT territories.territory_id AS "territories.territory_id", SELECT territories.territory_id AS "territories.territory_id",
territories.territory_description AS "territories.territory_description", territories.territory_description AS "territories.territory_description",
territories.region_id AS "territories.region_id", territories.region_id AS "territories.region_id",
$1 AS "custom_column_1" $1::text AS "custom_column_1"
FROM northwind.territories FROM northwind.territories
ORDER BY territories.territory_id ASC ORDER BY territories.territory_id ASC
),cte2 AS ( ),cte2 AS (
@ -392,7 +400,7 @@ WITH cte1 AS (
cte1."territories.territory_description" AS "territories.territory_description", cte1."territories.territory_description" AS "territories.territory_description",
cte1."territories.region_id" AS "territories.region_id", cte1."territories.region_id" AS "territories.region_id",
cte1.custom_column_1 AS "custom_column_1", cte1.custom_column_1 AS "custom_column_1",
$2 AS "custom_column_2" $2::text AS "custom_column_2"
FROM cte1 FROM cte1
) )
SELECT cte2."territories.territory_id" AS "territories.territory_id", SELECT cte2."territories.territory_id" AS "territories.territory_id",
@ -485,7 +493,7 @@ func TestRecursiveWithStatement(t *testing.T) {
Employees, Employees,
).WHERE( ).WHERE(
Employees.EmployeeID.EQ(Int(2)), Employees.EmployeeID.EQ(Int(2)),
).UNION( ).UNION_ALL(
SELECT( SELECT(
Employees.AllColumns, Employees.AllColumns,
).FROM( ).FROM(
@ -790,13 +798,13 @@ WITH suppliers_fax AS (
suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name", suppliers_fax."suppliers.contact_name" AS "suppliers.contact_name",
suppliers_fax."suppliers.country" AS "suppliers.country" suppliers_fax."suppliers.country" AS "suppliers.country"
FROM suppliers_fax FROM suppliers_fax
WHERE suppliers_fax."suppliers.country" NOT IN ('US', 'Australia') WHERE suppliers_fax."suppliers.country" NOT IN ('US'::text, 'Australia'::text)
) )
SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id", SELECT not_from_us_or_aus."suppliers.supplier_id" AS "suppliers.supplier_id",
not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name", not_from_us_or_aus."suppliers.contact_name" AS "suppliers.contact_name",
not_from_us_or_aus."suppliers.country" AS "suppliers.country" not_from_us_or_aus."suppliers.country" AS "suppliers.country"
FROM not_from_us_or_aus FROM not_from_us_or_aus
WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'; WHERE not_from_us_or_aus."suppliers.contact_name" != 'John'::text;
`) `)
var dest []model.Suppliers var dest []model.Suppliers

View file

@ -2,6 +2,7 @@ package sqlite
import ( import (
"context" "context"
"database/sql"
"math/rand" "math/rand"
"testing" "testing"
@ -15,9 +16,6 @@ import (
) )
func TestInsertValues(t *testing.T) { func TestInsertValues(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
VALUES(101, "http://www.google.com", "Google", "Search engine"). VALUES(101, "http://www.google.com", "Google", "Search engine").
@ -32,13 +30,13 @@ VALUES (?, ?, ?, ?),
101, "http://www.google.com", "Google", "Search engine", 101, "http://www.google.com", "Google", "Search engine",
102, "http://www.yahoo.com", "Yahoo", nil) 102, "http://www.yahoo.com", "Yahoo", nil)
_, err := insertQuery.Exec(tx) testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) {
require.NoError(t, err) testutils.AssertExec(t, insertQuery, tx)
requireLogged(t, insertQuery) requireLogged(t, insertQuery)
insertedLinks := []model.Link{} var insertedLinks []model.Link
err = SELECT(Link.AllColumns). err := SELECT(Link.AllColumns).
FROM(Link). FROM(Link).
WHERE(Link.ID.GT_EQ(Int(100))). WHERE(Link.ID.GT_EQ(Int(100))).
ORDER_BY(Link.ID). ORDER_BY(Link.ID).
@ -58,6 +56,7 @@ VALUES (?, ?, ?, ?),
URL: "http://www.yahoo.com", URL: "http://www.yahoo.com",
Name: "Yahoo", Name: "Yahoo",
}) })
})
} }
var postgreTutorial = model.Link{ var postgreTutorial = model.Link{
@ -67,25 +66,21 @@ var postgreTutorial = model.Link{
} }
func TestInsertEmptyColumnList(t *testing.T) { func TestInsertEmptyColumnList(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
expectedSQL := `
INSERT INTO link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL);
`
stmt := Link.INSERT(). stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil)
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL);
`,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil)
testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) {
_, err := stmt.Exec(tx) _, err := stmt.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
requireLogged(t, stmt) requireLogged(t, stmt)
insertedLinks := []model.Link{} var insertedLinks []model.Link
err = SELECT(Link.AllColumns). err = SELECT(Link.AllColumns).
FROM(Link). FROM(Link).
@ -96,12 +91,10 @@ VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL);
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(insertedLinks), 1) require.Equal(t, len(insertedLinks), 1)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
})
} }
func TestInsertModelObject(t *testing.T) { func TestInsertModelObject(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
linkData := model.Link{ linkData := model.Link{
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
Name: "Duck Duck go", Name: "Duck Duck go",
@ -115,19 +108,13 @@ INSERT INTO link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`, "http://www.duckduckgo.com", "Duck Duck go") `, "http://www.duckduckgo.com", "Duck Duck go")
testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) {
_, err := query.Exec(tx) _, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
})
} }
func TestInsertModelObjectEmptyColumnList(t *testing.T) { func TestInsertModelObjectEmptyColumnList(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
INSERT INTO link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
linkData := model.Link{ linkData := model.Link{
ID: 1000, ID: 1000,
URL: "http://www.duckduckgo.com", URL: "http://www.duckduckgo.com",
@ -138,23 +125,18 @@ VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
INSERT(). INSERT().
MODEL(linkData) MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) testutils.AssertDebugStatementSql(t, query, `
INSERT INTO link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) {
_, err := query.Exec(tx) _, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
})
} }
func TestInsertModelsObject(t *testing.T) { func TestInsertModelsObject(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
expectedSQL := `
INSERT INTO link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`
tutorial := model.Link{ tutorial := model.Link{
URL: "http://www.postgresqltutorial.com", URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial", Name: "PostgreSQL Tutorial",
@ -176,27 +158,20 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
yahoo, yahoo,
}) })
testutils.AssertDebugStatementSql(t, query, expectedSQL, testutils.AssertDebugStatementSql(t, query, `
INSERT INTO link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google", "http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo") "http://www.yahoo.com", "Yahoo")
_, err := query.Exec(tx) testutils.AssertExecAndRollback(t, query, sampleDB)
require.NoError(t, err)
} }
func TestInsertUsingMutableColumns(t *testing.T) { func TestInsertUsingMutableColumns(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
INSERT INTO link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`
google := model.Link{ google := model.Link{
URL: "http://www.google.com", URL: "http://www.google.com",
Name: "Google", Name: "Google",
@ -213,20 +188,22 @@ VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL),
MODEL(google). MODEL(google).
MODELS([]model.Link{google, yahoo}) MODELS([]model.Link{google, yahoo})
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil,
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.google.com", "Google", nil, "http://www.google.com", "Google", nil,
"http://www.yahoo.com", "Yahoo", nil) "http://www.yahoo.com", "Yahoo", nil)
_, err := stmt.Exec(tx) testutils.AssertExecAndRollback(t, stmt, sampleDB)
require.NoError(t, err)
} }
func TestInsertQuery(t *testing.T) { func TestInsertQuery(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = ` var expectedSQL = `
INSERT INTO link (url, name) INSERT INTO link (url, name)
SELECT link.url AS "link.url", SELECT link.url AS "link.url",
@ -242,11 +219,11 @@ WHERE link.id = 24;
) )
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24)) testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24))
testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) {
_, err := query.Exec(tx) _, err := query.Exec(tx)
require.NoError(t, err) require.NoError(t, err)
youtubeLinks := []model.Link{} var youtubeLinks []model.Link
err = Link. err = Link.
SELECT(Link.AllColumns). SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Bing"))). WHERE(Link.Name.EQ(String("Bing"))).
@ -254,12 +231,10 @@ WHERE link.id = 24;
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, len(youtubeLinks), 2) require.Equal(t, len(youtubeLinks), 2)
})
} }
func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) { func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
stmt := Link.INSERT(). stmt := Link.INSERT().
DEFAULT_VALUES(). DEFAULT_VALUES().
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
@ -273,6 +248,7 @@ RETURNING link.id AS "link.id",
link.description AS "link.description"; link.description AS "link.description";
`) `)
testutils.ExecuteInTxAndRollback(t, sampleDB, func(tx *sql.Tx) {
var link model.Link var link model.Link
err := stmt.Query(tx, &link) err := stmt.Query(tx, &link)
require.NoError(t, err) require.NoError(t, err)
@ -283,14 +259,12 @@ RETURNING link.id AS "link.id",
Name: "_", Name: "_",
Description: nil, Description: nil,
}) })
})
} }
func TestInsertOnConflict(t *testing.T) { func TestInsertOnConflict(t *testing.T) {
t.Run("do nothing", func(t *testing.T) { t.Run("do nothing", func(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
link := model.Link{ID: rand.Int31()} link := model.Link{ID: rand.Int31()}
stmt := Link.INSERT(Link.AllColumns). stmt := Link.INSERT(Link.AllColumns).
@ -304,14 +278,11 @@ VALUES (?, ?, ?, ?),
(?, ?, ?, ?) (?, ?, ?, ?)
ON CONFLICT (id) DO NOTHING; ON CONFLICT (id) DO NOTHING;
`) `)
testutils.AssertExec(t, stmt, tx, 1) testutils.AssertExecAndRollback(t, stmt, sampleDB, 1)
requireLogged(t, stmt) requireLogged(t, stmt)
}) })
t.Run("do update", func(t *testing.T) { t.Run("do update", func(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
@ -336,14 +307,11 @@ RETURNING link.id AS "link.id",
link.description AS "link.description"; link.description AS "link.description";
`) `)
testutils.AssertExec(t, stmt, tx) testutils.AssertExecAndRollback(t, stmt, sampleDB)
requireLogged(t, stmt) requireLogged(t, stmt)
}) })
t.Run("do update complex", func(t *testing.T) { t.Run("do update complex", func(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
ON_CONFLICT(Link.ID). ON_CONFLICT(Link.ID).
@ -370,7 +338,7 @@ ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE
WHERE link.description IS NOT NULL; WHERE link.description IS NOT NULL;
`) `)
testutils.AssertExec(t, stmt, tx) testutils.AssertExecAndRollback(t, stmt, sampleDB)
requireLogged(t, stmt) requireLogged(t, stmt)
}) })
} }
@ -384,7 +352,7 @@ func TestInsertContextDeadlineExceeded(t *testing.T) {
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
dest := []model.Link{} var dest []model.Link
err := stmt.QueryContext(ctx, sampleDB, &dest) err := stmt.QueryContext(ctx, sampleDB, &dest)
require.Error(t, err, "context deadline exceeded") require.Error(t, err, "context deadline exceeded")

View file

@ -35,6 +35,7 @@ func TestMain(m *testing.M) {
var err error var err error
db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath)
throw.OnError(err) throw.OnError(err)
defer db.Close()
_, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath))
throw.OnError(err) throw.OnError(err)
@ -42,8 +43,6 @@ func TestMain(m *testing.M) {
sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath) sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath)
throw.OnError(err) throw.OnError(err)
defer db.Close()
ret := m.Run() ret := m.Run()
if ret != 0 { if ret != 0 {