2019-04-14 17:55:10 +02:00
|
|
|
package sqlbuilder
|
|
|
|
|
|
|
|
|
|
import (
|
|
|
|
|
"database/sql"
|
2019-06-05 17:15:20 +02:00
|
|
|
"errors"
|
2019-06-05 17:56:24 +02:00
|
|
|
"github.com/go-jet/jet/sqlbuilder/execution"
|
2019-04-14 17:55:10 +02:00
|
|
|
)
|
|
|
|
|
|
2019-06-04 12:10:23 +02:00
|
|
|
type UpdateStatement interface {
|
2019-05-12 18:15:23 +02:00
|
|
|
Statement
|
2019-04-14 17:55:10 +02:00
|
|
|
|
2019-06-04 12:10:23 +02:00
|
|
|
SET(values ...interface{}) UpdateStatement
|
|
|
|
|
WHERE(expression BoolExpression) UpdateStatement
|
|
|
|
|
RETURNING(projections ...projection) UpdateStatement
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-06-05 17:15:20 +02:00
|
|
|
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
|
2019-04-14 17:55:10 +02:00
|
|
|
return &updateStatementImpl{
|
|
|
|
|
table: table,
|
|
|
|
|
columns: columns,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
type updateStatementImpl struct {
|
2019-06-05 17:15:20 +02:00
|
|
|
table WritableTable
|
|
|
|
|
columns []Column
|
2019-05-07 19:06:21 +02:00
|
|
|
updateValues []clause
|
2019-05-31 12:59:57 +02:00
|
|
|
where BoolExpression
|
2019-05-07 19:06:21 +02:00
|
|
|
returning []projection
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-06-04 12:10:23 +02:00
|
|
|
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
|
2019-04-14 17:55:10 +02:00
|
|
|
|
|
|
|
|
for _, value := range values {
|
2019-05-07 19:06:21 +02:00
|
|
|
if clause, ok := value.(clause); ok {
|
2019-04-14 17:55:10 +02:00
|
|
|
u.updateValues = append(u.updateValues, clause)
|
|
|
|
|
} else {
|
2019-06-04 11:52:37 +02:00
|
|
|
u.updateValues = append(u.updateValues, literal(value))
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return u
|
|
|
|
|
}
|
|
|
|
|
|
2019-06-04 12:10:23 +02:00
|
|
|
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
|
2019-04-14 17:55:10 +02:00
|
|
|
u.where = expression
|
|
|
|
|
return u
|
|
|
|
|
}
|
|
|
|
|
|
2019-06-04 12:10:23 +02:00
|
|
|
func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStatement {
|
2019-05-03 12:51:57 +02:00
|
|
|
u.returning = defaultProjectionAliasing(projections)
|
2019-04-14 17:55:10 +02:00
|
|
|
return u
|
|
|
|
|
}
|
|
|
|
|
|
2019-04-29 14:39:48 +02:00
|
|
|
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
|
|
|
|
|
out := &queryData{}
|
2019-05-03 12:51:57 +02:00
|
|
|
|
2019-05-12 18:15:23 +02:00
|
|
|
out.nextLine()
|
|
|
|
|
out.writeString("UPDATE")
|
2019-04-14 17:55:10 +02:00
|
|
|
|
|
|
|
|
if u.table == nil {
|
2019-04-29 14:39:48 +02:00
|
|
|
return "", nil, errors.New("nil tableName.")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-05-08 13:47:01 +02:00
|
|
|
if err = u.table.serialize(update_statement, out); err != nil {
|
2019-04-14 17:55:10 +02:00
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(u.updateValues) == 0 {
|
2019-04-29 14:39:48 +02:00
|
|
|
return "", nil, errors.New("No column updated.")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-05-12 18:15:23 +02:00
|
|
|
out.writeString("SET")
|
2019-04-14 17:55:10 +02:00
|
|
|
|
|
|
|
|
if len(u.columns) > 1 {
|
2019-05-12 18:15:23 +02:00
|
|
|
out.writeString("(")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-05-08 13:47:01 +02:00
|
|
|
err = serializeColumnList(update_statement, u.columns, out)
|
2019-04-29 14:39:48 +02:00
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return "", nil, err
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(u.columns) > 1 {
|
2019-05-12 18:15:23 +02:00
|
|
|
out.writeString(")")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-05-12 18:15:23 +02:00
|
|
|
out.writeString("=")
|
2019-04-14 17:55:10 +02:00
|
|
|
|
|
|
|
|
if len(u.updateValues) > 1 {
|
2019-05-12 18:15:23 +02:00
|
|
|
out.writeString("(")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for i, value := range u.updateValues {
|
|
|
|
|
if i > 0 {
|
2019-05-08 13:47:01 +02:00
|
|
|
out.writeString(", ")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-05-08 13:47:01 +02:00
|
|
|
err = value.serialize(update_statement, out)
|
2019-04-14 17:55:10 +02:00
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(u.updateValues) > 1 {
|
2019-05-12 18:15:23 +02:00
|
|
|
out.writeString(")")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if u.where == nil {
|
2019-04-29 14:39:48 +02:00
|
|
|
return "", nil, errors.New("Updating without a WHERE clause.")
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
|
|
|
|
|
2019-05-08 13:47:01 +02:00
|
|
|
if err = out.writeWhere(update_statement, u.where); err != nil {
|
2019-04-14 17:55:10 +02:00
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if len(u.returning) > 0 {
|
2019-05-12 18:15:23 +02:00
|
|
|
out.nextLine()
|
|
|
|
|
out.writeString("RETURNING")
|
2019-04-14 17:55:10 +02:00
|
|
|
|
2019-05-08 13:47:01 +02:00
|
|
|
err = serializeProjectionList(update_statement, u.returning, out)
|
2019-04-14 17:55:10 +02:00
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2019-05-12 18:15:23 +02:00
|
|
|
sql, args = out.finalize()
|
|
|
|
|
return
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func (u *updateStatementImpl) DebugSql() (query string, err error) {
|
|
|
|
|
return DebugSql(u)
|
2019-04-14 17:55:10 +02:00
|
|
|
}
|
2019-05-03 12:51:57 +02:00
|
|
|
|
2019-05-27 13:11:15 +02:00
|
|
|
func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) error {
|
2019-05-03 12:51:57 +02:00
|
|
|
return Query(u, db, destination)
|
|
|
|
|
}
|
|
|
|
|
|
2019-05-27 13:11:15 +02:00
|
|
|
func (u *updateStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
|
2019-05-03 12:51:57 +02:00
|
|
|
return Execute(u, db)
|
|
|
|
|
}
|