jet/sqlbuilder/insert_statement.go

220 lines
4.5 KiB
Go
Raw Normal View History

2019-04-07 09:58:12 +02:00
package sqlbuilder
import (
"database/sql"
2019-06-05 17:15:20 +02:00
"errors"
2019-04-07 09:58:12 +02:00
"github.com/serenize/snaker"
"github.com/sub0zero/go-sqlbuilder/sqlbuilder/execution"
2019-04-07 09:58:12 +02:00
"reflect"
"strings"
)
2019-06-04 12:10:23 +02:00
type InsertStatement interface {
2019-05-12 18:15:23 +02:00
Statement
2019-04-07 09:58:12 +02:00
2019-05-12 18:15:23 +02:00
// Add a row of values to the insert Statement.
2019-06-04 12:10:23 +02:00
VALUES(values ...interface{}) InsertStatement
2019-04-07 09:58:12 +02:00
// Map or stracture mapped to column names
2019-06-04 12:10:23 +02:00
VALUES_MAPPING(data interface{}) InsertStatement
2019-04-07 09:58:12 +02:00
2019-06-04 12:10:23 +02:00
RETURNING(projections ...projection) InsertStatement
2019-04-07 09:58:12 +02:00
2019-06-04 12:10:23 +02:00
QUERY(selectStatement SelectStatement) InsertStatement
2019-04-07 09:58:12 +02:00
}
2019-06-05 17:15:20 +02:00
func newInsertStatement(t WritableTable, columns ...Column) InsertStatement {
2019-04-07 09:58:12 +02:00
return &insertStatementImpl{
2019-04-14 17:55:10 +02:00
table: t,
columns: columns,
2019-04-07 09:58:12 +02:00
}
}
type insertStatementImpl struct {
2019-06-05 17:15:20 +02:00
table WritableTable
columns []Column
2019-05-07 19:06:21 +02:00
rows [][]clause
2019-06-04 12:10:23 +02:00
query SelectStatement
2019-05-07 19:06:21 +02:00
returning []projection
2019-04-07 09:58:12 +02:00
errors []string
}
func (s *insertStatementImpl) Query(db execution.Db, destination interface{}) error {
2019-04-20 19:49:29 +02:00
return Query(s, db, destination)
}
2019-04-07 09:58:12 +02:00
func (u *insertStatementImpl) Execute(db execution.Db) (res sql.Result, err error) {
2019-04-20 19:49:29 +02:00
return Execute(u, db)
2019-04-07 09:58:12 +02:00
}
2019-06-04 12:10:23 +02:00
// Expression or default keyword
func (i *insertStatementImpl) VALUES(values ...interface{}) InsertStatement {
2019-05-01 17:25:10 +02:00
if len(values) == 0 {
2019-05-12 18:15:23 +02:00
return i
2019-05-01 17:25:10 +02:00
}
2019-05-07 19:06:21 +02:00
literalRow := []clause{}
2019-04-07 09:58:12 +02:00
for _, value := range values {
2019-05-07 19:06:21 +02:00
if clause, ok := value.(clause); ok {
2019-04-07 16:54:06 +02:00
literalRow = append(literalRow, clause)
} else {
literalRow = append(literalRow, literal(value))
2019-04-07 16:54:06 +02:00
}
2019-04-07 09:58:12 +02:00
}
2019-05-12 18:15:23 +02:00
i.rows = append(i.rows, literalRow)
return i
2019-04-07 09:58:12 +02:00
}
2019-06-04 12:10:23 +02:00
func (i *insertStatementImpl) VALUES_MAPPING(data interface{}) InsertStatement {
2019-04-07 09:58:12 +02:00
if data == nil {
2019-05-29 14:03:38 +02:00
i.addError("ADD method data is nil.")
2019-04-07 09:58:12 +02:00
return i
}
value := reflect.ValueOf(data)
if value.Kind() == reflect.Ptr {
value = value.Elem()
}
if value.Kind() != reflect.Struct {
2019-05-29 14:03:38 +02:00
i.addError("ADD method data is not struct or pointer to struct.")
2019-04-07 09:58:12 +02:00
return i
}
2019-05-07 19:06:21 +02:00
rowValues := []clause{}
2019-04-07 09:58:12 +02:00
for _, column := range i.columns {
columnName := column.Name()
structFieldName := snaker.SnakeToCamel(columnName)
structField := value.FieldByName(structFieldName)
if !structField.IsValid() {
2019-05-29 14:03:38 +02:00
i.addError("ADD() : Data structure doesn't contain field : " + structFieldName + " for column " + columnName)
2019-04-07 09:58:12 +02:00
return i
}
rowValues = append(rowValues, literal(structField.Interface()))
2019-04-07 09:58:12 +02:00
}
i.rows = append(i.rows, rowValues)
return i
}
2019-06-04 12:10:23 +02:00
func (i *insertStatementImpl) RETURNING(projections ...projection) InsertStatement {
2019-05-03 12:51:57 +02:00
i.returning = defaultProjectionAliasing(projections)
2019-04-07 09:58:12 +02:00
return i
}
2019-06-04 12:10:23 +02:00
func (i *insertStatementImpl) QUERY(selectStatement SelectStatement) InsertStatement {
2019-04-07 16:54:06 +02:00
i.query = selectStatement
return i
}
2019-04-07 09:58:12 +02:00
func (i *insertStatementImpl) addError(err string) {
i.errors = append(i.errors, err)
}
2019-05-12 18:15:23 +02:00
func (i *insertStatementImpl) DebugSql() (query string, err error) {
return DebugSql(i)
}
func (s *insertStatementImpl) Sql() (sql string, args []interface{}, err error) {
2019-04-07 09:58:12 +02:00
if len(s.errors) > 0 {
return "", nil, errors.New("sql builder errors: " + strings.Join(s.errors, ", "))
2019-04-07 09:58:12 +02:00
}
queryData := &queryData{}
2019-05-12 18:15:23 +02:00
queryData.nextLine()
queryData.writeString("INSERT INTO")
2019-04-07 09:58:12 +02:00
if s.table == nil {
2019-06-05 17:15:20 +02:00
return "", nil, errors.New("nil tableName.")
2019-04-07 09:58:12 +02:00
}
err = s.table.serialize(insert_statement, queryData)
2019-04-07 09:58:12 +02:00
2019-05-12 18:15:23 +02:00
queryData.writeByte(' ')
if err != nil {
return "", nil, err
}
2019-04-07 09:58:12 +02:00
if len(s.columns) > 0 {
2019-05-12 18:15:23 +02:00
queryData.writeString("(")
err = serializeColumnList(insert_statement, s.columns, queryData)
2019-04-07 09:58:12 +02:00
if err != nil {
return "", nil, err
2019-04-07 09:58:12 +02:00
}
2019-05-12 18:15:23 +02:00
queryData.writeString(")")
2019-04-07 09:58:12 +02:00
}
2019-04-07 16:54:06 +02:00
if len(s.rows) == 0 && s.query == nil {
return "", nil, errors.New("No row or query specified.")
2019-04-07 09:58:12 +02:00
}
2019-04-07 16:54:06 +02:00
if len(s.rows) > 0 && s.query != nil {
return "", nil, errors.New("Only new rows or query has to be specified.")
2019-04-07 16:54:06 +02:00
}
2019-04-07 09:58:12 +02:00
2019-04-07 16:54:06 +02:00
if len(s.rows) > 0 {
2019-05-12 18:15:23 +02:00
queryData.writeString("VALUES")
2019-04-07 16:54:06 +02:00
for row_i, row := range s.rows {
if row_i > 0 {
2019-05-12 18:15:23 +02:00
queryData.writeString(",")
2019-04-07 09:58:12 +02:00
}
2019-05-12 18:15:23 +02:00
queryData.increaseIdent()
queryData.nextLine()
queryData.writeString("(")
2019-04-07 16:54:06 +02:00
if len(row) != len(s.columns) {
return "", nil, errors.New("# of values does not match # of columns.")
2019-04-07 09:58:12 +02:00
}
err = serializeClauseList(insert_statement, row, queryData)
if err != nil {
return "", nil, err
2019-04-07 09:58:12 +02:00
}
queryData.writeByte(')')
2019-05-12 18:15:23 +02:00
queryData.decreaseIdent()
2019-04-07 16:54:06 +02:00
}
}
if s.query != nil {
err = s.query.serialize(insert_statement, queryData)
2019-04-07 16:54:06 +02:00
if err != nil {
return
2019-04-07 09:58:12 +02:00
}
}
if len(s.returning) > 0 {
2019-05-12 18:15:23 +02:00
queryData.nextLine()
queryData.writeString("RETURNING")
2019-04-07 09:58:12 +02:00
err = queryData.writeProjection(insert_statement, s.returning)
2019-04-07 09:58:12 +02:00
2019-04-14 17:55:10 +02:00
if err != nil {
return
2019-04-07 09:58:12 +02:00
}
}
2019-05-12 18:15:23 +02:00
sql, args = queryData.finalize()
2019-04-07 09:58:12 +02:00
2019-05-12 18:15:23 +02:00
return
2019-04-07 09:58:12 +02:00
}