Merge pull request #452 from Hasaber8/feature/448

feature: Add support for LIMIT query in UPDATE for sql/sqlite
This commit is contained in:
go-jet 2025-02-20 14:59:38 +01:00 committed by GitHub
commit 7047de44a9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 101 additions and 5 deletions

View file

@ -12,6 +12,7 @@ type UpdateStatement interface {
MODEL(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
LIMIT(limit int64) UpdateStatement
}
type updateStatementImpl struct {
@ -21,6 +22,7 @@ type updateStatementImpl struct {
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
Limit jet.ClauseLimit
}
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
@ -29,11 +31,13 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
&update.Update,
&update.Set,
&update.SetNew,
&update.Where)
&update.Where,
&update.Limit)
update.Update.Table = table
update.Set.Columns = columns
update.Where.Mandatory = true
update.Limit.Count = -1 // Initialize to -1 to indicate no LIMIT
return update
}
@ -67,3 +71,11 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression
return u
}
func (u *updateStatementImpl) LIMIT(limit int64) UpdateStatement {
if _, isJoinTable := u.Update.Table.(*joinTable); isJoinTable {
panic("jet: MySQL does not support LIMIT with multi-table UPDATE statements")
}
u.Limit.Count = limit
return u
}

View file

@ -6,6 +6,21 @@ import (
"testing"
)
func TestUpdateWithLimit(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?
WHERE table1.col_int >= ?
LIMIT ?;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33))).
LIMIT(5)
assertStatementSql(t, stmt, expectedSQL, 1, int64(33), int64(5))
}
func TestUpdateWithOneValue(t *testing.T) {
expectedSQL := `
UPDATE db.table1
@ -69,11 +84,26 @@ func TestUpdateReservedWorldColumn(t *testing.T) {
Load: "foo",
},
).
WHERE(loadColumn.EQ(String("bar"))), strings.Replace(`
WHERE(loadColumn.EQ(String("bar"))).
LIMIT(0),
strings.Replace(`
UPDATE db.table1
SET ''Load'' = ?
WHERE ''Load'' = ?;
`, "''", "`", -1), "foo", "bar")
WHERE ''Load'' = ?
LIMIT ?;
`, "''", "`", -1), "foo", "bar", int64(0))
}
func LimitPanicStatement() {
joinedTable := table1.INNER_JOIN(table2, table1Col1.EQ(table2Col3))
joinedTable.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33))).
LIMIT(5)
}
func TestUpdateWithMultiTableAndLimit(t *testing.T) {
assertPanicErr(t, func() { LimitPanicStatement() }, "jet: MySQL does not support LIMIT with multi-table UPDATE statements")
}
func TestInvalidInputs(t *testing.T) {

View file

@ -2,10 +2,11 @@ package mysql
import (
"context"
"github.com/go-jet/jet/v2/qrm"
"testing"
"time"
"github.com/go-jet/jet/v2/qrm"
"github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/internal/testutils"
@ -285,3 +286,56 @@ WHERE link.name = 'Bing';
require.NoError(t, err)
})
}
func TestUpdateWithLimit(t *testing.T) {
t.Run("single table update with limit", func(t *testing.T) {
stmt := Link.
UPDATE(Link.Name).
SET(String("Updated Link")).
WHERE(Link.Name.NOT_EQ(String(""))).
LIMIT(2)
testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link
SET name = 'Updated Link'
WHERE link.name != ''
LIMIT 2;
`)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
// Execute update
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt)
// Verify only 2 rows were updated
var updatedLinks []model.Link
err := Link.
SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Updated Link"))).
Query(tx, &updatedLinks)
require.NoError(t, err)
require.Equal(t, 2, len(updatedLinks))
})
})
t.Run("multi-table update with limit should panic", func(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
require.Equal(t, "jet: MySQL does not support LIMIT with multi-table UPDATE statements", r)
}()
joinedTable := Link.
INNER_JOIN(Link2, Link.Name.EQ(Link2.Name))
stmt := joinedTable.
UPDATE(Link.Name).
SET(String("Updated Link")).
WHERE(Link.Name.NOT_EQ(String(""))).
LIMIT(2)
// Statement construction itself should panic
_ = stmt
})
}