diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 349d4db..fcf7655 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -11,9 +11,7 @@ import ( "os" "path/filepath" "runtime" - "strings" "testing" - "time" ) func AssertExec(t *testing.T, stmt jet.Statement, db execution.DB, rowsAffected ...int64) { @@ -90,6 +88,13 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, assert.DeepEqual(t, args, expectedArgs) } +func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) { + _, _, err := stmt.Sql() + + assert.Assert(t, err != nil) + assert.Error(t, err, errorStr) +} + func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { _, args, err := query.Sql() assert.NilError(t, err) @@ -105,66 +110,33 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st assert.Equal(t, debuqSql, expectedQuery) } -func Date(t string) *time.Time { - newTime, err := time.Parse("2006-01-02", t) +func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { + out := jet.SqlBuilder{Dialect: dialect} + err := jet.Serialize(clause, jet.SelectStatementType, &out) - if err != nil { - panic(err) - } + assert.NilError(t, err) - return &newTime + //fmt.Println(out.Buff.String()) + + assert.DeepEqual(t, out.Buff.String(), query) + assert.DeepEqual(t, out.Args, args) } -func TimestampWithoutTimeZone(t string, precision int) *time.Time { +func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { + out := jet.SqlBuilder{Dialect: dialect} + err := jet.Serialize(clause, jet.SelectStatementType, &out) - precisionStr := "" - - if precision > 0 { - precisionStr = "." + strings.Repeat("9", precision) - } - - newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") - - if err != nil { - panic(err) - } - - return &newTime + //fmt.Println(out.buff.String()) + assert.Assert(t, err != nil) + assert.Error(t, err, errString) } -func TimeWithoutTimeZone(t string) *time.Time { - newTime, err := time.Parse("15:04:05", t) +func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet.Projection, query string, args ...interface{}) { + out := jet.SqlBuilder{Dialect: dialect} + err := jet.SerializeForProjection(projection, jet.SelectStatementType, &out) - if err != nil { - panic(err) - } + assert.NilError(t, err) - return &newTime -} - -func TimeWithTimeZone(t string) *time.Time { - newTimez, err := time.Parse("15:04:05 -0700", t) - - if err != nil { - panic(err) - } - - return &newTimez -} - -func TimestampWithTimeZone(t string, precision int) *time.Time { - - precisionStr := "" - - if precision > 0 { - precisionStr = "." + strings.Repeat("9", precision) - } - - newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) - - if err != nil { - panic(err) - } - - return &newTime + assert.DeepEqual(t, out.Buff.String(), query) + assert.DeepEqual(t, out.Args, args) } diff --git a/internal/testutils/time_utils.go b/internal/testutils/time_utils.go new file mode 100644 index 0000000..2f83231 --- /dev/null +++ b/internal/testutils/time_utils.go @@ -0,0 +1,70 @@ +package testutils + +import ( + "strings" + "time" +) + +func Date(t string) *time.Time { + newTime, err := time.Parse("2006-01-02", t) + + if err != nil { + panic(err) + } + + return &newTime +} + +func TimestampWithoutTimeZone(t string, precision int) *time.Time { + + precisionStr := "" + + if precision > 0 { + precisionStr = "." + strings.Repeat("9", precision) + } + + newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") + + if err != nil { + panic(err) + } + + return &newTime +} + +func TimeWithoutTimeZone(t string) *time.Time { + newTime, err := time.Parse("15:04:05", t) + + if err != nil { + panic(err) + } + + return &newTime +} + +func TimeWithTimeZone(t string) *time.Time { + newTimez, err := time.Parse("15:04:05 -0700", t) + + if err != nil { + panic(err) + } + + return &newTimez +} + +func TimestampWithTimeZone(t string, precision int) *time.Time { + + precisionStr := "" + + if precision > 0 { + precisionStr = "." + strings.Repeat("9", precision) + } + + newTime, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" -0700 MST", t) + + if err != nil { + panic(err) + } + + return &newTime +} diff --git a/mysql/delete_statement_test.go b/mysql/delete_statement_test.go index cfe990e..1bd55ce 100644 --- a/mysql/delete_statement_test.go +++ b/mysql/delete_statement_test.go @@ -5,19 +5,19 @@ import ( ) func TestDeleteUnconditionally(t *testing.T) { - assertStatementErr(t, table1.DELETE(), `jet: WHERE clause not set`) - assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) + assertStatementSqlErr(t, table1.DELETE(), `jet: WHERE clause not set`) + assertStatementSqlErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) } func TestDeleteWithWhere(t *testing.T) { - assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` DELETE FROM db.table1 WHERE table1.col1 = ?; `, int64(1)) } func TestDeleteWithWhereOrderByLimit(t *testing.T) { - assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).ORDER_BY(table1Col1).LIMIT(1), ` + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).ORDER_BY(table1Col1).LIMIT(1), ` DELETE FROM db.table1 WHERE table1.col1 = ? ORDER BY table1.col1 diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index cb1513b..c65c1f9 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -6,21 +6,21 @@ import ( "time" ) -//TODO: +// //func TestInvalidInsert(t *testing.T) { -// assertStatementErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") -// assertStatementErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") +// assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") +// assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") //} func TestInsertNilValue(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), ` + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` INSERT INTO db.table1 (col1) VALUES (?); `, nil) } func TestInsertSingleValue(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1).VALUES(1), ` + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` INSERT INTO db.table1 (col1) VALUES (?); `, int(1)) @@ -29,7 +29,7 @@ INSERT INTO db.table1 (col1) VALUES func TestInsertWithColumnList(t *testing.T) { columnList := ColumnList(table3ColInt, table3StrCol) - assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), ` + assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` INSERT INTO db.table3 (col_int, col2) VALUES (?, ?); `, 1, 3) @@ -38,14 +38,14 @@ INSERT INTO db.table3 (col_int, col2) VALUES func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) - assertStatement(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` + assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` INSERT INTO db.table1 (col_timestamp) VALUES (?); `, date) } func TestInsertMultipleValues(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` + assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` INSERT INTO db.table1 (col1, col_float, col3) VALUES (?, ?, ?); `, 1, 2, 3) @@ -57,7 +57,7 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(11, 22). VALUES(111, 222) - assertStatement(t, stmt, ` + assertStatementSql(t, stmt, ` INSERT INTO db.table1 (col1, col_float) VALUES (?, ?), (?, ?), @@ -88,7 +88,7 @@ INSERT INTO db.table1 (col1, col_float) VALUES (?, ?); ` - assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) + assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { @@ -130,5 +130,5 @@ INSERT INTO db.table1 (col1, col_float) VALUES (DEFAULT, ?); ` - assertStatement(t, stmt, expectedSQL, "two") + assertStatementSql(t, stmt, expectedSQL, "two") } diff --git a/mysql/lock_statement_test.go b/mysql/lock_statement_test.go index b961421..875acf8 100644 --- a/mysql/lock_statement_test.go +++ b/mysql/lock_statement_test.go @@ -3,19 +3,19 @@ package mysql import "testing" func TestLockRead(t *testing.T) { - assertStatement(t, table2.LOCK().READ(), ` + assertStatementSql(t, table2.LOCK().READ(), ` LOCK TABLES db.table2 READ; `) } func TestLockWrite(t *testing.T) { - assertStatement(t, table2.LOCK().WRITE(), ` + assertStatementSql(t, table2.LOCK().WRITE(), ` LOCK TABLES db.table2 WRITE; `) } func TestUNLOCK_TABLES(t *testing.T) { - assertStatement(t, UNLOCK_TABLES(), ` + assertStatementSql(t, UNLOCK_TABLES(), ` UNLOCK TABLES; `) } diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index 96f0c54..e7f3d7e 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -6,13 +6,13 @@ import ( ) func TestInvalidSelect(t *testing.T) { - assertStatementErr(t, SELECT(nil), "jet: Projection is nil") + assertStatementSqlErr(t, SELECT(nil), "jet: Projection is nil") } func TestSelectColumnList(t *testing.T) { columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt) - assertStatement(t, SELECT(columnList).FROM(table2), ` + assertStatementSql(t, SELECT(columnList).FROM(table2), ` SELECT table2.col_int AS "table2.col_int", table2.col_float AS "table2.col_float", table3.col_int AS "table3.col_int" @@ -21,7 +21,7 @@ FROM db.table2; } func TestSelectLiterals(t *testing.T) { - assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` + assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` SELECT ?, ?, ? @@ -30,25 +30,25 @@ FROM db.table1; } func TestSelectDistinct(t *testing.T) { - assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` + assertStatementSql(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` SELECT DISTINCT table1.col_bool AS "table1.col_bool" FROM db.table1; `) } func TestSelectFrom(t *testing.T) { - assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` SELECT table1.col_int AS "table1.col_int", table2.col_float AS "table2.col_float" FROM db.table1; `) - assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` SELECT table1.col_int AS "table1.col_int", table2.col_float AS "table2.col_float" FROM db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int); `) - assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` + assertStatementSql(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` SELECT table1.col_int AS "table1.col_int", table2.col_float AS "table2.col_float" FROM db.table1 @@ -57,12 +57,12 @@ FROM db.table1 } func TestSelectWhere(t *testing.T) { - assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 WHERE ?; `, true) - assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 WHERE table1.col_int >= ?; @@ -70,7 +70,7 @@ WHERE table1.col_int >= ?; } func TestSelectGroupBy(t *testing.T) { - assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` SELECT table2.col_int AS "table2.col_int" FROM db.table2 GROUP BY table2.col_float; @@ -78,7 +78,7 @@ GROUP BY table2.col_float; } func TestSelectHaving(t *testing.T) { - assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` + assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` SELECT table3.col_int AS "table3.col_int" FROM db.table3 HAVING table1.col_bool = ?; @@ -86,12 +86,12 @@ HAVING table1.col_bool = ?; } func TestSelectOrderBy(t *testing.T) { - assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` SELECT table2.col_float AS "table2.col_float" FROM db.table2 ORDER BY table2.col_int DESC; `) - assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` SELECT table2.col_float AS "table2.col_float" FROM db.table2 ORDER BY table2.col_int DESC, table2.col_int ASC; @@ -99,12 +99,12 @@ ORDER BY table2.col_int DESC, table2.col_int ASC; } func TestSelectLimitOffset(t *testing.T) { - assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` SELECT table2.col_int AS "table2.col_int" FROM db.table2 LIMIT ?; `, int64(10)) - assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` SELECT table2.col_int AS "table2.col_int" FROM db.table2 LIMIT ? diff --git a/mysql/set_statement_test.go b/mysql/set_statement_test.go index 950b511..b21e7f3 100644 --- a/mysql/set_statement_test.go +++ b/mysql/set_statement_test.go @@ -8,7 +8,7 @@ func TestSelectSets(t *testing.T) { select1 := SELECT(table1ColBool).FROM(table1) select2 := SELECT(table2ColBool).FROM(table2) - assertStatement(t, select1.UNION(select2), ` + assertStatementSql(t, select1.UNION(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 @@ -19,7 +19,7 @@ UNION FROM db.table2 ); `) - assertStatement(t, select1.UNION_ALL(select2), ` + assertStatementSql(t, select1.UNION_ALL(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go index 2980ef6..55d88d3 100644 --- a/mysql/update_statement_test.go +++ b/mysql/update_statement_test.go @@ -17,7 +17,7 @@ WHERE table1.col_int >= ?; fmt.Println(stmt.Sql()) - assertStatement(t, stmt, expectedSQL, 1, int64(33)) + assertStatementSql(t, stmt, expectedSQL, 1, int64(33)) } func TestUpdateWithValues(t *testing.T) { @@ -33,7 +33,7 @@ WHERE table1.col_int >= ?; fmt.Println(stmt.Sql()) - assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33)) + assertStatementSql(t, stmt, expectedSQL, 1, 22.2, int64(33)) } func TestUpdateOneColumnWithSelect(t *testing.T) { @@ -54,10 +54,10 @@ WHERE table1.col1 = ?; //fmt.Println(stmt.Sql()) - assertStatement(t, stmt, expectedSQL, int64(2)) + assertStatementSql(t, stmt, expectedSQL, int64(2)) } func TestInvalidInputs(t *testing.T) { - assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") - assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") + assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") + assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") } diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 6af9545..5804a07 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -2,7 +2,7 @@ package mysql import ( "github.com/go-jet/jet/internal/jet" - "gotest.tools/assert" + "github.com/go-jet/jet/internal/testutils" "testing" ) @@ -59,48 +59,16 @@ var table3 = NewTable( table3StrCol) func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { - out := jet.SqlBuilder{Dialect: Dialect} - err := jet.Serialize(clause, jet.SelectStatementType, &out) - - assert.NilError(t, err) - - //fmt.Println(out.Buff.String()) - - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { - out := jet.SqlBuilder{Dialect: Dialect} - err := jet.Serialize(clause, jet.SelectStatementType, &out) - - //fmt.Println(out.buff.String()) - assert.Assert(t, err != nil) - assert.Error(t, err, errString) + testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) } func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { - out := jet.SqlBuilder{Dialect: Dialect} - err := jet.SerializeForProjection(projection, jet.SelectStatementType, &out) - - assert.NilError(t, err) - - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) } -func assertStatement(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { - queryStr, args, err := query.Sql() - assert.NilError(t, err) - - //fmt.Println(queryStr) - assert.Equal(t, queryStr, expectedQuery) - assert.DeepEqual(t, args, expectedArgs) -} - -func assertStatementErr(t *testing.T, stmt jet.Statement, errorStr string) { - _, _, err := stmt.Sql() - - assert.Assert(t, err != nil) - assert.Error(t, err, errorStr) -} +var assertStatementSql = testutils.AssertStatementSql +var assertStatementSqlErr = testutils.AssertStatementSqlErr diff --git a/postgres/delete_statement_test.go b/postgres/delete_statement_test.go index ebb0cc4..847e1bd 100644 --- a/postgres/delete_statement_test.go +++ b/postgres/delete_statement_test.go @@ -5,19 +5,19 @@ import ( ) func TestDeleteUnconditionally(t *testing.T) { - assertStatementErr(t, table1.DELETE(), `jet: WHERE clause not set`) - assertStatementErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) + assertStatementSqlErr(t, table1.DELETE(), `jet: WHERE clause not set`) + assertStatementSqlErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) } func TestDeleteWithWhere(t *testing.T) { - assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` DELETE FROM db.table1 WHERE table1.col1 = $1; `, int64(1)) } func TestDeleteWithWhereAndReturning(t *testing.T) { - assertStatement(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), ` + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).RETURNING(table1Col1), ` DELETE FROM db.table1 WHERE table1.col1 = $1 RETURNING table1.col1 AS "table1.col1"; diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index 7fd84ba..6426ab3 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -13,14 +13,14 @@ import ( //} func TestInsertNilValue(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1).VALUES(nil), ` + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` INSERT INTO db.table1 (col1) VALUES ($1); `, nil) } func TestInsertSingleValue(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1).VALUES(1), ` + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` INSERT INTO db.table1 (col1) VALUES ($1); `, int(1)) @@ -29,7 +29,7 @@ INSERT INTO db.table1 (col1) VALUES func TestInsertWithColumnList(t *testing.T) { columnList := ColumnList(table3ColInt, table3StrCol) - assertStatement(t, table3.INSERT(columnList).VALUES(1, 3), ` + assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` INSERT INTO db.table3 (col_int, col2) VALUES ($1, $2); `, 1, 3) @@ -38,14 +38,14 @@ INSERT INTO db.table3 (col_int, col2) VALUES func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) - assertStatement(t, table1.INSERT(table1ColTime).VALUES(date), ` + assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), ` INSERT INTO db.table1 (col_time) VALUES ($1); `, date) } func TestInsertMultipleValues(t *testing.T) { - assertStatement(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` + assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` INSERT INTO db.table1 (col1, col_float, col3) VALUES ($1, $2, $3); `, 1, 2, 3) @@ -57,7 +57,7 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(11, 22). VALUES(111, 222) - assertStatement(t, stmt, ` + assertStatementSql(t, stmt, ` INSERT INTO db.table1 (col1, col_float) VALUES ($1, $2), ($3, $4), @@ -88,7 +88,7 @@ INSERT INTO db.table1 (col1, col_float) VALUES ($3, $4); ` - assertStatement(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) + assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { @@ -132,7 +132,7 @@ INSERT INTO db.table1 (col1) ( FROM db.table1 ); ` - assertStatement(t, stmt, expectedSQL) + assertStatementSql(t, stmt, expectedSQL) } func TestInsertDefaultValue(t *testing.T) { @@ -144,5 +144,5 @@ INSERT INTO db.table1 (col1, col_float) VALUES (DEFAULT, $1); ` - assertStatement(t, stmt, expectedSQL, "two") + assertStatementSql(t, stmt, expectedSQL, "two") } diff --git a/postgres/lock_statement_test.go b/postgres/lock_statement_test.go index de0dddb..152aa76 100644 --- a/postgres/lock_statement_test.go +++ b/postgres/lock_statement_test.go @@ -5,28 +5,28 @@ import ( ) func TestLockTable(t *testing.T) { - assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_SHARE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_ACCESS_SHARE), ` LOCK TABLE db.table1 IN ACCESS SHARE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_ROW_SHARE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_ROW_SHARE), ` LOCK TABLE db.table1 IN ROW SHARE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_ROW_EXCLUSIVE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_ROW_EXCLUSIVE), ` LOCK TABLE db.table1 IN ROW EXCLUSIVE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_SHARE_UPDATE_EXCLUSIVE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_SHARE_UPDATE_EXCLUSIVE), ` LOCK TABLE db.table1 IN SHARE UPDATE EXCLUSIVE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_SHARE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_SHARE), ` LOCK TABLE db.table1 IN SHARE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_SHARE_ROW_EXCLUSIVE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_SHARE_ROW_EXCLUSIVE), ` LOCK TABLE db.table1 IN SHARE ROW EXCLUSIVE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_EXCLUSIVE), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_EXCLUSIVE), ` LOCK TABLE db.table1 IN EXCLUSIVE MODE; `) - assertStatement(t, table1.LOCK().IN(LOCK_ACCESS_EXCLUSIVE).NOWAIT(), ` + assertStatementSql(t, table1.LOCK().IN(LOCK_ACCESS_EXCLUSIVE).NOWAIT(), ` LOCK TABLE db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT; `) } diff --git a/postgres/select_statement_test.go b/postgres/select_statement_test.go index 8189644..5c59a5c 100644 --- a/postgres/select_statement_test.go +++ b/postgres/select_statement_test.go @@ -1,18 +1,17 @@ package postgres import ( - "github.com/go-jet/jet/internal/testutils" "testing" ) func TestInvalidSelect(t *testing.T) { - assertStatementErr(t, SELECT(nil), "jet: Projection is nil") + assertStatementSqlErr(t, SELECT(nil), "jet: Projection is nil") } func TestSelectColumnList(t *testing.T) { columnList := ColumnList(table2ColInt, table2ColFloat, table3ColInt) - assertStatement(t, SELECT(columnList).FROM(table2), ` + assertStatementSql(t, SELECT(columnList).FROM(table2), ` SELECT table2.col_int AS "table2.col_int", table2.col_float AS "table2.col_float", table3.col_int AS "table3.col_int" @@ -21,7 +20,7 @@ FROM db.table2; } func TestSelectLiterals(t *testing.T) { - assertStatement(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` + assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` SELECT $1, $2, $3 @@ -30,25 +29,25 @@ FROM db.table1; } func TestSelectDistinct(t *testing.T) { - assertStatement(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` + assertStatementSql(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` SELECT DISTINCT table1.col_bool AS "table1.col_bool" FROM db.table1; `) } func TestSelectFrom(t *testing.T) { - assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` SELECT table1.col_int AS "table1.col_int", table2.col_float AS "table2.col_float" FROM db.table1; `) - assertStatement(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` SELECT table1.col_int AS "table1.col_int", table2.col_float AS "table2.col_float" FROM db.table1 INNER JOIN db.table2 ON (table1.col_int = table2.col_int); `) - assertStatement(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` + assertStatementSql(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` SELECT table1.col_int AS "table1.col_int", table2.col_float AS "table2.col_float" FROM db.table1 @@ -57,12 +56,12 @@ FROM db.table1 } func TestSelectWhere(t *testing.T) { - assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 WHERE $1; `, true) - assertStatement(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` SELECT table1.col_int AS "table1.col_int" FROM db.table1 WHERE table1.col_int >= $1; @@ -70,7 +69,7 @@ WHERE table1.col_int >= $1; } func TestSelectGroupBy(t *testing.T) { - assertStatement(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` SELECT table2.col_int AS "table2.col_int" FROM db.table2 GROUP BY table2.col_float; @@ -78,7 +77,7 @@ GROUP BY table2.col_float; } func TestSelectHaving(t *testing.T) { - assertStatement(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` + assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` SELECT table3.col_int AS "table3.col_int" FROM db.table3 HAVING table1.col_bool = $1; @@ -86,12 +85,12 @@ HAVING table1.col_bool = $1; } func TestSelectOrderBy(t *testing.T) { - assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` SELECT table2.col_float AS "table2.col_float" FROM db.table2 ORDER BY table2.col_int DESC; `) - assertStatement(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` SELECT table2.col_float AS "table2.col_float" FROM db.table2 ORDER BY table2.col_int DESC, table2.col_int ASC; @@ -99,12 +98,12 @@ ORDER BY table2.col_int DESC, table2.col_int ASC; } func TestSelectLimitOffset(t *testing.T) { - assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` SELECT table2.col_int AS "table2.col_int" FROM db.table2 LIMIT $1; `, int64(10)) - assertStatement(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` SELECT table2.col_int AS "table2.col_int" FROM db.table2 LIMIT $1 @@ -113,23 +112,23 @@ OFFSET $2; } func TestSelectLock(t *testing.T) { - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` + assertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR UPDATE; `) - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` + assertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR SHARE NOWAIT; `) - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(KEY_SHARE().NOWAIT()), ` + assertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(KEY_SHARE().NOWAIT()), ` SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR KEY SHARE NOWAIT; `) - testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(NO_KEY_UPDATE().SKIP_LOCKED()), ` + assertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(NO_KEY_UPDATE().SKIP_LOCKED()), ` SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 FOR NO KEY UPDATE SKIP LOCKED; diff --git a/postgres/set_statement_test.go b/postgres/set_statement_test.go index 0ca3bd8..53f0048 100644 --- a/postgres/set_statement_test.go +++ b/postgres/set_statement_test.go @@ -8,7 +8,7 @@ func TestSelectSets(t *testing.T) { select1 := SELECT(table1ColBool).FROM(table1) select2 := SELECT(table2ColBool).FROM(table2) - assertStatement(t, select1.UNION(select2), ` + assertStatementSql(t, select1.UNION(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 @@ -19,7 +19,7 @@ UNION FROM db.table2 ); `) - assertStatement(t, select1.UNION_ALL(select2), ` + assertStatementSql(t, select1.UNION_ALL(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 @@ -31,7 +31,7 @@ UNION ALL ); `) - assertStatement(t, select1.INTERSECT(select2), ` + assertStatementSql(t, select1.INTERSECT(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 @@ -43,7 +43,7 @@ INTERSECT ); `) - assertStatement(t, select1.INTERSECT_ALL(select2), ` + assertStatementSql(t, select1.INTERSECT_ALL(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 @@ -54,7 +54,7 @@ INTERSECT ALL FROM db.table2 ); `) - assertStatement(t, select1.EXCEPT(select2), ` + assertStatementSql(t, select1.EXCEPT(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 @@ -66,7 +66,7 @@ EXCEPT ); `) - assertStatement(t, select1.EXCEPT_ALL(select2), ` + assertStatementSql(t, select1.EXCEPT_ALL(select2), ` ( SELECT table1.col_bool AS "table1.col_bool" FROM db.table1 diff --git a/postgres/update_statement_test.go b/postgres/update_statement_test.go index 1986806..9c98a3b 100644 --- a/postgres/update_statement_test.go +++ b/postgres/update_statement_test.go @@ -17,7 +17,7 @@ WHERE table1.col_int >= $2; fmt.Println(stmt.Sql()) - assertStatement(t, stmt, expectedSQL, 1, int64(33)) + assertStatementSql(t, stmt, expectedSQL, 1, int64(33)) } func TestUpdateWithValues(t *testing.T) { @@ -32,7 +32,7 @@ WHERE table1.col_int >= $3; fmt.Println(stmt.Sql()) - assertStatement(t, stmt, expectedSQL, 1, 22.2, int64(33)) + assertStatementSql(t, stmt, expectedSQL, 1, 22.2, int64(33)) } func TestUpdateOneColumnWithSelect(t *testing.T) { @@ -53,10 +53,10 @@ RETURNING table1.col1 AS "table1.col1"; WHERE(table1Col1.EQ(Int(2))). RETURNING(table1Col1) - assertStatement(t, stmt, expectedSQL, int64(2)) + assertStatementSql(t, stmt, expectedSQL, int64(2)) } func TestInvalidInputs(t *testing.T) { - assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") - assertStatementErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") + assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") + assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") } diff --git a/postgres/utils_test.go b/postgres/utils_test.go index 78d66e0..c65d5b6 100644 --- a/postgres/utils_test.go +++ b/postgres/utils_test.go @@ -2,7 +2,7 @@ package postgres import ( "github.com/go-jet/jet/internal/jet" - "gotest.tools/assert" + "github.com/go-jet/jet/internal/testutils" "testing" ) @@ -71,46 +71,16 @@ var table3 = NewTable( table3StrCol) func assertClauseSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { - out := jet.SqlBuilder{Dialect: Dialect} - err := jet.Serialize(clause, jet.SelectStatementType, &out) - - assert.NilError(t, err) - - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + testutils.AssertClauseSerialize(t, Dialect, clause, query, args...) } func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) { - out := jet.SqlBuilder{Dialect: Dialect} - err := jet.Serialize(clause, jet.SelectStatementType, &out) - - //fmt.Println(out.buff.String()) - assert.Assert(t, err != nil) - assert.Error(t, err, errString) + testutils.AssertClauseSerializeErr(t, Dialect, clause, errString) } func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { - out := jet.SqlBuilder{Dialect: Dialect} - err := jet.SerializeForProjection(projection, jet.SelectStatementType, &out) - - assert.NilError(t, err) - - assert.DeepEqual(t, out.Buff.String(), query) - assert.DeepEqual(t, out.Args, args) + testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) } -func assertStatement(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { - queryStr, args, err := query.Sql() - assert.NilError(t, err) - - //fmt.Println(queryStr) - assert.Equal(t, queryStr, expectedQuery) - assert.DeepEqual(t, args, expectedArgs) -} - -func assertStatementErr(t *testing.T, stmt jet.Statement, errorStr string) { - _, _, err := stmt.Sql() - - assert.Assert(t, err != nil) - assert.Error(t, err, errorStr) -} +var assertStatementSql = testutils.AssertStatementSql +var assertStatementSqlErr = testutils.AssertStatementSqlErr