Add support for INSERT select query.
This commit is contained in:
parent
0971573338
commit
b287521f1a
5 changed files with 154 additions and 74 deletions
|
|
@ -20,6 +20,8 @@ type InsertStatement interface {
|
|||
|
||||
RETURNING(column ...Expression) InsertStatement
|
||||
|
||||
QUERY(selectStatement SelectStatement) InsertStatement
|
||||
|
||||
Execute(db types.Db) (sql.Result, error)
|
||||
}
|
||||
|
||||
|
|
@ -27,7 +29,7 @@ func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
|
|||
return &insertStatementImpl{
|
||||
table: t,
|
||||
columns: columns,
|
||||
rows: make([][]Expression, 0, 1),
|
||||
rows: make([][]Clause, 0, 1),
|
||||
returning: make([]Expression, 0, 1),
|
||||
}
|
||||
}
|
||||
|
|
@ -40,7 +42,8 @@ type columnAssignment struct {
|
|||
type insertStatementImpl struct {
|
||||
table WritableTable
|
||||
columns []Column
|
||||
rows [][]Expression
|
||||
rows [][]Clause
|
||||
query SelectStatement
|
||||
returning []Expression
|
||||
|
||||
errors []string
|
||||
|
|
@ -58,23 +61,15 @@ func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
//func (i *insertStatementImpl) ExecuteInTx(tx *sql.Tx) (res sql.Result, err error) {
|
||||
// query, err := i.String()
|
||||
//
|
||||
// if err != nil {
|
||||
// return
|
||||
// }
|
||||
//
|
||||
// res, err = tx.Exec(query)
|
||||
//
|
||||
// return
|
||||
//}
|
||||
|
||||
func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
|
||||
literalRow := []Expression{}
|
||||
literalRow := []Clause{}
|
||||
|
||||
for _, value := range values {
|
||||
literalRow = append(literalRow, Literal(value))
|
||||
if clause, ok := value.(Clause); ok {
|
||||
literalRow = append(literalRow, clause)
|
||||
} else {
|
||||
literalRow = append(literalRow, Literal(value))
|
||||
}
|
||||
}
|
||||
|
||||
s.rows = append(s.rows, literalRow)
|
||||
|
|
@ -98,7 +93,7 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
|
|||
return i
|
||||
}
|
||||
|
||||
rowValues := []Expression{}
|
||||
rowValues := []Clause{}
|
||||
|
||||
for _, column := range i.columns {
|
||||
columnName := column.Name()
|
||||
|
|
@ -125,6 +120,12 @@ func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement {
|
|||
return i
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
|
||||
i.query = selectStatement
|
||||
|
||||
return i
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) addError(err string) {
|
||||
i.errors = append(i.errors, err)
|
||||
}
|
||||
|
|
@ -144,63 +145,73 @@ func (s *insertStatementImpl) String() (sql string, err error) {
|
|||
|
||||
buf.WriteString(s.table.SchemaName() + "." + s.table.TableName())
|
||||
|
||||
if len(s.columns) == 0 {
|
||||
return "", errors.Newf(
|
||||
"No column specified. Generated sql: %s",
|
||||
buf.String())
|
||||
}
|
||||
|
||||
_, _ = buf.WriteString(" (")
|
||||
for i, col := range s.columns {
|
||||
if i > 0 {
|
||||
_ = buf.WriteByte(',')
|
||||
}
|
||||
|
||||
if col == nil {
|
||||
return "", errors.Newf(
|
||||
"nil column in columns list. Generated sql: %s",
|
||||
buf.String())
|
||||
}
|
||||
|
||||
buf.WriteString(col.Name())
|
||||
}
|
||||
|
||||
if len(s.rows) == 0 {
|
||||
return "", errors.Newf(
|
||||
"No row specified. Generated sql: %s",
|
||||
buf.String())
|
||||
}
|
||||
|
||||
_, _ = buf.WriteString(") VALUES (")
|
||||
for row_i, row := range s.rows {
|
||||
if row_i > 0 {
|
||||
_, _ = buf.WriteString(", (")
|
||||
}
|
||||
|
||||
if len(row) != len(s.columns) {
|
||||
return "", errors.Newf(
|
||||
"# of values does not match # of columns. Generated sql: %s",
|
||||
buf.String())
|
||||
}
|
||||
|
||||
for col_i, value := range row {
|
||||
if col_i > 0 {
|
||||
if len(s.columns) > 0 {
|
||||
_, _ = buf.WriteString(" (")
|
||||
for i, col := range s.columns {
|
||||
if i > 0 {
|
||||
_ = buf.WriteByte(',')
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
if col == nil {
|
||||
return "", errors.Newf(
|
||||
"nil value in row %d col %d. Generated sql: %s",
|
||||
row_i,
|
||||
col_i,
|
||||
"nil column in columns list. Generated sql: %s",
|
||||
buf.String())
|
||||
}
|
||||
|
||||
if err = value.SerializeSql(buf); err != nil {
|
||||
return
|
||||
}
|
||||
buf.WriteString(col.Name())
|
||||
}
|
||||
|
||||
buf.WriteString(") ")
|
||||
}
|
||||
|
||||
if len(s.rows) == 0 && s.query == nil {
|
||||
return "", errors.Newf("No row or query specified. Generated sql: %s", buf.String())
|
||||
}
|
||||
|
||||
if len(s.rows) > 0 && s.query != nil {
|
||||
return "", errors.Newf("Only new rows or query has to be specified. Generated sql: %s", buf.String())
|
||||
}
|
||||
|
||||
if len(s.rows) > 0 {
|
||||
_, _ = buf.WriteString("VALUES (")
|
||||
for row_i, row := range s.rows {
|
||||
if row_i > 0 {
|
||||
_, _ = buf.WriteString(", (")
|
||||
}
|
||||
|
||||
if len(row) != len(s.columns) {
|
||||
return "", errors.Newf(
|
||||
"# of values does not match # of columns. Generated sql: %s",
|
||||
buf.String())
|
||||
}
|
||||
|
||||
for col_i, value := range row {
|
||||
if col_i > 0 {
|
||||
_ = buf.WriteByte(',')
|
||||
}
|
||||
|
||||
if value == nil {
|
||||
return "", errors.Newf(
|
||||
"nil value in row %d col %d. Generated sql: %s",
|
||||
row_i,
|
||||
col_i,
|
||||
buf.String())
|
||||
}
|
||||
|
||||
if err = value.SerializeSql(buf); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
_ = buf.WriteByte(')')
|
||||
}
|
||||
}
|
||||
|
||||
if s.query != nil {
|
||||
err = s.query.SerializeSql(buf)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_ = buf.WriteByte(')')
|
||||
}
|
||||
|
||||
if len(s.returning) > 0 {
|
||||
|
|
|
|||
|
|
@ -123,3 +123,26 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) {
|
|||
fmt.Println(err)
|
||||
assert.Assert(t, err != nil)
|
||||
}
|
||||
|
||||
func TestInsertQuery(t *testing.T) {
|
||||
|
||||
stmt := table1.INSERT(table1Col1).
|
||||
QUERY(table1.SELECT(table1Col1))
|
||||
|
||||
stmtStr, err := stmt.String()
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
fmt.Println(stmtStr)
|
||||
}
|
||||
|
||||
func TestInsertDefaultValue(t *testing.T) {
|
||||
stmt := table1.INSERT(table1Col1, table1Col2).
|
||||
VALUES(DEFAULT, "two")
|
||||
|
||||
stmtStr, err := stmt.String()
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
fmt.Println(stmtStr)
|
||||
}
|
||||
|
|
|
|||
15
sqlbuilder/keyword.go
Normal file
15
sqlbuilder/keyword.go
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
package sqlbuilder
|
||||
|
||||
import "bytes"
|
||||
|
||||
const (
|
||||
DEFAULT keywordClause = "DEFAULT"
|
||||
)
|
||||
|
||||
type keywordClause string
|
||||
|
||||
func (k keywordClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error {
|
||||
out.WriteString(string(k))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -73,7 +73,7 @@ func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...seriali
|
|||
return err
|
||||
}
|
||||
|
||||
out.WriteString("( ")
|
||||
out.WriteString("(")
|
||||
out.WriteString(str)
|
||||
out.WriteString(")")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue