From ebcbadef243f086728ce9462d58450d6af90fbb7 Mon Sep 17 00:00:00 2001 From: go-jet Date: Sat, 9 May 2020 10:49:09 +0200 Subject: [PATCH] Add new typesafe SET operator for UPDATE statement. --- internal/jet/clause.go | 22 ++-- internal/testutils/test_utils.go | 2 +- mysql/table.go | 8 +- mysql/update_statement.go | 22 +++- mysql/update_statement_test.go | 2 +- postgres/conflict_action.go | 2 +- postgres/table.go | 6 +- postgres/update_statement.go | 24 +++- tests/mysql/update_test.go | 90 +++++++++------ tests/postgres/scan_test.go | 30 ++--- tests/postgres/update_test.go | 184 ++++++++++++++++++++++--------- 11 files changed, 269 insertions(+), 123 deletions(-) diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 349c6dc..6091986 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -271,14 +271,17 @@ func (u *ClauseUpdate) Serialize(statementType StatementType, out *SQLBuilder, o u.Table.serialize(statementType, out, FallTrough(options)...) } -// ClauseSet struct -type ClauseSet struct { +// SetClause struct +type SetClause struct { Columns []Column Values []Serializer } // Serialize serializes clause into SQLBuilder -func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { +func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(s.Values) == 0 { + return + } out.NewLine() out.WriteString("SET") @@ -289,7 +292,7 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SQLBuilder, opti out.IncreaseIdent(4) for i, column := range s.Columns { if i > 0 { - out.WriteString(", ") + out.WriteString(",") out.NewLine() } @@ -517,11 +520,14 @@ type SetPair struct { Value Serializer } -// SetClause clause -type SetClause []ColumnAssigment +// SetClauseNew clause +type SetClauseNew []ColumnAssigment -// Serialize for SetClause -func (s SetClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { +// Serialize for SetClauseNew +func (s SetClauseNew) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(s) == 0 { + return + } out.NewLine() out.WriteString("SET") out.IncreaseIdent(4) diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 06035dd..ab36103 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -117,7 +117,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st } debuqSql := query.DebugSql() - assert.Equal(t, debuqSql, expectedQuery) + require.Equal(t, debuqSql, expectedQuery) } // AssertSerialize checks if clause serialize produces expected query and args diff --git a/mysql/table.go b/mysql/table.go index a4cf042..8287159 100644 --- a/mysql/table.go +++ b/mysql/table.go @@ -8,7 +8,7 @@ type Table interface { readableTable INSERT(columns ...jet.Column) InsertStatement - UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement + UPDATE(columns ...jet.Column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement } @@ -35,7 +35,7 @@ type readableTable interface { type joinSelectUpdateTable interface { ReadableTable - UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement + UPDATE(columns ...jet.Column) UpdateStatement } // ReadableTable interface @@ -98,8 +98,8 @@ func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement { return newInsertStatement(t.parent, jet.UnwidColumnList(columns)) } -func (t *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { - return newUpdateStatement(t.parent, jet.UnwindColumns(column, columns...)) +func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement { + return newUpdateStatement(t.parent, jet.UnwidColumnList(columns)) } func (t *tableImpl) DELETE() DeleteStatement { diff --git a/mysql/update_statement.go b/mysql/update_statement.go index ce4498f..ed8d515 100644 --- a/mysql/update_statement.go +++ b/mysql/update_statement.go @@ -16,14 +16,18 @@ type updateStatementImpl struct { jet.SerializerStatement Update jet.ClauseUpdate - Set jet.ClauseSet + Set jet.SetClause + SetNew jet.SetClauseNew Where jet.ClauseWhere } func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { update := &updateStatementImpl{} - update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, - &update.Set, &update.Where) + update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, + &update.Update, + &update.Set, + &update.SetNew, + &update.Where) update.Update.Table = table update.Set.Columns = columns @@ -33,7 +37,17 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { } func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { - u.Set.Values = jet.UnwindRowFromValues(value, values) + columnAssigment, isColumnAssigment := value.(ColumnAssigment) + + if isColumnAssigment { + u.SetNew = []ColumnAssigment{columnAssigment} + for _, value := range values { + u.SetNew = append(u.SetNew, value.(ColumnAssigment)) + } + } else { + u.Set.Values = jet.UnwindRowFromValues(value, values) + } + return u } diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go index fc933aa..fe3be01 100644 --- a/mysql/update_statement_test.go +++ b/mysql/update_statement_test.go @@ -23,7 +23,7 @@ WHERE table1.col_int >= ?; func TestUpdateWithValues(t *testing.T) { expectedSQL := ` UPDATE db.table1 -SET col_int = ?, +SET col_int = ?, col_float = ? WHERE table1.col_int >= ?; ` diff --git a/postgres/conflict_action.go b/postgres/conflict_action.go index b7e9e2e..55c9440 100644 --- a/postgres/conflict_action.go +++ b/postgres/conflict_action.go @@ -20,7 +20,7 @@ type updateConflictActionImpl struct { jet.Serializer doUpdate jet.KeywordClause - set jet.SetClause + set jet.SetClauseNew where jet.ClauseWhere } diff --git a/postgres/table.go b/postgres/table.go index bc2f5c2..c82a8f7 100644 --- a/postgres/table.go +++ b/postgres/table.go @@ -31,7 +31,7 @@ type readableTable interface { type writableTable interface { INSERT(columns ...jet.Column) InsertStatement - UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement + UPDATE(columns ...jet.Column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement } @@ -89,8 +89,8 @@ func (w *writableTableInterfaceImpl) INSERT(columns ...jet.Column) InsertStateme return newInsertStatement(w.parent, jet.UnwidColumnList(columns)) } -func (w *writableTableInterfaceImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement { - return newUpdateStatement(w.parent, jet.UnwindColumns(column, columns...)) +func (w *writableTableInterfaceImpl) UPDATE(columns ...jet.Column) UpdateStatement { + return newUpdateStatement(w.parent, jet.UnwidColumnList(columns)) } func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { diff --git a/postgres/update_statement.go b/postgres/update_statement.go index d96e1e9..9c56012 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -20,14 +20,19 @@ type updateStatementImpl struct { Update jet.ClauseUpdate Set clauseSet + SetNew jet.SetClauseNew Where jet.ClauseWhere Returning clauseReturning } func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement { update := &updateStatementImpl{} - update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update, - &update.Set, &update.Where, &update.Returning) + update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, + &update.Update, + &update.Set, + &update.SetNew, + &update.Where, + &update.Returning) update.Update.Table = table update.Set.Columns = columns @@ -37,7 +42,17 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme } func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { - u.Set.Values = jet.UnwindRowFromValues(value, values) + columnAssigment, isColumnAssigment := value.(ColumnAssigment) + + if isColumnAssigment { + u.SetNew = []ColumnAssigment{columnAssigment} + for _, value := range values { + u.SetNew = append(u.SetNew, value.(ColumnAssigment)) + } + } else { + u.Set.Values = jet.UnwindRowFromValues(value, values) + } + return u } @@ -62,6 +77,9 @@ type clauseSet struct { } func (s *clauseSet) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { + if len(s.Values) == 0 { + return + } out.NewLine() out.WriteString("SET") diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index c1e3f19..2114fe2 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -16,22 +16,33 @@ import ( func TestUpdateValues(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET("Bong", "http://bong.com"). - WHERE(Link.Name.EQ(String("Bing"))) - var expectedSQL = ` UPDATE test_sample.link -SET name = 'Bong', +SET name = 'Bong', url = 'http://bong.com' WHERE link.name = 'Bing'; ` + t.Run("old version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) - fmt.Println(query.DebugSql()) - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, query, db) + }) - testutils.AssertExec(t, query, db) + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(String("http://bong.com")), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "Bong", "http://bong.com", "Bing") + testutils.AssertExec(t, stmt, db) + }) links := []model.Link{} @@ -52,21 +63,11 @@ WHERE link.name = 'Bing'; func TestUpdateWithSubQueries(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET( - SELECT(String("Bong")), - SELECT(Link2.URL). - FROM(Link2). - WHERE(Link2.Name.EQ(String("Youtube"))), - ). - WHERE(Link.Name.EQ(String("Bing"))) - expectedSQL := ` UPDATE test_sample.link SET name = ( SELECT ? - ), + ), url = ( SELECT link2.url AS "link2.url" FROM test_sample.link2 @@ -74,10 +75,37 @@ SET name = ( ) WHERE link.name = ?; ` - fmt.Println(query.Sql()) - testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") + t.Run("old version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + SELECT(String("Bong")), + SELECT(Link2.URL). + FROM(Link2). + WHERE(Link2.Name.EQ(String("Youtube"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) - testutils.AssertExec(t, query, db) + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") + testutils.AssertExec(t, query, db) + }) + + t.Run("new version", func(t *testing.T) { + query := Link. + UPDATE(). + SET( + Link.Name.SET(StringExp(SELECT(String("Bong")))), + Link.URL.SET(StringExp( + SELECT(Link2.URL). + FROM(Link2). + WHERE(Link2.Name.EQ(String("Youtube"))), + )), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing") + testutils.AssertExec(t, query, db) + }) } func TestUpdateWithModelData(t *testing.T) { @@ -96,9 +124,9 @@ func TestUpdateWithModelData(t *testing.T) { expectedSQL := ` UPDATE test_sample.link -SET id = ?, - url = ?, - name = ?, +SET id = ?, + url = ?, + name = ?, description = ? WHERE link.id = ?; ` @@ -127,8 +155,8 @@ func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link -SET description = NULL, - name = 'DuckDuckGo', +SET description = NULL, + name = 'DuckDuckGo', url = 'http://www.duckduckgo.com' WHERE link.id = 201; ` @@ -156,22 +184,20 @@ func TestUpdateWithModelDataAndMutableColumns(t *testing.T) { var expectedSQL = ` UPDATE test_sample.link -SET url = 'http://www.duckduckgo.com', - name = 'DuckDuckGo', +SET url = 'http://www.duckduckgo.com', + name = 'DuckDuckGo', description = NULL WHERE link.id = 201; ` fmt.Println(stmt.DebugSql()) testutils.AssertDebugStatementSql(t, stmt, expectedSQL, "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201)) - testutils.AssertExec(t, stmt, db) } func TestUpdateWithInvalidModelData(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "missing struct field for column : id") }() diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index c5b6d3e..92251c3 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -12,7 +12,7 @@ import ( "testing" ) -var query = Inventory. +var oneInventoryQuery = Inventory. SELECT(Inventory.AllColumns). LIMIT(1). ORDER_BY(Inventory.InventoryID) @@ -20,69 +20,69 @@ var query = Inventory. func TestScanToInvalidDestination(t *testing.T) { t.Run("nil dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, nil, "jet: destination is nil") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, nil, "jet: destination is nil") }) t.Run("struct dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("slice dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("slice of pointers to pointer dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, &map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - testutils.AssertQueryPanicErr(t, query, db, &[]map[string]string{}, "jet: unsupported slice element type") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &[]map[string]string{}, "jet: unsupported slice element type") }) } func TestScanToValidDestination(t *testing.T) { t.Run("pointer to struct", func(t *testing.T) { dest := []struct{}{} - err := query.Query(db, &dest) + err := oneInventoryQuery.Query(db, &dest) assert.NoError(t, err) }) t.Run("global query function scan", func(t *testing.T) { - queryStr, args := query.Sql() + queryStr, args := oneInventoryQuery.Sql() dest := []struct{}{} err := qrm.Query(nil, db, queryStr, args, &dest) assert.NoError(t, err) }) t.Run("pointer to slice", func(t *testing.T) { - err := query.Query(db, &[]struct{}{}) + err := oneInventoryQuery.Query(db, &[]struct{}{}) assert.NoError(t, err) }) t.Run("pointer to slice of pointer to structs", func(t *testing.T) { - err := query.Query(db, &[]*struct{}{}) + err := oneInventoryQuery.Query(db, &[]*struct{}{}) assert.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { - err := query.Query(db, &[]int32{}) + err := oneInventoryQuery.Query(db, &[]int32{}) assert.NoError(t, err) }) t.Run("pointer to slice of strings", func(t *testing.T) { - err := query.Query(db, &[]*int32{}) + err := oneInventoryQuery.Query(db, &[]*int32{}) assert.NoError(t, err) }) @@ -690,7 +690,7 @@ func TestScanToSlice(t *testing.T) { } } - testutils.AssertQueryPanicErr(t, query, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'") + testutils.AssertQueryPanicErr(t, oneInventoryQuery, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'") }) } diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index ca07332..43e64fb 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -7,6 +7,7 @@ import ( "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "testing" "time" ) @@ -14,50 +15,69 @@ import ( func TestUpdateValues(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET("Bong", "http://bong.com"). - WHERE(Link.Name.EQ(String("Bing"))) + t.Run("deprecated version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET("Bong", "http://bong.com"). + WHERE(Link.Name.EQ(String("Bing"))) - var expectedSQL = ` + testutils.AssertDebugStatementSql(t, query, ` UPDATE test_sample.link SET (name, url) = ('Bong', 'http://bong.com') WHERE link.name = 'Bing'; -` - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing") +`, "Bong", "http://bong.com", "Bing") - AssertExec(t, query, 1) + testutils.AssertExec(t, query, db, 1) - links := []model.Link{} + links := []model.Link{} - err := Link. - SELECT(Link.AllColumns). - WHERE(Link.Name.EQ(String("Bong"))). - Query(db, &links) + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.IN(String("Bong"))). + Query(db, &links) - assert.NoError(t, err) - assert.Equal(t, len(links), 1) - testutils.AssertDeepEqual(t, links[0], model.Link{ - ID: 204, - URL: "http://bong.com", - Name: "Bong", + require.NoError(t, err) + require.Equal(t, len(links), 1) + testutils.AssertDeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.Name.SET(String("DuckDuckGo")), + Link.URL.SET(String("www.duckduckgo.com")), + ). + WHERE(Link.Name.EQ(String("Yahoo"))) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.link +SET name = 'DuckDuckGo', + url = 'www.duckduckgo.com' +WHERE link.name = 'Yahoo'; +`) + testutils.AssertExec(t, stmt, db, 1) }) } func TestUpdateWithSubQueries(t *testing.T) { setupLinkTableForUpdateTest(t) - query := Link. - UPDATE(Link.Name, Link.URL). - SET( - SELECT(String("Bong")), - SELECT(Link.URL). - FROM(Link). - WHERE(Link.Name.EQ(String("Bing"))), - ). - WHERE(Link.Name.EQ(String("Bing"))) + t.Run("deprecated version", func(t *testing.T) { + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + SELECT(String("Bong")), + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Bing"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) - expectedSQL := ` + expectedSQL := ` UPDATE test_sample.link SET (name, url) = (( SELECT 'Bong' @@ -68,10 +88,34 @@ SET (name, url) = (( )) WHERE link.name = 'Bing'; ` + testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") - testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing") + AssertExec(t, query, 1) + }) - AssertExec(t, query, 1) + t.Run("new version", func(t *testing.T) { + query := Link.UPDATE(). + SET( + Link.Name.SET(String("Bong")), + Link.URL.SET(StringExp( + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Bing")))), + ), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + testutils.AssertStatementSql(t, query, ` +UPDATE test_sample.link +SET name = $1, + url = ( + SELECT link.url AS "link.url" + FROM test_sample.link + WHERE link.name = $2 + ) +WHERE link.name = $3; +`, "Bong", "Bing", "Bing") + }) } func TestUpdateAndReturning(t *testing.T) { @@ -107,15 +151,16 @@ RETURNING link.id AS "link.id", func TestUpdateWithSelect(t *testing.T) { - stmt := Link.UPDATE(Link.AllColumns). - SET( - Link. - SELECT(Link.AllColumns). - WHERE(Link.ID.EQ(Int(0))), - ). - WHERE(Link.ID.EQ(Int(0))) + t.Run("deprecated version", func(t *testing.T) { + stmt := Link.UPDATE(Link.AllColumns). + SET( + Link. + SELECT(Link.AllColumns). + WHERE(Link.ID.EQ(Int(0))), + ). + WHERE(Link.ID.EQ(Int(0))) - expectedSQL := ` + expectedSQL := ` UPDATE test_sample.link SET (id, url, name, description) = ( SELECT link.id AS "link.id", @@ -127,22 +172,50 @@ SET (id, url, name, description) = ( ) WHERE link.id = 0; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - AssertExec(t, stmt, 1) + AssertExec(t, stmt, 1) + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET( + Link.MutableColumns.SET( + SELECT(Link.MutableColumns). + FROM(Link). + WHERE(Link.ID.EQ(Int(0))), + ), + ). + WHERE(Link.ID.EQ(Int(0))) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.link +SET (url, name, description) = ( + SELECT link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description" + FROM test_sample.link + WHERE link.id = 0 + ) +WHERE link.id = 0; +`, int64(0), int64(0)) + + AssertExec(t, stmt, 1) + }) } func TestUpdateWithInvalidSelect(t *testing.T) { - stmt := Link.UPDATE(Link.AllColumns). - SET( - Link. - SELECT(Link.ID, Link.Name). - WHERE(Link.ID.EQ(Int(0))), - ). - WHERE(Link.ID.EQ(Int(0))) + t.Run("deprecated version", func(t *testing.T) { + stmt := Link.UPDATE(Link.AllColumns). + SET( + Link. + SELECT(Link.ID, Link.Name). + WHERE(Link.ID.EQ(Int(0))), + ). + WHERE(Link.ID.EQ(Int(0))) - var expectedSQL = ` + var expectedSQL = ` UPDATE test_sample.link SET (id, url, name, description) = ( SELECT link.id AS "link.id", @@ -152,9 +225,18 @@ SET (id, url, name, description) = ( ) WHERE link.id = 0; ` - testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) + testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0)) - testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") + }) + + t.Run("new version", func(t *testing.T) { + stmt := Link.UPDATE(). + SET(Link.AllColumns.SET(Link.SELECT(Link.MutableColumns))). + WHERE(Link.ID.EQ(Int(0))) + + testutils.AssertExecErr(t, stmt, db, "pq: number of columns does not match number of values") + }) } func TestUpdateWithModelData(t *testing.T) {