Add FROM clause support for UPDATE statements
This commit is contained in:
parent
97c34fbb54
commit
72e8d7d584
11 changed files with 211 additions and 18 deletions
|
|
@ -45,6 +45,7 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o
|
||||||
|
|
||||||
// ClauseFrom struct
|
// ClauseFrom struct
|
||||||
type ClauseFrom struct {
|
type ClauseFrom struct {
|
||||||
|
Name string
|
||||||
Tables []Serializer
|
Tables []Serializer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -54,7 +55,11 @@ func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, opt
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
out.NewLine()
|
out.NewLine()
|
||||||
|
if f.Name != "" {
|
||||||
|
out.WriteString(f.Name)
|
||||||
|
} else {
|
||||||
out.WriteString("FROM")
|
out.WriteString("FROM")
|
||||||
|
}
|
||||||
|
|
||||||
out.IncreaseIdent()
|
out.IncreaseIdent()
|
||||||
for i, table := range f.Tables {
|
for i, table := range f.Tables {
|
||||||
|
|
|
||||||
|
|
@ -110,10 +110,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
|
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
|
||||||
s.From.Tables = nil
|
s.From.Tables = readableTablesToSerializerList(tables)
|
||||||
for _, table := range tables {
|
|
||||||
s.From.Tables = append(s.From.Tables, table)
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -182,3 +179,11 @@ func toJetFrameOffset(offset int64) jet.Serializer {
|
||||||
}
|
}
|
||||||
return jet.FixedLiteral(offset)
|
return jet.FixedLiteral(offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
|
||||||
|
var ret []jet.Serializer
|
||||||
|
for _, table := range tables {
|
||||||
|
ret = append(ret, table)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,9 @@ type UpdateStatement interface {
|
||||||
SET(value interface{}, values ...interface{}) UpdateStatement
|
SET(value interface{}, values ...interface{}) UpdateStatement
|
||||||
MODEL(data interface{}) UpdateStatement
|
MODEL(data interface{}) UpdateStatement
|
||||||
|
|
||||||
|
FROM(tables ...ReadableTable) UpdateStatement
|
||||||
WHERE(expression BoolExpression) UpdateStatement
|
WHERE(expression BoolExpression) UpdateStatement
|
||||||
RETURNING(projections ...jet.Projection) UpdateStatement
|
RETURNING(projections ...Projection) UpdateStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
type updateStatementImpl struct {
|
type updateStatementImpl struct {
|
||||||
|
|
@ -21,6 +22,7 @@ type updateStatementImpl struct {
|
||||||
Update jet.ClauseUpdate
|
Update jet.ClauseUpdate
|
||||||
Set clauseSet
|
Set clauseSet
|
||||||
SetNew jet.SetClauseNew
|
SetNew jet.SetClauseNew
|
||||||
|
From jet.ClauseFrom
|
||||||
Where jet.ClauseWhere
|
Where jet.ClauseWhere
|
||||||
Returning jet.ClauseReturning
|
Returning jet.ClauseReturning
|
||||||
}
|
}
|
||||||
|
|
@ -31,6 +33,7 @@ func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStateme
|
||||||
&update.Update,
|
&update.Update,
|
||||||
&update.Set,
|
&update.Set,
|
||||||
&update.SetNew,
|
&update.SetNew,
|
||||||
|
&update.From,
|
||||||
&update.Where,
|
&update.Where,
|
||||||
&update.Returning)
|
&update.Returning)
|
||||||
|
|
||||||
|
|
@ -61,6 +64,11 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement {
|
||||||
|
u.From.Tables = readableTablesToSerializerList(tables)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
|
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
|
||||||
u.Where.Condition = expression
|
u.Where.Condition = expression
|
||||||
return u
|
return u
|
||||||
|
|
|
||||||
|
|
@ -106,10 +106,7 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
|
func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement {
|
||||||
s.From.Tables = nil
|
s.From.Tables = readableTablesToSerializerList(tables)
|
||||||
for _, table := range tables {
|
|
||||||
s.From.Tables = append(s.From.Tables, table)
|
|
||||||
}
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -184,3 +181,11 @@ func toJetFrameOffset(offset interface{}) jet.Serializer {
|
||||||
|
|
||||||
return jet.FixedLiteral(offset)
|
return jet.FixedLiteral(offset)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func readableTablesToSerializerList(tables []ReadableTable) []jet.Serializer {
|
||||||
|
var ret []jet.Serializer
|
||||||
|
for _, table := range tables {
|
||||||
|
ret = append(ret, table)
|
||||||
|
}
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,14 +9,16 @@ type UpdateStatement interface {
|
||||||
SET(value interface{}, values ...interface{}) UpdateStatement
|
SET(value interface{}, values ...interface{}) UpdateStatement
|
||||||
MODEL(data interface{}) UpdateStatement
|
MODEL(data interface{}) UpdateStatement
|
||||||
|
|
||||||
|
FROM(tables ...ReadableTable) UpdateStatement
|
||||||
WHERE(expression BoolExpression) UpdateStatement
|
WHERE(expression BoolExpression) UpdateStatement
|
||||||
RETURNING(projections ...jet.Projection) UpdateStatement
|
RETURNING(projections ...Projection) UpdateStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
type updateStatementImpl struct {
|
type updateStatementImpl struct {
|
||||||
jet.SerializerStatement
|
jet.SerializerStatement
|
||||||
|
|
||||||
Update jet.ClauseUpdate
|
Update jet.ClauseUpdate
|
||||||
|
From jet.ClauseFrom
|
||||||
Set jet.SetClause
|
Set jet.SetClause
|
||||||
SetNew jet.SetClauseNew
|
SetNew jet.SetClauseNew
|
||||||
Where jet.ClauseWhere
|
Where jet.ClauseWhere
|
||||||
|
|
@ -29,6 +31,7 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
|
||||||
&update.Update,
|
&update.Update,
|
||||||
&update.Set,
|
&update.Set,
|
||||||
&update.SetNew,
|
&update.SetNew,
|
||||||
|
&update.From,
|
||||||
&update.Where,
|
&update.Where,
|
||||||
&update.Returning)
|
&update.Returning)
|
||||||
|
|
||||||
|
|
@ -59,12 +62,17 @@ func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) FROM(tables ...ReadableTable) UpdateStatement {
|
||||||
|
u.From.Tables = readableTablesToSerializerList(tables)
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
|
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
|
||||||
u.Where.Condition = expression
|
u.Where.Condition = expression
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *updateStatementImpl) RETURNING(projections ...jet.Projection) UpdateStatement {
|
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
|
||||||
u.Returning.ProjectionList = projections
|
u.Returning.ProjectionList = projections
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -70,3 +70,9 @@ func skipForMariaDB(t *testing.T) {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func beginTx(t *testing.T) *sql.Tx {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -261,15 +261,22 @@ func TestUpdateExecContext(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUpdateWithJoin(t *testing.T) {
|
func TestUpdateWithJoin(t *testing.T) {
|
||||||
query := table.Staff.
|
tx := beginTx(t)
|
||||||
INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
statement := table.Staff.INNER_JOIN(table.Address, table.Address.AddressID.EQ(table.Staff.AddressID)).
|
||||||
UPDATE(table.Staff.LastName).
|
UPDATE(table.Staff.LastName).
|
||||||
SET(String("New name")).
|
SET(String("New staff name")).
|
||||||
WHERE(table.Staff.StaffID.EQ(Int(1)))
|
WHERE(table.Staff.StaffID.EQ(Int(1)))
|
||||||
|
|
||||||
//fmt.Println(query.DebugSql())
|
testutils.AssertStatementSql(t, statement, `
|
||||||
|
UPDATE dvds.staff
|
||||||
|
INNER JOIN dvds.address ON (address.address_id = staff.address_id)
|
||||||
|
SET last_name = ?
|
||||||
|
WHERE staff.staff_id = ?;
|
||||||
|
`, "New staff name", int64(1))
|
||||||
|
|
||||||
_, err := query.Exec(db)
|
_, err := statement.Exec(tx)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,3 +87,9 @@ func isPgxDriver() bool {
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func beginTx(t *testing.T) *sql.Tx {
|
||||||
|
tx, err := db.Begin()
|
||||||
|
require.NoError(t, err)
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/go-jet/jet/v2/internal/testutils"
|
"github.com/go-jet/jet/v2/internal/testutils"
|
||||||
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
|
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
|
||||||
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/test_sample/model"
|
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
|
||||||
|
|
||||||
. "github.com/go-jet/jet/v2/postgres"
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/go-jet/jet/v2/internal/testutils"
|
"github.com/go-jet/jet/v2/internal/testutils"
|
||||||
. "github.com/go-jet/jet/v2/postgres"
|
. "github.com/go-jet/jet/v2/postgres"
|
||||||
|
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model"
|
||||||
|
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table"
|
||||||
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
|
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/model"
|
||||||
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table"
|
. "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/test_sample/table"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
@ -371,6 +373,74 @@ func TestUpdateExecContext(t *testing.T) {
|
||||||
require.Error(t, err, "context deadline exceeded")
|
require.Error(t, err, "context deadline exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateFrom(t *testing.T) {
|
||||||
|
tx := beginTx(t)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt := table.Rental.UPDATE().
|
||||||
|
SET(
|
||||||
|
table.Rental.RentalDate.SET(Timestamp(2020, 2, 2, 0, 0, 0)),
|
||||||
|
).FROM(
|
||||||
|
table.Staff.
|
||||||
|
INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)),
|
||||||
|
table.Actor,
|
||||||
|
).WHERE(
|
||||||
|
table.Staff.StaffID.EQ(table.Rental.StaffID).
|
||||||
|
AND(table.Staff.StaffID.EQ(Int(2))).
|
||||||
|
AND(table.Rental.RentalID.LT(Int(10))),
|
||||||
|
).RETURNING(
|
||||||
|
table.Rental.AllColumns.Except(table.Rental.LastUpdate),
|
||||||
|
table.Store.AllColumns.Except(table.Store.LastUpdate),
|
||||||
|
)
|
||||||
|
|
||||||
|
testutils.AssertStatementSql(t, stmt, `
|
||||||
|
UPDATE dvds.rental
|
||||||
|
SET rental_date = $1::timestamp without time zone
|
||||||
|
FROM dvds.staff
|
||||||
|
INNER JOIN dvds.store ON (store.store_id = staff.staff_id),
|
||||||
|
dvds.actor
|
||||||
|
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $2)) AND (rental.rental_id < $3)
|
||||||
|
RETURNING rental.rental_id AS "rental.rental_id",
|
||||||
|
rental.rental_date AS "rental.rental_date",
|
||||||
|
rental.inventory_id AS "rental.inventory_id",
|
||||||
|
rental.customer_id AS "rental.customer_id",
|
||||||
|
rental.return_date AS "rental.return_date",
|
||||||
|
rental.staff_id AS "rental.staff_id",
|
||||||
|
store.store_id AS "store.store_id",
|
||||||
|
store.manager_staff_id AS "store.manager_staff_id",
|
||||||
|
store.address_id AS "store.address_id";
|
||||||
|
`)
|
||||||
|
|
||||||
|
var dest []struct {
|
||||||
|
Rental model2.Rental
|
||||||
|
Store model2.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
err := stmt.Query(tx, &dest)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, dest, 3)
|
||||||
|
testutils.AssertJSON(t, dest[0], `
|
||||||
|
{
|
||||||
|
"Rental": {
|
||||||
|
"RentalID": 4,
|
||||||
|
"RentalDate": "2020-02-02T00:00:00Z",
|
||||||
|
"InventoryID": 2452,
|
||||||
|
"CustomerID": 333,
|
||||||
|
"ReturnDate": "2005-06-03T01:43:41Z",
|
||||||
|
"StaffID": 2,
|
||||||
|
"LastUpdate": "0001-01-01T00:00:00Z"
|
||||||
|
},
|
||||||
|
"Store": {
|
||||||
|
"StoreID": 2,
|
||||||
|
"ManagerStaffID": 2,
|
||||||
|
"AddressID": 2,
|
||||||
|
"LastUpdate": "0001-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
func setupLinkTableForUpdateTest(t *testing.T) {
|
func setupLinkTableForUpdateTest(t *testing.T) {
|
||||||
|
|
||||||
cleanUpLinkTable(t)
|
cleanUpLinkTable(t)
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ package sqlite
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
model2 "github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/model"
|
||||||
|
"github.com/go-jet/jet/v2/tests/.gentestdata/sqlite/sakila/table"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -288,3 +290,74 @@ func TestUpdateContextDeadlineExceeded(t *testing.T) {
|
||||||
_, err = updateStmt.ExecContext(ctx, tx)
|
_, err = updateStmt.ExecContext(ctx, tx)
|
||||||
require.Error(t, err, "context deadline exceeded")
|
require.Error(t, err, "context deadline exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateFrom(t *testing.T) {
|
||||||
|
tx := beginDBTx(t)
|
||||||
|
defer tx.Rollback()
|
||||||
|
|
||||||
|
stmt := table.Rental.UPDATE().
|
||||||
|
SET(
|
||||||
|
table.Rental.RentalDate.SET(DateTime(2020, 2, 2, 0, 0, 0)),
|
||||||
|
).FROM(
|
||||||
|
table.Staff.
|
||||||
|
INNER_JOIN(table.Store, table.Store.StoreID.EQ(table.Staff.StaffID)),
|
||||||
|
).WHERE(
|
||||||
|
table.Staff.StaffID.EQ(table.Rental.StaffID).
|
||||||
|
AND(table.Staff.StaffID.EQ(Int(2))).
|
||||||
|
AND(table.Rental.RentalID.LT(Int(10))),
|
||||||
|
).RETURNING(
|
||||||
|
table.Rental.AllColumns.Except(table.Rental.LastUpdate),
|
||||||
|
)
|
||||||
|
|
||||||
|
testutils.AssertDebugStatementSql(t, stmt, `
|
||||||
|
UPDATE rental
|
||||||
|
SET rental_date = DATETIME('2020-02-02 00:00:00')
|
||||||
|
FROM staff
|
||||||
|
INNER JOIN store ON (store.store_id = staff.staff_id)
|
||||||
|
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = 2)) AND (rental.rental_id < 10)
|
||||||
|
RETURNING rental.rental_id AS "rental.rental_id",
|
||||||
|
rental.rental_date AS "rental.rental_date",
|
||||||
|
rental.inventory_id AS "rental.inventory_id",
|
||||||
|
rental.customer_id AS "rental.customer_id",
|
||||||
|
rental.return_date AS "rental.return_date",
|
||||||
|
rental.staff_id AS "rental.staff_id";
|
||||||
|
`)
|
||||||
|
|
||||||
|
var dest []model2.Rental
|
||||||
|
|
||||||
|
err := stmt.Query(tx, &dest)
|
||||||
|
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Len(t, dest, 3)
|
||||||
|
testutils.AssertJSON(t, dest, `
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"RentalID": 4,
|
||||||
|
"RentalDate": "2020-02-02T00:00:00Z",
|
||||||
|
"InventoryID": 2452,
|
||||||
|
"CustomerID": 333,
|
||||||
|
"ReturnDate": "2005-06-03T01:43:41Z",
|
||||||
|
"StaffID": 2,
|
||||||
|
"LastUpdate": "0001-01-01T00:00:00Z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"RentalID": 7,
|
||||||
|
"RentalDate": "2020-02-02T00:00:00Z",
|
||||||
|
"InventoryID": 3995,
|
||||||
|
"CustomerID": 269,
|
||||||
|
"ReturnDate": "2005-05-29T20:34:53Z",
|
||||||
|
"StaffID": 2,
|
||||||
|
"LastUpdate": "0001-01-01T00:00:00Z"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"RentalID": 8,
|
||||||
|
"RentalDate": "2020-02-02T00:00:00Z",
|
||||||
|
"InventoryID": 2346,
|
||||||
|
"CustomerID": 239,
|
||||||
|
"ReturnDate": "2005-05-27T23:33:46Z",
|
||||||
|
"StaffID": 2,
|
||||||
|
"LastUpdate": "0001-01-01T00:00:00Z"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue