Add ON DUPLICATE KEY UPDATE support (MySQL).
This commit is contained in:
parent
30284af33e
commit
980b9b6aac
18 changed files with 388 additions and 109 deletions
|
|
@ -518,7 +518,7 @@ type SetPair struct {
|
|||
}
|
||||
|
||||
// SetClause clause
|
||||
type SetClause []SetPair
|
||||
type SetClause []ColumnAssigment
|
||||
|
||||
// Serialize for SetClause
|
||||
func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
|
|
@ -526,16 +526,15 @@ func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, optio
|
|||
out.WriteString("SET")
|
||||
out.IncreaseIdent(4)
|
||||
|
||||
for i, pair := range s {
|
||||
for i, assigment := range s {
|
||||
if i > 0 {
|
||||
out.WriteString(",")
|
||||
out.NewLine()
|
||||
}
|
||||
|
||||
pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...)
|
||||
out.WriteString("=")
|
||||
pair.Value.serialize(statementType, out, FallTrough(options)...)
|
||||
assigment.serialize(statementType, out, FallTrough(options)...)
|
||||
}
|
||||
|
||||
out.DecreaseIdent(4)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -115,56 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
|
|||
out.WriteIdentifier(c.name)
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------//
|
||||
|
||||
// ColumnList is a helper type to support list of columns as single projection
|
||||
type ColumnList []ColumnExpression
|
||||
|
||||
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
|
||||
newProjectionList := ProjectionList{}
|
||||
|
||||
for _, column := range cl {
|
||||
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
|
||||
}
|
||||
|
||||
return newProjectionList
|
||||
}
|
||||
|
||||
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
out.WriteString("(")
|
||||
for i, column := range cl {
|
||||
if i > 0 {
|
||||
out.WriteString(", ")
|
||||
}
|
||||
column.serialize(statement, out, FallTrough(options)...)
|
||||
}
|
||||
out.WriteString(")")
|
||||
}
|
||||
|
||||
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
|
||||
projections := ColumnListToProjectionList(cl)
|
||||
|
||||
SerializeProjectionList(statement, projections, out)
|
||||
}
|
||||
|
||||
// dummy column interface implementation
|
||||
|
||||
// Name is placeholder for ColumnList to implement Column interface
|
||||
func (cl ColumnList) Name() string { return "" }
|
||||
|
||||
// TableName is placeholder for ColumnList to implement Column interface
|
||||
func (cl ColumnList) TableName() string { return "" }
|
||||
func (cl ColumnList) setTableName(name string) {}
|
||||
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
|
||||
func (cl ColumnList) defaultAlias() string { return "" }
|
||||
|
||||
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
|
||||
func SetTableName(columnExpression ColumnExpression, tableName string) {
|
||||
columnExpression.setTableName(tableName)
|
||||
}
|
||||
|
||||
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
|
||||
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
|
||||
columnExpression.setSubQuery(subQuery)
|
||||
}
|
||||
|
|
|
|||
20
internal/jet/column_assigment.go
Normal file
20
internal/jet/column_assigment.go
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
package jet
|
||||
|
||||
// ColumnAssigment is interface wrapper around column assigment
|
||||
type ColumnAssigment interface {
|
||||
Serializer
|
||||
isColumnAssigment()
|
||||
}
|
||||
|
||||
type columnAssigmentImpl struct {
|
||||
column ColumnSerializer
|
||||
expression Expression
|
||||
}
|
||||
|
||||
func (a columnAssigmentImpl) isColumnAssigment() {}
|
||||
|
||||
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
|
||||
out.WriteString("=")
|
||||
a.expression.serialize(statement, out, FallTrough(options)...)
|
||||
}
|
||||
60
internal/jet/column_list.go
Normal file
60
internal/jet/column_list.go
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
package jet
|
||||
|
||||
// ColumnList is a helper type to support list of columns as single projection
|
||||
type ColumnList []ColumnExpression
|
||||
|
||||
// SET creates column assigment for each column in column list. expression should be created by ROW function
|
||||
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: cl,
|
||||
expression: expression,
|
||||
}
|
||||
}
|
||||
|
||||
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
|
||||
newProjectionList := ProjectionList{}
|
||||
|
||||
for _, column := range cl {
|
||||
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
|
||||
}
|
||||
|
||||
return newProjectionList
|
||||
}
|
||||
|
||||
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
out.WriteString("(")
|
||||
for i, column := range cl {
|
||||
if i > 0 {
|
||||
out.WriteString(", ")
|
||||
}
|
||||
column.serialize(statement, out, FallTrough(options)...)
|
||||
}
|
||||
out.WriteString(")")
|
||||
}
|
||||
|
||||
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
|
||||
projections := ColumnListToProjectionList(cl)
|
||||
|
||||
SerializeProjectionList(statement, projections, out)
|
||||
}
|
||||
|
||||
// dummy column interface implementation
|
||||
|
||||
// Name is placeholder for ColumnList to implement Column interface
|
||||
func (cl ColumnList) Name() string { return "" }
|
||||
|
||||
// TableName is placeholder for ColumnList to implement Column interface
|
||||
func (cl ColumnList) TableName() string { return "" }
|
||||
func (cl ColumnList) setTableName(name string) {}
|
||||
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
|
||||
func (cl ColumnList) defaultAlias() string { return "" }
|
||||
|
||||
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
|
||||
func SetTableName(columnExpression ColumnExpression, tableName string) {
|
||||
columnExpression.setTableName(tableName)
|
||||
}
|
||||
|
||||
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
|
||||
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
|
||||
columnExpression.setSubQuery(subQuery)
|
||||
}
|
||||
|
|
@ -6,6 +6,7 @@ type ColumnBool interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnBool
|
||||
SET(boolExp BoolExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type boolColumnImpl struct {
|
||||
|
|
@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
|
|||
return newBoolColumn
|
||||
}
|
||||
|
||||
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: boolExp,
|
||||
}
|
||||
}
|
||||
|
||||
// BoolColumn creates named bool column.
|
||||
func BoolColumn(name string) ColumnBool {
|
||||
boolColumn := &boolColumnImpl{}
|
||||
|
|
@ -38,6 +46,7 @@ type ColumnFloat interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnFloat
|
||||
SET(floatExp FloatExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type floatColumnImpl struct {
|
||||
|
|
@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
|
|||
return newFloatColumn
|
||||
}
|
||||
|
||||
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: floatExp,
|
||||
}
|
||||
}
|
||||
|
||||
// FloatColumn creates named float column.
|
||||
func FloatColumn(name string) ColumnFloat {
|
||||
floatColumn := &floatColumnImpl{}
|
||||
|
|
@ -70,6 +86,7 @@ type ColumnInteger interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnInteger
|
||||
SET(intExp IntegerExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type integerColumnImpl struct {
|
||||
|
|
@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
|
|||
return newIntColumn
|
||||
}
|
||||
|
||||
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: intExp,
|
||||
}
|
||||
}
|
||||
|
||||
// IntegerColumn creates named integer column.
|
||||
func IntegerColumn(name string) ColumnInteger {
|
||||
integerColumn := &integerColumnImpl{}
|
||||
|
|
@ -104,6 +128,7 @@ type ColumnString interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnString
|
||||
SET(stringExp StringExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type stringColumnImpl struct {
|
||||
|
|
@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
|
|||
return newStrColumn
|
||||
}
|
||||
|
||||
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: stringExp,
|
||||
}
|
||||
}
|
||||
|
||||
// StringColumn creates named string column.
|
||||
func StringColumn(name string) ColumnString {
|
||||
stringColumn := &stringColumnImpl{}
|
||||
|
|
@ -137,6 +169,7 @@ type ColumnTime interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnTime
|
||||
SET(timeExp TimeExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type timeColumnImpl struct {
|
||||
|
|
@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
|
|||
return newTimeColumn
|
||||
}
|
||||
|
||||
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timeExp,
|
||||
}
|
||||
}
|
||||
|
||||
// TimeColumn creates named time column
|
||||
func TimeColumn(name string) ColumnTime {
|
||||
timeColumn := &timeColumnImpl{}
|
||||
|
|
@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
|
|||
return newTimezColumn
|
||||
}
|
||||
|
||||
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timezExp,
|
||||
}
|
||||
}
|
||||
|
||||
// TimezColumn creates named time with time zone column.
|
||||
func TimezColumn(name string) ColumnTimez {
|
||||
timezColumn := &timezColumnImpl{}
|
||||
|
|
@ -200,6 +247,7 @@ type ColumnTimestamp interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnTimestamp
|
||||
SET(timestampExp TimestampExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type timestampColumnImpl struct {
|
||||
|
|
@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
|
|||
return newTimestampColumn
|
||||
}
|
||||
|
||||
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timestampExp,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestampColumn creates named timestamp column
|
||||
func TimestampColumn(name string) ColumnTimestamp {
|
||||
timestampColumn := ×tampColumnImpl{}
|
||||
|
|
@ -232,6 +287,7 @@ type ColumnTimestampz interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnTimestampz
|
||||
SET(timestampzExp TimestampzExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type timestampzColumnImpl struct {
|
||||
|
|
@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
|
|||
return newTimestampzColumn
|
||||
}
|
||||
|
||||
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timestampzExp,
|
||||
}
|
||||
}
|
||||
|
||||
// TimestampzColumn creates named timestamp with time zone column.
|
||||
func TimestampzColumn(name string) ColumnTimestampz {
|
||||
timestampzColumn := ×tampzColumnImpl{}
|
||||
|
|
@ -264,6 +327,7 @@ type ColumnDate interface {
|
|||
Column
|
||||
|
||||
From(subQuery SelectTable) ColumnDate
|
||||
SET(dateExp DateExpression) ColumnAssigment
|
||||
}
|
||||
|
||||
type dateColumnImpl struct {
|
||||
|
|
@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
|
|||
return newDateColumn
|
||||
}
|
||||
|
||||
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: dateExp,
|
||||
}
|
||||
}
|
||||
|
||||
// DateColumn creates named date column.
|
||||
func DateColumn(name string) ColumnDate {
|
||||
dateColumn := &dateColumnImpl{}
|
||||
|
|
|
|||
|
|
@ -24,12 +24,12 @@ import (
|
|||
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
|
||||
res, err := stmt.Exec(db)
|
||||
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
rows, err := res.RowsAffected()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
if len(rowsAffected) > 0 {
|
||||
assert.Equal(t, rows, rowsAffected[0])
|
||||
require.Equal(t, rows, rowsAffected[0])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -224,7 +224,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
|
|||
|
||||
// AssertDeepEqual checks if actual and expected objects are deeply equal.
|
||||
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) {
|
||||
assert.True(t, cmp.Equal(actual, expected), msg)
|
||||
require.True(t, cmp.Equal(actual, expected), msg)
|
||||
}
|
||||
|
||||
// BoolPtr returns address of bool parameter
|
||||
|
|
|
|||
|
|
@ -13,13 +13,15 @@ type InsertStatement interface {
|
|||
MODEL(data interface{}) InsertStatement
|
||||
MODELS(data interface{}) InsertStatement
|
||||
|
||||
ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement
|
||||
|
||||
QUERY(selectStatement SelectStatement) InsertStatement
|
||||
}
|
||||
|
||||
func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
|
||||
newInsert := &insertStatementImpl{}
|
||||
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
|
||||
&newInsert.Insert, &newInsert.ValuesQuery)
|
||||
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey)
|
||||
|
||||
newInsert.Insert.Table = table
|
||||
newInsert.Insert.Columns = columns
|
||||
|
|
@ -32,24 +34,53 @@ type insertStatementImpl struct {
|
|||
|
||||
Insert jet.ClauseInsert
|
||||
ValuesQuery jet.ClauseValuesQuery
|
||||
OnDuplicateKey onDuplicateKeyUpdateClause
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
|
||||
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
|
||||
return i
|
||||
func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
|
||||
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
|
||||
return is
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
|
||||
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data))
|
||||
return i
|
||||
func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
|
||||
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
|
||||
return is
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
|
||||
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...)
|
||||
return i
|
||||
func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
|
||||
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
|
||||
return is
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
|
||||
i.ValuesQuery.Query = selectStatement
|
||||
return i
|
||||
func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
|
||||
is.OnDuplicateKey = assigments
|
||||
return is
|
||||
}
|
||||
|
||||
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
|
||||
is.ValuesQuery.Query = selectStatement
|
||||
return is
|
||||
}
|
||||
|
||||
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
|
||||
|
||||
// Serialize for SetClause
|
||||
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
|
||||
if len(s) == 0 {
|
||||
return
|
||||
}
|
||||
out.NewLine()
|
||||
out.WriteString("ON DUPLICATE KEY UPDATE")
|
||||
out.IncreaseIdent(24)
|
||||
|
||||
for i, assigment := range s {
|
||||
if i > 0 {
|
||||
out.WriteString(",")
|
||||
out.NewLine()
|
||||
}
|
||||
|
||||
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
|
||||
}
|
||||
|
||||
out.DecreaseIdent(24)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -133,3 +133,50 @@ VALUES (DEFAULT, ?);
|
|||
|
||||
assertStatementSql(t, stmt, expectedSQL, "two")
|
||||
}
|
||||
|
||||
func TestInsertOnDuplicateKeyUpdate(t *testing.T) {
|
||||
stmt := func() InsertStatement {
|
||||
return table1.INSERT(table1Col1, table1ColFloat).
|
||||
VALUES(DEFAULT, "two")
|
||||
}
|
||||
|
||||
t.Run("empty list", func(t *testing.T) {
|
||||
stmt := stmt().ON_DUPLICATE_KEY_UPDATE()
|
||||
assertStatementSql(t, stmt, `
|
||||
INSERT INTO db.table1 (col1, col_float)
|
||||
VALUES (DEFAULT, ?);
|
||||
`, "two")
|
||||
})
|
||||
|
||||
t.Run("one set", func(t *testing.T) {
|
||||
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1)))
|
||||
assertStatementSql(t, stmt, `
|
||||
INSERT INTO db.table1 (col1, col_float)
|
||||
VALUES (DEFAULT, ?)
|
||||
ON DUPLICATE KEY UPDATE col_float = ?;
|
||||
`, "two", 11.1)
|
||||
})
|
||||
|
||||
t.Run("all types set", func(t *testing.T) {
|
||||
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(
|
||||
table1ColBool.SET(Bool(true)),
|
||||
table1ColInt.SET(Int(11)),
|
||||
table1ColFloat.SET(Float(11.1)),
|
||||
table1ColString.SET(String("str")),
|
||||
table1ColTime.SET(Time(11, 23, 11)),
|
||||
table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)),
|
||||
table1ColDate.SET(Date(2020, 12, 1)),
|
||||
)
|
||||
assertStatementSql(t, stmt, `
|
||||
INSERT INTO db.table1 (col1, col_float)
|
||||
VALUES (DEFAULT, ?)
|
||||
ON DUPLICATE KEY UPDATE col_bool = ?,
|
||||
col_int = ?,
|
||||
col_float = ?,
|
||||
col_string = ?,
|
||||
col_time = CAST(? AS TIME),
|
||||
col_timestamp = TIMESTAMP(?),
|
||||
col_date = CAST(? AS DATE);
|
||||
`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,3 +10,6 @@ type Projection = jet.Projection
|
|||
|
||||
// ProjectionList can be used to create conditional constructed projection list.
|
||||
type ProjectionList = jet.ProjectionList
|
||||
|
||||
// ColumnAssigment is interface wrapper around column assigment
|
||||
type ColumnAssigment = jet.ColumnAssigment
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@ import (
|
|||
)
|
||||
|
||||
var table1Col1 = IntegerColumn("col1")
|
||||
var table1ColBool = BoolColumn("col_bool")
|
||||
var table1ColInt = IntegerColumn("col_int")
|
||||
var table1ColFloat = FloatColumn("col_float")
|
||||
var table1ColString = StringColumn("col_string")
|
||||
var table1Col3 = IntegerColumn("col3")
|
||||
var table1ColTimestamp = TimestampColumn("col_timestamp")
|
||||
var table1ColBool = BoolColumn("col_bool")
|
||||
var table1ColDate = DateColumn("col_date")
|
||||
var table1ColTime = TimeColumn("col_time")
|
||||
|
||||
var table1 = NewTable(
|
||||
"db",
|
||||
|
|
@ -20,10 +22,12 @@ var table1 = NewTable(
|
|||
table1Col1,
|
||||
table1ColInt,
|
||||
table1ColFloat,
|
||||
table1ColString,
|
||||
table1Col3,
|
||||
table1ColBool,
|
||||
table1ColDate,
|
||||
table1ColTimestamp,
|
||||
table1ColTime,
|
||||
)
|
||||
|
||||
var table2Col3 = IntegerColumn("col3")
|
||||
|
|
|
|||
|
|
@ -21,9 +21,10 @@ ON CONFLICT (col_bool) DO NOTHING`)
|
|||
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
|
||||
|
||||
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
|
||||
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE(
|
||||
SET(table1ColBool, Bool(true)).
|
||||
SET(table1ColInt, Int(1)).
|
||||
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).
|
||||
DO_UPDATE(
|
||||
SET(table1ColBool.SET(Bool(true)),
|
||||
table1ColInt.SET(Int(11))).
|
||||
WHERE(table2ColFloat.GT(Float(11.1))),
|
||||
)
|
||||
assertClauseSerialize(t, onConflict, `
|
||||
|
|
|
|||
|
|
@ -4,16 +4,15 @@ import "github.com/go-jet/jet/internal/jet"
|
|||
|
||||
type conflictAction interface {
|
||||
jet.Serializer
|
||||
SET(column jet.ColumnSerializer, expression interface{}) conflictAction
|
||||
WHERE(condition BoolExpression) conflictAction
|
||||
}
|
||||
|
||||
// SET creates conflict action for ON_CONFLICT clause
|
||||
func SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
|
||||
func SET(assigments ...ColumnAssigment) conflictAction {
|
||||
conflictAction := updateConflictActionImpl{}
|
||||
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
|
||||
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
|
||||
conflictAction.SET(column, expression)
|
||||
conflictAction.set = assigments
|
||||
return &conflictAction
|
||||
}
|
||||
|
||||
|
|
@ -25,11 +24,6 @@ type updateConflictActionImpl struct {
|
|||
where jet.ClauseWhere
|
||||
}
|
||||
|
||||
func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
|
||||
u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
|
||||
u.where.Condition = condition
|
||||
return u
|
||||
|
|
|
|||
|
|
@ -153,10 +153,10 @@ func TestInsert_ON_CONFLICT(t *testing.T) {
|
|||
VALUES("1", "2").
|
||||
VALUES("theta", "beta").
|
||||
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
|
||||
SET(table1ColBool, "12").
|
||||
SET(table2ColInt, 1).
|
||||
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
|
||||
WHERE(table1Col1.GT(Int(2))),
|
||||
SET(table1ColBool.SET(Bool(true)),
|
||||
table2ColInt.SET(Int(1)),
|
||||
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
|
||||
).WHERE(table1Col1.GT(Int(2))),
|
||||
).
|
||||
RETURNING(table1Col1, table1ColBool)
|
||||
|
||||
|
|
@ -166,7 +166,7 @@ VALUES ('one', 'two'),
|
|||
('1', '2'),
|
||||
('theta', 'beta')
|
||||
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
|
||||
SET col_bool = '12',
|
||||
SET col_bool = TRUE,
|
||||
col_int = 1,
|
||||
(col1, col_bool) = ROW(2, 'two')
|
||||
WHERE table1.col1 > 2
|
||||
|
|
@ -180,10 +180,10 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
|
|||
VALUES("one", "two").
|
||||
VALUES("1", "2").
|
||||
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
|
||||
SET(table1ColBool, "12").
|
||||
SET(table2ColInt, 1).
|
||||
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
|
||||
WHERE(table1Col1.GT(Int(2))),
|
||||
SET(table1ColBool.SET(Bool(false)),
|
||||
table2ColInt.SET(Int(1)),
|
||||
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
|
||||
).WHERE(table1Col1.GT(Int(2))),
|
||||
).
|
||||
RETURNING(table1Col1, table1ColBool)
|
||||
|
||||
|
|
@ -192,7 +192,7 @@ INSERT INTO db.table1 (col1, col_bool)
|
|||
VALUES ('one', 'two'),
|
||||
('1', '2')
|
||||
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
|
||||
SET col_bool = '12',
|
||||
SET col_bool = FALSE,
|
||||
col_int = 1,
|
||||
(col1, col_bool) = ROW(2, 'two')
|
||||
WHERE table1.col1 > 2
|
||||
|
|
|
|||
|
|
@ -10,3 +10,6 @@ type Projection = jet.Projection
|
|||
|
||||
// ProjectionList can be used to create conditional constructed projection list.
|
||||
type ProjectionList = jet.ProjectionList
|
||||
|
||||
// ColumnAssigment is interface wrapper around column assigment
|
||||
type ColumnAssigment = jet.ColumnAssigment
|
||||
|
|
|
|||
|
|
@ -991,6 +991,53 @@ func TestAllTypesInsert(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
|
||||
tx, err := db.Begin()
|
||||
require.NoError(t, err)
|
||||
|
||||
toInsert := model.AllTypes{
|
||||
Boolean: true,
|
||||
Integer: 124,
|
||||
Float: 45.67,
|
||||
Blob: []byte("blob"),
|
||||
Text: "text",
|
||||
JSON: "{}",
|
||||
Time: time.Now(),
|
||||
Timestamp: time.Now(),
|
||||
Date: time.Now(),
|
||||
}
|
||||
|
||||
stmt := AllTypes.INSERT(
|
||||
AllTypes.Boolean,
|
||||
AllTypes.Integer,
|
||||
AllTypes.Float,
|
||||
AllTypes.Blob,
|
||||
AllTypes.Text,
|
||||
AllTypes.JSON,
|
||||
AllTypes.Time,
|
||||
AllTypes.Timestamp,
|
||||
AllTypes.Date,
|
||||
).
|
||||
MODEL(toInsert).
|
||||
ON_DUPLICATE_KEY_UPDATE(
|
||||
AllTypes.Boolean.SET(Bool(false)),
|
||||
AllTypes.Integer.SET(Int(4)),
|
||||
AllTypes.Float.SET(Float(0.67)),
|
||||
AllTypes.Text.SET(String("new text")),
|
||||
AllTypes.Time.SET(TimeT(time.Now())),
|
||||
AllTypes.Timestamp.SET(TimestampT(time.Now())),
|
||||
AllTypes.Date.SET(DateT(time.Now())),
|
||||
)
|
||||
|
||||
fmt.Println(stmt.DebugSql())
|
||||
|
||||
_, err = stmt.Exec(tx)
|
||||
assert.NoError(t, err)
|
||||
|
||||
err = tx.Rollback()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
var toInsert = model.AllTypes{
|
||||
Boolean: false,
|
||||
BooleanPtr: testutils.BoolPtr(true),
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import (
|
|||
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
|
||||
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
|
@ -248,6 +250,46 @@ INSERT INTO test_sample.link (url, name) (
|
|||
assert.Equal(t, len(youtubeLinks), 2)
|
||||
}
|
||||
|
||||
func TestInsertOnDuplicateKey(t *testing.T) {
|
||||
randId := rand.Int31()
|
||||
|
||||
stmt := Link.INSERT().
|
||||
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
ON_DUPLICATE_KEY_UPDATE(
|
||||
Link.ID.SET(Link.ID.ADD(Int(11))),
|
||||
Link.Name.SET(String("PostgreSQL Tutorial 2")),
|
||||
)
|
||||
|
||||
testutils.AssertStatementSql(t, stmt, `
|
||||
INSERT INTO test_sample.link
|
||||
VALUES (?, ?, ?, DEFAULT),
|
||||
(?, ?, ?, DEFAULT)
|
||||
ON DUPLICATE KEY UPDATE id = (id + ?),
|
||||
name = ?;
|
||||
`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
|
||||
randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
|
||||
int64(11), "PostgreSQL Tutorial 2")
|
||||
|
||||
testutils.AssertExec(t, stmt, db, 3)
|
||||
|
||||
newLinks := []model.Link{}
|
||||
|
||||
err := SELECT(Link.AllColumns).
|
||||
FROM(Link).
|
||||
WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))).
|
||||
Query(db, &newLinks)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, newLinks, 1)
|
||||
require.Equal(t, newLinks[0], model.Link{
|
||||
ID: randId + 11,
|
||||
URL: "http://www.postgresqltutorial.com",
|
||||
Name: "PostgreSQL Tutorial 2",
|
||||
Description: nil,
|
||||
})
|
||||
}
|
||||
|
||||
func TestInsertWithQueryContext(t *testing.T) {
|
||||
cleanUpLinkTable(t)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,8 @@ import (
|
|||
"database/sql"
|
||||
"flag"
|
||||
"github.com/go-jet/jet/tests/dbconfig"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
|
||||
|
|
@ -28,6 +30,7 @@ func sourceIsMariaDB() bool {
|
|||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
rand.Seed(time.Now().Unix())
|
||||
defer profile.Start().Stop()
|
||||
|
||||
var err error
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package postgres
|
|||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-jet/jet/internal/jet"
|
||||
"github.com/go-jet/jet/internal/testutils"
|
||||
. "github.com/go-jet/jet/postgres"
|
||||
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
|
||||
|
|
@ -134,8 +133,10 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING;
|
|||
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
ON_CONFLICT(Link.ID).DO_UPDATE(
|
||||
SET(Link.ID, Link.EXCLUDED.ID).
|
||||
SET(Link.URL, "http://www.postgresqltutorial2.com"),
|
||||
SET(
|
||||
Link.ID.SET(Link.EXCLUDED.ID),
|
||||
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
|
||||
),
|
||||
).
|
||||
RETURNING(Link.AllColumns)
|
||||
|
||||
|
|
@ -161,8 +162,10 @@ RETURNING link.id AS "link.id",
|
|||
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE(
|
||||
SET(Link.ID, Link.EXCLUDED.ID).
|
||||
SET(Link.URL, "http://www.postgresqltutorial2.com"),
|
||||
SET(
|
||||
Link.ID.SET(Link.EXCLUDED.ID),
|
||||
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
|
||||
),
|
||||
).
|
||||
RETURNING(Link.AllColumns)
|
||||
|
||||
|
|
@ -188,9 +191,13 @@ RETURNING link.id AS "link.id",
|
|||
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
|
||||
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE(
|
||||
SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)).
|
||||
SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))).
|
||||
WHERE(Link.Description.IS_NOT_NULL()),
|
||||
SET(
|
||||
Link.ID.SET(
|
||||
IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))).
|
||||
FROM(Link)),
|
||||
),
|
||||
ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String("new description"))),
|
||||
).WHERE(Link.Description.IS_NOT_NULL()),
|
||||
)
|
||||
|
||||
testutils.AssertDebugStatementSql(t, stmt, `
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue