From 5cbf4aac86c27657ce20eba4ba58ff45db1065ed Mon Sep 17 00:00:00 2001 From: go-jet Date: Wed, 5 Jan 2022 18:00:20 +0100 Subject: [PATCH] Add ability to change alias of all projections in the ProjectionList. Add ability to exclude list of columns from ProjectionList. --- internal/jet/alias.go | 7 +- internal/jet/expression.go | 5 +- internal/jet/projection.go | 56 +++++++-- internal/jet/projection_test.go | 46 +++++++ internal/jet/utils.go | 20 ++++ internal/testutils/test_utils.go | 6 + tests/postgres/chinook_db_test.go | 192 ++++++++++++++++++++++++++++-- tests/postgres/with_test.go | 76 ++++++++++-- 8 files changed, 377 insertions(+), 31 deletions(-) create mode 100644 internal/jet/projection_test.go diff --git a/internal/jet/alias.go b/internal/jet/alias.go index 57f55cd..8693b13 100644 --- a/internal/jet/alias.go +++ b/internal/jet/alias.go @@ -13,7 +13,12 @@ func newAlias(expression Expression, aliasName string) Projection { } func (a *alias) fromImpl(subQuery SelectTable) Projection { - column := NewColumnImpl(a.alias, "", nil) + // if alias is in the form "table.column", we break it into two parts so that ProjectionList.As(newAlias) can + // overwrite tableName with a new alias. This method is called only for exporting aliased custom columns. + // Generated columns have default aliasing. + tableName, columnName := extractTableAndColumnName(a.alias) + + column := NewColumnImpl(columnName, tableName, nil) column.subQuery = subQuery return &column diff --git a/internal/jet/expression.go b/internal/jet/expression.go index ad5c205..d748dcf 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -1,5 +1,7 @@ package jet +import "fmt" + // Expression is common interface for all expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. type Expression interface { @@ -33,7 +35,8 @@ type ExpressionInterfaceImpl struct { } func (e *ExpressionInterfaceImpl) fromImpl(subQuery SelectTable) Projection { - return e.Parent + panic(fmt.Sprintf("jet: can't export unaliased expression subQuery: %s, expression: %s", + subQuery.Alias(), serializeToDefaultDebugString(e.Parent))) } // IS_NULL tests expression whether it is a NULL value. diff --git a/internal/jet/projection.go b/internal/jet/projection.go index 16abe5b..1b1c625 100644 --- a/internal/jet/projection.go +++ b/internal/jet/projection.go @@ -16,36 +16,68 @@ func SerializeForProjection(projection Projection, statementType StatementType, // ProjectionList is a redefined type, so that ProjectionList can be used as a Projection. type ProjectionList []Projection -func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection { +func (pl ProjectionList) fromImpl(subQuery SelectTable) Projection { newProjectionList := ProjectionList{} - for _, projection := range cl { + for _, projection := range pl { newProjectionList = append(newProjectionList, projection.fromImpl(subQuery)) } return newProjectionList } -func (cl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) { - SerializeProjectionList(statement, cl, out) +func (pl ProjectionList) serializeForProjection(statement StatementType, out *SQLBuilder) { + SerializeProjectionList(statement, pl, out) } -// As is used to set aliases of the projection list. alias should be in the form 'name' or 'name.*'. -// For instance: If projection list has a column 'Artist.Name', and alias is 'Musician.*', returned projection list will -// have column wrapped in alias 'Musician.Name'. -func (cl ProjectionList) As(alias string) ProjectionList { - alias = strings.TrimRight(alias, ".*") +// As will create new projection list where each column is wrapped with a new table alias. +// tableAlias should be in the form 'name' or 'name.*'. +// For instance: If projection list has a column 'Artist.Name', and tableAlias is 'Musician.*', returned projection list will +// have a column wrapped in alias 'Musician.Name'. +func (pl ProjectionList) As(tableAlias string) ProjectionList { + tableAlias = strings.TrimRight(tableAlias, ".*") newProjectionList := ProjectionList{} - for _, projection := range cl { + for _, projection := range pl { switch p := projection.(type) { case ProjectionList: - newProjectionList = append(newProjectionList, p.As(alias)) + newProjectionList = append(newProjectionList, p.As(tableAlias)) case ColumnExpression: - newProjectionList = append(newProjectionList, newAlias(p, alias+"."+p.Name())) + newProjectionList = append(newProjectionList, newAlias(p, tableAlias+"."+p.Name())) + case *alias: + newAlias := *p + _, columnName := extractTableAndColumnName(newAlias.alias) + newAlias.alias = tableAlias + "." + columnName + newProjectionList = append(newProjectionList, &newAlias) } } return newProjectionList } + +// Except will create new projection list in which columns contained in excluded column names are removed +func (pl ProjectionList) Except(toExclude ...Column) ProjectionList { + excludedColumnList := UnwidColumnList(toExclude) + excludedColumnNames := map[string]bool{} + + for _, excludedColumn := range excludedColumnList { + excludedColumnNames[excludedColumn.Name()] = true + } + + var ret ProjectionList + + for _, projection := range pl { + switch p := projection.(type) { + case ProjectionList: + ret = append(ret, p.Except(toExclude...)) + case ColumnExpression: + if excludedColumnNames[p.Name()] { + continue + } + ret = append(ret, p) + } + } + + return ret +} diff --git a/internal/jet/projection_test.go b/internal/jet/projection_test.go new file mode 100644 index 0000000..7728e15 --- /dev/null +++ b/internal/jet/projection_test.go @@ -0,0 +1,46 @@ +package jet + +import "testing" + +func TestProjectionAs(t *testing.T) { + projectionList := ProjectionList{ + table1Col3, + SUM(table1ColInt).AS("sum"), + SUM(table1ColInt).AS("table.sum"), + ProjectionList{ + table1ColBool, + AVG(table1ColInt).AS("avg"), + AVG(table1ColInt).AS("t.avg"), + }, + } + + aliasedProjectionList := projectionList.As("new_alias.*") + + assertProjectionSerialize(t, aliasedProjectionList, + `table1.col3 AS "new_alias.col3", +SUM(table1.col_int) AS "new_alias.sum", +SUM(table1.col_int) AS "new_alias.sum", +table1.col_bool AS "new_alias.col_bool", +AVG(table1.col_int) AS "new_alias.avg", +AVG(table1.col_int) AS "new_alias.avg"`) + + subQueryProjections := projectionList.fromImpl(NewSelectTable(nil, "subQuery")) + + assertProjectionSerialize(t, subQueryProjections, + `"subQuery"."table1.col3" AS "table1.col3", +"subQuery".sum AS "sum", +"subQuery"."table.sum" AS "table.sum", +"subQuery"."table1.col_bool" AS "table1.col_bool", +"subQuery".avg AS "avg", +"subQuery"."t.avg" AS "t.avg"`) + + aliasedSubQueryProjectionList := subQueryProjections.(ProjectionList).As("subAlias") + + assertProjectionSerialize(t, aliasedSubQueryProjectionList, + `"subQuery"."table1.col3" AS "subAlias.col3", +"subQuery".sum AS "subAlias.sum", +"subQuery"."table.sum" AS "subAlias.sum", +"subQuery"."table1.col_bool" AS "subAlias.col_bool", +"subQuery".avg AS "subAlias.avg", +"subQuery"."t.avg" AS "subAlias.avg"`) +} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 113a396..d887f63 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -3,6 +3,7 @@ package jet import ( "github.com/go-jet/jet/v2/internal/utils" "reflect" + "strings" ) // SerializeClauseList func @@ -244,3 +245,22 @@ func OptionalOrDefaultExpression(defaultExpression Expression, expression ...Exp return defaultExpression } + +func extractTableAndColumnName(alias string) (tableName string, columnName string) { + parts := strings.Split(alias, ".") + + if len(parts) >= 2 { + tableName = parts[0] + columnName = parts[1] + } else { + columnName = parts[0] + } + + return +} + +func serializeToDefaultDebugString(expr Serializer) string { + out := SQLBuilder{Dialect: defaultDialect, Debug: true} + expr.serialize(SelectStatementType, &out) + return out.Buff.String() +} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index c1419aa..cac4a62 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -56,6 +56,12 @@ func PrintJson(v interface{}) { fmt.Println(string(jsonText)) } +// ToJSON converts v into json string +func ToJSON(v interface{}) string { + jsonText, _ := json.MarshalIndent(v, "", "\t") + return string(jsonText) +} + // AssertJSON check if data json output is the same as expectedJSON func AssertJSON(t *testing.T, data interface{}, expectedJSON string) { jsonData, err := json.MarshalIndent(data, "", "\t") diff --git a/tests/postgres/chinook_db_test.go b/tests/postgres/chinook_db_test.go index 0e00019..f5a1701 100644 --- a/tests/postgres/chinook_db_test.go +++ b/tests/postgres/chinook_db_test.go @@ -219,18 +219,33 @@ func TestSubQueryColumnAliasBubbling(t *testing.T) { ).AsTable("subQuery2") mainQuery := SELECT( - subQuery2.AllColumns(), + subQuery2.AllColumns(), // columns will have the same alias as in the sub-query + subQuery2.AllColumns().As("artist2.*"), // all column aliases will be changed to artist2.* + subQuery2.AllColumns().Except(Artist.Name).As("artist3.*"), + subQuery2.AllColumns().Except( + Artist.MutableColumns, + StringColumn("custom_column_1").From(subQuery2), // custom_column_1 appears with the same alias in subQuery2 + StringColumn("custom_column_2").From(subQuery2), + ).As("artist4.*"), ).FROM( subQuery2, ) - //fmt.Println(mainQuery.Sql()) + // fmt.Println(mainQuery.Sql()) testutils.AssertStatementSql(t, mainQuery, ` SELECT "subQuery2"."Artist.ArtistId" AS "Artist.ArtistId", "subQuery2"."Artist.Name" AS "Artist.Name", "subQuery2".custom_column_1 AS "custom_column_1", - "subQuery2".custom_column_2 AS "custom_column_2" + "subQuery2".custom_column_2 AS "custom_column_2", + "subQuery2"."Artist.ArtistId" AS "artist2.ArtistId", + "subQuery2"."Artist.Name" AS "artist2.Name", + "subQuery2".custom_column_1 AS "artist2.custom_column_1", + "subQuery2".custom_column_2 AS "artist2.custom_column_2", + "subQuery2"."Artist.ArtistId" AS "artist3.ArtistId", + "subQuery2".custom_column_1 AS "artist3.custom_column_1", + "subQuery2".custom_column_2 AS "artist3.custom_column_2", + "subQuery2"."Artist.ArtistId" AS "artist4.ArtistId" FROM ( SELECT "subQuery1"."Artist.ArtistId" AS "Artist.ArtistId", "subQuery1"."Artist.Name" AS "Artist.Name", @@ -246,21 +261,180 @@ FROM ( ) AS "subQuery2"; `) var dest []struct { - model.Artist - CustomColumn1 string - CustomColumn2 string + // subQuery2.AllColumns() + Artist1 struct { + model.Artist + + CustomColumn1 string + CustomColumn2 string + } + + // subQuery2.AllColumns().As("artist2.*") + Artist2 struct { + model.Artist `alias:"artist2.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"artist2.*"` + + // subQuery2.AllColumns().Except(Artist.Name).As("artist3.*") + Artist3 struct { + model.Artist `alias:"artist3.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"artist3.*"` + + // subQuery2.AllColumns().Except(...).As("artist4.*") + Artist4 struct { + model.Artist `alias:"artist4.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"artist4.*"` } err := mainQuery.Query(db, &dest) require.NoError(t, err) + // Artist1 require.Len(t, dest, 275) - require.Equal(t, dest[0].Artist, model.Artist{ + require.Equal(t, dest[0].Artist1.Artist, model.Artist{ ArtistId: 1, Name: testutils.StringPtr("AC/DC"), }) - require.Equal(t, dest[0].CustomColumn1, "custom_column_1") - require.Equal(t, dest[0].CustomColumn2, "custom_column_2") + require.Equal(t, dest[0].Artist1.CustomColumn1, "custom_column_1") + require.Equal(t, dest[0].Artist1.CustomColumn2, "custom_column_2") + + // Artist2 + require.Equal(t, testutils.ToJSON(dest[0].Artist1), testutils.ToJSON(dest[0].Artist2)) + + // Artist3 + require.Equal(t, dest[0].Artist3.ArtistId, int32(1)) + require.Nil(t, dest[0].Artist3.Name) + require.Equal(t, dest[0].Artist3.CustomColumn1, "custom_column_1") + require.Equal(t, dest[0].Artist3.CustomColumn2, "custom_column_2") + + // Artist4 + require.Equal(t, dest[0].Artist3.Artist, dest[0].Artist4.Artist) + require.Equal(t, dest[0].Artist4.CustomColumn1, "") + require.Equal(t, dest[0].Artist4.CustomColumn2, "") +} + +func TestUnAliasedNamesPanicError(t *testing.T) { + subQuery1 := SELECT( + Artist.AllColumns, + Artist.Name.CONCAT(String("-musician")), //alias missing + ).FROM( + Artist, + ).ORDER_BY( + Artist.ArtistId.ASC(), + ).AsTable("subQuery1") + + require.Panics(t, func() { + SELECT( + subQuery1.AllColumns(), // panic, column not aliased + ).FROM( + subQuery1, + ) + }, "jet: can't export unaliased expression subQuery: subQuery1, expression: (\"Artist\".\"Name\" || '-musician')") +} + +func TestProjectionListReAliasing(t *testing.T) { + projectionList := ProjectionList{ + Track.GenreId, + SUM(Track.Milliseconds).AS("duration"), + MAX(Track.Milliseconds).AS("duration.max"), + } + + stmt := SELECT( + projectionList.As("genre_info"), + ).FROM( + Track, + ).WHERE( + Track.GenreId.LT(Int(5)), + ).GROUP_BY( + Track.GenreId, + ).ORDER_BY( + Track.GenreId, + ) + + testutils.AssertDebugStatementSql(t, stmt, ` +SELECT "Track"."GenreId" AS "genre_info.GenreId", + SUM("Track"."Milliseconds") AS "genre_info.duration", + MAX("Track"."Milliseconds") AS "genre_info.max" +FROM chinook."Track" +WHERE "Track"."GenreId" < 5 +GROUP BY "Track"."GenreId" +ORDER BY "Track"."GenreId"; +`) + + type GenreInfo struct { + GenreID string + Duration int64 + Max int64 + } + + var dest []GenreInfo + + err := stmt.Query(db, &dest) + require.NoError(t, err) + + expectedSQL := ` +[ + { + "GenreID": "1", + "Duration": 368231326, + "Max": 1612329 + }, + { + "GenreID": "2", + "Duration": 37928199, + "Max": 907520 + }, + { + "GenreID": "3", + "Duration": 115846292, + "Max": 816509 + }, + { + "GenreID": "4", + "Duration": 77805478, + "Max": 558602 + } +] +` + testutils.AssertJSON(t, dest, expectedSQL) + + subQuery := stmt.AsTable("subQuery") + + mainStmt := SELECT( + subQuery.AllColumns().As("genre_information.*"), + ).FROM( + subQuery, + ) + + testutils.AssertDebugStatementSql(t, mainStmt, ` +SELECT "subQuery"."genre_info.GenreId" AS "genre_information.GenreId", + "subQuery"."genre_info.duration" AS "genre_information.duration", + "subQuery"."genre_info.max" AS "genre_information.max" +FROM ( + SELECT "Track"."GenreId" AS "genre_info.GenreId", + SUM("Track"."Milliseconds") AS "genre_info.duration", + MAX("Track"."Milliseconds") AS "genre_info.max" + FROM chinook."Track" + WHERE "Track"."GenreId" < 5 + GROUP BY "Track"."GenreId" + ORDER BY "Track"."GenreId" + ) AS "subQuery"; +`) + + type GenreInformation GenreInfo + var newDest []GenreInformation + + err = mainStmt.Query(db, &newDest) + require.NoError(t, err) + testutils.AssertJSON(t, dest, expectedSQL) } func TestSelfJoin(t *testing.T) { diff --git a/tests/postgres/with_test.go b/tests/postgres/with_test.go index 3b682d8..0b47e9a 100644 --- a/tests/postgres/with_test.go +++ b/tests/postgres/with_test.go @@ -364,7 +364,15 @@ func TestCTEColumnAliasBubbling(t *testing.T) { ), )( SELECT( - cte2.AllColumns(), + cte2.AllColumns(), // columns will have the same alias as in CTEs + cte2.AllColumns().As("territories2.*"), // all column aliases will be changed to territories2.* + cte2.AllColumns().Except(Territories.RegionID, Territories.TerritoryDescription).As("territories3.*"), + cte2.AllColumns(). + Except( + Territories.MutableColumns, + StringColumn("custom_column_1").From(cte2), // custom_column_1 appears with the same alias in cte2 + StringColumn("custom_column_2").From(cte2), + ).As("territories4.*"), ).FROM( cte2, ), @@ -392,26 +400,78 @@ SELECT cte2."territories.territory_id" AS "territories.territory_id", cte2."territories.territory_description" AS "territories.territory_description", cte2."territories.region_id" AS "territories.region_id", cte2.custom_column_1 AS "custom_column_1", - cte2.custom_column_2 AS "custom_column_2" + cte2.custom_column_2 AS "custom_column_2", + cte2."territories.territory_id" AS "territories2.territory_id", + cte2."territories.territory_description" AS "territories2.territory_description", + cte2."territories.region_id" AS "territories2.region_id", + cte2.custom_column_1 AS "territories2.custom_column_1", + cte2.custom_column_2 AS "territories2.custom_column_2", + cte2."territories.territory_id" AS "territories3.territory_id", + cte2.custom_column_1 AS "territories3.custom_column_1", + cte2.custom_column_2 AS "territories3.custom_column_2", + cte2."territories.territory_id" AS "territories4.territory_id" FROM cte2; `) var dest []struct { - model.Territories - CustomColumn1 string - CustomColumn2 string + // cte2.AllColumns() + Territories1 struct { + model.Territories + + CustomColumn1 string + CustomColumn2 string + } + + // cte2.AllColumns().As("territories2.*") + Territories2 struct { + model.Territories `alias:"territories2.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"territories2.*"` + + // cte2.AllColumns().Except(Territories.RegionID, Territories.TerritoryDescription).As("territories3.*") + Territories3 struct { + model.Territories `alias:"territories3.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"territories3.*"` + + // cte2.AllColumns() ... .As("territories4.*") + Territories4 struct { + model.Territories `alias:"territories3.*"` + + CustomColumn1 string + CustomColumn2 string + } `alias:"territories4.*"` } err := stmt.Query(db, &dest) require.NoError(t, err) require.Len(t, dest, 53) - require.Equal(t, dest[0].Territories, model.Territories{ + require.Equal(t, dest[0].Territories1.Territories, model.Territories{ TerritoryID: "01581", TerritoryDescription: "Westboro", RegionID: 1, }) - require.Equal(t, dest[0].CustomColumn1, "custom_column_1") - require.Equal(t, dest[0].CustomColumn2, "custom_column_2") + require.Equal(t, dest[0].Territories1.CustomColumn1, "custom_column_1") + require.Equal(t, dest[0].Territories1.CustomColumn2, "custom_column_2") + + // Territories2 + require.Equal(t, testutils.ToJSON(dest[0].Territories1), testutils.ToJSON(dest[0].Territories2)) + + // Territories3 + require.Equal(t, dest[0].Territories3.TerritoryID, dest[0].Territories1.TerritoryID) + require.Equal(t, dest[0].Territories3.RegionID, int16(0)) + require.Equal(t, dest[0].Territories3.TerritoryDescription, "") + require.Equal(t, dest[0].Territories1.CustomColumn1, dest[0].Territories3.CustomColumn1) + require.Equal(t, dest[0].Territories1.CustomColumn2, dest[0].Territories3.CustomColumn2) + + // Territories4 + require.Equal(t, dest[0].Territories3.Territories, dest[0].Territories4.Territories) + require.Equal(t, dest[0].Territories4.CustomColumn1, "") + require.Equal(t, dest[0].Territories4.CustomColumn2, "") } func TestRecursiveWithStatement(t *testing.T) {