MySQL update statement tests.

This commit is contained in:
go-jet 2019-08-02 11:08:24 +02:00
parent a46e8c1c51
commit 7660bdd8b5
12 changed files with 406 additions and 115 deletions

View file

@ -131,7 +131,7 @@ func (q *sqlBuilder) writeReturning(statement statementType, returning []project
}
if !q.dialect.SupportsReturning {
panic(q.dialect.Name + " dialect does not support RETURNING.")
panic("jet: " + q.dialect.Name + " dialect does not support RETURNING.")
}
q.newLine()

View file

@ -21,6 +21,7 @@ func newPostgresDialect() Dialect {
return "$" + strconv.Itoa(ord)
}
postgresDialect.SupportsReturning = true
postgresDialect.UpdateAssigment = postgresUpdateAssigment
return postgresDialect
}
@ -40,6 +41,7 @@ func newMySQLDialect() Dialect {
}
mySQLDialect.SupportsReturning = false
mySQLDialect.UpdateAssigment = mysqlUpdateAssigment
return mySQLDialect
}
@ -52,6 +54,7 @@ type Dialect struct {
AliasQuoteChar byte
IdentifierQuoteChar byte
ArgumentPlaceholder queryPlaceholderFunc
UpdateAssigment func(columns []column, values []clause, out *sqlBuilder) (err error)
SupportsReturning bool
}
@ -60,6 +63,63 @@ func (d *Dialect) serializeOverride(operator string) serializeOverride {
return d.SerializeOverrides[operator]
}
func mysqlUpdateAssigment(columns []column, values []clause, out *sqlBuilder) (err error) {
if len(columns) != len(values) {
return errors.New("jet: mismatch in numers of columns and values")
}
for i, column := range columns {
if i > 0 {
out.writeString(", ")
}
out.writeString(column.Name())
out.writeString(" = ")
if err = values[i].serialize(updateStatement, out); err != nil {
return err
}
}
return nil
}
func postgresUpdateAssigment(columns []column, values []clause, out *sqlBuilder) (err error) {
if len(columns) > 1 {
out.writeString("(")
}
err = serializeColumnNames(columns, out)
if err != nil {
return
}
if len(columns) > 1 {
out.writeString(")")
}
out.writeString("=")
if len(values) > 1 {
out.writeString("(")
}
err = serializeClauseList(updateStatement, values, out)
if err != nil {
return
}
if len(values) > 1 {
out.writeString(")")
}
return
}
type queryPlaceholderFunc func(ord int) string
func newDialect(name, packageName string) Dialect {

View file

@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"github.com/go-jet/jet"
"github.com/go-jet/jet/execution"
"gotest.tools/assert"
"io/ioutil"
"runtime"
@ -28,6 +29,24 @@ func JsonSave(v interface{}, path string) {
}
}
func AssertExec(t *testing.T, stmt jet.Statement, db execution.DB, rowsAffected ...int64) {
res, err := stmt.Exec(db)
assert.NilError(t, err)
rows, err := res.RowsAffected()
assert.NilError(t, err)
if len(rowsAffected) > 0 {
assert.Equal(t, rows, rowsAffected[0])
}
}
func AssertExecErr(t *testing.T, stmt jet.Statement, db execution.DB, errorStr string) {
_, err := stmt.Exec(db)
assert.Error(t, err, errorStr)
}
func AssertJSON(t *testing.T, expectedJSON string, data interface{}) {
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
@ -46,8 +65,8 @@ func AssertJSONFile(t *testing.T, data interface{}, jsonFilePath string) {
jsonData, err := json.MarshalIndent(data, "", "\t")
assert.NilError(t, err)
//assert.Assert(t, string(fileJSONData) == string(jsonData))
assert.DeepEqual(t, string(fileJSONData), string(jsonData))
assert.Assert(t, string(fileJSONData) == string(jsonData))
//assert.DeepEqual(t, string(fileJSONData), string(jsonData))
}
func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {

View file

@ -138,4 +138,18 @@ CREATE TABLE IF NOT EXISTS test_sample.link (
);
INSERT INTO test_sample.link (ID, url, name, description) VALUES
(0, 'http://www.youtube.com', 'Youtube' , '');
-- Link2 table --------------------
DROP TABLE IF EXISTS test_sample.link2;
CREATE TABLE IF NOT EXISTS test_sample.link2 (
id int PRIMARY KEY AUTO_INCREMENT,
url VARCHAR (255) NOT NULL,
name VARCHAR (255) NOT NULL,
description VARCHAR (255)
);
INSERT INTO test_sample.link2 (ID, url, name, description) VALUES
(0, 'http://www.youtube.com', 'Youtube' , '');

View file

@ -13,7 +13,7 @@ func TestCast(t *testing.T) {
query := SELECT(
CAST(String("test")).AS("CHAR CHARACTER SET utf8").AS("result.AS1"),
CAST(String("2011-02-02")).AS_DATE().AS("result.date"),
CAST(String("2011-02-02")).AS_DATE().AS("result.date1"),
CAST(String("14:06:10")).AS_TIME().AS("result.time"),
CAST(String("2011-02-02 14:06:10")).AS_DATETIME().AS("result.datetime"),
@ -27,7 +27,7 @@ func TestCast(t *testing.T) {
testutils.AssertStatementSql(t, query, `
SELECT CAST(? AS CHAR CHARACTER SET utf8) AS "result.AS1",
CAST(? AS DATE) AS "result.date",
CAST(? AS DATE) AS "result.date1",
CAST(? AS TIME) AS "result.time",
CAST(? AS DATETIME) AS "result.datetime",
CAST(? AS CHAR) AS "result.char1",
@ -41,7 +41,7 @@ FROM test_sample.all_types;
type Result struct {
As1 string
Date time.Time
Date1 time.Time
Time time.Time
DateTime time.Time
Char1 string
@ -59,7 +59,7 @@ FROM test_sample.all_types;
assert.DeepEqual(t, dest, Result{
As1: "test",
Date: *testutils.Date("2011-02-02"),
Date1: *testutils.Date("2011-02-02"),
Time: *testutils.TimeWithoutTimeZone("14:06:10"),
DateTime: *testutils.TimestampWithoutTimeZone("2011-02-02 14:06:10", 0),
Char1: "150",

View file

@ -137,69 +137,6 @@ ORDER BY payment.customer_id, SUM(payment.amount) ASC;
testutils.AssertJSONFile(t, dest, "mysql/testdata/customer_payment_sum.json")
}
func getRowLockTestData() map[SelectLock]string {
return map[SelectLock]string{
UPDATE(): "UPDATE",
SHARE(): "SHARE",
}
}
func TestRowLock(t *testing.T) {
expectedSQL := `
SELECT *
FROM dvds.address
LIMIT 3
OFFSET 1
FOR`
query := Address.
SELECT(STAR).
LIMIT(3).
OFFSET(1)
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType)
expectedQuery := expectedSQL + " " + lockTypeStr + ";\n"
testutils.AssertDebugStatementSql(t, query, expectedQuery, int64(3), int64(1))
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.NOWAIT())
testutils.AssertDebugStatementSql(t, query, expectedSQL+" "+lockTypeStr+" NOWAIT;\n", int64(3), int64(1))
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.SKIP_LOCKED())
testutils.AssertDebugStatementSql(t, query, expectedSQL+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3), int64(1))
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
}
func TestSubQuery(t *testing.T) {
rRatingFilms := Film.
@ -385,3 +322,66 @@ LIMIT ?;
testutils.AssertJSONFile(t, dest, "./mysql/testdata/lang_film_actor_inventory_rental.json")
}
}
func getRowLockTestData() map[SelectLock]string {
return map[SelectLock]string{
UPDATE(): "UPDATE",
SHARE(): "SHARE",
}
}
func TestRowLock(t *testing.T) {
expectedSQL := `
SELECT *
FROM dvds.address
LIMIT 3
OFFSET 1
FOR`
query := Address.
SELECT(STAR).
LIMIT(3).
OFFSET(1)
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType)
expectedQuery := expectedSQL + " " + lockTypeStr + ";\n"
testutils.AssertDebugStatementSql(t, query, expectedQuery, int64(3), int64(1))
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.NOWAIT())
testutils.AssertDebugStatementSql(t, query, expectedSQL+" "+lockTypeStr+" NOWAIT;\n", int64(3), int64(1))
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
for lockType, lockTypeStr := range getRowLockTestData() {
query.FOR(lockType.SKIP_LOCKED())
testutils.AssertDebugStatementSql(t, query, expectedSQL+" "+lockTypeStr+" SKIP LOCKED;\n", int64(3), int64(1))
tx, _ := db.Begin()
_, err := query.Exec(tx)
assert.NilError(t, err)
err = tx.Rollback()
assert.NilError(t, err)
}
}

224
tests/mysql/update_test.go Normal file
View file

@ -0,0 +1,224 @@
package mysql
import (
"context"
. "github.com/go-jet/jet"
"github.com/go-jet/jet/internal/testutils"
"github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/model"
. "github.com/go-jet/jet/tests/.gentestdata/mysql/test_sample/table"
"gotest.tools/assert"
"testing"
"time"
)
func TestUpdateValues(t *testing.T) {
setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
var expectedSQL = `
UPDATE test_sample.link
SET name = 'Bong', url = 'http://bong.com'
WHERE link.name = 'Bing';
`
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
testutils.AssertExec(t, query, db)
links := []model.Link{}
err := Link.
SELECT(Link.AllColumns).
WHERE(Link.Name.EQ(String("Bong"))).
Query(db, &links)
assert.NilError(t, err)
assert.Equal(t, len(links), 1)
assert.DeepEqual(t, links[0], model.Link{
ID: 204,
URL: "http://bong.com",
Name: "Bong",
})
}
func TestUpdateWithSubQueries(t *testing.T) {
setupLinkTableForUpdateTest(t)
query := Link.
UPDATE(Link.Name, Link.URL).
SET(
SELECT(String("Bong")),
SELECT(Link2.URL).
FROM(Link2).
WHERE(Link2.Name.EQ(String("Youtube"))),
).
WHERE(Link.Name.EQ(String("Bing")))
expectedSQL := `
UPDATE test_sample.link
SET name = (
SELECT ?
), url = (
SELECT link2.url AS "link2.url"
FROM test_sample.link2
WHERE link2.name = ?
)
WHERE link.name = ?;
`
testutils.AssertStatementSql(t, query, expectedSQL, "Bong", "Youtube", "Bing")
testutils.AssertExec(t, query, db)
}
func TestUpdateAndReturning(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "jet: MySQL dialect does not support RETURNING.")
}()
stmt := Link.
UPDATE(Link.Name, Link.URL).
SET("DuckDuckGo", "http://www.duckduckgo.com").
WHERE(Link.Name.EQ(String("Ask"))).
RETURNING(Link.AllColumns)
stmt.Query(db, &struct{}{})
}
func TestUpdateWithModelData(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{
ID: 201,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
stmt := Link.
UPDATE(Link.AllColumns).
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.ID))))
expectedSQL := `
UPDATE test_sample.link
SET id = ?, url = ?, name = ?, description = ?
WHERE link.id = ?;
`
testutils.AssertStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
testutils.AssertExec(t, stmt, db)
}
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
setupLinkTableForUpdateTest(t)
link := model.Link{
ID: 201,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
updateColumnList := ColumnList{Link.Description, Link.Name, Link.URL}
stmt := Link.
UPDATE(updateColumnList).
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.ID))))
var expectedSQL = `
UPDATE test_sample.link
SET description = NULL, name = 'DuckDuckGo', url = 'http://www.duckduckgo.com'
WHERE link.id = 201;
`
//fmt.Println(stmt.DebugSql())
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
testutils.AssertExec(t, stmt, db)
}
func TestUpdateWithInvalidModelData(t *testing.T) {
defer func() {
r := recover()
assert.Equal(t, r, "missing struct field for column : id")
}()
setupLinkTableForUpdateTest(t)
link := struct {
Ident int
URL string
Name string
Description *string
Rel *string
}{
Ident: 201,
URL: "http://www.duckduckgo.com",
Name: "DuckDuckGo",
}
stmt := Link.
UPDATE(Link.AllColumns).
MODEL(link).
WHERE(Link.ID.EQ(Int(int64(link.Ident))))
stmt.Sql()
}
func TestUpdateQueryContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
dest := []model.Link{}
err := updateStmt.QueryContext(ctx, db, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestUpdateExecContext(t *testing.T) {
setupLinkTableForUpdateTest(t)
updateStmt := Link.
UPDATE(Link.Name, Link.URL).
SET("Bong", "http://bong.com").
WHERE(Link.Name.EQ(String("Bing")))
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Microsecond)
defer cancel()
time.Sleep(10 * time.Millisecond)
_, err := updateStmt.ExecContext(ctx, db)
assert.Error(t, err, "context deadline exceeded")
}
func setupLinkTableForUpdateTest(t *testing.T) {
cleanUpLinkTable(t)
_, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Description).
VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT).
VALUES(201, "http://www.ask.com", "Ask", DEFAULT).
VALUES(202, "http://www.ask.com", "Ask", DEFAULT).
VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT).
VALUES(204, "http://www.bing.com", "Bing", DEFAULT).
Exec(db)
assert.NilError(t, err)
}

View file

@ -23,7 +23,7 @@ WHERE link.name IN ('Gmail', 'Outlook');
WHERE(Link.Name.IN(String("Gmail"), String("Outlook")))
testutils.AssertDebugStatementSql(t, deleteStmt, expectedSQL, "Gmail", "Outlook")
assertExec(t, deleteStmt, 2)
AssertExec(t, deleteStmt, 2)
}
func TestDeleteWithWhereAndReturning(t *testing.T) {
@ -61,7 +61,7 @@ func initForDeleteTest(t *testing.T) {
VALUES("www.gmail.com", "Gmail", "Email service developed by Google").
VALUES("www.outlook.live.com", "Outlook", "Email service developed by Microsoft")
assertExec(t, stmt, 2)
AssertExec(t, stmt, 2)
}
func TestDeleteQueryContext(t *testing.T) {

View file

@ -88,7 +88,7 @@ INSERT INTO test_sample.link VALUES
testutils.AssertDebugStatementSql(t, stmt, expectedSQL,
100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial")
assertExec(t, stmt, 1)
AssertExec(t, stmt, 1)
}
func TestInsertModelObject(t *testing.T) {
@ -109,7 +109,7 @@ INSERT INTO test_sample.link (url, name) VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, "http://www.duckduckgo.com", "Duck Duck go")
assertExec(t, query, 1)
AssertExec(t, query, 1)
}
func TestInsertModelObjectEmptyColumnList(t *testing.T) {
@ -131,7 +131,7 @@ INSERT INTO test_sample.link VALUES
testutils.AssertDebugStatementSql(t, query, expectedSQL, int32(1000), "http://www.duckduckgo.com", "Duck Duck go", nil)
assertExec(t, query, 1)
AssertExec(t, query, 1)
}
func TestInsertModelsObject(t *testing.T) {
@ -166,7 +166,7 @@ INSERT INTO test_sample.link (url, name) VALUES
"http://www.google.com", "Google",
"http://www.yahoo.com", "Yahoo")
assertExec(t, stmt, 3)
AssertExec(t, stmt, 3)
}
func TestInsertUsingMutableColumns(t *testing.T) {
@ -200,7 +200,7 @@ INSERT INTO test_sample.link (url, name, description) VALUES
"http://www.google.com", "Google", nil,
"http://www.yahoo.com", "Yahoo", nil)
assertExec(t, stmt, 4)
AssertExec(t, stmt, 4)
}
func TestInsertQuery(t *testing.T) {

View file

@ -26,7 +26,7 @@ WHERE link.name = 'Bing';
`
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "http://bong.com", "Bing")
assertExec(t, query, 1)
AssertExec(t, query, 1)
links := []model.Link{}
@ -71,7 +71,7 @@ WHERE link.name = 'Bing';
testutils.AssertDebugStatementSql(t, query, expectedSQL, "Bong", "Bing", "Bing")
assertExec(t, query, 1)
AssertExec(t, query, 1)
}
func TestUpdateAndReturning(t *testing.T) {
@ -129,7 +129,7 @@ WHERE link.id = 0;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int64(0), int64(0))
assertExec(t, stmt, 1)
AssertExec(t, stmt, 1)
}
func TestUpdateWithInvalidSelect(t *testing.T) {
@ -178,7 +178,7 @@ WHERE link.id = 201;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, int64(201))
assertExec(t, stmt, 1)
AssertExec(t, stmt, 1)
}
func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) {
@ -205,7 +205,7 @@ WHERE link.id = 201;
`
testutils.AssertDebugStatementSql(t, stmt, expectedSQL, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201))
assertExec(t, stmt, 1)
AssertExec(t, stmt, 1)
}
func TestUpdateWithInvalidModelData(t *testing.T) {

View file

@ -9,7 +9,7 @@ import (
"testing"
)
func assertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) {
func AssertExec(t *testing.T, stmt jet.Statement, rowsAffected int64) {
res, err := stmt.Exec(db)
assert.NilError(t, err)

View file

@ -23,26 +23,26 @@ func newUpdateStatement(table WritableTable, columns []column) UpdateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
row: make([]clause, 0, len(columns)),
values: make([]clause, 0, len(columns)),
}
}
type updateStatementImpl struct {
table WritableTable
columns []column
row []clause
values []clause
where BoolExpression
returning []projection
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.row = unwindRowFromValues(value, values)
u.values = unwindRowFromValues(value, values)
return u
}
func (u *updateStatementImpl) MODEL(data interface{}) UpdateStatement {
u.row = unwindRowFromModel(u.columns, data)
u.values = unwindRowFromModel(u.columns, data)
return u
}
@ -82,43 +82,17 @@ func (u *updateStatementImpl) Sql(dialect ...Dialect) (query string, args []inte
return "", nil, errors.New("jet: no columns selected")
}
if len(u.row) == 0 {
if len(u.values) == 0 {
return "", nil, errors.New("jet: no values to updated")
}
out.newLine()
out.writeString("SET")
if len(u.columns) > 1 {
out.writeString("(")
}
err = serializeColumnNames(u.columns, out)
if err != nil {
if err = out.dialect.UpdateAssigment(u.columns, u.values, out); err != nil {
return
}
if len(u.columns) > 1 {
out.writeString(")")
}
out.writeString("=")
if len(u.row) > 1 {
out.writeString("(")
}
err = serializeClauseList(updateStatement, u.row, out)
if err != nil {
return
}
if len(u.row) > 1 {
out.writeString(")")
}
if u.where == nil {
return "", nil, errors.New("jet: WHERE clause not set")
}