Statements sql serialization simplified.

This commit is contained in:
zer0sub 2019-05-08 13:47:01 +02:00
parent d9bbec3795
commit 439c9f1ef9
26 changed files with 274 additions and 284 deletions

View file

@ -12,15 +12,15 @@ func NewAlias(expression expression, alias string) *Alias {
} }
} }
func (a *Alias) serializeForProjection(out *queryData) error { func (a *Alias) serializeForProjection(statement statementType, out *queryData) error {
err := a.expression.serializeForProjection(out) err := a.expression.serializeForProjection(statement, out)
if err != nil { if err != nil {
return err return err
} }
out.WriteString(" AS \"" + a.alias + "\"") out.writeString(" AS \"" + a.alias + "\"")
return nil return nil
} }

View file

@ -9,7 +9,7 @@ func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3)) boolExpression := Eq(Literal(2), Literal(3))
out := queryData{} out := queryData{}
err := boolExpression.serialize(&out) err := boolExpression.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
@ -20,7 +20,7 @@ func TestBinaryExpression(t *testing.T) {
alias := boolExpression.AS("alias_eq_expression") alias := boolExpression.AS("alias_eq_expression")
out := queryData{} out := queryData{}
err := alias.serializeForProjection(&out) err := alias.serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`) assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`)
@ -30,7 +30,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.AND(Eq(Literal(4), Literal(5))) exp := boolExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.serialize(&out) err := exp.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `($1 = $2 AND $3 = $4)`) assert.Equal(t, out.buff.String(), `($1 = $2 AND $3 = $4)`)
@ -40,7 +40,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.OR(Eq(Literal(4), Literal(5))) exp := boolExpression.OR(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.serialize(&out) err := exp.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `($1 = $2 OR $3 = $4)`) assert.Equal(t, out.buff.String(), `($1 = $2 OR $3 = $4)`)
@ -51,7 +51,7 @@ func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1))) notExpression := Not(Eq(Literal(2), Literal(1)))
out := queryData{} out := queryData{}
err := notExpression.serialize(&out) err := notExpression.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "NOT $1 = $2") assert.Equal(t, out.buff.String(), "NOT $1 = $2")
@ -60,7 +60,7 @@ func TestUnaryExpression(t *testing.T) {
alias := notExpression.AS("alias_not_expression") alias := notExpression.AS("alias_not_expression")
out := queryData{} out := queryData{}
err := alias.serializeForProjection(&out) err := alias.serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `NOT $1 = $2 AS "alias_not_expression"`) assert.Equal(t, out.buff.String(), `NOT $1 = $2 AS "alias_not_expression"`)
@ -70,7 +70,7 @@ func TestUnaryExpression(t *testing.T) {
exp := notExpression.AND(Eq(Literal(4), Literal(5))) exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.serialize(&out) err := exp.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `(NOT $1 = $2 AND $3 = $4)`) assert.Equal(t, out.buff.String(), `(NOT $1 = $2 AND $3 = $4)`)
@ -81,7 +81,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1))) notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := queryData{} out := queryData{}
err := notExpression.serialize(&out) err := notExpression.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "IS TRUE $1 = $2") assert.Equal(t, out.buff.String(), "IS TRUE $1 = $2")
@ -90,7 +90,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
exp := notExpression.AND(Eq(Literal(4), Literal(5))) exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{} out := queryData{}
err := exp.serialize(&out) err := exp.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `(IS TRUE $1 = $2 AND $3 = $4)`) assert.Equal(t, out.buff.String(), `(IS TRUE $1 = $2 AND $3 = $4)`)
@ -101,7 +101,7 @@ func TestBoolLiteral(t *testing.T) {
literal := newBoolLiteralExpression(true) literal := newBoolLiteralExpression(true)
out := queryData{} out := queryData{}
err := literal.serialize(&out) err := literal.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
@ -116,7 +116,7 @@ func TestExists(t *testing.T) {
) )
out := queryData{} out := queryData{}
err := query.serialize(&out) err := query.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "EXISTS (SELECT $1 FROM db.table2 WHERE table1.col1 = table2.col3)") assert.Equal(t, out.buff.String(), "EXISTS (SELECT $1 FROM db.table2 WHERE table1.col1 = table2.col3)")
@ -126,7 +126,7 @@ func TestIn(t *testing.T) {
query := Literal(1.11).IN(table1.SELECT(table1Col1)) query := Literal(1.11).IN(table1.SELECT(table1Col1))
out := queryData{} out := queryData{}
err := query.serialize(&out) err := query.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `$1 IN (SELECT table1.col1 AS "table1.col1" FROM db.table1)`) assert.Equal(t, out.buff.String(), `$1 IN (SELECT table1.col1 AS "table1.col1" FROM db.table1)`)
@ -134,7 +134,7 @@ func TestIn(t *testing.T) {
query2 := ROW(Literal(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)) query2 := ROW(Literal(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1))
out = queryData{} out = queryData{}
err = query2.serialize(&out) err = query2.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN (SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2))`) assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN (SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2))`)

View file

