Add support for assigning one ColumnList to another in INSERT and UPDATE queries.
This commit is contained in:
parent
c9e6fb1f75
commit
4fcc99f48f
7 changed files with 170 additions and 44 deletions
|
|
@ -1,27 +1,20 @@
|
|||
package jet
|
||||
|
||||
// ColumnAssigment is interface wrapper around column assigment
|
||||
// ColumnAssigment is interface wrapper around column assignment
|
||||
type ColumnAssigment interface {
|
||||
Serializer
|
||||
isColumnAssigment()
|
||||
isColumnAssignment()
|
||||
}
|
||||
|
||||
type columnAssigmentImpl struct {
|
||||
column ColumnSerializer
|
||||
expression Expression
|
||||
toAssign Serializer
|
||||
}
|
||||
|
||||
func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment {
|
||||
return &columnAssigmentImpl{
|
||||
column: serializer,
|
||||
expression: expression,
|
||||
}
|
||||
}
|
||||
|
||||
func (a columnAssigmentImpl) isColumnAssigment() {}
|
||||
func (a columnAssigmentImpl) isColumnAssignment() {}
|
||||
|
||||
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
|
||||
out.WriteString("=")
|
||||
a.expression.serialize(statement, out, FallTrough(options)...)
|
||||
a.toAssign.serialize(statement, out, FallTrough(options)...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,60 @@
|
|||
package jet
|
||||
|
||||
import "fmt"
|
||||
|
||||
// ColumnList is a helper type to support list of columns as single projection
|
||||
type ColumnList []ColumnExpression
|
||||
|
||||
// SET creates column assigment for each column in column list. expression should be created by ROW function
|
||||
func (cl ColumnList) isExpressionOrColumnList() {}
|
||||
|
||||
// SET creates a column assignment from the current ColumnList using the provided expression.
|
||||
// This assignment can be used in INSERT queries (e.g., to set columns on conflict) or in UPDATE queries
|
||||
// (e.g., to assign new values to columns).
|
||||
//
|
||||
// Link.UPDATE().
|
||||
// SET(Link.MutableColumns.SET(ROW(String("github.com"), Bool(false))).
|
||||
// WHERE(Link.ID.EQ(Int(0)))
|
||||
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
|
||||
// The expression can be:
|
||||
// - Another ColumnList: It must have the same length as the current ColumnList and each column must match by name
|
||||
// - A ROW expression containing values.
|
||||
// - A SELECT statement that returns a matching column list structure.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// Link.AllColumns.SET(ROW(String("github.com"), Bool(false)))
|
||||
//
|
||||
// Link.MutableColumns.SET(Link.EXCLUDED.MutableColumns)
|
||||
//
|
||||
// Link.MutableColumns.SET(
|
||||
// SELECT(Link.MutableColumns).
|
||||
// FROM(Link).
|
||||
// WHERE(Link.ID.EQ(Int(200))),
|
||||
// )
|
||||
func (cl ColumnList) SET(toAssignExp expressionOrColumnList) ColumnAssigment {
|
||||
|
||||
if toAssign, ok := toAssignExp.(ColumnList); ok {
|
||||
if len(cl) != len(toAssign) {
|
||||
panic(fmt.Sprintf("jet: column list length mismatch: expected %d columns, got %d", len(cl), len(toAssign)))
|
||||
}
|
||||
|
||||
var ret columnListAssigment
|
||||
|
||||
for i, column := range cl {
|
||||
if column.Name() != toAssign[i].Name() {
|
||||
panic(fmt.Sprintf("jet: column name mismatch at index %d: expected column '%s', got '%s'",
|
||||
i, column.Name(), toAssign[i].Name(),
|
||||
))
|
||||
}
|
||||
|
||||
ret = append(ret, columnAssigmentImpl{
|
||||
column: column,
|
||||
toAssign: toAssign[i],
|
||||
})
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
return columnAssigmentImpl{
|
||||
column: cl,
|
||||
expression: expression,
|
||||
toAssign: toAssignExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
21
internal/jet/column_list_assigment.go
Normal file
21
internal/jet/column_list_assigment.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
package jet
|
||||
|
||||
type expressionOrColumnList interface {
|
||||
Serializer
|
||||
isExpressionOrColumnList()
|
||||
}
|
||||
|
||||
type columnListAssigment []ColumnAssigment
|
||||
|
||||
func (c columnListAssigment) isColumnAssignment() {}
|
||||
|
||||
func (c columnListAssigment) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
|
||||
for i, columnAssigment := range c {
|
||||
if i > 0 {
|
||||
out.WriteString(",")
|
||||
out.NewLine()
|
||||
}
|
||||
|
||||
columnAssigment.serialize(statement, out, options...)
|
||||
}
|
||||
}
|
||||
25
internal/jet/column_list_test.go
Normal file
25
internal/jet/column_list_test.go
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
package jet
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestColumnList_SET(t *testing.T) {
|
||||
columnList1 := ColumnList{IntegerColumn("id"), StringColumn("Name"), BoolColumn("active")}
|
||||
columnList2 := ColumnList{IntegerColumn("id"), StringColumn("Name"), BoolColumn("active")}
|
||||
|
||||
columnList1.SET(columnList2)
|
||||
|
||||
columnList3 := ColumnList{IntegerColumn("id"), StringColumn("Name")}
|
||||
|
||||
require.PanicsWithValue(t, "jet: column list length mismatch: expected 2 columns, got 3", func() {
|
||||
columnList3.SET(columnList1)
|
||||
})
|
||||
|
||||
columnList4 := ColumnList{IntegerColumn("id"), StringColumn("FullName"), BoolColumn("active")}
|
||||
|
||||
require.PanicsWithValue(t, "jet: column name mismatch at index 1: expected column 'Name', got 'FullName'", func() {
|
||||
columnList1.SET(columnList4)
|
||||
})
|
||||
}
|
||||
|
|
@ -29,7 +29,7 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
|
|||
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: boolExp,
|
||||
toAssign: boolExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
|
|||
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: floatExp,
|
||||
toAssign: floatExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -118,7 +118,7 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
|
|||
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: intExp,
|
||||
toAssign: intExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
|
|||
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: stringExp,
|
||||
toAssign: stringExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -209,7 +209,7 @@ func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob {
|
|||
func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: blobExp,
|
||||
toAssign: blobExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -253,7 +253,7 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
|
|||
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timeExp,
|
||||
toAssign: timeExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -296,7 +296,7 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
|
|||
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timezExp,
|
||||
toAssign: timezExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -340,7 +340,7 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
|
|||
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timestampExp,
|
||||
toAssign: timestampExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -384,7 +384,7 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
|
|||
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: timestampzExp,
|
||||
toAssign: timestampzExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -428,7 +428,7 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
|
|||
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: dateExp,
|
||||
toAssign: dateExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -461,7 +461,7 @@ type intervalColumnImpl struct {
|
|||
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: intervalExp,
|
||||
toAssign: intervalExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -517,7 +517,7 @@ func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] {
|
|||
func (i *rangeColumnImpl[T]) SET(rangeExp Range[T]) ColumnAssigment {
|
||||
return columnAssigmentImpl{
|
||||
column: i,
|
||||
expression: rangeExp,
|
||||
toAssign: rangeExp,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ type Expression interface {
|
|||
Projection
|
||||
GroupByClause
|
||||
OrderByClause
|
||||
expressionOrColumnList
|
||||
|
||||
serializeForJsonValue(statement StatementType, out *SQLBuilder)
|
||||
setRoot(root Expression)
|
||||
|
|
@ -37,6 +38,8 @@ type ExpressionInterfaceImpl struct {
|
|||
Root Expression
|
||||
}
|
||||
|
||||
func (e *ExpressionInterfaceImpl) isExpressionOrColumnList() {}
|
||||
|
||||
func (e *ExpressionInterfaceImpl) setRoot(root Expression) {
|
||||
e.Root = root
|
||||
}
|
||||
|
|
|
|||
|
|
@ -196,6 +196,47 @@ RETURNING link.id AS "link.id",
|
|||
testutils.AssertExecAndRollback(t, stmt, db, 2)
|
||||
})
|
||||
|
||||
t.Run("do update column list", func(t *testing.T) {
|
||||
stmt := Link.INSERT().
|
||||
VALUES(1, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
|
||||
ON_CONFLICT(Link.ID).DO_UPDATE(
|
||||
SET(
|
||||
Link.MutableColumns.SET(Link.EXCLUDED.MutableColumns),
|
||||
),
|
||||
).RETURNING(Link.AllColumns)
|
||||
|
||||
testutils.AssertDebugStatementSql(t, stmt, `
|
||||
INSERT INTO test_sample.link
|
||||
VALUES (1, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT)
|
||||
ON CONFLICT (id) DO UPDATE
|
||||
SET url = excluded.url,
|
||||
name = excluded.name,
|
||||
description = excluded.description
|
||||
RETURNING link.id AS "link.id",
|
||||
link.url AS "link.url",
|
||||
link.name AS "link.name",
|
||||
link.description AS "link.description";
|
||||
`)
|
||||
|
||||
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
|
||||
var dest []model.Link
|
||||
|
||||
err := stmt.QueryContext(ctx, tx, &dest)
|
||||
require.NoError(t, err)
|
||||
|
||||
testutils.AssertJSON(t, dest, `
|
||||
[
|
||||
{
|
||||
"ID": 1,
|
||||
"URL": "http://www.postgresqltutorial.com",
|
||||
"Name": "PostgreSQL Tutorial",
|
||||
"Description": null
|
||||
}
|
||||
]
|
||||
`)
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("do update complex", func(t *testing.T) {
|
||||
skipForCockroachDB(t) // does not support ROW
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue