jet/sqlbuilder/update_statement.go
2019-05-08 13:47:01 +02:00

139 lines
2.7 KiB
Go

package sqlbuilder
import (
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/types"
)
type updateStatement interface {
statement
SET(values ...interface{}) updateStatement
WHERE(expression boolExpression) updateStatement
RETURNING(projections ...projection) updateStatement
}
func newUpdateStatement(table writableTable, columns []column) updateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
}
}
type updateStatementImpl struct {
table writableTable
columns []column
updateValues []clause
where boolExpression
returning []projection
}
func (u *updateStatementImpl) SET(values ...interface{}) updateStatement {
for _, value := range values {
if clause, ok := value.(clause); ok {
u.updateValues = append(u.updateValues, clause)
} else {
u.updateValues = append(u.updateValues, Literal(value))
}
}
return u
}
func (u *updateStatementImpl) WHERE(expression boolExpression) updateStatement {
u.where = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...projection) updateStatement {
u.returning = defaultProjectionAliasing(projections)
return u
}
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
out.writeString("UPDATE ")
if u.table == nil {
return "", nil, errors.New("nil tableName.")
}
if err = u.table.serialize(update_statement, out); err != nil {
return
}
if len(u.updateValues) == 0 {
return "", nil, errors.New("No column updated.")
}
out.writeString(" SET")
if len(u.columns) > 1 {
out.writeString(" ( ")
} else {
out.writeString(" ")
}
err = serializeColumnList(update_statement, u.columns, out)
if err != nil {
return "", nil, err
}
if len(u.columns) > 1 {
out.writeString(" )")
}
out.writeString(" =")
if len(u.updateValues) > 1 {
out.writeString(" (")
}
for i, value := range u.updateValues {
if i > 0 {
out.writeString(", ")
}
err = value.serialize(update_statement, out)
if err != nil {
return
}
}
if len(u.updateValues) > 1 {
out.writeString(" )")
}
if u.where == nil {
return "", nil, errors.New("Updating without a WHERE clause.")
}
if err = out.writeWhere(update_statement, u.where); err != nil {
return
}
if len(u.returning) > 0 {
out.writeString(" RETURNING ")
err = serializeProjectionList(update_statement, u.returning, out)
if err != nil {
return
}
}
return out.buff.String(), out.args, nil
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return Execute(u, db)
}