From 70d6f84375228bfb139fc53ecce5ee9c92f484c8 Mon Sep 17 00:00:00 2001 From: zer0sub Date: Sun, 14 Apr 2019 17:55:10 +0200 Subject: [PATCH] Add support for Update statements. --- sqlbuilder/insert_statement.go | 26 ++--- sqlbuilder/select_statement.go | 7 -- sqlbuilder/statement.go | 157 +------------------------- sqlbuilder/statement_test.go | 96 +--------------- sqlbuilder/table.go | 6 +- sqlbuilder/update_statement.go | 168 ++++++++++++++++++++++++++++ sqlbuilder/update_statement_test.go | 113 +++++++++++++++++++ sqlbuilder/utils.go | 34 ++++++ tests/generator_test.go | 8 +- tests/insert_test.go | 2 +- tests/main_test.go | 8 +- tests/update_test.go | 83 ++++++++++++++ 12 files changed, 422 insertions(+), 286 deletions(-) create mode 100644 sqlbuilder/update_statement.go create mode 100644 sqlbuilder/update_statement_test.go create mode 100644 tests/update_test.go diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 8d8dfdb..0d9ebcf 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -18,7 +18,7 @@ type InsertStatement interface { // Map or stracture mapped to column names VALUES_MAPPING(data interface{}) InsertStatement - RETURNING(column ...Expression) InsertStatement + RETURNING(projections ...Projection) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement @@ -27,10 +27,8 @@ type InsertStatement interface { func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { return &insertStatementImpl{ - table: t, - columns: columns, - rows: make([][]Clause, 0, 1), - returning: make([]Expression, 0, 1), + table: t, + columns: columns, } } @@ -44,7 +42,7 @@ type insertStatementImpl struct { columns []Column rows [][]Clause query SelectStatement - returning []Expression + returning []Projection errors []string } @@ -114,8 +112,8 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement { return i } -func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement { - i.returning = column +func (i *insertStatementImpl) RETURNING(projections ...Projection) InsertStatement { + i.returning = projections return i } @@ -217,16 +215,10 @@ func (s *insertStatementImpl) String() (sql string, err error) { if len(s.returning) > 0 { buf.WriteString(" RETURNING ") - for i, column := range s.returning { - if i > 0 { - buf.WriteString(",") - } + err = serializeProjectionList(s.returning, buf) - err = column.SerializeSql(buf) - - if err != nil { - return - } + if err != nil { + return } } diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 9ecfb1e..03b7898 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -6,7 +6,6 @@ import ( "github.com/dropbox/godropbox/errors" "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" "github.com/sub0zero/go-sqlbuilder/types" - "reflect" ) type SelectStatement interface { @@ -88,12 +87,6 @@ func (s *selectStatementImpl) AsTable(alias string) *SelectStatementTable { } func (s *selectStatementImpl) Execute(db types.Db, destination interface{}) error { - destinationType := reflect.TypeOf(destination) - - if destinationType.Kind() == reflect.Ptr && destinationType.Elem().Kind() == reflect.Struct { - s.Limit(1) - } - query, err := s.String() if err != nil { diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index f8b6b9b..952d377 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -34,16 +34,6 @@ type UnionStatement interface { Offset(offset int64) UnionStatement } -type UpdateStatement interface { - Statement - - Set(column Column, expression Expression) UpdateStatement - Where(expression BoolExpression) UpdateStatement - OrderBy(clauses ...OrderByClause) UpdateStatement - Limit(limit int64) UpdateStatement - Comment(comment string) UpdateStatement -} - type DeleteStatement interface { Statement @@ -250,151 +240,6 @@ func (us *unionStatementImpl) String() (sql string, err error) { return buf.String(), nil } -// -// UPDATE statement =========================================================== -// - -func newUpdateStatement(table WritableTable) UpdateStatement { - return &updateStatementImpl{ - table: table, - updateValues: make(map[Column]Expression), - limit: -1, - } -} - -type updateStatementImpl struct { - table WritableTable - updateValues map[Column]Expression - where BoolExpression - order *listClause - limit int64 - comment string -} - -func (u *updateStatementImpl) Execute(db *sql.DB, data interface{}) error { - return nil -} - -func (u *updateStatementImpl) Set( - column Column, - expression Expression) UpdateStatement { - - u.updateValues[column] = expression - return u -} - -func (u *updateStatementImpl) Where(expression BoolExpression) UpdateStatement { - u.where = expression - return u -} - -func (u *updateStatementImpl) OrderBy( - clauses ...OrderByClause) UpdateStatement { - - u.order = newOrderByListClause(clauses...) - return u -} - -func (u *updateStatementImpl) Limit(limit int64) UpdateStatement { - u.limit = limit - return u -} - -func (u *updateStatementImpl) Comment(comment string) UpdateStatement { - u.comment = comment - return u -} - -func (u *updateStatementImpl) String() (sql string, err error) { - buf := new(bytes.Buffer) - _, _ = buf.WriteString("UPDATE ") - - if err = writeComment(u.comment, buf); err != nil { - return - } - - if u.table == nil { - return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) - } - - if err = u.table.SerializeSql(buf); err != nil { - return - } - - if len(u.updateValues) == 0 { - return "", errors.Newf( - "No column updated. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(" SET ") - addComma := false - - // Sorting is too hard in go, just create a second map ... - updateValues := make(map[string]Expression) - for col, expr := range u.updateValues { - if col == nil { - return "", errors.Newf( - "nil column. Generated sql: %s", - buf.String()) - } - - updateValues[col.Name()] = expr - } - - for _, col := range u.table.Columns() { - val, inMap := updateValues[col.Name()] - if !inMap { - continue - } - - if addComma { - _, _ = buf.WriteString(", ") - } - - if val == nil { - return "", errors.Newf( - "nil value. Generated sql: %s", - buf.String()) - } - - if err = col.SerializeSql(buf); err != nil { - return - } - - _ = buf.WriteByte('=') - if err = val.SerializeSql(buf); err != nil { - return - } - - addComma = true - } - - if u.where == nil { - return "", errors.Newf( - "Updating without a WHERE clause. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(" WHERE ") - if err = u.where.SerializeSql(buf); err != nil { - return - } - - if u.order != nil { - _, _ = buf.WriteString(" ORDER BY ") - if err = u.order.SerializeSql(buf); err != nil { - return - } - } - - if u.limit >= 0 { - _, _ = buf.WriteString(fmt.Sprintf(" LIMIT %d", u.limit)) - } - - return buf.String(), nil -} - // // DELETE statement =========================================================== // @@ -565,7 +410,7 @@ func (s *unlockStatementImpl) String() (sql string, err error) { return "UNLOCK TABLES", nil } -// Set GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. +// SET GTID_NEXT statement returns a SQL statement that can be used to explicitly set the next GTID. func NewGtidNextStatement(sid []byte, gno uint64) GtidNextStatement { return >idNextStatementImpl{ sid: sid, diff --git a/sqlbuilder/statement_test.go b/sqlbuilder/statement_test.go index f580fca..2afc4d6 100644 --- a/sqlbuilder/statement_test.go +++ b/sqlbuilder/statement_test.go @@ -385,100 +385,6 @@ func (s *StmtSuite) TestOnDuplicateKeyUpdateMulti(c *gc.C) { "ON DUPLICATE KEY UPDATE table1.col3=3, table1.col2=4") } -// -// UPDATE statement tests ===================================================== -// - -func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) { - stmt := table1.Update().Set(nil, Literal(1)) - _, err := stmt.String() - c.Assert(err, gc.NotNil) -} - -func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) { - stmt := table1.Update().Set(table1Col1, nil) - _, err := stmt.String() - c.Assert(err, gc.NotNil) -} - -func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) { - stmt := table1.Update().Set(table1Col1, Literal(1)) - _, err := stmt.String() - c.Assert(err, gc.NotNil) -} - -func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { - stmt := table1.Update().Set(table1Col1, Literal(1)) - stmt.Where(EqL(table1Col2, 2)) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2") -} - -func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { - stmt := table1.Update().Set(table1.C("col1"), Literal(1)) - stmt.Where(EqL(table1Col2, 2)) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2") -} - -func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) { - stmt := table1.Update() - stmt.Set(table1Col1, Literal(1)) - stmt.Set(table1Col2, Literal(2)) - stmt.Where(EqL(table1Col2, 3)) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "UPDATE db.table1 "+ - "SET table1.col1=1, table1.col2=2 "+ - "WHERE table1.col2=3") -} - -func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { - stmt := table1.Update().Set(table1Col1, Literal(1)) - stmt.Where(EqL(table1Col2, 2)) - stmt.OrderBy(table1Col2) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "UPDATE db.table1 "+ - "SET table1.col1=1 "+ - "WHERE table1.col2=2 "+ - "ORDER BY table1.col2") -} - -func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { - stmt := table1.Update().Set(table1Col1, Literal(1)) - stmt.Where(EqL(table1Col2, 2)) - stmt.Limit(5) - sql, err := stmt.String() - c.Assert(err, gc.IsNil) - - c.Assert( - sql, - gc.Equals, - "UPDATE db.table1 "+ - "SET table1.col1=1 "+ - "WHERE table1.col2=2 "+ - "LIMIT 5") -} - // // DELETE statement tests ===================================================== // @@ -619,7 +525,7 @@ func (s *StmtSuite) TestUnionSelectWithMismatchedColumns(c *gc.C) { func (s *StmtSuite) TestComplicatedUnionSelectWithWhereStatement(c *gc.C) { // tests on outer statement: Group By, Order By, Limit - // on inner statement: AndWhere, Where (with And), Order By, Limit + // on inner statement: AndWhere, WHERE (with And), Order By, Limit select_queries := make([]SelectStatement, 0, 3) // We're not trying to write a SQL parser, so we won't warn if you do something silly like diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index d7d08aa..0602751 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -45,7 +45,7 @@ type WritableTable interface { TableInterface INSERT(columns ...Column) InsertStatement - Update() UpdateStatement + UPDATE(columns ...Column) UpdateStatement Delete() DeleteStatement } @@ -229,8 +229,8 @@ func (t *Table) INSERT(columns ...Column) InsertStatement { return newInsertStatement(t, columns...) } -func (t *Table) Update() UpdateStatement { - return newUpdateStatement(t) +func (t *Table) UPDATE(columns ...Column) UpdateStatement { + return newUpdateStatement(t, columns) } func (t *Table) Delete() DeleteStatement { diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go new file mode 100644 index 0000000..5536b90 --- /dev/null +++ b/sqlbuilder/update_statement.go @@ -0,0 +1,168 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql" + "github.com/dropbox/godropbox/errors" + "github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution" + "github.com/sub0zero/go-sqlbuilder/types" +) + +type UpdateStatement interface { + Statement + + SET(values ...interface{}) UpdateStatement + WHERE(expression BoolExpression) UpdateStatement + RETURNING(projections ...Projection) UpdateStatement + + Query(db types.Db, destination interface{}) error + Execute(db types.Db) (sql.Result, error) +} + +func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement { + return &updateStatementImpl{ + table: table, + columns: columns, + } +} + +type updateStatementImpl struct { + table WritableTable + columns []Column + updateValues []Clause + where BoolExpression + returning []Projection +} + +func (u *updateStatementImpl) Query(db types.Db, destination interface{}) error { + query, err := u.String() + + if err != nil { + return err + } + + return execution.Execute(db, query, destination) +} + +func (u *updateStatementImpl) Execute(db types.Db) (res sql.Result, err error) { + query, err := u.String() + + if err != nil { + return + } + + res, err = db.Exec(query) + + return +} + +func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement { + + for _, value := range values { + if clause, ok := value.(Clause); ok { + u.updateValues = append(u.updateValues, clause) + } else { + u.updateValues = append(u.updateValues, Literal(value)) + } + } + + return u +} + +func (u *updateStatementImpl) WHERE(expression BoolExpression) UpdateStatement { + u.where = expression + return u +} + +func (u *updateStatementImpl) RETURNING(projections ...Projection) UpdateStatement { + u.returning = projections + return u +} + +func (u *updateStatementImpl) String() (sql string, err error) { + buf := new(bytes.Buffer) + _, _ = buf.WriteString("UPDATE ") + + if u.table == nil { + return "", errors.Newf("nil tableName. Generated sql: %s", buf.String()) + } + + if err = u.table.SerializeSql(buf); err != nil { + return + } + + if len(u.updateValues) == 0 { + return "", errors.Newf( + "No column updated. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" SET") + + if len(u.columns) > 1 { + buf.WriteString(" ( ") + } else { + buf.WriteString(" ") + } + + for i, column := range u.columns { + if i > 0 { + buf.WriteString(", ") + } + + buf.WriteString(column.Name()) + + if err != nil { + return + } + } + + if len(u.columns) > 1 { + buf.WriteString(" )") + } + + buf.WriteString(" =") + + if len(u.updateValues) > 1 { + buf.WriteString(" (") + } + + for i, value := range u.updateValues { + if i > 0 { + buf.WriteString(", ") + } + + err = value.SerializeSql(buf) + + if err != nil { + return + } + } + + if len(u.updateValues) > 1 { + buf.WriteString(" )") + } + + if u.where == nil { + return "", errors.Newf( + "Updating without a WHERE clause. Generated sql: %s", + buf.String()) + } + + _, _ = buf.WriteString(" WHERE ") + if err = u.where.SerializeSql(buf); err != nil { + return + } + + if len(u.returning) > 0 { + buf.WriteString(" RETURNING ") + + err = serializeProjectionList(u.returning, buf) + + if err != nil { + return + } + } + + return buf.String() + ";", nil +} diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go new file mode 100644 index 0000000..e53a916 --- /dev/null +++ b/sqlbuilder/update_statement_test.go @@ -0,0 +1,113 @@ +package sqlbuilder + +import ( + "fmt" + "gotest.tools/assert" + "testing" +) + +// +// UPDATE statement tests ===================================================== +// + +func TestUpdate(t *testing.T) { + stmt := table1.UPDATE(table1Col1, table1Col2). + SET(table1.SELECT(table1Col2)). + WHERE(table1Col1.EqL(2)) + + stmtStr, err := stmt.String() + + assert.NilError(t, err) + + fmt.Println(stmtStr) +} + +//func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) { +// stmt := table1.UPDATE().SET(nil, Literal(1)) +// _, err := stmt.String() +// c.Assert(err, gc.NotNil) +//} +// +//func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) { +// stmt := table1.UPDATE().SET(table1Col1, nil) +// _, err := stmt.String() +// c.Assert(err, gc.NotNil) +//} +// +//func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) { +// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// _, err := stmt.String() +// c.Assert(err, gc.NotNil) +//} +// +//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { +// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt.WHERE(EqL(table1Col2, 2)) +// sql, err := stmt.String() +// c.Assert(err, gc.IsNil) +// +// c.Assert( +// sql, +// gc.Equals, +// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2") +//} +// +//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { +// stmt := table1.UPDATE().SET(table1.C("col1"), Literal(1)) +// stmt.WHERE(EqL(table1Col2, 2)) +// sql, err := stmt.String() +// c.Assert(err, gc.IsNil) +// +// c.Assert( +// sql, +// gc.Equals, +// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2") +//} +// +//func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) { +// stmt := table1.UPDATE() +// stmt.SET(table1Col1, Literal(1)) +// stmt.SET(table1Col2, Literal(2)) +// stmt.WHERE(EqL(table1Col2, 3)) +// sql, err := stmt.String() +// c.Assert(err, gc.IsNil) +// +// c.Assert( +// sql, +// gc.Equals, +// "UPDATE db.table1 "+ +// "SET table1.col1=1, table1.col2=2 "+ +// "WHERE table1.col2=3") +//} +// +//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { +// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt.WHERE(EqL(table1Col2, 2)) +// stmt.OrderBy(table1Col2) +// sql, err := stmt.String() +// c.Assert(err, gc.IsNil) +// +// c.Assert( +// sql, +// gc.Equals, +// "UPDATE db.table1 "+ +// "SET table1.col1=1 "+ +// "WHERE table1.col2=2 "+ +// "ORDER BY table1.col2") +//} +// +//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { +// stmt := table1.UPDATE().SET(table1Col1, Literal(1)) +// stmt.WHERE(EqL(table1Col2, 2)) +// stmt.Limit(5) +// sql, err := stmt.String() +// c.Assert(err, gc.IsNil) +// +// c.Assert( +// sql, +// gc.Equals, +// "UPDATE db.table1 "+ +// "SET table1.col1=1 "+ +// "WHERE table1.col2=2 "+ +// "LIMIT 5") +//} diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 3896e39..b6db6cb 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -1 +1,35 @@ package sqlbuilder + +import "bytes" + +func serializeExpressionList(expressions []Expression, buf *bytes.Buffer) error { + for i, value := range expressions { + if i > 0 { + buf.WriteString(", ") + } + + err := value.SerializeSql(buf) + + if err != nil { + return err + } + } + + return nil +} + +func serializeProjectionList(projections []Projection, buf *bytes.Buffer) error { + for i, value := range projections { + if i > 0 { + buf.WriteString(", ") + } + + err := value.SerializeForProjection(buf) + + if err != nil { + return err + } + } + + return nil +} diff --git a/tests/generator_test.go b/tests/generator_test.go index a54cd0f..c600cef 100644 --- a/tests/generator_test.go +++ b/tests/generator_test.go @@ -25,13 +25,13 @@ func TestGenerateModel(t *testing.T) { func TestSelect_ScanToStruct(t *testing.T) { actor := model.Actor{} - query := Actor.SELECT(Actor.AllColumns) + query := Actor.SELECT(Actor.AllColumns).OrderBy(Actor.ActorID.Asc()) queryStr, err := query.String() fmt.Println(queryStr) - assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor`) + assert.Equal(t, queryStr, `SELECT actor.actor_id AS "actor.actor_id", actor.first_name AS "actor.first_name", actor.last_name AS "actor.last_name", actor.last_update AS "actor.last_update" FROM dvds.actor ORDER BY actor.actor_id ASC`) err = query.Execute(db, &actor) @@ -74,7 +74,7 @@ func TestSelect_ScanToSlice(t *testing.T) { // INNER_JOIN(Film, FilmActor.FilmID.Eq(Film.FilmID)). // INNER_JOIN(Language, Film.LanguageID.Eq(Language.LanguageID)). // SELECT(FilmActor.AllColumns, Film.AllColumns, Language.AllColumns, Actor.AllColumns). -// Where(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) +// WHERE(FilmActor.ActorID.GtEq(1).And(FilmActor.ActorID.LteLiteral(2))) // // queryStr, err := query.String() // assert.NilError(t, err) @@ -405,7 +405,7 @@ func TestSubQuery(t *testing.T) { //Customer. // INNER_JOIN(selectStmtTable, Customer.LastName.Eq(selectStmtTable.RefStringColumn(Actor.FirstName))). // SELECT(Customer.AllColumns, selectStmtTable.RefIntColumnName("first_name")). - // Where(Actor.LastName.Neq(avrgCustomer)) + // WHERE(Actor.LastName.Neq(avrgCustomer)) rFilmsOnly := Film.SELECT(Film.FilmID, Film.Title, Film.Rating). Where(Film.Rating.EqL("R")). diff --git a/tests/insert_test.go b/tests/insert_test.go index 6265f3f..499db50 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -24,7 +24,7 @@ func TestInsertValues(t *testing.T) { fmt.Println(insertQueryStr) - assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id;`) + assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id AS "link.id";`) res, err := insertQuery.Execute(db) assert.NilError(t, err) diff --git a/tests/main_test.go b/tests/main_test.go index 12c5057..95c6d3e 100644 --- a/tests/main_test.go +++ b/tests/main_test.go @@ -3,6 +3,7 @@ package tests import ( "database/sql" "fmt" + _ "github.com/lib/pq" "os" "testing" ) @@ -19,7 +20,8 @@ const ( var connectString = fmt.Sprintf("host=%s port=%d user=%s "+"password=%s dbname=%s sslmode=disable", host, port, user, password, dbname) var db *sql.DB -var tx *sql.Tx + +//var tx *sql.Tx //go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema dvds -path .test_files //go:generate generator -db "host=localhost port=5432 user=postgres password=postgres dbname=dvd_rental sslmode=disable" -dbName dvd_rental -schema test_sample -path .test_files @@ -32,7 +34,7 @@ func TestMain(m *testing.M) { if err != nil { panic("Failed to connect to test db") } - tx, _ = db.Begin() + //tx, _ = db.Begin() defer cleanUp() dbInit() @@ -48,7 +50,7 @@ func TestMain(m *testing.M) { func cleanUp() { fmt.Println("CLEAN UP") - tx.Rollback() + //tx.Rollback() db.Close() } diff --git a/tests/update_test.go b/tests/update_test.go new file mode 100644 index 0000000..73e8563 --- /dev/null +++ b/tests/update_test.go @@ -0,0 +1,83 @@ +package tests + +import ( + "fmt" + "github.com/sub0zero/go-sqlbuilder/sqlbuilder" + "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model" + "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table" + "gotest.tools/assert" + "testing" +) + +func TestUpdateValues(t *testing.T) { + _, err := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT). + VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT). + VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT). + RETURNING(table.Link.ID).Execute(db) + + assert.NilError(t, err) + + query := table.Link. + UPDATE(table.Link.Name, table.Link.URL). + SET("Bong", "http://bong.com"). + WHERE(table.Link.Name.EqL("Bing")) + + queryStr, err := query.String() + + assert.NilError(t, err) + + fmt.Println(queryStr) + + res, err := query.Execute(db) + + assert.NilError(t, err) + + fmt.Println(res) + + links := []model.Link{} + + err = table.Link.SELECT(table.Link.AllColumns). + Where(table.Link.Name.EqL("Bong")). + Execute(db, &links) + + assert.NilError(t, err) + + //spew.Dump(links) +} + +func TestUpdateAndReturning(t *testing.T) { + _, err := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT). + VALUES("http://www.ask.com", "Ask", sqlbuilder.DEFAULT). + VALUES("http://www.ask.com", "Ask", sqlbuilder.DEFAULT). + VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT). + VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT). + RETURNING(table.Link.ID).Execute(db) + + assert.NilError(t, err) + + stmt := table.Link. + UPDATE(table.Link.Name, table.Link.URL). + SET("DuckDuckGo", "http://www.duckduckgo.com"). + WHERE(table.Link.Name.EqL("Ask")). + RETURNING(table.Link.AllColumns) + + stmtStr, err := stmt.String() + + assert.NilError(t, err) + + fmt.Println(stmtStr) + + links := []model.Link{} + + err = stmt.Query(db, &links) + + assert.NilError(t, err) + + assert.Equal(t, len(links), 2) + + assert.Equal(t, links[0].Name, "DuckDuckGo") + + assert.Equal(t, links[1].Name, "DuckDuckGo") +}