diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 42dad7e..a6a49d8 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -308,7 +308,7 @@ func (s *SetClause) Serialize(statementType StatementType, out *SQLBuilder, opti panic("jet: nil column in columns list for SET clause") } - out.WriteString(column.Name()) + out.WriteIdentifier(column.Name()) out.WriteString(" = ") diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 272cf5f..9dd6318 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -90,7 +90,7 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { // AssertStatementSql check if statement Sql() is the same as expectedQuery and expectedArgs func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { queryStr, args := query.Sql() - require.Equal(t, queryStr, expectedQuery) + assertQueryString(t, queryStr, expectedQuery) if len(expectedArgs) == 0 { return @@ -117,12 +117,7 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st } debuqSql := query.DebugSql() - if !assert.Equal(t, debuqSql, expectedQuery) { - fmt.Println("Expected: ") - fmt.Println(expectedQuery) - fmt.Println("Got: ") - fmt.Println(debuqSql) - } + assertQueryString(t, debuqSql, expectedQuery) } // AssertSerialize checks if clause serialize produces expected query and args @@ -200,7 +195,7 @@ func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db qrm.DB, dest inter require.Equal(t, r, errString) }() - stmt.Query(db, dest) + _ = stmt.Query(db, dest) } // AssertFileContent check if file content at filePath contains expectedContent text. @@ -229,7 +224,22 @@ func AssertFileNamesEqual(t *testing.T, fileInfos []os.FileInfo, fileNames ...st // AssertDeepEqual checks if actual and expected objects are deeply equal. func AssertDeepEqual(t *testing.T, actual, expected interface{}, msg ...string) { - require.True(t, cmp.Equal(actual, expected), msg) + if !assert.True(t, cmp.Equal(actual, expected), msg) { + printDiff(actual, expected) + } +} + +func assertQueryString(t *testing.T, actual, expected string) { + if !assert.Equal(t, actual, expected) { + printDiff(actual, expected) + } +} + +func printDiff(actual, expected interface{}) { + fmt.Println("Actual: ") + fmt.Println(actual) + fmt.Println("Expected: ") + fmt.Println(expected) } // BoolPtr returns address of bool parameter diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go index fe3be01..f5284a7 100644 --- a/mysql/update_statement_test.go +++ b/mysql/update_statement_test.go @@ -2,6 +2,7 @@ package mysql import ( "fmt" + "strings" "testing" ) @@ -52,11 +53,29 @@ WHERE table1.col1 = ?; ). WHERE(table1Col1.EQ(Int(2))) - //fmt.Println(stmt.Sql()) - assertStatementSql(t, stmt, expectedSQL, int64(2)) } +func TestUpdateReservedWorldColumn(t *testing.T) { + type table struct { + Load string + } + + loadColumn := StringColumn("Load") + assertStatementSql(t, + table1.UPDATE(loadColumn). + MODEL( + table{ + Load: "foo", + }, + ). + WHERE(loadColumn.EQ(String("bar"))), strings.Replace(` +UPDATE db.table1 +SET ''Load'' = ? +WHERE ''Load'' = ?; +`, "''", "`", -1), "foo", "bar") +} + func TestInvalidInputs(t *testing.T) { assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause")