Fix unit tests.

This commit is contained in:
zer0sub 2019-05-01 17:25:10 +02:00
parent 9b826fff6e
commit eccc17dc8a
10 changed files with 91 additions and 72 deletions

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"bytes"
"gotest.tools/assert"
"testing"
)
@ -9,100 +8,102 @@ import (
func TestBinaryExpression(t *testing.T) {
boolExpression := Eq(Literal(2), Literal(3))
out := bytes.Buffer{}
out := queryData{}
err := boolExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "2 = 3")
assert.Equal(t, out.buff.String(), "$1 = $2")
assert.Equal(t, len(out.args), 2)
t.Run("alias", func(t *testing.T) {
alias := boolExpression.As("alias_eq_expression")
out := bytes.Buffer{}
out := queryData{}
err := alias.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `2 = 3 AS "alias_eq_expression"`)
assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`)
})
t.Run("and", func(t *testing.T) {
exp := boolExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
out := queryData{}
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`)
assert.Equal(t, out.buff.String(), `($1 = $2 AND $3 = $4)`)
})
t.Run("or", func(t *testing.T) {
exp := boolExpression.Or(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
out := queryData{}
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`)
assert.Equal(t, out.buff.String(), `($1 = $2 OR $3 = $4)`)
})
}
func TestUnaryExpression(t *testing.T) {
notExpression := Not(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{}
out := queryData{}
err := notExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), " NOT 2 = 1")
assert.Equal(t, out.buff.String(), " NOT $1 = $2")
t.Run("alias", func(t *testing.T) {
alias := notExpression.As("alias_not_expression")
out := bytes.Buffer{}
out := queryData{}
err := alias.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), ` NOT 2 = 1 AS "alias_not_expression"`)
assert.Equal(t, out.buff.String(), ` NOT $1 = $2 AS "alias_not_expression"`)
})
t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
out := queryData{}
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`)
assert.Equal(t, out.buff.String(), `( NOT $1 = $2 AND $3 = $4)`)
})
}
func TestUnaryIsTrueExpression(t *testing.T) {
notExpression := IsTrue(Eq(Literal(2), Literal(1)))
out := bytes.Buffer{}
out := queryData{}
err := notExpression.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), " IS TRUE 2 = 1")
assert.Equal(t, out.buff.String(), " IS TRUE $1 = $2")
t.Run("and", func(t *testing.T) {
exp := notExpression.And(Eq(Literal(4), Literal(5)))
out := bytes.Buffer{}
out := queryData{}
err := exp.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`)
assert.Equal(t, out.buff.String(), `( IS TRUE $1 = $2 AND $3 = $4)`)
})
}
func TestBoolLiteral(t *testing.T) {
literal := newBoolLiteralExpression(true)
out := bytes.Buffer{}
out := queryData{}
err := literal.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "true")
assert.Equal(t, out.buff.String(), "$1")
}

View file

@ -42,6 +42,11 @@ func (q *queryData) InsertArgument(arg interface{}) {
q.buff.WriteString(argPlaceholder)
}
func (q *queryData) Reset() {
q.buff.Reset()
q.args = []interface{}{}
}
func argToString(value interface{}) (string, error) {
switch bindVal := value.(type) {
case bool:

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"bytes"
"gotest.tools/assert"
"testing"
)
@ -9,23 +8,23 @@ import (
func TestNewBoolColumn(t *testing.T) {
boolColumn := NewBoolColumn("col", Nullable)
out := bytes.Buffer{}
out := queryData{}
err := boolColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = boolColumn.setTableName("table1")
assert.NilError(t, err)
err = boolColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset()
err = boolColumn.setTableName("table1")
@ -33,29 +32,29 @@ func TestNewBoolColumn(t *testing.T) {
aliasedBoolColumn := boolColumn.As("alias1")
err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
}
func TestNewIntColumn(t *testing.T) {
integerColumn := NewIntegerColumn("col", Nullable)
out := bytes.Buffer{}
out := queryData{}
err := integerColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = integerColumn.setTableName("table1")
assert.NilError(t, err)
err = integerColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset()
err = integerColumn.setTableName("table1")
@ -63,29 +62,29 @@ func TestNewIntColumn(t *testing.T) {
aliasedBoolColumn := integerColumn.As("alias1")
err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
}
func TestNewNumericColumnColumn(t *testing.T) {
numericColumn := NewNumericColumn("col", Nullable)
out := bytes.Buffer{}
out := queryData{}
err := numericColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = numericColumn.Serialize(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), "col")
assert.Equal(t, out.buff.String(), "col")
out.Reset()
err = numericColumn.setTableName("table1")
assert.NilError(t, err)
err = numericColumn.Serialize(&out, FOR_PROJECTION)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "table1.col"`)
assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`)
out.Reset()
err = numericColumn.setTableName("table1")
@ -93,5 +92,5 @@ func TestNewNumericColumnColumn(t *testing.T) {
aliasedBoolColumn := numericColumn.As("alias1")
err = aliasedBoolColumn.SerializeForProjection(&out)
assert.NilError(t, err)
assert.Equal(t, out.String(), `table1.col AS "alias1"`)
assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`)
}

View file

@ -6,13 +6,13 @@ import (
)
func TestDeleteUnconditionally(t *testing.T) {
_, err := table1.Delete().String()
_, _, err := table1.DELETE().Sql()
assert.Assert(t, err != nil)
}
func TestDeleteWithWhere(t *testing.T) {
sql, err := table1.Delete().WHERE(table1Col1.EqL(1)).String()
sql, _, err := table1.DELETE().WHERE(table1Col1.EqL(1)).Sql()
assert.NilError(t, err)
assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = 1;")
assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = $1;")
}

View file

@ -64,19 +64,31 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption)
if c.lhs == nil {
return errors.Newf("nil lhs.")
}
if c.rhs == nil {
return errors.Newf("nil rhs.")
}
_, literalLeft := c.lhs.(*literalExpression)
_, literalRight := c.rhs.(*literalExpression)
if !literalLeft && !literalRight {
out.WriteString("(")
}
if err := c.lhs.Serialize(out); err != nil {
return err
}
out.Write(c.operator)
if c.rhs == nil {
return errors.Newf("nil rhs.")
}
if err := c.rhs.Serialize(out); err != nil {
return err
}
if !literalLeft && !literalRight {
out.WriteString(")")
}
return nil
}

View file

@ -54,6 +54,10 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
// expression or default keyword
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
if len(values) == 0 {
return s
}
literalRow := []Clause{}
for _, value := range values {

View file

@ -8,63 +8,62 @@ import (
)
func TestInsertNoColumn(t *testing.T) {
_, err := table1.INSERT().VALUES().String()
_, _, err := table1.INSERT().VALUES().Sql()
assert.Assert(t, err != nil)
}
func TestInsertNoRow(t *testing.T) {
_, err := table1.INSERT(table1Col1).String()
_, _, err := table1.INSERT(table1Col1).Sql()
assert.Assert(t, err != nil)
}
func TestInsertColumnLengthMismatch(t *testing.T) {
_, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).String()
_, _, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).Sql()
fmt.Println(err)
//fmt.Println(err)
assert.Assert(t, err != nil)
}
func TestInsertNilValue(t *testing.T) {
_, err := table1.INSERT(table1Col1).VALUES(nil).String()
query, args, err := table1.INSERT(table1Col1).VALUES(nil).Sql()
assert.Assert(t, err != nil)
assert.Equal(t, query, "INSERT INTO db.table1 (col1) VALUES ($1);")
assert.Equal(t, len(args), 1)
assert.NilError(t, err)
}
func TestInsertNilColumn(t *testing.T) {
_, err := table1.INSERT(nil).VALUES(1).String()
_, _, err := table1.INSERT(nil).VALUES(1).Sql()
assert.Assert(t, err != nil)
}
func TestInsertSingleValue(t *testing.T) {
sql, err := table1.INSERT(table1Col1).VALUES(1).String()
sql, _, err := table1.INSERT(table1Col1).VALUES(1).Sql()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES (1)")
assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES ($1);")
}
func TestInsertDate(t *testing.T) {
date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC)
sql, err := table1.INSERT(table1Col4).VALUES(date).String()
sql, _, err := table1.INSERT(table1Col4).VALUES(date).Sql()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 (col4) "+
"VALUES ('1999-01-02 03:04:05.000000')")
assert.Equal(t, sql, "INSERT INTO db.table1 (col4) VALUES ($1);")
}
func TestInsertMultipleValues(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2, table1Col3)
stmt.VALUES(1, 2, 3)
sql, err := stmt.String()
sql, _, err := stmt.Sql()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 "+
"(col1,col2,col3) "+
"VALUES (1,2,3)")
assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2,col3) VALUES ($1, $2, $3);")
}
func TestInsertMultipleRows(t *testing.T) {
@ -73,12 +72,10 @@ func TestInsertMultipleRows(t *testing.T) {
VALUES(11, 22).
VALUES(111, 222)
sql, err := stmt.String()
sql, _, err := stmt.Sql()
assert.NilError(t, err)
assert.Equal(t, sql, "INSERT INTO db.table1 "+
"(col1,col2) "+
"VALUES (1,2), (11,22), (111,222)")
assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2) VALUES ($1, $2), ($3, $4), ($5, $6);")
}
func TestInsertValuesFromModel(t *testing.T) {
@ -95,13 +92,13 @@ func TestInsertValuesFromModel(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES_MAPPING(toInsert)
sql, err := stmt.String()
sql, _, err := stmt.Sql()
assert.NilError(t, err)
fmt.Println(sql)
assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES (1,'one')`)
assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES ($1, $2);`)
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -118,9 +115,9 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES_MAPPING(toInsert)
_, err := stmt.String()
_, _, err := stmt.Sql()
fmt.Println(err)
//fmt.Println(err)
assert.Assert(t, err != nil)
}
@ -129,7 +126,7 @@ func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1))
stmtStr, err := stmt.String()
stmtStr, _, err := stmt.Sql()
assert.NilError(t, err)
@ -140,7 +137,7 @@ func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES(DEFAULT, "two")
stmtStr, err := stmt.String()
stmtStr, _, err := stmt.Sql()
assert.NilError(t, err)

View file

@ -0,0 +1 @@
package sqlbuilder

View file

@ -45,7 +45,7 @@ type WritableTable interface {
INSERT(columns ...Column) InsertStatement
UPDATE(columns ...Column) UpdateStatement
Delete() DeleteStatement
DELETE() DeleteStatement
}
// Defines a physical tableName in the database that is both readable and writable.
@ -220,7 +220,7 @@ func (t *Table) UPDATE(columns ...Column) UpdateStatement {
return newUpdateStatement(t, columns)
}
func (t *Table) Delete() DeleteStatement {
func (t *Table) DELETE() DeleteStatement {
return newDeleteStatement(t)
}

View file

@ -15,7 +15,7 @@ func TestUpdate(t *testing.T) {
SET(table1.SELECT(table1Col2)).
WHERE(table1Col1.EqL(2))
stmtStr, err := stmt.String()
stmtStr, _, err := stmt.Sql()
assert.NilError(t, err)