diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 642f0cd..59b776f 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -163,10 +163,17 @@ func argToString(value interface{}) string { case time.Time: return stringQuote(string(pq.FormatTimestamp(bindVal))) default: + if strBindValue, ok := bindVal.(toStringInterface); ok { + return stringQuote(strBindValue.String()) + } panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) } } +type toStringInterface interface { + String() string +} + func integerTypesToString(value interface{}) string { switch bindVal := value.(type) { case int: diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 50f4e13..2f301be 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -59,7 +59,7 @@ func SerializeColumnNames(columns []Column, out *SQLBuilder) { panic("jet: nil column in columns list") } - out.WriteString(col.Name()) + out.WriteIdentifier(col.Name()) } } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 8673ff4..aa53b1e 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/utils" "github.com/go-jet/jet/qrm" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" @@ -14,6 +15,7 @@ import ( "path/filepath" "runtime" "testing" + "time" "github.com/google/go-cmp/cmp" ) @@ -224,3 +226,80 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { assert.True(t, cmp.Equal(actual, expected), msg) } + +// BoolPtr returns address of bool parameter +func BoolPtr(b bool) *bool { + return &b +} + +// Int8Ptr returns address of int8 parameter +func Int8Ptr(i int8) *int8 { + return &i +} + +// UInt8Ptr returns address of uint8 parameter +func UInt8Ptr(i uint8) *uint8 { + return &i +} + +// Int16Ptr returns address of int16 parameter +func Int16Ptr(i int16) *int16 { + return &i +} + +// UInt16Ptr returns address of uint16 parameter +func UInt16Ptr(i uint16) *uint16 { + return &i +} + +// Int32Ptr returns address of int32 parameter +func Int32Ptr(i int32) *int32 { + return &i +} + +// UInt32Ptr returns address of uint32 parameter +func UInt32Ptr(i uint32) *uint32 { + return &i +} + +// Int64Ptr returns address of int64 parameter +func Int64Ptr(i int64) *int64 { + return &i +} + +// UInt64Ptr returns address of uint64 parameter +func UInt64Ptr(i uint64) *uint64 { + return &i +} + +// StringPtr returns address of string parameter +func StringPtr(s string) *string { + return &s +} + +// TimePtr returns address of time.Time parameter +func TimePtr(t time.Time) *time.Time { + return &t +} + +// ByteArrayPtr returns address of []byte parameter +func ByteArrayPtr(arr []byte) *[]byte { + return &arr +} + +// Float32Ptr returns address of float32 parameter +func Float32Ptr(f float32) *float32 { + return &f +} + +// Float64Ptr returns address of float64 parameter +func Float64Ptr(f float64) *float64 { + return &f +} + +// UUIDPtr returns address of uuid.UUID +func UUIDPtr(u string) *uuid.UUID { + newUUID := uuid.MustParse(u) + + return &newUUID +} diff --git a/mysql/dialect.go b/mysql/dialect.go index cfd452a..55862f9 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -26,6 +26,7 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, + ReservedWords: reservedWords, } return jet.NewDialect(mySQLDialectParams) @@ -160,3 +161,267 @@ func mysqlNOTREGEXPLIKEoperator(expressions ...jet.Serializer) jet.SerializerFun jet.Serialize(expressions[1], statement, out, options...) } } + +var reservedWords = []string{ + "ACCESSIBLE", + "ADD", + "ALL", + "ALTER", + "ANALYZE", + "AND", + "AS", + "ASC", + "ASENSITIVE", + "BEFORE", + "BETWEEN", + "BIGINT", + "BINARY", + "BLOB", + "BOTH", + "BY", + "CALL", + "CASCADE", + "CASE", + "CHANGE", + "CHAR", + "CHARACTER", + "CHECK", + "COLLATE", + "COLUMN", + "CONDITION", + "CONSTRAINT", + "CONTINUE", + "CONVERT", + "CREATE", + "CROSS", + "CUBE", + "CUME_DIST", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "CURRENT_USER", + "CURSOR", + "DATABASE", + "DATABASES", + "DAY_HOUR", + "DAY_MICROSECOND", + "DAY_MINUTE", + "DAY_SECOND", + "DEC", + "DECIMAL", + "DECLARE", + "DEFAULT", + "DELAYED", + "DELETE", + "DENSE_RANK", + "DESC", + "DESCRIBE", + "DETERMINISTIC", + "DISTINCT", + "DISTINCTROW", + "DIV", + "DOUBLE", + "DROP", + "DUAL", + "EACH", + "ELSE", + "ELSEIF", + "EMPTY", + "ENCLOSED", + "ESCAPED", + "EXCEPT", + "EXISTS", + "EXIT", + "EXPLAIN", + "FALSE", + "FETCH", + "FIRST_VALUE", + "FLOAT", + "FLOAT4", + "FLOAT8", + "FOR", + "FORCE", + "FOREIGN", + "FROM", + "FULLTEXT", + "FUNCTION", + "GENERATED", + "GET", + "GRANT", + "GROUP", + "GROUPING", + "GROUPS", + "HAVING", + "HIGH_PRIORITY", + "HOUR_MICROSECOND", + "HOUR_MINUTE", + "HOUR_SECOND", + "IF", + "IGNORE", + "IN", + "INDEX", + "INFILE", + "INNER", + "INOUT", + "INSENSITIVE", + "INSERT", + "INT", + "INT1", + "INT2", + "INT3", + "INT4", + "INT8", + "INTEGER", + "INTERVAL", + "INTO", + "IO_AFTER_GTIDS", + "IO_BEFORE_GTIDS", + "IS", + "ITERATE", + "JOIN", + "JSON_TABLE", + "KEY", + "KEYS", + "KILL", + "LAG", + "LAST_VALUE", + "LATERAL", + "LEAD", + "LEADING", + "LEAVE", + "LEFT", + "LIKE", + "LIMIT", + "LINEAR", + "LINES", + "LOAD", + "LOCALTIME", + "LOCALTIMESTAMP", + "LOCK", + "LONG", + "LONGBLOB", + "LONGTEXT", + "LOOP", + "LOW_PRIORITY", + "MASTER_BIND", + "MASTER_SSL_VERIFY_SERVER_CERT", + "MATCH", + "MAXVALUE", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "MIDDLEINT", + "MINUTE_MICROSECOND", + "MINUTE_SECOND", + "MOD", + "MODIFIES", + "NATURAL", + "NOT", + "NO_WRITE_TO_BINLOG", + "NTH_VALUE", + "NTILE", + "NULL", + "NUMERIC", + "OF", + "ON", + "OPTIMIZE", + "OPTIMIZER_COSTS", + "OPTION", + "OPTIONALLY", + "OR", + "ORDER", + "OUT", + "OUTER", + "OUTFILE", + "OVER", + "PARTITION", + "PERCENT_RANK", + "PRECISION", + "PRIMARY", + "PROCEDURE", + "PURGE", + "RANGE", + "RANK", + "READ", + "READS", + "READ_WRITE", + "REAL", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "RELEASE", + "RENAME", + "REPEAT", + "REPLACE", + "REQUIRE", + "RESIGNAL", + "RESTRICT", + "RETURN", + "REVOKE", + "RIGHT", + "RLIKE", + "ROW", + "ROWS", + "ROW_NUMBER", + "SCHEMA", + "SCHEMAS", + "SECOND_MICROSECOND", + "SELECT", + "SENSITIVE", + "SEPARATOR", + "SET", + "SHOW", + "SIGNAL", + "SMALLINT", + "SPATIAL", + "SPECIFIC", + "SQL", + "SQLEXCEPTION", + "SQLSTATE", + "SQLWARNING", + "SQL_BIG_RESULT", + "SQL_CALC_FOUND_ROWS", + "SQL_SMALL_RESULT", + "SSL", + "STARTING", + "STORED", + "STRAIGHT_JOIN", + "SYSTEM", + "TABLE", + "TERMINATED", + "THEN", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "TO", + "TRAILING", + "TRIGGER", + "TRUE", + "UNDO", + "UNION", + "UNIQUE", + "UNLOCK", + "UNSIGNED", + "UPDATE", + "USAGE", + "USE", + "USING", + "UTC_DATE", + "UTC_TIME", + "UTC_TIMESTAMP", + "VALUES", + "VARBINARY", + "VARCHAR", + "VARCHARACTER", + "VARYING", + "VIRTUAL", + "WHEN", + "WHERE", + "WHILE", + "WINDOW", + "WITH", + "WRITE", + "XOR", + "YEAR_MONTH", + "ZEROFILL", +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 2c6768e..636359d 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -1,6 +1,9 @@ package mysql import ( + "fmt" + "github.com/stretchr/testify/require" + "strings" "testing" "time" @@ -95,23 +98,23 @@ func TestExpressionOperators(t *testing.T) { //fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, ` -SELECT all_types.integer IS NULL AS "result.is_null", + testutils.AssertStatementSql(t, query, strings.Replace(` +SELECT all_types.'integer' IS NULL AS "result.is_null", all_types.date_ptr IS NOT NULL AS "result.is_not_null", (all_types.small_int_ptr IN (?, ?)) AS "result.in", (all_types.small_int_ptr IN (( - SELECT all_types.integer AS "all_types.integer" + SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types ))) AS "result.in_select", (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", (all_types.small_int_ptr NOT IN (( - SELECT all_types.integer AS "all_types.integer" + SELECT all_types.'integer' AS "all_types.integer" FROM test_sample.all_types ))) AS "result.not_in_select", DATABASE() FROM test_sample.all_types LIMIT ?; -`, int64(11), int64(22), int64(11), int64(22), int64(2)) +`, "'", "`", -1), int64(11), int64(22), int64(11), int64(22), int64(2)) var dest []struct { common.ExpressionTestResult `alias:"result.*"` @@ -261,45 +264,47 @@ func TestFloatOperators(t *testing.T) { queryStr, _ := query.Sql() - assert.Equal(t, queryStr, ` -SELECT (all_types.numeric = all_types.numeric) AS "eq1", - (all_types.decimal = ?) AS "eq2", - (all_types.real = ?) AS "eq3", - (NOT(all_types.numeric <=> all_types.numeric)) AS "distinct1", - (NOT(all_types.decimal <=> ?)) AS "distinct2", - (NOT(all_types.real <=> ?)) AS "distinct3", - (all_types.numeric <=> all_types.numeric) AS "not_distinct1", - (all_types.decimal <=> ?) AS "not_distinct2", - (all_types.real <=> ?) AS "not_distinct3", - (all_types.numeric < ?) AS "lt1", - (all_types.numeric < ?) AS "lt2", - (all_types.numeric > ?) AS "gt1", - (all_types.numeric > ?) AS "gt2", - TRUNCATE((all_types.decimal + all_types.decimal), ?) AS "add1", - TRUNCATE((all_types.decimal + ?), ?) AS "add2", - TRUNCATE((all_types.decimal - all_types.decimal_ptr), ?) AS "sub1", - TRUNCATE((all_types.decimal - ?), ?) AS "sub2", - TRUNCATE((all_types.decimal * all_types.decimal_ptr), ?) AS "mul1", - TRUNCATE((all_types.decimal * ?), ?) AS "mul2", - TRUNCATE((all_types.decimal / all_types.decimal_ptr), ?) AS "div1", - TRUNCATE((all_types.decimal / ?), ?) AS "div2", - TRUNCATE((all_types.decimal % all_types.decimal_ptr), ?) AS "mod1", - TRUNCATE((all_types.decimal % ?), ?) AS "mod2", - TRUNCATE(POW(all_types.decimal, all_types.decimal_ptr), ?) AS "pow1", - TRUNCATE(POW(all_types.decimal, ?), ?) AS "pow2", - TRUNCATE(ABS(all_types.decimal), ?) AS "abs", - TRUNCATE(POWER(all_types.decimal, ?), ?) AS "power", - TRUNCATE(SQRT(all_types.decimal), ?) AS "sqrt", - TRUNCATE(POWER(all_types.decimal, (? / ?)), ?) AS "cbrt", - CEIL(all_types.real) AS "ceil", - FLOOR(all_types.real) AS "floor", - ROUND(all_types.decimal) AS "round1", - ROUND(all_types.decimal, ?) AS "round2", - SIGN(all_types.real) AS "sign", - TRUNCATE(all_types.decimal, ?) AS "trunc" + //fmt.Println(queryStr) + + assert.Equal(t, queryStr, strings.Replace(` +SELECT (all_types.'numeric' = all_types.'numeric') AS "eq1", + (all_types.'decimal' = ?) AS "eq2", + (all_types.'real' = ?) AS "eq3", + (NOT(all_types.'numeric' <=> all_types.'numeric')) AS "distinct1", + (NOT(all_types.'decimal' <=> ?)) AS "distinct2", + (NOT(all_types.'real' <=> ?)) AS "distinct3", + (all_types.'numeric' <=> all_types.'numeric') AS "not_distinct1", + (all_types.'decimal' <=> ?) AS "not_distinct2", + (all_types.'real' <=> ?) AS "not_distinct3", + (all_types.'numeric' < ?) AS "lt1", + (all_types.'numeric' < ?) AS "lt2", + (all_types.'numeric' > ?) AS "gt1", + (all_types.'numeric' > ?) AS "gt2", + TRUNCATE((all_types.'decimal' + all_types.'decimal'), ?) AS "add1", + TRUNCATE((all_types.'decimal' + ?), ?) AS "add2", + TRUNCATE((all_types.'decimal' - all_types.decimal_ptr), ?) AS "sub1", + TRUNCATE((all_types.'decimal' - ?), ?) AS "sub2", + TRUNCATE((all_types.'decimal' * all_types.decimal_ptr), ?) AS "mul1", + TRUNCATE((all_types.'decimal' * ?), ?) AS "mul2", + TRUNCATE((all_types.'decimal' / all_types.decimal_ptr), ?) AS "div1", + TRUNCATE((all_types.'decimal' / ?), ?) AS "div2", + TRUNCATE((all_types.'decimal' % all_types.decimal_ptr), ?) AS "mod1", + TRUNCATE((all_types.'decimal' % ?), ?) AS "mod2", + TRUNCATE(POW(all_types.'decimal', all_types.decimal_ptr), ?) AS "pow1", + TRUNCATE(POW(all_types.'decimal', ?), ?) AS "pow2", + TRUNCATE(ABS(all_types.'decimal'), ?) AS "abs", + TRUNCATE(POWER(all_types.'decimal', ?), ?) AS "power", + TRUNCATE(SQRT(all_types.'decimal'), ?) AS "sqrt", + TRUNCATE(POWER(all_types.'decimal', (? / ?)), ?) AS "cbrt", + CEIL(all_types.'real') AS "ceil", + FLOOR(all_types.'real') AS "floor", + ROUND(all_types.'decimal') AS "round1", + ROUND(all_types.'decimal', ?) AS "round2", + SIGN(all_types.'real') AS "sign", + TRUNCATE(all_types.'decimal', ?) AS "trunc" FROM test_sample.all_types LIMIT ?; -`) +`, "'", "`", -1)) var dest []struct { common.FloatExpressionTestResult `alias:"."` @@ -568,7 +573,7 @@ func TestTimeExpressions(t *testing.T) { //fmt.Println(query.DebugSql()) - testutils.AssertDebugStatementSql(t, query, ` + testutils.AssertDebugStatementSql(t, query, strings.Replace(` SELECT CAST('20:34:58' AS TIME), all_types.time = all_types.time, all_types.time = CAST('23:06:06' AS TIME), @@ -589,7 +594,7 @@ SELECT CAST('20:34:58' AS TIME), all_types.time >= all_types.time, all_types.time >= CAST('14:26:36' AS TIME), all_types.time + INTERVAL 10 MINUTE, - all_types.time + INTERVAL all_types.integer MINUTE, + all_types.time + INTERVAL all_types.''integer'' MINUTE, all_types.time + INTERVAL 3 HOUR, all_types.time - INTERVAL 20 MINUTE, all_types.time - INTERVAL all_types.small_int MINUTE, @@ -598,7 +603,7 @@ SELECT CAST('20:34:58' AS TIME), CURRENT_TIME, CURRENT_TIME(3) FROM test_sample.all_types; -`, "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", +`, "''", "`", -1), "20:34:58", "23:06:06", "22:06:06.011", "21:06:06.011111", "20:16:06", "19:26:06", "18:36:06", "17:46:06", "16:56:56", "15:16:46", "14:26:36") dest := []struct{}{} @@ -648,25 +653,25 @@ func TestDateExpressions(t *testing.T) { //fmt.Println(query.DebugSql()) - testutils.AssertStatementSql(t, query, ` -SELECT CAST(? AS DATE), + testutils.AssertDebugStatementSql(t, query, ` +SELECT CAST('2009-11-17' AS DATE), all_types.date = all_types.date, - all_types.date = CAST(? AS DATE), + all_types.date = CAST('2019-06-06' AS DATE), all_types.date_ptr != all_types.date, - all_types.date_ptr != CAST(? AS DATE), + all_types.date_ptr != CAST('2019-01-06' AS DATE), NOT(all_types.date <=> all_types.date), - NOT(all_types.date <=> CAST(? AS DATE)), + NOT(all_types.date <=> CAST('2019-02-06' AS DATE)), all_types.date <=> all_types.date, - all_types.date <=> CAST(? AS DATE), + all_types.date <=> CAST('2019-03-06' AS DATE), all_types.date < all_types.date, - all_types.date < CAST(? AS DATE), + all_types.date < CAST('2019-04-06' AS DATE), all_types.date <= all_types.date, - all_types.date <= CAST(? AS DATE), + all_types.date <= CAST('2019-05-05' AS DATE), all_types.date > all_types.date, - all_types.date > CAST(? AS DATE), + all_types.date > CAST('2019-01-04' AS DATE), all_types.date >= all_types.date, - all_types.date >= CAST(? AS DATE), - all_types.date + INTERVAL ? MINUTE_MICROSECOND, + all_types.date >= CAST('2019-02-03' AS DATE), + all_types.date + INTERVAL '10:20.000100' MINUTE_MICROSECOND, all_types.date + INTERVAL all_types.big_int MINUTE, all_types.date + INTERVAL 15 HOUR, all_types.date - INTERVAL 20 MINUTE, @@ -963,6 +968,91 @@ func TestINTERVAL(t *testing.T) { assert.NoError(t, err) } +func TestAllTypesInsert(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + + stmt := AllTypes.INSERT(AllTypes.AllColumns). + MODEL(toInsert) + + fmt.Println(stmt.DebugSql()) + + testutils.AssertExec(t, stmt, tx, 1) + + var dest model.AllTypes + err = AllTypes.SELECT(AllTypes.AllColumns). + WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))). + Query(tx, &dest) + + require.NoError(t, err) + require.Equal(t, toInsert.TinyInt, dest.TinyInt) + + err = tx.Rollback() + require.NoError(t, err) +} + +var toInsert = model.AllTypes{ + Boolean: false, + BooleanPtr: testutils.BoolPtr(true), + TinyInt: 1, + UTinyInt: 2, + SmallInt: 3, + USmallInt: 4, + MediumInt: 5, + UMediumInt: 6, + Integer: 7, + UInteger: 8, + BigInt: 9, + UBigInt: 1122334455, + TinyIntPtr: testutils.Int8Ptr(11), + UTinyIntPtr: testutils.UInt8Ptr(22), + SmallIntPtr: testutils.Int16Ptr(33), + USmallIntPtr: testutils.UInt16Ptr(44), + MediumIntPtr: testutils.Int32Ptr(55), + UMediumIntPtr: testutils.UInt32Ptr(66), + IntegerPtr: testutils.Int32Ptr(77), + UIntegerPtr: testutils.UInt32Ptr(88), + BigIntPtr: testutils.Int64Ptr(99), + UBigIntPtr: testutils.UInt64Ptr(111), + Decimal: 11.22, + DecimalPtr: testutils.Float64Ptr(33.44), + Numeric: 55.66, + NumericPtr: testutils.Float64Ptr(77.88), + Float: 99.00, + FloatPtr: testutils.Float64Ptr(11.22), + Double: 33.44, + DoublePtr: testutils.Float64Ptr(55.66), + Real: 77.88, + RealPtr: testutils.Float64Ptr(99.00), + Bit: "1", + BitPtr: testutils.StringPtr("0"), + Time: time.Date(0, 0, 0, 10, 11, 12, 100, &time.Location{}), + TimePtr: testutils.TimePtr(time.Date(0, 0, 0, 10, 11, 12, 100, time.UTC)), + Date: time.Now(), + DatePtr: testutils.TimePtr(time.Now()), + DateTime: time.Now(), + DateTimePtr: testutils.TimePtr(time.Now()), + Timestamp: time.Now(), + TimestampPtr: testutils.TimePtr(time.Now()), + Year: 2000, + YearPtr: testutils.Int16Ptr(2001), + Char: "abcd", + CharPtr: testutils.StringPtr("absd"), + VarChar: "abcd", + VarCharPtr: testutils.StringPtr("absd"), + Binary: []byte("1010"), + BinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + VarBinary: []byte("1010"), + VarBinaryPtr: testutils.ByteArrayPtr([]byte("100001")), + Blob: []byte("large file"), + BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + Text: "some text", + TextPtr: testutils.StringPtr("text"), + Enum: model.AllTypesEnum_Value1, + JSON: "{}", + JSONPtr: testutils.StringPtr(`{"a": 1}`), +} + var allTypesJson = ` [ { @@ -1100,24 +1190,22 @@ func TestReservedWord(t *testing.T) { stmt := SELECT(User.AllColumns). FROM(User) - // NOTE: A word that follows a period in a qualified name must be an identifier, so it - // need not be quoted even if it is reserved - testutils.AssertDebugStatementSql(t, stmt, ` -SELECT user.column AS "user.column", - user.use AS "user.use", + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT user.''column'' AS "user.column", + user.''use'' AS "user.use", user.ceil AS "user.ceil", user.commit AS "user.commit", - user.create AS "user.create", - user.default AS "user.default", - user.desc AS "user.desc", - user.empty AS "user.empty", - user.float AS "user.float", - user.join AS "user.join", - user.like AS "user.like", + user.''create'' AS "user.create", + user.''default'' AS "user.default", + user.''desc'' AS "user.desc", + user.''empty'' AS "user.empty", + user.''float'' AS "user.float", + user.''join'' AS "user.join", + user.''like'' AS "user.like", user.max AS "user.max", - user.rank AS "user.rank" + user.''rank'' AS "user.rank" FROM test_sample.user; -`) +`, "''", "`", -1)) var dest []model.User err := stmt.Query(db, &dest) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index 0b30080..f03f38b 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1066,32 +1066,32 @@ LIMIT $6; } var allTypesRow0 = model.AllTypes{ - SmallIntPtr: Int16Ptr(14), + SmallIntPtr: testutils.Int16Ptr(14), SmallInt: 14, - IntegerPtr: Int32Ptr(300), + IntegerPtr: testutils.Int32Ptr(300), Integer: 300, - BigIntPtr: Int64Ptr(50000), + BigIntPtr: testutils.Int64Ptr(50000), BigInt: 5000, - DecimalPtr: Float64Ptr(1.11), + DecimalPtr: testutils.Float64Ptr(1.11), Decimal: 1.11, - NumericPtr: Float64Ptr(2.22), + NumericPtr: testutils.Float64Ptr(2.22), Numeric: 2.22, - RealPtr: Float32Ptr(5.55), + RealPtr: testutils.Float32Ptr(5.55), Real: 5.55, - DoublePrecisionPtr: Float64Ptr(11111111.22), + DoublePrecisionPtr: testutils.Float64Ptr(11111111.22), DoublePrecision: 11111111.22, Smallserial: 1, Serial: 1, Bigserial: 1, //MoneyPtr: nil, //Money: - VarCharPtr: StringPtr("ABBA"), + VarCharPtr: testutils.StringPtr("ABBA"), VarChar: "ABBA", - CharPtr: StringPtr("JOHN "), + CharPtr: testutils.StringPtr("JOHN "), Char: "JOHN ", - TextPtr: StringPtr("Some text"), + TextPtr: testutils.StringPtr("Some text"), Text: "Some text", - ByteaPtr: ByteArrayPtr([]byte("bytea")), + ByteaPtr: testutils.ByteArrayPtr([]byte("bytea")), Bytea: []byte("bytea"), TimestampzPtr: testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), Timestampz: *testutils.TimestampWithTimeZone("1999-01-08 13:05:06 +0100 CET", 0), @@ -1103,31 +1103,31 @@ var allTypesRow0 = model.AllTypes{ Timez: *testutils.TimeWithTimeZone("04:05:06 -0800"), TimePtr: testutils.TimeWithoutTimeZone("04:05:06"), Time: *testutils.TimeWithoutTimeZone("04:05:06"), - IntervalPtr: StringPtr("3 days 04:05:06"), + IntervalPtr: testutils.StringPtr("3 days 04:05:06"), Interval: "3 days 04:05:06", - BooleanPtr: BoolPtr(true), + BooleanPtr: testutils.BoolPtr(true), Boolean: false, - PointPtr: StringPtr("(2,3)"), - BitPtr: StringPtr("101"), + PointPtr: testutils.StringPtr("(2,3)"), + BitPtr: testutils.StringPtr("101"), Bit: "101", - BitVaryingPtr: StringPtr("101111"), + BitVaryingPtr: testutils.StringPtr("101111"), BitVarying: "101111", - TsvectorPtr: StringPtr("'supernova':1"), + TsvectorPtr: testutils.StringPtr("'supernova':1"), Tsvector: "'supernova':1", - UUIDPtr: UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), + UUIDPtr: testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), UUID: uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"), - XMLPtr: StringPtr("abc"), + XMLPtr: testutils.StringPtr("abc"), XML: "abc", - JSONPtr: StringPtr(`{"a": 1, "b": 3}`), + JSONPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), JSON: `{"a": 1, "b": 3}`, - JsonbPtr: StringPtr(`{"a": 1, "b": 3}`), + JsonbPtr: testutils.StringPtr(`{"a": 1, "b": 3}`), Jsonb: `{"a": 1, "b": 3}`, - IntegerArrayPtr: StringPtr("{1,2,3}"), + IntegerArrayPtr: testutils.StringPtr("{1,2,3}"), IntegerArray: "{1,2,3}", - TextArrayPtr: StringPtr("{breakfast,consulting}"), + TextArrayPtr: testutils.StringPtr("{breakfast,consulting}"), TextArray: "{breakfast,consulting}", JsonbArray: `{"{\"a\": 1, \"b\": 2}","{\"a\": 3, \"b\": 4}"}`, - TextMultiDimArrayPtr: StringPtr("{{meeting,lunch},{training,presentation}}"), + TextMultiDimArrayPtr: testutils.StringPtr("{{meeting,lunch},{training,presentation}}"), TextMultiDimArray: "{{meeting,lunch},{training,presentation}}", } diff --git a/tests/postgres/sample_test.go b/tests/postgres/sample_test.go index b2bed1d..ea35c16 100644 --- a/tests/postgres/sample_test.go +++ b/tests/postgres/sample_test.go @@ -28,7 +28,7 @@ WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11'; err := query.Query(db, &result) assert.NoError(t, err) assert.Equal(t, result.UUID, uuid.MustParse("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - testutils.AssertDeepEqual(t, result.UUIDPtr, UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + testutils.AssertDeepEqual(t, result.UUIDPtr, testutils.UUIDPtr("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) } func TestUUIDComplex(t *testing.T) { @@ -280,7 +280,7 @@ ORDER BY employee.employee_id; FirstName: "Salley", LastName: "Lester", EmploymentDate: testutils.TimestampWithTimeZone("1999-01-08 04:05:06 +0100 CET", 1), - ManagerID: Int32Ptr(3), + ManagerID: testutils.Int32Ptr(3), }) } @@ -322,7 +322,7 @@ FROM test_sample."WEIRD NAMES TABLE"; WeirdColumnName5: "Doe", WeirdColumnName6: "Doe", WeirdColumnName7: "Doe", - Weirdcolumnname8: StringPtr("Doe"), + Weirdcolumnname8: testutils.StringPtr("Doe"), WeirdColName9: "Doe", WeirdColuName10: "Doe", WeirdColuName11: "Doe", diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 40e18f8..c5b6d3e 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -505,10 +505,10 @@ func TestScanToSlice(t *testing.T) { assert.NoError(t, err) assert.Equal(t, len(dest), 2) testutils.AssertDeepEqual(t, dest[0].Film, film1) - testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{Int32Ptr(1), Int32Ptr(2), Int32Ptr(3), Int32Ptr(4), - Int32Ptr(5), Int32Ptr(6), Int32Ptr(7), Int32Ptr(8)}) + testutils.AssertDeepEqual(t, dest[0].IDs, []*int32{testutils.Int32Ptr(1), testutils.Int32Ptr(2), testutils.Int32Ptr(3), testutils.Int32Ptr(4), + testutils.Int32Ptr(5), testutils.Int32Ptr(6), testutils.Int32Ptr(7), testutils.Int32Ptr(8)}) testutils.AssertDeepEqual(t, dest[1].Film, film2) - testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{Int32Ptr(9), Int32Ptr(10)}) + testutils.AssertDeepEqual(t, dest[1].IDs, []*int32{testutils.Int32Ptr(9), testutils.Int32Ptr(10)}) }) t.Run("complex struct 1", func(t *testing.T) { @@ -726,10 +726,10 @@ func TestStructScanAllNull(t *testing.T) { var address256 = model.Address{ AddressID: 256, Address: "1497 Yuzhou Drive", - Address2: StringPtr(""), + Address2: testutils.StringPtr(""), District: "England", CityID: 312, - PostalCode: StringPtr("3433"), + PostalCode: testutils.StringPtr("3433"), Phone: "246810237916", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -737,10 +737,10 @@ var address256 = model.Address{ var addres517 = model.Address{ AddressID: 517, Address: "548 Uruapan Street", - Address2: StringPtr(""), + Address2: testutils.StringPtr(""), District: "Ontario", CityID: 312, - PostalCode: StringPtr("35653"), + PostalCode: testutils.StringPtr("35653"), Phone: "879347453467", LastUpdate: *testutils.TimestampWithoutTimeZone("2006-02-15 09:45:30", 0), } @@ -750,12 +750,12 @@ var customer256 = model.Customer{ StoreID: 2, FirstName: "Mattie", LastName: "Hoffman", - Email: StringPtr("mattie.hoffman@sakilacustomer.org"), + Email: testutils.StringPtr("mattie.hoffman@sakilacustomer.org"), AddressID: 256, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var customer512 = model.Customer{ @@ -763,12 +763,12 @@ var customer512 = model.Customer{ StoreID: 1, FirstName: "Cecil", LastName: "Vines", - Email: StringPtr("cecil.vines@sakilacustomer.org"), + Email: testutils.StringPtr("cecil.vines@sakilacustomer.org"), AddressID: 517, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 0), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var countryUk = model.Country{ @@ -801,32 +801,32 @@ var inventory2 = model.Inventory{ var film1 = model.Film{ FilmID: 1, Title: "Academy Dinosaur", - Description: StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), - ReleaseYear: Int32Ptr(2006), + Description: testutils.StringPtr("A Epic Drama of a Feminist And a Mad Scientist who must Battle a Teacher in The Canadian Rockies"), + ReleaseYear: testutils.Int32Ptr(2006), LanguageID: 1, RentalDuration: 6, RentalRate: 0.99, - Length: Int16Ptr(86), + Length: testutils.Int16Ptr(86), ReplacementCost: 20.99, Rating: &pgRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), + SpecialFeatures: testutils.StringPtr("{\"Deleted Scenes\",\"Behind the Scenes\"}"), Fulltext: "'academi':1 'battl':15 'canadian':20 'dinosaur':2 'drama':5 'epic':4 'feminist':8 'mad':11 'must':14 'rocki':21 'scientist':12 'teacher':17", } var film2 = model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: Int32Ptr(2006), + Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: testutils.Int32Ptr(2006), LanguageID: 1, RentalDuration: 3, RentalRate: 4.99, - Length: Int16Ptr(48), + Length: testutils.Int16Ptr(48), ReplacementCost: 12.99, Rating: &gRating, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: StringPtr(`{Trailers,"Deleted Scenes"}`), + SpecialFeatures: testutils.StringPtr(`{Trailers,"Deleted Scenes"}`), Fulltext: `'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14`, } diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index b641ec0..28afd8c 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -982,16 +982,16 @@ ORDER BY film.film_id ASC; testutils.AssertDeepEqual(t, maxRentalRateFilms[0], model.Film{ FilmID: 2, Title: "Ace Goldfinger", - Description: StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), - ReleaseYear: Int32Ptr(2006), + Description: testutils.StringPtr("A Astounding Epistle of a Database Administrator And a Explorer who must Find a Car in Ancient China"), + ReleaseYear: testutils.Int32Ptr(2006), LanguageID: 1, RentalRate: 4.99, - Length: Int16Ptr(48), + Length: testutils.Int16Ptr(48), ReplacementCost: 12.99, Rating: &gRating, RentalDuration: 3, LastUpdate: *testutils.TimestampWithoutTimeZone("2013-05-26 14:50:58.951", 3), - SpecialFeatures: StringPtr("{Trailers,\"Deleted Scenes\"}"), + SpecialFeatures: testutils.StringPtr("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) } @@ -1130,11 +1130,11 @@ ORDER BY customer_payment_sum."amount_sum" ASC; FirstName: "Brian", LastName: "Wyman", AddressID: 323, - Email: StringPtr("brian.wyman@sakilacustomer.org"), + Email: testutils.StringPtr("brian.wyman@sakilacustomer.org"), Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), }) assert.Equal(t, customersWithAmounts[0].AmountSum, 27.93) @@ -1846,8 +1846,8 @@ func TestDynamicCondition(t *testing.T) { Active *bool } - request.CustomerID = Int64Ptr(1) - request.Active = BoolPtr(true) + request.CustomerID = testutils.Int64Ptr(1) + request.Active = testutils.BoolPtr(true) // ... diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 57f579e..ca07332 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -154,7 +154,7 @@ WHERE link.id = 0; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - assertExecErr(t, stmt, "pq: number of columns does not match number of values") + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") } func TestUpdateWithModelData(t *testing.T) { @@ -241,7 +241,7 @@ WHERE link.id = 201; ` testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) - assertExecErr(t, stmt, "pq: number of columns does not match number of values") + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") } func TestUpdateQueryContext(t *testing.T) { diff --git a/tests/postgres/util_test.go b/tests/postgres/util_test.go index c30a365..b2d5452 100644 --- a/tests/postgres/util_test.go +++ b/tests/postgres/util_test.go @@ -4,7 +4,6 @@ import ( "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/tests/.gentestdata/jetdb/dvds/model" - "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "testing" @@ -19,60 +18,17 @@ func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) { assert.Equal(t, rows, rowsAffected) } -func assertExecErr(t *testing.T, stmt jet.Statement, errorStr string) { - _, err := stmt.Exec(db) - - assert.Error(t, err, errorStr) -} - -func BoolPtr(b bool) *bool { - return &b -} - -func Int16Ptr(i int16) *int16 { - return &i -} - -func Int32Ptr(i int32) *int32 { - return &i -} - -func Int64Ptr(i int64) *int64 { - return &i -} - -func StringPtr(s string) *string { - return &s -} - -func ByteArrayPtr(arr []byte) *[]byte { - return &arr -} - -func Float32Ptr(f float32) *float32 { - return &f -} -func Float64Ptr(f float64) *float64 { - return &f -} - -func UUIDPtr(u string) *uuid.UUID { - newUUID := uuid.MustParse(u) - - return &newUUID -} - var customer0 = model.Customer{ CustomerID: 1, StoreID: 1, FirstName: "Mary", LastName: "Smith", - Email: StringPtr("mary.smith@sakilacustomer.org"), + Email: testutils.StringPtr("mary.smith@sakilacustomer.org"), AddressID: 5, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var customer1 = model.Customer{ @@ -80,12 +36,12 @@ var customer1 = model.Customer{ StoreID: 1, FirstName: "Patricia", LastName: "Johnson", - Email: StringPtr("patricia.johnson@sakilacustomer.org"), + Email: testutils.StringPtr("patricia.johnson@sakilacustomer.org"), AddressID: 6, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } var lastCustomer = model.Customer{ @@ -93,10 +49,10 @@ var lastCustomer = model.Customer{ StoreID: 2, FirstName: "Austin", LastName: "Cintron", - Email: StringPtr("austin.cintron@sakilacustomer.org"), + Email: testutils.StringPtr("austin.cintron@sakilacustomer.org"), AddressID: 605, Activebool: true, CreateDate: *testutils.TimestampWithoutTimeZone("2006-02-14 00:00:00", 0), LastUpdate: testutils.TimestampWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: Int32Ptr(1), + Active: testutils.Int32Ptr(1), } diff --git a/tests/testdata b/tests/testdata index 889e07c..1745be3 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit 889e07c0ebaf6b4021e31cce29b5861eb5c8cc17 +Subproject commit 1745be34a649c0f37d0d31d7c0352a1248ace2dc