diff --git a/mysql/update_statement.go b/mysql/update_statement.go index b0acaef..d47378a 100644 --- a/mysql/update_statement.go +++ b/mysql/update_statement.go @@ -12,6 +12,7 @@ type UpdateStatement interface { MODEL(data interface{}) UpdateStatement WHERE(expression BoolExpression) UpdateStatement + LIMIT(limit int64) UpdateStatement } type updateStatementImpl struct { @@ -21,6 +22,7 @@ type updateStatementImpl struct { Set jet.SetClause SetNew jet.SetClauseNew Where jet.ClauseWhere + Limit jet.ClauseLimit } func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { @@ -29,11 +31,13 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement { &update.Update, &update.Set, &update.SetNew, - &update.Where) + &update.Where, + &update.Limit) update.Update.Table = table update.Set.Columns = columns update.Where.Mandatory = true + update.Limit.Count = -1 // Initialize to -1 to indicate no LIMIT return update } @@ -67,3 +71,11 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { u.Where.Condition = expression return u } + +func (u *updateStatementImpl) LIMIT(limit int64) UpdateStatement { + if _, isJoinTable := u.Update.Table.(*joinTable); isJoinTable { + panic("jet: MySQL does not support LIMIT with multi-table UPDATE statements") + } + u.Limit.Count = limit + return u +} diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go index f5284a7..5c9be2f 100644 --- a/mysql/update_statement_test.go +++ b/mysql/update_statement_test.go @@ -6,6 +6,21 @@ import ( "testing" ) +func TestUpdateWithLimit(t *testing.T) { + expectedSQL := ` +UPDATE db.table1 +SET col_int = ? +WHERE table1.col_int >= ? +LIMIT ?; +` + stmt := table1.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))). + LIMIT(5) + + assertStatementSql(t, stmt, expectedSQL, 1, int64(33), int64(5)) +} + func TestUpdateWithOneValue(t *testing.T) { expectedSQL := ` UPDATE db.table1 @@ -69,11 +84,26 @@ func TestUpdateReservedWorldColumn(t *testing.T) { Load: "foo", }, ). - WHERE(loadColumn.EQ(String("bar"))), strings.Replace(` + WHERE(loadColumn.EQ(String("bar"))). + LIMIT(0), + strings.Replace(` UPDATE db.table1 SET ''Load'' = ? -WHERE ''Load'' = ?; -`, "''", "`", -1), "foo", "bar") +WHERE ''Load'' = ? +LIMIT ?; +`, "''", "`", -1), "foo", "bar", int64(0)) +} + +func LimitPanicStatement() { + joinedTable := table1.INNER_JOIN(table2, table1Col1.EQ(table2Col3)) + joinedTable.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))). + LIMIT(5) +} + +func TestUpdateWithMultiTableAndLimit(t *testing.T) { + assertPanicErr(t, func() { LimitPanicStatement() }, "jet: MySQL does not support LIMIT with multi-table UPDATE statements") } func TestInvalidInputs(t *testing.T) { diff --git a/tests/mysql/update_test.go b/tests/mysql/update_test.go index cf9a7ea..2376bcb 100644 --- a/tests/mysql/update_test.go +++ b/tests/mysql/update_test.go @@ -2,10 +2,11 @@ package mysql import ( "context" - "github.com/go-jet/jet/v2/qrm" "testing" "time" + "github.com/go-jet/jet/v2/qrm" + "github.com/stretchr/testify/require" "github.com/go-jet/jet/v2/internal/testutils" @@ -285,3 +286,56 @@ WHERE link.name = 'Bing'; require.NoError(t, err) }) } + +func TestUpdateWithLimit(t *testing.T) { + t.Run("single table update with limit", func(t *testing.T) { + stmt := Link. + UPDATE(Link.Name). + SET(String("Updated Link")). + WHERE(Link.Name.NOT_EQ(String(""))). + LIMIT(2) + + testutils.AssertDebugStatementSql(t, stmt, ` +UPDATE test_sample.link +SET name = 'Updated Link' +WHERE link.name != '' +LIMIT 2; +`) + + testutils.ExecuteInTxAndRollback(t, db, func(tx qrm.DB) { + // Execute update + testutils.AssertExec(t, stmt, tx) + requireLogged(t, stmt) + + // Verify only 2 rows were updated + var updatedLinks []model.Link + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Updated Link"))). + Query(tx, &updatedLinks) + + require.NoError(t, err) + require.Equal(t, 2, len(updatedLinks)) + }) + }) + + t.Run("multi-table update with limit should panic", func(t *testing.T) { + defer func() { + r := recover() + require.NotNil(t, r) + require.Equal(t, "jet: MySQL does not support LIMIT with multi-table UPDATE statements", r) + }() + + joinedTable := Link. + INNER_JOIN(Link2, Link.Name.EQ(Link2.Name)) + + stmt := joinedTable. + UPDATE(Link.Name). + SET(String("Updated Link")). + WHERE(Link.Name.NOT_EQ(String(""))). + LIMIT(2) + + // Statement construction itself should panic + _ = stmt + }) +}