diff --git a/internal/jet/clause.go b/internal/jet/clause.go index a6a49d8..446a545 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -217,12 +217,13 @@ func (f *ClauseFor) Serialize(statementType StatementType, out *SQLBuilder, opti // ClauseSetStmtOperator struct type ClauseSetStmtOperator struct { - Operator string - All bool - Selects []SerializerStatement - OrderBy ClauseOrderBy - Limit ClauseLimit - Offset ClauseOffset + Operator string + All bool + Selects []SerializerStatement + OrderBy ClauseOrderBy + Limit ClauseLimit + Offset ClauseOffset + SkipSelectWrap bool } // Projections returns set of projections for ClauseSetStmtOperator @@ -242,6 +243,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB for i, selectStmt := range s.Selects { out.NewLine() if i > 0 { + if s.SkipSelectWrap { + out.NewLine() + } + out.WriteString(s.Operator) if s.All { @@ -254,7 +259,11 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SQLB panic("jet: select statement of '" + s.Operator + "' is nil") } - selectStmt.serialize(statementType, out, FallTrough(options)...) + if s.SkipSelectWrap { + options = append(FallTrough(options), NoWrap) + } + + selectStmt.serialize(statementType, out, options...) } s.OrderBy.Serialize(statementType, out) @@ -360,10 +369,6 @@ type ClauseValuesQuery struct { // Serialize serializes clause into SQLBuilder func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { - if len(v.Rows) == 0 && v.Query == nil { - panic("jet: VALUES or QUERY has to be specified for INSERT statement") - } - if len(v.Rows) > 0 && v.Query != nil { panic("jet: VALUES or QUERY has to be specified for INSERT statement") } @@ -405,7 +410,8 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SQLBuilder, o // ClauseQuery struct type ClauseQuery struct { - Query SerializerStatement + Query SerializerStatement + SkipSelectWrap bool } // Serialize serializes clause into SQLBuilder @@ -414,7 +420,11 @@ func (v *ClauseQuery) Serialize(statementType StatementType, out *SQLBuilder, op return } - v.Query.serialize(statementType, out, FallTrough(options)...) + if v.SkipSelectWrap { + options = append(FallTrough(options), NoWrap) + } + + v.Query.serialize(statementType, out, options...) } // ClauseDelete struct @@ -561,3 +571,26 @@ type KeywordClause struct { func (k KeywordClause) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { k.serialize(statementType, out, FallTrough(options)...) } + +// ClauseReturning type +type ClauseReturning struct { + ProjectionList []Projection +} + +// Serialize for ClauseReturning +func (r *ClauseReturning) Serialize(statementType StatementType, out *SQLBuilder, options ...SerializeOption) { + if len(r.ProjectionList) == 0 { + return + } + + out.NewLine() + out.WriteString("RETURNING") + out.IncreaseIdent() + out.WriteProjections(statementType, r.ProjectionList) + out.DecreaseIdent() +} + +// Projections for ClauseReturning +func (r ClauseReturning) Projections() ProjectionList { + return r.ProjectionList +} diff --git a/postgres/clause.go b/postgres/clause.go index 6174d4f..3a23fd0 100644 --- a/postgres/clause.go +++ b/postgres/clause.go @@ -4,33 +4,10 @@ import ( "github.com/go-jet/jet/v2/internal/jet" ) -type clauseReturning struct { - ProjectionList []jet.Projection -} - -func (r *clauseReturning) Serialize(statementType jet.StatementType, out *jet.SQLBuilder, options ...jet.SerializeOption) { - if len(r.ProjectionList) == 0 { - return - } - - out.NewLine() - out.WriteString("RETURNING") - out.IncreaseIdent() - out.WriteProjections(statementType, r.ProjectionList) - out.DecreaseIdent() -} - -func (r clauseReturning) Projections() ProjectionList { - return r.ProjectionList -} - -// ========================================== // - type onConflict interface { ON_CONSTRAINT(name string) conflictTarget WHERE(indexPredicate BoolExpression) conflictTarget - DO_NOTHING() InsertStatement - DO_UPDATE(action conflictAction) InsertStatement + conflictTarget } type conflictTarget interface { diff --git a/postgres/delete_statement.go b/postgres/delete_statement.go index ca2816c..2bfbd8c 100644 --- a/postgres/delete_statement.go +++ b/postgres/delete_statement.go @@ -16,7 +16,7 @@ type deleteStatementImpl struct { Delete jet.ClauseStatementBegin Where jet.ClauseWhere - Returning clauseReturning + Returning jet.ClauseReturning } func newDeleteStatement(table WritableTable) DeleteStatement { diff --git a/postgres/insert_statement.go b/postgres/insert_statement.go index a134a12..763e533 100644 --- a/postgres/insert_statement.go +++ b/postgres/insert_statement.go @@ -22,7 +22,11 @@ type InsertStatement interface { func newInsertStatement(table WritableTable, columns []jet.Column) InsertStatement { newInsert := &insertStatementImpl{} newInsert.SerializerStatement = jet.NewStatementImpl(Dialect, jet.InsertStatementType, newInsert, - &newInsert.Insert, &newInsert.ValuesQuery, &newInsert.OnConflict, &newInsert.Returning) + &newInsert.Insert, + &newInsert.ValuesQuery, + &newInsert.OnConflict, + &newInsert.Returning, + ) newInsert.Insert.Table = table newInsert.Insert.Columns = columns @@ -35,7 +39,7 @@ type insertStatementImpl struct { Insert jet.ClauseInsert ValuesQuery jet.ClauseValuesQuery - Returning clauseReturning + Returning jet.ClauseReturning OnConflict onConflictClause } diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index fd3a76c..ad687b5 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -8,7 +8,6 @@ import ( ) func TestInvalidInsert(t *testing.T) { - assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") } diff --git a/postgres/interval_expression.go b/postgres/interval_expression.go index b8468cf..6f5ab58 100644 --- a/postgres/interval_expression.go +++ b/postgres/interval_expression.go @@ -116,7 +116,7 @@ func INTERVAL(quantityAndUnit ...quantityAndUnit) IntervalExpression { panic("jet: invalid number of quantity and unit fields") } - fields := []string{} + var fields []string for i := 0; i < len(quantityAndUnit); i += 2 { quantity := strconv.FormatFloat(quantityAndUnit[i], 'f', -1, 64) diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 594efa4..58c5ba4 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -22,7 +22,7 @@ type updateStatementImpl struct { Set clauseSet SetNew jet.SetClauseNew Where jet.ClauseWhere - Returning clauseReturning + Returning jet.ClauseReturning } func newUpdateStatement(table WritableTable, columns []jet.Column) UpdateStatement {