diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 95283a1..92c9df9 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -184,14 +184,6 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB panic("jet: UNION Statement must contain at least two SELECT statements") } - wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0 - - if wrap { - out.NewLine() - out.WriteString("(") - out.IncreaseIdent() - } - for i, selectStmt := range s.Selects { out.NewLine() if i > 0 { @@ -210,12 +202,6 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB selectStmt.serialize(statementType, out) } - if wrap { - out.DecreaseIdent() - out.NewLine() - out.WriteString(")") - } - s.OrderBy.Serialize(statementType, out) s.Limit.Serialize(statementType, out) s.Offset.Serialize(statementType, out) @@ -391,26 +377,19 @@ func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBu } } -type ClauseString struct { - Name string - Data string -} - -func (d *ClauseString) Serialize(statementType StatementType, out *SqlBuilder) { - out.NewLine() - out.WriteString(d.Name) - out.WriteString(d.Data) -} - type ClauseOptional struct { - Name string - Show bool + Name string + Show bool + InNewLine bool } func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) { if !d.Show { return } + if d.InNewLine { + out.NewLine() + } out.WriteString(d.Name) } diff --git a/mysql/select_statement.go b/mysql/select_statement.go index d205bbb..11a1563 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -23,6 +23,7 @@ type SelectStatement interface { LIMIT(limit int64) SelectStatement OFFSET(offset int64) SelectStatement FOR(lock SelectLock) SelectStatement + LOCK_IN_SHARE_MODE() SelectStatement UNION(rhs SelectStatement) SetStatement UNION_ALL(rhs SelectStatement) SetStatement @@ -39,7 +40,7 @@ func newSelectStatement(table ReadableTable, projections []jet.Projection) Selec newSelect := &selectStatementImpl{} newSelect.ExpressionStatementImpl.StatementImpl = jet.NewStatementImpl(Dialect, jet.SelectStatementType, newSelect, &newSelect.Select, &newSelect.From, &newSelect.Where, &newSelect.GroupBy, &newSelect.Having, &newSelect.OrderBy, - &newSelect.Limit, &newSelect.Offset, &newSelect.For) + &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) newSelect.ExpressionStatementImpl.ExpressionInterfaceImpl.Parent = newSelect @@ -47,6 +48,8 @@ func newSelectStatement(table ReadableTable, projections []jet.Projection) Selec newSelect.From.Table = table newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 + newSelect.ShareLock.Name = "LOCK IN SHARE MODE" + newSelect.ShareLock.InNewLine = true newSelect.setOperatorsImpl.parent = newSelect @@ -57,15 +60,16 @@ type selectStatementImpl struct { jet.ExpressionStatementImpl setOperatorsImpl - Select jet.ClauseSelect - From jet.ClauseFrom - Where jet.ClauseWhere - GroupBy jet.ClauseGroupBy - Having jet.ClauseHaving - OrderBy jet.ClauseOrderBy - Limit jet.ClauseLimit - Offset jet.ClauseOffset - For jet.ClauseFor + Select jet.ClauseSelect + From jet.ClauseFrom + Where jet.ClauseWhere + GroupBy jet.ClauseGroupBy + Having jet.ClauseHaving + OrderBy jet.ClauseOrderBy + Limit jet.ClauseLimit + Offset jet.ClauseOffset + For jet.ClauseFor + ShareLock jet.ClauseOptional } func (s *selectStatementImpl) DISTINCT() SelectStatement { @@ -113,6 +117,11 @@ func (s *selectStatementImpl) FOR(lock SelectLock) SelectStatement { return s } +func (s *selectStatementImpl) LOCK_IN_SHARE_MODE() SelectStatement { + s.ShareLock.Show = true + return s +} + func (s *selectStatementImpl) AsTable(alias string) SelectTable { return newSelectTable(s, alias) } diff --git a/mysql/select_statement_test.go b/mysql/select_statement_test.go index e7f3d7e..9655192 100644 --- a/mysql/select_statement_test.go +++ b/mysql/select_statement_test.go @@ -124,3 +124,11 @@ FROM db.table1 FOR SHARE NOWAIT; `) } + +func TestSelect_LOCK_IN_SHARE_MODE(t *testing.T) { + testutils.AssertStatementSql(t, SELECT(table1ColBool).FROM(table1).LOCK_IN_SHARE_MODE(), ` +SELECT table1.col_bool AS "table1.col_bool" +FROM db.table1 +LOCK IN SHARE MODE; +`) +} diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 7adef82..ab0b16d 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -26,6 +26,12 @@ func TestAllTypes(t *testing.T) { assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + + if sourceIsMariaDB() { // MariaDB saves current timestamp in a case of NULL value insert + dest[1].TimestampPtr = nil + } + //testutils.PrintJson(dest) testutils.AssertJSON(t, dest, allTypesJson) } @@ -788,7 +794,20 @@ LIMIT ?; //testutils.PrintJson(dest) - testutils.AssertJSON(t, dest, ` + if sourceIsMariaDB() { + testutils.AssertJSON(t, dest, ` +{ + "Date": "2009-11-17T00:00:00Z", + "DateT": "2009-11-17T00:00:00Z", + "Time": "0000-01-01T20:34:58Z", + "TimeT": "0000-01-01T19:34:58Z", + "DateTime": "2009-11-17T19:34:58Z", + "Timestamp": "2019-08-06T10:10:30Z", + "TimestampT": "2009-11-17T19:34:58Z" +} +`) + } else { + testutils.AssertJSON(t, dest, ` { "Date": "2009-11-17T00:00:00Z", "DateT": "2009-11-17T00:00:00Z", @@ -799,6 +818,8 @@ LIMIT ?; "TimestampT": "2009-11-17T19:34:58.351387Z" } `) + } + } var allTypesJson = ` diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index f544f18..eb19a68 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -2,6 +2,7 @@ package mysql import ( "database/sql" + "flag" "github.com/go-jet/jet/tests/dbconfig" _ "github.com/go-sql-driver/mysql" @@ -13,6 +14,19 @@ import ( var db *sql.DB +var source string + +const MariaDB = "MariaDB" + +func init() { + flag.StringVar(&source, "source", "", "MySQL or MariaDB") + flag.Parse() +} + +func sourceIsMariaDB() bool { + return source == MariaDB +} + func TestMain(m *testing.M) { defer profile.Start().Stop() diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index 637ed05..8680a49 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -174,6 +174,10 @@ func TestSubQuery(t *testing.T) { } func TestSelectAndUnionInProjection(t *testing.T) { + if sourceIsMariaDB() { + return + } + expectedSQL := ` SELECT payment.payment_id AS "payment.payment_id", ( @@ -183,19 +187,17 @@ SELECT payment.payment_id AS "payment.payment_id", ), ( ( - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT ? - OFFSET ? - ) - UNION - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT ? - OFFSET ? - ) + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? + ) + UNION + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? ) LIMIT ? ) @@ -223,19 +225,17 @@ LIMIT ?; func TestSelectUNION(t *testing.T) { expectedSQL := ` ( - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT ? - OFFSET ? - ) - UNION - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT ? - OFFSET ? - ) + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? +) +UNION +( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? ) LIMIT ?; ` @@ -260,19 +260,17 @@ LIMIT ?; func TestSelectUNION_ALL(t *testing.T) { expectedSQL := ` ( - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT ? - OFFSET ? - ) - UNION ALL - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT ? - OFFSET ? - ) + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? +) +UNION ALL +( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT ? + OFFSET ? ) ORDER BY "payment.payment_id" LIMIT ? @@ -408,6 +406,11 @@ LIMIT ?; } func getRowLockTestData() map[SelectLock]string { + if sourceIsMariaDB() { + return map[SelectLock]string{ + UPDATE(): "UPDATE", + } + } return map[SelectLock]string{ UPDATE(): "UPDATE", SHARE(): "SHARE", @@ -455,6 +458,10 @@ FOR` assert.NilError(t, err) } + if sourceIsMariaDB() { + return + } + for lockType, lockTypeStr := range getRowLockTestData() { query.FOR(lockType.SKIP_LOCKED()) @@ -496,3 +503,23 @@ SELECT true, err := query.Query(db, &struct{}{}) assert.NilError(t, err) } + +func TestLockInShareMode(t *testing.T) { + expectedSQL := ` +SELECT * +FROM dvds.address +LIMIT 3 +OFFSET 1 +LOCK IN SHARE MODE; +` + query := Address. + SELECT(STAR). + LIMIT(3). + OFFSET(1). + LOCK_IN_SHARE_MODE() + + testutils.AssertDebugStatementSql(t, query, expectedSQL) + + err := query.Query(db, &struct{}{}) + assert.NilError(t, err) +} diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 5b7cf01..c4dbd28 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -215,21 +215,19 @@ func TestUnionForQuotedNames(t *testing.T) { //fmt.Println(stmt.DebugSql()) testutils.AssertDebugStatementSql(t, stmt, ` ( - ( - SELECT "Album"."AlbumId" AS "Album.AlbumId", - "Album"."Title" AS "Album.Title", - "Album"."ArtistId" AS "Album.ArtistId" - FROM chinook."Album" - WHERE "Album"."AlbumId" = 1 - ) - UNION ALL - ( - SELECT "Album"."AlbumId" AS "Album.AlbumId", - "Album"."Title" AS "Album.Title", - "Album"."ArtistId" AS "Album.ArtistId" - FROM chinook."Album" - WHERE "Album"."AlbumId" = 2 - ) + SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" + FROM chinook."Album" + WHERE "Album"."AlbumId" = 1 +) +UNION ALL +( + SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" + FROM chinook."Album" + WHERE "Album"."AlbumId" = 2 ) ORDER BY "Album.AlbumId"; `, int64(1), int64(2)) diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 8d63d8c..38b5e1e 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -128,19 +128,17 @@ SELECT payment.payment_id AS "payment.payment_id", ), ( ( - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT 1 - OFFSET 10 - ) - UNION - ( - SELECT payment.payment_id AS "payment.payment_id" - FROM dvds.payment - LIMIT 1 - OFFSET 2 - ) + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT 1 + OFFSET 10 + ) + UNION + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT 1 + OFFSET 2 ) LIMIT 1 ) @@ -1217,19 +1215,17 @@ ORDER BY payment.payment_date ASC; func TestUnion(t *testing.T) { expectedQuery := ` ( - ( - SELECT payment.payment_id AS "payment.payment_id", - payment.amount AS "payment.amount" - FROM dvds.payment - WHERE payment.amount <= 100 - ) - UNION ALL - ( - SELECT payment.payment_id AS "payment.payment_id", - payment.amount AS "payment.amount" - FROM dvds.payment - WHERE payment.amount >= 200 - ) + SELECT payment.payment_id AS "payment.payment_id", + payment.amount AS "payment.amount" + FROM dvds.payment + WHERE payment.amount <= 100 +) +UNION ALL +( + SELECT payment.payment_id AS "payment.payment_id", + payment.amount AS "payment.amount" + FROM dvds.payment + WHERE payment.amount >= 200 ) ORDER BY "payment.payment_id" ASC, "payment.amount" DESC LIMIT 10