Add ON DUPLICATE KEY UPDATE support (MySQL).

This commit is contained in:
go-jet 2020-05-03 20:46:21 +02:00
parent 30284af33e
commit 980b9b6aac
18 changed files with 388 additions and 109 deletions

View file

@ -518,7 +518,7 @@ type SetPair struct {
}
// SetClause clause
type SetClause []SetPair
type SetClause []ColumnAssigment
// Serialize for SetClause
func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
@ -526,16 +526,15 @@ func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, optio
out.WriteString("SET")
out.IncreaseIdent(4)
for i, pair := range s {
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
pair.Column.serialize(statementType, out, ShortName.WithFallTrough(options)...)
out.WriteString("=")
pair.Value.serialize(statementType, out, FallTrough(options)...)
assigment.serialize(statementType, out, FallTrough(options)...)
}
out.DecreaseIdent(4)
}

View file

@ -115,56 +115,3 @@ func (c ColumnExpressionImpl) serialize(statement StatementType, out *SQLBuilder
out.WriteIdentifier(c.name)
}
}
//------------------------------------------------------//
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
for i, column := range cl {
if i > 0 {
out.WriteString(", ")
}
column.serialize(statement, out, FallTrough(options)...)
}
out.WriteString(")")
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -0,0 +1,20 @@
package jet
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment interface {
Serializer
isColumnAssigment()
}
type columnAssigmentImpl struct {
column ColumnSerializer
expression Expression
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
out.WriteString("=")
a.expression.serialize(statement, out, FallTrough(options)...)
}

View file

@ -0,0 +1,60 @@
package jet
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
// SET creates column assigment for each column in column list. expression should be created by ROW function
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
return columnAssigmentImpl{
column: cl,
expression: expression,
}
}
func (cl ColumnList) fromImpl(subQuery SelectTable) Projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.fromImpl(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
out.WriteString("(")
for i, column := range cl {
if i > 0 {
out.WriteString(", ")
}
column.serialize(statement, out, FallTrough(options)...)
}
out.WriteString(")")
}
func (cl ColumnList) serializeForProjection(statement StatementType, out *SQLBuilder) {
projections := ColumnListToProjectionList(cl)
SerializeProjectionList(statement, projections, out)
}
// dummy column interface implementation
// Name is placeholder for ColumnList to implement Column interface
func (cl ColumnList) Name() string { return "" }
// TableName is placeholder for ColumnList to implement Column interface
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery SelectTable) {}
func (cl ColumnList) defaultAlias() string { return "" }
// SetTableName is utility function to set table name from outside of jet package to avoid making public setTableName
func SetTableName(columnExpression ColumnExpression, tableName string) {
columnExpression.setTableName(tableName)
}
// SetSubQuery is utility function to set table name from outside of jet package to avoid making public setSubQuery
func SetSubQuery(columnExpression ColumnExpression, subQuery SelectTable) {
columnExpression.setSubQuery(subQuery)
}

View file

@ -6,6 +6,7 @@ type ColumnBool interface {
Column
From(subQuery SelectTable) ColumnBool
SET(boolExp BoolExpression) ColumnAssigment
}
type boolColumnImpl struct {
@ -21,6 +22,13 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
return newBoolColumn
}
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: boolExp,
}
}
// BoolColumn creates named bool column.
func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{}
@ -38,6 +46,7 @@ type ColumnFloat interface {
Column
From(subQuery SelectTable) ColumnFloat
SET(floatExp FloatExpression) ColumnAssigment
}
type floatColumnImpl struct {
@ -53,6 +62,13 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
return newFloatColumn
}
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: floatExp,
}
}
// FloatColumn creates named float column.
func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{}
@ -70,6 +86,7 @@ type ColumnInteger interface {
Column
From(subQuery SelectTable) ColumnInteger
SET(intExp IntegerExpression) ColumnAssigment
}
type integerColumnImpl struct {
@ -86,6 +103,13 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
return newIntColumn
}
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intExp,
}
}
// IntegerColumn creates named integer column.
func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{}
@ -104,6 +128,7 @@ type ColumnString interface {
Column
From(subQuery SelectTable) ColumnString
SET(stringExp StringExpression) ColumnAssigment
}
type stringColumnImpl struct {
@ -120,6 +145,13 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
return newStrColumn
}
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: stringExp,
}
}
// StringColumn creates named string column.
func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{}
@ -137,6 +169,7 @@ type ColumnTime interface {
Column
From(subQuery SelectTable) ColumnTime
SET(timeExp TimeExpression) ColumnAssigment
}
type timeColumnImpl struct {
@ -152,6 +185,13 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
return newTimeColumn
}
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timeExp,
}
}
// TimeColumn creates named time column
func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{}
@ -183,6 +223,13 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
return newTimezColumn
}
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timezExp,
}
}
// TimezColumn creates named time with time zone column.
func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{}
@ -200,6 +247,7 @@ type ColumnTimestamp interface {
Column
From(subQuery SelectTable) ColumnTimestamp
SET(timestampExp TimestampExpression) ColumnAssigment
}
type timestampColumnImpl struct {
@ -215,6 +263,13 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
return newTimestampColumn
}
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampExp,
}
}
// TimestampColumn creates named timestamp column
func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{}
@ -232,6 +287,7 @@ type ColumnTimestampz interface {
Column
From(subQuery SelectTable) ColumnTimestampz
SET(timestampzExp TimestampzExpression) ColumnAssigment
}
type timestampzColumnImpl struct {
@ -247,6 +303,13 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
return newTimestampzColumn
}
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampzExp,
}
}
// TimestampzColumn creates named timestamp with time zone column.
func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{}
@ -264,6 +327,7 @@ type ColumnDate interface {
Column
From(subQuery SelectTable) ColumnDate
SET(dateExp DateExpression) ColumnAssigment
}
type dateColumnImpl struct {
@ -279,6 +343,13 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
return newDateColumn
}
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: dateExp,
}
}
// DateColumn creates named date column.
func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{}

