diff --git a/sqlbuilder/alias.go b/sqlbuilder/alias.go index e852e29..f9c46c5 100644 --- a/sqlbuilder/alias.go +++ b/sqlbuilder/alias.go @@ -1,7 +1,5 @@ package sqlbuilder -import "bytes" - type Alias struct { expression Expression alias string @@ -14,9 +12,9 @@ func NewAlias(expression Expression, alias string) *Alias { } } -func (a *Alias) SerializeForProjection(out *bytes.Buffer) error { +func (a *Alias) SerializeForProjection(out *queryData) error { - err := a.expression.SerializeSql(out, ALIASED) + err := a.expression.Serialize(out, SKIP_DEFAULT_ALIASING) if err != nil { return err diff --git a/sqlbuilder/bool_expresion.go b/sqlbuilder/bool_expresion.go index a7458fc..182dde6 100644 --- a/sqlbuilder/bool_expresion.go +++ b/sqlbuilder/bool_expresion.go @@ -66,11 +66,7 @@ type boolLiteralExpression struct { func newBoolLiteralExpression(value bool) BoolExpression { boolLiteralExpression := boolLiteralExpression{} - sqlValue, err := sqltypes.BuildValue(value) - if err != nil { - panic(errors.Wrap(err, "Invalid literal value")) - } - boolLiteralExpression.literalExpression = *NewLiteralExpression(sqlValue) + boolLiteralExpression.literalExpression = *Literal(value) boolLiteralExpression.boolInterfaceImpl.parent = &boolLiteralExpression return &boolLiteralExpression @@ -113,27 +109,27 @@ func newPrefixBoolExpression(expression Expression, operator []byte) BoolExpress } //---------------------------------------------------// -type conjunctBoolExpression struct { - expressionInterfaceImpl - boolInterfaceImpl - - conjunctExpression - name string -} - -func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression { - boolExpression := conjunctBoolExpression{ - conjunctExpression: conjunctExpression{ - expressions: expressions, - conjunction: operator, - }, - } - - boolExpression.expressionInterfaceImpl.parent = &boolExpression - boolExpression.boolInterfaceImpl.parent = &boolExpression - - return &boolExpression -} +//type conjunctBoolExpression struct { +// expressionInterfaceImpl +// boolInterfaceImpl +// +// conjunctExpression +// name string +//} +// +//func NewConjunctBoolExpression(operator []byte, expressions ...BoolExpression) BoolExpression { +// boolExpression := conjunctBoolExpression{ +// conjunctExpression: conjunctExpression{ +// expressions: expressions, +// conjunction: operator, +// }, +// } +// +// boolExpression.expressionInterfaceImpl.parent = &boolExpression +// boolExpression.boolInterfaceImpl.parent = &boolExpression +// +// return &boolExpression +//} //---------------------------------------------------// type inExpression struct { @@ -146,34 +142,33 @@ type inExpression struct { err error } -func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (c *inExpression) Serialize(out *queryData, options ...serializeOption) error { if c.err != nil { return errors.Wrap(c.err, "Invalid IN expression") } if c.lhs == nil { - return errors.Newf( - "lhs of in expression is nil. Generated sql: %s", - out.String()) + return errors.Newf("lhs of in expression is nil.") } // We'll serialize the lhs even if we don't need it to ensure no error buf := &bytes.Buffer{} - err := c.lhs.SerializeSql(buf) + err := c.lhs.Serialize(out, options...) if err != nil { return err } if c.rhs == nil { - _, _ = out.WriteString("FALSE") + out.WriteString("FALSE") return nil } - _, _ = out.WriteString(buf.String()) - _, _ = out.WriteString(" IN ") + out.WriteString(buf.String()) + out.WriteString(" IN ") + + err = c.rhs.Serialize(out) - err = c.rhs.SerializeSql(out) if err != nil { return err } @@ -183,10 +178,6 @@ func (c *inExpression) SerializeSql(out *bytes.Buffer, options ...serializeOptio // Returns a representation of "a=b" func Eq(lhs, rhs Expression) BoolExpression { - lit, ok := rhs.(*literalExpression) - if ok && sqltypes.Value(lit.value).IsNull() { - return newBinaryBoolExpression(lhs, rhs, []byte(" IS ")) - } return newBinaryBoolExpression(lhs, rhs, []byte(" = ")) } @@ -197,10 +188,6 @@ func EqL(lhs Expression, val interface{}) BoolExpression { // Returns a representation of "a!=b" func NotEq(lhs, rhs Expression) BoolExpression { - lit, ok := rhs.(*literalExpression) - if ok && sqltypes.Value(lit.value).IsNull() { - return newBinaryBoolExpression(lhs, rhs, []byte(" IS NOT ")) - } return newBinaryBoolExpression(lhs, rhs, []byte("!=")) } @@ -258,14 +245,13 @@ func IsTrue(expr BoolExpression) BoolExpression { return newPrefixBoolExpression(expr, []byte(" IS TRUE ")) } -// Returns a representation of "c[0] AND ... AND c[n-1]" for c in clauses -func And(expressions ...BoolExpression) BoolExpression { - return NewConjunctBoolExpression([]byte(" AND "), expressions...) +func And(lhs, rhs Expression) BoolExpression { + return newBinaryBoolExpression(lhs, rhs, []byte(" AND ")) } // Returns a representation of "c[0] OR ... OR c[n-1]" for c in clauses -func Or(expressions ...BoolExpression) BoolExpression { - return NewConjunctBoolExpression([]byte(" OR "), expressions...) +func Or(lhs, rhs Expression) BoolExpression { + return newBinaryBoolExpression(lhs, rhs, []byte(" OR ")) } func Like(lhs, rhs Expression) BoolExpression { diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index b0bc407..66907a1 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -10,7 +10,7 @@ func TestBinaryExpression(t *testing.T) { boolExpression := Eq(Literal(2), Literal(3)) out := bytes.Buffer{} - err := boolExpression.SerializeSql(&out) + err := boolExpression.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), "2 = 3") @@ -29,7 +29,7 @@ func TestBinaryExpression(t *testing.T) { exp := boolExpression.And(Eq(Literal(4), Literal(5))) out := bytes.Buffer{} - err := exp.SerializeSql(&out) + err := exp.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`) @@ -39,7 +39,7 @@ func TestBinaryExpression(t *testing.T) { exp := boolExpression.Or(Eq(Literal(4), Literal(5))) out := bytes.Buffer{} - err := exp.SerializeSql(&out) + err := exp.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`) @@ -50,7 +50,7 @@ func TestUnaryExpression(t *testing.T) { notExpression := Not(Eq(Literal(2), Literal(1))) out := bytes.Buffer{} - err := notExpression.SerializeSql(&out) + err := notExpression.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), " NOT 2 = 1") @@ -69,7 +69,7 @@ func TestUnaryExpression(t *testing.T) { exp := notExpression.And(Eq(Literal(4), Literal(5))) out := bytes.Buffer{} - err := exp.SerializeSql(&out) + err := exp.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`) @@ -80,7 +80,7 @@ func TestUnaryIsTrueExpression(t *testing.T) { notExpression := IsTrue(Eq(Literal(2), Literal(1))) out := bytes.Buffer{} - err := notExpression.SerializeSql(&out) + err := notExpression.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), " IS TRUE 2 = 1") @@ -89,7 +89,7 @@ func TestUnaryIsTrueExpression(t *testing.T) { exp := notExpression.And(Eq(Literal(4), Literal(5))) out := bytes.Buffer{} - err := exp.SerializeSql(&out) + err := exp.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`) @@ -100,7 +100,7 @@ func TestBoolLiteral(t *testing.T) { literal := newBoolLiteralExpression(true) out := bytes.Buffer{} - err := literal.SerializeSql(&out) + err := literal.Serialize(&out) assert.NilError(t, err) diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 37c3cf8..dad2f57 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -1,16 +1,91 @@ package sqlbuilder -import "bytes" +import ( + "bytes" + "errors" + "strconv" +) type serializeOption int const ( - ALIASED = iota + SKIP_DEFAULT_ALIASING = iota FOR_PROJECTION ) type Clause interface { - SerializeSql(out *bytes.Buffer, options ...serializeOption) error + Serialize(out *queryData, options ...serializeOption) error +} + +type queryData struct { + queryBuff bytes.Buffer + args []interface{} +} + +func (q *queryData) Write(data []byte) { + q.queryBuff.Write(data) +} + +func (q *queryData) WriteString(str string) { + q.queryBuff.WriteString(str) +} + +func (q *queryData) WriteByte(b byte) { + q.queryBuff.WriteByte(b) +} + +func (q *queryData) InsertArgument(arg interface{}) { + q.args = append(q.args, arg) + argPlaceholder := "$" + strconv.Itoa(len(q.args)) + + q.queryBuff.WriteString(argPlaceholder) +} + +func argToString(value interface{}) (string, error) { + switch bindVal := value.(type) { + case bool: + if bindVal { + return "TRUE", nil + } else { + return "FALSE", nil + } + case int8: + return strconv.FormatInt(int64(bindVal), 10), nil + case int: + return strconv.FormatInt(int64(bindVal), 10), nil + case int16: + return strconv.FormatInt(int64(bindVal), 10), nil + case int32: + return strconv.FormatInt(int64(bindVal), 10), nil + case int64: + return strconv.FormatInt(int64(bindVal), 10), nil + + case uint8: + return strconv.FormatUint(uint64(bindVal), 10), nil + case uint: + return strconv.FormatUint(uint64(bindVal), 10), nil + case uint16: + return strconv.FormatUint(uint64(bindVal), 10), nil + case uint32: + return strconv.FormatUint(uint64(bindVal), 10), nil + case uint64: + return strconv.FormatUint(uint64(bindVal), 10), nil + + case float32: + return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil + case float64: + return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil + + case string: + return bindVal, nil + case []byte: + return string(bindVal), nil + //TODO: implement + //case time.Time: + // return bindVal.String()) + default: + return "", errors.New("Unsupported literal type. ") + } } func contains(s []serializeOption, e serializeOption) bool { diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 3a2fec3..5416673 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -3,14 +3,9 @@ package sqlbuilder import ( - "bytes" - "regexp" "strings" ) -// XXX: Maybe add UIntColumn - -// Representation of a tableName for query generation type Column interface { Expression @@ -28,11 +23,6 @@ const ( NotNullable NullableColumn = false ) -//// A column that can be refer to outside of the projection list -//type NonAliasColumn interface { -// Column -//} - type Collation string const ( @@ -82,194 +72,39 @@ func (c *baseColumn) setTableName(table string) error { return nil } -func (c baseColumn) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (c baseColumn) Serialize(out *queryData, options ...serializeOption) error { if c.tableName != "" { - _, _ = out.WriteString(c.tableName) - _, _ = out.WriteString(".") + out.WriteString(c.tableName) + out.WriteString(".") } + containsDot := strings.Contains(c.name, ".") if containsDot { - out.WriteString("\"") - } - _, _ = out.WriteString(c.name) - if containsDot { - out.WriteString("\"") + out.WriteString(`"`) } - if contains(options, FOR_PROJECTION) && !contains(options, ALIASED) && c.tableName != "" { - _, _ = out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"") + out.WriteString(c.name) + + if containsDot { + out.WriteString(`"`) + } + + if contains(options, FOR_PROJECTION) && !contains(options, SKIP_DEFAULT_ALIASING) && c.tableName != "" { + out.WriteString(" AS \"" + c.tableName + "." + c.name + "\"") } return nil } // -//type bytesColumn struct { -// baseColumn -//} +//// This is a strict subset of the actual allowed identifiers +//var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$") // -//// Representation of VARBINARY/BLOB columns -//// This function will panic if name is not valid -//func BytesColumn(name string, nullable NullableColumn) Column { -// if !validIdentifierName(name) { -// panic("Invalid column name in bytes column") -// } -// bc := &bytesColumn{} -// bc.name = name -// bc.nullable = nullable -// return bc +//// Returns true if the given string is suitable as an identifier. +//func validIdentifierName(name string) bool { +// return validIdentifierRegexp.MatchString(name) //} -// -//type stringColumn struct { -// baseColumn -// charset Charset -// collation Collation -//} -// -//// Representation of VARCHAR/TEXT columns -//// This function will panic if name is not valid -//func StrColumn( -// name string, -// charset Charset, -// collation Collation, -// nullable NullableColumn) Column { -// -// if !validIdentifierName(name) { -// panic("Invalid column name in str column") -// } -// sc := &stringColumn{charset: charset, collation: collation} -// sc.name = name -// sc.nullable = nullable -// return sc -//} -// -//type dateTimeColumn struct { -// baseColumn -//} -// -//// Representation of DateTime columns -//// This function will panic if name is not valid -//func DateTimeColumn(name string, nullable NullableColumn) Column { -// if !validIdentifierName(name) { -// panic("Invalid column name in datetime column") -// } -// dc := &dateTimeColumn{} -// dc.name = name -// dc.nullable = nullable -// return dc -//} - -//type IntegerColumn struct { -// baseColumn -//} -// -//// Representation of any integer column -//// This function will panic if name is not valid -//func IntColumn(name string, nullable NullableColumn) *IntegerColumn { -// if !validIdentifierName(name) { -// panic("Invalid column name in int column") -// } -// ic := &IntegerColumn{} -// ic.name = name -// ic.nullable = nullable -// return ic -//} - -//type doubleColumn struct { -// baseColumn -//} -// -//// Representation of any double column -//// This function will panic if name is not valid -//func DoubleColumn(name string, nullable NullableColumn) Column { -// if !validIdentifierName(name) { -// panic("Invalid column name in int column") -// } -// ic := &doubleColumn{} -// ic.name = name -// ic.nullable = nullable -// return ic -//} -// -//type booleanColumn struct { -// baseColumn -// -// // XXX: Maybe allow isBoolExpression (for now, not included because -// // the deferred lookup equivalent can never be isBoolExpression) -//} - -// Representation of TINYINT used as a bool -// This function will panic if name is not valid -//func NewBoolColumn(name string, nullable NullableColumn) Column { -// if !validIdentifierName(name) { -// panic("Invalid column name in bool column") -// } -// bc := &booleanColumn{} -// bc.name = name -// bc.nullable = nullable -// return bc -//} -// -//type aliasColumn struct { -// baseColumn -// expression Expression -//} -// -//func (c *aliasColumn) SerializeSql(out *bytes.Buffer) error { -// _ = out.WriteByte('`') -// _, _ = out.WriteString(c.name) -// _ = out.WriteByte('`') -// return nil -//} -// -//func (c *aliasColumn) SerializeSqlForColumnList(out *bytes.Buffer) error { -// if !validIdentifierName(c.name) { -// return errors.Newf( -// "Invalid alias name `%s`. Generated sql: %s", -// c.name, -// out.String()) -// } -// if c.expression == nil { -// return errors.Newf( -// "Cannot alias a nil expression. Generated sql: %s", -// out.String()) -// } -// -// _ = out.WriteByte('(') -// if c.expression == nil { -// return errors.Newf("nil alias clause. Generate sql: %s", out.String()) -// } -// if err := c.expression.SerializeSql(out); err != nil { -// return err -// } -// _, _ = out.WriteString(") AS \"") -// _, _ = out.WriteString(c.name) -// _ = out.WriteByte('"') -// return nil -//} - -//func (c *aliasColumn) setTableName(table string) error { -// return errors.Newf( -// "Alias column '%s' should never have setTableName called on it", -// c.name) -//} - -// Representation of aliased clauses (expression AS name) -//func Alias(name string, c Expression) Column { -// ac := &aliasColumn{} -// ac.name = name -// ac.expression = c -// return ac -//} - -// This is a strict subset of the actual allowed identifiers -var validIdentifierRegexp = regexp.MustCompile("^[a-zA-Z_]\\w*$") - -// Returns true if the given string is suitable as an identifier. -func validIdentifierName(name string) bool { - return validIdentifierRegexp.MatchString(name) -} // //// Pseudo Column type returned by tableName.C(name) @@ -289,12 +124,12 @@ func validIdentifierName(name string) bool { //func (c *deferredLookupColumn) SerializeSqlForColumnList( // out *bytes.Buffer) error { // -// return c.SerializeSql(out) +// return c.Serialize(out) //} // -//func (c *deferredLookupColumn) SerializeSql(out *bytes.Buffer) error { +//func (c *deferredLookupColumn) Serialize(out *bytes.Buffer) error { // if c.cachedColumn != nil { -// return c.cachedColumn.SerializeSql(out) +// return c.cachedColumn.Serialize(out) // } // // col, err := c.tableName.getColumn(c.colName) @@ -303,7 +138,7 @@ func validIdentifierName(name string) bool { // } // // c.cachedColumn = col -// return col.SerializeSql(out) +// return col.Serialize(out) //} // //func (c *deferredLookupColumn) setTableName(tableName string) error { diff --git a/sqlbuilder/column_types.go b/sqlbuilder/column_types.go index 5b2fa44..9d43d84 100644 --- a/sqlbuilder/column_types.go +++ b/sqlbuilder/column_types.go @@ -8,9 +8,7 @@ type BoolColumn struct { } func NewBoolColumn(name string, nullable NullableColumn) *BoolColumn { - if !validIdentifierName(name) { - panic("Invalid column name in bool column") - } + boolColumn := &BoolColumn{} boolColumn.baseColumn = newBaseColumn(name, nullable, "", boolColumn) @@ -26,9 +24,6 @@ type NumericColumn struct { } func NewNumericColumn(name string, nullable NullableColumn) *NumericColumn { - if !validIdentifierName(name) { - panic("Invalid column name") - } numericColumn := &NumericColumn{} @@ -70,9 +65,6 @@ type StringColumn struct { // Representation of any integer column // This function will panic if name is not valid func NewStringColumn(name string, nullable NullableColumn) *StringColumn { - if !validIdentifierName(name) { - panic("Invalid column name") - } stringColumn := &StringColumn{} diff --git a/sqlbuilder/column_types_test.go b/sqlbuilder/column_types_test.go index 344e1f6..319a244 100644 --- a/sqlbuilder/column_types_test.go +++ b/sqlbuilder/column_types_test.go @@ -10,20 +10,20 @@ func TestNewBoolColumn(t *testing.T) { boolColumn := NewBoolColumn("col", Nullable) out := bytes.Buffer{} - err := boolColumn.SerializeSql(&out) + err := boolColumn.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), "col") out.Reset() - err = boolColumn.SerializeSql(&out, FOR_PROJECTION) + err = boolColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) assert.Equal(t, out.String(), "col") out.Reset() err = boolColumn.setTableName("table1") assert.NilError(t, err) - err = boolColumn.SerializeSql(&out, FOR_PROJECTION) + err = boolColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) assert.Equal(t, out.String(), `table1.col AS "table1.col"`) @@ -40,20 +40,20 @@ func TestNewIntColumn(t *testing.T) { integerColumn := NewIntegerColumn("col", Nullable) out := bytes.Buffer{} - err := integerColumn.SerializeSql(&out) + err := integerColumn.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), "col") out.Reset() - err = integerColumn.SerializeSql(&out, FOR_PROJECTION) + err = integerColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) assert.Equal(t, out.String(), "col") out.Reset() err = integerColumn.setTableName("table1") assert.NilError(t, err) - err = integerColumn.SerializeSql(&out, FOR_PROJECTION) + err = integerColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) assert.Equal(t, out.String(), `table1.col AS "table1.col"`) @@ -70,20 +70,20 @@ func TestNewNumericColumnColumn(t *testing.T) { numericColumn := NewNumericColumn("col", Nullable) out := bytes.Buffer{} - err := numericColumn.SerializeSql(&out) + err := numericColumn.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), "col") out.Reset() - err = numericColumn.SerializeSql(&out) + err = numericColumn.Serialize(&out) assert.NilError(t, err) assert.Equal(t, out.String(), "col") out.Reset() err = numericColumn.setTableName("table1") assert.NilError(t, err) - err = numericColumn.SerializeSql(&out, FOR_PROJECTION) + err = numericColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) assert.Equal(t, out.String(), `table1.col AS "table1.col"`) diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index d978996..8bfccce 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "bytes" "database/sql" "github.com/dropbox/godropbox/errors" "github.com/sub0zero/go-sqlbuilder/types" @@ -38,33 +37,35 @@ func (d *deleteStatementImpl) WHERE(expression BoolExpression) DeleteStatement { return d } -func (d *deleteStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("DELETE FROM ") +func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) { + queryData := &queryData{} + + queryData.WriteString("DELETE FROM ") if d.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + return "", nil, errors.New("nil tableName.") } - if err = d.table.SerializeSql(buf); err != nil { + if err = d.table.SerializeSql(queryData); err != nil { return } if d.where == nil { - return "", errors.Newf("Deleting without a WHERE clause. Generated sql: %s", buf.String()) + return "", nil, errors.New("Deleting without a WHERE clause.") } - _, _ = buf.WriteString(" WHERE ") - if err = d.where.SerializeSql(buf); err != nil { + queryData.WriteString(" WHERE ") + + if err = d.where.Serialize(queryData); err != nil { return } if d.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = d.order.SerializeSql(buf); err != nil { + queryData.WriteString(" ORDER BY ") + if err = d.order.Serialize(queryData); err != nil { return } } - return buf.String() + ";", nil + return queryData.queryBuff.String() + ";", queryData.args, nil } diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index df09be3..c03fb31 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -14,7 +14,7 @@ import ( "time" ) -func Execute(db types.Db, query string, destinationPtr interface{}) error { +func Query(db types.Db, query string, args []interface{}, destinationPtr interface{}) error { if db == nil { return errors.New("db is nil") } @@ -28,7 +28,7 @@ func Execute(db types.Db, query string, destinationPtr interface{}) error { return errors.New("Destination has to be a pointer to slice or pointer to struct ") } - rows, err := db.Query(query) + rows, err := db.Query(query, args...) if err != nil { return err @@ -72,7 +72,7 @@ func Execute(db types.Db, query string, destinationPtr interface{}) error { return err } - fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROWS PROCESSED") + fmt.Println(strconv.Itoa(scanContext.rowNum) + " ROW(S) PROCESSED") return nil } diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index 45b9a2a..cae2e2f 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -1,8 +1,6 @@ package sqlbuilder import ( - "bytes" - "github.com/dropbox/godropbox/database/sqltypes" "github.com/dropbox/godropbox/errors" ) @@ -42,8 +40,8 @@ func (e *expressionInterfaceImpl) Desc() OrderByClause { return &orderByClause{expression: e.parent, ascent: false} } -func (e *expressionInterfaceImpl) SerializeForProjection(out *bytes.Buffer) error { - return e.parent.SerializeSql(out, FOR_PROJECTION) +func (e *expressionInterfaceImpl) SerializeForProjection(out *queryData) error { + return e.parent.Serialize(out, FOR_PROJECTION) } // Representation of binary operations (e.g. comparisons, arithmetic) @@ -62,21 +60,21 @@ func newBinaryExpression(lhs, rhs Expression, operator []byte, parent ...Express return binaryExpression } -func (c *binaryExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { +func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) error { if c.lhs == nil { - return errors.Newf("nil lhs. Generated sql: %s", out.String()) + return errors.Newf("nil lhs.") } - if err = c.lhs.SerializeSql(out); err != nil { - return + if err := c.lhs.Serialize(out); err != nil { + return err } - _, _ = out.Write(c.operator) + out.Write(c.operator) if c.rhs == nil { - return errors.Newf("nil rhs. Generated sql: %s", out.String()) + return errors.Newf("nil rhs.") } - if err = c.rhs.SerializeSql(out); err != nil { - return + if err := c.rhs.Serialize(out); err != nil { + return err } return nil @@ -97,80 +95,61 @@ func newPrefixExpression(expression Expression, operator []byte) prefixExpressio return prefixExpression } -func (p *prefixExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { - _, _ = out.Write(p.operator) +func (p *prefixExpression) Serialize(out *queryData, options ...serializeOption) error { + out.Write(p.operator) if p.expression == nil { - return errors.Newf("nil prefix expression. Generated sql: %s", out.String()) + return errors.Newf("nil prefix expression.") } - if err = p.expression.SerializeSql(out); err != nil { - return + if err := p.expression.Serialize(out); err != nil { + return err } return nil } -// Representation of n-ary conjunctions (AND/OR) -type conjunctExpression struct { - expressions []BoolExpression - conjunction []byte -} - -func (conj *conjunctExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { - if len(conj.expressions) == 0 { - return errors.Newf( - "Empty conjunction. Generated sql: %s", - out.String()) - } - - clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) - for i, expr := range conj.expressions { - clauses[i] = expr - } - - useParentheses := len(clauses) > 1 - if useParentheses { - _ = out.WriteByte('(') - } - - if err = serializeClauses(clauses, conj.conjunction, out); err != nil { - return - } - - if useParentheses { - _ = out.WriteByte(')') - } - - return nil -} +// +//// Representation of n-ary conjunctions (AND/OR) +//type conjunctExpression struct { +// expressions []Expression +// conjunction []byte +//} +// +//func (conj *conjunctExpression) Serialize(out *queryData, options ...serializeOption) error { +// if len(conj.expressions) == 0 { +// return errors.New("Empty conjunction.") +// } +// +// //clauses := make([]Clause, len(conj.expressions), len(conj.expressions)) +// //for i, expr := range conj.expressions { +// // clauses[i] = expr +// //} +// +// useParentheses := len(conj.expressions) > 1 +// if useParentheses { +// out.WriteByte('(') +// } +// +// if err := serializeExpressionList(conj.expressions, string(conj.conjunction), out); err != nil { +// return err +// } +// +// if useParentheses { +// out.WriteByte(')') +// } +// +// return nil +//} //-------------------------------------------------------------- -// Representation of an escaped literal -type literalExpression struct { - expressionInterfaceImpl - value sqltypes.Value -} - -func NewLiteralExpression(value sqltypes.Value) *literalExpression { - exp := literalExpression{value: value} - exp.expressionInterfaceImpl.parent = &exp - - return &exp -} - -func (c literalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { - sqltypes.Value(c.value).EncodeSql(out) - return nil -} - //------------------------------------------------------// //// Dummy type for select * //type ColumnList []Column // -//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +//func (cl ColumnList) Serialize(out *bytes.Buffer, options ...serializeOption) error { // for i, column := range cl { -// err := column.SerializeSql(out) +// err := column.Serialize(out) // // if err != nil { // return err diff --git a/sqlbuilder/expression_old.go b/sqlbuilder/expression_old.go index af82d0c..b5716f2 100644 --- a/sqlbuilder/expression_old.go +++ b/sqlbuilder/expression_old.go @@ -2,124 +2,88 @@ package sqlbuilder import ( - "bytes" "strconv" "strings" "time" - "github.com/dropbox/godropbox/database/sqltypes" "github.com/dropbox/godropbox/errors" ) -type orderByClause struct { - isOrderByClause - expression Expression - ascent bool -} - -func (o *orderByClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { - if o.expression == nil { - return errors.Newf( - "nil order by clause. Generated sql: %s", - out.String()) - } - - if err := o.expression.SerializeSql(out); err != nil { - return err - } - - if o.ascent { - _, _ = out.WriteString(" ASC") - } else { - _, _ = out.WriteString(" DESC") - } - - return nil -} - -func Asc(expression Expression) OrderByClause { - return &orderByClause{expression: expression, ascent: true} -} - -func Desc(expression Expression) OrderByClause { - return &orderByClause{expression: expression, ascent: false} -} - -func serializeClauses( - clauses []Clause, - separator []byte, - out *bytes.Buffer) (err error) { - - if clauses == nil || len(clauses) == 0 { - return errors.Newf("Empty clauses. Generated sql: %s", out.String()) - } - - if clauses[0] == nil { - return errors.Newf("nil clause. Generated sql: %s", out.String()) - } - if err = clauses[0].SerializeSql(out); err != nil { - return - } - - for _, c := range clauses[1:] { - _, _ = out.Write(separator) - - if c == nil { - return errors.Newf("nil clause. Generated sql: %s", out.String()) - } - if err = c.SerializeSql(out); err != nil { - return - } - } - - return nil -} - -// Representation of n-ary arithmetic (+ - * /) -type arithmeticExpression struct { - expressionInterfaceImpl - expressions []Expression - operator []byte -} - -func (arith *arithmeticExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { - if len(arith.expressions) == 0 { - return errors.Newf( - "Empty arithmetic expression. Generated sql: %s", - out.String()) - } - - clauses := make([]Clause, len(arith.expressions), len(arith.expressions)) - for i, expr := range arith.expressions { - clauses[i] = expr - } - - useParentheses := len(clauses) > 1 - if useParentheses { - _ = out.WriteByte('(') - } - - if err = serializeClauses(clauses, arith.operator, out); err != nil { - return - } - - if useParentheses { - _ = out.WriteByte(')') - } - - return nil -} +//func serializeClauses( +// clauses []Clause, +// separator []byte, +// out *bytes.Buffer) (err error) { +// +// if clauses == nil || len(clauses) == 0 { +// return errors.Newf("Empty clauses.") +// } +// +// if clauses[0] == nil { +// return errors.Newf("nil clause.") +// } +// if err = clauses[0].Serialize(out); err != nil { +// return +// } +// +// for _, c := range clauses[1:] { +// _, _ = out.Write(separator) +// +// if c == nil { +// return errors.Newf("nil clause.") +// } +// if err = c.Serialize(out); err != nil { +// return +// } +// } +// +// return nil +//} +// +//// Representation of n-ary arithmetic (+ - * /) +//type arithmeticExpression struct { +// expressionInterfaceImpl +// expressions []Expression +// operator []byte +//} +// +//func (arith *arithmeticExpression) Serialize(out *queryData, options ...serializeOption) error { +// if len(arith.expressions) == 0 { +// return errors.Newf( +// "Empty arithmetic expression.") +// } +// +// clauses := make([]Clause, len(arith.expressions), len(arith.expressions)) +// for i, expr := range arith.expressions { +// clauses[i] = expr +// } +// +// useParentheses := len(clauses) > 1 +// if useParentheses { +// _ = out.WriteByte('(') +// } +// +// if err = serializeClauses(clauses, arith.operator, out); err != nil { +// return +// } +// +// if useParentheses { +// _ = out.WriteByte(')') +// } +// +// return nil +//} +// type tupleExpression struct { expressionInterfaceImpl elements listClause } -func (tuple *tupleExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { - if len(tuple.elements.clauses) < 1 { +func (tuple *tupleExpression) Serialize(out *queryData, options ...serializeOption) error { + if len(tuple.elements.clauses) == 0 { return errors.Newf("Tuples must include at least one element") } - return tuple.elements.SerializeSql(out) + return tuple.elements.Serialize(out) } func Tuple(exprs ...Expression) Expression { @@ -141,61 +105,62 @@ type listClause struct { includeParentheses bool } -func (list *listClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (list *listClause) Serialize(out *queryData, options ...serializeOption) error { if list.includeParentheses { - _ = out.WriteByte('(') + out.WriteByte('(') } - if err := serializeClauses(list.clauses, []byte(","), out); err != nil { + if err := serializeClauseList(list.clauses, out); err != nil { return err } if list.includeParentheses { - _ = out.WriteByte(')') + out.WriteByte(')') } return nil } -type funcExpression struct { - expressionInterfaceImpl - funcName string - args *listClause -} - -func (c *funcExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { - if !validIdentifierName(c.funcName) { - return errors.Newf( - "Invalid function name: %s. Generated sql: %s", - c.funcName, - out.String()) - } - _, _ = out.WriteString(c.funcName) - if c.args == nil { - _, _ = out.WriteString("()") - } else { - return c.args.SerializeSql(out) - } - return nil -} - -// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) -func SqlFunc(funcName string, expressions ...Expression) Expression { - f := &funcExpression{ - funcName: funcName, - } - if len(expressions) > 0 { - args := make([]Clause, len(expressions), len(expressions)) - for i, expr := range expressions { - args[i] = expr - } - - f.args = &listClause{ - clauses: args, - includeParentheses: true, - } - } - return f -} +// +//type funcExpression struct { +// expressionInterfaceImpl +// funcName string +// args *listClause +//} +// +//func (c *funcExpression) Serialize(out *queryData, options ...serializeOption) error { +// if !validIdentifierName(c.funcName) { +// return errors.Newf( +// "Invalid function name: %s.", +// c.funcName, +// out.String()) +// } +// _, _ = out.WriteString(c.funcName) +// if c.args == nil { +// _, _ = out.WriteString("()") +// } else { +// return c.args.Serialize(out) +// } +// return nil +//} +// +//// Returns a representation of sql function call "func_call(c[0], ..., c[n-1]) +//func SqlFunc(funcName string, expressions ...Expression) Expression { +// f := &funcExpression{ +// funcName: funcName, +// } +// if len(expressions) > 0 { +// args := make([]Clause, len(expressions), len(expressions)) +// for i, expr := range expressions { +// args[i] = expr +// } +// +// f.args = &listClause{ +// clauses: args, +// includeParentheses: true, +// } +// } +// return f +//} type intervalExpression struct { expressionInterfaceImpl @@ -205,23 +170,24 @@ type intervalExpression struct { var intervalSep = ":" -func (c *intervalExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { +func (c *intervalExpression) Serialize(out *queryData, options ...serializeOption) error { hours := c.duration / time.Hour minutes := (c.duration % time.Hour) / time.Minute sec := (c.duration % time.Minute) / time.Second msec := (c.duration % time.Second) / time.Microsecond - _, _ = out.WriteString("INTERVAL '") + out.WriteString("INTERVAL '") if c.negative { - _, _ = out.WriteString("-") + out.WriteString("-") } - _, _ = out.WriteString(strconv.FormatInt(int64(hours), 10)) - _, _ = out.WriteString(intervalSep) - _, _ = out.WriteString(strconv.FormatInt(int64(minutes), 10)) - _, _ = out.WriteString(intervalSep) - _, _ = out.WriteString(strconv.FormatInt(int64(sec), 10)) - _, _ = out.WriteString(intervalSep) - _, _ = out.WriteString(strconv.FormatInt(int64(msec), 10)) - _, _ = out.WriteString("' HOUR_MICROSECOND") + out.WriteString(strconv.FormatInt(int64(hours), 10)) + out.WriteString(intervalSep) + out.WriteString(strconv.FormatInt(int64(minutes), 10)) + out.WriteString(intervalSep) + out.WriteString(strconv.FormatInt(int64(sec), 10)) + out.WriteString(intervalSep) + out.WriteString(strconv.FormatInt(int64(msec), 10)) + out.WriteString("' HOUR_MICROSECOND") + return nil } @@ -246,45 +212,45 @@ func EscapeForLike(s string) string { } // Returns an escaped literal string -func Literal(v interface{}) Expression { - value, err := sqltypes.BuildValue(v) - if err != nil { - panic(errors.Wrap(err, "Invalid literal value")) - } - return NewLiteralExpression(value) -} - -// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses -func Add(expressions ...Expression) Expression { - return &arithmeticExpression{ - expressions: expressions, - operator: []byte(" + "), - } -} - -// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses -func Sub(expressions ...Expression) Expression { - return &arithmeticExpression{ - expressions: expressions, - operator: []byte(" - "), - } -} - -// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses -func Mul(expressions ...Expression) Expression { - return &arithmeticExpression{ - expressions: expressions, - operator: []byte(" * "), - } -} - -// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses -func Div(expressions ...Expression) Expression { - return &arithmeticExpression{ - expressions: expressions, - operator: []byte(" / "), - } -} +//func Literal(v interface{}) Expression { +// value, err := sqltypes.BuildValue(v) +// if err != nil { +// panic(errors.Wrap(err, "Invalid literal value")) +// } +// return NewLiteralExpression(value) +//} +// +//// Returns a representation of "c[0] + ... + c[n-1]" for c in clauses +//func Add(expressions ...Expression) Expression { +// return &arithmeticExpression{ +// expressions: expressions, +// operator: []byte(" + "), +// } +//} +// +//// Returns a representation of "c[0] - ... - c[n-1]" for c in clauses +//func Sub(expressions ...Expression) Expression { +// return &arithmeticExpression{ +// expressions: expressions, +// operator: []byte(" - "), +// } +//} +// +//// Returns a representation of "c[0] * ... * c[n-1]" for c in clauses +//func Mul(expressions ...Expression) Expression { +// return &arithmeticExpression{ +// expressions: expressions, +// operator: []byte(" * "), +// } +//} +// +//// Returns a representation of "c[0] / ... / c[n-1]" for c in clauses +//func Div(expressions ...Expression) Expression { +// return &arithmeticExpression{ +// expressions: expressions, +// operator: []byte(" / "), +// } +//} //TODO: Uncomment // @@ -336,14 +302,15 @@ type ifExpression struct { falseExpression Expression } -func (exp *ifExpression) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { - _, _ = out.WriteString("IF(") - _ = exp.conditional.SerializeSql(out) - _, _ = out.WriteString(",") - _ = exp.trueExpression.SerializeSql(out) - _, _ = out.WriteString(",") - _ = exp.falseExpression.SerializeSql(out) - _, _ = out.WriteString(")") +func (exp *ifExpression) Serialize(out *queryData, options ...serializeOption) error { + out.WriteString("IF(") + _ = exp.conditional.Serialize(out) + out.WriteString(",") + _ = exp.trueExpression.Serialize(out) + out.WriteString(",") + _ = exp.falseExpression.Serialize(out) + out.WriteString(")") + return nil } @@ -371,7 +338,7 @@ func If(conditional BoolExpression, // } //} // -//func (cv *columnValueExpression) SerializeSql(out *bytes.Buffer) error { +//func (cv *columnValueExpression) Serialize(out *bytes.Buffer) error { // _, _ = out.WriteString("VALUES(") // _ = cv.column.SerializeSqlForColumnList(out) // _ = out.WriteByte(')') diff --git a/sqlbuilder/expression_old_test.go b/sqlbuilder/expression_old_test.go index 5825103..fb3de48 100644 --- a/sqlbuilder/expression_old_test.go +++ b/sqlbuilder/expression_old_test.go @@ -19,7 +19,7 @@ func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.NotNil) } @@ -28,7 +28,7 @@ func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.NotNil) } @@ -37,7 +37,7 @@ func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -48,11 +48,11 @@ func (s *ExprSuite) TestTupleExpr(c *gc.C) { expr := Tuple() buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.NotNil) expr = Tuple(table1Col1, Literal(1), Literal("five")) - err = expr.SerializeSql(buf) + err = expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -68,7 +68,7 @@ func (s *ExprSuite) TestLikeExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -84,7 +84,7 @@ func (s *ExprSuite) TestRegexExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -100,7 +100,7 @@ func (s *ExprSuite) TestAndExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -115,7 +115,7 @@ func (s *ExprSuite) TestOrExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -130,7 +130,7 @@ func (s *ExprSuite) TestAddExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -142,7 +142,7 @@ func (s *ExprSuite) TestSubExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -154,7 +154,7 @@ func (s *ExprSuite) TestMulExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -166,7 +166,7 @@ func (s *ExprSuite) TestDivExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -178,7 +178,7 @@ func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.NotNil) } @@ -187,7 +187,7 @@ func (s *ExprSuite) TestNegateExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -199,7 +199,7 @@ func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.NotNil) } @@ -208,7 +208,7 @@ func (s *ExprSuite) TestEqExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -220,7 +220,7 @@ func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -232,7 +232,7 @@ func (s *ExprSuite) TestNeqExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -244,7 +244,7 @@ func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -256,7 +256,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -268,7 +268,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -283,7 +283,7 @@ func (s *ExprSuite) TestGtExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -295,7 +295,7 @@ func (s *ExprSuite) TestGteExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -308,7 +308,7 @@ func (s *ExprSuite) TestInExpr(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -321,7 +321,7 @@ func (s *ExprSuite) TestInExprEmptyList(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -333,7 +333,7 @@ func (s *ExprSuite) TestSqlFuncExprNilInArgList(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.NotNil) } @@ -342,7 +342,7 @@ func (s *ExprSuite) TestSqlFuncExprEmptyArgList(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -354,7 +354,7 @@ func (s *ExprSuite) TestSqlFuncExprNonEmptyArgList(c *gc.C) { buf := &bytes.Buffer{} - err := expr.SerializeSql(buf) + err := expr.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -366,7 +366,7 @@ func (s *ExprSuite) TestOrderByClauseNilExpr(c *gc.C) { buf := &bytes.Buffer{} - err := clause.SerializeSql(buf) + err := clause.Serialize(buf) c.Assert(err, gc.NotNil) } @@ -375,7 +375,7 @@ func (s *ExprSuite) TestAsc(c *gc.C) { buf := &bytes.Buffer{} - err := clause.SerializeSql(buf) + err := clause.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -387,7 +387,7 @@ func (s *ExprSuite) TestDesc(c *gc.C) { buf := &bytes.Buffer{} - err := clause.SerializeSql(buf) + err := clause.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -400,7 +400,7 @@ func (s *ExprSuite) TestIf(c *gc.C) { buf := &bytes.Buffer{} - err := clause.SerializeSql(buf) + err := clause.Serialize(buf) c.Assert(err, gc.IsNil) sql := buf.String() @@ -538,7 +538,7 @@ func (s *ExprSuite) TestInterval(c *gc.C) { for i, tt := range testTable { buf.Reset() - err := Interval(tt.interval).SerializeSql(buf) + err := Interval(tt.interval).Serialize(buf) c.Assert(err, gc.Equals, tt.expectedErr, gc.Commentf("experiment #%d", i)) if err == nil { diff --git a/sqlbuilder/func.go b/sqlbuilder/func_expression.go similarity index 83% rename from sqlbuilder/func.go rename to sqlbuilder/func_expression.go index ae9babf..38b9c92 100644 --- a/sqlbuilder/func.go +++ b/sqlbuilder/func_expression.go @@ -1,7 +1,5 @@ package sqlbuilder -import "bytes" - type FuncExpression interface { Expression } @@ -26,10 +24,10 @@ func NewNumericFunc(name string, expression Expression) NumericExpression { return numericFunc } -func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (f *numericFunc) Serialize(out *queryData, options ...serializeOption) error { out.WriteString(f.name) out.WriteString("(") - err := f.expression.SerializeSql(out) + err := f.expression.Serialize(out) if err != nil { return err } @@ -39,7 +37,7 @@ func (f *numericFunc) SerializeSql(out *bytes.Buffer, options ...serializeOption } //func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error { -// return f.SerializeSql(out) +// return f.Serialize(out) //} func MAX(expression NumericExpression) NumericExpression { diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index a02a6fe..6607b3b 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "bytes" "database/sql" "github.com/dropbox/godropbox/errors" "github.com/serenize/snaker" @@ -53,6 +52,7 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { return Execute(u, db) } +// expression or default keyword func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { literalRow := []Clause{} @@ -122,84 +122,92 @@ func (i *insertStatementImpl) addError(err string) { i.errors = append(i.errors, err) } -func (s *insertStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("INSERT ") - _, _ = buf.WriteString("INTO ") - +func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { if len(s.errors) > 0 { - return "", errors.New("sql builder errors: " + strings.Join(s.errors, ", ")) + return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", ")) } + queryData := &queryData{} + queryData.WriteString("INSERT INTO ") + if s.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + return "", nil, errors.Newf("nil tableName.") } - buf.WriteString(s.table.SchemaName() + "." + s.table.TableName()) + err = s.table.SerializeSql(queryData) + + if err != nil { + return "", nil, err + } if len(s.columns) > 0 { - _, _ = buf.WriteString(" (") - for i, col := range s.columns { - if i > 0 { - _ = buf.WriteByte(',') - } + queryData.WriteString(" (") - if col == nil { - return "", errors.Newf( - "nil column in columns list. Generated sql: %s", - buf.String()) - } + //for i, col := range s.columns { + // if i > 0 { + // queryData.WriteByte(',') + // } + // + // if col == nil { + // return "", nil, errors.New("nil column in columns list.") + // } + // + // queryData.WriteString(col.Name()) + //} - buf.WriteString(col.Name()) + err = serializeColumnList(s.columns, queryData) + + if err != nil { + return "", nil, err } - buf.WriteString(") ") + queryData.WriteString(") ") } if len(s.rows) == 0 && s.query == nil { - return "", errors.Newf("No row or query specified. Generated sql: %s", buf.String()) + return "", nil, errors.New("No row or query specified.") } if len(s.rows) > 0 && s.query != nil { - return "", errors.Newf("Only new rows or query has to be specified. Generated sql: %s", buf.String()) + return "", nil, errors.New("Only new rows or query has to be specified.") } if len(s.rows) > 0 { - _, _ = buf.WriteString("VALUES (") + queryData.WriteString("VALUES (") for row_i, row := range s.rows { if row_i > 0 { - _, _ = buf.WriteString(", (") + queryData.WriteString(", (") } if len(row) != len(s.columns) { - return "", errors.Newf( - "# of values does not match # of columns. Generated sql: %s", - buf.String()) + return "", nil, errors.New("# of values does not match # of columns.") } - for col_i, value := range row { - if col_i > 0 { - _ = buf.WriteByte(',') - } + err = serializeClauseList(row, queryData) - if value == nil { - return "", errors.Newf( - "nil value in row %d col %d. Generated sql: %s", - row_i, - col_i, - buf.String()) - } - - if err = value.SerializeSql(buf); err != nil { - return - } + if err != nil { + return "", nil, err } - _ = buf.WriteByte(')') + + //for col_i, value := range row { + // if col_i > 0 { + // queryData.WriteByte(',') + // } + // + // if value == nil { + // return "", nil, errors.Newf("nil value in row %d col %d.", row_i, col_i) + // } + // + // if err = value.Serialize(queryData); err != nil { + // return + // } + //} + queryData.WriteByte(')') } } if s.query != nil { - err = s.query.SerializeSql(buf) + err = s.query.Serialize(queryData) if err != nil { return @@ -207,16 +215,16 @@ func (s *insertStatementImpl) String() (sql string, err error) { } if len(s.returning) > 0 { - buf.WriteString(" RETURNING ") + queryData.WriteString(" RETURNING ") - err = serializeProjectionList(s.returning, buf) + err = serializeProjectionList(s.returning, queryData) if err != nil { return } } - buf.WriteByte(';') + queryData.WriteByte(';') - return buf.String(), nil + return queryData.queryBuff.String(), queryData.args, nil } diff --git a/sqlbuilder/keyword.go b/sqlbuilder/keyword.go index a6eaf8c..82359c3 100644 --- a/sqlbuilder/keyword.go +++ b/sqlbuilder/keyword.go @@ -1,14 +1,12 @@ package sqlbuilder -import "bytes" - const ( DEFAULT keywordClause = "DEFAULT" ) type keywordClause string -func (k keywordClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +func (k keywordClause) Serialize(out *queryData, options ...serializeOption) error { out.WriteString(string(k)) return nil diff --git a/sqlbuilder/literal_expression.go b/sqlbuilder/literal_expression.go new file mode 100644 index 0000000..6a408db --- /dev/null +++ b/sqlbuilder/literal_expression.go @@ -0,0 +1,22 @@ +package sqlbuilder + +// Representation of an escaped literal +type literalExpression struct { + expressionInterfaceImpl + value interface{} +} + +func Literal(value interface{}) *literalExpression { + exp := literalExpression{value: value} + exp.expressionInterfaceImpl.parent = &exp + + return &exp +} + +func (l literalExpression) Serialize(out *queryData, options ...serializeOption) error { + //sqltypes.Value(c.value).EncodeSql(out) + + out.InsertArgument(l.value) + + return nil +} diff --git a/sqlbuilder/numeric_expression.go b/sqlbuilder/numeric_expression.go index 5cf5eca..e169a7f 100644 --- a/sqlbuilder/numeric_expression.go +++ b/sqlbuilder/numeric_expression.go @@ -1,11 +1,5 @@ package sqlbuilder -import ( - "bytes" - "github.com/dropbox/godropbox/database/sqltypes" - "github.com/pkg/errors" -) - type NumericExpression interface { Expression @@ -13,8 +7,11 @@ type NumericExpression interface { EqL(literal interface{}) BoolExpression NotEq(expression NumericExpression) BoolExpression NotEqL(literal interface{}) BoolExpression + + Gt(rhs NumericExpression) BoolExpression GtEq(rhs NumericExpression) BoolExpression GtEqL(literal interface{}) BoolExpression + LtEq(rhs NumericExpression) BoolExpression LtEqL(literal interface{}) BoolExpression @@ -44,6 +41,10 @@ func (n *numericInterfaceImpl) NotEqL(literal interface{}) BoolExpression { return NotEq(n.parent, Literal(literal)) } +func (n *numericInterfaceImpl) Gt(expression NumericExpression) BoolExpression { + return Gt(n.parent, expression) +} + func (n *numericInterfaceImpl) GtEq(expression NumericExpression) BoolExpression { return GtEq(n.parent, expression) } @@ -84,12 +85,8 @@ type numericLiteral struct { func NewNumericLiteral(value interface{}) NumericExpression { numericLiteral := numericLiteral{} + numericLiteral.literalExpression = *Literal(value) - sqlValue, err := sqltypes.BuildValue(value) - if err != nil { - panic(errors.Wrap(err, "Invalid literal value")) - } - numericLiteral.literalExpression = *NewLiteralExpression(sqlValue) numericLiteral.numericInterfaceImpl.parent = &numericLiteral return &numericLiteral @@ -133,10 +130,10 @@ func newNumericExpressionWrap(expression Expression) NumericExpression { return &numericExpressionWrap } -func (c *numericExpressionWrapper) SerializeSql(out *bytes.Buffer, options ...serializeOption) (err error) { +func (c *numericExpressionWrapper) Serialize(out *queryData, options ...serializeOption) error { out.WriteString("(") - err = c.expression.SerializeSql(out, options...) + err := c.expression.Serialize(out, options...) out.WriteString(")") - return nil + return err } diff --git a/sqlbuilder/order_by_clause.go b/sqlbuilder/order_by_clause.go new file mode 100644 index 0000000..5a28cf1 --- /dev/null +++ b/sqlbuilder/order_by_clause.go @@ -0,0 +1,46 @@ +package sqlbuilder + +import "github.com/dropbox/godropbox/errors" + +type OrderByClause interface { + Clause + isOrderByClauseType() +} + +type isOrderByClause struct { +} + +func (o *isOrderByClause) isOrderByClauseType() { +} + +type orderByClause struct { + isOrderByClause + expression Expression + ascent bool +} + +func (o *orderByClause) Serialize(out *queryData, options ...serializeOption) error { + if o.expression == nil { + return errors.Newf("nil orderBy by clause.") + } + + if err := o.expression.Serialize(out); err != nil { + return err + } + + if o.ascent { + out.WriteString(" ASC") + } else { + out.WriteString(" DESC") + } + + return nil +} + +func Asc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: true} +} + +func Desc(expression Expression) OrderByClause { + return &orderByClause{expression: expression, ascent: false} +} diff --git a/sqlbuilder/projection.go b/sqlbuilder/projection.go index 41e80db..17124c6 100644 --- a/sqlbuilder/projection.go +++ b/sqlbuilder/projection.go @@ -1,18 +1,16 @@ package sqlbuilder -import "bytes" - type Projection interface { - SerializeForProjection(out *bytes.Buffer) error + SerializeForProjection(out *queryData) error } //------------------------------------------------------// // Dummy type for select * AllColumns type ColumnList []Column -func (cl ColumnList) SerializeForProjection(out *bytes.Buffer) error { +func (cl ColumnList) SerializeForProjection(out *queryData) error { for i, column := range cl { - err := column.SerializeSql(out, FOR_PROJECTION) + err := column.Serialize(out, FOR_PROJECTION) if err != nil { return err diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 33fa57a..9fc96a8 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -1,9 +1,7 @@ package sqlbuilder import ( - "bytes" "database/sql" - "fmt" "github.com/dropbox/godropbox/errors" "github.com/sub0zero/go-sqlbuilder/types" ) @@ -12,17 +10,17 @@ type SelectStatement interface { Statement Expression - Where(expression BoolExpression) SelectStatement - GroupBy(expressions ...Expression) SelectStatement - HAVING(expressions BoolExpression) SelectStatement + DISTINCT() SelectStatement + WHERE(expression BoolExpression) SelectStatement + GROUP_BY(expressions ...Clause) SelectStatement + HAVING(boolExpression BoolExpression) SelectStatement + ORDER_BY(clauses ...OrderByClause) SelectStatement + + LIMIT(limit int64) SelectStatement + OFFSET(offset int64) SelectStatement + + FOR_UPDATE() SelectStatement - OrderBy(clauses ...OrderByClause) SelectStatement - Limit(limit int64) SelectStatement - Offset(offset int64) SelectStatement - Distinct() SelectStatement - WithSharedLock() SelectStatement - ForUpdate() SelectStatement - Comment(comment string) SelectStatement Copy() SelectStatement AsTable(alias string) *SelectStatementTable @@ -33,17 +31,17 @@ type SelectStatement interface { type selectStatementImpl struct { expressionInterfaceImpl - table ReadableTable - projections []Projection - where BoolExpression - group *listClause - having BoolExpression - order *listClause - comment string - limit, offset int64 - withSharedLock bool - forUpdate bool - distinct bool + table ReadableTable + distinct bool + projections []Projection + where BoolExpression + groupBy []Clause + having BoolExpression + orderBy []OrderByClause + + limit, offset int64 + + forUpdate bool } func newSelectStatement( @@ -51,30 +49,119 @@ func newSelectStatement( projections []Projection) SelectStatement { return &selectStatementImpl{ - table: table, - projections: projections, - limit: -1, - offset: -1, - withSharedLock: false, - forUpdate: false, - distinct: false, + table: table, + projections: projections, + limit: -1, + offset: -1, + forUpdate: false, + distinct: false, } } -func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { - str, err := s.String() +func (s *selectStatementImpl) Serialize(out *queryData, options ...serializeOption) error { + + out.WriteString("(") + + err := s.serializeImpl(out, options...) if err != nil { return err } - out.WriteString("(") - out.WriteString(str) out.WriteString(")") return nil } +func (s *selectStatementImpl) serializeImpl(out *queryData, options ...serializeOption) error { + + out.WriteString("SELECT ") + + if s.distinct { + out.WriteString("DISTINCT ") + } + + if s.projections == nil || len(s.projections) == 0 { + return errors.New("No column selected for projection.") + } + + err := serializeProjectionList(s.projections, out) + + if err != nil { + return err + } + + out.WriteString(" FROM ") + + if s.table == nil { + return errors.Newf("nil tableName.") + } + + if err := s.table.SerializeSql(out); err != nil { + return err + } + + if s.where != nil { + out.WriteString(" WHERE ") + if err := s.where.Serialize(out); err != nil { + return err + } + } + + if s.groupBy != nil && len(s.groupBy) > 0 { + out.WriteString(" GROUP BY ") + + err := serializeClauseList(s.groupBy, out) + + if err != nil { + return err + } + } + + if s.having != nil { + out.WriteString(" HAVING ") + if err = s.having.Serialize(out); err != nil { + return err + } + } + + if s.orderBy != nil { + out.WriteString(" ORDER BY ") + if err := serializeOrderByClauseList(s.orderBy, out); err != nil { + return err + } + } + + if s.limit >= 0 { + out.WriteString(" LIMIT ") + out.InsertArgument(s.limit) + } + + if s.offset >= 0 { + out.WriteString(" OFFSET ") + out.InsertArgument(s.offset) + } + + if s.forUpdate { + out.WriteString(" FOR UPDATE") + } + + return nil +} + +// Return the properly escaped SQL statement, against the specified database +func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error) { + queryData := queryData{} + + err = q.serializeImpl(&queryData) + + if err != nil { + return "", nil, err + } + + return queryData.queryBuff.String(), queryData.args, nil +} + func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { return &SelectStatementTable{ statement: s, @@ -95,23 +182,14 @@ func (s *selectStatementImpl) Copy() SelectStatement { return &ret } -func (q *selectStatementImpl) Where(expression BoolExpression) SelectStatement { +func (q *selectStatementImpl) WHERE(expression BoolExpression) SelectStatement { q.where = expression return q } -func (q *selectStatementImpl) GroupBy( - expressions ...Expression) SelectStatement { - - q.group = &listClause{ - clauses: make([]Clause, len(expressions), len(expressions)), - includeParentheses: false, - } - - for i, e := range expressions { - q.group.clauses[i] = e - } - return q +func (s *selectStatementImpl) GROUP_BY(cluases ...Clause) SelectStatement { + s.groupBy = cluases + return s } func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement { @@ -119,132 +197,31 @@ func (q *selectStatementImpl) HAVING(expression BoolExpression) SelectStatement return q } -func (q *selectStatementImpl) OrderBy( - clauses ...OrderByClause) SelectStatement { +func (q *selectStatementImpl) ORDER_BY(clauses ...OrderByClause) SelectStatement { + + q.orderBy = clauses - q.order = newOrderByListClause(clauses...) return q } -func (q *selectStatementImpl) Limit(limit int64) SelectStatement { - q.limit = limit - return q -} - -func (q *selectStatementImpl) Distinct() SelectStatement { - q.distinct = true - return q -} - -func (q *selectStatementImpl) WithSharedLock() SelectStatement { - // We don't need to grab a read lock if we're going to grab a write one - if !q.forUpdate { - q.withSharedLock = true - } - return q -} - -func (q *selectStatementImpl) ForUpdate() SelectStatement { - // Clear a request for a shared lock if we're asking for a write one - q.withSharedLock = false - q.forUpdate = true - return q -} - -func (q *selectStatementImpl) Offset(offset int64) SelectStatement { +func (q *selectStatementImpl) OFFSET(offset int64) SelectStatement { q.offset = offset return q } -func (q *selectStatementImpl) Comment(comment string) SelectStatement { - q.comment = comment +func (q *selectStatementImpl) LIMIT(limit int64) SelectStatement { + q.limit = limit return q } -// Return the properly escaped SQL statement, against the specified database -func (q *selectStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("SELECT ") +func (q *selectStatementImpl) DISTINCT() SelectStatement { + q.distinct = true + return q +} - if err = writeComment(q.comment, buf); err != nil { - return - } - - if q.distinct { - _, _ = buf.WriteString("DISTINCT ") - } - - if q.projections == nil || len(q.projections) == 0 { - return "", errors.Newf( - "No column selected. Generated sql: %s", - buf.String()) - } - - for i, col := range q.projections { - if i > 0 { - _ = buf.WriteByte(',') - } - if col == nil { - return "", errors.Newf( - "nil column selected. Generated sql: %s", - buf.String()) - } - if err = col.SerializeForProjection(buf); err != nil { - return - } - } - - _, _ = buf.WriteString(" FROM ") - if q.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) - } - if err = q.table.SerializeSql(buf); err != nil { - return - } - - if q.where != nil { - _, _ = buf.WriteString(" WHERE ") - if err = q.where.SerializeSql(buf); err != nil { - return - } - } - - if q.group != nil { - _, _ = buf.WriteString(" GROUP BY ") - if err = q.group.SerializeSql(buf); err != nil { - return - } - } - - if q.having != nil { - buf.WriteString(" HAVING ") - if err = q.having.SerializeSql(buf); err != nil { - return - } - } - - if q.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = q.order.SerializeSql(buf); err != nil { - return - } - } - - if q.limit >= 0 { - if q.offset >= 0 { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d, %d", q.offset, q.limit)) - } else { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", q.limit)) - } - } - - if q.forUpdate { - _, _ = buf.WriteString(" FOR UPDATE") - } else if q.withSharedLock { - _, _ = buf.WriteString(" LOCK IN SHARE MODE") - } - - return buf.String(), nil +func (q *selectStatementImpl) FOR_UPDATE() SelectStatement { + q.forUpdate = true + return q } func NumExp(statement SelectStatement) NumericExpression { diff --git a/sqlbuilder/select_statement_table.go b/sqlbuilder/select_statement_table.go index 0410207..5efef4c 100644 --- a/sqlbuilder/select_statement_table.go +++ b/sqlbuilder/select_statement_table.go @@ -1,7 +1,5 @@ package sqlbuilder -import "bytes" - type SelectStatementTable struct { statement SelectStatement columns []Column @@ -41,16 +39,14 @@ func (s *SelectStatementTable) RefStringColumn(column Column) *StringColumn { return strColumn } -func (s *SelectStatementTable) SerializeSql(out *bytes.Buffer) error { +func (s *SelectStatementTable) SerializeSql(out *queryData) error { out.WriteString("( ") - statementStr, err := s.statement.String() + err := s.statement.Serialize(out) if err != nil { return err } - out.WriteString(statementStr) - out.WriteString(" ) AS ") out.WriteString(s.alias) diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index 74caa30..91d2423 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -1,17 +1,13 @@ package sqlbuilder import ( - "bytes" "database/sql" "github.com/sub0zero/go-sqlbuilder/types" - "regexp" - - "github.com/dropbox/godropbox/errors" ) type Statement interface { // String returns generated SQL as string. - String() (sql string, err error) + Sql() (query string, args []interface{}, err error) Query(db types.Db, destination interface{}) error Execute(db types.Db) (sql.Result, error) @@ -88,10 +84,10 @@ type Statement interface { // // for idx, lock := range s.locks { // if lock.t == nil { -// return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) +// return "", errors.Newf("nil tableName.", buf.String()) // } // -// if err = lock.t.SerializeSql(buf); err != nil { +// if err = lock.t.Serialize(buf); err != nil { // return // } // @@ -162,23 +158,23 @@ type Statement interface { // // Once again, teisenberger is lazy. Here's a quick filter on comments -var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$") - -func isValidComment(comment string) bool { - return validCommentRegexp.MatchString(comment) -} - -func writeComment(comment string, buf *bytes.Buffer) error { - if comment != "" { - _, _ = buf.WriteString("/* ") - if !isValidComment(comment) { - return errors.Newf("Invalid comment: %s", comment) - } - _, _ = buf.WriteString(comment) - _, _ = buf.WriteString(" */") - } - return nil -} +//var validCommentRegexp *regexp.Regexp = regexp.MustCompile("^[\\w .?]*$") +// +//func isValidComment(comment string) bool { +// return validCommentRegexp.MatchString(comment) +//} +// +//func writeComment(comment string, buf *bytes.Buffer) error { +// if comment != "" { +// _, _ = buf.WriteString("/* ") +// if !isValidComment(comment) { +// return errors.Newf("Invalid comment: %s", comment) +// } +// _, _ = buf.WriteString(comment) +// _, _ = buf.WriteString(" */") +// } +// return nil +//} func newOrderByListClause(clauses ...OrderByClause) *listClause { ret := &listClause{ diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index 86765a2..ba0d431 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -475,14 +475,14 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { "All inner selects in Union statement must select the "+ "same number of columns. For sanity, you probably "+ "want to select the same tableName columns in the same "+ - "order. If you are selecting on multiple tables, "+ + "orderBy. If you are selecting on multiple tables, "+ "use Null to pad to the right number of fields.") } func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { - // tests on outer statement: Group By, Order By, Limit - // on inner statement: AndWhere, WHERE (with And), Order By, Limit + // tests on outer statement: Group By, Order By, LIMIT + // on inner statement: AndWhere, WHERE (with And), Order By, LIMIT select_queries := make([]SelectStatement, 0, 3) // We're not trying to write a SQL parser, so we won't warn if you do something silly like diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 5a48202..4510da2 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -3,7 +3,6 @@ package sqlbuilder import ( - "bytes" "fmt" "github.com/dropbox/godropbox/errors" ) @@ -15,7 +14,7 @@ type TableInterface interface { Columns() []Column // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. - SerializeSql(out *bytes.Buffer) error + SerializeSql(out *queryData) error } // The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause @@ -52,9 +51,6 @@ type WritableTable interface { // Defines a physical tableName in the database that is both readable and writable. // This function will panic if name is not valid func NewTable(schemaName, name string, columns ...Column) *Table { - if !validIdentifierName(name) { - panic("Invalid tableName name") - } t := &Table{ schemaName: schemaName, @@ -154,28 +150,20 @@ func (t *Table) ForceIndex(index string) *Table { // Generates the sql string for the current tableName expression. Note: the // generated string may not be a valid/executable sql statement. -func (t *Table) SerializeSql(out *bytes.Buffer) error { +func (t *Table) SerializeSql(out *queryData) error { if t == nil { - return errors.Newf("nil tableName. Generated sql: %s", out.String()) + return errors.Newf("nil tableName.") } - _, _ = out.WriteString(t.schemaName) - _, _ = out.WriteString(".") - _, _ = out.WriteString(t.TableName()) + + out.WriteString(t.schemaName) + out.WriteString(".") + out.WriteString(t.TableName()) if len(t.alias) > 0 { out.WriteString(" AS ") out.WriteString(t.alias) } - if t.forcedIndex != "" { - if !validIdentifierName(t.forcedIndex) { - return errors.Newf("'%s' is not a valid identifier for an index", t.forcedIndex) - } - _, _ = out.WriteString(" FORCE INDEX (") - _, _ = out.WriteString(t.forcedIndex) - _, _ = out.WriteString(")") - } - return nil } @@ -307,7 +295,6 @@ func CrossJoin( return newJoinTable(lhs, rhs, CROSS_JOIN, nil) } -// Returns the tableName's name in the database func (t *joinTable) SchemaName() string { return "" } @@ -328,16 +315,16 @@ func (t *joinTable) Column(name string) Column { panic("Not implemented") } -func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { +func (t *joinTable) SerializeSql(out *queryData) (err error) { if t.lhs == nil { - return errors.Newf("nil lhs. Generated sql: %s", out.String()) + return errors.Newf("nil lhs.") } if t.rhs == nil { - return errors.Newf("nil rhs. Generated sql: %s", out.String()) + return errors.Newf("nil rhs.") } if t.onCondition == nil && t.join_type != CROSS_JOIN { - return errors.Newf("nil onCondition. Generated sql: %s", out.String()) + return errors.Newf("nil onCondition.") } if err = t.lhs.SerializeSql(out); err != nil { @@ -346,11 +333,11 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { switch t.join_type { case INNER_JOIN: - _, _ = out.WriteString(" JOIN ") + out.WriteString(" JOIN ") case LEFT_JOIN: - _, _ = out.WriteString(" LEFT JOIN ") + out.WriteString(" LEFT JOIN ") case RIGHT_JOIN: - _, _ = out.WriteString(" RIGHT JOIN ") + out.WriteString(" RIGHT JOIN ") case FULL_JOIN: out.WriteString(" FULL JOIN ") case CROSS_JOIN: @@ -362,8 +349,8 @@ func (t *joinTable) SerializeSql(out *bytes.Buffer) (err error) { } if t.onCondition != nil { - _, _ = out.WriteString(" ON ") - if err = t.onCondition.SerializeSql(out); err != nil { + out.WriteString(" ON ") + if err = t.onCondition.Serialize(out); err != nil { return } } diff --git a/sqlbuilder/types.go b/sqlbuilder/types.go index 8ce595d..13db60e 100644 --- a/sqlbuilder/types.go +++ b/sqlbuilder/types.go @@ -1,10 +1,6 @@ package sqlbuilder -// A clause that can be used in order by -type OrderByClause interface { - Clause - isOrderByClauseInterface -} +// A clause that can be used in orderBy by // A clause that is selectable. //type Projection interface { @@ -16,9 +12,9 @@ type OrderByClause interface { //type ColumnList []Column // -//func (cl ColumnList) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { +//func (cl ColumnList) Serialize(out *bytes.Buffer, options ...serializeOption) error { // for i, column := range cl { -// column.SerializeSql(out) +// column.Serialize(out) // // if i != len(cl)-1 { // out.WriteString(", ") @@ -49,16 +45,6 @@ type OrderByClause interface { // Boiler plates ... // -type isOrderByClauseInterface interface { - isOrderByClauseType() -} - -type isOrderByClause struct { -} - -func (o *isOrderByClause) isOrderByClauseType() { -} - // //type isProjectionInterface interface { // isProjectionType() diff --git a/sqlbuilder/union_statement.go b/sqlbuilder/union_statement.go index 462c125..41c26f0 100644 --- a/sqlbuilder/union_statement.go +++ b/sqlbuilder/union_statement.go @@ -1,17 +1,9 @@ package sqlbuilder -import ( - "bytes" - "database/sql" - "fmt" - "github.com/dropbox/godropbox/errors" - "github.com/sub0zero/go-sqlbuilder/types" -) - -// By default, rows selected by a UNION statement are out-of-order +// By default, rows selected by a UNION statement are out-of-orderBy // If you have an ORDER BY on an inner SELECT statement, the only thing // it affects is the LIMIT clause on that inner statement (the ordering will -// still be out-of-order). +// still be out-of-orderBy). type UnionStatement interface { Statement @@ -27,177 +19,178 @@ type UnionStatement interface { Offset(offset int64) UnionStatement } -func Union(selects ...SelectStatement) UnionStatement { - return &unionStatementImpl{ - selects: selects, - limit: -1, - offset: -1, - unique: true, - } -} - -func UnionAll(selects ...SelectStatement) UnionStatement { - return &unionStatementImpl{ - selects: selects, - limit: -1, - offset: -1, - unique: false, - } -} - -// Similar to selectStatementImpl, but less complete -type unionStatementImpl struct { - selects []SelectStatement - where BoolExpression - group *listClause - order *listClause - limit, offset int64 - // True if results of the union should be deduped. - unique bool -} - -func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { - return Query(s, db, destination) -} - -func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { - return Execute(u, db) -} - -func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { - us.where = expression - return us -} - -// Further filter the query, instead of replacing the filter -func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { - if us.where == nil { - return us.Where(expression) - } - us.where = And(us.where, expression) - return us -} - -func (us *unionStatementImpl) GroupBy( - expressions ...Expression) UnionStatement { - - us.group = &listClause{ - clauses: make([]Clause, len(expressions), len(expressions)), - includeParentheses: false, - } - - for i, e := range expressions { - us.group.clauses[i] = e - } - return us -} - -func (us *unionStatementImpl) OrderBy( - clauses ...OrderByClause) UnionStatement { - - us.order = newOrderByListClause(clauses...) - return us -} - -func (us *unionStatementImpl) Limit(limit int64) UnionStatement { - us.limit = limit - return us -} - -func (us *unionStatementImpl) Offset(offset int64) UnionStatement { - us.offset = offset - return us -} - -func (us *unionStatementImpl) String() (sql string, err error) { - if len(us.selects) == 0 { - return "", errors.Newf("Union statement must have at least one SELECT") - } - - if len(us.selects) == 1 { - return us.selects[0].String() - } - - // Union statements in MySQL require that the same number of columns in each subquery - var projections []Projection - - for _, statement := range us.selects { - // do a type assertion to get at the underlying struct - statementImpl, ok := statement.(*selectStatementImpl) - if !ok { - return "", errors.Newf( - "Expected inner select statement to be of type " + - "selectStatementImpl") - } - - // check that for limit for statements with order by clauses - if statementImpl.order != nil && statementImpl.limit < 0 { - return "", errors.Newf( - "All inner selects in Union statement must have LIMIT if " + - "they have ORDER BY") - } - - // check number of projections - if projections == nil { - projections = statementImpl.projections - } else { - if len(projections) != len(statementImpl.projections) { - return "", errors.Newf( - "All inner selects in Union statement must select the " + - "same number of columns. For sanity, you probably " + - "want to select the same tableName columns in the same " + - "order. If you are selecting on multiple tables, " + - "use Null to pad to the right number of fields.") - } - } - } - - buf := new(bytes.Buffer) - for i, statement := range us.selects { - if i != 0 { - if us.unique { - _, _ = buf.WriteString(" UNION ") - } else { - _, _ = buf.WriteString(" UNION ALL ") - } - } - _, _ = buf.WriteString("(") - selectSql, err := statement.String() - if err != nil { - return "", err - } - _, _ = buf.WriteString(selectSql) - _, _ = buf.WriteString(")") - } - - if us.where != nil { - _, _ = buf.WriteString(" WHERE ") - if err = us.where.SerializeSql(buf); err != nil { - return - } - } - - if us.group != nil { - _, _ = buf.WriteString(" GROUP BY ") - if err = us.group.SerializeSql(buf); err != nil { - return - } - } - - if us.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = us.order.SerializeSql(buf); err != nil { - return - } - } - - if us.limit >= 0 { - if us.offset >= 0 { - _, _ = buf.WriteString( - fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) - } else { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) - } - } - return buf.String(), nil -} +// +//func Union(selects ...SelectStatement) UnionStatement { +// return &unionStatementImpl{ +// selects: selects, +// limit: -1, +// offset: -1, +// unique: true, +// } +//} +// +//func UnionAll(selects ...SelectStatement) UnionStatement { +// return &unionStatementImpl{ +// selects: selects, +// limit: -1, +// offset: -1, +// unique: false, +// } +//} +// +//// Similar to selectStatementImpl, but less complete +//type unionStatementImpl struct { +// selects []SelectStatement +// where BoolExpression +// group *listClause +// order *listClause +// limit, offset int64 +// // True if results of the union should be deduped. +// unique bool +//} +// +//func (s *unionStatementImpl) Query(db types.Db, destination interface{}) error { +// return Query(s, db, destination) +//} +// +//func (u *unionStatementImpl) Execute(db types.Db) (res sql.Result, err error) { +// return Execute(u, db) +//} +// +//func (us *unionStatementImpl) Where(expression BoolExpression) UnionStatement { +// us.where = expression +// return us +//} +// +//// Further filter the query, instead of replacing the filter +//func (us *unionStatementImpl) AndWhere(expression BoolExpression) UnionStatement { +// if us.where == nil { +// return us.Where(expression) +// } +// us.where = And(us.where, expression) +// return us +//} +// +//func (us *unionStatementImpl) GroupBy( +// expressions ...Expression) UnionStatement { +// +// us.group = &listClause{ +// clauses: make([]Clause, len(expressions), len(expressions)), +// includeParentheses: false, +// } +// +// for i, e := range expressions { +// us.group.clauses[i] = e +// } +// return us +//} +// +//func (us *unionStatementImpl) OrderBy( +// clauses ...OrderByClause) UnionStatement { +// +// us.order = newOrderByListClause(clauses...) +// return us +//} +// +//func (us *unionStatementImpl) Limit(limit int64) UnionStatement { +// us.limit = limit +// return us +//} +// +//func (us *unionStatementImpl) Offset(offset int64) UnionStatement { +// us.offset = offset +// return us +//} +// +//func (us *unionStatementImpl) String() (sql string, err error) { +// if len(us.selects) == 0 { +// return "", errors.Newf("Union statement must have at least one SELECT") +// } +// +// if len(us.selects) == 1 { +// return us.selects[0].String() +// } +// +// // Union statements in MySQL require that the same number of columns in each subquery +// var projections []Projection +// +// for _, statement := range us.selects { +// // do a type assertion to get at the underlying struct +// statementImpl, ok := statement.(*selectStatementImpl) +// if !ok { +// return "", errors.Newf( +// "Expected inner select statement to be of type " + +// "selectStatementImpl") +// } +// +// // check that for limit for statements with orderBy by clauses +// if statementImpl.orderBy != nil && statementImpl.limit < 0 { +// return "", errors.Newf( +// "All inner selects in Union statement must have LIMIT if " + +// "they have ORDER BY") +// } +// +// // check number of projections +// if projections == nil { +// projections = statementImpl.projections +// } else { +// if len(projections) != len(statementImpl.projections) { +// return "", errors.Newf( +// "All inner selects in Union statement must select the " + +// "same number of columns. For sanity, you probably " + +// "want to select the same tableName columns in the same " + +// "orderBy. If you are selecting on multiple tables, " + +// "use Null to pad to the right number of fields.") +// } +// } +// } +// +// buf := new(bytes.Buffer) +// for i, statement := range us.selects { +// if i != 0 { +// if us.unique { +// _, _ = buf.WriteString(" UNION ") +// } else { +// _, _ = buf.WriteString(" UNION ALL ") +// } +// } +// _, _ = buf.WriteString("(") +// selectSql, err := statement.String() +// if err != nil { +// return "", err +// } +// _, _ = buf.WriteString(selectSql) +// _, _ = buf.WriteString(")") +// } +// +// if us.where != nil { +// _, _ = buf.WriteString(" WHERE ") +// if err = us.where.Serialize(buf); err != nil { +// return +// } +// } +// +// if us.group != nil { +// _, _ = buf.WriteString(" GROUP BY ") +// if err = us.group.Serialize(buf); err != nil { +// return +// } +// } +// +// if us.order != nil { +// _, _ = buf.WriteString(" ORDER BY ") +// if err = us.order.Serialize(buf); err != nil { +// return +// } +// } +// +// if us.limit >= 0 { +// if us.offset >= 0 { +// _, _ = buf.WriteString( +// fmt.Sprintf(" LIMIT %d, %d", us.offset, us.limit)) +// } else { +// _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", us.limit)) +// } +// } +// return buf.String(), nil +//} diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index a35f731..41f6687 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "bytes" "database/sql" "github.com/dropbox/godropbox/errors" "github.com/sub0zero/go-sqlbuilder/types" @@ -61,60 +60,64 @@ func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStateme return u } -func (u *updateStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("UPDATE ") +func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) { + out := &queryData{} + out.WriteString("UPDATE ") if u.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + return "", nil, errors.New("nil tableName.") } - if err = u.table.SerializeSql(buf); err != nil { + if err = u.table.SerializeSql(out); err != nil { return } if len(u.updateValues) == 0 { - return "", errors.Newf( - "No column updated. Generated sql: %s", - buf.String()) + return "", nil, errors.New("No column updated.") } - _, _ = buf.WriteString(" SET") + out.WriteString(" SET") if len(u.columns) > 1 { - buf.WriteString(" ( ") + out.WriteString(" ( ") } else { - buf.WriteString(" ") + out.WriteString(" ") } - for i, column := range u.columns { - if i > 0 { - buf.WriteString(", ") - } + //for i, column := range u.columns { + // if i > 0 { + // out.WriteString(", ") + // } + // + // out.WriteString(column.Name()) + // + // if err != nil { + // return + // } + //} - buf.WriteString(column.Name()) + err = serializeColumnList(u.columns, out) - if err != nil { - return - } + if err != nil { + return "", nil, err } if len(u.columns) > 1 { - buf.WriteString(" )") + out.WriteString(" )") } - buf.WriteString(" =") + out.WriteString(" =") if len(u.updateValues) > 1 { - buf.WriteString(" (") + out.WriteString(" (") } for i, value := range u.updateValues { if i > 0 { - buf.WriteString(", ") + out.WriteString(", ") } - err = value.SerializeSql(buf) + err = value.Serialize(out) if err != nil { return @@ -122,29 +125,27 @@ func (u *updateStatementImpl) String() (sql string, err error) { } if len(u.updateValues) > 1 { - buf.WriteString(" )") + out.WriteString(" )") } if u.where == nil { - return "", errors.Newf( - "Updating without a WHERE clause. Generated sql: %s", - buf.String()) + return "", nil, errors.New("Updating without a WHERE clause.") } - _, _ = buf.WriteString(" WHERE ") - if err = u.where.SerializeSql(buf); err != nil { + out.WriteString(" WHERE ") + if err = u.where.Serialize(out); err != nil { return } if len(u.returning) > 0 { - buf.WriteString(" RETURNING ") + out.WriteString(" RETURNING ") - err = serializeProjectionList(u.returning, buf) + err = serializeProjectionList(u.returning, out) if err != nil { return } } - return buf.String() + ";", nil + return out.queryBuff.String(), out.args, nil } diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go index e53a916..46187d4 100644 --- a/sqlbuilder/update_statement_test.go +++ b/sqlbuilder/update_statement_test.go @@ -83,7 +83,7 @@ func TestUpdate(t *testing.T) { //func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { // stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt.WHERE(EqL(table1Col2, 2)) -// stmt.OrderBy(table1Col2) +// stmt.ORDER_BY(table1Col2) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) // @@ -99,7 +99,7 @@ func TestUpdate(t *testing.T) { //func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { // stmt := table1.UPDATE().SET(table1Col1, Literal(1)) // stmt.WHERE(EqL(table1Col2, 2)) -// stmt.Limit(5) +// stmt.LIMIT(5) // sql, err := stmt.String() // c.Assert(err, gc.IsNil) // diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 105665b..b3e128b 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -1,19 +1,20 @@ package sqlbuilder import ( - "bytes" "database/sql" + "errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" "github.com/sub0zero/go-sqlbuilder/types" ) -func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error { - for i, value := range expressions { +func serializeOrderByClauseList(orderByClauses []OrderByClause, out *queryData) error { + + for i, value := range orderByClauses { if i > 0 { - buf.WriteString(", ") + out.WriteString(", ") } - err := value.SerializeSql(buf) + err := value.Serialize(out) if err != nil { return err @@ -23,13 +24,33 @@ func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error return nil } -func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error { - for i, value := range projections { +func serializeClauseList(clauses []Clause, out *queryData) (err error) { + + for i, c := range clauses { if i > 0 { - buf.WriteString(", ") + out.WriteString(", ") } - err := value.SerializeForProjection(buf) + if c == nil { + return errors.New("nil clause.") + } + + if err = c.Serialize(out); err != nil { + return + } + } + + return nil +} + +func serializeExpressionList(expressions []Expression, separator string, out *queryData) error { + + for i, value := range expressions { + if i > 0 { + out.WriteString(separator) + } + + err := value.Serialize(out) if err != nil { return err @@ -39,24 +60,55 @@ func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error return nil } +func serializeProjectionList(projections []Projection, out *queryData) error { + for i, col := range projections { + if i > 0 { + out.WriteByte(',') + } + if col == nil { + return errors.New("Projection expression is nil.") + } + + if err := col.SerializeForProjection(out); err != nil { + return err + } + } + + return nil +} + +func serializeColumnList(columns []Column, out *queryData) error { + for i, col := range columns { + if i > 0 { + out.WriteByte(',') + } + + if col == nil { + return errors.New("nil column in columns list.") + } + + out.WriteString(col.Name()) + } + + return nil +} + func Query(statement Statement, db types.Db, destination interface{}) error { - query, err := statement.String() + query, args, err := statement.Sql() if err != nil { return err } - return execution.Execute(db, query, destination) + return execution.Query(db, query, args, destination) } func Execute(statement Statement, db types.Db) (res sql.Result, err error) { - query, err := statement.String() + query, args, err := statement.Sql() if err != nil { return } - res, err = db.Exec(query) - - return + return db.Exec(query, args...) } diff --git a/tests/generator_test.go b/tests/generator_test.go index 23103eb..ee1be4e 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -25,13 +25,14 @@ func TestGenerateModel(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) { actor := model.Actor{} - query := Actor.SELECT(Actor.AllColumns).OrderBy(Actor.ActorID.Asc()) + query := Actor.SELECT(Actor.AllColumns).ORDER_BY(Actor.ActorID.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() fmt.Println(queryStr) assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`) + assert.Equal(t, len(args), 0) err = query.Query(db, &actor) @@ -50,12 +51,14 @@ func TestSelect_ScanToStruct(t *testing.T) { func TestSelect_ScanToSlice(t *testing.T) { customers := []model.Customer{} - query := Customer.SELECT(Customer.AllColumns).OrderBy(Customer.CustomerID.Asc()) + query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) fmt.Println(queryStr) + assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`) + assert.Equal(t, len(args), 0) err = query.Query(db, &customers) assert.NilError(t, err) @@ -76,7 +79,7 @@ func TestSelect_ScanToSlice(t *testing.T) { // SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). // WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) // -// queryStr, err := query.String() +// queryStr, args, err := query.Sql() // assert.NilError(t, err) // assert.Equal(t, queryStr, `SELECT film_actor.actor_id AS "film_actor.actor_id", film_actor.film_id AS "film_actor.film_id", film_actor.last_update AS "film_actor.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext",language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.film_actor JOIN dvds.actor ON film_actor.actor_id = actor.actor_id JOIN dvds.film ON film_actor.film_id = film.film_id JOIN dvds.language ON film.language_id = language.language_id WHERE (film_actor.actor_id>=1 AND film_actor.actor_id<=2)`) // @@ -104,14 +107,18 @@ func TestJoinQuerySlice(t *testing.T) { query := Film. INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). SELECT(Language.AllColumns, Film.AllColumns). - Where(Film.Rating.EqL(string(model.MpaaRating_NC17))). - Limit(15) + WHERE(Film.Rating.EqL(string(model.MpaaRating_NC17))). + LIMIT(15) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = 'NC-17' LIMIT 15`) - //fmt.Println(queryStr) + fmt.Println(queryStr) + assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update",film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = $1 LIMIT $2`) + + assert.Equal(t, len(args), 2) + assert.Equal(t, args[0], string(model.MpaaRating_NC17)) + assert.Equal(t, args[1], int64(15)) err = query.Query(db, &filmsPerLanguage) @@ -149,7 +156,7 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { query := Film.INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). SELECT(Language.AllColumns, Film.AllColumns). - Limit(limit) + LIMIT(limit) filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err := query.Query(db, &filmsPerLanguageWithPtrs) @@ -179,7 +186,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { customersAsc := []model.Customer{} err := Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). - OrderBy(Customer.FirstName.Asc()). + ORDER_BY(Customer.FirstName.Asc()). Query(db, &customersAsc) assert.NilError(t, err) @@ -189,7 +196,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { customersDesc := []model.Customer{} err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). - OrderBy(Customer.FirstName.Desc()). + ORDER_BY(Customer.FirstName.Desc()). Query(db, &customersDesc) assert.NilError(t, err) @@ -202,7 +209,7 @@ func TestSelectOrderByAscDesc(t *testing.T) { customersAscDesc := []model.Customer{} err = Customer.SELECT(Customer.CustomerID, Customer.FirstName, Customer.LastName). - OrderBy(Customer.FirstName.Asc(), Customer.LastName.Desc()). + ORDER_BY(Customer.FirstName.Asc(), Customer.LastName.Desc()). Query(db, &customersAscDesc) assert.NilError(t, err) @@ -227,13 +234,14 @@ func TestSelectFullJoin(t *testing.T) { query := Customer. FULL_JOIN(Address, Customer.AddressID.Eq(Address.AddressID)). SELECT(Customer.AllColumns, Address.AllColumns). - OrderBy(Customer.CustomerID.Asc()) + ORDER_BY(Customer.CustomerID.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer FULL JOIN dvds.address ON customer.address_id = address.address_id ORDER BY customer.customer_id ASC`) + assert.Equal(t, len(args), 0) allCustomersAndAddress := []struct { Address *model.Address @@ -259,13 +267,14 @@ func TestSelectFullCrossJoin(t *testing.T) { query := Customer. CrossJoin(Address). SELECT(Customer.AllColumns, Address.AllColumns). - OrderBy(Customer.CustomerID.Asc()). - Limit(1000) + ORDER_BY(Customer.CustomerID.Asc()). + LIMIT(1000) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT 1000`) + assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active",address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT $1`) + assert.Equal(t, len(args), 1) customerAddresCrosJoined := []model.Customer{} @@ -286,9 +295,10 @@ func TestSelectSelfJoin(t *testing.T) { query := f1. INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).And(f1.Length.Eq(f2.Length))). SELECT(f1.AllColumns, f2.AllColumns). - OrderBy(f1.FilmID.Asc()) + ORDER_BY(f1.FilmID.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() + assert.Equal(t, len(args), 0) assert.NilError(t, err) @@ -326,10 +336,11 @@ func TestSelectAliasColumn(t *testing.T) { SELECT(f1.Title.As("thesame_length_films.title1"), f2.Title.As("thesame_length_films.title2"), f1.Length.As("thesame_length_films.length")). - OrderBy(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()). - Limit(1000) + ORDER_BY(f1.Length.Asc(), f1.Title.Asc(), f2.Title.Asc()). + LIMIT(1000) - queryStr, err := query.String() + queryStr, args, err := query.Sql() + assert.Equal(t, len(args), 1) assert.NilError(t, err) @@ -372,9 +383,10 @@ func TestSelectSelfReferenceType(t *testing.T) { INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) fmt.Println(queryStr) + assert.Equal(t, len(args), 0) staffs := []staff{} @@ -394,13 +406,13 @@ func TestSubQuery(t *testing.T) { // selectStmtTable.RefIntColumnName("actor.last_name").As("nesto2"), // ) // - //queryStr, err := query.String() + //queryStr, args, err := query.Sql() // //assert.NilError(t, err) // //fmt.Println(queryStr) // - //avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).Limit(1)) + //avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).LIMIT(1)) // //Customer. // INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))). @@ -408,7 +420,7 @@ func TestSubQuery(t *testing.T) { // WHERE(Actor.LastName.Neq(avrgCustomer)) rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). - Where(Film.Rating.EqL("R")). + WHERE(Film.Rating.EqL("R")). AsTable("films") query := Actor.INNER_JOIN(FilmActor, Actor.ActorID.Eq(FilmActor.FilmID)). @@ -420,10 +432,10 @@ func TestSubQuery(t *testing.T) { rFilmsOnly.RefStringColumn(Film.Rating).As("film.rating"), ) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) - + assert.Equal(t, len(args), 1) fmt.Println(queryStr) } @@ -431,12 +443,12 @@ func TestSubQuery(t *testing.T) { func TestSelectFunctions(t *testing.T) { query := Film.SELECT(sqlbuilder.MAX(Film.RentalRate).As("max_film_rate")) - str, err := query.String() + str, args, err := query.Sql() assert.NilError(t, err) assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`) - + assert.Equal(t, len(args), 0) fmt.Println(str) } @@ -445,13 +457,13 @@ func TestSelectQueryScalar(t *testing.T) { maxFilmRentalRate := sqlbuilder.NumExp(Film.SELECT(sqlbuilder.MAX(Film.RentalRate))) query := Film.SELECT(Film.AllColumns). - Where(Film.RentalRate.Eq(maxFilmRentalRate)). - OrderBy(Film.FilmID.Asc()) + WHERE(Film.RentalRate.Eq(maxFilmRentalRate)). + ORDER_BY(Film.FilmID.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) - + assert.Equal(t, len(args), 0) fmt.Println(queryStr) maxRentalRateFilms := []model.Film{} @@ -488,16 +500,17 @@ func TestSelectGroupByHaving(t *testing.T) { Payment.CustomerID.As("customer_payment_sum.customer_id"), sqlbuilder.SUM(Payment.Amount).As("customer_payment_sum.amount_sum"), ). - GroupBy(Payment.CustomerID). - OrderBy(sqlbuilder.SUM(Payment.Amount).Asc()). - HAVING(sqlbuilder.Gt(sqlbuilder.SUM(Payment.Amount), sqlbuilder.Literal(100))) + GROUP_BY(Payment.CustomerID). + ORDER_BY(sqlbuilder.SUM(Payment.Amount).Asc()). + HAVING(sqlbuilder.SUM(Payment.Amount).Gt(sqlbuilder.NewNumericLiteral(100))) - queryStr, err := customersPaymentQuery.String() + queryStr, args, err := customersPaymentQuery.Sql() assert.NilError(t, err) fmt.Println(queryStr) + assert.Equal(t, len(args), 1) + assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>$1 ORDER BY SUM(payment.amount) ASC`) - assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id",SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount)>100 ORDER BY SUM(payment.amount) ASC`) type CustomerPaymentSum struct { CustomerID int16 AmountSum float64 @@ -528,7 +541,7 @@ func TestSelectGroupBy2(t *testing.T) { Payment.CustomerID, sqlbuilder.SUM(Payment.Amount).As("amount_sum"), ). - GroupBy(Payment.CustomerID) + GROUP_BY(Payment.CustomerID) customersPaymentTable := customersPaymentSubQuery.AsTable("customer_payment_sum") amountSumColumn := customersPaymentTable.RefIntColumnName("amount_sum") @@ -536,11 +549,12 @@ func TestSelectGroupBy2(t *testing.T) { query := Customer. INNER_JOIN(customersPaymentTable, Customer.CustomerID.Eq(customersPaymentTable.RefIntColumn(Payment.CustomerID))). SELECT(Customer.AllColumns, amountSumColumn.As("customer_with_amounts.amount_sum")). - OrderBy(amountSumColumn.Asc()) + ORDER_BY(amountSumColumn.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) fmt.Println(queryStr) + assert.Equal(t, len(args), 0) err = query.Query(db, &customersWithAmounts) assert.NilError(t, err) @@ -565,13 +579,13 @@ func TestSelectGroupBy2(t *testing.T) { func TestSelectTimeColumns(t *testing.T) { query := Payment.SELECT(Payment.AllColumns). - Where(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")). - OrderBy(Payment.PaymentDate.Asc()) + WHERE(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")). + ORDER_BY(Payment.PaymentDate.Asc()) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) - + assert.Equal(t, len(args), 1) fmt.Println(queryStr) payments := []model.Payment{} diff --git a/tests/insert_test.go b/tests/insert_test.go index 2695d97..48482fb 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -18,13 +18,14 @@ func TestInsertValues(t *testing.T) { VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT). RETURNING(table.Link.ID) - insertQueryStr, err := insertQuery.String() + insertQueryStr, args, err := insertQuery.Sql() assert.NilError(t, err) + assert.Equal(t, len(args), 8) fmt.Println(insertQueryStr) - assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id AS "link.id";`) + assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ($1, $2, DEFAULT), ($3, $4, DEFAULT), ($5, $6, DEFAULT), ($7, $8, DEFAULT) RETURNING link.id AS "link.id";`) res, err := insertQuery.Execute(db) assert.NilError(t, err) @@ -68,9 +69,10 @@ func TestInsertDataObject(t *testing.T) { INSERT(table.Link.URL, table.Link.Name). VALUES_MAPPING(linkData) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) + assert.Equal(t, len(args), 2) fmt.Println(queryStr) @@ -92,9 +94,10 @@ func TestInsertQuery(t *testing.T) { INSERT(table.Link.URL, table.Link.Name). QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name)) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) + assert.Equal(t, len(args), 0) fmt.Println(queryStr) diff --git a/tests/sample_test.go b/tests/sample_test.go index f8ad90c..f4ef72d 100644 --- a/tests/sample_test.go +++ b/tests/sample_test.go @@ -12,11 +12,12 @@ import ( func TestUUIDType(t *testing.T) { query := table.AllTypes. SELECT(table.AllTypes.AllColumns). - Where(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) + WHERE(table.AllTypes.UUID.EqL("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) + assert.Equal(t, len(args), 1) fmt.Println(queryStr) //assert.Equal(t, queryStr, `SELECT all_types.character AS "all_types.character", all_types.character_varying AS "all_types.character_varying", all_types.text AS "all_types.text", all_types.bytea AS "all_types.bytea", all_types.timestamp_without_time_zone AS "all_types.timestamp_without_time_zone", all_types.timestamp_with_time_zone AS "all_types.timestamp_with_time_zone", all_types.uuid AS "all_types.uuid", all_types.json AS "all_types.json", all_types.jsonb AS "all_types.jsonb" FROM test_sample.all_types WHERE all_types.uuid = 'a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11`) result := model.AllTypes{} @@ -29,11 +30,11 @@ func TestEnumType(t *testing.T) { query := table.Person. SELECT(table.Person.AllColumns) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) fmt.Println(queryStr) - + assert.Equal(t, len(args), 0) result := []model.Person{} err = query.Query(db, &result) diff --git a/tests/update_test.go b/tests/update_test.go index add2fb3..77b5657 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -23,10 +23,10 @@ func TestUpdateValues(t *testing.T) { SET("Bong", "http://bong.com"). WHERE(table.Link.Name.EqL("Bing")) - queryStr, err := query.String() + queryStr, args, err := query.Sql() assert.NilError(t, err) - + assert.Equal(t, len(args), 3) fmt.Println(queryStr) res, err := query.Execute(db) @@ -38,7 +38,7 @@ func TestUpdateValues(t *testing.T) { links := []model.Link{} err = table.Link.SELECT(table.Link.AllColumns). - Where(table.Link.Name.EqL("Bong")). + WHERE(table.Link.Name.EqL("Bong")). Query(db, &links) assert.NilError(t, err) @@ -63,10 +63,10 @@ func TestUpdateAndReturning(t *testing.T) { WHERE(table.Link.Name.EqL("Ask")). RETURNING(table.Link.AllColumns) - stmtStr, err := stmt.String() + stmtStr, args, err := stmt.Sql() assert.NilError(t, err) - + assert.Equal(t, len(args), 3) fmt.Println(stmtStr) links := []model.Link{}