diff --git a/examples/quick-start/quick-start.go b/examples/quick-start/quick-start.go index 48a28ce..cabbb5c 100644 --- a/examples/quick-start/quick-start.go +++ b/examples/quick-start/quick-start.go @@ -24,7 +24,6 @@ const ( ) func main() { - // Connect to database var connectString = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", Host, Port, User, Password, DBName) @@ -98,16 +97,14 @@ func jsonSave(path string, v interface{}) { } func printStatementInfo(stmt SelectStatement) { - query, args, err := stmt.Sql() - panicOnError(err) + query, args := stmt.Sql() fmt.Println("Parameterized query: ") fmt.Println(query) fmt.Println("Arguments: ") fmt.Println(args) - debugSQL, err := stmt.DebugSql() - panicOnError(err) + debugSQL := stmt.DebugSql() fmt.Println("\n\n==============================") diff --git a/execution/execution.go b/execution/execution.go index 03693e9..e7700a7 100644 --- a/execution/execution.go +++ b/execution/execution.go @@ -4,7 +4,6 @@ import ( "context" "database/sql" "database/sql/driver" - "errors" "fmt" "github.com/go-jet/jet/execution/internal" "github.com/go-jet/jet/internal/utils" @@ -19,14 +18,11 @@ import ( // Destination can be either pointer to struct or pointer to slice of structs. func Query(context context.Context, db DB, query string, args []interface{}, destinationPtr interface{}) error { - if utils.IsNil(destinationPtr) { - return errors.New("jet: Destination is nil") - } + utils.MustBeInitializedPtr(db, "jet: db is nil") + utils.MustBeInitializedPtr(destinationPtr, "jet: destination is nil") + utils.MustBe(destinationPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct") destinationPtrType := reflect.TypeOf(destinationPtr) - if destinationPtrType.Kind() != reflect.Ptr { - return errors.New("jet: Destination has to be a pointer to slice or pointer to struct") - } if destinationPtrType.Elem().Kind() == reflect.Slice { return queryToSlice(context, db, query, args, destinationPtr) @@ -52,24 +48,11 @@ func Query(context context.Context, db DB, query string, args []interface{}, des } return nil } else { - return errors.New("jet: unsupported destination type") + panic("jet: destination has to be a pointer to slice or pointer to struct") } } func queryToSlice(ctx context.Context, db DB, query string, args []interface{}, slicePtr interface{}) error { - if db == nil { - return errors.New("jet: db is nil") - } - - if slicePtr == nil { - return errors.New("jet: Destination is nil. ") - } - - destinationType := reflect.TypeOf(slicePtr) - if destinationType.Kind() != reflect.Ptr && destinationType.Elem().Kind() != reflect.Slice { - return errors.New("jet: Destination has to be a pointer to slice. ") - } - if ctx == nil { ctx = context.Background() } @@ -132,9 +115,7 @@ func mapRowToSlice(scanContext *scanContext, groupKey string, slicePtrValue refl return } - if sliceElemType.Kind() != reflect.Struct { - return false, errors.New("jet: Unsupported dest type: " + field.Name + " " + field.Type.String()) - } + utils.TypeMustBe(sliceElemType, reflect.Struct, "jet: unsupported slice element type at '"+fieldToString(field)+"'.") structGroupKey := scanContext.getGroupKey(sliceElemType, field) @@ -315,9 +296,7 @@ func mapRowToStruct(scanContext *scanContext, groupKey string, structPtrValue re func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrValue reflect.Value, structField *reflect.StructField) (updated bool, err error) { - if destPtrValue.Kind() != reflect.Ptr { - return false, errors.New("jet: Internal error. ") - } + utils.ValueMustBe(destPtrValue, reflect.Ptr, "jet: internal error. Destination is not pointer.") destValueKind := destPtrValue.Elem().Kind() @@ -326,7 +305,7 @@ func mapRowToDestinationPtr(scanContext *scanContext, groupKey string, destPtrVa } else if destValueKind == reflect.Slice { return mapRowToSlice(scanContext, groupKey, destPtrValue, structField) } else { - return false, errors.New("jet: Unsupported dest type: " + structField.Name + " " + structField.Type.String()) + panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) } } @@ -336,14 +315,12 @@ func mapRowToDestinationValue(scanContext *scanContext, groupKey string, dest re if dest.Kind() != reflect.Ptr { destPtrValue = dest.Addr() - } else if dest.Kind() == reflect.Ptr { + } else { if dest.IsNil() { destPtrValue = reflect.New(dest.Type().Elem()) } else { destPtrValue = dest } - } else { - return false, errors.New("jet: Internal error. ") } updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField) @@ -602,7 +579,7 @@ func setReflectValue(source, destination reflect.Value) error { } } - return errors.New("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) + panic("jet: can't set " + source.Type().String() + " to " + destination.Type().String()) } func createScanValue(columnTypes []*sql.ColumnType) []interface{} { @@ -869,3 +846,11 @@ func indirectType(reflectType reflect.Type) reflect.Type { } return reflectType.Elem() } + +func fieldToString(field *reflect.StructField) string { + if field == nil { + return "" + } + + return field.Name + " " + field.Type.String() +} diff --git a/internal/jet/alias.go b/internal/jet/alias.go index 5d4c84b..d9b33d7 100644 --- a/internal/jet/alias.go +++ b/internal/jet/alias.go @@ -20,15 +20,9 @@ func (a *alias) fromImpl(subQuery SelectTable) Projection { return &column } -func (a *alias) serializeForProjection(statement StatementType, out *SqlBuilder) error { - err := a.expression.serialize(statement, out) - - if err != nil { - return err - } +func (a *alias) serializeForProjection(statement StatementType, out *SqlBuilder) { + a.expression.serialize(statement, out) out.WriteString("AS") out.WriteAlias(a.alias) - - return nil } diff --git a/internal/jet/bool_expression_test.go b/internal/jet/bool_expression_test.go index 5295374..765cdab 100644 --- a/internal/jet/bool_expression_test.go +++ b/internal/jet/bool_expression_test.go @@ -5,8 +5,8 @@ import ( ) func TestBoolExpressionEQ(t *testing.T) { - assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: nil rhs") assertClauseSerialize(t, table1ColBool.EQ(table2ColBool), "(table1.col_bool = table2.col_bool)") + assertClauseSerializeErr(t, table1ColBool.EQ(nil), "jet: rhs is nil for '=' operator") } func TestBoolExpressionNOT_EQ(t *testing.T) { diff --git a/internal/jet/cast.go b/internal/jet/cast.go index b7768af..7369256 100644 --- a/internal/jet/cast.go +++ b/internal/jet/cast.go @@ -34,23 +34,18 @@ type castExpression struct { cast string } -func (b *castExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (b *castExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { expression := b.expression castType := b.cast if castOverride := out.Dialect.SerializeOverride("CAST"); castOverride != nil { - return castOverride(expression, String(castType))(statement, out, options...) + castOverride(expression, String(castType))(statement, out, options...) + return } out.WriteString("CAST(") - err := expression.serialize(statement, out, options...) - if err != nil { - return err - } - + expression.serialize(statement, out, options...) out.WriteString("AS") out.WriteString(castType + ")") - - return err } diff --git a/internal/jet/clause.go b/internal/jet/clause.go index b61df06..b64dc0e 100644 --- a/internal/jet/clause.go +++ b/internal/jet/clause.go @@ -1,12 +1,11 @@ package jet import ( - "errors" "github.com/go-jet/jet/internal/utils" ) type Clause interface { - Serialize(statementType StatementType, out *SqlBuilder) error + Serialize(statementType StatementType, out *SqlBuilder) } type ClauseWithProjections interface { @@ -24,7 +23,7 @@ func (s *ClauseSelect) projections() []Projection { return s.Projections } -func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) error { +func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString("SELECT") @@ -33,28 +32,26 @@ func (s *ClauseSelect) Serialize(statementType StatementType, out *SqlBuilder) e } if len(s.Projections) == 0 { - return errors.New("jet: no column selected for Projection") + panic("jet: SELECT clause has to have at least one projection") } - return out.WriteProjections(statementType, s.Projections) + out.WriteProjections(statementType, s.Projections) } type ClauseFrom struct { Table Serializer } -func (f *ClauseFrom) Serialize(statementType StatementType, s *SqlBuilder) error { +func (f *ClauseFrom) Serialize(statementType StatementType, out *SqlBuilder) { if f.Table == nil { - return nil + return } - s.NewLine() - s.WriteString("FROM") + out.NewLine() + out.WriteString("FROM") - s.IncreaseIdent() - err := f.Table.serialize(statementType, s) - s.DecreaseIdent() - - return err + out.IncreaseIdent() + f.Table.serialize(statementType, out) + out.DecreaseIdent() } type ClauseWhere struct { @@ -62,120 +59,108 @@ type ClauseWhere struct { Mandatory bool } -func (c *ClauseWhere) Serialize(statementType StatementType, s *SqlBuilder) error { +func (c *ClauseWhere) Serialize(statementType StatementType, out *SqlBuilder) { if c.Condition == nil { if c.Mandatory { - return errors.New("jet: WHERE clause not set") + panic("jet: WHERE clause not set") } - return nil + return } - s.NewLine() - s.WriteString("WHERE") + out.NewLine() + out.WriteString("WHERE") - s.IncreaseIdent() - err := c.Condition.serialize(statementType, s, noWrap) - s.DecreaseIdent() - - return err + out.IncreaseIdent() + c.Condition.serialize(statementType, out, noWrap) + out.DecreaseIdent() } type ClauseGroupBy struct { List []GroupByClause } -func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) error { +func (c *ClauseGroupBy) Serialize(statementType StatementType, out *SqlBuilder) { if len(c.List) == 0 { - return nil + return } out.NewLine() out.WriteString("GROUP BY") out.IncreaseIdent() - err := serializeGroupByClauseList(statementType, c.List, out) + serializeGroupByClauseList(statementType, c.List, out) out.DecreaseIdent() - - return err } type ClauseHaving struct { Condition BoolExpression } -func (c *ClauseHaving) Serialize(statementType StatementType, s *SqlBuilder) error { +func (c *ClauseHaving) Serialize(statementType StatementType, out *SqlBuilder) { if c.Condition == nil { - return nil + return } - s.NewLine() - s.WriteString("HAVING") + out.NewLine() + out.WriteString("HAVING") - s.IncreaseIdent() - err := c.Condition.serialize(statementType, s, noWrap) - s.DecreaseIdent() - - return err + out.IncreaseIdent() + c.Condition.serialize(statementType, out, noWrap) + out.DecreaseIdent() } type ClauseOrderBy struct { List []OrderByClause } -func (o *ClauseOrderBy) Serialize(statementType StatementType, s *SqlBuilder) error { +func (o *ClauseOrderBy) Serialize(statementType StatementType, out *SqlBuilder) { if o.List == nil { - return nil + return } - s.NewLine() - s.WriteString("ORDER BY") + out.NewLine() + out.WriteString("ORDER BY") - s.IncreaseIdent() - err := serializeOrderByClauseList(statementType, o.List, s) - s.DecreaseIdent() - - return err + out.IncreaseIdent() + serializeOrderByClauseList(statementType, o.List, out) + out.DecreaseIdent() } type ClauseLimit struct { Count int64 } -func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) error { +func (l *ClauseLimit) Serialize(statementType StatementType, out *SqlBuilder) { if l.Count >= 0 { out.NewLine() out.WriteString("LIMIT") out.insertParametrizedArgument(l.Count) } - - return nil } type ClauseOffset struct { Count int64 } -func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) error { +func (o *ClauseOffset) Serialize(statementType StatementType, out *SqlBuilder) { if o.Count >= 0 { out.NewLine() out.WriteString("OFFSET") out.insertParametrizedArgument(o.Count) } - - return nil } type ClauseFor struct { Lock SelectLock } -func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) error { +func (f *ClauseFor) Serialize(statementType StatementType, out *SqlBuilder) { if f.Lock == nil { - return nil + return } out.NewLine() out.WriteString("FOR") - return f.Lock.serialize(statementType, out) + f.Lock.serialize(statementType, out) } type ClauseSetStmtOperator struct { @@ -194,9 +179,9 @@ func (s *ClauseSetStmtOperator) projections() []Projection { return nil } -func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlBuilder) error { +func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlBuilder) { if len(s.Selects) < 2 { - return errors.New("jet: UNION Statement must have at least two SELECT statements") + panic("jet: UNION Statement must contain at least two SELECT statements") } wrap := s.OrderBy.List != nil || s.Limit.Count >= 0 || s.Offset.Count >= 0 @@ -219,14 +204,10 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB } if selectStmt == nil { - return errors.New("jet: select statement is nil") + panic("jet: select statement of '" + s.Operator + "' is nil") } - err := selectStmt.serialize(statementType, out) - - if err != nil { - return err - } + selectStmt.serialize(statementType, out) } if wrap { @@ -235,38 +216,24 @@ func (s *ClauseSetStmtOperator) Serialize(statementType StatementType, out *SqlB out.WriteString(")") } - if err := s.OrderBy.Serialize(statementType, out); err != nil { - return err - } - - if err := s.Limit.Serialize(statementType, out); err != nil { - return err - } - - if err := s.Offset.Serialize(statementType, out); err != nil { - return err - } - - return nil + s.OrderBy.Serialize(statementType, out) + s.Limit.Serialize(statementType, out) + s.Offset.Serialize(statementType, out) } type ClauseUpdate struct { Table SerializerTable } -func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) error { +func (u *ClauseUpdate) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString("UPDATE") if utils.IsNil(u.Table) { - return errors.New("jet: table to update is nil") + panic("jet: table to update is nil") } - if err := u.Table.serialize(statementType, out); err != nil { - return err - } - - return nil + u.Table.serialize(statementType, out) } type ClauseSet struct { @@ -274,12 +241,12 @@ type ClauseSet struct { Values []Serializer } -func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) error { +func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString("SET") if len(s.Columns) != len(s.Values) { - return errors.New("jet: mismatch in numers of columns and values") + panic("jet: mismatch in numbers of columns and values for SET clause") } out.IncreaseIdent(4) @@ -290,20 +257,16 @@ func (s *ClauseSet) Serialize(statementType StatementType, out *SqlBuilder) erro } if column == nil { - return errors.New("jet: nil column in columns list") + panic("jet: nil column in columns list for SET clause") } out.WriteString(column.Name()) out.WriteString(" = ") - if err := s.Values[i].serialize(UpdateStatementType, out); err != nil { - return err - } + s.Values[i].serialize(UpdateStatementType, out) } out.DecreaseIdent(4) - - return nil } type ClauseInsert struct { @@ -319,33 +282,23 @@ func (i *ClauseInsert) GetColumns() []Column { return i.Table.columns() } -func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) error { +func (i *ClauseInsert) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString("INSERT INTO") if utils.IsNil(i.Table) { - return errors.New("jet: table is nil") + panic("jet: table is nil for INSERT clause") } - err := i.Table.serialize(statementType, out) - - if err != nil { - return err - } + i.Table.serialize(statementType, out) if len(i.Columns) > 0 { out.WriteString("(") - err = SerializeColumnNames(i.Columns, out) - - if err != nil { - return err - } + SerializeColumnNames(i.Columns, out) out.WriteString(")") } - - return nil } type ClauseValuesQuery struct { @@ -353,33 +306,26 @@ type ClauseValuesQuery struct { ClauseQuery } -func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SqlBuilder) error { +func (v *ClauseValuesQuery) Serialize(statementType StatementType, out *SqlBuilder) { if len(v.Rows) == 0 && v.Query == nil { - return errors.New("jet: no row values or query specified") + panic("jet: VALUES or QUERY has to be specified for INSERT statement") } if len(v.Rows) > 0 && v.Query != nil { - return errors.New("jet: only row values or query has to be specified") + panic("jet: VALUES or QUERY has to be specified for INSERT statement") } - if err := v.ClauseValues.Serialize(statementType, out); err != nil { - return err - } - - if err := v.ClauseQuery.Serialize(statementType, out); err != nil { - return err - } - - return nil + v.ClauseValues.Serialize(statementType, out) + v.ClauseQuery.Serialize(statementType, out) } type ClauseValues struct { Rows [][]Serializer } -func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) error { +func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) { if len(v.Rows) == 0 { - return nil + return } out.WriteString("VALUES") @@ -393,47 +339,38 @@ func (v *ClauseValues) Serialize(statementType StatementType, out *SqlBuilder) e out.NewLine() out.WriteString("(") - err := SerializeClauseList(statementType, row, out) - - if err != nil { - return err - } + SerializeClauseList(statementType, row, out) out.WriteByte(')') out.DecreaseIdent() } - return nil } type ClauseQuery struct { Query SerializerStatement } -func (v *ClauseQuery) Serialize(statementType StatementType, out *SqlBuilder) error { +func (v *ClauseQuery) Serialize(statementType StatementType, out *SqlBuilder) { if v.Query == nil { - return nil + return } - return v.Query.serialize(statementType, out) + v.Query.serialize(statementType, out) } type ClauseDelete struct { Table SerializerTable } -func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) error { +func (d *ClauseDelete) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString("DELETE FROM") if d.Table == nil { - return errors.New("jet: nil tableName") + panic("jet: nil table in DELETE clause") } - if err := d.Table.serialize(statementType, out); err != nil { - return err - } - - return nil + d.Table.serialize(statementType, out) } type ClauseStatementBegin struct { @@ -441,7 +378,7 @@ type ClauseStatementBegin struct { Tables []SerializerTable } -func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) error { +func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString(d.Name) @@ -450,14 +387,8 @@ func (d *ClauseStatementBegin) Serialize(statementType StatementType, out *SqlBu out.WriteString(", ") } - err := table.serialize(statementType, out) - - if err != nil { - return err - } + table.serialize(statementType, out) } - - return nil } type ClauseString struct { @@ -465,11 +396,10 @@ type ClauseString struct { Data string } -func (d *ClauseString) Serialize(statementType StatementType, out *SqlBuilder) error { +func (d *ClauseString) Serialize(statementType StatementType, out *SqlBuilder) { out.NewLine() out.WriteString(d.Name) out.WriteString(d.Data) - return nil } type ClauseOptional struct { @@ -477,27 +407,23 @@ type ClauseOptional struct { Show bool } -func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) error { +func (d *ClauseOptional) Serialize(statementType StatementType, out *SqlBuilder) { if !d.Show { - return nil + return } - //out.newLine() out.WriteString(d.Name) - return nil } type ClauseIn struct { LockMode string } -func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) error { +func (i *ClauseIn) Serialize(statementType StatementType, out *SqlBuilder) { if i.LockMode == "" { - return nil + return } out.WriteString("IN") out.WriteString(string(i.LockMode)) out.WriteString("MODE") - - return nil } diff --git a/internal/jet/clause_test.go b/internal/jet/clause_test.go new file mode 100644 index 0000000..6722c10 --- /dev/null +++ b/internal/jet/clause_test.go @@ -0,0 +1,16 @@ +package jet + +import ( + "gotest.tools/assert" + "testing" +) + +func TestClauseSelect_Serialize(t *testing.T) { + defer func() { + r := recover() + assert.Equal(t, r, "jet: SELECT clause has to have at least one projection") + }() + + selectClause := &ClauseSelect{} + selectClause.Serialize(SelectStatementType, &SqlBuilder{}) +} diff --git a/internal/jet/column.go b/internal/jet/column.go index fed37d0..4a9473d 100644 --- a/internal/jet/column.go +++ b/internal/jet/column.go @@ -62,31 +62,25 @@ func (c *columnImpl) defaultAlias() string { return c.name } -func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) error { +func (c *columnImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) { if statement == SetStatementType { // set Statement (UNION, EXCEPT ...) can reference only select projections in order by clause out.WriteAlias(c.defaultAlias()) //always quote - return nil + return } - return c.serialize(statement, out) + c.serialize(statement, out) } -func (c columnImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { - err := c.serialize(statement, out) - - if err != nil { - return err - } +func (c columnImpl) serializeForProjection(statement StatementType, out *SqlBuilder) { + c.serialize(statement, out) out.WriteString("AS") out.WriteAlias(c.defaultAlias()) - - return nil } -func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if c.subQuery != nil { out.WriteIdentifier(c.subQuery.Alias()) @@ -100,8 +94,6 @@ func (c columnImpl) serialize(statement StatementType, out *SqlBuilder, options out.WriteIdentifier(c.name) } - - return nil } //------------------------------------------------------// @@ -134,16 +126,10 @@ func (cl columnListImpl) fromImpl(subQuery SelectTable) Projection { return newProjectionList } -func (cl columnListImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { +func (cl columnListImpl) serializeForProjection(statement StatementType, out *SqlBuilder) { projections := ColumnListToProjectionList(cl) - err := SerializeProjectionList(statement, projections, out) - - if err != nil { - return err - } - - return nil + SerializeProjectionList(statement, projections, out) } // dummy column interface implementation diff --git a/internal/jet/column_types_test.go b/internal/jet/column_types_test.go index 2e12ef1..0c154ec 100644 --- a/internal/jet/column_types_test.go +++ b/internal/jet/column_types_test.go @@ -4,7 +4,7 @@ import ( "testing" ) -var subQuery = &SelectTableImpl2{ +var subQuery = &SelectTableImpl{ alias: "sub_query", } diff --git a/internal/jet/dialect.go b/internal/jet/dialect.go index 7fc1111..d07320d 100644 --- a/internal/jet/dialect.go +++ b/internal/jet/dialect.go @@ -9,11 +9,9 @@ type Dialect interface { ArgumentPlaceholder() QueryPlaceholderFunc } -type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...SerializeOption) error +type SerializeFunc func(statement StatementType, out *SqlBuilder, options ...SerializeOption) type SerializeOverride func(expressions ...Expression) SerializeFunc - type QueryPlaceholderFunc func(ord int) string -type UpdateAssigmentFunc func(columns []Column, values []Serializer, out *SqlBuilder) (err error) type DialectParams struct { Name string @@ -42,7 +40,6 @@ type dialectImpl struct { aliasQuoteChar byte identifierQuoteChar byte argumentPlaceholder QueryPlaceholderFunc - setClause UpdateAssigmentFunc supportsReturning bool } diff --git a/internal/jet/enum_value.go b/internal/jet/enum_value.go index c7befab..c1b2067 100644 --- a/internal/jet/enum_value.go +++ b/internal/jet/enum_value.go @@ -17,7 +17,6 @@ func NewEnumValue(name string) StringExpression { return enumValue } -func (e enumValue) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (e enumValue) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.insertConstantArgument(e.name) - return nil } diff --git a/internal/jet/expression.go b/internal/jet/expression.go index d86d753..817ce6f 100644 --- a/internal/jet/expression.go +++ b/internal/jet/expression.go @@ -1,9 +1,5 @@ package jet -import ( - "errors" -) - // Expression is common interface for all expressions. // Can be Bool, Int, Float, String, Date, Time, Timez, Timestamp or Timestampz expressions. type Expression interface { @@ -67,16 +63,16 @@ func (e *ExpressionInterfaceImpl) DESC() OrderByClause { return newOrderByClause(e.Parent, false) } -func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SqlBuilder) error { - return e.Parent.serialize(statement, out, noWrap) +func (e *ExpressionInterfaceImpl) serializeForGroupBy(statement StatementType, out *SqlBuilder) { + e.Parent.serialize(statement, out, noWrap) } -func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { - return e.Parent.serialize(statement, out, noWrap) +func (e *ExpressionInterfaceImpl) serializeForProjection(statement StatementType, out *SqlBuilder) { + e.Parent.serialize(statement, out, noWrap) } -func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) error { - return e.Parent.serialize(statement, out, noWrap) +func (e *ExpressionInterfaceImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) { + e.Parent.serialize(statement, out, noWrap) } // Representation of binary operations (e.g. comparisons, arithmetic) @@ -95,15 +91,12 @@ func newBinaryExpression(lhs, rhs Expression, operator string) binaryOpExpressio return binaryExpression } -func (c *binaryOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) { - if c == nil { - return errors.New("jet: binary Expression is nil") - } +func (c *binaryOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if c.lhs == nil { - return errors.New("jet: nil lhs") + panic("jet: lhs is nil for '" + c.operator + "' operator") } if c.rhs == nil { - return errors.New("jet: nil rhs") + panic("jet: rhs is nil for '" + c.operator + "' operator") } wrap := !contains(options, noWrap) @@ -113,27 +106,17 @@ func (c *binaryOpExpression) serialize(statement StatementType, out *SqlBuilder, } if serializeOverride := out.Dialect.SerializeOverride(c.operator); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(c.lhs, c.rhs) - err = serializeOverrideFunc(statement, out, options...) - + serializeOverrideFunc(statement, out, options...) } else { - if err := c.lhs.serialize(statement, out); err != nil { - return err - } - + c.lhs.serialize(statement, out) out.WriteString(c.operator) - - if err := c.rhs.serialize(statement, out); err != nil { - return err - } + c.rhs.serialize(statement, out) } if wrap { out.WriteString(")") } - - return err } // A prefix operator Expression @@ -151,24 +134,17 @@ func newPrefixExpression(expression Expression, operator string) prefixOpExpress return prefixExpression } -func (p *prefixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if p == nil { - return errors.New("jet: Prefix Expression is nil") - } - +func (p *prefixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString("(") out.WriteString(p.operator) if p.expression == nil { - return errors.New("jet: nil prefix Expression") - } - if err := p.expression.serialize(statement, out); err != nil { - return err + panic("jet: nil prefix expression in prefix operator " + p.operator) } + p.expression.serialize(statement, out) + out.WriteString(")") - - return nil } // A postifx operator Expression @@ -186,19 +162,12 @@ func newPostfixOpExpression(expression Expression, operator string) postfixOpExp return postfixOpExpression } -func (p *postfixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if p == nil { - return errors.New("jet: Postifx operator Expression is nil") +func (p *postfixOpExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { + if p.expression == nil { + panic("jet: nil prefix expression in postfix operator " + p.operator) } - if p.expression == nil { - return errors.New("jet: nil prefix Expression") - } - if err := p.expression.serialize(statement, out); err != nil { - return err - } + p.expression.serialize(statement, out) out.WriteString(p.operator) - - return nil } diff --git a/internal/jet/expression_test.go b/internal/jet/expression_test.go index 8db4cf9..7b7bed6 100644 --- a/internal/jet/expression_test.go +++ b/internal/jet/expression_test.go @@ -4,10 +4,13 @@ import ( "testing" ) +func TestInvalidExpression(t *testing.T) { + assertClauseSerializeErr(t, table2Col3.ADD(nil), `jet: rhs is nil for '+' operator`) +} + func TestExpressionIS_NULL(t *testing.T) { assertClauseSerialize(t, table2Col3.IS_NULL(), "table2.col3 IS NULL") assertClauseSerialize(t, table2Col3.ADD(table2Col3).IS_NULL(), "(table2.col3 + table2.col3) IS NULL") - assertClauseSerializeErr(t, table2Col3.ADD(nil), "jet: nil rhs") } func TestExpressionIS_NOT_NULL(t *testing.T) { diff --git a/internal/jet/func_expression.go b/internal/jet/func_expression.go index 8495494..45ee0f5 100644 --- a/internal/jet/func_expression.go +++ b/internal/jet/func_expression.go @@ -1,7 +1,5 @@ package jet -import "errors" - // ROW is construct one table row from list of expressions. func ROW(expressions ...Expression) Expression { return newFunc("ROW", expressions, nil) @@ -500,15 +498,11 @@ func newFunc(name string, expressions []Expression, parent Expression) *funcExpr return funcExp } -func (f *funcExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if f == nil { - return errors.New("jet: Function expressions is nil. ") - } - +func (f *funcExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if serializeOverride := out.Dialect.SerializeOverride(f.name); serializeOverride != nil { - serializeOverrideFunc := serializeOverride(f.expressions...) - return serializeOverrideFunc(statement, out, options...) + serializeOverrideFunc(statement, out, options...) + return } addBrackets := !f.noBrackets || len(f.expressions) > 0 @@ -519,16 +513,11 @@ func (f *funcExpressionImpl) serialize(statement StatementType, out *SqlBuilder, out.WriteString(f.name) } - err := serializeExpressionList(statement, f.expressions, ", ", out) - if err != nil { - return err - } + serializeExpressionList(statement, f.expressions, ", ", out) if addBrackets { out.WriteString(")") } - - return nil } type boolFunc struct { diff --git a/internal/jet/group_by_clause.go b/internal/jet/group_by_clause.go index 7a629f6..d010da9 100644 --- a/internal/jet/group_by_clause.go +++ b/internal/jet/group_by_clause.go @@ -1,5 +1,5 @@ package jet type GroupByClause interface { - serializeForGroupBy(statement StatementType, out *SqlBuilder) error + serializeForGroupBy(statement StatementType, out *SqlBuilder) } diff --git a/internal/jet/keyword.go b/internal/jet/keyword.go index be04d02..b8f30cd 100644 --- a/internal/jet/keyword.go +++ b/internal/jet/keyword.go @@ -14,8 +14,6 @@ var ( type keywordClause string -func (k keywordClause) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (k keywordClause) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString(string(k)) - - return nil } diff --git a/internal/jet/literal_expression.go b/internal/jet/literal_expression.go index a5fbbaa..ccb9940 100644 --- a/internal/jet/literal_expression.go +++ b/internal/jet/literal_expression.go @@ -39,14 +39,12 @@ func constLiteral(value interface{}) *literalExpressionImpl { return exp } -func (l *literalExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (l *literalExpressionImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if l.constant { out.insertConstantArgument(l.value) } else { out.insertParametrizedArgument(l.value) } - - return nil } func (l *literalExpressionImpl) Value() interface{} { @@ -286,9 +284,8 @@ func newNullLiteral() Expression { return nullExpression } -func (n *nullLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (n *nullLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString("NULL") - return nil } //--------------------------------------------------// @@ -304,9 +301,8 @@ func newStarLiteral() Expression { return starExpression } -func (n *starLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (n *starLiteral) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString("*") - return nil } //---------------------------------------------------// @@ -316,11 +312,10 @@ type wrap struct { expressions []Expression } -func (n *wrap) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (n *wrap) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString("(") - err := serializeExpressionList(statement, n.expressions, ", ", out) + serializeExpressionList(statement, n.expressions, ", ", out) out.WriteString(")") - return err } // WRAP wraps list of expressions with brackets '(' and ')' @@ -339,9 +334,8 @@ type rawExpression struct { raw string } -func (n *rawExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (n *rawExpression) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString(n.raw) - return nil } // Raw can be used for any unsupported functions, operators or expressions. diff --git a/internal/jet/operators.go b/internal/jet/operators.go index f8dc448..ba1f984 100644 --- a/internal/jet/operators.go +++ b/internal/jet/operators.go @@ -1,7 +1,5 @@ package jet -import "errors" - const ( StringConcatOperator = "||" ) @@ -112,55 +110,33 @@ func (c *caseOperatorImpl) ELSE(els Expression) CaseOperator { return c } -func (c *caseOperatorImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if c == nil { - return errors.New("jet: Case Expression is nil. ") - } - +func (c *caseOperatorImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString("(CASE") if c.expression != nil { - err := c.expression.serialize(statement, out) - - if err != nil { - return err - } + c.expression.serialize(statement, out) } if len(c.when) == 0 || len(c.then) == 0 { - return errors.New("jet: Invalid case Statement. There should be at least one when/then Expression pair. ") + panic("jet: invalid case Statement. There should be at least one WHEN/THEN pair. ") } if len(c.when) != len(c.then) { - return errors.New("jet: When and then Expression count mismatch. ") + panic("jet: WHEN and THEN expression count mismatch. ") } for i, when := range c.when { out.WriteString("WHEN") - err := when.serialize(statement, out, noWrap) - - if err != nil { - return err - } + when.serialize(statement, out, noWrap) out.WriteString("THEN") - err = c.then[i].serialize(statement, out, noWrap) - - if err != nil { - return err - } + c.then[i].serialize(statement, out, noWrap) } if c.els != nil { out.WriteString("ELSE") - err := c.els.serialize(statement, out, noWrap) - - if err != nil { - return err - } + c.els.serialize(statement, out, noWrap) } out.WriteString("END)") - - return nil } diff --git a/internal/jet/order_by_clause.go b/internal/jet/order_by_clause.go index 782f366..42d12d4 100644 --- a/internal/jet/order_by_clause.go +++ b/internal/jet/order_by_clause.go @@ -1,10 +1,8 @@ package jet -import "errors" - // OrderByClause type OrderByClause interface { - serializeForOrderBy(statement StatementType, out *SqlBuilder) error + serializeForOrderBy(statement StatementType, out *SqlBuilder) } type orderByClauseImpl struct { @@ -12,22 +10,18 @@ type orderByClauseImpl struct { ascent bool } -func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) error { +func (o *orderByClauseImpl) serializeForOrderBy(statement StatementType, out *SqlBuilder) { if o.expression == nil { - return errors.New("jet: nil orderBy by clause") + panic("jet: nil expression in ORDER BY clause") } - if err := o.expression.serializeForOrderBy(statement, out); err != nil { - return err - } + o.expression.serializeForOrderBy(statement, out) if o.ascent { out.WriteString("ASC") } else { out.WriteString("DESC") } - - return nil } func newOrderByClause(expression Expression, ascent bool) OrderByClause { diff --git a/internal/jet/projection.go b/internal/jet/projection.go index 0b5505c..da355c2 100644 --- a/internal/jet/projection.go +++ b/internal/jet/projection.go @@ -1,12 +1,12 @@ package jet type Projection interface { - serializeForProjection(statement StatementType, out *SqlBuilder) error + serializeForProjection(statement StatementType, out *SqlBuilder) fromImpl(subQuery SelectTable) Projection } -func SerializeForProjection(projection Projection, statementType StatementType, out *SqlBuilder) error { - return projection.serializeForProjection(statementType, out) +func SerializeForProjection(projection Projection, statementType StatementType, out *SqlBuilder) { + projection.serializeForProjection(statementType, out) } // ProjectionList is a redefined type, so that ProjectionList can be used as a Projection. @@ -22,12 +22,6 @@ func (cl ProjectionList) fromImpl(subQuery SelectTable) Projection { return newProjectionList } -func (cl ProjectionList) serializeForProjection(statement StatementType, out *SqlBuilder) error { - err := SerializeProjectionList(statement, cl, out) - - if err != nil { - return err - } - - return nil +func (cl ProjectionList) serializeForProjection(statement StatementType, out *SqlBuilder) { + SerializeProjectionList(statement, cl, out) } diff --git a/internal/jet/select_lock.go b/internal/jet/select_lock.go index 8e79b9c..60b38f3 100644 --- a/internal/jet/select_lock.go +++ b/internal/jet/select_lock.go @@ -33,7 +33,7 @@ func (s *selectLockImpl) SKIP_LOCKED() SelectLock { return s } -func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { out.WriteString(s.lockStrength) if s.noWait { @@ -43,6 +43,4 @@ func (s *selectLockImpl) serialize(statement StatementType, out *SqlBuilder, opt if s.skipLocked { out.WriteString("SKIP LOCKED") } - - return nil } diff --git a/internal/jet/select_table.go b/internal/jet/select_table.go index de05285..e32d714 100644 --- a/internal/jet/select_table.go +++ b/internal/jet/select_table.go @@ -1,22 +1,20 @@ package jet -import "errors" - // SelectTable is interface for SELECT sub-queries type SelectTable interface { Alias() string AllColumns() ProjectionList } -type SelectTableImpl2 struct { +type SelectTableImpl struct { selectStmt StatementWithProjections alias string projections []Projection } -func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl2 { - selectTable := SelectTableImpl2{selectStmt: selectStmt, alias: alias} +func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTableImpl { + selectTable := SelectTableImpl{selectStmt: selectStmt, alias: alias} for _, projection := range selectStmt.projections() { newProjection := projection.fromImpl(&selectTable) @@ -27,27 +25,21 @@ func NewSelectTable(selectStmt StatementWithProjections, alias string) SelectTab return selectTable } -func (s *SelectTableImpl2) Alias() string { +func (s *SelectTableImpl) Alias() string { return s.alias } -func (s *SelectTableImpl2) AllColumns() ProjectionList { +func (s *SelectTableImpl) AllColumns() ProjectionList { return s.projections } -func (s *SelectTableImpl2) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (s *SelectTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if s == nil { - return errors.New("jet: Expression table is nil. ") + panic("jet: expression table is nil. ") } - err := s.selectStmt.serialize(statement, out) - - if err != nil { - return err - } + s.selectStmt.serialize(statement, out) out.WriteString("AS") out.WriteIdentifier(s.alias) - - return nil } diff --git a/internal/jet/serializer.go b/internal/jet/serializer.go index 47fcab5..b624f61 100644 --- a/internal/jet/serializer.go +++ b/internal/jet/serializer.go @@ -19,11 +19,11 @@ const ( ) type Serializer interface { - serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error + serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) } -func Serialize(exp Serializer, statementType StatementType, out *SqlBuilder, options ...SerializeOption) error { - return exp.serialize(statementType, out, options...) +func Serialize(exp Serializer, statementType StatementType, out *SqlBuilder, options ...SerializeOption) { + exp.serialize(statementType, out, options...) } func contains(options []SerializeOption, option SerializeOption) bool { diff --git a/internal/jet/sql_builder.go b/internal/jet/sql_builder.go index 6f4d089..9f9f0a3 100644 --- a/internal/jet/sql_builder.go +++ b/internal/jet/sql_builder.go @@ -44,11 +44,10 @@ func (s *SqlBuilder) DecreaseIdent(ident ...int) { s.ident -= toDecrease } -func (s *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) error { +func (s *SqlBuilder) WriteProjections(statement StatementType, projections []Projection) { s.IncreaseIdent() - err := SerializeProjectionList(statement, projections, s) + SerializeProjectionList(statement, projections, s) s.DecreaseIdent() - return err } func (s *SqlBuilder) NewLine() { diff --git a/internal/jet/statement.go b/internal/jet/statement.go index 07afd31..4523750 100644 --- a/internal/jet/statement.go +++ b/internal/jet/statement.go @@ -3,19 +3,16 @@ package jet import ( "context" "database/sql" - "errors" "github.com/go-jet/jet/execution" ) //Statement is common interface for all statements(SELECT, INSERT, UPDATE, DELETE, LOCK) type Statement interface { // Sql returns parametrized sql query with list of arguments. - // err is returned if statement is not composed correctly - Sql() (query string, args []interface{}, err error) + Sql() (query string, args []interface{}) // DebugSql returns debug query where every parametrized placeholder is replaced with its argument. // Do not use it in production. Use it only for debug purposes. - // err is returned if statement is not composed correctly - DebugSql() (query string, err error) + DebugSql() (query string) // Query executes statement over database connection db and stores row result in destination. // Destination can be arbitrary structure @@ -51,71 +48,44 @@ type SerializerStatementInterfaceImpl struct { parent SerializerStatement } -func (s *SerializerStatementInterfaceImpl) Sql() (query string, args []interface{}, err error) { +func (s *SerializerStatementInterfaceImpl) Sql() (query string, args []interface{}) { queryData := &SqlBuilder{Dialect: s.dialect} - err = s.parent.serialize(s.statementType, queryData, noWrap) - - if err != nil { - return "", nil, err - } + s.parent.serialize(s.statementType, queryData, noWrap) query, args = queryData.finalize() - return } -func (s *SerializerStatementInterfaceImpl) DebugSql() (query string, err error) { +func (s *SerializerStatementInterfaceImpl) DebugSql() (query string) { sqlBuilder := &SqlBuilder{Dialect: s.dialect, debug: true} - err = s.parent.serialize(s.statementType, sqlBuilder, noWrap) - - if err != nil { - return "", err - } + s.parent.serialize(s.statementType, sqlBuilder, noWrap) query, _ = sqlBuilder.finalize() - return } func (s *SerializerStatementInterfaceImpl) Query(db execution.DB, destination interface{}) error { - query, args, err := s.Sql() - - if err != nil { - return err - } + query, args := s.Sql() return execution.Query(context.Background(), db, query, args, destination) } func (s *SerializerStatementInterfaceImpl) QueryContext(context context.Context, db execution.DB, destination interface{}) error { - query, args, err := s.Sql() - - if err != nil { - return err - } + query, args := s.Sql() return execution.Query(context, db, query, args, destination) } func (s *SerializerStatementInterfaceImpl) Exec(db execution.DB) (res sql.Result, err error) { - query, args, err := s.Sql() - - if err != nil { - return - } - + query, args := s.Sql() return db.Exec(query, args...) } func (s *SerializerStatementInterfaceImpl) ExecContext(context context.Context, db execution.DB) (res sql.Result, err error) { - query, args, err := s.Sql() - - if err != nil { - return - } + query, args := s.Sql() return db.ExecContext(context, query, args...) } @@ -125,8 +95,8 @@ type ExpressionStatementImpl struct { StatementImpl } -func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType, out *SqlBuilder) error { - return s.serialize(statement, out) +func (s *ExpressionStatementImpl) serializeForProjection(statement StatementType, out *SqlBuilder) { + s.serialize(statement, out) } func NewStatementImpl(Dialect Dialect, statementType StatementType, parent SerializerStatement, clauses ...Clause) StatementImpl { @@ -156,23 +126,15 @@ func (s *StatementImpl) projections() []Projection { return nil } -func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { - if s == nil { - return errors.New("jet: Select expression is nil. ") - } +func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if !contains(options, noWrap) { out.WriteString("(") - out.IncreaseIdent() } for _, clause := range s.Clauses { - err := clause.Serialize(statement, out) - - if err != nil { - return err - } + clause.Serialize(statement, out) } if !contains(options, noWrap) { @@ -180,6 +142,4 @@ func (s *StatementImpl) serialize(statement StatementType, out *SqlBuilder, opti out.NewLine() out.WriteString(")") } - - return nil } diff --git a/internal/jet/table.go b/internal/jet/table.go index 7551d90..1c71a86 100644 --- a/internal/jet/table.go +++ b/internal/jet/table.go @@ -1,7 +1,6 @@ package jet import ( - "errors" "github.com/go-jet/jet/internal/utils" ) @@ -66,9 +65,9 @@ func (t *TableImpl) columns() []Column { return ret } -func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) error { +func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if t == nil { - return errors.New("jet: tableImpl is nil. ") + panic("jet: tableImpl is nil") } out.WriteIdentifier(t.schemaName) @@ -79,8 +78,6 @@ func (t *TableImpl) serialize(statement StatementType, out *SqlBuilder, options out.WriteString("AS") out.WriteIdentifier(t.alias) } - - return nil } type JoinType int @@ -137,18 +134,16 @@ func (t *JoinTableImpl) Columns() []Column { return ret } -func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) (err error) { +func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, options ...SerializeOption) { if t == nil { - return errors.New("jet: Join table is nil. ") + panic("jet: Join table is nil. ") } if utils.IsNil(t.lhs) { - return errors.New("jet: left hand side of join operation is nil table") + panic("jet: left hand side of join operation is nil table") } - if err = t.lhs.serialize(statement, out); err != nil { - return - } + t.lhs.serialize(statement, out) out.NewLine() @@ -166,25 +161,19 @@ func (t *JoinTableImpl) serialize(statement StatementType, out *SqlBuilder, opti } if utils.IsNil(t.rhs) { - return errors.New("jet: right hand side of join operation is nil table") + panic("jet: right hand side of join operation is nil table") } - if err = t.rhs.serialize(statement, out); err != nil { - return - } + t.rhs.serialize(statement, out) if t.onCondition == nil && t.joinType != CrossJoin { - return errors.New("jet: join condition is nil") + panic("jet: join condition is nil") } if t.onCondition != nil { out.WriteString("ON") - if err = t.onCondition.serialize(statement, out); err != nil { - return - } + t.onCondition.serialize(statement, out) } - - return nil } func UnwindColumns(column1 Column, columns ...Column) []Column { diff --git a/internal/jet/testutils.go b/internal/jet/testutils.go index ebcf5ce..f95ae76 100644 --- a/internal/jet/testutils.go +++ b/internal/jet/testutils.go @@ -48,9 +48,7 @@ var table3 = NewTable("db", "table3", table3Col1, table3ColInt, table3StrCol) func assertClauseSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { out := SqlBuilder{Dialect: DefaultDialect} - err := clause.serialize(SelectStatementType, &out) - - assert.NilError(t, err) + clause.serialize(SelectStatementType, &out) //fmt.Println(out.Buff.String()) @@ -59,19 +57,18 @@ func assertClauseSerialize(t *testing.T, clause Serializer, query string, args . } func assertClauseSerializeErr(t *testing.T, clause Serializer, errString string) { - out := SqlBuilder{Dialect: DefaultDialect} - err := clause.serialize(SelectStatementType, &out) + defer func() { + r := recover() + assert.Equal(t, r, errString) + }() - //fmt.Println(out.buff.String()) - assert.Assert(t, err != nil) - assert.Error(t, err, errString) + out := SqlBuilder{Dialect: DefaultDialect} + clause.serialize(SelectStatementType, &out) } func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, args ...interface{}) { out := SqlBuilder{Dialect: DefaultDialect, debug: true} - err := clause.serialize(SelectStatementType, &out) - - assert.NilError(t, err) + clause.serialize(SelectStatementType, &out) //fmt.Println(out.Buff.String()) @@ -81,25 +78,8 @@ func assertClauseDebugSerialize(t *testing.T, clause Serializer, query string, a func assertProjectionSerialize(t *testing.T, projection Projection, query string, args ...interface{}) { out := SqlBuilder{Dialect: DefaultDialect} - err := projection.serializeForProjection(SelectStatementType, &out) - - assert.NilError(t, err) + projection.serializeForProjection(SelectStatementType, &out) assert.DeepEqual(t, out.Buff.String(), query) assert.DeepEqual(t, out.Args, args) } - -func assertStatement(t *testing.T, query Statement, expectedQuery string, expectedArgs ...interface{}) { - queryStr, args, err := query.Sql() - assert.NilError(t, err) - - assert.Equal(t, queryStr, expectedQuery) - assert.DeepEqual(t, args, expectedArgs) -} - -func assertStatementErr(t *testing.T, stmt Statement, errorStr string) { - _, _, err := stmt.Sql() - - assert.Assert(t, err != nil) - assert.Error(t, err, errorStr) -} diff --git a/internal/jet/utils.go b/internal/jet/utils.go index 2000f83..78dc013 100644 --- a/internal/jet/utils.go +++ b/internal/jet/utils.go @@ -1,30 +1,22 @@ package jet import ( - "errors" "github.com/go-jet/jet/internal/utils" "reflect" - "strings" ) -func serializeOrderByClauseList(statement StatementType, orderByClauses []OrderByClause, out *SqlBuilder) error { +func serializeOrderByClauseList(statement StatementType, orderByClauses []OrderByClause, out *SqlBuilder) { for i, value := range orderByClauses { if i > 0 { out.WriteString(", ") } - err := value.serializeForOrderBy(statement, out) - - if err != nil { - return err - } + value.serializeForOrderBy(statement, out) } - - return nil } -func serializeGroupByClauseList(statement StatementType, clauses []GroupByClause, out *SqlBuilder) (err error) { +func serializeGroupByClauseList(statement StatementType, clauses []GroupByClause, out *SqlBuilder) { for i, c := range clauses { if i > 0 { @@ -32,18 +24,14 @@ func serializeGroupByClauseList(statement StatementType, clauses []GroupByClause } if c == nil { - return errors.New("jet: nil clause") + panic("jet: nil clause") } - if err = c.serializeForGroupBy(statement, out); err != nil { - return - } + c.serializeForGroupBy(statement, out) } - - return nil } -func SerializeClauseList(statement StatementType, clauses []Serializer, out *SqlBuilder) (err error) { +func SerializeClauseList(statement StatementType, clauses []Serializer, out *SqlBuilder) { for i, c := range clauses { if i > 0 { @@ -51,35 +39,25 @@ func SerializeClauseList(statement StatementType, clauses []Serializer, out *Sql } if c == nil { - return errors.New("jet: nil clause") + panic("jet: nil clause") } - if err = c.serialize(statement, out); err != nil { - return - } + c.serialize(statement, out) } - - return nil } -func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SqlBuilder) error { +func serializeExpressionList(statement StatementType, expressions []Expression, separator string, out *SqlBuilder) { for i, value := range expressions { if i > 0 { out.WriteString(separator) } - err := value.serialize(statement, out) - - if err != nil { - return err - } + value.serialize(statement, out) } - - return nil } -func SerializeProjectionList(statement StatementType, projections []Projection, out *SqlBuilder) error { +func SerializeProjectionList(statement StatementType, projections []Projection, out *SqlBuilder) { for i, col := range projections { if i > 0 { out.WriteString(",") @@ -87,31 +65,25 @@ func SerializeProjectionList(statement StatementType, projections []Projection, } if col == nil { - return errors.New("jet: Projection is nil") + panic("jet: Projection is nil") } - if err := col.serializeForProjection(statement, out); err != nil { - return err - } + col.serializeForProjection(statement, out) } - - return nil } -func SerializeColumnNames(columns []Column, out *SqlBuilder) error { +func SerializeColumnNames(columns []Column, out *SqlBuilder) { for i, col := range columns { if i > 0 { out.WriteString(", ") } if col == nil { - return errors.New("jet: nil column in columns list") + panic("jet: nil column in columns list") } out.WriteString(col.Name()) } - - return nil } func ColumnListToProjectionList(columns []ColumnExpression) []Projection { @@ -137,7 +109,7 @@ func UnwindRowFromModel(columns []Column, data interface{}) []Serializer { row := []Serializer{} - mustBe(structValue, reflect.Struct) + utils.ValueMustBe(structValue, reflect.Struct, "jet: data has to be a struct") for _, column := range columns { columnName := column.Name() @@ -165,7 +137,7 @@ func UnwindRowFromModel(columns []Column, data interface{}) []Serializer { func UnwindRowsFromModels(columns []Column, data interface{}) [][]Serializer { sliceValue := reflect.Indirect(reflect.ValueOf(data)) - mustBe(sliceValue, reflect.Slice) + utils.ValueMustBe(sliceValue, reflect.Slice, "jet: data has to be a slice.") rows := [][]Serializer{} @@ -189,17 +161,3 @@ func UnwindRowFromValues(value interface{}, values []interface{}) []Serializer { return row } - -func mustBe(v reflect.Value, expectedKinds ...reflect.Kind) { - indirectV := reflect.Indirect(v) - types := []string{} - - for _, expectedKind := range expectedKinds { - types = append(types, expectedKind.String()) - if k := indirectV.Kind(); k == expectedKind { - return - } - } - - panic("argument mismatch: expected " + strings.Join(types, " or ") + ", got " + v.Type().String()) -} diff --git a/internal/testutils/test_utils.go b/internal/testutils/test_utils.go index 59ee748..375785f 100644 --- a/internal/testutils/test_utils.go +++ b/internal/testutils/test_utils.go @@ -78,8 +78,7 @@ func AssertJSONFile(t *testing.T, data interface{}, testRelativePath string) { } func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { - queryStr, args, err := query.Sql() - assert.NilError(t, err) + queryStr, args := query.Sql() assert.Equal(t, queryStr, expectedQuery) if len(expectedArgs) == 0 { @@ -89,32 +88,28 @@ func AssertStatementSql(t *testing.T, query jet.Statement, expectedQuery string, } func AssertStatementSqlErr(t *testing.T, stmt jet.Statement, errorStr string) { - _, _, err := stmt.Sql() + defer func() { + r := recover() + assert.Equal(t, r, errorStr) + }() - assert.Assert(t, err != nil) - assert.Error(t, err, errorStr) + stmt.Sql() } func AssertDebugStatementSql(t *testing.T, query jet.Statement, expectedQuery string, expectedArgs ...interface{}) { - _, args, err := query.Sql() - assert.NilError(t, err) - //assert.Equal(t, queryStr, expectedQuery) + _, args := query.Sql() + if len(expectedArgs) > 0 { assert.DeepEqual(t, args, expectedArgs) } - debuqSql, err := query.DebugSql() - - assert.NilError(t, err) - + debuqSql := query.DebugSql() assert.Equal(t, debuqSql, expectedQuery) } func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Serializer, query string, args ...interface{}) { out := jet.SqlBuilder{Dialect: dialect} - err := jet.Serialize(clause, jet.SelectStatementType, &out) - - assert.NilError(t, err) + jet.Serialize(clause, jet.SelectStatementType, &out) //fmt.Println(out.Buff.String()) @@ -123,24 +118,31 @@ func AssertClauseSerialize(t *testing.T, dialect jet.Dialect, clause jet.Seriali if len(args) > 0 { assert.DeepEqual(t, out.Args, args) } - } func AssertClauseSerializeErr(t *testing.T, dialect jet.Dialect, clause jet.Serializer, errString string) { - out := jet.SqlBuilder{Dialect: dialect} - err := jet.Serialize(clause, jet.SelectStatementType, &out) + defer func() { + r := recover() + assert.Equal(t, r, errString) + }() - //fmt.Println(out.buff.String()) - assert.Assert(t, err != nil) - assert.Error(t, err, errString) + out := jet.SqlBuilder{Dialect: dialect} + jet.Serialize(clause, jet.SelectStatementType, &out) } func AssertProjectionSerialize(t *testing.T, dialect jet.Dialect, projection jet.Projection, query string, args ...interface{}) { out := jet.SqlBuilder{Dialect: dialect} - err := jet.SerializeForProjection(projection, jet.SelectStatementType, &out) - - assert.NilError(t, err) + jet.SerializeForProjection(projection, jet.SelectStatementType, &out) assert.DeepEqual(t, out.Buff.String(), query) assert.DeepEqual(t, out.Args, args) } + +func AssertQueryPanicErr(t *testing.T, stmt jet.Statement, db execution.DB, dest interface{}, errString string) { + defer func() { + r := recover() + assert.Equal(t, r, errString) + }() + + stmt.Query(db, dest) +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index 72f666a..c0d5e7a 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -144,3 +144,28 @@ func FormatTimestamp(t time.Time) []byte { func IsNil(v interface{}) bool { return v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) } + +func MustBe(v interface{}, kind reflect.Kind, errorStr string) { + if reflect.TypeOf(v).Kind() != kind { + panic(errorStr) + } +} + +func ValueMustBe(v reflect.Value, kind reflect.Kind, errorStr string) { + if v.Kind() != kind { + panic(errorStr) + } +} + +func TypeMustBe(v reflect.Type, kind reflect.Kind, errorStr string) { + if v.Kind() != kind { + panic(errorStr) + } +} + +func MustBeInitializedPtr(val interface{}, errorStr string) { + if IsNil(val) { + panic(errorStr) + } + +} diff --git a/mysql/dialect.go b/mysql/dialect.go index bb61adf..1da350a 100644 --- a/mysql/dialect.go +++ b/mysql/dialect.go @@ -1,7 +1,6 @@ package mysql import ( - "errors" "github.com/go-jet/jet/internal/jet" ) @@ -31,63 +30,49 @@ func NewDialect() jet.Dialect { } func mysql_BIT_XOR(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { if len(expressions) != 2 { - return errors.New("jet: invalid number of expressions for operator") + panic("jet: invalid number of expressions for operator XOR") } lhs := expressions[0] rhs := expressions[1] - if err := jet.Serialize(lhs, statement, out, options...); err != nil { - return err - } + jet.Serialize(lhs, statement, out, options...) out.WriteString("^") - if err := jet.Serialize(rhs, statement, out, options...); err != nil { - return err - } - return nil + jet.Serialize(rhs, statement, out, options...) } } func mysql_CONCAT_operator(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { if len(expressions) != 2 { - return errors.New("jet: invalid number of expressions for operator") + panic("jet: invalid number of expressions for operator CONCAT") } - out.WriteString("CONCAT(") - if err := jet.Serialize(expressions[0], statement, out, options...); err != nil { - return err - } + jet.Serialize(expressions[0], statement, out, options...) out.WriteString(", ") - if err := jet.Serialize(expressions[1], statement, out, options...); err != nil { - return err - } + jet.Serialize(expressions[1], statement, out, options...) out.WriteString(")") - - return nil } } func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { if len(expressions) != 2 { - return errors.New("jet: invalid number of expressions for operator") + panic("jet: invalid number of expressions for operator DIV") } lhs := expressions[0] rhs := expressions[1] - if err := jet.Serialize(lhs, statement, out, options...); err != nil { - return err - } + jet.Serialize(lhs, statement, out, options...) _, isLhsInt := lhs.(IntegerExpression) _, isRhsInt := rhs.(IntegerExpression) @@ -98,44 +83,26 @@ func mysql_DIVISION(expressions ...jet.Expression) jet.SerializeFunc { out.WriteString("/") } - if err := jet.Serialize(rhs, statement, out, options...); err != nil { - return err - } - return nil + jet.Serialize(rhs, statement, out, options...) } } func mysql_IS_NOT_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { if len(expressions) != 2 { - return errors.New("jet: invalid number of expressions for operator") - } - if err := jet.Serialize(expressions[0], statement, out); err != nil { - return err + panic("jet: invalid number of expressions for operator") } + jet.Serialize(expressions[0], statement, out) out.WriteString("<=>") - - if err := jet.Serialize(expressions[1], statement, out); err != nil { - return err - } - - return nil + jet.Serialize(expressions[1], statement, out) } } func mysql_IS_DISTINCT_FROM(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { out.WriteString("NOT(") - - err := mysql_IS_NOT_DISTINCT_FROM(expressions...)(statement, out, options...) - - if err != nil { - return err - } - + mysql_IS_NOT_DISTINCT_FROM(expressions...)(statement, out, options...) out.WriteString(")") - - return nil } } diff --git a/mysql/insert_statement_test.go b/mysql/insert_statement_test.go index 91fdb8e..917f746 100644 --- a/mysql/insert_statement_test.go +++ b/mysql/insert_statement_test.go @@ -7,7 +7,7 @@ import ( ) func TestInvalidInsert(t *testing.T) { - assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") + assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") } @@ -114,7 +114,7 @@ func TestInsertFromNonStructModel(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "argument mismatch: expected struct, got []int") + assert.Equal(t, r, "jet: data has to be a struct") }() table2.INSERT(table2ColInt).MODEL([]int{}) diff --git a/mysql/select_table.go b/mysql/select_table.go index 7a314d8..e7bb5cb 100644 --- a/mysql/select_table.go +++ b/mysql/select_table.go @@ -8,13 +8,13 @@ type SelectTable interface { } type selectTableImpl struct { - jet.SelectTableImpl2 + jet.SelectTableImpl readableTableInterfaceImpl } func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { subQuery := &selectTableImpl{ - SelectTableImpl2: jet.NewSelectTable(selectStmt, alias), + SelectTableImpl: jet.NewSelectTable(selectStmt, alias), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/mysql/update_statement_test.go b/mysql/update_statement_test.go index 55d88d3..fc933aa 100644 --- a/mysql/update_statement_test.go +++ b/mysql/update_statement_test.go @@ -59,5 +59,5 @@ WHERE table1.col1 = ?; func TestInvalidInputs(t *testing.T) { assertStatementSqlErr(t, table1.UPDATE(table1ColInt).SET(1), "jet: WHERE clause not set") - assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list") + assertStatementSqlErr(t, table1.UPDATE(nil).SET(1), "jet: nil column in columns list for SET clause") } diff --git a/postgres/clauses.go b/postgres/clauses.go index fa3d01b..b616ad5 100644 --- a/postgres/clauses.go +++ b/postgres/clauses.go @@ -8,14 +8,13 @@ type ClauseReturning struct { Projections []jet.Projection } -func (r *ClauseReturning) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) error { +func (r *ClauseReturning) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) { if len(r.Projections) == 0 { - return nil + return } out.NewLine() out.WriteString("RETURNING") out.IncreaseIdent() - - return out.WriteProjections(statementType, r.Projections) + out.WriteProjections(statementType, r.Projections) } diff --git a/postgres/dialect.go b/postgres/dialect.go index d392542..27656b4 100644 --- a/postgres/dialect.go +++ b/postgres/dialect.go @@ -1,7 +1,6 @@ package postgres import ( - "errors" "github.com/go-jet/jet/internal/jet" "strconv" "strings" @@ -30,9 +29,9 @@ func NewDialect() jet.Dialect { } func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { - return errors.New("jet: invalid number of expressions for operator") + panic("jet: invalid number of expressions for operator") } expression := expressions[0] @@ -40,32 +39,27 @@ func postgresCAST(expressions ...jet.Expression) jet.SerializeFunc { litExpr, ok := expressions[1].(jet.LiteralExpression) if !ok { - return errors.New("jet: cast invalid cast type") + panic("jet: cast invalid cast type") } castType, ok := litExpr.Value().(string) if !ok { - return errors.New("jet: cast type is not string") + panic("jet: cast type is not string") } - if err := jet.Serialize(expression, statement, out, options...); err != nil { - return err - } + jet.Serialize(expression, statement, out, options...) out.WriteString("::" + castType) - return nil } } func postgres_REGEXP_LIKE_function(expressions ...jet.Expression) jet.SerializeFunc { - return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) error { + return func(statement jet.StatementType, out *jet.SqlBuilder, options ...jet.SerializeOption) { if len(expressions) < 2 { - return errors.New("jet: invalid number of expressions for operator") + panic("jet: invalid number of expressions for operator") } - if err := jet.Serialize(expressions[0], statement, out, options...); err != nil { - return err - } + jet.Serialize(expressions[0], statement, out, options...) caseSensitive := false @@ -83,10 +77,6 @@ func postgres_REGEXP_LIKE_function(expressions ...jet.Expression) jet.SerializeF out.WriteString("~*") } - if err := jet.Serialize(expressions[1], statement, out, options...); err != nil { - return err - } - - return nil + jet.Serialize(expressions[1], statement, out, options...) } } diff --git a/postgres/insert_statement_test.go b/postgres/insert_statement_test.go index c64d96c..2eab553 100644 --- a/postgres/insert_statement_test.go +++ b/postgres/insert_statement_test.go @@ -7,7 +7,7 @@ import ( ) func TestInvalidInsert(t *testing.T) { - assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: no row values or query specified") + assertStatementSqlErr(t, table1.INSERT(table1Col1), "jet: VALUES or QUERY has to be specified for INSERT statement") assertStatementSqlErr(t, table1.INSERT(nil).VALUES(1), "jet: nil column in columns list") } @@ -114,7 +114,7 @@ func TestInsertFromNonStructModel(t *testing.T) { defer func() { r := recover() - assert.Equal(t, r, "argument mismatch: expected struct, got []int") + assert.Equal(t, r, "jet: data has to be a struct") }() table2.INSERT(table2ColInt).MODEL([]int{}) diff --git a/postgres/select_table.go b/postgres/select_table.go index 7b94607..9fdb035 100644 --- a/postgres/select_table.go +++ b/postgres/select_table.go @@ -8,13 +8,13 @@ type SelectTable interface { } type selectTableImpl struct { - jet.SelectTableImpl2 + jet.SelectTableImpl readableTableInterfaceImpl } func newSelectTable(selectStmt jet.StatementWithProjections, alias string) SelectTable { subQuery := &selectTableImpl{ - SelectTableImpl2: jet.NewSelectTable(selectStmt, alias), + SelectTableImpl: jet.NewSelectTable(selectStmt, alias), } subQuery.readableTableInterfaceImpl.parent = subQuery diff --git a/postgres/update_statement.go b/postgres/update_statement.go index 6f20326..5839254 100644 --- a/postgres/update_statement.go +++ b/postgres/update_statement.go @@ -1,7 +1,6 @@ package postgres import ( - "errors" "github.com/go-jet/jet/internal/jet" ) @@ -62,23 +61,19 @@ type ClauseSet struct { Values []jet.Serializer } -func (s *ClauseSet) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) error { +func (s *ClauseSet) Serialize(statementType jet.StatementType, out *jet.SqlBuilder) { out.NewLine() out.WriteString("SET") if len(s.Columns) == 0 { - return errors.New("jet: no columns selected") + panic("jet: no columns selected") } if len(s.Columns) > 1 { out.WriteString("(") } - err := jet.SerializeColumnNames(s.Columns, out) - - if err != nil { - return err - } + jet.SerializeColumnNames(s.Columns, out) if len(s.Columns) > 1 { out.WriteString(")") @@ -90,15 +85,9 @@ func (s *ClauseSet) Serialize(statementType jet.StatementType, out *jet.SqlBuild out.WriteString("(") } - err = jet.SerializeClauseList(statementType, s.Values, out) - - if err != nil { - return err - } + jet.SerializeClauseList(statementType, s.Values, out) if len(s.Values) > 1 { out.WriteString(")") } - - return nil } diff --git a/tests/mysql/alltypes_test.go b/tests/mysql/alltypes_test.go index 6e15330..5c16465 100644 --- a/tests/mysql/alltypes_test.go +++ b/tests/mysql/alltypes_test.go @@ -234,8 +234,8 @@ func TestFloatOperators(t *testing.T) { TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), ).LIMIT(2) - queryStr, _, err := query.Sql() - assert.NilError(t, err) + queryStr, _ := query.Sql() + assert.Equal(t, queryStr, ` SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.decimal = ?) AS "eq2", @@ -280,7 +280,7 @@ LIMIT ?; common.FloatExpressionTestResult `alias:"."` } - err = query.Query(db, &dest) + err := query.Query(db, &dest) assert.NilError(t, err) diff --git a/tests/postgres/alltypes_test.go b/tests/postgres/alltypes_test.go index e530c31..44b5b46 100644 --- a/tests/postgres/alltypes_test.go +++ b/tests/postgres/alltypes_test.go @@ -353,8 +353,8 @@ func TestFloatOperators(t *testing.T) { TRUNC(AllTypes.Decimal, Int(1)).AS("trunc"), ).LIMIT(2) - queryStr, _, err := query.Sql() - assert.NilError(t, err) + queryStr, _ := query.Sql() + assert.Equal(t, queryStr, ` SELECT (all_types.numeric = all_types.numeric) AS "eq1", (all_types.decimal = $1) AS "eq2", @@ -399,7 +399,7 @@ LIMIT $35; common.FloatExpressionTestResult `alias:"."` } - err = query.Query(db, &dest) + err := query.Query(db, &dest) assert.NilError(t, err) diff --git a/tests/postgres/scan_test.go b/tests/postgres/scan_test.go index 267d7be..4b6060d 100644 --- a/tests/postgres/scan_test.go +++ b/tests/postgres/scan_test.go @@ -18,33 +18,23 @@ var query = Inventory. func TestScanToInvalidDestination(t *testing.T) { t.Run("nil dest", func(t *testing.T) { - err := query.Query(db, nil) - - assert.Error(t, err, "jet: Destination is nil") + testutils.AssertQueryPanicErr(t, query, db, nil, "jet: destination is nil") }) t.Run("struct dest", func(t *testing.T) { - err := query.Query(db, struct{}{}) - - assert.Error(t, err, "jet: Destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, query, db, struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("slice dest", func(t *testing.T) { - err := query.Query(db, []struct{}{}) - - assert.Error(t, err, "jet: Destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, query, db, []struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("slice of pointers to pointer dest", func(t *testing.T) { - err := query.Query(db, []**struct{}{}) - - assert.Error(t, err, "jet: Destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, query, db, []**struct{}{}, "jet: destination has to be a pointer to slice or pointer to struct") }) t.Run("map dest", func(t *testing.T) { - err := query.Query(db, []map[string]string{}) - - assert.Error(t, err, "jet: Destination has to be a pointer to slice or pointer to struct") + testutils.AssertQueryPanicErr(t, query, db, []map[string]string{}, "jet: destination has to be a pointer to slice or pointer to struct") }) } @@ -126,9 +116,7 @@ func TestScanToStruct(t *testing.T) { Inventory **model.Inventory }{} - err := query.Query(db, &dest) - - assert.Error(t, err, "jet: Unsupported dest type: Inventory **model.Inventory") + testutils.AssertQueryPanicErr(t, query, db, &dest, "jet: unsupported dest type: Inventory **model.Inventory") }) t.Run("invalid dest 2", func(t *testing.T) { @@ -136,9 +124,7 @@ func TestScanToStruct(t *testing.T) { Inventory ***model.Inventory }{} - err := query.Query(db, &dest) - - assert.Error(t, err, "jet: Unsupported dest type: Inventory ***model.Inventory") + testutils.AssertQueryPanicErr(t, query, db, &dest, "jet: unsupported dest type: Inventory ***model.Inventory") }) t.Run("custom struct", func(t *testing.T) { @@ -669,9 +655,7 @@ func TestScanToSlice(t *testing.T) { } } - err := query.Query(db, &dest) - - assert.Error(t, err, "jet: Unsupported dest type: Cities []**struct { *model.City }") + testutils.AssertQueryPanicErr(t, query, db, &dest, "jet: unsupported slice element type at 'Cities []**struct { *model.City }'.") }) }