jet/sqlbuilder/update_statement.go

145 lines
2.7 KiB
Go
Raw Normal View History

2019-04-14 17:55:10 +02:00
package sqlbuilder
import (
"database/sql"
2019-06-05 17:15:20 +02:00
"errors"
2019-06-05 17:56:24 +02:00
"github.com/go-jet/jet/sqlbuilder/execution"
2019-04-14 17:55:10 +02:00
)
2019-06-04 12:10:23 +02:00
type UpdateStatement interface {
2019-05-12 18:15:23 +02:00
Statement
2019-04-14 17:55:10 +02:00
2019-06-04 12:10:23 +02:00
SET(values ...interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...projection) UpdateStatement
2019-04-14 17:55:10 +02:00
}
2019-06-05 17:15:20 +02:00
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
2019-04-14 17:55:10 +02:00
return &updateStatementImpl{
table: table,
columns: columns,
}
}
type updateStatementImpl struct {
2019-06-05 17:15:20 +02:00
table WritableTable
columns []Column
2019-05-07 19:06:21 +02:00
updateValues []clause
where BoolExpression
2019-05-07 19:06:21 +02:00
returning []projection
2019-04-14 17:55:10 +02:00
}
2019-06-04 12:10:23 +02:00
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
2019-04-14 17:55:10 +02:00
for _, value := range values {
2019-05-07 19:06:21 +02:00
if clause, ok := value.(clause); ok {
2019-04-14 17:55:10 +02:00
u.updateValues = append(u.updateValues, clause)
} else {
u.updateValues = append(u.updateValues, literal(value))
2019-04-14 17:55:10 +02:00
}
}
return u
}
2019-06-04 12:10:23 +02:00
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
2019-04-14 17:55:10 +02:00
u.where = expression
return u
}
2019-06-04 12:10:23 +02:00
func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStatement {
2019-06-09 11:06:08 +02:00
u.returning = projections
2019-04-14 17:55:10 +02:00
return u
}
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
2019-05-03 12:51:57 +02:00
2019-05-12 18:15:23 +02:00
out.nextLine()
out.writeString("UPDATE")
2019-04-14 17:55:10 +02:00
if u.table == nil {
return "", nil, errors.New("nil tableName.")
2019-04-14 17:55:10 +02:00
}
if err = u.table.serialize(update_statement, out); err != nil {
2019-04-14 17:55:10 +02:00
return
}
if len(u.updateValues) == 0 {
return "", nil, errors.New("No column updated.")
2019-04-14 17:55:10 +02:00
}
2019-05-12 18:15:23 +02:00
out.writeString("SET")
2019-04-14 17:55:10 +02:00
if len(u.columns) > 1 {
2019-05-12 18:15:23 +02:00
out.writeString("(")
2019-04-14 17:55:10 +02:00
}
err = serializeColumnList(update_statement, u.columns, out)
if err != nil {
return "", nil, err
2019-04-14 17:55:10 +02:00
}
if len(u.columns) > 1 {
2019-05-12 18:15:23 +02:00
out.writeString(")")
2019-04-14 17:55:10 +02:00
}
2019-05-12 18:15:23 +02:00
out.writeString("=")
2019-04-14 17:55:10 +02:00
if len(u.updateValues) > 1 {
2019-05-12 18:15:23 +02:00
out.writeString("(")
2019-04-14 17:55:10 +02:00
}
for i, value := range u.updateValues {
if i > 0 {
out.writeString(", ")
2019-04-14 17:55:10 +02:00
}
err = value.serialize(update_statement, out)
2019-04-14 17:55:10 +02:00
if err != nil {
return
}
}
if len(u.updateValues) > 1 {
2019-05-12 18:15:23 +02:00
out.writeString(")")
2019-04-14 17:55:10 +02:00
}
if u.where == nil {
return "", nil, errors.New("Updating without a WHERE clause.")
2019-04-14 17:55:10 +02:00
}
if err = out.writeWhere(update_statement, u.where); err != nil {
2019-04-14 17:55:10 +02:00
return
}
if len(u.returning) > 0 {
2019-05-12 18:15:23 +02:00
out.nextLine()
out.writeString("RETURNING")
2019-04-14 17:55:10 +02:00
err = serializeProjectionList(update_statement, u.returning, out)
2019-04-14 17:55:10 +02:00
if err != nil {
return
}
}
2019-05-12 18:15:23 +02:00
sql, args = out.finalize()
return
}
func (u *updateStatementImpl) DebugSql() (query string, err error) {
return DebugSql(u)
2019-04-14 17:55:10 +02:00
}
2019-05-03 12:51:57 +02:00
func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) error {
2019-05-03 12:51:57 +02:00
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
2019-05-03 12:51:57 +02:00
return Execute(u, db)
}