From c94216ab0e7d3251d4b5d7cc189dd5f4ce450bc3 Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 28 Feb 2025 18:23:15 +0100 Subject: [PATCH] Add support for blob expressions. --- generator/template/sql_builder_template.go | 10 +- internal/jet/blob_expression.go | 103 ++++++ internal/jet/column_types.go | 43 ++- internal/jet/dialect.go | 12 +- internal/jet/expression.go | 2 +- internal/jet/func_expression.go | 57 +-- internal/jet/literal_expression.go | 5 + internal/jet/sql_builder.go | 14 +- internal/jet/string_expression.go | 3 + internal/jet/string_or_blob_expression.go | 8 + internal/jet/utils.go | 4 +- internal/utils/is/is.go | 4 +- mysql/cast.go | 4 +- mysql/columns.go | 6 + mysql/dialect.go | 11 + mysql/expressions.go | 9 + mysql/functions.go | 58 ++- mysql/literal.go | 5 + postgres/cast.go | 6 +- postgres/columns.go | 6 + postgres/dialect.go | 13 +- postgres/expressions.go | 27 +- postgres/functions.go | 72 +++- postgres/literal.go | 2 +- sqlite/cast.go | 4 +- sqlite/columns.go | 6 + sqlite/dialect.go | 13 +- sqlite/expressions.go | 8 + sqlite/functions.go | 13 +- sqlite/literal.go | 4 + tests/mysql/alltypes_test.go | 74 +++- tests/mysql/generator_test.go | 392 +++++++++++++++++++++ tests/mysql/main_test.go | 3 + tests/postgres/alltypes_test.go | 266 +++++++++++++- tests/postgres/generator_test.go | 8 +- tests/postgres/main_test.go | 7 +- tests/sqlite/alltypes_test.go | 95 ++++- 37 files changed, 1296 insertions(+), 81 deletions(-) create mode 100644 internal/jet/blob_expression.go create mode 100644 internal/jet/string_or_blob_expression.go diff --git a/generator/template/sql_builder_template.go b/generator/template/sql_builder_template.go index cfd40aa..dcb4e97 100644 --- a/generator/template/sql_builder_template.go +++ b/generator/template/sql_builder_template.go @@ -180,11 +180,15 @@ func getSqlBuilderColumnType(columnMetaData metadata.Column) string { return "Timez" case "interval": return "Interval" - case "user-defined", "enum", "text", "character", "character varying", "bytea", "uuid", + case "user-defined", "enum", "text", "character", "character varying", "uuid", "tsvector", "bit", "bit varying", "money", "json", "jsonb", "xml", "point", "line", "ARRAY", - "char", "varchar", "nvarchar", "binary", "varbinary", "bpchar", "varbit", - "tinyblob", "blob", "mediumblob", "longblob", "tinytext", "mediumtext", "longtext": // MySQL + "char", "varchar", "nvarchar", "bpchar", "varbit", + "tinytext", "mediumtext", "longtext": // MySQL return "String" + case "bytea": // postgres + return "Bytea" + case "binary", "varbinary", "tinyblob", "mediumblob", "longblob", "blob": // mysql and sqlite + return "Blob" case "real", "numeric", "decimal", "double precision", "float", "float4", "float8", "double": // MySQL return "Float" diff --git a/internal/jet/blob_expression.go b/internal/jet/blob_expression.go new file mode 100644 index 0000000..47cdf9a --- /dev/null +++ b/internal/jet/blob_expression.go @@ -0,0 +1,103 @@ +package jet + +// BlobExpression interface +type BlobExpression interface { + Expression + + isStringOrBlob() + + EQ(rhs BlobExpression) BoolExpression + NOT_EQ(rhs BlobExpression) BoolExpression + IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression + IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression + + LT(rhs BlobExpression) BoolExpression + LT_EQ(rhs BlobExpression) BoolExpression + GT(rhs BlobExpression) BoolExpression + GT_EQ(rhs BlobExpression) BoolExpression + BETWEEN(min, max BlobExpression) BoolExpression + NOT_BETWEEN(min, max BlobExpression) BoolExpression + + CONCAT(rhs BlobExpression) BlobExpression + + LIKE(pattern BlobExpression) BoolExpression + NOT_LIKE(pattern BlobExpression) BoolExpression +} + +type blobInterfaceImpl struct { + parent BlobExpression +} + +func (s *blobInterfaceImpl) isStringOrBlob() {} + +func (s *blobInterfaceImpl) EQ(rhs BlobExpression) BoolExpression { + return Eq(s.parent, rhs) +} + +func (s *blobInterfaceImpl) NOT_EQ(rhs BlobExpression) BoolExpression { + return NotEq(s.parent, rhs) +} + +func (s *blobInterfaceImpl) IS_DISTINCT_FROM(rhs BlobExpression) BoolExpression { + return IsDistinctFrom(s.parent, rhs) +} + +func (s *blobInterfaceImpl) IS_NOT_DISTINCT_FROM(rhs BlobExpression) BoolExpression { + return IsNotDistinctFrom(s.parent, rhs) +} + +func (s *blobInterfaceImpl) GT(rhs BlobExpression) BoolExpression { + return Gt(s.parent, rhs) +} + +func (s *blobInterfaceImpl) GT_EQ(rhs BlobExpression) BoolExpression { + return GtEq(s.parent, rhs) +} + +func (s *blobInterfaceImpl) LT(rhs BlobExpression) BoolExpression { + return Lt(s.parent, rhs) +} + +func (s *blobInterfaceImpl) LT_EQ(rhs BlobExpression) BoolExpression { + return LtEq(s.parent, rhs) +} + +func (s *blobInterfaceImpl) BETWEEN(min, max BlobExpression) BoolExpression { + return NewBetweenOperatorExpression(s.parent, min, max, false) +} + +func (s *blobInterfaceImpl) NOT_BETWEEN(min, max BlobExpression) BoolExpression { + return NewBetweenOperatorExpression(s.parent, min, max, true) +} + +func (s *blobInterfaceImpl) CONCAT(rhs BlobExpression) BlobExpression { + return BlobExp(newBinaryStringOperatorExpression(s.parent, rhs, StringConcatOperator)) +} + +func (s *blobInterfaceImpl) LIKE(pattern BlobExpression) BoolExpression { + return newBinaryBoolOperatorExpression(s.parent, pattern, "LIKE") +} + +func (s *blobInterfaceImpl) NOT_LIKE(pattern BlobExpression) BoolExpression { + return newBinaryBoolOperatorExpression(s.parent, pattern, "NOT LIKE") +} + +//---------------------------------------------------// + +type blobExpressionWrapper struct { + blobInterfaceImpl + Expression +} + +func newBlobExpressionWrap(expression Expression) BlobExpression { + blobExpressionWrap := blobExpressionWrapper{Expression: expression} + blobExpressionWrap.blobInterfaceImpl.parent = &blobExpressionWrap + return &blobExpressionWrap +} + +// BlobExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as blob expression. +// Does not add sql cast to generated sql builder output. +func BlobExp(expression Expression) BlobExpression { + return newBlobExpressionWrap(expression) +} diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index a732061..98e9117 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -122,7 +122,7 @@ func IntegerColumn(name string) ColumnInteger { //------------------------------------------------------// // ColumnString is interface for SQL text, character, character varying -// bytea, uuid columns and enums types. +// uuid columns and enums types. type ColumnString interface { StringExpression Column @@ -163,6 +163,47 @@ func StringColumn(name string) ColumnString { //------------------------------------------------------// +// ColumnBlob is interface for binary data types (bytea, binary, blob, etc...) +type ColumnBlob interface { + BlobExpression + Column + + From(subQuery SelectTable) ColumnBlob + SET(blob BlobExpression) ColumnAssigment +} + +type blobColumnImpl struct { + blobInterfaceImpl + + ColumnExpressionImpl +} + +func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob { + newBlobColumn := BlobColumn(i.name) + newBlobColumn.setTableName(i.tableName) + newBlobColumn.setSubQuery(subQuery) + + return newBlobColumn +} + +func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: blobExp, + } +} + +// BlobColumn creates named blob column. +func BlobColumn(name string) ColumnBlob { + blobColumn := &blobColumnImpl{} + blobColumn.blobInterfaceImpl.parent = blobColumn + blobColumn.ColumnExpressionImpl = NewColumnImpl(name, "", blobColumn) + + return blobColumn +} + +//------------------------------------------------------// + // ColumnTime is interface for SQL time column. type ColumnTime interface { TimeExpression diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 68c4c02..e38c581 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -1,6 +1,8 @@ package jet -import "strings" +import ( + "strings" +) // Dialect interface type Dialect interface { @@ -11,6 +13,7 @@ type Dialect interface { AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc + ArgumentToString(value any) (string, bool) IsReservedWord(name string) bool SerializeOrderBy() func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName(index int) string @@ -34,6 +37,7 @@ type DialectParams struct { AliasQuoteChar byte IdentifierQuoteChar byte ArgumentPlaceholder QueryPlaceholderFunc + ArgumentToString func(value any) (string, bool) ReservedWords []string SerializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc ValuesDefaultColumnName func(index int) string @@ -49,6 +53,7 @@ func NewDialect(params DialectParams) Dialect { aliasQuoteChar: params.AliasQuoteChar, identifierQuoteChar: params.IdentifierQuoteChar, argumentPlaceholder: params.ArgumentPlaceholder, + argumentToString: params.ArgumentToString, reservedWords: arrayOfStringsToMapOfStrings(params.ReservedWords), serializeOrderBy: params.SerializeOrderBy, valuesDefaultColumnName: params.ValuesDefaultColumnName, @@ -63,6 +68,7 @@ type dialectImpl struct { aliasQuoteChar byte identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc + argumentToString func(value any) (string, bool) reservedWords map[string]bool serializeOrderBy func(expression Expression, ascending, nullsFirst *bool) SerializerFunc valuesDefaultColumnName func(index int) string @@ -102,6 +108,10 @@ func (d *dialectImpl) ArgumentPlaceholder() QueryPlaceholderFunc { return d.argumentPlaceholder } +func (d *dialectImpl) ArgumentToString(value any) (string, bool) { + return d.argumentToString(value) +} + func (d *dialectImpl) IsReservedWord(name string) bool { _, isReservedWord := d.reservedWords[strings.ToLower(name)] return isReservedWord diff --git a/internal/jet/expression.go b/internal/jet/expression.go index 5c152ce..1065010 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -159,7 +159,7 @@ func newExpressionListOperator(operator string, expressions ...Expression) *expr } func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression { - return BoolExp(newExpressionListOperator(operator, BoolExpressionListToExpressionList(expressions)...)) + return BoolExp(newExpressionListOperator(operator, ToExpressionList(expressions)...)) } func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index ddc579e..46c8f0d 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -255,18 +255,30 @@ func leadLagImpl(name string, expr Expression, offsetAndDefault ...interface{}) //------------ String functions ------------------// +// HEX function takes an input and returns its equivalent hexadecimal representation +func HEX(expression Expression) StringExpression { + return StringExp(Func("HEX", expression)) +} + +// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument +// as a hexadecimal number and converts it to the byte represented by the number. +// The return value is a binary string. +func UNHEX(expression StringExpression) BlobExpression { + return BlobExp(Func("UNHEX", expression)) +} + // BIT_LENGTH returns number of bits in string expression -func BIT_LENGTH(stringExpression StringExpression) IntegerExpression { +func BIT_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression { return newIntegerFunc("BIT_LENGTH", stringExpression) } // CHAR_LENGTH returns number of characters in string expression -func CHAR_LENGTH(stringExpression StringExpression) IntegerExpression { +func CHAR_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression { return newIntegerFunc("CHAR_LENGTH", stringExpression) } // OCTET_LENGTH returns number of bytes in string expression -func OCTET_LENGTH(stringExpression StringExpression) IntegerExpression { +func OCTET_LENGTH(stringExpression StringOrBlobExpression) IntegerExpression { return newIntegerFunc("OCTET_LENGTH", stringExpression) } @@ -282,7 +294,7 @@ func UPPER(stringExpression StringExpression) StringExpression { // BTRIM removes the longest string consisting only of characters // in characters (a space by default) from the start and end of string -func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) StringExpression { +func BTRIM(stringExpression StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression { if len(trimChars) > 0 { return NewStringFunc("BTRIM", stringExpression, trimChars[0]) } @@ -291,7 +303,7 @@ func BTRIM(stringExpression StringExpression, trimChars ...StringExpression) Str // LTRIM removes the longest string containing only characters // from characters (a space by default) from the start of string -func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { +func LTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression { if len(trimChars) > 0 { return NewStringFunc("LTRIM", str, trimChars[0]) } @@ -300,7 +312,7 @@ func LTRIM(str StringExpression, trimChars ...StringExpression) StringExpression // RTRIM removes the longest string containing only characters // from characters (a space by default) from the end of string -func RTRIM(str StringExpression, trimChars ...StringExpression) StringExpression { +func RTRIM(str StringOrBlobExpression, trimChars ...StringOrBlobExpression) StringExpression { if len(trimChars) > 0 { return NewStringFunc("RTRIM", str, trimChars[0]) } @@ -324,32 +336,32 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression // CONVERT converts string to dest_encoding. The original encoding is // specified by src_encoding. The string must be valid in this encoding. -func CONVERT(str StringExpression, srcEncoding StringExpression, destEncoding StringExpression) StringExpression { - return NewStringFunc("CONVERT", str, srcEncoding, destEncoding) +func CONVERT(str BlobExpression, srcEncoding StringExpression, destEncoding StringExpression) BlobExpression { + return BlobExp(Func("CONVERT", str, srcEncoding, destEncoding)) } // CONVERT_FROM converts string to the database encoding. The original // encoding is specified by src_encoding. The string must be valid in this encoding. -func CONVERT_FROM(str StringExpression, srcEncoding StringExpression) StringExpression { +func CONVERT_FROM(str BlobExpression, srcEncoding StringExpression) StringExpression { return NewStringFunc("CONVERT_FROM", str, srcEncoding) } // CONVERT_TO converts string to dest_encoding. -func CONVERT_TO(str StringExpression, toEncoding StringExpression) StringExpression { - return NewStringFunc("CONVERT_TO", str, toEncoding) +func CONVERT_TO(str StringExpression, toEncoding StringExpression) BlobExpression { + return BlobExp(Func("CONVERT_TO", str, toEncoding)) } // ENCODE encodes binary data into a textual representation. // Supported formats are: base64, hex, escape. escape converts zero bytes and // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. -func ENCODE(data StringExpression, format StringExpression) StringExpression { - return NewStringFunc("ENCODE", data, format) +func ENCODE(data BlobExpression, format StringExpression) StringExpression { + return StringExp(Func("ENCODE", data, format)) } // DECODE decodes binary data from textual representation in string. // Options for format are same as in encode. -func DECODE(data StringExpression, format StringExpression) StringExpression { - return NewStringFunc("DECODE", data, format) +func DECODE(data StringExpression, format StringExpression) BlobExpression { + return BlobExp(Func("DECODE", data, format)) } // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. @@ -379,11 +391,11 @@ func RIGHT(str StringExpression, n IntegerExpression) StringExpression { } // LENGTH returns number of characters in string with a given encoding -func LENGTH(str StringExpression, encoding ...StringExpression) StringExpression { +func LENGTH(str StringOrBlobExpression, encoding ...StringExpression) IntegerExpression { if len(encoding) > 0 { - return NewStringFunc("LENGTH", str, encoding[0]) + return IntExp(Func("LENGTH", str, encoding[0])) } - return NewStringFunc("LENGTH", str) + return IntExp(Func("LENGTH", str)) } // LPAD fills up the string to length length by prepending the characters @@ -407,8 +419,13 @@ func RPAD(str StringExpression, length IntegerExpression, text ...StringExpressi return NewStringFunc("RPAD", str, length) } +// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”). +func BIT_COUNT(bytes BlobExpression) IntegerExpression { + return IntExp(Func("BIT_COUNT", bytes)) +} + // MD5 calculates the MD5 hash of string, returning the result in hexadecimal -func MD5(stringExpression StringExpression) StringExpression { +func MD5(stringExpression StringOrBlobExpression) StringExpression { return NewStringFunc("MD5", stringExpression) } @@ -434,7 +451,7 @@ func STRPOS(str, substring StringExpression) IntegerExpression { } // SUBSTR extracts substring -func SUBSTR(str StringExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { +func SUBSTR(str StringOrBlobExpression, from IntegerExpression, count ...IntegerExpression) StringExpression { if len(count) > 0 { return NewStringFunc("SUBSTR", str, from, count[0]) } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index 251d3ab..ff0faa1 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -468,6 +468,11 @@ func RawDate(raw string, namedArgs ...map[string]interface{}) DateExpression { return DateExp(Raw(raw, namedArgs...)) } +// RawBlob is raw query helper that for blob expressions +func RawBlob(raw string, namedArgs ...map[string]interface{}) BlobExpression { + return BlobExp(Raw(raw, namedArgs...)) +} + // RawRange helper that for range expressions func RawRange[T Expression](raw string, namedArgs ...map[string]interface{}) Range[T] { return RangeExp[T](Raw(raw, namedArgs...)) diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 87a5814..9465ec6 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -127,7 +127,7 @@ func (s *SQLBuilder) finalize() (string, []interface{}) { } func (s *SQLBuilder) insertConstantArgument(arg interface{}) { - s.WriteString(argToString(arg)) + s.WriteString(s.argToString(arg)) } func (s *SQLBuilder) insertParametrizedArgument(arg interface{}) { @@ -200,7 +200,7 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) } if s.Debug { - placeholder = argToString(namedArgumentPos.Value) + placeholder = s.argToString(namedArgumentPos.Value) } raw = strings.Replace(raw, namedArgumentPos.Name, placeholder, toReplace) @@ -209,11 +209,17 @@ func (s *SQLBuilder) insertRawQuery(raw string, namedArg map[string]interface{}) s.WriteString(raw) } -func argToString(value interface{}) string { +func (s *SQLBuilder) argToString(value interface{}) string { if is.Nil(value) { return "NULL" } + strVal, ok := s.Dialect.ArgumentToString(value) + + if ok { + return strVal + } + switch bindVal := value.(type) { case bool: if bindVal { @@ -250,7 +256,7 @@ func argToString(value interface{}) string { return err.Error() } - return argToString(val) + return s.argToString(val) } panic(fmt.Sprintf("jet: %s type can not be used as SQL query parameter", reflect.TypeOf(value).String())) diff --git a/internal/jet/string_expression.go b/internal/jet/string_expression.go index 29b2447..20e9e39 100644 --- a/internal/jet/string_expression.go +++ b/internal/jet/string_expression.go @@ -3,6 +3,7 @@ package jet // StringExpression interface type StringExpression interface { Expression + isStringOrBlob() EQ(rhs StringExpression) BoolExpression NOT_EQ(rhs StringExpression) BoolExpression @@ -29,6 +30,8 @@ type stringInterfaceImpl struct { parent StringExpression } +func (s *stringInterfaceImpl) isStringOrBlob() {} + func (s *stringInterfaceImpl) EQ(rhs StringExpression) BoolExpression { return Eq(s.parent, rhs) } diff --git a/internal/jet/string_or_blob_expression.go b/internal/jet/string_or_blob_expression.go new file mode 100644 index 0000000..f05a149 --- /dev/null +++ b/internal/jet/string_or_blob_expression.go @@ -0,0 +1,8 @@ +package jet + +// StringOrBlobExpression is common interface for all string and blob expressions +type StringOrBlobExpression interface { + Expression + + isStringOrBlob() +} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 5278605..1a62807 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -132,8 +132,8 @@ func ExpressionListToSerializerList(expressions []Expression) []Serializer { return ret } -// BoolExpressionListToExpressionList converts list of bool expressions to list of expressions -func BoolExpressionListToExpressionList(expressions []BoolExpression) []Expression { +// ToExpressionList converts list of any expressions to list of expressions +func ToExpressionList[T Expression](expressions []T) []Expression { var ret []Expression for _, expression := range expressions { diff --git a/internal/utils/is/is.go b/internal/utils/is/is.go index 8824b06..1f5ee54 100644 --- a/internal/utils/is/is.go +++ b/internal/utils/is/is.go @@ -1,6 +1,8 @@ package is -import "reflect" +import ( + "reflect" +) // Nil check if v is nil func Nil(v interface{}) bool { diff --git a/mysql/cast.go b/mysql/cast.go index dcf1f57..fbce06c 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -70,6 +70,6 @@ func (c *cast) AS_TIME() TimeExpression { } // AS_BINARY casts expression as BINARY type -func (c *cast) AS_BINARY() StringExpression { - return StringExp(c.AS("BINARY")) +func (c *cast) AS_BINARY() BlobExpression { + return BlobExp(c.AS("BINARY")) } diff --git a/mysql/columns.go b/mysql/columns.go index 3f08396..c0df1aa 100644 --- a/mysql/columns.go +++ b/mysql/columns.go @@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString // StringColumn creates named string column. var StringColumn = jet.StringColumn +// ColumnBlob is interface for blob columns. +type ColumnBlob = jet.ColumnBlob + +// BlobColumn creates named blob column. +var BlobColumn = jet.BlobColumn + // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger = jet.ColumnInteger diff --git a/mysql/dialect.go b/mysql/dialect.go index 9628bfb..175b2c1 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -1,6 +1,7 @@ package mysql import ( + "encoding/hex" "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -27,6 +28,7 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, + ArgumentToString: argumentToString, ReservedWords: reservedWords, SerializeOrderBy: serializeOrderBy, ValuesDefaultColumnName: func(index int) string { @@ -37,6 +39,15 @@ func newDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } +func argumentToString(value any) (string, bool) { + switch bindVal := value.(type) { + case []byte: + return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true + } + + return "", false +} + func mysqlBitXor(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { diff --git a/mysql/expressions.go b/mysql/expressions.go index 4073ef5..ac2be48 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +// BlobExpression interface +type BlobExpression = jet.BlobExpression + // IntegerExpression interface type IntegerExpression = jet.IntegerExpression @@ -43,6 +46,11 @@ var BoolExp = jet.BoolExp // Does not add sql cast to generated sql builder output. var StringExp = jet.StringExp +// BlobExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as blob expression. +// Does not add sql cast to generated sql builder output. +var BlobExp = jet.BlobExp + // IntExp is int expression wrapper around arbitrary expression. // Allows go compiler to see any expression as int expression. // Does not add sql cast to generated sql builder output. @@ -100,6 +108,7 @@ var ( RawTime = jet.RawTime RawTimestamp = jet.RawTimestamp RawDate = jet.RawDate + RawBlob = jet.RawBlob ) // Func can be used to call custom or unsupported database functions. diff --git a/mysql/functions.go b/mysql/functions.go index 3eb16da..56aa517 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -149,9 +149,12 @@ var NTH_VALUE = jet.NTH_VALUE //--------------------- String functions ------------------// // HEX function in MySQL takes an input and returns its equivalent hexadecimal representation -func HEX(expression Expression) StringExpression { - return StringExp(Func("HEX", expression)) -} +var HEX = jet.HEX + +// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument +// as a hexadecimal number and converts it to the byte represented by the number. +// The return value is a binary string. +var UNHEX = jet.UNHEX // BIT_LENGTH returns number of bits in string expression var BIT_LENGTH = jet.BIT_LENGTH @@ -162,6 +165,23 @@ var CHAR_LENGTH = jet.CHAR_LENGTH // OCTET_LENGTH returns number of bytes in string expression var OCTET_LENGTH = jet.OCTET_LENGTH +// ELT returns the Nth element of the list of strings: str1 if N = 1, str2 if N = 2, and so on. +// Returns NULL if N is less than 1, greater than the number of arguments, or NULL. +func ELT(n IntegerExpression, list ...StringExpression) StringExpression { + args := []Expression{n} + args = append(args, jet.ToExpressionList(list)...) + + return StringExp(Func("ELT", args...)) +} + +// FIELD returns the index (position) of str in the str1, str2, str3, ... list. Returns 0 if str is not found. +func FIELD(str StringExpression, list ...StringExpression) StringExpression { + args := []Expression{str} + args = append(args, jet.ToExpressionList(list)...) + + return StringExp(Func("FIELD", args...)) +} + // LOWER returns string expression in lower case var LOWER = jet.LOWER @@ -183,7 +203,35 @@ var CONCAT = jet.CONCAT var CONCAT_WS = jet.CONCAT_WS // FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. -var FORMAT = jet.FORMAT +func FORMAT(number jet.NumericExpression, decimals IntegerExpression, optionalLocale ...StringExpression) StringExpression { + if len(optionalLocale) > 0 { + return StringExp(Func("FORMAT", number, decimals, optionalLocale[0])) + } + + return StringExp(Func("FORMAT", number, decimals)) +} + +// TO_BASE64 converts the string argument to base-64 encoded form and returns the +// result as a character string with the connection character set and collation. +func TO_BASE64(data jet.StringOrBlobExpression) StringExpression { + return StringExp(Func("TO_BASE64", data)) +} + +// FROM_BASE64 takes a string encoded with the base-64 encoded rules used by TO_BASE64() +// and returns the decoded result as a binary string. +func FROM_BASE64(data StringExpression) BlobExpression { + return BlobExp(Func("FROM_BASE64", data)) +} + +// CHARSET returns the character set of the string argument, or NULL if the argument is NULL. +func CHARSET(exp Expression) StringExpression { + return StringExp(Func("CHARSET", exp)) +} + +// COLLATION returns the collation of the string argument. +func COLLATION(exp Expression) StringExpression { + return StringExp(Func("COLLATION ", exp)) +} // LEFT returns first n characters in the string. // When n is negative, return all but last |n| characters. @@ -194,7 +242,7 @@ var LEFT = jet.LEFT var RIGHT = jet.RIGHT // LENGTH returns number of characters in string with a given encoding -func LENGTH(str jet.StringExpression) jet.StringExpression { +func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression { return jet.LENGTH(str) } diff --git a/mysql/literal.go b/mysql/literal.go index ca720c8..d23fd7a 100644 --- a/mysql/literal.go +++ b/mysql/literal.go @@ -56,6 +56,11 @@ var String = jet.String // value can be any uuid type with a String method var UUID = jet.UUID +// Blob creates new blob literal expression +func Blob(data []byte) BlobExpression { + return BlobExp(jet.Literal(data)) +} + // Date creates new date literal func Date(year int, month time.Month, day int) DateExpression { return CAST(jet.Date(year, month, day)).AS_DATE() diff --git a/postgres/cast.go b/postgres/cast.go index 935cb08..e4d09ce 100644 --- a/postgres/cast.go +++ b/postgres/cast.go @@ -101,9 +101,9 @@ func (b *cast) AS_DECIMAL() FloatExpression { return FloatExp(b.AS("decimal")) } -// AS_BYTEA casts expression AS text type -func (b *cast) AS_BYTEA() StringExpression { - return StringExp(b.AS("bytea")) +// AS_BYTEA casts expression AS bytea type +func (b *cast) AS_BYTEA() ByteaExpression { + return ByteaExp(b.AS("bytea")) } // AS_TIME casts expression AS date type diff --git a/postgres/columns.go b/postgres/columns.go index a70c234..390c23d 100644 --- a/postgres/columns.go +++ b/postgres/columns.go @@ -23,6 +23,12 @@ type ColumnString = jet.ColumnString // StringColumn creates named string column. var StringColumn = jet.StringColumn +// ColumnBytea is interface for bytea columns +type ColumnBytea = jet.ColumnBlob + +// ByteaColumn creates new named bytea column. +var ByteaColumn = jet.BlobColumn + // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger = jet.ColumnInteger diff --git a/postgres/dialect.go b/postgres/dialect.go index 7885cf7..9e6160f 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -1,6 +1,7 @@ package postgres import ( + "encoding/hex" "fmt" "strconv" @@ -26,7 +27,8 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, - ReservedWords: reservedWords, + ArgumentToString: argumentToString, + ReservedWords: reservedWords, ValuesDefaultColumnName: func(index int) string { return fmt.Sprintf("column%d", index+1) }, @@ -35,6 +37,15 @@ func newDialect() jet.Dialect { return jet.NewDialect(dialectParams) } +func argumentToString(value any) (string, bool) { + switch bindVal := value.(type) { + case []byte: + return fmt.Sprintf("'\\x%s'", hex.EncodeToString(bindVal)), true + } + + return "", false +} + func postgresCAST(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { diff --git a/postgres/expressions.go b/postgres/expressions.go index d8ad34b..510b4bf 100644 --- a/postgres/expressions.go +++ b/postgres/expressions.go @@ -12,6 +12,8 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +type ByteaExpression = jet.BlobExpression + // NumericExpression interface type NumericExpression = jet.NumericExpression @@ -82,6 +84,11 @@ var TimeExp = jet.TimeExp // Does not add sql cast to generated sql builder output. var StringExp = jet.StringExp +// ByteaExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as string expression. +// Does not add sql cast to generated sql builder output. +var ByteaExp = jet.BlobExp + // TimezExp is time with time zone expression wrapper around arbitrary expression. // Allows go compiler to see any expression as time with time zone expression. // Does not add sql cast to generated sql builder output. @@ -134,15 +141,17 @@ type RawArgs = map[string]interface{} var ( Raw = jet.Raw - RawBool = jet.RawBool - RawInt = jet.RawInt - RawFloat = jet.RawFloat - RawString = jet.RawString - RawTime = jet.RawTime - RawTimez = jet.RawTimez - RawTimestamp = jet.RawTimestamp - RawTimestampz = jet.RawTimestampz - RawDate = jet.RawDate + RawBool = jet.RawBool + RawInt = jet.RawInt + RawFloat = jet.RawFloat + RawString = jet.RawString + RawTime = jet.RawTime + RawTimez = jet.RawTimez + RawTimestamp = jet.RawTimestamp + RawTimestampz = jet.RawTimestampz + RawDate = jet.RawDate + RawBytea = jet.RawBlob + RawNumRange = jet.RawRange[jet.NumericExpression] RawInt4Range = jet.RawRange[jet.Int4Expression] RawInt8Range = jet.RawRange[jet.Int8Expression] diff --git a/postgres/functions.go b/postgres/functions.go index bce2e98..cf1f11e 100644 --- a/postgres/functions.go +++ b/postgres/functions.go @@ -192,9 +192,27 @@ func CONCAT_WS(separator Expression, expressions ...Expression) StringExpression return jet.CONCAT_WS(explicitLiteralCast(separator), explicitLiteralCasts(expressions...)...) } +// Character encodings for CONVERT, CONVERT_FROM and CONVERT_TO functions +var ( + UTF8 = String("UTF8") + LATIN1 = String("LATIN1") + LATIN2 = String("LATIN2") + LATIN3 = String("LATIN3") + LATIN4 = String("LATIN4") + WIN1252 = String("WIN1252") + ISO_8859_5 = String("ISO_8859_5") + ISO_8859_6 = String("ISO_8859_6") + ISO_8859_7 = String("ISO_8859_7") + ISO_8859_8 = String("ISO_8859_8") + KOI8R = String("KOI8R") + KOI8U = String("KOI8U") +) + // CONVERT converts string to dest_encoding. The original encoding is // specified by src_encoding. The string must be valid in this encoding. -var CONVERT = jet.CONVERT +func CONVERT(str ByteaExpression, srcEncoding StringExpression, destEncoding StringExpression) ByteaExpression { + return jet.CONVERT(str, srcEncoding, destEncoding) +} // CONVERT_FROM converts string to the database encoding. The original // encoding is specified by src_encoding. The string must be valid in this encoding. @@ -203,6 +221,13 @@ var CONVERT_FROM = jet.CONVERT_FROM // CONVERT_TO converts string to dest_encoding. var CONVERT_TO = jet.CONVERT_TO +// ENCODE/DECODE textual formats +var ( + Base64 StringExpression = String("base64") + Escape StringExpression = String("escape") + Hex StringExpression = String("hex") +) + // ENCODE encodes binary data into a textual representation. // Supported formats are: base64, hex, escape. escape converts zero bytes and // high-bit-set bytes to octal sequences (\nnn) and doubles backslashes. @@ -212,7 +237,7 @@ var ENCODE = jet.ENCODE // Options for format are same as in encode. var DECODE = jet.DECODE -// FORMAT formats a number to a format like "#,###,###.##", rounded to a specified number of decimal places, then it returns the result as a string. +// FORMAT formats the arguments according to a format string. This function is similar to the C function sprintf. func FORMAT(formatStr StringExpression, formatArgs ...Expression) StringExpression { return jet.FORMAT(formatStr, explicitLiteralCasts(formatArgs...)...) } @@ -242,6 +267,49 @@ var LPAD = jet.LPAD // fill (a space by default). If the string is already longer than length then it is truncated. var RPAD = jet.RPAD +// BIT_COUNT returns the number of bits set in the binary string (also known as “popcount”). +var BIT_COUNT = jet.BIT_COUNT + +// GET_BIT extracts n'th bit from binary string. +func GET_BIT(bytes ByteaExpression, n IntegerExpression) IntegerExpression { + return IntExp(Func("GET_BIT", bytes, n)) +} + +// GET_BYTE extracts n'th byte from binary string. +func GET_BYTE(bytes ByteaExpression, n IntegerExpression) IntegerExpression { + return IntExp(Func("GET_BYTE", bytes, n)) +} + +// SET_BIT sets n'th bit in binary string to newvalue. +func SET_BIT(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression { + return ByteaExp(Func("SET_BIT", bytes, n, newValue)) +} + +// SET_BYTE sets n'th byte in binary string to newvalue. +func SET_BYTE(bytes ByteaExpression, n IntegerExpression, newValue IntegerExpression) ByteaExpression { + return ByteaExp(Func("SET_BYTE", bytes, n, newValue)) +} + +// SHA224 computes the SHA-224 hash of the binary string. +func SHA224(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA224", bytes)) +} + +// SHA256 computes the SHA-256 hash of the binary string. +func SHA256(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA256", bytes)) +} + +// SHA384 computes the SHA-384 hash of the binary string. +func SHA384(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA384", bytes)) +} + +// SHA512 computes the SHA-512 hash of the binary string. +func SHA512(bytes ByteaExpression) ByteaExpression { + return ByteaExp(Func("SHA512", bytes)) +} + // MD5 calculates the MD5 hash of string, returning the result in hexadecimal var MD5 = jet.MD5 diff --git a/postgres/literal.go b/postgres/literal.go index 4f1c2c8..e785b48 100644 --- a/postgres/literal.go +++ b/postgres/literal.go @@ -122,7 +122,7 @@ func Json(value interface{}) StringExpression { var UUID = jet.UUID // Bytea creates new bytea literal expression -func Bytea(value interface{}) StringExpression { +func Bytea(value interface{}) ByteaExpression { switch value.(type) { case string, []byte: default: diff --git a/sqlite/cast.go b/sqlite/cast.go index 0f68f43..fb74820 100644 --- a/sqlite/cast.go +++ b/sqlite/cast.go @@ -42,6 +42,6 @@ func (c *cast) AS_REAL() FloatExpression { } // AS_BLOB cast expression to BLOB type -func (c *cast) AS_BLOB() StringExpression { - return StringExp(c.AS("BLOB")) +func (c *cast) AS_BLOB() BlobExpression { + return BlobExp(c.AS("BLOB")) } diff --git a/sqlite/columns.go b/sqlite/columns.go index 2941b8d..44b6145 100644 --- a/sqlite/columns.go +++ b/sqlite/columns.go @@ -21,6 +21,12 @@ type ColumnString = jet.ColumnString // StringColumn creates named string column. var StringColumn = jet.StringColumn +// ColumnBlob is interface for +type ColumnBlob = jet.ColumnBlob + +// BlobColumn creates new named blob column +var BlobColumn = jet.BlobColumn + // ColumnInteger is interface for SQL smallint, integer, bigint columns. type ColumnInteger = jet.ColumnInteger diff --git a/sqlite/dialect.go b/sqlite/dialect.go index da03364..3585651 100644 --- a/sqlite/dialect.go +++ b/sqlite/dialect.go @@ -1,6 +1,7 @@ package sqlite import ( + "encoding/hex" "fmt" "github.com/go-jet/jet/v2/internal/jet" ) @@ -23,7 +24,8 @@ func newDialect() jet.Dialect { ArgumentPlaceholder: func(int) string { return "?" }, - ReservedWords: reservedWords2, + ArgumentToString: argumentToString, + ReservedWords: reservedWords2, ValuesDefaultColumnName: func(index int) string { return fmt.Sprintf("column%d", index+1) }, @@ -32,6 +34,15 @@ func newDialect() jet.Dialect { return jet.NewDialect(mySQLDialectParams) } +func argumentToString(value any) (string, bool) { + switch bindVal := value.(type) { + case []byte: + return fmt.Sprintf("X'%s'", hex.EncodeToString(bindVal)), true + } + + return "", false +} + func sqliteBitXOR(expressions ...jet.Serializer) jet.SerializerFunc { return func(statement jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { diff --git a/sqlite/expressions.go b/sqlite/expressions.go index 0b2d320..3de5cd2 100644 --- a/sqlite/expressions.go +++ b/sqlite/expressions.go @@ -12,6 +12,9 @@ type BoolExpression = jet.BoolExpression // StringExpression interface type StringExpression = jet.StringExpression +// BlobExpression interface +type BlobExpression = jet.BlobExpression + // NumericExpression is shared interface for integer or real expression type NumericExpression = jet.NumericExpression @@ -46,6 +49,11 @@ var BoolExp = jet.BoolExp // Does not add sql cast to generated sql builder output. var StringExp = jet.StringExp +// BlobExp is blob expression wrapper around arbitrary expression. +// Allows go compiler to see any expression as blob expression. +// Does not add sql cast to generated sql builder output. +var BlobExp = jet.BlobExp + // IntExp is int expression wrapper around arbitrary expression. // Allows go compiler to see any expression as int expression. // Does not add sql cast to generated sql builder output. diff --git a/sqlite/functions.go b/sqlite/functions.go index ac6bd08..92139b4 100644 --- a/sqlite/functions.go +++ b/sqlite/functions.go @@ -196,11 +196,22 @@ var RTRIM = jet.RTRIM // return jet.NewStringFunc("RIGHTSTR", str, n) //} +// HEX function takes an input and returns its equivalent hexadecimal representation +var HEX = jet.HEX + +// UNHEX for a string argument str, UNHEX(str) interprets each pair of characters in the argument +// as a hexadecimal number and converts it to the byte represented by the number. +// The return value is a binary string. +var UNHEX = jet.UNHEX + // LENGTH returns number of characters in string with a given encoding -func LENGTH(str jet.StringExpression) jet.StringExpression { +func LENGTH(str jet.StringOrBlobExpression) jet.IntegerExpression { return jet.LENGTH(str) } +// OCTET_LENGTH returns number of bytes in string expression +var OCTET_LENGTH = jet.OCTET_LENGTH + // LPAD fills up the string to length length by prepending the characters // fill (a space by default). If the string is already longer than length // then it is truncated (on the right). diff --git a/sqlite/literal.go b/sqlite/literal.go index 2df5dd7..2069711 100644 --- a/sqlite/literal.go +++ b/sqlite/literal.go @@ -50,6 +50,10 @@ var Decimal = jet.Decimal // String creates new string literal expression var String = jet.String +func Blob(data []byte) BlobExpression { + return BlobExp(jet.Literal(data)) +} + // UUID is a helper function to create string literal expression from uuid object // value can be any uuid type with a String method var UUID = jet.UUID diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 2440339..7e7397c 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -511,7 +511,8 @@ func TestStringOperators(t *testing.T) { RTRIM(AllTypes.VarCharPtr), CONCAT(String("string1"), Int(1), Float(11.12)), CONCAT_WS(String("string1"), Int(1), Float(11.12)), - FORMAT(String("Hello %s, %1$s"), String("World")), + FORMAT(Int(11), Int(2)), + FORMAT(Int(11), Int(2), String("de_DE")), LEFT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)), LENGTH(String("jose")), @@ -523,6 +524,12 @@ func TestStringOperators(t *testing.T) { REVERSE(AllTypes.VarCharPtr), SUBSTR(AllTypes.CharPtr, Int(3)), SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), + ELT(Int(2), AllTypes.CharPtr, AllTypes.Char, AllTypes.Text), + FIELD(AllTypes.Char, AllTypes.VarChar, AllTypes.Text), + FROM_BASE64(String("SGVsbG8gV29ybGQ=")), + TO_BASE64(String("Hello World")), + CHARSET(AllTypes.Char), + COLLATION(AllTypes.Text), } if !sourceIsMariaDB() { @@ -544,6 +551,71 @@ func TestStringOperators(t *testing.T) { require.NoError(t, err) } +func TestBlob(t *testing.T) { + + var sampleBlob = Blob([]byte{11, 0, 22, 33, 44}) + var textBlob = Blob([]byte("text blob")) + + stmt := SELECT( + AllTypes.BlobPtr.EQ(sampleBlob), + AllTypes.BlobPtr.EQ(AllTypes.BlobPtr), + AllTypes.BlobPtr.NOT_EQ(sampleBlob), + AllTypes.BlobPtr.GT(textBlob), + AllTypes.BlobPtr.GT_EQ(AllTypes.BlobPtr), + AllTypes.BlobPtr.LT(AllTypes.BlobPtr), + AllTypes.BlobPtr.LT_EQ(sampleBlob), + AllTypes.BlobPtr.BETWEEN(Blob([]byte("min")), Blob([]byte("max"))), + AllTypes.BlobPtr.NOT_BETWEEN(AllTypes.BlobPtr, AllTypes.BlobPtr), + AllTypes.BlobPtr.CONCAT(textBlob), + AllTypes.BlobPtr.LIKE(AllTypes.BlobPtr), + AllTypes.BlobPtr.NOT_LIKE(sampleBlob), + + BIT_LENGTH(textBlob), + LENGTH(sampleBlob), + CHAR_LENGTH(AllTypes.BlobPtr), + OCTET_LENGTH(textBlob), + CONCAT(sampleBlob, Int(1), Float(11.12)), + TO_BASE64(sampleBlob), + HEX(sampleBlob), + UNHEX(String("616B263A")), + SUBSTR(AllTypes.BlobPtr, Int(3)), + SUBSTR(AllTypes.BlobPtr, Int(3), Int(2)), + ).FROM( + AllTypes, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT all_types.blob_ptr = X'0b0016212c', + all_types.blob_ptr = all_types.blob_ptr, + all_types.blob_ptr != X'0b0016212c', + all_types.blob_ptr > X'7465787420626c6f62', + all_types.blob_ptr >= all_types.blob_ptr, + all_types.blob_ptr < all_types.blob_ptr, + all_types.blob_ptr <= X'0b0016212c', + all_types.blob_ptr BETWEEN X'6d696e' AND X'6d6178', + all_types.blob_ptr NOT BETWEEN all_types.blob_ptr AND all_types.blob_ptr, + CONCAT(all_types.blob_ptr, X'7465787420626c6f62'), + all_types.blob_ptr LIKE all_types.blob_ptr, + all_types.blob_ptr NOT LIKE X'0b0016212c', + BIT_LENGTH(X'7465787420626c6f62'), + LENGTH(X'0b0016212c'), + CHAR_LENGTH(all_types.blob_ptr), + OCTET_LENGTH(X'7465787420626c6f62'), + CONCAT(X'0b0016212c', 1, 11.12), + TO_BASE64(X'0b0016212c'), + HEX(X'0b0016212c'), + UNHEX('616B263A'), + SUBSTR(all_types.blob_ptr, 3), + SUBSTR(all_types.blob_ptr, 3, 2) +FROM test_sample.all_types; +`) + + var dest []struct{} + err := stmt.Query(db, &dest) + + require.NoError(t, err) +} + var timeT = time.Date(2009, 11, 17, 20, 34, 58, 651387237, time.UTC) func TestTimeExpressions(t *testing.T) { diff --git a/tests/mysql/generator_test.go b/tests/mysql/generator_test.go index 800de18..e356956 100644 --- a/tests/mysql/generator_test.go +++ b/tests/mysql/generator_test.go @@ -3,6 +3,7 @@ package mysql import ( "os" "os/exec" + "path/filepath" "strconv" "testing" @@ -606,3 +607,394 @@ func UseSchema(schema string) { StaffList = StaffList.FromSchema(schema) } ` + +func TestGeneratedTestSampleDatabase(t *testing.T) { + + enumDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/enum/") + modelDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/model/") + tableDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/table/") + viewDir := filepath.Join(testRoot, "/.gentestdata/mysql/test_sample/view/") + + testutils.AssertFileNamesEqual(t, enumDir, "all_types_enum.go", "all_types_enum_ptr.go", + "all_types_view_enum.go", "all_types_view_enum_ptr.go") + testutils.AssertFileContent(t, enumDir+"/all_types_enum.go", allTypesEnum) + + testutils.AssertFileNamesEqual(t, modelDir, "all_types.go", "all_types_enum.go", "all_types_enum_ptr.go", + "all_types_view.go", "all_types_view_enum.go", "all_types_view_enum_ptr.go", "link.go", "link2.go", + "floats.go", "user.go") + + testutils.AssertFileContent(t, modelDir+"/all_types.go", allTypesModelContent) + + testutils.AssertFileNamesEqual(t, tableDir, "all_types.go", + "link.go", "link2.go", "user.go", "floats.go", "table_use_schema.go") + testutils.AssertFileContent(t, tableDir+"/all_types.go", allTypesTableContent) + + testutils.AssertFileNamesEqual(t, viewDir, "all_types_view.go", "view_use_schema.go") +} + +var allTypesEnum = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package enum + +import "github.com/go-jet/jet/v2/mysql" + +var AllTypesEnum = &struct { + Value1 mysql.StringExpression + Value2 mysql.StringExpression + Value3 mysql.StringExpression +}{ + Value1: mysql.NewEnumValue("value1"), + Value2: mysql.NewEnumValue("value2"), + Value3: mysql.NewEnumValue("value3"), +} +` + +var allTypesModelContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package model + +import ( + "time" +) + +type AllTypes struct { + Boolean bool + BooleanPtr *bool + TinyInt int8 + UTinyInt uint8 + SmallInt int16 + USmallInt uint16 + MediumInt int32 + UMediumInt uint32 + Integer int32 + UInteger uint32 + BigInt int64 + UBigInt uint64 + TinyIntPtr *int8 + UTinyIntPtr *uint8 + SmallIntPtr *int16 + USmallIntPtr *uint16 + MediumIntPtr *int32 + UMediumIntPtr *uint32 + IntegerPtr *int32 + UIntegerPtr *uint32 + BigIntPtr *int64 + UBigIntPtr *uint64 + Decimal float64 + DecimalPtr *float64 + Numeric float64 + NumericPtr *float64 + Float float64 + FloatPtr *float64 + Double float64 + DoublePtr *float64 + Real float64 + RealPtr *float64 + Bit string + BitPtr *string + Time time.Time + TimePtr *time.Time + Date time.Time + DatePtr *time.Time + DateTime time.Time + DateTimePtr *time.Time + Timestamp time.Time + TimestampPtr *time.Time + Year int16 + YearPtr *int16 + Char string + CharPtr *string + VarChar string + VarCharPtr *string + Binary []byte + BinaryPtr *[]byte + VarBinary []byte + VarBinaryPtr *[]byte + Blob []byte + BlobPtr *[]byte + Text string + TextPtr *string + Enum AllTypesEnum + EnumPtr *AllTypesEnumPtr + Set string + SetPtr *string + JSON string + JSONPtr *string +} +` + +var allTypesTableContent = ` +// +// Code generated by go-jet DO NOT EDIT. +// +// WARNING: Changes to this file may cause incorrect behavior +// and will be lost if the code is regenerated +// + +package table + +import ( + "github.com/go-jet/jet/v2/mysql" +) + +var AllTypes = newAllTypesTable("test_sample", "all_types", "") + +type allTypesTable struct { + mysql.Table + + // Columns + Boolean mysql.ColumnBool + BooleanPtr mysql.ColumnBool + TinyInt mysql.ColumnInteger + UTinyInt mysql.ColumnInteger + SmallInt mysql.ColumnInteger + USmallInt mysql.ColumnInteger + MediumInt mysql.ColumnInteger + UMediumInt mysql.ColumnInteger + Integer mysql.ColumnInteger + UInteger mysql.ColumnInteger + BigInt mysql.ColumnInteger + UBigInt mysql.ColumnInteger + TinyIntPtr mysql.ColumnInteger + UTinyIntPtr mysql.ColumnInteger + SmallIntPtr mysql.ColumnInteger + USmallIntPtr mysql.ColumnInteger + MediumIntPtr mysql.ColumnInteger + UMediumIntPtr mysql.ColumnInteger + IntegerPtr mysql.ColumnInteger + UIntegerPtr mysql.ColumnInteger + BigIntPtr mysql.ColumnInteger + UBigIntPtr mysql.ColumnInteger + Decimal mysql.ColumnFloat + DecimalPtr mysql.ColumnFloat + Numeric mysql.ColumnFloat + NumericPtr mysql.ColumnFloat + Float mysql.ColumnFloat + FloatPtr mysql.ColumnFloat + Double mysql.ColumnFloat + DoublePtr mysql.ColumnFloat + Real mysql.ColumnFloat + RealPtr mysql.ColumnFloat + Bit mysql.ColumnString + BitPtr mysql.ColumnString + Time mysql.ColumnTime + TimePtr mysql.ColumnTime + Date mysql.ColumnDate + DatePtr mysql.ColumnDate + DateTime mysql.ColumnTimestamp + DateTimePtr mysql.ColumnTimestamp + Timestamp mysql.ColumnTimestamp + TimestampPtr mysql.ColumnTimestamp + Year mysql.ColumnInteger + YearPtr mysql.ColumnInteger + Char mysql.ColumnString + CharPtr mysql.ColumnString + VarChar mysql.ColumnString + VarCharPtr mysql.ColumnString + Binary mysql.ColumnBlob + BinaryPtr mysql.ColumnBlob + VarBinary mysql.ColumnBlob + VarBinaryPtr mysql.ColumnBlob + Blob mysql.ColumnBlob + BlobPtr mysql.ColumnBlob + Text mysql.ColumnString + TextPtr mysql.ColumnString + Enum mysql.ColumnString + EnumPtr mysql.ColumnString + Set mysql.ColumnString + SetPtr mysql.ColumnString + JSON mysql.ColumnString + JSONPtr mysql.ColumnString + + AllColumns mysql.ColumnList + MutableColumns mysql.ColumnList + DefaultColumns mysql.ColumnList +} + +type AllTypesTable struct { + allTypesTable + + NEW allTypesTable +} + +// AS creates new AllTypesTable with assigned alias +func (a AllTypesTable) AS(alias string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), a.TableName(), alias) +} + +// Schema creates new AllTypesTable with assigned schema name +func (a AllTypesTable) FromSchema(schemaName string) *AllTypesTable { + return newAllTypesTable(schemaName, a.TableName(), a.Alias()) +} + +// WithPrefix creates new AllTypesTable with assigned table prefix +func (a AllTypesTable) WithPrefix(prefix string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), prefix+a.TableName(), a.TableName()) +} + +// WithSuffix creates new AllTypesTable with assigned table suffix +func (a AllTypesTable) WithSuffix(suffix string) *AllTypesTable { + return newAllTypesTable(a.SchemaName(), a.TableName()+suffix, a.TableName()) +} + +func newAllTypesTable(schemaName, tableName, alias string) *AllTypesTable { + return &AllTypesTable{ + allTypesTable: newAllTypesTableImpl(schemaName, tableName, alias), + NEW: newAllTypesTableImpl("", "new", ""), + } +} + +func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { + var ( + BooleanColumn = mysql.BoolColumn("boolean") + BooleanPtrColumn = mysql.BoolColumn("boolean_ptr") + TinyIntColumn = mysql.IntegerColumn("tiny_int") + UTinyIntColumn = mysql.IntegerColumn("u_tiny_int") + SmallIntColumn = mysql.IntegerColumn("small_int") + USmallIntColumn = mysql.IntegerColumn("u_small_int") + MediumIntColumn = mysql.IntegerColumn("medium_int") + UMediumIntColumn = mysql.IntegerColumn("u_medium_int") + IntegerColumn = mysql.IntegerColumn("integer") + UIntegerColumn = mysql.IntegerColumn("u_integer") + BigIntColumn = mysql.IntegerColumn("big_int") + UBigIntColumn = mysql.IntegerColumn("u_big_int") + TinyIntPtrColumn = mysql.IntegerColumn("tiny_int_ptr") + UTinyIntPtrColumn = mysql.IntegerColumn("u_tiny_int_ptr") + SmallIntPtrColumn = mysql.IntegerColumn("small_int_ptr") + USmallIntPtrColumn = mysql.IntegerColumn("u_small_int_ptr") + MediumIntPtrColumn = mysql.IntegerColumn("medium_int_ptr") + UMediumIntPtrColumn = mysql.IntegerColumn("u_medium_int_ptr") + IntegerPtrColumn = mysql.IntegerColumn("integer_ptr") + UIntegerPtrColumn = mysql.IntegerColumn("u_integer_ptr") + BigIntPtrColumn = mysql.IntegerColumn("big_int_ptr") + UBigIntPtrColumn = mysql.IntegerColumn("u_big_int_ptr") + DecimalColumn = mysql.FloatColumn("decimal") + DecimalPtrColumn = mysql.FloatColumn("decimal_ptr") + NumericColumn = mysql.FloatColumn("numeric") + NumericPtrColumn = mysql.FloatColumn("numeric_ptr") + FloatColumn = mysql.FloatColumn("float") + FloatPtrColumn = mysql.FloatColumn("float_ptr") + DoubleColumn = mysql.FloatColumn("double") + DoublePtrColumn = mysql.FloatColumn("double_ptr") + RealColumn = mysql.FloatColumn("real") + RealPtrColumn = mysql.FloatColumn("real_ptr") + BitColumn = mysql.StringColumn("bit") + BitPtrColumn = mysql.StringColumn("bit_ptr") + TimeColumn = mysql.TimeColumn("time") + TimePtrColumn = mysql.TimeColumn("time_ptr") + DateColumn = mysql.DateColumn("date") + DatePtrColumn = mysql.DateColumn("date_ptr") + DateTimeColumn = mysql.TimestampColumn("date_time") + DateTimePtrColumn = mysql.TimestampColumn("date_time_ptr") + TimestampColumn = mysql.TimestampColumn("timestamp") + TimestampPtrColumn = mysql.TimestampColumn("timestamp_ptr") + YearColumn = mysql.IntegerColumn("year") + YearPtrColumn = mysql.IntegerColumn("year_ptr") + CharColumn = mysql.StringColumn("char") + CharPtrColumn = mysql.StringColumn("char_ptr") + VarCharColumn = mysql.StringColumn("var_char") + VarCharPtrColumn = mysql.StringColumn("var_char_ptr") + BinaryColumn = mysql.BlobColumn("binary") + BinaryPtrColumn = mysql.BlobColumn("binary_ptr") + VarBinaryColumn = mysql.BlobColumn("var_binary") + VarBinaryPtrColumn = mysql.BlobColumn("var_binary_ptr") + BlobColumn = mysql.BlobColumn("blob") + BlobPtrColumn = mysql.BlobColumn("blob_ptr") + TextColumn = mysql.StringColumn("text") + TextPtrColumn = mysql.StringColumn("text_ptr") + EnumColumn = mysql.StringColumn("enum") + EnumPtrColumn = mysql.StringColumn("enum_ptr") + SetColumn = mysql.StringColumn("set") + SetPtrColumn = mysql.StringColumn("set_ptr") + JSONColumn = mysql.StringColumn("json") + JSONPtrColumn = mysql.StringColumn("json_ptr") + allColumns = mysql.ColumnList{BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn} + mutableColumns = mysql.ColumnList{BooleanColumn, BooleanPtrColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, TinyIntPtrColumn, UTinyIntPtrColumn, SmallIntPtrColumn, USmallIntPtrColumn, MediumIntPtrColumn, UMediumIntPtrColumn, IntegerPtrColumn, UIntegerPtrColumn, BigIntPtrColumn, UBigIntPtrColumn, DecimalColumn, DecimalPtrColumn, NumericColumn, NumericPtrColumn, FloatColumn, FloatPtrColumn, DoubleColumn, DoublePtrColumn, RealColumn, RealPtrColumn, BitColumn, BitPtrColumn, TimeColumn, TimePtrColumn, DateColumn, DatePtrColumn, DateTimeColumn, DateTimePtrColumn, TimestampColumn, TimestampPtrColumn, YearColumn, YearPtrColumn, CharColumn, CharPtrColumn, VarCharColumn, VarCharPtrColumn, BinaryColumn, BinaryPtrColumn, VarBinaryColumn, VarBinaryPtrColumn, BlobColumn, BlobPtrColumn, TextColumn, TextPtrColumn, EnumColumn, EnumPtrColumn, SetColumn, SetPtrColumn, JSONColumn, JSONPtrColumn} + defaultColumns = mysql.ColumnList{BooleanColumn, TinyIntColumn, UTinyIntColumn, SmallIntColumn, USmallIntColumn, MediumIntColumn, UMediumIntColumn, IntegerColumn, UIntegerColumn, BigIntColumn, UBigIntColumn, DecimalColumn, NumericColumn, FloatColumn, DoubleColumn, RealColumn, BitColumn, TimeColumn, DateColumn, DateTimeColumn, TimestampColumn, YearColumn, CharColumn, VarCharColumn, BinaryColumn, VarBinaryColumn, EnumColumn, SetColumn} + ) + + return allTypesTable{ + Table: mysql.NewTable(schemaName, tableName, alias, allColumns...), + + //Columns + Boolean: BooleanColumn, + BooleanPtr: BooleanPtrColumn, + TinyInt: TinyIntColumn, + UTinyInt: UTinyIntColumn, + SmallInt: SmallIntColumn, + USmallInt: USmallIntColumn, + MediumInt: MediumIntColumn, + UMediumInt: UMediumIntColumn, + Integer: IntegerColumn, + UInteger: UIntegerColumn, + BigInt: BigIntColumn, + UBigInt: UBigIntColumn, + TinyIntPtr: TinyIntPtrColumn, + UTinyIntPtr: UTinyIntPtrColumn, + SmallIntPtr: SmallIntPtrColumn, + USmallIntPtr: USmallIntPtrColumn, + MediumIntPtr: MediumIntPtrColumn, + UMediumIntPtr: UMediumIntPtrColumn, + IntegerPtr: IntegerPtrColumn, + UIntegerPtr: UIntegerPtrColumn, + BigIntPtr: BigIntPtrColumn, + UBigIntPtr: UBigIntPtrColumn, + Decimal: DecimalColumn, + DecimalPtr: DecimalPtrColumn, + Numeric: NumericColumn, + NumericPtr: NumericPtrColumn, + Float: FloatColumn, + FloatPtr: FloatPtrColumn, + Double: DoubleColumn, + DoublePtr: DoublePtrColumn, + Real: RealColumn, + RealPtr: RealPtrColumn, + Bit: BitColumn, + BitPtr: BitPtrColumn, + Time: TimeColumn, + TimePtr: TimePtrColumn, + Date: DateColumn, + DatePtr: DatePtrColumn, + DateTime: DateTimeColumn, + DateTimePtr: DateTimePtrColumn, + Timestamp: TimestampColumn, + TimestampPtr: TimestampPtrColumn, + Year: YearColumn, + YearPtr: YearPtrColumn, + Char: CharColumn, + CharPtr: CharPtrColumn, + VarChar: VarCharColumn, + VarCharPtr: VarCharPtrColumn, + Binary: BinaryColumn, + BinaryPtr: BinaryPtrColumn, + VarBinary: VarBinaryColumn, + VarBinaryPtr: VarBinaryPtrColumn, + Blob: BlobColumn, + BlobPtr: BlobPtrColumn, + Text: TextColumn, + TextPtr: TextPtrColumn, + Enum: EnumColumn, + EnumPtr: EnumPtrColumn, + Set: SetColumn, + SetPtr: SetPtrColumn, + JSON: JSONColumn, + JSONPtr: JSONPtrColumn, + + AllColumns: allColumns, + MutableColumns: mutableColumns, + DefaultColumns: defaultColumns, + } +} +` diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index e4f4322..1d6f68d 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -8,6 +8,7 @@ import ( jetmysql "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/stmtcache" "github.com/go-jet/jet/v2/tests/dbconfig" + "github.com/go-jet/jet/v2/tests/internal/utils/repo" _ "github.com/go-sql-driver/mysql" "github.com/stretchr/testify/require" "runtime" @@ -21,12 +22,14 @@ var db *stmtcache.DB var source string var withStatementCaching bool +var testRoot string const MariaDB = "MariaDB" func init() { source = os.Getenv("MY_SQL_SOURCE") withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" + testRoot = repo.GetTestsDirPath() } func sourceIsMariaDB() bool { diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index a411071..96f15ad 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -1,6 +1,7 @@ package postgres import ( + "encoding/base64" "github.com/go-jet/jet/v2/internal/utils/ptr" "github.com/stretchr/testify/assert" @@ -485,6 +486,7 @@ func TestExpressionCast(t *testing.T) { CAST(String("1999-01-08 04:05:06")).AS_TIMESTAMP(), CAST(String("1999-01-08 04:05:06+01:00")).AS_TIMESTAMPZ(), CAST(String("04:05:06")).AS_INTERVAL(), + CAST(String("some text")).AS_BYTEA().EQ(Bytea([]byte("some text"))), func() ProjectionList { if sourceIsCockroachDB() { @@ -538,7 +540,6 @@ func TestStringOperators(t *testing.T) { AllTypes.Text.BETWEEN(String("min"), String("max")), AllTypes.Text.NOT_BETWEEN(AllTypes.VarChar, AllTypes.CharPtr), AllTypes.Text.CONCAT(String("text2")), - AllTypes.Text.CONCAT(Int(11)), AllTypes.Text.LIKE(String("abc")), AllTypes.Text.NOT_LIKE(String("_b_")), AllTypes.Text.REGEXP_LIKE(String("^t")), @@ -569,18 +570,18 @@ func TestStringOperators(t *testing.T) { CONCAT(AllTypes.VarCharPtr, AllTypes.VarCharPtr, String("aaa"), Int(1)), CONCAT(Bool(false), Int(1), Float(22.2), String("test test")), CONCAT_WS(String("string1"), Int(1), Float(11.22), String("bytea"), Bool(false)), //Float(11.12)), - CONVERT(Bytea("bytea"), String("UTF8"), String("LATIN1")), - CONVERT(AllTypes.Bytea, String("UTF8"), String("LATIN1")), - CONVERT_FROM(Bytea("text_in_utf8"), String("UTF8")), - CONVERT_TO(String("text_in_utf8"), String("UTF8")), - ENCODE(Bytea("123\000\001"), String("base64")), - DECODE(String("MTIzAAE="), String("base64")), + CONVERT(Bytea("bytea"), UTF8, LATIN1), + CONVERT(AllTypes.Bytea, UTF8, LATIN1), + CONVERT_FROM(Bytea("text_in_utf8"), UTF8), + CONVERT_TO(String("text_in_utf8"), UTF8), + ENCODE(Bytea("some text"), Escape), + DECODE(String("MTIzAAE="), Base64), FORMAT(String("Hello %s, %1$s"), String("World")), INITCAP(String("hi THOMAS")), LEFT(String("abcde"), Int(2)), RIGHT(String("abcde"), Int(2)), LENGTH(Bytea("jose")), - LENGTH(Bytea("jose"), String("UTF8")), + LENGTH(Bytea("jose"), UTF8), LPAD(String("Hi"), Int(5)), LPAD(String("Hi"), Int(5), String("xy")), RPAD(String("Hi"), Int(5)), @@ -601,6 +602,155 @@ func TestStringOperators(t *testing.T) { require.NoError(t, err) } +func TestBlob(t *testing.T) { + + var sampleBlob = Bytea([]byte{11, 0, 22, 33, 44}) + var textBlob = Bytea([]byte("text blob")) + + stmt := SELECT( + AllTypes.Bytea.EQ(sampleBlob), + AllTypes.Bytea.EQ(AllTypes.ByteaPtr), + AllTypes.Bytea.NOT_EQ(sampleBlob), + AllTypes.Bytea.GT(textBlob), + AllTypes.Bytea.GT_EQ(AllTypes.ByteaPtr), + AllTypes.Bytea.LT(AllTypes.ByteaPtr), + AllTypes.Bytea.LT_EQ(sampleBlob), + AllTypes.Bytea.BETWEEN(Bytea([]byte("min")), Bytea([]byte("max"))), + AllTypes.Bytea.NOT_BETWEEN(AllTypes.Bytea, AllTypes.ByteaPtr), + AllTypes.Bytea.CONCAT(textBlob), + + func() ProjectionList { + if sourceIsCockroachDB() { + return ProjectionList{NULL} + } + // cockroach doesn't support currently + return ProjectionList{ + AllTypes.Bytea.LIKE(Bytea("b'%pattern%'")), + AllTypes.Bytea.NOT_LIKE(Bytea("b'%pattern%'")), + + BTRIM(AllTypes.Bytea, Bytea([]byte{33})), + RTRIM(AllTypes.ByteaPtr, sampleBlob), + LTRIM(sampleBlob, textBlob), + CONCAT(sampleBlob, AllTypes.ByteaPtr, textBlob), + BIT_COUNT(sampleBlob).EQ(Int(3)), + LENGTH(textBlob, UTF8).EQ(Int(4)), + + CONVERT(textBlob, UTF8, WIN1252), + CONVERT(AllTypes.Bytea, UTF8, LATIN1).EQ(sampleBlob), + } + }(), + + BIT_LENGTH(textBlob), + OCTET_LENGTH(textBlob), + + GET_BIT(textBlob, Int(2)).EQ(Int(23)), + GET_BYTE(sampleBlob, Int(1)).EQ(Int(0)), + SET_BIT(textBlob, Int(1), Int(0)).EQ(sampleBlob), + SET_BYTE(textBlob, Int(1), Int(0)).EQ(textBlob), + LENGTH(sampleBlob), + + SUBSTR(AllTypes.Bytea, Int(0), Int(2)), + + MD5(AllTypes.Bytea), + SHA224(AllTypes.Bytea), + SHA256(AllTypes.Bytea), + SHA384(AllTypes.Bytea), + SHA512(AllTypes.Bytea), + + ENCODE(sampleBlob, Base64), + DECODE(String("A234C12B"), Hex).EQ(sampleBlob), + + CONVERT_FROM(AllTypes.ByteaPtr, UTF8).EQ(AllTypes.VarChar), + CONVERT_TO(AllTypes.Text, UTF8).NOT_EQ(textBlob), + + RawBytea("DECODE(#1::text, #2)", RawArgs{ + "#1": "A234C12B", + "#2": "hex", + }).EQ(sampleBlob), + ).FROM( + AllTypes, + ) + + if !sourceIsCockroachDB() { + testutils.AssertStatementSql(t, stmt, ` +SELECT all_types.bytea = $1::bytea, + all_types.bytea = all_types.bytea_ptr, + all_types.bytea != $2::bytea, + all_types.bytea > $3::bytea, + all_types.bytea >= all_types.bytea_ptr, + all_types.bytea < all_types.bytea_ptr, + all_types.bytea <= $4::bytea, + all_types.bytea BETWEEN $5::bytea AND $6::bytea, + all_types.bytea NOT BETWEEN all_types.bytea AND all_types.bytea_ptr, + all_types.bytea || $7::bytea, + all_types.bytea LIKE $8::bytea, + all_types.bytea NOT LIKE $9::bytea, + BTRIM(all_types.bytea, $10::bytea), + RTRIM(all_types.bytea_ptr, $11::bytea), + LTRIM($12::bytea, $13::bytea), + CONCAT($14::bytea, all_types.bytea_ptr, $15::bytea), + BIT_COUNT($16::bytea) = $17, + LENGTH($18::bytea, $19::text) = $20, + CONVERT($21::bytea, $22::text, $23::text), + CONVERT(all_types.bytea, $24::text, $25::text) = $26::bytea, + BIT_LENGTH($27::bytea), + OCTET_LENGTH($28::bytea), + GET_BIT($29::bytea, $30) = $31, + GET_BYTE($32::bytea, $33) = $34, + SET_BIT($35::bytea, $36, $37) = $38::bytea, + SET_BYTE($39::bytea, $40, $41) = $42::bytea, + LENGTH($43::bytea), + SUBSTR(all_types.bytea, $44, $45), + MD5(all_types.bytea), + SHA224(all_types.bytea), + SHA256(all_types.bytea), + SHA384(all_types.bytea), + SHA512(all_types.bytea), + ENCODE($46::bytea, $47::text), + DECODE($48::text, $49::text) = $50::bytea, + CONVERT_FROM(all_types.bytea_ptr, $51::text) = all_types.var_char, + CONVERT_TO(all_types.text, $52::text) != $53::bytea, + (DECODE($54::text, $55)) = $56::bytea +FROM test_sample.all_types; +`) + } + + var dest []struct{} + err := stmt.Query(db, &dest) + + require.NoError(t, err) +} + +func TestBlobConversion(t *testing.T) { + + nonPrintable := []byte{11, 22, 33, 44, 55} + printable := []byte("this is blob") + + stmt := SELECT( + Bytea(nonPrintable).AS("non_printable"), + Bytea(printable).AS("printable"), + + ENCODE(Bytea(nonPrintable), Base64).AS("non_printable_base64"), + CONVERT_FROM(Bytea(printable), UTF8).AS("printable_utf8"), + ) + + var dest struct { + NonPrintable []byte + Printable []byte + + NonPrintableBase64 []byte + PrintableUTF8 string + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Equal(t, dest.NonPrintable, nonPrintable) + require.Equal(t, dest.Printable, printable) + require.Equal(t, dest.NonPrintableBase64, []byte(base64.StdEncoding.EncodeToString(nonPrintable))) + require.Equal(t, dest.PrintableUTF8, string(printable)) +} + func TestBoolOperators(t *testing.T) { query := AllTypes.SELECT( AllTypes.Boolean.EQ(AllTypes.BooleanPtr).AS("EQ1"), @@ -1208,6 +1358,106 @@ SELECT ROW($1::integer, $2::real, $3::text) AS "row", require.NoError(t, err) } +func TestAllTypesSubQueryFrom(t *testing.T) { + subQuery := SELECT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.DoublePrecision, + AllTypes.Text, + AllTypes.Date, + AllTypes.Time, + AllTypes.Timez, + AllTypes.Timestamp, + AllTypes.Interval, + AllTypes.Bytea, + ).FROM( + AllTypes, + ).AsTable("subQuery") + + stmt := SELECT( + AllTypes.Boolean.From(subQuery), + AllTypes.Integer.From(subQuery), + AllTypes.DoublePrecision.From(subQuery), + AllTypes.Text.From(subQuery), + AllTypes.Date.From(subQuery), + AllTypes.Time.From(subQuery), + AllTypes.Timez.From(subQuery), + AllTypes.Timestamp.From(subQuery), + AllTypes.Interval.From(subQuery), + AllTypes.Bytea.From(subQuery), + ).FROM( + subQuery, + ) + + testutils.AssertStatementSql(t, stmt, ` +SELECT "subQuery"."all_types.boolean" AS "all_types.boolean", + "subQuery"."all_types.integer" AS "all_types.integer", + "subQuery"."all_types.double_precision" AS "all_types.double_precision", + "subQuery"."all_types.text" AS "all_types.text", + "subQuery"."all_types.date" AS "all_types.date", + "subQuery"."all_types.time" AS "all_types.time", + "subQuery"."all_types.timez" AS "all_types.timez", + "subQuery"."all_types.timestamp" AS "all_types.timestamp", + "subQuery"."all_types.interval" AS "all_types.interval", + "subQuery"."all_types.bytea" AS "all_types.bytea" +FROM ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.integer AS "all_types.integer", + all_types.double_precision AS "all_types.double_precision", + all_types.text AS "all_types.text", + all_types.date AS "all_types.date", + all_types.time AS "all_types.time", + all_types.timez AS "all_types.timez", + all_types.timestamp AS "all_types.timestamp", + all_types.interval AS "all_types.interval", + all_types.bytea AS "all_types.bytea" + FROM test_sample.all_types + ) AS "subQuery"; +`) + + var dest []model.AllTypes + + err := stmt.Query(db, &dest) + require.NoError(t, err) +} + +func TestAllTypesUpdateSet(t *testing.T) { + + stmt := AllTypes.UPDATE(). + SET( + AllTypes.Boolean.SET(Bool(false)), + AllTypes.Integer.SET(Int(2)), + AllTypes.DoublePrecision.SET(Float(2.22)), + AllTypes.Text.SET(Text("some text")), + AllTypes.Date.SET(DateT(time.Now())), + AllTypes.Time.SET(TimeT(time.Now())), + AllTypes.Timez.SET(TimezT(time.Now())), + AllTypes.Timestamp.SET(TimestampT(time.Now())), + AllTypes.Interval.SET(INTERVAL(1, HOUR)), + AllTypes.Bytea.SET(Bytea([]byte{11, 22, 33, 44})), + ).WHERE(Bool(true)) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE test_sample.all_types +SET boolean = $1::boolean, + integer = $2, + double_precision = $3, + text = $4::text, + date = $5::date, + time = $6::time without time zone, + timez = $7::time with time zone, + timestamp = $8::timestamp without time zone, + interval = INTERVAL '1 HOUR', + bytea = $9::bytea +WHERE $10::boolean; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + _, err := stmt.Exec(tx) + require.NoError(t, err) + }) +} + func TestSubQueryColumnReference(t *testing.T) { type expected struct { sql string diff --git a/tests/postgres/generator_test.go b/tests/postgres/generator_test.go index 005e703..06478a6 100644 --- a/tests/postgres/generator_test.go +++ b/tests/postgres/generator_test.go @@ -974,8 +974,8 @@ type allTypesTable struct { Char postgres.ColumnString TextPtr postgres.ColumnString Text postgres.ColumnString - ByteaPtr postgres.ColumnString - Bytea postgres.ColumnString + ByteaPtr postgres.ColumnBytea + Bytea postgres.ColumnBytea TimestampzPtr postgres.ColumnTimestampz Timestampz postgres.ColumnTimestampz TimestampPtr postgres.ColumnTimestamp @@ -1078,8 +1078,8 @@ func newAllTypesTableImpl(schemaName, tableName, alias string) allTypesTable { CharColumn = postgres.StringColumn("char") TextPtrColumn = postgres.StringColumn("text_ptr") TextColumn = postgres.StringColumn("text") - ByteaPtrColumn = postgres.StringColumn("bytea_ptr") - ByteaColumn = postgres.StringColumn("bytea") + ByteaPtrColumn = postgres.ByteaColumn("bytea_ptr") + ByteaColumn = postgres.ByteaColumn("bytea") TimestampzPtrColumn = postgres.TimestampzColumn("timestampz_ptr") TimestampzColumn = postgres.TimestampzColumn("timestampz") TimestampPtrColumn = postgres.TimestampColumn("timestamp_ptr") diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index a179bd4..c1d48d9 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -31,6 +31,7 @@ const CockroachDB = "COCKROACH_DB" func init() { source = os.Getenv("PG_SOURCE") withStatementCaching = os.Getenv("JET_TESTS_WITH_STMT_CACHE") == "true" + testRoot = repo.GetTestsDirPath() } func sourceIsCockroachDB() bool { @@ -46,8 +47,6 @@ func skipForCockroachDB(t *testing.T) { func TestMain(m *testing.M) { defer profile.Start().Stop() - setTestRoot() - for _, driverName := range []string{"postgres", "pgx"} { fmt.Printf("\nRunning postgres tests for driver: %s, caching enabled: %t \n", driverName, withStatementCaching) @@ -94,10 +93,6 @@ func getConnectionString() string { return dbconfig.PostgresConnectString } -func setTestRoot() { - testRoot = repo.GetTestsDirPath() -} - var loggedSQL string var loggedSQLArgs []interface{} var loggedDebugSQL string diff --git a/tests/sqlite/alltypes_test.go b/tests/sqlite/alltypes_test.go index 080e870..b676bb4 100644 --- a/tests/sqlite/alltypes_test.go +++ b/tests/sqlite/alltypes_test.go @@ -1,6 +1,7 @@ package sqlite import ( + "encoding/hex" "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/utils/ptr" . "github.com/go-jet/jet/v2/sqlite" @@ -18,7 +19,7 @@ import ( func TestAllTypes(t *testing.T) { - dest := []model.AllTypes{} + var dest []model.AllTypes err := SELECT(AllTypes.AllColumns). FROM(AllTypes). @@ -571,12 +572,102 @@ func TestStringOperators(t *testing.T) { SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), ).FROM(AllTypes) - dest := []struct{}{} + var dest []struct{} err := query.Query(sampleDB, &dest) require.NoError(t, err) } +func TestBlob(t *testing.T) { + + var sampleBlob = Blob([]byte{11, 0, 22, 33, 44}) + var textBlob = Blob([]byte("text blob")) + + stmt := SELECT( + AllTypes.Blob.EQ(sampleBlob), + AllTypes.Blob.EQ(AllTypes.BlobPtr), + AllTypes.Blob.NOT_EQ(sampleBlob), + AllTypes.Blob.GT(textBlob), + AllTypes.Blob.GT_EQ(AllTypes.BlobPtr), + AllTypes.Blob.LT(AllTypes.BlobPtr), + AllTypes.Blob.LT_EQ(sampleBlob), + AllTypes.Blob.BETWEEN(Blob([]byte("min")), Blob([]byte("max"))), + AllTypes.Blob.NOT_BETWEEN(AllTypes.Blob, AllTypes.BlobPtr), + AllTypes.Blob.CONCAT(textBlob), + AllTypes.Blob.LIKE(AllTypes.BlobPtr), + AllTypes.Blob.NOT_LIKE(sampleBlob), + + RTRIM(AllTypes.BlobPtr, sampleBlob), + LTRIM(sampleBlob, textBlob), + LENGTH(sampleBlob), + OCTET_LENGTH(textBlob), + SUBSTR(AllTypes.Blob, Int(0), Int(2)), + + HEX(AllTypes.Blob), + UNHEX(AllTypes.Text), + ).FROM( + AllTypes, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT all_types.blob = X'0b0016212c', + all_types.blob = all_types.blob_ptr, + all_types.blob != X'0b0016212c', + all_types.blob > X'7465787420626c6f62', + all_types.blob >= all_types.blob_ptr, + all_types.blob < all_types.blob_ptr, + all_types.blob <= X'0b0016212c', + all_types.blob BETWEEN X'6d696e' AND X'6d6178', + all_types.blob NOT BETWEEN all_types.blob AND all_types.blob_ptr, + all_types.blob || X'7465787420626c6f62', + all_types.blob LIKE all_types.blob_ptr, + all_types.blob NOT LIKE X'0b0016212c', + RTRIM(all_types.blob_ptr, X'0b0016212c'), + LTRIM(X'0b0016212c', X'7465787420626c6f62'), + LENGTH(X'0b0016212c'), + OCTET_LENGTH(X'7465787420626c6f62'), + SUBSTR(all_types.blob, 0, 2), + HEX(all_types.blob), + UNHEX(all_types.text) +FROM all_types; +`) + + var dest []struct{} + err := stmt.Query(sampleDB, &dest) + + require.NoError(t, err) +} + +func TestBlobConversion(t *testing.T) { + + nonPrintable := []byte{0x11, 0x22, 0x33, 0x44, 0x55} + printable := []byte("this is blob") + + stmt := SELECT( + Blob(nonPrintable).AS("non_printable"), + Blob(printable).AS("printable"), + + HEX(Blob(nonPrintable)).AS("non_printable_hex"), + UNHEX(String("1122334455")).AS("non_printable_unhex"), + ) + + var dest struct { + NonPrintable []byte + Printable []byte + + NonPrintableHex string + NonPrintableUnHex []byte + } + + err := stmt.Query(db, &dest) + + require.NoError(t, err) + require.Equal(t, dest.NonPrintable, nonPrintable) + require.Equal(t, dest.Printable, printable) + require.Equal(t, dest.NonPrintableHex, hex.EncodeToString(nonPrintable)) + require.Equal(t, dest.NonPrintableUnHex, nonPrintable) +} + func TestReservedWord(t *testing.T) { stmt := SELECT(ReservedWords.AllColumns). FROM(ReservedWords)