From e8f4c2b31bf2682f731dd1175ba1cb9cf003efa2 Mon Sep 17 00:00:00 2001 From: go-jet Date: Thu, 21 Oct 2021 13:39:24 +0200 Subject: [PATCH] Add SQLBuilder support for SQLite databases. --- .gitignore | 3 +- go.mod | 1 + go.sum | 2 + internal/jet/func_expression.go | 130 ++-- internal/jet/interval.go | 6 +- internal/testutils/test_utils.go | 16 +- qrm/internal/null_types.go | 5 +- sqlite/cast.go | 55 ++ sqlite/cast_test.go | 14 + sqlite/columns.go | 58 ++ sqlite/delete_statement.go | 61 ++ sqlite/delete_statement_test.go | 26 + sqlite/dialect.go | 225 +++++++ sqlite/dialect_test.go | 59 ++ sqlite/expressions.go | 97 +++ sqlite/expressions_test.go | 52 ++ sqlite/functions.go | 342 +++++++++++ sqlite/insert_statement.go | 117 ++++ sqlite/insert_statement_test.go | 150 +++++ sqlite/literal.go | 70 +++ sqlite/literal_test.go | 80 +++ sqlite/on_conflict_clause.go | 84 +++ sqlite/operators.go | 9 + sqlite/select_statement.go | 186 ++++++ sqlite/select_statement_test.go | 156 +++++ sqlite/select_table.go | 24 + sqlite/set_statement.go | 99 ++++ sqlite/set_statement_test.go | 31 + sqlite/statement.go | 8 + sqlite/table.go | 122 ++++ sqlite/table_test.go | 101 ++++ sqlite/types.go | 27 + sqlite/update_statement.go | 70 +++ sqlite/update_statement_test.go | 82 +++ sqlite/utils_test.go | 55 ++ sqlite/with_statement.go | 26 + tests/dbconfig/dbconfig.go | 12 +- tests/init/init.go | 19 + tests/internal/utils/repo/repo.go | 33 ++ tests/sqlite/alltypes_test.go | 912 +++++++++++++++++++++++++++++ tests/sqlite/cast_test.go | 41 ++ tests/sqlite/delete_test.go | 83 +++ tests/sqlite/generator_test.go | 298 ++++++++++ tests/sqlite/insert_test.go | 393 +++++++++++++ tests/sqlite/main_test.go | 90 +++ tests/sqlite/raw_statement_test.go | 121 ++++ tests/sqlite/select_test.go | 749 +++++++++++++++++++++++ tests/sqlite/update_test.go | 290 +++++++++ tests/sqlite/with_test.go | 234 ++++++++ tests/testdata | 2 +- 50 files changed, 5851 insertions(+), 75 deletions(-) create mode 100644 sqlite/cast.go create mode 100644 sqlite/cast_test.go create mode 100644 sqlite/columns.go create mode 100644 sqlite/delete_statement.go create mode 100644 sqlite/delete_statement_test.go create mode 100644 sqlite/dialect.go create mode 100644 sqlite/dialect_test.go create mode 100644 sqlite/expressions.go create mode 100644 sqlite/expressions_test.go create mode 100644 sqlite/functions.go create mode 100644 sqlite/insert_statement.go create mode 100644 sqlite/insert_statement_test.go create mode 100644 sqlite/literal.go create mode 100644 sqlite/literal_test.go create mode 100644 sqlite/on_conflict_clause.go create mode 100644 sqlite/operators.go create mode 100644 sqlite/select_statement.go create mode 100644 sqlite/select_statement_test.go create mode 100644 sqlite/select_table.go create mode 100644 sqlite/set_statement.go create mode 100644 sqlite/set_statement_test.go create mode 100644 sqlite/statement.go create mode 100644 sqlite/table.go create mode 100644 sqlite/table_test.go create mode 100644 sqlite/types.go create mode 100644 sqlite/update_statement.go create mode 100644 sqlite/update_statement_test.go create mode 100644 sqlite/utils_test.go create mode 100644 sqlite/with_statement.go create mode 100644 tests/internal/utils/repo/repo.go create mode 100644 tests/sqlite/alltypes_test.go create mode 100644 tests/sqlite/cast_test.go create mode 100644 tests/sqlite/delete_test.go create mode 100644 tests/sqlite/generator_test.go create mode 100644 tests/sqlite/insert_test.go create mode 100644 tests/sqlite/main_test.go create mode 100644 tests/sqlite/raw_statement_test.go create mode 100644 tests/sqlite/select_test.go create mode 100644 tests/sqlite/update_test.go create mode 100644 tests/sqlite/with_test.go diff --git a/.gitignore b/.gitignore index f286e83..153be12 100644 --- a/.gitignore +++ b/.gitignore @@ -18,4 +18,5 @@ # Test files gen .gentestdata -.tests/testdata/ \ No newline at end of file +.tests/testdata/ +.gen \ No newline at end of file diff --git a/go.mod b/go.mod index c349db9..12ed6d7 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/jackc/pgconn v1.8.1 github.com/jackc/pgx/v4 v4.11.0 //tests github.com/lib/pq v1.7.0 + github.com/mattn/go-sqlite3 v1.14.8 github.com/pkg/profile v1.5.0 //tests github.com/shopspring/decimal v1.2.0 // tests github.com/stretchr/testify v1.6.1 // tests diff --git a/go.sum b/go.sum index 26a2d4a..d0f5f98 100644 --- a/go.sum +++ b/go.sum @@ -218,6 +218,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= +github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU= +github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 606e7e1..9a647e9 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -2,7 +2,7 @@ package jet // ROW is construct one table row from list of expressions. func ROW(expressions ...Expression) Expression { - return newFunc("ROW", expressions, nil) + return NewFunc("ROW", expressions, nil) } // ------------------ Mathematical functions ---------------// @@ -265,118 +265,118 @@ func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression { // LOWER returns string expression in lower case func LOWER(stringExpression StringExpression) StringExpression { - return newStringFunc("LOWER", stringExpression) + return NewStringFunc("LOWER", stringExpression) } // UPPER returns string expression in upper case func UPPER(stringExpression StringExpression) StringExpression { - return newStringFunc("UPPER", stringExpression) + return NewStringFunc("UPPER", stringExpression) } // BTRIM removes the longest string consisting only of characters // in characters (a space by default) from the start and end of string func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { if len(trimChars) > 0 { - return newStringFunc("BTRIM", stringExpression, trimChars[0]) + return NewStringFunc("BTRIM", stringExpression, trimChars[0]) } - return newStringFunc("BTRIM", stringExpression) + return NewStringFunc("BTRIM", stringExpression) } // LTRIM removes the longest string containing only characters // from characters (a space by default) from the start of string func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { if len(trimChars) > 0 { - return newStringFunc("LTRIM", str, trimChars[0]) + return NewStringFunc("LTRIM", str, trimChars[0]) } - return newStringFunc("LTRIM", str) + return NewStringFunc("LTRIM", str) } // RTRIM removes the longest string containing only characters // from characters (a space by default) from the end of string func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { if len(trimChars) > 0 { - return newStringFunc("RTRIM", str, trimChars[0]) + return NewStringFunc("RTRIM", str, trimChars[0]) } - return newStringFunc("RTRIM", str) + return NewStringFunc("RTRIM", str) } // CHR returns character with the given code. func CHR(integerExpression IntegerExpression) StringExpression { - return newStringFunc("CHR", integerExpression) + return NewStringFunc("CHR", integerExpression) } // CONCAT adds two or more expressions together func CONCAT(expressions ...Expression) StringExpression { - return newStringFunc("CONCAT", expressions...) + return NewStringFunc("CONCAT", expressions...) } // CONCAT_WS adds two or more expressions together with a separator. func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { - return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...) + return NewStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...) } // CONVERT converts string to dest_encoding. The original encoding is // specified by src_encoding. The string must be valid in this encoding. func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { - return newStringFunc("CONVERT", str, srcEncoding, destEncoding) + return NewStringFunc("CONVERT", str, srcEncoding, destEncoding) } // CONVERT_FROM converts string to the database encoding. The original // encoding is specified by src_encoding. The string must be valid in this encoding. func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { - return newStringFunc("CONVERT_FROM", str, srcEncoding) + return NewStringFunc("CONVERT_FROM", str, srcEncoding) } // CONVERT_TO converts string to dest_encoding. func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { - return newStringFunc("CONVERT_TO", str, toEncoding) + return NewStringFunc("CONVERT_TO", str, toEncoding) } // ENCODE encodes binary data into a textual representation. // Supported formats are: base64, hex, escape. escape converts zero bytes and // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. func ENCODE(data StringExpression, format StringExpression) StringExpression { - return newStringFunc("ENCODE", data, format) + return NewStringFunc("ENCODE", data, format) } // DECODE decodes binary data from textual representation in string. // Options for format are same as in encode. func DECODE(data StringExpression, format StringExpression) StringExpression { - return newStringFunc("DECODE", data, format) + return NewStringFunc("DECODE", data, format) } // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { args := []Expression{formatStr} args = append(args, formatArgs...) - return newStringFunc("FORMAT", args...) + return NewStringFunc("FORMAT", args...) } // INITCAP converts the first letter of each word to upper case // and the rest to lower case. Words are sequences of alphanumeric // characters separated by non-alphanumeric characters. func INITCAP(str StringExpression) StringExpression { - return newStringFunc("INITCAP", str) + return NewStringFunc("INITCAP", str) } // LEFT returns first n characters in the string. // When n is negative, return all but last |n| characters. func LEFT(str StringExpression, n IntegerExpression) StringExpression { - return newStringFunc("LEFT", str, n) + return NewStringFunc("LEFT", str, n) } // RIGHT returns last n characters in the string. // When n is negative, return all but first |n| characters. func RIGHT(str StringExpression, n IntegerExpression) StringExpression { - return newStringFunc("RIGHT", str, n) + return NewStringFunc("RIGHT", str, n) } // LENGTH returns number of characters in string with a given encoding func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { if len(encoding) > 0 { - return newStringFunc("LENGTH", str, encoding[0]) + return NewStringFunc("LENGTH", str, encoding[0]) } - return newStringFunc("LENGTH", str) + return NewStringFunc("LENGTH", str) } // LPAD fills up the string to length length by prepending the characters @@ -384,40 +384,40 @@ func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression // then it is truncated (on the right). func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { if len(text) > 0 { - return newStringFunc("LPAD", str, length, text[0]) + return NewStringFunc("LPAD", str, length, text[0]) } - return newStringFunc("LPAD", str, length) + return NewStringFunc("LPAD", str, length) } // RPAD fills up the string to length length by appending the characters // fill (a space by default). If the string is already longer than length then it is truncated. func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { if len(text) > 0 { - return newStringFunc("RPAD", str, length, text[0]) + return NewStringFunc("RPAD", str, length, text[0]) } - return newStringFunc("RPAD", str, length) + return NewStringFunc("RPAD", str, length) } // MD5 calculates the MD5 hash of string, returning the result in hexadecimal func MD5(stringExpression StringExpression) StringExpression { - return newStringFunc("MD5", stringExpression) + return NewStringFunc("MD5", stringExpression) } // REPEAT repeats string the specified number of times func REPEAT(str StringExpression, n IntegerExpression) StringExpression { - return newStringFunc("REPEAT", str, n) + return NewStringFunc("REPEAT", str, n) } // REPLACE replaces all occurrences in string of substring from with substring to func REPLACE(text, from, to StringExpression) StringExpression { - return newStringFunc("REPLACE", text, from, to) + return NewStringFunc("REPLACE", text, from, to) } // REVERSE returns reversed string. func REVERSE(stringExpression StringExpression) StringExpression { - return newStringFunc("REVERSE", stringExpression) + return NewStringFunc("REVERSE", stringExpression) } // STRPOS returns location of specified substring (same as position(substring in string), @@ -429,22 +429,22 @@ func STRPOS(str, substring StringExpression) IntegerExpression { // SUBSTR extracts substring func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { if len(count) > 0 { - return newStringFunc("SUBSTR", str, from, count[0]) + return NewStringFunc("SUBSTR", str, from, count[0]) } - return newStringFunc("SUBSTR", str, from) + return NewStringFunc("SUBSTR", str, from) } // TO_ASCII convert string to ASCII from another encoding func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression { if len(encoding) > 0 { - return newStringFunc("TO_ASCII", str, encoding[0]) + return NewStringFunc("TO_ASCII", str, encoding[0]) } - return newStringFunc("TO_ASCII", str) + return NewStringFunc("TO_ASCII", str) } // TO_HEX converts number to its equivalent hexadecimal representation func TO_HEX(number IntegerExpression) StringExpression { - return newStringFunc("TO_HEX", number) + return NewStringFunc("TO_HEX", number) } // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. @@ -460,12 +460,12 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType // TO_CHAR converts expression to string with format func TO_CHAR(expression Expression, format StringExpression) StringExpression { - return newStringFunc("TO_CHAR", expression, format) + return NewStringFunc("TO_CHAR", expression, format) } // TO_DATE converts string to date using format func TO_DATE(dateStr, format StringExpression) DateExpression { - return newDateFunc("TO_DATE", dateStr, format) + return NewDateFunc("TO_DATE", dateStr, format) } // TO_NUMBER converts string to numeric using format @@ -482,7 +482,7 @@ func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression { // CURRENT_DATE returns current date func CURRENT_DATE() DateExpression { - dateFunc := newDateFunc("CURRENT_DATE") + dateFunc := NewDateFunc("CURRENT_DATE") dateFunc.noBrackets = true return dateFunc } @@ -522,9 +522,9 @@ func LOCALTIME(precision ...int) TimeExpression { var timeFunc *timeFunc if len(precision) > 0 { - timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0])) + timeFunc = NewTimeFunc("LOCALTIME", FixedLiteral(precision[0])) } else { - timeFunc = newTimeFunc("LOCALTIME") + timeFunc = NewTimeFunc("LOCALTIME") } timeFunc.noBrackets = true @@ -558,26 +558,26 @@ func NOW() TimestampzExpression { func COALESCE(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return newFunc("COALESCE", allValues, nil) + return NewFunc("COALESCE", allValues, nil) } // NULLIF function returns a null value if value1 equals value2; otherwise it returns value1. func NULLIF(value1, value2 Expression) Expression { - return newFunc("NULLIF", []Expression{value1, value2}, nil) + return NewFunc("NULLIF", []Expression{value1, value2}, nil) } // GREATEST selects the largest value from a list of expressions func GREATEST(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return newFunc("GREATEST", allValues, nil) + return NewFunc("GREATEST", allValues, nil) } // LEAST selects the smallest value from a list of expressions func LEAST(value Expression, values ...Expression) Expression { var allValues = []Expression{value} allValues = append(allValues, values...) - return newFunc("LEAST", allValues, nil) + return NewFunc("LEAST", allValues, nil) } //--------------------------------------------------------------------// @@ -590,7 +590,8 @@ type funcExpressionImpl struct { noBrackets bool } -func newFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { +// NewFunc creates new function with name and expressions parameters +func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { funcExp := &funcExpressionImpl{ name: name, expressions: expressions, @@ -608,7 +609,7 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr // NewFloatWindowFunc creates new float function with name and expressions func newWindowFunc(name string, expressions ...Expression) windowExpression { - newFun := newFunc(name, expressions, nil) + newFun := NewFunc(name, expressions, nil) windowExpr := newWindowExpression(newFun) newFun.ExpressionInterfaceImpl.Parent = windowExpr @@ -645,7 +646,7 @@ type boolFunc struct { func newBoolFunc(name string, expressions ...Expression) BoolExpression { boolFunc := &boolFunc{} - boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc) boolFunc.boolInterfaceImpl.parent = boolFunc boolFunc.ExpressionInterfaceImpl.Parent = boolFunc @@ -656,7 +657,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression { func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { boolFunc := &boolFunc{} - boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) + boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc) boolFunc.boolInterfaceImpl.parent = intWindowFunc boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc @@ -673,7 +674,7 @@ type floatFunc struct { func NewFloatFunc(name string, expressions ...Expression) FloatExpression { floatFunc := &floatFunc{} - floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) floatFunc.floatInterfaceImpl.parent = floatFunc return floatFunc @@ -683,7 +684,7 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression { func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { floatFunc := &floatFunc{} - floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc) floatFunc.floatInterfaceImpl.parent = floatWindowFunc floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc @@ -699,7 +700,7 @@ type integerFunc struct { func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { floatFunc := &integerFunc{} - floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) + floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc) floatFunc.integerInterfaceImpl.parent = floatFunc return floatFunc @@ -709,7 +710,7 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { integerFunc := &integerFunc{} - integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) + integerFunc.funcExpressionImpl = *NewFunc(name, expressions, integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc) integerFunc.integerInterfaceImpl.parent = intWindowFunc integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc @@ -722,10 +723,11 @@ type stringFunc struct { stringInterfaceImpl } -func newStringFunc(name string, expressions ...Expression) StringExpression { +// NewStringFunc creates new string function with name and expression parameters +func NewStringFunc(name string, expressions ...Expression) StringExpression { stringFunc := &stringFunc{} - stringFunc.funcExpressionImpl = *newFunc(name, expressions, stringFunc) + stringFunc.funcExpressionImpl = *NewFunc(name, expressions, stringFunc) stringFunc.stringInterfaceImpl.parent = stringFunc return stringFunc @@ -736,10 +738,11 @@ type dateFunc struct { dateInterfaceImpl } -func newDateFunc(name string, expressions ...Expression) *dateFunc { +// NewDateFunc creates new date function with name and expression parameters +func NewDateFunc(name string, expressions ...Expression) *dateFunc { dateFunc := &dateFunc{} - dateFunc.funcExpressionImpl = *newFunc(name, expressions, dateFunc) + dateFunc.funcExpressionImpl = *NewFunc(name, expressions, dateFunc) dateFunc.dateInterfaceImpl.parent = dateFunc return dateFunc @@ -750,10 +753,11 @@ type timeFunc struct { timeInterfaceImpl } -func newTimeFunc(name string, expressions ...Expression) *timeFunc { +// NewTimeFunc creates new time function with name and expression parameters +func NewTimeFunc(name string, expressions ...Expression) *timeFunc { timeFun := &timeFunc{} - timeFun.funcExpressionImpl = *newFunc(name, expressions, timeFun) + timeFun.funcExpressionImpl = *NewFunc(name, expressions, timeFun) timeFun.timeInterfaceImpl.parent = timeFun return timeFun @@ -767,7 +771,7 @@ type timezFunc struct { func newTimezFunc(name string, expressions ...Expression) *timezFunc { timezFun := &timezFunc{} - timezFun.funcExpressionImpl = *newFunc(name, expressions, timezFun) + timezFun.funcExpressionImpl = *NewFunc(name, expressions, timezFun) timezFun.timezInterfaceImpl.parent = timezFun return timezFun @@ -782,7 +786,7 @@ type timestampFunc struct { func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc { timestampFunc := ×tampFunc{} - timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) + timestampFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampFunc) timestampFunc.timestampInterfaceImpl.parent = timestampFunc return timestampFunc @@ -796,7 +800,7 @@ type timestampzFunc struct { func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { timestampzFunc := ×tampzFunc{} - timestampzFunc.funcExpressionImpl = *newFunc(name, expressions, timestampzFunc) + timestampzFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampzFunc) timestampzFunc.timestampzInterfaceImpl.parent = timestampzFunc return timestampzFunc @@ -804,5 +808,5 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { // Func can be used to call an custom or as of yet unsupported function in the database. func Func(name string, expressions ...Expression) Expression { - return newFunc(name, expressions, nil) + return NewFunc(name, expressions, nil) } diff --git a/internal/jet/interval.go b/internal/jet/interval.go index 5b371e1..debcb57 100644 --- a/internal/jet/interval.go +++ b/internal/jet/interval.go @@ -19,7 +19,7 @@ func (i *IsIntervalImpl) isInterval() {} // NewInterval creates new interval from serializer func NewInterval(s Serializer) *IntervalImpl { newInterval := &IntervalImpl{ - interval: s, + Value: s, } return newInterval @@ -27,11 +27,11 @@ func NewInterval(s Serializer) *IntervalImpl { // IntervalImpl is implementation of Interval type type IntervalImpl struct { - interval Serializer + Value Serializer IsIntervalImpl } func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { out.WriteString("INTERVAL") - i.interval.serialize(statement, out, FallTrough(options)...) + i.Value.serialize(statement, out, FallTrough(options)...) } diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 4158977..c1419aa 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -20,6 +20,11 @@ import ( "github.com/google/go-cmp/cmp" ) +// UnixTimeComparer will compare time equality while ignoring time zone +var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool { + return t1.Unix() == t2.Unix() +}) + // AssertExec assert statement execution for successful execution and number of rows affected func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) @@ -113,7 +118,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st _, args := query.Sql() if len(expectedArgs) > 0 { - AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") + AssertDeepEqual(t, args, expectedArgs) } debugSql := query.DebugSql() @@ -223,9 +228,9 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st } // AssertDeepEqual checks if actual and expected objects are deeply equal. -func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { - if !assert.True(t, cmp.Equal(actual, expected), msg) { - printDiff(actual, expected) +func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) { + if !assert.True(t, cmp.Equal(actual, expected, option...)) { + printDiff(actual, expected, option...) t.FailNow() } } @@ -237,7 +242,8 @@ func assertQueryString(t *testing.T, actual, expected string) { } } -func printDiff(actual, expected interface{}) { +func printDiff(actual, expected interface{}, options ...cmp.Option) { + fmt.Println(cmp.Diff(actual, expected, options...)) fmt.Println("Actual: ") fmt.Println(actual) fmt.Println("Expected: ") diff --git a/qrm/internal/null_types.go b/qrm/internal/null_types.go index d09a712..ab75cf6 100644 --- a/qrm/internal/null_types.go +++ b/qrm/internal/null_types.go @@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) error { } // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value. - // At this point we try to parse time using some of the predefined formats + // At this point we try to parse those values using some of the predefined formats nt.Time, nt.Valid = tryParseAsTime(value) if !nt.Valid { @@ -70,6 +70,7 @@ func (nt *NullTime) Scan(value interface{}) error { } var formats = []string{ + "2006-01-02 15:04:05-07:00", // sqlite "2006-01-02 15:04:05.999999", // go-sql-driver/mysql "15:04:05-07", // pgx "15:04:05.999999", // pgx @@ -84,6 +85,8 @@ func tryParseAsTime(value interface{}) (time.Time, bool) { timeStr = v case []byte: timeStr = string(v) + case int64: + return time.Unix(v, 0), true // sqlite default: return time.Time{}, false } diff --git a/sqlite/cast.go b/sqlite/cast.go new file mode 100644 index 0000000..517fb95 --- /dev/null +++ b/sqlite/cast.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +type cast interface { + AS(castType string) Expression + AS_TEXT() StringExpression + AS_NUMERIC() FloatExpression + AS_INTEGER() IntegerExpression + AS_REAL() FloatExpression + AS_BLOB() StringExpression +} + +type castImpl struct { + jet.Cast +} + +// CAST function converts a expr (of any type) into latter specified datatype. +func CAST(expr Expression) cast { + castImpl := &castImpl{} + castImpl.Cast = jet.NewCastImpl(expr) + return castImpl +} + +// AS casts expressions to castType +func (c *castImpl) AS(castType string) Expression { + return c.Cast.AS(castType) +} + +// AS_TEXT cast expression to TEXT type +func (c *castImpl) AS_TEXT() StringExpression { + return StringExp(c.AS("TEXT")) +} + +// AS_NUMERIC cast expression to NUMERIC type +func (c *castImpl) AS_NUMERIC() FloatExpression { + return FloatExp(c.AS("NUMERIC")) +} + +// AS_INTEGER cast expression to INTEGER type +func (c *castImpl) AS_INTEGER() IntegerExpression { + return IntExp(c.AS("INTEGER")) +} + +// AS_REAL cast expression to REAL type +func (c *castImpl) AS_REAL() FloatExpression { + return FloatExp(c.AS("REAL")) +} + +// AS_BLOB cast expression to BLOB type +func (c *castImpl) AS_BLOB() StringExpression { + return StringExp(c.AS("BLOB")) +} diff --git a/sqlite/cast_test.go b/sqlite/cast_test.go new file mode 100644 index 0000000..c0ef914 --- /dev/null +++ b/sqlite/cast_test.go @@ -0,0 +1,14 @@ +package sqlite + +import ( + "testing" +) + +func TestCAST(t *testing.T) { + assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`) + assertSerialize(t, CAST(Int(22)).AS_TEXT(), `CAST(? AS TEXT)`) + assertSerialize(t, CAST(Int(22)).AS_NUMERIC(), `CAST(? AS NUMERIC)`) + assertSerialize(t, CAST(String("22")).AS_INTEGER(), `CAST(? AS INTEGER)`) + assertSerialize(t, CAST(String("22.2")).AS_REAL(), `CAST(? AS REAL)`) + assertSerialize(t, CAST(String("blob")).AS_BLOB(), `CAST(? AS BLOB)`) +} diff --git a/sqlite/columns.go b/sqlite/columns.go new file mode 100644 index 0000000..88ae4f6 --- /dev/null +++ b/sqlite/columns.go @@ -0,0 +1,58 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Column is common column interface for all types of columns. +type Column = jet.ColumnExpression + +// ColumnList function returns list of columns that be used as projection or column list for UPDATE and INSERT statement. +type ColumnList = jet.ColumnList + +// ColumnBool is interface for SQL boolean columns. +type ColumnBool = jet.ColumnBool + +// BoolColumn creates named bool column. +var BoolColumn = jet.BoolColumn + +// ColumnString is interface for SQL text, character, character varying +// bytea, uuid columns and enums types. +type ColumnString = jet.ColumnString + +// StringColumn creates named string column. +var StringColumn = jet.StringColumn + +// ColumnInteger is interface for SQL smallint, integer, bigint columns. +type ColumnInteger = jet.ColumnInteger + +// IntegerColumn creates named integer column. +var IntegerColumn = jet.IntegerColumn + +// ColumnFloat is interface for SQL real, numeric, decimal or double precision column. +type ColumnFloat = jet.ColumnFloat + +// FloatColumn creates named float column. +var FloatColumn = jet.FloatColumn + +// ColumnTime is interface for SQL time column. +type ColumnTime = jet.ColumnTime + +// TimeColumn creates named time column +var TimeColumn = jet.TimeColumn + +// ColumnDate is interface of SQL date columns. +type ColumnDate = jet.ColumnDate + +// DateColumn creates named date column. +var DateColumn = jet.DateColumn + +// ColumnDateTime is interface of SQL timestamp columns. +type ColumnDateTime = jet.ColumnTimestamp + +// DateTimeColumn creates named timestamp column +var DateTimeColumn = jet.TimestampColumn + +//ColumnTimestamp is interface of SQL timestamp columns. +type ColumnTimestamp = jet.ColumnTimestamp + +// TimestampColumn creates named timestamp column +var TimestampColumn = jet.TimestampColumn diff --git a/sqlite/delete_statement.go b/sqlite/delete_statement.go new file mode 100644 index 0000000..dee85c0 --- /dev/null +++ b/sqlite/delete_statement.go @@ -0,0 +1,61 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// DeleteStatement is interface for MySQL DELETE statement +type DeleteStatement interface { + Statement + + WHERE(expression BoolExpression) DeleteStatement + ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement + LIMIT(limit int64) DeleteStatement + RETURNING(projections ...jet.Projection) DeleteStatement +} + +type deleteStatementImpl struct { + jet.SerializerStatement + + Delete jet.ClauseStatementBegin + Where jet.ClauseWhere + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Returning jet.ClauseReturning +} + +func newDeleteStatement(table Table) DeleteStatement { + newDelete := &deleteStatementImpl{} + newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete, + &newDelete.Delete, + &newDelete.Where, + &newDelete.OrderBy, + &newDelete.Limit, + &newDelete.Returning, + ) + + newDelete.Delete.Name = "DELETE FROM" + newDelete.Delete.Tables = append(newDelete.Delete.Tables, table) + newDelete.Where.Mandatory = true + newDelete.Limit.Count = -1 + + return newDelete +} + +func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { + d.Where.Condition = expression + return d +} + +func (d *deleteStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement { + d.OrderBy.List = orderByClauses + return d +} + +func (d *deleteStatementImpl) LIMIT(limit int64) DeleteStatement { + d.Limit.Count = limit + return d +} + +func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement { + d.Returning.ProjectionList = projections + return d +} diff --git a/sqlite/delete_statement_test.go b/sqlite/delete_statement_test.go new file mode 100644 index 0000000..6620c9f --- /dev/null +++ b/sqlite/delete_statement_test.go @@ -0,0 +1,26 @@ +package sqlite + +import ( + "testing" +) + +func TestDeleteUnconditionally(t *testing.T) { + assertStatementSqlErr(t, table1.DELETE(), `jet: WHERE clause not set`) + assertStatementSqlErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`) +} + +func TestDeleteWithWhere(t *testing.T) { + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), ` +DELETE FROM db.table1 +WHERE table1.col1 = ?; +`, int64(1)) +} + +func TestDeleteWithWhereOrderByLimit(t *testing.T) { + assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).ORDER_BY(table1Col1).LIMIT(1), ` +DELETE FROM db.table1 +WHERE table1.col1 = ? +ORDER BY table1.col1 +LIMIT ?; +`, int64(1), int64(1)) +} diff --git a/sqlite/dialect.go b/sqlite/dialect.go new file mode 100644 index 0000000..93e1d2f --- /dev/null +++ b/sqlite/dialect.go @@ -0,0 +1,225 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +// Dialect is implementation of SQL Builder for SQLite databases. +var Dialect = newDialect() + +func newDialect() jet.Dialect { + operatorSerializeOverrides := map[string]jet.SerializeOverride{} + operatorSerializeOverrides["IS DISTINCT FROM"] = sqlite_IS_DISTINCT_FROM + operatorSerializeOverrides["IS NOT DISTINCT FROM"] = sqlite_IS_NOT_DISTINCT_FROM + operatorSerializeOverrides["#"] = sqliteBitXOR + + mySQLDialectParams := jet.DialectParams{ + Name: "SQLite", + PackageName: "sqlite", + OperatorSerializeOverrides: operatorSerializeOverrides, + AliasQuoteChar: '"', + IdentifierQuoteChar: '`', + ArgumentPlaceholder: func(int) string { + return "?" + }, + ReservedWords: reservedWords2, + } + + return jet.NewDialect(mySQLDialectParams) +} + +func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(expressions) < 2 { + panic("jet: invalid number of expressions for operator XOR") + } + + // (~(a&b))&(a|b) + a := expressions[0] + b := expressions[1] + + out.WriteString("(~(") + jet.Serialize(a, statement, out, options...) + out.WriteByte('&') + jet.Serialize(b, statement, out, options...) + out.WriteString("))&(") + jet.Serialize(a, statement, out, options...) + out.WriteByte('|') + jet.Serialize(b, statement, out, options...) + out.WriteByte(')') + } +} + +func sqlite_IS_NOT_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(expressions) < 2 { + panic("jet: invalid number of expressions for operator") + } + + jet.Serialize(expressions[0], statement, out) + out.WriteString("IS") + jet.Serialize(expressions[1], statement, out) + } +} + +func sqlite_IS_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc { + return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(expressions) < 2 { + panic("jet: invalid number of expressions for operator") + } + + jet.Serialize(expressions[0], statement, out) + out.WriteString("IS NOT") + jet.Serialize(expressions[1], statement, out) + } +} + +var reservedWords2 = []string{ + "ABORT", + "ACTION", + "ADD", + "AFTER", + "ALL", + "ALTER", + "ALWAYS", + "ANALYZE", + "AND", + "AS", + "ASC", + "ATTACH", + "AUTOINCREMENT", + "BEFORE", + "BEGIN", + "BETWEEN", + "BY", + "CASCADE", + "CASE", + "CAST", + "CHECK", + "COLLATE", + "COLUMN", + "COMMIT", + "CONFLICT", + "CONSTRAINT", + "CREATE", + "CROSS", + "CURRENT", + "CURRENT_DATE", + "CURRENT_TIME", + "CURRENT_TIMESTAMP", + "DATABASE", + "DEFAULT", + "DEFERRABLE", + "DEFERRED", + "DELETE", + "DESC", + "DETACH", + "DISTINCT", + "DO", + "DROP", + "EACH", + "ELSE", + "END", + "ESCAPE", + "EXCEPT", + "EXCLUDE", + "EXCLUSIVE", + "EXISTS", + "EXPLAIN", + "FAIL", + "FILTER", + "FIRST", + "FOLLOWING", + "FOR", + "FOREIGN", + "FROM", + "FULL", + "GENERATED", + "GLOB", + "GROUP", + "GROUPS", + "HAVING", + "IF", + "IGNORE", + "IMMEDIATE", + "IN", + "INDEX", + "INDEXED", + "INITIALLY", + "INNER", + "INSERT", + "INSTEAD", + "INTERSECT", + "INTO", + "IS", + "ISNULL", + "JOIN", + "KEY", + "LAST", + "LEFT", + "LIKE", + "LIMIT", + "MATCH", + "MATERIALIZED", + "NATURAL", + "NO", + "NOT", + "NOTHING", + "NOTNULL", + "NULL", + "NULLS", + "OF", + "OFFSET", + "ON", + "OR", + "ORDER", + "OTHERS", + "OUTER", + "OVER", + "PARTITION", + "PLAN", + "PRAGMA", + "PRECEDING", + "PRIMARY", + "QUERY", + "RAISE", + "RANGE", + "RECURSIVE", + "REFERENCES", + "REGEXP", + "REINDEX", + "RELEASE", + "RENAME", + "REPLACE", + "RESTRICT", + "RETURNING", + "RIGHT", + "ROLLBACK", + "ROW", + "ROWS", + "SAVEPOINT", + "SELECT", + "SET", + "TABLE", + "TEMP", + "TEMPORARY", + "THEN", + "TIES", + "TO", + "TRANSACTION", + "TRIGGER", + "UNBOUNDED", + "UNION", + "UNIQUE", + "UPDATE", + "USING", + "VACUUM", + "VALUES", + "VIEW", + "VIRTUAL", + "WHEN", + "WHERE", + "WINDOW", + "WITH", + "WITHOUT", +} diff --git a/sqlite/dialect_test.go b/sqlite/dialect_test.go new file mode 100644 index 0000000..e90357f --- /dev/null +++ b/sqlite/dialect_test.go @@ -0,0 +1,59 @@ +package sqlite + +import ( + "testing" +) + +func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) { + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS NOT table2.col_bool)") + assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS NOT ?)", false) +} + +func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) { + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS table2.col_bool)") + assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS ?)", false) +} + +func TestBoolLiteral(t *testing.T) { + assertSerialize(t, Bool(true), "?", true) + assertSerialize(t, Bool(false), "?", false) +} + +func TestIntegerExpressionDIV(t *testing.T) { + assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int / table2.col_int)") + assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int / ?)", int64(11)) +} + +func TestIntExpressionPOW(t *testing.T) { + assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)") + assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11)) +} + +func TestIntExpressionBIT_XOR(t *testing.T) { + assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "((~(table1.col_int & table2.col_int))&(table1.col_int | table2.col_int))") + assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "((~(table1.col_int & ?))&(table1.col_int | ?))", int64(11), int64(11)) +} + +func TestExists(t *testing.T) { + assertSerialize(t, EXISTS( + table2. + SELECT(Int(1)). + WHERE(table1Col1.EQ(table2Col3)), + ), + `(EXISTS ( + SELECT ? + FROM db.table2 + WHERE table1.col1 = table2.col3 +))`, int64(1)) +} + +func TestString_REGEXP_LIKE_operator(t *testing.T) { + assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN") + +} + +func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) { + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)") + assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN") +} diff --git a/sqlite/expressions.go b/sqlite/expressions.go new file mode 100644 index 0000000..d1d4737 --- /dev/null +++ b/sqlite/expressions.go @@ -0,0 +1,97 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Expression is common interface for all expressions. +// Can be Bool, Int, Float, String, Date, Time or Timestamp expressions. +type Expression = jet.Expression + +// BoolExpression interface +type BoolExpression = jet.BoolExpression + +// StringExpression interface +type StringExpression = jet.StringExpression + +// NumericExpression is shared interface for integer or real expression +type NumericExpression = jet.NumericExpression + +// IntegerExpression interface +type IntegerExpression = jet.IntegerExpression + +// FloatExpression interface +type FloatExpression = jet.FloatExpression + +// TimeExpression interface +type TimeExpression = jet.TimeExpression + +// DateExpression interface +type DateExpression = jet.DateExpression + +// DateTimeExpression interface +type DateTimeExpression = jet.TimestampExpression + +// TimestampExpression interface +type TimestampExpression = jet.TimestampExpression + +// BoolExp is bool expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as bool expression. +// Does not add sql cast to generated sql builder output. +var BoolExp = jet.BoolExp + +// StringExp is string expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as string expression. +// Does not add sql cast to generated sql builder output. +var StringExp = jet.StringExp + +// IntExp is int expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as int expression. +// Does not add sql cast to generated sql builder output. +var IntExp = jet.IntExp + +// FloatExp is date expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as float expression. +// Does not add sql cast to generated sql builder output. +var FloatExp = jet.FloatExp + +// TimeExp is time expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as time expression. +// Does not add sql cast to generated sql builder output. +var TimeExp = jet.TimeExp + +// DateExp is date expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as date expression. +// Does not add sql cast to generated sql builder output. +var DateExp = jet.DateExp + +// DateTimeExp is timestamp expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as timestamp expression. +// Does not add sql cast to generated sql builder output. +var DateTimeExp = jet.TimestampExp + +// TimestampExp is timestamp expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as timestamp expression. +// Does not add sql cast to generated sql builder output. +var TimestampExp = jet.TimestampExp + +// RawArgs is type used to pass optional arguments to Raw method +type RawArgs = map[string]interface{} + +// Raw can be used for any unsupported functions, operators or expressions. +// For example: Raw("current_database()") +// Raw helper methods for each of the sqlite types +var ( + Raw = jet.Raw + + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimestamp = jet.RawTimestamp + RawDate = jet.RawDate +) + +// Func can be used to call an custom or as of yet unsupported function in the database. +var Func = jet.Func + +// NewEnumValue creates new named enum value +var NewEnumValue = jet.NewEnumValue diff --git a/sqlite/expressions_test.go b/sqlite/expressions_test.go new file mode 100644 index 0000000..2c2bbef --- /dev/null +++ b/sqlite/expressions_test.go @@ -0,0 +1,52 @@ +package sqlite + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func TestRaw(t *testing.T) { + assertSerialize(t, Raw("current_database()"), "(current_database())") + assertDebugSerialize(t, Raw("current_database()"), "(current_database())") + + assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(? + table.colInt + ?)", 11, 22) + assertDebugSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}), + "(11 + table.colInt + 22)") + + assertSerialize(t, + Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), + "(? + (? + table.colInt + ?))", + int64(700), 11, 22) + assertDebugSerialize(t, + Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})), + "(700 + (11 + table.colInt + 22))") +} + +func TestRawDuplicateArguments(t *testing.T) { + assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}), + "(? + table.colInt + ?)", 11, 11) + + assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}), + "(? + table.colInt + ? + ? + ? + 11)", 11, 2000, 11, 2000) + + assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4", + RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}), + `(? + all_types.integer + ? + ? + ? + ? + ?)`, 11, 22, 11, 22, 33, 44) +} + +func TestRawInvalidArguments(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, "jet: named argument 'first_arg' does not appear in raw query", r) + }() + + assertSerialize(t, Raw("table.colInt + :second_arg", RawArgs{"first_arg": 11}), "(table.colInt + ?)", 22) +} + +func TestRawType(t *testing.T) { + assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)), + "((table.colInt + ?) = ?)", 11.22, 3.14) + assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")), + "((table.colStr || ?) = ?)", "doe", "john doe") +} diff --git a/sqlite/functions.go b/sqlite/functions.go new file mode 100644 index 0000000..2b70714 --- /dev/null +++ b/sqlite/functions.go @@ -0,0 +1,342 @@ +package sqlite + +import ( + "fmt" + "github.com/go-jet/jet/v2/internal/jet" + "time" +) + +// ROW is construct one table row from list of expressions. +func ROW(expressions ...Expression) Expression { + return jet.NewFunc("", expressions, nil) +} + +// ------------------ Mathematical functions ---------------// + +// ABSf calculates absolute value from float expression +var ABSf = jet.ABSf + +// ABSi calculates absolute value from int expression +var ABSi = jet.ABSi + +// POW calculates power of base with exponent +var POW = jet.POW + +// POWER calculates power of base with exponent +var POWER = jet.POWER + +// SQRT calculates square root of numeric expression +var SQRT = jet.SQRT + +// CBRT calculates cube root of numeric expression +func CBRT(number jet.NumericExpression) jet.FloatExpression { + return POWER(number, Float(1.0).DIV(Float(3.0))) +} + +// CEIL calculates ceil of float expression +var CEIL = jet.CEIL + +// FLOOR calculates floor of float expression +var FLOOR = jet.FLOOR + +// ROUND calculates round of a float expressions with optional precision +var ROUND = jet.ROUND + +// SIGN returns sign of float expression +var SIGN = jet.SIGN + +// TRUNC calculates trunc of float expression with precision +var TRUNC = TRUNCATE + +// TRUNCATE calculates trunc of float expression with precision +var TRUNCATE = func(floatExpression jet.FloatExpression, precision jet.IntegerExpression) jet.FloatExpression { + return jet.NewFloatFunc("TRUNCATE", floatExpression, precision) +} + +// LN calculates natural algorithm of float expression +var LN = jet.LN + +// LOG calculates logarithm of float expression +var LOG = jet.LOG + +// ----------------- Aggregate functions -------------------// + +// AVG is aggregate function used to calculate avg value from numeric expression +var AVG = jet.AVG + +// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none. +//var BIT_AND = jet.BIT_AND + +// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none. +//var BIT_OR = jet.BIT_OR + +// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null. +var COUNT = jet.COUNT + +// MAX is aggregate function. Returns maximum value of expression across all input values +var MAX = jet.MAX + +// MAXi is aggregate function. Returns maximum value of int expression across all input values +var MAXi = jet.MAXi + +// MAXf is aggregate function. Returns maximum value of float expression across all input values +var MAXf = jet.MAXf + +// MIN is aggregate function. Returns minimum value of int expression across all input values +var MIN = jet.MIN + +// MINi is aggregate function. Returns minimum value of int expression across all input values +var MINi = jet.MINi + +// MINf is aggregate function. Returns minimum value of float expression across all input values +var MINf = jet.MINf + +// SUM is aggregate function. Returns sum of all expressions +var SUM = jet.SUM + +// SUMi is aggregate function. Returns sum of integer expression. +var SUMi = jet.SUMi + +// SUMf is aggregate function. Returns sum of float expression. +var SUMf = jet.SUMf + +// -------------------- Window functions -----------------------// + +// ROW_NUMBER returns number of the current row within its partition, counting from 1 +var ROW_NUMBER = jet.ROW_NUMBER + +// RANK of the current row with gaps; same as row_number of its first peer +var RANK = jet.RANK + +// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups +var DENSE_RANK = jet.DENSE_RANK + +// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1) +var PERCENT_RANK = jet.PERCENT_RANK + +// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows +var CUME_DIST = jet.CUME_DIST + +// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible +var NTILE = jet.NTILE + +// LAG returns value evaluated at the row that is offset rows before the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LAG = jet.LAG + +// LEAD returns value evaluated at the row that is offset rows after the current row within the partition; +// if there is no such row, instead return default (which must be of the same type as value). +// Both offset and default are evaluated with respect to the current row. +// If omitted, offset defaults to 1 and default to null +var LEAD = jet.LEAD + +// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame +var FIRST_VALUE = jet.FIRST_VALUE + +// LAST_VALUE returns value evaluated at the row that is the last row of the window frame +var LAST_VALUE = jet.LAST_VALUE + +// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row +var NTH_VALUE = jet.NTH_VALUE + +//--------------------- String functions ------------------// + +// BIT_LENGTH returns number of bits in string expression +//var BIT_LENGTH = jet.BIT_LENGTH +// +//// CHAR_LENGTH returns number of characters in string expression +//var CHAR_LENGTH = jet.CHAR_LENGTH +// +//// OCTET_LENGTH returns number of bytes in string expression +//var OCTET_LENGTH = jet.OCTET_LENGTH + +// LOWER returns string expression in lower case +var LOWER = jet.LOWER + +// UPPER returns string expression in upper case +var UPPER = jet.UPPER + +// LTRIM removes the longest string containing only characters +// from characters (a space by default) from the start of string +var LTRIM = jet.LTRIM + +// RTRIM removes the longest string containing only characters +// from characters (a space by default) from the end of string +var RTRIM = jet.RTRIM + +// CONCAT adds two or more expressions together +//var CONCAT = jet.CONCAT + +// CONCAT_WS adds two or more expressions together with a separator. +//var CONCAT_WS = jet.CONCAT_WS + +// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. +//var FORMAT = jet.FORMAT + +// LEFTSTR returns first n characters in the string. +// When n is negative, return all but last |n| characters. +//func LEFTSTR(str StringExpression, n IntegerExpression) StringExpression { +// return jet.NewStringFunc("LEFTSTR", str, n) +//} +// +//// RIGHT returns last n characters in the string. +//// When n is negative, return all but first |n| characters. +//func RIGHTSTR(str StringExpression, n IntegerExpression) StringExpression { +// return jet.NewStringFunc("RIGHTSTR", str, n) +//} + +// LENGTH returns number of characters in string with a given encoding +func LENGTH(str jet.StringExpression) jet.StringExpression { + return jet.LENGTH(str) +} + +// LPAD fills up the string to length length by prepending the characters +// fill (a space by default). If the string is already longer than length +// then it is truncated (on the right). +//func LPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression { +// return jet.LPAD(str, length, text) +//} + +// RPAD fills up the string to length length by appending the characters +// fill (a space by default). If the string is already longer than length then it is truncated. +//func RPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression { +// return jet.RPAD(str, length, text) +//} + +// MD5 calculates the MD5 hash of string, returning the result in hexadecimal +//var MD5 = jet.MD5 + +// REPEAT repeats string the specified number of times +//var REPEAT = jet.REPEAT + +// REPLACE replaces all occurrences in string of substring from with substring to +var REPLACE = jet.REPLACE + +// REVERSE returns reversed string. +var REVERSE = jet.REVERSE + +// SUBSTR extracts substring +var SUBSTR = jet.SUBSTR + +// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. +var REGEXP_LIKE = jet.REGEXP_LIKE + +//----------------- Date/Time Functions and Operators ------------// + +// CURRENT_DATE returns current date +var CURRENT_DATE = jet.CURRENT_DATE + +// CURRENT_TIME returns current time with time zone +func CURRENT_TIME() TimeExpression { + return TimeExp(jet.CURRENT_TIME()) +} + +// CURRENT_TIMESTAMP returns current timestamp with time zone +func CURRENT_TIMESTAMP() TimestampExpression { + return TimestampExp(jet.CURRENT_TIMESTAMP()) +} + +//// NOW returns current datetime +//func NOW() DateTimeExpression { +// //if len(fsp) > 0 { +// // return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0]))) +// //} +// //return jet.NewTimestampFunc("NOW") +// return DATETIME(jet.FixedLiteral("now")) +//} + +// time-value modifiers +var ( + YEARS = modifier("YEARS") + MONTHS = modifier("MONTHS") + DAYS = modifier("DAYS") + HOURS = modifier("HOURS") + MINUTES = modifier("MINUTES") + SECONDS = modifier("SECONDS") + + START_OF_YEAR = String("start of year") + START_OF_MONTH = String("start of month") + UNIXEPOCH = String("unixepoch") + LOCALTIME = String("localtime") + UTC = String("UTC") + + WEEKDAY = func(value int) Expression { + return String(fmt.Sprintf("WEEKDAY %d", value)) + } +) + +func modifier(modifierName string) func(value float64) Expression { + return func(value float64) Expression { + return String(fmt.Sprintf("%g %s", value, modifierName)) + } +} + +// DATE function creates new date from time-value and zero or more time modifiers +func DATE(timeValue interface{}, modifiers ...Expression) DateExpression { + exprList := getFuncExprList(timeValue, modifiers...) + + return jet.NewDateFunc("DATE", exprList...) +} + +// TIME function creates new time from time-value and zero or more time modifiers +func TIME(timeValue interface{}, modifiers ...Expression) TimeExpression { + exprList := getFuncExprList(timeValue, modifiers...) + + return jet.NewTimeFunc("TIME", exprList...) +} + +// DATETIME function creates new DateTime from time-value and zero or more time modifiers +func DATETIME(timeValue interface{}, modifiers ...Expression) DateTimeExpression { + exprList := getFuncExprList(timeValue, modifiers...) + + return jet.NewTimestampFunc("DATETIME", exprList...) +} + +// JULIANDAY returns the number of days since noon in Greenwich on November 24, 4714 B.C +func JULIANDAY(timeValue interface{}, modifiers ...Expression) FloatExpression { + exprList := getFuncExprList(timeValue, modifiers...) + return jet.NewFloatFunc("JULIANDAY", exprList...) +} + +// STRFTIME routine returns the date formatted according to the format string specified as the first argument. +func STRFTIME(format StringExpression, timeValue interface{}, modifiers ...Expression) StringExpression { + exprList := append([]Expression{format}, getFuncExprList(timeValue, modifiers...)...) + return jet.NewStringFunc("strftime", exprList...) +} + +func getFuncExprList(timeValue interface{}, modifiers ...Expression) []Expression { + return append([]Expression{getTimeValueExpression(timeValue)}, modifiers...) +} + +func getTimeValueExpression(timeValue interface{}) Expression { + switch t := timeValue.(type) { + case string: + return String(t) + case Expression: + return t + case time.Time, int64: + return jet.Literal(t) + } + + panic(fmt.Sprintf("jet: Invalid time value %T(%q)", timeValue, timeValue)) +} + +// TIMESTAMP return a datetime value based on the arguments: +func TIMESTAMP(str StringExpression) TimestampExpression { + return jet.NewTimestampFunc("TIMESTAMP", str) +} + +// UNIX_TIMESTAMP returns unix timestamp +func UNIX_TIMESTAMP(str StringExpression) TimestampExpression { + return jet.NewTimestampFunc("UNIX_TIMESTAMP", str) +} + +//----------- Comparison operators ---------------// + +// EXISTS checks for existence of the rows in subQuery +var EXISTS = jet.EXISTS + +// CASE create CASE operator with optional list of expressions +var CASE = jet.CASE diff --git a/sqlite/insert_statement.go b/sqlite/insert_statement.go new file mode 100644 index 0000000..3912cc3 --- /dev/null +++ b/sqlite/insert_statement.go @@ -0,0 +1,117 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// InsertStatement is interface for SQL INSERT statements +type InsertStatement interface { + Statement + + VALUES(value interface{}, values ...interface{}) InsertStatement + MODEL(data interface{}) InsertStatement + MODELS(data interface{}) InsertStatement + QUERY(selectStatement SelectStatement) InsertStatement + DEFAULT_VALUES() InsertStatement + + ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict + RETURNING(projections ...Projection) InsertStatement +} + +func newInsertStatement(table Table, columns []jet.Column) InsertStatement { + newInsert := &insertStatementImpl{ + DefaultValues: jet.ClauseOptional{Name: "DEFAULT VALUES", InNewLine: true}, + } + + newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, + &newInsert.Insert, + &newInsert.ValuesQuery, + &newInsert.OnDuplicateKey, + &newInsert.DefaultValues, + &newInsert.OnConflict, + &newInsert.Returning, + ) + + newInsert.Insert.Table = table + newInsert.Insert.Columns = columns + newInsert.ValuesQuery.SkipSelectWrap = true + + return newInsert +} + +type insertStatementImpl struct { + jet.SerializerStatement + + Insert jet.ClauseInsert + ValuesQuery jet.ClauseValuesQuery + OnDuplicateKey onDuplicateKeyUpdateClause + DefaultValues jet.ClauseOptional + OnConflict onConflictClause + Returning jet.ClauseReturning +} + +func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) + return is +} + +// MODEL will insert row of values, where value for each column is extracted from filed of structure data. +// If data is not struct or there is no field for every column selected, this method will panic. +func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data)) + return is +} + +func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...) + return is +} + +func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement { + is.OnDuplicateKey = assigments + return is +} + +func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + is.ValuesQuery.Query = selectStatement + return is +} + +func (is *insertStatementImpl) DEFAULT_VALUES() InsertStatement { + is.DefaultValues.Show = true + return is +} + +func (is *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement { + is.Returning.ProjectionList = projections + return is +} + +type onDuplicateKeyUpdateClause []jet.ColumnAssigment + +// Serialize for SetClause +func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(s) == 0 { + return + } + out.NewLine() + out.WriteString("ON DUPLICATE KEY UPDATE") + out.IncreaseIdent(24) + + for i, assigment := range s { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...) + } + + out.DecreaseIdent(24) +} + +func (is *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict { + is.OnConflict = onConflictClause{ + insertStatement: is, + indexExpressions: indexExpressions, + } + return &is.OnConflict +} diff --git a/sqlite/insert_statement_test.go b/sqlite/insert_statement_test.go new file mode 100644 index 0000000..5bb639e --- /dev/null +++ b/sqlite/insert_statement_test.go @@ -0,0 +1,150 @@ +package sqlite + +import ( + "github.com/stretchr/testify/require" + "testing" + "time" +) + +func TestInvalidInsert(t *testing.T) { + assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") +} + +func TestInsertNilValue(t *testing.T) { + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), ` +INSERT INTO db.table1 (col1) +VALUES (?); +`, nil) +} + +func TestInsertSingleValue(t *testing.T) { + assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), ` +INSERT INTO db.table1 (col1) +VALUES (?); +`, int(1)) +} + +func TestInsertWithColumnList(t *testing.T) { + columnList := ColumnList{table3ColInt} + + columnList = append(columnList, table3StrCol) + + assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), ` +INSERT INTO db.table3 (col_int, col2) +VALUES (?, ?); +`, 1, 3) +} + +func TestInsertDate(t *testing.T) { + date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) + + assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), ` +INSERT INTO db.table1 (col_timestamp) +VALUES (?); +`, date) +} + +func TestInsertMultipleValues(t *testing.T) { + assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), ` +INSERT INTO db.table1 (col1, col_float, col3) +VALUES (?, ?, ?); +`, 1, 2, 3) +} + +func TestInsertMultipleRows(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColFloat). + VALUES(1, 2). + VALUES(11, 22). + VALUES(111, 222) + + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?), + (?, ?); +`, 1, 2, 11, 22, 111, 222) +} + +func TestInsertValuesFromModel(t *testing.T) { + type Table1Model struct { + Col1 *int + ColFloat float64 + } + + one := 1 + + toInsert := Table1Model{ + Col1: &one, + ColFloat: 1.11, + } + + stmt := table1.INSERT(table1Col1, table1ColFloat). + MODEL(toInsert). + MODEL(&toInsert) + + expectedSQL := ` +INSERT INTO db.table1 (col1, col_float) +VALUES (?, ?), + (?, ?); +` + + assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11)) +} + +func TestInsertValuesFromModelColumnMismatch(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, r, "missing struct field for column : col1") + }() + type Table1Model struct { + Col1Prim int + Col2 string + } + + newData := Table1Model{ + Col1Prim: 1, + Col2: "one", + } + + table1. + INSERT(table1Col1, table1ColFloat). + MODEL(newData) +} + +func TestInsertFromNonStructModel(t *testing.T) { + + defer func() { + r := recover() + require.Equal(t, r, "jet: data has to be a struct") + }() + + table2.INSERT(table2ColInt).MODEL([]int{}) +} + +func TestInsert_ON_CONFLICT(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1ColBool). + VALUES("one", "two"). + VALUES("1", "2"). + VALUES("theta", "beta"). + ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), + ). + RETURNING(table1Col1, table1ColBool) + + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_bool) +VALUES (?, ?), + (?, ?), + (?, ?) +ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE + SET col_bool = ?, + col_int = ?, + (col1, col_bool) = (?, ?) + WHERE table1.col1 > ? +RETURNING table1.col1 AS "table1.col1", + table1.col_bool AS "table1.col_bool"; +`) +} diff --git a/sqlite/literal.go b/sqlite/literal.go new file mode 100644 index 0000000..2df5dd7 --- /dev/null +++ b/sqlite/literal.go @@ -0,0 +1,70 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" + "time" +) + +// Keywords +var ( + STAR = jet.STAR + NULL = jet.NULL +) + +// Bool creates new bool literal expression +var Bool = jet.Bool + +// Int is constructor for 64 bit signed integer expressions literals. +var Int = jet.Int + +// Int8 is constructor for 8 bit signed integer expressions literals. +var Int8 = jet.Int8 + +// Int16 is constructor for 16 bit signed integer expressions literals. +var Int16 = jet.Int16 + +// Int32 is constructor for 32 bit signed integer expressions literals. +var Int32 = jet.Int32 + +// Int64 is constructor for 64 bit signed integer expressions literals. +var Int64 = jet.Int + +// Uint8 is constructor for 8 bit unsigned integer expressions literals. +var Uint8 = jet.Uint8 + +// Uint16 is constructor for 16 bit unsigned integer expressions literals. +var Uint16 = jet.Uint16 + +// Uint32 is constructor for 32 bit unsigned integer expressions literals. +var Uint32 = jet.Uint32 + +// Uint64 is constructor for 64 bit unsigned integer expressions literals. +var Uint64 = jet.Uint64 + +// Float creates new float literal expression from float64 value +var Float = jet.Float + +// Decimal creates new float literal expression from string value +var Decimal = jet.Decimal + +// String creates new string literal expression +var String = jet.String + +// UUID is a helper function to create string literal expression from uuid object +// value can be any uuid type with a String method +var UUID = jet.UUID + +// Date creates new date literal expression +func Date(year int, month time.Month, day int) DateExpression { + return DATE(jet.Date(year, month, day)) +} + +// Time creates new time literal expression +func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression { + return TIME(jet.Time(hour, minute, second, nanoseconds...)) +} + +// DateTime creates new datetime(timestamp) literal expression +func DateTime(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) DateTimeExpression { + return DATETIME(jet.Timestamp(year, month, day, hour, minute, second, nanoseconds...)) +} diff --git a/sqlite/literal_test.go b/sqlite/literal_test.go new file mode 100644 index 0000000..8a40931 --- /dev/null +++ b/sqlite/literal_test.go @@ -0,0 +1,80 @@ +package sqlite + +import ( + "math" + "testing" + "time" +) + +func TestBool(t *testing.T) { + assertSerialize(t, Bool(false), `?`, false) +} + +func TestInt(t *testing.T) { + assertSerialize(t, Int(11), `?`, int64(11)) +} + +func TestInt8(t *testing.T) { + val := int8(math.MinInt8) + assertSerialize(t, Int8(val), `?`, val) +} + +func TestInt16(t *testing.T) { + val := int16(math.MinInt16) + assertSerialize(t, Int16(val), `?`, val) +} + +func TestInt32(t *testing.T) { + val := int32(math.MinInt32) + assertSerialize(t, Int32(val), `?`, val) +} + +func TestInt64(t *testing.T) { + val := int64(math.MinInt64) + assertSerialize(t, Int64(val), `?`, val) +} + +func TestUint8(t *testing.T) { + val := uint8(math.MaxUint8) + assertSerialize(t, Uint8(val), `?`, val) +} + +func TestUint16(t *testing.T) { + val := uint16(math.MaxUint16) + assertSerialize(t, Uint16(val), `?`, val) +} + +func TestUint32(t *testing.T) { + val := uint32(math.MaxUint32) + assertSerialize(t, Uint32(val), `?`, val) +} + +func TestUint64(t *testing.T) { + val := uint64(math.MaxUint64) + assertSerialize(t, Uint64(val), `?`, val) +} + +func TestFloat(t *testing.T) { + assertSerialize(t, Float(12.34), `?`, float64(12.34)) +} + +func TestString(t *testing.T) { + assertSerialize(t, String("Some text"), `?`, "Some text") +} + +var testTime = time.Now() + +func TestDate(t *testing.T) { + assertSerialize(t, Date(2014, time.January, 2), "DATE(?)", "2014-01-02") + assertSerialize(t, DATE(testTime), "DATE(?)", testTime) +} + +func TestTime(t *testing.T) { + assertSerialize(t, Time(10, 15, 30), `TIME(?)`, "10:15:30") + assertSerialize(t, TIME(testTime), "TIME(?)", testTime) +} + +func TestDateTime(t *testing.T) { + assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `DATETIME(?)`, "2010-03-30 10:15:30") + assertSerialize(t, DATETIME(testTime), `DATETIME(?)`, testTime) +} diff --git a/sqlite/on_conflict_clause.go b/sqlite/on_conflict_clause.go new file mode 100644 index 0000000..d131b9e --- /dev/null +++ b/sqlite/on_conflict_clause.go @@ -0,0 +1,84 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +type onConflict interface { + WHERE(indexPredicate BoolExpression) conflictTarget + conflictTarget +} + +type conflictTarget interface { + DO_NOTHING() InsertStatement + DO_UPDATE(action conflictAction) InsertStatement +} + +type onConflictClause struct { + insertStatement InsertStatement + indexExpressions []jet.ColumnExpression + whereClause jet.ClauseWhere + do jet.Serializer +} + +func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget { + o.whereClause.Condition = indexPredicate + return o +} + +func (o *onConflictClause) DO_NOTHING() InsertStatement { + o.do = jet.Keyword("DO NOTHING") + return o.insertStatement +} + +func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement { + o.do = action + return o.insertStatement +} + +func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(o.indexExpressions) == 0 && o.do == nil { + return + } + + out.NewLine() + out.WriteString("ON CONFLICT") + if len(o.indexExpressions) > 0 { + out.WriteString("(") + jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName) + out.WriteString(")") + } + + o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName) + + out.IncreaseIdent(7) + jet.Serialize(o.do, statementType, out) + out.DecreaseIdent(7) +} + +type conflictAction interface { + jet.Serializer + WHERE(condition BoolExpression) conflictAction +} + +// SET creates conflict action for ON_CONFLICT clause +func SET(assigments ...ColumnAssigment) conflictAction { + conflictAction := updateConflictActionImpl{} + conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} + conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) + conflictAction.set = assigments + return &conflictAction +} + +type updateConflictActionImpl struct { + jet.Serializer + + doUpdate jet.KeywordClause + set jet.SetClauseNew + where jet.ClauseWhere +} + +func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { + u.where.Condition = condition + return u +} diff --git a/sqlite/operators.go b/sqlite/operators.go new file mode 100644 index 0000000..8ebecbf --- /dev/null +++ b/sqlite/operators.go @@ -0,0 +1,9 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// NOT returns negation of bool expression result +var NOT = jet.NOT + +// BIT_NOT inverts every bit in integer expression result +var BIT_NOT = jet.BIT_NOT diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go new file mode 100644 index 0000000..4406dcd --- /dev/null +++ b/sqlite/select_statement.go @@ -0,0 +1,186 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" +) + +// RowLock is interface for SELECT statement row lock types +type RowLock = jet.RowLock + +// Row lock types +var ( + UPDATE = jet.NewRowLock("UPDATE") + SHARE = jet.NewRowLock("SHARE") +) + +// Window function clauses +var ( + PARTITION_BY = jet.PARTITION_BY + ORDER_BY = jet.ORDER_BY + UNBOUNDED = jet.UNBOUNDED + CURRENT_ROW = jet.CURRENT_ROW +) + +// PRECEDING window frame clause +func PRECEDING(offset interface{}) jet.FrameExtent { + return jet.PRECEDING(toJetFrameOffset(offset)) +} + +// FOLLOWING window frame clause +func FOLLOWING(offset interface{}) jet.FrameExtent { + return jet.FOLLOWING(toJetFrameOffset(offset)) +} + +// Window is used to specify window reference from WINDOW clause +var Window = jet.WindowName + +// SelectStatement is interface for MySQL SELECT statement +type SelectStatement interface { + Statement + jet.HasProjections + Expression + + DISTINCT() SelectStatement + FROM(tables ...ReadableTable) SelectStatement + WHERE(expression BoolExpression) SelectStatement + GROUP_BY(groupByClauses ...GroupByClause) SelectStatement + HAVING(boolExpression BoolExpression) SelectStatement + WINDOW(name string) windowExpand + ORDER_BY(orderByClauses ...OrderByClause) SelectStatement + LIMIT(limit int64) SelectStatement + OFFSET(offset int64) SelectStatement + FOR(lock RowLock) SelectStatement + LOCK_IN_SHARE_MODE() SelectStatement + + UNION(rhs SelectStatement) setStatement + UNION_ALL(rhs SelectStatement) setStatement + + AsTable(alias string) SelectTable +} + +//SELECT creates new SelectStatement with list of projections +func SELECT(projection Projection, projections ...Projection) SelectStatement { + return newSelectStatement(nil, append([]Projection{projection}, projections...)) +} + +func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { + newSelect := &selectStatementImpl{} + newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, + &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy, + &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) + + newSelect.Select.ProjectionList = projections + if table != nil { + newSelect.From.Tables = []jet.Serializer{table} + } + newSelect.Limit.Count = -1 + newSelect.Offset.Count = -1 + newSelect.ShareLock.Name = "LOCK IN SHARE MODE" + newSelect.ShareLock.InNewLine = true + + newSelect.setOperatorsImpl.parent = newSelect + + return newSelect +} + +type selectStatementImpl struct { + jet.ExpressionStatement + setOperatorsImpl + + Select jet.ClauseSelect + From jet.ClauseFrom + Where jet.ClauseWhere + GroupBy jet.ClauseGroupBy + Having jet.ClauseHaving + Window jet.ClauseWindow + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Offset jet.ClauseOffset + For jet.ClauseFor + ShareLock jet.ClauseOptional +} + +func (s *selectStatementImpl) DISTINCT() SelectStatement { + s.Select.Distinct = true + return s +} + +func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { + s.From.Tables = nil + for _, table := range tables { + s.From.Tables = append(s.From.Tables, table) + } + return s +} + +func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement { + s.Where.Condition = condition + return s +} + +func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement { + s.GroupBy.List = groupByClauses + return s +} + +func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatement { + s.Having.Condition = boolExpression + return s +} + +func (s *selectStatementImpl) WINDOW(name string) windowExpand { + s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name}) + return windowExpand{selectStatement: s} +} + +func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement { + s.OrderBy.List = orderByClauses + return s +} + +func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement { + s.Limit.Count = limit + return s +} + +func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement { + s.Offset.Count = offset + return s +} + +func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement { + s.For.Lock = lock + return s +} + +func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { + s.ShareLock.Show = true + return s +} + +func (s *selectStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +//----------------------------------------------------- + +type windowExpand struct { + selectStatement *selectStatementImpl +} + +func (w windowExpand) AS(window ...jet.Window) SelectStatement { + if len(window) == 0 { + return w.selectStatement + } + windowsDefinition := w.selectStatement.Window.Definitions + windowsDefinition[len(windowsDefinition)-1].Window = window[0] + return w.selectStatement +} + +func toJetFrameOffset(offset interface{}) jet.Serializer { + if offset == UNBOUNDED { + return jet.UNBOUNDED + } + + return jet.FixedLiteral(offset) +} diff --git a/sqlite/select_statement_test.go b/sqlite/select_statement_test.go new file mode 100644 index 0000000..0ba76f0 --- /dev/null +++ b/sqlite/select_statement_test.go @@ -0,0 +1,156 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + "testing" +) + +func TestInvalidSelect(t *testing.T) { + assertStatementSqlErr(t, SELECT(nil), "jet: Projection is nil") +} + +func TestSelectColumnList(t *testing.T) { + columnList := ColumnList{table2ColInt, table2ColFloat, table3ColInt} + + assertStatementSql(t, SELECT(columnList).FROM(table2), ` +SELECT table2.col_int AS "table2.col_int", + table2.col_float AS "table2.col_float", + table3.col_int AS "table3.col_int" +FROM db.table2; +`) +} + +func TestSelectLiterals(t *testing.T) { + assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), ` +SELECT ?, + ?, + ? +FROM db.table1; +`, int64(1), 2.2, false) +} + +func TestSelectDistinct(t *testing.T) { + assertStatementSql(t, SELECT(table1ColBool).DISTINCT().FROM(table1), ` +SELECT DISTINCT table1.col_bool AS "table1.col_bool" +FROM db.table1; +`) +} + +func TestSelectFrom(t *testing.T) { + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1; +`) + assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) + assertStatementSql(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), ` +SELECT table1.col_int AS "table1.col_int", + table2.col_float AS "table2.col_float" +FROM db.table1 + INNER JOIN db.table2 ON (table1.col_int = table2.col_int); +`) +} + +func TestSelectWhere(t *testing.T) { + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE ?; +`, true) + assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE table1.col_int >= ?; +`, int64(10)) +} + +func TestSelectGroupBy(t *testing.T) { + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +GROUP BY table2.col_float; +`) +} + +func TestSelectHaving(t *testing.T) { + assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), ` +SELECT table3.col_int AS "table3.col_int" +FROM db.table3 +HAVING table1.col_bool = ?; +`, true) +} + +func TestSelectOrderBy(t *testing.T) { + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC; +`) + assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), ` +SELECT table2.col_float AS "table2.col_float" +FROM db.table2 +ORDER BY table2.col_int DESC, table2.col_int ASC; +`) +} + +func TestSelectLimitOffset(t *testing.T) { + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT ?; +`, int64(10)) + assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), ` +SELECT table2.col_int AS "table2.col_int" +FROM db.table2 +LIMIT ? +OFFSET ?; +`, int64(10), int64(2)) +} + +func TestSelectLock(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR UPDATE; +`) + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +FOR SHARE NOWAIT; +`) +} + +func TestSelect_LOCK_IN_SHARE_MODE(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).LOCK_IN_SHARE_MODE(), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +LOCK IN SHARE MODE; +`) +} + +func TestSelect_NOT_EXISTS(t *testing.T) { + testutils.AssertStatementSql(t, + SELECT(table1ColInt). + FROM(table1). + WHERE( + NOT(EXISTS( + SELECT(table2ColInt). + FROM(table2). + WHERE( + table1ColInt.EQ(table2ColInt), + ), + ))), ` +SELECT table1.col_int AS "table1.col_int" +FROM db.table1 +WHERE (NOT (EXISTS ( + SELECT table2.col_int AS "table2.col_int" + FROM db.table2 + WHERE table1.col_int = table2.col_int + ))); +`) +} diff --git a/sqlite/select_table.go b/sqlite/select_table.go new file mode 100644 index 0000000..4117e06 --- /dev/null +++ b/sqlite/select_table.go @@ -0,0 +1,24 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// SelectTable is interface for MySQL sub-queries +type SelectTable interface { + readableTable + jet.SelectTable +} + +type selectTableImpl struct { + jet.SelectTable + readableTableInterfaceImpl +} + +func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTable: jet.NewSelectTable(selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/sqlite/set_statement.go b/sqlite/set_statement.go new file mode 100644 index 0000000..18bcca5 --- /dev/null +++ b/sqlite/set_statement.go @@ -0,0 +1,99 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// UNION effectively appends the result of sub-queries(select statements) into single query. +// It eliminates duplicate rows from its result. +func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { + return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...)) +} + +// UNION_ALL effectively appends the result of sub-queries(select statements) into single query. +// It does not eliminates duplicate rows from its result. +func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement { + return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...)) +} + +type setStatement interface { + setOperators + + ORDER_BY(orderByClauses ...OrderByClause) setStatement + + LIMIT(limit int64) setStatement + OFFSET(offset int64) setStatement + + AsTable(alias string) SelectTable +} + +type setOperators interface { + jet.Statement + jet.HasProjections + jet.Expression + + UNION(rhs SelectStatement) setStatement + UNION_ALL(rhs SelectStatement) setStatement +} + +type setOperatorsImpl struct { + parent setOperators +} + +func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement { + return UNION(s.parent, rhs) +} + +func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement { + return UNION_ALL(s.parent, rhs) +} + +type setStatementImpl struct { + jet.ExpressionStatement + + setOperatorsImpl + + setOperator jet.ClauseSetStmtOperator +} + +func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement { + newSetStatement := &setStatementImpl{} + newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement, + &newSetStatement.setOperator) + + newSetStatement.setOperator.Operator = operator + newSetStatement.setOperator.All = all + newSetStatement.setOperator.Selects = selects + newSetStatement.setOperator.Limit.Count = -1 + newSetStatement.setOperator.Offset.Count = -1 + newSetStatement.setOperator.SkipSelectWrap = true + + newSetStatement.setOperatorsImpl.parent = newSetStatement + + return newSetStatement +} + +func (s *setStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) setStatement { + s.setOperator.OrderBy.List = orderByClauses + return s +} + +func (s *setStatementImpl) LIMIT(limit int64) setStatement { + s.setOperator.Limit.Count = limit + return s +} + +func (s *setStatementImpl) OFFSET(offset int64) setStatement { + s.setOperator.Offset.Count = offset + return s +} + +func (s *setStatementImpl) AsTable(alias string) SelectTable { + return newSelectTable(s, alias) +} + +const ( + union = "UNION" +) + +func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement { + return append([]jet.SerializerStatement{lhs, rhs}, selects...) +} diff --git a/sqlite/set_statement_test.go b/sqlite/set_statement_test.go new file mode 100644 index 0000000..c822089 --- /dev/null +++ b/sqlite/set_statement_test.go @@ -0,0 +1,31 @@ +package sqlite + +import ( + "testing" +) + +func TestSelectSets(t *testing.T) { + select1 := SELECT(table1ColBool).FROM(table1) + select2 := SELECT(table2ColBool).FROM(table2) + + assertStatementSql(t, select1.UNION(select2), ` + +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 + +UNION + +SELECT table2.col_bool AS "table2.col_bool" +FROM db.table2; +`) + assertStatementSql(t, select1.UNION_ALL(select2), ` + +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 + +UNION ALL + +SELECT table2.col_bool AS "table2.col_bool" +FROM db.table2; +`) +} diff --git a/sqlite/statement.go b/sqlite/statement.go new file mode 100644 index 0000000..754ae41 --- /dev/null +++ b/sqlite/statement.go @@ -0,0 +1,8 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// RawStatement creates new sql statements from raw query and optional map of named arguments +func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement { + return jet.RawStatement(Dialect, rawQuery, namedArguments...) +} diff --git a/sqlite/table.go b/sqlite/table.go new file mode 100644 index 0000000..6d70f7f --- /dev/null +++ b/sqlite/table.go @@ -0,0 +1,122 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Table is interface for MySQL tables +type Table interface { + jet.SerializerTable + readableTable + + INSERT(columns ...jet.Column) InsertStatement + UPDATE(columns ...jet.Column) UpdateStatement + DELETE() DeleteStatement +} + +type readableTable interface { + // Generates a select query on the current tableName. + SELECT(projection Projection, projections ...Projection) SelectStatement + + // Creates a inner join tableName Expression using onCondition. + INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a left join tableName Expression using onCondition. + LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a right join tableName Expression using onCondition. + RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a full join tableName Expression using onCondition. + FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable + + // Creates a cross join tableName Expression using onCondition. + CROSS_JOIN(table ReadableTable) joinSelectUpdateTable +} + +type joinSelectUpdateTable interface { + ReadableTable + UPDATE(columns ...jet.Column) UpdateStatement +} + +// ReadableTable interface +type ReadableTable interface { + readableTable + jet.Serializer +} + +type readableTableInterfaceImpl struct { + parent ReadableTable +} + +// Generates a select query on the current tableName. +func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { + return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) +} + +// Creates a inner join tableName Expression using onCondition. +func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.InnerJoin, onCondition) +} + +// Creates a left join tableName Expression using onCondition. +func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.LeftJoin, onCondition) +} + +// Creates a right join tableName Expression using onCondition. +func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.RightJoin, onCondition) +} + +func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.FullJoin, onCondition) +} + +func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable { + return newJoinTable(r.parent, table, jet.CrossJoin, nil) +} + +// NewTable creates new table with schema Name, table Name and list of columns +func NewTable(schemaName, name, alias string, columns ...jet.ColumnExpression) Table { + t := &tableImpl{ + SerializerTable: jet.NewTable(schemaName, name, alias, columns...), + } + + t.readableTableInterfaceImpl.parent = t + t.parent = t + + return t +} + +type tableImpl struct { + jet.SerializerTable + readableTableInterfaceImpl + parent Table +} + +func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement { + return newInsertStatement(t.parent, jet.UnwidColumnList(columns)) +} + +func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement { + return newUpdateStatement(t.parent, jet.UnwidColumnList(columns)) +} + +func (t *tableImpl) DELETE() DeleteStatement { + return newDeleteStatement(t.parent) +} + +type joinTable struct { + tableImpl + jet.JoinTable +} + +func newJoinTable(lhs jet.Serializer, rhs jet.Serializer, joinType jet.JoinType, onCondition BoolExpression) Table { + newJoinTable := &joinTable{ + JoinTable: jet.NewJoinTable(lhs, rhs, joinType, onCondition), + } + + newJoinTable.readableTableInterfaceImpl.parent = newJoinTable + newJoinTable.parent = newJoinTable + + return newJoinTable +} diff --git a/sqlite/table_test.go b/sqlite/table_test.go new file mode 100644 index 0000000..a68d562 --- /dev/null +++ b/sqlite/table_test.go @@ -0,0 +1,101 @@ +package sqlite + +import ( + "testing" +) + +func TestJoinNilInputs(t *testing.T) { + assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)), + "jet: right hand side of join operation is nil table") + assertSerializeErr(t, table2.INNER_JOIN(table1, nil), + "jet: join condition is nil") +} + +func TestINNER_JOIN(t *testing.T) { + assertSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)). + INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = table2.col_int) +INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + INNER_JOIN(table2, table1ColInt.EQ(Int(1))). + INNER_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +INNER JOIN db.table2 ON (table1.col_int = ?) +INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestLEFT_JOIN(t *testing.T) { + assertSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = table2.col_int) +LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + LEFT_JOIN(table2, table1ColInt.EQ(Int(1))). + LEFT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +LEFT JOIN db.table2 ON (table1.col_int = ?) +LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestRIGHT_JOIN(t *testing.T) { + assertSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)). + RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int) +RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))). + RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +RIGHT JOIN db.table2 ON (table1.col_int = ?) +RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestFULL_JOIN(t *testing.T) { + assertSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`) + assertSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)). + FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = table2.col_int) +FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`) + assertSerialize(t, table1. + FULL_JOIN(table2, table1ColInt.EQ(Int(1))). + FULL_JOIN(table3, table1ColInt.EQ(Int(2))), + `db.table1 +FULL JOIN db.table2 ON (table1.col_int = ?) +FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2)) +} + +func TestCROSS_JOIN(t *testing.T) { + assertSerialize(t, table1. + CROSS_JOIN(table2), + `db.table1 +CROSS JOIN db.table2`) + assertSerialize(t, table1. + CROSS_JOIN(table2). + CROSS_JOIN(table3), + `db.table1 +CROSS JOIN db.table2 +CROSS JOIN db.table3`) +} diff --git a/sqlite/types.go b/sqlite/types.go new file mode 100644 index 0000000..755be1d --- /dev/null +++ b/sqlite/types.go @@ -0,0 +1,27 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) +type Statement = jet.Statement + +// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. +type Projection = jet.Projection + +// ProjectionList can be used to create conditional constructed projection list. +type ProjectionList = jet.ProjectionList + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment = jet.ColumnAssigment + +// PrintableStatement is a statement which sql query can be logged +type PrintableStatement = jet.PrintableStatement + +// OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY. +type OrderByClause = jet.OrderByClause + +// GroupByClause interface to use as input for GROUP_BY +type GroupByClause = jet.GroupByClause + +// SetLogger sets automatic statement logging +var SetLogger = jet.SetLoggerFunc diff --git a/sqlite/update_statement.go b/sqlite/update_statement.go new file mode 100644 index 0000000..53cf72d --- /dev/null +++ b/sqlite/update_statement.go @@ -0,0 +1,70 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// UpdateStatement is interface of SQL UPDATE statement +type UpdateStatement interface { + jet.Statement + + SET(value interface{}, values ...interface{}) UpdateStatement + MODEL(data interface{}) UpdateStatement + + WHERE(expression BoolExpression) UpdateStatement + RETURNING(projections ...jet.Projection) UpdateStatement +} + +type updateStatementImpl struct { + jet.SerializerStatement + + Update jet.ClauseUpdate + Set jet.SetClause + SetNew jet.SetClauseNew + Where jet.ClauseWhere + Returning jet.ClauseReturning +} + +func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { + update := &updateStatementImpl{} + update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, + &update.Update, + &update.Set, + &update.SetNew, + &update.Where, + &update.Returning) + + update.Update.Table = table + update.Set.Columns = columns + update.Where.Mandatory = true + + return update +} + +func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { + columnAssigment, isColumnAssigment := value.(ColumnAssigment) + + if isColumnAssigment { + u.SetNew = []ColumnAssigment{columnAssigment} + for _, value := range values { + u.SetNew = append(u.SetNew, value.(ColumnAssigment)) + } + } else { + u.Set.Values = jet.UnwindRowFromValues(value, values) + } + + return u +} + +func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { + u.Set.Values = jet.UnwindRowFromModel(u.Set.Columns, data) + return u +} + +func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { + u.Where.Condition = expression + return u +} + +func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { + u.Returning.ProjectionList = projections + return u +} diff --git a/sqlite/update_statement_test.go b/sqlite/update_statement_test.go new file mode 100644 index 0000000..5c468a3 --- /dev/null +++ b/sqlite/update_statement_test.go @@ -0,0 +1,82 @@ +package sqlite + +import ( + "fmt" + "strings" + "testing" +) + +func TestUpdateWithOneValue(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ? +WHERE table1.col_int >= ?; +` + stmt := table1.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatementSql(t, stmt, expectedSQL, 1, int64(33)) +} + +func TestUpdateWithValues(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ?, + col_float = ? +WHERE table1.col_int >= ?; +` + stmt := table1.UPDATE(table1ColInt, table1ColFloat). + SET(1, 22.2). + WHERE(table1ColInt.GT_EQ(Int(33))) + + fmt.Println(stmt.Sql()) + + assertStatementSql(t, stmt, expectedSQL, 1, 22.2, int64(33)) +} + +func TestUpdateOneColumnWithSelect(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_float = ( + SELECT table1.col_float AS "table1.col_float" + FROM db.table1 + ) +WHERE table1.col1 = ?; +` + stmt := table1. + UPDATE(table1ColFloat). + SET( + table1.SELECT(table1ColFloat), + ). + WHERE(table1Col1.EQ(Int(2))) + + assertStatementSql(t, stmt, expectedSQL, int64(2)) +} + +func TestUpdateReservedWorldColumn(t *testing.T) { + type table struct { + Load string + } + + loadColumn := StringColumn("Load") + assertStatementSql(t, + table1.UPDATE(loadColumn). + MODEL( + table{ + Load: "foo", + }, + ). + WHERE(loadColumn.EQ(String("bar"))), strings.Replace(` +UPDATE db.table1 +SET ''Load'' = ? +WHERE ''Load'' = ?; +`, "''", "`", -1), "foo", "bar") +} + +func TestInvalidInputs(t *testing.T) { + assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") + assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause") +} diff --git a/sqlite/utils_test.go b/sqlite/utils_test.go new file mode 100644 index 0000000..3f9b9f3 --- /dev/null +++ b/sqlite/utils_test.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/jet" + "github.com/go-jet/jet/v2/internal/testutils" + "testing" +) + +var table1Col1 = IntegerColumn("col1") +var table1ColBool = BoolColumn("col_bool") +var table1ColInt = IntegerColumn("col_int") +var table1ColFloat = FloatColumn("col_float") +var table1ColString = StringColumn("col_string") +var table1Col3 = IntegerColumn("col3") +var table1ColTimestamp = TimestampColumn("col_timestamp") +var table1ColDate = DateColumn("col_date") +var table1ColTime = TimeColumn("col_time") + +var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1ColString, table1Col3, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTime) + +var table2Col3 = IntegerColumn("col3") +var table2Col4 = IntegerColumn("col4") +var table2ColInt = IntegerColumn("col_int") +var table2ColFloat = FloatColumn("col_float") +var table2ColStr = StringColumn("col_str") +var table2ColBool = BoolColumn("col_bool") +var table2ColTimestamp = TimestampColumn("col_timestamp") +var table2ColDate = DateColumn("col_date") + +var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColDate, table2ColTimestamp) + +var table3Col1 = IntegerColumn("col1") +var table3ColInt = IntegerColumn("col_int") +var table3StrCol = StringColumn("col2") +var table3 = NewTable("db", "table3", "", table3Col1, table3ColInt, table3StrCol) + +func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { + testutils.AssertSerialize(t, Dialect, clause, query, args...) +} + +func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) { + testutils.AssertDebugSerialize(t, Dialect, clause, query, args...) +} + +func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) { + testutils.AssertSerializeErr(t, Dialect, clause, errString) +} + +func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) { + testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...) +} + +var assertPanicErr = testutils.AssertPanicErr +var assertStatementSql = testutils.AssertStatementSql +var assertStatementSqlErr = testutils.AssertStatementSqlErr diff --git a/sqlite/with_statement.go b/sqlite/with_statement.go new file mode 100644 index 0000000..7940dcd --- /dev/null +++ b/sqlite/with_statement.go @@ -0,0 +1,26 @@ +package sqlite + +import "github.com/go-jet/jet/v2/internal/jet" + +// CommonTableExpression contains information about a CTE. +type CommonTableExpression struct { + readableTableInterfaceImpl + jet.CommonTableExpression +} + +// WITH function creates new WITH statement from list of common table expressions +func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement { + return jet.WITH(Dialect, cte...) +} + +// CTE creates new named CommonTableExpression +func CTE(name string) CommonTableExpression { + cte := CommonTableExpression{ + readableTableInterfaceImpl: readableTableInterfaceImpl{}, + CommonTableExpression: jet.CTE(name), + } + + cte.parent = &cte + + return cte +} diff --git a/tests/dbconfig/dbconfig.go b/tests/dbconfig/dbconfig.go index 0481252..ef89c1b 100644 --- a/tests/dbconfig/dbconfig.go +++ b/tests/dbconfig/dbconfig.go @@ -1,6 +1,9 @@ package dbconfig -import "fmt" +import ( + "fmt" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" +) // Postgres test database connection parameters const ( @@ -24,3 +27,10 @@ const ( // MySQLConnectionString is MySQL driver connection string to test database var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort) + +// sqllite +var ( + SakilaDBPath = repo.GetTestDataFilePath("/init/sqlite/sakila.db") + ChinookDBPath = repo.GetTestDataFilePath("/init/sqlite/chinook.db") + TestSampleDBPath = repo.GetTestDataFilePath("/init/sqlite/test_sample.db") +) diff --git a/tests/init/init.go b/tests/init/init.go index 3bd7e64..aa04fb5 100644 --- a/tests/init/init.go +++ b/tests/init/init.go @@ -4,6 +4,8 @@ import ( "database/sql" "flag" "fmt" + "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" "io/ioutil" "os" "os/exec" @@ -15,6 +17,8 @@ import ( "github.com/go-jet/jet/v2/tests/dbconfig" _ "github.com/go-sql-driver/mysql" _ "github.com/lib/pq" + + _ "github.com/mattn/go-sqlite3" ) var testSuite string @@ -39,8 +43,23 @@ func main() { return } + if testSuite == "sqlite" { + initSQLiteDB() + return + } + initMySQLDB() initPostgresDB() + initSQLiteDB() +} + +func initSQLiteDB() { + err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila")) + throw.OnError(err) + err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook")) + throw.OnError(err) + err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample")) + throw.OnError(err) } func initMySQLDB() { diff --git a/tests/internal/utils/repo/repo.go b/tests/internal/utils/repo/repo.go new file mode 100644 index 0000000..3d24039 --- /dev/null +++ b/tests/internal/utils/repo/repo.go @@ -0,0 +1,33 @@ +package repo + +import ( + "os/exec" + "path/filepath" + "strings" +) + +// GetRootDirPath will return this repo full dir path +func GetRootDirPath() string { + cmd := exec.Command("git", "rev-parse", "--show-toplevel") + byteArr, err := cmd.Output() + if err != nil { + panic(err) + } + + return strings.TrimSpace(string(byteArr)) +} + +// GetTestsDirPath will return tests folder full path +func GetTestsDirPath() string { + return filepath.Join(GetRootDirPath(), "tests") +} + +// GetTestsFilePath will return full file path of the file in the tests folder +func GetTestsFilePath(subPath string) string { + return filepath.Join(GetTestsDirPath(), subPath) +} + +// GetTestDataFilePath will return full file path of the file in the testdata folder +func GetTestDataFilePath(subPath string) string { + return filepath.Join(GetTestsDirPath(), "testdata", subPath) +} diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go new file mode 100644 index 0000000..c52d709 --- /dev/null +++ b/tests/sqlite/alltypes_test.go @@ -0,0 +1,912 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/view" + "github.com/go-jet/jet/v2/tests/testdata/results/common" + "github.com/google/uuid" + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + "strings" + "testing" + "time" +) + +func TestAllTypes(t *testing.T) { + + dest := []model.AllTypes{} + + err := SELECT(AllTypes.AllColumns). + FROM(AllTypes). + Query(sampleDB, &dest) + + require.NoError(t, err) + testutils.AssertJSON(t, dest, allTypesJSON) +} + +var allTypesJSON = ` +[ + { + "Boolean": false, + "BooleanPtr": true, + "TinyInt": -3, + "TinyIntPtr": 3, + "SmallInt": 14, + "SmallIntPtr": 14, + "MediumInt": -150, + "MediumIntPtr": 150, + "Integer": -1600, + "IntegerPtr": 1600, + "BigInt": 5000, + "BigIntPtr": 50000, + "Decimal": 1.11, + "DecimalPtr": 1.01, + "Numeric": 2.22, + "NumericPtr": 2.02, + "Float": 3.33, + "FloatPtr": 3.03, + "Double": 4.44, + "DoublePtr": 4.04, + "Real": 5.55, + "RealPtr": 5.05, + "Time": "0000-01-01T10:11:12.33Z", + "TimePtr": "0000-01-01T10:11:12.123456Z", + "Date": "2008-07-04T00:00:00Z", + "DatePtr": "2008-07-04T00:00:00Z", + "DateTime": "2011-12-18T13:17:17Z", + "DateTimePtr": "2011-12-18T13:17:17Z", + "Timestamp": "2007-12-31T23:00:01Z", + "TimestampPtr": "2007-12-31T23:00:01Z", + "Char": "char1", + "CharPtr": "char-ptr", + "VarChar": "varchar", + "VarCharPtr": "varchar-ptr", + "Text": "text", + "TextPtr": "text-ptr", + "Blob": "YmxvYjE=", + "BlobPtr": "YmxvYi1wdHI=" + }, + { + "Boolean": false, + "BooleanPtr": null, + "TinyInt": -3, + "TinyIntPtr": null, + "SmallInt": 14, + "SmallIntPtr": null, + "MediumInt": -150, + "MediumIntPtr": null, + "Integer": -1600, + "IntegerPtr": null, + "BigInt": 5000, + "BigIntPtr": null, + "Decimal": 1.11, + "DecimalPtr": null, + "Numeric": 2.22, + "NumericPtr": null, + "Float": 3.33, + "FloatPtr": null, + "Double": 4.44, + "DoublePtr": null, + "Real": 5.55, + "RealPtr": null, + "Time": "0000-01-01T10:11:12.33Z", + "TimePtr": null, + "Date": "2008-07-04T00:00:00Z", + "DatePtr": null, + "DateTime": "2011-12-18T13:17:17Z", + "DateTimePtr": null, + "Timestamp": "2007-12-31T23:00:01Z", + "TimestampPtr": null, + "Char": "char2", + "CharPtr": null, + "VarChar": "varchar", + "VarCharPtr": null, + "Text": "text", + "TextPtr": null, + "Blob": "YmxvYjI=", + "BlobPtr": null + } +] +` + +func TestAllTypesViewSelect(t *testing.T) { + var dest []model.AllTypesView + + stmt := SELECT(view.AllTypesView.AllColumns). + FROM(view.AllTypesView) + + err := stmt.Query(sampleDB, &dest) + + require.NoError(t, err) + require.Equal(t, len(dest), 2) + + testutils.AssertJSON(t, dest, allTypesJSON) +} + +func TestAllTypesInsert(t *testing.T) { + tx := beginSampleDBTx(t) + + stmt := AllTypes.INSERT(AllTypes.AllColumns). + MODEL(toInsert). + RETURNING(AllTypes.AllColumns) + + var inserted model.AllTypes + err := stmt.Query(tx, &inserted) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, toInsert, inserted, testutils.UnixTimeComparer) + + var dest model.AllTypes + err = AllTypes.SELECT(AllTypes.AllColumns). + WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))). + Query(tx, &dest) + + require.NoError(t, err) + testutils.AssertDeepEqual(t, dest, toInsert, testutils.UnixTimeComparer) + + err = tx.Rollback() + require.NoError(t, err) +} + +var toInsert = model.AllTypes{ + Boolean: false, + BooleanPtr: testutils.BoolPtr(true), + TinyInt: 1, + SmallInt: 3, + MediumInt: 5, + Integer: 7, + BigInt: 9, + TinyIntPtr: testutils.Int8Ptr(11), + SmallIntPtr: testutils.Int16Ptr(33), + MediumIntPtr: testutils.Int32Ptr(55), + IntegerPtr: testutils.Int32Ptr(77), + BigIntPtr: testutils.Int64Ptr(99), + Decimal: 11.22, + DecimalPtr: testutils.Float64Ptr(33.44), + Numeric: 55.66, + NumericPtr: testutils.Float64Ptr(77.88), + Float: 99.00, + FloatPtr: testutils.Float64Ptr(11.22), + Double: 33.44, + DoublePtr: testutils.Float64Ptr(55.66), + Real: 77.88, + RealPtr: testutils.Float32Ptr(99.00), + Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC), + TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)), + Date: time.Now(), + DatePtr: testutils.TimePtr(time.Now()), + DateTime: time.Now(), + DateTimePtr: testutils.TimePtr(time.Now()), + Timestamp: time.Now(), + TimestampPtr: testutils.TimePtr(time.Now()), + Char: "abcd", + CharPtr: testutils.StringPtr("absd"), + VarChar: "abcd", + VarCharPtr: testutils.StringPtr("absd"), + Blob: []byte("large file"), + BlobPtr: testutils.ByteArrayPtr([]byte("very large file")), + Text: "some text", + TextPtr: testutils.StringPtr("text"), +} + +func TestUUID(t *testing.T) { + query := SELECT( + //Raw("uuid()").AS("uuid"), + String("dc8daae3-b83b-11e9-8eb4-98ded00c39c6").AS("str_uuid"), + ) + + var dest struct { + UUID uuid.UUID + StrUUID *uuid.UUID + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + require.Equal(t, dest.StrUUID.String(), "dc8daae3-b83b-11e9-8eb4-98ded00c39c6") + requireLogged(t, query) +} + +func TestExpressionOperators(t *testing.T) { + query := SELECT( + AllTypes.Integer.IS_NULL().AS("result.is_null"), + AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"), + AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"), + AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"), + + Raw("length(121232459)").AS("result.raw"), + Raw(":first + COALESCE(all_types.small_int_ptr, 0) + :second", RawArgs{":first": 78, ":second": 56}). + AS("result.raw_arg"), + Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}). + AS("result.raw_arg2"), + + AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"), + AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"), + ).FROM( + AllTypes, + ).LIMIT(2) + + testutils.AssertStatementSql(t, query, strings.Replace(` +SELECT all_types.integer IS NULL AS "result.is_null", + all_types.date_ptr IS NOT NULL AS "result.is_not_null", + (all_types.small_int_ptr IN (?, ?)) AS "result.in", + (all_types.small_int_ptr IN ( + SELECT all_types.integer AS "all_types.integer" + FROM all_types + )) AS "result.in_select", + (length(121232459)) AS "result.raw", + (? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg", + (? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2", + (all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in", + (all_types.small_int_ptr NOT IN ( + SELECT all_types.integer AS "all_types.integer" + FROM all_types + )) AS "result.not_in_select" +FROM all_types +LIMIT ?; +`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2)) + + var dest []struct { + common.ExpressionTestResult `alias:"result.*"` + } + + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + + require.Equal(t, *dest[0].IsNull, false) + require.Equal(t, *dest[0].IsNotNull, true) + require.Equal(t, *dest[0].In, false) + require.Equal(t, *dest[0].InSelect, false) + require.Equal(t, *dest[0].Raw, "9") + require.Equal(t, *dest[0].RawArg, int32(148)) + require.Equal(t, *dest[0].RawArg2, int32(-1479)) + require.Nil(t, dest[0].NotIn) + require.Equal(t, *dest[0].NotInSelect, true) +} + +func TestBoolOperators(t *testing.T) { + query := AllTypes.SELECT( + AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"), + AllTypes.Boolean.EQ(Bool(true)).AS("EQ2"), + AllTypes.Boolean.NOT_EQ(AllTypes.BooleanPtr).AS("NEq1"), + AllTypes.Boolean.NOT_EQ(Bool(false)).AS("NEq2"), + AllTypes.Boolean.IS_DISTINCT_FROM(AllTypes.BooleanPtr).AS("distinct1"), + AllTypes.Boolean.IS_DISTINCT_FROM(Bool(true)).AS("distinct2"), + AllTypes.Boolean.IS_NOT_DISTINCT_FROM(AllTypes.BooleanPtr).AS("not_distinct_1"), + AllTypes.Boolean.IS_NOT_DISTINCT_FROM(Bool(true)).AS("NOTDISTINCT2"), + AllTypes.Boolean.IS_TRUE().AS("ISTRUE"), + AllTypes.Boolean.IS_NOT_TRUE().AS("isnottrue"), + AllTypes.Boolean.IS_FALSE().AS("is_False"), + AllTypes.Boolean.IS_NOT_FALSE().AS("is not false"), + AllTypes.Boolean.IS_NULL().AS("is unknown"), + AllTypes.Boolean.IS_NOT_NULL().AS("is_not_unknown"), + + AllTypes.Boolean.AND(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)).AS("complex1"), + AllTypes.Boolean.OR(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)).AS("complex2"), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1", + (all_types.boolean = ?) AS "EQ2", + (all_types.boolean != all_types.boolean_ptr) AS "NEq1", + (all_types.boolean != ?) AS "NEq2", + (all_types.boolean IS NOT all_types.boolean_ptr) AS "distinct1", + (all_types.boolean IS NOT ?) AS "distinct2", + (all_types.boolean IS all_types.boolean_ptr) AS "not_distinct_1", + (all_types.boolean IS ?) AS "NOTDISTINCT2", + all_types.boolean IS TRUE AS "ISTRUE", + all_types.boolean IS NOT TRUE AS "isnottrue", + all_types.boolean IS FALSE AS "is_False", + all_types.boolean IS NOT FALSE AS "is not false", + all_types.boolean IS NULL AS "is unknown", + all_types.boolean IS NOT NULL AS "is_not_unknown", + ((all_types.boolean AND all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex1", + ((all_types.boolean OR all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex2" +FROM all_types; +`, true, false, true, true) + + var dest []struct { + Eq1 *bool + Eq2 *bool + NEq1 *bool + NEq2 *bool + Distinct1 *bool + Distinct2 *bool + NotDistinct1 *bool + NotDistinct2 *bool + IsTrue *bool + IsNotTrue *bool + IsFalse *bool + IsNotFalse *bool + IsUnknown *bool + IsNotUnknown *bool + + Complex1 *bool + Complex2 *bool + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + + testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json") +} + +func TestFloatOperators(t *testing.T) { + + query := AllTypes.SELECT( + AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"), + AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"), + AllTypes.Real.EQ(Float(12.12)).AS("eq3"), + AllTypes.Numeric.IS_DISTINCT_FROM(AllTypes.Numeric).AS("distinct1"), + AllTypes.Decimal.IS_DISTINCT_FROM(Float(12)).AS("distinct2"), + AllTypes.Real.IS_DISTINCT_FROM(Float(12.12)).AS("distinct3"), + AllTypes.Numeric.IS_NOT_DISTINCT_FROM(AllTypes.Numeric).AS("not_distinct1"), + AllTypes.Decimal.IS_NOT_DISTINCT_FROM(Float(12)).AS("not_distinct2"), + AllTypes.Real.IS_NOT_DISTINCT_FROM(Float(12.12)).AS("not_distinct3"), + AllTypes.Numeric.LT(Float(124)).AS("lt1"), + AllTypes.Numeric.LT(Float(34.56)).AS("lt2"), + AllTypes.Numeric.GT(Float(124)).AS("gt1"), + AllTypes.Numeric.GT(Float(34.56)).AS("gt2"), + + AllTypes.Decimal.ADD(AllTypes.Decimal).AS("add1"), + AllTypes.Decimal.ADD(Float(11.22)).AS("add2"), + AllTypes.Decimal.SUB(AllTypes.DecimalPtr).AS("sub1"), + AllTypes.Decimal.SUB(Float(11.22)).AS("sub2"), + AllTypes.Decimal.MUL(AllTypes.DecimalPtr).AS("mul1"), + AllTypes.Decimal.MUL(Float(11.22)).AS("mul2"), + AllTypes.Decimal.DIV(AllTypes.DecimalPtr).AS("div1"), + AllTypes.Decimal.DIV(Float(11.22)).AS("div2"), + AllTypes.Decimal.MOD(AllTypes.DecimalPtr).AS("mod1"), + AllTypes.Decimal.MOD(Float(11.22)).AS("mod2"), + + // sqlite driver has to enable SQLITE_ENABLE_MATH_FUNCTIONS before commented math functions can be used + + //AllTypes.Decimal.POW(AllTypes.DecimalPtr).AS("pow1"), + //AllTypes.Decimal.POW(Float(2.1)).AS("pow2"), + + ABSf(AllTypes.Decimal).AS("abs"), + //POWER(AllTypes.Decimal, Float(2.1)).AS("power"), + //SQRT(AllTypes.Decimal).AS("sqrt"), + //CBRT(AllTypes.Decimal).AS("cbrt"), + + //CEIL(AllTypes.Real).AS("ceil"), + //FLOOR(AllTypes.Real).AS("floor"), + ROUND(AllTypes.Decimal).AS("round1"), + ROUND(AllTypes.Decimal, Int(2)).AS("round2"), + //TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), + SIGN(AllTypes.Real).AS("sign"), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query, ` +SELECT (all_types.numeric = all_types.numeric) AS "eq1", + (all_types.decimal = ?) AS "eq2", + (all_types.real = ?) AS "eq3", + (all_types.numeric IS NOT all_types.numeric) AS "distinct1", + (all_types.decimal IS NOT ?) AS "distinct2", + (all_types.real IS NOT ?) AS "distinct3", + (all_types.numeric IS all_types.numeric) AS "not_distinct1", + (all_types.decimal IS ?) AS "not_distinct2", + (all_types.real IS ?) AS "not_distinct3", + (all_types.numeric < ?) AS "lt1", + (all_types.numeric < ?) AS "lt2", + (all_types.numeric > ?) AS "gt1", + (all_types.numeric > ?) AS "gt2", + (all_types.decimal + all_types.decimal) AS "add1", + (all_types.decimal + ?) AS "add2", + (all_types.decimal - all_types.decimal_ptr) AS "sub1", + (all_types.decimal - ?) AS "sub2", + (all_types.decimal * all_types.decimal_ptr) AS "mul1", + (all_types.decimal * ?) AS "mul2", + (all_types.decimal / all_types.decimal_ptr) AS "div1", + (all_types.decimal / ?) AS "div2", + (all_types.decimal % all_types.decimal_ptr) AS "mod1", + (all_types.decimal % ?) AS "mod2", + ABS(all_types.decimal) AS "abs", + ROUND(all_types.decimal) AS "round1", + ROUND(all_types.decimal, ?) AS "round2", + SIGN(all_types.real) AS "sign" +FROM all_types +LIMIT ?; +`) + + var dest struct { + common.FloatExpressionTestResult `alias:"."` + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + require.Equal(t, *dest.Eq1, true) + require.Equal(t, *dest.Distinct1, false) + require.Equal(t, *dest.Lt1, true) + require.Equal(t, *dest.Add1, 2.22) + require.Equal(t, *dest.Mod2, float64(1)) + require.Equal(t, *dest.Round1, float64(1)) + require.Equal(t, *dest.Round2, float64(1.11)) + require.Equal(t, *dest.Sign, float64(1)) + + //testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json") +} + +func TestIntegerOperators(t *testing.T) { + query := AllTypes.SELECT( + AllTypes.BigInt, + AllTypes.BigIntPtr, + AllTypes.SmallInt, + AllTypes.SmallIntPtr, + + AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"), + AllTypes.BigInt.EQ(Int(12)).AS("eq2"), + + AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"), + AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"), + + AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"), + AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"), + + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"), + AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"), + + AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"), + AllTypes.BigInt.LT(Int(65)).AS("lt2"), + + AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"), + AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"), + + AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"), + AllTypes.BigInt.GT(Int(65)).AS("gt2"), + + AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"), + AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"), + + AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"), + AllTypes.BigInt.ADD(Int(11)).AS("add2"), + + AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"), + AllTypes.BigInt.SUB(Int(11)).AS("sub2"), + + AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"), + AllTypes.BigInt.MUL(Int(11)).AS("mul2"), + + AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"), + AllTypes.BigInt.DIV(Int(11)).AS("div2"), + + AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"), + AllTypes.BigInt.MOD(Int(11)).AS("mod2"), + + //AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"), + //AllTypes.SmallInt.POW(Int(6)).AS("pow2"), + + AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"), + AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"), + + AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"), + AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"), + + AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"), + AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), + + BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), + BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"), + + AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), + AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"), + + AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"), + AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"), + + ABSi(AllTypes.BigInt).AS("abs"), + //SQRT(ABSi(AllTypes.BigInt)).AS("sqrt"), + //CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"), + ).LIMIT(2) + + var dest []struct { + common.AllTypesIntegerExpResult `alias:"."` + } + + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) + + require.Equal(t, *dest[0].Eq1, true) + require.Equal(t, *dest[0].Distinct2, true) + require.Equal(t, *dest[0].Lt2, false) + require.Equal(t, *dest[0].Add1, int64(10000)) + require.Equal(t, *dest[0].Mul1, int64(25000000)) + require.Equal(t, *dest[0].Div2, int64(454)) + require.Equal(t, *dest[0].BitAnd1, int64(14)) + require.Equal(t, *dest[0].BitXor2, int64(5)) + require.Equal(t, *dest[0].BitShiftLeft1, int64(1792)) + require.Equal(t, *dest[0].BitShiftRight2, int64(7)) + +} + +func TestStringOperators(t *testing.T) { + + query := SELECT( + AllTypes.Text.EQ(AllTypes.Char), + AllTypes.Text.EQ(String("Text")), + AllTypes.Text.NOT_EQ(AllTypes.VarCharPtr), + AllTypes.Text.NOT_EQ(String("Text")), + AllTypes.Text.GT(AllTypes.Text), + AllTypes.Text.GT(String("Text")), + AllTypes.Text.GT_EQ(AllTypes.TextPtr), + AllTypes.Text.GT_EQ(String("Text")), + AllTypes.Text.LT(AllTypes.Char), + AllTypes.Text.LT(String("Text")), + AllTypes.Text.LT_EQ(AllTypes.VarCharPtr), + AllTypes.Text.LT_EQ(String("Text")), + AllTypes.Text.CONCAT(String("text2")), + AllTypes.Text.CONCAT(Int(11)), + AllTypes.Text.LIKE(String("abc")), + AllTypes.Text.NOT_LIKE(String("_b_")), + //AllTypes.Text.REGEXP_LIKE(String("aba")), + //AllTypes.Text.REGEXP_LIKE(String("aba"), false), + //String("ABA").REGEXP_LIKE(String("aba"), true), + //AllTypes.Text.NOT_REGEXP_LIKE(String("aba")), + //AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false), + //String("ABA").NOT_REGEXP_LIKE(String("aba"), true), + + //BIT_LENGTH(AllTypes.Text), + //CHAR_LENGTH(AllTypes.Char), + //OCTET_LENGTH(AllTypes.Text), + LOWER(AllTypes.VarCharPtr), + UPPER(AllTypes.Char), + LTRIM(AllTypes.VarCharPtr), + RTRIM(AllTypes.VarCharPtr), + //CONCAT(String("string1"), Int(1), Float(11.12)), + //CONCAT_WS(String("string1"), Int(1), Float(11.12)), + //FORMAT(String("Hello %s, %1$s"), String("World")), + //LEFTSTR(String("abcde"), Int(2)), + //RIGHTSTR(String("abcde"), Int(2)), + LENGTH(String("jose")), + //LPAD(String("Hi"), Int(5), String("xy")), + //RPAD(String("Hi"), Int(5), String("xy")), + //MD5(AllTypes.VarCharPtr), + //REPEAT(AllTypes.Text, Int(33)), + REPLACE(AllTypes.Char, String("BA"), String("AB")), + //REVERSE(AllTypes.VarCharPtr), + SUBSTR(AllTypes.CharPtr, Int(3)), + SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), + ).FROM(AllTypes) + + dest := []struct{}{} + err := query.Query(sampleDB, &dest) + + require.NoError(t, err) +} + +func TestReservedWord(t *testing.T) { + stmt := SELECT(ReservedWords.AllColumns). + FROM(ReservedWords) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT ''ReservedWords''.''column'' AS "ReservedWords.column", + ''ReservedWords''.use AS "ReservedWords.use", + ''ReservedWords''.ceil AS "ReservedWords.ceil", + ''ReservedWords''.''commit'' AS "ReservedWords.commit", + ''ReservedWords''.''create'' AS "ReservedWords.create", + ''ReservedWords''.''default'' AS "ReservedWords.default", + ''ReservedWords''.''desc'' AS "ReservedWords.desc", + ''ReservedWords''.empty AS "ReservedWords.empty", + ''ReservedWords''.float AS "ReservedWords.float", + ''ReservedWords''.''join'' AS "ReservedWords.join", + ''ReservedWords''.''like'' AS "ReservedWords.like", + ''ReservedWords''.max AS "ReservedWords.max", + ''ReservedWords''.rank AS "ReservedWords.rank" +FROM ''ReservedWords''; +`, "''", "`", -1)) + + var dest model.ReservedWords + err := stmt.Query(sampleDB, &dest) + require.NoError(t, err) + require.Equal(t, dest, model.ReservedWords{ + Column: "Column", + Use: "CHECK", + Ceil: "CEIL", + Commit: "COMMIT", + Create: "CREATE", + Default: "DEFAULT", + Desc: "DESC", + Empty: "EMPTY", + Float: "FLOAT", + Join: "JOIN", + Like: "LIKE", + Max: "MAX", + Rank: "RANK", + }) +} + +func TestExactDecimals(t *testing.T) { + + type exactDecimals struct { + model.ExactDecimals + Decimal decimal.Decimal + DecimalPtr decimal.Decimal + } + + t.Run("should query decimal", func(t *testing.T) { + query := SELECT( + ExactDecimals.AllColumns, + ).FROM( + ExactDecimals, + ).WHERE(ExactDecimals.Decimal.EQ(String("1.11111111111111111111"))) + + var result exactDecimals + + err := query.Query(sampleDB, &result) + require.NoError(t, err) + + require.Equal(t, "1.11111111111111111111", result.Decimal.String()) + require.Equal(t, "0", result.DecimalPtr.String()) // NULL + + require.Equal(t, "1.11111111111111111111", result.ExactDecimals.Decimal) // precision loss + require.Equal(t, (*string)(nil), result.ExactDecimals.DecimalPtr) + require.Equal(t, "2.22222222222222222222", result.ExactDecimals.Numeric) + require.Equal(t, (*string)(nil), result.ExactDecimals.NumericPtr) // NULL + }) + + t.Run("should insert decimal", func(t *testing.T) { + + insertQuery := ExactDecimals.INSERT( + ExactDecimals.AllColumns, + ).MODEL( + exactDecimals{ + ExactDecimals: model.ExactDecimals{ + // overwritten by wrapped(exactDecimals) scope + Decimal: "0.1", + DecimalPtr: nil, + + // not overwritten + Numeric: "6.7", + NumericPtr: testutils.StringPtr("7.7"), + }, + Decimal: decimal.RequireFromString("91.23"), + DecimalPtr: decimal.RequireFromString("45.67"), + }, + ).RETURNING(ExactDecimals.AllColumns) + + testutils.AssertDebugStatementSql(t, insertQuery, strings.Replace(` +INSERT INTO exact_decimals (decimal, decimal_ptr, numeric, numeric_ptr) +VALUES ('91.23', '45.67', '6.7', '7.7') +RETURNING exact_decimals.decimal AS "exact_decimals.decimal", + exact_decimals.decimal_ptr AS "exact_decimals.decimal_ptr", + exact_decimals.numeric AS "exact_decimals.numeric", + exact_decimals.numeric_ptr AS "exact_decimals.numeric_ptr"; +`, "''", "`", -1)) + + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var result exactDecimals + + err := insertQuery.Query(tx, &result) + require.NoError(t, err) + + require.Equal(t, "91.23", result.Decimal.String()) + require.Equal(t, "45.67", result.DecimalPtr.String()) + + require.Equal(t, "6.7", result.ExactDecimals.Numeric) + require.Equal(t, "7.7", *result.ExactDecimals.NumericPtr) + require.Equal(t, "91.23", result.ExactDecimals.Decimal) + require.Equal(t, "45.67", *result.ExactDecimals.DecimalPtr) + }) +} + +var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) + +func TestDateExpressions(t *testing.T) { + + query := AllTypes.SELECT( + //Date(2009, 11, 17, 2, MONTH, 1, DAY), + + //DateT(timeT, START_OF_THE_MONTH), + AllTypes.Date.AS("date"), + DATE("2009-11-17").AS("date1"), + DATE("2013-10-07 08:23:19.120", DAYS(1)).AS("date2"), + DATE(AllTypes.Date, START_OF_YEAR, DAYS(2)).AS("date3"), + DATE(timeT, START_OF_MONTH).AS("date3"), + DATE("now", WEEKDAY(1)).AS("date4"), + DATE(timeT.Unix(), UNIXEPOCH).AS("date5"), + DATE(time.Now(), UTC).AS("date6"), + DATE(time.Now().UTC(), LOCALTIME).AS("date7"), + + AllTypes.Date.EQ(AllTypes.Date), + AllTypes.Date.EQ(Date(2019, 6, 6)), + + AllTypes.DatePtr.NOT_EQ(AllTypes.Date), + AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)), + + AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date).AS("distinct1"), + AllTypes.Date.IS_DISTINCT_FROM(Date(2008, 7, 4)).AS("distinct2"), + + AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date), + AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)), + + AllTypes.Date.LT(AllTypes.Date), + AllTypes.Date.LT(Date(2019, 4, 6)), + + AllTypes.Date.LT_EQ(AllTypes.Date), + AllTypes.Date.LT_EQ(Date(2019, 5, 5)), + + AllTypes.Date.GT(AllTypes.Date), + AllTypes.Date.GT(Date(2019, 1, 4)), + + AllTypes.Date.GT_EQ(AllTypes.Date), + AllTypes.Date.GT_EQ(Date(2019, 2, 3)), + + //AllTypes.Date.ADD(INTERVAL2(2, HOUR)), + //AllTypes.Date.ADD(INTERVAL2(1, DAY, 7, MONTH)), + //AllTypes.Date.ADD(INTERVALd(25 * time.Hour + 100 * time.Millisecond)), + //AllTypes.Date.ADD(INTERVALd(-25 * time.Hour - 100 * time.Millisecond)), + // + //AllTypes.Date.SUB(INTERVAL(20, MINUTE)), + //AllTypes.Date.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + //AllTypes.Date.SUB(INTERVALd(3*time.Minute)), + + CURRENT_DATE().AS("current_date"), + ) + + var dest struct { + Date string + Date1 time.Time + Date2 string + Date3 time.Time + Date4 string + Date5 time.Time + Date6 string + Date7 time.Time + Distinct1 bool + Distinct2 bool + CurrentDate time.Time + } + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + + require.Equal(t, dest.Date, "2008-07-04T00:00:00Z") + require.Equal(t, dest.Date1.Unix(), int64(1258416000)) +} + +func TestTimeExpressions(t *testing.T) { + + query := AllTypes.SELECT( + TIME(AllTypes.Time).AS("time1"), + TIME(timeT).AS("time2"), + TIME("04:23:19.120-04:00", HOURS(1), MINUTES(2), SECONDS(1.234)).AS("time3"), + TIME(timeT.Unix(), UNIXEPOCH).AS("time4"), + TIME(time.Now(), UTC).AS("time5"), + TIME(time.Now().UTC(), LOCALTIME).AS("time6"), + + Time(timeT.Clock()), + + AllTypes.Time.EQ(AllTypes.Time), + AllTypes.Time.EQ(Time(23, 6, 6)), + AllTypes.Time.EQ(Time(22, 6, 6, 11*time.Millisecond)), + AllTypes.Time.EQ(Time(21, 6, 6, 11111*time.Microsecond)), + + AllTypes.TimePtr.NOT_EQ(AllTypes.Time), + AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)), + + AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time), + AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)), + + AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time), + AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)), + + AllTypes.Time.LT(AllTypes.Time), + AllTypes.Time.LT(Time(17, 46, 6)), + + AllTypes.Time.LT_EQ(AllTypes.Time), + AllTypes.Time.LT_EQ(Time(16, 56, 56)), + + AllTypes.Time.GT(AllTypes.Time), + AllTypes.Time.GT(Time(15, 16, 46)), + + AllTypes.Time.GT_EQ(AllTypes.Time), + AllTypes.Time.GT_EQ(Time(14, 26, 36)), + + //AllTypes.Time.ADD(INTERVAL(10, MINUTE)), + //AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)), + //AllTypes.Time.ADD(INTERVALd(3*time.Hour)), + // + //AllTypes.Time.SUB(INTERVAL(20, MINUTE)), + //AllTypes.Time.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)), + //AllTypes.Time.SUB(INTERVALd(3*time.Minute)), + // + //AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)), + + CURRENT_TIME(), + ) + + var dest struct { + Time1 string + Time2 time.Time + Time3 string + Time4 time.Time + Time5 string + Time6 time.Time + } + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + + require.Equal(t, dest.Time1, "10:11:12") + require.Equal(t, dest.Time2.UTC().String(), "0000-01-01 20:34:58 +0000 UTC") + require.Equal(t, dest.Time3, "09:25:20") +} + +func TestDateTimeExpressions(t *testing.T) { + + var dateTime = DateTime(2019, 6, 6, 10, 2, 46) + + query := SELECT( + DATETIME("now").AS("now"), + DATETIME("2013-10-07T08:23:19.120Z", YEARS(2), MONTHS(1), DAYS(1)).AS("datetime1"), + DATETIME(AllTypes.DateTime, MONTHS(1), DAYS(1)).AS("datetime2"), + DATETIME(timeT.Unix(), UNIXEPOCH).AS("datetime3"), + DATETIME(time.Now(), UTC).AS("datetime4"), + DATETIME(timeT.UTC(), LOCALTIME).AS("datetime5"), + + JULIANDAY(timeT, DAYS(1)).AS("JulianDay"), + STRFTIME(String("%H:%M"), timeT, SECONDS(1.22)).AS("strftime"), + + AllTypes.DateTime.EQ(AllTypes.DateTime), + AllTypes.DateTime.EQ(dateTime), + + AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime), + AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)), + + AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime), + AllTypes.DateTime.IS_DISTINCT_FROM(dateTime), + + AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime), + AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime), + + AllTypes.DateTime.LT(AllTypes.DateTime), + AllTypes.DateTime.LT(dateTime), + + AllTypes.DateTime.LT_EQ(AllTypes.DateTime), + AllTypes.DateTime.LT_EQ(dateTime), + + AllTypes.DateTime.GT(AllTypes.DateTime), + AllTypes.DateTime.GT(dateTime), + + AllTypes.DateTime.GT_EQ(AllTypes.DateTime), + AllTypes.DateTime.GT_EQ(dateTime), + + //AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + //AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)), + //AllTypes.DateTime.ADD(INTERVALd(2*time.Hour)), + // + //AllTypes.DateTime.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)), + //AllTypes.DateTime.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)), + //AllTypes.DateTime.SUB(INTERVALd(3*time.Hour)), + + CURRENT_TIMESTAMP(), + ).FROM(AllTypes) + + var dest struct { + Now time.Time + DateTime1 time.Time + DateTime2 time.Time + DateTime3 time.Time + DateTime4 time.Time + DateTime5 time.Time + JulianDay float64 + StrfTime string + } + + err := query.Query(sampleDB, &dest) + require.NoError(t, err) + require.True(t, dest.Now.After(time.Now().Add(-1*time.Minute))) + require.Equal(t, dest.DateTime1.String(), "2015-11-08 08:23:19 +0000 UTC") + require.Equal(t, dest.DateTime2.String(), "2012-01-19 13:17:17 +0000 UTC") + require.Equal(t, dest.DateTime3.String(), "2009-11-17 20:34:58 +0000 UTC") + require.True(t, dest.DateTime4.After(time.Now().Add(-1*time.Minute))) + require.Equal(t, dest.DateTime5.String(), "2009-11-17 21:34:58 +0000 UTC") + require.Equal(t, dest.JulianDay, 2.4551543576232754e+06) + require.Equal(t, dest.StrfTime, "20:34") +} diff --git a/tests/sqlite/cast_test.go b/tests/sqlite/cast_test.go new file mode 100644 index 0000000..a20a60c --- /dev/null +++ b/tests/sqlite/cast_test.go @@ -0,0 +1,41 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/stretchr/testify/require" + "testing" +) + +func TestCast(t *testing.T) { + query := SELECT( + CAST(String("test")).AS("CHARACTER").AS("result.AS1"), + CAST(Float(11.33)).AS_TEXT().AS("result.text"), + CAST(String("33.44")).AS_REAL().AS("result.real"), + CAST(String("33")).AS_INTEGER().AS("result.integer"), + CAST(String("Blob blob")).AS_BLOB().AS("result.blob"), + ) + + type Result struct { + As1 string + Text string + Real float64 + Integer int64 + Blob []byte + } + + var dest Result + + err := query.Query(db, &dest) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, dest, Result{ + As1: "test", + Text: "11.33", + Real: 33.44, + Integer: 33, + Blob: []byte("Blob blob"), + }) + + requireLogged(t, query) +} diff --git a/tests/sqlite/delete_test.go b/tests/sqlite/delete_test.go new file mode 100644 index 0000000..7045772 --- /dev/null +++ b/tests/sqlite/delete_test.go @@ -0,0 +1,83 @@ +package sqlite + +import ( + "context" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/stretchr/testify/require" +) + +func TestDelete_WHERE_RETURNING(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +DELETE FROM link +WHERE link.name IN ('Bing', 'Yahoo') +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +` + deleteStmt := Link.DELETE(). + WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))). + RETURNING(Link.AllColumns) + + testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Bing", "Yahoo") + var dest []model.Link + err := deleteStmt.Query(tx, &dest) + require.NoError(t, err) + require.Len(t, dest, 2) + requireLogged(t, deleteStmt) +} + +func TestDeleteWithWhereOrderByLimit(t *testing.T) { + t.SkipNow() // Until https://github.com/mattn/go-sqlite3/pull/802 is fixed + tx := beginSampleDBTx(t) + defer tx.Rollback() + + sampleDB.Stats() + + var expectedSQL = ` +DELETE FROM link +WHERE link.name IN ('Bing', 'Yahoo') +ORDER BY link.name +LIMIT 1; +` + deleteStmt := Link.DELETE(). + WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))). + ORDER_BY(Link.Name). + LIMIT(1) + + testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Bing", "Yahoo", int64(1)) + testutils.AssertExec(t, deleteStmt, tx, 1) + requireLogged(t, deleteStmt) +} + +func TestDeleteContextDeadlineExceeded(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + deleteStmt := Link. + DELETE(). + WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) + + dest := []model.Link{} + err := deleteStmt.QueryContext(ctx, tx, &dest) + require.Error(t, err, "context deadline exceeded") + + _, err = deleteStmt.ExecContext(ctx, tx) + require.Error(t, err, "context deadline exceeded") + + requireLogged(t, deleteStmt) +} diff --git a/tests/sqlite/generator_test.go b/tests/sqlite/generator_test.go new file mode 100644 index 0000000..ac7ab5d --- /dev/null +++ b/tests/sqlite/generator_test.go @@ -0,0 +1,298 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/generator/sqlite" + "github.com/go-jet/jet/v2/internal/testutils" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" + "github.com/stretchr/testify/require" + "io/ioutil" + "os" + "os/exec" + "reflect" + "testing" +) + +func TestGeneratedModel(t *testing.T) { + actor := model.Actor{} + + require.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32") + actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID") + require.True(t, ok) + require.Equal(t, actorIDField.Tag.Get("sql"), "primary_key") + require.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string") + require.Equal(t, reflect.TypeOf(actor.LastName).String(), "string") + require.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time") + + filmActor := model.FilmActor{} + + require.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int32") + filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID") + require.True(t, ok) + require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") + + require.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int32") + actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID") + require.True(t, ok) + require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key") + + staff := model.Staff{} + + require.Equal(t, reflect.TypeOf(staff.Email).String(), "*string") + require.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8") +} + +var testDatabaseFilePath = repo.GetTestDataFilePath("/init/sqlite/sakila.db") +var genDestDir = repo.GetTestsFilePath("/sqlite/.gen") + +func TestGenerator(t *testing.T) { + for i := 0; i < 3; i++ { + err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir) + require.NoError(t, err) + + assertGeneratedFiles(t) + } + + err := os.RemoveAll(genDestDir) + require.NoError(t, err) +} + +func TestCmdGenerator(t *testing.T) { + cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir) + + cmd.Stderr = os.Stderr + cmd.Stdout = os.Stdout + + err := cmd.Run() + require.NoError(t, err) + + assertGeneratedFiles(t) + + err = os.RemoveAll(genDestDir) + require.NoError(t, err) +} + +func assertGeneratedFiles(t *testing.T) { + // Table SQL Builder files + tableSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/table") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go") + + testutils.AssertFileContent(t, genDestDir+"/table/actor.go", actorSQLBuilderFile) + + // View SQL Builder files + viewSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/view") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, genDestDir+"/view/film_list.go", filmListSQLBuilderFile) + + // Model files + modelFiles, err := ioutil.ReadDir(genDestDir + "/model") + require.NoError(t, err) + + testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go", + "customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go", + "payment.go", "rental.go", "staff.go", "store.go", + "film_list.go", "sales_by_film_category.go", + "customer_list.go", "sales_by_store.go", "staff_list.go") + + testutils.AssertFileContent(t, genDestDir+"/model/address.go", addressModelFile) +} + +const actorSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/sqlite" +) + +var Actor = newActorTable("", "actor", "") + +type actorTable struct { + sqlite.Table + + //Columns + ActorID sqlite.ColumnInteger + FirstName sqlite.ColumnString + LastName sqlite.ColumnString + LastUpdate sqlite.ColumnTimestamp + + AllColumns sqlite.ColumnList + MutableColumns sqlite.ColumnList +} + +type ActorTable struct { + actorTable + + EXCLUDED actorTable +} + +// AS creates new ActorTable with assigned alias +func (a ActorTable) AS(alias string) *ActorTable { + return newActorTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new ActorTable with assigned schema name +func (a ActorTable) FromSchema(schemaName string) *ActorTable { + return newActorTable(schemaName, a.TableName(), a.Alias()) +} + +func newActorTable(schemaName, tableName, alias string) *ActorTable { + return &ActorTable{ + actorTable: newActorTableImpl(schemaName, tableName, alias), + EXCLUDED: newActorTableImpl("", "excluded", ""), + } +} + +func newActorTableImpl(schemaName, tableName, alias string) actorTable { + var ( + ActorIDColumn = sqlite.IntegerColumn("actor_id") + FirstNameColumn = sqlite.StringColumn("first_name") + LastNameColumn = sqlite.StringColumn("last_name") + LastUpdateColumn = sqlite.TimestampColumn("last_update") + allColumns = sqlite.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn} + mutableColumns = sqlite.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn} + ) + + return actorTable{ + Table: sqlite.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + ActorID: ActorIDColumn, + FirstName: FirstNameColumn, + LastName: LastNameColumn, + LastUpdate: LastUpdateColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +const filmListSQLBuilderFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package view + +import ( + "github.com/go-jet/jet/v2/sqlite" +) + +var FilmList = newFilmListTable("", "film_list", "") + +type filmListTable struct { + sqlite.Table + + //Columns + Fid sqlite.ColumnInteger + Title sqlite.ColumnString + Description sqlite.ColumnString + Category sqlite.ColumnString + Price sqlite.ColumnFloat + Length sqlite.ColumnInteger + Rating sqlite.ColumnString + Actors sqlite.ColumnString + + AllColumns sqlite.ColumnList + MutableColumns sqlite.ColumnList +} + +type FilmListTable struct { + filmListTable + + EXCLUDED filmListTable +} + +// AS creates new FilmListTable with assigned alias +func (a FilmListTable) AS(alias string) *FilmListTable { + return newFilmListTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new FilmListTable with assigned schema name +func (a FilmListTable) FromSchema(schemaName string) *FilmListTable { + return newFilmListTable(schemaName, a.TableName(), a.Alias()) +} + +func newFilmListTable(schemaName, tableName, alias string) *FilmListTable { + return &FilmListTable{ + filmListTable: newFilmListTableImpl(schemaName, tableName, alias), + EXCLUDED: newFilmListTableImpl("", "excluded", ""), + } +} + +func newFilmListTableImpl(schemaName, tableName, alias string) filmListTable { + var ( + FidColumn = sqlite.IntegerColumn("FID") + TitleColumn = sqlite.StringColumn("title") + DescriptionColumn = sqlite.StringColumn("description") + CategoryColumn = sqlite.StringColumn("category") + PriceColumn = sqlite.FloatColumn("price") + LengthColumn = sqlite.IntegerColumn("length") + RatingColumn = sqlite.StringColumn("rating") + ActorsColumn = sqlite.StringColumn("actors") + allColumns = sqlite.ColumnList{FidColumn, TitleColumn, DescriptionColumn, CategoryColumn, PriceColumn, LengthColumn, RatingColumn, ActorsColumn} + mutableColumns = sqlite.ColumnList{FidColumn, TitleColumn, DescriptionColumn, CategoryColumn, PriceColumn, LengthColumn, RatingColumn, ActorsColumn} + ) + + return filmListTable{ + Table: sqlite.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + Fid: FidColumn, + Title: TitleColumn, + Description: DescriptionColumn, + Category: CategoryColumn, + Price: PriceColumn, + Length: LengthColumn, + Rating: RatingColumn, + Actors: ActorsColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + } +} +` + +const addressModelFile = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +import ( + "time" +) + +type Address struct { + AddressID int32 ` + "`sql:\"primary_key\"`" + ` + Address string + Address2 *string + District string + CityID int32 + PostalCode *string + Phone string + LastUpdate time.Time +} +` diff --git a/tests/sqlite/insert_test.go b/tests/sqlite/insert_test.go new file mode 100644 index 0000000..f5939bb --- /dev/null +++ b/tests/sqlite/insert_test.go @@ -0,0 +1,393 @@ +package sqlite + +import ( + "context" + "math/rand" + + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/stretchr/testify/require" +) + +func TestInsertValues(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + VALUES(101, "http://www.google.com", "Google", "Search engine"). + VALUES(102, "http://www.yahoo.com", "Yahoo", nil) + + testutils.AssertStatementSql(t, insertQuery, ` +INSERT INTO link (id, url, name, description) +VALUES (?, ?, ?, ?), + (?, ?, ?, ?), + (?, ?, ?, ?); +`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, + 101, "http://www.google.com", "Google", "Search engine", + 102, "http://www.yahoo.com", "Yahoo", nil) + + _, err := insertQuery.Exec(tx) + require.NoError(t, err) + requireLogged(t, insertQuery) + + insertedLinks := []model.Link{} + + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 3) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) + testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", + Description: testutils.StringPtr("Search engine"), + }) + testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + }) +} + +var postgreTutorial = model.Link{ + ID: 100, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", +} + +func TestInsertEmptyColumnList(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + expectedSQL := ` +INSERT INTO link +VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL); +` + + stmt := Link.INSERT(). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) + + _, err := stmt.Exec(tx) + require.NoError(t, err) + requireLogged(t, stmt) + + insertedLinks := []model.Link{} + + err = SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(tx, &insertedLinks) + + require.NoError(t, err) + require.Equal(t, len(insertedLinks), 1) + testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial) +} + +func TestInsertModelObject(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + linkData := model.Link{ + URL: "http://www.duckduckgo.com", + Name: "Duck Duck go", + } + + query := Link.INSERT(Link.URL, Link.Name). + MODEL(linkData) + + testutils.AssertDebugStatementSql(t, query, ` +INSERT INTO link (url, name) +VALUES ('http://www.duckduckgo.com', 'Duck Duck go'); +`, "http://www.duckduckgo.com", "Duck Duck go") + + _, err := query.Exec(tx) + require.NoError(t, err) +} + +func TestInsertModelObjectEmptyColumnList(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +INSERT INTO link +VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL); +` + + linkData := model.Link{ + ID: 1000, + URL: "http://www.duckduckgo.com", + Name: "Duck Duck go", + } + + query := Link. + INSERT(). + MODEL(linkData) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil) + + _, err := query.Exec(tx) + require.NoError(t, err) +} + +func TestInsertModelsObject(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + expectedSQL := ` +INSERT INTO link (url, name) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'), + ('http://www.google.com', 'Google'), + ('http://www.yahoo.com', 'Yahoo'); +` + + tutorial := model.Link{ + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial", + } + google := model.Link{ + URL: "http://www.google.com", + Name: "Google", + } + yahoo := model.Link{ + URL: "http://www.yahoo.com", + Name: "Yahoo", + } + + query := Link. + INSERT(Link.URL, Link.Name). + MODELS([]model.Link{ + tutorial, + google, + yahoo, + }) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, + "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + "http://www.google.com", "Google", + "http://www.yahoo.com", "Yahoo") + + _, err := query.Exec(tx) + require.NoError(t, err) +} + +func TestInsertUsingMutableColumns(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +INSERT INTO link (url, name, description) +VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.google.com', 'Google', NULL), + ('http://www.yahoo.com', 'Yahoo', NULL); +` + + google := model.Link{ + URL: "http://www.google.com", + Name: "Google", + } + + yahoo := model.Link{ + URL: "http://www.yahoo.com", + Name: "Yahoo", + } + + stmt := Link. + INSERT(Link.MutableColumns). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + MODEL(google). + MODELS([]model.Link{google, yahoo}) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, + "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil, + "http://www.google.com", "Google", nil, + "http://www.google.com", "Google", nil, + "http://www.yahoo.com", "Yahoo", nil) + + _, err := stmt.Exec(tx) + require.NoError(t, err) +} + +func TestInsertQuery(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +INSERT INTO link (url, name) +SELECT link.url AS "link.url", + link.name AS "link.name" +FROM link +WHERE link.id = 24; +` + query := Link.INSERT(Link.URL, Link.Name). + QUERY( + SELECT(Link.URL, Link.Name). + FROM(Link). + WHERE(Link.ID.EQ(Int(24))), + ) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24)) + + _, err := query.Exec(tx) + require.NoError(t, err) + + youtubeLinks := []model.Link{} + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bing"))). + Query(tx, &youtubeLinks) + + require.NoError(t, err) + require.Equal(t, len(youtubeLinks), 2) +} + +func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + stmt := Link.INSERT(). + DEFAULT_VALUES(). + RETURNING(Link.AllColumns) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link +DEFAULT VALUES +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + var link model.Link + err := stmt.Query(tx, &link) + require.NoError(t, err) + + require.EqualValues(t, link, model.Link{ + ID: 25, + URL: "www.", + Name: "_", + Description: nil, + }) +} + +func TestInsertOnConflict(t *testing.T) { + + t.Run("do nothing", func(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ID: rand.Int31()} + + stmt := Link.INSERT(Link.AllColumns). + MODEL(link). + MODEL(link). + ON_CONFLICT(Link.ID).DO_NOTHING() + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO link (id, url, name, description) +VALUES (?, ?, ?, ?), + (?, ?, ?, ?) +ON CONFLICT (id) DO NOTHING; +`) + testutils.AssertExec(t, stmt, tx, 1) + requireLogged(t, stmt) + }) + + t.Run("do update", func(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + ON_CONFLICT(Link.ID). + DO_UPDATE( + SET( + Link.ID.SET(Link.EXCLUDED.ID), + Link.URL.SET(String("http://www.postgresqltutorial2.com")), + ), + ).RETURNING(Link.AllColumns) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO link (id, url, name, description) +VALUES (?, ?, ?, ?), + (?, ?, ?, ?) +ON CONFLICT (id) DO UPDATE + SET id = excluded.id, + url = ? +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description"; +`) + + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + }) + + t.Run("do update complex", func(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). + VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil). + ON_CONFLICT(Link.ID). + WHERE(Link.ID.MUL(Int(2)).GT(Int(10))). + DO_UPDATE( + SET( + Link.ID.SET( + IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))). + FROM(Link)), + ), + ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String(""))), + ).WHERE(Link.Description.IS_NOT_NULL()), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +INSERT INTO link (id, url, name, description) +VALUES (21, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL) +ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE + SET id = ( + SELECT MAX(link.id) + 1 + FROM link + ), + (name, description) = (excluded.name, '') + WHERE link.description IS NOT NULL; +`) + + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + }) +} + +func TestInsertContextDeadlineExceeded(t *testing.T) { + stmt := Link.INSERT(). + VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) + + dest := []model.Link{} + err := stmt.QueryContext(ctx, sampleDB, &dest) + require.Error(t, err, "context deadline exceeded") + + _, err = stmt.ExecContext(ctx, db) + require.Error(t, err, "context deadline exceeded") +} diff --git a/tests/sqlite/main_test.go b/tests/sqlite/main_test.go new file mode 100644 index 0000000..710f7ad --- /dev/null +++ b/tests/sqlite/main_test.go @@ -0,0 +1,90 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "github.com/go-jet/jet/v2/internal/utils/throw" + "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/stretchr/testify/require" + "math/rand" + "os" + "os/exec" + "strings" + "testing" + "time" + + "github.com/pkg/profile" + + _ "github.com/mattn/go-sqlite3" +) + +var db *sql.DB +var sampleDB *sql.DB +var testRoot string + +func TestMain(m *testing.M) { + rand.Seed(time.Now().Unix()) + defer profile.Start().Stop() + + setTestRoot() + + var err error + db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath) + throw.OnError(err) + + _, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath)) + throw.OnError(err) + + sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath) + throw.OnError(err) + + defer db.Close() + + ret := m.Run() + + if ret != 0 { + os.Exit(ret) + } +} + +func setTestRoot() { + cmd := exec.Command("git", "rev-parse", "--show-toplevel") + byteArr, err := cmd.Output() + if err != nil { + panic(err) + } + + testRoot = strings.TrimSpace(string(byteArr)) + "/tests/" +} + +var loggedSQL string +var loggedSQLArgs []interface{} +var loggedDebugSQL string + +func init() { + sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) { + loggedSQL, loggedSQLArgs = statement.Sql() + loggedDebugSQL = statement.DebugSql() + }) +} + +func requireLogged(t *testing.T, statement sqlite.Statement) { + query, args := statement.Sql() + require.Equal(t, loggedSQL, query) + require.Equal(t, loggedSQLArgs, args) + require.Equal(t, loggedDebugSQL, statement.DebugSql()) +} + +func beginSampleDBTx(t *testing.T) *sql.Tx { + tx, err := sampleDB.Begin() + require.NoError(t, err) + return tx +} + +func beginDBTx(t *testing.T) *sql.Tx { + tx, err := db.Begin() + require.NoError(t, err) + return tx +} diff --git a/tests/sqlite/raw_statement_test.go b/tests/sqlite/raw_statement_test.go new file mode 100644 index 0000000..974dfda --- /dev/null +++ b/tests/sqlite/raw_statement_test.go @@ -0,0 +1,121 @@ +package sqlite + +import ( + "context" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/stretchr/testify/require" +) + +func TestRawStatementSelect(t *testing.T) { + stmt := RawStatement(` + SELECT actor.first_name AS "actor.first_name" + FROM actor + WHERE actor.actor_id = 2`) + + testutils.AssertStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM actor + WHERE actor.actor_id = 2; +`) + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT actor.first_name AS "actor.first_name" + FROM actor + WHERE actor.actor_id = 2; +`) + var actor model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + require.Equal(t, actor.FirstName, "NICK") +} + +func TestRawStatementSelectWithArguments(t *testing.T) { + stmt := RawStatement(` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + WHERE actor.actor_id IN (#actorID1, #actorID2, #actorID3) AND ((#actorID1 / #actorID2) <> (#actorID2 * #actorID3)) + ORDER BY actor.actor_id`, + RawArgs{ + "#actorID1": int64(1), + "#actorID2": int64(2), + "#actorID3": int64(3), + }, + ) + + testutils.AssertStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + WHERE actor.actor_id IN (?, ?, ?) AND ((? / ?) <> (? * ?)) + ORDER BY actor.actor_id; +`, int64(1), int64(2), int64(3), int64(1), int64(2), int64(2), int64(3)) + + testutils.AssertDebugStatementSql(t, stmt, ` + SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + WHERE actor.actor_id IN (1, 2, 3) AND ((1 / 2) <> (2 * 3)) + ORDER BY actor.actor_id; +`) + + var actor []model.Actor + err := stmt.Query(db, &actor) + require.NoError(t, err) + + testutils.AssertDeepEqual(t, actor[1], model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", + LastUpdate: *testutils.TimestampWithoutTimeZone("2019-04-11 18:11:48", 2), + }) +} + +func TestRawStatementRows(t *testing.T) { + stmt := RawStatement(` + SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" + FROM actor + ORDER BY actor.actor_id`) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var actor model.Actor + err := rows.Scan(&actor) + require.NoError(t, err) + + require.NotEqual(t, actor.ActorID, int16(0)) + require.NotEqual(t, actor.FirstName, "") + require.NotEqual(t, actor.LastName, "") + require.NotEqual(t, actor.LastUpdate, time.Time{}) + + if actor.ActorID == 54 { + require.Equal(t, actor.ActorID, int32(54)) + require.Equal(t, actor.FirstName, "PENELOPE") + require.Equal(t, actor.LastName, "PINKETT") + require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2019-04-11T18:11:48Z") + } + } + + err = rows.Close() + require.NoError(t, err) + + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} diff --git a/tests/sqlite/select_test.go b/tests/sqlite/select_test.go new file mode 100644 index 0000000..95527f1 --- /dev/null +++ b/tests/sqlite/select_test.go @@ -0,0 +1,749 @@ +package sqlite + +import ( + "context" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table" + "strings" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/view" + + "github.com/stretchr/testify/require" +) + +func TestSelect_ScanToStruct(t *testing.T) { + query := Actor. + SELECT(Actor.AllColumns). + DISTINCT(). + WHERE(Actor.ActorID.EQ(Int(2))) + + testutils.AssertStatementSql(t, query, ` +SELECT DISTINCT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM actor +WHERE actor.actor_id = ?; +`, int64(2)) + + actor := model.Actor{} + err := query.Query(db, &actor) + + require.NoError(t, err) + + testutils.AssertDeepEqual(t, actor, actor2) + requireLogged(t, query) +} + +var actor2 = model.Actor{ + ActorID: 2, + FirstName: "NICK", + LastName: "WAHLBERG", + LastUpdate: *testutils.TimestampWithoutTimeZone("2019-04-11 18:11:48", 2), +} + +func TestSelect_ScanToSlice(t *testing.T) { + query := SELECT(Actor.AllColumns). + FROM(Actor). + ORDER_BY(Actor.ActorID) + + testutils.AssertStatementSql(t, query, ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM actor +ORDER BY actor.actor_id; +`) + dest := []model.Actor{} + + err := query.Query(db, &dest) + + require.NoError(t, err) + + require.Equal(t, len(dest), 200) + testutils.AssertDeepEqual(t, dest[1], actor2) + + //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json") + requireLogged(t, query) +} + +func TestSelectGroupByHaving(t *testing.T) { + expectedSQL := ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.active AS "customer.active", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + SUM(payment.amount) AS "amount.sum", + AVG(payment.amount) AS "amount.avg", + MAX(payment.payment_date) AS "amount.max_date", + MAX(payment.amount) AS "amount.max", + MIN(payment.payment_date) AS "amount.min_date", + MIN(payment.amount) AS "amount.min", + COUNT(payment.amount) AS "amount.count" +FROM payment + INNER JOIN customer ON (customer.customer_id = payment.customer_id) +GROUP BY payment.customer_id +HAVING SUM(payment.amount) > 125.6 +ORDER BY payment.customer_id, SUM(payment.amount) ASC; +` + query := Payment. + INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)). + SELECT( + Customer.AllColumns, + + SUMf(Payment.Amount).AS("amount.sum"), + AVG(Payment.Amount).AS("amount.avg"), + MAX(Payment.PaymentDate).AS("amount.max_date"), + MAXf(Payment.Amount).AS("amount.max"), + MIN(Payment.PaymentDate).AS("amount.min_date"), + MINf(Payment.Amount).AS("amount.min"), + COUNT(Payment.Amount).AS("amount.count"), + ). + GROUP_BY(Payment.CustomerID). + HAVING( + SUMf(Payment.Amount).GT(Float(125.6)), + ). + ORDER_BY( + Payment.CustomerID, SUMf(Payment.Amount).ASC(), + ) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, float64(125.6)) + + var dest []struct { + model.Customer + + Amount struct { + Sum float64 + Avg float64 + Max float64 + Min float64 + Count int64 + } `alias:"amount"` + } + + err := query.Query(db, &dest) + + require.NoError(t, err) + require.Equal(t, len(dest), 174) + //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/customer_payment_sum.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/customer_payment_sum.json") + requireLogged(t, query) +} + +func TestSubQuery(t *testing.T) { + + rRatingFilms := + SELECT( + Film.FilmID, + Film.Title, + Film.Rating, + ).FROM( + Film, + ).WHERE(Film.Rating.EQ(String("R"))). + AsTable("rFilms") + + rFilmID := Film.FilmID.From(rRatingFilms) + + main := + SELECT( + Actor.AllColumns, + FilmActor.AllColumns, + rRatingFilms.AllColumns(), + ).FROM( + rRatingFilms. + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)), + ).ORDER_BY( + rFilmID, + Actor.ActorID, + ) + + var dest []struct { + model.Film + Actors []model.Actor + } + + err := main.Query(db, &dest) + require.NoError(t, err) + + //testutils.SaveJSONFile(dest, "./testdata/results/sqlite/r_rating_films.json") + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/r_rating_films.json") +} + +func TestSelectAndUnionInProjection(t *testing.T) { + query := UNION( + SELECT( + Payment.PaymentID, + ).FROM(Payment), + + SELECT( + STAR, + ).FROM( + SELECT(Payment.PaymentID). + FROM(Payment).LIMIT(1).OFFSET(2).AsTable("p"), + ), + ).LIMIT(1).OFFSET(10) + + testutils.AssertDebugStatementSql(t, query, ` + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment + +UNION + +SELECT * +FROM ( + SELECT payment.payment_id AS "payment.payment_id" + FROM payment + LIMIT 1 + OFFSET 2 + ) AS p +LIMIT 1 +OFFSET 10; +`, int64(1), int64(2), int64(1), int64(10)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestSelectUNION(t *testing.T) { + expectedSQL := ` + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.payment_id > ? + +UNION + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.amount < ? +LIMIT ?; +` + query := UNION( + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.PaymentID.GT(Int(11))), + + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(2000.0))), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(11), 2000.0, int64(1)) + + query2 := + SELECT( + Payment.PaymentID, + ).FROM( + Payment, + ).WHERE( + Payment.PaymentID.GT(Int(11)), + ).UNION( + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(2000.0))), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query2, expectedSQL, int64(11), 2000.0, int64(1)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestSelectUNION_ALL(t *testing.T) { + expectedSQL := ` + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.payment_id > ? + +UNION ALL + +SELECT payment.payment_id AS "payment.payment_id" +FROM payment +WHERE payment.amount < ? +LIMIT ?; +` + query := UNION_ALL( + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.PaymentID.GT(Int(11))), + + SELECT(Payment.PaymentID). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(2000.0))), + ).LIMIT(1) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(11), 2000.0, int64(1)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestJoinQueryStruct(t *testing.T) { + + expectedSQL := ` +SELECT film_actor.actor_id AS "film_actor.actor_id", + film_actor.film_id AS "film_actor.film_id", + film_actor.last_update AS "film_actor.last_update", + film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.original_language_id AS "film.original_language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.special_features AS "film.special_features", + film.last_update AS "film.last_update", + language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update", + actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update", + inventory.inventory_id AS "inventory.inventory_id", + inventory.film_id AS "inventory.film_id", + inventory.store_id AS "inventory.store_id", + inventory.last_update AS "inventory.last_update", + rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + rental.last_update AS "rental.last_update" +FROM language + INNER JOIN film ON (film.language_id = language.language_id) + INNER JOIN film_actor ON (film_actor.film_id = film.film_id) + INNER JOIN actor ON (actor.actor_id = film_actor.actor_id) + LEFT JOIN inventory ON (inventory.film_id = film.film_id) + LEFT JOIN rental ON (rental.inventory_id = inventory.inventory_id) +ORDER BY language.language_id ASC, film.film_id ASC, actor.actor_id ASC, inventory.inventory_id ASC, rental.rental_id ASC +LIMIT ?; +` + for i := 0; i < 2; i++ { + query := + SELECT( + FilmActor.AllColumns, + Film.AllColumns, + Language.AllColumns, + Actor.AllColumns, + Inventory.AllColumns, + Rental.AllColumns, + ). + FROM( + Language. + INNER_JOIN(Film, Film.LanguageID.EQ(Language.LanguageID)). + INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)). + INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)). + LEFT_JOIN(Inventory, Inventory.FilmID.EQ(Film.FilmID)). + LEFT_JOIN(Rental, Rental.InventoryID.EQ(Inventory.InventoryID)), + ).ORDER_BY( + Language.LanguageID.ASC(), + Film.FilmID.ASC(), + Actor.ActorID.ASC(), + Inventory.InventoryID.ASC(), + Rental.RentalID.ASC(), + ). + LIMIT(1000) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(1000)) + + var dest []struct { + model.Language + + Films []struct { + model.Film + + Actors []struct { + model.Actor + } + + Inventories *[]struct { + model.Inventory + + Rentals *[]model.Rental + } + } + } + + err := query.Query(db, &dest) + + require.NoError(t, err) + testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/lang_film_actor_inventory_rental.json") + } +} + +func TestExpressionWrappers(t *testing.T) { + query := SELECT( + BoolExp(Raw("true")), + IntExp(Raw("11")), + FloatExp(Raw("11.22")), + StringExp(Raw("'stringer'")), + TimeExp(Raw("'raw'")), + TimestampExp(Raw("'raw'")), + DateTimeExp(Raw("'raw'")), + DateExp(Raw("'date'")), + ) + + testutils.AssertStatementSql(t, query, ` +SELECT true, + 11, + 11.22, + 'stringer', + 'raw', + 'raw', + 'raw', + 'date'; +`) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestWindowFunction(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (PARTITION BY payment.customer_id), + MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC), + MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING), + SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING), + MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC), + ROW_NUMBER() OVER (ORDER BY payment.payment_date), + RANK() OVER (ORDER BY payment.payment_date), + DENSE_RANK() OVER (ORDER BY payment.payment_date), + CUME_DIST() OVER (ORDER BY payment.payment_date), + NTILE(11) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date), + LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date), + FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date), + NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date) +FROM payment +WHERE payment.payment_id < ? +GROUP BY payment.amount, payment.customer_id, payment.payment_date; +` + query := + SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)), + MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())), + MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))), + SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID). + ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))), + MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())), + ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)), + RANK().OVER(ORDER_BY(Payment.PaymentDate)), + DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)), + CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)), + NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)), + FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)), + NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)), + ).FROM( + Payment, + ).GROUP_BY( + Payment.Amount, + Payment.CustomerID, + Payment.PaymentDate, + ).WHERE(Payment.PaymentID.LT(Int(10))) + + testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10)) + + dest := []struct{}{} + err := query.Query(db, &dest) + require.NoError(t, err) +} + +func TestWindowClause(t *testing.T) { + var expectedSQL = ` +SELECT AVG(payment.amount) OVER (), + AVG(payment.amount) OVER (w1), + AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING), + AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) +FROM payment +WHERE payment.payment_id < ? +WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id) +ORDER BY payment.customer_id; +` + query := SELECT( + AVG(Payment.Amount).OVER(), + AVG(Payment.Amount).OVER(Window("w1")), + AVG(Payment.Amount).OVER( + Window("w2"). + ORDER_BY(Payment.CustomerID). + RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)), + ), + AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))), + ).FROM( + Payment, + ).WHERE( + Payment.PaymentID.LT(Int(10)), + ). + WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)). + WINDOW("w2").AS(Window("w1")). + WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)). + ORDER_BY( + Payment.CustomerID, + ) + + testutils.AssertStatementSql(t, query, expectedSQL, int64(10)) + + dest := []struct{}{} + err := query.Query(db, &dest) + + require.NoError(t, err) +} + +func TestSimpleView(t *testing.T) { + query := + SELECT( + view.CustomerList.AllColumns, + ).FROM( + view.CustomerList, + ).ORDER_BY( + view.CustomerList.ID, + ).LIMIT(10) + + var dest []model.CustomerList + + err := query.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 10) + require.Equal(t, dest[2], model.CustomerList{ + ID: testutils.Int32Ptr(3), + Name: testutils.StringPtr("LINDA WILLIAMS"), + Address: testutils.StringPtr("692 Joliet Street"), + ZipCode: testutils.StringPtr("83579"), + Phone: testutils.StringPtr(" "), + City: testutils.StringPtr("Athenai"), + Country: testutils.StringPtr("Greece"), + Notes: testutils.StringPtr("active"), + Sid: testutils.Int32Ptr(1), + }) +} + +func TestJoinViewWithTable(t *testing.T) { + query := + SELECT( + view.CustomerList.AllColumns, + Rental.AllColumns, + ).FROM( + view.CustomerList. + INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)), + ).ORDER_BY( + view.CustomerList.ID, + ).WHERE( + view.CustomerList.ID.LT_EQ(Int(2)), + ) + + var dest []struct { + model.CustomerList `sql:"primary_key=ID"` + Rentals []model.Rental + } + + err := query.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 2) + require.Equal(t, len(dest[0].Rentals), 32) + require.Equal(t, len(dest[1].Rentals), 27) +} + +func TestConditionalProjectionList(t *testing.T) { + projectionList := ProjectionList{} + + columnsToSelect := []string{"customer_id", "create_date"} + + for _, columnName := range columnsToSelect { + switch columnName { + case Customer.CustomerID.Name(): + projectionList = append(projectionList, Customer.CustomerID) + case Customer.Email.Name(): + projectionList = append(projectionList, Customer.Email) + case Customer.CreateDate.Name(): + projectionList = append(projectionList, Customer.CreateDate) + } + } + + stmt := SELECT(projectionList). + FROM(Customer). + LIMIT(3) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT customer.customer_id AS "customer.customer_id", + customer.create_date AS "customer.create_date" +FROM customer +LIMIT 3; +`) + var dest []model.Customer + err := stmt.Query(db, &dest) + require.NoError(t, err) + + require.Equal(t, len(dest), 3) +} + +func TestUseAttachedDatabase(t *testing.T) { + Artists := table.Artists.FromSchema("chinook") + Albums := table.Albums.FromSchema("chinook") + + stmt := + SELECT( + Artists.AllColumns, + Albums.AllColumns, + ).FROM( + Albums. + INNER_JOIN(Artists, Artists.ArtistId.EQ(Albums.ArtistId)), + ).ORDER_BY( + Artists.ArtistId, + ).LIMIT(10) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT artists.''ArtistId'' AS "artists.ArtistId", + artists.''Name'' AS "artists.Name", + albums.''AlbumId'' AS "albums.AlbumId", + albums.''Title'' AS "albums.Title", + albums.''ArtistId'' AS "albums.ArtistId" +FROM chinook.albums + INNER JOIN chinook.artists ON (artists.''ArtistId'' = albums.''ArtistId'') +ORDER BY artists.''ArtistId'' +LIMIT 10; +`, "''", "`", -1)) + + var dest []struct { + model2.Artists + Albums []model2.Albums + } + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Len(t, dest, 7) +} + +func TestRowsScan(t *testing.T) { + stmt := + SELECT( + Inventory.AllColumns, + ).FROM( + Inventory, + ).ORDER_BY( + Inventory.InventoryID.ASC(), + ) + + rows, err := stmt.Rows(context.Background(), db) + require.NoError(t, err) + + for rows.Next() { + var inventory model.Inventory + err = rows.Scan(&inventory) + require.NoError(t, err) + + require.NotEqual(t, inventory.InventoryID, uint32(0)) + require.NotEqual(t, inventory.FilmID, uint16(0)) + require.NotEqual(t, inventory.StoreID, uint16(0)) + require.NotEqual(t, inventory.LastUpdate, time.Time{}) + + if inventory.InventoryID == 2103 { + require.Equal(t, inventory.FilmID, int32(456)) + require.Equal(t, inventory.StoreID, int32(2)) + require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2019-04-11T18:11:48Z") + } + } + + err = rows.Close() + require.NoError(t, err) + err = rows.Err() + require.NoError(t, err) + + requireLogged(t, stmt) +} + +func TestScanNumericToNumber(t *testing.T) { + type Number struct { + Int8 int8 + UInt8 uint8 + Int16 int16 + UInt16 uint16 + Int32 int32 + UInt32 uint32 + Int64 int64 + UInt64 uint64 + Float32 float32 + Float64 float64 + } + + numeric := CAST(String("1234567890.111")).AS_REAL() + + stmt := SELECT( + numeric.AS("number.int8"), + numeric.AS("number.uint8"), + numeric.AS("number.int16"), + numeric.AS("number.uint16"), + numeric.AS("number.int32"), + numeric.AS("number.uint32"), + numeric.AS("number.int64"), + numeric.AS("number.uint64"), + numeric.AS("number.float32"), + numeric.AS("number.float64"), + ) + + var number Number + err := stmt.Query(db, &number) + require.NoError(t, err) + + require.Equal(t, number.Int8, int8(-46)) // overflow + require.Equal(t, number.UInt8, uint8(210)) // overflow + require.Equal(t, number.Int16, int16(722)) // overflow + require.Equal(t, number.UInt16, uint16(722)) // overflow + require.Equal(t, number.Int32, int32(1234567890)) + require.Equal(t, number.UInt32, uint32(1234567890)) + require.Equal(t, number.Int64, int64(1234567890)) + require.Equal(t, number.UInt64, uint64(1234567890)) + require.Equal(t, number.Float32, float32(1.234568e+09)) + require.Equal(t, number.Float64, float64(1.234567890111e+09)) +} diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go new file mode 100644 index 0000000..61135a8 --- /dev/null +++ b/tests/sqlite/update_test.go @@ -0,0 +1,290 @@ +package sqlite + +import ( + "context" + "testing" + "time" + + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table" + "github.com/stretchr/testify/require" +) + +func TestUpdateValues(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + var expectedSQL = ` +UPDATE link +SET name = 'Bong', + url = 'http://bong.com' +WHERE link.name = 'Bing'; +` + t.Run("old version", func(t *testing.T) { + query := Link.UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(String("http://bong.com")), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + }) + + links := []model.Link{} + + err := SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.Name.EQ(String("Bong"))). + Query(tx, &links) + + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 24, + URL: "http://bong.com", + Name: "Bong", + }) +} + +func TestUpdateWithSubQueries(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + expectedSQL := ` +UPDATE link +SET name = ?, + url = ( + SELECT link.url AS "link.url" + FROM link + WHERE link.name = ? + ) +WHERE link.name = ?; +` + t.Run("old version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + String("Bong"), + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Ask"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Ask", "Bing") + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + }) + + t.Run("new version", func(t *testing.T) { + query := Link. + UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(StringExp( + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Ask"))), + )), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Ask", "Bing") + testutils.AssertExec(t, query, tx) + requireLogged(t, query) + }) +} + +func TestUpdateWithModelDataAndReturning(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ + ID: 20, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link.UPDATE(Link.AllColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int32(link.ID))). + RETURNING( + Link.AllColumns, + String("str").AS("dest.literal"), + NOT(Bool(false)).AS("dest.unary_operator"), + Link.ID.ADD(Int(11)).AS("dest.binary_operator"), + CAST(Link.ID).AS_TEXT().AS("dest.cast_operator"), + Link.Name.LIKE(String("Bing")).AS("dest.like_operator"), + Link.Description.IS_NULL().AS("dest.is_null"), + CASE(Link.Name). + WHEN(String("Yahoo")).THEN(String("search")). + WHEN(String("GMail")).THEN(String("mail")). + ELSE(String("unknown")).AS("dest.case_operator"), + ) + + expectedSQL := ` +UPDATE link +SET id = ?, + url = ?, + name = ?, + description = ? +WHERE link.id = ? +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description", + ? AS "dest.literal", + (NOT ?) AS "dest.unary_operator", + (link.id + ?) AS "dest.binary_operator", + CAST(link.id AS TEXT) AS "dest.cast_operator", + (link.name LIKE ?) AS "dest.like_operator", + link.description IS NULL AS "dest.is_null", + (CASE link.name WHEN ? THEN ? WHEN ? THEN ? ELSE ? END) AS "dest.case_operator"; +` + testutils.AssertStatementSql(t, stmt, expectedSQL, int32(20), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(20), + "str", false, int64(11), "Bing", "Yahoo", "search", "GMail", "mail", "unknown") + + type Dest struct { + model.Link + Literal string + UnaryOperator bool + BinaryOperator int64 + CastOperator string + LikeOperator bool + IsNull bool + CaseOperator string + } + + var dest Dest + + err := stmt.Query(tx, &dest) + require.NoError(t, err) + require.EqualValues(t, dest, Dest{ + Link: link, + Literal: "str", + UnaryOperator: true, + BinaryOperator: 31, + CastOperator: "20", + LikeOperator: false, + IsNull: true, + CaseOperator: "unknown", + }) + requireLogged(t, stmt) +} + +func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ + ID: 20, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + updateColumnList := ColumnList{Link.Description, Link.Name, Link.URL} + + stmt := Link.UPDATE(updateColumnList). + MODEL(link). + WHERE(Link.ID.EQ(Int32(link.ID))) + + var expectedSQL = ` +UPDATE link +SET description = NULL, + name = 'DuckDuckGo', + url = 'http://www.duckduckgo.com' +WHERE link.id = 20; +` + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(20)) + + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) +} + +func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + link := model.Link{ + ID: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link.UPDATE(Link.MutableColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int32(link.ID))) + + var expectedSQL = ` +UPDATE link +SET url = 'http://www.duckduckgo.com', + name = 'DuckDuckGo', + description = NULL +WHERE link.id = 201; +` + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201)) + testutils.AssertExec(t, stmt, tx) +} + +func TestUpdateWithInvalidModelData(t *testing.T) { + defer func() { + r := recover() + require.Equal(t, r, "missing struct field for column : id") + }() + + link := struct { + Ident int + URL string + Name string + Description *string + Rel *string + }{ + Ident: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link.UPDATE(Link.AllColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int(int64(link.Ident)))) + + stmt.Sql() +} + +func TestUpdateContextDeadlineExceeded(t *testing.T) { + tx := beginSampleDBTx(t) + defer tx.Rollback() + + updateStmt := Link.UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) + + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond) + defer cancel() + + time.Sleep(10 * time.Millisecond) + + dest := []model.Link{} + err := updateStmt.QueryContext(ctx, tx, &dest) + require.Error(t, err, "context deadline exceeded") + + _, err = updateStmt.ExecContext(ctx, tx) + require.Error(t, err, "context deadline exceeded") +} diff --git a/tests/sqlite/with_test.go b/tests/sqlite/with_test.go new file mode 100644 index 0000000..f2b623a --- /dev/null +++ b/tests/sqlite/with_test.go @@ -0,0 +1,234 @@ +package sqlite + +import ( + "github.com/go-jet/jet/v2/internal/testutils" + . "github.com/go-jet/jet/v2/sqlite" + . "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" + "github.com/stretchr/testify/require" + "strings" + "testing" +) + +func TestWITH_And_SELECT(t *testing.T) { + salesRep := CTE("sales_rep") + salesRepStaffID := Staff.StaffID.From(salesRep) + salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep) + customerSalesRep := CTE("customer_sales_rep") + + stmt := WITH( + salesRep.AS( + SELECT( + Staff.StaffID, + Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()), + ).FROM(Staff), + ), + customerSalesRep.AS( + SELECT( + Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"), + salesRepFullName, + ).FROM( + salesRep. + INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)). + INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)), + ), + ), + )( + SELECT(customerSalesRep.AllColumns()). + FROM(customerSalesRep), + ) + + testutils.AssertStatementSql(t, stmt, strings.Replace(` +WITH sales_rep AS ( + SELECT staff.staff_id AS "staff.staff_id", + (staff.first_name || staff.last_name) AS "sales_rep_full_name" + FROM staff +),customer_sales_rep AS ( + SELECT (customer.first_name || customer.last_name) AS "customer_name", + sales_rep.sales_rep_full_name AS "sales_rep_full_name" + FROM sales_rep + INNER JOIN store ON (store.manager_staff_id = sales_rep.''staff.staff_id'') + INNER JOIN customer ON (customer.store_id = store.store_id) +) +SELECT customer_sales_rep.customer_name AS "customer_name", + customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name" +FROM customer_sales_rep; +`, "''", "`", -1)) + + var dest []struct { + CustomerName string + SalesRepFullName string + } + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, len(dest), 599) +} + +func TestWITH_And_INSERT(t *testing.T) { + paymentsToInsert := CTE("payments_to_insert") + + stmt := WITH( + paymentsToInsert.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.INSERT(Payment.AllColumns). + QUERY( + SELECT( + paymentsToInsert.AllColumns(), + ).FROM( + paymentsToInsert, + ).WHERE(Bool(true)), //https://stackoverflow.com/questions/66230093/error-while-doing-upsert-in-sqlite-3-34-error-near-do-syntax-error + ).ON_CONFLICT().DO_UPDATE( + SET( + Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))), + ), + ), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +WITH payments_to_insert AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM payment + WHERE payment.amount < 0.5 +) +INSERT INTO payment (payment_id, customer_id, staff_id, rental_id, amount, payment_date, last_update) +SELECT payments_to_insert.''payment.payment_id'' AS "payment.payment_id", + payments_to_insert.''payment.customer_id'' AS "payment.customer_id", + payments_to_insert.''payment.staff_id'' AS "payment.staff_id", + payments_to_insert.''payment.rental_id'' AS "payment.rental_id", + payments_to_insert.''payment.amount'' AS "payment.amount", + payments_to_insert.''payment.payment_date'' AS "payment.payment_date", + payments_to_insert.''payment.last_update'' AS "payment.last_update" +FROM payments_to_insert +WHERE TRUE +ON CONFLICT DO UPDATE + SET payment_id = (payment.payment_id + 100000); +`, "''", "`", -1)) + + tx := beginDBTx(t) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx, 24) +} + +func TestWITH_SELECT_UPDATE(t *testing.T) { + paymentsToUpdate := CTE("payments_to_update") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate) + + stmt := WITH( + paymentsToUpdate.AS( + SELECT(Payment.AllColumns). + FROM(Payment). + WHERE(Payment.Amount.LT(Float(0.5))), + ), + )( + Payment.UPDATE(). + SET(Payment.Amount.SET(Float(0.0))). + WHERE(Payment.PaymentID.IN( + SELECT(paymentsToDeleteID). + FROM(paymentsToUpdate), + ), + ), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +WITH payments_to_update AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM payment + WHERE payment.amount < 0.5 +) +UPDATE payment +SET amount = 0 +WHERE payment.payment_id IN ( + SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id" + FROM payments_to_update + ); +`, "''", "`", -1)) + + tx := beginDBTx(t) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx) +} + +func TestWITH_And_DELETE(t *testing.T) { + paymentsToDelete := CTE("payments_to_delete") + paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete) + + stmt := WITH( + paymentsToDelete.AS( + SELECT( + Payment.AllColumns, + ).FROM( + Payment, + ).WHERE( + Payment.Amount.LT(Float(0.5)), + ), + ), + )( + Payment.DELETE(). + WHERE( + Payment.PaymentID.IN( + SELECT( + paymentsToDeleteID, + ).FROM( + paymentsToDelete, + ), + ), + ), + ) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +WITH payments_to_delete AS ( + SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + payment.last_update AS "payment.last_update" + FROM payment + WHERE payment.amount < 0.5 +) +DELETE FROM payment +WHERE payment.payment_id IN ( + SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" + FROM payments_to_delete + ); +`, "''", "`", -1)) + + tx := beginDBTx(t) + defer tx.Rollback() + + testutils.AssertExec(t, stmt, tx, 24) +} + +func TestOperatorIN(t *testing.T) { + stmt := SELECT(Payment.PaymentID.IN(SELECT(Int(11)), Int(22))). + FROM(Payment) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT payment.payment_id IN (( + SELECT 11 + ), 22) +FROM payment; +`) + + var dest []struct{} + err := stmt.Query(db, &dest) + require.NoError(t, err) +} diff --git a/tests/testdata b/tests/testdata index a6c1975..946bc1e 160000 --- a/tests/testdata +++ b/tests/testdata @@ -1 +1 @@ -Subproject commit a6c1975a167645f913496131ae81d4cabc070046 +Subproject commit 946bc1e5d3e162154eade8b79ff915e4c4986efd