@ -7,69 +7,70 @@ import (
) )
type clause interface { type clause interface {
serialize(out *queryData) error serialize(statement statementType, out *queryData) error
} }
type queryData struct { type queryData struct {
buff bytes.Buffer buff bytes.Buffer
args []interface{} args []interface{}
statementType int
} }
type statementType string
const ( const (
select_statement = iota select_statement statementType = "SELECT"
insert_statement insert_statement statementType = "INSERT"
update_statement update_statement statementType = "UPDATE"
delete_statement delete_statement statementType = "DELETE"
set_statement set_statement statementType = "SET"
lock_statement statementType = "LOCK"
) )
func (q *queryData) WriteProjection(projections []projection) error { func (q *queryData) writeProjection(statement statementType, projections []projection) error {
return serializeProjectionList(projections, q) return serializeProjectionList(statement, projections, q)
} }
func (q *queryData) WriteWhere(where expression) error { func (q *queryData) writeWhere(statement statementType, where expression) error {
q.WriteString(" WHERE ") q.writeString(" WHERE ")
return where.serialize(q) return where.serialize(statement, q)
} }
func (q *queryData) WriteGroupBy(groupBy []groupByClause) error { func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.WriteString(" GROUP BY ") q.writeString(" GROUP BY ")
return serializeGroupByClauseList(groupBy, q) return serializeGroupByClauseList(statement, groupBy, q)
} }
func (q *queryData) WriteOrderBy(orderBy []orderByClause) error { func (q *queryData) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.WriteString(" ORDER BY ") q.writeString(" ORDER BY ")
return serializeOrderByClauseList(orderBy, q) return serializeOrderByClauseList(statement, orderBy, q)
} }
func (q *queryData) WriteHaving(having expression) error { func (q *queryData) writeHaving(statement statementType, having expression) error {
q.WriteString(" HAVING ") q.writeString(" HAVING ")
return having.serialize(q) return having.serialize(statement, q)
} }
func (q *queryData) Write(data []byte) { func (q *queryData) write(data []byte) {
q.buff.Write(data) q.buff.Write(data)
} }
func (q *queryData) WriteString(str string) { func (q *queryData) writeString(str string) {
q.buff.WriteString(str) q.buff.WriteString(str)
} }
func (q *queryData) WriteByte(b byte) { func (q *queryData) writeByte(b byte) {
q.buff.WriteByte(b) q.buff.WriteByte(b)
} }
func (q *queryData) InsertArgument(arg interface{}) { func (q *queryData) insertArgument(arg interface{}) {
q.args = append(q.args, arg) q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args)) argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.buff.WriteString(argPlaceholder) q.buff.WriteString(argPlaceholder)
} }
func (q *queryData) Reset() { func (q *queryData) reset() {
q.buff.Reset() q.buff.Reset()
q.args = []interface{}{} q.args = []interface{}{}
} }

View file

@ -77,42 +77,42 @@ func (c *baseColumn) DefaultAlias() projection {
return c.AS(c.tableName + "." + c.name) return c.AS(c.tableName + "." + c.name)
} }
func (c *baseColumn) serializeAsOrderBy(out *queryData) error { func (c *baseColumn) serializeAsOrderBy(statement statementType, out *queryData) error {
if out.statementType == set_statement { if statement == set_statement {
// set statement (UNION, EXCEPT ...) can reference only select projections in order by clause // set statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.WriteString(`"`) out.writeString(`"`)
if c.tableName != "" { if c.tableName != "" {
out.WriteString(c.tableName) out.writeString(c.tableName)
out.WriteString(".") out.writeString(".")
} }
out.WriteString(c.name) out.writeString(c.name)
out.WriteString(`"`) out.writeString(`"`)
return nil return nil
} }
return c.serialize(out) return c.serialize(statement, out)
} }
func (c baseColumn) serialize(out *queryData) error { func (c baseColumn) serialize(statement statementType, out *queryData) error {
if c.tableName != "" { if c.tableName != "" {
out.WriteString(c.tableName) out.writeString(c.tableName)
out.WriteString(".") out.writeString(".")
} }
wrapColumnName := strings.Contains(c.name, ".") wrapColumnName := strings.Contains(c.name, ".")
if wrapColumnName { if wrapColumnName {
out.WriteString(`"`) out.writeString(`"`)
} }
out.WriteString(c.name) out.writeString(c.name)
if wrapColumnName { if wrapColumnName {
out.WriteString(`"`) out.writeString(`"`)
} }
return nil return nil

View file

@ -9,26 +9,26 @@ func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable) boolColumn := NewBoolColumn("col", Nullable)
out := queryData{} out := queryData{}
err := boolColumn.serialize(&out) err := boolColumn.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.reset()
err = boolColumn.serialize(&out) err = boolColumn.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.reset()
boolColumn.setTableName("table1") boolColumn.setTableName("table1")
err = boolColumn.DefaultAlias().serializeForProjection(&out) err = boolColumn.DefaultAlias().serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset() out.reset()
boolColumn.setTableName("table1") boolColumn.setTableName("table1")
aliasedBoolColumn := boolColumn.AS("alias1") aliasedBoolColumn := boolColumn.AS("alias1")
err = aliasedBoolColumn.serializeForProjection(&out) err = aliasedBoolColumn.serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`) assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
} }
@ -37,26 +37,26 @@ func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable) integerColumn := NewIntegerColumn("col", Nullable)
out := queryData{} out := queryData{}
err := integerColumn.serialize(&out) err := integerColumn.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.reset()
err = integerColumn.serialize(&out) err = integerColumn.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.reset()
integerColumn.setTableName("table1") integerColumn.setTableName("table1")
err = integerColumn.DefaultAlias().serializeForProjection(&out) err = integerColumn.DefaultAlias().serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset() out.reset()
integerColumn.setTableName("table1") integerColumn.setTableName("table1")
aliasedBoolColumn := integerColumn.AS("alias1") aliasedBoolColumn := integerColumn.AS("alias1")
err = aliasedBoolColumn.serializeForProjection(&out) err = aliasedBoolColumn.serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`) assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
} }
@ -65,26 +65,26 @@ func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable) numericColumn := NewNumericColumn("col", Nullable)
out := queryData{} out := queryData{}
err := numericColumn.serialize(&out) err := numericColumn.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.reset()
err = numericColumn.serialize(&out) err = numericColumn.serialize(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.reset()
numericColumn.setTableName("table1") numericColumn.setTableName("table1")
err = numericColumn.DefaultAlias().serializeForProjection(&out) err = numericColumn.DefaultAlias().serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset() out.reset()
numericColumn.setTableName("table1") numericColumn.setTableName("table1")
aliasedBoolColumn := numericColumn.AS("alias1") aliasedBoolColumn := numericColumn.AS("alias1")
err = aliasedBoolColumn.serializeForProjection(&out) err = aliasedBoolColumn.serializeForProjection(select_statement, &out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`) assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
} }

View file

@ -30,15 +30,14 @@ func (d *deleteStatementImpl) WHERE(expression boolExpression) deleteStatement {
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) { func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{} queryData := &queryData{}
queryData.statementType = delete_statement
queryData.WriteString("DELETE FROM ") queryData.writeString("DELETE FROM ")
if d.table == nil { if d.table == nil {
return "", nil, errors.New("nil tableName.") return "", nil, errors.New("nil tableName.")
} }
if err = d.table.serializeSql(queryData); err != nil { if err = d.table.serialize(delete_statement, queryData); err != nil {
return return
} }
@ -46,7 +45,7 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
return "", nil, errors.New("Deleting without a WHERE clause.") return "", nil, errors.New("Deleting without a WHERE clause.")
} }
if err = queryData.WriteWhere(d.where); err != nil { if err = queryData.writeWhere(delete_statement, d.where); err != nil {
return return
} }

View file

@ -53,16 +53,16 @@ func (e *expressionInterfaceImpl) DESC() orderByClause {
return &orderByClauseImpl{expression: e.parent, ascent: false} return &orderByClauseImpl{expression: e.parent, ascent: false}
} }
func (e *expressionInterfaceImpl) serializeForGroupBy(out *queryData) error { func (e *expressionInterfaceImpl) serializeForGroupBy(statement statementType, out *queryData) error {
return e.parent.serialize(out) return e.parent.serialize(statement, out)
} }
func (e *expressionInterfaceImpl) serializeForProjection(out *queryData) error { func (e *expressionInterfaceImpl) serializeForProjection(statement statementType, out *queryData) error {
return e.parent.serialize(out) return e.parent.serialize(statement, out)
} }
func (e *expressionInterfaceImpl) serializeAsOrderBy(out *queryData) error { func (e *expressionInterfaceImpl) serializeAsOrderBy(statement statementType, out *queryData) error {
return e.parent.serialize(out) return e.parent.serialize(statement, out)
} }
// Representation of binary operations (e.g. comparisons, arithmetic) // Representation of binary operations (e.g. comparisons, arithmetic)
@ -95,7 +95,7 @@ func isSimpleOperand(expression expression) bool {
return false return false
} }
func (c *binaryExpression) serialize(out *queryData) error { func (c *binaryExpression) serialize(statement statementType, out *queryData) error {
if c.lhs == nil { if c.lhs == nil {
return errors.Newf("nil lhs.") return errors.Newf("nil lhs.")
} }
@ -106,21 +106,21 @@ func (c *binaryExpression) serialize(out *queryData) error {
wrap := !isSimpleOperand(c.lhs) && !isSimpleOperand(c.rhs) wrap := !isSimpleOperand(c.lhs) && !isSimpleOperand(c.rhs)
if wrap { if wrap {
out.WriteString("(") out.writeString("(")
} }
if err := c.lhs.serialize(out); err != nil { if err := c.lhs.serialize(statement, out); err != nil {
return err return err
} }
out.WriteString(" " + c.operator + " ") out.writeString(" " + c.operator + " ")
if err := c.rhs.serialize(out); err != nil { if err := c.rhs.serialize(statement, out); err != nil {
return err return err
} }
if wrap { if wrap {
out.WriteString(")") out.writeString(")")
} }
return nil return nil
@ -141,13 +141,13 @@ func newPrefixExpression(expression expression, operator string) prefixExpressio
return prefixExpression return prefixExpression
} }
func (p *prefixExpression) serialize(out *queryData) error { func (p *prefixExpression) serialize(statement statementType, out *queryData) error {
out.WriteString(p.operator + " ") out.writeString(p.operator + " ")
if p.expression == nil { if p.expression == nil {
return errors.Newf("nil prefix expression.") return errors.Newf("nil prefix expression.")
} }
if err := p.expression.serialize(out); err != nil { if err := p.expression.serialize(statement, out); err != nil {
return err return err
} }

View file

@ -14,14 +14,14 @@ type intervalExpression struct {
const intervalSep = ":" const intervalSep = ":"
func (c *intervalExpression) serialize(out *queryData) error { func (c *intervalExpression) serialize(statement statementType, out *queryData) error {
out.WriteString("INTERVAL '") out.writeString("INTERVAL '")
duration := c.duration duration := c.duration
if duration < 0 { if duration < 0 {
duration = -duration duration = -duration
out.WriteString("-") out.writeString("-")
} }
hours := duration / time.Hour hours := duration / time.Hour
@ -29,14 +29,14 @@ func (c *intervalExpression) serialize(out *queryData) error {
sec := (duration % time.Minute) / time.Second sec := (duration % time.Minute) / time.Second
msec := (duration % time.Second) / time.Microsecond msec := (duration % time.Second) / time.Microsecond
out.WriteString(strconv.FormatInt(int64(hours), 10)) out.writeString(strconv.FormatInt(int64(hours), 10))
out.WriteString(intervalSep) out.writeString(intervalSep)
out.WriteString(strconv.FormatInt(int64(minutes), 10)) out.writeString(strconv.FormatInt(int64(minutes), 10))
out.WriteString(intervalSep) out.writeString(intervalSep)
out.WriteString(strconv.FormatInt(int64(sec), 10)) out.writeString(strconv.FormatInt(int64(sec), 10))
out.WriteString(intervalSep) out.writeString(intervalSep)
out.WriteString(strconv.FormatInt(int64(msec), 10)) out.writeString(strconv.FormatInt(int64(msec), 10))
out.WriteString("' HOUR_MICROSECOND") out.writeString("' HOUR_MICROSECOND")
return nil return nil
} }

View file

@ -19,7 +19,7 @@ func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -28,7 +28,7 @@ func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -37,7 +37,7 @@ func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -49,7 +49,7 @@ func (s *ExprSuite) TestLikeExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -65,7 +65,7 @@ func (s *ExprSuite) TestRegexExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -81,7 +81,7 @@ func (s *ExprSuite) TestAndExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -96,7 +96,7 @@ func (s *ExprSuite) TestOrExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -159,7 +159,7 @@ func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -168,7 +168,7 @@ func (s *ExprSuite) TestNegateExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -180,7 +180,7 @@ func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -189,7 +189,7 @@ func (s *ExprSuite) TestEqExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -201,7 +201,7 @@ func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -213,7 +213,7 @@ func (s *ExprSuite) TestNeqExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -225,7 +225,7 @@ func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -237,7 +237,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -249,7 +249,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -264,7 +264,7 @@ func (s *ExprSuite) TestGtExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -276,7 +276,7 @@ func (s *ExprSuite) TestGteExpr(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := expr.serialize(buf) err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()

View file

@ -47,16 +47,16 @@ func (s *expressionTableImpl) RefStringColumn(column column) *StringColumn {
return strColumn return strColumn
} }
func (s *expressionTableImpl) serializeSql(out *queryData) error { func (s *expressionTableImpl) serialize(statement statementType, out *queryData) error {
out.WriteString("( ") out.writeString("( ")
err := s.statement.serialize(out) err := s.statement.serialize(statement, out)
if err != nil { if err != nil {
return err return err
} }
out.WriteString(" ) AS ") out.writeString(" ) AS ")
out.WriteString(s.alias) out.writeString(s.alias)
return nil return nil
} }

View file

@ -28,14 +28,14 @@ func newFunc(name string, expressions []expression, parent expression) *funcExpr
return funcExp return funcExp
} }
func (f *funcExpressionImpl) serialize(out *queryData) error { func (f *funcExpressionImpl) serialize(statement statementType, out *queryData) error {
out.WriteString(f.name) out.writeString(f.name)
out.WriteString("(") out.writeString("(")
err := serializeExpressionList(f.expression, ", ", out) err := serializeExpressionList(statement, f.expression, ", ", out)
if err != nil { if err != nil {
return err return err
} }
out.WriteString(")") out.writeString(")")
return nil return nil
} }
@ -54,10 +54,6 @@ func NewNumericFunc(name string, expressions ...expression) numericExpression {
return numericFunc return numericFunc
} }
//func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error {
// return f.serialize(out)
//}
func MAX(expression numericExpression) numericExpression { func MAX(expression numericExpression) numericExpression {
return NewNumericFunc("MAX", expression) return NewNumericFunc("MAX", expression)
} }
@ -111,12 +107,12 @@ func (c *caseExpression) ELSE(els expression) caseInterface {
return c return c
} }
func (c *caseExpression) serialize(out *queryData) error { func (c *caseExpression) serialize(statement statementType, out *queryData) error {
out.WriteString("(CASE") out.writeString("(CASE")
if c.expression != nil { if c.expression != nil {
out.WriteString(" ") out.writeString(" ")
err := c.expression.serialize(out) err := c.expression.serialize(statement, out)
if err != nil { if err != nil {
return err return err
@ -132,15 +128,15 @@ func (c *caseExpression) serialize(out *queryData) error {
} }
for i, when := range c.when { for i, when := range c.when {
out.WriteString(" WHEN ") out.writeString(" WHEN ")
err := when.serialize(out) err := when.serialize(statement, out)
if err != nil { if err != nil {
return err return err
} }
out.WriteString(" THEN ") out.writeString(" THEN ")
err = c.then[i].serialize(out) err = c.then[i].serialize(statement, out)
if err != nil { if err != nil {
return err return err
@ -148,15 +144,15 @@ func (c *caseExpression) serialize(out *queryData) error {
} }
if c.els != nil { if c.els != nil {
out.WriteString(" ELSE ") out.writeString(" ELSE ")
err := c.els.serialize(out) err := c.els.serialize(statement, out)
if err != nil { if err != nil {
return err return err
} }
} }
out.WriteString(" END)") out.writeString(" END)")
return nil return nil
} }

View file

@ -12,7 +12,7 @@ func TestCase1(t *testing.T) {
queryData := &queryData{} queryData := &queryData{}
err := query.serialize(queryData) err := query.serialize(select_statement, queryData)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`) assert.Equal(t, queryData.buff.String(), `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`)
@ -26,7 +26,7 @@ func TestCase2(t *testing.T) {
queryData := &queryData{} queryData := &queryData{}
err := query.serialize(queryData) err := query.serialize(select_statement, queryData)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`) assert.Equal(t, queryData.buff.String(), `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`)
@ -37,7 +37,7 @@ func TestInterval(t *testing.T) {
queryData := &queryData{} queryData := &queryData{}
err := query.serialize(queryData) err := query.serialize(select_statement, queryData)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `INTERVAL $1`) assert.Equal(t, queryData.buff.String(), `INTERVAL $1`)

View file

@ -1,7 +1,7 @@
package sqlbuilder package sqlbuilder
type groupByClause interface { type groupByClause interface {
serializeForGroupBy(out *queryData) error serializeForGroupBy(statement statementType, out *queryData) error
} }
// TODO: GROUPING SETS, CUBE, and ROLLUP // TODO: GROUPING SETS, CUBE, and ROLLUP

View file

@ -127,29 +127,28 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
queryData := &queryData{} queryData := &queryData{}
queryData.statementType = insert_statement queryData.writeString("INSERT INTO ")
queryData.WriteString("INSERT INTO ")
if s.table == nil { if s.table == nil {
return "", nil, errors.Newf("nil tableName.") return "", nil, errors.Newf("nil tableName.")
} }
err = s.table.serializeSql(queryData) err = s.table.serialize(insert_statement, queryData)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
if len(s.columns) > 0 { if len(s.columns) > 0 {
queryData.WriteString(" (") queryData.writeString(" (")
err = serializeColumnList(s.columns, queryData) err = serializeColumnList(insert_statement, s.columns, queryData)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
queryData.WriteString(") ") queryData.writeString(") ")
} }
if len(s.rows) == 0 && s.query == nil { if len(s.rows) == 0 && s.query == nil {
@ -161,28 +160,28 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
if len(s.rows) > 0 { if len(s.rows) > 0 {
queryData.WriteString("VALUES (") queryData.writeString("VALUES (")
for row_i, row := range s.rows { for row_i, row := range s.rows {
if row_i > 0 { if row_i > 0 {
queryData.WriteString(", (") queryData.writeString(", (")
} }
if len(row) != len(s.columns) { if len(row) != len(s.columns) {
return "", nil, errors.New("# of values does not match # of columns.") return "", nil, errors.New("# of values does not match # of columns.")
} }
err = serializeClauseList(row, queryData) err = serializeClauseList(insert_statement, row, queryData)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
queryData.WriteByte(')') queryData.writeByte(')')
} }
} }
if s.query != nil { if s.query != nil {
err = s.query.serialize(queryData) err = s.query.serialize(insert_statement, queryData)
if err != nil { if err != nil {
return return
@ -190,16 +189,16 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
if len(s.returning) > 0 { if len(s.returning) > 0 {
queryData.WriteString(" RETURNING ") queryData.writeString(" RETURNING ")
err = queryData.WriteProjection(s.returning) err = queryData.writeProjection(insert_statement, s.returning)
if err != nil { if err != nil {
return return
} }
} }
queryData.WriteByte(';') queryData.writeByte(';')
return queryData.buff.String(), queryData.args, nil return queryData.buff.String(), queryData.args, nil
} }

View file

@ -6,8 +6,8 @@ const (
type keywordClause string type keywordClause string
func (k keywordClause) serialize(out *queryData) error { func (k keywordClause) serialize(statement statementType, out *queryData) error {
out.WriteString(string(k)) out.writeString(string(k))
return nil return nil
} }

View file

@ -13,8 +13,8 @@ func Literal(value interface{}) *literalExpression {
return &exp return &exp
} }
func (l literalExpression) serialize(out *queryData) error { func (l literalExpression) serialize(statement statementType, out *queryData) error {
out.InsertArgument(l.value) out.insertArgument(l.value)
return nil return nil
} }

View file

@ -59,14 +59,14 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
out := &queryData{} out := &queryData{}
out.WriteString("LOCK TABLE ") out.writeString("LOCK TABLE ")
for i, table := range l.tables { for i, table := range l.tables {
if i > 0 { if i > 0 {
out.WriteString(", ") out.writeString(", ")
} }
err := table.serializeSql(out) err := table.serialize(lock_statement, out)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
@ -74,13 +74,13 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
} }
if l.lockMode != "" { if l.lockMode != "" {
out.WriteString(" IN ") out.writeString(" IN ")
out.WriteString(string(l.lockMode)) out.writeString(string(l.lockMode))
out.WriteString(" MODE") out.writeString(" MODE")
} }
if l.nowait { if l.nowait {
out.WriteString(" NOWAIT") out.writeString(" NOWAIT")
} }
return out.buff.String(), out.args, nil return out.buff.String(), out.args, nil

View file

@ -130,10 +130,10 @@ func newNumericExpressionWrap(expression expression) numericExpression {
return &numericExpressionWrap return &numericExpressionWrap
} }
func (c *numericExpressionWrapper) serialize(out *queryData) error { func (c *numericExpressionWrapper) serialize(statement statementType, out *queryData) error {
out.WriteString("(") out.writeString("(")
err := c.expression.serialize(out) err := c.expression.serialize(statement, out)
out.WriteString(")") out.writeString(")")
return err return err
} }

View file

@ -3,7 +3,7 @@ package sqlbuilder
import "github.com/dropbox/godropbox/errors" import "github.com/dropbox/godropbox/errors"
type orderByClause interface { type orderByClause interface {
serializeAsOrderBy(out *queryData) error serializeAsOrderBy(statement statementType, out *queryData) error
} }
type orderByClauseImpl struct { type orderByClauseImpl struct {
@ -11,19 +11,19 @@ type orderByClauseImpl struct {
ascent bool ascent bool
} }
func (o *orderByClauseImpl) serializeAsOrderBy(out *queryData) error { func (o *orderByClauseImpl) serializeAsOrderBy(statement statementType, out *queryData) error {
if o.expression == nil { if o.expression == nil {
return errors.Newf("nil orderBy by clause.") return errors.Newf("nil orderBy by clause.")
} }
if err := o.expression.serializeAsOrderBy(out); err != nil { if err := o.expression.serializeAsOrderBy(statement, out); err != nil {
return err return err
} }
if o.ascent { if o.ascent {
out.WriteString(" ASC") out.writeString(" ASC")
} else { } else {
out.WriteString(" DESC") out.writeString(" DESC")
} }
return nil return nil

View file

@ -1,7 +1,7 @@
package sqlbuilder package sqlbuilder
type projection interface { type projection interface {
serializeForProjection(out *queryData) error serializeForProjection(statement statementType, out *queryData) error
} }
//------------------------------------------------------// //------------------------------------------------------//
@ -10,16 +10,16 @@ type ColumnList []column
func (cl ColumnList) isProjectionType() {} func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) serializeForProjection(out *queryData) error { func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error {
for i, column := range cl { for i, column := range cl {
err := column.serializeForProjection(out) err := column.serializeForProjection(statement, out)
if err != nil { if err != nil {
return err return err
} }
if i != len(cl)-1 { if i != len(cl)-1 {
out.WriteString(", ") out.writeString(", ")
} }
} }
return nil return nil

View file

@ -25,7 +25,7 @@ type selectStatement interface {
AsTable(alias string) expressionTable AsTable(alias string) expressionTable
} }
var SELECT = func(projection ...projection) selectStatement { func SELECT(projection ...projection) selectStatement {
return newSelectStatement(nil, projection) return newSelectStatement(nil, projection)
} }
@ -83,9 +83,9 @@ func (s *selectStatementImpl) FROM(table readableTable) selectStatement {
return s return s
} }
func (s *selectStatementImpl) serialize(out *queryData) error { func (s *selectStatementImpl) serialize(statement statementType, out *queryData) error {
out.WriteString("(") out.writeString("(")
err := s.serializeImpl(out) err := s.serializeImpl(out)
@ -93,42 +93,41 @@ func (s *selectStatementImpl) serialize(out *queryData) error {
return err return err
} }
out.WriteString(")") out.writeString(")")
return nil return nil
} }
func (s *selectStatementImpl) serializeImpl(out *queryData) error { func (s *selectStatementImpl) serializeImpl(out *queryData) error {
out.WriteString("SELECT ") out.writeString("SELECT ")
out.statementType = select_statement
if s.distinct { if s.distinct {
out.WriteString("DISTINCT ") out.writeString("DISTINCT ")
} }
if s.projections == nil || len(s.projections) == 0 { if s.projections == nil || len(s.projections) == 0 {
return errors.New("No column selected for projection.") return errors.New("No column selected for projection.")
} }
err := out.WriteProjection(s.projections) err := out.writeProjection(select_statement, s.projections)
if err != nil { if err != nil {
return err return err
} }
out.WriteString(" FROM ") out.writeString(" FROM ")
if s.table == nil { if s.table == nil {
return errors.Newf("nil tableName.") return errors.Newf("nil tableName.")
} }
if err := s.table.serializeSql(out); err != nil { if err := s.table.serialize(select_statement, out); err != nil {
return err return err
} }
if s.where != nil { if s.where != nil {
err := out.WriteWhere(s.where) err := out.writeWhere(select_statement, s.where)
if err != nil { if err != nil {
return nil return nil
@ -136,7 +135,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
} }
if s.groupBy != nil && len(s.groupBy) > 0 { if s.groupBy != nil && len(s.groupBy) > 0 {
err := out.WriteGroupBy(s.groupBy) err := out.writeGroupBy(select_statement, s.groupBy)
if err != nil { if err != nil {
return err return err
@ -144,7 +143,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
} }
if s.having != nil { if s.having != nil {
err := out.WriteHaving(s.having) err := out.writeHaving(select_statement, s.having)
if err != nil { if err != nil {
return err return err
@ -152,7 +151,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
} }
if s.orderBy != nil { if s.orderBy != nil {
err := out.WriteOrderBy(s.orderBy) err := out.writeOrderBy(select_statement, s.orderBy)
if err != nil { if err != nil {
return err return err
@ -160,17 +159,17 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
} }
if s.limit >= 0 { if s.limit >= 0 {
out.WriteString(" LIMIT ") out.writeString(" LIMIT ")
out.InsertArgument(s.limit) out.insertArgument(s.limit)
} }
if s.offset >= 0 { if s.offset >= 0 {
out.WriteString(" OFFSET ") out.writeString(" OFFSET ")
out.InsertArgument(s.offset) out.insertArgument(s.offset)
} }
if s.forUpdate { if s.forUpdate {
out.WriteString(" FOR UPDATE") out.writeString(" FOR UPDATE")
} }
return nil return nil

View file

@ -96,9 +96,9 @@ func (us *setStatementImpl) AsTable(alias string) expressionTable {
} }
} }
func (s *setStatementImpl) serialize(out *queryData) error { func (s *setStatementImpl) serialize(statement statementType, out *queryData) error {
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString("(") out.writeString("(")
} }
err := s.serializeImpl(out) err := s.serializeImpl(out)
@ -108,7 +108,7 @@ func (s *setStatementImpl) serialize(out *queryData) error {
} }
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString(")") out.writeString(")")
} }
return nil return nil
@ -120,43 +120,41 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
return errors.Newf("UNION statement must have at least two SELECT statements.") return errors.Newf("UNION statement must have at least two SELECT statements.")
} }
out.WriteString("(") out.writeString("(")
for i, selectStmt := range s.selects { for i, selectStmt := range s.selects {
if i > 0 { if i > 0 {
out.WriteString(" " + s.operator + " ") out.writeString(" " + s.operator + " ")
if s.all { if s.all {
out.WriteString(" ALL ") out.writeString(" ALL ")
} }
} }
err := selectStmt.serialize(out) err := selectStmt.serialize(set_statement, out)
if err != nil { if err != nil {
return err return err
} }
} }
out.WriteString(")") out.writeString(")")
out.statementType = set_statement
if s.orderBy != nil { if s.orderBy != nil {
err := out.WriteOrderBy(s.orderBy) err := out.writeOrderBy(set_statement, s.orderBy)
if err != nil { if err != nil {
return err return err
} }
} }
if s.limit >= 0 { if s.limit >= 0 {
out.WriteString(" LIMIT ") out.writeString(" LIMIT ")
out.InsertArgument(s.limit) out.insertArgument(s.limit)
} }
if s.offset >= 0 { if s.offset >= 0 {
out.WriteString(" OFFSET ") out.writeString(" OFFSET ")
out.InsertArgument(s.offset) out.insertArgument(s.offset)
} }
return nil return nil

View file

@ -7,12 +7,11 @@ import (
) )
type tableInterface interface { type tableInterface interface {
clause
SchemaName() string SchemaName() string
TableName() string TableName() string
Columns() []column Columns() []column
// Generates the sql string for the current tableName expression.
serializeSql(out *queryData) error
} }
// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause // The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
@ -108,18 +107,18 @@ func (t *Table) Columns() []column {
// Generates the sql string for the current tableName expression. Note: the // Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement. // generated string may not be a valid/executable sql statement.
func (t *Table) serializeSql(out *queryData) error { func (t *Table) serialize(statement statementType, out *queryData) error {
if t == nil { if t == nil {
return errors.Newf("nil tableName.") return errors.Newf("nil tableName.")
} }
out.WriteString(t.schemaName) out.writeString(t.schemaName)
out.WriteString(".") out.writeString(".")
out.WriteString(t.TableName()) out.writeString(t.TableName())
if len(t.alias) > 0 { if len(t.alias) > 0 {
out.WriteString(" AS ") out.writeString(" AS ")
out.WriteString(t.alias) out.writeString(t.alias)
} }
return nil return nil
@ -272,7 +271,7 @@ func (t *joinTable) Column(name string) column {
} }
} }
func (t *joinTable) serializeSql(out *queryData) (err error) { func (t *joinTable) serialize(statement statementType, out *queryData) (err error) {
if t.lhs == nil { if t.lhs == nil {
return errors.Newf("nil lhs.") return errors.Newf("nil lhs.")
@ -284,30 +283,30 @@ func (t *joinTable) serializeSql(out *queryData) (err error) {
return errors.Newf("nil onCondition.") return errors.Newf("nil onCondition.")
} }
if err = t.lhs.serializeSql(out); err != nil { if err = t.lhs.serialize(statement, out); err != nil {
return return
} }
switch t.join_type { switch t.join_type {
case INNER_JOIN: case INNER_JOIN:
out.WriteString(" JOIN ") out.writeString(" JOIN ")
case LEFT_JOIN: case LEFT_JOIN:
out.WriteString(" LEFT JOIN ") out.writeString(" LEFT JOIN ")
case RIGHT_JOIN: case RIGHT_JOIN:
out.WriteString(" RIGHT JOIN ") out.writeString(" RIGHT JOIN ")
case FULL_JOIN: case FULL_JOIN:
out.WriteString(" FULL JOIN ") out.writeString(" FULL JOIN ")
case CROSS_JOIN: case CROSS_JOIN:
out.WriteString(" CROSS JOIN ") out.writeString(" CROSS JOIN ")
} }
if err = t.rhs.serializeSql(out); err != nil { if err = t.rhs.serialize(statement, out); err != nil {
return return
} }
if t.onCondition != nil { if t.onCondition != nil {
out.WriteString(" ON ") out.writeString(" ON ")
if err = t.onCondition.serialize(out); err != nil { if err = t.onCondition.serialize(statement, out); err != nil {
return return
} }
} }

View file

@ -51,7 +51,7 @@ func (s *TableSuite) TestJoinNilLeftTable(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := join.serializeSql(buf) err := join.serialize("", buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -60,7 +60,7 @@ func (s *TableSuite) TestJoinNilRightTable(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := join.serializeSql(buf) err := join.serialize("", buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -69,7 +69,7 @@ func (s *TableSuite) TestJoinNilOnCondition(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := join.serializeSql(buf) err := join.serialize("", buf)
c.Assert(err, gc.NotNil) c.Assert(err, gc.NotNil)
} }
@ -93,7 +93,7 @@ func (s *TableSuite) TestLeftJoin(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := join.serializeSql(buf) err := join.serialize("", buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()
@ -109,7 +109,7 @@ func (s *TableSuite) TestRightJoin(c *gc.C) {
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
err := join.serializeSql(buf) err := join.serialize("", buf)
c.Assert(err, gc.IsNil) c.Assert(err, gc.IsNil)
sql := buf.String() sql := buf.String()

View file

@ -54,15 +54,14 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) updateStateme
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) { func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{} out := &queryData{}
out.statementType = update_statement
out.WriteString("UPDATE ") out.writeString("UPDATE ")
if u.table == nil { if u.table == nil {
return "", nil, errors.New("nil tableName.") return "", nil, errors.New("nil tableName.")
} }
if err = u.table.serializeSql(out); err != nil { if err = u.table.serialize(update_statement, out); err != nil {
return return
} }
@ -70,36 +69,36 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return "", nil, errors.New("No column updated.") return "", nil, errors.New("No column updated.")
} }
out.WriteString(" SET") out.writeString(" SET")
if len(u.columns) > 1 { if len(u.columns) > 1 {
out.WriteString(" ( ") out.writeString(" ( ")
} else { } else {
out.WriteString(" ") out.writeString(" ")
} }
err = serializeColumnList(u.columns, out) err = serializeColumnList(update_statement, u.columns, out)
if err != nil { if err != nil {
return "", nil, err return "", nil, err
} }
if len(u.columns) > 1 { if len(u.columns) > 1 {
out.WriteString(" )") out.writeString(" )")
} }
out.WriteString(" =") out.writeString(" =")
if len(u.updateValues) > 1 { if len(u.updateValues) > 1 {
out.WriteString(" (") out.writeString(" (")
} }
for i, value := range u.updateValues { for i, value := range u.updateValues {
if i > 0 { if i > 0 {
out.WriteString(", ") out.writeString(", ")
} }
err = value.serialize(out) err = value.serialize(update_statement, out)
if err != nil { if err != nil {
return return
@ -107,21 +106,21 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
} }
if len(u.updateValues) > 1 { if len(u.updateValues) > 1 {
out.WriteString(" )") out.writeString(" )")
} }
if u.where == nil { if u.where == nil {
return "", nil, errors.New("Updating without a WHERE clause.") return "", nil, errors.New("Updating without a WHERE clause.")
} }
if err = out.WriteWhere(u.where); err != nil { if err = out.writeWhere(update_statement, u.where); err != nil {
return return
} }
if len(u.returning) > 0 { if len(u.returning) > 0 {
out.WriteString(" RETURNING ") out.writeString(" RETURNING ")
err = serializeProjectionList(u.returning, out) err = serializeProjectionList(update_statement, u.returning, out)
if err != nil { if err != nil {
return return

View file

@ -7,14 +7,14 @@ import (
"github.com/sub0zero/go-sqlbuilder/types" "github.com/sub0zero/go-sqlbuilder/types"
) )
func serializeOrderByClauseList(orderByClauses []orderByClause, out *queryData) error { func serializeOrderByClauseList(statement statementType, orderByClauses []orderByClause, out *queryData) error {
for i, value := range orderByClauses { for i, value := range orderByClauses {
if i > 0 { if i > 0 {
out.WriteString(", ") out.writeString(", ")
} }
err := value.serializeAsOrderBy(out) err := value.serializeAsOrderBy(statement, out)
if err != nil { if err != nil {
return err return err
@ -24,18 +24,18 @@ func serializeOrderByClauseList(orderByClauses []orderByClause, out *queryData)
return nil return nil
} }
func serializeGroupByClauseList(clauses []groupByClause, out *queryData) (err error) { func serializeGroupByClauseList(statement statementType, clauses []groupByClause, out *queryData) (err error) {
for i, c := range clauses { for i, c := range clauses {
if i > 0 { if i > 0 {
out.WriteString(", ") out.writeString(", ")
} }
if c == nil { if c == nil {
return errors.New("nil clause.") return errors.New("nil clause.")
} }
if err = c.serializeForGroupBy(out); err != nil { if err = c.serializeForGroupBy(statement, out); err != nil {
return return
} }
} }
@ -43,18 +43,18 @@ func serializeGroupByClauseList(clauses []groupByClause, out *queryData) (err er
return nil return nil
} }
func serializeClauseList(clauses []clause, out *queryData) (err error) { func serializeClauseList(statement statementType, clauses []clause, out *queryData) (err error) {
for i, c := range clauses { for i, c := range clauses {
if i > 0 { if i > 0 {
out.WriteString(", ") out.writeString(", ")
} }
if c == nil { if c == nil {
return errors.New("nil clause.") return errors.New("nil clause.")
} }
if err = c.serialize(out); err != nil { if err = c.serialize(statement, out); err != nil {
return return
} }
} }
@ -62,14 +62,14 @@ func serializeClauseList(clauses []clause, out *queryData) (err error) {
return nil return nil
} }
func serializeExpressionList(expressions []expression, separator string, out *queryData) error { func serializeExpressionList(statement statementType, expressions []expression, separator string, out *queryData) error {
for i, value := range expressions { for i, value := range expressions {
if i > 0 { if i > 0 {
out.WriteString(separator) out.writeString(separator)
} }
err := value.serialize(out) err := value.serialize(statement, out)
if err != nil { if err != nil {
return err return err
@ -79,16 +79,16 @@ func serializeExpressionList(expressions []expression, separator string, out *qu
return nil return nil
} }
func serializeProjectionList(projections []projection, out *queryData) error { func serializeProjectionList(statement statementType, projections []projection, out *queryData) error {
for i, col := range projections { for i, col := range projections {
if i > 0 { if i > 0 {
out.WriteString(", ") out.writeString(", ")
} }
if col == nil { if col == nil {
return errors.New("projection expression is nil.") return errors.New("projection expression is nil.")
} }
if err := col.serializeForProjection(out); err != nil { if err := col.serializeForProjection(statement, out); err != nil {
return err return err
} }
} }
@ -96,17 +96,17 @@ func serializeProjectionList(projections []projection, out *queryData) error {
return nil return nil
} }
func serializeColumnList(columns []column, out *queryData) error { func serializeColumnList(statement statementType, columns []column, out *queryData) error {
for i, col := range columns { for i, col := range columns {
if i > 0 { if i > 0 {
out.WriteByte(',') out.writeByte(',')
} }
if col == nil { if col == nil {
return errors.New("nil column in columns list.") return errors.New("nil column in columns list.")
} }
out.WriteString(col.Name()) out.writeString(col.Name())
} }
return nil return nil