diff --git a/sqlbuilder/insert_statement.go b/sqlbuilder/insert_statement.go index 898781e..8d8dfdb 100644 --- a/sqlbuilder/insert_statement.go +++ b/sqlbuilder/insert_statement.go @@ -20,6 +20,8 @@ type InsertStatement interface { RETURNING(column ...Expression) InsertStatement + QUERY(selectStatement SelectStatement) InsertStatement + Execute(db types.Db) (sql.Result, error) } @@ -27,7 +29,7 @@ func newInsertStatement(t WritableTable, columns ...Column) InsertStatement { return &insertStatementImpl{ table: t, columns: columns, - rows: make([][]Expression, 0, 1), + rows: make([][]Clause, 0, 1), returning: make([]Expression, 0, 1), } } @@ -40,7 +42,8 @@ type columnAssignment struct { type insertStatementImpl struct { table WritableTable columns []Column - rows [][]Expression + rows [][]Clause + query SelectStatement returning []Expression errors []string @@ -58,23 +61,15 @@ func (i *insertStatementImpl) Execute(db types.Db) (res sql.Result, err error) { return } -//func (i *insertStatementImpl) ExecuteInTx(tx *sql.Tx) (res sql.Result, err error) { -// query, err := i.String() -// -// if err != nil { -// return -// } -// -// res, err = tx.Exec(query) -// -// return -//} - func (s *insertStatementImpl) VALUES(values ...interface{}) InsertStatement { - literalRow := []Expression{} + literalRow := []Clause{} for _, value := range values { - literalRow = append(literalRow, Literal(value)) + if clause, ok := value.(Clause); ok { + literalRow = append(literalRow, clause) + } else { + literalRow = append(literalRow, Literal(value)) + } } s.rows = append(s.rows, literalRow) @@ -98,7 +93,7 @@ func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement { return i } - rowValues := []Expression{} + rowValues := []Clause{} for _, column := range i.columns { columnName := column.Name() @@ -125,6 +120,12 @@ func (i *insertStatementImpl) RETURNING(column ...Expression) InsertStatement { return i } +func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement { + i.query = selectStatement + + return i +} + func (i *insertStatementImpl) addError(err string) { i.errors = append(i.errors, err) } @@ -144,63 +145,73 @@ func (s *insertStatementImpl) String() (sql string, err error) { buf.WriteString(s.table.SchemaName() + "." + s.table.TableName()) - if len(s.columns) == 0 { - return "", errors.Newf( - "No column specified. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(" (") - for i, col := range s.columns { - if i > 0 { - _ = buf.WriteByte(',') - } - - if col == nil { - return "", errors.Newf( - "nil column in columns list. Generated sql: %s", - buf.String()) - } - - buf.WriteString(col.Name()) - } - - if len(s.rows) == 0 { - return "", errors.Newf( - "No row specified. Generated sql: %s", - buf.String()) - } - - _, _ = buf.WriteString(") VALUES (") - for row_i, row := range s.rows { - if row_i > 0 { - _, _ = buf.WriteString(", (") - } - - if len(row) != len(s.columns) { - return "", errors.Newf( - "# of values does not match # of columns. Generated sql: %s", - buf.String()) - } - - for col_i, value := range row { - if col_i > 0 { + if len(s.columns) > 0 { + _, _ = buf.WriteString(" (") + for i, col := range s.columns { + if i > 0 { _ = buf.WriteByte(',') } - if value == nil { + if col == nil { return "", errors.Newf( - "nil value in row %d col %d. Generated sql: %s", - row_i, - col_i, + "nil column in columns list. Generated sql: %s", buf.String()) } - if err = value.SerializeSql(buf); err != nil { - return - } + buf.WriteString(col.Name()) + } + + buf.WriteString(") ") + } + + if len(s.rows) == 0 && s.query == nil { + return "", errors.Newf("No row or query specified. Generated sql: %s", buf.String()) + } + + if len(s.rows) > 0 && s.query != nil { + return "", errors.Newf("Only new rows or query has to be specified. Generated sql: %s", buf.String()) + } + + if len(s.rows) > 0 { + _, _ = buf.WriteString("VALUES (") + for row_i, row := range s.rows { + if row_i > 0 { + _, _ = buf.WriteString(", (") + } + + if len(row) != len(s.columns) { + return "", errors.Newf( + "# of values does not match # of columns. Generated sql: %s", + buf.String()) + } + + for col_i, value := range row { + if col_i > 0 { + _ = buf.WriteByte(',') + } + + if value == nil { + return "", errors.Newf( + "nil value in row %d col %d. Generated sql: %s", + row_i, + col_i, + buf.String()) + } + + if err = value.SerializeSql(buf); err != nil { + return + } + } + _ = buf.WriteByte(')') + } + } + + if s.query != nil { + err = s.query.SerializeSql(buf) + + if err != nil { + return } - _ = buf.WriteByte(')') } if len(s.returning) > 0 { diff --git a/sqlbuilder/insert_statement_test.go b/sqlbuilder/insert_statement_test.go index d43b872..be2a600 100644 --- a/sqlbuilder/insert_statement_test.go +++ b/sqlbuilder/insert_statement_test.go @@ -123,3 +123,26 @@ func TestInsertValuesFromModelColumnMismatch(t *testing.T) { fmt.Println(err) assert.Assert(t, err != nil) } + +func TestInsertQuery(t *testing.T) { + + stmt := table1.INSERT(table1Col1). + QUERY(table1.SELECT(table1Col1)) + + stmtStr, err := stmt.String() + + assert.NilError(t, err) + + fmt.Println(stmtStr) +} + +func TestInsertDefaultValue(t *testing.T) { + stmt := table1.INSERT(table1Col1, table1Col2). + VALUES(DEFAULT, "two") + + stmtStr, err := stmt.String() + + assert.NilError(t, err) + + fmt.Println(stmtStr) +} diff --git a/sqlbuilder/keyword.go b/sqlbuilder/keyword.go new file mode 100644 index 0000000..a6eaf8c --- /dev/null +++ b/sqlbuilder/keyword.go @@ -0,0 +1,15 @@ +package sqlbuilder + +import "bytes" + +const ( + DEFAULT keywordClause = "DEFAULT" +) + +type keywordClause string + +func (k keywordClause) SerializeSql(out *bytes.Buffer, options ...serializeOption) error { + out.WriteString(string(k)) + + return nil +} diff --git a/sqlbuilder/select_statement.go b/sqlbuilder/select_statement.go index 151e87b..9ecfb1e 100644 --- a/sqlbuilder/select_statement.go +++ b/sqlbuilder/select_statement.go @@ -73,7 +73,7 @@ func (s *selectStatementImpl) SerializeSql(out *bytes.Buffer, options ...seriali return err } - out.WriteString("( ") + out.WriteString("(") out.WriteString(str) out.WriteString(")") diff --git a/tests/insert_test.go b/tests/insert_test.go index 3178793..6265f3f 100644 --- a/tests/insert_test.go +++ b/tests/insert_test.go @@ -2,6 +2,8 @@ package tests import ( "fmt" + "github.com/davecgh/go-spew/spew" + "github.com/sub0zero/go-sqlbuilder/sqlbuilder" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/model" "github.com/sub0zero/go-sqlbuilder/tests/.test_files/dvd_rental/test_sample/table" "gotest.tools/assert" @@ -9,11 +11,11 @@ import ( ) func TestInsertValues(t *testing.T) { - insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name). - VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial"). - VALUES("http://www.google.com", "Google"). - VALUES("http://www.yahoo.com", "Yahoo"). - VALUES("http://www.bing.com", "Bing"). + insertQuery := table.Link.INSERT(table.Link.URL, table.Link.Name, table.Link.Rel). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial", sqlbuilder.DEFAULT). + VALUES("http://www.google.com", "Google", sqlbuilder.DEFAULT). + VALUES("http://www.yahoo.com", "Yahoo", sqlbuilder.DEFAULT). + VALUES("http://www.bing.com", "Bing", sqlbuilder.DEFAULT). RETURNING(table.Link.ID) insertQueryStr, err := insertQuery.String() @@ -22,7 +24,7 @@ func TestInsertValues(t *testing.T) { fmt.Println(insertQueryStr) - assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial'), ('http://www.google.com','Google'), ('http://www.yahoo.com','Yahoo'), ('http://www.bing.com','Bing') RETURNING link.id;`) + assert.Equal(t, insertQueryStr, `INSERT INTO test_sample.link (url,name,rel) VALUES ('http://www.postgresqltutorial.com','PostgreSQL Tutorial',DEFAULT), ('http://www.google.com','Google',DEFAULT), ('http://www.yahoo.com','Yahoo',DEFAULT), ('http://www.bing.com','Bing',DEFAULT) RETURNING link.id;`) res, err := insertQuery.Execute(db) assert.NilError(t, err) @@ -62,7 +64,8 @@ func TestInsertDataObject(t *testing.T) { Rel: nil, } - query := table.Link.INSERT(table.Link.URL, table.Link.Name). + query := table.Link. + INSERT(table.Link.URL, table.Link.Name). VALUES_MAPPING(linkData) queryStr, err := query.String() @@ -77,3 +80,31 @@ func TestInsertDataObject(t *testing.T) { fmt.Println(result) } + +func TestInsertQuery(t *testing.T) { + + _, err := table.Link.INSERT(table.Link.URL, table.Link.Name). + VALUES("http://www.postgresqltutorial.com", "PostgreSQL Tutorial").Execute(db) + + assert.NilError(t, err) + + query := table.Link. + INSERT(table.Link.URL, table.Link.Name). + QUERY(table.Link.SELECT(table.Link.URL, table.Link.Name)) + + queryStr, err := query.String() + + assert.NilError(t, err) + + fmt.Println(queryStr) + + _, err = query.Execute(db) + + assert.NilError(t, err) + + allLinks := []model.Link{} + err = table.Link.SELECT(table.Link.AllColumns).Execute(db, &allLinks) + assert.NilError(t, err) + + spew.Dump(allLinks) +}