diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go index a4a0b66..a07b9ba 100644 --- a/internal/jet/column_list.go +++ b/internal/jet/column_list.go @@ -39,6 +39,18 @@ func (cl ColumnList) Except(excludedColumns ...Column) ColumnList { return ret } +// As will create new projection list where each column is wrapped with a new table alias. +// tableAlias should be in the form 'name' or 'name.*', or it can also be an empty string. +// For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will +// have a column wrapped in alias 'Musician.Name'. If tableAlias is empty string, it removes existing table alias ('Artist.Name' becomes 'Name'). +func (cl ColumnList) As(tableAlias string) ProjectionList { + ret := make(ProjectionList, 0, len(cl)) + for _, c := range cl { + ret = append(ret, c.AS(joinAlias(tableAlias, c.Name()))) + } + return ret +} + func (cl ColumnList) fromImpl(subQuery SelectTable) Projection { newProjectionList := ProjectionList{} diff --git a/internal/jet/projection.go b/internal/jet/projection.go index 1b1c625..3b2ccd8 100644 --- a/internal/jet/projection.go +++ b/internal/jet/projection.go @@ -1,7 +1,5 @@ package jet -import "strings" - // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. type Projection interface { serializeForProjection(statement StatementType, out *SQLBuilder) @@ -31,24 +29,24 @@ func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQ } // As will create new projection list where each column is wrapped with a new table alias. -// tableAlias should be in the form 'name' or 'name.*'. +// tableAlias should be in the form 'name' or 'name.*', or it can be an empty string, which will remove existing table alias. // For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will -// have a column wrapped in alias 'Musician.Name'. +// have a column wrapped in alias 'Musician.Name'. If tableAlias is empty string, it removes existing table alias ('Artist.Name' becomes 'Name'). func (pl ProjectionList) As(tableAlias string) ProjectionList { - tableAlias = strings.TrimRight(tableAlias, ".*") - newProjectionList := ProjectionList{} for _, projection := range pl { switch p := projection.(type) { case ProjectionList: newProjectionList = append(newProjectionList, p.As(tableAlias)) + case ColumnList: + newProjectionList = append(newProjectionList, p.As(tableAlias)) case ColumnExpression: - newProjectionList = append(newProjectionList, newAlias(p, tableAlias+"."+p.Name())) + newProjectionList = append(newProjectionList, newAlias(p, joinAlias(tableAlias, p.Name()))) case *alias: newAlias := *p _, columnName := extractTableAndColumnName(newAlias.alias) - newAlias.alias = tableAlias + "." + columnName + newAlias.alias = joinAlias(tableAlias, columnName) newProjectionList = append(newProjectionList, &newAlias) } } diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go index 7728e15..0370b43 100644 --- a/internal/jet/projection_test.go +++ b/internal/jet/projection_test.go @@ -12,6 +12,7 @@ func TestProjectionAs(t *testing.T) { AVG(table1ColInt).AS("avg"), AVG(table1ColInt).AS("t.avg"), }, + ColumnList{table2Col3, table2Col4}, } aliasedProjectionList := projectionList.As("new_alias.*") @@ -22,7 +23,21 @@ SUM(table1.col_int) AS "new_alias.sum", SUM(table1.col_int) AS "new_alias.sum", table1.col_bool AS "new_alias.col_bool", AVG(table1.col_int) AS "new_alias.avg", -AVG(table1.col_int) AS "new_alias.avg"`) +AVG(table1.col_int) AS "new_alias.avg", +table2.col3 AS "new_alias.col3", +table2.col4 AS "new_alias.col4"`) + + aliasedProjectionList = projectionList.As("") + + assertProjectionSerialize(t, aliasedProjectionList, + `table1.col3 AS "col3", +SUM(table1.col_int) AS "sum", +SUM(table1.col_int) AS "sum", +table1.col_bool AS "col_bool", +AVG(table1.col_int) AS "avg", +AVG(table1.col_int) AS "avg", +table2.col3 AS "col3", +table2.col4 AS "col4"`) subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery")) @@ -32,7 +47,9 @@ AVG(table1.col_int) AS "new_alias.avg"`) "subQuery"."table.sum" AS "table.sum", "subQuery"."table1.col_bool" AS "table1.col_bool", "subQuery".avg AS "avg", -"subQuery"."t.avg" AS "t.avg"`) +"subQuery"."t.avg" AS "t.avg", +"subQuery"."table2.col3" AS "table2.col3", +"subQuery"."table2.col4" AS "table2.col4"`) aliasedSubQueryProjectionList := subQueryProjections.(ProjectionList).As("subAlias") @@ -42,5 +59,7 @@ AVG(table1.col_int) AS "new_alias.avg"`) "subQuery"."table.sum" AS "subAlias.sum", "subQuery"."table1.col_bool" AS "subAlias.col_bool", "subQuery".avg AS "subAlias.avg", -"subQuery"."t.avg" AS "subAlias.avg"`) +"subQuery"."t.avg" AS "subAlias.avg", +"subQuery"."table2.col3" AS "subAlias.col3", +"subQuery"."table2.col4" AS "subAlias.col4"`) } diff --git a/internal/jet/utils.go b/internal/jet/utils.go index fe29a09..466f2a5 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -1,10 +1,11 @@ package jet import ( - "github.com/go-jet/jet/v2/internal/utils/dbidentifier" - "github.com/go-jet/jet/v2/internal/utils/must" "reflect" "strings" + + "github.com/go-jet/jet/v2/internal/utils/dbidentifier" + "github.com/go-jet/jet/v2/internal/utils/must" ) // SerializeClauseList func @@ -278,3 +279,15 @@ func serializeToDefaultDebugString(expr Serializer) string { expr.serialize(SelectStatementType, &out) return out.Buff.String() } + +// joinAlias examples: +// +// joinAlias("foo", "bar") // "foo.bar" +// joinAlias("foo.*", "bar") // "foo.bar" +// joinAlias("", "bar") // "bar" +func joinAlias(tableAlias, columnAlias string) string { + if tableAlias == "" { + return columnAlias + } + return strings.TrimRight(tableAlias, ".*") + "." + columnAlias +} diff --git a/internal/jet/utils_test.go b/internal/jet/utils_test.go index b907760..86feff1 100644 --- a/internal/jet/utils_test.go +++ b/internal/jet/utils_test.go @@ -1,8 +1,9 @@ package jet import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestOptionalOrDefaultString(t *testing.T) { @@ -17,3 +18,10 @@ func TestOptionalOrDefaultExpression(t *testing.T) { require.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) require.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) } + +func TestJoinAlias(t *testing.T) { + require.Equal(t, joinAlias("", ""), "") + require.Equal(t, joinAlias("foo", "bar"), "foo.bar") + require.Equal(t, joinAlias("foo.*", "bar"), "foo.bar") + require.Equal(t, joinAlias("", "bar"), "bar") +}