Add support for INSERT select query.

This commit is contained in:
zer0sub 2019-04-07 16:54:06 +02:00
parent 0971573338
commit b287521f1a
5 changed files with 154 additions and 74 deletions

View file

@ -20,6 +20,8 @@ type InsertStatement interface {
RETURNING(column ...Expression) InsertStatement RETURNING(column ...Expression) InsertStatement
QUERY(selectStatement SelectStatement) InsertStatement
Execute(db types.Db) (sql.Result, error) Execute(db types.Db) (sql.Result, error)
} }
@ -27,7 +29,7 @@ func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
return &insertStatementImpl{ return &insertStatementImpl{
table: t, table: t,
columns: columns, columns: columns,
rows: make([][]Expression, 0, 1), rows: make([][]Clause, 0, 1),
returning: make([]Expression, 0, 1), returning: make([]Expression, 0, 1),
} }
} }
@ -40,7 +42,8 @@ type columnAssignment struct {
type insertStatementImpl struct { type insertStatementImpl struct {
table WritableTable table WritableTable
columns []Column columns []Column
rows [][]Expression rows [][]Clause
query SelectStatement
returning []Expression returning []Expression
errors []string errors []string
@ -58,23 +61,15 @@ func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
return return
} }
//func (i *insertStatementImpl) ExecuteInTx(tx *sql.Tx) (res sql.Result, err error) {
// query, err := i.String()
//
// if err != nil {
// return
// }
//
// res, err = tx.Exec(query)
//
// return
//}
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
literalRow := []Expression{} literalRow := []Clause{}
for _, value := range values { for _, value := range values {
literalRow = append(literalRow, Literal(value)) if clause, ok := value.(Clause); ok {
literalRow = append(literalRow, clause)
} else {
literalRow = append(literalRow, Literal(value))
}
} }
s.rows = append(s.rows, literalRow) s.rows = append(s.rows, literalRow)
@ -98,7 +93,7 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
return i return i
} }
rowValues := []Expression{} rowValues := []Clause{}
for _, column := range i.columns { for _, column := range i.columns {
columnName := column.Name() columnName := column.Name()
@ -125,6 +120,12 @@ func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement {
return i return i
} }
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
i.query = selectStatement
return i
}
func (i *insertStatementImpl) addError(err string) { func (i *insertStatementImpl) addError(err string) {
i.errors = append(i.errors, err) i.errors = append(i.errors, err)
} }
@ -144,63 +145,73 @@ func (s *insertStatementImpl) String() (sql string, err error) {
buf.WriteString(s.table.SchemaName() + "." + s.table.TableName()) buf.WriteString(s.table.SchemaName() + "." + s.table.TableName())
if len(s.columns) == 0 { if len(s.columns) > 0 {
return "", errors.Newf( _, _ = buf.WriteString(" (")
"No column specified. Generated sql: %s", for i, col := range s.columns {
buf.String()) if i > 0 {
}
_, _ = buf.WriteString(" (")
for i, col := range s.columns {
if i > 0 {
_ = buf.WriteByte(',')
}
if col == nil {
return "", errors.Newf(
"nil column in columns list. Generated sql: %s",
buf.String())
}
buf.WriteString(col.Name())
}
if len(s.rows) == 0 {
return "", errors.Newf(
"No row specified. Generated sql: %s",
buf.String())
}
_, _ = buf.WriteString(") VALUES (")
for row_i, row := range s.rows {
if row_i > 0 {
_, _ = buf.WriteString(", (")
}
if len(row) != len(s.columns) {
return "", errors.Newf(
"# of values does not match # of columns. Generated sql: %s",
buf.String())
}
for col_i, value := range row {
if col_i > 0 {
_ = buf.WriteByte(',') _ = buf.WriteByte(',')
} }
if value == nil { if col == nil {
return "", errors.Newf( return "", errors.Newf(
"nil value in row %d col %d. Generated sql: %s", "nil column in columns list. Generated sql: %s",
row_i,
col_i,
buf.String()) buf.String())
} }
if err = value.SerializeSql(buf); err != nil { buf.WriteString(col.Name())
return }
}
buf.WriteString(") ")
}
if len(s.rows) == 0 && s.query == nil {
return "", errors.Newf("No row or query specified. Generated sql: %s", buf.String())
}
if len(s.rows) > 0 && s.query != nil {
return "", errors.Newf("Only new rows or query has to be specified. Generated sql: %s", buf.String())
}
if len(s.rows) > 0 {
_, _ = buf.WriteString("VALUES (")
for row_i, row := range s.rows {
if row_i > 0 {
_, _ = buf.WriteString(", (")
}
if len(row) != len(s.columns) {
return "", errors.Newf(
"# of values does not match # of columns. Generated sql: %s",
buf.String())
}
for col_i, value := range row {
if col_i > 0 {
_ = buf.WriteByte(',')
}
if value == nil {
return "", errors.Newf(
"nil value in row %d col %d. Generated sql: %s",
row_i,
col_i,
buf.String())
}
if err = value.SerializeSql(buf); err != nil {
return
}
}
_ = buf.WriteByte(')')
}
}
if s.query != nil {
err = s.query.SerializeSql(buf)
if err != nil {
return
} }
_ = buf.WriteByte(')')
} }
if len(s.returning) > 0 { if len(s.returning) > 0 {

View file

@ -123,3 +123,26 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
fmt.Println(err) fmt.Println(err)
assert.Assert(t, err != nil) assert.Assert(t, err != nil)
} }
func TestInsertQuery(t *testing.T) {
stmt := table1.INSERT(table1Col1).
QUERY(table1.SELECT(table1Col1))
stmtStr, err := stmt.String()
assert.NilError(t, err)
fmt.Println(stmtStr)
}
func TestInsertDefaultValue(t *testing.T) {
stmt := table1.INSERT(table1Col1, table1Col2).
VALUES(DEFAULT, "two")
stmtStr, err := stmt.String()
assert.NilError(t, err)
fmt.Println(stmtStr)
}

15
sqlbuilder/keyword.go Normal file
View file

@ -0,0 +1,15 @@
package sqlbuilder
import "bytes"
const (
DEFAULT keywordClause = "DEFAULT"
)
type keywordClause string
func (k keywordClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
out.WriteString(string(k))
return nil
}

View file

@ -73,7 +73,7 @@ func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...seriali
return err return err
} }
out.WriteString("( ") out.WriteString("(")
out.WriteString(str) out.WriteString(str)
out.WriteString(")") out.WriteString(")")