View file

@ -24,12 +24,12 @@ import (
func AssertExec(t *testing.T, stmt jet.Statement, db qrm.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db)
assert.NoError(t, err)
require.NoError(t, err)
rows, err := res.RowsAffected()
assert.NoError(t, err)
require.NoError(t, err)
if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0])
require.Equal(t, rows, rowsAffected[0])
}
}
@ -224,7 +224,7 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
// AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) {
assert.True(t, cmp.Equal(actual, expected), msg)
require.True(t, cmp.Equal(actual, expected), msg)
}
// BoolPtr returns address of bool parameter

View file

@ -13,13 +13,15 @@ type InsertStatement interface {
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
}
func newInsertStatement(table Table, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery)
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnDuplicateKey)
newInsert.Insert.Table = table
newInsert.Insert.Columns = columns
@ -32,24 +34,53 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
OnDuplicateKey onDuplicateKeyUpdateClause
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return i
func (is *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromValues(value, values))
return is
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowFromModel(i.Insert.GetColumns(), data))
return i
func (is *insertStatementImpl) MODEL(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowFromModel(is.Insert.GetColumns(), data))
return is
}
func (i *insertStatementImpl) MODELS(data interface{}) InsertStatement {
i.ValuesQuery.Rows = append(i.ValuesQuery.Rows, jet.UnwindRowsFromModels(i.Insert.GetColumns(), data)...)
return i
func (is *insertStatementImpl) MODELS(data interface{}) InsertStatement {
is.ValuesQuery.Rows = append(is.ValuesQuery.Rows, jet.UnwindRowsFromModels(is.Insert.GetColumns(), data)...)
return is
}
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.ValuesQuery.Query = selectStatement
return i
func (is *insertStatementImpl) ON_DUPLICATE_KEY_UPDATE(assigments ...ColumnAssigment) InsertStatement {
is.OnDuplicateKey = assigments
return is
}
func (is *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
is.ValuesQuery.Query = selectStatement
return is
}
type onDuplicateKeyUpdateClause []jet.ColumnAssigment
// Serialize for SetClause
func (s onDuplicateKeyUpdateClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine()
out.WriteString("ON DUPLICATE KEY UPDATE")
out.IncreaseIdent(24)
for i, assigment := range s {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
jet.Serialize(assigment, statementType, out, jet.ShortName.WithFallTrough(options)...)
}
out.DecreaseIdent(24)
}

View file

@ -133,3 +133,50 @@ VALUES (DEFAULT, ?);
assertStatementSql(t, stmt, expectedSQL, "two")
}
func TestInsertOnDuplicateKeyUpdate(t *testing.T) {
stmt := func() InsertStatement {
return table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
}
t.Run("empty list", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE()
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?);
`, "two")
})
t.Run("one set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(table1ColFloat.SET(Float(11.1)))
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_float = ?;
`, "two", 11.1)
})
t.Run("all types set", func(t *testing.T) {
stmt := stmt().ON_DUPLICATE_KEY_UPDATE(
table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11)),
table1ColFloat.SET(Float(11.1)),
table1ColString.SET(String("str")),
table1ColTime.SET(Time(11, 23, 11)),
table1ColTimestamp.SET(Timestamp(2020, 1, 22, 3, 4, 5)),
table1ColDate.SET(Date(2020, 12, 1)),
)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, ?)
ON DUPLICATE KEY UPDATE col_bool = ?,
col_int = ?,
col_float = ?,
col_string = ?,
col_time = CAST(? AS TIME),
col_timestamp = TIMESTAMP(?),
col_date = CAST(? AS DATE);
`, "two", true, int64(11), 11.1, "str", "11:23:11", "2020-01-22 03:04:05", "2020-12-01")
})
}

View file

@ -10,3 +10,6 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment

View file

@ -7,12 +7,14 @@ import (
)
var table1Col1 = IntegerColumn("col1")
var table1ColBool = BoolColumn("col_bool")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1ColString = StringColumn("col_string")
var table1Col3 = IntegerColumn("col3")
var table1ColTimestamp = TimestampColumn("col_timestamp")
var table1ColBool = BoolColumn("col_bool")
var table1ColDate = DateColumn("col_date")
var table1ColTime = TimeColumn("col_time")
var table1 = NewTable(
"db",
@ -20,10 +22,12 @@ var table1 = NewTable(
table1Col1,
table1ColInt,
table1ColFloat,
table1ColString,
table1Col3,
table1ColBool,
table1ColDate,
table1ColTimestamp,
table1ColTime,
)
var table2Col3 = IntegerColumn("col3")

View file

@ -21,9 +21,10 @@ ON CONFLICT (col_bool) DO NOTHING`)
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE(
SET(table1ColBool, Bool(true)).
SET(table1ColInt, Int(1)).
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).
DO_UPDATE(
SET(table1ColBool.SET(Bool(true)),
table1ColInt.SET(Int(11))).
WHERE(table2ColFloat.GT(Float(11.1))),
)
assertClauseSerialize(t, onConflict, `

View file

@ -4,16 +4,15 @@ import "github.com/go-jet/jet/internal/jet"
type conflictAction interface {
jet.Serializer
SET(column jet.ColumnSerializer, expression interface{}) conflictAction
WHERE(condition BoolExpression) conflictAction
}
// SET creates conflict action for ON_CONFLICT clause
func SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
func SET(assigments ...ColumnAssigment) conflictAction {
conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.SET(column, expression)
conflictAction.set = assigments
return &conflictAction
}
@ -25,11 +24,6 @@ type updateConflictActionImpl struct {
where jet.ClauseWhere
}
func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)})
return u
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition
return u

View file

@ -153,10 +153,10 @@ func TestInsert_ON_CONFLICT(t *testing.T) {
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool, "12").
SET(table2ColInt, 1).
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
WHERE(table1Col1.GT(Int(2))),
SET(table1ColBool.SET(Bool(true)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
@ -166,7 +166,7 @@ VALUES ('one', 'two'),
('1', '2'),
('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = '12',
SET col_bool = TRUE,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
@ -180,10 +180,10 @@ func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
VALUES("one", "two").
VALUES("1", "2").
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
SET(table1ColBool, "12").
SET(table2ColInt, 1).
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
WHERE(table1Col1.GT(Int(2))),
SET(table1ColBool.SET(Bool(false)),
table2ColInt.SET(Int(1)),
ColumnList{table1Col1, table1ColBool}.SET(jet.ROW(Int(2), String("two"))),
).WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
@ -192,7 +192,7 @@ INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = '12',
SET col_bool = FALSE,
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2

View file

@ -10,3 +10,6 @@ type Projection = jet.Projection
// ProjectionList can be used to create conditional constructed projection list.
type ProjectionList = jet.ProjectionList
// ColumnAssigment is interface wrapper around column assigment
type ColumnAssigment = jet.ColumnAssigment

View file

@ -991,6 +991,53 @@ func TestAllTypesInsert(t *testing.T) {
require.NoError(t, err)
}
func TestAllTypesInsertOnDuplicateKeyUpdate(t *testing.T) {
tx, err := db.Begin()
require.NoError(t, err)
toInsert := model.AllTypes{
Boolean: true,
Integer: 124,
Float: 45.67,
Blob: []byte("blob"),
Text: "text",
JSON: "{}",
Time: time.Now(),
Timestamp: time.Now(),
Date: time.Now(),
}
stmt := AllTypes.INSERT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Float,
AllTypes.Blob,
AllTypes.Text,
AllTypes.JSON,
AllTypes.Time,
AllTypes.Timestamp,
AllTypes.Date,
).
MODEL(toInsert).
ON_DUPLICATE_KEY_UPDATE(
AllTypes.Boolean.SET(Bool(false)),
AllTypes.Integer.SET(Int(4)),
AllTypes.Float.SET(Float(0.67)),
AllTypes.Text.SET(String("new text")),
AllTypes.Time.SET(TimeT(time.Now())),
AllTypes.Timestamp.SET(TimestampT(time.Now())),
AllTypes.Date.SET(DateT(time.Now())),
)
fmt.Println(stmt.DebugSql())
_, err = stmt.Exec(tx)
assert.NoError(t, err)
err = tx.Rollback()
require.NoError(t, err)
}
var toInsert = model.AllTypes{
Boolean: false,
BooleanPtr: testutils.BoolPtr(true),

View file

@ -7,6 +7,8 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"math/rand"
"testing"
"time"
)
@ -248,6 +250,46 @@ INSERT INTO test_sample.link (url, name) (
assert.Equal(t, len(youtubeLinks), 2)
}
func TestInsertOnDuplicateKey(t *testing.T) {
randId := rand.Int31()
stmt := Link.INSERT().
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_DUPLICATE_KEY_UPDATE(
Link.ID.SET(Link.ID.ADD(Int(11))),
Link.Name.SET(String("PostgreSQL Tutorial 2")),
)
testutils.AssertStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (?, ?, ?, DEFAULT),
(?, ?, ?, DEFAULT)
ON DUPLICATE KEY UPDATE id = (id + ?),
name = ?;
`, randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
randId, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial",
int64(11), "PostgreSQL Tutorial 2")
testutils.AssertExec(t, stmt, db, 3)
newLinks := []model.Link{}
err := SELECT(Link.AllColumns).
FROM(Link).
WHERE(Link.ID.EQ(Int(int64(randId)).ADD(Int(11)))).
Query(db, &newLinks)
require.NoError(t, err)
require.Len(t, newLinks, 1)
require.Equal(t, newLinks[0], model.Link{
ID: randId + 11,
URL: "http://www.postgresqltutorial.com",
Name: "PostgreSQL Tutorial 2",
Description: nil,
})
}
func TestInsertWithQueryContext(t *testing.T) {
cleanUpLinkTable(t)

View file

@ -4,6 +4,8 @@ import (
"database/sql"
"flag"
"github.com/go-jet/jet/tests/dbconfig"
"math/rand"
"time"
_ "github.com/go-sql-driver/mysql"
@ -28,6 +30,7 @@ func sourceIsMariaDB() bool {
}
func TestMain(m *testing.M) {
rand.Seed(time.Now().Unix())
defer profile.Start().Stop()
var err error

View file

@ -2,7 +2,6 @@ package postgres
import (
"context"
"github.com/go-jet/jet/internal/jet"
"github.com/go-jet/jet/internal/testutils"
. "github.com/go-jet/jet/postgres"
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
@ -134,8 +133,10 @@ ON CONFLICT ON CONSTRAINT employee_pkey DO NOTHING;
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT(Link.ID).DO_UPDATE(
SET(Link.ID, Link.EXCLUDED.ID).
SET(Link.URL, "http://www.postgresqltutorial2.com"),
SET(
Link.ID.SET(Link.EXCLUDED.ID),
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
),
).
RETURNING(Link.AllColumns)
@ -161,8 +162,10 @@ RETURNING link.id AS "link.id",
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT().ON_CONSTRAINT("link_pkey").DO_UPDATE(
SET(Link.ID, Link.EXCLUDED.ID).
SET(Link.URL, "http://www.postgresqltutorial2.com"),
SET(
Link.ID.SET(Link.EXCLUDED.ID),
Link.URL.SET(String("http://www.postgresqltutorial2.com")),
),
).
RETURNING(Link.AllColumns)
@ -188,9 +191,13 @@ RETURNING link.id AS "link.id",
stmt := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT(Link.ID).WHERE(Link.ID.MUL(Int(2)).GT(Int(10))).DO_UPDATE(
SET(Link.ID, SELECT(MAXi(Link.ID).ADD(Int(1))).FROM(Link)).
SET(ColumnList{Link.Name, Link.Description}, jet.ROW(Link.EXCLUDED.Name, String("new description"))).
WHERE(Link.Description.IS_NOT_NULL()),
SET(
Link.ID.SET(
IntExp(SELECT(MAXi(Link.ID).ADD(Int(1))).
FROM(Link)),
),
ColumnList{Link.Name, Link.Description}.SET(ROW(Link.EXCLUDED.Name, String("new description"))),
).WHERE(Link.Description.IS_NOT_NULL()),
)
testutils.AssertDebugStatementSql(t, stmt, `