diff --git a/internal/jet/column_list.go b/internal/jet/column_list.go index 4834871..a07b9ba 100644 --- a/internal/jet/column_list.go +++ b/internal/jet/column_list.go @@ -1,7 +1,5 @@ package jet -import "strings" - // ColumnList is a helper type to support list of columns as single projection type ColumnList []ColumnExpression @@ -46,12 +44,9 @@ func (cl ColumnList) Except(excludedColumns ...Column) ColumnList { // For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will // have a column wrapped in alias 'Musician.Name'. If tableAlias is empty string, it removes existing table alias ('Artist.Name' becomes 'Name'). func (cl ColumnList) As(tableAlias string) ProjectionList { - if tableAlias > "" { - tableAlias = strings.TrimRight(tableAlias, ".*") + "." - } ret := make(ProjectionList, 0, len(cl)) for _, c := range cl { - ret = append(ret, c.AS(tableAlias+c.Name())) + ret = append(ret, c.AS(joinAlias(tableAlias, c.Name()))) } return ret } diff --git a/internal/jet/projection.go b/internal/jet/projection.go index 09f1c2b..3b2ccd8 100644 --- a/internal/jet/projection.go +++ b/internal/jet/projection.go @@ -1,7 +1,5 @@ package jet -import "strings" - // Projection is interface for all projection types. Types that can be part of, for instance SELECT clause. type Projection interface { serializeForProjection(statement StatementType, out *SQLBuilder) @@ -35,12 +33,6 @@ func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQ // For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will // have a column wrapped in alias 'Musician.Name'. If tableAlias is empty string, it removes existing table alias ('Artist.Name' becomes 'Name'). func (pl ProjectionList) As(tableAlias string) ProjectionList { - tableAliasWithDot := "" - if tableAlias != "" { - tableAlias = strings.TrimRight(tableAlias, ".*") - tableAliasWithDot = tableAlias + "." - } - newProjectionList := ProjectionList{} for _, projection := range pl { @@ -50,11 +42,11 @@ func (pl ProjectionList) As(tableAlias string) ProjectionList { case ColumnList: newProjectionList = append(newProjectionList, p.As(tableAlias)) case ColumnExpression: - newProjectionList = append(newProjectionList, newAlias(p, tableAliasWithDot+p.Name())) + newProjectionList = append(newProjectionList, newAlias(p, joinAlias(tableAlias, p.Name()))) case *alias: newAlias := *p _, columnName := extractTableAndColumnName(newAlias.alias) - newAlias.alias = tableAliasWithDot + columnName + newAlias.alias = joinAlias(tableAlias, columnName) newProjectionList = append(newProjectionList, &newAlias) } } diff --git a/internal/jet/utils.go b/internal/jet/utils.go index fe29a09..466f2a5 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -1,10 +1,11 @@ package jet import ( - "github.com/go-jet/jet/v2/internal/utils/dbidentifier" - "github.com/go-jet/jet/v2/internal/utils/must" "reflect" "strings" + + "github.com/go-jet/jet/v2/internal/utils/dbidentifier" + "github.com/go-jet/jet/v2/internal/utils/must" ) // SerializeClauseList func @@ -278,3 +279,15 @@ func serializeToDefaultDebugString(expr Serializer) string { expr.serialize(SelectStatementType, &out) return out.Buff.String() } + +// joinAlias examples: +// +// joinAlias("foo", "bar") // "foo.bar" +// joinAlias("foo.*", "bar") // "foo.bar" +// joinAlias("", "bar") // "bar" +func joinAlias(tableAlias, columnAlias string) string { + if tableAlias == "" { + return columnAlias + } + return strings.TrimRight(tableAlias, ".*") + "." + columnAlias +} diff --git a/internal/jet/utils_test.go b/internal/jet/utils_test.go index b907760..86feff1 100644 --- a/internal/jet/utils_test.go +++ b/internal/jet/utils_test.go @@ -1,8 +1,9 @@ package jet import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestOptionalOrDefaultString(t *testing.T) { @@ -17,3 +18,10 @@ func TestOptionalOrDefaultExpression(t *testing.T) { require.Equal(t, OptionalOrDefaultExpression(defaultExpression), defaultExpression) require.Equal(t, OptionalOrDefaultExpression(defaultExpression, optionalExpression), optionalExpression) } + +func TestJoinAlias(t *testing.T) { + require.Equal(t, joinAlias("", ""), "") + require.Equal(t, joinAlias("foo", "bar"), "foo.bar") + require.Equal(t, joinAlias("foo.*", "bar"), "foo.bar") + require.Equal(t, joinAlias("", "bar"), "bar") +}