Add statements debug sql support.
This commit is contained in:
parent
439c9f1ef9
commit
240ddd65e6
27 changed files with 1013 additions and 426 deletions
|
|
@ -20,7 +20,7 @@ func (a *Alias) serializeForProjection(statement statementType, out *queryData)
|
|||
return err
|
||||
}
|
||||
|
||||
out.writeString(" AS \"" + a.alias + "\"")
|
||||
out.writeString(`AS "` + a.alias + `"`)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gotest.tools/assert"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -118,8 +119,17 @@ func TestExists(t *testing.T) {
|
|||
out := queryData{}
|
||||
err := query.serialize(select_statement, &out)
|
||||
|
||||
fmt.Println(out.buff.String())
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, out.buff.String(), "EXISTS (SELECT $1 FROM db.table2 WHERE table1.col1 = table2.col3)")
|
||||
|
||||
expectedSql :=
|
||||
`EXISTS (
|
||||
SELECT $1
|
||||
FROM db.table2
|
||||
WHERE table1.col1 = table2.col3
|
||||
)`
|
||||
assert.Equal(t, out.buff.String(), expectedSql)
|
||||
}
|
||||
|
||||
func TestIn(t *testing.T) {
|
||||
|
|
@ -129,7 +139,11 @@ func TestIn(t *testing.T) {
|
|||
err := query.serialize(select_statement, &out)
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, out.buff.String(), `$1 IN (SELECT table1.col1 AS "table1.col1" FROM db.table1)`)
|
||||
fmt.Println(out.buff.String())
|
||||
assert.Equal(t, out.buff.String(), `$1 IN (
|
||||
SELECT table1.col1 AS "table1.col1"
|
||||
FROM db.table1
|
||||
)`)
|
||||
|
||||
query2 := ROW(Literal(12), table1Col1).IN(table2.SELECT(table2Col3, table3Col1))
|
||||
|
||||
|
|
@ -137,5 +151,10 @@ func TestIn(t *testing.T) {
|
|||
err = query2.serialize(select_statement, &out)
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN (SELECT table2.col3 AS "table2.col3", table3.col1 AS "table3.col1" FROM db.table2))`)
|
||||
fmt.Println(out.buff.String())
|
||||
assert.Equal(t, out.buff.String(), `(ROW($1, table1.col1) IN (
|
||||
SELECT table2.col3 AS "table2.col3",
|
||||
table3.col1 AS "table3.col1"
|
||||
FROM db.table2
|
||||
))`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package sqlbuilder
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
|
|
@ -13,6 +12,9 @@ type clause interface {
|
|||
type queryData struct {
|
||||
buff bytes.Buffer
|
||||
args []interface{}
|
||||
|
||||
lastChar byte
|
||||
ident int
|
||||
}
|
||||
|
||||
type statementType string
|
||||
|
|
@ -26,48 +28,125 @@ const (
|
|||
lock_statement statementType = "LOCK"
|
||||
)
|
||||
|
||||
const defaultIdent = 5
|
||||
|
||||
func (q *queryData) increaseIdent() {
|
||||
q.ident += defaultIdent
|
||||
}
|
||||
|
||||
func (q *queryData) decreaseIdent() {
|
||||
if q.ident < defaultIdent {
|
||||
q.ident = 0
|
||||
}
|
||||
|
||||
q.ident -= defaultIdent
|
||||
}
|
||||
|
||||
func (q *queryData) writeProjection(statement statementType, projections []projection) error {
|
||||
return serializeProjectionList(statement, projections, q)
|
||||
q.increaseIdent()
|
||||
err := serializeProjectionList(statement, projections, q)
|
||||
q.decreaseIdent()
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryData) writeFrom(statement statementType, table tableInterface) error {
|
||||
q.nextLine()
|
||||
q.writeString("FROM")
|
||||
|
||||
q.increaseIdent()
|
||||
err := table.serialize(statement, q)
|
||||
q.decreaseIdent()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryData) writeWhere(statement statementType, where expression) error {
|
||||
q.writeString(" WHERE ")
|
||||
return where.serialize(statement, q)
|
||||
q.nextLine()
|
||||
q.writeString("WHERE")
|
||||
|
||||
q.increaseIdent()
|
||||
err := where.serialize(statement, q)
|
||||
q.decreaseIdent()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClause) error {
|
||||
q.writeString(" GROUP BY ")
|
||||
q.nextLine()
|
||||
q.writeString("GROUP BY")
|
||||
|
||||
return serializeGroupByClauseList(statement, groupBy, q)
|
||||
q.increaseIdent()
|
||||
err := serializeGroupByClauseList(statement, groupBy, q)
|
||||
q.decreaseIdent()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryData) writeOrderBy(statement statementType, orderBy []orderByClause) error {
|
||||
q.writeString(" ORDER BY ")
|
||||
return serializeOrderByClauseList(statement, orderBy, q)
|
||||
q.nextLine()
|
||||
q.writeString("ORDER BY")
|
||||
|
||||
q.increaseIdent()
|
||||
err := serializeOrderByClauseList(statement, orderBy, q)
|
||||
q.decreaseIdent()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryData) writeHaving(statement statementType, having expression) error {
|
||||
q.writeString(" HAVING ")
|
||||
return having.serialize(statement, q)
|
||||
q.nextLine()
|
||||
q.writeString("HAVING")
|
||||
|
||||
q.increaseIdent()
|
||||
err := having.serialize(statement, q)
|
||||
q.decreaseIdent()
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (q *queryData) nextLine() {
|
||||
q.write([]byte{'\n'})
|
||||
q.write(bytes.Repeat([]byte{' '}, q.ident))
|
||||
}
|
||||
|
||||
func (q *queryData) write(data []byte) {
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !isPreSeparator(q.lastChar) && !isPostSeparator(data[0]) && q.buff.Len() > 0 {
|
||||
q.buff.WriteByte(' ')
|
||||
}
|
||||
|
||||
q.buff.Write(data)
|
||||
q.lastChar = data[len(data)-1]
|
||||
}
|
||||
|
||||
func isPreSeparator(b byte) bool {
|
||||
return b == ' ' || b == '.' || b == ',' || b == '(' || b == '\n'
|
||||
}
|
||||
|
||||
func isPostSeparator(b byte) bool {
|
||||
return b == ' ' || b == '.' || b == ',' || b == ')' || b == '\n'
|
||||
}
|
||||
|
||||
func (q *queryData) writeString(str string) {
|
||||
q.buff.WriteString(str)
|
||||
q.write([]byte(str))
|
||||
}
|
||||
|
||||
func (q *queryData) writeByte(b byte) {
|
||||
q.buff.WriteByte(b)
|
||||
q.write([]byte{b})
|
||||
}
|
||||
|
||||
func (q *queryData) finalize() (string, []interface{}) {
|
||||
return q.buff.String() + ";\n", q.args
|
||||
}
|
||||
|
||||
func (q *queryData) insertArgument(arg interface{}) {
|
||||
q.args = append(q.args, arg)
|
||||
argPlaceholder := "$" + strconv.Itoa(len(q.args))
|
||||
|
||||
q.buff.WriteString(argPlaceholder)
|
||||
q.writeString(argPlaceholder)
|
||||
}
|
||||
|
||||
func (q *queryData) reset() {
|
||||
|
|
@ -75,49 +154,49 @@ func (q *queryData) reset() {
|
|||
q.args = []interface{}{}
|
||||
}
|
||||
|
||||
func argToString(value interface{}) (string, error) {
|
||||
func ArgToString(value interface{}) string {
|
||||
switch bindVal := value.(type) {
|
||||
case bool:
|
||||
if bindVal {
|
||||
return "TRUE", nil
|
||||
return "TRUE"
|
||||
} else {
|
||||
return "FALSE", nil
|
||||
return "FALSE"
|
||||
}
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(bindVal), 10), nil
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case int:
|
||||
return strconv.FormatInt(int64(bindVal), 10), nil
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(bindVal), 10), nil
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(bindVal), 10), nil
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(int64(bindVal), 10), nil
|
||||
return strconv.FormatInt(int64(bindVal), 10)
|
||||
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(bindVal), 10), nil
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(bindVal), 10), nil
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(bindVal), 10), nil
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(bindVal), 10), nil
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(uint64(bindVal), 10), nil
|
||||
return strconv.FormatUint(uint64(bindVal), 10)
|
||||
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil
|
||||
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
|
||||
case float64:
|
||||
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64), nil
|
||||
return strconv.FormatFloat(float64(bindVal), 'f', -1, 64)
|
||||
|
||||
case string:
|
||||
return bindVal, nil
|
||||
return `'` + bindVal + `'`
|
||||
case []byte:
|
||||
return string(bindVal), nil
|
||||
return string(bindVal)
|
||||
//TODO: implement
|
||||
//case time.Time:
|
||||
// return bindVal.String())
|
||||
default:
|
||||
return "", errors.New("Unsupported literal type. ")
|
||||
return "[Unknown type]"
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,17 +79,16 @@ func (c *baseColumn) DefaultAlias() projection {
|
|||
|
||||
func (c *baseColumn) serializeAsOrderBy(statement statementType, out *queryData) error {
|
||||
if statement == set_statement {
|
||||
// set statement (UNION, EXCEPT ...) can reference only select projections in order by clause
|
||||
out.writeString(`"`)
|
||||
// set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause
|
||||
columnRef := ""
|
||||
|
||||
if c.tableName != "" {
|
||||
out.writeString(c.tableName)
|
||||
out.writeString(".")
|
||||
columnRef += c.tableName + "."
|
||||
}
|
||||
|
||||
out.writeString(c.name)
|
||||
columnRef += c.name
|
||||
|
||||
out.writeString(`"`)
|
||||
out.writeString(`"` + columnRef + `"`)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -98,22 +97,26 @@ func (c *baseColumn) serializeAsOrderBy(statement statementType, out *queryData)
|
|||
}
|
||||
|
||||
func (c baseColumn) serialize(statement statementType, out *queryData) error {
|
||||
|
||||
columnRef := ""
|
||||
|
||||
if c.tableName != "" {
|
||||
out.writeString(c.tableName)
|
||||
out.writeString(".")
|
||||
columnRef += c.tableName + "."
|
||||
}
|
||||
|
||||
wrapColumnName := strings.Contains(c.name, ".")
|
||||
|
||||
if wrapColumnName {
|
||||
out.writeString(`"`)
|
||||
columnRef += `"`
|
||||
}
|
||||
|
||||
out.writeString(c.name)
|
||||
columnRef += c.name
|
||||
|
||||
if wrapColumnName {
|
||||
out.writeString(`"`)
|
||||
columnRef += `"`
|
||||
}
|
||||
|
||||
out.writeString(columnRef)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
type deleteStatement interface {
|
||||
statement
|
||||
Statement
|
||||
|
||||
WHERE(expression boolExpression) deleteStatement
|
||||
}
|
||||
|
|
@ -28,28 +28,44 @@ func (d *deleteStatementImpl) WHERE(expression boolExpression) deleteStatement {
|
|||
return d
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
queryData := &queryData{}
|
||||
|
||||
queryData.writeString("DELETE FROM ")
|
||||
func (d *deleteStatementImpl) serializeImpl(out *queryData) error {
|
||||
out.nextLine()
|
||||
out.writeString("DELETE FROM")
|
||||
|
||||
if d.table == nil {
|
||||
return "", nil, errors.New("nil tableName.")
|
||||
return errors.New("nil tableName.")
|
||||
}
|
||||
|
||||
if err = d.table.serialize(delete_statement, queryData); err != nil {
|
||||
return
|
||||
if err := d.table.serialize(delete_statement, out); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.where == nil {
|
||||
return "", nil, errors.New("Deleting without a WHERE clause.")
|
||||
return errors.New("Deleting without a WHERE clause.")
|
||||
}
|
||||
|
||||
if err = queryData.writeWhere(delete_statement, d.where); err != nil {
|
||||
if err := out.writeWhere(delete_statement, d.where); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
queryData := &queryData{}
|
||||
|
||||
err = d.serializeImpl(queryData)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return queryData.buff.String() + ";", queryData.args, nil
|
||||
query, args = queryData.finalize()
|
||||
return
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) DebugSql() (query string, err error) {
|
||||
return DebugSql(d)
|
||||
}
|
||||
|
||||
func (u *deleteStatementImpl) Query(db types.Db, destination interface{}) error {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gotest.tools/assert"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -14,5 +15,10 @@ func TestDeleteWithWhere(t *testing.T) {
|
|||
sql, _, err := table1.DELETE().WHERE(table1Col1.EqL(1)).Sql()
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, sql, "DELETE FROM db.table1 WHERE table1.col1 = $1;")
|
||||
fmt.Println(sql)
|
||||
expectedSql := `
|
||||
DELETE FROM db.table1
|
||||
WHERE table1.col1 = $1;
|
||||
`
|
||||
assert.Equal(t, sql, expectedSql)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -543,8 +543,8 @@ func isDbBaseType(objType reflect.Type) bool {
|
|||
typeStr := objType.String()
|
||||
|
||||
switch typeStr {
|
||||
case "string", "int32", "int16", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8",
|
||||
"*string", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
|
||||
case "string", "int", "int32", "int16", "float32", "float64", "time.Time", "bool", "[]byte", "[]uint8",
|
||||
"*string", "*int", "*int32", "*int16", "*float32", "*float64", "*time.Time", "*bool", "*[]byte", "*[]uint8":
|
||||
return true
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ func (e *expressionInterfaceImpl) IN(subQuery selectStatement) boolExpression {
|
|||
}
|
||||
|
||||
func (e *expressionInterfaceImpl) NOT_IN(subQuery selectStatement) boolExpression {
|
||||
return newBinaryBoolExpression(e.parent, subQuery, "NOT_IN")
|
||||
return newBinaryBoolExpression(e.parent, subQuery, "NOT IN")
|
||||
}
|
||||
|
||||
func (e *expressionInterfaceImpl) AS(alias string) projection {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ type expressionTable interface {
|
|||
|
||||
type expressionTableImpl struct {
|
||||
statement expression
|
||||
columns []column
|
||||
alias string
|
||||
}
|
||||
|
||||
|
|
@ -24,7 +23,7 @@ func (s *expressionTableImpl) TableName() string {
|
|||
}
|
||||
|
||||
func (s *expressionTableImpl) Columns() []column {
|
||||
return s.columns
|
||||
return []column{}
|
||||
}
|
||||
|
||||
func (s *expressionTableImpl) RefIntColumnName(name string) *IntegerColumn {
|
||||
|
|
@ -42,20 +41,20 @@ func (s *expressionTableImpl) RefIntColumn(column column) *IntegerColumn {
|
|||
}
|
||||
|
||||
func (s *expressionTableImpl) RefStringColumn(column column) *StringColumn {
|
||||
strColumn := NewStringColumn(column.Name(), NotNullable)
|
||||
strColumn.setTableName(column.TableName())
|
||||
strColumn := NewStringColumn(column.TableName()+"."+column.Name(), NotNullable)
|
||||
strColumn.setTableName(s.alias)
|
||||
return strColumn
|
||||
}
|
||||
|
||||
func (s *expressionTableImpl) serialize(statement statementType, out *queryData) error {
|
||||
out.writeString("( ")
|
||||
//out.writeString("( ")
|
||||
err := s.statement.serialize(statement, out)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out.writeString(" ) AS ")
|
||||
out.writeString("AS")
|
||||
out.writeString(s.alias)
|
||||
|
||||
return nil
|
||||
|
|
|
|||
|
|
@ -29,8 +29,8 @@ func newFunc(name string, expressions []expression, parent expression) *funcExpr
|
|||
}
|
||||
|
||||
func (f *funcExpressionImpl) serialize(statement statementType, out *queryData) error {
|
||||
out.writeString(f.name)
|
||||
out.writeString("(")
|
||||
out.writeString(f.name + "(")
|
||||
|
||||
err := serializeExpressionList(statement, f.expression, ", ", out)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -111,7 +111,6 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro
|
|||
out.writeString("(CASE")
|
||||
|
||||
if c.expression != nil {
|
||||
out.writeString(" ")
|
||||
err := c.expression.serialize(statement, out)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -120,7 +119,7 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro
|
|||
}
|
||||
|
||||
if len(c.when) == 0 || len(c.then) == 0 {
|
||||
return errors.New("Invalid case statement. There should be at least one when/then expression pair. ")
|
||||
return errors.New("Invalid case Statement. There should be at least one when/then expression pair. ")
|
||||
}
|
||||
|
||||
if len(c.when) != len(c.then) {
|
||||
|
|
@ -128,14 +127,14 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro
|
|||
}
|
||||
|
||||
for i, when := range c.when {
|
||||
out.writeString(" WHEN ")
|
||||
out.writeString("WHEN")
|
||||
err := when.serialize(statement, out)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out.writeString(" THEN ")
|
||||
out.writeString("THEN")
|
||||
err = c.then[i].serialize(statement, out)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -144,7 +143,7 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro
|
|||
}
|
||||
|
||||
if c.els != nil {
|
||||
out.writeString(" ELSE ")
|
||||
out.writeString("ELSE")
|
||||
err := c.els.serialize(statement, out)
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -152,7 +151,7 @@ func (c *caseExpression) serialize(statement statementType, out *queryData) erro
|
|||
}
|
||||
}
|
||||
|
||||
out.writeString(" END)")
|
||||
out.writeString("END)")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,9 +10,9 @@ import (
|
|||
)
|
||||
|
||||
type insertStatement interface {
|
||||
statement
|
||||
Statement
|
||||
|
||||
// Add a row of values to the insert statement.
|
||||
// Add a row of values to the insert Statement.
|
||||
VALUES(values ...interface{}) insertStatement
|
||||
// Map or stracture mapped to column names
|
||||
VALUES_MAPPING(data interface{}) insertStatement
|
||||
|
|
@ -48,9 +48,9 @@ func (u *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
|
|||
}
|
||||
|
||||
// expression or default keyword
|
||||
func (s *insertStatementImpl) VALUES(values ...interface{}) insertStatement {
|
||||
func (i *insertStatementImpl) VALUES(values ...interface{}) insertStatement {
|
||||
if len(values) == 0 {
|
||||
return s
|
||||
return i
|
||||
}
|
||||
|
||||
literalRow := []clause{}
|
||||
|
|
@ -63,8 +63,8 @@ func (s *insertStatementImpl) VALUES(values ...interface{}) insertStatement {
|
|||
}
|
||||
}
|
||||
|
||||
s.rows = append(s.rows, literalRow)
|
||||
return s
|
||||
i.rows = append(i.rows, literalRow)
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) insertStatement {
|
||||
|
|
@ -121,13 +121,19 @@ func (i *insertStatementImpl) addError(err string) {
|
|||
i.errors = append(i.errors, err)
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) DebugSql() (query string, err error) {
|
||||
return DebugSql(i)
|
||||
}
|
||||
|
||||
func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
||||
if len(s.errors) > 0 {
|
||||
return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", "))
|
||||
}
|
||||
|
||||
queryData := &queryData{}
|
||||
queryData.writeString("INSERT INTO ")
|
||||
|
||||
queryData.nextLine()
|
||||
queryData.writeString("INSERT INTO")
|
||||
|
||||
if s.table == nil {
|
||||
return "", nil, errors.Newf("nil tableName.")
|
||||
|
|
@ -135,12 +141,14 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
|
||||
err = s.table.serialize(insert_statement, queryData)
|
||||
|
||||
queryData.writeByte(' ')
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
if len(s.columns) > 0 {
|
||||
queryData.writeString(" (")
|
||||
queryData.writeString("(")
|
||||
|
||||
err = serializeColumnList(insert_statement, s.columns, queryData)
|
||||
|
||||
|
|
@ -148,7 +156,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
return "", nil, err
|
||||
}
|
||||
|
||||
queryData.writeString(") ")
|
||||
queryData.writeString(")")
|
||||
}
|
||||
|
||||
if len(s.rows) == 0 && s.query == nil {
|
||||
|
|
@ -160,12 +168,17 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if len(s.rows) > 0 {
|
||||
queryData.writeString("VALUES (")
|
||||
queryData.writeString("VALUES")
|
||||
|
||||
for row_i, row := range s.rows {
|
||||
if row_i > 0 {
|
||||
queryData.writeString(", (")
|
||||
queryData.writeString(",")
|
||||
}
|
||||
|
||||
queryData.increaseIdent()
|
||||
queryData.nextLine()
|
||||
queryData.writeString("(")
|
||||
|
||||
if len(row) != len(s.columns) {
|
||||
return "", nil, errors.New("# of values does not match # of columns.")
|
||||
}
|
||||
|
|
@ -177,6 +190,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
queryData.writeByte(')')
|
||||
queryData.decreaseIdent()
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -189,7 +203,8 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if len(s.returning) > 0 {
|
||||
queryData.writeString(" RETURNING ")
|
||||
queryData.nextLine()
|
||||
queryData.writeString("RETURNING")
|
||||
|
||||
err = queryData.writeProjection(insert_statement, s.returning)
|
||||
|
||||
|
|
@ -198,7 +213,7 @@ func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
}
|
||||
|
||||
queryData.writeByte(';')
|
||||
sql, args = queryData.finalize()
|
||||
|
||||
return queryData.buff.String(), queryData.args, nil
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,7 +29,10 @@ func TestInsertColumnLengthMismatch(t *testing.T) {
|
|||
func TestInsertNilValue(t *testing.T) {
|
||||
query, args, err := table1.INSERT(table1Col1).VALUES(nil).Sql()
|
||||
|
||||
assert.Equal(t, query, "INSERT INTO db.table1 (col1) VALUES ($1);")
|
||||
assert.Equal(t, query, `
|
||||
INSERT INTO db.table1 (col1) VALUES
|
||||
($1);
|
||||
`)
|
||||
assert.Equal(t, len(args), 1)
|
||||
assert.NilError(t, err)
|
||||
}
|
||||
|
|
@ -44,7 +47,10 @@ func TestInsertSingleValue(t *testing.T) {
|
|||
sql, _, err := table1.INSERT(table1Col1).VALUES(1).Sql()
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, sql, "INSERT INTO db.table1 (col1) VALUES ($1);")
|
||||
assert.Equal(t, sql, `
|
||||
INSERT INTO db.table1 (col1) VALUES
|
||||
($1);
|
||||
`)
|
||||
}
|
||||
|
||||
func TestInsertDate(t *testing.T) {
|
||||
|
|
@ -53,7 +59,10 @@ func TestInsertDate(t *testing.T) {
|
|||
sql, _, err := table1.INSERT(table1Col4).VALUES(date).Sql()
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, sql, "INSERT INTO db.table1 (col4) VALUES ($1);")
|
||||
assert.Equal(t, sql, `
|
||||
INSERT INTO db.table1 (col4) VALUES
|
||||
($1);
|
||||
`)
|
||||
}
|
||||
|
||||
func TestInsertMultipleValues(t *testing.T) {
|
||||
|
|
@ -63,7 +72,14 @@ func TestInsertMultipleValues(t *testing.T) {
|
|||
sql, _, err := stmt.Sql()
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2,col3) VALUES ($1, $2, $3);")
|
||||
fmt.Println(sql)
|
||||
|
||||
expectedSql := `
|
||||
INSERT INTO db.table1 (col1,col2,col3) VALUES
|
||||
($1, $2, $3);
|
||||
`
|
||||
|
||||
assert.Equal(t, sql, expectedSql)
|
||||
}
|
||||
|
||||
func TestInsertMultipleRows(t *testing.T) {
|
||||
|
|
@ -75,7 +91,16 @@ func TestInsertMultipleRows(t *testing.T) {
|
|||
sql, _, err := stmt.Sql()
|
||||
assert.NilError(t, err)
|
||||
|
||||
assert.Equal(t, sql, "INSERT INTO db.table1 (col1,col2) VALUES ($1, $2), ($3, $4), ($5, $6);")
|
||||
fmt.Println(sql)
|
||||
|
||||
expectedSql := `
|
||||
INSERT INTO db.table1 (col1,col2) VALUES
|
||||
($1, $2),
|
||||
($3, $4),
|
||||
($5, $6);
|
||||
`
|
||||
|
||||
assert.Equal(t, sql, expectedSql)
|
||||
}
|
||||
|
||||
func TestInsertValuesFromModel(t *testing.T) {
|
||||
|
|
@ -98,7 +123,10 @@ func TestInsertValuesFromModel(t *testing.T) {
|
|||
|
||||
fmt.Println(sql)
|
||||
|
||||
assert.Equal(t, sql, `INSERT INTO db.table1 (col1,col2) VALUES ($1, $2);`)
|
||||
assert.Equal(t, sql, `
|
||||
INSERT INTO db.table1 (col1,col2) VALUES
|
||||
($1, $2);
|
||||
`)
|
||||
}
|
||||
|
||||
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ const (
|
|||
)
|
||||
|
||||
type lockStatement interface {
|
||||
statement
|
||||
Statement
|
||||
|
||||
IN(lockMode lockMode) lockStatement
|
||||
NOWAIT() lockStatement
|
||||
|
|
@ -48,9 +48,13 @@ func (l *lockStatementImpl) NOWAIT() lockStatement {
|
|||
return l
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) DebugSql() (query string, err error) {
|
||||
return DebugSql(l)
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
if l == nil {
|
||||
return "", nil, errors.New("nil statement.")
|
||||
return "", nil, errors.New("nil Statement.")
|
||||
}
|
||||
|
||||
if len(l.tables) == 0 {
|
||||
|
|
@ -59,7 +63,8 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
|
|||
|
||||
out := &queryData{}
|
||||
|
||||
out.writeString("LOCK TABLE ")
|
||||
out.nextLine()
|
||||
out.writeString("LOCK TABLE")
|
||||
|
||||
for i, table := range l.tables {
|
||||
if i > 0 {
|
||||
|
|
@ -74,16 +79,17 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if l.lockMode != "" {
|
||||
out.writeString(" IN ")
|
||||
out.writeString("IN")
|
||||
out.writeString(string(l.lockMode))
|
||||
out.writeString(" MODE")
|
||||
out.writeString("MODE")
|
||||
}
|
||||
|
||||
if l.nowait {
|
||||
out.writeString(" NOWAIT")
|
||||
out.writeString("NOWAIT")
|
||||
}
|
||||
|
||||
return out.buff.String(), out.args, nil
|
||||
query, args = out.finalize()
|
||||
return
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) Query(db types.Db, destination interface{}) error {
|
||||
|
|
|
|||
|
|
@ -11,7 +11,9 @@ func TestLockSingleTable(t *testing.T) {
|
|||
queryStr, _, err := lock.Sql()
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, queryStr, `LOCK TABLE db.table1 IN ROW SHARE MODE`)
|
||||
assert.Equal(t, queryStr, `
|
||||
LOCK TABLE db.table1 IN ROW SHARE MODE;
|
||||
`)
|
||||
}
|
||||
|
||||
func TestLockMultipleTable(t *testing.T) {
|
||||
|
|
@ -20,5 +22,7 @@ func TestLockMultipleTable(t *testing.T) {
|
|||
queryStr, _, err := lock.Sql()
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, queryStr, `LOCK TABLE db.table2, db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT`)
|
||||
assert.Equal(t, queryStr, `
|
||||
LOCK TABLE db.table2, db.table1 IN ACCESS EXCLUSIVE MODE NOWAIT;
|
||||
`)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -131,9 +131,9 @@ func newNumericExpressionWrap(expression expression) numericExpression {
|
|||
}
|
||||
|
||||
func (c *numericExpressionWrapper) serialize(statement statementType, out *queryData) error {
|
||||
out.writeString("(")
|
||||
//out.writeString("(")
|
||||
err := c.expression.serialize(statement, out)
|
||||
out.writeString(")")
|
||||
//out.writeString(")")
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
type selectStatement interface {
|
||||
statement
|
||||
Statement
|
||||
expression
|
||||
|
||||
DISTINCT() selectStatement
|
||||
|
|
@ -84,15 +84,17 @@ func (s *selectStatementImpl) FROM(table readableTable) selectStatement {
|
|||
}
|
||||
|
||||
func (s *selectStatementImpl) serialize(statement statementType, out *queryData) error {
|
||||
|
||||
out.writeString("(")
|
||||
|
||||
out.increaseIdent()
|
||||
err := s.serializeImpl(out)
|
||||
out.decreaseIdent()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
out.nextLine()
|
||||
out.writeString(")")
|
||||
|
||||
return nil
|
||||
|
|
@ -100,10 +102,11 @@ func (s *selectStatementImpl) serialize(statement statementType, out *queryData)
|
|||
|
||||
func (s *selectStatementImpl) serializeImpl(out *queryData) error {
|
||||
|
||||
out.writeString("SELECT ")
|
||||
out.nextLine()
|
||||
out.writeString("SELECT")
|
||||
|
||||
if s.distinct {
|
||||
out.writeString("DISTINCT ")
|
||||
out.writeString("DISTINCT")
|
||||
}
|
||||
|
||||
if s.projections == nil || len(s.projections) == 0 {
|
||||
|
|
@ -116,16 +119,18 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
|
|||
return err
|
||||
}
|
||||
|
||||
out.writeString(" FROM ")
|
||||
|
||||
if s.table == nil {
|
||||
return errors.Newf("nil tableName.")
|
||||
}
|
||||
|
||||
if err := s.table.serialize(select_statement, out); err != nil {
|
||||
if err := out.writeFrom(select_statement, s.table); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//if err := s.table.serialize(select_statement, out); err != nil {
|
||||
// return err
|
||||
//}
|
||||
|
||||
if s.where != nil {
|
||||
err := out.writeWhere(select_statement, s.where)
|
||||
|
||||
|
|
@ -159,33 +164,42 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
|
|||
}
|
||||
|
||||
if s.limit >= 0 {
|
||||
out.writeString(" LIMIT ")
|
||||
out.nextLine()
|
||||
out.writeString("LIMIT")
|
||||
out.insertArgument(s.limit)
|
||||
}
|
||||
|
||||
if s.offset >= 0 {
|
||||
out.writeString(" OFFSET ")
|
||||
out.nextLine()
|
||||
out.writeString("OFFSET")
|
||||
out.insertArgument(s.offset)
|
||||
}
|
||||
|
||||
if s.forUpdate {
|
||||
out.writeString(" FOR UPDATE")
|
||||
out.nextLine()
|
||||
out.writeString("FOR UPDATE")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Return the properly escaped SQL statement, against the specified database
|
||||
func (q *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
// Return the properly escaped SQL Statement, against the specified database
|
||||
func (s *selectStatementImpl) Sql() (query string, args []interface{}, err error) {
|
||||
queryData := queryData{}
|
||||
|
||||
err = q.serializeImpl(&queryData)
|
||||
err = s.serializeImpl(&queryData)
|
||||
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return queryData.buff.String(), queryData.args, nil
|
||||
query, args = queryData.finalize()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) DebugSql() (query string, err error) {
|
||||
return DebugSql(s)
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) AsTable(alias string) expressionTable {
|
||||
|
|
@ -195,9 +209,9 @@ func (s *selectStatementImpl) AsTable(alias string) expressionTable {
|
|||
}
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) WHERE(expression boolExpression) selectStatement {
|
||||
q.where = expression
|
||||
return q
|
||||
func (s *selectStatementImpl) WHERE(expression boolExpression) selectStatement {
|
||||
s.where = expression
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) selectStatement {
|
||||
|
|
@ -205,46 +219,46 @@ func (s *selectStatementImpl) GROUP_BY(groupByClauses ...groupByClause) selectSt
|
|||
return s
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) HAVING(expression boolExpression) selectStatement {
|
||||
q.having = expression
|
||||
return q
|
||||
func (s *selectStatementImpl) HAVING(expression boolExpression) selectStatement {
|
||||
s.having = expression
|
||||
return s
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) ORDER_BY(clauses ...orderByClause) selectStatement {
|
||||
func (s *selectStatementImpl) ORDER_BY(clauses ...orderByClause) selectStatement {
|
||||
|
||||
q.orderBy = clauses
|
||||
s.orderBy = clauses
|
||||
|
||||
return q
|
||||
return s
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) OFFSET(offset int64) selectStatement {
|
||||
q.offset = offset
|
||||
return q
|
||||
func (s *selectStatementImpl) OFFSET(offset int64) selectStatement {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) LIMIT(limit int64) selectStatement {
|
||||
q.limit = limit
|
||||
return q
|
||||
func (s *selectStatementImpl) LIMIT(limit int64) selectStatement {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) DISTINCT() selectStatement {
|
||||
q.distinct = true
|
||||
return q
|
||||
func (s *selectStatementImpl) DISTINCT() selectStatement {
|
||||
s.distinct = true
|
||||
return s
|
||||
}
|
||||
|
||||
func (q *selectStatementImpl) FOR_UPDATE() selectStatement {
|
||||
q.forUpdate = true
|
||||
return q
|
||||
func (s *selectStatementImpl) FOR_UPDATE() selectStatement {
|
||||
s.forUpdate = true
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) Query(db types.Db, destination interface{}) error {
|
||||
return Query(s, db, destination)
|
||||
}
|
||||
|
||||
func (u *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
|
||||
return Execute(u, db)
|
||||
func (s *selectStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
|
||||
return Execute(s, db)
|
||||
}
|
||||
|
||||
func NumExp(statement selectStatement) numericExpression {
|
||||
return newNumericExpressionWrap(statement)
|
||||
func NumExp(expression expression) numericExpression {
|
||||
return newNumericExpressionWrap(expression)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ const (
|
|||
)
|
||||
|
||||
type setStatement interface {
|
||||
statement
|
||||
Statement
|
||||
expression
|
||||
|
||||
ORDER_BY(clauses ...orderByClause) setStatement
|
||||
|
|
@ -97,8 +97,10 @@ func (us *setStatementImpl) AsTable(alias string) expressionTable {
|
|||
}
|
||||
|
||||
func (s *setStatementImpl) serialize(statement statementType, out *queryData) error {
|
||||
|
||||
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
|
||||
out.writeString("(")
|
||||
out.increaseIdent()
|
||||
}
|
||||
|
||||
err := s.serializeImpl(out)
|
||||
|
|
@ -108,6 +110,8 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData) er
|
|||
}
|
||||
|
||||
if s.orderBy != nil || s.limit >= 0 || s.offset >= 0 {
|
||||
out.decreaseIdent()
|
||||
out.nextLine()
|
||||
out.writeString(")")
|
||||
}
|
||||
|
||||
|
|
@ -117,18 +121,22 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData) er
|
|||
func (s *setStatementImpl) serializeImpl(out *queryData) error {
|
||||
|
||||
if len(s.selects) < 2 {
|
||||
return errors.Newf("UNION statement must have at least two SELECT statements.")
|
||||
return errors.Newf("UNION Statement must have at least two SELECT statements.")
|
||||
}
|
||||
|
||||
out.nextLine()
|
||||
out.writeString("(")
|
||||
out.increaseIdent()
|
||||
|
||||
for i, selectStmt := range s.selects {
|
||||
out.nextLine()
|
||||
if i > 0 {
|
||||
out.writeString(" " + s.operator + " ")
|
||||
out.writeString(s.operator)
|
||||
|
||||
if s.all {
|
||||
out.writeString(" ALL ")
|
||||
out.writeString("ALL")
|
||||
}
|
||||
out.nextLine()
|
||||
}
|
||||
|
||||
err := selectStmt.serialize(set_statement, out)
|
||||
|
|
@ -138,6 +146,8 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
|
|||
}
|
||||
}
|
||||
|
||||
out.decreaseIdent()
|
||||
out.nextLine()
|
||||
out.writeString(")")
|
||||
|
||||
if s.orderBy != nil {
|
||||
|
|
@ -148,12 +158,14 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
|
|||
}
|
||||
|
||||
if s.limit >= 0 {
|
||||
out.writeString(" LIMIT ")
|
||||
out.nextLine()
|
||||
out.writeString("LIMIT")
|
||||
out.insertArgument(s.limit)
|
||||
}
|
||||
|
||||
if s.offset >= 0 {
|
||||
out.writeString(" OFFSET ")
|
||||
out.nextLine()
|
||||
out.writeString("OFFSET")
|
||||
out.insertArgument(s.offset)
|
||||
}
|
||||
|
||||
|
|
@ -169,7 +181,12 @@ func (us *setStatementImpl) Sql() (query string, args []interface{}, err error)
|
|||
return
|
||||
}
|
||||
|
||||
return queryData.buff.String(), queryData.args, nil
|
||||
query, args = queryData.finalize()
|
||||
return
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) DebugSql() (query string, err error) {
|
||||
return DebugSql(s)
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) Query(db types.Db, destination interface{}) error {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"gotest.tools/assert"
|
||||
"testing"
|
||||
)
|
||||
|
|
@ -28,7 +29,20 @@ func TestUnionTwoSelect(t *testing.T) {
|
|||
).Sql()
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2))`)
|
||||
fmt.Println(query)
|
||||
assert.Equal(t, query, `
|
||||
(
|
||||
(
|
||||
SELECT table1.col1 AS "table1.col1"
|
||||
FROM db.table1
|
||||
)
|
||||
UNION
|
||||
(
|
||||
SELECT table2.col3 AS "table2.col3"
|
||||
FROM db.table2
|
||||
)
|
||||
);
|
||||
`)
|
||||
assert.Equal(t, len(args), 0)
|
||||
}
|
||||
|
||||
|
|
@ -39,8 +53,26 @@ func TestUnionThreeSelect(t *testing.T) {
|
|||
table3.SELECT(table3Col1),
|
||||
).Sql()
|
||||
|
||||
fmt.Println(query)
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2) UNION (SELECT table3.col1 AS "table3.col1" FROM db.table3))`)
|
||||
assert.Equal(t, query, `
|
||||
(
|
||||
(
|
||||
SELECT table1.col1 AS "table1.col1"
|
||||
FROM db.table1
|
||||
)
|
||||
UNION
|
||||
(
|
||||
SELECT table2.col3 AS "table2.col3"
|
||||
FROM db.table2
|
||||
)
|
||||
UNION
|
||||
(
|
||||
SELECT table3.col1 AS "table3.col1"
|
||||
FROM db.table3
|
||||
)
|
||||
);
|
||||
`)
|
||||
assert.Equal(t, len(args), 0)
|
||||
}
|
||||
|
||||
|
|
@ -51,7 +83,21 @@ func TestUnionWithOrderBy(t *testing.T) {
|
|||
).ORDER_BY(table1Col1.ASC()).Sql()
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) ORDER BY "table1.col1" ASC`)
|
||||
fmt.Println(query)
|
||||
assert.Equal(t, query, `
|
||||
(
|
||||
(
|
||||
SELECT table1.col1 AS "table1.col1"
|
||||
FROM db.table1
|
||||
)
|
||||
UNION
|
||||
(
|
||||
SELECT table2.col3 AS "table2.col3"
|
||||
FROM db.table2
|
||||
)
|
||||
)
|
||||
ORDER BY "table1.col1" ASC;
|
||||
`)
|
||||
assert.Equal(t, len(args), 0)
|
||||
}
|
||||
|
||||
|
|
@ -62,6 +108,21 @@ func TestUnionWithLimit(t *testing.T) {
|
|||
).LIMIT(10).OFFSET(11).Sql()
|
||||
|
||||
assert.NilError(t, err)
|
||||
assert.Equal(t, query, `((SELECT table1.col1 AS "table1.col1" FROM db.table1) UNION (SELECT table2.col3 AS "table2.col3" FROM db.table2)) LIMIT $1 OFFSET $2`)
|
||||
fmt.Println(query)
|
||||
assert.Equal(t, query, `
|
||||
(
|
||||
(
|
||||
SELECT table1.col1 AS "table1.col1"
|
||||
FROM db.table1
|
||||
)
|
||||
UNION
|
||||
(
|
||||
SELECT table2.col3 AS "table2.col3"
|
||||
FROM db.table2
|
||||
)
|
||||
)
|
||||
LIMIT $1
|
||||
OFFSET $2;
|
||||
`)
|
||||
assert.Equal(t, len(args), 2)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,33 @@ package sqlbuilder
|
|||
import (
|
||||
"database/sql"
|
||||
"github.com/sub0zero/go-sqlbuilder/types"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type statement interface {
|
||||
type Statement interface {
|
||||
// String returns generated SQL as string.
|
||||
Sql() (query string, args []interface{}, err error)
|
||||
|
||||
DebugSql() (query string, err error)
|
||||
|
||||
Query(db types.Db, destination interface{}) error
|
||||
Execute(db types.Db) (sql.Result, error)
|
||||
}
|
||||
|
||||
func DebugSql(statement Statement) (string, error) {
|
||||
sql, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
debugSql := sql
|
||||
|
||||
for i, arg := range args {
|
||||
argPlaceholder := "$" + strconv.Itoa(i+1)
|
||||
debugSql = strings.Replace(debugSql, argPlaceholder, ArgToString(arg), 1)
|
||||
}
|
||||
|
||||
return debugSql, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ var _ = gc.Suite(&StmtSuite{})
|
|||
// NOTE: tables / columns are defined in test_utils.go
|
||||
|
||||
//
|
||||
// SELECT statement tests
|
||||
// SELECT Statement tests
|
||||
//
|
||||
|
||||
func (s *StmtSuite) TestSelectEmptyProjection(c *gc.C) {
|
||||
|
|
@ -233,7 +233,7 @@ func (s *StmtSuite) TestSelectDistinct(c *gc.C) {
|
|||
}
|
||||
|
||||
//
|
||||
// INSERT statement tests
|
||||
// INSERT Statement tests
|
||||
//
|
||||
|
||||
func (s *StmtSuite) TestInsertNoColumn(c *gc.C) {
|
||||
|
|
@ -386,7 +386,7 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) {
|
|||
}
|
||||
|
||||
//
|
||||
// LOCK/UNLOCK statement tests ================================================
|
||||
// LOCK/UNLOCK Statement tests ================================================
|
||||
//
|
||||
|
||||
func (s *StmtSuite) TestLockStatement(c *gc.C) {
|
||||
|
|
@ -444,7 +444,7 @@ func (s *StmtSuite) TestUnionLimitWithoutOrderBy(c *gc.C) {
|
|||
c.Assert(
|
||||
errors.GetMessage(err),
|
||||
gc.Equals,
|
||||
"All inner selects in UNION statement must have LIMIT if they have ORDER BY")
|
||||
"All inner selects in UNION Statement must have LIMIT if they have ORDER BY")
|
||||
}
|
||||
|
||||
func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
|
||||
|
|
@ -472,7 +472,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
|
|||
c.Assert(
|
||||
errors.GetMessage(err),
|
||||
gc.Equals,
|
||||
"All inner selects in UNION statement must select the "+
|
||||
"All inner selects in UNION Statement must select the "+
|
||||
"same number of columns. For sanity, you probably "+
|
||||
"want to select the same tableName columns in the same "+
|
||||
"orderBy. If you are selecting on multiple tables, "+
|
||||
|
|
@ -481,8 +481,8 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
|
|||
|
||||
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
|
||||
|
||||
// tests on outer statement: Group By, Order By, LIMIT
|
||||
// on inner statement: AndWhere, WHERE (with AND), Order By, LIMIT
|
||||
// tests on outer Statement: Group By, Order By, LIMIT
|
||||
// on inner Statement: AndWhere, WHERE (with AND), Order By, LIMIT
|
||||
select_queries := make([]selectStatement, 0, 3)
|
||||
|
||||
// We're not trying to write a SQL parser, so we won't warn if you do something silly like
|
||||
|
|
|
|||
|
|
@ -106,7 +106,7 @@ func (t *Table) Columns() []column {
|
|||
}
|
||||
|
||||
// Generates the sql string for the current tableName expression. Note: the
|
||||
// generated string may not be a valid/executable sql statement.
|
||||
// generated string may not be a valid/executable sql Statement.
|
||||
func (t *Table) serialize(statement statementType, out *queryData) error {
|
||||
if t == nil {
|
||||
return errors.Newf("nil tableName.")
|
||||
|
|
@ -287,17 +287,19 @@ func (t *joinTable) serialize(statement statementType, out *queryData) (err erro
|
|||
return
|
||||
}
|
||||
|
||||
out.nextLine()
|
||||
|
||||
switch t.join_type {
|
||||
case INNER_JOIN:
|
||||
out.writeString(" JOIN ")
|
||||
out.writeString("JOIN")
|
||||
case LEFT_JOIN:
|
||||
out.writeString(" LEFT JOIN ")
|
||||
out.writeString("LEFT JOIN")
|
||||
case RIGHT_JOIN:
|
||||
out.writeString(" RIGHT JOIN ")
|
||||
out.writeString("RIGHT JOIN")
|
||||
case FULL_JOIN:
|
||||
out.writeString(" FULL JOIN ")
|
||||
out.writeString("FULL JOIN")
|
||||
case CROSS_JOIN:
|
||||
out.writeString(" CROSS JOIN ")
|
||||
out.writeString("CROSS JOIN")
|
||||
}
|
||||
|
||||
if err = t.rhs.serialize(statement, out); err != nil {
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
type updateStatement interface {
|
||||
statement
|
||||
Statement
|
||||
|
||||
SET(values ...interface{}) updateStatement
|
||||
WHERE(expression boolExpression) updateStatement
|
||||
|
|
@ -55,7 +55,8 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) updateStateme
|
|||
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
||||
out := &queryData{}
|
||||
|
||||
out.writeString("UPDATE ")
|
||||
out.nextLine()
|
||||
out.writeString("UPDATE")
|
||||
|
||||
if u.table == nil {
|
||||
return "", nil, errors.New("nil tableName.")
|
||||
|
|
@ -69,12 +70,10 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
return "", nil, errors.New("No column updated.")
|
||||
}
|
||||
|
||||
out.writeString(" SET")
|
||||
out.writeString("SET")
|
||||
|
||||
if len(u.columns) > 1 {
|
||||
out.writeString(" ( ")
|
||||
} else {
|
||||
out.writeString(" ")
|
||||
out.writeString("(")
|
||||
}
|
||||
|
||||
err = serializeColumnList(update_statement, u.columns, out)
|
||||
|
|
@ -84,13 +83,13 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if len(u.columns) > 1 {
|
||||
out.writeString(" )")
|
||||
out.writeString(")")
|
||||
}
|
||||
|
||||
out.writeString(" =")
|
||||
out.writeString("=")
|
||||
|
||||
if len(u.updateValues) > 1 {
|
||||
out.writeString(" (")
|
||||
out.writeString("(")
|
||||
}
|
||||
|
||||
for i, value := range u.updateValues {
|
||||
|
|
@ -106,7 +105,7 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if len(u.updateValues) > 1 {
|
||||
out.writeString(" )")
|
||||
out.writeString(")")
|
||||
}
|
||||
|
||||
if u.where == nil {
|
||||
|
|
@ -118,7 +117,8 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
|
||||
if len(u.returning) > 0 {
|
||||
out.writeString(" RETURNING ")
|
||||
out.nextLine()
|
||||
out.writeString("RETURNING")
|
||||
|
||||
err = serializeProjectionList(update_statement, u.returning, out)
|
||||
|
||||
|
|
@ -127,7 +127,12 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
|
|||
}
|
||||
}
|
||||
|
||||
return out.buff.String(), out.args, nil
|
||||
sql, args = out.finalize()
|
||||
return
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) DebugSql() (query string, err error) {
|
||||
return DebugSql(u)
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
|
||||
|
|
|
|||
|
|
@ -7,19 +7,30 @@ import (
|
|||
)
|
||||
|
||||
//
|
||||
// UPDATE statement tests =====================================================
|
||||
// UPDATE Statement tests =====================================================
|
||||
//
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
stmt := table1.UPDATE(table1Col1, table1Col2).
|
||||
SET(table1.SELECT(table1Col2)).
|
||||
WHERE(table1Col1.EqL(2))
|
||||
SET(table1.SELECT(table1Col2, table2Col3)).
|
||||
WHERE(table1Col1.EqL(2)).
|
||||
RETURNING(table1Col1)
|
||||
|
||||
stmtStr, _, err := stmt.Sql()
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
fmt.Println(stmtStr)
|
||||
|
||||
assert.Equal(t, stmtStr, `
|
||||
UPDATE db.table1 SET (col1,col2) = (
|
||||
SELECT table1.col2 AS "table1.col2",
|
||||
table2.col3 AS "table2.col3"
|
||||
FROM db.table1
|
||||
)
|
||||
WHERE table1.col1 = $1
|
||||
RETURNING table1.col1 AS "table1.col1";
|
||||
`)
|
||||
}
|
||||
|
||||
//func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) {
|
||||
|
|
|
|||
|
|
@ -82,8 +82,10 @@ func serializeExpressionList(statement statementType, expressions []expression,
|
|||
func serializeProjectionList(statement statementType, projections []projection, out *queryData) error {
|
||||
for i, col := range projections {
|
||||
if i > 0 {
|
||||
out.writeString(", ")
|
||||
out.writeString(",")
|
||||
out.nextLine()
|
||||
}
|
||||
|
||||
if col == nil {
|
||||
return errors.New("projection expression is nil.")
|
||||
}
|
||||
|
|
@ -112,7 +114,7 @@ func serializeColumnList(statement statementType, columns []column, out *queryDa
|
|||
return nil
|
||||
}
|
||||
|
||||
func Query(statement statement, db types.Db, destination interface{}) error {
|
||||
func Query(statement Statement, db types.Db, destination interface{}) error {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -122,7 +124,7 @@ func Query(statement statement, db types.Db, destination interface{}) error {
|
|||
return execution.Query(db, query, args, destination)
|
||||
}
|
||||
|
||||
func Execute(statement statement, db types.Db) (res sql.Result, err error) {
|
||||
func Execute(statement Statement, db types.Db) (res sql.Result, err error) {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue