Refactoring to support parameterized queries.

This commit is contained in:
zer0sub 2019-04-29 14:39:48 +02:00
parent bc6a2bbcac
commit fef8f0ef83
33 changed files with 1112 additions and 1206 deletions

View file

@ -1,7 +1,5 @@
package sqlbuilder package sqlbuilder
import "bytes"
type Alias struct { type Alias struct {
expression Expression expression Expression
alias string alias string
@ -14,9 +12,9 @@ func NewAlias(expression Expression, alias string) *Alias {
} }
} }
func (a *Alias) SerializeForProjection(out *bytes.Buffer) error { func (a *Alias) SerializeForProjection(out *queryData) error {
err := a.expression.SerializeSql(out, ALIASED) err := a.expression.Serialize(out, SKIP_DEFAULT_ALIASING)
if err != nil { if err != nil {
return err return err

View file

@ -66,11 +66,7 @@ type boolLiteralExpression struct {
func newBoolLiteralExpression(value bool) BoolExpression { func newBoolLiteralExpression(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{} boolLiteralExpression := boolLiteralExpression{}
sqlValue, err := sqltypes.BuildValue(value) boolLiteralExpression.literalExpression = *Literal(value)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
boolLiteralExpression.literalExpression = *NewLiteralExpression(sqlValue)
boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression
return &boolLiteralExpression return &boolLiteralExpression
@ -113,27 +109,27 @@ func newPrefixBoolExpression(expression Expression, operator []byte) BoolExpress
} }
//---------------------------------------------------// //---------------------------------------------------//
type conjunctBoolExpression struct { //type conjunctBoolExpression struct {
expressionInterfaceImpl // expressionInterfaceImpl
boolInterfaceImpl // boolInterfaceImpl
//
conjunctExpression // conjunctExpression
name string // name string
} //}
//
func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression { //func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression {
boolExpression := conjunctBoolExpression{ // boolExpression := conjunctBoolExpression{
conjunctExpression: conjunctExpression{ // conjunctExpression: conjunctExpression{
expressions: expressions, // expressions: expressions,
conjunction: operator, // conjunction: operator,
}, // },
} // }
//
boolExpression.expressionInterfaceImpl.parent = &boolExpression // boolExpression.expressionInterfaceImpl.parent = &boolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression // boolExpression.boolInterfaceImpl.parent = &boolExpression
//
return &boolExpression // return &boolExpression
} //}
//---------------------------------------------------// //---------------------------------------------------//
type inExpression struct { type inExpression struct {
@ -146,34 +142,33 @@ type inExpression struct {
err error err error
} }
func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (c *inExpression) Serialize(out *queryData, options ...serializeOption) error {
if c.err != nil { if c.err != nil {
return errors.Wrap(c.err, "Invalid IN expression") return errors.Wrap(c.err, "Invalid IN expression")
} }
if c.lhs == nil { if c.lhs == nil {
return errors.Newf( return errors.Newf("lhs of in expression is nil.")
"lhs of in expression is nil. Generated sql: %s",
out.String())
} }
// We'll serialize the lhs even if we don't need it to ensure no error // We'll serialize the lhs even if we don't need it to ensure no error
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := c.lhs.SerializeSql(buf) err := c.lhs.Serialize(out, options...)
if err != nil { if err != nil {
return err return err
} }
if c.rhs == nil { if c.rhs == nil {
_, _ = out.WriteString("FALSE") out.WriteString("FALSE")
return nil return nil
} }
_, _ = out.WriteString(buf.String()) out.WriteString(buf.String())
_, _ = out.WriteString(" IN ") out.WriteString(" IN ")
err = c.rhs.Serialize(out)
err = c.rhs.SerializeSql(out)
if err != nil { if err != nil {
return err return err
} }
@ -183,10 +178,6 @@ func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOptio
// Returns a representation of "a=b" // Returns a representation of "a=b"
func Eq(lhs, rhs Expression) BoolExpression { func Eq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() {
return newBinaryBoolExpression(lhs, rhs, []byte(" IS "))
}
return newBinaryBoolExpression(lhs, rhs, []byte(" = ")) return newBinaryBoolExpression(lhs, rhs, []byte(" = "))
} }
@ -197,10 +188,6 @@ func EqL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a!=b" // Returns a representation of "a!=b"
func NotEq(lhs, rhs Expression) BoolExpression { func NotEq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() {
return newBinaryBoolExpression(lhs, rhs, []byte(" IS NOT "))
}
return newBinaryBoolExpression(lhs, rhs, []byte("!=")) return newBinaryBoolExpression(lhs, rhs, []byte("!="))
} }
@ -258,14 +245,13 @@ func IsTrue(expr BoolExpression) BoolExpression {
return newPrefixBoolExpression(expr, []byte(" IS TRUE ")) return newPrefixBoolExpression(expr, []byte(" IS TRUE "))
} }
// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses func And(lhs, rhs Expression) BoolExpression {
func And(expressions ...BoolExpression) BoolExpression { return newBinaryBoolExpression(lhs, rhs, []byte(" AND "))
return NewConjunctBoolExpression([]byte(" AND "), expressions...)
} }
// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses // Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses
func Or(expressions ...BoolExpression) BoolExpression { func Or(lhs, rhs Expression) BoolExpression {
return NewConjunctBoolExpression([]byte(" OR "), expressions...) return newBinaryBoolExpression(lhs, rhs, []byte(" OR "))
} }
func Like(lhs, rhs Expression) BoolExpression { func Like(lhs, rhs Expression) BoolExpression {

View file

@ -10,7 +10,7 @@ func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3)) boolExpression := Eq(Literal(2), Literal(3))
out := bytes.Buffer{} out := bytes.Buffer{}
err := boolExpression.SerializeSql(&out) err := boolExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "2 = 3") assert.Equal(t, out.String(), "2 = 3")
@ -29,7 +29,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5))) exp := boolExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := bytes.Buffer{}
err := exp.SerializeSql(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`) assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`)
@ -39,7 +39,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5))) exp := boolExpression.Or(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := bytes.Buffer{}
err := exp.SerializeSql(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`) assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`)
@ -50,7 +50,7 @@ func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1))) notExpression := Not(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{} out := bytes.Buffer{}
err := notExpression.SerializeSql(&out) err := notExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), " NOT 2 = 1") assert.Equal(t, out.String(), " NOT 2 = 1")
@ -69,7 +69,7 @@ func TestUnaryExpression(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5))) exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := bytes.Buffer{}
err := exp.SerializeSql(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`) assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`)
@ -80,7 +80,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1))) notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{} out := bytes.Buffer{}
err := notExpression.SerializeSql(&out) err := notExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), " IS TRUE 2 = 1") assert.Equal(t, out.String(), " IS TRUE 2 = 1")
@ -89,7 +89,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5))) exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := bytes.Buffer{}
err := exp.SerializeSql(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`) assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`)
@ -100,7 +100,7 @@ func TestBoolLiteral(t *testing.T) {
literal := newBoolLiteralExpression(true) literal := newBoolLiteralExpression(true)
out := bytes.Buffer{} out := bytes.Buffer{}
err := literal.SerializeSql(&out) err := literal.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)

View file

@ -1,16 +1,91 @@
package sqlbuilder package sqlbuilder
import "bytes" import (
"bytes"
"errors"
"strconv"
)
type serializeOption int type serializeOption int
const ( const (
ALIASED = iota SKIP_DEFAULT_ALIASING = iota
FOR_PROJECTION FOR_PROJECTION
) )
type Clause interface { type Clause interface {
SerializeSql(out *bytes.Buffer, options ...serializeOption) error Serialize(out *queryData, options ...serializeOption) error
}
type queryData struct {
queryBuff bytes.Buffer
args []interface{}
}
func (q *queryData) Write(data []byte) {
q.queryBuff.Write(data)
}
func (q *queryData) WriteString(str string) {
q.queryBuff.WriteString(str)
}
func (q *queryData) WriteByte(b byte) {
q.queryBuff.WriteByte(b)
}
func (q *queryData) InsertArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.queryBuff.WriteString(argPlaceholder)
}
func argToString(value interface{}) (string, error) {
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE", nil
} else {
return "FALSE", nil
}
case int8:
return strconv.FormatInt(int64(bindVal), 10), nil
case int:
return strconv.FormatInt(int64(bindVal), 10), nil
case int16:
return strconv.FormatInt(int64(bindVal), 10), nil
case int32:
return strconv.FormatInt(int64(bindVal), 10), nil
case int64:
return strconv.FormatInt(int64(bindVal), 10), nil
case uint8:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint16:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint32:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint64:
return strconv.FormatUint(uint64(bindVal), 10), nil
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil
case string:
return bindVal, nil
case []byte:
return string(bindVal), nil
//TODO: implement
//case time.Time:
// return bindVal.String())
default:
return "", errors.New("Unsupported literal type. ")
}
} }
func contains(s []serializeOption, e serializeOption) bool { func contains(s []serializeOption, e serializeOption) bool {

View file

@ -3,14 +3,9 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"regexp"
"strings" "strings"
) )
// XXX: Maybe add UIntColumn
// Representation of a tableName for query generation
type Column interface { type Column interface {
Expression Expression
@ -28,11 +23,6 @@ const (
NotNullable NullableColumn = false NotNullable NullableColumn = false
) )
//// A column that can be refer to outside of the projection list
//type NonAliasColumn interface {
// Column
//}
type Collation string type Collation string
const ( const (
@ -82,194 +72,39 @@ func (c *baseColumn) setTableName(table string) error {
return nil return nil
} }
func (c baseColumn) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error {
if c.tableName != "" { if c.tableName != "" {
_, _ = out.WriteString(c.tableName) out.WriteString(c.tableName)
_, _ = out.WriteString(".") out.WriteString(".")
} }
containsDot := strings.Contains(c.name, ".") containsDot := strings.Contains(c.name, ".")
if containsDot { if containsDot {
out.WriteString("\"") out.WriteString(`"`)
}
_, _ = out.WriteString(c.name)
if containsDot {
out.WriteString("\"")
} }
if contains(options, FOR_PROJECTION) && !contains(options, ALIASED) && c.tableName != "" { out.WriteString(c.name)
_, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
if containsDot {
out.WriteString(`"`)
}
if contains(options, FOR_PROJECTION) && !contains(options, SKIP_DEFAULT_ALIASING) && c.tableName != "" {
out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
} }
return nil return nil
} }
// //
//type bytesColumn struct { //// This is a strict subset of the actual allowed identifiers
// baseColumn //var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$")
//}
// //
//// Representation of VARBINARY/BLOB columns //// Returns true if the given string is suitable as an identifier.
//// This function will panic if name is not valid //func validIdentifierName(name string) bool {
//func BytesColumn(name string, nullable NullableColumn) Column { // return validIdentifierRegexp.MatchString(name)
// if !validIdentifierName(name) {
// panic("Invalid column name in bytes column")
// }
// bc := &bytesColumn{}
// bc.name = name
// bc.nullable = nullable
// return bc
//} //}
//
//type stringColumn struct {
// baseColumn
// charset Charset
// collation Collation
//}
//
//// Representation of VARCHAR/TEXT columns
//// This function will panic if name is not valid
//func StrColumn(
// name string,
// charset Charset,
// collation Collation,
// nullable NullableColumn) Column {
//
// if !validIdentifierName(name) {
// panic("Invalid column name in str column")
// }
// sc := &stringColumn{charset: charset, collation: collation}
// sc.name = name
// sc.nullable = nullable
// return sc
//}
//
//type dateTimeColumn struct {
// baseColumn
//}
//
//// Representation of DateTime columns
//// This function will panic if name is not valid
//func DateTimeColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in datetime column")
// }
// dc := &dateTimeColumn{}
// dc.name = name
// dc.nullable = nullable
// return dc
//}
//type IntegerColumn struct {
// baseColumn
//}
//
//// Representation of any integer column
//// This function will panic if name is not valid
//func IntColumn(name string, nullable NullableColumn) *IntegerColumn {
// if !validIdentifierName(name) {
// panic("Invalid column name in int column")
// }
// ic := &IntegerColumn{}
// ic.name = name
// ic.nullable = nullable
// return ic
//}
//type doubleColumn struct {
// baseColumn
//}
//
//// Representation of any double column
//// This function will panic if name is not valid
//func DoubleColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in int column")
// }
// ic := &doubleColumn{}
// ic.name = name
// ic.nullable = nullable
// return ic
//}
//
//type booleanColumn struct {
// baseColumn
//
// // XXX: Maybe allow isBoolExpression (for now, not included because
// // the deferred lookup equivalent can never be isBoolExpression)
//}
// Representation of TINYINT used as a bool
// This function will panic if name is not valid
//func NewBoolColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in bool column")
// }
// bc := &booleanColumn{}
// bc.name = name
// bc.nullable = nullable
// return bc
//}
//
//type aliasColumn struct {
// baseColumn
// expression Expression
//}
//
//func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error {
// _ = out.WriteByte('`')
// _, _ = out.WriteString(c.name)
// _ = out.WriteByte('`')
// return nil
//}
//
//func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error {
// if !validIdentifierName(c.name) {
// return errors.Newf(
// "Invalid alias name `%s`. Generated sql: %s",
// c.name,
// out.String())
// }
// if c.expression == nil {
// return errors.Newf(
// "Cannot alias a nil expression. Generated sql: %s",
// out.String())
// }
//
// _ = out.WriteByte('(')
// if c.expression == nil {
// return errors.Newf("nil alias clause. Generate sql: %s", out.String())
// }
// if err := c.expression.SerializeSql(out); err != nil {
// return err
// }
// _, _ = out.WriteString(") AS \"")
// _, _ = out.WriteString(c.name)
// _ = out.WriteByte('"')
// return nil
//}
//func (c *aliasColumn) setTableName(table string) error {
// return errors.Newf(
// "Alias column '%s' should never have setTableName called on it",
// c.name)
//}
// Representation of aliased clauses (expression AS name)
//func Alias(name string, c Expression) Column {
// ac := &aliasColumn{}
// ac.name = name
// ac.expression = c
// return ac
//}
// This is a strict subset of the actual allowed identifiers
var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$")
// Returns true if the given string is suitable as an identifier.
func validIdentifierName(name string) bool {
return validIdentifierRegexp.MatchString(name)
}
// //
//// Pseudo Column type returned by tableName.C(name) //// Pseudo Column type returned by tableName.C(name)
@ -289,12 +124,12 @@ func validIdentifierName(name string) bool {
//func (c *deferredLookupColumn) SerializeSqlForColumnList( //func (c *deferredLookupColumn) SerializeSqlForColumnList(
// out *bytes.Buffer) error { // out *bytes.Buffer) error {
// //
// return c.SerializeSql(out) // return c.Serialize(out)
//} //}
// //
//func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { //func (c *deferredLookupColumn) Serialize(out *bytes.Buffer) error {
// if c.cachedColumn != nil { // if c.cachedColumn != nil {
// return c.cachedColumn.SerializeSql(out) // return c.cachedColumn.Serialize(out)
// } // }
// //
// col, err := c.tableName.getColumn(c.colName) // col, err := c.tableName.getColumn(c.colName)
@ -303,7 +138,7 @@ func validIdentifierName(name string) bool {
// } // }
// //
// c.cachedColumn = col // c.cachedColumn = col
// return col.SerializeSql(out) // return col.Serialize(out)
//} //}
// //
//func (c *deferredLookupColumn) setTableName(tableName string) error { //func (c *deferredLookupColumn) setTableName(tableName string) error {

View file

@ -8,9 +8,7 @@ type BoolColumn struct {
} }
func NewBoolColumn(name string, nullable NullableColumn) *BoolColumn { func NewBoolColumn(name string, nullable NullableColumn) *BoolColumn {
if !validIdentifierName(name) {
panic("Invalid column name in bool column")
}
boolColumn := &BoolColumn{} boolColumn := &BoolColumn{}
boolColumn.baseColumn = newBaseColumn(name, nullable, "", boolColumn) boolColumn.baseColumn = newBaseColumn(name, nullable, "", boolColumn)
@ -26,9 +24,6 @@ type NumericColumn struct {
} }
func NewNumericColumn(name string, nullable NullableColumn) *NumericColumn { func NewNumericColumn(name string, nullable NullableColumn) *NumericColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
numericColumn := &NumericColumn{} numericColumn := &NumericColumn{}
@ -70,9 +65,6 @@ type StringColumn struct {
// Representation of any integer column // Representation of any integer column
// This function will panic if name is not valid // This function will panic if name is not valid
func NewStringColumn(name string, nullable NullableColumn) *StringColumn { func NewStringColumn(name string, nullable NullableColumn) *StringColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
stringColumn := &StringColumn{} stringColumn := &StringColumn{}

View file

@ -10,20 +10,20 @@ func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable) boolColumn := NewBoolColumn("col", Nullable)
out := bytes.Buffer{} out := bytes.Buffer{}
err := boolColumn.SerializeSql(&out) err := boolColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.String(), "col")
out.Reset() out.Reset()
err = boolColumn.SerializeSql(&out, FOR_PROJECTION) err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.String(), "col")
out.Reset() out.Reset()
err = boolColumn.setTableName("table1") err = boolColumn.setTableName("table1")
assert.NilError(t, err) assert.NilError(t, err)
err = boolColumn.SerializeSql(&out, FOR_PROJECTION) err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
@ -40,20 +40,20 @@ func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable) integerColumn := NewIntegerColumn("col", Nullable)
out := bytes.Buffer{} out := bytes.Buffer{}
err := integerColumn.SerializeSql(&out) err := integerColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.String(), "col")
out.Reset() out.Reset()
err = integerColumn.SerializeSql(&out, FOR_PROJECTION) err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.String(), "col")
out.Reset() out.Reset()
err = integerColumn.setTableName("table1") err = integerColumn.setTableName("table1")
assert.NilError(t, err) assert.NilError(t, err)
err = integerColumn.SerializeSql(&out, FOR_PROJECTION) err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
@ -70,20 +70,20 @@ func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable) numericColumn := NewNumericColumn("col", Nullable)
out := bytes.Buffer{} out := bytes.Buffer{}
err := numericColumn.SerializeSql(&out) err := numericColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.String(), "col")
out.Reset() out.Reset()
err = numericColumn.SerializeSql(&out) err = numericColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.String(), "col")
out.Reset() out.Reset()
err = numericColumn.setTableName("table1") err = numericColumn.setTableName("table1")
assert.NilError(t, err) assert.NilError(t, err)
err = numericColumn.SerializeSql(&out, FOR_PROJECTION) err = numericColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.String(), `table1.col AS "table1.col"`)

View file

@ -1,7 +1,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql" "database/sql"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
@ -38,33 +37,35 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
return d return d
} }
func (d *deleteStatementImpl) String() (sql string, err error) { func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
buf := new(bytes.Buffer) queryData := &queryData{}
_, _ = buf.WriteString("DELETE FROM ")
queryData.WriteString("DELETE FROM ")
if d.table == nil { if d.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) return "", nil, errors.New("nil tableName.")
} }
if err = d.table.SerializeSql(buf); err != nil { if err = d.table.SerializeSql(queryData); err != nil {
return return
} }
if d.where == nil { if d.where == nil {
return "", errors.Newf("Deleting without a WHERE clause. Generated sql: %s", buf.String()) return "", nil, errors.New("Deleting without a WHERE clause.")
} }
_, _ = buf.WriteString(" WHERE ") queryData.WriteString(" WHERE ")
if err = d.where.SerializeSql(buf); err != nil {
if err = d.where.Serialize(queryData); err != nil {
return return
} }
if d.order != nil { if d.order != nil {
_, _ = buf.WriteString(" ORDER BY ") queryData.WriteString(" ORDER BY ")
if err = d.order.SerializeSql(buf); err != nil { if err = d.order.Serialize(queryData); err != nil {
return return
} }
} }
return buf.String() + ";", nil return queryData.queryBuff.String() + ";", queryData.args, nil
} }

View file

@ -14,7 +14,7 @@ import (
"time" "time"
) )
func Execute(db types.Db, query string, destinationPtr interface{}) error { func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error {
if db == nil { if db == nil {
return errors.New("db is nil") return errors.New("db is nil")
} }
@ -28,7 +28,7 @@ func Execute(db types.Db, query string, destinationPtr interface{}) error {
return errors.New("Destination has to be a pointer to slice or pointer to struct ") return errors.New("Destination has to be a pointer to slice or pointer to struct ")
} }
rows, err := db.Query(query) rows, err := db.Query(query, args...)
if err != nil { if err != nil {
return err return err
@ -72,7 +72,7 @@ func Execute(db types.Db, query string, destinationPtr interface{}) error {
return err return err
} }
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROWS PROCESSED") fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED")
return nil return nil
} }

View file

@ -1,8 +1,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
) )
@ -42,8 +40,8 @@ func (e *expressionInterfaceImpl) Desc() OrderByClause {
return &orderByClause{expression: e.parent, ascent: false} return &orderByClause{expression: e.parent, ascent: false}
} }
func (e *expressionInterfaceImpl) SerializeForProjection(out *bytes.Buffer) error { func (e *expressionInterfaceImpl) SerializeForProjection(out *queryData) error {
return e.parent.SerializeSql(out, FOR_PROJECTION) return e.parent.Serialize(out, FOR_PROJECTION)
} }
// Representation of binary operations (e.g. comparisons, arithmetic) // Representation of binary operations (e.g. comparisons, arithmetic)
@ -62,21 +60,21 @@ func newBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Express
return binaryExpression return binaryExpression
} }
func (c *binaryExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) error {
if c.lhs == nil { if c.lhs == nil {
return errors.Newf("nil lhs. Generated sql: %s", out.String()) return errors.Newf("nil lhs.")
} }
if err = c.lhs.SerializeSql(out); err != nil { if err := c.lhs.Serialize(out); err != nil {
return return err
} }
_, _ = out.Write(c.operator) out.Write(c.operator)
if c.rhs == nil { if c.rhs == nil {
return errors.Newf("nil rhs. Generated sql: %s", out.String()) return errors.Newf("nil rhs.")
} }
if err = c.rhs.SerializeSql(out); err != nil { if err := c.rhs.Serialize(out); err != nil {
return return err
} }
return nil return nil
@ -97,80 +95,61 @@ func newPrefixExpression(expression Expression, operator []byte) prefixExpressio
return prefixExpression return prefixExpression
} }
func (p *prefixExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { func (p *prefixExpression) Serialize(out *queryData, options ...serializeOption) error {
_, _ = out.Write(p.operator) out.Write(p.operator)
if p.expression == nil { if p.expression == nil {
return errors.Newf("nil prefix expression. Generated sql: %s", out.String()) return errors.Newf("nil prefix expression.")
} }
if err = p.expression.SerializeSql(out); err != nil { if err := p.expression.Serialize(out); err != nil {
return return err
} }
return nil return nil
} }
// Representation of n-ary conjunctions (AND/OR) //
type conjunctExpression struct { //// Representation of n-ary conjunctions (AND/OR)
expressions []BoolExpression //type conjunctExpression struct {
conjunction []byte // expressions []Expression
} // conjunction []byte
//}
func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { //
if len(conj.expressions) == 0 { //func (conj *conjunctExpression) Serialize(out *queryData, options ...serializeOption) error {
return errors.Newf( // if len(conj.expressions) == 0 {
"Empty conjunction. Generated sql: %s", // return errors.New("Empty conjunction.")
out.String()) // }
} //
// //clauses := make([]Clause, len(conj.expressions), len(conj.expressions))
clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) // //for i, expr := range conj.expressions {
for i, expr := range conj.expressions { // // clauses[i] = expr
clauses[i] = expr // //}
} //
// useParentheses := len(conj.expressions) > 1
useParentheses := len(clauses) > 1 // if useParentheses {
if useParentheses { // out.WriteByte('(')
_ = out.WriteByte('(') // }
} //
// if err := serializeExpressionList(conj.expressions, string(conj.conjunction), out); err != nil {
if err = serializeClauses(clauses, conj.conjunction, out); err != nil { // return err
return // }
} //
// if useParentheses {
if useParentheses { // out.WriteByte(')')
_ = out.WriteByte(')') // }
} //
// return nil
return nil //}
}
//-------------------------------------------------------------- //--------------------------------------------------------------
// Representation of an escaped literal
type literalExpression struct {
expressionInterfaceImpl
value sqltypes.Value
}
func NewLiteralExpression(value sqltypes.Value) *literalExpression {
exp := literalExpression{value: value}
exp.expressionInterfaceImpl.parent = &exp
return &exp
}
func (c literalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
sqltypes.Value(c.value).EncodeSql(out)
return nil
}
//------------------------------------------------------// //------------------------------------------------------//
//// Dummy type for select * //// Dummy type for select *
//type ColumnList []Column //type ColumnList []Column
// //
//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { //func (cl ColumnList) Serialize(out *bytes.Buffer, options ...serializeOption) error {
// for i, column := range cl { // for i, column := range cl {
// err := column.SerializeSql(out) // err := column.Serialize(out)
// //
// if err != nil { // if err != nil {
// return err // return err

View file

@ -2,124 +2,88 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
) )
type orderByClause struct { //func serializeClauses(
isOrderByClause // clauses []Clause,
expression Expression // separator []byte,
ascent bool // out *bytes.Buffer) (err error) {
} //
// if clauses == nil || len(clauses) == 0 {
func (o *orderByClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { // return errors.Newf("Empty clauses.")
if o.expression == nil { // }
return errors.Newf( //
"nil order by clause. Generated sql: %s", // if clauses[0] == nil {
out.String()) // return errors.Newf("nil clause.")
} // }
// if err = clauses[0].Serialize(out); err != nil {
if err := o.expression.SerializeSql(out); err != nil { // return
return err // }
} //
// for _, c := range clauses[1:] {
if o.ascent { // _, _ = out.Write(separator)
_, _ = out.WriteString(" ASC") //
} else { // if c == nil {
_, _ = out.WriteString(" DESC") // return errors.Newf("nil clause.")
} // }
// if err = c.Serialize(out); err != nil {
return nil // return
} // }
// }
func Asc(expression Expression) OrderByClause { //
return &orderByClause{expression: expression, ascent: true} // return nil
} //}
//
func Desc(expression Expression) OrderByClause { //// Representation of n-ary arithmetic (+ - * /)
return &orderByClause{expression: expression, ascent: false} //type arithmeticExpression struct {
} // expressionInterfaceImpl
// expressions []Expression
func serializeClauses( // operator []byte
clauses []Clause, //}
separator []byte, //
out *bytes.Buffer) (err error) { //func (arith *arithmeticExpression) Serialize(out *queryData, options ...serializeOption) error {
// if len(arith.expressions) == 0 {
if clauses == nil || len(clauses) == 0 { // return errors.Newf(
return errors.Newf("Empty clauses. Generated sql: %s", out.String()) // "Empty arithmetic expression.")
} // }
//
if clauses[0] == nil { // clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
return errors.Newf("nil clause. Generated sql: %s", out.String()) // for i, expr := range arith.expressions {
} // clauses[i] = expr
if err = clauses[0].SerializeSql(out); err != nil { // }
return //
} // useParentheses := len(clauses) > 1
// if useParentheses {
for _, c := range clauses[1:] { // _ = out.WriteByte('(')
_, _ = out.Write(separator) // }
//
if c == nil { // if err = serializeClauses(clauses, arith.operator, out); err != nil {
return errors.Newf("nil clause. Generated sql: %s", out.String()) // return
} // }
if err = c.SerializeSql(out); err != nil { //
return // if useParentheses {
} // _ = out.WriteByte(')')
} // }
//
return nil // return nil
} //}
//
// Representation of n-ary arithmetic (+ - * /)
type arithmeticExpression struct {
expressionInterfaceImpl
expressions []Expression
operator []byte
}
func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if len(arith.expressions) == 0 {
return errors.Newf(
"Empty arithmetic expression. Generated sql: %s",
out.String())
}
clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
for i, expr := range arith.expressions {
clauses[i] = expr
}
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, arith.operator, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil
}
type tupleExpression struct { type tupleExpression struct {
expressionInterfaceImpl expressionInterfaceImpl
elements listClause elements listClause
} }
func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (tuple *tupleExpression) Serialize(out *queryData, options ...serializeOption) error {
if len(tuple.elements.clauses) < 1 { if len(tuple.elements.clauses) == 0 {
return errors.Newf("Tuples must include at least one element") return errors.Newf("Tuples must include at least one element")
} }
return tuple.elements.SerializeSql(out) return tuple.elements.Serialize(out)
} }
func Tuple(exprs ...Expression) Expression { func Tuple(exprs ...Expression) Expression {
@ -141,61 +105,62 @@ type listClause struct {
includeParentheses bool includeParentheses bool
} }
func (list *listClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (list *listClause) Serialize(out *queryData, options ...serializeOption) error {
if list.includeParentheses { if list.includeParentheses {
_ = out.WriteByte('(') out.WriteByte('(')
} }
if err := serializeClauses(list.clauses, []byte(","), out); err != nil { if err := serializeClauseList(list.clauses, out); err != nil {
return err return err
} }
if list.includeParentheses { if list.includeParentheses {
_ = out.WriteByte(')') out.WriteByte(')')
} }
return nil return nil
} }
type funcExpression struct { //
expressionInterfaceImpl //type funcExpression struct {
funcName string // expressionInterfaceImpl
args *listClause // funcName string
} // args *listClause
//}
func (c *funcExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { //
if !validIdentifierName(c.funcName) { //func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error {
return errors.Newf( // if !validIdentifierName(c.funcName) {
"Invalid function name: %s. Generated sql: %s", // return errors.Newf(
c.funcName, // "Invalid function name: %s.",
out.String()) // c.funcName,
} // out.String())
_, _ = out.WriteString(c.funcName) // }
if c.args == nil { // _, _ = out.WriteString(c.funcName)
_, _ = out.WriteString("()") // if c.args == nil {
} else { // _, _ = out.WriteString("()")
return c.args.SerializeSql(out) // } else {
} // return c.args.Serialize(out)
return nil // }
} // return nil
//}
// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) //
func SqlFunc(funcName string, expressions ...Expression) Expression { //// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
f := &funcExpression{ //func SqlFunc(funcName string, expressions ...Expression) Expression {
funcName: funcName, // f := &funcExpression{
} // funcName: funcName,
if len(expressions) > 0 { // }
args := make([]Clause, len(expressions), len(expressions)) // if len(expressions) > 0 {
for i, expr := range expressions { // args := make([]Clause, len(expressions), len(expressions))
args[i] = expr // for i, expr := range expressions {
} // args[i] = expr
// }
f.args = &listClause{ //
clauses: args, // f.args = &listClause{
includeParentheses: true, // clauses: args,
} // includeParentheses: true,
} // }
return f // }
} // return f
//}
type intervalExpression struct { type intervalExpression struct {
expressionInterfaceImpl expressionInterfaceImpl
@ -205,23 +170,24 @@ type intervalExpression struct {
var intervalSep = ":" var intervalSep = ":"
func (c *intervalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { func (c *intervalExpression) Serialize(out *queryData, options ...serializeOption) error {
hours := c.duration / time.Hour hours := c.duration / time.Hour
minutes := (c.duration % time.Hour) / time.Minute minutes := (c.duration % time.Hour) / time.Minute
sec := (c.duration % time.Minute) / time.Second sec := (c.duration % time.Minute) / time.Second
msec := (c.duration % time.Second) / time.Microsecond msec := (c.duration % time.Second) / time.Microsecond
_, _ = out.WriteString("INTERVAL '") out.WriteString("INTERVAL '")
if c.negative { if c.negative {
_, _ = out.WriteString("-") out.WriteString("-")
} }
_, _ = out.WriteString(strconv.FormatInt(int64(hours), 10)) out.WriteString(strconv.FormatInt(int64(hours), 10))
_, _ = out.WriteString(intervalSep) out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10)) out.WriteString(strconv.FormatInt(int64(minutes), 10))
_, _ = out.WriteString(intervalSep) out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(sec), 10)) out.WriteString(strconv.FormatInt(int64(sec), 10))
_, _ = out.WriteString(intervalSep) out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(msec), 10)) out.WriteString(strconv.FormatInt(int64(msec), 10))
_, _ = out.WriteString("' HOUR_MICROSECOND") out.WriteString("' HOUR_MICROSECOND")
return nil return nil
} }
@ -246,45 +212,45 @@ func EscapeForLike(s string) string {
} }
// Returns an escaped literal string // Returns an escaped literal string
func Literal(v interface{}) Expression { //func Literal(v interface{}) Expression {
value, err := sqltypes.BuildValue(v) // value, err := sqltypes.BuildValue(v)
if err != nil { // if err != nil {
panic(errors.Wrap(err, "Invalid literal value")) // panic(errors.Wrap(err, "Invalid literal value"))
} // }
return NewLiteralExpression(value) // return NewLiteralExpression(value)
} //}
//
// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses //// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
func Add(expressions ...Expression) Expression { //func Add(expressions ...Expression) Expression {
return &arithmeticExpression{ // return &arithmeticExpression{
expressions: expressions, // expressions: expressions,
operator: []byte(" + "), // operator: []byte(" + "),
} // }
} //}
//
// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses //// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
func Sub(expressions ...Expression) Expression { //func Sub(expressions ...Expression) Expression {
return &arithmeticExpression{ // return &arithmeticExpression{
expressions: expressions, // expressions: expressions,
operator: []byte(" - "), // operator: []byte(" - "),
} // }
} //}
//
// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses //// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
func Mul(expressions ...Expression) Expression { //func Mul(expressions ...Expression) Expression {
return &arithmeticExpression{ // return &arithmeticExpression{
expressions: expressions, // expressions: expressions,
operator: []byte(" * "), // operator: []byte(" * "),
} // }
} //}
//
// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses //// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
func Div(expressions ...Expression) Expression { //func Div(expressions ...Expression) Expression {
return &arithmeticExpression{ // return &arithmeticExpression{
expressions: expressions, // expressions: expressions,
operator: []byte(" / "), // operator: []byte(" / "),
} // }
} //}
//TODO: Uncomment //TODO: Uncomment
// //
@ -336,14 +302,15 @@ type ifExpression struct {
falseExpression Expression falseExpression Expression
} }
func (exp *ifExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error {
_, _ = out.WriteString("IF(") out.WriteString("IF(")
_ = exp.conditional.SerializeSql(out) _ = exp.conditional.Serialize(out)
_, _ = out.WriteString(",") out.WriteString(",")
_ = exp.trueExpression.SerializeSql(out) _ = exp.trueExpression.Serialize(out)
_, _ = out.WriteString(",") out.WriteString(",")
_ = exp.falseExpression.SerializeSql(out) _ = exp.falseExpression.Serialize(out)
_, _ = out.WriteString(")") out.WriteString(")")
return nil return nil
} }
@ -371,7 +338,7 @@ func If(conditional BoolExpression,
// } // }
//} //}
// //
//func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error { //func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error {
// _, _ = out.WriteString("VALUES(") // _, _ = out.WriteString("VALUES(")
// _ = cv.column.SerializeSqlForColumnList(out) // _ = cv.column.SerializeSqlForColumnList(out)
// _ = out.WriteByte(')') // _ = out.WriteByte(')')

View file

@ -19,7 +19,7 @@ func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -28,7 +28,7 @@ func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -37,7 +37,7 @@ func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -48,11 +48,11 @@ func (s *ExprSuite) TestTupleExpr(c *gc.C) {
expr := Tuple() expr := Tuple()
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
expr = Tuple(table1Col1, Literal(1), Literal("five")) expr = Tuple(table1Col1, Literal(1), Literal("five"))
err = expr.SerializeSql(buf) err = expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -68,7 +68,7 @@ func (s *ExprSuite) TestLikeExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -84,7 +84,7 @@ func (s *ExprSuite) TestRegexExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -100,7 +100,7 @@ func (s *ExprSuite) TestAndExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -115,7 +115,7 @@ func (s *ExprSuite) TestOrExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -130,7 +130,7 @@ func (s *ExprSuite) TestAddExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -142,7 +142,7 @@ func (s *ExprSuite) TestSubExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -154,7 +154,7 @@ func (s *ExprSuite) TestMulExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -166,7 +166,7 @@ func (s *ExprSuite) TestDivExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -178,7 +178,7 @@ func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -187,7 +187,7 @@ func (s *ExprSuite) TestNegateExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -199,7 +199,7 @@ func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -208,7 +208,7 @@ func (s *ExprSuite) TestEqExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -220,7 +220,7 @@ func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -232,7 +232,7 @@ func (s *ExprSuite) TestNeqExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -244,7 +244,7 @@ func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -256,7 +256,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -268,7 +268,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -283,7 +283,7 @@ func (s *ExprSuite) TestGtExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -295,7 +295,7 @@ func (s *ExprSuite) TestGteExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -308,7 +308,7 @@ func (s *ExprSuite) TestInExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -321,7 +321,7 @@ func (s *ExprSuite) TestInExprEmptyList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -333,7 +333,7 @@ func (s *ExprSuite) TestSqlFuncExprNilInArgList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -342,7 +342,7 @@ func (s *ExprSuite) TestSqlFuncExprEmptyArgList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -354,7 +354,7 @@ func (s *ExprSuite) TestSqlFuncExprNonEmptyArgList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.SerializeSql(buf) err := expr.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -366,7 +366,7 @@ func (s *ExprSuite) TestOrderByClauseNilExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := clause.SerializeSql(buf) err := clause.Serialize(buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -375,7 +375,7 @@ func (s *ExprSuite) TestAsc(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := clause.SerializeSql(buf) err := clause.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -387,7 +387,7 @@ func (s *ExprSuite) TestDesc(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := clause.SerializeSql(buf) err := clause.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -400,7 +400,7 @@ func (s *ExprSuite) TestIf(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := clause.SerializeSql(buf) err := clause.Serialize(buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -538,7 +538,7 @@ func (s *ExprSuite) TestInterval(c *gc.C) {
for i, tt := range testTable { for i, tt := range testTable {
buf.Reset() buf.Reset()
err := Interval(tt.interval).SerializeSql(buf) err := Interval(tt.interval).Serialize(buf)
c.Assert(err, gc.Equals, tt.expectedErr, c.Assert(err, gc.Equals, tt.expectedErr,
gc.Commentf("experiment #%d", i)) gc.Commentf("experiment #%d", i))
if err == nil { if err == nil {

View file

@ -1,7 +1,5 @@
package sqlbuilder package sqlbuilder
import "bytes"
type FuncExpression interface { type FuncExpression interface {
Expression Expression
} }
@ -26,10 +24,10 @@ func NewNumericFunc(name string, expression Expression) NumericExpression {
return numericFunc return numericFunc
} }
func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (f *numericFunc) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString(f.name) out.WriteString(f.name)
out.WriteString("(") out.WriteString("(")
err := f.expression.SerializeSql(out) err := f.expression.Serialize(out)
if err != nil { if err != nil {
return err return err
} }
@ -39,7 +37,7 @@ func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption
} }
//func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error { //func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error {
// return f.SerializeSql(out) // return f.Serialize(out)
//} //}
func MAX(expression NumericExpression) NumericExpression { func MAX(expression NumericExpression) NumericExpression {

View file

@ -1,7 +1,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql" "database/sql"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
"github.com/serenize/snaker" "github.com/serenize/snaker"
@ -53,6 +52,7 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db) return Execute(u, db)
} }
// expression or default keyword
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
literalRow := []Clause{} literalRow := []Clause{}
@ -122,84 +122,92 @@ func (i *insertStatementImpl) addError(err string) {
i.errors = append(i.errors, err) i.errors = append(i.errors, err)
} }
func (s *insertStatementImpl) String() (sql string, err error) { func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("INSERT ")
_, _ = buf.WriteString("INTO ")
if len(s.errors) > 0 { if len(s.errors) > 0 {
return "", errors.New("sql builder errors: " + strings.Join(s.errors, ", ")) return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", "))
} }
queryData := &queryData{}
queryData.WriteString("INSERT INTO ")
if s.table == nil { if s.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) return "", nil, errors.Newf("nil tableName.")
} }
buf.WriteString(s.table.SchemaName() + "." + s.table.TableName()) err = s.table.SerializeSql(queryData)
if err != nil {
return "", nil, err
}
if len(s.columns) > 0 { if len(s.columns) > 0 {
_, _ = buf.WriteString(" (") queryData.WriteString(" (")
for i, col := range s.columns {
if i > 0 { //for i, col := range s.columns {
_ = buf.WriteByte(',') // if i > 0 {
// queryData.WriteByte(',')
// }
//
// if col == nil {
// return "", nil, errors.New("nil column in columns list.")
// }
//
// queryData.WriteString(col.Name())
//}
err = serializeColumnList(s.columns, queryData)
if err != nil {
return "", nil, err
} }
if col == nil { queryData.WriteString(") ")
return "", errors.Newf(
"nil column in columns list. Generated sql: %s",
buf.String())
}
buf.WriteString(col.Name())
}
buf.WriteString(") ")
} }
if len(s.rows) == 0 && s.query == nil { if len(s.rows) == 0 && s.query == nil {
return "", errors.Newf("No row or query specified. Generated sql: %s", buf.String()) return "", nil, errors.New("No row or query specified.")
} }
if len(s.rows) > 0 && s.query != nil { if len(s.rows) > 0 && s.query != nil {
return "", errors.Newf("Only new rows or query has to be specified. Generated sql: %s", buf.String()) return "", nil, errors.New("Only new rows or query has to be specified.")
} }
if len(s.rows) > 0 { if len(s.rows) > 0 {
_, _ = buf.WriteString("VALUES (") queryData.WriteString("VALUES (")
for row_i, row := range s.rows { for row_i, row := range s.rows {
if row_i > 0 { if row_i > 0 {
_, _ = buf.WriteString(", (") queryData.WriteString(", (")
} }
if len(row) != len(s.columns) { if len(row) != len(s.columns) {
return "", errors.Newf( return "", nil, errors.New("# of values does not match # of columns.")
"# of values does not match # of columns. Generated sql: %s",
buf.String())
} }
for col_i, value := range row { err = serializeClauseList(row, queryData)
if col_i > 0 {
_ = buf.WriteByte(',') if err != nil {
return "", nil, err
} }
if value == nil { //for col_i, value := range row {
return "", errors.Newf( // if col_i > 0 {
"nil value in row %d col %d. Generated sql: %s", // queryData.WriteByte(',')
row_i, // }
col_i, //
buf.String()) // if value == nil {
} // return "", nil, errors.Newf("nil value in row %d col %d.", row_i, col_i)
// }
if err = value.SerializeSql(buf); err != nil { //
return // if err = value.Serialize(queryData); err != nil {
} // return
} // }
_ = buf.WriteByte(')') //}
queryData.WriteByte(')')
} }
} }
if s.query != nil { if s.query != nil {
err = s.query.SerializeSql(buf) err = s.query.Serialize(queryData)
if err != nil { if err != nil {
return return
@ -207,16 +215,16 @@ func (s *insertStatementImpl) String() (sql string, err error) {
} }
if len(s.returning) > 0 { if len(s.returning) > 0 {
buf.WriteString(" RETURNING ") queryData.WriteString(" RETURNING ")
err = serializeProjectionList(s.returning, buf) err = serializeProjectionList(s.returning, queryData)
if err != nil { if err != nil {
return return
} }
} }
buf.WriteByte(';') queryData.WriteByte(';')
return buf.String(), nil return queryData.queryBuff.String(), queryData.args, nil
} }

View file

@ -1,14 +1,12 @@
package sqlbuilder package sqlbuilder
import "bytes"
const ( const (
DEFAULT keywordClause = "DEFAULT" DEFAULT keywordClause = "DEFAULT"
) )
type keywordClause string type keywordClause string
func (k keywordClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (k keywordClause) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString(string(k)) out.WriteString(string(k))
return nil return nil

View file

@ -0,0 +1,22 @@
package sqlbuilder
// Representation of an escaped literal
type literalExpression struct {
expressionInterfaceImpl
value interface{}
}
func Literal(value interface{}) *literalExpression {
exp := literalExpression{value: value}
exp.expressionInterfaceImpl.parent = &exp
return &exp
}
func (l literalExpression) Serialize(out *queryData, options ...serializeOption) error {
//sqltypes.Value(c.value).EncodeSql(out)
out.InsertArgument(l.value)
return nil
}

View file

@ -1,11 +1,5 @@
package sqlbuilder package sqlbuilder
import (
"bytes"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/pkg/errors"
)
type NumericExpression interface { type NumericExpression interface {
Expression Expression
@ -13,8 +7,11 @@ type NumericExpression interface {
EqL(literal interface{}) BoolExpression EqL(literal interface{}) BoolExpression
NotEq(expression NumericExpression) BoolExpression NotEq(expression NumericExpression) BoolExpression
NotEqL(literal interface{}) BoolExpression NotEqL(literal interface{}) BoolExpression
Gt(rhs NumericExpression) BoolExpression
GtEq(rhs NumericExpression) BoolExpression GtEq(rhs NumericExpression) BoolExpression
GtEqL(literal interface{}) BoolExpression GtEqL(literal interface{}) BoolExpression
LtEq(rhs NumericExpression) BoolExpression LtEq(rhs NumericExpression) BoolExpression
LtEqL(literal interface{}) BoolExpression LtEqL(literal interface{}) BoolExpression
@ -44,6 +41,10 @@ func (n *numericInterfaceImpl) NotEqL(literal interface{}) BoolExpression {
return NotEq(n.parent, Literal(literal)) return NotEq(n.parent, Literal(literal))
} }
func (n *numericInterfaceImpl) Gt(expression NumericExpression) BoolExpression {
return Gt(n.parent, expression)
}
func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression { func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression {
return GtEq(n.parent, expression) return GtEq(n.parent, expression)
} }
@ -84,12 +85,8 @@ type numericLiteral struct {
func NewNumericLiteral(value interface{}) NumericExpression { func NewNumericLiteral(value interface{}) NumericExpression {
numericLiteral := numericLiteral{} numericLiteral := numericLiteral{}
numericLiteral.literalExpression = *Literal(value)
sqlValue, err := sqltypes.BuildValue(value)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
numericLiteral.literalExpression = *NewLiteralExpression(sqlValue)
numericLiteral.numericInterfaceImpl.parent = &numericLiteral numericLiteral.numericInterfaceImpl.parent = &numericLiteral
return &numericLiteral return &numericLiteral
@ -133,10 +130,10 @@ func newNumericExpressionWrap(expression Expression) NumericExpression {
return &numericExpressionWrap return &numericExpressionWrap
} }
func (c *numericExpressionWrapper) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { func (c *numericExpressionWrapper) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("(") out.WriteString("(")
err = c.expression.SerializeSql(out, options...) err := c.expression.Serialize(out, options...)
out.WriteString(")") out.WriteString(")")
return nil return err
} }

View file

@ -0,0 +1,46 @@
package sqlbuilder
import "github.com/dropbox/godropbox/errors"
type OrderByClause interface {
Clause
isOrderByClauseType()
}
type isOrderByClause struct {
}
func (o *isOrderByClause) isOrderByClauseType() {
}
type orderByClause struct {
isOrderByClause
expression Expression
ascent bool
}
func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) error {
if o.expression == nil {
return errors.Newf("nil orderBy by clause.")
}
if err := o.expression.Serialize(out); err != nil {
return err
}
if o.ascent {
out.WriteString(" ASC")
} else {
out.WriteString(" DESC")
}
return nil
}
func Asc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: true}
}
func Desc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: false}
}

View file

@ -1,18 +1,16 @@
package sqlbuilder package sqlbuilder
import "bytes"
type Projection interface { type Projection interface {
SerializeForProjection(out *bytes.Buffer) error SerializeForProjection(out *queryData) error
} }
//------------------------------------------------------// //------------------------------------------------------//
// Dummy type for select * AllColumns // Dummy type for select * AllColumns
type ColumnList []Column type ColumnList []Column
func (cl ColumnList) SerializeForProjection(out *bytes.Buffer) error { func (cl ColumnList) SerializeForProjection(out *queryData) error {
for i, column := range cl { for i, column := range cl {
err := column.SerializeSql(out, FOR_PROJECTION) err := column.Serialize(out, FOR_PROJECTION)
if err != nil { if err != nil {
return err return err

View file

@ -1,9 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql" "database/sql"
"fmt"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
) )
@ -12,17 +10,17 @@ type SelectStatement interface {
Statement Statement
Expression Expression
Where(expression BoolExpression) SelectStatement DISTINCT() SelectStatement
GroupBy(expressions ...Expression) SelectStatement WHERE(expression BoolExpression) SelectStatement
HAVING(expressions BoolExpression) SelectStatement GROUP_BY(expressions ...Clause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
ORDER_BY(clauses ...OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR_UPDATE() SelectStatement
OrderBy(clauses ...OrderByClause) SelectStatement
Limit(limit int64) SelectStatement
Offset(offset int64) SelectStatement
Distinct() SelectStatement
WithSharedLock() SelectStatement
ForUpdate() SelectStatement
Comment(comment string) SelectStatement
Copy() SelectStatement Copy() SelectStatement
AsTable(alias string) *SelectStatementTable AsTable(alias string) *SelectStatementTable
@ -34,16 +32,16 @@ type selectStatementImpl struct {
expressionInterfaceImpl expressionInterfaceImpl
table ReadableTable table ReadableTable
distinct bool
projections []Projection projections []Projection
where BoolExpression where BoolExpression
group *listClause groupBy []Clause
having BoolExpression having BoolExpression
order *listClause orderBy []OrderByClause
comment string
limit, offset int64 limit, offset int64
withSharedLock bool
forUpdate bool forUpdate bool
distinct bool
} }
func newSelectStatement( func newSelectStatement(
@ -55,26 +53,115 @@ func newSelectStatement(
projections: projections, projections: projections,
limit: -1, limit: -1,
offset: -1, offset: -1,
withSharedLock: false,
forUpdate: false, forUpdate: false,
distinct: false, distinct: false,
} }
} }
func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
str, err := s.String()
out.WriteString("(")
err := s.serializeImpl(out, options...)
if err != nil { if err != nil {
return err return err
} }
out.WriteString("(")
out.WriteString(str)
out.WriteString(")") out.WriteString(")")
return nil return nil
} }
func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error {
out.WriteString("SELECT ")
if s.distinct {
out.WriteString("DISTINCT ")
}
if s.projections == nil || len(s.projections) == 0 {
return errors.New("No column selected for projection.")
}
err := serializeProjectionList(s.projections, out)
if err != nil {
return err
}
out.WriteString(" FROM ")
if s.table == nil {
return errors.Newf("nil tableName.")
}
if err := s.table.SerializeSql(out); err != nil {
return err
}
if s.where != nil {
out.WriteString(" WHERE ")
if err := s.where.Serialize(out); err != nil {
return err
}
}
if s.groupBy != nil && len(s.groupBy) > 0 {
out.WriteString(" GROUP BY ")
err := serializeClauseList(s.groupBy, out)
if err != nil {
return err
}
}
if s.having != nil {
out.WriteString(" HAVING ")
if err = s.having.Serialize(out); err != nil {
return err
}
}
if s.orderBy != nil {
out.WriteString(" ORDER BY ")
if err := serializeOrderByClauseList(s.orderBy, out); err != nil {
return err
}
}
if s.limit >= 0 {
out.WriteString(" LIMIT ")
out.InsertArgument(s.limit)
}
if s.offset >= 0 {
out.WriteString(" OFFSET ")
out.InsertArgument(s.offset)
}
if s.forUpdate {
out.WriteString(" FOR UPDATE")
}
return nil
}
// Return the properly escaped SQL statement, against the specified database
func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := queryData{}
err = q.serializeImpl(&queryData)
if err != nil {
return "", nil, err
}
return queryData.queryBuff.String(), queryData.args, nil
}
func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
return &SelectStatementTable{ return &SelectStatementTable{
statement: s, statement: s,
@ -95,23 +182,14 @@ func (s *selectStatementImpl) Copy() SelectStatement {
return &ret return &ret
} }
func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { func (q *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
q.where = expression q.where = expression
return q return q
} }
func (q *selectStatementImpl) GroupBy( func (s *selectStatementImpl) GROUP_BY(cluases ...Clause) SelectStatement {
expressions ...Expression) SelectStatement { s.groupBy = cluases
return s
q.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
q.group.clauses[i] = e
}
return q
} }
func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement { func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
@ -119,132 +197,31 @@ func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement
return q return q
} }
func (q *selectStatementImpl) OrderBy( func (q *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
clauses ...OrderByClause) SelectStatement {
q.orderBy = clauses
q.order = newOrderByListClause(clauses...)
return q return q
} }
func (q *selectStatementImpl) Limit(limit int64) SelectStatement { func (q *selectStatementImpl) OFFSET(offset int64) SelectStatement {
q.limit = limit
return q
}
func (q *selectStatementImpl) Distinct() SelectStatement {
q.distinct = true
return q
}
func (q *selectStatementImpl) WithSharedLock() SelectStatement {
// We don't need to grab a read lock if we're going to grab a write one
if !q.forUpdate {
q.withSharedLock = true
}
return q
}
func (q *selectStatementImpl) ForUpdate() SelectStatement {
// Clear a request for a shared lock if we're asking for a write one
q.withSharedLock = false
q.forUpdate = true
return q
}
func (q *selectStatementImpl) Offset(offset int64) SelectStatement {
q.offset = offset q.offset = offset
return q return q
} }
func (q *selectStatementImpl) Comment(comment string) SelectStatement { func (q *selectStatementImpl) LIMIT(limit int64) SelectStatement {
q.comment = comment q.limit = limit
return q return q
} }
// Return the properly escaped SQL statement, against the specified database func (q *selectStatementImpl) DISTINCT() SelectStatement {
func (q *selectStatementImpl) String() (sql string, err error) { q.distinct = true
buf := new(bytes.Buffer) return q
_, _ = buf.WriteString("SELECT ") }
if err = writeComment(q.comment, buf); err != nil { func (q *selectStatementImpl) FOR_UPDATE() SelectStatement {
return q.forUpdate = true
} return q
if q.distinct {
_, _ = buf.WriteString("DISTINCT ")
}
if q.projections == nil || len(q.projections) == 0 {
return "", errors.Newf(
"No column selected. Generated sql: %s",
buf.String())
}
for i, col := range q.projections {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column selected. Generated sql: %s",
buf.String())
}
if err = col.SerializeForProjection(buf); err != nil {
return
}
}
_, _ = buf.WriteString(" FROM ")
if q.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = q.table.SerializeSql(buf); err != nil {
return
}
if q.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = q.where.SerializeSql(buf); err != nil {
return
}
}
if q.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = q.group.SerializeSql(buf); err != nil {
return
}
}
if q.having != nil {
buf.WriteString(" HAVING ")
if err = q.having.SerializeSql(buf); err != nil {
return
}
}
if q.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = q.order.SerializeSql(buf); err != nil {
return
}
}
if q.limit >= 0 {
if q.offset >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
}
if q.forUpdate {
_, _ = buf.WriteString(" FOR UPDATE")
} else if q.withSharedLock {
_, _ = buf.WriteString(" LOCK IN SHARE MODE")
}
return buf.String(), nil
} }
func NumExp(statement SelectStatement) NumericExpression { func NumExp(statement SelectStatement) NumericExpression {

View file

@ -1,7 +1,5 @@
package sqlbuilder package sqlbuilder
import "bytes"
type SelectStatementTable struct { type SelectStatementTable struct {
statement SelectStatement statement SelectStatement
columns []Column columns []Column
@ -41,16 +39,14 @@ func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn {
return strColumn return strColumn
} }
func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error { func (s *SelectStatementTable) SerializeSql(out *queryData) error {
out.WriteString("( ") out.WriteString("( ")
statementStr, err := s.statement.String() err := s.statement.Serialize(out)
if err != nil { if err != nil {
return err return err
} }
out.WriteString(statementStr)
out.WriteString(" ) AS ") out.WriteString(" ) AS ")
out.WriteString(s.alias) out.WriteString(s.alias)

View file

@ -1,17 +1,13 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql" "database/sql"
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
"regexp"
"github.com/dropbox/godropbox/errors"
) )
type Statement interface { type Statement interface {
// String returns generated SQL as string. // String returns generated SQL as string.
String() (sql string, err error) Sql() (query string, args []interface{}, err error)
Query(db types.Db, destination interface{}) error Query(db types.Db, destination interface{}) error
Execute(db types.Db) (sql.Result, error) Execute(db types.Db) (sql.Result, error)
@ -88,10 +84,10 @@ type Statement interface {
// //
// for idx, lock := range s.locks { // for idx, lock := range s.locks {
// if lock.t == nil { // if lock.t == nil {
// return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) // return "", errors.Newf("nil tableName.", buf.String())
// } // }
// //
// if err = lock.t.SerializeSql(buf); err != nil { // if err = lock.t.Serialize(buf); err != nil {
// return // return
// } // }
// //
@ -162,23 +158,23 @@ type Statement interface {
// //
// Once again, teisenberger is lazy. Here's a quick filter on comments // Once again, teisenberger is lazy. Here's a quick filter on comments
var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$") //var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$")
//
func isValidComment(comment string) bool { //func isValidComment(comment string) bool {
return validCommentRegexp.MatchString(comment) // return validCommentRegexp.MatchString(comment)
} //}
//
func writeComment(comment string, buf *bytes.Buffer) error { //func writeComment(comment string, buf *bytes.Buffer) error {
if comment != "" { // if comment != "" {
_, _ = buf.WriteString("/* ") // _, _ = buf.WriteString("/* ")
if !isValidComment(comment) { // if !isValidComment(comment) {
return errors.Newf("Invalid comment: %s", comment) // return errors.Newf("Invalid comment: %s", comment)
} // }
_, _ = buf.WriteString(comment) // _, _ = buf.WriteString(comment)
_, _ = buf.WriteString(" */") // _, _ = buf.WriteString(" */")
} // }
return nil // return nil
} //}
func newOrderByListClause(clauses ...OrderByClause) *listClause { func newOrderByListClause(clauses ...OrderByClause) *listClause {
ret := &listClause{ ret := &listClause{

View file

@ -475,14 +475,14 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
"All inner selects in Union statement must select the "+ "All inner selects in Union statement must select the "+
"same number of columns. For sanity, you probably "+ "same number of columns. For sanity, you probably "+
"want to select the same tableName columns in the same "+ "want to select the same tableName columns in the same "+
"order. If you are selecting on multiple tables, "+ "orderBy. If you are selecting on multiple tables, "+
"use Null to pad to the right number of fields.") "use Null to pad to the right number of fields.")
} }
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
// tests on outer statement: Group By, Order By, Limit // tests on outer statement: Group By, Order By, LIMIT
// on inner statement: AndWhere, WHERE (with And), Order By, Limit // on inner statement: AndWhere, WHERE (with And), Order By, LIMIT
select_queries := make([]SelectStatement, 0, 3) select_queries := make([]SelectStatement, 0, 3)
// We're not trying to write a SQL parser, so we won't warn if you do something silly like // We're not trying to write a SQL parser, so we won't warn if you do something silly like

View file

@ -3,7 +3,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"fmt" "fmt"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
) )
@ -15,7 +14,7 @@ type TableInterface interface {
Columns() []Column Columns() []Column
// Generates the sql string for the current tableName expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
SerializeSql(out *bytes.Buffer) error SerializeSql(out *queryData) error
} }
// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause // The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
@ -52,9 +51,6 @@ type WritableTable interface {
// Defines a physical tableName in the database that is both readable and writable. // Defines a physical tableName in the database that is both readable and writable.
// This function will panic if name is not valid // This function will panic if name is not valid
func NewTable(schemaName, name string, columns ...Column) *Table { func NewTable(schemaName, name string, columns ...Column) *Table {
if !validIdentifierName(name) {
panic("Invalid tableName name")
}
t := &Table{ t := &Table{
schemaName: schemaName, schemaName: schemaName,
@ -154,28 +150,20 @@ func (t *Table) ForceIndex(index string) *Table {
// Generates the sql string for the current tableName expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
func (t *Table) SerializeSql(out *bytes.Buffer) error { func (t *Table) SerializeSql(out *queryData) error {
if t == nil { if t == nil {
return errors.Newf("nil tableName. Generated sql: %s", out.String()) return errors.Newf("nil tableName.")
} }
_, _ = out.WriteString(t.schemaName)
_, _ = out.WriteString(".") out.WriteString(t.schemaName)
_, _ = out.WriteString(t.TableName()) out.WriteString(".")
out.WriteString(t.TableName())
if len(t.alias) > 0 { if len(t.alias) > 0 {
out.WriteString(" AS ") out.WriteString(" AS ")
out.WriteString(t.alias) out.WriteString(t.alias)
} }
if t.forcedIndex != "" {
if !validIdentifierName(t.forcedIndex) {
return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex)
}
_, _ = out.WriteString(" FORCE INDEX (")
_, _ = out.WriteString(t.forcedIndex)
_, _ = out.WriteString(")")
}
return nil return nil
} }
@ -307,7 +295,6 @@ func CrossJoin(
return newJoinTable(lhs, rhs, CROSS_JOIN, nil) return newJoinTable(lhs, rhs, CROSS_JOIN, nil)
} }
// Returns the tableName's name in the database
func (t *joinTable) SchemaName() string { func (t *joinTable) SchemaName() string {
return "" return ""
} }
@ -328,16 +315,16 @@ func (t *joinTable) Column(name string) Column {
panic("Not implemented") panic("Not implemented")
} }
func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { func (t *joinTable) SerializeSql(out *queryData) (err error) {
if t.lhs == nil { if t.lhs == nil {
return errors.Newf("nil lhs. Generated sql: %s", out.String()) return errors.Newf("nil lhs.")
} }
if t.rhs == nil { if t.rhs == nil {
return errors.Newf("nil rhs. Generated sql: %s", out.String()) return errors.Newf("nil rhs.")
} }
if t.onCondition == nil && t.join_type != CROSS_JOIN { if t.onCondition == nil && t.join_type != CROSS_JOIN {
return errors.Newf("nil onCondition. Generated sql: %s", out.String()) return errors.Newf("nil onCondition.")
} }
if err = t.lhs.SerializeSql(out); err != nil { if err = t.lhs.SerializeSql(out); err != nil {
@ -346,11 +333,11 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
switch t.join_type { switch t.join_type {
case INNER_JOIN: case INNER_JOIN:
_, _ = out.WriteString(" JOIN ") out.WriteString(" JOIN ")
case LEFT_JOIN: case LEFT_JOIN:
_, _ = out.WriteString(" LEFT JOIN ") out.WriteString(" LEFT JOIN ")
case RIGHT_JOIN: case RIGHT_JOIN:
_, _ = out.WriteString(" RIGHT JOIN ") out.WriteString(" RIGHT JOIN ")
case FULL_JOIN: case FULL_JOIN:
out.WriteString(" FULL JOIN ") out.WriteString(" FULL JOIN ")
case CROSS_JOIN: case CROSS_JOIN:
@ -362,8 +349,8 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
} }
if t.onCondition != nil { if t.onCondition != nil {
_, _ = out.WriteString(" ON ") out.WriteString(" ON ")
if err = t.onCondition.SerializeSql(out); err != nil { if err = t.onCondition.Serialize(out); err != nil {
return return
} }
} }

View file

@ -1,10 +1,6 @@
package sqlbuilder package sqlbuilder
// A clause that can be used in order by // A clause that can be used in orderBy by
type OrderByClause interface {
Clause
isOrderByClauseInterface
}
// A clause that is selectable. // A clause that is selectable.
//type Projection interface { //type Projection interface {
@ -16,9 +12,9 @@ type OrderByClause interface {
//type ColumnList []Column //type ColumnList []Column
// //
//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { //func (cl ColumnList) Serialize(out *bytes.Buffer, options ...serializeOption) error {
// for i, column := range cl { // for i, column := range cl {
// column.SerializeSql(out) // column.Serialize(out)
// //
// if i != len(cl)-1 { // if i != len(cl)-1 {
// out.WriteString(", ") // out.WriteString(", ")
@ -49,16 +45,6 @@ type OrderByClause interface {
// Boiler plates ... // Boiler plates ...
// //
type isOrderByClauseInterface interface {
isOrderByClauseType()
}
type isOrderByClause struct {
}
func (o *isOrderByClause) isOrderByClauseType() {
}
// //
//type isProjectionInterface interface { //type isProjectionInterface interface {
// isProjectionType() // isProjectionType()

View file

@ -1,17 +1,9 @@
package sqlbuilder package sqlbuilder
import ( // By default, rows selected by a UNION statement are out-of-orderBy
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
// By default, rows selected by a UNION statement are out-of-order
// If you have an ORDER BY on an inner SELECT statement, the only thing // If you have an ORDER BY on an inner SELECT statement, the only thing
// it affects is the LIMIT clause on that inner statement (the ordering will // it affects is the LIMIT clause on that inner statement (the ordering will
// still be out-of-order). // still be out-of-orderBy).
type UnionStatement interface { type UnionStatement interface {
Statement Statement
@ -27,177 +19,178 @@ type UnionStatement interface {
Offset(offset int64) UnionStatement Offset(offset int64) UnionStatement
} }
func Union(selects ...SelectStatement) UnionStatement { //
return &unionStatementImpl{ //func Union(selects ...SelectStatement) UnionStatement {
selects: selects, // return &unionStatementImpl{
limit: -1, // selects: selects,
offset: -1, // limit: -1,
unique: true, // offset: -1,
} // unique: true,
} // }
//}
func UnionAll(selects ...SelectStatement) UnionStatement { //
return &unionStatementImpl{ //func UnionAll(selects ...SelectStatement) UnionStatement {
selects: selects, // return &unionStatementImpl{
limit: -1, // selects: selects,
offset: -1, // limit: -1,
unique: false, // offset: -1,
} // unique: false,
} // }
//}
// Similar to selectStatementImpl, but less complete //
type unionStatementImpl struct { //// Similar to selectStatementImpl, but less complete
selects []SelectStatement //type unionStatementImpl struct {
where BoolExpression // selects []SelectStatement
group *listClause // where BoolExpression
order *listClause // group *listClause
limit, offset int64 // order *listClause
// True if results of the union should be deduped. // limit, offset int64
unique bool // // True if results of the union should be deduped.
} // unique bool
//}
func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { //
return Query(s, db, destination) //func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error {
} // return Query(s, db, destination)
//}
func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { //
return Execute(u, db) //func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
} // return Execute(u, db)
//}
func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { //
us.where = expression //func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement {
return us // us.where = expression
} // return us
//}
// Further filter the query, instead of replacing the filter //
func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { //// Further filter the query, instead of replacing the filter
if us.where == nil { //func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement {
return us.Where(expression) // if us.where == nil {
} // return us.Where(expression)
us.where = And(us.where, expression) // }
return us // us.where = And(us.where, expression)
} // return us
//}
func (us *unionStatementImpl) GroupBy( //
expressions ...Expression) UnionStatement { //func (us *unionStatementImpl) GroupBy(
// expressions ...Expression) UnionStatement {
us.group = &listClause{ //
clauses: make([]Clause, len(expressions), len(expressions)), // us.group = &listClause{
includeParentheses: false, // clauses: make([]Clause, len(expressions), len(expressions)),
} // includeParentheses: false,
// }
for i, e := range expressions { //
us.group.clauses[i] = e // for i, e := range expressions {
} // us.group.clauses[i] = e
return us // }
} // return us
//}
func (us *unionStatementImpl) OrderBy( //
clauses ...OrderByClause) UnionStatement { //func (us *unionStatementImpl) OrderBy(
// clauses ...OrderByClause) UnionStatement {
us.order = newOrderByListClause(clauses...) //
return us // us.order = newOrderByListClause(clauses...)
} // return us
//}
func (us *unionStatementImpl) Limit(limit int64) UnionStatement { //
us.limit = limit //func (us *unionStatementImpl) Limit(limit int64) UnionStatement {
return us // us.limit = limit
} // return us
//}
func (us *unionStatementImpl) Offset(offset int64) UnionStatement { //
us.offset = offset //func (us *unionStatementImpl) Offset(offset int64) UnionStatement {
return us // us.offset = offset
} // return us
//}
func (us *unionStatementImpl) String() (sql string, err error) { //
if len(us.selects) == 0 { //func (us *unionStatementImpl) String() (sql string, err error) {
return "", errors.Newf("Union statement must have at least one SELECT") // if len(us.selects) == 0 {
} // return "", errors.Newf("Union statement must have at least one SELECT")
// }
if len(us.selects) == 1 { //
return us.selects[0].String() // if len(us.selects) == 1 {
} // return us.selects[0].String()
// }
// Union statements in MySQL require that the same number of columns in each subquery //
var projections []Projection // // Union statements in MySQL require that the same number of columns in each subquery
// var projections []Projection
for _, statement := range us.selects { //
// do a type assertion to get at the underlying struct // for _, statement := range us.selects {
statementImpl, ok := statement.(*selectStatementImpl) // // do a type assertion to get at the underlying struct
if !ok { // statementImpl, ok := statement.(*selectStatementImpl)
return "", errors.Newf( // if !ok {
"Expected inner select statement to be of type " + // return "", errors.Newf(
"selectStatementImpl") // "Expected inner select statement to be of type " +
} // "selectStatementImpl")
// }
// check that for limit for statements with order by clauses //
if statementImpl.order != nil && statementImpl.limit < 0 { // // check that for limit for statements with orderBy by clauses
return "", errors.Newf( // if statementImpl.orderBy != nil && statementImpl.limit < 0 {
"All inner selects in Union statement must have LIMIT if " + // return "", errors.Newf(
"they have ORDER BY") // "All inner selects in Union statement must have LIMIT if " +
} // "they have ORDER BY")
// }
// check number of projections //
if projections == nil { // // check number of projections
projections = statementImpl.projections // if projections == nil {
} else { // projections = statementImpl.projections
if len(projections) != len(statementImpl.projections) { // } else {
return "", errors.Newf( // if len(projections) != len(statementImpl.projections) {
"All inner selects in Union statement must select the " + // return "", errors.Newf(
"same number of columns. For sanity, you probably " + // "All inner selects in Union statement must select the " +
"want to select the same tableName columns in the same " + // "same number of columns. For sanity, you probably " +
"order. If you are selecting on multiple tables, " + // "want to select the same tableName columns in the same " +
"use Null to pad to the right number of fields.") // "orderBy. If you are selecting on multiple tables, " +
} // "use Null to pad to the right number of fields.")
} // }
} // }
// }
buf := new(bytes.Buffer) //
for i, statement := range us.selects { // buf := new(bytes.Buffer)
if i != 0 { // for i, statement := range us.selects {
if us.unique { // if i != 0 {
_, _ = buf.WriteString(" UNION ") // if us.unique {
} else { // _, _ = buf.WriteString(" UNION ")
_, _ = buf.WriteString(" UNION ALL ") // } else {
} // _, _ = buf.WriteString(" UNION ALL ")
} // }
_, _ = buf.WriteString("(") // }
selectSql, err := statement.String() // _, _ = buf.WriteString("(")
if err != nil { // selectSql, err := statement.String()
return "", err // if err != nil {
} // return "", err
_, _ = buf.WriteString(selectSql) // }
_, _ = buf.WriteString(")") // _, _ = buf.WriteString(selectSql)
} // _, _ = buf.WriteString(")")
// }
if us.where != nil { //
_, _ = buf.WriteString(" WHERE ") // if us.where != nil {
if err = us.where.SerializeSql(buf); err != nil { // _, _ = buf.WriteString(" WHERE ")
return // if err = us.where.Serialize(buf); err != nil {
} // return
} // }
// }
if us.group != nil { //
_, _ = buf.WriteString(" GROUP BY ") // if us.group != nil {
if err = us.group.SerializeSql(buf); err != nil { // _, _ = buf.WriteString(" GROUP BY ")
return // if err = us.group.Serialize(buf); err != nil {
} // return
} // }
// }
if us.order != nil { //
_, _ = buf.WriteString(" ORDER BY ") // if us.order != nil {
if err = us.order.SerializeSql(buf); err != nil { // _, _ = buf.WriteString(" ORDER BY ")
return // if err = us.order.Serialize(buf); err != nil {
} // return
} // }
// }
if us.limit >= 0 { //
if us.offset >= 0 { // if us.limit >= 0 {
_, _ = buf.WriteString( // if us.offset >= 0 {
fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) // _, _ = buf.WriteString(
} else { // fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit))
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) // } else {
} // _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit))
} // }
return buf.String(), nil // }
} // return buf.String(), nil
//}

View file

@ -1,7 +1,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql" "database/sql"
"github.com/dropbox/godropbox/errors" "github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
@ -61,60 +60,64 @@ func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStateme
return u return u
} }
func (u *updateStatementImpl) String() (sql string, err error) { func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
buf := new(bytes.Buffer) out := &queryData{}
_, _ = buf.WriteString("UPDATE ") out.WriteString("UPDATE ")
if u.table == nil { if u.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) return "", nil, errors.New("nil tableName.")
} }
if err = u.table.SerializeSql(buf); err != nil { if err = u.table.SerializeSql(out); err != nil {
return return
} }
if len(u.updateValues) == 0 { if len(u.updateValues) == 0 {
return "", errors.Newf( return "", nil, errors.New("No column updated.")
"No column updated. Generated sql: %s",
buf.String())
} }
_, _ = buf.WriteString(" SET") out.WriteString(" SET")
if len(u.columns) > 1 { if len(u.columns) > 1 {
buf.WriteString(" ( ") out.WriteString(" ( ")
} else { } else {
buf.WriteString(" ") out.WriteString(" ")
} }
for i, column := range u.columns { //for i, column := range u.columns {
if i > 0 { // if i > 0 {
buf.WriteString(", ") // out.WriteString(", ")
} // }
//
// out.WriteString(column.Name())
//
// if err != nil {
// return
// }
//}
buf.WriteString(column.Name()) err = serializeColumnList(u.columns, out)
if err != nil { if err != nil {
return return "", nil, err
}
} }
if len(u.columns) > 1 { if len(u.columns) > 1 {
buf.WriteString(" )") out.WriteString(" )")
} }
buf.WriteString(" =") out.WriteString(" =")
if len(u.updateValues) > 1 { if len(u.updateValues) > 1 {
buf.WriteString(" (") out.WriteString(" (")
} }
for i, value := range u.updateValues { for i, value := range u.updateValues {
if i > 0 { if i > 0 {
buf.WriteString(", ") out.WriteString(", ")
} }
err = value.SerializeSql(buf) err = value.Serialize(out)
if err != nil { if err != nil {
return return
@ -122,29 +125,27 @@ func (u *updateStatementImpl) String() (sql string, err error) {
} }
if len(u.updateValues) > 1 { if len(u.updateValues) > 1 {
buf.WriteString(" )") out.WriteString(" )")
} }
if u.where == nil { if u.where == nil {
return "", errors.Newf( return "", nil, errors.New("Updating without a WHERE clause.")
"Updating without a WHERE clause. Generated sql: %s",
buf.String())
} }
_, _ = buf.WriteString(" WHERE ") out.WriteString(" WHERE ")
if err = u.where.SerializeSql(buf); err != nil { if err = u.where.Serialize(out); err != nil {
return return
} }
if len(u.returning) > 0 { if len(u.returning) > 0 {
buf.WriteString(" RETURNING ") out.WriteString(" RETURNING ")
err = serializeProjectionList(u.returning, buf) err = serializeProjectionList(u.returning, out)
if err != nil { if err != nil {
return return
} }
} }
return buf.String() + ";", nil return out.queryBuff.String(), out.args, nil
} }

View file

@ -83,7 +83,7 @@ func TestUpdate(t *testing.T) {
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { //func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2)) // stmt.WHERE(EqL(table1Col2, 2))
// stmt.OrderBy(table1Col2) // stmt.ORDER_BY(table1Col2)
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)
// //
@ -99,7 +99,7 @@ func TestUpdate(t *testing.T) {
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { //func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2)) // stmt.WHERE(EqL(table1Col2, 2))
// stmt.Limit(5) // stmt.LIMIT(5)
// sql, err := stmt.String() // sql, err := stmt.String()
// c.Assert(err, gc.IsNil) // c.Assert(err, gc.IsNil)
// //

View file

@ -1,19 +1,20 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"database/sql" "database/sql"
"errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
) )
func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error { func serializeOrderByClauseList(orderByClauses []OrderByClause, out *queryData) error {
for i, value := range expressions {
for i, value := range orderByClauses {
if i > 0 { if i > 0 {
buf.WriteString(", ") out.WriteString(", ")
} }
err := value.SerializeSql(buf) err := value.Serialize(out)
if err != nil { if err != nil {
return err return err
@ -23,13 +24,33 @@ func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error
return nil return nil
} }
func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error { func serializeClauseList(clauses []Clause, out *queryData) (err error) {
for i, value := range projections {
for i, c := range clauses {
if i > 0 { if i > 0 {
buf.WriteString(", ") out.WriteString(", ")
} }
err := value.SerializeForProjection(buf) if c == nil {
return errors.New("nil clause.")
}
if err = c.Serialize(out); err != nil {
return
}
}
return nil
}
func serializeExpressionList(expressions []Expression, separator string, out *queryData) error {
for i, value := range expressions {
if i > 0 {
out.WriteString(separator)
}
err := value.Serialize(out)
if err != nil { if err != nil {
return err return err
@ -39,24 +60,55 @@ func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error
return nil return nil
} }
func serializeProjectionList(projections []Projection, out *queryData) error {
for i, col := range projections {
if i > 0 {
out.WriteByte(',')
}
if col == nil {
return errors.New("Projection expression is nil.")
}
if err := col.SerializeForProjection(out); err != nil {
return err
}
}
return nil
}
func serializeColumnList(columns []Column, out *queryData) error {
for i, col := range columns {
if i > 0 {
out.WriteByte(',')
}
if col == nil {
return errors.New("nil column in columns list.")
}
out.WriteString(col.Name())
}
return nil
}
func Query(statement Statement, db types.Db, destination interface{}) error { func Query(statement Statement, db types.Db, destination interface{}) error {
query, err := statement.String() query, args, err := statement.Sql()
if err != nil { if err != nil {
return err return err
} }
return execution.Execute(db, query, destination) return execution.Query(db, query, args, destination)
} }
func Execute(statement Statement, db types.Db) (res sql.Result, err error) { func Execute(statement Statement, db types.Db) (res sql.Result, err error) {
query, err := statement.String() query, args, err := statement.Sql()
if err != nil { if err != nil {
return return
} }
res, err = db.Exec(query) return db.Exec(query, args...)
return
} }

View file

@ -25,13 +25,14 @@ func TestGenerateModel(t *testing.T) {
func TestSelect_ScanToStruct(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) {
actor := model.Actor{} actor := model.Actor{}
query := Actor.SELECT(Actor.AllColumns).OrderBy(Actor.ActorID.Asc()) query := Actor.SELECT(Actor.AllColumns).ORDER_BY(Actor.ActorID.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`) assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`)
assert.Equal(t, len(args), 0)
err = query.Query(db, &actor) err = query.Query(db, &actor)
@ -50,12 +51,14 @@ func TestSelect_ScanToStruct(t *testing.T) {
func TestSelect_ScanToSlice(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) {
customers := []model.Customer{} customers := []model.Customer{}
query := Customer.SELECT(Customer.AllColumns).OrderBy(Customer.CustomerID.Asc()) query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`)
assert.Equal(t, len(args), 0)
err = query.Query(db, &customers) err = query.Query(db, &customers)
assert.NilError(t, err) assert.NilError(t, err)
@ -76,7 +79,7 @@ func TestSelect_ScanToSlice(t *testing.T) {
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). // SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns).
// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) // WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2)))
// //
// queryStr, err := query.String() // queryStr, args, err := query.Sql()
// assert.NilError(t, err) // assert.NilError(t, err)
// assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) // assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`)
// //
@ -104,14 +107,18 @@ func TestJoinQuerySlice(t *testing.T) {
query := Film. query := Film.
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns). SELECT(Language.AllColumns, Film.AllColumns).
Where(Film.Rating.EqL(string(model.MpaaRating_NC17))). WHERE(Film.Rating.EqL(string(model.MpaaRating_NC17))).
Limit(15) LIMIT(15)
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`) fmt.Println(queryStr)
//fmt.Println(queryStr) assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = $1 LIMIT $2`)
assert.Equal(t, len(args), 2)
assert.Equal(t, args[0], string(model.MpaaRating_NC17))
assert.Equal(t, args[1], int64(15))
err = query.Query(db, &filmsPerLanguage) err = query.Query(db, &filmsPerLanguage)
@ -149,7 +156,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) {
query := Film.INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). query := Film.INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns). SELECT(Language.AllColumns, Film.AllColumns).
Limit(limit) LIMIT(limit)
filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} filmsPerLanguageWithPtrs := []*FilmsPerLanguage{}
err := query.Query(db, &filmsPerLanguageWithPtrs) err := query.Query(db, &filmsPerLanguageWithPtrs)
@ -179,7 +186,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
customersAsc := []model.Customer{} customersAsc := []model.Customer{}
err := Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). err := Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
OrderBy(Customer.FirstName.Asc()). ORDER_BY(Customer.FirstName.Asc()).
Query(db, &customersAsc) Query(db, &customersAsc)
assert.NilError(t, err) assert.NilError(t, err)
@ -189,7 +196,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
customersDesc := []model.Customer{} customersDesc := []model.Customer{}
err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
OrderBy(Customer.FirstName.Desc()). ORDER_BY(Customer.FirstName.Desc()).
Query(db, &customersDesc) Query(db, &customersDesc)
assert.NilError(t, err) assert.NilError(t, err)
@ -202,7 +209,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
customersAscDesc := []model.Customer{} customersAscDesc := []model.Customer{}
err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
OrderBy(Customer.FirstName.Asc(), Customer.LastName.Desc()). ORDER_BY(Customer.FirstName.Asc(), Customer.LastName.Desc()).
Query(db, &customersAscDesc) Query(db, &customersAscDesc)
assert.NilError(t, err) assert.NilError(t, err)
@ -227,13 +234,14 @@ func TestSelectFullJoin(t *testing.T) {
query := Customer. query := Customer.
FULL_JOIN(Address, Customer.AddressID.Eq(Address.AddressID)). FULL_JOIN(Address, Customer.AddressID.Eq(Address.AddressID)).
SELECT(Customer.AllColumns, Address.AllColumns). SELECT(Customer.AllColumns, Address.AllColumns).
OrderBy(Customer.CustomerID.Asc()) ORDER_BY(Customer.CustomerID.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer FULL JOIN dvds.address ON customer.address_id = address.address_id ORDER BY customer.customer_id ASC`) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer FULL JOIN dvds.address ON customer.address_id = address.address_id ORDER BY customer.customer_id ASC`)
assert.Equal(t, len(args), 0)
allCustomersAndAddress := []struct { allCustomersAndAddress := []struct {
Address *model.Address Address *model.Address
@ -259,13 +267,14 @@ func TestSelectFullCrossJoin(t *testing.T) {
query := Customer. query := Customer.
CrossJoin(Address). CrossJoin(Address).
SELECT(Customer.AllColumns, Address.AllColumns). SELECT(Customer.AllColumns, Address.AllColumns).
OrderBy(Customer.CustomerID.Asc()). ORDER_BY(Customer.CustomerID.Asc()).
Limit(1000) LIMIT(1000)
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT 1000`) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT $1`)
assert.Equal(t, len(args), 1)
customerAddresCrosJoined := []model.Customer{} customerAddresCrosJoined := []model.Customer{}
@ -286,9 +295,10 @@ func TestSelectSelfJoin(t *testing.T) {
query := f1. query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))).
SELECT(f1.AllColumns, f2.AllColumns). SELECT(f1.AllColumns, f2.AllColumns).
OrderBy(f1.FilmID.Asc()) ORDER_BY(f1.FilmID.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.Equal(t, len(args), 0)
assert.NilError(t, err) assert.NilError(t, err)
@ -326,10 +336,11 @@ func TestSelectAliasColumn(t *testing.T) {
SELECT(f1.Title.As("thesame_length_films.title1"), SELECT(f1.Title.As("thesame_length_films.title1"),
f2.Title.As("thesame_length_films.title2"), f2.Title.As("thesame_length_films.title2"),
f1.Length.As("thesame_length_films.length")). f1.Length.As("thesame_length_films.length")).
OrderBy(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()). ORDER_BY(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()).
Limit(1000) LIMIT(1000)
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.Equal(t, len(args), 1)
assert.NilError(t, err) assert.NilError(t, err)
@ -372,9 +383,10 @@ func TestSelectSelfReferenceType(t *testing.T) {
INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)).
SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName) SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName)
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, len(args), 0)
staffs := []staff{} staffs := []staff{}
@ -394,13 +406,13 @@ func TestSubQuery(t *testing.T) {
// selectStmtTable.RefIntColumnName("actor.last_name").As("nesto2"), // selectStmtTable.RefIntColumnName("actor.last_name").As("nesto2"),
// ) // )
// //
//queryStr, err := query.String() //queryStr, args, err := query.Sql()
// //
//assert.NilError(t, err) //assert.NilError(t, err)
// //
//fmt.Println(queryStr) //fmt.Println(queryStr)
// //
//avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).Limit(1)) //avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).LIMIT(1))
// //
//Customer. //Customer.
// INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))). // INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))).
@ -408,7 +420,7 @@ func TestSubQuery(t *testing.T) {
// WHERE(Actor.LastName.Neq(avrgCustomer)) // WHERE(Actor.LastName.Neq(avrgCustomer))
rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating).
Where(Film.Rating.EqL("R")). WHERE(Film.Rating.EqL("R")).
AsTable("films") AsTable("films")
query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)). query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)).
@ -420,10 +432,10 @@ func TestSubQuery(t *testing.T) {
rFilmsOnly.RefStringColumn(Film.Rating).As("film.rating"), rFilmsOnly.RefStringColumn(Film.Rating).As("film.rating"),
) )
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 1)
fmt.Println(queryStr) fmt.Println(queryStr)
} }
@ -431,12 +443,12 @@ func TestSubQuery(t *testing.T) {
func TestSelectFunctions(t *testing.T) { func TestSelectFunctions(t *testing.T) {
query := Film.SELECT(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate")) query := Film.SELECT(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate"))
str, err := query.String() str, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`) assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`)
assert.Equal(t, len(args), 0)
fmt.Println(str) fmt.Println(str)
} }
@ -445,13 +457,13 @@ func TestSelectQueryScalar(t *testing.T) {
maxFilmRentalRate := sqlbuilder.NumExp(Film.SELECT(sqlbuilder.MAX(Film.RentalRate))) maxFilmRentalRate := sqlbuilder.NumExp(Film.SELECT(sqlbuilder.MAX(Film.RentalRate)))
query := Film.SELECT(Film.AllColumns). query := Film.SELECT(Film.AllColumns).
Where(Film.RentalRate.Eq(maxFilmRentalRate)). WHERE(Film.RentalRate.Eq(maxFilmRentalRate)).
OrderBy(Film.FilmID.Asc()) ORDER_BY(Film.FilmID.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 0)
fmt.Println(queryStr) fmt.Println(queryStr)
maxRentalRateFilms := []model.Film{} maxRentalRateFilms := []model.Film{}
@ -488,16 +500,17 @@ func TestSelectGroupByHaving(t *testing.T) {
Payment.CustomerID.As("customer_payment_sum.customer_id"), Payment.CustomerID.As("customer_payment_sum.customer_id"),
sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"), sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"),
). ).
GroupBy(Payment.CustomerID). GROUP_BY(Payment.CustomerID).
OrderBy(sqlbuilder.SUM(Payment.Amount).Asc()). ORDER_BY(sqlbuilder.SUM(Payment.Amount).Asc()).
HAVING(sqlbuilder.Gt(sqlbuilder.SUM(Payment.Amount), sqlbuilder.Literal(100))) HAVING(sqlbuilder.SUM(Payment.Amount).Gt(sqlbuilder.NewNumericLiteral(100)))
queryStr, err := customersPaymentQuery.String() queryStr, args, err := customersPaymentQuery.Sql()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, len(args), 1)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>$1 ORDER BY SUM(payment.amount) ASC`)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount) ASC`)
type CustomerPaymentSum struct { type CustomerPaymentSum struct {
CustomerID int16 CustomerID int16
AmountSum float64 AmountSum float64
@ -528,7 +541,7 @@ func TestSelectGroupBy2(t *testing.T) {
Payment.CustomerID, Payment.CustomerID,
sqlbuilder.SUM(Payment.Amount).As("amount_sum"), sqlbuilder.SUM(Payment.Amount).As("amount_sum"),
). ).
GroupBy(Payment.CustomerID) GROUP_BY(Payment.CustomerID)
customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum") customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum")
amountSumColumn := customersPaymentTable.RefIntColumnName("amount_sum") amountSumColumn := customersPaymentTable.RefIntColumnName("amount_sum")
@ -536,11 +549,12 @@ func TestSelectGroupBy2(t *testing.T) {
query := Customer. query := Customer.
INNER_JOIN(customersPaymentTable, Customer.CustomerID.Eq(customersPaymentTable.RefIntColumn(Payment.CustomerID))). INNER_JOIN(customersPaymentTable, Customer.CustomerID.Eq(customersPaymentTable.RefIntColumn(Payment.CustomerID))).
SELECT(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")). SELECT(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")).
OrderBy(amountSumColumn.Asc()) ORDER_BY(amountSumColumn.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, len(args), 0)
err = query.Query(db, &customersWithAmounts) err = query.Query(db, &customersWithAmounts)
assert.NilError(t, err) assert.NilError(t, err)
@ -565,13 +579,13 @@ func TestSelectGroupBy2(t *testing.T) {
func TestSelectTimeColumns(t *testing.T) { func TestSelectTimeColumns(t *testing.T) {
query := Payment.SELECT(Payment.AllColumns). query := Payment.SELECT(Payment.AllColumns).
Where(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")). WHERE(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")).
OrderBy(Payment.PaymentDate.Asc()) ORDER_BY(Payment.PaymentDate.Asc())
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 1)
fmt.Println(queryStr) fmt.Println(queryStr)
payments := []model.Payment{} payments := []model.Payment{}

View file

@ -18,13 +18,14 @@ func TestInsertValues(t *testing.T) {
VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT). VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT).
RETURNING(table.Link.ID) RETURNING(table.Link.ID)
insertQueryStr, err := insertQuery.String() insertQueryStr, args, err := insertQuery.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 8)
fmt.Println(insertQueryStr) fmt.Println(insertQueryStr)
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id AS "link.id";`) assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ($1, $2, DEFAULT), ($3, $4, DEFAULT), ($5, $6, DEFAULT), ($7, $8, DEFAULT) RETURNING link.id AS "link.id";`)
res, err := insertQuery.Execute(db) res, err := insertQuery.Execute(db)
assert.NilError(t, err) assert.NilError(t, err)
@ -68,9 +69,10 @@ func TestInsertDataObject(t *testing.T) {
INSERT(table.Link.URL, table.Link.Name). INSERT(table.Link.URL, table.Link.Name).
VALUES_MAPPING(linkData) VALUES_MAPPING(linkData)
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 2)
fmt.Println(queryStr) fmt.Println(queryStr)
@ -92,9 +94,10 @@ func TestInsertQuery(t *testing.T) {
INSERT(table.Link.URL, table.Link.Name). INSERT(table.Link.URL, table.Link.Name).
QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name)) QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name))
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 0)
fmt.Println(queryStr) fmt.Println(queryStr)

View file

@ -12,11 +12,12 @@ import (
func TestUUIDType(t *testing.T) { func TestUUIDType(t *testing.T) {
query := table.AllTypes. query := table.AllTypes.
SELECT(table.AllTypes.AllColumns). SELECT(table.AllTypes.AllColumns).
Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) WHERE(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 1)
fmt.Println(queryStr) fmt.Println(queryStr)
//assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid", all_types.json AS "all_types.json", all_types.jsonb AS "all_types.jsonb" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11`) //assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid", all_types.json AS "all_types.json", all_types.jsonb AS "all_types.jsonb" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11`)
result := model.AllTypes{} result := model.AllTypes{}
@ -29,11 +30,11 @@ func TestEnumType(t *testing.T) {
query := table.Person. query := table.Person.
SELECT(table.Person.AllColumns) SELECT(table.Person.AllColumns)
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(queryStr) fmt.Println(queryStr)
assert.Equal(t, len(args), 0)
result := []model.Person{} result := []model.Person{}
err = query.Query(db, &result) err = query.Query(db, &result)

View file

@ -23,10 +23,10 @@ func TestUpdateValues(t *testing.T) {
SET("Bong", "http://bong.com"). SET("Bong", "http://bong.com").
WHERE(table.Link.Name.EqL("Bing")) WHERE(table.Link.Name.EqL("Bing"))
queryStr, err := query.String() queryStr, args, err := query.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 3)
fmt.Println(queryStr) fmt.Println(queryStr)
res, err := query.Execute(db) res, err := query.Execute(db)
@ -38,7 +38,7 @@ func TestUpdateValues(t *testing.T) {
links := []model.Link{} links := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns). err = table.Link.SELECT(table.Link.AllColumns).
Where(table.Link.Name.EqL("Bong")). WHERE(table.Link.Name.EqL("Bong")).
Query(db, &links) Query(db, &links)
assert.NilError(t, err) assert.NilError(t, err)
@ -63,10 +63,10 @@ func TestUpdateAndReturning(t *testing.T) {
WHERE(table.Link.Name.EqL("Ask")). WHERE(table.Link.Name.EqL("Ask")).
RETURNING(table.Link.AllColumns) RETURNING(table.Link.AllColumns)
stmtStr, err := stmt.String() stmtStr, args, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, len(args), 3)
fmt.Println(stmtStr) fmt.Println(stmtStr)
links := []model.Link{} links := []model.Link{}