From a4feb66692aa1f616033e8fd539e126d4d695f9c Mon Sep 17 00:00:00 2001 From: go-jet Date: Fri, 14 Jun 2019 14:35:50 +0200 Subject: [PATCH] Insert and Update statement improvements. --- sqlbuilder/clause.go | 12 +- sqlbuilder/column.go | 24 +++ sqlbuilder/delete_statement.go | 6 +- sqlbuilder/insert_statement.go | 95 +++------- sqlbuilder/insert_statement_test.go | 47 ++--- sqlbuilder/lock_statement.go | 6 +- sqlbuilder/projection.go | 18 -- sqlbuilder/select_statement.go | 14 +- sqlbuilder/set_statement.go | 18 +- sqlbuilder/set_statement_test.go | 11 -- sqlbuilder/statement.go | 2 +- sqlbuilder/table.go | 32 +++- sqlbuilder/test_utils.go | 11 +- sqlbuilder/update_statement.go | 77 ++++---- sqlbuilder/update_statement_test.go | 162 ++++++----------- sqlbuilder/utils.go | 64 ++++++- tests/all_types_test.go | 27 ++- tests/init/data/test_sample.sql | 3 + tests/insert_test.go | 157 +++++++++------- tests/select_test.go | 44 ++--- tests/test_util.go | 16 +- tests/update_test.go | 267 +++++++++++++++++++++++----- 22 files changed, 660 insertions(+), 453 deletions(-) diff --git a/sqlbuilder/clause.go b/sqlbuilder/clause.go index 3733ee6..d6f5a36 100644 --- a/sqlbuilder/clause.go +++ b/sqlbuilder/clause.go @@ -70,7 +70,7 @@ func (q *queryData) writeProjections(statement statementType, projections []proj } func (q *queryData) writeFrom(statement statementType, table ReadableTable) error { - q.nextLine() + q.newLine() q.writeString("FROM") q.increaseIdent() @@ -81,7 +81,7 @@ func (q *queryData) writeFrom(statement statementType, table ReadableTable) erro } func (q *queryData) writeWhere(statement statementType, where Expression) error { - q.nextLine() + q.newLine() q.writeString("WHERE") q.increaseIdent() @@ -92,7 +92,7 @@ func (q *queryData) writeWhere(statement statementType, where Expression) error } func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClause) error { - q.nextLine() + q.newLine() q.writeString("GROUP BY") q.increaseIdent() @@ -103,7 +103,7 @@ func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClaus } func (q *queryData) writeOrderBy(statement statementType, orderBy []OrderByClause) error { - q.nextLine() + q.newLine() q.writeString("ORDER BY") q.increaseIdent() @@ -114,7 +114,7 @@ func (q *queryData) writeOrderBy(statement statementType, orderBy []OrderByClaus } func (q *queryData) writeHaving(statement statementType, having Expression) error { - q.nextLine() + q.newLine() q.writeString("HAVING") q.increaseIdent() @@ -124,7 +124,7 @@ func (q *queryData) writeHaving(statement statementType, having Expression) erro return err } -func (q *queryData) nextLine() { +func (q *queryData) newLine() { q.write([]byte{'\n'}) q.write(bytes.Repeat([]byte{' '}, q.ident)) } diff --git a/sqlbuilder/column.go b/sqlbuilder/column.go index 633d8dc..f2df417 100644 --- a/sqlbuilder/column.go +++ b/sqlbuilder/column.go @@ -112,3 +112,27 @@ func (c columnImpl) serialize(statement statementType, out *queryData, options . return nil } + +//------------------------------------------------------// +// Dummy type for select * AllColumns +type ColumnList []Column + +// projection interface implementation +func (cl ColumnList) isProjectionType() {} + +func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error { + projections := columnListToProjectionList(cl) + + err := serializeProjectionList(statement, projections, out) + + if err != nil { + return err + } + + return nil +} + +// column interface implementation +func (cl ColumnList) Name() string { return "" } +func (cl ColumnList) TableName() string { return "" } +func (cl ColumnList) setTableName(name string) {} diff --git a/sqlbuilder/delete_statement.go b/sqlbuilder/delete_statement.go index 8f8080e..e8ed466 100644 --- a/sqlbuilder/delete_statement.go +++ b/sqlbuilder/delete_statement.go @@ -32,7 +32,7 @@ func (d *deleteStatementImpl) serializeImpl(out *queryData) error { if d == nil { return errors.New("Delete expression. ") } - out.nextLine() + out.newLine() out.writeString("DELETE FROM") if d.table == nil { @@ -75,6 +75,6 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er return Query(d, db, destination) } -func (d *deleteStatementImpl) Execute(db execution.Db) (res sql.Result, err error) { - return Execute(d, db) +func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { + return Exec(d, db) } diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 824192b..e7e8089 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -4,8 +4,6 @@ import ( "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" - "github.com/serenize/snaker" - "reflect" "strings" ) @@ -13,16 +11,16 @@ type InsertStatement interface { Statement // Add a row of values to the insert Statement. - VALUES(values ...interface{}) InsertStatement + VALUES(value interface{}, values ...interface{}) InsertStatement // Model structure mapped to column names - MODEL(data interface{}) InsertStatement + USING(data interface{}) InsertStatement QUERY(selectStatement SelectStatement) InsertStatement RETURNING(projections ...projection) InsertStatement } -func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { +func newInsertStatement(t WritableTable, columns []column) InsertStatement { return &insertStatementImpl{ table: t, columns: columns, @@ -31,7 +29,7 @@ func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { type insertStatementImpl struct { table WritableTable - columns []Column + columns []column rows [][]clause query SelectStatement returning []projection @@ -39,74 +37,13 @@ type insertStatementImpl struct { errors []string } -func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) error { - return Query(i, db, destination) -} - -func (i *insertStatementImpl) Execute(db execution.Db) (res sql.Result, err error) { - return Execute(i, db) -} - -func (i *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { - if len(values) == 0 { - return i - } - - literalRow := []clause{} - - for _, value := range values { - if clause, ok := value.(clause); ok { - literalRow = append(literalRow, clause) - } else { - literalRow = append(literalRow, literal(value)) - } - } - - i.rows = append(i.rows, literalRow) +func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement { + i.rows = append(i.rows, unwindRowFromValues(value, values)) return i } -func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement { - if data == nil { - i.addError("MODEL : data is nil.") - return i - } - - value := reflect.Indirect(reflect.ValueOf(data)) - - if value.Kind() != reflect.Struct { - i.addError("MODEL : data is not struct or pointer to struct.") - return i - } - - rowValues := []clause{} - - for _, column := range i.columns { - columnName := column.Name() - structFieldName := snaker.SnakeToCamel(columnName) - - structField := value.FieldByName(structFieldName) - - if !structField.IsValid() { - i.addError("MODEL : Data structure doesn't contain field for column " + columnName) - return i - } - - var field interface{} - - fieldValue := reflect.Indirect(structField) - - if fieldValue.IsValid() { - field = fieldValue.Interface() - } else { - field = nil - } - - rowValues = append(rowValues, literal(field)) - } - - i.rows = append(i.rows, rowValues) - +func (i *insertStatementImpl) USING(data interface{}) InsertStatement { + i.rows = append(i.rows, unwindRowFromModel(i.columns, data)) return i } @@ -135,7 +72,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) queryData := &queryData{} - queryData.nextLine() + queryData.newLine() queryData.writeString("INSERT INTO") if isNil(i.table) { @@ -151,7 +88,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) if len(i.columns) > 0 { queryData.writeString("(") - err = serializeColumnList(insert_statement, i.columns, queryData) + err = serializeColumnNames(i.columns, queryData) if err != nil { return @@ -177,7 +114,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } queryData.increaseIdent() - queryData.nextLine() + queryData.newLine() queryData.writeString("(") if len(row) != len(i.columns) { @@ -204,7 +141,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(i.returning) > 0 { - queryData.nextLine() + queryData.newLine() queryData.writeString("RETURNING") err = queryData.writeProjections(insert_statement, i.returning) @@ -218,3 +155,11 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error) return } + +func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) error { + return Query(i, db, destination) +} + +func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { + return Exec(i, db) +} diff --git a/sqlbuilder/insert_statement_test.go b/sqlbuilder/insert_statement_test.go index 417d2e1..f674ae3 100644 --- a/sqlbuilder/insert_statement_test.go +++ b/sqlbuilder/insert_statement_test.go @@ -1,18 +1,11 @@ package sqlbuilder import ( - "fmt" "gotest.tools/assert" "testing" "time" ) -func TestInsertNoColumn(t *testing.T) { - _, _, err := table1.INSERT().VALUES().Sql() - - assert.Assert(t, err != nil) -} - func TestInsertNoRow(t *testing.T) { _, _, err := table1.INSERT(table1Col1).Sql() @@ -72,10 +65,8 @@ func TestInsertMultipleValues(t *testing.T) { sql, _, err := stmt.Sql() assert.NilError(t, err) - fmt.Println(sql) - expectedSql := ` -INSERT INTO db.table1 (col1,colFloat,col3) VALUES +INSERT INTO db.table1 (col1, colFloat, col3) VALUES ($1, $2, $3); ` @@ -91,10 +82,8 @@ func TestInsertMultipleRows(t *testing.T) { sql, _, err := stmt.Sql() assert.NilError(t, err) - fmt.Println(sql) - expectedSql := ` -INSERT INTO db.table1 (col1,colFloat) VALUES +INSERT INTO db.table1 (col1, colFloat) VALUES ($1, $2), ($3, $4), ($5, $6); @@ -117,16 +106,16 @@ func TestInsertValuesFromModel(t *testing.T) { } stmt := table1.INSERT(table1Col1, table1ColFloat). - MODEL(toInsert). - MODEL(&toInsert) + USING(toInsert). + USING(&toInsert) expectedSql := ` -INSERT INTO db.table1 (col1,colFloat) VALUES +INSERT INTO db.table1 (col1, colFloat) VALUES ($1, $2), ($3, $4); ` - assertQuery(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11)) + assertStatement(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11)) } func TestInsertValuesFromModelColumnMismatch(t *testing.T) { @@ -141,11 +130,10 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) { } stmt := table1.INSERT(table1Col1, table1ColFloat). - MODEL(toInsert) + USING(toInsert) _, _, err := stmt.Sql() - fmt.Println(err) assert.Assert(t, err != nil) } @@ -154,20 +142,23 @@ func TestInsertQuery(t *testing.T) { stmt := table1.INSERT(table1Col1). QUERY(table1.SELECT(table1Col1)) - stmtStr, _, err := stmt.Sql() - - assert.NilError(t, err) - - fmt.Println(stmtStr) + var expectedSql = ` +INSERT INTO db.table1 (col1) ( + SELECT table1.col1 AS "table1.col1" + FROM db.table1 +); +` + assertStatement(t, stmt, expectedSql) } func TestInsertDefaultValue(t *testing.T) { stmt := table1.INSERT(table1Col1, table1ColFloat). VALUES(DEFAULT, "two") - stmtStr, _, err := stmt.Sql() + var expectedSql = ` +INSERT INTO db.table1 (col1, colFloat) VALUES + (DEFAULT, $1); +` - assert.NilError(t, err) - - fmt.Println(stmtStr) + assertStatement(t, stmt, expectedSql, "two") } diff --git a/sqlbuilder/lock_statement.go b/sqlbuilder/lock_statement.go index b624e1b..b9271bb 100644 --- a/sqlbuilder/lock_statement.go +++ b/sqlbuilder/lock_statement.go @@ -63,7 +63,7 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error) out := &queryData{} - out.nextLine() + out.newLine() out.writeString("LOCK TABLE") for i, table := range l.tables { @@ -96,6 +96,6 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro return Query(l, db, destination) } -func (l *lockStatementImpl) Execute(db execution.Db) (sql.Result, error) { - return Execute(l, db) +func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) { + return Exec(l, db) } diff --git a/sqlbuilder/projection.go b/sqlbuilder/projection.go index f64746c..d25057e 100644 --- a/sqlbuilder/projection.go +++ b/sqlbuilder/projection.go @@ -3,21 +3,3 @@ package sqlbuilder type projection interface { serializeForProjection(statement statementType, out *queryData) error } - -//------------------------------------------------------// -// Dummy type for select * AllColumns -type ColumnList []Column - -func (cl ColumnList) isProjectionType() {} - -func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error { - projections := columnListToProjectionList(cl) - - err := serializeProjectionList(statement, projections, out) - - if err != nil { - return err - } - - return nil -} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 34088f7..3c4a7d9 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -83,7 +83,7 @@ func (s *selectStatementImpl) serialize(statement statementType, out *queryData, return err } - out.nextLine() + out.newLine() out.writeString(")") return nil @@ -94,7 +94,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { return errors.New("Select expression is nil. ") } - out.nextLine() + out.newLine() out.writeString("SELECT") if s.distinct { @@ -150,19 +150,19 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error { } if s.limit >= 0 { - out.nextLine() + out.newLine() out.writeString("LIMIT") out.insertPreparedArgument(s.limit) } if s.offset >= 0 { - out.nextLine() + out.newLine() out.writeString("OFFSET") out.insertPreparedArgument(s.offset) } if s.forUpdate { - out.nextLine() + out.newLine() out.writeString("FOR UPDATE") } @@ -238,6 +238,6 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er return Query(s, db, destination) } -func (s *selectStatementImpl) Execute(db execution.Db) (res sql.Result, err error) { - return Execute(s, db) +func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { + return Exec(s, db) } diff --git a/sqlbuilder/set_statement.go b/sqlbuilder/set_statement.go index 0615be8..d823476 100644 --- a/sqlbuilder/set_statement.go +++ b/sqlbuilder/set_statement.go @@ -114,7 +114,7 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData, op if wrap { out.decreaseIdent() - out.nextLine() + out.newLine() out.writeString(")") } @@ -130,19 +130,19 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { return errors.New("UNION Statement must have at least two SELECT statements.") } - out.nextLine() + out.newLine() out.writeString("(") out.increaseIdent() for i, selectStmt := range s.selects { - out.nextLine() + out.newLine() if i > 0 { out.writeString(s.operator) if s.all { out.writeString("ALL") } - out.nextLine() + out.newLine() } err := selectStmt.serialize(set_statement, out) @@ -153,7 +153,7 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { } out.decreaseIdent() - out.nextLine() + out.newLine() out.writeString(")") if s.orderBy != nil { @@ -164,13 +164,13 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error { } if s.limit >= 0 { - out.nextLine() + out.newLine() out.writeString("LIMIT") out.insertPreparedArgument(s.limit) } if s.offset >= 0 { - out.nextLine() + out.newLine() out.writeString("OFFSET") out.insertPreparedArgument(s.offset) } @@ -199,6 +199,6 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error return Query(s, db, destination) } -func (u *setStatementImpl) Execute(db execution.Db) (res sql.Result, err error) { - return Execute(u, db) +func (u *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { + return Exec(u, db) } diff --git a/sqlbuilder/set_statement_test.go b/sqlbuilder/set_statement_test.go index 0b583bf..a0db46d 100644 --- a/sqlbuilder/set_statement_test.go +++ b/sqlbuilder/set_statement_test.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "fmt" "gotest.tools/assert" "testing" ) @@ -29,7 +28,6 @@ func TestUnionTwoSelect(t *testing.T) { ).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -53,7 +51,6 @@ func TestUnionThreeSelect(t *testing.T) { table3.SELECT(table3Col1), ).Sql() - fmt.Println(query) assert.NilError(t, err) assert.Equal(t, query, ` ( @@ -83,7 +80,6 @@ func TestUnionWithOrderBy(t *testing.T) { ).ORDER_BY(table1Col1.ASC()).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -108,7 +104,6 @@ func TestUnionWithLimit(t *testing.T) { ).LIMIT(10).OFFSET(11).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -157,7 +152,6 @@ func TestUnionInUnion(t *testing.T) { queryStr, args, err := query.Sql() - fmt.Println(queryStr) assert.NilError(t, err) assert.Equal(t, len(args), 0) assert.Equal(t, queryStr, expectedSql) @@ -170,7 +164,6 @@ func TestUnionALL(t *testing.T) { ).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -194,7 +187,6 @@ func TestINTERSECT(t *testing.T) { ).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -218,7 +210,6 @@ func TestINTERSECT_ALL(t *testing.T) { ).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -242,7 +233,6 @@ func TestEXCEPT(t *testing.T) { ).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( @@ -266,7 +256,6 @@ func TestEXCEPT_ALL(t *testing.T) { ).Sql() assert.NilError(t, err) - fmt.Println(query) assert.Equal(t, query, ` ( ( diff --git a/sqlbuilder/statement.go b/sqlbuilder/statement.go index c07b3d9..6e26919 100644 --- a/sqlbuilder/statement.go +++ b/sqlbuilder/statement.go @@ -14,7 +14,7 @@ type Statement interface { DebugSql() (query string, err error) Query(db execution.Db, destination interface{}) error - Execute(db execution.Db) (sql.Result, error) + Exec(db execution.Db) (sql.Result, error) } func DebugSql(statement Statement) (string, error) { diff --git a/sqlbuilder/table.go b/sqlbuilder/table.go index 8b68494..f889a32 100644 --- a/sqlbuilder/table.go +++ b/sqlbuilder/table.go @@ -19,15 +19,17 @@ type readableTable interface { // Creates a right join tableName Expression using onCondition. RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable + // Creates a full join tableName Expression using onCondition. FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable + // Creates a cross join tableName Expression using onCondition. CROSS_JOIN(table ReadableTable) ReadableTable } // The sql tableName write interface. type writableTable interface { - INSERT(columns ...Column) InsertStatement - UPDATE(columns ...Column) UpdateStatement + INSERT(column column, columns ...column) InsertStatement + UPDATE(column column, columns ...column) UpdateStatement DELETE() DeleteStatement LOCK() LockStatement @@ -88,12 +90,12 @@ type writableTableInterfaceImpl struct { parent WritableTable } -func (w *writableTableInterfaceImpl) INSERT(columns ...Column) InsertStatement { - return newInsertStatement(w.parent, columns...) +func (w *writableTableInterfaceImpl) INSERT(column column, columns ...column) InsertStatement { + return newInsertStatement(w.parent, unwindColumns(column, columns...)) } -func (w *writableTableInterfaceImpl) UPDATE(columns ...Column) UpdateStatement { - return newUpdateStatement(w.parent, columns) +func (w *writableTableInterfaceImpl) UPDATE(column column, columns ...column) UpdateStatement { + return newUpdateStatement(w.parent, unwindColumns(column, columns...)) } func (w *writableTableInterfaceImpl) DELETE() DeleteStatement { @@ -229,7 +231,7 @@ func (t *joinTable) serialize(statement statementType, out *queryData, options . return } - out.nextLine() + out.newLine() switch t.join_type { case innerJoin: @@ -265,3 +267,19 @@ func (t *joinTable) serialize(statement statementType, out *queryData, options . return nil } + +func unwindColumns(column1 column, columns ...column) []column { + columnList := []column{} + + if val, ok := column1.(ColumnList); ok { + for _, col := range val { + columnList = append(columnList, col) + } + columnList = append(columnList, columns...) + } else { + columnList = append(columnList, column1) + columnList = append(columnList, columns...) + } + + return columnList +} diff --git a/sqlbuilder/test_utils.go b/sqlbuilder/test_utils.go index bad8947..845d884 100644 --- a/sqlbuilder/test_utils.go +++ b/sqlbuilder/test_utils.go @@ -1,7 +1,6 @@ package sqlbuilder import ( - "fmt" "gotest.tools/assert" "testing" ) @@ -66,7 +65,6 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) { out := queryData{} err := clause.serialize(select_statement, &out) - fmt.Println(err) assert.Assert(t, err != nil) assert.Equal(t, err.Error(), errString) } @@ -81,9 +79,16 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string assert.DeepEqual(t, out.args, args) } -func assertQuery(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { +func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { queryStr, args, err := query.Sql() assert.NilError(t, err) assert.Equal(t, queryStr, expectedQuery) assert.DeepEqual(t, args, expectedArgs) } + +func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { + _, _, err := stmt.Sql() + + assert.Assert(t, err != nil) + assert.Equal(t, err.Error(), errorStr) +} diff --git a/sqlbuilder/update_statement.go b/sqlbuilder/update_statement.go index d148a75..03b64aa 100644 --- a/sqlbuilder/update_statement.go +++ b/sqlbuilder/update_statement.go @@ -9,35 +9,37 @@ import ( type UpdateStatement interface { Statement - SET(values ...interface{}) UpdateStatement + SET(value interface{}, values ...interface{}) UpdateStatement + USING(data interface{}) UpdateStatement + WHERE(expression BoolExpression) UpdateStatement RETURNING(projections ...projection) UpdateStatement } -func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement { +func newUpdateStatement(table WritableTable, columns []column) UpdateStatement { return &updateStatementImpl{ table: table, columns: columns, + row: make([]clause, 0, len(columns)), } } type updateStatementImpl struct { - table WritableTable - columns []Column - updateValues []clause - where BoolExpression - returning []projection + table WritableTable + columns []column + row []clause + where BoolExpression + returning []projection } -func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement { +func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement { + u.row = unwindRowFromValues(value, values) - for _, value := range values { - if clause, ok := value.(clause); ok { - u.updateValues = append(u.updateValues, clause) - } else { - u.updateValues = append(u.updateValues, literal(value)) - } - } + return u +} + +func (u *updateStatementImpl) USING(modelData interface{}) UpdateStatement { + u.row = unwindRowFromModel(u.columns, modelData) return u } @@ -55,31 +57,36 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStateme func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) { out := &queryData{} - out.nextLine() + out.newLine() out.writeString("UPDATE") - if u.table == nil { - return "", nil, errors.New("nil tableName.") + if isNil(u.table) { + return "", nil, errors.New("table to update is nil") } if err = u.table.serialize(update_statement, out); err != nil { return } - if len(u.updateValues) == 0 { - return "", nil, errors.New("No column updated.") + if len(u.columns) == 0 { + return "", nil, errors.New("no columns selected") } + if len(u.row) == 0 { + return "", nil, errors.New("no values to updated") + } + + out.newLine() out.writeString("SET") if len(u.columns) > 1 { out.writeString("(") } - err = serializeColumnList(update_statement, u.columns, out) + err = serializeColumnNames(u.columns, out) if err != nil { - return "", nil, err + return } if len(u.columns) > 1 { @@ -88,28 +95,22 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) out.writeString("=") - if len(u.updateValues) > 1 { + if len(u.row) > 1 { out.writeString("(") } - for i, value := range u.updateValues { - if i > 0 { - out.writeString(", ") - } + err = serializeClauseList(update_statement, u.row, out) - err = value.serialize(update_statement, out) - - if err != nil { - return - } + if err != nil { + return } - if len(u.updateValues) > 1 { + if len(u.row) > 1 { out.writeString(")") } if u.where == nil { - return "", nil, errors.New("Updating without a WHERE clause.") + return "", nil, errors.New("WHERE clause not set") } if err = out.writeWhere(update_statement, u.where); err != nil { @@ -117,8 +118,10 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) } if len(u.returning) > 0 { - out.nextLine() + out.newLine() out.writeString("RETURNING") + out.increaseIdent() + out.increaseIdent() err = serializeProjectionList(update_statement, u.returning, out) @@ -139,6 +142,6 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er return Query(u, db, destination) } -func (u *updateStatementImpl) Execute(db execution.Db) (res sql.Result, err error) { - return Execute(u, db) +func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { + return Exec(u, db) } diff --git a/sqlbuilder/update_statement_test.go b/sqlbuilder/update_statement_test.go index 7b4ba79..3020f09 100644 --- a/sqlbuilder/update_statement_test.go +++ b/sqlbuilder/update_statement_test.go @@ -1,124 +1,76 @@ package sqlbuilder import ( - "fmt" - "gotest.tools/assert" "testing" ) -// -// UPDATE Statement tests ===================================================== -// +func TestUpdateWithOneValue(t *testing.T) { + expectedSql := ` +UPDATE db.table1 +SET colInt = $1 +WHERE table1.colInt >= $2; +` + stmt := table1.UPDATE(table1ColInt). + SET(1). + WHERE(table1ColInt.GT_EQ(Int(33))) -func TestUpdate(t *testing.T) { - stmt := table1.UPDATE(table1Col1, table1ColFloat). - SET(table1.SELECT(table1ColFloat, table2Col3)). + assertStatement(t, stmt, expectedSql, 1, int64(33)) +} + +func TestUpdateWithValues(t *testing.T) { + expectedSql := ` +UPDATE db.table1 +SET (colInt, colFloat) = ($1, $2) +WHERE table1.colInt >= $3; +` + stmt := table1.UPDATE(table1ColInt, table1ColFloat). + SET(1, 22.2). + WHERE(table1ColInt.GT_EQ(Int(33))) + + assertStatement(t, stmt, expectedSql, 1, 22.2, int64(33)) +} + +func TestUpdateOneColumnWithSelect(t *testing.T) { + expectedSql := ` +UPDATE db.table1 +SET colFloat = ( + SELECT table1.colFloat AS "table1.colFloat" + FROM db.table1 +) +WHERE table1.col1 = $1 +RETURNING table1.col1 AS "table1.col1"; +` + stmt := table1. + UPDATE(table1ColFloat). + SET( + table1.SELECT(table1ColFloat), + ). WHERE(table1Col1.EQ(Int(2))). RETURNING(table1Col1) - stmtStr, _, err := stmt.Sql() + assertStatement(t, stmt, expectedSql, int64(2)) +} - assert.NilError(t, err) - - fmt.Println(stmtStr) - - assert.Equal(t, stmtStr, ` -UPDATE db.table1 SET (col1,colFloat) = ( +func TestUpdateColumnsWithSelect(t *testing.T) { + expectedSql := ` +UPDATE db.table1 +SET (col1, colFloat) = ( SELECT table1.colFloat AS "table1.colFloat", table2.col3 AS "table2.col3" FROM db.table1 ) WHERE table1.col1 = $1 RETURNING table1.col1 AS "table1.col1"; -`) +` + stmt := table1.UPDATE(table1Col1, table1ColFloat). + SET(table1.SELECT(table1ColFloat, table2Col3)). + WHERE(table1Col1.EQ(Int(2))). + RETURNING(table1Col1) + + assertStatement(t, stmt, expectedSql, int64(2)) } -//func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) { -// stmt := table1.UPDATE().SET(nil, literal(1)) -// _, err := stmt.String() -// c.Assert(err, gc.NotNil) -//} -// -//func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, nil) -// _, err := stmt.String() -// c.Assert(err, gc.NotNil) -//} -// -//func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, literal(1)) -// _, err := stmt.String() -// c.Assert(err, gc.NotNil) -//} -// -//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, literal(1)) -// stmt.WHERE(EqString(table1ColFloat, 2)) -// sql, err := stmt.String() -// c.Assert(err, gc.IsNil) -// -// c.Assert( -// sql, -// gc.Equals, -// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2") -//} -// -//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) { -// stmt := table1.UPDATE().SET(table1.C("col1"), literal(1)) -// stmt.WHERE(EqString(table1ColFloat, 2)) -// sql, err := stmt.String() -// c.Assert(err, gc.IsNil) -// -// c.Assert( -// sql, -// gc.Equals, -// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2") -//} -// -//func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) { -// stmt := table1.UPDATE() -// stmt.SET(table1Col1, literal(1)) -// stmt.SET(table1ColFloat, literal(2)) -// stmt.WHERE(EqString(table1ColFloat, 3)) -// sql, err := stmt.String() -// c.Assert(err, gc.IsNil) -// -// c.Assert( -// sql, -// gc.Equals, -// "UPDATE db.table1 "+ -// "SET table1.col1=1, table1.col2=2 "+ -// "WHERE table1.col2=3") -//} -// -//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, literal(1)) -// stmt.WHERE(EqString(table1ColFloat, 2)) -// stmt.ORDER_BY(table1ColFloat) -// sql, err := stmt.String() -// c.Assert(err, gc.IsNil) -// -// c.Assert( -// sql, -// gc.Equals, -// "UPDATE db.table1 "+ -// "SET table1.col1=1 "+ -// "WHERE table1.col2=2 "+ -// "ORDER BY table1.col2") -//} -// -//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) { -// stmt := table1.UPDATE().SET(table1Col1, literal(1)) -// stmt.WHERE(EqString(table1ColFloat, 2)) -// stmt.LIMIT(5) -// sql, err := stmt.String() -// c.Assert(err, gc.IsNil) -// -// c.Assert( -// sql, -// gc.Equals, -// "UPDATE db.table1 "+ -// "SET table1.col1=1 "+ -// "WHERE table1.col2=2 "+ -// "LIMIT 5") -//} +func TestInvalidInputs(t *testing.T) { + assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1, 2), "WHERE clause not set") + assertStatementErr(t, table1.UPDATE(nil).SET(1, 2), "nil column in columns list") +} diff --git a/sqlbuilder/utils.go b/sqlbuilder/utils.go index 6400aea..68aeab8 100644 --- a/sqlbuilder/utils.go +++ b/sqlbuilder/utils.go @@ -4,6 +4,7 @@ import ( "database/sql" "errors" "github.com/go-jet/jet/sqlbuilder/execution" + "github.com/serenize/snaker" "reflect" ) @@ -83,7 +84,7 @@ func serializeProjectionList(statement statementType, projections []projection, for i, col := range projections { if i > 0 { out.writeString(",") - out.nextLine() + out.newLine() } if col == nil { @@ -98,14 +99,14 @@ func serializeProjectionList(statement statementType, projections []projection, return nil } -func serializeColumnList(statement statementType, columns []Column, out *queryData) error { +func serializeColumnNames(columns []column, out *queryData) error { for i, col := range columns { if i > 0 { - out.writeByte(',') + out.writeString(", ") } if col == nil { - return errors.New("nil column in columns list.") + return errors.New("nil column in columns list") } out.writeString(col.Name()) @@ -118,6 +119,59 @@ func isNil(v interface{}) bool { return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) } +func valueToClause(value interface{}) clause { + if clause, ok := value.(clause); ok { + return clause + } else { + return literal(value) + } +} + +func unwindRowFromModel(columns []column, data interface{}) []clause { + structValue := reflect.Indirect(reflect.ValueOf(data)) + + row := []clause{} + + if structValue.Kind() != reflect.Struct { + return row + } + + for _, column := range columns { + columnName := column.Name() + structFieldName := snaker.SnakeToCamel(columnName) + + structField := structValue.FieldByName(structFieldName) + + if !structField.IsValid() { + continue + } + + var field interface{} + + if structField.Kind() == reflect.Ptr && structField.IsNil() { + field = nil + } else { + field = reflect.Indirect(structField).Interface() + } + + row = append(row, literal(field)) + } + + return row +} + +func unwindRowFromValues(value interface{}, values []interface{}) []clause { + row := []clause{} + + allValues := append([]interface{}{value}, values...) + + for _, val := range allValues { + row = append(row, valueToClause(val)) + } + + return row +} + func columnListToProjectionList(columns []Column) []projection { var ret []projection @@ -138,7 +192,7 @@ func Query(statement Statement, db execution.Db, destination interface{}) error return execution.Query(db, query, args, destination) } -func Execute(statement Statement, db execution.Db) (res sql.Result, err error) { +func Exec(statement Statement, db execution.Db) (res sql.Result, err error) { query, args, err := statement.Sql() if err != nil { diff --git a/tests/all_types_test.go b/tests/all_types_test.go index 4c91385..770c079 100644 --- a/tests/all_types_test.go +++ b/tests/all_types_test.go @@ -2,7 +2,6 @@ package tests import ( "fmt" - "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet/sqlbuilder" "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" @@ -26,16 +25,32 @@ func TestAllTypesSelect(t *testing.T) { assert.DeepEqual(t, dest[1], allTypesRow1) } -func TestAllTypesInsert(t *testing.T) { - query := AllTypes.INSERT(AllTypes.AllColumns...). - MODEL(allTypesRow0). - MODEL(&allTypesRow1). +func TestAllTypesInsertModel(t *testing.T) { + query := AllTypes.INSERT(AllTypes.AllColumns). + USING(allTypesRow0). + USING(&allTypesRow1). RETURNING(AllTypes.AllColumns) dest := []model.AllTypes{} err := query.Query(db, &dest) - spew.Dump(dest[0]) + assert.NilError(t, err) + assert.Equal(t, len(dest), 2) + assert.DeepEqual(t, dest[0], allTypesRow0) + assert.DeepEqual(t, dest[1], allTypesRow1) +} + +func TestAllTypesInsertQuery(t *testing.T) { + query := AllTypes.INSERT(AllTypes.AllColumns). + QUERY( + AllTypes. + SELECT(AllTypes.AllColumns). + LIMIT(2), + ). + RETURNING(AllTypes.AllColumns) + + dest := []model.AllTypes{} + err := query.Query(db, &dest) assert.NilError(t, err) assert.Equal(t, len(dest), 2) diff --git a/tests/init/data/test_sample.sql b/tests/init/data/test_sample.sql index 221122d..3afedab 100644 --- a/tests/init/data/test_sample.sql +++ b/tests/init/data/test_sample.sql @@ -149,6 +149,9 @@ CREATE TABLE IF NOT EXISTS test_sample.link ( rel VARCHAR (50) ); +INSERT INTO test_sample.link (ID, url, name, description) VALUES + (0, 'http://www.youtube.com', 'Youtube' , ''); + -- Employee table --------------- diff --git a/tests/insert_test.go b/tests/insert_test.go index d6c834c..184a977 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -1,8 +1,6 @@ package tests import ( - "fmt" - "github.com/davecgh/go-spew/spew" . "github.com/go-jet/jet/sqlbuilder" "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model" . "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" @@ -11,61 +9,79 @@ import ( ) func TestInsertValues(t *testing.T) { - insertQuery := Link.INSERT(Link.URL, Link.Name, Link.Rel). - VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES("http://www.google.com", "Google", DEFAULT). - VALUES("http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES("http://www.bing.com", "Bing", DEFAULT). - RETURNING(Link.ID) - insertQueryStr, args, err := insertQuery.Sql() + cleanUpLinkTable(t) - assert.NilError(t, err) - assert.Equal(t, len(args), 8) + var expectedSql = ` +INSERT INTO test_sample.link (id, url, name, rel) VALUES + (100, 'http://www.postgresqltutorial.com', 'PostgreSQL Tutorial', DEFAULT), + (101, 'http://www.google.com', 'Google', DEFAULT), + (102, 'http://www.yahoo.com', 'Yahoo', NULL) +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description", + link.rel AS "link.rel"; +` - fmt.Println(insertQueryStr) + insertQuery := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Rel). + VALUES(100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(101, "http://www.google.com", "Google", DEFAULT). + VALUES(102, "http://www.yahoo.com", "Yahoo", nil). + RETURNING(Link.AllColumns) - assert.Equal(t, insertQueryStr, ` -INSERT INTO test_sample.link (url,name,rel) VALUES - ($1, $2, DEFAULT), - ($3, $4, DEFAULT), - ($5, $6, DEFAULT), - ($7, $8, DEFAULT) -RETURNING link.id AS "link.id"; -`) - res, err := insertQuery.Execute(db) + assertStatementSql(t, insertQuery, expectedSql, + 100, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", + 101, "http://www.google.com", "Google", + 102, "http://www.yahoo.com", "Yahoo", nil) + + insertedLinks := []model.Link{} + + err := insertQuery.Query(db, &insertedLinks) assert.NilError(t, err) - rowsAffected, err := res.RowsAffected() - assert.NilError(t, err) + assert.Equal(t, len(insertedLinks), 3) - assert.Equal(t, rowsAffected, int64(4)) - - link := []model.Link{} - - err = Link.SELECT(Link.AllColumns).Query(db, &link) - - assert.NilError(t, err) - - assert.Equal(t, len(link), 4) - - assert.DeepEqual(t, link[0], model.Link{ - ID: 1, + assert.DeepEqual(t, insertedLinks[0], model.Link{ + ID: 100, URL: "http://www.postgresqltutorial.com", Name: "PostgreSQL Tutorial", Rel: nil, }) - assert.DeepEqual(t, link[3], model.Link{ - ID: 4, - URL: "http://www.bing.com", - Name: "Bing", + assert.DeepEqual(t, insertedLinks[1], model.Link{ + ID: 101, + URL: "http://www.google.com", + Name: "Google", Rel: nil, }) + + assert.DeepEqual(t, insertedLinks[2], model.Link{ + ID: 102, + URL: "http://www.yahoo.com", + Name: "Yahoo", + Rel: nil, + }) + + allLinks := []model.Link{} + + err = Link.SELECT(Link.AllColumns). + WHERE(Link.ID.GT_EQ(Int(100))). + ORDER_BY(Link.ID). + Query(db, &allLinks) + + assert.NilError(t, err) + + assert.DeepEqual(t, insertedLinks, allLinks) } func TestInsertDataObject(t *testing.T) { + var expectedSql = ` +INSERT INTO test_sample.link (url, name) VALUES + ('http://www.duckduckgo.com', 'Duck Duck go'); +` + linkData := model.Link{ URL: "http://www.duckduckgo.com", Name: "Duck Duck go", @@ -74,47 +90,62 @@ func TestInsertDataObject(t *testing.T) { query := Link. INSERT(Link.URL, Link.Name). - MODEL(linkData) + USING(linkData) - queryStr, args, err := query.Sql() + assertStatementSql(t, query, expectedSql, "http://www.duckduckgo.com", "Duck Duck go") - assert.NilError(t, err) - assert.Equal(t, len(args), 2) - - fmt.Println(queryStr) - - result, err := query.Execute(db) + result, err := query.Exec(db) assert.NilError(t, err) - fmt.Println(result) + rowsAffected, err := result.RowsAffected() + + assert.Equal(t, rowsAffected, int64(1)) } func TestInsertQuery(t *testing.T) { - - _, err := Link.INSERT(Link.URL, Link.Name). - VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial").Execute(db) - + _, err := Link.DELETE(). + WHERE(Link.ID.NOT_EQ(Int(0)).AND(Link.Name.EQ(String("Youtube")))). + Exec(db) assert.NilError(t, err) + var expectedSql = ` +INSERT INTO test_sample.link (url, name) ( + SELECT link.url AS "link.url", + link.name AS "link.name" + FROM test_sample.link + WHERE link.id = 0 +) +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description", + link.rel AS "link.rel"; +` + query := Link. INSERT(Link.URL, Link.Name). - QUERY(Link.SELECT(Link.URL, Link.Name)) + QUERY( + SELECT(Link.URL, Link.Name). + FROM(Link). + WHERE(Link.ID.EQ(Int(0))), + ). + RETURNING(Link.AllColumns) - queryStr, args, err := query.Sql() + assertStatementSql(t, query, expectedSql, int64(0)) - assert.NilError(t, err) - assert.Equal(t, len(args), 0) + dest := []model.Link{} - fmt.Println(queryStr) - - _, err = query.Execute(db) + err = query.Query(db, &dest) assert.NilError(t, err) - allLinks := []model.Link{} - err = Link.SELECT(Link.AllColumns).Query(db, &allLinks) - assert.NilError(t, err) + youtubeLinks := []model.Link{} + err = Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Youtube"))). + Query(db, &youtubeLinks) - spew.Dump(allLinks) + assert.NilError(t, err) + assert.Equal(t, len(youtubeLinks), 2) } diff --git a/tests/select_test.go b/tests/select_test.go index 3d9f95b..fa9490f 100644 --- a/tests/select_test.go +++ b/tests/select_test.go @@ -27,7 +27,7 @@ WHERE actor.actor_id = 1; SELECT(Actor.AllColumns). WHERE(Actor.ActorID.EQ(Int(1))) - assertQuery(t, query, expectedSql, int64(1)) + assertStatementSql(t, query, expectedSql, int64(1)) actor := model.Actor{} err := query.Query(db, &actor) @@ -73,7 +73,7 @@ LIMIT 30; ORDER_BY(Payment.PaymentID.ASC()). LIMIT(30) - assertQuery(t, query, expectedSql, int64(30)) + assertStatementSql(t, query, expectedSql, int64(30)) dest := []model.Payment{} @@ -102,7 +102,7 @@ ORDER BY customer.customer_id ASC; query := Customer.SELECT(Customer.AllColumns).ORDER_BY(Customer.CustomerID.ASC()) - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) err := query.Query(db, &customers) assert.NilError(t, err) @@ -156,7 +156,7 @@ LIMIT 12; LIMIT(12) fmt.Println(query.Sql()) - assertQuery(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) + assertStatementSql(t, query, expectedSql, int64(1), int64(1), int64(10), int64(1), int64(2), int64(1), int64(12)) } func TestJoinQueryStruct(t *testing.T) { @@ -224,7 +224,7 @@ LIMIT 500; ORDER_BY(Film.FilmID.ASC()). LIMIT(500) - assertQuery(t, query, expectedSql, int64(500)) + assertStatementSql(t, query, expectedSql, int64(500)) var languageActorFilm []struct { model.Language @@ -291,7 +291,7 @@ LIMIT 15; WHERE(Film.Rating.EQ(enum.MpaaRating.NC17)). LIMIT(15) - assertQuery(t, query, expectedSql, int64(15)) + assertStatementSql(t, query, expectedSql, int64(15)) err := query.Query(db, &filmsPerLanguage) @@ -422,7 +422,7 @@ ORDER BY customer.customer_id ASC; SELECT(Customer.AllColumns, Address.AllColumns). ORDER_BY(Customer.CustomerID.ASC()) - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) allCustomersAndAddress := []struct { Address *model.Address @@ -475,7 +475,7 @@ LIMIT 1000; ORDER_BY(Customer.CustomerID.ASC()). LIMIT(1000) - assertQuery(t, query, expectedSql, int64(1000)) + assertStatementSql(t, query, expectedSql, int64(1000)) var customerAddresCrosJoined []struct { model.Customer @@ -513,7 +513,7 @@ ORDER BY employee.employee_id; SELECT(Employee.AllColumns, manager.AllColumns). ORDER_BY(Employee.EmployeeID) - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) var dest []struct { model2.Employee @@ -585,7 +585,7 @@ ORDER BY f1.film_id ASC; SELECT(f1.AllColumns, f2.AllColumns). ORDER_BY(f1.FilmID.ASC()) - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) type F1 model.Film type F2 model.Film @@ -627,7 +627,7 @@ LIMIT 1000; ORDER_BY(f1.Length.ASC(), f1.Title.ASC(), f2.Title.ASC()). LIMIT(1000) - assertQuery(t, query, expectedSql, int64(1000)) + assertStatementSql(t, query, expectedSql, int64(1000)) type thesameLengthFilms struct { Title1 string @@ -691,7 +691,7 @@ LIMIT 1000; // SELECT(Staff.StaffID, Staff.FirstName, Staff.LastName, Address.AllColumns, manager.StaffID, manager.FirstName). // DISTINCT() // -// assertQuery(t, query, expectedSql) +// assertStatementSql(t, query, expectedSql) // // staffs := []staff{} // @@ -747,7 +747,7 @@ FROM dvds.actor fmt.Println(query.Sql()) - assertQuery(t, query, expectedQuery) + assertStatementSql(t, query, expectedQuery) dest := []model.Actor{} @@ -765,7 +765,7 @@ FROM dvds.film; MAXf(Film.RentalRate).AS("max_film_rate"), ) - assertQuery(t, query, expectedQuery) + assertStatementSql(t, query, expectedQuery) ret := struct { MaxFilmRate float64 @@ -808,7 +808,7 @@ ORDER BY film.film_id ASC; ORDER_BY(Film.FilmID.ASC()) fmt.Println(query.Sql()) - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) maxRentalRateFilms := []model.Film{} err := query.Query(db, &maxRentalRateFilms) @@ -866,7 +866,7 @@ ORDER BY SUM(payment.amount) ASC; SUMf(Payment.Amount).GT(Float(100)), ) - assertQuery(t, customersPaymentQuery, expectedSql, float64(100)) + assertStatementSql(t, customersPaymentQuery, expectedSql, float64(100)) type CustomerPaymentSum struct { CustomerID int16 @@ -936,7 +936,7 @@ ORDER BY customer_payment_sum.amount_sum ASC; ). ORDER_BY(amountSum.ASC()) - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) type CustomerWithAmounts struct { Customer *model.Customer @@ -992,7 +992,7 @@ ORDER BY payment.payment_date ASC; WHERE(Payment.PaymentDate.LT(Timestamp(2007, 02, 14, 22, 16, 01, 0))). ORDER_BY(Payment.PaymentDate.ASC()) - assertQuery(t, query, expectedSql, "2007-02-14 22:16:01.000") + assertStatementSql(t, query, expectedSql, "2007-02-14 22:16:01.000") payments := []model.Payment{} @@ -1049,7 +1049,7 @@ OFFSET 20; queryStr, _, _ := query.Sql() fmt.Println("-" + queryStr + "-") - assertQuery(t, query, expectedQuery, float64(100), float64(200), int64(10), int64(20)) + assertStatementSql(t, query, expectedQuery, float64(100), float64(200), int64(10), int64(20)) dest := []model.Payment{} @@ -1088,7 +1088,7 @@ LIMIT 20; ORDER_BY(Payment.PaymentID.ASC()). LIMIT(20) - assertQuery(t, query, expectedQuery, int64(1), "ONE", int64(2), "TWO", int64(3), "THREE", "OTHER", int64(20)) + assertStatementSql(t, query, expectedQuery, int64(1), "ONE", int64(2), "TWO", int64(3), "THREE", "OTHER", int64(20)) dest := []struct { StaffIdNum string @@ -1111,11 +1111,11 @@ LOCK TABLE dvds.address IN EXCLUSIVE MODE NOWAIT; querySql, _, _ := query.Sql() fmt.Println("-" + querySql + "-") - assertQuery(t, query, expectedSql) + assertStatementSql(t, query, expectedSql) tx, _ := db.Begin() - _, err := query.Execute(tx) + _, err := query.Exec(tx) assert.NilError(t, err) } diff --git a/tests/test_util.go b/tests/test_util.go index c94d74e..0d682fa 100644 --- a/tests/test_util.go +++ b/tests/test_util.go @@ -10,7 +10,7 @@ import ( "time" ) -func assertQuery(t *testing.T, query sqlbuilder.Statement, expectedQuery string, expectedArgs ...interface{}) { +func assertStatementSql(t *testing.T, query sqlbuilder.Statement, expectedQuery string, expectedArgs ...interface{}) { _, args, err := query.Sql() assert.NilError(t, err) //assert.Equal(t, queryStr, expectedQuery) @@ -21,6 +21,20 @@ func assertQuery(t *testing.T, query sqlbuilder.Statement, expectedQuery string, assert.Equal(t, debuqSql, expectedQuery) } +func assertExec(t *testing.T, stmt sqlbuilder.Statement, rowsAffected int64) { + res, err := stmt.Exec(db) + + assert.NilError(t, err) + rows, err := res.RowsAffected() + assert.NilError(t, err) + assert.Equal(t, rows, rowsAffected) +} + +func assertExecErr(t *testing.T, stmt sqlbuilder.Statement, errorStr string) { + _, err := stmt.Exec(db) + + assert.Equal(t, err.Error(), errorStr) +} func boolPtr(b bool) *bool { return &b } diff --git a/tests/update_test.go b/tests/update_test.go index 728d660..26997f1 100644 --- a/tests/update_test.go +++ b/tests/update_test.go @@ -4,80 +4,261 @@ import ( "fmt" . "github.com/go-jet/jet/sqlbuilder" "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/model" - "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" + . "github.com/go-jet/jet/tests/.test_files/dvd_rental/test_sample/table" "gotest.tools/assert" "testing" ) func TestUpdateValues(t *testing.T) { - _, err := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel). - VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES("http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES("http://www.bing.com", "Bing", DEFAULT). - RETURNING(table.Link.ID).Execute(db) + setupLinkTableForUpdateTest(t) - assert.NilError(t, err) - - query := table.Link. - UPDATE(table.Link.Name, table.Link.URL). + query := Link. + UPDATE(Link.Name, Link.URL). SET("Bong", "http://bong.com"). - WHERE(table.Link.Name.EQ(String("Bing"))) + WHERE(Link.Name.EQ(String("Bing"))) - queryStr, args, err := query.Sql() + fmt.Println(query.DebugSql()) - assert.NilError(t, err) - assert.Equal(t, len(args), 3) - fmt.Println(queryStr) + var expectedSql = ` +UPDATE test_sample.link +SET (name, url) = ('Bong', 'http://bong.com') +WHERE link.name = 'Bing'; +` + fmt.Println(query.Sql()) - res, err := query.Execute(db) + assertStatementSql(t, query, expectedSql, "Bong", "http://bong.com", "Bing") - assert.NilError(t, err) - - fmt.Println(res) + assertExec(t, query, 1) links := []model.Link{} - err = table.Link.SELECT(table.Link.AllColumns). - WHERE(table.Link.Name.EQ(String("Bong"))). + err := Link. + SELECT(Link.AllColumns). + WHERE(Link.Name.EQ(String("Bong"))). Query(db, &links) assert.NilError(t, err) + assert.Equal(t, len(links), 1) + assert.DeepEqual(t, links[0], model.Link{ + ID: 204, + URL: "http://bong.com", + Name: "Bong", + }) +} - //spew.Dump(links) +func TestUpdateWithSubQueries(t *testing.T) { + setupLinkTableForUpdateTest(t) + + query := Link. + UPDATE(Link.Name, Link.URL). + SET( + SELECT(String("Bong")), + SELECT(Link.URL). + FROM(Link). + WHERE(Link.Name.EQ(String("Bing"))), + ). + WHERE(Link.Name.EQ(String("Bing"))) + + expectedSql := ` +UPDATE test_sample.link +SET (name, url) = (( + SELECT 'Bong' +), ( + SELECT link.url AS "link.url" + FROM test_sample.link + WHERE link.name = 'Bing' +)) +WHERE link.name = 'Bing'; +` + + assertStatementSql(t, query, expectedSql, "Bong", "Bing", "Bing") + + assertExec(t, query, 1) } func TestUpdateAndReturning(t *testing.T) { - _, err := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel). - VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). - VALUES("http://www.ask.com", "Ask", DEFAULT). - VALUES("http://www.ask.com", "Ask", DEFAULT). - VALUES("http://www.yahoo.com", "Yahoo", DEFAULT). - VALUES("http://www.bing.com", "Bing", DEFAULT). - RETURNING(table.Link.ID).Execute(db) + setupLinkTableForUpdateTest(t) - assert.NilError(t, err) + expectedSql := ` +UPDATE test_sample.link +SET (name, url) = ('DuckDuckGo', 'http://www.duckduckgo.com') +WHERE link.name = 'Ask' +RETURNING link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description", + link.rel AS "link.rel"; +` - stmt := table.Link. - UPDATE(table.Link.Name, table.Link.URL). + stmt := Link. + UPDATE(Link.Name, Link.URL). SET("DuckDuckGo", "http://www.duckduckgo.com"). - WHERE(table.Link.Name.EQ(String("Ask"))). - RETURNING(table.Link.AllColumns) + WHERE(Link.Name.EQ(String("Ask"))). + RETURNING(Link.AllColumns) - stmtStr, args, err := stmt.Sql() - - assert.NilError(t, err) - assert.Equal(t, len(args), 3) - fmt.Println(stmtStr) + assertStatementSql(t, stmt, expectedSql, "DuckDuckGo", "http://www.duckduckgo.com", "Ask") links := []model.Link{} - err = stmt.Query(db, &links) + err := stmt.Query(db, &links) assert.NilError(t, err) - assert.Equal(t, len(links), 2) - assert.Equal(t, links[0].Name, "DuckDuckGo") - assert.Equal(t, links[1].Name, "DuckDuckGo") } + +func TestUpdateWithSelect(t *testing.T) { + + stmt := Link.UPDATE(Link.AllColumns). + SET( + Link. + SELECT(Link.AllColumns). + WHERE(Link.ID.EQ(Int(0))), + ). + WHERE(Link.ID.EQ(Int(0))) + + expectedSql := ` +UPDATE test_sample.link +SET (id, url, name, description, rel) = ( + SELECT link.id AS "link.id", + link.url AS "link.url", + link.name AS "link.name", + link.description AS "link.description", + link.rel AS "link.rel" + FROM test_sample.link + WHERE link.id = 0 +) +WHERE link.id = 0; +` + assertStatementSql(t, stmt, expectedSql, int64(0), int64(0)) + + assertExec(t, stmt, 1) +} + +func TestUpdateWithInvalidSelect(t *testing.T) { + + stmt := Link.UPDATE(Link.AllColumns). + SET( + Link. + SELECT(Link.ID, Link.Name). + WHERE(Link.ID.EQ(Int(0))), + ). + WHERE(Link.ID.EQ(Int(0))) + + var expectedSql = ` +UPDATE test_sample.link +SET (id, url, name, description, rel) = ( + SELECT link.id AS "link.id", + link.name AS "link.name" + FROM test_sample.link + WHERE link.id = 0 +) +WHERE link.id = 0; +` + assertStatementSql(t, stmt, expectedSql, int64(0), int64(0)) + + assertExecErr(t, stmt, "pq: number of columns does not match number of values") +} + +func TestUpdateWithModelData(t *testing.T) { + setupLinkTableForUpdateTest(t) + + link := model.Link{ + ID: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link. + UPDATE(Link.AllColumns). + USING(link). + WHERE(Link.ID.EQ(Int(int64(link.ID)))) + + expectedSql := ` +UPDATE test_sample.link +SET (id, url, name, description, rel) = (201, 'http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL) +WHERE link.id = 201; +` + assertStatementSql(t, stmt, expectedSql, int32(201), "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) + + assertExec(t, stmt, 1) +} + +func TestUpdateWithModelDataAndPredefinedColumnList(t *testing.T) { + + setupLinkTableForUpdateTest(t) + + link := model.Link{ + ID: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + updateColumnList := ColumnList{Link.Rel, Link.Name, Link.URL} + + stmt := Link. + UPDATE(updateColumnList). + USING(link). + WHERE(Link.ID.EQ(Int(int64(link.ID)))) + + var expectedSql = ` +UPDATE test_sample.link +SET (rel, name, url) = (NULL, 'DuckDuckGo', 'http://www.duckduckgo.com') +WHERE link.id = 201; +` + assertStatementSql(t, stmt, expectedSql, nil, "DuckDuckGo", "http://www.duckduckgo.com", int64(201)) + + assertExec(t, stmt, 1) +} + +func TestUpdateWithInvalidModelData(t *testing.T) { + + setupLinkTableForUpdateTest(t) + + link := struct { + Ident int + URL string + Name string + Description *string + Rel *string + }{ + Ident: 201, + URL: "http://www.duckduckgo.com", + Name: "DuckDuckGo", + } + + stmt := Link. + UPDATE(Link.AllColumns). + USING(link). + WHERE(Link.ID.EQ(Int(int64(link.Ident)))) + + var expectedSql = ` +UPDATE test_sample.link +SET (id, url, name, description, rel) = ('http://www.duckduckgo.com', 'DuckDuckGo', NULL, NULL) +WHERE link.id = 201; +` + assertStatementSql(t, stmt, expectedSql, "http://www.duckduckgo.com", "DuckDuckGo", nil, nil, int64(201)) + + assertExecErr(t, stmt, "pq: number of columns does not match number of values") +} + +func setupLinkTableForUpdateTest(t *testing.T) { + + cleanUpLinkTable(t) + + _, err := Link.INSERT(Link.ID, Link.URL, Link.Name, Link.Rel). + VALUES(200, "http://www.postgresqltutorial.com", "PostgreSQL Tutorial", DEFAULT). + VALUES(201, "http://www.ask.com", "Ask", DEFAULT). + VALUES(202, "http://www.ask.com", "Ask", DEFAULT). + VALUES(203, "http://www.yahoo.com", "Yahoo", DEFAULT). + VALUES(204, "http://www.bing.com", "Bing", DEFAULT). + Exec(db) + + assert.NilError(t, err) +} + +func cleanUpLinkTable(t *testing.T) { + _, err := Link.DELETE().WHERE(Link.ID.GT(Int(0))).Exec(db) + assert.NilError(t, err) +}