jet/sqlbuilder/update_statement.go

152 lines
2.8 KiB
Go
Raw Normal View History

2019-04-14 17:55:10 +02:00
package sqlbuilder
import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
type UpdateStatement interface {
Statement
SET(values ...interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
}
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
}
}
type updateStatementImpl struct {
table WritableTable
columns []Column
updateValues []Clause
where BoolExpression
returning []Projection
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
2019-04-20 19:49:29 +02:00
return Query(u, db, destination)
2019-04-14 17:55:10 +02:00
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
2019-04-20 19:49:29 +02:00
return Execute(u, db)
2019-04-14 17:55:10 +02:00
}
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
for _, value := range values {
if clause, ok := value.(Clause); ok {
u.updateValues = append(u.updateValues, clause)
} else {
u.updateValues = append(u.updateValues, Literal(value))
}
}
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.where = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.returning = projections
return u
}
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
out.WriteString("UPDATE ")
2019-04-14 17:55:10 +02:00
if u.table == nil {
return "", nil, errors.New("nil tableName.")
2019-04-14 17:55:10 +02:00
}
if err = u.table.SerializeSql(out); err != nil {
2019-04-14 17:55:10 +02:00
return
}
if len(u.updateValues) == 0 {
return "", nil, errors.New("No column updated.")
2019-04-14 17:55:10 +02:00
}
out.WriteString(" SET")
2019-04-14 17:55:10 +02:00
if len(u.columns) > 1 {
out.WriteString(" ( ")
2019-04-14 17:55:10 +02:00
} else {
out.WriteString(" ")
2019-04-14 17:55:10 +02:00
}
//for i, column := range u.columns {
// if i > 0 {
// out.WriteString(", ")
// }
//
// out.WriteString(column.Name())
//
// if err != nil {
// return
// }
//}
err = serializeColumnList(u.columns, out)
if err != nil {
return "", nil, err
2019-04-14 17:55:10 +02:00
}
if len(u.columns) > 1 {
out.WriteString(" )")
2019-04-14 17:55:10 +02:00
}
out.WriteString(" =")
2019-04-14 17:55:10 +02:00
if len(u.updateValues) > 1 {
out.WriteString(" (")
2019-04-14 17:55:10 +02:00
}
for i, value := range u.updateValues {
if i > 0 {
out.WriteString(", ")
2019-04-14 17:55:10 +02:00
}
err = value.Serialize(out)
2019-04-14 17:55:10 +02:00
if err != nil {
return
}
}
if len(u.updateValues) > 1 {
out.WriteString(" )")
2019-04-14 17:55:10 +02:00
}
if u.where == nil {
return "", nil, errors.New("Updating without a WHERE clause.")
2019-04-14 17:55:10 +02:00
}
out.WriteString(" WHERE ")
if err = u.where.Serialize(out); err != nil {
2019-04-14 17:55:10 +02:00
return
}
if len(u.returning) > 0 {
out.WriteString(" RETURNING ")
2019-04-14 17:55:10 +02:00
err = serializeProjectionList(u.returning, out)
2019-04-14 17:55:10 +02:00
if err != nil {
return
}
}
2019-05-01 14:42:46 +02:00
return out.buff.String(), out.args, nil
2019-04-14 17:55:10 +02:00
}