From 72e8d7d5840ce4ea904e462665157aaeefdd941d Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 8 Dec 2021 18:13:58 +0100 Subject: [PATCH] Add FROM clause support for UPDATE statements --- internal/jet/clause.go | 7 ++- postgres/select_statement.go | 13 +++-- postgres/update_statement.go | 10 +++- sqlite/select_statement.go | 13 +++-- sqlite/update_statement.go | 12 ++++- tests/mysql/main_test.go | 6 +++ tests/mysql/update_test.go | 17 +++++-- tests/postgres/main_test.go | 6 +++ tests/postgres/raw_statements_test.go | 2 +- tests/postgres/update_test.go | 70 +++++++++++++++++++++++++ tests/sqlite/update_test.go | 73 +++++++++++++++++++++++++++ 11 files changed, 211 insertions(+), 18 deletions(-) diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 446a545..924a303 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -45,6 +45,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o // ClauseFrom struct type ClauseFrom struct { + Name string Tables []Serializer } @@ -54,7 +55,11 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, opt return } out.NewLine() - out.WriteString("FROM") + if f.Name != "" { + out.WriteString(f.Name) + } else { + out.WriteString("FROM") + } out.IncreaseIdent() for i, table := range f.Tables { diff --git a/postgres/select_statement.go b/postgres/select_statement.go index 8fb9cb6..518ebd5 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -110,10 +110,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { } func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { - s.From.Tables = nil - for _, table := range tables { - s.From.Tables = append(s.From.Tables, table) - } + s.From.Tables = readableTablesToSerializerList(tables) return s } @@ -182,3 +179,11 @@ func toJetFrameOffset(offset int64) jet.Serializer { } return jet.FixedLiteral(offset) } + +func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { + var ret []jet.Serializer + for _, table := range tables { + ret = append(ret, table) + } + return ret +} diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 58c5ba4..c13ffc6 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -11,8 +11,9 @@ type UpdateStatement interface { SET(value interface{}, values ...interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement + FROM(tables ...ReadableTable) UpdateStatement WHERE(expression BoolExpression) UpdateStatement - RETURNING(projections ...jet.Projection) UpdateStatement + RETURNING(projections ...Projection) UpdateStatement } type updateStatementImpl struct { @@ -21,6 +22,7 @@ type updateStatementImpl struct { Update jet.ClauseUpdate Set clauseSet SetNew jet.SetClauseNew + From jet.ClauseFrom Where jet.ClauseWhere Returning jet.ClauseReturning } @@ -31,6 +33,7 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme &update.Update, &update.Set, &update.SetNew, + &update.From, &update.Where, &update.Returning) @@ -61,6 +64,11 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { return u } +func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement { + u.From.Tables = readableTablesToSerializerList(tables) + return u +} + func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { u.Where.Condition = expression return u diff --git a/sqlite/select_statement.go b/sqlite/select_statement.go index 4406dcd..b5a7566 100644 --- a/sqlite/select_statement.go +++ b/sqlite/select_statement.go @@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { } func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { - s.From.Tables = nil - for _, table := range tables { - s.From.Tables = append(s.From.Tables, table) - } + s.From.Tables = readableTablesToSerializerList(tables) return s } @@ -184,3 +181,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer { return jet.FixedLiteral(offset) } + +func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer { + var ret []jet.Serializer + for _, table := range tables { + ret = append(ret, table) + } + return ret +} diff --git a/sqlite/update_statement.go b/sqlite/update_statement.go index 53cf72d..c28819a 100644 --- a/sqlite/update_statement.go +++ b/sqlite/update_statement.go @@ -9,14 +9,16 @@ type UpdateStatement interface { SET(value interface{}, values ...interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement + FROM(tables ...ReadableTable) UpdateStatement WHERE(expression BoolExpression) UpdateStatement - RETURNING(projections ...jet.Projection) UpdateStatement + RETURNING(projections ...Projection) UpdateStatement } type updateStatementImpl struct { jet.SerializerStatement Update jet.ClauseUpdate + From jet.ClauseFrom Set jet.SetClause SetNew jet.SetClauseNew Where jet.ClauseWhere @@ -29,6 +31,7 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { &update.Update, &update.Set, &update.SetNew, + &update.From, &update.Where, &update.Returning) @@ -59,12 +62,17 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement { return u } +func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement { + u.From.Tables = readableTablesToSerializerList(tables) + return u +} + func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { u.Where.Condition = expression return u } -func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement { +func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement { u.Returning.ProjectionList = projections return u } diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index e2be933..7f2e801 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -70,3 +70,9 @@ func skipForMariaDB(t *testing.T) { t.SkipNow() } } + +func beginTx(t *testing.T) *sql.Tx { + tx, err := db.Begin() + require.NoError(t, err) + return tx +} diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index dc28924..ba628a1 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -261,15 +261,22 @@ func TestUpdateExecContext(t *testing.T) { } func TestUpdateWithJoin(t *testing.T) { - query := table.Staff. - INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). + tx := beginTx(t) + defer tx.Rollback() + + statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)). UPDATE(table.Staff.LastName). - SET(String("New name")). + SET(String("New staff name")). WHERE(table.Staff.StaffID.EQ(Int(1))) - //fmt.Println(query.DebugSql()) + testutils.AssertStatementSql(t, statement, ` +UPDATE dvds.staff +INNER JOIN dvds.address ON (address.address_id = staff.address_id) +SET last_name = ? +WHERE staff.staff_id = ?; +`, "New staff name", int64(1)) - _, err := query.Exec(db) + _, err := statement.Exec(tx) require.NoError(t, err) } diff --git a/tests/postgres/main_test.go b/tests/postgres/main_test.go index 4e8aade..cc20646 100644 --- a/tests/postgres/main_test.go +++ b/tests/postgres/main_test.go @@ -87,3 +87,9 @@ func isPgxDriver() bool { return false } + +func beginTx(t *testing.T) *sql.Tx { + tx, err := db.Begin() + require.NoError(t, err) + return tx +} diff --git a/tests/postgres/raw_statements_test.go b/tests/postgres/raw_statements_test.go index a193258..4bbf90c 100644 --- a/tests/postgres/raw_statements_test.go +++ b/tests/postgres/raw_statements_test.go @@ -9,7 +9,7 @@ import ( "github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" - model2 "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/v2/postgres" ) diff --git a/tests/postgres/update_test.go b/tests/postgres/update_test.go index 5ec44a1..87ba49b 100644 --- a/tests/postgres/update_test.go +++ b/tests/postgres/update_test.go @@ -4,6 +4,8 @@ import ( "context" "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table" "github.com/stretchr/testify/require" @@ -371,6 +373,74 @@ func TestUpdateExecContext(t *testing.T) { require.Error(t, err, "context deadline exceeded") } +func TestUpdateFrom(t *testing.T) { + tx := beginTx(t) + defer tx.Rollback() + + stmt := table.Rental.UPDATE(). + SET( + table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)), + ).FROM( + table.Staff. + INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)), + table.Actor, + ).WHERE( + table.Staff.StaffID.EQ(table.Rental.StaffID). + AND(table.Staff.StaffID.EQ(Int(2))). + AND(table.Rental.RentalID.LT(Int(10))), + ).RETURNING( + table.Rental.AllColumns.Except(table.Rental.LastUpdate), + table.Store.AllColumns.Except(table.Store.LastUpdate), + ) + + testutils.AssertStatementSql(t, stmt, ` +UPDATE dvds.rental +SET rental_date = $1::timestamp without time zone +FROM dvds.staff + INNER JOIN dvds.store ON (store.store_id = staff.staff_id), + dvds.actor +WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $2)) AND (rental.rental_id < $3) +RETURNING rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id", + store.store_id AS "store.store_id", + store.manager_staff_id AS "store.manager_staff_id", + store.address_id AS "store.address_id"; +`) + + var dest []struct { + Rental model2.Rental + Store model2.Store + } + + err := stmt.Query(tx, &dest) + + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest[0], ` +{ + "Rental": { + "RentalID": 4, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 2452, + "CustomerID": 333, + "ReturnDate": "2005-06-03T01:43:41Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + "Store": { + "StoreID": 2, + "ManagerStaffID": 2, + "AddressID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + } +} +`) +} + func setupLinkTableForUpdateTest(t *testing.T) { cleanUpLinkTable(t) diff --git a/tests/sqlite/update_test.go b/tests/sqlite/update_test.go index 61135a8..99560ec 100644 --- a/tests/sqlite/update_test.go +++ b/tests/sqlite/update_test.go @@ -2,6 +2,8 @@ package sqlite import ( "context" + model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model" + "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table" "testing" "time" @@ -288,3 +290,74 @@ func TestUpdateContextDeadlineExceeded(t *testing.T) { _, err = updateStmt.ExecContext(ctx, tx) require.Error(t, err, "context deadline exceeded") } + +func TestUpdateFrom(t *testing.T) { + tx := beginDBTx(t) + defer tx.Rollback() + + stmt := table.Rental.UPDATE(). + SET( + table.Rental.RentalDate.SET(DateTime(2020, 2, 2, 0, 0, 0)), + ).FROM( + table.Staff. + INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)), + ).WHERE( + table.Staff.StaffID.EQ(table.Rental.StaffID). + AND(table.Staff.StaffID.EQ(Int(2))). + AND(table.Rental.RentalID.LT(Int(10))), + ).RETURNING( + table.Rental.AllColumns.Except(table.Rental.LastUpdate), + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE rental +SET rental_date = DATETIME('2020-02-02 00:00:00') +FROM staff + INNER JOIN store ON (store.store_id = staff.staff_id) +WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = 2)) AND (rental.rental_id < 10) +RETURNING rental.rental_id AS "rental.rental_id", + rental.rental_date AS "rental.rental_date", + rental.inventory_id AS "rental.inventory_id", + rental.customer_id AS "rental.customer_id", + rental.return_date AS "rental.return_date", + rental.staff_id AS "rental.staff_id"; +`) + + var dest []model2.Rental + + err := stmt.Query(tx, &dest) + + require.NoError(t, err) + require.Len(t, dest, 3) + testutils.AssertJSON(t, dest, ` +[ + { + "RentalID": 4, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 2452, + "CustomerID": 333, + "ReturnDate": "2005-06-03T01:43:41Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + { + "RentalID": 7, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 3995, + "CustomerID": 269, + "ReturnDate": "2005-05-29T20:34:53Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + }, + { + "RentalID": 8, + "RentalDate": "2020-02-02T00:00:00Z", + "InventoryID": 2346, + "CustomerID": 239, + "ReturnDate": "2005-05-27T23:33:46Z", + "StaffID": 2, + "LastUpdate": "0001-01-01T00:00:00Z" + } +] +`) +}