diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 8654547..42dad7e 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -45,19 +45,25 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o // ClauseFrom struct type ClauseFrom struct { - Table Serializer + Tables []Serializer } // Serialize serializes clause into SQLBuilder func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - if f.Table == nil { + if len(f.Tables) == 0 { // SELECT statement does not have to have FROM clause return } out.NewLine() out.WriteString("FROM") out.IncreaseIdent() - f.Table.serialize(statementType, out, FallTrough(options)...) + for i, table := range f.Tables { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + table.serialize(statementType, out, FallTrough(options)...) + } out.DecreaseIdent() } diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index 52689d4..1421b14 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -13,8 +13,8 @@ type selectTableImpl struct { } // NewSelectTable func -func NewSelectTable(selectStmt SerializerStatement, alias string) SelectTable { - selectTable := &selectTableImpl{selectStmt: selectStmt, alias: alias} +func NewSelectTable(selectStmt SerializerStatement, alias string) selectTableImpl { + selectTable := selectTableImpl{selectStmt: selectStmt, alias: alias} return selectTable } @@ -38,3 +38,21 @@ func (s selectTableImpl) serialize(statement StatementType, out *SQLBuilder, opt out.WriteString("AS") out.WriteIdentifier(s.alias) } + +// -------------------------------------- + +type lateralImpl struct { + selectTableImpl +} + +func NewLateral(selectStmt SerializerStatement, alias string) SelectTable { + return lateralImpl{selectTableImpl: NewSelectTable(selectStmt, alias)} +} + +func (s lateralImpl) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) { + out.WriteString("LATERAL") + s.selectStmt.serialize(statement, out) + + out.WriteString("AS") + out.WriteIdentifier(s.alias) +} diff --git a/internal/jet/table.go b/internal/jet/table.go index 78abb99..c69a1dc 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -136,9 +136,6 @@ func (t *joinTableImpl) TableName() string { return "" } -func (t *joinTableImpl) AS(alias string) { -} - func (t *joinTableImpl) columns() []Column { var ret []Column diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 556ed6c..272cf5f 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -8,6 +8,7 @@ import ( "github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" "os" @@ -116,7 +117,12 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st } debuqSql := query.DebugSql() - require.Equal(t, debuqSql, expectedQuery) + if !assert.Equal(t, debuqSql, expectedQuery) { + fmt.Println("Expected: ") + fmt.Println(expectedQuery) + fmt.Println("Got: ") + fmt.Println(debuqSql) + } } // AssertSerialize checks if clause serialize produces expected query and args diff --git a/mysql/lateral.go b/mysql/lateral.go new file mode 100644 index 0000000..8ba974b --- /dev/null +++ b/mysql/lateral.go @@ -0,0 +1,23 @@ +package mysql + +import "github.com/go-jet/jet/v2/internal/jet" + +func LATERAL(selectStmt SelectStatement) lateralImpl { + return lateralImpl{ + selectStmt: selectStmt, + } +} + +type lateralImpl struct { + selectStmt SelectStatement +} + +func (l lateralImpl) AS(alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTable: jet.NewLateral(l.selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 37ac96b..fa6dd9c 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -41,7 +41,7 @@ type SelectStatement interface { Expression DISTINCT() SelectStatement - FROM(table ReadableTable) SelectStatement + FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement @@ -70,7 +70,9 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) newSelect.Select.ProjectionList = projections - newSelect.From.Table = table + if table != nil { + newSelect.From.Tables = []jet.Serializer{table} + } newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 newSelect.ShareLock.Name = "LOCK IN SHARE MODE" @@ -103,8 +105,10 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { return s } -func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { - s.From.Table = table +func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { + for _, table := range tables { + s.From.Tables = append(s.From.Tables, table) + } return s } diff --git a/postgres/lateral.go b/postgres/lateral.go new file mode 100644 index 0000000..8d2d4f8 --- /dev/null +++ b/postgres/lateral.go @@ -0,0 +1,23 @@ +package postgres + +import "github.com/go-jet/jet/v2/internal/jet" + +func LATERAL(selectStmt SelectStatement) lateralImpl { + return lateralImpl{ + selectStmt: selectStmt, + } +} + +type lateralImpl struct { + selectStmt SelectStatement +} + +func (l lateralImpl) AS(alias string) SelectTable { + subQuery := &selectTableImpl{ + SelectTable: jet.NewLateral(l.selectStmt, alias), + } + + subQuery.readableTableInterfaceImpl.parent = subQuery + + return subQuery +} diff --git a/postgres/lateral_test.go b/postgres/lateral_test.go new file mode 100644 index 0000000..35a0429 --- /dev/null +++ b/postgres/lateral_test.go @@ -0,0 +1,14 @@ +package postgres + +import "testing" + +func TestLATERAL(t *testing.T) { + assertSerialize(t, + LATERAL( + SELECT(Int(1)), + ).AS("lat1"), + + `LATERAL ( + SELECT $1 +) AS lat1`) +} diff --git a/postgres/select_statement.go b/postgres/select_statement.go index b31909f..a0d3e27 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -44,7 +44,7 @@ type SelectStatement interface { Expression DISTINCT() SelectStatement - FROM(table ReadableTable) SelectStatement + FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement @@ -76,7 +76,9 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta &newSelect.Limit, &newSelect.Offset, &newSelect.For) newSelect.Select.ProjectionList = projections - newSelect.From.Table = table + if table != nil { + newSelect.From.Tables = []jet.Serializer{table} + } newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 @@ -106,8 +108,10 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { return s } -func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { - s.From.Table = table +func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { + for _, table := range tables { + s.From.Tables = append(s.From.Tables, table) + } return s } diff --git a/postgres/table_test.go b/postgres/table_test.go index 3b6f498..fa12228 100644 --- a/postgres/table_test.go +++ b/postgres/table_test.go @@ -99,3 +99,26 @@ CROSS JOIN db.table2`) CROSS JOIN db.table2 CROSS JOIN db.table3`) } + +func TestImplicitCROSS_JOIN(t *testing.T) { + assertDebugStatementSql(t, + SELECT(table1Col1, table2Col3). + FROM(table1, table2), + ` +SELECT table1.col1 AS "table1.col1", + table2.col3 AS "table2.col3" +FROM db.table1, + db.table2; +`) + assertDebugStatementSql(t, + SELECT( + table1Col1, table2Col3, + ).FROM(table1, table2, table3), + ` +SELECT table1.col1 AS "table1.col1", + table2.col3 AS "table2.col3" +FROM db.table1, + db.table2, + db.table3; +`) +} diff --git a/tests/mysql/main_test.go b/tests/mysql/main_test.go index 4f9268c..e2be933 100644 --- a/tests/mysql/main_test.go +++ b/tests/mysql/main_test.go @@ -64,3 +64,9 @@ func requireLogged(t *testing.T, statement postgres.Statement) { require.Equal(t, loggedSQLArgs, args) require.Equal(t, loggedDebugSQL, statement.DebugSql()) } + +func skipForMariaDB(t *testing.T) { + if sourceIsMariaDB() { + t.SkipNow() + } +} diff --git a/tests/mysql/select_test.go b/tests/mysql/select_test.go index f447218..1a60a42 100644 --- a/tests/mysql/select_test.go +++ b/tests/mysql/select_test.go @@ -1,15 +1,17 @@ package mysql import ( + "strings" + "testing" + "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/mysql" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/enum" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/mysql/dvds/view" - "github.com/stretchr/testify/require" - "testing" + "github.com/stretchr/testify/require" ) func TestSelect_ScanToStruct(t *testing.T) { @@ -787,3 +789,101 @@ LIMIT 5; require.Equal(t, dest.Films[1].Title, "ACE GOLDFINGER") require.Equal(t, dest.Films[4].Title, "AFRICAN EGG") } + +func TestLateral(t *testing.T) { + skipForMariaDB(t) // MariaDB does not implement LATERAL + + languages := LATERAL( + SELECT( + Language.AllColumns, + ).FROM( + Language, + ).WHERE( + Language.Name.NOT_IN(String("spanish")). + AND(Film.LanguageID.EQ(Language.LanguageID)), + ), + ).AS("films") + + stmt := SELECT( + Film.FilmID, + Film.Title, + languages.AllColumns(), + ).FROM( + Film.CROSS_JOIN(languages), + ).WHERE( + Film.FilmID.EQ(Int(1)), + ).ORDER_BY( + Film.FilmID, + ).LIMIT(1) + + testutils.AssertDebugStatementSql(t, stmt, strings.Replace(` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + films.''language.language_id'' AS "language.language_id", + films.''language.name'' AS "language.name", + films.''language.last_update'' AS "language.last_update" +FROM dvds.film + CROSS JOIN LATERAL ( + SELECT language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update" + FROM dvds.language + WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + ) AS films +WHERE film.film_id = 1 +ORDER BY film.film_id +LIMIT 1; +`, "''", "`", -1)) + + type FilmLanguage struct { + model.Film + model.Language + } + + var dest []FilmLanguage + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest[0].Film.Title, "ACADEMY DINOSAUR") + require.Equal(t, dest[0].Language.Name, "English") + + t.Run("implicit cross join", func(t *testing.T) { + stmt2 := SELECT( + Film.FilmID, + Film.Title, + languages.AllColumns(), + ).FROM( + Film, + languages, + ).WHERE( + Film.FilmID.EQ(Int(1)), + ).ORDER_BY( + Film.FilmID, + ).LIMIT(1) + + testutils.AssertDebugStatementSql(t, stmt2, strings.Replace(` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + films.''language.language_id'' AS "language.language_id", + films.''language.name'' AS "language.name", + films.''language.last_update'' AS "language.last_update" +FROM dvds.film, + LATERAL ( + SELECT language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update" + FROM dvds.language + WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + ) AS films +WHERE film.film_id = 1 +ORDER BY film.film_id +LIMIT 1; +`, "''", "`", -1)) + + var dest2 []FilmLanguage + + err2 := stmt2.Query(db, &dest2) + require.NoError(t, err2) + require.Equal(t, dest, dest2) + }) +} diff --git a/tests/postgres/select_test.go b/tests/postgres/select_test.go index 8bd6f53..59fc44a 100644 --- a/tests/postgres/select_test.go +++ b/tests/postgres/select_test.go @@ -1,15 +1,17 @@ package postgres import ( + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/go-jet/jet/v2/internal/testutils" . "github.com/go-jet/jet/v2/postgres" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/enum" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/model" . "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/table" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/dvds/view" - "github.com/stretchr/testify/require" - "testing" - "time" ) func TestSelect_ScanToStruct(t *testing.T) { @@ -1893,3 +1895,100 @@ WHERE ($1 AND (customer.customer_id = $2)) AND (customer.activebool = $3); require.Len(t, dest, 1) testutils.AssertDeepEqual(t, dest[0], customer0) } + +func TestLateral(t *testing.T) { + + languages := LATERAL( + SELECT( + Language.AllColumns, + ).FROM( + Language, + ).WHERE( + Language.Name.NOT_IN(String("spanish")). + AND(Film.LanguageID.EQ(Language.LanguageID)), + ), + ).AS("films") + + stmt := SELECT( + Film.FilmID, + Film.Title, + languages.AllColumns(), + ).FROM( + Film.CROSS_JOIN(languages), + ).WHERE( + Film.FilmID.EQ(Int(1)), + ).ORDER_BY( + Film.FilmID, + ).LIMIT(1) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + films."language.language_id" AS "language.language_id", + films."language.name" AS "language.name", + films."language.last_update" AS "language.last_update" +FROM dvds.film + CROSS JOIN LATERAL ( + SELECT language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update" + FROM dvds.language + WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + ) AS films +WHERE film.film_id = 1 +ORDER BY film.film_id +LIMIT 1; +`) + + type FilmLanguage struct { + model.Film + model.Language + } + + var dest []FilmLanguage + + err := stmt.Query(db, &dest) + require.NoError(t, err) + require.Equal(t, dest[0].Film.Title, "Academy Dinosaur") + require.Equal(t, dest[0].Language.Name, "English ") + + t.Run("implicit cross join", func(t *testing.T) { + stmt2 := SELECT( + Film.FilmID, + Film.Title, + languages.AllColumns(), + ).FROM( + Film, + languages, + ).WHERE( + Film.FilmID.EQ(Int(1)), + ).ORDER_BY( + Film.FilmID, + ).LIMIT(1) + + testutils.AssertDebugStatementSql(t, stmt2, ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + films."language.language_id" AS "language.language_id", + films."language.name" AS "language.name", + films."language.last_update" AS "language.last_update" +FROM dvds.film, + LATERAL ( + SELECT language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update" + FROM dvds.language + WHERE (language.name NOT IN ('spanish')) AND (film.language_id = language.language_id) + ) AS films +WHERE film.film_id = 1 +ORDER BY film.film_id +LIMIT 1; +`) + + var dest2 []FilmLanguage + + err2 := stmt2.Query(db, &dest2) + require.NoError(t, err2) + require.Equal(t, dest, dest2) + }) +}