Statements execution with context.
This commit is contained in:
parent
cdfd8f1dff
commit
1ac324e198
11 changed files with 146 additions and 28 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
|
|
@ -75,6 +76,14 @@ func (d *deleteStatementImpl) Query(db execution.Db, destination interface{}) er
|
|||
return Query(d, db, destination)
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||
return QueryContext(d, db, context, destination)
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||
return Exec(d, db)
|
||||
}
|
||||
|
||||
func (d *deleteStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
return ExecContext(d, db, context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
package execution
|
||||
|
||||
import "database/sql"
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type Db interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package execution
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
|
|
@ -12,7 +13,7 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func Query(db Db, query string, args []interface{}, destinationPtr interface{}) error {
|
||||
func Query(db Db, context context.Context, query string, args []interface{}, destinationPtr interface{}) error {
|
||||
|
||||
if destinationPtr == nil {
|
||||
return errors.New("Destination is nil. ")
|
||||
|
|
@ -24,12 +25,12 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
|
|||
}
|
||||
|
||||
if destinationPtrType.Elem().Kind() == reflect.Slice {
|
||||
return queryToSlice(db, query, args, destinationPtr)
|
||||
return queryToSlice(db, context, query, args, destinationPtr)
|
||||
} else if destinationPtrType.Elem().Kind() == reflect.Struct {
|
||||
tempSlicePtrValue := reflect.New(reflect.SliceOf(destinationPtrType))
|
||||
tempSliceValue := tempSlicePtrValue.Elem()
|
||||
|
||||
err := queryToSlice(db, query, args, tempSlicePtrValue.Interface())
|
||||
err := queryToSlice(db, context, query, args, tempSlicePtrValue.Interface())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -53,7 +54,7 @@ func Query(db Db, query string, args []interface{}, destinationPtr interface{})
|
|||
}
|
||||
}
|
||||
|
||||
func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{}) error {
|
||||
func queryToSlice(db Db, ctx context.Context, query string, args []interface{}, slicePtr interface{}) error {
|
||||
if db == nil {
|
||||
return errors.New("db is nil")
|
||||
}
|
||||
|
|
@ -67,7 +68,11 @@ func queryToSlice(db Db, query string, args []interface{}, slicePtr interface{})
|
|||
return errors.New("Destination has to be a pointer to slice. ")
|
||||
}
|
||||
|
||||
rows, err := db.Query(query, args...)
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
rows, err := db.QueryContext(ctx, query, args...)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
|
|
@ -149,6 +150,14 @@ func (i *insertStatementImpl) Query(db execution.Db, destination interface{}) er
|
|||
return Query(i, db, destination)
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||
return QueryContext(i, db, context, destination)
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||
return Exec(i, db)
|
||||
}
|
||||
|
||||
func (i *insertStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
return ExecContext(i, db, context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
"github.com/pkg/errors"
|
||||
|
|
@ -96,6 +97,14 @@ func (l *lockStatementImpl) Query(db execution.Db, destination interface{}) erro
|
|||
return Query(l, db, destination)
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||
return QueryContext(l, db, context, destination)
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) Exec(db execution.Db) (sql.Result, error) {
|
||||
return Exec(l, db)
|
||||
}
|
||||
|
||||
func (l *lockStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
return ExecContext(l, db, context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
|
|
@ -294,6 +295,14 @@ func (s *selectStatementImpl) Query(db execution.Db, destination interface{}) er
|
|||
return Query(s, db, destination)
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||
return QueryContext(s, db, context, destination)
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||
return Exec(s, db)
|
||||
}
|
||||
|
||||
func (s *selectStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
return ExecContext(s, db, context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
|
|
@ -206,6 +207,14 @@ func (s *setStatementImpl) Query(db execution.Db, destination interface{}) error
|
|||
return Query(s, db, destination)
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||
return QueryContext(s, db, context, destination)
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||
return Exec(s, db)
|
||||
}
|
||||
|
||||
func (s *setStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
return ExecContext(s, db, context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
"strconv"
|
||||
|
|
@ -14,7 +15,10 @@ type Statement interface {
|
|||
DebugSql() (query string, err error)
|
||||
|
||||
Query(db execution.Db, destination interface{}) error
|
||||
QueryContext(db execution.Db, context context.Context, destination interface{}) error
|
||||
|
||||
Exec(db execution.Db) (sql.Result, error)
|
||||
ExecContext(db execution.Db, context context.Context) (sql.Result, error)
|
||||
}
|
||||
|
||||
func DebugSql(statement Statement) (string, error) {
|
||||
|
|
@ -33,3 +37,43 @@ func DebugSql(statement Statement) (string, error) {
|
|||
|
||||
return debugSql, nil
|
||||
}
|
||||
|
||||
func Query(statement Statement, db execution.Db, destination interface{}) error {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return execution.Query(db, context.Background(), query, args, destination)
|
||||
}
|
||||
|
||||
func QueryContext(statement Statement, db execution.Db, context context.Context, destination interface{}) error {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return execution.Query(db, context, query, args, destination)
|
||||
}
|
||||
|
||||
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return db.Exec(query, args...)
|
||||
}
|
||||
|
||||
func ExecContext(statement Statement, db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return db.ExecContext(context, query, args...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
|
|
@ -142,6 +143,14 @@ func (u *updateStatementImpl) Query(db execution.Db, destination interface{}) er
|
|||
return Query(u, db, destination)
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) QueryContext(db execution.Db, context context.Context, destination interface{}) error {
|
||||
return QueryContext(u, db, context, destination)
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) Exec(db execution.Db) (res sql.Result, err error) {
|
||||
return Exec(u, db)
|
||||
}
|
||||
|
||||
func (u *updateStatementImpl) ExecContext(db execution.Db, context context.Context) (res sql.Result, err error) {
|
||||
return ExecContext(u, db, context)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
package sqlbuilder
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"github.com/go-jet/jet/sqlbuilder/execution"
|
||||
"github.com/serenize/snaker"
|
||||
"reflect"
|
||||
)
|
||||
|
|
@ -185,23 +183,3 @@ func mustBe(v reflect.Value, expected reflect.Kind) {
|
|||
panic("argument mismatch: expected " + expected.String() + ", got " + v.Type().String())
|
||||
}
|
||||
}
|
||||
|
||||
func Query(statement Statement, db execution.Db, destination interface{}) error {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return execution.Query(db, query, args, destination)
|
||||
}
|
||||
|
||||
func Exec(statement Statement, db execution.Db) (res sql.Result, err error) {
|
||||
query, args, err := statement.Sql()
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return db.Exec(query, args...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package tests
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
|
|
@ -10,6 +11,7 @@ import (
|
|||
"gotest.tools/assert"
|
||||
"io/ioutil"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSelect(t *testing.T) {
|
||||
|
|
@ -152,6 +154,36 @@ ORDER BY "Album.AlbumId";
|
|||
assert.DeepEqual(t, dest[1], album2)
|
||||
}
|
||||
|
||||
func TestQueryWithContext(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
dest := []model.Album{}
|
||||
|
||||
err := Album.
|
||||
CROSS_JOIN(Track).
|
||||
CROSS_JOIN(InvoiceLine).
|
||||
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
|
||||
QueryContext(db, ctx, &dest)
|
||||
|
||||
assert.Error(t, err, "context deadline exceeded")
|
||||
}
|
||||
|
||||
func TestExecWithContext(t *testing.T) {
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
_, err := Album.
|
||||
CROSS_JOIN(Track).
|
||||
CROSS_JOIN(InvoiceLine).
|
||||
SELECT(Album.AllColumns, Track.AllColumns, InvoiceLine.AllColumns).
|
||||
ExecContext(db, ctx)
|
||||
|
||||
assert.Error(t, err, "pq: canceling statement due to user request")
|
||||
}
|
||||
|
||||
func TestSubQueriesForQuotedNames(t *testing.T) {
|
||||
first10Artist := Artist.
|
||||
SELECT(Artist.AllColumns).
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue