Fix unit tests.

This commit is contained in:
zer0sub 2019-05-01 17:25:10 +02:00
parent 9b826fff6e
commit eccc17dc8a
10 changed files with 91 additions and 72 deletions

View file

@ -1,7 +1,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
) )
@ -9,100 +8,102 @@ import (
func TestBinaryExpression(t *testing.T) { func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3)) boolExpression := Eq(Literal(2), Literal(3))
out := bytes.Buffer{} out := queryData{}
err := boolExpression.Serialize(&out) err := boolExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "2 = 3")
assert.Equal(t, out.buff.String(), "$1 = $2")
assert.Equal(t, len(out.args), 2)
t.Run("alias", func(t *testing.T) { t.Run("alias", func(t *testing.T) {
alias := boolExpression.As("alias_eq_expression") alias := boolExpression.As("alias_eq_expression")
out := bytes.Buffer{} out := queryData{}
err := alias.SerializeForProjection(&out) err := alias.SerializeForProjection(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `2 = 3 AS "alias_eq_expression"`) assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`)
}) })
t.Run("and", func(t *testing.T) { t.Run("and", func(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5))) exp := boolExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`) assert.Equal(t, out.buff.String(), `($1 = $2 AND $3 = $4)`)
}) })
t.Run("or", func(t *testing.T) { t.Run("or", func(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5))) exp := boolExpression.Or(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`) assert.Equal(t, out.buff.String(), `($1 = $2 OR $3 = $4)`)
}) })
} }
func TestUnaryExpression(t *testing.T) { func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1))) notExpression := Not(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{} out := queryData{}
err := notExpression.Serialize(&out) err := notExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), " NOT 2 = 1") assert.Equal(t, out.buff.String(), " NOT $1 = $2")
t.Run("alias", func(t *testing.T) { t.Run("alias", func(t *testing.T) {
alias := notExpression.As("alias_not_expression") alias := notExpression.As("alias_not_expression")
out := bytes.Buffer{} out := queryData{}
err := alias.SerializeForProjection(&out) err := alias.SerializeForProjection(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), ` NOT 2 = 1 AS "alias_not_expression"`) assert.Equal(t, out.buff.String(), ` NOT $1 = $2 AS "alias_not_expression"`)
}) })
t.Run("and", func(t *testing.T) { t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5))) exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`) assert.Equal(t, out.buff.String(), `( NOT $1 = $2 AND $3 = $4)`)
}) })
} }
func TestUnaryIsTrueExpression(t *testing.T) { func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1))) notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{} out := queryData{}
err := notExpression.Serialize(&out) err := notExpression.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), " IS TRUE 2 = 1") assert.Equal(t, out.buff.String(), " IS TRUE $1 = $2")
t.Run("and", func(t *testing.T) { t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5))) exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{} out := queryData{}
err := exp.Serialize(&out) err := exp.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`) assert.Equal(t, out.buff.String(), `( IS TRUE $1 = $2 AND $3 = $4)`)
}) })
} }
func TestBoolLiteral(t *testing.T) { func TestBoolLiteral(t *testing.T) {
literal := newBoolLiteralExpression(true) literal := newBoolLiteralExpression(true)
out := bytes.Buffer{} out := queryData{}
err := literal.Serialize(&out) err := literal.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "true") assert.Equal(t, out.buff.String(), "$1")
} }

View file

@ -42,6 +42,11 @@ func (q *queryData) InsertArgument(arg interface{}) {
q.buff.WriteString(argPlaceholder) q.buff.WriteString(argPlaceholder)
} }
func (q *queryData) Reset() {
q.buff.Reset()
q.args = []interface{}{}
}
func argToString(value interface{}) (string, error) { func argToString(value interface{}) (string, error) {
switch bindVal := value.(type) { switch bindVal := value.(type) {
case bool: case bool:

View file

@ -1,7 +1,6 @@
package sqlbuilder package sqlbuilder
import ( import (
"bytes"
"gotest.tools/assert" "gotest.tools/assert"
"testing" "testing"
) )
@ -9,23 +8,23 @@ import (
func TestNewBoolColumn(t *testing.T) { func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable) boolColumn := NewBoolColumn("col", Nullable)
out := bytes.Buffer{} out := queryData{}
err := boolColumn.Serialize(&out) err := boolColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.Reset()
err = boolColumn.Serialize(&out, FOR_PROJECTION) err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.Reset()
err = boolColumn.setTableName("table1") err = boolColumn.setTableName("table1")
assert.NilError(t, err) assert.NilError(t, err)
err = boolColumn.Serialize(&out, FOR_PROJECTION) err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset() out.Reset()
err = boolColumn.setTableName("table1") err = boolColumn.setTableName("table1")
@ -33,29 +32,29 @@ func TestNewBoolColumn(t *testing.T) {
aliasedBoolColumn := boolColumn.As("alias1") aliasedBoolColumn := boolColumn.As("alias1")
err = aliasedBoolColumn.SerializeForProjection(&out) err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`) assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
} }
func TestNewIntColumn(t *testing.T) { func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable) integerColumn := NewIntegerColumn("col", Nullable)
out := bytes.Buffer{} out := queryData{}
err := integerColumn.Serialize(&out) err := integerColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.Reset()
err = integerColumn.Serialize(&out, FOR_PROJECTION) err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.Reset()
err = integerColumn.setTableName("table1") err = integerColumn.setTableName("table1")
assert.NilError(t, err) assert.NilError(t, err)
err = integerColumn.Serialize(&out, FOR_PROJECTION) err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset() out.Reset()
err = integerColumn.setTableName("table1") err = integerColumn.setTableName("table1")
@ -63,29 +62,29 @@ func TestNewIntColumn(t *testing.T) {
aliasedBoolColumn := integerColumn.As("alias1") aliasedBoolColumn := integerColumn.As("alias1")
err = aliasedBoolColumn.SerializeForProjection(&out) err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`) assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
} }
func TestNewNumericColumnColumn(t *testing.T) { func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable) numericColumn := NewNumericColumn("col", Nullable)
out := bytes.Buffer{} out := queryData{}
err := numericColumn.Serialize(&out) err := numericColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.Reset()
err = numericColumn.Serialize(&out) err = numericColumn.Serialize(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), "col") assert.Equal(t, out.buff.String(), "col")
out.Reset() out.Reset()
err = numericColumn.setTableName("table1") err = numericColumn.setTableName("table1")
assert.NilError(t, err) assert.NilError(t, err)
err = numericColumn.Serialize(&out, FOR_PROJECTION) err = numericColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`) assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset() out.Reset()
err = numericColumn.setTableName("table1") err = numericColumn.setTableName("table1")
@ -93,5 +92,5 @@ func TestNewNumericColumnColumn(t *testing.T) {
aliasedBoolColumn := numericColumn.As("alias1") aliasedBoolColumn := numericColumn.As("alias1")
err = aliasedBoolColumn.SerializeForProjection(&out) err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`) assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
} }

View file

@ -6,13 +6,13 @@ import (
) )
func TestDeleteUnconditionally(t *testing.T) { func TestDeleteUnconditionally(t *testing.T) {
_, err := table1.Delete().String() _, _, err := table1.DELETE().Sql()
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestDeleteWithWhere(t *testing.T) { func TestDeleteWithWhere(t *testing.T) {
sql, err := table1.Delete().WHERE(table1Col1.EqL(1)).String() sql, _, err := table1.DELETE().WHERE(table1Col1.EqL(1)).Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = 1;") assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = $1;")
} }

View file

@ -64,19 +64,31 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption)
if c.lhs == nil { if c.lhs == nil {
return errors.Newf("nil lhs.") return errors.Newf("nil lhs.")
} }
if c.rhs == nil {
return errors.Newf("nil rhs.")
}
_, literalLeft := c.lhs.(*literalExpression)
_, literalRight := c.rhs.(*literalExpression)
if !literalLeft && !literalRight {
out.WriteString("(")
}
if err := c.lhs.Serialize(out); err != nil { if err := c.lhs.Serialize(out); err != nil {
return err return err
} }
out.Write(c.operator) out.Write(c.operator)
if c.rhs == nil {
return errors.Newf("nil rhs.")
}
if err := c.rhs.Serialize(out); err != nil { if err := c.rhs.Serialize(out); err != nil {
return err return err
} }
if !literalLeft && !literalRight {
out.WriteString(")")
}
return nil return nil
} }

View file

@ -54,6 +54,10 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
// expression or default keyword // expression or default keyword
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
if len(values) == 0 {
return s
}
literalRow := []Clause{} literalRow := []Clause{}
for _, value := range values { for _, value := range values {

View file

@ -8,63 +8,62 @@ import (
) )
func TestInsertNoColumn(t *testing.T) { func TestInsertNoColumn(t *testing.T) {
_, err := table1.INSERT().VALUES().String() _, _, err := table1.INSERT().VALUES().Sql()
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestInsertNoRow(t *testing.T) { func TestInsertNoRow(t *testing.T) {
_, err := table1.INSERT(table1Col1).String() _, _, err := table1.INSERT(table1Col1).Sql()
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestInsertColumnLengthMismatch(t *testing.T) { func TestInsertColumnLengthMismatch(t *testing.T) {
_, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).String() _, _, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).Sql()
fmt.Println(err) //fmt.Println(err)
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestInsertNilValue(t *testing.T) { func TestInsertNilValue(t *testing.T) {
_, err := table1.INSERT(table1Col1).VALUES(nil).String() query, args, err := table1.INSERT(table1Col1).VALUES(nil).Sql()
assert.Assert(t, err != nil) assert.Equal(t, query, "INSERT INTO db.table1 (col1) VALUES ($1);")
assert.Equal(t, len(args), 1)
assert.NilError(t, err)
} }
func TestInsertNilColumn(t *testing.T) { func TestInsertNilColumn(t *testing.T) {
_, err := table1.INSERT(nil).VALUES(1).String() _, _, err := table1.INSERT(nil).VALUES(1).Sql()
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestInsertSingleValue(t *testing.T) { func TestInsertSingleValue(t *testing.T) {
sql, err := table1.INSERT(table1Col1).VALUES(1).String() sql, _, err := table1.INSERT(table1Col1).VALUES(1).Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES (1)") assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES ($1);")
} }
func TestInsertDate(t *testing.T) { func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
sql, err := table1.INSERT(table1Col4).VALUES(date).String() sql, _, err := table1.INSERT(table1Col4).VALUES(date).Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 (col4) "+ assert.Equal(t, sql, "INSERT INTO db.table1 (col4) VALUES ($1);")
"VALUES ('1999-01-02 03:04:05.000000')")
} }
func TestInsertMultipleValues(t *testing.T) { func TestInsertMultipleValues(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2, table1Col3) stmt := table1.INSERT(table1Col1, table1Col2, table1Col3)
stmt.VALUES(1, 2, 3) stmt.VALUES(1, 2, 3)
sql, err := stmt.String() sql, _, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 "+ assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2,col3) VALUES ($1, $2, $3);")
"(col1,col2,col3) "+
"VALUES (1,2,3)")
} }
func TestInsertMultipleRows(t *testing.T) { func TestInsertMultipleRows(t *testing.T) {
@ -73,12 +72,10 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(11, 22). VALUES(11, 22).
VALUES(111, 222) VALUES(111, 222)
sql, err := stmt.String() sql, _, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 "+ assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2) VALUES ($1, $2), ($3, $4), ($5, $6);")
"(col1,col2) "+
"VALUES (1,2), (11,22), (111,222)")
} }
func TestInsertValuesFromModel(t *testing.T) { func TestInsertValuesFromModel(t *testing.T) {
@ -95,13 +92,13 @@ func TestInsertValuesFromModel(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2). stmt := table1.INSERT(table1Col1, table1Col2).
VALUES_MAPPING(toInsert) VALUES_MAPPING(toInsert)
sql, err := stmt.String() sql, _, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)
fmt.Println(sql) fmt.Println(sql)
assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES (1,'one')`) assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES ($1, $2);`)
} }
func TestInsertValuesFromModelColumnMismatch(t *testing.T) { func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -118,9 +115,9 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2). stmt := table1.INSERT(table1Col1, table1Col2).
VALUES_MAPPING(toInsert) VALUES_MAPPING(toInsert)
_, err := stmt.String() _, _, err := stmt.Sql()
fmt.Println(err) //fmt.Println(err)
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
@ -129,7 +126,7 @@ func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1). stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1)) QUERY(table1.SELECT(table1Col1))
stmtStr, err := stmt.String() stmtStr, _, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)
@ -140,7 +137,7 @@ func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2). stmt := table1.INSERT(table1Col1, table1Col2).
VALUES(DEFAULT, "two") VALUES(DEFAULT, "two")
stmtStr, err := stmt.String() stmtStr, _, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)

View file

@ -0,0 +1 @@
package sqlbuilder

View file

@ -45,7 +45,7 @@ type WritableTable interface {
INSERT(columns ...Column) InsertStatement INSERT(columns ...Column) InsertStatement
UPDATE(columns ...Column) UpdateStatement UPDATE(columns ...Column) UpdateStatement
Delete() DeleteStatement DELETE() DeleteStatement
} }
// Defines a physical tableName in the database that is both readable and writable. // Defines a physical tableName in the database that is both readable and writable.
@ -220,7 +220,7 @@ func (t *Table) UPDATE(columns ...Column) UpdateStatement {
return newUpdateStatement(t, columns) return newUpdateStatement(t, columns)
} }
func (t *Table) Delete() DeleteStatement { func (t *Table) DELETE() DeleteStatement {
return newDeleteStatement(t) return newDeleteStatement(t)
} }

View file

@ -15,7 +15,7 @@ func TestUpdate(t *testing.T) {
SET(table1.SELECT(table1Col2)). SET(table1.SELECT(table1Col2)).
WHERE(table1Col1.EqL(2)) WHERE(table1Col1.EqL(2))
stmtStr, err := stmt.String() stmtStr, _, err := stmt.Sql()
assert.NilError(t, err) assert.NilError(t, err)