Add support for Update statements.
This commit is contained in:
parent
b287521f1a
commit
70d6f84375
12 changed files with 422 additions and 286 deletions
|
|
@ -18,7 +18,7 @@ type InsertStatement interface {
|
||||||
// Map or stracture mapped to column names
|
// Map or stracture mapped to column names
|
||||||
VALUES_MAPPING(data interface{}) InsertStatement
|
VALUES_MAPPING(data interface{}) InsertStatement
|
||||||
|
|
||||||
RETURNING(column ...Expression) InsertStatement
|
RETURNING(projections ...Projection) InsertStatement
|
||||||
|
|
||||||
QUERY(selectStatement SelectStatement) InsertStatement
|
QUERY(selectStatement SelectStatement) InsertStatement
|
||||||
|
|
||||||
|
|
@ -27,10 +27,8 @@ type InsertStatement interface {
|
||||||
|
|
||||||
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
|
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
|
||||||
return &insertStatementImpl{
|
return &insertStatementImpl{
|
||||||
table: t,
|
table: t,
|
||||||
columns: columns,
|
columns: columns,
|
||||||
rows: make([][]Clause, 0, 1),
|
|
||||||
returning: make([]Expression, 0, 1),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -44,7 +42,7 @@ type insertStatementImpl struct {
|
||||||
columns []Column
|
columns []Column
|
||||||
rows [][]Clause
|
rows [][]Clause
|
||||||
query SelectStatement
|
query SelectStatement
|
||||||
returning []Expression
|
returning []Projection
|
||||||
|
|
||||||
errors []string
|
errors []string
|
||||||
}
|
}
|
||||||
|
|
@ -114,8 +112,8 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement {
|
func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement {
|
||||||
i.returning = column
|
i.returning = projections
|
||||||
|
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
@ -217,16 +215,10 @@ func (s *insertStatementImpl) String() (sql string, err error) {
|
||||||
if len(s.returning) > 0 {
|
if len(s.returning) > 0 {
|
||||||
buf.WriteString(" RETURNING ")
|
buf.WriteString(" RETURNING ")
|
||||||
|
|
||||||
for i, column := range s.returning {
|
err = serializeProjectionList(s.returning, buf)
|
||||||
if i > 0 {
|
|
||||||
buf.WriteString(",")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = column.SerializeSql(buf)
|
if err != nil {
|
||||||
|
return
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"github.com/dropbox/godropbox/errors"
|
"github.com/dropbox/godropbox/errors"
|
||||||
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
|
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
|
||||||
"github.com/sub0zero/go-sqlbuilder/types"
|
"github.com/sub0zero/go-sqlbuilder/types"
|
||||||
"reflect"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type SelectStatement interface {
|
type SelectStatement interface {
|
||||||
|
|
@ -88,12 +87,6 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error {
|
func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error {
|
||||||
destinationType := reflect.TypeOf(destination)
|
|
||||||
|
|
||||||
if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct {
|
|
||||||
s.Limit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
query, err := s.String()
|
query, err := s.String()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
||||||
|
|
@ -34,16 +34,6 @@ type UnionStatement interface {
|
||||||
Offset(offset int64) UnionStatement
|
Offset(offset int64) UnionStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
type UpdateStatement interface {
|
|
||||||
Statement
|
|
||||||
|
|
||||||
Set(column Column, expression Expression) UpdateStatement
|
|
||||||
Where(expression BoolExpression) UpdateStatement
|
|
||||||
OrderBy(clauses ...OrderByClause) UpdateStatement
|
|
||||||
Limit(limit int64) UpdateStatement
|
|
||||||
Comment(comment string) UpdateStatement
|
|
||||||
}
|
|
||||||
|
|
||||||
type DeleteStatement interface {
|
type DeleteStatement interface {
|
||||||
Statement
|
Statement
|
||||||
|
|
||||||
|
|
@ -250,151 +240,6 @@ func (us *unionStatementImpl) String() (sql string, err error) {
|
||||||
return buf.String(), nil
|
return buf.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// UPDATE statement ===========================================================
|
|
||||||
//
|
|
||||||
|
|
||||||
func newUpdateStatement(table WritableTable) UpdateStatement {
|
|
||||||
return &updateStatementImpl{
|
|
||||||
table: table,
|
|
||||||
updateValues: make(map[Column]Expression),
|
|
||||||
limit: -1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type updateStatementImpl struct {
|
|
||||||
table WritableTable
|
|
||||||
updateValues map[Column]Expression
|
|
||||||
where BoolExpression
|
|
||||||
order *listClause
|
|
||||||
limit int64
|
|
||||||
comment string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) Execute(db *sql.DB, data interface{}) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) Set(
|
|
||||||
column Column,
|
|
||||||
expression Expression) UpdateStatement {
|
|
||||||
|
|
||||||
u.updateValues[column] = expression
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) Where(expression BoolExpression) UpdateStatement {
|
|
||||||
u.where = expression
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) OrderBy(
|
|
||||||
clauses ...OrderByClause) UpdateStatement {
|
|
||||||
|
|
||||||
u.order = newOrderByListClause(clauses...)
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) Limit(limit int64) UpdateStatement {
|
|
||||||
u.limit = limit
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) Comment(comment string) UpdateStatement {
|
|
||||||
u.comment = comment
|
|
||||||
return u
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *updateStatementImpl) String() (sql string, err error) {
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
_, _ = buf.WriteString("UPDATE ")
|
|
||||||
|
|
||||||
if err = writeComment(u.comment, buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.table == nil {
|
|
||||||
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = u.table.SerializeSql(buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(u.updateValues) == 0 {
|
|
||||||
return "", errors.Newf(
|
|
||||||
"No column updated. Generated sql: %s",
|
|
||||||
buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = buf.WriteString(" SET ")
|
|
||||||
addComma := false
|
|
||||||
|
|
||||||
// Sorting is too hard in go, just create a second map ...
|
|
||||||
updateValues := make(map[string]Expression)
|
|
||||||
for col, expr := range u.updateValues {
|
|
||||||
if col == nil {
|
|
||||||
return "", errors.Newf(
|
|
||||||
"nil column. Generated sql: %s",
|
|
||||||
buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
updateValues[col.Name()] = expr
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, col := range u.table.Columns() {
|
|
||||||
val, inMap := updateValues[col.Name()]
|
|
||||||
if !inMap {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if addComma {
|
|
||||||
_, _ = buf.WriteString(", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
if val == nil {
|
|
||||||
return "", errors.Newf(
|
|
||||||
"nil value. Generated sql: %s",
|
|
||||||
buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = col.SerializeSql(buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = buf.WriteByte('=')
|
|
||||||
if err = val.SerializeSql(buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
addComma = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.where == nil {
|
|
||||||
return "", errors.Newf(
|
|
||||||
"Updating without a WHERE clause. Generated sql: %s",
|
|
||||||
buf.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _ = buf.WriteString(" WHERE ")
|
|
||||||
if err = u.where.SerializeSql(buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.order != nil {
|
|
||||||
_, _ = buf.WriteString(" ORDER BY ")
|
|
||||||
if err = u.order.SerializeSql(buf); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if u.limit >= 0 {
|
|
||||||
_, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", u.limit))
|
|
||||||
}
|
|
||||||
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// DELETE statement ===========================================================
|
// DELETE statement ===========================================================
|
||||||
//
|
//
|
||||||
|
|
@ -565,7 +410,7 @@ func (s *unlockStatementImpl) String() (sql string, err error) {
|
||||||
return "UNLOCK TABLES", nil
|
return "UNLOCK TABLES", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID.
|
// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID.
|
||||||
func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement {
|
func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement {
|
||||||
return >idNextStatementImpl{
|
return >idNextStatementImpl{
|
||||||
sid: sid,
|
sid: sid,
|
||||||
|
|
|
||||||
|
|
@ -385,100 +385,6 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) {
|
||||||
"ON DUPLICATE KEY UPDATE table1.col3=3, table1.col2=4")
|
"ON DUPLICATE KEY UPDATE table1.col3=3, table1.col2=4")
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// UPDATE statement tests =====================================================
|
|
||||||
//
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(nil, Literal(1))
|
|
||||||
_, err := stmt.String()
|
|
||||||
c.Assert(err, gc.NotNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(table1Col1, nil)
|
|
||||||
_, err := stmt.String()
|
|
||||||
c.Assert(err, gc.NotNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(table1Col1, Literal(1))
|
|
||||||
_, err := stmt.String()
|
|
||||||
c.Assert(err, gc.NotNil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(table1Col1, Literal(1))
|
|
||||||
stmt.Where(EqL(table1Col2, 2))
|
|
||||||
sql, err := stmt.String()
|
|
||||||
c.Assert(err, gc.IsNil)
|
|
||||||
|
|
||||||
c.Assert(
|
|
||||||
sql,
|
|
||||||
gc.Equals,
|
|
||||||
"UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(table1.C("col1"), Literal(1))
|
|
||||||
stmt.Where(EqL(table1Col2, 2))
|
|
||||||
sql, err := stmt.String()
|
|
||||||
c.Assert(err, gc.IsNil)
|
|
||||||
|
|
||||||
c.Assert(
|
|
||||||
sql,
|
|
||||||
gc.Equals,
|
|
||||||
"UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) {
|
|
||||||
stmt := table1.Update()
|
|
||||||
stmt.Set(table1Col1, Literal(1))
|
|
||||||
stmt.Set(table1Col2, Literal(2))
|
|
||||||
stmt.Where(EqL(table1Col2, 3))
|
|
||||||
sql, err := stmt.String()
|
|
||||||
c.Assert(err, gc.IsNil)
|
|
||||||
|
|
||||||
c.Assert(
|
|
||||||
sql,
|
|
||||||
gc.Equals,
|
|
||||||
"UPDATE db.table1 "+
|
|
||||||
"SET table1.col1=1, table1.col2=2 "+
|
|
||||||
"WHERE table1.col2=3")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(table1Col1, Literal(1))
|
|
||||||
stmt.Where(EqL(table1Col2, 2))
|
|
||||||
stmt.OrderBy(table1Col2)
|
|
||||||
sql, err := stmt.String()
|
|
||||||
c.Assert(err, gc.IsNil)
|
|
||||||
|
|
||||||
c.Assert(
|
|
||||||
sql,
|
|
||||||
gc.Equals,
|
|
||||||
"UPDATE db.table1 "+
|
|
||||||
"SET table1.col1=1 "+
|
|
||||||
"WHERE table1.col2=2 "+
|
|
||||||
"ORDER BY table1.col2")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
|
|
||||||
stmt := table1.Update().Set(table1Col1, Literal(1))
|
|
||||||
stmt.Where(EqL(table1Col2, 2))
|
|
||||||
stmt.Limit(5)
|
|
||||||
sql, err := stmt.String()
|
|
||||||
c.Assert(err, gc.IsNil)
|
|
||||||
|
|
||||||
c.Assert(
|
|
||||||
sql,
|
|
||||||
gc.Equals,
|
|
||||||
"UPDATE db.table1 "+
|
|
||||||
"SET table1.col1=1 "+
|
|
||||||
"WHERE table1.col2=2 "+
|
|
||||||
"LIMIT 5")
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// DELETE statement tests =====================================================
|
// DELETE statement tests =====================================================
|
||||||
//
|
//
|
||||||
|
|
@ -619,7 +525,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) {
|
||||||
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
|
func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) {
|
||||||
|
|
||||||
// tests on outer statement: Group By, Order By, Limit
|
// tests on outer statement: Group By, Order By, Limit
|
||||||
// on inner statement: AndWhere, Where (with And), Order By, Limit
|
// on inner statement: AndWhere, WHERE (with And), Order By, Limit
|
||||||
select_queries := make([]SelectStatement, 0, 3)
|
select_queries := make([]SelectStatement, 0, 3)
|
||||||
|
|
||||||
// We're not trying to write a SQL parser, so we won't warn if you do something silly like
|
// We're not trying to write a SQL parser, so we won't warn if you do something silly like
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ type WritableTable interface {
|
||||||
TableInterface
|
TableInterface
|
||||||
|
|
||||||
INSERT(columns ...Column) InsertStatement
|
INSERT(columns ...Column) InsertStatement
|
||||||
Update() UpdateStatement
|
UPDATE(columns ...Column) UpdateStatement
|
||||||
Delete() DeleteStatement
|
Delete() DeleteStatement
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -229,8 +229,8 @@ func (t *Table) INSERT(columns ...Column) InsertStatement {
|
||||||
return newInsertStatement(t, columns...)
|
return newInsertStatement(t, columns...)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Table) Update() UpdateStatement {
|
func (t *Table) UPDATE(columns ...Column) UpdateStatement {
|
||||||
return newUpdateStatement(t)
|
return newUpdateStatement(t, columns)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Table) Delete() DeleteStatement {
|
func (t *Table) Delete() DeleteStatement {
|
||||||
|
|
|
||||||
168
sqlbuilder/update_statement.go
Normal file
168
sqlbuilder/update_statement.go
Normal file
|
|
@ -0,0 +1,168 @@
|
||||||
|
package sqlbuilder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"database/sql"
|
||||||
|
"github.com/dropbox/godropbox/errors"
|
||||||
|
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
|
||||||
|
"github.com/sub0zero/go-sqlbuilder/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
type UpdateStatement interface {
|
||||||
|
Statement
|
||||||
|
|
||||||
|
SET(values ...interface{}) UpdateStatement
|
||||||
|
WHERE(expression BoolExpression) UpdateStatement
|
||||||
|
RETURNING(projections ...Projection) UpdateStatement
|
||||||
|
|
||||||
|
Query(db types.Db, destination interface{}) error
|
||||||
|
Execute(db types.Db) (sql.Result, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
|
||||||
|
return &updateStatementImpl{
|
||||||
|
table: table,
|
||||||
|
columns: columns,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type updateStatementImpl struct {
|
||||||
|
table WritableTable
|
||||||
|
columns []Column
|
||||||
|
updateValues []Clause
|
||||||
|
where BoolExpression
|
||||||
|
returning []Projection
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error {
|
||||||
|
query, err := u.String()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return execution.Execute(db, query, destination)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
|
||||||
|
query, err := u.String()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
res, err = db.Exec(query)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
|
||||||
|
|
||||||
|
for _, value := range values {
|
||||||
|
if clause, ok := value.(Clause); ok {
|
||||||
|
u.updateValues = append(u.updateValues, clause)
|
||||||
|
} else {
|
||||||
|
u.updateValues = append(u.updateValues, Literal(value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement {
|
||||||
|
u.where = expression
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement {
|
||||||
|
u.returning = projections
|
||||||
|
return u
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *updateStatementImpl) String() (sql string, err error) {
|
||||||
|
buf := new(bytes.Buffer)
|
||||||
|
_, _ = buf.WriteString("UPDATE ")
|
||||||
|
|
||||||
|
if u.table == nil {
|
||||||
|
return "", errors.Newf("nil tableName. Generated sql: %s", buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = u.table.SerializeSql(buf); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(u.updateValues) == 0 {
|
||||||
|
return "", errors.Newf(
|
||||||
|
"No column updated. Generated sql: %s",
|
||||||
|
buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = buf.WriteString(" SET")
|
||||||
|
|
||||||
|
if len(u.columns) > 1 {
|
||||||
|
buf.WriteString(" ( ")
|
||||||
|
} else {
|
||||||
|
buf.WriteString(" ")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, column := range u.columns {
|
||||||
|
if i > 0 {
|
||||||
|
buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteString(column.Name())
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(u.columns) > 1 {
|
||||||
|
buf.WriteString(" )")
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteString(" =")
|
||||||
|
|
||||||
|
if len(u.updateValues) > 1 {
|
||||||
|
buf.WriteString(" (")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, value := range u.updateValues {
|
||||||
|
if i > 0 {
|
||||||
|
buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = value.SerializeSql(buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(u.updateValues) > 1 {
|
||||||
|
buf.WriteString(" )")
|
||||||
|
}
|
||||||
|
|
||||||
|
if u.where == nil {
|
||||||
|
return "", errors.Newf(
|
||||||
|
"Updating without a WHERE clause. Generated sql: %s",
|
||||||
|
buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = buf.WriteString(" WHERE ")
|
||||||
|
if err = u.where.SerializeSql(buf); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(u.returning) > 0 {
|
||||||
|
buf.WriteString(" RETURNING ")
|
||||||
|
|
||||||
|
err = serializeProjectionList(u.returning, buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return buf.String() + ";", nil
|
||||||
|
}
|
||||||
113
sqlbuilder/update_statement_test.go
Normal file
113
sqlbuilder/update_statement_test.go
Normal file
|
|
@ -0,0 +1,113 @@
|
||||||
|
package sqlbuilder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"gotest.tools/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//
|
||||||
|
// UPDATE statement tests =====================================================
|
||||||
|
//
|
||||||
|
|
||||||
|
func TestUpdate(t *testing.T) {
|
||||||
|
stmt := table1.UPDATE(table1Col1, table1Col2).
|
||||||
|
SET(table1.SELECT(table1Col2)).
|
||||||
|
WHERE(table1Col1.EqL(2))
|
||||||
|
|
||||||
|
stmtStr, err := stmt.String()
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
fmt.Println(stmtStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
//func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(nil, Literal(1))
|
||||||
|
// _, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.NotNil)
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(table1Col1, nil)
|
||||||
|
// _, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.NotNil)
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
|
||||||
|
// _, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.NotNil)
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
|
||||||
|
// stmt.WHERE(EqL(table1Col2, 2))
|
||||||
|
// sql, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.IsNil)
|
||||||
|
//
|
||||||
|
// c.Assert(
|
||||||
|
// sql,
|
||||||
|
// gc.Equals,
|
||||||
|
// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1))
|
||||||
|
// stmt.WHERE(EqL(table1Col2, 2))
|
||||||
|
// sql, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.IsNil)
|
||||||
|
//
|
||||||
|
// c.Assert(
|
||||||
|
// sql,
|
||||||
|
// gc.Equals,
|
||||||
|
// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE()
|
||||||
|
// stmt.SET(table1Col1, Literal(1))
|
||||||
|
// stmt.SET(table1Col2, Literal(2))
|
||||||
|
// stmt.WHERE(EqL(table1Col2, 3))
|
||||||
|
// sql, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.IsNil)
|
||||||
|
//
|
||||||
|
// c.Assert(
|
||||||
|
// sql,
|
||||||
|
// gc.Equals,
|
||||||
|
// "UPDATE db.table1 "+
|
||||||
|
// "SET table1.col1=1, table1.col2=2 "+
|
||||||
|
// "WHERE table1.col2=3")
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
|
||||||
|
// stmt.WHERE(EqL(table1Col2, 2))
|
||||||
|
// stmt.OrderBy(table1Col2)
|
||||||
|
// sql, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.IsNil)
|
||||||
|
//
|
||||||
|
// c.Assert(
|
||||||
|
// sql,
|
||||||
|
// gc.Equals,
|
||||||
|
// "UPDATE db.table1 "+
|
||||||
|
// "SET table1.col1=1 "+
|
||||||
|
// "WHERE table1.col2=2 "+
|
||||||
|
// "ORDER BY table1.col2")
|
||||||
|
//}
|
||||||
|
//
|
||||||
|
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
|
||||||
|
// stmt := table1.UPDATE().SET(table1Col1, Literal(1))
|
||||||
|
// stmt.WHERE(EqL(table1Col2, 2))
|
||||||
|
// stmt.Limit(5)
|
||||||
|
// sql, err := stmt.String()
|
||||||
|
// c.Assert(err, gc.IsNil)
|
||||||
|
//
|
||||||
|
// c.Assert(
|
||||||
|
// sql,
|
||||||
|
// gc.Equals,
|
||||||
|
// "UPDATE db.table1 "+
|
||||||
|
// "SET table1.col1=1 "+
|
||||||
|
// "WHERE table1.col2=2 "+
|
||||||
|
// "LIMIT 5")
|
||||||
|
//}
|
||||||
|
|
@ -1 +1,35 @@
|
||||||
package sqlbuilder
|
package sqlbuilder
|
||||||
|
|
||||||
|
import "bytes"
|
||||||
|
|
||||||
|
func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error {
|
||||||
|
for i, value := range expressions {
|
||||||
|
if i > 0 {
|
||||||
|
buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := value.SerializeSql(buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error {
|
||||||
|
for i, value := range projections {
|
||||||
|
if i > 0 {
|
||||||
|
buf.WriteString(", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
err := value.SerializeForProjection(buf)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,13 +25,13 @@ func TestGenerateModel(t *testing.T) {
|
||||||
|
|
||||||
func TestSelect_ScanToStruct(t *testing.T) {
|
func TestSelect_ScanToStruct(t *testing.T) {
|
||||||
actor := model.Actor{}
|
actor := model.Actor{}
|
||||||
query := Actor.SELECT(Actor.AllColumns)
|
query := Actor.SELECT(Actor.AllColumns).OrderBy(Actor.ActorID.Asc())
|
||||||
|
|
||||||
queryStr, err := query.String()
|
queryStr, err := query.String()
|
||||||
|
|
||||||
fmt.Println(queryStr)
|
fmt.Println(queryStr)
|
||||||
|
|
||||||
assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor`)
|
assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`)
|
||||||
|
|
||||||
err = query.Execute(db, &actor)
|
err = query.Execute(db, &actor)
|
||||||
|
|
||||||
|
|
@ -74,7 +74,7 @@ func TestSelect_ScanToSlice(t *testing.T) {
|
||||||
// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)).
|
// INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)).
|
||||||
// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
|
// INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)).
|
||||||
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns).
|
// SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns).
|
||||||
// Where(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2)))
|
// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2)))
|
||||||
//
|
//
|
||||||
// queryStr, err := query.String()
|
// queryStr, err := query.String()
|
||||||
// assert.NilError(t, err)
|
// assert.NilError(t, err)
|
||||||
|
|
@ -405,7 +405,7 @@ func TestSubQuery(t *testing.T) {
|
||||||
//Customer.
|
//Customer.
|
||||||
// INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))).
|
// INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))).
|
||||||
// SELECT(Customer.AllColumns, selectStmtTable.RefIntColumnName("first_name")).
|
// SELECT(Customer.AllColumns, selectStmtTable.RefIntColumnName("first_name")).
|
||||||
// Where(Actor.LastName.Neq(avrgCustomer))
|
// WHERE(Actor.LastName.Neq(avrgCustomer))
|
||||||
|
|
||||||
rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating).
|
rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating).
|
||||||
Where(Film.Rating.EqL("R")).
|
Where(Film.Rating.EqL("R")).
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ func TestInsertValues(t *testing.T) {
|
||||||
|
|
||||||
fmt.Println(insertQueryStr)
|
fmt.Println(insertQueryStr)
|
||||||
|
|
||||||
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id;`)
|
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id AS "link.id";`)
|
||||||
res, err := insertQuery.Execute(db)
|
res, err := insertQuery.Execute(db)
|
||||||
|
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ package tests
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
"os"
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
@ -19,7 +20,8 @@ const (
|
||||||
|
|
||||||
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
|
var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname)
|
||||||
var db *sql.DB
|
var db *sql.DB
|
||||||
var tx *sql.Tx
|
|
||||||
|
//var tx *sql.Tx
|
||||||
|
|
||||||
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files
|
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files
|
||||||
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files
|
//go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files
|
||||||
|
|
@ -32,7 +34,7 @@ func TestMain(m *testing.M) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic("Failed to connect to test db")
|
panic("Failed to connect to test db")
|
||||||
}
|
}
|
||||||
tx, _ = db.Begin()
|
//tx, _ = db.Begin()
|
||||||
defer cleanUp()
|
defer cleanUp()
|
||||||
|
|
||||||
dbInit()
|
dbInit()
|
||||||
|
|
@ -48,7 +50,7 @@ func TestMain(m *testing.M) {
|
||||||
func cleanUp() {
|
func cleanUp() {
|
||||||
fmt.Println("CLEAN UP")
|
fmt.Println("CLEAN UP")
|
||||||
|
|
||||||
tx.Rollback()
|
//tx.Rollback()
|
||||||
db.Close()
|
db.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
83
tests/update_test.go
Normal file
83
tests/update_test.go
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
package tests
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/sub0zero/go-sqlbuilder/sqlbuilder"
|
||||||
|
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
|
||||||
|
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
|
||||||
|
"gotest.tools/assert"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestUpdateValues(t *testing.T) {
|
||||||
|
_, err := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel).
|
||||||
|
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT).
|
||||||
|
VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT).
|
||||||
|
VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT).
|
||||||
|
RETURNING(table.Link.ID).Execute(db)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
query := table.Link.
|
||||||
|
UPDATE(table.Link.Name, table.Link.URL).
|
||||||
|
SET("Bong", "http://bong.com").
|
||||||
|
WHERE(table.Link.Name.EqL("Bing"))
|
||||||
|
|
||||||
|
queryStr, err := query.String()
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
fmt.Println(queryStr)
|
||||||
|
|
||||||
|
res, err := query.Execute(db)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
fmt.Println(res)
|
||||||
|
|
||||||
|
links := []model.Link{}
|
||||||
|
|
||||||
|
err = table.Link.SELECT(table.Link.AllColumns).
|
||||||
|
Where(table.Link.Name.EqL("Bong")).
|
||||||
|
Execute(db, &links)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
//spew.Dump(links)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateAndReturning(t *testing.T) {
|
||||||
|
_, err := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel).
|
||||||
|
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT).
|
||||||
|
VALUES("http://www.ask.com", "Ask", sqlbuilder.DEFAULT).
|
||||||
|
VALUES("http://www.ask.com", "Ask", sqlbuilder.DEFAULT).
|
||||||
|
VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT).
|
||||||
|
VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT).
|
||||||
|
RETURNING(table.Link.ID).Execute(db)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
stmt := table.Link.
|
||||||
|
UPDATE(table.Link.Name, table.Link.URL).
|
||||||
|
SET("DuckDuckGo", "http://www.duckduckgo.com").
|
||||||
|
WHERE(table.Link.Name.EqL("Ask")).
|
||||||
|
RETURNING(table.Link.AllColumns)
|
||||||
|
|
||||||
|
stmtStr, err := stmt.String()
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
fmt.Println(stmtStr)
|
||||||
|
|
||||||
|
links := []model.Link{}
|
||||||
|
|
||||||
|
err = stmt.Query(db, &links)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, len(links), 2)
|
||||||
|
|
||||||
|
assert.Equal(t, links[0].Name, "DuckDuckGo")
|
||||||
|
|
||||||
|
assert.Equal(t, links[1].Name, "DuckDuckGo")
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue