Add FROM clause support for UPDATE statements

This commit is contained in:
go-jet 2021-12-08 18:13:58 +01:00
parent 97c34fbb54
commit 72e8d7d584
11 changed files with 211 additions and 18 deletions

View file

@ -45,6 +45,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
// ClauseFrom struct
type ClauseFrom struct {
Name string
Tables []Serializer
}
@ -54,7 +55,11 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, opt
return
}
out.NewLine()
if f.Name != "" {
out.WriteString(f.Name)
} else {
out.WriteString("FROM")
}
out.IncreaseIdent()
for i, table := range f.Tables {

View file

@ -110,10 +110,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
s.From.Tables = readableTablesToSerializerList(tables)
return s
}
@ -182,3 +179,11 @@ func toJetFrameOffset(offset int64) jet.Serializer {
}
return jet.FixedLiteral(offset)
}
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
var ret []jet.Serializer
for _, table := range tables {
ret = append(ret, table)
}
return ret
}

View file

@ -11,8 +11,9 @@ type UpdateStatement interface {
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
FROM(tables ...ReadableTable) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...jet.Projection) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
type updateStatementImpl struct {
@ -21,6 +22,7 @@ type updateStatementImpl struct {
Update jet.ClauseUpdate
Set clauseSet
SetNew jet.SetClauseNew
From jet.ClauseFrom
Where jet.ClauseWhere
Returning jet.ClauseReturning
}
@ -31,6 +33,7 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme
&update.Update,
&update.Set,
&update.SetNew,
&update.From,
&update.Where,
&update.Returning)
@ -61,6 +64,11 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
return u
}
func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement {
u.From.Tables = readableTablesToSerializerList(tables)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u

View file

@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
}
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
s.From.Tables = nil
for _, table := range tables {
s.From.Tables = append(s.From.Tables, table)
}
s.From.Tables = readableTablesToSerializerList(tables)
return s
}
@ -184,3 +181,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer {
return jet.FixedLiteral(offset)
}
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
var ret []jet.Serializer
for _, table := range tables {
ret = append(ret, table)
}
return ret
}

View file

@ -9,14 +9,16 @@ type UpdateStatement interface {
SET(value interface{}, values ...interface{}) UpdateStatement
MODEL(data interface{}) UpdateStatement
FROM(tables ...ReadableTable) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...jet.Projection) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
type updateStatementImpl struct {
jet.SerializerStatement
Update jet.ClauseUpdate
From jet.ClauseFrom
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
@ -29,6 +31,7 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
&update.Update,
&update.Set,
&update.SetNew,
&update.From,
&update.Where,
&update.Returning)
@ -59,12 +62,17 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
return u
}
func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement {
u.From.Tables = readableTablesToSerializerList(tables)
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.Returning.ProjectionList = projections
return u
}

View file

@ -70,3 +70,9 @@ func skipForMariaDB(t *testing.T) {
t.SkipNow()
}
}
func beginTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -261,15 +261,22 @@ func TestUpdateExecContext(t *testing.T) {
}
func TestUpdateWithJoin(t *testing.T) {
query := table.Staff.
INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
tx := beginTx(t)
defer tx.Rollback()
statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
UPDATE(table.Staff.LastName).
SET(String("New name")).
SET(String("New staff name")).
WHERE(table.Staff.StaffID.EQ(Int(1)))
//fmt.Println(query.DebugSql())
testutils.AssertStatementSql(t, statement, `
UPDATE dvds.staff
INNER JOIN dvds.address ON (address.address_id = staff.address_id)
SET last_name = ?
WHERE staff.staff_id = ?;
`, "New staff name", int64(1))
_, err := query.Exec(db)
_, err := statement.Exec(tx)
require.NoError(t, err)
}

View file

@ -87,3 +87,9 @@ func isPgxDriver() bool {
return false
}
func beginTx(t *testing.T) *sql.Tx {
tx, err := db.Begin()
require.NoError(t, err)
return tx
}

View file

@ -9,7 +9,7 @@ import (
"github.com/go-jet/jet/v2/internal/testutils"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/v2/postgres"
)

View file

@ -4,6 +4,8 @@ import (
"context"
"github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table"
"github.com/stretchr/testify/require"
@ -371,6 +373,74 @@ func TestUpdateExecContext(t *testing.T) {
require.Error(t, err, "context deadline exceeded")
}
func TestUpdateFrom(t *testing.T) {
tx := beginTx(t)
defer tx.Rollback()
stmt := table.Rental.UPDATE().
SET(
table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)),
).FROM(
table.Staff.
INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)),
table.Actor,
).WHERE(
table.Staff.StaffID.EQ(table.Rental.StaffID).
AND(table.Staff.StaffID.EQ(Int(2))).
AND(table.Rental.RentalID.LT(Int(10))),
).RETURNING(
table.Rental.AllColumns.Except(table.Rental.LastUpdate),
table.Store.AllColumns.Except(table.Store.LastUpdate),
)
testutils.AssertStatementSql(t, stmt, `
UPDATE dvds.rental
SET rental_date = $1::timestamp without time zone
FROM dvds.staff
INNER JOIN dvds.store ON (store.store_id = staff.staff_id),
dvds.actor
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $2)) AND (rental.rental_id < $3)
RETURNING rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id",
rental.customer_id AS "rental.customer_id",
rental.return_date AS "rental.return_date",
rental.staff_id AS "rental.staff_id",
store.store_id AS "store.store_id",
store.manager_staff_id AS "store.manager_staff_id",
store.address_id AS "store.address_id";
`)
var dest []struct {
Rental model2.Rental
Store model2.Store
}
err := stmt.Query(tx, &dest)
require.NoError(t, err)
require.Len(t, dest, 3)
testutils.AssertJSON(t, dest[0], `
{
"Rental": {
"RentalID": 4,
"RentalDate": "2020-02-02T00:00:00Z",
"InventoryID": 2452,
"CustomerID": 333,
"ReturnDate": "2005-06-03T01:43:41Z",
"StaffID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
},
"Store": {
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
}
}
`)
}
func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t)

View file

@ -2,6 +2,8 @@ package sqlite
import (
"context"
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table"
"testing"
"time"
@ -288,3 +290,74 @@ func TestUpdateContextDeadlineExceeded(t *testing.T) {
_, err = updateStmt.ExecContext(ctx, tx)
require.Error(t, err, "context deadline exceeded")
}
func TestUpdateFrom(t *testing.T) {
tx := beginDBTx(t)
defer tx.Rollback()
stmt := table.Rental.UPDATE().
SET(
table.Rental.RentalDate.SET(DateTime(2020, 2, 2, 0, 0, 0)),
).FROM(
table.Staff.
INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)),
).WHERE(
table.Staff.StaffID.EQ(table.Rental.StaffID).
AND(table.Staff.StaffID.EQ(Int(2))).
AND(table.Rental.RentalID.LT(Int(10))),
).RETURNING(
table.Rental.AllColumns.Except(table.Rental.LastUpdate),
)
testutils.AssertDebugStatementSql(t, stmt, `
UPDATE rental
SET rental_date = DATETIME('2020-02-02 00:00:00')
FROM staff
INNER JOIN store ON (store.store_id = staff.staff_id)
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = 2)) AND (rental.rental_id < 10)
RETURNING rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id",
rental.customer_id AS "rental.customer_id",
rental.return_date AS "rental.return_date",
rental.staff_id AS "rental.staff_id";
`)
var dest []model2.Rental
err := stmt.Query(tx, &dest)
require.NoError(t, err)
require.Len(t, dest, 3)
testutils.AssertJSON(t, dest, `
[
{
"RentalID": 4,
"RentalDate": "2020-02-02T00:00:00Z",
"InventoryID": 2452,
"CustomerID": 333,
"ReturnDate": "2005-06-03T01:43:41Z",
"StaffID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
},
{
"RentalID": 7,
"RentalDate": "2020-02-02T00:00:00Z",
"InventoryID": 3995,
"CustomerID": 269,
"ReturnDate": "2005-05-29T20:34:53Z",
"StaffID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
},
{
"RentalID": 8,
"RentalDate": "2020-02-02T00:00:00Z",
"InventoryID": 2346,
"CustomerID": 239,
"ReturnDate": "2005-05-27T23:33:46Z",
"StaffID": 2,
"LastUpdate": "0001-01-01T00:00:00Z"
}
]
`)
}