Add new typesafe SET operator for UPDATE statement.

This commit is contained in:
go-jet 2020-05-09 10:49:09 +02:00
parent a4b4710637
commit ebcbadef24
11 changed files with 269 additions and 123 deletions

View file

@ -271,14 +271,17 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, o
u.Table.serialize(statementType, out, FallTrough(options)...) u.Table.serialize(statementType, out, FallTrough(options)...)
} }
// ClauseSet struct // SetClause struct
type ClauseSet struct { type SetClause struct {
Columns []Column Columns []Column
Values []Serializer Values []Serializer
} }
// Serialize serializes clause into SQLBuilder // Serialize serializes clause into SQLBuilder
func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s.Values) == 0 {
return
}
out.NewLine() out.NewLine()
out.WriteString("SET") out.WriteString("SET")
@ -517,11 +520,14 @@ type SetPair struct {
Value Serializer Value Serializer
} }
// SetClause clause // SetClauseNew clause
type SetClause []ColumnAssigment type SetClauseNew []ColumnAssigment
// Serialize for SetClause // Serialize for SetClauseNew
func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { func (s SetClauseNew) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(s) == 0 {
return
}
out.NewLine() out.NewLine()
out.WriteString("SET") out.WriteString("SET")
out.IncreaseIdent(4) out.IncreaseIdent(4)

View file

@ -117,7 +117,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
} }
debuqSql := query.DebugSql() debuqSql := query.DebugSql()
assert.Equal(t, debuqSql, expectedQuery) require.Equal(t, debuqSql, expectedQuery)
} }
// AssertSerialize checks if clause serialize produces expected query and args // AssertSerialize checks if clause serialize produces expected query and args

View file

@ -8,7 +8,7 @@ type Table interface {
readableTable readableTable
INSERT(columns ...jet.Column) InsertStatement INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement DELETE() DeleteStatement
LOCK() LockStatement LOCK() LockStatement
} }
@ -35,7 +35,7 @@ type readableTable interface {
type joinSelectUpdateTable interface { type joinSelectUpdateTable interface {
ReadableTable ReadableTable
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement UPDATE(columns ...jet.Column) UpdateStatement
} }
// ReadableTable interface // ReadableTable interface
@ -98,8 +98,8 @@ func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(t.parent, jet.UnwidColumnList(columns)) return newInsertStatement(t.parent, jet.UnwidColumnList(columns))
} }
func (t *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwindColumns(column, columns...)) return newUpdateStatement(t.parent, jet.UnwidColumnList(columns))
} }
func (t *tableImpl) DELETE() DeleteStatement { func (t *tableImpl) DELETE() DeleteStatement {

View file

@ -16,14 +16,18 @@ type updateStatementImpl struct {
jet.SerializerStatement jet.SerializerStatement
Update jet.ClauseUpdate Update jet.ClauseUpdate
Set jet.ClauseSet Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere Where jet.ClauseWhere
} }
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{} update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Set, &update.Where) &update.Update,
&update.Set,
&update.SetNew,
&update.Where)
update.Update.Table = table update.Update.Table = table
update.Set.Columns = columns update.Set.Columns = columns
@ -33,7 +37,17 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
} }
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values) u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u return u
} }

View file

@ -20,7 +20,7 @@ type updateConflictActionImpl struct {
jet.Serializer jet.Serializer
doUpdate jet.KeywordClause doUpdate jet.KeywordClause
set jet.SetClause set jet.SetClauseNew
where jet.ClauseWhere where jet.ClauseWhere
} }

View file

@ -31,7 +31,7 @@ type readableTable interface {
type writableTable interface { type writableTable interface {
INSERT(columns ...jet.Column) InsertStatement INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement DELETE() DeleteStatement
LOCK() LockStatement LOCK() LockStatement
} }
@ -89,8 +89,8 @@ func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStateme
return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) return newInsertStatement(w.parent, jet.UnwidColumnList(columns))
} }
func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { func (w *writableTableInterfaceImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) return newUpdateStatement(w.parent, jet.UnwidColumnList(columns))
} }
func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { func (w *writableTableInterfaceImpl) DELETE() DeleteStatement {

View file

@ -20,14 +20,19 @@ type updateStatementImpl struct {
Update jet.ClauseUpdate Update jet.ClauseUpdate
Set clauseSet Set clauseSet
SetNew jet.SetClauseNew
Where jet.ClauseWhere Where jet.ClauseWhere
Returning clauseReturning Returning clauseReturning
} }
func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{} update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Set, &update.Where, &update.Returning) &update.Update,
&update.Set,
&update.SetNew,
&update.Where,
&update.Returning)
update.Update.Table = table update.Update.Table = table
update.Set.Columns = columns update.Set.Columns = columns
@ -37,7 +42,17 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme
} }
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values) u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u return u
} }
@ -62,6 +77,9 @@ type clauseSet struct {
} }
func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) {
if len(s.Values) == 0 {
return
}
out.NewLine() out.NewLine()
out.WriteString("SET") out.WriteString("SET")

View file

@ -16,22 +16,33 @@ import (
func TestUpdateValues(t *testing.T) { func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
var expectedSQL = ` var expectedSQL = `
UPDATE test_sample.link UPDATE test_sample.link
SET name = 'Bong', SET name = 'Bong',
url = 'http://bong.com' url = 'http://bong.com'
WHERE link.name = 'Bing'; WHERE link.name = 'Bing';
` `
t.Run("old version", func(t *testing.T) {
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
fmt.Println(query.DebugSql())
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, db) testutils.AssertExec(t, query, db)
})
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(
Link.Name.SET(String("Bong")),
Link.URL.SET(String("http://bong.com")),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, stmt, db)
})
links := []model.Link{} links := []model.Link{}
@ -52,16 +63,6 @@ WHERE link.name = 'Bing';
func TestUpdateWithSubQueries(t *testing.T) { func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
expectedSQL := ` expectedSQL := `
UPDATE test_sample.link UPDATE test_sample.link
SET name = ( SET name = (
@ -74,10 +75,37 @@ SET name = (
) )
WHERE link.name = ?; WHERE link.name = ?;
` `
fmt.Println(query.Sql()) t.Run("old version", func(t *testing.T) {
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db) testutils.AssertExec(t, query, db)
})
t.Run("new version", func(t *testing.T) {
query := Link.
UPDATE().
SET(
Link.Name.SET(StringExp(SELECT(String("Bong")))),
Link.URL.SET(StringExp(
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
)),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db)
})
} }
func TestUpdateWithModelData(t *testing.T) { func TestUpdateWithModelData(t *testing.T) {
@ -164,14 +192,12 @@ WHERE link.id = 201;
fmt.Println(stmt.DebugSql()) fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
testutils.AssertExec(t, stmt, db) testutils.AssertExec(t, stmt, db)
} }
func TestUpdateWithInvalidModelData(t *testing.T) { func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() { defer func() {
r := recover() r := recover()
assert.Equal(t, r, "missing struct field for column : id") assert.Equal(t, r, "missing struct field for column : id")
}() }()

View file

@ -12,7 +12,7 @@ import (
"testing" "testing"
) )
var query = Inventory. var oneInventoryQuery = Inventory.
SELECT(Inventory.AllColumns). SELECT(Inventory.AllColumns).
LIMIT(1). LIMIT(1).
ORDER_BY(Inventory.InventoryID) ORDER_BY(Inventory.InventoryID)
@ -20,69 +20,69 @@ var query = Inventory.
func TestScanToInvalidDestination(t *testing.T) { func TestScanToInvalidDestination(t *testing.T) {
t.Run("nil dest", func(t *testing.T) { t.Run("nil dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, nil, "jet: destination is nil") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, nil, "jet: destination is nil")
}) })
t.Run("struct dest", func(t *testing.T) { t.Run("struct dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct")
}) })
t.Run("slice dest", func(t *testing.T) { t.Run("slice dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct")
}) })
t.Run("slice of pointers to pointer dest", func(t *testing.T) { t.Run("slice of pointers to pointer dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct")
}) })
t.Run("map dest", func(t *testing.T) { t.Run("map dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, &map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct")
}) })
t.Run("map dest", func(t *testing.T) { t.Run("map dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct")
}) })
t.Run("map dest", func(t *testing.T) { t.Run("map dest", func(t *testing.T) {
testutils.AssertQueryPanicErr(t, query, db, &[]map[string]string{}, "jet: unsupported slice element type") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &[]map[string]string{}, "jet: unsupported slice element type")
}) })
} }
func TestScanToValidDestination(t *testing.T) { func TestScanToValidDestination(t *testing.T) {
t.Run("pointer to struct", func(t *testing.T) { t.Run("pointer to struct", func(t *testing.T) {
dest := []struct{}{} dest := []struct{}{}
err := query.Query(db, &dest) err := oneInventoryQuery.Query(db, &dest)
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("global query function scan", func(t *testing.T) { t.Run("global query function scan", func(t *testing.T) {
queryStr, args := query.Sql() queryStr, args := oneInventoryQuery.Sql()
dest := []struct{}{} dest := []struct{}{}
err := qrm.Query(nil, db, queryStr, args, &dest) err := qrm.Query(nil, db, queryStr, args, &dest)
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice", func(t *testing.T) { t.Run("pointer to slice", func(t *testing.T) {
err := query.Query(db, &[]struct{}{}) err := oneInventoryQuery.Query(db, &[]struct{}{})
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice of pointer to structs", func(t *testing.T) { t.Run("pointer to slice of pointer to structs", func(t *testing.T) {
err := query.Query(db, &[]*struct{}{}) err := oneInventoryQuery.Query(db, &[]*struct{}{})
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice of strings", func(t *testing.T) { t.Run("pointer to slice of strings", func(t *testing.T) {
err := query.Query(db, &[]int32{}) err := oneInventoryQuery.Query(db, &[]int32{})
assert.NoError(t, err) assert.NoError(t, err)
}) })
t.Run("pointer to slice of strings", func(t *testing.T) { t.Run("pointer to slice of strings", func(t *testing.T) {
err := query.Query(db, &[]*int32{}) err := oneInventoryQuery.Query(db, &[]*int32{})
assert.NoError(t, err) assert.NoError(t, err)
}) })
@ -690,7 +690,7 @@ func TestScanToSlice(t *testing.T) {
} }
} }
testutils.AssertQueryPanicErr(t, query, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'") testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'")
}) })
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"testing" "testing"
"time" "time"
) )
@ -14,39 +15,58 @@ import (
func TestUpdateValues(t *testing.T) { func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
t.Run("deprecated version", func(t *testing.T) {
query := Link. query := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing"))) WHERE(Link.Name.EQ(String("Bing")))
var expectedSQL = ` testutils.AssertDebugStatementSql(t, query, `
UPDATE test_sample.link UPDATE test_sample.link
SET (name, url) = ('Bong', 'http://bong.com') SET (name, url) = ('Bong', 'http://bong.com')
WHERE link.name = 'Bing'; WHERE link.name = 'Bing';
` `, "Bong", "http://bong.com", "Bing")
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
AssertExec(t, query, 1) testutils.AssertExec(t, query, db, 1)
links := []model.Link{} links := []model.Link{}
err := Link. err := Link.
SELECT(Link.AllColumns). SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Bong"))). WHERE(Link.Name.IN(String("Bong"))).
Query(db, &links) Query(db, &links)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, len(links), 1) require.Equal(t, len(links), 1)
testutils.AssertDeepEqual(t, links[0], model.Link{ testutils.AssertDeepEqual(t, links[0], model.Link{
ID: 204, ID: 204,
URL: "http://bong.com", URL: "http://bong.com",
Name: "Bong", Name: "Bong",
}) })
})
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(
Link.Name.SET(String("DuckDuckGo")),
Link.URL.SET(String("www.duckduckgo.com")),
).
WHERE(Link.Name.EQ(String("Yahoo")))
testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link
SET name = 'DuckDuckGo',
url = 'www.duckduckgo.com'
WHERE link.name = 'Yahoo';
`)
testutils.AssertExec(t, stmt, db, 1)
})
} }
func TestUpdateWithSubQueries(t *testing.T) { func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t) setupLinkTableForUpdateTest(t)
t.Run("deprecated version", func(t *testing.T) {
query := Link. query := Link.
UPDATE(Link.Name, Link.URL). UPDATE(Link.Name, Link.URL).
SET( SET(
@ -68,10 +88,34 @@ SET (name, url) = ((
)) ))
WHERE link.name = 'Bing'; WHERE link.name = 'Bing';
` `
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing")
AssertExec(t, query, 1) AssertExec(t, query, 1)
})
t.Run("new version", func(t *testing.T) {
query := Link.UPDATE().
SET(
Link.Name.SET(String("Bong")),
Link.URL.SET(StringExp(
SELECT(Link.URL).
FROM(Link).
WHERE(Link.Name.EQ(String("Bing")))),
),
).
WHERE(Link.Name.EQ(String("Bing")))
testutils.AssertStatementSql(t, query, `
UPDATE test_sample.link
SET name = $1,
url = (
SELECT link.url AS "link.url"
FROM test_sample.link
WHERE link.name = $2
)
WHERE link.name = $3;
`, "Bong", "Bing", "Bing")
})
} }
func TestUpdateAndReturning(t *testing.T) { func TestUpdateAndReturning(t *testing.T) {
@ -107,6 +151,7 @@ RETURNING link.id AS "link.id",
func TestUpdateWithSelect(t *testing.T) { func TestUpdateWithSelect(t *testing.T) {
t.Run("deprecated version", func(t *testing.T) {
stmt := Link.UPDATE(Link.AllColumns). stmt := Link.UPDATE(Link.AllColumns).
SET( SET(
Link. Link.
@ -130,10 +175,38 @@ WHERE link.id = 0;
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
AssertExec(t, stmt, 1) AssertExec(t, stmt, 1)
})
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(
Link.MutableColumns.SET(
SELECT(Link.MutableColumns).
FROM(Link).
WHERE(Link.ID.EQ(Int(0))),
),
).
WHERE(Link.ID.EQ(Int(0)))
testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link
SET (url, name, description) = (
SELECT link.url AS "link.url",
link.name AS "link.name",
link.description AS "link.description"
FROM test_sample.link
WHERE link.id = 0
)
WHERE link.id = 0;
`, int64(0), int64(0))
AssertExec(t, stmt, 1)
})
} }
func TestUpdateWithInvalidSelect(t *testing.T) { func TestUpdateWithInvalidSelect(t *testing.T) {
t.Run("deprecated version", func(t *testing.T) {
stmt := Link.UPDATE(Link.AllColumns). stmt := Link.UPDATE(Link.AllColumns).
SET( SET(
Link. Link.
@ -155,6 +228,15 @@ WHERE link.id = 0;
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values")
})
t.Run("new version", func(t *testing.T) {
stmt := Link.UPDATE().
SET(Link.AllColumns.SET(Link.SELECT(Link.MutableColumns))).
WHERE(Link.ID.EQ(Int(0)))
testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values")
})
} }
func TestUpdateWithModelData(t *testing.T) { func TestUpdateWithModelData(t *testing.T) {