Add ON DUPLICATE KEY UPDATE support (MySQL).

This commit is contained in:
go-jet 2020-05-03 20:46:21 +02:00
parent 30284af33e
commit 980b9b6aac
18 changed files with 388 additions and 109 deletions

View file

@ -518,7 +518,7 @@ type SetPair struct {
} }
// SetClause clause // SetClause clause
type SetClause []SetPair type SetClause []ColumnAssigment
// Serialize for SetClause // Serialize for SetClause
func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
@ -526,16 +526,15 @@ func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, optio
out.WriteString("SET") out.WriteString("SET")
out.IncreaseIdent(4) out.IncreaseIdent(4)
for i, pair := range s { for i, assigment := range s {
if i > 0 { if i > 0 {
out.WriteString(",") out.WriteString(",")
out.NewLine() out.NewLine()
} }
pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...) assigment.serialize(statementType, out, FallTrough(options)...)
out.WriteString("=")
pair.Value.serialize(statementType, out, FallTrough(options)...)
} }
out.DecreaseIdent(4) out.DecreaseIdent(4)
} }

View file

@ -115,56 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
out.WriteIdentifier(c.name) out.WriteIdentifier(c.name)
} }
} }
//------------------------------------------------------//
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
for i, column := range cl {
if i > 0 {
out.WriteString(", ")
}
column.serialize(statement, out, FallTrough(options)...)
}
out.WriteString(")")
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -0,0 +1,20 @@
package jet
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment interface {
Serializer
isColumnAssigment()
}
type columnAssigmentImpl struct {
column ColumnSerializer
expression Expression
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
out.WriteString("=")
a.expression.serialize(statement, out, FallTrough(options)...)
}

View file

@ -0,0 +1,60 @@
package jet
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
// SET creates column assigment for each column in column list. expression should be created by ROW function
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
return columnAssigmentImpl{
column: cl,
expression: expression,
}
}
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
for i, column := range cl {
if i > 0 {
out.WriteString(", ")
}
column.serialize(statement, out, FallTrough(options)...)
}
out.WriteString(")")
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -6,6 +6,7 @@ type ColumnBool interface {
Column Column
From(subQuery SelectTable) ColumnBool From(subQuery SelectTable) ColumnBool
SET(boolExp BoolExpression) ColumnAssigment
} }
type boolColumnImpl struct { type boolColumnImpl struct {
@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
return newBoolColumn return newBoolColumn
} }
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: boolExp,
}
}
// BoolColumn creates named bool column. // BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool { func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{} boolColumn := &boolColumnImpl{}
@ -38,6 +46,7 @@ type ColumnFloat interface {
Column Column
From(subQuery SelectTable) ColumnFloat From(subQuery SelectTable) ColumnFloat
SET(floatExp FloatExpression) ColumnAssigment
} }
type floatColumnImpl struct { type floatColumnImpl struct {
@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
return newFloatColumn return newFloatColumn
} }
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: floatExp,
}
}
// FloatColumn creates named float column. // FloatColumn creates named float column.
func FloatColumn(name string) ColumnFloat { func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{} floatColumn := &floatColumnImpl{}
@ -70,6 +86,7 @@ type ColumnInteger interface {
Column Column
From(subQuery SelectTable) ColumnInteger From(subQuery SelectTable) ColumnInteger
SET(intExp IntegerExpression) ColumnAssigment
} }
type integerColumnImpl struct { type integerColumnImpl struct {
@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return newIntColumn return newIntColumn
} }
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intExp,
}
}
// IntegerColumn creates named integer column. // IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger { func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{} integerColumn := &integerColumnImpl{}
@ -104,6 +128,7 @@ type ColumnString interface {
Column Column
From(subQuery SelectTable) ColumnString From(subQuery SelectTable) ColumnString
SET(stringExp StringExpression) ColumnAssigment
} }
type stringColumnImpl struct { type stringColumnImpl struct {
@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return newStrColumn return newStrColumn
} }
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: stringExp,
}
}
// StringColumn creates named string column. // StringColumn creates named string column.
func StringColumn(name string) ColumnString { func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{} stringColumn := &stringColumnImpl{}
@ -137,6 +169,7 @@ type ColumnTime interface {
Column Column
From(subQuery SelectTable) ColumnTime From(subQuery SelectTable) ColumnTime
SET(timeExp TimeExpression) ColumnAssigment
} }
type timeColumnImpl struct { type timeColumnImpl struct {
@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return newTimeColumn return newTimeColumn
} }
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timeExp,
}
}
// TimeColumn creates named time column // TimeColumn creates named time column
func TimeColumn(name string) ColumnTime { func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{} timeColumn := &timeColumnImpl{}
@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return newTimezColumn return newTimezColumn
} }
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timezExp,
}
}
// TimezColumn creates named time with time zone column. // TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez { func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{} timezColumn := &timezColumnImpl{}
@ -200,6 +247,7 @@ type ColumnTimestamp interface {
Column Column
From(subQuery SelectTable) ColumnTimestamp From(subQuery SelectTable) ColumnTimestamp
SET(timestampExp TimestampExpression) ColumnAssigment
} }
type timestampColumnImpl struct { type timestampColumnImpl struct {
@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return newTimestampColumn return newTimestampColumn
} }
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampExp,
}
}
// TimestampColumn creates named timestamp column // TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp { func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{} timestampColumn := &timestampColumnImpl{}
@ -232,6 +287,7 @@ type ColumnTimestampz interface {
Column Column
From(subQuery SelectTable) ColumnTimestampz From(subQuery SelectTable) ColumnTimestampz
SET(timestampzExp TimestampzExpression) ColumnAssigment
} }
type timestampzColumnImpl struct { type timestampzColumnImpl struct {
@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return newTimestampzColumn return newTimestampzColumn
} }
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampzExp,
}
}
// TimestampzColumn creates named timestamp with time zone column. // TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz { func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{} timestampzColumn := &timestampzColumnImpl{}
@ -264,6 +327,7 @@ type ColumnDate interface {
Column Column
From(subQuery SelectTable) ColumnDate From(subQuery SelectTable) ColumnDate
SET(dateExp DateExpression) ColumnAssigment
} }
type dateColumnImpl struct { type dateColumnImpl struct {
@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return newDateColumn return newDateColumn
} }
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: dateExp,
}
}
// DateColumn creates named date column. // DateColumn creates named date column.
func DateColumn(name string) ColumnDate { func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{} dateColumn := &dateColumnImpl{}

View file

@ -24,12 +24,12 @@ import (
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) { func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db) res, err := stmt.Exec(db)
assert.NoError(t, err) require.NoError(t, err)
rows, err := res.RowsAffected() rows, err := res.RowsAffected()
assert.NoError(t, err) require.NoError(t, err)
if len(rowsAffected) > 0 { if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0]) require.Equal(t, rows, rowsAffected[0])
} }
} }
@ -224,7 +224,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
// AssertDeepEqual checks if actual and expected objects are deeply equal. // AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) {
assert.True(t, cmp.Equal(actual, expected), msg) require.True(t, cmp.Equal(actual, expected), msg)
} }
// BoolPtr returns address of bool parameter // BoolPtr returns address of bool parameter

View file

@ -13,13 +13,15 @@ type InsertStatement interface {
MODEL(data interface{}) InsertStatement MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement MODELS(data interface{}) InsertStatement
ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement
} }
func newInsertStatement(table Table, columns []jet.Column) InsertStatement { func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{} newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery) &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey)
newInsert.Insert.Table = table newInsert.Insert.Table = table
newInsert.Insert.Columns = columns newInsert.Insert.Columns = columns
@ -30,26 +32,55 @@ func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
type insertStatementImpl struct { type insertStatementImpl struct {
jet.SerializerStatement jet.SerializerStatement
Insert jet.ClauseInsert Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
} }
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values)) is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return i return is
} }
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data)) is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
return i return is
} }
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement { func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...) is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
return i return is
} }
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
i.ValuesQuery.Query = selectStatement is.OnDuplicateKey = assigments
return i return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
} }

View file

@ -133,3 +133,50 @@ VALUES (DEFAULT, ?);
assertStatementSql(t, stmt, expectedSQL, "two") assertStatementSql(t, stmt, expectedSQL, "two")
} }
func TestInsertOnDuplicateKeyUpdate(t *testing.T) {
stmt := func() InsertStatement {
return table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
}
t.Run("empty list", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE()
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?);
`, "two")
})
t.Run("one set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1)))
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_float = ?;
`, "two", 11.1)
})
t.Run("all types set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(
table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11)),
table1ColFloat.SET(Float(11.1)),
table1ColString.SET(String("str")),
table1ColTime.SET(Time(11, 23, 11)),
table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)),
table1ColDate.SET(Date(2020, 12, 1)),
)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_bool = ?,
col_int = ?,
col_float = ?,
col_string = ?,
col_time = CAST(? AS TIME),
col_timestamp = TIMESTAMP(?),
col_date = CAST(? AS DATE);
`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01")
})
}

View file

@ -10,3 +10,6 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list. // ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment

View file

@ -7,12 +7,14 @@ import (
) )
var table1Col1 = IntegerColumn("col1") var table1Col1 = IntegerColumn("col1")
var table1ColBool = BoolColumn("col_bool")
var table1ColInt = IntegerColumn("col_int") var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float") var table1ColFloat = FloatColumn("col_float")
var table1ColString = StringColumn("col_string")
var table1Col3 = IntegerColumn("col3") var table1Col3 = IntegerColumn("col3")
var table1ColTimestamp = TimestampColumn("col_timestamp") var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date") var table1ColDate = DateColumn("col_date")
var table1ColTime = TimeColumn("col_time")
var table1 = NewTable( var table1 = NewTable(
"db", "db",
@ -20,10 +22,12 @@ var table1 = NewTable(
table1Col1, table1Col1,
table1ColInt, table1ColInt,
table1ColFloat, table1ColFloat,
table1ColString,
table1Col3, table1Col3,
table1ColBool, table1ColBool,
table1ColDate, table1ColDate,
table1ColTimestamp, table1ColTimestamp,
table1ColTime,
) )
var table2Col3 = IntegerColumn("col3") var table2Col3 = IntegerColumn("col3")

View file

@ -21,11 +21,12 @@ ON CONFLICT (col_bool) DO NOTHING`)
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`) ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}} onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE( onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).
SET(table1ColBool, Bool(true)). DO_UPDATE(
SET(table1ColInt, Int(1)). SET(table1ColBool.SET(Bool(true)),
WHERE(table2ColFloat.GT(Float(11.1))), table1ColInt.SET(Int(11))).
) WHERE(table2ColFloat.GT(Float(11.1))),
)
assertClauseSerialize(t, onConflict, ` assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
SET col_bool = $1, SET col_bool = $1,

View file

@ -4,16 +4,15 @@ import "github.com/go-jet/jet/internal/jet"
type conflictAction interface { type conflictAction interface {
jet.Serializer jet.Serializer
SET(column jet.ColumnSerializer, expression interface{}) conflictAction
WHERE(condition BoolExpression) conflictAction WHERE(condition BoolExpression) conflictAction
} }
// SET creates conflict action for ON_CONFLICT clause // SET creates conflict action for ON_CONFLICT clause
func SET(column jet.ColumnSerializer, expression interface{}) conflictAction { func SET(assigments ...ColumnAssigment) conflictAction {
conflictAction := updateConflictActionImpl{} conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"} conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where) conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.SET(column, expression) conflictAction.set = assigments
return &conflictAction return &conflictAction
} }
@ -25,11 +24,6 @@ type updateConflictActionImpl struct {
where jet.ClauseWhere where jet.ClauseWhere
} }
func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)})
return u
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction { func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition u.where.Condition = condition
return u return u

View file

@ -153,10 +153,10 @@ func TestInsert_ON_CONFLICT(t *testing.T) {
VALUES("1", "2"). VALUES("1", "2").
VALUES("theta", "beta"). VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE( ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool, "12"). SET(table1ColBool.SET(Bool(true)),
SET(table2ColInt, 1). table2ColInt.SET(Int(1)),
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
WHERE(table1Col1.GT(Int(2))), ).WHERE(table1Col1.GT(Int(2))),
). ).
RETURNING(table1Col1, table1ColBool) RETURNING(table1Col1, table1ColBool)
@ -166,7 +166,7 @@ VALUES ('one', 'two'),
('1', '2'), ('1', '2'),
('theta', 'beta') ('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = '12', SET col_bool = TRUE,
col_int = 1, col_int = 1,
(col1, col_bool) = ROW(2, 'two') (col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2 WHERE table1.col1 > 2
@ -180,10 +180,10 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
VALUES("one", "two"). VALUES("one", "two").
VALUES("1", "2"). VALUES("1", "2").
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE( ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
SET(table1ColBool, "12"). SET(table1ColBool.SET(Bool(false)),
SET(table2ColInt, 1). table2ColInt.SET(Int(1)),
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))). ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
WHERE(table1Col1.GT(Int(2))), ).WHERE(table1Col1.GT(Int(2))),
). ).
RETURNING(table1Col1, table1ColBool) RETURNING(table1Col1, table1ColBool)
@ -192,7 +192,7 @@ INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'), VALUES ('one', 'two'),
('1', '2') ('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = '12', SET col_bool = FALSE,
col_int = 1, col_int = 1,
(col1, col_bool) = ROW(2, 'two') (col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2 WHERE table1.col1 > 2

View file

@ -10,3 +10,6 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list. // ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment

View file

@ -991,6 +991,53 @@ func TestAllTypesInsert(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
} }
func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
toInsert := model.AllTypes{
Boolean: true,
Integer: 124,
Float: 45.67,
Blob: []byte("blob"),
Text: "text",
JSON: "{}",
Time: time.Now(),
Timestamp: time.Now(),
Date: time.Now(),
}
stmt := AllTypes.INSERT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Float,
AllTypes.Blob,
AllTypes.Text,
AllTypes.JSON,
AllTypes.Time,
AllTypes.Timestamp,
AllTypes.Date,
).
MODEL(toInsert).
ON_DUPLICATE_KEY_UPDATE(
AllTypes.Boolean.SET(Bool(false)),
AllTypes.Integer.SET(Int(4)),
AllTypes.Float.SET(Float(0.67)),
AllTypes.Text.SET(String("new text")),
AllTypes.Time.SET(TimeT(time.Now())),
AllTypes.Timestamp.SET(TimestampT(time.Now())),
AllTypes.Date.SET(DateT(time.Now())),
)
fmt.Println(stmt.DebugSql())
_, err = stmt.Exec(tx)
assert.NoError(t, err)
err = tx.Rollback()
require.NoError(t, err)
}
var toInsert = model.AllTypes{ var toInsert = model.AllTypes{
Boolean: false, Boolean: false,
BooleanPtr: testutils.BoolPtr(true), BooleanPtr: testutils.BoolPtr(true),

View file

@ -7,6 +7,8 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"math/rand"
"testing" "testing"
"time" "time"
) )
@ -248,6 +250,46 @@ INSERT INTO test_sample.link (url, name) (
assert.Equal(t, len(youtubeLinks), 2) assert.Equal(t, len(youtubeLinks), 2)
} }
func TestInsertOnDuplicateKey(t *testing.T) {
randId := rand.Int31()
stmt := Link.INSERT().
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_DUPLICATE_KEY_UPDATE(
Link.ID.SET(Link.ID.ADD(Int(11))),
Link.Name.SET(String("PostgreSQL Tutorial 2")),
)
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (?, ?, ?, DEFAULT),
(?, ?, ?, DEFAULT)
ON DUPLICATE KEY UPDATE id = (id + ?),
name = ?;
`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
int64(11), "PostgreSQL Tutorial 2")
testutils.AssertExec(t, stmt, db, 3)
newLinks := []model.Link{}
err := SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))).
Query(db, &newLinks)
require.NoError(t, err)
require.Len(t, newLinks, 1)
require.Equal(t, newLinks[0], model.Link{
ID: randId + 11,
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial 2",
Description: nil,
})
}
func TestInsertWithQueryContext(t *testing.T) { func TestInsertWithQueryContext(t *testing.T) {
cleanUpLinkTable(t) cleanUpLinkTable(t)

View file

@ -4,6 +4,8 @@ import (
"database/sql" "database/sql"
"flag" "flag"
"github.com/go-jet/jet/tests/dbconfig" "github.com/go-jet/jet/tests/dbconfig"
"math/rand"
"time"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
@ -28,6 +30,7 @@ func sourceIsMariaDB() bool {
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop() defer profile.Start().Stop()
var err error var err error

View file

@ -2,7 +2,6 @@ package postgres
import ( import (
"context" "context"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/testutils" "github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/postgres" . "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
@ -134,8 +133,10 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING;
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT(Link.ID).DO_UPDATE( ON_CONFLICT(Link.ID).DO_UPDATE(
SET(Link.ID, Link.EXCLUDED.ID). SET(
SET(Link.URL, "http://www.postgresqltutorial2.com"), Link.ID.SET(Link.EXCLUDED.ID),
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
),
). ).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
@ -161,8 +162,10 @@ RETURNING link.id AS "link.id",
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE( ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE(
SET(Link.ID, Link.EXCLUDED.ID). SET(
SET(Link.URL, "http://www.postgresqltutorial2.com"), Link.ID.SET(Link.EXCLUDED.ID),
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
),
). ).
RETURNING(Link.AllColumns) RETURNING(Link.AllColumns)
@ -188,9 +191,13 @@ RETURNING link.id AS "link.id",
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description). stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE( ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE(
SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)). SET(
SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))). Link.ID.SET(
WHERE(Link.Description.IS_NOT_NULL()), IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))).
FROM(Link)),
),
ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String("new description"))),
).WHERE(Link.Description.IS_NOT_NULL()),
) )
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `