diff --git a/internal/jet/cast.go b/internal/jet/cast.go index 7369256..bd0eb7c 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -39,7 +39,7 @@ func (b *castExpression) serialize(statement StatementType, out *SqlBuilder, opt expression := b.expression castType := b.cast - if castOverride := out.Dialect.SerializeOverride("CAST"); castOverride != nil { + if castOverride := out.Dialect.OperatorSerializeOverride("CAST"); castOverride != nil { castOverride(expression, String(castType))(statement, out, options...) return } diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index d07320d..a0a9d37 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -3,7 +3,8 @@ package jet type Dialect interface { Name() string PackageName() string - SerializeOverride(operator string) SerializeOverride + OperatorSerializeOverride(operator string) SerializeOverride + FunctionSerializeOverride(function string) SerializeOverride AliasQuoteChar() byte IdentifierQuoteChar() byte ArgumentPlaceholder() QueryPlaceholderFunc @@ -14,32 +15,35 @@ type SerializeOverride func(expressions ...Expression) SerializeFunc type QueryPlaceholderFunc func(ord int) string type DialectParams struct { - Name string - PackageName string - SerializeOverrides map[string]SerializeOverride - AliasQuoteChar byte - IdentifierQuoteChar byte - ArgumentPlaceholder QueryPlaceholderFunc + Name string + PackageName string + OperatorSerializeOverrides map[string]SerializeOverride + FunctionSerializeOverrides map[string]SerializeOverride + AliasQuoteChar byte + IdentifierQuoteChar byte + ArgumentPlaceholder QueryPlaceholderFunc } func NewDialect(params DialectParams) Dialect { return &dialectImpl{ - name: params.Name, - packageName: params.PackageName, - serializeOverrides: params.SerializeOverrides, - aliasQuoteChar: params.AliasQuoteChar, - identifierQuoteChar: params.IdentifierQuoteChar, - argumentPlaceholder: params.ArgumentPlaceholder, + name: params.Name, + packageName: params.PackageName, + operatorSerializeOverrides: params.OperatorSerializeOverrides, + functionSerializeOverrides: params.FunctionSerializeOverrides, + aliasQuoteChar: params.AliasQuoteChar, + identifierQuoteChar: params.IdentifierQuoteChar, + argumentPlaceholder: params.ArgumentPlaceholder, } } type dialectImpl struct { - name string - packageName string - serializeOverrides map[string]SerializeOverride - aliasQuoteChar byte - identifierQuoteChar byte - argumentPlaceholder QueryPlaceholderFunc + name string + packageName string + operatorSerializeOverrides map[string]SerializeOverride + functionSerializeOverrides map[string]SerializeOverride + aliasQuoteChar byte + identifierQuoteChar byte + argumentPlaceholder QueryPlaceholderFunc supportsReturning bool } @@ -52,8 +56,18 @@ func (d *dialectImpl) PackageName() string { return d.packageName } -func (d *dialectImpl) SerializeOverride(operator string) SerializeOverride { - return d.serializeOverrides[operator] +func (d *dialectImpl) OperatorSerializeOverride(operator string) SerializeOverride { + if d.operatorSerializeOverrides == nil { + return nil + } + return d.operatorSerializeOverrides[operator] +} + +func (d *dialectImpl) FunctionSerializeOverride(function string) SerializeOverride { + if d.functionSerializeOverrides == nil { + return nil + } + return d.functionSerializeOverrides[function] } func (d *dialectImpl) AliasQuoteChar() byte { diff --git a/internal/jet/expression.go b/internal/jet/expression.go index d71fa40..f303fee 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -110,7 +110,7 @@ func (c *binaryOpExpression) serialize(statement StatementType, out *SqlBuilder, out.WriteString("(") } - if serializeOverride := out.Dialect.SerializeOverride(c.operator); serializeOverride != nil { + if serializeOverride := out.Dialect.OperatorSerializeOverride(c.operator); serializeOverride != nil { serializeOverrideFunc := serializeOverride(c.lhs, c.rhs, c.additionalParam) serializeOverrideFunc(statement, out, options...) } else { diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 2e6e9dd..56b64b1 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -505,7 +505,7 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr } func (f *funcExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { - if serializeOverride := out.Dialect.SerializeOverride(f.name); serializeOverride != nil { + if serializeOverride := out.Dialect.FunctionSerializeOverride(f.name); serializeOverride != nil { serializeOverrideFunc := serializeOverride(f.expressions...) serializeOverrideFunc(statement, out, options...) return diff --git a/mysql/delete_statement.go b/mysql/delete_statement.go index a01f1c3..07dd4ef 100644 --- a/mysql/delete_statement.go +++ b/mysql/delete_statement.go @@ -3,7 +3,7 @@ package mysql import "github.com/go-jet/jet/internal/jet" type DeleteStatement interface { - jet.Statement + Statement WHERE(expression BoolExpression) DeleteStatement ORDER_BY(orderByClauses ...jet.OrderByClause) DeleteStatement diff --git a/mysql/dialect.go b/mysql/dialect.go index 6f8fe31..7186fa6 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -8,21 +8,21 @@ var Dialect = NewDialect() func NewDialect() jet.Dialect { - serializeOverrides := map[string]jet.SerializeOverride{} - serializeOverrides[jet.StringRegexpLikeOperator] = mysql_REGEXP_LIKE_operator - serializeOverrides[jet.StringNotRegexpLikeOperator] = mysql_NOT_REGEXP_LIKE_operator - serializeOverrides["IS DISTINCT FROM"] = mysql_IS_DISTINCT_FROM - serializeOverrides["IS NOT DISTINCT FROM"] = mysql_IS_NOT_DISTINCT_FROM - serializeOverrides["/"] = mysql_DIVISION - serializeOverrides["#"] = mysql_BIT_XOR - serializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator + operatorSerializeOverrides := map[string]jet.SerializeOverride{} + operatorSerializeOverrides[jet.StringRegexpLikeOperator] = mysql_REGEXP_LIKE_operator + operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = mysql_NOT_REGEXP_LIKE_operator + operatorSerializeOverrides["IS DISTINCT FROM"] = mysql_IS_DISTINCT_FROM + operatorSerializeOverrides["IS NOT DISTINCT FROM"] = mysql_IS_NOT_DISTINCT_FROM + operatorSerializeOverrides["/"] = mysql_DIVISION + operatorSerializeOverrides["#"] = mysql_BIT_XOR + operatorSerializeOverrides[jet.StringConcatOperator] = mysql_CONCAT_operator mySQLDialectParams := jet.DialectParams{ - Name: "MySQL", - PackageName: "mysql", - SerializeOverrides: serializeOverrides, - AliasQuoteChar: '"', - IdentifierQuoteChar: '`', + Name: "MySQL", + PackageName: "mysql", + OperatorSerializeOverrides: operatorSerializeOverrides, + AliasQuoteChar: '"', + IdentifierQuoteChar: '`', ArgumentPlaceholder: func(int) string { return "?" }, diff --git a/mysql/expressions.go b/mysql/expressions.go index 297507d..d91ccaf 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -30,7 +30,3 @@ var DateTimeExp = jet.TimestampExp var TimestampExp = jet.TimestampExp var Raw = jet.Raw - -var NewEnumValue = jet.NewEnumValue - -type Statement jet.Statement diff --git a/mysql/insert_statement.go b/mysql/insert_statement.go index a9020fe..8639eb7 100644 --- a/mysql/insert_statement.go +++ b/mysql/insert_statement.go @@ -4,7 +4,7 @@ import "github.com/go-jet/jet/internal/jet" // InsertStatement is interface for SQL INSERT statements type InsertStatement interface { - jet.Statement + Statement // Insert row of values VALUES(value interface{}, values ...interface{}) InsertStatement diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 11a1563..71d3e46 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -10,9 +10,9 @@ var ( ) type SelectStatement interface { - jet.Statement + Statement jet.HasProjections - jet.Expression + Expression DISTINCT() SelectStatement FROM(table ReadableTable) SelectStatement @@ -32,11 +32,21 @@ type SelectStatement interface { } //SELECT creates new SelectStatement with list of projections -func SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement { - return newSelectStatement(nil, append([]jet.Projection{projection}, projections...)) +func SELECT(projection Projection, projections ...Projection) SelectStatement { + return newSelectStatement(nil, append([]Projection{projection}, projections...)) } -func newSelectStatement(table ReadableTable, projections []jet.Projection) SelectStatement { +func toJetProjectionList(projections []Projection) []jet.Projection { + ret := []jet.Projection{} + + for _, projection := range projections { + ret = append(ret, projection) + } + + return ret +} + +func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, @@ -44,7 +54,7 @@ func newSelectStatement(table ReadableTable, projections []jet.Projection) Selec newSelect.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSelect - newSelect.Select.Projections = projections + newSelect.Select.Projections = toJetProjectionList(projections) newSelect.From.Table = table newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 diff --git a/mysql/table.go b/mysql/table.go index ada83d6..11b02d5 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -14,7 +14,7 @@ type Table interface { type readableTable interface { // Generates a select query on the current tableName. - SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement + SELECT(projection Projection, projections ...Projection) SelectStatement // Creates a inner join tableName Expression using onCondition. INNER_JOIN(table ReadableTable, onCondition BoolExpression) joinSelectUpdateTable @@ -47,8 +47,8 @@ type readableTableInterfaceImpl struct { } // Generates a select query on the current tableName. -func (r *readableTableInterfaceImpl) SELECT(projection1 jet.Projection, projections ...jet.Projection) SelectStatement { - return newSelectStatement(r.parent, append([]jet.Projection{projection1}, projections...)) +func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { + return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. diff --git a/mysql/types.go b/mysql/types.go new file mode 100644 index 0000000..ec85765 --- /dev/null +++ b/mysql/types.go @@ -0,0 +1,8 @@ +package mysql + +import "github.com/go-jet/jet/internal/jet" + +type Statement jet.Statement +type Projection jet.Projection + +var NewEnumValue = jet.NewEnumValue diff --git a/postgres/dialect.go b/postgres/dialect.go index 15ed7e0..8de68b8 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -9,17 +9,17 @@ var Dialect = NewDialect() func NewDialect() jet.Dialect { - serializeOverrides := map[string]jet.SerializeOverride{} - serializeOverrides[jet.StringRegexpLikeOperator] = postgres_REGEXP_LIKE_operator - serializeOverrides[jet.StringNotRegexpLikeOperator] = postgres_NOT_REGEXP_LIKE_operator - serializeOverrides["CAST"] = postgresCAST + operatorSerializeOverrides := map[string]jet.SerializeOverride{} + operatorSerializeOverrides[jet.StringRegexpLikeOperator] = postgres_REGEXP_LIKE_operator + operatorSerializeOverrides[jet.StringNotRegexpLikeOperator] = postgres_NOT_REGEXP_LIKE_operator + operatorSerializeOverrides["CAST"] = postgresCAST dialectParams := jet.DialectParams{ - Name: "PostgreSQL", - PackageName: "postgres", - SerializeOverrides: serializeOverrides, - AliasQuoteChar: '"', - IdentifierQuoteChar: '"', + Name: "PostgreSQL", + PackageName: "postgres", + OperatorSerializeOverrides: operatorSerializeOverrides, + AliasQuoteChar: '"', + IdentifierQuoteChar: '"', ArgumentPlaceholder: func(ord int) string { return "$" + strconv.Itoa(ord) }, diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index f771636..f8c0d07 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -433,7 +433,8 @@ LIMIT ?; } func TestStringOperators(t *testing.T) { - query := AllTypes.SELECT( + + var projectionList = []Projection{ AllTypes.Text.EQ(AllTypes.Char), AllTypes.Text.EQ(String("Text")), AllTypes.Text.NOT_EQ(AllTypes.VarCharPtr), @@ -478,17 +479,24 @@ func TestStringOperators(t *testing.T) { REVERSE(AllTypes.VarCharPtr), SUBSTR(AllTypes.CharPtr, Int(3)), SUBSTR(AllTypes.CharPtr, Int(3), Int(2)), - REGEXP_LIKE(String("ABA"), String("aba")), - REGEXP_LIKE(String("ABA"), String("aba"), "i"), - REGEXP_LIKE(AllTypes.Text, String("aba"), "i"), - ) + } + if !sourceIsMariaDB() { + projectionList = append(projectionList, []Projection{ + REGEXP_LIKE(String("ABA"), String("aba")), + REGEXP_LIKE(String("ABA"), String("aba"), "i"), + REGEXP_LIKE(AllTypes.Text, String("aba"), "i"), + }...) + } //_, args, _ := query.Sql() //fmt.Println(query.Sql()) //fmt.Println(args[15]) - // fmt.Println(query.Sql()) + query := SELECT(projectionList[0], projectionList[1:]...). + FROM(AllTypes) + + fmt.Println(query.DebugSql()) err := query.Query(db, &struct{}{}) @@ -765,7 +773,7 @@ func TestTimeLiterals(t *testing.T) { TimestampT(timeT).AS("timestampT"), ).FROM(AllTypes).LIMIT(1) - fmt.Println(query.DebugSql()) + //fmt.Println(query.DebugSql()) testutils.AssertStatementSql(t, query, ` SELECT CAST(? AS DATE) AS "date", diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 8680a49..81e5ae1 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -69,6 +69,10 @@ ORDER BY actor.actor_id; } func TestSelectGroupByHaving(t *testing.T) { + if sourceIsMariaDB() { + return + } + expectedSQL := ` SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id",