Column types refactoring.

This commit is contained in:
zer0sub 2019-03-31 14:07:58 +02:00
parent 38007810c1
commit 033ab1d0da
19 changed files with 746 additions and 396 deletions

View file

@ -20,9 +20,9 @@ func NewAlias(expression Expression, alias string) *Alias {
} }
} }
func (a *Alias) SerializeSql(out *bytes.Buffer) error { func (a *Alias) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
err := a.expression.SerializeSql(out) err := a.expression.SerializeSql(out, ALIASED)
if err != nil { if err != nil {
return err return err

View file

@ -11,6 +11,11 @@ import (
type BoolExpression interface { type BoolExpression interface {
Expression Expression
Eq(expression BoolExpression) BoolExpression
NotEq(expression BoolExpression) BoolExpression
GtEq(rhs Expression) BoolExpression
LtEq(rhs Expression) BoolExpression
And(expression BoolExpression) BoolExpression And(expression BoolExpression) BoolExpression
Or(expression BoolExpression) BoolExpression Or(expression BoolExpression) BoolExpression
IsTrue() BoolExpression IsTrue() BoolExpression
@ -21,6 +26,22 @@ type boolInterfaceImpl struct {
parent BoolExpression parent BoolExpression
} }
func (b *boolInterfaceImpl) Eq(expression BoolExpression) BoolExpression {
return Eq(b.parent, expression)
}
func (b *boolInterfaceImpl) NotEq(expression BoolExpression) BoolExpression {
return Neq(b.parent, expression)
}
func (b *boolInterfaceImpl) GtEq(rhs Expression) BoolExpression {
return Gte(b.parent, rhs)
}
func (b *boolInterfaceImpl) LtEq(rhs Expression) BoolExpression {
return Lte(b.parent, rhs)
}
func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression { func (b *boolInterfaceImpl) And(expression BoolExpression) BoolExpression {
return And(b.parent, expression) return And(b.parent, expression)
} }
@ -42,7 +63,7 @@ type boolLiteralExpression struct {
literalExpression literalExpression
} }
func NewBoolLiteralExpression(value bool) BoolExpression { func newBoolLiteralExpression(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{} boolLiteralExpression := boolLiteralExpression{}
sqlValue, err := sqltypes.BuildValue(value) sqlValue, err := sqltypes.BuildValue(value)
@ -57,15 +78,17 @@ func NewBoolLiteralExpression(value bool) BoolExpression {
//---------------------------------------------------// //---------------------------------------------------//
type binaryBoolExpression struct { type binaryBoolExpression struct {
expressionInterfaceImpl
boolInterfaceImpl boolInterfaceImpl
binaryExpression binaryExpression
} }
func NewBinaryBoolExpression(lhs, rhs Expression, operator []byte) BoolExpression { func newBinaryBoolExpression(lhs, rhs Expression, operator []byte) BoolExpression {
boolExpression := binaryBoolExpression{} boolExpression := binaryBoolExpression{}
boolExpression.binaryExpression = *NewBinaryExpression(lhs, rhs, operator, &boolExpression) boolExpression.binaryExpression = newBinaryExpression(lhs, rhs, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression return &boolExpression
@ -73,15 +96,17 @@ func NewBinaryBoolExpression(lhs, rhs Expression, operator []byte) BoolExpressio
//---------------------------------------------------// //---------------------------------------------------//
type prefixBoolExpression struct { type prefixBoolExpression struct {
expressionInterfaceImpl
boolInterfaceImpl boolInterfaceImpl
prefixExpression prefixExpression
} }
func NewPrefixBoolExpression(expression Expression, operator []byte) BoolExpression { func newPrefixBoolExpression(expression Expression, operator []byte) BoolExpression {
boolExpression := prefixBoolExpression{} boolExpression := prefixBoolExpression{}
boolExpression.prefixExpression = *NewPrefixExpression(expression, operator, &boolExpression) boolExpression.prefixExpression = newPrefixExpression(expression, operator)
boolExpression.expressionInterfaceImpl.parent = &boolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression return &boolExpression
@ -89,6 +114,7 @@ func NewPrefixBoolExpression(expression Expression, operator []byte) BoolExpress
//---------------------------------------------------// //---------------------------------------------------//
type conjunctBoolExpression struct { type conjunctBoolExpression struct {
expressionInterfaceImpl
boolInterfaceImpl boolInterfaceImpl
conjunctExpression conjunctExpression
@ -103,8 +129,8 @@ func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) B
}, },
} }
//boolExpression.expressionInterfaceImpl.parent = &boolExpression boolExpression.expressionInterfaceImpl.parent = &boolExpression
//boolExpression.boolInterfaceImpl.parent = &boolExpression boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression return &boolExpression
} }
@ -120,7 +146,7 @@ type inExpression struct {
err error err error
} }
func (c *inExpression) SerializeSql(out *bytes.Buffer) error { func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
if c.err != nil { if c.err != nil {
return errors.Wrap(c.err, "Invalid IN expression") return errors.Wrap(c.err, "Invalid IN expression")
} }
@ -159,9 +185,9 @@ func (c *inExpression) SerializeSql(out *bytes.Buffer) error {
func Eq(lhs, rhs Expression) BoolExpression { func Eq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression) lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() { if ok && sqltypes.Value(lit.value).IsNull() {
return NewBinaryBoolExpression(lhs, rhs, []byte(" IS ")) return newBinaryBoolExpression(lhs, rhs, []byte(" IS "))
} }
return NewBinaryBoolExpression(lhs, rhs, []byte(" = ")) return newBinaryBoolExpression(lhs, rhs, []byte(" = "))
} }
// Returns a representation of "a=b", where b is a literal // Returns a representation of "a=b", where b is a literal
@ -173,9 +199,9 @@ func EqL(lhs Expression, val interface{}) BoolExpression {
func Neq(lhs, rhs Expression) BoolExpression { func Neq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression) lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() { if ok && sqltypes.Value(lit.value).IsNull() {
return NewBinaryBoolExpression(lhs, rhs, []byte(" IS NOT ")) return newBinaryBoolExpression(lhs, rhs, []byte(" IS NOT "))
} }
return NewBinaryBoolExpression(lhs, rhs, []byte("!=")) return newBinaryBoolExpression(lhs, rhs, []byte("!="))
} }
// Returns a representation of "a!=b", where b is a literal // Returns a representation of "a!=b", where b is a literal
@ -185,7 +211,7 @@ func NeqL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a<b" // Returns a representation of "a<b"
func Lt(lhs Expression, rhs Expression) BoolExpression { func Lt(lhs Expression, rhs Expression) BoolExpression {
return NewBinaryBoolExpression(lhs, rhs, []byte("<")) return newBinaryBoolExpression(lhs, rhs, []byte("<"))
} }
// Returns a representation of "a<b", where b is a literal // Returns a representation of "a<b", where b is a literal
@ -195,7 +221,7 @@ func LtL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a<=b" // Returns a representation of "a<=b"
func Lte(lhs, rhs Expression) BoolExpression { func Lte(lhs, rhs Expression) BoolExpression {
return NewBinaryBoolExpression(lhs, rhs, []byte("<=")) return newBinaryBoolExpression(lhs, rhs, []byte("<="))
} }
// Returns a representation of "a<=b", where b is a literal // Returns a representation of "a<=b", where b is a literal
@ -205,7 +231,7 @@ func LteL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a>b" // Returns a representation of "a>b"
func Gt(lhs, rhs Expression) BoolExpression { func Gt(lhs, rhs Expression) BoolExpression {
return NewBinaryBoolExpression(lhs, rhs, []byte(">")) return newBinaryBoolExpression(lhs, rhs, []byte(">"))
} }
// Returns a representation of "a>b", where b is a literal // Returns a representation of "a>b", where b is a literal
@ -215,7 +241,7 @@ func GtL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a>=b" // Returns a representation of "a>=b"
func Gte(lhs, rhs Expression) BoolExpression { func Gte(lhs, rhs Expression) BoolExpression {
return NewBinaryBoolExpression(lhs, rhs, []byte(">=")) return newBinaryBoolExpression(lhs, rhs, []byte(">="))
} }
// Returns a representation of "a>=b", where b is a literal // Returns a representation of "a>=b", where b is a literal
@ -225,11 +251,11 @@ func GteL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "not expr" // Returns a representation of "not expr"
func Not(expr BoolExpression) BoolExpression { func Not(expr BoolExpression) BoolExpression {
return NewPrefixBoolExpression(expr, []byte(" NOT ")) return newPrefixBoolExpression(expr, []byte(" NOT "))
} }
func IsTrue(expr BoolExpression) BoolExpression { func IsTrue(expr BoolExpression) BoolExpression {
return NewPrefixBoolExpression(expr, []byte(" IS TRUE ")) return newPrefixBoolExpression(expr, []byte(" IS TRUE "))
} }
// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses // Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses
@ -243,7 +269,7 @@ func Or(expressions ...BoolExpression) BoolExpression {
} }
func Like(lhs, rhs Expression) BoolExpression { func Like(lhs, rhs Expression) BoolExpression {
return NewBinaryBoolExpression(lhs, rhs, []byte(" LIKE ")) return newBinaryBoolExpression(lhs, rhs, []byte(" LIKE "))
} }
func LikeL(lhs Expression, val string) BoolExpression { func LikeL(lhs Expression, val string) BoolExpression {
@ -251,7 +277,7 @@ func LikeL(lhs Expression, val string) BoolExpression {
} }
func Regexp(lhs, rhs Expression) BoolExpression { func Regexp(lhs, rhs Expression) BoolExpression {
return NewBinaryBoolExpression(lhs, rhs, []byte(" REGEXP ")) return newBinaryBoolExpression(lhs, rhs, []byte(" REGEXP "))
} }
func RegexpL(lhs Expression, val string) BoolExpression { func RegexpL(lhs Expression, val string) BoolExpression {

View file

@ -97,7 +97,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
} }
func TestBoolLiteral(t *testing.T) { func TestBoolLiteral(t *testing.T) {
literal := NewBoolLiteralExpression(true) literal := newBoolLiteralExpression(true)
out := bytes.Buffer{} out := bytes.Buffer{}
err := literal.SerializeSql(&out) err := literal.SerializeSql(&out)

View file

@ -2,6 +2,22 @@ package sqlbuilder
import "bytes" import "bytes"
type serializeOption int
const (
ALIASED = iota
FOR_PROJECTION
)
type Clause interface { type Clause interface {
SerializeSql(out *bytes.Buffer) error SerializeSql(out *bytes.Buffer, options ...serializeOption) error
}
func contains(s []serializeOption, e serializeOption) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
} }

View file

@ -6,8 +6,6 @@ import (
"bytes" "bytes"
"regexp" "regexp"
"strings" "strings"
"github.com/dropbox/godropbox/errors"
) )
// XXX: Maybe add UIntColumn // XXX: Maybe add UIntColumn
@ -15,31 +13,29 @@ import (
// Representation of a tableName for query generation // Representation of a tableName for query generation
type Column interface { type Column interface {
Expression Expression
isProjectionInterface
Name() string Name() string
TableName() string TableName() string
// Serialization for use in column lists
SerializeSqlForColumnList(out *bytes.Buffer) error
// Internal function for tracking tableName that a column belongs to // Internal function for tracking tableName that a column belongs to
// for the purpose of serialization // for the purpose of serialization
setTableName(table string) error setTableName(table string) error
Eq(rhs Expression) BoolExpression
Neq(rhs Expression) BoolExpression
Gte(rhs Expression) BoolExpression
GteLiteral(rhs interface{}) BoolExpression
Lte(rhs Expression) BoolExpression
LteLiteral(rhs interface{}) BoolExpression
Asc() OrderByClause Asc() OrderByClause
Desc() OrderByClause Desc() OrderByClause
} }
type columnInterfaceImpl struct {
parent Column
}
func (c *columnInterfaceImpl) Asc() OrderByClause {
return &orderByClause{expression: c.parent, ascent: true}
}
func (c *columnInterfaceImpl) Desc() OrderByClause {
return &orderByClause{expression: c.parent, ascent: false}
}
type NullableColumn bool type NullableColumn bool
const ( const (
@ -47,10 +43,10 @@ const (
NotNullable NullableColumn = false NotNullable NullableColumn = false
) )
// A column that can be refer to outside of the projection list //// A column that can be refer to outside of the projection list
type NonAliasColumn interface { //type NonAliasColumn interface {
Column // Column
} //}
type Collation string type Collation string
@ -70,19 +66,25 @@ const (
// The base type for real materialized columns. // The base type for real materialized columns.
type baseColumn struct { type baseColumn struct {
expressionInterfaceImpl expressionInterfaceImpl
isProjection columnInterfaceImpl
name string name string
nullable NullableColumn nullable NullableColumn
tableName string tableName string
alias string
} }
//func (c *baseColumn) As(alias string) Projection { func newBaseColumn(name string, nullable NullableColumn, tableName string, parent Column) baseColumn {
// newBaseColumn := *c bc := baseColumn{
// newBaseColumn.alias = alias name: name,
// nullable: nullable,
// return &newBaseColumn tableName: tableName,
//} }
bc.expressionInterfaceImpl.parent = parent
bc.columnInterfaceImpl.parent = parent
return bc
}
func (c *baseColumn) Name() string { func (c *baseColumn) Name() string {
return c.name return c.name
@ -97,20 +99,7 @@ func (c *baseColumn) setTableName(table string) error {
return nil return nil
} }
func (c *baseColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { func (c baseColumn) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
c.SerializeSql(out)
if c.alias != "" {
_, _ = out.WriteString(" AS \"" + c.alias + "\"")
} else if c.tableName != "" {
_, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
}
return nil
}
func (c baseColumn) SerializeSql(out *bytes.Buffer) error {
if c.tableName != "" { if c.tableName != "" {
_, _ = out.WriteString(c.tableName) _, _ = out.WriteString(c.tableName)
_, _ = out.WriteString(".") _, _ = out.WriteString(".")
@ -125,190 +114,163 @@ func (c baseColumn) SerializeSql(out *bytes.Buffer) error {
out.WriteString("\"") out.WriteString("\"")
} }
if contains(options, FOR_PROJECTION) && !contains(options, ALIASED) && c.tableName != "" {
_, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
}
return nil return nil
} }
func (c *baseColumn) Eq(rhs Expression) BoolExpression { //
return Eq(c, rhs) //type bytesColumn struct {
} // baseColumn
//}
//
//// Representation of VARBINARY/BLOB columns
//// This function will panic if name is not valid
//func BytesColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in bytes column")
// }
// bc := &bytesColumn{}
// bc.name = name
// bc.nullable = nullable
// return bc
//}
//
//type stringColumn struct {
// baseColumn
// charset Charset
// collation Collation
//}
//
//// Representation of VARCHAR/TEXT columns
//// This function will panic if name is not valid
//func StrColumn(
// name string,
// charset Charset,
// collation Collation,
// nullable NullableColumn) Column {
//
// if !validIdentifierName(name) {
// panic("Invalid column name in str column")
// }
// sc := &stringColumn{charset: charset, collation: collation}
// sc.name = name
// sc.nullable = nullable
// return sc
//}
//
//type dateTimeColumn struct {
// baseColumn
//}
//
//// Representation of DateTime columns
//// This function will panic if name is not valid
//func DateTimeColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in datetime column")
// }
// dc := &dateTimeColumn{}
// dc.name = name
// dc.nullable = nullable
// return dc
//}
func (c *baseColumn) Neq(rhs Expression) BoolExpression { //type IntegerColumn struct {
return Neq(c, rhs) // baseColumn
} //}
//
//// Representation of any integer column
//// This function will panic if name is not valid
//func IntColumn(name string, nullable NullableColumn) *IntegerColumn {
// if !validIdentifierName(name) {
// panic("Invalid column name in int column")
// }
// ic := &IntegerColumn{}
// ic.name = name
// ic.nullable = nullable
// return ic
//}
func (c *baseColumn) Gte(rhs Expression) BoolExpression { //type doubleColumn struct {
return Gte(c, rhs) // baseColumn
} //}
//
func (c *baseColumn) GteLiteral(rhs interface{}) BoolExpression { //// Representation of any double column
return Gte(c, Literal(rhs)) //// This function will panic if name is not valid
} //func DoubleColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
func (c *baseColumn) Lte(rhs Expression) BoolExpression { // panic("Invalid column name in int column")
return Lte(c, rhs) // }
} // ic := &doubleColumn{}
// ic.name = name
func (c *baseColumn) LteLiteral(literal interface{}) BoolExpression { // ic.nullable = nullable
return Lte(c, Literal(literal)) // return ic
} //}
//
func (c *baseColumn) Asc() OrderByClause { //type booleanColumn struct {
return Asc(c) // baseColumn
} //
// // XXX: Maybe allow isBoolExpression (for now, not included because
func (c *baseColumn) Desc() OrderByClause { // // the deferred lookup equivalent can never be isBoolExpression)
return Desc(c) //}
}
type bytesColumn struct {
baseColumn
}
// Representation of VARBINARY/BLOB columns
// This function will panic if name is not valid
func BytesColumn(name string, nullable NullableColumn) NonAliasColumn {
if !validIdentifierName(name) {
panic("Invalid column name in bytes column")
}
bc := &bytesColumn{}
bc.name = name
bc.nullable = nullable
return bc
}
type stringColumn struct {
baseColumn
charset Charset
collation Collation
}
// Representation of VARCHAR/TEXT columns
// This function will panic if name is not valid
func StrColumn(
name string,
charset Charset,
collation Collation,
nullable NullableColumn) NonAliasColumn {
if !validIdentifierName(name) {
panic("Invalid column name in str column")
}
sc := &stringColumn{charset: charset, collation: collation}
sc.name = name
sc.nullable = nullable
return sc
}
type dateTimeColumn struct {
baseColumn
}
// Representation of DateTime columns
// This function will panic if name is not valid
func DateTimeColumn(name string, nullable NullableColumn) NonAliasColumn {
if !validIdentifierName(name) {
panic("Invalid column name in datetime column")
}
dc := &dateTimeColumn{}
dc.name = name
dc.nullable = nullable
return dc
}
type IntegerColumn struct {
baseColumn
}
// Representation of any integer column
// This function will panic if name is not valid
func IntColumn(name string, nullable NullableColumn) *IntegerColumn {
if !validIdentifierName(name) {
panic("Invalid column name in int column")
}
ic := &IntegerColumn{}
ic.name = name
ic.nullable = nullable
return ic
}
type doubleColumn struct {
baseColumn
}
// Representation of any double column
// This function will panic if name is not valid
func DoubleColumn(name string, nullable NullableColumn) NonAliasColumn {
if !validIdentifierName(name) {
panic("Invalid column name in int column")
}
ic := &doubleColumn{}
ic.name = name
ic.nullable = nullable
return ic
}
type booleanColumn struct {
baseColumn
// XXX: Maybe allow isBoolExpression (for now, not included because
// the deferred lookup equivalent can never be isBoolExpression)
}
// Representation of TINYINT used as a bool // Representation of TINYINT used as a bool
// This function will panic if name is not valid // This function will panic if name is not valid
func BoolColumn(name string, nullable NullableColumn) NonAliasColumn { //func NewBoolColumn(name string, nullable NullableColumn) Column {
if !validIdentifierName(name) { // if !validIdentifierName(name) {
panic("Invalid column name in bool column") // panic("Invalid column name in bool column")
} // }
bc := &booleanColumn{} // bc := &booleanColumn{}
bc.name = name // bc.name = name
bc.nullable = nullable // bc.nullable = nullable
return bc // return bc
} //}
//
//type aliasColumn struct {
// baseColumn
// expression Expression
//}
//
//func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error {
// _ = out.WriteByte('`')
// _, _ = out.WriteString(c.name)
// _ = out.WriteByte('`')
// return nil
//}
//
//func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error {
// if !validIdentifierName(c.name) {
// return errors.Newf(
// "Invalid alias name `%s`. Generated sql: %s",
// c.name,
// out.String())
// }
// if c.expression == nil {
// return errors.Newf(
// "Cannot alias a nil expression. Generated sql: %s",
// out.String())
// }
//
// _ = out.WriteByte('(')
// if c.expression == nil {
// return errors.Newf("nil alias clause. Generate sql: %s", out.String())
// }
// if err := c.expression.SerializeSql(out); err != nil {
// return err
// }
// _, _ = out.WriteString(") AS \"")
// _, _ = out.WriteString(c.name)
// _ = out.WriteByte('"')
// return nil
//}
type aliasColumn struct { //func (c *aliasColumn) setTableName(table string) error {
baseColumn // return errors.Newf(
expression Expression // "Alias column '%s' should never have setTableName called on it",
} // c.name)
//}
func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error {
_ = out.WriteByte('`')
_, _ = out.WriteString(c.name)
_ = out.WriteByte('`')
return nil
}
func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error {
if !validIdentifierName(c.name) {
return errors.Newf(
"Invalid alias name `%s`. Generated sql: %s",
c.name,
out.String())
}
if c.expression == nil {
return errors.Newf(
"Cannot alias a nil expression. Generated sql: %s",
out.String())
}
_ = out.WriteByte('(')
if c.expression == nil {
return errors.Newf("nil alias clause. Generate sql: %s", out.String())
}
if err := c.expression.SerializeSql(out); err != nil {
return err
}
_, _ = out.WriteString(") AS \"")
_, _ = out.WriteString(c.name)
_ = out.WriteByte('"')
return nil
}
func (c *aliasColumn) setTableName(table string) error {
return errors.Newf(
"Alias column '%s' should never have setTableName called on it",
c.name)
}
// Representation of aliased clauses (expression AS name) // Representation of aliased clauses (expression AS name)
//func Alias(name string, c Expression) Column { //func Alias(name string, c Expression) Column {

View file

@ -0,0 +1,65 @@
package sqlbuilder
//------------------------------------------------------//
type BoolColumn struct {
boolInterfaceImpl
baseColumn
}
func NewBoolColumn(name string, nullable NullableColumn) *BoolColumn {
if !validIdentifierName(name) {
panic("Invalid column name in bool column")
}
boolColumn := &BoolColumn{}
boolColumn.baseColumn = newBaseColumn(name, nullable, "", boolColumn)
boolColumn.boolInterfaceImpl.parent = boolColumn
return boolColumn
}
//------------------------------------------------------//
type NumericColumn struct {
numericInterfaceImpl
baseColumn
}
func NewNumericColumn(name string, nullable NullableColumn) *NumericColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
numericColumn := &NumericColumn{}
numericColumn.numericInterfaceImpl.parent = numericColumn
numericColumn.baseColumn = newBaseColumn(name, nullable, "", numericColumn)
return numericColumn
}
//------------------------------------------------------//
type IntegerColumn struct {
numericInterfaceImpl
integerInterfaceImpl
baseColumn
}
// Representation of any integer column
// This function will panic if name is not valid
func NewIntegerColumn(name string, nullable NullableColumn) *IntegerColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
integerColumn := &IntegerColumn{}
integerColumn.numericInterfaceImpl.parent = integerColumn
integerColumn.integerInterfaceImpl.parent = integerColumn
integerColumn.baseColumn = newBaseColumn(name, nullable, "", integerColumn)
return integerColumn
}

View file

@ -0,0 +1,97 @@
package sqlbuilder
import (
"bytes"
"gotest.tools/assert"
"testing"
)
func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable)
out := bytes.Buffer{}
err := boolColumn.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = boolColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = boolColumn.setTableName("table1")
assert.NilError(t, err)
err = boolColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
out.Reset()
err = boolColumn.setTableName("table1")
assert.NilError(t, err)
aliasedBoolColumn := boolColumn.As("alias1")
err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
}
func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable)
out := bytes.Buffer{}
err := integerColumn.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = integerColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = integerColumn.setTableName("table1")
assert.NilError(t, err)
err = integerColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
out.Reset()
err = integerColumn.setTableName("table1")
assert.NilError(t, err)
aliasedBoolColumn := integerColumn.As("alias1")
err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
}
func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable)
out := bytes.Buffer{}
err := numericColumn.SerializeSql(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = numericColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = numericColumn.setTableName("table1")
assert.NilError(t, err)
err = numericColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
out.Reset()
err = numericColumn.setTableName("table1")
assert.NilError(t, err)
aliasedBoolColumn := numericColumn.As("alias1")
err = aliasedBoolColumn.SerializeSql(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
}

View file

@ -12,7 +12,7 @@ type Expression interface {
As(alias string) Clause As(alias string) Clause
IsDistinct(expression Expression) BoolExpression IsDistinct(expression Expression) BoolExpression
IsNull(expression Expression) BoolExpression IsNull() BoolExpression
} }
type expressionInterfaceImpl struct { type expressionInterfaceImpl struct {
@ -27,31 +27,27 @@ func (e *expressionInterfaceImpl) IsDistinct(expression Expression) BoolExpressi
return nil return nil
} }
func (e *expressionInterfaceImpl) IsNull(expression Expression) BoolExpression { func (e *expressionInterfaceImpl) IsNull() BoolExpression {
return nil return nil
} }
// Representation of binary operations (e.g. comparisons, arithmetic) // Representation of binary operations (e.g. comparisons, arithmetic)
type binaryExpression struct { type binaryExpression struct {
expressionInterfaceImpl
lhs, rhs Expression lhs, rhs Expression
operator []byte operator []byte
} }
func NewBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Expression) *binaryExpression { func newBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Expression) binaryExpression {
binaryExpression := binaryExpression{ binaryExpression := binaryExpression{
lhs: lhs, lhs: lhs,
rhs: rhs, rhs: rhs,
operator: operator, operator: operator,
} }
if len(parent) > 0 {
binaryExpression.parent = parent[0]
}
return &binaryExpression return binaryExpression
} }
func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) { func (c *binaryExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if c.lhs == nil { if c.lhs == nil {
return errors.Newf("nil lhs. Generated sql: %s", out.String()) return errors.Newf("nil lhs. Generated sql: %s", out.String())
} }
@ -73,25 +69,20 @@ func (c *binaryExpression) SerializeSql(out *bytes.Buffer) (err error) {
// A not expression which negates a expression value // A not expression which negates a expression value
type prefixExpression struct { type prefixExpression struct {
expressionInterfaceImpl
expression Expression expression Expression
operator []byte operator []byte
} }
func NewPrefixExpression(expression Expression, operator []byte, parent ...Expression) *prefixExpression { func newPrefixExpression(expression Expression, operator []byte) prefixExpression {
prefixExpression := prefixExpression{ prefixExpression := prefixExpression{
expression: expression, expression: expression,
operator: operator, operator: operator,
} }
if len(parent) > 0 {
prefixExpression.parent = parent[0]
}
return &prefixExpression return prefixExpression
} }
func (p *prefixExpression) SerializeSql(out *bytes.Buffer) (err error) { func (p *prefixExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
_, _ = out.Write(p.operator) _, _ = out.Write(p.operator)
if p.expression == nil { if p.expression == nil {
@ -106,12 +97,11 @@ func (p *prefixExpression) SerializeSql(out *bytes.Buffer) (err error) {
// Representation of n-ary conjunctions (AND/OR) // Representation of n-ary conjunctions (AND/OR)
type conjunctExpression struct { type conjunctExpression struct {
expressionInterfaceImpl
expressions []BoolExpression expressions []BoolExpression
conjunction []byte conjunction []byte
} }
func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer) (err error) { func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if len(conj.expressions) == 0 { if len(conj.expressions) == 0 {
return errors.Newf( return errors.Newf(
"Empty conjunction. Generated sql: %s", "Empty conjunction. Generated sql: %s",
@ -154,7 +144,38 @@ func NewLiteralExpression(value sqltypes.Value) *literalExpression {
return &exp return &exp
} }
func (c literalExpression) SerializeSql(out *bytes.Buffer) error { func (c literalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
sqltypes.Value(c.value).EncodeSql(out) sqltypes.Value(c.value).EncodeSql(out)
return nil return nil
} }
//------------------------------------------------------//
// Dummy type for select *
type ColumnList []Column
func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
for i, column := range cl {
err := column.SerializeSql(out)
if err != nil {
return err
}
if i != len(cl)-1 {
out.WriteString(", ")
}
}
return nil
}
func (e ColumnList) As(alias string) Clause {
panic("Invalid usage")
}
func (e ColumnList) IsDistinct(expression Expression) BoolExpression {
panic("Invalid usage")
}
func (e ColumnList) IsNull(expression Expression) BoolExpression {
panic("Invalid usage")
}

View file

@ -17,7 +17,7 @@ type orderByClause struct {
ascent bool ascent bool
} }
func (o *orderByClause) SerializeSql(out *bytes.Buffer) error { func (o *orderByClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
if o.expression == nil { if o.expression == nil {
return errors.Newf( return errors.Newf(
"nil order by clause. Generated sql: %s", "nil order by clause. Generated sql: %s",
@ -82,7 +82,7 @@ type arithmeticExpression struct {
operator []byte operator []byte
} }
func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer) (err error) { func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if len(arith.expressions) == 0 { if len(arith.expressions) == 0 {
return errors.Newf( return errors.Newf(
"Empty arithmetic expression. Generated sql: %s", "Empty arithmetic expression. Generated sql: %s",
@ -115,7 +115,7 @@ type tupleExpression struct {
elements listClause elements listClause
} }
func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer) error { func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
if len(tuple.elements.clauses) < 1 { if len(tuple.elements.clauses) < 1 {
return errors.Newf("Tuples must include at least one element") return errors.Newf("Tuples must include at least one element")
} }
@ -141,7 +141,7 @@ type listClause struct {
includeParentheses bool includeParentheses bool
} }
func (list *listClause) SerializeSql(out *bytes.Buffer) error { func (list *listClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
if list.includeParentheses { if list.includeParentheses {
_ = out.WriteByte('(') _ = out.WriteByte('(')
} }
@ -162,7 +162,7 @@ type funcExpression struct {
args *listClause args *listClause
} }
func (c *funcExpression) SerializeSql(out *bytes.Buffer) (err error) { func (c *funcExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if !validIdentifierName(c.funcName) { if !validIdentifierName(c.funcName) {
return errors.Newf( return errors.Newf(
"Invalid function name: %s. Generated sql: %s", "Invalid function name: %s. Generated sql: %s",
@ -205,7 +205,7 @@ type intervalExpression struct {
var intervalSep = ":" var intervalSep = ":"
func (c *intervalExpression) SerializeSql(out *bytes.Buffer) (err error) { func (c *intervalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
hours := c.duration / time.Hour hours := c.duration / time.Hour
minutes := (c.duration % time.Hour) / time.Minute minutes := (c.duration % time.Hour) / time.Minute
sec := (c.duration % time.Minute) / time.Second sec := (c.duration % time.Minute) / time.Second
@ -336,7 +336,7 @@ type ifExpression struct {
falseExpression Expression falseExpression Expression
} }
func (exp *ifExpression) SerializeSql(out *bytes.Buffer) error { func (exp *ifExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
_, _ = out.WriteString("IF(") _, _ = out.WriteString("IF(")
_ = exp.conditional.SerializeSql(out) _ = exp.conditional.SerializeSql(out)
_, _ = out.WriteString(",") _, _ = out.WriteString(",")

View file

@ -3,15 +3,13 @@ package sqlbuilder
import "bytes" import "bytes"
type FuncExpression struct { type FuncExpression struct {
isProjection
name string name string
expression Expression expression Expression
alias string alias string
} }
func (f *FuncExpression) As(alias string) Projection { func (f *FuncExpression) As(alias string) Clause {
newFuncExpression := *f newFuncExpression := *f
newFuncExpression.alias = alias newFuncExpression.alias = alias
@ -19,7 +17,7 @@ func (f *FuncExpression) As(alias string) Projection {
return &newFuncExpression return &newFuncExpression
} }
func (f *FuncExpression) SerializeSql(out *bytes.Buffer) error { func (f *FuncExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
out.WriteString(f.name) out.WriteString(f.name)
out.WriteString("(") out.WriteString("(")
err := f.expression.SerializeSql(out) err := f.expression.SerializeSql(out)
@ -37,9 +35,9 @@ func (f *FuncExpression) SerializeSql(out *bytes.Buffer) error {
return nil return nil
} }
func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error { //func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error {
return f.SerializeSql(out) // return f.SerializeSql(out)
} //}
func MAX(expression Expression) *FuncExpression { func MAX(expression Expression) *FuncExpression {
return &FuncExpression{ return &FuncExpression{

View file

@ -0,0 +1,82 @@
package sqlbuilder
type IntegerExpression interface {
NumericExpression
//AddInt(value int) IntegerExpression
//AddInt64(value int) IntegerExpression
BitAnd(expression IntegerExpression) IntegerExpression
BitOr(expression IntegerExpression) IntegerExpression
BitXor(expression IntegerExpression) IntegerExpression
BitNot() IntegerExpression
}
type integerInterfaceImpl struct {
parent IntegerExpression
}
//func (i *integerInterfaceImpl) AddInt(expression IntegerExpression) IntegerExpression {
// return NewBinaryIntegerExpression(i.parent, expression, " & ")
//}
//
//func (i *integerInterfaceImpl) AddInt64(expression IntegerExpression) IntegerExpression {
// return NewBinaryIntegerExpression(i.parent, expression, " & ")
//}
func (i *integerInterfaceImpl) BitAnd(expression IntegerExpression) IntegerExpression {
return NewBinaryIntegerExpression(i.parent, expression, " & ")
}
func (i *integerInterfaceImpl) BitOr(expression IntegerExpression) IntegerExpression {
return NewBinaryIntegerExpression(i.parent, expression, " | ")
}
func (i *integerInterfaceImpl) BitXor(expression IntegerExpression) IntegerExpression {
return NewBinaryIntegerExpression(i.parent, expression, " # ")
}
func (i *integerInterfaceImpl) BitNot() IntegerExpression {
return NewPrefixIntegerExpression(i.parent, " ~")
}
//---------------------------------------------------//
type binaryIntegerExpression struct {
expressionInterfaceImpl
numericInterfaceImpl
integerInterfaceImpl
binaryExpression
}
func NewBinaryIntegerExpression(lhs, rhs IntegerExpression, operator string) IntegerExpression {
integerExpression := binaryIntegerExpression{}
integerExpression.expressionInterfaceImpl.parent = &integerExpression
integerExpression.numericInterfaceImpl.parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression
integerExpression.binaryExpression = newBinaryExpression(lhs, rhs, []byte(operator))
return &integerExpression
}
//---------------------------------------------------//
type prefixIntegerExpression struct {
expressionInterfaceImpl
numericInterfaceImpl
integerInterfaceImpl
prefixExpression
}
func NewPrefixIntegerExpression(expression IntegerExpression, operator string) IntegerExpression {
integerExpression := prefixIntegerExpression{}
integerExpression.prefixExpression = newPrefixExpression(expression, []byte(operator))
integerExpression.expressionInterfaceImpl.parent = &integerExpression
integerExpression.numericInterfaceImpl.parent = &integerExpression
integerExpression.integerInterfaceImpl.parent = &integerExpression
return &integerExpression
}

View file

@ -0,0 +1,94 @@
package sqlbuilder
import (
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/pkg/errors"
)
type NumericExpression interface {
Expression
Eq(expression NumericExpression) BoolExpression
NotEq(expression NumericExpression) BoolExpression
GtEq(rhs NumericExpression) BoolExpression
LtEq(rhs NumericExpression) BoolExpression
Add(expression NumericExpression) NumericExpression
Sub(expression NumericExpression) NumericExpression
Mul(expression NumericExpression) NumericExpression
Div(expression NumericExpression) NumericExpression
}
type numericInterfaceImpl struct {
parent NumericExpression
}
func (n *numericInterfaceImpl) Eq(expression NumericExpression) BoolExpression {
return Eq(n.parent, expression)
}
func (n *numericInterfaceImpl) NotEq(expression NumericExpression) BoolExpression {
return Neq(n.parent, expression)
}
func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression {
return Gte(n.parent, expression)
}
func (n *numericInterfaceImpl) LtEq(expression NumericExpression) BoolExpression {
return Lte(n.parent, expression)
}
func (n *numericInterfaceImpl) Add(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, []byte(" + "))
}
func (n *numericInterfaceImpl) Sub(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, []byte(" - "))
}
func (n *numericInterfaceImpl) Mul(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, []byte(" * "))
}
func (n *numericInterfaceImpl) Div(expression NumericExpression) NumericExpression {
return newBinaryNumericExpression(n.parent, expression, []byte(" / "))
}
//---------------------------------------------------//
type numericLiteral struct {
numericInterfaceImpl
literalExpression
}
func NewNumericLiteral(value interface{}) NumericExpression {
numericLiteral := numericLiteral{}
sqlValue, err := sqltypes.BuildValue(value)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
numericLiteral.literalExpression = *NewLiteralExpression(sqlValue)
numericLiteral.numericInterfaceImpl.parent = &numericLiteral
return &numericLiteral
}
//---------------------------------------------------//
type binaryNumericExpression struct {
expressionInterfaceImpl
numericInterfaceImpl
binaryExpression
}
func newBinaryNumericExpression(lhs, rhs Expression, operator []byte) NumericExpression {
numericExpression := binaryNumericExpression{}
numericExpression.binaryExpression = newBinaryExpression(lhs, rhs, operator)
numericExpression.expressionInterfaceImpl.parent = &numericExpression
numericExpression.numericInterfaceImpl.parent = &numericExpression
return &numericExpression
}

View file

@ -36,7 +36,7 @@ type selectStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
table ReadableTable table ReadableTable
projections []Projection projections []Expression
where BoolExpression where BoolExpression
group *listClause group *listClause
having BoolExpression having BoolExpression
@ -50,7 +50,7 @@ type selectStatementImpl struct {
func newSelectStatement( func newSelectStatement(
table ReadableTable, table ReadableTable,
projections []Projection) SelectStatement { projections []Expression) SelectStatement {
return &selectStatementImpl{ return &selectStatementImpl{
table: table, table: table,
@ -63,7 +63,7 @@ func newSelectStatement(
} }
} }
func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer) error { func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
str, err := s.String() str, err := s.String()
if err != nil { if err != nil {
@ -210,7 +210,7 @@ func (q *selectStatementImpl) String() (sql string, err error) {
"nil column selected. Generated sql: %s", "nil column selected. Generated sql: %s",
buf.String()) buf.String())
} }
if err = col.SerializeSqlForColumnList(buf); err != nil { if err = col.SerializeSql(buf, FOR_PROJECTION); err != nil {
return return
} }
} }

View file

@ -4,22 +4,22 @@ import "bytes"
type SelectStatementTable struct { type SelectStatementTable struct {
statement SelectStatement statement SelectStatement
columns []NonAliasColumn columns []Column
alias string alias string
} }
func (s *SelectStatementTable) Columns() []NonAliasColumn { func (s *SelectStatementTable) Columns() []Column {
return s.columns return s.columns
} }
func (s *SelectStatementTable) Column(name string) NonAliasColumn { func (s *SelectStatementTable) Column(name string) Column {
return &baseColumn{ return &baseColumn{
name: name, name: name,
tableName: s.alias, tableName: s.alias,
} }
} }
func (s *SelectStatementTable) ColumnFrom(column NonAliasColumn) NonAliasColumn { func (s *SelectStatementTable) ColumnFrom(column Column) Column {
return &baseColumn{ return &baseColumn{
name: column.TableName() + "." + column.Name(), name: column.TableName() + "." + column.Name(),
tableName: s.alias, tableName: s.alias,
@ -43,7 +43,7 @@ func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error {
} }
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (s *SelectStatementTable) Select(projections ...Projection) SelectStatement { func (s *SelectStatementTable) Select(projections ...Expression) SelectStatement {
return newSelectStatement(s, projections) return newSelectStatement(s, projections)
} }
@ -52,9 +52,9 @@ func (s *SelectStatementTable) InnerJoinOn(table ReadableTable, onCondition Bool
return InnerJoinOn(s, table, onCondition) return InnerJoinOn(s, table, onCondition)
} }
func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable { //func (s *SelectStatementTable) InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable {
return InnerJoinOn(s, table, col1.Eq(col2)) // return InnerJoinOn(s, table, col1.Eq(col2))
} //}
// Creates a left join tableName expression using onCondition. // Creates a left join tableName expression using onCondition.
func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable { func (s *SelectStatementTable) LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable {
@ -66,8 +66,8 @@ func (s *SelectStatementTable) RightJoinOn(table ReadableTable, onCondition Bool
return RightJoinOn(s, table, onCondition) return RightJoinOn(s, table, onCondition)
} }
func (s *SelectStatementTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable { func (s *SelectStatementTable) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(s, table, col1.Eq(col2)) return FullJoin(s, table, onCondition)
} }
func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable { func (s *SelectStatementTable) CrossJoin(table ReadableTable) ReadableTable {

View file

@ -20,7 +20,7 @@ type InsertStatement interface {
// Add a row of values to the insert statement. // Add a row of values to the insert statement.
Add(row ...Expression) InsertStatement Add(row ...Expression) InsertStatement
AddOnDuplicateKeyUpdate(col NonAliasColumn, expr Expression) InsertStatement AddOnDuplicateKeyUpdate(col Column, expr Expression) InsertStatement
Comment(comment string) InsertStatement Comment(comment string) InsertStatement
IgnoreDuplicates(ignore bool) InsertStatement IgnoreDuplicates(ignore bool) InsertStatement
} }
@ -48,7 +48,7 @@ type UnionStatement interface {
type UpdateStatement interface { type UpdateStatement interface {
Statement Statement
Set(column NonAliasColumn, expression Expression) UpdateStatement Set(column Column, expression Expression) UpdateStatement
Where(expression BoolExpression) UpdateStatement Where(expression BoolExpression) UpdateStatement
OrderBy(clauses ...OrderByClause) UpdateStatement OrderBy(clauses ...OrderByClause) UpdateStatement
Limit(limit int64) UpdateStatement Limit(limit int64) UpdateStatement
@ -178,7 +178,7 @@ func (us *unionStatementImpl) String() (sql string, err error) {
} }
// Union statements in MySQL require that the same number of columns in each subquery // Union statements in MySQL require that the same number of columns in each subquery
var projections []Projection var projections []Expression
for _, statement := range us.selects { for _, statement := range us.selects {
// do a type assertion to get at the underlying struct // do a type assertion to get at the underlying struct
@ -267,7 +267,7 @@ func (us *unionStatementImpl) String() (sql string, err error) {
func newInsertStatement( func newInsertStatement(
t WritableTable, t WritableTable,
columns ...NonAliasColumn) InsertStatement { columns ...Column) InsertStatement {
return &insertStatementImpl{ return &insertStatementImpl{
table: t, table: t,
@ -278,13 +278,13 @@ func newInsertStatement(
} }
type columnAssignment struct { type columnAssignment struct {
col NonAliasColumn col Column
expr Expression expr Expression
} }
type insertStatementImpl struct { type insertStatementImpl struct {
table WritableTable table WritableTable
columns []NonAliasColumn columns []Column
rows [][]Expression rows [][]Expression
onDuplicateKeyUpdates []columnAssignment onDuplicateKeyUpdates []columnAssignment
comment string comment string
@ -303,7 +303,7 @@ func (s *insertStatementImpl) Add(
} }
func (s *insertStatementImpl) AddOnDuplicateKeyUpdate( func (s *insertStatementImpl) AddOnDuplicateKeyUpdate(
col NonAliasColumn, col Column,
expr Expression) InsertStatement { expr Expression) InsertStatement {
s.onDuplicateKeyUpdates = append( s.onDuplicateKeyUpdates = append(
@ -361,7 +361,7 @@ func (s *insertStatementImpl) String() (sql string, err error) {
buf.String()) buf.String())
} }
if err = col.SerializeSqlForColumnList(buf); err != nil { if err = col.SerializeSql(buf, FOR_PROJECTION); err != nil {
return return
} }
} }
@ -413,12 +413,11 @@ func (s *insertStatementImpl) String() (sql string, err error) {
if colExpr.col == nil { if colExpr.col == nil {
return "", errors.Newf( return "", errors.Newf(
("nil column in on duplicate key update list. " + "nil column in on duplicate key update list. "+"Generated sql: %s",
"Generated sql: %s"),
buf.String()) buf.String())
} }
if err = colExpr.col.SerializeSqlForColumnList(buf); err != nil { if err = colExpr.col.SerializeSql(buf, FOR_PROJECTION); err != nil {
return return
} }
@ -426,8 +425,7 @@ func (s *insertStatementImpl) String() (sql string, err error) {
if colExpr.expr == nil { if colExpr.expr == nil {
return "", errors.Newf( return "", errors.Newf(
("nil expression in on duplicate key update list. " + "nil expression in on duplicate key update list. "+"Generated sql: %s",
"Generated sql: %s"),
buf.String()) buf.String())
} }
@ -447,14 +445,14 @@ func (s *insertStatementImpl) String() (sql string, err error) {
func newUpdateStatement(table WritableTable) UpdateStatement { func newUpdateStatement(table WritableTable) UpdateStatement {
return &updateStatementImpl{ return &updateStatementImpl{
table: table, table: table,
updateValues: make(map[NonAliasColumn]Expression), updateValues: make(map[Column]Expression),
limit: -1, limit: -1,
} }
} }
type updateStatementImpl struct { type updateStatementImpl struct {
table WritableTable table WritableTable
updateValues map[NonAliasColumn]Expression updateValues map[Column]Expression
where BoolExpression where BoolExpression
order *listClause order *listClause
limit int64 limit int64
@ -466,7 +464,7 @@ func (u *updateStatementImpl) Execute(db *sql.DB, data interface{}) error {
} }
func (u *updateStatementImpl) Set( func (u *updateStatementImpl) Set(
column NonAliasColumn, column Column,
expression Expression) UpdateStatement { expression Expression) UpdateStatement {
u.updateValues[column] = expression u.updateValues[column] = expression

View file

@ -12,22 +12,21 @@ import (
// are not supported. // are not supported.
type ReadableTable interface { type ReadableTable interface {
// Returns the list of columns that are in the current tableName expression. // Returns the list of columns that are in the current tableName expression.
Columns() []NonAliasColumn Columns() []Column
Column(name string) NonAliasColumn Column(name string) Column
// Generates the sql string for the current tableName expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
// The database is the name of the database the tableName is on
SerializeSql(out *bytes.Buffer) error SerializeSql(out *bytes.Buffer) error
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
Select(projections ...Projection) SelectStatement Select(projections ...Expression) SelectStatement
// Creates a inner join tableName expression using onCondition. // Creates a inner join tableName expression using onCondition.
InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable InnerJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable //InnerJoinUsing(table ReadableTable, col1 Column, col2 Column) ReadableTable
// Creates a left join tableName expression using onCondition. // Creates a left join tableName expression using onCondition.
LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable LeftJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
@ -35,7 +34,7 @@ type ReadableTable interface {
// Creates a right join tableName expression using onCondition. // Creates a right join tableName expression using onCondition.
RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable RightJoinOn(table ReadableTable, onCondition BoolExpression) ReadableTable
FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable
CrossJoin(table ReadableTable) ReadableTable CrossJoin(table ReadableTable) ReadableTable
} }
@ -43,21 +42,21 @@ type ReadableTable interface {
// The sql tableName write interface. // The sql tableName write interface.
type WritableTable interface { type WritableTable interface {
// Returns the list of columns that are in the tableName. // Returns the list of columns that are in the tableName.
Columns() []NonAliasColumn Columns() []Column
// Generates the sql string for the current tableName expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
// The database is the name of the database the tableName is on // The database is the name of the database the tableName is on
SerializeSql(out *bytes.Buffer) error SerializeSql(out *bytes.Buffer) error
Insert(columns ...NonAliasColumn) InsertStatement Insert(columns ...Column) InsertStatement
Update() UpdateStatement Update() UpdateStatement
Delete() DeleteStatement Delete() DeleteStatement
} }
// Defines a physical tableName in the database that is both readable and writable. // Defines a physical tableName in the database that is both readable and writable.
// This function will panic if name is not valid // This function will panic if name is not valid
func NewTable(schemaName, name string, columns ...NonAliasColumn) *Table { func NewTable(schemaName, name string, columns ...Column) *Table {
if !validIdentifierName(name) { if !validIdentifierName(name) {
panic("Invalid tableName name") panic("Invalid tableName name")
} }
@ -66,7 +65,7 @@ func NewTable(schemaName, name string, columns ...NonAliasColumn) *Table {
schemaName: schemaName, schemaName: schemaName,
name: name, name: name,
columns: columns, columns: columns,
columnLookup: make(map[string]NonAliasColumn), columnLookup: make(map[string]Column),
} }
for _, c := range columns { for _, c := range columns {
err := c.setTableName(name) err := c.setTableName(name)
@ -87,21 +86,21 @@ type Table struct {
schemaName string schemaName string
name string name string
alias string alias string
columns []NonAliasColumn columns []Column
columnLookup map[string]NonAliasColumn columnLookup map[string]Column
// If not empty, the name of the index to force // If not empty, the name of the index to force
forcedIndex string forcedIndex string
} }
// Returns the specified column, or errors if it doesn't exist in the tableName // Returns the specified column, or errors if it doesn't exist in the tableName
func (t *Table) getColumn(name string) (NonAliasColumn, error) { func (t *Table) getColumn(name string) (Column, error) {
if c, ok := t.columnLookup[name]; ok { if c, ok := t.columnLookup[name]; ok {
return c, nil return c, nil
} }
return nil, errors.Newf("No such column '%s' in tableName '%s'", name, t.name) return nil, errors.Newf("No such column '%s' in tableName '%s'", name, t.name)
} }
func (t *Table) Column(name string) NonAliasColumn { func (t *Table) Column(name string) Column {
return &baseColumn{ return &baseColumn{
name: name, name: name,
nullable: NotNullable, nullable: NotNullable,
@ -109,9 +108,9 @@ func (t *Table) Column(name string) NonAliasColumn {
} }
} }
// Returns all columns for a tableName as a slice of projections // Returns all expresssion for a tableName as a slice of projections
func (t *Table) Projections() []Projection { func (t *Table) Projections() []Expression {
result := make([]Projection, 0) result := make([]Expression, 0)
for _, col := range t.columns { for _, col := range t.columns {
col.Asc() col.Asc()
@ -142,7 +141,7 @@ func (t *Table) SchemaName() string {
} }
// Returns a list of the tableName's columns // Returns a list of the tableName's columns
func (t *Table) Columns() []NonAliasColumn { func (t *Table) Columns() []Column {
return t.columns return t.columns
} }
@ -182,7 +181,7 @@ func (t *Table) SerializeSql(out *bytes.Buffer) error {
} }
// Generates a select query on the current tableName. // Generates a select query on the current tableName.
func (t *Table) Select(projections ...Projection) SelectStatement { func (t *Table) Select(projections ...Expression) SelectStatement {
return newSelectStatement(t, projections) return newSelectStatement(t, projections)
} }
@ -194,13 +193,13 @@ func (t *Table) InnerJoinOn(
return InnerJoinOn(t, table, onCondition) return InnerJoinOn(t, table, onCondition)
} }
func (t *Table) InnerJoinUsing( //func (t *Table) InnerJoinUsing(
table ReadableTable, // table ReadableTable,
col1 Column, // col1 Column,
col2 Column) ReadableTable { // col2 Column) ReadableTable {
//
return InnerJoinOn(t, table, col1.Eq(col2)) // return InnerJoinOn(t, table, col1.Eq(col2))
} //}
// Creates a left join tableName expression using onCondition. // Creates a left join tableName expression using onCondition.
func (t *Table) LeftJoinOn( func (t *Table) LeftJoinOn(
@ -218,15 +217,15 @@ func (t *Table) RightJoinOn(
return RightJoinOn(t, table, onCondition) return RightJoinOn(t, table, onCondition)
} }
func (t *Table) FullJoin(table ReadableTable, col1, col2 Column) ReadableTable { func (t *Table) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(t, table, col1.Eq(col2)) return FullJoin(t, table, onCondition)
} }
func (t *Table) CrossJoin(table ReadableTable) ReadableTable { func (t *Table) CrossJoin(table ReadableTable) ReadableTable {
return CrossJoin(t, table) return CrossJoin(t, table)
} }
func (t *Table) Insert(columns ...NonAliasColumn) InsertStatement { func (t *Table) Insert(columns ...Column) InsertStatement {
return newInsertStatement(t, columns...) return newInsertStatement(t, columns...)
} }
@ -309,15 +308,15 @@ func CrossJoin(
return newJoinTable(lhs, rhs, CROSS_JOIN, nil) return newJoinTable(lhs, rhs, CROSS_JOIN, nil)
} }
func (t *joinTable) Columns() []NonAliasColumn { func (t *joinTable) Columns() []Column {
columns := make([]NonAliasColumn, 0) columns := make([]Column, 0)
columns = append(columns, t.lhs.Columns()...) columns = append(columns, t.lhs.Columns()...)
columns = append(columns, t.rhs.Columns()...) columns = append(columns, t.rhs.Columns()...)
return columns return columns
} }
func (t *joinTable) Column(name string) NonAliasColumn { func (t *joinTable) Column(name string) Column {
panic("Not implemented") panic("Not implemented")
} }
@ -364,7 +363,7 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
return nil return nil
} }
func (t *joinTable) Select(projections ...Projection) SelectStatement { func (t *joinTable) Select(projections ...Expression) SelectStatement {
return newSelectStatement(t, projections) return newSelectStatement(t, projections)
} }
@ -375,14 +374,6 @@ func (t *joinTable) InnerJoinOn(
return InnerJoinOn(t, table, onCondition) return InnerJoinOn(t, table, onCondition)
} }
func (t *joinTable) InnerJoinUsing(
table ReadableTable,
col1 Column,
col2 Column) ReadableTable {
return InnerJoinOn(t, table, col1.Eq(col2))
}
func (t *joinTable) LeftJoinOn( func (t *joinTable) LeftJoinOn(
table ReadableTable, table ReadableTable,
onCondition BoolExpression) ReadableTable { onCondition BoolExpression) ReadableTable {
@ -390,8 +381,8 @@ func (t *joinTable) LeftJoinOn(
return LeftJoinOn(t, table, onCondition) return LeftJoinOn(t, table, onCondition)
} }
func (t *joinTable) FullJoin(table ReadableTable, col1 Column, col2 Column) ReadableTable { func (t *joinTable) FullJoin(table ReadableTable, onCondition BoolExpression) ReadableTable {
return FullJoin(t, table, col1.Eq(col2)) return FullJoin(t, table, onCondition)
} }
func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable { func (t *joinTable) CrossJoin(table ReadableTable) ReadableTable {

View file

@ -1,3 +1,5 @@
// +build disabled
package sqlbuilder package sqlbuilder
var table1Col1 = IntColumn("col1", Nullable) var table1Col1 = IntColumn("col1", Nullable)

View file

@ -1,9 +1,5 @@
package sqlbuilder package sqlbuilder
import (
"bytes"
)
// A clause that can be used in order by // A clause that can be used in order by
type OrderByClause interface { type OrderByClause interface {
Clause Clause
@ -11,43 +7,43 @@ type OrderByClause interface {
} }
// A clause that is selectable. // A clause that is selectable.
type Projection interface { //type Projection interface {
Clause // Clause
isProjectionInterface // isProjectionInterface
//
// SerializeSqlForColumnList(out *bytes.Buffer) error
//}
SerializeSqlForColumnList(out *bytes.Buffer) error //type ColumnList []Column
} //
//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
// for i, column := range cl {
// column.SerializeSql(out)
//
// if i != len(cl)-1 {
// out.WriteString(", ")
// }
// }
// return nil
//}
//
//func (cl ColumnList) isProjectionType() {
//}
//
//func (cl ColumnList) As(name string) Clause {
// panic("Unallowed operation ")
//}
type ColumnList []NonAliasColumn //func (cl ColumnList) SerializeSqlForColumnList(out *bytes.Buffer) error {
// for i, column := range cl {
func (cl ColumnList) SerializeSql(out *bytes.Buffer) error { // column.SerializeSqlForColumnList(out)
for i, column := range cl { //
column.SerializeSql(out) // if i != len(cl)-1 {
// out.WriteString(", ")
if i != len(cl)-1 { // }
out.WriteString(", ") // }
} // return nil
} //}
return nil
}
func (cl ColumnList) isProjectionType() {
}
func (cl ColumnList) As(name string) Projection {
panic("Unallowed operation ")
}
func (cl ColumnList) SerializeSqlForColumnList(out *bytes.Buffer) error {
for i, column := range cl {
column.SerializeSqlForColumnList(out)
if i != len(cl)-1 {
out.WriteString(", ")
}
}
return nil
}
// //
// Boiler plates ... // Boiler plates ...
@ -63,12 +59,13 @@ type isOrderByClause struct {
func (o *isOrderByClause) isOrderByClauseType() { func (o *isOrderByClause) isOrderByClauseType() {
} }
type isProjectionInterface interface { //
isProjectionType() //type isProjectionInterface interface {
} // isProjectionType()
//}
type isProjection struct { //
} //type isProjection struct {
//}
func (p *isProjection) isProjectionType() { //
} //func (p *isProjection) isProjectionType() {
//}

1
sqlbuilder/utils.go Normal file
View file

@ -0,0 +1 @@
package sqlbuilder