diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index 66907a1..a721bf2 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "bytes" "gotest.tools/assert" "testing" ) @@ -9,100 +8,102 @@ import ( func TestBinaryExpression(t *testing.T) { boolExpression := Eq(Literal(2), Literal(3)) - out := bytes.Buffer{} + out := queryData{} err := boolExpression.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), "2 = 3") + + assert.Equal(t, out.buff.String(), "$1 = $2") + assert.Equal(t, len(out.args), 2) t.Run("alias", func(t *testing.T) { alias := boolExpression.As("alias_eq_expression") - out := bytes.Buffer{} + out := queryData{} err := alias.SerializeForProjection(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `2 = 3 AS "alias_eq_expression"`) + assert.Equal(t, out.buff.String(), `$1 = $2 AS "alias_eq_expression"`) }) t.Run("and", func(t *testing.T) { exp := boolExpression.And(Eq(Literal(4), Literal(5))) - out := bytes.Buffer{} + out := queryData{} err := exp.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `(2 = 3 AND 4 = 5)`) + assert.Equal(t, out.buff.String(), `($1 = $2 AND $3 = $4)`) }) t.Run("or", func(t *testing.T) { exp := boolExpression.Or(Eq(Literal(4), Literal(5))) - out := bytes.Buffer{} + out := queryData{} err := exp.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `(2 = 3 OR 4 = 5)`) + assert.Equal(t, out.buff.String(), `($1 = $2 OR $3 = $4)`) }) } func TestUnaryExpression(t *testing.T) { notExpression := Not(Eq(Literal(2), Literal(1))) - out := bytes.Buffer{} + out := queryData{} err := notExpression.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), " NOT 2 = 1") + assert.Equal(t, out.buff.String(), " NOT $1 = $2") t.Run("alias", func(t *testing.T) { alias := notExpression.As("alias_not_expression") - out := bytes.Buffer{} + out := queryData{} err := alias.SerializeForProjection(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), ` NOT 2 = 1 AS "alias_not_expression"`) + assert.Equal(t, out.buff.String(), ` NOT $1 = $2 AS "alias_not_expression"`) }) t.Run("and", func(t *testing.T) { exp := notExpression.And(Eq(Literal(4), Literal(5))) - out := bytes.Buffer{} + out := queryData{} err := exp.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `( NOT 2 = 1 AND 4 = 5)`) + assert.Equal(t, out.buff.String(), `( NOT $1 = $2 AND $3 = $4)`) }) } func TestUnaryIsTrueExpression(t *testing.T) { notExpression := IsTrue(Eq(Literal(2), Literal(1))) - out := bytes.Buffer{} + out := queryData{} err := notExpression.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), " IS TRUE 2 = 1") + assert.Equal(t, out.buff.String(), " IS TRUE $1 = $2") t.Run("and", func(t *testing.T) { exp := notExpression.And(Eq(Literal(4), Literal(5))) - out := bytes.Buffer{} + out := queryData{} err := exp.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `( IS TRUE 2 = 1 AND 4 = 5)`) + assert.Equal(t, out.buff.String(), `( IS TRUE $1 = $2 AND $3 = $4)`) }) } func TestBoolLiteral(t *testing.T) { literal := newBoolLiteralExpression(true) - out := bytes.Buffer{} + out := queryData{} err := literal.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), "true") + assert.Equal(t, out.buff.String(), "$1") } diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 224c419..c5bcb10 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -42,6 +42,11 @@ func (q *queryData) InsertArgument(arg interface{}) { q.buff.WriteString(argPlaceholder) } +func (q *queryData) Reset() { + q.buff.Reset() + q.args = []interface{}{} +} + func argToString(value interface{}) (string, error) { switch bindVal := value.(type) { case bool: diff --git a/sqlbuilder/column_types_test.go b/sqlbuilder/column_types_test.go index 319a244..1bbf822 100644 --- a/sqlbuilder/column_types_test.go +++ b/sqlbuilder/column_types_test.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "bytes" "gotest.tools/assert" "testing" ) @@ -9,23 +8,23 @@ import ( func TestNewBoolColumn(t *testing.T) { boolColumn := NewBoolColumn("col", Nullable) - out := bytes.Buffer{} + out := queryData{} err := boolColumn.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), "col") + assert.Equal(t, out.buff.String(), "col") out.Reset() err = boolColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) - assert.Equal(t, out.String(), "col") + assert.Equal(t, out.buff.String(), "col") out.Reset() err = boolColumn.setTableName("table1") assert.NilError(t, err) err = boolColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) - assert.Equal(t, out.String(), `table1.col AS "table1.col"`) + assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`) out.Reset() err = boolColumn.setTableName("table1") @@ -33,29 +32,29 @@ func TestNewBoolColumn(t *testing.T) { aliasedBoolColumn := boolColumn.As("alias1") err = aliasedBoolColumn.SerializeForProjection(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `table1.col AS "alias1"`) + assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`) } func TestNewIntColumn(t *testing.T) { integerColumn := NewIntegerColumn("col", Nullable) - out := bytes.Buffer{} + out := queryData{} err := integerColumn.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), "col") + assert.Equal(t, out.buff.String(), "col") out.Reset() err = integerColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) - assert.Equal(t, out.String(), "col") + assert.Equal(t, out.buff.String(), "col") out.Reset() err = integerColumn.setTableName("table1") assert.NilError(t, err) err = integerColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) - assert.Equal(t, out.String(), `table1.col AS "table1.col"`) + assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`) out.Reset() err = integerColumn.setTableName("table1") @@ -63,29 +62,29 @@ func TestNewIntColumn(t *testing.T) { aliasedBoolColumn := integerColumn.As("alias1") err = aliasedBoolColumn.SerializeForProjection(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `table1.col AS "alias1"`) + assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`) } func TestNewNumericColumnColumn(t *testing.T) { numericColumn := NewNumericColumn("col", Nullable) - out := bytes.Buffer{} + out := queryData{} err := numericColumn.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), "col") + assert.Equal(t, out.buff.String(), "col") out.Reset() err = numericColumn.Serialize(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), "col") + assert.Equal(t, out.buff.String(), "col") out.Reset() err = numericColumn.setTableName("table1") assert.NilError(t, err) err = numericColumn.Serialize(&out, FOR_PROJECTION) assert.NilError(t, err) - assert.Equal(t, out.String(), `table1.col AS "table1.col"`) + assert.Equal(t, out.buff.String(), `table1.col AS "table1.col"`) out.Reset() err = numericColumn.setTableName("table1") @@ -93,5 +92,5 @@ func TestNewNumericColumnColumn(t *testing.T) { aliasedBoolColumn := numericColumn.As("alias1") err = aliasedBoolColumn.SerializeForProjection(&out) assert.NilError(t, err) - assert.Equal(t, out.String(), `table1.col AS "alias1"`) + assert.Equal(t, out.buff.String(), `table1.col AS "alias1"`) } diff --git a/sqlbuilder/delete_statement_test.go b/sqlbuilder/delete_statement_test.go index 0ef31ae..24e3fc9 100644 --- a/sqlbuilder/delete_statement_test.go +++ b/sqlbuilder/delete_statement_test.go @@ -6,13 +6,13 @@ import ( ) func TestDeleteUnconditionally(t *testing.T) { - _, err := table1.Delete().String() + _, _, err := table1.DELETE().Sql() assert.Assert(t, err != nil) } func TestDeleteWithWhere(t *testing.T) { - sql, err := table1.Delete().WHERE(table1Col1.EqL(1)).String() + sql, _, err := table1.DELETE().WHERE(table1Col1.EqL(1)).Sql() assert.NilError(t, err) - assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = 1;") + assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = $1;") } diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index cae2e2f..2fb127d 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -64,19 +64,31 @@ func (c *binaryExpression) Serialize(out *queryData, options ...serializeOption) if c.lhs == nil { return errors.Newf("nil lhs.") } + if c.rhs == nil { + return errors.Newf("nil rhs.") + } + + _, literalLeft := c.lhs.(*literalExpression) + _, literalRight := c.rhs.(*literalExpression) + + if !literalLeft && !literalRight { + out.WriteString("(") + } + if err := c.lhs.Serialize(out); err != nil { return err } out.Write(c.operator) - if c.rhs == nil { - return errors.Newf("nil rhs.") - } if err := c.rhs.Serialize(out); err != nil { return err } + if !literalLeft && !literalRight { + out.WriteString(")") + } + return nil } diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 48ed991..38846b3 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -54,6 +54,10 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { // expression or default keyword func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { + if len(values) == 0 { + return s + } + literalRow := []Clause{} for _, value := range values { diff --git a/sqlbuilder/insert_statement_test.go b/sqlbuilder/insert_statement_test.go index be2a600..2f6df0c 100644 --- a/sqlbuilder/insert_statement_test.go +++ b/sqlbuilder/insert_statement_test.go @@ -8,63 +8,62 @@ import ( ) func TestInsertNoColumn(t *testing.T) { - _, err := table1.INSERT().VALUES().String() + _, _, err := table1.INSERT().VALUES().Sql() assert.Assert(t, err != nil) } func TestInsertNoRow(t *testing.T) { - _, err := table1.INSERT(table1Col1).String() + _, _, err := table1.INSERT(table1Col1).Sql() assert.Assert(t, err != nil) } func TestInsertColumnLengthMismatch(t *testing.T) { - _, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).String() + _, _, err := table1.INSERT(table1Col1, table1Col2).VALUES(nil).Sql() - fmt.Println(err) + //fmt.Println(err) assert.Assert(t, err != nil) } func TestInsertNilValue(t *testing.T) { - _, err := table1.INSERT(table1Col1).VALUES(nil).String() + query, args, err := table1.INSERT(table1Col1).VALUES(nil).Sql() - assert.Assert(t, err != nil) + assert.Equal(t, query, "INSERT INTO db.table1 (col1) VALUES ($1);") + assert.Equal(t, len(args), 1) + assert.NilError(t, err) } func TestInsertNilColumn(t *testing.T) { - _, err := table1.INSERT(nil).VALUES(1).String() + _, _, err := table1.INSERT(nil).VALUES(1).Sql() assert.Assert(t, err != nil) } func TestInsertSingleValue(t *testing.T) { - sql, err := table1.INSERT(table1Col1).VALUES(1).String() + sql, _, err := table1.INSERT(table1Col1).VALUES(1).Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES (1)") + assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES ($1);") } func TestInsertDate(t *testing.T) { date := time.Date(1999, 1, 2, 3, 4, 5, 0, time.UTC) - sql, err := table1.INSERT(table1Col4).VALUES(date).String() + sql, _, err := table1.INSERT(table1Col4).VALUES(date).Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 (col4) "+ - "VALUES ('1999-01-02 03:04:05.000000')") + assert.Equal(t, sql, "INSERT INTO db.table1 (col4) VALUES ($1);") } func TestInsertMultipleValues(t *testing.T) { stmt := table1.INSERT(table1Col1, table1Col2, table1Col3) stmt.VALUES(1, 2, 3) - sql, err := stmt.String() + sql, _, err := stmt.Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 "+ - "(col1,col2,col3) "+ - "VALUES (1,2,3)") + assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2,col3) VALUES ($1, $2, $3);") } func TestInsertMultipleRows(t *testing.T) { @@ -73,12 +72,10 @@ func TestInsertMultipleRows(t *testing.T) { VALUES(11, 22). VALUES(111, 222) - sql, err := stmt.String() + sql, _, err := stmt.Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 "+ - "(col1,col2) "+ - "VALUES (1,2), (11,22), (111,222)") + assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2) VALUES ($1, $2), ($3, $4), ($5, $6);") } func TestInsertValuesFromModel(t *testing.T) { @@ -95,13 +92,13 @@ func TestInsertValuesFromModel(t *testing.T) { stmt := table1.INSERT(table1Col1, table1Col2). VALUES_MAPPING(toInsert) - sql, err := stmt.String() + sql, _, err := stmt.Sql() assert.NilError(t, err) fmt.Println(sql) - assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES (1,'one')`) + assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES ($1, $2);`) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { @@ -118,9 +115,9 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) { stmt := table1.INSERT(table1Col1, table1Col2). VALUES_MAPPING(toInsert) - _, err := stmt.String() + _, _, err := stmt.Sql() - fmt.Println(err) + //fmt.Println(err) assert.Assert(t, err != nil) } @@ -129,7 +126,7 @@ func TestInsertQuery(t *testing.T) { stmt := table1.INSERT(table1Col1). QUERY(table1.SELECT(table1Col1)) - stmtStr, err := stmt.String() + stmtStr, _, err := stmt.Sql() assert.NilError(t, err) @@ -140,7 +137,7 @@ func TestInsertDefaultValue(t *testing.T) { stmt := table1.INSERT(table1Col1, table1Col2). VALUES(DEFAULT, "two") - stmtStr, err := stmt.String() + stmtStr, _, err := stmt.Sql() assert.NilError(t, err) diff --git a/sqlbuilder/set_statement_test.go b/sqlbuilder/set_statement_test.go new file mode 100644 index 0000000..3896e39 --- /dev/null +++ b/sqlbuilder/set_statement_test.go @@ -0,0 +1 @@ +package sqlbuilder diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 4510da2..33fa219 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -45,7 +45,7 @@ type WritableTable interface { INSERT(columns ...Column) InsertStatement UPDATE(columns ...Column) UpdateStatement - Delete() DeleteStatement + DELETE() DeleteStatement } // Defines a physical tableName in the database that is both readable and writable. @@ -220,7 +220,7 @@ func (t *Table) UPDATE(columns ...Column) UpdateStatement { return newUpdateStatement(t, columns) } -func (t *Table) Delete() DeleteStatement { +func (t *Table) DELETE() DeleteStatement { return newDeleteStatement(t) } diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go index 46187d4..cc1be4d 100644 --- a/sqlbuilder/update_statement_test.go +++ b/sqlbuilder/update_statement_test.go @@ -15,7 +15,7 @@ func TestUpdate(t *testing.T) { SET(table1.SELECT(table1Col2)). WHERE(table1Col1.EqL(2)) - stmtStr, err := stmt.String() + stmtStr, _, err := stmt.Sql() assert.NilError(t, err)