Add SQLBuilder support for SQLite databases.

This commit is contained in:
go-jet 2021-10-21 13:39:24 +02:00
parent d197956271
commit e8f4c2b31b
50 changed files with 5851 additions and 75 deletions

3
.gitignore vendored
View file

@ -18,4 +18,5 @@
# Test files # Test files
gen gen
.gentestdata .gentestdata
.tests/testdata/ .tests/testdata/
.gen

1
go.mod
View file

@ -9,6 +9,7 @@ require (
github.com/jackc/pgconn v1.8.1 github.com/jackc/pgconn v1.8.1
github.com/jackc/pgx/v4 v4.11.0 //tests github.com/jackc/pgx/v4 v4.11.0 //tests
github.com/lib/pq v1.7.0 github.com/lib/pq v1.7.0
github.com/mattn/go-sqlite3 v1.14.8
github.com/pkg/profile v1.5.0 //tests github.com/pkg/profile v1.5.0 //tests
github.com/shopspring/decimal v1.2.0 // tests github.com/shopspring/decimal v1.2.0 // tests
github.com/stretchr/testify v1.6.1 // tests github.com/stretchr/testify v1.6.1 // tests

2
go.sum
View file

@ -218,6 +218,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-sqlite3 v1.14.8 h1:gDp86IdQsN/xWjIEmr9MF6o9mpksUgh0fu+9ByFxzIU=
github.com/mattn/go-sqlite3 v1.14.8/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg= github.com/miekg/dns v1.0.14/go.mod h1:W1PPwlIAgtquWBMBEV9nkV9Cazfe8ScdGz/Lj7v3Nrg=
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=

View file

@ -2,7 +2,7 @@ package jet
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression { func ROW(expressions ...Expression) Expression {
return newFunc("ROW", expressions, nil) return NewFunc("ROW", expressions, nil)
} }
// ------------------ Mathematical functions ---------------// // ------------------ Mathematical functions ---------------//
@ -265,118 +265,118 @@ func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression {
// LOWER returns string expression in lower case // LOWER returns string expression in lower case
func LOWER(stringExpression StringExpression) StringExpression { func LOWER(stringExpression StringExpression) StringExpression {
return newStringFunc("LOWER", stringExpression) return NewStringFunc("LOWER", stringExpression)
} }
// UPPER returns string expression in upper case // UPPER returns string expression in upper case
func UPPER(stringExpression StringExpression) StringExpression { func UPPER(stringExpression StringExpression) StringExpression {
return newStringFunc("UPPER", stringExpression) return NewStringFunc("UPPER", stringExpression)
} }
// BTRIM removes the longest string consisting only of characters // BTRIM removes the longest string consisting only of characters
// in characters (a space by default) from the start and end of string // in characters (a space by default) from the start and end of string
func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return newStringFunc("BTRIM", stringExpression, trimChars[0]) return NewStringFunc("BTRIM", stringExpression, trimChars[0])
} }
return newStringFunc("BTRIM", stringExpression) return NewStringFunc("BTRIM", stringExpression)
} }
// LTRIM removes the longest string containing only characters // LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string // from characters (a space by default) from the start of string
func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return newStringFunc("LTRIM", str, trimChars[0]) return NewStringFunc("LTRIM", str, trimChars[0])
} }
return newStringFunc("LTRIM", str) return NewStringFunc("LTRIM", str)
} }
// RTRIM removes the longest string containing only characters // RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string // from characters (a space by default) from the end of string
func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression {
if len(trimChars) > 0 { if len(trimChars) > 0 {
return newStringFunc("RTRIM", str, trimChars[0]) return NewStringFunc("RTRIM", str, trimChars[0])
} }
return newStringFunc("RTRIM", str) return NewStringFunc("RTRIM", str)
} }
// CHR returns character with the given code. // CHR returns character with the given code.
func CHR(integerExpression IntegerExpression) StringExpression { func CHR(integerExpression IntegerExpression) StringExpression {
return newStringFunc("CHR", integerExpression) return NewStringFunc("CHR", integerExpression)
} }
// CONCAT adds two or more expressions together // CONCAT adds two or more expressions together
func CONCAT(expressions ...Expression) StringExpression { func CONCAT(expressions ...Expression) StringExpression {
return newStringFunc("CONCAT", expressions...) return NewStringFunc("CONCAT", expressions...)
} }
// CONCAT_WS adds two or more expressions together with a separator. // CONCAT_WS adds two or more expressions together with a separator.
func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression { func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression {
return newStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...) return NewStringFunc("CONCAT_WS", append([]Expression{separator}, expressions...)...)
} }
// CONVERT converts string to dest_encoding. The original encoding is // CONVERT converts string to dest_encoding. The original encoding is
// specified by src_encoding. The string must be valid in this encoding. // specified by src_encoding. The string must be valid in this encoding.
func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT", str, srcEncoding, destEncoding) return NewStringFunc("CONVERT", str, srcEncoding, destEncoding)
} }
// CONVERT_FROM converts string to the database encoding. The original // CONVERT_FROM converts string to the database encoding. The original
// encoding is specified by src_encoding. The string must be valid in this encoding. // encoding is specified by src_encoding. The string must be valid in this encoding.
func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_FROM", str, srcEncoding) return NewStringFunc("CONVERT_FROM", str, srcEncoding)
} }
// CONVERT_TO converts string to dest_encoding. // CONVERT_TO converts string to dest_encoding.
func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression {
return newStringFunc("CONVERT_TO", str, toEncoding) return NewStringFunc("CONVERT_TO", str, toEncoding)
} }
// ENCODE encodes binary data into a textual representation. // ENCODE encodes binary data into a textual representation.
// Supported formats are: base64, hex, escape. escape converts zero bytes and // Supported formats are: base64, hex, escape. escape converts zero bytes and
// high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes.
func ENCODE(data StringExpression, format StringExpression) StringExpression { func ENCODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("ENCODE", data, format) return NewStringFunc("ENCODE", data, format)
} }
// DECODE decodes binary data from textual representation in string. // DECODE decodes binary data from textual representation in string.
// Options for format are same as in encode. // Options for format are same as in encode.
func DECODE(data StringExpression, format StringExpression) StringExpression { func DECODE(data StringExpression, format StringExpression) StringExpression {
return newStringFunc("DECODE", data, format) return NewStringFunc("DECODE", data, format)
} }
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression {
args := []Expression{formatStr} args := []Expression{formatStr}
args = append(args, formatArgs...) args = append(args, formatArgs...)
return newStringFunc("FORMAT", args...) return NewStringFunc("FORMAT", args...)
} }
// INITCAP converts the first letter of each word to upper case // INITCAP converts the first letter of each word to upper case
// and the rest to lower case. Words are sequences of alphanumeric // and the rest to lower case. Words are sequences of alphanumeric
// characters separated by non-alphanumeric characters. // characters separated by non-alphanumeric characters.
func INITCAP(str StringExpression) StringExpression { func INITCAP(str StringExpression) StringExpression {
return newStringFunc("INITCAP", str) return NewStringFunc("INITCAP", str)
} }
// LEFT returns first n characters in the string. // LEFT returns first n characters in the string.
// When n is negative, return all but last |n| characters. // When n is negative, return all but last |n| characters.
func LEFT(str StringExpression, n IntegerExpression) StringExpression { func LEFT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("LEFT", str, n) return NewStringFunc("LEFT", str, n)
} }
// RIGHT returns last n characters in the string. // RIGHT returns last n characters in the string.
// When n is negative, return all but first |n| characters. // When n is negative, return all but first |n| characters.
func RIGHT(str StringExpression, n IntegerExpression) StringExpression { func RIGHT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("RIGHT", str, n) return NewStringFunc("RIGHT", str, n)
} }
// LENGTH returns number of characters in string with a given encoding // LENGTH returns number of characters in string with a given encoding
func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 { if len(encoding) > 0 {
return newStringFunc("LENGTH", str, encoding[0]) return NewStringFunc("LENGTH", str, encoding[0])
} }
return newStringFunc("LENGTH", str) return NewStringFunc("LENGTH", str)
} }
// LPAD fills up the string to length length by prepending the characters // LPAD fills up the string to length length by prepending the characters
@ -384,40 +384,40 @@ func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression
// then it is truncated (on the right). // then it is truncated (on the right).
func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { func LPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 { if len(text) > 0 {
return newStringFunc("LPAD", str, length, text[0]) return NewStringFunc("LPAD", str, length, text[0])
} }
return newStringFunc("LPAD", str, length) return NewStringFunc("LPAD", str, length)
} }
// RPAD fills up the string to length length by appending the characters // RPAD fills up the string to length length by appending the characters
// fill (a space by default). If the string is already longer than length then it is truncated. // fill (a space by default). If the string is already longer than length then it is truncated.
func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression { func RPAD(str StringExpression, length IntegerExpression, text ...StringExpression) StringExpression {
if len(text) > 0 { if len(text) > 0 {
return newStringFunc("RPAD", str, length, text[0]) return NewStringFunc("RPAD", str, length, text[0])
} }
return newStringFunc("RPAD", str, length) return NewStringFunc("RPAD", str, length)
} }
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal // MD5 calculates the MD5 hash of string, returning the result in hexadecimal
func MD5(stringExpression StringExpression) StringExpression { func MD5(stringExpression StringExpression) StringExpression {
return newStringFunc("MD5", stringExpression) return NewStringFunc("MD5", stringExpression)
} }
// REPEAT repeats string the specified number of times // REPEAT repeats string the specified number of times
func REPEAT(str StringExpression, n IntegerExpression) StringExpression { func REPEAT(str StringExpression, n IntegerExpression) StringExpression {
return newStringFunc("REPEAT", str, n) return NewStringFunc("REPEAT", str, n)
} }
// REPLACE replaces all occurrences in string of substring from with substring to // REPLACE replaces all occurrences in string of substring from with substring to
func REPLACE(text, from, to StringExpression) StringExpression { func REPLACE(text, from, to StringExpression) StringExpression {
return newStringFunc("REPLACE", text, from, to) return NewStringFunc("REPLACE", text, from, to)
} }
// REVERSE returns reversed string. // REVERSE returns reversed string.
func REVERSE(stringExpression StringExpression) StringExpression { func REVERSE(stringExpression StringExpression) StringExpression {
return newStringFunc("REVERSE", stringExpression) return NewStringFunc("REVERSE", stringExpression)
} }
// STRPOS returns location of specified substring (same as position(substring in string), // STRPOS returns location of specified substring (same as position(substring in string),
@ -429,22 +429,22 @@ func STRPOS(str, substring StringExpression) IntegerExpression {
// SUBSTR extracts substring // SUBSTR extracts substring
func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression {
if len(count) > 0 { if len(count) > 0 {
return newStringFunc("SUBSTR", str, from, count[0]) return NewStringFunc("SUBSTR", str, from, count[0])
} }
return newStringFunc("SUBSTR", str, from) return NewStringFunc("SUBSTR", str, from)
} }
// TO_ASCII convert string to ASCII from another encoding // TO_ASCII convert string to ASCII from another encoding
func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression { func TO_ASCII(str StringExpression, encoding ...StringExpression) StringExpression {
if len(encoding) > 0 { if len(encoding) > 0 {
return newStringFunc("TO_ASCII", str, encoding[0]) return NewStringFunc("TO_ASCII", str, encoding[0])
} }
return newStringFunc("TO_ASCII", str) return NewStringFunc("TO_ASCII", str)
} }
// TO_HEX converts number to its equivalent hexadecimal representation // TO_HEX converts number to its equivalent hexadecimal representation
func TO_HEX(number IntegerExpression) StringExpression { func TO_HEX(number IntegerExpression) StringExpression {
return newStringFunc("TO_HEX", number) return NewStringFunc("TO_HEX", number)
} }
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
@ -460,12 +460,12 @@ func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType
// TO_CHAR converts expression to string with format // TO_CHAR converts expression to string with format
func TO_CHAR(expression Expression, format StringExpression) StringExpression { func TO_CHAR(expression Expression, format StringExpression) StringExpression {
return newStringFunc("TO_CHAR", expression, format) return NewStringFunc("TO_CHAR", expression, format)
} }
// TO_DATE converts string to date using format // TO_DATE converts string to date using format
func TO_DATE(dateStr, format StringExpression) DateExpression { func TO_DATE(dateStr, format StringExpression) DateExpression {
return newDateFunc("TO_DATE", dateStr, format) return NewDateFunc("TO_DATE", dateStr, format)
} }
// TO_NUMBER converts string to numeric using format // TO_NUMBER converts string to numeric using format
@ -482,7 +482,7 @@ func TO_TIMESTAMP(timestampzStr, format StringExpression) TimestampzExpression {
// CURRENT_DATE returns current date // CURRENT_DATE returns current date
func CURRENT_DATE() DateExpression { func CURRENT_DATE() DateExpression {
dateFunc := newDateFunc("CURRENT_DATE") dateFunc := NewDateFunc("CURRENT_DATE")
dateFunc.noBrackets = true dateFunc.noBrackets = true
return dateFunc return dateFunc
} }
@ -522,9 +522,9 @@ func LOCALTIME(precision ...int) TimeExpression {
var timeFunc *timeFunc var timeFunc *timeFunc
if len(precision) > 0 { if len(precision) > 0 {
timeFunc = newTimeFunc("LOCALTIME", FixedLiteral(precision[0])) timeFunc = NewTimeFunc("LOCALTIME", FixedLiteral(precision[0]))
} else { } else {
timeFunc = newTimeFunc("LOCALTIME") timeFunc = NewTimeFunc("LOCALTIME")
} }
timeFunc.noBrackets = true timeFunc.noBrackets = true
@ -558,26 +558,26 @@ func NOW() TimestampzExpression {
func COALESCE(value Expression, values ...Expression) Expression { func COALESCE(value Expression, values ...Expression) Expression {
var allValues = []Expression{value} var allValues = []Expression{value}
allValues = append(allValues, values...) allValues = append(allValues, values...)
return newFunc("COALESCE", allValues, nil) return NewFunc("COALESCE", allValues, nil)
} }
// NULLIF function returns a null value if value1 equals value2; otherwise it returns value1. // NULLIF function returns a null value if value1 equals value2; otherwise it returns value1.
func NULLIF(value1, value2 Expression) Expression { func NULLIF(value1, value2 Expression) Expression {
return newFunc("NULLIF", []Expression{value1, value2}, nil) return NewFunc("NULLIF", []Expression{value1, value2}, nil)
} }
// GREATEST selects the largest value from a list of expressions // GREATEST selects the largest value from a list of expressions
func GREATEST(value Expression, values ...Expression) Expression { func GREATEST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value} var allValues = []Expression{value}
allValues = append(allValues, values...) allValues = append(allValues, values...)
return newFunc("GREATEST", allValues, nil) return NewFunc("GREATEST", allValues, nil)
} }
// LEAST selects the smallest value from a list of expressions // LEAST selects the smallest value from a list of expressions
func LEAST(value Expression, values ...Expression) Expression { func LEAST(value Expression, values ...Expression) Expression {
var allValues = []Expression{value} var allValues = []Expression{value}
allValues = append(allValues, values...) allValues = append(allValues, values...)
return newFunc("LEAST", allValues, nil) return NewFunc("LEAST", allValues, nil)
} }
//--------------------------------------------------------------------// //--------------------------------------------------------------------//
@ -590,7 +590,8 @@ type funcExpressionImpl struct {
noBrackets bool noBrackets bool
} }
func newFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl { // NewFunc creates new function with name and expressions parameters
func NewFunc(name string, expressions []Expression, parent Expression) *funcExpressionImpl {
funcExp := &funcExpressionImpl{ funcExp := &funcExpressionImpl{
name: name, name: name,
expressions: expressions, expressions: expressions,
@ -608,7 +609,7 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr
// NewFloatWindowFunc creates new float function with name and expressions // NewFloatWindowFunc creates new float function with name and expressions
func newWindowFunc(name string, expressions ...Expression) windowExpression { func newWindowFunc(name string, expressions ...Expression) windowExpression {
newFun := newFunc(name, expressions, nil) newFun := NewFunc(name, expressions, nil)
windowExpr := newWindowExpression(newFun) windowExpr := newWindowExpression(newFun)
newFun.ExpressionInterfaceImpl.Parent = windowExpr newFun.ExpressionInterfaceImpl.Parent = windowExpr
@ -645,7 +646,7 @@ type boolFunc struct {
func newBoolFunc(name string, expressions ...Expression) BoolExpression { func newBoolFunc(name string, expressions ...Expression) BoolExpression {
boolFunc := &boolFunc{} boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc)
boolFunc.boolInterfaceImpl.parent = boolFunc boolFunc.boolInterfaceImpl.parent = boolFunc
boolFunc.ExpressionInterfaceImpl.Parent = boolFunc boolFunc.ExpressionInterfaceImpl.Parent = boolFunc
@ -656,7 +657,7 @@ func newBoolFunc(name string, expressions ...Expression) BoolExpression {
func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression { func newBoolWindowFunc(name string, expressions ...Expression) boolWindowExpression {
boolFunc := &boolFunc{} boolFunc := &boolFunc{}
boolFunc.funcExpressionImpl = *newFunc(name, expressions, boolFunc) boolFunc.funcExpressionImpl = *NewFunc(name, expressions, boolFunc)
intWindowFunc := newBoolWindowExpression(boolFunc) intWindowFunc := newBoolWindowExpression(boolFunc)
boolFunc.boolInterfaceImpl.parent = intWindowFunc boolFunc.boolInterfaceImpl.parent = intWindowFunc
boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc boolFunc.ExpressionInterfaceImpl.Parent = intWindowFunc
@ -673,7 +674,7 @@ type floatFunc struct {
func NewFloatFunc(name string, expressions ...Expression) FloatExpression { func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
floatFunc := &floatFunc{} floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.floatInterfaceImpl.parent = floatFunc floatFunc.floatInterfaceImpl.parent = floatFunc
return floatFunc return floatFunc
@ -683,7 +684,7 @@ func NewFloatFunc(name string, expressions ...Expression) FloatExpression {
func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression { func NewFloatWindowFunc(name string, expressions ...Expression) floatWindowExpression {
floatFunc := &floatFunc{} floatFunc := &floatFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatWindowFunc := newFloatWindowExpression(floatFunc) floatWindowFunc := newFloatWindowExpression(floatFunc)
floatFunc.floatInterfaceImpl.parent = floatWindowFunc floatFunc.floatInterfaceImpl.parent = floatWindowFunc
floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc floatFunc.ExpressionInterfaceImpl.Parent = floatWindowFunc
@ -699,7 +700,7 @@ type integerFunc struct {
func newIntegerFunc(name string, expressions ...Expression) IntegerExpression { func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
floatFunc := &integerFunc{} floatFunc := &integerFunc{}
floatFunc.funcExpressionImpl = *newFunc(name, expressions, floatFunc) floatFunc.funcExpressionImpl = *NewFunc(name, expressions, floatFunc)
floatFunc.integerInterfaceImpl.parent = floatFunc floatFunc.integerInterfaceImpl.parent = floatFunc
return floatFunc return floatFunc
@ -709,7 +710,7 @@ func newIntegerFunc(name string, expressions ...Expression) IntegerExpression {
func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression { func newIntegerWindowFunc(name string, expressions ...Expression) integerWindowExpression {
integerFunc := &integerFunc{} integerFunc := &integerFunc{}
integerFunc.funcExpressionImpl = *newFunc(name, expressions, integerFunc) integerFunc.funcExpressionImpl = *NewFunc(name, expressions, integerFunc)
intWindowFunc := newIntegerWindowExpression(integerFunc) intWindowFunc := newIntegerWindowExpression(integerFunc)
integerFunc.integerInterfaceImpl.parent = intWindowFunc integerFunc.integerInterfaceImpl.parent = intWindowFunc
integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc integerFunc.ExpressionInterfaceImpl.Parent = intWindowFunc
@ -722,10 +723,11 @@ type stringFunc struct {
stringInterfaceImpl stringInterfaceImpl
} }
func newStringFunc(name string, expressions ...Expression) StringExpression { // NewStringFunc creates new string function with name and expression parameters
func NewStringFunc(name string, expressions ...Expression) StringExpression {
stringFunc := &stringFunc{} stringFunc := &stringFunc{}
stringFunc.funcExpressionImpl = *newFunc(name, expressions, stringFunc) stringFunc.funcExpressionImpl = *NewFunc(name, expressions, stringFunc)
stringFunc.stringInterfaceImpl.parent = stringFunc stringFunc.stringInterfaceImpl.parent = stringFunc
return stringFunc return stringFunc
@ -736,10 +738,11 @@ type dateFunc struct {
dateInterfaceImpl dateInterfaceImpl
} }
func newDateFunc(name string, expressions ...Expression) *dateFunc { // NewDateFunc creates new date function with name and expression parameters
func NewDateFunc(name string, expressions ...Expression) *dateFunc {
dateFunc := &dateFunc{} dateFunc := &dateFunc{}
dateFunc.funcExpressionImpl = *newFunc(name, expressions, dateFunc) dateFunc.funcExpressionImpl = *NewFunc(name, expressions, dateFunc)
dateFunc.dateInterfaceImpl.parent = dateFunc dateFunc.dateInterfaceImpl.parent = dateFunc
return dateFunc return dateFunc
@ -750,10 +753,11 @@ type timeFunc struct {
timeInterfaceImpl timeInterfaceImpl
} }
func newTimeFunc(name string, expressions ...Expression) *timeFunc { // NewTimeFunc creates new time function with name and expression parameters
func NewTimeFunc(name string, expressions ...Expression) *timeFunc {
timeFun := &timeFunc{} timeFun := &timeFunc{}
timeFun.funcExpressionImpl = *newFunc(name, expressions, timeFun) timeFun.funcExpressionImpl = *NewFunc(name, expressions, timeFun)
timeFun.timeInterfaceImpl.parent = timeFun timeFun.timeInterfaceImpl.parent = timeFun
return timeFun return timeFun
@ -767,7 +771,7 @@ type timezFunc struct {
func newTimezFunc(name string, expressions ...Expression) *timezFunc { func newTimezFunc(name string, expressions ...Expression) *timezFunc {
timezFun := &timezFunc{} timezFun := &timezFunc{}
timezFun.funcExpressionImpl = *newFunc(name, expressions, timezFun) timezFun.funcExpressionImpl = *NewFunc(name, expressions, timezFun)
timezFun.timezInterfaceImpl.parent = timezFun timezFun.timezInterfaceImpl.parent = timezFun
return timezFun return timezFun
@ -782,7 +786,7 @@ type timestampFunc struct {
func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc { func NewTimestampFunc(name string, expressions ...Expression) *timestampFunc {
timestampFunc := &timestampFunc{} timestampFunc := &timestampFunc{}
timestampFunc.funcExpressionImpl = *newFunc(name, expressions, timestampFunc) timestampFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampFunc)
timestampFunc.timestampInterfaceImpl.parent = timestampFunc timestampFunc.timestampInterfaceImpl.parent = timestampFunc
return timestampFunc return timestampFunc
@ -796,7 +800,7 @@ type timestampzFunc struct {
func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc { func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
timestampzFunc := &timestampzFunc{} timestampzFunc := &timestampzFunc{}
timestampzFunc.funcExpressionImpl = *newFunc(name, expressions, timestampzFunc) timestampzFunc.funcExpressionImpl = *NewFunc(name, expressions, timestampzFunc)
timestampzFunc.timestampzInterfaceImpl.parent = timestampzFunc timestampzFunc.timestampzInterfaceImpl.parent = timestampzFunc
return timestampzFunc return timestampzFunc
@ -804,5 +808,5 @@ func newTimestampzFunc(name string, expressions ...Expression) *timestampzFunc {
// Func can be used to call an custom or as of yet unsupported function in the database. // Func can be used to call an custom or as of yet unsupported function in the database.
func Func(name string, expressions ...Expression) Expression { func Func(name string, expressions ...Expression) Expression {
return newFunc(name, expressions, nil) return NewFunc(name, expressions, nil)
} }

View file

@ -19,7 +19,7 @@ func (i *IsIntervalImpl) isInterval() {}
// NewInterval creates new interval from serializer // NewInterval creates new interval from serializer
func NewInterval(s Serializer) *IntervalImpl { func NewInterval(s Serializer) *IntervalImpl {
newInterval := &IntervalImpl{ newInterval := &IntervalImpl{
interval: s, Value: s,
} }
return newInterval return newInterval
@ -27,11 +27,11 @@ func NewInterval(s Serializer) *IntervalImpl {
// IntervalImpl is implementation of Interval type // IntervalImpl is implementation of Interval type
type IntervalImpl struct { type IntervalImpl struct {
interval Serializer Value Serializer
IsIntervalImpl IsIntervalImpl
} }
func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { func (i IntervalImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("INTERVAL") out.WriteString("INTERVAL")
i.interval.serialize(statement, out, FallTrough(options)...) i.Value.serialize(statement, out, FallTrough(options)...)
} }

View file

@ -20,6 +20,11 @@ import (
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
) )
// UnixTimeComparer will compare time equality while ignoring time zone
var UnixTimeComparer = cmp.Comparer(func(t1, t2 time.Time) bool {
return t1.Unix() == t2.Unix()
})
// AssertExec assert statement execution for successful execution and number of rows affected // AssertExec assert statement execution for successful execution and number of rows affected
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
@ -113,7 +118,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
_, args := query.Sql() _, args := query.Sql()
if len(expectedArgs) > 0 { if len(expectedArgs) > 0 {
AssertDeepEqual(t, args, expectedArgs, "arguments are not equal") AssertDeepEqual(t, args, expectedArgs)
} }
debugSql := query.DebugSql() debugSql := query.DebugSql()
@ -223,9 +228,9 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
} }
// AssertDeepEqual checks if actual and expected objects are deeply equal. // AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { func AssertDeepEqual(t *testing.T, actual, expected interface{}, option ...cmp.Option) {
if !assert.True(t, cmp.Equal(actual, expected), msg) { if !assert.True(t, cmp.Equal(actual, expected, option...)) {
printDiff(actual, expected) printDiff(actual, expected, option...)
t.FailNow() t.FailNow()
} }
} }
@ -237,7 +242,8 @@ func assertQueryString(t *testing.T, actual, expected string) {
} }
} }
func printDiff(actual, expected interface{}) { func printDiff(actual, expected interface{}, options ...cmp.Option) {
fmt.Println(cmp.Diff(actual, expected, options...))
fmt.Println("Actual: ") fmt.Println("Actual: ")
fmt.Println(actual) fmt.Println(actual)
fmt.Println("Expected: ") fmt.Println("Expected: ")

View file

@ -59,7 +59,7 @@ func (nt *NullTime) Scan(value interface{}) error {
} }
// Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value. // Some of the drivers (pgx, mysql) are not parsing all of the time formats(date, time with time zone,...) and are just forwarding string value.
// At this point we try to parse time using some of the predefined formats // At this point we try to parse those values using some of the predefined formats
nt.Time, nt.Valid = tryParseAsTime(value) nt.Time, nt.Valid = tryParseAsTime(value)
if !nt.Valid { if !nt.Valid {
@ -70,6 +70,7 @@ func (nt *NullTime) Scan(value interface{}) error {
} }
var formats = []string{ var formats = []string{
"2006-01-02 15:04:05-07:00", // sqlite
"2006-01-02 15:04:05.999999", // go-sql-driver/mysql "2006-01-02 15:04:05.999999", // go-sql-driver/mysql
"15:04:05-07", // pgx "15:04:05-07", // pgx
"15:04:05.999999", // pgx "15:04:05.999999", // pgx
@ -84,6 +85,8 @@ func tryParseAsTime(value interface{}) (time.Time, bool) {
timeStr = v timeStr = v
case []byte: case []byte:
timeStr = string(v) timeStr = string(v)
case int64:
return time.Unix(v, 0), true // sqlite
default: default:
return time.Time{}, false return time.Time{}, false
} }

55
sqlite/cast.go Normal file
View file

@ -0,0 +1,55 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
type cast interface {
AS(castType string) Expression
AS_TEXT() StringExpression
AS_NUMERIC() FloatExpression
AS_INTEGER() IntegerExpression
AS_REAL() FloatExpression
AS_BLOB() StringExpression
}
type castImpl struct {
jet.Cast
}
// CAST function converts a expr (of any type) into latter specified datatype.
func CAST(expr Expression) cast {
castImpl := &castImpl{}
castImpl.Cast = jet.NewCastImpl(expr)
return castImpl
}
// AS casts expressions to castType
func (c *castImpl) AS(castType string) Expression {
return c.Cast.AS(castType)
}
// AS_TEXT cast expression to TEXT type
func (c *castImpl) AS_TEXT() StringExpression {
return StringExp(c.AS("TEXT"))
}
// AS_NUMERIC cast expression to NUMERIC type
func (c *castImpl) AS_NUMERIC() FloatExpression {
return FloatExp(c.AS("NUMERIC"))
}
// AS_INTEGER cast expression to INTEGER type
func (c *castImpl) AS_INTEGER() IntegerExpression {
return IntExp(c.AS("INTEGER"))
}
// AS_REAL cast expression to REAL type
func (c *castImpl) AS_REAL() FloatExpression {
return FloatExp(c.AS("REAL"))
}
// AS_BLOB cast expression to BLOB type
func (c *castImpl) AS_BLOB() StringExpression {
return StringExp(c.AS("BLOB"))
}

14
sqlite/cast_test.go Normal file
View file

@ -0,0 +1,14 @@
package sqlite
import (
"testing"
)
func TestCAST(t *testing.T) {
assertSerialize(t, CAST(Float(11.22)).AS("bigint"), `CAST(? AS bigint)`)
assertSerialize(t, CAST(Int(22)).AS_TEXT(), `CAST(? AS TEXT)`)
assertSerialize(t, CAST(Int(22)).AS_NUMERIC(), `CAST(? AS NUMERIC)`)
assertSerialize(t, CAST(String("22")).AS_INTEGER(), `CAST(? AS INTEGER)`)
assertSerialize(t, CAST(String("22.2")).AS_REAL(), `CAST(? AS REAL)`)
assertSerialize(t, CAST(String("blob")).AS_BLOB(), `CAST(? AS BLOB)`)
}

58
sqlite/columns.go Normal file
View file

@ -0,0 +1,58 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Column is common column interface for all types of columns.
type Column = jet.ColumnExpression
// ColumnList function returns list of columns that be used as projection or column list for UPDATE and INSERT statement.
type ColumnList = jet.ColumnList
// ColumnBool is interface for SQL boolean columns.
type ColumnBool = jet.ColumnBool
// BoolColumn creates named bool column.
var BoolColumn = jet.BoolColumn
// ColumnString is interface for SQL text, character, character varying
// bytea, uuid columns and enums types.
type ColumnString = jet.ColumnString
// StringColumn creates named string column.
var StringColumn = jet.StringColumn
// ColumnInteger is interface for SQL smallint, integer, bigint columns.
type ColumnInteger = jet.ColumnInteger
// IntegerColumn creates named integer column.
var IntegerColumn = jet.IntegerColumn
// ColumnFloat is interface for SQL real, numeric, decimal or double precision column.
type ColumnFloat = jet.ColumnFloat
// FloatColumn creates named float column.
var FloatColumn = jet.FloatColumn
// ColumnTime is interface for SQL time column.
type ColumnTime = jet.ColumnTime
// TimeColumn creates named time column
var TimeColumn = jet.TimeColumn
// ColumnDate is interface of SQL date columns.
type ColumnDate = jet.ColumnDate
// DateColumn creates named date column.
var DateColumn = jet.DateColumn
// ColumnDateTime is interface of SQL timestamp columns.
type ColumnDateTime = jet.ColumnTimestamp
// DateTimeColumn creates named timestamp column
var DateTimeColumn = jet.TimestampColumn
//ColumnTimestamp is interface of SQL timestamp columns.
type ColumnTimestamp = jet.ColumnTimestamp
// TimestampColumn creates named timestamp column
var TimestampColumn = jet.TimestampColumn

View file

@ -0,0 +1,61 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// DeleteStatement is interface for MySQL DELETE statement
type DeleteStatement interface {
Statement
WHERE(expression BoolExpression) DeleteStatement
ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement
LIMIT(limit int64) DeleteStatement
RETURNING(projections ...jet.Projection) DeleteStatement
}
type deleteStatementImpl struct {
jet.SerializerStatement
Delete jet.ClauseStatementBegin
Where jet.ClauseWhere
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
Returning jet.ClauseReturning
}
func newDeleteStatement(table Table) DeleteStatement {
newDelete := &deleteStatementImpl{}
newDelete.SerializerStatement = jet.NewStatementImpl(Dialect, jet.DeleteStatementType, newDelete,
&newDelete.Delete,
&newDelete.Where,
&newDelete.OrderBy,
&newDelete.Limit,
&newDelete.Returning,
)
newDelete.Delete.Name = "DELETE FROM"
newDelete.Delete.Tables = append(newDelete.Delete.Tables, table)
newDelete.Where.Mandatory = true
newDelete.Limit.Count = -1
return newDelete
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
d.Where.Condition = expression
return d
}
func (d *deleteStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) DeleteStatement {
d.OrderBy.List = orderByClauses
return d
}
func (d *deleteStatementImpl) LIMIT(limit int64) DeleteStatement {
d.Limit.Count = limit
return d
}
func (d *deleteStatementImpl) RETURNING(projections ...jet.Projection) DeleteStatement {
d.Returning.ProjectionList = projections
return d
}

View file

@ -0,0 +1,26 @@
package sqlite
import (
"testing"
)
func TestDeleteUnconditionally(t *testing.T) {
assertStatementSqlErr(t, table1.DELETE(), `jet: WHERE clause not set`)
assertStatementSqlErr(t, table1.DELETE().WHERE(nil), `jet: WHERE clause not set`)
}
func TestDeleteWithWhere(t *testing.T) {
assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))), `
DELETE FROM db.table1
WHERE table1.col1 = ?;
`, int64(1))
}
func TestDeleteWithWhereOrderByLimit(t *testing.T) {
assertStatementSql(t, table1.DELETE().WHERE(table1Col1.EQ(Int(1))).ORDER_BY(table1Col1).LIMIT(1), `
DELETE FROM db.table1
WHERE table1.col1 = ?
ORDER BY table1.col1
LIMIT ?;
`, int64(1), int64(1))
}

225
sqlite/dialect.go Normal file
View file

@ -0,0 +1,225 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
// Dialect is implementation of SQL Builder for SQLite databases.
var Dialect = newDialect()
func newDialect() jet.Dialect {
operatorSerializeOverrides := map[string]jet.SerializeOverride{}
operatorSerializeOverrides["IS DISTINCT FROM"] = sqlite_IS_DISTINCT_FROM
operatorSerializeOverrides["IS NOT DISTINCT FROM"] = sqlite_IS_NOT_DISTINCT_FROM
operatorSerializeOverrides["#"] = sqliteBitXOR
mySQLDialectParams := jet.DialectParams{
Name: "SQLite",
PackageName: "sqlite",
OperatorSerializeOverrides: operatorSerializeOverrides,
AliasQuoteChar: '"',
IdentifierQuoteChar: '`',
ArgumentPlaceholder: func(int) string {
return "?"
},
ReservedWords: reservedWords2,
}
return jet.NewDialect(mySQLDialectParams)
}
func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator XOR")
}
// (~(a&b))&(a|b)
a := expressions[0]
b := expressions[1]
out.WriteString("(~(")
jet.Serialize(a, statement, out, options...)
out.WriteByte('&')
jet.Serialize(b, statement, out, options...)
out.WriteString("))&(")
jet.Serialize(a, statement, out, options...)
out.WriteByte('|')
jet.Serialize(b, statement, out, options...)
out.WriteByte(')')
}
}
func sqlite_IS_NOT_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out)
out.WriteString("IS")
jet.Serialize(expressions[1], statement, out)
}
}
func sqlite_IS_DISTINCT_FROM(expressions ...jet.Serializer) jet.SerializerFunc {
return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(expressions) < 2 {
panic("jet: invalid number of expressions for operator")
}
jet.Serialize(expressions[0], statement, out)
out.WriteString("IS NOT")
jet.Serialize(expressions[1], statement, out)
}
}
var reservedWords2 = []string{
"ABORT",
"ACTION",
"ADD",
"AFTER",
"ALL",
"ALTER",
"ALWAYS",
"ANALYZE",
"AND",
"AS",
"ASC",
"ATTACH",
"AUTOINCREMENT",
"BEFORE",
"BEGIN",
"BETWEEN",
"BY",
"CASCADE",
"CASE",
"CAST",
"CHECK",
"COLLATE",
"COLUMN",
"COMMIT",
"CONFLICT",
"CONSTRAINT",
"CREATE",
"CROSS",
"CURRENT",
"CURRENT_DATE",
"CURRENT_TIME",
"CURRENT_TIMESTAMP",
"DATABASE",
"DEFAULT",
"DEFERRABLE",
"DEFERRED",
"DELETE",
"DESC",
"DETACH",
"DISTINCT",
"DO",
"DROP",
"EACH",
"ELSE",
"END",
"ESCAPE",
"EXCEPT",
"EXCLUDE",
"EXCLUSIVE",
"EXISTS",
"EXPLAIN",
"FAIL",
"FILTER",
"FIRST",
"FOLLOWING",
"FOR",
"FOREIGN",
"FROM",
"FULL",
"GENERATED",
"GLOB",
"GROUP",
"GROUPS",
"HAVING",
"IF",
"IGNORE",
"IMMEDIATE",
"IN",
"INDEX",
"INDEXED",
"INITIALLY",
"INNER",
"INSERT",
"INSTEAD",
"INTERSECT",
"INTO",
"IS",
"ISNULL",
"JOIN",
"KEY",
"LAST",
"LEFT",
"LIKE",
"LIMIT",
"MATCH",
"MATERIALIZED",
"NATURAL",
"NO",
"NOT",
"NOTHING",
"NOTNULL",
"NULL",
"NULLS",
"OF",
"OFFSET",
"ON",
"OR",
"ORDER",
"OTHERS",
"OUTER",
"OVER",
"PARTITION",
"PLAN",
"PRAGMA",
"PRECEDING",
"PRIMARY",
"QUERY",
"RAISE",
"RANGE",
"RECURSIVE",
"REFERENCES",
"REGEXP",
"REINDEX",
"RELEASE",
"RENAME",
"REPLACE",
"RESTRICT",
"RETURNING",
"RIGHT",
"ROLLBACK",
"ROW",
"ROWS",
"SAVEPOINT",
"SELECT",
"SET",
"TABLE",
"TEMP",
"TEMPORARY",
"THEN",
"TIES",
"TO",
"TRANSACTION",
"TRIGGER",
"UNBOUNDED",
"UNION",
"UNIQUE",
"UPDATE",
"USING",
"VACUUM",
"VALUES",
"VIEW",
"VIRTUAL",
"WHEN",
"WHERE",
"WINDOW",
"WITH",
"WITHOUT",
}

59
sqlite/dialect_test.go Normal file
View file

@ -0,0 +1,59 @@
package sqlite
import (
"testing"
)
func TestBoolExpressionIS_DISTINCT_FROM(t *testing.T) {
assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS NOT table2.col_bool)")
assertSerialize(t, table1ColBool.IS_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS NOT ?)", false)
}
func TestBoolExpressionIS_NOT_DISTINCT_FROM(t *testing.T) {
assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(table2ColBool), "(table1.col_bool IS table2.col_bool)")
assertSerialize(t, table1ColBool.IS_NOT_DISTINCT_FROM(Bool(false)), "(table1.col_bool IS ?)", false)
}
func TestBoolLiteral(t *testing.T) {
assertSerialize(t, Bool(true), "?", true)
assertSerialize(t, Bool(false), "?", false)
}
func TestIntegerExpressionDIV(t *testing.T) {
assertSerialize(t, table1ColInt.DIV(table2ColInt), "(table1.col_int / table2.col_int)")
assertSerialize(t, table1ColInt.DIV(Int(11)), "(table1.col_int / ?)", int64(11))
}
func TestIntExpressionPOW(t *testing.T) {
assertSerialize(t, table1ColInt.POW(table2ColInt), "POW(table1.col_int, table2.col_int)")
assertSerialize(t, table1ColInt.POW(Int(11)), "POW(table1.col_int, ?)", int64(11))
}
func TestIntExpressionBIT_XOR(t *testing.T) {
assertSerialize(t, table1ColInt.BIT_XOR(table2ColInt), "((~(table1.col_int & table2.col_int))&(table1.col_int | table2.col_int))")
assertSerialize(t, table1ColInt.BIT_XOR(Int(11)), "((~(table1.col_int & ?))&(table1.col_int | ?))", int64(11), int64(11))
}
func TestExists(t *testing.T) {
assertSerialize(t, EXISTS(
table2.
SELECT(Int(1)).
WHERE(table1Col1.EQ(table2Col3)),
),
`(EXISTS (
SELECT ?
FROM db.table2
WHERE table1.col1 = table2.col3
))`, int64(1))
}
func TestString_REGEXP_LIKE_operator(t *testing.T) {
assertSerialize(t, table3StrCol.REGEXP_LIKE(table2ColStr), "(table3.col2 REGEXP table2.col_str)")
assertSerialize(t, table3StrCol.REGEXP_LIKE(String("JOHN")), "(table3.col2 REGEXP ?)", "JOHN")
}
func TestString_NOT_REGEXP_LIKE_operator(t *testing.T) {
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(table2ColStr), "(table3.col2 NOT REGEXP table2.col_str)")
assertSerialize(t, table3StrCol.NOT_REGEXP_LIKE(String("JOHN")), "(table3.col2 NOT REGEXP ?)", "JOHN")
}

97
sqlite/expressions.go Normal file
View file

@ -0,0 +1,97 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Expression is common interface for all expressions.
// Can be Bool, Int, Float, String, Date, Time or Timestamp expressions.
type Expression = jet.Expression
// BoolExpression interface
type BoolExpression = jet.BoolExpression
// StringExpression interface
type StringExpression = jet.StringExpression
// NumericExpression is shared interface for integer or real expression
type NumericExpression = jet.NumericExpression
// IntegerExpression interface
type IntegerExpression = jet.IntegerExpression
// FloatExpression interface
type FloatExpression = jet.FloatExpression
// TimeExpression interface
type TimeExpression = jet.TimeExpression
// DateExpression interface
type DateExpression = jet.DateExpression
// DateTimeExpression interface
type DateTimeExpression = jet.TimestampExpression
// TimestampExpression interface
type TimestampExpression = jet.TimestampExpression
// BoolExp is bool expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as bool expression.
// Does not add sql cast to generated sql builder output.
var BoolExp = jet.BoolExp
// StringExp is string expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as string expression.
// Does not add sql cast to generated sql builder output.
var StringExp = jet.StringExp
// IntExp is int expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as int expression.
// Does not add sql cast to generated sql builder output.
var IntExp = jet.IntExp
// FloatExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as float expression.
// Does not add sql cast to generated sql builder output.
var FloatExp = jet.FloatExp
// TimeExp is time expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as time expression.
// Does not add sql cast to generated sql builder output.
var TimeExp = jet.TimeExp
// DateExp is date expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as date expression.
// Does not add sql cast to generated sql builder output.
var DateExp = jet.DateExp
// DateTimeExp is timestamp expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp expression.
// Does not add sql cast to generated sql builder output.
var DateTimeExp = jet.TimestampExp
// TimestampExp is timestamp expression wrapper around arbitrary expression.
// Allows go compiler to see any expression as timestamp expression.
// Does not add sql cast to generated sql builder output.
var TimestampExp = jet.TimestampExp
// RawArgs is type used to pass optional arguments to Raw method
type RawArgs = map[string]interface{}
// Raw can be used for any unsupported functions, operators or expressions.
// For example: Raw("current_database()")
// Raw helper methods for each of the sqlite types
var (
Raw = jet.Raw
RawInt = jet.RawInt
RawFloat = jet.RawFloat
RawString = jet.RawString
RawTime = jet.RawTime
RawTimestamp = jet.RawTimestamp
RawDate = jet.RawDate
)
// Func can be used to call an custom or as of yet unsupported function in the database.
var Func = jet.Func
// NewEnumValue creates new named enum value
var NewEnumValue = jet.NewEnumValue

View file

@ -0,0 +1,52 @@
package sqlite
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestRaw(t *testing.T) {
assertSerialize(t, Raw("current_database()"), "(current_database())")
assertDebugSerialize(t, Raw("current_database()"), "(current_database())")
assertSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}),
"(? + table.colInt + ?)", 11, 22)
assertDebugSerialize(t, Raw(":first_arg + table.colInt + :second_arg", RawArgs{":first_arg": 11, ":second_arg": 22}),
"(11 + table.colInt + 22)")
assertSerialize(t,
Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})),
"(? + (? + table.colInt + ?))",
int64(700), 11, 22)
assertDebugSerialize(t,
Int(700).ADD(RawInt("#1 + table.colInt + #2", RawArgs{"#1": 11, "#2": 22})),
"(700 + (11 + table.colInt + 22))")
}
func TestRawDuplicateArguments(t *testing.T) {
assertSerialize(t, Raw(":arg + table.colInt + :arg", RawArgs{":arg": 11}),
"(? + table.colInt + ?)", 11, 11)
assertSerialize(t, Raw("#age + table.colInt + #year + #age + #year + 11", RawArgs{"#age": 11, "#year": 2000}),
"(? + table.colInt + ? + ? + ? + 11)", 11, 2000, 11, 2000)
assertSerialize(t, Raw("#1 + all_types.integer + #2 + #1 + #2 + #3 + #4",
RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}),
`(? + all_types.integer + ? + ? + ? + ? + ?)`, 11, 22, 11, 22, 33, 44)
}
func TestRawInvalidArguments(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, "jet: named argument 'first_arg' does not appear in raw query", r)
}()
assertSerialize(t, Raw("table.colInt + :second_arg", RawArgs{"first_arg": 11}), "(table.colInt + ?)", 22)
}
func TestRawType(t *testing.T) {
assertSerialize(t, RawFloat("table.colInt + &float", RawArgs{"&float": 11.22}).EQ(Float(3.14)),
"((table.colInt + ?) = ?)", 11.22, 3.14)
assertSerialize(t, RawString("table.colStr || str", RawArgs{"str": "doe"}).EQ(String("john doe")),
"((table.colStr || ?) = ?)", "doe", "john doe")
}

342
sqlite/functions.go Normal file
View file

@ -0,0 +1,342 @@
package sqlite
import (
"fmt"
"github.com/go-jet/jet/v2/internal/jet"
"time"
)
// ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression {
return jet.NewFunc("", expressions, nil)
}
// ------------------ Mathematical functions ---------------//
// ABSf calculates absolute value from float expression
var ABSf = jet.ABSf
// ABSi calculates absolute value from int expression
var ABSi = jet.ABSi
// POW calculates power of base with exponent
var POW = jet.POW
// POWER calculates power of base with exponent
var POWER = jet.POWER
// SQRT calculates square root of numeric expression
var SQRT = jet.SQRT
// CBRT calculates cube root of numeric expression
func CBRT(number jet.NumericExpression) jet.FloatExpression {
return POWER(number, Float(1.0).DIV(Float(3.0)))
}
// CEIL calculates ceil of float expression
var CEIL = jet.CEIL
// FLOOR calculates floor of float expression
var FLOOR = jet.FLOOR
// ROUND calculates round of a float expressions with optional precision
var ROUND = jet.ROUND
// SIGN returns sign of float expression
var SIGN = jet.SIGN
// TRUNC calculates trunc of float expression with precision
var TRUNC = TRUNCATE
// TRUNCATE calculates trunc of float expression with precision
var TRUNCATE = func(floatExpression jet.FloatExpression, precision jet.IntegerExpression) jet.FloatExpression {
return jet.NewFloatFunc("TRUNCATE", floatExpression, precision)
}
// LN calculates natural algorithm of float expression
var LN = jet.LN
// LOG calculates logarithm of float expression
var LOG = jet.LOG
// ----------------- Aggregate functions -------------------//
// AVG is aggregate function used to calculate avg value from numeric expression
var AVG = jet.AVG
// BIT_AND is aggregate function used to calculates the bitwise AND of all non-null input values, or null if none.
//var BIT_AND = jet.BIT_AND
// BIT_OR is aggregate function used to calculates the bitwise OR of all non-null input values, or null if none.
//var BIT_OR = jet.BIT_OR
// COUNT is aggregate function. Returns number of input rows for which the value of expression is not null.
var COUNT = jet.COUNT
// MAX is aggregate function. Returns maximum value of expression across all input values
var MAX = jet.MAX
// MAXi is aggregate function. Returns maximum value of int expression across all input values
var MAXi = jet.MAXi
// MAXf is aggregate function. Returns maximum value of float expression across all input values
var MAXf = jet.MAXf
// MIN is aggregate function. Returns minimum value of int expression across all input values
var MIN = jet.MIN
// MINi is aggregate function. Returns minimum value of int expression across all input values
var MINi = jet.MINi
// MINf is aggregate function. Returns minimum value of float expression across all input values
var MINf = jet.MINf
// SUM is aggregate function. Returns sum of all expressions
var SUM = jet.SUM
// SUMi is aggregate function. Returns sum of integer expression.
var SUMi = jet.SUMi
// SUMf is aggregate function. Returns sum of float expression.
var SUMf = jet.SUMf
// -------------------- Window functions -----------------------//
// ROW_NUMBER returns number of the current row within its partition, counting from 1
var ROW_NUMBER = jet.ROW_NUMBER
// RANK of the current row with gaps; same as row_number of its first peer
var RANK = jet.RANK
// DENSE_RANK returns rank of the current row without gaps; this function counts peer groups
var DENSE_RANK = jet.DENSE_RANK
// PERCENT_RANK calculates relative rank of the current row: (rank - 1) / (total partition rows - 1)
var PERCENT_RANK = jet.PERCENT_RANK
// CUME_DIST calculates cumulative distribution: (number of partition rows preceding or peer with current row) / total partition rows
var CUME_DIST = jet.CUME_DIST
// NTILE returns integer ranging from 1 to the argument value, dividing the partition as equally as possible
var NTILE = jet.NTILE
// LAG returns value evaluated at the row that is offset rows before the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LAG = jet.LAG
// LEAD returns value evaluated at the row that is offset rows after the current row within the partition;
// if there is no such row, instead return default (which must be of the same type as value).
// Both offset and default are evaluated with respect to the current row.
// If omitted, offset defaults to 1 and default to null
var LEAD = jet.LEAD
// FIRST_VALUE returns value evaluated at the row that is the first row of the window frame
var FIRST_VALUE = jet.FIRST_VALUE
// LAST_VALUE returns value evaluated at the row that is the last row of the window frame
var LAST_VALUE = jet.LAST_VALUE
// NTH_VALUE returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
var NTH_VALUE = jet.NTH_VALUE
//--------------------- String functions ------------------//
// BIT_LENGTH returns number of bits in string expression
//var BIT_LENGTH = jet.BIT_LENGTH
//
//// CHAR_LENGTH returns number of characters in string expression
//var CHAR_LENGTH = jet.CHAR_LENGTH
//
//// OCTET_LENGTH returns number of bytes in string expression
//var OCTET_LENGTH = jet.OCTET_LENGTH
// LOWER returns string expression in lower case
var LOWER = jet.LOWER
// UPPER returns string expression in upper case
var UPPER = jet.UPPER
// LTRIM removes the longest string containing only characters
// from characters (a space by default) from the start of string
var LTRIM = jet.LTRIM
// RTRIM removes the longest string containing only characters
// from characters (a space by default) from the end of string
var RTRIM = jet.RTRIM
// CONCAT adds two or more expressions together
//var CONCAT = jet.CONCAT
// CONCAT_WS adds two or more expressions together with a separator.
//var CONCAT_WS = jet.CONCAT_WS
// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string.
//var FORMAT = jet.FORMAT
// LEFTSTR returns first n characters in the string.
// When n is negative, return all but last |n| characters.
//func LEFTSTR(str StringExpression, n IntegerExpression) StringExpression {
// return jet.NewStringFunc("LEFTSTR", str, n)
//}
//
//// RIGHT returns last n characters in the string.
//// When n is negative, return all but first |n| characters.
//func RIGHTSTR(str StringExpression, n IntegerExpression) StringExpression {
// return jet.NewStringFunc("RIGHTSTR", str, n)
//}
// LENGTH returns number of characters in string with a given encoding
func LENGTH(str jet.StringExpression) jet.StringExpression {
return jet.LENGTH(str)
}
// LPAD fills up the string to length length by prepending the characters
// fill (a space by default). If the string is already longer than length
// then it is truncated (on the right).
//func LPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression {
// return jet.LPAD(str, length, text)
//}
// RPAD fills up the string to length length by appending the characters
// fill (a space by default). If the string is already longer than length then it is truncated.
//func RPAD(str jet.StringExpression, length jet.IntegerExpression, text jet.StringExpression) jet.StringExpression {
// return jet.RPAD(str, length, text)
//}
// MD5 calculates the MD5 hash of string, returning the result in hexadecimal
//var MD5 = jet.MD5
// REPEAT repeats string the specified number of times
//var REPEAT = jet.REPEAT
// REPLACE replaces all occurrences in string of substring from with substring to
var REPLACE = jet.REPLACE
// REVERSE returns reversed string.
var REVERSE = jet.REVERSE
// SUBSTR extracts substring
var SUBSTR = jet.SUBSTR
// REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise.
var REGEXP_LIKE = jet.REGEXP_LIKE
//----------------- Date/Time Functions and Operators ------------//
// CURRENT_DATE returns current date
var CURRENT_DATE = jet.CURRENT_DATE
// CURRENT_TIME returns current time with time zone
func CURRENT_TIME() TimeExpression {
return TimeExp(jet.CURRENT_TIME())
}
// CURRENT_TIMESTAMP returns current timestamp with time zone
func CURRENT_TIMESTAMP() TimestampExpression {
return TimestampExp(jet.CURRENT_TIMESTAMP())
}
//// NOW returns current datetime
//func NOW() DateTimeExpression {
// //if len(fsp) > 0 {
// // return jet.NewTimestampFunc("NOW", jet.FixedLiteral(int64(fsp[0])))
// //}
// //return jet.NewTimestampFunc("NOW")
// return DATETIME(jet.FixedLiteral("now"))
//}
// time-value modifiers
var (
YEARS = modifier("YEARS")
MONTHS = modifier("MONTHS")
DAYS = modifier("DAYS")
HOURS = modifier("HOURS")
MINUTES = modifier("MINUTES")
SECONDS = modifier("SECONDS")
START_OF_YEAR = String("start of year")
START_OF_MONTH = String("start of month")
UNIXEPOCH = String("unixepoch")
LOCALTIME = String("localtime")
UTC = String("UTC")
WEEKDAY = func(value int) Expression {
return String(fmt.Sprintf("WEEKDAY %d", value))
}
)
func modifier(modifierName string) func(value float64) Expression {
return func(value float64) Expression {
return String(fmt.Sprintf("%g %s", value, modifierName))
}
}
// DATE function creates new date from time-value and zero or more time modifiers
func DATE(timeValue interface{}, modifiers ...Expression) DateExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewDateFunc("DATE", exprList...)
}
// TIME function creates new time from time-value and zero or more time modifiers
func TIME(timeValue interface{}, modifiers ...Expression) TimeExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewTimeFunc("TIME", exprList...)
}
// DATETIME function creates new DateTime from time-value and zero or more time modifiers
func DATETIME(timeValue interface{}, modifiers ...Expression) DateTimeExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewTimestampFunc("DATETIME", exprList...)
}
// JULIANDAY returns the number of days since noon in Greenwich on November 24, 4714 B.C
func JULIANDAY(timeValue interface{}, modifiers ...Expression) FloatExpression {
exprList := getFuncExprList(timeValue, modifiers...)
return jet.NewFloatFunc("JULIANDAY", exprList...)
}
// STRFTIME routine returns the date formatted according to the format string specified as the first argument.
func STRFTIME(format StringExpression, timeValue interface{}, modifiers ...Expression) StringExpression {
exprList := append([]Expression{format}, getFuncExprList(timeValue, modifiers...)...)
return jet.NewStringFunc("strftime", exprList...)
}
func getFuncExprList(timeValue interface{}, modifiers ...Expression) []Expression {
return append([]Expression{getTimeValueExpression(timeValue)}, modifiers...)
}
func getTimeValueExpression(timeValue interface{}) Expression {
switch t := timeValue.(type) {
case string:
return String(t)
case Expression:
return t
case time.Time, int64:
return jet.Literal(t)
}
panic(fmt.Sprintf("jet: Invalid time value %T(%q)", timeValue, timeValue))
}
// TIMESTAMP return a datetime value based on the arguments:
func TIMESTAMP(str StringExpression) TimestampExpression {
return jet.NewTimestampFunc("TIMESTAMP", str)
}
// UNIX_TIMESTAMP returns unix timestamp
func UNIX_TIMESTAMP(str StringExpression) TimestampExpression {
return jet.NewTimestampFunc("UNIX_TIMESTAMP", str)
}
//----------- Comparison operators ---------------//
// EXISTS checks for existence of the rows in subQuery
var EXISTS = jet.EXISTS
// CASE create CASE operator with optional list of expressions
var CASE = jet.CASE

117
sqlite/insert_statement.go Normal file
View file

@ -0,0 +1,117 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// InsertStatement is interface for SQL INSERT statements
type InsertStatement interface {
Statement
VALUES(value interface{}, values ...interface{}) InsertStatement
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
DEFAULT_VALUES() InsertStatement
ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict
RETURNING(projections ...Projection) InsertStatement
}
func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{
DefaultValues: jet.ClauseOptional{Name: "DEFAULT VALUES", InNewLine: true},
}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert,
&newInsert.ValuesQuery,
&newInsert.OnDuplicateKey,
&newInsert.DefaultValues,
&newInsert.OnConflict,
&newInsert.Returning,
)
newInsert.Insert.Table = table
newInsert.Insert.Columns = columns
newInsert.ValuesQuery.SkipSelectWrap = true
return newInsert
}
type insertStatementImpl struct {
jet.SerializerStatement
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
DefaultValues jet.ClauseOptional
OnConflict onConflictClause
Returning jet.ClauseReturning
}
func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return is
}
// MODEL will insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
return is
}
func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
return is
}
func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
is.OnDuplicateKey = assigments
return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
}
func (is *insertStatementImpl) DEFAULT_VALUES() InsertStatement {
is.DefaultValues.Show = true
return is
}
func (is *insertStatementImpl) RETURNING(projections ...jet.Projection) InsertStatement {
is.Returning.ProjectionList = projections
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
}
func (is *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
is.OnConflict = onConflictClause{
insertStatement: is,
indexExpressions: indexExpressions,
}
return &is.OnConflict
}

View file

@ -0,0 +1,150 @@
package sqlite
import (
"github.com/stretchr/testify/require"
"testing"
"time"
)
func TestInvalidInsert(t *testing.T) {
assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list")
}
func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1)
VALUES (?);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1)
VALUES (?);
`, int(1))
}
func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList{table3ColInt}
columnList = append(columnList, table3StrCol)
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2)
VALUES (?, ?);
`, 1, 3)
}
func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTimestamp).VALUES(date), `
INSERT INTO db.table1 (col_timestamp)
VALUES (?);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3)
VALUES (?, ?, ?);
`, 1, 2, 3)
}
func TestInsertMultipleRows(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(1, 2).
VALUES(11, 22).
VALUES(111, 222)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (?, ?),
(?, ?),
(?, ?);
`, 1, 2, 11, 22, 111, 222)
}
func TestInsertValuesFromModel(t *testing.T) {
type Table1Model struct {
Col1 *int
ColFloat float64
}
one := 1
toInsert := Table1Model{
Col1: &one,
ColFloat: 1.11,
}
stmt := table1.INSERT(table1Col1, table1ColFloat).
MODEL(toInsert).
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float)
VALUES (?, ?),
(?, ?);
`
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, r, "missing struct field for column : col1")
}()
type Table1Model struct {
Col1Prim int
Col2 string
}
newData := Table1Model{
Col1Prim: 1,
Col2: "one",
}
table1.
INSERT(table1Col1, table1ColFloat).
MODEL(newData)
}
func TestInsertFromNonStructModel(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, r, "jet: data has to be a struct")
}()
table2.INSERT(table2ColInt).MODEL([]int{})
}
func TestInsert_ON_CONFLICT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES (?, ?),
(?, ?),
(?, ?)
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = ?,
col_int = ?,
(col1, col_bool) = (?, ?)
WHERE table1.col1 > ?
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}

