Insert and Update statement improvements.

This commit is contained in:
go-jet 2019-06-14 14:35:50 +02:00
parent 038a4b9dd0
commit a4feb66692
22 changed files with 660 additions and 453 deletions

View file

@ -70,7 +70,7 @@ func (q *queryData) writeProjections(statement statementType, projections []proj
}
func (q *queryData) writeFrom(statement statementType, table ReadableTable) error {
q.nextLine()
q.newLine()
q.writeString("FROM")
q.increaseIdent()
@ -81,7 +81,7 @@ func (q *queryData) writeFrom(statement statementType, table ReadableTable) erro
}
func (q *queryData) writeWhere(statement statementType, where Expression) error {
q.nextLine()
q.newLine()
q.writeString("WHERE")
q.increaseIdent()
@ -92,7 +92,7 @@ func (q *queryData) writeWhere(statement statementType, where Expression) error
}
func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClause) error {
q.nextLine()
q.newLine()
q.writeString("GROUP BY")
q.increaseIdent()
@ -103,7 +103,7 @@ func (q *queryData) writeGroupBy(statement statementType, groupBy []groupByClaus
}
func (q *queryData) writeOrderBy(statement statementType, orderBy []OrderByClause) error {
q.nextLine()
q.newLine()
q.writeString("ORDER BY")
q.increaseIdent()
@ -114,7 +114,7 @@ func (q *queryData) writeOrderBy(statement statementType, orderBy []OrderByClaus
}
func (q *queryData) writeHaving(statement statementType, having Expression) error {
q.nextLine()
q.newLine()
q.writeString("HAVING")
q.increaseIdent()
@ -124,7 +124,7 @@ func (q *queryData) writeHaving(statement statementType, having Expression) erro
return err
}
func (q *queryData) nextLine() {
func (q *queryData) newLine() {
q.write([]byte{'\n'})
q.write(bytes.Repeat([]byte{' '}, q.ident))
}

View file

@ -112,3 +112,27 @@ func (c columnImpl) serialize(statement statementType, out *queryData, options .
return nil
}
//------------------------------------------------------//
// Dummy type for select * AllColumns
type ColumnList []Column
// projection interface implementation
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error {
projections := columnListToProjectionList(cl)
err := serializeProjectionList(statement, projections, out)
if err != nil {
return err
}
return nil
}
// column interface implementation
func (cl ColumnList) Name() string { return "" }
func (cl ColumnList) TableName() string { return "" }
func (cl ColumnList) setTableName(name string) {}

View file

@ -32,7 +32,7 @@ func (d *deleteStatementImpl) serializeImpl(out *queryData) error {
if d == nil {
return errors.New("Delete expression. ")
}
out.nextLine()
out.newLine()
out.writeString("DELETE FROM")
if d.table == nil {
@ -75,6 +75,6 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(d, db, destination)
}
func (d *deleteStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(d, db)
func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(d, db)
}

View file

@ -4,8 +4,6 @@ import (
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
"github.com/serenize/snaker"
"reflect"
"strings"
)
@ -13,16 +11,16 @@ type InsertStatement interface {
Statement
// Add a row of values to the insert Statement.
VALUES(values ...interface{}) InsertStatement
VALUES(value interface{}, values ...interface{}) InsertStatement
// Model structure mapped to column names
MODEL(data interface{}) InsertStatement
USING(data interface{}) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
RETURNING(projections ...projection) InsertStatement
}
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
func newInsertStatement(t WritableTable, columns []column) InsertStatement {
return &insertStatementImpl{
table: t,
columns: columns,
@ -31,7 +29,7 @@ func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
type insertStatementImpl struct {
table WritableTable
columns []Column
columns []column
rows [][]clause
query SelectStatement
returning []projection
@ -39,74 +37,13 @@ type insertStatementImpl struct {
errors []string
}
func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(i, db, destination)
}
func (i *insertStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(i, db)
}
func (i *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
if len(values) == 0 {
return i
}
literalRow := []clause{}
for _, value := range values {
if clause, ok := value.(clause); ok {
literalRow = append(literalRow, clause)
} else {
literalRow = append(literalRow, literal(value))
}
}
i.rows = append(i.rows, literalRow)
func (i *insertStatementImpl) VALUES(value interface{}, values ...interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromValues(value, values))
return i
}
func (i *insertStatementImpl) MODEL(data interface{}) InsertStatement {
if data == nil {
i.addError("MODEL : data is nil.")
return i
}
value := reflect.Indirect(reflect.ValueOf(data))
if value.Kind() != reflect.Struct {
i.addError("MODEL : data is not struct or pointer to struct.")
return i
}
rowValues := []clause{}
for _, column := range i.columns {
columnName := column.Name()
structFieldName := snaker.SnakeToCamel(columnName)
structField := value.FieldByName(structFieldName)
if !structField.IsValid() {
i.addError("MODEL : Data structure doesn't contain field for column " + columnName)
return i
}
var field interface{}
fieldValue := reflect.Indirect(structField)
if fieldValue.IsValid() {
field = fieldValue.Interface()
} else {
field = nil
}
rowValues = append(rowValues, literal(field))
}
i.rows = append(i.rows, rowValues)
func (i *insertStatementImpl) USING(data interface{}) InsertStatement {
i.rows = append(i.rows, unwindRowFromModel(i.columns, data))
return i
}
@ -135,7 +72,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
queryData := &queryData{}
queryData.nextLine()
queryData.newLine()
queryData.writeString("INSERT INTO")
if isNil(i.table) {
@ -151,7 +88,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
if len(i.columns) > 0 {
queryData.writeString("(")
err = serializeColumnList(insert_statement, i.columns, queryData)
err = serializeColumnNames(i.columns, queryData)
if err != nil {
return
@ -177,7 +114,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
queryData.increaseIdent()
queryData.nextLine()
queryData.newLine()
queryData.writeString("(")
if len(row) != len(i.columns) {
@ -204,7 +141,7 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
}
if len(i.returning) > 0 {
queryData.nextLine()
queryData.newLine()
queryData.writeString("RETURNING")
err = queryData.writeProjections(insert_statement, i.returning)
@ -218,3 +155,11 @@ func (i *insertStatementImpl) Sql() (sql string, args []interface{}, err error)
return
}
func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) error {
return Query(i, db, destination)
}
func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(i, db)
}

View file

@ -1,18 +1,11 @@
package sqlbuilder
import (
"fmt"
"gotest.tools/assert"
"testing"
"time"
)
func TestInsertNoColumn(t *testing.T) {
_, _, err := table1.INSERT().VALUES().Sql()
assert.Assert(t, err != nil)
}
func TestInsertNoRow(t *testing.T) {
_, _, err := table1.INSERT(table1Col1).Sql()
@ -72,10 +65,8 @@ func TestInsertMultipleValues(t *testing.T) {
sql, _, err := stmt.Sql()
assert.NilError(t, err)
fmt.Println(sql)
expectedSql := `
INSERT INTO db.table1 (col1,colFloat,col3) VALUES
INSERT INTO db.table1 (col1, colFloat, col3) VALUES
($1, $2, $3);
`
@ -91,10 +82,8 @@ func TestInsertMultipleRows(t *testing.T) {
sql, _, err := stmt.Sql()
assert.NilError(t, err)
fmt.Println(sql)
expectedSql := `
INSERT INTO db.table1 (col1,colFloat) VALUES
INSERT INTO db.table1 (col1, colFloat) VALUES
($1, $2),
($3, $4),
($5, $6);
@ -117,16 +106,16 @@ func TestInsertValuesFromModel(t *testing.T) {
}
stmt := table1.INSERT(table1Col1, table1ColFloat).
MODEL(toInsert).
MODEL(&toInsert)
USING(toInsert).
USING(&toInsert)
expectedSql := `
INSERT INTO db.table1 (col1,colFloat) VALUES
INSERT INTO db.table1 (col1, colFloat) VALUES
($1, $2),
($3, $4);
`
assertQuery(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11))
assertStatement(t, stmt, expectedSql, int(1), float64(1.11), int(1), float64(1.11))
}
func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
@ -141,11 +130,10 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
}
stmt := table1.INSERT(table1Col1, table1ColFloat).
MODEL(toInsert)
USING(toInsert)
_, _, err := stmt.Sql()
fmt.Println(err)
assert.Assert(t, err != nil)
}
@ -154,20 +142,23 @@ func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1))
stmtStr, _, err := stmt.Sql()
assert.NilError(t, err)
fmt.Println(stmtStr)
var expectedSql = `
INSERT INTO db.table1 (col1) (
SELECT table1.col1 AS "table1.col1"
FROM db.table1
);
`
assertStatement(t, stmt, expectedSql)
}
func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1ColFloat).
VALUES(DEFAULT, "two")
stmtStr, _, err := stmt.Sql()
var expectedSql = `
INSERT INTO db.table1 (col1, colFloat) VALUES
(DEFAULT, $1);
`
assert.NilError(t, err)
fmt.Println(stmtStr)
assertStatement(t, stmt, expectedSql, "two")
}

View file

@ -63,7 +63,7 @@ func (l *lockStatementImpl) Sql() (query string, args []interface{}, err error)
out := &queryData{}
out.nextLine()
out.newLine()
out.writeString("LOCK TABLE")
for i, table := range l.tables {
@ -96,6 +96,6 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro
return Query(l, db, destination)
}
func (l *lockStatementImpl) Execute(db execution.Db) (sql.Result, error) {
return Execute(l, db)
func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) {
return Exec(l, db)
}

View file

@ -3,21 +3,3 @@ package sqlbuilder
type projection interface {
serializeForProjection(statement statementType, out *queryData) error
}
//------------------------------------------------------//
// Dummy type for select * AllColumns
type ColumnList []Column
func (cl ColumnList) isProjectionType() {}
func (cl ColumnList) serializeForProjection(statement statementType, out *queryData) error {
projections := columnListToProjectionList(cl)
err := serializeProjectionList(statement, projections, out)
if err != nil {
return err
}
return nil
}

View file

@ -83,7 +83,7 @@ func (s *selectStatementImpl) serialize(statement statementType, out *queryData,
return err
}
out.nextLine()
out.newLine()
out.writeString(")")
return nil
@ -94,7 +94,7 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
return errors.New("Select expression is nil. ")
}
out.nextLine()
out.newLine()
out.writeString("SELECT")
if s.distinct {
@ -150,19 +150,19 @@ func (s *selectStatementImpl) serializeImpl(out *queryData) error {
}
if s.limit >= 0 {
out.nextLine()
out.newLine()
out.writeString("LIMIT")
out.insertPreparedArgument(s.limit)
}
if s.offset >= 0 {
out.nextLine()
out.newLine()
out.writeString("OFFSET")
out.insertPreparedArgument(s.offset)
}
if s.forUpdate {
out.nextLine()
out.newLine()
out.writeString("FOR UPDATE")
}
@ -238,6 +238,6 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(s, db, destination)
}
func (s *selectStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(s, db)
func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(s, db)
}

View file

@ -114,7 +114,7 @@ func (s *setStatementImpl) serialize(statement statementType, out *queryData, op
if wrap {
out.decreaseIdent()
out.nextLine()
out.newLine()
out.writeString(")")
}
@ -130,19 +130,19 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
return errors.New("UNION Statement must have at least two SELECT statements.")
}
out.nextLine()
out.newLine()
out.writeString("(")
out.increaseIdent()
for i, selectStmt := range s.selects {
out.nextLine()
out.newLine()
if i > 0 {
out.writeString(s.operator)
if s.all {
out.writeString("ALL")
}
out.nextLine()
out.newLine()
}
err := selectStmt.serialize(set_statement, out)
@ -153,7 +153,7 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
}
out.decreaseIdent()
out.nextLine()
out.newLine()
out.writeString(")")
if s.orderBy != nil {
@ -164,13 +164,13 @@ func (s *setStatementImpl) serializeImpl(out *queryData) error {
}
if s.limit >= 0 {
out.nextLine()
out.newLine()
out.writeString("LIMIT")
out.insertPreparedArgument(s.limit)
}
if s.offset >= 0 {
out.nextLine()
out.newLine()
out.writeString("OFFSET")
out.insertPreparedArgument(s.offset)
}
@ -199,6 +199,6 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error
return Query(s, db, destination)
}
func (u *setStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(u, db)
func (u *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(u, db)
}

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"fmt"
"gotest.tools/assert"
"testing"
)
@ -29,7 +28,6 @@ func TestUnionTwoSelect(t *testing.T) {
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -53,7 +51,6 @@ func TestUnionThreeSelect(t *testing.T) {
table3.SELECT(table3Col1),
).Sql()
fmt.Println(query)
assert.NilError(t, err)
assert.Equal(t, query, `
(
@ -83,7 +80,6 @@ func TestUnionWithOrderBy(t *testing.T) {
).ORDER_BY(table1Col1.ASC()).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -108,7 +104,6 @@ func TestUnionWithLimit(t *testing.T) {
).LIMIT(10).OFFSET(11).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -157,7 +152,6 @@ func TestUnionInUnion(t *testing.T) {
queryStr, args, err := query.Sql()
fmt.Println(queryStr)
assert.NilError(t, err)
assert.Equal(t, len(args), 0)
assert.Equal(t, queryStr, expectedSql)
@ -170,7 +164,6 @@ func TestUnionALL(t *testing.T) {
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -194,7 +187,6 @@ func TestINTERSECT(t *testing.T) {
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -218,7 +210,6 @@ func TestINTERSECT_ALL(t *testing.T) {
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -242,7 +233,6 @@ func TestEXCEPT(t *testing.T) {
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(
@ -266,7 +256,6 @@ func TestEXCEPT_ALL(t *testing.T) {
).Sql()
assert.NilError(t, err)
fmt.Println(query)
assert.Equal(t, query, `
(
(

View file

@ -14,7 +14,7 @@ type Statement interface {
DebugSql() (query string, err error)
Query(db execution.Db, destination interface{}) error
Execute(db execution.Db) (sql.Result, error)
Exec(db execution.Db) (sql.Result, error)
}
func DebugSql(statement Statement) (string, error) {

View file

@ -19,15 +19,17 @@ type readableTable interface {
// Creates a right join tableName Expression using onCondition.
RIGHT_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a full join tableName Expression using onCondition.
FULL_JOIN(table ReadableTable, onCondition BoolExpression) ReadableTable
// Creates a cross join tableName Expression using onCondition.
CROSS_JOIN(table ReadableTable) ReadableTable
}
// The sql tableName write interface.
type writableTable interface {
INSERT(columns ...Column) InsertStatement
UPDATE(columns ...Column) UpdateStatement
INSERT(column column, columns ...column) InsertStatement
UPDATE(column column, columns ...column) UpdateStatement
DELETE() DeleteStatement
LOCK() LockStatement
@ -88,12 +90,12 @@ type writableTableInterfaceImpl struct {
parent WritableTable
}
func (w *writableTableInterfaceImpl) INSERT(columns ...Column) InsertStatement {
return newInsertStatement(w.parent, columns...)
func (w *writableTableInterfaceImpl) INSERT(column column, columns ...column) InsertStatement {
return newInsertStatement(w.parent, unwindColumns(column, columns...))
}
func (w *writableTableInterfaceImpl) UPDATE(columns ...Column) UpdateStatement {
return newUpdateStatement(w.parent, columns)
func (w *writableTableInterfaceImpl) UPDATE(column column, columns ...column) UpdateStatement {
return newUpdateStatement(w.parent, unwindColumns(column, columns...))
}
func (w *writableTableInterfaceImpl) DELETE() DeleteStatement {
@ -229,7 +231,7 @@ func (t *joinTable) serialize(statement statementType, out *queryData, options .
return
}
out.nextLine()
out.newLine()
switch t.join_type {
case innerJoin:
@ -265,3 +267,19 @@ func (t *joinTable) serialize(statement statementType, out *queryData, options .
return nil
}
func unwindColumns(column1 column, columns ...column) []column {
columnList := []column{}
if val, ok := column1.(ColumnList); ok {
for _, col := range val {
columnList = append(columnList, col)
}
columnList = append(columnList, columns...)
} else {
columnList = append(columnList, column1)
columnList = append(columnList, columns...)
}
return columnList
}

View file

@ -1,7 +1,6 @@
package sqlbuilder
import (
"fmt"
"gotest.tools/assert"
"testing"
)
@ -66,7 +65,6 @@ func assertClauseSerializeErr(t *testing.T, clause clause, errString string) {
out := queryData{}
err := clause.serialize(select_statement, &out)
fmt.Println(err)
assert.Assert(t, err != nil)
assert.Equal(t, err.Error(), errString)
}
@ -81,9 +79,16 @@ func assertProjectionSerialize(t *testing.T, projection projection, query string
assert.DeepEqual(t, out.args, args)
}
func assertQuery(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) {
queryStr, args, err := query.Sql()
assert.NilError(t, err)
assert.Equal(t, queryStr, expectedQuery)
assert.DeepEqual(t, args, expectedArgs)
}
func assertStatementErr(t *testing.T, stmt Statement, errorStr string) {
_, _, err := stmt.Sql()
assert.Assert(t, err != nil)
assert.Equal(t, err.Error(), errorStr)
}

View file

@ -9,35 +9,37 @@ import (
type UpdateStatement interface {
Statement
SET(values ...interface{}) UpdateStatement
SET(value interface{}, values ...interface{}) UpdateStatement
USING(data interface{}) UpdateStatement
WHERE(expression BoolExpression) UpdateStatement
RETURNING(projections ...projection) UpdateStatement
}
func newUpdateStatement(table WritableTable, columns []Column) UpdateStatement {
func newUpdateStatement(table WritableTable, columns []column) UpdateStatement {
return &updateStatementImpl{
table: table,
columns: columns,
row: make([]clause, 0, len(columns)),
}
}
type updateStatementImpl struct {
table WritableTable
columns []Column
updateValues []clause
where BoolExpression
returning []projection
table WritableTable
columns []column
row []clause
where BoolExpression
returning []projection
}
func (u *updateStatementImpl) SET(values ...interface{}) UpdateStatement {
func (u *updateStatementImpl) SET(value interface{}, values ...interface{}) UpdateStatement {
u.row = unwindRowFromValues(value, values)
for _, value := range values {
if clause, ok := value.(clause); ok {
u.updateValues = append(u.updateValues, clause)
} else {
u.updateValues = append(u.updateValues, literal(value))
}
}
return u
}
func (u *updateStatementImpl) USING(modelData interface{}) UpdateStatement {
u.row = unwindRowFromModel(u.columns, modelData)
return u
}
@ -55,31 +57,36 @@ func (u *updateStatementImpl) RETURNING(projections ...projection) UpdateStateme
func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error) {
out := &queryData{}
out.nextLine()
out.newLine()
out.writeString("UPDATE")
if u.table == nil {
return "", nil, errors.New("nil tableName.")
if isNil(u.table) {
return "", nil, errors.New("table to update is nil")
}
if err = u.table.serialize(update_statement, out); err != nil {
return
}
if len(u.updateValues) == 0 {
return "", nil, errors.New("No column updated.")
if len(u.columns) == 0 {
return "", nil, errors.New("no columns selected")
}
if len(u.row) == 0 {
return "", nil, errors.New("no values to updated")
}
out.newLine()
out.writeString("SET")
if len(u.columns) > 1 {
out.writeString("(")
}
err = serializeColumnList(update_statement, u.columns, out)
err = serializeColumnNames(u.columns, out)
if err != nil {
return "", nil, err
return
}
if len(u.columns) > 1 {
@ -88,28 +95,22 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
out.writeString("=")
if len(u.updateValues) > 1 {
if len(u.row) > 1 {
out.writeString("(")
}
for i, value := range u.updateValues {
if i > 0 {
out.writeString(", ")
}
err = serializeClauseList(update_statement, u.row, out)
err = value.serialize(update_statement, out)
if err != nil {
return
}
if err != nil {
return
}
if len(u.updateValues) > 1 {
if len(u.row) > 1 {
out.writeString(")")
}
if u.where == nil {
return "", nil, errors.New("Updating without a WHERE clause.")
return "", nil, errors.New("WHERE clause not set")
}
if err = out.writeWhere(update_statement, u.where); err != nil {
@ -117,8 +118,10 @@ func (u *updateStatementImpl) Sql() (sql string, args []interface{}, err error)
}
if len(u.returning) > 0 {
out.nextLine()
out.newLine()
out.writeString("RETURNING")
out.increaseIdent()
out.increaseIdent()
err = serializeProjectionList(update_statement, u.returning, out)
@ -139,6 +142,6 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(u, db, destination)
}
func (u *updateStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
return Execute(u, db)
func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(u, db)
}

View file

@ -1,124 +1,76 @@
package sqlbuilder
import (
"fmt"
"gotest.tools/assert"
"testing"
)
//
// UPDATE Statement tests =====================================================
//
func TestUpdateWithOneValue(t *testing.T) {
expectedSql := `
UPDATE db.table1
SET colInt = $1
WHERE table1.colInt >= $2;
`
stmt := table1.UPDATE(table1ColInt).
SET(1).
WHERE(table1ColInt.GT_EQ(Int(33)))
func TestUpdate(t *testing.T) {
stmt := table1.UPDATE(table1Col1, table1ColFloat).
SET(table1.SELECT(table1ColFloat, table2Col3)).
assertStatement(t, stmt, expectedSql, 1, int64(33))
}
func TestUpdateWithValues(t *testing.T) {
expectedSql := `
UPDATE db.table1
SET (colInt, colFloat) = ($1, $2)
WHERE table1.colInt >= $3;
`
stmt := table1.UPDATE(table1ColInt, table1ColFloat).
SET(1, 22.2).
WHERE(table1ColInt.GT_EQ(Int(33)))
assertStatement(t, stmt, expectedSql, 1, 22.2, int64(33))
}
func TestUpdateOneColumnWithSelect(t *testing.T) {
expectedSql := `
UPDATE db.table1
SET colFloat = (
SELECT table1.colFloat AS "table1.colFloat"
FROM db.table1
)
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`
stmt := table1.
UPDATE(table1ColFloat).
SET(
table1.SELECT(table1ColFloat),
).
WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1)
stmtStr, _, err := stmt.Sql()
assertStatement(t, stmt, expectedSql, int64(2))
}
assert.NilError(t, err)
fmt.Println(stmtStr)
assert.Equal(t, stmtStr, `
UPDATE db.table1 SET (col1,colFloat) = (
func TestUpdateColumnsWithSelect(t *testing.T) {
expectedSql := `
UPDATE db.table1
SET (col1, colFloat) = (
SELECT table1.colFloat AS "table1.colFloat",
table2.col3 AS "table2.col3"
FROM db.table1
)
WHERE table1.col1 = $1
RETURNING table1.col1 AS "table1.col1";
`)
`
stmt := table1.UPDATE(table1Col1, table1ColFloat).
SET(table1.SELECT(table1ColFloat, table2Col3)).
WHERE(table1Col1.EQ(Int(2))).
RETURNING(table1Col1)
assertStatement(t, stmt, expectedSql, int64(2))
}
//func (s *StmtSuite) TestUpdateNilColumn(c *gc.C) {
// stmt := table1.UPDATE().SET(nil, literal(1))
// _, err := stmt.String()
// c.Assert(err, gc.NotNil)
//}
//
//func (s *StmtSuite) TestUpdateNilExpr(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, nil)
// _, err := stmt.String()
// c.Assert(err, gc.NotNil)
//}
//
//func (s *StmtSuite) TestUpdateUnconditionally(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, literal(1))
// _, err := stmt.String()
// c.Assert(err, gc.NotNil)
//}
//
//func (s *StmtSuite) TestUpdateSingleValue(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, literal(1))
// stmt.WHERE(EqString(table1ColFloat, 2))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
//}
//
//func (s *StmtSuite) TestUpdateUsingDeferredLookupColumns(c *gc.C) {
// stmt := table1.UPDATE().SET(table1.C("col1"), literal(1))
// stmt.WHERE(EqString(table1ColFloat, 2))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 SET table1.col1=1 WHERE table1.col2=2")
//}
//
//func (s *StmtSuite) TestUpdateMultiValues(c *gc.C) {
// stmt := table1.UPDATE()
// stmt.SET(table1Col1, literal(1))
// stmt.SET(table1ColFloat, literal(2))
// stmt.WHERE(EqString(table1ColFloat, 3))
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 "+
// "SET table1.col1=1, table1.col2=2 "+
// "WHERE table1.col2=3")
//}
//
//func (s *StmtSuite) TestUpdateWithOrderBy(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, literal(1))
// stmt.WHERE(EqString(table1ColFloat, 2))
// stmt.ORDER_BY(table1ColFloat)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 "+
// "SET table1.col1=1 "+
// "WHERE table1.col2=2 "+
// "ORDER BY table1.col2")
//}
//
//func (s *StmtSuite) TestUpdateWithLimit(c *gc.C) {
// stmt := table1.UPDATE().SET(table1Col1, literal(1))
// stmt.WHERE(EqString(table1ColFloat, 2))
// stmt.LIMIT(5)
// sql, err := stmt.String()
// c.Assert(err, gc.IsNil)
//
// c.Assert(
// sql,
// gc.Equals,
// "UPDATE db.table1 "+
// "SET table1.col1=1 "+
// "WHERE table1.col2=2 "+
// "LIMIT 5")
//}
func TestInvalidInputs(t *testing.T) {
assertStatementErr(t, table1.UPDATE(table1ColInt).SET(1, 2), "WHERE clause not set")
assertStatementErr(t, table1.UPDATE(nil).SET(1, 2), "nil column in columns list")
}

View file

@ -4,6 +4,7 @@ import (
"database/sql"
"errors"
"github.com/go-jet/jet/sqlbuilder/execution"
"github.com/serenize/snaker"
"reflect"
)
@ -83,7 +84,7 @@ func serializeProjectionList(statement statementType, projections []projection,
for i, col := range projections {
if i > 0 {
out.writeString(",")
out.nextLine()
out.newLine()
}
if col == nil {
@ -98,14 +99,14 @@ func serializeProjectionList(statement statementType, projections []projection,
return nil
}
func serializeColumnList(statement statementType, columns []Column, out *queryData) error {
func serializeColumnNames(columns []column, out *queryData) error {
for i, col := range columns {
if i > 0 {
out.writeByte(',')
out.writeString(", ")
}
if col == nil {
return errors.New("nil column in columns list.")
return errors.New("nil column in columns list")
}
out.writeString(col.Name())
@ -118,6 +119,59 @@ func isNil(v interface{}) bool {
return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil())
}
func valueToClause(value interface{}) clause {
if clause, ok := value.(clause); ok {
return clause
} else {
return literal(value)
}
}
func unwindRowFromModel(columns []column, data interface{}) []clause {
structValue := reflect.Indirect(reflect.ValueOf(data))
row := []clause{}
if structValue.Kind() != reflect.Struct {
return row
}
for _, column := range columns {
columnName := column.Name()
structFieldName := snaker.SnakeToCamel(columnName)
structField := structValue.FieldByName(structFieldName)
if !structField.IsValid() {
continue
}
var field interface{}
if structField.Kind() == reflect.Ptr && structField.IsNil() {
field = nil
} else {
field = reflect.Indirect(structField).Interface()
}
row = append(row, literal(field))
}
return row
}
func unwindRowFromValues(value interface{}, values []interface{}) []clause {
row := []clause{}
allValues := append([]interface{}{value}, values...)
for _, val := range allValues {
row = append(row, valueToClause(val))
}
return row
}
func columnListToProjectionList(columns []Column) []projection {
var ret []projection
@ -138,7 +192,7 @@ func Query(statement Statement, db execution.Db, destination interface{}) error
return execution.Query(db, query, args, destination)
}
func Execute(statement Statement, db execution.Db) (res sql.Result, err error) {
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {