Default aliasing refactoring.

This commit is contained in:
zer0sub 2019-05-03 12:51:57 +02:00
parent 22426c8cad
commit 5ad213885f
16 changed files with 198 additions and 124 deletions

View file

@ -11,6 +11,7 @@ type serializeOption int
const (
SKIP_DEFAULT_ALIASING = iota
FOR_PROJECTION
UNION_ORDER_BY
NO_TABLE_NAME
)
@ -21,6 +22,56 @@ type Clause interface {
type queryData struct {
buff bytes.Buffer
args []interface{}
statementType int
clauseType int
}
const (
select_statement = iota
insert_statement
update_statement
delete_statement
set_statement
)
const (
projection_clause = iota
where_clause
order_by_clause
group_by_clause
having_clause
)
func (q *queryData) WriteProjection(projections []Projection) error {
q.clauseType = projection_clause
return serializeProjectionList(projections, q)
}
func (q *queryData) WriteWhere(where Expression) error {
q.clauseType = where_clause
q.WriteString(" WHERE ")
return where.Serialize(q)
}
func (q *queryData) WriteGroupBy(groupBy []Clause) error {
q.clauseType = group_by_clause
q.WriteString(" GROUP BY ")
return serializeClauseList(groupBy, q)
}
func (q *queryData) WriteOrderBy(orderBy []OrderByClause) error {
q.clauseType = order_by_clause
q.WriteString(" ORDER BY ")
return serializeOrderByClauseList(orderBy, q)
}
func (q *queryData) WriteHaving(having Expression) error {
q.clauseType = having_clause
q.WriteString(" HAVING ")
return having.Serialize(q)
}
func (q *queryData) Write(data []byte) {

View file

@ -11,6 +11,8 @@ type Column interface {
Name() string
TableName() string
DefaultAlias() Projection
// Internal function for tracking tableName that a column belongs to
// for the purpose of serialization
setTableName(table string) error
@ -72,26 +74,41 @@ func (c *baseColumn) setTableName(table string) error {
return nil
}
func (c *baseColumn) DefaultAlias() Projection {
return c.As(c.tableName + "." + c.name)
}
func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error {
if c.tableName != "" && !contains(options, NO_TABLE_NAME) {
setOrderBy := out.statementType == set_statement && out.clauseType == order_by_clause
if setOrderBy {
out.WriteString(`"`)
}
if c.tableName != "" {
out.WriteString(c.tableName)
out.WriteString(".")
}
containsDot := strings.Contains(c.name, ".")
wrapColumnName := strings.Contains(c.name, ".") && !setOrderBy
if containsDot {
if wrapColumnName {
out.WriteString(`"`)
}
out.WriteString(c.name)
if containsDot {
if wrapColumnName {
out.WriteString(`"`)
}
if contains(options, FOR_PROJECTION) && !contains(options, SKIP_DEFAULT_ALIASING) && c.tableName != "" {
out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
//if contains(options, FOR_PROJECTION) && !contains(options, SKIP_DEFAULT_ALIASING) && c.tableName != "" {
// out.WriteString(" AS \"" + c.tableName + "." + c.name + `"`)
//}
if setOrderBy {
out.WriteString(`"`)
}
return nil

View file

@ -93,3 +93,15 @@ func NewTimeColumn(name string, nullable NullableColumn) *TimeColumn {
return stringColumn
}
// ------------------------------------------------------//
type refColumn struct {
baseColumn
}
func RefColumn(name string) *refColumn {
refColumn := &refColumn{}
refColumn.baseColumn = newBaseColumn(name, false, "", refColumn)
return refColumn
}

View file

@ -22,7 +22,7 @@ func TestNewBoolColumn(t *testing.T) {
out.Reset()
err = boolColumn.setTableName("table1")
assert.NilError(t, err)
err = boolColumn.Serialize(&out, FOR_PROJECTION)
err = boolColumn.DefaultAlias().SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
@ -52,7 +52,7 @@ func TestNewIntColumn(t *testing.T) {
out.Reset()
err = integerColumn.setTableName("table1")
assert.NilError(t, err)
err = integerColumn.Serialize(&out, FOR_PROJECTION)
err = integerColumn.DefaultAlias().SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
@ -82,7 +82,7 @@ func TestNewNumericColumnColumn(t *testing.T) {
out.Reset()
err = numericColumn.setTableName("table1")
assert.NilError(t, err)
err = numericColumn.Serialize(&out, FOR_PROJECTION)
err = numericColumn.DefaultAlias().SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)

View file

@ -21,15 +21,6 @@ func newDeleteStatement(table WritableTable) DeleteStatement {
type deleteStatementImpl struct {
table WritableTable
where BoolExpression
order *listClause
}
func (u *deleteStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *deleteStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
@ -39,6 +30,7 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{}
queryData.statementType = delete_statement
queryData.WriteString("DELETE FROM ")
@ -54,18 +46,17 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
return "", nil, errors.New("Deleting without a WHERE clause.")
}
queryData.WriteString(" WHERE ")
if err = d.where.Serialize(queryData); err != nil {
if err = queryData.WriteWhere(d.where); err != nil {
return
}
if d.order != nil {
queryData.WriteString(" ORDER BY ")
if err = d.order.Serialize(queryData); err != nil {
return
}
}
return queryData.buff.String() + ";", queryData.args, nil
}
func (u *deleteStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *deleteStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}

View file

@ -67,7 +67,7 @@ func isSimpleOperand(expression Expression) bool {
if _, ok := expression.(Column); ok {
return true
}
if _, ok := expression.(FuncExpression); ok {
if _, ok := expression.(*numericFunc); ok {
return true
}

View file

@ -1,8 +1,8 @@
package sqlbuilder
type FuncExpression interface {
Expression
}
//type FuncExpression interface {
// Expression
//}
type numericFunc struct {
expressionInterfaceImpl

View file

@ -111,7 +111,7 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
}
func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement {
i.returning = projections
i.returning = defaultProjectionAliasing(projections)
return i
}
@ -132,6 +132,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
queryData := &queryData{}
queryData.statementType = insert_statement
queryData.WriteString("INSERT INTO ")
if s.table == nil {
@ -147,18 +148,6 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
if len(s.columns) > 0 {
queryData.WriteString(" (")
//for i, col := range s.columns {
// if i > 0 {
// queryData.WriteByte(',')
// }
//
// if col == nil {
// return "", nil, errors.New("nil column in columns list.")
// }
//
// queryData.WriteString(col.Name())
//}
err = serializeColumnList(s.columns, queryData)
if err != nil {
@ -193,19 +182,6 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
return "", nil, err
}
//for col_i, value := range row {
// if col_i > 0 {
// queryData.WriteByte(',')
// }
//
// if value == nil {
// return "", nil, errors.Newf("nil value in row %d col %d.", row_i, col_i)
// }
//
// if err = value.Serialize(queryData); err != nil {
// return
// }
//}
queryData.WriteByte(')')
}
}
@ -221,7 +197,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
if len(s.returning) > 0 {
queryData.WriteString(" RETURNING ")
err = serializeProjectionList(s.returning, queryData)
err = queryData.WriteProjection(s.returning)
if err != nil {
return

View file

@ -13,6 +13,15 @@ type isOrderByClause struct {
func (o *isOrderByClause) isOrderByClauseType() {
}
type ColumnNameOrderBy string
func (o *ColumnNameOrderBy) isOrderByClauseType() {
}
func (o *ColumnNameOrderBy) Serialize(out *queryData, options ...serializeOption) error {
return nil
}
type orderByClause struct {
isOrderByClause
expression Expression
@ -24,7 +33,7 @@ func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) er
return errors.Newf("nil orderBy by clause.")
}
if err := o.expression.Serialize(out); err != nil {
if err := o.expression.Serialize(out, UNION_ORDER_BY); err != nil {
return err
}

View file

@ -22,3 +22,14 @@ func (cl ColumnList) SerializeForProjection(out *queryData) error {
}
return nil
}
func (cl ColumnList) DefaultAlias() []Projection {
newColumnList := []Projection{}
for _, column := range cl {
newColumn := column.DefaultAlias()
newColumnList = append(newColumnList, newColumn)
}
return newColumnList
}

View file

@ -33,7 +33,7 @@ type selectStatementImpl struct {
distinct bool
projections []Projection
where BoolExpression
groupBy []Clause
groupBy []Clause //can be ROLLUP, ... so clause for now
having BoolExpression
orderBy []OrderByClause
@ -42,13 +42,27 @@ type selectStatementImpl struct {
forUpdate bool
}
func newSelectStatement(
table ReadableTable,
projections []Projection) SelectStatement {
func defaultProjectionAliasing(projections []Projection) []Projection {
aliasedProjections := []Projection{}
for _, projection := range projections {
if column, ok := projection.(Column); ok {
aliasedProjections = append(aliasedProjections, column.DefaultAlias())
} else if columnList, ok := projection.(ColumnList); ok {
aliasedProjections = append(aliasedProjections, columnList.DefaultAlias()...)
} else {
aliasedProjections = append(aliasedProjections, projection)
}
}
return aliasedProjections
}
func newSelectStatement(table ReadableTable, projections []Projection) SelectStatement {
return &selectStatementImpl{
table: table,
projections: projections,
projections: defaultProjectionAliasing(projections),
limit: -1,
offset: -1,
forUpdate: false,
@ -74,6 +88,7 @@ func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOpti
func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error {
out.WriteString("SELECT ")
out.statementType = select_statement
if s.distinct {
out.WriteString("DISTINCT ")
@ -83,7 +98,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serialize
return errors.New("No column selected for projection.")
}
err := serializeProjectionList(s.projections, out)
err := out.WriteProjection(s.projections)
if err != nil {
return err
@ -100,16 +115,15 @@ func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serialize
}
if s.where != nil {
out.WriteString(" WHERE ")
if err := s.where.Serialize(out); err != nil {
return err
err := out.WriteWhere(s.where)
if err != nil {
return nil
}
}
if s.groupBy != nil && len(s.groupBy) > 0 {
out.WriteString(" GROUP BY ")
err := serializeClauseList(s.groupBy, out)
err := out.WriteGroupBy(s.groupBy)
if err != nil {
return err
@ -117,15 +131,17 @@ func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serialize
}
if s.having != nil {
out.WriteString(" HAVING ")
if err = s.having.Serialize(out); err != nil {
err := out.WriteHaving(s.having)
if err != nil {
return err
}
}
if s.orderBy != nil {
out.WriteString(" ORDER BY ")
if err := serializeOrderByClauseList(s.orderBy, out); err != nil {
err := out.WriteOrderBy(s.orderBy)
if err != nil {
return err
}
}

View file

@ -48,7 +48,7 @@ func EXCEPT_ALL(selects ...SelectStatement) SetStatement {
type setStatementImpl struct {
operator string
selects []SelectStatement
order *listClause
orderBy []OrderByClause
limit, offset int64
// True if results of the union should be deduped.
all bool
@ -64,9 +64,9 @@ func newSetStatementImpl(operator string, all bool, selects ...SelectStatement)
}
}
func (us *setStatementImpl) ORDER_BY(clauses ...OrderByClause) SetStatement {
func (us *setStatementImpl) ORDER_BY(orderBy ...OrderByClause) SetStatement {
us.order = newOrderByListClause(clauses...)
us.orderBy = orderBy
return us
}
@ -80,18 +80,19 @@ func (us *setStatementImpl) OFFSET(offset int64) SetStatement {
return us
}
func (us *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
if len(us.selects) < 2 {
func (s *setStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
if len(s.selects) < 2 {
return errors.Newf("UNION statement must have at least two SELECT statements.")
}
out.WriteString("(")
for i, selectStmt := range us.selects {
for i, selectStmt := range s.selects {
if i > 0 {
out.WriteString(" " + us.operator + " ")
out.WriteString(" " + s.operator + " ")
if us.all {
if s.all {
out.WriteString(" ALL ")
}
}
@ -105,21 +106,23 @@ func (us *setStatementImpl) Serialize(out *queryData, options ...serializeOption
out.WriteString(")")
if us.order != nil {
out.WriteString(" ORDER BY ")
if err := us.order.Serialize(out, NO_TABLE_NAME); err != nil {
out.statementType = set_statement
if s.orderBy != nil {
err := out.WriteOrderBy(s.orderBy)
if err != nil {
return err
}
}
if us.limit >= 0 {
if s.limit >= 0 {
out.WriteString(" LIMIT ")
out.InsertArgument(us.limit)
out.InsertArgument(s.limit)
}
if us.offset >= 0 {
if s.offset >= 0 {
out.WriteString(" OFFSET ")
out.InsertArgument(us.offset)
out.InsertArgument(s.offset)
}
return nil

View file

@ -7,11 +7,11 @@ import (
)
func TestUnionNoSelect(t *testing.T) {
query, args, err := UNION().Sql()
_, _, err := UNION().Sql()
assert.Assert(t, err != nil)
//fmt.Println(err.Error())
fmt.Print(query, args)
//fmt.Print(query, args)
}
func TestUnionOneSelect(t *testing.T) {
@ -55,7 +55,7 @@ func TestUnionWithOrderBy(t *testing.T) {
).ORDER_BY(table1Col1.Asc()).Sql()
assert.NilError(t, err)
assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) ORDER BY table1.col1 ASC`)
assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) ORDER BY "table1.col1" ASC`)
assert.Equal(t, len(args), 0)
}

View file

@ -29,14 +29,6 @@ type updateStatementImpl struct {
returning []Projection
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
for _, value := range values {
@ -56,12 +48,14 @@ func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
}
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.returning = projections
u.returning = defaultProjectionAliasing(projections)
return u
}
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
out.statementType = update_statement
out.WriteString("UPDATE ")
if u.table == nil {
@ -84,18 +78,6 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
out.WriteString(" ")
}
//for i, column := range u.columns {
// if i > 0 {
// out.WriteString(", ")
// }
//
// out.WriteString(column.Name())
//
// if err != nil {
// return
// }
//}
err = serializeColumnList(u.columns, out)
if err != nil {
@ -132,8 +114,7 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return "", nil, errors.New("Updating without a WHERE clause.")
}
out.WriteString(" WHERE ")
if err = u.where.Serialize(out); err != nil {
if err = out.WriteWhere(u.where); err != nil {
return
}
@ -149,3 +130,11 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return out.buff.String(), out.args, nil
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}

View file

@ -63,7 +63,7 @@ func serializeExpressionList(expressions []Expression, separator string, out *qu
func serializeProjectionList(projections []Projection, out *queryData) error {
for i, col := range projections {
if i > 0 {
out.WriteByte(',')
out.WriteString(", ")
}
if col == nil {
return errors.New("Projection expression is nil.")