Merge pull request #120 from go-jet/develop

Release 2.7.1
This commit is contained in:
go-jet 2022-02-14 12:54:08 +01:00 committed by GitHub
commit c29f0afd2b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
28 changed files with 1035 additions and 456 deletions

View file

@ -42,7 +42,7 @@ https://medium.com/@go.jet/jet-5f3667efa0cc
* [WITH](https://github.com/go-jet/jet/wiki/WITH) * [WITH](https://github.com/go-jet/jet/wiki/WITH)
2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store 2) Auto-generated Data Model types - Go types mapped to database type (table, view or enum), used to store
result of database queries. Can be combined to create desired query result destination. result of database queries. Can be combined to create complex query result destination.
3) Query execution with result mapping to arbitrary destination. 3) Query execution with result mapping to arbitrary destination.
## Getting Started ## Getting Started
@ -164,11 +164,11 @@ import (
``` ```
Let's say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English' Let's say we want to retrieve the list of all _actors_ that acted in _films_ longer than 180 minutes, _film language_ is 'English'
and _film category_ is not 'Action'. and _film category_ is not 'Action'.
```java ```golang
stmt := SELECT( stmt := SELECT(
Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, // or just Actor.AllColumns Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, // or just Actor.AllColumns
Film.AllColumns, Film.AllColumns,
Language.AllColumns, Language.AllColumns.Except(Language.LastUpdate),
Category.AllColumns, Category.AllColumns,
).FROM( ).FROM(
Actor. Actor.
@ -358,7 +358,7 @@ fmt.Println(string(jsonText))
"Language": { "Language": {
"LanguageID": 1, "LanguageID": 1,
"Name": "English ", "Name": "English ",
"LastUpdate": "2006-02-15T10:02:19Z" "LastUpdate": "0001-01-01T00:00:00Z"
}, },
"Categories": [ "Categories": [
{ {
@ -393,7 +393,7 @@ fmt.Println(string(jsonText))
"Language": { "Language": {
"LanguageID": 1, "LanguageID": 1,
"Name": "English ", "Name": "English ",
"LastUpdate": "2006-02-15T10:02:19Z" "LastUpdate": "0001-01-01T00:00:00Z"
}, },
"Categories": [ "Categories": [
{ {
@ -580,5 +580,5 @@ To run the tests, additional dependencies are required:
## License ## License
Copyright 2019-2021 Goran Bjelanovic Copyright 2019-2022 Goran Bjelanovic
Licensed under the Apache License, Version 2.0. Licensed under the Apache License, Version 2.0.

File diff suppressed because it is too large Load diff

View file

@ -36,7 +36,7 @@ func main() {
stmt := SELECT( stmt := SELECT(
Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate, Actor.ActorID, Actor.FirstName, Actor.LastName, Actor.LastUpdate,
Film.AllColumns, Film.AllColumns,
Language.AllColumns, Language.AllColumns.Except(Language.LastUpdate),
Category.AllColumns, Category.AllColumns,
).FROM( ).FROM(
Actor. Actor.

View file

@ -98,9 +98,9 @@ func (c *ClauseWhere) Serialize(statementType StatementType, out *SQLBuilder, op
} }
out.WriteString("WHERE") out.WriteString("WHERE")
out.IncreaseIdent() out.IncreaseIdent(6)
c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...) c.Condition.serialize(statementType, out, NoWrap.WithFallTrough(options)...)
out.DecreaseIdent() out.DecreaseIdent(6)
} }
// ClauseGroupBy struct // ClauseGroupBy struct

View file

@ -123,6 +123,65 @@ func (c *binaryOperatorExpression) serialize(statement StatementType, out *SQLBu
} }
} }
type expressionListOperator struct {
ExpressionInterfaceImpl
operator string
expressions []Expression
}
func newExpressionListOperator(operator string, expressions ...Expression) *expressionListOperator {
ret := &expressionListOperator{
operator: operator,
expressions: expressions,
}
ret.ExpressionInterfaceImpl.Parent = ret
return ret
}
func newBoolExpressionListOperator(operator string, expressions ...BoolExpression) BoolExpression {
return BoolExp(newExpressionListOperator(operator, BoolExpressionListToExpressionList(expressions)...))
}
func (elo *expressionListOperator) serialize(statement StatementType, out *SQLBuilder, options ...SerializeOption) {
if len(elo.expressions) == 0 {
panic("jet: syntax error, expression list empty")
}
shouldWrap := len(elo.expressions) > 1
if shouldWrap {
out.WriteByte('(')
out.IncreaseIdent(tabSize)
out.NewLine()
}
for i, expression := range elo.expressions {
if i == 1 {
out.IncreaseIdent(tabSize)
}
if i > 0 {
out.NewLine()
out.WriteString(elo.operator)
}
out.IncreaseIdent(len(elo.operator) + 1)
expression.serialize(statement, out, FallTrough(options)...)
out.DecreaseIdent(len(elo.operator) + 1)
}
if len(elo.expressions) > 1 {
out.DecreaseIdent(tabSize)
}
if shouldWrap {
out.DecreaseIdent(tabSize)
out.NewLine()
out.WriteByte(')')
}
}
// A prefix operator Expression // A prefix operator Expression
type prefixExpression struct { type prefixExpression struct {
ExpressionInterfaceImpl ExpressionInterfaceImpl
@ -209,8 +268,8 @@ type complexExpression struct {
expressions Expression expressions Expression
} }
func complexExpr(expressions Expression) Expression { func complexExpr(expression Expression) Expression {
complexExpression := &complexExpression{expressions: expressions} complexExpression := &complexExpression{expressions: expression}
complexExpression.ExpressionInterfaceImpl.Parent = complexExpression complexExpression.ExpressionInterfaceImpl.Parent = complexExpression
return complexExpression return complexExpression

View file

@ -1,5 +1,17 @@
package jet package jet
// AND function adds AND operator between expressions. This function can be used, instead of method AND,
// to have a better inlining of a complex condition in the Go code and in the generated SQL.
func AND(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("AND", expressions...)
}
// OR function adds OR operator between expressions. This function can be used, instead of method OR,
// to have a better inlining of a complex condition in the Go code and in the generated SQL.
func OR(expressions ...BoolExpression) BoolExpression {
return newBoolExpressionListOperator("OR", expressions...)
}
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression { func ROW(expressions ...Expression) Expression {
return NewFunc("ROW", expressions, nil) return NewFunc("ROW", expressions, nil)

View file

@ -4,6 +4,28 @@ import (
"testing" "testing"
) )
func TestAND(t *testing.T) {
assertClauseSerializeErr(t, AND(), "jet: syntax error, expression list empty")
assertClauseSerialize(t, AND(table1ColInt.IS_NULL()), `table1.col_int IS NULL`) // IS NULL doesn't add parenthesis
assertClauseSerialize(t, AND(table1ColInt.LT(Int(11))), `(table1.col_int < $1)`, int64(11))
assertClauseSerialize(t, AND(table1ColInt.GT(Int(11)), table1ColFloat.EQ(Float(0))),
`(
(table1.col_int > $1)
AND (table1.col_float = $2)
)`, int64(11), 0.0)
}
func TestOR(t *testing.T) {
assertClauseSerializeErr(t, OR(), "jet: syntax error, expression list empty")
assertClauseSerialize(t, OR(table1ColInt.IS_NULL()), `table1.col_int IS NULL`) // IS NULL doesn't add parenthesis
assertClauseSerialize(t, OR(table1ColInt.LT(Int(11))), `(table1.col_int < $1)`, int64(11))
assertClauseSerialize(t, OR(table1ColInt.GT(Int(11)), table1ColFloat.EQ(Float(0))),
`(
(table1.col_int > $1)
OR (table1.col_float = $2)
)`, int64(11), 0.0)
}
func TestFuncAVG(t *testing.T) { func TestFuncAVG(t *testing.T) {
assertClauseSerialize(t, AVG(table1ColFloat), "AVG(table1.col_float)") assertClauseSerialize(t, AVG(table1ColFloat), "AVG(table1.col_float)")
assertClauseSerialize(t, AVG(table1ColInt), "AVG(table1.col_int)") assertClauseSerialize(t, AVG(table1ColInt), "AVG(table1.col_int)")

View file

@ -26,6 +26,7 @@ type SQLBuilder struct {
Debug bool Debug bool
} }
const tabSize = 4
const defaultIdent = 5 const defaultIdent = 5
// IncreaseIdent adds ident or defaultIdent number of spaces to each new line // IncreaseIdent adds ident or defaultIdent number of spaces to each new line

View file

@ -33,11 +33,13 @@ type Statement interface {
// Rows wraps sql.Rows type to add query result mapping for Scan method // Rows wraps sql.Rows type to add query result mapping for Scan method
type Rows struct { type Rows struct {
*sql.Rows *sql.Rows
scanContext *qrm.ScanContext
} }
// Scan will map the Row values into struct destination // Scan will map the Row values into struct destination
func (r *Rows) Scan(destination interface{}) error { func (r *Rows) Scan(destination interface{}) error {
return qrm.ScanOneRowToDest(r.Rows, destination) return qrm.ScanOneRowToDest(r.scanContext, r.Rows, destination)
} }
// SerializerStatement interface // SerializerStatement interface
@ -161,7 +163,16 @@ func (s *serializerStatementInterfaceImpl) Rows(ctx context.Context, db qrm.DB)
return nil, err return nil, err
} }
return &Rows{rows}, nil scanContext, err := qrm.NewScanContext(rows)
if err != nil {
return nil, err
}
return &Rows{
Rows: rows,
scanContext: scanContext,
}, nil
} }
func duration(f func()) time.Duration { func duration(f func()) time.Duration {

View file

@ -113,6 +113,17 @@ func ExpressionListToSerializerList(expressions []Expression) []Serializer {
return ret return ret
} }
// BoolExpressionListToExpressionList converts list of bool expressions to list of expressions
func BoolExpressionListToExpressionList(expressions []BoolExpression) []Expression {
var ret []Expression
for _, expression := range expressions {
ret = append(ret, expression)
}
return ret
}
// ColumnListToProjectionList func // ColumnListToProjectionList func
func ColumnListToProjectionList(columns []ColumnExpression) []Projection { func ColumnListToProjectionList(columns []ColumnExpression) []Projection {
var ret []Projection var ret []Projection

View file

@ -67,7 +67,8 @@ func AssertJSON(t *testing.T, data interface{}, expectedJSON string) {
jsonData, err := json.MarshalIndent(data, "", "\t") jsonData, err := json.MarshalIndent(data, "", "\t")
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "\n"+string(jsonData)+"\n", expectedJSON) dataJson := "\n" + string(jsonData) + "\n"
require.Equal(t, dataJson, expectedJSON)
} }
// SaveJSONFile saves v as json at testRelativePath // SaveJSONFile saves v as json at testRelativePath

View file

@ -2,6 +2,15 @@ package mysql
import "github.com/go-jet/jet/v2/internal/jet" import "github.com/go-jet/jet/v2/internal/jet"
// This functions can be used, instead of its method counterparts, to have a better indentation of a complex condition
// in the Go code and in the generated SQL.
var (
// AND function adds AND operator between expressions.
AND = jet.AND
// OR function adds OR operator between expressions.
OR = jet.OR
)
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
var ROW = jet.ROW var ROW = jet.ROW

View file

@ -148,9 +148,9 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
SELECT table1.col_int AS "table1.col_int" SELECT table1.col_int AS "table1.col_int"
FROM db.table1 FROM db.table1
WHERE NOT (EXISTS ( WHERE NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int" SELECT table2.col_int AS "table2.col_int"
FROM db.table2 FROM db.table2
WHERE table1.col_int = table2.col_int WHERE table1.col_int = table2.col_int
)); ));
`) `)
} }

View file

@ -2,6 +2,15 @@ package postgres
import "github.com/go-jet/jet/v2/internal/jet" import "github.com/go-jet/jet/v2/internal/jet"
// This functions can be used, instead of its method counterparts, to have a better indentation of a complex condition
// in the Go code and in the generated SQL.
var (
// AND function adds AND operator between expressions.
AND = jet.AND
// OR function adds OR operator between expressions.
OR = jet.OR
)
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
var ROW = jet.ROW var ROW = jet.ROW

View file

@ -63,46 +63,26 @@ func Query(ctx context.Context, db DB, query string, args []interface{}, destPtr
} }
// ScanOneRowToDest will scan one row into struct destination // ScanOneRowToDest will scan one row into struct destination
func ScanOneRowToDest(rows *sql.Rows, destPtr interface{}) error { func ScanOneRowToDest(scanContext *ScanContext, rows *sql.Rows, destPtr interface{}) error {
utils.MustBeInitializedPtr(destPtr, "jet: destination is nil") utils.MustBeInitializedPtr(destPtr, "jet: destination is nil")
utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct") utils.MustBe(destPtr, reflect.Ptr, "jet: destination has to be a pointer to slice or pointer to struct")
scanContext, err := newScanContext(rows)
if err != nil {
return fmt.Errorf("failed to create scan context, %w", err)
}
if len(scanContext.row) == 0 { if len(scanContext.row) == 0 {
return errors.New("empty row slice") return errors.New("empty row slice")
} }
err = rows.Scan(scanContext.row...) err := rows.Scan(scanContext.row...)
if err != nil { if err != nil {
return fmt.Errorf("rows scan error, %w", err) return fmt.Errorf("jet: rows scan error, %w", err)
} }
destinationPtrType := reflect.TypeOf(destPtr) destValuePtr := reflect.ValueOf(destPtr)
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem()
_, err = mapRowToSlice(scanContext, "", newTypeStack(), tempSlicePtrValue, nil) _, err = mapRowToStruct(scanContext, "", destValuePtr, nil)
if err != nil { if err != nil {
return fmt.Errorf("failed to map a row, %w", err) return fmt.Errorf("jet: failed to scan a row into destination, %w", err)
}
// edge case when row result set contains only NULLs.
if tempSliceValue.Len() == 0 {
return nil
}
destValue := reflect.ValueOf(destPtr).Elem()
firstTempSliceValue := tempSliceValue.Index(0).Elem()
if destValue.Type().AssignableTo(firstTempSliceValue.Type()) {
destValue.Set(tempSliceValue.Index(0).Elem())
} }
return nil return nil
@ -120,7 +100,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
} }
defer rows.Close() defer rows.Close()
scanContext, err := newScanContext(rows) scanContext, err := NewScanContext(rows)
if err != nil { if err != nil {
return return
@ -141,7 +121,7 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
scanContext.rowNum++ scanContext.rowNum++
_, err = mapRowToSlice(scanContext, "", newTypeStack(), slicePtrValue, nil) _, err = mapRowToSlice(scanContext, "", slicePtrValue, nil)
if err != nil { if err != nil {
return scanContext.rowNum, err return scanContext.rowNum, err
@ -157,9 +137,8 @@ func queryToSlice(ctx context.Context, db DB, query string, args []interface{},
} }
func mapRowToSlice( func mapRowToSlice(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack,
slicePtrValue reflect.Value, slicePtrValue reflect.Value,
field *reflect.StructField) (updated bool, err error) { field *reflect.StructField) (updated bool, err error) {
@ -174,19 +153,19 @@ func mapRowToSlice(
structGroupKey := scanContext.getGroupKey(sliceElemType, field) structGroupKey := scanContext.getGroupKey(sliceElemType, field)
groupKey = groupKey + "," + structGroupKey groupKey = concat(groupKey, ",", structGroupKey)
index, ok := scanContext.uniqueDestObjectsMap[groupKey] index, ok := scanContext.uniqueDestObjectsMap[groupKey]
if ok { if ok {
structPtrValue := getSliceElemPtrAt(slicePtrValue, index) structPtrValue := getSliceElemPtrAt(slicePtrValue, index)
return mapRowToStruct(scanContext, groupKey, typesVisited, structPtrValue, field, true) return mapRowToStruct(scanContext, groupKey, structPtrValue, field, true)
} }
destinationStructPtr := newElemPtrValueForSlice(slicePtrValue) destinationStructPtr := newElemPtrValueForSlice(slicePtrValue)
updated, err = mapRowToStruct(scanContext, groupKey, typesVisited, destinationStructPtr, field) updated, err = mapRowToStruct(scanContext, groupKey, destinationStructPtr, field)
if err != nil { if err != nil {
return return
@ -204,7 +183,7 @@ func mapRowToSlice(
return return
} }
func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) { func mapRowToBaseTypeSlice(scanContext *ScanContext, slicePtrValue reflect.Value, field *reflect.StructField) (updated bool, err error) {
index := 0 index := 0
if field != nil { if field != nil {
typeName, columnName := getTypeAndFieldName("", *field) typeName, columnName := getTypeAndFieldName("", *field)
@ -212,7 +191,7 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
return return
} }
} }
rowElemPtr := scanContext.rowElemValuePtr(index) rowElemPtr := scanContext.rowElemValueClonePtr(index)
if rowElemPtr.IsValid() && !rowElemPtr.IsNil() { if rowElemPtr.IsValid() && !rowElemPtr.IsNil() {
updated = true updated = true
@ -226,9 +205,8 @@ func mapRowToBaseTypeSlice(scanContext *scanContext, slicePtrValue reflect.Value
} }
func mapRowToStruct( func mapRowToStruct(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack, // to prevent circular dependency scan
structPtrValue reflect.Value, structPtrValue reflect.Value,
parentField *reflect.StructField, parentField *reflect.StructField,
onlySlices ...bool, // small optimization, not to assign to already assigned struct fields onlySlices ...bool, // small optimization, not to assign to already assigned struct fields
@ -237,12 +215,12 @@ func mapRowToStruct(
mapOnlySlices := len(onlySlices) > 0 mapOnlySlices := len(onlySlices) > 0
structType := structPtrValue.Type().Elem() structType := structPtrValue.Type().Elem()
if typesVisited.contains(&structType) { if scanContext.typesVisited.contains(&structType) {
return false, nil return false, nil
} }
typesVisited.push(&structType) scanContext.typesVisited.push(&structType)
defer typesVisited.pop() defer scanContext.typesVisited.pop()
typeInf := scanContext.getTypeInfo(structType, parentField) typeInf := scanContext.getTypeInfo(structType, parentField)
@ -260,7 +238,7 @@ func mapRowToStruct(
if fieldMap.complexType { if fieldMap.complexType {
var changed bool var changed bool
changed, err = mapRowToDestinationValue(scanContext, groupKey, typesVisited, fieldValue, &field) changed, err = mapRowToDestinationValue(scanContext, groupKey, fieldValue, &field)
if err != nil { if err != nil {
return return
@ -271,34 +249,36 @@ func mapRowToStruct(
} }
} else { } else {
if mapOnlySlices || fieldMap.columnIndex == -1 { if mapOnlySlices || fieldMap.rowIndex == -1 {
continue continue
} }
cellValue := scanContext.rowElem(fieldMap.columnIndex) scannedValue := scanContext.rowElemValue(fieldMap.rowIndex)
if cellValue == nil { if !scannedValue.IsValid() {
setZeroValue(fieldValue) // scannedValue is nil, destination should be set to zero value
continue continue
} }
initializeValueIfNilPtr(fieldValue)
updated = true updated = true
if fieldMap.implementsScanner { if fieldMap.implementsScanner {
scanner := getScanner(fieldValue) initializeValueIfNilPtr(fieldValue)
fieldScanner := getScanner(fieldValue)
err = scanner.Scan(cellValue) value := scannedValue.Interface()
err := fieldScanner.Scan(value)
if err != nil { if err != nil {
err = fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) return updated, fmt.Errorf(`can't scan %T(%q) to '%s %s': %w`, value, value, field.Name, field.Type.String(), err)
return
} }
} else { } else {
err = setReflectValue(reflect.ValueOf(cellValue), fieldValue) err := assign(scannedValue, fieldValue)
if err != nil { if err != nil {
err = fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, cellValue, cellValue, field.Name, field.Type.String(), err) return updated, fmt.Errorf(`can't assign %T(%q) to '%s %s': %w`, scannedValue.Interface(), scannedValue.Interface(),
return field.Name, field.Type.String(), err)
} }
} }
} }
@ -308,9 +288,8 @@ func mapRowToStruct(
} }
func mapRowToDestinationValue( func mapRowToDestinationValue(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack,
dest reflect.Value, dest reflect.Value,
structField *reflect.StructField) (updated bool, err error) { structField *reflect.StructField) (updated bool, err error) {
@ -326,7 +305,7 @@ func mapRowToDestinationValue(
} }
} }
updated, err = mapRowToDestinationPtr(scanContext, groupKey, typesVisited, destPtrValue, structField) updated, err = mapRowToDestinationPtr(scanContext, groupKey, destPtrValue, structField)
if err != nil { if err != nil {
return return
@ -340,9 +319,8 @@ func mapRowToDestinationValue(
} }
func mapRowToDestinationPtr( func mapRowToDestinationPtr(
scanContext *scanContext, scanContext *ScanContext,
groupKey string, groupKey string,
typesVisited *typeStack,
destPtrValue reflect.Value, destPtrValue reflect.Value,
structField *reflect.StructField) (updated bool, err error) { structField *reflect.StructField) (updated bool, err error) {
@ -351,9 +329,9 @@ func mapRowToDestinationPtr(
destValueKind := destPtrValue.Elem().Kind() destValueKind := destPtrValue.Elem().Kind()
if destValueKind == reflect.Struct { if destValueKind == reflect.Struct {
return mapRowToStruct(scanContext, groupKey, typesVisited, destPtrValue, structField) return mapRowToStruct(scanContext, groupKey, destPtrValue, structField)
} else if destValueKind == reflect.Slice { } else if destValueKind == reflect.Slice {
return mapRowToSlice(scanContext, groupKey, typesVisited, destPtrValue, structField) return mapRowToSlice(scanContext, groupKey, destPtrValue, structField)
} else { } else {
panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String()) panic("jet: unsupported dest type: " + structField.Name + " " + structField.Type.String())
} }

View file

@ -7,16 +7,21 @@ import (
"strings" "strings"
) )
type scanContext struct { // ScanContext contains information about current row processed, mapping from the row to the
// destination types and type grouping information.
type ScanContext struct {
rowNum int64 rowNum int64
row []interface{} row []interface{}
uniqueDestObjectsMap map[string]int uniqueDestObjectsMap map[string]int
commonIdentToColumnIndex map[string]int commonIdentToColumnIndex map[string]int
groupKeyInfoCache map[string]groupKeyInfo groupKeyInfoCache map[string]groupKeyInfo
typeInfoMap map[string]typeInfo typeInfoMap map[string]typeInfo
typesVisited typeStack // to prevent circular dependency scan
} }
func newScanContext(rows *sql.Rows) (*scanContext, error) { // NewScanContext creates new ScanContext from rows
func NewScanContext(rows *sql.Rows) (*ScanContext, error) {
aliases, err := rows.Columns() aliases, err := rows.Columns()
if err != nil { if err != nil {
@ -36,13 +41,13 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
commonIdentifier := toCommonIdentifier(names[0]) commonIdentifier := toCommonIdentifier(names[0])
if len(names) > 1 { if len(names) > 1 {
commonIdentifier += "." + toCommonIdentifier(names[1]) commonIdentifier = concat(commonIdentifier, ".", toCommonIdentifier(names[1]))
} }
commonIdentToColumnIndex[commonIdentifier] = i commonIdentToColumnIndex[commonIdentifier] = i
} }
return &scanContext{ return &ScanContext{
row: createScanSlice(len(columnTypes)), row: createScanSlice(len(columnTypes)),
uniqueDestObjectsMap: make(map[string]int), uniqueDestObjectsMap: make(map[string]int),
@ -50,15 +55,17 @@ func newScanContext(rows *sql.Rows) (*scanContext, error) {
commonIdentToColumnIndex: commonIdentToColumnIndex, commonIdentToColumnIndex: commonIdentToColumnIndex,
typeInfoMap: make(map[string]typeInfo), typeInfoMap: make(map[string]typeInfo),
typesVisited: newTypeStack(),
}, nil }, nil
} }
func createScanSlice(columnCount int) []interface{} { func createScanSlice(columnCount int) []interface{} {
scanSlice := make([]interface{}, columnCount)
scanPtrSlice := make([]interface{}, columnCount) scanPtrSlice := make([]interface{}, columnCount)
for i := range scanPtrSlice { for i := range scanPtrSlice {
scanPtrSlice[i] = &scanSlice[i] // if destination is pointer to interface sql.Scan will just forward driver value var a interface{}
scanPtrSlice[i] = &a // if destination is pointer to interface sql.Scan will just forward driver value
} }
return scanPtrSlice return scanPtrSlice
@ -69,17 +76,17 @@ type typeInfo struct {
} }
type fieldMapping struct { type fieldMapping struct {
complexType bool // slice or struct complexType bool // slice and struct are complex types
columnIndex int rowIndex int // index in ScanContext.row
implementsScanner bool implementsScanner bool
} }
func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo { func (s *ScanContext) getTypeInfo(structType reflect.Type, parentField *reflect.StructField) typeInfo {
typeMapKey := structType.String() typeMapKey := structType.String()
if parentField != nil { if parentField != nil {
typeMapKey += string(parentField.Tag) typeMapKey = concat(typeMapKey, string(parentField.Tag))
} }
if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok { if typeInfo, ok := s.typeInfoMap[typeMapKey]; ok {
@ -97,7 +104,7 @@ func (s *scanContext) getTypeInfo(structType reflect.Type, parentField *reflect.
columnIndex := s.typeToColumnIndex(newTypeName, fieldName) columnIndex := s.typeToColumnIndex(newTypeName, fieldName)
fieldMap := fieldMapping{ fieldMap := fieldMapping{
columnIndex: columnIndex, rowIndex: columnIndex,
} }
if implementsScannerType(field.Type) { if implementsScannerType(field.Type) {
@ -120,26 +127,27 @@ type groupKeyInfo struct {
subTypes []groupKeyInfo subTypes []groupKeyInfo
} }
func (s *scanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string { func (s *ScanContext) getGroupKey(structType reflect.Type, structField *reflect.StructField) string {
mapKey := structType.Name() mapKey := structType.Name()
if structField != nil { if structField != nil {
mapKey += structField.Type.String() mapKey = concat(mapKey, structField.Type.String())
} }
if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok { if groupKeyInfo, ok := s.groupKeyInfoCache[mapKey]; ok {
return s.constructGroupKey(groupKeyInfo) return s.constructGroupKey(groupKeyInfo)
} }
groupKeyInfo := s.getGroupKeyInfo(structType, structField, newTypeStack()) tempTypeStack := newTypeStack()
groupKeyInfo := s.getGroupKeyInfo(structType, structField, &tempTypeStack)
s.groupKeyInfoCache[mapKey] = groupKeyInfo s.groupKeyInfoCache[mapKey] = groupKeyInfo
return s.constructGroupKey(groupKeyInfo) return s.constructGroupKey(groupKeyInfo)
} }
func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string { func (s *ScanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 { if len(groupKeyInfo.indexes) == 0 && len(groupKeyInfo.subTypes) == 0 {
return fmt.Sprintf("|ROW:%d|", s.rowNum) return fmt.Sprintf("|ROW:%d|", s.rowNum)
} }
@ -147,10 +155,7 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
var groupKeys []string var groupKeys []string
for _, index := range groupKeyInfo.indexes { for _, index := range groupKeyInfo.indexes {
cellValue := s.rowElem(index) groupKeys = append(groupKeys, s.rowElemToString(index))
subKey := valueToString(reflect.ValueOf(cellValue))
groupKeys = append(groupKeys, subKey)
} }
var subTypesGroupKeys []string var subTypesGroupKeys []string
@ -158,10 +163,10 @@ func (s *scanContext) constructGroupKey(groupKeyInfo groupKeyInfo) string {
subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType)) subTypesGroupKeys = append(subTypesGroupKeys, s.constructGroupKey(subType))
} }
return groupKeyInfo.typeName + "(" + strings.Join(groupKeys, ",") + strings.Join(subTypesGroupKeys, ",") + ")" return concat(groupKeyInfo.typeName, "(", strings.Join(groupKeys, ","), strings.Join(subTypesGroupKeys, ","), ")")
} }
func (s *scanContext) getGroupKeyInfo( func (s *ScanContext) getGroupKeyInfo(
structType reflect.Type, structType reflect.Type,
parentField *reflect.StructField, parentField *reflect.StructField,
typeVisited *typeStack) groupKeyInfo { typeVisited *typeStack) groupKeyInfo {
@ -210,7 +215,7 @@ func (s *scanContext) getGroupKeyInfo(
return ret return ret
} }
func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int { func (s *ScanContext) typeToColumnIndex(typeName, fieldName string) int {
var key string var key string
if typeName != "" { if typeName != "" {
@ -228,32 +233,36 @@ func (s *scanContext) typeToColumnIndex(typeName, fieldName string) int {
return index return index
} }
func (s *scanContext) rowElem(index int) interface{} { // rowElemValue always returns non-ptr value,
cellValue := reflect.ValueOf(s.row[index]) // invalid value is nil
func (s *ScanContext) rowElemValue(index int) reflect.Value {
if cellValue.IsValid() && !cellValue.IsNil() { scannedValue := reflect.ValueOf(s.row[index])
return cellValue.Elem().Interface() return scannedValue.Elem().Elem() // no need to check validity of Elem, because s.row[index] always contains interface in interface
}
return nil
} }
func (s *scanContext) rowElemValuePtr(index int) reflect.Value { func (s *ScanContext) rowElemToString(index int) string {
rowElem := s.rowElem(index) value := s.rowElemValue(index)
rowElemValue := reflect.ValueOf(rowElem)
if !value.IsValid() {
return "nil"
}
valueInterface := value.Interface()
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
func (s *ScanContext) rowElemValueClonePtr(index int) reflect.Value {
rowElemValue := s.rowElemValue(index)
if !rowElemValue.IsValid() { if !rowElemValue.IsValid() {
return reflect.Value{} return reflect.Value{}
} }
if rowElemValue.Kind() == reflect.Ptr {
return rowElemValue
}
if rowElemValue.CanAddr() {
return rowElemValue.Addr()
}
newElem := reflect.New(rowElemValue.Type()) newElem := reflect.New(rowElemValue.Type())
newElem.Elem().Set(rowElemValue) newElem.Elem().Set(rowElemValue)
return newElem return newElem

View file

@ -4,9 +4,9 @@ import "reflect"
type typeStack []*reflect.Type type typeStack []*reflect.Type
func newTypeStack() *typeStack { func newTypeStack() typeStack {
stack := make(typeStack, 0, 20) stack := make(typeStack, 0, 20)
return &stack return stack
} }
func (s *typeStack) isEmpty() bool { func (s *typeStack) isEmpty() bool {

View file

@ -18,9 +18,9 @@ func implementsScannerType(fieldType reflect.Type) bool {
return true return true
} }
typePtr := reflect.New(fieldType).Type() fieldTypePtr := reflect.New(fieldType).Type()
return typePtr.Implements(scannerInterfaceType) return fieldTypePtr.Implements(scannerInterfaceType)
} }
func getScanner(value reflect.Value) sql.Scanner { func getScanner(value reflect.Value) sql.Scanner {
@ -68,9 +68,9 @@ func appendElemToSlice(slicePtrValue reflect.Value, objPtrValue reflect.Value) e
if newSliceElemValue.Kind() == reflect.Ptr { if newSliceElemValue.Kind() == reflect.Ptr {
newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem())) newSliceElemValue.Set(reflect.New(newSliceElemValue.Type().Elem()))
err = tryAssign(objPtrValue.Elem(), newSliceElemValue.Elem()) err = assign(objPtrValue.Elem(), newSliceElemValue.Elem())
} else { } else {
err = tryAssign(objPtrValue.Elem(), newSliceElemValue) err = assign(objPtrValue.Elem(), newSliceElemValue)
} }
if err != nil { if err != nil {
@ -138,29 +138,6 @@ func initializeValueIfNilPtr(value reflect.Value) {
} }
} }
func valueToString(value reflect.Value) string {
if !value.IsValid() {
return "nil"
}
var valueInterface interface{}
if value.Kind() == reflect.Ptr {
if value.IsNil() {
return "nil"
}
valueInterface = value.Elem().Interface()
} else {
valueInterface = value.Interface()
}
if t, ok := valueInterface.(fmt.Stringer); ok {
return t.String()
}
return fmt.Sprintf("%#v", valueInterface)
}
var timeType = reflect.TypeOf(time.Now()) var timeType = reflect.TypeOf(time.Now())
var uuidType = reflect.TypeOf(uuid.New()) var uuidType = reflect.TypeOf(uuid.New())
var byteArrayType = reflect.TypeOf([]byte("")) var byteArrayType = reflect.TypeOf([]byte(""))
@ -180,51 +157,57 @@ func isSimpleModelType(objType reflect.Type) bool {
return objType == timeType || objType == uuidType || objType == byteArrayType return objType == timeType || objType == uuidType || objType == byteArrayType
} }
func isIntegerType(objType reflect.Type) bool { // source can't be pointer
objType = indirectType(objType) // destination can be pointer
func assign(source, destination reflect.Value) error {
if destination.Kind() == reflect.Ptr {
if destination.IsNil() {
initializeValueIfNilPtr(destination)
}
switch objType.Kind() { destination = destination.Elem()
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return true
} }
return false err := tryAssign(source, destination)
if err != nil {
// needs for the type conversions are rare, so we leave conversion as a last assign step if everything else fails
if tryConvert(source, destination) {
return nil
}
return err
}
return nil
} }
func isFloatType(value reflect.Type) bool { func assignIfAssignable(source, destination reflect.Value) bool {
switch value.Kind() { sourceType := source.Type()
case reflect.Float32, reflect.Float64: if sourceType.AssignableTo(destination.Type()) {
return true switch sourceType {
} case byteArrayType:
destination.SetBytes(cloneBytes(source.Interface().([]byte)))
return false
}
func tryAssign(source, destination reflect.Value) error {
if source.Type() != destination.Type() &&
!isFloatType(destination.Type()) && // to preserve precision during conversion
!(isIntegerType(source.Type()) && destination.Kind() == reflect.String) && // default conversion will convert int to 1 rune string
source.Type().ConvertibleTo(destination.Type()) {
source = source.Convert(destination.Type())
}
if source.Type().AssignableTo(destination.Type()) {
switch b := source.Interface().(type) {
case []byte:
destination.SetBytes(cloneBytes(b))
default: default:
destination.Set(source) destination.Set(source)
} }
return true
}
return false
}
// source and destination are non-ptr values
func tryAssign(source, destination reflect.Value) error {
if assignIfAssignable(source, destination) {
return nil return nil
} }
sourceInterface := source.Interface() sourceInterface := source.Interface()
switch destination.Interface().(type) { switch destination.Type().Kind() {
case bool: case reflect.Bool:
var nullBool internal.NullBool var nullBool internal.NullBool
err := nullBool.Scan(sourceInterface) err := nullBool.Scan(sourceInterface)
@ -235,7 +218,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetBool(nullBool.Bool) destination.SetBool(nullBool.Bool)
case float32, float64: case reflect.Float32, reflect.Float64:
var nullFloat sql.NullFloat64 var nullFloat sql.NullFloat64
err := nullFloat.Scan(sourceInterface) err := nullFloat.Scan(sourceInterface)
@ -246,7 +229,7 @@ func tryAssign(source, destination reflect.Value) error {
if nullFloat.Valid { if nullFloat.Valid {
destination.SetFloat(nullFloat.Float64) destination.SetFloat(nullFloat.Float64)
} }
case int, int8, int16, int32, int64: case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
var integer sql.NullInt64 var integer sql.NullInt64
err := integer.Scan(sourceInterface) err := integer.Scan(sourceInterface)
@ -258,7 +241,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetInt(integer.Int64) destination.SetInt(integer.Int64)
} }
case uint, uint8, uint16, uint32, uint64: case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
var uInt internal.NullUInt64 var uInt internal.NullUInt64
err := uInt.Scan(sourceInterface) err := uInt.Scan(sourceInterface)
@ -271,7 +254,7 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetUint(uInt.UInt64) destination.SetUint(uInt.UInt64)
} }
case string: case reflect.String:
var str sql.NullString var str sql.NullString
err := str.Scan(sourceInterface) err := str.Scan(sourceInterface)
@ -283,57 +266,42 @@ func tryAssign(source, destination reflect.Value) error {
destination.SetString(str.String) destination.SetString(str.String)
} }
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default: default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface()) switch destination.Interface().(type) {
case time.Time:
var nullTime internal.NullTime
err := nullTime.Scan(sourceInterface)
if err != nil {
return err
}
if nullTime.Valid {
destination.Set(reflect.ValueOf(nullTime.Time))
}
default:
return fmt.Errorf("can't assign %T to %T", sourceInterface, destination.Interface())
}
} }
return nil return nil
} }
func setReflectValue(source, destination reflect.Value) error { func tryConvert(source, destination reflect.Value) bool {
destinationType := destination.Type()
if destination.Kind() == reflect.Ptr { if source.Type().ConvertibleTo(destinationType) {
if destination.IsNil() { source = source.Convert(destinationType)
initializeValueIfNilPtr(destination) return assignIfAssignable(source, destination)
}
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination.Elem()); err != nil {
return err
}
} else {
if source.Kind() == reflect.Ptr {
if source.IsNil() {
return nil // source is nil, destination should keep its zero value
}
source = source.Elem()
}
if err := tryAssign(source, destination); err != nil {
return err
}
} }
return nil return false
}
func setZeroValue(value reflect.Value) {
if !value.IsZero() {
value.Set(reflect.Zero(value.Type()))
}
} }
func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool { func isPrimaryKey(field reflect.StructField, primaryKeyOverwrites []string) bool {
@ -389,3 +357,11 @@ func cloneBytes(b []byte) []byte {
copy(c, b) copy(c, b)
return c return c
} }
func concat(stringList ...string) string {
var b strings.Builder
for _, str := range stringList {
b.WriteString(str)
}
return b.String()
}

View file

@ -6,6 +6,15 @@ import (
"time" "time"
) )
// This functions can be used, instead of its method counterparts, to have a better indentation of a complex condition
// in the Go code and in the generated SQL.
var (
// AND function adds AND operator between expressions.
AND = jet.AND
// OR function adds OR operator between expressions.
OR = jet.OR
)
// ROW is construct one table row from list of expressions. // ROW is construct one table row from list of expressions.
func ROW(expressions ...Expression) Expression { func ROW(expressions ...Expression) Expression {
return jet.NewFunc("", expressions, nil) return jet.NewFunc("", expressions, nil)

View file

@ -148,9 +148,9 @@ func TestSelect_NOT_EXISTS(t *testing.T) {
SELECT table1.col_int AS "table1.col_int" SELECT table1.col_int AS "table1.col_int"
FROM db.table1 FROM db.table1
WHERE NOT (EXISTS ( WHERE NOT (EXISTS (
SELECT table2.col_int AS "table2.col_int" SELECT table2.col_int AS "table2.col_int"
FROM db.table2 FROM db.table2
WHERE table1.col_int = table2.col_int WHERE table1.col_int = table2.col_int
)); ));
`) `)
} }

View file

@ -951,8 +951,12 @@ func TestRowsScan(t *testing.T) {
stmt := SELECT( stmt := SELECT(
Inventory.AllColumns, Inventory.AllColumns,
Film.AllColumns,
Store.AllColumns,
).FROM( ).FROM(
Inventory, Inventory.
INNER_JOIN(Film, Film.FilmID.EQ(Inventory.FilmID)).
INNER_JOIN(Store, Store.StoreID.EQ(Inventory.StoreID)),
).ORDER_BY( ).ORDER_BY(
Inventory.InventoryID.ASC(), Inventory.InventoryID.ASC(),
) )
@ -960,20 +964,43 @@ func TestRowsScan(t *testing.T) {
rows, err := stmt.Rows(context.Background(), db) rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err) require.NoError(t, err)
var inventory struct {
model.Inventory
Film model.Film
Store model.Store
}
for rows.Next() { for rows.Next() {
var inventory model.Inventory
err = rows.Scan(&inventory) err = rows.Scan(&inventory)
require.NoError(t, err) require.NoError(t, err)
require.NotEqual(t, inventory.InventoryID, uint32(0)) require.NotEmpty(t, inventory.InventoryID)
require.NotEqual(t, inventory.FilmID, uint16(0)) require.NotEmpty(t, inventory.FilmID)
require.NotEqual(t, inventory.StoreID, uint16(0)) require.NotEmpty(t, inventory.StoreID)
require.NotEqual(t, inventory.LastUpdate, time.Time{}) require.NotEmpty(t, inventory.LastUpdate)
require.NotEmpty(t, inventory.Film.FilmID)
require.NotEmpty(t, inventory.Film.Title)
require.NotEmpty(t, inventory.Film.Description)
require.NotEmpty(t, inventory.Store.StoreID)
require.NotEmpty(t, inventory.Store.AddressID)
require.NotEmpty(t, inventory.Store.ManagerStaffID)
if inventory.InventoryID == 2103 { if inventory.InventoryID == 2103 {
require.Equal(t, inventory.FilmID, uint16(456)) require.Equal(t, inventory.FilmID, uint16(456))
require.Equal(t, inventory.StoreID, uint8(2)) require.Equal(t, inventory.StoreID, uint8(2))
require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z") require.Equal(t, inventory.LastUpdate.Format(time.RFC3339), "2006-02-15T05:09:17Z")
require.Equal(t, inventory.Film.FilmID, uint16(456))
require.Equal(t, inventory.Film.Title, "INCH JET")
require.Equal(t, *inventory.Film.Description, "A Fateful Saga of a Womanizer And a Student who must Defeat a Butler in A Monastery")
require.Equal(t, *inventory.Film.ReleaseYear, int16(2006))
require.Equal(t, inventory.Store.StoreID, uint8(2))
require.Equal(t, inventory.Store.ManagerStaffID, uint8(2))
require.Equal(t, inventory.Store.AddressID, uint16(2))
} }
} }
@ -1029,3 +1056,50 @@ func TestScanNumericToNumber(t *testing.T) {
require.Equal(t, number.Float32, float32(1.234568e+09)) require.Equal(t, number.Float32, float32(1.234568e+09))
require.Equal(t, number.Float64, float64(1.23456789e+09)) require.Equal(t, number.Float64, float64(1.23456789e+09))
} }
// scan into custom base types should be equivalent to the scan into base go types
func TestScanIntoCustomBaseTypes(t *testing.T) {
type MyUint8 uint8
type MyUint16 uint16
type MyUint32 uint32
type MyInt16 int16
type MyFloat32 float32
type MyFloat64 float64
type MyString string
type MyTime = time.Time
type film struct {
FilmID MyUint16 `sql:"primary_key"`
Title MyString
Description *MyString
ReleaseYear *MyInt16
LanguageID MyUint8
OriginalLanguageID *MyUint8
RentalDuration MyUint8
RentalRate MyFloat32
Length *MyUint32
ReplacementCost MyFloat64
Rating *model.FilmRating
SpecialFeatures *MyString
LastUpdate MyTime
}
stmt := SELECT(
Film.AllColumns,
).FROM(
Film,
).ORDER_BY(
Film.FilmID.ASC(),
).LIMIT(3)
var films []model.Film
err := stmt.Query(db, &films)
require.NoError(t, err)
var myFilms []film
err = stmt.Query(db, &myFilms)
require.NoError(t, err)
require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms))
}

View file

@ -165,9 +165,9 @@ WITH payments_to_delete AS (
) )
DELETE FROM dvds.payment DELETE FROM dvds.payment
WHERE payment.payment_id IN ( WHERE payment.payment_id IN (
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete FROM payments_to_delete
); );
`, "''", "`")) `, "''", "`"))
tx, err := db.Begin() tx, err := db.Begin()

View file

@ -38,6 +38,152 @@ ORDER BY "Album"."AlbumId" ASC;
requireQueryLogged(t, stmt, 347) requireQueryLogged(t, stmt, 347)
} }
func TestComplex_AND_OR(t *testing.T) {
stmt := SELECT(
Artist.AllColumns,
Album.AllColumns,
Track.AllColumns,
).FROM(
Artist.
LEFT_JOIN(Album, Artist.ArtistId.EQ(Album.ArtistId)).
LEFT_JOIN(Track, Track.AlbumId.EQ(Album.AlbumId)),
).WHERE(
AND(
Artist.ArtistId.BETWEEN(Int(5), Int(11)),
Album.AlbumId.GT_EQ(Int(7)),
Track.TrackId.GT(Int(74)),
OR(
Track.GenreId.EQ(Int(2)),
Track.UnitPrice.GT(Float(1.01)),
),
Track.TrackId.LT(Int(125)),
),
).ORDER_BY(
Artist.ArtistId,
Album.AlbumId,
Track.TrackId,
)
testutils.AssertDebugStatementSql(t, stmt, `
SELECT "Artist"."ArtistId" AS "Artist.ArtistId",
"Artist"."Name" AS "Artist.Name",
"Album"."AlbumId" AS "Album.AlbumId",
"Album"."Title" AS "Album.Title",
"Album"."ArtistId" AS "Album.ArtistId",
"Track"."TrackId" AS "Track.TrackId",
"Track"."Name" AS "Track.Name",
"Track"."AlbumId" AS "Track.AlbumId",
"Track"."MediaTypeId" AS "Track.MediaTypeId",
"Track"."GenreId" AS "Track.GenreId",
"Track"."Composer" AS "Track.Composer",
"Track"."Milliseconds" AS "Track.Milliseconds",
"Track"."Bytes" AS "Track.Bytes",
"Track"."UnitPrice" AS "Track.UnitPrice"
FROM chinook."Artist"
LEFT JOIN chinook."Album" ON ("Artist"."ArtistId" = "Album"."ArtistId")
LEFT JOIN chinook."Track" ON ("Track"."AlbumId" = "Album"."AlbumId")
WHERE (
("Artist"."ArtistId" BETWEEN 5 AND 11)
AND ("Album"."AlbumId" >= 7)
AND ("Track"."TrackId" > 74)
AND (
("Track"."GenreId" = 2)
OR ("Track"."UnitPrice" > 1.01)
)
AND ("Track"."TrackId" < 125)
)
ORDER BY "Artist"."ArtistId", "Album"."AlbumId", "Track"."TrackId";
`)
var dest []struct {
model.Artist
Albums []struct {
model.Album
Tracks []model.Track
}
}
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"ArtistId": 6,
"Name": "Ant<6E>nio Carlos Jobim",
"Albums": [
{
"AlbumId": 8,
"Title": "Warner 25 Anos",
"ArtistId": 6,
"Tracks": [
{
"TrackId": 75,
"Name": "O Boto (B<>to)",
"AlbumId": 8,
"MediaTypeId": 1,
"GenreId": 2,
"Composer": null,
"Milliseconds": 366837,
"Bytes": 12089673,
"UnitPrice": 0.99
},
{
"TrackId": 76,
"Name": "Canta, Canta Mais",
"AlbumId": 8,
"MediaTypeId": 1,
"GenreId": 2,
"Composer": null,
"Milliseconds": 271856,
"Bytes": 8719426,
"UnitPrice": 0.99
}
]
}
]
},
{
"ArtistId": 10,
"Name": "Billy Cobham",
"Albums": [
{
"AlbumId": 13,
"Title": "The Best Of Billy Cobham",
"ArtistId": 10,
"Tracks": [
{
"TrackId": 123,
"Name": "Quadrant",
"AlbumId": 13,
"MediaTypeId": 1,
"GenreId": 2,
"Composer": "Billy Cobham",
"Milliseconds": 261851,
"Bytes": 8538199,
"UnitPrice": 0.99
},
{
"TrackId": 124,
"Name": "Snoopy's search-Red baron",
"AlbumId": 13,
"MediaTypeId": 1,
"GenreId": 2,
"Composer": "Billy Cobham",
"Milliseconds": 456071,
"Bytes": 15075616,
"UnitPrice": 0.99
}
]
}
]
}
]
`)
}
func TestJoinEverything(t *testing.T) { func TestJoinEverything(t *testing.T) {
manager := Employee.AS("Manager") manager := Employee.AS("Manager")

View file

@ -124,9 +124,11 @@ func TestDeleteFrom(t *testing.T) {
table.Actor, table.Actor,
). ).
WHERE( WHERE(
table.Staff.StaffID.EQ(table.Rental.StaffID). AND(
AND(table.Staff.StaffID.EQ(Int(2))). table.Staff.StaffID.EQ(table.Rental.StaffID),
AND(table.Rental.RentalID.LT(Int(10))), table.Store.StoreID.EQ(Int(2)),
table.Rental.RentalID.LT(Int(10)),
),
). ).
RETURNING( RETURNING(
table.Rental.AllColumns, table.Rental.AllColumns,
@ -138,7 +140,11 @@ DELETE FROM dvds.rental
USING dvds.staff USING dvds.staff
INNER JOIN dvds.store ON (store.store_id = staff.staff_id), INNER JOIN dvds.store ON (store.store_id = staff.staff_id),
dvds.actor dvds.actor
WHERE ((staff.staff_id = rental.staff_id) AND (staff.staff_id = $1)) AND (rental.rental_id < $2) WHERE (
(staff.staff_id = rental.staff_id)
AND (store.store_id = $1)
AND (rental.rental_id < $2)
)
RETURNING rental.rental_id AS "rental.rental_id", RETURNING rental.rental_id AS "rental.rental_id",
rental.rental_date AS "rental.rental_date", rental.rental_date AS "rental.rental_date",
rental.inventory_id AS "rental.inventory_id", rental.inventory_id AS "rental.inventory_id",

View file

@ -786,6 +786,123 @@ func TestRowsScan(t *testing.T) {
requireQueryLogged(t, stmt, 0) requireQueryLogged(t, stmt, 0)
} }
func TestScanNullColumn(t *testing.T) {
stmt := SELECT(
Address.AllColumns,
).FROM(
Address,
).WHERE(
Address.Address2.IS_NULL(),
)
var dest []model.Address
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"AddressID": 1,
"Address": "47 MySakila Drive",
"Address2": null,
"District": "Alberta",
"CityID": 300,
"PostalCode": "",
"Phone": "",
"LastUpdate": "2006-02-15T09:45:30Z"
},
{
"AddressID": 2,
"Address": "28 MySQL Boulevard",
"Address2": null,
"District": "QLD",
"CityID": 576,
"PostalCode": "",
"Phone": "",
"LastUpdate": "2006-02-15T09:45:30Z"
},
{
"AddressID": 3,
"Address": "23 Workhaven Lane",
"Address2": null,
"District": "Alberta",
"CityID": 300,
"PostalCode": "",
"Phone": "14033335568",
"LastUpdate": "2006-02-15T09:45:30Z"
},
{
"AddressID": 4,
"Address": "1411 Lillydale Drive",
"Address2": null,
"District": "QLD",
"CityID": 576,
"PostalCode": "",
"Phone": "6172235589",
"LastUpdate": "2006-02-15T09:45:30Z"
}
]
`)
}
func TestRowsScanSetZeroValue(t *testing.T) {
stmt := SELECT(
Rental.AllColumns,
).FROM(
Rental,
).WHERE(
Rental.RentalID.IN(Int(16049), Int(15966)),
).ORDER_BY(
Rental.RentalID.DESC(),
)
rows, err := stmt.Rows(context.Background(), db)
require.NoError(t, err)
defer rows.Close()
// destination object is used as destination for all rows scan.
// this tests checks that ReturnedDate is set to nil with the second call
// check qrm.setZeroValue
var dest model.Rental
for rows.Next() {
err := rows.Scan(&dest)
require.NoError(t, err)
if dest.RentalID == 16049 {
testutils.AssertJSON(t, dest, `
{
"RentalID": 16049,
"RentalDate": "2005-08-23T22:50:12Z",
"InventoryID": 2666,
"CustomerID": 393,
"ReturnDate": "2005-08-30T01:01:12Z",
"StaffID": 2,
"LastUpdate": "2006-02-16T02:30:53Z"
}
`)
} else {
testutils.AssertJSON(t, dest, `
{
"RentalID": 15966,
"RentalDate": "2006-02-14T15:16:03Z",
"InventoryID": 4472,
"CustomerID": 374,
"ReturnDate": null,
"StaffID": 1,
"LastUpdate": "2006-02-16T02:30:53Z"
}
`)
}
}
err = rows.Close()
require.NoError(t, err)
err = rows.Err()
require.NoError(t, err)
}
func TestScanNumericToFloat(t *testing.T) { func TestScanNumericToFloat(t *testing.T) {
type Number struct { type Number struct {
Float32 float32 Float32 float32
@ -826,6 +943,54 @@ func TestScanNumericToIntegerError(t *testing.T) {
} }
func TestScanIntoCustomBaseTypes(t *testing.T) {
type MyUint8 uint8
type MyUint16 uint16
type MyUint32 uint32
type MyInt16 int16
type MyFloat32 float32
type MyFloat64 float64
type MyString string
type MyTime = time.Time
type film struct {
FilmID MyUint16 `sql:"primary_key"`
Title MyString
Description *MyString
ReleaseYear *MyInt16
LanguageID MyUint8
RentalDuration MyUint8
RentalRate MyFloat32
Length *MyUint32
ReplacementCost MyFloat64
Rating *model.MpaaRating
LastUpdate MyTime
SpecialFeatures *MyString
Fulltext MyString
}
stmt := SELECT(
Film.AllColumns,
).FROM(
Film,
).ORDER_BY(
Film.FilmID.ASC(),
).LIMIT(3)
var films []model.Film
err := stmt.Query(db, &films)
require.NoError(t, err)
var myFilms []film
err = stmt.Query(db, &myFilms)
require.NoError(t, err)
require.Equal(t, testutils.ToJSON(films), testutils.ToJSON(myFilms))
}
// QueryContext panic when the scanned value is nil and the destination is a slice of primitive // QueryContext panic when the scanned value is nil and the destination is a slice of primitive
// https://github.com/go-jet/jet/issues/91 // https://github.com/go-jet/jet/issues/91
func TestScanToPrimitiveElementsSlice(t *testing.T) { func TestScanToPrimitiveElementsSlice(t *testing.T) {

View file

@ -395,8 +395,15 @@ func TestExecution1(t *testing.T) {
Customer.CustomerID, Customer.CustomerID,
Customer.LastName, Customer.LastName,
). ).
WHERE(City.City.EQ(String("London")).OR(City.City.EQ(String("York")))). WHERE(
ORDER_BY(City.CityID, Address.AddressID, Customer.CustomerID) OR(
City.City.EQ(String("London")),
City.City.EQ(String("York")),
),
).
ORDER_BY(
City.CityID, Address.AddressID, Customer.CustomerID,
)
testutils.AssertDebugStatementSql(t, stmt, ` testutils.AssertDebugStatementSql(t, stmt, `
SELECT city.city_id AS "city.city_id", SELECT city.city_id AS "city.city_id",
@ -408,7 +415,10 @@ SELECT city.city_id AS "city.city_id",
FROM dvds.city FROM dvds.city
INNER JOIN dvds.address ON (address.city_id = city.city_id) INNER JOIN dvds.address ON (address.city_id = city.city_id)
INNER JOIN dvds.customer ON (customer.address_id = address.address_id) INNER JOIN dvds.customer ON (customer.address_id = address.address_id)
WHERE (city.city = 'London') OR (city.city = 'York') WHERE (
(city.city = 'London')
OR (city.city = 'York')
)
ORDER BY city.city_id, address.address_id, customer.customer_id; ORDER BY city.city_id, address.address_id, customer.customer_id;
`, "London", "York") `, "London", "York")
@ -1073,9 +1083,9 @@ SELECT film.film_id AS "film.film_id",
film.fulltext AS "film.fulltext" film.fulltext AS "film.fulltext"
FROM dvds.film FROM dvds.film
WHERE film.rental_rate = ( WHERE film.rental_rate = (
SELECT MAX(film.rental_rate) SELECT MAX(film.rental_rate)
FROM dvds.film FROM dvds.film
) )
ORDER BY film.film_id ASC; ORDER BY film.film_id ASC;
` `
@ -2521,6 +2531,79 @@ func TestRecursionScanNx1(t *testing.T) {
}) })
} }
type StoreInfo struct {
model.Store
Staffs ManagerInfo
}
type ManagerInfo struct {
model.Staff
Store *StoreInfo
}
func TestRecursionScan1x1(t *testing.T) {
stmt := SELECT(
Store.AllColumns,
Staff.AllColumns,
).FROM(
Store.
INNER_JOIN(Staff, Staff.StaffID.EQ(Store.ManagerStaffID)),
).ORDER_BY(
Store.StoreID,
)
var dest []StoreInfo
err := stmt.Query(db, &dest)
require.NoError(t, err)
testutils.AssertJSON(t, dest, `
[
{
"StoreID": 1,
"ManagerStaffID": 1,
"AddressID": 1,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": {
"StaffID": 1,
"FirstName": "Mike",
"LastName": "Hillyer",
"AddressID": 3,
"Email": "Mike.Hillyer@sakilastaff.com",
"StoreID": 1,
"Active": true,
"Username": "Mike",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": "iVBORw0KWgo=",
"Store": null
}
},
{
"StoreID": 2,
"ManagerStaffID": 2,
"AddressID": 2,
"LastUpdate": "2006-02-15T09:57:12Z",
"Staffs": {
"StaffID": 2,
"FirstName": "Jon",
"LastName": "Stephens",
"AddressID": 4,
"Email": "Jon.Stephens@sakilastaff.com",
"StoreID": 2,
"Active": true,
"Username": "Jon",
"Password": "8cb2237d0679ca88db6464eac60da96345513964",
"LastUpdate": "2006-05-16T16:13:11.79328Z",
"Picture": null,
"Store": null
}
}
]
`)
}
// In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions, // In parameterized statements integer literals, like Int(num), are replaced with a placeholders. For some expressions,
// postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error. // postgres interpreter will not have enough information to deduce the type. If this is the case postgres returns an error.
// Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible. // Int8, Int16, .... functions will add automatic type cast over placeholder, so type deduction is always possible.

View file

@ -2,7 +2,6 @@ package postgres
import ( import (
"context" "context"
"fmt"
"github.com/go-jet/jet/v2/internal/testutils" "github.com/go-jet/jet/v2/internal/testutils"
. "github.com/go-jet/jet/v2/postgres" . "github.com/go-jet/jet/v2/postgres"
"github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model" "github.com/go-jet/jet/v2/tests/.gentestdata/jetdb/northwind/model"
@ -74,9 +73,9 @@ WITH regional_sales AS (
SELECT regional_sales."orders.ship_region" AS "orders.ship_region" SELECT regional_sales."orders.ship_region" AS "orders.ship_region"
FROM regional_sales FROM regional_sales
WHERE regional_sales.total_sales > (( WHERE regional_sales.total_sales > ((
SELECT SUM(regional_sales.total_sales) SELECT SUM(regional_sales.total_sales)
FROM regional_sales FROM regional_sales
) / 50) ) / 50)
) )
SELECT orders.ship_region AS "orders.ship_region", SELECT orders.ship_region AS "orders.ship_region",
order_details.product_id AS "order_details.product_id", order_details.product_id AS "order_details.product_id",
@ -85,9 +84,9 @@ SELECT orders.ship_region AS "orders.ship_region",
FROM northwind.orders FROM northwind.orders
INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id) INNER JOIN northwind.order_details ON (orders.order_id = order_details.order_id)
WHERE orders.ship_region IN ( WHERE orders.ship_region IN (
SELECT top_region."orders.ship_region" AS "orders.ship_region" SELECT top_region."orders.ship_region" AS "orders.ship_region"
FROM top_region FROM top_region
) )
GROUP BY orders.ship_region, order_details.product_id GROUP BY orders.ship_region, order_details.product_id
ORDER BY SUM(order_details.quantity) DESC; ORDER BY SUM(order_details.quantity) DESC;
`) `)
@ -151,18 +150,18 @@ func TestWithStatementDeleteAndInsert(t *testing.T) {
WITH remove_discontinued_orders AS ( WITH remove_discontinued_orders AS (
DELETE FROM northwind.order_details DELETE FROM northwind.order_details
WHERE order_details.product_id IN ( WHERE order_details.product_id IN (
SELECT products.product_id AS "products.product_id" SELECT products.product_id AS "products.product_id"
FROM northwind.products FROM northwind.products
WHERE products.discontinued = $1 WHERE products.discontinued = $1
) )
RETURNING order_details.product_id AS "order_details.product_id" RETURNING order_details.product_id AS "order_details.product_id"
),update_discontinued_price AS ( ),update_discontinued_price AS (
UPDATE northwind.products UPDATE northwind.products
SET unit_price = $2 SET unit_price = $2
WHERE products.product_id IN ( WHERE products.product_id IN (
SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id" SELECT remove_discontinued_orders."order_details.product_id" AS "order_details.product_id"
FROM remove_discontinued_orders FROM remove_discontinued_orders
) )
RETURNING products.product_id AS "products.product_id", RETURNING products.product_id AS "products.product_id",
products.product_name AS "products.product_name", products.product_name AS "products.product_name",
products.supplier_id AS "products.supplier_id", products.supplier_id AS "products.supplier_id",
@ -864,5 +863,4 @@ WHERE orders1."orders.order_id" < $1;
err := stmt.Query(db, &dest) err := stmt.Query(db, &dest)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, dest, 72) require.Len(t, dest, 72)
fmt.Println(len(dest))
} }

View file

@ -154,9 +154,9 @@ WITH payments_to_update AS (
UPDATE payment UPDATE payment
SET amount = 0 SET amount = 0
WHERE payment.payment_id IN ( WHERE payment.payment_id IN (
SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id" SELECT payments_to_update.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_update FROM payments_to_update
); );
`, "''", "`", -1)) `, "''", "`", -1))
tx := beginDBTx(t) tx := beginDBTx(t)
@ -206,9 +206,9 @@ WITH payments_to_delete AS (
) )
DELETE FROM payment DELETE FROM payment
WHERE payment.payment_id IN ( WHERE payment.payment_id IN (
SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id" SELECT payments_to_delete.''payment.payment_id'' AS "payment.payment_id"
FROM payments_to_delete FROM payments_to_delete
); );
`, "''", "`", -1)) `, "''", "`", -1))
tx := beginDBTx(t) tx := beginDBTx(t)