[BUG] Update statement reserved word not escaped

Update statement, using MODEL struct, now generates escaped SQL identifier if column name is reserved word.
This commit is contained in:
go-jet 2021-05-09 17:17:14 +02:00
parent 063b17ca05
commit 256be8a406
3 changed files with 41 additions and 12 deletions

View file

@ -308,7 +308,7 @@ func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, opti
panic("jet: nil column in columns list for SET clause") panic("jet: nil column in columns list for SET clause")
} }
out.WriteString(column.Name()) out.WriteIdentifier(column.Name())
out.WriteString(" = ") out.WriteString(" = ")

View file

@ -90,7 +90,7 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) {
// AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs
func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args := query.Sql() queryStr, args := query.Sql()
require.Equal(t, queryStr, expectedQuery) assertQueryString(t, queryStr, expectedQuery)
if len(expectedArgs) == 0 { if len(expectedArgs) == 0 {
return return
@ -117,12 +117,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st
} }
debuqSql := query.DebugSql() debuqSql := query.DebugSql()
if !assert.Equal(t, debuqSql, expectedQuery) { assertQueryString(t, debuqSql, expectedQuery)
fmt.Println("Expected: ")
fmt.Println(expectedQuery)
fmt.Println("Got: ")
fmt.Println(debuqSql)
}
} }
// AssertSerialize checks if clause serialize produces expected query and args // AssertSerialize checks if clause serialize produces expected query and args
@ -200,7 +195,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter
require.Equal(t, r, errString) require.Equal(t, r, errString)
}() }()
stmt.Query(db, dest) _ = stmt.Query(db, dest)
} }
// AssertFileContent check if file content at filePath contains expectedContent text. // AssertFileContent check if file content at filePath contains expectedContent text.
@ -229,7 +224,22 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st
// AssertDeepEqual checks if actual and expected objects are deeply equal. // AssertDeepEqual checks if actual and expected objects are deeply equal.
func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) {
require.True(t, cmp.Equal(actual, expected), msg) if !assert.True(t, cmp.Equal(actual, expected), msg) {
printDiff(actual, expected)
}
}
func assertQueryString(t *testing.T, actual, expected string) {
if !assert.Equal(t, actual, expected) {
printDiff(actual, expected)
}
}
func printDiff(actual, expected interface{}) {
fmt.Println("Actual: ")
fmt.Println(actual)
fmt.Println("Expected: ")
fmt.Println(expected)
} }
// BoolPtr returns address of bool parameter // BoolPtr returns address of bool parameter

View file

@ -2,6 +2,7 @@ package mysql
import ( import (
"fmt" "fmt"
"strings"
"testing" "testing"
) )
@ -52,11 +53,29 @@ WHERE table1.col1 = ?;
). ).
WHERE(table1Col1.EQ(Int(2))) WHERE(table1Col1.EQ(Int(2)))
//fmt.Println(stmt.Sql())
assertStatementSql(t, stmt, expectedSQL, int64(2)) assertStatementSql(t, stmt, expectedSQL, int64(2))
} }
func TestUpdateReservedWorldColumn(t *testing.T) {
type table struct {
Load string
}
loadColumn := StringColumn("Load")
assertStatementSql(t,
table1.UPDATE(loadColumn).
MODEL(
table{
Load: "foo",
},
).
WHERE(loadColumn.EQ(String("bar"))), strings.Replace(`
UPDATE db.table1
SET ''Load'' = ?
WHERE ''Load'' = ?;
`, "''", "`", -1), "foo", "bar")
}
func TestInvalidInputs(t *testing.T) { func TestInvalidInputs(t *testing.T) {
assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set")
assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause") assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause")