Add support for assigning one ColumnList to another in INSERT and UPDATE queries.

This commit is contained in:
go-jet 2025-03-13 13:05:35 +01:00
parent c9e6fb1f75
commit 4fcc99f48f
7 changed files with 170 additions and 44 deletions

View file

@ -1,27 +1,20 @@
package jet
// ColumnAssigment is interface wrapper around column assigment
// ColumnAssigment is interface wrapper around column assignment
type ColumnAssigment interface {
Serializer
isColumnAssigment()
isColumnAssignment()
}
type columnAssigmentImpl struct {
column ColumnSerializer
expression Expression
toAssign Serializer
}
func NewColumnAssignment(serializer ColumnSerializer, expression Expression) ColumnAssigment {
return &columnAssigmentImpl{
column: serializer,
expression: expression,
}
}
func (a columnAssigmentImpl) isColumnAssigment() {}
func (a columnAssigmentImpl) isColumnAssignment() {}
func (a columnAssigmentImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
a.column.serialize(statement, out, ShortName.WithFallTrough(options)...)
out.WriteString("=")
a.expression.serialize(statement, out, FallTrough(options)...)
a.toAssign.serialize(statement, out, FallTrough(options)...)
}

View file

@ -1,17 +1,60 @@
package jet
import "fmt"
// ColumnList is a helper type to support list of columns as single projection
type ColumnList []ColumnExpression
// SET creates column assigment for each column in column list. expression should be created by ROW function
func (cl ColumnList) isExpressionOrColumnList() {}
// SET creates a column assignment from the current ColumnList using the provided expression.
// This assignment can be used in INSERT queries (e.g., to set columns on conflict) or in UPDATE queries
// (e.g., to assign new values to columns).
//
// Link.UPDATE().
// SET(Link.MutableColumns.SET(ROW(String("github.com"), Bool(false))).
// WHERE(Link.ID.EQ(Int(0)))
func (cl ColumnList) SET(expression Expression) ColumnAssigment {
// The expression can be:
// - Another ColumnList: It must have the same length as the current ColumnList and each column must match by name
// - A ROW expression containing values.
// - A SELECT statement that returns a matching column list structure.
//
// Examples:
//
// Link.AllColumns.SET(ROW(String("github.com"), Bool(false)))
//
// Link.MutableColumns.SET(Link.EXCLUDED.MutableColumns)
//
// Link.MutableColumns.SET(
// SELECT(Link.MutableColumns).
// FROM(Link).
// WHERE(Link.ID.EQ(Int(200))),
// )
func (cl ColumnList) SET(toAssignExp expressionOrColumnList) ColumnAssigment {
if toAssign, ok := toAssignExp.(ColumnList); ok {
if len(cl) != len(toAssign) {
panic(fmt.Sprintf("jet: column list length mismatch: expected %d columns, got %d", len(cl), len(toAssign)))
}
var ret columnListAssigment
for i, column := range cl {
if column.Name() != toAssign[i].Name() {
panic(fmt.Sprintf("jet: column name mismatch at index %d: expected column '%s', got '%s'",
i, column.Name(), toAssign[i].Name(),
))
}
ret = append(ret, columnAssigmentImpl{
column: column,
toAssign: toAssign[i],
})
}
return ret
}
return columnAssigmentImpl{
column: cl,
expression: expression,
toAssign: toAssignExp,
}
}

View file

@ -0,0 +1,21 @@
package jet
type expressionOrColumnList interface {
Serializer
isExpressionOrColumnList()
}
type columnListAssigment []ColumnAssigment
func (c columnListAssigment) isColumnAssignment() {}
func (c columnListAssigment) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
for i, columnAssigment := range c {
if i > 0 {
out.WriteString(",")
out.NewLine()
}
columnAssigment.serialize(statement, out, options...)
}
}

View file

@ -0,0 +1,25 @@
package jet
import (
"github.com/stretchr/testify/require"
"testing"
)
func TestColumnList_SET(t *testing.T) {
columnList1 := ColumnList{IntegerColumn("id"), StringColumn("Name"), BoolColumn("active")}
columnList2 := ColumnList{IntegerColumn("id"), StringColumn("Name"), BoolColumn("active")}
columnList1.SET(columnList2)
columnList3 := ColumnList{IntegerColumn("id"), StringColumn("Name")}
require.PanicsWithValue(t, "jet: column list length mismatch: expected 2 columns, got 3", func() {
columnList3.SET(columnList1)
})
columnList4 := ColumnList{IntegerColumn("id"), StringColumn("FullName"), BoolColumn("active")}
require.PanicsWithValue(t, "jet: column name mismatch at index 1: expected column 'Name', got 'FullName'", func() {
columnList1.SET(columnList4)
})
}

View file

@ -29,7 +29,7 @@ func (i *boolColumnImpl) From(subQuery SelectTable) ColumnBool {
func (i *boolColumnImpl) SET(boolExp BoolExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: boolExp,
toAssign: boolExp,
}
}
@ -73,7 +73,7 @@ func (i *floatColumnImpl) From(subQuery SelectTable) ColumnFloat {
func (i *floatColumnImpl) SET(floatExp FloatExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: floatExp,
toAssign: floatExp,
}
}
@ -118,7 +118,7 @@ func (i *integerColumnImpl) From(subQuery SelectTable) ColumnInteger {
func (i *integerColumnImpl) SET(intExp IntegerExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intExp,
toAssign: intExp,
}
}
@ -164,7 +164,7 @@ func (i *stringColumnImpl) From(subQuery SelectTable) ColumnString {
func (i *stringColumnImpl) SET(stringExp StringExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: stringExp,
toAssign: stringExp,
}
}
@ -209,7 +209,7 @@ func (i *blobColumnImpl) From(subQuery SelectTable) ColumnBlob {
func (i *blobColumnImpl) SET(blobExp BlobExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: blobExp,
toAssign: blobExp,
}
}
@ -253,7 +253,7 @@ func (i *timeColumnImpl) From(subQuery SelectTable) ColumnTime {
func (i *timeColumnImpl) SET(timeExp TimeExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timeExp,
toAssign: timeExp,
}
}
@ -296,7 +296,7 @@ func (i *timezColumnImpl) From(subQuery SelectTable) ColumnTimez {
func (i *timezColumnImpl) SET(timezExp TimezExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timezExp,
toAssign: timezExp,
}
}
@ -340,7 +340,7 @@ func (i *timestampColumnImpl) From(subQuery SelectTable) ColumnTimestamp {
func (i *timestampColumnImpl) SET(timestampExp TimestampExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampExp,
toAssign: timestampExp,
}
}
@ -384,7 +384,7 @@ func (i *timestampzColumnImpl) From(subQuery SelectTable) ColumnTimestampz {
func (i *timestampzColumnImpl) SET(timestampzExp TimestampzExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: timestampzExp,
toAssign: timestampzExp,
}
}
@ -428,7 +428,7 @@ func (i *dateColumnImpl) From(subQuery SelectTable) ColumnDate {
func (i *dateColumnImpl) SET(dateExp DateExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: dateExp,
toAssign: dateExp,
}
}
@ -461,7 +461,7 @@ type intervalColumnImpl struct {
func (i *intervalColumnImpl) SET(intervalExp IntervalExpression) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: intervalExp,
toAssign: intervalExp,
}
}
@ -517,7 +517,7 @@ func (i *rangeColumnImpl[T]) From(subQuery SelectTable) ColumnRange[T] {
func (i *rangeColumnImpl[T]) SET(rangeExp Range[T]) ColumnAssigment {
return columnAssigmentImpl{
column: i,
expression: rangeExp,
toAssign: rangeExp,
}
}

View file

@ -9,6 +9,7 @@ type Expression interface {
Projection
GroupByClause
OrderByClause
expressionOrColumnList
serializeForJsonValue(statement StatementType, out *SQLBuilder)
setRoot(root Expression)
@ -37,6 +38,8 @@ type ExpressionInterfaceImpl struct {
Root Expression
}
func (e *ExpressionInterfaceImpl) isExpressionOrColumnList() {}
func (e *ExpressionInterfaceImpl) setRoot(root Expression) {
e.Root = root
}

View file

@ -196,6 +196,47 @@ RETURNING link.id AS "link.id",
testutils.AssertExecAndRollback(t, stmt, db, 2)
})
t.Run("do update column list", func(t *testing.T) {
stmt := Link.INSERT().
VALUES(1, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
ON_CONFLICT(Link.ID).DO_UPDATE(
SET(
Link.MutableColumns.SET(Link.EXCLUDED.MutableColumns),
),
).RETURNING(Link.AllColumns)
testutils.AssertDebugStatementSql(t, stmt, `
INSERT INTO test_sample.link
VALUES (1, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT)
ON CONFLICT (id) DO UPDATE
SET url = excluded.url,
name = excluded.name,
description = excluded.description
RETURNING link.id AS "link.id",
link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description";
`)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
var dest []model.Link
err := stmt.QueryContext(ctx, tx, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"ID": 1,
"URL": "http://www.postgresqltutorial.com",
"Name": "PostgreSQL Tutorial",
"Description": null
}
]
`)
})
})
t.Run("do update complex", func(t *testing.T) {
skipForCockroachDB(t) // does not support ROW