Refactoring to support parameterized queries.

This commit is contained in:
zer0sub 2019-04-29 14:39:48 +02:00
parent bc6a2bbcac
commit fef8f0ef83
33 changed files with 1112 additions and 1206 deletions

View file

@ -1,7 +1,5 @@
package sqlbuilder
import "bytes"
type Alias struct {
expression Expression
alias string
@ -14,9 +12,9 @@ func NewAlias(expression Expression, alias string) *Alias {
}
}
func (a *Alias) SerializeForProjection(out *bytes.Buffer) error {
func (a *Alias) SerializeForProjection(out *queryData) error {
err := a.expression.SerializeSql(out, ALIASED)
err := a.expression.Serialize(out, SKIP_DEFAULT_ALIASING)
if err != nil {
return err

View file

@ -66,11 +66,7 @@ type boolLiteralExpression struct {
func newBoolLiteralExpression(value bool) BoolExpression {
boolLiteralExpression := boolLiteralExpression{}
sqlValue, err := sqltypes.BuildValue(value)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
boolLiteralExpression.literalExpression = *NewLiteralExpression(sqlValue)
boolLiteralExpression.literalExpression = *Literal(value)
boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression
return &boolLiteralExpression
@ -113,27 +109,27 @@ func newPrefixBoolExpression(expression Expression, operator []byte) BoolExpress
}
//---------------------------------------------------//
type conjunctBoolExpression struct {
expressionInterfaceImpl
boolInterfaceImpl
conjunctExpression
name string
}
func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression {
boolExpression := conjunctBoolExpression{
conjunctExpression: conjunctExpression{
expressions: expressions,
conjunction: operator,
},
}
boolExpression.expressionInterfaceImpl.parent = &boolExpression
boolExpression.boolInterfaceImpl.parent = &boolExpression
return &boolExpression
}
//type conjunctBoolExpression struct {
// expressionInterfaceImpl
// boolInterfaceImpl
//
// conjunctExpression
// name string
//}
//
//func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression {
// boolExpression := conjunctBoolExpression{
// conjunctExpression: conjunctExpression{
// expressions: expressions,
// conjunction: operator,
// },
// }
//
// boolExpression.expressionInterfaceImpl.parent = &boolExpression
// boolExpression.boolInterfaceImpl.parent = &boolExpression
//
// return &boolExpression
//}
//---------------------------------------------------//
type inExpression struct {
@ -146,34 +142,33 @@ type inExpression struct {
err error
}
func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (c *inExpression) Serialize(out *queryData, options ...serializeOption) error {
if c.err != nil {
return errors.Wrap(c.err, "Invalid IN expression")
}
if c.lhs == nil {
return errors.Newf(
"lhs of in expression is nil. Generated sql: %s",
out.String())
return errors.Newf("lhs of in expression is nil.")
}
// We'll serialize the lhs even if we don't need it to ensure no error
buf := &bytes.Buffer{}
err := c.lhs.SerializeSql(buf)
err := c.lhs.Serialize(out, options...)
if err != nil {
return err
}
if c.rhs == nil {
_, _ = out.WriteString("FALSE")
out.WriteString("FALSE")
return nil
}
_, _ = out.WriteString(buf.String())
_, _ = out.WriteString(" IN ")
out.WriteString(buf.String())
out.WriteString(" IN ")
err = c.rhs.Serialize(out)
err = c.rhs.SerializeSql(out)
if err != nil {
return err
}
@ -183,10 +178,6 @@ func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOptio
// Returns a representation of "a=b"
func Eq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() {
return newBinaryBoolExpression(lhs, rhs, []byte(" IS "))
}
return newBinaryBoolExpression(lhs, rhs, []byte(" = "))
}
@ -197,10 +188,6 @@ func EqL(lhs Expression, val interface{}) BoolExpression {
// Returns a representation of "a!=b"
func NotEq(lhs, rhs Expression) BoolExpression {
lit, ok := rhs.(*literalExpression)
if ok && sqltypes.Value(lit.value).IsNull() {
return newBinaryBoolExpression(lhs, rhs, []byte(" IS NOT "))
}
return newBinaryBoolExpression(lhs, rhs, []byte("!="))
}
@ -258,14 +245,13 @@ func IsTrue(expr BoolExpression) BoolExpression {
return newPrefixBoolExpression(expr, []byte(" IS TRUE "))
}
// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses
func And(expressions ...BoolExpression) BoolExpression {
return NewConjunctBoolExpression([]byte(" AND "), expressions...)
func And(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, []byte(" AND "))
}
// Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses
func Or(expressions ...BoolExpression) BoolExpression {
return NewConjunctBoolExpression([]byte(" OR "), expressions...)
func Or(lhs, rhs Expression) BoolExpression {
return newBinaryBoolExpression(lhs, rhs, []byte(" OR "))
}
func Like(lhs, rhs Expression) BoolExpression {

View file

@ -10,7 +10,7 @@ func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3))
out := bytes.Buffer{}
err := boolExpression.SerializeSql(&out)
err := boolExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "2 = 3")
@ -29,7 +29,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`)
@ -39,7 +39,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`)
@ -50,7 +50,7 @@ func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{}
err := notExpression.SerializeSql(&out)
err := notExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), " NOT 2 = 1")
@ -69,7 +69,7 @@ func TestUnaryExpression(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`)
@ -80,7 +80,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{}
err := notExpression.SerializeSql(&out)
err := notExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), " IS TRUE 2 = 1")
@ -89,7 +89,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
err := exp.SerializeSql(&out)
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`)
@ -100,7 +100,7 @@ func TestBoolLiteral(t *testing.T) {
literal := newBoolLiteralExpression(true)
out := bytes.Buffer{}
err := literal.SerializeSql(&out)
err := literal.Serialize(&out)
assert.NilError(t, err)

View file

@ -1,16 +1,91 @@
package sqlbuilder
import "bytes"
import (
"bytes"
"errors"
"strconv"
)
type serializeOption int
const (
ALIASED = iota
SKIP_DEFAULT_ALIASING = iota
FOR_PROJECTION
)
type Clause interface {
SerializeSql(out *bytes.Buffer, options ...serializeOption) error
Serialize(out *queryData, options ...serializeOption) error
}
type queryData struct {
queryBuff bytes.Buffer
args []interface{}
}
func (q *queryData) Write(data []byte) {
q.queryBuff.Write(data)
}
func (q *queryData) WriteString(str string) {
q.queryBuff.WriteString(str)
}
func (q *queryData) WriteByte(b byte) {
q.queryBuff.WriteByte(b)
}
func (q *queryData) InsertArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.queryBuff.WriteString(argPlaceholder)
}
func argToString(value interface{}) (string, error) {
switch bindVal := value.(type) {
case bool:
if bindVal {
return "TRUE", nil
} else {
return "FALSE", nil
}
case int8:
return strconv.FormatInt(int64(bindVal), 10), nil
case int:
return strconv.FormatInt(int64(bindVal), 10), nil
case int16:
return strconv.FormatInt(int64(bindVal), 10), nil
case int32:
return strconv.FormatInt(int64(bindVal), 10), nil
case int64:
return strconv.FormatInt(int64(bindVal), 10), nil
case uint8:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint16:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint32:
return strconv.FormatUint(uint64(bindVal), 10), nil
case uint64:
return strconv.FormatUint(uint64(bindVal), 10), nil
case float32:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil
case float64:
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil
case string:
return bindVal, nil
case []byte:
return string(bindVal), nil
//TODO: implement
//case time.Time:
// return bindVal.String())
default:
return "", errors.New("Unsupported literal type. ")
}
}
func contains(s []serializeOption, e serializeOption) bool {

View file

@ -3,14 +3,9 @@
package sqlbuilder
import (
"bytes"
"regexp"
"strings"
)
// XXX: Maybe add UIntColumn
// Representation of a tableName for query generation
type Column interface {
Expression
@ -28,11 +23,6 @@ const (
NotNullable NullableColumn = false
)
//// A column that can be refer to outside of the projection list
//type NonAliasColumn interface {
// Column
//}
type Collation string
const (
@ -82,194 +72,39 @@ func (c *baseColumn) setTableName(table string) error {
return nil
}
func (c baseColumn) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error {
if c.tableName != "" {
_, _ = out.WriteString(c.tableName)
_, _ = out.WriteString(".")
out.WriteString(c.tableName)
out.WriteString(".")
}
containsDot := strings.Contains(c.name, ".")
if containsDot {
out.WriteString("\"")
}
_, _ = out.WriteString(c.name)
if containsDot {
out.WriteString("\"")
out.WriteString(`"`)
}
if contains(options, FOR_PROJECTION) && !contains(options, ALIASED) && c.tableName != "" {
_, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
out.WriteString(c.name)
if containsDot {
out.WriteString(`"`)
}
if contains(options, FOR_PROJECTION) && !contains(options, SKIP_DEFAULT_ALIASING) && c.tableName != "" {
out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"")
}
return nil
}
//
//type bytesColumn struct {
// baseColumn
//}
//// This is a strict subset of the actual allowed identifiers
//var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$")
//
//// Representation of VARBINARY/BLOB columns
//// This function will panic if name is not valid
//func BytesColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in bytes column")
//// Returns true if the given string is suitable as an identifier.
//func validIdentifierName(name string) bool {
// return validIdentifierRegexp.MatchString(name)
//}
// bc := &bytesColumn{}
// bc.name = name
// bc.nullable = nullable
// return bc
//}
//
//type stringColumn struct {
// baseColumn
// charset Charset
// collation Collation
//}
//
//// Representation of VARCHAR/TEXT columns
//// This function will panic if name is not valid
//func StrColumn(
// name string,
// charset Charset,
// collation Collation,
// nullable NullableColumn) Column {
//
// if !validIdentifierName(name) {
// panic("Invalid column name in str column")
// }
// sc := &stringColumn{charset: charset, collation: collation}
// sc.name = name
// sc.nullable = nullable
// return sc
//}
//
//type dateTimeColumn struct {
// baseColumn
//}
//
//// Representation of DateTime columns
//// This function will panic if name is not valid
//func DateTimeColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in datetime column")
// }
// dc := &dateTimeColumn{}
// dc.name = name
// dc.nullable = nullable
// return dc
//}
//type IntegerColumn struct {
// baseColumn
//}
//
//// Representation of any integer column
//// This function will panic if name is not valid
//func IntColumn(name string, nullable NullableColumn) *IntegerColumn {
// if !validIdentifierName(name) {
// panic("Invalid column name in int column")
// }
// ic := &IntegerColumn{}
// ic.name = name
// ic.nullable = nullable
// return ic
//}
//type doubleColumn struct {
// baseColumn
//}
//
//// Representation of any double column
//// This function will panic if name is not valid
//func DoubleColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in int column")
// }
// ic := &doubleColumn{}
// ic.name = name
// ic.nullable = nullable
// return ic
//}
//
//type booleanColumn struct {
// baseColumn
//
// // XXX: Maybe allow isBoolExpression (for now, not included because
// // the deferred lookup equivalent can never be isBoolExpression)
//}
// Representation of TINYINT used as a bool
// This function will panic if name is not valid
//func NewBoolColumn(name string, nullable NullableColumn) Column {
// if !validIdentifierName(name) {
// panic("Invalid column name in bool column")
// }
// bc := &booleanColumn{}
// bc.name = name
// bc.nullable = nullable
// return bc
//}
//
//type aliasColumn struct {
// baseColumn
// expression Expression
//}
//
//func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error {
// _ = out.WriteByte('`')
// _, _ = out.WriteString(c.name)
// _ = out.WriteByte('`')
// return nil
//}
//
//func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error {
// if !validIdentifierName(c.name) {
// return errors.Newf(
// "Invalid alias name `%s`. Generated sql: %s",
// c.name,
// out.String())
// }
// if c.expression == nil {
// return errors.Newf(
// "Cannot alias a nil expression. Generated sql: %s",
// out.String())
// }
//
// _ = out.WriteByte('(')
// if c.expression == nil {
// return errors.Newf("nil alias clause. Generate sql: %s", out.String())
// }
// if err := c.expression.SerializeSql(out); err != nil {
// return err
// }
// _, _ = out.WriteString(") AS \"")
// _, _ = out.WriteString(c.name)
// _ = out.WriteByte('"')
// return nil
//}
//func (c *aliasColumn) setTableName(table string) error {
// return errors.Newf(
// "Alias column '%s' should never have setTableName called on it",
// c.name)
//}
// Representation of aliased clauses (expression AS name)
//func Alias(name string, c Expression) Column {
// ac := &aliasColumn{}
// ac.name = name
// ac.expression = c
// return ac
//}
// This is a strict subset of the actual allowed identifiers
var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$")
// Returns true if the given string is suitable as an identifier.
func validIdentifierName(name string) bool {
return validIdentifierRegexp.MatchString(name)
}
//
//// Pseudo Column type returned by tableName.C(name)
@ -289,12 +124,12 @@ func validIdentifierName(name string) bool {
//func (c *deferredLookupColumn) SerializeSqlForColumnList(
// out *bytes.Buffer) error {
//
// return c.SerializeSql(out)
// return c.Serialize(out)
//}
//
//func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error {
//func (c *deferredLookupColumn) Serialize(out *bytes.Buffer) error {
// if c.cachedColumn != nil {
// return c.cachedColumn.SerializeSql(out)
// return c.cachedColumn.Serialize(out)
// }
//
// col, err := c.tableName.getColumn(c.colName)
@ -303,7 +138,7 @@ func validIdentifierName(name string) bool {
// }
//
// c.cachedColumn = col
// return col.SerializeSql(out)
// return col.Serialize(out)
//}
//
//func (c *deferredLookupColumn) setTableName(tableName string) error {

View file

@ -8,9 +8,7 @@ type BoolColumn struct {
}
func NewBoolColumn(name string, nullable NullableColumn) *BoolColumn {
if !validIdentifierName(name) {
panic("Invalid column name in bool column")
}
boolColumn := &BoolColumn{}
boolColumn.baseColumn = newBaseColumn(name, nullable, "", boolColumn)
@ -26,9 +24,6 @@ type NumericColumn struct {
}
func NewNumericColumn(name string, nullable NullableColumn) *NumericColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
numericColumn := &NumericColumn{}
@ -70,9 +65,6 @@ type StringColumn struct {
// Representation of any integer column
// This function will panic if name is not valid
func NewStringColumn(name string, nullable NullableColumn) *StringColumn {
if !validIdentifierName(name) {
panic("Invalid column name")
}
stringColumn := &StringColumn{}

View file

@ -10,20 +10,20 @@ func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable)
out := bytes.Buffer{}
err := boolColumn.SerializeSql(&out)
err := boolColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = boolColumn.SerializeSql(&out, FOR_PROJECTION)
err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = boolColumn.setTableName("table1")
assert.NilError(t, err)
err = boolColumn.SerializeSql(&out, FOR_PROJECTION)
err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
@ -40,20 +40,20 @@ func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable)
out := bytes.Buffer{}
err := integerColumn.SerializeSql(&out)
err := integerColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = integerColumn.SerializeSql(&out, FOR_PROJECTION)
err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = integerColumn.setTableName("table1")
assert.NilError(t, err)
err = integerColumn.SerializeSql(&out, FOR_PROJECTION)
err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
@ -70,20 +70,20 @@ func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable)
out := bytes.Buffer{}
err := numericColumn.SerializeSql(&out)
err := numericColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = numericColumn.SerializeSql(&out)
err = numericColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
out.Reset()
err = numericColumn.setTableName("table1")
assert.NilError(t, err)
err = numericColumn.SerializeSql(&out, FOR_PROJECTION)
err = numericColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
@ -38,33 +37,35 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement {
return d
}
func (d *deleteStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("DELETE FROM ")
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{}
queryData.WriteString("DELETE FROM ")
if d.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
return "", nil, errors.New("nil tableName.")
}
if err = d.table.SerializeSql(buf); err != nil {
if err = d.table.SerializeSql(queryData); err != nil {
return
}
if d.where == nil {
return "", errors.Newf("Deleting without a WHERE clause. Generated sql: %s", buf.String())
return "", nil, errors.New("Deleting without a WHERE clause.")
}
_, _ = buf.WriteString(" WHERE ")
if err = d.where.SerializeSql(buf); err != nil {
queryData.WriteString(" WHERE ")
if err = d.where.Serialize(queryData); err != nil {
return
}
if d.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = d.order.SerializeSql(buf); err != nil {
queryData.WriteString(" ORDER BY ")
if err = d.order.Serialize(queryData); err != nil {
return
}
}
return buf.String() + ";", nil
return queryData.queryBuff.String() + ";", queryData.args, nil
}

View file

@ -14,7 +14,7 @@ import (
"time"
)
func Execute(db types.Db, query string, destinationPtr interface{}) error {
func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error {
if db == nil {
return errors.New("db is nil")
}
@ -28,7 +28,7 @@ func Execute(db types.Db, query string, destinationPtr interface{}) error {
return errors.New("Destination has to be a pointer to slice or pointer to struct ")
}
rows, err := db.Query(query)
rows, err := db.Query(query, args...)
if err != nil {
return err
@ -72,7 +72,7 @@ func Execute(db types.Db, query string, destinationPtr interface{}) error {
return err
}
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROWS PROCESSED")
fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED")
return nil
}

View file

@ -1,8 +1,6 @@
package sqlbuilder
import (
"bytes"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/dropbox/godropbox/errors"
)
@ -42,8 +40,8 @@ func (e *expressionInterfaceImpl) Desc() OrderByClause {
return &orderByClause{expression: e.parent, ascent: false}
}
func (e *expressionInterfaceImpl) SerializeForProjection(out *bytes.Buffer) error {
return e.parent.SerializeSql(out, FOR_PROJECTION)
func (e *expressionInterfaceImpl) SerializeForProjection(out *queryData) error {
return e.parent.Serialize(out, FOR_PROJECTION)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
@ -62,21 +60,21 @@ func newBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Express
return binaryExpression
}
func (c *binaryExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) error {
if c.lhs == nil {
return errors.Newf("nil lhs. Generated sql: %s", out.String())
return errors.Newf("nil lhs.")
}
if err = c.lhs.SerializeSql(out); err != nil {
return
if err := c.lhs.Serialize(out); err != nil {
return err
}
_, _ = out.Write(c.operator)
out.Write(c.operator)
if c.rhs == nil {
return errors.Newf("nil rhs. Generated sql: %s", out.String())
return errors.Newf("nil rhs.")
}
if err = c.rhs.SerializeSql(out); err != nil {
return
if err := c.rhs.Serialize(out); err != nil {
return err
}
return nil
@ -97,80 +95,61 @@ func newPrefixExpression(expression Expression, operator []byte) prefixExpressio
return prefixExpression
}
func (p *prefixExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
_, _ = out.Write(p.operator)
func (p *prefixExpression) Serialize(out *queryData, options ...serializeOption) error {
out.Write(p.operator)
if p.expression == nil {
return errors.Newf("nil prefix expression. Generated sql: %s", out.String())
return errors.Newf("nil prefix expression.")
}
if err = p.expression.SerializeSql(out); err != nil {
return
if err := p.expression.Serialize(out); err != nil {
return err
}
return nil
}
// Representation of n-ary conjunctions (AND/OR)
type conjunctExpression struct {
expressions []BoolExpression
conjunction []byte
}
func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if len(conj.expressions) == 0 {
return errors.Newf(
"Empty conjunction. Generated sql: %s",
out.String())
}
clauses := make([]Clause, len(conj.expressions), len(conj.expressions))
for i, expr := range conj.expressions {
clauses[i] = expr
}
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, conj.conjunction, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil
}
//
//// Representation of n-ary conjunctions (AND/OR)
//type conjunctExpression struct {
// expressions []Expression
// conjunction []byte
//}
//
//func (conj *conjunctExpression) Serialize(out *queryData, options ...serializeOption) error {
// if len(conj.expressions) == 0 {
// return errors.New("Empty conjunction.")
// }
//
// //clauses := make([]Clause, len(conj.expressions), len(conj.expressions))
// //for i, expr := range conj.expressions {
// // clauses[i] = expr
// //}
//
// useParentheses := len(conj.expressions) > 1
// if useParentheses {
// out.WriteByte('(')
// }
//
// if err := serializeExpressionList(conj.expressions, string(conj.conjunction), out); err != nil {
// return err
// }
//
// if useParentheses {
// out.WriteByte(')')
// }
//
// return nil
//}
//--------------------------------------------------------------
// Representation of an escaped literal
type literalExpression struct {
expressionInterfaceImpl
value sqltypes.Value
}
func NewLiteralExpression(value sqltypes.Value) *literalExpression {
exp := literalExpression{value: value}
exp.expressionInterfaceImpl.parent = &exp
return &exp
}
func (c literalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
sqltypes.Value(c.value).EncodeSql(out)
return nil
}
//------------------------------------------------------//
//// Dummy type for select *
//type ColumnList []Column
//
//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
//func (cl ColumnList) Serialize(out *bytes.Buffer, options ...serializeOption) error {
// for i, column := range cl {
// err := column.SerializeSql(out)
// err := column.Serialize(out)
//
// if err != nil {
// return err

View file

@ -2,124 +2,88 @@
package sqlbuilder
import (
"bytes"
"strconv"
"strings"
"time"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/dropbox/godropbox/errors"
)
type orderByClause struct {
isOrderByClause
expression Expression
ascent bool
}
func (o *orderByClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
if o.expression == nil {
return errors.Newf(
"nil order by clause. Generated sql: %s",
out.String())
}
if err := o.expression.SerializeSql(out); err != nil {
return err
}
if o.ascent {
_, _ = out.WriteString(" ASC")
} else {
_, _ = out.WriteString(" DESC")
}
return nil
}
func Asc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: true}
}
func Desc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: false}
}
func serializeClauses(
clauses []Clause,
separator []byte,
out *bytes.Buffer) (err error) {
if clauses == nil || len(clauses) == 0 {
return errors.Newf("Empty clauses. Generated sql: %s", out.String())
}
if clauses[0] == nil {
return errors.Newf("nil clause. Generated sql: %s", out.String())
}
if err = clauses[0].SerializeSql(out); err != nil {
return
}
for _, c := range clauses[1:] {
_, _ = out.Write(separator)
if c == nil {
return errors.Newf("nil clause. Generated sql: %s", out.String())
}
if err = c.SerializeSql(out); err != nil {
return
}
}
return nil
}
// Representation of n-ary arithmetic (+ - * /)
type arithmeticExpression struct {
expressionInterfaceImpl
expressions []Expression
operator []byte
}
func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if len(arith.expressions) == 0 {
return errors.Newf(
"Empty arithmetic expression. Generated sql: %s",
out.String())
}
clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
for i, expr := range arith.expressions {
clauses[i] = expr
}
useParentheses := len(clauses) > 1
if useParentheses {
_ = out.WriteByte('(')
}
if err = serializeClauses(clauses, arith.operator, out); err != nil {
return
}
if useParentheses {
_ = out.WriteByte(')')
}
return nil
}
//func serializeClauses(
// clauses []Clause,
// separator []byte,
// out *bytes.Buffer) (err error) {
//
// if clauses == nil || len(clauses) == 0 {
// return errors.Newf("Empty clauses.")
// }
//
// if clauses[0] == nil {
// return errors.Newf("nil clause.")
// }
// if err = clauses[0].Serialize(out); err != nil {
// return
// }
//
// for _, c := range clauses[1:] {
// _, _ = out.Write(separator)
//
// if c == nil {
// return errors.Newf("nil clause.")
// }
// if err = c.Serialize(out); err != nil {
// return
// }
// }
//
// return nil
//}
//
//// Representation of n-ary arithmetic (+ - * /)
//type arithmeticExpression struct {
// expressionInterfaceImpl
// expressions []Expression
// operator []byte
//}
//
//func (arith *arithmeticExpression) Serialize(out *queryData, options ...serializeOption) error {
// if len(arith.expressions) == 0 {
// return errors.Newf(
// "Empty arithmetic expression.")
// }
//
// clauses := make([]Clause, len(arith.expressions), len(arith.expressions))
// for i, expr := range arith.expressions {
// clauses[i] = expr
// }
//
// useParentheses := len(clauses) > 1
// if useParentheses {
// _ = out.WriteByte('(')
// }
//
// if err = serializeClauses(clauses, arith.operator, out); err != nil {
// return
// }
//
// if useParentheses {
// _ = out.WriteByte(')')
// }
//
// return nil
//}
//
type tupleExpression struct {
expressionInterfaceImpl
elements listClause
}
func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
if len(tuple.elements.clauses) < 1 {
func (tuple *tupleExpression) Serialize(out *queryData, options ...serializeOption) error {
if len(tuple.elements.clauses) == 0 {
return errors.Newf("Tuples must include at least one element")
}
return tuple.elements.SerializeSql(out)
return tuple.elements.Serialize(out)
}
func Tuple(exprs ...Expression) Expression {
@ -141,61 +105,62 @@ type listClause struct {
includeParentheses bool
}
func (list *listClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (list *listClause) Serialize(out *queryData, options ...serializeOption) error {
if list.includeParentheses {
_ = out.WriteByte('(')
out.WriteByte('(')
}
if err := serializeClauses(list.clauses, []byte(","), out); err != nil {
if err := serializeClauseList(list.clauses, out); err != nil {
return err
}
if list.includeParentheses {
_ = out.WriteByte(')')
out.WriteByte(')')
}
return nil
}
type funcExpression struct {
expressionInterfaceImpl
funcName string
args *listClause
}
func (c *funcExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
if !validIdentifierName(c.funcName) {
return errors.Newf(
"Invalid function name: %s. Generated sql: %s",
c.funcName,
out.String())
}
_, _ = out.WriteString(c.funcName)
if c.args == nil {
_, _ = out.WriteString("()")
} else {
return c.args.SerializeSql(out)
}
return nil
}
// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
func SqlFunc(funcName string, expressions ...Expression) Expression {
f := &funcExpression{
funcName: funcName,
}
if len(expressions) > 0 {
args := make([]Clause, len(expressions), len(expressions))
for i, expr := range expressions {
args[i] = expr
}
f.args = &listClause{
clauses: args,
includeParentheses: true,
}
}
return f
}
//
//type funcExpression struct {
// expressionInterfaceImpl
// funcName string
// args *listClause
//}
//
//func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error {
// if !validIdentifierName(c.funcName) {
// return errors.Newf(
// "Invalid function name: %s.",
// c.funcName,
// out.String())
// }
// _, _ = out.WriteString(c.funcName)
// if c.args == nil {
// _, _ = out.WriteString("()")
// } else {
// return c.args.Serialize(out)
// }
// return nil
//}
//
//// Returns a representation of sql function call "func_call(c[0], ..., c[n-1])
//func SqlFunc(funcName string, expressions ...Expression) Expression {
// f := &funcExpression{
// funcName: funcName,
// }
// if len(expressions) > 0 {
// args := make([]Clause, len(expressions), len(expressions))
// for i, expr := range expressions {
// args[i] = expr
// }
//
// f.args = &listClause{
// clauses: args,
// includeParentheses: true,
// }
// }
// return f
//}
type intervalExpression struct {
expressionInterfaceImpl
@ -205,23 +170,24 @@ type intervalExpression struct {
var intervalSep = ":"
func (c *intervalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
func (c *intervalExpression) Serialize(out *queryData, options ...serializeOption) error {
hours := c.duration / time.Hour
minutes := (c.duration % time.Hour) / time.Minute
sec := (c.duration % time.Minute) / time.Second
msec := (c.duration % time.Second) / time.Microsecond
_, _ = out.WriteString("INTERVAL '")
out.WriteString("INTERVAL '")
if c.negative {
_, _ = out.WriteString("-")
out.WriteString("-")
}
_, _ = out.WriteString(strconv.FormatInt(int64(hours), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(sec), 10))
_, _ = out.WriteString(intervalSep)
_, _ = out.WriteString(strconv.FormatInt(int64(msec), 10))
_, _ = out.WriteString("' HOUR_MICROSECOND")
out.WriteString(strconv.FormatInt(int64(hours), 10))
out.WriteString(intervalSep)
out.WriteString(strconv.FormatInt(int64(minutes), 10))
out.WriteString(intervalSep)
out.WriteString(strconv.FormatInt(int64(sec), 10))
out.WriteString(intervalSep)
out.WriteString(strconv.FormatInt(int64(msec), 10))
out.WriteString("' HOUR_MICROSECOND")
return nil
}
@ -246,45 +212,45 @@ func EscapeForLike(s string) string {
}
// Returns an escaped literal string
func Literal(v interface{}) Expression {
value, err := sqltypes.BuildValue(v)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
return NewLiteralExpression(value)
}
// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
func Add(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" + "),
}
}
// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
func Sub(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" - "),
}
}
// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
func Mul(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" * "),
}
}
// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
func Div(expressions ...Expression) Expression {
return &arithmeticExpression{
expressions: expressions,
operator: []byte(" / "),
}
}
//func Literal(v interface{}) Expression {
// value, err := sqltypes.BuildValue(v)
// if err != nil {
// panic(errors.Wrap(err, "Invalid literal value"))
// }
// return NewLiteralExpression(value)
//}
//
//// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses
//func Add(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" + "),
// }
//}
//
//// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses
//func Sub(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" - "),
// }
//}
//
//// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses
//func Mul(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" * "),
// }
//}
//
//// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses
//func Div(expressions ...Expression) Expression {
// return &arithmeticExpression{
// expressions: expressions,
// operator: []byte(" / "),
// }
//}
//TODO: Uncomment
//
@ -336,14 +302,15 @@ type ifExpression struct {
falseExpression Expression
}
func (exp *ifExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
_, _ = out.WriteString("IF(")
_ = exp.conditional.SerializeSql(out)
_, _ = out.WriteString(",")
_ = exp.trueExpression.SerializeSql(out)
_, _ = out.WriteString(",")
_ = exp.falseExpression.SerializeSql(out)
_, _ = out.WriteString(")")
func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("IF(")
_ = exp.conditional.Serialize(out)
out.WriteString(",")
_ = exp.trueExpression.Serialize(out)
out.WriteString(",")
_ = exp.falseExpression.Serialize(out)
out.WriteString(")")
return nil
}
@ -371,7 +338,7 @@ func If(conditional BoolExpression,
// }
//}
//
//func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error {
//func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error {
// _, _ = out.WriteString("VALUES(")
// _ = cv.column.SerializeSqlForColumnList(out)
// _ = out.WriteByte(')')

View file

@ -19,7 +19,7 @@ func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
}
@ -28,7 +28,7 @@ func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
}
@ -37,7 +37,7 @@ func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -48,11 +48,11 @@ func (s *ExprSuite) TestTupleExpr(c *gc.C) {
expr := Tuple()
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
expr = Tuple(table1Col1, Literal(1), Literal("five"))
err = expr.SerializeSql(buf)
err = expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -68,7 +68,7 @@ func (s *ExprSuite) TestLikeExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -84,7 +84,7 @@ func (s *ExprSuite) TestRegexExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -100,7 +100,7 @@ func (s *ExprSuite) TestAndExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -115,7 +115,7 @@ func (s *ExprSuite) TestOrExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -130,7 +130,7 @@ func (s *ExprSuite) TestAddExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -142,7 +142,7 @@ func (s *ExprSuite) TestSubExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -154,7 +154,7 @@ func (s *ExprSuite) TestMulExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -166,7 +166,7 @@ func (s *ExprSuite) TestDivExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -178,7 +178,7 @@ func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
}
@ -187,7 +187,7 @@ func (s *ExprSuite) TestNegateExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -199,7 +199,7 @@ func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
}
@ -208,7 +208,7 @@ func (s *ExprSuite) TestEqExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -220,7 +220,7 @@ func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -232,7 +232,7 @@ func (s *ExprSuite) TestNeqExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -244,7 +244,7 @@ func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -256,7 +256,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -268,7 +268,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -283,7 +283,7 @@ func (s *ExprSuite) TestGtExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -295,7 +295,7 @@ func (s *ExprSuite) TestGteExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -308,7 +308,7 @@ func (s *ExprSuite) TestInExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -321,7 +321,7 @@ func (s *ExprSuite) TestInExprEmptyList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -333,7 +333,7 @@ func (s *ExprSuite) TestSqlFuncExprNilInArgList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.NotNil)
}
@ -342,7 +342,7 @@ func (s *ExprSuite) TestSqlFuncExprEmptyArgList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -354,7 +354,7 @@ func (s *ExprSuite) TestSqlFuncExprNonEmptyArgList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.SerializeSql(buf)
err := expr.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -366,7 +366,7 @@ func (s *ExprSuite) TestOrderByClauseNilExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
err := clause.Serialize(buf)
c.Assert(err, gc.NotNil)
}
@ -375,7 +375,7 @@ func (s *ExprSuite) TestAsc(c *gc.C) {
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
err := clause.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -387,7 +387,7 @@ func (s *ExprSuite) TestDesc(c *gc.C) {
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
err := clause.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -400,7 +400,7 @@ func (s *ExprSuite) TestIf(c *gc.C) {
buf := &bytes.Buffer{}
err := clause.SerializeSql(buf)
err := clause.Serialize(buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -538,7 +538,7 @@ func (s *ExprSuite) TestInterval(c *gc.C) {
for i, tt := range testTable {
buf.Reset()
err := Interval(tt.interval).SerializeSql(buf)
err := Interval(tt.interval).Serialize(buf)
c.Assert(err, gc.Equals, tt.expectedErr,
gc.Commentf("experiment #%d", i))
if err == nil {

View file

@ -1,7 +1,5 @@
package sqlbuilder
import "bytes"
type FuncExpression interface {
Expression
}
@ -26,10 +24,10 @@ func NewNumericFunc(name string, expression Expression) NumericExpression {
return numericFunc
}
func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (f *numericFunc) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString(f.name)
out.WriteString("(")
err := f.expression.SerializeSql(out)
err := f.expression.Serialize(out)
if err != nil {
return err
}
@ -39,7 +37,7 @@ func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption
}
//func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error {
// return f.SerializeSql(out)
// return f.Serialize(out)
//}
func MAX(expression NumericExpression) NumericExpression {

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/serenize/snaker"
@ -53,6 +52,7 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
// expression or default keyword
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
literalRow := []Clause{}
@ -122,84 +122,92 @@ func (i *insertStatementImpl) addError(err string) {
i.errors = append(i.errors, err)
}
func (s *insertStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("INSERT ")
_, _ = buf.WriteString("INTO ")
func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
if len(s.errors) > 0 {
return "", errors.New("sql builder errors: " + strings.Join(s.errors, ", "))
return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", "))
}
queryData := &queryData{}
queryData.WriteString("INSERT INTO ")
if s.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
return "", nil, errors.Newf("nil tableName.")
}
buf.WriteString(s.table.SchemaName() + "." + s.table.TableName())
err = s.table.SerializeSql(queryData)
if err != nil {
return "", nil, err
}
if len(s.columns) > 0 {
_, _ = buf.WriteString(" (")
for i, col := range s.columns {
if i > 0 {
_ = buf.WriteByte(',')
queryData.WriteString(" (")
//for i, col := range s.columns {
// if i > 0 {
// queryData.WriteByte(',')
// }
//
// if col == nil {
// return "", nil, errors.New("nil column in columns list.")
// }
//
// queryData.WriteString(col.Name())
//}
err = serializeColumnList(s.columns, queryData)
if err != nil {
return "", nil, err
}
if col == nil {
return "", errors.Newf(
"nil column in columns list. Generated sql: %s",
buf.String())
}
buf.WriteString(col.Name())
}
buf.WriteString(") ")
queryData.WriteString(") ")
}
if len(s.rows) == 0 && s.query == nil {
return "", errors.Newf("No row or query specified. Generated sql: %s", buf.String())
return "", nil, errors.New("No row or query specified.")
}
if len(s.rows) > 0 && s.query != nil {
return "", errors.Newf("Only new rows or query has to be specified. Generated sql: %s", buf.String())
return "", nil, errors.New("Only new rows or query has to be specified.")
}
if len(s.rows) > 0 {
_, _ = buf.WriteString("VALUES (")
queryData.WriteString("VALUES (")
for row_i, row := range s.rows {
if row_i > 0 {
_, _ = buf.WriteString(", (")
queryData.WriteString(", (")
}
if len(row) != len(s.columns) {
return "", errors.Newf(
"# of values does not match # of columns. Generated sql: %s",
buf.String())
return "", nil, errors.New("# of values does not match # of columns.")
}
for col_i, value := range row {
if col_i > 0 {
_ = buf.WriteByte(',')
err = serializeClauseList(row, queryData)
if err != nil {
return "", nil, err
}
if value == nil {
return "", errors.Newf(
"nil value in row %d col %d. Generated sql: %s",
row_i,
col_i,
buf.String())
}
if err = value.SerializeSql(buf); err != nil {
return
}
}
_ = buf.WriteByte(')')
//for col_i, value := range row {
// if col_i > 0 {
// queryData.WriteByte(',')
// }
//
// if value == nil {
// return "", nil, errors.Newf("nil value in row %d col %d.", row_i, col_i)
// }
//
// if err = value.Serialize(queryData); err != nil {
// return
// }
//}
queryData.WriteByte(')')
}
}
if s.query != nil {
err = s.query.SerializeSql(buf)
err = s.query.Serialize(queryData)
if err != nil {
return
@ -207,16 +215,16 @@ func (s *insertStatementImpl) String() (sql string, err error) {
}
if len(s.returning) > 0 {
buf.WriteString(" RETURNING ")
queryData.WriteString(" RETURNING ")
err = serializeProjectionList(s.returning, buf)
err = serializeProjectionList(s.returning, queryData)
if err != nil {
return
}
}
buf.WriteByte(';')
queryData.WriteByte(';')
return buf.String(), nil
return queryData.queryBuff.String(), queryData.args, nil
}

View file

@ -1,14 +1,12 @@
package sqlbuilder
import "bytes"
const (
DEFAULT keywordClause = "DEFAULT"
)
type keywordClause string
func (k keywordClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
func (k keywordClause) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString(string(k))
return nil

View file

@ -0,0 +1,22 @@
package sqlbuilder
// Representation of an escaped literal
type literalExpression struct {
expressionInterfaceImpl
value interface{}
}
func Literal(value interface{}) *literalExpression {
exp := literalExpression{value: value}
exp.expressionInterfaceImpl.parent = &exp
return &exp
}
func (l literalExpression) Serialize(out *queryData, options ...serializeOption) error {
//sqltypes.Value(c.value).EncodeSql(out)
out.InsertArgument(l.value)
return nil
}

View file

@ -1,11 +1,5 @@
package sqlbuilder
import (
"bytes"
"github.com/dropbox/godropbox/database/sqltypes"
"github.com/pkg/errors"
)
type NumericExpression interface {
Expression
@ -13,8 +7,11 @@ type NumericExpression interface {
EqL(literal interface{}) BoolExpression
NotEq(expression NumericExpression) BoolExpression
NotEqL(literal interface{}) BoolExpression
Gt(rhs NumericExpression) BoolExpression
GtEq(rhs NumericExpression) BoolExpression
GtEqL(literal interface{}) BoolExpression
LtEq(rhs NumericExpression) BoolExpression
LtEqL(literal interface{}) BoolExpression
@ -44,6 +41,10 @@ func (n *numericInterfaceImpl) NotEqL(literal interface{}) BoolExpression {
return NotEq(n.parent, Literal(literal))
}
func (n *numericInterfaceImpl) Gt(expression NumericExpression) BoolExpression {
return Gt(n.parent, expression)
}
func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression {
return GtEq(n.parent, expression)
}
@ -84,12 +85,8 @@ type numericLiteral struct {
func NewNumericLiteral(value interface{}) NumericExpression {
numericLiteral := numericLiteral{}
numericLiteral.literalExpression = *Literal(value)
sqlValue, err := sqltypes.BuildValue(value)
if err != nil {
panic(errors.Wrap(err, "Invalid literal value"))
}
numericLiteral.literalExpression = *NewLiteralExpression(sqlValue)
numericLiteral.numericInterfaceImpl.parent = &numericLiteral
return &numericLiteral
@ -133,10 +130,10 @@ func newNumericExpressionWrap(expression Expression) NumericExpression {
return &numericExpressionWrap
}
func (c *numericExpressionWrapper) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) {
func (c *numericExpressionWrapper) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("(")
err = c.expression.SerializeSql(out, options...)
err := c.expression.Serialize(out, options...)
out.WriteString(")")
return nil
return err
}

View file

@ -0,0 +1,46 @@
package sqlbuilder
import "github.com/dropbox/godropbox/errors"
type OrderByClause interface {
Clause
isOrderByClauseType()
}
type isOrderByClause struct {
}
func (o *isOrderByClause) isOrderByClauseType() {
}
type orderByClause struct {
isOrderByClause
expression Expression
ascent bool
}
func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) error {
if o.expression == nil {
return errors.Newf("nil orderBy by clause.")
}
if err := o.expression.Serialize(out); err != nil {
return err
}
if o.ascent {
out.WriteString(" ASC")
} else {
out.WriteString(" DESC")
}
return nil
}
func Asc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: true}
}
func Desc(expression Expression) OrderByClause {
return &orderByClause{expression: expression, ascent: false}
}

View file

@ -1,18 +1,16 @@
package sqlbuilder
import "bytes"
type Projection interface {
SerializeForProjection(out *bytes.Buffer) error
SerializeForProjection(out *queryData) error
}
//------------------------------------------------------//
// Dummy type for select * AllColumns
type ColumnList []Column
func (cl ColumnList) SerializeForProjection(out *bytes.Buffer) error {
func (cl ColumnList) SerializeForProjection(out *queryData) error {
for i, column := range cl {
err := column.SerializeSql(out, FOR_PROJECTION)
err := column.Serialize(out, FOR_PROJECTION)
if err != nil {
return err

View file

@ -1,9 +1,7 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
@ -12,17 +10,17 @@ type SelectStatement interface {
Statement
Expression
Where(expression BoolExpression) SelectStatement
GroupBy(expressions ...Expression) SelectStatement
HAVING(expressions BoolExpression) SelectStatement
DISTINCT() SelectStatement
WHERE(expression BoolExpression) SelectStatement
GROUP_BY(expressions ...Clause) SelectStatement
HAVING(boolExpression BoolExpression) SelectStatement
ORDER_BY(clauses ...OrderByClause) SelectStatement
LIMIT(limit int64) SelectStatement
OFFSET(offset int64) SelectStatement
FOR_UPDATE() SelectStatement
OrderBy(clauses ...OrderByClause) SelectStatement
Limit(limit int64) SelectStatement
Offset(offset int64) SelectStatement
Distinct() SelectStatement
WithSharedLock() SelectStatement
ForUpdate() SelectStatement
Comment(comment string) SelectStatement
Copy() SelectStatement
AsTable(alias string) *SelectStatementTable
@ -34,16 +32,16 @@ type selectStatementImpl struct {
expressionInterfaceImpl
table ReadableTable
distinct bool
projections []Projection
where BoolExpression
group *listClause
groupBy []Clause
having BoolExpression
order *listClause
comment string
orderBy []OrderByClause
limit, offset int64
withSharedLock bool
forUpdate bool
distinct bool
}
func newSelectStatement(
@ -55,26 +53,115 @@ func newSelectStatement(
projections: projections,
limit: -1,
offset: -1,
withSharedLock: false,
forUpdate: false,
distinct: false,
}
}
func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
str, err := s.String()
func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error {
out.WriteString("(")
err := s.serializeImpl(out, options...)
if err != nil {
return err
}
out.WriteString("(")
out.WriteString(str)
out.WriteString(")")
return nil
}
func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error {
out.WriteString("SELECT ")
if s.distinct {
out.WriteString("DISTINCT ")
}
if s.projections == nil || len(s.projections) == 0 {
return errors.New("No column selected for projection.")
}
err := serializeProjectionList(s.projections, out)
if err != nil {
return err
}
out.WriteString(" FROM ")
if s.table == nil {
return errors.Newf("nil tableName.")
}
if err := s.table.SerializeSql(out); err != nil {
return err
}
if s.where != nil {
out.WriteString(" WHERE ")
if err := s.where.Serialize(out); err != nil {
return err
}
}
if s.groupBy != nil && len(s.groupBy) > 0 {
out.WriteString(" GROUP BY ")
err := serializeClauseList(s.groupBy, out)
if err != nil {
return err
}
}
if s.having != nil {
out.WriteString(" HAVING ")
if err = s.having.Serialize(out); err != nil {
return err
}
}
if s.orderBy != nil {
out.WriteString(" ORDER BY ")
if err := serializeOrderByClauseList(s.orderBy, out); err != nil {
return err
}
}
if s.limit >= 0 {
out.WriteString(" LIMIT ")
out.InsertArgument(s.limit)
}
if s.offset >= 0 {
out.WriteString(" OFFSET ")
out.InsertArgument(s.offset)
}
if s.forUpdate {
out.WriteString(" FOR UPDATE")
}
return nil
}
// Return the properly escaped SQL statement, against the specified database
func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := queryData{}
err = q.serializeImpl(&queryData)
if err != nil {
return "", nil, err
}
return queryData.queryBuff.String(), queryData.args, nil
}
func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
return &SelectStatementTable{
statement: s,
@ -95,23 +182,14 @@ func (s *selectStatementImpl) Copy() SelectStatement {
return &ret
}
func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement {
func (q *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement {
q.where = expression
return q
}
func (q *selectStatementImpl) GroupBy(
expressions ...Expression) SelectStatement {
q.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
q.group.clauses[i] = e
}
return q
func (s *selectStatementImpl) GROUP_BY(cluases ...Clause) SelectStatement {
s.groupBy = cluases
return s
}
func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement {
@ -119,132 +197,31 @@ func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement
return q
}
func (q *selectStatementImpl) OrderBy(
clauses ...OrderByClause) SelectStatement {
func (q *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement {
q.orderBy = clauses
q.order = newOrderByListClause(clauses...)
return q
}
func (q *selectStatementImpl) Limit(limit int64) SelectStatement {
q.limit = limit
return q
}
func (q *selectStatementImpl) Distinct() SelectStatement {
q.distinct = true
return q
}
func (q *selectStatementImpl) WithSharedLock() SelectStatement {
// We don't need to grab a read lock if we're going to grab a write one
if !q.forUpdate {
q.withSharedLock = true
}
return q
}
func (q *selectStatementImpl) ForUpdate() SelectStatement {
// Clear a request for a shared lock if we're asking for a write one
q.withSharedLock = false
q.forUpdate = true
return q
}
func (q *selectStatementImpl) Offset(offset int64) SelectStatement {
func (q *selectStatementImpl) OFFSET(offset int64) SelectStatement {
q.offset = offset
return q
}
func (q *selectStatementImpl) Comment(comment string) SelectStatement {
q.comment = comment
func (q *selectStatementImpl) LIMIT(limit int64) SelectStatement {
q.limit = limit
return q
}
// Return the properly escaped SQL statement, against the specified database
func (q *selectStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("SELECT ")
if err = writeComment(q.comment, buf); err != nil {
return
func (q *selectStatementImpl) DISTINCT() SelectStatement {
q.distinct = true
return q
}
if q.distinct {
_, _ = buf.WriteString("DISTINCT ")
}
if q.projections == nil || len(q.projections) == 0 {
return "", errors.Newf(
"No column selected. Generated sql: %s",
buf.String())
}
for i, col := range q.projections {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column selected. Generated sql: %s",
buf.String())
}
if err = col.SerializeForProjection(buf); err != nil {
return
}
}
_, _ = buf.WriteString(" FROM ")
if q.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = q.table.SerializeSql(buf); err != nil {
return
}
if q.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = q.where.SerializeSql(buf); err != nil {
return
}
}
if q.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = q.group.SerializeSql(buf); err != nil {
return
}
}
if q.having != nil {
buf.WriteString(" HAVING ")
if err = q.having.SerializeSql(buf); err != nil {
return
}
}
if q.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = q.order.SerializeSql(buf); err != nil {
return
}
}
if q.limit >= 0 {
if q.offset >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit))
}
}
if q.forUpdate {
_, _ = buf.WriteString(" FOR UPDATE")
} else if q.withSharedLock {
_, _ = buf.WriteString(" LOCK IN SHARE MODE")
}
return buf.String(), nil
func (q *selectStatementImpl) FOR_UPDATE() SelectStatement {
q.forUpdate = true
return q
}
func NumExp(statement SelectStatement) NumericExpression {

View file

@ -1,7 +1,5 @@
package sqlbuilder
import "bytes"
type SelectStatementTable struct {
statement SelectStatement
columns []Column
@ -41,16 +39,14 @@ func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn {
return strColumn
}
func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error {
func (s *SelectStatementTable) SerializeSql(out *queryData) error {
out.WriteString("( ")
statementStr, err := s.statement.String()
err := s.statement.Serialize(out)
if err != nil {
return err
}
out.WriteString(statementStr)
out.WriteString(" ) AS ")
out.WriteString(s.alias)

View file

@ -1,17 +1,13 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/sub0zero/go-sqlbuilder/types"
"regexp"
"github.com/dropbox/godropbox/errors"
)
type Statement interface {
// String returns generated SQL as string.
String() (sql string, err error)
Sql() (query string, args []interface{}, err error)
Query(db types.Db, destination interface{}) error
Execute(db types.Db) (sql.Result, error)
@ -88,10 +84,10 @@ type Statement interface {
//
// for idx, lock := range s.locks {
// if lock.t == nil {
// return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
// return "", errors.Newf("nil tableName.", buf.String())
// }
//
// if err = lock.t.SerializeSql(buf); err != nil {
// if err = lock.t.Serialize(buf); err != nil {
// return
// }
//
@ -162,23 +158,23 @@ type Statement interface {
//
// Once again, teisenberger is lazy. Here's a quick filter on comments
var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$")
func isValidComment(comment string) bool {
return validCommentRegexp.MatchString(comment)
}
func writeComment(comment string, buf *bytes.Buffer) error {
if comment != "" {
_, _ = buf.WriteString("/* ")
if !isValidComment(comment) {
return errors.Newf("Invalid comment: %s", comment)
}
_, _ = buf.WriteString(comment)
_, _ = buf.WriteString(" */")
}
return nil
}
//var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$")
//
//func isValidComment(comment string) bool {
// return validCommentRegexp.MatchString(comment)
//}
//
//func writeComment(comment string, buf *bytes.Buffer) error {
// if comment != "" {
// _, _ = buf.WriteString("/* ")
// if !isValidComment(comment) {
// return errors.Newf("Invalid comment: %s", comment)
// }
// _, _ = buf.WriteString(comment)
// _, _ = buf.WriteString(" */")
// }
// return nil
//}
func newOrderByListClause(clauses ...OrderByClause) *listClause {
ret := &listClause{

View file

@ -475,14 +475,14 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
"All inner selects in Union statement must select the "+
"same number of columns. For sanity, you probably "+
"want to select the same tableName columns in the same "+
"order. If you are selecting on multiple tables, "+
"orderBy. If you are selecting on multiple tables, "+
"use Null to pad to the right number of fields.")
}
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
// tests on outer statement: Group By, Order By, Limit
// on inner statement: AndWhere, WHERE (with And), Order By, Limit
// tests on outer statement: Group By, Order By, LIMIT
// on inner statement: AndWhere, WHERE (with And), Order By, LIMIT
select_queries := make([]SelectStatement, 0, 3)
// We're not trying to write a SQL parser, so we won't warn if you do something silly like

View file

@ -3,7 +3,6 @@
package sqlbuilder
import (
"bytes"
"fmt"
"github.com/dropbox/godropbox/errors"
)
@ -15,7 +14,7 @@ type TableInterface interface {
Columns() []Column
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
SerializeSql(out *bytes.Buffer) error
SerializeSql(out *queryData) error
}
// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
@ -52,9 +51,6 @@ type WritableTable interface {
// Defines a physical tableName in the database that is both readable and writable.
// This function will panic if name is not valid
func NewTable(schemaName, name string, columns ...Column) *Table {
if !validIdentifierName(name) {
panic("Invalid tableName name")
}
t := &Table{
schemaName: schemaName,
@ -154,28 +150,20 @@ func (t *Table) ForceIndex(index string) *Table {
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
func (t *Table) SerializeSql(out *bytes.Buffer) error {
func (t *Table) SerializeSql(out *queryData) error {
if t == nil {
return errors.Newf("nil tableName. Generated sql: %s", out.String())
return errors.Newf("nil tableName.")
}
_, _ = out.WriteString(t.schemaName)
_, _ = out.WriteString(".")
_, _ = out.WriteString(t.TableName())
out.WriteString(t.schemaName)
out.WriteString(".")
out.WriteString(t.TableName())
if len(t.alias) > 0 {
out.WriteString(" AS ")
out.WriteString(t.alias)
}
if t.forcedIndex != "" {
if !validIdentifierName(t.forcedIndex) {
return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex)
}
_, _ = out.WriteString(" FORCE INDEX (")
_, _ = out.WriteString(t.forcedIndex)
_, _ = out.WriteString(")")
}
return nil
}
@ -307,7 +295,6 @@ func CrossJoin(
return newJoinTable(lhs, rhs, CROSS_JOIN, nil)
}
// Returns the tableName's name in the database
func (t *joinTable) SchemaName() string {
return ""
}
@ -328,16 +315,16 @@ func (t *joinTable) Column(name string) Column {
panic("Not implemented")
}
func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
func (t *joinTable) SerializeSql(out *queryData) (err error) {
if t.lhs == nil {
return errors.Newf("nil lhs. Generated sql: %s", out.String())
return errors.Newf("nil lhs.")
}
if t.rhs == nil {
return errors.Newf("nil rhs. Generated sql: %s", out.String())
return errors.Newf("nil rhs.")
}
if t.onCondition == nil && t.join_type != CROSS_JOIN {
return errors.Newf("nil onCondition. Generated sql: %s", out.String())
return errors.Newf("nil onCondition.")
}
if err = t.lhs.SerializeSql(out); err != nil {
@ -346,11 +333,11 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
switch t.join_type {
case INNER_JOIN:
_, _ = out.WriteString(" JOIN ")
out.WriteString(" JOIN ")
case LEFT_JOIN:
_, _ = out.WriteString(" LEFT JOIN ")
out.WriteString(" LEFT JOIN ")
case RIGHT_JOIN:
_, _ = out.WriteString(" RIGHT JOIN ")
out.WriteString(" RIGHT JOIN ")
case FULL_JOIN:
out.WriteString(" FULL JOIN ")
case CROSS_JOIN:
@ -362,8 +349,8 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) {
}
if t.onCondition != nil {
_, _ = out.WriteString(" ON ")
if err = t.onCondition.SerializeSql(out); err != nil {
out.WriteString(" ON ")
if err = t.onCondition.Serialize(out); err != nil {
return
}
}

View file

@ -1,10 +1,6 @@
package sqlbuilder
// A clause that can be used in order by
type OrderByClause interface {
Clause
isOrderByClauseInterface
}
// A clause that can be used in orderBy by
// A clause that is selectable.
//type Projection interface {
@ -16,9 +12,9 @@ type OrderByClause interface {
//type ColumnList []Column
//
//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
//func (cl ColumnList) Serialize(out *bytes.Buffer, options ...serializeOption) error {
// for i, column := range cl {
// column.SerializeSql(out)
// column.Serialize(out)
//
// if i != len(cl)-1 {
// out.WriteString(", ")
@ -49,16 +45,6 @@ type OrderByClause interface {
// Boiler plates ...
//
type isOrderByClauseInterface interface {
isOrderByClauseType()
}
type isOrderByClause struct {
}
func (o *isOrderByClause) isOrderByClauseType() {
}
//
//type isProjectionInterface interface {
// isProjectionType()

View file

@ -1,17 +1,9 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"fmt"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
// By default, rows selected by a UNION statement are out-of-order
// By default, rows selected by a UNION statement are out-of-orderBy
// If you have an ORDER BY on an inner SELECT statement, the only thing
// it affects is the LIMIT clause on that inner statement (the ordering will
// still be out-of-order).
// still be out-of-orderBy).
type UnionStatement interface {
Statement
@ -27,177 +19,178 @@ type UnionStatement interface {
Offset(offset int64) UnionStatement
}
func Union(selects ...SelectStatement) UnionStatement {
return &unionStatementImpl{
selects: selects,
limit: -1,
offset: -1,
unique: true,
}
}
func UnionAll(selects ...SelectStatement) UnionStatement {
return &unionStatementImpl{
selects: selects,
limit: -1,
offset: -1,
unique: false,
}
}
// Similar to selectStatementImpl, but less complete
type unionStatementImpl struct {
selects []SelectStatement
where BoolExpression
group *listClause
order *listClause
limit, offset int64
// True if results of the union should be deduped.
unique bool
}
func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(s, db, destination)
}
func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}
func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement {
us.where = expression
return us
}
// Further filter the query, instead of replacing the filter
func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement {
if us.where == nil {
return us.Where(expression)
}
us.where = And(us.where, expression)
return us
}
func (us *unionStatementImpl) GroupBy(
expressions ...Expression) UnionStatement {
us.group = &listClause{
clauses: make([]Clause, len(expressions), len(expressions)),
includeParentheses: false,
}
for i, e := range expressions {
us.group.clauses[i] = e
}
return us
}
func (us *unionStatementImpl) OrderBy(
clauses ...OrderByClause) UnionStatement {
us.order = newOrderByListClause(clauses...)
return us
}
func (us *unionStatementImpl) Limit(limit int64) UnionStatement {
us.limit = limit
return us
}
func (us *unionStatementImpl) Offset(offset int64) UnionStatement {
us.offset = offset
return us
}
func (us *unionStatementImpl) String() (sql string, err error) {
if len(us.selects) == 0 {
return "", errors.Newf("Union statement must have at least one SELECT")
}
if len(us.selects) == 1 {
return us.selects[0].String()
}
// Union statements in MySQL require that the same number of columns in each subquery
var projections []Projection
for _, statement := range us.selects {
// do a type assertion to get at the underlying struct
statementImpl, ok := statement.(*selectStatementImpl)
if !ok {
return "", errors.Newf(
"Expected inner select statement to be of type " +
"selectStatementImpl")
}
// check that for limit for statements with order by clauses
if statementImpl.order != nil && statementImpl.limit < 0 {
return "", errors.Newf(
"All inner selects in Union statement must have LIMIT if " +
"they have ORDER BY")
}
// check number of projections
if projections == nil {
projections = statementImpl.projections
} else {
if len(projections) != len(statementImpl.projections) {
return "", errors.Newf(
"All inner selects in Union statement must select the " +
"same number of columns. For sanity, you probably " +
"want to select the same tableName columns in the same " +
"order. If you are selecting on multiple tables, " +
"use Null to pad to the right number of fields.")
}
}
}
buf := new(bytes.Buffer)
for i, statement := range us.selects {
if i != 0 {
if us.unique {
_, _ = buf.WriteString(" UNION ")
} else {
_, _ = buf.WriteString(" UNION ALL ")
}
}
_, _ = buf.WriteString("(")
selectSql, err := statement.String()
if err != nil {
return "", err
}
_, _ = buf.WriteString(selectSql)
_, _ = buf.WriteString(")")
}
if us.where != nil {
_, _ = buf.WriteString(" WHERE ")
if err = us.where.SerializeSql(buf); err != nil {
return
}
}
if us.group != nil {
_, _ = buf.WriteString(" GROUP BY ")
if err = us.group.SerializeSql(buf); err != nil {
return
}
}
if us.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = us.order.SerializeSql(buf); err != nil {
return
}
}
if us.limit >= 0 {
if us.offset >= 0 {
_, _ = buf.WriteString(
fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit))
} else {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit))
}
}
return buf.String(), nil
}
//
//func Union(selects ...SelectStatement) UnionStatement {
// return &unionStatementImpl{
// selects: selects,
// limit: -1,
// offset: -1,
// unique: true,
// }
//}
//
//func UnionAll(selects ...SelectStatement) UnionStatement {
// return &unionStatementImpl{
// selects: selects,
// limit: -1,
// offset: -1,
// unique: false,
// }
//}
//
//// Similar to selectStatementImpl, but less complete
//type unionStatementImpl struct {
// selects []SelectStatement
// where BoolExpression
// group *listClause
// order *listClause
// limit, offset int64
// // True if results of the union should be deduped.
// unique bool
//}
//
//func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error {
// return Query(s, db, destination)
//}
//
//func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
// return Execute(u, db)
//}
//
//func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement {
// us.where = expression
// return us
//}
//
//// Further filter the query, instead of replacing the filter
//func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement {
// if us.where == nil {
// return us.Where(expression)
// }
// us.where = And(us.where, expression)
// return us
//}
//
//func (us *unionStatementImpl) GroupBy(
// expressions ...Expression) UnionStatement {
//
// us.group = &listClause{
// clauses: make([]Clause, len(expressions), len(expressions)),
// includeParentheses: false,
// }
//
// for i, e := range expressions {
// us.group.clauses[i] = e
// }
// return us
//}
//
//func (us *unionStatementImpl) OrderBy(
// clauses ...OrderByClause) UnionStatement {
//
// us.order = newOrderByListClause(clauses...)
// return us
//}
//
//func (us *unionStatementImpl) Limit(limit int64) UnionStatement {
// us.limit = limit
// return us
//}
//
//func (us *unionStatementImpl) Offset(offset int64) UnionStatement {
// us.offset = offset
// return us
//}
//
//func (us *unionStatementImpl) String() (sql string, err error) {
// if len(us.selects) == 0 {
// return "", errors.Newf("Union statement must have at least one SELECT")
// }
//
// if len(us.selects) == 1 {
// return us.selects[0].String()
// }
//
// // Union statements in MySQL require that the same number of columns in each subquery
// var projections []Projection
//
// for _, statement := range us.selects {
// // do a type assertion to get at the underlying struct
// statementImpl, ok := statement.(*selectStatementImpl)
// if !ok {
// return "", errors.Newf(
// "Expected inner select statement to be of type " +
// "selectStatementImpl")
// }
//
// // check that for limit for statements with orderBy by clauses
// if statementImpl.orderBy != nil && statementImpl.limit < 0 {
// return "", errors.Newf(
// "All inner selects in Union statement must have LIMIT if " +
// "they have ORDER BY")
// }
//
// // check number of projections
// if projections == nil {
// projections = statementImpl.projections
// } else {
// if len(projections) != len(statementImpl.projections) {
// return "", errors.Newf(
// "All inner selects in Union statement must select the " +
// "same number of columns. For sanity, you probably " +
// "want to select the same tableName columns in the same " +
// "orderBy. If you are selecting on multiple tables, " +
// "use Null to pad to the right number of fields.")
// }
// }
// }
//
// buf := new(bytes.Buffer)
// for i, statement := range us.selects {
// if i != 0 {
// if us.unique {
// _, _ = buf.WriteString(" UNION ")
// } else {
// _, _ = buf.WriteString(" UNION ALL ")
// }
// }
// _, _ = buf.WriteString("(")
// selectSql, err := statement.String()
// if err != nil {
// return "", err
// }
// _, _ = buf.WriteString(selectSql)
// _, _ = buf.WriteString(")")
// }
//
// if us.where != nil {
// _, _ = buf.WriteString(" WHERE ")
// if err = us.where.Serialize(buf); err != nil {
// return
// }
// }
//
// if us.group != nil {
// _, _ = buf.WriteString(" GROUP BY ")
// if err = us.group.Serialize(buf); err != nil {
// return
// }
// }
//
// if us.order != nil {
// _, _ = buf.WriteString(" ORDER BY ")
// if err = us.order.Serialize(buf); err != nil {
// return
// }
// }
//
// if us.limit >= 0 {
// if us.offset >= 0 {
// _, _ = buf.WriteString(
// fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit))
// } else {
// _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit))
// }
// }
// return buf.String(), nil
//}

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
@ -61,60 +60,64 @@ func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStateme
return u
}
func (u *updateStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("UPDATE ")
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
out.WriteString("UPDATE ")
if u.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
return "", nil, errors.New("nil tableName.")
}
if err = u.table.SerializeSql(buf); err != nil {
if err = u.table.SerializeSql(out); err != nil {
return
}
if len(u.updateValues) == 0 {
return "", errors.Newf(
"No column updated. Generated sql: %s",
buf.String())
return "", nil, errors.New("No column updated.")
}
_, _ = buf.WriteString(" SET")
out.WriteString(" SET")
if len(u.columns) > 1 {
buf.WriteString(" ( ")
out.WriteString(" ( ")
} else {
buf.WriteString(" ")
out.WriteString(" ")
}
for i, column := range u.columns {
if i > 0 {
buf.WriteString(", ")
}
//for i, column := range u.columns {
// if i > 0 {
// out.WriteString(", ")
// }
//
// out.WriteString(column.Name())
//
// if err != nil {
// return
// }
//}
buf.WriteString(column.Name())
err = serializeColumnList(u.columns, out)
if err != nil {
return
}
return "", nil, err
}
if len(u.columns) > 1 {
buf.WriteString(" )")
out.WriteString(" )")
}
buf.WriteString(" =")
out.WriteString(" =")
if len(u.updateValues) > 1 {
buf.WriteString(" (")
out.WriteString(" (")
}
for i, value := range u.updateValues {
if i > 0 {
buf.WriteString(", ")
out.WriteString(", ")
}
err = value.SerializeSql(buf)
err = value.Serialize(out)
if err != nil {
return
@ -122,29 +125,27 @@ func (u *updateStatementImpl) String() (sql string, err error) {
}
if len(u.updateValues) > 1 {
buf.WriteString(" )")
out.WriteString(" )")
}
if u.where == nil {
return "", errors.Newf(
"Updating without a WHERE clause. Generated sql: %s",
buf.String())
return "", nil, errors.New("Updating without a WHERE clause.")
}
_, _ = buf.WriteString(" WHERE ")
if err = u.where.SerializeSql(buf); err != nil {
out.WriteString(" WHERE ")
if err = u.where.Serialize(out); err != nil {
return
}
if len(u.returning) > 0 {
buf.WriteString(" RETURNING ")
out.WriteString(" RETURNING ")
err = serializeProjectionList(u.returning, buf)
err = serializeProjectionList(u.returning, out)
if err != nil {
return
}
}
return buf.String() + ";", nil
return out.queryBuff.String(), out.args, nil
}

View file

@ -83,7 +83,7 @@ func TestUpdate(t *testing.T) {
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.OrderBy(table1Col2)
// stmt.ORDER_BY(table1Col2)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
@ -99,7 +99,7 @@ func TestUpdate(t *testing.T) {
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.Limit(5)
// stmt.LIMIT(5)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//

View file

@ -1,19 +1,20 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
)
func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error {
for i, value := range expressions {
func serializeOrderByClauseList(orderByClauses []OrderByClause, out *queryData) error {
for i, value := range orderByClauses {
if i > 0 {
buf.WriteString(", ")
out.WriteString(", ")
}
err := value.SerializeSql(buf)
err := value.Serialize(out)
if err != nil {
return err
@ -23,13 +24,33 @@ func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error
return nil
}
func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error {
for i, value := range projections {
func serializeClauseList(clauses []Clause, out *queryData) (err error) {
for i, c := range clauses {
if i > 0 {
buf.WriteString(", ")
out.WriteString(", ")
}
err := value.SerializeForProjection(buf)
if c == nil {
return errors.New("nil clause.")
}
if err = c.Serialize(out); err != nil {
return
}
}
return nil
}
func serializeExpressionList(expressions []Expression, separator string, out *queryData) error {
for i, value := range expressions {
if i > 0 {
out.WriteString(separator)
}
err := value.Serialize(out)
if err != nil {
return err
@ -39,24 +60,55 @@ func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error
return nil
}
func serializeProjectionList(projections []Projection, out *queryData) error {
for i, col := range projections {
if i > 0 {
out.WriteByte(',')
}
if col == nil {
return errors.New("Projection expression is nil.")
}
if err := col.SerializeForProjection(out); err != nil {
return err
}
}
return nil
}
func serializeColumnList(columns []Column, out *queryData) error {
for i, col := range columns {
if i > 0 {
out.WriteByte(',')
}
if col == nil {
return errors.New("nil column in columns list.")
}
out.WriteString(col.Name())
}
return nil
}
func Query(statement Statement, db types.Db, destination interface{}) error {
query, err := statement.String()
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
return execution.Query(db, query, args, destination)
}
func Execute(statement Statement, db types.Db) (res sql.Result, err error) {
query, err := statement.String()
query, args, err := statement.Sql()
if err != nil {
return
}
res, err = db.Exec(query)
return
return db.Exec(query, args...)
}

View file

@ -25,13 +25,14 @@ func TestGenerateModel(t *testing.T) {
func TestSelect_ScanToStruct(t *testing.T) {
actor := model.Actor{}
query := Actor.SELECT(Actor.AllColumns).OrderBy(Actor.ActorID.Asc())
query := Actor.SELECT(Actor.AllColumns).ORDER_BY(Actor.ActorID.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`)
assert.Equal(t, len(args), 0)
err = query.Query(db, &actor)
@ -50,12 +51,14 @@ func TestSelect_ScanToStruct(t *testing.T) {
func TestSelect_ScanToSlice(t *testing.T) {
customers := []model.Customer{}
query := Customer.SELECT(Customer.AllColumns).OrderBy(Customer.CustomerID.Asc())
query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`)
assert.Equal(t, len(args), 0)
err = query.Query(db, &customers)
assert.NilError(t, err)
@ -76,7 +79,7 @@ func TestSelect_ScanToSlice(t *testing.T) {
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns).
// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2)))
//
// queryStr, err := query.String()
// queryStr, args, err := query.Sql()
// assert.NilError(t, err)
// assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`)
//
@ -104,14 +107,18 @@ func TestJoinQuerySlice(t *testing.T) {
query := Film.
INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns).
Where(Film.Rating.EqL(string(model.MpaaRating_NC17))).
Limit(15)
WHERE(Film.Rating.EqL(string(model.MpaaRating_NC17))).
LIMIT(15)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`)
//fmt.Println(queryStr)
fmt.Println(queryStr)
assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = $1 LIMIT $2`)
assert.Equal(t, len(args), 2)
assert.Equal(t, args[0], string(model.MpaaRating_NC17))
assert.Equal(t, args[1], int64(15))
err = query.Query(db, &filmsPerLanguage)
@ -149,7 +156,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) {
query := Film.INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
SELECT(Language.AllColumns, Film.AllColumns).
Limit(limit)
LIMIT(limit)
filmsPerLanguageWithPtrs := []*FilmsPerLanguage{}
err := query.Query(db, &filmsPerLanguageWithPtrs)
@ -179,7 +186,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
customersAsc := []model.Customer{}
err := Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
OrderBy(Customer.FirstName.Asc()).
ORDER_BY(Customer.FirstName.Asc()).
Query(db, &customersAsc)
assert.NilError(t, err)
@ -189,7 +196,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
customersDesc := []model.Customer{}
err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
OrderBy(Customer.FirstName.Desc()).
ORDER_BY(Customer.FirstName.Desc()).
Query(db, &customersDesc)
assert.NilError(t, err)
@ -202,7 +209,7 @@ func TestSelectOrderByAscDesc(t *testing.T) {
customersAscDesc := []model.Customer{}
err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName).
OrderBy(Customer.FirstName.Asc(), Customer.LastName.Desc()).
ORDER_BY(Customer.FirstName.Asc(), Customer.LastName.Desc()).
Query(db, &customersAscDesc)
assert.NilError(t, err)
@ -227,13 +234,14 @@ func TestSelectFullJoin(t *testing.T) {
query := Customer.
FULL_JOIN(Address, Customer.AddressID.Eq(Address.AddressID)).
SELECT(Customer.AllColumns, Address.AllColumns).
OrderBy(Customer.CustomerID.Asc())
ORDER_BY(Customer.CustomerID.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer FULL JOIN dvds.address ON customer.address_id = address.address_id ORDER BY customer.customer_id ASC`)
assert.Equal(t, len(args), 0)
allCustomersAndAddress := []struct {
Address *model.Address
@ -259,13 +267,14 @@ func TestSelectFullCrossJoin(t *testing.T) {
query := Customer.
CrossJoin(Address).
SELECT(Customer.AllColumns, Address.AllColumns).
OrderBy(Customer.CustomerID.Asc()).
Limit(1000)
ORDER_BY(Customer.CustomerID.Asc()).
LIMIT(1000)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT 1000`)
assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT $1`)
assert.Equal(t, len(args), 1)
customerAddresCrosJoined := []model.Customer{}
@ -286,9 +295,10 @@ func TestSelectSelfJoin(t *testing.T) {
query := f1.
INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))).
SELECT(f1.AllColumns, f2.AllColumns).
OrderBy(f1.FilmID.Asc())
ORDER_BY(f1.FilmID.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.Equal(t, len(args), 0)
assert.NilError(t, err)
@ -326,10 +336,11 @@ func TestSelectAliasColumn(t *testing.T) {
SELECT(f1.Title.As("thesame_length_films.title1"),
f2.Title.As("thesame_length_films.title2"),
f1.Length.As("thesame_length_films.length")).
OrderBy(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()).
Limit(1000)
ORDER_BY(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()).
LIMIT(1000)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.Equal(t, len(args), 1)
assert.NilError(t, err)
@ -372,9 +383,10 @@ func TestSelectSelfReferenceType(t *testing.T) {
INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)).
SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, len(args), 0)
staffs := []staff{}
@ -394,13 +406,13 @@ func TestSubQuery(t *testing.T) {
// selectStmtTable.RefIntColumnName("actor.last_name").As("nesto2"),
// )
//
//queryStr, err := query.String()
//queryStr, args, err := query.Sql()
//
//assert.NilError(t, err)
//
//fmt.Println(queryStr)
//
//avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).Limit(1))
//avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).LIMIT(1))
//
//Customer.
// INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))).
@ -408,7 +420,7 @@ func TestSubQuery(t *testing.T) {
// WHERE(Actor.LastName.Neq(avrgCustomer))
rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating).
Where(Film.Rating.EqL("R")).
WHERE(Film.Rating.EqL("R")).
AsTable("films")
query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)).
@ -420,10 +432,10 @@ func TestSubQuery(t *testing.T) {
rFilmsOnly.RefStringColumn(Film.Rating).As("film.rating"),
)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 1)
fmt.Println(queryStr)
}
@ -431,12 +443,12 @@ func TestSubQuery(t *testing.T) {
func TestSelectFunctions(t *testing.T) {
query := Film.SELECT(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate"))
str, err := query.String()
str, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`)
assert.Equal(t, len(args), 0)
fmt.Println(str)
}
@ -445,13 +457,13 @@ func TestSelectQueryScalar(t *testing.T) {
maxFilmRentalRate := sqlbuilder.NumExp(Film.SELECT(sqlbuilder.MAX(Film.RentalRate)))
query := Film.SELECT(Film.AllColumns).
Where(Film.RentalRate.Eq(maxFilmRentalRate)).
OrderBy(Film.FilmID.Asc())
WHERE(Film.RentalRate.Eq(maxFilmRentalRate)).
ORDER_BY(Film.FilmID.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 0)
fmt.Println(queryStr)
maxRentalRateFilms := []model.Film{}
@ -488,16 +500,17 @@ func TestSelectGroupByHaving(t *testing.T) {
Payment.CustomerID.As("customer_payment_sum.customer_id"),
sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"),
).
GroupBy(Payment.CustomerID).
OrderBy(sqlbuilder.SUM(Payment.Amount).Asc()).
HAVING(sqlbuilder.Gt(sqlbuilder.SUM(Payment.Amount), sqlbuilder.Literal(100)))
GROUP_BY(Payment.CustomerID).
ORDER_BY(sqlbuilder.SUM(Payment.Amount).Asc()).
HAVING(sqlbuilder.SUM(Payment.Amount).Gt(sqlbuilder.NewNumericLiteral(100)))
queryStr, err := customersPaymentQuery.String()
queryStr, args, err := customersPaymentQuery.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, len(args), 1)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>$1 ORDER BY SUM(payment.amount) ASC`)
assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount) ASC`)
type CustomerPaymentSum struct {
CustomerID int16
AmountSum float64
@ -528,7 +541,7 @@ func TestSelectGroupBy2(t *testing.T) {
Payment.CustomerID,
sqlbuilder.SUM(Payment.Amount).As("amount_sum"),
).
GroupBy(Payment.CustomerID)
GROUP_BY(Payment.CustomerID)
customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum")
amountSumColumn := customersPaymentTable.RefIntColumnName("amount_sum")
@ -536,11 +549,12 @@ func TestSelectGroupBy2(t *testing.T) {
query := Customer.
INNER_JOIN(customersPaymentTable, Customer.CustomerID.Eq(customersPaymentTable.RefIntColumn(Payment.CustomerID))).
SELECT(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")).
OrderBy(amountSumColumn.Asc())
ORDER_BY(amountSumColumn.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, len(args), 0)
err = query.Query(db, &customersWithAmounts)
assert.NilError(t, err)
@ -565,13 +579,13 @@ func TestSelectGroupBy2(t *testing.T) {
func TestSelectTimeColumns(t *testing.T) {
query := Payment.SELECT(Payment.AllColumns).
Where(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")).
OrderBy(Payment.PaymentDate.Asc())
WHERE(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")).
ORDER_BY(Payment.PaymentDate.Asc())
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 1)
fmt.Println(queryStr)
payments := []model.Payment{}

View file

@ -18,13 +18,14 @@ func TestInsertValues(t *testing.T) {
VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT).
RETURNING(table.Link.ID)
insertQueryStr, err := insertQuery.String()
insertQueryStr, args, err := insertQuery.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 8)
fmt.Println(insertQueryStr)
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id AS "link.id";`)
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ($1, $2, DEFAULT), ($3, $4, DEFAULT), ($5, $6, DEFAULT), ($7, $8, DEFAULT) RETURNING link.id AS "link.id";`)
res, err := insertQuery.Execute(db)
assert.NilError(t, err)
@ -68,9 +69,10 @@ func TestInsertDataObject(t *testing.T) {
INSERT(table.Link.URL, table.Link.Name).
VALUES_MAPPING(linkData)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 2)
fmt.Println(queryStr)
@ -92,9 +94,10 @@ func TestInsertQuery(t *testing.T) {
INSERT(table.Link.URL, table.Link.Name).
QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name))
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 0)
fmt.Println(queryStr)

View file

@ -12,11 +12,12 @@ import (
func TestUUIDType(t *testing.T) {
query := table.AllTypes.
SELECT(table.AllTypes.AllColumns).
Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
WHERE(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11"))
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 1)
fmt.Println(queryStr)
//assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid", all_types.json AS "all_types.json", all_types.jsonb AS "all_types.jsonb" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11`)
result := model.AllTypes{}
@ -29,11 +30,11 @@ func TestEnumType(t *testing.T) {
query := table.Person.
SELECT(table.Person.AllColumns)
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
fmt.Println(queryStr)
assert.Equal(t, len(args), 0)
result := []model.Person{}
err = query.Query(db, &result)

View file

@ -23,10 +23,10 @@ func TestUpdateValues(t *testing.T) {
SET("Bong", "http://bong.com").
WHERE(table.Link.Name.EqL("Bing"))
queryStr, err := query.String()
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 3)
fmt.Println(queryStr)
res, err := query.Execute(db)
@ -38,7 +38,7 @@ func TestUpdateValues(t *testing.T) {
links := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns).
Where(table.Link.Name.EqL("Bong")).
WHERE(table.Link.Name.EqL("Bong")).
Query(db, &links)
assert.NilError(t, err)
@ -63,10 +63,10 @@ func TestUpdateAndReturning(t *testing.T) {
WHERE(table.Link.Name.EqL("Ask")).
RETURNING(table.Link.AllColumns)
stmtStr, err := stmt.String()
stmtStr, args, err := stmt.Sql()
assert.NilError(t, err)
assert.Equal(t, len(args), 3)
fmt.Println(stmtStr)
links := []model.Link{}