Add support for Update statements.

This commit is contained in:
zer0sub 2019-04-14 17:55:10 +02:00
parent b287521f1a
commit 70d6f84375
12 changed files with 422 additions and 286 deletions

View file

@ -18,7 +18,7 @@ type InsertStatement interface {
// Map or stracture mapped to column names
VALUES_MAPPING(data interface{}) InsertStatement
RETURNING(column ...Expression) InsertStatement
RETURNING(projections ...Projection) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
@ -27,10 +27,8 @@ type InsertStatement interface {
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
rows: make([][]Clause, 0, 1),
returning: make([]Expression, 0, 1),
table: t,
columns: columns,
}
}
@ -44,7 +42,7 @@ type insertStatementImpl struct {
columns []Column
rows [][]Clause
query SelectStatement
returning []Expression
returning []Projection
errors []string
}
@ -114,8 +112,8 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
return i
}
func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement {
i.returning = column
func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement {
i.returning = projections
return i
}
@ -217,16 +215,10 @@ func (s *insertStatementImpl) String() (sql string, err error) {
if len(s.returning) > 0 {
buf.WriteString(" RETURNING ")
for i, column := range s.returning {
if i > 0 {
buf.WriteString(",")
}
err = serializeProjectionList(s.returning, buf)
err = column.SerializeSql(buf)
if err != nil {
return
}
if err != nil {
return
}
}

View file

@ -6,7 +6,6 @@ import (
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
"reflect"
)
type SelectStatement interface {
@ -88,12 +87,6 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
}
func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error {
destinationType := reflect.TypeOf(destination)
if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct {
s.Limit(1)
}
query, err := s.String()
if err != nil {

View file

@ -34,16 +34,6 @@ type UnionStatement interface {
Offset(offset int64) UnionStatement
}
type UpdateStatement interface {
Statement
Set(column Column, expression Expression) UpdateStatement
Where(expression BoolExpression) UpdateStatement
OrderBy(clauses ...OrderByClause) UpdateStatement
Limit(limit int64) UpdateStatement
Comment(comment string) UpdateStatement
}
type DeleteStatement interface {
Statement
@ -250,151 +240,6 @@ func (us *unionStatementImpl) String() (sql string, err error) {
return buf.String(), nil
}
//
// UPDATE statement ===========================================================
//
func newUpdateStatement(table WritableTable) UpdateStatement {
return &updateStatementImpl{
table: table,
updateValues: make(map[Column]Expression),
limit: -1,
}
}
type updateStatementImpl struct {
table WritableTable
updateValues map[Column]Expression
where BoolExpression
order *listClause
limit int64
comment string
}
func (u *updateStatementImpl) Execute(db *sql.DB, data interface{}) error {
return nil
}
func (u *updateStatementImpl) Set(
column Column,
expression Expression) UpdateStatement {
u.updateValues[column] = expression
return u
}
func (u *updateStatementImpl) Where(expression BoolExpression) UpdateStatement {
u.where = expression
return u
}
func (u *updateStatementImpl) OrderBy(
clauses ...OrderByClause) UpdateStatement {
u.order = newOrderByListClause(clauses...)
return u
}
func (u *updateStatementImpl) Limit(limit int64) UpdateStatement {
u.limit = limit
return u
}
func (u *updateStatementImpl) Comment(comment string) UpdateStatement {
u.comment = comment
return u
}
func (u *updateStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("UPDATE ")
if err = writeComment(u.comment, buf); err != nil {
return
}
if u.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = u.table.SerializeSql(buf); err != nil {
return
}
if len(u.updateValues) == 0 {
return "", errors.Newf(
"No column updated. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" SET ")
addComma := false
// Sorting is too hard in go, just create a second map ...
updateValues := make(map[string]Expression)
for col, expr := range u.updateValues {
if col == nil {
return "", errors.Newf(
"nil column. Generated sql: %s",
buf.String())
}
updateValues[col.Name()] = expr
}
for _, col := range u.table.Columns() {
val, inMap := updateValues[col.Name()]
if !inMap {
continue
}
if addComma {
_, _ = buf.WriteString(", ")
}
if val == nil {
return "", errors.Newf(
"nil value. Generated sql: %s",
buf.String())
}
if err = col.SerializeSql(buf); err != nil {
return
}
_ = buf.WriteByte('=')
if err = val.SerializeSql(buf); err != nil {
return
}
addComma = true
}
if u.where == nil {
return "", errors.Newf(
"Updating without a WHERE clause. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" WHERE ")
if err = u.where.SerializeSql(buf); err != nil {
return
}
if u.order != nil {
_, _ = buf.WriteString(" ORDER BY ")
if err = u.order.SerializeSql(buf); err != nil {
return
}
}
if u.limit >= 0 {
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", u.limit))
}
return buf.String(), nil
}
//
// DELETE statement ===========================================================
//
@ -565,7 +410,7 @@ func (s *unlockStatementImpl) String() (sql string, err error) {
return "UNLOCK TABLES", nil
}
// Set GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID.
// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID.
func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement {
return &gtidNextStatementImpl{
sid: sid,

View file

@ -385,100 +385,6 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) {
"ON DUPLICATE KEY UPDATE table1.col3=3, table1.col2=4")
}
//
// UPDATE statement tests =====================================================
//
func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) {
stmt := table1.Update().Set(nil, Literal(1))
_, err := stmt.String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) {
stmt := table1.Update().Set(table1Col1, nil)
_, err := stmt.String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) {
stmt := table1.Update().Set(table1Col1, Literal(1))
_, err := stmt.String()
c.Assert(err, gc.NotNil)
}
func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
stmt := table1.Update().Set(table1Col1, Literal(1))
stmt.Where(EqL(table1Col2, 2))
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
}
func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
stmt := table1.Update().Set(table1.C("col1"), Literal(1))
stmt.Where(EqL(table1Col2, 2))
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
}
func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) {
stmt := table1.Update()
stmt.Set(table1Col1, Literal(1))
stmt.Set(table1Col2, Literal(2))
stmt.Where(EqL(table1Col2, 3))
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"UPDATE db.table1 "+
"SET table1.col1=1, table1.col2=2 "+
"WHERE table1.col2=3")
}
func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
stmt := table1.Update().Set(table1Col1, Literal(1))
stmt.Where(EqL(table1Col2, 2))
stmt.OrderBy(table1Col2)
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"UPDATE db.table1 "+
"SET table1.col1=1 "+
"WHERE table1.col2=2 "+
"ORDER BY table1.col2")
}
func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
stmt := table1.Update().Set(table1Col1, Literal(1))
stmt.Where(EqL(table1Col2, 2))
stmt.Limit(5)
sql, err := stmt.String()
c.Assert(err, gc.IsNil)
c.Assert(
sql,
gc.Equals,
"UPDATE db.table1 "+
"SET table1.col1=1 "+
"WHERE table1.col2=2 "+
"LIMIT 5")
}
//
// DELETE statement tests =====================================================
//
@ -619,7 +525,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
// tests on outer statement: Group By, Order By, Limit
// on inner statement: AndWhere, Where (with And), Order By, Limit
// on inner statement: AndWhere, WHERE (with And), Order By, Limit
select_queries := make([]SelectStatement, 0, 3)
// We're not trying to write a SQL parser, so we won't warn if you do something silly like

View file

@ -45,7 +45,7 @@ type WritableTable interface {
TableInterface
INSERT(columns ...Column) InsertStatement
Update() UpdateStatement
UPDATE(columns ...Column) UpdateStatement
Delete() DeleteStatement
}
@ -229,8 +229,8 @@ func (t *Table) INSERT(columns ...Column) InsertStatement {
return newInsertStatement(t, columns...)
}
func (t *Table) Update() UpdateStatement {
return newUpdateStatement(t)
func (t *Table) UPDATE(columns ...Column) UpdateStatement {
return newUpdateStatement(t, columns)
}
func (t *Table) Delete() DeleteStatement {

View file

@ -0,0 +1,168 @@
package sqlbuilder
import (
"bytes"
"database/sql"
"github.com/dropbox/godropbox/errors"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
"github.com/sub0zero/go-sqlbuilder/types"
)
type UpdateStatement interface {
Statement
SET(values ...interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...Projection) UpdateStatement
Query(db types.Db, destination interface{}) error
Execute(db types.Db) (sql.Result, error)
}
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
}
}
type updateStatementImpl struct {
table WritableTable
columns []Column
updateValues []Clause
where BoolExpression
returning []Projection
}
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
query, err := u.String()
if err != nil {
return err
}
return execution.Execute(db, query, destination)
}
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
query, err := u.String()
if err != nil {
return
}
res, err = db.Exec(query)
return
}
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
for _, value := range values {
if clause, ok := value.(Clause); ok {
u.updateValues = append(u.updateValues, clause)
} else {
u.updateValues = append(u.updateValues, Literal(value))
}
}
return u
}
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
u.where = expression
return u
}
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
u.returning = projections
return u
}
func (u *updateStatementImpl) String() (sql string, err error) {
buf := new(bytes.Buffer)
_, _ = buf.WriteString("UPDATE ")
if u.table == nil {
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
}
if err = u.table.SerializeSql(buf); err != nil {
return
}
if len(u.updateValues) == 0 {
return "", errors.Newf(
"No column updated. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" SET")
if len(u.columns) > 1 {
buf.WriteString(" ( ")
} else {
buf.WriteString(" ")
}
for i, column := range u.columns {
if i > 0 {
buf.WriteString(", ")
}
buf.WriteString(column.Name())
if err != nil {
return
}
}
if len(u.columns) > 1 {
buf.WriteString(" )")
}
buf.WriteString(" =")
if len(u.updateValues) > 1 {
buf.WriteString(" (")
}
for i, value := range u.updateValues {
if i > 0 {
buf.WriteString(", ")
}
err = value.SerializeSql(buf)
if err != nil {
return
}
}
if len(u.updateValues) > 1 {
buf.WriteString(" )")
}
if u.where == nil {
return "", errors.Newf(
"Updating without a WHERE clause. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(" WHERE ")
if err = u.where.SerializeSql(buf); err != nil {
return
}
if len(u.returning) > 0 {
buf.WriteString(" RETURNING ")
err = serializeProjectionList(u.returning, buf)
if err != nil {
return
}
}
return buf.String() + ";", nil
}

View file

@ -0,0 +1,113 @@
package sqlbuilder
import (
"fmt"
"gotest.tools/assert"
"testing"
)
//
// UPDATE statement tests =====================================================
//
func TestUpdate(t *testing.T) {
stmt := table1.UPDATE(table1Col1, table1Col2).
SET(table1.SELECT(table1Col2)).
WHERE(table1Col1.EqL(2))
stmtStr, err := stmt.String()
assert.NilError(t, err)
fmt.Println(stmtStr)
}
//func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) {
// stmt := table1.UPDATE().SET(nil, Literal(1))
// _, err := stmt.String()
// c.Assert(err, gc.NotNil)
//}
//
//func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, nil)
// _, err := stmt.String()
// c.Assert(err, gc.NotNil)
//}
//
//func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// _, err := stmt.String()
// c.Assert(err, gc.NotNil)
//}
//
//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
//}
//
//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
// stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
//}
//
//func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) {
// stmt := table1.UPDATE()
// stmt.SET(table1Col1, Literal(1))
// stmt.SET(table1Col2, Literal(2))
// stmt.WHERE(EqL(table1Col2, 3))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 "+
// "SET table1.col1=1, table1.col2=2 "+
// "WHERE table1.col2=3")
//}
//
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.OrderBy(table1Col2)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 "+
// "SET table1.col1=1 "+
// "WHERE table1.col2=2 "+
// "ORDER BY table1.col2")
//}
//
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
// stmt.WHERE(EqL(table1Col2, 2))
// stmt.Limit(5)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 "+
// "SET table1.col1=1 "+
// "WHERE table1.col2=2 "+
// "LIMIT 5")
//}

View file

@ -1 +1,35 @@
package sqlbuilder
import "bytes"
func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error {
for i, value := range expressions {
if i > 0 {
buf.WriteString(", ")
}
err := value.SerializeSql(buf)
if err != nil {
return err
}
}
return nil
}
func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error {
for i, value := range projections {
if i > 0 {
buf.WriteString(", ")
}
err := value.SerializeForProjection(buf)
if err != nil {
return err
}
}
return nil
}