Add new typesafe SET operator for UPDATE statement.

This commit is contained in:
go-jet 2020-05-09 10:49:09 +02:00
parent a4b4710637
commit ebcbadef24
11 changed files with 269 additions and 123 deletions

View file

@ -8,7 +8,7 @@ type Table interface {
readableTable
INSERT(columns ...jet.Column) InsertStatement
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
UPDATE(columns ...jet.Column) UpdateStatement
DELETE() DeleteStatement
LOCK() LockStatement
}
@ -35,7 +35,7 @@ type readableTable interface {
type joinSelectUpdateTable interface {
ReadableTable
UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement
UPDATE(columns ...jet.Column) UpdateStatement
}
// ReadableTable interface
@ -98,8 +98,8 @@ func (t *tableImpl) INSERT(columns ...jet.Column) InsertStatement {
return newInsertStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) UPDATE(column jet.Column, columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwindColumns(column, columns...))
func (t *tableImpl) UPDATE(columns ...jet.Column) UpdateStatement {
return newUpdateStatement(t.parent, jet.UnwidColumnList(columns))
}
func (t *tableImpl) DELETE() DeleteStatement {

View file

@ -16,14 +16,18 @@ type updateStatementImpl struct {
jet.SerializerStatement
Update jet.ClauseUpdate
Set jet.ClauseSet
Set jet.SetClause
SetNew jet.SetClauseNew
Where jet.ClauseWhere
}
func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
update := &updateStatementImpl{}
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update, &update.Update,
&update.Set, &update.Where)
update.SerializerStatement = jet.NewStatementImpl(Dialect, jet.UpdateStatementType, update,
&update.Update,
&update.Set,
&update.SetNew,
&update.Where)
update.Update.Table = table
update.Set.Columns = columns
@ -33,7 +37,17 @@ func newUpdateStatement(table Table, columns []jet.Column) UpdateStatement {
}
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.Set.Values = jet.UnwindRowFromValues(value, values)
columnAssigment, isColumnAssigment := value.(ColumnAssigment)
if isColumnAssigment {
u.SetNew = []ColumnAssigment{columnAssigment}
for _, value := range values {
u.SetNew = append(u.SetNew, value.(ColumnAssigment))
}
} else {
u.Set.Values = jet.UnwindRowFromValues(value, values)
}
return u
}

View file

@ -23,7 +23,7 @@ WHERE table1.col_int >= ?;
func TestUpdateWithValues(t *testing.T) {
expectedSQL := `
UPDATE db.table1
SET col_int = ?,
SET col_int = ?,
col_float = ?
WHERE table1.col_int >= ?;
`