diff --git a/sqlbuilder/alias.go b/sqlbuilder/alias.go index 4dd9b8b..2fdba92 100644 --- a/sqlbuilder/alias.go +++ b/sqlbuilder/alias.go @@ -1,23 +1,36 @@ package sqlbuilder -type Alias struct { +type alias struct { expression Expression alias string + + subQuery ExpressionTable } -func NewAlias(expression Expression, alias string) *Alias { - return &Alias{ +func newAlias(expression Expression, aliasName string) projection { + return &alias{ expression: expression, - alias: alias, + alias: aliasName, } } -func (a *Alias) serializeForProjection(statement statementType, out *queryData) error { +func (a *alias) from(subQuery ExpressionTable) projection { + newAlias := *a + newAlias.subQuery = subQuery + return &newAlias +} - err := a.expression.serialize(statement, out) +func (a *alias) serializeForProjection(statement statementType, out *queryData) error { + if a.subQuery != nil { + out.writeIdentifier(a.subQuery.Alias()) + out.writeByte('.') + out.writeQuotedString(a.alias) + } else { + err := a.expression.serialize(statement, out) - if err != nil { - return err + if err != nil { + return err + } } out.writeString(`AS "` + a.alias + `"`) diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 3f51030..d7ac801 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -150,6 +150,10 @@ func isPostSeparator(b byte) bool { return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' } +func (q *queryData) writeQuotedString(str string) { + q.writeString(`"` + str + `"`) +} + func (q *queryData) writeString(str string) { q.write([]byte(str)) } diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 33c2900..7fd37a3 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -7,6 +7,8 @@ type column interface { TableName() string setTableName(table string) + setSubQuery(subQuery ExpressionTable) + defaultAlias() string } type Column interface { @@ -20,6 +22,8 @@ type columnImpl struct { name string tableName string + + subQuery ExpressionTable } func newColumn(name string, tableName string, parent Column) columnImpl { @@ -45,6 +49,10 @@ func (c *columnImpl) setTableName(table string) { c.tableName = table } +func (c *columnImpl) setSubQuery(subQuery ExpressionTable) { + c.subQuery = subQuery +} + func (c *columnImpl) defaultAlias() string { if c.tableName != "" { return c.tableName + "." + c.name @@ -78,12 +86,18 @@ func (c columnImpl) serializeForProjection(statement statementType, out *queryDa func (c columnImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { - if c.tableName != "" { - out.writeIdentifier(c.tableName) + if c.subQuery != nil { + out.writeIdentifier(c.subQuery.Alias()) out.writeByte('.') - } + out.writeString(`"` + c.defaultAlias() + `"`) + } else { + if c.tableName != "" { + out.writeIdentifier(c.tableName) + out.writeByte('.') + } - out.writeIdentifier(c.name) + out.writeIdentifier(c.name) + } return nil } @@ -95,6 +109,16 @@ type ColumnList []Column // projection interface implementation func (cl ColumnList) isProjectionType() {} +func (cl ColumnList) from(subQuery ExpressionTable) projection { + newProjectionList := ProjectionList{} + + for _, column := range cl { + newProjectionList = append(newProjectionList, column.from(subQuery)) + } + + return newProjectionList +} + func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error { projections := columnListToProjectionList(cl) @@ -108,6 +132,8 @@ func (cl ColumnList) serializeForProjection(statement statementType, out *queryD } // column interface implementation -func (cl ColumnList) Name() string { return "" } -func (cl ColumnList) TableName() string { return "" } -func (cl ColumnList) setTableName(name string) {} +func (cl ColumnList) Name() string { return "" } +func (cl ColumnList) TableName() string { return "" } +func (cl ColumnList) setTableName(name string) {} +func (cl ColumnList) setSubQuery(subQuery ExpressionTable) {} +func (cl ColumnList) defaultAlias() string { return "" } diff --git a/sqlbuilder/column_test.go b/sqlbuilder/column_test.go index f5d9ef2..be1c0a1 100644 --- a/sqlbuilder/column_test.go +++ b/sqlbuilder/column_test.go @@ -9,6 +9,6 @@ func TestColumn(t *testing.T) { assertClauseSerialize(t, column, "col") column.setTableName("table1") assertClauseSerialize(t, column, "table1.col") - assertProjectionSerialize(t, column, `table1.col AS "table1.col"`) + assertProjectionSerialize(t, &column, `table1.col AS "table1.col"`) assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`) } diff --git a/sqlbuilder/column_types.go b/sqlbuilder/column_types.go index f3cb25f..ab87180 100644 --- a/sqlbuilder/column_types.go +++ b/sqlbuilder/column_types.go @@ -5,7 +5,7 @@ type ColumnBool interface { BoolExpression column - From(table ExpressionTable) ColumnBool + From(subQuery ExpressionTable) ColumnBool } type boolColumnImpl struct { @@ -14,17 +14,23 @@ type boolColumnImpl struct { columnImpl } -func (i *boolColumnImpl) From(table ExpressionTable) ColumnBool { - newBoolColumn := BoolColumn(i.defaultAlias()) - newBoolColumn.setTableName(table.Alias()) +func (i *boolColumnImpl) from(subQuery ExpressionTable) projection { + newBoolColumn := BoolColumn(i.name) + newBoolColumn.setTableName(i.tableName) + newBoolColumn.setSubQuery(subQuery) + + return newBoolColumn +} + +func (i *boolColumnImpl) From(subQuery ExpressionTable) ColumnBool { + newBoolColumn := i.from(subQuery).(ColumnBool) + return newBoolColumn } func BoolColumn(name string) ColumnBool { - boolColumn := &boolColumnImpl{} boolColumn.columnImpl = newColumn(name, "", boolColumn) - boolColumn.boolInterfaceImpl.parent = boolColumn return boolColumn @@ -35,7 +41,7 @@ type ColumnFloat interface { FloatExpression column - From(table ExpressionTable) ColumnFloat + From(subQuery ExpressionTable) ColumnFloat } type floatColumnImpl struct { @@ -43,16 +49,22 @@ type floatColumnImpl struct { columnImpl } -func (i *floatColumnImpl) From(table ExpressionTable) ColumnFloat { - newFloatColumn := FloatColumn(i.defaultAlias()) - newFloatColumn.setTableName(table.Alias()) +func (i *floatColumnImpl) from(subQuery ExpressionTable) projection { + newFloatColumn := FloatColumn(i.name) + newFloatColumn.setTableName(i.tableName) + newFloatColumn.setSubQuery(subQuery) + + return newFloatColumn +} + +func (i *floatColumnImpl) From(subQuery ExpressionTable) ColumnFloat { + newFloatColumn := i.from(subQuery).(ColumnFloat) + return newFloatColumn } func FloatColumn(name string) ColumnFloat { - floatColumn := &floatColumnImpl{} - floatColumn.floatInterfaceImpl.parent = floatColumn floatColumn.columnImpl = newColumn(name, "", floatColumn) @@ -64,7 +76,7 @@ type ColumnInteger interface { IntegerExpression column - From(table ExpressionTable) ColumnInteger + From(subQuery ExpressionTable) ColumnInteger } type integerColumnImpl struct { @@ -73,15 +85,20 @@ type integerColumnImpl struct { columnImpl } -func (i *integerColumnImpl) From(table ExpressionTable) ColumnInteger { - newIntColumn := IntegerColumn(i.defaultAlias()) - newIntColumn.setTableName(table.Alias()) +func (i *integerColumnImpl) from(subQuery ExpressionTable) projection { + newIntColumn := IntegerColumn(i.name) + newIntColumn.setTableName(i.tableName) + newIntColumn.setSubQuery(subQuery) + return newIntColumn } +func (i *integerColumnImpl) From(subQuery ExpressionTable) ColumnInteger { + return i.from(subQuery).(ColumnInteger) +} + func IntegerColumn(name string) ColumnInteger { integerColumn := &integerColumnImpl{} - integerColumn.integerInterfaceImpl.parent = integerColumn integerColumn.columnImpl = newColumn(name, "", integerColumn) @@ -93,7 +110,7 @@ type ColumnString interface { StringExpression column - From(table ExpressionTable) ColumnString + From(subQuery ExpressionTable) ColumnString } type stringColumnImpl struct { @@ -102,18 +119,21 @@ type stringColumnImpl struct { columnImpl } -func (i *stringColumnImpl) From(table ExpressionTable) ColumnString { - newStrColumn := StringColumn(i.defaultAlias()) - newStrColumn.setTableName(table.Alias()) +func (i *stringColumnImpl) from(subQuery ExpressionTable) projection { + newStrColumn := StringColumn(i.name) + newStrColumn.setTableName(i.tableName) + newStrColumn.setSubQuery(subQuery) + return newStrColumn } +func (i *stringColumnImpl) From(subQuery ExpressionTable) ColumnString { + return i.from(subQuery).(ColumnString) +} + func StringColumn(name string) ColumnString { - stringColumn := &stringColumnImpl{} - stringColumn.stringInterfaceImpl.parent = stringColumn - stringColumn.columnImpl = newColumn(name, "", stringColumn) return stringColumn @@ -124,28 +144,30 @@ type ColumnTime interface { TimeExpression column - From(table ExpressionTable) ColumnTime + From(subQuery ExpressionTable) ColumnTime } type timeColumnImpl struct { timeInterfaceImpl - columnImpl } -func (i *timeColumnImpl) From(table ExpressionTable) ColumnTime { - newTimeColumn := TimeColumn(i.defaultAlias()) - newTimeColumn.setTableName(table.Alias()) +func (i *timeColumnImpl) from(subQuery ExpressionTable) projection { + newTimeColumn := TimeColumn(i.name) + newTimeColumn.setTableName(i.tableName) + newTimeColumn.setSubQuery(subQuery) + return newTimeColumn } +func (i *timeColumnImpl) From(subQuery ExpressionTable) ColumnTime { + return i.from(subQuery).(ColumnTime) +} + func TimeColumn(name string) ColumnTime { timeColumn := &timeColumnImpl{} - timeColumn.timeInterfaceImpl.parent = timeColumn - timeColumn.columnImpl = newColumn(name, "", timeColumn) - return timeColumn } @@ -155,7 +177,7 @@ type ColumnTimez interface { TimezExpression column - From(table ExpressionTable) ColumnTimez + From(subQuery ExpressionTable) ColumnTimez } type timezColumnImpl struct { @@ -164,17 +186,21 @@ type timezColumnImpl struct { columnImpl } -func (i *timezColumnImpl) From(table ExpressionTable) ColumnTimez { - newTimezColumn := TimezColumn(i.defaultAlias()) - newTimezColumn.setTableName(table.Alias()) +func (i *timezColumnImpl) from(subQuery ExpressionTable) projection { + newTimezColumn := TimezColumn(i.name) + newTimezColumn.setTableName(i.tableName) + newTimezColumn.setSubQuery(subQuery) + return newTimezColumn } +func (i *timezColumnImpl) From(subQuery ExpressionTable) ColumnTimez { + return i.from(subQuery).(ColumnTimez) +} + func TimezColumn(name string) ColumnTimez { timezColumn := &timezColumnImpl{} - timezColumn.timezInterfaceImpl.parent = timezColumn - timezColumn.columnImpl = newColumn(name, "", timezColumn) return timezColumn @@ -185,7 +211,7 @@ type ColumnTimestamp interface { TimestampExpression column - From(table ExpressionTable) ColumnTimestamp + From(subQuery ExpressionTable) ColumnTimestamp } type timestampColumnImpl struct { @@ -194,17 +220,21 @@ type timestampColumnImpl struct { columnImpl } -func (i *timestampColumnImpl) From(table ExpressionTable) ColumnTimestamp { - newTimestampColumn := TimestampColumn(i.defaultAlias()) - newTimestampColumn.setTableName(table.Alias()) +func (i *timestampColumnImpl) from(subQuery ExpressionTable) projection { + newTimestampColumn := TimestampColumn(i.name) + newTimestampColumn.setTableName(i.tableName) + newTimestampColumn.setSubQuery(subQuery) + return newTimestampColumn } +func (i *timestampColumnImpl) From(subQuery ExpressionTable) ColumnTimestamp { + return i.from(subQuery).(ColumnTimestamp) +} + func TimestampColumn(name string) ColumnTimestamp { timestampColumn := ×tampColumnImpl{} - timestampColumn.timestampInterfaceImpl.parent = timestampColumn - timestampColumn.columnImpl = newColumn(name, "", timestampColumn) return timestampColumn @@ -215,7 +245,7 @@ type ColumnTimestampz interface { TimestampzExpression column - From(table ExpressionTable) ColumnTimestampz + From(subQuery ExpressionTable) ColumnTimestampz } type timestampzColumnImpl struct { @@ -224,17 +254,21 @@ type timestampzColumnImpl struct { columnImpl } -func (i *timestampzColumnImpl) From(table ExpressionTable) ColumnTimestampz { - newTimestampzColumn := TimestampzColumn(i.defaultAlias()) - newTimestampzColumn.setTableName(table.Alias()) +func (i *timestampzColumnImpl) from(subQuery ExpressionTable) projection { + newTimestampzColumn := TimestampzColumn(i.name) + newTimestampzColumn.setTableName(i.tableName) + newTimestampzColumn.setSubQuery(subQuery) + return newTimestampzColumn } +func (i *timestampzColumnImpl) From(subQuery ExpressionTable) ColumnTimestampz { + return i.from(subQuery).(ColumnTimestampz) +} + func TimestampzColumn(name string) ColumnTimestampz { timestampzColumn := ×tampzColumnImpl{} - timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn - timestampzColumn.columnImpl = newColumn(name, "", timestampzColumn) return timestampzColumn @@ -245,7 +279,7 @@ type ColumnDate interface { DateExpression column - From(table ExpressionTable) ColumnDate + From(subQuery ExpressionTable) ColumnDate } type dateColumnImpl struct { @@ -254,18 +288,21 @@ type dateColumnImpl struct { columnImpl } -func (i *dateColumnImpl) From(table ExpressionTable) ColumnDate { - newDateColumn := DateColumn(i.defaultAlias()) - newDateColumn.setTableName(table.Alias()) +func (i *dateColumnImpl) from(subQuery ExpressionTable) projection { + newDateColumn := DateColumn(i.name) + newDateColumn.setTableName(i.tableName) + newDateColumn.setSubQuery(subQuery) + return newDateColumn } +func (i *dateColumnImpl) From(subQuery ExpressionTable) ColumnDate { + return i.from(subQuery).(ColumnDate) +} + func DateColumn(name string) ColumnDate { dateColumn := &dateColumnImpl{} - dateColumn.dateInterfaceImpl.parent = dateColumn - dateColumn.columnImpl = newColumn(name, "", dateColumn) - return dateColumn } diff --git a/sqlbuilder/column_types_test.go b/sqlbuilder/column_types_test.go index 8787d14..0be91ba 100644 --- a/sqlbuilder/column_types_test.go +++ b/sqlbuilder/column_types_test.go @@ -10,36 +10,36 @@ func TestNewBoolColumn(t *testing.T) { boolColumn := BoolColumn("colBool").From(subQuery) assertClauseSerialize(t, boolColumn, `sub_query."colBool"`) assertClauseSerialize(t, boolColumn.EQ(Bool(true)), `(sub_query."colBool" = $1)`, true) - assertProjectionSerialize(t, boolColumn, `sub_query."colBool" AS "sub_query.colBool"`) + assertProjectionSerialize(t, boolColumn, `sub_query."colBool" AS "colBool"`) boolColumn2 := table1ColBool.From(subQuery) assertClauseSerialize(t, boolColumn2, `sub_query."table1.col_bool"`) assertClauseSerialize(t, boolColumn2.EQ(Bool(true)), `(sub_query."table1.col_bool" = $1)`, true) - assertProjectionSerialize(t, boolColumn2, `sub_query."table1.col_bool" AS "sub_query.table1.col_bool"`) + assertProjectionSerialize(t, boolColumn2, `sub_query."table1.col_bool" AS "table1.col_bool"`) } func TestNewIntColumn(t *testing.T) { intColumn := IntegerColumn("col_int").From(subQuery) - assertClauseSerialize(t, intColumn, "sub_query.col_int") - assertClauseSerialize(t, intColumn.EQ(Int(12)), "(sub_query.col_int = $1)", int64(12)) - assertProjectionSerialize(t, intColumn, `sub_query.col_int AS "sub_query.col_int"`) + assertClauseSerialize(t, intColumn, `sub_query."col_int"`) + assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query."col_int" = $1)`, int64(12)) + assertProjectionSerialize(t, intColumn, `sub_query."col_int" AS "col_int"`) intColumn2 := table1ColInt.From(subQuery) assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`) assertClauseSerialize(t, intColumn2.EQ(Int(14)), `(sub_query."table1.col_int" = $1)`, int64(14)) - assertProjectionSerialize(t, intColumn2, `sub_query."table1.col_int" AS "sub_query.table1.col_int"`) + assertProjectionSerialize(t, intColumn2, `sub_query."table1.col_int" AS "table1.col_int"`) } func TestNewFloatColumnColumn(t *testing.T) { floatColumn := FloatColumn("col_float").From(subQuery) - assertClauseSerialize(t, floatColumn, "sub_query.col_float") - assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), "(sub_query.col_float = $1)", float64(1.11)) - assertProjectionSerialize(t, floatColumn, `sub_query.col_float AS "sub_query.col_float"`) + assertClauseSerialize(t, floatColumn, `sub_query."col_float"`) + assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query."col_float" = $1)`, float64(1.11)) + assertProjectionSerialize(t, floatColumn, `sub_query."col_float" AS "col_float"`) floatColumn2 := table1ColFloat.From(subQuery) assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`) assertClauseSerialize(t, floatColumn2.EQ(Float(2.22)), `(sub_query."table1.col_float" = $1)`, float64(2.22)) - assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "sub_query.table1.col_float"`) + assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "table1.col_float"`) } diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index b2a0321..bb1e137 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -44,6 +44,10 @@ type expressionInterfaceImpl struct { parent Expression } +func (e *expressionInterfaceImpl) from(subQuery ExpressionTable) projection { + return e.parent +} + func (e *expressionInterfaceImpl) IS_NULL() BoolExpression { return newPostifxBoolExpression(e.parent, "IS NULL") } @@ -61,7 +65,7 @@ func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpressi } func (e *expressionInterfaceImpl) AS(alias string) projection { - return NewAlias(e.parent, alias) + return newAlias(e.parent, alias) } func (e *expressionInterfaceImpl) ASC() OrderByClause { diff --git a/sqlbuilder/expression_table.go b/sqlbuilder/expression_table.go index 090855c..b986ee7 100644 --- a/sqlbuilder/expression_table.go +++ b/sqlbuilder/expression_table.go @@ -6,19 +6,29 @@ type ExpressionTable interface { ReadableTable Alias() string + + AllColumns() ProjectionList } type expressionTableImpl struct { readableTableInterfaceImpl expression Expression alias string + + projections []projection } -func newExpressionTable(expression Expression, alias string) ExpressionTable { +func newExpressionTable(expression Expression, alias string, projections []projection) ExpressionTable { expTable := &expressionTableImpl{expression: expression, alias: alias} expTable.readableTableInterfaceImpl.parent = expTable + for _, projection := range projections { + newProjection := projection.from(expTable) + + expTable.projections = append(expTable.projections, newProjection) + } + return expTable } @@ -26,6 +36,14 @@ func (e *expressionTableImpl) Alias() string { return e.alias } +func (e *expressionTableImpl) columns() []Column { + return nil +} + +func (e *expressionTableImpl) AllColumns() ProjectionList { + return e.projections +} + func (e *expressionTableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { if e == nil { return errors.New("Expression table is nil. ") diff --git a/sqlbuilder/projection.go b/sqlbuilder/projection.go index d25057e..0a6c89f 100644 --- a/sqlbuilder/projection.go +++ b/sqlbuilder/projection.go @@ -2,4 +2,29 @@ package sqlbuilder type projection interface { serializeForProjection(statement statementType, out *queryData) error + from(subQuery ExpressionTable) projection +} + +//------------------------------------------------------// +// Dummy type for projection list +type ProjectionList []projection + +func (cl ProjectionList) from(subQuery ExpressionTable) projection { + newProjectionList := ProjectionList{} + + for _, projection := range cl { + newProjectionList = append(newProjectionList, projection.from(subQuery)) + } + + return newProjectionList +} + +func (cl ProjectionList) serializeForProjection(statement statementType, out *queryData) error { + err := serializeProjectionList(statement, cl, out) + + if err != nil { + return err + } + + return nil } diff --git a/sqlbuilder/row_type.go b/sqlbuilder/row_type.go index b89914d..9f9f7d9 100644 --- a/sqlbuilder/row_type.go +++ b/sqlbuilder/row_type.go @@ -2,9 +2,5 @@ package sqlbuilder type rowsType interface { clause - hasRows() + projections() []projection } - -type isRowsType struct{} - -func (i *isRowsType) hasRows() {} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index ebf8fa1..4a0bfec 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -16,7 +16,6 @@ var ( type SelectStatement interface { Statement Expression - hasRows() DISTINCT() SelectStatement FROM(table ReadableTable) SelectStatement @@ -31,6 +30,8 @@ type SelectStatement interface { FOR(lock SelectLock) SelectStatement AsTable(alias string) ExpressionTable + + projections() []projection } func SELECT(projection1 projection, projections ...projection) SelectStatement { @@ -39,15 +40,14 @@ func SELECT(projection1 projection, projections ...projection) SelectStatement { type selectStatementImpl struct { expressionInterfaceImpl - isRowsType - table ReadableTable - distinct bool - projections []projection - where BoolExpression - groupBy []groupByClause - having BoolExpression - orderBy []OrderByClause + table ReadableTable + distinct bool + projectionList []projection + where BoolExpression + groupBy []groupByClause + having BoolExpression + orderBy []OrderByClause limit, offset int64 @@ -56,11 +56,11 @@ type selectStatementImpl struct { func newSelectStatement(table ReadableTable, projections []projection) SelectStatement { newSelect := &selectStatementImpl{ - table: table, - projections: projections, - limit: -1, - offset: -1, - distinct: false, + table: table, + projectionList: projections, + limit: -1, + offset: -1, + distinct: false, } newSelect.expressionInterfaceImpl.parent = newSelect @@ -105,11 +105,11 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { out.writeString("DISTINCT") } - if len(s.projections) == 0 { + if len(s.projectionList) == 0 { return errors.New("no column selected for projection") } - err := out.writeProjections(select_statement, s.projections) + err := out.writeProjections(select_statement, s.projectionList) if err != nil { return err @@ -196,8 +196,12 @@ func (s *selectStatementImpl) DebugSql() (query string, err error) { return DebugSql(s) } +func (s *selectStatementImpl) projections() []projection { + return s.projectionList +} + func (s *selectStatementImpl) AsTable(alias string) ExpressionTable { - return newExpressionTable(s.parent, alias) + return newExpressionTable(s.parent, alias, s.projectionList) } func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { @@ -216,9 +220,7 @@ func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement } func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { - s.orderBy = clauses - return s } diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index d823476..4f84be8 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -9,13 +9,14 @@ import ( type SetStatement interface { Statement Expression - hasRows() ORDER_BY(clauses ...OrderByClause) SetStatement LIMIT(limit int64) SetStatement OFFSET(offset int64) SetStatement AsTable(alias string) ExpressionTable + + projections() []projection } const ( @@ -51,7 +52,6 @@ func EXCEPT_ALL(selects ...rowsType) SetStatement { // Similar to selectStatementImpl, but less complete type setStatementImpl struct { expressionInterfaceImpl - isRowsType operator string selects []rowsType @@ -75,23 +75,30 @@ func newSetStatementImpl(operator string, all bool, selects ...rowsType) SetStat return setStatement } -func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { - us.orderBy = orderBy - return us +func (s *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { + s.orderBy = orderBy + return s } -func (us *setStatementImpl) LIMIT(limit int64) SetStatement { - us.limit = limit - return us +func (s *setStatementImpl) LIMIT(limit int64) SetStatement { + s.limit = limit + return s } -func (us *setStatementImpl) OFFSET(offset int64) SetStatement { - us.offset = offset - return us +func (s *setStatementImpl) OFFSET(offset int64) SetStatement { + s.offset = offset + return s } -func (us *setStatementImpl) AsTable(alias string) ExpressionTable { - return newExpressionTable(us.parent, alias) +func (s *setStatementImpl) projections() []projection { + if len(s.selects) > 0 { + return s.selects[0].projections() + } + return []projection{} +} + +func (s *setStatementImpl) AsTable(alias string) ExpressionTable { + return newExpressionTable(s.parent, alias, s.projections()) } func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { @@ -178,10 +185,10 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { return nil } -func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) { +func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) { queryData := &queryData{} - err = us.serializeImpl(queryData) + err = s.serializeImpl(queryData) if err != nil { return @@ -199,6 +206,6 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error return Query(s, db, destination) } -func (u *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { - return Exec(u, db) +func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { + return Exec(s, db) } diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index b58c5aa..1994143 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -24,6 +24,8 @@ type readableTable interface { // Creates a cross join tableName Expression using onCondition. CROSS_JOIN(table ReadableTable) ReadableTable + + columns() []Column } // The sql tableName write interface. @@ -111,7 +113,7 @@ func NewTable(schemaName, name string, columns ...Column) Table { t := &tableImpl{ schemaName: schemaName, name: name, - columns: columns, + columnList: columns, } for _, c := range columns { c.setTableName(name) @@ -130,13 +132,13 @@ type tableImpl struct { schemaName string name string alias string - columns []Column + columnList []Column } func (t *tableImpl) AS(alias string) { t.alias = alias - for _, c := range t.columns { + for _, c := range t.columnList { c.setTableName(alias) } } @@ -151,8 +153,8 @@ func (t *tableImpl) TableName() string { return t.name } -func (t *tableImpl) SchemaTableName() string { - return t.schemaName +func (t *tableImpl) columns() []Column { + return t.columnList } func (t *tableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { @@ -218,6 +220,10 @@ func (t *joinTable) TableName() string { return "" } +func (t *joinTable) columns() []Column { + return append(t.lhs.columns(), t.rhs.columns()...) +} + func (t *joinTable) serialize(statement statementType, out *queryData, options ...serializeOption) (err error) { if t == nil { return errors.New("Join table is nil. ") diff --git a/tests/all_types_test.go b/tests/all_types_test.go index 770c079..0f8c05e 100644 --- a/tests/all_types_test.go +++ b/tests/all_types_test.go @@ -356,6 +356,185 @@ func TestTimeOperators(t *testing.T) { assert.NilError(t, err) } +func TestSubQueryColumnReference(t *testing.T) { + + type expected struct { + sql string + args []interface{} + } + + subQueries := map[ExpressionTable]expected{} + + selectSubQuery := AllTypes.SELECT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.Real, + AllTypes.Text, + AllTypes.Time, + AllTypes.Timez, + AllTypes.Timestamp, + AllTypes.Timestampz, + AllTypes.Date, + AllTypes.Bytea.AS("aliasedColumn"), + ). + LIMIT(2). + AsTable("subQuery") + + var selectExpectedSql = ` ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.integer AS "all_types.integer", + all_types.real AS "all_types.real", + all_types.text AS "all_types.text", + all_types.time AS "all_types.time", + all_types.timez AS "all_types.timez", + all_types.timestamp AS "all_types.timestamp", + all_types.timestampz AS "all_types.timestampz", + all_types.date AS "all_types.date", + all_types.bytea AS "aliasedColumn" + FROM test_sample.all_types + LIMIT 2 + ) AS "subQuery"` + + unionSubQuery := + UNION_ALL( + AllTypes.SELECT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.Real, + AllTypes.Text, + AllTypes.Time, + AllTypes.Timez, + AllTypes.Timestamp, + AllTypes.Timestampz, + AllTypes.Date, + AllTypes.Bytea.AS("aliasedColumn"), + ). + LIMIT(1), + AllTypes.SELECT( + AllTypes.Boolean, + AllTypes.Integer, + AllTypes.Real, + AllTypes.Text, + AllTypes.Time, + AllTypes.Timez, + AllTypes.Timestamp, + AllTypes.Timestampz, + AllTypes.Date, + AllTypes.Bytea.AS("aliasedColumn"), + ). + LIMIT(1).OFFSET(1), + ). + AsTable("subQuery") + + unionExpectedSql := ` + ( + ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.integer AS "all_types.integer", + all_types.real AS "all_types.real", + all_types.text AS "all_types.text", + all_types.time AS "all_types.time", + all_types.timez AS "all_types.timez", + all_types.timestamp AS "all_types.timestamp", + all_types.timestampz AS "all_types.timestampz", + all_types.date AS "all_types.date", + all_types.bytea AS "aliasedColumn" + FROM test_sample.all_types + LIMIT 1 + ) + UNION ALL + ( + SELECT all_types.boolean AS "all_types.boolean", + all_types.integer AS "all_types.integer", + all_types.real AS "all_types.real", + all_types.text AS "all_types.text", + all_types.time AS "all_types.time", + all_types.timez AS "all_types.timez", + all_types.timestamp AS "all_types.timestamp", + all_types.timestampz AS "all_types.timestampz", + all_types.date AS "all_types.date", + all_types.bytea AS "aliasedColumn" + FROM test_sample.all_types + LIMIT 1 + OFFSET 1 + ) + ) AS "subQuery"` + + subQueries[selectSubQuery] = expected{sql: selectExpectedSql, args: []interface{}{int64(2)}} + subQueries[unionSubQuery] = expected{sql: unionExpectedSql, args: []interface{}{int64(1), int64(1), int64(1)}} + + for subQuery, expected := range subQueries { + boolColumn := AllTypes.Boolean.From(subQuery) + intColumn := AllTypes.Integer.From(subQuery) + floatColumn := AllTypes.Real.From(subQuery) + stringColumn := AllTypes.Text.From(subQuery) + timeColumn := AllTypes.Time.From(subQuery) + timezColumn := AllTypes.Timez.From(subQuery) + timestampColumn := AllTypes.Timestamp.From(subQuery) + timestampzColumn := AllTypes.Timestampz.From(subQuery) + dateColumn := AllTypes.Date.From(subQuery) + aliasedColumn := StringColumn("aliasedColumn").From(subQuery) + + stmt1 := SELECT( + boolColumn, + intColumn, + floatColumn, + stringColumn, + timeColumn, + timezColumn, + timestampColumn, + timestampzColumn, + dateColumn, + aliasedColumn, + ). + FROM(subQuery) + + var expectedSql = ` +SELECT "subQuery"."all_types.boolean" AS "all_types.boolean", + "subQuery"."all_types.integer" AS "all_types.integer", + "subQuery"."all_types.real" AS "all_types.real", + "subQuery"."all_types.text" AS "all_types.text", + "subQuery"."all_types.time" AS "all_types.time", + "subQuery"."all_types.timez" AS "all_types.timez", + "subQuery"."all_types.timestamp" AS "all_types.timestamp", + "subQuery"."all_types.timestampz" AS "all_types.timestampz", + "subQuery"."all_types.date" AS "all_types.date", + "subQuery"."aliasedColumn" AS "aliasedColumn" +FROM` + + assertStatementSql(t, stmt1, expectedSql+expected.sql+";\n", expected.args...) + + dest1 := []model.AllTypes{} + err := stmt1.Query(db, &dest1) + assert.NilError(t, err) + assert.Equal(t, len(dest1), 2) + assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean) + assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer) + assert.Equal(t, dest1[0].Real, allTypesRow0.Real) + assert.Equal(t, dest1[0].Text, allTypesRow0.Text) + assert.DeepEqual(t, dest1[0].Time, allTypesRow0.Time) + assert.DeepEqual(t, dest1[0].Timez, allTypesRow0.Timez) + assert.DeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp) + assert.DeepEqual(t, dest1[0].Timestampz, allTypesRow0.Timestampz) + assert.DeepEqual(t, dest1[0].Date, allTypesRow0.Date) + + stmt2 := SELECT( + subQuery.AllColumns(), + ). + FROM(subQuery) + + fmt.Println(stmt2.DebugSql()) + + assertStatementSql(t, stmt2, expectedSql+expected.sql+";\n", expected.args...) + + dest2 := []model.AllTypes{} + err = stmt2.Query(db, &dest2) + + assert.NilError(t, err) + assert.DeepEqual(t, dest1, dest2) + } +} + var allTypesRow0 = model.AllTypes{ SmallintPtr: int16Ptr(1), Smallint: 1, diff --git a/tests/chinook_db_test.go b/tests/chinook_db_test.go index 774ab15..30fea09 100644 --- a/tests/chinook_db_test.go +++ b/tests/chinook_db_test.go @@ -3,6 +3,7 @@ package tests import ( "encoding/json" "fmt" + "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet/sqlbuilder" "github.com/go-jet/jet/tests/.test_files/dvd_rental/chinook/model" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/chinook/table" @@ -151,6 +152,65 @@ ORDER BY "Album.AlbumId"; assert.DeepEqual(t, dest[1], album2) } +func TestSubQueriesForQuotedNames(t *testing.T) { + first10Artist := Artist. + SELECT(Artist.AllColumns). + ORDER_BY(Artist.ArtistId). + LIMIT(10). + AsTable("first10Artist") + + artistId := Artist.ArtistId.From(first10Artist) + + first10Albums := Album. + SELECT(Album.AllColumns). + ORDER_BY(Album.AlbumId). + LIMIT(10). + AsTable("first10Albums") + + albumArtistId := Album.ArtistId.From(first10Albums) + + stmt := first10Artist. + INNER_JOIN(first10Albums, artistId.EQ(albumArtistId)). + SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()). + ORDER_BY(artistId) + + assertStatementSql(t, stmt, ` +SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId", + "first10Artist"."Artist.Name" AS "Artist.Name", + "first10Albums"."Album.AlbumId" AS "Album.AlbumId", + "first10Albums"."Album.Title" AS "Album.Title", + "first10Albums"."Album.ArtistId" AS "Album.ArtistId" +FROM ( + SELECT "Artist"."ArtistId" AS "Artist.ArtistId", + "Artist"."Name" AS "Artist.Name" + FROM chinook."Artist" + ORDER BY "Artist"."ArtistId" + LIMIT 10 + ) AS "first10Artist" + INNER JOIN ( + SELECT "Album"."AlbumId" AS "Album.AlbumId", + "Album"."Title" AS "Album.Title", + "Album"."ArtistId" AS "Album.ArtistId" + FROM chinook."Album" + ORDER BY "Album"."AlbumId" + LIMIT 10 + ) AS "first10Albums" ON ("first10Artist"."Artist.ArtistId" = "first10Albums"."Album.ArtistId") +ORDER BY "first10Artist"."Artist.ArtistId"; +`, int64(10), int64(10)) + + var dest []struct { + model.Artist + + Album []model.Album + } + + err := stmt.Query(db, &dest) + + assert.NilError(t, err) + + spew.Dump(dest) +} + func assertJson(t *testing.T, jsonFilePath string, data interface{}) { fileJsonData, err := ioutil.ReadFile(jsonFilePath) assert.NilError(t, err) diff --git a/tests/select_test.go b/tests/select_test.go index 4b84f0e..b0aa3b8 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -907,7 +907,7 @@ SELECT customer.customer_id AS "customer.customer_id", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active", - customer_payment_sum.amount_sum AS "customer_with_amounts.amount_sum" + customer_payment_sum."amount_sum" AS "CustomerWithAmounts.AmountSum" FROM dvds.customer INNER JOIN ( SELECT payment.customer_id AS "payment.customer_id", @@ -915,7 +915,7 @@ FROM dvds.customer FROM dvds.payment GROUP BY payment.customer_id ) AS customer_payment_sum ON (customer.customer_id = customer_payment_sum."payment.customer_id") -ORDER BY customer_payment_sum.amount_sum ASC; +ORDER BY customer_payment_sum."amount_sum" ASC; ` customersPayments := Payment. @@ -933,7 +933,7 @@ ORDER BY customer_payment_sum.amount_sum ASC; INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerId)). SELECT( Customer.AllColumns, - amountSum.AS("customer_with_amounts.amount_sum"), + amountSum.AS("CustomerWithAmounts.AmountSum"), ). ORDER_BY(amountSum.ASC()) diff --git a/tests/test_util.go b/tests/test_util.go index 0d682fa..d89b5d9 100644 --- a/tests/test_util.go +++ b/tests/test_util.go @@ -17,6 +17,7 @@ func assertStatementSql(t *testing.T, query sqlbuilder.Statement, expectedQuery assert.DeepEqual(t, args, expectedArgs) debuqSql, err := query.DebugSql() + assert.NilError(t, err) assert.Equal(t, debuqSql, expectedQuery) }