View file

@ -2,6 +2,8 @@ package tests
import ( import (
"fmt" "fmt"
"github.com/davecgh/go-spew/spew"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model"
"github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table"
"gotest.tools/assert" "gotest.tools/assert"
@ -9,11 +11,11 @@ import (
) )
func TestInsertValues(t *testing.T) { func TestInsertValues(t *testing.T) {
insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name). insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel).
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial"). VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT).
VALUES("http://www.google.com", "Google"). VALUES("http://www.google.com", "Google", sqlbuilder.DEFAULT).
VALUES("http://www.yahoo.com", "Yahoo"). VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT).
VALUES("http://www.bing.com", "Bing"). VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT).
RETURNING(table.Link.ID) RETURNING(table.Link.ID)
insertQueryStr, err := insertQuery.String() insertQueryStr, err := insertQuery.String()
@ -22,7 +24,7 @@ func TestInsertValues(t *testing.T) {
fmt.Println(insertQueryStr) fmt.Println(insertQueryStr)
assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial'), ('http://www.google.com','Google'), ('http://www.yahoo.com','Yahoo'), ('http://www.bing.com','Bing') RETURNING link.id;`) assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id;`)
res, err := insertQuery.Execute(db) res, err := insertQuery.Execute(db)
assert.NilError(t, err) assert.NilError(t, err)
@ -62,7 +64,8 @@ func TestInsertDataObject(t *testing.T) {
Rel: nil, Rel: nil,
} }
query := table.Link.INSERT(table.Link.URL, table.Link.Name). query := table.Link.
INSERT(table.Link.URL, table.Link.Name).
VALUES_MAPPING(linkData) VALUES_MAPPING(linkData)
queryStr, err := query.String() queryStr, err := query.String()
@ -77,3 +80,31 @@ func TestInsertDataObject(t *testing.T) {
fmt.Println(result) fmt.Println(result)
} }
func TestInsertQuery(t *testing.T) {
_, err := table.Link.INSERT(table.Link.URL, table.Link.Name).
VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial").Execute(db)
assert.NilError(t, err)
query := table.Link.
INSERT(table.Link.URL, table.Link.Name).
QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name))
queryStr, err := query.String()
assert.NilError(t, err)
fmt.Println(queryStr)
_, err = query.Execute(db)
assert.NilError(t, err)
allLinks := []model.Link{}
err = table.Link.SELECT(table.Link.AllColumns).Execute(db, &allLinks)
assert.NilError(t, err)
spew.Dump(allLinks)
}