Statements sql serialization simplified.

This commit is contained in:
zer0sub 2019-05-08 13:47:01 +02:00
parent d9bbec3795
commit 439c9f1ef9
26 changed files with 274 additions and 284 deletions

View file

@ -12,15 +12,15 @@ func NewAlias(expression expression, alias string) *Alias {
}
}
func (a *Alias) serializeForProjection(out *queryData) error {
func (a *Alias) serializeForProjection(statement statementType, out *queryData) error {
err := a.expression.serializeForProjection(out)
err := a.expression.serializeForProjection(statement, out)
if err != nil {
return err
}
out.WriteString(" AS \"" + a.alias + "\"")
out.writeString(" AS \"" + a.alias + "\"")
return nil
}

View file

@ -9,7 +9,7 @@ func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3))
out := queryData{}
err := boolExpression.serialize(&out)
err := boolExpression.serialize(select_statement, &out)
assert.NilError(t, err)
@ -20,7 +20,7 @@ func TestBinaryExpression(t *testing.T) {
alias := boolExpression.AS("alias_eq_expression")
out := queryData{}
err := alias.serializeForProjection(&out)
err := alias.serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`)
@ -30,7 +30,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.serialize(&out)
err := exp.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `($1 = $2 AND $3 = $4)`)
@ -40,7 +40,7 @@ func TestBinaryExpression(t *testing.T) {
exp := boolExpression.OR(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.serialize(&out)
err := exp.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `($1 = $2 OR $3 = $4)`)
@ -51,7 +51,7 @@ func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1)))
out := queryData{}
err := notExpression.serialize(&out)
err := notExpression.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "NOT $1 = $2")
@ -60,7 +60,7 @@ func TestUnaryExpression(t *testing.T) {
alias := notExpression.AS("alias_not_expression")
out := queryData{}
err := alias.serializeForProjection(&out)
err := alias.serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `NOT $1 = $2 AS "alias_not_expression"`)
@ -70,7 +70,7 @@ func TestUnaryExpression(t *testing.T) {
exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.serialize(&out)
err := exp.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `(NOT $1 = $2 AND $3 = $4)`)
@ -81,7 +81,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := queryData{}
err := notExpression.serialize(&out)
err := notExpression.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "IS TRUE $1 = $2")
@ -90,7 +90,7 @@ func TestUnaryIsTrueExpression(t *testing.T) {
exp := notExpression.AND(Eq(Literal(4), Literal(5)))
out := queryData{}
err := exp.serialize(&out)
err := exp.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `(IS TRUE $1 = $2 AND $3 = $4)`)
@ -101,7 +101,7 @@ func TestBoolLiteral(t *testing.T) {
literal := newBoolLiteralExpression(true)
out := queryData{}
err := literal.serialize(&out)
err := literal.serialize(select_statement, &out)
assert.NilError(t, err)
@ -116,7 +116,7 @@ func TestExists(t *testing.T) {
)
out := queryData{}
err := query.serialize(&out)
err := query.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "EXISTS (SELECT $1 FROM db.table2 WHERE table1.col1 = table2.col3)")
@ -126,7 +126,7 @@ func TestIn(t *testing.T) {
query := Literal(1.11).IN(table1.SELECT(table1Col1))
out := queryData{}
err := query.serialize(&out)
err := query.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `$1 IN (SELECT table1.col1 AS "table1.col1" FROM db.table1)`)
@ -134,7 +134,7 @@ func TestIn(t *testing.T) {
query2 := ROW(Literal(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1))
out = queryData{}
err = query2.serialize(&out)
err = query2.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN (SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2))`)

View file

@ -7,69 +7,70 @@ import (
)
type clause interface {
serialize(out *queryData) error
serialize(statement statementType, out *queryData) error
}
type queryData struct {
buff bytes.Buffer
args []interface{}
statementType int
}
type statementType string
const (
select_statement = iota
insert_statement
update_statement
delete_statement
set_statement
select_statement statementType = "SELECT"
insert_statement statementType = "INSERT"
update_statement statementType = "UPDATE"
delete_statement statementType = "DELETE"
set_statement statementType = "SET"
lock_statement statementType = "LOCK"
)
func (q *queryData) WriteProjection(projections []projection) error {
return serializeProjectionList(projections, q)
func (q *queryData) writeProjection(statement statementType, projections []projection) error {
return serializeProjectionList(statement, projections, q)
}
func (q *queryData) WriteWhere(where expression) error {
q.WriteString(" WHERE ")
return where.serialize(q)
func (q *queryData) writeWhere(statement statementType, where expression) error {
q.writeString(" WHERE ")
return where.serialize(statement, q)
}
func (q *queryData) WriteGroupBy(groupBy []groupByClause) error {
q.WriteString(" GROUP BY ")
func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.writeString(" GROUP BY ")
return serializeGroupByClauseList(groupBy, q)
return serializeGroupByClauseList(statement, groupBy, q)
}
func (q *queryData) WriteOrderBy(orderBy []orderByClause) error {
q.WriteString(" ORDER BY ")
return serializeOrderByClauseList(orderBy, q)
func (q *queryData) writeOrderBy(statement statementType, orderBy []orderByClause) error {
q.writeString(" ORDER BY ")
return serializeOrderByClauseList(statement, orderBy, q)
}
func (q *queryData) WriteHaving(having expression) error {
q.WriteString(" HAVING ")
return having.serialize(q)
func (q *queryData) writeHaving(statement statementType, having expression) error {
q.writeString(" HAVING ")
return having.serialize(statement, q)
}
func (q *queryData) Write(data []byte) {
func (q *queryData) write(data []byte) {
q.buff.Write(data)
}
func (q *queryData) WriteString(str string) {
func (q *queryData) writeString(str string) {
q.buff.WriteString(str)
}
func (q *queryData) WriteByte(b byte) {
func (q *queryData) writeByte(b byte) {
q.buff.WriteByte(b)
}
func (q *queryData) InsertArgument(arg interface{}) {
func (q *queryData) insertArgument(arg interface{}) {
q.args = append(q.args, arg)
argPlaceholder := "$" + strconv.Itoa(len(q.args))
q.buff.WriteString(argPlaceholder)
}
func (q *queryData) Reset() {
func (q *queryData) reset() {
q.buff.Reset()
q.args = []interface{}{}
}

View file

@ -77,42 +77,42 @@ func (c *baseColumn) DefaultAlias() projection {
return c.AS(c.tableName + "." + c.name)
}
func (c *baseColumn) serializeAsOrderBy(out *queryData) error {
if out.statementType == set_statement {
func (c *baseColumn) serializeAsOrderBy(statement statementType, out *queryData) error {
if statement == set_statement {
// set statement (UNION, EXCEPT ...) can reference only select projections in order by clause
out.WriteString(`"`)
out.writeString(`"`)
if c.tableName != "" {
out.WriteString(c.tableName)
out.WriteString(".")
out.writeString(c.tableName)
out.writeString(".")
}
out.WriteString(c.name)
out.writeString(c.name)
out.WriteString(`"`)
out.writeString(`"`)
return nil
}
return c.serialize(out)
return c.serialize(statement, out)
}
func (c baseColumn) serialize(out *queryData) error {
func (c baseColumn) serialize(statement statementType, out *queryData) error {
if c.tableName != "" {
out.WriteString(c.tableName)
out.WriteString(".")
out.writeString(c.tableName)
out.writeString(".")
}
wrapColumnName := strings.Contains(c.name, ".")
if wrapColumnName {
out.WriteString(`"`)
out.writeString(`"`)
}
out.WriteString(c.name)
out.writeString(c.name)
if wrapColumnName {
out.WriteString(`"`)
out.writeString(`"`)
}
return nil

View file

@ -9,26 +9,26 @@ func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable)
out := queryData{}
err := boolColumn.serialize(&out)
err := boolColumn.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = boolColumn.serialize(&out)
out.reset()
err = boolColumn.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col")
out.Reset()
out.reset()
boolColumn.setTableName("table1")
err = boolColumn.DefaultAlias().serializeForProjection(&out)
err = boolColumn.DefaultAlias().serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset()
out.reset()
boolColumn.setTableName("table1")
aliasedBoolColumn := boolColumn.AS("alias1")
err = aliasedBoolColumn.serializeForProjection(&out)
err = aliasedBoolColumn.serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
}
@ -37,26 +37,26 @@ func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable)
out := queryData{}
err := integerColumn.serialize(&out)
err := integerColumn.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = integerColumn.serialize(&out)
out.reset()
err = integerColumn.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col")
out.Reset()
out.reset()
integerColumn.setTableName("table1")
err = integerColumn.DefaultAlias().serializeForProjection(&out)
err = integerColumn.DefaultAlias().serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset()
out.reset()
integerColumn.setTableName("table1")
aliasedBoolColumn := integerColumn.AS("alias1")
err = aliasedBoolColumn.serializeForProjection(&out)
err = aliasedBoolColumn.serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
}
@ -65,26 +65,26 @@ func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable)
out := queryData{}
err := numericColumn.serialize(&out)
err := numericColumn.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = numericColumn.serialize(&out)
out.reset()
err = numericColumn.serialize(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), "col")
out.Reset()
out.reset()
numericColumn.setTableName("table1")
err = numericColumn.DefaultAlias().serializeForProjection(&out)
err = numericColumn.DefaultAlias().serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset()
out.reset()
numericColumn.setTableName("table1")
aliasedBoolColumn := numericColumn.AS("alias1")
err = aliasedBoolColumn.serializeForProjection(&out)
err = aliasedBoolColumn.serializeForProjection(select_statement, &out)
assert.NilError(t, err)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
}

View file

@ -30,15 +30,14 @@ func (d *deleteStatementImpl) WHERE(expression boolExpression) deleteStatement {
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
queryData := &queryData{}
queryData.statementType = delete_statement
queryData.WriteString("DELETE FROM ")
queryData.writeString("DELETE FROM ")
if d.table == nil {
return "", nil, errors.New("nil tableName.")
}
if err = d.table.serializeSql(queryData); err != nil {
if err = d.table.serialize(delete_statement, queryData); err != nil {
return
}
@ -46,7 +45,7 @@ func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error
return "", nil, errors.New("Deleting without a WHERE clause.")
}
if err = queryData.WriteWhere(d.where); err != nil {
if err = queryData.writeWhere(delete_statement, d.where); err != nil {
return
}

View file

@ -53,16 +53,16 @@ func (e *expressionInterfaceImpl) DESC() orderByClause {
return &orderByClauseImpl{expression: e.parent, ascent: false}
}
func (e *expressionInterfaceImpl) serializeForGroupBy(out *queryData) error {
return e.parent.serialize(out)
func (e *expressionInterfaceImpl) serializeForGroupBy(statement statementType, out *queryData) error {
return e.parent.serialize(statement, out)
}
func (e *expressionInterfaceImpl) serializeForProjection(out *queryData) error {
return e.parent.serialize(out)
func (e *expressionInterfaceImpl) serializeForProjection(statement statementType, out *queryData) error {
return e.parent.serialize(statement, out)
}
func (e *expressionInterfaceImpl) serializeAsOrderBy(out *queryData) error {
return e.parent.serialize(out)
func (e *expressionInterfaceImpl) serializeAsOrderBy(statement statementType, out *queryData) error {
return e.parent.serialize(statement, out)
}
// Representation of binary operations (e.g. comparisons, arithmetic)
@ -95,7 +95,7 @@ func isSimpleOperand(expression expression) bool {
return false
}
func (c *binaryExpression) serialize(out *queryData) error {
func (c *binaryExpression) serialize(statement statementType, out *queryData) error {
if c.lhs == nil {
return errors.Newf("nil lhs.")
}
@ -106,21 +106,21 @@ func (c *binaryExpression) serialize(out *queryData) error {
wrap := !isSimpleOperand(c.lhs) && !isSimpleOperand(c.rhs)
if wrap {
out.WriteString("(")
out.writeString("(")
}
if err := c.lhs.serialize(out); err != nil {
if err := c.lhs.serialize(statement, out); err != nil {
return err
}
out.WriteString(" " + c.operator + " ")
out.writeString(" " + c.operator + " ")
if err := c.rhs.serialize(out); err != nil {
if err := c.rhs.serialize(statement, out); err != nil {
return err
}
if wrap {
out.WriteString(")")
out.writeString(")")
}
return nil
@ -141,13 +141,13 @@ func newPrefixExpression(expression expression, operator string) prefixExpressio
return prefixExpression
}
func (p *prefixExpression) serialize(out *queryData) error {
out.WriteString(p.operator + " ")
func (p *prefixExpression) serialize(statement statementType, out *queryData) error {
out.writeString(p.operator + " ")
if p.expression == nil {
return errors.Newf("nil prefix expression.")
}
if err := p.expression.serialize(out); err != nil {
if err := p.expression.serialize(statement, out); err != nil {
return err
}

View file

@ -14,14 +14,14 @@ type intervalExpression struct {
const intervalSep = ":"
func (c *intervalExpression) serialize(out *queryData) error {
out.WriteString("INTERVAL '")
func (c *intervalExpression) serialize(statement statementType, out *queryData) error {
out.writeString("INTERVAL '")
duration := c.duration
if duration < 0 {
duration = -duration
out.WriteString("-")
out.writeString("-")
}
hours := duration / time.Hour
@ -29,14 +29,14 @@ func (c *intervalExpression) serialize(out *queryData) error {
sec := (duration % time.Minute) / time.Second
msec := (duration % time.Second) / time.Microsecond
out.WriteString(strconv.FormatInt(int64(hours), 10))
out.WriteString(intervalSep)
out.WriteString(strconv.FormatInt(int64(minutes), 10))
out.WriteString(intervalSep)
out.WriteString(strconv.FormatInt(int64(sec), 10))
out.WriteString(intervalSep)
out.WriteString(strconv.FormatInt(int64(msec), 10))
out.WriteString("' HOUR_MICROSECOND")
out.writeString(strconv.FormatInt(int64(hours), 10))
out.writeString(intervalSep)
out.writeString(strconv.FormatInt(int64(minutes), 10))
out.writeString(intervalSep)
out.writeString(strconv.FormatInt(int64(sec), 10))
out.writeString(intervalSep)
out.writeString(strconv.FormatInt(int64(msec), 10))
out.writeString("' HOUR_MICROSECOND")
return nil
}

View file

@ -19,7 +19,7 @@ func (s *ExprSuite) TestConjunctExprEmptyList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
@ -28,7 +28,7 @@ func (s *ExprSuite) TestConjunctExprNilInList(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
@ -37,7 +37,7 @@ func (s *ExprSuite) TestConjunctExprSingleElement(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -49,7 +49,7 @@ func (s *ExprSuite) TestLikeExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -65,7 +65,7 @@ func (s *ExprSuite) TestRegexExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -81,7 +81,7 @@ func (s *ExprSuite) TestAndExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -96,7 +96,7 @@ func (s *ExprSuite) TestOrExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -159,7 +159,7 @@ func (s *ExprSuite) TestBinaryExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
@ -168,7 +168,7 @@ func (s *ExprSuite) TestNegateExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -180,7 +180,7 @@ func (s *ExprSuite) TestBinaryExprNilRHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.NotNil)
}
@ -189,7 +189,7 @@ func (s *ExprSuite) TestEqExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -201,7 +201,7 @@ func (s *ExprSuite) TestEqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -213,7 +213,7 @@ func (s *ExprSuite) TestNeqExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -225,7 +225,7 @@ func (s *ExprSuite) TestNeqExprNilLHS(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -237,7 +237,7 @@ func (s *ExprSuite) TestLtExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -249,7 +249,7 @@ func (s *ExprSuite) TestLteExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -264,7 +264,7 @@ func (s *ExprSuite) TestGtExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -276,7 +276,7 @@ func (s *ExprSuite) TestGteExpr(c *gc.C) {
buf := &bytes.Buffer{}
err := expr.serialize(buf)
err := expr.serialize(0, buf)
c.Assert(err, gc.IsNil)
sql := buf.String()

View file

@ -47,16 +47,16 @@ func (s *expressionTableImpl) RefStringColumn(column column) *StringColumn {
return strColumn
}
func (s *expressionTableImpl) serializeSql(out *queryData) error {
out.WriteString("( ")
err := s.statement.serialize(out)
func (s *expressionTableImpl) serialize(statement statementType, out *queryData) error {
out.writeString("( ")
err := s.statement.serialize(statement, out)
if err != nil {
return err
}
out.WriteString(" ) AS ")
out.WriteString(s.alias)
out.writeString(" ) AS ")
out.writeString(s.alias)
return nil
}

View file

@ -28,14 +28,14 @@ func newFunc(name string, expressions []expression, parent expression) *funcExpr
return funcExp
}
func (f *funcExpressionImpl) serialize(out *queryData) error {
out.WriteString(f.name)
out.WriteString("(")
err := serializeExpressionList(f.expression, ", ", out)
func (f *funcExpressionImpl) serialize(statement statementType, out *queryData) error {
out.writeString(f.name)
out.writeString("(")
err := serializeExpressionList(statement, f.expression, ", ", out)
if err != nil {
return err
}
out.WriteString(")")
out.writeString(")")
return nil
}
@ -54,10 +54,6 @@ func NewNumericFunc(name string, expressions ...expression) numericExpression {
return numericFunc
}
//func (f *FuncExpression) SerializeSqlForColumnList(out *bytes.Buffer) error {
// return f.serialize(out)
//}
func MAX(expression numericExpression) numericExpression {
return NewNumericFunc("MAX", expression)
}
@ -111,12 +107,12 @@ func (c *caseExpression) ELSE(els expression) caseInterface {
return c
}
func (c *caseExpression) serialize(out *queryData) error {
out.WriteString("(CASE")
func (c *caseExpression) serialize(statement statementType, out *queryData) error {
out.writeString("(CASE")
if c.expression != nil {
out.WriteString(" ")
err := c.expression.serialize(out)
out.writeString(" ")
err := c.expression.serialize(statement, out)
if err != nil {
return err
@ -132,15 +128,15 @@ func (c *caseExpression) serialize(out *queryData) error {
}
for i, when := range c.when {
out.WriteString(" WHEN ")
err := when.serialize(out)
out.writeString(" WHEN ")
err := when.serialize(statement, out)
if err != nil {
return err
}
out.WriteString(" THEN ")
err = c.then[i].serialize(out)
out.writeString(" THEN ")
err = c.then[i].serialize(statement, out)
if err != nil {
return err
@ -148,15 +144,15 @@ func (c *caseExpression) serialize(out *queryData) error {
}
if c.els != nil {
out.WriteString(" ELSE ")
err := c.els.serialize(out)
out.writeString(" ELSE ")
err := c.els.serialize(statement, out)
if err != nil {
return err
}
}
out.WriteString(" END)")
out.writeString(" END)")
return nil
}

View file

@ -12,7 +12,7 @@ func TestCase1(t *testing.T) {
queryData := &queryData{}
err := query.serialize(queryData)
err := query.serialize(select_statement, queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE WHEN table3.col1 = $1 THEN table3.col1 + $2 WHEN table3.col1 = $3 THEN table3.col1 + $4 END)`)
@ -26,7 +26,7 @@ func TestCase2(t *testing.T) {
queryData := &queryData{}
err := query.serialize(queryData)
err := query.serialize(select_statement, queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `(CASE table3.col1 WHEN $1 THEN table3.col1 + $2 WHEN $3 THEN table3.col1 + $4 ELSE $5 END)`)
@ -37,7 +37,7 @@ func TestInterval(t *testing.T) {
queryData := &queryData{}
err := query.serialize(queryData)
err := query.serialize(select_statement, queryData)
assert.NilError(t, err)
assert.Equal(t, queryData.buff.String(), `INTERVAL $1`)

View file

@ -1,7 +1,7 @@
package sqlbuilder
type groupByClause interface {
serializeForGroupBy(out *queryData) error
serializeForGroupBy(statement statementType, out *queryData) error
}
// TODO: GROUPING SETS, CUBE, and ROLLUP

View file

@ -127,29 +127,28 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
queryData := &queryData{}
queryData.statementType = insert_statement
queryData.WriteString("INSERT INTO ")
queryData.writeString("INSERT INTO ")
if s.table == nil {
return "", nil, errors.Newf("nil tableName.")
}
err = s.table.serializeSql(queryData)
err = s.table.serialize(insert_statement, queryData)
if err != nil {
return "", nil, err
}
if len(s.columns) > 0 {
queryData.WriteString(" (")
queryData.writeString(" (")
err = serializeColumnList(s.columns, queryData)
err = serializeColumnList(insert_statement, s.columns, queryData)
if err != nil {
return "", nil, err
}
queryData.WriteString(") ")
queryData.writeString(") ")
}
if len(s.rows) == 0 && s.query == nil {
@ -161,28 +160,28 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
if len(s.rows) > 0 {
queryData.WriteString("VALUES (")
queryData.writeString("VALUES (")
for row_i, row := range s.rows {
if row_i > 0 {
queryData.WriteString(", (")
queryData.writeString(", (")
}
if len(row) != len(s.columns) {
return "", nil, errors.New("# of values does not match # of columns.")
}
err = serializeClauseList(row, queryData)
err = serializeClauseList(insert_statement, row, queryData)
if err != nil {
return "", nil, err
}
queryData.WriteByte(')')
queryData.writeByte(')')
}
}
if s.query != nil {
err = s.query.serialize(queryData)
err = s.query.serialize(insert_statement, queryData)
if err != nil {
return
@ -190,16 +189,16 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
if len(s.returning) > 0 {
queryData.WriteString(" RETURNING ")
queryData.writeString(" RETURNING ")
err = queryData.WriteProjection(s.returning)
err = queryData.writeProjection(insert_statement, s.returning)
if err != nil {
return
}
}
queryData.WriteByte(';')
queryData.writeByte(';')
return queryData.buff.String(), queryData.args, nil
}

View file

@ -6,8 +6,8 @@ const (
type keywordClause string
func (k keywordClause) serialize(out *queryData) error {
out.WriteString(string(k))
func (k keywordClause) serialize(statement statementType, out *queryData) error {
out.writeString(string(k))
return nil
}

View file

@ -13,8 +13,8 @@ func Literal(value interface{}) *literalExpression {
return &exp
}
func (l literalExpression) serialize(out *queryData) error {
out.InsertArgument(l.value)
func (l literalExpression) serialize(statement statementType, out *queryData) error {
out.insertArgument(l.value)
return nil
}

View file

@ -59,14 +59,14 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
out := &queryData{}
out.WriteString("LOCK TABLE ")
out.writeString("LOCK TABLE ")
for i, table := range l.tables {
if i > 0 {
out.WriteString(", ")
out.writeString(", ")
}
err := table.serializeSql(out)
err := table.serialize(lock_statement, out)
if err != nil {
return "", nil, err
@ -74,13 +74,13 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
}
if l.lockMode != "" {
out.WriteString(" IN ")
out.WriteString(string(l.lockMode))
out.WriteString(" MODE")
out.writeString(" IN ")
out.writeString(string(l.lockMode))
out.writeString(" MODE")
}
if l.nowait {
out.WriteString(" NOWAIT")
out.writeString(" NOWAIT")
}
return out.buff.String(), out.args, nil

View file

@ -130,10 +130,10 @@ func newNumericExpressionWrap(expression expression) numericExpression {
return &numericExpressionWrap
}
func (c *numericExpressionWrapper) serialize(out *queryData) error {
out.WriteString("(")
err := c.expression.serialize(out)
out.WriteString(")")
func (c *numericExpressionWrapper) serialize(statement statementType, out *queryData) error {
out.writeString("(")
err := c.expression.serialize(statement, out)
out.writeString(")")
return err
}

View file

@ -3,7 +3,7 @@ package sqlbuilder
import "github.com/dropbox/godropbox/errors"
type orderByClause interface {
serializeAsOrderBy(out *queryData) error
serializeAsOrderBy(statement statementType, out *queryData) error
}
type orderByClauseImpl struct {
@ -11,19 +11,19 @@ type orderByClauseImpl struct {
ascent bool
}
func (o *orderByClauseImpl) serializeAsOrderBy(out *queryData) error {
func (o *orderByClauseImpl) serializeAsOrderBy(statement statementType, out *queryData) error {
if o.expression == nil {
return errors.Newf("nil orderBy by clause.")
}
if err := o.expression.serializeAsOrderBy(out); err != nil {
if err := o.expression.serializeAsOrderBy(statement, out); err != nil {
return err
}
if o.ascent {
out.WriteString(" ASC")
out.writeString(" ASC")
} else {
out.WriteString(" DESC")
out.writeString(" DESC")
}
return nil

View file

@ -1,7 +1,7 @@
package sqlbuilder
type projection interface {
serializeForProjection(out *queryData) error
serializeForProjection(statement statementType, out *queryData) error
}
//------------------------------------------------------//
@ -10,16 +10,16 @@ type ColumnList []column
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) serializeForProjection(out *queryData) error {
func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error {
for i, column := range cl {
err := column.serializeForProjection(out)
err := column.serializeForProjection(statement, out)
if err != nil {
return err
}
if i != len(cl)-1 {
out.WriteString(", ")
out.writeString(", ")
}
}
return nil

View file

@ -25,7 +25,7 @@ type selectStatement interface {
AsTable(alias string) expressionTable
}
var SELECT = func(projection ...projection) selectStatement {
func SELECT(projection ...projection) selectStatement {
return newSelectStatement(nil, projection)
}
@ -83,9 +83,9 @@ func (s *selectStatementImpl) FROM(table readableTable) selectStatement {
return s
}
func (s *selectStatementImpl) serialize(out *queryData) error {
func (s *selectStatementImpl) serialize(statement statementType, out *queryData) error {
out.WriteString("(")
out.writeString("(")
err := s.serializeImpl(out)
@ -93,42 +93,41 @@ func (s *selectStatementImpl) serialize(out *queryData) error {
return err
}
out.WriteString(")")
out.writeString(")")
return nil
}
func (s *selectStatementImpl) serializeImpl(out *queryData) error {
out.WriteString("SELECT ")
out.statementType = select_statement
out.writeString("SELECT ")
if s.distinct {
out.WriteString("DISTINCT ")
out.writeString("DISTINCT ")
}
if s.projections == nil || len(s.projections) == 0 {
return errors.New("No column selected for projection.")
}
err := out.WriteProjection(s.projections)
err := out.writeProjection(select_statement, s.projections)
if err != nil {
return err
}
out.WriteString(" FROM ")
out.writeString(" FROM ")
if s.table == nil {
return errors.Newf("nil tableName.")
}
if err := s.table.serializeSql(out); err != nil {
if err := s.table.serialize(select_statement, out); err != nil {
return err
}
if s.where != nil {
err := out.WriteWhere(s.where)
err := out.writeWhere(select_statement, s.where)
if err != nil {
return nil
@ -136,7 +135,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
}
if s.groupBy != nil && len(s.groupBy) > 0 {
err := out.WriteGroupBy(s.groupBy)
err := out.writeGroupBy(select_statement, s.groupBy)
if err != nil {
return err
@ -144,7 +143,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
}
if s.having != nil {
err := out.WriteHaving(s.having)
err := out.writeHaving(select_statement, s.having)
if err != nil {
return err
@ -152,7 +151,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
}
if s.orderBy != nil {
err := out.WriteOrderBy(s.orderBy)
err := out.writeOrderBy(select_statement, s.orderBy)
if err != nil {
return err
@ -160,17 +159,17 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
}
if s.limit >= 0 {
out.WriteString(" LIMIT ")
out.InsertArgument(s.limit)
out.writeString(" LIMIT ")
out.insertArgument(s.limit)
}
if s.offset >= 0 {
out.WriteString(" OFFSET ")
out.InsertArgument(s.offset)
out.writeString(" OFFSET ")
out.insertArgument(s.offset)
}
if s.forUpdate {
out.WriteString(" FOR UPDATE")
out.writeString(" FOR UPDATE")
}
return nil

View file

@ -96,9 +96,9 @@ func (us *setStatementImpl) AsTable(alias string) expressionTable {
}
}
func (s *setStatementImpl) serialize(out *queryData) error {
func (s *setStatementImpl) serialize(statement statementType, out *queryData) error {
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString("(")
out.writeString("(")
}
err := s.serializeImpl(out)
@ -108,7 +108,7 @@ func (s *setStatementImpl) serialize(out *queryData) error {
}
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
out.WriteString(")")
out.writeString(")")
}
return nil
@ -120,43 +120,41 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
return errors.Newf("UNION statement must have at least two SELECT statements.")
}
out.WriteString("(")
out.writeString("(")
for i, selectStmt := range s.selects {
if i > 0 {
out.WriteString(" " + s.operator + " ")
out.writeString(" " + s.operator + " ")
if s.all {
out.WriteString(" ALL ")
out.writeString(" ALL ")
}
}
err := selectStmt.serialize(out)
err := selectStmt.serialize(set_statement, out)
if err != nil {
return err
}
}
out.WriteString(")")
out.statementType = set_statement
out.writeString(")")
if s.orderBy != nil {
err := out.WriteOrderBy(s.orderBy)
err := out.writeOrderBy(set_statement, s.orderBy)
if err != nil {
return err
}
}
if s.limit >= 0 {
out.WriteString(" LIMIT ")
out.InsertArgument(s.limit)
out.writeString(" LIMIT ")
out.insertArgument(s.limit)
}
if s.offset >= 0 {
out.WriteString(" OFFSET ")
out.InsertArgument(s.offset)
out.writeString(" OFFSET ")
out.insertArgument(s.offset)
}
return nil

View file

@ -7,12 +7,11 @@ import (
)
type tableInterface interface {
clause
SchemaName() string
TableName() string
Columns() []column
// Generates the sql string for the current tableName expression.
serializeSql(out *queryData) error
}
// The sql tableName read interface. NOTE: NATURAL JOINs, and join "USING" clause
@ -108,18 +107,18 @@ func (t *Table) Columns() []column {
// Generates the sql string for the current tableName expression. Note: the
// generated string may not be a valid/executable sql statement.
func (t *Table) serializeSql(out *queryData) error {
func (t *Table) serialize(statement statementType, out *queryData) error {
if t == nil {
return errors.Newf("nil tableName.")
}
out.WriteString(t.schemaName)
out.WriteString(".")
out.WriteString(t.TableName())
out.writeString(t.schemaName)
out.writeString(".")
out.writeString(t.TableName())
if len(t.alias) > 0 {
out.WriteString(" AS ")
out.WriteString(t.alias)
out.writeString(" AS ")
out.writeString(t.alias)
}
return nil
@ -272,7 +271,7 @@ func (t *joinTable) Column(name string) column {
}
}
func (t *joinTable) serializeSql(out *queryData) (err error) {
func (t *joinTable) serialize(statement statementType, out *queryData) (err error) {
if t.lhs == nil {
return errors.Newf("nil lhs.")
@ -284,30 +283,30 @@ func (t *joinTable) serializeSql(out *queryData) (err error) {
return errors.Newf("nil onCondition.")
}
if err = t.lhs.serializeSql(out); err != nil {
if err = t.lhs.serialize(statement, out); err != nil {
return
}
switch t.join_type {
case INNER_JOIN:
out.WriteString(" JOIN ")
out.writeString(" JOIN ")
case LEFT_JOIN:
out.WriteString(" LEFT JOIN ")
out.writeString(" LEFT JOIN ")
case RIGHT_JOIN:
out.WriteString(" RIGHT JOIN ")
out.writeString(" RIGHT JOIN ")
case FULL_JOIN:
out.WriteString(" FULL JOIN ")
out.writeString(" FULL JOIN ")
case CROSS_JOIN:
out.WriteString(" CROSS JOIN ")
out.writeString(" CROSS JOIN ")
}
if err = t.rhs.serializeSql(out); err != nil {
if err = t.rhs.serialize(statement, out); err != nil {
return
}
if t.onCondition != nil {
out.WriteString(" ON ")
if err = t.onCondition.serialize(out); err != nil {
out.writeString(" ON ")
if err = t.onCondition.serialize(statement, out); err != nil {
return
}
}

View file

@ -51,7 +51,7 @@ func (s *TableSuite) TestJoinNilLeftTable(c *gc.C) {
buf := &bytes.Buffer{}
err := join.serializeSql(buf)
err := join.serialize("", buf)
c.Assert(err, gc.NotNil)
}
@ -60,7 +60,7 @@ func (s *TableSuite) TestJoinNilRightTable(c *gc.C) {
buf := &bytes.Buffer{}
err := join.serializeSql(buf)
err := join.serialize("", buf)
c.Assert(err, gc.NotNil)
}
@ -69,7 +69,7 @@ func (s *TableSuite) TestJoinNilOnCondition(c *gc.C) {
buf := &bytes.Buffer{}
err := join.serializeSql(buf)
err := join.serialize("", buf)
c.Assert(err, gc.NotNil)
}
@ -93,7 +93,7 @@ func (s *TableSuite) TestLeftJoin(c *gc.C) {
buf := &bytes.Buffer{}
err := join.serializeSql(buf)
err := join.serialize("", buf)
c.Assert(err, gc.IsNil)
sql := buf.String()
@ -109,7 +109,7 @@ func (s *TableSuite) TestRightJoin(c *gc.C) {
buf := &bytes.Buffer{}
err := join.serializeSql(buf)
err := join.serialize("", buf)
c.Assert(err, gc.IsNil)
sql := buf.String()

View file

@ -54,15 +54,14 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) updateStateme
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
out.statementType = update_statement
out.WriteString("UPDATE ")
out.writeString("UPDATE ")
if u.table == nil {
return "", nil, errors.New("nil tableName.")
}
if err = u.table.serializeSql(out); err != nil {
if err = u.table.serialize(update_statement, out); err != nil {
return
}
@ -70,36 +69,36 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
return "", nil, errors.New("No column updated.")
}
out.WriteString(" SET")
out.writeString(" SET")
if len(u.columns) > 1 {
out.WriteString(" ( ")
out.writeString(" ( ")
} else {
out.WriteString(" ")
out.writeString(" ")
}
err = serializeColumnList(u.columns, out)
err = serializeColumnList(update_statement, u.columns, out)
if err != nil {
return "", nil, err
}
if len(u.columns) > 1 {
out.WriteString(" )")
out.writeString(" )")
}
out.WriteString(" =")
out.writeString(" =")
if len(u.updateValues) > 1 {
out.WriteString(" (")
out.writeString(" (")
}
for i, value := range u.updateValues {
if i > 0 {
out.WriteString(", ")
out.writeString(", ")
}
err = value.serialize(out)
err = value.serialize(update_statement, out)
if err != nil {
return
@ -107,21 +106,21 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
}
if len(u.updateValues) > 1 {
out.WriteString(" )")
out.writeString(" )")
}
if u.where == nil {
return "", nil, errors.New("Updating without a WHERE clause.")
}
if err = out.WriteWhere(u.where); err != nil {
if err = out.writeWhere(update_statement, u.where); err != nil {
return
}
if len(u.returning) > 0 {
out.WriteString(" RETURNING ")
out.writeString(" RETURNING ")
err = serializeProjectionList(u.returning, out)
err = serializeProjectionList(update_statement, u.returning, out)
if err != nil {
return

View file

@ -7,14 +7,14 @@ import (
"github.com/sub0zero/go-sqlbuilder/types"
)
func serializeOrderByClauseList(orderByClauses []orderByClause, out *queryData) error {
func serializeOrderByClauseList(statement statementType, orderByClauses []orderByClause, out *queryData) error {
for i, value := range orderByClauses {
if i > 0 {
out.WriteString(", ")
out.writeString(", ")
}
err := value.serializeAsOrderBy(out)
err := value.serializeAsOrderBy(statement, out)
if err != nil {
return err
@ -24,18 +24,18 @@ func serializeOrderByClauseList(orderByClauses []orderByClause, out *queryData)
return nil
}
func serializeGroupByClauseList(clauses []groupByClause, out *queryData) (err error) {
func serializeGroupByClauseList(statement statementType, clauses []groupByClause, out *queryData) (err error) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
out.writeString(", ")
}
if c == nil {
return errors.New("nil clause.")
}
if err = c.serializeForGroupBy(out); err != nil {
if err = c.serializeForGroupBy(statement, out); err != nil {
return
}
}
@ -43,18 +43,18 @@ func serializeGroupByClauseList(clauses []groupByClause, out *queryData) (err er
return nil
}
func serializeClauseList(clauses []clause, out *queryData) (err error) {
func serializeClauseList(statement statementType, clauses []clause, out *queryData) (err error) {
for i, c := range clauses {
if i > 0 {
out.WriteString(", ")
out.writeString(", ")
}
if c == nil {
return errors.New("nil clause.")
}
if err = c.serialize(out); err != nil {
if err = c.serialize(statement, out); err != nil {
return
}
}
@ -62,14 +62,14 @@ func serializeClauseList(clauses []clause, out *queryData) (err error) {
return nil
}
func serializeExpressionList(expressions []expression, separator string, out *queryData) error {
func serializeExpressionList(statement statementType, expressions []expression, separator string, out *queryData) error {
for i, value := range expressions {
if i > 0 {
out.WriteString(separator)
out.writeString(separator)
}
err := value.serialize(out)
err := value.serialize(statement, out)
if err != nil {
return err
@ -79,16 +79,16 @@ func serializeExpressionList(expressions []expression, separator string, out *qu
return nil
}
func serializeProjectionList(projections []projection, out *queryData) error {
func serializeProjectionList(statement statementType, projections []projection, out *queryData) error {
for i, col := range projections {
if i > 0 {
out.WriteString(", ")
out.writeString(", ")
}
if col == nil {
return errors.New("projection expression is nil.")
}
if err := col.serializeForProjection(out); err != nil {
if err := col.serializeForProjection(statement, out); err != nil {
return err
}
}
@ -96,17 +96,17 @@ func serializeProjectionList(projections []projection, out *queryData) error {
return nil
}
func serializeColumnList(columns []column, out *queryData) error {
func serializeColumnList(statement statementType, columns []column, out *queryData) error {
for i, col := range columns {
if i > 0 {
out.WriteByte(',')
out.writeByte(',')
}
if col == nil {
return errors.New("nil column in columns list.")
}
out.WriteString(col.Name())
out.writeString(col.Name())
}
return nil