feature: Add support for LIMIT query in UPDATE for sql

Signed-off-by: Rohan Hasabe <rohanhasabe8@gmail.com>
This commit is contained in:
Rohan Hasabe 2025-02-08 17:25:02 -05:00
parent 00b8155f74
commit d733f9688e
No known key found for this signature in database
GPG key ID: BA28FE62BF71733A
3 changed files with 101 additions and 5 deletions

View file

@ -12,6 +12,7 @@ type UpdateStatement interface {
MODEL(data interface{}) UpdateStatement MODEL(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement WHERE(expression BoolExpression) UpdateStatement
LIMIT(limit int64) UpdateStatement
} }
type updateStatementImpl struct { type updateStatementImpl struct {
@ -21,6 +22,7 @@ type updateStatementImpl struct {
Set jet.SetClause Set jet.SetClause
SetNew jet.SetClauseNew SetNew jet.SetClauseNew
Where jet.ClauseWhere Where jet.ClauseWhere
Limit jet.ClauseLimit
} }
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
@ -29,11 +31,13 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
&update.Update, &update.Update,
&update.Set, &update.Set,
&update.SetNew, &update.SetNew,
&update.Where) &update.Where,
&update.Limit)
update.Update.Table = table update.Update.Table = table
update.Set.Columns = columns update.Set.Columns = columns
update.Where.Mandatory = true update.Where.Mandatory = true
update.Limit.Count = -1 // Initialize to -1 to indicate no LIMIT
return update return update
} }
@ -67,3 +71,11 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.Where.Condition = expression u.Where.Condition = expression
return u return u
} }
func (u *updateStatementImpl) LIMIT(limit int64) UpdateStatement {
if _, isJoinTable := u.Update.Table.(*joinTable); isJoinTable {
panic("jet: MySQL does not support LIMIT with multi-table UPDATE statements")
}
u.Limit.Count = limit
return u
}

View file

@ -6,6 +6,21 @@ import (
"testing" "testing"
) )
func TestUpdateWithLimit(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?
WHERE table1.col_int >= ?
LIMIT ?;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33))).
LIMIT(5)
assertStatementSql(t, stmt, expectedSQL, 1, int64(33), int64(5))
}
func TestUpdateWithOneValue(t *testing.T) { func TestUpdateWithOneValue(t *testing.T) {
expectedSQL := ` expectedSQL := `
UPDATE db.table1 UPDATE db.table1
@ -69,11 +84,26 @@ func TestUpdateReservedWorldColumn(t *testing.T) {
Load: "foo", Load: "foo",
}, },
). ).
WHERE(loadColumn.EQ(String("bar"))), strings.Replace(` WHERE(loadColumn.EQ(String("bar"))).
LIMIT(0),
strings.Replace(`
UPDATE db.table1 UPDATE db.table1
SET ''Load'' = ? SET ''Load'' = ?
WHERE ''Load'' = ?; WHERE ''Load'' = ?
`, "''", "`", -1), "foo", "bar") LIMIT ?;
`, "''", "`", -1), "foo", "bar", int64(0))
}
func LimitPanicStatement() {
joinedTable := table1.INNER_JOIN(table2, table1Col1.EQ(table2Col3))
joinedTable.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33))).
LIMIT(5)
}
func TestUpdateWithMultiTableAndLimit(t *testing.T) {
assertPanicErr(t, func() { LimitPanicStatement() }, "jet: MySQL does not support LIMIT with multi-table UPDATE statements")
} }
func TestInvalidInputs(t *testing.T) { func TestInvalidInputs(t *testing.T) {

View file

@ -2,10 +2,11 @@ package mysql
import ( import (
"context" "context"
"github.com/go-jet/jet/v2/qrm"
"testing" "testing"
"time" "time"
"github.com/go-jet/jet/v2/qrm"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
@ -285,3 +286,56 @@ WHERE link.name = 'Bing';
require.NoError(t, err) require.NoError(t, err)
}) })
} }
func TestUpdateWithLimit(t *testing.T) {
t.Run("single table update with limit", func(t *testing.T) {
stmt := Link.
UPDATE(Link.Name).
SET(String("Updated Link")).
WHERE(Link.Name.NOT_EQ(String(""))).
LIMIT(2)
testutils.AssertDebugStatementSql(t, stmt, `
UPDATE test_sample.link
SET name = 'Updated Link'
WHERE link.name != ''
LIMIT 2;
`)
testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) {
// Execute update
testutils.AssertExec(t, stmt, tx)
requireLogged(t, stmt)
// Verify only 2 rows were updated
var updatedLinks []model.Link
err := Link.
SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Updated Link"))).
Query(tx, &updatedLinks)
require.NoError(t, err)
require.Equal(t, 2, len(updatedLinks))
})
})
t.Run("multi-table update with limit should panic", func(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
require.Equal(t, "jet: MySQL does not support LIMIT with multi-table UPDATE statements", r)
}()
joinedTable := Link.
INNER_JOIN(Link2, Link.Name.EQ(Link2.Name))
stmt := joinedTable.
UPDATE(Link.Name).
SET(String("Updated Link")).
WHERE(Link.Name.NOT_EQ(String(""))).
LIMIT(2)
// Statement construction itself should panic
_ = stmt
})
}