diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 56b64b1..3cab191 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -349,7 +349,7 @@ func TO_HEX(number IntegerExpression) StringExpression { // REGEXP_LIKE Returns 1 if the string expr matches the regular expression specified by the pattern pat, 0 otherwise. func REGEXP_LIKE(stringExp StringExpression, pattern StringExpression, matchType ...string) BoolExpression { if len(matchType) > 0 { - return newBoolFunc("REGEXP_LIKE", stringExp, pattern, String(matchType[0], true)) + return newBoolFunc("REGEXP_LIKE", stringExp, pattern, ConstLiteral(matchType[0])) } return newBoolFunc("REGEXP_LIKE", stringExp, pattern) @@ -391,7 +391,7 @@ func CURRENT_TIME(precision ...int) TimezExpression { var timezFunc *timezFunc if len(precision) > 0 { - timezFunc = newTimezFunc("CURRENT_TIME", constLiteral(precision[0])) + timezFunc = newTimezFunc("CURRENT_TIME", ConstLiteral(precision[0])) } else { timezFunc = newTimezFunc("CURRENT_TIME") } @@ -406,7 +406,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampzExpression { var timestampzFunc *timestampzFunc if len(precision) > 0 { - timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", constLiteral(precision[0])) + timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP", ConstLiteral(precision[0])) } else { timestampzFunc = newTimestampzFunc("CURRENT_TIMESTAMP") } @@ -421,7 +421,7 @@ func LOCALTIME(precision ...int) TimeExpression { var timeFunc *timeFunc if len(precision) > 0 { - timeFunc = newTimeFunc("LOCALTIME", constLiteral(precision[0])) + timeFunc = newTimeFunc("LOCALTIME", ConstLiteral(precision[0])) } else { timeFunc = newTimeFunc("LOCALTIME") } @@ -436,7 +436,7 @@ func LOCALTIMESTAMP(precision ...int) TimestampExpression { var timestampFunc *timestampFunc if len(precision) > 0 { - timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", constLiteral(precision[0])) + timestampFunc = NewTimestampFunc("LOCALTIMESTAMP", ConstLiteral(precision[0])) } else { timestampFunc = NewTimestampFunc("LOCALTIMESTAMP") } diff --git a/internal/jet/integer_expression_test.go b/internal/jet/integer_expression_test.go index 5496da5..7920545 100644 --- a/internal/jet/integer_expression_test.go +++ b/internal/jet/integer_expression_test.go @@ -66,7 +66,7 @@ func TestIntExpressionPOW(t *testing.T) { func TestIntExpressionBIT_NOT(t *testing.T) { assertClauseSerialize(t, BIT_NOT(table2ColInt), "(~ table2.col_int)") - assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ $1)", int64(11)) + assertClauseSerialize(t, BIT_NOT(Int(11)), "(~ 11)") } func TestIntExpressionBIT_AND(t *testing.T) { diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index ccb9940..5516572 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -32,7 +32,7 @@ func literal(value interface{}, optionalConstant ...bool) *literalExpressionImpl return &exp } -func constLiteral(value interface{}) *literalExpressionImpl { +func ConstLiteral(value interface{}) *literalExpressionImpl { exp := literal(value) exp.constant = true @@ -61,13 +61,10 @@ type integerLiteralExpression struct { } // Int is constructor for integer expressions literals. -func Int(value int64, constant ...bool) IntegerExpression { +func Int(value int64) IntegerExpression { numLiteral := &integerLiteralExpression{} numLiteral.literalExpressionImpl = *literal(value) - if len(constant) > 0 && constant[0] == true { - numLiteral.constant = true - } numLiteral.literalExpressionImpl.Parent = numLiteral numLiteral.integerInterfaceImpl.parent = numLiteral @@ -114,12 +111,9 @@ type stringLiteral struct { } // String creates new string literal expression -func String(value string, constant ...bool) StringExpression { +func String(value string) StringExpression { stringLiteral := stringLiteral{} stringLiteral.literalExpressionImpl = *literal(value) - if len(constant) > 0 && constant[0] == true { - stringLiteral.constant = true - } stringLiteral.stringInterfaceImpl.parent = &stringLiteral diff --git a/internal/jet/operators.go b/internal/jet/operators.go index 74b2be5..14dbd77 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -15,6 +15,9 @@ func NOT(exp BoolExpression) BoolExpression { // BIT_NOT inverts every bit in integer expression result func BIT_NOT(expr IntegerExpression) IntegerExpression { + if literalExp, ok := expr.(LiteralExpression); ok { + literalExp.SetConstant(true) + } return newPrefixIntegerOperator(expr, "~") } diff --git a/mysql/cast.go b/mysql/cast.go index c4c056c..8240f42 100644 --- a/mysql/cast.go +++ b/mysql/cast.go @@ -30,7 +30,7 @@ type castImpl struct { jet.Cast } -func CAST(expr jet.Expression) cast { +func CAST(expr Expression) cast { castImpl := &castImpl{} castImpl.Cast = jet.NewCastImpl(expr) diff --git a/mysql/expressions.go b/mysql/expressions.go index d91ccaf..2ff56b8 100644 --- a/mysql/expressions.go +++ b/mysql/expressions.go @@ -30,3 +30,5 @@ var DateTimeExp = jet.TimestampExp var TimestampExp = jet.TimestampExp var Raw = jet.Raw + +var NewEnumValue = jet.NewEnumValue diff --git a/mysql/functions.go b/mysql/functions.go index b4611f2..a78b27e 100644 --- a/mysql/functions.go +++ b/mysql/functions.go @@ -180,7 +180,7 @@ func CURRENT_TIMESTAMP(precision ...int) TimestampExpression { // NOW returns current datetime func NOW(fsp ...int) DateTimeExpression { if len(fsp) > 0 { - return jet.NewTimestampFunc("NOW", Int(int64(fsp[0]), true)) + return jet.NewTimestampFunc("NOW", jet.ConstLiteral(int64(fsp[0]))) } return jet.NewTimestampFunc("NOW") } diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 71d3e46..ad91935 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -36,16 +36,6 @@ func SELECT(projection Projection, projections ...Projection) SelectStatement { return newSelectStatement(nil, append([]Projection{projection}, projections...)) } -func toJetProjectionList(projections []Projection) []jet.Projection { - ret := []jet.Projection{} - - for _, projection := range projections { - ret = append(ret, projection) - } - - return ret -} - func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, diff --git a/mysql/types.go b/mysql/types.go index ec85765..b9063f0 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -5,4 +5,12 @@ import "github.com/go-jet/jet/internal/jet" type Statement jet.Statement type Projection jet.Projection -var NewEnumValue = jet.NewEnumValue +func toJetProjectionList(projections []Projection) []jet.Projection { + ret := []jet.Projection{} + + for _, projection := range projections { + ret = append(ret, projection) + } + + return ret +} diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 76fd4be..b8bc3c1 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -37,11 +37,11 @@ type SelectStatement interface { } //SELECT creates new SelectStatement with list of projections -func SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement { - return newSelectStatement(nil, append([]jet.Projection{projection}, projections...)) +func SELECT(projection Projection, projections ...Projection) SelectStatement { + return newSelectStatement(nil, append([]Projection{projection}, projections...)) } -func newSelectStatement(table ReadableTable, projections []jet.Projection) SelectStatement { +func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement { newSelect := &selectStatementImpl{} newSelect.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, @@ -49,7 +49,7 @@ func newSelectStatement(table ReadableTable, projections []jet.Projection) Selec newSelect.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSelect - newSelect.Select.Projections = projections + newSelect.Select.Projections = toJetProjectionList(projections) newSelect.From.Table = table newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 diff --git a/postgres/table.go b/postgres/table.go index 963cad8..4d5bef2 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -4,7 +4,7 @@ import "github.com/go-jet/jet/internal/jet" type readableTable interface { // Generates a select query on the current tableName. - SELECT(projection jet.Projection, projections ...jet.Projection) SelectStatement + SELECT(projection Projection, projections ...Projection) SelectStatement // Creates a inner join tableName Expression using onCondition. INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable @@ -52,8 +52,8 @@ type readableTableInterfaceImpl struct { } // Generates a select query on the current tableName. -func (r *readableTableInterfaceImpl) SELECT(projection1 jet.Projection, projections ...jet.Projection) SelectStatement { - return newSelectStatement(r.parent, append([]jet.Projection{projection1}, projections...)) +func (r *readableTableInterfaceImpl) SELECT(projection1 Projection, projections ...Projection) SelectStatement { + return newSelectStatement(r.parent, append([]Projection{projection1}, projections...)) } // Creates a inner join tableName Expression using onCondition. diff --git a/postgres/types.go b/postgres/types.go new file mode 100644 index 0000000..20e5c4e --- /dev/null +++ b/postgres/types.go @@ -0,0 +1,16 @@ +package postgres + +import "github.com/go-jet/jet/internal/jet" + +type Statement jet.Statement +type Projection jet.Projection + +func toJetProjectionList(projections []Projection) []jet.Projection { + ret := []jet.Projection{} + + for _, projection := range projections { + ret = append(ret, projection) + } + + return ret +} diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index 185c8e8..2b23f6b 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -139,6 +139,35 @@ WHERE link.id = 201; testutils.AssertExec(t, stmt, db) } +func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { + + setupLinkTableForUpdateTest(t) + + link := model.Link{ + ID: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link. + UPDATE(Link.MutableColumns). + MODEL(link). + WHERE(Link.ID.EQ(Int(int64(link.ID)))) + + var expectedSQL = ` +UPDATE test_sample.link +SET url = 'http://www.duckduckgo.com', + name = 'DuckDuckGo', + description = NULL +WHERE link.id = 201; +` + fmt.Println(stmt.DebugSql()) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) + + testutils.AssertExec(t, stmt, db) +} + func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index c925d66..402ae42 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -468,7 +468,7 @@ func TestIntegerOperators(t *testing.T) { AllTypes.SmallInt.BIT_XOR(Int(11)).AS("bit xor 2"), BIT_NOT(Int(-1).MUL(AllTypes.SmallInt)).AS("bit_not_1"), - BIT_NOT(Int(-11, true)).AS("bit_not_2"), + BIT_NOT(Int(-11)).AS("bit_not_2"), AllTypes.SmallInt.BIT_SHIFT_LEFT(AllTypes.SmallInt.DIV(Int(2))).AS("bit shift left 1"), AllTypes.SmallInt.BIT_SHIFT_LEFT(Int(4)).AS("bit shift left 2"),