[postgres] Add support for ON CONFLICT clause
This commit is contained in:
parent
eea776a1ac
commit
14e1863456
42 changed files with 827 additions and 277 deletions
87
postgres/clause.go
Normal file
87
postgres/clause.go
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/go-jet/jet/internal/jet"
|
||||
)
|
||||
|
||||
type clauseReturning struct {
|
||||
Projections []jet.Projection
|
||||
}
|
||||
|
||||
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
|
||||
if len(r.Projections) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
out.NewLine()
|
||||
out.WriteString("RETURNING")
|
||||
out.IncreaseIdent()
|
||||
out.WriteProjections(statementType, r.Projections)
|
||||
}
|
||||
|
||||
// ========================================== //
|
||||
|
||||
type onConflict interface {
|
||||
ON_CONSTRAINT(name string) conflictTarget
|
||||
WHERE(indexPredicate BoolExpression) conflictTarget
|
||||
DO_NOTHING() InsertStatement
|
||||
DO_UPDATE(action conflictAction) InsertStatement
|
||||
}
|
||||
|
||||
type conflictTarget interface {
|
||||
DO_NOTHING() InsertStatement
|
||||
DO_UPDATE(action conflictAction) InsertStatement
|
||||
}
|
||||
|
||||
type onConflictClause struct {
|
||||
insertStatement InsertStatement
|
||||
constraint string
|
||||
indexExpressions []jet.ColumnExpression
|
||||
whereClause jet.ClauseWhere
|
||||
do jet.Serializer
|
||||
}
|
||||
|
||||
func (o *onConflictClause) ON_CONSTRAINT(name string) conflictTarget {
|
||||
o.constraint = name
|
||||
return o
|
||||
}
|
||||
|
||||
func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget {
|
||||
o.whereClause.Condition = indexPredicate
|
||||
return o
|
||||
}
|
||||
|
||||
func (o *onConflictClause) DO_NOTHING() InsertStatement {
|
||||
o.do = jet.Keyword("DO NOTHING")
|
||||
return o.insertStatement
|
||||
}
|
||||
|
||||
func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement {
|
||||
o.do = action
|
||||
return o.insertStatement
|
||||
}
|
||||
|
||||
func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
|
||||
if len(o.indexExpressions) == 0 && o.constraint == "" {
|
||||
return
|
||||
}
|
||||
|
||||
out.NewLine()
|
||||
out.WriteString("ON CONFLICT")
|
||||
if len(o.indexExpressions) > 0 {
|
||||
out.WriteString("(")
|
||||
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
|
||||
out.WriteString(")")
|
||||
}
|
||||
|
||||
if o.constraint != "" {
|
||||
out.WriteString("ON CONSTRAINT")
|
||||
out.WriteString(o.constraint)
|
||||
}
|
||||
|
||||
o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName)
|
||||
|
||||
out.IncreaseIdent(7)
|
||||
jet.Serialize(o.do, statementType, out)
|
||||
out.DecreaseIdent(7)
|
||||
}
|
||||
34
postgres/clause_test.go
Normal file
34
postgres/clause_test.go
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
package postgres
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestOnConflict(t *testing.T) {
|
||||
|
||||
assertClauseSerialize(t, &onConflictClause{}, "")
|
||||
|
||||
onConflict := &onConflictClause{}
|
||||
onConflict.DO_NOTHING()
|
||||
assertClauseSerialize(t, onConflict, "")
|
||||
|
||||
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
|
||||
onConflict.DO_NOTHING()
|
||||
assertClauseSerialize(t, onConflict, `
|
||||
ON CONFLICT (col_bool) DO NOTHING`)
|
||||
|
||||
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
|
||||
onConflict.ON_CONSTRAINT("table_pkey").DO_NOTHING()
|
||||
assertClauseSerialize(t, onConflict, `
|
||||
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
|
||||
|
||||
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
|
||||
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE(
|
||||
SET(table1ColBool, Bool(true)).
|
||||
SET(table1ColInt, Int(1)).
|
||||
WHERE(table2ColFloat.GT(Float(11.1))),
|
||||
)
|
||||
assertClauseSerialize(t, onConflict, `
|
||||
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
|
||||
SET col_bool = $1,
|
||||
col_int = $2
|
||||
WHERE table2.col_float > $3`)
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/go-jet/jet/internal/jet"
|
||||
)
|
||||
|
||||
type clauseReturning struct {
|
||||
Projections []jet.Projection
|
||||
}
|
||||
|
||||
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
|
||||
if len(r.Projections) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
out.NewLine()
|
||||
out.WriteString("RETURNING")
|
||||
out.IncreaseIdent()
|
||||
out.WriteProjections(statementType, r.Projections)
|
||||
}
|
||||
36
postgres/conflict_action.go
Normal file
36
postgres/conflict_action.go
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
package postgres
|
||||
|
||||
import "github.com/go-jet/jet/internal/jet"
|
||||
|
||||
type conflictAction interface {
|
||||
jet.Serializer
|
||||
SET(column jet.ColumnSerializer, expression interface{}) conflictAction
|
||||
WHERE(condition BoolExpression) conflictAction
|
||||
}
|
||||
|
||||
// SET creates conflict action for ON_CONFLICT clause
|
||||
func SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
|
||||
conflictAction := updateConflictActionImpl{}
|
||||
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
|
||||
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
|
||||
conflictAction.SET(column, expression)
|
||||
return &conflictAction
|
||||
}
|
||||
|
||||
type updateConflictActionImpl struct {
|
||||
jet.Serializer
|
||||
|
||||
doUpdate jet.KeywordClause
|
||||
set jet.SetClause
|
||||
where jet.ClauseWhere
|
||||
}
|
||||
|
||||
func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
|
||||
u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)})
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
|
||||
u.where.Condition = condition
|
||||
return u
|
||||
}
|
||||
|
|
@ -11,18 +11,18 @@ type InsertStatement interface {
|
|||
// Insert row of values, where value for each column is extracted from filed of structure data.
|
||||
// If data is not struct or there is no field for every column selected, this method will panic.
|
||||
MODEL(data interface{}) InsertStatement
|
||||
|
||||
MODELS(data interface{}) InsertStatement
|
||||
|
||||
QUERY(selectStatement SelectStatement) InsertStatement
|
||||
|
||||
RETURNING(projections ...jet.Projection) InsertStatement
|
||||
ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict
|
||||
|
||||
RETURNING(projections ...Projection) InsertStatement
|
||||
}
|
||||
|
||||
func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement {
|
||||
newInsert := &insertStatementImpl{}
|
||||
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
|
||||
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.Returning)
|
||||
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning)
|
||||
|
||||
newInsert.Insert.Table = table
|
||||
newInsert.Insert.Columns = columns
|
||||
|
|
@ -36,6 +36,7 @@ type insertStatementImpl struct {
|
|||
Insert jet.ClauseInsert
|
||||
ValuesQuery jet.ClauseValuesQuery
|
||||
Returning clauseReturning
|
||||
OnConflict onConflictClause
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
|
||||
|
|
@ -62,3 +63,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
|
|||
i.ValuesQuery.Query = selectStatement
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
|
||||
i.OnConflict = onConflictClause{
|
||||
insertStatement: i,
|
||||
indexExpressions: indexExpressions,
|
||||
}
|
||||
return &i.OnConflict
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/go-jet/jet/internal/jet"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
"time"
|
||||
|
|
@ -13,15 +14,15 @@ func TestInvalidInsert(t *testing.T) {
|
|||
|
||||
func TestInsertNilValue(t *testing.T) {
|
||||
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
|
||||
INSERT INTO db.table1 (col1) VALUES
|
||||
($1);
|
||||
INSERT INTO db.table1 (col1)
|
||||
VALUES ($1);
|
||||
`, nil)
|
||||
}
|
||||
|
||||
func TestInsertSingleValue(t *testing.T) {
|
||||
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
|
||||
INSERT INTO db.table1 (col1) VALUES
|
||||
($1);
|
||||
INSERT INTO db.table1 (col1)
|
||||
VALUES ($1);
|
||||
`, int(1))
|
||||
}
|
||||
|
||||
|
|
@ -29,8 +30,8 @@ func TestInsertWithColumnList(t *testing.T) {
|
|||
columnList := ColumnList{table3ColInt, table3StrCol}
|
||||
|
||||
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
|
||||
INSERT INTO db.table3 (col_int, col2) VALUES
|
||||
($1, $2);
|
||||
INSERT INTO db.table3 (col_int, col2)
|
||||
VALUES ($1, $2);
|
||||
`, 1, 3)
|
||||
}
|
||||
|
||||
|
|
@ -38,15 +39,15 @@ func TestInsertDate(t *testing.T) {
|
|||
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
|
||||
|
||||
assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), `
|
||||
INSERT INTO db.table1 (col_time) VALUES
|
||||
($1);
|
||||
INSERT INTO db.table1 (col_time)
|
||||
VALUES ($1);
|
||||
`, date)
|
||||
}
|
||||
|
||||
func TestInsertMultipleValues(t *testing.T) {
|
||||
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
|
||||
INSERT INTO db.table1 (col1, col_float, col3) VALUES
|
||||
($1, $2, $3);
|
||||
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1ColBool).VALUES(1, 2, 3), `
|
||||
INSERT INTO db.table1 (col1, col_float, col_bool)
|
||||
VALUES ($1, $2, $3);
|
||||
`, 1, 2, 3)
|
||||
}
|
||||
|
||||
|
|
@ -57,10 +58,10 @@ func TestInsertMultipleRows(t *testing.T) {
|
|||
VALUES(111, 222)
|
||||
|
||||
assertStatementSql(t, stmt, `
|
||||
INSERT INTO db.table1 (col1, col_float) VALUES
|
||||
($1, $2),
|
||||
($3, $4),
|
||||
($5, $6);
|
||||
INSERT INTO db.table1 (col1, col_float)
|
||||
VALUES ($1, $2),
|
||||
($3, $4),
|
||||
($5, $6);
|
||||
`, 1, 2, 11, 22, 111, 222)
|
||||
}
|
||||
|
||||
|
|
@ -82,12 +83,12 @@ func TestInsertValuesFromModel(t *testing.T) {
|
|||
MODEL(&toInsert)
|
||||
|
||||
expectedSQL := `
|
||||
INSERT INTO db.table1 (col1, col_float) VALUES
|
||||
($1, $2),
|
||||
($3, $4);
|
||||
INSERT INTO db.table1 (col1, col_float)
|
||||
VALUES ($1, $2),
|
||||
($3, $4);
|
||||
`
|
||||
|
||||
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
|
||||
assertStatementSql(t, stmt, expectedSQL, 1, float64(1.11), 1, float64(1.11))
|
||||
}
|
||||
|
||||
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
|
||||
|
|
@ -139,9 +140,63 @@ func TestInsertDefaultValue(t *testing.T) {
|
|||
VALUES(DEFAULT, "two")
|
||||
|
||||
var expectedSQL = `
|
||||
INSERT INTO db.table1 (col1, col_float) VALUES
|
||||
(DEFAULT, $1);
|
||||
INSERT INTO db.table1 (col1, col_float)
|
||||
VALUES (DEFAULT, $1);
|
||||
`
|
||||
|
||||
assertStatementSql(t, stmt, expectedSQL, "two")
|
||||
}
|
||||
|
||||
func TestInsert_ON_CONFLICT(t *testing.T) {
|
||||
stmt := table1.INSERT(table1Col1, table1ColBool).
|
||||
VALUES("one", "two").
|
||||
VALUES("1", "2").
|
||||
VALUES("theta", "beta").
|
||||
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
|
||||
SET(table1ColBool, "12").
|
||||
SET(table2ColInt, 1).
|
||||
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
|
||||
WHERE(table1Col1.GT(Int(2))),
|
||||
).
|
||||
RETURNING(table1Col1, table1ColBool)
|
||||
|
||||
assertDebugStatementSql(t, stmt, `
|
||||
INSERT INTO db.table1 (col1, col_bool)
|
||||
VALUES ('one', 'two'),
|
||||
('1', '2'),
|
||||
('theta', 'beta')
|
||||
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
|
||||
SET col_bool = '12',
|
||||
col_int = 1,
|
||||
(col1, col_bool) = ROW(2, 'two')
|
||||
WHERE table1.col1 > 2
|
||||
RETURNING table1.col1 AS "table1.col1",
|
||||
table1.col_bool AS "table1.col_bool";
|
||||
`)
|
||||
}
|
||||
|
||||
func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
|
||||
stmt := table1.INSERT(table1Col1, table1ColBool).
|
||||
VALUES("one", "two").
|
||||
VALUES("1", "2").
|
||||
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
|
||||
SET(table1ColBool, "12").
|
||||
SET(table2ColInt, 1).
|
||||
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
|
||||
WHERE(table1Col1.GT(Int(2))),
|
||||
).
|
||||
RETURNING(table1Col1, table1ColBool)
|
||||
|
||||
assertDebugStatementSql(t, stmt, `
|
||||
INSERT INTO db.table1 (col1, col_bool)
|
||||
VALUES ('one', 'two'),
|
||||
('1', '2')
|
||||
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
|
||||
SET col_bool = '12',
|
||||
col_int = 1,
|
||||
(col1, col_bool) = ROW(2, 'two')
|
||||
WHERE table1.col1 > 2
|
||||
RETURNING table1.col1 AS "table1.col1",
|
||||
table1.col_bool AS "table1.col_bool";
|
||||
`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -109,10 +109,10 @@ type tableImpl struct {
|
|||
}
|
||||
|
||||
// NewTable creates new table with schema Name, table Name and list of columns
|
||||
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
|
||||
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
|
||||
|
||||
t := &tableImpl{
|
||||
SerializerTable: jet.NewTable(schemaName, name, column, columns...),
|
||||
SerializerTable: jet.NewTable(schemaName, name, columns...),
|
||||
}
|
||||
|
||||
t.readableTableInterfaceImpl.parent = t
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ import (
|
|||
)
|
||||
|
||||
func TestJoinNilInputs(t *testing.T) {
|
||||
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
|
||||
assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
|
||||
"jet: right hand side of join operation is nil table")
|
||||
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil),
|
||||
assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
|
||||
"jet: join condition is nil")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ type clauseSet struct {
|
|||
Values []jet.Serializer
|
||||
}
|
||||
|
||||
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
|
||||
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
|
||||
out.NewLine()
|
||||
out.WriteString("SET")
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
var table1Col1 = IntegerColumn("col1")
|
||||
var table1ColInt = IntegerColumn("col_int")
|
||||
var table1ColFloat = FloatColumn("col_float")
|
||||
var table1Col3 = IntegerColumn("col3")
|
||||
var table1ColTime = TimeColumn("col_time")
|
||||
var table1ColTimez = TimezColumn("col_timez")
|
||||
var table1ColTimestamp = TimestampColumn("col_timestamp")
|
||||
|
|
@ -25,7 +24,6 @@ var table1 = NewTable(
|
|||
table1Col1,
|
||||
table1ColInt,
|
||||
table1ColFloat,
|
||||
table1Col3,
|
||||
table1ColTime,
|
||||
table1ColTimez,
|
||||
table1ColBool,
|
||||
|
|
@ -75,12 +73,16 @@ var table3 = NewTable(
|
|||
table3ColInt,
|
||||
table3StrCol)
|
||||
|
||||
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
|
||||
func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) {
|
||||
testutils.AssertSerialize(t, Dialect, serializer, query, args...)
|
||||
}
|
||||
|
||||
func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) {
|
||||
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
|
||||
}
|
||||
|
||||
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
|
||||
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString)
|
||||
func assertSerializeErr(t *testing.T, serializer jet.Serializer, errString string) {
|
||||
testutils.AssertSerializeErr(t, Dialect, serializer, errString)
|
||||
}
|
||||
|
||||
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {
|
||||
|
|
@ -88,5 +90,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st
|
|||
}
|
||||
|
||||
var assertStatementSql = testutils.AssertStatementSql
|
||||
var assertDebugStatementSql = testutils.AssertDebugStatementSql
|
||||
var assertStatementSqlErr = testutils.AssertStatementSqlErr
|
||||
var assertPanicErr = testutils.AssertPanicErr
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue