Improvements on sub-query projection reference.

This commit is contained in:
go-jet 2019-06-18 14:35:32 +02:00
parent d9ffa86453
commit 565b670188
17 changed files with 512 additions and 134 deletions

View file

@ -1,24 +1,37 @@
package sqlbuilder package sqlbuilder
type Alias struct { type alias struct {
expression Expression expression Expression
alias string alias string
subQuery ExpressionTable
} }
func NewAlias(expression Expression, alias string) *Alias { func newAlias(expression Expression, aliasName string) projection {
return &Alias{ return &alias{
expression: expression, expression: expression,
alias: alias, alias: aliasName,
} }
} }
func (a *Alias) serializeForProjection(statement statementType, out *queryData) error { func (a *alias) from(subQuery ExpressionTable) projection {
newAlias := *a
newAlias.subQuery = subQuery
return &newAlias
}
func (a *alias) serializeForProjection(statement statementType, out *queryData) error {
if a.subQuery != nil {
out.writeIdentifier(a.subQuery.Alias())
out.writeByte('.')
out.writeQuotedString(a.alias)
} else {
err := a.expression.serialize(statement, out) err := a.expression.serialize(statement, out)
if err != nil { if err != nil {
return err return err
} }
}
out.writeString(`AS "` + a.alias + `"`) out.writeString(`AS "` + a.alias + `"`)

View file

@ -150,6 +150,10 @@ func isPostSeparator(b byte) bool {
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':' return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' || b == ':'
} }
func (q *queryData) writeQuotedString(str string) {
q.writeString(`"` + str + `"`)
}
func (q *queryData) writeString(str string) { func (q *queryData) writeString(str string) {
q.write([]byte(str)) q.write([]byte(str))
} }

View file

@ -7,6 +7,8 @@ type column interface {
TableName() string TableName() string
setTableName(table string) setTableName(table string)
setSubQuery(subQuery ExpressionTable)
defaultAlias() string
} }
type Column interface { type Column interface {
@ -20,6 +22,8 @@ type columnImpl struct {
name string name string
tableName string tableName string
subQuery ExpressionTable
} }
func newColumn(name string, tableName string, parent Column) columnImpl { func newColumn(name string, tableName string, parent Column) columnImpl {
@ -45,6 +49,10 @@ func (c *columnImpl) setTableName(table string) {
c.tableName = table c.tableName = table
} }
func (c *columnImpl) setSubQuery(subQuery ExpressionTable) {
c.subQuery = subQuery
}
func (c *columnImpl) defaultAlias() string { func (c *columnImpl) defaultAlias() string {
if c.tableName != "" { if c.tableName != "" {
return c.tableName + "." + c.name return c.tableName + "." + c.name
@ -78,12 +86,18 @@ func (c columnImpl) serializeForProjection(statement statementType, out *queryDa
func (c columnImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { func (c columnImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
if c.subQuery != nil {
out.writeIdentifier(c.subQuery.Alias())
out.writeByte('.')
out.writeString(`"` + c.defaultAlias() + `"`)
} else {
if c.tableName != "" { if c.tableName != "" {
out.writeIdentifier(c.tableName) out.writeIdentifier(c.tableName)
out.writeByte('.') out.writeByte('.')
} }
out.writeIdentifier(c.name) out.writeIdentifier(c.name)
}
return nil return nil
} }
@ -95,6 +109,16 @@ type ColumnList []Column
// projection interface implementation // projection interface implementation
func (cl ColumnList) isProjectionType() {} func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) from(subQuery ExpressionTable) projection {
newProjectionList := ProjectionList{}
for _, column := range cl {
newProjectionList = append(newProjectionList, column.from(subQuery))
}
return newProjectionList
}
func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error { func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error {
projections := columnListToProjectionList(cl) projections := columnListToProjectionList(cl)
@ -111,3 +135,5 @@ func (cl ColumnList) serializeForProjection(statement statementType, out *queryD
func (cl ColumnList) Name() string { return "" } func (cl ColumnList) Name() string { return "" }
func (cl ColumnList) TableName() string { return "" } func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {} func (cl ColumnList) setTableName(name string) {}
func (cl ColumnList) setSubQuery(subQuery ExpressionTable) {}
func (cl ColumnList) defaultAlias() string { return "" }

View file

@ -9,6 +9,6 @@ func TestColumn(t *testing.T) {
assertClauseSerialize(t, column, "col") assertClauseSerialize(t, column, "col")
column.setTableName("table1") column.setTableName("table1")
assertClauseSerialize(t, column, "table1.col") assertClauseSerialize(t, column, "table1.col")
assertProjectionSerialize(t, column, `table1.col AS "table1.col"`) assertProjectionSerialize(t, &column, `table1.col AS "table1.col"`)
assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`) assertProjectionSerialize(t, column.AS("alias1"), `table1.col AS "alias1"`)
} }

View file

@ -5,7 +5,7 @@ type ColumnBool interface {
BoolExpression BoolExpression
column column
From(table ExpressionTable) ColumnBool From(subQuery ExpressionTable) ColumnBool
} }
type boolColumnImpl struct { type boolColumnImpl struct {
@ -14,17 +14,23 @@ type boolColumnImpl struct {
columnImpl columnImpl
} }
func (i *boolColumnImpl) From(table ExpressionTable) ColumnBool { func (i *boolColumnImpl) from(subQuery ExpressionTable) projection {
newBoolColumn := BoolColumn(i.defaultAlias()) newBoolColumn := BoolColumn(i.name)
newBoolColumn.setTableName(table.Alias()) newBoolColumn.setTableName(i.tableName)
newBoolColumn.setSubQuery(subQuery)
return newBoolColumn
}
func (i *boolColumnImpl) From(subQuery ExpressionTable) ColumnBool {
newBoolColumn := i.from(subQuery).(ColumnBool)
return newBoolColumn return newBoolColumn
} }
func BoolColumn(name string) ColumnBool { func BoolColumn(name string) ColumnBool {
boolColumn := &boolColumnImpl{} boolColumn := &boolColumnImpl{}
boolColumn.columnImpl = newColumn(name, "", boolColumn) boolColumn.columnImpl = newColumn(name, "", boolColumn)
boolColumn.boolInterfaceImpl.parent = boolColumn boolColumn.boolInterfaceImpl.parent = boolColumn
return boolColumn return boolColumn
@ -35,7 +41,7 @@ type ColumnFloat interface {
FloatExpression FloatExpression
column column
From(table ExpressionTable) ColumnFloat From(subQuery ExpressionTable) ColumnFloat
} }
type floatColumnImpl struct { type floatColumnImpl struct {
@ -43,16 +49,22 @@ type floatColumnImpl struct {
columnImpl columnImpl
} }
func (i *floatColumnImpl) From(table ExpressionTable) ColumnFloat { func (i *floatColumnImpl) from(subQuery ExpressionTable) projection {
newFloatColumn := FloatColumn(i.defaultAlias()) newFloatColumn := FloatColumn(i.name)
newFloatColumn.setTableName(table.Alias()) newFloatColumn.setTableName(i.tableName)
newFloatColumn.setSubQuery(subQuery)
return newFloatColumn
}
func (i *floatColumnImpl) From(subQuery ExpressionTable) ColumnFloat {
newFloatColumn := i.from(subQuery).(ColumnFloat)
return newFloatColumn return newFloatColumn
} }
func FloatColumn(name string) ColumnFloat { func FloatColumn(name string) ColumnFloat {
floatColumn := &floatColumnImpl{} floatColumn := &floatColumnImpl{}
floatColumn.floatInterfaceImpl.parent = floatColumn floatColumn.floatInterfaceImpl.parent = floatColumn
floatColumn.columnImpl = newColumn(name, "", floatColumn) floatColumn.columnImpl = newColumn(name, "", floatColumn)
@ -64,7 +76,7 @@ type ColumnInteger interface {
IntegerExpression IntegerExpression
column column
From(table ExpressionTable) ColumnInteger From(subQuery ExpressionTable) ColumnInteger
} }
type integerColumnImpl struct { type integerColumnImpl struct {
@ -73,15 +85,20 @@ type integerColumnImpl struct {
columnImpl columnImpl
} }
func (i *integerColumnImpl) From(table ExpressionTable) ColumnInteger { func (i *integerColumnImpl) from(subQuery ExpressionTable) projection {
newIntColumn := IntegerColumn(i.defaultAlias()) newIntColumn := IntegerColumn(i.name)
newIntColumn.setTableName(table.Alias()) newIntColumn.setTableName(i.tableName)
newIntColumn.setSubQuery(subQuery)
return newIntColumn return newIntColumn
} }
func (i *integerColumnImpl) From(subQuery ExpressionTable) ColumnInteger {
return i.from(subQuery).(ColumnInteger)
}
func IntegerColumn(name string) ColumnInteger { func IntegerColumn(name string) ColumnInteger {
integerColumn := &integerColumnImpl{} integerColumn := &integerColumnImpl{}
integerColumn.integerInterfaceImpl.parent = integerColumn integerColumn.integerInterfaceImpl.parent = integerColumn
integerColumn.columnImpl = newColumn(name, "", integerColumn) integerColumn.columnImpl = newColumn(name, "", integerColumn)
@ -93,7 +110,7 @@ type ColumnString interface {
StringExpression StringExpression
column column
From(table ExpressionTable) ColumnString From(subQuery ExpressionTable) ColumnString
} }
type stringColumnImpl struct { type stringColumnImpl struct {
@ -102,18 +119,21 @@ type stringColumnImpl struct {
columnImpl columnImpl
} }
func (i *stringColumnImpl) From(table ExpressionTable) ColumnString { func (i *stringColumnImpl) from(subQuery ExpressionTable) projection {
newStrColumn := StringColumn(i.defaultAlias()) newStrColumn := StringColumn(i.name)
newStrColumn.setTableName(table.Alias()) newStrColumn.setTableName(i.tableName)
newStrColumn.setSubQuery(subQuery)
return newStrColumn return newStrColumn
} }
func (i *stringColumnImpl) From(subQuery ExpressionTable) ColumnString {
return i.from(subQuery).(ColumnString)
}
func StringColumn(name string) ColumnString { func StringColumn(name string) ColumnString {
stringColumn := &stringColumnImpl{} stringColumn := &stringColumnImpl{}
stringColumn.stringInterfaceImpl.parent = stringColumn stringColumn.stringInterfaceImpl.parent = stringColumn
stringColumn.columnImpl = newColumn(name, "", stringColumn) stringColumn.columnImpl = newColumn(name, "", stringColumn)
return stringColumn return stringColumn
@ -124,28 +144,30 @@ type ColumnTime interface {
TimeExpression TimeExpression
column column
From(table ExpressionTable) ColumnTime From(subQuery ExpressionTable) ColumnTime
} }
type timeColumnImpl struct { type timeColumnImpl struct {
timeInterfaceImpl timeInterfaceImpl
columnImpl columnImpl
} }
func (i *timeColumnImpl) From(table ExpressionTable) ColumnTime { func (i *timeColumnImpl) from(subQuery ExpressionTable) projection {
newTimeColumn := TimeColumn(i.defaultAlias()) newTimeColumn := TimeColumn(i.name)
newTimeColumn.setTableName(table.Alias()) newTimeColumn.setTableName(i.tableName)
newTimeColumn.setSubQuery(subQuery)
return newTimeColumn return newTimeColumn
} }
func (i *timeColumnImpl) From(subQuery ExpressionTable) ColumnTime {
return i.from(subQuery).(ColumnTime)
}
func TimeColumn(name string) ColumnTime { func TimeColumn(name string) ColumnTime {
timeColumn := &timeColumnImpl{} timeColumn := &timeColumnImpl{}
timeColumn.timeInterfaceImpl.parent = timeColumn timeColumn.timeInterfaceImpl.parent = timeColumn
timeColumn.columnImpl = newColumn(name, "", timeColumn) timeColumn.columnImpl = newColumn(name, "", timeColumn)
return timeColumn return timeColumn
} }
@ -155,7 +177,7 @@ type ColumnTimez interface {
TimezExpression TimezExpression
column column
From(table ExpressionTable) ColumnTimez From(subQuery ExpressionTable) ColumnTimez
} }
type timezColumnImpl struct { type timezColumnImpl struct {
@ -164,17 +186,21 @@ type timezColumnImpl struct {
columnImpl columnImpl
} }
func (i *timezColumnImpl) From(table ExpressionTable) ColumnTimez { func (i *timezColumnImpl) from(subQuery ExpressionTable) projection {
newTimezColumn := TimezColumn(i.defaultAlias()) newTimezColumn := TimezColumn(i.name)
newTimezColumn.setTableName(table.Alias()) newTimezColumn.setTableName(i.tableName)
newTimezColumn.setSubQuery(subQuery)
return newTimezColumn return newTimezColumn
} }
func (i *timezColumnImpl) From(subQuery ExpressionTable) ColumnTimez {
return i.from(subQuery).(ColumnTimez)
}
func TimezColumn(name string) ColumnTimez { func TimezColumn(name string) ColumnTimez {
timezColumn := &timezColumnImpl{} timezColumn := &timezColumnImpl{}
timezColumn.timezInterfaceImpl.parent = timezColumn timezColumn.timezInterfaceImpl.parent = timezColumn
timezColumn.columnImpl = newColumn(name, "", timezColumn) timezColumn.columnImpl = newColumn(name, "", timezColumn)
return timezColumn return timezColumn
@ -185,7 +211,7 @@ type ColumnTimestamp interface {
TimestampExpression TimestampExpression
column column
From(table ExpressionTable) ColumnTimestamp From(subQuery ExpressionTable) ColumnTimestamp
} }
type timestampColumnImpl struct { type timestampColumnImpl struct {
@ -194,17 +220,21 @@ type timestampColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampColumnImpl) From(table ExpressionTable) ColumnTimestamp { func (i *timestampColumnImpl) from(subQuery ExpressionTable) projection {
newTimestampColumn := TimestampColumn(i.defaultAlias()) newTimestampColumn := TimestampColumn(i.name)
newTimestampColumn.setTableName(table.Alias()) newTimestampColumn.setTableName(i.tableName)
newTimestampColumn.setSubQuery(subQuery)
return newTimestampColumn return newTimestampColumn
} }
func (i *timestampColumnImpl) From(subQuery ExpressionTable) ColumnTimestamp {
return i.from(subQuery).(ColumnTimestamp)
}
func TimestampColumn(name string) ColumnTimestamp { func TimestampColumn(name string) ColumnTimestamp {
timestampColumn := &timestampColumnImpl{} timestampColumn := &timestampColumnImpl{}
timestampColumn.timestampInterfaceImpl.parent = timestampColumn timestampColumn.timestampInterfaceImpl.parent = timestampColumn
timestampColumn.columnImpl = newColumn(name, "", timestampColumn) timestampColumn.columnImpl = newColumn(name, "", timestampColumn)
return timestampColumn return timestampColumn
@ -215,7 +245,7 @@ type ColumnTimestampz interface {
TimestampzExpression TimestampzExpression
column column
From(table ExpressionTable) ColumnTimestampz From(subQuery ExpressionTable) ColumnTimestampz
} }
type timestampzColumnImpl struct { type timestampzColumnImpl struct {
@ -224,17 +254,21 @@ type timestampzColumnImpl struct {
columnImpl columnImpl
} }
func (i *timestampzColumnImpl) From(table ExpressionTable) ColumnTimestampz { func (i *timestampzColumnImpl) from(subQuery ExpressionTable) projection {
newTimestampzColumn := TimestampzColumn(i.defaultAlias()) newTimestampzColumn := TimestampzColumn(i.name)
newTimestampzColumn.setTableName(table.Alias()) newTimestampzColumn.setTableName(i.tableName)
newTimestampzColumn.setSubQuery(subQuery)
return newTimestampzColumn return newTimestampzColumn
} }
func (i *timestampzColumnImpl) From(subQuery ExpressionTable) ColumnTimestampz {
return i.from(subQuery).(ColumnTimestampz)
}
func TimestampzColumn(name string) ColumnTimestampz { func TimestampzColumn(name string) ColumnTimestampz {
timestampzColumn := &timestampzColumnImpl{} timestampzColumn := &timestampzColumnImpl{}
timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn timestampzColumn.timestampzInterfaceImpl.parent = timestampzColumn
timestampzColumn.columnImpl = newColumn(name, "", timestampzColumn) timestampzColumn.columnImpl = newColumn(name, "", timestampzColumn)
return timestampzColumn return timestampzColumn
@ -245,7 +279,7 @@ type ColumnDate interface {
DateExpression DateExpression
column column
From(table ExpressionTable) ColumnDate From(subQuery ExpressionTable) ColumnDate
} }
type dateColumnImpl struct { type dateColumnImpl struct {
@ -254,18 +288,21 @@ type dateColumnImpl struct {
columnImpl columnImpl
} }
func (i *dateColumnImpl) From(table ExpressionTable) ColumnDate { func (i *dateColumnImpl) from(subQuery ExpressionTable) projection {
newDateColumn := DateColumn(i.defaultAlias()) newDateColumn := DateColumn(i.name)
newDateColumn.setTableName(table.Alias()) newDateColumn.setTableName(i.tableName)
newDateColumn.setSubQuery(subQuery)
return newDateColumn return newDateColumn
} }
func (i *dateColumnImpl) From(subQuery ExpressionTable) ColumnDate {
return i.from(subQuery).(ColumnDate)
}
func DateColumn(name string) ColumnDate { func DateColumn(name string) ColumnDate {
dateColumn := &dateColumnImpl{} dateColumn := &dateColumnImpl{}
dateColumn.dateInterfaceImpl.parent = dateColumn dateColumn.dateInterfaceImpl.parent = dateColumn
dateColumn.columnImpl = newColumn(name, "", dateColumn) dateColumn.columnImpl = newColumn(name, "", dateColumn)
return dateColumn return dateColumn
} }

View file

@ -10,36 +10,36 @@ func TestNewBoolColumn(t *testing.T) {
boolColumn := BoolColumn("colBool").From(subQuery) boolColumn := BoolColumn("colBool").From(subQuery)
assertClauseSerialize(t, boolColumn, `sub_query."colBool"`) assertClauseSerialize(t, boolColumn, `sub_query."colBool"`)
assertClauseSerialize(t, boolColumn.EQ(Bool(true)), `(sub_query."colBool" = $1)`, true) assertClauseSerialize(t, boolColumn.EQ(Bool(true)), `(sub_query."colBool" = $1)`, true)
assertProjectionSerialize(t, boolColumn, `sub_query."colBool" AS "sub_query.colBool"`) assertProjectionSerialize(t, boolColumn, `sub_query."colBool" AS "colBool"`)
boolColumn2 := table1ColBool.From(subQuery) boolColumn2 := table1ColBool.From(subQuery)
assertClauseSerialize(t, boolColumn2, `sub_query."table1.col_bool"`) assertClauseSerialize(t, boolColumn2, `sub_query."table1.col_bool"`)
assertClauseSerialize(t, boolColumn2.EQ(Bool(true)), `(sub_query."table1.col_bool" = $1)`, true) assertClauseSerialize(t, boolColumn2.EQ(Bool(true)), `(sub_query."table1.col_bool" = $1)`, true)
assertProjectionSerialize(t, boolColumn2, `sub_query."table1.col_bool" AS "sub_query.table1.col_bool"`) assertProjectionSerialize(t, boolColumn2, `sub_query."table1.col_bool" AS "table1.col_bool"`)
} }
func TestNewIntColumn(t *testing.T) { func TestNewIntColumn(t *testing.T) {
intColumn := IntegerColumn("col_int").From(subQuery) intColumn := IntegerColumn("col_int").From(subQuery)
assertClauseSerialize(t, intColumn, "sub_query.col_int") assertClauseSerialize(t, intColumn, `sub_query."col_int"`)
assertClauseSerialize(t, intColumn.EQ(Int(12)), "(sub_query.col_int = $1)", int64(12)) assertClauseSerialize(t, intColumn.EQ(Int(12)), `(sub_query."col_int" = $1)`, int64(12))
assertProjectionSerialize(t, intColumn, `sub_query.col_int AS "sub_query.col_int"`) assertProjectionSerialize(t, intColumn, `sub_query."col_int" AS "col_int"`)
intColumn2 := table1ColInt.From(subQuery) intColumn2 := table1ColInt.From(subQuery)
assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`) assertClauseSerialize(t, intColumn2, `sub_query."table1.col_int"`)
assertClauseSerialize(t, intColumn2.EQ(Int(14)), `(sub_query."table1.col_int" = $1)`, int64(14)) assertClauseSerialize(t, intColumn2.EQ(Int(14)), `(sub_query."table1.col_int" = $1)`, int64(14))
assertProjectionSerialize(t, intColumn2, `sub_query."table1.col_int" AS "sub_query.table1.col_int"`) assertProjectionSerialize(t, intColumn2, `sub_query."table1.col_int" AS "table1.col_int"`)
} }
func TestNewFloatColumnColumn(t *testing.T) { func TestNewFloatColumnColumn(t *testing.T) {
floatColumn := FloatColumn("col_float").From(subQuery) floatColumn := FloatColumn("col_float").From(subQuery)
assertClauseSerialize(t, floatColumn, "sub_query.col_float") assertClauseSerialize(t, floatColumn, `sub_query."col_float"`)
assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), "(sub_query.col_float = $1)", float64(1.11)) assertClauseSerialize(t, floatColumn.EQ(Float(1.11)), `(sub_query."col_float" = $1)`, float64(1.11))
assertProjectionSerialize(t, floatColumn, `sub_query.col_float AS "sub_query.col_float"`) assertProjectionSerialize(t, floatColumn, `sub_query."col_float" AS "col_float"`)
floatColumn2 := table1ColFloat.From(subQuery) floatColumn2 := table1ColFloat.From(subQuery)
assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`) assertClauseSerialize(t, floatColumn2, `sub_query."table1.col_float"`)
assertClauseSerialize(t, floatColumn2.EQ(Float(2.22)), `(sub_query."table1.col_float" = $1)`, float64(2.22)) assertClauseSerialize(t, floatColumn2.EQ(Float(2.22)), `(sub_query."table1.col_float" = $1)`, float64(2.22))
assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "sub_query.table1.col_float"`) assertProjectionSerialize(t, floatColumn2, `sub_query."table1.col_float" AS "table1.col_float"`)
} }

View file

@ -44,6 +44,10 @@ type expressionInterfaceImpl struct {
parent Expression parent Expression
} }
func (e *expressionInterfaceImpl) from(subQuery ExpressionTable) projection {
return e.parent
}
func (e *expressionInterfaceImpl) IS_NULL() BoolExpression { func (e *expressionInterfaceImpl) IS_NULL() BoolExpression {
return newPostifxBoolExpression(e.parent, "IS NULL") return newPostifxBoolExpression(e.parent, "IS NULL")
} }
@ -61,7 +65,7 @@ func (e *expressionInterfaceImpl) NOT_IN(expressions ...Expression) BoolExpressi
} }
func (e *expressionInterfaceImpl) AS(alias string) projection { func (e *expressionInterfaceImpl) AS(alias string) projection {
return NewAlias(e.parent, alias) return newAlias(e.parent, alias)
} }
func (e *expressionInterfaceImpl) ASC() OrderByClause { func (e *expressionInterfaceImpl) ASC() OrderByClause {

View file

@ -6,19 +6,29 @@ type ExpressionTable interface {
ReadableTable ReadableTable
Alias() string Alias() string
AllColumns() ProjectionList
} }
type expressionTableImpl struct { type expressionTableImpl struct {
readableTableInterfaceImpl readableTableInterfaceImpl
expression Expression expression Expression
alias string alias string
projections []projection
} }
func newExpressionTable(expression Expression, alias string) ExpressionTable { func newExpressionTable(expression Expression, alias string, projections []projection) ExpressionTable {
expTable := &expressionTableImpl{expression: expression, alias: alias} expTable := &expressionTableImpl{expression: expression, alias: alias}
expTable.readableTableInterfaceImpl.parent = expTable expTable.readableTableInterfaceImpl.parent = expTable
for _, projection := range projections {
newProjection := projection.from(expTable)
expTable.projections = append(expTable.projections, newProjection)
}
return expTable return expTable
} }
@ -26,6 +36,14 @@ func (e *expressionTableImpl) Alias() string {
return e.alias return e.alias
} }
func (e *expressionTableImpl) columns() []Column {
return nil
}
func (e *expressionTableImpl) AllColumns() ProjectionList {
return e.projections
}
func (e *expressionTableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { func (e *expressionTableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
if e == nil { if e == nil {
return errors.New("Expression table is nil. ") return errors.New("Expression table is nil. ")

View file

@ -2,4 +2,29 @@ package sqlbuilder
type projection interface { type projection interface {
serializeForProjection(statement statementType, out *queryData) error serializeForProjection(statement statementType, out *queryData) error
from(subQuery ExpressionTable) projection
}
//------------------------------------------------------//
// Dummy type for projection list
type ProjectionList []projection
func (cl ProjectionList) from(subQuery ExpressionTable) projection {
newProjectionList := ProjectionList{}
for _, projection := range cl {
newProjectionList = append(newProjectionList, projection.from(subQuery))
}
return newProjectionList
}
func (cl ProjectionList) serializeForProjection(statement statementType, out *queryData) error {
err := serializeProjectionList(statement, cl, out)
if err != nil {
return err
}
return nil
} }

View file

@ -2,9 +2,5 @@ package sqlbuilder
type rowsType interface { type rowsType interface {
clause clause
hasRows() projections() []projection
} }
type isRowsType struct{}
func (i *isRowsType) hasRows() {}

View file

@ -16,7 +16,6 @@ var (
type SelectStatement interface { type SelectStatement interface {
Statement Statement
Expression Expression
hasRows()
DISTINCT() SelectStatement DISTINCT() SelectStatement
FROM(table ReadableTable) SelectStatement FROM(table ReadableTable) SelectStatement
@ -31,6 +30,8 @@ type SelectStatement interface {
FOR(lock SelectLock) SelectStatement FOR(lock SelectLock) SelectStatement
AsTable(alias string) ExpressionTable AsTable(alias string) ExpressionTable
projections() []projection
} }
func SELECT(projection1 projection, projections ...projection) SelectStatement { func SELECT(projection1 projection, projections ...projection) SelectStatement {
@ -39,11 +40,10 @@ func SELECT(projection1 projection, projections ...projection) SelectStatement {
type selectStatementImpl struct { type selectStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
isRowsType
table ReadableTable table ReadableTable
distinct bool distinct bool
projections []projection projectionList []projection
where BoolExpression where BoolExpression
groupBy []groupByClause groupBy []groupByClause
having BoolExpression having BoolExpression
@ -57,7 +57,7 @@ type selectStatementImpl struct {
func newSelectStatement(table ReadableTable, projections []projection) SelectStatement { func newSelectStatement(table ReadableTable, projections []projection) SelectStatement {
newSelect := &selectStatementImpl{ newSelect := &selectStatementImpl{
table: table, table: table,
projections: projections, projectionList: projections,
limit: -1, limit: -1,
offset: -1, offset: -1,
distinct: false, distinct: false,
@ -105,11 +105,11 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
out.writeString("DISTINCT") out.writeString("DISTINCT")
} }
if len(s.projections) == 0 { if len(s.projectionList) == 0 {
return errors.New("no column selected for projection") return errors.New("no column selected for projection")
} }
err := out.writeProjections(select_statement, s.projections) err := out.writeProjections(select_statement, s.projectionList)
if err != nil { if err != nil {
return err return err
@ -196,8 +196,12 @@ func (s *selectStatementImpl) DebugSql() (query string, err error) {
return DebugSql(s) return DebugSql(s)
} }
func (s *selectStatementImpl) projections() []projection {
return s.projectionList
}
func (s *selectStatementImpl) AsTable(alias string) ExpressionTable { func (s *selectStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias) return newExpressionTable(s.parent, alias, s.projectionList)
} }
func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { func (s *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
@ -216,9 +220,7 @@ func (s *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement
} }
func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { func (s *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
s.orderBy = clauses s.orderBy = clauses
return s return s
} }

View file

@ -9,13 +9,14 @@ import (
type SetStatement interface { type SetStatement interface {
Statement Statement
Expression Expression
hasRows()
ORDER_BY(clauses ...OrderByClause) SetStatement ORDER_BY(clauses ...OrderByClause) SetStatement
LIMIT(limit int64) SetStatement LIMIT(limit int64) SetStatement
OFFSET(offset int64) SetStatement OFFSET(offset int64) SetStatement
AsTable(alias string) ExpressionTable AsTable(alias string) ExpressionTable
projections() []projection
} }
const ( const (
@ -51,7 +52,6 @@ func EXCEPT_ALL(selects ...rowsType) SetStatement {
// Similar to selectStatementImpl, but less complete // Similar to selectStatementImpl, but less complete
type setStatementImpl struct { type setStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
isRowsType
operator string operator string
selects []rowsType selects []rowsType
@ -75,23 +75,30 @@ func newSetStatementImpl(operator string, all bool, selects ...rowsType) SetStat
return setStatement return setStatement
} }
func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement { func (s *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
us.orderBy = orderBy s.orderBy = orderBy
return us return s
} }
func (us *setStatementImpl) LIMIT(limit int64) SetStatement { func (s *setStatementImpl) LIMIT(limit int64) SetStatement {
us.limit = limit s.limit = limit
return us return s
} }
func (us *setStatementImpl) OFFSET(offset int64) SetStatement { func (s *setStatementImpl) OFFSET(offset int64) SetStatement {
us.offset = offset s.offset = offset
return us return s
} }
func (us *setStatementImpl) AsTable(alias string) ExpressionTable { func (s *setStatementImpl) projections() []projection {
return newExpressionTable(us.parent, alias) if len(s.selects) > 0 {
return s.selects[0].projections()
}
return []projection{}
}
func (s *setStatementImpl) AsTable(alias string) ExpressionTable {
return newExpressionTable(s.parent, alias, s.projections())
} }
func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { func (s *setStatementImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
@ -178,10 +185,10 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
return nil return nil
} }
func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) { func (s *setStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{} queryData := &queryData{}
err = us.serializeImpl(queryData) err = s.serializeImpl(queryData)
if err != nil { if err != nil {
return return
@ -199,6 +206,6 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error
return Query(s, db, destination) return Query(s, db, destination)
} }
func (u *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(u, db) return Exec(s, db)
} }

View file

@ -24,6 +24,8 @@ type readableTable interface {
// Creates a cross join tableName Expression using onCondition. // Creates a cross join tableName Expression using onCondition.
CROSS_JOIN(table ReadableTable) ReadableTable CROSS_JOIN(table ReadableTable) ReadableTable
columns() []Column
} }
// The sql tableName write interface. // The sql tableName write interface.
@ -111,7 +113,7 @@ func NewTable(schemaName, name string, columns ...Column) Table {
t := &tableImpl{ t := &tableImpl{
schemaName: schemaName, schemaName: schemaName,
name: name, name: name,
columns: columns, columnList: columns,
} }
for _, c := range columns { for _, c := range columns {
c.setTableName(name) c.setTableName(name)
@ -130,13 +132,13 @@ type tableImpl struct {
schemaName string schemaName string
name string name string
alias string alias string
columns []Column columnList []Column
} }
func (t *tableImpl) AS(alias string) { func (t *tableImpl) AS(alias string) {
t.alias = alias t.alias = alias
for _, c := range t.columns { for _, c := range t.columnList {
c.setTableName(alias) c.setTableName(alias)
} }
} }
@ -151,8 +153,8 @@ func (t *tableImpl) TableName() string {
return t.name return t.name
} }
func (t *tableImpl) SchemaTableName() string { func (t *tableImpl) columns() []Column {
return t.schemaName return t.columnList
} }
func (t *tableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error { func (t *tableImpl) serialize(statement statementType, out *queryData, options ...serializeOption) error {
@ -218,6 +220,10 @@ func (t *joinTable) TableName() string {
return "" return ""
} }
func (t *joinTable) columns() []Column {
return append(t.lhs.columns(), t.rhs.columns()...)
}
func (t *joinTable) serialize(statement statementType, out *queryData, options ...serializeOption) (err error) { func (t *joinTable) serialize(statement statementType, out *queryData, options ...serializeOption) (err error) {
if t == nil { if t == nil {
return errors.New("Join table is nil. ") return errors.New("Join table is nil. ")

View file

@ -356,6 +356,185 @@ func TestTimeOperators(t *testing.T) {
assert.NilError(t, err) assert.NilError(t, err)
} }
func TestSubQueryColumnReference(t *testing.T) {
type expected struct {
sql string
args []interface{}
}
subQueries := map[ExpressionTable]expected{}
selectSubQuery := AllTypes.SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Real,
AllTypes.Text,
AllTypes.Time,
AllTypes.Timez,
AllTypes.Timestamp,
AllTypes.Timestampz,
AllTypes.Date,
AllTypes.Bytea.AS("aliasedColumn"),
).
LIMIT(2).
AsTable("subQuery")
var selectExpectedSql = ` (
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.real AS "all_types.real",
all_types.text AS "all_types.text",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.date AS "all_types.date",
all_types.bytea AS "aliasedColumn"
FROM test_sample.all_types
LIMIT 2
) AS "subQuery"`
unionSubQuery :=
UNION_ALL(
AllTypes.SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Real,
AllTypes.Text,
AllTypes.Time,
AllTypes.Timez,
AllTypes.Timestamp,
AllTypes.Timestampz,
AllTypes.Date,
AllTypes.Bytea.AS("aliasedColumn"),
).
LIMIT(1),
AllTypes.SELECT(
AllTypes.Boolean,
AllTypes.Integer,
AllTypes.Real,
AllTypes.Text,
AllTypes.Time,
AllTypes.Timez,
AllTypes.Timestamp,
AllTypes.Timestampz,
AllTypes.Date,
AllTypes.Bytea.AS("aliasedColumn"),
).
LIMIT(1).OFFSET(1),
).
AsTable("subQuery")
unionExpectedSql := `
(
(
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.real AS "all_types.real",
all_types.text AS "all_types.text",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.date AS "all_types.date",
all_types.bytea AS "aliasedColumn"
FROM test_sample.all_types
LIMIT 1
)
UNION ALL
(
SELECT all_types.boolean AS "all_types.boolean",
all_types.integer AS "all_types.integer",
all_types.real AS "all_types.real",
all_types.text AS "all_types.text",
all_types.time AS "all_types.time",
all_types.timez AS "all_types.timez",
all_types.timestamp AS "all_types.timestamp",
all_types.timestampz AS "all_types.timestampz",
all_types.date AS "all_types.date",
all_types.bytea AS "aliasedColumn"
FROM test_sample.all_types
LIMIT 1
OFFSET 1
)
) AS "subQuery"`
subQueries[selectSubQuery] = expected{sql: selectExpectedSql, args: []interface{}{int64(2)}}
subQueries[unionSubQuery] = expected{sql: unionExpectedSql, args: []interface{}{int64(1), int64(1), int64(1)}}
for subQuery, expected := range subQueries {
boolColumn := AllTypes.Boolean.From(subQuery)
intColumn := AllTypes.Integer.From(subQuery)
floatColumn := AllTypes.Real.From(subQuery)
stringColumn := AllTypes.Text.From(subQuery)
timeColumn := AllTypes.Time.From(subQuery)
timezColumn := AllTypes.Timez.From(subQuery)
timestampColumn := AllTypes.Timestamp.From(subQuery)
timestampzColumn := AllTypes.Timestampz.From(subQuery)
dateColumn := AllTypes.Date.From(subQuery)
aliasedColumn := StringColumn("aliasedColumn").From(subQuery)
stmt1 := SELECT(
boolColumn,
intColumn,
floatColumn,
stringColumn,
timeColumn,
timezColumn,
timestampColumn,
timestampzColumn,
dateColumn,
aliasedColumn,
).
FROM(subQuery)
var expectedSql = `
SELECT "subQuery"."all_types.boolean" AS "all_types.boolean",
"subQuery"."all_types.integer" AS "all_types.integer",
"subQuery"."all_types.real" AS "all_types.real",
"subQuery"."all_types.text" AS "all_types.text",
"subQuery"."all_types.time" AS "all_types.time",
"subQuery"."all_types.timez" AS "all_types.timez",
"subQuery"."all_types.timestamp" AS "all_types.timestamp",
"subQuery"."all_types.timestampz" AS "all_types.timestampz",
"subQuery"."all_types.date" AS "all_types.date",
"subQuery"."aliasedColumn" AS "aliasedColumn"
FROM`
assertStatementSql(t, stmt1, expectedSql+expected.sql+";\n", expected.args...)
dest1 := []model.AllTypes{}
err := stmt1.Query(db, &dest1)
assert.NilError(t, err)
assert.Equal(t, len(dest1), 2)
assert.Equal(t, dest1[0].Boolean, allTypesRow0.Boolean)
assert.Equal(t, dest1[0].Integer, allTypesRow0.Integer)
assert.Equal(t, dest1[0].Real, allTypesRow0.Real)
assert.Equal(t, dest1[0].Text, allTypesRow0.Text)
assert.DeepEqual(t, dest1[0].Time, allTypesRow0.Time)
assert.DeepEqual(t, dest1[0].Timez, allTypesRow0.Timez)
assert.DeepEqual(t, dest1[0].Timestamp, allTypesRow0.Timestamp)
assert.DeepEqual(t, dest1[0].Timestampz, allTypesRow0.Timestampz)
assert.DeepEqual(t, dest1[0].Date, allTypesRow0.Date)
stmt2 := SELECT(
subQuery.AllColumns(),
).
FROM(subQuery)
fmt.Println(stmt2.DebugSql())
assertStatementSql(t, stmt2, expectedSql+expected.sql+";\n", expected.args...)
dest2 := []model.AllTypes{}
err = stmt2.Query(db, &dest2)
assert.NilError(t, err)
assert.DeepEqual(t, dest1, dest2)
}
}
var allTypesRow0 = model.AllTypes{ var allTypesRow0 = model.AllTypes{
SmallintPtr: int16Ptr(1), SmallintPtr: int16Ptr(1),
Smallint: 1, Smallint: 1,

View file

@ -3,6 +3,7 @@ package tests
import ( import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/davecgh/go-spew/spew"
. "github.com/go-jet/jet/sqlbuilder" . "github.com/go-jet/jet/sqlbuilder"
"github.com/go-jet/jet/tests/.test_files/dvd_rental/chinook/model" "github.com/go-jet/jet/tests/.test_files/dvd_rental/chinook/model"
. "github.com/go-jet/jet/tests/.test_files/dvd_rental/chinook/table" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/chinook/table"
@ -151,6 +152,65 @@ ORDER BY "Album.AlbumId";
assert.DeepEqual(t, dest[1], album2) assert.DeepEqual(t, dest[1], album2)
} }
func TestSubQueriesForQuotedNames(t *testing.T) {
first10Artist := Artist.
SELECT(Artist.AllColumns).
ORDER_BY(Artist.ArtistId).
LIMIT(10).
AsTable("first10Artist")
artistId := Artist.ArtistId.From(first10Artist)
first10Albums := Album.
SELECT(Album.AllColumns).
ORDER_BY(Album.AlbumId).
LIMIT(10).
AsTable("first10Albums")
albumArtistId := Album.ArtistId.From(first10Albums)
stmt := first10Artist.
INNER_JOIN(first10Albums, artistId.EQ(albumArtistId)).
SELECT(first10Artist.AllColumns(), first10Albums.AllColumns()).
ORDER_BY(artistId)
assertStatementSql(t, stmt, `
SELECT "first10Artist"."Artist.ArtistId" AS "Artist.ArtistId",
"first10Artist"."Artist.Name" AS "Artist.Name",
"first10Albums"."Album.AlbumId" AS "Album.AlbumId",
"first10Albums"."Album.Title" AS "Album.Title",
"first10Albums"."Album.ArtistId" AS "Album.ArtistId"
FROM (
SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name"
FROM chinook."Artist"
ORDER BY "Artist"."ArtistId"
LIMIT 10
) AS "first10Artist"
INNER JOIN (
SELECT "Album"."AlbumId" AS "Album.AlbumId",
"Album"."Title" AS "Album.Title",
"Album"."ArtistId" AS "Album.ArtistId"
FROM chinook."Album"
ORDER BY "Album"."AlbumId"
LIMIT 10
) AS "first10Albums" ON ("first10Artist"."Artist.ArtistId" = "first10Albums"."Album.ArtistId")
ORDER BY "first10Artist"."Artist.ArtistId";
`, int64(10), int64(10))
var dest []struct {
model.Artist
Album []model.Album
}
err := stmt.Query(db, &dest)
assert.NilError(t, err)
spew.Dump(dest)
}
func assertJson(t *testing.T, jsonFilePath string, data interface{}) { func assertJson(t *testing.T, jsonFilePath string, data interface{}) {
fileJsonData, err := ioutil.ReadFile(jsonFilePath) fileJsonData, err := ioutil.ReadFile(jsonFilePath)
assert.NilError(t, err) assert.NilError(t, err)

View file

@ -907,7 +907,7 @@ SELECT customer.customer_id AS "customer.customer_id",
customer.create_date AS "customer.create_date", customer.create_date AS "customer.create_date",
customer.last_update AS "customer.last_update", customer.last_update AS "customer.last_update",
customer.active AS "customer.active", customer.active AS "customer.active",
customer_payment_sum.amount_sum AS "customer_with_amounts.amount_sum" customer_payment_sum."amount_sum" AS "CustomerWithAmounts.AmountSum"
FROM dvds.customer FROM dvds.customer
INNER JOIN ( INNER JOIN (
SELECT payment.customer_id AS "payment.customer_id", SELECT payment.customer_id AS "payment.customer_id",
@ -915,7 +915,7 @@ FROM dvds.customer
FROM dvds.payment FROM dvds.payment
GROUP BY payment.customer_id GROUP BY payment.customer_id
) AS customer_payment_sum ON (customer.customer_id = customer_payment_sum."payment.customer_id") ) AS customer_payment_sum ON (customer.customer_id = customer_payment_sum."payment.customer_id")
ORDER BY customer_payment_sum.amount_sum ASC; ORDER BY customer_payment_sum."amount_sum" ASC;
` `
customersPayments := Payment. customersPayments := Payment.
@ -933,7 +933,7 @@ ORDER BY customer_payment_sum.amount_sum ASC;
INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerId)). INNER_JOIN(customersPayments, Customer.CustomerID.EQ(customerId)).
SELECT( SELECT(
Customer.AllColumns, Customer.AllColumns,
amountSum.AS("customer_with_amounts.amount_sum"), amountSum.AS("CustomerWithAmounts.AmountSum"),
). ).
ORDER_BY(amountSum.ASC()) ORDER_BY(amountSum.ASC())

View file

@ -17,6 +17,7 @@ func assertStatementSql(t *testing.T, query sqlbuilder.Statement, expectedQuery
assert.DeepEqual(t, args, expectedArgs) assert.DeepEqual(t, args, expectedArgs)
debuqSql, err := query.DebugSql() debuqSql, err := query.DebugSql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, debuqSql, expectedQuery) assert.Equal(t, debuqSql, expectedQuery)
} }