70
sqlite/literal.go Normal file
View file

@ -0,0 +1,70 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
"time"
)
// Keywords
var (
STAR = jet.STAR
NULL = jet.NULL
)
// Bool creates new bool literal expression
var Bool = jet.Bool
// Int is constructor for 64 bit signed integer expressions literals.
var Int = jet.Int
// Int8 is constructor for 8 bit signed integer expressions literals.
var Int8 = jet.Int8
// Int16 is constructor for 16 bit signed integer expressions literals.
var Int16 = jet.Int16
// Int32 is constructor for 32 bit signed integer expressions literals.
var Int32 = jet.Int32
// Int64 is constructor for 64 bit signed integer expressions literals.
var Int64 = jet.Int
// Uint8 is constructor for 8 bit unsigned integer expressions literals.
var Uint8 = jet.Uint8
// Uint16 is constructor for 16 bit unsigned integer expressions literals.
var Uint16 = jet.Uint16
// Uint32 is constructor for 32 bit unsigned integer expressions literals.
var Uint32 = jet.Uint32
// Uint64 is constructor for 64 bit unsigned integer expressions literals.
var Uint64 = jet.Uint64
// Float creates new float literal expression from float64 value
var Float = jet.Float
// Decimal creates new float literal expression from string value
var Decimal = jet.Decimal
// String creates new string literal expression
var String = jet.String
// UUID is a helper function to create string literal expression from uuid object
// value can be any uuid type with a String method
var UUID = jet.UUID
// Date creates new date literal expression
func Date(year int, month time.Month, day int) DateExpression {
return DATE(jet.Date(year, month, day))
}
// Time creates new time literal expression
func Time(hour, minute, second int, nanoseconds ...time.Duration) TimeExpression {
return TIME(jet.Time(hour, minute, second, nanoseconds...))
}
// DateTime creates new datetime(timestamp) literal expression
func DateTime(year int, month time.Month, day, hour, minute, second int, nanoseconds ...time.Duration) DateTimeExpression {
return DATETIME(jet.Timestamp(year, month, day, hour, minute, second, nanoseconds...))
}

80
sqlite/literal_test.go Normal file
View file

@ -0,0 +1,80 @@
package sqlite
import (
"math"
"testing"
"time"
)
func TestBool(t *testing.T) {
assertSerialize(t, Bool(false), `?`, false)
}
func TestInt(t *testing.T) {
assertSerialize(t, Int(11), `?`, int64(11))
}
func TestInt8(t *testing.T) {
val := int8(math.MinInt8)
assertSerialize(t, Int8(val), `?`, val)
}
func TestInt16(t *testing.T) {
val := int16(math.MinInt16)
assertSerialize(t, Int16(val), `?`, val)
}
func TestInt32(t *testing.T) {
val := int32(math.MinInt32)
assertSerialize(t, Int32(val), `?`, val)
}
func TestInt64(t *testing.T) {
val := int64(math.MinInt64)
assertSerialize(t, Int64(val), `?`, val)
}
func TestUint8(t *testing.T) {
val := uint8(math.MaxUint8)
assertSerialize(t, Uint8(val), `?`, val)
}
func TestUint16(t *testing.T) {
val := uint16(math.MaxUint16)
assertSerialize(t, Uint16(val), `?`, val)
}
func TestUint32(t *testing.T) {
val := uint32(math.MaxUint32)
assertSerialize(t, Uint32(val), `?`, val)
}
func TestUint64(t *testing.T) {
val := uint64(math.MaxUint64)
assertSerialize(t, Uint64(val), `?`, val)
}
func TestFloat(t *testing.T) {
assertSerialize(t, Float(12.34), `?`, float64(12.34))
}
func TestString(t *testing.T) {
assertSerialize(t, String("Some text"), `?`, "Some text")
}
var testTime = time.Now()
func TestDate(t *testing.T) {
assertSerialize(t, Date(2014, time.January, 2), "DATE(?)", "2014-01-02")
assertSerialize(t, DATE(testTime), "DATE(?)", testTime)
}
func TestTime(t *testing.T) {
assertSerialize(t, Time(10, 15, 30), `TIME(?)`, "10:15:30")
assertSerialize(t, TIME(testTime), "TIME(?)", testTime)
}
func TestDateTime(t *testing.T) {
assertSerialize(t, DateTime(2010, time.March, 30, 10, 15, 30), `DATETIME(?)`, "2010-03-30 10:15:30")
assertSerialize(t, DATETIME(testTime), `DATETIME(?)`, testTime)
}

View file

@ -0,0 +1,84 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
type onConflict interface {
WHERE(indexPredicate BoolExpression) conflictTarget
conflictTarget
}
type conflictTarget interface {
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type onConflictClause struct {
insertStatement InsertStatement
indexExpressions []jet.ColumnExpression
whereClause jet.ClauseWhere
do jet.Serializer
}
func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget {
o.whereClause.Condition = indexPredicate
return o
}
func (o *onConflictClause) DO_NOTHING() InsertStatement {
o.do = jet.Keyword("DO NOTHING")
return o.insertStatement
}
func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement {
o.do = action
return o.insertStatement
}
func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(o.indexExpressions) == 0 && o.do == nil {
return
}
out.NewLine()
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}
o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName)
out.IncreaseIdent(7)
jet.Serialize(o.do, statementType, out)
out.DecreaseIdent(7)
}
type conflictAction interface {
jet.Serializer
WHERE(condition BoolExpression) conflictAction
}
// SET creates conflict action for ON_CONFLICT clause
func SET(assigments ...ColumnAssigment) conflictAction {
conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.set = assigments
return &conflictAction
}
type updateConflictActionImpl struct {
jet.Serializer
doUpdate jet.KeywordClause
set jet.SetClauseNew
where jet.ClauseWhere
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition
return u
}

9
sqlite/operators.go Normal file
View file

@ -0,0 +1,9 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// NOT returns negation of bool expression result
var NOT = jet.NOT
// BIT_NOT inverts every bit in integer expression result
var BIT_NOT = jet.BIT_NOT

186
sqlite/select_statement.go Normal file
View file

@ -0,0 +1,186 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
)
// RowLock is interface for SELECT statement row lock types
type RowLock = jet.RowLock
// Row lock types
var (
UPDATE = jet.NewRowLock("UPDATE")
SHARE = jet.NewRowLock("SHARE")
)
// Window function clauses
var (
PARTITION_BY = jet.PARTITION_BY
ORDER_BY = jet.ORDER_BY
UNBOUNDED = jet.UNBOUNDED
CURRENT_ROW = jet.CURRENT_ROW
)
// PRECEDING window frame clause
func PRECEDING(offset interface{}) jet.FrameExtent {
return jet.PRECEDING(toJetFrameOffset(offset))
}
// FOLLOWING window frame clause
func FOLLOWING(offset interface{}) jet.FrameExtent {
return jet.FOLLOWING(toJetFrameOffset(offset))
}
// Window is used to specify window reference from WINDOW clause
var Window = jet.WindowName
// SelectStatement is interface for MySQL SELECT statement
type SelectStatement interface {
Statement
jet.HasProjections
Expression
DISTINCT() SelectStatement
FROM(tables ...ReadableTable) SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(groupByClauses ...GroupByClause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
WINDOW(name string) windowExpand
ORDER_BY(orderByClauses ...OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR(lock RowLock) SelectStatement
LOCK_IN_SHARE_MODE() SelectStatement
UNION(rhs SelectStatement) setStatement
UNION_ALL(rhs SelectStatement) setStatement
AsTable(alias string) SelectTable
}
//SELECT creates new SelectStatement with list of projections
func SELECT(projection Projection, projections ...Projection) SelectStatement {
return newSelectStatement(nil, append([]Projection{projection}, projections...))
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
newSelect := &selectStatementImpl{}
newSelect.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select,
&newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.Window, &newSelect.OrderBy,
&newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock)
newSelect.Select.ProjectionList = projections
if table != nil {
newSelect.From.Tables = []jet.Serializer{table}
}
newSelect.Limit.Count = -1
newSelect.Offset.Count = -1
newSelect.ShareLock.Name = "LOCK IN SHARE MODE"
newSelect.ShareLock.InNewLine = true
newSelect.setOperatorsImpl.parent = newSelect
return newSelect
}
type selectStatementImpl struct {
jet.ExpressionStatement
setOperatorsImpl
Select jet.ClauseSelect
From jet.ClauseFrom
Where jet.ClauseWhere
GroupBy jet.ClauseGroupBy
Having jet.ClauseHaving
Window jet.ClauseWindow
OrderBy jet.ClauseOrderBy
Limit jet.ClauseLimit
Offset jet.ClauseOffset
For jet.ClauseFor
ShareLock jet.ClauseOptional
}
func (s *selectStatementImpl) DISTINCT() SelectStatement {
s.Select.Distinct = true
return s
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
return s
}
func (s *selectStatementImpl) WHERE(condition BoolExpression) SelectStatement {
s.Where.Condition = condition
return s
}
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...GroupByClause) SelectStatement {
s.GroupBy.List = groupByClauses
return s
}
func (s *selectStatementImpl) HAVING(boolExpression BoolExpression) SelectStatement {
s.Having.Condition = boolExpression
return s
}
func (s *selectStatementImpl) WINDOW(name string) windowExpand {
s.Window.Definitions = append(s.Window.Definitions, jet.WindowDefinition{Name: name})
return windowExpand{selectStatement: s}
}
func (s *selectStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) SelectStatement {
s.OrderBy.List = orderByClauses
return s
}
func (s *selectStatementImpl) LIMIT(limit int64) SelectStatement {
s.Limit.Count = limit
return s
}
func (s *selectStatementImpl) OFFSET(offset int64) SelectStatement {
s.Offset.Count = offset
return s
}
func (s *selectStatementImpl) FOR(lock RowLock) SelectStatement {
s.For.Lock = lock
return s
}
func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement {
s.ShareLock.Show = true
return s
}
func (s *selectStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
}
//-----------------------------------------------------
type windowExpand struct {
selectStatement *selectStatementImpl
}
func (w windowExpand) AS(window ...jet.Window) SelectStatement {
if len(window) == 0 {
return w.selectStatement
}
windowsDefinition := w.selectStatement.Window.Definitions
windowsDefinition[len(windowsDefinition)-1].Window = window[0]
return w.selectStatement
}
func toJetFrameOffset(offset interface{}) jet.Serializer {
if offset == UNBOUNDED {
return jet.UNBOUNDED
}
return jet.FixedLiteral(offset)
}

View file

@ -0,0 +1,156 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
"testing"
)
func TestInvalidSelect(t *testing.T) {
assertStatementSqlErr(t, SELECT(nil), "jet: Projection is nil")
}
func TestSelectColumnList(t *testing.T) {
columnList := ColumnList{table2ColInt, table2ColFloat, table3ColInt}
assertStatementSql(t, SELECT(columnList).FROM(table2), `
SELECT table2.col_int AS "table2.col_int",
table2.col_float AS "table2.col_float",
table3.col_int AS "table3.col_int"
FROM db.table2;
`)
}
func TestSelectLiterals(t *testing.T) {
assertStatementSql(t, SELECT(Int(1), Float(2.2), Bool(false)).FROM(table1), `
SELECT ?,
?,
?
FROM db.table1;
`, int64(1), 2.2, false)
}
func TestSelectDistinct(t *testing.T) {
assertStatementSql(t, SELECT(table1ColBool).DISTINCT().FROM(table1), `
SELECT DISTINCT table1.col_bool AS "table1.col_bool"
FROM db.table1;
`)
}
func TestSelectFrom(t *testing.T) {
assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1;
`)
assertStatementSql(t, SELECT(table1ColInt, table2ColFloat).FROM(table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt))), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
assertStatementSql(t, table1.INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).SELECT(table1ColInt, table2ColFloat), `
SELECT table1.col_int AS "table1.col_int",
table2.col_float AS "table2.col_float"
FROM db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int);
`)
}
func TestSelectWhere(t *testing.T) {
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(Bool(true)), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE ?;
`, true)
assertStatementSql(t, SELECT(table1ColInt).FROM(table1).WHERE(table1ColInt.GT_EQ(Int(10))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE table1.col_int >= ?;
`, int64(10))
}
func TestSelectGroupBy(t *testing.T) {
assertStatementSql(t, SELECT(table2ColInt).FROM(table2).GROUP_BY(table2ColFloat), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
GROUP BY table2.col_float;
`)
}
func TestSelectHaving(t *testing.T) {
assertStatementSql(t, SELECT(table3ColInt).FROM(table3).HAVING(table1ColBool.EQ(Bool(true))), `
SELECT table3.col_int AS "table3.col_int"
FROM db.table3
HAVING table1.col_bool = ?;
`, true)
}
func TestSelectOrderBy(t *testing.T) {
assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC;
`)
assertStatementSql(t, SELECT(table2ColFloat).FROM(table2).ORDER_BY(table2ColInt.DESC(), table2ColInt.ASC()), `
SELECT table2.col_float AS "table2.col_float"
FROM db.table2
ORDER BY table2.col_int DESC, table2.col_int ASC;
`)
}
func TestSelectLimitOffset(t *testing.T) {
assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT ?;
`, int64(10))
assertStatementSql(t, SELECT(table2ColInt).FROM(table2).LIMIT(10).OFFSET(2), `
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
LIMIT ?
OFFSET ?;
`, int64(10), int64(2))
}
func TestSelectLock(t *testing.T) {
testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(UPDATE()), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
FOR UPDATE;
`)
testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).FOR(SHARE().NOWAIT()), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
FOR SHARE NOWAIT;
`)
}
func TestSelect_LOCK_IN_SHARE_MODE(t *testing.T) {
testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).LOCK_IN_SHARE_MODE(), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
LOCK IN SHARE MODE;
`)
}
func TestSelect_NOT_EXISTS(t *testing.T) {
testutils.AssertStatementSql(t,
SELECT(table1ColInt).
FROM(table1).
WHERE(
NOT(EXISTS(
SELECT(table2ColInt).
FROM(table2).
WHERE(
table1ColInt.EQ(table2ColInt),
),
))), `
SELECT table1.col_int AS "table1.col_int"
FROM db.table1
WHERE (NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int"
FROM db.table2
WHERE table1.col_int = table2.col_int
)));
`)
}

24
sqlite/select_table.go Normal file
View file

@ -0,0 +1,24 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// SelectTable is interface for MySQL sub-queries
type SelectTable interface {
readableTable
jet.SelectTable
}
type selectTableImpl struct {
jet.SelectTable
readableTableInterfaceImpl
}
func newSelectTable(selectStmt jet.SerializerStatement, alias string) SelectTable {
subQuery := &selectTableImpl{
SelectTable: jet.NewSelectTable(selectStmt, alias),
}
subQuery.readableTableInterfaceImpl.parent = subQuery
return subQuery
}

99
sqlite/set_statement.go Normal file
View file

@ -0,0 +1,99 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// UNION effectively appends the result of sub-queries(select statements) into single query.
// It eliminates duplicate rows from its result.
func UNION(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, false, toSelectList(lhs, rhs, selects...))
}
// UNION_ALL effectively appends the result of sub-queries(select statements) into single query.
// It does not eliminates duplicate rows from its result.
func UNION_ALL(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) setStatement {
return newSetStatementImpl(union, true, toSelectList(lhs, rhs, selects...))
}
type setStatement interface {
setOperators
ORDER_BY(orderByClauses ...OrderByClause) setStatement
LIMIT(limit int64) setStatement
OFFSET(offset int64) setStatement
AsTable(alias string) SelectTable
}
type setOperators interface {
jet.Statement
jet.HasProjections
jet.Expression
UNION(rhs SelectStatement) setStatement
UNION_ALL(rhs SelectStatement) setStatement
}
type setOperatorsImpl struct {
parent setOperators
}
func (s *setOperatorsImpl) UNION(rhs SelectStatement) setStatement {
return UNION(s.parent, rhs)
}
func (s *setOperatorsImpl) UNION_ALL(rhs SelectStatement) setStatement {
return UNION_ALL(s.parent, rhs)
}
type setStatementImpl struct {
jet.ExpressionStatement
setOperatorsImpl
setOperator jet.ClauseSetStmtOperator
}
func newSetStatementImpl(operator string, all bool, selects []jet.SerializerStatement) setStatement {
newSetStatement := &setStatementImpl{}
newSetStatement.ExpressionStatement = jet.NewExpressionStatementImpl(Dialect, jet.SetStatementType, newSetStatement,
&newSetStatement.setOperator)
newSetStatement.setOperator.Operator = operator
newSetStatement.setOperator.All = all
newSetStatement.setOperator.Selects = selects
newSetStatement.setOperator.Limit.Count = -1
newSetStatement.setOperator.Offset.Count = -1
newSetStatement.setOperator.SkipSelectWrap = true
newSetStatement.setOperatorsImpl.parent = newSetStatement
return newSetStatement
}
func (s *setStatementImpl) ORDER_BY(orderByClauses ...OrderByClause) setStatement {
s.setOperator.OrderBy.List = orderByClauses
return s
}
func (s *setStatementImpl) LIMIT(limit int64) setStatement {
s.setOperator.Limit.Count = limit
return s
}
func (s *setStatementImpl) OFFSET(offset int64) setStatement {
s.setOperator.Offset.Count = offset
return s
}
func (s *setStatementImpl) AsTable(alias string) SelectTable {
return newSelectTable(s, alias)
}
const (
union = "UNION"
)
func toSelectList(lhs, rhs jet.SerializerStatement, selects ...jet.SerializerStatement) []jet.SerializerStatement {
return append([]jet.SerializerStatement{lhs, rhs}, selects...)
}

View file

@ -0,0 +1,31 @@
package sqlite
import (
"testing"
)
func TestSelectSets(t *testing.T) {
select1 := SELECT(table1ColBool).FROM(table1)
select2 := SELECT(table2ColBool).FROM(table2)
assertStatementSql(t, select1.UNION(select2), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
UNION
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2;
`)
assertStatementSql(t, select1.UNION_ALL(select2), `
SELECT table1.col_bool AS "table1.col_bool"
FROM db.table1
UNION ALL
SELECT table2.col_bool AS "table2.col_bool"
FROM db.table2;
`)
}

8
sqlite/statement.go Normal file
View file

@ -0,0 +1,8 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// RawStatement creates new sql statements from raw query and optional map of named arguments
func RawStatement(rawQuery string, namedArguments ...RawArgs) Statement {
return jet.RawStatement(Dialect, rawQuery, namedArguments...)
}

122
sqlite/table.go Normal file
View file

@ -0,0 +1,122 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Table is interface for MySQL tables
type Table interface {
jet.SerializerTable
readableTable
INSERT(columns ...jet.Column) InsertStatement
UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement
}
type readableTable interface {
// Generates a select query on the current tableName.
SELECT(projection Projection, projections ...Projection) SelectStatement
// Creates a inner join tableName Expression using onCondition.
INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a left join tableName Expression using onCondition.
LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a right join tableName Expression using onCondition.
RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a full join tableName Expression using onCondition.
FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable
// Creates a cross join tableName Expression using onCondition.
CROSS_JOIN(table ReadableTable) joinSelectUpdateTable
}
type joinSelectUpdateTable interface {
ReadableTable
UPDATE(columns ...jet.Column) UpdateStatement
}
// ReadableTable interface
type ReadableTable interface {
readableTable
jet.Serializer
}
type readableTableInterfaceImpl struct {
parent ReadableTable
}
// Generates a select query on the current tableName.
func (r readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement {
return newSelectStatement(r.parent, append([]Projection{projection1}, projections...))
}
// Creates a inner join tableName Expression using onCondition.
func (r readableTableInterfaceImpl) INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.InnerJoin, onCondition)
}
// Creates a left join tableName Expression using onCondition.
func (r readableTableInterfaceImpl) LEFT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.LeftJoin, onCondition)
}
// Creates a right join tableName Expression using onCondition.
func (r readableTableInterfaceImpl) RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.RightJoin, onCondition)
}
func (r readableTableInterfaceImpl) FULL_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.FullJoin, onCondition)
}
func (r readableTableInterfaceImpl) CROSS_JOIN(table ReadableTable) joinSelectUpdateTable {
return newJoinTable(r.parent, table, jet.CrossJoin, nil)
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name, alias string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, alias, columns...),
}
t.readableTableInterfaceImpl.parent = t
t.parent = t
return t
}
type tableImpl struct {
jet.SerializerTable
readableTableInterfaceImpl
parent Table
}
func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) DELETE() DeleteStatement {
return newDeleteStatement(t.parent)
}
type joinTable struct {
tableImpl
jet.JoinTable
}
func newJoinTable(lhs jet.Serializer, rhs jet.Serializer, joinType jet.JoinType, onCondition BoolExpression) Table {
newJoinTable := &joinTable{
JoinTable: jet.NewJoinTable(lhs, rhs, joinType, onCondition),
}
newJoinTable.readableTableInterfaceImpl.parent = newJoinTable
newJoinTable.parent = newJoinTable
return newJoinTable
}

101
sqlite/table_test.go Normal file
View file

@ -0,0 +1,101 @@
package sqlite
import (
"testing"
)
func TestJoinNilInputs(t *testing.T) {
assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}
func TestINNER_JOIN(t *testing.T) {
assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(table2ColInt)).
INNER_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = table2.col_int)
INNER JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
INNER_JOIN(table2, table1ColInt.EQ(Int(1))).
INNER_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
INNER JOIN db.table2 ON (table1.col_int = ?)
INNER JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestLEFT_JOIN(t *testing.T) {
assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
LEFT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = table2.col_int)
LEFT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
LEFT_JOIN(table2, table1ColInt.EQ(Int(1))).
LEFT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
LEFT JOIN db.table2 ON (table1.col_int = ?)
LEFT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestRIGHT_JOIN(t *testing.T) {
assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(table2ColInt)).
RIGHT_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = table2.col_int)
RIGHT JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
RIGHT_JOIN(table2, table1ColInt.EQ(Int(1))).
RIGHT_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
RIGHT JOIN db.table2 ON (table1.col_int = ?)
RIGHT JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestFULL_JOIN(t *testing.T) {
assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)`)
assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(table2ColInt)).
FULL_JOIN(table3, table1ColInt.EQ(table3ColInt)),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = table2.col_int)
FULL JOIN db.table3 ON (table1.col_int = table3.col_int)`)
assertSerialize(t, table1.
FULL_JOIN(table2, table1ColInt.EQ(Int(1))).
FULL_JOIN(table3, table1ColInt.EQ(Int(2))),
`db.table1
FULL JOIN db.table2 ON (table1.col_int = ?)
FULL JOIN db.table3 ON (table1.col_int = ?)`, int64(1), int64(2))
}
func TestCROSS_JOIN(t *testing.T) {
assertSerialize(t, table1.
CROSS_JOIN(table2),
`db.table1
CROSS JOIN db.table2`)
assertSerialize(t, table1.
CROSS_JOIN(table2).
CROSS_JOIN(table3),
`db.table1
CROSS JOIN db.table2
CROSS JOIN db.table3`)
}

27
sqlite/types.go Normal file
View file

@ -0,0 +1,27 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK)
type Statement = jet.Statement
// Projection is interface for all projection types. Types that can be part of, for instance SELECT clause.
type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment
// PrintableStatement is a statement which sql query can be logged
type PrintableStatement = jet.PrintableStatement
// OrderByClause is the combination of an expression and the wanted ordering to use as input for ORDER BY.
type OrderByClause = jet.OrderByClause
// GroupByClause interface to use as input for GROUP_BY
type GroupByClause = jet.GroupByClause
// SetLogger sets automatic statement logging
var SetLogger = jet.SetLoggerFunc

View file

@ -0,0 +1,70 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// UpdateStatement is interface of SQL UPDATE statement
type UpdateStatement interface {
jet.Statement
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...jet.Projection) UpdateStatement
}
type updateStatementImpl struct {
jet.SerializerStatement
Update jet.ClauseUpdate
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
Returning jet.ClauseReturning
}
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Update,
&update.Set,
&update.SetNew,
&update.Where,
&update.Returning)
update.Update.Table = table
update.Set.Columns = columns
update.Where.Mandatory = true
return update
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u
}
func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
u.Set.Values = jet.UnwindRowFromModel(u.Set.Columns, data)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
u.Returning.ProjectionList = projections
return u
}

View file

@ -0,0 +1,82 @@
package sqlite
import (
"fmt"
"strings"
"testing"
)
func TestUpdateWithOneValue(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?
WHERE table1.col_int >= ?;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33)))
fmt.Println(stmt.Sql())
assertStatementSql(t, stmt, expectedSQL, 1, int64(33))
}
func TestUpdateWithValues(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?,
col_float = ?
WHERE table1.col_int >= ?;
`
stmt := table1.UPDATE(table1ColInt, table1ColFloat).
SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33)))
fmt.Println(stmt.Sql())
assertStatementSql(t, stmt, expectedSQL, 1, 22.2, int64(33))
}
func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_float = (
SELECT table1.col_float AS "table1.col_float"
FROM db.table1
)
WHERE table1.col1 = ?;
`
stmt := table1.
UPDATE(table1ColFloat).
SET(
table1.SELECT(table1ColFloat),
).
WHERE(table1Col1.EQ(Int(2)))
assertStatementSql(t, stmt, expectedSQL, int64(2))
}
func TestUpdateReservedWorldColumn(t *testing.T) {
type table struct {
Load string
}
loadColumn := StringColumn("Load")
assertStatementSql(t,
table1.UPDATE(loadColumn).
MODEL(
table{
Load: "foo",
},
).
WHERE(loadColumn.EQ(String("bar"))), strings.Replace(`
UPDATE db.table1
SET ''Load'' = ?
WHERE ''Load'' = ?;
`, "''", "`", -1), "foo", "bar")
}
func TestInvalidInputs(t *testing.T) {
assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set")
assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause")
}

55
sqlite/utils_test.go Normal file
View file

@ -0,0 +1,55 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/jet"
"github.com/go-jet/jet/v2/internal/testutils"
"testing"
)
var table1Col1 = IntegerColumn("col1")
var table1ColBool = BoolColumn("col_bool")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1ColString = StringColumn("col_string")
var table1Col3 = IntegerColumn("col3")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColDate = DateColumn("col_date")
var table1ColTime = TimeColumn("col_time")
var table1 = NewTable("db", "table1", "", table1Col1, table1ColInt, table1ColFloat, table1ColString, table1Col3, table1ColBool, table1ColDate, table1ColTimestamp, table1ColTime)
var table2Col3 = IntegerColumn("col3")
var table2Col4 = IntegerColumn("col4")
var table2ColInt = IntegerColumn("col_int")
var table2ColFloat = FloatColumn("col_float")
var table2ColStr = StringColumn("col_str")
var table2ColBool = BoolColumn("col_bool")
var table2ColTimestamp = TimestampColumn("col_timestamp")
var table2ColDate = DateColumn("col_date")
var table2 = NewTable("db", "table2", "", table2Col3, table2Col4, table2ColInt, table2ColFloat, table2ColStr, table2ColBool, table2ColDate, table2ColTimestamp)
var table3Col1 = IntegerColumn("col1")
var table3ColInt = IntegerColumn("col_int")
var table3StrCol = StringColumn("col2")
var table3 = NewTable("db", "table3", "", table3Col1, table3ColInt, table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertSerialize(t, Dialect, clause, query, args...)
}
func assertDebugSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
testutils.AssertDebugSerialize(t, Dialect, clause, query, args...)
}
func assertSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertSerializeErr(t, Dialect, clause, errString)
}
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {
testutils.AssertProjectionSerialize(t, Dialect, projection, query, args...)
}
var assertPanicErr = testutils.AssertPanicErr
var assertStatementSql = testutils.AssertStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr

26
sqlite/with_statement.go Normal file
View file

@ -0,0 +1,26 @@
package sqlite
import "github.com/go-jet/jet/v2/internal/jet"
// CommonTableExpression contains information about a CTE.
type CommonTableExpression struct {
readableTableInterfaceImpl
jet.CommonTableExpression
}
// WITH function creates new WITH statement from list of common table expressions
func WITH(cte ...jet.CommonTableExpressionDefinition) func(statement jet.Statement) Statement {
return jet.WITH(Dialect, cte...)
}
// CTE creates new named CommonTableExpression
func CTE(name string) CommonTableExpression {
cte := CommonTableExpression{
readableTableInterfaceImpl: readableTableInterfaceImpl{},
CommonTableExpression: jet.CTE(name),
}
cte.parent = &cte
return cte
}

View file

@ -1,6 +1,9 @@
package dbconfig package dbconfig
import "fmt" import (
"fmt"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
)
// Postgres test database connection parameters // Postgres test database connection parameters
const ( const (
@ -24,3 +27,10 @@ const (
// MySQLConnectionString is MySQL driver connection string to test database // MySQLConnectionString is MySQL driver connection string to test database
var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort) var MySQLConnectionString = fmt.Sprintf("%s:%s@tcp(%s:%d)/", MySQLUser, MySQLPassword, MySqLHost, MySQLPort)
// sqllite
var (
SakilaDBPath = repo.GetTestDataFilePath("/init/sqlite/sakila.db")
ChinookDBPath = repo.GetTestDataFilePath("/init/sqlite/chinook.db")
TestSampleDBPath = repo.GetTestDataFilePath("/init/sqlite/test_sample.db")
)

View file

@ -4,6 +4,8 @@ import (
"database/sql" "database/sql"
"flag" "flag"
"fmt" "fmt"
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
@ -15,6 +17,8 @@ import (
"github.com/go-jet/jet/v2/tests/dbconfig" "github.com/go-jet/jet/v2/tests/dbconfig"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq" _ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
) )
var testSuite string var testSuite string
@ -39,8 +43,23 @@ func main() {
return return
} }
if testSuite == "sqlite" {
initSQLiteDB()
return
}
initMySQLDB() initMySQLDB()
initPostgresDB() initPostgresDB()
initSQLiteDB()
}
func initSQLiteDB() {
err := sqlite.GenerateDSN(dbconfig.SakilaDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/sakila"))
throw.OnError(err)
err = sqlite.GenerateDSN(dbconfig.ChinookDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/chinook"))
throw.OnError(err)
err = sqlite.GenerateDSN(dbconfig.TestSampleDBPath, repo.GetTestsFilePath("./.gentestdata/sqlite/test_sample"))
throw.OnError(err)
} }
func initMySQLDB() { func initMySQLDB() {

View file

@ -0,0 +1,33 @@
package repo
import (
"os/exec"
"path/filepath"
"strings"
)
// GetRootDirPath will return this repo full dir path
func GetRootDirPath() string {
cmd := exec.Command("git", "rev-parse", "--show-toplevel")
byteArr, err := cmd.Output()
if err != nil {
panic(err)
}
return strings.TrimSpace(string(byteArr))
}
// GetTestsDirPath will return tests folder full path
func GetTestsDirPath() string {
return filepath.Join(GetRootDirPath(), "tests")
}
// GetTestsFilePath will return full file path of the file in the tests folder
func GetTestsFilePath(subPath string) string {
return filepath.Join(GetTestsDirPath(), subPath)
}
// GetTestDataFilePath will return full file path of the file in the testdata folder
func GetTestDataFilePath(subPath string) string {
return filepath.Join(GetTestsDirPath(), "testdata", subPath)
}

View file

@ -0,0 +1,912 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/view"
"github.com/go-jet/jet/v2/tests/testdata/results/common"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"strings"
"testing"
"time"
)
func TestAllTypes(t *testing.T) {
dest := []model.AllTypes{}
err := SELECT(AllTypes.AllColumns).
FROM(AllTypes).
Query(sampleDB, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, allTypesJSON)
}
var allTypesJSON = `
[
{
"Boolean": false,
"BooleanPtr": true,
"TinyInt": -3,
"TinyIntPtr": 3,
"SmallInt": 14,
"SmallIntPtr": 14,
"MediumInt": -150,
"MediumIntPtr": 150,
"Integer": -1600,
"IntegerPtr": 1600,
"BigInt": 5000,
"BigIntPtr": 50000,
"Decimal": 1.11,
"DecimalPtr": 1.01,
"Numeric": 2.22,
"NumericPtr": 2.02,
"Float": 3.33,
"FloatPtr": 3.03,
"Double": 4.44,
"DoublePtr": 4.04,
"Real": 5.55,
"RealPtr": 5.05,
"Time": "0000-01-01T10:11:12.33Z",
"TimePtr": "0000-01-01T10:11:12.123456Z",
"Date": "2008-07-04T00:00:00Z",
"DatePtr": "2008-07-04T00:00:00Z",
"DateTime": "2011-12-18T13:17:17Z",
"DateTimePtr": "2011-12-18T13:17:17Z",
"Timestamp": "2007-12-31T23:00:01Z",
"TimestampPtr": "2007-12-31T23:00:01Z",
"Char": "char1",
"CharPtr": "char-ptr",
"VarChar": "varchar",
"VarCharPtr": "varchar-ptr",
"Text": "text",
"TextPtr": "text-ptr",
"Blob": "YmxvYjE=",
"BlobPtr": "YmxvYi1wdHI="
},
{
"Boolean": false,
"BooleanPtr": null,
"TinyInt": -3,
"TinyIntPtr": null,
"SmallInt": 14,
"SmallIntPtr": null,
"MediumInt": -150,
"MediumIntPtr": null,
"Integer": -1600,
"IntegerPtr": null,
"BigInt": 5000,
"BigIntPtr": null,
"Decimal": 1.11,
"DecimalPtr": null,
"Numeric": 2.22,
"NumericPtr": null,
"Float": 3.33,
"FloatPtr": null,
"Double": 4.44,
"DoublePtr": null,
"Real": 5.55,
"RealPtr": null,
"Time": "0000-01-01T10:11:12.33Z",
"TimePtr": null,
"Date": "2008-07-04T00:00:00Z",
"DatePtr": null,
"DateTime": "2011-12-18T13:17:17Z",
"DateTimePtr": null,
"Timestamp": "2007-12-31T23:00:01Z",
"TimestampPtr": null,
"Char": "char2",
"CharPtr": null,
"VarChar": "varchar",
"VarCharPtr": null,
"Text": "text",
"TextPtr": null,
"Blob": "YmxvYjI=",
"BlobPtr": null
}
]
`
func TestAllTypesViewSelect(t *testing.T) {
var dest []model.AllTypesView
stmt := SELECT(view.AllTypesView.AllColumns).
FROM(view.AllTypesView)
err := stmt.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 2)
testutils.AssertJSON(t, dest, allTypesJSON)
}
func TestAllTypesInsert(t *testing.T) {
tx := beginSampleDBTx(t)
stmt := AllTypes.INSERT(AllTypes.AllColumns).
MODEL(toInsert).
RETURNING(AllTypes.AllColumns)
var inserted model.AllTypes
err := stmt.Query(tx, &inserted)
require.NoError(t, err)
testutils.AssertDeepEqual(t, toInsert, inserted, testutils.UnixTimeComparer)
var dest model.AllTypes
err = AllTypes.SELECT(AllTypes.AllColumns).
WHERE(AllTypes.BigInt.EQ(Int(toInsert.BigInt))).
Query(tx, &dest)
require.NoError(t, err)
testutils.AssertDeepEqual(t, dest, toInsert, testutils.UnixTimeComparer)
err = tx.Rollback()
require.NoError(t, err)
}
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),
TinyInt: 1,
SmallInt: 3,
MediumInt: 5,
Integer: 7,
BigInt: 9,
TinyIntPtr: testutils.Int8Ptr(11),
SmallIntPtr: testutils.Int16Ptr(33),
MediumIntPtr: testutils.Int32Ptr(55),
IntegerPtr: testutils.Int32Ptr(77),
BigIntPtr: testutils.Int64Ptr(99),
Decimal: 11.22,
DecimalPtr: testutils.Float64Ptr(33.44),
Numeric: 55.66,
NumericPtr: testutils.Float64Ptr(77.88),
Float: 99.00,
FloatPtr: testutils.Float64Ptr(11.22),
Double: 33.44,
DoublePtr: testutils.Float64Ptr(55.66),
Real: 77.88,
RealPtr: testutils.Float32Ptr(99.00),
Time: time.Date(1, 1, 1, 1, 1, 1, 10, time.UTC),
TimePtr: testutils.TimePtr(time.Date(2, 2, 2, 2, 2, 2, 200, time.UTC)),
Date: time.Now(),
DatePtr: testutils.TimePtr(time.Now()),
DateTime: time.Now(),
DateTimePtr: testutils.TimePtr(time.Now()),
Timestamp: time.Now(),
TimestampPtr: testutils.TimePtr(time.Now()),
Char: "abcd",
CharPtr: testutils.StringPtr("absd"),
VarChar: "abcd",
VarCharPtr: testutils.StringPtr("absd"),
Blob: []byte("large file"),
BlobPtr: testutils.ByteArrayPtr([]byte("very large file")),
Text: "some text",
TextPtr: testutils.StringPtr("text"),
}
func TestUUID(t *testing.T) {
query := SELECT(
//Raw("uuid()").AS("uuid"),
String("dc8daae3-b83b-11e9-8eb4-98ded00c39c6").AS("str_uuid"),
)
var dest struct {
UUID uuid.UUID
StrUUID *uuid.UUID
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, dest.StrUUID.String(), "dc8daae3-b83b-11e9-8eb4-98ded00c39c6")
requireLogged(t, query)
}
func TestExpressionOperators(t *testing.T) {
query := SELECT(
AllTypes.Integer.IS_NULL().AS("result.is_null"),
AllTypes.DatePtr.IS_NOT_NULL().AS("result.is_not_null"),
AllTypes.SmallIntPtr.IN(Int(11), Int(22)).AS("result.in"),
AllTypes.SmallIntPtr.IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.in_select"),
Raw("length(121232459)").AS("result.raw"),
Raw(":first + COALESCE(all_types.small_int_ptr, 0) + :second", RawArgs{":first": 78, ":second": 56}).
AS("result.raw_arg"),
Raw("#1 + all_types.integer + #2 + #1 + #3 + #4", RawArgs{"#1": 11, "#2": 22, "#3": 33, "#4": 44}).
AS("result.raw_arg2"),
AllTypes.SmallIntPtr.NOT_IN(Int(11), Int(22), NULL).AS("result.not_in"),
AllTypes.SmallIntPtr.NOT_IN(AllTypes.SELECT(AllTypes.Integer)).AS("result.not_in_select"),
).FROM(
AllTypes,
).LIMIT(2)
testutils.AssertStatementSql(t, query, strings.Replace(`
SELECT all_types.integer IS NULL AS "result.is_null",
all_types.date_ptr IS NOT NULL AS "result.is_not_null",
(all_types.small_int_ptr IN (?, ?)) AS "result.in",
(all_types.small_int_ptr IN (
SELECT all_types.integer AS "all_types.integer"
FROM all_types
)) AS "result.in_select",
(length(121232459)) AS "result.raw",
(? + COALESCE(all_types.small_int_ptr, 0) + ?) AS "result.raw_arg",
(? + all_types.integer + ? + ? + ? + ?) AS "result.raw_arg2",
(all_types.small_int_ptr NOT IN (?, ?, NULL)) AS "result.not_in",
(all_types.small_int_ptr NOT IN (
SELECT all_types.integer AS "all_types.integer"
FROM all_types
)) AS "result.not_in_select"
FROM all_types
LIMIT ?;
`, "'", "`", -1), int64(11), int64(22), 78, 56, 11, 22, 11, 33, 44, int64(11), int64(22), int64(2))
var dest []struct {
common.ExpressionTestResult `alias:"result.*"`
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, *dest[0].IsNull, false)
require.Equal(t, *dest[0].IsNotNull, true)
require.Equal(t, *dest[0].In, false)
require.Equal(t, *dest[0].InSelect, false)
require.Equal(t, *dest[0].Raw, "9")
require.Equal(t, *dest[0].RawArg, int32(148))
require.Equal(t, *dest[0].RawArg2, int32(-1479))
require.Nil(t, dest[0].NotIn)
require.Equal(t, *dest[0].NotInSelect, true)
}
func TestBoolOperators(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"),
AllTypes.Boolean.EQ(Bool(true)).AS("EQ2"),
AllTypes.Boolean.NOT_EQ(AllTypes.BooleanPtr).AS("NEq1"),
AllTypes.Boolean.NOT_EQ(Bool(false)).AS("NEq2"),
AllTypes.Boolean.IS_DISTINCT_FROM(AllTypes.BooleanPtr).AS("distinct1"),
AllTypes.Boolean.IS_DISTINCT_FROM(Bool(true)).AS("distinct2"),
AllTypes.Boolean.IS_NOT_DISTINCT_FROM(AllTypes.BooleanPtr).AS("not_distinct_1"),
AllTypes.Boolean.IS_NOT_DISTINCT_FROM(Bool(true)).AS("NOTDISTINCT2"),
AllTypes.Boolean.IS_TRUE().AS("ISTRUE"),
AllTypes.Boolean.IS_NOT_TRUE().AS("isnottrue"),
AllTypes.Boolean.IS_FALSE().AS("is_False"),
AllTypes.Boolean.IS_NOT_FALSE().AS("is not false"),
AllTypes.Boolean.IS_NULL().AS("is unknown"),
AllTypes.Boolean.IS_NOT_NULL().AS("is_not_unknown"),
AllTypes.Boolean.AND(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)).AS("complex1"),
AllTypes.Boolean.OR(AllTypes.Boolean).EQ(AllTypes.Boolean.AND(AllTypes.Boolean)).AS("complex2"),
)
testutils.AssertStatementSql(t, query, `
SELECT (all_types.boolean = all_types.boolean_ptr) AS "EQ1",
(all_types.boolean = ?) AS "EQ2",
(all_types.boolean != all_types.boolean_ptr) AS "NEq1",
(all_types.boolean != ?) AS "NEq2",
(all_types.boolean IS NOT all_types.boolean_ptr) AS "distinct1",
(all_types.boolean IS NOT ?) AS "distinct2",
(all_types.boolean IS all_types.boolean_ptr) AS "not_distinct_1",
(all_types.boolean IS ?) AS "NOTDISTINCT2",
all_types.boolean IS TRUE AS "ISTRUE",
all_types.boolean IS NOT TRUE AS "isnottrue",
all_types.boolean IS FALSE AS "is_False",
all_types.boolean IS NOT FALSE AS "is not false",
all_types.boolean IS NULL AS "is unknown",
all_types.boolean IS NOT NULL AS "is_not_unknown",
((all_types.boolean AND all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex1",
((all_types.boolean OR all_types.boolean) = (all_types.boolean AND all_types.boolean)) AS "complex2"
FROM all_types;
`, true, false, true, true)
var dest []struct {
Eq1 *bool
Eq2 *bool
NEq1 *bool
NEq2 *bool
Distinct1 *bool
Distinct2 *bool
NotDistinct1 *bool
NotDistinct2 *bool
IsTrue *bool
IsNotTrue *bool
IsFalse *bool
IsNotFalse *bool
IsUnknown *bool
IsNotUnknown *bool
Complex1 *bool
Complex2 *bool
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/common/bool_operators.json")
}
func TestFloatOperators(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.Numeric.EQ(AllTypes.Numeric).AS("eq1"),
AllTypes.Decimal.EQ(Float(12.22)).AS("eq2"),
AllTypes.Real.EQ(Float(12.12)).AS("eq3"),
AllTypes.Numeric.IS_DISTINCT_FROM(AllTypes.Numeric).AS("distinct1"),
AllTypes.Decimal.IS_DISTINCT_FROM(Float(12)).AS("distinct2"),
AllTypes.Real.IS_DISTINCT_FROM(Float(12.12)).AS("distinct3"),
AllTypes.Numeric.IS_NOT_DISTINCT_FROM(AllTypes.Numeric).AS("not_distinct1"),
AllTypes.Decimal.IS_NOT_DISTINCT_FROM(Float(12)).AS("not_distinct2"),
AllTypes.Real.IS_NOT_DISTINCT_FROM(Float(12.12)).AS("not_distinct3"),
AllTypes.Numeric.LT(Float(124)).AS("lt1"),
AllTypes.Numeric.LT(Float(34.56)).AS("lt2"),
AllTypes.Numeric.GT(Float(124)).AS("gt1"),
AllTypes.Numeric.GT(Float(34.56)).AS("gt2"),
AllTypes.Decimal.ADD(AllTypes.Decimal).AS("add1"),
AllTypes.Decimal.ADD(Float(11.22)).AS("add2"),
AllTypes.Decimal.SUB(AllTypes.DecimalPtr).AS("sub1"),
AllTypes.Decimal.SUB(Float(11.22)).AS("sub2"),
AllTypes.Decimal.MUL(AllTypes.DecimalPtr).AS("mul1"),
AllTypes.Decimal.MUL(Float(11.22)).AS("mul2"),
AllTypes.Decimal.DIV(AllTypes.DecimalPtr).AS("div1"),
AllTypes.Decimal.DIV(Float(11.22)).AS("div2"),
AllTypes.Decimal.MOD(AllTypes.DecimalPtr).AS("mod1"),
AllTypes.Decimal.MOD(Float(11.22)).AS("mod2"),
// sqlite driver has to enable SQLITE_ENABLE_MATH_FUNCTIONS before commented math functions can be used
//AllTypes.Decimal.POW(AllTypes.DecimalPtr).AS("pow1"),
//AllTypes.Decimal.POW(Float(2.1)).AS("pow2"),
ABSf(AllTypes.Decimal).AS("abs"),
//POWER(AllTypes.Decimal, Float(2.1)).AS("power"),
//SQRT(AllTypes.Decimal).AS("sqrt"),
//CBRT(AllTypes.Decimal).AS("cbrt"),
//CEIL(AllTypes.Real).AS("ceil"),
//FLOOR(AllTypes.Real).AS("floor"),
ROUND(AllTypes.Decimal).AS("round1"),
ROUND(AllTypes.Decimal, Int(2)).AS("round2"),
//TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"),
SIGN(AllTypes.Real).AS("sign"),
).LIMIT(1)
testutils.AssertStatementSql(t, query, `
SELECT (all_types.numeric = all_types.numeric) AS "eq1",
(all_types.decimal = ?) AS "eq2",
(all_types.real = ?) AS "eq3",
(all_types.numeric IS NOT all_types.numeric) AS "distinct1",
(all_types.decimal IS NOT ?) AS "distinct2",
(all_types.real IS NOT ?) AS "distinct3",
(all_types.numeric IS all_types.numeric) AS "not_distinct1",
(all_types.decimal IS ?) AS "not_distinct2",
(all_types.real IS ?) AS "not_distinct3",
(all_types.numeric < ?) AS "lt1",
(all_types.numeric < ?) AS "lt2",
(all_types.numeric > ?) AS "gt1",
(all_types.numeric > ?) AS "gt2",
(all_types.decimal + all_types.decimal) AS "add1",
(all_types.decimal + ?) AS "add2",
(all_types.decimal - all_types.decimal_ptr) AS "sub1",
(all_types.decimal - ?) AS "sub2",
(all_types.decimal * all_types.decimal_ptr) AS "mul1",
(all_types.decimal * ?) AS "mul2",
(all_types.decimal / all_types.decimal_ptr) AS "div1",
(all_types.decimal / ?) AS "div2",
(all_types.decimal % all_types.decimal_ptr) AS "mod1",
(all_types.decimal % ?) AS "mod2",
ABS(all_types.decimal) AS "abs",
ROUND(all_types.decimal) AS "round1",
ROUND(all_types.decimal, ?) AS "round2",
SIGN(all_types.real) AS "sign"
FROM all_types
LIMIT ?;
`)
var dest struct {
common.FloatExpressionTestResult `alias:"."`
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, *dest.Eq1, true)
require.Equal(t, *dest.Distinct1, false)
require.Equal(t, *dest.Lt1, true)
require.Equal(t, *dest.Add1, 2.22)
require.Equal(t, *dest.Mod2, float64(1))
require.Equal(t, *dest.Round1, float64(1))
require.Equal(t, *dest.Round2, float64(1.11))
require.Equal(t, *dest.Sign, float64(1))
//testutils.AssertJSONFile(t, dest, "./testdata/results/common/float_operators.json")
}
func TestIntegerOperators(t *testing.T) {
query := AllTypes.SELECT(
AllTypes.BigInt,
AllTypes.BigIntPtr,
AllTypes.SmallInt,
AllTypes.SmallIntPtr,
AllTypes.BigInt.EQ(AllTypes.BigInt).AS("eq1"),
AllTypes.BigInt.EQ(Int(12)).AS("eq2"),
AllTypes.BigInt.NOT_EQ(AllTypes.BigIntPtr).AS("neq1"),
AllTypes.BigInt.NOT_EQ(Int(12)).AS("neq2"),
AllTypes.BigInt.IS_DISTINCT_FROM(AllTypes.BigInt).AS("distinct1"),
AllTypes.BigInt.IS_DISTINCT_FROM(Int(12)).AS("distinct2"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(AllTypes.BigInt).AS("not distinct1"),
AllTypes.BigInt.IS_NOT_DISTINCT_FROM(Int(12)).AS("not distinct2"),
AllTypes.BigInt.LT(AllTypes.BigIntPtr).AS("lt1"),
AllTypes.BigInt.LT(Int(65)).AS("lt2"),
AllTypes.BigInt.LT_EQ(AllTypes.BigIntPtr).AS("lte1"),
AllTypes.BigInt.LT_EQ(Int(65)).AS("lte2"),
AllTypes.BigInt.GT(AllTypes.BigIntPtr).AS("gt1"),
AllTypes.BigInt.GT(Int(65)).AS("gt2"),
AllTypes.BigInt.GT_EQ(AllTypes.BigIntPtr).AS("gte1"),
AllTypes.BigInt.GT_EQ(Int(65)).AS("gte2"),
AllTypes.BigInt.ADD(AllTypes.BigInt).AS("add1"),
AllTypes.BigInt.ADD(Int(11)).AS("add2"),
AllTypes.BigInt.SUB(AllTypes.BigInt).AS("sub1"),
AllTypes.BigInt.SUB(Int(11)).AS("sub2"),
AllTypes.BigInt.MUL(AllTypes.BigInt).AS("mul1"),
AllTypes.BigInt.MUL(Int(11)).AS("mul2"),
AllTypes.BigInt.DIV(AllTypes.BigInt).AS("div1"),
AllTypes.BigInt.DIV(Int(11)).AS("div2"),
AllTypes.BigInt.MOD(AllTypes.BigInt).AS("mod1"),
AllTypes.BigInt.MOD(Int(11)).AS("mod2"),
//AllTypes.SmallInt.POW(AllTypes.SmallInt.DIV(Int(3))).AS("pow1"),
//AllTypes.SmallInt.POW(Int(6)).AS("pow2"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and1"),
AllTypes.SmallInt.BIT_AND(AllTypes.SmallInt).AS("bit_and2"),
AllTypes.SmallInt.BIT_OR(AllTypes.SmallInt).AS("bit or 1"),
AllTypes.SmallInt.BIT_OR(Int(22)).AS("bit or 2"),
AllTypes.SmallInt.BIT_XOR(AllTypes.SmallInt).AS("bit xor 1"),
AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"),
BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"),
BIT_NOT(Int(-1).MUL(Int(11))).AS("bit_not_2"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"),
AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(AllTypes.SmallInt.DIV(Int(5))).AS("bit shift right 1"),
AllTypes.SmallInt.BIT_SHIFT_RIGHT(Int(1)).AS("bit shift right 2"),
ABSi(AllTypes.BigInt).AS("abs"),
//SQRT(ABSi(AllTypes.BigInt)).AS("sqrt"),
//CBRT(ABSi(AllTypes.BigInt)).AS("cbrt"),
).LIMIT(2)
var dest []struct {
common.AllTypesIntegerExpResult `alias:"."`
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, *dest[0].Eq1, true)
require.Equal(t, *dest[0].Distinct2, true)
require.Equal(t, *dest[0].Lt2, false)
require.Equal(t, *dest[0].Add1, int64(10000))
require.Equal(t, *dest[0].Mul1, int64(25000000))
require.Equal(t, *dest[0].Div2, int64(454))
require.Equal(t, *dest[0].BitAnd1, int64(14))
require.Equal(t, *dest[0].BitXor2, int64(5))
require.Equal(t, *dest[0].BitShiftLeft1, int64(1792))
require.Equal(t, *dest[0].BitShiftRight2, int64(7))
}
func TestStringOperators(t *testing.T) {
query := SELECT(
AllTypes.Text.EQ(AllTypes.Char),
AllTypes.Text.EQ(String("Text")),
AllTypes.Text.NOT_EQ(AllTypes.VarCharPtr),
AllTypes.Text.NOT_EQ(String("Text")),
AllTypes.Text.GT(AllTypes.Text),
AllTypes.Text.GT(String("Text")),
AllTypes.Text.GT_EQ(AllTypes.TextPtr),
AllTypes.Text.GT_EQ(String("Text")),
AllTypes.Text.LT(AllTypes.Char),
AllTypes.Text.LT(String("Text")),
AllTypes.Text.LT_EQ(AllTypes.VarCharPtr),
AllTypes.Text.LT_EQ(String("Text")),
AllTypes.Text.CONCAT(String("text2")),
AllTypes.Text.CONCAT(Int(11)),
AllTypes.Text.LIKE(String("abc")),
AllTypes.Text.NOT_LIKE(String("_b_")),
//AllTypes.Text.REGEXP_LIKE(String("aba")),
//AllTypes.Text.REGEXP_LIKE(String("aba"), false),
//String("ABA").REGEXP_LIKE(String("aba"), true),
//AllTypes.Text.NOT_REGEXP_LIKE(String("aba")),
//AllTypes.Text.NOT_REGEXP_LIKE(String("aba"), false),
//String("ABA").NOT_REGEXP_LIKE(String("aba"), true),
//BIT_LENGTH(AllTypes.Text),
//CHAR_LENGTH(AllTypes.Char),
//OCTET_LENGTH(AllTypes.Text),
LOWER(AllTypes.VarCharPtr),
UPPER(AllTypes.Char),
LTRIM(AllTypes.VarCharPtr),
RTRIM(AllTypes.VarCharPtr),
//CONCAT(String("string1"), Int(1), Float(11.12)),
//CONCAT_WS(String("string1"), Int(1), Float(11.12)),
//FORMAT(String("Hello %s, %1$s"), String("World")),
//LEFTSTR(String("abcde"), Int(2)),
//RIGHTSTR(String("abcde"), Int(2)),
LENGTH(String("jose")),
//LPAD(String("Hi"), Int(5), String("xy")),
//RPAD(String("Hi"), Int(5), String("xy")),
//MD5(AllTypes.VarCharPtr),
//REPEAT(AllTypes.Text, Int(33)),
REPLACE(AllTypes.Char, String("BA"), String("AB")),
//REVERSE(AllTypes.VarCharPtr),
SUBSTR(AllTypes.CharPtr, Int(3)),
SUBSTR(AllTypes.CharPtr, Int(3), Int(2)),
).FROM(AllTypes)
dest := []struct{}{}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
}
func TestReservedWord(t *testing.T) {
stmt := SELECT(ReservedWords.AllColumns).
FROM(ReservedWords)
testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
SELECT ''ReservedWords''.''column'' AS "ReservedWords.column",
''ReservedWords''.use AS "ReservedWords.use",
''ReservedWords''.ceil AS "ReservedWords.ceil",
''ReservedWords''.''commit'' AS "ReservedWords.commit",
''ReservedWords''.''create'' AS "ReservedWords.create",
''ReservedWords''.''default'' AS "ReservedWords.default",
''ReservedWords''.''desc'' AS "ReservedWords.desc",
''ReservedWords''.empty AS "ReservedWords.empty",
''ReservedWords''.float AS "ReservedWords.float",
''ReservedWords''.''join'' AS "ReservedWords.join",
''ReservedWords''.''like'' AS "ReservedWords.like",
''ReservedWords''.max AS "ReservedWords.max",
''ReservedWords''.rank AS "ReservedWords.rank"
FROM ''ReservedWords'';
`, "''", "`", -1))
var dest model.ReservedWords
err := stmt.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, dest, model.ReservedWords{
Column: "Column",
Use: "CHECK",
Ceil: "CEIL",
Commit: "COMMIT",
Create: "CREATE",
Default: "DEFAULT",
Desc: "DESC",
Empty: "EMPTY",
Float: "FLOAT",
Join: "JOIN",
Like: "LIKE",
Max: "MAX",
Rank: "RANK",
})
}
func TestExactDecimals(t *testing.T) {
type exactDecimals struct {
model.ExactDecimals
Decimal decimal.Decimal
DecimalPtr decimal.Decimal
}
t.Run("should query decimal", func(t *testing.T) {
query := SELECT(
ExactDecimals.AllColumns,
).FROM(
ExactDecimals,
).WHERE(ExactDecimals.Decimal.EQ(String("1.11111111111111111111")))
var result exactDecimals
err := query.Query(sampleDB, &result)
require.NoError(t, err)
require.Equal(t, "1.11111111111111111111", result.Decimal.String())
require.Equal(t, "0", result.DecimalPtr.String()) // NULL
require.Equal(t, "1.11111111111111111111", result.ExactDecimals.Decimal) // precision loss
require.Equal(t, (*string)(nil), result.ExactDecimals.DecimalPtr)
require.Equal(t, "2.22222222222222222222", result.ExactDecimals.Numeric)
require.Equal(t, (*string)(nil), result.ExactDecimals.NumericPtr) // NULL
})
t.Run("should insert decimal", func(t *testing.T) {
insertQuery := ExactDecimals.INSERT(
ExactDecimals.AllColumns,
).MODEL(
exactDecimals{
ExactDecimals: model.ExactDecimals{
// overwritten by wrapped(exactDecimals) scope
Decimal: "0.1",
DecimalPtr: nil,
// not overwritten
Numeric: "6.7",
NumericPtr: testutils.StringPtr("7.7"),
},
Decimal: decimal.RequireFromString("91.23"),
DecimalPtr: decimal.RequireFromString("45.67"),
},
).RETURNING(ExactDecimals.AllColumns)
testutils.AssertDebugStatementSql(t, insertQuery, strings.Replace(`
INSERT INTO exact_decimals (decimal, decimal_ptr, numeric, numeric_ptr)
VALUES ('91.23', '45.67', '6.7', '7.7')
RETURNING exact_decimals.decimal AS "exact_decimals.decimal",
exact_decimals.decimal_ptr AS "exact_decimals.decimal_ptr",
exact_decimals.numeric AS "exact_decimals.numeric",
exact_decimals.numeric_ptr AS "exact_decimals.numeric_ptr";
`, "''", "`", -1))
tx := beginSampleDBTx(t)
defer tx.Rollback()
var result exactDecimals
err := insertQuery.Query(tx, &result)
require.NoError(t, err)
require.Equal(t, "91.23", result.Decimal.String())
require.Equal(t, "45.67", result.DecimalPtr.String())
require.Equal(t, "6.7", result.ExactDecimals.Numeric)
require.Equal(t, "7.7", *result.ExactDecimals.NumericPtr)
require.Equal(t, "91.23", result.ExactDecimals.Decimal)
require.Equal(t, "45.67", *result.ExactDecimals.DecimalPtr)
})
}
var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC)
func TestDateExpressions(t *testing.T) {
query := AllTypes.SELECT(
//Date(2009, 11, 17, 2, MONTH, 1, DAY),
//DateT(timeT, START_OF_THE_MONTH),
AllTypes.Date.AS("date"),
DATE("2009-11-17").AS("date1"),
DATE("2013-10-07 08:23:19.120", DAYS(1)).AS("date2"),
DATE(AllTypes.Date, START_OF_YEAR, DAYS(2)).AS("date3"),
DATE(timeT, START_OF_MONTH).AS("date3"),
DATE("now", WEEKDAY(1)).AS("date4"),
DATE(timeT.Unix(), UNIXEPOCH).AS("date5"),
DATE(time.Now(), UTC).AS("date6"),
DATE(time.Now().UTC(), LOCALTIME).AS("date7"),
AllTypes.Date.EQ(AllTypes.Date),
AllTypes.Date.EQ(Date(2019, 6, 6)),
AllTypes.DatePtr.NOT_EQ(AllTypes.Date),
AllTypes.DatePtr.NOT_EQ(Date(2019, 1, 6)),
AllTypes.Date.IS_DISTINCT_FROM(AllTypes.Date).AS("distinct1"),
AllTypes.Date.IS_DISTINCT_FROM(Date(2008, 7, 4)).AS("distinct2"),
AllTypes.Date.IS_NOT_DISTINCT_FROM(AllTypes.Date),
AllTypes.Date.IS_NOT_DISTINCT_FROM(Date(2019, 3, 6)),
AllTypes.Date.LT(AllTypes.Date),
AllTypes.Date.LT(Date(2019, 4, 6)),
AllTypes.Date.LT_EQ(AllTypes.Date),
AllTypes.Date.LT_EQ(Date(2019, 5, 5)),
AllTypes.Date.GT(AllTypes.Date),
AllTypes.Date.GT(Date(2019, 1, 4)),
AllTypes.Date.GT_EQ(AllTypes.Date),
AllTypes.Date.GT_EQ(Date(2019, 2, 3)),
//AllTypes.Date.ADD(INTERVAL2(2, HOUR)),
//AllTypes.Date.ADD(INTERVAL2(1, DAY, 7, MONTH)),
//AllTypes.Date.ADD(INTERVALd(25 * time.Hour + 100 * time.Millisecond)),
//AllTypes.Date.ADD(INTERVALd(-25 * time.Hour - 100 * time.Millisecond)),
//
//AllTypes.Date.SUB(INTERVAL(20, MINUTE)),
//AllTypes.Date.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)),
//AllTypes.Date.SUB(INTERVALd(3*time.Minute)),
CURRENT_DATE().AS("current_date"),
)
var dest struct {
Date string
Date1 time.Time
Date2 string
Date3 time.Time
Date4 string
Date5 time.Time
Date6 string
Date7 time.Time
Distinct1 bool
Distinct2 bool
CurrentDate time.Time
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, dest.Date, "2008-07-04T00:00:00Z")
require.Equal(t, dest.Date1.Unix(), int64(1258416000))
}
func TestTimeExpressions(t *testing.T) {
query := AllTypes.SELECT(
TIME(AllTypes.Time).AS("time1"),
TIME(timeT).AS("time2"),
TIME("04:23:19.120-04:00", HOURS(1), MINUTES(2), SECONDS(1.234)).AS("time3"),
TIME(timeT.Unix(), UNIXEPOCH).AS("time4"),
TIME(time.Now(), UTC).AS("time5"),
TIME(time.Now().UTC(), LOCALTIME).AS("time6"),
Time(timeT.Clock()),
AllTypes.Time.EQ(AllTypes.Time),
AllTypes.Time.EQ(Time(23, 6, 6)),
AllTypes.Time.EQ(Time(22, 6, 6, 11*time.Millisecond)),
AllTypes.Time.EQ(Time(21, 6, 6, 11111*time.Microsecond)),
AllTypes.TimePtr.NOT_EQ(AllTypes.Time),
AllTypes.TimePtr.NOT_EQ(Time(20, 16, 6)),
AllTypes.Time.IS_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_DISTINCT_FROM(Time(19, 26, 6)),
AllTypes.Time.IS_NOT_DISTINCT_FROM(AllTypes.Time),
AllTypes.Time.IS_NOT_DISTINCT_FROM(Time(18, 36, 6)),
AllTypes.Time.LT(AllTypes.Time),
AllTypes.Time.LT(Time(17, 46, 6)),
AllTypes.Time.LT_EQ(AllTypes.Time),
AllTypes.Time.LT_EQ(Time(16, 56, 56)),
AllTypes.Time.GT(AllTypes.Time),
AllTypes.Time.GT(Time(15, 16, 46)),
AllTypes.Time.GT_EQ(AllTypes.Time),
AllTypes.Time.GT_EQ(Time(14, 26, 36)),
//AllTypes.Time.ADD(INTERVAL(10, MINUTE)),
//AllTypes.Time.ADD(INTERVALe(AllTypes.Integer, MINUTE)),
//AllTypes.Time.ADD(INTERVALd(3*time.Hour)),
//
//AllTypes.Time.SUB(INTERVAL(20, MINUTE)),
//AllTypes.Time.SUB(INTERVALe(AllTypes.SmallInt, MINUTE)),
//AllTypes.Time.SUB(INTERVALd(3*time.Minute)),
//
//AllTypes.Time.ADD(INTERVAL(20, MINUTE)).SUB(INTERVAL(11, HOUR)),
CURRENT_TIME(),
)
var dest struct {
Time1 string
Time2 time.Time
Time3 string
Time4 time.Time
Time5 string
Time6 time.Time
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.Equal(t, dest.Time1, "10:11:12")
require.Equal(t, dest.Time2.UTC().String(), "0000-01-01 20:34:58 +0000 UTC")
require.Equal(t, dest.Time3, "09:25:20")
}
func TestDateTimeExpressions(t *testing.T) {
var dateTime = DateTime(2019, 6, 6, 10, 2, 46)
query := SELECT(
DATETIME("now").AS("now"),
DATETIME("2013-10-07T08:23:19.120Z", YEARS(2), MONTHS(1), DAYS(1)).AS("datetime1"),
DATETIME(AllTypes.DateTime, MONTHS(1), DAYS(1)).AS("datetime2"),
DATETIME(timeT.Unix(), UNIXEPOCH).AS("datetime3"),
DATETIME(time.Now(), UTC).AS("datetime4"),
DATETIME(timeT.UTC(), LOCALTIME).AS("datetime5"),
JULIANDAY(timeT, DAYS(1)).AS("JulianDay"),
STRFTIME(String("%H:%M"), timeT, SECONDS(1.22)).AS("strftime"),
AllTypes.DateTime.EQ(AllTypes.DateTime),
AllTypes.DateTime.EQ(dateTime),
AllTypes.DateTimePtr.NOT_EQ(AllTypes.DateTime),
AllTypes.DateTimePtr.NOT_EQ(DateTime(2019, 6, 6, 10, 2, 46, 100*time.Millisecond)),
AllTypes.DateTime.IS_DISTINCT_FROM(AllTypes.DateTime),
AllTypes.DateTime.IS_DISTINCT_FROM(dateTime),
AllTypes.DateTime.IS_NOT_DISTINCT_FROM(AllTypes.DateTime),
AllTypes.DateTime.IS_NOT_DISTINCT_FROM(dateTime),
AllTypes.DateTime.LT(AllTypes.DateTime),
AllTypes.DateTime.LT(dateTime),
AllTypes.DateTime.LT_EQ(AllTypes.DateTime),
AllTypes.DateTime.LT_EQ(dateTime),
AllTypes.DateTime.GT(AllTypes.DateTime),
AllTypes.DateTime.GT(dateTime),
AllTypes.DateTime.GT_EQ(AllTypes.DateTime),
AllTypes.DateTime.GT_EQ(dateTime),
//AllTypes.DateTime.ADD(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
//AllTypes.DateTime.ADD(INTERVALe(AllTypes.BigInt, HOUR)),
//AllTypes.DateTime.ADD(INTERVALd(2*time.Hour)),
//
//AllTypes.DateTime.SUB(INTERVAL("05:10:20.000100", HOUR_MICROSECOND)),
//AllTypes.DateTime.SUB(INTERVALe(AllTypes.IntegerPtr, HOUR)),
//AllTypes.DateTime.SUB(INTERVALd(3*time.Hour)),
CURRENT_TIMESTAMP(),
).FROM(AllTypes)
var dest struct {
Now time.Time
DateTime1 time.Time
DateTime2 time.Time
DateTime3 time.Time
DateTime4 time.Time
DateTime5 time.Time
JulianDay float64
StrfTime string
}
err := query.Query(sampleDB, &dest)
require.NoError(t, err)
require.True(t, dest.Now.After(time.Now().Add(-1*time.Minute)))
require.Equal(t, dest.DateTime1.String(), "2015-11-08 08:23:19 +0000 UTC")
require.Equal(t, dest.DateTime2.String(), "2012-01-19 13:17:17 +0000 UTC")
require.Equal(t, dest.DateTime3.String(), "2009-11-17 20:34:58 +0000 UTC")
require.True(t, dest.DateTime4.After(time.Now().Add(-1*time.Minute)))
require.Equal(t, dest.DateTime5.String(), "2009-11-17 21:34:58 +0000 UTC")
require.Equal(t, dest.JulianDay, 2.4551543576232754e+06)
require.Equal(t, dest.StrfTime, "20:34")
}

41
tests/sqlite/cast_test.go Normal file
View file

@ -0,0 +1,41 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/stretchr/testify/require"
"testing"
)
func TestCast(t *testing.T) {
query := SELECT(
CAST(String("test")).AS("CHARACTER").AS("result.AS1"),
CAST(Float(11.33)).AS_TEXT().AS("result.text"),
CAST(String("33.44")).AS_REAL().AS("result.real"),
CAST(String("33")).AS_INTEGER().AS("result.integer"),
CAST(String("Blob blob")).AS_BLOB().AS("result.blob"),
)
type Result struct {
As1 string
Text string
Real float64
Integer int64
Blob []byte
}
var dest Result
err := query.Query(db, &dest)
require.NoError(t, err)
testutils.AssertDeepEqual(t, dest, Result{
As1: "test",
Text: "11.33",
Real: 33.44,
Integer: 33,
Blob: []byte("Blob blob"),
})
requireLogged(t, query)
}

View file

@ -0,0 +1,83 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table"
"github.com/stretchr/testify/require"
)
func TestDelete_WHERE_RETURNING(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
DELETE FROM link
WHERE link.name IN ('Bing', 'Yahoo')
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`
deleteStmt := Link.DELETE().
WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))).
RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Bing", "Yahoo")
var dest []model.Link
err := deleteStmt.Query(tx, &dest)
require.NoError(t, err)
require.Len(t, dest, 2)
requireLogged(t, deleteStmt)
}
func TestDeleteWithWhereOrderByLimit(t *testing.T) {
t.SkipNow() // Until https://github.com/mattn/go-sqlite3/pull/802 is fixed
tx := beginSampleDBTx(t)
defer tx.Rollback()
sampleDB.Stats()
var expectedSQL = `
DELETE FROM link
WHERE link.name IN ('Bing', 'Yahoo')
ORDER BY link.name
LIMIT 1;
`
deleteStmt := Link.DELETE().
WHERE(Link.Name.IN(String("Bing"), String("Yahoo"))).
ORDER_BY(Link.Name).
LIMIT(1)
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Bing", "Yahoo", int64(1))
testutils.AssertExec(t, deleteStmt, tx, 1)
requireLogged(t, deleteStmt)
}
func TestDeleteContextDeadlineExceeded(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
deleteStmt := Link.
DELETE().
WHERE(Link.Name.IN(String("Bing"), String("Yahoo")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := deleteStmt.QueryContext(ctx, tx, &dest)
require.Error(t, err, "context deadline exceeded")
_, err = deleteStmt.ExecContext(ctx, tx)
require.Error(t, err, "context deadline exceeded")
requireLogged(t, deleteStmt)
}

View file

@ -0,0 +1,298 @@
package sqlite
import (
"github.com/go-jet/jet/v2/generator/sqlite"
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
"github.com/go-jet/jet/v2/tests/internal/utils/repo"
"github.com/stretchr/testify/require"
"io/ioutil"
"os"
"os/exec"
"reflect"
"testing"
)
func TestGeneratedModel(t *testing.T) {
actor := model.Actor{}
require.Equal(t, reflect.TypeOf(actor.ActorID).String(), "int32")
actorIDField, ok := reflect.TypeOf(actor).FieldByName("ActorID")
require.True(t, ok)
require.Equal(t, actorIDField.Tag.Get("sql"), "primary_key")
require.Equal(t, reflect.TypeOf(actor.FirstName).String(), "string")
require.Equal(t, reflect.TypeOf(actor.LastName).String(), "string")
require.Equal(t, reflect.TypeOf(actor.LastUpdate).String(), "time.Time")
filmActor := model.FilmActor{}
require.Equal(t, reflect.TypeOf(filmActor.FilmID).String(), "int32")
filmIDField, ok := reflect.TypeOf(filmActor).FieldByName("FilmID")
require.True(t, ok)
require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
require.Equal(t, reflect.TypeOf(filmActor.ActorID).String(), "int32")
actorIDField, ok = reflect.TypeOf(filmActor).FieldByName("ActorID")
require.True(t, ok)
require.Equal(t, filmIDField.Tag.Get("sql"), "primary_key")
staff := model.Staff{}
require.Equal(t, reflect.TypeOf(staff.Email).String(), "*string")
require.Equal(t, reflect.TypeOf(staff.Picture).String(), "*[]uint8")
}
var testDatabaseFilePath = repo.GetTestDataFilePath("/init/sqlite/sakila.db")
var genDestDir = repo.GetTestsFilePath("/sqlite/.gen")
func TestGenerator(t *testing.T) {
for i := 0; i < 3; i++ {
err := sqlite.GenerateDSN(testDatabaseFilePath, genDestDir)
require.NoError(t, err)
assertGeneratedFiles(t)
}
err := os.RemoveAll(genDestDir)
require.NoError(t, err)
}
func TestCmdGenerator(t *testing.T) {
cmd := exec.Command("jet", "-source=SQLite", "-dsn=file://"+testDatabaseFilePath, "-path="+genDestDir)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
err := cmd.Run()
require.NoError(t, err)
assertGeneratedFiles(t)
err = os.RemoveAll(genDestDir)
require.NoError(t, err)
}
func assertGeneratedFiles(t *testing.T) {
// Table SQL Builder files
tableSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/table")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, tableSQLBuilderFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go")
testutils.AssertFileContent(t, genDestDir+"/table/actor.go", actorSQLBuilderFile)
// View SQL Builder files
viewSQLBuilderFiles, err := ioutil.ReadDir(genDestDir + "/view")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, viewSQLBuilderFiles, "film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genDestDir+"/view/film_list.go", filmListSQLBuilderFile)
// Model files
modelFiles, err := ioutil.ReadDir(genDestDir + "/model")
require.NoError(t, err)
testutils.AssertFileNamesEqual(t, modelFiles, "actor.go", "address.go", "category.go", "city.go", "country.go",
"customer.go", "film.go", "film_actor.go", "film_category.go", "film_text.go", "inventory.go", "language.go",
"payment.go", "rental.go", "staff.go", "store.go",
"film_list.go", "sales_by_film_category.go",
"customer_list.go", "sales_by_store.go", "staff_list.go")
testutils.AssertFileContent(t, genDestDir+"/model/address.go", addressModelFile)
}
const actorSQLBuilderFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package table
import (
"github.com/go-jet/jet/v2/sqlite"
)
var Actor = newActorTable("", "actor", "")
type actorTable struct {
sqlite.Table
//Columns
ActorID sqlite.ColumnInteger
FirstName sqlite.ColumnString
LastName sqlite.ColumnString
LastUpdate sqlite.ColumnTimestamp
AllColumns sqlite.ColumnList
MutableColumns sqlite.ColumnList
}
type ActorTable struct {
actorTable
EXCLUDED actorTable
}
// AS creates new ActorTable with assigned alias
func (a ActorTable) AS(alias string) *ActorTable {
return newActorTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new ActorTable with assigned schema name
func (a ActorTable) FromSchema(schemaName string) *ActorTable {
return newActorTable(schemaName, a.TableName(), a.Alias())
}
func newActorTable(schemaName, tableName, alias string) *ActorTable {
return &ActorTable{
actorTable: newActorTableImpl(schemaName, tableName, alias),
EXCLUDED: newActorTableImpl("", "excluded", ""),
}
}
func newActorTableImpl(schemaName, tableName, alias string) actorTable {
var (
ActorIDColumn = sqlite.IntegerColumn("actor_id")
FirstNameColumn = sqlite.StringColumn("first_name")
LastNameColumn = sqlite.StringColumn("last_name")
LastUpdateColumn = sqlite.TimestampColumn("last_update")
allColumns = sqlite.ColumnList{ActorIDColumn, FirstNameColumn, LastNameColumn, LastUpdateColumn}
mutableColumns = sqlite.ColumnList{FirstNameColumn, LastNameColumn, LastUpdateColumn}
)
return actorTable{
Table: sqlite.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
ActorID: ActorIDColumn,
FirstName: FirstNameColumn,
LastName: LastNameColumn,
LastUpdate: LastUpdateColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
const filmListSQLBuilderFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package view
import (
"github.com/go-jet/jet/v2/sqlite"
)
var FilmList = newFilmListTable("", "film_list", "")
type filmListTable struct {
sqlite.Table
//Columns
Fid sqlite.ColumnInteger
Title sqlite.ColumnString
Description sqlite.ColumnString
Category sqlite.ColumnString
Price sqlite.ColumnFloat
Length sqlite.ColumnInteger
Rating sqlite.ColumnString
Actors sqlite.ColumnString
AllColumns sqlite.ColumnList
MutableColumns sqlite.ColumnList
}
type FilmListTable struct {
filmListTable
EXCLUDED filmListTable
}
// AS creates new FilmListTable with assigned alias
func (a FilmListTable) AS(alias string) *FilmListTable {
return newFilmListTable(a.SchemaName(), a.TableName(), alias)
}
// Schema creates new FilmListTable with assigned schema name
func (a FilmListTable) FromSchema(schemaName string) *FilmListTable {
return newFilmListTable(schemaName, a.TableName(), a.Alias())
}
func newFilmListTable(schemaName, tableName, alias string) *FilmListTable {
return &FilmListTable{
filmListTable: newFilmListTableImpl(schemaName, tableName, alias),
EXCLUDED: newFilmListTableImpl("", "excluded", ""),
}
}
func newFilmListTableImpl(schemaName, tableName, alias string) filmListTable {
var (
FidColumn = sqlite.IntegerColumn("FID")
TitleColumn = sqlite.StringColumn("title")
DescriptionColumn = sqlite.StringColumn("description")
CategoryColumn = sqlite.StringColumn("category")
PriceColumn = sqlite.FloatColumn("price")
LengthColumn = sqlite.IntegerColumn("length")
RatingColumn = sqlite.StringColumn("rating")
ActorsColumn = sqlite.StringColumn("actors")
allColumns = sqlite.ColumnList{FidColumn, TitleColumn, DescriptionColumn, CategoryColumn, PriceColumn, LengthColumn, RatingColumn, ActorsColumn}
mutableColumns = sqlite.ColumnList{FidColumn, TitleColumn, DescriptionColumn, CategoryColumn, PriceColumn, LengthColumn, RatingColumn, ActorsColumn}
)
return filmListTable{
Table: sqlite.NewTable(schemaName, tableName, alias, allColumns...),
//Columns
Fid: FidColumn,
Title: TitleColumn,
Description: DescriptionColumn,
Category: CategoryColumn,
Price: PriceColumn,
Length: LengthColumn,
Rating: RatingColumn,
Actors: ActorsColumn,
AllColumns: allColumns,
MutableColumns: mutableColumns,
}
}
`
const addressModelFile = `
//
// Code generated by go-jet DO NOT EDIT.
//
// WARNING: Changes to this file may cause incorrect behavior
// and will be lost if the code is regenerated
//
package model
import (
"time"
)
type Address struct {
AddressID int32 ` + "`sql:\"primary_key\"`" + `
Address string
Address2 *string
District string
CityID int32
PostalCode *string
Phone string
LastUpdate time.Time
}
`

393
tests/sqlite/insert_test.go Normal file
View file

@ -0,0 +1,393 @@
package sqlite
import (
"context"
"math/rand"
"testing"
"time"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table"
"github.com/stretchr/testify/require"
)
func TestInsertValues(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
VALUES(101, "http://www.google.com", "Google", "Search engine").
VALUES(102, "http://www.yahoo.com", "Yahoo", nil)
testutils.AssertStatementSql(t, insertQuery, `
INSERT INTO link (id, url, name, description)
VALUES (?, ?, ?, ?),
(?, ?, ?, ?),
(?, ?, ?, ?);
`, 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil,
101, "http://www.google.com", "Google", "Search engine",
102, "http://www.yahoo.com", "Yahoo", nil)
_, err := insertQuery.Exec(tx)
require.NoError(t, err)
requireLogged(t, insertQuery)
insertedLinks := []model.Link{}
err = SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.ID.GT_EQ(Int(100))).
ORDER_BY(Link.ID).
Query(tx, &insertedLinks)
require.NoError(t, err)
require.Equal(t, len(insertedLinks), 3)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
testutils.AssertDeepEqual(t, insertedLinks[1], model.Link{
ID: 101,
URL: "http://www.google.com",
Name: "Google",
Description: testutils.StringPtr("Search engine"),
})
testutils.AssertDeepEqual(t, insertedLinks[2], model.Link{
ID: 102,
URL: "http://www.yahoo.com",
Name: "Yahoo",
})
}
var postgreTutorial = model.Link{
ID: 100,
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial",
}
func TestInsertEmptyColumnList(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
expectedSQL := `
INSERT INTO link
VALUES (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL);
`
stmt := Link.INSERT().
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil)
testutils.AssertDebugStatementSql(t, stmt, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil)
_, err := stmt.Exec(tx)
require.NoError(t, err)
requireLogged(t, stmt)
insertedLinks := []model.Link{}
err = SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.ID.GT_EQ(Int(100))).
ORDER_BY(Link.ID).
Query(tx, &insertedLinks)
require.NoError(t, err)
require.Equal(t, len(insertedLinks), 1)
testutils.AssertDeepEqual(t, insertedLinks[0], postgreTutorial)
}
func TestInsertModelObject(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
linkData := model.Link{
URL: "http://www.duckduckgo.com",
Name: "Duck Duck go",
}
query := Link.INSERT(Link.URL, Link.Name).
MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, `
INSERT INTO link (url, name)
VALUES ('http://www.duckduckgo.com', 'Duck Duck go');
`, "http://www.duckduckgo.com", "Duck Duck go")
_, err := query.Exec(tx)
require.NoError(t, err)
}
func TestInsertModelObjectEmptyColumnList(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
INSERT INTO link
VALUES (1000, 'http://www.duckduckgo.com', 'Duck Duck go', NULL);
`
linkData := model.Link{
ID: 1000,
URL: "http://www.duckduckgo.com",
Name: "Duck Duck go",
}
query := Link.
INSERT().
MODEL(linkData)
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
_, err := query.Exec(tx)
require.NoError(t, err)
}
func TestInsertModelsObject(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
expectedSQL := `
INSERT INTO link (url, name)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial'),
('http://www.google.com', 'Google'),
('http://www.yahoo.com', 'Yahoo');
`
tutorial := model.Link{
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial",
}
google := model.Link{
URL: "http://www.google.com",
Name: "Google",
}
yahoo := model.Link{
URL: "http://www.yahoo.com",
Name: "Yahoo",
}
query := Link.
INSERT(Link.URL, Link.Name).
MODELS([]model.Link{
tutorial,
google,
yahoo,
})
testutils.AssertDebugStatementSql(t, query, expectedSQL,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
"http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo")
_, err := query.Exec(tx)
require.NoError(t, err)
}
func TestInsertUsingMutableColumns(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
INSERT INTO link (url, name, description)
VALUES ('http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.google.com', 'Google', NULL),
('http://www.yahoo.com', 'Yahoo', NULL);
`
google := model.Link{
URL: "http://www.google.com",
Name: "Google",
}
yahoo := model.Link{
URL: "http://www.yahoo.com",
Name: "Yahoo",
}
stmt := Link.
INSERT(Link.MutableColumns).
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
MODEL(google).
MODELS([]model.Link{google, yahoo})
testutils.AssertDebugStatementSql(t, stmt, expectedSQL,
"http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil,
"http://www.google.com", "Google", nil,
"http://www.google.com", "Google", nil,
"http://www.yahoo.com", "Yahoo", nil)
_, err := stmt.Exec(tx)
require.NoError(t, err)
}
func TestInsertQuery(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
INSERT INTO link (url, name)
SELECT link.url AS "link.url",
link.name AS "link.name"
FROM link
WHERE link.id = 24;
`
query := Link.INSERT(Link.URL, Link.Name).
QUERY(
SELECT(Link.URL, Link.Name).
FROM(Link).
WHERE(Link.ID.EQ(Int(24))),
)
testutils.AssertDebugStatementSql(t, query, expectedSQL, int64(24))
_, err := query.Exec(tx)
require.NoError(t, err)
youtubeLinks := []model.Link{}
err = Link.
SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Bing"))).
Query(tx, &youtubeLinks)
require.NoError(t, err)
require.Equal(t, len(youtubeLinks), 2)
}
func TestInsert_DEFAULT_VALUES_RETURNING(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
stmt := Link.INSERT().
DEFAULT_VALUES().
RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO link
DEFAULT VALUES
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`)
var link model.Link
err := stmt.Query(tx, &link)
require.NoError(t, err)
require.EqualValues(t, link, model.Link{
ID: 25,
URL: "www.",
Name: "_",
Description: nil,
})
}
func TestInsertOnConflict(t *testing.T) {
t.Run("do nothing", func(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
link := model.Link{ID: rand.Int31()}
stmt := Link.INSERT(Link.AllColumns).
MODEL(link).
MODEL(link).
ON_CONFLICT(Link.ID).DO_NOTHING()
testutils.AssertStatementSql(t, stmt, `
INSERT INTO link (id, url, name, description)
VALUES (?, ?, ?, ?),
(?, ?, ?, ?)
ON CONFLICT (id) DO NOTHING;
`)
testutils.AssertExec(t, stmt, tx, 1)
requireLogged(t, stmt)
})
t.Run("do update", func(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
VALUES(22, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
ON_CONFLICT(Link.ID).
DO_UPDATE(
SET(
Link.ID.SET(Link.EXCLUDED.ID),
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
),
).RETURNING(Link.AllColumns)
testutils.AssertStatementSql(t, stmt, `
INSERT INTO link (id, url, name, description)
VALUES (?, ?, ?, ?),
(?, ?, ?, ?)
ON CONFLICT (id) DO UPDATE
SET id = excluded.id,
url = ?
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`)
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt)
})
t.Run("do update complex", func(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(21, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil).
ON_CONFLICT(Link.ID).
WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).
DO_UPDATE(
SET(
Link.ID.SET(
IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))).
FROM(Link)),
),
ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String(""))),
).WHERE(Link.Description.IS_NOT_NULL()),
)
testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO link (id, url, name, description)
VALUES (21, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', NULL)
ON CONFLICT (id) WHERE (id * 2) > 10 DO UPDATE
SET id = (
SELECT MAX(link.id) + 1
FROM link
),
(name, description) = (excluded.name, '')
WHERE link.description IS NOT NULL;
`)
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt)
})
}
func TestInsertContextDeadlineExceeded(t *testing.T) {
stmt := Link.INSERT().
VALUES(1100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", nil)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := stmt.QueryContext(ctx, sampleDB, &dest)
require.Error(t, err, "context deadline exceeded")
_, err = stmt.ExecContext(ctx, db)
require.Error(t, err, "context deadline exceeded")
}

90
tests/sqlite/main_test.go Normal file
View file

@ -0,0 +1,90 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"github.com/go-jet/jet/v2/internal/utils/throw"
"github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/dbconfig"
"github.com/stretchr/testify/require"
"math/rand"
"os"
"os/exec"
"strings"
"testing"
"time"
"github.com/pkg/profile"
_ "github.com/mattn/go-sqlite3"
)
var db *sql.DB
var sampleDB *sql.DB
var testRoot string
func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
setTestRoot()
var err error
db, err = sql.Open("sqlite3", "file:"+dbconfig.SakilaDBPath)
throw.OnError(err)
_, err = db.Exec(fmt.Sprintf("ATTACH DATABASE '%s' as 'chinook';", dbconfig.ChinookDBPath))
throw.OnError(err)
sampleDB, err = sql.Open("sqlite3", dbconfig.TestSampleDBPath)
throw.OnError(err)
defer db.Close()
ret := m.Run()
if ret != 0 {
os.Exit(ret)
}
}
func setTestRoot() {
cmd := exec.Command("git", "rev-parse", "--show-toplevel")
byteArr, err := cmd.Output()
if err != nil {
panic(err)
}
testRoot = strings.TrimSpace(string(byteArr)) + "/tests/"
}
var loggedSQL string
var loggedSQLArgs []interface{}
var loggedDebugSQL string
func init() {
sqlite.SetLogger(func(ctx context.Context, statement sqlite.PrintableStatement) {
loggedSQL, loggedSQLArgs = statement.Sql()
loggedDebugSQL = statement.DebugSql()
})
}
func requireLogged(t *testing.T, statement sqlite.Statement) {
query, args := statement.Sql()
require.Equal(t, loggedSQL, query)
require.Equal(t, loggedSQLArgs, args)
require.Equal(t, loggedDebugSQL, statement.DebugSql())
}
func beginSampleDBTx(t *testing.T) *sql.Tx {
tx, err := sampleDB.Begin()
require.NoError(t, err)
return tx
}
func beginDBTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -0,0 +1,121 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
"github.com/stretchr/testify/require"
)
func TestRawStatementSelect(t *testing.T) {
stmt := RawStatement(`
SELECT actor.first_name AS "actor.first_name"
FROM actor
WHERE actor.actor_id = 2`)
testutils.AssertStatementSql(t, stmt, `
SELECT actor.first_name AS "actor.first_name"
FROM actor
WHERE actor.actor_id = 2;
`)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT actor.first_name AS "actor.first_name"
FROM actor
WHERE actor.actor_id = 2;
`)
var actor model.Actor
err := stmt.Query(db, &actor)
require.NoError(t, err)
require.Equal(t, actor.FirstName, "NICK")
}
func TestRawStatementSelectWithArguments(t *testing.T) {
stmt := RawStatement(`
SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
WHERE actor.actor_id IN (#actorID1, #actorID2, #actorID3) AND ((#actorID1 / #actorID2) <> (#actorID2 * #actorID3))
ORDER BY actor.actor_id`,
RawArgs{
"#actorID1": int64(1),
"#actorID2": int64(2),
"#actorID3": int64(3),
},
)
testutils.AssertStatementSql(t, stmt, `
SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
WHERE actor.actor_id IN (?, ?, ?) AND ((? / ?) <> (? * ?))
ORDER BY actor.actor_id;
`, int64(1), int64(2), int64(3), int64(1), int64(2), int64(2), int64(3))
testutils.AssertDebugStatementSql(t, stmt, `
SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
WHERE actor.actor_id IN (1, 2, 3) AND ((1 / 2) <> (2 * 3))
ORDER BY actor.actor_id;
`)
var actor []model.Actor
err := stmt.Query(db, &actor)
require.NoError(t, err)
testutils.AssertDeepEqual(t, actor[1], model.Actor{
ActorID: 2,
FirstName: "NICK",
LastName: "WAHLBERG",
LastUpdate: *testutils.TimestampWithoutTimeZone("2019-04-11 18:11:48", 2),
})
}
func TestRawStatementRows(t *testing.T) {
stmt := RawStatement(`
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
ORDER BY actor.actor_id`)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
for rows.Next() {
var actor model.Actor
err := rows.Scan(&actor)
require.NoError(t, err)
require.NotEqual(t, actor.ActorID, int16(0))
require.NotEqual(t, actor.FirstName, "")
require.NotEqual(t, actor.LastName, "")
require.NotEqual(t, actor.LastUpdate, time.Time{})
if actor.ActorID == 54 {
require.Equal(t, actor.ActorID, int32(54))
require.Equal(t, actor.FirstName, "PENELOPE")
require.Equal(t, actor.LastName, "PINKETT")
require.Equal(t, actor.LastUpdate.Format(time.RFC3339), "2019-04-11T18:11:48Z")
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
requireLogged(t, stmt)
}

749
tests/sqlite/select_test.go Normal file
View file

@ -0,0 +1,749 @@
package sqlite
import (
"context"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/chinook/table"
"strings"
"testing"
"time"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/view"
"github.com/stretchr/testify/require"
)
func TestSelect_ScanToStruct(t *testing.T) {
query := Actor.
SELECT(Actor.AllColumns).
DISTINCT().
WHERE(Actor.ActorID.EQ(Int(2)))
testutils.AssertStatementSql(t, query, `
SELECT DISTINCT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
WHERE actor.actor_id = ?;
`, int64(2))
actor := model.Actor{}
err := query.Query(db, &actor)
require.NoError(t, err)
testutils.AssertDeepEqual(t, actor, actor2)
requireLogged(t, query)
}
var actor2 = model.Actor{
ActorID: 2,
FirstName: "NICK",
LastName: "WAHLBERG",
LastUpdate: *testutils.TimestampWithoutTimeZone("2019-04-11 18:11:48", 2),
}
func TestSelect_ScanToSlice(t *testing.T) {
query := SELECT(Actor.AllColumns).
FROM(Actor).
ORDER_BY(Actor.ActorID)
testutils.AssertStatementSql(t, query, `
SELECT actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update"
FROM actor
ORDER BY actor.actor_id;
`)
dest := []model.Actor{}
err := query.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 200)
testutils.AssertDeepEqual(t, dest[1], actor2)
//testutils.SaveJSONFile(dest, "./testdata/results/sqlite/all_actors.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/all_actors.json")
requireLogged(t, query)
}
func TestSelectGroupByHaving(t *testing.T) {
expectedSQL := `
SELECT customer.customer_id AS "customer.customer_id",
customer.store_id AS "customer.store_id",
customer.first_name AS "customer.first_name",
customer.last_name AS "customer.last_name",
customer.email AS "customer.email",
customer.address_id AS "customer.address_id",
customer.active AS "customer.active",
customer.create_date AS "customer.create_date",
customer.last_update AS "customer.last_update",
SUM(payment.amount) AS "amount.sum",
AVG(payment.amount) AS "amount.avg",
MAX(payment.payment_date) AS "amount.max_date",
MAX(payment.amount) AS "amount.max",
MIN(payment.payment_date) AS "amount.min_date",
MIN(payment.amount) AS "amount.min",
COUNT(payment.amount) AS "amount.count"
FROM payment
INNER JOIN customer ON (customer.customer_id = payment.customer_id)
GROUP BY payment.customer_id
HAVING SUM(payment.amount) > 125.6
ORDER BY payment.customer_id, SUM(payment.amount) ASC;
`
query := Payment.
INNER_JOIN(Customer, Customer.CustomerID.EQ(Payment.CustomerID)).
SELECT(
Customer.AllColumns,
SUMf(Payment.Amount).AS("amount.sum"),
AVG(Payment.Amount).AS("amount.avg"),
MAX(Payment.PaymentDate).AS("amount.max_date"),
MAXf(Payment.Amount).AS("amount.max"),
MIN(Payment.PaymentDate).AS("amount.min_date"),
MINf(Payment.Amount).AS("amount.min"),
COUNT(Payment.Amount).AS("amount.count"),
).
GROUP_BY(Payment.CustomerID).
HAVING(
SUMf(Payment.Amount).GT(Float(125.6)),
).
ORDER_BY(
Payment.CustomerID, SUMf(Payment.Amount).ASC(),
)
testutils.AssertDebugStatementSql(t, query, expectedSQL, float64(125.6))
var dest []struct {
model.Customer
Amount struct {
Sum float64
Avg float64
Max float64
Min float64
Count int64
} `alias:"amount"`
}
err := query.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 174)
//testutils.SaveJSONFile(dest, "./testdata/results/sqlite/customer_payment_sum.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/customer_payment_sum.json")
requireLogged(t, query)
}
func TestSubQuery(t *testing.T) {
rRatingFilms :=
SELECT(
Film.FilmID,
Film.Title,
Film.Rating,
).FROM(
Film,
).WHERE(Film.Rating.EQ(String("R"))).
AsTable("rFilms")
rFilmID := Film.FilmID.From(rRatingFilms)
main :=
SELECT(
Actor.AllColumns,
FilmActor.AllColumns,
rRatingFilms.AllColumns(),
).FROM(
rRatingFilms.
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(rFilmID)).
INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)),
).ORDER_BY(
rFilmID,
Actor.ActorID,
)
var dest []struct {
model.Film
Actors []model.Actor
}
err := main.Query(db, &dest)
require.NoError(t, err)
//testutils.SaveJSONFile(dest, "./testdata/results/sqlite/r_rating_films.json")
testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/r_rating_films.json")
}
func TestSelectAndUnionInProjection(t *testing.T) {
query := UNION(
SELECT(
Payment.PaymentID,
).FROM(Payment),
SELECT(
STAR,
).FROM(
SELECT(Payment.PaymentID).
FROM(Payment).LIMIT(1).OFFSET(2).AsTable("p"),
),
).LIMIT(1).OFFSET(10)
testutils.AssertDebugStatementSql(t, query, `
SELECT payment.payment_id AS "payment.payment_id"
FROM payment
UNION
SELECT *
FROM (
SELECT payment.payment_id AS "payment.payment_id"
FROM payment
LIMIT 1
OFFSET 2
) AS p
LIMIT 1
OFFSET 10;
`, int64(1), int64(2), int64(1), int64(10))
dest := []struct{}{}
err := query.Query(db, &dest)
require.NoError(t, err)
}
func TestSelectUNION(t *testing.T) {
expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id"
FROM payment
WHERE payment.payment_id > ?
UNION
SELECT payment.payment_id AS "payment.payment_id"
FROM payment
WHERE payment.amount < ?
LIMIT ?;
`
query := UNION(
SELECT(Payment.PaymentID).
FROM(Payment).
WHERE(Payment.PaymentID.GT(Int(11))),
SELECT(Payment.PaymentID).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(2000.0))),
).LIMIT(1)
testutils.AssertStatementSql(t, query, expectedSQL, int64(11), 2000.0, int64(1))
query2 :=
SELECT(
Payment.PaymentID,
).FROM(
Payment,
).WHERE(
Payment.PaymentID.GT(Int(11)),
).UNION(
SELECT(Payment.PaymentID).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(2000.0))),
).LIMIT(1)
testutils.AssertStatementSql(t, query2, expectedSQL, int64(11), 2000.0, int64(1))
dest := []struct{}{}
err := query.Query(db, &dest)
require.NoError(t, err)
}
func TestSelectUNION_ALL(t *testing.T) {
expectedSQL := `
SELECT payment.payment_id AS "payment.payment_id"
FROM payment
WHERE payment.payment_id > ?
UNION ALL
SELECT payment.payment_id AS "payment.payment_id"
FROM payment
WHERE payment.amount < ?
LIMIT ?;
`
query := UNION_ALL(
SELECT(Payment.PaymentID).
FROM(Payment).
WHERE(Payment.PaymentID.GT(Int(11))),
SELECT(Payment.PaymentID).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(2000.0))),
).LIMIT(1)
testutils.AssertStatementSql(t, query, expectedSQL, int64(11), 2000.0, int64(1))
dest := []struct{}{}
err := query.Query(db, &dest)
require.NoError(t, err)
}
func TestJoinQueryStruct(t *testing.T) {
expectedSQL := `
SELECT film_actor.actor_id AS "film_actor.actor_id",
film_actor.film_id AS "film_actor.film_id",
film_actor.last_update AS "film_actor.last_update",
film.film_id AS "film.film_id",
film.title AS "film.title",
film.description AS "film.description",
film.release_year AS "film.release_year",
film.language_id AS "film.language_id",
film.original_language_id AS "film.original_language_id",
film.rental_duration AS "film.rental_duration",
film.rental_rate AS "film.rental_rate",
film.length AS "film.length",
film.replacement_cost AS "film.replacement_cost",
film.rating AS "film.rating",
film.special_features AS "film.special_features",
film.last_update AS "film.last_update",
language.language_id AS "language.language_id",
language.name AS "language.name",
language.last_update AS "language.last_update",
actor.actor_id AS "actor.actor_id",
actor.first_name AS "actor.first_name",
actor.last_name AS "actor.last_name",
actor.last_update AS "actor.last_update",
inventory.inventory_id AS "inventory.inventory_id",
inventory.film_id AS "inventory.film_id",
inventory.store_id AS "inventory.store_id",
inventory.last_update AS "inventory.last_update",
rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id",
rental.customer_id AS "rental.customer_id",
rental.return_date AS "rental.return_date",
rental.staff_id AS "rental.staff_id",
rental.last_update AS "rental.last_update"
FROM language
INNER JOIN film ON (film.language_id = language.language_id)
INNER JOIN film_actor ON (film_actor.film_id = film.film_id)
INNER JOIN actor ON (actor.actor_id = film_actor.actor_id)
LEFT JOIN inventory ON (inventory.film_id = film.film_id)
LEFT JOIN rental ON (rental.inventory_id = inventory.inventory_id)
ORDER BY language.language_id ASC, film.film_id ASC, actor.actor_id ASC, inventory.inventory_id ASC, rental.rental_id ASC
LIMIT ?;
`
for i := 0; i < 2; i++ {
query :=
SELECT(
FilmActor.AllColumns,
Film.AllColumns,
Language.AllColumns,
Actor.AllColumns,
Inventory.AllColumns,
Rental.AllColumns,
).
FROM(
Language.
INNER_JOIN(Film, Film.LanguageID.EQ(Language.LanguageID)).
INNER_JOIN(FilmActor, FilmActor.FilmID.EQ(Film.FilmID)).
INNER_JOIN(Actor, Actor.ActorID.EQ(FilmActor.ActorID)).
LEFT_JOIN(Inventory, Inventory.FilmID.EQ(Film.FilmID)).
LEFT_JOIN(Rental, Rental.InventoryID.EQ(Inventory.InventoryID)),
).ORDER_BY(
Language.LanguageID.ASC(),
Film.FilmID.ASC(),
Actor.ActorID.ASC(),
Inventory.InventoryID.ASC(),
Rental.RentalID.ASC(),
).
LIMIT(1000)
testutils.AssertStatementSql(t, query, expectedSQL, int64(1000))
var dest []struct {
model.Language
Films []struct {
model.Film
Actors []struct {
model.Actor
}
Inventories *[]struct {
model.Inventory
Rentals *[]model.Rental
}
}
}
err := query.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSONFile(t, dest, "./testdata/results/sqlite/lang_film_actor_inventory_rental.json")
}
}
func TestExpressionWrappers(t *testing.T) {
query := SELECT(
BoolExp(Raw("true")),
IntExp(Raw("11")),
FloatExp(Raw("11.22")),
StringExp(Raw("'stringer'")),
TimeExp(Raw("'raw'")),
TimestampExp(Raw("'raw'")),
DateTimeExp(Raw("'raw'")),
DateExp(Raw("'date'")),
)
testutils.AssertStatementSql(t, query, `
SELECT true,
11,
11.22,
'stringer',
'raw',
'raw',
'raw',
'date';
`)
dest := []struct{}{}
err := query.Query(db, &dest)
require.NoError(t, err)
}
func TestWindowFunction(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (PARTITION BY payment.customer_id),
MAX(payment.amount) OVER (ORDER BY payment.payment_date DESC),
MIN(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC ROWS BETWEEN 1 PRECEDING AND 6 FOLLOWING),
SUM(payment.amount) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
MAX(payment.customer_id) OVER (ORDER BY payment.payment_date DESC ROWS BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING),
MIN(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
SUM(payment.customer_id) OVER (PARTITION BY payment.customer_id ORDER BY payment.payment_date DESC),
ROW_NUMBER() OVER (ORDER BY payment.payment_date),
RANK() OVER (ORDER BY payment.payment_date),
DENSE_RANK() OVER (ORDER BY payment.payment_date),
CUME_DIST() OVER (ORDER BY payment.payment_date),
NTILE(11) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LAG(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, payment.amount) OVER (ORDER BY payment.payment_date),
LEAD(payment.amount, 2, ?) OVER (ORDER BY payment.payment_date),
FIRST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
LAST_VALUE(payment.amount) OVER (ORDER BY payment.payment_date),
NTH_VALUE(payment.amount, 3) OVER (ORDER BY payment.payment_date)
FROM payment
WHERE payment.payment_id < ?
GROUP BY payment.amount, payment.customer_id, payment.payment_date;
`
query :=
SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID)),
MAXf(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate.DESC())),
MINf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).ROWS(PRECEDING(1), FOLLOWING(6))),
SUMf(Payment.Amount).OVER(PARTITION_BY(Payment.CustomerID).
ORDER_BY(Payment.PaymentDate.DESC()).RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
MAXi(Payment.CustomerID).OVER(ORDER_BY(Payment.PaymentDate.DESC()).ROWS(CURRENT_ROW, FOLLOWING(UNBOUNDED))),
MINi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
SUMi(Payment.CustomerID).OVER(PARTITION_BY(Payment.CustomerID).ORDER_BY(Payment.PaymentDate.DESC())),
ROW_NUMBER().OVER(ORDER_BY(Payment.PaymentDate)),
RANK().OVER(ORDER_BY(Payment.PaymentDate)),
DENSE_RANK().OVER(ORDER_BY(Payment.PaymentDate)),
CUME_DIST().OVER(ORDER_BY(Payment.PaymentDate)),
NTILE(11).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAG(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LEAD(Payment.Amount, 2, 100).OVER(ORDER_BY(Payment.PaymentDate)),
FIRST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
LAST_VALUE(Payment.Amount).OVER(ORDER_BY(Payment.PaymentDate)),
NTH_VALUE(Payment.Amount, 3).OVER(ORDER_BY(Payment.PaymentDate)),
).FROM(
Payment,
).GROUP_BY(
Payment.Amount,
Payment.CustomerID,
Payment.PaymentDate,
).WHERE(Payment.PaymentID.LT(Int(10)))
testutils.AssertStatementSql(t, query, expectedSQL, 100, 100, int64(10))
dest := []struct{}{}
err := query.Query(db, &dest)
require.NoError(t, err)
}
func TestWindowClause(t *testing.T) {
var expectedSQL = `
SELECT AVG(payment.amount) OVER (),
AVG(payment.amount) OVER (w1),
AVG(payment.amount) OVER (w2 ORDER BY payment.customer_id RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),
AVG(payment.amount) OVER (w3 RANGE BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)
FROM payment
WHERE payment.payment_id < ?
WINDOW w1 AS (PARTITION BY payment.payment_date), w2 AS (w1), w3 AS (w2 ORDER BY payment.customer_id)
ORDER BY payment.customer_id;
`
query := SELECT(
AVG(Payment.Amount).OVER(),
AVG(Payment.Amount).OVER(Window("w1")),
AVG(Payment.Amount).OVER(
Window("w2").
ORDER_BY(Payment.CustomerID).
RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED)),
),
AVG(Payment.Amount).OVER(Window("w3").RANGE(PRECEDING(UNBOUNDED), FOLLOWING(UNBOUNDED))),
).FROM(
Payment,
).WHERE(
Payment.PaymentID.LT(Int(10)),
).
WINDOW("w1").AS(PARTITION_BY(Payment.PaymentDate)).
WINDOW("w2").AS(Window("w1")).
WINDOW("w3").AS(Window("w2").ORDER_BY(Payment.CustomerID)).
ORDER_BY(
Payment.CustomerID,
)
testutils.AssertStatementSql(t, query, expectedSQL, int64(10))
dest := []struct{}{}
err := query.Query(db, &dest)
require.NoError(t, err)
}
func TestSimpleView(t *testing.T) {
query :=
SELECT(
view.CustomerList.AllColumns,
).FROM(
view.CustomerList,
).ORDER_BY(
view.CustomerList.ID,
).LIMIT(10)
var dest []model.CustomerList
err := query.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 10)
require.Equal(t, dest[2], model.CustomerList{
ID: testutils.Int32Ptr(3),
Name: testutils.StringPtr("LINDA WILLIAMS"),
Address: testutils.StringPtr("692 Joliet Street"),
ZipCode: testutils.StringPtr("83579"),
Phone: testutils.StringPtr(" "),
City: testutils.StringPtr("Athenai"),
Country: testutils.StringPtr("Greece"),
Notes: testutils.StringPtr("active"),
Sid: testutils.Int32Ptr(1),
})
}
func TestJoinViewWithTable(t *testing.T) {
query :=
SELECT(
view.CustomerList.AllColumns,
Rental.AllColumns,
).FROM(
view.CustomerList.
INNER_JOIN(Rental, view.CustomerList.ID.EQ(Rental.CustomerID)),
).ORDER_BY(
view.CustomerList.ID,
).WHERE(
view.CustomerList.ID.LT_EQ(Int(2)),
)
var dest []struct {
model.CustomerList `sql:"primary_key=ID"`
Rentals []model.Rental
}
err := query.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 2)
require.Equal(t, len(dest[0].Rentals), 32)
require.Equal(t, len(dest[1].Rentals), 27)
}
func TestConditionalProjectionList(t *testing.T) {
projectionList := ProjectionList{}
columnsToSelect := []string{"customer_id", "create_date"}
for _, columnName := range columnsToSelect {
switch columnName {
case Customer.CustomerID.Name():
projectionList = append(projectionList, Customer.CustomerID)
case Customer.Email.Name():
projectionList = append(projectionList, Customer.Email)
case Customer.CreateDate.Name():
projectionList = append(projectionList, Customer.CreateDate)
}
}
stmt := SELECT(projectionList).
FROM(Customer).
LIMIT(3)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT customer.customer_id AS "customer.customer_id",
customer.create_date AS "customer.create_date"
FROM customer
LIMIT 3;
`)
var dest []model.Customer
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 3)
}
func TestUseAttachedDatabase(t *testing.T) {
Artists := table.Artists.FromSchema("chinook")
Albums := table.Albums.FromSchema("chinook")
stmt :=
SELECT(
Artists.AllColumns,
Albums.AllColumns,
).FROM(
Albums.
INNER_JOIN(Artists, Artists.ArtistId.EQ(Albums.ArtistId)),
).ORDER_BY(
Artists.ArtistId,
).LIMIT(10)
testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
SELECT artists.''ArtistId'' AS "artists.ArtistId",
artists.''Name'' AS "artists.Name",
albums.''AlbumId'' AS "albums.AlbumId",
albums.''Title'' AS "albums.Title",
albums.''ArtistId'' AS "albums.ArtistId"
FROM chinook.albums
INNER JOIN chinook.artists ON (artists.''ArtistId'' = albums.''ArtistId'')
ORDER BY artists.''ArtistId''
LIMIT 10;
`, "''", "`", -1))
var dest []struct {
model2.Artists
Albums []model2.Albums
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Len(t, dest, 7)
}
func TestRowsScan(t *testing.T) {
stmt :=
SELECT(
Inventory.AllColumns,
).FROM(
Inventory,
).ORDER_BY(
Inventory.InventoryID.ASC(),
)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
for rows.Next() {
var inventory model.Inventory
err = rows.Scan(&inventory)
require.NoError(t, err)
require.NotEqual(t, inventory.InventoryID, uint32(0))
require.NotEqual(t, inventory.FilmID, uint16(0))
require.NotEqual(t, inventory.StoreID, uint16(0))
require.NotEqual(t, inventory.LastUpdate, time.Time{})
if inventory.InventoryID == 2103 {
require.Equal(t, inventory.FilmID, int32(456))
require.Equal(t, inventory.StoreID, int32(2))
require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2019-04-11T18:11:48Z")
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
requireLogged(t, stmt)
}
func TestScanNumericToNumber(t *testing.T) {
type Number struct {
Int8 int8
UInt8 uint8
Int16 int16
UInt16 uint16
Int32 int32
UInt32 uint32
Int64 int64
UInt64 uint64
Float32 float32
Float64 float64
}
numeric := CAST(String("1234567890.111")).AS_REAL()
stmt := SELECT(
numeric.AS("number.int8"),
numeric.AS("number.uint8"),
numeric.AS("number.int16"),
numeric.AS("number.uint16"),
numeric.AS("number.int32"),
numeric.AS("number.uint32"),
numeric.AS("number.int64"),
numeric.AS("number.uint64"),
numeric.AS("number.float32"),
numeric.AS("number.float64"),
)
var number Number
err := stmt.Query(db, &number)
require.NoError(t, err)
require.Equal(t, number.Int8, int8(-46)) // overflow
require.Equal(t, number.UInt8, uint8(210)) // overflow
require.Equal(t, number.Int16, int16(722)) // overflow
require.Equal(t, number.UInt16, uint16(722)) // overflow
require.Equal(t, number.Int32, int32(1234567890))
require.Equal(t, number.UInt32, uint32(1234567890))
require.Equal(t, number.Int64, int64(1234567890))
require.Equal(t, number.UInt64, uint64(1234567890))
require.Equal(t, number.Float32, float32(1.234568e+09))
require.Equal(t, number.Float64, float64(1.234567890111e+09))
}

290
tests/sqlite/update_test.go Normal file
View file

@ -0,0 +1,290 @@
package sqlite
import (
"context"
"testing"
"time"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/test_sample/table"
"github.com/stretchr/testify/require"
)
func TestUpdateValues(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
var expectedSQL = `
UPDATE link
SET name = 'Bong',
url = 'http://bong.com'
WHERE link.name = 'Bing';
`
t.Run("old version", func(t *testing.T) {
query := Link.UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, tx)
requireLogged(t, query)
})
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(
Link.Name.SET(String("Bong")),
Link.URL.SET(String("http://bong.com")),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt)
})
links := []model.Link{}
err := SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.Name.EQ(String("Bong"))).
Query(tx, &links)
require.NoError(t, err)
require.Equal(t, len(links), 1)
testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 24,
URL: "http://bong.com",
Name: "Bong",
})
}
func TestUpdateWithSubQueries(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
expectedSQL := `
UPDATE link
SET name = ?,
url = (
SELECT link.url AS "link.url"
FROM link
WHERE link.name = ?
)
WHERE link.name = ?;
`
t.Run("old version", func(t *testing.T) {
query := Link.
UPDATE(Link.Name, Link.URL).
SET(
String("Bong"),
SELECT(Link.URL).
FROM(Link).
WHERE(Link.Name.EQ(String("Ask"))),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Ask", "Bing")
testutils.AssertExec(t, query, tx)
requireLogged(t, query)
})
t.Run("new version", func(t *testing.T) {
query := Link.
UPDATE().
SET(
Link.Name.SET(String("Bong")),
Link.URL.SET(StringExp(
SELECT(Link.URL).
FROM(Link).
WHERE(Link.Name.EQ(String("Ask"))),
)),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Ask", "Bing")
testutils.AssertExec(t, query, tx)
requireLogged(t, query)
})
}
func TestUpdateWithModelDataAndReturning(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
link := model.Link{
ID: 20,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
stmt := Link.UPDATE(Link.AllColumns).
MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID))).
RETURNING(
Link.AllColumns,
String("str").AS("dest.literal"),
NOT(Bool(false)).AS("dest.unary_operator"),
Link.ID.ADD(Int(11)).AS("dest.binary_operator"),
CAST(Link.ID).AS_TEXT().AS("dest.cast_operator"),
Link.Name.LIKE(String("Bing")).AS("dest.like_operator"),
Link.Description.IS_NULL().AS("dest.is_null"),
CASE(Link.Name).
WHEN(String("Yahoo")).THEN(String("search")).
WHEN(String("GMail")).THEN(String("mail")).
ELSE(String("unknown")).AS("dest.case_operator"),
)
expectedSQL := `
UPDATE link
SET id = ?,
url = ?,
name = ?,
description = ?
WHERE link.id = ?
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description",
? AS "dest.literal",
(NOT ?) AS "dest.unary_operator",
(link.id + ?) AS "dest.binary_operator",
CAST(link.id AS TEXT) AS "dest.cast_operator",
(link.name LIKE ?) AS "dest.like_operator",
link.description IS NULL AS "dest.is_null",
(CASE link.name WHEN ? THEN ? WHEN ? THEN ? ELSE ? END) AS "dest.case_operator";
`
testutils.AssertStatementSql(t, stmt, expectedSQL, int32(20), "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(20),
"str", false, int64(11), "Bing", "Yahoo", "search", "GMail", "mail", "unknown")
type Dest struct {
model.Link
Literal string
UnaryOperator bool
BinaryOperator int64
CastOperator string
LikeOperator bool
IsNull bool
CaseOperator string
}
var dest Dest
err := stmt.Query(tx, &dest)
require.NoError(t, err)
require.EqualValues(t, dest, Dest{
Link: link,
Literal: "str",
UnaryOperator: true,
BinaryOperator: 31,
CastOperator: "20",
LikeOperator: false,
IsNull: true,
CaseOperator: "unknown",
})
requireLogged(t, stmt)
}
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
link := model.Link{
ID: 20,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
updateColumnList := ColumnList{Link.Description, Link.Name, Link.URL}
stmt := Link.UPDATE(updateColumnList).
MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID)))
var expectedSQL = `
UPDATE link
SET description = NULL,
name = 'DuckDuckGo',
url = 'http://www.duckduckgo.com'
WHERE link.id = 20;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int32(20))
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt)
}
func TestUpdateWithModelDataAndMutableColumns(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
link := model.Link{
ID: 201,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
stmt := Link.UPDATE(Link.MutableColumns).
MODEL(link).
WHERE(Link.ID.EQ(Int32(link.ID)))
var expectedSQL = `
UPDATE link
SET url = 'http://www.duckduckgo.com',
name = 'DuckDuckGo',
description = NULL
WHERE link.id = 201;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int32(201))
testutils.AssertExec(t, stmt, tx)
}
func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() {
r := recover()
require.Equal(t, r, "missing struct field for column : id")
}()
link := struct {
Ident int
URL string
Name string
Description *string
Rel *string
}{
Ident: 201,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
stmt := Link.UPDATE(Link.AllColumns).
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.Ident))))
stmt.Sql()
}
func TestUpdateContextDeadlineExceeded(t *testing.T) {
tx := beginSampleDBTx(t)
defer tx.Rollback()
updateStmt := Link.UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := updateStmt.QueryContext(ctx, tx, &dest)
require.Error(t, err, "context deadline exceeded")
_, err = updateStmt.ExecContext(ctx, tx)
require.Error(t, err, "context deadline exceeded")
}

234
tests/sqlite/with_test.go Normal file
View file

@ -0,0 +1,234 @@
package sqlite
import (
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/sqlite"
. "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table"
"github.com/stretchr/testify/require"
"strings"
"testing"
)
func TestWITH_And_SELECT(t *testing.T) {
salesRep := CTE("sales_rep")
salesRepStaffID := Staff.StaffID.From(salesRep)
salesRepFullName := StringColumn("sales_rep_full_name").From(salesRep)
customerSalesRep := CTE("customer_sales_rep")
stmt := WITH(
salesRep.AS(
SELECT(
Staff.StaffID,
Staff.FirstName.CONCAT(Staff.LastName).AS(salesRepFullName.Name()),
).FROM(Staff),
),
customerSalesRep.AS(
SELECT(
Customer.FirstName.CONCAT(Customer.LastName).AS("customer_name"),
salesRepFullName,
).FROM(
salesRep.
INNER_JOIN(Store, Store.ManagerStaffID.EQ(salesRepStaffID)).
INNER_JOIN(Customer, Customer.StoreID.EQ(Store.StoreID)),
),
),
)(
SELECT(customerSalesRep.AllColumns()).
FROM(customerSalesRep),
)
testutils.AssertStatementSql(t, stmt, strings.Replace(`
WITH sales_rep AS (
SELECT staff.staff_id AS "staff.staff_id",
(staff.first_name || staff.last_name) AS "sales_rep_full_name"
FROM staff
),customer_sales_rep AS (
SELECT (customer.first_name || customer.last_name) AS "customer_name",
sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM sales_rep
INNER JOIN store ON (store.manager_staff_id = sales_rep.''staff.staff_id'')
INNER JOIN customer ON (customer.store_id = store.store_id)
)
SELECT customer_sales_rep.customer_name AS "customer_name",
customer_sales_rep.sales_rep_full_name AS "sales_rep_full_name"
FROM customer_sales_rep;
`, "''", "`", -1))
var dest []struct {
CustomerName string
SalesRepFullName string
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
require.Equal(t, len(dest), 599)
}
func TestWITH_And_INSERT(t *testing.T) {
paymentsToInsert := CTE("payments_to_insert")
stmt := WITH(
paymentsToInsert.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.INSERT(Payment.AllColumns).
QUERY(
SELECT(
paymentsToInsert.AllColumns(),
).FROM(
paymentsToInsert,
).WHERE(Bool(true)), //https://stackoverflow.com/questions/66230093/error-while-doing-upsert-in-sqlite-3-34-error-near-do-syntax-error
).ON_CONFLICT().DO_UPDATE(
SET(
Payment.PaymentID.SET(Payment.PaymentID.ADD(Int(100000))),
),
),
)
testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
WITH payments_to_insert AS (
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update"
FROM payment
WHERE payment.amount < 0.5
)
INSERT INTO payment (payment_id, customer_id, staff_id, rental_id, amount, payment_date, last_update)
SELECT payments_to_insert.''payment.payment_id'' AS "payment.payment_id",
payments_to_insert.''payment.customer_id'' AS "payment.customer_id",
payments_to_insert.''payment.staff_id'' AS "payment.staff_id",
payments_to_insert.''payment.rental_id'' AS "payment.rental_id",
payments_to_insert.''payment.amount'' AS "payment.amount",
payments_to_insert.''payment.payment_date'' AS "payment.payment_date",
payments_to_insert.''payment.last_update'' AS "payment.last_update"
FROM payments_to_insert
WHERE TRUE
ON CONFLICT DO UPDATE
SET payment_id = (payment.payment_id + 100000);
`, "''", "`", -1))
tx := beginDBTx(t)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx, 24)
}
func TestWITH_SELECT_UPDATE(t *testing.T) {
paymentsToUpdate := CTE("payments_to_update")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToUpdate)
stmt := WITH(
paymentsToUpdate.AS(
SELECT(Payment.AllColumns).
FROM(Payment).
WHERE(Payment.Amount.LT(Float(0.5))),
),
)(
Payment.UPDATE().
SET(Payment.Amount.SET(Float(0.0))).
WHERE(Payment.PaymentID.IN(
SELECT(paymentsToDeleteID).
FROM(paymentsToUpdate),
),
),
)
testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
WITH payments_to_update AS (
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update"
FROM payment
WHERE payment.amount < 0.5
)
UPDATE payment
SET amount = 0
WHERE payment.payment_id IN (
SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_update
);
`, "''", "`", -1))
tx := beginDBTx(t)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx)
}
func TestWITH_And_DELETE(t *testing.T) {
paymentsToDelete := CTE("payments_to_delete")
paymentsToDeleteID := Payment.PaymentID.From(paymentsToDelete)
stmt := WITH(
paymentsToDelete.AS(
SELECT(
Payment.AllColumns,
).FROM(
Payment,
).WHERE(
Payment.Amount.LT(Float(0.5)),
),
),
)(
Payment.DELETE().
WHERE(
Payment.PaymentID.IN(
SELECT(
paymentsToDeleteID,
).FROM(
paymentsToDelete,
),
),
),
)
testutils.AssertDebugStatementSql(t, stmt, strings.Replace(`
WITH payments_to_delete AS (
SELECT payment.payment_id AS "payment.payment_id",
payment.customer_id AS "payment.customer_id",
payment.staff_id AS "payment.staff_id",
payment.rental_id AS "payment.rental_id",
payment.amount AS "payment.amount",
payment.payment_date AS "payment.payment_date",
payment.last_update AS "payment.last_update"
FROM payment
WHERE payment.amount < 0.5
)
DELETE FROM payment
WHERE payment.payment_id IN (
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete
);
`, "''", "`", -1))
tx := beginDBTx(t)
defer tx.Rollback()
testutils.AssertExec(t, stmt, tx, 24)
}
func TestOperatorIN(t *testing.T) {
stmt := SELECT(Payment.PaymentID.IN(SELECT(Int(11)), Int(22))).
FROM(Payment)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT payment.payment_id IN ((
SELECT 11
), 22)
FROM payment;
`)
var dest []struct{}
err := stmt.Query(db, &dest)
require.NoError(t, err)
}

@ -1 +1 @@
Subproject commit a6c1975a167645f913496131ae81d4cabc070046 Subproject commit 946bc1e5d3e162154eade8b79ff915e4c4986efd