Statements execution with context.

This commit is contained in:
go-jet 2019-06-20 12:22:19 +02:00
parent cdfd8f1dff
commit 1ac324e198
11 changed files with 146 additions and 28 deletions

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
@ -75,6 +76,14 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(d, db, destination) return Query(d, db, destination)
} }
func (d *deleteStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(d, db, context, destination)
}
func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(d, db) return Exec(d, db)
} }
func (d *deleteStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(d, db, context)
}

View file

@ -1,8 +1,13 @@
package execution package execution
import "database/sql" import (
"context"
"database/sql"
)
type Db interface { type Db interface {
Exec(query string, args ...interface{}) (sql.Result, error) Exec(query string, args ...interface{}) (sql.Result, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error) Query(query string, args ...interface{}) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
} }

View file

@ -1,6 +1,7 @@
package execution package execution
import ( import (
"context"
"database/sql" "database/sql"
"database/sql/driver" "database/sql/driver"
"errors" "errors"
@ -12,7 +13,7 @@ import (
"time" "time"
) )
func Query(db Db, query string, args []interface{}, destinationPtr interface{}) error { func Query(db Db, context context.Context, query string, args []interface{}, destinationPtr interface{}) error {
if destinationPtr == nil { if destinationPtr == nil {
return errors.New("Destination is nil. ") return errors.New("Destination is nil. ")
@ -24,12 +25,12 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
} }
if destinationPtrType.Elem().Kind() == reflect.Slice { if destinationPtrType.Elem().Kind() == reflect.Slice {
return queryToSlice(db, query, args, destinationPtr) return queryToSlice(db, context, query, args, destinationPtr)
} else if destinationPtrType.Elem().Kind() == reflect.Struct { } else if destinationPtrType.Elem().Kind() == reflect.Struct {
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType)) tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
tempSliceValue := tempSlicePtrValue.Elem() tempSliceValue := tempSlicePtrValue.Elem()
err := queryToSlice(db, query, args, tempSlicePtrValue.Interface()) err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface())
if err != nil { if err != nil {
return err return err
@ -53,7 +54,7 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
} }
} }
func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) error { func queryToSlice(db Db, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error {
if db == nil { if db == nil {
return errors.New("db is nil") return errors.New("db is nil")
} }
@ -67,7 +68,11 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{})
return errors.New("Destination has to be a pointer to slice. ") return errors.New("Destination has to be a pointer to slice. ")
} }
rows, err := db.Query(query, args...) if ctx == nil {
ctx = context.Background()
}
rows, err := db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return err return err

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
@ -149,6 +150,14 @@ func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(i, db, destination) return Query(i, db, destination)
} }
func (i *insertStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(i, db, context, destination)
}
func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(i, db) return Exec(i, db)
} }
func (i *insertStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(i, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -96,6 +97,14 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro
return Query(l, db, destination) return Query(l, db, destination)
} }
func (l *lockStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(l, db, context, destination)
}
func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) { func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) {
return Exec(l, db) return Exec(l, db)
} }
func (l *lockStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(l, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
@ -294,6 +295,14 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(s, db, destination) return Query(s, db, destination)
} }
func (s *selectStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(s, db, context, destination)
}
func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(s, db) return Exec(s, db)
} }
func (s *selectStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(s, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
@ -206,6 +207,14 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error
return Query(s, db, destination) return Query(s, db, destination)
} }
func (s *setStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(s, db, context, destination)
}
func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(s, db) return Exec(s, db)
} }
func (s *setStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(s, db, context)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
"strconv" "strconv"
@ -14,7 +15,10 @@ type Statement interface {
DebugSql() (query string, err error) DebugSql() (query string, err error)
Query(db execution.Db, destination interface{}) error Query(db execution.Db, destination interface{}) error
QueryContext(db execution.Db, context context.Context, destination interface{}) error
Exec(db execution.Db) (sql.Result, error) Exec(db execution.Db) (sql.Result, error)
ExecContext(db execution.Db, context context.Context) (sql.Result, error)
} }
func DebugSql(statement Statement) (string, error) { func DebugSql(statement Statement) (string, error) {
@ -33,3 +37,43 @@ func DebugSql(statement Statement) (string, error) {
return debugSql, nil return debugSql, nil
} }
func Query(statement Statement, db execution.Db, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, context.Background(), query, args, destination)
}
func QueryContext(statement Statement, db execution.Db, context context.Context, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, context, query, args, destination)
}
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {
return
}
return db.Exec(query, args...)
}
func ExecContext(statement Statement, db execution.Db, context context.Context) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {
return
}
return db.ExecContext(context, query, args...)
}

View file

@ -1,6 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution" "github.com/go-jet/jet/sqlbuilder/execution"
@ -142,6 +143,14 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er
return Query(u, db, destination) return Query(u, db, destination)
} }
func (u *updateStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
return QueryContext(u, db, context, destination)
}
func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) { func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
return Exec(u, db) return Exec(u, db)
} }
func (u *updateStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
return ExecContext(u, db, context)
}

View file

@ -1,9 +1,7 @@
package sqlbuilder package sqlbuilder
import ( import (
"database/sql"
"errors" "errors"
"github.com/go-jet/jet/sqlbuilder/execution"
"github.com/serenize/snaker" "github.com/serenize/snaker"
"reflect" "reflect"
) )
@ -185,23 +183,3 @@ func mustBe(v reflect.Value, expected reflect.Kind) {
panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String()) panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String())
} }
} }
func Query(statement Statement, db execution.Db, destination interface{}) error {
query, args, err := statement.Sql()
if err != nil {
return err
}
return execution.Query(db, query, args, destination)
}
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
query, args, err := statement.Sql()
if err != nil {
return
}
return db.Exec(query, args...)
}

View file

@ -1,6 +1,7 @@
package tests package tests
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
@ -10,6 +11,7 @@ import (
"gotest.tools/assert" "gotest.tools/assert"
"io/ioutil" "io/ioutil"
"testing" "testing"
"time"
) )
func TestSelect(t *testing.T) { func TestSelect(t *testing.T) {
@ -152,6 +154,36 @@ ORDER BY "Album.AlbumId";
assert.DeepEqual(t, dest[1], album2) assert.DeepEqual(t, dest[1], album2)
} }
func TestQueryWithContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
dest := []model.Album{}
err := Album.
CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
QueryContext(db, ctx, &dest)
assert.Error(t, err, "context deadline exceeded")
}
func TestExecWithContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
_, err := Album.
CROSS_JOIN(Track).
CROSS_JOIN(InvoiceLine).
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
ExecContext(db, ctx)
assert.Error(t, err, "pq: canceling statement due to user request")
}
func TestSubQueriesForQuotedNames(t *testing.T) { func TestSubQueriesForQuotedNames(t *testing.T) {
first10Artist := Artist. first10Artist := Artist.
SELECT(Artist.AllColumns). SELECT(Artist.AllColumns).