From 240ddd65e66c3bdcc9fed29f222d15c201daa9d7 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Sun, 12 May 2019 18:15:23 +0200 Subject: [PATCH] Add statements debug sql support. --- sqlbuilder/alias.go | 2 +- sqlbuilder/bool_expression_test.go | 25 +- sqlbuilder/clause.go | 141 ++++-- sqlbuilder/column.go | 25 +- sqlbuilder/delete_statement.go | 38 +- sqlbuilder/delete_statement_test.go | 8 +- sqlbuilder/execution/execution.go | 4 +- sqlbuilder/expression.go | 2 +- sqlbuilder/expression_table.go | 11 +- sqlbuilder/func_expression.go | 15 +- sqlbuilder/insert_statement.go | 43 +- sqlbuilder/insert_statement_test.go | 40 +- sqlbuilder/lock_statement.go | 20 +- sqlbuilder/lock_statement_test.go | 8 +- sqlbuilder/numeric_expression.go | 4 +- sqlbuilder/select_statement.go | 92 ++-- sqlbuilder/set_statement.go | 31 +- sqlbuilder/set_statement_test.go | 69 ++- sqlbuilder/statement.go | 23 +- sqlbuilder/statement_test.go | 14 +- sqlbuilder/table.go | 14 +- sqlbuilder/update_statement.go | 29 +- sqlbuilder/update_statement_test.go | 17 +- sqlbuilder/utils.go | 8 +- tests/insert_test.go | 9 +- tests/select_test.go | 658 ++++++++++++++++++---------- tests/test_util.go | 89 ++++ 27 files changed, 1013 insertions(+), 426 deletions(-) create mode 100644 tests/test_util.go diff --git a/sqlbuilder/alias.go b/sqlbuilder/alias.go index 686014b..1aae3fe 100644 --- a/sqlbuilder/alias.go +++ b/sqlbuilder/alias.go @@ -20,7 +20,7 @@ func (a *Alias) serializeForProjection(statement statementType, out *queryData) return err } - out.writeString(" AS \"" + a.alias + "\"") + out.writeString(`AS "` + a.alias + `"`) return nil } diff --git a/sqlbuilder/bool_expression_test.go b/sqlbuilder/bool_expression_test.go index eb968ad..457e5de 100644 --- a/sqlbuilder/bool_expression_test.go +++ b/sqlbuilder/bool_expression_test.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "fmt" "gotest.tools/assert" "testing" ) @@ -118,8 +119,17 @@ func TestExists(t *testing.T) { out := queryData{} err := query.serialize(select_statement, &out) + fmt.Println(out.buff.String()) + assert.NilError(t, err) - assert.Equal(t, out.buff.String(), "EXISTS (SELECT $1 FROM db.table2 WHERE table1.col1 = table2.col3)") + + expectedSql := + `EXISTS ( + SELECT $1 + FROM db.table2 + WHERE table1.col1 = table2.col3 +)` + assert.Equal(t, out.buff.String(), expectedSql) } func TestIn(t *testing.T) { @@ -129,7 +139,11 @@ func TestIn(t *testing.T) { err := query.serialize(select_statement, &out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `$1 IN (SELECT table1.col1 AS "table1.col1" FROM db.table1)`) + fmt.Println(out.buff.String()) + assert.Equal(t, out.buff.String(), `$1 IN ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +)`) query2 := ROW(Literal(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1)) @@ -137,5 +151,10 @@ func TestIn(t *testing.T) { err = query2.serialize(select_statement, &out) assert.NilError(t, err) - assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN (SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2))`) + fmt.Println(out.buff.String()) + assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN ( + SELECT table2.col3 AS "table2.col3", + table3.col1 AS "table3.col1" + FROM db.table2 +))`) } diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 5df2d40..a8d73e5 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -2,7 +2,6 @@ package sqlbuilder import ( "bytes" - "errors" "strconv" ) @@ -13,6 +12,9 @@ type clause interface { type queryData struct { buff bytes.Buffer args []interface{} + + lastChar byte + ident int } type statementType string @@ -26,48 +28,125 @@ const ( lock_statement statementType = "LOCK" ) +const defaultIdent = 5 + +func (q *queryData) increaseIdent() { + q.ident += defaultIdent +} + +func (q *queryData) decreaseIdent() { + if q.ident < defaultIdent { + q.ident = 0 + } + + q.ident -= defaultIdent +} + func (q *queryData) writeProjection(statement statementType, projections []projection) error { - return serializeProjectionList(statement, projections, q) + q.increaseIdent() + err := serializeProjectionList(statement, projections, q) + q.decreaseIdent() + return err +} + +func (q *queryData) writeFrom(statement statementType, table tableInterface) error { + q.nextLine() + q.writeString("FROM") + + q.increaseIdent() + err := table.serialize(statement, q) + q.decreaseIdent() + + return err } func (q *queryData) writeWhere(statement statementType, where expression) error { - q.writeString(" WHERE ") - return where.serialize(statement, q) + q.nextLine() + q.writeString("WHERE") + + q.increaseIdent() + err := where.serialize(statement, q) + q.decreaseIdent() + + return err } func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClause) error { - q.writeString(" GROUP BY ") + q.nextLine() + q.writeString("GROUP BY") - return serializeGroupByClauseList(statement, groupBy, q) + q.increaseIdent() + err := serializeGroupByClauseList(statement, groupBy, q) + q.decreaseIdent() + + return err } func (q *queryData) writeOrderBy(statement statementType, orderBy []orderByClause) error { - q.writeString(" ORDER BY ") - return serializeOrderByClauseList(statement, orderBy, q) + q.nextLine() + q.writeString("ORDER BY") + + q.increaseIdent() + err := serializeOrderByClauseList(statement, orderBy, q) + q.decreaseIdent() + + return err } func (q *queryData) writeHaving(statement statementType, having expression) error { - q.writeString(" HAVING ") - return having.serialize(statement, q) + q.nextLine() + q.writeString("HAVING") + + q.increaseIdent() + err := having.serialize(statement, q) + q.decreaseIdent() + + return err +} + +func (q *queryData) nextLine() { + q.write([]byte{'\n'}) + q.write(bytes.Repeat([]byte{' '}, q.ident)) } func (q *queryData) write(data []byte) { + if len(data) == 0 { + return + } + + if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.buff.Len() > 0 { + q.buff.WriteByte(' ') + } + q.buff.Write(data) + q.lastChar = data[len(data)-1] +} + +func isPreSeparator(b byte) bool { + return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n' +} + +func isPostSeparator(b byte) bool { + return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n' } func (q *queryData) writeString(str string) { - q.buff.WriteString(str) + q.write([]byte(str)) } func (q *queryData) writeByte(b byte) { - q.buff.WriteByte(b) + q.write([]byte{b}) +} + +func (q *queryData) finalize() (string, []interface{}) { + return q.buff.String() + ";\n", q.args } func (q *queryData) insertArgument(arg interface{}) { q.args = append(q.args, arg) argPlaceholder := "$" + strconv.Itoa(len(q.args)) - q.buff.WriteString(argPlaceholder) + q.writeString(argPlaceholder) } func (q *queryData) reset() { @@ -75,49 +154,49 @@ func (q *queryData) reset() { q.args = []interface{}{} } -func argToString(value interface{}) (string, error) { +func ArgToString(value interface{}) string { switch bindVal := value.(type) { case bool: if bindVal { - return "TRUE", nil + return "TRUE" } else { - return "FALSE", nil + return "FALSE" } case int8: - return strconv.FormatInt(int64(bindVal), 10), nil + return strconv.FormatInt(int64(bindVal), 10) case int: - return strconv.FormatInt(int64(bindVal), 10), nil + return strconv.FormatInt(int64(bindVal), 10) case int16: - return strconv.FormatInt(int64(bindVal), 10), nil + return strconv.FormatInt(int64(bindVal), 10) case int32: - return strconv.FormatInt(int64(bindVal), 10), nil + return strconv.FormatInt(int64(bindVal), 10) case int64: - return strconv.FormatInt(int64(bindVal), 10), nil + return strconv.FormatInt(int64(bindVal), 10) case uint8: - return strconv.FormatUint(uint64(bindVal), 10), nil + return strconv.FormatUint(uint64(bindVal), 10) case uint: - return strconv.FormatUint(uint64(bindVal), 10), nil + return strconv.FormatUint(uint64(bindVal), 10) case uint16: - return strconv.FormatUint(uint64(bindVal), 10), nil + return strconv.FormatUint(uint64(bindVal), 10) case uint32: - return strconv.FormatUint(uint64(bindVal), 10), nil + return strconv.FormatUint(uint64(bindVal), 10) case uint64: - return strconv.FormatUint(uint64(bindVal), 10), nil + return strconv.FormatUint(uint64(bindVal), 10) case float32: - return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil + return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) case float64: - return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil + return strconv.FormatFloat(float64(bindVal), 'f', -1, 64) case string: - return bindVal, nil + return `'` + bindVal + `'` case []byte: - return string(bindVal), nil + return string(bindVal) //TODO: implement //case time.Time: // return bindVal.String()) default: - return "", errors.New("Unsupported literal type. ") + return "[Unknown type]" } } diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 0738b88..6f2d71b 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -79,17 +79,16 @@ func (c *baseColumn) DefaultAlias() projection { func (c *baseColumn) serializeAsOrderBy(statement statementType, out *queryData) error { if statement == set_statement { - // set statement (UNION, EXCEPT ...) can reference only select projections in order by clause - out.writeString(`"`) + // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause + columnRef := "" if c.tableName != "" { - out.writeString(c.tableName) - out.writeString(".") + columnRef += c.tableName + "." } - out.writeString(c.name) + columnRef += c.name - out.writeString(`"`) + out.writeString(`"` + columnRef + `"`) return nil } @@ -98,22 +97,26 @@ func (c *baseColumn) serializeAsOrderBy(statement statementType, out *queryData) } func (c baseColumn) serialize(statement statementType, out *queryData) error { + + columnRef := "" + if c.tableName != "" { - out.writeString(c.tableName) - out.writeString(".") + columnRef += c.tableName + "." } wrapColumnName := strings.Contains(c.name, ".") if wrapColumnName { - out.writeString(`"`) + columnRef += `"` } - out.writeString(c.name) + columnRef += c.name if wrapColumnName { - out.writeString(`"`) + columnRef += `"` } + out.writeString(columnRef) + return nil } diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index 6f071a5..54dcef6 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -7,7 +7,7 @@ import ( ) type deleteStatement interface { - statement + Statement WHERE(expression boolExpression) deleteStatement } @@ -28,28 +28,44 @@ func (d *deleteStatementImpl) WHERE(expression boolExpression) deleteStatement { return d } -func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) { - queryData := &queryData{} - - queryData.writeString("DELETE FROM ") +func (d *deleteStatementImpl) serializeImpl(out *queryData) error { + out.nextLine() + out.writeString("DELETE FROM") if d.table == nil { - return "", nil, errors.New("nil tableName.") + return errors.New("nil tableName.") } - if err = d.table.serialize(delete_statement, queryData); err != nil { - return + if err := d.table.serialize(delete_statement, out); err != nil { + return err } if d.where == nil { - return "", nil, errors.New("Deleting without a WHERE clause.") + return errors.New("Deleting without a WHERE clause.") } - if err = queryData.writeWhere(delete_statement, d.where); err != nil { + if err := out.writeWhere(delete_statement, d.where); err != nil { + return err + } + + return nil +} + +func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) { + queryData := &queryData{} + + err = d.serializeImpl(queryData) + + if err != nil { return } - return queryData.buff.String() + ";", queryData.args, nil + query, args = queryData.finalize() + return +} + +func (d *deleteStatementImpl) DebugSql() (query string, err error) { + return DebugSql(d) } func (u *deleteStatementImpl) Query(db types.Db, destination interface{}) error { diff --git a/sqlbuilder/delete_statement_test.go b/sqlbuilder/delete_statement_test.go index 24e3fc9..c95a340 100644 --- a/sqlbuilder/delete_statement_test.go +++ b/sqlbuilder/delete_statement_test.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "fmt" "gotest.tools/assert" "testing" ) @@ -14,5 +15,10 @@ func TestDeleteWithWhere(t *testing.T) { sql, _, err := table1.DELETE().WHERE(table1Col1.EqL(1)).Sql() assert.NilError(t, err) - assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = $1;") + fmt.Println(sql) + expectedSql := ` +DELETE FROM db.table1 +WHERE table1.col1 = $1; +` + assert.Equal(t, sql, expectedSql) } diff --git a/sqlbuilder/execution/execution.go b/sqlbuilder/execution/execution.go index 7cb44b1..84880b9 100644 --- a/sqlbuilder/execution/execution.go +++ b/sqlbuilder/execution/execution.go @@ -543,8 +543,8 @@ func isDbBaseType(objType reflect.Type) bool { typeStr := objType.String() switch typeStr { - case "string", "int32", "int16", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8", - "*string", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8": + case "string", "int", "int32", "int16", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8", + "*string", "*int", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8": return true } diff --git a/sqlbuilder/expression.go b/sqlbuilder/expression.go index a06382a..4e55fe2 100644 --- a/sqlbuilder/expression.go +++ b/sqlbuilder/expression.go @@ -30,7 +30,7 @@ func (e *expressionInterfaceImpl) IN(subQuery selectStatement) boolExpression { } func (e *expressionInterfaceImpl) NOT_IN(subQuery selectStatement) boolExpression { - return newBinaryBoolExpression(e.parent, subQuery, "NOT_IN") + return newBinaryBoolExpression(e.parent, subQuery, "NOT IN") } func (e *expressionInterfaceImpl) AS(alias string) projection { diff --git a/sqlbuilder/expression_table.go b/sqlbuilder/expression_table.go index b68f899..147c744 100644 --- a/sqlbuilder/expression_table.go +++ b/sqlbuilder/expression_table.go @@ -10,7 +10,6 @@ type expressionTable interface { type expressionTableImpl struct { statement expression - columns []column alias string } @@ -24,7 +23,7 @@ func (s *expressionTableImpl) TableName() string { } func (s *expressionTableImpl) Columns() []column { - return s.columns + return []column{} } func (s *expressionTableImpl) RefIntColumnName(name string) *IntegerColumn { @@ -42,20 +41,20 @@ func (s *expressionTableImpl) RefIntColumn(column column) *IntegerColumn { } func (s *expressionTableImpl) RefStringColumn(column column) *StringColumn { - strColumn := NewStringColumn(column.Name(), NotNullable) - strColumn.setTableName(column.TableName()) + strColumn := NewStringColumn(column.TableName()+"."+column.Name(), NotNullable) + strColumn.setTableName(s.alias) return strColumn } func (s *expressionTableImpl) serialize(statement statementType, out *queryData) error { - out.writeString("( ") + //out.writeString("( ") err := s.statement.serialize(statement, out) if err != nil { return err } - out.writeString(" ) AS ") + out.writeString("AS") out.writeString(s.alias) return nil diff --git a/sqlbuilder/func_expression.go b/sqlbuilder/func_expression.go index 2d8c325..71b99d3 100644 --- a/sqlbuilder/func_expression.go +++ b/sqlbuilder/func_expression.go @@ -29,8 +29,8 @@ func newFunc(name string, expressions []expression, parent expression) *funcExpr } func (f *funcExpressionImpl) serialize(statement statementType, out *queryData) error { - out.writeString(f.name) - out.writeString("(") + out.writeString(f.name + "(") + err := serializeExpressionList(statement, f.expression, ", ", out) if err != nil { return err @@ -111,7 +111,6 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro out.writeString("(CASE") if c.expression != nil { - out.writeString(" ") err := c.expression.serialize(statement, out) if err != nil { @@ -120,7 +119,7 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro } if len(c.when) == 0 || len(c.then) == 0 { - return errors.New("Invalid case statement. There should be at least one when/then expression pair. ") + return errors.New("Invalid case Statement. There should be at least one when/then expression pair. ") } if len(c.when) != len(c.then) { @@ -128,14 +127,14 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro } for i, when := range c.when { - out.writeString(" WHEN ") + out.writeString("WHEN") err := when.serialize(statement, out) if err != nil { return err } - out.writeString(" THEN ") + out.writeString("THEN") err = c.then[i].serialize(statement, out) if err != nil { @@ -144,7 +143,7 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro } if c.els != nil { - out.writeString(" ELSE ") + out.writeString("ELSE") err := c.els.serialize(statement, out) if err != nil { @@ -152,7 +151,7 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro } } - out.writeString(" END)") + out.writeString("END)") return nil } diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 0559d1e..a61680e 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -10,9 +10,9 @@ import ( ) type insertStatement interface { - statement + Statement - // Add a row of values to the insert statement. + // Add a row of values to the insert Statement. VALUES(values ...interface{}) insertStatement // Map or stracture mapped to column names VALUES_MAPPING(data interface{}) insertStatement @@ -48,9 +48,9 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { } // expression or default keyword -func (s *insertStatementImpl) VALUES(values ...interface{}) insertStatement { +func (i *insertStatementImpl) VALUES(values ...interface{}) insertStatement { if len(values) == 0 { - return s + return i } literalRow := []clause{} @@ -63,8 +63,8 @@ func (s *insertStatementImpl) VALUES(values ...interface{}) insertStatement { } } - s.rows = append(s.rows, literalRow) - return s + i.rows = append(i.rows, literalRow) + return i } func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) insertStatement { @@ -121,13 +121,19 @@ func (i *insertStatementImpl) addError(err string) { i.errors = append(i.errors, err) } +func (i *insertStatementImpl) DebugSql() (query string, err error) { + return DebugSql(i) +} + func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) { if len(s.errors) > 0 { return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", ")) } queryData := &queryData{} - queryData.writeString("INSERT INTO ") + + queryData.nextLine() + queryData.writeString("INSERT INTO") if s.table == nil { return "", nil, errors.Newf("nil tableName.") @@ -135,12 +141,14 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) err = s.table.serialize(insert_statement, queryData) + queryData.writeByte(' ') + if err != nil { return "", nil, err } if len(s.columns) > 0 { - queryData.writeString(" (") + queryData.writeString("(") err = serializeColumnList(insert_statement, s.columns, queryData) @@ -148,7 +156,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) return "", nil, err } - queryData.writeString(") ") + queryData.writeString(")") } if len(s.rows) == 0 && s.query == nil { @@ -160,12 +168,17 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(s.rows) > 0 { - queryData.writeString("VALUES (") + queryData.writeString("VALUES") + for row_i, row := range s.rows { if row_i > 0 { - queryData.writeString(", (") + queryData.writeString(",") } + queryData.increaseIdent() + queryData.nextLine() + queryData.writeString("(") + if len(row) != len(s.columns) { return "", nil, errors.New("# of values does not match # of columns.") } @@ -177,6 +190,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } queryData.writeByte(')') + queryData.decreaseIdent() } } @@ -189,7 +203,8 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(s.returning) > 0 { - queryData.writeString(" RETURNING ") + queryData.nextLine() + queryData.writeString("RETURNING") err = queryData.writeProjection(insert_statement, s.returning) @@ -198,7 +213,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } } - queryData.writeByte(';') + sql, args = queryData.finalize() - return queryData.buff.String(), queryData.args, nil + return } diff --git a/sqlbuilder/insert_statement_test.go b/sqlbuilder/insert_statement_test.go index 2f6df0c..c3c8949 100644 --- a/sqlbuilder/insert_statement_test.go +++ b/sqlbuilder/insert_statement_test.go @@ -29,7 +29,10 @@ func TestInsertColumnLengthMismatch(t *testing.T) { func TestInsertNilValue(t *testing.T) { query, args, err := table1.INSERT(table1Col1).VALUES(nil).Sql() - assert.Equal(t, query, "INSERT INTO db.table1 (col1) VALUES ($1);") + assert.Equal(t, query, ` +INSERT INTO db.table1 (col1) VALUES + ($1); +`) assert.Equal(t, len(args), 1) assert.NilError(t, err) } @@ -44,7 +47,10 @@ func TestInsertSingleValue(t *testing.T) { sql, _, err := table1.INSERT(table1Col1).VALUES(1).Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES ($1);") + assert.Equal(t, sql, ` +INSERT INTO db.table1 (col1) VALUES + ($1); +`) } func TestInsertDate(t *testing.T) { @@ -53,7 +59,10 @@ func TestInsertDate(t *testing.T) { sql, _, err := table1.INSERT(table1Col4).VALUES(date).Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 (col4) VALUES ($1);") + assert.Equal(t, sql, ` +INSERT INTO db.table1 (col4) VALUES + ($1); +`) } func TestInsertMultipleValues(t *testing.T) { @@ -63,7 +72,14 @@ func TestInsertMultipleValues(t *testing.T) { sql, _, err := stmt.Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2,col3) VALUES ($1, $2, $3);") + fmt.Println(sql) + + expectedSql := ` +INSERT INTO db.table1 (col1,col2,col3) VALUES + ($1, $2, $3); +` + + assert.Equal(t, sql, expectedSql) } func TestInsertMultipleRows(t *testing.T) { @@ -75,7 +91,16 @@ func TestInsertMultipleRows(t *testing.T) { sql, _, err := stmt.Sql() assert.NilError(t, err) - assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2) VALUES ($1, $2), ($3, $4), ($5, $6);") + fmt.Println(sql) + + expectedSql := ` +INSERT INTO db.table1 (col1,col2) VALUES + ($1, $2), + ($3, $4), + ($5, $6); +` + + assert.Equal(t, sql, expectedSql) } func TestInsertValuesFromModel(t *testing.T) { @@ -98,7 +123,10 @@ func TestInsertValuesFromModel(t *testing.T) { fmt.Println(sql) - assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES ($1, $2);`) + assert.Equal(t, sql, ` +INSERT INTO db.table1 (col1,col2) VALUES + ($1, $2); +`) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { diff --git a/sqlbuilder/lock_statement.go b/sqlbuilder/lock_statement.go index 2f5cb8d..d3251b4 100644 --- a/sqlbuilder/lock_statement.go +++ b/sqlbuilder/lock_statement.go @@ -20,7 +20,7 @@ const ( ) type lockStatement interface { - statement + Statement IN(lockMode lockMode) lockStatement NOWAIT() lockStatement @@ -48,9 +48,13 @@ func (l *lockStatementImpl) NOWAIT() lockStatement { return l } +func (l *lockStatementImpl) DebugSql() (query string, err error) { + return DebugSql(l) +} + func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) { if l == nil { - return "", nil, errors.New("nil statement.") + return "", nil, errors.New("nil Statement.") } if len(l.tables) == 0 { @@ -59,7 +63,8 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) out := &queryData{} - out.writeString("LOCK TABLE ") + out.nextLine() + out.writeString("LOCK TABLE") for i, table := range l.tables { if i > 0 { @@ -74,16 +79,17 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) } if l.lockMode != "" { - out.writeString(" IN ") + out.writeString("IN") out.writeString(string(l.lockMode)) - out.writeString(" MODE") + out.writeString("MODE") } if l.nowait { - out.writeString(" NOWAIT") + out.writeString("NOWAIT") } - return out.buff.String(), out.args, nil + query, args = out.finalize() + return } func (l *lockStatementImpl) Query(db types.Db, destination interface{}) error { diff --git a/sqlbuilder/lock_statement_test.go b/sqlbuilder/lock_statement_test.go index cd625d9..af25cea 100644 --- a/sqlbuilder/lock_statement_test.go +++ b/sqlbuilder/lock_statement_test.go @@ -11,7 +11,9 @@ func TestLockSingleTable(t *testing.T) { queryStr, _, err := lock.Sql() assert.NilError(t, err) - assert.Equal(t, queryStr, `LOCK TABLE db.table1 IN ROW SHARE MODE`) + assert.Equal(t, queryStr, ` +LOCK TABLE db.table1 IN ROW SHARE MODE; +`) } func TestLockMultipleTable(t *testing.T) { @@ -20,5 +22,7 @@ func TestLockMultipleTable(t *testing.T) { queryStr, _, err := lock.Sql() assert.NilError(t, err) - assert.Equal(t, queryStr, `LOCK TABLE db.table2, db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT`) + assert.Equal(t, queryStr, ` +LOCK TABLE db.table2, db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT; +`) } diff --git a/sqlbuilder/numeric_expression.go b/sqlbuilder/numeric_expression.go index 8a23b46..17fc60a 100644 --- a/sqlbuilder/numeric_expression.go +++ b/sqlbuilder/numeric_expression.go @@ -131,9 +131,9 @@ func newNumericExpressionWrap(expression expression) numericExpression { } func (c *numericExpressionWrapper) serialize(statement statementType, out *queryData) error { - out.writeString("(") + //out.writeString("(") err := c.expression.serialize(statement, out) - out.writeString(")") + //out.writeString(")") return err } diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 6f3ed71..6b39af4 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -7,7 +7,7 @@ import ( ) type selectStatement interface { - statement + Statement expression DISTINCT() selectStatement @@ -84,15 +84,17 @@ func (s *selectStatementImpl) FROM(table readableTable) selectStatement { } func (s *selectStatementImpl) serialize(statement statementType, out *queryData) error { - out.writeString("(") + out.increaseIdent() err := s.serializeImpl(out) + out.decreaseIdent() if err != nil { return err } + out.nextLine() out.writeString(")") return nil @@ -100,10 +102,11 @@ func (s *selectStatementImpl) serialize(statement statementType, out *queryData) func (s *selectStatementImpl) serializeImpl(out *queryData) error { - out.writeString("SELECT ") + out.nextLine() + out.writeString("SELECT") if s.distinct { - out.writeString("DISTINCT ") + out.writeString("DISTINCT") } if s.projections == nil || len(s.projections) == 0 { @@ -116,16 +119,18 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { return err } - out.writeString(" FROM ") - if s.table == nil { return errors.Newf("nil tableName.") } - if err := s.table.serialize(select_statement, out); err != nil { + if err := out.writeFrom(select_statement, s.table); err != nil { return err } + //if err := s.table.serialize(select_statement, out); err != nil { + // return err + //} + if s.where != nil { err := out.writeWhere(select_statement, s.where) @@ -159,33 +164,42 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { } if s.limit >= 0 { - out.writeString(" LIMIT ") + out.nextLine() + out.writeString("LIMIT") out.insertArgument(s.limit) } if s.offset >= 0 { - out.writeString(" OFFSET ") + out.nextLine() + out.writeString("OFFSET") out.insertArgument(s.offset) } if s.forUpdate { - out.writeString(" FOR UPDATE") + out.nextLine() + out.writeString("FOR UPDATE") } return nil } -// Return the properly escaped SQL statement, against the specified database -func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error) { +// Return the properly escaped SQL Statement, against the specified database +func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) { queryData := queryData{} - err = q.serializeImpl(&queryData) + err = s.serializeImpl(&queryData) if err != nil { return "", nil, err } - return queryData.buff.String(), queryData.args, nil + query, args = queryData.finalize() + + return +} + +func (s *selectStatementImpl) DebugSql() (query string, err error) { + return DebugSql(s) } func (s *selectStatementImpl) AsTable(alias string) expressionTable { @@ -195,9 +209,9 @@ func (s *selectStatementImpl) AsTable(alias string) expressionTable { } } -func (q *selectStatementImpl) WHERE(expression boolExpression) selectStatement { - q.where = expression - return q +func (s *selectStatementImpl) WHERE(expression boolExpression) selectStatement { + s.where = expression + return s } func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) selectStatement { @@ -205,46 +219,46 @@ func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) selectSt return s } -func (q *selectStatementImpl) HAVING(expression boolExpression) selectStatement { - q.having = expression - return q +func (s *selectStatementImpl) HAVING(expression boolExpression) selectStatement { + s.having = expression + return s } -func (q *selectStatementImpl) ORDER_BY(clauses ...orderByClause) selectStatement { +func (s *selectStatementImpl) ORDER_BY(clauses ...orderByClause) selectStatement { - q.orderBy = clauses + s.orderBy = clauses - return q + return s } -func (q *selectStatementImpl) OFFSET(offset int64) selectStatement { - q.offset = offset - return q +func (s *selectStatementImpl) OFFSET(offset int64) selectStatement { + s.offset = offset + return s } -func (q *selectStatementImpl) LIMIT(limit int64) selectStatement { - q.limit = limit - return q +func (s *selectStatementImpl) LIMIT(limit int64) selectStatement { + s.limit = limit + return s } -func (q *selectStatementImpl) DISTINCT() selectStatement { - q.distinct = true - return q +func (s *selectStatementImpl) DISTINCT() selectStatement { + s.distinct = true + return s } -func (q *selectStatementImpl) FOR_UPDATE() selectStatement { - q.forUpdate = true - return q +func (s *selectStatementImpl) FOR_UPDATE() selectStatement { + s.forUpdate = true + return s } func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error { return Query(s, db, destination) } -func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) { - return Execute(u, db) +func (s *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + return Execute(s, db) } -func NumExp(statement selectStatement) numericExpression { - return newNumericExpressionWrap(statement) +func NumExp(expression expression) numericExpression { + return newNumericExpressionWrap(expression) } diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index 1609414..a18ccee 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -13,7 +13,7 @@ const ( ) type setStatement interface { - statement + Statement expression ORDER_BY(clauses ...orderByClause) setStatement @@ -97,8 +97,10 @@ func (us *setStatementImpl) AsTable(alias string) expressionTable { } func (s *setStatementImpl) serialize(statement statementType, out *queryData) error { + if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { out.writeString("(") + out.increaseIdent() } err := s.serializeImpl(out) @@ -108,6 +110,8 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData) er } if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 { + out.decreaseIdent() + out.nextLine() out.writeString(")") } @@ -117,18 +121,22 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData) er func (s *setStatementImpl) serializeImpl(out *queryData) error { if len(s.selects) < 2 { - return errors.Newf("UNION statement must have at least two SELECT statements.") + return errors.Newf("UNION Statement must have at least two SELECT statements.") } + out.nextLine() out.writeString("(") + out.increaseIdent() for i, selectStmt := range s.selects { + out.nextLine() if i > 0 { - out.writeString(" " + s.operator + " ") + out.writeString(s.operator) if s.all { - out.writeString(" ALL ") + out.writeString("ALL") } + out.nextLine() } err := selectStmt.serialize(set_statement, out) @@ -138,6 +146,8 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { } } + out.decreaseIdent() + out.nextLine() out.writeString(")") if s.orderBy != nil { @@ -148,12 +158,14 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { } if s.limit >= 0 { - out.writeString(" LIMIT ") + out.nextLine() + out.writeString("LIMIT") out.insertArgument(s.limit) } if s.offset >= 0 { - out.writeString(" OFFSET ") + out.nextLine() + out.writeString("OFFSET") out.insertArgument(s.offset) } @@ -169,7 +181,12 @@ func (us *setStatementImpl) Sql() (query string, args []interface{}, err error) return } - return queryData.buff.String(), queryData.args, nil + query, args = queryData.finalize() + return +} + +func (s *setStatementImpl) DebugSql() (query string, err error) { + return DebugSql(s) } func (s *setStatementImpl) Query(db types.Db, destination interface{}) error { diff --git a/sqlbuilder/set_statement_test.go b/sqlbuilder/set_statement_test.go index 91c038d..88caebe 100644 --- a/sqlbuilder/set_statement_test.go +++ b/sqlbuilder/set_statement_test.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "fmt" "gotest.tools/assert" "testing" ) @@ -28,7 +29,20 @@ func TestUnionTwoSelect(t *testing.T) { ).Sql() assert.NilError(t, err) - assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2))`) + fmt.Println(query) + assert.Equal(t, query, ` +( + ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 + ) + UNION + ( + SELECT table2.col3 AS "table2.col3" + FROM db.table2 + ) +); +`) assert.Equal(t, len(args), 0) } @@ -39,8 +53,26 @@ func TestUnionThreeSelect(t *testing.T) { table3.SELECT(table3Col1), ).Sql() + fmt.Println(query) assert.NilError(t, err) - assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2) UNION (SELECT table3.col1 AS "table3.col1" FROM db.table3))`) + assert.Equal(t, query, ` +( + ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 + ) + UNION + ( + SELECT table2.col3 AS "table2.col3" + FROM db.table2 + ) + UNION + ( + SELECT table3.col1 AS "table3.col1" + FROM db.table3 + ) +); +`) assert.Equal(t, len(args), 0) } @@ -51,7 +83,21 @@ func TestUnionWithOrderBy(t *testing.T) { ).ORDER_BY(table1Col1.ASC()).Sql() assert.NilError(t, err) - assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) ORDER BY "table1.col1" ASC`) + fmt.Println(query) + assert.Equal(t, query, ` +( + ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 + ) + UNION + ( + SELECT table2.col3 AS "table2.col3" + FROM db.table2 + ) +) +ORDER BY "table1.col1" ASC; +`) assert.Equal(t, len(args), 0) } @@ -62,6 +108,21 @@ func TestUnionWithLimit(t *testing.T) { ).LIMIT(10).OFFSET(11).Sql() assert.NilError(t, err) - assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) LIMIT $1 OFFSET $2`) + fmt.Println(query) + assert.Equal(t, query, ` +( + ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 + ) + UNION + ( + SELECT table2.col3 AS "table2.col3" + FROM db.table2 + ) +) +LIMIT $1 +OFFSET $2; +`) assert.Equal(t, len(args), 2) } diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index e703ed9..6681013 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -3,12 +3,33 @@ package sqlbuilder import ( "database/sql" "github.com/sub0zero/go-sqlbuilder/types" + "strconv" + "strings" ) -type statement interface { +type Statement interface { // String returns generated SQL as string. Sql() (query string, args []interface{}, err error) + DebugSql() (query string, err error) + Query(db types.Db, destination interface{}) error Execute(db types.Db) (sql.Result, error) } + +func DebugSql(statement Statement) (string, error) { + sql, args, err := statement.Sql() + + if err != nil { + return "", err + } + + debugSql := sql + + for i, arg := range args { + argPlaceholder := "$" + strconv.Itoa(i+1) + debugSql = strings.Replace(debugSql, argPlaceholder, ArgToString(arg), 1) + } + + return debugSql, nil +} diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index 83c50f7..7ed4fd9 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -18,7 +18,7 @@ var _ = gc.Suite(&StmtSuite{}) // NOTE: tables / columns are defined in test_utils.go // -// SELECT statement tests +// SELECT Statement tests // func (s *StmtSuite) TestSelectEmptyProjection(c *gc.C) { @@ -233,7 +233,7 @@ func (s *StmtSuite) TestSelectDistinct(c *gc.C) { } // -// INSERT statement tests +// INSERT Statement tests // func (s *StmtSuite) TestInsertNoColumn(c *gc.C) { @@ -386,7 +386,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) { } // -// LOCK/UNLOCK statement tests ================================================ +// LOCK/UNLOCK Statement tests ================================================ // func (s *StmtSuite) TestLockStatement(c *gc.C) { @@ -444,7 +444,7 @@ func (s *StmtSuite) TestUnionLimitWithoutOrderBy(c *gc.C) { c.Assert( errors.GetMessage(err), gc.Equals, - "All inner selects in UNION statement must have LIMIT if they have ORDER BY") + "All inner selects in UNION Statement must have LIMIT if they have ORDER BY") } func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { @@ -472,7 +472,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { c.Assert( errors.GetMessage(err), gc.Equals, - "All inner selects in UNION statement must select the "+ + "All inner selects in UNION Statement must select the "+ "same number of columns. For sanity, you probably "+ "want to select the same tableName columns in the same "+ "orderBy. If you are selecting on multiple tables, "+ @@ -481,8 +481,8 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { - // tests on outer statement: Group By, Order By, LIMIT - // on inner statement: AndWhere, WHERE (with AND), Order By, LIMIT + // tests on outer Statement: Group By, Order By, LIMIT + // on inner Statement: AndWhere, WHERE (with AND), Order By, LIMIT select_queries := make([]selectStatement, 0, 3) // We're not trying to write a SQL parser, so we won't warn if you do something silly like diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 5bd4788..325dbcd 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -106,7 +106,7 @@ func (t *Table) Columns() []column { } // Generates the sql string for the current tableName expression. Note: the -// generated string may not be a valid/executable sql statement. +// generated string may not be a valid/executable sql Statement. func (t *Table) serialize(statement statementType, out *queryData) error { if t == nil { return errors.Newf("nil tableName.") @@ -287,17 +287,19 @@ func (t *joinTable) serialize(statement statementType, out *queryData) (err erro return } + out.nextLine() + switch t.join_type { case INNER_JOIN: - out.writeString(" JOIN ") + out.writeString("JOIN") case LEFT_JOIN: - out.writeString(" LEFT JOIN ") + out.writeString("LEFT JOIN") case RIGHT_JOIN: - out.writeString(" RIGHT JOIN ") + out.writeString("RIGHT JOIN") case FULL_JOIN: - out.writeString(" FULL JOIN ") + out.writeString("FULL JOIN") case CROSS_JOIN: - out.writeString(" CROSS JOIN ") + out.writeString("CROSS JOIN") } if err = t.rhs.serialize(statement, out); err != nil { diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index 88a186f..7d55f8b 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -7,7 +7,7 @@ import ( ) type updateStatement interface { - statement + Statement SET(values ...interface{}) updateStatement WHERE(expression boolExpression) updateStatement @@ -55,7 +55,8 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) updateStateme func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) { out := &queryData{} - out.writeString("UPDATE ") + out.nextLine() + out.writeString("UPDATE") if u.table == nil { return "", nil, errors.New("nil tableName.") @@ -69,12 +70,10 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) return "", nil, errors.New("No column updated.") } - out.writeString(" SET") + out.writeString("SET") if len(u.columns) > 1 { - out.writeString(" ( ") - } else { - out.writeString(" ") + out.writeString("(") } err = serializeColumnList(update_statement, u.columns, out) @@ -84,13 +83,13 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(u.columns) > 1 { - out.writeString(" )") + out.writeString(")") } - out.writeString(" =") + out.writeString("=") if len(u.updateValues) > 1 { - out.writeString(" (") + out.writeString("(") } for i, value := range u.updateValues { @@ -106,7 +105,7 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(u.updateValues) > 1 { - out.writeString(" )") + out.writeString(")") } if u.where == nil { @@ -118,7 +117,8 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(u.returning) > 0 { - out.writeString(" RETURNING ") + out.nextLine() + out.writeString("RETURNING") err = serializeProjectionList(update_statement, u.returning, out) @@ -127,7 +127,12 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) } } - return out.buff.String(), out.args, nil + sql, args = out.finalize() + return +} + +func (u *updateStatementImpl) DebugSql() (query string, err error) { + return DebugSql(u) } func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error { diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go index bb6dcd9..d6e9daa 100644 --- a/sqlbuilder/update_statement_test.go +++ b/sqlbuilder/update_statement_test.go @@ -7,19 +7,30 @@ import ( ) // -// UPDATE statement tests ===================================================== +// UPDATE Statement tests ===================================================== // func TestUpdate(t *testing.T) { stmt := table1.UPDATE(table1Col1, table1Col2). - SET(table1.SELECT(table1Col2)). - WHERE(table1Col1.EqL(2)) + SET(table1.SELECT(table1Col2, table2Col3)). + WHERE(table1Col1.EqL(2)). + RETURNING(table1Col1) stmtStr, _, err := stmt.Sql() assert.NilError(t, err) fmt.Println(stmtStr) + + assert.Equal(t, stmtStr, ` +UPDATE db.table1 SET (col1,col2) = ( + SELECT table1.col2 AS "table1.col2", + table2.col3 AS "table2.col3" + FROM db.table1 +) +WHERE table1.col1 = $1 +RETURNING table1.col1 AS "table1.col1"; +`) } //func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) { diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 1ea6f0c..859113b 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -82,8 +82,10 @@ func serializeExpressionList(statement statementType, expressions []expression, func serializeProjectionList(statement statementType, projections []projection, out *queryData) error { for i, col := range projections { if i > 0 { - out.writeString(", ") + out.writeString(",") + out.nextLine() } + if col == nil { return errors.New("projection expression is nil.") } @@ -112,7 +114,7 @@ func serializeColumnList(statement statementType, columns []column, out *queryDa return nil } -func Query(statement statement, db types.Db, destination interface{}) error { +func Query(statement Statement, db types.Db, destination interface{}) error { query, args, err := statement.Sql() if err != nil { @@ -122,7 +124,7 @@ func Query(statement statement, db types.Db, destination interface{}) error { return execution.Query(db, query, args, destination) } -func Execute(statement statement, db types.Db) (res sql.Result, err error) { +func Execute(statement Statement, db types.Db) (res sql.Result, err error) { query, args, err := statement.Sql() if err != nil { diff --git a/tests/insert_test.go b/tests/insert_test.go index 48482fb..b6154fc 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -25,7 +25,14 @@ func TestInsertValues(t *testing.T) { fmt.Println(insertQueryStr) - assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ($1, $2, DEFAULT), ($3, $4, DEFAULT), ($5, $6, DEFAULT), ($7, $8, DEFAULT) RETURNING link.id AS "link.id";`) + assert.Equal(t, insertQueryStr, ` +INSERT INTO test_sample.link (url,name,rel) VALUES + ($1, $2, DEFAULT), + ($3, $4, DEFAULT), + ($5, $6, DEFAULT), + ($7, $8, DEFAULT) +RETURNING link.id AS "link.id"; +`) res, err := insertQuery.Execute(db) assert.NilError(t, err) diff --git a/tests/select_test.go b/tests/select_test.go index 81c4c37..db341c8 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -2,29 +2,33 @@ package tests import ( "fmt" - "github.com/sub0zero/go-sqlbuilder/sqlbuilder" + . "github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" . "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/table" "gotest.tools/assert" - "strings" "testing" - "time" ) func TestSelect_ScanToStruct(t *testing.T) { - actor := model.Actor{} + expectedSql := ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update" +FROM dvds.actor +WHERE actor.actor_id = 1 +ORDER BY actor.actor_id ASC; +` + query := Actor. SELECT(Actor.AllColumns). + WHERE(Actor.ActorID.EqL(1)). ORDER_BY(Actor.ActorID.ASC()) - queryStr, args, err := query.Sql() + assertQuery(t, query, expectedSql, 1) - fmt.Println(queryStr) - - assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`) - assert.Equal(t, len(args), 0) - - err = query.Query(db, &actor) + actor := model.Actor{} + err := query.Query(db, &actor) assert.NilError(t, err) @@ -39,37 +43,68 @@ func TestSelect_ScanToStruct(t *testing.T) { } func TestClassicSelect(t *testing.T) { - query := sqlbuilder.SELECT(Payment.AllColumns, Customer.AllColumns). + expectedSql := ` +SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date", + customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active" +FROM dvds.payment + JOIN dvds.customer ON payment.customer_id = customer.customer_id +ORDER BY payment.payment_id ASC +LIMIT 30; +` + + query := SELECT(Payment.AllColumns, Customer.AllColumns). FROM(Payment.INNER_JOIN(Customer, Payment.CustomerID.Eq(Customer.CustomerID))). ORDER_BY(Payment.PaymentID.ASC()). LIMIT(30) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - fmt.Println(queryStr) - fmt.Println(args) + assertQuery(t, query, expectedSql, int64(30)) dest := []model.Payment{} - err = query.Query(db, &dest) + err := query.Query(db, &dest) assert.NilError(t, err) + assert.Equal(t, len(dest), 30) + + //spew.Dump(dest) } func TestSelect_ScanToSlice(t *testing.T) { + expectedSql := ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active" +FROM dvds.customer +ORDER BY customer.customer_id ASC; +` customers := []model.Customer{} query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC()) - queryStr, args, err := query.Sql() - assert.NilError(t, err) - fmt.Println(queryStr) + assertQuery(t, query, expectedSql) - assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active" FROM dvds.customer ORDER BY customer.customer_id ASC`) - assert.Equal(t, len(args), 0) - - err = query.Query(db, &customers) + err := query.Query(db, &customers) assert.NilError(t, err) assert.Equal(t, len(customers), 599) @@ -80,20 +115,48 @@ func TestSelect_ScanToSlice(t *testing.T) { } func TestSelectAndUnionInProjection(t *testing.T) { + expectedSql := ` +SELECT payment.payment_id AS "payment.payment_id", + ( + SELECT customer.customer_id AS "customer.customer_id" + FROM dvds.customer + LIMIT 1 + ), + ( + ( + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT 1 + OFFSET 10 + ) + UNION + ( + SELECT payment.payment_id AS "payment.payment_id" + FROM dvds.payment + LIMIT 1 + OFFSET 2 + ) + ) + LIMIT 1 + ) +FROM dvds.payment +LIMIT 12; +` query := Payment. SELECT( Payment.PaymentID, Customer.SELECT(Customer.CustomerID).LIMIT(1), - sqlbuilder.UNION(Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2)).LIMIT(1), + UNION( + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(10), + Payment.SELECT(Payment.PaymentID).LIMIT(1).OFFSET(2), + ).LIMIT(1), ). LIMIT(12) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - fmt.Println(queryStr) - fmt.Println(args) + fmt.Println(query.Sql()) + assertQuery(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) } //func TestJoinQueryStruct(t *testing.T) { @@ -122,6 +185,29 @@ func TestSelectAndUnionInProjection(t *testing.T) { //} func TestJoinQuerySlice(t *testing.T) { + expectedSql := ` +SELECT language.language_id AS "language.language_id", + language.name AS "language.name", + language.last_update AS "language.last_update", + film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.last_update AS "film.last_update", + film.special_features AS "film.special_features", + film.fulltext AS "film.fulltext" +FROM dvds.film + JOIN dvds.language ON film.language_id = language.language_id +WHERE film.rating = 'NC-17' +LIMIT 15; +` + type FilmsPerLanguage struct { Language *model.Language Film []model.Film @@ -136,25 +222,11 @@ func TestJoinQuerySlice(t *testing.T) { WHERE(Film.Rating.EqString(string(model.MpaaRating_NC17))). LIMIT(15) - queryStr, args, err := query.Sql() + assertQuery(t, query, expectedSql, string(model.MpaaRating_NC17), int64(15)) + + err := query.Query(db, &filmsPerLanguage) assert.NilError(t, err) - fmt.Println(queryStr) - assert.Equal(t, queryStr, `SELECT language.language_id AS "language.language_id", language.name AS "language.name", language.last_update AS "language.last_update", film.film_id AS "film.film_id", film.title AS "film.title", film.description AS "film.description", film.release_year AS "film.release_year", film.language_id AS "film.language_id", film.rental_duration AS "film.rental_duration", film.rental_rate AS "film.rental_rate", film.length AS "film.length", film.replacement_cost AS "film.replacement_cost", film.rating AS "film.rating", film.last_update AS "film.last_update", film.special_features AS "film.special_features", film.fulltext AS "film.fulltext" FROM dvds.film JOIN dvds.language ON film.language_id = language.language_id WHERE film.rating = $1 LIMIT $2`) - - assert.Equal(t, len(args), 2) - assert.Equal(t, args[0], string(model.MpaaRating_NC17)) - assert.Equal(t, args[1], int64(15)) - - err = query.Query(db, &filmsPerLanguage) - - assert.NilError(t, err) - - //fmt.Println("--------------- result --------------- ") - //spew.Dump(filmsPerLanguage) - - //spew.Dump(filmsPerLanguage) - assert.Equal(t, len(filmsPerLanguage), 1) assert.Equal(t, len(filmsPerLanguage[0].Film), limit) @@ -162,8 +234,6 @@ func TestJoinQuerySlice(t *testing.T) { assert.Equal(t, *englishFilms.Film[0].Rating, model.MpaaRating_NC17) - //spew.Dump(filmsPerLanguage) - filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err = query.Query(db, &filmsPerLanguageWithPtrs) @@ -187,8 +257,6 @@ func TestJoinQuerySliceWithPtrs(t *testing.T) { filmsPerLanguageWithPtrs := []*FilmsPerLanguage{} err := query.Query(db, &filmsPerLanguageWithPtrs) - //spew.Dump(filmsPerLanguageWithPtrs) - assert.NilError(t, err) assert.Equal(t, len(filmsPerLanguageWithPtrs), 1) assert.Equal(t, len(*filmsPerLanguageWithPtrs[0].Film), int(limit)) @@ -257,24 +325,42 @@ func TestSelectOrderByAscDesc(t *testing.T) { } func TestSelectFullJoin(t *testing.T) { + expectedSql := ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active", + address.address_id AS "address.address_id", + address.address AS "address.address", + address.address2 AS "address.address2", + address.district AS "address.district", + address.city_id AS "address.city_id", + address.postal_code AS "address.postal_code", + address.phone AS "address.phone", + address.last_update AS "address.last_update" +FROM dvds.customer + FULL JOIN dvds.address ON customer.address_id = address.address_id +ORDER BY customer.customer_id ASC; +` query := Customer. FULL_JOIN(Address, Customer.AddressID.Eq(Address.AddressID)). SELECT(Customer.AllColumns, Address.AllColumns). ORDER_BY(Customer.CustomerID.ASC()) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - - assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active", address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer FULL JOIN dvds.address ON customer.address_id = address.address_id ORDER BY customer.customer_id ASC`) - assert.Equal(t, len(args), 0) + assertQuery(t, query, expectedSql) allCustomersAndAddress := []struct { Address *model.Address Customer *model.Customer }{} - err = query.Query(db, &allCustomersAndAddress) + err := query.Query(db, &allCustomersAndAddress) assert.NilError(t, err) assert.Equal(t, len(allCustomersAndAddress), 603) @@ -290,21 +376,41 @@ func TestSelectFullJoin(t *testing.T) { } func TestSelectFullCrossJoin(t *testing.T) { + expectedSql := ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active", + address.address_id AS "address.address_id", + address.address AS "address.address", + address.address2 AS "address.address2", + address.district AS "address.district", + address.city_id AS "address.city_id", + address.postal_code AS "address.postal_code", + address.phone AS "address.phone", + address.last_update AS "address.last_update" +FROM dvds.customer + CROSS JOIN dvds.address +ORDER BY customer.customer_id ASC +LIMIT 1000; +` query := Customer. CROSS_JOIN(Address). SELECT(Customer.AllColumns, Address.AllColumns). ORDER_BY(Customer.CustomerID.ASC()). LIMIT(1000) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT customer.customer_id AS "customer.customer_id", customer.store_id AS "customer.store_id", customer.first_name AS "customer.first_name", customer.last_name AS "customer.last_name", customer.email AS "customer.email", customer.address_id AS "customer.address_id", customer.activebool AS "customer.activebool", customer.create_date AS "customer.create_date", customer.last_update AS "customer.last_update", customer.active AS "customer.active", address.address_id AS "address.address_id", address.address AS "address.address", address.address2 AS "address.address2", address.district AS "address.district", address.city_id AS "address.city_id", address.postal_code AS "address.postal_code", address.phone AS "address.phone", address.last_update AS "address.last_update" FROM dvds.customer CROSS JOIN dvds.address ORDER BY customer.customer_id ASC LIMIT $1`) - assert.Equal(t, len(args), 1) + assertQuery(t, query, expectedSql, int64(1000)) customerAddresCrosJoined := []model.Customer{} - err = query.Query(db, &customerAddresCrosJoined) + err := query.Query(db, &customerAddresCrosJoined) assert.Equal(t, len(customerAddresCrosJoined), 1000) @@ -312,23 +418,49 @@ func TestSelectFullCrossJoin(t *testing.T) { } func TestSelectSelfJoin(t *testing.T) { - + expectedSql := ` +SELECT f1.film_id AS "f1.film_id", + f1.title AS "f1.title", + f1.description AS "f1.description", + f1.release_year AS "f1.release_year", + f1.language_id AS "f1.language_id", + f1.rental_duration AS "f1.rental_duration", + f1.rental_rate AS "f1.rental_rate", + f1.length AS "f1.length", + f1.replacement_cost AS "f1.replacement_cost", + f1.rating AS "f1.rating", + f1.last_update AS "f1.last_update", + f1.special_features AS "f1.special_features", + f1.fulltext AS "f1.fulltext", + f2.film_id AS "f2.film_id", + f2.title AS "f2.title", + f2.description AS "f2.description", + f2.release_year AS "f2.release_year", + f2.language_id AS "f2.language_id", + f2.rental_duration AS "f2.rental_duration", + f2.rental_rate AS "f2.rental_rate", + f2.length AS "f2.length", + f2.replacement_cost AS "f2.replacement_cost", + f2.rating AS "f2.rating", + f2.last_update AS "f2.last_update", + f2.special_features AS "f2.special_features", + f2.fulltext AS "f2.fulltext" +FROM dvds.film AS f1 + JOIN dvds.film AS f2 ON (f1.film_id != f2.film_id AND f1.length = f2.length) +ORDER BY f1.film_id ASC +LIMIT 100; +` f1 := Film.AS("f1") - //spew.Dump(f1) f2 := Film.AS("f2") query := f1. INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))). SELECT(f1.AllColumns, f2.AllColumns). - ORDER_BY(f1.FilmID.ASC()) + ORDER_BY(f1.FilmID.ASC()). + LIMIT(100) - queryStr, args, err := query.Sql() - assert.Equal(t, len(args), 0) - - assert.NilError(t, err) - - fmt.Println(queryStr) + assertQuery(t, query, expectedSql, int64(100)) type F1 model.Film type F2 model.Film @@ -338,25 +470,26 @@ func TestSelectSelfJoin(t *testing.T) { F2 F2 }{} - err = query.Query(db, &theSameLengthFilms) + err := query.Query(db, &theSameLengthFilms) assert.NilError(t, err) - //spew.Dump(theSameLengthFilms[0]) - - assert.Equal(t, len(theSameLengthFilms), 6972) + assert.Equal(t, len(theSameLengthFilms), 100) } func TestSelectAliasColumn(t *testing.T) { + expectedSql := ` +SELECT f1.title AS "thesame_length_films.title1", + f2.title AS "thesame_length_films.title2", + f1.length AS "thesame_length_films.length" +FROM dvds.film AS f1 + JOIN dvds.film AS f2 ON (f1.film_id != f2.film_id AND f1.length = f2.length) +ORDER BY f1.length ASC, f1.title ASC, f2.title ASC +LIMIT 1000; +` f1 := Film.AS("f1") f2 := Film.AS("f2") - type thesameLengthFilms struct { - Title1 string - Title2 string - Length int16 - } - query := f1. INNER_JOIN(f2, f1.FilmID.NotEq(f2.FilmID).AND(f1.Length.Eq(f2.Length))). SELECT(f1.Title.AS("thesame_length_films.title1"), @@ -365,15 +498,16 @@ func TestSelectAliasColumn(t *testing.T) { ORDER_BY(f1.Length.ASC(), f1.Title.ASC(), f2.Title.ASC()). LIMIT(1000) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - assert.Equal(t, len(args), 1) - fmt.Println(queryStr) + assertQuery(t, query, expectedSql, int64(1000)) + type thesameLengthFilms struct { + Title1 string + Title2 string + Length int16 + } films := []thesameLengthFilms{} - err = query.Query(db, &films) + err := query.Query(db, &films) assert.NilError(t, err) @@ -401,24 +535,41 @@ type staff struct { func TestSelectSelfReferenceType(t *testing.T) { + expectedSql := ` +SELECT DISTINCT staff.staff_id AS "staff.staff_id", + staff.first_name AS "staff.first_name", + staff.last_name AS "staff.last_name", + address.address_id AS "address.address_id", + address.address AS "address.address", + address.address2 AS "address.address2", + address.district AS "address.district", + address.city_id AS "address.city_id", + address.postal_code AS "address.postal_code", + address.phone AS "address.phone", + address.last_update AS "address.last_update", + manager.staff_id AS "manager.staff_id", + manager.first_name AS "manager.first_name" +FROM dvds.staff + JOIN dvds.address ON staff.address_id = address.address_id + JOIN dvds.staff AS manager ON staff.staff_id = manager.staff_id; +` manager := Staff.AS("manager") query := Staff. INNER_JOIN(Address, Staff.AddressID.Eq(Address.AddressID)). INNER_JOIN(manager, Staff.StaffID.Eq(manager.StaffID)). - SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName) + SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName). + DISTINCT() - queryStr, args, err := query.Sql() - assert.NilError(t, err) - fmt.Println(queryStr) - assert.Equal(t, len(args), 0) + assertQuery(t, query, expectedSql) staffs := []staff{} - err = query.Query(db, &staffs) + err := query.Query(db, &staffs) assert.NilError(t, err) + fmt.Println(query.DebugSql()) //spew.Dump(staffs) } @@ -437,13 +588,34 @@ func TestSubQuery(t *testing.T) { // //fmt.Println(queryStr) // - //avrgCustomer := sqlbuilder.NumExp(Customer.SELECT(Customer.LastName).LIMIT(1)) + //avrgCustomer := NumExp(Customer.SELECT(Customer.LastName).LIMIT(1)) // //Customer. // INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))). // SELECT(Customer.AllColumns, selectStmtTable.RefIntColumnName("first_name")). // WHERE(Actor.LastName.Neq(avrgCustomer)) + expectedQuery := ` +SELECT actor.actor_id AS "actor.actor_id", + actor.first_name AS "actor.first_name", + actor.last_name AS "actor.last_name", + actor.last_update AS "actor.last_update", + film_actor.actor_id AS "film_actor.actor_id", + film_actor.film_id AS "film_actor.film_id", + film_actor.last_update AS "film_actor.last_update", + films."film.title" AS "film.title", + films."film.rating" AS "film.rating" +FROM dvds.actor + JOIN dvds.film_actor ON actor.actor_id = film_actor.film_id + JOIN ( + SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.rating AS "film.rating" + FROM dvds.film + WHERE film.rating = 'R' + ) AS films ON film_actor.film_id = films."film.film_id"; +` + rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). WHERE(Film.Rating.EqString("R")). AsTable("films") @@ -457,42 +629,70 @@ func TestSubQuery(t *testing.T) { rFilmsOnly.RefStringColumn(Film.Rating).AS("film.rating"), ) - queryStr, args, err := query.Sql() + fmt.Println(query.Sql()) + + assertQuery(t, query, expectedQuery, "R") + + dest := []model.Actor{} + + err := query.Query(db, &dest) assert.NilError(t, err) - assert.Equal(t, len(args), 1) - fmt.Println(queryStr) - } func TestSelectFunctions(t *testing.T) { - query := Film.SELECT(sqlbuilder.MAX(Film.RentalRate).AS("max_film_rate")) + expectedQuery := ` +SELECT MAX(film.rental_rate) AS "max_film_rate" +FROM dvds.film; +` + query := Film.SELECT(MAX(Film.RentalRate).AS("max_film_rate")) - str, args, err := query.Sql() + assertQuery(t, query, expectedQuery) + + ret := struct { + MaxFilmRate float64 + }{} + + err := query.Query(db, &ret) assert.NilError(t, err) - - assert.Equal(t, str, `SELECT MAX(film.rental_rate) AS "max_film_rate" FROM dvds.film`) - assert.Equal(t, len(args), 0) - fmt.Println(str) + assert.Equal(t, ret.MaxFilmRate, 4.99) } func TestSelectQueryScalar(t *testing.T) { + expectedSql := ` +SELECT film.film_id AS "film.film_id", + film.title AS "film.title", + film.description AS "film.description", + film.release_year AS "film.release_year", + film.language_id AS "film.language_id", + film.rental_duration AS "film.rental_duration", + film.rental_rate AS "film.rental_rate", + film.length AS "film.length", + film.replacement_cost AS "film.replacement_cost", + film.rating AS "film.rating", + film.last_update AS "film.last_update", + film.special_features AS "film.special_features", + film.fulltext AS "film.fulltext" +FROM dvds.film +WHERE film.rental_rate = ( + SELECT MAX(film.rental_rate) + FROM dvds.film + ) +ORDER BY film.film_id ASC; +` - maxFilmRentalRate := sqlbuilder.NumExp(Film.SELECT(sqlbuilder.MAX(Film.RentalRate))) + maxFilmRentalRate := NumExp(Film.SELECT(MAX(Film.RentalRate))) query := Film.SELECT(Film.AllColumns). WHERE(Film.RentalRate.Eq(maxFilmRentalRate)). ORDER_BY(Film.FilmID.ASC()) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - assert.Equal(t, len(args), 0) - fmt.Println(queryStr) + fmt.Println(query.Sql()) + assertQuery(t, query, expectedSql) maxRentalRateFilms := []model.Film{} - err = query.Query(db, &maxRentalRateFilms) + err := query.Query(db, &maxRentalRateFilms) assert.NilError(t, err) @@ -515,26 +715,27 @@ func TestSelectQueryScalar(t *testing.T) { SpecialFeatures: stringPtr("{Trailers,\"Deleted Scenes\"}"), Fulltext: "'ace':1 'administr':9 'ancient':19 'astound':4 'car':17 'china':20 'databas':8 'epistl':5 'explor':12 'find':15 'goldfing':2 'must':14", }) - - //spew.Dump(maxRentalRateFilms[0]) } func TestSelectGroupByHaving(t *testing.T) { + expectedSql := ` +SELECT payment.customer_id AS "customer_payment_sum.customer_id", + SUM(payment.amount) AS "customer_payment_sum.amount_sum" +FROM dvds.payment +GROUP BY payment.customer_id +HAVING SUM(payment.amount) > 100 +ORDER BY SUM(payment.amount) ASC; +` customersPaymentQuery := Payment. SELECT( Payment.CustomerID.AS("customer_payment_sum.customer_id"), - sqlbuilder.SUM(Payment.Amount).AS("customer_payment_sum.amount_sum"), + SUM(Payment.Amount).AS("customer_payment_sum.amount_sum"), ). GROUP_BY(Payment.CustomerID). - ORDER_BY(sqlbuilder.SUM(Payment.Amount).ASC()). - HAVING(sqlbuilder.SUM(Payment.Amount).Gt(sqlbuilder.NewNumericLiteral(100))) + ORDER_BY(SUM(Payment.Amount).ASC()). + HAVING(SUM(Payment.Amount).Gt(NewNumericLiteral(100))) - queryStr, args, err := customersPaymentQuery.Sql() - - assert.NilError(t, err) - fmt.Println(queryStr) - assert.Equal(t, len(args), 1) - assert.Equal(t, queryStr, `SELECT payment.customer_id AS "customer_payment_sum.customer_id", SUM(payment.amount) AS "customer_payment_sum.amount_sum" FROM dvds.payment GROUP BY payment.customer_id HAVING SUM(payment.amount) > $1 ORDER BY SUM(payment.amount) ASC`) + assertQuery(t, customersPaymentQuery, expectedSql, 100) type CustomerPaymentSum struct { CustomerID int16 @@ -543,7 +744,7 @@ func TestSelectGroupByHaving(t *testing.T) { customerPaymentSum := []CustomerPaymentSum{} - err = customersPaymentQuery.Query(db, &customerPaymentSum) + err := customersPaymentQuery.Query(db, &customerPaymentSum) assert.NilError(t, err) @@ -555,16 +756,32 @@ func TestSelectGroupByHaving(t *testing.T) { } func TestSelectGroupBy2(t *testing.T) { - type CustomerWithAmounts struct { - Customer *model.Customer - AmountSum float64 - } - customersWithAmounts := []CustomerWithAmounts{} + expectedSql := ` +SELECT customer.customer_id AS "customer.customer_id", + customer.store_id AS "customer.store_id", + customer.first_name AS "customer.first_name", + customer.last_name AS "customer.last_name", + customer.email AS "customer.email", + customer.address_id AS "customer.address_id", + customer.activebool AS "customer.activebool", + customer.create_date AS "customer.create_date", + customer.last_update AS "customer.last_update", + customer.active AS "customer.active", + customer_payment_sum.amount_sum AS "customer_with_amounts.amount_sum" +FROM dvds.customer + JOIN ( + SELECT payment.customer_id AS "payment.customer_id", + SUM(payment.amount) AS "amount_sum" + FROM dvds.payment + GROUP BY payment.customer_id + ) AS customer_payment_sum ON customer.customer_id = customer_payment_sum."payment.customer_id" +ORDER BY customer_payment_sum.amount_sum ASC; +` customersPaymentSubQuery := Payment. SELECT( Payment.CustomerID, - sqlbuilder.SUM(Payment.Amount).AS("amount_sum"), + SUM(Payment.Amount).AS("amount_sum"), ). GROUP_BY(Payment.CustomerID) @@ -576,15 +793,16 @@ func TestSelectGroupBy2(t *testing.T) { SELECT(Customer.AllColumns, amountSumColumn.AS("customer_with_amounts.amount_sum")). ORDER_BY(amountSumColumn.ASC()) - queryStr, args, err := query.Sql() - assert.NilError(t, err) - fmt.Println(queryStr) - assert.Equal(t, len(args), 0) + assertQuery(t, query, expectedSql) - err = query.Query(db, &customersWithAmounts) - assert.NilError(t, err) - //spew.Dump(customersWithAmounts) + type CustomerWithAmounts struct { + Customer *model.Customer + AmountSum float64 + } + customersWithAmounts := []CustomerWithAmounts{} + err := query.Query(db, &customersWithAmounts) + assert.NilError(t, err) assert.Equal(t, len(customersWithAmounts), 599) assert.DeepEqual(t, customersWithAmounts[0].Customer, &model.Customer{ @@ -603,19 +821,28 @@ func TestSelectGroupBy2(t *testing.T) { } func TestSelectTimeColumns(t *testing.T) { + + expectedSql := ` +SELECT payment.payment_id AS "payment.payment_id", + payment.customer_id AS "payment.customer_id", + payment.staff_id AS "payment.staff_id", + payment.rental_id AS "payment.rental_id", + payment.amount AS "payment.amount", + payment.payment_date AS "payment.payment_date" +FROM dvds.payment +WHERE payment.payment_date <= '2007-02-14 22:16:01' +ORDER BY payment.payment_date ASC; +` + query := Payment.SELECT(Payment.AllColumns). WHERE(Payment.PaymentDate.LtEqL("2007-02-14 22:16:01")). ORDER_BY(Payment.PaymentDate.ASC()) - queryStr, args, err := query.Sql() - - assert.NilError(t, err) - assert.Equal(t, len(args), 1) - fmt.Println(queryStr) + assertQuery(t, query, expectedSql, "2007-02-14 22:16:01") payments := []model.Payment{} - err = query.Query(db, &payments) + err := query.Query(db, &payments) assert.NilError(t, err) @@ -630,7 +857,27 @@ func TestSelectTimeColumns(t *testing.T) { } func TestUnion(t *testing.T) { - query := sqlbuilder.UNION( + expectedQuery := ` +( + ( + SELECT payment.payment_id AS "payment.payment_id", + payment.amount AS "payment.amount" + FROM dvds.payment + WHERE payment.amount <= 100 + ) + UNION ALL + ( + SELECT payment.payment_id AS "payment.payment_id", + payment.amount AS "payment.amount" + FROM dvds.payment + WHERE payment.amount >= 200 + ) +) +ORDER BY "payment.payment_id" ASC, "payment.amount" DESC +LIMIT 10 +OFFSET 20; +` + query := UNION_ALL( Payment. SELECT(Payment.PaymentID.AS("payment.payment_id"), Payment.Amount). WHERE(Payment.Amount.LtEqL(100)), @@ -638,19 +885,18 @@ func TestUnion(t *testing.T) { SELECT(Payment.PaymentID, Payment.Amount). WHERE(Payment.Amount.GtEqL(200)), ). - ORDER_BY(sqlbuilder.RefColumn("payment.payment_id").ASC(), Payment.Amount.DESC()). - LIMIT(10).OFFSET(20) + ORDER_BY(RefColumn("payment.payment_id").ASC(), Payment.Amount.DESC()). + LIMIT(10). + OFFSET(20) - queryStr, args, err := query.Sql() + queryStr, _, _ := query.Sql() - assert.NilError(t, err) - - fmt.Println(queryStr) - fmt.Println(args) + fmt.Println("-" + queryStr + "-") + assertQuery(t, query, expectedQuery, int(100), int(200), int64(10), int64(20)) dest := []model.Payment{} - err = query.Query(db, &dest) + err := query.Query(db, &dest) assert.NilError(t, err) assert.Equal(t, len(dest), 10) @@ -669,26 +915,29 @@ func TestUnion(t *testing.T) { } func TestSelectWithCase(t *testing.T) { + expectedQuery := ` +SELECT (CASE payment.staff_id WHEN 1 THEN 'ONE' WHEN 2 THEN 'TWO' WHEN 3 THEN 'THREE' ELSE 'OTHER' END) AS "staff_id_num" +FROM dvds.payment +ORDER BY payment.payment_id ASC +LIMIT 20; +` query := Payment.SELECT( - sqlbuilder.CASE(Payment.StaffID). - WHEN(sqlbuilder.IntLiteral(1)).THEN(sqlbuilder.Literal("ONE")). - WHEN(sqlbuilder.IntLiteral(2)).THEN(sqlbuilder.Literal("TWO")). - WHEN(sqlbuilder.IntLiteral(3)).THEN(sqlbuilder.Literal("THREE")). - ELSE(sqlbuilder.Literal("OTHER")).AS("staff_id_num"), + CASE(Payment.StaffID). + WHEN(IntLiteral(1)).THEN(Literal("ONE")). + WHEN(IntLiteral(2)).THEN(Literal("TWO")). + WHEN(IntLiteral(3)).THEN(Literal("THREE")). + ELSE(Literal("OTHER")).AS("staff_id_num"), ). ORDER_BY(Payment.PaymentID.ASC()). LIMIT(20) - queryStr, _, err := query.Sql() - - assert.NilError(t, err) - assert.Equal(t, queryStr, `SELECT (CASE payment.staff_id WHEN $1 THEN $2 WHEN $3 THEN $4 WHEN $5 THEN $6 ELSE $7 END) AS "staff_id_num" FROM dvds.payment ORDER BY payment.payment_id ASC LIMIT $8`) + assertQuery(t, query, expectedQuery, 1, "ONE", 2, "TWO", 3, "THREE", "OTHER", int64(20)) dest := []struct { StaffIdNum string }{} - err = query.Query(db, &dest) + err := query.Query(db, &dest) assert.NilError(t, err) assert.Equal(t, len(dest), 20) @@ -697,84 +946,19 @@ func TestSelectWithCase(t *testing.T) { } func TestLockTable(t *testing.T) { - query := Address.LOCK().IN(sqlbuilder.LOCK_EXCLUSIVE).NOWAIT() + expectedSql := ` +LOCK TABLE dvds.address IN EXCLUSIVE MODE NOWAIT; +` + query := Address.LOCK().IN(LOCK_EXCLUSIVE).NOWAIT() - queryStr, _, err := query.Sql() + querySql, _, _ := query.Sql() + fmt.Println("-" + querySql + "-") - assert.NilError(t, err) - assert.Equal(t, queryStr, `LOCK TABLE dvds.address IN EXCLUSIVE MODE NOWAIT`) + assertQuery(t, query, expectedSql) tx, _ := db.Begin() - _, err = query.Execute(tx) + _, err := query.Execute(tx) assert.NilError(t, err) } - -func int16Ptr(i int16) *int16 { - return &i -} - -func int32Ptr(i int32) *int32 { - return &i -} - -func stringPtr(s string) *string { - return &s -} - -func timeWithoutTimeZone(t string, precision int) *time.Time { - - precisionStr := "" - - if precision > 0 { - precisionStr = "." + strings.Repeat("9", precision) - } - - time, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") - - if err != nil { - panic(err) - } - - return &time -} - -var customer0 = model.Customer{ - CustomerID: 1, - StoreID: 1, - FirstName: "Mary", - LastName: "Smith", - Email: stringPtr("mary.smith@sakilacustomer.org"), - Address: nil, - Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: int32Ptr(1), -} - -var customer1 = model.Customer{ - CustomerID: 2, - StoreID: 1, - FirstName: "Patricia", - LastName: "Johnson", - Email: stringPtr("patricia.johnson@sakilacustomer.org"), - Address: nil, - Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: int32Ptr(1), -} - -var lastCustomer = model.Customer{ - CustomerID: 599, - StoreID: 2, - FirstName: "Austin", - LastName: "Cintron", - Email: stringPtr("austin.cintron@sakilacustomer.org"), - Address: nil, - Activebool: true, - CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), - LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), - Active: int32Ptr(1), -} diff --git a/tests/test_util.go b/tests/test_util.go new file mode 100644 index 0000000..2c445ec --- /dev/null +++ b/tests/test_util.go @@ -0,0 +1,89 @@ +package tests + +import ( + "github.com/sub0zero/go-sqlbuilder/sqlbuilder" + "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/dvds/model" + "gotest.tools/assert" + "strings" + "testing" + "time" +) + +func assertQuery(t *testing.T, query sqlbuilder.Statement, expectedQuery string, expectedArgs ...interface{}) { + _, args, err := query.Sql() + assert.NilError(t, err) + //assert.Equal(t, queryStr, expectedQuery) + assert.DeepEqual(t, args, expectedArgs) + + debuqSql, err := query.DebugSql() + assert.NilError(t, err) + assert.Equal(t, debuqSql, expectedQuery, args) +} + +func int16Ptr(i int16) *int16 { + return &i +} + +func int32Ptr(i int32) *int32 { + return &i +} + +func stringPtr(s string) *string { + return &s +} + +func timeWithoutTimeZone(t string, precision int) *time.Time { + + precisionStr := "" + + if precision > 0 { + precisionStr = "." + strings.Repeat("9", precision) + } + + time, err := time.Parse("2006-01-02 15:04:05"+precisionStr+" +0000", t+" +0000") + + if err != nil { + panic(err) + } + + return &time +} + +var customer0 = model.Customer{ + CustomerID: 1, + StoreID: 1, + FirstName: "Mary", + LastName: "Smith", + Email: stringPtr("mary.smith@sakilacustomer.org"), + Address: nil, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: int32Ptr(1), +} + +var customer1 = model.Customer{ + CustomerID: 2, + StoreID: 1, + FirstName: "Patricia", + LastName: "Johnson", + Email: stringPtr("patricia.johnson@sakilacustomer.org"), + Address: nil, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: int32Ptr(1), +} + +var lastCustomer = model.Customer{ + CustomerID: 599, + StoreID: 2, + FirstName: "Austin", + LastName: "Cintron", + Email: stringPtr("austin.cintron@sakilacustomer.org"), + Address: nil, + Activebool: true, + CreateDate: *timeWithoutTimeZone("2006-02-14 00:00:00", 0), + LastUpdate: timeWithoutTimeZone("2013-05-26 14:49:45.738", 3), + Active: int32Ptr(1), +}