[postgres] Add support for ON CONFLICT clause

This commit is contained in:
go-jet 2020-04-12 18:53:57 +02:00
parent eea776a1ac
commit 14e1863456
42 changed files with 827 additions and 277 deletions

87
postgres/clause.go Normal file
View file

@ -0,0 +1,87 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type clauseReturning struct {
Projections []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(r.Projections) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.Projections)
}
// ========================================== //
type onConflict interface {
ON_CONSTRAINT(name string) conflictTarget
WHERE(indexPredicate BoolExpression) conflictTarget
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type conflictTarget interface {
DO_NOTHING() InsertStatement
DO_UPDATE(action conflictAction) InsertStatement
}
type onConflictClause struct {
insertStatement InsertStatement
constraint string
indexExpressions []jet.ColumnExpression
whereClause jet.ClauseWhere
do jet.Serializer
}
func (o *onConflictClause) ON_CONSTRAINT(name string) conflictTarget {
o.constraint = name
return o
}
func (o *onConflictClause) WHERE(indexPredicate BoolExpression) conflictTarget {
o.whereClause.Condition = indexPredicate
return o
}
func (o *onConflictClause) DO_NOTHING() InsertStatement {
o.do = jet.Keyword("DO NOTHING")
return o.insertStatement
}
func (o *onConflictClause) DO_UPDATE(action conflictAction) InsertStatement {
o.do = action
return o.insertStatement
}
func (o *onConflictClause) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(o.indexExpressions) == 0 && o.constraint == "" {
return
}
out.NewLine()
out.WriteString("ON CONFLICT")
if len(o.indexExpressions) > 0 {
out.WriteString("(")
jet.SerializeColumnExpressionNames(o.indexExpressions, statementType, out, jet.ShortName)
out.WriteString(")")
}
if o.constraint != "" {
out.WriteString("ON CONSTRAINT")
out.WriteString(o.constraint)
}
o.whereClause.Serialize(statementType, out, jet.SkipNewLine, jet.ShortName)
out.IncreaseIdent(7)
jet.Serialize(o.do, statementType, out)
out.DecreaseIdent(7)
}

34
postgres/clause_test.go Normal file
View file

@ -0,0 +1,34 @@
package postgres
import "testing"
func TestOnConflict(t *testing.T) {
assertClauseSerialize(t, &onConflictClause{}, "")
onConflict := &onConflictClause{}
onConflict.DO_NOTHING()
assertClauseSerialize(t, onConflict, "")
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
onConflict.DO_NOTHING()
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool) DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool}}
onConflict.ON_CONSTRAINT("table_pkey").DO_NOTHING()
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool) ON CONSTRAINT table_pkey DO NOTHING`)
onConflict = &onConflictClause{indexExpressions: ColumnList{table1ColBool, table2ColFloat}}
onConflict.WHERE(table2ColFloat.ADD(table1ColInt).GT(table1ColFloat)).DO_UPDATE(
SET(table1ColBool, Bool(true)).
SET(table1ColInt, Int(1)).
WHERE(table2ColFloat.GT(Float(11.1))),
)
assertClauseSerialize(t, onConflict, `
ON CONFLICT (col_bool, col_float) WHERE (col_float + col_int) > col_float DO UPDATE
SET col_bool = $1,
col_int = $2
WHERE table2.col_float > $3`)
}

View file

@ -1,20 +0,0 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
)
type clauseReturning struct {
Projections []jet.Projection
}
func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
if len(r.Projections) == 0 {
return
}
out.NewLine()
out.WriteString("RETURNING")
out.IncreaseIdent()
out.WriteProjections(statementType, r.Projections)
}

View file

@ -0,0 +1,36 @@
package postgres
import "github.com/go-jet/jet/internal/jet"
type conflictAction interface {
jet.Serializer
SET(column jet.ColumnSerializer, expression interface{}) conflictAction
WHERE(condition BoolExpression) conflictAction
}
// SET creates conflict action for ON_CONFLICT clause
func SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
conflictAction := updateConflictActionImpl{}
conflictAction.doUpdate = jet.KeywordClause{Keyword: "DO UPDATE"}
conflictAction.Serializer = jet.NewSerializerClauseImpl(&conflictAction.doUpdate, &conflictAction.set, &conflictAction.where)
conflictAction.SET(column, expression)
return &conflictAction
}
type updateConflictActionImpl struct {
jet.Serializer
doUpdate jet.KeywordClause
set jet.SetClause
where jet.ClauseWhere
}
func (u *updateConflictActionImpl) SET(column jet.ColumnSerializer, expression interface{}) conflictAction {
u.set = append(u.set, jet.SetPair{Column: column, Value: jet.ToSerializerValue(expression)})
return u
}
func (u *updateConflictActionImpl) WHERE(condition BoolExpression) conflictAction {
u.where.Condition = condition
return u
}

View file

@ -11,18 +11,18 @@ type InsertStatement interface {
// Insert row of values, where value for each column is extracted from filed of structure data.
// If data is not struct or there is no field for every column selected, this method will panic.
MODEL(data interface{}) InsertStatement
MODELS(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...jet.Projection) InsertStatement
ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict
RETURNING(projections ...Projection) InsertStatement
}
func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement {
newInsert := &insertStatementImpl{}
newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert,
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.Returning)
&newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning)
newInsert.Insert.Table = table
newInsert.Insert.Columns = columns
@ -36,6 +36,7 @@ type insertStatementImpl struct {
Insert jet.ClauseInsert
ValuesQuery jet.ClauseValuesQuery
Returning clauseReturning
OnConflict onConflictClause
}
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
@ -62,3 +63,11 @@ func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertState
i.ValuesQuery.Query = selectStatement
return i
}
func (i *insertStatementImpl) ON_CONFLICT(indexExpressions ...jet.ColumnExpression) onConflict {
i.OnConflict = onConflictClause{
insertStatement: i,
indexExpressions: indexExpressions,
}
return &i.OnConflict
}

View file

@ -1,6 +1,7 @@
package postgres
import (
"github.com/go-jet/jet/internal/jet"
"github.com/stretchr/testify/assert"
"testing"
"time"
@ -13,15 +14,15 @@ func TestInvalidInsert(t *testing.T) {
func TestInsertNilValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(nil), `
INSERT INTO db.table1 (col1) VALUES
($1);
INSERT INTO db.table1 (col1)
VALUES ($1);
`, nil)
}
func TestInsertSingleValue(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1).VALUES(1), `
INSERT INTO db.table1 (col1) VALUES
($1);
INSERT INTO db.table1 (col1)
VALUES ($1);
`, int(1))
}
@ -29,8 +30,8 @@ func TestInsertWithColumnList(t *testing.T) {
columnList := ColumnList{table3ColInt, table3StrCol}
assertStatementSql(t, table3.INSERT(columnList).VALUES(1, 3), `
INSERT INTO db.table3 (col_int, col2) VALUES
($1, $2);
INSERT INTO db.table3 (col_int, col2)
VALUES ($1, $2);
`, 1, 3)
}
@ -38,15 +39,15 @@ func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
assertStatementSql(t, table1.INSERT(table1ColTime).VALUES(date), `
INSERT INTO db.table1 (col_time) VALUES
($1);
INSERT INTO db.table1 (col_time)
VALUES ($1);
`, date)
}
func TestInsertMultipleValues(t *testing.T) {
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1Col3).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col3) VALUES
($1, $2, $3);
assertStatementSql(t, table1.INSERT(table1Col1, table1ColFloat, table1ColBool).VALUES(1, 2, 3), `
INSERT INTO db.table1 (col1, col_float, col_bool)
VALUES ($1, $2, $3);
`, 1, 2, 3)
}
@ -57,10 +58,10 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(111, 222)
assertStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4),
($5, $6);
INSERT INTO db.table1 (col1, col_float)
VALUES ($1, $2),
($3, $4),
($5, $6);
`, 1, 2, 11, 22, 111, 222)
}
@ -82,12 +83,12 @@ func TestInsertValuesFromModel(t *testing.T) {
MODEL(&toInsert)
expectedSQL := `
INSERT INTO db.table1 (col1, col_float) VALUES
($1, $2),
($3, $4);
INSERT INTO db.table1 (col1, col_float)
VALUES ($1, $2),
($3, $4);
`
assertStatementSql(t, stmt, expectedSQL, int(1), float64(1.11), int(1), float64(1.11))
assertStatementSql(t, stmt, expectedSQL, 1, float64(1.11), 1, float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -139,9 +140,63 @@ func TestInsertDefaultValue(t *testing.T) {
VALUES(DEFAULT, "two")
var expectedSQL = `
INSERT INTO db.table1 (col1, col_float) VALUES
(DEFAULT, $1);
INSERT INTO db.table1 (col1, col_float)
VALUES (DEFAULT, $1);
`
assertStatementSql(t, stmt, expectedSQL, "two")
}
func TestInsert_ON_CONFLICT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
VALUES("theta", "beta").
ON_CONFLICT(table1ColBool).WHERE(table1ColBool.IS_NOT_FALSE()).DO_UPDATE(
SET(table1ColBool, "12").
SET(table2ColInt, 1).
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2'),
('theta', 'beta')
ON CONFLICT (col_bool) WHERE col_bool IS NOT FALSE DO UPDATE
SET col_bool = '12',
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}
func TestInsert_ON_CONFLICT_ON_CONSTRAINT(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColBool).
VALUES("one", "two").
VALUES("1", "2").
ON_CONFLICT().ON_CONSTRAINT("idk_primary_key").DO_UPDATE(
SET(table1ColBool, "12").
SET(table2ColInt, 1).
SET(ColumnList{table1Col1, table1ColBool}, jet.ROW(Int(2), String("two"))).
WHERE(table1Col1.GT(Int(2))),
).
RETURNING(table1Col1, table1ColBool)
assertDebugStatementSql(t, stmt, `
INSERT INTO db.table1 (col1, col_bool)
VALUES ('one', 'two'),
('1', '2')
ON CONFLICT ON CONSTRAINT idk_primary_key DO UPDATE
SET col_bool = '12',
col_int = 1,
(col1, col_bool) = ROW(2, 'two')
WHERE table1.col1 > 2
RETURNING table1.col1 AS "table1.col1",
table1.col_bool AS "table1.col_bool";
`)
}

View file

@ -109,10 +109,10 @@ type tableImpl struct {
}
// NewTable creates new table with schema Name, table Name and list of columns
func NewTable(schemaName, name string, column jet.ColumnExpression, columns ...jet.ColumnExpression) Table {
func NewTable(schemaName, name string, columns ...jet.ColumnExpression) Table {
t := &tableImpl{
SerializerTable: jet.NewTable(schemaName, name, column, columns...),
SerializerTable: jet.NewTable(schemaName, name, columns...),
}
t.readableTableInterfaceImpl.parent = t

View file

@ -5,9 +5,9 @@ import (
)
func TestJoinNilInputs(t *testing.T) {
assertClauseSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
assertSerializeErr(t, table2.INNER_JOIN(nil, table1ColBool.EQ(table2ColBool)),
"jet: right hand side of join operation is nil table")
assertClauseSerializeErr(t, table2.INNER_JOIN(table1, nil),
assertSerializeErr(t, table2.INNER_JOIN(table1, nil),
"jet: join condition is nil")
}

View file

@ -61,7 +61,7 @@ type clauseSet struct {
Values []jet.Serializer
}
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder) {
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
out.NewLine()
out.WriteString("SET")

View file

@ -10,7 +10,6 @@ import (
var table1Col1 = IntegerColumn("col1")
var table1ColInt = IntegerColumn("col_int")
var table1ColFloat = FloatColumn("col_float")
var table1Col3 = IntegerColumn("col3")
var table1ColTime = TimeColumn("col_time")
var table1ColTimez = TimezColumn("col_timez")
var table1ColTimestamp = TimestampColumn("col_timestamp")
@ -25,7 +24,6 @@ var table1 = NewTable(
table1Col1,
table1ColInt,
table1ColFloat,
table1Col3,
table1ColTime,
table1ColTimez,
table1ColBool,
@ -75,12 +73,16 @@ var table3 = NewTable(
table3ColInt,
table3StrCol)
func assertSerialize(t *testing.T, clause jet.Serializer, query string, args ...interface{}) {
func assertSerialize(t *testing.T, serializer jet.Serializer, query string, args ...interface{}) {
testutils.AssertSerialize(t, Dialect, serializer, query, args...)
}
func assertClauseSerialize(t *testing.T, clause jet.Clause, query string, args ...interface{}) {
testutils.AssertClauseSerialize(t, Dialect, clause, query, args...)
}
func assertClauseSerializeErr(t *testing.T, clause jet.Serializer, errString string) {
testutils.AssertClauseSerializeErr(t, Dialect, clause, errString)
func assertSerializeErr(t *testing.T, serializer jet.Serializer, errString string) {
testutils.AssertSerializeErr(t, Dialect, serializer, errString)
}
func assertProjectionSerialize(t *testing.T, projection jet.Projection, query string, args ...interface{}) {
@ -88,5 +90,6 @@ func assertProjectionSerialize(t *testing.T, projection jet.Projection, query st
}
var assertStatementSql = testutils.AssertStatementSql
var assertDebugStatementSql = testutils.AssertDebugStatementSql
var assertStatementSqlErr = testutils.AssertStatementSqlErr
var assertPanicErr = testutils.AssertPanicErr