diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 90d194e..349c6dc 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -518,7 +518,7 @@ type SetPair struct { } // SetClause clause -type SetClause []SetPair +type SetClause []ColumnAssigment // Serialize for SetClause func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { @@ -526,16 +526,15 @@ func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, optio out.WriteString("SET") out.IncreaseIdent(4) - for i, pair := range s { + for i, assigment := range s { if i > 0 { out.WriteString(",") out.NewLine() } - pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...) - out.WriteString("=") - pair.Value.serialize(statementType, out, FallTrough(options)...) + assigment.serialize(statementType, out, FallTrough(options)...) } + out.DecreaseIdent(4) } diff --git a/internal/jet/column.go b/internal/jet/column.go index 0fd59be..3e4c300 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -115,56 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder out.WriteIdentifier(c.name) } } - -//------------------------------------------------------// - -// ColumnList is a helper type to support list of columns as single projection -type ColumnList []ColumnExpression - -func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { - newProjectionList := ProjectionList{} - - for _, column := range cl { - newProjectionList = append(newProjectionList, column.fromImpl(subQuery)) - } - - return newProjectionList -} - -func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { - out.WriteString("(") - for i, column := range cl { - if i > 0 { - out.WriteString(", ") - } - column.serialize(statement, out, FallTrough(options)...) - } - out.WriteString(")") -} - -func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) { - projections := ColumnListToProjectionList(cl) - - SerializeProjectionList(statement, projections, out) -} - -// dummy column interface implementation - -// Name is placeholder for ColumnList to implement Column interface -func (cl ColumnList) Name() string { return "" } - -// TableName is placeholder for ColumnList to implement Column interface -func (cl ColumnList) TableName() string { return "" } -func (cl ColumnList) setTableName(name string) {} -func (cl ColumnList) setSubQuery(subQuery SelectTable) {} -func (cl ColumnList) defaultAlias() string { return "" } - -// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName -func SetTableName(columnExpression ColumnExpression, tableName string) { - columnExpression.setTableName(tableName) -} - -// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery -func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) { - columnExpression.setSubQuery(subQuery) -} diff --git a/internal/jet/column_assigment.go b/internal/jet/column_assigment.go new file mode 100644 index 0000000..440f3eb --- /dev/null +++ b/internal/jet/column_assigment.go @@ -0,0 +1,20 @@ +package jet + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment interface { + Serializer + isColumnAssigment() +} + +type columnAssigmentImpl struct { + column ColumnSerializer + expression Expression +} + +func (a columnAssigmentImpl) isColumnAssigment() {} + +func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + a.column.serialize(statement, out, ShortName.WithFallTrough(options)...) + out.WriteString("=") + a.expression.serialize(statement, out, FallTrough(options)...) +} diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go new file mode 100644 index 0000000..8483c76 --- /dev/null +++ b/internal/jet/column_list.go @@ -0,0 +1,60 @@ +package jet + +// ColumnList is a helper type to support list of columns as single projection +type ColumnList []ColumnExpression + +// SET creates column assigment for each column in column list. expression should be created by ROW function +func (cl ColumnList) SET(expression Expression) ColumnAssigment { + return columnAssigmentImpl{ + column: cl, + expression: expression, + } +} + +func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { + newProjectionList := ProjectionList{} + + for _, column := range cl { + newProjectionList = append(newProjectionList, column.fromImpl(subQuery)) + } + + return newProjectionList +} + +func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("(") + for i, column := range cl { + if i > 0 { + out.WriteString(", ") + } + column.serialize(statement, out, FallTrough(options)...) + } + out.WriteString(")") +} + +func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) { + projections := ColumnListToProjectionList(cl) + + SerializeProjectionList(statement, projections, out) +} + +// dummy column interface implementation + +// Name is placeholder for ColumnList to implement Column interface +func (cl ColumnList) Name() string { return "" } + +// TableName is placeholder for ColumnList to implement Column interface +func (cl ColumnList) TableName() string { return "" } +func (cl ColumnList) setTableName(name string) {} +func (cl ColumnList) setSubQuery(subQuery SelectTable) {} +func (cl ColumnList) defaultAlias() string { return "" } + +// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName +func SetTableName(columnExpression ColumnExpression, tableName string) { + columnExpression.setTableName(tableName) +} + +// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery +func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) { + columnExpression.setSubQuery(subQuery) +} diff --git a/internal/jet/column_types.go b/internal/jet/column_types.go index 58f6751..a606a4e 100644 --- a/internal/jet/column_types.go +++ b/internal/jet/column_types.go @@ -6,6 +6,7 @@ type ColumnBool interface { Column From(subQuery SelectTable) ColumnBool + SET(boolExp BoolExpression) ColumnAssigment } type boolColumnImpl struct { @@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool { return newBoolColumn } +func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: boolExp, + } +} + // BoolColumn creates named bool column. func BoolColumn(name string) ColumnBool { boolColumn := &boolColumnImpl{} @@ -38,6 +46,7 @@ type ColumnFloat interface { Column From(subQuery SelectTable) ColumnFloat + SET(floatExp FloatExpression) ColumnAssigment } type floatColumnImpl struct { @@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat { return newFloatColumn } +func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: floatExp, + } +} + // FloatColumn creates named float column. func FloatColumn(name string) ColumnFloat { floatColumn := &floatColumnImpl{} @@ -70,6 +86,7 @@ type ColumnInteger interface { Column From(subQuery SelectTable) ColumnInteger + SET(intExp IntegerExpression) ColumnAssigment } type integerColumnImpl struct { @@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger { return newIntColumn } +func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: intExp, + } +} + // IntegerColumn creates named integer column. func IntegerColumn(name string) ColumnInteger { integerColumn := &integerColumnImpl{} @@ -104,6 +128,7 @@ type ColumnString interface { Column From(subQuery SelectTable) ColumnString + SET(stringExp StringExpression) ColumnAssigment } type stringColumnImpl struct { @@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString { return newStrColumn } +func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: stringExp, + } +} + // StringColumn creates named string column. func StringColumn(name string) ColumnString { stringColumn := &stringColumnImpl{} @@ -137,6 +169,7 @@ type ColumnTime interface { Column From(subQuery SelectTable) ColumnTime + SET(timeExp TimeExpression) ColumnAssigment } type timeColumnImpl struct { @@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime { return newTimeColumn } +func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timeExp, + } +} + // TimeColumn creates named time column func TimeColumn(name string) ColumnTime { timeColumn := &timeColumnImpl{} @@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez { return newTimezColumn } +func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timezExp, + } +} + // TimezColumn creates named time with time zone column. func TimezColumn(name string) ColumnTimez { timezColumn := &timezColumnImpl{} @@ -200,6 +247,7 @@ type ColumnTimestamp interface { Column From(subQuery SelectTable) ColumnTimestamp + SET(timestampExp TimestampExpression) ColumnAssigment } type timestampColumnImpl struct { @@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp { return newTimestampColumn } +func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timestampExp, + } +} + // TimestampColumn creates named timestamp column func TimestampColumn(name string) ColumnTimestamp { timestampColumn := ×tampColumnImpl{} @@ -232,6 +287,7 @@ type ColumnTimestampz interface { Column From(subQuery SelectTable) ColumnTimestampz + SET(timestampzExp TimestampzExpression) ColumnAssigment } type timestampzColumnImpl struct { @@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz { return newTimestampzColumn } +func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: timestampzExp, + } +} + // TimestampzColumn creates named timestamp with time zone column. func TimestampzColumn(name string) ColumnTimestampz { timestampzColumn := ×tampzColumnImpl{} @@ -264,6 +327,7 @@ type ColumnDate interface { Column From(subQuery SelectTable) ColumnDate + SET(dateExp DateExpression) ColumnAssigment } type dateColumnImpl struct { @@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate { return newDateColumn } +func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment { + return columnAssigmentImpl{ + column: i, + expression: dateExp, + } +} + // DateColumn creates named date column. func DateColumn(name string) ColumnDate { dateColumn := &dateColumnImpl{} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index aa53b1e..06035dd 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -24,12 +24,12 @@ import ( func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { res, err := stmt.Exec(db) - assert.NoError(t, err) + require.NoError(t, err) rows, err := res.RowsAffected() - assert.NoError(t, err) + require.NoError(t, err) if len(rowsAffected) > 0 { - assert.Equal(t, rows, rowsAffected[0]) + require.Equal(t, rows, rowsAffected[0]) } } @@ -224,7 +224,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st // AssertDeepEqual checks if actual and expected objects are deeply equal. func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { - assert.True(t, cmp.Equal(actual, expected), msg) + require.True(t, cmp.Equal(actual, expected), msg) } // BoolPtr returns address of bool parameter diff --git a/mysql/insert_statement.go b/mysql/insert_statement.go index a21089c..a4ecc94 100644 --- a/mysql/insert_statement.go +++ b/mysql/insert_statement.go @@ -13,13 +13,15 @@ type InsertStatement interface { MODEL(data interface{}) InsertStatement MODELS(data interface{}) InsertStatement + ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement + QUERY(selectStatement SelectStatement) InsertStatement } func newInsertStatement(table Table, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, - &newInsert.Insert, &newInsert.ValuesQuery) + &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey) newInsert.Insert.Table = table newInsert.Insert.Columns = columns @@ -30,26 +32,55 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement { type insertStatementImpl struct { jet.SerializerStatement - Insert jet.ClauseInsert - ValuesQuery jet.ClauseValuesQuery + Insert jet.ClauseInsert + ValuesQuery jet.ClauseValuesQuery + OnDuplicateKey onDuplicateKeyUpdateClause } -func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { - i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) - return i +func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) + return is } -func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data)) - return i +func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data)) + return is } -func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { - i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...) - return i +func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement { + is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...) + return is } -func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { - i.ValuesQuery.Query = selectStatement - return i +func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement { + is.OnDuplicateKey = assigments + return is +} + +func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + is.ValuesQuery.Query = selectStatement + return is +} + +type onDuplicateKeyUpdateClause []jet.ColumnAssigment + +// Serialize for SetClause +func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(s) == 0 { + return + } + out.NewLine() + out.WriteString("ON DUPLICATE KEY UPDATE") + out.IncreaseIdent(24) + + for i, assigment := range s { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + + jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...) + } + + out.DecreaseIdent(24) } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 65c8fba..95814d2 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -133,3 +133,50 @@ VALUES (DEFAULT, ?); assertStatementSql(t, stmt, expectedSQL, "two") } + +func TestInsertOnDuplicateKeyUpdate(t *testing.T) { + stmt := func() InsertStatement { + return table1.INSERT(table1Col1, table1ColFloat). + VALUES(DEFAULT, "two") + } + + t.Run("empty list", func(t *testing.T) { + stmt := stmt().ON_DUPLICATE_KEY_UPDATE() + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?); +`, "two") + }) + + t.Run("one set", func(t *testing.T) { + stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1))) + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?) +ON DUPLICATE KEY UPDATE col_float = ?; +`, "two", 11.1) + }) + + t.Run("all types set", func(t *testing.T) { + stmt := stmt().ON_DUPLICATE_KEY_UPDATE( + table1ColBool.SET(Bool(true)), + table1ColInt.SET(Int(11)), + table1ColFloat.SET(Float(11.1)), + table1ColString.SET(String("str")), + table1ColTime.SET(Time(11, 23, 11)), + table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)), + table1ColDate.SET(Date(2020, 12, 1)), + ) + assertStatementSql(t, stmt, ` +INSERT INTO db.table1 (col1, col_float) +VALUES (DEFAULT, ?) +ON DUPLICATE KEY UPDATE col_bool = ?, + col_int = ?, + col_float = ?, + col_string = ?, + col_time = CAST(? AS TIME), + col_timestamp = TIMESTAMP(?), + col_date = CAST(? AS DATE); +`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01") + }) +} diff --git a/mysql/types.go b/mysql/types.go index 4ef84b4..908fce5 100644 --- a/mysql/types.go +++ b/mysql/types.go @@ -10,3 +10,6 @@ type Projection = jet.Projection // ProjectionList can be used to create conditional constructed projection list. type ProjectionList = jet.ProjectionList + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment = jet.ColumnAssigment diff --git a/mysql/utils_test.go b/mysql/utils_test.go index 709097d..584bfee 100644 --- a/mysql/utils_test.go +++ b/mysql/utils_test.go @@ -7,12 +7,14 @@ import ( ) var table1Col1 = IntegerColumn("col1") +var table1ColBool = BoolColumn("col_bool") var table1ColInt = IntegerColumn("col_int") var table1ColFloat = FloatColumn("col_float") +var table1ColString = StringColumn("col_string") var table1Col3 = IntegerColumn("col3") var table1ColTimestamp = TimestampColumn("col_timestamp") -var table1ColBool = BoolColumn("col_bool") var table1ColDate = DateColumn("col_date") +var table1ColTime = TimeColumn("col_time") var table1 = NewTable( "db", @@ -20,10 +22,12 @@ var table1 = NewTable( table1Col1, table1ColInt, table1ColFloat, + table1ColString, table1Col3, table1ColBool, table1ColDate, table1ColTimestamp, + table1ColTime, ) var table2Col3 = IntegerColumn("col3") diff --git a/postgres/clause_test.go b/postgres/clause_test.go index 7f64c61..5602505 100644 --- a/postgres/clause_test.go +++ b/postgres/clause_test.go @@ -21,11 +21,12 @@ ON CONFLICT (col_bool) DO NOTHING`) ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}} - onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE( - SET(table1ColBool, Bool(true)). - SET(table1ColInt, Int(1)). - WHERE(table2ColFloat.GT(Float(11.1))), - ) + onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)). + DO_UPDATE( + SET(table1ColBool.SET(Bool(true)), + table1ColInt.SET(Int(11))). + WHERE(table2ColFloat.GT(Float(11.1))), + ) assertClauseSerialize(t, onConflict, ` ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE SET col_bool = $1, diff --git a/postgres/conflict_action.go b/postgres/conflict_action.go index 5ff7a19..b7e9e2e 100644 --- a/postgres/conflict_action.go +++ b/postgres/conflict_action.go @@ -4,16 +4,15 @@ import "github.com/go-jet/jet/internal/jet" type conflictAction interface { jet.Serializer - SET(column jet.ColumnSerializer, expression interface{}) conflictAction WHERE(condition BoolExpression) conflictAction } // SET creates conflict action for ON_CONFLICT clause -func SET(column jet.ColumnSerializer, expression interface{}) conflictAction { +func SET(assigments ...ColumnAssigment) conflictAction { conflictAction := updateConflictActionImpl{} conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) - conflictAction.SET(column, expression) + conflictAction.set = assigments return &conflictAction } @@ -25,11 +24,6 @@ type updateConflictActionImpl struct { where jet.ClauseWhere } -func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction { - u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)}) - return u -} - func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { u.where.Condition = condition return u diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index d80c1a1..0761cf3 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -153,10 +153,10 @@ func TestInsert_ON_CONFLICT(t *testing.T) { VALUES("1", "2"). VALUES("theta", "beta"). ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( - SET(table1ColBool, "12"). - SET(table2ColInt, 1). - SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). - WHERE(table1Col1.GT(Int(2))), + SET(table1ColBool.SET(Bool(true)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), ). RETURNING(table1Col1, table1ColBool) @@ -166,7 +166,7 @@ VALUES ('one', 'two'), ('1', '2'), ('theta', 'beta') ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE - SET col_bool = '12', + SET col_bool = TRUE, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 @@ -180,10 +180,10 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) { VALUES("one", "two"). VALUES("1", "2"). ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( - SET(table1ColBool, "12"). - SET(table2ColInt, 1). - SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). - WHERE(table1Col1.GT(Int(2))), + SET(table1ColBool.SET(Bool(false)), + table2ColInt.SET(Int(1)), + ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))), + ).WHERE(table1Col1.GT(Int(2))), ). RETURNING(table1Col1, table1ColBool) @@ -192,7 +192,7 @@ INSERT INTO db.table1 (col1, col_bool) VALUES ('one', 'two'), ('1', '2') ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE - SET col_bool = '12', + SET col_bool = FALSE, col_int = 1, (col1, col_bool) = ROW(2, 'two') WHERE table1.col1 > 2 diff --git a/postgres/types.go b/postgres/types.go index 58a8ae9..48de455 100644 --- a/postgres/types.go +++ b/postgres/types.go @@ -10,3 +10,6 @@ type Projection = jet.Projection // ProjectionList can be used to create conditional constructed projection list. type ProjectionList = jet.ProjectionList + +// ColumnAssigment is interface wrapper around column assigment +type ColumnAssigment = jet.ColumnAssigment diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 2f269ac..42e8aa3 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -991,6 +991,53 @@ func TestAllTypesInsert(t *testing.T) { require.NoError(t, err) } +func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) { + tx, err := db.Begin() + require.NoError(t, err) + + toInsert := model.AllTypes{ + Boolean: true, + Integer: 124, + Float: 45.67, + Blob: []byte("blob"), + Text: "text", + JSON: "{}", + Time: time.Now(), + Timestamp: time.Now(), + Date: time.Now(), + } + + stmt := AllTypes.INSERT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.Float, + AllTypes.Blob, + AllTypes.Text, + AllTypes.JSON, + AllTypes.Time, + AllTypes.Timestamp, + AllTypes.Date, + ). + MODEL(toInsert). + ON_DUPLICATE_KEY_UPDATE( + AllTypes.Boolean.SET(Bool(false)), + AllTypes.Integer.SET(Int(4)), + AllTypes.Float.SET(Float(0.67)), + AllTypes.Text.SET(String("new text")), + AllTypes.Time.SET(TimeT(time.Now())), + AllTypes.Timestamp.SET(TimestampT(time.Now())), + AllTypes.Date.SET(DateT(time.Now())), + ) + + fmt.Println(stmt.DebugSql()) + + _, err = stmt.Exec(tx) + assert.NoError(t, err) + + err = tx.Rollback() + require.NoError(t, err) +} + var toInsert = model.AllTypes{ Boolean: false, BooleanPtr: testutils.BoolPtr(true), diff --git a/tests/mysql/insert_test.go b/tests/mysql/insert_test.go index cb27f61..0b1fd2e 100644 --- a/tests/mysql/insert_test.go +++ b/tests/mysql/insert_test.go @@ -7,6 +7,8 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "math/rand" "testing" "time" ) @@ -248,6 +250,46 @@ INSERT INTO test_sample.link (url, name) ( assert.Equal(t, len(youtubeLinks), 2) } +func TestInsertOnDuplicateKey(t *testing.T) { + randId := rand.Int31() + + stmt := Link.INSERT(). + VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + ON_DUPLICATE_KEY_UPDATE( + Link.ID.SET(Link.ID.ADD(Int(11))), + Link.Name.SET(String("PostgreSQL Tutorial 2")), + ) + + testutils.AssertStatementSql(t, stmt, ` +INSERT INTO test_sample.link +VALUES (?, ?, ?, DEFAULT), + (?, ?, ?, DEFAULT) +ON DUPLICATE KEY UPDATE id = (id + ?), + name = ?; +`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + int64(11), "PostgreSQL Tutorial 2") + + testutils.AssertExec(t, stmt, db, 3) + + newLinks := []model.Link{} + + err := SELECT(Link.AllColumns). + FROM(Link). + WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))). + Query(db, &newLinks) + + require.NoError(t, err) + require.Len(t, newLinks, 1) + require.Equal(t, newLinks[0], model.Link{ + ID: randId + 11, + URL: "http://www.postgresqltutorial.com", + Name: "PostgreSQL Tutorial 2", + Description: nil, + }) +} + func TestInsertWithQueryContext(t *testing.T) { cleanUpLinkTable(t) diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index eb19a68..c7db884 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -4,6 +4,8 @@ import ( "database/sql" "flag" "github.com/go-jet/jet/tests/dbconfig" + "math/rand" + "time" _ "github.com/go-sql-driver/mysql" @@ -28,6 +30,7 @@ func sourceIsMariaDB() bool { } func TestMain(m *testing.M) { + rand.Seed(time.Now().Unix()) defer profile.Start().Stop() var err error diff --git a/tests/postgres/insert_test.go b/tests/postgres/insert_test.go index ecbd3c5..7367e1d 100644 --- a/tests/postgres/insert_test.go +++ b/tests/postgres/insert_test.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "github.com/go-jet/jet/internal/jet" "github.com/go-jet/jet/internal/testutils" . "github.com/go-jet/jet/postgres" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" @@ -134,8 +133,10 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING; VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). ON_CONFLICT(Link.ID).DO_UPDATE( - SET(Link.ID, Link.EXCLUDED.ID). - SET(Link.URL, "http://www.postgresqltutorial2.com"), + SET( + Link.ID.SET(Link.EXCLUDED.ID), + Link.URL.SET(String("http://www.postgresqltutorial2.com")), + ), ). RETURNING(Link.AllColumns) @@ -161,8 +162,10 @@ RETURNING link.id AS "link.id", VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE( - SET(Link.ID, Link.EXCLUDED.ID). - SET(Link.URL, "http://www.postgresqltutorial2.com"), + SET( + Link.ID.SET(Link.EXCLUDED.ID), + Link.URL.SET(String("http://www.postgresqltutorial2.com")), + ), ). RETURNING(Link.AllColumns) @@ -188,9 +191,13 @@ RETURNING link.id AS "link.id", stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE( - SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)). - SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))). - WHERE(Link.Description.IS_NOT_NULL()), + SET( + Link.ID.SET( + IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))). + FROM(Link)), + ), + ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String("new description"))), + ).WHERE(Link.Description.IS_NOT_NULL()), ) testutils.AssertDebugStatementSql(t, stmt, `