Add StringColumn type and expression

Add Projection type
Alias refactoring
More numeric operations
This commit is contained in:
zer0sub 2019-04-03 11:03:07 +02:00
parent 033ab1d0da
commit b2f84d048c
16 changed files with 350 additions and 199 deletions

View file

@ -3,24 +3,18 @@ package sqlbuilder
import "bytes"
type Alias struct {
Clause
expression Expression
alias string
}
func NewAlias(expression Expression, alias string) *Alias {
if !validIdentifierName(alias) {
panic("Invalid alias")
}
return &Alias{
expression: expression,
alias: alias,
}
}
func (a *Alias) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (a *Alias) SerializeForProjection(out *bytes.Buffer) error {
err := a.expression.SerializeSql(out, ALIASED)

View file

@ -31,15 +31,15 @@ func (b *boolInterfaceImpl) Eq(expression BoolExpression) BoolExpression {
}
func (b *boolInterfaceImpl) NotEq(expression BoolExpression) BoolExpression {
return Neq(b.parent, expression)
return NotEq(b.parent, expression)
}
func (b *boolInterfaceImpl) GtEq(rhs Expression) BoolExpression {
return Gte(b.parent, rhs)
return GtEq(b.parent, rhs)
}
func (b *boolInterfaceImpl) LtEq(rhs Expression) BoolExpression {
return Lte(b.parent, rhs)
return LtEq(b.parent, rhs)
}
func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression {
@ -196,7 +196,7 @@ func EqL(lhs Expression, val interface{}) BoolExpression {
}
// Returns a representation of "a!=b"
func Neq(lhs, rhs Expression) BoolExpression {
func NotEq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() {
return newBinaryBoolExpression(lhs, rhs, []byte(" IS NOT "))
@ -206,7 +206,7 @@ func Neq(lhs, rhs Expression) BoolExpression {
// Returns a representation of "a!=b", where b is a literal
func NeqL(lhs Expression, val interface{}) BoolExpression {
return Neq(lhs, Literal(val))
return NotEq(lhs, Literal(val))
}
// Returns a representation of "a<b"
@ -220,13 +220,13 @@ func LtL(lhs Expression, val interface{}) BoolExpression {
}
// Returns a representation of "a<=b"
func Lte(lhs, rhs Expression) BoolExpression {
func LtEq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, []byte("<="))
}
// Returns a representation of "a<=b", where b is a literal
func LteL(lhs Expression, val interface{}) BoolExpression {
return Lte(lhs, Literal(val))
return LtEq(lhs, Literal(val))
}
// Returns a representation of "a>b"
@ -240,13 +240,13 @@ func GtL(lhs Expression, val interface{}) BoolExpression {
}
// Returns a representation of "a>=b"
func Gte(lhs, rhs Expression) BoolExpression {
func GtEq(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, []byte(">="))
}
// Returns a representation of "a>=b", where b is a literal
func GteL(lhs Expression, val interface{}) BoolExpression {
return Gte(lhs, Literal(val))
return GtEq(lhs, Literal(val))
}
// Returns a representation of "not expr"

View file

@ -19,21 +19,6 @@ type Column interface {
// Internal function for tracking tableName that a column belongs to
// for the purpose of serialization
setTableName(table string) error
Asc() OrderByClause
Desc() OrderByClause
}
type columnInterfaceImpl struct {
parent Column
}
func (c *columnInterfaceImpl) Asc() OrderByClause {
return &orderByClause{expression: c.parent, ascent: true}
}
func (c *columnInterfaceImpl) Desc() OrderByClause {
return &orderByClause{expression: c.parent, ascent: false}
}
type NullableColumn bool
@ -66,7 +51,6 @@ const (
// The base type for real materialized columns.
type baseColumn struct {
expressionInterfaceImpl
columnInterfaceImpl
name string
nullable NullableColumn
@ -81,7 +65,6 @@ func newBaseColumn(name string, nullable NullableColumn, tableName string, paren
}
bc.expressionInterfaceImpl.parent = parent
bc.columnInterfaceImpl.parent = parent
return bc
}

View file

@ -50,10 +50,6 @@ type IntegerColumn struct {
// Representation of any integer column
// This function will panic if name is not valid
func NewIntegerColumn(name string, nullable NullableColumn) *IntegerColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
integerColumn := &IntegerColumn{}
integerColumn.numericInterfaceImpl.parent = integerColumn
@ -63,3 +59,27 @@ func NewIntegerColumn(name string, nullable NullableColumn) *IntegerColumn {
return integerColumn
}
//------------------------------------------------------//
type StringColumn struct {
stringInterfaceImpl
baseColumn
}
// Representation of any integer column
// This function will panic if name is not valid
func NewStringColumn(name string, nullable NullableColumn) *StringColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
stringColumn := &StringColumn{}
stringColumn.stringInterfaceImpl.parent = stringColumn
stringColumn.stringInterfaceImpl.parent = stringColumn
stringColumn.baseColumn = newBaseColumn(name, nullable, "", stringColumn)
return stringColumn
}

View file

@ -9,17 +9,20 @@ import (
// An expression
type Expression interface {
Clause
Projection
As(alias string) Clause
As(alias string) Projection
IsDistinct(expression Expression) BoolExpression
IsNull() BoolExpression
Asc() OrderByClause
Desc() OrderByClause
}
type expressionInterfaceImpl struct {
parent Expression
}
func (e *expressionInterfaceImpl) As(alias string) Clause {
func (e *expressionInterfaceImpl) As(alias string) Projection {
return NewAlias(e.parent, alias)
}
@ -31,6 +34,18 @@ func (e *expressionInterfaceImpl) IsNull() BoolExpression {
return nil
}
func (e *expressionInterfaceImpl) Asc() OrderByClause {
return &orderByClause{expression: e.parent, ascent: true}
}
func (e *expressionInterfaceImpl) Desc() OrderByClause {
return &orderByClause{expression: e.parent, ascent: false}
}
func (e *expressionInterfaceImpl) SerializeForProjection(out *bytes.Buffer) error {
return e.parent.SerializeSql(out, FOR_PROJECTION)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
type binaryExpression struct {
lhs, rhs Expression
@ -150,32 +165,32 @@ func (c literalExpression) SerializeSql(out *bytes.Buffer, options ...serializeO
}
//------------------------------------------------------//
// Dummy type for select *
type ColumnList []Column
func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
for i, column := range cl {
err := column.SerializeSql(out)
if err != nil {
return err
}
if i != len(cl)-1 {
out.WriteString(", ")
}
}
return nil
}
func (e ColumnList) As(alias string) Clause {
panic("Invalid usage")
}
func (e ColumnList) IsDistinct(expression Expression) BoolExpression {
panic("Invalid usage")
}
func (e ColumnList) IsNull(expression Expression) BoolExpression {
panic("Invalid usage")
}
//// Dummy type for select *
//type ColumnList []Column
//
//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
// for i, column := range cl {
// err := column.SerializeSql(out)
//
// if err != nil {
// return err
// }
//
// if i != len(cl)-1 {
// out.WriteString(", ")
// }
// }
// return nil
//}
//
//func (e ColumnList) As(alias string) Clause {
// panic("Invalid usage")
//}
//
//func (e ColumnList) IsDistinct(expression Expression) BoolExpression {
// panic("Invalid usage")
//}
//
//func (e ColumnList) IsNull(expression Expression) BoolExpression {
// panic("Invalid usage")
//}

View file

@ -2,22 +2,31 @@ package sqlbuilder
import "bytes"
type FuncExpression struct {
type FuncExpression interface {
Expression
}
type numericFunc struct {
expressionInterfaceImpl
numericInterfaceImpl
name string
expression Expression
alias string
}
func (f *FuncExpression) As(alias string) Clause {
newFuncExpression := *f
func NewNumericFunc(name string, expression Expression) NumericExpression {
numericFunc := &numericFunc{
name: name,
expression: expression,
}
newFuncExpression.alias = alias
numericFunc.expressionInterfaceImpl.parent = numericFunc
numericFunc.numericInterfaceImpl.parent = numericFunc
return &newFuncExpression
return numericFunc
}
func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
out.WriteString(f.name)
out.WriteString("(")
err := f.expression.SerializeSql(out)
@ -26,12 +35,6 @@ func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOpt
}
out.WriteString(")")
if f.alias != "" {
out.WriteString(` AS "`)
out.WriteString(f.alias)
out.WriteString(`"`)
}
return nil
}
@ -39,16 +42,10 @@ func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOpt
// return f.SerializeSql(out)
//}
func MAX(expression Expression) *FuncExpression {
return &FuncExpression{
name: "MAX",
expression: expression,
}
func MAX(expression NumericExpression) NumericExpression {
return NewNumericFunc("MAX", expression)
}
func SUM(expression Expression) *FuncExpression {
return &FuncExpression{
name: "SUM",
expression: expression,
}
func SUM(expression NumericExpression) NumericExpression {
return NewNumericFunc("SUM", expression)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder
import (
"bytes"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/pkg/errors"
)
@ -9,9 +10,13 @@ type NumericExpression interface {
Expression
Eq(expression NumericExpression) BoolExpression
EqL(literal interface{}) BoolExpression
NotEq(expression NumericExpression) BoolExpression
NotEqL(literal interface{}) BoolExpression
GtEq(rhs NumericExpression) BoolExpression
GtEqL(literal interface{}) BoolExpression
LtEq(rhs NumericExpression) BoolExpression
LtEqL(literal interface{}) BoolExpression
Add(expression NumericExpression) NumericExpression
Sub(expression NumericExpression) NumericExpression
@ -27,16 +32,32 @@ func (n *numericInterfaceImpl) Eq(expression NumericExpression) BoolExpression {
return Eq(n.parent, expression)
}
func (n *numericInterfaceImpl) EqL(literal interface{}) BoolExpression {
return Eq(n.parent, Literal(literal))
}
func (n *numericInterfaceImpl) NotEq(expression NumericExpression) BoolExpression {
return Neq(n.parent, expression)
return NotEq(n.parent, expression)
}
func (n *numericInterfaceImpl) NotEqL(literal interface{}) BoolExpression {
return NotEq(n.parent, Literal(literal))
}
func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression {
return Gte(n.parent, expression)
return GtEq(n.parent, expression)
}
func (n *numericInterfaceImpl) GtEqL(literal interface{}) BoolExpression {
return GtEq(n.parent, Literal(literal))
}
func (n *numericInterfaceImpl) LtEq(expression NumericExpression) BoolExpression {
return Lte(n.parent, expression)
return LtEq(n.parent, expression)
}
func (n *numericInterfaceImpl) LtEqL(literal interface{}) BoolExpression {
return LtEq(n.parent, Literal(literal))
}
func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression {
@ -92,3 +113,30 @@ func newBinaryNumericExpression(lhs, rhs Expression, operator []byte) NumericExp
return &numericExpression
}
//---------------------------------------------------//
type numericExpressionWrapper struct {
expressionInterfaceImpl
numericInterfaceImpl
expression Expression
}
func newNumericExpressionWrap(expression Expression) NumericExpression {
numericExpressionWrap := numericExpressionWrapper{}
numericExpressionWrap.expression = expression
numericExpressionWrap.expressionInterfaceImpl.parent = &numericExpressionWrap
numericExpressionWrap.numericInterfaceImpl.parent = &numericExpressionWrap
return &numericExpressionWrap
}
func (c *numericExpressionWrapper) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
out.WriteString("(")
err = c.expression.SerializeSql(out, options...)
out.WriteString(")")
return nil
}

26
sqlbuilder/projection.go Normal file
View file

@ -0,0 +1,26 @@
package sqlbuilder
import "bytes"
type Projection interface {
SerializeForProjection(out *bytes.Buffer) error
}
//------------------------------------------------------//
// Dummy type for select * AllColumns
type ColumnList []Column
func (cl ColumnList) SerializeForProjection(out *bytes.Buffer) error {
for i, column := range cl {
err := column.SerializeSql(out, FOR_PROJECTION)
if err != nil {
return err
}
if i != len(cl)-1 {
out.WriteString(", ")
}
}
return nil
}

View file

@ -36,7 +36,7 @@ type selectStatementImpl struct {
expressionInterfaceImpl
table ReadableTable
projections []Expression
projections []Projection
where BoolExpression
group *listClause
having BoolExpression
@ -50,7 +50,7 @@ type selectStatementImpl struct {
func newSelectStatement(
table ReadableTable,
projections []Expression) SelectStatement {
projections []Projection) SelectStatement {
return &selectStatementImpl{
table: table,
@ -210,7 +210,7 @@ func (q *selectStatementImpl) String() (sql string, err error) {
"nil column selected. Generated sql: %s",
buf.String())
}
if err = col.SerializeSql(buf, FOR_PROJECTION); err != nil {
if err = col.SerializeForProjection(buf); err != nil {
return
}
}
@ -267,3 +267,7 @@ func (q *selectStatementImpl) String() (sql string, err error) {
return buf.String(), nil
}
func NumExp(statement SelectStatement) NumericExpression {
return newNumericExpressionWrap(statement)
}

View file

@ -12,18 +12,24 @@ func (s *SelectStatementTable) Columns() []Column {
return s.columns
}
func (s *SelectStatementTable) Column(name string) Column {
return &baseColumn{
name: name,
tableName: s.alias,
}
func (s *SelectStatementTable) RefIntColumnName(name string) Column {
intColumn := NewIntegerColumn(name, NotNullable)
intColumn.setTableName(s.alias)
return intColumn
}
func (s *SelectStatementTable) ColumnFrom(column Column) Column {
return &baseColumn{
name: column.TableName() + "." + column.Name(),
tableName: s.alias,
}
func (s *SelectStatementTable) RefIntColumn(column Column) *IntegerColumn {
intColumn := NewIntegerColumn(column.TableName()+"."+column.Name(), NotNullable)
intColumn.setTableName(s.alias)
return intColumn
}
func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn {
strColumn := NewStringColumn(column.Name(), NotNullable)
strColumn.setTableName(column.TableName())
return strColumn
}
func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error {
@ -43,17 +49,17 @@ func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error {
}
// Generates a select query on the current tableName.
func (s *SelectStatementTable) Select(projections ...Expression) SelectStatement {
func (s *SelectStatementTable) SELECT(projections ...Projection) SelectStatement {
return newSelectStatement(s, projections)
}
// Creates a inner join tableName expression using onCondition.
func (s *SelectStatementTable) InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (s *SelectStatementTable) INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return InnerJoinOn(s, table, onCondition)
}
//func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable {
// return InnerJoinOn(s, table, col1.Eq(col2))
// return INNER_JOIN(s, table, col1.Eq(col2))
//}
// Creates a left join tableName expression using onCondition.
@ -66,7 +72,7 @@ func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition Bool
return RightJoinOn(s, table, onCondition)
}
func (s *SelectStatementTable) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (s *SelectStatementTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(s, table, onCondition)
}

View file

@ -178,7 +178,7 @@ func (us *unionStatementImpl) String() (sql string, err error) {
}
// Union statements in MySQL require that the same number of columns in each subquery
var projections []Expression
var projections []Projection
for _, statement := range us.selects {
// do a type assertion to get at the underlying struct

View file

@ -0,0 +1,25 @@
package sqlbuilder
type StringExpression interface {
Expression
Eq(expression StringExpression) BoolExpression
EqL(value string) BoolExpression
NotEq(expression StringExpression) BoolExpression
}
type stringInterfaceImpl struct {
parent StringExpression
}
func (b *stringInterfaceImpl) Eq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, []byte(" = "))
}
func (b *stringInterfaceImpl) EqL(value string) BoolExpression {
return newBinaryBoolExpression(b.parent, Literal(value), []byte(" = "))
}
func (b *stringInterfaceImpl) NotEq(expression StringExpression) BoolExpression {
return newBinaryBoolExpression(b.parent, expression, []byte(" != "))
}

View file

@ -14,17 +14,17 @@ type ReadableTable interface {
// Returns the list of columns that are in the current tableName expression.
Columns() []Column
Column(name string) Column
//Column(name string) Column
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
SerializeSql(out *bytes.Buffer) error
// Generates a select query on the current tableName.
Select(projections ...Expression) SelectStatement
SELECT(projections ...Projection) SelectStatement
// Creates a inner join tableName expression using onCondition.
InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
INNER_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
//InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable
@ -34,7 +34,7 @@ type ReadableTable interface {
// Creates a right join tableName expression using onCondition.
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable
FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
CrossJoin(table ReadableTable) ReadableTable
}
@ -181,12 +181,12 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error {
}
// Generates a select query on the current tableName.
func (t *Table) Select(projections ...Expression) SelectStatement {
func (t *Table) SELECT(projections ...Projection) SelectStatement {
return newSelectStatement(t, projections)
}
// Creates a inner join tableName expression using onCondition.
func (t *Table) InnerJoinOn(
func (t *Table) INNER_JOIN(
table ReadableTable,
onCondition BoolExpression) ReadableTable {
@ -198,7 +198,7 @@ func (t *Table) InnerJoinOn(
// col1 Column,
// col2 Column) ReadableTable {
//
// return InnerJoinOn(t, table, col1.Eq(col2))
// return INNER_JOIN(t, table, col1.Eq(col2))
//}
// Creates a left join tableName expression using onCondition.
@ -217,7 +217,7 @@ func (t *Table) RightJoinOn(
return RightJoinOn(t, table, onCondition)
}
func (t *Table) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (t *Table) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(t, table, onCondition)
}
@ -363,11 +363,11 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
return nil
}
func (t *joinTable) Select(projections ...Expression) SelectStatement {
func (t *joinTable) SELECT(projections ...Projection) SelectStatement {
return newSelectStatement(t, projections)
}
func (t *joinTable) InnerJoinOn(
func (t *joinTable) INNER_JOIN(
table ReadableTable,
onCondition BoolExpression) ReadableTable {
@ -381,7 +381,7 @@ func (t *joinTable) LeftJoinOn(
return LeftJoinOn(t, table, onCondition)
}
func (t *joinTable) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable {
func (t *joinTable) FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(t, table, onCondition)
}