diff --git a/internal/jet/clause.go b/internal/jet/clause.go index 8654547..42dad7e 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -45,19 +45,25 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SQLBuilder, o // ClauseFrom struct type ClauseFrom struct { - Table Serializer + Tables []Serializer } // Serialize serializes clause into SQLBuilder func (f *ClauseFrom) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - if f.Table == nil { + if len(f.Tables) == 0 { // SELECT statement does not have to have FROM clause return } out.NewLine() out.WriteString("FROM") out.IncreaseIdent() - f.Table.serialize(statementType, out, FallTrough(options)...) + for i, table := range f.Tables { + if i > 0 { + out.WriteString(",") + out.NewLine() + } + table.serialize(statementType, out, FallTrough(options)...) + } out.DecreaseIdent() } diff --git a/internal/jet/table.go b/internal/jet/table.go index 78abb99..c69a1dc 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -136,9 +136,6 @@ func (t *joinTableImpl) TableName() string { return "" } -func (t *joinTableImpl) AS(alias string) { -} - func (t *joinTableImpl) columns() []Column { var ret []Column diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 556ed6c..272cf5f 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -8,6 +8,7 @@ import ( "github.com/go-jet/jet/v2/internal/utils" "github.com/go-jet/jet/v2/qrm" "github.com/google/uuid" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" "os" @@ -116,7 +117,12 @@ func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery st } debuqSql := query.DebugSql() - require.Equal(t, debuqSql, expectedQuery) + if !assert.Equal(t, debuqSql, expectedQuery) { + fmt.Println("Expected: ") + fmt.Println(expectedQuery) + fmt.Println("Got: ") + fmt.Println(debuqSql) + } } // AssertSerialize checks if clause serialize produces expected query and args diff --git a/mysql/select_statement.go b/mysql/select_statement.go index 37ac96b..fa6dd9c 100644 --- a/mysql/select_statement.go +++ b/mysql/select_statement.go @@ -41,7 +41,7 @@ type SelectStatement interface { Expression DISTINCT() SelectStatement - FROM(table ReadableTable) SelectStatement + FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement @@ -70,7 +70,9 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta &newSelect.Limit, &newSelect.Offset, &newSelect.For, &newSelect.ShareLock) newSelect.Select.ProjectionList = projections - newSelect.From.Table = table + if table != nil { + newSelect.From.Tables = []jet.Serializer{table} + } newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 newSelect.ShareLock.Name = "LOCK IN SHARE MODE" @@ -103,8 +105,10 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { return s } -func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { - s.From.Table = table +func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { + for _, table := range tables { + s.From.Tables = append(s.From.Tables, table) + } return s } diff --git a/postgres/select_statement.go b/postgres/select_statement.go index b31909f..a0d3e27 100644 --- a/postgres/select_statement.go +++ b/postgres/select_statement.go @@ -44,7 +44,7 @@ type SelectStatement interface { Expression DISTINCT() SelectStatement - FROM(table ReadableTable) SelectStatement + FROM(tables ...ReadableTable) SelectStatement WHERE(expression BoolExpression) SelectStatement GROUP_BY(groupByClauses ...jet.GroupByClause) SelectStatement HAVING(boolExpression BoolExpression) SelectStatement @@ -76,7 +76,9 @@ func newSelectStatement(table ReadableTable, projections []Projection) SelectSta &newSelect.Limit, &newSelect.Offset, &newSelect.For) newSelect.Select.ProjectionList = projections - newSelect.From.Table = table + if table != nil { + newSelect.From.Tables = []jet.Serializer{table} + } newSelect.Limit.Count = -1 newSelect.Offset.Count = -1 @@ -106,8 +108,10 @@ func (s *selectStatementImpl) DISTINCT() SelectStatement { return s } -func (s *selectStatementImpl) FROM(table ReadableTable) SelectStatement { - s.From.Table = table +func (s *selectStatementImpl) FROM(tables ...ReadableTable) SelectStatement { + for _, table := range tables { + s.From.Tables = append(s.From.Tables, table) + } return s } diff --git a/postgres/table_test.go b/postgres/table_test.go index 3b6f498..fa12228 100644 --- a/postgres/table_test.go +++ b/postgres/table_test.go @@ -99,3 +99,26 @@ CROSS JOIN db.table2`) CROSS JOIN db.table2 CROSS JOIN db.table3`) } + +func TestImplicitCROSS_JOIN(t *testing.T) { + assertDebugStatementSql(t, + SELECT(table1Col1, table2Col3). + FROM(table1, table2), + ` +SELECT table1.col1 AS "table1.col1", + table2.col3 AS "table2.col3" +FROM db.table1, + db.table2; +`) + assertDebugStatementSql(t, + SELECT( + table1Col1, table2Col3, + ).FROM(table1, table2, table3), + ` +SELECT table1.col1 AS "table1.col1", + table2.col3 AS "table2.col3" +FROM db.table1, + db.table2, + db.table3; +`) +}